Refactor problem feed code
This commit is contained in:
parent
b6c9ce4763
commit
0b4eeb8751
4 changed files with 134 additions and 67 deletions
|
@ -1,7 +1,9 @@
|
|||
import numpy as np
|
||||
from django.conf import settings
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.conf import settings
|
||||
|
||||
from judge.caching import cache_wrapper
|
||||
|
||||
|
@ -12,14 +14,13 @@ class CollabFilter:
|
|||
|
||||
# name = 'collab_filter' or 'collab_filter_time'
|
||||
def __init__(self, name):
|
||||
embeddings = np.load(
|
||||
self.embeddings = np.load(
|
||||
os.path.join(settings.ML_OUTPUT_PATH, name + "/embeddings.npz"),
|
||||
allow_pickle=True,
|
||||
)
|
||||
arr0, arr1 = embeddings.files
|
||||
_, problem_arr = self.embeddings.files
|
||||
self.name = name
|
||||
self.user_embeddings = embeddings[arr0]
|
||||
self.problem_embeddings = embeddings[arr1]
|
||||
self.problem_embeddings = self.embeddings[problem_arr]
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
@ -43,18 +44,32 @@ class CollabFilter:
|
|||
scores = u.dot(V.T)
|
||||
return scores
|
||||
|
||||
def _get_embedding_version(self):
|
||||
first_problem = self.problem_embeddings[0]
|
||||
array_bytes = first_problem.tobytes()
|
||||
hash_object = hashlib.sha256(array_bytes)
|
||||
hash_bytes = hash_object.digest()
|
||||
return hash_bytes.hex()[:5]
|
||||
|
||||
@cache_wrapper(prefix="CFgue", timeout=86400)
|
||||
def _get_user_embedding(self, user_id, embedding_version):
|
||||
user_arr, _ = self.embeddings.files
|
||||
user_embeddings = self.embeddings[user_arr]
|
||||
if user_id >= len(user_embeddings):
|
||||
return user_embeddings[0]
|
||||
return user_embeddings[user_id]
|
||||
|
||||
def get_user_embedding(self, user_id):
|
||||
version = self._get_embedding_version()
|
||||
return self._get_user_embedding(user_id, version)
|
||||
|
||||
@cache_wrapper(prefix="user_recommendations", timeout=3600)
|
||||
def user_recommendations(self, user, problems, measure=DOT, limit=None):
|
||||
uid = user.id
|
||||
if uid >= len(self.user_embeddings):
|
||||
uid = 0
|
||||
scores = self.compute_scores(
|
||||
self.user_embeddings[uid], self.problem_embeddings, measure
|
||||
)
|
||||
def user_recommendations(self, user_id, problems, measure=DOT, limit=None):
|
||||
user_embedding = self.get_user_embedding(user_id)
|
||||
scores = self.compute_scores(user_embedding, self.problem_embeddings, measure)
|
||||
|
||||
res = [] # [(score, problem)]
|
||||
for pid in problems:
|
||||
# pid = problem.id
|
||||
if pid < len(scores):
|
||||
res.append((scores[pid], pid))
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue