# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida_core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
# pylint: disable=cyclic-import
"""The AiiDA process class"""
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import abc
import collections
import enum
import inspect
import uuid
import traceback
import six
from six.moves import filter, range
from pika.exceptions import ConnectionClosed
import plumpy
from plumpy import ProcessState
from aiida.common import exceptions
from aiida.common.extendeddicts import AttributeDict
from aiida.common.lang import classproperty, override, protected
from aiida.common.links import LinkType
from aiida.common.log import LOG_LEVEL_REPORT
from aiida import orm
from aiida.orm import ProcessNode, CalculationNode, WorkflowNode
from aiida.orm.utils import serialize
from .. import utils
from .exit_code import ExitCode
from .builder import ProcessBuilder
from .ports import InputPort, PortNamespace
from .process_spec import ProcessSpec
__all__ = ('Process', 'ProcessState')
def instantiate_process(runner, process, *args, **inputs):
"""
Return an instance of the process with the given inputs. The function can deal with various types
of the `process`:
* Process instance: will simply return the instance
* ProcessBuilder instance: will instantiate the Process from the class and inputs defined within it
* Process class: will instantiate with the specified inputs
If anything else is passed, a ValueError will be raised
:param process: Process instance or class, CalcJobNode class or ProcessBuilder instance
:param inputs: the inputs for the process to be instantiated with
"""
if isinstance(process, Process):
assert not args
assert not inputs
assert runner is process.runner
return process
if isinstance(process, ProcessBuilder):
builder = process
process_class = builder.process_class
inputs.update(**builder)
elif issubclass(process, Process):
process_class = process
else:
raise ValueError('invalid process {}, needs to be Process or ProcessBuilder'.format(type(process)))
process = process_class(runner=runner, inputs=inputs)
return process
[docs]@plumpy.auto_persist('_parent_pid', '_enable_persistence')
@six.add_metaclass(abc.ABCMeta)
class Process(plumpy.Process):
"""
This class represents an AiiDA process which can be executed and will
have full provenance saved in the database.
"""
# pylint: disable=too-many-public-methods
_node_class = ProcessNode
_spec_type = ProcessSpec
SINGLE_OUTPUT_LINKNAME = 'result'
[docs] class SaveKeys(enum.Enum):
"""
Keys used to identify things in the saved instance state bundle.
"""
# pylint: disable=too-few-public-methods
CALC_ID = 'calc_id'
[docs] @classmethod
def define(cls, spec):
super(Process, cls).define(spec)
spec.input_namespace(spec.metadata_key, required=False, non_db=True, default={})
spec.input_namespace('{}.{}'.format(spec.metadata_key, spec.options_key), required=False)
spec.input('{}.store_provenance'.format(spec.metadata_key), valid_type=bool, default=True, non_db=True)
spec.input(
'{}.description'.format(spec.metadata_key), valid_type=six.string_types[0], required=False, non_db=True)
spec.input('{}.label'.format(spec.metadata_key), valid_type=six.string_types[0], required=False, non_db=True)
spec.inputs.valid_type = (orm.Data,)
spec.outputs.valid_type = (orm.Data,)
[docs] @classmethod
def get_builder(cls):
return ProcessBuilder(cls)
[docs] @classmethod
def get_or_create_db_record(cls):
"""
Create a database calculation node that represents what happened in
this process.
:return: A calculation
"""
return cls._node_class()
[docs] def __init__(self, inputs=None, logger=None, runner=None, parent_pid=None, enable_persistence=True):
from aiida.manage import manager
self._runner = runner if runner is not None else manager.get_manager().get_runner()
super(Process, self).__init__(
inputs=self.spec().inputs.serialize(inputs),
logger=logger,
loop=self._runner.loop,
communicator=self.runner.communicator)
self._node = None
self._parent_pid = parent_pid
self._enable_persistence = enable_persistence
if self._enable_persistence and self.runner.persister is None:
self.logger.warning('Disabling persistence, runner does not have a persister')
self._enable_persistence = False
[docs] def init(self):
super(Process, self).init()
if self._logger is None:
self.set_logger(self.node.logger)
@classproperty
def exit_codes(self):
"""
Return the namespace of exit codes defined for this WorkChain through its ProcessSpec.
The namespace supports getitem and getattr operations with an ExitCode label to retrieve a specific code.
Additionally, the namespace can also be called with either the exit code integer status to retrieve it.
:returns: ExitCodesNamespace of ExitCode named tuples
"""
return self.spec().exit_codes
@property
def node(self):
"""Return the ProcessNode used by this process to represent itself in the database.
:return: instance of sub class of ProcessNode
"""
return self._node
@property
def uuid(self):
"""Return the UUID of the process which corresponds to the UUID of its associated `ProcessNode`.
:return: the UUID associated to this process instance
"""
return self.node.uuid
@property
def metadata(self):
"""Return the metadata passed when launching this process.
:return: metadata dictionary
"""
try:
return self.inputs.metadata
except AttributeError:
return AttributeDict()
@property
def options(self):
"""Return the options of the metadata passed when launching this process.
:return: options dictionary
"""
try:
return self.metadata.options
except AttributeError:
return AttributeDict()
[docs] def _save_checkpoint(self):
"""
Save the current state in a chechpoint if persistence is enabled and the process state is not terminal
If the persistence call excepts with a PersistenceError, it will be caught and a warning will be logged.
"""
if self._enable_persistence and not self._state.is_terminal():
try:
self.runner.persister.save_checkpoint(self)
except plumpy.PersistenceError:
self.logger.exception("Exception trying to save checkpoint, this means you will "
"not be able to restart in case of a crash until the next successful checkpoint.")
[docs] @override
def save_instance_state(self, out_state, save_context):
super(Process, self).save_instance_state(out_state, save_context)
if self.metadata.store_provenance:
assert self.node.is_stored
out_state[self.SaveKeys.CALC_ID.value] = self.pid
[docs] @override
def load_instance_state(self, saved_state, load_context):
from aiida.manage import manager
if 'runner' in load_context:
self._runner = load_context.runner
else:
self._runner = manager.get_manager().get_runner()
load_context = load_context.copyextend(loop=self._runner.loop, communicator=self._runner.communicator)
super(Process, self).load_instance_state(saved_state, load_context)
if self.SaveKeys.CALC_ID.value in saved_state:
self._node = orm.load_node(saved_state[self.SaveKeys.CALC_ID.value])
self._pid = self.node.pk
else:
self._pid = self._create_and_setup_db_record()
self.node.logger.info('Loaded process<{}> from saved state'.format(self.node.pk))
[docs] def kill(self, msg=None):
"""
Kill the process and all the children calculations it called
"""
self.node.logger.debug('Request to kill Process<{}>'.format(self.node.pk))
had_been_terminated = self.has_terminated()
result = super(Process, self).kill(msg)
# Only kill children if we could be killed ourselves
if result is not False and not had_been_terminated:
killing = []
for child in self.node.called:
try:
result = self.runner.controller.kill_process(child.pk, 'Killed by parent<{}>'.format(self.node.pk))
if isinstance(result, plumpy.Future):
killing.append(result)
except ConnectionClosed:
self.logger.info('no connection available to kill child<%s>', child.pk)
if isinstance(result, plumpy.Future):
# We ourselves are waiting to be killed so add it to the list
killing.append(result)
if killing:
# We are waiting for things to be killed, so return the 'gathered' future
result = plumpy.gather(result)
return result
[docs] @override
def out(self, output_port, value=None):
if value is None:
# In this case assume that output_port is the actual value and there is just one return value
value = output_port
output_port = self.SINGLE_OUTPUT_LINKNAME
return super(Process, self).out(output_port, value)
[docs] def out_many(self, out_dict):
"""
Add all values given in ``out_dict`` to the outputs. The keys of the dictionary will be used as output names.
"""
for key, value in out_dict.items():
self.out(key, value)
# region Process event hooks
[docs] def on_create(self):
super(Process, self).on_create()
# If parent PID hasn't been supplied try to get it from the stack
if self._parent_pid is None and Process.current():
current = Process.current()
if isinstance(current, Process):
self._parent_pid = current.pid
self._pid = self._create_and_setup_db_record()
[docs] @override
def on_entering(self, state):
super(Process, self).on_entering(state)
# Update the node attributes every time we enter a new state
[docs] def on_entered(self, from_state):
self.update_node_state(self._state)
self._save_checkpoint()
# Update the latest process state change timestamp
utils.set_process_state_change_timestamp(self)
super(Process, self).on_entered(from_state)
[docs] @override
def on_terminated(self):
"""
Called when a Process enters a terminal state.
"""
super(Process, self).on_terminated()
if self._enable_persistence:
try:
self.runner.persister.delete_checkpoint(self.pid)
except BaseException:
self.logger.exception('Failed to delete checkpoint')
try:
self.node.seal()
except exceptions.ModificationNotAllowed:
pass
[docs] @override
def on_except(self, exc_info):
"""
Log the exception by calling the report method with formatted stack trace from exception info object
and store the exception string as a node attribute
:param exc_info: the sys.exc_info() object
"""
super(Process, self).on_except(exc_info)
self.node.set_exception(''.join(traceback.format_exception(exc_info[0], exc_info[1], None)))
self.report(''.join(traceback.format_exception(*exc_info)))
[docs] @override
def on_finish(self, result, successful):
"""
Set the finish status on the process node
"""
super(Process, self).on_finish(result, successful)
if result is None or isinstance(result, int):
self.node.set_exit_status(result)
elif isinstance(result, ExitCode):
self.node.set_exit_status(result.status)
self.node.set_exit_message(result.message)
else:
raise ValueError('the result should be an integer, ExitCode or None, got {} {} {}'.format(
type(result), result, self.pid))
[docs] @override
def on_paused(self, msg=None):
"""
The Process was paused so set the paused attribute on the process node
"""
super(Process, self).on_paused(msg)
self._save_checkpoint()
self.node.pause()
[docs] @override
def on_playing(self):
"""
The Process was unpaused so remove the paused attribute on the process node
"""
super(Process, self).on_playing()
self.node.unpause()
[docs] @override
def on_output_emitting(self, output_port, value):
"""
The process has emitted a value on the given output port.
:param output_port: The output port name the value was emitted on
:param value: The value emitted
"""
super(Process, self).on_output_emitting(output_port, value)
if not isinstance(value, orm.Data):
raise TypeError('Values output from process must be instances of AiiDA orm.Data types, got {}'.format(
value.__class__))
# end region
[docs] def set_status(self, status):
"""
The status of the Process is about to be changed, so we reflect this is in node's attribute proxy.
:param status: the status message
"""
super(Process, self).set_status(status)
self.node.set_process_status(status)
[docs] def submit(self, process, *args, **kwargs):
return self.runner.submit(process, *args, **kwargs)
@property
def runner(self):
return self._runner
[docs] @protected
def get_parent_calc(self):
"""
Get the parent process node
:return: the parent process node if there is one
:rtype: :class:`aiida.orm.nodes.process.process.ProcessNode`
"""
# Can't get it if we don't know our parent
if self._parent_pid is None:
return None
return orm.load_node(pk=self._parent_pid)
[docs] @classmethod
def build_process_type(cls):
"""
The process type.
:return: string of the process type
:rtype: str
Note: This could be made into a property 'process_type' but in order to have it be a property of the class
it would need to be defined in the metaclass, see https://bugs.python.org/issue20659
"""
from aiida.plugins.entry_point import get_entry_point_string_from_class
class_module = cls.__module__
class_name = cls.__name__
# If the process is a registered plugin the corresponding entry point will be used as process type
process_type = get_entry_point_string_from_class(class_module, class_name)
# If no entry point was found, default to fully qualified path name
if process_type is None:
return '{}.{}'.format(class_module, class_name)
return process_type
[docs] @protected
def report(self, msg, *args, **kwargs):
"""Log a message to the logger, which should get saved to the database through the attached DbLogHandler.
The pk, class name and function name of the caller are prepended to the given message
:param msg: message to log
:param args: args to pass to the log call
:param kwargs: kwargs to pass to the log call
"""
message = '[{}|{}|{}]: {}'.format(self.node.pk, self.__class__.__name__, inspect.stack()[1][3], msg)
self.logger.log(LOG_LEVEL_REPORT, message, *args, **kwargs)
[docs] def _create_and_setup_db_record(self):
"""
Create and setup the database record for this process
:return: the uuid of the process
"""
self._node = self.get_or_create_db_record()
self._setup_db_record()
if self.metadata.store_provenance:
try:
self.node.store_all()
if self.node.is_finished_ok:
self._state = ProcessState.FINISHED
for entry in self.node.get_outgoing(link_type=LinkType.RETURN):
if entry.link_label.endswith('_{pk}'.format(pk=entry.node.pk)):
continue
self.out(entry.link_label, entry.node)
# This is needed for CalcJob. In that case, the outputs are
# returned regardless of whether they end in '_pk'
for entry in self.node.get_outgoing(link_type=LinkType.CREATE):
self.out(entry.link_label, entry.node)
except exceptions.ModificationNotAllowed:
# The calculation was already stored
pass
if self.node.pk is not None:
return self.node.pk
return uuid.UUID(self.node.uuid)
[docs] def update_node_state(self, state):
self.update_outputs()
self.node.set_process_state(state.LABEL)
[docs] def update_outputs(self):
"""Attach any new outputs to the node since the last time this was called, if store provenance is True."""
if self.metadata.store_provenance is False:
return
outputs_stored = self.node.get_outgoing(link_type=(LinkType.CREATE, LinkType.RETURN)).all_link_labels()
outputs_new = set(self.outputs.keys()) - set(outputs_stored)
for link_label in outputs_new:
output = self.outputs[link_label]
if isinstance(self.node, CalculationNode):
output.add_incoming(self.node, LinkType.CREATE, link_label)
elif isinstance(self.node, WorkflowNode):
output.add_incoming(self.node, LinkType.RETURN, link_label)
output.store()
[docs] def _setup_db_record(self):
"""
Create the database record for this process and the links with respect to its inputs
This function will set various attributes on the node that serve as a proxy for attributes of the Process.
This is essential as otherwise this information could only be introspected through the Process itself, which
is only available to the interpreter that has it in memory. To make this data introspectable from any
interpreter, for example for the command line interface, certain Process attributes are proxied through the
calculation node.
In addition, the parent calculation will be setup with a CALL link if applicable and all inputs will be
linked up as well.
"""
assert self.inputs is not None
assert not self.node.is_sealed, 'process node cannot be sealed when setting up the database record'
# Store important process attributes in the node proxy
self.node.set_process_state(None)
self.node.set_process_label(self.__class__.__name__)
self.node.set_process_type(self.__class__.build_process_type())
parent_calc = self.get_parent_calc()
if parent_calc and self.metadata.store_provenance:
if isinstance(parent_calc, CalculationNode):
raise exceptions.InvalidOperation('calling processes from a calculation type process is forbidden.')
if isinstance(self.node, CalculationNode):
self.node.add_incoming(parent_calc, LinkType.CALL_CALC, 'CALL_CALC')
elif isinstance(self.node, WorkflowNode):
self.node.add_incoming(parent_calc, LinkType.CALL_WORK, 'CALL_WORK')
self._setup_metadata()
self._setup_inputs()
[docs] def exposed_outputs(self, process_instance, process_class, namespace=None, agglomerate=True):
"""
Gather the outputs which were exposed from the ``process_class`` and emitted by the specific
``process_instance`` in a dictionary.
:param namespace: Namespace in which to search for exposed outputs.
:type namespace: str
:param agglomerate: If set to true, all parent namespaces of the given ``namespace`` will also
be searched for outputs. Outputs in lower-lying namespaces take precedence.
:type agglomerate: bool
"""
namespace_separator = self.spec().namespace_separator
output_key_map = {}
# maps the exposed name to all outputs that belong to it
top_namespace_map = collections.defaultdict(list)
process_outputs_dict = {
entry.link_label: entry.node for entry in process_instance.get_outgoing(link_type=LinkType.RETURN)
}
for port_name in process_outputs_dict:
top_namespace = port_name.split(namespace_separator)[0]
top_namespace_map[top_namespace].append(port_name)
for nspace in self._get_namespace_list(namespace=namespace, agglomerate=agglomerate):
# only the top-level key is stored in _exposed_outputs
for top_name in top_namespace_map:
if top_name in self.spec()._exposed_outputs[nspace][process_class]: # pylint: disable=protected-access
output_key_map[top_name] = nspace
result = {}
for top_name, nspace in output_key_map.items():
# collect all outputs belonging to the given top_name
for port_name in top_namespace_map[top_name]:
if nspace is None:
result[port_name] = process_outputs_dict[port_name]
else:
result[nspace + namespace_separator + port_name] = process_outputs_dict[port_name]
return result
[docs] @staticmethod
def _get_namespace_list(namespace=None, agglomerate=True):
"""Get the list of namespaces in a given namespace"""
if not agglomerate:
return [namespace]
namespace_list = [None]
if namespace is not None:
split_ns = namespace.split('.')
namespace_list.extend(['.'.join(split_ns[:i]) for i in range(1, len(split_ns) + 1)])
return namespace_list
def get_query_string_from_process_type_string(process_type_string): # pylint: disable=invalid-name
"""
Take the process type string of a Node and create the queryable type string.
:param process_type_string: the process type string
:return: string that can be used to query for subclasses of the process type using 'LIKE <string>'
"""
if ":" in process_type_string:
return process_type_string + "."
path = process_type_string.rsplit('.', 2)[0]
return path + "."