diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -457,7 +457,9 @@ ) @_prepared_select_statement(RevisionRow, "WHERE id IN ?") - def revision_get(self, revision_ids, *, statement) -> Iterable[RevisionRow]: + def revision_get( + self, revision_ids: List[Sha1Git], *, statement + ) -> Iterable[RevisionRow]: return map( RevisionRow.from_dict, self._execute_with_retries(statement, [revision_ids]) ) diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -25,7 +25,6 @@ Iterator, List, Optional, - Set, Tuple, Type, TypeVar, @@ -39,7 +38,6 @@ from swh.model.model import ( Content, SkippedContent, - Revision, Release, Snapshot, OriginVisit, @@ -60,6 +58,8 @@ DirectoryRow, DirectoryEntryRow, ObjectCountRow, + RevisionRow, + RevisionParentRow, SkippedContentRow, ) from swh.storage.interface import ( @@ -240,6 +240,8 @@ self._skipped_content_indexes = defaultdict(lambda: defaultdict(set)) self._directories = Table(DirectoryRow) self._directory_entries = Table(DirectoryEntryRow) + self._revisions = Table(RevisionRow) + self._revision_parents = Table(RevisionParentRow) self._stat_counters = defaultdict(int) def increment_counter(self, object_type: str, nb: int): @@ -294,7 +296,6 @@ for id_ in ids: if id_ not in self._content_indexes["sha1_git"]: missing.append(id_) - return missing def content_index_add_one(self, algo: str, content: Content, token: int) -> None: @@ -344,7 +345,6 @@ for id_ in ids: if self._directories.get_from_primary_key((id_,)) is None: missing.append(id_) - return missing def directory_add_one(self, directory: DirectoryRow) -> None: @@ -371,8 +371,41 @@ # 'revision' table ########################## - def revision_missing(self, ids: List[bytes]) -> List[bytes]: - return ids + def revision_missing(self, ids: List[bytes]) -> Iterable[bytes]: + missing = [] + for id_ in ids: + if self._revisions.get_from_primary_key((id_,)) is None: + missing.append(id_) + return missing + + def revision_add_one(self, revision: RevisionRow) -> None: + self._revisions.insert(revision) + self.increment_counter("revision", 1) + + def revision_get_ids(self, revision_ids) -> Iterable[int]: + for id_ in revision_ids: + if self._revisions.get_from_primary_key((id_,)) is not None: + yield id_ + + def revision_get(self, revision_ids: List[Sha1Git]) -> Iterable[RevisionRow]: + for id_ in revision_ids: + row = self._revisions.get_from_primary_key((id_,)) + if row: + yield row + + def revision_get_random(self) -> Optional[RevisionRow]: + return self._revisions.get_random() + + ########################## + # 'revision_parent' table + ########################## + + def revision_parent_add_one(self, revision_parent: RevisionParentRow) -> None: + self._revision_parents.insert(revision_parent) + + def revision_parent_get(self, revision_id: Sha1Git) -> Iterable[bytes]: + for parent in self._revision_parents.get_from_partition_key((revision_id,)): + yield parent.parent_id ########################## # 'release' table @@ -391,7 +424,6 @@ def reset(self): self._cql_runner = InMemoryCqlRunner() - self._revisions = {} self._releases = {} self._snapshots = {} self._origins = {} @@ -435,69 +467,6 @@ def check_config(self, *, check_write: bool) -> bool: return True - def revision_add(self, revisions: List[Revision]) -> Dict: - revisions = [rev for rev in revisions if rev.id not in self._revisions] - self.journal_writer.revision_add(revisions) - - count = 0 - for revision in revisions: - revision = attr.evolve( - revision, - committer=self._person_add(revision.committer), - author=self._person_add(revision.author), - ) - self._revisions[revision.id] = revision - self._objects[revision.id].append(("revision", revision.id)) - count += 1 - - self._cql_runner.increment_counter("revision", len(revisions)) - - return {"revision:add": count} - - def revision_missing(self, revisions: List[Sha1Git]) -> Iterable[Sha1Git]: - for id in revisions: - if id not in self._revisions: - yield id - - def revision_get( - self, revisions: List[Sha1Git] - ) -> Iterable[Optional[Dict[str, Any]]]: - for id in revisions: - if id in self._revisions: - yield self._revisions.get(id).to_dict() - else: - yield None - - def __get_parent_revs( - self, rev_id: Sha1Git, seen: Set[Sha1Git], limit: Optional[int] - ) -> Iterable[Dict[str, Any]]: - if limit and len(seen) >= limit: - return - if rev_id in seen or rev_id not in self._revisions: - return - seen.add(rev_id) - yield self._revisions[rev_id].to_dict() - for parent in self._revisions[rev_id].parents: - yield from self.__get_parent_revs(parent, seen, limit) - - def revision_log( - self, revisions: List[Sha1Git], limit: Optional[int] = None - ) -> Iterable[Optional[Dict[str, Any]]]: - seen: Set[Sha1Git] = set() - for rev_id in revisions: - yield from self.__get_parent_revs(rev_id, seen, limit) - - def revision_shortlog( - self, revisions: List[Sha1Git], limit: Optional[int] = None - ) -> Iterable[Optional[Tuple[Sha1Git, Tuple[Sha1Git, ...]]]]: - yield from ( - (rev["id"], rev["parents"]) if rev else None - for rev in self.revision_log(revisions, limit) - ) - - def revision_get_random(self) -> Sha1Git: - return random.choice(list(self._revisions)) - def release_add(self, releases: List[Release]) -> Dict: to_add = [] for rel in releases: diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py --- a/swh/storage/tests/test_api_client.py +++ b/swh/storage/tests/test_api_client.py @@ -64,3 +64,10 @@ @pytest.mark.skip("content_update is not yet implemented for Cassandra") def test_content_update(self): pass + + @pytest.mark.skip( + 'The "person" table of the pgsql is a legacy thing, and not ' + "supported by the cassandra backend." + ) + def test_person_fullname_unicity(self): + pass diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py --- a/swh/storage/tests/test_in_memory.py +++ b/swh/storage/tests/test_in_memory.py @@ -158,6 +158,13 @@ class TestInMemoryStorage(_TestStorage): + @pytest.mark.skip( + 'The "person" table of the pgsql is a legacy thing, and not ' + "supported by the cassandra backend." + ) + def test_person_fullname_unicity(self): + pass + @pytest.mark.skip("content_update is not yet implemented for Cassandra") def test_content_update(self): pass diff --git a/swh/storage/tests/test_replay.py b/swh/storage/tests/test_replay.py --- a/swh/storage/tests/test_replay.py +++ b/swh/storage/tests/test_replay.py @@ -206,7 +206,6 @@ assert got_persons == expected_persons for attr_ in ( - "revisions", "releases", "snapshots", "origins", @@ -223,6 +222,7 @@ "contents", "skipped_contents", "directories", + "revisions", ): if exclude and attr_ in exclude: continue @@ -380,7 +380,6 @@ assert got_persons == expected_persons for attr_ in ( - "revisions", "releases", "snapshots", "origins", @@ -399,6 +398,7 @@ "contents", "skipped_contents", "directories", + "revisions", ): expected_objects = [ (id, nullify_ctime(maybe_anonymize(attr_, obj)))