diff --git a/setup.py b/setup.py index 1728606..5312f69 100644 --- a/setup.py +++ b/setup.py @@ -1,68 +1,68 @@ #!/usr/bin/env python3 -# Copyright (C) 2019-2021 The Software Heritage developers +# Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from io import open from os import path from setuptools import find_packages, setup here = path.abspath(path.dirname(__file__)) # Get the long description from the README file with open(path.join(here, "README.md"), encoding="utf-8") as f: long_description = f.read() def parse_requirements(*names): requirements = [] for name in names: if name: reqf = "requirements-%s.txt" % name else: reqf = "requirements.txt" if not path.exists(reqf): return requirements with open(reqf) as f: for line in f.readlines(): line = line.strip() if not line or line.startswith("#"): continue requirements.append(line) return requirements setup( name="swh.graphql", description="Software Heritage GraphQL Apis", long_description=long_description, long_description_content_type="text/x-rst", python_requires=">=3.7", author="Software Heritage developers", author_email="swh-devel@inria.fr", url="https://forge.softwareheritage.org/diffusion/DGQL", packages=find_packages(), install_requires=parse_requirements(None, "swh"), tests_require=parse_requirements("test"), setup_requires=["setuptools-scm"], use_scm_version=True, extras_require={"testing": parse_requirements("test")}, include_package_data=True, classifiers=[ "Programming Language :: Python :: 3", "Intended Audience :: Developers", "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", "Operating System :: OS Independent", "Development Status :: 3 - Alpha", ], project_urls={ "Bug Reports": "https://forge.softwareheritage.org/maniphest", "Funding": "https://www.softwareheritage.org/donate", "Source": "https://forge.softwareheritage.org/source/swh-graphql", "Documentation": "https://docs.softwareheritage.org/devel/swh-graphql/", }, ) diff --git a/swh/graphql/app.py b/swh/graphql/app.py index f438767..326e1b2 100644 --- a/swh/graphql/app.py +++ b/swh/graphql/app.py @@ -1,35 +1,40 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + # import pkg_resources import os from pathlib import Path from ariadne import gql, load_schema_from_path, make_executable_schema from .resolvers import resolvers, scalars type_defs = gql( # pkg_resources.resource_string("swh.graphql", "schem/schema.graphql").decode() load_schema_from_path( os.path.join(Path(__file__).parent.resolve(), "schema", "schema.graphql") ) ) schema = make_executable_schema( type_defs, resolvers.query, resolvers.origin, resolvers.visit, resolvers.visit_status, resolvers.snapshot, resolvers.snapshot_branch, resolvers.revision, resolvers.release, resolvers.directory, resolvers.directory_entry, resolvers.branch_target, resolvers.release_target, resolvers.directory_entry_target, scalars.id_scalar, scalars.string_scalar, scalars.datetime_scalar, scalars.swhid_scalar, ) diff --git a/swh/graphql/backends/archive.py b/swh/graphql/backends/archive.py index 7ee5b50..644c29c 100644 --- a/swh/graphql/backends/archive.py +++ b/swh/graphql/backends/archive.py @@ -1,72 +1,77 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + from swh.graphql import server class Archive: def __init__(self): self.storage = server.get_storage() def get_origin(self, url): return self.storage.origin_get([url])[0] def get_origins(self, after=None, first=50, url_pattern=None): # STORAGE-TODO # Make them a single function in the backend if url_pattern is None: return self.storage.origin_list(page_token=after, limit=first) return self.storage.origin_search( url_pattern=url_pattern, page_token=after, limit=first ) def get_origin_visits(self, origin_url, after=None, first=50): return self.storage.origin_visit_get(origin_url, page_token=after, limit=first) def get_origin_visit(self, origin_url, visit_id): return self.storage.origin_visit_get_by(origin_url, visit_id) def get_origin_latest_visit(self, origin_url): return self.storage.origin_visit_get_latest(origin_url) def get_visit_status(self, origin_url, visit_id, after=None, first=50): return self.storage.origin_visit_status_get( origin_url, visit_id, page_token=after, limit=first ) def get_latest_visit_status(self, origin_url, visit_id): return self.storage.origin_visit_status_get_latest(origin_url, visit_id) def get_origin_snapshots(self, origin_url): return self.storage.origin_snapshot_get_all(origin_url) def is_snapshot_available(self, snapshot_ids): return not self.storage.snapshot_missing(snapshot_ids) def get_snapshot_branches(self, snapshot, after, first, target_types, name_include): return self.storage.snapshot_get_branches( snapshot, branches_from=after, branches_count=first, target_types=target_types, branch_name_include_substring=name_include, ) def get_revisions(self, revision_ids): return self.storage.revision_get(revision_ids=revision_ids) def get_revision_log(self, revision_ids, after=None, first=50): return self.storage.revision_log(revisions=revision_ids, limit=first) def get_releases(self, release_ids): return self.storage.release_get(releases=release_ids) def is_directory_available(self, directory_ids): return not self.storage.directory_missing(directory_ids) def get_directory_entries(self, directory_id, after=None, first=50): return self.storage.directory_get_entries( directory_id, limit=first, page_token=after ) def get_content(self, content_id): # FIXME, only for tests return self.storage.content_find({"sha1_git": content_id}) diff --git a/swh/graphql/errors/__init__.py b/swh/graphql/errors/__init__.py index 103f1bd..3e8bd41 100644 --- a/swh/graphql/errors/__init__.py +++ b/swh/graphql/errors/__init__.py @@ -1,4 +1,9 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + from .errors import ObjectNotFoundError from .handlers import format_error __all__ = ["ObjectNotFoundError", "format_error"] diff --git a/swh/graphql/errors/errors.py b/swh/graphql/errors/errors.py index 6a587f5..8036b65 100644 --- a/swh/graphql/errors/errors.py +++ b/swh/graphql/errors/errors.py @@ -1,2 +1,8 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + + class ObjectNotFoundError(Exception): """ """ diff --git a/swh/graphql/errors/handlers.py b/swh/graphql/errors/handlers.py index 8d13115..c61e593 100644 --- a/swh/graphql/errors/handlers.py +++ b/swh/graphql/errors/handlers.py @@ -1,7 +1,13 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + + def format_error(error) -> dict: """ Response error formatting """ formatted = error.formatted formatted["message"] = "Unknown error" return formatted diff --git a/swh/graphql/middlewares/graphql.py b/swh/graphql/middlewares/graphql.py index a203330..9314edf 100644 --- a/swh/graphql/middlewares/graphql.py +++ b/swh/graphql/middlewares/graphql.py @@ -1,11 +1,16 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + """ To implement graphql middleware """ class CostError: pass def cost_limiter(): pass diff --git a/swh/graphql/qns.txt b/swh/graphql/qns.txt deleted file mode 100644 index 1e87ec6..0000000 --- a/swh/graphql/qns.txt +++ /dev/null @@ -1,14 +0,0 @@ -Questions -========= -* Idea for a homogeneous ID for node (can we expose primary key from postgres) -* Why the visit status is a paginated list in storage?, and not in v1 -* visit id should be visit number -... Schema related questions - -* What should we include in pageinfo (start token, haspreviouspage etc) -* Query cost calculator logic (could be a bit complex) -* Throttling based on Query cost calculator -* Authentication and Authorization - -* Datetime as a time stamp -* Other scalar types needed (swhid) diff --git a/swh/graphql/resolvers/base_connection.py b/swh/graphql/resolvers/base_connection.py index 617d2a7..f5f56b3 100644 --- a/swh/graphql/resolvers/base_connection.py +++ b/swh/graphql/resolvers/base_connection.py @@ -1,107 +1,112 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Optional, Type from swh.graphql.utils import utils from .base_node import BaseNode @dataclass class PageInfo: hasNextPage: bool endCursor: str class BaseConnection(ABC): """ Base class for all the connection resolvers """ _node_class: Optional[Type[BaseNode]] = None _page_size = 50 # default page size def __init__(self, obj, info, paged_data=None, **kwargs): self.obj = obj self.info = info self.kwargs = kwargs self._paged_data = paged_data def __call__(self, *args, **kw): return self @property def edges(self): return self._get_edges() @property def nodes(self): """ Override if needed; return a list of objects If a node class is set, return a list of its (Node) instances else a list of raw results """ if self._node_class is not None: return [ self._node_class(self.obj, self.info, node_data=result, **self.kwargs) for result in self.get_paged_data().results ] return self.get_paged_data().results @property def pageInfo(self): # To support the schema naming convention # FIXME, add more details like startCursor return PageInfo( hasNextPage=bool(self.get_paged_data().next_page_token), endCursor=utils.get_encoded_cursor(self.get_paged_data().next_page_token), ) @property def totalCount(self): # To support the schema naming convention return self._get_total_count() def _get_total_count(self): """ Will be None for most of the connections override if needed/possible """ return None def get_paged_data(self): """ Cache to avoid multiple calls to the backend (_get_paged_result) return a PagedResult object """ if self._paged_data is None: # FIXME, make this call async (not for v1) self._paged_data = self._get_paged_result() return self._paged_data @abstractmethod def _get_paged_result(self): """ Override for desired behaviour return a PagedResult object """ # FIXME, make this call async (not for v1) return None def _get_edges(self): # FIXME, make cursor work per item # Cursor can't be None here return [{"cursor": "dummy", "node": node} for node in self.nodes] def _get_after_arg(self): """ Return the decoded next page token override to use a specific token """ return utils.get_decoded_cursor(self.kwargs.get("after")) def _get_first_arg(self): """ page_size is set to 50 by default """ return self.kwargs.get("first", self._page_size) diff --git a/swh/graphql/resolvers/base_node.py b/swh/graphql/resolvers/base_node.py index 9f389f3..ff3c2b3 100644 --- a/swh/graphql/resolvers/base_node.py +++ b/swh/graphql/resolvers/base_node.py @@ -1,76 +1,81 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + from abc import ABC from collections import namedtuple from swh.graphql.errors import ObjectNotFoundError class BaseNode(ABC): """ Base class for all the Node resolvers """ def __init__(self, obj, info, node_data=None, **kwargs): self.obj = obj self.info = info self.kwargs = kwargs self._node = self._get_node(node_data) # handle the errors, if any, after _node is set self._handle_node_errors() def _get_node(self, node_data): """ Get the node object from the given data if the data (node_data) is none make a function call to get data from backend """ if node_data is None: node_data = self._get_node_data() return self._get_node_from_data(node_data) def _get_node_from_data(self, node_data): """ Get the object from node_data In case of a dict, convert it to an object Override to support different data structures """ if type(node_data) is dict: return namedtuple("NodeObj", node_data.keys())(*node_data.values()) return node_data def _handle_node_errors(self): """ Handle any error related to node data raise an error in case the object returned is None override for specific behaviour """ if self._node is None: raise ObjectNotFoundError("Requested object is not available") def __call__(self, *args, **kw): return self def _get_node_data(self): """ Override for desired behaviour This will be called only when node_data is None """ # FIXME, make this call async (not for v1) return None def __getattr__(self, name): """ Any property defined in the sub-class will get precedence over the _node attributes """ return getattr(self._node, name) def is_type_of(self): return self.__class__.__name__ class BaseSWHNode(BaseNode): @property def SWHID(self): return self._node.swhid() diff --git a/swh/graphql/resolvers/content.py b/swh/graphql/resolvers/content.py index 68d71e0..08d8637 100644 --- a/swh/graphql/resolvers/content.py +++ b/swh/graphql/resolvers/content.py @@ -1,44 +1,49 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + from swh.graphql.backends import archive from .base_node import BaseSWHNode class BaseContentNode(BaseSWHNode): """ """ def _get_content_by_id(self, content_id): content = archive.Archive().get_content(content_id) return content[0] if content else None @property def checksum(self): # FIXME, return a Node object return {k: v.hex() for (k, v) in self._node.hashes().items()} @property def id(self): return self._node.sha1_git def is_type_of(self): return "Content" class ContentNode(BaseContentNode): def _get_node_data(self): """ When a content is requested directly with its SWHID """ return self._get_content_by_id(self.kwargs.get("SWHID").object_id) class TargetContentNode(BaseContentNode): def _get_node_data(self): """ When a content is requested from a directory entry or from a release target content id is obj.targetHash here """ content_id = self.obj.targetHash return self._get_content_by_id(content_id) diff --git a/swh/graphql/resolvers/directory.py b/swh/graphql/resolvers/directory.py index 68c5d78..5c345b7 100644 --- a/swh/graphql/resolvers/directory.py +++ b/swh/graphql/resolvers/directory.py @@ -1,50 +1,55 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + from swh.graphql.backends import archive from swh.model.model import Directory from .base_node import BaseSWHNode class BaseDirectoryNode(BaseSWHNode): def _get_directory_by_id(self, directory_id): # Return a Directory model object # entries is initialized as empty # Same pattern is used in snapshot return Directory(id=directory_id, entries=()) def is_type_of(self): return "Directory" class DirectoryNode(BaseDirectoryNode): def _get_node_data(self): """ When a directory is requested directly with its SWHID """ directory_id = self.kwargs.get("SWHID").object_id # path = "" if archive.Archive().is_directory_available([directory_id]): return self._get_directory_by_id(directory_id) return None class RevisionDirectoryNode(BaseDirectoryNode): def _get_node_data(self): """ When a directory is requested from a revision self.obj is revision here self.obj.directorySWHID is the required dir SWHID (set from resolvers.revision.py:BaseRevisionNode) """ directory_id = self.obj.directorySWHID.object_id return self._get_directory_by_id(directory_id) class TargetDirectoryNode(BaseDirectoryNode): def _get_node_data(self): """ When a directory is requested as a target self.obj can be a Release or a DirectoryEntry obj.targetHash is the requested directory id here """ return self._get_directory_by_id(self.obj.targetHash) diff --git a/swh/graphql/resolvers/directory_entry.py b/swh/graphql/resolvers/directory_entry.py index b4b9712..bbb94c9 100644 --- a/swh/graphql/resolvers/directory_entry.py +++ b/swh/graphql/resolvers/directory_entry.py @@ -1,34 +1,39 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + from swh.graphql.backends import archive from swh.graphql.utils import utils from .base_connection import BaseConnection from .base_node import BaseNode class DirectoryEntryNode(BaseNode): """ """ @property def targetHash(self): # To support the schema naming convention return self._node.target class DirectoryEntryConnection(BaseConnection): _node_class = DirectoryEntryNode def _get_paged_result(self): """ When entries requested from a directory self.obj.SWHID is the directory SWHID here This is not paginated from swh-storgae using dummy pagination """ # FIXME, using dummy(local) pagination, move pagination to backend # To remove localpagination, just drop the paginated call # STORAGE-TODO entries = ( archive.Archive().get_directory_entries(self.obj.SWHID.object_id).results ) return utils.paginated(entries, self._get_first_arg(), self._get_after_arg()) diff --git a/swh/graphql/resolvers/person.py b/swh/graphql/resolvers/person.py index ca55b85..c805924 100644 --- a/swh/graphql/resolvers/person.py +++ b/swh/graphql/resolvers/person.py @@ -1,5 +1,10 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + from .base_node import BaseNode class PersonNode(BaseNode): """ """ diff --git a/swh/graphql/resolvers/release.py b/swh/graphql/resolvers/release.py index 767b197..2187e2f 100644 --- a/swh/graphql/resolvers/release.py +++ b/swh/graphql/resolvers/release.py @@ -1,45 +1,50 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + from swh.graphql.backends import archive from .base_node import BaseSWHNode class BaseReleaseNode(BaseSWHNode): def _get_release_by_id(self, release_id): return (archive.Archive().get_releases([release_id]) or None)[0] @property def targetHash(self): # To support the schema naming convention return self._node.target @property def targetType(self): # To support the schema naming convention return self._node.target_type.value def is_type_of(self): """ is_type_of is required only when resolving a UNION type This is for ariadne to return the right type """ return "Release" class ReleaseNode(BaseReleaseNode): """ When the release is requested directly with its SWHID """ def _get_node_data(self): return self._get_release_by_id(self.kwargs.get("SWHID").object_id) class TargetReleaseNode(BaseReleaseNode): """ When a release is requested as a target self.obj could be a snapshotbranch or a release self.obj.targetHash is the requested release id here """ def _get_node_data(self): return self._get_release_by_id(self.obj.targetHash) diff --git a/swh/graphql/resolvers/resolver_factory.py b/swh/graphql/resolvers/resolver_factory.py index 300da90..286573e 100644 --- a/swh/graphql/resolvers/resolver_factory.py +++ b/swh/graphql/resolvers/resolver_factory.py @@ -1,60 +1,65 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + from .content import ContentNode, TargetContentNode from .directory import DirectoryNode, RevisionDirectoryNode, TargetDirectoryNode from .directory_entry import DirectoryEntryConnection from .origin import OriginConnection, OriginNode from .release import ReleaseNode, TargetReleaseNode from .revision import ( LogRevisionConnection, ParentRevisionConnection, RevisionNode, TargetRevisionNode, ) from .snapshot import OriginSnapshotConnection, SnapshotNode, VisitSnapshotNode from .snapshot_branch import SnapshotBranchConnection from .visit import LatestVisitNode, OriginVisitConnection, OriginVisitNode from .visit_status import LatestVisitStatusNode, VisitStatusConnection def get_node_resolver(resolver_type): # FIXME, replace with a proper factory method mapping = { "origin": OriginNode, "visit": OriginVisitNode, "latest-visit": LatestVisitNode, "latest-status": LatestVisitStatusNode, "visit-snapshot": VisitSnapshotNode, "snapshot": SnapshotNode, "branch-revision": TargetRevisionNode, "branch-release": TargetReleaseNode, "revision": RevisionNode, "revision-directory": RevisionDirectoryNode, "release": ReleaseNode, "release-revision": TargetRevisionNode, "release-release": TargetReleaseNode, "release-directory": TargetDirectoryNode, "release-content": TargetContentNode, "directory": DirectoryNode, "content": ContentNode, "dir-entry-dir": TargetDirectoryNode, "dir-entry-file": TargetContentNode, } if resolver_type not in mapping: raise AttributeError(f"Invalid node type: {resolver_type}") return mapping[resolver_type] def get_connection_resolver(resolver_type): # FIXME, replace with a proper factory method mapping = { "origins": OriginConnection, "origin-visits": OriginVisitConnection, "origin-snapshots": OriginSnapshotConnection, "visit-status": VisitStatusConnection, "snapshot-branches": SnapshotBranchConnection, "revision-parents": ParentRevisionConnection, "revision-log": LogRevisionConnection, "directory-entries": DirectoryEntryConnection, } if resolver_type not in mapping: raise AttributeError(f"Invalid connection type: {resolver_type}") return mapping[resolver_type] diff --git a/swh/graphql/resolvers/resolvers.py b/swh/graphql/resolvers/resolvers.py index ad85786..3675fd3 100644 --- a/swh/graphql/resolvers/resolvers.py +++ b/swh/graphql/resolvers/resolvers.py @@ -1,240 +1,245 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + """ High level resolvers Any schema attribute can be resolved by any of the following ways and in the following priority order - In this module using an annotation (eg: @visitstatus.field("snapshot")) - As a property in the Node object (eg: resolvers.visit.OriginVisitNode.id) - As an attribute/item in the object/dict returned by the backend (eg: Origin.url) """ from ariadne import ObjectType, UnionType from graphql.type import GraphQLResolveInfo from swh.graphql import resolvers as rs from .resolver_factory import get_connection_resolver, get_node_resolver query: ObjectType = ObjectType("Query") origin: ObjectType = ObjectType("Origin") visit: ObjectType = ObjectType("Visit") visit_status: ObjectType = ObjectType("VisitStatus") snapshot: ObjectType = ObjectType("Snapshot") snapshot_branch: ObjectType = ObjectType("Branch") revision: ObjectType = ObjectType("Revision") release: ObjectType = ObjectType("Release") directory: ObjectType = ObjectType("Directory") directory_entry: ObjectType = ObjectType("DirectoryEntry") branch_target: UnionType = UnionType("BranchTarget") release_target: UnionType = UnionType("ReleaseTarget") directory_entry_target: UnionType = UnionType("DirectoryEntryTarget") # Node resolvers # A node resolver should return an instance of BaseNode @query.field("origin") def origin_resolver(obj: None, info: GraphQLResolveInfo, **kw) -> rs.origin.OriginNode: """ """ resolver = get_node_resolver("origin") return resolver(obj, info, **kw)() @origin.field("latestVisit") def latest_visit_resolver( obj: rs.origin.OriginNode, info: GraphQLResolveInfo, **kw ) -> rs.visit.LatestVisitNode: """ """ resolver = get_node_resolver("latest-visit") return resolver(obj, info, **kw)() @query.field("visit") def visit_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.visit.OriginVisitNode: """ """ resolver = get_node_resolver("visit") return resolver(obj, info, **kw)() @visit.field("latestStatus") def latest_visit_status_resolver( obj, info: GraphQLResolveInfo, **kw ) -> rs.visit_status.LatestVisitStatusNode: """ """ resolver = get_node_resolver("latest-status") return resolver(obj, info, **kw)() @query.field("snapshot") def snapshot_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.snapshot.SnapshotNode: """ """ resolver = get_node_resolver("snapshot") return resolver(obj, info, **kw)() @visit_status.field("snapshot") def visit_snapshot_resolver( obj, info: GraphQLResolveInfo, **kw ) -> rs.snapshot.VisitSnapshotNode: resolver = get_node_resolver("visit-snapshot") return resolver(obj, info, **kw)() @snapshot_branch.field("target") def snapshot_branch_target_resolver( obj: rs.snapshot_branch.SnapshotBranchNode, info: GraphQLResolveInfo, **kw ): """ Snapshot branch target can be a revision or a release """ resolver_type = f"branch-{obj.type}" resolver = get_node_resolver(resolver_type) return resolver(obj, info, **kw)() @query.field("revision") def revision_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.revision.RevisionNode: resolver = get_node_resolver("revision") return resolver(obj, info, **kw)() @revision.field("directory") def revision_directory_resolver( obj, info: GraphQLResolveInfo, **kw ) -> rs.directory.RevisionDirectoryNode: resolver = get_node_resolver("revision-directory") return resolver(obj, info, **kw)() @query.field("release") def release_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.release.ReleaseNode: resolver = get_node_resolver("release") return resolver(obj, info, **kw)() @release.field("target") def release_target_resolver(obj, info: GraphQLResolveInfo, **kw): """ release target can be a release, revision, directory or content obj is release here, target type is obj.target_type """ resolver_type = f"release-{obj.target_type.value}" resolver = get_node_resolver(resolver_type) return resolver(obj, info, **kw)() @query.field("directory") def directory_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.directory.DirectoryNode: resolver = get_node_resolver("directory") return resolver(obj, info, **kw)() @directory_entry.field("target") def directory_entry_target_resolver( obj: rs.directory_entry.DirectoryEntryNode, info: GraphQLResolveInfo, **kw ): """ directory entry target can be a directory or a content """ resolver_type = f"dir-entry-{obj.type}" resolver = get_node_resolver(resolver_type) return resolver(obj, info, **kw)() @query.field("content") def content_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.content.ContentNode: resolver = get_node_resolver("content") return resolver(obj, info, **kw)() # Connection resolvers # A connection resolver should return an instance of BaseConnection @query.field("origins") def origins_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.origin.OriginConnection: resolver = get_connection_resolver("origins") return resolver(obj, info, **kw)() @origin.field("visits") def visits_resolver( obj: rs.origin.OriginNode, info: GraphQLResolveInfo, **kw ) -> rs.visit.OriginVisitConnection: resolver = get_connection_resolver("origin-visits") return resolver(obj, info, **kw)() @origin.field("snapshots") def origin_snapshots_resolver( obj: rs.origin.OriginNode, info: GraphQLResolveInfo, **kw ) -> rs.snapshot.OriginSnapshotConnection: """ """ resolver = get_connection_resolver("origin-snapshots") return resolver(obj, info, **kw)() @visit.field("status") def visitstatus_resolver( obj, info: GraphQLResolveInfo, **kw ) -> rs.visit_status.VisitStatusConnection: resolver = get_connection_resolver("visit-status") return resolver(obj, info, **kw)() @snapshot.field("branches") def snapshot_branches_resolver( obj, info: GraphQLResolveInfo, **kw ) -> rs.snapshot_branch.SnapshotBranchConnection: resolver = get_connection_resolver("snapshot-branches") return resolver(obj, info, **kw)() @revision.field("parents") def revision_parents_resolver( obj, info: GraphQLResolveInfo, **kw ) -> rs.revision.ParentRevisionConnection: resolver = get_connection_resolver("revision-parents") return resolver(obj, info, **kw)() # @revision.field("revisionLog") # def revision_log_resolver(obj, info, **kw): # resolver = get_connection_resolver("revision-log") # return resolver(obj, info, **kw)() @directory.field("entries") def directory_entry_resolver( obj, info: GraphQLResolveInfo, **kw ) -> rs.directory_entry.DirectoryEntryConnection: resolver = get_connection_resolver("directory-entries") return resolver(obj, info, **kw)() # Any other type of resolver @release_target.type_resolver @directory_entry_target.type_resolver @branch_target.type_resolver def union_resolver(obj, *_) -> str: """ Generic resolver for all the union types """ return obj.is_type_of() diff --git a/swh/graphql/resolvers/revision.py b/swh/graphql/resolvers/revision.py index 8cb6cd1..99da788 100644 --- a/swh/graphql/resolvers/revision.py +++ b/swh/graphql/resolvers/revision.py @@ -1,97 +1,102 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + from swh.graphql.backends import archive from swh.graphql.utils import utils from swh.model.swhids import CoreSWHID, ObjectType from .base_connection import BaseConnection from .base_node import BaseSWHNode class BaseRevisionNode(BaseSWHNode): def _get_revision_by_id(self, revision_id): return (archive.Archive().get_revisions([revision_id]) or None)[0] @property def parentSWHIDs(self): # To support the schema naming convention return [ CoreSWHID(object_type=ObjectType.REVISION, object_id=parent_id) for parent_id in self._node.parents ] @property def directorySWHID(self): # To support the schema naming convention """ """ return CoreSWHID( object_type=ObjectType.DIRECTORY, object_id=self._node.directory ) @property def type(self): return self._node.type.value def is_type_of(self): """ is_type_of is required only when resolving a UNION type This is for ariadne to return the right type """ return "Revision" class RevisionNode(BaseRevisionNode): """ When the revision is requested directly with its SWHID """ def _get_node_data(self): return self._get_revision_by_id(self.kwargs.get("SWHID").object_id) class TargetRevisionNode(BaseRevisionNode): """ When a revision is requested as a target self.obj could be a snapshotbranch or a release self.obj.targetHash is the requested revision id here """ def _get_node_data(self): return self._get_revision_by_id(self.obj.targetHash) class ParentRevisionConnection(BaseConnection): """ When parent revisions is requested from a revision self.obj is the current(child) revision self.obj.parentSWHIDs is the list of parent SWHIDs """ _node_class = BaseRevisionNode def _get_paged_result(self): # FIXME, using dummy(local) pagination, move pagination to backend # To remove localpagination, just drop the paginated call # STORAGE-TODO (pagination) parents = archive.Archive().get_revisions( [x.object_id for x in self.obj.parentSWHIDs] ) return utils.paginated(parents, self._get_first_arg(), self._get_after_arg()) class LogRevisionConnection(BaseConnection): """ When revisionslog is requested from a revision self.obj is the current revision id """ _node_class = BaseRevisionNode def _get_paged_result(self): # STORAGE-TODO (date in revisionlog is a dict) log = archive.Archive().get_revision_log([self.obj.SWHID.object_id]) # FIXME, using dummy(local) pagination, move pagination to backend # To remove localpagination, just drop the paginated call # STORAGE-TODO (pagination) return utils.paginated(log, self._get_first_arg(), self._get_after_arg()) diff --git a/swh/graphql/resolvers/scalars.py b/swh/graphql/resolvers/scalars.py index e21ffec..5d34594 100644 --- a/swh/graphql/resolvers/scalars.py +++ b/swh/graphql/resolvers/scalars.py @@ -1,46 +1,51 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + from datetime import datetime from ariadne import ScalarType from swh.graphql.utils import utils from swh.model.model import TimestampWithTimezone from swh.model.swhids import CoreSWHID datetime_scalar = ScalarType("DateTime") swhid_scalar = ScalarType("SWHID") id_scalar = ScalarType("ID") string_scalar = ScalarType("String") @id_scalar.serializer def serialize_id(value): if type(value) is bytes: return value.hex() return value @string_scalar.serializer def serialize_string(value): if type(value) is bytes: return value.decode("utf-8") return value @datetime_scalar.serializer def serialize_datetime(value): # FIXME, handle error and return None if type(value) == TimestampWithTimezone: value = value.to_datetime() if type(value) == datetime: return utils.get_formatted_date(value) return None @swhid_scalar.value_parser def validate_swhid(value): return CoreSWHID.from_string(value) @swhid_scalar.serializer def serialize_swhid(value): return str(value) diff --git a/swh/graphql/resolvers/snapshot.py b/swh/graphql/resolvers/snapshot.py index 568a636..d0e81e6 100644 --- a/swh/graphql/resolvers/snapshot.py +++ b/swh/graphql/resolvers/snapshot.py @@ -1,54 +1,59 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + from swh.graphql.backends import archive from swh.graphql.utils import utils from swh.model.model import Snapshot from .base_connection import BaseConnection from .base_node import BaseSWHNode class BaseSnapshotNode(BaseSWHNode): def _get_snapshot_by_id(self, snapshot_id): # Return a Snapshot model object # branches is initialized as empty # Same pattern is used in directory return Snapshot(id=snapshot_id, branches={}) class SnapshotNode(BaseSnapshotNode): """ For directly accessing a snapshot with its SWHID """ def _get_node_data(self): """ """ snapshot_id = self.kwargs.get("SWHID").object_id if archive.Archive().is_snapshot_available([snapshot_id]): return self._get_snapshot_by_id(snapshot_id) return None class VisitSnapshotNode(BaseSnapshotNode): """ For accessing a snapshot from a visitstatus type """ def _get_node_data(self): """ self.obj is visitstatus here self.obj.snapshotSWHID is the requested snapshot SWHID """ snapshot_id = self.obj.snapshotSWHID.object_id return self._get_snapshot_by_id(snapshot_id) class OriginSnapshotConnection(BaseConnection): _node_class = BaseSnapshotNode def _get_paged_result(self): """ """ results = archive.Archive().get_origin_snapshots(self.obj.url) snapshots = [Snapshot(id=snapshot, branches={}) for snapshot in results] # FIXME, using dummy(local) pagination, move pagination to backend # To remove localpagination, just drop the paginated call # STORAGE-TODO return utils.paginated(snapshots, self._get_first_arg(), self._get_after_arg()) diff --git a/swh/graphql/resolvers/snapshot_branch.py b/swh/graphql/resolvers/snapshot_branch.py index 2c49e24..0bcd37b 100644 --- a/swh/graphql/resolvers/snapshot_branch.py +++ b/swh/graphql/resolvers/snapshot_branch.py @@ -1,73 +1,78 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + from collections import namedtuple from swh.graphql.backends import archive from swh.graphql.utils import utils from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseNode class SnapshotBranchNode(BaseNode): """ target field for this Node is a UNION in the schema It is resolved in resolvers.resolvers.py """ def _get_node_from_data(self, node_data): """ node_data is not a dict in this case overriding to support this special data structure """ # STORAGE-TODO; return an object in the normal format branch_name, branch_obj = node_data node = { "name": branch_name, "type": branch_obj.target_type.value, "target": branch_obj.target, } return namedtuple("NodeObj", node.keys())(*node.values()) @property def targetHash(self): # To support the schema naming convention return self._node.target class SnapshotBranchConnection(BaseConnection): _node_class = SnapshotBranchNode def _get_paged_result(self): """ When branches requested from a snapshot self.obj.SWHID is the snapshot SWHID here (as returned from resolvers/snapshot.py) """ result = archive.Archive().get_snapshot_branches( self.obj.SWHID.object_id, after=self._get_after_arg(), first=self._get_first_arg(), target_types=self.kwargs.get("types"), name_include=self.kwargs.get("nameInclude"), ) # FIXME Cursor must be a hex to be consistent with # the base class, hack to make that work end_cusrsor = ( result["next_branch"].hex() if result["next_branch"] is not None else None ) # FIXME, this pagination is not consistent with other connections # FIX in swh-storage to return PagedResult # STORAGE-TODO return PagedResult( results=result["branches"].items(), next_page_token=end_cusrsor ) def _get_after_arg(self): """ Snapshot branch is using a different cursor; logic to handle that """ # FIXME Cursor must be a hex to be consistent with # the base class, hack to make that work after = utils.get_decoded_cursor(self.kwargs.get("after", "")) return bytes.fromhex(after) diff --git a/swh/graphql/resolvers/visit.py b/swh/graphql/resolvers/visit.py index 877f1fd..a2222d1 100644 --- a/swh/graphql/resolvers/visit.py +++ b/swh/graphql/resolvers/visit.py @@ -1,51 +1,56 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + from swh.graphql.backends import archive from swh.graphql.utils import utils from .base_connection import BaseConnection from .base_node import BaseNode class BaseVisitNode(BaseNode): @property def id(self): # FIXME, use a better id return utils.b64encode(f"{self.origin}-{str(self.visit)}") @property def visitId(self): # To support the schema naming convention return self._node.visit class OriginVisitNode(BaseVisitNode): """ Get the visit directly with an origin URL and a visit ID """ def _get_node_data(self): return archive.Archive().get_origin_visit( self.kwargs.get("originUrl"), int(self.kwargs.get("visitId")) ) class LatestVisitNode(BaseVisitNode): """ Get the latest visit for an origin self.obj is the origin object here self.obj.url is the origin URL """ def _get_node_data(self): return archive.Archive().get_origin_latest_visit(self.obj.url) class OriginVisitConnection(BaseConnection): _node_class = BaseVisitNode def _get_paged_result(self): """ Get the visits for the given origin parent obj (self.obj) is origin here """ return archive.Archive().get_origin_visits( self.obj.url, after=self._get_after_arg(), first=self._get_first_arg() ) diff --git a/swh/graphql/resolvers/visit_status.py b/swh/graphql/resolvers/visit_status.py index 970a69f..7c9b33d 100644 --- a/swh/graphql/resolvers/visit_status.py +++ b/swh/graphql/resolvers/visit_status.py @@ -1,43 +1,48 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + from swh.graphql.backends import archive from swh.model.swhids import CoreSWHID, ObjectType from .base_connection import BaseConnection from .base_node import BaseNode class BaseVisitStatusNode(BaseNode): """ """ @property def snapshotSWHID(self): # To support the schema naming convention return CoreSWHID(object_type=ObjectType.SNAPSHOT, object_id=self._node.snapshot) class LatestVisitStatusNode(BaseVisitStatusNode): """ Get the latest visit status for a visit self.obj is the visit object here self.obj.origin is the origin URL """ def _get_node_data(self): return archive.Archive().get_latest_visit_status( self.obj.origin, self.obj.visitId ) class VisitStatusConnection(BaseConnection): """ self.obj is the visit object self.obj.origin is the origin URL """ _node_class = BaseVisitStatusNode def _get_paged_result(self): return archive.Archive().get_visit_status( self.obj.origin, self.obj.visitId, after=self._get_after_arg(), first=self._get_first_arg(), ) diff --git a/swh/graphql/tests/unit/resolvers/test_base_node.py b/swh/graphql/tests/unit/resolvers/test_base_node.py index c5f5968..94b2e29 100644 --- a/swh/graphql/tests/unit/resolvers/test_base_node.py +++ b/swh/graphql/tests/unit/resolvers/test_base_node.py @@ -1,25 +1,30 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + """ """ class TestBaseNode: def test_init(self): pass def test_get_node(self): pass def test_get_node_from_data(self): pass def test_handle_node_errors(self): pass def test_get_node_data(self): pass def test_getattr(self): pass def test_is_type_of(sellf): pass diff --git a/swh/graphql/tests/unit/resolvers/test_resolver_factory.py b/swh/graphql/tests/unit/resolvers/test_resolver_factory.py index 4d39ce9..7bb4400 100644 --- a/swh/graphql/tests/unit/resolvers/test_resolver_factory.py +++ b/swh/graphql/tests/unit/resolvers/test_resolver_factory.py @@ -1,58 +1,63 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + import pytest from swh.graphql.resolvers import resolver_factory class TestFactory: @pytest.mark.parametrize( - "input_type, expexted", + "input_type, expected", [ ("origin", "OriginNode"), ("visit", "OriginVisitNode"), ("latest-visit", "LatestVisitNode"), ("latest-status", "LatestVisitStatusNode"), ("visit-snapshot", "VisitSnapshotNode"), ("snapshot", "SnapshotNode"), ("branch-revision", "TargetRevisionNode"), ("branch-release", "TargetReleaseNode"), ("revision", "RevisionNode"), ("revision-directory", "RevisionDirectoryNode"), ("release", "ReleaseNode"), ("release-revision", "TargetRevisionNode"), ("release-release", "TargetReleaseNode"), ("release-directory", "TargetDirectoryNode"), ("release-content", "TargetContentNode"), ("directory", "DirectoryNode"), ("content", "ContentNode"), ("dir-entry-dir", "TargetDirectoryNode"), ("dir-entry-file", "TargetContentNode"), ], ) - def test_get_node_resolver(self, input_type, expexted): + def test_get_node_resolver(self, input_type, expected): response = resolver_factory.get_node_resolver(input_type) - assert response.__name__ == expexted + assert response.__name__ == expected def test_get_node_resolver_invalid_type(self): with pytest.raises(AttributeError): resolver_factory.get_node_resolver("invalid") @pytest.mark.parametrize( - "input_type, expexted", + "input_type, expected", [ ("origins", "OriginConnection"), ("origin-visits", "OriginVisitConnection"), ("origin-snapshots", "OriginSnapshotConnection"), ("visit-status", "VisitStatusConnection"), ("snapshot-branches", "SnapshotBranchConnection"), ("revision-parents", "ParentRevisionConnection"), ("revision-log", "LogRevisionConnection"), ("directory-entries", "DirectoryEntryConnection"), ], ) - def test_get_connection_resolver(self, input_type, expexted): + def test_get_connection_resolver(self, input_type, expected): response = resolver_factory.get_connection_resolver(input_type) - assert response.__name__ == expexted + assert response.__name__ == expected def test_get_connection_resolver_invalid_type(self): with pytest.raises(AttributeError): resolver_factory.get_connection_resolver("invalid") diff --git a/swh/graphql/tests/unit/resolvers/test_resolvers.py b/swh/graphql/tests/unit/resolvers/test_resolvers.py index 6ac5464..ff29a8f 100644 --- a/swh/graphql/tests/unit/resolvers/test_resolvers.py +++ b/swh/graphql/tests/unit/resolvers/test_resolvers.py @@ -1,115 +1,120 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + import pytest from swh.graphql import resolvers from swh.graphql.resolvers import resolvers as rs class TestResolvers: """ """ @pytest.fixture def dummy_node(self): return {"test": "test"} @pytest.mark.parametrize( "resolver_func, node_cls", [ (rs.origin_resolver, resolvers.origin.OriginNode), (rs.visit_resolver, resolvers.visit.OriginVisitNode), (rs.latest_visit_resolver, resolvers.visit.LatestVisitNode), ( rs.latest_visit_status_resolver, resolvers.visit_status.LatestVisitStatusNode, ), (rs.snapshot_resolver, resolvers.snapshot.SnapshotNode), (rs.visit_snapshot_resolver, resolvers.snapshot.VisitSnapshotNode), (rs.revision_resolver, resolvers.revision.RevisionNode), (rs.revision_directory_resolver, resolvers.directory.RevisionDirectoryNode), (rs.release_resolver, resolvers.release.ReleaseNode), (rs.directory_resolver, resolvers.directory.DirectoryNode), (rs.content_resolver, resolvers.content.ContentNode), ], ) def test_node_resolver(self, mocker, dummy_node, resolver_func, node_cls): mock_get = mocker.patch.object(node_cls, "_get_node", return_value=dummy_node) node_obj = resolver_func(None, None) # assert the _get_node method is called on the right object assert isinstance(node_obj, node_cls) assert mock_get.assert_called @pytest.mark.parametrize( "resolver_func, connection_cls", [ (rs.origins_resolver, resolvers.origin.OriginConnection), (rs.visits_resolver, resolvers.visit.OriginVisitConnection), (rs.origin_snapshots_resolver, resolvers.snapshot.OriginSnapshotConnection), (rs.visitstatus_resolver, resolvers.visit_status.VisitStatusConnection), ( rs.snapshot_branches_resolver, resolvers.snapshot_branch.SnapshotBranchConnection, ), (rs.revision_parents_resolver, resolvers.revision.ParentRevisionConnection), # (rs.revision_log_resolver, resolvers.revision.LogRevisionConnection), ( rs.directory_entry_resolver, resolvers.directory_entry.DirectoryEntryConnection, ), ], ) def test_connection_resolver(self, resolver_func, connection_cls): connection_obj = resolver_func(None, None) # assert the right object is returned assert isinstance(connection_obj, connection_cls) @pytest.mark.parametrize( "branch_type, node_cls", [ ("revision", resolvers.revision.TargetRevisionNode), ("release", resolvers.release.TargetReleaseNode), ], ) def test_snapshot_branch_target_resolver( self, mocker, dummy_node, branch_type, node_cls ): obj = mocker.Mock(type=branch_type) mock_get = mocker.patch.object(node_cls, "_get_node", return_value=dummy_node) node_obj = rs.snapshot_branch_target_resolver(obj, None) assert isinstance(node_obj, node_cls) assert mock_get.assert_called @pytest.mark.parametrize( "target_type, node_cls", [ ("revision", resolvers.revision.TargetRevisionNode), ("release", resolvers.release.TargetReleaseNode), ("directory", resolvers.directory.TargetDirectoryNode), ("content", resolvers.content.TargetContentNode), ], ) def test_release_target_resolver(self, mocker, dummy_node, target_type, node_cls): obj = mocker.Mock(target_type=(mocker.Mock(value=target_type))) mock_get = mocker.patch.object(node_cls, "_get_node", return_value=dummy_node) node_obj = rs.release_target_resolver(obj, None) assert isinstance(node_obj, node_cls) assert mock_get.assert_called @pytest.mark.parametrize( "target_type, node_cls", [ ("dir", resolvers.directory.TargetDirectoryNode), ("file", resolvers.content.TargetContentNode), ], ) def test_directory_entry_target_resolver( self, mocker, dummy_node, target_type, node_cls ): obj = mocker.Mock(type=target_type) mock_get = mocker.patch.object(node_cls, "_get_node", return_value=dummy_node) node_obj = rs.directory_entry_target_resolver(obj, None) assert isinstance(node_obj, node_cls) assert mock_get.assert_called def test_unit_resolver(self, mocker): obj = mocker.Mock() obj.is_type_of.return_value = "test" assert rs.union_resolver(obj) == "test" diff --git a/swh/graphql/tests/unit/utils/test_utils.py b/swh/graphql/tests/unit/utils/test_utils.py index a692a8d..bc7c5ac 100644 --- a/swh/graphql/tests/unit/utils/test_utils.py +++ b/swh/graphql/tests/unit/utils/test_utils.py @@ -1,61 +1,66 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + import datetime from swh.graphql.utils import utils class TestUtils: def test_b64encode(self): assert utils.b64encode("testing") == "dGVzdGluZw==" def test_get_encoded_cursor_is_none(self): assert utils.get_encoded_cursor(None) is None def test_get_encoded_cursor(self): assert utils.get_encoded_cursor(None) is None assert utils.get_encoded_cursor("testing") == "dGVzdGluZw==" def test_get_decoded_cursor_is_none(self): assert utils.get_decoded_cursor(None) is None def test_get_decoded_cursor(self): assert utils.get_decoded_cursor("dGVzdGluZw==") == "testing" def test_str_to_sha1(self): assert ( utils.str_to_sha1("208f61cc7a5dbc9879ae6e5c2f95891e270f09ef") == b" \x8fa\xccz]\xbc\x98y\xaen\\/\x95\x89\x1e'\x0f\t\xef" ) def test_get_formatted_date(self): date = datetime.datetime( 2015, 8, 4, 22, 26, 14, 804009, tzinfo=datetime.timezone.utc ) assert utils.get_formatted_date(date) == "2015-08-04T22:26:14.804009+00:00" def test_paginated(self): source = [1, 2, 3, 4, 5] response = utils.paginated(source, first=50) assert response.results == source assert response.next_page_token is None def test_paginated_first_arg(self): source = [1, 2, 3, 4, 5] response = utils.paginated(source, first=2) assert response.results == source[:2] assert response.next_page_token == "2" def test_paginated_after_arg(self): source = [1, 2, 3, 4, 5] response = utils.paginated(source, first=2, after="2") assert response.results == [3, 4] assert response.next_page_token == "4" response = utils.paginated(source, first=2, after="3") assert response.results == [4, 5] assert response.next_page_token is None def test_paginated_endcursor_outside(self): source = [1, 2, 3, 4, 5] response = utils.paginated(source, first=2, after="10") assert response.results == [] assert response.next_page_token is None diff --git a/swh/graphql/utils/utils.py b/swh/graphql/utils/utils.py index 77748a7..c6f4094 100644 --- a/swh/graphql/utils/utils.py +++ b/swh/graphql/utils/utils.py @@ -1,49 +1,54 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + import base64 from datetime import datetime from typing import List from swh.storage.interface import PagedResult def b64encode(text: str) -> str: return base64.b64encode(bytes(text, "utf-8")).decode("utf-8") def get_encoded_cursor(cursor: str) -> str: if cursor is None: return None return b64encode(cursor) def get_decoded_cursor(cursor: str) -> str: if cursor is None: return None return base64.b64decode(cursor).decode("utf-8") def str_to_sha1(sha1: str) -> bytearray: # FIXME, use core function return bytearray.fromhex(sha1) def get_formatted_date(date: datetime) -> str: # FIXME, handle error + return other formats return date.isoformat() def paginated(source: List, first: int, after=0) -> PagedResult: """ Pagination at the GraphQL level This is a temporary fix and inefficient. Should eventually be moved to the backend (storage) level """ # FIXME, handle data errors here after = 0 if after is None else int(after) end_cursor = after + first results = source[after:end_cursor] next_page_token = None if len(source) > end_cursor: next_page_token = str(end_cursor) return PagedResult(results=results, next_page_token=next_page_token)