diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -14,7 +14,19 @@ from collections import defaultdict from datetime import timedelta -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + Iterator, + List, + Optional, + Tuple, + TypeVar, + Union, +) import attr @@ -45,6 +57,50 @@ BULK_BLOCK_CONTENT_LEN_MAX = 10000 +SortedListItem = TypeVar("SortedListItem") +SortedListKey = TypeVar("SortedListKey") + + +class SortedList(collections.UserList, Generic[SortedListKey, SortedListItem]): + data: List[Tuple[Any, SortedListItem]] + + # https://github.com/python/mypy/issues/708 + # key: Callable[[SortedListItem], SortedListKey] + + def __init__( + self, + data: List[SortedListItem] = None, + key: Optional[Callable[[SortedListItem], SortedListKey]] = None, + ): + if key is None: + + def key(item): + return item + + assert key is not None # for mypy + super().__init__(sorted((key(x), x) for x in data or [])) + + self.key: Callable[[SortedListItem], SortedListKey] = key + + def add(self, item: SortedListItem): + k = self.key(item) + bisect.insort(self.data, (k, item)) + + def __iter__(self) -> Iterator[SortedListItem]: + for (k, item) in self.data: + yield item + + def iter_from(self, start_key: SortedListKey) -> Iterator[SortedListItem]: + """Returns an iterator over all the elements whose key is greater + or equal to `start_key`. + (This is an efficient equivalent to: + `(x for x in L if key(x) >= start_key)`) + """ + from_index = bisect.bisect_left(self.data, (start_key,)) + for (k, item) in itertools.islice(self.data, from_index, None): + yield item + + class InMemoryStorage: def __init__(self, journal_writer=None): @@ -70,8 +126,7 @@ self._metadata_providers = {} self._objects = defaultdict(list) - # ideally we would want a skip list for both fast inserts and searches - self._sorted_sha1s = [] + self._sorted_sha1s = SortedList[bytes, bytes]() self.objstorage = ObjStorage({"cls": "memory", "args": {}}) @@ -111,7 +166,7 @@ self._content_indexes[algorithm][hash_].add(key) self._objects[content.sha1_git].append(("content", content.sha1)) self._contents[key] = content - bisect.insort(self._sorted_sha1s, content.sha1) + self._sorted_sha1s.add(content.sha1) self._contents[key] = attr.evolve(self._contents[key], data=None) content_add += 1 @@ -163,11 +218,9 @@ def content_get_range(self, start, end, limit=1000): if limit is None: raise StorageArgumentException("limit should not be None") - from_index = bisect.bisect_left(self._sorted_sha1s, start) - sha1s = itertools.islice(self._sorted_sha1s, from_index, None) sha1s = ( (sha1, content_key) - for sha1 in sha1s + for sha1 in self._sorted_sha1s.iter_from(start) for content_key in self._content_indexes["sha1"][sha1] ) matched = [] diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py --- a/swh/storage/tests/test_in_memory.py +++ b/swh/storage/tests/test_in_memory.py @@ -5,6 +5,7 @@ import pytest +from swh.storage.in_memory import SortedList from swh.storage.tests.test_storage import TestStorage, TestStorageGeneratedData # noqa @@ -19,3 +20,51 @@ "cls": "memory", "journal_writer": {"cls": "memory",}, } + + +parametrize = pytest.mark.parametrize( + "items", + [ + [1, 2, 3, 4, 5, 6, 10, 100], + [10, 100, 6, 5, 4, 3, 2, 1], + [10, 4, 5, 6, 1, 2, 3, 100], + ], +) + + +@parametrize +def test_sorted_list_iter(items): + list1 = SortedList() + for item in items: + list1.add(item) + assert list(list1) == sorted(items) + + list2 = SortedList(items) + assert list(list2) == sorted(items) + + +@parametrize +def test_sorted_list_iter__key(items): + list1 = SortedList(key=lambda item: -item) + for item in items: + list1.add(item) + assert list(list1) == list(reversed(sorted(items))) + + list2 = SortedList(items, key=lambda item: -item) + assert list(list2) == list(reversed(sorted(items))) + + +@parametrize +def test_sorted_list_iter_from(items): + list_ = SortedList(items) + for split in items: + expected = sorted(item for item in items if item >= split) + assert list(list_.iter_from(split)) == expected, split + + +@parametrize +def test_sorted_list_iter_from__key(items): + list_ = SortedList(items, key=lambda item: -item) + for split in items: + expected = reversed(sorted(item for item in items if item <= split)) + assert list(list_.iter_from(-split)) == list(expected), split