diff --git a/swh/core/collections.py b/swh/core/collections.py index 92fab40..ed7b869 100644 --- a/swh/core/collections.py +++ b/swh/core/collections.py @@ -1,62 +1,61 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import bisect -import collections import itertools from typing import Any, Callable, Generic, Iterator, List, Optional, Tuple, TypeVar SortedListItem = TypeVar("SortedListItem") SortedListKey = TypeVar("SortedListKey") -class SortedList(collections.UserList, Generic[SortedListKey, SortedListItem]): +class SortedList(Generic[SortedListKey, SortedListItem]): data: List[Tuple[SortedListKey, SortedListItem]] # https://github.com/python/mypy/issues/708 # key: Callable[[SortedListItem], SortedListKey] def __init__( self, data: List[SortedListItem] = None, key: Optional[Callable[[SortedListItem], SortedListKey]] = None, ): if key is None: def key(item): return item assert key is not None # for mypy - super().__init__(sorted((key(x), x) for x in data or [])) + self.data = sorted((key(x), x) for x in data or []) self.key: Callable[[SortedListItem], SortedListKey] = key def add(self, item: SortedListItem): k = self.key(item) bisect.insort(self.data, (k, item)) def __iter__(self) -> Iterator[SortedListItem]: for (k, item) in self.data: yield item def iter_from(self, start_key: Any) -> Iterator[SortedListItem]: """Returns an iterator over all the elements whose key is greater or equal to `start_key`. (This is an efficient equivalent to: `(x for x in L if key(x) >= start_key)`) """ from_index = bisect.bisect_left(self.data, (start_key,)) for (k, item) in itertools.islice(self.data, from_index, None): yield item def iter_after(self, start_key: Any) -> Iterator[SortedListItem]: """Same as iter_from, but using a strict inequality.""" it = self.iter_from(start_key) for item in it: if self.key(item) > start_key: yield item break yield from it diff --git a/swh/core/tests/test_collections.py b/swh/core/tests/test_collections.py index 22efbc0..b2b4b21 100644 --- a/swh/core/tests/test_collections.py +++ b/swh/core/tests/test_collections.py @@ -1,71 +1,86 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest from swh.core.collections import SortedList parametrize = pytest.mark.parametrize( "items", [ [1, 2, 3, 4, 5, 6, 10, 100], [10, 100, 6, 5, 4, 3, 2, 1], [10, 4, 5, 6, 1, 2, 3, 100], ], ) @parametrize def test_sorted_list_iter(items): list1 = SortedList() for item in items: list1.add(item) assert list(list1) == sorted(items) list2 = SortedList(items) assert list(list2) == sorted(items) @parametrize def test_sorted_list_iter__key(items): list1 = SortedList(key=lambda item: -item) for item in items: list1.add(item) assert list(list1) == list(reversed(sorted(items))) list2 = SortedList(items, key=lambda item: -item) assert list(list2) == list(reversed(sorted(items))) @parametrize def test_sorted_list_iter_from(items): list_ = SortedList(items) for split in items: expected = sorted(item for item in items if item >= split) assert list(list_.iter_from(split)) == expected, f"split: {split}" @parametrize def test_sorted_list_iter_from__key(items): list_ = SortedList(items, key=lambda item: -item) for split in items: expected = reversed(sorted(item for item in items if item <= split)) assert list(list_.iter_from(-split)) == list(expected), f"split: {split}" @parametrize def test_sorted_list_iter_after(items): list_ = SortedList(items) for split in items: expected = sorted(item for item in items if item > split) assert list(list_.iter_after(split)) == expected, f"split: {split}" @parametrize def test_sorted_list_iter_after__key(items): list_ = SortedList(items, key=lambda item: -item) for split in items: expected = reversed(sorted(item for item in items if item < split)) assert list(list_.iter_after(-split)) == list(expected), f"split: {split}" + + +@parametrize +def test_contains(items): + list_ = SortedList() + for i in range(len(items)): + for item in items[0:i]: + assert item in list_ + for item in items[i:]: + assert item not in list_ + + list_.add(items[i]) + + for item in items: + assert item in list_