Reformat using black

This commit is contained in:
cuom1999 2022-05-14 12:57:27 -05:00
parent efee4ad081
commit a87fb49918
221 changed files with 19127 additions and 7310 deletions

View file

@ -5,14 +5,16 @@ from dmoj.decorators import timeit
class CollabFilter:
DOT = 'dot'
COSINE = 'cosine'
DOT = "dot"
COSINE = "cosine"
# name = 'collab_filter' or 'collab_filter_time'
@timeit
def __init__(self, name, **kwargs):
embeddings = np.load(os.path.join(settings.ML_OUTPUT_PATH, name + '/embeddings.npz'),
allow_pickle=True)
embeddings = np.load(
os.path.join(settings.ML_OUTPUT_PATH, name + "/embeddings.npz"),
allow_pickle=True,
)
arr0, arr1 = embeddings.files
self.user_embeddings = embeddings[arr0]
self.problem_embeddings = embeddings[arr1]
@ -42,9 +44,10 @@ class CollabFilter:
if uid >= len(self.user_embeddings):
uid = 0
scores = self.compute_scores(
self.user_embeddings[uid], self.problem_embeddings, measure)
res = [] # [(score, problem)]
self.user_embeddings[uid], self.problem_embeddings, measure
)
res = [] # [(score, problem)]
for problem in problems:
pid = problem.id
if pid < len(scores):
@ -53,17 +56,17 @@ class CollabFilter:
res.sort(reverse=True, key=lambda x: x[0])
return res[:limit]
# return a list of pid
def problems_neighbors(self, problem, problemset, measure=DOT, limit=None):
pid = problem.id
if pid >= len(self.problem_embeddings):
return None
scores = self.compute_scores(
self.problem_embeddings[pid], self.problem_embeddings, measure)
self.problem_embeddings[pid], self.problem_embeddings, measure
)
res = []
for p in problemset:
if p.id < len(scores):
res.append((scores[p.id], p))
res.sort(reverse=True, key=lambda x: x[0])
return res[:limit]
return res[:limit]