Refactor problem feed code

This commit is contained in:
cuom1999 2023-11-09 02:43:11 -06:00
parent b6c9ce4763
commit 0b4eeb8751
4 changed files with 134 additions and 67 deletions

View file

@ -1,7 +1,9 @@
import numpy as np
from django.conf import settings
import os
import hashlib
from django.core.cache import cache
from django.conf import settings
from judge.caching import cache_wrapper
@ -12,14 +14,13 @@ class CollabFilter:
# name = 'collab_filter' or 'collab_filter_time'
def __init__(self, name):
embeddings = np.load(
self.embeddings = np.load(
os.path.join(settings.ML_OUTPUT_PATH, name + "/embeddings.npz"),
allow_pickle=True,
)
arr0, arr1 = embeddings.files
_, problem_arr = self.embeddings.files
self.name = name
self.user_embeddings = embeddings[arr0]
self.problem_embeddings = embeddings[arr1]
self.problem_embeddings = self.embeddings[problem_arr]
def __str__(self):
return self.name
@ -43,18 +44,32 @@ class CollabFilter:
scores = u.dot(V.T)
return scores
def _get_embedding_version(self):
first_problem = self.problem_embeddings[0]
array_bytes = first_problem.tobytes()
hash_object = hashlib.sha256(array_bytes)
hash_bytes = hash_object.digest()
return hash_bytes.hex()[:5]
@cache_wrapper(prefix="CFgue", timeout=86400)
def _get_user_embedding(self, user_id, embedding_version):
user_arr, _ = self.embeddings.files
user_embeddings = self.embeddings[user_arr]
if user_id >= len(user_embeddings):
return user_embeddings[0]
return user_embeddings[user_id]
def get_user_embedding(self, user_id):
version = self._get_embedding_version()
return self._get_user_embedding(user_id, version)
@cache_wrapper(prefix="user_recommendations", timeout=3600)
def user_recommendations(self, user, problems, measure=DOT, limit=None):
uid = user.id
if uid >= len(self.user_embeddings):
uid = 0
scores = self.compute_scores(
self.user_embeddings[uid], self.problem_embeddings, measure
)
def user_recommendations(self, user_id, problems, measure=DOT, limit=None):
user_embedding = self.get_user_embedding(user_id)
scores = self.compute_scores(user_embedding, self.problem_embeddings, measure)
res = [] # [(score, problem)]
for pid in problems:
# pid = problem.id
if pid < len(scores):
res.append((scores[pid], pid))