diff --git a/swh/deposit/api/private/__init__.py b/swh/deposit/api/private/__init__.py --- a/swh/deposit/api/private/__init__.py +++ b/swh/deposit/api/private/__init__.py @@ -1,15 +1,16 @@ -# Copyright (C) 2017-2019 The Software Heritage developers +# Copyright (C) 2017-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from swh.deposit import utils +from typing import Dict, List + +from rest_framework.permissions import AllowAny from ...config import METADATA_TYPE, SWHDefaultConfig from ...models import DepositRequest, Deposit -from rest_framework.permissions import AllowAny - +from swh.deposit import utils from swh.deposit.api.common import SWHAPIView from swh.deposit.errors import make_error_dict, NOT_FOUND @@ -40,7 +41,7 @@ for deposit_request in deposit_requests: yield deposit_request - def _metadata_get(self, deposit): + def _metadata_get(self, deposit: Deposit) -> Dict: """Given a deposit, aggregate all metadata requests. Args: @@ -57,6 +58,22 @@ ) return utils.merge(*metadata) + def _raw_metadata_get(self, deposit: Deposit) -> List[bytes]: + """Given a deposit, returns the metadata as they were received. + + Args: + deposit (Deposit): The deposit instance to extract + metadata from. + + Returns: + list of deposited metadata + + """ + return [ + m.raw_metadata.encode("utf8") + for m in self._deposit_requests(deposit, request_type=METADATA_TYPE) + ] + class SWHPrivateAPIView(SWHDefaultConfig, SWHAPIView): """Mixin intended as private api (so no authentication) based API view diff --git a/swh/deposit/tests/loader/common.py b/swh/deposit/tests/loader/common.py --- a/swh/deposit/tests/loader/common.py +++ b/swh/deposit/tests/loader/common.py @@ -5,10 +5,12 @@ import json -from typing import Dict +from typing import Dict, Optional from swh.deposit.client import PrivateApiDepositClient from swh.model.hashutil import hash_to_bytes, hash_to_hex +from swh.model.model import SnapshotBranch, TargetType +from swh.storage.algos.snapshot import snapshot_get_all_branches CLIENT_TEST_CONFIG = { "url": "http://nowhere:9000/", @@ -85,18 +87,18 @@ return {k: stats.get(k) for k in keys} -def decode_target(target): +def decode_target(branch: Optional[SnapshotBranch]) -> Optional[Dict]: """Test helper to ease readability in test """ - if not target: - return target - target_type = target["target_type"] + if not branch: + return None + target_type = branch.target_type - if target_type == "alias": - decoded_target = target["target"].decode("utf-8") + if target_type == TargetType.ALIAS: + decoded_target = branch.target.decode("utf-8") else: - decoded_target = hash_to_hex(target["target"]) + decoded_target = hash_to_hex(branch.target) return {"target": decoded_target, "target_type": target_type} @@ -114,7 +116,7 @@ """ expected_snapshot_id = expected_snapshot["id"] expected_branches = expected_snapshot["branches"] - snap = storage.snapshot_get(hash_to_bytes(expected_snapshot_id)) + snap = snapshot_get_all_branches(hash_to_bytes(expected_snapshot_id)) if snap is None: # display known snapshots instead if possible if hasattr(storage, "_snapshots"): # in-mem storage @@ -132,7 +134,7 @@ raise AssertionError("Snapshot is not found") branches = { - branch.decode("utf-8"): decode_target(target) - for branch, target in snap["branches"].items() + branch.decode("utf-8"): decode_target(branch) + for branch_name, branch in snap["branches"].items() } assert expected_branches == branches