diff --git a/swh/storage/algos/origin.py b/swh/storage/algos/origin.py --- a/swh/storage/algos/origin.py +++ b/swh/storage/algos/origin.py @@ -3,7 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Optional, Iterable, Iterator, Tuple +from typing import Iterator, List, Optional, Tuple from swh.model.model import Origin, OriginVisit, OriginVisitStatus from swh.storage.interface import StorageInterface @@ -50,7 +50,7 @@ storage: StorageInterface, origin_url: str, type: Optional[str] = None, - allowed_statuses: Optional[Iterable[str]] = None, + allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, ) -> Optional[Tuple[OriginVisit, OriginVisitStatus]]: """Get the latest origin visit (and status) of an origin. Optionally, a combination of diff --git a/swh/storage/algos/snapshot.py b/swh/storage/algos/snapshot.py --- a/swh/storage/algos/snapshot.py +++ b/swh/storage/algos/snapshot.py @@ -3,7 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Iterable, Optional +from typing import List, Optional from swh.model.model import Snapshot @@ -39,7 +39,7 @@ def snapshot_get_latest( storage, origin: str, - allowed_statuses: Optional[Iterable[str]] = None, + allowed_statuses: Optional[List[str]] = None, branches_count: Optional[int] = None, ) -> Optional[Snapshot]: """Get the latest snapshot for the given origin, optionally only from visits that have diff --git a/swh/storage/buffer.py b/swh/storage/buffer.py --- a/swh/storage/buffer.py +++ b/swh/storage/buffer.py @@ -9,6 +9,7 @@ from swh.core.utils import grouper from swh.model.model import Content, BaseModel from swh.storage import get_storage +from swh.storage.interface import StorageInterface class BufferingProxyStorage: @@ -36,7 +37,7 @@ """ def __init__(self, storage, min_batch_size=None): - self.storage = get_storage(**storage) + self.storage: StorageInterface = get_storage(**storage) if min_batch_size is None: min_batch_size = {} @@ -67,7 +68,7 @@ raise AttributeError(key) return getattr(self.storage, key) - def content_add(self, content: Iterable[Content]) -> Dict: + def content_add(self, content: List[Content]) -> Dict: """Enqueue contents to write to the storage. Following policies apply: @@ -79,7 +80,6 @@ threshold is hit. If it is flush content to the storage. """ - content = list(content) s = self.object_add( content, object_type="content", @@ -93,14 +93,14 @@ return s - def skipped_content_add(self, content: Iterable[Content]) -> Dict: + def skipped_content_add(self, content: List[Content]) -> Dict: return self.object_add( content, object_type="skipped_content", keys=["sha1", "sha1_git", "sha256", "blake2s256"], ) - def flush(self, object_types: Optional[Iterable[str]] = None) -> Dict: + def flush(self, object_types: Optional[List[str]] = None) -> Dict: summary: Dict[str, int] = self.storage.flush(object_types) if object_types is None: object_types = self.object_types @@ -132,7 +132,7 @@ return {} - def clear_buffers(self, object_types: Optional[Iterable[str]] = None) -> None: + def clear_buffers(self, object_types: Optional[List[str]] = None) -> None: """Clear objects from current buffer. WARNING: diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -146,7 +146,7 @@ return summary - def content_add(self, content: Iterable[Content]) -> Dict: + def content_add(self, content: List[Content]) -> Dict: contents = [attr.evolve(c, ctime=now()) for c in content] return self._content_add(list(contents), with_data=True) @@ -155,8 +155,8 @@ "content_update is not supported by the Cassandra backend" ) - def content_add_metadata(self, content: Iterable[Content]) -> Dict: - return self._content_add(list(content), with_data=False) + def content_add_metadata(self, content: List[Content]) -> Dict: + return self._content_add(content, with_data=False) def content_get(self, content): if len(content) > BULK_BLOCK_CONTENT_LEN_MAX: @@ -280,7 +280,7 @@ if getattr(row, algo) == hash_: yield row - def _skipped_content_add(self, contents: Iterable[SkippedContent]) -> Dict: + def _skipped_content_add(self, contents: List[SkippedContent]) -> Dict: # Filter-out content already in the database. contents = [ c @@ -305,7 +305,7 @@ return {"skipped_content:add": len(contents)} - def skipped_content_add(self, content: Iterable[SkippedContent]) -> Dict: + def skipped_content_add(self, content: List[SkippedContent]) -> Dict: contents = [attr.evolve(c, ctime=now()) for c in content] return self._skipped_content_add(contents) @@ -314,9 +314,7 @@ if not self._cql_runner.skipped_content_get_from_pk(content): yield {algo: content[algo] for algo in DEFAULT_ALGORITHMS} - def directory_add(self, directories: Iterable[Directory]) -> Dict: - directories = list(directories) - + def directory_add(self, directories: List[Directory]) -> Dict: # Filter out directories that are already inserted. missing = self.directory_missing([dir_.id for dir_ in directories]) directories = [dir_ for dir_ in directories if dir_.id in missing] @@ -419,9 +417,7 @@ def directory_get_random(self): return self._cql_runner.directory_get_random().id - def revision_add(self, revisions: Iterable[Revision]) -> Dict: - revisions = list(revisions) - + def revision_add(self, revisions: List[Revision]) -> Dict: # Filter-out revisions already in the database missing = self.revision_missing([rev.id for rev in revisions]) revisions = [rev for rev in revisions if rev.id in missing] @@ -510,7 +506,7 @@ def revision_get_random(self): return self._cql_runner.revision_get_random().id - def release_add(self, releases: Iterable[Release]) -> Dict: + def release_add(self, releases: List[Release]) -> Dict: to_add = [] for rel in releases: if rel not in to_add: @@ -542,8 +538,7 @@ def release_get_random(self): return self._cql_runner.release_get_random().id - def snapshot_add(self, snapshots: Iterable[Snapshot]) -> Dict: - snapshots = list(snapshots) + def snapshot_add(self, snapshots: List[Snapshot]) -> Dict: missing = self._cql_runner.snapshot_missing([snp.id for snp in snapshots]) snapshots = [snp for snp in snapshots if snp.id in missing] @@ -685,7 +680,7 @@ return results - def origin_get(self, origins: Iterable[str]) -> Iterable[Optional[Origin]]: + def origin_get(self, origins: List[str]) -> Iterable[Optional[Origin]]: return [self.origin_get_one(origin) for origin in origins] def origin_get_one(self, origin_url: str) -> Optional[Origin]: @@ -746,15 +741,14 @@ return [{"url": orig.url,} for orig in origins[offset : offset + limit]] - def origin_add(self, origins: Iterable[Origin]) -> Dict[str, int]: - origins = list(origins) + def origin_add(self, origins: List[Origin]) -> Dict[str, int]: to_add = [ori for ori in origins if self.origin_get_one(ori.url) is None] self.journal_writer.origin_add(to_add) for origin in to_add: self._cql_runner.origin_add_one(origin) return {"origin:add": len(to_add)} - def origin_visit_add(self, visits: Iterable[OriginVisit]) -> Iterable[OriginVisit]: + def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]: for visit in visits: origin = self.origin_get_one(visit.origin) if not origin: # Cannot add a visit without an origin @@ -790,9 +784,7 @@ self.journal_writer.origin_visit_status_add([visit_status]) self._cql_runner.origin_visit_status_add_one(visit_status) - def origin_visit_status_add( - self, visit_statuses: Iterable[OriginVisitStatus] - ) -> None: + def origin_visit_status_add(self, visit_statuses: List[OriginVisitStatus]) -> None: # First round to check existence (fail early if any is ko) for visit_status in visit_statuses: origin_url = self.origin_get_one(visit_status.origin) @@ -969,10 +961,7 @@ def refresh_stat_counters(self): pass - def raw_extrinsic_metadata_add( - self, metadata: Iterable[RawExtrinsicMetadata] - ) -> None: - metadata = list(metadata) + def raw_extrinsic_metadata_add(self, metadata: List[RawExtrinsicMetadata]) -> None: self.journal_writer.raw_extrinsic_metadata_add(metadata) for metadata_entry in metadata: if not self._cql_runner.metadata_authority_get( @@ -1109,8 +1098,7 @@ "results": results, } - def metadata_fetcher_add(self, fetchers: Iterable[MetadataFetcher]) -> None: - fetchers = list(fetchers) + def metadata_fetcher_add(self, fetchers: List[MetadataFetcher]) -> None: self.journal_writer.metadata_fetcher_add(fetchers) for fetcher in fetchers: self._cql_runner.metadata_fetcher_add( @@ -1132,8 +1120,7 @@ else: return None - def metadata_authority_add(self, authorities: Iterable[MetadataAuthority]) -> None: - authorities = list(authorities) + def metadata_authority_add(self, authorities: List[MetadataAuthority]) -> None: self.journal_writer.metadata_authority_add(authorities) for authority in authorities: self._cql_runner.metadata_authority_add( @@ -1155,11 +1142,11 @@ else: return None - def clear_buffers(self, object_types: Optional[Iterable[str]] = None) -> None: + def clear_buffers(self, object_types: Optional[List[str]] = None) -> None: """Do nothing """ return None - def flush(self, object_types: Optional[Iterable[str]] = None) -> Dict: + def flush(self, object_types: Optional[List[str]] = None) -> Dict: return {} diff --git a/swh/storage/filter.py b/swh/storage/filter.py --- a/swh/storage/filter.py +++ b/swh/storage/filter.py @@ -4,7 +4,7 @@ # See top-level LICENSE file for more information -from typing import Dict, Iterable, Set +from typing import Dict, Iterable, List, Set from swh.model.model import ( Content, @@ -14,6 +14,7 @@ ) from swh.storage import get_storage +from swh.storage.interface import StorageInterface class FilteringProxyStorage: @@ -35,38 +36,34 @@ object_types = ["content", "skipped_content", "directory", "revision"] def __init__(self, storage): - self.storage = get_storage(**storage) + self.storage: StorageInterface = get_storage(**storage) def __getattr__(self, key): if key == "storage": raise AttributeError(key) return getattr(self.storage, key) - def content_add(self, content: Iterable[Content]) -> Dict: - contents = list(content) - contents_to_add = self._filter_missing_contents(contents) + def content_add(self, content: List[Content]) -> Dict: + contents_to_add = self._filter_missing_contents(content) return self.storage.content_add( - x for x in contents if x.sha256 in contents_to_add + x for x in content if x.sha256 in contents_to_add ) - def skipped_content_add(self, content: Iterable[SkippedContent]) -> Dict: - contents = list(content) - contents_to_add = self._filter_missing_skipped_contents(contents) + def skipped_content_add(self, content: List[SkippedContent]) -> Dict: + contents_to_add = self._filter_missing_skipped_contents(content) return self.storage.skipped_content_add( - x for x in contents if x.sha1_git is None or x.sha1_git in contents_to_add + x for x in content if x.sha1_git is None or x.sha1_git in contents_to_add ) - def directory_add(self, directories: Iterable[Directory]) -> Dict: - directories = list(directories) + def directory_add(self, directories: List[Directory]) -> Dict: missing_ids = self._filter_missing_ids("directory", (d.id for d in directories)) return self.storage.directory_add(d for d in directories if d.id in missing_ids) - def revision_add(self, revisions: Iterable[Revision]) -> Dict: - revisions = list(revisions) + def revision_add(self, revisions: List[Revision]) -> Dict: missing_ids = self._filter_missing_ids("revision", (r.id for r in revisions)) return self.storage.revision_add(r for r in revisions if r.id in missing_ids) - def _filter_missing_contents(self, contents: Iterable[Content]) -> Set[bytes]: + def _filter_missing_contents(self, contents: List[Content]) -> Set[bytes]: """Return only the content keys missing from swh Args: @@ -81,7 +78,7 @@ return set(self.storage.content_missing(missing_contents, key_hash="sha256",)) def _filter_missing_skipped_contents( - self, contents: Iterable[SkippedContent] + self, contents: List[SkippedContent] ) -> Set[bytes]: """Return only the content keys missing from swh @@ -106,7 +103,7 @@ Args: object_type: object type to use {revision, directory} - ids: Iterable of object_type ids + ids: List of object_type ids Returns: Missing ids from the storage for object_type diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -176,7 +176,7 @@ def check_config(self, *, check_write): return True - def _content_add(self, contents: Iterable[Content], with_data: bool) -> Dict: + def _content_add(self, contents: List[Content], with_data: bool) -> Dict: self.journal_writer.content_add(contents) content_add = 0 @@ -220,7 +220,7 @@ return summary - def content_add(self, content: Iterable[Content]) -> Dict: + def content_add(self, content: List[Content]) -> Dict: content = [attr.evolve(c, ctime=now()) for c in content] return self._content_add(content, with_data=True) @@ -246,7 +246,7 @@ hash_ = new_cont.get_hash(algorithm) self._content_indexes[algorithm][hash_].add(new_key) - def content_add_metadata(self, content: Iterable[Content]) -> Dict: + def content_add_metadata(self, content: List[Content]) -> Dict: return self._content_add(content, with_data=False) def content_get(self, content): @@ -381,7 +381,7 @@ return summary - def skipped_content_add(self, content: Iterable[SkippedContent]) -> Dict: + def skipped_content_add(self, content: List[SkippedContent]) -> Dict: content = [attr.evolve(c, ctime=now()) for c in content] return self._skipped_content_add(content) @@ -399,7 +399,7 @@ if not matches: yield {algo: content[algo] for algo in DEFAULT_ALGORITHMS} - def directory_add(self, directories: Iterable[Directory]) -> Dict: + def directory_add(self, directories: List[Directory]) -> Dict: directories = [dir_ for dir_ in directories if dir_.id not in self._directories] self.journal_writer.directory_add(directories) @@ -486,7 +486,7 @@ first_item["target"], paths[1:], prefix + paths[0] + b"/" ) - def revision_add(self, revisions: Iterable[Revision]) -> Dict: + def revision_add(self, revisions: List[Revision]) -> Dict: revisions = [rev for rev in revisions if rev.id not in self._revisions] self.journal_writer.revision_add(revisions) @@ -538,7 +538,7 @@ def revision_get_random(self): return random.choice(list(self._revisions)) - def release_add(self, releases: Iterable[Release]) -> Dict: + def release_add(self, releases: List[Release]) -> Dict: to_add = [] for rel in releases: if rel.id not in self._releases and rel not in to_add: @@ -566,9 +566,9 @@ def release_get_random(self): return random.choice(list(self._releases)) - def snapshot_add(self, snapshots: Iterable[Snapshot]) -> Dict: + def snapshot_add(self, snapshots: List[Snapshot]) -> Dict: count = 0 - snapshots = (snap for snap in snapshots if snap.id not in self._snapshots) + snapshots = [snap for snap in snapshots if snap.id not in self._snapshots] for snapshot in snapshots: self.journal_writer.snapshot_add([snapshot]) self._snapshots[snapshot.id] = snapshot @@ -673,7 +673,7 @@ def origin_get_one(self, origin_url: str) -> Optional[Origin]: return self._origins.get(origin_url) - def origin_get(self, origins: Iterable[str]) -> Iterable[Optional[Origin]]: + def origin_get(self, origins: List[str]) -> Iterable[Optional[Origin]]: return [self.origin_get_one(origin_url) for origin_url in origins] def origin_get_by_sha1(self, sha1s): @@ -743,8 +743,7 @@ ) ) - def origin_add(self, origins: Iterable[Origin]) -> Dict[str, int]: - origins = list(origins) + def origin_add(self, origins: List[Origin]) -> Dict[str, int]: added = 0 for origin in origins: if origin.url not in self._origins: @@ -769,7 +768,7 @@ return origin.url - def origin_visit_add(self, visits: Iterable[OriginVisit]) -> Iterable[OriginVisit]: + def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]: for visit in visits: origin = self.origin_get_one(visit.origin) if not origin: # Cannot add a visit without an origin @@ -818,9 +817,7 @@ if visit_status not in visit_statuses: visit_statuses.append(visit_status) - def origin_visit_status_add( - self, visit_statuses: Iterable[OriginVisitStatus], - ) -> None: + def origin_visit_status_add(self, visit_statuses: List[OriginVisitStatus],) -> None: # First round to check existence (fail early if any is ko) for visit_status in visit_statuses: origin_url = self.origin_get_one(visit_status.origin) @@ -1018,10 +1015,7 @@ def refresh_stat_counters(self): pass - def raw_extrinsic_metadata_add( - self, metadata: Iterable[RawExtrinsicMetadata], - ) -> None: - metadata = list(metadata) + def raw_extrinsic_metadata_add(self, metadata: List[RawExtrinsicMetadata],) -> None: self.journal_writer.raw_extrinsic_metadata_add(metadata) for metadata_entry in metadata: authority_key = self._metadata_authority_key(metadata_entry.authority) @@ -1132,8 +1126,7 @@ "results": results, } - def metadata_fetcher_add(self, fetchers: Iterable[MetadataFetcher]) -> None: - fetchers = list(fetchers) + def metadata_fetcher_add(self, fetchers: List[MetadataFetcher]) -> None: self.journal_writer.metadata_fetcher_add(fetchers) for fetcher in fetchers: if fetcher.metadata is None: @@ -1151,8 +1144,7 @@ self._metadata_fetcher_key(MetadataFetcher(name=name, version=version)) ) - def metadata_authority_add(self, authorities: Iterable[MetadataAuthority]) -> None: - authorities = list(authorities) + def metadata_authority_add(self, authorities: List[MetadataAuthority]) -> None: self.journal_writer.metadata_authority_add(authorities) for authority in authorities: if authority.metadata is None: @@ -1208,11 +1200,11 @@ def diff_revision(self, revision, track_renaming=False): raise NotImplementedError("InMemoryStorage.diff_revision") - def clear_buffers(self, object_types: Optional[Iterable[str]] = None) -> None: + def clear_buffers(self, object_types: Optional[List[str]] = None) -> None: """Do nothing """ return None - def flush(self, object_types: Optional[Iterable[str]] = None) -> Dict: + def flush(self, object_types: Optional[List[str]] = None) -> Dict: return {} diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -39,7 +39,7 @@ ... @remote_api_endpoint("content/add") - def content_add(self, content: Iterable[Content]) -> Dict: + def content_add(self, content: List[Content]) -> Dict: """Add content blobs to the storage Args: @@ -97,7 +97,7 @@ ... @remote_api_endpoint("content/add_metadata") - def content_add_metadata(self, content: Iterable[Content]) -> Dict: + def content_add_metadata(self, content: List[Content]) -> Dict: """Add content metadata to the storage (like `content_add`, but without inserting to the objstorage). @@ -248,7 +248,7 @@ """List content missing from storage based only on sha1. Args: - contents: Iterable of sha1 to check for absence. + contents: List of sha1 to check for absence. Returns: iterable: missing ids @@ -264,7 +264,7 @@ """List content missing from storage based only on sha1_git. Args: - contents (Iterable): An iterable of content id (sha1_git) + contents (List): An iterable of content id (sha1_git) Yields: missing contents sha1_git @@ -301,7 +301,7 @@ ... @remote_api_endpoint("content/skipped/add") - def skipped_content_add(self, content: Iterable[SkippedContent]) -> Dict: + def skipped_content_add(self, content: List[SkippedContent]) -> Dict: """Add contents to the skipped_content list, which contains (partial) information about content missing from the archive. @@ -352,7 +352,7 @@ ... @remote_api_endpoint("directory/add") - def directory_add(self, directories: Iterable[Directory]) -> Dict: + def directory_add(self, directories: List[Directory]) -> Dict: """Add directories to the storage Args: @@ -434,11 +434,11 @@ ... @remote_api_endpoint("revision/add") - def revision_add(self, revisions: Iterable[Revision]) -> Dict: + def revision_add(self, revisions: List[Revision]) -> Dict: """Add revisions to the storage Args: - revisions (Iterable[dict]): iterable of dictionaries representing + revisions (List[dict]): iterable of dictionaries representing the individual revisions to add. Each dict has the following keys: @@ -538,11 +538,11 @@ ... @remote_api_endpoint("release/add") - def release_add(self, releases: Iterable[Release]) -> Dict: + def release_add(self, releases: List[Release]) -> Dict: """Add releases to the storage Args: - releases (Iterable[dict]): iterable of dictionaries representing + releases (List[dict]): iterable of dictionaries representing the individual releases to add. Each dict has the following keys: @@ -603,7 +603,7 @@ ... @remote_api_endpoint("snapshot/add") - def snapshot_add(self, snapshots: Iterable[Snapshot]) -> Dict: + def snapshot_add(self, snapshots: List[Snapshot]) -> Dict: """Add snapshots to the storage. Args: @@ -755,26 +755,24 @@ ... @remote_api_endpoint("origin/visit/add") - def origin_visit_add(self, visits: Iterable[OriginVisit]) -> Iterable[OriginVisit]: + def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]: """Add visits to storage. If the visits have no id, they will be created and assigned one. The resulted visits are visits with their visit id set. Args: - visits: Iterable of OriginVisit objects to add + visits: List of OriginVisit objects to add Raises: StorageArgumentException if some origin visit reference unknown origins Returns: - Iterable[OriginVisit] stored + List[OriginVisit] stored """ ... @remote_api_endpoint("origin/visit_status/add") - def origin_visit_status_add( - self, visit_statuses: Iterable[OriginVisitStatus], - ) -> None: + def origin_visit_status_add(self, visit_statuses: List[OriginVisitStatus],) -> None: """Add origin visit statuses. If there is already a status for the same origin and visit id at the same @@ -937,7 +935,7 @@ ... @remote_api_endpoint("origin/get") - def origin_get(self, origins: Iterable[str]) -> Iterable[Optional[Origin]]: + def origin_get(self, origins: List[str]) -> Iterable[Optional[Origin]]: """Return origins. Args: @@ -1044,7 +1042,7 @@ ... @remote_api_endpoint("origin/add_multi") - def origin_add(self, origins: Iterable[Origin]) -> Dict[str, int]: + def origin_add(self, origins: List[Origin]) -> Dict[str, int]: """Add origins to the storage Args: @@ -1077,9 +1075,7 @@ ... @remote_api_endpoint("raw_extrinsic_metadata/add") - def raw_extrinsic_metadata_add( - self, metadata: Iterable[RawExtrinsicMetadata], - ) -> None: + def raw_extrinsic_metadata_add(self, metadata: List[RawExtrinsicMetadata],) -> None: """Add extrinsic metadata on objects (contents, directories, ...). The authority and fetcher must be known to the storage before @@ -1125,7 +1121,7 @@ ... @remote_api_endpoint("metadata_fetcher/add") - def metadata_fetcher_add(self, fetchers: Iterable[MetadataFetcher],) -> None: + def metadata_fetcher_add(self, fetchers: List[MetadataFetcher],) -> None: """Add new metadata fetchers to the storage. Their `name` and `version` together are unique identifiers of this @@ -1157,7 +1153,7 @@ ... @remote_api_endpoint("metadata_authority/add") - def metadata_authority_add(self, authorities: Iterable[MetadataAuthority]) -> None: + def metadata_authority_add(self, authorities: List[MetadataAuthority]) -> None: """Add new metadata authorities to the storage. Their `type` and `url` together are unique identifiers of this @@ -1242,7 +1238,7 @@ ... @remote_api_endpoint("clear/buffer") - def clear_buffers(self, object_types: Optional[Iterable[str]] = None) -> None: + def clear_buffers(self, object_types: Optional[List[str]] = None) -> None: """For backend storages (pg, storage, in-memory), this is a noop operation. For proxy storages (especially filter, buffer), this is an operation which cleans internal state. @@ -1250,7 +1246,7 @@ """ @remote_api_endpoint("flush") - def flush(self, object_types: Optional[Iterable[str]] = None) -> Dict: + def flush(self, object_types: Optional[List[str]] = None) -> Dict: """For backend storages (pg, storage, in-memory), this is expected to be a noop operation. For proxy storages (especially buffer), this is expected to trigger actual writes to the backend. diff --git a/swh/storage/retry.py b/swh/storage/retry.py --- a/swh/storage/retry.py +++ b/swh/storage/retry.py @@ -6,7 +6,7 @@ import logging import traceback -from typing import Dict, Iterable, Optional +from typing import Dict, Iterable, List, Optional from tenacity import ( retry, @@ -29,6 +29,7 @@ from swh.storage import get_storage from swh.storage.exc import StorageArgumentException +from swh.storage.interface import StorageInterface logger = logging.getLogger(__name__) @@ -85,7 +86,7 @@ """ def __init__(self, storage): - self.storage = get_storage(**storage) + self.storage: StorageInterface = get_storage(**storage) def __getattr__(self, key): if key == "storage": @@ -93,55 +94,53 @@ return getattr(self.storage, key) @swh_retry - def content_add(self, content: Iterable[Content]) -> Dict: + def content_add(self, content: List[Content]) -> Dict: return self.storage.content_add(content) @swh_retry - def content_add_metadata(self, content: Iterable[Content]) -> Dict: + def content_add_metadata(self, content: List[Content]) -> Dict: return self.storage.content_add_metadata(content) @swh_retry - def skipped_content_add(self, content: Iterable[SkippedContent]) -> Dict: + def skipped_content_add(self, content: List[SkippedContent]) -> Dict: return self.storage.skipped_content_add(content) @swh_retry - def origin_visit_add(self, visits: Iterable[OriginVisit]) -> Iterable[OriginVisit]: + def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]: return self.storage.origin_visit_add(visits) @swh_retry - def metadata_fetcher_add(self, fetchers: Iterable[MetadataFetcher],) -> None: + def metadata_fetcher_add(self, fetchers: List[MetadataFetcher],) -> None: return self.storage.metadata_fetcher_add(fetchers) @swh_retry - def metadata_authority_add(self, authorities: Iterable[MetadataAuthority]) -> None: + def metadata_authority_add(self, authorities: List[MetadataAuthority]) -> None: return self.storage.metadata_authority_add(authorities) @swh_retry - def raw_extrinsic_metadata_add( - self, metadata: Iterable[RawExtrinsicMetadata], - ) -> None: + def raw_extrinsic_metadata_add(self, metadata: List[RawExtrinsicMetadata],) -> None: return self.storage.raw_extrinsic_metadata_add(metadata) @swh_retry - def directory_add(self, directories: Iterable[Directory]) -> Dict: + def directory_add(self, directories: List[Directory]) -> Dict: return self.storage.directory_add(directories) @swh_retry - def revision_add(self, revisions: Iterable[Revision]) -> Dict: + def revision_add(self, revisions: List[Revision]) -> Dict: return self.storage.revision_add(revisions) @swh_retry - def release_add(self, releases: Iterable[Release]) -> Dict: + def release_add(self, releases: List[Release]) -> Dict: return self.storage.release_add(releases) @swh_retry - def snapshot_add(self, snapshots: Iterable[Snapshot]) -> Dict: + def snapshot_add(self, snapshots: List[Snapshot]) -> Dict: return self.storage.snapshot_add(snapshots) - def clear_buffers(self, object_types: Optional[Iterable[str]] = None) -> None: + def clear_buffers(self, object_types: Optional[List[str]] = None) -> None: return self.storage.clear_buffers(object_types) - def flush(self, object_types: Optional[Iterable[str]] = None) -> Dict: + def flush(self, object_types: Optional[List[str]] = None) -> Dict: """Specific case for buffer proxy storage failing to flush data """ diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -204,7 +204,7 @@ @timed @process_metrics - def content_add(self, content: Iterable[Content]) -> Dict: + def content_add(self, content: List[Content]) -> Dict: ctime = now() contents = [attr.evolve(c, ctime=ctime) for c in content] @@ -247,14 +247,11 @@ @timed @process_metrics @db_transaction() - def content_add_metadata( - self, content: Iterable[Content], db=None, cur=None - ) -> Dict: - contents = list(content) + def content_add_metadata(self, content: List[Content], db=None, cur=None) -> Dict: missing = self.content_missing( - (c.to_dict() for c in contents), key_hash="sha1_git", db=db, cur=cur, + (c.to_dict() for c in content), key_hash="sha1_git", db=db, cur=cur, ) - contents = [c for c in contents if c.sha1_git in missing] + contents = [c for c in content if c.sha1_git in missing] self.journal_writer.content_add_metadata(contents) self._content_add_metadata(db, cur, contents) @@ -393,7 +390,7 @@ return d - def _skipped_content_add_metadata(self, db, cur, content: Iterable[SkippedContent]): + def _skipped_content_add_metadata(self, db, cur, content: List[SkippedContent]): origin_ids = db.origin_id_get_by_url([cont.origin for cont in content], cur=cur) content = [ attr.evolve(c, origin=origin_id) @@ -414,7 +411,7 @@ @process_metrics @db_transaction() def skipped_content_add( - self, content: Iterable[SkippedContent], db=None, cur=None + self, content: List[SkippedContent], db=None, cur=None ) -> Dict: ctime = now() content = [attr.evolve(c, ctime=ctime) for c in content] @@ -451,10 +448,7 @@ @timed @process_metrics @db_transaction() - def directory_add( - self, directories: Iterable[Directory], db=None, cur=None - ) -> Dict: - directories = list(directories) + def directory_add(self, directories: List[Directory], db=None, cur=None) -> Dict: summary = {"directory:add": 0} dirs = set() @@ -540,8 +534,7 @@ @timed @process_metrics @db_transaction() - def revision_add(self, revisions: Iterable[Revision], db=None, cur=None) -> Dict: - revisions = list(revisions) + def revision_add(self, revisions: List[Revision], db=None, cur=None) -> Dict: summary = {"revision:add": 0} revisions_missing = set( @@ -628,8 +621,7 @@ @timed @process_metrics @db_transaction() - def release_add(self, releases: Iterable[Release], db=None, cur=None) -> Dict: - releases = list(releases) + def release_add(self, releases: List[Release], db=None, cur=None) -> Dict: summary = {"release:add": 0} release_ids = set(release.id for release in releases) @@ -679,7 +671,7 @@ @timed @process_metrics @db_transaction() - def snapshot_add(self, snapshots: Iterable[Snapshot], db=None, cur=None) -> Dict: + def snapshot_add(self, snapshots: List[Snapshot], db=None, cur=None) -> Dict: created_temp_table = False count = 0 @@ -799,7 +791,7 @@ @timed @db_transaction() def origin_visit_add( - self, visits: Iterable[OriginVisit], db=None, cur=None + self, visits: List[OriginVisit], db=None, cur=None ) -> Iterable[OriginVisit]: for visit in visits: origin = self.origin_get([visit.origin], db=db, cur=cur)[0] @@ -847,7 +839,7 @@ @timed @db_transaction() def origin_visit_status_add( - self, visit_statuses: Iterable[OriginVisitStatus], db=None, cur=None, + self, visit_statuses: List[OriginVisitStatus], db=None, cur=None, ) -> None: # First round to check existence (fail early if any is ko) for visit_status in visit_statuses: @@ -995,10 +987,9 @@ @timed @db_transaction(statement_timeout=500) def origin_get( - self, origins: Iterable[str], db=None, cur=None + self, origins: List[str], db=None, cur=None ) -> Iterable[Optional[Origin]]: - origin_urls = list(origins) - rows = db.origin_get_by_url(origin_urls, cur) + rows = db.origin_get_by_url(origins, cur) result: List[Optional[Origin]] = [] for row in rows: origin_d = dict(zip(db.origin_cols, row)) @@ -1073,9 +1064,7 @@ @timed @process_metrics @db_transaction() - def origin_add( - self, origins: Iterable[Origin], db=None, cur=None - ) -> Dict[str, int]: + def origin_add(self, origins: List[Origin], db=None, cur=None) -> Dict[str, int]: urls = [o.url for o in origins] known_origins = set(url for (url,) in db.origin_get_by_url(urls, cur)) # use lists here to keep origins sorted; some tests depend on this @@ -1115,7 +1104,7 @@ @db_transaction() def raw_extrinsic_metadata_add( - self, metadata: Iterable[RawExtrinsicMetadata], db, cur, + self, metadata: List[RawExtrinsicMetadata], db, cur, ) -> None: metadata = list(metadata) self.journal_writer.raw_extrinsic_metadata_add(metadata) @@ -1255,7 +1244,7 @@ @timed @db_transaction() def metadata_fetcher_add( - self, fetchers: Iterable[MetadataFetcher], db=None, cur=None + self, fetchers: List[MetadataFetcher], db=None, cur=None ) -> None: fetchers = list(fetchers) self.journal_writer.metadata_fetcher_add(fetchers) @@ -1284,7 +1273,7 @@ @timed @db_transaction() def metadata_authority_add( - self, authorities: Iterable[MetadataAuthority], db=None, cur=None + self, authorities: List[MetadataAuthority], db=None, cur=None ) -> None: authorities = list(authorities) self.journal_writer.metadata_authority_add(authorities) @@ -1325,13 +1314,13 @@ def diff_revision(self, revision, track_renaming=False): return diff.diff_revision(self, revision, track_renaming) - def clear_buffers(self, object_types: Optional[Iterable[str]] = None) -> None: + def clear_buffers(self, object_types: Optional[List[str]] = None) -> None: """Do nothing """ return None - def flush(self, object_types: Optional[Iterable[str]] = None) -> Dict: + def flush(self, object_types: Optional[List[str]] = None) -> Dict: return {} def _get_authority_id(self, authority: MetadataAuthority, db, cur): diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -187,22 +187,6 @@ swh_storage.refresh_stat_counters() assert swh_storage.stat_counters()["content"] == 1 - def test_content_add_from_generator(self, swh_storage, sample_data): - cont = sample_data.content - - def _cnt_gen(): - yield cont - - actual_result = swh_storage.content_add(_cnt_gen()) - - assert actual_result == { - "content:add": 1, - "content:add:bytes": cont.length, - } - - swh_storage.refresh_stat_counters() - assert swh_storage.stat_counters()["content"] == 1 - def test_content_add_from_lazy_content(self, swh_storage, sample_data): cont = sample_data.content lazy_content = LazyContent.from_dict(cont.to_dict()) @@ -677,22 +661,6 @@ swh_storage.refresh_stat_counters() assert swh_storage.stat_counters()["directory"] == 1 - def test_directory_add_from_generator(self, swh_storage, sample_data): - directory = sample_data.directories[1] - - def _dir_gen(): - yield directory - - actual_result = swh_storage.directory_add(directories=_dir_gen()) - assert actual_result == {"directory:add": 1} - - assert list(swh_storage.journal_writer.journal.objects) == [ - ("directory", directory) - ] - - swh_storage.refresh_stat_counters() - assert swh_storage.stat_counters()["directory"] == 1 - def test_directory_add_twice(self, swh_storage, sample_data): directory = sample_data.directories[1] @@ -881,18 +849,6 @@ swh_storage.refresh_stat_counters() assert swh_storage.stat_counters()["revision"] == 1 - def test_revision_add_from_generator(self, swh_storage, sample_data): - revision = sample_data.revision - - def _rev_gen(): - yield revision - - actual_result = swh_storage.revision_add(_rev_gen()) - assert actual_result == {"revision:add": 1} - - swh_storage.refresh_stat_counters() - assert swh_storage.stat_counters()["revision"] == 1 - def test_revision_add_twice(self, swh_storage, sample_data): revision, revision2 = sample_data.revisions[:2] @@ -1062,24 +1018,6 @@ swh_storage.refresh_stat_counters() assert swh_storage.stat_counters()["release"] == 2 - def test_release_add_from_generator(self, swh_storage, sample_data): - release, release2 = sample_data.releases[:2] - - def _rel_gen(): - yield release - yield release2 - - actual_result = swh_storage.release_add(_rel_gen()) - assert actual_result == {"release:add": 2} - - assert list(swh_storage.journal_writer.journal.objects) == [ - ("release", release), - ("release", release2), - ] - - swh_storage.refresh_stat_counters() - assert swh_storage.stat_counters()["release"] == 2 - def test_release_add_no_author_date(self, swh_storage, sample_data): full_release = sample_data.release @@ -1187,26 +1125,6 @@ swh_storage.refresh_stat_counters() assert swh_storage.stat_counters()["origin"] == 2 - def test_origin_add_from_generator(self, swh_storage, sample_data): - origin, origin2 = sample_data.origins[:2] - - def _ori_gen(): - yield origin - yield origin2 - - stats = swh_storage.origin_add(_ori_gen()) - assert stats == {"origin:add": 2} - - actual_origins = swh_storage.origin_get([origin.url, origin2.url]) - assert actual_origins == [origin, origin2] - - assert set(swh_storage.journal_writer.journal.objects) == set( - [("origin", origin), ("origin", origin2),] - ) - - swh_storage.refresh_stat_counters() - assert swh_storage.stat_counters()["origin"] == 2 - def test_origin_add_twice(self, swh_storage, sample_data): origin, origin2 = sample_data.origins[:2] @@ -2377,18 +2295,6 @@ swh_storage.refresh_stat_counters() assert swh_storage.stat_counters()["snapshot"] == 2 - def test_snapshot_add_many_from_generator(self, swh_storage, sample_data): - snapshot, _, complete_snapshot = sample_data.snapshots[:3] - - def _snp_gen(): - yield from [snapshot, complete_snapshot] - - actual_result = swh_storage.snapshot_add(_snp_gen()) - assert actual_result == {"snapshot:add": 2} - - swh_storage.refresh_stat_counters() - assert swh_storage.stat_counters()["snapshot"] == 2 - def test_snapshot_add_many_incremental(self, swh_storage, sample_data): snapshot, _, complete_snapshot = sample_data.snapshots[:3]