Source code for aiida.engine.persistence

# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved.                     #
# This file is part of the AiiDA code.                                    #
#                                                                         #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file        #
# For further information please visit http://www.aiida.net               #
###########################################################################
# pylint: disable=global-statement
"""Definition of AiiDA's process persister and the necessary object loaders."""

import importlib
import logging
import traceback

import plumpy

from aiida.orm.utils import serialize

__all__ = ('AiiDAPersister', 'ObjectLoader', 'get_object_loader')

LOGGER = logging.getLogger(__name__)
OBJECT_LOADER = None


[docs]class ObjectLoader(plumpy.DefaultObjectLoader): """Custom object loader for `aiida-core`."""
[docs] def load_object(self, identifier): """Attempt to load the object identified by the given `identifier`. .. note:: We override the `plumpy.DefaultObjectLoader` to be able to throw an `ImportError` instead of a `ValueError` which in the context of `aiida-core` is not as apt, since we are loading classes. :param identifier: concatenation of module and resource name :return: loaded object :raises ImportError: if the object cannot be loaded """ module, name = identifier.split(':') try: module = importlib.import_module(module) except ImportError: raise ImportError("module '{}' from identifier '{}' could not be loaded".format(module, identifier)) try: return getattr(module, name) except AttributeError: raise ImportError("object '{}' from identifier '{}' could not be loaded".format(name, identifier))
[docs]def get_object_loader(): """Return the global AiiDA object loader. :return: The global object loader :rtype: :class:`plumpy.ObjectLoader` """ global OBJECT_LOADER if OBJECT_LOADER is None: OBJECT_LOADER = ObjectLoader() return OBJECT_LOADER
[docs]class AiiDAPersister(plumpy.Persister): """Persister to take saved process instance states and persisting them to the database."""
[docs] def save_checkpoint(self, process, tag=None): """Persist a Process instance. :param process: :class:`aiida.engine.Process` :param tag: optional checkpoint identifier to allow distinguishing multiple checkpoints for the same process :raises: :class:`plumpy.PersistenceError` Raised if there was a problem saving the checkpoint """ LOGGER.debug('Persisting process<%d>', process.pid) if tag is not None: raise NotImplementedError('Checkpoint tags not supported yet') try: bundle = plumpy.Bundle(process, plumpy.LoadSaveContext(loader=get_object_loader())) except ImportError: # Couldn't create the bundle raise plumpy.PersistenceError( "Failed to create a bundle for '{}': {}".format(process, traceback.format_exc()) ) try: process.node.set_checkpoint(serialize.serialize(bundle)) except Exception: raise plumpy.PersistenceError( "Failed to store a checkpoint for '{}': {}".format(process, traceback.format_exc()) ) return bundle
[docs] def load_checkpoint(self, pid, tag=None): """Load a process from a persisted checkpoint by its process id. :param pid: the process id of the :class:`plumpy.Process` :param tag: optional checkpoint identifier to allow retrieving a specific sub checkpoint :return: a bundle with the process state :rtype: :class:`plumpy.Bundle` :raises: :class:`plumpy.PersistenceError` Raised if there was a problem loading the checkpoint """ from aiida.common.exceptions import MultipleObjectsError, NotExistent from aiida.orm import load_node if tag is not None: raise NotImplementedError('Checkpoint tags not supported yet') try: calculation = load_node(pid) except (MultipleObjectsError, NotExistent): raise plumpy.PersistenceError( 'Failed to load the node for process<{}>: {}'.format(pid, traceback.format_exc()) ) checkpoint = calculation.checkpoint if checkpoint is None: raise plumpy.PersistenceError('Calculation<{}> does not have a saved checkpoint'.format(calculation.pk)) try: bundle = serialize.deserialize(checkpoint) except Exception: raise plumpy.PersistenceError( 'Failed to load the checkpoint for process<{}>: {}'.format(pid, traceback.format_exc()) ) return bundle
[docs] def get_checkpoints(self): """Return a list of all the current persisted process checkpoints :return: list of PersistedCheckpoint tuples with element containing the process id and optional checkpoint tag. """
[docs] def get_process_checkpoints(self, pid): """Return a list of all the current persisted process checkpoints for the specified process. :param pid: the process pid :return: list of PersistedCheckpoint tuples with element containing the process id and optional checkpoint tag. """
[docs] def delete_checkpoint(self, pid, tag=None): """Delete a persisted process checkpoint, where no error will be raised if the checkpoint does not exist. :param pid: the process id of the :class:`plumpy.Process` :param tag: optional checkpoint identifier to allow retrieving a specific sub checkpoint """ from aiida.orm import load_node calc = load_node(pid) calc.delete_checkpoint()
[docs] def delete_process_checkpoints(self, pid): """Delete all persisted checkpoints related to the given process id. :param pid: the process id of the :class:`aiida.engine.processes.process.Process` """