Source code for aiida.tools.visualization.graphviz

# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved.                     #
# This file is part of the AiiDA code.                                    #
#                                                                         #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida_core #
# For further information on the license, see the LICENSE.txt file        #
# For further information please visit http://www.aiida.net               #
###########################################################################
"""
Draw the provenance graphs
"""
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import io
import os
import subprocess
import tempfile


[docs]def draw_graph(origin_node, ancestor_depth=None, descendant_depth=None, image_format='dot', include_calculation_inputs=False, include_calculation_outputs=False): """ The algorithm starts from the original node and goes both input-ward and output-ward via a breadth-first algorithm. :param origin_node: An Aiida node, the starting point for drawing the graph :param int ancestor_depth: The maximum depth of the ancestors drawn. If left to None, we recurse until the graph is fully explored :param int descendant_depth: The maximum depth of the descendants drawn. If left to None, we recurse until the graph is fully explored :param str image_format: The output plot format, by default dot :returns: The exit_code of the subprocess.call() method that produced the valid file :returns: The file name of the final output ..note:: If an invalid format is provided graphviz prints a helpful message, so this doesn't need to be implemented here. """ # pylint: disable=too-many-locals,too-many-statements,too-many-branches from aiida.orm import ProcessNode from aiida.orm import Code from aiida.orm import Node from aiida.common.links import LinkType from aiida.orm.querybuilder import QueryBuilder def draw_node_settings(node, **kwargs): """ Returns a string with all infos needed in a .dot file to define a node of a graph. :param node: :param kwargs: Additional key-value pairs to be added to the returned string :return: a string """ if isinstance(node, ProcessNode): shape = "shape=polygon,sides=4" elif isinstance(node, Code): shape = "shape=diamond" else: shape = "shape=ellipse" if kwargs: additional_params = ",{}".format(",".join('{}="{}"'.format(k, v) for k, v in kwargs.items())) else: additional_params = "" if node.label: label_string = "\\n'{}'".format(node.label) additional_string = "" else: additional_string = "\\n {}".format(node.get_description()) label_string = "" labelstring = 'label="{} ({}){}{}"'.format(node.__class__.__name__, node.pk, label_string, additional_string) return "N{} [{},{}{}];".format(node.pk, shape, labelstring, additional_params) def draw_link_settings(inp_id, out_id, link_label, link_type): """Return a string with label information.""" if link_type in (LinkType.CREATE.value, LinkType.INPUT_CALC.value, LinkType.INPUT_WORK.value): style = 'solid' # Solid lines and black colors color = "0.0 0.0 0.0" # for CREATE and INPUT (The provenance graph) elif link_type == LinkType.RETURN.value: style = 'dotted' # Dotted lines of color = "0.0 0.0 0.0" # black color for Returns elif link_type == (LinkType.CALL_CALC.value or LinkType.CALL_WORK.value): style = 'bold' # Bold lines and color = "0.0 1.0 1.0" # Bright red for calls else: style = 'solid' # Solid and color = "0.0 0.0 0.5" #grey lines for unspecified links! return ' {} -> {} [label="{}", color="{}", style="{}"];'.format("N{}".format(inp_id), "N{}".format(out_id), link_label, color, style) # Breadth-first search of all ancestors and descendant nodes of a given node links = {} # Accumulate links here nodes = { origin_node.pk: draw_node_settings(origin_node, style='filled', color='lightblue') } #Accumulate nodes specs here # Additional nodes (the ones added with either one of include_calculation_inputs or include_calculation_outputs # is set to true. I have to put them in a different dictionary because nodes is the one used for the recursion, # whereas these should not be used for the recursion: additional_nodes = {} last_nodes = [origin_node] # Put the nodes whose links have not been scanned yet # Go through the graph on-ward (i.e. look at inputs) depth = 0 while last_nodes: # I augment depth every time I get through a new iteration depth += 1 # I check whether I should stop here: if ancestor_depth is not None and depth > ancestor_depth: break # I continue by adding new nodes here! new_nodes = [] for node in last_nodes: # This query gives me all the inputs of this node, and link labels and types! input_query = QueryBuilder() input_query.append(Node, filters={'id': node.pk}, tag='n') input_query.append(Node, with_outgoing='n', edge_project=('id', 'label', 'type'), project='*', tag='inp') for inp, link_id, link_label, link_type in input_query.iterall(): # I removed this check, to me there is no way that this link was already referred to! # if link_id not in links: links[link_id] = draw_link_settings(inp.pk, node.pk, link_label, link_type) # For the nodes I need to check, maybe this same node is referred to multiple times. if inp.pk not in nodes: nodes[inp.pk] = draw_node_settings(inp) new_nodes.append(inp) # Checking whether I also should include all the outputs of a calculation into the drawing: if include_calculation_outputs and isinstance(node, ProcessNode): # Query for the outputs, giving me also link labels and types: output_query = QueryBuilder() output_query.append(Node, filters={'id': node.pk}, tag='n') output_query.append( Node, with_incoming='n', edge_project=('id', 'label', 'type'), project='*', tag='out') # Iterate through results for out, link_id, link_label, link_type in output_query.iterall(): # This link might have been drawn already, because the output is maybe # already drawn. # To check: Maybe it's more efficient not to check this, since # the dictionaries are large and contain many keys... # I.e. just always draw, also when overwriting an existing (identical) entry. if link_id not in links: links[link_id] = draw_link_settings(node.pk, out.pk, link_label, link_type) if out.pk not in nodes and out.pk not in additional_nodes: additional_nodes[out.pk] = draw_node_settings(out) last_nodes = new_nodes # Go through the graph down-ward (i.e. look at outputs) last_nodes = [origin_node] depth = 0 while last_nodes: depth += 1 # Also here, checking of maximum descendant depth is set and applies. if descendant_depth is not None and depth > descendant_depth: break new_nodes = [] for node in last_nodes: # Query for the outputs: output_query = QueryBuilder() output_query.append(Node, filters={'id': node.pk}, tag='n') output_query.append(Node, with_incoming='n', edge_project=('id', 'label', 'type'), project='*', tag='out') for out, link_id, link_label, link_type in output_query.iterall(): # Draw the link links[link_id] = draw_link_settings(node.pk, out.pk, link_label, link_type) if out.pk not in nodes: nodes[out.pk] = draw_node_settings(out) new_nodes.append(out) if include_calculation_inputs and isinstance(node, ProcessNode): input_query = QueryBuilder() input_query.append(Node, filters={'id': node.pk}, tag='n') input_query.append( Node, with_outgoing='n', edge_project=('id', 'label', 'type'), project='*', tag='inp') for inp, link_id, link_label, link_type in input_query.iterall(): # Also here, maybe it's just better not to check? if link_id not in links: links[link_id] = draw_link_settings(inp.pk, node.pk, link_label, link_type) if inp.pk not in nodes and inp.pk not in additional_nodes: additional_nodes[inp.pk] = draw_node_settings(inp) last_nodes = new_nodes # Writing the graph to a temporary file _, fname = tempfile.mkstemp(suffix='.dot') with io.open(fname, 'w', encoding='utf8') as fhandle: fhandle.write(u"digraph G {\n") for _, l_values in links.items(): fhandle.write(u' {}\n'.format(l_values)) for _, n_values in nodes.items(): fhandle.write(u" {}\n".format(n_values)) for _, n_values in additional_nodes.items(): fhandle.write(u" {}\n".format(n_values)) fhandle.write(u"}\n") # Now I am producing the output file output_file_name = "{0}.{1}".format(origin_node.pk, image_format) # Try and convert the .dot file using the `dot` utility from graphviz try: exit_code = subprocess.call(['dot', '-T', image_format, fname, '-o', output_file_name]) except OSError: from aiida.cmdline.utils import echo echo.echo_critical('Operating system error - perhaps Graphviz is not installed?') # cleaning up by removing the temporary file os.remove(fname) return exit_code, output_file_name