Commit 76c8fbb6 authored by Johannes Bechberger's avatar Johannes Bechberger

Fix ast mode tests

parent 4c8ebdac
...@@ -72,8 +72,7 @@ Test types for the ast mode ...@@ -72,8 +72,7 @@ Test types for the ast mode
<tr><th>File ending(s) of test cases</th><th>Expected behaviour to complete a test of this type</th></tr> <tr><th>File ending(s) of test cases</th><th>Expected behaviour to complete a test of this type</th></tr>
<tr> <tr>
<td><code>.valid.mj</code> <code>.mj</code> <code>.valid.java</code> <code>.java</code></td> <td><code>.valid.mj</code> <code>.mj</code> <code>.valid.java</code> <code>.java</code></td>
<td>Pretty printing the source file should result in the same output as pretty printing the already pretty printed file. <td>Pretty printing the source file should result in the same output as pretty printing the already pretty printed file.</td>
All lines in the lexer output for the source file should be present in the lexer output of the pretty printed file.</td>
</tr> </tr>
</table> </table>
......
import logging import logging
import os import os
import random
import shutil import shutil
import tempfile import tempfile
from datetime import datetime, time from datetime import datetime
import time
from mjtest.util.shell import execute from mjtest.util.shell import execute
from mjtest.util.utils import get_mjtest_basedir from mjtest.util.utils import get_mjtest_basedir
from typing import Tuple, List from typing import Tuple, List
...@@ -83,9 +85,12 @@ class Environment: ...@@ -83,9 +85,12 @@ class Environment:
self.output_incorrect_reports = not output_no_incorrect_reports self.output_incorrect_reports = not output_no_incorrect_reports
self.produce_all_reports = produce_all_reports self.produce_all_reports = produce_all_reports
self.ci_testing = ci_testing self.ci_testing = ci_testing
self._tmp_file_ctr = 0
def create_tmpfile(self) -> str: def create_tmpfile(self) -> str:
return os.path.join(self.tmp_dir, str(os.times())) self._tmp_file_ctr += 1
return os.path.join(self.tmp_dir, str(round(time.time() * 100000))
+ str(random.randrange(0, 10000, 1)) + str(self._tmp_file_ctr))
def create_tmpdir(self) -> str: def create_tmpdir(self) -> str:
dir = self.create_tmpfile() dir = self.create_tmpfile()
......
...@@ -2,7 +2,7 @@ import difflib ...@@ -2,7 +2,7 @@ import difflib
import os import os
import shutil, logging import shutil, logging
from typing import Tuple, Dict from typing import Tuple, Dict
from mjtest.environment import Environment, TestMode from mjtest.environment import Environment, TestMode, TEST_MODES
from mjtest.test.syntax_tests import BasicSyntaxTest from mjtest.test.syntax_tests import BasicSyntaxTest
from mjtest.test.tests import TestCase, BasicDiffTestResult, BasicTestResult from mjtest.test.tests import TestCase, BasicDiffTestResult, BasicTestResult
from os import path from os import path
...@@ -17,6 +17,8 @@ class ASTPrettyPrintTest(BasicSyntaxTest): ...@@ -17,6 +17,8 @@ class ASTPrettyPrintTest(BasicSyntaxTest):
def __init__(self, env: Environment, type: str, file: str): def __init__(self, env: Environment, type: str, file: str):
super().__init__(env, type, file) super().__init__(env, type, file)
if type != TestMode.ast and TEST_MODES.index(TestMode.ast) < TEST_MODES.index(type):
self._should_succeed = True
def run(self) -> BasicTestResult: def run(self) -> BasicTestResult:
tmp_file = self.env.create_tmpfile() tmp_file = self.env.create_tmpfile()
...@@ -28,65 +30,35 @@ class ASTPrettyPrintTest(BasicSyntaxTest): ...@@ -28,65 +30,35 @@ class ASTPrettyPrintTest(BasicSyntaxTest):
tmp_file2 = self.env.create_tmpfile() tmp_file2 = self.env.create_tmpfile()
rtcode, out2, err2 = self._pretty_print(tmp_file, tmp_file2) rtcode, out2, err2 = self._pretty_print(tmp_file, tmp_file2)
if rtcode > 0: if rtcode > 0:
os.remove(tmp_file)
os.remove(tmp_file2) os.remove(tmp_file2)
btr = BasicTestResult(self, rtcode, out2, err2) btr = BasicTestResult(self, rtcode, out2, err2)
btr.add_additional_text("Prior out", out) btr.add_additional_text("Prior out", out)
btr.add_additional_text("Prior err", err) btr.add_additional_text("Prior err", err)
return btr return btr
rtcode_lex, out_lex, err_lex = self.env.run_mj_command(TestMode.lexer, self.file)
rtcode_lex2, out_lex2, err_lex2 = self.env.run_mj_command(TestMode.lexer, tmp_file2) #out_lex = self._line_count(out_lex.decode())
os.remove(tmp_file2) #out_lex2 = self._line_count(out_lex2.decode())
out_lex = self._line_count(out_lex.decode()) #comp = self._comp_dicts(out_lex, out_lex2)
out_lex2 = self._line_count(out_lex2.decode())
comp = self._comp_dicts(out_lex, out_lex2)
incorrect_msg, rtcode = "", 0 incorrect_msg, rtcode = "", 0
if rtcode_lex + rtcode_lex2: if out2 != out:
incorrect_msg, rtcode = "Lexing failed", 1
elif out != out2:
incorrect_msg, rtcode = "Not idempotent", 1 incorrect_msg, rtcode = "Not idempotent", 1
elif comp[0]:
incorrect_msg, rtcode = "Sorted and lexed second pretty print differs from original", 1
btr = BasicTestResult(self, rtcode, incorrect_msg=incorrect_msg) btr = BasicTestResult(self, rtcode, incorrect_msg=incorrect_msg)
btr.add_additional_text("First round output", out) btr.add_additional_text("First round output", out)
btr.add_additional_text("Second round output", out2) btr.add_additional_text("Second round output", out2)
btr.add_additional_text("Diff", self._diff(out, out2)) btr.add_additional_text("Diff", self._diff(out, out2))
btr.add_additional_text("Original file, sorted and lexed", self._dict_to_str(out_lex)) os.remove(tmp_file)
btr.add_additional_text("Second round output, sorted and lexed", self._dict_to_str(out_lex2)) os.remove(tmp_file2)
btr.add_additional_text("Diff", comp[1])
return btr return btr
def _diff(self, first: str, second: str) -> str: def _diff(self, first: str, second: str) -> str:
return "".join(difflib.Differ().compare(first.splitlines(True), second.splitlines(True))) return "".join(difflib.Differ().compare(first.splitlines(True), second.splitlines(True)))
def _sort_lexed(self, lexed: str) -> str:
#return "".join(difflib.Differ().compare(self.expected_output.splitlines(True), self.output.splitlines(True)))
return "".join(sorted(lexed.splitlines(True)))
def _line_count(self, lexed: str) -> Dict[str, int]:
d = {}
for l in lexed.splitlines():
if l not in d:
d[l] = 1
else:
d[l] += 1
return d
def _dict_to_str(self, d: Dict[str, int]) -> str:
return "\n".join("{} {}".format(k, d[k]) for k in sorted(d.keys()))
def _comp_dicts(self, d_orig: Dict[str, int], d_second: Dict[str, int]) -> (bool, str):
strs = []
for k in sorted(d_orig.keys()):
c_orig = d_orig[k]
c_second = d_second[k] if k in d_second else 0
if c_orig > c_second:
strs.append("{} [original file contained {}, pretty printed only {}]".format(c_orig, c_second))
return len(strs) == 0, "\n".join(strs)
def _pretty_print(self, input_file: str, output_file: str) -> Tuple[int, str, str]: def _pretty_print(self, input_file: str, output_file: str) -> Tuple[int, str, str]:
out, err, rtcode = self.env.run_mj_command(TestMode.ast, input_file) out, err, rtcode = self.env.run_mj_command(TestMode.ast, input_file)
with open(output_file, "wb") as f: with open(output_file, "wb") as f:
f.write(out) f.write(out)
f.flush()
return rtcode, out.decode(), err.decode() return rtcode, out.decode(), err.decode()
TestCase.TEST_CASE_CLASSES["ast"].append(ASTPrettyPrintTest) TestCase.TEST_CASE_CLASSES["ast"].append(ASTPrettyPrintTest)
...@@ -15,6 +15,7 @@ class BasicSyntaxTest(TestCase): ...@@ -15,6 +15,7 @@ class BasicSyntaxTest(TestCase):
self._should_succeed = True self._should_succeed = True
else: else:
self._should_succeed = not file.endswith(".invalid.mj") and not file.endswith(".invalid.java") self._should_succeed = not file.endswith(".invalid.mj") and not file.endswith(".invalid.java")
def should_succeed(self) -> bool: def should_succeed(self) -> bool:
return self._should_succeed return self._should_succeed
......
...@@ -189,6 +189,10 @@ class TestSuite: ...@@ -189,6 +189,10 @@ class TestSuite:
for mode in self.correct_test_cases.keys(): for mode in self.correct_test_cases.keys():
log_file = self._log_file_for_type(mode) log_file = self._log_file_for_type(mode)
try: try:
try:
os.mkdir(os.path.dirname(log_file))
except IOError:
pass
with open(log_file, "w") as f: with open(log_file, "w") as f:
f.write("\n".join(self.correct_test_cases[mode])) f.write("\n".join(self.correct_test_cases[mode]))
except IOError as e: except IOError as e:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment