bench.py 6.22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import hashlib
import logging
import os
import shutil
import signal
from os import path
from typing import List, Tuple

import math

import subprocess

import time

from mjtest.environment import TestMode, Environment
from mjtest.test.syntax_tests import BasicSyntaxTest
from mjtest.test.tests import TestCase, BasicDiffTestResult, BasicTestResult, ExtensibleTestResult
from mjtest.util.shell import SigKill
19
from mjtest.util.utils import get_main_class_name, InsertionTimeOrderedDict, decode
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

_LOG = logging.getLogger("bench_tests")

class _RunResult:

    def __init__(self, runs: List[float], is_correct: bool):
        self.runs = runs
        self.is_correct = is_correct

    def mean(self) -> float:
        return sum(self.runs) / self.number()

    def stddev(self) -> float:
        m = self.mean()
        return math.sqrt(sum(map(lambda x: (x - m) ** 2, self.runs)) / self.number())

    def min(self) -> float:
        return min(self.runs)

    def number(self) -> int:
        return len(self.runs)


class BenchExecTest(BasicSyntaxTest):
    """
    Simple benchmark test. The new compiler mode shouldn't be slower than the old ones (or javac)
    """

    FILE_ENDINGS = [".java", ".mj"]
    INVALID_FILE_ENDINGS = [".inf.java", ".inf.mj"]
    MODE = TestMode.compile_firm

    def __init__(self, env: Environment, type: str, file: str, preprocessed_file: str):
        super().__init__(env, type, file, preprocessed_file)
        self._should_succeed = True

    def _bench_command(self, cmd: str, *args: Tuple[str]) -> _RunResult:
        runs = []  # type: List[float]
        for i in range(0, self.env.bench_runs):
            try:
                start = time.time()
                subprocess.check_call([cmd] + list(args), stdout=subprocess.DEVNULL)
                runs.append(time.time() - start)
            except subprocess.CalledProcessError:
                return _RunResult([], False)
        return _RunResult(runs, True)

    def run(self) -> BasicDiffTestResult:
        is_big_testcase = "big" in self.file
        timeout = self.env.big_timeout if is_big_testcase else self.env.timeout
        base_filename = path.basename(self.file).split(".")[0]
        tmp_dir = self.env.create_pid_local_tmpdir()
        shutil.copy(self.preprocessed_file, path.join(tmp_dir, base_filename + ".java"))
        cwd = os.getcwd()
        os.chdir(tmp_dir)

        test_result = ExtensibleTestResult(self)

        results = []  # type: List[_RunResult]

        for compiler_flag in self.env.bench_compiler_flags:
            if compiler_flag == "javac":
                _, err, javac_rtcode = \
                    self.env.run_command("javac", base_filename + ".java", timeout=timeout)
                if javac_rtcode != 0:
                    _LOG.error("File \"{}\" isn't valid Java".format(self.preprocessed_file))
                    test_result.incorrect_msg = "invalid java code, but output file missing"
                    test_result.set_error_code(javac_rtcode)
88
                    test_result.add_long_text("Javac error message", decode(err))
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
                    test_result.add_file("Source file", self.preprocessed_file)
                    os.chdir(cwd)
                    return test_result
                main_class = get_main_class_name(base_filename + ".java")
                if not main_class:
                    _LOG.debug("Can't find a main class, using the file name instead")
                    main_class = base_filename
                results.append(self._bench_command("java", main_class))
            else:
                try:
                    compiler_flag = compiler_flag.replace("\\-", "-")
                    out, err, rtcode = self.env.run_command(self.env.mj_run_cmd, compiler_flag, base_filename + ".java",
                                                            timeout=timeout)
                    if rtcode != 0:
                        test_result.incorrect_msg = "file can't be compiled"
                        test_result.set_error_code(rtcode)
105
106
                        test_result.add_long_text("Error output", decode(err))
                        test_result.add_long_text("Output", decode(out))
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
                        test_result.add_file("Source file", self.preprocessed_file)
                        os.chdir(cwd)
                        return test_result
                except SigKill as sig:
                    test_result.incorrect_msg = "file can't be compiled: " + sig.name
                    test_result.set_error_code(sig.retcode)
                    test_result.add_file("Source file", self.preprocessed_file)
                    os.chdir(cwd)
                    return test_result
                except:
                    os.chdir(cwd)
                    raise
                results.append(self._bench_command("./a.out"))
        os.chdir(cwd)
        assert len(results) == 2
        if not results[0].is_correct or not results[1].is_correct:
            incorrect_flags = [self.env.bench_compiler_flags[i] for (i, res) in enumerate(results) if not res.is_correct]
            test_result.incorrect_msg = "Running with {} failed".format(", ".join(incorrect_flags))
            test_result.has_succeeded = False
            return test_result
        msg_parts = []
        stddev = max(results[0].stddev() / results[0].mean(), results[1].stddev() / results[1].mean())
        rel_min = results[0].min() / results[1].min()
        msg_parts.append("min(0/1) = {:.0%}".format(rel_min))
        rel_mean = results[0].mean() / results[1].mean()
        msg_parts.append("mean(0/1) = {:.0%}".format(rel_mean))
        msg_parts.append("-1 / std = {:.0%}".format((rel_mean - 1) / stddev))
        for (i, res) in enumerate(results):
            test_result.add_short_text("min({})".format(i), res.min())
            test_result.add_short_text("mean({})".format(i), res.mean())
            test_result.add_short_text("stddev({})".format(i), res.stddev())
        if (rel_mean - 1) / stddev <= 1:
            msg_parts.append("first faster")
            test_result.incorrect_msg = ", ".join(msg_parts)
            test_result.has_succeeded = False
            return test_result
        test_result.correct_msg = ", ".join(msg_parts)
        test_result.has_succeeded = True
        return test_result



TestCase.TEST_CASE_CLASSES[TestMode.bench].append(BenchExecTest)