Add cache wrapper

This commit is contained in:
cuom1999 2023-04-05 12:49:23 -05:00
parent 9d645841ae
commit 57a6233779
5 changed files with 145 additions and 119 deletions

View file

@ -4,13 +4,15 @@ import os
from django.core.cache import cache
import hashlib
from judge.caching import cache_wrapper
class CollabFilter:
DOT = "dot"
COSINE = "cosine"
# name = 'collab_filter' or 'collab_filter_time'
def __init__(self, name, **kwargs):
def __init__(self, name):
embeddings = np.load(
os.path.join(settings.ML_OUTPUT_PATH, name + "/embeddings.npz"),
allow_pickle=True,
@ -20,6 +22,9 @@ class CollabFilter:
self.user_embeddings = embeddings[arr0]
self.problem_embeddings = embeddings[arr1]
def __str__(self):
return self.name
def compute_scores(self, query_embedding, item_embeddings, measure=DOT):
"""Computes the scores of the candidates given a query.
Args:
@ -39,14 +44,9 @@ class CollabFilter:
scores = u.dot(V.T)
return scores
def user_recommendations(self, user, problems, measure=DOT, limit=None, **kwargs):
@cache_wrapper(prefix="user_recommendations", timeout=3600)
def user_recommendations(self, user, problems, measure=DOT, limit=None):
uid = user.id
problems_hash = hashlib.sha1(str(list(problems)).encode()).hexdigest()
cache_key = ":".join(map(str, [self.name, uid, measure, limit, problems_hash]))
value = cache.get(cache_key)
if value:
return value
if uid >= len(self.user_embeddings):
uid = 0
scores = self.compute_scores(
@ -61,7 +61,6 @@ class CollabFilter:
res.sort(reverse=True, key=lambda x: x[0])
res = res[:limit]
cache.set(cache_key, res, 3600)
return res
# return a list of pid