diff --git a/judge/ml/collab_filter.py b/judge/ml/collab_filter.py index d19c5e5..0bdafa8 100644 --- a/judge/ml/collab_filter.py +++ b/judge/ml/collab_filter.py @@ -20,29 +20,21 @@ class CollabFilter: ) _, problem_arr = self.embeddings.files self.name = name - self.problem_embeddings = self.embeddings[problem_arr] + self.problem_embeddings = self.embeddings[problem_arr].item() def __str__(self): return self.name def compute_scores(self, query_embedding, item_embeddings, measure=DOT): - """Computes the scores of the candidates given a query. - Args: - query_embedding: a vector of shape [k], representing the query embedding. - item_embeddings: a matrix of shape [N, k], such that row i is the embedding - of item i. - measure: a string specifying the similarity measure to be used. Can be - either DOT or COSINE. - Returns: - scores: a vector of shape [N], such that scores[i] is the score of item i. - """ + """Return {id: score}""" u = query_embedding - V = item_embeddings + V = np.stack(list(item_embeddings.values())) if measure == self.COSINE: V = V / np.linalg.norm(V, axis=1, keepdims=True) u = u / np.linalg.norm(u) scores = u.dot(V.T) - return scores + scores_by_id = {id_: s for id_, s in zip(item_embeddings.keys(), scores)} + return scores_by_id def _get_embedding_version(self): first_problem = self.problem_embeddings[0] @@ -54,8 +46,8 @@ class CollabFilter: @cache_wrapper(prefix="CFgue", timeout=86400) def _get_user_embedding(self, user_id, embedding_version): user_arr, _ = self.embeddings.files - user_embeddings = self.embeddings[user_arr] - if user_id >= len(user_embeddings): + user_embeddings = self.embeddings[user_arr].item() + if user_id not in user_embeddings: return user_embeddings[0] return user_embeddings[user_id] @@ -67,27 +59,24 @@ class CollabFilter: def user_recommendations(self, user_id, problems, measure=DOT, limit=None): user_embedding = self.get_user_embedding(user_id) scores = self.compute_scores(user_embedding, self.problem_embeddings, measure) - res = [] # [(score, problem)] for pid in problems: - if pid < len(scores): + if pid in scores: res.append((scores[pid], pid)) res.sort(reverse=True, key=lambda x: x[0]) - res = res[:limit] - return res + return res[:limit] # return a list of pid def problem_neighbors(self, problem, problemset, measure=DOT, limit=None): pid = problem.id - if pid >= len(self.problem_embeddings): - return [] - scores = self.compute_scores( - self.problem_embeddings[pid], self.problem_embeddings, measure - ) + if pid not in self.problem_embeddings: + return None + embedding = self.problem_embeddings[pid] + scores = self.compute_scores(embedding, self.problem_embeddings, measure) res = [] for p in problemset: - if p < len(scores): + if p in scores: res.append((scores[p], p)) res.sort(reverse=True, key=lambda x: x[0]) return res[:limit]