# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida_core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
from abc import ABCMeta, abstractmethod
import inspect
from enum import Enum
from aiida.work.defaults import registry
from aiida.work.run import RunningType, RunningInfo
from aiida.work.process import Process, ProcessSpec
from aiida.work.legacy.wait_on import WaitOnWorkflow
from aiida.common.lang import override
from aiida.common.utils import get_class_string, get_object_string, \
get_object_from_string
from aiida.orm import load_node, load_workflow, Node
from aiida.utils.serialize import serialize_data, deserialize_data
from plum.wait_ons import Checkpoint, WaitOnAll, WaitOnProcess
from plum.wait import WaitOn
from plum.persistence.bundle import Bundle
from collections import namedtuple
from aiida.work.interstep import *
from plum.engine.execution_engine import Future
[docs]class _WorkChainSpec(ProcessSpec):
[docs] def __init__(self):
super(_WorkChainSpec, self).__init__()
self._outline = None
[docs] def get_description(self):
desc = [super(_WorkChainSpec, self).get_description()]
if self._outline:
desc.append("Outline")
desc.append("=======")
desc.append(self._outline.get_description())
return "\n".join(desc)
[docs] def outline(self, *commands):
"""
Define the outline that describes this work chain.
:param commands: One or more functions that make up this work chain.
"""
self._outline = commands \
if isinstance(commands, _Instruction) else _Block(commands)
[docs] def get_outline(self):
return self._outline
[docs]class WorkChain(Process):
"""
A WorkChain, the base class for AiiDA workflows.
"""
_spec_type = _WorkChainSpec
_CONTEXT = 'context'
_STEPPER_STATE = 'stepper_state'
_BARRIERS = 'barriers'
_INTERSTEPS = 'intersteps'
_ABORTED = 'aborted'
[docs] @classmethod
def define(cls, spec):
super(WorkChain, cls).define(spec)
# For now workchains can accept any input and emit any output
# If this changes in the future the spec should be updated here.
spec.dynamic_input()
spec.dynamic_output()
[docs] class Context(object):
[docs] def __init__(self, value=None):
# Have to do it this way otherwise our setattr will be called
# causing infinite recursion.
# See http://rafekettler.com/magicmethods.html
super(WorkChain.Context, self).__setattr__('_content', {})
if value is not None:
for k, v in value.iteritems():
self._content[k] = v
[docs] def _get_dict(self):
return self._content
[docs] def __getitem__(self, item):
return self._content[item]
[docs] def __setitem__(self, key, value):
self._content[key] = value
[docs] def __delitem__(self, key):
del self._content[key]
[docs] def __getattr__(self, name):
try:
return self._content[name]
except KeyError:
raise AttributeError("Context does not have a variable {}"
.format(name))
[docs] def __delattr__(self, item):
del self._content[item]
[docs] def __setattr__(self, name, value):
self._content[name] = value
[docs] def __dir__(self):
return sorted(self._content.keys())
[docs] def __iter__(self):
for k in self._content:
yield k
[docs] def get(self, key, default=None):
return self._content.get(key, default)
[docs] def setdefault(self, key, default=None):
return self._content.setdefault(key, default)
[docs] def save_instance_state(self, out_state):
for k, v in self._content.iteritems():
if isinstance(v, Node) and not v.is_stored:
v.store()
out_state[k] = serialize_data(v)
[docs] def __init__(self):
super(WorkChain, self).__init__()
self._context = None
self._stepper = None
self._barriers = []
self._intersteps = []
@property
def ctx(self):
return self._context
[docs] @override
def save_instance_state(self, out_state):
super(WorkChain, self).save_instance_state(out_state)
# Ask the context to save itself
bundle = Bundle()
self.ctx.save_instance_state(bundle)
out_state[self._CONTEXT] = bundle
# Save intersteps
for interstep in self._intersteps:
bundle = Bundle()
interstep.save_instance_state(bundle)
out_state.setdefault(self._INTERSTEPS, []).append(bundle)
# Save barriers
for barrier in self._barriers:
bundle = Bundle()
barrier.save_instance_state(bundle)
out_state.setdefault(self._BARRIERS, []).append(bundle)
# Ask the stepper to save itself
if self._stepper is not None:
bundle = Bundle()
self._stepper.save_position(bundle)
out_state[self._STEPPER_STATE] = bundle
out_state[self._ABORTED] = self._aborted
[docs] def insert_barrier(self, wait_on):
"""
Insert a barrier that will cause the workchain to wait until the wait
on is finished before continuing to the next step.
:param wait_on: The thing to wait on (of type plum.wait.wait_on)
"""
self._barriers.append(wait_on)
[docs] def remove_barrier(self, wait_on):
"""
Remove a barrier.
Precondition: must be a barrier that was previously inserted
:param wait_on: The wait on to remove (of type plum.wait.wait_on)
"""
del self._barriers[wait_on]
[docs] def insert_intersteps(self, intersteps):
"""
Insert an interstep to be executed after the current step
ends but before the next step ends
:param interstep: class:Interstep
"""
if not isinstance(intersteps, list):
intersteps = [intersteps]
for interstep in intersteps:
self._intersteps.append(interstep)
[docs] def to_context(self, **kwargs):
"""
This is a convenience method that provides syntactic sugar, for
a user to add multiple intersteps that will assign a certain value
to the corresponding key in the context of the workchain
"""
intersteps = []
for key, value in kwargs.iteritems():
if not isinstance(value, UpdateContextBuilder):
value = assign_(value)
interstep = value.build(key)
intersteps.append(interstep)
self.insert_intersteps(intersteps)
[docs] @override
def _run(self, **kwargs):
self._stepper = self.spec().get_outline().create_stepper(self)
return self._do_step()
@property
def _do_abort(self):
return self.calc.get_attr(self.calc.DO_ABORT_KEY, False)
@property
def _aborted(self):
return self.calc.get_attr(self.calc.ABORTED_KEY, False)
@_aborted.setter
def _aborted(self, value):
# One is not allowed to unabort an aborted WorkChain
if self._aborted == True and value == False:
self.logger.warning('trying to unset the abort flag on an already aborted workchain which is not allowed')
return
self.calc._set_attr(self.calc.ABORTED_KEY, value)
[docs] def _do_step(self, wait_on=None):
self._handle_do_abort()
if self._aborted:
return
for interstep in self._intersteps:
interstep.on_next_step_starting(self)
self._intersteps = []
self._barriers = []
try:
finished, retval = self._stepper.step()
except _PropagateReturn:
finished, retval = True, None
# Could have aborted during the step
self._handle_do_abort()
if self._aborted:
return
if not finished:
if retval is not None:
if isinstance(retval, list) and all(isinstance(interstep, Interstep) for interstep in retval):
self.insert_intersteps(retval)
elif isinstance(retval, Interstep):
self.insert_intersteps(retval)
else:
raise TypeError(
"Invalid value returned from step '{}'".format(retval))
for interstep in self._intersteps:
interstep.on_last_step_finished(self)
if self._barriers:
return WaitOnAll(self._do_step.__name__, self._barriers)
else:
return Checkpoint(self._do_step.__name__)
[docs] @override
def on_create(self, pid, inputs, saved_state):
super(WorkChain, self).on_create(pid, inputs, saved_state)
if saved_state is None:
self._context = self.Context()
else:
# Recreate the context
self._context = self.Context(deserialize_data(saved_state[self._CONTEXT]))
# Recreate the stepper
if self._STEPPER_STATE in saved_state:
self._stepper = self.spec().get_outline().create_stepper(self)
self._stepper.load_position(
saved_state[self._STEPPER_STATE])
try:
self._intersteps = [load_with_classloader(b) for
b in saved_state[self._INTERSTEPS]]
except KeyError:
self._intersteps = []
try:
self._barriers = [WaitOn.create_from(b) for
b in saved_state[self._BARRIERS]]
except KeyError:
pass
self._aborted = saved_state[self._ABORTED]
[docs] def _handle_do_abort(self):
"""
Check whether a request to abort has been registered, by checking whether the DO_ABORT_KEY
attribute has been set, and if so call self.abort and remove the DO_ABORT_KEY attribute
"""
do_abort = self._do_abort
if do_abort:
self.abort(do_abort)
self.calc._del_attr(self.calc.DO_ABORT_KEY)
[docs] def abort_nowait(self, msg=None):
"""
Abort the workchain at the next state transition without waiting
which is achieved by passing a timeout value of zero
:param msg: The abort message
:type msg: str
"""
self.report("Aborting: {}".format(msg))
self._aborted = True
self.stop()
[docs] def abort(self, msg=None, timeout=None):
"""
Abort the workchain by calling the abort method of the Process and
also adding the abort message to the report
:param msg: The abort message
:type msg: str
:param timeout: Wait for the given time until the process has aborted
:type timeout: float
:return: True if the process is aborted at the end of the function, False otherwise
"""
self.report("Aborting: {}".format(msg))
self._aborted = True
self.stop()
[docs]def ToContext(**kwargs):
"""
Utility function that returns a list of UpdateContext Interstep instances
NOTE: This is effectively a copy of WorkChain.to_context method added to
keep backwards compatibility, but should eventually be deprecated
"""
intersteps = []
for key, value in kwargs.iteritems():
if not isinstance(value, UpdateContextBuilder):
value = assign_(value)
interstep = value.build(key)
intersteps.append(interstep)
return intersteps
[docs]class _InterstepFactory(object):
"""
Factory to create the appropriate Interstep instance based
on the class string that was written to the bundle
"""
[docs] def create(self, bundle):
class_string = bundle[Bundle.CLASS]
if class_string == get_class_string(ToContext):
return ToContext(**bundle[ToContext.TO_ASSIGN])
else:
raise ValueError(
"Unknown interstep class type '{}'".format(class_string))
_INTERSTEP_FACTORY = _InterstepFactory()
[docs]class Stepper(object):
__metaclass__ = ABCMeta
[docs] def __init__(self, workflow):
self._workflow = workflow
[docs] @abstractmethod
def step(self):
"""
Execute on step of the instructions.
:return: A 2-tuple with entries
0. True if the stepper has finished, False otherwise
1. The return value from the executed step
:rtype: tuple
"""
pass
[docs] @abstractmethod
def save_position(self, out_position):
pass
[docs] @abstractmethod
def load_position(self, bundle):
pass
[docs]class _Instruction(object):
"""
This class represents an instruction in a a workchain. To step through the
step you need to get a stepper by calling ``create_stepper()`` from which
you can call the :class:`~Stepper.step()` method.
"""
__metaclass__ = ABCMeta
[docs] @abstractmethod
def create_stepper(self, workflow):
pass
[docs] def __str__(self):
return self.get_description()
[docs] @abstractmethod
def get_description(self):
"""
Get a text description of these instructions.
:return: The description
:rtype: str
"""
pass
[docs] @staticmethod
def check_command(command):
if not isinstance(command, _Instruction):
assert issubclass(command.im_class, Process)
args = inspect.getargspec(command)[0]
assert len(args) == 1, "Instruction must take one argument only: self"
[docs]class _Block(_Instruction):
"""
Represents a block of instructions i.e. a sequential list of instructions.
"""
[docs] class Stepper(Stepper):
_POSITION = 'pos'
_STEPPER_POS = 'stepper_pos'
[docs] def __init__(self, workflow, commands):
super(_Block.Stepper, self).__init__(workflow)
for c in commands:
_Instruction.check_command(c)
self._commands = commands
self._current_stepper = None
self._pos = 0
[docs] def step(self):
assert self._pos != len(self._commands), \
"Can't call step after the block is finished"
command = self._commands[self._pos]
if self._current_stepper is None and isinstance(command, _Instruction):
self._current_stepper = command.create_stepper(self._workflow)
# If there is a stepper being used then call that, otherwise just
# call the command (class function) directly
if self._current_stepper is not None:
finished, retval = self._current_stepper.step()
else:
finished, retval = True, command(self._workflow)
if finished:
self._pos += 1
self._current_stepper = None
return self._pos == len(self._commands), retval
[docs] def save_position(self, out_position):
out_position[self._POSITION] = self._pos
# Save the position of the current step we're working (if it's not a
# direct function)
if self._current_stepper is not None:
stepper_pos = Bundle()
self._current_stepper.save_position(stepper_pos)
out_position[self._STEPPER_POS] = stepper_pos
[docs] def load_position(self, bundle):
self._pos = bundle[self._POSITION]
# Do we have a stepper position to load?
if self._STEPPER_POS in bundle:
self._current_stepper = \
self._commands[self._pos].create_stepper(self._workflow)
self._current_stepper.load_position(bundle[self._STEPPER_POS])
[docs] def __init__(self, commands):
for command in commands:
if not isinstance(command, _Instruction):
# Maybe it's a simple method
if not inspect.ismethod(command):
raise ValueError(
"Workflow commands {} is not a class method.".
format(command))
self._commands = commands
[docs] @override
def create_stepper(self, workflow):
return self.Stepper(workflow, self._commands)
[docs] @override
def get_description(self, indent_level=0, indent_increment=4):
indent = ' ' * (indent_level * indent_increment)
desc = []
for c in self._commands:
if isinstance(c, _Instruction):
desc.append(c.get_description())
else:
desc.append('{}* {}'.format(indent, c.__name__))
if c.__doc__:
doc = c.__doc__
desc.append('{}{}'.format(indent,doc))
return '\n'.join(desc)
[docs]class _Conditional(object):
"""
Object that represents some condition with the corresponding body to be
executed if the condition.
E.g. ::
if(condition):
body
or::
while(condition):
body
"""
[docs] def __init__(self, parent, condition):
self._parent = parent
self._condition = condition
self._body = None
@property
def body(self):
return self._body
@property
def condition(self):
return self._condition
[docs] def is_true(self, workflow):
return self._condition(workflow)
[docs] def __call__(self, *commands):
assert self._body is None
self._body = _Block(commands)
return self._parent
[docs]class _If(_Instruction):
[docs] class Stepper(Stepper):
_POSITION = 'pos'
_STEPPER_POS = 'stepper_pos'
[docs] def __init__(self, workflow, if_spec):
super(_If.Stepper, self).__init__(workflow)
self._if_spec = if_spec
self._pos = -1
self._current_stepper = None
[docs] def step(self):
if self._current_stepper is None:
self._create_stepper()
# If we can't get a stepper then no conditions match, return
if self._current_stepper is None:
return True, None
finished, retval = self._current_stepper.step()
if finished:
self._current_stepper = None
self._pos = -1
return finished, retval
[docs] def save_position(self, out_position):
out_position[self._POSITION] = self._pos
if self._current_stepper is not None:
stepper_pos = Bundle()
self._current_stepper.save_position(stepper_pos)
out_position[self._STEPPER_POS] = stepper_pos
[docs] def load_position(self, bundle):
self._pos = bundle[self._POSITION]
if self._STEPPER_POS in bundle:
self._create_stepper()
self._current_stepper.load_position(bundle[self._STEPPER_POS])
else:
self._current_stepper = None
[docs] def _create_stepper(self):
if self._pos == -1:
self._current_stepper = None
# Check the conditions until we find one that is true
for idx, condition in enumerate(self._if_spec.conditionals):
if condition.is_true(self._workflow):
stepper = condition.body.create_stepper(self._workflow)
self._pos = idx
self._current_stepper = stepper
return
else:
branch = self._if_spec.conditionals[self._pos]
self._current_stepper = branch.body.create_stepper(self._workflow)
[docs] def __init__(self, condition):
super(_If, self).__init__()
self._ifs = [_Conditional(self, condition)]
self._sealed = False
[docs] def __call__(self, *commands):
"""
This is how the commands for the if(...) body are set
:param commands: The commands to run on the original if.
:return: This instance.
"""
self._ifs[0](*commands)
return self
[docs] def elif_(self, condition):
self._ifs.append(_Conditional(self, condition))
return self._ifs[-1]
[docs] def else_(self, *commands):
assert not self._sealed
# Create a dummy conditional that always returns True
cond = _Conditional(self, lambda wf: True)
cond(*commands)
self._ifs.append(cond)
# Can't do any more after the else
self._sealed = True
return self
[docs] def create_stepper(self, workflow):
return self.Stepper(workflow, self)
@property
def conditionals(self):
return self._ifs
[docs] @override
def get_description(self):
description = ['if {}:\n{}'.format(self._ifs[0].condition.__name__, self._ifs[0].body.get_description(indent_level=1))]
for conditional in self._ifs[1:]:
description.append('elif {}:\n{}'.format(
conditional.condition.__name__, conditional.body.get_description(indent_level=1)))
return '\n'.join(description)
[docs]class _While(_Conditional, _Instruction):
[docs] class Stepper(Stepper):
_STEPPER_POS = 'stepper_pos'
_CHECK_CONDITION = 'check_condition'
_FINISHED = 'finished'
[docs] def __init__(self, workflow, while_spec):
super(_While.Stepper, self).__init__(workflow)
self._spec = while_spec
self._stepper = None
self._check_condition = True
self._finished = False
[docs] def step(self):
assert not self._finished, \
"Can't call step after the loop has finished"
# Do we need to check the condition?
if self._check_condition is True:
self._check_condition = False
# Should we go into the loop body?
if self._spec.is_true(self._workflow):
self._stepper = \
self._spec.body.create_stepper(self._workflow)
else: # Nope...
self._finished = True
return True, None
finished, retval = self._stepper.step()
if finished:
self._check_condition = True
self._stepper = None
# Are we finished looping?
return self._finished, retval
[docs] def save_position(self, out_position):
if self._stepper is not None:
stepper_pos = Bundle()
self._stepper.save_position(stepper_pos)
out_position[self._STEPPER_POS] = stepper_pos
out_position[self._CHECK_CONDITION] = self._check_condition
out_position[self._FINISHED] = self._finished
[docs] def load_position(self, bundle):
if self._STEPPER_POS in bundle:
self._stepper = self._spec.body.create_stepper(self._workflow)
self._stepper.load_position(bundle[self._STEPPER_POS])
self._finished = bundle[self._FINISHED]
self._check_condition = bundle[self._CHECK_CONDITION]
@property
def _body_stepper(self):
if self._stepper is None:
self._stepper = \
self._spec.body.create_stepper(self._workflow)
return self._stepper
[docs] def __init__(self, condition):
super(_While, self).__init__(self, condition)
[docs] @override
def create_stepper(self, workflow):
return self.Stepper(workflow, self)
[docs] @override
def get_description(self):
return "while {}:\n{}".format(self.condition.__name__, self.body.get_description(indent_level=1))
[docs]class _PropagateReturn(BaseException):
pass
[docs]class _ReturnStepper(Stepper):
[docs] def step(self):
"""
Execute on step of the instructions.
:return: A 2-tuple with entries:
0. True if the stepper has finished, False otherwise
1. The return value from the executed step
:rtype: tuple
"""
raise _PropagateReturn()
[docs] def save_position(self, out_position):
return
[docs] def load_position(self, bundle):
"""
Nothing to be done: no internal state.
:param bundle:
:return:
"""
return
[docs]class _Return(_Instruction):
"""
A return instruction to tell the workchain to stop stepping through the
outline and cease execution immediately.
"""
[docs] def create_stepper(self, workflow):
return _ReturnStepper(workflow)
[docs] def get_description(self):
"""
Get a text description of these instructions.
:return: The description
:rtype: str
"""
return "Return from the outline immediately"
[docs]def if_(condition):
"""
A conditional that can be used in a workchain outline.
Use as::
if_(cls.conditional)(
cls.step1,
cls.step2
)
Each step can, of course, also be any valid workchain step e.g. conditional.
:param condition: The workchain method that will return True or False
"""
return _If(condition)
[docs]def while_(condition):
"""
A while loop that can be used in a workchain outline.
Use as::
while_(cls.conditional)(
cls.step1,
cls.step2
)
Each step can, of course, also be any valid workchain step e.g. conditional.
:param condition: The workchain method that will return True or False
"""
return _While(condition)
# Global singleton for return statements in workchain outlines
return_ = _Return()