Source code for aiida.tools.graph.graph_traversers

# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved.                     #
# This file is part of the AiiDA code.                                    #
#                                                                         #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file        #
# For further information please visit http://www.aiida.net               #
###########################################################################
"""Module for functions to traverse AiiDA graphs."""

from numpy import inf
from aiida.common.links import GraphTraversalRules, LinkType


[docs]def get_nodes_delete(starting_pks, get_links=False, **kwargs): """ This function will return the set of all nodes that can be connected to a list of initial nodes through any sequence of specified authorized links and directions for deletion. :type starting_pks: list or tuple or set :param starting_pks: Contains the (valid) pks of the starting nodes. :param bool get_links: Pass True to also return the links between all nodes (found + initial). :param bool create_forward: will traverse CREATE links in the forward direction. :param bool call_calc_forward: will traverse CALL_CALC links in the forward direction. :param bool call_work_forward: will traverse CALL_WORK links in the forward direction. """ traverse_links = validate_traversal_rules(GraphTraversalRules.DELETE, **kwargs) traverse_output = traverse_graph( starting_pks, get_links=get_links, links_forward=traverse_links['forward'], links_backward=traverse_links['backward'] ) function_output = { 'nodes': traverse_output['nodes'], 'links': traverse_output['links'], 'rules': traverse_links['rules_applied'] } return function_output
[docs]def get_nodes_export(starting_pks, get_links=False, **kwargs): """ This function will return the set of all nodes that can be connected to a list of initial nodes through any sequence of specified authorized links and directions for export. This will also return the links and the traversal rules parsed. :type starting_pks: list or tuple or set :param starting_pks: Contains the (valid) pks of the starting nodes. :param bool get_links: Pass True to also return the links between all nodes (found + initial). :param bool input_calc_forward: will traverse INPUT_CALC links in the forward direction. :param bool create_backward: will traverse CREATE links in the backward direction. :param bool return_backward: will traverse RETURN links in the backward direction. :param bool input_work_forward: will traverse INPUT_WORK links in the forward direction. :param bool call_calc_backward: will traverse CALL_CALC links in the backward direction. :param bool call_work_backward: will traverse CALL_WORK links in the backward direction. """ traverse_links = validate_traversal_rules(GraphTraversalRules.EXPORT, **kwargs) traverse_output = traverse_graph( starting_pks, get_links=get_links, links_forward=traverse_links['forward'], links_backward=traverse_links['backward'] ) function_output = { 'nodes': traverse_output['nodes'], 'links': traverse_output['links'], 'rules': traverse_links['rules_applied'] } return function_output
[docs]def validate_traversal_rules(ruleset=GraphTraversalRules.DEFAULT, **kwargs): """ Validates the keywords with a ruleset template and returns a parsed dictionary ready to be used. :type ruleset: :py:class:`aiida.common.links.GraphTraversalRules` :param ruleset: Ruleset template used to validate the set of rules. :param bool input_calc_forward: will traverse INPUT_CALC links in the forward direction. :param bool input_calc_backward: will traverse INPUT_CALC links in the backward direction. :param bool create_forward: will traverse CREATE links in the forward direction. :param bool create_backward: will traverse CREATE links in the backward direction. :param bool return_forward: will traverse RETURN links in the forward direction. :param bool return_backward: will traverse RETURN links in the backward direction. :param bool input_work_forward: will traverse INPUT_WORK links in the forward direction. :param bool input_work_backward: will traverse INPUT_WORK links in the backward direction. :param bool call_calc_forward: will traverse CALL_CALC links in the forward direction. :param bool call_calc_backward: will traverse CALL_CALC links in the backward direction. :param bool call_work_forward: will traverse CALL_WORK links in the forward direction. :param bool call_work_backward: will traverse CALL_WORK links in the backward direction. """ from aiida.common import exceptions if not isinstance(ruleset, GraphTraversalRules): raise TypeError( 'ruleset input must be of type aiida.common.links.GraphTraversalRules\ninstead, it is: {}'.format( type(ruleset) ) ) rules_applied = {} links_forward = [] links_backward = [] for name, rule in ruleset.value.items(): follow = rule.default if name in kwargs: if not rule.toggleable: raise ValueError('input rule {} is not toggleable for ruleset {}'.format(name, ruleset)) follow = kwargs.pop(name) if not isinstance(follow, bool): raise ValueError('the value of rule {} must be boolean, but it is: {}'.format(name, follow)) if follow: if rule.direction == 'forward': links_forward.append(rule.link_type) elif rule.direction == 'backward': links_backward.append(rule.link_type) else: raise exceptions.InternalError( 'unrecognized direction `{}` for graph traversal rule'.format(rule.direction) ) rules_applied[name] = follow if kwargs: error_message = 'unrecognized keywords: {}'.format(', '.join(kwargs.keys())) raise exceptions.ValidationError(error_message) valid_output = { 'rules_applied': rules_applied, 'forward': links_forward, 'backward': links_backward, } return valid_output
[docs]def traverse_graph(starting_pks, max_iterations=None, get_links=False, links_forward=(), links_backward=()): """ This function will return the set of all nodes that can be connected to a list of initial nodes through any sequence of specified links. Optionally, it may also return the links that connect these nodes. :type starting_pks: list or tuple or set :param starting_pks: Contains the (valid) pks of the starting nodes. :type max_iterations: int or None :param max_iterations: The number of iterations to apply the set of rules (a value of 'None' will iterate until no new nodes are added). :param bool get_links: Pass True to also return the links between all nodes (found + initial). :type links_forward: aiida.common.links.LinkType :param links_forward: List with all the links that should be traversed in the forward direction. :type links_backward: aiida.common.links.LinkType :param links_backward: List with all the links that should be traversed in the backward direction. """ # pylint: disable=too-many-locals,too-many-statements,too-many-branches from aiida import orm from aiida.tools.graph.age_entities import Basket from aiida.tools.graph.age_rules import UpdateRule, RuleSequence, RuleSaveWalkers, RuleSetWalkers from aiida.common import exceptions if max_iterations is None: max_iterations = inf elif not (isinstance(max_iterations, int) or max_iterations is inf): raise TypeError('Max_iterations has to be an integer or infinity') linktype_list = [] for linktype in links_forward: if not isinstance(linktype, LinkType): raise TypeError('links_forward should contain links, but one of them is: {}'.format(type(linktype))) linktype_list.append(linktype.value) filters_forwards = {'type': {'in': linktype_list}} linktype_list = [] for linktype in links_backward: if not isinstance(linktype, LinkType): raise TypeError('links_backward should contain links, but one of them is: {}'.format(type(linktype))) linktype_list.append(linktype.value) filters_backwards = {'type': {'in': linktype_list}} if not isinstance(starting_pks, (list, set, tuple)): raise TypeError('starting_pks must be of type list, set or tuple\ninstead, it is {}'.format(type(starting_pks))) if not starting_pks: if get_links: output = {'nodes': set(), 'links': set()} else: output = {'nodes': set(), 'links': None} return output if any([not isinstance(pk, int) for pk in starting_pks]): raise TypeError('one of the starting_pks is not of type int:\n {}'.format(starting_pks)) operational_set = set(starting_pks) query_nodes = orm.QueryBuilder() query_nodes.append(orm.Node, project=['id'], filters={'id': {'in': operational_set}}) existing_pks = {pk[0] for pk in query_nodes.all()} missing_pks = operational_set.difference(existing_pks) if missing_pks: raise exceptions.NotExistent( 'The following pks are not in the database and must be pruned before this call: {}'.format(missing_pks) ) rules = [] basket = Basket(nodes=operational_set) # When max_iterations is finite, the order of traversal may affect the result # (its not the same to first go backwards and then forwards than vice-versa) # In order to make it order-independent, the result of the first operation needs # to be stashed and the second operation must be performed only on the nodes # that were already in the set at the begining of the iteration: this way, both # rules are applied on the same set of nodes and the order doesn't matter. # The way to do this is saving and seting the walkers at the right moments only # when both forwards and backwards rules are present. if links_forward and links_backward: stash = basket.get_template() rules += [RuleSaveWalkers(stash)] if links_forward: query_outgoing = orm.QueryBuilder() query_outgoing.append(orm.Node, tag='sources') query_outgoing.append(orm.Node, edge_filters=filters_forwards, with_incoming='sources') rule_outgoing = UpdateRule(query_outgoing, max_iterations=1, track_edges=get_links) rules += [rule_outgoing] if links_forward and links_backward: rules += [RuleSetWalkers(stash)] if links_backward: query_incoming = orm.QueryBuilder() query_incoming.append(orm.Node, tag='sources') query_incoming.append(orm.Node, edge_filters=filters_backwards, with_outgoing='sources') rule_incoming = UpdateRule(query_incoming, max_iterations=1, track_edges=get_links) rules += [rule_incoming] rulesequence = RuleSequence(rules, max_iterations=max_iterations) results = rulesequence.run(basket) output = {} output['nodes'] = results.nodes.keyset output['links'] = None if get_links: output['links'] = results['nodes_nodes'].keyset return output