preprocessor.py 6.45 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
from collections import defaultdict
import logging
import os
from os.path import relpath
from pprint import pprint
import re
import sys

_LOG = logging.getLogger("preprocessor")

class PreProcessorError(BaseException):

    def __init__(self, msg):
        super().__init__(msg)

class PreProcessor:

    def __init__(self, src_file: str, import_base_dir: str, dst_file: str):
        self.src_file = src_file
        self.import_base_dir = import_base_dir
        self.dst_file = dst_file
        self.imported_strs = []
        self._already_imported_classes = {}  # name -> full_name
24 25
        self._import_regexp = re.compile("import [A-Za-z.0-9]+;")
        self._imported_class_regexp = re.compile("[A-Za-z.0-9]+;")
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
        self._imported_classes = defaultdict(lambda: [])  # name -> embedding files
        if not os.path.isfile(src_file):
            raise PreProcessorError("Source file '{}' isn't a file".format(src_file))
        if not os.path.isdir(import_base_dir):
            raise PreProcessorError("MJ_IMPORT_DIR '{}' isn't a directory".format(import_base_dir))
        if not dst_file != "-":
            if os.path.isdir(dst_file):
                raise PreProcessorError("Destination file '{}' isn't a file or '-'".format(dst_file))
            if os.path.realpath(src_file) == os.path.realpath(dst_file):
                raise PreProcessorError("Destination file '{}' is equal to the source file".format(dst_file))


    def preprocess(self):
        #if is_importable_file(self.src_file):
        #    raise PreProcessorError("Can't pre process importable file '{}'".format(self.src_file))
        self._preprocess_file(self.src_file)
        self._store_in_dst_file()

    def _preprocess_file(self, file: str):
        lines = []
        middle_lines = []

        def add_commented(line: str):
            middle_lines.append("/*{}*/".format(line))

51
        with open(file, "r", errors="backslashreplace") as f:
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
            for line in f:
                line = line.rstrip()
                if self._import_regexp.match(line):
                    _LOG.debug("File '{}': parse '{}'".format(file, line))
                    full_name = self._imported_class_regexp.search(line).group(0)[:-1]
                    _LOG.debug("File '{}': import class {}".format(file, full_name))
                    self._import_file(full_name)
                    self._imported_classes[relpath(self._file_name_for_full_class_name(full_name))]\
                        .append(_class_name_for_file(file))
                    add_commented(line)
                elif line.startswith("public class "):
                    _LOG.debug("File '{}': modify '{}'".format(file, line))
                    line = line.replace("public class ", "/*public*/ class ", 1)
                    middle_lines.append(line)
                elif line.startswith("package"):
                    _LOG.debug("File '{}': ignore '{}'".format(file, line))
                    add_commented(line)
                else:
                    middle_lines.append(line)

        if file != self.src_file:
            lines.append("/* ##################### \n")
            lines.append("   imported from file: {}".format(relpath(file, self.import_base_dir)))
            #lines.append("   imported by: {}".format(",".join(sorted(self._imported_classes[relpath(file)]))))
            lines.append("\n   ##################### \n*/")
        lines.extend(middle_lines)
        self.imported_strs.append("\n".join(lines))


    def _import_file(self, full_name: str):
        _LOG.debug("Try to import {}".format(full_name))
        path = self._file_name_for_full_class_name(full_name)
        name = _class_name_for_file(path)
        if name in self._already_imported_classes:
            if full_name != self._already_imported_classes[name]:
                raise PreProcessorError("Can't import {}, as another class with the same name was already imported: {}"
                           .format(full_name, self._already_imported_classes[name]))
        else:
            self._already_imported_classes[name] = full_name
            self._preprocess_file(path)

    def _file_name_for_full_class_name(self, full_name: str) -> str:
        parts = full_name.split(".")
        parts[-1] += ".java"
        path = os.path.join(self.import_base_dir, *parts)
        if not os.path.exists(path):
            raise PreProcessorError("File '{}' doesn't exist, can't import class {}".format(path, _class_name_for_file(path)))
        if not is_importable_file(path):
            raise PreProcessorError("File '{}' isn't importable, can't import class {}".format(path, _class_name_for_file(path)))
        return path

    def _store_in_dst_file(self):
        _LOG.debug("Try to output to '{}'".format(self.dst_file))
        if self.dst_file == "-":
            for text in reversed(self.imported_strs):
                print()
                print()
                print(text)
        else:
111
            with open(self.dst_file, "w", errors="backslashreplace") as f:
112 113 114 115 116 117 118 119 120 121 122 123 124
                for text in reversed(self.imported_strs):
                    f.write(text)
                    f.write("\n\n")
                f.flush()

def _class_name_for_file(file: str):
    return os.path.basename(file).split(".")[0]

def is_importable_file(file: str) -> bool:
    name = _class_name_for_file(file)
    has_package = False
    has_public_class = False
    has_main_method = False
125
    with open(file, "r", errors="backslashreplace") as f:
126 127 128 129
        for line in f:
            if line.startswith("package "):
                has_package = True
            elif line.startswith("public class "):
130
                match = re.search("[A-Za-z_0-9]+", line.replace("public class ", ""))
131 132 133 134 135 136 137 138 139
                if match:
                    has_public_class = True
                    if match.group(0) != name:
                        raise PreProcessorError("File '{}' has invalid format: expected a public class {}, got {}"
                                   .format(file, name, match.group(0)))
            elif "String[]" in line and "main" in line and "void" in line and "static" in line and "public" in line:
                has_main_method = True
    if all([has_package, has_public_class, not has_main_method]):
        return True
140 141 142 143 144
    #if (has_package or has_public_class) == has_main_method:
    #    raise PreProcessorError("File '{}' has invalid format: "
    #                            "package={}, 'public class'={}, main method={}"
    #                           .format(file, has_package, has_public_class, has_main_method))
    return False