Source code for aiida.plugins.factories

# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved.                     #
# This file is part of the AiiDA code.                                    #
#                                                                         #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file        #
# For further information please visit http://www.aiida.net               #
###########################################################################
# pylint: disable=invalid-name,inconsistent-return-statements,cyclic-import
"""Definition of factories to load classes from the various plugin groups."""

from inspect import isclass
from aiida.common.exceptions import InvalidEntryPointTypeError

__all__ = (
    'BaseFactory', 'CalculationFactory', 'DataFactory', 'DbImporterFactory', 'GroupFactory', 'OrbitalFactory',
    'ParserFactory', 'SchedulerFactory', 'TransportFactory', 'WorkflowFactory'
)


def raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes):
    """Raise an `InvalidEntryPointTypeError` with formatted message.

    :param entry_point_name: name of the entry point
    :param entry_point_group: name of the entry point group
    :param valid_classes: tuple of valid classes for the given entry point group
    :raises aiida.common.InvalidEntryPointTypeError: always
    """
    template = 'entry point `{}` registered in group `{}` is invalid because its type is not one of: {}'
    args = (entry_point_name, entry_point_group, ', '.join([e.__name__ for e in valid_classes]))
    raise InvalidEntryPointTypeError(template.format(*args))


[docs]def BaseFactory(group, name): """Return the plugin class registered under a given entry point group and name. :param group: entry point group :param name: entry point name :return: the plugin class :raises aiida.common.MissingEntryPointError: entry point was not registered :raises aiida.common.MultipleEntryPointError: entry point could not be uniquely resolved :raises aiida.common.LoadingEntryPointError: entry point could not be loaded """ from .entry_point import load_entry_point return load_entry_point(group, name)
[docs]def CalculationFactory(entry_point_name): """Return the `CalcJob` sub class registered under the given entry point. :param entry_point_name: the entry point name :return: sub class of :py:class:`~aiida.engine.processes.calcjobs.calcjob.CalcJob` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from aiida.engine import CalcJob, calcfunction, is_process_function from aiida.orm import CalcFunctionNode entry_point_group = 'aiida.calculations' entry_point = BaseFactory(entry_point_group, entry_point_name) valid_classes = (CalcJob, calcfunction) if isclass(entry_point) and issubclass(entry_point, CalcJob): return entry_point if is_process_function(entry_point) and entry_point.node_class is CalcFunctionNode: return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
[docs]def DataFactory(entry_point_name): """Return the `Data` sub class registered under the given entry point. :param entry_point_name: the entry point name :return: sub class of :py:class:`~aiida.orm.nodes.data.data.Data` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from aiida.orm import Data entry_point_group = 'aiida.data' entry_point = BaseFactory(entry_point_group, entry_point_name) valid_classes = (Data,) if isclass(entry_point) and issubclass(entry_point, Data): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
[docs]def DbImporterFactory(entry_point_name): """Return the `DbImporter` sub class registered under the given entry point. :param entry_point_name: the entry point name :return: sub class of :py:class:`~aiida.tools.dbimporters.baseclasses.DbImporter` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from aiida.tools.dbimporters import DbImporter entry_point_group = 'aiida.tools.dbimporters' entry_point = BaseFactory(entry_point_group, entry_point_name) valid_classes = (DbImporter,) if isclass(entry_point) and issubclass(entry_point, DbImporter): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
[docs]def GroupFactory(entry_point_name): """Return the `Group` sub class registered under the given entry point. :param entry_point_name: the entry point name :return: sub class of :py:class:`~aiida.orm.groups.Group` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from aiida.orm import Group entry_point_group = 'aiida.groups' entry_point = BaseFactory(entry_point_group, entry_point_name) valid_classes = (Group,) if isclass(entry_point) and issubclass(entry_point, Group): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
[docs]def OrbitalFactory(entry_point_name): """Return the `Orbital` sub class registered under the given entry point. :param entry_point_name: the entry point name :return: sub class of :py:class:`~aiida.tools.data.orbital.orbital.Orbital` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from aiida.tools.data.orbital import Orbital entry_point_group = 'aiida.tools.data.orbitals' entry_point = BaseFactory(entry_point_group, entry_point_name) valid_classes = (Orbital,) if isclass(entry_point) and issubclass(entry_point, Orbital): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
[docs]def ParserFactory(entry_point_name): """Return the `Parser` sub class registered under the given entry point. :param entry_point_name: the entry point name :return: sub class of :py:class:`~aiida.parsers.parser.Parser` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from aiida.parsers import Parser entry_point_group = 'aiida.parsers' entry_point = BaseFactory(entry_point_group, entry_point_name) valid_classes = (Parser,) if isclass(entry_point) and issubclass(entry_point, Parser): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
[docs]def SchedulerFactory(entry_point_name): """Return the `Scheduler` sub class registered under the given entry point. :param entry_point_name: the entry point name :return: sub class of :py:class:`~aiida.schedulers.scheduler.Scheduler` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from aiida.schedulers import Scheduler entry_point_group = 'aiida.schedulers' entry_point = BaseFactory(entry_point_group, entry_point_name) valid_classes = (Scheduler,) if isclass(entry_point) and issubclass(entry_point, Scheduler): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
[docs]def TransportFactory(entry_point_name): """Return the `Transport` sub class registered under the given entry point. :param entry_point_name: the entry point name :return: sub class of :py:class:`~aiida.transports.transport.Transport` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from aiida.transports import Transport entry_point_group = 'aiida.transports' entry_point = BaseFactory(entry_point_group, entry_point_name) valid_classes = (Transport,) if isclass(entry_point) and issubclass(entry_point, Transport): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
[docs]def WorkflowFactory(entry_point_name): """Return the `WorkChain` sub class registered under the given entry point. :param entry_point_name: the entry point name :return: sub class of :py:class:`~aiida.engine.processes.workchains.workchain.WorkChain` or a `workfunction` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from aiida.engine import WorkChain, is_process_function, workfunction from aiida.orm import WorkFunctionNode entry_point_group = 'aiida.workflows' entry_point = BaseFactory(entry_point_group, entry_point_name) valid_classes = (WorkChain, workfunction) if isclass(entry_point) and issubclass(entry_point, WorkChain): return entry_point if is_process_function(entry_point) and entry_point.node_class is WorkFunctionNode: return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)