Source code for aiida.cmdline.params.types.plugin

###########################################################################
# Copyright (c), The AiiDA team. All rights reserved.                     #
# This file is part of the AiiDA code.                                    #
#                                                                         #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file        #
# For further information please visit http://www.aiida.net               #
###########################################################################
"""Click parameter type for AiiDA Plugins."""

from __future__ import annotations

import functools
import typing as t

import click

from aiida.common import exceptions
from aiida.plugins import factories
from aiida.plugins.entry_point import (
    ENTRY_POINT_GROUP_PREFIX,
    ENTRY_POINT_STRING_SEPARATOR,
    EntryPointFormat,
    format_entry_point_string,
    get_entry_point,
    get_entry_point_groups,
    get_entry_point_string_format,
    get_entry_points,
)

from .strings import EntryPointType

if t.TYPE_CHECKING:
    from importlib_metadata import EntryPoint

__all__ = ('PluginParamType',)


[docs] class PluginParamType(EntryPointType): """AiiDA Plugin name parameter type. :param group: string or tuple of strings, where each is a valid entry point group. Adding the `aiida.` prefix is optional. If it is not detected it will be prepended internally. :param load: when set to True, convert will not return the entry point, but the loaded entry point Usage:: click.option(... type=PluginParamType(group='aiida.calculations') or:: click.option(... type=PluginParamType(group=('calculations', 'data')) """ name = 'plugin' _factory_mapping = { 'aiida.calculations': factories.CalculationFactory, 'aiida.data': factories.DataFactory, 'aiida.groups': factories.GroupFactory, 'aiida.parsers': factories.ParserFactory, 'aiida.schedulers': factories.SchedulerFactory, 'aiida.transports': factories.TransportFactory, 'aiida.tools.dbimporters': factories.DbImporterFactory, 'aiida.tools.data.orbitals': factories.OrbitalFactory, 'aiida.workflows': factories.WorkflowFactory, }
[docs] def __init__(self, group: str | tuple[str] | None = None, load: bool = False, *args, **kwargs): """Group should be either a string or a tuple of valid entry point groups. If it is not specified we use the tuple of all recognized entry point groups. """ self.load = load self._input_group = group super().__init__(*args, **kwargs)
@functools.cached_property def groups(self) -> tuple[str, ...]: """Returns a tuple of valid groups for this instance""" group = self._input_group valid_entry_point_groups = get_entry_point_groups() if group is None: return tuple(valid_entry_point_groups) if isinstance(group, str): unvalidated_groups = (group,) elif isinstance(group, tuple): unvalidated_groups = group else: raise ValueError('invalid type for group') groups = [] for grp in unvalidated_groups: if not grp.startswith(ENTRY_POINT_GROUP_PREFIX): grp = ENTRY_POINT_GROUP_PREFIX + grp # noqa: PLW2901 if grp not in valid_entry_point_groups: raise ValueError(f'entry point group {grp} is not recognized') groups.append(grp) return tuple(groups) @functools.cached_property def _entry_points(self) -> list[tuple[str, EntryPoint]]: return [(group, entry_point) for group in self.groups for entry_point in get_entry_points(group)] @functools.cached_property def _entry_point_names(self) -> list[str]: return [entry_point.name for _, entry_point in self._entry_points] @property def has_potential_ambiguity(self) -> bool: """Returns whether the set of supported entry point groups can lead to ambiguity when only an entry point name is specified. This will happen if one ore more groups share an entry point with a common name """ return len(self._entry_point_names) != len(set(self._entry_point_names))
[docs] def get_valid_arguments(self) -> list[str]: """Return a list of all available plugin names for the groups configured for this PluginParamType instance. If the entry point names are not unique, because there are multiple groups that contain an entry point that has an identical name, we need to prefix the names with the full group name :returns: list of valid entry point strings """ if self.has_potential_ambiguity: fmt = EntryPointFormat.FULL return sorted([format_entry_point_string(group, ep.name, fmt=fmt) for group, ep in self._entry_points]) return sorted(self._entry_point_names)
[docs] def get_possibilities(self, incomplete: str = '') -> list[str]: """Return a list of plugins starting with incomplete""" if incomplete == '': return self.get_valid_arguments() # If there is a chance of ambiguity we always return the entry point string in FULL format, otherwise # return the possibilities in the same format as the incomplete. Note that this may have some unexpected # effects. For example if incomplete equals `aiida.` or `calculations` it will be detected as the MINIMAL # format, even though they would also be the valid beginnings of a FULL or PARTIAL format, except that we # cannot know that for sure at this time if self.has_potential_ambiguity: possibilites = [eps for eps in self.get_valid_arguments() if eps.startswith(incomplete)] else: possibilites = [] fmt = get_entry_point_string_format(incomplete) for group, entry_point in self._entry_points: entry_point_string = format_entry_point_string(group, entry_point.name, fmt=fmt) if entry_point_string.startswith(incomplete): possibilites.append(entry_point_string) return possibilites
[docs] def shell_complete( self, ctx: click.Context | None, param: click.Parameter | None, incomplete: str ) -> list[click.shell_completion.CompletionItem]: """Return possible completions based on an incomplete value :returns: list of tuples of valid entry points (matching incomplete) and a description """ return [click.shell_completion.CompletionItem(p) for p in self.get_possibilities(incomplete=incomplete)]
[docs] def get_missing_message(self, param: click.Parameter) -> str: return 'Possible arguments are:\n\n' + '\n'.join(self.get_valid_arguments())
[docs] def get_entry_point_from_string(self, entry_point_string: str) -> EntryPoint: """Validate a given entry point string, which means that it should have a valid entry point string format and that the entry point unambiguously corresponds to an entry point in the groups configured for this instance of PluginParameterType. :returns: the entry point if valid :raises: ValueError if the entry point string is invalid """ entry_point_format = get_entry_point_string_format(entry_point_string) if entry_point_format in (EntryPointFormat.FULL, EntryPointFormat.PARTIAL): group, name = entry_point_string.split(ENTRY_POINT_STRING_SEPARATOR) if entry_point_format == EntryPointFormat.PARTIAL: group = ENTRY_POINT_GROUP_PREFIX + group self.validate_entry_point_group(group) elif entry_point_format == EntryPointFormat.MINIMAL: name = entry_point_string matching_groups = {group for group, entry_point in self._entry_points if entry_point.name == name} if len(matching_groups) > 1: raise ValueError( "entry point '{}' matches more than one valid entry point group [{}], " 'please specify an explicit group prefix: {}'.format( name, ' '.join(matching_groups), self._entry_points ) ) elif not matching_groups: raise ValueError( "entry point '{}' is not valid for any of the allowed " 'entry point groups: {}'.format( name, ' '.join(self.groups) ) ) group = matching_groups.pop() else: raise ValueError(f'invalid entry point string format: {entry_point_string}') try: return get_entry_point(group, name) except exceptions.EntryPointError as exception: raise ValueError(exception)
[docs] def validate_entry_point_group(self, group: str) -> None: if group not in self.groups: raise ValueError(f'entry point group `{group}` is not supported by this parameter.')
[docs] def convert( self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None ) -> t.Union[EntryPoint, t.Any]: """Convert the string value to an entry point instance, if the value can be successfully parsed into an actual entry point. Will raise click.BadParameter if validation fails. """ from importlib_metadata import EntryPoint # If the value is already of the expected return type, simply return it. This behavior is new in `click==8.0`: # https://click.palletsprojects.com/en/8.0.x/parameters/#implementing-custom-types if isinstance(value, EntryPoint): try: self.validate_entry_point_group(value.group) except ValueError as exception: raise click.BadParameter(str(exception)) return value value = super().convert(value, param, ctx) try: entry_point = self.get_entry_point_from_string(value) self.validate_entry_point_group(entry_point.group) except ValueError as exception: raise click.BadParameter(str(exception)) if self.load: try: return entry_point.load() except exceptions.LoadingEntryPointError as exception: raise click.BadParameter(str(exception)) else: return entry_point