import difflib import os import shutil, logging from typing import Tuple, Dict from mjtest.environment import Environment, TestMode from mjtest.test.syntax_tests import BasicSyntaxTest from mjtest.test.tests import TestCase, BasicDiffTestResult, BasicTestResult from os import path _LOG = logging.getLogger("tests") class ASTPrettyPrintTest(BasicSyntaxTest): FILE_ENDINGS = [".mj", ".valid.mj"] INVALID_FILE_ENDINGS = [".invalid.mj"] def __init__(self, env: Environment, type: str, file: str): super().__init__(env, type, file) def run(self) -> BasicTestResult: tmp_file = self.env.create_tmpfile() rtcode, out, err = self._pretty_print(self.file, tmp_file) if rtcode > 0: os.remove(tmp_file) return BasicTestResult(self, rtcode, out, err) _file = self.file tmp_file2 = self.env.create_tmpfile() rtcode, out2, err2 = self._pretty_print(tmp_file, tmp_file2) if rtcode > 0: os.remove(tmp_file2) btr = BasicTestResult(self, rtcode, out2, err2) btr.add_additional_text("Prior out", out) btr.add_additional_text("Prior err", err) return btr rtcode_lex, out_lex, err_lex = self.env.run_mj_command(TestMode.lexer, self.file) rtcode_lex2, out_lex2, err_lex2 = self.env.run_mj_command(TestMode.lexer, tmp_file2) os.remove(tmp_file2) out_lex = self._line_count(out_lex.decode()) out_lex2 = self._line_count(out_lex2.decode()) comp = self._comp_dicts(out_lex, out_lex2) incorrect_msg, rtcode = "", 0 if rtcode_lex + rtcode_lex2: incorrect_msg, rtcode = "Lexing failed", 1 elif out != out2: incorrect_msg, rtcode = "Not idempotent", 1 elif comp[0]: incorrect_msg, rtcode = "Sorted and lexed second pretty print differs from original", 1 btr = BasicTestResult(self, rtcode, incorrect_msg=incorrect_msg) btr.add_additional_text("First round output", out) btr.add_additional_text("Second round output", out2) btr.add_additional_text("Diff", self._diff(out, out2)) btr.add_additional_text("Original file, sorted and lexed", self._dict_to_str(out_lex)) btr.add_additional_text("Second round output, sorted and lexed", self._dict_to_str(out_lex2)) btr.add_additional_text("Diff", comp[1]) return btr def _diff(self, first: str, second: str) -> str: return "".join(difflib.Differ().compare(first.splitlines(True), second.splitlines(True))) def _sort_lexed(self, lexed: str) -> str: #return "".join(difflib.Differ().compare(self.expected_output.splitlines(True), self.output.splitlines(True))) return "".join(sorted(lexed.splitlines(True))) def _line_count(self, lexed: str) -> Dict[str, int]: d = {} for l in lexed.splitlines(): if l not in d: d[l] = 1 else: d[l] += 1 return d def _dict_to_str(self, d: Dict[str, int]) -> str: return "\n".join("{} {}".format(k, d[k]) for k in sorted(d.keys())) def _comp_dicts(self, d_orig: Dict[str, int], d_second: Dict[str, int]) -> (bool, str): strs = [] for k in sorted(d_orig.keys()): c_orig = d_orig[k] c_second = d_second[k] if k in d_second else 0 if c_orig > c_second: strs.append("{} [original file contained {}, pretty printed only {}]".format(c_orig, c_second)) return len(strs) == 0, "\n".join(strs) def _pretty_print(self, input_file: str, output_file: str) -> Tuple[int, str, str]: out, err, rtcode = self.env.run_mj_command(TestMode.ast, input_file) with open(output_file, "wb") as f: f.write(out) return rtcode, out.decode(), err.decode() TestCase.TEST_CASE_CLASSES["ast"].append(ASTPrettyPrintTest)