diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -38,7 +38,6 @@ from swh.model.model import ( Content, SkippedContent, - Release, Snapshot, OriginVisit, OriginVisitStatus, @@ -58,6 +57,7 @@ DirectoryRow, DirectoryEntryRow, ObjectCountRow, + ReleaseRow, RevisionRow, RevisionParentRow, SkippedContentRow, @@ -234,6 +234,7 @@ self._directory_entries = Table(DirectoryEntryRow) self._revisions = Table(RevisionRow) self._revision_parents = Table(RevisionParentRow) + self._releases = Table(RevisionRow) self._stat_counters = defaultdict(int) def increment_counter(self, object_type: str, nb: int): @@ -404,7 +405,24 @@ ########################## def release_missing(self, ids: List[bytes]) -> List[bytes]: - return ids + missing = [] + for id_ in ids: + if self._releases.get_from_primary_key((id_,)) is None: + missing.append(id_) + return missing + + def release_add_one(self, release: ReleaseRow) -> None: + self._releases.insert(release) + self.increment_counter("release", 1) + + def release_get(self, release_ids: List[str]) -> Iterable[ReleaseRow]: + for id_ in release_ids: + row = self._releases.get_from_primary_key((id_,)) + if row: + yield row + + def release_get_random(self) -> Optional[ReleaseRow]: + return self._releases.get_random() class InMemoryStorage(CassandraStorage): @@ -416,7 +434,6 @@ def reset(self): self._cql_runner = InMemoryCqlRunner() - self._releases = {} self._snapshots = {} self._origins = {} self._origins_by_sha1 = {} @@ -459,38 +476,6 @@ def check_config(self, *, check_write: bool) -> bool: return True - def release_add(self, releases: List[Release]) -> Dict: - to_add = [] - for rel in releases: - if rel.id not in self._releases and rel not in to_add: - to_add.append(rel) - self.journal_writer.release_add(to_add) - - for rel in to_add: - if rel.author: - self._person_add(rel.author) - self._objects[rel.id].append(("release", rel.id)) - self._releases[rel.id] = rel - - self._cql_runner.increment_counter("release", len(to_add)) - - return {"release:add": len(to_add)} - - def release_missing(self, releases: List[Sha1Git]) -> Iterable[Sha1Git]: - yield from (rel for rel in releases if rel not in self._releases) - - def release_get( - self, releases: List[Sha1Git] - ) -> Iterable[Optional[Dict[str, Any]]]: - for rel_id in releases: - if rel_id in self._releases: - yield self._releases[rel_id].to_dict() - else: - yield None - - def release_get_random(self) -> Sha1Git: - return random.choice(list(self._releases)) - def snapshot_add(self, snapshots: List[Snapshot]) -> Dict: count = 0 snapshots = [snap for snap in snapshots if snap.id not in self._snapshots] @@ -595,13 +580,6 @@ def snapshot_get_random(self) -> Sha1Git: return random.choice(list(self._snapshots)) - def object_find_by_sha1_git(self, ids: List[Sha1Git]) -> Dict[Sha1Git, List[Dict]]: - ret = super().object_find_by_sha1_git(ids) - for id_ in ids: - objs = self._objects.get(id_, []) - ret[id_].extend([{"sha1_git": id_, "type": obj[0],} for obj in objs]) - return ret - def _convert_origin(self, t): if t is None: return None @@ -1143,14 +1121,6 @@ else: raise TypeError("origin must be a string.") - def _person_add(self, person): - key = ("person", person.fullname) - if key not in self._objects: - self._persons[person.fullname] = person - self._objects[key].append(key) - - return self._persons[person.fullname] - @staticmethod def _metadata_fetcher_key(fetcher: MetadataFetcher) -> FetcherKey: return (fetcher.name, fetcher.version)