bench.py 6.22 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
import hashlib
import logging
import os
import shutil
import signal
from os import path
from typing import List, Tuple

import math

import subprocess

import time

from mjtest.environment import TestMode, Environment
from mjtest.test.syntax_tests import BasicSyntaxTest
from mjtest.test.tests import TestCase, BasicDiffTestResult, BasicTestResult, ExtensibleTestResult
from mjtest.util.shell import SigKill
from mjtest.util.utils import get_main_class_name, InsertionTimeOrderedDict

_LOG = logging.getLogger("bench_tests")

class _RunResult:

    def __init__(self, runs: List[float], is_correct: bool):
        self.runs = runs
        self.is_correct = is_correct

    def mean(self) -> float:
        return sum(self.runs) / self.number()

    def stddev(self) -> float:
        m = self.mean()
        return math.sqrt(sum(map(lambda x: (x - m) ** 2, self.runs)) / self.number())

    def min(self) -> float:
        return min(self.runs)

    def number(self) -> int:
        return len(self.runs)


class BenchExecTest(BasicSyntaxTest):
    """
    Simple benchmark test. The new compiler mode shouldn't be slower than the old ones (or javac)
    """

    FILE_ENDINGS = [".java", ".mj"]
    INVALID_FILE_ENDINGS = [".inf.java", ".inf.mj"]
    MODE = TestMode.compile_firm

    def __init__(self, env: Environment, type: str, file: str, preprocessed_file: str):
        super().__init__(env, type, file, preprocessed_file)
        self._should_succeed = True

    def _bench_command(self, cmd: str, *args: Tuple[str]) -> _RunResult:
        runs = []  # type: List[float]
        for i in range(0, self.env.bench_runs):
            try:
                start = time.time()
                subprocess.check_call([cmd] + list(args), stdout=subprocess.DEVNULL)
                runs.append(time.time() - start)
            except subprocess.CalledProcessError:
                return _RunResult([], False)
        return _RunResult(runs, True)

    def run(self) -> BasicDiffTestResult:
        is_big_testcase = "big" in self.file
        timeout = self.env.big_timeout if is_big_testcase else self.env.timeout
        base_filename = path.basename(self.file).split(".")[0]
        tmp_dir = self.env.create_pid_local_tmpdir()
        shutil.copy(self.preprocessed_file, path.join(tmp_dir, base_filename + ".java"))
        cwd = os.getcwd()
        os.chdir(tmp_dir)

        test_result = ExtensibleTestResult(self)

        results = []  # type: List[_RunResult]

        for compiler_flag in self.env.bench_compiler_flags:
            if compiler_flag == "javac":
                _, err, javac_rtcode = \
                    self.env.run_command("javac", base_filename + ".java", timeout=timeout)
                if javac_rtcode != 0:
                    _LOG.error("File \"{}\" isn't valid Java".format(self.preprocessed_file))
                    test_result.incorrect_msg = "invalid java code, but output file missing"
                    test_result.set_error_code(javac_rtcode)
                    test_result.add_long_text("Javac error message", err.decode())
                    test_result.add_file("Source file", self.preprocessed_file)
                    os.chdir(cwd)
                    return test_result
                main_class = get_main_class_name(base_filename + ".java")
                if not main_class:
                    _LOG.debug("Can't find a main class, using the file name instead")
                    main_class = base_filename
                results.append(self._bench_command("java", main_class))
            else:
                try:
                    compiler_flag = compiler_flag.replace("\\-", "-")
                    out, err, rtcode = self.env.run_command(self.env.mj_run_cmd, compiler_flag, base_filename + ".java",
                                                            timeout=timeout)
                    if rtcode != 0:
                        test_result.incorrect_msg = "file can't be compiled"
                        test_result.set_error_code(rtcode)
                        test_result.add_long_text("Error output", err.decode())
                        test_result.add_long_text("Output", out.decode())
                        test_result.add_file("Source file", self.preprocessed_file)
                        os.chdir(cwd)
                        return test_result
                except SigKill as sig:
                    test_result.incorrect_msg = "file can't be compiled: " + sig.name
                    test_result.set_error_code(sig.retcode)
                    test_result.add_file("Source file", self.preprocessed_file)
                    os.chdir(cwd)
                    return test_result
                except:
                    os.chdir(cwd)
                    raise
                results.append(self._bench_command("./a.out"))
        os.chdir(cwd)
        assert len(results) == 2
        if not results[0].is_correct or not results[1].is_correct:
            incorrect_flags = [self.env.bench_compiler_flags[i] for (i, res) in enumerate(results) if not res.is_correct]
            test_result.incorrect_msg = "Running with {} failed".format(", ".join(incorrect_flags))
            test_result.has_succeeded = False
            return test_result
        msg_parts = []
        stddev = max(results[0].stddev() / results[0].mean(), results[1].stddev() / results[1].mean())
        rel_min = results[0].min() / results[1].min()
        msg_parts.append("min(0/1) = {:.0%}".format(rel_min))
        rel_mean = results[0].mean() / results[1].mean()
        msg_parts.append("mean(0/1) = {:.0%}".format(rel_mean))
        msg_parts.append("-1 / std = {:.0%}".format((rel_mean - 1) / stddev))
        for (i, res) in enumerate(results):
            test_result.add_short_text("min({})".format(i), res.min())
            test_result.add_short_text("mean({})".format(i), res.mean())
            test_result.add_short_text("stddev({})".format(i), res.stddev())
        if (rel_mean - 1) / stddev <= 1:
            msg_parts.append("first faster")
            test_result.incorrect_msg = ", ".join(msg_parts)
            test_result.has_succeeded = False
            return test_result
        test_result.correct_msg = ", ".join(msg_parts)
        test_result.has_succeeded = True
        return test_result



TestCase.TEST_CASE_CLASSES[TestMode.bench].append(BenchExecTest)