# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
# pylint: disable=invalid-name,inconsistent-return-statements,cyclic-import
"""Definition of factories to load classes from the various plugin groups."""
from inspect import isclass
from aiida.common.exceptions import InvalidEntryPointTypeError
__all__ = (
'BaseFactory', 'CalculationFactory', 'DataFactory', 'DbImporterFactory', 'GroupFactory', 'OrbitalFactory',
'ParserFactory', 'SchedulerFactory', 'TransportFactory', 'WorkflowFactory'
)
def raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes):
"""Raise an `InvalidEntryPointTypeError` with formatted message.
:param entry_point_name: name of the entry point
:param entry_point_group: name of the entry point group
:param valid_classes: tuple of valid classes for the given entry point group
:raises aiida.common.InvalidEntryPointTypeError: always
"""
template = 'entry point `{}` registered in group `{}` is invalid because its type is not one of: {}'
args = (entry_point_name, entry_point_group, ', '.join([e.__name__ for e in valid_classes]))
raise InvalidEntryPointTypeError(template.format(*args))
[docs]def BaseFactory(group, name):
"""Return the plugin class registered under a given entry point group and name.
:param group: entry point group
:param name: entry point name
:return: the plugin class
:raises aiida.common.MissingEntryPointError: entry point was not registered
:raises aiida.common.MultipleEntryPointError: entry point could not be uniquely resolved
:raises aiida.common.LoadingEntryPointError: entry point could not be loaded
"""
from .entry_point import load_entry_point
return load_entry_point(group, name)
[docs]def CalculationFactory(entry_point_name):
"""Return the `CalcJob` sub class registered under the given entry point.
:param entry_point_name: the entry point name
:return: sub class of :py:class:`~aiida.engine.processes.calcjobs.calcjob.CalcJob`
:raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid.
"""
from aiida.engine import CalcJob, calcfunction, is_process_function
from aiida.orm import CalcFunctionNode
entry_point_group = 'aiida.calculations'
entry_point = BaseFactory(entry_point_group, entry_point_name)
valid_classes = (CalcJob, calcfunction)
if isclass(entry_point) and issubclass(entry_point, CalcJob):
return entry_point
if is_process_function(entry_point) and entry_point.node_class is CalcFunctionNode:
return entry_point
raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
[docs]def DataFactory(entry_point_name):
"""Return the `Data` sub class registered under the given entry point.
:param entry_point_name: the entry point name
:return: sub class of :py:class:`~aiida.orm.nodes.data.data.Data`
:raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid.
"""
from aiida.orm import Data
entry_point_group = 'aiida.data'
entry_point = BaseFactory(entry_point_group, entry_point_name)
valid_classes = (Data,)
if isclass(entry_point) and issubclass(entry_point, Data):
return entry_point
raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
[docs]def DbImporterFactory(entry_point_name):
"""Return the `DbImporter` sub class registered under the given entry point.
:param entry_point_name: the entry point name
:return: sub class of :py:class:`~aiida.tools.dbimporters.baseclasses.DbImporter`
:raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid.
"""
from aiida.tools.dbimporters import DbImporter
entry_point_group = 'aiida.tools.dbimporters'
entry_point = BaseFactory(entry_point_group, entry_point_name)
valid_classes = (DbImporter,)
if isclass(entry_point) and issubclass(entry_point, DbImporter):
return entry_point
raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
[docs]def GroupFactory(entry_point_name):
"""Return the `Group` sub class registered under the given entry point.
:param entry_point_name: the entry point name
:return: sub class of :py:class:`~aiida.orm.groups.Group`
:raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid.
"""
from aiida.orm import Group
entry_point_group = 'aiida.groups'
entry_point = BaseFactory(entry_point_group, entry_point_name)
valid_classes = (Group,)
if isclass(entry_point) and issubclass(entry_point, Group):
return entry_point
raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
[docs]def OrbitalFactory(entry_point_name):
"""Return the `Orbital` sub class registered under the given entry point.
:param entry_point_name: the entry point name
:return: sub class of :py:class:`~aiida.tools.data.orbital.orbital.Orbital`
:raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid.
"""
from aiida.tools.data.orbital import Orbital
entry_point_group = 'aiida.tools.data.orbitals'
entry_point = BaseFactory(entry_point_group, entry_point_name)
valid_classes = (Orbital,)
if isclass(entry_point) and issubclass(entry_point, Orbital):
return entry_point
raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
[docs]def ParserFactory(entry_point_name):
"""Return the `Parser` sub class registered under the given entry point.
:param entry_point_name: the entry point name
:return: sub class of :py:class:`~aiida.parsers.parser.Parser`
:raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid.
"""
from aiida.parsers import Parser
entry_point_group = 'aiida.parsers'
entry_point = BaseFactory(entry_point_group, entry_point_name)
valid_classes = (Parser,)
if isclass(entry_point) and issubclass(entry_point, Parser):
return entry_point
raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
[docs]def SchedulerFactory(entry_point_name):
"""Return the `Scheduler` sub class registered under the given entry point.
:param entry_point_name: the entry point name
:return: sub class of :py:class:`~aiida.schedulers.scheduler.Scheduler`
:raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid.
"""
from aiida.schedulers import Scheduler
entry_point_group = 'aiida.schedulers'
entry_point = BaseFactory(entry_point_group, entry_point_name)
valid_classes = (Scheduler,)
if isclass(entry_point) and issubclass(entry_point, Scheduler):
return entry_point
raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
[docs]def TransportFactory(entry_point_name):
"""Return the `Transport` sub class registered under the given entry point.
:param entry_point_name: the entry point name
:return: sub class of :py:class:`~aiida.transports.transport.Transport`
:raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid.
"""
from aiida.transports import Transport
entry_point_group = 'aiida.transports'
entry_point = BaseFactory(entry_point_group, entry_point_name)
valid_classes = (Transport,)
if isclass(entry_point) and issubclass(entry_point, Transport):
return entry_point
raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
[docs]def WorkflowFactory(entry_point_name):
"""Return the `WorkChain` sub class registered under the given entry point.
:param entry_point_name: the entry point name
:return: sub class of :py:class:`~aiida.engine.processes.workchains.workchain.WorkChain` or a `workfunction`
:raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid.
"""
from aiida.engine import WorkChain, is_process_function, workfunction
from aiida.orm import WorkFunctionNode
entry_point_group = 'aiida.workflows'
entry_point = BaseFactory(entry_point_group, entry_point_name)
valid_classes = (WorkChain, workfunction)
if isclass(entry_point) and issubclass(entry_point, WorkChain):
return entry_point
if is_process_function(entry_point) and entry_point.node_class is WorkFunctionNode:
return entry_point
raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)