Page Menu
Home
Software Heritage
Search
Configure Global Search
Log In
Files
F7066211
D3770.id13325.diff
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
19 KB
Subscribers
None
D3770.id13325.diff
View Options
diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py
--- a/swh/storage/in_memory.py
+++ b/swh/storage/in_memory.py
@@ -8,6 +8,7 @@
import collections
import copy
import datetime
+import functools
import itertools
import random
import re
@@ -46,19 +47,21 @@
OriginVisit,
OriginVisitStatus,
Origin,
- SHA1_SIZE,
MetadataAuthority,
MetadataAuthorityType,
MetadataFetcher,
MetadataTargetType,
RawExtrinsicMetadata,
- Sha1,
Sha1Git,
)
-from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex
+from swh.model.hashutil import DEFAULT_ALGORITHMS
from swh.storage.cassandra import CassandraStorage
-from swh.storage.cassandra.model import BaseRow, ObjectCountRow
+from swh.storage.cassandra.model import (
+ BaseRow,
+ ContentRow,
+ ObjectCountRow,
+)
from swh.storage.interface import (
ListOrder,
PagedResult,
@@ -69,8 +72,7 @@
from swh.storage.utils import now
from .converters import origin_url_to_sha1
-from .exc import StorageArgumentException, HashCollision
-from .utils import get_partition_bounds_bytes
+from .exc import StorageArgumentException
from .writer import JournalWriter
# Max block size of contents to return
@@ -222,6 +224,8 @@
class InMemoryCqlRunner:
def __init__(self):
+ self._contents = Table(ContentRow)
+ self._content_indexes = defaultdict(lambda: defaultdict(set))
self._stat_counters = defaultdict(int)
def increment_counter(self, object_type: str, nb: int):
@@ -231,12 +235,67 @@
for (object_type, count) in self._stat_counters.items():
yield ObjectCountRow(partition_key=0, object_type=object_type, count=count)
+ ##########################
+ # 'content' table
+ ##########################
+
+ def _content_add_finalize(self, content: ContentRow) -> None:
+ self._contents.insert(content)
+ self.increment_counter("content", 1)
+
+ def content_add_prepare(self, content: ContentRow):
+ finalizer = functools.partial(self._content_add_finalize, content)
+ return (self._contents.token(self._contents.partition_key(content)), finalizer)
+
+ def content_get_from_pk(
+ self, content_hashes: Dict[str, bytes]
+ ) -> Optional[ContentRow]:
+ primary_key = self._contents.primary_key_from_dict(content_hashes)
+ return self._contents.get_from_primary_key(primary_key)
+
+ def content_get_from_token(self, token: int) -> Iterable[ContentRow]:
+ return self._contents.get_from_token(token)
+
+ def content_get_random(self) -> Optional[ContentRow]:
+ return random.choice(
+ [
+ row
+ for partition in self._contents.data.values()
+ for row in partition.values()
+ ]
+ )
+
+ def content_get_token_range(
+ self, start: int, end: int, limit: int,
+ ) -> Iterable[Tuple[int, ContentRow]]:
+ matches = [
+ (token, row)
+ for (token, partition) in self._contents.data.items()
+ for (clustering_key, row) in partition.items()
+ if start <= token <= end
+ ]
+ matches.sort()
+ return matches[0:limit]
+
##########################
# 'content_by_*' tables
##########################
def content_missing_by_sha1_git(self, ids: List[bytes]) -> List[bytes]:
- return ids
+ missing = []
+ for id_ in ids:
+ if id_ not in self._content_indexes["sha1_git"]:
+ missing.append(id_)
+
+ return missing
+
+ def content_index_add_one(self, algo: str, content: Content, token: int) -> None:
+ self._content_indexes[algo][content.get_hash(algo)].add(token)
+
+ def content_get_tokens_from_single_hash(
+ self, algo: str, hash_: bytes
+ ) -> Iterable[int]:
+ return self._content_indexes[algo][hash_]
##########################
# 'directory' table
@@ -269,8 +328,6 @@
def reset(self):
self._cql_runner = InMemoryCqlRunner()
- self._contents = {}
- self._content_indexes = defaultdict(lambda: defaultdict(set))
self._skipped_contents = {}
self._skipped_content_indexes = defaultdict(lambda: defaultdict(set))
self._directories = {}
@@ -318,184 +375,6 @@
def check_config(self, *, check_write: bool) -> bool:
return True
- def _content_add(self, contents: List[Content], with_data: bool) -> Dict:
- self.journal_writer.content_add(contents)
-
- content_add = 0
- if with_data:
- summary = self.objstorage.content_add(
- c for c in contents if c.status != "absent"
- )
- content_add_bytes = summary["content:add:bytes"]
-
- for content in contents:
- key = self._content_key(content)
- if key in self._contents:
- continue
- for algorithm in DEFAULT_ALGORITHMS:
- hash_ = content.get_hash(algorithm)
- if hash_ in self._content_indexes[algorithm] and (
- algorithm not in {"blake2s256", "sha256"}
- ):
- colliding_content_hashes = []
- # Add the already stored contents
- for content_hashes_set in self._content_indexes[algorithm][hash_]:
- hashes = dict(content_hashes_set)
- colliding_content_hashes.append(hashes)
- # Add the new colliding content
- colliding_content_hashes.append(content.hashes())
- raise HashCollision(algorithm, hash_, colliding_content_hashes)
- for algorithm in DEFAULT_ALGORITHMS:
- hash_ = content.get_hash(algorithm)
- self._content_indexes[algorithm][hash_].add(key)
- self._objects[content.sha1_git].append(("content", content.sha1))
- self._contents[key] = content
- self._sorted_sha1s.add(content.sha1)
- self._contents[key] = attr.evolve(self._contents[key], data=None)
- content_add += 1
-
- self._cql_runner.increment_counter("content", content_add)
-
- summary = {
- "content:add": content_add,
- }
- if with_data:
- summary["content:add:bytes"] = content_add_bytes
-
- return summary
-
- def content_add(self, content: List[Content]) -> Dict:
- content = [attr.evolve(c, ctime=now()) for c in content]
- return self._content_add(content, with_data=True)
-
- def content_update(
- self, contents: List[Dict[str, Any]], keys: List[str] = []
- ) -> None:
- self.journal_writer.content_update(contents)
-
- for cont_update in contents:
- cont_update = cont_update.copy()
- sha1 = cont_update.pop("sha1")
- for old_key in self._content_indexes["sha1"][sha1]:
- old_cont = self._contents.pop(old_key)
-
- for algorithm in DEFAULT_ALGORITHMS:
- hash_ = old_cont.get_hash(algorithm)
- self._content_indexes[algorithm][hash_].remove(old_key)
-
- new_cont = attr.evolve(old_cont, **cont_update)
- new_key = self._content_key(new_cont)
-
- self._contents[new_key] = new_cont
-
- for algorithm in DEFAULT_ALGORITHMS:
- hash_ = new_cont.get_hash(algorithm)
- self._content_indexes[algorithm][hash_].add(new_key)
-
- def content_add_metadata(self, content: List[Content]) -> Dict:
- return self._content_add(content, with_data=False)
-
- def content_get_data(self, content: Sha1) -> Optional[bytes]:
- # FIXME: Make this method support slicing the `data`
- return self.objstorage.content_get(content)
-
- def content_get_partition(
- self,
- partition_id: int,
- nb_partitions: int,
- page_token: Optional[str] = None,
- limit: int = 1000,
- ) -> PagedResult[Content]:
- if limit is None:
- raise StorageArgumentException("limit should not be None")
- (start, end) = get_partition_bounds_bytes(
- partition_id, nb_partitions, SHA1_SIZE
- )
- if page_token:
- start = hash_to_bytes(page_token)
- if end is None:
- end = b"\xff" * SHA1_SIZE
-
- next_page_token: Optional[str] = None
- sha1s = (
- (sha1, content_key)
- for sha1 in self._sorted_sha1s.iter_from(start)
- for content_key in self._content_indexes["sha1"][sha1]
- )
- contents: List[Content] = []
- for counter, (sha1, key) in enumerate(sha1s):
- if sha1 > end:
- break
- if counter >= limit:
- next_page_token = hash_to_hex(sha1)
- break
- contents.append(self._contents[key])
-
- assert len(contents) <= limit
- return PagedResult(results=contents, next_page_token=next_page_token)
-
- def content_get(self, contents: List[Sha1]) -> List[Optional[Content]]:
- contents_by_sha1: Dict[Sha1, Optional[Content]] = {}
- for sha1 in contents:
- if sha1 in self._content_indexes["sha1"]:
- objs = self._content_indexes["sha1"][sha1]
- # only 1 element as content_add_metadata would have raised a
- # hash collision otherwise
- assert len(objs) == 1
- for key in objs:
- content = attr.evolve(self._contents[key], data=None, ctime=None)
- contents_by_sha1[sha1] = content
- return [contents_by_sha1.get(sha1) for sha1 in contents]
-
- def content_find(self, content: Dict[str, Any]) -> List[Content]:
- if not set(content).intersection(DEFAULT_ALGORITHMS):
- raise StorageArgumentException(
- "content keys must contain at least one "
- f"of: {', '.join(sorted(DEFAULT_ALGORITHMS))}"
- )
- found = []
- for algo in DEFAULT_ALGORITHMS:
- hash = content.get(algo)
- if hash and hash in self._content_indexes[algo]:
- found.append(self._content_indexes[algo][hash])
-
- if not found:
- return []
-
- keys = list(set.intersection(*found))
- return [self._contents[key] for key in keys]
-
- def content_missing(
- self, contents: List[Dict[str, Any]], key_hash: str = "sha1"
- ) -> Iterable[bytes]:
- if key_hash not in DEFAULT_ALGORITHMS:
- raise StorageArgumentException(
- "key_hash should be one of {','.join(DEFAULT_ALGORITHMS)}"
- )
-
- for content in contents:
- for (algo, hash_) in content.items():
- if algo not in DEFAULT_ALGORITHMS:
- continue
- if hash_ not in self._content_indexes.get(algo, []):
- yield content[key_hash]
- break
-
- def content_missing_per_sha1(self, contents: List[bytes]) -> Iterable[bytes]:
- for content in contents:
- if content not in self._content_indexes["sha1"]:
- yield content
-
- def content_missing_per_sha1_git(
- self, contents: List[Sha1Git]
- ) -> Iterable[Sha1Git]:
- for content in contents:
- if content not in self._content_indexes["sha1_git"]:
- yield content
-
- def content_get_random(self) -> Sha1Git:
- return random.choice(list(self._content_indexes["sha1_git"]))
-
def _skipped_content_add(self, contents: List[SkippedContent]) -> Dict:
self.journal_writer.skipped_content_add(contents)
diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py
--- a/swh/storage/tests/test_api_client.py
+++ b/swh/storage/tests/test_api_client.py
@@ -3,8 +3,6 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
-from unittest.mock import patch
-
import pytest
import swh.storage.api.server as server
@@ -63,8 +61,6 @@
class TestStorage(_TestStorage):
- def test_content_update(self, swh_storage, app_server, sample_data):
- # TODO, journal_writer not supported
- swh_storage.journal_writer.journal = None
- with patch.object(server.storage.journal_writer, "journal", None):
- super().test_content_update(swh_storage, sample_data)
+ @pytest.mark.skip("content_update is not yet implemented for Cassandra")
+ def test_content_update(self):
+ pass
diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py
--- a/swh/storage/tests/test_in_memory.py
+++ b/swh/storage/tests/test_in_memory.py
@@ -9,7 +9,8 @@
from swh.storage.cassandra.model import BaseRow
from swh.storage.in_memory import SortedList, Table
-from swh.storage.tests.test_storage import TestStorage, TestStorageGeneratedData # noqa
+from swh.storage.tests.test_storage import TestStorage as _TestStorage
+from swh.storage.tests.test_storage import TestStorageGeneratedData # noqa
# tests are executed using imported classes (TestStorage and
@@ -144,3 +145,9 @@
assert len(all_rows) == 3
for row in (row1, row2, row3):
assert (table.primary_key(row), row) in all_rows
+
+
+class TestInMemoryStorage(_TestStorage):
+ @pytest.mark.skip("content_update is not yet implemented for Cassandra")
+ def test_content_update(self):
+ pass
diff --git a/swh/storage/tests/test_replay.py b/swh/storage/tests/test_replay.py
--- a/swh/storage/tests/test_replay.py
+++ b/swh/storage/tests/test_replay.py
@@ -3,16 +3,19 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
+import dataclasses
import datetime
import functools
import logging
from typing import Any, Container, Dict, Optional
+import attr
import pytest
from swh.model.hashutil import hash_to_hex, MultiHash, DEFAULT_ALGORITHMS
from swh.storage import get_storage
+from swh.storage.cassandra.model import ContentRow, SkippedContentRow
from swh.storage.in_memory import InMemoryStorage
from swh.storage.replay import process_replay_objects
@@ -28,6 +31,13 @@
UTC = datetime.timezone.utc
+def nullify_ctime(obj):
+ if isinstance(obj, (ContentRow, SkippedContentRow)):
+ return dataclasses.replace(obj, ctime=None)
+ else:
+ return obj
+
+
@pytest.fixture()
def replayer_storage_and_client(
kafka_prefix: str, kafka_consumer_group: str, kafka_server: str
@@ -120,6 +130,8 @@
for content in DUPLICATE_CONTENTS:
topic = f"{prefix}.content"
key = content.sha1
+ now = datetime.datetime.now(tz=UTC)
+ content = attr.evolve(content, ctime=now)
producer.produce(
topic=topic, key=key_to_kafka(key), value=value_to_kafka(content.to_dict()),
)
@@ -162,7 +174,10 @@
# all objects from the src should exists in the dst storage
_check_replayed(src, dst, exclude=["contents"])
# but the dst has one content more (one of the 2 colliding ones)
- assert len(src._contents) == len(dst._contents) - 1
+ assert (
+ len(list(src._cql_runner._contents.iter_all()))
+ == len(list(dst._cql_runner._contents.iter_all())) - 1
+ )
def test_replay_skipped_content(replayer_storage_and_client):
@@ -190,8 +205,7 @@
got_persons = set(dst._persons.values())
assert got_persons == expected_persons
- for attr in (
- "contents",
+ for attr_ in (
"skipped_contents",
"directories",
"revisions",
@@ -201,11 +215,24 @@
"origin_visits",
"origin_visit_statuses",
):
- if exclude and attr in exclude:
+ if exclude and attr_ in exclude:
continue
- expected_objects = sorted(getattr(src, f"_{attr}").items())
- got_objects = sorted(getattr(dst, f"_{attr}").items())
- assert got_objects == expected_objects, f"Mismatch object list for {attr}"
+ expected_objects = sorted(getattr(src, f"_{attr_}").items())
+ got_objects = sorted(getattr(dst, f"_{attr_}").items())
+ assert got_objects == expected_objects, f"Mismatch object list for {attr_}"
+
+ for attr_ in ("contents",):
+ if exclude and attr_ in exclude:
+ continue
+ expected_objects = [
+ (id, nullify_ctime(obj))
+ for id, obj in sorted(getattr(src._cql_runner, f"_{attr_}").iter_all())
+ ]
+ got_objects = [
+ (id, nullify_ctime(obj))
+ for id, obj in sorted(getattr(dst._cql_runner, f"_{attr_}").iter_all())
+ ]
+ assert got_objects == expected_objects, f"Mismatch object list for {attr_}"
def _check_replay_skipped_content(storage, replayer, topic):
@@ -329,17 +356,28 @@
"""
- def maybe_anonymize(obj):
+ def maybe_anonymize(attr_, row):
if expected_anonymized:
- return obj.anonymize() or obj
- return obj
-
- expected_persons = {maybe_anonymize(person) for person in src._persons.values()}
+ if hasattr(row, "anonymize"):
+ # for model objects; cases below are for BaseRow objects
+ row = row.anonymize() or row
+ elif attr_ == "releases":
+ row = dataclasses.replace(row, author=row.author.anonymize())
+ elif attr_ == "revisions":
+ row = dataclasses.replace(
+ row,
+ author=row.author.anonymize(),
+ committer=row.committer.anonymize(),
+ )
+ return row
+
+ expected_persons = {
+ maybe_anonymize("persons", person) for person in src._persons.values()
+ }
got_persons = set(dst._persons.values())
assert got_persons == expected_persons
- for attr in (
- "contents",
+ for attr_ in (
"skipped_contents",
"directories",
"revisions",
@@ -349,10 +387,21 @@
"origin_visit_statuses",
):
expected_objects = [
- (id, maybe_anonymize(obj))
- for id, obj in sorted(getattr(src, f"_{attr}").items())
+ (id, maybe_anonymize(attr_, obj))
+ for id, obj in sorted(getattr(src, f"_{attr_}").items())
+ ]
+ got_objects = [
+ (id, obj) for id, obj in sorted(getattr(dst, f"_{attr_}").items())
+ ]
+ assert got_objects == expected_objects, f"Mismatch object list for {attr_}"
+
+ for attr_ in ("contents",):
+ expected_objects = [
+ (id, nullify_ctime(maybe_anonymize(attr_, obj)))
+ for id, obj in sorted(getattr(src._cql_runner, f"_{attr_}").iter_all())
]
got_objects = [
- (id, obj) for id, obj in sorted(getattr(dst, f"_{attr}").items())
+ (id, nullify_ctime(obj))
+ for id, obj in sorted(getattr(dst._cql_runner, f"_{attr_}").iter_all())
]
- assert got_objects == expected_objects, f"Mismatch object list for {attr}"
+ assert got_objects == expected_objects, f"Mismatch object list for {attr_}"
diff --git a/swh/storage/tests/test_retry.py b/swh/storage/tests/test_retry.py
--- a/swh/storage/tests/test_retry.py
+++ b/swh/storage/tests/test_retry.py
@@ -13,6 +13,7 @@
from swh.model.model import MetadataTargetType
from swh.storage.exc import HashCollision, StorageArgumentException
+from swh.storage.utils import now
@pytest.fixture
@@ -120,7 +121,7 @@
content_metadata = swh_storage.content_get([pk])
assert content_metadata == [None]
- s = swh_storage.content_add_metadata([content])
+ s = swh_storage.content_add_metadata([attr.evolve(content, ctime=now())])
assert s == {
"content:add": 1,
}
File Metadata
Details
Attached
Mime Type
text/plain
Expires
Nov 5 2024, 1:19 AM (11 w, 17 h ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3222759
Attached To
D3770: in_memory: Remove InMemoryStorage.content_* and implement InMemoryCqlRunner.content_*
Event Timeline
Log In to Comment