ast_tests.py 3.86 KB
Newer Older
Johannes Bechberger's avatar
Johannes Bechberger committed
1 2
import difflib
import os
Johannes Bechberger's avatar
Johannes Bechberger committed
3
import shutil, logging
4
from typing import Tuple, Dict
Johannes Bechberger's avatar
Johannes Bechberger committed
5
from mjtest.environment import Environment, TestMode
Johannes Bechberger's avatar
Johannes Bechberger committed
6
from mjtest.test.syntax_tests import BasicSyntaxTest
7
from mjtest.test.tests import TestCase, BasicDiffTestResult, BasicTestResult
Johannes Bechberger's avatar
Johannes Bechberger committed
8 9 10 11 12
from os import path

_LOG = logging.getLogger("tests")


Johannes Bechberger's avatar
Johannes Bechberger committed
13
class ASTPrettyPrintTest(BasicSyntaxTest):
Johannes Bechberger's avatar
Johannes Bechberger committed
14

Johannes Bechberger's avatar
Johannes Bechberger committed
15 16
    FILE_ENDINGS = [".mj", ".valid.mj"]
    INVALID_FILE_ENDINGS = [".invalid.mj"]
Johannes Bechberger's avatar
Johannes Bechberger committed
17 18 19

    def __init__(self, env: Environment, type: str, file: str):
        super().__init__(env, type, file)
Johannes Bechberger's avatar
Johannes Bechberger committed
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38

    def run(self) -> BasicTestResult:
        tmp_file = self.env.create_tmpfile()
        rtcode, out, err = self._pretty_print(self.file, tmp_file)
        if rtcode > 0:
            os.remove(tmp_file)
            return BasicTestResult(self, rtcode, out, err)
        _file = self.file
        tmp_file2 = self.env.create_tmpfile()
        rtcode, out2, err2 = self._pretty_print(tmp_file, tmp_file2)
        if rtcode > 0:
            os.remove(tmp_file2)
            btr = BasicTestResult(self, rtcode, out2, err2)
            btr.add_additional_text("Prior out", out)
            btr.add_additional_text("Prior err", err)
            return btr
        rtcode_lex, out_lex, err_lex = self.env.run_mj_command(TestMode.lexer, self.file)
        rtcode_lex2, out_lex2, err_lex2 = self.env.run_mj_command(TestMode.lexer, tmp_file2)
        os.remove(tmp_file2)
39 40 41
        out_lex = self._line_count(out_lex.decode())
        out_lex2 = self._line_count(out_lex2.decode())
        comp = self._comp_dicts(out_lex, out_lex2)
Johannes Bechberger's avatar
Johannes Bechberger committed
42 43 44 45 46
        incorrect_msg, rtcode = "", 0
        if rtcode_lex + rtcode_lex2:
            incorrect_msg, rtcode = "Lexing failed", 1
        elif out != out2:
            incorrect_msg, rtcode = "Not idempotent", 1
47
        elif comp[0]:
Johannes Bechberger's avatar
Johannes Bechberger committed
48 49 50 51 52
            incorrect_msg, rtcode = "Sorted and lexed second pretty print differs from original", 1
        btr = BasicTestResult(self, rtcode, incorrect_msg=incorrect_msg)
        btr.add_additional_text("First round output", out)
        btr.add_additional_text("Second round output", out2)
        btr.add_additional_text("Diff", self._diff(out, out2))
53 54 55
        btr.add_additional_text("Original file, sorted and lexed", self._dict_to_str(out_lex))
        btr.add_additional_text("Second round output, sorted and lexed", self._dict_to_str(out_lex2))
        btr.add_additional_text("Diff", comp[1])
Johannes Bechberger's avatar
Johannes Bechberger committed
56 57 58 59 60 61 62 63 64
        return btr

    def _diff(self, first: str, second: str) -> str:
        return "".join(difflib.Differ().compare(first.splitlines(True), second.splitlines(True)))

    def _sort_lexed(self, lexed: str) -> str:
                #return "".join(difflib.Differ().compare(self.expected_output.splitlines(True), self.output.splitlines(True)))
        return "".join(sorted(lexed.splitlines(True)))

65 66 67 68 69 70 71 72
    def _line_count(self, lexed: str) -> Dict[str, int]:
        d = {}
        for l in lexed.splitlines():
            if l not in d:
                d[l] = 1
            else:
                d[l] += 1
        return d
Johannes Bechberger's avatar
Johannes Bechberger committed
73

74 75 76 77 78 79 80 81 82 83 84
    def _dict_to_str(self, d: Dict[str, int]) -> str:
        return "\n".join("{} {}".format(k, d[k]) for k in sorted(d.keys()))

    def _comp_dicts(self, d_orig: Dict[str, int], d_second: Dict[str, int]) -> (bool, str):
        strs = []
        for k in sorted(d_orig.keys()):
            c_orig = d_orig[k]
            c_second =  d_second[k] if k in d_second else 0
            if c_orig > c_second:
                strs.append("{} [original file contained {}, pretty printed only {}]".format(c_orig, c_second))
        return len(strs) == 0, "\n".join(strs)
Johannes Bechberger's avatar
Johannes Bechberger committed
85 86 87

    def _pretty_print(self, input_file: str, output_file: str) -> Tuple[int, str, str]:
        out, err, rtcode = self.env.run_mj_command(TestMode.ast, input_file)
88 89
        with open(output_file, "wb") as f:
            f.write(out)
Johannes Bechberger's avatar
Johannes Bechberger committed
90 91
        return rtcode, out.decode(), err.decode()

92
TestCase.TEST_CASE_CLASSES["ast"].append(ASTPrettyPrintTest)