# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
# pylint: disable=global-statement
"""Runners that can run and submit processes."""
import asyncio
import functools
import logging
import signal
import threading
from typing import Any, Callable, Dict, NamedTuple, Optional, Tuple, Type, Union
import uuid
import kiwipy
from plumpy.communications import wrap_communicator
from plumpy.events import reset_event_loop_policy, set_event_loop_policy
from plumpy.persistence import Persister
from plumpy.process_comms import RemoteProcessThreadController
from aiida.common import exceptions
from aiida.orm import ProcessNode, load_node
from aiida.plugins.utils import PluginVersionProvider
from . import transports, utils
from .processes import Process, ProcessBuilder, ProcessState, futures
from .processes.calcjobs import manager
__all__ = ('Runner',)
LOGGER = logging.getLogger(__name__)
[docs]class ResultAndNode(NamedTuple):
result: Dict[str, Any]
node: ProcessNode
[docs]class ResultAndPk(NamedTuple):
result: Dict[str, Any]
pk: int
TYPE_RUN_PROCESS = Union[Process, Type[Process], ProcessBuilder] # pylint: disable=invalid-name
# run can also be process function, but it is not clear what type this should be
TYPE_SUBMIT_PROCESS = Union[Process, Type[Process], ProcessBuilder] # pylint: disable=invalid-name
[docs]class Runner: # pylint: disable=too-many-public-methods
"""Class that can launch processes by running in the current interpreter or by submitting them to the daemon."""
_persister: Optional[Persister] = None
_communicator: Optional[kiwipy.Communicator] = None
_controller: Optional[RemoteProcessThreadController] = None
_closed: bool = False
[docs] def __init__(
self,
poll_interval: Union[int, float] = 0,
loop: Optional[asyncio.AbstractEventLoop] = None,
communicator: Optional[kiwipy.Communicator] = None,
rmq_submit: bool = False,
persister: Optional[Persister] = None
):
"""Construct a new runner.
:param poll_interval: interval in seconds between polling for status of active sub processes
:param loop: an asyncio event loop, if none is suppled a new one will be created
:param communicator: the communicator to use
:param rmq_submit: if True, processes will be submitted to RabbitMQ, otherwise they will be scheduled here
:param persister: the persister to use to persist processes
"""
assert not (rmq_submit and persister is None), \
'Must supply a persister if you want to submit using communicator'
set_event_loop_policy()
self._loop = loop if loop is not None else asyncio.get_event_loop()
self._poll_interval = poll_interval
self._rmq_submit = rmq_submit
self._transport = transports.TransportQueue(self._loop)
self._job_manager = manager.JobManager(self._transport)
self._persister = persister
self._plugin_version_provider = PluginVersionProvider()
if communicator is not None:
self._communicator = wrap_communicator(communicator, self._loop)
self._controller = RemoteProcessThreadController(communicator)
elif self._rmq_submit:
LOGGER.warning('Disabling RabbitMQ submission, no communicator provided')
self._rmq_submit = False
[docs] def __enter__(self) -> 'Runner':
return self
[docs] def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
@property
def loop(self) -> asyncio.AbstractEventLoop:
"""Get the event loop of this runner."""
return self._loop
@property
def transport(self) -> transports.TransportQueue:
return self._transport
@property
def persister(self) -> Optional[Persister]:
"""Get the persister used by this runner."""
return self._persister
@property
def communicator(self) -> Optional[kiwipy.Communicator]:
"""Get the communicator used by this runner."""
return self._communicator
@property
def plugin_version_provider(self) -> PluginVersionProvider:
return self._plugin_version_provider
@property
def job_manager(self) -> manager.JobManager:
return self._job_manager
@property
def controller(self) -> Optional[RemoteProcessThreadController]:
"""Get the controller used by this runner."""
return self._controller
@property
def is_daemon_runner(self) -> bool:
"""Return whether the runner is a daemon runner, which means it submits processes over RabbitMQ.
:return: True if the runner is a daemon runner
"""
return self._rmq_submit
[docs] def is_closed(self) -> bool:
return self._closed
[docs] def start(self) -> None:
"""Start the internal event loop."""
self._loop.run_forever()
[docs] def stop(self) -> None:
"""Stop the internal event loop."""
self._loop.stop()
[docs] def run_until_complete(self, future: asyncio.Future) -> Any:
"""Run the loop until the future has finished and return the result."""
with utils.loop_scope(self._loop):
return self._loop.run_until_complete(future)
[docs] def close(self) -> None:
"""Close the runner by stopping the loop."""
assert not self._closed
self.stop()
reset_event_loop_policy()
self._closed = True
[docs] def instantiate_process(self, process: TYPE_RUN_PROCESS, *args, **inputs):
from .utils import instantiate_process # pylint: disable=no-name-in-module
return instantiate_process(self, process, *args, **inputs)
[docs] def submit(self, process: TYPE_SUBMIT_PROCESS, *args: Any, **inputs: Any):
"""
Submit the process with the supplied inputs to this runner immediately returning control to
the interpreter. The return value will be the calculation node of the submitted process
:param process: the process class to submit
:param inputs: the inputs to be passed to the process
:return: the calculation node of the process
"""
assert not utils.is_process_function(process), 'Cannot submit a process function'
assert not self._closed
process_inited = self.instantiate_process(process, *args, **inputs)
if not process_inited.metadata.store_provenance:
raise exceptions.InvalidOperation('cannot submit a process with `store_provenance=False`')
if process_inited.metadata.get('dry_run', False):
raise exceptions.InvalidOperation('cannot submit a process from within another with `dry_run=True`')
if self._rmq_submit:
assert self.persister is not None, 'runner does not have a persister'
assert self.controller is not None, 'runner does not have a controller'
self.persister.save_checkpoint(process_inited)
process_inited.close()
self.controller.continue_process(process_inited.pid, nowait=False, no_reply=True)
else:
self.loop.create_task(process_inited.step_until_terminated())
return process_inited.node
[docs] def schedule(self, process: TYPE_SUBMIT_PROCESS, *args: Any, **inputs: Any) -> ProcessNode:
"""
Schedule a process to be executed by this runner
:param process: the process class to submit
:param inputs: the inputs to be passed to the process
:return: the calculation node of the process
"""
assert not utils.is_process_function(process), 'Cannot submit a process function'
assert not self._closed
process_inited = self.instantiate_process(process, *args, **inputs)
self.loop.create_task(process_inited.step_until_terminated())
return process_inited.node
[docs] def _run(self, process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> Tuple[Dict[str, Any], ProcessNode]:
"""
Run the process with the supplied inputs in this runner that will block until the process is completed.
The return value will be the results of the completed process
:param process: the process class or process function to run
:param inputs: the inputs to be passed to the process
:return: tuple of the outputs of the process and the calculation node
"""
assert not self._closed
if utils.is_process_function(process):
result, node = process.run_get_node(*args, **inputs) # type: ignore[union-attr]
return result, node
with utils.loop_scope(self.loop):
process_inited = self.instantiate_process(process, *args, **inputs)
def kill_process(_num, _frame):
"""Send the kill signal to the process in the current scope."""
if process_inited.is_killing:
LOGGER.warning('runner received interrupt, process %s already being killed', process_inited.pid)
return
LOGGER.critical('runner received interrupt, killing process %s', process_inited.pid)
process_inited.kill(msg='Process was killed because the runner received an interrupt')
original_handler_int = signal.getsignal(signal.SIGINT)
original_handler_term = signal.getsignal(signal.SIGTERM)
try:
signal.signal(signal.SIGINT, kill_process)
signal.signal(signal.SIGTERM, kill_process)
process_inited.execute()
finally:
signal.signal(signal.SIGINT, original_handler_int)
signal.signal(signal.SIGTERM, original_handler_term)
return process_inited.outputs, process_inited.node
[docs] def run(self, process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> Dict[str, Any]:
"""
Run the process with the supplied inputs in this runner that will block until the process is completed.
The return value will be the results of the completed process
:param process: the process class or process function to run
:param inputs: the inputs to be passed to the process
:return: the outputs of the process
"""
result, _ = self._run(process, *args, **inputs)
return result
[docs] def run_get_node(self, process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> ResultAndNode:
"""
Run the process with the supplied inputs in this runner that will block until the process is completed.
The return value will be the results of the completed process
:param process: the process class or process function to run
:param inputs: the inputs to be passed to the process
:return: tuple of the outputs of the process and the calculation node
"""
result, node = self._run(process, *args, **inputs)
return ResultAndNode(result, node)
[docs] def run_get_pk(self, process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> ResultAndPk:
"""
Run the process with the supplied inputs in this runner that will block until the process is completed.
The return value will be the results of the completed process
:param process: the process class or process function to run
:param inputs: the inputs to be passed to the process
:return: tuple of the outputs of the process and process node pk
"""
result, node = self._run(process, *args, **inputs)
return ResultAndPk(result, node.pk)
[docs] def call_on_process_finish(self, pk: int, callback: Callable[[], Any]) -> None:
"""Schedule a callback when the process of the given pk is terminated.
This method will add a broadcast subscriber that will listen for state changes of the target process to be
terminated. As a fail-safe, a polling-mechanism is used to check the state of the process, should the broadcast
message be missed by the subscriber, in order to prevent the caller to wait indefinitely.
:param pk: pk of the process
:param callback: function to be called upon process termination
"""
assert self.communicator is not None, 'communicator not set for runner'
node = load_node(pk=pk)
subscriber_identifier = str(uuid.uuid4())
event = threading.Event()
def inline_callback(event, *args, **kwargs): # pylint: disable=unused-argument
"""Callback to wrap the actual callback, that will always remove the subscriber that will be registered.
As soon as the callback is called successfully once, the `event` instance is toggled, such that if this
inline callback is called a second time, the actual callback is not called again.
"""
if event.is_set():
return
try:
callback()
finally:
event.set()
self.communicator.remove_broadcast_subscriber(subscriber_identifier) # type: ignore[union-attr]
broadcast_filter = kiwipy.BroadcastFilter(functools.partial(inline_callback, event), sender=pk)
for state in [ProcessState.FINISHED, ProcessState.KILLED, ProcessState.EXCEPTED]:
broadcast_filter.add_subject_filter(f'state_changed.*.{state.value}')
LOGGER.info('adding subscriber for broadcasts of %d', pk)
self.communicator.add_broadcast_subscriber(broadcast_filter, subscriber_identifier)
self._poll_process(node, functools.partial(inline_callback, event))
[docs] def get_process_future(self, pk: int) -> futures.ProcessFuture:
"""Return a future for a process.
The future will have the process node as the result when finished.
:return: A future representing the completion of the process node
"""
return futures.ProcessFuture(pk, self._loop, self._poll_interval, self._communicator)
[docs] def _poll_process(self, node, callback):
"""Check whether the process state of the node is terminated and call the callback or reschedule it.
:param node: the process node
:param callback: callback to be called when process is terminated
"""
if node.is_terminated:
args = [node.__class__.__name__, node.pk]
LOGGER.info('%s<%d> confirmed to be terminated by backup polling mechanism', *args)
self._loop.call_soon(callback)
else:
self._loop.call_later(self._poll_interval, self._poll_process, node, callback)