diff --git a/swh/loader/tests/__init__.py b/swh/loader/tests/__init__.py --- a/swh/loader/tests/__init__.py +++ b/swh/loader/tests/__init__.py @@ -7,11 +7,12 @@ import subprocess from pathlib import PosixPath -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union -from swh.model.model import OriginVisitStatus +from swh.model.model import OriginVisitStatus, Snapshot from swh.model.hashutil import hash_to_bytes, hash_to_hex +from swh.storage.interface import StorageInterface from swh.storage.algos.origin import origin_get_latest_visit_status @@ -81,48 +82,66 @@ return repo_url -def decode_target(target): +def decode_target(target: Dict) -> Dict: """Test helper to ease readability in test """ if not target: return target target_type = target["target_type"] - - if target_type == "alias": - decoded_target = target["target"].decode("utf-8") + target_data = target["target"] + if target_type == "alias" and isinstance(target_data, bytes): + decoded_target = target_data.decode("utf-8") + elif isinstance(target_data, bytes): + decoded_target = hash_to_hex(target_data) else: - decoded_target = hash_to_hex(target["target"]) + decoded_target = target_data return {"target": decoded_target, "target_type": target_type} -def check_snapshot(expected_snapshot, storage): +def check_snapshot( + snapshot: Union[Dict[str, Any], Snapshot], storage: StorageInterface +): """Check for snapshot match. - Provide the hashes as hexadecimal, the conversion is done - within the method. + The hashes can be both in hex or bytes, the necessary conversion will happen prior + to check. Args: - expected_snapshot (dict): full snapshot with hex ids - storage (Storage): expected storage + expected_snapshot: full snapshot to check for existence and consistency + storage: storage to lookup information into Returns: the snapshot stored in the storage for further test assertion if any is needed. """ + if isinstance(snapshot, Snapshot): + expected_snapshot = snapshot.to_dict() + elif isinstance(snapshot, dict): + expected_snapshot = snapshot + else: + raise AssertionError(f"variable 'snapshot' must be a snapshot: {snapshot!r}") + expected_snapshot_id = expected_snapshot["id"] expected_branches = expected_snapshot["branches"] snap = storage.snapshot_get(hash_to_bytes(expected_snapshot_id)) if snap is None: raise AssertionError(f"Snapshot {expected_snapshot_id} is not found") - branches = { - branch.decode("utf-8"): decode_target(target) - for branch, target in snap["branches"].items() - } - assert expected_branches == branches + expected_branches = {} + for branch, target in expected_snapshot["branches"].items(): + if isinstance(branch, bytes): + branch = branch.decode("utf-8") + expected_branches[branch] = decode_target(target) + + snapshot_branches = {} + for branch, target in snap["branches"].items(): + if isinstance(branch, bytes): + branch = branch.decode("utf-8") + snapshot_branches[branch] = decode_target(target) + assert expected_branches == snapshot_branches return snap diff --git a/swh/loader/tests/test_init.py b/swh/loader/tests/test_init.py --- a/swh/loader/tests/test_init.py +++ b/swh/loader/tests/test_init.py @@ -177,29 +177,30 @@ def test_decode_target(): - actual_alias_decode_target = decode_target( - {"target_type": "alias", "target": b"something",} - ) - - assert actual_alias_decode_target == { - "target_type": "alias", - "target": "something", - } - - actual_decode_target = decode_target( - {"target_type": "revision", "target": hash_to_bytes(hash_hex),} - ) - - assert actual_decode_target == { - "target_type": "revision", - "target": hash_hex, - } + for target_alias in ["something", b"something"]: + target = { + "target_type": "alias", + "target": target_alias, + } + actual_alias_decode_target = decode_target(target) + assert actual_alias_decode_target == { + "target_type": "alias", + "target": "something", + } + + for hash_ in [hash_hex, hash_to_bytes(hash_hex)]: + target = {"target_type": "revision", "target": hash_} + actual_decode_target = decode_target(target) + assert actual_decode_target == { + "target_type": "revision", + "target": hash_hex, + } def test_check_snapshot(swh_storage): - snap_id = "2498dbf535f882bc7f9a18fb16c9ad27fda7bab7" + """Check snapshot should not raise when everything is fine""" snapshot = Snapshot( - id=hash_to_bytes(snap_id), + id=hash_to_bytes("2498dbf535f882bc7f9a18fb16c9ad27fda7bab7"), branches={ b"master": SnapshotBranch( target=hash_to_bytes(hash_hex), target_type=TargetType.REVISION, @@ -212,16 +213,15 @@ "snapshot:add": 1, } - expected_snapshot = { - "id": snap_id, - "branches": {"master": {"target": hash_hex, "target_type": "revision",}}, - } - check_snapshot(expected_snapshot, swh_storage) + for snap in [snapshot, snapshot.to_dict()]: + check_snapshot(snap, swh_storage) def test_check_snapshot_failure(swh_storage): + """check_snapshot should raise if something goes wrong""" + snap_id_hex = "2498dbf535f882bc7f9a18fb16c9ad27fda7bab7" snapshot = Snapshot( - id=hash_to_bytes("2498dbf535f882bc7f9a18fb16c9ad27fda7bab7"), + id=hash_to_bytes(snap_id_hex), branches={ b"master": SnapshotBranch( target=hash_to_bytes(hash_hex), target_type=TargetType.REVISION, @@ -241,10 +241,19 @@ }, } - with pytest.raises(AssertionError, match="Differing items"): - check_snapshot(unexpected_snapshot, swh_storage) + # id is correct, the branch is wrong, that should raise nonetheless + for snap_id in [snap_id_hex, snapshot.id]: + with pytest.raises(AssertionError, match="Differing items"): + unexpected_snapshot["id"] = snap_id + check_snapshot(unexpected_snapshot, swh_storage) # snapshot id which does not exist - unexpected_snapshot["id"] = "999666f535f882bc7f9a18fb16c9ad27fda7bab7" - with pytest.raises(AssertionError, match="is not found"): - check_snapshot(unexpected_snapshot, swh_storage) + wrong_snap_id_hex = "999666f535f882bc7f9a18fb16c9ad27fda7bab7" + for snap_id in [wrong_snap_id_hex, hash_to_bytes(wrong_snap_id_hex)]: + unexpected_snapshot["id"] = wrong_snap_id_hex + with pytest.raises(AssertionError, match="is not found"): + check_snapshot(unexpected_snapshot, swh_storage) + + # not a Snapshot object, raise! + with pytest.raises(AssertionError, match="variable 'snapshot' must be a snapshot"): + check_snapshot(ORIGIN_VISIT, swh_storage)