diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -26,6 +26,7 @@ Optional, Set, Tuple, + Type, TypeVar, Union, ) @@ -55,6 +56,8 @@ Sha1Git, ) from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex + +from swh.storage.cassandra.model import BaseRow from swh.storage.interface import ( ListOrder, PagedResult, @@ -129,9 +132,95 @@ yield from it +TRow = TypeVar("TRow", bound=BaseRow) + + +class Table(Generic[TRow]): + def __init__(self, row_class: Type[TRow]): + self.row_class = row_class + self.primary_key_cols = row_class.PARTITION_KEY + row_class.CLUSTERING_KEY + + # Map from tokens to clustering keys to rows + # These are not actually partitions (or rather, there is one partition + # for each token) and they aren't sorted. + # But it is good enough if we don't care about performance; + # and makes the code a lot simpler. + self.data: Dict[int, Dict[Tuple, TRow]] = defaultdict(dict) + + def __repr__(self): + return f"<__module__.Table[{self.row_class.__name__}] object>" + + def partition_key(self, row: Union[TRow, Dict[str, Any]]) -> Tuple: + """Returns the partition key of a row (ie. the cells which get hashed + into the token.""" + if isinstance(row, dict): + row_d = row + else: + row_d = row.to_dict() + return tuple(row_d[col] for col in self.row_class.PARTITION_KEY) + + def clustering_key(self, row: Union[TRow, Dict[str, Any]]) -> Tuple: + """Returns the clustering key of a row (ie. the cells which are used + for sorting rows within a partition.""" + if isinstance(row, dict): + row_d = row + else: + row_d = row.to_dict() + return tuple(row_d[col] for col in self.row_class.CLUSTERING_KEY) + + def primary_key(self, row): + return self.partition_key(row) + self.clustering_key(row) + + def primary_key_from_dict(self, d: Dict[str, Any]) -> Tuple: + """Returns the primary key (ie. concatenation of partition key and + clustering key) of the given dictionary interpreted as a row.""" + return tuple(d[col] for col in self.primary_key_cols) + + def token(self, key: Tuple): + """Returns the token of a row (ie. the hash of its partition key).""" + return hash(key) + + def get_partition(self, token: int) -> Dict[Tuple, TRow]: + """Returns the partition that contains this token.""" + return self.data[token] + + def insert(self, row: TRow): + partition = self.data[self.token(self.partition_key(row))] + partition[self.clustering_key(row)] = row + + def split_primary_key(self, key: Tuple) -> Tuple[Tuple, Tuple]: + """Returns (partition_key, clustering_key) from a partition key""" + assert len(key) == len(self.primary_key_cols) + + partition_key = key[0 : len(self.row_class.PARTITION_KEY)] + clustering_key = key[len(self.row_class.PARTITION_KEY) :] + + return (partition_key, clustering_key) + + def get_from_primary_key(self, primary_key: Tuple) -> Optional[TRow]: + """Returns at most one row, from its primary key.""" + (partition_key, clustering_key) = self.split_primary_key(primary_key) + + token = self.token(partition_key) + partition = self.get_partition(token) + + return partition.get(clustering_key) + + def get_from_token(self, token: int) -> Iterable[TRow]: + """Returns all rows whose token (ie. non-cryptographic hash of the + partition key) is the one passed as argument.""" + return (v for (k, v) in sorted(self.get_partition(token).items())) + + def iter_all(self) -> Iterator[Tuple[Tuple, TRow]]: + return ( + (self.primary_key(row), row) + for (token, partition) in self.data.items() + for (clustering_key, row) in partition.items() + ) + + class InMemoryStorage: def __init__(self, journal_writer=None): - self.reset() self.journal_writer = JournalWriter(journal_writer) diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py --- a/swh/storage/tests/test_in_memory.py +++ b/swh/storage/tests/test_in_memory.py @@ -3,9 +3,12 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import dataclasses + import pytest -from swh.storage.in_memory import SortedList +from swh.storage.cassandra.model import BaseRow +from swh.storage.in_memory import SortedList, Table from swh.storage.tests.test_storage import TestStorage, TestStorageGeneratedData # noqa @@ -84,3 +87,60 @@ for split in items: expected = reversed(sorted(item for item in items if item < split)) assert list(list_.iter_after(-split)) == list(expected), f"split: {split}" + + +@dataclasses.dataclass +class Row(BaseRow): + PARTITION_KEY = ("col1", "col2") + CLUSTERING_KEY = ("col3", "col4") + + col1: str + col2: str + col3: str + col4: str + col5: str + col6: int + + +def test_table_keys(): + table = Table(Row) + + primary_key = ("foo", "bar", "baz", "qux") + partition_key = ("foo", "bar") + clustering_key = ("baz", "qux") + + row = Row(col1="foo", col2="bar", col3="baz", col4="qux", col5="quux", col6=4) + assert table.partition_key(row) == partition_key + assert table.clustering_key(row) == clustering_key + assert table.primary_key(row) == primary_key + + assert table.primary_key_from_dict(row.to_dict()) == primary_key + assert table.split_primary_key(primary_key) == (partition_key, clustering_key) + + +def test_table(): + table = Table(Row) + + row1 = Row(col1="foo", col2="bar", col3="baz", col4="qux", col5="quux", col6=4) + row2 = Row(col1="foo", col2="bar", col3="baz", col4="qux2", col5="quux", col6=4) + row3 = Row(col1="foo", col2="bar", col3="baz", col4="qux1", col5="quux", col6=4) + partition_key = ("foo", "bar") + primary_key1 = ("foo", "bar", "baz", "qux") + primary_key2 = ("foo", "bar", "baz", "qux2") + primary_key3 = ("foo", "bar", "baz", "qux1") + + table.insert(row1) + table.insert(row2) + table.insert(row3) + + assert table.get_from_primary_key(primary_key1) == row1 + assert table.get_from_primary_key(primary_key2) == row2 + assert table.get_from_primary_key(primary_key3) == row3 + + # order matters + assert list(table.get_from_token(table.token(partition_key))) == [row1, row3, row2] + + all_rows = list(table.iter_all()) + assert len(all_rows) == 3 + for row in (row1, row2, row3): + assert (table.primary_key(row), row) in all_rows