from collections import namedtuple import shutil from typing import Optional, List, Tuple, T, Union, Dict import collections from mjtest.environment import Environment, TestMode, TEST_MODES from os.path import join, exists, basename import logging import os import multiprocessing from mjtest.util.parallelism import available_cpu_count from mjtest.util.utils import cprint, colored from pprint import pprint import shutil import difflib _LOG = logging.getLogger("tests") RunResult = namedtuple("RunResult", ['count', 'failed']) class TestSuite: """ The whole set of tests. """ def __init__(self, env: Environment): self.env = env self.test_cases = {} # type: Dict[str, List[TestCase]] self.correct_test_cases = collections.defaultdict(set) # type: Dict[str, Set[str]] self._load_test_cases() def _load_test_cases(self): types = TEST_MODES[TEST_MODES.index(self.env.mode):] for type in types: self._load_test_case_type(type) def _load_test_case_type(self, type: str): dir = join(self.env.test_dir, type) if exists(dir): self._load_test_case_dir(type, dir) else: _LOG.info("Test folder {} doesn't exist".format(dir)) def _load_test_case_dir(self, mode: str, dir: str): self.test_cases[mode] = [] correct_test_cases = set() log_file = self._log_file_for_type(mode) if exists(log_file): with open(log_file) as f: correct_test_cases = set() for t in f.readlines(): t = t.strip() if len(t) > 0: self.correct_test_cases[mode].add(t) correct_test_cases.add(t) for file in sorted(os.listdir(dir)): if not TestCase.has_valid_file_ending(mode, file): _LOG.debug("Skip file " + file) elif self.env.only_incorrect_tests and file in correct_test_cases: _LOG.info("Skip file {} as its test case was executed correctly the last run") else: test_case = TestCase.create_from_file(self.env, self.env.mode, join(dir, file)) if not test_case.can_run(): _LOG.debug("Skip test case '{}' because it isn't suited".format(test_case.name())) else: self.test_cases[mode].append(test_case) if len(self.test_cases[mode]) == 0: del self.test_cases[mode] def _log_file_for_type(self, type: str): return join(self.env.test_dir, type, ".mjtest_correct_testcases") def _add_correct_test_case(self, test_case: 'TestCase'): self.correct_test_cases[test_case.type].add(basename(test_case.file)) def run(self) -> RunResult: ret = RunResult(0, 0) try: for mode in self.test_cases.keys(): single_ret = RunResult(0, 0) if self.env.parallel: single_ret = self._run_parallel(mode, available_cpu_count()) else: single_ret = self._run_sequential(mode) ret = RunResult(ret.count + single_ret.count, ret.failed + single_ret.failed) except BaseException: logging.exception("") finally: print("-" * 40) if ret.failed > 0: # some tests failed print(colored("Ran {} tests, of which ".format(ret.count), "red") + colored("{} failed.".format(ret.failed), "red", attrs=["bold"])) else: cprint("All {} run tests succeeded".format(ret.count), "green") if self.env.produce_reports and (self.env.produce_all_reports or ret.failed > 0): report_dir = self.env.report_dir + "." + ("successful" if ret.failed == 0 else "failed") os.rename(self.env.report_dir, report_dir) print("A full report for each test can be found at {}".format( os.path.relpath(report_dir))) return ret def _run_sequential(self, mode: str) -> RunResult: failed = 0 count = 0 for test_case in self.test_cases[mode]: ret = self._run_test_case(test_case) if ret is False or not ret.is_correct(): failed += 1 else: self._add_correct_test_case(test_case) count += 1 return RunResult(count, failed) def _func(self, test_case: 'TestCase'): ret = self._run_test_case(test_case) if ret is not False and ret.is_correct(): return 0, test_case return 1, test_case def _run_parallel(self, mode: str, parallel_jobs: int) -> RunResult: pool = multiprocessing.Pool(parallel_jobs) rets = pool.map(self._func, self.test_cases[mode]) result = RunResult(len(rets), sum(map(lambda x: x[0], rets))) for (suc, test_case) in rets: if suc == 0: self._add_correct_test_case(test_case) return result def _run_test_case(self, test_case: 'TestCase') -> Optional['TestResult']: try: ret = test_case.run() color = "green" if ret.is_correct() else "red" print(colored("[{result:7s}] {tc:40s}".format( result="SUCCESS" if ret.is_correct() else "FAIL", tc=test_case.name()), color, attrs=["bold"]) + colored("" if ret.is_correct() else ret.short_message(), color)) try: if self.env.produce_reports and (self.env.produce_all_reports or not ret.is_correct()): if not exists(self.env.report_dir): os.mkdir(self.env.report_dir) rep_dir = join(self.env.report_dir, test_case.type) if not exists(rep_dir): try: os.mkdir(rep_dir) except IOError: pass suffix = ".correct" if ret.is_correct() else ".incorrect" ret.store_at(join(rep_dir, test_case.short_name() + suffix)) if self.env.output_incorrect_reports and not ret.is_correct(): print(colored("Report for failing test case {}".format(test_case.short_name()), "red", attrs=["bold"])) print(colored(ret.long_message(), "red")) return ret except IOError: _LOG.exception("Caught i/o error while trying to store the report for '{}'" .format(test_case.name())) return False except KeyboardInterrupt: raise except BaseException: _LOG.exception("At test case '{}'".format(test_case.short_name())) return False def store(self): for mode in self.correct_test_cases.keys(): log_file = self._log_file_for_type(mode) try: with open(log_file, "w") as f: f.write("\n".join(self.correct_test_cases[mode])) except IOError as e: _LOG.exception("Caught i/o error while storing {}".format(log_file)) class TestCase: """ A single test case. """ TEST_CASE_CLASSES = dict((k, []) for k in TEST_MODES) FILE_ENDINGS = [] def __init__(self, env: Environment, type: str, file: str): self.env = env self.type = type self.file = file def should_succeed(self) -> bool: raise NotImplementedError() def can_run(self, mode: str = "") -> bool: mode = mode or self.env.mode return self.type == mode or \ (self.type in TEST_MODES[TEST_MODES.index(self.env.mode):] and self.should_succeed()) def run(self) -> 'TestResult': raise NotImplementedError() @classmethod def create_from_file(cls, env: Environment, mode: str, file: str) -> Optional['TestCase']: if cls.has_valid_file_ending(mode, file): return cls._test_case_class_for_file(mode, file)(env, mode, file) return None def name(self): return "{}:{}".format(self.type, self.short_name()) def short_name(self) -> str: raise NotImplementedError() @classmethod def _test_case_class_for_file(cls, type: str, file: str): for t in cls.TEST_CASE_CLASSES[type]: if any(file.endswith(e) for e in t.FILE_ENDINGS): return t return False @classmethod def has_valid_file_ending(cls, type: str, file: str): return cls._test_case_class_for_file(type, file) != False class TestResult: def __init__(self, test_case: TestCase, error_code: int): self.test_case = test_case self.error_code = error_code def is_correct(self) -> bool: return self.succeeded() == self.test_case.should_succeed() def succeeded(self) -> bool: return self.error_code == 0 def store_at(self, file: str): with open(file, "w") as f: print(self.long_message(), file=f) def short_message(self) -> str: raise NotImplementedError() def long_message(self) -> str: raise NotImplementedError() class BasicTestResult(TestResult): def __init__(self, test_case: TestCase, error_code: int, output: str, error_output: str): super().__init__(test_case, error_code) self._contains_error_str = "error" in error_output self.error_output = error_output self.output = output self.other_texts = [] # type: List[Tuple[str, str, bool]] def is_correct(self): if self.succeeded(): return super().is_correct() else: return super().is_correct() and self._contains_error_str def short_message(self) -> str: if self.is_correct(): return "correct" else: if not self.succeeded() and not self.test_case.should_succeed() and not self._contains_error_str: return "the error output doesn't contain the word \"error\"" return "incorrect return code" def long_message(self) -> str: file_content = [] with open(self.test_case.file, "r") as f: file_content = [line.rstrip() for line in f] others = [] for title, content, long_text in self.other_texts: if long_text: others.append(""" {}: {} """.format(title, self._ident(content))) else: others.append("""{}: {}\n""".format(title, content)) return """{} Source file: {} Output: {} Error output: {} Return code: {} {} """.format(self.short_message().capitalize(), self._ident(file_content), self._ident(self.output), self._ident(self.error_output), self.error_code, "\n".join(others)) def add_additional_text(self, title: str, content: str): self.other_texts.append((title, content, True)) def add_additional_text_line(self, title: str, content: str): self.other_texts.append((title, content, False)) def _ident(self, text: Union[str,List[str]]) -> str: arr = text if isinstance(text, list) else text.split("\n") if len(arr) == 0 or text == "": return "" arr = ["[{:04d}] {:s}".format(i + 1, l) for (i, l) in enumerate(arr)] return "\n".join(arr) class BasicDiffTestResult(BasicTestResult): def __init__(self, test_case: TestCase, error_code: int, output: str, error_output: str, expected_output: str): super().__init__(test_case, error_code, output, error_output) self.expected_output = expected_output self._is_output_correct = self.expected_output.strip() == self.output if self.is_correct(): self.add_additional_text("Expected and actual output", self.output) elif self.succeeded() and self.test_case.should_succeed(): self.add_additional_text("Diff[expected output, actual output]", self._output_diff()) self.add_additional_text("Expected output", self.expected_output) self.add_additional_text("Actual output", self.output) def is_correct(self): if self.succeeded(): return super().is_correct() and self.is_output_correct() else: return super().is_correct() and self._contains_error_str def _output_diff(self) -> str: return difflib.Differ().compare(self.expected_output, self.output) def is_output_correct(self) -> str: return self._is_output_correct def short_message(self) -> str: if self.is_correct(): return "correct" else: if not self.succeeded() and not self.test_case.should_succeed() and not self._contains_error_str: return "the error output doesn't contain the word \"error\"" if self.succeeded() and self.test_case.should_succeed(): return "the actual output differs from the expected" return "incorrect return code" import mjtest.test.syntax_tests import mjtest.test.ast_tests