Source code for aiida.storage.psql_dos.backend

# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved.                     #
# This file is part of the AiiDA code.                                    #
#                                                                         #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file        #
# For further information please visit http://www.aiida.net               #
###########################################################################
"""SqlAlchemy implementation of `aiida.orm.implementation.backends.Backend`."""
# pylint: disable=missing-function-docstring
from contextlib import contextmanager, nullcontext
import functools
from typing import TYPE_CHECKING, Iterator, List, Optional, Sequence, Set, Union

from disk_objectstore import Container
from sqlalchemy import table
from sqlalchemy.orm import Session, scoped_session, sessionmaker

from aiida.common.exceptions import ClosedStorage, IntegrityError
from aiida.manage.configuration.profile import Profile
from aiida.orm import User
from aiida.orm.entities import EntityTypes
from aiida.orm.implementation import BackendEntity, StorageBackend
from aiida.storage.log import STORAGE_LOGGER
from aiida.storage.psql_dos.migrator import REPOSITORY_UUID_KEY, PsqlDostoreMigrator
from aiida.storage.psql_dos.models import base

from .orm import authinfos, comments, computers, convert, groups, logs, nodes, querybuilder, users

if TYPE_CHECKING:
    from aiida.repository.backend import DiskObjectStoreRepositoryBackend

__all__ = ('PsqlDosBackend',)

CONTAINER_DEFAULTS: dict = {
    'pack_size_target': 4 * 1024 * 1024 * 1024,
    'loose_prefix_len': 2,
    'hash_type': 'sha256',
    'compression_algorithm': 'zlib+1'
}


[docs]class PsqlDosBackend(StorageBackend): # pylint: disable=too-many-public-methods """An AiiDA storage backend that stores data in a PostgreSQL database and disk-objectstore repository. Note, there were originally two such backends, `sqlalchemy` and `django`. The `django` backend was removed, to consolidate access to this storage. """ migrator = PsqlDostoreMigrator
[docs] @classmethod def version_head(cls) -> str: return cls.migrator.get_schema_version_head()
[docs] @classmethod def version_profile(cls, profile: Profile) -> Optional[str]: return cls.migrator(profile).get_schema_version_profile(check_legacy=True)
[docs] @classmethod def migrate(cls, profile: Profile) -> None: cls.migrator(profile).migrate()
[docs] def __init__(self, profile: Profile) -> None: super().__init__(profile) # check that the storage is reachable and at the correct version self.migrator(profile).validate_storage() self._session_factory: Optional[scoped_session] = None self._initialise_session() # save the URL of the database, for use in the __str__ method self._db_url = self.get_session().get_bind().url # type: ignore self._authinfos = authinfos.SqlaAuthInfoCollection(self) self._comments = comments.SqlaCommentCollection(self) self._computers = computers.SqlaComputerCollection(self) self._groups = groups.SqlaGroupCollection(self) self._logs = logs.SqlaLogCollection(self) self._nodes = nodes.SqlaNodeCollection(self) self._users = users.SqlaUserCollection(self)
@property def is_closed(self) -> bool: return self._session_factory is None
[docs] def __str__(self) -> str: repo_uri = self.profile.storage_config['repository_uri'] state = 'closed' if self.is_closed else 'open' return f'Storage for {self.profile.name!r} [{state}] @ {self._db_url!r} / {repo_uri}'
[docs] def _initialise_session(self): """Initialise the SQLAlchemy session factory. Only one session factory is ever associated with a given class instance, i.e. once the instance is closed, it cannot be reopened. The session factory, returns a session that is bound to the current thread. Multi-thread support is currently required by the REST API. Although, in the future, we may want to move the multi-thread handling to higher in the AiiDA stack. """ from aiida.storage.psql_dos.utils import create_sqlalchemy_engine engine = create_sqlalchemy_engine(self._profile.storage_config) self._session_factory = scoped_session(sessionmaker(bind=engine, future=True, expire_on_commit=True))
[docs] def get_session(self) -> Session: """Return an SQLAlchemy session bound to the current thread.""" if self._session_factory is None: raise ClosedStorage(str(self)) return self._session_factory()
[docs] def close(self) -> None: if self._session_factory is None: return # the instance is already closed, and so this is a no-op # reset the cached default user instance, since it will now have no associated session User.objects(self).reset() # close the connection # pylint: disable=no-member engine = self._session_factory.bind if engine is not None: engine.dispose() # type: ignore self._session_factory.expunge_all() self._session_factory.close() self._session_factory = None
[docs] def _clear(self, recreate_user: bool = True) -> None: from aiida.storage.psql_dos.models.settings import DbSetting from aiida.storage.psql_dos.models.user import DbUser super()._clear(recreate_user) session = self.get_session() # clear the database with self.transaction(): # save the default user default_user_kwargs = None if recreate_user: default_user = User.objects(self).get_default() if default_user is not None: default_user_kwargs = { 'email': default_user.email, 'first_name': default_user.first_name, 'last_name': default_user.last_name, 'institution': default_user.institution, } # now clear the database for table_name in ( 'db_dbgroup_dbnodes', 'db_dbgroup', 'db_dblink', 'db_dbnode', 'db_dblog', 'db_dbauthinfo', 'db_dbuser', 'db_dbcomputer' ): session.execute(table(table_name).delete()) session.expunge_all() # restore the default user if recreate_user and default_user_kwargs: session.add(DbUser(**default_user_kwargs)) # clear aiida's cache of the default user User.objects(self).reset() # Clear the repository and reset the repository UUID container = Container(self.profile.repository_path / 'container') container.init_container(clear=True, **CONTAINER_DEFAULTS) container_id = container.container_id with self.transaction(): session.execute( DbSetting.__table__.update().where(DbSetting.key == REPOSITORY_UUID_KEY).values(val=container_id) )
[docs] def get_repository(self) -> 'DiskObjectStoreRepositoryBackend': from aiida.repository.backend import DiskObjectStoreRepositoryBackend container = Container(self.profile.repository_path / 'container') return DiskObjectStoreRepositoryBackend(container=container)
@property def authinfos(self): return self._authinfos @property def comments(self): return self._comments @property def computers(self): return self._computers @property def groups(self): return self._groups @property def logs(self): return self._logs @property def nodes(self): return self._nodes
[docs] def query(self): return querybuilder.SqlaQueryBuilder(self)
@property def users(self): return self._users
[docs] @contextmanager def transaction(self) -> Iterator[Session]: """Open a transaction to be used as a context manager. If there is an exception within the context then the changes will be rolled back and the state will be as before entering. Transactions can be nested. """ session = self.get_session() if session.in_transaction(): with session.begin_nested(): yield session session.commit() else: with session.begin(): with session.begin_nested(): yield session
@property def in_transaction(self) -> bool: return self.get_session().in_nested_transaction()
[docs] @staticmethod @functools.lru_cache(maxsize=18) def _get_mapper_from_entity(entity_type: EntityTypes, with_pk: bool): """Return the Sqlalchemy mapper and fields corresponding to the given entity. :param with_pk: if True, the fields returned will include the primary key """ from sqlalchemy import inspect from aiida.storage.psql_dos.models.authinfo import DbAuthInfo from aiida.storage.psql_dos.models.comment import DbComment from aiida.storage.psql_dos.models.computer import DbComputer from aiida.storage.psql_dos.models.group import DbGroup, DbGroupNode from aiida.storage.psql_dos.models.log import DbLog from aiida.storage.psql_dos.models.node import DbLink, DbNode from aiida.storage.psql_dos.models.user import DbUser model = { EntityTypes.AUTHINFO: DbAuthInfo, EntityTypes.COMMENT: DbComment, EntityTypes.COMPUTER: DbComputer, EntityTypes.GROUP: DbGroup, EntityTypes.LOG: DbLog, EntityTypes.NODE: DbNode, EntityTypes.USER: DbUser, EntityTypes.LINK: DbLink, EntityTypes.GROUP_NODE: DbGroupNode, }[entity_type] mapper = inspect(model).mapper keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key} return mapper, keys
[docs] def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults: bool = False) -> List[int]: mapper, keys = self._get_mapper_from_entity(entity_type, False) if not rows: return [] if entity_type in (EntityTypes.COMPUTER, EntityTypes.LOG, EntityTypes.AUTHINFO): for row in rows: row['_metadata'] = row.pop('metadata') if allow_defaults: for row in rows: if not keys.issuperset(row): raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') else: for row in rows: if set(row) != keys: raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} != {keys}') # note for postgresql+psycopg2 we could also use `save_all` + `flush` with minimal performance degradation, see # https://docs.sqlalchemy.org/en/14/changelog/migration_14.html#orm-batch-inserts-with-psycopg2-now-batch-statements-with-returning-in-most-cases # by contrast, in sqlite, bulk_insert is faster: https://docs.sqlalchemy.org/en/14/faq/performance.html session = self.get_session() with (nullcontext() if self.in_transaction else self.transaction()): session.bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True) return [row['id'] for row in rows]
[docs] def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None: # pylint: disable=no-self-use mapper, keys = self._get_mapper_from_entity(entity_type, True) if not rows: return None for row in rows: if 'id' not in row: raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}") if not keys.issuperset(row): raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') session = self.get_session() with (nullcontext() if self.in_transaction else self.transaction()): session.bulk_update_mappings(mapper, rows)
[docs] def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None: # pylint: disable=no-self-use # pylint: disable=no-value-for-parameter from aiida.storage.psql_dos.models.group import DbGroupNode from aiida.storage.psql_dos.models.node import DbLink, DbNode if not self.in_transaction: raise AssertionError('Cannot delete nodes and links outside a transaction') session = self.get_session() # Delete the membership of these nodes to groups. session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete)) ).delete(synchronize_session='fetch') # Delete the links coming out of the nodes marked for deletion. session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') # Delete the links pointing to the nodes marked for deletion. session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') # Delete the actual nodes session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
[docs] def get_backend_entity(self, model: base.Base) -> BackendEntity: """ Return the backend entity that corresponds to the given Model instance :param model: the ORM model instance to promote to a backend instance :return: the backend entity corresponding to the given model """ return convert.get_backend_entity(model, self)
[docs] def set_global_variable( self, key: str, value: Union[None, str, int, float], description: Optional[str] = None, overwrite=True ) -> None: from aiida.storage.psql_dos.models.settings import DbSetting session = self.get_session() with (nullcontext() if self.in_transaction else self.transaction()): if session.query(DbSetting).filter(DbSetting.key == key).count(): if overwrite: session.query(DbSetting).filter(DbSetting.key == key).update(dict(val=value)) else: raise ValueError(f'The setting {key} already exists') else: session.add(DbSetting(key=key, val=value, description=description or ''))
[docs] def get_global_variable(self, key: str) -> Union[None, str, int, float]: from aiida.storage.psql_dos.models.settings import DbSetting session = self.get_session() with (nullcontext() if self.in_transaction else self.transaction()): setting = session.query(DbSetting).filter(DbSetting.key == key).one_or_none() if setting is None: raise KeyError(f'No setting found with key {key}') return setting.val
[docs] def maintain(self, full: bool = False, dry_run: bool = False, **kwargs) -> None: from aiida.manage.profile_access import ProfileAccessManager repository = self.get_repository() if full: maintenance_context = ProfileAccessManager(self._profile).lock else: maintenance_context = nullcontext with maintenance_context(): unreferenced_objects = self.get_unreferenced_keyset() STORAGE_LOGGER.info(f'Deleting {len(unreferenced_objects)} unreferenced objects ...') if not dry_run: repository.delete_objects(list(unreferenced_objects)) STORAGE_LOGGER.info('Starting repository-specific operations ...') repository.maintain(live=not full, dry_run=dry_run, **kwargs)
[docs] def get_unreferenced_keyset(self, check_consistency: bool = True) -> Set[str]: """Returns the keyset of objects that exist in the repository but are not tracked by AiiDA. This should be all the soft-deleted files. :param check_consistency: toggle for a check that raises if there are references in the database with no actual object in the underlying repository. :return: a set with all the objects in the underlying repository that are not referenced in the database. """ from aiida import orm STORAGE_LOGGER.info('Obtaining unreferenced object keys ...') repository = self.get_repository() keyset_repository = set(repository.list_objects()) keyset_database = set(orm.Node.objects(self).iter_repo_keys()) if check_consistency: keyset_missing = keyset_database - keyset_repository if len(keyset_missing) > 0: raise RuntimeError( 'There are objects referenced in the database that are not present in the repository. Aborting!' ) return keyset_repository - keyset_database
[docs] def get_info(self, detailed: bool = False) -> dict: results = super().get_info(detailed=detailed) results['repository'] = self.get_repository().get_info(detailed) return results