Update collab filter to use dict result
This commit is contained in:
parent
a4c2fad04f
commit
45469ff103
1 changed files with 14 additions and 25 deletions
|
@ -20,29 +20,21 @@ class CollabFilter:
|
||||||
)
|
)
|
||||||
_, problem_arr = self.embeddings.files
|
_, problem_arr = self.embeddings.files
|
||||||
self.name = name
|
self.name = name
|
||||||
self.problem_embeddings = self.embeddings[problem_arr]
|
self.problem_embeddings = self.embeddings[problem_arr].item()
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
def compute_scores(self, query_embedding, item_embeddings, measure=DOT):
|
def compute_scores(self, query_embedding, item_embeddings, measure=DOT):
|
||||||
"""Computes the scores of the candidates given a query.
|
"""Return {id: score}"""
|
||||||
Args:
|
|
||||||
query_embedding: a vector of shape [k], representing the query embedding.
|
|
||||||
item_embeddings: a matrix of shape [N, k], such that row i is the embedding
|
|
||||||
of item i.
|
|
||||||
measure: a string specifying the similarity measure to be used. Can be
|
|
||||||
either DOT or COSINE.
|
|
||||||
Returns:
|
|
||||||
scores: a vector of shape [N], such that scores[i] is the score of item i.
|
|
||||||
"""
|
|
||||||
u = query_embedding
|
u = query_embedding
|
||||||
V = item_embeddings
|
V = np.stack(list(item_embeddings.values()))
|
||||||
if measure == self.COSINE:
|
if measure == self.COSINE:
|
||||||
V = V / np.linalg.norm(V, axis=1, keepdims=True)
|
V = V / np.linalg.norm(V, axis=1, keepdims=True)
|
||||||
u = u / np.linalg.norm(u)
|
u = u / np.linalg.norm(u)
|
||||||
scores = u.dot(V.T)
|
scores = u.dot(V.T)
|
||||||
return scores
|
scores_by_id = {id_: s for id_, s in zip(item_embeddings.keys(), scores)}
|
||||||
|
return scores_by_id
|
||||||
|
|
||||||
def _get_embedding_version(self):
|
def _get_embedding_version(self):
|
||||||
first_problem = self.problem_embeddings[0]
|
first_problem = self.problem_embeddings[0]
|
||||||
|
@ -54,8 +46,8 @@ class CollabFilter:
|
||||||
@cache_wrapper(prefix="CFgue", timeout=86400)
|
@cache_wrapper(prefix="CFgue", timeout=86400)
|
||||||
def _get_user_embedding(self, user_id, embedding_version):
|
def _get_user_embedding(self, user_id, embedding_version):
|
||||||
user_arr, _ = self.embeddings.files
|
user_arr, _ = self.embeddings.files
|
||||||
user_embeddings = self.embeddings[user_arr]
|
user_embeddings = self.embeddings[user_arr].item()
|
||||||
if user_id >= len(user_embeddings):
|
if user_id not in user_embeddings:
|
||||||
return user_embeddings[0]
|
return user_embeddings[0]
|
||||||
return user_embeddings[user_id]
|
return user_embeddings[user_id]
|
||||||
|
|
||||||
|
@ -67,27 +59,24 @@ class CollabFilter:
|
||||||
def user_recommendations(self, user_id, problems, measure=DOT, limit=None):
|
def user_recommendations(self, user_id, problems, measure=DOT, limit=None):
|
||||||
user_embedding = self.get_user_embedding(user_id)
|
user_embedding = self.get_user_embedding(user_id)
|
||||||
scores = self.compute_scores(user_embedding, self.problem_embeddings, measure)
|
scores = self.compute_scores(user_embedding, self.problem_embeddings, measure)
|
||||||
|
|
||||||
res = [] # [(score, problem)]
|
res = [] # [(score, problem)]
|
||||||
for pid in problems:
|
for pid in problems:
|
||||||
if pid < len(scores):
|
if pid in scores:
|
||||||
res.append((scores[pid], pid))
|
res.append((scores[pid], pid))
|
||||||
|
|
||||||
res.sort(reverse=True, key=lambda x: x[0])
|
res.sort(reverse=True, key=lambda x: x[0])
|
||||||
res = res[:limit]
|
return res[:limit]
|
||||||
return res
|
|
||||||
|
|
||||||
# return a list of pid
|
# return a list of pid
|
||||||
def problem_neighbors(self, problem, problemset, measure=DOT, limit=None):
|
def problem_neighbors(self, problem, problemset, measure=DOT, limit=None):
|
||||||
pid = problem.id
|
pid = problem.id
|
||||||
if pid >= len(self.problem_embeddings):
|
if pid not in self.problem_embeddings:
|
||||||
return []
|
return None
|
||||||
scores = self.compute_scores(
|
embedding = self.problem_embeddings[pid]
|
||||||
self.problem_embeddings[pid], self.problem_embeddings, measure
|
scores = self.compute_scores(embedding, self.problem_embeddings, measure)
|
||||||
)
|
|
||||||
res = []
|
res = []
|
||||||
for p in problemset:
|
for p in problemset:
|
||||||
if p < len(scores):
|
if p in scores:
|
||||||
res.append((scores[p], p))
|
res.append((scores[p], p))
|
||||||
res.sort(reverse=True, key=lambda x: x[0])
|
res.sort(reverse=True, key=lambda x: x[0])
|
||||||
return res[:limit]
|
return res[:limit]
|
||||||
|
|
Loading…
Reference in a new issue