diff --git a/swh/storage/algos/origin.py b/swh/storage/algos/origin.py --- a/swh/storage/algos/origin.py +++ b/swh/storage/algos/origin.py @@ -10,41 +10,18 @@ from swh.storage.interface import ListOrder, StorageInterface -def iter_origins( - storage: StorageInterface, - origin_from: int = 1, - origin_to: Optional[int] = None, - batch_size: int = 10000, -) -> Iterator[Origin]: - """Iterates over all origins in the storage. +def iter_origins(storage: StorageInterface, limit: int = 10000,) -> Iterator[Origin]: + """Iterates over origins in the storage. Args: storage: the storage object used for queries. - origin_from: lower interval boundary - origin_to: upper interval boundary - batch_size: number of origins per query + limit: maximum number of origins per page Yields: - origin within the boundary [origin_to, origin_from] in batch_size + origin model objects from the storage in page of `limit` origins """ - start = origin_from - while True: - if origin_to: - origin_count = min(origin_to - start, batch_size) - else: - origin_count = batch_size - origins = list( - storage.origin_get_range(origin_from=start, origin_count=origin_count) - ) - if not origins: - break - start = origins[-1]["id"] + 1 - for origin in origins: - del origin["id"] - yield Origin.from_dict(origin) - if origin_to and start > origin_to: - break + yield from stream_results(storage.origin_list, limit=limit) def origin_get_latest_visit_status( diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -137,7 +137,6 @@ self._releases = {} self._snapshots = {} self._origins = {} - self._origins_by_id = [] self._origins_by_sha1 = {} self._origin_visits = {} self._origin_visit_statuses: Dict[Tuple[str, int], List[OriginVisitStatus]] = {} @@ -681,16 +680,6 @@ def origin_get_by_sha1(self, sha1s): return [self._convert_origin(self._origins_by_sha1.get(sha1)) for sha1 in sha1s] - def origin_get_range(self, origin_from=1, origin_count=100): - origin_from = max(origin_from, 1) - if origin_from <= len(self._origins_by_id): - max_idx = origin_from + origin_count - 1 - if max_idx > len(self._origins_by_id): - max_idx = len(self._origins_by_id) - for idx in range(origin_from - 1, max_idx): - origin = self._convert_origin(self._origins[self._origins_by_id[idx]]) - yield {"id": idx + 1, **origin} - def origin_list( self, page_token: Optional[str] = None, limit: int = 100 ) -> PagedResult[Origin]: @@ -776,12 +765,6 @@ def origin_add_one(self, origin: Origin) -> str: if origin.url not in self._origins: self.journal_writer.origin_add([origin]) - # generate an origin_id because it is needed by origin_get_range. - # TODO: remove this when we remove origin_get_range - origin_id = len(self._origins) + 1 - self._origins_by_id.append(origin.url) - assert len(self._origins_by_id) == origin_id - self._origins[origin.url] = origin self._origins_by_sha1[origin_url_to_sha1(origin.url)] = origin self._origin_visits[origin.url] = [] diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -1015,24 +1015,6 @@ """ ... - @deprecated - @remote_api_endpoint("origin/get_range") - def origin_get_range(self, origin_from=1, origin_count=100): - """Retrieve ``origin_count`` origins whose ids are greater - or equal than ``origin_from``. - - Origins are sorted by id before retrieving them. - - Args: - origin_from (int): the minimum id of origins to retrieve - origin_count (int): the maximum number of origins to retrieve - - Yields: - dicts containing origin information as returned - by :meth:`swh.storage.storage.Storage.origin_get`. - """ - ... - @remote_api_endpoint("origin/list") def origin_list( self, page_token: Optional[str] = None, limit: int = 100 diff --git a/swh/storage/tests/algos/test_origin.py b/swh/storage/tests/algos/test_origin.py --- a/swh/storage/tests/algos/test_origin.py +++ b/swh/storage/tests/algos/test_origin.py @@ -19,10 +19,6 @@ from swh.storage.tests.test_storage import round_to_milliseconds -def assert_list_eq(left, right, msg=None): - assert list(left) == list(right), msg - - def test_iter_origins(swh_storage): origins = [ Origin(url="bar"), @@ -30,43 +26,11 @@ Origin(url="quuz"), ] assert swh_storage.origin_add(origins) == {"origin:add": 3} - assert_list_eq(iter_origins(swh_storage), origins) - assert_list_eq(iter_origins(swh_storage, batch_size=1), origins) - assert_list_eq(iter_origins(swh_storage, batch_size=2), origins) - - for i in range(1, 5): - assert_list_eq(iter_origins(swh_storage, origin_from=i + 1), origins[i:], i) - - assert_list_eq( - iter_origins(swh_storage, origin_from=i + 1, batch_size=1), origins[i:], i - ) - - assert_list_eq( - iter_origins(swh_storage, origin_from=i + 1, batch_size=2), origins[i:], i - ) - - for j in range(i, 5): - assert_list_eq( - iter_origins(swh_storage, origin_from=i + 1, origin_to=j + 1), - origins[i:j], - (i, j), - ) - assert_list_eq( - iter_origins( - swh_storage, origin_from=i + 1, origin_to=j + 1, batch_size=1 - ), - origins[i:j], - (i, j), - ) - - assert_list_eq( - iter_origins( - swh_storage, origin_from=i + 1, origin_to=j + 1, batch_size=2 - ), - origins[i:j], - (i, j), - ) + # this returns all the origins, only the number of paged called is different + assert list(iter_origins(swh_storage)) == origins + assert list(iter_origins(swh_storage, limit=1)) == origins + assert list(iter_origins(swh_storage, limit=2)) == origins def test_origin_get_latest_visit_status_none(swh_storage, sample_data): diff --git a/swh/storage/tests/test_cassandra.py b/swh/storage/tests/test_cassandra.py --- a/swh/storage/tests/test_cassandra.py +++ b/swh/storage/tests/test_cassandra.py @@ -351,14 +351,6 @@ def test_origin_count(self): pass - @pytest.mark.skip("Not supported by Cassandra") - def test_origin_get_range(self): - pass - - @pytest.mark.skip("Not supported by Cassandra") - def test_origin_get_range_from_zero(self): - pass - @pytest.mark.skip("Not supported by Cassandra") def test_generate_content_get_range_limit(self): pass diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -3968,42 +3968,6 @@ assert_contents_ok([contents_map[get_sha1s[-1]]], actual_contents2, ["sha1"]) - def test_origin_get_range_from_zero(self, swh_storage, swh_origins): - actual_origins = list( - swh_storage.origin_get_range(origin_from=0, origin_count=0) - ) - assert len(actual_origins) == 0 - - actual_origins = list( - swh_storage.origin_get_range(origin_from=0, origin_count=1) - ) - assert len(actual_origins) == 1 - assert actual_origins[0]["id"] == 1 - assert actual_origins[0]["url"] == swh_origins[0].url - - @pytest.mark.parametrize( - "origin_from,origin_count", - [(1, 1), (1, 10), (1, 20), (1, 101), (11, 0), (11, 10), (91, 11)], - ) - def test_origin_get_range( - self, swh_storage, swh_origins, origin_from, origin_count - ): - actual_origins = list( - swh_storage.origin_get_range( - origin_from=origin_from, origin_count=origin_count - ) - ) - - origins_with_id = list(enumerate(swh_origins, start=1)) - expected_origins = [ - {"url": origin.url, "id": origin_id,} - for (origin_id, origin) in origins_with_id[ - origin_from - 1 : origin_from + origin_count - 1 - ] - ] - - assert actual_origins == expected_origins - @pytest.mark.parametrize("limit", [1, 7, 10, 100, 1000]) def test_origin_list(self, swh_storage, swh_origins, limit): returned_origins = []