diff --git a/swh/storage/cassandra/common.py b/swh/storage/cassandra/common.py --- a/swh/storage/cassandra/common.py +++ b/swh/storage/cassandra/common.py @@ -5,7 +5,6 @@ import hashlib -from typing import Any, Dict, Tuple TOKEN_BEGIN = -(2 ** 63) """Minimum value returned by the CQL function token()""" @@ -15,7 +14,3 @@ def hash_url(url: str) -> bytes: return hashlib.sha1(url.encode("ascii")).digest() - - -def remove_keys(d: Dict[str, Any], keys: Tuple[str, ...]) -> Dict[str, Any]: - return {k: v for (k, v) in d.items() if k not in keys} diff --git a/swh/storage/cassandra/converters.py b/swh/storage/cassandra/converters.py --- a/swh/storage/cassandra/converters.py +++ b/swh/storage/cassandra/converters.py @@ -21,7 +21,7 @@ Sha1Git, ) -from .common import remove_keys +from ..utils import remove_keys from .model import OriginVisitRow, OriginVisitStatusRow, ReleaseRow, RevisionRow diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -44,7 +44,8 @@ ) from swh.storage.interface import ListOrder -from .common import TOKEN_BEGIN, TOKEN_END, hash_url, remove_keys +from ..utils import remove_keys +from .common import TOKEN_BEGIN, TOKEN_END, hash_url from .model import ( MAGIC_NULL_PK, BaseRow, diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -60,7 +60,8 @@ from . import converters from ..exc import HashCollision, StorageArgumentException -from .common import TOKEN_BEGIN, TOKEN_END, hash_url, remove_keys +from ..utils import remove_keys +from .common import TOKEN_BEGIN, TOKEN_END, hash_url from .cql import CqlRunner from .model import ( ContentRow, diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -28,6 +28,7 @@ OriginVisit, OriginVisitStatus, Person, + RawExtrinsicMetadata, Revision, SkippedContent, Snapshot, @@ -37,7 +38,12 @@ from swh.storage.common import origin_url_to_sha1 as sha1 from swh.storage.exc import HashCollision, StorageArgumentException from swh.storage.interface import ListOrder, PagedResult, StorageInterface -from swh.storage.utils import content_hex_hashes, now, round_to_milliseconds +from swh.storage.utils import ( + content_hex_hashes, + now, + remove_keys, + round_to_milliseconds, +) def transform_entries( @@ -3364,8 +3370,12 @@ object_type="content", object_id=hash_to_bytes(content.sha1_git) ) - new_content_metadata2 = attr.evolve( - content_metadata2, format="new-format", metadata=b"new-metadata", + new_content_metadata2 = RawExtrinsicMetadata.from_dict( + { + **remove_keys(content_metadata2.to_dict(), ("id",)), # recompute id + "format": "new-format", + "metadata": b"new-metadata", + } ) swh_storage.metadata_fetcher_add([fetcher]) @@ -3399,7 +3409,12 @@ content1_swhid = SWHID(object_type="content", object_id=content.sha1_git) content2_swhid = SWHID(object_type="content", object_id=content2.sha1_git) - content2_metadata = attr.evolve(content1_metadata2, target=content2_swhid) + content2_metadata = RawExtrinsicMetadata.from_dict( + { + **remove_keys(content1_metadata2.to_dict(), ("id",)), # recompute id + "target": str(content2_swhid), + } + ) swh_storage.metadata_authority_add([authority, authority2]) swh_storage.metadata_fetcher_add([fetcher, fetcher2]) @@ -3519,10 +3534,12 @@ swh_storage.metadata_fetcher_add([fetcher1, fetcher2]) swh_storage.metadata_authority_add([authority]) - new_content_metadata2 = attr.evolve( - content_metadata2, - discovery_date=content_metadata2.discovery_date, - fetcher=attr.evolve(fetcher2, metadata=None), + new_content_metadata2 = RawExtrinsicMetadata.from_dict( + { + **remove_keys(content_metadata2.to_dict(), ("id",)), # recompute id + "discovery_date": content_metadata2.discovery_date, + "fetcher": attr.evolve(fetcher2, metadata=None).to_dict(), + } ) swh_storage.raw_extrinsic_metadata_add( @@ -3543,6 +3560,7 @@ page_token=result.next_page_token, ) assert result.next_page_token is None + assert result.results[0].to_dict() == new_content_metadata2.to_dict() assert result.results == [new_content_metadata2] def test_content_metadata_get__invalid_id(self, swh_storage, sample_data): @@ -3601,8 +3619,12 @@ origin_metadata, origin_metadata2 = sample_data.origin_metadata[:2] assert swh_storage.origin_add([origin]) == {"origin:add": 1} - new_origin_metadata2 = attr.evolve( - origin_metadata2, format="new-format", metadata=b"new-metadata", + new_origin_metadata2 = RawExtrinsicMetadata.from_dict( + { + **remove_keys(origin_metadata2.to_dict(), ("id",)), # recompute id + "format": "new-format", + "metadata": b"new-metadata", + } ) swh_storage.metadata_fetcher_add([fetcher]) @@ -3637,7 +3659,12 @@ assert swh_storage.origin_add([origin, origin2]) == {"origin:add": 2} - origin2_metadata = attr.evolve(origin1_metadata2, target=origin2.url) + origin2_metadata = RawExtrinsicMetadata.from_dict( + { + **remove_keys(origin1_metadata2.to_dict(), ("id",)), # recompute id + "target": origin2.url, + } + ) swh_storage.metadata_authority_add([authority, authority2]) swh_storage.metadata_fetcher_add([fetcher, fetcher2]) @@ -3752,10 +3779,12 @@ swh_storage.metadata_fetcher_add([fetcher1, fetcher2]) swh_storage.metadata_authority_add([authority]) - new_origin_metadata2 = attr.evolve( - origin_metadata2, - discovery_date=origin_metadata2.discovery_date, - fetcher=attr.evolve(fetcher2, metadata=None), + new_origin_metadata2 = RawExtrinsicMetadata.from_dict( + { + **remove_keys(origin_metadata2.to_dict(), ("id",)), # recompute id + "discovery_date": origin_metadata2.discovery_date, + "fetcher": attr.evolve(fetcher2, metadata=None).to_dict(), + } ) swh_storage.raw_extrinsic_metadata_add([origin_metadata, new_origin_metadata2]) diff --git a/swh/storage/utils.py b/swh/storage/utils.py --- a/swh/storage/utils.py +++ b/swh/storage/utils.py @@ -98,6 +98,11 @@ return {algo: hash_to_bytes(content[algo]) for algo in DEFAULT_ALGORITHMS} +def remove_keys(d: Dict[T1, T2], keys: Tuple[T1, ...]) -> Dict[T1, T2]: + """Returns a copy of ``d`` minus the given keys.""" + return {k: v for (k, v) in d.items() if k not in keys} + + def round_to_milliseconds(date): """Round datetime to milliseconds before insertion, so equality doesn't fail after a round-trip through a DB (eg. Cassandra)