Page MenuHomeSoftware Heritage

D2987.id10834.diff
No OneTemporary

D2987.id10834.diff

diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py
--- a/swh/storage/in_memory.py
+++ b/swh/storage/in_memory.py
@@ -14,7 +14,19 @@
from collections import defaultdict
from datetime import timedelta
-from typing import Any, Dict, Iterable, List, Optional, Union
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Generic,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+ Union,
+)
import attr
@@ -45,6 +57,50 @@
BULK_BLOCK_CONTENT_LEN_MAX = 10000
+SortedListItem = TypeVar("SortedListItem")
+SortedListKey = TypeVar("SortedListKey")
+
+
+class SortedList(collections.UserList, Generic[SortedListKey, SortedListItem]):
+ data: List[Tuple[SortedListKey, SortedListItem]]
+
+ # https://github.com/python/mypy/issues/708
+ # key: Callable[[SortedListItem], SortedListKey]
+
+ def __init__(
+ self,
+ data: List[SortedListItem] = None,
+ key: Optional[Callable[[SortedListItem], SortedListKey]] = None,
+ ):
+ if key is None:
+
+ def key(item):
+ return item
+
+ assert key is not None # for mypy
+ super().__init__(sorted((key(x), x) for x in data or []))
+
+ self.key: Callable[[SortedListItem], SortedListKey] = key
+
+ def add(self, item: SortedListItem):
+ k = self.key(item)
+ bisect.insort(self.data, (k, item))
+
+ def __iter__(self) -> Iterator[SortedListItem]:
+ for (k, item) in self.data:
+ yield item
+
+ def iter_from(self, start_key: SortedListKey) -> Iterator[SortedListItem]:
+ """Returns an iterator over all the elements whose key is greater
+ or equal to `start_key`.
+ (This is an efficient equivalent to:
+ `(x for x in L if key(x) >= start_key)`)
+ """
+ from_index = bisect.bisect_left(self.data, (start_key,))
+ for (k, item) in itertools.islice(self.data, from_index, None):
+ yield item
+
+
class InMemoryStorage:
def __init__(self, journal_writer=None):
@@ -70,8 +126,7 @@
self._metadata_providers = {}
self._objects = defaultdict(list)
- # ideally we would want a skip list for both fast inserts and searches
- self._sorted_sha1s = []
+ self._sorted_sha1s = SortedList[bytes, bytes]()
self.objstorage = ObjStorage({"cls": "memory", "args": {}})
@@ -111,7 +166,7 @@
self._content_indexes[algorithm][hash_].add(key)
self._objects[content.sha1_git].append(("content", content.sha1))
self._contents[key] = content
- bisect.insort(self._sorted_sha1s, content.sha1)
+ self._sorted_sha1s.add(content.sha1)
self._contents[key] = attr.evolve(self._contents[key], data=None)
content_add += 1
@@ -163,11 +218,9 @@
def content_get_range(self, start, end, limit=1000):
if limit is None:
raise StorageArgumentException("limit should not be None")
- from_index = bisect.bisect_left(self._sorted_sha1s, start)
- sha1s = itertools.islice(self._sorted_sha1s, from_index, None)
sha1s = (
(sha1, content_key)
- for sha1 in sha1s
+ for sha1 in self._sorted_sha1s.iter_from(start)
for content_key in self._content_indexes["sha1"][sha1]
)
matched = []
diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py
--- a/swh/storage/tests/test_in_memory.py
+++ b/swh/storage/tests/test_in_memory.py
@@ -5,6 +5,7 @@
import pytest
+from swh.storage.in_memory import SortedList
from swh.storage.tests.test_storage import TestStorage, TestStorageGeneratedData # noqa
@@ -19,3 +20,51 @@
"cls": "memory",
"journal_writer": {"cls": "memory",},
}
+
+
+parametrize = pytest.mark.parametrize(
+ "items",
+ [
+ [1, 2, 3, 4, 5, 6, 10, 100],
+ [10, 100, 6, 5, 4, 3, 2, 1],
+ [10, 4, 5, 6, 1, 2, 3, 100],
+ ],
+)
+
+
+@parametrize
+def test_sorted_list_iter(items):
+ list1 = SortedList()
+ for item in items:
+ list1.add(item)
+ assert list(list1) == sorted(items)
+
+ list2 = SortedList(items)
+ assert list(list2) == sorted(items)
+
+
+@parametrize
+def test_sorted_list_iter__key(items):
+ list1 = SortedList(key=lambda item: -item)
+ for item in items:
+ list1.add(item)
+ assert list(list1) == list(reversed(sorted(items)))
+
+ list2 = SortedList(items, key=lambda item: -item)
+ assert list(list2) == list(reversed(sorted(items)))
+
+
+@parametrize
+def test_sorted_list_iter_from(items):
+ list_ = SortedList(items)
+ for split in items:
+ expected = sorted(item for item in items if item >= split)
+ assert list(list_.iter_from(split)) == expected, f"split: {split}"
+
+
+@parametrize
+def test_sorted_list_iter_from__key(items):
+ list_ = SortedList(items, key=lambda item: -item)
+ for split in items:
+ expected = reversed(sorted(item for item in items if item <= split))
+ assert list(list_.iter_from(-split)) == list(expected), f"split: {split}"

File Metadata

Mime Type
text/plain
Expires
Wed, Jul 2, 10:43 AM (2 w, 2 h ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3218046

Event Timeline