diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -39,7 +39,6 @@ from swh.model.model import ( Content, SkippedContent, - Directory, Revision, Release, Snapshot, @@ -58,6 +57,8 @@ from swh.storage.cassandra.model import ( BaseRow, ContentRow, + DirectoryRow, + DirectoryEntryRow, ObjectCountRow, SkippedContentRow, ) @@ -199,6 +200,13 @@ return (partition_key, clustering_key) + def get_from_partition_key(self, partition_key: Tuple) -> Iterable[TRow]: + """Returns at most one row, from its partition key.""" + token = self.token(partition_key) + for row in self.get_from_token(token): + if self.partition_key(row) == partition_key: + yield row + def get_from_primary_key(self, primary_key: Tuple) -> Optional[TRow]: """Returns at most one row, from its primary key.""" (partition_key, clustering_key) = self.split_primary_key(primary_key) @@ -220,6 +228,9 @@ for (clustering_key, row) in partition.items() ) + def get_random(self) -> Optional[TRow]: + return random.choice([row for (pk, row) in self.iter_all()]) + class InMemoryCqlRunner: def __init__(self): @@ -227,6 +238,8 @@ self._content_indexes = defaultdict(lambda: defaultdict(set)) self._skipped_contents = Table(ContentRow) self._skipped_content_indexes = defaultdict(lambda: defaultdict(set)) + self._directories = Table(DirectoryRow) + self._directory_entries = Table(DirectoryEntryRow) self._stat_counters = defaultdict(int) def increment_counter(self, object_type: str, nb: int): @@ -258,13 +271,7 @@ return self._contents.get_from_token(token) def content_get_random(self) -> Optional[ContentRow]: - return random.choice( - [ - row - for partition in self._contents.data.values() - for row in partition.values() - ] - ) + return self._contents.get_random() def content_get_token_range( self, start: int, end: int, limit: int, @@ -333,7 +340,32 @@ ########################## def directory_missing(self, ids: List[bytes]) -> List[bytes]: - return ids + missing = [] + for id_ in ids: + if self._directories.get_from_primary_key((id_,)) is None: + missing.append(id_) + + return missing + + def directory_add_one(self, directory: DirectoryRow) -> None: + self._directories.insert(directory) + self.increment_counter("directory", 1) + + def directory_get_random(self) -> Optional[DirectoryRow]: + return self._directories.get_random() + + ########################## + # 'directory_entry' table + ########################## + + def directory_entry_add_one(self, entry: DirectoryEntryRow) -> None: + self._directory_entries.insert(entry) + + def directory_entry_get( + self, directory_ids: List[Sha1Git] + ) -> Iterable[DirectoryEntryRow]: + for id_ in directory_ids: + yield from self._directory_entries.get_from_partition_key((id_,)) ########################## # 'revision' table @@ -359,7 +391,6 @@ def reset(self): self._cql_runner = InMemoryCqlRunner() - self._directories = {} self._revisions = {} self._releases = {} self._snapshots = {} @@ -404,80 +435,6 @@ def check_config(self, *, check_write: bool) -> bool: return True - def directory_add(self, directories: List[Directory]) -> Dict: - directories = [dir_ for dir_ in directories if dir_.id not in self._directories] - self.journal_writer.directory_add(directories) - - count = 0 - for directory in directories: - count += 1 - self._directories[directory.id] = directory - self._objects[directory.id].append(("directory", directory.id)) - - self._cql_runner.increment_counter("directory", len(directories)) - - return {"directory:add": count} - - def directory_missing(self, directories: List[Sha1Git]) -> Iterable[Sha1Git]: - for id in directories: - if id not in self._directories: - yield id - - def _directory_ls(self, directory_id, recursive, prefix=b""): - if directory_id in self._directories: - for entry in self._directories[directory_id].entries: - ret = self._join_dentry_to_content(entry) - ret["name"] = prefix + ret["name"] - ret["dir_id"] = directory_id - yield ret - if recursive and ret["type"] == "dir": - yield from self._directory_ls( - ret["target"], True, prefix + ret["name"] + b"/" - ) - - def directory_ls( - self, directory: Sha1Git, recursive: bool = False - ) -> Iterable[Dict[str, Any]]: - yield from self._directory_ls(directory, recursive) - - def directory_entry_get_by_path( - self, directory: Sha1Git, paths: List[bytes] - ) -> Optional[Dict[str, Any]]: - return self._directory_entry_get_by_path(directory, paths, b"") - - def directory_get_random(self) -> Sha1Git: - return random.choice(list(self._directories)) - - def _directory_entry_get_by_path( - self, directory: Sha1Git, paths: List[bytes], prefix: bytes - ) -> Optional[Dict[str, Any]]: - if not paths: - return None - - contents = list(self.directory_ls(directory)) - - if not contents: - return None - - def _get_entry(entries, name): - for entry in entries: - if entry["name"] == name: - entry = entry.copy() - entry["name"] = prefix + entry["name"] - return entry - - first_item = _get_entry(contents, paths[0]) - - if len(paths) == 1: - return first_item - - if not first_item or first_item["type"] != "dir": - return None - - return self._directory_entry_get_by_path( - first_item["target"], paths[1:], prefix + paths[0] + b"/" - ) - def revision_add(self, revisions: List[Revision]) -> Dict: revisions = [rev for rev in revisions if rev.id not in self._revisions] self.journal_writer.revision_add(revisions) diff --git a/swh/storage/tests/test_filter.py b/swh/storage/tests/test_filter.py --- a/swh/storage/tests/test_filter.py +++ b/swh/storage/tests/test_filter.py @@ -114,7 +114,7 @@ def test_filtering_proxy_storage_directory(swh_storage, sample_data): sample_directory = sample_data.directory - directory = next(swh_storage.directory_missing([sample_directory.id])) + directory = list(swh_storage.directory_missing([sample_directory.id]))[0] assert directory s = swh_storage.directory_add([sample_directory]) diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py --- a/swh/storage/tests/test_in_memory.py +++ b/swh/storage/tests/test_in_memory.py @@ -125,25 +125,35 @@ row1 = Row(col1="foo", col2="bar", col3="baz", col4="qux", col5="quux", col6=4) row2 = Row(col1="foo", col2="bar", col3="baz", col4="qux2", col5="quux", col6=4) row3 = Row(col1="foo", col2="bar", col3="baz", col4="qux1", col5="quux", col6=4) + row4 = Row(col1="foo", col2="bar2", col3="baz", col4="qux1", col5="quux", col6=4) partition_key = ("foo", "bar") + partition_key4 = ("foo", "bar2") primary_key1 = ("foo", "bar", "baz", "qux") primary_key2 = ("foo", "bar", "baz", "qux2") primary_key3 = ("foo", "bar", "baz", "qux1") + primary_key4 = ("foo", "bar2", "baz", "qux1") table.insert(row1) table.insert(row2) table.insert(row3) + table.insert(row4) assert table.get_from_primary_key(primary_key1) == row1 assert table.get_from_primary_key(primary_key2) == row2 assert table.get_from_primary_key(primary_key3) == row3 + assert table.get_from_primary_key(primary_key4) == row4 # order matters assert list(table.get_from_token(table.token(partition_key))) == [row1, row3, row2] + # order matters + assert list(table.get_from_partition_key(partition_key)) == [row1, row3, row2] + + assert list(table.get_from_partition_key(partition_key4)) == [row4] + all_rows = list(table.iter_all()) - assert len(all_rows) == 3 - for row in (row1, row2, row3): + assert len(all_rows) == 4 + for row in (row1, row2, row3, row4): assert (table.primary_key(row), row) in all_rows diff --git a/swh/storage/tests/test_replay.py b/swh/storage/tests/test_replay.py --- a/swh/storage/tests/test_replay.py +++ b/swh/storage/tests/test_replay.py @@ -206,7 +206,6 @@ assert got_persons == expected_persons for attr_ in ( - "directories", "revisions", "releases", "snapshots", @@ -223,6 +222,7 @@ for attr_ in ( "contents", "skipped_contents", + "directories", ): if exclude and attr_ in exclude: continue @@ -380,7 +380,6 @@ assert got_persons == expected_persons for attr_ in ( - "directories", "revisions", "releases", "snapshots", @@ -399,6 +398,7 @@ for attr_ in ( "contents", "skipped_contents", + "directories", ): expected_objects = [ (id, nullify_ctime(maybe_anonymize(attr_, obj)))