utils.py 3.64 KB
Newer Older
Johannes Bechberger's avatar
Johannes Bechberger committed
1
2
3
import logging
from os import path
import sys
4
from typing import Tuple, Optional, Any, List, Callable
5
import re
6

7
COLOR_OUTPUT_IF_POSSIBLE = False
8

9
10
11
12
13
14
if sys.stdout.isatty():
    try:
        import termcolor
        COLOR_OUTPUT_IF_POSSIBLE = True
    except ImportError:
        COLOR_OUTPUT_IF_POSSIBLE = False
15
16


17
def force_colored_output():
18
    global COLOR_OUTPUT_IF_POSSIBLE
19
    COLOR_OUTPUT_IF_POSSIBLE = True
20

Johannes Bechberger's avatar
Johannes Bechberger committed
21
22

def get_mjtest_basedir() -> str:
23
    return path.dirname(path.dirname(path.dirname(path.realpath(__file__))))
Johannes Bechberger's avatar
Johannes Bechberger committed
24
25


26
27
28
29
def colored(text: str, *args, **kwargs):
    """
    Wrapper around termcolor.colored (if it's loadable)
    """
30
31
32
33
    global COLOR_OUTPUT_IF_POSSIBLE
    if COLOR_OUTPUT_IF_POSSIBLE:
        return termcolor.colored(text, *args, **kwargs)
    else:
34
35
36
37
38
39
40
        return text


def cprint(text: str, *args, **kwargs):
    """
    Wrapper around termcolor.cprint (if it's loadable)
    """
41
42
43
44
    global COLOR_OUTPUT_IF_POSSIBLE
    if COLOR_OUTPUT_IF_POSSIBLE:
        termcolor.cprint(text, *args, **kwargs)
    else:
45
46
47
48
49
        print(text)


def get_main_class_name(file: str) -> Optional[str]:
    current_class = None
50
    with open(file, "r", errors="backslashreplace") as f:
51
        for line in f:
52
53
            if line.startswith("class ") or line.startswith("/*public*/ class "):
                match = re.search("[A-Za-z_0-9]+", line.replace("class ", "").replace("/*public*/", ""))
54
55
56
57
58
59
                if match:
                    has_public_class = True
                    current_class = match.group(0)
            elif "String[]" in line and "main" in line and "void" in line and "static" in line and "public" in line:
                return current_class
    return None
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115


class InsertionTimeOrderedDict:
    """
    A dictionary which's elements are sorted by their insertion time.
    """

    def __init__(self):
        self._dict = {}
        self._keys = []
        dict()

    def __delitem__(self, key):
        """ Remove the entry with the passed key """
        del(self._dict[key])
        del(self._keys[self._keys.index(key)])

    def __getitem__(self, key):
        """ Get the entry with the passed key """
        return self._dict[key]

    def __setitem__(self, key, value):
        """ Set the value of the item with the passed key """
        if key not in self._dict:
            self._keys.append(key)
        self._dict[key] = value

    def __iter__(self):
        """ Iterate over all keys """
        return self._keys.__iter__()

    def values(self) -> List:
        """ Rerturns all values of this dictionary. They are sorted by their insertion time. """
        return [self._dict[key] for key in self._keys]

    def keys(self) -> List:
        """ Returns all keys of this dictionary. They are sorted by their insertion time. """
        return self._keys

    def __len__(self):
        """ Returns the number of items in this dictionary """
        return len(self._keys)

    @classmethod
    def from_list(cls, items: Optional[list], key_func: Callable[[Any], Any]) -> 'InsertionTimeOrderedDict':
        """
        Creates an ordered dict out of a list of elements.
        :param items: list of elements
        :param key_func: function that returns a key for each passed list element
        :return: created ordered dict with the elements in the same order as in the passed list
        """
        if items is None:
            return InsertionTimeOrderedDict()
        ret = InsertionTimeOrderedDict()
        for item in items:
            ret[key_func(item)] = item
116
117
118
119
120
121
122
123
        return ret


def decode(arr: bytes) -> str:
    """
    Decodes the passed byte array as UTF8 and handles invalid characters
    """
    return arr.decode("utf-8", "backslashreplace")