diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -543,7 +543,7 @@ snapshots = [snp for snp in snapshots if snp.id in missing] for snapshot in snapshots: - self.journal_writer.snapshot_add(snapshot) + self.journal_writer.snapshot_add([snapshot]) # Add branches for (branch_name, branch) in snapshot.branches.items(): @@ -787,7 +787,7 @@ if known_origin: origin_url = known_origin["url"] else: - self.journal_writer.origin_add_one(origin) + self.journal_writer.origin_add([origin]) self._cql_runner.origin_add_one(origin) origin_url = origin.url @@ -821,7 +821,7 @@ } ) - self.journal_writer.origin_visit_add(visit) + self.journal_writer.origin_visit_add([visit]) self._cql_runner.origin_visit_add_one(visit) with convert_validation_exceptions(): @@ -868,7 +868,7 @@ with convert_validation_exceptions(): visit = attr.evolve(visit, **updates) - self.journal_writer.origin_visit_update(visit) + self.journal_writer.origin_visit_update([visit]) last_visit_update = self._origin_visit_get_updated(visit.origin, visit.visit) assert last_visit_update is not None diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -529,7 +529,7 @@ count = 0 snapshots = (snap for snap in snapshots if snap.id not in self._snapshots) for snapshot in snapshots: - self.journal_writer.snapshot_add(snapshot) + self.journal_writer.snapshot_add([snapshot]) sorted_branch_names = sorted(snapshot.branches) self._snapshots[snapshot.id] = (snapshot, sorted_branch_names) self._objects[snapshot.id].append(("snapshot", snapshot.id)) @@ -756,7 +756,7 @@ def origin_add_one(self, origin: Origin) -> str: if origin.url not in self._origins: - self.journal_writer.origin_add_one(origin) + self.journal_writer.origin_add([origin]) # generate an origin_id because it is needed by origin_get_range. # TODO: remove this when we remove origin_get_range origin_id = len(self._origins) + 1 @@ -816,7 +816,7 @@ self._objects[visit_key].append(("origin_visit", None)) - self.journal_writer.origin_visit_add(visit) + self.journal_writer.origin_visit_add([visit]) # return last visit return visit @@ -859,7 +859,7 @@ self._origin_visit_statuses[visit_key].append(visit_update) self.journal_writer.origin_visit_update( - self._origin_visit_get_updated(origin_url, visit_id) + [self._origin_visit_get_updated(origin_url, visit_id)] ) self._origin_visits[origin_url][visit_id - 1] = visit diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -707,7 +707,7 @@ cur, ) - self.journal_writer.snapshot_add(snapshot) + self.journal_writer.snapshot_add([snapshot]) db.snapshot_add(snapshot.id, cur) count += 1 @@ -871,7 +871,7 @@ ) self._origin_visit_status_add(visit_status, db=db, cur=cur) - self.journal_writer.origin_visit_add(visit) + self.journal_writer.origin_visit_add([visit]) send_metric("origin_visit:add", count=1, method_name="origin_visit") return visit @@ -922,7 +922,7 @@ if updates: with convert_validation_exceptions(): updated_visit = OriginVisit.from_dict({**visit, **updates}) - self.journal_writer.origin_visit_update(updated_visit) + self.journal_writer.origin_visit_update([updated_visit]) # Write updates to origin visit (backward compatibility) db.origin_visit_update(origin, visit_id, updates) @@ -1202,7 +1202,7 @@ if origin_url: return origin_url - self.journal_writer.origin_add_one(origin) + self.journal_writer.origin_add([origin]) url = db.origin_add(origin.url, cur) send_metric("origin:add", count=1, method_name="origin_add_one") diff --git a/swh/storage/writer.py b/swh/storage/writer.py --- a/swh/storage/writer.py +++ b/swh/storage/writer.py @@ -3,7 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Iterable, Union +from typing import Iterable from attr import evolve @@ -42,67 +42,47 @@ else: self.journal = None + def write_additions(self, obj_type, values) -> None: + if self.journal: + self.journal.write_additions(obj_type, values) + def content_add(self, contents: Iterable[Content]) -> None: """Add contents to the journal. Drop the data field if provided. """ - if not self.journal: - return contents = [evolve(item, data=None) for item in contents] - self.journal.write_additions("content", contents) + self.write_additions("content", contents) def content_update(self, contents: Iterable[Content]) -> None: - if not self.journal: - return - raise NotImplementedError( - "content_update is not yet supported with a journal writer." - ) + if self.journal: + raise NotImplementedError("content_update is not supported by the journal.") def content_add_metadata(self, contents: Iterable[Content]) -> None: - return self.content_add(contents) + self.content_add(contents) def skipped_content_add(self, contents: Iterable[SkippedContent]) -> None: - if not self.journal: - return - self.journal.write_additions("content", contents) + self.write_additions("content", contents) def directory_add(self, directories: Iterable[Directory]) -> None: - if not self.journal: - return - self.journal.write_additions("directory", directories) + self.write_additions("directory", directories) def revision_add(self, revisions: Iterable[Revision]) -> None: - if not self.journal: - return - self.journal.write_additions("revision", revisions) + self.write_additions("revision", revisions) def release_add(self, releases: Iterable[Release]) -> None: - if not self.journal: - return - self.journal.write_additions("release", releases) - - def snapshot_add(self, snapshots: Union[Iterable[Snapshot], Snapshot]) -> None: - if not self.journal: - return - snaps = snapshots if isinstance(snapshots, list) else [snapshots] - self.journal.write_additions("snapshot", snaps) - - def origin_visit_add(self, visit: OriginVisit): - if not self.journal: - return - self.journal.write_addition("origin_visit", visit) - - def origin_visit_update(self, visit: OriginVisit): - if not self.journal: - return - self.journal.write_update("origin_visit", visit) - - def origin_visit_upsert(self, visits: Iterable[OriginVisit]): - if not self.journal: - return - self.journal.write_additions("origin_visit", visits) - - def origin_add_one(self, origin: Origin): - if not self.journal: - return - self.journal.write_addition("origin", origin) + self.write_additions("release", releases) + + def snapshot_add(self, snapshots: Iterable[Snapshot]) -> None: + self.write_additions("snapshot", snapshots) + + def origin_visit_add(self, visits: Iterable[OriginVisit]) -> None: + self.write_additions("origin_visit", visits) + + def origin_visit_update(self, visits: Iterable[OriginVisit]) -> None: + self.write_additions("origin_visit", visits) + + def origin_visit_upsert(self, visits: Iterable[OriginVisit]) -> None: + self.write_additions("origin_visit", visits) + + def origin_add(self, origins: Iterable[Origin]) -> None: + self.write_additions("origin", origins)