diff --git a/swh/storage/validate.py b/swh/storage/validate.py --- a/swh/storage/validate.py +++ b/swh/storage/validate.py @@ -4,7 +4,7 @@ # See top-level LICENSE file for more information import contextlib -from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Type, TypeVar, Union from swh.model.model import ( SkippedContent, @@ -38,6 +38,30 @@ raise StorageArgumentException(str(e)) +ModelObject = TypeVar( + "ModelObject", + Content, + SkippedContent, + Directory, + Revision, + Release, + Snapshot, + OriginVisit, + Origin, +) + + +def dict_converter( + model: Type[ModelObject], obj: Union[Dict, ModelObject] +) -> ModelObject: + """Convert dicts to model objects; Passes through model objects as well.""" + if isinstance(obj, dict): + with convert_validation_exceptions(): + return model.from_dict(obj) + else: + return obj + + class ValidatingProxyStorage: """Storage implementation converts dictionaries to swh-model objects before calling its backend, and back to dicts before returning results @@ -53,30 +77,30 @@ raise AttributeError(key) return getattr(self.storage, key) - def content_add(self, content: Iterable[Dict]) -> Dict: - with convert_validation_exceptions(): - contents = [Content.from_dict(c) for c in content] - return self.storage.content_add(contents) + def content_add(self, content: Iterable[Union[Content, Dict]]) -> Dict: + return self.storage.content_add([dict_converter(Content, c) for c in content]) - def content_add_metadata(self, content: Iterable[Dict]) -> Dict: - with convert_validation_exceptions(): - contents = [Content.from_dict(c) for c in content] - return self.storage.content_add_metadata(contents) + def content_add_metadata(self, content: Iterable[Union[Content, Dict]]) -> Dict: + return self.storage.content_add_metadata( + [dict_converter(Content, c) for c in content] + ) - def skipped_content_add(self, content: Iterable[Dict]) -> Dict: - with convert_validation_exceptions(): - contents = [SkippedContent.from_dict(c) for c in content] - return self.storage.skipped_content_add(contents) + def skipped_content_add( + self, content: Iterable[Union[SkippedContent, Dict]] + ) -> Dict: + return self.storage.skipped_content_add( + [dict_converter(SkippedContent, c) for c in content] + ) - def directory_add(self, directories: Iterable[Dict]) -> Dict: - with convert_validation_exceptions(): - directories = [Directory.from_dict(d) for d in directories] - return self.storage.directory_add(directories) + def directory_add(self, directories: Iterable[Union[Directory, Dict]]) -> Dict: + return self.storage.directory_add( + [dict_converter(Directory, d) for d in directories] + ) - def revision_add(self, revisions: Iterable[Dict]) -> Dict: - with convert_validation_exceptions(): - revisions = [Revision.from_dict(r) for r in revisions] - return self.storage.revision_add(revisions) + def revision_add(self, revisions: Iterable[Union[Revision, Dict]]) -> Dict: + return self.storage.revision_add( + [dict_converter(Revision, r) for r in revisions] + ) def revision_get(self, revisions: Iterable[bytes]) -> Iterator[Optional[Dict]]: rev_dicts = self.storage.revision_get(revisions) @@ -101,39 +125,24 @@ for rev, parents in self.storage.revision_shortlog(revisions, limit): yield (rev, tuple(parents)) - def release_add(self, releases: Iterable[Dict]) -> Dict: - with convert_validation_exceptions(): - releases = [Release.from_dict(r) for r in releases] - return self.storage.release_add(releases) + def release_add(self, releases: Iterable[Union[Dict, Release]]) -> Dict: + return self.storage.release_add( + [dict_converter(Release, release) for release in releases] + ) - def snapshot_add(self, snapshots: Iterable[Dict]) -> Dict: - with convert_validation_exceptions(): - snapshots = [Snapshot.from_dict(s) for s in snapshots] - return self.storage.snapshot_add(snapshots) + def snapshot_add(self, snapshots: Iterable[Union[Dict, Snapshot]]) -> Dict: + return self.storage.snapshot_add( + [dict_converter(Snapshot, snapshot) for snapshot in snapshots] + ) def origin_visit_add(self, visits: Iterable[OriginVisit]) -> Iterable[OriginVisit]: return self.storage.origin_visit_add(visits) - def origin_add(self, origins: Union[Iterable[Dict], Iterable[Origin]]) -> List: - origins_: List[Origin] = [] - for o in origins: - ori: Origin - if isinstance(o, Dict): - with convert_validation_exceptions(): - ori = Origin.from_dict(o) - else: - ori = o - origins_.append(ori) - return self.storage.origin_add(origins_) + def origin_add(self, origins: Iterable[Union[Dict, Origin]]) -> List: + return self.storage.origin_add([dict_converter(Origin, o) for o in origins]) def origin_add_one(self, origin: Union[Dict, Origin]) -> int: - origin_: Origin - if isinstance(origin, Dict): - with convert_validation_exceptions(): - origin_ = Origin.from_dict(origin) - else: - origin_ = origin - return self.storage.origin_add_one(origin_) + return self.storage.origin_add_one(dict_converter(Origin, origin)) def clear_buffers(self, object_types: Optional[Iterable[str]] = None) -> None: return self.storage.clear_buffers(object_types)