utils.py 3.64 KB
Newer Older
Johannes Bechberger's avatar
Johannes Bechberger committed
1 2 3
import logging
from os import path
import sys
4
from typing import Tuple, Optional, Any, List, Callable
5
import re
6

7
COLOR_OUTPUT_IF_POSSIBLE = False
8

9 10 11 12 13 14
if sys.stdout.isatty():
    try:
        import termcolor
        COLOR_OUTPUT_IF_POSSIBLE = True
    except ImportError:
        COLOR_OUTPUT_IF_POSSIBLE = False
15 16


17
def force_colored_output():
18
    global COLOR_OUTPUT_IF_POSSIBLE
19
    COLOR_OUTPUT_IF_POSSIBLE = True
20

Johannes Bechberger's avatar
Johannes Bechberger committed
21 22

def get_mjtest_basedir() -> str:
23
    return path.dirname(path.dirname(path.dirname(path.realpath(__file__))))
Johannes Bechberger's avatar
Johannes Bechberger committed
24 25


26 27 28 29
def colored(text: str, *args, **kwargs):
    """
    Wrapper around termcolor.colored (if it's loadable)
    """
30 31 32 33
    global COLOR_OUTPUT_IF_POSSIBLE
    if COLOR_OUTPUT_IF_POSSIBLE:
        return termcolor.colored(text, *args, **kwargs)
    else:
34 35 36 37 38 39 40
        return text


def cprint(text: str, *args, **kwargs):
    """
    Wrapper around termcolor.cprint (if it's loadable)
    """
41 42 43 44
    global COLOR_OUTPUT_IF_POSSIBLE
    if COLOR_OUTPUT_IF_POSSIBLE:
        termcolor.cprint(text, *args, **kwargs)
    else:
45 46 47 48 49
        print(text)


def get_main_class_name(file: str) -> Optional[str]:
    current_class = None
50
    with open(file, "r", errors="backslashreplace") as f:
51
        for line in f:
52 53
            if line.startswith("class ") or line.startswith("/*public*/ class "):
                match = re.search("[A-Za-z_0-9]+", line.replace("class ", "").replace("/*public*/", ""))
54 55 56 57 58 59
                if match:
                    has_public_class = True
                    current_class = match.group(0)
            elif "String[]" in line and "main" in line and "void" in line and "static" in line and "public" in line:
                return current_class
    return None
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115


class InsertionTimeOrderedDict:
    """
    A dictionary which's elements are sorted by their insertion time.
    """

    def __init__(self):
        self._dict = {}
        self._keys = []
        dict()

    def __delitem__(self, key):
        """ Remove the entry with the passed key """
        del(self._dict[key])
        del(self._keys[self._keys.index(key)])

    def __getitem__(self, key):
        """ Get the entry with the passed key """
        return self._dict[key]

    def __setitem__(self, key, value):
        """ Set the value of the item with the passed key """
        if key not in self._dict:
            self._keys.append(key)
        self._dict[key] = value

    def __iter__(self):
        """ Iterate over all keys """
        return self._keys.__iter__()

    def values(self) -> List:
        """ Rerturns all values of this dictionary. They are sorted by their insertion time. """
        return [self._dict[key] for key in self._keys]

    def keys(self) -> List:
        """ Returns all keys of this dictionary. They are sorted by their insertion time. """
        return self._keys

    def __len__(self):
        """ Returns the number of items in this dictionary """
        return len(self._keys)

    @classmethod
    def from_list(cls, items: Optional[list], key_func: Callable[[Any], Any]) -> 'InsertionTimeOrderedDict':
        """
        Creates an ordered dict out of a list of elements.
        :param items: list of elements
        :param key_func: function that returns a key for each passed list element
        :return: created ordered dict with the elements in the same order as in the passed list
        """
        if items is None:
            return InsertionTimeOrderedDict()
        ret = InsertionTimeOrderedDict()
        for item in items:
            ret[key_func(item)] = item
116 117 118 119 120 121 122 123
        return ret


def decode(arr: bytes) -> str:
    """
    Decodes the passed byte array as UTF8 and handles invalid characters
    """
    return arr.decode("utf-8", "backslashreplace")