diff --git a/swh/scrubber/db.py b/swh/scrubber/db.py --- a/swh/scrubber/db.py +++ b/swh/scrubber/db.py @@ -53,12 +53,26 @@ cur = self.cursor() cur.execute( """ - INSERT INTO datastore (package, class, instance) - VALUES (%s, %s, %s) - ON CONFLICT DO NOTHING - RETURNING id + WITH inserted AS ( + INSERT INTO datastore (package, class, instance) + VALUES (%(package)s, %(cls)s, %(instance)s) + ON CONFLICT DO NOTHING + RETURNING id + ) + SELECT id + FROM inserted + UNION ( + -- If the datastore already exists, we need to fetch its id + SELECT id + FROM datastore + WHERE + package=%(package)s + AND class=%(cls)s + AND instance=%(instance)s + ) + LIMIT 1 """, - (datastore.package, datastore.cls, datastore.instance), + (dataclasses.asdict(datastore)), ) (id_,) = cur.fetchone() return id_ diff --git a/swh/scrubber/tests/test_storage_postgresql.py b/swh/scrubber/tests/test_storage_postgresql.py --- a/swh/scrubber/tests/test_storage_postgresql.py +++ b/swh/scrubber/tests/test_storage_postgresql.py @@ -80,7 +80,7 @@ @patch_byte_ranges -def test_corrupt_snapshots(scrubber_db, swh_storage): +def test_corrupt_snapshots_same_batch(scrubber_db, swh_storage): snapshots = list(swh_model_data.SNAPSHOTS) for i in (0, 1): snapshots[i] = attr.evolve(snapshots[i], id=bytes([i]) * 20) @@ -103,3 +103,43 @@ "swh:1:snp:0101010101010101010101010101010101010101", ] } + + +@patch_byte_ranges +def test_corrupt_snapshots_different_batches(scrubber_db, swh_storage): + snapshots = list(swh_model_data.SNAPSHOTS) + for i in (0, 1): + snapshots[i] = attr.evolve(snapshots[i], id=bytes([i * 255]) * 20) + swh_storage.snapshot_add(snapshots) + + StorageChecker( + db=scrubber_db, + storage=swh_storage, + object_type="snapshot", + start_object="00" * 20, + end_object="87" * 20, + ).run() + + corrupt_objects = list(scrubber_db.corrupt_object_iter()) + assert len(corrupt_objects) == 1 + + # Simulates resuming from a different process, with an empty lru_cache + scrubber_db.datastore_get_or_add.cache_clear() + + StorageChecker( + db=scrubber_db, + storage=swh_storage, + object_type="snapshot", + start_object="88" * 20, + end_object="ff" * 20, + ).run() + + corrupt_objects = list(scrubber_db.corrupt_object_iter()) + assert len(corrupt_objects) == 2 + assert {co.id for co in corrupt_objects} == { + swhids.CoreSWHID.from_string(swhid) + for swhid in [ + "swh:1:snp:0000000000000000000000000000000000000000", + "swh:1:snp:ffffffffffffffffffffffffffffffffffffffff", + ] + }