Source code for aiida.backends.tests.tools.importexport.orm.test_extras

# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved.                     #
# This file is part of the AiiDA code.                                    #
#                                                                         #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file        #
# For further information please visit http://www.aiida.net               #
###########################################################################
"""Extras tests for the export and import routines"""
# pylint: disable=attribute-defined-outside-init
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
from __future__ import with_statement

import os
import shutil
import tempfile

from aiida import orm
from aiida.backends.testbase import AiidaTestCase
from aiida.tools.importexport import import_data, export


[docs]class TestExtras(AiidaTestCase): """Test ex-/import cases related to Extras"""
[docs] @classmethod def setUpClass(cls, *args, **kwargs): """Only run to prepare an export file""" super(TestExtras, cls).setUpClass() data = orm.Data() data.label = 'my_test_data_node' data.store() data.set_extra_many({'b': 2, 'c': 3}) cls.tmp_folder = tempfile.mkdtemp() cls.export_file = os.path.join(cls.tmp_folder, 'export.aiida') export([data], outfile=cls.export_file, silent=True)
[docs] @classmethod def tearDownClass(cls, *args, **kwargs): """Remove tmp_folder""" super(TestExtras, cls).tearDownClass() shutil.rmtree(cls.tmp_folder, ignore_errors=True)
[docs] def setUp(self): """This function runs before every test execution""" self.clean_db() self.insert_data()
[docs] def import_extras(self, mode_new='import'): """Import an aiida database""" import_data(self.export_file, silent=True, extras_mode_new=mode_new) builder = orm.QueryBuilder().append(orm.Data, filters={'label': 'my_test_data_node'}) self.assertEqual(builder.count(), 1) self.imported_node = builder.all()[0][0]
[docs] def modify_extras(self, mode_existing): """Import the same aiida database again""" self.imported_node.set_extra('a', 1) self.imported_node.set_extra('b', 1000) self.imported_node.delete_extra('c') import_data(self.export_file, silent=True, extras_mode_existing=mode_existing) # Query again the database builder = orm.QueryBuilder().append(orm.Data, filters={'label': 'my_test_data_node'}) self.assertEqual(builder.count(), 1) return builder.all()[0][0]
[docs] def tearDown(self): pass
[docs] def test_import_of_extras(self): """Check if extras are properly imported""" self.import_extras() self.assertEqual(self.imported_node.get_extra('b'), 2) self.assertEqual(self.imported_node.get_extra('c'), 3)
[docs] def test_absence_of_extras(self): """Check whether extras are not imported if the mode is set to 'none'""" self.import_extras(mode_new='none') with self.assertRaises(AttributeError): # the extra 'b' should not exist self.imported_node.get_extra('b') with self.assertRaises(AttributeError): # the extra 'c' should not exist self.imported_node.get_extra('c')
[docs] def test_extras_import_mode_keep_existing(self): """Check if old extras are not modified in case of name collision""" self.import_extras() imported_node = self.modify_extras(mode_existing='kcl') # Check that extras are imported correctly self.assertEqual(imported_node.get_extra('a'), 1) self.assertEqual(imported_node.get_extra('b'), 1000) self.assertEqual(imported_node.get_extra('c'), 3)
[docs] def test_extras_import_mode_update_existing(self): """Check if old extras are modified in case of name collision""" self.import_extras() imported_node = self.modify_extras(mode_existing='kcu') # Check that extras are imported correctly self.assertEqual(imported_node.get_extra('a'), 1) self.assertEqual(imported_node.get_extra('b'), 2) self.assertEqual(imported_node.get_extra('c'), 3)
[docs] def test_extras_import_mode_mirror(self): """Check if old extras are fully overwritten by the imported ones""" self.import_extras() imported_node = self.modify_extras(mode_existing='ncu') # Check that extras are imported correctly with self.assertRaises(AttributeError): # the extra # 'a' should not exist, as the extras were fully mirrored with respect to # the imported node imported_node.get_extra('a') self.assertEqual(imported_node.get_extra('b'), 2) self.assertEqual(imported_node.get_extra('c'), 3)
[docs] def test_extras_import_mode_none(self): """Check if old extras are fully overwritten by the imported ones""" self.import_extras() imported_node = self.modify_extras(mode_existing='knl') # Check if extras are imported correctly self.assertEqual(imported_node.get_extra('b'), 1000) self.assertEqual(imported_node.get_extra('a'), 1) with self.assertRaises(AttributeError): # the extra # 'c' should not exist, as the extras were keept untached imported_node.get_extra('c')
[docs] def test_extras_import_mode_strange(self): """Check a mode that is probably does not make much sense but is still available""" self.import_extras() imported_node = self.modify_extras(mode_existing='kcd') # Check if extras are imported correctly self.assertEqual(imported_node.get_extra('a'), 1) self.assertEqual(imported_node.get_extra('c'), 3) with self.assertRaises(AttributeError): # the extra # 'b' should not exist, as the collided extras are deleted imported_node.get_extra('b')
[docs] def test_extras_import_mode_correct(self): """Test all possible import modes except 'ask' """ self.import_extras() for mode1 in ['k', 'n']: # keep or not keep old extras for mode2 in ['n', 'c']: # create or not create new extras for mode3 in ['l', 'u', 'd']: # leave old, update or delete collided extras mode = mode1 + mode2 + mode3 import_data(self.export_file, silent=True, extras_mode_existing=mode)
[docs] def test_extras_import_mode_wrong(self): """Check a mode that is wrong""" from aiida.tools.importexport.common.exceptions import ImportValidationError self.import_extras() with self.assertRaises(ImportValidationError): import_data(self.export_file, silent=True, extras_mode_existing='xnd') # first letter is wrong with self.assertRaises(ImportValidationError): import_data(self.export_file, silent=True, extras_mode_existing='nxd') # second letter is wrong with self.assertRaises(ImportValidationError): import_data(self.export_file, silent=True, extras_mode_existing='nnx') # third letter is wrong with self.assertRaises(ImportValidationError): import_data(self.export_file, silent=True, extras_mode_existing='n') # too short with self.assertRaises(ImportValidationError): import_data(self.export_file, silent=True, extras_mode_existing='nndnn') # too long with self.assertRaises(ImportValidationError): import_data(self.export_file, silent=True, extras_mode_existing=5) # wrong type