# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""This module contains the AiiDA backend ORM classes for the SQLite backend.
It re-uses the classes already defined in ``psql_dos`` backend (for PostGresQL),
but redefines the SQLAlchemy models to the SQLite compatible ones.
"""
from functools import singledispatch
import json
from typing import Any, List, Optional, Tuple
from sqlalchemy import JSON, case, func
from sqlalchemy.sql import ColumnElement
from aiida.common.lang import type_check
from aiida.storage.psql_dos.orm import authinfos, comments, computers, entities, groups, logs, nodes, users, utils
from aiida.storage.psql_dos.orm.querybuilder.main import SqlaQueryBuilder
from . import models
from .utils import ReadOnlyError
[docs]class SqliteEntityOverride:
"""Overrides type-checking of psql_dos ``Entity``."""
MODEL_CLASS: Any
_model: utils.ModelWrapper
[docs] @classmethod
def _class_check(cls):
"""Assert that the class is correctly configured"""
assert issubclass(
cls.MODEL_CLASS, models.SqliteBase
), 'Must set the MODEL_CLASS in the derived class to a SQLA model'
[docs] @classmethod
def from_dbmodel(cls, dbmodel, backend):
"""Create an AiiDA Entity from the corresponding SQLA ORM model and storage backend
:param dbmodel: the SQLAlchemy model to create the entity from
:param backend: the corresponding storage backend
:return: the AiiDA entity
"""
cls._class_check()
type_check(dbmodel, cls.MODEL_CLASS)
entity = cls.__new__(cls)
super(entities.SqlaModelEntity, entity).__init__(backend) # type: ignore # # pylint: disable=bad-super-call
entity._model = utils.ModelWrapper(dbmodel, backend) # pylint: disable=protected-access
return entity
[docs] def store(self, *args, **kwargs):
backend = self._model._backend # pylint: disable=protected-access
if getattr(backend, '_read_only', False):
raise ReadOnlyError(f'Cannot store entity in read-only backend: {backend}')
super().store(*args, **kwargs) # type: ignore # pylint: disable=no-member
[docs]class SqliteUser(SqliteEntityOverride, users.SqlaUser):
MODEL_CLASS = models.DbUser
[docs]class SqliteUserCollection(users.SqlaUserCollection):
ENTITY_CLASS = SqliteUser
[docs]class SqliteComputer(SqliteEntityOverride, computers.SqlaComputer):
MODEL_CLASS = models.DbComputer
[docs]class SqliteComputerCollection(computers.SqlaComputerCollection):
ENTITY_CLASS = SqliteComputer
[docs]class SqliteAuthInfo(SqliteEntityOverride, authinfos.SqlaAuthInfo):
MODEL_CLASS = models.DbAuthInfo
USER_CLASS = SqliteUser
COMPUTER_CLASS = SqliteComputer
[docs]class SqliteAuthInfoCollection(authinfos.SqlaAuthInfoCollection):
ENTITY_CLASS = SqliteAuthInfo
[docs]class SqliteGroup(SqliteEntityOverride, groups.SqlaGroup):
MODEL_CLASS = models.DbGroup
USER_CLASS = SqliteUser
[docs]class SqliteGroupCollection(groups.SqlaGroupCollection):
ENTITY_CLASS = SqliteGroup
[docs]class SqliteLog(SqliteEntityOverride, logs.SqlaLog):
MODEL_CLASS = models.DbLog
[docs]class SqliteLogCollection(logs.SqlaLogCollection):
ENTITY_CLASS = SqliteLog
[docs]class SqliteNode(SqliteEntityOverride, nodes.SqlaNode):
"""SQLA Node backend entity"""
MODEL_CLASS = models.DbNode
USER_CLASS = SqliteUser
COMPUTER_CLASS = SqliteComputer
LINK_CLASS = models.DbLink
[docs]class SqliteNodeCollection(nodes.SqlaNodeCollection):
ENTITY_CLASS = SqliteNode
[docs]class SqliteQueryBuilder(SqlaQueryBuilder):
"""QueryBuilder to use with SQLAlchemy-backend, adapted for SQLite."""
@property
def Node(self):
return models.DbNode
@property
def Link(self):
return models.DbLink
@property
def Computer(self):
return models.DbComputer
@property
def User(self):
return models.DbUser
@property
def Group(self):
return models.DbGroup
@property
def AuthInfo(self):
return models.DbAuthInfo
@property
def Comment(self):
return models.DbComment
@property
def Log(self):
return models.DbLog
@property
def table_groups_nodes(self):
return models.DbGroupNodes.__table__ # type: ignore[attr-defined] # pylint: disable=no-member
[docs] def get_projectable_attribute(
self, alias, column_name: str, attrpath: List[str], cast: Optional[str] = None
) -> ColumnElement:
"""Return an attribute store in a JSON field of the give column"""
# pylint: disable=unused-argument
entity = self.get_column(column_name, alias)[attrpath]
if cast is None:
pass
elif cast == 'f':
entity = entity.as_float()
elif cast == 'i':
entity = entity.as_integer()
elif cast == 'b':
entity = entity.as_boolean()
elif cast == 't':
entity = entity.as_string()
elif cast == 'j':
entity = entity.as_json()
elif cast == 'd':
raise NotImplementedError('Date casting (d) for JSON key, not implemented for sqlite backend')
else:
raise ValueError(f'Unknown casting key {cast}')
return entity
[docs] def get_filter_expr_from_jsonb( # pylint: disable=too-many-return-statements,too-many-branches
self, operator: str, value, attr_key: List[str], column=None, column_name=None, alias=None
):
"""Return a filter expression.
See: https://www.sqlite.org/json1.html
"""
if column is None:
column = self.get_column(column_name, alias)
query_str = f'{alias or ""}.{column_name or ""}.{attr_key} {operator} {value}'
def _cast_json_type(comparator: JSON.Comparator, value: Any) -> Tuple[ColumnElement, JSON.Comparator]:
"""Cast the JSON comparator to the target type."""
if isinstance(value, bool):
# SQLite booleans in JSON evaluate to 0/1, see:
# https://dba.stackexchange.com/questions/287377/how-can-i-set-a-json-value-to-a-boolean-in-sqlite
return func.json_type(comparator) == 'integer', comparator.as_boolean()
if isinstance(value, int):
return func.json_type(comparator).in_(['integer', 'real']), comparator.as_integer()
if isinstance(value, float):
return func.json_type(comparator).in_(['integer', 'real']), comparator.as_float()
if isinstance(value, str):
return func.json_type(comparator) == 'text', comparator.as_string()
if isinstance(value, list):
return func.json_type(comparator) == 'array', comparator.as_json()
if isinstance(value, dict):
return func.json_type(comparator) == 'object', comparator.as_json()
raise TypeError(f'Unsupported type {type(value)} for SQLite query: {query_str}')
database_entity: JSON.Comparator = column[tuple(attr_key)]
if operator == '==':
# to-do: non-existent keys also equate to json_type null, so should check it exists also
# if value is None:
# return func.json_type(database_entity) == 'null'
type_filter, casted_entity = _cast_json_type(database_entity, value)
if isinstance(value, (list, dict)):
return case((type_filter, casted_entity == func.json(json.dumps(value))), else_=False)
# to-do not working for dict
return case((type_filter, casted_entity == value), else_=False)
if operator == '>':
type_filter, casted_entity = _cast_json_type(database_entity, value)
return case((type_filter, casted_entity > value), else_=False)
if operator == '<':
type_filter, casted_entity = _cast_json_type(database_entity, value)
return case((type_filter, casted_entity < value), else_=False)
if operator in ('>=', '=>'):
type_filter, casted_entity = _cast_json_type(database_entity, value)
return case((type_filter, casted_entity >= value), else_=False)
if operator in ('<=', '=<'):
type_filter, casted_entity = _cast_json_type(database_entity, value)
return case((type_filter, casted_entity <= value), else_=False)
if operator == 'of_type':
# convert from postgres types http://www.postgresql.org/docs/9.5/static/functions-json.html
# for consistency with other backends
valid_types = ('object', 'array', 'string', 'number', 'boolean', 'null')
type_map = {'object': 'object', 'array': 'array', 'string': 'text', 'null': 'null'}
if value in type_map:
return func.json_type(database_entity) == type_map[value]
if value == 'boolean':
type_filter = func.json_type(database_entity) == 'integer'
value_filter = database_entity.as_boolean().in_([True, False])
return case((type_filter, value_filter <= value), else_=False)
if value == 'number':
return func.json_type(database_entity).in_(['integer', 'real'])
raise ValueError(f'value {value!r} for `of_type` is not among valid types: {valid_types}')
if operator == 'like':
type_filter, casted_entity = _cast_json_type(database_entity, value)
return case((type_filter, casted_entity.like(value)), else_=False)
if operator == 'ilike':
type_filter, casted_entity = _cast_json_type(database_entity, value)
return case((type_filter, casted_entity.ilike(value)), else_=False)
# if operator == 'contains':
# to-do, see: https://github.com/sqlalchemy/sqlalchemy/discussions/7836
if operator == 'has_key':
return case((
func.json_type(database_entity) == 'object',
func.json_each(database_entity).table_valued('key', joins_implicitly=True).c.key == value,
),
else_=False)
if operator == 'in':
type_filter, casted_entity = _cast_json_type(database_entity, value[0])
return case((type_filter, casted_entity.in_(value)), else_=False)
if operator == 'of_length':
return case((
func.json_type(database_entity) == 'array',
func.json_array_length(database_entity.as_json()) == value,
),
else_=False)
if operator == 'longer':
return case((
func.json_type(database_entity) == 'array',
func.json_array_length(database_entity.as_json()) > value,
),
else_=False)
if operator == 'shorter':
return case((
func.json_type(database_entity) == 'array',
func.json_array_length(database_entity.as_json()) < value,
),
else_=False)
raise ValueError(f'SQLite does not support JSON query: {query_str}')
[docs]@singledispatch
def get_backend_entity(dbmodel, backend): # pylint: disable=unused-argument
raise TypeError(f"No corresponding AiiDA backend class exists for the model class '{dbmodel.__class__.__name__}'")
@get_backend_entity.register(models.DbUser) # type: ignore[call-overload]
def _(dbmodel, backend):
return SqliteUser.from_dbmodel(dbmodel, backend)
@get_backend_entity.register(models.DbGroup) # type: ignore[call-overload]
def _(dbmodel, backend):
return SqliteGroup.from_dbmodel(dbmodel, backend)
@get_backend_entity.register(models.DbComputer) # type: ignore[call-overload]
def _(dbmodel, backend):
return SqliteComputer.from_dbmodel(dbmodel, backend)
@get_backend_entity.register(models.DbNode) # type: ignore[call-overload]
def _(dbmodel, backend):
return SqliteNode.from_dbmodel(dbmodel, backend)
@get_backend_entity.register(models.DbAuthInfo) # type: ignore[call-overload]
def _(dbmodel, backend):
return SqliteAuthInfo.from_dbmodel(dbmodel, backend)
@get_backend_entity.register(models.DbComment) # type: ignore[call-overload]
def _(dbmodel, backend):
return SqliteComment.from_dbmodel(dbmodel, backend)
@get_backend_entity.register(models.DbLog) # type: ignore[call-overload]
def _(dbmodel, backend):
return SqliteLog.from_dbmodel(dbmodel, backend)
[docs]@get_backend_entity.register(models.DbLink) # type: ignore[call-overload]
def _(dbmodel, backend):
from aiida.orm.utils.links import LinkQuadruple
return LinkQuadruple(dbmodel.input_id, dbmodel.output_id, dbmodel.type, dbmodel.label)