diff --git a/swh/provenance/tests/conftest.py b/swh/provenance/tests/conftest.py --- a/swh/provenance/tests/conftest.py +++ b/swh/provenance/tests/conftest.py @@ -5,7 +5,7 @@ from datetime import datetime, timedelta, timezone from os import path -from typing import Any, Dict, Generator, Iterable +from typing import Any, Dict, Generator, List from _pytest.fixtures import SubRequest import mongomock.database @@ -15,12 +15,13 @@ from pytest_postgresql.factories import postgresql from swh.journal.serializers import msgpack_ext_hook +from swh.model.model import BaseModel from swh.provenance import get_provenance, get_provenance_storage from swh.provenance.archive import ArchiveInterface from swh.provenance.interface import ProvenanceInterface, ProvenanceStorageInterface from swh.provenance.storage.archive import ArchiveStorage from swh.storage.interface import StorageInterface -from swh.storage.replay import OBJECT_CONVERTERS, process_replay_objects +from swh.storage.replay import OBJECT_CONVERTERS, OBJECT_FIXERS, process_replay_objects @pytest.fixture( @@ -100,12 +101,22 @@ return ArchiveStorage(swh_storage) +def fill_storage(storage: StorageInterface, data: Dict[str, List[dict]]) -> None: + objects = { + objtype: [objs_from_dict(objtype, d) for d in dicts] + for objtype, dicts in data.items() + } + process_replay_objects(objects, storage=storage) + + def get_datafile(fname: str) -> str: return path.join(path.dirname(__file__), "data", fname) -def load_repo_data(repo: str) -> Dict[str, Any]: - data: Dict[str, Any] = {} +# TODO: this should return Dict[str, List[BaseModel]] directly, but it requires +# refactoring several tests +def load_repo_data(repo: str) -> Dict[str, List[dict]]: + data: Dict[str, List[dict]] = {} with open(get_datafile(f"{repo}.msgpack"), "rb") as fobj: unpacker = msgpack.Unpacker( fobj, @@ -119,16 +130,11 @@ return data -def filter_dict(d: Dict[Any, Any], keys: Iterable[Any]) -> Dict[Any, Any]: - return {k: v for (k, v) in d.items() if k in keys} - - -def fill_storage(storage: StorageInterface, data: Dict[str, Any]) -> None: - data = { - object_type: [OBJECT_CONVERTERS[object_type](d) for d in values] - for object_type, values in data.items() - } - process_replay_objects(data, storage=storage) +def objs_from_dict(object_type: str, dict_repr: dict) -> BaseModel: + if object_type in OBJECT_FIXERS: + dict_repr = OBJECT_FIXERS[object_type](dict_repr) + obj = OBJECT_CONVERTERS[object_type](dict_repr) + return obj # TODO: remove this function in favour of TimestampWithTimezone.to_datetime