Page MenuHomeSoftware Heritage

No OneTemporary

diff --git a/setup.py b/setup.py
index 1728606..5312f69 100644
--- a/setup.py
+++ b/setup.py
@@ -1,68 +1,68 @@
#!/usr/bin/env python3
-# Copyright (C) 2019-2021 The Software Heritage developers
+# Copyright (C) 2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from io import open
from os import path
from setuptools import find_packages, setup
here = path.abspath(path.dirname(__file__))
# Get the long description from the README file
with open(path.join(here, "README.md"), encoding="utf-8") as f:
long_description = f.read()
def parse_requirements(*names):
requirements = []
for name in names:
if name:
reqf = "requirements-%s.txt" % name
else:
reqf = "requirements.txt"
if not path.exists(reqf):
return requirements
with open(reqf) as f:
for line in f.readlines():
line = line.strip()
if not line or line.startswith("#"):
continue
requirements.append(line)
return requirements
setup(
name="swh.graphql",
description="Software Heritage GraphQL Apis",
long_description=long_description,
long_description_content_type="text/x-rst",
python_requires=">=3.7",
author="Software Heritage developers",
author_email="swh-devel@inria.fr",
url="https://forge.softwareheritage.org/diffusion/DGQL",
packages=find_packages(),
install_requires=parse_requirements(None, "swh"),
tests_require=parse_requirements("test"),
setup_requires=["setuptools-scm"],
use_scm_version=True,
extras_require={"testing": parse_requirements("test")},
include_package_data=True,
classifiers=[
"Programming Language :: Python :: 3",
"Intended Audience :: Developers",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
"Operating System :: OS Independent",
"Development Status :: 3 - Alpha",
],
project_urls={
"Bug Reports": "https://forge.softwareheritage.org/maniphest",
"Funding": "https://www.softwareheritage.org/donate",
"Source": "https://forge.softwareheritage.org/source/swh-graphql",
"Documentation": "https://docs.softwareheritage.org/devel/swh-graphql/",
},
)
diff --git a/swh/graphql/app.py b/swh/graphql/app.py
index f438767..326e1b2 100644
--- a/swh/graphql/app.py
+++ b/swh/graphql/app.py
@@ -1,35 +1,40 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
# import pkg_resources
import os
from pathlib import Path
from ariadne import gql, load_schema_from_path, make_executable_schema
from .resolvers import resolvers, scalars
type_defs = gql(
# pkg_resources.resource_string("swh.graphql", "schem/schema.graphql").decode()
load_schema_from_path(
os.path.join(Path(__file__).parent.resolve(), "schema", "schema.graphql")
)
)
schema = make_executable_schema(
type_defs,
resolvers.query,
resolvers.origin,
resolvers.visit,
resolvers.visit_status,
resolvers.snapshot,
resolvers.snapshot_branch,
resolvers.revision,
resolvers.release,
resolvers.directory,
resolvers.directory_entry,
resolvers.branch_target,
resolvers.release_target,
resolvers.directory_entry_target,
scalars.id_scalar,
scalars.string_scalar,
scalars.datetime_scalar,
scalars.swhid_scalar,
)
diff --git a/swh/graphql/backends/archive.py b/swh/graphql/backends/archive.py
index 7ee5b50..644c29c 100644
--- a/swh/graphql/backends/archive.py
+++ b/swh/graphql/backends/archive.py
@@ -1,72 +1,77 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from swh.graphql import server
class Archive:
def __init__(self):
self.storage = server.get_storage()
def get_origin(self, url):
return self.storage.origin_get([url])[0]
def get_origins(self, after=None, first=50, url_pattern=None):
# STORAGE-TODO
# Make them a single function in the backend
if url_pattern is None:
return self.storage.origin_list(page_token=after, limit=first)
return self.storage.origin_search(
url_pattern=url_pattern, page_token=after, limit=first
)
def get_origin_visits(self, origin_url, after=None, first=50):
return self.storage.origin_visit_get(origin_url, page_token=after, limit=first)
def get_origin_visit(self, origin_url, visit_id):
return self.storage.origin_visit_get_by(origin_url, visit_id)
def get_origin_latest_visit(self, origin_url):
return self.storage.origin_visit_get_latest(origin_url)
def get_visit_status(self, origin_url, visit_id, after=None, first=50):
return self.storage.origin_visit_status_get(
origin_url, visit_id, page_token=after, limit=first
)
def get_latest_visit_status(self, origin_url, visit_id):
return self.storage.origin_visit_status_get_latest(origin_url, visit_id)
def get_origin_snapshots(self, origin_url):
return self.storage.origin_snapshot_get_all(origin_url)
def is_snapshot_available(self, snapshot_ids):
return not self.storage.snapshot_missing(snapshot_ids)
def get_snapshot_branches(self, snapshot, after, first, target_types, name_include):
return self.storage.snapshot_get_branches(
snapshot,
branches_from=after,
branches_count=first,
target_types=target_types,
branch_name_include_substring=name_include,
)
def get_revisions(self, revision_ids):
return self.storage.revision_get(revision_ids=revision_ids)
def get_revision_log(self, revision_ids, after=None, first=50):
return self.storage.revision_log(revisions=revision_ids, limit=first)
def get_releases(self, release_ids):
return self.storage.release_get(releases=release_ids)
def is_directory_available(self, directory_ids):
return not self.storage.directory_missing(directory_ids)
def get_directory_entries(self, directory_id, after=None, first=50):
return self.storage.directory_get_entries(
directory_id, limit=first, page_token=after
)
def get_content(self, content_id):
# FIXME, only for tests
return self.storage.content_find({"sha1_git": content_id})
diff --git a/swh/graphql/errors/__init__.py b/swh/graphql/errors/__init__.py
index 103f1bd..3e8bd41 100644
--- a/swh/graphql/errors/__init__.py
+++ b/swh/graphql/errors/__init__.py
@@ -1,4 +1,9 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from .errors import ObjectNotFoundError
from .handlers import format_error
__all__ = ["ObjectNotFoundError", "format_error"]
diff --git a/swh/graphql/errors/errors.py b/swh/graphql/errors/errors.py
index 6a587f5..8036b65 100644
--- a/swh/graphql/errors/errors.py
+++ b/swh/graphql/errors/errors.py
@@ -1,2 +1,8 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+
class ObjectNotFoundError(Exception):
""" """
diff --git a/swh/graphql/errors/handlers.py b/swh/graphql/errors/handlers.py
index 8d13115..c61e593 100644
--- a/swh/graphql/errors/handlers.py
+++ b/swh/graphql/errors/handlers.py
@@ -1,7 +1,13 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+
def format_error(error) -> dict:
"""
Response error formatting
"""
formatted = error.formatted
formatted["message"] = "Unknown error"
return formatted
diff --git a/swh/graphql/middlewares/graphql.py b/swh/graphql/middlewares/graphql.py
index a203330..9314edf 100644
--- a/swh/graphql/middlewares/graphql.py
+++ b/swh/graphql/middlewares/graphql.py
@@ -1,11 +1,16 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
"""
To implement graphql middleware
"""
class CostError:
pass
def cost_limiter():
pass
diff --git a/swh/graphql/qns.txt b/swh/graphql/qns.txt
deleted file mode 100644
index 1e87ec6..0000000
--- a/swh/graphql/qns.txt
+++ /dev/null
@@ -1,14 +0,0 @@
-Questions
-=========
-* Idea for a homogeneous ID for node (can we expose primary key from postgres)
-* Why the visit status is a paginated list in storage?, and not in v1
-* visit id should be visit number
-... Schema related questions
-
-* What should we include in pageinfo (start token, haspreviouspage etc)
-* Query cost calculator logic (could be a bit complex)
-* Throttling based on Query cost calculator
-* Authentication and Authorization
-
-* Datetime as a time stamp
-* Other scalar types needed (swhid)
diff --git a/swh/graphql/resolvers/base_connection.py b/swh/graphql/resolvers/base_connection.py
index 617d2a7..f5f56b3 100644
--- a/swh/graphql/resolvers/base_connection.py
+++ b/swh/graphql/resolvers/base_connection.py
@@ -1,107 +1,112 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Type
from swh.graphql.utils import utils
from .base_node import BaseNode
@dataclass
class PageInfo:
hasNextPage: bool
endCursor: str
class BaseConnection(ABC):
"""
Base class for all the connection resolvers
"""
_node_class: Optional[Type[BaseNode]] = None
_page_size = 50 # default page size
def __init__(self, obj, info, paged_data=None, **kwargs):
self.obj = obj
self.info = info
self.kwargs = kwargs
self._paged_data = paged_data
def __call__(self, *args, **kw):
return self
@property
def edges(self):
return self._get_edges()
@property
def nodes(self):
"""
Override if needed; return a list of objects
If a node class is set, return a list of its (Node) instances
else a list of raw results
"""
if self._node_class is not None:
return [
self._node_class(self.obj, self.info, node_data=result, **self.kwargs)
for result in self.get_paged_data().results
]
return self.get_paged_data().results
@property
def pageInfo(self): # To support the schema naming convention
# FIXME, add more details like startCursor
return PageInfo(
hasNextPage=bool(self.get_paged_data().next_page_token),
endCursor=utils.get_encoded_cursor(self.get_paged_data().next_page_token),
)
@property
def totalCount(self): # To support the schema naming convention
return self._get_total_count()
def _get_total_count(self):
"""
Will be None for most of the connections
override if needed/possible
"""
return None
def get_paged_data(self):
"""
Cache to avoid multiple calls to
the backend (_get_paged_result)
return a PagedResult object
"""
if self._paged_data is None:
# FIXME, make this call async (not for v1)
self._paged_data = self._get_paged_result()
return self._paged_data
@abstractmethod
def _get_paged_result(self):
"""
Override for desired behaviour
return a PagedResult object
"""
# FIXME, make this call async (not for v1)
return None
def _get_edges(self):
# FIXME, make cursor work per item
# Cursor can't be None here
return [{"cursor": "dummy", "node": node} for node in self.nodes]
def _get_after_arg(self):
"""
Return the decoded next page token
override to use a specific token
"""
return utils.get_decoded_cursor(self.kwargs.get("after"))
def _get_first_arg(self):
"""
page_size is set to 50 by default
"""
return self.kwargs.get("first", self._page_size)
diff --git a/swh/graphql/resolvers/base_node.py b/swh/graphql/resolvers/base_node.py
index 9f389f3..ff3c2b3 100644
--- a/swh/graphql/resolvers/base_node.py
+++ b/swh/graphql/resolvers/base_node.py
@@ -1,76 +1,81 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from abc import ABC
from collections import namedtuple
from swh.graphql.errors import ObjectNotFoundError
class BaseNode(ABC):
"""
Base class for all the Node resolvers
"""
def __init__(self, obj, info, node_data=None, **kwargs):
self.obj = obj
self.info = info
self.kwargs = kwargs
self._node = self._get_node(node_data)
# handle the errors, if any, after _node is set
self._handle_node_errors()
def _get_node(self, node_data):
"""
Get the node object from the given data
if the data (node_data) is none make
a function call to get data from backend
"""
if node_data is None:
node_data = self._get_node_data()
return self._get_node_from_data(node_data)
def _get_node_from_data(self, node_data):
"""
Get the object from node_data
In case of a dict, convert it to an object
Override to support different data structures
"""
if type(node_data) is dict:
return namedtuple("NodeObj", node_data.keys())(*node_data.values())
return node_data
def _handle_node_errors(self):
"""
Handle any error related to node data
raise an error in case the object returned is None
override for specific behaviour
"""
if self._node is None:
raise ObjectNotFoundError("Requested object is not available")
def __call__(self, *args, **kw):
return self
def _get_node_data(self):
"""
Override for desired behaviour
This will be called only when
node_data is None
"""
# FIXME, make this call async (not for v1)
return None
def __getattr__(self, name):
"""
Any property defined in the sub-class
will get precedence over the _node attributes
"""
return getattr(self._node, name)
def is_type_of(self):
return self.__class__.__name__
class BaseSWHNode(BaseNode):
@property
def SWHID(self):
return self._node.swhid()
diff --git a/swh/graphql/resolvers/content.py b/swh/graphql/resolvers/content.py
index 68d71e0..08d8637 100644
--- a/swh/graphql/resolvers/content.py
+++ b/swh/graphql/resolvers/content.py
@@ -1,44 +1,49 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from swh.graphql.backends import archive
from .base_node import BaseSWHNode
class BaseContentNode(BaseSWHNode):
""" """
def _get_content_by_id(self, content_id):
content = archive.Archive().get_content(content_id)
return content[0] if content else None
@property
def checksum(self):
# FIXME, return a Node object
return {k: v.hex() for (k, v) in self._node.hashes().items()}
@property
def id(self):
return self._node.sha1_git
def is_type_of(self):
return "Content"
class ContentNode(BaseContentNode):
def _get_node_data(self):
"""
When a content is requested directly
with its SWHID
"""
return self._get_content_by_id(self.kwargs.get("SWHID").object_id)
class TargetContentNode(BaseContentNode):
def _get_node_data(self):
"""
When a content is requested from a
directory entry or from a release target
content id is obj.targetHash here
"""
content_id = self.obj.targetHash
return self._get_content_by_id(content_id)
diff --git a/swh/graphql/resolvers/directory.py b/swh/graphql/resolvers/directory.py
index 68c5d78..5c345b7 100644
--- a/swh/graphql/resolvers/directory.py
+++ b/swh/graphql/resolvers/directory.py
@@ -1,50 +1,55 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from swh.graphql.backends import archive
from swh.model.model import Directory
from .base_node import BaseSWHNode
class BaseDirectoryNode(BaseSWHNode):
def _get_directory_by_id(self, directory_id):
# Return a Directory model object
# entries is initialized as empty
# Same pattern is used in snapshot
return Directory(id=directory_id, entries=())
def is_type_of(self):
return "Directory"
class DirectoryNode(BaseDirectoryNode):
def _get_node_data(self):
"""
When a directory is requested directly with its SWHID
"""
directory_id = self.kwargs.get("SWHID").object_id
# path = ""
if archive.Archive().is_directory_available([directory_id]):
return self._get_directory_by_id(directory_id)
return None
class RevisionDirectoryNode(BaseDirectoryNode):
def _get_node_data(self):
"""
When a directory is requested from a revision
self.obj is revision here
self.obj.directorySWHID is the required dir SWHID
(set from resolvers.revision.py:BaseRevisionNode)
"""
directory_id = self.obj.directorySWHID.object_id
return self._get_directory_by_id(directory_id)
class TargetDirectoryNode(BaseDirectoryNode):
def _get_node_data(self):
"""
When a directory is requested as a target
self.obj can be a Release or a DirectoryEntry
obj.targetHash is the requested directory id here
"""
return self._get_directory_by_id(self.obj.targetHash)
diff --git a/swh/graphql/resolvers/directory_entry.py b/swh/graphql/resolvers/directory_entry.py
index b4b9712..bbb94c9 100644
--- a/swh/graphql/resolvers/directory_entry.py
+++ b/swh/graphql/resolvers/directory_entry.py
@@ -1,34 +1,39 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from swh.graphql.backends import archive
from swh.graphql.utils import utils
from .base_connection import BaseConnection
from .base_node import BaseNode
class DirectoryEntryNode(BaseNode):
""" """
@property
def targetHash(self): # To support the schema naming convention
return self._node.target
class DirectoryEntryConnection(BaseConnection):
_node_class = DirectoryEntryNode
def _get_paged_result(self):
"""
When entries requested from a directory
self.obj.SWHID is the directory SWHID here
This is not paginated from swh-storgae
using dummy pagination
"""
# FIXME, using dummy(local) pagination, move pagination to backend
# To remove localpagination, just drop the paginated call
# STORAGE-TODO
entries = (
archive.Archive().get_directory_entries(self.obj.SWHID.object_id).results
)
return utils.paginated(entries, self._get_first_arg(), self._get_after_arg())
diff --git a/swh/graphql/resolvers/person.py b/swh/graphql/resolvers/person.py
index ca55b85..c805924 100644
--- a/swh/graphql/resolvers/person.py
+++ b/swh/graphql/resolvers/person.py
@@ -1,5 +1,10 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from .base_node import BaseNode
class PersonNode(BaseNode):
""" """
diff --git a/swh/graphql/resolvers/release.py b/swh/graphql/resolvers/release.py
index 767b197..2187e2f 100644
--- a/swh/graphql/resolvers/release.py
+++ b/swh/graphql/resolvers/release.py
@@ -1,45 +1,50 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from swh.graphql.backends import archive
from .base_node import BaseSWHNode
class BaseReleaseNode(BaseSWHNode):
def _get_release_by_id(self, release_id):
return (archive.Archive().get_releases([release_id]) or None)[0]
@property
def targetHash(self): # To support the schema naming convention
return self._node.target
@property
def targetType(self): # To support the schema naming convention
return self._node.target_type.value
def is_type_of(self):
"""
is_type_of is required only when resolving
a UNION type
This is for ariadne to return the right type
"""
return "Release"
class ReleaseNode(BaseReleaseNode):
"""
When the release is requested directly with its SWHID
"""
def _get_node_data(self):
return self._get_release_by_id(self.kwargs.get("SWHID").object_id)
class TargetReleaseNode(BaseReleaseNode):
"""
When a release is requested as a target
self.obj could be a snapshotbranch or a release
self.obj.targetHash is the requested release id here
"""
def _get_node_data(self):
return self._get_release_by_id(self.obj.targetHash)
diff --git a/swh/graphql/resolvers/resolver_factory.py b/swh/graphql/resolvers/resolver_factory.py
index 300da90..286573e 100644
--- a/swh/graphql/resolvers/resolver_factory.py
+++ b/swh/graphql/resolvers/resolver_factory.py
@@ -1,60 +1,65 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from .content import ContentNode, TargetContentNode
from .directory import DirectoryNode, RevisionDirectoryNode, TargetDirectoryNode
from .directory_entry import DirectoryEntryConnection
from .origin import OriginConnection, OriginNode
from .release import ReleaseNode, TargetReleaseNode
from .revision import (
LogRevisionConnection,
ParentRevisionConnection,
RevisionNode,
TargetRevisionNode,
)
from .snapshot import OriginSnapshotConnection, SnapshotNode, VisitSnapshotNode
from .snapshot_branch import SnapshotBranchConnection
from .visit import LatestVisitNode, OriginVisitConnection, OriginVisitNode
from .visit_status import LatestVisitStatusNode, VisitStatusConnection
def get_node_resolver(resolver_type):
# FIXME, replace with a proper factory method
mapping = {
"origin": OriginNode,
"visit": OriginVisitNode,
"latest-visit": LatestVisitNode,
"latest-status": LatestVisitStatusNode,
"visit-snapshot": VisitSnapshotNode,
"snapshot": SnapshotNode,
"branch-revision": TargetRevisionNode,
"branch-release": TargetReleaseNode,
"revision": RevisionNode,
"revision-directory": RevisionDirectoryNode,
"release": ReleaseNode,
"release-revision": TargetRevisionNode,
"release-release": TargetReleaseNode,
"release-directory": TargetDirectoryNode,
"release-content": TargetContentNode,
"directory": DirectoryNode,
"content": ContentNode,
"dir-entry-dir": TargetDirectoryNode,
"dir-entry-file": TargetContentNode,
}
if resolver_type not in mapping:
raise AttributeError(f"Invalid node type: {resolver_type}")
return mapping[resolver_type]
def get_connection_resolver(resolver_type):
# FIXME, replace with a proper factory method
mapping = {
"origins": OriginConnection,
"origin-visits": OriginVisitConnection,
"origin-snapshots": OriginSnapshotConnection,
"visit-status": VisitStatusConnection,
"snapshot-branches": SnapshotBranchConnection,
"revision-parents": ParentRevisionConnection,
"revision-log": LogRevisionConnection,
"directory-entries": DirectoryEntryConnection,
}
if resolver_type not in mapping:
raise AttributeError(f"Invalid connection type: {resolver_type}")
return mapping[resolver_type]
diff --git a/swh/graphql/resolvers/resolvers.py b/swh/graphql/resolvers/resolvers.py
index ad85786..3675fd3 100644
--- a/swh/graphql/resolvers/resolvers.py
+++ b/swh/graphql/resolvers/resolvers.py
@@ -1,240 +1,245 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
"""
High level resolvers
Any schema attribute can be resolved by any of the following ways
and in the following priority order
- In this module using an annotation (eg: @visitstatus.field("snapshot"))
- As a property in the Node object (eg: resolvers.visit.OriginVisitNode.id)
- As an attribute/item in the object/dict returned by the backend (eg: Origin.url)
"""
from ariadne import ObjectType, UnionType
from graphql.type import GraphQLResolveInfo
from swh.graphql import resolvers as rs
from .resolver_factory import get_connection_resolver, get_node_resolver
query: ObjectType = ObjectType("Query")
origin: ObjectType = ObjectType("Origin")
visit: ObjectType = ObjectType("Visit")
visit_status: ObjectType = ObjectType("VisitStatus")
snapshot: ObjectType = ObjectType("Snapshot")
snapshot_branch: ObjectType = ObjectType("Branch")
revision: ObjectType = ObjectType("Revision")
release: ObjectType = ObjectType("Release")
directory: ObjectType = ObjectType("Directory")
directory_entry: ObjectType = ObjectType("DirectoryEntry")
branch_target: UnionType = UnionType("BranchTarget")
release_target: UnionType = UnionType("ReleaseTarget")
directory_entry_target: UnionType = UnionType("DirectoryEntryTarget")
# Node resolvers
# A node resolver should return an instance of BaseNode
@query.field("origin")
def origin_resolver(obj: None, info: GraphQLResolveInfo, **kw) -> rs.origin.OriginNode:
""" """
resolver = get_node_resolver("origin")
return resolver(obj, info, **kw)()
@origin.field("latestVisit")
def latest_visit_resolver(
obj: rs.origin.OriginNode, info: GraphQLResolveInfo, **kw
) -> rs.visit.LatestVisitNode:
""" """
resolver = get_node_resolver("latest-visit")
return resolver(obj, info, **kw)()
@query.field("visit")
def visit_resolver(
obj: None, info: GraphQLResolveInfo, **kw
) -> rs.visit.OriginVisitNode:
""" """
resolver = get_node_resolver("visit")
return resolver(obj, info, **kw)()
@visit.field("latestStatus")
def latest_visit_status_resolver(
obj, info: GraphQLResolveInfo, **kw
) -> rs.visit_status.LatestVisitStatusNode:
""" """
resolver = get_node_resolver("latest-status")
return resolver(obj, info, **kw)()
@query.field("snapshot")
def snapshot_resolver(
obj: None, info: GraphQLResolveInfo, **kw
) -> rs.snapshot.SnapshotNode:
""" """
resolver = get_node_resolver("snapshot")
return resolver(obj, info, **kw)()
@visit_status.field("snapshot")
def visit_snapshot_resolver(
obj, info: GraphQLResolveInfo, **kw
) -> rs.snapshot.VisitSnapshotNode:
resolver = get_node_resolver("visit-snapshot")
return resolver(obj, info, **kw)()
@snapshot_branch.field("target")
def snapshot_branch_target_resolver(
obj: rs.snapshot_branch.SnapshotBranchNode, info: GraphQLResolveInfo, **kw
):
"""
Snapshot branch target can be a revision or a release
"""
resolver_type = f"branch-{obj.type}"
resolver = get_node_resolver(resolver_type)
return resolver(obj, info, **kw)()
@query.field("revision")
def revision_resolver(
obj: None, info: GraphQLResolveInfo, **kw
) -> rs.revision.RevisionNode:
resolver = get_node_resolver("revision")
return resolver(obj, info, **kw)()
@revision.field("directory")
def revision_directory_resolver(
obj, info: GraphQLResolveInfo, **kw
) -> rs.directory.RevisionDirectoryNode:
resolver = get_node_resolver("revision-directory")
return resolver(obj, info, **kw)()
@query.field("release")
def release_resolver(
obj: None, info: GraphQLResolveInfo, **kw
) -> rs.release.ReleaseNode:
resolver = get_node_resolver("release")
return resolver(obj, info, **kw)()
@release.field("target")
def release_target_resolver(obj, info: GraphQLResolveInfo, **kw):
"""
release target can be a release, revision,
directory or content
obj is release here, target type is
obj.target_type
"""
resolver_type = f"release-{obj.target_type.value}"
resolver = get_node_resolver(resolver_type)
return resolver(obj, info, **kw)()
@query.field("directory")
def directory_resolver(
obj: None, info: GraphQLResolveInfo, **kw
) -> rs.directory.DirectoryNode:
resolver = get_node_resolver("directory")
return resolver(obj, info, **kw)()
@directory_entry.field("target")
def directory_entry_target_resolver(
obj: rs.directory_entry.DirectoryEntryNode, info: GraphQLResolveInfo, **kw
):
"""
directory entry target can be a directory or a content
"""
resolver_type = f"dir-entry-{obj.type}"
resolver = get_node_resolver(resolver_type)
return resolver(obj, info, **kw)()
@query.field("content")
def content_resolver(
obj: None, info: GraphQLResolveInfo, **kw
) -> rs.content.ContentNode:
resolver = get_node_resolver("content")
return resolver(obj, info, **kw)()
# Connection resolvers
# A connection resolver should return an instance of BaseConnection
@query.field("origins")
def origins_resolver(
obj: None, info: GraphQLResolveInfo, **kw
) -> rs.origin.OriginConnection:
resolver = get_connection_resolver("origins")
return resolver(obj, info, **kw)()
@origin.field("visits")
def visits_resolver(
obj: rs.origin.OriginNode, info: GraphQLResolveInfo, **kw
) -> rs.visit.OriginVisitConnection:
resolver = get_connection_resolver("origin-visits")
return resolver(obj, info, **kw)()
@origin.field("snapshots")
def origin_snapshots_resolver(
obj: rs.origin.OriginNode, info: GraphQLResolveInfo, **kw
) -> rs.snapshot.OriginSnapshotConnection:
""" """
resolver = get_connection_resolver("origin-snapshots")
return resolver(obj, info, **kw)()
@visit.field("status")
def visitstatus_resolver(
obj, info: GraphQLResolveInfo, **kw
) -> rs.visit_status.VisitStatusConnection:
resolver = get_connection_resolver("visit-status")
return resolver(obj, info, **kw)()
@snapshot.field("branches")
def snapshot_branches_resolver(
obj, info: GraphQLResolveInfo, **kw
) -> rs.snapshot_branch.SnapshotBranchConnection:
resolver = get_connection_resolver("snapshot-branches")
return resolver(obj, info, **kw)()
@revision.field("parents")
def revision_parents_resolver(
obj, info: GraphQLResolveInfo, **kw
) -> rs.revision.ParentRevisionConnection:
resolver = get_connection_resolver("revision-parents")
return resolver(obj, info, **kw)()
# @revision.field("revisionLog")
# def revision_log_resolver(obj, info, **kw):
# resolver = get_connection_resolver("revision-log")
# return resolver(obj, info, **kw)()
@directory.field("entries")
def directory_entry_resolver(
obj, info: GraphQLResolveInfo, **kw
) -> rs.directory_entry.DirectoryEntryConnection:
resolver = get_connection_resolver("directory-entries")
return resolver(obj, info, **kw)()
# Any other type of resolver
@release_target.type_resolver
@directory_entry_target.type_resolver
@branch_target.type_resolver
def union_resolver(obj, *_) -> str:
"""
Generic resolver for all the union types
"""
return obj.is_type_of()
diff --git a/swh/graphql/resolvers/revision.py b/swh/graphql/resolvers/revision.py
index 8cb6cd1..99da788 100644
--- a/swh/graphql/resolvers/revision.py
+++ b/swh/graphql/resolvers/revision.py
@@ -1,97 +1,102 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from swh.graphql.backends import archive
from swh.graphql.utils import utils
from swh.model.swhids import CoreSWHID, ObjectType
from .base_connection import BaseConnection
from .base_node import BaseSWHNode
class BaseRevisionNode(BaseSWHNode):
def _get_revision_by_id(self, revision_id):
return (archive.Archive().get_revisions([revision_id]) or None)[0]
@property
def parentSWHIDs(self): # To support the schema naming convention
return [
CoreSWHID(object_type=ObjectType.REVISION, object_id=parent_id)
for parent_id in self._node.parents
]
@property
def directorySWHID(self): # To support the schema naming convention
""" """
return CoreSWHID(
object_type=ObjectType.DIRECTORY, object_id=self._node.directory
)
@property
def type(self):
return self._node.type.value
def is_type_of(self):
"""
is_type_of is required only when resolving
a UNION type
This is for ariadne to return the right type
"""
return "Revision"
class RevisionNode(BaseRevisionNode):
"""
When the revision is requested directly with its SWHID
"""
def _get_node_data(self):
return self._get_revision_by_id(self.kwargs.get("SWHID").object_id)
class TargetRevisionNode(BaseRevisionNode):
"""
When a revision is requested as a target
self.obj could be a snapshotbranch or a release
self.obj.targetHash is the requested revision id here
"""
def _get_node_data(self):
return self._get_revision_by_id(self.obj.targetHash)
class ParentRevisionConnection(BaseConnection):
"""
When parent revisions is requested from a
revision
self.obj is the current(child) revision
self.obj.parentSWHIDs is the list of
parent SWHIDs
"""
_node_class = BaseRevisionNode
def _get_paged_result(self):
# FIXME, using dummy(local) pagination, move pagination to backend
# To remove localpagination, just drop the paginated call
# STORAGE-TODO (pagination)
parents = archive.Archive().get_revisions(
[x.object_id for x in self.obj.parentSWHIDs]
)
return utils.paginated(parents, self._get_first_arg(), self._get_after_arg())
class LogRevisionConnection(BaseConnection):
"""
When revisionslog is requested from a
revision
self.obj is the current revision id
"""
_node_class = BaseRevisionNode
def _get_paged_result(self):
# STORAGE-TODO (date in revisionlog is a dict)
log = archive.Archive().get_revision_log([self.obj.SWHID.object_id])
# FIXME, using dummy(local) pagination, move pagination to backend
# To remove localpagination, just drop the paginated call
# STORAGE-TODO (pagination)
return utils.paginated(log, self._get_first_arg(), self._get_after_arg())
diff --git a/swh/graphql/resolvers/scalars.py b/swh/graphql/resolvers/scalars.py
index e21ffec..5d34594 100644
--- a/swh/graphql/resolvers/scalars.py
+++ b/swh/graphql/resolvers/scalars.py
@@ -1,46 +1,51 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from datetime import datetime
from ariadne import ScalarType
from swh.graphql.utils import utils
from swh.model.model import TimestampWithTimezone
from swh.model.swhids import CoreSWHID
datetime_scalar = ScalarType("DateTime")
swhid_scalar = ScalarType("SWHID")
id_scalar = ScalarType("ID")
string_scalar = ScalarType("String")
@id_scalar.serializer
def serialize_id(value):
if type(value) is bytes:
return value.hex()
return value
@string_scalar.serializer
def serialize_string(value):
if type(value) is bytes:
return value.decode("utf-8")
return value
@datetime_scalar.serializer
def serialize_datetime(value):
# FIXME, handle error and return None
if type(value) == TimestampWithTimezone:
value = value.to_datetime()
if type(value) == datetime:
return utils.get_formatted_date(value)
return None
@swhid_scalar.value_parser
def validate_swhid(value):
return CoreSWHID.from_string(value)
@swhid_scalar.serializer
def serialize_swhid(value):
return str(value)
diff --git a/swh/graphql/resolvers/snapshot.py b/swh/graphql/resolvers/snapshot.py
index 568a636..d0e81e6 100644
--- a/swh/graphql/resolvers/snapshot.py
+++ b/swh/graphql/resolvers/snapshot.py
@@ -1,54 +1,59 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from swh.graphql.backends import archive
from swh.graphql.utils import utils
from swh.model.model import Snapshot
from .base_connection import BaseConnection
from .base_node import BaseSWHNode
class BaseSnapshotNode(BaseSWHNode):
def _get_snapshot_by_id(self, snapshot_id):
# Return a Snapshot model object
# branches is initialized as empty
# Same pattern is used in directory
return Snapshot(id=snapshot_id, branches={})
class SnapshotNode(BaseSnapshotNode):
"""
For directly accessing a snapshot with its SWHID
"""
def _get_node_data(self):
""" """
snapshot_id = self.kwargs.get("SWHID").object_id
if archive.Archive().is_snapshot_available([snapshot_id]):
return self._get_snapshot_by_id(snapshot_id)
return None
class VisitSnapshotNode(BaseSnapshotNode):
"""
For accessing a snapshot from a visitstatus type
"""
def _get_node_data(self):
"""
self.obj is visitstatus here
self.obj.snapshotSWHID is the requested snapshot SWHID
"""
snapshot_id = self.obj.snapshotSWHID.object_id
return self._get_snapshot_by_id(snapshot_id)
class OriginSnapshotConnection(BaseConnection):
_node_class = BaseSnapshotNode
def _get_paged_result(self):
""" """
results = archive.Archive().get_origin_snapshots(self.obj.url)
snapshots = [Snapshot(id=snapshot, branches={}) for snapshot in results]
# FIXME, using dummy(local) pagination, move pagination to backend
# To remove localpagination, just drop the paginated call
# STORAGE-TODO
return utils.paginated(snapshots, self._get_first_arg(), self._get_after_arg())
diff --git a/swh/graphql/resolvers/snapshot_branch.py b/swh/graphql/resolvers/snapshot_branch.py
index 2c49e24..0bcd37b 100644
--- a/swh/graphql/resolvers/snapshot_branch.py
+++ b/swh/graphql/resolvers/snapshot_branch.py
@@ -1,73 +1,78 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from collections import namedtuple
from swh.graphql.backends import archive
from swh.graphql.utils import utils
from swh.storage.interface import PagedResult
from .base_connection import BaseConnection
from .base_node import BaseNode
class SnapshotBranchNode(BaseNode):
"""
target field for this Node is a UNION in the schema
It is resolved in resolvers.resolvers.py
"""
def _get_node_from_data(self, node_data):
"""
node_data is not a dict in this case
overriding to support this special data structure
"""
# STORAGE-TODO; return an object in the normal format
branch_name, branch_obj = node_data
node = {
"name": branch_name,
"type": branch_obj.target_type.value,
"target": branch_obj.target,
}
return namedtuple("NodeObj", node.keys())(*node.values())
@property
def targetHash(self): # To support the schema naming convention
return self._node.target
class SnapshotBranchConnection(BaseConnection):
_node_class = SnapshotBranchNode
def _get_paged_result(self):
"""
When branches requested from a snapshot
self.obj.SWHID is the snapshot SWHID here
(as returned from resolvers/snapshot.py)
"""
result = archive.Archive().get_snapshot_branches(
self.obj.SWHID.object_id,
after=self._get_after_arg(),
first=self._get_first_arg(),
target_types=self.kwargs.get("types"),
name_include=self.kwargs.get("nameInclude"),
)
# FIXME Cursor must be a hex to be consistent with
# the base class, hack to make that work
end_cusrsor = (
result["next_branch"].hex() if result["next_branch"] is not None else None
)
# FIXME, this pagination is not consistent with other connections
# FIX in swh-storage to return PagedResult
# STORAGE-TODO
return PagedResult(
results=result["branches"].items(), next_page_token=end_cusrsor
)
def _get_after_arg(self):
"""
Snapshot branch is using a different cursor; logic to handle that
"""
# FIXME Cursor must be a hex to be consistent with
# the base class, hack to make that work
after = utils.get_decoded_cursor(self.kwargs.get("after", ""))
return bytes.fromhex(after)
diff --git a/swh/graphql/resolvers/visit.py b/swh/graphql/resolvers/visit.py
index 877f1fd..a2222d1 100644
--- a/swh/graphql/resolvers/visit.py
+++ b/swh/graphql/resolvers/visit.py
@@ -1,51 +1,56 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from swh.graphql.backends import archive
from swh.graphql.utils import utils
from .base_connection import BaseConnection
from .base_node import BaseNode
class BaseVisitNode(BaseNode):
@property
def id(self):
# FIXME, use a better id
return utils.b64encode(f"{self.origin}-{str(self.visit)}")
@property
def visitId(self): # To support the schema naming convention
return self._node.visit
class OriginVisitNode(BaseVisitNode):
"""
Get the visit directly with an origin URL and a visit ID
"""
def _get_node_data(self):
return archive.Archive().get_origin_visit(
self.kwargs.get("originUrl"), int(self.kwargs.get("visitId"))
)
class LatestVisitNode(BaseVisitNode):
"""
Get the latest visit for an origin
self.obj is the origin object here
self.obj.url is the origin URL
"""
def _get_node_data(self):
return archive.Archive().get_origin_latest_visit(self.obj.url)
class OriginVisitConnection(BaseConnection):
_node_class = BaseVisitNode
def _get_paged_result(self):
"""
Get the visits for the given origin
parent obj (self.obj) is origin here
"""
return archive.Archive().get_origin_visits(
self.obj.url, after=self._get_after_arg(), first=self._get_first_arg()
)
diff --git a/swh/graphql/resolvers/visit_status.py b/swh/graphql/resolvers/visit_status.py
index 970a69f..7c9b33d 100644
--- a/swh/graphql/resolvers/visit_status.py
+++ b/swh/graphql/resolvers/visit_status.py
@@ -1,43 +1,48 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
from swh.graphql.backends import archive
from swh.model.swhids import CoreSWHID, ObjectType
from .base_connection import BaseConnection
from .base_node import BaseNode
class BaseVisitStatusNode(BaseNode):
""" """
@property
def snapshotSWHID(self): # To support the schema naming convention
return CoreSWHID(object_type=ObjectType.SNAPSHOT, object_id=self._node.snapshot)
class LatestVisitStatusNode(BaseVisitStatusNode):
"""
Get the latest visit status for a visit
self.obj is the visit object here
self.obj.origin is the origin URL
"""
def _get_node_data(self):
return archive.Archive().get_latest_visit_status(
self.obj.origin, self.obj.visitId
)
class VisitStatusConnection(BaseConnection):
"""
self.obj is the visit object
self.obj.origin is the origin URL
"""
_node_class = BaseVisitStatusNode
def _get_paged_result(self):
return archive.Archive().get_visit_status(
self.obj.origin,
self.obj.visitId,
after=self._get_after_arg(),
first=self._get_first_arg(),
)
diff --git a/swh/graphql/tests/unit/resolvers/test_base_node.py b/swh/graphql/tests/unit/resolvers/test_base_node.py
index c5f5968..94b2e29 100644
--- a/swh/graphql/tests/unit/resolvers/test_base_node.py
+++ b/swh/graphql/tests/unit/resolvers/test_base_node.py
@@ -1,25 +1,30 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
"""
"""
class TestBaseNode:
def test_init(self):
pass
def test_get_node(self):
pass
def test_get_node_from_data(self):
pass
def test_handle_node_errors(self):
pass
def test_get_node_data(self):
pass
def test_getattr(self):
pass
def test_is_type_of(sellf):
pass
diff --git a/swh/graphql/tests/unit/resolvers/test_resolver_factory.py b/swh/graphql/tests/unit/resolvers/test_resolver_factory.py
index 4d39ce9..7bb4400 100644
--- a/swh/graphql/tests/unit/resolvers/test_resolver_factory.py
+++ b/swh/graphql/tests/unit/resolvers/test_resolver_factory.py
@@ -1,58 +1,63 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
import pytest
from swh.graphql.resolvers import resolver_factory
class TestFactory:
@pytest.mark.parametrize(
- "input_type, expexted",
+ "input_type, expected",
[
("origin", "OriginNode"),
("visit", "OriginVisitNode"),
("latest-visit", "LatestVisitNode"),
("latest-status", "LatestVisitStatusNode"),
("visit-snapshot", "VisitSnapshotNode"),
("snapshot", "SnapshotNode"),
("branch-revision", "TargetRevisionNode"),
("branch-release", "TargetReleaseNode"),
("revision", "RevisionNode"),
("revision-directory", "RevisionDirectoryNode"),
("release", "ReleaseNode"),
("release-revision", "TargetRevisionNode"),
("release-release", "TargetReleaseNode"),
("release-directory", "TargetDirectoryNode"),
("release-content", "TargetContentNode"),
("directory", "DirectoryNode"),
("content", "ContentNode"),
("dir-entry-dir", "TargetDirectoryNode"),
("dir-entry-file", "TargetContentNode"),
],
)
- def test_get_node_resolver(self, input_type, expexted):
+ def test_get_node_resolver(self, input_type, expected):
response = resolver_factory.get_node_resolver(input_type)
- assert response.__name__ == expexted
+ assert response.__name__ == expected
def test_get_node_resolver_invalid_type(self):
with pytest.raises(AttributeError):
resolver_factory.get_node_resolver("invalid")
@pytest.mark.parametrize(
- "input_type, expexted",
+ "input_type, expected",
[
("origins", "OriginConnection"),
("origin-visits", "OriginVisitConnection"),
("origin-snapshots", "OriginSnapshotConnection"),
("visit-status", "VisitStatusConnection"),
("snapshot-branches", "SnapshotBranchConnection"),
("revision-parents", "ParentRevisionConnection"),
("revision-log", "LogRevisionConnection"),
("directory-entries", "DirectoryEntryConnection"),
],
)
- def test_get_connection_resolver(self, input_type, expexted):
+ def test_get_connection_resolver(self, input_type, expected):
response = resolver_factory.get_connection_resolver(input_type)
- assert response.__name__ == expexted
+ assert response.__name__ == expected
def test_get_connection_resolver_invalid_type(self):
with pytest.raises(AttributeError):
resolver_factory.get_connection_resolver("invalid")
diff --git a/swh/graphql/tests/unit/resolvers/test_resolvers.py b/swh/graphql/tests/unit/resolvers/test_resolvers.py
index 6ac5464..ff29a8f 100644
--- a/swh/graphql/tests/unit/resolvers/test_resolvers.py
+++ b/swh/graphql/tests/unit/resolvers/test_resolvers.py
@@ -1,115 +1,120 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
import pytest
from swh.graphql import resolvers
from swh.graphql.resolvers import resolvers as rs
class TestResolvers:
""" """
@pytest.fixture
def dummy_node(self):
return {"test": "test"}
@pytest.mark.parametrize(
"resolver_func, node_cls",
[
(rs.origin_resolver, resolvers.origin.OriginNode),
(rs.visit_resolver, resolvers.visit.OriginVisitNode),
(rs.latest_visit_resolver, resolvers.visit.LatestVisitNode),
(
rs.latest_visit_status_resolver,
resolvers.visit_status.LatestVisitStatusNode,
),
(rs.snapshot_resolver, resolvers.snapshot.SnapshotNode),
(rs.visit_snapshot_resolver, resolvers.snapshot.VisitSnapshotNode),
(rs.revision_resolver, resolvers.revision.RevisionNode),
(rs.revision_directory_resolver, resolvers.directory.RevisionDirectoryNode),
(rs.release_resolver, resolvers.release.ReleaseNode),
(rs.directory_resolver, resolvers.directory.DirectoryNode),
(rs.content_resolver, resolvers.content.ContentNode),
],
)
def test_node_resolver(self, mocker, dummy_node, resolver_func, node_cls):
mock_get = mocker.patch.object(node_cls, "_get_node", return_value=dummy_node)
node_obj = resolver_func(None, None)
# assert the _get_node method is called on the right object
assert isinstance(node_obj, node_cls)
assert mock_get.assert_called
@pytest.mark.parametrize(
"resolver_func, connection_cls",
[
(rs.origins_resolver, resolvers.origin.OriginConnection),
(rs.visits_resolver, resolvers.visit.OriginVisitConnection),
(rs.origin_snapshots_resolver, resolvers.snapshot.OriginSnapshotConnection),
(rs.visitstatus_resolver, resolvers.visit_status.VisitStatusConnection),
(
rs.snapshot_branches_resolver,
resolvers.snapshot_branch.SnapshotBranchConnection,
),
(rs.revision_parents_resolver, resolvers.revision.ParentRevisionConnection),
# (rs.revision_log_resolver, resolvers.revision.LogRevisionConnection),
(
rs.directory_entry_resolver,
resolvers.directory_entry.DirectoryEntryConnection,
),
],
)
def test_connection_resolver(self, resolver_func, connection_cls):
connection_obj = resolver_func(None, None)
# assert the right object is returned
assert isinstance(connection_obj, connection_cls)
@pytest.mark.parametrize(
"branch_type, node_cls",
[
("revision", resolvers.revision.TargetRevisionNode),
("release", resolvers.release.TargetReleaseNode),
],
)
def test_snapshot_branch_target_resolver(
self, mocker, dummy_node, branch_type, node_cls
):
obj = mocker.Mock(type=branch_type)
mock_get = mocker.patch.object(node_cls, "_get_node", return_value=dummy_node)
node_obj = rs.snapshot_branch_target_resolver(obj, None)
assert isinstance(node_obj, node_cls)
assert mock_get.assert_called
@pytest.mark.parametrize(
"target_type, node_cls",
[
("revision", resolvers.revision.TargetRevisionNode),
("release", resolvers.release.TargetReleaseNode),
("directory", resolvers.directory.TargetDirectoryNode),
("content", resolvers.content.TargetContentNode),
],
)
def test_release_target_resolver(self, mocker, dummy_node, target_type, node_cls):
obj = mocker.Mock(target_type=(mocker.Mock(value=target_type)))
mock_get = mocker.patch.object(node_cls, "_get_node", return_value=dummy_node)
node_obj = rs.release_target_resolver(obj, None)
assert isinstance(node_obj, node_cls)
assert mock_get.assert_called
@pytest.mark.parametrize(
"target_type, node_cls",
[
("dir", resolvers.directory.TargetDirectoryNode),
("file", resolvers.content.TargetContentNode),
],
)
def test_directory_entry_target_resolver(
self, mocker, dummy_node, target_type, node_cls
):
obj = mocker.Mock(type=target_type)
mock_get = mocker.patch.object(node_cls, "_get_node", return_value=dummy_node)
node_obj = rs.directory_entry_target_resolver(obj, None)
assert isinstance(node_obj, node_cls)
assert mock_get.assert_called
def test_unit_resolver(self, mocker):
obj = mocker.Mock()
obj.is_type_of.return_value = "test"
assert rs.union_resolver(obj) == "test"
diff --git a/swh/graphql/tests/unit/utils/test_utils.py b/swh/graphql/tests/unit/utils/test_utils.py
index a692a8d..bc7c5ac 100644
--- a/swh/graphql/tests/unit/utils/test_utils.py
+++ b/swh/graphql/tests/unit/utils/test_utils.py
@@ -1,61 +1,66 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
import datetime
from swh.graphql.utils import utils
class TestUtils:
def test_b64encode(self):
assert utils.b64encode("testing") == "dGVzdGluZw=="
def test_get_encoded_cursor_is_none(self):
assert utils.get_encoded_cursor(None) is None
def test_get_encoded_cursor(self):
assert utils.get_encoded_cursor(None) is None
assert utils.get_encoded_cursor("testing") == "dGVzdGluZw=="
def test_get_decoded_cursor_is_none(self):
assert utils.get_decoded_cursor(None) is None
def test_get_decoded_cursor(self):
assert utils.get_decoded_cursor("dGVzdGluZw==") == "testing"
def test_str_to_sha1(self):
assert (
utils.str_to_sha1("208f61cc7a5dbc9879ae6e5c2f95891e270f09ef")
== b" \x8fa\xccz]\xbc\x98y\xaen\\/\x95\x89\x1e'\x0f\t\xef"
)
def test_get_formatted_date(self):
date = datetime.datetime(
2015, 8, 4, 22, 26, 14, 804009, tzinfo=datetime.timezone.utc
)
assert utils.get_formatted_date(date) == "2015-08-04T22:26:14.804009+00:00"
def test_paginated(self):
source = [1, 2, 3, 4, 5]
response = utils.paginated(source, first=50)
assert response.results == source
assert response.next_page_token is None
def test_paginated_first_arg(self):
source = [1, 2, 3, 4, 5]
response = utils.paginated(source, first=2)
assert response.results == source[:2]
assert response.next_page_token == "2"
def test_paginated_after_arg(self):
source = [1, 2, 3, 4, 5]
response = utils.paginated(source, first=2, after="2")
assert response.results == [3, 4]
assert response.next_page_token == "4"
response = utils.paginated(source, first=2, after="3")
assert response.results == [4, 5]
assert response.next_page_token is None
def test_paginated_endcursor_outside(self):
source = [1, 2, 3, 4, 5]
response = utils.paginated(source, first=2, after="10")
assert response.results == []
assert response.next_page_token is None
diff --git a/swh/graphql/utils/utils.py b/swh/graphql/utils/utils.py
index 77748a7..c6f4094 100644
--- a/swh/graphql/utils/utils.py
+++ b/swh/graphql/utils/utils.py
@@ -1,49 +1,54 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
import base64
from datetime import datetime
from typing import List
from swh.storage.interface import PagedResult
def b64encode(text: str) -> str:
return base64.b64encode(bytes(text, "utf-8")).decode("utf-8")
def get_encoded_cursor(cursor: str) -> str:
if cursor is None:
return None
return b64encode(cursor)
def get_decoded_cursor(cursor: str) -> str:
if cursor is None:
return None
return base64.b64decode(cursor).decode("utf-8")
def str_to_sha1(sha1: str) -> bytearray:
# FIXME, use core function
return bytearray.fromhex(sha1)
def get_formatted_date(date: datetime) -> str:
# FIXME, handle error + return other formats
return date.isoformat()
def paginated(source: List, first: int, after=0) -> PagedResult:
"""
Pagination at the GraphQL level
This is a temporary fix and inefficient.
Should eventually be moved to the
backend (storage) level
"""
# FIXME, handle data errors here
after = 0 if after is None else int(after)
end_cursor = after + first
results = source[after:end_cursor]
next_page_token = None
if len(source) > end_cursor:
next_page_token = str(end_cursor)
return PagedResult(results=results, next_page_token=next_page_token)

File Metadata

Mime Type
text/x-diff
Expires
Fri, Jul 4, 1:52 PM (4 d, 6 h ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3262774

Event Timeline