# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Manage code objects with lazy loading of the db env"""
import enum
import os
from aiida.cmdline.utils.decorators import with_dbenv
from aiida.common.utils import ErrorAccumulator
[docs]class CodeBuilder:
"""Build a code with validation of attribute combinations"""
[docs] def __init__(self, **kwargs):
self._err_acc = ErrorAccumulator(self.CodeValidationError)
self._code_spec = {}
# code_type must go first
for key in ['code_type']:
self.__setattr__(key, kwargs.pop(key))
# then set the rest
for key, value in kwargs.items():
self.__setattr__(key, value)
[docs] def validate(self, raise_error=True):
self._err_acc.run(self.validate_code_type)
self._err_acc.run(self.validate_upload)
self._err_acc.run(self.validate_installed)
return self._err_acc.result(raise_error=self.CodeValidationError if raise_error else False)
[docs] @with_dbenv()
def new(self):
"""Build and return a new code instance (not stored)"""
self.validate()
from aiida.orm import Code
# Will be used at the end to check if all keys are known (those that are not None)
passed_keys = set(k for k in self._code_spec.keys() if self._code_spec[k] is not None)
used = set()
if self._get_and_count('code_type', used) == self.CodeType.STORE_AND_UPLOAD:
file_list = [
os.path.realpath(os.path.join(self.code_folder, f))
for f in os.listdir(self._get_and_count('code_folder', used))
]
code = Code(local_executable=self._get_and_count('code_rel_path', used), files=file_list)
else:
code = Code(
remote_computer_exec=(
self._get_and_count('computer', used), self._get_and_count('remote_abs_path', used)
)
)
code.label = self._get_and_count('label', used)
code.description = self._get_and_count('description', used)
code.set_input_plugin_name(self._get_and_count('input_plugin', used))
code.set_prepend_text(self._get_and_count('prepend_text', used))
code.set_append_text(self._get_and_count('append_text', used))
# Complain if there are keys that are passed but not used
if passed_keys - used:
raise self.CodeValidationError(
f"Unknown parameters passed to the CodeBuilder: {', '.join(sorted(passed_keys - used))}"
)
return code
[docs] @staticmethod
def from_code(code):
"""Create CodeBuilder from existing code instance.
See also :py:func:`~CodeBuilder.get_code_spec`
"""
spec = CodeBuilder.get_code_spec(code)
return CodeBuilder(**spec)
[docs] @staticmethod
def get_code_spec(code):
"""Get code attributes from existing code instance.
These attributes can be used to create a new CodeBuilder::
spec = CodeBuilder.get_code_spec(old_code)
builder = CodeBuilder(**spec)
new_code = builder.new()
"""
spec = {}
spec['label'] = code.label
spec['description'] = code.description
spec['input_plugin'] = code.get_input_plugin_name()
spec['prepend_text'] = code.get_prepend_text()
spec['append_text'] = code.get_append_text()
if code.is_local():
spec['code_type'] = CodeBuilder.CodeType.STORE_AND_UPLOAD
spec['code_folder'] = code.get_code_folder()
spec['code_rel_path'] = code.get_code_rel_path()
else:
spec['code_type'] = CodeBuilder.CodeType.ON_COMPUTER
spec['computer'] = code.get_remote_computer()
spec['remote_abs_path'] = code.get_remote_exec_path()
return spec
[docs] def __getattr__(self, key):
"""Access code attributes used to build the code"""
if not key.startswith('_'):
try:
return self._code_spec[key]
except KeyError:
raise KeyError(f"Attribute '{key}' not set")
return None
[docs] def _get(self, key):
"""
Return a spec, or None if not defined
:param key: name of a code spec
"""
return self._code_spec.get(key)
[docs] def _get_and_count(self, key, used):
"""
Return a spec, or raise if not defined.
Moreover, add the key to the 'used' dict.
:param key: name of a code spec
:param used: should be a set of keys that you want to track.
``key`` will be added to this set if the value exists in the spec and can be retrieved.
"""
retval = self.__getattr__(key)
# I first get a retval, so if I get an exception, I don't add it to the 'used' set
used.add(key)
return retval
[docs] def __setattr__(self, key, value):
if not key.startswith('_'):
self._set_code_attr(key, value)
super().__setattr__(key, value)
[docs] def _set_code_attr(self, key, value):
"""Set a code attribute, if it passes validation.
Checks compatibility with other code attributes.
"""
if key == 'description' and value is None:
value = ''
backup = self._code_spec.copy()
self._code_spec[key] = value
success, _ = self.validate(raise_error=False)
if not success:
self._code_spec = backup
self.validate()
[docs] def validate_code_type(self):
"""Make sure the code type is set correctly"""
if self._get('code_type') and self.code_type not in self.CodeType:
raise self.CodeValidationError(
f'invalid code type: must be one of {list(self.CodeType)}, not {self.code_type}'
)
[docs] def validate_upload(self):
"""If the code is stored and uploaded, catch invalid on-computer attributes"""
messages = []
if self.is_local():
if self._get('computer'):
messages.append('invalid option for store-and-upload code: "computer"')
if self._get('remote_abs_path'):
messages.append('invalid option for store-and-upload code: "remote_abs_path"')
if messages:
raise self.CodeValidationError(f'{messages}')
[docs] def validate_installed(self):
"""If the code is on-computer, catch invalid store-and-upload attributes"""
messages = []
if self._get('code_type') == self.CodeType.ON_COMPUTER:
if self._get('code_folder'):
messages.append('invalid options for on-computer code: "code_folder"')
if self._get('code_rel_path'):
messages.append('invalid options for on-computer code: "code_rel_path"')
if messages:
raise self.CodeValidationError(f'{messages}')
[docs] class CodeValidationError(ValueError):
"""
A CodeBuilder instance may raise this
* when asked to instanciate a code with missing or invalid code attributes
* when asked for a code attibute that has not been set yet
"""
[docs] def __init__(self, msg):
super().__init__()
self.msg = msg
[docs] def __str__(self):
return self.msg
[docs] def __repr__(self):
return f'<CodeValidationError: {self}>'
[docs] def is_local(self):
"""Analogous to Code.is_local()"""
return self.__getattr__('code_type') == self.CodeType.STORE_AND_UPLOAD
[docs] class CodeType(enum.Enum):
STORE_AND_UPLOAD = 'store in the db and upload'
ON_COMPUTER = 'on computer'