Source code for aiida.plugins.factories

###########################################################################
# Copyright (c), The AiiDA team. All rights reserved.                     #
# This file is part of the AiiDA code.                                    #
#                                                                         #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file        #
# For further information please visit http://www.aiida.net               #
###########################################################################
# ruff: noqa: N802
"""Definition of factories to load classes from the various plugin groups."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, Tuple, Type, Union, overload

from aiida.common.exceptions import InvalidEntryPointTypeError

__all__ = (
    'BaseFactory',
    'BrokerFactory',
    'CalculationFactory',
    'CalcJobImporterFactory',
    'DataFactory',
    'DbImporterFactory',
    'GroupFactory',
    'OrbitalFactory',
    'ParserFactory',
    'SchedulerFactory',
    'StorageFactory',
    'TransportFactory',
    'WorkflowFactory',
)

if TYPE_CHECKING:
    from importlib_metadata import EntryPoint

    from aiida.brokers import Broker
    from aiida.engine import CalcJob, CalcJobImporter, WorkChain
    from aiida.orm import Data, Group
    from aiida.orm.implementation import StorageBackend
    from aiida.parsers import Parser
    from aiida.schedulers import Scheduler
    from aiida.tools.data.orbital import Orbital
    from aiida.tools.dbimporters import DbImporter
    from aiida.transports import Transport


[docs] def raise_invalid_type_error(entry_point_name: str, entry_point_group: str, valid_classes: Tuple[Any, ...]) -> NoReturn: """Raise an `InvalidEntryPointTypeError` with formatted message. :param entry_point_name: name of the entry point :param entry_point_group: name of the entry point group :param valid_classes: tuple of valid classes for the given entry point group :raises aiida.common.InvalidEntryPointTypeError: always """ template = 'entry point `{}` registered in group `{}` is invalid because its type is not one of: {}' args = (entry_point_name, entry_point_group, ', '.join([e.__name__ for e in valid_classes])) raise InvalidEntryPointTypeError(template.format(*args))
[docs] def BaseFactory(group: str, name: str, load: bool = True) -> Union[EntryPoint, Any]: """Return the plugin class registered under a given entry point group and name. :param group: entry point group :param name: entry point name :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: the plugin class :raises aiida.common.MissingEntryPointError: entry point was not registered :raises aiida.common.MultipleEntryPointError: entry point could not be uniquely resolved :raises aiida.common.LoadingEntryPointError: entry point could not be loaded """ from .entry_point import get_entry_point, load_entry_point if load is True: return load_entry_point(group, name) return get_entry_point(group, name)
@overload def BrokerFactory(entry_point_name: str, load: Literal[True] = True) -> Type['Broker']: ... @overload def BrokerFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint: ...
[docs] def BrokerFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoint, Type['Broker'], Callable]: """Return the `Broker` sub class registered under the given entry point. :param entry_point_name: the entry point name. :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: sub class of :py:class:`~aiida.brokers.broker.Broker` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from inspect import isclass from aiida.brokers import Broker entry_point_group = 'aiida.brokers' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Broker,) if not load: return entry_point if isclass(entry_point) and issubclass(entry_point, Broker): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
@overload def CalculationFactory(entry_point_name: str, load: Literal[True] = True) -> Union[Type['CalcJob'], Callable]: ... @overload def CalculationFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint: ...
[docs] def CalculationFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoint, Type['CalcJob'], Callable]: """Return the `CalcJob` sub class registered under the given entry point. :param entry_point_name: the entry point name. :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: sub class of :py:class:`~aiida.engine.processes.calcjobs.calcjob.CalcJob` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from inspect import isclass from aiida.engine import CalcJob, calcfunction, is_process_function from aiida.orm import CalcFunctionNode entry_point_group = 'aiida.calculations' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (CalcJob, calcfunction) if not load: return entry_point if (isclass(entry_point) and issubclass(entry_point, CalcJob)) or ( is_process_function(entry_point) and entry_point.node_class is CalcFunctionNode # type: ignore[union-attr] ): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
@overload def CalcJobImporterFactory(entry_point_name: str, load: Literal[True] = True) -> Type['CalcJobImporter']: ... @overload def CalcJobImporterFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint: ...
[docs] def CalcJobImporterFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoint, Type['CalcJobImporter']]: """Return the plugin registered under the given entry point. :param entry_point_name: the entry point name. :return: the loaded :class:`~aiida.engine.processes.calcjobs.importer.CalcJobImporter` plugin. :raises ``aiida.common.InvalidEntryPointTypeError``: if the type of the loaded entry point is invalid. """ from inspect import isclass from aiida.engine import CalcJobImporter entry_point_group = 'aiida.calculations.importers' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (CalcJobImporter,) if not load: return entry_point if isclass(entry_point) and issubclass(entry_point, CalcJobImporter): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
@overload def DataFactory(entry_point_name: str, load: Literal[True] = True) -> Type['Data']: ... @overload def DataFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint: ...
[docs] def DataFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoint, Type['Data']]: """Return the `Data` sub class registered under the given entry point. :param entry_point_name: the entry point name. :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: sub class of :py:class:`~aiida.orm.nodes.data.data.Data` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from inspect import isclass from aiida.orm import Data entry_point_group = 'aiida.data' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Data,) if not load: return entry_point if isclass(entry_point) and issubclass(entry_point, Data): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
@overload def DbImporterFactory(entry_point_name: str, load: Literal[True] = True) -> Type['DbImporter']: ... @overload def DbImporterFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint: ...
[docs] def DbImporterFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoint, Type['DbImporter']]: """Return the `DbImporter` sub class registered under the given entry point. :param entry_point_name: the entry point name. :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: sub class of :py:class:`~aiida.tools.dbimporters.baseclasses.DbImporter` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from inspect import isclass from aiida.tools.dbimporters import DbImporter entry_point_group = 'aiida.tools.dbimporters' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (DbImporter,) if not load: return entry_point if isclass(entry_point) and issubclass(entry_point, DbImporter): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
@overload def GroupFactory(entry_point_name: str, load: Literal[True] = True) -> Type['Group']: ... @overload def GroupFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint: ...
[docs] def GroupFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoint, Type['Group']]: """Return the `Group` sub class registered under the given entry point. :param entry_point_name: the entry point name. :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: sub class of :py:class:`~aiida.orm.groups.Group` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from inspect import isclass from aiida.orm import Group entry_point_group = 'aiida.groups' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Group,) if not load: return entry_point if isclass(entry_point) and issubclass(entry_point, Group): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
@overload def OrbitalFactory(entry_point_name: str, load: Literal[True] = True) -> Type['Orbital']: ... @overload def OrbitalFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint: ...
[docs] def OrbitalFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoint, Type['Orbital']]: """Return the `Orbital` sub class registered under the given entry point. :param entry_point_name: the entry point name. :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: sub class of :py:class:`~aiida.tools.data.orbital.orbital.Orbital` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from inspect import isclass from aiida.tools.data.orbital import Orbital entry_point_group = 'aiida.tools.data.orbitals' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Orbital,) if not load: return entry_point if isclass(entry_point) and issubclass(entry_point, Orbital): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
@overload def ParserFactory(entry_point_name: str, load: Literal[True] = True) -> Type['Parser']: ... @overload def ParserFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint: ...
[docs] def ParserFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoint, Type['Parser']]: """Return the `Parser` sub class registered under the given entry point. :param entry_point_name: the entry point name. :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: sub class of :py:class:`~aiida.parsers.parser.Parser` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from inspect import isclass from aiida.parsers import Parser entry_point_group = 'aiida.parsers' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Parser,) if not load: return entry_point if isclass(entry_point) and issubclass(entry_point, Parser): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
@overload def SchedulerFactory(entry_point_name: str, load: Literal[True] = True) -> Type['Scheduler']: ... @overload def SchedulerFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint: ...
[docs] def SchedulerFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoint, Type['Scheduler']]: """Return the `Scheduler` sub class registered under the given entry point. :param entry_point_name: the entry point name. :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: sub class of :py:class:`~aiida.schedulers.scheduler.Scheduler` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from inspect import isclass from aiida.schedulers import Scheduler entry_point_group = 'aiida.schedulers' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Scheduler,) if not load: return entry_point if isclass(entry_point) and issubclass(entry_point, Scheduler): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
@overload def StorageFactory(entry_point_name: str, load: Literal[True] = True) -> Type['StorageBackend']: ... @overload def StorageFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint: ...
[docs] def StorageFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoint, Type['StorageBackend']]: """Return the ``StorageBackend`` sub class registered under the given entry point. :param entry_point_name: the entry point name. :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: sub class of :py:class:`~aiida.orm.implementation.storage_backend.StorageBackend`. :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from inspect import isclass from aiida.orm.implementation import StorageBackend entry_point_group = 'aiida.storage' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (StorageBackend,) if not load: return entry_point if isclass(entry_point) and issubclass(entry_point, StorageBackend): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
@overload def TransportFactory(entry_point_name: str, load: Literal[True] = True) -> Type['Transport']: ... @overload def TransportFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint: ...
[docs] def TransportFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoint, Type['Transport']]: """Return the `Transport` sub class registered under the given entry point. :param entry_point_name: the entry point name. :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from inspect import isclass from aiida.transports import Transport entry_point_group = 'aiida.transports' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Transport,) if not load: return entry_point if isclass(entry_point) and issubclass(entry_point, Transport): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
@overload def WorkflowFactory(entry_point_name: str, load: Literal[True] = True) -> Union[Type['WorkChain'], Callable]: ... @overload def WorkflowFactory(entry_point_name: str, load: Literal[False]) -> EntryPoint: ...
[docs] def WorkflowFactory(entry_point_name: str, load: bool = True) -> Union[EntryPoint, Type['WorkChain'], Callable]: """Return the `WorkChain` sub class registered under the given entry point. :param entry_point_name: the entry point name. :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: sub class of :py:class:`~aiida.engine.processes.workchains.workchain.WorkChain` or a `workfunction` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from inspect import isclass from aiida.engine import WorkChain, is_process_function, workfunction from aiida.orm import WorkFunctionNode entry_point_group = 'aiida.workflows' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (WorkChain, workfunction) if not load: return entry_point if (isclass(entry_point) and issubclass(entry_point, WorkChain)) or ( is_process_function(entry_point) and entry_point.node_class is WorkFunctionNode # type: ignore[union-attr] ): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)