import os import random from judge.views.test_formatter import tf_utils as utils SAMPLE_SIZE = 16 NUMBERED_MM = ["0", "1", "00", "01", "000", "001", "0000", "0001"] VALID_MM = ["*"] + NUMBERED_MM MSG_TOO_MANY_OCCURRENCES = ( "400: Invalid pattern: Pattern cannot have more than one '{}'" ) MSG_MM_NOT_FOUND = "400: Invalid pattern: Wildcard not found. Wildcard list: {}" class Pattern: def __init__(self, ll, mm, rr): assert mm in VALID_MM, "Invalid wildcard" self.ll = ll self.mm = mm self.rr = rr def __repr__(self): return "Pattern('{}', '{}', '{}')".format(self.ll, self.mm, self.rr) def __eq__(self, other): return self.__repr__() == other.__repr__() def __hash__(self): return self.__repr__().__hash__() @classmethod def from_string(cls, text): for mm in ["*"] + sorted(NUMBERED_MM, key=len, reverse=True): if mm in text: if text.count(mm) > 1: raise Exception(MSG_TOO_MANY_OCCURRENCES.format(mm)) i = text.index(mm) return cls(text[:i], mm, text[i + len(mm) :]) raise Exception(MSG_MM_NOT_FOUND.format(",".join(VALID_MM))) def to_string(self): return self.ll + self.mm + self.rr def is_valid_test_id(self, test_id): if self.mm == "*": return True if self.mm in NUMBERED_MM: return test_id.isdigit() and len(test_id) >= len(self.mm) raise NotImplementedError def matched(self, name): return ( name.startswith(self.ll) and name.endswith(self.rr) and len(name) >= len(self.ll) + len(self.rr) and self.is_valid_test_id(self.get_test_id(name)) ) def get_test_id(self, name): return name[len(self.ll) : len(name) - len(self.rr)] def get_test_id_from_index(self, index): assert self.mm in NUMBERED_MM, "Wildcard is not a number" return str(int(self.mm) + index).zfill(len(self.mm)) def get_name(self, test_id, index=None, use_index=False): if use_index and self.mm in NUMBERED_MM: return self.ll + self.get_test_id_from_index(index) + self.rr return self.ll + test_id + self.rr def matches(self, names, returns): if returns == "test_id": result = [n for n in names] result = [n for n in result if self.matched(n)] result = [self.get_test_id(n) for n in result] return result else: raise NotImplementedError class PatternPair: def __init__(self, x: Pattern, y: Pattern): assert x.mm == y.mm, "Input wildcard and output wildcard must be equal" self.x = x self.y = y def __repr__(self): return "PatternPair({}, {})".format(self.x, self.y) def __eq__(self, other): return self.__repr__() == other.__repr__() def __hash__(self): return self.__repr__().__hash__() @classmethod def from_string_pair(cls, inp_format, out_format): return cls(Pattern.from_string(inp_format), Pattern.from_string(out_format)) def matches(self, names, returns): x_test_ids = self.x.matches(names, returns="test_id") y_test_ids = self.y.matches(names, returns="test_id") test_ids = set(x_test_ids) & set(y_test_ids) test_ids = list(sorted(test_ids, key=utils.natural_sorting_key)) if returns == "fast_count": if self.x.mm == "*": return len(test_ids) elif self.x.mm in NUMBERED_MM: count_valid = 0 for t in test_ids: if t == self.x.get_test_id_from_index(count_valid): count_valid += 1 return count_valid extra_files = list(names) valid_test_ids = [] for t in test_ids: if self.x.mm in NUMBERED_MM: if t != self.x.get_test_id_from_index(len(valid_test_ids)): continue inp_name = self.x.get_name(t) out_name = self.y.get_name(t) if inp_name == out_name: continue if inp_name not in extra_files: continue if out_name not in extra_files: continue valid_test_ids.append(t) extra_files.remove(inp_name) extra_files.remove(out_name) if returns == "count": return len(valid_test_ids) elif returns == "test_id": return valid_test_ids elif returns == "test_id_with_extra_files": return valid_test_ids, extra_files else: raise NotImplementedError def score(self, names): def ls(s): return len(s) - s.count("0") def zs(s): return -s.count("0") def vs(s): return sum( s.lower().count(c) * w for c, w in [("a", -1), ("e", -1), ("i", +1), ("o", -1), ("u", -1)] ) count_score = self.matches(names, returns="fast_count") len_score = ls(self.x.ll + self.x.rr + self.y.ll + self.y.rr) zero_score = zs(self.x.ll + self.x.rr + self.y.ll + self.y.rr) assert self.x.mm in ["*"] + NUMBERED_MM specific_score = 0 if self.x.mm == "*" else len(self.x.mm) vowel_score = vs(self.x.ll + self.x.rr) - vs(self.y.ll + self.y.rr) return count_score, specific_score, len_score, zero_score, vowel_score def is_string_safe(self): try: x = Pattern.from_string(self.x.to_string()) y = Pattern.from_string(self.y.to_string()) return self == PatternPair(x, y) except: return False def maximal(a, key): max_score = max(map(key, a)) result = [x for x in a if key(x) == max_score] if len(result) == 1: return result[0] else: print(result) raise Exception("More than one maximum values") def get_all_star_pattern_pairs(names): sample = random.sample(names, min(len(names), SAMPLE_SIZE)) star_pattern_pairs = [] all_prefixes = [n[:i] for n in sample for i in range(len(n) + 1)] all_prefixes = list(sorted(set(all_prefixes))) all_suffixes = [n[i:] for n in sample for i in range(len(n) + 1)] all_suffixes = list(sorted(set(all_suffixes))) for prefix in all_prefixes: matched_names = [n for n in names if n.startswith(prefix)] if len(matched_names) == 2: mn0, mn1 = matched_names for i in range(len(prefix) + 1): x = Pattern(prefix[:i], "*", mn0[len(prefix) :]) y = Pattern(prefix[:i], "*", mn1[len(prefix) :]) star_pattern_pairs.append(PatternPair(x, y)) for suffix in all_suffixes: matched_names = [n for n in names if n.endswith(suffix)] if len(matched_names) == 2: mn0, mn1 = matched_names for i in range(len(suffix) + 1): x = Pattern(mn0[: len(mn0) - len(suffix)], "*", suffix[i:]) y = Pattern(mn1[: len(mn1) - len(suffix)], "*", suffix[i:]) star_pattern_pairs.append(PatternPair(x, y)) star_pattern_pairs = list(set(star_pattern_pairs)) return star_pattern_pairs def get_variant_pattern_pairs(pp): return [ PatternPair(Pattern(pp.x.ll, mm, pp.x.rr), Pattern(pp.y.ll, mm, pp.y.rr)) for mm in VALID_MM ] + [ PatternPair(Pattern(pp.y.ll, mm, pp.y.rr), Pattern(pp.x.ll, mm, pp.x.rr)) for mm in VALID_MM ] def find_best_pattern_pair(names): star_pattern_pairs = get_all_star_pattern_pairs(names) star_pattern_pairs = [ pp for pp in star_pattern_pairs if pp.matches(names, returns="fast_count") >= 2 ] # for pp in star_pattern_pairs: # print(pp, pp.is_string_safe(), pp.score(names)) if len(star_pattern_pairs) == 0: return PatternPair(Pattern("", "*", ""), Pattern("", "*", "")) best_star_pattern_pair = maximal(star_pattern_pairs, key=lambda pp: pp.score(names)) pattern_pairs = get_variant_pattern_pairs(best_star_pattern_pair) # for pp in pattern_pairs: # print(pp, pp.is_string_safe(), pp.score(names)) pattern_pairs = [pp for pp in pattern_pairs if pp.is_string_safe()] best_pattern_pair = maximal(pattern_pairs, key=lambda pp: pp.score(names)) return best_pattern_pair def list_dir_recursively(folder): old_cwd = os.getcwd() os.chdir(folder) result = [] for root, _, filenames in os.walk("."): for filename in filenames: result.append(os.path.join(root, filename)) os.chdir(old_cwd) return result def test_with_dir(folder): names = list_dir_recursively(folder) print(folder, find_best_pattern_pair(names))