Page Menu
Home
Software Heritage
Search
Configure Global Search
Log In
Files
F9343804
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
60 KB
Subscribers
None
View Options
diff --git a/setup.py b/setup.py
index 1728606..5312f69 100644
--- a/setup.py
+++ b/setup.py
@@ -1,68 +1,68 @@
#!/usr/bin/env python3
-# Copyright (C) 2019-2021 The Software Heritage developers
+# Copyright (C) 2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from io import open
from os import path
from setuptools import find_packages, setup
here = path.abspath(path.dirname(__file__))
# Get the long description from the README file
with open(path.join(here, "README.md"), encoding="utf-8") as f:
long_description = f.read()
def parse_requirements(*names):
requirements = []
for name in names:
if name:
reqf = "requirements-%s.txt" % name
else:
reqf = "requirements.txt"
if not path.exists(reqf):
return requirements
with open(reqf) as f:
for line in f.readlines():
line = line.strip()
if not line or line.startswith("#"):
continue
requirements.append(line)
return requirements
setup(
name="swh.graphql",
description="Software Heritage GraphQL Apis",
long_description=long_description,
long_description_content_type="text/x-rst",
python_requires=">=3.7",
author="Software Heritage developers",
author_email="swh-devel@inria.fr",
url="https://forge.softwareheritage.org/diffusion/DGQL",
packages=find_packages(),
install_requires=parse_requirements(None, "swh"),
tests_require=parse_requirements("test"),
setup_requires=["setuptools-scm"],
use_scm_version=True,
extras_require={"testing": parse_requirements("test")},
include_package_data=True,
classifiers=[
"Programming Language :: Python :: 3",
"Intended Audience :: Developers",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
"Operating System :: OS Independent",
"Development Status :: 3 - Alpha",
],
project_urls={
"Bug Reports": "https://forge.softwareheritage.org/maniphest",
"Funding": "https://www.softwareheritage.org/donate",
"Source": "https://forge.softwareheritage.org/source/swh-graphql",
"Documentation": "https://docs.softwareheritage.org/devel/swh-graphql/",
},
)
diff --git a/swh/graphql/app.py b/swh/graphql/app.py
index f438767..326e1b2 100644
--- a/swh/graphql/app.py
+++ b/swh/graphql/app.py
@@ -1,35 +1,40 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
# import pkg_resources
import os
from pathlib import Path
from ariadne import gql, load_schema_from_path, make_executable_schema
from .resolvers import resolvers, scalars
type_defs = gql(
# pkg_resources.resource_string("swh.graphql", "schem/schema.graphql").decode()
load_schema_from_path(
os.path.join(Path(__file__).parent.resolve(), "schema", "schema.graphql")
)
)
schema = make_executable_schema(
type_defs,
resolvers.query,
resolvers.origin,
resolvers.visit,
resolvers.visit_status,
resolvers.snapshot,
resolvers.snapshot_branch,
resolvers.revision,
resolvers.release,
resolvers.directory,
resolvers.directory_entry,
resolvers.branch_target,
resolvers.release_target,
resolvers.directory_entry_target,
scalars.id_scalar,
scalars.string_scalar,
scalars.datetime_scalar,
scalars.swhid_scalar,
)
diff --git a/swh/graphql/backends/archive.py b/swh/graphql/backends/archive.py
index 7ee5b50..644c29c 100644
--- a/swh/graphql/backends/archive.py
+++ b/swh/graphql/backends/archive.py
@@ -1,72 +1,77 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from swh.graphql import server
class Archive:
def __init__(self):
self.storage = server.get_storage()
def get_origin(self, url):
return self.storage.origin_get([url])[0]
def get_origins(self, after=None, first=50, url_pattern=None):
# STORAGE-TODO
# Make them a single function in the backend
if url_pattern is None:
return self.storage.origin_list(page_token=after, limit=first)
return self.storage.origin_search(
url_pattern=url_pattern, page_token=after, limit=first
)
def get_origin_visits(self, origin_url, after=None, first=50):
return self.storage.origin_visit_get(origin_url, page_token=after, limit=first)
def get_origin_visit(self, origin_url, visit_id):
return self.storage.origin_visit_get_by(origin_url, visit_id)
def get_origin_latest_visit(self, origin_url):
return self.storage.origin_visit_get_latest(origin_url)
def get_visit_status(self, origin_url, visit_id, after=None, first=50):
return self.storage.origin_visit_status_get(
origin_url, visit_id, page_token=after, limit=first
)
def get_latest_visit_status(self, origin_url, visit_id):
return self.storage.origin_visit_status_get_latest(origin_url, visit_id)
def get_origin_snapshots(self, origin_url):
return self.storage.origin_snapshot_get_all(origin_url)
def is_snapshot_available(self, snapshot_ids):
return not self.storage.snapshot_missing(snapshot_ids)
def get_snapshot_branches(self, snapshot, after, first, target_types, name_include):
return self.storage.snapshot_get_branches(
snapshot,
branches_from=after,
branches_count=first,
target_types=target_types,
branch_name_include_substring=name_include,
)
def get_revisions(self, revision_ids):
return self.storage.revision_get(revision_ids=revision_ids)
def get_revision_log(self, revision_ids, after=None, first=50):
return self.storage.revision_log(revisions=revision_ids, limit=first)
def get_releases(self, release_ids):
return self.storage.release_get(releases=release_ids)
def is_directory_available(self, directory_ids):
return not self.storage.directory_missing(directory_ids)
def get_directory_entries(self, directory_id, after=None, first=50):
return self.storage.directory_get_entries(
directory_id, limit=first, page_token=after
)
def get_content(self, content_id):
# FIXME, only for tests
return self.storage.content_find({"sha1_git": content_id})
diff --git a/swh/graphql/errors/__init__.py b/swh/graphql/errors/__init__.py
index 103f1bd..3e8bd41 100644
--- a/swh/graphql/errors/__init__.py
+++ b/swh/graphql/errors/__init__.py
@@ -1,4 +1,9 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from .errors import ObjectNotFoundError
from .handlers import format_error
__all__ = ["ObjectNotFoundError", "format_error"]
diff --git a/swh/graphql/errors/errors.py b/swh/graphql/errors/errors.py
index 6a587f5..8036b65 100644
--- a/swh/graphql/errors/errors.py
+++ b/swh/graphql/errors/errors.py
@@ -1,2 +1,8 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+
class ObjectNotFoundError(Exception):
""" """
diff --git a/swh/graphql/errors/handlers.py b/swh/graphql/errors/handlers.py
index 8d13115..c61e593 100644
--- a/swh/graphql/errors/handlers.py
+++ b/swh/graphql/errors/handlers.py
@@ -1,7 +1,13 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+
def format_error(error) -> dict:
"""
Response error formatting
"""
formatted = error.formatted
formatted["message"] = "Unknown error"
return formatted
diff --git a/swh/graphql/middlewares/graphql.py b/swh/graphql/middlewares/graphql.py
index a203330..9314edf 100644
--- a/swh/graphql/middlewares/graphql.py
+++ b/swh/graphql/middlewares/graphql.py
@@ -1,11 +1,16 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
"""
To implement graphql middleware
"""
class CostError:
pass
def cost_limiter():
pass
diff --git a/swh/graphql/qns.txt b/swh/graphql/qns.txt
deleted file mode 100644
index 1e87ec6..0000000
--- a/swh/graphql/qns.txt
+++ /dev/null
@@ -1,14 +0,0 @@
-Questions
-=========
-* Idea for a homogeneous ID for node (can we expose primary key from postgres)
-* Why the visit status is a paginated list in storage?, and not in v1
-* visit id should be visit number
-... Schema related questions
-
-* What should we include in pageinfo (start token, haspreviouspage etc)
-* Query cost calculator logic (could be a bit complex)
-* Throttling based on Query cost calculator
-* Authentication and Authorization
-
-* Datetime as a time stamp
-* Other scalar types needed (swhid)
diff --git a/swh/graphql/resolvers/base_connection.py b/swh/graphql/resolvers/base_connection.py
index 617d2a7..f5f56b3 100644
--- a/swh/graphql/resolvers/base_connection.py
+++ b/swh/graphql/resolvers/base_connection.py
@@ -1,107 +1,112 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Type
from swh.graphql.utils import utils
from .base_node import BaseNode
@dataclass
class PageInfo:
hasNextPage: bool
endCursor: str
class BaseConnection(ABC):
"""
Base class for all the connection resolvers
"""
_node_class: Optional[Type[BaseNode]] = None
_page_size = 50 # default page size
def __init__(self, obj, info, paged_data=None, **kwargs):
self.obj = obj
self.info = info
self.kwargs = kwargs
self._paged_data = paged_data
def __call__(self, *args, **kw):
return self
@property
def edges(self):
return self._get_edges()
@property
def nodes(self):
"""
Override if needed; return a list of objects
If a node class is set, return a list of its (Node) instances
else a list of raw results
"""
if self._node_class is not None:
return [
self._node_class(self.obj, self.info, node_data=result, **self.kwargs)
for result in self.get_paged_data().results
]
return self.get_paged_data().results
@property
def pageInfo(self): # To support the schema naming convention
# FIXME, add more details like startCursor
return PageInfo(
hasNextPage=bool(self.get_paged_data().next_page_token),
endCursor=utils.get_encoded_cursor(self.get_paged_data().next_page_token),
)
@property
def totalCount(self): # To support the schema naming convention
return self._get_total_count()
def _get_total_count(self):
"""
Will be None for most of the connections
override if needed/possible
"""
return None
def get_paged_data(self):
"""
Cache to avoid multiple calls to
the backend (_get_paged_result)
return a PagedResult object
"""
if self._paged_data is None:
# FIXME, make this call async (not for v1)
self._paged_data = self._get_paged_result()
return self._paged_data
@abstractmethod
def _get_paged_result(self):
"""
Override for desired behaviour
return a PagedResult object
"""
# FIXME, make this call async (not for v1)
return None
def _get_edges(self):
# FIXME, make cursor work per item
# Cursor can't be None here
return [{"cursor": "dummy", "node": node} for node in self.nodes]
def _get_after_arg(self):
"""
Return the decoded next page token
override to use a specific token
"""
return utils.get_decoded_cursor(self.kwargs.get("after"))
def _get_first_arg(self):
"""
page_size is set to 50 by default
"""
return self.kwargs.get("first", self._page_size)
diff --git a/swh/graphql/resolvers/base_node.py b/swh/graphql/resolvers/base_node.py
index 9f389f3..ff3c2b3 100644
--- a/swh/graphql/resolvers/base_node.py
+++ b/swh/graphql/resolvers/base_node.py
@@ -1,76 +1,81 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from abc import ABC
from collections import namedtuple
from swh.graphql.errors import ObjectNotFoundError
class BaseNode(ABC):
"""
Base class for all the Node resolvers
"""
def __init__(self, obj, info, node_data=None, **kwargs):
self.obj = obj
self.info = info
self.kwargs = kwargs
self._node = self._get_node(node_data)
# handle the errors, if any, after _node is set
self._handle_node_errors()
def _get_node(self, node_data):
"""
Get the node object from the given data
if the data (node_data) is none make
a function call to get data from backend
"""
if node_data is None:
node_data = self._get_node_data()
return self._get_node_from_data(node_data)
def _get_node_from_data(self, node_data):
"""
Get the object from node_data
In case of a dict, convert it to an object
Override to support different data structures
"""
if type(node_data) is dict:
return namedtuple("NodeObj", node_data.keys())(*node_data.values())
return node_data
def _handle_node_errors(self):
"""
Handle any error related to node data
raise an error in case the object returned is None
override for specific behaviour
"""
if self._node is None:
raise ObjectNotFoundError("Requested object is not available")
def __call__(self, *args, **kw):
return self
def _get_node_data(self):
"""
Override for desired behaviour
This will be called only when
node_data is None
"""
# FIXME, make this call async (not for v1)
return None
def __getattr__(self, name):
"""
Any property defined in the sub-class
will get precedence over the _node attributes
"""
return getattr(self._node, name)
def is_type_of(self):
return self.__class__.__name__
class BaseSWHNode(BaseNode):
@property
def SWHID(self):
return self._node.swhid()
diff --git a/swh/graphql/resolvers/content.py b/swh/graphql/resolvers/content.py
index 68d71e0..08d8637 100644
--- a/swh/graphql/resolvers/content.py
+++ b/swh/graphql/resolvers/content.py
@@ -1,44 +1,49 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from swh.graphql.backends import archive
from .base_node import BaseSWHNode
class BaseContentNode(BaseSWHNode):
""" """
def _get_content_by_id(self, content_id):
content = archive.Archive().get_content(content_id)
return content[0] if content else None
@property
def checksum(self):
# FIXME, return a Node object
return {k: v.hex() for (k, v) in self._node.hashes().items()}
@property
def id(self):
return self._node.sha1_git
def is_type_of(self):
return "Content"
class ContentNode(BaseContentNode):
def _get_node_data(self):
"""
When a content is requested directly
with its SWHID
"""
return self._get_content_by_id(self.kwargs.get("SWHID").object_id)
class TargetContentNode(BaseContentNode):
def _get_node_data(self):
"""
When a content is requested from a
directory entry or from a release target
content id is obj.targetHash here
"""
content_id = self.obj.targetHash
return self._get_content_by_id(content_id)
diff --git a/swh/graphql/resolvers/directory.py b/swh/graphql/resolvers/directory.py
index 68c5d78..5c345b7 100644
--- a/swh/graphql/resolvers/directory.py
+++ b/swh/graphql/resolvers/directory.py
@@ -1,50 +1,55 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from swh.graphql.backends import archive
from swh.model.model import Directory
from .base_node import BaseSWHNode
class BaseDirectoryNode(BaseSWHNode):
def _get_directory_by_id(self, directory_id):
# Return a Directory model object
# entries is initialized as empty
# Same pattern is used in snapshot
return Directory(id=directory_id, entries=())
def is_type_of(self):
return "Directory"
class DirectoryNode(BaseDirectoryNode):
def _get_node_data(self):
"""
When a directory is requested directly with its SWHID
"""
directory_id = self.kwargs.get("SWHID").object_id
# path = ""
if archive.Archive().is_directory_available([directory_id]):
return self._get_directory_by_id(directory_id)
return None
class RevisionDirectoryNode(BaseDirectoryNode):
def _get_node_data(self):
"""
When a directory is requested from a revision
self.obj is revision here
self.obj.directorySWHID is the required dir SWHID
(set from resolvers.revision.py:BaseRevisionNode)
"""
directory_id = self.obj.directorySWHID.object_id
return self._get_directory_by_id(directory_id)
class TargetDirectoryNode(BaseDirectoryNode):
def _get_node_data(self):
"""
When a directory is requested as a target
self.obj can be a Release or a DirectoryEntry
obj.targetHash is the requested directory id here
"""
return self._get_directory_by_id(self.obj.targetHash)
diff --git a/swh/graphql/resolvers/directory_entry.py b/swh/graphql/resolvers/directory_entry.py
index b4b9712..bbb94c9 100644
--- a/swh/graphql/resolvers/directory_entry.py
+++ b/swh/graphql/resolvers/directory_entry.py
@@ -1,34 +1,39 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from swh.graphql.backends import archive
from swh.graphql.utils import utils
from .base_connection import BaseConnection
from .base_node import BaseNode
class DirectoryEntryNode(BaseNode):
""" """
@property
def targetHash(self): # To support the schema naming convention
return self._node.target
class DirectoryEntryConnection(BaseConnection):
_node_class = DirectoryEntryNode
def _get_paged_result(self):
"""
When entries requested from a directory
self.obj.SWHID is the directory SWHID here
This is not paginated from swh-storgae
using dummy pagination
"""
# FIXME, using dummy(local) pagination, move pagination to backend
# To remove localpagination, just drop the paginated call
# STORAGE-TODO
entries = (
archive.Archive().get_directory_entries(self.obj.SWHID.object_id).results
)
return utils.paginated(entries, self._get_first_arg(), self._get_after_arg())
diff --git a/swh/graphql/resolvers/person.py b/swh/graphql/resolvers/person.py
index ca55b85..c805924 100644
--- a/swh/graphql/resolvers/person.py
+++ b/swh/graphql/resolvers/person.py
@@ -1,5 +1,10 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from .base_node import BaseNode
class PersonNode(BaseNode):
""" """
diff --git a/swh/graphql/resolvers/release.py b/swh/graphql/resolvers/release.py
index 767b197..2187e2f 100644
--- a/swh/graphql/resolvers/release.py
+++ b/swh/graphql/resolvers/release.py
@@ -1,45 +1,50 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from swh.graphql.backends import archive
from .base_node import BaseSWHNode
class BaseReleaseNode(BaseSWHNode):
def _get_release_by_id(self, release_id):
return (archive.Archive().get_releases([release_id]) or None)[0]
@property
def targetHash(self): # To support the schema naming convention
return self._node.target
@property
def targetType(self): # To support the schema naming convention
return self._node.target_type.value
def is_type_of(self):
"""
is_type_of is required only when resolving
a UNION type
This is for ariadne to return the right type
"""
return "Release"
class ReleaseNode(BaseReleaseNode):
"""
When the release is requested directly with its SWHID
"""
def _get_node_data(self):
return self._get_release_by_id(self.kwargs.get("SWHID").object_id)
class TargetReleaseNode(BaseReleaseNode):
"""
When a release is requested as a target
self.obj could be a snapshotbranch or a release
self.obj.targetHash is the requested release id here
"""
def _get_node_data(self):
return self._get_release_by_id(self.obj.targetHash)
diff --git a/swh/graphql/resolvers/resolver_factory.py b/swh/graphql/resolvers/resolver_factory.py
index 300da90..286573e 100644
--- a/swh/graphql/resolvers/resolver_factory.py
+++ b/swh/graphql/resolvers/resolver_factory.py
@@ -1,60 +1,65 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from .content import ContentNode, TargetContentNode
from .directory import DirectoryNode, RevisionDirectoryNode, TargetDirectoryNode
from .directory_entry import DirectoryEntryConnection
from .origin import OriginConnection, OriginNode
from .release import ReleaseNode, TargetReleaseNode
from .revision import (
LogRevisionConnection,
ParentRevisionConnection,
RevisionNode,
TargetRevisionNode,
)
from .snapshot import OriginSnapshotConnection, SnapshotNode, VisitSnapshotNode
from .snapshot_branch import SnapshotBranchConnection
from .visit import LatestVisitNode, OriginVisitConnection, OriginVisitNode
from .visit_status import LatestVisitStatusNode, VisitStatusConnection
def get_node_resolver(resolver_type):
# FIXME, replace with a proper factory method
mapping = {
"origin": OriginNode,
"visit": OriginVisitNode,
"latest-visit": LatestVisitNode,
"latest-status": LatestVisitStatusNode,
"visit-snapshot": VisitSnapshotNode,
"snapshot": SnapshotNode,
"branch-revision": TargetRevisionNode,
"branch-release": TargetReleaseNode,
"revision": RevisionNode,
"revision-directory": RevisionDirectoryNode,
"release": ReleaseNode,
"release-revision": TargetRevisionNode,
"release-release": TargetReleaseNode,
"release-directory": TargetDirectoryNode,
"release-content": TargetContentNode,
"directory": DirectoryNode,
"content": ContentNode,
"dir-entry-dir": TargetDirectoryNode,
"dir-entry-file": TargetContentNode,
}
if resolver_type not in mapping:
raise AttributeError(f"Invalid node type: {resolver_type}")
return mapping[resolver_type]
def get_connection_resolver(resolver_type):
# FIXME, replace with a proper factory method
mapping = {
"origins": OriginConnection,
"origin-visits": OriginVisitConnection,
"origin-snapshots": OriginSnapshotConnection,
"visit-status": VisitStatusConnection,
"snapshot-branches": SnapshotBranchConnection,
"revision-parents": ParentRevisionConnection,
"revision-log": LogRevisionConnection,
"directory-entries": DirectoryEntryConnection,
}
if resolver_type not in mapping:
raise AttributeError(f"Invalid connection type: {resolver_type}")
return mapping[resolver_type]
diff --git a/swh/graphql/resolvers/resolvers.py b/swh/graphql/resolvers/resolvers.py
index ad85786..3675fd3 100644
--- a/swh/graphql/resolvers/resolvers.py
+++ b/swh/graphql/resolvers/resolvers.py
@@ -1,240 +1,245 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
"""
High level resolvers
Any schema attribute can be resolved by any of the following ways
and in the following priority order
- In this module using an annotation (eg: @visitstatus.field("snapshot"))
- As a property in the Node object (eg: resolvers.visit.OriginVisitNode.id)
- As an attribute/item in the object/dict returned by the backend (eg: Origin.url)
"""
from ariadne import ObjectType, UnionType
from graphql.type import GraphQLResolveInfo
from swh.graphql import resolvers as rs
from .resolver_factory import get_connection_resolver, get_node_resolver
query: ObjectType = ObjectType("Query")
origin: ObjectType = ObjectType("Origin")
visit: ObjectType = ObjectType("Visit")
visit_status: ObjectType = ObjectType("VisitStatus")
snapshot: ObjectType = ObjectType("Snapshot")
snapshot_branch: ObjectType = ObjectType("Branch")
revision: ObjectType = ObjectType("Revision")
release: ObjectType = ObjectType("Release")
directory: ObjectType = ObjectType("Directory")
directory_entry: ObjectType = ObjectType("DirectoryEntry")
branch_target: UnionType = UnionType("BranchTarget")
release_target: UnionType = UnionType("ReleaseTarget")
directory_entry_target: UnionType = UnionType("DirectoryEntryTarget")
# Node resolvers
# A node resolver should return an instance of BaseNode
@query.field("origin")
def origin_resolver(obj: None, info: GraphQLResolveInfo, **kw) -> rs.origin.OriginNode:
""" """
resolver = get_node_resolver("origin")
return resolver(obj, info, **kw)()
@origin.field("latestVisit")
def latest_visit_resolver(
obj: rs.origin.OriginNode, info: GraphQLResolveInfo, **kw
) -> rs.visit.LatestVisitNode:
""" """
resolver = get_node_resolver("latest-visit")
return resolver(obj, info, **kw)()
@query.field("visit")
def visit_resolver(
obj: None, info: GraphQLResolveInfo, **kw
) -> rs.visit.OriginVisitNode:
""" """
resolver = get_node_resolver("visit")
return resolver(obj, info, **kw)()
@visit.field("latestStatus")
def latest_visit_status_resolver(
obj, info: GraphQLResolveInfo, **kw
) -> rs.visit_status.LatestVisitStatusNode:
""" """
resolver = get_node_resolver("latest-status")
return resolver(obj, info, **kw)()
@query.field("snapshot")
def snapshot_resolver(
obj: None, info: GraphQLResolveInfo, **kw
) -> rs.snapshot.SnapshotNode:
""" """
resolver = get_node_resolver("snapshot")
return resolver(obj, info, **kw)()
@visit_status.field("snapshot")
def visit_snapshot_resolver(
obj, info: GraphQLResolveInfo, **kw
) -> rs.snapshot.VisitSnapshotNode:
resolver = get_node_resolver("visit-snapshot")
return resolver(obj, info, **kw)()
@snapshot_branch.field("target")
def snapshot_branch_target_resolver(
obj: rs.snapshot_branch.SnapshotBranchNode, info: GraphQLResolveInfo, **kw
):
"""
Snapshot branch target can be a revision or a release
"""
resolver_type = f"branch-{obj.type}"
resolver = get_node_resolver(resolver_type)
return resolver(obj, info, **kw)()
@query.field("revision")
def revision_resolver(
obj: None, info: GraphQLResolveInfo, **kw
) -> rs.revision.RevisionNode:
resolver = get_node_resolver("revision")
return resolver(obj, info, **kw)()
@revision.field("directory")
def revision_directory_resolver(
obj, info: GraphQLResolveInfo, **kw
) -> rs.directory.RevisionDirectoryNode:
resolver = get_node_resolver("revision-directory")
return resolver(obj, info, **kw)()
@query.field("release")
def release_resolver(
obj: None, info: GraphQLResolveInfo, **kw
) -> rs.release.ReleaseNode:
resolver = get_node_resolver("release")
return resolver(obj, info, **kw)()
@release.field("target")
def release_target_resolver(obj, info: GraphQLResolveInfo, **kw):
"""
release target can be a release, revision,
directory or content
obj is release here, target type is
obj.target_type
"""
resolver_type = f"release-{obj.target_type.value}"
resolver = get_node_resolver(resolver_type)
return resolver(obj, info, **kw)()
@query.field("directory")
def directory_resolver(
obj: None, info: GraphQLResolveInfo, **kw
) -> rs.directory.DirectoryNode:
resolver = get_node_resolver("directory")
return resolver(obj, info, **kw)()
@directory_entry.field("target")
def directory_entry_target_resolver(
obj: rs.directory_entry.DirectoryEntryNode, info: GraphQLResolveInfo, **kw
):
"""
directory entry target can be a directory or a content
"""
resolver_type = f"dir-entry-{obj.type}"
resolver = get_node_resolver(resolver_type)
return resolver(obj, info, **kw)()
@query.field("content")
def content_resolver(
obj: None, info: GraphQLResolveInfo, **kw
) -> rs.content.ContentNode:
resolver = get_node_resolver("content")
return resolver(obj, info, **kw)()
# Connection resolvers
# A connection resolver should return an instance of BaseConnection
@query.field("origins")
def origins_resolver(
obj: None, info: GraphQLResolveInfo, **kw
) -> rs.origin.OriginConnection:
resolver = get_connection_resolver("origins")
return resolver(obj, info, **kw)()
@origin.field("visits")
def visits_resolver(
obj: rs.origin.OriginNode, info: GraphQLResolveInfo, **kw
) -> rs.visit.OriginVisitConnection:
resolver = get_connection_resolver("origin-visits")
return resolver(obj, info, **kw)()
@origin.field("snapshots")
def origin_snapshots_resolver(
obj: rs.origin.OriginNode, info: GraphQLResolveInfo, **kw
) -> rs.snapshot.OriginSnapshotConnection:
""" """
resolver = get_connection_resolver("origin-snapshots")
return resolver(obj, info, **kw)()
@visit.field("status")
def visitstatus_resolver(
obj, info: GraphQLResolveInfo, **kw
) -> rs.visit_status.VisitStatusConnection:
resolver = get_connection_resolver("visit-status")
return resolver(obj, info, **kw)()
@snapshot.field("branches")
def snapshot_branches_resolver(
obj, info: GraphQLResolveInfo, **kw
) -> rs.snapshot_branch.SnapshotBranchConnection:
resolver = get_connection_resolver("snapshot-branches")
return resolver(obj, info, **kw)()
@revision.field("parents")
def revision_parents_resolver(
obj, info: GraphQLResolveInfo, **kw
) -> rs.revision.ParentRevisionConnection:
resolver = get_connection_resolver("revision-parents")
return resolver(obj, info, **kw)()
# @revision.field("revisionLog")
# def revision_log_resolver(obj, info, **kw):
# resolver = get_connection_resolver("revision-log")
# return resolver(obj, info, **kw)()
@directory.field("entries")
def directory_entry_resolver(
obj, info: GraphQLResolveInfo, **kw
) -> rs.directory_entry.DirectoryEntryConnection:
resolver = get_connection_resolver("directory-entries")
return resolver(obj, info, **kw)()
# Any other type of resolver
@release_target.type_resolver
@directory_entry_target.type_resolver
@branch_target.type_resolver
def union_resolver(obj, *_) -> str:
"""
Generic resolver for all the union types
"""
return obj.is_type_of()
diff --git a/swh/graphql/resolvers/revision.py b/swh/graphql/resolvers/revision.py
index 8cb6cd1..99da788 100644
--- a/swh/graphql/resolvers/revision.py
+++ b/swh/graphql/resolvers/revision.py
@@ -1,97 +1,102 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from swh.graphql.backends import archive
from swh.graphql.utils import utils
from swh.model.swhids import CoreSWHID, ObjectType
from .base_connection import BaseConnection
from .base_node import BaseSWHNode
class BaseRevisionNode(BaseSWHNode):
def _get_revision_by_id(self, revision_id):
return (archive.Archive().get_revisions([revision_id]) or None)[0]
@property
def parentSWHIDs(self): # To support the schema naming convention
return [
CoreSWHID(object_type=ObjectType.REVISION, object_id=parent_id)
for parent_id in self._node.parents
]
@property
def directorySWHID(self): # To support the schema naming convention
""" """
return CoreSWHID(
object_type=ObjectType.DIRECTORY, object_id=self._node.directory
)
@property
def type(self):
return self._node.type.value
def is_type_of(self):
"""
is_type_of is required only when resolving
a UNION type
This is for ariadne to return the right type
"""
return "Revision"
class RevisionNode(BaseRevisionNode):
"""
When the revision is requested directly with its SWHID
"""
def _get_node_data(self):
return self._get_revision_by_id(self.kwargs.get("SWHID").object_id)
class TargetRevisionNode(BaseRevisionNode):
"""
When a revision is requested as a target
self.obj could be a snapshotbranch or a release
self.obj.targetHash is the requested revision id here
"""
def _get_node_data(self):
return self._get_revision_by_id(self.obj.targetHash)
class ParentRevisionConnection(BaseConnection):
"""
When parent revisions is requested from a
revision
self.obj is the current(child) revision
self.obj.parentSWHIDs is the list of
parent SWHIDs
"""
_node_class = BaseRevisionNode
def _get_paged_result(self):
# FIXME, using dummy(local) pagination, move pagination to backend
# To remove localpagination, just drop the paginated call
# STORAGE-TODO (pagination)
parents = archive.Archive().get_revisions(
[x.object_id for x in self.obj.parentSWHIDs]
)
return utils.paginated(parents, self._get_first_arg(), self._get_after_arg())
class LogRevisionConnection(BaseConnection):
"""
When revisionslog is requested from a
revision
self.obj is the current revision id
"""
_node_class = BaseRevisionNode
def _get_paged_result(self):
# STORAGE-TODO (date in revisionlog is a dict)
log = archive.Archive().get_revision_log([self.obj.SWHID.object_id])
# FIXME, using dummy(local) pagination, move pagination to backend
# To remove localpagination, just drop the paginated call
# STORAGE-TODO (pagination)
return utils.paginated(log, self._get_first_arg(), self._get_after_arg())
diff --git a/swh/graphql/resolvers/scalars.py b/swh/graphql/resolvers/scalars.py
index e21ffec..5d34594 100644
--- a/swh/graphql/resolvers/scalars.py
+++ b/swh/graphql/resolvers/scalars.py
@@ -1,46 +1,51 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from datetime import datetime
from ariadne import ScalarType
from swh.graphql.utils import utils
from swh.model.model import TimestampWithTimezone
from swh.model.swhids import CoreSWHID
datetime_scalar = ScalarType("DateTime")
swhid_scalar = ScalarType("SWHID")
id_scalar = ScalarType("ID")
string_scalar = ScalarType("String")
@id_scalar.serializer
def serialize_id(value):
if type(value) is bytes:
return value.hex()
return value
@string_scalar.serializer
def serialize_string(value):
if type(value) is bytes:
return value.decode("utf-8")
return value
@datetime_scalar.serializer
def serialize_datetime(value):
# FIXME, handle error and return None
if type(value) == TimestampWithTimezone:
value = value.to_datetime()
if type(value) == datetime:
return utils.get_formatted_date(value)
return None
@swhid_scalar.value_parser
def validate_swhid(value):
return CoreSWHID.from_string(value)
@swhid_scalar.serializer
def serialize_swhid(value):
return str(value)
diff --git a/swh/graphql/resolvers/snapshot.py b/swh/graphql/resolvers/snapshot.py
index 568a636..d0e81e6 100644
--- a/swh/graphql/resolvers/snapshot.py
+++ b/swh/graphql/resolvers/snapshot.py
@@ -1,54 +1,59 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from swh.graphql.backends import archive
from swh.graphql.utils import utils
from swh.model.model import Snapshot
from .base_connection import BaseConnection
from .base_node import BaseSWHNode
class BaseSnapshotNode(BaseSWHNode):
def _get_snapshot_by_id(self, snapshot_id):
# Return a Snapshot model object
# branches is initialized as empty
# Same pattern is used in directory
return Snapshot(id=snapshot_id, branches={})
class SnapshotNode(BaseSnapshotNode):
"""
For directly accessing a snapshot with its SWHID
"""
def _get_node_data(self):
""" """
snapshot_id = self.kwargs.get("SWHID").object_id
if archive.Archive().is_snapshot_available([snapshot_id]):
return self._get_snapshot_by_id(snapshot_id)
return None
class VisitSnapshotNode(BaseSnapshotNode):
"""
For accessing a snapshot from a visitstatus type
"""
def _get_node_data(self):
"""
self.obj is visitstatus here
self.obj.snapshotSWHID is the requested snapshot SWHID
"""
snapshot_id = self.obj.snapshotSWHID.object_id
return self._get_snapshot_by_id(snapshot_id)
class OriginSnapshotConnection(BaseConnection):
_node_class = BaseSnapshotNode
def _get_paged_result(self):
""" """
results = archive.Archive().get_origin_snapshots(self.obj.url)
snapshots = [Snapshot(id=snapshot, branches={}) for snapshot in results]
# FIXME, using dummy(local) pagination, move pagination to backend
# To remove localpagination, just drop the paginated call
# STORAGE-TODO
return utils.paginated(snapshots, self._get_first_arg(), self._get_after_arg())
diff --git a/swh/graphql/resolvers/snapshot_branch.py b/swh/graphql/resolvers/snapshot_branch.py
index 2c49e24..0bcd37b 100644
--- a/swh/graphql/resolvers/snapshot_branch.py
+++ b/swh/graphql/resolvers/snapshot_branch.py
@@ -1,73 +1,78 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from collections import namedtuple
from swh.graphql.backends import archive
from swh.graphql.utils import utils
from swh.storage.interface import PagedResult
from .base_connection import BaseConnection
from .base_node import BaseNode
class SnapshotBranchNode(BaseNode):
"""
target field for this Node is a UNION in the schema
It is resolved in resolvers.resolvers.py
"""
def _get_node_from_data(self, node_data):
"""
node_data is not a dict in this case
overriding to support this special data structure
"""
# STORAGE-TODO; return an object in the normal format
branch_name, branch_obj = node_data
node = {
"name": branch_name,
"type": branch_obj.target_type.value,
"target": branch_obj.target,
}
return namedtuple("NodeObj", node.keys())(*node.values())
@property
def targetHash(self): # To support the schema naming convention
return self._node.target
class SnapshotBranchConnection(BaseConnection):
_node_class = SnapshotBranchNode
def _get_paged_result(self):
"""
When branches requested from a snapshot
self.obj.SWHID is the snapshot SWHID here
(as returned from resolvers/snapshot.py)
"""
result = archive.Archive().get_snapshot_branches(
self.obj.SWHID.object_id,
after=self._get_after_arg(),
first=self._get_first_arg(),
target_types=self.kwargs.get("types"),
name_include=self.kwargs.get("nameInclude"),
)
# FIXME Cursor must be a hex to be consistent with
# the base class, hack to make that work
end_cusrsor = (
result["next_branch"].hex() if result["next_branch"] is not None else None
)
# FIXME, this pagination is not consistent with other connections
# FIX in swh-storage to return PagedResult
# STORAGE-TODO
return PagedResult(
results=result["branches"].items(), next_page_token=end_cusrsor
)
def _get_after_arg(self):
"""
Snapshot branch is using a different cursor; logic to handle that
"""
# FIXME Cursor must be a hex to be consistent with
# the base class, hack to make that work
after = utils.get_decoded_cursor(self.kwargs.get("after", ""))
return bytes.fromhex(after)
diff --git a/swh/graphql/resolvers/visit.py b/swh/graphql/resolvers/visit.py
index 877f1fd..a2222d1 100644
--- a/swh/graphql/resolvers/visit.py
+++ b/swh/graphql/resolvers/visit.py
@@ -1,51 +1,56 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from swh.graphql.backends import archive
from swh.graphql.utils import utils
from .base_connection import BaseConnection
from .base_node import BaseNode
class BaseVisitNode(BaseNode):
@property
def id(self):
# FIXME, use a better id
return utils.b64encode(f"{self.origin}-{str(self.visit)}")
@property
def visitId(self): # To support the schema naming convention
return self._node.visit
class OriginVisitNode(BaseVisitNode):
"""
Get the visit directly with an origin URL and a visit ID
"""
def _get_node_data(self):
return archive.Archive().get_origin_visit(
self.kwargs.get("originUrl"), int(self.kwargs.get("visitId"))
)
class LatestVisitNode(BaseVisitNode):
"""
Get the latest visit for an origin
self.obj is the origin object here
self.obj.url is the origin URL
"""
def _get_node_data(self):
return archive.Archive().get_origin_latest_visit(self.obj.url)
class OriginVisitConnection(BaseConnection):
_node_class = BaseVisitNode
def _get_paged_result(self):
"""
Get the visits for the given origin
parent obj (self.obj) is origin here
"""
return archive.Archive().get_origin_visits(
self.obj.url, after=self._get_after_arg(), first=self._get_first_arg()
)
diff --git a/swh/graphql/resolvers/visit_status.py b/swh/graphql/resolvers/visit_status.py
index 970a69f..7c9b33d 100644
--- a/swh/graphql/resolvers/visit_status.py
+++ b/swh/graphql/resolvers/visit_status.py
@@ -1,43 +1,48 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from swh.graphql.backends import archive
from swh.model.swhids import CoreSWHID, ObjectType
from .base_connection import BaseConnection
from .base_node import BaseNode
class BaseVisitStatusNode(BaseNode):
""" """
@property
def snapshotSWHID(self): # To support the schema naming convention
return CoreSWHID(object_type=ObjectType.SNAPSHOT, object_id=self._node.snapshot)
class LatestVisitStatusNode(BaseVisitStatusNode):
"""
Get the latest visit status for a visit
self.obj is the visit object here
self.obj.origin is the origin URL
"""
def _get_node_data(self):
return archive.Archive().get_latest_visit_status(
self.obj.origin, self.obj.visitId
)
class VisitStatusConnection(BaseConnection):
"""
self.obj is the visit object
self.obj.origin is the origin URL
"""
_node_class = BaseVisitStatusNode
def _get_paged_result(self):
return archive.Archive().get_visit_status(
self.obj.origin,
self.obj.visitId,
after=self._get_after_arg(),
first=self._get_first_arg(),
)
diff --git a/swh/graphql/tests/unit/resolvers/test_base_node.py b/swh/graphql/tests/unit/resolvers/test_base_node.py
index c5f5968..94b2e29 100644
--- a/swh/graphql/tests/unit/resolvers/test_base_node.py
+++ b/swh/graphql/tests/unit/resolvers/test_base_node.py
@@ -1,25 +1,30 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
"""
"""
class TestBaseNode:
def test_init(self):
pass
def test_get_node(self):
pass
def test_get_node_from_data(self):
pass
def test_handle_node_errors(self):
pass
def test_get_node_data(self):
pass
def test_getattr(self):
pass
def test_is_type_of(sellf):
pass
diff --git a/swh/graphql/tests/unit/resolvers/test_resolver_factory.py b/swh/graphql/tests/unit/resolvers/test_resolver_factory.py
index 4d39ce9..7bb4400 100644
--- a/swh/graphql/tests/unit/resolvers/test_resolver_factory.py
+++ b/swh/graphql/tests/unit/resolvers/test_resolver_factory.py
@@ -1,58 +1,63 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
import pytest
from swh.graphql.resolvers import resolver_factory
class TestFactory:
@pytest.mark.parametrize(
- "input_type, expexted",
+ "input_type, expected",
[
("origin", "OriginNode"),
("visit", "OriginVisitNode"),
("latest-visit", "LatestVisitNode"),
("latest-status", "LatestVisitStatusNode"),
("visit-snapshot", "VisitSnapshotNode"),
("snapshot", "SnapshotNode"),
("branch-revision", "TargetRevisionNode"),
("branch-release", "TargetReleaseNode"),
("revision", "RevisionNode"),
("revision-directory", "RevisionDirectoryNode"),
("release", "ReleaseNode"),
("release-revision", "TargetRevisionNode"),
("release-release", "TargetReleaseNode"),
("release-directory", "TargetDirectoryNode"),
("release-content", "TargetContentNode"),
("directory", "DirectoryNode"),
("content", "ContentNode"),
("dir-entry-dir", "TargetDirectoryNode"),
("dir-entry-file", "TargetContentNode"),
],
)
- def test_get_node_resolver(self, input_type, expexted):
+ def test_get_node_resolver(self, input_type, expected):
response = resolver_factory.get_node_resolver(input_type)
- assert response.__name__ == expexted
+ assert response.__name__ == expected
def test_get_node_resolver_invalid_type(self):
with pytest.raises(AttributeError):
resolver_factory.get_node_resolver("invalid")
@pytest.mark.parametrize(
- "input_type, expexted",
+ "input_type, expected",
[
("origins", "OriginConnection"),
("origin-visits", "OriginVisitConnection"),
("origin-snapshots", "OriginSnapshotConnection"),
("visit-status", "VisitStatusConnection"),
("snapshot-branches", "SnapshotBranchConnection"),
("revision-parents", "ParentRevisionConnection"),
("revision-log", "LogRevisionConnection"),
("directory-entries", "DirectoryEntryConnection"),
],
)
- def test_get_connection_resolver(self, input_type, expexted):
+ def test_get_connection_resolver(self, input_type, expected):
response = resolver_factory.get_connection_resolver(input_type)
- assert response.__name__ == expexted
+ assert response.__name__ == expected
def test_get_connection_resolver_invalid_type(self):
with pytest.raises(AttributeError):
resolver_factory.get_connection_resolver("invalid")
diff --git a/swh/graphql/tests/unit/resolvers/test_resolvers.py b/swh/graphql/tests/unit/resolvers/test_resolvers.py
index 6ac5464..ff29a8f 100644
--- a/swh/graphql/tests/unit/resolvers/test_resolvers.py
+++ b/swh/graphql/tests/unit/resolvers/test_resolvers.py
@@ -1,115 +1,120 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
import pytest
from swh.graphql import resolvers
from swh.graphql.resolvers import resolvers as rs
class TestResolvers:
""" """
@pytest.fixture
def dummy_node(self):
return {"test": "test"}
@pytest.mark.parametrize(
"resolver_func, node_cls",
[
(rs.origin_resolver, resolvers.origin.OriginNode),
(rs.visit_resolver, resolvers.visit.OriginVisitNode),
(rs.latest_visit_resolver, resolvers.visit.LatestVisitNode),
(
rs.latest_visit_status_resolver,
resolvers.visit_status.LatestVisitStatusNode,
),
(rs.snapshot_resolver, resolvers.snapshot.SnapshotNode),
(rs.visit_snapshot_resolver, resolvers.snapshot.VisitSnapshotNode),
(rs.revision_resolver, resolvers.revision.RevisionNode),
(rs.revision_directory_resolver, resolvers.directory.RevisionDirectoryNode),
(rs.release_resolver, resolvers.release.ReleaseNode),
(rs.directory_resolver, resolvers.directory.DirectoryNode),
(rs.content_resolver, resolvers.content.ContentNode),
],
)
def test_node_resolver(self, mocker, dummy_node, resolver_func, node_cls):
mock_get = mocker.patch.object(node_cls, "_get_node", return_value=dummy_node)
node_obj = resolver_func(None, None)
# assert the _get_node method is called on the right object
assert isinstance(node_obj, node_cls)
assert mock_get.assert_called
@pytest.mark.parametrize(
"resolver_func, connection_cls",
[
(rs.origins_resolver, resolvers.origin.OriginConnection),
(rs.visits_resolver, resolvers.visit.OriginVisitConnection),
(rs.origin_snapshots_resolver, resolvers.snapshot.OriginSnapshotConnection),
(rs.visitstatus_resolver, resolvers.visit_status.VisitStatusConnection),
(
rs.snapshot_branches_resolver,
resolvers.snapshot_branch.SnapshotBranchConnection,
),
(rs.revision_parents_resolver, resolvers.revision.ParentRevisionConnection),
# (rs.revision_log_resolver, resolvers.revision.LogRevisionConnection),
(
rs.directory_entry_resolver,
resolvers.directory_entry.DirectoryEntryConnection,
),
],
)
def test_connection_resolver(self, resolver_func, connection_cls):
connection_obj = resolver_func(None, None)
# assert the right object is returned
assert isinstance(connection_obj, connection_cls)
@pytest.mark.parametrize(
"branch_type, node_cls",
[
("revision", resolvers.revision.TargetRevisionNode),
("release", resolvers.release.TargetReleaseNode),
],
)
def test_snapshot_branch_target_resolver(
self, mocker, dummy_node, branch_type, node_cls
):
obj = mocker.Mock(type=branch_type)
mock_get = mocker.patch.object(node_cls, "_get_node", return_value=dummy_node)
node_obj = rs.snapshot_branch_target_resolver(obj, None)
assert isinstance(node_obj, node_cls)
assert mock_get.assert_called
@pytest.mark.parametrize(
"target_type, node_cls",
[
("revision", resolvers.revision.TargetRevisionNode),
("release", resolvers.release.TargetReleaseNode),
("directory", resolvers.directory.TargetDirectoryNode),
("content", resolvers.content.TargetContentNode),
],
)
def test_release_target_resolver(self, mocker, dummy_node, target_type, node_cls):
obj = mocker.Mock(target_type=(mocker.Mock(value=target_type)))
mock_get = mocker.patch.object(node_cls, "_get_node", return_value=dummy_node)
node_obj = rs.release_target_resolver(obj, None)
assert isinstance(node_obj, node_cls)
assert mock_get.assert_called
@pytest.mark.parametrize(
"target_type, node_cls",
[
("dir", resolvers.directory.TargetDirectoryNode),
("file", resolvers.content.TargetContentNode),
],
)
def test_directory_entry_target_resolver(
self, mocker, dummy_node, target_type, node_cls
):
obj = mocker.Mock(type=target_type)
mock_get = mocker.patch.object(node_cls, "_get_node", return_value=dummy_node)
node_obj = rs.directory_entry_target_resolver(obj, None)
assert isinstance(node_obj, node_cls)
assert mock_get.assert_called
def test_unit_resolver(self, mocker):
obj = mocker.Mock()
obj.is_type_of.return_value = "test"
assert rs.union_resolver(obj) == "test"
diff --git a/swh/graphql/tests/unit/utils/test_utils.py b/swh/graphql/tests/unit/utils/test_utils.py
index a692a8d..bc7c5ac 100644
--- a/swh/graphql/tests/unit/utils/test_utils.py
+++ b/swh/graphql/tests/unit/utils/test_utils.py
@@ -1,61 +1,66 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
import datetime
from swh.graphql.utils import utils
class TestUtils:
def test_b64encode(self):
assert utils.b64encode("testing") == "dGVzdGluZw=="
def test_get_encoded_cursor_is_none(self):
assert utils.get_encoded_cursor(None) is None
def test_get_encoded_cursor(self):
assert utils.get_encoded_cursor(None) is None
assert utils.get_encoded_cursor("testing") == "dGVzdGluZw=="
def test_get_decoded_cursor_is_none(self):
assert utils.get_decoded_cursor(None) is None
def test_get_decoded_cursor(self):
assert utils.get_decoded_cursor("dGVzdGluZw==") == "testing"
def test_str_to_sha1(self):
assert (
utils.str_to_sha1("208f61cc7a5dbc9879ae6e5c2f95891e270f09ef")
== b" \x8fa\xccz]\xbc\x98y\xaen\\/\x95\x89\x1e'\x0f\t\xef"
)
def test_get_formatted_date(self):
date = datetime.datetime(
2015, 8, 4, 22, 26, 14, 804009, tzinfo=datetime.timezone.utc
)
assert utils.get_formatted_date(date) == "2015-08-04T22:26:14.804009+00:00"
def test_paginated(self):
source = [1, 2, 3, 4, 5]
response = utils.paginated(source, first=50)
assert response.results == source
assert response.next_page_token is None
def test_paginated_first_arg(self):
source = [1, 2, 3, 4, 5]
response = utils.paginated(source, first=2)
assert response.results == source[:2]
assert response.next_page_token == "2"
def test_paginated_after_arg(self):
source = [1, 2, 3, 4, 5]
response = utils.paginated(source, first=2, after="2")
assert response.results == [3, 4]
assert response.next_page_token == "4"
response = utils.paginated(source, first=2, after="3")
assert response.results == [4, 5]
assert response.next_page_token is None
def test_paginated_endcursor_outside(self):
source = [1, 2, 3, 4, 5]
response = utils.paginated(source, first=2, after="10")
assert response.results == []
assert response.next_page_token is None
diff --git a/swh/graphql/utils/utils.py b/swh/graphql/utils/utils.py
index 77748a7..c6f4094 100644
--- a/swh/graphql/utils/utils.py
+++ b/swh/graphql/utils/utils.py
@@ -1,49 +1,54 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
import base64
from datetime import datetime
from typing import List
from swh.storage.interface import PagedResult
def b64encode(text: str) -> str:
return base64.b64encode(bytes(text, "utf-8")).decode("utf-8")
def get_encoded_cursor(cursor: str) -> str:
if cursor is None:
return None
return b64encode(cursor)
def get_decoded_cursor(cursor: str) -> str:
if cursor is None:
return None
return base64.b64decode(cursor).decode("utf-8")
def str_to_sha1(sha1: str) -> bytearray:
# FIXME, use core function
return bytearray.fromhex(sha1)
def get_formatted_date(date: datetime) -> str:
# FIXME, handle error + return other formats
return date.isoformat()
def paginated(source: List, first: int, after=0) -> PagedResult:
"""
Pagination at the GraphQL level
This is a temporary fix and inefficient.
Should eventually be moved to the
backend (storage) level
"""
# FIXME, handle data errors here
after = 0 if after is None else int(after)
end_cursor = after + first
results = source[after:end_cursor]
next_page_token = None
if len(source) > end_cursor:
next_page_token = str(end_cursor)
return PagedResult(results=results, next_page_token=next_page_token)
File Metadata
Details
Attached
Mime Type
text/x-diff
Expires
Fri, Jul 4, 1:52 PM (4 d, 6 h ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3262774
Attached To
rDGQL GraphQL API
Event Timeline
Log In to Comment