Page MenuHomeSoftware Heritage

D3770.id13325.diff
No OneTemporary

D3770.id13325.diff

diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py
--- a/swh/storage/in_memory.py
+++ b/swh/storage/in_memory.py
@@ -8,6 +8,7 @@
import collections
import copy
import datetime
+import functools
import itertools
import random
import re
@@ -46,19 +47,21 @@
OriginVisit,
OriginVisitStatus,
Origin,
- SHA1_SIZE,
MetadataAuthority,
MetadataAuthorityType,
MetadataFetcher,
MetadataTargetType,
RawExtrinsicMetadata,
- Sha1,
Sha1Git,
)
-from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex
+from swh.model.hashutil import DEFAULT_ALGORITHMS
from swh.storage.cassandra import CassandraStorage
-from swh.storage.cassandra.model import BaseRow, ObjectCountRow
+from swh.storage.cassandra.model import (
+ BaseRow,
+ ContentRow,
+ ObjectCountRow,
+)
from swh.storage.interface import (
ListOrder,
PagedResult,
@@ -69,8 +72,7 @@
from swh.storage.utils import now
from .converters import origin_url_to_sha1
-from .exc import StorageArgumentException, HashCollision
-from .utils import get_partition_bounds_bytes
+from .exc import StorageArgumentException
from .writer import JournalWriter
# Max block size of contents to return
@@ -222,6 +224,8 @@
class InMemoryCqlRunner:
def __init__(self):
+ self._contents = Table(ContentRow)
+ self._content_indexes = defaultdict(lambda: defaultdict(set))
self._stat_counters = defaultdict(int)
def increment_counter(self, object_type: str, nb: int):
@@ -231,12 +235,67 @@
for (object_type, count) in self._stat_counters.items():
yield ObjectCountRow(partition_key=0, object_type=object_type, count=count)
+ ##########################
+ # 'content' table
+ ##########################
+
+ def _content_add_finalize(self, content: ContentRow) -> None:
+ self._contents.insert(content)
+ self.increment_counter("content", 1)
+
+ def content_add_prepare(self, content: ContentRow):
+ finalizer = functools.partial(self._content_add_finalize, content)
+ return (self._contents.token(self._contents.partition_key(content)), finalizer)
+
+ def content_get_from_pk(
+ self, content_hashes: Dict[str, bytes]
+ ) -> Optional[ContentRow]:
+ primary_key = self._contents.primary_key_from_dict(content_hashes)
+ return self._contents.get_from_primary_key(primary_key)
+
+ def content_get_from_token(self, token: int) -> Iterable[ContentRow]:
+ return self._contents.get_from_token(token)
+
+ def content_get_random(self) -> Optional[ContentRow]:
+ return random.choice(
+ [
+ row
+ for partition in self._contents.data.values()
+ for row in partition.values()
+ ]
+ )
+
+ def content_get_token_range(
+ self, start: int, end: int, limit: int,
+ ) -> Iterable[Tuple[int, ContentRow]]:
+ matches = [
+ (token, row)
+ for (token, partition) in self._contents.data.items()
+ for (clustering_key, row) in partition.items()
+ if start <= token <= end
+ ]
+ matches.sort()
+ return matches[0:limit]
+
##########################
# 'content_by_*' tables
##########################
def content_missing_by_sha1_git(self, ids: List[bytes]) -> List[bytes]:
- return ids
+ missing = []
+ for id_ in ids:
+ if id_ not in self._content_indexes["sha1_git"]:
+ missing.append(id_)
+
+ return missing
+
+ def content_index_add_one(self, algo: str, content: Content, token: int) -> None:
+ self._content_indexes[algo][content.get_hash(algo)].add(token)
+
+ def content_get_tokens_from_single_hash(
+ self, algo: str, hash_: bytes
+ ) -> Iterable[int]:
+ return self._content_indexes[algo][hash_]
##########################
# 'directory' table
@@ -269,8 +328,6 @@
def reset(self):
self._cql_runner = InMemoryCqlRunner()
- self._contents = {}
- self._content_indexes = defaultdict(lambda: defaultdict(set))
self._skipped_contents = {}
self._skipped_content_indexes = defaultdict(lambda: defaultdict(set))
self._directories = {}
@@ -318,184 +375,6 @@
def check_config(self, *, check_write: bool) -> bool:
return True
- def _content_add(self, contents: List[Content], with_data: bool) -> Dict:
- self.journal_writer.content_add(contents)
-
- content_add = 0
- if with_data:
- summary = self.objstorage.content_add(
- c for c in contents if c.status != "absent"
- )
- content_add_bytes = summary["content:add:bytes"]
-
- for content in contents:
- key = self._content_key(content)
- if key in self._contents:
- continue
- for algorithm in DEFAULT_ALGORITHMS:
- hash_ = content.get_hash(algorithm)
- if hash_ in self._content_indexes[algorithm] and (
- algorithm not in {"blake2s256", "sha256"}
- ):
- colliding_content_hashes = []
- # Add the already stored contents
- for content_hashes_set in self._content_indexes[algorithm][hash_]:
- hashes = dict(content_hashes_set)
- colliding_content_hashes.append(hashes)
- # Add the new colliding content
- colliding_content_hashes.append(content.hashes())
- raise HashCollision(algorithm, hash_, colliding_content_hashes)
- for algorithm in DEFAULT_ALGORITHMS:
- hash_ = content.get_hash(algorithm)
- self._content_indexes[algorithm][hash_].add(key)
- self._objects[content.sha1_git].append(("content", content.sha1))
- self._contents[key] = content
- self._sorted_sha1s.add(content.sha1)
- self._contents[key] = attr.evolve(self._contents[key], data=None)
- content_add += 1
-
- self._cql_runner.increment_counter("content", content_add)
-
- summary = {
- "content:add": content_add,
- }
- if with_data:
- summary["content:add:bytes"] = content_add_bytes
-
- return summary
-
- def content_add(self, content: List[Content]) -> Dict:
- content = [attr.evolve(c, ctime=now()) for c in content]
- return self._content_add(content, with_data=True)
-
- def content_update(
- self, contents: List[Dict[str, Any]], keys: List[str] = []
- ) -> None:
- self.journal_writer.content_update(contents)
-
- for cont_update in contents:
- cont_update = cont_update.copy()
- sha1 = cont_update.pop("sha1")
- for old_key in self._content_indexes["sha1"][sha1]:
- old_cont = self._contents.pop(old_key)
-
- for algorithm in DEFAULT_ALGORITHMS:
- hash_ = old_cont.get_hash(algorithm)
- self._content_indexes[algorithm][hash_].remove(old_key)
-
- new_cont = attr.evolve(old_cont, **cont_update)
- new_key = self._content_key(new_cont)
-
- self._contents[new_key] = new_cont
-
- for algorithm in DEFAULT_ALGORITHMS:
- hash_ = new_cont.get_hash(algorithm)
- self._content_indexes[algorithm][hash_].add(new_key)
-
- def content_add_metadata(self, content: List[Content]) -> Dict:
- return self._content_add(content, with_data=False)
-
- def content_get_data(self, content: Sha1) -> Optional[bytes]:
- # FIXME: Make this method support slicing the `data`
- return self.objstorage.content_get(content)
-
- def content_get_partition(
- self,
- partition_id: int,
- nb_partitions: int,
- page_token: Optional[str] = None,
- limit: int = 1000,
- ) -> PagedResult[Content]:
- if limit is None:
- raise StorageArgumentException("limit should not be None")
- (start, end) = get_partition_bounds_bytes(
- partition_id, nb_partitions, SHA1_SIZE
- )
- if page_token:
- start = hash_to_bytes(page_token)
- if end is None:
- end = b"\xff" * SHA1_SIZE
-
- next_page_token: Optional[str] = None
- sha1s = (
- (sha1, content_key)
- for sha1 in self._sorted_sha1s.iter_from(start)
- for content_key in self._content_indexes["sha1"][sha1]
- )
- contents: List[Content] = []
- for counter, (sha1, key) in enumerate(sha1s):
- if sha1 > end:
- break
- if counter >= limit:
- next_page_token = hash_to_hex(sha1)
- break
- contents.append(self._contents[key])
-
- assert len(contents) <= limit
- return PagedResult(results=contents, next_page_token=next_page_token)
-
- def content_get(self, contents: List[Sha1]) -> List[Optional[Content]]:
- contents_by_sha1: Dict[Sha1, Optional[Content]] = {}
- for sha1 in contents:
- if sha1 in self._content_indexes["sha1"]:
- objs = self._content_indexes["sha1"][sha1]
- # only 1 element as content_add_metadata would have raised a
- # hash collision otherwise
- assert len(objs) == 1
- for key in objs:
- content = attr.evolve(self._contents[key], data=None, ctime=None)
- contents_by_sha1[sha1] = content
- return [contents_by_sha1.get(sha1) for sha1 in contents]
-
- def content_find(self, content: Dict[str, Any]) -> List[Content]:
- if not set(content).intersection(DEFAULT_ALGORITHMS):
- raise StorageArgumentException(
- "content keys must contain at least one "
- f"of: {', '.join(sorted(DEFAULT_ALGORITHMS))}"
- )
- found = []
- for algo in DEFAULT_ALGORITHMS:
- hash = content.get(algo)
- if hash and hash in self._content_indexes[algo]:
- found.append(self._content_indexes[algo][hash])
-
- if not found:
- return []
-
- keys = list(set.intersection(*found))
- return [self._contents[key] for key in keys]
-
- def content_missing(
- self, contents: List[Dict[str, Any]], key_hash: str = "sha1"
- ) -> Iterable[bytes]:
- if key_hash not in DEFAULT_ALGORITHMS:
- raise StorageArgumentException(
- "key_hash should be one of {','.join(DEFAULT_ALGORITHMS)}"
- )
-
- for content in contents:
- for (algo, hash_) in content.items():
- if algo not in DEFAULT_ALGORITHMS:
- continue
- if hash_ not in self._content_indexes.get(algo, []):
- yield content[key_hash]
- break
-
- def content_missing_per_sha1(self, contents: List[bytes]) -> Iterable[bytes]:
- for content in contents:
- if content not in self._content_indexes["sha1"]:
- yield content
-
- def content_missing_per_sha1_git(
- self, contents: List[Sha1Git]
- ) -> Iterable[Sha1Git]:
- for content in contents:
- if content not in self._content_indexes["sha1_git"]:
- yield content
-
- def content_get_random(self) -> Sha1Git:
- return random.choice(list(self._content_indexes["sha1_git"]))
-
def _skipped_content_add(self, contents: List[SkippedContent]) -> Dict:
self.journal_writer.skipped_content_add(contents)
diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py
--- a/swh/storage/tests/test_api_client.py
+++ b/swh/storage/tests/test_api_client.py
@@ -3,8 +3,6 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
-from unittest.mock import patch
-
import pytest
import swh.storage.api.server as server
@@ -63,8 +61,6 @@
class TestStorage(_TestStorage):
- def test_content_update(self, swh_storage, app_server, sample_data):
- # TODO, journal_writer not supported
- swh_storage.journal_writer.journal = None
- with patch.object(server.storage.journal_writer, "journal", None):
- super().test_content_update(swh_storage, sample_data)
+ @pytest.mark.skip("content_update is not yet implemented for Cassandra")
+ def test_content_update(self):
+ pass
diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py
--- a/swh/storage/tests/test_in_memory.py
+++ b/swh/storage/tests/test_in_memory.py
@@ -9,7 +9,8 @@
from swh.storage.cassandra.model import BaseRow
from swh.storage.in_memory import SortedList, Table
-from swh.storage.tests.test_storage import TestStorage, TestStorageGeneratedData # noqa
+from swh.storage.tests.test_storage import TestStorage as _TestStorage
+from swh.storage.tests.test_storage import TestStorageGeneratedData # noqa
# tests are executed using imported classes (TestStorage and
@@ -144,3 +145,9 @@
assert len(all_rows) == 3
for row in (row1, row2, row3):
assert (table.primary_key(row), row) in all_rows
+
+
+class TestInMemoryStorage(_TestStorage):
+ @pytest.mark.skip("content_update is not yet implemented for Cassandra")
+ def test_content_update(self):
+ pass
diff --git a/swh/storage/tests/test_replay.py b/swh/storage/tests/test_replay.py
--- a/swh/storage/tests/test_replay.py
+++ b/swh/storage/tests/test_replay.py
@@ -3,16 +3,19 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
+import dataclasses
import datetime
import functools
import logging
from typing import Any, Container, Dict, Optional
+import attr
import pytest
from swh.model.hashutil import hash_to_hex, MultiHash, DEFAULT_ALGORITHMS
from swh.storage import get_storage
+from swh.storage.cassandra.model import ContentRow, SkippedContentRow
from swh.storage.in_memory import InMemoryStorage
from swh.storage.replay import process_replay_objects
@@ -28,6 +31,13 @@
UTC = datetime.timezone.utc
+def nullify_ctime(obj):
+ if isinstance(obj, (ContentRow, SkippedContentRow)):
+ return dataclasses.replace(obj, ctime=None)
+ else:
+ return obj
+
+
@pytest.fixture()
def replayer_storage_and_client(
kafka_prefix: str, kafka_consumer_group: str, kafka_server: str
@@ -120,6 +130,8 @@
for content in DUPLICATE_CONTENTS:
topic = f"{prefix}.content"
key = content.sha1
+ now = datetime.datetime.now(tz=UTC)
+ content = attr.evolve(content, ctime=now)
producer.produce(
topic=topic, key=key_to_kafka(key), value=value_to_kafka(content.to_dict()),
)
@@ -162,7 +174,10 @@
# all objects from the src should exists in the dst storage
_check_replayed(src, dst, exclude=["contents"])
# but the dst has one content more (one of the 2 colliding ones)
- assert len(src._contents) == len(dst._contents) - 1
+ assert (
+ len(list(src._cql_runner._contents.iter_all()))
+ == len(list(dst._cql_runner._contents.iter_all())) - 1
+ )
def test_replay_skipped_content(replayer_storage_and_client):
@@ -190,8 +205,7 @@
got_persons = set(dst._persons.values())
assert got_persons == expected_persons
- for attr in (
- "contents",
+ for attr_ in (
"skipped_contents",
"directories",
"revisions",
@@ -201,11 +215,24 @@
"origin_visits",
"origin_visit_statuses",
):
- if exclude and attr in exclude:
+ if exclude and attr_ in exclude:
continue
- expected_objects = sorted(getattr(src, f"_{attr}").items())
- got_objects = sorted(getattr(dst, f"_{attr}").items())
- assert got_objects == expected_objects, f"Mismatch object list for {attr}"
+ expected_objects = sorted(getattr(src, f"_{attr_}").items())
+ got_objects = sorted(getattr(dst, f"_{attr_}").items())
+ assert got_objects == expected_objects, f"Mismatch object list for {attr_}"
+
+ for attr_ in ("contents",):
+ if exclude and attr_ in exclude:
+ continue
+ expected_objects = [
+ (id, nullify_ctime(obj))
+ for id, obj in sorted(getattr(src._cql_runner, f"_{attr_}").iter_all())
+ ]
+ got_objects = [
+ (id, nullify_ctime(obj))
+ for id, obj in sorted(getattr(dst._cql_runner, f"_{attr_}").iter_all())
+ ]
+ assert got_objects == expected_objects, f"Mismatch object list for {attr_}"
def _check_replay_skipped_content(storage, replayer, topic):
@@ -329,17 +356,28 @@
"""
- def maybe_anonymize(obj):
+ def maybe_anonymize(attr_, row):
if expected_anonymized:
- return obj.anonymize() or obj
- return obj
-
- expected_persons = {maybe_anonymize(person) for person in src._persons.values()}
+ if hasattr(row, "anonymize"):
+ # for model objects; cases below are for BaseRow objects
+ row = row.anonymize() or row
+ elif attr_ == "releases":
+ row = dataclasses.replace(row, author=row.author.anonymize())
+ elif attr_ == "revisions":
+ row = dataclasses.replace(
+ row,
+ author=row.author.anonymize(),
+ committer=row.committer.anonymize(),
+ )
+ return row
+
+ expected_persons = {
+ maybe_anonymize("persons", person) for person in src._persons.values()
+ }
got_persons = set(dst._persons.values())
assert got_persons == expected_persons
- for attr in (
- "contents",
+ for attr_ in (
"skipped_contents",
"directories",
"revisions",
@@ -349,10 +387,21 @@
"origin_visit_statuses",
):
expected_objects = [
- (id, maybe_anonymize(obj))
- for id, obj in sorted(getattr(src, f"_{attr}").items())
+ (id, maybe_anonymize(attr_, obj))
+ for id, obj in sorted(getattr(src, f"_{attr_}").items())
+ ]
+ got_objects = [
+ (id, obj) for id, obj in sorted(getattr(dst, f"_{attr_}").items())
+ ]
+ assert got_objects == expected_objects, f"Mismatch object list for {attr_}"
+
+ for attr_ in ("contents",):
+ expected_objects = [
+ (id, nullify_ctime(maybe_anonymize(attr_, obj)))
+ for id, obj in sorted(getattr(src._cql_runner, f"_{attr_}").iter_all())
]
got_objects = [
- (id, obj) for id, obj in sorted(getattr(dst, f"_{attr}").items())
+ (id, nullify_ctime(obj))
+ for id, obj in sorted(getattr(dst._cql_runner, f"_{attr_}").iter_all())
]
- assert got_objects == expected_objects, f"Mismatch object list for {attr}"
+ assert got_objects == expected_objects, f"Mismatch object list for {attr_}"
diff --git a/swh/storage/tests/test_retry.py b/swh/storage/tests/test_retry.py
--- a/swh/storage/tests/test_retry.py
+++ b/swh/storage/tests/test_retry.py
@@ -13,6 +13,7 @@
from swh.model.model import MetadataTargetType
from swh.storage.exc import HashCollision, StorageArgumentException
+from swh.storage.utils import now
@pytest.fixture
@@ -120,7 +121,7 @@
content_metadata = swh_storage.content_get([pk])
assert content_metadata == [None]
- s = swh_storage.content_add_metadata([content])
+ s = swh_storage.content_add_metadata([attr.evolve(content, ctime=now())])
assert s == {
"content:add": 1,
}

File Metadata

Mime Type
text/plain
Expires
Nov 5 2024, 1:19 AM (11 w, 17 h ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3222759

Event Timeline