# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Entities for the AiiDA Graph Explorer utility"""
from abc import ABCMeta, abstractmethod
from collections import namedtuple
from aiida import orm
from aiida.orm.utils.links import LinkQuadruple
VALID_ENTITY_CLASSES = (orm.Node, orm.Group)
GroupNodeEdge = namedtuple('GroupNodeEdge', ['node_id', 'group_id'])
[docs]class AbstractSetContainer(metaclass=ABCMeta):
"""Abstract Class
This class defines a set that can be speciaized to contain either
nodes (either AiiDA nodes or Aiida groups) or edges (AiiDA links,
or records of the connections between groups and nodes).
Instances of this class reference a subset of entities in a database
via a unique identifier.
There are also a few operators defined, for simplicity, to do
set-additions (unions) and deletions.
The underlying Python-class is **set**, which means that adding
an instance to an AiidaEntitySet that is already contained by it
will not create a duplicate.
"""
[docs] def __init__(self):
"""Initialization method"""
super().__init__()
self._keyset = None
self._additional_identifiers = ()
[docs] @abstractmethod
def _check_self_and_other(self, other):
"""Utility function
When called, will check whether self and other instance are compatible.
(i.e. if they contain the same AiiDA class, same identifiers, length of
tuples, etc)
:type other: :py:class:`aiida.tools.graph.age_entities.AbstractSetContainer`
:param other: the other entity to check for compatibility.
"""
[docs] @abstractmethod
def get_template(self):
"""Create new instance with the same defining attributes."""
@property
def keyset(self):
"""Set containing the keys of the entities"""
return self._keyset
@property
def additional_identifiers(self):
"""Additional identifiers for the entities"""
return self._additional_identifiers
@keyset.setter
def keyset(self, inpset):
"""Setter for the keyset
Use with care! There is no way to check if the keys are consistent ids here.
Checks should be performed upstream in the code, previous to calling this setter.
:type inpset: set or None
:param inpset: input set of identifiers that will become the new set contained
"""
valid_type = isinstance(inpset, set) or inpset is None
if not valid_type:
raise ValueError('keyset must be assigned a set or None')
self._keyset = inpset
[docs] def __add__(self, other):
"""Addition (return = self + other): defined as the set union"""
self._check_self_and_other(other)
new = self.get_template()
new.keyset = self.keyset.union(other.keyset)
return new
[docs] def __iadd__(self, other):
"""Addition inplace (self += other)"""
self._check_self_and_other(other)
self.keyset = self.keyset.union(other.keyset)
return self
[docs] def __sub__(self, other):
"""Subtraction (return = self - other): defined as the set-difference"""
self._check_self_and_other(other)
new = self.get_template()
new.keyset = self.keyset.difference(other.keyset)
return new
[docs] def __isub__(self, other):
"""Subtraction inplace (self -= other)"""
self._check_self_and_other(other)
self.keyset = self.keyset.difference(other.keyset)
return self
[docs] def __len__(self):
return len(self.keyset)
[docs] def __repr__(self):
return f"{{{','.join(map(str, self.keyset))}}}"
[docs] def __eq__(self, other):
return self.keyset == other.keyset
[docs] def __ne__(self, other):
return not self == other
[docs] def set_entities(self, new_entitites):
"""
Replaces contained set with the new entities.
:param new_entities: entities which will replace the ones contained
by the EntitySet. Must be an AiiDA instance (Node or Group) or
an appropriate identifier (ID).
"""
self.keyset = set(map(self._check_input_for_set, new_entitites))
[docs] def add_entities(self, new_entitites):
"""
Add new entitities to the existing set of self.
:param new_entities: an iterable of new entities to add.
"""
self.keyset = self.keyset.union(list(map(self._check_input_for_set, new_entitites)))
[docs] def copy(self):
"""Create new instance with the same defining attributes and the same keyset."""
new = self.get_template()
new.keyset = self.keyset.copy()
return new
[docs] def empty(self):
"""Resets the contained set to be an empty set"""
self.keyset = set()
[docs]class AiidaEntitySet(AbstractSetContainer):
"""Extension of AbstractSetContainer
This class is used to store `graph nodes` (aidda nodes or aiida groups).
"""
[docs] def __init__(self, aiida_cls):
"""Initialization method
:param aiida_cls: a valid AiiDA ORM class (Node or Group supported).
"""
super().__init__()
if not aiida_cls in VALID_ENTITY_CLASSES:
raise TypeError(f'aiida_cls has to be among:{VALID_ENTITY_CLASSES}')
self._aiida_cls = aiida_cls
self.keyset = set()
self._identifier = 'id'
self._identifier_type = int
[docs] def _check_self_and_other(self, other):
if not isinstance(other, AiidaEntitySet):
raise TypeError('Other class is not an instance of AiidaEntitySet')
if self.aiida_cls != other.aiida_cls:
raise TypeError('The two instances do not have the same aiida type!')
if self.identifier != other.identifier:
raise ValueError('The two instances do not have the same identifier!')
if self._identifier_type != other.identifier_type:
raise TypeError('The two instances do not have the same identifier type!')
return True
[docs] def get_template(self):
return AiidaEntitySet(aiida_cls=self.aiida_cls)
@property
def identifier(self):
"""Identifier used for the nodes or groups (currently always id)"""
return self._identifier
@property
def identifier_type(self):
"""Type of identifier for the node or group (for id is int)"""
return self._identifier_type
@property
def aiida_cls(self):
"""Class of nodes contained in the entity set (node or group)"""
return self._aiida_cls
[docs]class DirectedEdgeSet(AbstractSetContainer):
"""Extension of AbstractSetContainer
This class is used to store `graph edges` (aidda nodes or aiida groups).
"""
[docs] def __init__(self, aiida_cls_to, aiida_cls_from):
"""Initialization method
The classes that the link connects must be provided.
:param aiida_cls_to: a valid AiiDA ORM class (Node or Group supported).
:param aiida_cls_from: a valid AiiDA ORM class (Node supported).
"""
super().__init__()
for aiida_cls in (aiida_cls_to, aiida_cls_from):
if not aiida_cls in VALID_ENTITY_CLASSES:
raise TypeError(f'aiida_cls has to be among:{VALID_ENTITY_CLASSES}')
self._aiida_cls_to = aiida_cls_to
self._aiida_cls_from = aiida_cls_from
self.keyset = set()
# I need to get the identifiers for the edge. For now, these should be hardcoded
if aiida_cls_from is orm.Node:
if aiida_cls_to is orm.Node:
# This is a Node-Node connection, therefore I will get information on the links
self._edge_identifiers = (('edge', 'input_id'), ('edge', 'output_id'), ('edge', 'type'),
('edge', 'label'))
self._edge_namedtuple = LinkQuadruple
elif aiida_cls_to is orm.Group:
self._edge_identifiers = (('nodes', 'id'), ('groups', 'id'))
self._edge_namedtuple = GroupNodeEdge
else:
raise TypeError(f'Unexpted types aiida_cls_from={aiida_cls_from} and aiida_cls_to={aiida_cls_to}')
else:
raise TypeError(f'Unexpted types aiida_cls_from={aiida_cls_from} and aiida_cls_to={aiida_cls_to}')
[docs] def _check_self_and_other(self, other):
if not isinstance(other, DirectedEdgeSet):
raise TypeError('Other class is not an instance of AiidaEntitySet')
if self.aiida_cls_to != other.aiida_cls_to:
raise TypeError('The two instances do not have the same aiida type!')
if self.aiida_cls_from != other.aiida_cls_from:
raise TypeError('The two instances do not have the same aiida type!')
if self.edge_namedtuple != other.edge_namedtuple:
raise ValueError('The two instances do not have the same identifiers!')
return True
[docs] def get_template(self):
return DirectedEdgeSet(aiida_cls_to=self.aiida_cls_to, aiida_cls_from=self.aiida_cls_from)
@property
def aiida_cls_to(self):
"""The class of nodes which the edge points to"""
return self._aiida_cls_to
@property
def aiida_cls_from(self):
"""The class of nodes which the edge points from"""
return self._aiida_cls_from
@property
def edge_namedtuple(self):
"""The namedtuple type used for the edges` identifiers"""
return self._edge_namedtuple
@property
def edge_identifiers(self):
"""The identifiers for the edges"""
return self._edge_identifiers
[docs]class Basket():
"""Container for several instances of
:py:class:`aiida.tools.graph.age_entities.AiidaEntitySet` .
In the current implementation, it contains one EntitySet for Nodes and one for Groups,
and one EdgeSet for Node-Node edges (links) and one for Group-Node connections.
"""
[docs] def __init__(self, nodes=None, groups=None, nodes_nodes=None, groups_nodes=None):
"""Initialization method
During initialization of the basket, both the sets of nodes and the set of
groups can be provided as one of the following: an AiidaEntitySet with the
respective type (node or group) or a list/set/tuple with the ids of the
respective node or group.
:param nodes: AiiDA nodes provided in an acceptable way.
:param groups: AiiDA groups provided in an acceptable way.
"""
def get_check_set_entity_set(input_object, keyword, aiida_class):
if input_object is None:
output_set = AiidaEntitySet(aiida_class)
return output_set
if isinstance(input_object, (list, tuple, set)):
output_set = AiidaEntitySet(aiida_class)
output_set.set_entities(input_object)
return output_set
if isinstance(input_object, AiidaEntitySet):
if input_object.aiida_cls is aiida_class:
return input_object
raise TypeError(f'{keyword} has to have {aiida_class} as aiida_cls')
else:
raise ValueError(
'Input object is of type {}.\n'
'Instead, it should be either None or one of:\n'
' - {}\n - {}\n - {}\n - {}\n'.format(input_object, AiidaEntitySet, list, tuple, set)
)
def get_check_set_directed_edge_set(var, keyword, cls_from, cls_to):
if var is None:
return DirectedEdgeSet(aiida_cls_to=cls_to, aiida_cls_from=cls_from)
if isinstance(var, DirectedEdgeSet):
if var.aiida_cls_from is not cls_from:
raise TypeError(f'{keyword} has to have {cls_from} as aiida_cls_from')
elif var.aiida_cls_to is not cls_to:
raise TypeError(f'{keyword} has to have {cls_to} as aiida_cls_to')
else:
return var
else:
raise TypeError(f'{keyword} has to be an instance of DirectedEdgeSet')
nodes = get_check_set_entity_set(nodes, 'nodes', orm.Node)
groups = get_check_set_entity_set(groups, 'groups', orm.Group)
nodes_nodes = get_check_set_directed_edge_set(nodes_nodes, 'nodes-nodes', orm.Node, orm.Node)
groups_nodes = get_check_set_directed_edge_set(groups_nodes, 'groups-nodes', orm.Node, orm.Group)
self._dict = dict(nodes=nodes, groups=groups, nodes_nodes=nodes_nodes, groups_nodes=groups_nodes)
@property
def sets(self):
"""
All sets in the basket returned as an ordered list.
The order is: 'groups', 'groups_nodes', 'nodes', 'nodes_nodes'.
"""
return list(zip(*sorted(self.dict.items())))[1]
@property
def dict(self):
"""
All sets in the basket returned as a dictionary.
This includes the keys 'nodes', 'groups', 'nodes_nodes' and 'nodes_groups'.
"""
return self._dict
@property
def nodes(self):
"""Set of nodes stored in the basket"""
return self._dict['nodes']
@property
def groups(self):
"""Set of groups stored in the basket"""
return self._dict['groups']
[docs] def __getitem__(self, key):
return self._dict[key]
[docs] def __setitem__(self, key, val):
self._dict[key] = val
[docs] def __add__(self, other):
new_dict = {}
for key, value in self._dict.items():
new_dict[key] = value + other.dict[key]
return Basket(**new_dict)
[docs] def __iadd__(self, other):
for key in self._dict:
self[key] += other[key]
return self
[docs] def __sub__(self, other):
new_dict = {}
for key in self._dict:
new_dict[key] = self[key] - other[key]
return Basket(**new_dict)
[docs] def __isub__(self, other):
for key in other.dict:
self[key] -= other[key]
return self
[docs] def __len__(self):
return sum([len(s) for s in self.sets])
[docs] def __eq__(self, other):
for key in self._dict:
if self[key] != other[key]:
return False
return True
[docs] def __ne__(self, other):
return not self == other
[docs] def __repr__(self):
ret_str = ''
for key, val in self._dict.items():
ret_str += f' {key}: '
ret_str += f'{str(val)}\n'
return ret_str
[docs] def empty(self):
"""Empty every subset from its content"""
for set_ in self._dict.values():
set_.empty()
[docs] def get_template(self):
"""Create new nasket with the same defining attributes for its internal containers."""
new_dict = {}
for key, val in self._dict.items():
new_dict[key] = val.get_template()
return Basket(**new_dict)
[docs] def copy(self):
"""Create new instance with the same defining attributes and content."""
new_dict = {}
for key, val in self._dict.items():
new_dict[key] = val.copy()
return Basket(**new_dict)