diff --git a/swh/graphql/backends/archive.py b/swh/graphql/backends/archive.py --- a/swh/graphql/backends/archive.py +++ b/swh/graphql/backends/archive.py @@ -4,61 +4,95 @@ # See top-level LICENSE file for more information import os -from typing import Any, Dict, Optional +from typing import Any, Dict, Iterable, List, Optional from swh.graphql import server -from swh.model.model import Sha1Git +from swh.model.model import ( + Content, + DirectoryEntry, + Origin, + OriginVisit, + OriginVisitStatus, + Release, + Revision, + Sha1, + Sha1Git, +) from swh.model.swhids import ObjectType +from swh.storage.interface import PagedResult, PartialBranches, StorageInterface class Archive: - def __init__(self): - self.storage = server.get_storage() + def __init__(self) -> None: + self.storage: StorageInterface = server.get_storage() - def get_origin(self, url): - return self.storage.origin_get([url])[0] + def get_origin(self, url: str) -> Optional[Origin]: + return list(self.storage.origin_get(origins=[url]))[0] - def get_origins(self, after=None, first=50): + def get_origins( + self, after: Optional[str] = None, first: int = 50 + ) -> PagedResult[Origin]: return self.storage.origin_list(page_token=after, limit=first) - def get_origin_visits(self, origin_url, after=None, first=50): - return self.storage.origin_visit_get(origin_url, page_token=after, limit=first) + def get_origin_visits( + self, origin_url: str, after: Optional[str] = None, first: int = 50 + ) -> PagedResult[OriginVisit]: + return self.storage.origin_visit_get( + origin=origin_url, page_token=after, limit=first + ) - def get_origin_visit(self, origin_url, visit_id): - return self.storage.origin_visit_get_by(origin_url, visit_id) + def get_origin_visit(self, origin_url: str, visit_id: int) -> Optional[OriginVisit]: + return self.storage.origin_visit_get_by(origin=origin_url, visit=visit_id) - def get_origin_latest_visit(self, origin_url): - return self.storage.origin_visit_get_latest(origin_url) + def get_origin_latest_visit(self, origin_url: str) -> Optional[OriginVisit]: + return self.storage.origin_visit_get_latest(origin=origin_url) - def get_visit_status(self, origin_url, visit_id, after=None, first=50): + def get_visit_status( + self, + origin_url: str, + visit_id: int, + after: Optional[str] = None, + first: int = 50, + ) -> PagedResult[OriginVisitStatus]: return self.storage.origin_visit_status_get( - origin_url, visit_id, page_token=after, limit=first + origin=origin_url, visit=visit_id, page_token=after, limit=first ) - def get_latest_visit_status(self, origin_url, visit_id): - return self.storage.origin_visit_status_get_latest(origin_url, visit_id) + def get_latest_visit_status( + self, origin_url: str, visit_id: int + ) -> Optional[OriginVisitStatus]: + return self.storage.origin_visit_status_get_latest( + origin_url=origin_url, visit=visit_id + ) - def get_origin_snapshots(self, origin_url): - return self.storage.origin_snapshot_get_all(origin_url) + def get_origin_snapshots(self, origin_url: str) -> List[Sha1Git]: + return self.storage.origin_snapshot_get_all(origin_url=origin_url) def get_snapshot_branches( - self, snapshot, after=b"", first=50, target_types=None, name_include=None - ): + self, + snapshot: Sha1Git, + after: bytes = b"", + first: int = 50, + target_types: Optional[List[str]] = None, + name_include: Optional[bytes] = None, + ) -> Optional[PartialBranches]: return self.storage.snapshot_get_branches( - snapshot, + snapshot_id=snapshot, branches_from=after, branches_count=first, target_types=target_types, branch_name_include_substring=name_include, ) - def get_revisions(self, revision_ids): + def get_revisions(self, revision_ids: List[Sha1Git]) -> List[Optional[Revision]]: return self.storage.revision_get(revision_ids=revision_ids) - def get_revision_log(self, revision_ids, after=None, first=50): + def get_revision_log( + self, revision_ids: List[Sha1Git], first: int = 50 + ) -> Iterable[Optional[Dict[str, Any]]]: return self.storage.revision_log(revisions=revision_ids, limit=first) - def get_releases(self, release_ids): + def get_releases(self, release_ids: List[Sha1Git]) -> List[Optional[Release]]: return self.storage.release_get(releases=release_ids) def get_directory_entry_by_path( @@ -69,12 +103,14 @@ directory=directory_id, paths=paths ) - def get_directory_entries(self, directory_id, after=None, first=50): + def get_directory_entries( + self, directory_id: Sha1Git, after: Optional[bytes] = None, first: int = 50 + ) -> Optional[PagedResult[DirectoryEntry]]: return self.storage.directory_get_entries( - directory_id, limit=first, page_token=after + directory_id=directory_id, limit=first, page_token=after ) - def is_object_available(self, object_id: str, object_type: ObjectType) -> bool: + def is_object_available(self, object_id: bytes, object_type: ObjectType) -> bool: mapping = { ObjectType.CONTENT: self.storage.content_missing_per_sha1_git, ObjectType.DIRECTORY: self.storage.directory_missing, @@ -84,8 +120,8 @@ } return not list(mapping[object_type]([object_id])) - def get_contents(self, checksums: dict): - return self.storage.content_find(checksums) + def get_contents(self, checksums: Dict[str, Any]) -> List[Content]: + return self.storage.content_find(content=checksums) - def get_content_data(self, content_sha1): - return self.storage.content_get_data(content_sha1) + def get_content_data(self, content_sha1: Sha1) -> Optional[bytes]: + return self.storage.content_get_data(content=content_sha1) diff --git a/swh/graphql/backends/search.py b/swh/graphql/backends/search.py --- a/swh/graphql/backends/search.py +++ b/swh/graphql/backends/search.py @@ -3,14 +3,20 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from typing import Optional + from swh.graphql import server +from swh.search.interface import MinimalOriginDict, SearchInterface +from swh.storage.interface import PagedResult class Search: def __init__(self): - self.search = server.get_search() + self.search: SearchInterface = server.get_search() - def get_origins(self, query: str, after=None, first=50): + def get_origins( + self, query: str, after: Optional[str] = None, first: int = 50 + ) -> PagedResult[MinimalOriginDict]: return self.search.origin_search( url_pattern=query, page_token=after, diff --git a/swh/graphql/server.py b/swh/graphql/server.py --- a/swh/graphql/server.py +++ b/swh/graphql/server.py @@ -8,21 +8,23 @@ from swh.core import config from swh.search import get_search as get_swh_search +from swh.search.interface import SearchInterface from swh.storage import get_storage as get_swh_storage +from swh.storage.interface import StorageInterface -graphql_cfg = None -storage = None -search = None +graphql_cfg: Dict[str, Any] = {} +storage: Optional[StorageInterface] = None +search: Optional[SearchInterface] = None -def get_storage(): +def get_storage() -> StorageInterface: global storage if not storage: storage = get_swh_storage(**graphql_cfg["storage"]) return storage -def get_search(): +def get_search() -> SearchInterface: global search if not search: search = get_swh_search(**graphql_cfg["search"])