diff --git a/swh/storage/api/client.py b/swh/storage/api/client.py --- a/swh/storage/api/client.py +++ b/swh/storage/api/client.py @@ -133,6 +133,10 @@ 'regexp': regexp, 'with_visit': with_visit}) + def origins_get(self, origin_from=1, origin_count=100): + return self.post('origins/get', {'origin_from': origin_from, + 'origin_count': origin_count}) + def origin_add(self, origins): return self.post('origin/add_multi', {'origins': origins}) diff --git a/swh/storage/api/server.py b/swh/storage/api/server.py --- a/swh/storage/api/server.py +++ b/swh/storage/api/server.py @@ -225,6 +225,11 @@ return encode_data(get_storage().origin_get(**decode_request(request))) +@app.route('/origins/get', methods=['POST']) +def origins_get(): + return encode_data(get_storage().origins_get(**decode_request(request))) + + @app.route('/origin/search', methods=['POST']) def origin_search(): return encode_data(get_storage().origin_search(**decode_request(request))) diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -881,6 +881,24 @@ person_cols = ['fullname', 'name', 'email'] person_get_cols = person_cols + ['id'] + def origins_get(self, origin_from=1, origin_count=100, cur=None): + """Retrieve origins whose ids are greater or equal than origin_from. + Origins are sorted by id before retrieving them. + + Args: + origin_from (int): the minimum id of origins to retrieve + origin_count (int): the maximum number of origins to retrieve + """ + cur = self._cursor(cur) + + query = """SELECT %s + FROM origin WHERE id >= %%s + ORDER BY id LIMIT %%s + """ % ','.join(self.origin_cols) + + cur.execute(query, (origin_from, origin_count)) + yield from cursor_to_bytes(cur) + def person_get(self, ids, cur=None): """Retrieve the persons identified by the list of ids. diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -817,6 +817,25 @@ origin['id'] = origin_id return origin + def origins_get(self, origin_from=1, origin_count=100): + """Retrieve origins whose ids are greater or equal than origin_from. + Origins are sorted by id before retrieving them. + + Args: + origin_from (int): the minimum id of origins to retrieve + origin_count (int): the maximum number of origins to retrieve + + Yields: + dicts containing origin information as returned + by :meth:`swh.storage.in_memory.Storage.origin_get`. + """ + if 1 <= origin_from <= len(self._origins): + max_idx = origin_from + origin_count - 1 + if max_idx > len(self._origins): + max_idx = len(self._origins) + for idx in range(origin_from-1, max_idx): + yield copy.deepcopy(self._origins[idx]) + def origin_search(self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False, db=None, cur=None): """Search for origins whose urls contain a provided string pattern diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1135,14 +1135,30 @@ expression and return origins whose urls match it with_visit (bool): if True, filter out origins with no visit - Returns: - An iterable of dict containing origin information as returned + Yields: + dicts containing origin information as returned by :meth:`swh.storage.storage.Storage.origin_get`. """ for origin in db.origin_search(url_pattern, offset, limit, regexp, with_visit, cur): yield dict(zip(self.origin_keys, origin)) + @db_transaction_generator() + def origins_get(self, origin_from=1, origin_count=100, db=None, cur=None): + """Retrieve origins whose ids are greater or equal than origin_from. + Origins are sorted by id before retrieving them. + + Args: + origin_from (int): the minimum id of origins to retrieve + origin_count (int): the maximum number of origins to retrieve + + Yields: + dicts containing origin information as returned + by :meth:`swh.storage.storage.Storage.origin_get`. + """ + for origin in db.origins_get(origin_from, origin_count, cur): + yield dict(zip(self.origin_keys, origin)) + @db_transaction_generator(statement_timeout=500) def person_get(self, person, db=None, cur=None): """Return the persons identified by their ids. diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -5,8 +5,9 @@ import copy import datetime -import unittest import itertools +import random +import unittest from collections import defaultdict from unittest.mock import Mock, patch @@ -16,6 +17,7 @@ from swh.model import from_disk, identifiers from swh.model.hashutil import hash_to_bytes +from swh.storage.tests.algos.test_snapshot import origins from swh.storage.tests.storage_testing import StorageTestFixture from swh.storage import HashCollision @@ -2137,6 +2139,26 @@ self.assert_contents_ok([contents_map[actual_next]], actual_contents2, keys_to_check) + def test_origins_get(self): + nb_origins = 200 + origins_to_add = [] + for _ in range(nb_origins): + origins_to_add.append(origins().example()) + self.storage.origin_add(origins_to_add) + + origin_from = random.randint(1, nb_origins) + origin_count = random.randint(1, nb_origins - origin_from) + + expected_origins = [] + for i in range(origin_from, origin_from + origin_count): + expected_origins.append(self.storage.origin_get({'id': i})) + + actual_origins = list( + self.storage.origins_get(origin_from=origin_from, + origin_count=origin_count)) + + self.assertEqual(actual_origins, expected_origins) + @pytest.mark.db class TestLocalStorage(CommonTestStorage, StorageTestDbFixture,