from copy import copy from django.db import connections from django.db.models import Field from django.db.models.expressions import RawSQL from django.db.models.sql.constants import INNER, LOUTER from django.db.models.sql.datastructures import Join from django.utils import six from judge.utils.cachedict import CacheDict def unique_together_left_join(queryset, model, link_field_name, filter_field_name, filter_value, parent_model=None): link_field = copy(model._meta.get_field(link_field_name).remote_field) filter_field = model._meta.get_field(filter_field_name) def restrictions(where_class, alias, related_alias): cond = where_class() cond.add(filter_field.get_lookup('exact')(filter_field.get_col(alias), filter_value), 'AND') return cond link_field.get_extra_restriction = restrictions if parent_model is not None: parent_alias = parent_model._meta.db_table else: parent_alias = queryset.query.get_initial_alias() return queryset.query.join(Join(model._meta.db_table, parent_alias, None, LOUTER, link_field, True)) class RawSQLJoin(Join): def __init__(self, subquery, subquery_params, parent_alias, table_alias, join_type, join_field, nullable, filtered_relation=None): self.subquery_params = subquery_params super().__init__(subquery, parent_alias, table_alias, join_type, join_field, nullable, filtered_relation) def as_sql(self, compiler, connection): compiler.quote_cache[self.table_name] = '(%s)' % self.table_name sql, params = super().as_sql(compiler, connection) return sql, self.subquery_params + params class FakeJoinField: def __init__(self, joining_columns): self.joining_columns = joining_columns def get_joining_columns(self): return self.joining_columns def get_extra_restriction(self, where_class, alias, remote_alias): pass def join_sql_subquery(queryset, subquery, params, join_fields, alias, join_type=INNER, parent_model=None): if parent_model is not None: parent_alias = parent_model._meta.db_table else: parent_alias = queryset.query.get_initial_alias() queryset.query.external_aliases.add(alias) join = RawSQLJoin(subquery, params, parent_alias, alias, join_type, FakeJoinField(join_fields), join_type == LOUTER) queryset.query.join(join) join.table_alias = alias def RawSQLColumn(model, field=None): if isinstance(model, Field): field = model model = field.model if isinstance(field, six.string_types): field = model._meta.get_field(field) return RawSQL('%s.%s' % (model._meta.db_table, field.get_attname_column()[1]), ()) def make_straight_join_query(QueryType): class Query(QueryType): def join(self, join, *args, **kwargs): alias = super().join(join, *args, **kwargs) join = self.alias_map[alias] if join.join_type == INNER: join.join_type = 'STRAIGHT_JOIN' return alias return Query straight_join_cache = CacheDict(make_straight_join_query) def use_straight_join(queryset): if connections[queryset.db].vendor != 'mysql': return try: cloner = queryset.query.chain except AttributeError: cloner = queryset.query.clone queryset.query = cloner(straight_join_cache[type(queryset.query)])