Source code for aiida.work.process

# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved.                     #
# This file is part of the AiiDA code.                                    #
#                                                                         #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida_core #
# For further information on the license, see the LICENSE.txt file        #
# For further information please visit http://www.aiida.net               #
###########################################################################

import collections
import uuid
from enum import Enum
import itertools

import plum.port as port
import plum.process
from plum.process_monitor import MONITOR
import plum.process_monitor

import voluptuous
from abc import ABCMeta
from aiida.common.extendeddicts import FixedFieldsAttributeDict
import aiida.common.exceptions as exceptions
from aiida.common.lang import override, protected
from aiida.common.links import LinkType
from aiida.utils.calculation import add_source_info
from aiida.work.defaults import class_loader
import aiida.work.util
from aiida.work.util import PROCESS_LABEL_ATTR, get_or_create_output_group
from aiida.orm.calculation import Calculation
from aiida.orm.data.parameter import ParameterData
from aiida import LOG_LEVEL_REPORT



class DictSchema(object):
    def __init__(self, schema):
        self._schema = voluptuous.Schema(schema)

    def __call__(self, value):
        """
        Call this to validate the value against the schema.

        :param value: a regular dictionary or a ParameterData instance 
        :return: tuple (success, msg).  success is True if the value is valid
            and False otherwise, in which case msg will contain information about
            the validation failure.
        :rtype: tuple
        """
        try:
            if isinstance(value, ParameterData):
                value = value.get_dict()
            self._schema(value)
            return True, None
        except voluptuous.Invalid as e:
            return False, str(e)

    def get_template(self):
        return self._get_template(self._schema.schema)

    def _get_template(self, dict):
        template = type(
            "{}Inputs".format(self.__class__.__name__),
            (FixedFieldsAttributeDict,),
            {'_valid_fields': dict.keys()})()

        for key, value in dict.iteritems():
            if isinstance(key, (voluptuous.Optional, voluptuous.Required)):
                if key.default is not voluptuous.UNDEFINED:
                    template[key.schema] = key.default
                else:
                    template[key.schema] = None
            if isinstance(value, collections.Mapping):
                template[key] = self._get_template(value)
        return template


class ProcessSpec(plum.process.ProcessSpec):
    def __init__(self):
        super(ProcessSpec, self).__init__()
        self._fastforwardable = False

    def is_fastforwardable(self):
        return self._fastforwardable

    def fastforwardable(self):
        self._fastforwardable = True

    def get_inputs_template(self):
        """
        Get an object that represents a template of the known inputs and their
        defaults for the :class:`Process`.

        :return: An object with attributes that represent the known inputs for
            this process.  Default values will be filled in.
        """
        template = type(
            "{}Inputs".format(self.__class__.__name__),
            (FixedFieldsAttributeDict,),
            {'_valid_fields': self.inputs.keys()})()

        # Now fill in any default values
        for name, value_spec in self.inputs.iteritems():
            if isinstance(value_spec.validator, DictSchema):
                template[name] = value_spec.validator.get_template()
            elif value_spec.default is not None:
                template[name] = value_spec.default
            else:
                template[name] = None

        return template


[docs]class Process(plum.process.Process): """ This class represents an AiiDA process which can be executed and will have full provenance saved in the database. """ __metaclass__ = ABCMeta SINGLE_RETURN_LINKNAME = '_return'
[docs] class SaveKeys(Enum): """ Keys used to identify things in the saved instance state bundle. """ CALC_ID = 'calc_id' PARENT_CALC_PID = 'parent_calc_pid'
@classmethod def define(cls, spec): import aiida.orm super(Process, cls).define(spec) spec.input("_store_provenance", valid_type=bool, default=True, required=False) spec.input("_description", valid_type=basestring, required=False) spec.input("_label", valid_type=basestring, required=False) spec.dynamic_input(valid_type=(aiida.orm.Data, aiida.orm.Calculation)) spec.dynamic_output(valid_type=aiida.orm.Data) @classmethod def get_inputs_template(cls): return cls.spec().get_inputs_template() @classmethod def _create_default_exec_engine(cls): from aiida.work.defaults import serial_engine return serial_engine @classmethod
[docs] def create_db_record(cls): """ Create a database calculation node that represents what happened in this process. :return: """ from aiida.orm.calculation.work import WorkCalculation calc = WorkCalculation() return calc
_spec_type = ProcessSpec def __init__(self): super(Process, self).__init__() self._calc = None self._parent_pid = None @property def calc(self): return self._calc @override def save_instance_state(self, bundle): super(Process, self).save_instance_state(bundle) if self.inputs._store_provenance: assert self.calc.is_stored bundle[self.SaveKeys.CALC_ID.value] = self.pid bundle.set_class_loader(class_loader) def run_after_queueing(self, wait_on): return self._run def get_provenance_inputs_iterator(self): return itertools.ifilter(lambda kv: not kv[0].startswith('_'), self.inputs.iteritems()) @override def out(self, output_port, value=None): if value is None: # In this case assume that output_port is the actual value and there # is just one return value return super(Process, self).out(self.SINGLE_RETURN_LINKNAME, output_port) else: return super(Process, self).out(output_port, value) # Messages ##################################################### @override def on_create(self, pid, inputs, saved_instance_state): from aiida.orm import load_node super(Process, self).on_create(pid, inputs, saved_instance_state) if saved_instance_state is None: # Get the parent from the top of the process stack try: self._parent_pid = aiida.work.util.ProcessStack.top().pid except IndexError: pass self._pid = self._create_and_setup_db_record() else: if self.SaveKeys.CALC_ID.value in saved_instance_state: self._calc = load_node(saved_instance_state[self.SaveKeys.CALC_ID.value]) self._pid = self.calc.pk else: self._pid = self._create_and_setup_db_record() if self.SaveKeys.PARENT_CALC_PID.value in saved_instance_state: self._parent_pid = saved_instance_state[ self.SaveKeys.PARENT_CALC_PID.value] if self._logger is None: self.set_logger(self.calc.logger) @override def on_start(self): super(Process, self).on_start() aiida.work.util.ProcessStack.push(self) @override def on_finish(self): super(Process, self).on_finish() self.calc.seal() @override def _on_output_emitted(self, output_port, value, dynamic): """ The process has emitted a value on the given output port. :param output_port: The output port name the value was emitted on :param value: The value emitted :param dynamic: Was the output port a dynamic one (i.e. not known beforehand?) """ from aiida.orm import Data super(Process, self)._on_output_emitted(output_port, value, dynamic) assert isinstance(value, Data), \ "Values outputted from process must be instances of AiiDA Data" \ "types. Got: {}".format(value.__class__) if not value.is_stored: value.add_link_from(self.calc, output_port, LinkType.CREATE) if self.inputs._store_provenance: value.store() value.add_link_from(self.calc, output_port, LinkType.RETURN) ################################################################# @override def do_run(self): # Exclude all private inputs ins = {k: v for k, v in self.inputs.iteritems() if not k.startswith('_')} return self._run(**ins) @protected def get_parent_calc(self): from aiida.orm import load_node # Can't get it if we don't know our parent if self._parent_pid is None: return None # First try and get the process from the registry in case it is running try: return MONITOR.get_process(self._parent_pid).calc except ValueError: pass # Ok, maybe the pid is actually a pk... try: return load_node(pk=self._parent_pid) except exceptions.NotExistent: pass # Out of options return None @protected
[docs] def report(self, msg, *args, **kwargs): """ """ self.logger.log(LOG_LEVEL_REPORT, msg, *args, **kwargs)
# @override # def create_input_args(self, inputs): # parsed = super(Process, self).create_input_args(inputs) # # Now remove any that have a leading underscore # for name in parsed.keys(): # if name.startswith('_'): # del parsed[name] # return parsed def _create_and_setup_db_record(self): self._calc = self.create_db_record() self._setup_db_record() if self.inputs._store_provenance: self.calc.store_all() if self.calc.pk is not None: return self.calc.pk else: return uuid.UUID(self.calc.uuid) def _setup_db_record(self): assert self.inputs is not None assert not self.calc.is_sealed, \ "Calculation cannot be sealed when setting up the database record" # Save the name of this process self.calc._set_attr(PROCESS_LABEL_ATTR, self.__class__.__name__) parent_calc = self.get_parent_calc() # First get a dictionary of all the inputs to link, this is needed to # deal with things like input groups to_link = {} for name, input in self.inputs.iteritems(): # Ignore all inputs starting with a leading underscore if name.startswith('_'): continue if self.spec().has_input(name): if isinstance(self.spec().get_input(name), port.InputGroupPort): to_link.update( {"{}_{}".format(name, k): v for k, v in input.iteritems()}) else: to_link[name] = input else: # It's not in the spec, so we better support dynamic inputs assert self.spec().has_dynamic_input() to_link[name] = input for name, input in to_link.iteritems(): if isinstance(input, Calculation): input = get_or_create_output_group(input) if not input.is_stored: # If the input isn't stored then assume our parent created it if parent_calc: input.add_link_from(parent_calc, "CREATE", link_type=LinkType.CREATE) if self.inputs._store_provenance: input.store() self.calc.add_link_from(input, name) if parent_calc: self.calc.add_link_from(parent_calc, "CALL", link_type=LinkType.CALL) if self.raw_inputs: if '_description' in self.raw_inputs: self.calc.description = self.raw_inputs._description if '_label' in self.raw_inputs: self.calc.label = self.raw_inputs._label def _can_fast_forward(self, inputs): return False def _fast_forward(self): node = None # Here we should find the old node for k, v in node.get_output_dict(): self.out(k, v)
class FunctionProcess(Process): _func_args = None @staticmethod def _func(*args, **kwargs): """ This is used internally to store the actual function that is being wrapped and will be replaced by the build method. """ return {} @staticmethod def build(func, **kwargs): """ Build a Process from the given function. All function arguments will be assigned as process inputs. If keyword arguments are specified then these will also become inputs. :param func: The function to build a process from :param kwargs: Optional keyword arguments that will become additional inputs to the process :return: A Process class that represents the function :rtype: :class:`Process` """ import inspect from aiida.orm.data import Data args, varargs, keywords, defaults = inspect.getargspec(func) def _define(cls, spec): super(FunctionProcess, cls).define(spec) for i in range(len(args)): default = None if defaults and len(defaults) - len(args) + i >= 0: default = defaults[i] spec.input(args[i], valid_type=Data, default=default) # Make sure to get rid of the argument from the keywords dict kwargs.pop(args[i], None) for k, v in kwargs.iteritems(): spec.input(k) # If the function support kwargs then allow dynamic inputs, # otherwise disallow if keywords is not None: spec.dynamic_input() else: spec.no_dynamic_input() # We don't know what a function will return so keep it dynamic spec.dynamic_output(valid_type=Data) return type(func.__name__, (FunctionProcess,), {'_func': staticmethod(func), Process.define.__name__: classmethod(_define), '_func_args': args}) @classmethod def args_to_dict(cls, *args): """ Create an input dictionary (i.e. label: value) from supplied args. :param args: The values to use :return: A label: value dictionary """ assert (len(args) == len(cls._func_args)) return dict(zip(cls._func_args, args)) @override def _setup_db_record(self): super(FunctionProcess, self)._setup_db_record() add_source_info(self.calc, self._func) # Save the name of the function self.calc._set_attr(PROCESS_LABEL_ATTR, self._func.__name__) @override def _run(self, **kwargs): from aiida.orm.data import Data args = [] for arg in self._func_args: args.append(kwargs.pop(arg)) outs = self._func(*args, **kwargs) if outs is not None: if isinstance(outs, Data): self.out(self.SINGLE_RETURN_LINKNAME, outs) elif isinstance(outs, collections.Mapping): for name, value in outs.iteritems(): self.out(name, value) else: raise TypeError( "Workfunction returned unsupported type '{}'\n" "Must be a Data type or a Mapping of string => Data". format(outs.__class__)) class _ProcessFinaliser(plum.process_monitor.ProcessMonitorListener): """ Take care of finalising a process when it finishes either through successful completion or because of a failure caused by an exception. """ def __init__(self): MONITOR.add_monitor_listener(self) @override def on_monitored_process_destroying(self, process): aiida.work.util.ProcessStack.pop(process) @override def on_monitored_process_failed(self, pid): from aiida.orm import load_node try: calc_node = load_node(pk=pid) except ValueError: pass else: calc_node.seal() aiida.work.util.ProcessStack.pop(pid=pid) # Have a global singleton to take care of finalising all processes _finaliser = _ProcessFinaliser()