# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Transport tasks for calculation jobs."""
import asyncio
import functools
import logging
import tempfile
from typing import TYPE_CHECKING, Any, Callable, Optional
import plumpy
import plumpy.futures
import plumpy.process_states
from aiida.common.datastructures import CalcJobState
from aiida.common.exceptions import FeatureNotAvailable, TransportTaskException
from aiida.common.folders import SandboxFolder
from aiida.engine.daemon import execmanager
from aiida.engine.transports import TransportQueue
from aiida.engine.utils import InterruptableFuture, exponential_backoff_retry, interruptable_task
from aiida.manage.configuration import get_config_option
from aiida.orm.nodes.process.calculation.calcjob import CalcJobNode
from aiida.schedulers.datastructures import JobState
from ..process import ProcessState
if TYPE_CHECKING:
from .calcjob import CalcJob
UPLOAD_COMMAND = 'upload'
SUBMIT_COMMAND = 'submit'
UPDATE_COMMAND = 'update'
RETRIEVE_COMMAND = 'retrieve'
STASH_COMMAND = 'stash'
KILL_COMMAND = 'kill'
RETRY_INTERVAL_OPTION = 'transport.task_retry_initial_interval'
MAX_ATTEMPTS_OPTION = 'transport.task_maximum_attempts'
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
[docs]class PreSubmitException(Exception):
"""Raise in the `do_upload` coroutine when an exception is raised in `CalcJob.presubmit`."""
[docs]async def task_upload_job(process: 'CalcJob', transport_queue: TransportQueue, cancellable: InterruptableFuture):
"""Transport task that will attempt to upload the files of a job calculation to the remote.
The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager
function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will
retry after an interval that increases exponentially with the number of retries, for a maximum number of retries.
If all retries fail, the task will raise a TransportTaskException
:param process: the job calculation
:param transport_queue: the TransportQueue from which to request a Transport
:param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled
:raises: TransportTaskException if after the maximum number of retries the transport task still excepted
"""
node = process.node
if node.get_state() == CalcJobState.SUBMITTING:
logger.warning(f'CalcJob<{node.pk}> already marked as SUBMITTING, skipping task_update_job')
return
initial_interval = get_config_option(RETRY_INTERVAL_OPTION)
max_attempts = get_config_option(MAX_ATTEMPTS_OPTION)
authinfo = node.get_authinfo()
async def do_upload():
with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)
with SandboxFolder() as folder:
# Any exception thrown in `presubmit` call is not transient so we circumvent the exponential backoff
try:
calc_info = process.presubmit(folder)
except Exception as exception: # pylint: disable=broad-except
raise PreSubmitException('exception occurred in presubmit call') from exception
else:
execmanager.upload_calculation(node, transport, calc_info, folder)
skip_submit = calc_info.skip_submit or False
return skip_submit
try:
logger.info(f'scheduled request to upload CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, PreSubmitException, plumpy.process_states.Interruption)
skip_submit = await exponential_backoff_retry(
do_upload, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
)
except PreSubmitException:
raise
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption):
raise
except Exception as exception:
logger.warning(f'uploading CalcJob<{node.pk}> failed')
raise TransportTaskException(f'upload_calculation failed {max_attempts} times consecutively') from exception
else:
logger.info(f'uploading CalcJob<{node.pk}> successful')
node.set_state(CalcJobState.SUBMITTING)
return skip_submit
[docs]async def task_submit_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture):
"""Transport task that will attempt to submit a job calculation.
The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager
function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will
retry after an interval that increases exponentially with the number of retries, for a maximum number of retries.
If all retries fail, the task will raise a TransportTaskException
:param node: the node that represents the job calculation
:param transport_queue: the TransportQueue from which to request a Transport
:param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled
:raises: TransportTaskException if after the maximum number of retries the transport task still excepted
"""
if node.get_state() == CalcJobState.WITHSCHEDULER:
assert node.get_job_id() is not None, 'job is WITHSCHEDULER, however, it does not have a job id'
logger.warning(f'CalcJob<{node.pk}> already marked as WITHSCHEDULER, skipping task_submit_job')
return node.get_job_id()
initial_interval = get_config_option(RETRY_INTERVAL_OPTION)
max_attempts = get_config_option(MAX_ATTEMPTS_OPTION)
authinfo = node.get_authinfo()
async def do_submit():
with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)
return execmanager.submit_calculation(node, transport)
try:
logger.info(f'scheduled request to submit CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
result = await exponential_backoff_retry(
do_submit, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
)
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): # pylint: disable=try-except-raise
raise
except Exception as exception:
logger.warning(f'submitting CalcJob<{node.pk}> failed')
raise TransportTaskException(f'submit_calculation failed {max_attempts} times consecutively') from exception
else:
logger.info(f'submitting CalcJob<{node.pk}> successful')
node.set_state(CalcJobState.WITHSCHEDULER)
return result
[docs]async def task_update_job(node: CalcJobNode, job_manager, cancellable: InterruptableFuture):
"""Transport task that will attempt to update the scheduler status of the job calculation.
The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager
function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will
retry after an interval that increases exponentially with the number of retries, for a maximum number of retries.
If all retries fail, the task will raise a TransportTaskException
:param node: the node that represents the job calculation
:type node: :class:`aiida.orm.nodes.process.calculation.calcjob.CalcJobNode`
:param job_manager: The job manager
:type job_manager: :class:`aiida.engine.processes.calcjobs.manager.JobManager`
:param cancellable: A cancel flag
:type cancellable: :class:`aiida.engine.utils.InterruptableFuture`
:return: True if the tasks was successfully completed, False otherwise
"""
state = node.get_state()
if state in [CalcJobState.RETRIEVING, CalcJobState.STASHING]:
logger.warning(f'CalcJob<{node.pk}> already marked as `{state}`, skipping task_update_job')
return True
initial_interval = get_config_option(RETRY_INTERVAL_OPTION)
max_attempts = get_config_option(MAX_ATTEMPTS_OPTION)
authinfo = node.get_authinfo()
job_id = node.get_job_id()
async def do_update():
# Get the update request
with job_manager.request_job_info_update(authinfo, job_id) as update_request:
job_info = await cancellable.with_interrupt(update_request)
if job_info is None:
# If the job is computed or not found assume it's done
node.set_scheduler_state(JobState.DONE)
job_done = True
else:
node.set_last_job_info(job_info)
node.set_scheduler_state(job_info.job_state)
job_done = job_info.job_state == JobState.DONE
return job_done
try:
logger.info(f'scheduled request to update CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
job_done = await exponential_backoff_retry(
do_update, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
)
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): # pylint: disable=try-except-raise
raise
except Exception as exception:
logger.warning(f'updating CalcJob<{node.pk}> failed')
raise TransportTaskException(f'update_calculation failed {max_attempts} times consecutively') from exception
else:
logger.info(f'updating CalcJob<{node.pk}> successful')
if job_done:
node.set_state(CalcJobState.STASHING)
return job_done
[docs]async def task_retrieve_job(
node: CalcJobNode, transport_queue: TransportQueue, retrieved_temporary_folder: str,
cancellable: InterruptableFuture
):
"""Transport task that will attempt to retrieve all files of a completed job calculation.
The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager
function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will
retry after an interval that increases exponentially with the number of retries, for a maximum number of retries.
If all retries fail, the task will raise a TransportTaskException
:param node: the node that represents the job calculation
:param transport_queue: the TransportQueue from which to request a Transport
:param retrieved_temporary_folder: the absolute path to a directory to store files
:param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled
:raises: TransportTaskException if after the maximum number of retries the transport task still excepted
"""
if node.get_state() == CalcJobState.PARSING:
logger.warning(f'CalcJob<{node.pk}> already marked as PARSING, skipping task_retrieve_job')
return
initial_interval = get_config_option(RETRY_INTERVAL_OPTION)
max_attempts = get_config_option(MAX_ATTEMPTS_OPTION)
authinfo = node.get_authinfo()
async def do_retrieve():
with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)
# Perform the job accounting and set it on the node if successful. If the scheduler does not implement this
# still set the attribute but set it to `None`. This way we can distinguish calculation jobs for which the
# accounting was called but could not be set.
scheduler = node.computer.get_scheduler() # type: ignore[union-attr]
scheduler.set_transport(transport)
if node.get_job_id() is None:
logger.warning(f'there is no job id for CalcJobNoe<{node.pk}>: skipping `get_detailed_job_info`')
return execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder)
try:
detailed_job_info = scheduler.get_detailed_job_info(node.get_job_id())
except FeatureNotAvailable:
logger.info(f'detailed job info not available for scheduler of CalcJob<{node.pk}>')
node.set_detailed_job_info(None)
else:
node.set_detailed_job_info(detailed_job_info)
return execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder)
try:
logger.info(f'scheduled request to retrieve CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
result = await exponential_backoff_retry(
do_retrieve, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
)
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption): # pylint: disable=try-except-raise
raise
except Exception as exception:
logger.warning(f'retrieving CalcJob<{node.pk}> failed')
raise TransportTaskException(f'retrieve_calculation failed {max_attempts} times consecutively') from exception
else:
node.set_state(CalcJobState.PARSING)
logger.info(f'retrieving CalcJob<{node.pk}> successful')
return result
[docs]async def task_stash_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture):
"""Transport task that will optionally stash files of a completed job calculation on the remote.
The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager
function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will
retry after an interval that increases exponentially with the number of retries, for a maximum number of retries.
If all retries fail, the task will raise a TransportTaskException
:param node: the node that represents the job calculation
:param transport_queue: the TransportQueue from which to request a Transport
:param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled
:type cancellable: :class:`aiida.engine.utils.InterruptableFuture`
:raises: Return if the tasks was successfully completed
:raises: TransportTaskException if after the maximum number of retries the transport task still excepted
"""
if node.get_state() == CalcJobState.RETRIEVING:
logger.warning(f'calculation<{node.pk}> already marked as RETRIEVING, skipping task_stash_job')
return
initial_interval = get_config_option(RETRY_INTERVAL_OPTION)
max_attempts = get_config_option(MAX_ATTEMPTS_OPTION)
authinfo = node.get_authinfo()
async def do_stash():
with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)
logger.info(f'stashing calculation<{node.pk}>')
return execmanager.stash_calculation(node, transport)
try:
await exponential_backoff_retry(
do_stash,
initial_interval,
max_attempts,
logger=node.logger,
ignore_exceptions=plumpy.process_states.Interruption
)
except plumpy.process_states.Interruption:
raise
except Exception as exception:
logger.warning(f'stashing calculation<{node.pk}> failed')
raise TransportTaskException(f'stash_calculation failed {max_attempts} times consecutively') from exception
else:
node.set_state(CalcJobState.RETRIEVING)
logger.info(f'stashing calculation<{node.pk}> successful')
return
[docs]async def task_kill_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture):
"""Transport task that will attempt to kill a job calculation.
The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager
function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will
retry after an interval that increases exponentially with the number of retries, for a maximum number of retries.
If all retries fail, the task will raise a TransportTaskException
:param node: the node that represents the job calculation
:param transport_queue: the TransportQueue from which to request a Transport
:param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled
:raises: TransportTaskException if after the maximum number of retries the transport task still excepted
"""
initial_interval = get_config_option(RETRY_INTERVAL_OPTION)
max_attempts = get_config_option(MAX_ATTEMPTS_OPTION)
if node.get_state() in [CalcJobState.UPLOADING, CalcJobState.SUBMITTING]:
logger.warning(f'CalcJob<{node.pk}> killed, it was in the {node.get_state()} state')
return True
authinfo = node.get_authinfo()
async def do_kill():
with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)
return execmanager.kill_calculation(node, transport)
try:
logger.info(f'scheduled request to kill CalcJob<{node.pk}>')
result = await exponential_backoff_retry(do_kill, initial_interval, max_attempts, logger=node.logger)
except plumpy.process_states.Interruption:
raise
except Exception as exception:
logger.warning(f'killing CalcJob<{node.pk}> failed')
raise TransportTaskException(f'kill_calculation failed {max_attempts} times consecutively') from exception
else:
logger.info(f'killing CalcJob<{node.pk}> successful')
node.set_scheduler_state(JobState.DONE)
return result
[docs]class Waiting(plumpy.process_states.Waiting):
"""The waiting state for the `CalcJob` process."""
[docs] def __init__(
self,
process: 'CalcJob',
done_callback: Optional[Callable[..., Any]],
msg: Optional[str] = None,
data: Optional[Any] = None
):
"""
:param process: The process this state belongs to
"""
super().__init__(process, done_callback, msg, data)
self._task: Optional[InterruptableFuture] = None
self._killing: Optional[plumpy.futures.Future] = None
@property
def process(self) -> 'CalcJob':
"""
:return: The process
"""
return self.state_machine # type: ignore[return-value]
[docs] def load_instance_state(self, saved_state, load_context):
super().load_instance_state(saved_state, load_context)
self._task = None
self._killing = None
[docs] async def execute(self) -> plumpy.process_states.State: # type: ignore[override] # pylint: disable=invalid-overridden-method
"""Override the execute coroutine of the base `Waiting` state."""
# pylint: disable=too-many-branches,too-many-statements
node = self.process.node
transport_queue = self.process.runner.transport
result: plumpy.process_states.State = self
command = self.data
process_status = f'Waiting for transport task: {command}'
try:
if command == UPLOAD_COMMAND:
node.set_process_status(process_status)
skip_submit = await self._launch_task(task_upload_job, self.process, transport_queue)
if skip_submit:
result = self.retrieve()
else:
result = self.submit()
elif command == SUBMIT_COMMAND:
node.set_process_status(process_status)
await self._launch_task(task_submit_job, node, transport_queue)
result = self.update()
elif command == UPDATE_COMMAND:
job_done = False
while not job_done:
scheduler_state = node.get_scheduler_state()
scheduler_state_string = scheduler_state.name if scheduler_state else 'UNKNOWN'
process_status = f'Monitoring scheduler: job state {scheduler_state_string}'
node.set_process_status(process_status)
job_done = await self._launch_task(task_update_job, node, self.process.runner.job_manager)
if node.get_option('stash') is not None:
result = self.stash()
else:
result = self.retrieve()
elif command == STASH_COMMAND:
node.set_process_status(process_status)
await self._launch_task(task_stash_job, node, transport_queue)
result = self.retrieve()
elif command == RETRIEVE_COMMAND:
node.set_process_status(process_status)
temp_folder = tempfile.mkdtemp()
await self._launch_task(task_retrieve_job, node, transport_queue, temp_folder)
result = self.parse(temp_folder)
else:
raise RuntimeError('Unknown waiting command')
except TransportTaskException as exception:
raise plumpy.process_states.PauseInterruption(f'Pausing after failed transport task: {exception}')
except plumpy.process_states.KillInterruption:
await self._launch_task(task_kill_job, node, transport_queue)
if self._killing is not None:
self._killing.set_result(True)
else:
logger.warning(f'killed CalcJob<{node.pk}> but async future was None')
raise
except (plumpy.futures.CancelledError, asyncio.CancelledError):
node.set_process_status(f'Transport task {command} was cancelled')
raise
except plumpy.process_states.Interruption:
node.set_process_status(f'Transport task {command} was interrupted')
raise
else:
node.set_process_status(None)
return result
finally:
# If we were trying to kill but we didn't deal with it, make sure it's set here
if self._killing and not self._killing.done():
self._killing.set_result(False)
[docs] async def _launch_task(self, coro, *args, **kwargs):
"""Launch a coroutine as a task, making sure to make it interruptable."""
task_fn = functools.partial(coro, *args, **kwargs)
try:
self._task = interruptable_task(task_fn)
result = await self._task
return result
finally:
self._task = None
[docs] def upload(self) -> 'Waiting':
"""Return the `Waiting` state that will `upload` the `CalcJob`."""
msg = 'Waiting for calculation folder upload'
return self.create_state(ProcessState.WAITING, None, msg=msg, data=UPLOAD_COMMAND) # type: ignore[return-value]
[docs] def submit(self) -> 'Waiting':
"""Return the `Waiting` state that will `submit` the `CalcJob`."""
msg = 'Waiting for scheduler submission'
return self.create_state(ProcessState.WAITING, None, msg=msg, data=SUBMIT_COMMAND) # type: ignore[return-value]
[docs] def update(self) -> 'Waiting':
"""Return the `Waiting` state that will `update` the `CalcJob`."""
msg = 'Waiting for scheduler update'
return self.create_state(ProcessState.WAITING, None, msg=msg, data=UPDATE_COMMAND) # type: ignore[return-value]
[docs] def retrieve(self) -> 'Waiting':
"""Return the `Waiting` state that will `retrieve` the `CalcJob`."""
msg = 'Waiting to retrieve'
return self.create_state(
ProcessState.WAITING, None, msg=msg, data=RETRIEVE_COMMAND
) # type: ignore[return-value]
[docs] def stash(self):
"""Return the `Waiting` state that will `stash` the `CalcJob`."""
msg = 'Waiting to stash'
return self.create_state(ProcessState.WAITING, None, msg=msg, data=STASH_COMMAND)
[docs] def parse(self, retrieved_temporary_folder: str) -> plumpy.process_states.Running:
"""Return the `Running` state that will parse the `CalcJob`.
:param retrieved_temporary_folder: temporary folder used in retrieving that can be used during parsing.
"""
return self.create_state(
ProcessState.RUNNING, self.process.parse, retrieved_temporary_folder
) # type: ignore[return-value]
[docs] def interrupt(self, reason: Any) -> Optional[plumpy.futures.Future]: # type: ignore[override]
"""Interrupt the `Waiting` state by calling interrupt on the transport task `InterruptableFuture`."""
if self._task is not None:
self._task.interrupt(reason)
if isinstance(reason, plumpy.process_states.KillInterruption):
if self._killing is None:
self._killing = plumpy.futures.Future()
return self._killing
return None