diff --git a/swh/storage/algos/origin.py b/swh/storage/algos/origin.py new file mode 100644 --- /dev/null +++ b/swh/storage/algos/origin.py @@ -0,0 +1,33 @@ +# Copyright (C) 2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + + +def iter_origins(storage, origin_from=1, origin_to=None, batch_size=10000): + """Iterates over all origins in the storage. + + Args: + storage: the storage object used for queries. + batch_size: number of origins per query + Yields: + dict: the origin dictionary with the keys: + + - id: origin's id + - type: origin's type + - url: origin's url + """ + start = origin_from + while True: + if origin_to: + origin_count = min(origin_to-start, batch_size) + else: + origin_count = batch_size + origins = list(storage.origin_get_range( + origin_from=start, origin_count=origin_count)) + if not origins: + break + start = origins[-1]['id']+1 + yield from origins + if origin_to and start > origin_to: + break diff --git a/swh/storage/tests/algos/test_origin.py b/swh/storage/tests/algos/test_origin.py new file mode 100644 --- /dev/null +++ b/swh/storage/tests/algos/test_origin.py @@ -0,0 +1,57 @@ +# Copyright (C) 2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from unittest.mock import patch + +from swh.storage.in_memory import Storage +from swh.storage.algos.origin import iter_origins + + +def test_iter_origins(): + storage = Storage() + origins = storage.origin_add([ + {'type': 'foo', 'url': 'bar'}, + {'type': 'baz', 'url': 'qux'}, + {'type': 'quux', 'url': 'quuz'}, + ]) + assert list(iter_origins( + storage)) == origins + assert list(iter_origins( + storage, batch_size=1)) == origins + assert list(iter_origins( + storage, batch_size=2)) == origins + + for i in range(1, 5): + assert list(iter_origins( + storage, origin_from=i+1)) == origins[i:], i + assert list(iter_origins( + storage, origin_from=i+1, batch_size=1)) == origins[i:], i + assert list(iter_origins( + storage, origin_from=i+1, batch_size=2)) == origins[i:], i + + for j in range(i, 5): + assert list(iter_origins( + storage, origin_from=i+1, origin_to=j+1) + ) == origins[i:j], (i, j) + assert list(iter_origins( + storage, origin_from=i+1, origin_to=j+1, batch_size=1) + ) == origins[i:j], (i, j) + assert list(iter_origins( + storage, origin_from=i+1, origin_to=j+1, batch_size=2) + ) == origins[i:j], (i, j) + + +@patch('swh.storage.in_memory.Storage.origin_get_range') +def test_iter_origins_batch_size(mock_origin_get_range): + storage = Storage() + mock_origin_get_range.return_value = [] + + list(iter_origins(storage)) + mock_origin_get_range.assert_called_with( + origin_from=1, origin_count=10000) + + list(iter_origins(storage, batch_size=42)) + mock_origin_get_range.assert_called_with( + origin_from=1, origin_count=42)