###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Abstraction for an archive file format."""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, BinaryIO, Dict, List, Literal, Optional, Type, TypeVar, Union, overload
if TYPE_CHECKING:
from aiida.orm import QueryBuilder
from aiida.orm.entities import Entity, EntityTypes
from aiida.orm.implementation import StorageBackend
from aiida.tools.visualization.graph import Graph
SelfType = TypeVar('SelfType')
EntityType = TypeVar('EntityType', bound='Entity')
__all__ = ('ArchiveFormatAbstract', 'ArchiveReaderAbstract', 'ArchiveWriterAbstract', 'get_format')
[docs]
class ArchiveWriterAbstract(ABC):
"""Writer of an archive, that will be used as a context manager."""
[docs]
def __init__(
self,
path: Union[str, Path],
fmt: 'ArchiveFormatAbstract',
*,
mode: Literal['x', 'w', 'a'] = 'x',
compression: int = 6,
**kwargs: Any,
):
"""Initialise the writer.
:param path: archive path
:param mode: mode to open the archive in: 'x' (exclusive), 'w' (write) or 'a' (append)
:param compression: default level of compression to use (integer from 0 to 9)
"""
self._path = Path(path)
if mode not in ('x', 'w', 'a'):
raise ValueError(f'mode not in x, w, a: {mode}')
self._mode = mode
if compression not in range(10):
raise ValueError(f'compression not in range 0-9: {compression}')
self._compression = compression
self._format = fmt
self._kwargs = kwargs
@property
def path(self) -> Path:
"""Return the path to the archive."""
return self._path
@property
def mode(self) -> Literal['x', 'w', 'a']:
"""Return the mode of the archive."""
return self._mode
@property
def compression(self) -> int:
"""Return the compression level."""
return self._compression
[docs]
def __enter__(self: SelfType) -> SelfType:
"""Start writing to the archive."""
return self
[docs]
def __exit__(self, *args, **kwargs) -> None:
"""Finalise the archive."""
[docs]
@abstractmethod
def bulk_insert(
self,
entity_type: 'EntityTypes',
rows: List[Dict[str, Any]],
allow_defaults: bool = False,
) -> None:
"""Add multiple rows of entity data to the archive.
:param entity_type: The type of the entity
:param data: A list of dictionaries, containing all fields of the backend model,
except the `id` field (a.k.a primary key), which will be generated dynamically
:param allow_defaults: If ``False``, assert that each row contains all fields,
otherwise, allow default values for missing fields.
:raises: ``IntegrityError`` if the keys in a row are not a subset of the columns in the table
"""
[docs]
@abstractmethod
def put_object(self, stream: BinaryIO, *, buffer_size: Optional[int] = None, key: Optional[str] = None) -> str:
"""Add an object to the archive.
:param stream: byte stream to read the object from
:param buffer_size: Number of bytes to buffer when read/writing
:param key: key to use for the object (if None will be auto-generated)
:return: the key of the object
"""
[docs]
@abstractmethod
def delete_object(self, key: str) -> None:
"""Delete the object from the archive.
:param key: fully qualified identifier for the object within the repository.
:raise OSError: if the file could not be deleted.
"""
[docs]
class ArchiveReaderAbstract(ABC):
"""Reader of an archive, that will be used as a context manager."""
[docs]
def __init__(self, path: Union[str, Path], **kwargs: Any):
"""Initialise the reader.
:param path: archive path
"""
self._path = Path(path)
@property
def path(self):
"""Return the path to the archive."""
return self._path
[docs]
def __enter__(self: SelfType) -> SelfType:
"""Start reading from the archive."""
return self
[docs]
def __exit__(self, *args, **kwargs) -> None:
"""Finalise the archive."""
[docs]
@abstractmethod
def get_backend(self) -> 'StorageBackend':
"""Return a 'read-only' backend for the archive."""
# below are convenience methods for some common use cases
[docs]
def querybuilder(self, **kwargs: Any) -> 'QueryBuilder':
"""Return a ``QueryBuilder`` instance, initialised with the archive backend."""
from aiida.orm import QueryBuilder
return QueryBuilder(backend=self.get_backend(), **kwargs)
[docs]
def get(self, entity_cls: Type[EntityType], **filters: Any) -> EntityType:
"""Return the entity for the given filters.
Example::
reader.get(orm.Node, pk=1)
:param entity_cls: The type of the front-end entity
:param filters: the filters identifying the object to get
"""
if 'pk' in filters:
filters['id'] = filters.pop('pk')
return self.querybuilder().append(entity_cls, filters=filters).one()[0]
[docs]
def graph(self, **kwargs: Any) -> 'Graph':
"""Return a provenance graph generator for the archive."""
from aiida.tools.visualization.graph import Graph
return Graph(backend=self.get_backend(), **kwargs)