diff --git a/swh/loader/tests/__init__.py b/swh/loader/tests/__init__.py --- a/swh/loader/tests/__init__.py +++ b/swh/loader/tests/__init__.py @@ -7,11 +7,12 @@ import subprocess from pathlib import PosixPath -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union -from swh.model.model import OriginVisitStatus -from swh.model.hashutil import hash_to_bytes, hash_to_hex +from swh.model.model import OriginVisitStatus, Snapshot +from swh.model.hashutil import hash_to_bytes +from swh.storage.interface import StorageInterface from swh.storage.algos.origin import origin_get_latest_visit_status @@ -81,49 +82,67 @@ return repo_url -def decode_target(target): +def encode_target(target: Dict) -> Dict: """Test helper to ease readability in test """ if not target: return target target_type = target["target_type"] - - if target_type == "alias": - decoded_target = target["target"].decode("utf-8") + target_data = target["target"] + if target_type == "alias" and isinstance(target_data, str): + encoded_target = target_data.encode("utf-8") + elif isinstance(target_data, str): + encoded_target = hash_to_bytes(target_data) else: - decoded_target = hash_to_hex(target["target"]) + encoded_target = target_data - return {"target": decoded_target, "target_type": target_type} + return {"target": encoded_target, "target_type": target_type} -def check_snapshot(expected_snapshot, storage): +def check_snapshot( + snapshot: Union[Dict[str, Any], Snapshot], storage: StorageInterface +): """Check for snapshot match. - Provide the hashes as hexadecimal, the conversion is done - within the method. + The hashes can be both in hex or bytes, the necessary conversion will happen prior + to check. Args: - expected_snapshot (dict): full snapshot with hex ids - storage (Storage): expected storage + snapshot: full snapshot to check for existence and consistency + storage: storage to lookup information into Returns: the snapshot stored in the storage for further test assertion if any is needed. """ - expected_snapshot_id = expected_snapshot["id"] - expected_branches = expected_snapshot["branches"] - snap = storage.snapshot_get(hash_to_bytes(expected_snapshot_id)) + if isinstance(snapshot, Snapshot): + expected_snapshot = snapshot + elif isinstance(snapshot, dict): + # dict must be snapshot compliant + snapshot_dict = {"id": hash_to_bytes(snapshot["id"])} + branches = {} + for branch, target in snapshot["branches"].items(): + if isinstance(branch, str): + branch = branch.encode("utf-8") + branches[branch] = encode_target(target) + snapshot_dict["branches"] = branches + expected_snapshot = Snapshot.from_dict(snapshot_dict) + else: + raise AssertionError(f"variable 'snapshot' must be a snapshot: {snapshot!r}") + + snap = storage.snapshot_get(expected_snapshot.id) if snap is None: - raise AssertionError(f"Snapshot {expected_snapshot_id} is not found") - - branches = { - branch.decode("utf-8"): decode_target(target) - for branch, target in snap["branches"].items() - } - assert expected_branches == branches - return snap + raise AssertionError(f"Snapshot {expected_snapshot.id.hex()} is not found") + + assert snap["next_branch"] is None # we don't deal with large snapshot in tests + snap.pop("next_branch") + actual_snap = Snapshot.from_dict(snap) + + assert expected_snapshot == actual_snap + + return snap # for retro compat, returned the dict, remove when clients are migrated def get_stats(storage) -> Dict: diff --git a/swh/loader/tests/test_init.py b/swh/loader/tests/test_init.py --- a/swh/loader/tests/test_init.py +++ b/swh/loader/tests/test_init.py @@ -9,7 +9,6 @@ import os import subprocess -from swh.loader.tests import prepare_repository_from_archive, assert_last_visit_matches from swh.model.model import ( OriginVisit, OriginVisitStatus, @@ -20,8 +19,10 @@ from swh.model.hashutil import hash_to_bytes from swh.loader.tests import ( - decode_target, + assert_last_visit_matches, + encode_target, check_snapshot, + prepare_repository_from_archive, ) @@ -172,34 +173,33 @@ assert os.path.exists(expected_uncompressed_archive_path) -def test_decode_target_edge(): - assert not decode_target(None) +def test_encode_target(): + assert encode_target(None) is None + for target_alias in ["something", b"something"]: + target = { + "target_type": "alias", + "target": target_alias, + } + actual_alias_encode_target = encode_target(target) + assert actual_alias_encode_target == { + "target_type": "alias", + "target": b"something", + } -def test_decode_target(): - actual_alias_decode_target = decode_target( - {"target_type": "alias", "target": b"something",} - ) - - assert actual_alias_decode_target == { - "target_type": "alias", - "target": "something", - } - - actual_decode_target = decode_target( - {"target_type": "revision", "target": hash_to_bytes(hash_hex),} - ) - - assert actual_decode_target == { - "target_type": "revision", - "target": hash_hex, - } + for hash_ in [hash_hex, hash_to_bytes(hash_hex)]: + target = {"target_type": "revision", "target": hash_} + actual_encode_target = encode_target(target) + assert actual_encode_target == { + "target_type": "revision", + "target": hash_to_bytes(hash_hex), + } def test_check_snapshot(swh_storage): - snap_id = "2498dbf535f882bc7f9a18fb16c9ad27fda7bab7" + """Check snapshot should not raise when everything is fine""" snapshot = Snapshot( - id=hash_to_bytes(snap_id), + id=hash_to_bytes("2498dbf535f882bc7f9a18fb16c9ad27fda7bab7"), branches={ b"master": SnapshotBranch( target=hash_to_bytes(hash_hex), target_type=TargetType.REVISION, @@ -212,16 +212,15 @@ "snapshot:add": 1, } - expected_snapshot = { - "id": snap_id, - "branches": {"master": {"target": hash_hex, "target_type": "revision",}}, - } - check_snapshot(expected_snapshot, swh_storage) + for snap in [snapshot, snapshot.to_dict()]: + check_snapshot(snap, swh_storage) def test_check_snapshot_failure(swh_storage): + """check_snapshot should raise if something goes wrong""" + snap_id_hex = "2498dbf535f882bc7f9a18fb16c9ad27fda7bab7" snapshot = Snapshot( - id=hash_to_bytes("2498dbf535f882bc7f9a18fb16c9ad27fda7bab7"), + id=hash_to_bytes(snap_id_hex), branches={ b"master": SnapshotBranch( target=hash_to_bytes(hash_hex), target_type=TargetType.REVISION, @@ -241,10 +240,19 @@ }, } - with pytest.raises(AssertionError, match="Differing items"): - check_snapshot(unexpected_snapshot, swh_storage) + # id is correct, the branch is wrong, that should raise nonetheless + for snap_id in [snap_id_hex, snapshot.id]: + with pytest.raises(AssertionError, match="Differing attributes"): + unexpected_snapshot["id"] = snap_id + check_snapshot(unexpected_snapshot, swh_storage) # snapshot id which does not exist - unexpected_snapshot["id"] = "999666f535f882bc7f9a18fb16c9ad27fda7bab7" - with pytest.raises(AssertionError, match="is not found"): - check_snapshot(unexpected_snapshot, swh_storage) + wrong_snap_id_hex = "999666f535f882bc7f9a18fb16c9ad27fda7bab7" + for snap_id in [wrong_snap_id_hex, hash_to_bytes(wrong_snap_id_hex)]: + unexpected_snapshot["id"] = wrong_snap_id_hex + with pytest.raises(AssertionError, match="is not found"): + check_snapshot(unexpected_snapshot, swh_storage) + + # not a Snapshot object, raise! + with pytest.raises(AssertionError, match="variable 'snapshot' must be a snapshot"): + check_snapshot(ORIGIN_VISIT, swh_storage)