diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -102,6 +102,16 @@ for (k, item) in itertools.islice(self.data, from_index, None): yield item + def iter_after(self, start_key: SortedListKey) -> Iterator[SortedListItem]: + """Same as iter_from, but using a strict inequality.""" + it = self.iter_from(start_key) + for item in it: + if self.key(item) > start_key: # type: ignore + yield item + break + + yield from it + class InMemoryStorage: def __init__(self, journal_writer=None): diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py --- a/swh/storage/tests/test_in_memory.py +++ b/swh/storage/tests/test_in_memory.py @@ -68,3 +68,19 @@ for split in items: expected = reversed(sorted(item for item in items if item <= split)) assert list(list_.iter_from(-split)) == list(expected), f"split: {split}" + + +@parametrize +def test_sorted_list_iter_after(items): + list_ = SortedList(items) + for split in items: + expected = sorted(item for item in items if item > split) + assert list(list_.iter_after(split)) == expected, f"split: {split}" + + +@parametrize +def test_sorted_list_iter_after__key(items): + list_ = SortedList(items, key=lambda item: -item) + for split in items: + expected = reversed(sorted(item for item in items if item < split)) + assert list(list_.iter_after(-split)) == list(expected), f"split: {split}"