diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -3,7 +3,6 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import copy import datetime import inspect import itertools @@ -4035,22 +4034,14 @@ assert swh_storage.origin_count("github", regexp=True, with_visit=True) == 1 @settings(suppress_health_check=[HealthCheck.too_slow]) - @given(strategies.lists(objects(), max_size=2)) + @given(strategies.lists(objects(split_content=True), max_size=2)) def test_add_arbitrary(self, swh_storage, objects): for (obj_type, obj) in objects: - obj = obj.to_dict() - if obj_type == "origin_visit": - origin_url = obj.pop("origin") - swh_storage.origin_add([{"url": origin_url}]) - if "visit" in obj: - del obj["visit"] - visit = OriginVisit( - origin=origin_url, date=obj["date"], type=obj["type"], - ) + if obj.object_type == "origin_visit": + swh_storage.origin_add([Origin(url=obj.origin)]) + visit = OriginVisit(origin=obj.origin, date=obj.date, type=obj.type,) swh_storage.origin_visit_add([visit]) else: - if obj_type == "content" and obj["status"] == "absent": - obj_type = "skipped_content" method = getattr(swh_storage, obj_type + "_add") try: method([obj]) @@ -4064,30 +4055,32 @@ # This test is only relevant on the local storage, with an actual # objstorage raising an exception - def test_content_add_objstorage_exception(self, swh_storage): + def test_content_add_objstorage_exception(self, swh_storage, sample_data_model): + content = sample_data_model["content"][0] + swh_storage.objstorage.content_add = Mock( side_effect=Exception("mocked broken objstorage") ) - with pytest.raises(Exception) as e: - swh_storage.content_add([data.cont]) + with pytest.raises(Exception, match="mocked broken"): + swh_storage.content_add([content]) - assert e.value.args == ("mocked broken objstorage",) - missing = list(swh_storage.content_missing([data.cont])) - assert missing == [data.cont["sha1"]] + missing = list(swh_storage.content_missing([content.hashes()])) + assert missing == [content.sha1] @pytest.mark.db class TestStorageRaceConditions: @pytest.mark.xfail - def test_content_add_race(self, swh_storage): + def test_content_add_race(self, swh_storage, sample_data_model): + content = sample_data_model["content"][0] results = queue.Queue() def thread(): try: with db_transaction(swh_storage) as (db, cur): - ret = swh_storage.content_add([data.cont], db=db, cur=cur) + ret = swh_storage.content_add([content], db=db, cur=cur) results.put((threading.get_ident(), "data", ret)) except Exception as e: results.put((threading.get_ident(), "exc", e)) @@ -4121,7 +4114,9 @@ """ - def test_content_update_with_new_cols(self, swh_storage): + def test_content_update_with_new_cols(self, swh_storage, sample_data_model): + content, content2 = sample_data_model["content"][:2] + swh_storage.journal_writer.journal = None # TODO, not supported with db_transaction(swh_storage) as (_, cur): @@ -4131,8 +4126,9 @@ add column test2 text default null""" ) - cont = copy.deepcopy(data.cont2) - swh_storage.content_add([cont]) + swh_storage.content_add([content]) + + cont = content.to_dict() cont["test"] = "value-1" cont["test2"] = "value-2" @@ -4163,73 +4159,66 @@ drop column test2""" ) - def test_content_add_db(self, swh_storage): - cont = data.cont + def test_content_add_db(self, swh_storage, sample_data_model): + content = sample_data_model["content"][0] - actual_result = swh_storage.content_add([cont]) + actual_result = swh_storage.content_add([content]) assert actual_result == { "content:add": 1, - "content:add:bytes": cont["length"], + "content:add:bytes": content.length, } if hasattr(swh_storage, "objstorage"): - assert cont["sha1"] in swh_storage.objstorage.objstorage + assert content.sha1 in swh_storage.objstorage.objstorage with db_transaction(swh_storage) as (_, cur): cur.execute( "SELECT sha1, sha1_git, sha256, length, status" " FROM content WHERE sha1 = %s", - (cont["sha1"],), + (content.sha1,), ) datum = cur.fetchone() assert datum == ( - cont["sha1"], - cont["sha1_git"], - cont["sha256"], - cont["length"], + content.sha1, + content.sha1_git, + content.sha256, + content.length, "visible", ) - expected_cont = cont.copy() - del expected_cont["data"] contents = [ obj for (obj_type, obj) in swh_storage.journal_writer.journal.objects if obj_type == "content" ] assert len(contents) == 1 - for obj in contents: - obj_d = obj.to_dict() - del obj_d["ctime"] - assert obj_d == expected_cont + assert contents[0] == attr.evolve(content, data=None) - def test_content_add_metadata_db(self, swh_storage): - cont = data.cont - del cont["data"] - cont["ctime"] = now() + def test_content_add_metadata_db(self, swh_storage, sample_data_model): + content = attr.evolve(sample_data_model["content"][0], data=None, ctime=now()) - actual_result = swh_storage.content_add_metadata([cont]) + actual_result = swh_storage.content_add_metadata([content]) assert actual_result == { "content:add": 1, } if hasattr(swh_storage, "objstorage"): - assert cont["sha1"] not in swh_storage.objstorage.objstorage + assert content.sha1 not in swh_storage.objstorage.objstorage with db_transaction(swh_storage) as (_, cur): cur.execute( "SELECT sha1, sha1_git, sha256, length, status" " FROM content WHERE sha1 = %s", - (cont["sha1"],), + (content.sha1,), ) datum = cur.fetchone() assert datum == ( - cont["sha1"], - cont["sha1_git"], - cont["sha256"], - cont["length"], + content.sha1, + content.sha1_git, + content.sha256, + content.length, "visible", ) @@ -4239,16 +4228,13 @@ if obj_type == "content" ] assert len(contents) == 1 - for obj in contents: - obj_d = obj.to_dict() - assert obj_d == cont + assert contents[0] == content - def test_skipped_content_add_db(self, swh_storage): - cont = data.skipped_cont - cont2 = data.skipped_cont2 - cont2["blake2s256"] = None + def test_skipped_content_add_db(self, swh_storage, sample_data_model): + content, cont2 = sample_data_model["skipped_content"][:2] + content2 = attr.evolve(cont2, blake2s256=None) - actual_result = swh_storage.skipped_content_add([cont, cont, cont2]) + actual_result = swh_storage.skipped_content_add([content, content, content2]) assert 2 <= actual_result.pop("skipped_content:add") <= 3 assert actual_result == {} @@ -4264,21 +4250,21 @@ assert len(dbdata) == 2 assert dbdata[0] == ( - cont["sha1"], - cont["sha1_git"], - cont["sha256"], - cont["blake2s256"], - cont["length"], + content.sha1, + content.sha1_git, + content.sha256, + content.blake2s256, + content.length, "absent", "Content too long", ) assert dbdata[1] == ( - cont2["sha1"], - cont2["sha1_git"], - cont2["sha256"], - cont2["blake2s256"], - cont2["length"], + content2.sha1, + content2.sha1_git, + content2.sha256, + content2.blake2s256, + content2.length, "absent", "Content too long", )