Source code for aiida.orm.implementation.sqlalchemy.utils

# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved.                     #
# This file is part of the AiiDA code.                                    #
#                                                                         #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida_core #
# For further information on the license, see the LICENSE.txt file        #
# For further information please visit http://www.aiida.net               #
###########################################################################
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import contextlib

import six
from sqlalchemy import inspect
from sqlalchemy.orm.attributes import flag_modified
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.types import Integer, Boolean
import sqlalchemy.exc

from aiida.common import exceptions
from aiida.backends.sqlalchemy import get_scoped_session

__all__ = ('django_filter', 'get_attr')


class ModelWrapper(object):
    """
    This wraps a SQLA model delegating all get/set attributes to the
    underlying model class, BUT it will make sure that if the model is
    stored then:
    * before every read it has the latest value from the database, and,
    * after ever write the updated value is flushed to the database.
    """

    def __init__(self, model, auto_flush=()):
        """Construct the ModelWrapper.

        :param model: the database model instance to wrap
        :param auto_flush: an optional tuple of database model fields that are always to be flushed, in addition to
            the field that corresponds to the attribute being set through `__setattr__`.
        """
        super(ModelWrapper, self).__init__()
        # Have to do it this way because we overwrite __setattr__
        object.__setattr__(self, '_model', model)
        object.__setattr__(self, '_auto_flush', auto_flush)

    def __getattr__(self, item):
        # Python 3's implementation of copy.copy does not call __init__ on the new object
        # but manually restores attributes instead. Make sure we never get into a recursive
        # loop by protecting the only special variable here: _model
        if item == '_model':
            raise AttributeError()

        if self.is_saved() and not self._in_transaction() and self._is_model_field(item):
            self._ensure_model_uptodate(fields=(item,))

        return getattr(self._model, item)

    def __setattr__(self, key, value):
        setattr(self._model, key, value)
        if self.is_saved() and self._is_model_field(key):
            fields = set((key,) + self._auto_flush)
            self._flush(fields=fields)

    def is_saved(self):
        return self._model.id is not None

    def save(self):
        """Store the model (possibly updating values if changed)."""
        try:
            commit = not self._in_transaction()
            self._model.save(commit=commit)
        except sqlalchemy.exc.IntegrityError as e:
            self._model.session.rollback()
            raise exceptions.IntegrityError(str(e))

    def _is_model_field(self, name):
        return inspect(self._model.__class__).has_property(name)

    def _flush(self, fields=()):
        """If the model is stored then save the current value."""
        if self.is_saved():
            for field in fields:
                flag_modified(self._model, field)

            self.save()

    def _ensure_model_uptodate(self, fields=None):
        if self.is_saved():
            self._model.session.expire(self._model, attribute_names=fields)

    @staticmethod
    def _in_transaction():
        return get_scoped_session().transaction.nested


@contextlib.contextmanager
def disable_expire_on_commit(session):
    """
    Context manager that disables expire_on_commit and restores the original value on exit

    :param session: The SQLA session
    :type session: :class:`sqlalchemy.orm.session.Session`
    """
    current_value = session.expire_on_commit
    session.expire_on_commit = False
    try:
        yield session
    finally:
        session.expire_on_commit = current_value


def iter_dict(attrs):
    if isinstance(attrs, dict):
        for key in sorted(attrs.keys()):
            it = iter_dict(attrs[key])
            for k, v in it:
                new_key = key
                if k:
                    new_key += "." + str(k)
                yield new_key, v
    elif isinstance(attrs, list):
        for i, val in enumerate(attrs):
            it = iter_dict(val)
            for k, v in it:
                new_key = str(i)
                if k:
                    new_key += "." + str(k)
                yield new_key, v
    else:
        yield "", attrs


[docs]def get_attr(attrs, key): path = key.split('.') d = attrs for p in path: if p.isdigit(): p = int(p) # Let it raise the appropriate exception d = d[p] return d
def _create_op_func(op): def f(attr, val): return getattr(attr, op)(val) return f _from_op = { 'in': _create_op_func('in_'), 'gte': _create_op_func('__ge__'), 'gt': _create_op_func('__gt__'), 'lte': _create_op_func('__le__'), 'lt': _create_op_func('__lt__'), 'eq': _create_op_func('__eq__'), 'startswith': lambda attr, val: attr.like('{}%'.format(val)), 'contains': lambda attr, val: attr.like('%{}%'.format(val)), 'endswith': lambda attr, val: attr.like('%{}'.format(val)), 'istartswith': lambda attr, val: attr.ilike('{}%'.format(val)), 'icontains': lambda attr, val: attr.ilike('%{}%'.format(val)), 'iendswith': lambda attr, val: attr.ilike('%{}'.format(val)) }
[docs]def django_filter(cls_query, **kwargs): # Pass the query object you want to use. # This also assume a AND between each arguments cls = inspect(cls_query)._entity_zero().class_ q = cls_query # We regroup all the filter on a relationship at the same place, so that # when a join is done, we can filter it, and then reset to the original # query. current_join = None tmp_attr = dict(key=None, val=None) tmp_extra = dict(key=None, val=None) for key in sorted(kwargs.keys()): val = kwargs[key] join, field, op = [None] * 3 splits = key.split("__") if len(splits) > 3: raise ValueError("Too many parameters to handle.") # something like "computer__id__in" elif len(splits) == 3: join, field, op = splits # we have either "computer__id", which means join + field quality or # "id__gte" which means field + op elif len(splits) == 2: if splits[1] in _from_op.keys(): field, op = splits else: join, field = splits else: field = splits[0] if "dbattributes" == join: if "val" in field: field = "val" if field in ["key", "val"]: tmp_attr[field] = val continue elif "dbextras" == join: if "val" in field: field = "val" if field in ["key", "val"]: tmp_extra[field] = val continue current_cls = cls if join: if current_join != join: q = q.join(join, aliased=True) current_join = join current_cls = [r for r in inspect(cls).relationships.items() if r[0] == join][0][1].argument if isinstance(current_cls, Mapper): current_cls = current_cls.class_ else: current_cls = current_cls() else: if current_join is not None: # Filter on the queried class again q = q.reset_joinpoint() current_join = None if field == "pk": field = "id" filtered_field = getattr(current_cls, field) if not op: op = "eq" f = _from_op[op] q = q.filter(f(filtered_field, val)) # We reset one last time q.reset_joinpoint() key = tmp_attr["key"] if key: val = tmp_attr["val"] if val: q = q.filter(apply_json_cast(cls.attributes[key], val) == val) else: q = q.filter(tmp_attr["key"] in cls.attributes) key = tmp_extra["key"] if key: val = tmp_extra["val"] if val: q = q.filter(apply_json_cast(cls.extras[key], val) == val) else: q = q.filter(tmp_extra["key"] in cls.extras) return q
def apply_json_cast(attr, val): if isinstance(val, six.string_types): attr = attr.astext if isinstance(val, six.integer_types): attr = attr.astext.cast(Integer) if isinstance(val, bool): attr = attr.astext.cast(Boolean) return attr