diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -4,6 +4,7 @@ # See top-level LICENSE file for more information from datetime import datetime +from hashlib import sha1 import logging import os from types import TracebackType @@ -252,7 +253,7 @@ ) revs = { - sha1 + sha1: RevisionData(date=None, origin=None) for sha1, date in self.cache["revision"]["data"].items() if sha1 in self.cache["revision"]["added"] and date is not None } @@ -267,7 +268,7 @@ ) paths = { - path + sha1(path).digest(): path for _, _, path in self.cache["content_in_revision"] | self.cache["content_in_directory"] | self.cache["directory_in_revision"] @@ -465,10 +466,10 @@ } ) dates: Dict[Sha1Git, datetime] = {} - for sha1 in ids: - date = cache["data"].setdefault(sha1, None) + for sha1sum in ids: + date = cache["data"].setdefault(sha1sum, None) if date is not None: - dates[sha1] = date + dates[sha1sum] = date return dates def open(self) -> None: diff --git a/swh/provenance/storage/interface.py b/swh/provenance/storage/interface.py --- a/swh/provenance/storage/interface.py +++ b/swh/provenance/storage/interface.py @@ -9,7 +9,7 @@ from datetime import datetime import enum from types import TracebackType -from typing import Dict, Generator, Iterable, List, Optional, Set, Type, Union +from typing import Dict, Generator, Iterable, List, Optional, Set, Type from typing_extensions import Protocol, runtime_checkable @@ -151,12 +151,12 @@ ... @remote_api_endpoint("location_add") - def location_add(self, paths: Iterable[bytes]) -> bool: + def location_add(self, paths: Dict[Sha1Git, bytes]) -> bool: """Register the given `paths` in the storage.""" ... @remote_api_endpoint("location_get_all") - def location_get_all(self) -> Set[bytes]: + def location_get_all(self) -> Dict[Sha1Git, bytes]: """Retrieve all paths present in the provenance model. This method is used only in tests.""" ... @@ -180,9 +180,7 @@ ... @remote_api_endpoint("revision_add") - def revision_add( - self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] - ) -> bool: + def revision_add(self, revs: Dict[Sha1Git, RevisionData]) -> bool: """Add revisions identified by sha1 ids, with optional associated date or origin (as paired in `revs`) to the provenance storage. Return a boolean stating if the information was successfully stored. diff --git a/swh/provenance/storage/postgresql.py b/swh/provenance/storage/postgresql.py --- a/swh/provenance/storage/postgresql.py +++ b/swh/provenance/storage/postgresql.py @@ -8,10 +8,11 @@ from contextlib import contextmanager from datetime import datetime from functools import wraps +from hashlib import sha1 import itertools import logging from types import TracebackType -from typing import Dict, Generator, Iterable, List, Optional, Set, Type, Union +from typing import Dict, Generator, Iterable, List, Optional, Set, Type import psycopg2.extensions import psycopg2.extras @@ -219,9 +220,9 @@ @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_add"}) @handle_raise_on_commit - def location_add(self, paths: Iterable[bytes]) -> bool: + def location_add(self, paths: Dict[Sha1Git, bytes]) -> bool: if self.with_path(): - values = [(path,) for path in paths] + values = [(path,) for path in paths.values()] if values: sql = """ INSERT INTO location(path) VALUES %s @@ -235,10 +236,10 @@ return True @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_get_all"}) - def location_get_all(self) -> Set[bytes]: + def location_get_all(self) -> Dict[Sha1Git, bytes]: with self.transaction(readonly=True) as cursor: cursor.execute("SELECT location.path AS path FROM location") - return {row["path"] for row in cursor} + return {sha1(row["path"]).digest(): row["path"] for row in cursor} @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_add"}) @handle_raise_on_commit @@ -284,14 +285,9 @@ @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_add"}) @handle_raise_on_commit - def revision_add( - self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] - ) -> bool: - if isinstance(revs, dict): + def revision_add(self, revs: Dict[Sha1Git, RevisionData]) -> bool: + if revs: data = [(sha1, rev.date, rev.origin) for sha1, rev in revs.items()] - else: - data = [(sha1, None, None) for sha1 in revs] - if data: sql = """ INSERT INTO revision(sha1, date, origin) (SELECT V.rev AS sha1, V.date::timestamptz AS date, O.id AS origin diff --git a/swh/provenance/storage/rabbitmq/client.py b/swh/provenance/storage/rabbitmq/client.py --- a/swh/provenance/storage/rabbitmq/client.py +++ b/swh/provenance/storage/rabbitmq/client.py @@ -71,6 +71,7 @@ if isinstance(data, dict): items = set(data.items()) else: + # TODO this is probably not ised any more items = {(item,) for item in data} for id, *rest in items: key = ProvenanceStorageRabbitMQServer.get_routing_key(id, meth_name) diff --git a/swh/provenance/storage/rabbitmq/server.py b/swh/provenance/storage/rabbitmq/server.py --- a/swh/provenance/storage/rabbitmq/server.py +++ b/swh/provenance/storage/rabbitmq/server.py @@ -480,7 +480,7 @@ elif meth_name == "directory_add": return resolve_directory elif meth_name == "location_add": - return lambda data: set(data) # just remove duplicates + return lambda data: dict(data) elif meth_name == "origin_add": return lambda data: dict(data) # last processed value is good enough elif meth_name == "revision_add": diff --git a/swh/provenance/tests/test_provenance_storage.py b/swh/provenance/tests/test_provenance_storage.py --- a/swh/provenance/tests/test_provenance_storage.py +++ b/swh/provenance/tests/test_provenance_storage.py @@ -4,6 +4,7 @@ # See top-level LICENSE file for more information from datetime import datetime, timezone +from hashlib import sha1 import inspect import os from typing import Any, Dict, Iterable, Optional, Set, Tuple @@ -99,13 +100,17 @@ # Add all names of entries present in the directories of the current repo as paths # to the storage. Then check that the returned results when querying are the same. - paths = {entry["name"] for dir in data["directory"] for entry in dir["entries"]} + paths = { + sha1(entry["name"]).digest(): entry["name"] + for dir in data["directory"] + for entry in dir["entries"] + } assert provenance_storage.location_add(paths) if provenance_storage.with_path(): assert provenance_storage.location_get_all() == paths else: - assert provenance_storage.location_get_all() == set() + assert not provenance_storage.location_get_all() @pytest.mark.origin_layer def test_provenance_storage_origin( @@ -143,22 +148,22 @@ # Origin must be inserted in advance. assert provenance_storage.origin_add({origin.id: origin.url}) - revs = {rev["id"] for idx, rev in enumerate(data["revision"]) if idx % 6 == 0} + revs = {rev["id"] for idx, rev in enumerate(data["revision"])} rev_data = { rev["id"]: RevisionData( date=ts2dt(rev["date"]) if idx % 2 != 0 else None, origin=origin.id if idx % 3 != 0 else None, ) for idx, rev in enumerate(data["revision"]) - if idx % 6 != 0 } assert revs - assert provenance_storage.revision_add(revs) assert provenance_storage.revision_add(rev_data) - assert provenance_storage.revision_get(set(rev_data.keys())) == rev_data - assert provenance_storage.entity_get_all(EntityType.REVISION) == revs | set( - rev_data.keys() - ) + assert provenance_storage.revision_get(set(rev_data.keys())) == { + k: v + for (k, v) in rev_data.items() + if v.date is not None or v.origin is not None + } + assert provenance_storage.entity_get_all(EntityType.REVISION) == set(rev_data) def test_provenance_storage_relation_revision_layer( self, @@ -476,7 +481,12 @@ assert entity_add(storage, EntityType(dst), dsts) if storage.with_path(): assert storage.location_add( - {rel.path for rels in data.values() for rel in rels if rel.path is not None} + { + sha1(rel.path).digest(): rel.path + for rels in data.values() + for rel in rels + if rel.path is not None + } ) assert data diff --git a/swh/provenance/tests/test_revision_content_layer.py b/swh/provenance/tests/test_revision_content_layer.py --- a/swh/provenance/tests/test_revision_content_layer.py +++ b/swh/provenance/tests/test_revision_content_layer.py @@ -317,9 +317,9 @@ rows["location"] |= set(x["path"].encode() for x in synth_rev["R_C"]) rows["location"] |= set(x["path"].encode() for x in synth_rev["D_C"]) rows["location"] |= set(x["path"].encode() for x in synth_rev["R_D"]) - assert rows["location"] == provenance.storage.location_get_all(), synth_rev[ - "msg" - ] + assert rows["location"] == set( + provenance.storage.location_get_all().values() + ), synth_rev["msg"] @pytest.mark.parametrize(