diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -197,22 +197,23 @@ next_page_token: Optional[str] = None - rows = self._cql_runner.content_get_token_range(range_start, range_end, limit) + rows = self._cql_runner.content_get_token_range( + range_start, range_end, limit + 1 + ) contents = [] last_id: Optional[int] = None - for row in rows: + for counter, row in enumerate(rows): if row.status == "absent": continue row_d = row._asdict() last_id = row_d.pop("tok") + if counter >= limit: + next_page_token = str(last_id) + break contents.append(Content(**row_d)) - if len(contents) == limit: - assert last_id is not None - next_page_token = str(last_id + 1) - assert len(contents) <= limit - return PagedResult(results=contents, next_page_token=next_page_token,) + return PagedResult(results=contents, next_page_token=next_page_token) def content_get_metadata(self, contents: List[bytes]) -> Dict[bytes, List[Dict]]: result: Dict[bytes, List[Dict]] = {sha1: [] for sha1 in contents} diff --git a/swh/storage/tests/test_cassandra.py b/swh/storage/tests/test_cassandra.py --- a/swh/storage/tests/test_cassandra.py +++ b/swh/storage/tests/test_cassandra.py @@ -11,9 +11,11 @@ import time from collections import namedtuple +from typing import Dict import pytest +from swh.core.api.classes import stream_results from swh.storage import get_storage from swh.storage.cassandra import create_keyspace from swh.storage.cassandra.schema import TABLES, HASH_ALGORITHMS @@ -322,6 +324,57 @@ # but cont2 should be filtered out assert actual_result == [expected_content] + def test_content_get_partition_murmur3_collision( + self, swh_storage, mocker, sample_data + ): + """The Murmur3 token is used as link from index tables to the main table; and + non-matching contents with colliding murmur3-hash are filtered-out when reading + the main table. + + This test checks the content_get_partition endpoints return all contents, even + the collisions. + + """ + called = 0 + + rows: Dict[int, Dict] = {} + for tok, content in enumerate(sample_data.contents): + cont = attr.evolve(content, data=None) + row_d = {**cont.to_dict(), "tok": tok} + rows[tok] = row_d + + # For all tokens, always return cont + keys = set(["tok"] + list(content.to_dict().keys())).difference(set(["data"])) + Row = namedtuple("Row", keys) + + def mock_content_get_token_range(range_start, range_end, limit): + nonlocal called + called += 1 + + for tok in list(rows.keys()) * 3: # yield multiple times the same tok + row_d = rows[tok] + yield Row(**row_d) + + mocker.patch.object( + swh_storage._cql_runner, + "content_get_token_range", + mock_content_get_token_range, + ) + + actual_results = list( + stream_results( + swh_storage.content_get_partition, partition_id=0, nb_partitions=1 + ) + ) + + assert called > 0 + + # everything is listed, even collisions + assert len(actual_results) == 3 * len(sample_data.contents) + # as we duplicated the returned results, dropping duplicate should yield + # the original length + assert len(set(actual_results)) == len(sample_data.contents) + @pytest.mark.skip("content_update is not yet implemented for Cassandra") def test_content_update(self): pass