import warnings
from sqlalchemy import orm
import django
from django.db.models.fields.related import (ForeignKey, OneToOneField,
ManyToManyField)
from django.db import connections, router
from django.db.backends import signals
from django.conf import settings
from .core import get_meta, get_engine, Cache
from .table import get_django_models, generate_tables
def get_session(alias='default', recreate=False):
connection = connections[alias]
if not hasattr(connection, 'sa_session') or recreate:
engine = get_engine(alias)
session = orm.sessionmaker(bind=engine)
connection.sa_session = session()
return connection.sa_session
def new_session(sender, connection, **kw):
if connection.alias in settings.DATABASES:
get_session(alias=connection.alias, recreate=True)
signals.connection_created.connect(new_session)
def get_remote_field(foreign_key):
if django.VERSION < (1, 8):
return foreign_key.related
elif django.VERSION < (1, 9):
return foreign_key.rel
return foreign_key.remote_field
def _extract_model_attrs(metadata, model, sa_models):
tables = metadata.tables
name = model._meta.db_table
qualname = (metadata.schema + '.' + name) if metadata.schema else name
table = tables[qualname]
fks = [t for t in model._meta.fields
if isinstance(t, (ForeignKey, OneToOneField))]
attrs = {}
rel_fields = fks + list(model._meta.many_to_many)
for f in model._meta.fields:
if not isinstance(f, (ForeignKey, OneToOneField)):
if f.model != model or f.column not in table.c:
continue # Fields from parent model are not supported
attrs[f.name] = orm.column_property(table.c[f.column])
for fk in rel_fields:
if not fk.column in table.c and not isinstance(fk, ManyToManyField):
continue
if django.VERSION < (1, 8):
parent_model = fk.related.parent_model
else:
parent_model = get_remote_field(fk).model
parent_model_meta = parent_model._meta
if parent_model_meta.proxy:
continue
p_table_name = parent_model_meta.db_table
p_table_qualname = (
metadata.schema + '.' + p_table_name
if metadata.schema else p_table_name
)
p_table = tables[p_table_qualname]
p_name = parent_model_meta.pk.column
if django.VERSION < (1, 9):
disable_backref = fk.rel.related_name and fk.rel.related_name.endswith('+')
backref = (fk.rel.related_name.lower().strip('+')
if fk.rel.related_name else None)
else:
disable_backref = fk.remote_field.related_name and fk.remote_field.related_name.endswith('+')
backref = (fk.remote_field.related_name.lower().strip('+')
if fk.remote_field.related_name else None)
if not backref and not disable_backref:
backref = model._meta.object_name.lower()
if not isinstance(fk, OneToOneField):
backref = backref + '_set'
elif backref and isinstance(fk, OneToOneField):
backref = orm.backref(backref, uselist=False)
kw = {}
if isinstance(fk, ManyToManyField):
model_pk = model._meta.pk.column
sec_table_name = get_remote_field(fk).field.m2m_db_table()
sec_table_qualname = (
metadata.schema + '.' + sec_table_name
if metadata.schema else sec_table_name
)
sec_table = tables[sec_table_qualname]
sec_column = fk.m2m_column_name()
p_sec_column = fk.m2m_reverse_name()
kw.update(
secondary=sec_table,
primaryjoin=(sec_table.c[sec_column] == table.c[model_pk]),
secondaryjoin=(sec_table.c[p_sec_column] == p_table.c[p_name])
)
if fk.model() != model:
backref = None
else:
kw.update(
foreign_keys=[table.c[fk.column]],
primaryjoin=(table.c[fk.column] == p_table.c[p_name]),
remote_side=p_table.c[p_name],
)
if backref:
kw.update(backref=backref)
attrs[fk.name] = orm.relationship(
sa_models[parent_model],
**kw
)
return attrs
def prepare_models():
metadata = get_meta()
models = [model for model in get_django_models() if not model._meta.proxy]
Cache.sa_models = construct_models(metadata)
Cache.models = {}
for model in models:
table_name = (
metadata.schema + '.' + model._meta.db_table
if metadata.schema else model._meta.db_table
)
Cache.models[table_name] = Cache.sa_models[model]
model.sa = Cache.sa_models[model]
def construct_models(metadata):
if not metadata.tables:
generate_tables(metadata)
tables = metadata.tables
models = [model for model in get_django_models() if not model._meta.proxy]
sa_models_by_django_models = {}
for model in models:
table_name = (
metadata.schema + '.' + model._meta.db_table
if metadata.schema else model._meta.db_table
)
mixin = getattr(model, 'aldjemy_mixin', None)
bases = (mixin, BaseSQLAModel) if mixin else (BaseSQLAModel, )
table = tables[table_name]
# because querying happens on sqlalchemy side, we can use only one
# type of queries for alias, so we use 'read' type
sa_model = type(model._meta.object_name, bases,
{'table': table,
'alias': router.db_for_read(model)})
sa_models_by_django_models[model] = sa_model
for model in models:
sa_model = sa_models_by_django_models[model]
table_name = (
metadata.schema + '.' + model._meta.db_table
if metadata.schema else model._meta.db_table
)
table = tables[table_name]
attrs = _extract_model_attrs(
metadata, model, sa_models_by_django_models)
orm.mapper(sa_model, table, attrs)
return sa_models_by_django_models
class BaseSQLAModel(object):
@classmethod
def query(cls, *a, **kw):
alias = getattr(cls, 'alias', 'default')
if a or kw:
return get_session(alias).query(*a, **kw)
return get_session(alias).query(cls)