###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Module which provides decorators for AiiDA ORM entity -> DB field mappings."""
import datetime
import typing as t
from abc import ABCMeta
from copy import deepcopy
from functools import singledispatchmethod
from pprint import pformat
from aiida.common.lang import isidentifier
__all__ = (
'QbField',
'QbFieldFilters',
'QbFields',
)
[docs]
class QbField:
"""A field of an ORM entity, accessible via the ``QueryBuilder``"""
__slots__ = (
'_key',
'_backend_key',
'_doc',
'_dtype',
'_is_attribute',
'_is_subscriptable',
)
[docs]
def __init__(
self,
key: str,
alias: t.Optional[str] = None,
*,
dtype: t.Optional[t.Any] = None,
doc: str = '',
is_attribute: bool = True,
is_subscriptable: bool = False,
) -> None:
"""Initialise a ORM entity field, accessible via the ``QueryBuilder``
:param key: The key of the field on the ORM entity
:param alias: The alias in the storage backend for the key, if not equal to ``key``
:param dtype: The data type of the field. If None, the field is of variable type.
:param doc: A docstring for the field
:param is_attribute: If True, the ``backend_key`` property will prepend "attributes." to field name
:param is_subscriptable: If True, a new field can be created by ``field["subkey"]``
"""
self._key = key
self._backend_key = alias if alias is not None else key
self._doc = doc
self._dtype = dtype
self._is_attribute = is_attribute
self._is_subscriptable = is_subscriptable
@property
def key(self) -> str:
return self._key
@property
def backend_key(self) -> str:
if self._is_attribute:
return f'attributes.{self._backend_key}'
return self._backend_key
@property
def doc(self) -> str:
return self._doc
@property
def dtype(self) -> t.Optional[t.Any]:
"""Return the primitive root type."""
return extract_root_type(self._dtype)
@property
def annotation(self) -> t.Optional[t.Any]:
"""Return the full type annotation."""
return self._dtype
@property
def is_attribute(self) -> bool:
return self._is_attribute
@property
def is_subscriptable(self) -> bool:
return self._is_subscriptable
[docs]
def __repr__(self) -> str:
type_str = self._get_dtype_as_str() if self._dtype else ''
return (
f'{self.__class__.__name__}({self.key!r}'
+ (f', {self._backend_key!r}' if self._backend_key != self.key else '')
+ (f', dtype={type_str}' if self._dtype else '')
+ (f', is_attribute={self.is_attribute}')
+ (f', is_subscriptable={self.is_subscriptable!r}' if self.is_subscriptable else '')
+ ')'
)
[docs]
def __str__(self) -> str:
class_name = self.__class__.__name__
field_name = f'{self.backend_key}{".*" if self.is_subscriptable else ""}'
type_str = self._get_dtype_as_str()
return f'{class_name}({field_name}) -> {type_str}'
[docs]
def __hash__(self):
return hash((self.key, self.backend_key))
[docs]
def __eq__(self, value):
return QbFieldFilters(((self, '==', value),))
[docs]
def __ne__(self, value):
return QbFieldFilters(((self, '!==', value),))
[docs]
def in_(self, value: t.Iterable[t.Any]):
"""Return a filter for only values in the list"""
try:
set(value)
except TypeError:
raise TypeError('in_ must be iterable')
return QbFieldFilters(((self, 'in', value),))
[docs]
def _get_dtype_as_str(self) -> str:
"""Return field type as processed string.
>>> None -> ?
>>> str -> str
>>> typing.Optional[str] -> Optional[str]
>>> typing.Dict[typing.List[str]] -> Dict[List[str]]
"""
if self._dtype:
if t.get_origin(self._dtype):
type_str = str(self._dtype).replace('typing.', '')
else:
type_str = self._dtype.__name__
else:
type_str = '?'
return type_str
if t.TYPE_CHECKING:
def __getitem__(self, key: str) -> 'QbField': ...
[docs]
class QbNumericField(QbField):
"""A numeric (`int`, `float`, `datetime`) flavor of `QbField`."""
[docs]
def __lt__(self, value):
return QbFieldFilters(((self, '<', value),))
[docs]
def __le__(self, value):
return QbFieldFilters(((self, '<=', value),))
[docs]
def __gt__(self, value):
return QbFieldFilters(((self, '>', value),))
[docs]
def __ge__(self, value):
return QbFieldFilters(((self, '>=', value),))
[docs]
class QbArrayField(QbField):
"""An array (`list`) flavor of `QbField`."""
[docs]
def contains(self, value):
"""Return a filter for only values containing these items"""
return QbFieldFilters(((self, 'contains', value),))
[docs]
def of_length(self, value: int):
"""Return a filter for only array values of this length."""
if not isinstance(value, int):
raise TypeError('of_length must be an integer')
return QbFieldFilters(((self, 'of_length', value),))
[docs]
def longer(self, value: int):
"""Return a filter for only array values longer than this length."""
if not isinstance(value, int):
raise TypeError('longer must be an integer')
return QbFieldFilters(((self, 'longer', value),))
[docs]
def shorter(self, value: int):
"""Return a filter for only array values shorter than this length."""
if not isinstance(value, int):
raise TypeError('shorter must be an integer')
return QbFieldFilters(((self, 'shorter', value),))
[docs]
class QbStrField(QbField):
"""A string (`str`) flavor of `QbField`."""
[docs]
def like(self, value: str):
"""Return a filter for only string values matching the wildcard string.
- The percent sign (`%`) represents zero, one, or multiple characters
- The underscore sign (`_`) represents one, single character
"""
if not isinstance(value, str):
raise TypeError('like must be a string')
return QbFieldFilters(((self, 'like', value),))
[docs]
def ilike(self, value: str):
"""Return a filter for only string values matching the (case-insensitive) wildcard string.
- The percent sign (`%`) represents zero, one, or multiple characters
- The underscore sign (`_`) represents one, single character
"""
if not isinstance(value, str):
raise TypeError('ilike must be a string')
return QbFieldFilters(((self, 'ilike', value),))
[docs]
class QbDictField(QbField):
"""A dictionary (`dict`) flavor of `QbField`."""
[docs]
def has_key(self, value):
"""Return a filter for only values with these keys"""
return QbFieldFilters(((self, 'has_key', value),))
[docs]
def __getitem__(self, key: str) -> 'QbAttrField':
"""Return a new `QbField` with a nested key."""
if not self.is_subscriptable:
raise IndexError('This field is not subscriptable')
return QbAttrField(
f'{self.key}.{key}',
f'{self._backend_key}.{key}' if self.is_attribute else None,
is_attribute=self.is_attribute,
)
[docs]
class QbAttrField(QbNumericField, QbArrayField, QbStrField, QbDictField):
"""A generic flavor of `QbField` covering all operations."""
[docs]
def of_type(self, value):
"""Return a filter for only values of this type."""
return QbFieldFilters(((self, 'of_type', value),))
[docs]
class QbFieldFilters:
"""An representation of a list of fields and their comparators."""
__slots__ = ('filters',)
[docs]
def __init__(
self,
filters: t.Union[t.Sequence[t.Tuple[QbField, str, t.Any]], dict],
):
self.filters: t.Dict[str, t.Any] = {}
self.add_filters(filters)
[docs]
def as_dict(self) -> t.Dict[str, t.Any]:
"""Return the filters dictionary."""
return self.filters
[docs]
def items(self):
"""Return an items view of the filters for use in the QueryBuilder."""
return self.filters.items()
[docs]
@singledispatchmethod
def add_filters(self, filters: dict):
self.filters.update(filters)
[docs]
@add_filters.register(list)
@add_filters.register(tuple)
def _(self, filters):
field: QbField
for field, comparator, value in filters:
qb_field = field.backend_key
if qb_field in self.filters:
self.filters['and'] = [
{qb_field: self.filters.pop(qb_field)},
{qb_field: {comparator: value}},
]
else:
self.filters[qb_field] = {comparator: value}
[docs]
def __repr__(self) -> str:
return f'QbFieldFilters({self.filters})'
[docs]
def __getitem__(self, key: str) -> t.Any:
return self.filters[key]
[docs]
def __contains__(self, key: str) -> bool:
return key in self.filters
[docs]
def __eq__(self, other: object) -> bool:
"""``a == b`` checks if `a.filters == b.filters`."""
if not isinstance(other, QbFieldFilters):
raise TypeError(f'Cannot compare QbFieldFilters to {type(other)}')
return self.filters == other.filters
[docs]
def __and__(self, other: 'QbFieldFilters') -> 'QbFieldFilters':
"""``a & b`` -> {'and': [`a.filters`, `b.filters`]}."""
return self._resolve_redundancy(other, 'and') or QbFieldFilters({'and': [self.filters, other.filters]})
[docs]
def __or__(self, other: 'QbFieldFilters') -> 'QbFieldFilters':
"""``a | b`` -> {'or': [`a.filters`, `b.filters`]}."""
return self._resolve_redundancy(other, 'or') or QbFieldFilters({'or': [self.filters, other.filters]})
[docs]
def __invert__(self) -> 'QbFieldFilters':
"""~(a > b) -> a !> b; ~(a !> b) -> a > b"""
filters = deepcopy(self.filters)
if 'and' in filters:
filters['!and'] = filters.pop('and')
elif 'or' in filters:
filters['!or'] = filters.pop('or')
elif '!and' in filters:
filters['and'] = filters.pop('!and')
elif '!or' in filters:
filters['or'] = filters.pop('!or')
else:
key, args = next(iter(filters.items()))
operator, value = next(iter(args.items()))
operator = operator[1:] if '!' in operator else f'!{operator}'
filters[key] = {operator: value}
return QbFieldFilters(filters)
[docs]
def _resolve_redundancy(self, other: 'QbFieldFilters', logical: str) -> t.Optional['QbFieldFilters']:
"""Resolve redundant filters and nested logical operators."""
if not isinstance(other, QbFieldFilters):
raise TypeError(f'Cannot combine QbFieldFilters and {type(other)}')
# same filters
if other == self:
return self
# self is already wrapped in `logical`
# append other to self
if logical in self.filters:
self[logical].append(other.filters)
return self
# other is already wrapped in `logical`
# insert self in other
if logical in other:
other[logical].insert(0, self.filters)
return other
return None
[docs]
class QbFields:
"""A readonly class for mapping attributes to database fields of an AiiDA entity."""
__isabstractmethod__ = False
[docs]
def __init__(self, fields: t.Optional[t.Dict[str, QbField]] = None):
self._fields = fields or {}
[docs]
def __repr__(self) -> str:
return pformat({key: str(value) for key, value in self._fields.items()})
[docs]
def __str__(self) -> str:
return str({key: str(value) for key, value in self._fields.items()})
[docs]
def __getitem__(self, key: str) -> QbField:
"""Return an QbField by key."""
return self._fields[key]
[docs]
def __getattr__(self, key: str) -> QbField:
"""Return an QbField by key."""
try:
return self._fields[key]
except KeyError:
raise AttributeError(key)
[docs]
def __contains__(self, key: str) -> bool:
"""Return if the field key exists"""
return key in self._fields
[docs]
def __len__(self) -> int:
"""Return the number of fields"""
return len(self._fields)
[docs]
def __iter__(self):
"""Iterate through the field keys"""
return iter(self._fields)
[docs]
def __dir__(self):
"""Return keys for tab competion."""
return list(self._fields) + ['_dict']
@property
def _dict(self):
"""Return a copy of the internal mapping"""
return deepcopy(self._fields)
[docs]
class QbFieldArguments(t.TypedDict):
key: str
alias: t.Optional[str]
dtype: t.Optional[t.Any]
doc: str
is_attribute: bool
is_subscriptable: bool
[docs]
def add_field(
key: str,
alias: t.Optional[str] = None,
*,
dtype: t.Optional[t.Any] = None,
doc: str = '',
is_attribute: bool = True,
is_subscriptable: bool = False,
) -> QbField:
"""Add a `dtype`-dependent `QbField` representation of a field.
:param key: The key of the field on the ORM entity
:param alias: The alias in the storage backend for the key, if not equal to ``key``
:param dtype: The data type of the field. If None, the field is of variable type.
:param doc: A docstring for the field
:param is_attribute: If True, the ``backend_key`` property will prepend "attributes." to field name
:param is_subscriptable: If True, a new field can be created by ``field["subkey"]``
"""
kwargs: QbFieldArguments = {
'key': key,
'alias': alias,
'dtype': dtype,
'doc': doc,
'is_attribute': is_attribute,
'is_subscriptable': is_subscriptable,
}
if not isidentifier(key):
raise ValueError(f'{key} is not a valid python identifier')
if not is_attribute and alias:
raise ValueError('only attribute fields may be aliased')
if not dtype:
return QbField(**kwargs)
root_type = extract_root_type(dtype)
if root_type in (int, float, datetime.datetime):
return QbNumericField(**kwargs)
elif root_type is list:
return QbArrayField(**kwargs)
elif root_type is str:
return QbStrField(**kwargs)
elif root_type is dict:
return QbDictField(**kwargs)
else:
return QbField(**kwargs)