diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -38,7 +38,6 @@ from swh.model.model import ( Content, SkippedContent, - Snapshot, OriginVisit, OriginVisitStatus, Origin, @@ -61,11 +60,12 @@ RevisionRow, RevisionParentRow, SkippedContentRow, + SnapshotRow, + SnapshotBranchRow, ) from swh.storage.interface import ( ListOrder, PagedResult, - PartialBranches, VISIT_STATUSES, ) from swh.storage.objstorage import ObjStorage @@ -243,6 +243,8 @@ self._revisions = Table(RevisionRow) self._revision_parents = Table(RevisionParentRow) self._releases = Table(ReleaseRow) + self._snapshots = Table(SnapshotRow) + self._snapshot_branches = Table(SnapshotBranchRow) self._stat_counters = defaultdict(int) def increment_counter(self, object_type: str, nb: int): @@ -432,6 +434,54 @@ def release_get_random(self) -> Optional[ReleaseRow]: return self._releases.get_random() + ########################## + # 'snapshot' table + ########################## + + def snapshot_missing(self, ids: List[bytes]) -> List[bytes]: + missing = [] + for id_ in ids: + if self._snapshots.get_from_primary_key((id_,)) is None: + missing.append(id_) + return missing + + def snapshot_add_one(self, snapshot: SnapshotRow) -> None: + self._snapshots.insert(snapshot) + self.increment_counter("snapshot", 1) + + def snapshot_get_random(self) -> Optional[SnapshotRow]: + return self._snapshots.get_random() + + ########################## + # 'snapshot_branch' table + ########################## + + def snapshot_branch_add_one(self, branch: SnapshotBranchRow) -> None: + self._snapshot_branches.insert(branch) + + def snapshot_count_branches(self, snapshot_id: Sha1Git) -> Dict[Optional[str], int]: + """Returns a dictionary from type names to the number of branches + of that type.""" + counts: Dict[Optional[str], int] = defaultdict(int) + for branch in self._snapshot_branches.get_from_partition_key((snapshot_id,)): + if branch.target_type is None: + target_type = None + else: + target_type = branch.target_type + counts[target_type] += 1 + return counts + + def snapshot_branch_get( + self, snapshot_id: Sha1Git, from_: bytes, limit: int + ) -> Iterable[SnapshotBranchRow]: + count = 0 + for branch in self._snapshot_branches.get_from_partition_key((snapshot_id,)): + if branch.name >= from_: + count += 1 + yield branch + if count >= limit: + break + class InMemoryStorage(CassandraStorage): _cql_runner: InMemoryCqlRunner # type: ignore @@ -442,7 +492,6 @@ def reset(self): self._cql_runner = InMemoryCqlRunner() - self._snapshots = {} self._origins = {} self._origins_by_sha1 = {} self._origin_visits = {} @@ -484,91 +533,6 @@ def check_config(self, *, check_write: bool) -> bool: return True - def snapshot_add(self, snapshots: List[Snapshot]) -> Dict: - count = 0 - snapshots = [snap for snap in snapshots if snap.id not in self._snapshots] - for snapshot in snapshots: - self.journal_writer.snapshot_add([snapshot]) - self._snapshots[snapshot.id] = snapshot - self._objects[snapshot.id].append(("snapshot", snapshot.id)) - count += 1 - - self._cql_runner.increment_counter("snapshot", len(snapshots)) - - return {"snapshot:add": count} - - def snapshot_missing(self, snapshots: List[Sha1Git]) -> Iterable[Sha1Git]: - for id in snapshots: - if id not in self._snapshots: - yield id - - def snapshot_get(self, snapshot_id: Sha1Git) -> Optional[Dict[str, Any]]: - d = self.snapshot_get_branches(snapshot_id) - if d is None: - return None - return { - "id": d["id"], - "branches": { - name: branch.to_dict() if branch else None - for (name, branch) in d["branches"].items() - }, - "next_branch": d["next_branch"], - } - - def snapshot_count_branches( - self, snapshot_id: Sha1Git - ) -> Optional[Dict[Optional[str], int]]: - snapshot = self._snapshots[snapshot_id] - return collections.Counter( - branch.target_type.value if branch else None - for branch in snapshot.branches.values() - ) - - def snapshot_get_branches( - self, - snapshot_id: Sha1Git, - branches_from: bytes = b"", - branches_count: int = 1000, - target_types: Optional[List[str]] = None, - ) -> Optional[PartialBranches]: - snapshot = self._snapshots.get(snapshot_id) - if snapshot is None: - return None - sorted_branches = sorted(snapshot.branches.items()) - sorted_branch_names = [k for (k, v) in sorted_branches] - from_index = bisect.bisect_left(sorted_branch_names, branches_from) - if target_types: - next_branch = None - branches: Dict = {} - for (branch_name, branch) in sorted_branches: - if branch_name in sorted_branch_names[from_index:]: - if branch and branch.target_type.value in target_types: - if len(branches) < branches_count: - branches[branch_name] = branch - else: - next_branch = branch_name - break - else: - # As there is no 'target_types', we can do that much faster - to_index = from_index + branches_count - returned_branch_names = frozenset(sorted_branch_names[from_index:to_index]) - branches = dict( - (branch_name, branch) - for (branch_name, branch) in snapshot.branches.items() - if branch_name in returned_branch_names - ) - if to_index >= len(sorted_branch_names): - next_branch = None - else: - next_branch = sorted_branch_names[to_index] - - return PartialBranches( - id=snapshot_id, branches=branches, next_branch=next_branch, - ) - - def snapshot_get_random(self) -> Sha1Git: - return random.choice(list(self._snapshots)) - def _convert_origin(self, t): if t is None: return None @@ -631,7 +595,7 @@ ) for ov in visits: snapshot = ov["snapshot"] - if snapshot and snapshot in self._snapshots: + if snapshot and not list(self.snapshot_missing([snapshot])): filtered_origins.append(orig) break else: diff --git a/swh/storage/tests/test_replay.py b/swh/storage/tests/test_replay.py --- a/swh/storage/tests/test_replay.py +++ b/swh/storage/tests/test_replay.py @@ -206,7 +206,6 @@ assert got_persons == expected_persons for attr_ in ( - "snapshots", "origins", "origin_visits", "origin_visit_statuses", @@ -223,6 +222,7 @@ "directories", "revisions", "releases", + "snapshots", ): if exclude and attr_ in exclude: continue @@ -380,7 +380,6 @@ assert got_persons == expected_persons for attr_ in ( - "snapshots", "origins", "origin_visit_statuses", ): @@ -399,6 +398,7 @@ "directories", "revisions", "releases", + "snapshots", ): expected_objects = [ (id, nullify_ctime(maybe_anonymize(attr_, obj)))