diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -6,7 +6,11 @@ import copy import datetime import itertools +import queue import random +import sys +import threading +import time import unittest from collections import defaultdict from unittest.mock import Mock, patch @@ -3971,6 +3975,59 @@ self.assertEqual(missing, [self.cont['sha1']]) +@pytest.mark.db +class TestStorageRaceConditions(TestStorageData, StorageTestDbFixture, + unittest.TestCase): + def test_content_add_race(self): + + results = queue.Queue() + + def thread1(): + db1 = None + try: + db1 = self.storage.get_db() + with db1.transaction() as cur1: + ret = self.storage.content_add([self.cont], db=db1, + cur=cur1) + results.put(('thread1', 'data', ret)) + except Exception: + results.put(('thread1', 'exc', sys.exc_info())) + finally: + if db1: + self.storage.put_db(db1) + + def thread2(): + db2 = None + try: + db2 = self.storage.get_db() + with db2.transaction() as cur2: + ret = self.storage.content_add([self.cont], db=db2, + cur=cur2) + results.put(('thread2', 'data', ret)) + except Exception as e: + results.put(('thread2', 'exc', e)) + finally: + if db2: + self.storage.put_db(db2) + + t1 = threading.Thread(target=thread1) + t2 = threading.Thread(target=thread2) + t1.start() + #time.sleep(1) + t2.start() + t1.join() + t2.join() + + r1 = results.get(block=False) + r2 = results.get(block=False) + + with pytest.raises(queue.Empty): + results.get(block=False) + assert {r1[0], r2[0]} == {'thread1', 'thread2'} + assert r1[1] == 'data', 'Got exception %r in %s' % (r1[2], r1[0]) + assert r2[1] == 'data', 'Got exception %r in %s' % (r2[2], r2[0]) + + @pytest.mark.db @pytest.mark.property_based class PropTestLocalStorage(CommonPropTestStorage, StorageTestDbFixture,