Update collab filter to use dict result
This commit is contained in:
parent
a4c2fad04f
commit
45469ff103
1 changed files with 14 additions and 25 deletions
|
@ -20,29 +20,21 @@ class CollabFilter:
|
|||
)
|
||||
_, problem_arr = self.embeddings.files
|
||||
self.name = name
|
||||
self.problem_embeddings = self.embeddings[problem_arr]
|
||||
self.problem_embeddings = self.embeddings[problem_arr].item()
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def compute_scores(self, query_embedding, item_embeddings, measure=DOT):
|
||||
"""Computes the scores of the candidates given a query.
|
||||
Args:
|
||||
query_embedding: a vector of shape [k], representing the query embedding.
|
||||
item_embeddings: a matrix of shape [N, k], such that row i is the embedding
|
||||
of item i.
|
||||
measure: a string specifying the similarity measure to be used. Can be
|
||||
either DOT or COSINE.
|
||||
Returns:
|
||||
scores: a vector of shape [N], such that scores[i] is the score of item i.
|
||||
"""
|
||||
"""Return {id: score}"""
|
||||
u = query_embedding
|
||||
V = item_embeddings
|
||||
V = np.stack(list(item_embeddings.values()))
|
||||
if measure == self.COSINE:
|
||||
V = V / np.linalg.norm(V, axis=1, keepdims=True)
|
||||
u = u / np.linalg.norm(u)
|
||||
scores = u.dot(V.T)
|
||||
return scores
|
||||
scores_by_id = {id_: s for id_, s in zip(item_embeddings.keys(), scores)}
|
||||
return scores_by_id
|
||||
|
||||
def _get_embedding_version(self):
|
||||
first_problem = self.problem_embeddings[0]
|
||||
|
@ -54,8 +46,8 @@ class CollabFilter:
|
|||
@cache_wrapper(prefix="CFgue", timeout=86400)
|
||||
def _get_user_embedding(self, user_id, embedding_version):
|
||||
user_arr, _ = self.embeddings.files
|
||||
user_embeddings = self.embeddings[user_arr]
|
||||
if user_id >= len(user_embeddings):
|
||||
user_embeddings = self.embeddings[user_arr].item()
|
||||
if user_id not in user_embeddings:
|
||||
return user_embeddings[0]
|
||||
return user_embeddings[user_id]
|
||||
|
||||
|
@ -67,27 +59,24 @@ class CollabFilter:
|
|||
def user_recommendations(self, user_id, problems, measure=DOT, limit=None):
|
||||
user_embedding = self.get_user_embedding(user_id)
|
||||
scores = self.compute_scores(user_embedding, self.problem_embeddings, measure)
|
||||
|
||||
res = [] # [(score, problem)]
|
||||
for pid in problems:
|
||||
if pid < len(scores):
|
||||
if pid in scores:
|
||||
res.append((scores[pid], pid))
|
||||
|
||||
res.sort(reverse=True, key=lambda x: x[0])
|
||||
res = res[:limit]
|
||||
return res
|
||||
return res[:limit]
|
||||
|
||||
# return a list of pid
|
||||
def problem_neighbors(self, problem, problemset, measure=DOT, limit=None):
|
||||
pid = problem.id
|
||||
if pid >= len(self.problem_embeddings):
|
||||
return []
|
||||
scores = self.compute_scores(
|
||||
self.problem_embeddings[pid], self.problem_embeddings, measure
|
||||
)
|
||||
if pid not in self.problem_embeddings:
|
||||
return None
|
||||
embedding = self.problem_embeddings[pid]
|
||||
scores = self.compute_scores(embedding, self.problem_embeddings, measure)
|
||||
res = []
|
||||
for p in problemset:
|
||||
if p < len(scores):
|
||||
if p in scores:
|
||||
res.append((scores[p], p))
|
||||
res.sort(reverse=True, key=lambda x: x[0])
|
||||
return res[:limit]
|
||||
|
|
Loading…
Reference in a new issue