diff --git a/swh/storage/algos/origin.py b/swh/storage/algos/origin.py new file mode 100644 index 00000000..efecb010 --- /dev/null +++ b/swh/storage/algos/origin.py @@ -0,0 +1,33 @@ +# Copyright (C) 2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + + +def iter_origins(storage, origin_from=1, origin_to=None, batch_size=10000): + """Iterates over all origins in the storage. + + Args: + storage: the storage object used for queries. + batch_size: number of origins per query + Yields: + dict: the origin dictionary with the keys: + + - id: origin's id + - type: origin's type + - url: origin's url + """ + start = origin_from + while True: + if origin_to: + origin_count = min(origin_to - start, batch_size) + else: + origin_count = batch_size + origins = list(storage.origin_get_range( + origin_from=start, origin_count=origin_count)) + if not origins: + break + start = origins[-1]['id'] + 1 + yield from origins + if origin_to and start > origin_to: + break diff --git a/swh/storage/tests/algos/test_origin.py b/swh/storage/tests/algos/test_origin.py new file mode 100644 index 00000000..8ad5e501 --- /dev/null +++ b/swh/storage/tests/algos/test_origin.py @@ -0,0 +1,74 @@ +# Copyright (C) 2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from unittest.mock import patch + +from swh.storage.in_memory import Storage +from swh.storage.algos.origin import iter_origins + + +def assert_list_eq(left, right, msg=None): + assert list(left) == list(right), msg + + +def test_iter_origins(): + storage = Storage() + origins = storage.origin_add([ + {'type': 'foo', 'url': 'bar'}, + {'type': 'baz', 'url': 'qux'}, + {'type': 'quux', 'url': 'quuz'}, + ]) + assert_list_eq(iter_origins(storage), origins) + assert_list_eq(iter_origins(storage, batch_size=1), origins) + assert_list_eq(iter_origins(storage, batch_size=2), origins) + + for i in range(1, 5): + assert_list_eq( + iter_origins(storage, origin_from=i+1), + origins[i:], + i) + + assert_list_eq( + iter_origins(storage, origin_from=i+1, batch_size=1), + origins[i:], + i) + + assert_list_eq( + iter_origins(storage, origin_from=i+1, batch_size=2), + origins[i:], + i) + + for j in range(i, 5): + assert_list_eq( + iter_origins( + storage, origin_from=i+1, origin_to=j+1), + origins[i:j], + (i, j)) + + assert_list_eq( + iter_origins( + storage, origin_from=i+1, origin_to=j+1, batch_size=1), + origins[i:j], + (i, j)) + + assert_list_eq( + iter_origins( + storage, origin_from=i+1, origin_to=j+1, batch_size=2), + origins[i:j], + (i, j)) + + +@patch('swh.storage.in_memory.Storage.origin_get_range') +def test_iter_origins_batch_size(mock_origin_get_range): + storage = Storage() + mock_origin_get_range.return_value = [] + + list(iter_origins(storage)) + mock_origin_get_range.assert_called_with( + origin_from=1, origin_count=10000) + + list(iter_origins(storage, batch_size=42)) + mock_origin_get_range.assert_called_with( + origin_from=1, origin_count=42)