diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -191,6 +191,14 @@ db.copy_to(content_with_data, 'tmp_content', db.content_add_keys, cur) + # Create a read/write dependency between transactions that would + # write the same content, so that we get a SerializationFailure + # (read/write conflict) instead of an IntegrityError (write/write + # conflict) + cur.execute('SELECT 1 FROM content WHERE sha1 IN %s', + (tuple(cont['sha1'] for cont in content_with_data),)) + list(cur) + # move metadata in place try: db.content_add_from_temp(cur) @@ -264,6 +272,8 @@ content:add:bytes: Sum of the contents' length data skipped_content:add: New skipped contents (no data) added """ + cur.execute('SET TRANSACTION ISOLATION LEVEL SERIALIZABLE') + content = [dict(c.items()) for c in content] # semi-shallow copy now = datetime.datetime.now(tz=datetime.timezone.utc) for item in content: @@ -379,6 +389,7 @@ content:add: New contents added skipped_content:add: New skipped contents (no data) added """ + cur.execute('SET TRANSACTION ISOLATION LEVEL SERIALIZABLE') content = [self._normalize_content(c) for c in content] for c in content: diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -3238,7 +3238,6 @@ @pytest.mark.db class TestStorageRaceConditions: - @pytest.mark.xfail def test_content_add_race(self, swh_storage): results = queue.Queue() @@ -3267,9 +3266,16 @@ with pytest.raises(queue.Empty): results.get(block=False) - assert r1[0] != r2[0] - assert r1[1] == 'data', 'Got exception %r in Thread%s' % (r1[2], r1[0]) - assert r2[1] == 'data', 'Got exception %r in Thread%s' % (r2[2], r2[0]) + assert r1[0] != r2[0] # Ident is unique + assert r1[1] != r2[1] # Don't have same result + if r1[1] == 'data': + (r_data, r_error) = (r1, r2) + else: + (r_data, r_error) = (r2, r1) + assert r_data[1] == 'data' + assert r_error[1] == 'exc', 'Got no exception %r in Thread%s' % ( + r_error[2], r_error[0]) + assert isinstance(r_error[2], psycopg2.errors.SerializationFailure) @pytest.mark.db