Source code for aiida.storage.psql_dos.orm.groups

###########################################################################
# Copyright (c), The AiiDA team. All rights reserved.                     #
# This file is part of the AiiDA code.                                    #
#                                                                         #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file        #
# For further information please visit http://www.aiida.net               #
###########################################################################
"""SQLA groups"""

import logging

from aiida.common.exceptions import UniquenessError
from aiida.common.lang import type_check
from aiida.orm.implementation.groups import BackendGroup, BackendGroupCollection
from aiida.storage.psql_dos.models.group import DbGroup, DbGroupNode

from . import entities, users, utils
from .extras_mixin import ExtrasMixin
from .nodes import SqlaNode

_LOGGER = logging.getLogger(__name__)


# Unfortunately the linter doesn't seem to be able to pick up on the fact that the abstract property 'id'
# of BackendGroup is actually implemented in SqlaModelEntity so disable the abstract check
[docs] class SqlaGroup(entities.SqlaModelEntity[DbGroup], ExtrasMixin, BackendGroup): """The SQLAlchemy Group object""" MODEL_CLASS = DbGroup USER_CLASS = users.SqlaUser NODE_CLASS = SqlaNode GROUP_NODE_CLASS = DbGroupNode
[docs] def __init__(self, backend, label, user, description='', type_string=''): """Construct a new SQLA group :param backend: the backend to use :param label: the group label :param user: the owner of the group :param description: an optional group description :param type_string: an optional type for the group to contain """ type_check(user, self.USER_CLASS) super().__init__(backend) dbgroup = self.MODEL_CLASS(label=label, description=description, user=user.bare_model, type_string=type_string) self._model = utils.ModelWrapper(dbgroup, backend)
@property def label(self): return self.model.label @label.setter def label(self, label): """Attempt to change the label of the group instance. If the group is already stored and the another group of the same type already exists with the desired label, a UniquenessError will be raised :param label: the new group label :raises aiida.common.UniquenessError: if another group of same type and label already exists """ self.model.label = label if self.is_stored: try: self.model.save() except Exception: raise UniquenessError(f'a group of the same type with the label {label} already exists') from Exception @property def description(self): return self.model.description @description.setter def description(self, value): self.model.description = value # Update the entry in the DB, if the group is already stored if self.is_stored: self.model.save() @property def type_string(self): return self.model.type_string @property def user(self): return self.backend.users.ENTITY_CLASS.from_dbmodel(self.model.user, self.backend) @user.setter def user(self, new_user): type_check(new_user, self.USER_CLASS) self.model.user = new_user.bare_model @property def pk(self): return self.model.id @property def uuid(self): return str(self.model.uuid) @property def is_stored(self): return self.pk is not None
[docs] def store(self): self.model.save() return self
[docs] def count(self): """Return the number of entities in this group. :return: integer number of entities contained within the group """ from sqlalchemy.orm import aliased group = aliased(self.MODEL_CLASS) nodes = aliased(self.GROUP_NODE_CLASS) session = self.backend.get_session() return session.query(group).join(nodes, nodes.dbgroup_id == group.id).filter(group.id == self.pk).count()
[docs] def clear(self): """Remove all the nodes from this group.""" session = self.backend.get_session() # Note we have to call `bare_model` to circumvent flushing data to the database self.bare_model.dbnodes = [] session.commit()
@property def nodes(self): """Get an iterator to all the nodes in the group""" class Iterator: """Nodes iterator""" def __init__(self, dbnodes, backend): self._backend = backend self._dbnodes = dbnodes self.generator = self._genfunction() def _genfunction(self): for node in self._dbnodes: yield self._backend.get_backend_entity(node) def __iter__(self): return self def __len__(self): return self._dbnodes.count() def __getitem__(self, value): if isinstance(value, slice): return [self._backend.get_backend_entity(n) for n in self._dbnodes[value]] return self._backend.get_backend_entity(self._dbnodes[value]) def __next__(self): return next(self.generator) return Iterator(self.model.dbnodes, self._backend)
[docs] def add_nodes(self, nodes, **kwargs): """Add a node or a set of nodes to the group. :note: all the nodes *and* the group itself have to be stored. :param nodes: a list of `BackendNode` instance to be added to this group :param kwargs: skip_orm: When the flag is on, the SQLA ORM is skipped and SQLA is used to create a direct SQL INSERT statement to the group-node relationship table (to improve speed). """ from sqlalchemy.dialects.postgresql import insert from sqlalchemy.exc import IntegrityError super().add_nodes(nodes) skip_orm = kwargs.get('skip_orm', False) def check_node(given_node): """Check if given node is of correct type and stored""" if not isinstance(given_node, self.NODE_CLASS): raise TypeError(f'invalid type {type(given_node)}, has to be {self.NODE_CLASS}') if not given_node.is_stored: raise ValueError('At least one of the provided nodes is unstored, stopping...') with utils.disable_expire_on_commit(self.backend.get_session()) as session: if not skip_orm: # Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database dbnodes = self.model.dbnodes for node in nodes: check_node(node) # Use pattern as suggested here: # http://docs.sqlalchemy.org/en/latest/orm/session_transaction.html#using-savepoint try: with session.begin_nested(): dbnodes.append(node.bare_model) session.flush() except IntegrityError: # Duplicate entry, skip pass else: ins_dict = [] for node in nodes: check_node(node) ins_dict.append({'dbnode_id': node.id, 'dbgroup_id': self.id}) table = self.GROUP_NODE_CLASS.__table__ ins = insert(table).values(ins_dict) session.execute(ins.on_conflict_do_nothing(index_elements=['dbnode_id', 'dbgroup_id'])) # Commit everything as up till now we've just flushed if not session.in_nested_transaction(): session.commit()
[docs] def remove_nodes(self, nodes, **kwargs): """Remove a node or a set of nodes from the group. :note: all the nodes *and* the group itself have to be stored. :param nodes: a list of `BackendNode` instance to be added to this group :param kwargs: skip_orm: When the flag is set to `True`, the SQLA ORM is skipped and SQLA is used to create a direct SQL DELETE statement to the group-node relationship table in order to improve speed. """ from sqlalchemy import and_ super().remove_nodes(nodes) # Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database dbnodes = self.model.dbnodes skip_orm = kwargs.get('skip_orm', False) def check_node(node): if not isinstance(node, self.NODE_CLASS): raise TypeError(f'invalid type {type(node)}, has to be {self.NODE_CLASS}') if node.id is None: raise ValueError('At least one of the provided nodes is unstored, stopping...') list_nodes = [] with utils.disable_expire_on_commit(self.backend.get_session()) as session: if not skip_orm: for node in nodes: check_node(node) # Check first, if SqlA issues a DELETE statement for an unexisting key it will result in an error if node.bare_model in dbnodes: list_nodes.append(node.bare_model) for node in list_nodes: dbnodes.remove(node) else: table = self.GROUP_NODE_CLASS.__table__ for node in nodes: check_node(node) clause = and_(table.c.dbnode_id == node.id, table.c.dbgroup_id == self.id) statement = table.delete().where(clause) session.execute(statement) if not session.in_nested_transaction(): session.commit()
[docs] class SqlaGroupCollection(BackendGroupCollection): """The SLQA collection of groups""" ENTITY_CLASS = SqlaGroup
[docs] def delete(self, id): session = self.backend.get_session() row = session.get(self.ENTITY_CLASS.MODEL_CLASS, id) session.delete(row) session.commit()