from django.db import connections from django.db.models.sql.constants import INNER, LOUTER from django.db.models.sql.datastructures import Join from judge.utils.cachedict import CacheDict class RawSQLJoin(Join): def __init__( self, subquery, subquery_params, parent_alias, table_alias, join_type, join_field, nullable, filtered_relation=None, ): self.subquery_params = subquery_params super().__init__( subquery, parent_alias, table_alias, join_type, join_field, nullable, filtered_relation, ) def as_sql(self, compiler, connection): compiler.quote_cache[self.table_name] = "(%s)" % self.table_name sql, params = super().as_sql(compiler, connection) return sql, self.subquery_params + params class FakeJoinField: def __init__(self, joining_columns, related_model): self.joining_columns = joining_columns self.related_model = related_model def get_joining_columns(self): return self.joining_columns def get_extra_restriction(self, where_class, alias, remote_alias): pass def join_sql_subquery( queryset, subquery, params, join_fields, alias, related_model, join_type=INNER, parent_model=None, ): if parent_model is not None: parent_alias = parent_model._meta.db_table else: parent_alias = queryset.query.get_initial_alias() if isinstance(queryset.query.external_aliases, dict): # Django 3.x queryset.query.external_aliases[alias] = True else: queryset.query.external_aliases.add(alias) join = RawSQLJoin( subquery, params, parent_alias, alias, join_type, FakeJoinField(join_fields, related_model), join_type == LOUTER, ) queryset.query.join(join) join.table_alias = alias def make_straight_join_query(QueryType): class Query(QueryType): def join(self, join, *args, **kwargs): alias = super().join(join, *args, **kwargs) join = self.alias_map[alias] if join.join_type == INNER: join.join_type = "STRAIGHT_JOIN" return alias return Query straight_join_cache = CacheDict(make_straight_join_query) def use_straight_join(queryset): if connections[queryset.db].vendor != "mysql": return try: cloner = queryset.query.chain except AttributeError: cloner = queryset.query.clone queryset.query = cloner(straight_join_cache[type(queryset.query)])