###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Class and decorators to generate processes out of simple python functions."""
from __future__ import annotations
import collections
import functools
import inspect
import itertools
import logging
import signal
import sys
import types
import typing as t
from typing import TYPE_CHECKING
import docstring_parser
from aiida.common.lang import override
from aiida.manage import get_manager
from aiida.orm import (
Bool,
CalcFunctionNode,
Data,
Dict,
Float,
Int,
List,
ProcessNode,
Str,
WorkFunctionNode,
to_aiida_type,
)
from aiida.orm.utils.mixins import FunctionCalculationMixin
from .process import Process
from .process_spec import ProcessSpec
try:
UnionType = types.UnionType
except AttributeError:
# This type is not available for Python 3.9 and older
UnionType = None # type: ignore[assignment,misc]
try:
from typing import ParamSpec
except ImportError:
# Fallback for Python 3.9 and older
from typing_extensions import ParamSpec # type: ignore[assignment]
try:
get_annotations = inspect.get_annotations
except AttributeError:
# This is the backport for Python 3.9 and older
from get_annotations import get_annotations # type: ignore[no-redef]
if TYPE_CHECKING:
from .exit_code import ExitCode
__all__ = ('calcfunction', 'workfunction', 'FunctionProcess')
LOGGER = logging.getLogger(__name__)
FunctionType = t.TypeVar('FunctionType', bound=t.Callable[..., t.Any])
[docs]
def get_stack_size(size: int = 2) -> int: # type: ignore[return]
"""Return the stack size for the caller's frame.
This solution is taken from https://stackoverflow.com/questions/34115298/ as a more performant alternative to the
naive ``len(inspect.stack())` solution. This implementation is about three orders of magnitude faster compared to
the naive solution and it scales especially well for larger stacks, which will be usually the case for the usage
of ``aiida-core``. However, it does use the internal ``_getframe`` of the ``sys`` standard library. It this ever
were to stop working, simply switch to using ``len(inspect.stack())``.
:param size: Hint for the expected stack size.
:returns: The stack size for caller's frame.
"""
frame = sys._getframe(size)
try:
for size in itertools.count(size, 8):
frame = frame.f_back.f_back.f_back.f_back.f_back.f_back.f_back.f_back # type: ignore[assignment,union-attr]
except AttributeError:
while frame: # type: ignore[truthy-bool]
frame = frame.f_back # type: ignore[assignment]
size += 1
return size - 1
P = ParamSpec('P')
R_co = t.TypeVar('R_co', covariant=True)
N = t.TypeVar('N', bound=ProcessNode)
[docs]
class ProcessFunctionType(t.Protocol, t.Generic[P, R_co, N]):
"""Protocol for a decorated process function."""
[docs]
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co: ...
[docs]
def run(self, *args: P.args, **kwargs: P.kwargs) -> R_co: ...
[docs]
def run_get_pk(self, *args: P.args, **kwargs: P.kwargs) -> tuple[dict[str, t.Any] | None, int]: ...
[docs]
def run_get_node(self, *args: P.args, **kwargs: P.kwargs) -> tuple[dict[str, t.Any] | None, N]: ...
is_process_function: bool
node_class: t.Type[N]
process_class: t.Type[Process]
recreate_from: t.Callable[[N], Process]
spec: t.Callable[[], ProcessSpec]
[docs]
def calcfunction(function: t.Callable[P, R_co]) -> ProcessFunctionType[P, R_co, CalcFunctionNode]:
"""A decorator to turn a standard python function into a calcfunction.
Example usage:
>>> from aiida.orm import Int
>>>
>>> # Define the calcfunction
>>> @calcfunction
>>> def sum(a, b):
>>> return a + b
>>> # Run it with some input
>>> r = sum(Int(4), Int(5))
>>> print(r)
9
>>> r.base.links.get_incoming().all() # doctest: +SKIP
[Neighbor(link_type='', link_label='result',
node=<CalcFunctionNode: uuid: ce0c63b3-1c84-4bb8-ba64-7b70a36adf34 (pk: 3567)>)]
>>> r.base.links.get_incoming().get_node_by_label('result').base.links.get_incoming().all_nodes()
[4, 5]
:param function: The function to decorate.
:return: The decorated function.
"""
return process_function(node_class=CalcFunctionNode)(function) # type: ignore[arg-type]
[docs]
def workfunction(function: t.Callable[P, R_co]) -> ProcessFunctionType[P, R_co, WorkFunctionNode]:
"""A decorator to turn a standard python function into a workfunction.
Example usage:
>>> from aiida.orm import Int
>>>
>>> # Define the workfunction
>>> @workfunction
>>> def select(a, b):
>>> return a
>>> # Run it with some input
>>> r = select(Int(4), Int(5))
>>> print(r)
4
>>> r.base.links.get_incoming().all() # doctest: +SKIP
[Neighbor(link_type='', link_label='result',
node=<WorkFunctionNode: uuid: ce0c63b3-1c84-4bb8-ba64-7b70a36adf34 (pk: 3567)>)]
>>> r.base.links.get_incoming().get_node_by_label('result').base.links.get_incoming().all_nodes()
[4, 5]
:param function: The function to decorate.
:return: The decorated function.
"""
return process_function(node_class=WorkFunctionNode)(function) # type: ignore[arg-type]
[docs]
def process_function(node_class: t.Type['ProcessNode']) -> t.Callable[[FunctionType], FunctionType]:
"""The base function decorator to create a FunctionProcess out of a normal python function.
:param node_class: the ORM class to be used as the Node record for the FunctionProcess
"""
def decorator(function: FunctionType) -> FunctionType:
"""Turn the decorated function into a FunctionProcess.
:param callable function: the actual decorated function that the FunctionProcess represents
:return callable: The decorated function.
"""
process_class = FunctionProcess.build(function, node_class=node_class)
def run_get_node(*args, **kwargs) -> tuple[dict[str, t.Any] | None, 'ProcessNode']:
"""Run the FunctionProcess with the supplied inputs in a local runner.
:param args: input arguments to construct the FunctionProcess
:param kwargs: input keyword arguments to construct the FunctionProcess
:return: tuple of the outputs of the process and the process node
"""
frame_delta = 1000
frame_count = get_stack_size()
stack_limit = sys.getrecursionlimit()
LOGGER.info('Executing process function, current stack status: %d frames of %d', frame_count, stack_limit)
# If the current frame count is more than 80% of the stack limit, or comes within 200 frames, increase the
# stack limit by ``frame_delta``.
if frame_count > min(0.8 * stack_limit, stack_limit - 200):
LOGGER.warning(
'Current stack contains %d frames which is close to the limit of %d. Increasing the limit by %d',
frame_count,
stack_limit,
frame_delta,
)
sys.setrecursionlimit(stack_limit + frame_delta)
manager = get_manager()
runner = manager.get_runner()
inputs = process_class.create_inputs(*args, **kwargs)
# Remove all the known inputs from the kwargs
for port in process_class.spec().inputs:
kwargs.pop(port, None)
# If any kwargs remain, the spec should be dynamic, so we raise if it isn't
if kwargs and not process_class.spec().inputs.dynamic:
raise ValueError(f'{function.__name__} does not support these kwargs: {kwargs.keys()}')
process = process_class(inputs=inputs, runner=runner)
# Only add handlers for interrupt signal to kill the process if we are in a local and not a daemon runner.
# Without this check, running process functions in a daemon worker would be killed if the daemon is shutdown
current_runner = manager.get_runner()
original_handler = None
kill_signal = signal.SIGINT
if not current_runner.is_daemon_runner:
def kill_process(_num, _frame):
"""Send the kill signal to the process in the current scope."""
LOGGER.critical('runner received interrupt, killing process %s', process.pid)
result = process.kill(msg='Process was killed because the runner received an interrupt')
return result
# Store the current handler on the signal such that it can be restored after process has terminated
original_handler = signal.getsignal(kill_signal)
signal.signal(kill_signal, kill_process)
try:
result = process.execute()
finally:
# If the `original_handler` is set, that means the `kill_process` was bound, which needs to be reset
if original_handler:
signal.signal(signal.SIGINT, original_handler)
store_provenance = inputs.get('metadata', {}).get('store_provenance', True)
if not store_provenance:
process.node._storable = False
process.node._unstorable_message = 'cannot store node because it was run with `store_provenance=False`'
return result, process.node
def run_get_pk(*args, **kwargs) -> tuple[dict[str, t.Any] | None, int]:
"""Recreate the `run_get_pk` utility launcher.
:param args: input arguments to construct the FunctionProcess
:param kwargs: input keyword arguments to construct the FunctionProcess
:return: tuple of the outputs of the process and the process node pk
"""
result, node = run_get_node(*args, **kwargs)
assert node.pk is not None
return result, node.pk
@functools.wraps(function)
def decorated_function(*args, **kwargs):
"""This wrapper function is the actual function that is called."""
result, _ = run_get_node(*args, **kwargs)
return result
decorated_function.run = decorated_function # type: ignore[attr-defined]
decorated_function.run_get_pk = run_get_pk # type: ignore[attr-defined]
decorated_function.run_get_node = run_get_node # type: ignore[attr-defined]
decorated_function.is_process_function = True # type: ignore[attr-defined]
decorated_function.node_class = node_class # type: ignore[attr-defined]
decorated_function.process_class = process_class # type: ignore[attr-defined]
decorated_function.recreate_from = process_class.recreate_from # type: ignore[attr-defined]
decorated_function.spec = process_class.spec # type: ignore[attr-defined]
return decorated_function # type: ignore[return-value]
return decorator
[docs]
def infer_valid_type_from_type_annotation(annotation: t.Any) -> tuple[t.Any, ...]:
"""Infer the value for the ``valid_type`` of an input port from the given function argument annotation.
:param annotation: The annotation of a function argument as returned by ``inspect.get_annotation``.
:returns: A tuple of valid types. If no valid types were defined or they could not be successfully parsed, an empty
tuple is returned.
"""
def get_type_from_annotation(annotation):
valid_type_map = {
bool: Bool,
dict: Dict,
t.Dict: Dict,
float: Float,
int: Int,
list: List,
t.List: List,
str: Str,
}
if inspect.isclass(annotation) and issubclass(annotation, Data):
return annotation
return valid_type_map.get(annotation)
inferred_valid_type: tuple[t.Any, ...] = ()
if inspect.isclass(annotation):
inferred_valid_type = (get_type_from_annotation(annotation),)
elif t.get_origin(annotation) is t.Union or t.get_origin(annotation) is UnionType:
inferred_valid_type = tuple(get_type_from_annotation(valid_type) for valid_type in t.get_args(annotation))
elif t.get_origin(annotation) is t.Optional:
inferred_valid_type = (t.get_args(annotation),)
return tuple(valid_type for valid_type in inferred_valid_type if valid_type is not None)
[docs]
class FunctionProcess(Process):
"""Function process class used for turning functions into a Process"""
_func_args: t.Sequence[str] = ()
_var_positional: str | None = None
_var_keyword: str | None = None
[docs]
@staticmethod
def _func(*_args, **_kwargs) -> dict:
"""This is used internally to store the actual function that is being
wrapped and will be replaced by the build method.
"""
return {}
[docs]
@staticmethod
def build(func: FunctionType, node_class: t.Type['ProcessNode']) -> t.Type['FunctionProcess']:
"""Build a Process from the given function.
All function arguments will be assigned as process inputs. If keyword arguments are specified then
these will also become inputs.
:param func: The function to build a process from
:param node_class: Provide a custom node class to be used, has to be constructable with no arguments. It has to
be a sub class of `ProcessNode` and the mixin :class:`~aiida.orm.utils.mixins.FunctionCalculationMixin`.
:return: A Process class that represents the function
"""
if (
not issubclass(node_class, ProcessNode) # type: ignore[redundant-expr]
or not issubclass(node_class, FunctionCalculationMixin) # type: ignore[unreachable]
):
raise TypeError('the node_class should be a sub class of `ProcessNode` and `FunctionCalculationMixin`')
signature = inspect.signature(func) # type: ignore[unreachable]
args: list[str] = []
var_positional: str | None = None
var_keyword: str | None = None
try:
annotations = get_annotations(func, eval_str=True)
except Exception as exception:
# Since we are running with ``eval_str=True`` to unstringize the annotations, the call can except if the
# annotations are incorrect. In this case we simply want to log a warning and continue with type inference.
LOGGER.warning(f'function `{func.__name__}` has invalid type hints: {exception}')
annotations = {}
try:
parsed_docstring = docstring_parser.parse(func.__doc__)
except Exception as exception:
LOGGER.warning(f'function `{func.__name__}` has a docstring that could not be parsed: {exception}')
param_help_string = {}
namespace_help_string = None
else:
param_help_string = {param.arg_name: param.description for param in parsed_docstring.params}
namespace_help_string = parsed_docstring.short_description if parsed_docstring.short_description else ''
if parsed_docstring.long_description is not None:
namespace_help_string += f'\n\n{parsed_docstring.long_description}'
for key, parameter in signature.parameters.items():
if parameter.kind in [parameter.POSITIONAL_ONLY, parameter.POSITIONAL_OR_KEYWORD, parameter.KEYWORD_ONLY]:
args.append(key)
if parameter.kind is parameter.VAR_POSITIONAL:
var_positional = key
if parameter.kind is parameter.VAR_KEYWORD:
var_keyword = key
def define(cls, spec):
"""Define the spec dynamically"""
from plumpy.ports import UNSPECIFIED
super().define(spec)
for parameter in signature.parameters.values():
if parameter.kind in [parameter.VAR_POSITIONAL, parameter.VAR_KEYWORD]:
continue
annotation = annotations.get(parameter.name)
valid_type = infer_valid_type_from_type_annotation(annotation) or (Data,)
help_string = param_help_string.get(parameter.name, None)
default = parameter.default if parameter.default is not parameter.empty else UNSPECIFIED
# If the keyword was already specified, simply override the default
if spec.has_input(parameter.name):
spec.inputs[parameter.name].default = default
continue
# If the default is ``None`` make sure that the port also accepts a ``NoneType``. Note that we cannot
# use ``None`` because the validation will call ``isinstance`` which does not work when passing ``None``
# but it does work with ``NoneType`` which is returned by calling ``type(None)``.
if default is None:
valid_type += (type(None),)
# If a default is defined and it is not a ``Data`` instance it should be serialized, but this should be
# done lazily using a lambda, just as any port defaults should not define node instances directly as is
# also checked by the ``spec.input`` call.
if (
default is not None
and default != UNSPECIFIED
and not isinstance(default, Data)
and not callable(default)
):
def indirect_default(value=default):
return to_aiida_type(value)
else:
indirect_default = default
spec.input(
parameter.name,
valid_type=valid_type,
default=indirect_default,
serializer=to_aiida_type,
help=help_string,
)
# Set defaults for label and description based on function name and docstring, if not explicitly defined
port_label = spec.inputs['metadata']['label']
if not port_label.has_default():
port_label.default = func.__name__
spec.inputs.help = namespace_help_string
# If the function supports varargs or kwargs then allow dynamic inputs, otherwise disallow
spec.inputs.dynamic = var_positional is not None or var_keyword is not None
# Function processes must have a dynamic output namespace since we do not know beforehand what outputs
# will be returned and the valid types for the value should be `Data` nodes as well as a dictionary because
# the output namespace can be nested.
spec.outputs.valid_type = (Data, dict)
return type(
func.__qualname__,
(FunctionProcess,),
{
'__module__': func.__module__,
'__name__': func.__name__,
'__qualname__': func.__qualname__,
'_func': staticmethod(func),
Process.define.__name__: classmethod(define),
'_func_args': args,
'_var_positional': var_positional,
'_var_keyword': var_keyword,
'_node_class': node_class,
},
)
[docs]
@classmethod
def get_or_create_db_record(cls) -> 'ProcessNode':
return cls._node_class()
[docs]
def __init__(self, *args, **kwargs) -> None:
if kwargs.get('enable_persistence', False):
raise RuntimeError('Cannot persist a function process')
super().__init__(enable_persistence=False, *args, **kwargs) # type: ignore[misc]
@property
def process_class(self) -> t.Callable[..., t.Any]:
"""Return the class that represents this Process, for the FunctionProcess this is the function itself.
For a standard Process or sub class of Process, this is the class itself. However, for legacy reasons,
the Process class is a wrapper around another class. This function returns that original class, i.e. the
class that really represents what was being executed.
:return: A Process class that represents the function
"""
return self._func
[docs]
def execute(self) -> dict[str, t.Any] | None:
"""Execute the process."""
result = super().execute()
# FunctionProcesses can return a single value as output, and not a dictionary, so we should also return that
if result and len(result) == 1 and self.SINGLE_OUTPUT_LINKNAME in result:
return result[self.SINGLE_OUTPUT_LINKNAME]
return result
[docs]
@override
def _setup_db_record(self) -> None:
"""Set up the database record for the process."""
super()._setup_db_record()
self.node.store_source_info(self._func)
[docs]
@override
def run(self) -> 'ExitCode' | None:
"""Run the process."""
from .exit_code import ExitCode
# The following conditional is required for the caching to properly work. Even if the source node has a process
# state of `Finished` the cached process will still enter the running state. The process state will have then
# been overridden by the engine to `Running` so we cannot check that, but if the `exit_status` is anything other
# than `None`, it should mean this node was taken from the cache, so the process should not be rerun.
if self.node.exit_status is not None:
return ExitCode(self.node.exit_status, self.node.exit_message)
# Now the original functions arguments need to be reconstructed from the inputs to the process, as they were
# passed to the original function call. To do so, all positional parameters are popped from the inputs
# dictionary and added to the positional arguments list.
args = []
kwargs: dict[str, Data] = {}
inputs = dict(self.inputs or {})
for name, parameter in inspect.signature(self._func).parameters.items():
if parameter.kind in [parameter.POSITIONAL_ONLY, parameter.POSITIONAL_OR_KEYWORD]:
args.append(inputs.pop(name))
elif parameter.kind is parameter.VAR_POSITIONAL:
for key in [key for key in inputs.keys() if key.startswith(f'{name}_')]:
args.append(inputs.pop(key))
# Any inputs that correspond to metadata ports were not part of the original function signature but were added
# by the process function decorator, so these have to be removed.
for key in [key for key, port in self.spec().inputs.items() if port.is_metadata]:
inputs.pop(key, None)
# The remaining inputs have to be keyword arguments.
kwargs.update(**inputs)
result = self._func(*args, **kwargs)
if result is None or isinstance(result, ExitCode): # type: ignore[redundant-expr]
return result # type: ignore[unreachable]
if isinstance(result, Data): # type: ignore[unreachable]
self.out(self.SINGLE_OUTPUT_LINKNAME, result) # type: ignore[unreachable]
elif isinstance(result, collections.abc.Mapping):
for name, value in result.items():
self.out(name, value)
else:
raise TypeError(
"Function process returned an output with unsupported type '{}'\n"
'Must be a Data type or a mapping of {{string: Data}}'.format(result.__class__)
)
return ExitCode()