Update collab filter to use dict result

This commit is contained in:
cuom1999 2024-03-26 12:23:13 -05:00
parent a4c2fad04f
commit 45469ff103

View file

@ -20,29 +20,21 @@ class CollabFilter:
)
_, problem_arr = self.embeddings.files
self.name = name
self.problem_embeddings = self.embeddings[problem_arr]
self.problem_embeddings = self.embeddings[problem_arr].item()
def __str__(self):
return self.name
def compute_scores(self, query_embedding, item_embeddings, measure=DOT):
"""Computes the scores of the candidates given a query.
Args:
query_embedding: a vector of shape [k], representing the query embedding.
item_embeddings: a matrix of shape [N, k], such that row i is the embedding
of item i.
measure: a string specifying the similarity measure to be used. Can be
either DOT or COSINE.
Returns:
scores: a vector of shape [N], such that scores[i] is the score of item i.
"""
"""Return {id: score}"""
u = query_embedding
V = item_embeddings
V = np.stack(list(item_embeddings.values()))
if measure == self.COSINE:
V = V / np.linalg.norm(V, axis=1, keepdims=True)
u = u / np.linalg.norm(u)
scores = u.dot(V.T)
return scores
scores_by_id = {id_: s for id_, s in zip(item_embeddings.keys(), scores)}
return scores_by_id
def _get_embedding_version(self):
first_problem = self.problem_embeddings[0]
@ -54,8 +46,8 @@ class CollabFilter:
@cache_wrapper(prefix="CFgue", timeout=86400)
def _get_user_embedding(self, user_id, embedding_version):
user_arr, _ = self.embeddings.files
user_embeddings = self.embeddings[user_arr]
if user_id >= len(user_embeddings):
user_embeddings = self.embeddings[user_arr].item()
if user_id not in user_embeddings:
return user_embeddings[0]
return user_embeddings[user_id]
@ -67,27 +59,24 @@ class CollabFilter:
def user_recommendations(self, user_id, problems, measure=DOT, limit=None):
user_embedding = self.get_user_embedding(user_id)
scores = self.compute_scores(user_embedding, self.problem_embeddings, measure)
res = [] # [(score, problem)]
for pid in problems:
if pid < len(scores):
if pid in scores:
res.append((scores[pid], pid))
res.sort(reverse=True, key=lambda x: x[0])
res = res[:limit]
return res
return res[:limit]
# return a list of pid
def problem_neighbors(self, problem, problemset, measure=DOT, limit=None):
pid = problem.id
if pid >= len(self.problem_embeddings):
return []
scores = self.compute_scores(
self.problem_embeddings[pid], self.problem_embeddings, measure
)
if pid not in self.problem_embeddings:
return None
embedding = self.problem_embeddings[pid]
scores = self.compute_scores(embedding, self.problem_embeddings, measure)
res = []
for p in problemset:
if p < len(scores):
if p in scores:
res.append((scores[p], p))
res.sort(reverse=True, key=lambda x: x[0])
return res[:limit]