Source code for aiida.backends.tests.test_base_dataclasses

# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved.                     #
# This file is part of the AiiDA code.                                    #
#                                                                         #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida_core #
# For further information on the license, see the LICENSE.txt file        #
# For further information please visit http://www.aiida.net               #
###########################################################################
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

import operator

from six.moves import range, zip

from aiida.backends.testbase import AiidaTestCase
from aiida.common.exceptions import ModificationNotAllowed
from aiida.orm import load_node, List, Bool, Float, Int, Str, NumericType
from aiida.orm.nodes.data.bool import get_true_node, get_false_node


[docs]class TestList(AiidaTestCase):
[docs] def test_creation(self): node = List() self.assertEqual(len(node), 0) with self.assertRaises(IndexError): node[0]
[docs] def test_append(self): def do_checks(node): self.assertEqual(len(node), 1) self.assertEqual(node[0], 4) node = List() node.append(4) do_checks(node) # Try the same after storing node = List() node.append(4) node.store() do_checks(node)
[docs] def test_extend(self): lst = [1, 2, 3] def do_checks(node): self.assertEqual(len(node), len(lst)) # Do an element wise comparison for x, y in zip(lst, node): self.assertEqual(x, y) node = List() node.extend(lst) do_checks(node) # Further extend node.extend(lst) self.assertEqual(len(node), len(lst) * 2) # Do an element wise comparison for i in range(len(lst)): self.assertEqual(lst[i], node[i]) self.assertEqual(lst[i], node[i % len(lst)]) # Now try after storing node = List() node.extend(lst) node.store() do_checks(node)
[docs] def test_mutability(self): node = List() node.append(5) node.store() # Test all mutable calls are now disallowed with self.assertRaises(ModificationNotAllowed): node.append(5) with self.assertRaises(ModificationNotAllowed): node.extend([5]) with self.assertRaises(ModificationNotAllowed): node.insert(0, 2) with self.assertRaises(ModificationNotAllowed): node.remove(0) with self.assertRaises(ModificationNotAllowed): node.pop() with self.assertRaises(ModificationNotAllowed): node.sort() with self.assertRaises(ModificationNotAllowed): node.reverse()
[docs] def test_store_load(self): node = List(list=[1, 2, 3]) node.store() node_loaded = load_node(node.pk) assert node.get_list() == node_loaded.get_list()
[docs]class TestFloat(AiidaTestCase):
[docs] def setUp(self): super(TestFloat, self).setUp() self.value = Float() self.all_types = [Int, Float, Bool, Str]
[docs] def test_create(self): a = Float() # Check that initial value is zero self.assertAlmostEqual(a.value, 0.0) f = Float(6.0) self.assertAlmostEqual(f.value, 6.) self.assertAlmostEqual(f, Float(6.0)) i = Int() self.assertAlmostEqual(i.value, 0) i = Int(6) self.assertAlmostEqual(i.value, 6) self.assertAlmostEqual(f, i) b = Bool() self.assertAlmostEqual(b.value, False) b = Bool(False) self.assertAlmostEqual(b.value, False) self.assertAlmostEqual(b.value, get_false_node()) b = Bool(True) self.assertAlmostEqual(b.value, True) self.assertAlmostEqual(b.value, get_true_node()) s = Str() self.assertAlmostEqual(s.value, "") s = Str('Hello') self.assertAlmostEqual(s.value, 'Hello')
[docs] def test_load(self): for t in self.all_types: node = t() node.store() loaded = load_node(node.pk) self.assertAlmostEqual(node, loaded)
[docs] def test_add(self): a = Float(4) b = Float(5) # Check adding two db Floats res = a + b self.assertIsInstance(res, NumericType) self.assertAlmostEqual(res, 9.0) # Check adding db Float and native (both ways) res = a + 5.0 self.assertIsInstance(res, NumericType) self.assertAlmostEqual(res, 9.0) res = 5.0 + a self.assertIsInstance(res, NumericType) self.assertAlmostEqual(res, 9.0) # Inplace a = Float(4) a += b self.assertAlmostEqual(a, 9.0) a = Float(4) a += 5 self.assertAlmostEqual(a, 9.0)
[docs] def test_mul(self): a = Float(4) b = Float(5) # Check adding two db Floats res = a * b self.assertIsInstance(res, NumericType) self.assertAlmostEqual(res, 20.0) # Check adding db Float and native (both ways) res = a * 5.0 self.assertIsInstance(res, NumericType) self.assertAlmostEqual(res, 20) res = 5.0 * a self.assertIsInstance(res, NumericType) self.assertAlmostEqual(res, 20.0) # Inplace a = Float(4) a *= b self.assertAlmostEqual(a, 20) a = Float(4) a *= 5 self.assertAlmostEqual(a, 20)
[docs] def test_power(self): a = Float(4) b = Float(2) res = a ** b self.assertAlmostEqual(res.value, 16.)
[docs] def test_modulo(self): a = Float(12.0) b = Float(10.0) self.assertAlmostEqual(a % b, 2.0) self.assertIsInstance(a % b, NumericType) self.assertAlmostEqual(a % 10.0, 2.0) self.assertIsInstance(a % 10.0, NumericType) self.assertAlmostEqual(12.0 % b, 2.0) self.assertIsInstance(12.0 % b, NumericType)
[docs]class TestFloatIntMix(AiidaTestCase):
[docs] def test_operator(self): a = Float(2.2) b = Int(3) for op in [operator.add, operator.mul, operator.pow, operator.lt, operator.le, operator.gt, operator.ge, operator.iadd, operator.imul]: for x, y in [(a, b), (b, a)]: c = op(x, y) c_val = op(x.value, y.value) self.assertEqual(c._type, type(c_val)) self.assertEqual(c, op(x.value, y.value))
[docs]class TestInt(AiidaTestCase):
[docs] def test_modulo(self): a = Int(12) b = Int(10) self.assertEqual(a % b, 2) self.assertIsInstance(a % b, NumericType) self.assertEqual(a % 10, 2) self.assertIsInstance(a % 10, NumericType) self.assertEqual(12 % b, 2) self.assertIsInstance(12 % b, NumericType)