269 lines
8.7 KiB
Python
269 lines
8.7 KiB
Python
|
import os
|
||
|
import random
|
||
|
from judge.views.test_formatter import tf_utils as utils
|
||
|
|
||
|
SAMPLE_SIZE = 16
|
||
|
NUMBERED_MM = ["0", "1", "00", "01", "000", "001", "0000", "0001"]
|
||
|
VALID_MM = ["*"] + NUMBERED_MM
|
||
|
|
||
|
MSG_TOO_MANY_OCCURRENCES = (
|
||
|
"400: Invalid pattern: Pattern cannot have more than one '{}'"
|
||
|
)
|
||
|
MSG_MM_NOT_FOUND = "400: Invalid pattern: Wildcard not found. Wildcard list: {}"
|
||
|
|
||
|
|
||
|
class Pattern:
|
||
|
def __init__(self, ll, mm, rr):
|
||
|
assert mm in VALID_MM, "Invalid wildcard"
|
||
|
self.ll = ll
|
||
|
self.mm = mm
|
||
|
self.rr = rr
|
||
|
|
||
|
def __repr__(self):
|
||
|
return "Pattern('{}', '{}', '{}')".format(self.ll, self.mm, self.rr)
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
return self.__repr__() == other.__repr__()
|
||
|
|
||
|
def __hash__(self):
|
||
|
return self.__repr__().__hash__()
|
||
|
|
||
|
@classmethod
|
||
|
def from_string(cls, text):
|
||
|
for mm in ["*"] + sorted(NUMBERED_MM, key=len, reverse=True):
|
||
|
if mm in text:
|
||
|
if text.count(mm) > 1:
|
||
|
raise Exception(MSG_TOO_MANY_OCCURRENCES.format(mm))
|
||
|
i = text.index(mm)
|
||
|
return cls(text[:i], mm, text[i + len(mm) :])
|
||
|
raise Exception(MSG_MM_NOT_FOUND.format(",".join(VALID_MM)))
|
||
|
|
||
|
def to_string(self):
|
||
|
return self.ll + self.mm + self.rr
|
||
|
|
||
|
def is_valid_test_id(self, test_id):
|
||
|
if self.mm == "*":
|
||
|
return True
|
||
|
if self.mm in NUMBERED_MM:
|
||
|
return test_id.isdigit() and len(test_id) >= len(self.mm)
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def matched(self, name):
|
||
|
return (
|
||
|
name.startswith(self.ll)
|
||
|
and name.endswith(self.rr)
|
||
|
and len(name) >= len(self.ll) + len(self.rr)
|
||
|
and self.is_valid_test_id(self.get_test_id(name))
|
||
|
)
|
||
|
|
||
|
def get_test_id(self, name):
|
||
|
return name[len(self.ll) : len(name) - len(self.rr)]
|
||
|
|
||
|
def get_test_id_from_index(self, index):
|
||
|
assert self.mm in NUMBERED_MM, "Wildcard is not a number"
|
||
|
return str(int(self.mm) + index).zfill(len(self.mm))
|
||
|
|
||
|
def get_name(self, test_id, index=None, use_index=False):
|
||
|
if use_index and self.mm in NUMBERED_MM:
|
||
|
return self.ll + self.get_test_id_from_index(index) + self.rr
|
||
|
return self.ll + test_id + self.rr
|
||
|
|
||
|
def matches(self, names, returns):
|
||
|
if returns == "test_id":
|
||
|
result = [n for n in names]
|
||
|
result = [n for n in result if self.matched(n)]
|
||
|
result = [self.get_test_id(n) for n in result]
|
||
|
return result
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
|
||
|
class PatternPair:
|
||
|
def __init__(self, x: Pattern, y: Pattern):
|
||
|
assert x.mm == y.mm, "Input wildcard and output wildcard must be equal"
|
||
|
self.x = x
|
||
|
self.y = y
|
||
|
|
||
|
def __repr__(self):
|
||
|
return "PatternPair({}, {})".format(self.x, self.y)
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
return self.__repr__() == other.__repr__()
|
||
|
|
||
|
def __hash__(self):
|
||
|
return self.__repr__().__hash__()
|
||
|
|
||
|
@classmethod
|
||
|
def from_string_pair(cls, inp_format, out_format):
|
||
|
return cls(Pattern.from_string(inp_format), Pattern.from_string(out_format))
|
||
|
|
||
|
def matches(self, names, returns):
|
||
|
x_test_ids = self.x.matches(names, returns="test_id")
|
||
|
y_test_ids = self.y.matches(names, returns="test_id")
|
||
|
|
||
|
test_ids = set(x_test_ids) & set(y_test_ids)
|
||
|
test_ids = list(sorted(test_ids, key=utils.natural_sorting_key))
|
||
|
|
||
|
if returns == "fast_count":
|
||
|
if self.x.mm == "*":
|
||
|
return len(test_ids)
|
||
|
elif self.x.mm in NUMBERED_MM:
|
||
|
count_valid = 0
|
||
|
for t in test_ids:
|
||
|
if t == self.x.get_test_id_from_index(count_valid):
|
||
|
count_valid += 1
|
||
|
|
||
|
return count_valid
|
||
|
|
||
|
extra_files = list(names)
|
||
|
valid_test_ids = []
|
||
|
for t in test_ids:
|
||
|
if self.x.mm in NUMBERED_MM:
|
||
|
if t != self.x.get_test_id_from_index(len(valid_test_ids)):
|
||
|
continue
|
||
|
|
||
|
inp_name = self.x.get_name(t)
|
||
|
out_name = self.y.get_name(t)
|
||
|
|
||
|
if inp_name == out_name:
|
||
|
continue
|
||
|
if inp_name not in extra_files:
|
||
|
continue
|
||
|
if out_name not in extra_files:
|
||
|
continue
|
||
|
|
||
|
valid_test_ids.append(t)
|
||
|
extra_files.remove(inp_name)
|
||
|
extra_files.remove(out_name)
|
||
|
|
||
|
if returns == "count":
|
||
|
return len(valid_test_ids)
|
||
|
elif returns == "test_id":
|
||
|
return valid_test_ids
|
||
|
elif returns == "test_id_with_extra_files":
|
||
|
return valid_test_ids, extra_files
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def score(self, names):
|
||
|
def ls(s):
|
||
|
return len(s) - s.count("0")
|
||
|
|
||
|
def zs(s):
|
||
|
return -s.count("0")
|
||
|
|
||
|
def vs(s):
|
||
|
return sum(
|
||
|
s.lower().count(c) * w
|
||
|
for c, w in [("a", -1), ("e", -1), ("i", +1), ("o", -1), ("u", -1)]
|
||
|
)
|
||
|
|
||
|
count_score = self.matches(names, returns="fast_count")
|
||
|
|
||
|
len_score = ls(self.x.ll + self.x.rr + self.y.ll + self.y.rr)
|
||
|
zero_score = zs(self.x.ll + self.x.rr + self.y.ll + self.y.rr)
|
||
|
|
||
|
assert self.x.mm in ["*"] + NUMBERED_MM
|
||
|
specific_score = 0 if self.x.mm == "*" else len(self.x.mm)
|
||
|
|
||
|
vowel_score = vs(self.x.ll + self.x.rr) - vs(self.y.ll + self.y.rr)
|
||
|
|
||
|
return count_score, specific_score, len_score, zero_score, vowel_score
|
||
|
|
||
|
def is_string_safe(self):
|
||
|
try:
|
||
|
x = Pattern.from_string(self.x.to_string())
|
||
|
y = Pattern.from_string(self.y.to_string())
|
||
|
return self == PatternPair(x, y)
|
||
|
except:
|
||
|
return False
|
||
|
|
||
|
|
||
|
def maximal(a, key):
|
||
|
max_score = max(map(key, a))
|
||
|
result = [x for x in a if key(x) == max_score]
|
||
|
if len(result) == 1:
|
||
|
return result[0]
|
||
|
else:
|
||
|
print(result)
|
||
|
raise Exception("More than one maximum values")
|
||
|
|
||
|
|
||
|
def get_all_star_pattern_pairs(names):
|
||
|
sample = random.sample(names, min(len(names), SAMPLE_SIZE))
|
||
|
|
||
|
star_pattern_pairs = []
|
||
|
|
||
|
all_prefixes = [n[:i] for n in sample for i in range(len(n) + 1)]
|
||
|
all_prefixes = list(sorted(set(all_prefixes)))
|
||
|
all_suffixes = [n[i:] for n in sample for i in range(len(n) + 1)]
|
||
|
all_suffixes = list(sorted(set(all_suffixes)))
|
||
|
|
||
|
for prefix in all_prefixes:
|
||
|
matched_names = [n for n in names if n.startswith(prefix)]
|
||
|
if len(matched_names) == 2:
|
||
|
mn0, mn1 = matched_names
|
||
|
for i in range(len(prefix) + 1):
|
||
|
x = Pattern(prefix[:i], "*", mn0[len(prefix) :])
|
||
|
y = Pattern(prefix[:i], "*", mn1[len(prefix) :])
|
||
|
star_pattern_pairs.append(PatternPair(x, y))
|
||
|
|
||
|
for suffix in all_suffixes:
|
||
|
matched_names = [n for n in names if n.endswith(suffix)]
|
||
|
if len(matched_names) == 2:
|
||
|
mn0, mn1 = matched_names
|
||
|
for i in range(len(suffix) + 1):
|
||
|
x = Pattern(mn0[: len(mn0) - len(suffix)], "*", suffix[i:])
|
||
|
y = Pattern(mn1[: len(mn1) - len(suffix)], "*", suffix[i:])
|
||
|
star_pattern_pairs.append(PatternPair(x, y))
|
||
|
|
||
|
star_pattern_pairs = list(set(star_pattern_pairs))
|
||
|
return star_pattern_pairs
|
||
|
|
||
|
|
||
|
def get_variant_pattern_pairs(pp):
|
||
|
return [
|
||
|
PatternPair(Pattern(pp.x.ll, mm, pp.x.rr), Pattern(pp.y.ll, mm, pp.y.rr))
|
||
|
for mm in VALID_MM
|
||
|
] + [
|
||
|
PatternPair(Pattern(pp.y.ll, mm, pp.y.rr), Pattern(pp.x.ll, mm, pp.x.rr))
|
||
|
for mm in VALID_MM
|
||
|
]
|
||
|
|
||
|
|
||
|
def find_best_pattern_pair(names):
|
||
|
star_pattern_pairs = get_all_star_pattern_pairs(names)
|
||
|
star_pattern_pairs = [
|
||
|
pp for pp in star_pattern_pairs if pp.matches(names, returns="fast_count") >= 2
|
||
|
]
|
||
|
# for pp in star_pattern_pairs:
|
||
|
# print(pp, pp.is_string_safe(), pp.score(names))
|
||
|
|
||
|
if len(star_pattern_pairs) == 0:
|
||
|
return PatternPair(Pattern("", "*", ""), Pattern("", "*", ""))
|
||
|
best_star_pattern_pair = maximal(star_pattern_pairs, key=lambda pp: pp.score(names))
|
||
|
|
||
|
pattern_pairs = get_variant_pattern_pairs(best_star_pattern_pair)
|
||
|
# for pp in pattern_pairs:
|
||
|
# print(pp, pp.is_string_safe(), pp.score(names))
|
||
|
pattern_pairs = [pp for pp in pattern_pairs if pp.is_string_safe()]
|
||
|
best_pattern_pair = maximal(pattern_pairs, key=lambda pp: pp.score(names))
|
||
|
|
||
|
return best_pattern_pair
|
||
|
|
||
|
|
||
|
def list_dir_recursively(folder):
|
||
|
old_cwd = os.getcwd()
|
||
|
os.chdir(folder)
|
||
|
result = []
|
||
|
for root, _, filenames in os.walk("."):
|
||
|
for filename in filenames:
|
||
|
result.append(os.path.join(root, filename))
|
||
|
os.chdir(old_cwd)
|
||
|
return result
|
||
|
|
||
|
|
||
|
def test_with_dir(folder):
|
||
|
names = list_dir_recursively(folder)
|
||
|
print(folder, find_best_pattern_pair(names))
|