preprocessor.py 6.36 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
from collections import defaultdict
import logging
import os
from os.path import relpath
from pprint import pprint
import re
import sys

_LOG = logging.getLogger("preprocessor")

class PreProcessorError(BaseException):

    def __init__(self, msg):
        super().__init__(msg)

class PreProcessor:

    def __init__(self, src_file: str, import_base_dir: str, dst_file: str):
        self.src_file = src_file
        self.import_base_dir = import_base_dir
        self.dst_file = dst_file
        self.imported_strs = []
        self._already_imported_classes = {}  # name -> full_name
        self._import_regexp = re.compile("import [A-Za-z.]+;")
        self._imported_class_regexp = re.compile("[A-Za-z.]+;")
        self._imported_classes = defaultdict(lambda: [])  # name -> embedding files
        if not os.path.isfile(src_file):
            raise PreProcessorError("Source file '{}' isn't a file".format(src_file))
        if not os.path.isdir(import_base_dir):
            raise PreProcessorError("MJ_IMPORT_DIR '{}' isn't a directory".format(import_base_dir))
        if not dst_file != "-":
            if os.path.isdir(dst_file):
                raise PreProcessorError("Destination file '{}' isn't a file or '-'".format(dst_file))
            if os.path.realpath(src_file) == os.path.realpath(dst_file):
                raise PreProcessorError("Destination file '{}' is equal to the source file".format(dst_file))


    def preprocess(self):
        #if is_importable_file(self.src_file):
        #    raise PreProcessorError("Can't pre process importable file '{}'".format(self.src_file))
        self._preprocess_file(self.src_file)
        self._store_in_dst_file()

    def _preprocess_file(self, file: str):
        lines = []
        middle_lines = []

        def add_commented(line: str):
            middle_lines.append("/*{}*/".format(line))

        with open(file, "r") as f:
            for line in f:
                line = line.rstrip()
                if self._import_regexp.match(line):
                    _LOG.debug("File '{}': parse '{}'".format(file, line))
                    full_name = self._imported_class_regexp.search(line).group(0)[:-1]
                    _LOG.debug("File '{}': import class {}".format(file, full_name))
                    self._import_file(full_name)
                    self._imported_classes[relpath(self._file_name_for_full_class_name(full_name))]\
                        .append(_class_name_for_file(file))
                    add_commented(line)
                elif line.startswith("public class "):
                    _LOG.debug("File '{}': modify '{}'".format(file, line))
                    line = line.replace("public class ", "/*public*/ class ", 1)
                    middle_lines.append(line)
                elif line.startswith("package"):
                    _LOG.debug("File '{}': ignore '{}'".format(file, line))
                    add_commented(line)
                else:
                    middle_lines.append(line)

        if file != self.src_file:
            lines.append("/* ##################### \n")
            lines.append("   imported from file: {}".format(relpath(file, self.import_base_dir)))
            #lines.append("   imported by: {}".format(",".join(sorted(self._imported_classes[relpath(file)]))))
            lines.append("\n   ##################### \n*/")
        lines.extend(middle_lines)
        self.imported_strs.append("\n".join(lines))


    def _import_file(self, full_name: str):
        _LOG.debug("Try to import {}".format(full_name))
        path = self._file_name_for_full_class_name(full_name)
        name = _class_name_for_file(path)
        if name in self._already_imported_classes:
            if full_name != self._already_imported_classes[name]:
                raise PreProcessorError("Can't import {}, as another class with the same name was already imported: {}"
                           .format(full_name, self._already_imported_classes[name]))
        else:
            self._already_imported_classes[name] = full_name
            self._preprocess_file(path)

    def _file_name_for_full_class_name(self, full_name: str) -> str:
        parts = full_name.split(".")
        parts[-1] += ".java"
        path = os.path.join(self.import_base_dir, *parts)
        if not os.path.exists(path):
            raise PreProcessorError("File '{}' doesn't exist, can't import class {}".format(path, _class_name_for_file(path)))
        if not is_importable_file(path):
            raise PreProcessorError("File '{}' isn't importable, can't import class {}".format(path, _class_name_for_file(path)))
        return path

    def _store_in_dst_file(self):
        _LOG.debug("Try to output to '{}'".format(self.dst_file))
        if self.dst_file == "-":
            for text in reversed(self.imported_strs):
                print()
                print()
                print(text)
        else:
            with open(self.dst_file, "w") as f:
                for text in reversed(self.imported_strs):
                    f.write(text)
                    f.write("\n\n")
                f.flush()

def _class_name_for_file(file: str):
    return os.path.basename(file).split(".")[0]

def is_importable_file(file: str) -> bool:
    name = _class_name_for_file(file)
    has_package = False
    has_public_class = False
    has_main_method = False
    with open(file, "r") as f:
        for line in f:
            if line.startswith("package "):
                has_package = True
            elif line.startswith("public class "):
                match = re.search("[A-Za-z_]+", line.replace("public class ", ""))
                if match:
                    has_public_class = True
                    if match.group(0) != name:
                        raise PreProcessorError("File '{}' has invalid format: expected a public class {}, got {}"
                                   .format(file, name, match.group(0)))
            elif "String[]" in line and "main" in line and "void" in line and "static" in line and "public" in line:
                has_main_method = True
    if all([has_package, has_public_class, not has_main_method]):
        return True
140 141 142 143 144
    #if (has_package or has_public_class) == has_main_method:
    #    raise PreProcessorError("File '{}' has invalid format: "
    #                            "package={}, 'public class'={}, main method={}"
    #                           .format(file, has_package, has_public_class, has_main_method))
    return False