# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
# pylint: disable=too-many-lines
"""Sqla query builder implementation"""
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
import uuid
import warnings
from sqlalchemy import and_
from sqlalchemy import func as sa_func
from sqlalchemy import not_, or_
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.engine import Row
from sqlalchemy.exc import CompileError, SAWarning
from sqlalchemy.orm import aliased
from sqlalchemy.orm.attributes import InstrumentedAttribute, QueryableAttribute
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql.compiler import SQLCompiler
from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList, Cast, ColumnClause, ColumnElement, Label
from sqlalchemy.sql.expression import case, text
from sqlalchemy.types import Boolean, DateTime, Float, Integer, String
from aiida.common.exceptions import NotExistent
from aiida.orm.entities import EntityTypes
from aiida.orm.implementation.querybuilder import QUERYBUILD_LOGGER, BackendQueryBuilder, QueryDictType
from .joiner import JoinReturn, SqlaJoiner
jsonb_typeof = sa_func.jsonb_typeof
jsonb_array_length = sa_func.jsonb_array_length
array_length = sa_func.array_length
[docs]
@dataclass
class BuiltQuery:
"""A class to store the query and the corresponding projections."""
query: Query
tag_to_alias: Dict[str, Optional[AliasedClass]]
tag_to_projected: Dict[str, Dict[str, int]]
[docs]
class SqlaQueryBuilder(BackendQueryBuilder):
"""
QueryBuilder to use with SQLAlchemy-backend and
schema defined in backends.sqlalchemy.models
"""
# pylint: disable=too-many-public-methods,invalid-name
[docs]
def __init__(self, backend):
super().__init__(backend)
self._joiner = SqlaJoiner(self, self.build_filters)
# Set conversions between the field names in the database and used by the `QueryBuilder`
# table -> field -> field
self.inner_to_outer_schema: Dict[str, Dict[str, str]] = {
'db_dbauthinfo': {
'_metadata': 'metadata'
},
'db_dbcomputer': {
'_metadata': 'metadata'
},
'db_dblog': {
'_metadata': 'metadata'
},
}
self.outer_to_inner_schema: Dict[str, Dict[str, str]] = {
'db_dbauthinfo': {
'metadata': '_metadata'
},
'db_dbcomputer': {
'metadata': '_metadata'
},
'db_dblog': {
'metadata': '_metadata'
},
}
# Hashing the internal query representation avoids rebuilding a query
self._query_cache: Optional[BuiltQuery] = None
self._query_hash: Optional[str] = None
@property
def Node(self):
import aiida.storage.psql_dos.models.node
return aiida.storage.psql_dos.models.node.DbNode
@property
def Link(self):
import aiida.storage.psql_dos.models.node
return aiida.storage.psql_dos.models.node.DbLink
@property
def Computer(self):
import aiida.storage.psql_dos.models.computer
return aiida.storage.psql_dos.models.computer.DbComputer
@property
def User(self):
import aiida.storage.psql_dos.models.user
return aiida.storage.psql_dos.models.user.DbUser
@property
def Group(self):
import aiida.storage.psql_dos.models.group
return aiida.storage.psql_dos.models.group.DbGroup
@property
def AuthInfo(self):
import aiida.storage.psql_dos.models.authinfo
return aiida.storage.psql_dos.models.authinfo.DbAuthInfo
@property
def Comment(self):
import aiida.storage.psql_dos.models.comment
return aiida.storage.psql_dos.models.comment.DbComment
@property
def Log(self):
import aiida.storage.psql_dos.models.log
return aiida.storage.psql_dos.models.log.DbLog
@property
def table_groups_nodes(self):
import aiida.storage.psql_dos.models.group
return aiida.storage.psql_dos.models.group.table_groups_nodes
[docs]
def get_session(self) -> Session:
"""Get the connection to the database"""
return self._backend.get_session()
[docs]
def count(self, data: QueryDictType) -> int:
with self.query_session(data) as build:
result = build.query.count()
return result
[docs]
def first(self, data: QueryDictType) -> Optional[List[Any]]:
with self.query_session(data) as build:
result = build.query.first()
if result is None:
return result
# SQLA will return a Row, if only certain columns are requested,
# or a database model if that is requested
if not isinstance(result, Row):
result = (result,)
return [self.to_backend(r) for r in result]
[docs]
def iterall(self, data: QueryDictType, batch_size: Optional[int]) -> Iterable[List[Any]]:
"""Return an iterator over all the results of a list of lists."""
with self.query_session(data) as build:
stmt = build.query.statement.execution_options(yield_per=batch_size)
session = self.get_session()
# Open a session transaction unless already inside one. This prevents the `ModelWrapper` from calling commit
# on the session when a yielded row is mutated. This would reset the cursor invalidating it and causing an
# exception to be raised in the next batch of rows in the iteration.
# See https://github.com/python/mypy/issues/10109 for the reason of the type warning.
with nullcontext() if session.in_nested_transaction() else self._backend.transaction(
): # type: ignore[attr-defined]
for resultrow in session.execute(stmt):
yield [self.to_backend(rowitem) for rowitem in resultrow]
[docs]
def iterdict(self, data: QueryDictType, batch_size: Optional[int]) -> Iterable[Dict[str, Dict[str, Any]]]:
"""Return an iterator over all the results of a list of dictionaries."""
with self.query_session(data) as build:
stmt = build.query.statement.execution_options(yield_per=batch_size)
session = self.get_session()
# Open a session transaction unless already inside one. This prevents the `ModelWrapper` from calling commit
# on the session when a yielded row is mutated. This would reset the cursor invalidating it and causing an
# exception to be raised in the next batch of rows in the iteration.
# See https://github.com/python/mypy/issues/10109 for the reason of the type warning.
with nullcontext() if session.in_nested_transaction() else self._backend.transaction(
): # type: ignore[attr-defined]
for row in self.get_session().execute(stmt):
# build the yield result
yield_result: Dict[str, Dict[str, Any]] = {}
for (
tag,
projected_entities_dict,
) in build.tag_to_projected.items():
yield_result[tag] = {}
for attrkey, project_index in projected_entities_dict.items():
alias = build.tag_to_alias.get(tag)
if alias is None:
raise ValueError(f'No alias found for tag {tag}')
field_name = get_corresponding_property(
get_table_name(alias),
attrkey,
self.inner_to_outer_schema,
)
yield_result[tag][field_name] = self.to_backend(row[project_index])
yield yield_result
[docs]
def get_query(self, data: QueryDictType) -> BuiltQuery:
"""Return the built query.
To avoid unnecessary re-builds of the query, the hashed dictionary representation of this instance
is compared to the last query returned, which is cached by its hash.
"""
from aiida.common.hashing import make_hash
query_hash = make_hash(data)
if not (self._query_cache and self._query_hash and self._query_hash == query_hash):
self._query_cache = self._build(data)
self._query_hash = query_hash
return self._query_cache
[docs]
@contextmanager
def query_session(self, data: QueryDictType) -> Iterator[BuiltQuery]:
"""Yield the built query, ensuring the session is closed on an exception."""
query = self.get_query(data)
try:
yield query
except Exception:
self.get_session().close()
raise
[docs]
def _build(self, data: QueryDictType) -> BuiltQuery:
"""Build the query and return."""
# pylint: disable=too-many-branches,too-many-locals
# generate aliases for tags
tag_to_alias: Dict[str, Optional[AliasedClass]] = {}
cls_map = {
EntityTypes.AUTHINFO.value: self.AuthInfo,
EntityTypes.COMMENT.value: self.Comment,
EntityTypes.COMPUTER.value: self.Computer,
EntityTypes.GROUP.value: self.Group,
EntityTypes.NODE.value: self.Node,
EntityTypes.LOG.value: self.Log,
EntityTypes.USER.value: self.User,
EntityTypes.LINK.value: self.Link,
}
for path in data['path']:
# An SAWarning warning is currently emitted:
# "relationship 'DbNode.input_links' will copy column db_dbnode.id to column db_dblink.output_id,
# which conflicts with relationship(s): 'DbNode.outputs' (copies db_dbnode.id to db_dblink.output_id)"
# This should be eventually fixed
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=SAWarning)
tag_to_alias[path['tag']] = aliased(cls_map[path['orm_base']])
# now create joins first, so we can populate edge tag aliases
joins = generate_joins(data, tag_to_alias, self._joiner)
for join in joins:
if join.aliased_edge is not None:
tag_to_alias[join.edge_tag] = join.aliased_edge
# generate the projections
projections, tag_to_projected = generate_projections(
data, tag_to_alias, self.outer_to_inner_schema, self._get_projectable_entity
)
# initialise the query
query = self.get_session().query()
# add the projections
for projection, is_entity in projections:
if is_entity:
query = query.add_entity(projection)
else:
query = query.add_columns(projection)
# and the starting table, from which joins will be made
starting_table = tag_to_alias.get(data['path'][0]['tag'])
if starting_table is None:
raise ValueError('starting tag not found')
query = query.select_from(starting_table)
# add the joins
for join in joins:
query = join.join(query)
# add the filters
for tag, filter_specs in data['filters'].items():
if not filter_specs:
continue
alias = tag_to_alias.get(tag)
if not alias:
raise ValueError(f'Unknown tag {tag!r} in filters, known: {list(tag_to_alias)}')
filters = self.build_filters(alias, filter_specs)
if filters is not None:
query = query.filter(filters)
# set the ordering
for order_spec in data['order_by']:
for tag, entity_list in order_spec.items():
alias = tag_to_alias.get(tag)
if not alias:
raise ValueError(f'Unknown tag {tag!r} in order_by, known: {list(tag_to_alias)}')
for entitydict in entity_list:
for entitytag, entityspec in entitydict.items():
query = query.order_by(self._create_order_by(alias, entitytag, entityspec))
# final setup for the query
if data['limit'] is not None:
query = query.limit(data['limit'])
if data['offset'] is not None:
query = query.offset(data['offset'])
if data['distinct']:
query = query.distinct()
return BuiltQuery(query, tag_to_alias, tag_to_projected)
[docs]
def _create_order_by(self, alias: AliasedClass, field_key: str,
entityspec: dict) -> Union[ColumnElement, InstrumentedAttribute]:
"""Build the order_by parameter of the query."""
column_name = field_key.split('.')[0]
attrpath = field_key.split('.')[1:]
if attrpath and 'cast' not in entityspec.keys():
# JSONB fields ar delimited by '.' must be cast
raise ValueError(
f'To order_by {field_key!r}, the value has to be cast, '
"but no 'cast' key has been specified."
)
entity = self._get_projectable_entity(alias, column_name, attrpath, cast=entityspec.get('cast'))
order = entityspec.get('order', 'asc')
if order == 'desc':
entity = entity.desc()
elif order != 'asc':
raise ValueError(f"Unknown 'order' key: {order!r}, must be one of: 'asc', 'desc'")
return entity
[docs]
@staticmethod
def _get_projectable_entity(
alias: AliasedClass,
column_name: str,
attrpath: List[str],
cast: Optional[str] = None,
) -> Union[ColumnElement, InstrumentedAttribute]:
"""Return projectable entity for a given alias and column name."""
if not (attrpath or column_name in ('attributes', 'extras')):
return get_column(column_name, alias)
entity: ColumnElement = get_column(column_name, alias)[attrpath]
if cast is None:
pass
elif cast == 'f':
entity = entity.astext.cast(Float)
elif cast == 'i':
entity = entity.astext.cast(Integer)
elif cast == 'b':
entity = entity.astext.cast(Boolean)
elif cast == 't':
entity = entity.astext
elif cast == 'j':
entity = entity.astext.cast(JSONB)
elif cast == 'd':
entity = entity.astext.cast(DateTime)
else:
raise ValueError(f'Unknown casting key {cast}')
return entity
[docs]
def build_filters(self, alias: AliasedClass, filter_spec: Dict[str, Any]) -> Optional[BooleanClauseList]: # pylint: disable=too-many-branches
"""Recurse through the filter specification and apply filter operations.
:param alias: The alias of the ORM class the filter will be applied on
:param filter_spec: the specification of the filter
:returns: an sqlalchemy expression.
"""
expressions: List[Any] = []
for path_spec, filter_operation_dict in filter_spec.items():
if path_spec in ('and', 'or', '~or', '~and', '!and', '!or'):
subexpressions = []
for sub_filter_spec in filter_operation_dict:
filters = self.build_filters(alias, sub_filter_spec)
if filters is not None:
subexpressions.append(filters)
if subexpressions:
if path_spec == 'and':
expressions.append(and_(*subexpressions))
elif path_spec == 'or':
expressions.append(or_(*subexpressions))
elif path_spec in ('~and', '!and'):
expressions.append(not_(and_(*subexpressions)))
elif path_spec in ('~or', '!or'):
expressions.append(not_(or_(*subexpressions)))
else:
column_name = path_spec.split('.')[0]
attr_key = path_spec.split('.')[1:]
is_jsonb = bool(attr_key) or column_name in ('attributes', 'extras')
column: Optional[InstrumentedAttribute]
try:
column = get_column(column_name, alias)
except (ValueError, TypeError):
if is_jsonb:
column = None
else:
raise
if not isinstance(filter_operation_dict, dict):
filter_operation_dict = {'==': filter_operation_dict}
for operator, value in filter_operation_dict.items():
expressions.append(
self.get_filter_expr(
operator,
value,
attr_key,
is_jsonb=is_jsonb,
column=column,
column_name=column_name,
alias=alias,
)
)
return and_(*expressions) if expressions else None
[docs]
def get_filter_expr(
self,
operator: str,
value: Any,
attr_key: List[str],
is_jsonb: bool,
alias=None,
column=None,
column_name=None,
):
"""Applies a filter on the alias given.
Expects the alias of the ORM-class on which to filter, and filter_spec.
Filter_spec contains the specification on the filter.
Expects:
:param operator: The operator to apply, see below for further details
:param value:
The value for the right side of the expression,
the value you want to compare with.
:param path: The path leading to the value
:param is_jsonb: Whether the value is in a json-column, or in an attribute like table.
Implemented and valid operators:
* for any type:
* == (compare single value, eg: '==':5.0)
* in (compare whether in list, eg: 'in':[5, 6, 34]
* for floats and integers:
* >
* <
* <=
* >=
* for strings:
* like (case - sensitive), for example
'like':'node.calc.%' will match node.calc.relax and
node.calc.RELAX and node.calc. but
not node.CALC.relax
* ilike (case - unsensitive)
will also match node.CaLc.relax in the above example
.. note::
The character % is a reserved special character in SQL,
and acts as a wildcard. If you specifically
want to capture a ``%`` in the string, use: ``_%``
* for arrays and dictionaries (only for the
SQLAlchemy implementation):
* contains: pass a list with all the items that
the array should contain, or that should be among
the keys, eg: 'contains': ['N', 'H'])
* has_key: pass an element that the list has to contain
or that has to be a key, eg: 'has_key':'N')
* for arrays only (SQLAlchemy version):
* of_length
* longer
* shorter
All the above filters invoke a negation of the
expression if preceded by **~**::
# first example:
filter_spec = {
'name' : {
'~in':[
'halle',
'lujah'
]
} # Name not 'halle' or 'lujah'
}
# second example:
filter_spec = {
'id' : {
'~==': 2
}
} # id is not 2
"""
# pylint: disable=too-many-arguments, too-many-branches
expr: Any = None
if operator.startswith('~'):
negation = True
operator = operator.lstrip('~')
elif operator.startswith('!'):
negation = True
operator = operator.lstrip('!')
else:
negation = False
if operator in ('longer', 'shorter', 'of_length'):
if not isinstance(value, int):
raise TypeError('You have to give an integer when comparing to a length')
elif operator in ('like', 'ilike'):
if not isinstance(value, str):
raise TypeError(f'Value for operator {operator} has to be a string (you gave {value})')
elif operator == 'in':
try:
value_type_set = set(type(i) for i in value)
except TypeError:
raise TypeError('Value for operator `in` could not be iterated')
if not value_type_set:
raise ValueError('Value for operator `in` is an empty list')
if len(value_type_set) > 1:
raise ValueError(f'Value for operator `in` contains more than one type: {value}')
elif operator in ('and', 'or'):
expressions_for_this_path = []
for filter_operation_dict in value:
for newoperator, newvalue in filter_operation_dict.items():
expressions_for_this_path.append(
self.get_filter_expr(
newoperator,
newvalue,
attr_key=attr_key,
is_jsonb=is_jsonb,
alias=alias,
column=column,
column_name=column_name,
)
)
if operator == 'and':
expr = and_(*expressions_for_this_path)
elif operator == 'or':
expr = or_(*expressions_for_this_path)
if expr is None:
if is_jsonb:
expr = self.get_filter_expr_from_jsonb(
operator,
value,
attr_key,
column=column,
column_name=column_name,
alias=alias,
)
else:
if column is None:
if (alias is None) and (column_name is None):
raise RuntimeError('I need to get the column but do not know the alias and the column name')
column = get_column(column_name, alias)
expr = self.get_filter_expr_from_column(operator, value, column)
if negation:
return not_(expr)
return expr
[docs]
@staticmethod
def get_filter_expr_from_jsonb(
operator: str,
value,
attr_key: List[str],
column=None,
column_name=None,
alias=None,
):
"""Return a filter expression"""
# pylint: disable=too-many-branches, too-many-arguments, too-many-statements
def cast_according_to_type(path_in_json, value):
"""Cast the value according to the type"""
if isinstance(value, bool):
type_filter = jsonb_typeof(path_in_json) == 'boolean'
casted_entity = path_in_json.astext.cast(Boolean)
elif isinstance(value, (int, float)):
type_filter = jsonb_typeof(path_in_json) == 'number'
casted_entity = path_in_json.astext.cast(Float)
elif isinstance(value, dict) or value is None:
type_filter = jsonb_typeof(path_in_json) == 'object'
casted_entity = path_in_json.astext.cast(JSONB) # BOOLEANS?
elif isinstance(value, dict):
type_filter = jsonb_typeof(path_in_json) == 'array'
casted_entity = path_in_json.astext.cast(JSONB) # BOOLEANS?
elif isinstance(value, str):
type_filter = jsonb_typeof(path_in_json) == 'string'
casted_entity = path_in_json.astext
elif value is None:
type_filter = jsonb_typeof(path_in_json) == 'null'
casted_entity = path_in_json.astext.cast(JSONB) # BOOLEANS?
else:
raise TypeError(f'Unknown type {type(value)}')
return type_filter, casted_entity
if column is None:
column = get_column(column_name, alias)
database_entity = column[tuple(attr_key)]
expr: Any
if operator == '==':
type_filter, casted_entity = cast_according_to_type(database_entity, value)
expr = case((type_filter, casted_entity == value), else_=False)
elif operator == '>':
type_filter, casted_entity = cast_according_to_type(database_entity, value)
expr = case((type_filter, casted_entity > value), else_=False)
elif operator == '<':
type_filter, casted_entity = cast_according_to_type(database_entity, value)
expr = case((type_filter, casted_entity < value), else_=False)
elif operator in ('>=', '=>'):
type_filter, casted_entity = cast_according_to_type(database_entity, value)
expr = case((type_filter, casted_entity >= value), else_=False)
elif operator in ('<=', '=<'):
type_filter, casted_entity = cast_according_to_type(database_entity, value)
expr = case((type_filter, casted_entity <= value), else_=False)
elif operator == 'of_type':
# http://www.postgresql.org/docs/9.5/static/functions-json.html
# Possible types are object, array, string, number, boolean, and null.
valid_types = ('object', 'array', 'string', 'number', 'boolean', 'null')
if value not in valid_types:
raise ValueError(f'value {value} for of_type is not among valid types\n{valid_types}')
expr = jsonb_typeof(database_entity) == value
elif operator == 'like':
type_filter, casted_entity = cast_according_to_type(database_entity, value)
expr = case((type_filter, casted_entity.like(value)), else_=False)
elif operator == 'ilike':
type_filter, casted_entity = cast_according_to_type(database_entity, value)
expr = case((type_filter, casted_entity.ilike(value)), else_=False)
elif operator == 'in':
type_filter, casted_entity = cast_according_to_type(database_entity, value[0])
expr = case((type_filter, casted_entity.in_(value)), else_=False)
elif operator == 'contains':
expr = database_entity.cast(JSONB).contains(value)
elif operator == 'has_key':
expr = database_entity.cast(JSONB).has_key(value) # noqa
elif operator == 'of_length':
expr = case(
(
jsonb_typeof(database_entity) == 'array',
jsonb_array_length(database_entity.cast(JSONB)) == value,
),
else_=False,
)
elif operator == 'longer':
expr = case(
(
jsonb_typeof(database_entity) == 'array',
jsonb_array_length(database_entity.cast(JSONB)) > value,
),
else_=False,
)
elif operator == 'shorter':
expr = case(
(
jsonb_typeof(database_entity) == 'array',
jsonb_array_length(database_entity.cast(JSONB)) < value,
),
else_=False,
)
else:
raise ValueError(f'Unknown operator {operator} for filters in JSON field')
return expr
[docs]
@staticmethod
def get_filter_expr_from_column(operator: str, value: Any, column) -> BinaryExpression:
"""A method that returns an valid SQLAlchemy expression.
:param operator: The operator provided by the user ('==', '>', ...)
:param value: The value to compare with, e.g. (5.0, 'foo', ['a','b'])
:param column: an instance of sqlalchemy.orm.attributes.InstrumentedAttribute or
"""
# Label is used because it is what is returned for the
# 'state' column by the hybrid_column construct
if not isinstance(
column,
(Cast, InstrumentedAttribute, QueryableAttribute, Label, ColumnClause),
):
raise TypeError(f'column ({type(column)}) {column} is not a valid column')
database_entity = column
if operator == '==':
expr = database_entity == value
elif operator == '>':
expr = database_entity > value
elif operator == '<':
expr = database_entity < value
elif operator == '>=':
expr = database_entity >= value
elif operator == '<=':
expr = database_entity <= value
elif operator == 'like':
# the like operator expects a string, so we cast to avoid problems
# with fields like UUID, which don't support the like operator
expr = database_entity.cast(String).like(value)
elif operator == 'ilike':
expr = database_entity.ilike(value)
elif operator == 'in':
expr = database_entity.in_(value)
else:
raise ValueError(f'Unknown operator {operator} for filters on columns')
return expr
[docs]
def to_backend(self, res) -> Any:
"""Convert results to return backend specific objects.
- convert `DbModel` instances to `BackendEntity` instances.
- convert UUIDs to strings
:param res: the result returned by the query
:returns:backend compatible instance
"""
if isinstance(res, uuid.UUID):
return str(res)
try:
return self._backend.get_backend_entity(res)
except TypeError:
return res
[docs]
def as_sql(self, data: QueryDictType, inline: bool = False) -> str:
"""Return the SQL query as a string."""
with self.query_session(data) as build:
compiled = compile_query(build.query, literal_binds=inline)
if inline:
return compiled.string + '\n'
return f'{compiled.string!r} % {compiled.params!r}\n'
[docs]
def analyze_query(self, data: QueryDictType, execute: bool = True, verbose: bool = False) -> str:
"""Analyze the query and return the result as a string."""
with self.query_session(data) as build:
if build.query.session.bind.dialect.name != 'postgresql': # type: ignore[union-attr]
raise NotImplementedError('Only PostgreSQL is supported for this method')
compiled = compile_query(build.query, literal_binds=True)
options = ', '.join((['ANALYZE'] if execute else []) + (['VERBOSE'] if verbose else []))
options = f' ({options})' if options else ''
rows = (self.get_session().execute(text(f'EXPLAIN{options} {compiled.string}')).fetchall())
return '\n'.join(row[0] for row in rows)
[docs]
def get_creation_statistics(self, user_pk: Optional[int] = None) -> Dict[str, Any]:
session = self.get_session()
retdict: Dict[Any, Any] = {}
total_query = session.query(self.Node)
types_query = session.query(self.Node.node_type.label('typestring'), sa_func.count(self.Node.id)) # pylint: disable=no-member,not-callable
stat_query = session.query(
sa_func.date_trunc('day', self.Node.ctime).label('cday'), # pylint: disable=no-member
sa_func.count(self.Node.id), # pylint: disable=no-member,not-callable
)
if user_pk is not None:
total_query = total_query.filter(self.Node.user_id == user_pk)
types_query = types_query.filter(self.Node.user_id == user_pk)
stat_query = stat_query.filter(self.Node.user_id == user_pk)
# Total number of nodes
retdict['total'] = total_query.count()
# Nodes per type
retdict['types'] = dict(types_query.group_by('typestring').all())
# Nodes created per day
stat = stat_query.group_by('cday').order_by('cday').all()
ctime_by_day = {_[0].strftime('%Y-%m-%d'): _[1] for _ in stat}
retdict['ctime_by_day'] = ctime_by_day
return retdict
# Still not containing all dates
[docs]
def get_corresponding_property(entity_table: str, given_property: str, mapper: Dict[str, Dict[str, str]]) -> str:
"""
This method returns an updated property for a given a property.
If there is no update for the property, the given property is returned.
"""
try:
# Get the mapping for the specific entity_table
property_mapping = mapper[entity_table]
try:
# Get the mapping for the specific property
return property_mapping[given_property]
except KeyError:
# If there is no mapping, the property remains unchanged
return given_property
except KeyError:
# If it doesn't exist, it means that the given_property remains v
return given_property
[docs]
def get_corresponding_properties(entity_table: str, given_properties: List[str], mapper: Dict[str, Dict[str, str]]):
"""
This method returns a list of updated properties for a given list of properties.
If there is no update for the property, the given property is returned in the list.
"""
if entity_table in mapper:
res = []
for given_property in given_properties:
res.append(get_corresponding_property(entity_table, given_property, mapper))
return res
return given_properties
[docs]
def modify_expansions(
alias: AliasedClass,
expansions: List[str],
outer_to_inner_schema=Dict[str, Dict[str, str]],
) -> List[str]:
"""Modify names of projections if `**` was specified.
This is important for the schema having attributes in a different table.
In SQLA, the metadata should be changed to _metadata to be in-line with the database schema
"""
# pylint: disable=protected-access
# The following check is added to avoided unnecessary calls to get_inner_property for QB edge queries
# The update of expansions makes sense only when AliasedClass is provided
if hasattr(alias, '_sa_class_manager'):
if '_metadata' in expansions:
raise NotExistent(f"_metadata doesn't exist for {alias}. Please try metadata.")
return get_corresponding_properties(alias.__tablename__, expansions, outer_to_inner_schema)
return expansions
[docs]
def get_column(colname: str, alias: AliasedClass) -> InstrumentedAttribute:
"""
Return the column for a given projection.
"""
try:
return getattr(alias, colname)
except AttributeError as exc:
raise ValueError(
'{} is not a column of {}\n'
'Valid columns are:\n'
'{}'.format(colname, alias, '\n'.join(alias._sa_class_manager.mapper.c.keys())) # pylint: disable=protected-access
) from exc
[docs]
def get_column_names(alias: AliasedClass) -> List[str]:
"""
Given the backend specific alias, return the column names that correspond to the aliased table.
"""
return [str(c).replace(f'{alias.__table__.name}.', '') for c in alias.__table__.columns]
[docs]
def get_table_name(aliased_class: AliasedClass) -> str:
""" Returns the table name given an Aliased class"""
return aliased_class.__tablename__
[docs]
def compile_query(query: Query, literal_binds: bool = False) -> SQLCompiler:
"""Compile the query to the SQL executable.
:params literal_binds: Inline bound parameters (this is normally handled by the Python DBAPI).
"""
dialect = query.session.bind.dialect # type: ignore[union-attr]
class _Compiler(dialect.statement_compiler): # type: ignore[name-defined]
"""Override the compiler with additional literal value renderers."""
def render_literal_value(self, value, type_):
"""Render the value of a bind parameter as a quoted literal.
See https://www.postgresql.org/docs/current/functions-json.html for serialisation specs
"""
from datetime import date, datetime, timedelta
try:
return super().render_literal_value(value, type_)
# sqlalchemy<1.4.45 raises NotImplementedError, sqlalchemy>=1.4.45 raises CompileError
except (NotImplementedError, CompileError):
if isinstance(value, list):
values = ','.join(self.render_literal_value(item, type_) for item in value)
return f"'[{values}]'"
if isinstance(value, int):
return str(value)
if isinstance(value, (str, date, datetime, timedelta)):
escaped = str(value).replace('"', '\\"')
return f'"{escaped}"'
raise
return _Compiler(dialect, query.statement, compile_kwargs=dict(literal_binds=literal_binds))
[docs]
def generate_joins(data: QueryDictType, aliases: Dict[str, Optional[AliasedClass]],
joiner: SqlaJoiner) -> List[JoinReturn]:
"""Generate the joins for the query."""
joins: List[JoinReturn] = []
# Start on second path item, since there is nothing to join if that is the first table
for index, verticespec in enumerate(data['path'][1:], start=1):
join_to = aliases[verticespec['tag']]
if join_to is None:
raise ValueError(f'No alias found for tag {verticespec["tag"]}')
calling_entity = data['path'][index]['orm_base']
joining_keyword = verticespec['joining_keyword']
joining_value = verticespec['joining_value']
try:
join_func = joiner.get_join_func(calling_entity, joining_keyword)
except KeyError:
raise ValueError(f"'{joining_keyword}' is not a valid joining keyword for a '{calling_entity}' type entity")
if not isinstance(joining_value, str):
raise ValueError(f"'joining_value' value is not a string: {joining_value}")
join_tag = aliases.get(joining_value, None)
if not join_tag:
raise ValueError(f'no alias found for joining_value tag {joining_value!r}')
edge_tag = verticespec['edge_tag']
# if verticespec['joining_keyword'] in ('with_ancestors', 'with_descendants'):
# These require a filter_dict, to help the recursive function find a good starting point.
filter_dict = data['filters'].get(verticespec['joining_value'], {})
# Also find out whether the path is used in a filter or a project and, if so,
# instruct the recursive function to build the path on the fly.
# The default is False, because it's super expensive
expand_path = (data['filters'][edge_tag].get('path', None)
is not None) or any('path' in d.keys() for d in data['project'][edge_tag])
join = join_func(
join_tag,
join_to,
isouterjoin=verticespec.get('outerjoin'), # type: ignore
filter_dict=filter_dict,
expand_path=expand_path,
)
join.edge_tag = edge_tag
joins.append(join)
return joins
[docs]
def generate_projections(
data: QueryDictType, aliases: Dict[str, Optional[AliasedClass]], outer_to_inner_schema, get_projectable_entity
):
"""Generate the projections for the query."""
# mapping of tag -> field -> projection_index
tag_to_projected_fields: Dict[str, Dict[str, int]] = {}
projections: List[Tuple[Union[AliasedClass, ColumnElement], bool]] = []
QUERYBUILD_LOGGER.debug('projections data: %s', data['project'])
if not any(data['project'].values()):
# If user has not set projection,
# I will simply project the last item specified!
# Don't change, path traversal querying relies on this behavior!
projections += _create_projections(
data['path'][-1]['tag'], aliases, len(projections), [{
'*': {}
}], get_projectable_entity, tag_to_projected_fields, outer_to_inner_schema
)
else:
for vertex in data['path']:
project_dict = data['project'].get(vertex['tag'], [])
projections += _create_projections(
vertex['tag'], aliases, len(projections), project_dict, get_projectable_entity, tag_to_projected_fields,
outer_to_inner_schema
)
for vertex in data['path'][1:]:
edge_tag = vertex.get('edge_tag', None)
QUERYBUILD_LOGGER.debug(
'Checking projections for edges: This is edge %s from %s, %s of %s',
edge_tag,
vertex.get('tag'),
vertex.get('joining_keyword'),
vertex.get('joining_value'),
)
if edge_tag is not None:
project_dict = data['project'].get(edge_tag, [])
projections += _create_projections(
edge_tag, aliases, len(projections), project_dict, get_projectable_entity, tag_to_projected_fields,
outer_to_inner_schema
)
# check the consistency of projections
projection_index_to_field = {
index_in_sql_result: attrkey
for _, projected_entities_dict in tag_to_projected_fields.items()
for attrkey, index_in_sql_result in projected_entities_dict.items()
}
if len(projections) > len(projection_index_to_field):
raise ValueError('You are projecting the same key multiple times within the same node')
if not projection_index_to_field:
raise ValueError('No projections requested')
return projections, tag_to_projected_fields
[docs]
def _create_projections(
tag: str,
aliases: Dict[str, Optional[AliasedClass]],
projection_count: int,
project_dict: List[Dict[str, dict]],
get_projectable_entity,
tag_to_projected_fields,
outer_to_inner_schema,
) -> List[Tuple[Union[AliasedClass, ColumnElement], bool]]:
"""Build the projections for a given tag.
:param tag: the tag of the node for which to build the projections
:param projection_count: the number of previous projections
:param items_to_project: the list of projections to build, if None, use the projections specified in the query
"""
# Return here if there is nothing to project, reduces number of key in return dictionary
QUERYBUILD_LOGGER.debug('projection for %s: %s', tag, project_dict)
if not project_dict:
return []
alias = aliases.get(tag)
if alias is None:
raise ValueError(f'No alias found for tag {tag}')
tag_to_projected_fields[tag] = {}
projections = []
for projectable_spec in project_dict:
for projectable_entity_name, extraspec in projectable_spec.items():
property_names = []
if projectable_entity_name == '**':
# Need to expand
property_names.extend(modify_expansions(alias, get_column_names(alias), outer_to_inner_schema))
else:
property_names.extend(modify_expansions(alias, [projectable_entity_name], outer_to_inner_schema))
for property_name in property_names:
projections.append(_get_projection(alias, property_name, get_projectable_entity, **extraspec))
tag_to_projected_fields[tag][property_name] = projection_count
projection_count += 1
return projections
[docs]
def _get_projection(
alias: AliasedClass,
projectable_entity_name: str,
get_projectable_entity,
cast: Optional[str] = None,
func: Optional[str] = None,
**_kw: Any,
) -> Tuple[Union[AliasedClass, ColumnElement], bool]:
"""
:param alias: An alias for an ormclass
:param projectable_entity_name:
User specification of what to project.
Appends to query's entities what the user wants to project
(have returned by the query)
:param cast: Cast the value to a different type
:param func: Apply a function to the projection
:return: The projection
"""
column_name = projectable_entity_name.split('.')[0]
attr_key = projectable_entity_name.split('.')[1:]
if column_name == '*':
if func is not None:
raise ValueError(
'Very sorry, but functions on the aliased class\n'
"(You specified '*')\n"
'will not work!\n'
"I suggest you apply functions on a column, e.g. ('id')\n"
)
return alias, True
entity_to_project = get_projectable_entity(alias, column_name, attr_key, cast=cast)
if func is None:
pass
elif func == 'max':
entity_to_project = sa_func.max(entity_to_project) # pylint: disable=not-callable
elif func == 'min':
entity_to_project = sa_func.max(entity_to_project) # pylint: disable=not-callable
elif func == 'count':
entity_to_project = sa_func.count(entity_to_project) # pylint: disable=not-callable
else:
raise ValueError(f'\nInvalid function specification {func}')
return entity_to_project, False