diff --git a/swh/storage/api/client.py b/swh/storage/api/client.py --- a/swh/storage/api/client.py +++ b/swh/storage/api/client.py @@ -133,6 +133,10 @@ 'regexp': regexp, 'with_visit': with_visit}) + def origin_get_range(self, origin_from=1, origin_count=100): + return self.post('origin/get_range', {'origin_from': origin_from, + 'origin_count': origin_count}) + def origin_add(self, origins): return self.post('origin/add_multi', {'origins': origins}) diff --git a/swh/storage/api/server.py b/swh/storage/api/server.py --- a/swh/storage/api/server.py +++ b/swh/storage/api/server.py @@ -225,6 +225,12 @@ return encode_data(get_storage().origin_get(**decode_request(request))) +@app.route('/origin/get_range', methods=['POST']) +def origin_get_range(): + return encode_data(get_storage().origin_get_range( + **decode_request(request))) + + @app.route('/origin/search', methods=['POST']) def origin_search(): return encode_data(get_storage().origin_search(**decode_request(request))) diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -881,6 +881,26 @@ person_cols = ['fullname', 'name', 'email'] person_get_cols = person_cols + ['id'] + def origin_get_range(self, origin_from=1, origin_count=100, cur=None): + """Retrieve ``origin_count`` origins whose ids are greater + or equal than ``origin_from``. + + Origins are sorted by id before retrieving them. + + Args: + origin_from (int): the minimum id of origins to retrieve + origin_count (int): the maximum number of origins to retrieve + """ + cur = self._cursor(cur) + + query = """SELECT %s + FROM origin WHERE id >= %%s + ORDER BY id LIMIT %%s + """ % ','.join(self.origin_cols) + + cur.execute(query, (origin_from, origin_count)) + yield from cursor_to_bytes(cur) + def person_get(self, ids, cur=None): """Retrieve the persons identified by the list of ids. diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -817,6 +817,28 @@ origin['id'] = origin_id return origin + def origin_get_range(self, origin_from=1, origin_count=100): + """Retrieve ``origin_count`` origins whose ids are greater + or equal than ``origin_from``. + + Origins are sorted by id before retrieving them. + + Args: + origin_from (int): the minimum id of origins to retrieve + origin_count (int): the maximum number of origins to retrieve + + Yields: + dicts containing origin information as returned + by :meth:`swh.storage.in_memory.Storage.origin_get`. + """ + origin_from = max(origin_from, 1) + if origin_from <= len(self._origins): + max_idx = origin_from + origin_count - 1 + if max_idx > len(self._origins): + max_idx = len(self._origins) + for idx in range(origin_from-1, max_idx): + yield copy.deepcopy(self._origins[idx]) + def origin_search(self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False, db=None, cur=None): """Search for origins whose urls contain a provided string pattern diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1135,14 +1135,33 @@ expression and return origins whose urls match it with_visit (bool): if True, filter out origins with no visit - Returns: - An iterable of dict containing origin information as returned + Yields: + dicts containing origin information as returned by :meth:`swh.storage.storage.Storage.origin_get`. """ for origin in db.origin_search(url_pattern, offset, limit, regexp, with_visit, cur): yield dict(zip(self.origin_keys, origin)) + @db_transaction_generator() + def origin_get_range(self, origin_from=1, origin_count=100, + db=None, cur=None): + """Retrieve ``origin_count`` origins whose ids are greater + or equal than ``origin_from``. + + Origins are sorted by id before retrieving them. + + Args: + origin_from (int): the minimum id of origins to retrieve + origin_count (int): the maximum number of origins to retrieve + + Yields: + dicts containing origin information as returned + by :meth:`swh.storage.storage.Storage.origin_get`. + """ + for origin in db.origin_get_range(origin_from, origin_count, cur): + yield dict(zip(self.origin_keys, origin)) + @db_transaction_generator(statement_timeout=500) def person_get(self, person, db=None, cur=None): """Return the persons identified by their ids. diff --git a/swh/storage/tests/generate_data_test.py b/swh/storage/tests/generate_data_test.py --- a/swh/storage/tests/generate_data_test.py +++ b/swh/storage/tests/generate_data_test.py @@ -3,10 +3,14 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from hypothesis.strategies import (binary, composite, sets) +import random + +from hypothesis.strategies import (binary, composite, just, sets) from swh.model.hashutil import MultiHash +from swh.storage.tests.algos.test_snapshot import origins + def gen_raw_content(): """Generate raw content binary. @@ -47,3 +51,31 @@ }) return contents + + +def gen_origins(min_size=10, max_size=100, unique=True): + """Generate a list of origins. + + Args: + **min_size** (int): Minimal number of elements to generate + (default: 10) + **max_size** (int): Maximal number of elements to generate + (default: 100) + **unique** (bool): Specify if all generated origins must be unique + + Returns: + [dict] representing origins. The list's size is between + [min_size:max_size]. + """ + size = random.randint(min_size, max_size) + new_origins = [] + origins_set = set() + while len(new_origins) != size: + new_origin = origins().example() + if unique: + key = (new_origin['type'], new_origin['url']) + if key in origins_set: + continue + origins_set.add(key) + new_origins.append(new_origin) + return just(new_origins) diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -5,8 +5,9 @@ import copy import datetime -import unittest import itertools +import random +import unittest from collections import defaultdict from unittest.mock import Mock, patch @@ -19,7 +20,7 @@ from swh.storage.tests.storage_testing import StorageTestFixture from swh.storage import HashCollision -from .generate_data_test import gen_contents +from .generate_data_test import gen_contents, gen_origins @pytest.mark.db @@ -2135,8 +2136,42 @@ self.assertIsNone(actual_next2) self.assert_contents_ok([contents_map[actual_next]], actual_contents2, + keys_to_check) + @given(gen_origins(min_size=100, max_size=100)) + def test_origin_get_range(self, new_origins): + + nb_origins = len(new_origins) + + self.storage.origin_add(new_origins) + + origin_from = random.randint(1, nb_origins) + origin_count = random.randint(1, nb_origins - origin_from) + + expected_origins = [] + for i in range(origin_from, origin_from + origin_count): + expected_origins.append(self.storage.origin_get({'id': i})) + + actual_origins = list( + self.storage.origin_get_range(origin_from=origin_from, + origin_count=origin_count)) + + self.assertEqual(actual_origins, expected_origins) + + origin_from = -1 + origin_count = 10 + origins = list( + self.storage.origin_get_range(origin_from=origin_from, + origin_count=origin_count)) + self.assertEqual(len(origins), origin_count) + + origin_from = 10000 + origins = list( + self.storage.origin_get_range(origin_from=origin_from, + origin_count=origin_count)) + self.assertEqual(len(origins), 0) + @pytest.mark.db class TestLocalStorage(CommonTestStorage, StorageTestDbFixture,