Source code for aiida.engine.persistence
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
# pylint: disable=global-statement
"""Definition of AiiDA's process persister and the necessary object loaders."""
import importlib
import logging
import traceback
from typing import TYPE_CHECKING, Any, Hashable, Optional
from plumpy.exceptions import PersistenceError
import plumpy.loaders
import plumpy.persistence
from aiida.orm.utils import serialize
if TYPE_CHECKING:
from aiida.engine.processes.process import Process
__all__ = ('AiiDAPersister', 'ObjectLoader', 'get_object_loader')
LOGGER = logging.getLogger(__name__)
OBJECT_LOADER = None
[docs]class ObjectLoader(plumpy.loaders.DefaultObjectLoader):
"""Custom object loader for `aiida-core`."""
[docs] def load_object(self, identifier: str) -> Any:
"""Attempt to load the object identified by the given `identifier`.
.. note:: We override the `plumpy.DefaultObjectLoader` to be able to throw an `ImportError` instead of a
`ValueError` which in the context of `aiida-core` is not as apt, since we are loading classes.
:param identifier: concatenation of module and resource name
:return: loaded object
:raises ImportError: if the object cannot be loaded
"""
module_name, name = identifier.split(':')
try:
module = importlib.import_module(module_name)
except ImportError:
raise ImportError(f"module '{module_name}' from identifier '{identifier}' could not be loaded")
try:
return getattr(module, name)
except AttributeError:
raise ImportError(f"object '{name}' from identifier '{identifier}' could not be loaded")
[docs]def get_object_loader() -> ObjectLoader:
"""Return the global AiiDA object loader.
:return: The global object loader
"""
global OBJECT_LOADER
if OBJECT_LOADER is None:
OBJECT_LOADER = ObjectLoader()
return OBJECT_LOADER
[docs]class AiiDAPersister(plumpy.persistence.Persister):
"""Persister to take saved process instance states and persisting them to the database."""
[docs] def save_checkpoint(self, process: 'Process', tag: Optional[str] = None): # type: ignore[override]
"""Persist a Process instance.
:param process: :class:`aiida.engine.Process`
:param tag: optional checkpoint identifier to allow distinguishing multiple checkpoints for the same process
:raises: :class:`PersistenceError` Raised if there was a problem saving the checkpoint
"""
LOGGER.debug('Persisting process<%d>', process.pid)
if tag is not None:
raise NotImplementedError('Checkpoint tags not supported yet')
try:
bundle = plumpy.persistence.Bundle(process, plumpy.persistence.LoadSaveContext(loader=get_object_loader()))
except ImportError:
# Couldn't create the bundle
raise PersistenceError(f"Failed to create a bundle for '{process}': {traceback.format_exc()}")
try:
process.node.set_checkpoint(serialize.serialize(bundle))
except Exception:
raise PersistenceError(f"Failed to store a checkpoint for '{process}': {traceback.format_exc()}")
return bundle
[docs] def load_checkpoint(self, pid: Hashable, tag: Optional[str] = None) -> plumpy.persistence.Bundle:
"""Load a process from a persisted checkpoint by its process id.
:param pid: the process id of the :class:`plumpy.Process`
:param tag: optional checkpoint identifier to allow retrieving a specific sub checkpoint
:return: a bundle with the process state
:rtype: :class:`plumpy.Bundle`
:raises: :class:`PersistenceError` Raised if there was a problem loading the checkpoint
"""
from aiida.common.exceptions import MultipleObjectsError, NotExistent
from aiida.orm import load_node
if tag is not None:
raise NotImplementedError('Checkpoint tags not supported yet')
try:
calculation = load_node(pid)
except (MultipleObjectsError, NotExistent):
raise PersistenceError(f'Failed to load the node for process<{pid}>: {traceback.format_exc()}')
checkpoint = calculation.checkpoint
if checkpoint is None:
raise PersistenceError(f'Calculation<{calculation.pk}> does not have a saved checkpoint')
try:
bundle = serialize.deserialize_unsafe(checkpoint)
except Exception:
raise PersistenceError(f'Failed to load the checkpoint for process<{pid}>: {traceback.format_exc()}')
return bundle
[docs] def get_checkpoints(self):
"""Return a list of all the current persisted process checkpoints
:return: list of PersistedCheckpoint tuples with element containing the process id and optional checkpoint tag.
"""
[docs] def get_process_checkpoints(self, pid: Hashable):
"""Return a list of all the current persisted process checkpoints for the specified process.
:param pid: the process pid
:return: list of PersistedCheckpoint tuples with element containing the process id and optional checkpoint tag.
"""
[docs] def delete_checkpoint(self, pid: Hashable, tag: Optional[str] = None) -> None: # pylint: disable=unused-argument
"""Delete a persisted process checkpoint, where no error will be raised if the checkpoint does not exist.
:param pid: the process id of the :class:`plumpy.Process`
:param tag: optional checkpoint identifier to allow retrieving a specific sub checkpoint
"""
from aiida.orm import load_node
calc = load_node(pid)
calc.delete_checkpoint()
[docs] def delete_process_checkpoints(self, pid: Hashable):
"""Delete all persisted checkpoints related to the given process id.
:param pid: the process id of the :class:`aiida.engine.processes.process.Process`
"""