diff --git a/swh/storage/validate.py b/swh/storage/validate.py --- a/swh/storage/validate.py +++ b/swh/storage/validate.py @@ -4,7 +4,7 @@ # See top-level LICENSE file for more information import contextlib -from typing import Dict, Iterable, Iterator, List, Optional, Tuple +from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union from swh.model.model import ( SkippedContent, @@ -42,6 +42,7 @@ """Storage implementation converts dictionaries to swh-model objects before calling its backend, and back to dicts before returning results + For test purposes. """ def __init__(self, storage): @@ -113,15 +114,26 @@ def origin_visit_add(self, visits: Iterable[OriginVisit]) -> Iterable[OriginVisit]: return self.storage.origin_visit_add(visits) - def origin_add(self, origins: Iterable[Dict]) -> List[Dict]: - with convert_validation_exceptions(): - origins = [Origin.from_dict(o) for o in origins] - return self.storage.origin_add(origins) - - def origin_add_one(self, origin: Dict) -> int: - with convert_validation_exceptions(): - origin = Origin.from_dict(origin) - return self.storage.origin_add_one(origin) + def origin_add(self, origins: Union[Iterable[Dict], Iterable[Origin]]) -> List: + origins_: List[Origin] = [] + for o in origins: + ori: Origin + if isinstance(o, Dict): + with convert_validation_exceptions(): + ori = Origin.from_dict(o) + else: + ori = o + origins_.append(ori) + return self.storage.origin_add(origins_) + + def origin_add_one(self, origin: Union[Dict, Origin]) -> int: + origin_: Origin + if isinstance(origin, Dict): + with convert_validation_exceptions(): + origin_ = Origin.from_dict(origin) + else: + origin_ = origin + return self.storage.origin_add_one(origin_) def clear_buffers(self, object_types: Optional[Iterable[str]] = None) -> None: return self.storage.clear_buffers(object_types)