# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
# pylint: disable=global-statement
"""Definition of AiiDA's process persister and the necessary object loaders."""
import importlib
import logging
import traceback
import plumpy
from aiida.orm.utils import serialize
__all__ = ('AiiDAPersister', 'ObjectLoader', 'get_object_loader')
LOGGER = logging.getLogger(__name__)
OBJECT_LOADER = None
[docs]class ObjectLoader(plumpy.DefaultObjectLoader):
"""Custom object loader for `aiida-core`."""
[docs] def load_object(self, identifier):
"""Attempt to load the object identified by the given `identifier`.
.. note:: We override the `plumpy.DefaultObjectLoader` to be able to throw an `ImportError` instead of a
`ValueError` which in the context of `aiida-core` is not as apt, since we are loading classes.
:param identifier: concatenation of module and resource name
:return: loaded object
:raises ImportError: if the object cannot be loaded
"""
module, name = identifier.split(':')
try:
module = importlib.import_module(module)
except ImportError:
raise ImportError("module '{}' from identifier '{}' could not be loaded".format(module, identifier))
try:
return getattr(module, name)
except AttributeError:
raise ImportError("object '{}' from identifier '{}' could not be loaded".format(name, identifier))
[docs]def get_object_loader():
"""Return the global AiiDA object loader.
:return: The global object loader
:rtype: :class:`plumpy.ObjectLoader`
"""
global OBJECT_LOADER
if OBJECT_LOADER is None:
OBJECT_LOADER = ObjectLoader()
return OBJECT_LOADER
[docs]class AiiDAPersister(plumpy.Persister):
"""Persister to take saved process instance states and persisting them to the database."""
[docs] def save_checkpoint(self, process, tag=None):
"""Persist a Process instance.
:param process: :class:`aiida.engine.Process`
:param tag: optional checkpoint identifier to allow distinguishing multiple checkpoints for the same process
:raises: :class:`plumpy.PersistenceError` Raised if there was a problem saving the checkpoint
"""
LOGGER.debug('Persisting process<%d>', process.pid)
if tag is not None:
raise NotImplementedError('Checkpoint tags not supported yet')
try:
bundle = plumpy.Bundle(process, plumpy.LoadSaveContext(loader=get_object_loader()))
except ImportError:
# Couldn't create the bundle
raise plumpy.PersistenceError(
"Failed to create a bundle for '{}': {}".format(process, traceback.format_exc())
)
try:
process.node.set_checkpoint(serialize.serialize(bundle))
except Exception:
raise plumpy.PersistenceError(
"Failed to store a checkpoint for '{}': {}".format(process, traceback.format_exc())
)
return bundle
[docs] def load_checkpoint(self, pid, tag=None):
"""Load a process from a persisted checkpoint by its process id.
:param pid: the process id of the :class:`plumpy.Process`
:param tag: optional checkpoint identifier to allow retrieving a specific sub checkpoint
:return: a bundle with the process state
:rtype: :class:`plumpy.Bundle`
:raises: :class:`plumpy.PersistenceError` Raised if there was a problem loading the checkpoint
"""
from aiida.common.exceptions import MultipleObjectsError, NotExistent
from aiida.orm import load_node
if tag is not None:
raise NotImplementedError('Checkpoint tags not supported yet')
try:
calculation = load_node(pid)
except (MultipleObjectsError, NotExistent):
raise plumpy.PersistenceError(
'Failed to load the node for process<{}>: {}'.format(pid, traceback.format_exc())
)
checkpoint = calculation.checkpoint
if checkpoint is None:
raise plumpy.PersistenceError('Calculation<{}> does not have a saved checkpoint'.format(calculation.pk))
try:
bundle = serialize.deserialize(checkpoint)
except Exception:
raise plumpy.PersistenceError(
'Failed to load the checkpoint for process<{}>: {}'.format(pid, traceback.format_exc())
)
return bundle
[docs] def get_checkpoints(self):
"""Return a list of all the current persisted process checkpoints
:return: list of PersistedCheckpoint tuples with element containing the process id and optional checkpoint tag.
"""
[docs] def get_process_checkpoints(self, pid):
"""Return a list of all the current persisted process checkpoints for the specified process.
:param pid: the process pid
:return: list of PersistedCheckpoint tuples with element containing the process id and optional checkpoint tag.
"""
[docs] def delete_checkpoint(self, pid, tag=None):
"""Delete a persisted process checkpoint, where no error will be raised if the checkpoint does not exist.
:param pid: the process id of the :class:`plumpy.Process`
:param tag: optional checkpoint identifier to allow retrieving a specific sub checkpoint
"""
from aiida.orm import load_node
calc = load_node(pid)
calc.delete_checkpoint()
[docs] def delete_process_checkpoints(self, pid):
"""Delete all persisted checkpoints related to the given process id.
:param pid: the process id of the :class:`aiida.engine.processes.process.Process`
"""