diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -14,7 +14,19 @@ from collections import defaultdict from datetime import timedelta -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + Iterator, + List, + Optional, + Tuple, + TypeVar, + Union, +) import attr @@ -45,6 +57,44 @@ BULK_BLOCK_CONTENT_LEN_MAX = 10000 +SortedListItem = TypeVar("SortedListItem") +SortedListKey = TypeVar("SortedListKey") + + +class SortedList(collections.UserList, Generic[SortedListKey, SortedListItem]): + data: List[Tuple[SortedListKey, SortedListItem]] + + # https://github.com/python/mypy/issues/708 + # key: Callable[[SortedListItem], SortedListKey] + + def __init__( + self, + data: List[SortedListItem] = None, + key: Optional[Callable[[SortedListItem], SortedListKey]] = None, + ): + super().__init__(sorted(data or [], key=key)) + if key is None: + + def key(item): + return item + + assert key is not None # for mypy + self.key: Callable[[SortedListItem], SortedListKey] = key + + def add(self, item: SortedListItem): + k = self.key(item) + bisect.insort(self.data, (k, item)) + + def __iter__(self) -> Iterator[SortedListItem]: + for (k, item) in self.data: + yield item + + def iter_from(self, start_key: SortedListKey) -> Iterator[SortedListItem]: + from_index = bisect.bisect_left(self.data, (start_key,)) + for (k, item) in itertools.islice(self.data, from_index, None): + yield item + + class InMemoryStorage: def __init__(self, journal_writer=None): @@ -70,8 +120,7 @@ self._metadata_providers = {} self._objects = defaultdict(list) - # ideally we would want a skip list for both fast inserts and searches - self._sorted_sha1s = [] + self._sorted_sha1s: SortedList[bytes, bytes] = SortedList() self.objstorage = ObjStorage({"cls": "memory", "args": {}}) @@ -111,7 +160,7 @@ self._content_indexes[algorithm][hash_].add(key) self._objects[content.sha1_git].append(("content", content.sha1)) self._contents[key] = content - bisect.insort(self._sorted_sha1s, content.sha1) + self._sorted_sha1s.add(content.sha1) self._contents[key] = attr.evolve(self._contents[key], data=None) content_add += 1 @@ -163,11 +212,9 @@ def content_get_range(self, start, end, limit=1000): if limit is None: raise StorageArgumentException("limit should not be None") - from_index = bisect.bisect_left(self._sorted_sha1s, start) - sha1s = itertools.islice(self._sorted_sha1s, from_index, None) sha1s = ( (sha1, content_key) - for sha1 in sha1s + for sha1 in self._sorted_sha1s.iter_from(start) for content_key in self._content_indexes["sha1"][sha1] ) matched = []