NDOJ/judge/utils/raw_sql.py

103 lines
2.6 KiB
Python
Raw Normal View History

2020-01-21 06:35:58 +00:00
from django.db import connections
from django.db.models.sql.constants import INNER, LOUTER
from django.db.models.sql.datastructures import Join
from judge.utils.cachedict import CacheDict
class RawSQLJoin(Join):
2022-05-14 17:57:27 +00:00
def __init__(
self,
subquery,
subquery_params,
parent_alias,
table_alias,
join_type,
join_field,
nullable,
filtered_relation=None,
):
2020-01-21 06:35:58 +00:00
self.subquery_params = subquery_params
2022-05-14 17:57:27 +00:00
super().__init__(
subquery,
parent_alias,
table_alias,
join_type,
join_field,
nullable,
filtered_relation,
)
2020-01-21 06:35:58 +00:00
def as_sql(self, compiler, connection):
2022-05-14 17:57:27 +00:00
compiler.quote_cache[self.table_name] = "(%s)" % self.table_name
2020-01-21 06:35:58 +00:00
sql, params = super().as_sql(compiler, connection)
return sql, self.subquery_params + params
class FakeJoinField:
2022-11-01 01:43:06 +00:00
def __init__(self, joining_columns, related_model):
2020-01-21 06:35:58 +00:00
self.joining_columns = joining_columns
2022-11-01 01:43:06 +00:00
self.related_model = related_model
2020-01-21 06:35:58 +00:00
def get_joining_columns(self):
return self.joining_columns
def get_extra_restriction(self, where_class, alias, remote_alias):
pass
2022-05-14 17:57:27 +00:00
def join_sql_subquery(
2022-11-01 01:43:06 +00:00
queryset,
subquery,
params,
join_fields,
alias,
related_model,
join_type=INNER,
parent_model=None,
2022-05-14 17:57:27 +00:00
):
2020-01-21 06:35:58 +00:00
if parent_model is not None:
parent_alias = parent_model._meta.db_table
else:
parent_alias = queryset.query.get_initial_alias()
2022-11-01 01:43:06 +00:00
if isinstance(queryset.query.external_aliases, dict): # Django 3.x
queryset.query.external_aliases[alias] = True
else:
queryset.query.external_aliases.add(alias)
2022-05-14 17:57:27 +00:00
join = RawSQLJoin(
subquery,
params,
parent_alias,
alias,
join_type,
2022-11-01 01:43:06 +00:00
FakeJoinField(join_fields, related_model),
2022-05-14 17:57:27 +00:00
join_type == LOUTER,
)
2020-01-21 06:35:58 +00:00
queryset.query.join(join)
join.table_alias = alias
def make_straight_join_query(QueryType):
class Query(QueryType):
def join(self, join, *args, **kwargs):
alias = super().join(join, *args, **kwargs)
join = self.alias_map[alias]
if join.join_type == INNER:
2022-05-14 17:57:27 +00:00
join.join_type = "STRAIGHT_JOIN"
2020-01-21 06:35:58 +00:00
return alias
return Query
straight_join_cache = CacheDict(make_straight_join_query)
def use_straight_join(queryset):
2022-05-14 17:57:27 +00:00
if connections[queryset.db].vendor != "mysql":
2020-01-21 06:35:58 +00:00
return
try:
cloner = queryset.query.chain
except AttributeError:
cloner = queryset.query.clone
queryset.query = cloner(straight_join_cache[type(queryset.query)])