# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida_core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
from abc import ABCMeta
import numbers
import collections
try:
from functools import singledispatch
except ImportError:
from singledispatch import singledispatch
from past.builtins import basestring
import numpy as np
from aiida.orm import Data
from aiida.orm.data.parameter import ParameterData
[docs]class BaseType(Data):
"""
Store a base python type as a AiiDA node in the DB.
Provide the .value property to get the actual value.
"""
__metaclass__ = ABCMeta
[docs] def __init__(self, *args, **kwargs):
try:
getattr(self, '_type')
except AttributeError:
raise RuntimeError("Derived class must define the _type class member")
super(BaseType, self).__init__(**self._create_init_args(*args, **kwargs))
[docs] def set_typevalue(self, typevalue):
_type, value = typevalue
self._type = _type
if value:
self.value = value
else:
self.value = _type()
@property
def value(self):
return self.get_attr('value')
@value.setter
def value(self, value):
self._set_attr('value', self._type(value))
[docs] def __str__(self):
return self.value.__str__()
[docs] def __repr__(self):
return self.value.__repr__()
[docs] def __eq__(self, other):
if isinstance(other, BaseType):
return self.value == other.value
else:
return self.value == other
[docs] def __ne__(self, other):
if isinstance(other, BaseType):
return self.value != other.value
else:
return self.value != other
[docs] def new(self, value=None):
return self.__class__(typevalue=(self._type, value))
[docs] def _create_init_args(self, *args, **kwargs):
if args:
assert not kwargs, "Cannot have positional arguments and kwargs"
assert len(args) == 1, \
"Simple data can only take at most one positional argument"
kwargs['typevalue'] = (self._type, self._type(args[0]))
elif 'dbnode' not in kwargs:
if 'typevalue' in kwargs:
assert kwargs['typevalue'][0] is self._type
if kwargs['typevalue'][1] is not None:
kwargs['typevalue'] = \
(self._type, self._type(kwargs['typevalue'][1]))
else:
kwargs['typevalue'] = (self._type, None)
else:
assert len(kwargs) == 1, \
"When specifying dbnode it can be the only kwarg"
return kwargs
[docs]def _left_operator(func):
def inner(self, other):
l = self.value
if isinstance(other, NumericType):
r = other.value
else:
r = other
return to_aiida_type(func(l, r))
return inner
[docs]def _right_operator(func):
def inner(self, other):
assert not isinstance(other, NumericType)
return to_aiida_type(func(self.value, other))
return inner
[docs]class NumericType(BaseType):
"""
Specific subclass of :py:class:`BaseType` to store numbers,
overloading common operators (``+``, ``*``, ...)
"""
[docs] @_left_operator
def __add__(self, other):
return self + other
[docs] @_right_operator
def __radd__(self, other):
return other + self
[docs] @_left_operator
def __sub__(self, other):
return self - other
[docs] @_right_operator
def __rsub__(self, other):
return other - self
[docs] @_left_operator
def __mul__(self, other):
return self * other
[docs] @_right_operator
def __rmul__(self, other):
return other * self
[docs] @_left_operator
def __pow__(self, power):
return self ** power
[docs] @_left_operator
def __lt__(self, other):
return self < other
[docs] @_left_operator
def __le__(self, other):
return self <= other
[docs] @_left_operator
def __gt__(self, other):
return self > other
[docs] @_left_operator
def __ge__(self, other):
return self >= other
[docs] def __float__(self):
return float(self.value)
[docs] def __int__(self):
return int(self.value)
[docs]class Float(NumericType):
"""
Class to store float numbers as AiiDA nodes
"""
_type = float
[docs]class Int(NumericType):
"""
Class to store integer numbers as AiiDA nodes
"""
_type = int
[docs]class Str(BaseType):
"""
Class to store strings as AiiDA nodes
"""
_type = str
[docs]class Bool(BaseType):
"""
Class to store booleans as AiiDA nodes
"""
_type = bool
[docs] def __int__(self):
return int(bool(self))
# Python 2
[docs] def __nonzero__(self):
return self.__bool__()
# Python 3
[docs] def __bool__(self):
return self.value
[docs]class List(Data, collections.MutableSequence):
"""
Class to store python lists as AiiDA nodes
"""
_LIST_KEY = 'list'
[docs] def __getitem__(self, item):
return self._get_list()[item]
[docs] def __setitem__(self, key, value):
l = self._get_list()
l[key] = value
if not self._using_list_reference():
self._set_list(l)
[docs] def __delitem__(self, key):
l = self._get_list()
del l[key]
if not self._using_list_reference():
self._set_list(l)
[docs] def __len__(self):
return len(self._get_list())
[docs] def __str__(self):
return self._get_list().__str__()
[docs] def append(self, value):
l = self._get_list()
l.append(value)
if not self._using_list_reference():
self._set_list(l)
[docs] def extend(self, L):
l = self._get_list()
l.extend(L)
if not self._using_list_reference():
self._set_list(l)
[docs] def insert(self, i, value):
l = self._get_list()
l.insert(i, value)
if not self._using_list_reference():
self._set_list(l)
[docs] def remove(self, value):
del self[value]
[docs] def pop(self, **kwargs):
l = self._get_list()
l.pop(**kwargs)
if not self._using_list_reference():
self._set_list(l)
[docs] def index(self, value):
return self._get_list().index(value)
[docs] def count(self, value):
return self._get_list().count(value)
[docs] def sort(self, cmp=None, key=None, reverse=False):
l = self._get_list()
l.sort(cmp, key, reverse)
if not self._using_list_reference():
self._set_list(l)
[docs] def reverse(self):
l = self._get_list()
l.reverse()
if not self._using_list_reference():
self._set_list(l)
[docs] def _get_list(self):
try:
return self.get_attr(self._LIST_KEY)
except AttributeError:
self._set_list(list())
return self.get_attr(self._LIST_KEY)
[docs] def _set_list(self, list_):
if not isinstance(list_, list):
raise TypeError("Must supply list type")
self._set_attr(self._LIST_KEY, list_)
[docs] def _using_list_reference(self):
"""
This function tells the class if we are using a list reference. This
means that calls to self.get_list return a reference rather than a copy
of the underlying list and therefore self._set_list need not be called.
This knwoledge is essential to make sure this class is performant.
Currently the implementation assumes that if the node needs to be
stored then it is using the attributes cache which is a reference.
:return: True if using self._get_list returns a reference to the
underlying sequence. False otherwise.
:rtype: bool
"""
return self._to_be_stored
[docs]def get_true_node():
"""
Return a Bool Data node, with value True
Cannot be done as a singleton in the module, because it would be generated
at import time, with the risk that (e.g. in the tests, or at the very first use
of AiiDA) a user is not yet defined in the DB (but a user is mandatory in the
DB before you can create new Nodes in AiiDA).
"""
TRUE = Bool(typevalue=(bool, True))
return TRUE
[docs]def get_false_node():
"""
Return a Bool Data node, with value False
Cannot be done as a singleton in the module, because it would be generated
at import time, with the risk that (e.g. in the tests, or at the very first use
of AiiDA) a user is not yet defined in the DB (but a user is mandatory in the
DB before you can create new Nodes in AiiDA).
"""
FALSE = Bool(typevalue=(bool, False))
return FALSE
[docs]@singledispatch
def to_aiida_type(value):
"""
Turns basic Python types (str, int, float, bool) into the corresponding AiiDA types.
"""
raise TypeError("Cannot convert value of type {} to AiiDA type.".format(type(value)))
@to_aiida_type.register(basestring)
def _(value):
return Str(value)
@to_aiida_type.register(numbers.Integral)
def _(value):
return Int(value)
@to_aiida_type.register(numbers.Real)
def _(value):
return Float(value)
@to_aiida_type.register(bool)
@to_aiida_type.register(np.bool_)
def _(value):
return Bool(value)
[docs]@to_aiida_type.register(dict)
def _(value):
return ParameterData(dict=value)