diff --git a/swh/provenance/storage/__init__.py b/swh/provenance/storage/__init__.py --- a/swh/provenance/storage/__init__.py +++ b/swh/provenance/storage/__init__.py @@ -48,5 +48,14 @@ if TYPE_CHECKING: assert isinstance(rmq_storage, ProvenanceStorageInterface) return rmq_storage + elif cls == "journal": + from swh.journal.writer import get_journal_writer + from swh.provenance.storage.journal import ProvenanceStorageJournal + + storage = get_provenance_storage(**kwargs["storage"]) + journal = get_journal_writer(**kwargs["journal_writer"]) + + ret = ProvenanceStorageJournal(storage=storage, journal=journal) + return ret raise ValueError diff --git a/swh/provenance/storage/journal.py b/swh/provenance/storage/journal.py new file mode 100644 --- /dev/null +++ b/swh/provenance/storage/journal.py @@ -0,0 +1,152 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from __future__ import annotations + +from dataclasses import asdict +from datetime import datetime +from types import TracebackType +from typing import Dict, Generator, Iterable, List, Optional, Set, Type + +from swh.model.model import Sha1Git +from swh.provenance.storage.interface import ( + DirectoryData, + EntityType, + ProvenanceResult, + ProvenanceStorageInterface, + RelationData, + RelationType, + RevisionData, +) + + +class JournalMessage: + def __init__(self, id, value): + self.id = id + self.value = value + + def anonymize(self): + return None + + def unique_key(self): + return self.id + + def to_dict(self): + return { + "id": self.id, + "value": self.value, + } + + +class ProvenanceStorageJournal: + def __init__(self, storage, journal): + self.storage = storage + self.journal = journal + + def __enter__(self) -> ProvenanceStorageInterface: + self.storage.__enter__() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + return self.storage.__exit__(exc_type, exc_val, exc_tb) + + def open(self) -> None: + self.storage.open() + + def close(self) -> None: + self.storage.close() + + def content_add(self, cnts: Dict[Sha1Git, datetime]) -> bool: + self.journal.write_additions( + "content", [JournalMessage(key, value) for (key, value) in cnts.items()] + ) + return self.storage.content_add(cnts) + + def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: + return self.storage.content_find_first(id) + + def content_find_all( + self, id: Sha1Git, limit: Optional[int] = None + ) -> Generator[ProvenanceResult, None, None]: + return self.storage.content_find_all(id, limit) + + def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: + return self.storage.content_get(ids) + + def directory_add(self, dirs: Dict[Sha1Git, DirectoryData]) -> bool: + self.journal.write_additions( + "directory", + [JournalMessage(key, asdict(value)) for (key, value) in dirs.items()], + ) + return self.storage.directory_add(dirs) + + def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, DirectoryData]: + return self.storage.directory_get(ids) + + def directory_iter_not_flattenned( + self, limit: int, start_id: Sha1Git + ) -> List[Sha1Git]: + return self.storage.directory_iter_not_flattenned(limit, start_id) + + def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: + return self.storage.entity_get_all(entity) + + def location_add(self, paths: Dict[Sha1Git, bytes]) -> bool: + self.journal.write_additions( + "location", [JournalMessage(key, value) for (key, value) in paths.items()] + ) + return self.storage.location_add(paths) + + def location_get_all(self) -> Dict[Sha1Git, bytes]: + return self.storage.location_get_all() + + def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: + self.journal.write_additions( + "origin", [JournalMessage(key, value) for (key, value) in orgs.items()] + ) + return self.storage.origin_add(orgs) + + def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: + return self.storage.origin_get(ids) + + def revision_add(self, revs: Dict[Sha1Git, RevisionData]) -> bool: + self.journal.write_additions( + "revision", + [JournalMessage(key, asdict(value)) for (key, value) in revs.items()], + ) + return self.storage.revision_add(revs) + + def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: + return self.storage.revision_get(ids) + + def relation_add( + self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] + ) -> bool: + self.journal.write_additions( + relation.value, + [ + JournalMessage(key, [asdict(reldata) for reldata in value]) + for (key, value) in data.items() + ], + ) + return self.storage.relation_add(relation, data) + + def relation_get( + self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False + ) -> Dict[Sha1Git, Set[RelationData]]: + return self.storage.relation_get(relation, ids, reverse) + + def relation_get_all( + self, relation: RelationType + ) -> Dict[Sha1Git, Set[RelationData]]: + return self.storage.relation_get_all(relation) + + def with_path(self) -> bool: + return self.storage.with_path() diff --git a/swh/provenance/tests/test_provenance_journal_writer.py b/swh/provenance/tests/test_provenance_journal_writer.py new file mode 100644 --- /dev/null +++ b/swh/provenance/tests/test_provenance_journal_writer.py @@ -0,0 +1,193 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from dataclasses import asdict +from typing import Dict, Generator + +import pytest + +from swh.provenance import get_provenance_storage +from swh.provenance.storage.interface import ( + EntityType, + ProvenanceStorageInterface, + RelationType, +) + +from .test_provenance_storage import TestProvenanceStorage as _TestProvenanceStorage + + +@pytest.fixture() +def provenance_storage( + provenance_postgresqldb: Dict[str, str], +) -> Generator[ProvenanceStorageInterface, None, None]: + cfg = { + "storage": { + "cls": "postgresql", + "db": provenance_postgresqldb, + "raise_on_commit": True, + }, + "journal_writer": { + "cls": "memory", + }, + } + with get_provenance_storage(cls="journal", **cfg) as storage: + yield storage + + +class TestProvenanceStorageJournal(_TestProvenanceStorage): + def test_provenance_storage_content(self, provenance_storage): + super().test_provenance_storage_content(provenance_storage) + assert provenance_storage.journal + objtypes = set(objtype for (objtype, obj) in provenance_storage.journal.objects) + assert objtypes == {"content"} + + journal_objs = set( + obj.id + for (objtype, obj) in provenance_storage.journal.objects + if objtype == "content" + ) + assert provenance_storage.entity_get_all(EntityType.CONTENT) == journal_objs + + def test_provenance_storage_directory(self, provenance_storage): + super().test_provenance_storage_directory(provenance_storage) + assert provenance_storage.journal + objtypes = set(objtype for (objtype, obj) in provenance_storage.journal.objects) + assert objtypes == {"directory"} + + journal_objs = set( + obj.id + for (objtype, obj) in provenance_storage.journal.objects + if objtype == "directory" + ) + assert provenance_storage.entity_get_all(EntityType.DIRECTORY) == journal_objs + + def test_provenance_storage_location(self, provenance_storage): + super().test_provenance_storage_location(provenance_storage) + assert provenance_storage.journal + objtypes = set(objtype for (objtype, obj) in provenance_storage.journal.objects) + assert objtypes == {"location"} + + journal_objs = { + obj.id: obj.value + for (objtype, obj) in provenance_storage.journal.objects + if objtype == "location" + } + assert provenance_storage.location_get_all() == journal_objs + + def test_provenance_storage_orign(self, provenance_storage): + super().test_provenance_storage_origin(provenance_storage) + assert provenance_storage.journal + objtypes = set(objtype for (objtype, obj) in provenance_storage.journal.objects) + assert objtypes == {"origin"} + + journal_objs = set( + obj.id + for (objtype, obj) in provenance_storage.journal.objects + if objtype == "origin" + ) + assert provenance_storage.entity_get_all(EntityType.ORIGIN) == journal_objs + + def test_provenance_storage_revision(self, provenance_storage): + super().test_provenance_storage_revision(provenance_storage) + assert provenance_storage.journal + objtypes = set(objtype for (objtype, obj) in provenance_storage.journal.objects) + assert objtypes == {"revision", "origin"} + + journal_objs = set( + obj.id + for (objtype, obj) in provenance_storage.journal.objects + if objtype == "revision" + ) + assert provenance_storage.entity_get_all(EntityType.REVISION) == journal_objs + + def test_provenance_storage_relation_revision_layer(self, provenance_storage): + super().test_provenance_storage_relation_revision_layer(provenance_storage) + assert provenance_storage.journal + objtypes = set(objtype for (objtype, obj) in provenance_storage.journal.objects) + assert objtypes == { + "location", + "content", + "directory", + "revision", + "content_in_revision", + "content_in_directory", + "directory_in_revision", + } + + journal_rels = { + obj.id: {tuple(v.items()) for v in obj.value} + for (objtype, obj) in provenance_storage.journal.objects + if objtype == "content_in_revision" + } + prov_rels = { + k: {tuple(asdict(reldata).items()) for reldata in v} + for k, v in provenance_storage.relation_get_all( + RelationType.CNT_EARLY_IN_REV + ).items() + } + assert prov_rels == journal_rels + + journal_rels = { + obj.id: {tuple(v.items()) for v in obj.value} + for (objtype, obj) in provenance_storage.journal.objects + if objtype == "content_in_directory" + } + prov_rels = { + k: {tuple(asdict(reldata).items()) for reldata in v} + for k, v in provenance_storage.relation_get_all( + RelationType.CNT_IN_DIR + ).items() + } + assert prov_rels == journal_rels + + journal_rels = { + obj.id: {tuple(v.items()) for v in obj.value} + for (objtype, obj) in provenance_storage.journal.objects + if objtype == "directory_in_revision" + } + prov_rels = { + k: {tuple(asdict(reldata).items()) for reldata in v} + for k, v in provenance_storage.relation_get_all( + RelationType.DIR_IN_REV + ).items() + } + assert prov_rels == journal_rels + + def test_provenance_storage_relation_origin_layer(self, provenance_storage): + super().test_provenance_storage_relation_orign_layer(provenance_storage) + assert provenance_storage.journal + objtypes = set(objtype for (objtype, obj) in provenance_storage.journal.objects) + assert objtypes == { + "origin", + "revision", + "revision_in_origin", + "revision_before_revision", + } + + journal_rels = { + obj.id: {tuple(v.items()) for v in obj.value} + for (objtype, obj) in provenance_storage.journal.objects + if objtype == "revision_in_origin" + } + prov_rels = { + k: {tuple(asdict(reldata).items()) for reldata in v} + for k, v in provenance_storage.relation_get_all( + RelationType.REV_IN_ORG + ).items() + } + assert prov_rels == journal_rels + + journal_rels = { + obj.id: {tuple(v.items()) for v in obj.value} + for (objtype, obj) in provenance_storage.journal.objects + if objtype == "revision_before_revision" + } + prov_rels = { + k: {tuple(asdict(reldata).items()) for reldata in v} + for k, v in provenance_storage.relation_get_all( + RelationType.REV_BEFORE_REV + ).items() + } + assert prov_rels == journal_rels diff --git a/swh/provenance/tests/test_provenance_journal_writer_kafka.py b/swh/provenance/tests/test_provenance_journal_writer_kafka.py new file mode 100644 --- /dev/null +++ b/swh/provenance/tests/test_provenance_journal_writer_kafka.py @@ -0,0 +1,41 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from typing import Dict, Generator + +from confluent_kafka import Consumer +import pytest + +from swh.provenance import get_provenance_storage +from swh.provenance.storage.interface import ProvenanceStorageInterface + +from .test_provenance_storage import ( # noqa + TestProvenanceStorage as TestProvenanceStorage, +) + + +@pytest.fixture() +def provenance_storage( + provenance_postgresqldb: Dict[str, str], + kafka_prefix: str, + kafka_server: str, + consumer: Consumer, +) -> Generator[ProvenanceStorageInterface, None, None]: + cfg = { + "storage": { + "cls": "postgresql", + "db": provenance_postgresqldb, + "raise_on_commit": True, + }, + "journal_writer": { + "cls": "kafka", + "brokers": [kafka_server], + "client_id": "kafka_writer", + "prefix": "swh.provenance", + "anonymize": False, + }, + } + with get_provenance_storage(cls="journal", **cfg) as storage: + yield storage