###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""AiiDA specific implementation of plumpy Ports and PortNamespaces for the ProcessSpec."""
from __future__ import annotations
import re
import warnings
from collections.abc import Mapping
from typing import Any, Callable, Dict, Optional, Sequence
from plumpy import ports
from plumpy.ports import breadcrumbs_to_port
from aiida.common.links import validate_link_label
from aiida.orm import Data, Node, to_aiida_type
__all__ = (
'PortNamespace',
'InputPort',
'OutputPort',
'CalcJobOutputPort',
'WithNonDb',
'WithSerialize',
'PORT_NAMESPACE_SEPARATOR',
)
PORT_NAME_MAX_CONSECUTIVE_UNDERSCORES = 1
PORT_NAMESPACE_SEPARATOR = '__' # The character sequence to represent a nested port namespace in a flat link label
OutputPort = ports.OutputPort
[docs]
class WithNonDb:
"""A mixin that adds support to a port to flag it should not be stored in the database using the ``non_db`` flag."""
[docs]
def __init__(self, *args, **kwargs) -> None:
self._non_db_explicitly_set: bool = bool('non_db' in kwargs)
non_db = kwargs.pop('non_db', False)
super().__init__(*args, **kwargs)
self._non_db: bool = non_db
@property
def non_db_explicitly_set(self) -> bool:
"""Return whether the ``non_db`` keyword was explicitly passed in the construction of the ``InputPort``.
:return: ``True`` if ``non_db`` was explicitly defined during construction, ``False`` otherwise
"""
return self._non_db_explicitly_set
@property
def non_db(self) -> bool:
"""Return whether the value of this ``InputPort`` should be stored in the database.
:return: ``True`` if it should be stored, ``False`` otherwise
"""
return self._non_db
@non_db.setter
def non_db(self, non_db: bool) -> None:
"""Set whether the value of this ``InputPort`` should be stored as a ``Data`` in the database."""
self._non_db_explicitly_set = True
self._non_db = non_db
[docs]
class WithSerialize:
"""A mixin that adds support for a serialization function which is automatically applied on inputs
that are not AiiDA data types.
"""
[docs]
def __init__(self, *args, **kwargs) -> None:
serializer = kwargs.pop('serializer', None)
super().__init__(*args, **kwargs)
self._serializer: Callable[[Any], 'Data'] = serializer
@property
def serializer(self) -> Callable[[Any], 'Data'] | None:
"""Return the serializer."""
return self._serializer
[docs]
def serialize(self, value: Any) -> 'Data':
"""Serialize the given value, unless it is ``None``, already a Data type, or no serializer function is defined.
:param value: the value to be serialized
:returns: a serialized version of the value or the unchanged value
"""
if self._serializer is None or value is None or isinstance(value, Data):
return value
return self._serializer(value)
[docs]
class CalcJobOutputPort(ports.OutputPort):
"""Sub class of plumpy.OutputPort which adds the `_pass_to_parser` attribute."""
[docs]
def __init__(self, *args, **kwargs) -> None:
pass_to_parser = kwargs.pop('pass_to_parser', False)
super().__init__(*args, **kwargs)
self._pass_to_parser: bool = pass_to_parser
@property
def pass_to_parser(self) -> bool:
return self._pass_to_parser
[docs]
class PortNamespace(WithMetadata, WithNonDb, ports.PortNamespace):
"""Sub class of plumpy.PortNamespace which implements the serialize method to support automatic recursive
serialization of a given mapping onto the ports of the PortNamespace.
"""
[docs]
def __setitem__(self, key: str, port: ports.Port) -> None:
"""Ensure that a `Port` being added inherits the `non_db` attribute if not explicitly defined at construction.
The reasoning is that if a `PortNamespace` has `non_db=True`, which is different from the default value, very
often all leaves should be also `non_db=True`. To prevent a user from having to specify it manually everytime
we overload the value here, unless it was specifically set during construction.
Note that the `non_db` attribute is not present for all `Port` sub classes so we have to check for it first.
"""
if not isinstance(port, ports.Port):
raise TypeError('port needs to be an instance of Port')
self.validate_port_name(key)
if hasattr(port, 'is_metadata_explicitly_set') and not port.is_metadata_explicitly_set: # type: ignore[attr-defined]
port.is_metadata = self.is_metadata # type: ignore[attr-defined]
if hasattr(port, 'non_db_explicitly_set') and not port.non_db_explicitly_set: # type: ignore[attr-defined]
port.non_db = self.non_db # type: ignore[attr-defined]
# If the port is not metadata (signified by ``is_metadata`` and ``non_db`` being ``False`` if defined) and it
# does not already define a serializer, set the default serializer to ``to_aiida_type``.
if (
((hasattr(port, 'is_metadata') and not port.is_metadata) and (hasattr(port, 'non_db') and not port.non_db))
and hasattr(port, 'serializer')
and port.serializer is None
):
port._serializer = to_aiida_type
super().__setitem__(key, port)
[docs]
@staticmethod
def validate_port_name(port_name: str) -> None:
"""Validate the given port name.
Valid port names adhere to the following restrictions:
* Is a valid link label (see below)
* Does not contain two or more consecutive underscores
Valid link labels adhere to the following restrictions:
* Has to be a valid python identifier
* Can only contain alphanumeric characters and underscores
* Can not start or end with an underscore
:param port_name: the proposed name of the port to be added
:raise TypeError: if the port name is not a string type
:raise ValueError: if the port name is invalid
"""
try:
validate_link_label(port_name)
except ValueError as exception:
raise ValueError(f'invalid port name `{port_name}`: {exception}')
# Following regexes will match all groups of consecutive underscores where each group will be of the form
# `('___', '_')`, where the first element is the matched group of consecutive underscores.
consecutive_underscores = [match[0] for match in re.findall(r'((_)\2+)', port_name)]
if any(len(entry) > PORT_NAME_MAX_CONSECUTIVE_UNDERSCORES for entry in consecutive_underscores):
raise ValueError(f'invalid port name `{port_name}`: more than two consecutive underscores')
[docs]
def serialize(self, mapping: Optional[Dict[str, Any]], breadcrumbs: Sequence[str] = ()) -> Optional[Dict[str, Any]]:
"""Serialize the given mapping onto this `Portnamespace`.
It will recursively call this function on any nested `PortNamespace` or the serialize function on any `Ports`.
:param mapping: a mapping of values to be serialized
:param breadcrumbs: a tuple with the namespaces of parent namespaces
:returns: the serialized mapping
"""
if mapping is None:
return None
breadcrumbs = (*breadcrumbs, self.name)
if not isinstance(mapping, Mapping):
port_name = breadcrumbs_to_port(breadcrumbs)
raise TypeError(f'port namespace `{port_name}` received `{type(mapping)}` instead of a dictionary')
result = {}
for name, value in mapping.items():
if name in self:
port = self[name]
if isinstance(port, PortNamespace):
result[name] = port.serialize(value, breadcrumbs)
elif isinstance(port, InputPort):
result[name] = port.serialize(value)
else:
raise AssertionError(f'port does not have a serialize method: {port}')
else:
result[name] = value
return result