NDOJ/judge/utils/raw_sql.py
2022-05-14 12:57:27 -05:00

130 lines
3.6 KiB
Python

from copy import copy
from django.db import connections
from django.db.models import Field
from django.db.models.expressions import RawSQL
from django.db.models.sql.constants import INNER, LOUTER
from django.db.models.sql.datastructures import Join
from django.utils import six
from judge.utils.cachedict import CacheDict
def unique_together_left_join(
queryset, model, link_field_name, filter_field_name, filter_value, parent_model=None
):
link_field = copy(model._meta.get_field(link_field_name).remote_field)
filter_field = model._meta.get_field(filter_field_name)
def restrictions(where_class, alias, related_alias):
cond = where_class()
cond.add(
filter_field.get_lookup("exact")(filter_field.get_col(alias), filter_value),
"AND",
)
return cond
link_field.get_extra_restriction = restrictions
if parent_model is not None:
parent_alias = parent_model._meta.db_table
else:
parent_alias = queryset.query.get_initial_alias()
return queryset.query.join(
Join(model._meta.db_table, parent_alias, None, LOUTER, link_field, True)
)
class RawSQLJoin(Join):
def __init__(
self,
subquery,
subquery_params,
parent_alias,
table_alias,
join_type,
join_field,
nullable,
filtered_relation=None,
):
self.subquery_params = subquery_params
super().__init__(
subquery,
parent_alias,
table_alias,
join_type,
join_field,
nullable,
filtered_relation,
)
def as_sql(self, compiler, connection):
compiler.quote_cache[self.table_name] = "(%s)" % self.table_name
sql, params = super().as_sql(compiler, connection)
return sql, self.subquery_params + params
class FakeJoinField:
def __init__(self, joining_columns):
self.joining_columns = joining_columns
def get_joining_columns(self):
return self.joining_columns
def get_extra_restriction(self, where_class, alias, remote_alias):
pass
def join_sql_subquery(
queryset, subquery, params, join_fields, alias, join_type=INNER, parent_model=None
):
if parent_model is not None:
parent_alias = parent_model._meta.db_table
else:
parent_alias = queryset.query.get_initial_alias()
queryset.query.external_aliases.add(alias)
join = RawSQLJoin(
subquery,
params,
parent_alias,
alias,
join_type,
FakeJoinField(join_fields),
join_type == LOUTER,
)
queryset.query.join(join)
join.table_alias = alias
def RawSQLColumn(model, field=None):
if isinstance(model, Field):
field = model
model = field.model
if isinstance(field, six.string_types):
field = model._meta.get_field(field)
return RawSQL("%s.%s" % (model._meta.db_table, field.get_attname_column()[1]), ())
def make_straight_join_query(QueryType):
class Query(QueryType):
def join(self, join, *args, **kwargs):
alias = super().join(join, *args, **kwargs)
join = self.alias_map[alias]
if join.join_type == INNER:
join.join_type = "STRAIGHT_JOIN"
return alias
return Query
straight_join_cache = CacheDict(make_straight_join_query)
def use_straight_join(queryset):
if connections[queryset.db].vendor != "mysql":
return
try:
cloner = queryset.query.chain
except AttributeError:
cloner = queryset.query.clone
queryset.query = cloner(straight_join_cache[type(queryset.query)])