Source code for aiida.orm.utils.links

# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved.                     #
# This file is part of the AiiDA code.                                    #
#                                                                         #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file        #
# For further information please visit http://www.aiida.net               #
###########################################################################
"""Utilities for dealing with links between nodes."""
from collections import namedtuple, OrderedDict
from collections.abc import Mapping

from aiida.common import exceptions
from aiida.common.lang import type_check

__all__ = ('LinkPair', 'LinkTriple', 'LinkManager', 'validate_link')

LinkPair = namedtuple('LinkPair', ['link_type', 'link_label'])
LinkTriple = namedtuple('LinkTriple', ['node', 'link_type', 'link_label'])
LinkQuadruple = namedtuple('LinkQuadruple', ['source_id', 'target_id', 'link_type', 'link_label'])


def link_triple_exists(source, target, link_type, link_label):
    """Return whether a link with the given type and label exists between the given source and target node.

    :param source: node from which the link is outgoing
    :param target: node to which the link is incoming
    :param link_type: the link type
    :param link_label: the link label
    :return: boolean, True if the link triple exists, False otherwise
    """
    from aiida.orm import Node, QueryBuilder

    # First check if the triple exist in the cache, in the case of an unstored target node
    if target._incoming_cache and LinkTriple(source, link_type, link_label) in target._incoming_cache:  # pylint: disable=protected-access
        return True

    # If either node is unstored (i.e. does not have a pk), the link cannot exist in the database, so no need to check
    if source.pk is None or target.pk is None:
        return False

    # Here we have two stored nodes, so we need to check if the same link already exists in the database.
    # Finding just a single match is sufficient so we can use the `limit` clause for efficiency
    builder = QueryBuilder()
    builder.append(Node, filters={'id': source.id}, project=['id'])
    builder.append(Node, filters={'id': target.id}, edge_filters={'type': link_type.value, 'label': link_label})
    builder.limit(1)

    return builder.count() != 0





[docs]class LinkManager: """ Class to convert a list of LinkTriple tuples into an iterator. It defines convenience methods to retrieve certain subsets of LinkTriple while checking for consistency. For example:: LinkManager.one(): returns the only entry in the list or it raises an exception LinkManager.first(): returns the first entry from the list LinkManager.all(): returns all entries from list The methods `all_nodes` and `all_link_labels` are syntactic sugar wrappers around `all` to get a list of only the incoming nodes or link labels, respectively. """
[docs] def __init__(self, link_triples): """Initialise the collection.""" self.link_triples = link_triples
[docs] def __iter__(self): """Return an iterator of LinkTriple instances. :return: iterator of LinkTriple instances """ return iter(self.link_triples)
[docs] def __next__(self): """Return the next element in the iterator. :return: LinkTriple """ for link_triple in self.link_triples: yield link_triple
[docs] def __bool__(self): return bool(len(self.link_triples))
[docs] def next(self): """Return the next element in the iterator. :return: LinkTriple """ return self.__next__()
[docs] def one(self): """Return a single entry from the iterator. If the iterator contains no or more than one entry, an exception will be raised :return: LinkTriple instance :raises ValueError: if the iterator contains anything but one entry """ if self.link_triples: if len(self.link_triples) > 1: raise ValueError('more than one entry found') return self.link_triples[0] raise ValueError('no entries found')
[docs] def first(self): """Return the first entry from the iterator. :return: LinkTriple instance or None if no entries were matched """ if self.link_triples: return self.link_triples[0] return None
[docs] def all(self): """Return all entries from the list. :return: list of LinkTriple instances """ return self.link_triples
[docs] def all_nodes(self): """Return a list of all nodes. :return: list of nodes """ return [entry.node for entry in self.all()]
[docs] def get_node_by_label(self, label): """Return the node from list for given label. :return: node that corresponds to the given label :raises aiida.common.NotExistent: if the label is not present among the link_triples """ matching_entry = None for entry in self.link_triples: if entry.link_label == label: if matching_entry is None: matching_entry = entry.node else: raise exceptions.MultipleObjectsError( f'more than one neighbor with the label {label} found' ) if matching_entry is None: raise exceptions.NotExistent(f'no neighbor with the label {label} found') return matching_entry
[docs] def nested(self, sort=True): """Construct (nested) dictionary of matched nodes that mirrors the original nesting of link namespaces. Process input and output namespaces can be nested, however the link labels that represent them in the database have a flat hierarchy, and so the link labels are flattened representations of the nested namespaces. This function reconstructs the original node nesting based on the flattened links. :return: dictionary of nested namespaces :raises KeyError: if there are duplicate link labels in a namespace """ from aiida.engine.processes.ports import PORT_NAMESPACE_SEPARATOR nested = {} for entry in self.link_triples: current_namespace = nested breadcrumbs = entry.link_label.split(PORT_NAMESPACE_SEPARATOR) # The last element is the "leaf" port name the preceding elements are nested port namespaces port_name = breadcrumbs[-1] port_namespaces = breadcrumbs[:-1] # Get the nested namespace for subspace in port_namespaces: current_namespace = current_namespace.setdefault(subspace, {}) # Insert the node at the given port name if port_name in current_namespace: raise KeyError(f"duplicate label '{port_name}' in namespace '{'.'.join(port_namespaces)}'") current_namespace[port_name] = entry.node if sort: return OrderedDict(sorted(nested.items(), key=lambda x: (not isinstance(x[1], Mapping), x))) return nested