diff --git a/config/dev.yml b/config/dev.yml index 2e858f7..e55d2db 100644 --- a/config/dev.yml +++ b/config/dev.yml @@ -1,11 +1,9 @@ storage: cls: remote url: http://moma.internal.softwareheritage.org:5002 search: cls: remote url: http://moma.internal.softwareheritage.org:5010 debug: yes - -server-type: asgi diff --git a/config/staging.yml b/config/staging.yml index f989165..3835837 100644 --- a/config/staging.yml +++ b/config/staging.yml @@ -1,11 +1,9 @@ storage: cls: remote url: http://webapp.internal.staging.swh.network:5002 search: cls: remote url: http://webapp.internal.staging.swh.network:5010 debug: no - -server-type: asgi diff --git a/swh/graphql/backends/archive.py b/swh/graphql/backends/archive.py index d1bf0d9..88226b1 100644 --- a/swh/graphql/backends/archive.py +++ b/swh/graphql/backends/archive.py @@ -1,147 +1,146 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os from typing import Any, Dict, Iterable, List, Optional from swh.graphql import server from swh.model.model import ( Content, DirectoryEntry, Origin, OriginVisit, OriginVisitStatus, Release, Revision, - Sha1, Sha1Git, ) from swh.model.swhids import ObjectType from swh.storage.interface import PagedResult, PartialBranches, StorageInterface class Archive: def __init__(self) -> None: self.storage: StorageInterface = server.get_storage() def get_origin(self, url: str) -> Optional[Origin]: return list(self.storage.origin_get(origins=[url]))[0] def get_origins( self, after: Optional[str] = None, first: int = 50 ) -> PagedResult[Origin]: return self.storage.origin_list(page_token=after, limit=first) def get_origin_visits( self, origin_url: str, after: Optional[str] = None, first: int = 50 ) -> PagedResult[OriginVisit]: return self.storage.origin_visit_get( origin=origin_url, page_token=after, limit=first ) def get_origin_visit(self, origin_url: str, visit_id: int) -> Optional[OriginVisit]: return self.storage.origin_visit_get_by(origin=origin_url, visit=visit_id) def get_origin_latest_visit( self, origin_url: str, visit_type: Optional[str] = None, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, ) -> Optional[OriginVisit]: return self.storage.origin_visit_get_latest( origin=origin_url, type=visit_type, allowed_statuses=allowed_statuses, require_snapshot=require_snapshot, ) def get_visit_status( self, origin_url: str, visit_id: int, after: Optional[str] = None, first: int = 50, ) -> PagedResult[OriginVisitStatus]: return self.storage.origin_visit_status_get( origin=origin_url, visit=visit_id, page_token=after, limit=first ) def get_latest_visit_status( self, origin_url: str, visit_id: int, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, ) -> Optional[OriginVisitStatus]: return self.storage.origin_visit_status_get_latest( origin_url=origin_url, visit=visit_id, allowed_statuses=allowed_statuses, require_snapshot=require_snapshot, ) def get_origin_snapshots(self, origin_url: str) -> List[Sha1Git]: return self.storage.origin_snapshot_get_all(origin_url=origin_url) def get_snapshot_branches( self, snapshot: Sha1Git, after: bytes = b"", first: int = 50, target_types: Optional[List[str]] = None, name_include: Optional[bytes] = None, name_exclude_prefix: Optional[bytes] = None, ) -> Optional[PartialBranches]: return self.storage.snapshot_get_branches( snapshot_id=snapshot, branches_from=after, branches_count=first, target_types=target_types, branch_name_include_substring=name_include, branch_name_exclude_prefix=name_exclude_prefix, ) def get_revisions(self, revision_ids: List[Sha1Git]) -> List[Optional[Revision]]: return self.storage.revision_get(revision_ids=revision_ids) def get_revision_log( self, revision_ids: List[Sha1Git], first: int = 50 ) -> Iterable[Optional[Dict[str, Any]]]: return self.storage.revision_log(revisions=revision_ids, limit=first) def get_releases(self, release_ids: List[Sha1Git]) -> List[Optional[Release]]: return self.storage.release_get(releases=release_ids) def get_directory_entry_by_path( self, directory_id: Sha1Git, path: str ) -> Optional[Dict[str, Any]]: paths = [x.encode() for x in path.strip(os.path.sep).split(os.path.sep)] return self.storage.directory_entry_get_by_path( directory=directory_id, paths=paths ) def get_directory_entries( self, directory_id: Sha1Git, after: Optional[bytes] = None, first: int = 50 ) -> Optional[PagedResult[DirectoryEntry]]: return self.storage.directory_get_entries( directory_id=directory_id, limit=first, page_token=after ) def is_object_available(self, object_id: bytes, object_type: ObjectType) -> bool: mapping = { ObjectType.CONTENT: self.storage.content_missing_per_sha1_git, ObjectType.DIRECTORY: self.storage.directory_missing, ObjectType.RELEASE: self.storage.release_missing, ObjectType.REVISION: self.storage.revision_missing, ObjectType.SNAPSHOT: self.storage.snapshot_missing, } return not list(mapping[object_type]([object_id])) def get_contents(self, checksums: Dict[str, Any]) -> List[Content]: return self.storage.content_find(content=checksums) - def get_content_data(self, content_sha1: Sha1) -> Optional[bytes]: - return self.storage.content_get_data(content=content_sha1) + # def get_content_data(self, content_sha1: Sha1) -> Optional[bytes]: + # return self.storage.content_get_data(content=content_sha1) diff --git a/swh/graphql/resolvers/person.py b/swh/graphql/resolvers/person.py index c805924..0707230 100644 --- a/swh/graphql/resolvers/person.py +++ b/swh/graphql/resolvers/person.py @@ -1,10 +1,4 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information - -from .base_node import BaseNode - - -class PersonNode(BaseNode): - """ """ diff --git a/swh/graphql/resolvers/scalars.py b/swh/graphql/resolvers/scalars.py index 4a427bf..3693ac1 100644 --- a/swh/graphql/resolvers/scalars.py +++ b/swh/graphql/resolvers/scalars.py @@ -1,63 +1,63 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime from ariadne import ScalarType from swh.graphql.errors import InvalidInputError from swh.graphql.utils import utils from swh.model import hashutil from swh.model.exceptions import ValidationError from swh.model.model import TimestampWithTimezone from swh.model.swhids import CoreSWHID datetime_scalar = ScalarType("DateTime") swhid_scalar = ScalarType("SWHID") id_scalar = ScalarType("ID") content_hash_scalar = ScalarType("ContentHash") @id_scalar.serializer -def serialize_id(value): - if type(value) is bytes: - return value.hex() - return value +def serialize_id(value) -> str: + if type(value) is str: + value = value.encode() + return value.hex() @datetime_scalar.serializer def serialize_datetime(value): # FIXME, handle error and return None if type(value) == TimestampWithTimezone: value = value.to_datetime() if type(value) == datetime: return utils.get_formatted_date(value) return None @swhid_scalar.value_parser def validate_swhid(value): try: swhid = CoreSWHID.from_string(value) except ValidationError as e: raise InvalidInputError("Invalid SWHID", e) return swhid @swhid_scalar.serializer def serialize_swhid(value): return str(value) @content_hash_scalar.value_parser def validate_content_hash(value): try: hash_type, hash_string = value.split(":") hash_value = hashutil.hash_to_bytes(hash_string) except ValueError as e: raise InvalidInputError("Invalid content checksum", e) if hash_type not in hashutil.ALGORITHMS: raise InvalidInputError("Invalid hash algorithm") return hash_type, hash_value diff --git a/swh/graphql/server.py b/swh/graphql/server.py index da15189..ef1fbb6 100644 --- a/swh/graphql/server.py +++ b/swh/graphql/server.py @@ -1,88 +1,85 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os from typing import Any, Dict, Optional from swh.core import config from swh.search import get_search as get_swh_search from swh.search.interface import SearchInterface from swh.storage import get_storage as get_swh_storage from swh.storage.interface import StorageInterface graphql_cfg: Dict[str, Any] = {} storage: Optional[StorageInterface] = None search: Optional[SearchInterface] = None def get_storage() -> StorageInterface: global storage if not storage: storage = get_swh_storage(**graphql_cfg["storage"]) return storage def get_search() -> SearchInterface: global search if not search: search = get_swh_search(**graphql_cfg["search"]) return search def load_and_check_config(config_path: Optional[str]) -> Dict[str, Any]: """Check the minimal configuration is set to run the api or raise an error explanation. Args: config_path: Path to the configuration file to load Raises: Error if the setup is not as expected Returns: configuration as a dict """ if not config_path: raise EnvironmentError("Configuration file must be defined") if not os.path.exists(config_path): raise FileNotFoundError(f"Configuration file {config_path} does not exist") cfg = config.read(config_path) if "storage" not in cfg: raise KeyError("Missing 'storage' configuration") return cfg def make_app_from_configfile(): """Loading the configuration from a configuration file. SWH_CONFIG_FILENAME environment variable defines the configuration path to load. """ + from ariadne.asgi import GraphQL from starlette.middleware.cors import CORSMiddleware from .app import schema from .errors import format_error global graphql_cfg if not graphql_cfg: config_path = os.environ.get("SWH_CONFIG_FILENAME") graphql_cfg = load_and_check_config(config_path) - server_type = graphql_cfg.get("server-type") - if server_type == "asgi": - from ariadne.asgi import GraphQL - - application = CORSMiddleware( - GraphQL(schema, debug=graphql_cfg["debug"], error_formatter=format_error), - # FIXME, restrict origins after deploying the JS client - allow_origins=["*"], - allow_methods=("GET", "POST", "OPTIONS"), - ) + application = CORSMiddleware( + GraphQL(schema, debug=graphql_cfg["debug"], error_formatter=format_error), + # FIXME, restrict origins after deploying the JS client + allow_origins=["*"], + allow_methods=("GET", "POST", "OPTIONS"), + ) return application diff --git a/swh/graphql/tests/functional/test_content.py b/swh/graphql/tests/functional/test_content.py index fae6d77..aad5b0b 100644 --- a/swh/graphql/tests/functional/test_content.py +++ b/swh/graphql/tests/functional/test_content.py @@ -1,183 +1,185 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest from . import utils from ..data import get_contents @pytest.mark.parametrize("content", get_contents()) def test_get_content_with_swhid(client, content): query_str = """ query getContent($swhid: SWHID!) { content(swhid: $swhid) { swhid + id checksum { blake2s256 sha1 sha1_git sha256 } length status data { url } fileType { encoding } language { lang } license { licenses } } } """ data, _ = utils.get_query_response(client, query_str, swhid=str(content.swhid())) archive_url = "https://archive.softwareheritage.org/api/1/" response = { "swhid": str(content.swhid()), + "id": content.sha1_git.hex(), "checksum": { "blake2s256": content.blake2s256.hex(), "sha1": content.sha1.hex(), "sha1_git": content.sha1_git.hex(), "sha256": content.sha256.hex(), }, "length": content.length, "status": content.status, "data": { "url": f"{archive_url}content/sha1:{content.sha1.hex()}/raw/", }, "fileType": None, "language": None, "license": None, } assert data["content"] == response @pytest.mark.parametrize("content", get_contents()) def test_get_content_with_hash(client, content): query_str = """ query getContent($checksums: [ContentHash]!) { contentByHash(checksums: $checksums) { swhid } } """ data, _ = utils.get_query_response( client, query_str, checksums=[ f"blake2s256:{content.blake2s256.hex()}", f"sha1:{content.sha1.hex()}", f"sha1_git:{content.sha1_git.hex()}", f"sha256:{content.sha256.hex()}", ], ) assert data["contentByHash"] == {"swhid": str(content.swhid())} def test_get_content_with_invalid_swhid(client): query_str = """ query getContent($swhid: SWHID!) { content(swhid: $swhid) { swhid } } """ errors = utils.get_error_response(client, query_str, swhid="invalid") # API will throw an error in case of an invalid SWHID assert len(errors) == 1 assert "Input error: Invalid SWHID" in errors[0]["message"] def test_get_content_with_invalid_hashes(client): content = get_contents()[0] query_str = """ query getContent($checksums: [ContentHash]!) { contentByHash(checksums: $checksums) { swhid } } """ errors = utils.get_error_response( client, query_str, checksums=[ "invalid", # Only one hash is invalid f"sha1:{content.sha1.hex()}", f"sha1_git:{content.sha1_git.hex()}", f"sha256:{content.sha256.hex()}", ], ) # API will throw an error in case of an invalid content hash assert len(errors) == 1 assert "Input error: Invalid content checksum" in errors[0]["message"] def test_get_content_with_invalid_hash_algorithm(client): content = get_contents()[0] query_str = """ query getContent($checksums: [ContentHash]!) { contentByHash(checksums: $checksums) { swhid } } """ data, errors = utils.get_query_response( client, query_str, checksums=[f"test:{content.sha1.hex()}"] ) assert data is None assert len(errors) == 1 assert "Input error: Invalid hash algorithm" in errors[0]["message"] def test_get_content_as_target(client): # SWHID of a test dir with a file entry directory_swhid = "swh:1:dir:87b339104f7dc2a8163dec988445e3987995545f" query_str = """ query getDirectory($swhid: SWHID!) { directory(swhid: $swhid) { swhid entries(first: 2) { nodes { targetType target { ...on Content { swhid length } } } } } } """ data, _ = utils.get_query_response(client, query_str, swhid=directory_swhid) content_obj = data["directory"]["entries"]["nodes"][1]["target"] assert content_obj == { "length": 4, "swhid": "swh:1:cnt:86bc6b377e9d25f9d26777a4a28d08e63e7c5779", } def test_get_content_with_unknown_swhid(client): unknown_sha1 = "1" * 40 query_str = """ query getDirectory($swhid: SWHID!) { content(swhid: $swhid) { swhid } } """ utils.assert_missing_object( client, query_str, obj_type="content", swhid=f"swh:1:cnt:{unknown_sha1}", ) diff --git a/swh/graphql/tests/unit/resolvers/test_resolver_factory.py b/swh/graphql/tests/unit/resolvers/test_resolver_factory.py index ed6cb9f..691e6b6 100644 --- a/swh/graphql/tests/unit/resolvers/test_resolver_factory.py +++ b/swh/graphql/tests/unit/resolvers/test_resolver_factory.py @@ -1,18 +1,18 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest from swh.graphql.resolvers import resolver_factory class TestFactory: def test_get_node_resolver_invalid_type(self): with pytest.raises(AttributeError): resolver_factory.NodeObjectFactory().create("invalid", None, None) def test_get_connection_resolver_invalid_type(self): with pytest.raises(AttributeError): - resolver_factory.get_connection_resolver("invalid", None, None) + resolver_factory.ConnectionObjectFactory().create("invalid", None, None) diff --git a/swh/graphql/tests/unit/resolvers/test_scalars.py b/swh/graphql/tests/unit/resolvers/test_scalars.py new file mode 100644 index 0000000..9c4dc41 --- /dev/null +++ b/swh/graphql/tests/unit/resolvers/test_scalars.py @@ -0,0 +1,54 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import datetime + +import pytest + +from swh.graphql.errors import InvalidInputError +from swh.graphql.resolvers import scalars + + +def test_serialize_id(): + assert scalars.serialize_id("test") == "74657374" + assert scalars.serialize_id(b"test") == "74657374" + + +def test_serialize_datetime(): + assert scalars.serialize_datetime("invalid") is None + # python datetime + date = datetime.datetime(2020, 5, 17) + assert scalars.serialize_datetime(date) == date.isoformat() + # FIXME, Timestamp with timezone + + +def test_validate_swhid_invalid(): + with pytest.raises(InvalidInputError): + scalars.validate_swhid("invalid") + + +def test_validate_swhid(): + swhid = scalars.validate_swhid(f"swh:1:rev:{'1' * 40}") + assert str(swhid) == "swh:1:rev:1111111111111111111111111111111111111111" + + +@pytest.mark.parametrize("content_hash", ["invalid", "test:invalid"]) +def test_validate_content_hash_invalid_value(content_hash): + with pytest.raises(InvalidInputError) as e: + scalars.validate_content_hash(content_hash) + assert "Invalid content checksum" in str(e.value) + + +def test_validate_content_hash_invalid_hash_algo(): + with pytest.raises(InvalidInputError) as e: + scalars.validate_content_hash(f"invalid:{'1' * 40}") + assert "Invalid hash algorithm" in str(e.value) + + +def test_validate_content_hash(): + assert ( + "sha1", + b"\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11", + ) == scalars.validate_content_hash(f"sha1:{'1' * 40}") diff --git a/swh/graphql/tests/unit/test_server.py b/swh/graphql/tests/unit/test_server.py new file mode 100644 index 0000000..f968649 --- /dev/null +++ b/swh/graphql/tests/unit/test_server.py @@ -0,0 +1,67 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import pytest + +from swh.graphql import server + + +def test_get_storage(mocker): + server.storage = None + server.graphql_cfg = {"storage": {"test": "test"}} + mocker.patch("swh.graphql.server.get_swh_storage", return_value="dummy-storage") + assert server.get_storage() == "dummy-storage" + + +def test_get_global_storage(mocker): + server.storage = "existing-storage" + assert server.get_storage() == "existing-storage" + + +def test_get_search(mocker): + server.search = None + server.graphql_cfg = {"search": {"test": "test"}} + mocker.patch("swh.graphql.server.get_swh_search", return_value="dummy-search") + assert server.get_search() == "dummy-search" + + +def test_get_global_search(mocker): + server.search = "existing-search" + assert server.get_search() == "existing-search" + + +def test_load_and_check_config_no_config(): + with pytest.raises(EnvironmentError): + server.load_and_check_config(config_path=None) + + +def test_load_and_check_config_missing_config_file(): + with pytest.raises(FileNotFoundError): + server.load_and_check_config(config_path="invalid") + + +def test_load_and_check_config_missing_storage_config(mocker): + mocker.patch("swh.core.config.read", return_value={"test": "test"}) + with pytest.raises(KeyError): + server.load_and_check_config(config_path="/tmp") + + +def test_load_and_check_config(mocker): + mocker.patch("swh.core.config.read", return_value={"storage": {"test": "test"}}) + cfg = server.load_and_check_config(config_path="/tmp") + assert cfg == {"storage": {"test": "test"}} + + +def test_make_app_from_configfile_with_config(mocker): + server.graphql_cfg = {"storage": {"test": "test"}, "debug": True} + mocker.patch("starlette.middleware.cors.CORSMiddleware", return_value="dummy-app") + assert server.make_app_from_configfile() == "dummy-app" + + +def test_make_app_from_configfile_missing_config(mocker): + server.graphql_cfg = None + with pytest.raises(EnvironmentError): + # trying to load config from a non existing env var + assert server.make_app_from_configfile()