diff --git a/config/staging.yml b/config/staging.yml index 03e7c28..2e7ee1f 100644 --- a/config/staging.yml +++ b/config/staging.yml @@ -1,11 +1,11 @@ storage: cls: remote url: http://webapp.internal.staging.swh.network:5002 search: cls: remote url: http://webapp.internal.staging.swh.network:5010 -debug: yes +debug: no server-type: wsgi diff --git a/swh/graphql/errors/__init__.py b/swh/graphql/errors/__init__.py index 7bc04e9..43fa532 100644 --- a/swh/graphql/errors/__init__.py +++ b/swh/graphql/errors/__init__.py @@ -1,9 +1,14 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from .errors import ObjectNotFoundError, PaginationError +from .errors import InvalidInputError, ObjectNotFoundError, PaginationError from .handlers import format_error -__all__ = ["ObjectNotFoundError", "PaginationError", "format_error"] +__all__ = [ + "ObjectNotFoundError", + "PaginationError", + "InvalidInputError", + "format_error", +] diff --git a/swh/graphql/errors/errors.py b/swh/graphql/errors/errors.py index 37ecd35..478e353 100644 --- a/swh/graphql/errors/errors.py +++ b/swh/graphql/errors/errors.py @@ -1,18 +1,31 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information class ObjectNotFoundError(Exception): """ """ + msg: str = "Object error" + + def __init__(self, message, errors=None): + super().__init__(f"{self.msg}: {message}") + class PaginationError(Exception): """ """ msg: str = "Pagination error" def __init__(self, message, errors=None): - # FIXME, log this error + super().__init__(f"{self.msg}: {message}") + + +class InvalidInputError(Exception): + """ """ + + msg: str = "Input error" + + def __init__(self, message, errors=None): super().__init__(f"{self.msg}: {message}") diff --git a/swh/graphql/errors/handlers.py b/swh/graphql/errors/handlers.py index c61e593..6c75c74 100644 --- a/swh/graphql/errors/handlers.py +++ b/swh/graphql/errors/handlers.py @@ -1,13 +1,29 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from ariadne import format_error as original_format_error +from graphql import GraphQLError +import sentry_sdk -def format_error(error) -> dict: +from .errors import InvalidInputError, ObjectNotFoundError, PaginationError + + +def format_error(error: GraphQLError, debug: bool = False): """ Response error formatting """ + original_format = original_format_error(error, debug) + if debug: + # If debug is enabled, reuse Ariadne's formatting logic with stack trace + return original_format + + expected_errors = [ObjectNotFoundError, PaginationError, InvalidInputError] formatted = error.formatted - formatted["message"] = "Unknown error" + formatted["message"] = error.message + if type(error.original_error) not in expected_errors: + # a crash, send to sentry + sentry_sdk.capture_exception(error) + # FIXME log the original_format to kibana (with stack trace) return formatted diff --git a/swh/graphql/resolvers/scalars.py b/swh/graphql/resolvers/scalars.py index f832055..0bf757e 100644 --- a/swh/graphql/resolvers/scalars.py +++ b/swh/graphql/resolvers/scalars.py @@ -1,60 +1,58 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime from ariadne import ScalarType +from swh.graphql.errors import InvalidInputError from swh.graphql.utils import utils from swh.model import hashutil from swh.model.model import TimestampWithTimezone from swh.model.swhids import CoreSWHID datetime_scalar = ScalarType("DateTime") swhid_scalar = ScalarType("SWHID") id_scalar = ScalarType("ID") content_hash_scalar = ScalarType("ContentHash") @id_scalar.serializer def serialize_id(value): if type(value) is bytes: return value.hex() return value @datetime_scalar.serializer def serialize_datetime(value): # FIXME, handle error and return None if type(value) == TimestampWithTimezone: value = value.to_datetime() if type(value) == datetime: return utils.get_formatted_date(value) return None @swhid_scalar.value_parser def validate_swhid(value): return CoreSWHID.from_string(value) @swhid_scalar.serializer def serialize_swhid(value): return str(value) @content_hash_scalar.value_parser def validate_content_hash(value): try: hash_type, hash_string = value.split(":") hash_value = hashutil.hash_to_bytes(hash_string) except ValueError as e: - # FIXME, log this error - raise AttributeError("Invalid content checksum", e) - except Exception as e: - # FIXME, log this error - raise AttributeError("Invalid content checksum", e) - # FIXME, add validation for the hash_type + raise InvalidInputError("Invalid content checksum", e) + if hash_type not in hashutil.ALGORITHMS: + raise InvalidInputError("Invalid hash algorithm") return hash_type, hash_value diff --git a/swh/graphql/server.py b/swh/graphql/server.py index d5b40fb..5df83d9 100644 --- a/swh/graphql/server.py +++ b/swh/graphql/server.py @@ -1,88 +1,91 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os from typing import Any, Dict, Optional from swh.core import config from swh.search import get_search as get_swh_search from swh.storage import get_storage as get_swh_storage graphql_cfg = None storage = None search = None def get_storage(): global storage if not storage: storage = get_swh_storage(**graphql_cfg["storage"]) return storage def get_search(): global search if not search: search = get_swh_search(**graphql_cfg["search"]) return search def load_and_check_config(config_path: Optional[str]) -> Dict[str, Any]: """Check the minimal configuration is set to run the api or raise an error explanation. Args: config_path: Path to the configuration file to load Raises: Error if the setup is not as expected Returns: configuration as a dict """ if not config_path: raise EnvironmentError("Configuration file must be defined") if not os.path.exists(config_path): raise FileNotFoundError(f"Configuration file {config_path} does not exist") cfg = config.read(config_path) if "storage" not in cfg: raise KeyError("Missing 'storage' configuration") return cfg def make_app_from_configfile(): """Loading the configuration from a configuration file. SWH_CONFIG_FILENAME environment variable defines the configuration path to load. """ from .app import schema + from .errors.handlers import format_error global graphql_cfg if not graphql_cfg: config_path = os.environ.get("SWH_CONFIG_FILENAME") graphql_cfg = load_and_check_config(config_path) server_type = graphql_cfg.get("server-type") if server_type == "asgi": from ariadne.asgi import GraphQL from starlette.middleware.cors import CORSMiddleware # Enable cors in the asgi version application = CORSMiddleware( - GraphQL(schema), + GraphQL(schema, debug=graphql_cfg["debug"], error_formatter=format_error), allow_origins=["*"], allow_methods=("GET", "POST", "OPTIONS"), ) else: from ariadne.wsgi import GraphQL - application = GraphQL(schema) + application = GraphQL( + schema, debug=graphql_cfg["debug"], error_formatter=format_error + ) return application diff --git a/swh/graphql/tests/functional/test_content.py b/swh/graphql/tests/functional/test_content.py index 901b15e..1de014d 100644 --- a/swh/graphql/tests/functional/test_content.py +++ b/swh/graphql/tests/functional/test_content.py @@ -1,149 +1,163 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest from . import utils from ..data import get_contents @pytest.mark.parametrize("content", get_contents()) def test_get_content_with_swhid(client, content): query_str = """ { content(swhid: "%s") { swhid checksum { blake2s256 sha1 sha1_git sha256 } length status data { url } fileType { encoding } language { lang } license { licenses } } } """ data, _ = utils.get_query_response(client, query_str % content.swhid()) archive_url = "https://archive.softwareheritage.org/api/1/" response = { "swhid": str(content.swhid()), "checksum": { "blake2s256": content.blake2s256.hex(), "sha1": content.sha1.hex(), "sha1_git": content.sha1_git.hex(), "sha256": content.sha256.hex(), }, "length": content.length, "status": content.status, "data": { "url": f"{archive_url}content/sha1:{content.sha1.hex()}/raw/", }, "fileType": None, "language": None, "license": None, } assert data["content"] == response @pytest.mark.parametrize("content", get_contents()) def test_get_content_with_hash(client, content): query_str = """ { contentByHash(checksums: ["blake2s256:%s", "sha1:%s", "sha1_git:%s", "sha256:%s"]) { swhid } } """ data, _ = utils.get_query_response( client, query_str % ( content.blake2s256.hex(), content.sha1.hex(), content.sha1_git.hex(), content.sha256.hex(), ), ) assert data["contentByHash"] == {"swhid": str(content.swhid())} def test_get_content_with_invalid_swhid(client): query_str = """ { content(swhid: "swh:1:cnt:invalid") { swhid } } """ errors = utils.get_error_response(client, query_str) # API will throw an error in case of an invalid SWHID assert len(errors) == 1 assert "Invalid SWHID: invalid syntax" in errors[0]["message"] def test_get_content_with_invalid_hashes(client): content = get_contents()[0] query_str = """ { contentByHash(checksums: ["blake2s256:%s", "sha1:%s", "sha1_git:%s", "sha256:%s"]) { swhid } } """ errors = utils.get_error_response( client, query_str % ( "invalid", # Only one hash is invalid content.sha1.hex(), content.sha1_git.hex(), content.sha256.hex(), ), ) # API will throw an error in case of an invalid content hash assert len(errors) == 1 - assert "Invalid content checksum" in errors[0]["message"] + assert "Input error: Invalid content checksum" in errors[0]["message"] + + +def test_get_content_with_invalid_hash_algorithm(client): + content = get_contents()[0] + query_str = """ + { + contentByHash(checksums: ["test:%s"]) { + swhid + } + } + """ + errors = utils.get_error_response(client, query_str % content.sha1.hex()) + assert len(errors) == 1 + assert "Input error: Invalid hash algorithm" in errors[0]["message"] def test_get_content_as_target(client): # SWHID of a test dir with a file entry directory_swhid = "swh:1:dir:87b339104f7dc2a8163dec988445e3987995545f" query_str = """ { directory(swhid: "%s") { swhid entries(first: 2) { nodes { type target { ...on Content { swhid length } } } } } } """ data, _ = utils.get_query_response(client, query_str % directory_swhid) content_obj = data["directory"]["entries"]["nodes"][1]["target"] assert content_obj == { "length": 4, "swhid": "swh:1:cnt:86bc6b377e9d25f9d26777a4a28d08e63e7c5779", } diff --git a/swh/graphql/tests/functional/utils.py b/swh/graphql/tests/functional/utils.py index 43207dd..340500c 100644 --- a/swh/graphql/tests/functional/utils.py +++ b/swh/graphql/tests/functional/utils.py @@ -1,36 +1,36 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import json from typing import Dict, Tuple def get_response(client, query_str: str): return client.post("/", json={"query": query_str}) def get_query_response(client, query_str: str) -> Tuple[Dict, Dict]: response = get_response(client, query_str) assert response.status_code == 200, response.data result = json.loads(response.data) return result.get("data"), result.get("errors") def assert_missing_object(client, query_str: str, obj_type: str) -> None: data, errors = get_query_response(client, query_str) assert data[obj_type] is None assert len(errors) == 1 - assert errors[0]["message"] == "Requested object is not available" + assert errors[0]["message"] == "Object error: Requested object is not available" def get_error_response(client, query_str: str, error_code: int = 400) -> Dict: response = get_response(client, query_str) assert response.status_code == error_code return json.loads(response.data)["errors"] def get_query_params_from_args(**args) -> str: # build a GraphQL query parameters string from arguments return ",".join([f"{key}: {val}" for (key, val) in args.items()]) diff --git a/swh/graphql/tests/unit/errors/test_errors.py b/swh/graphql/tests/unit/errors/test_errors.py new file mode 100644 index 0000000..9d12300 --- /dev/null +++ b/swh/graphql/tests/unit/errors/test_errors.py @@ -0,0 +1,53 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from graphql import GraphQLError +import pytest +import sentry_sdk + +from swh.graphql import errors + + +def test_errors(): + err = errors.ObjectNotFoundError("test error") + assert str(err) == "Object error: test error" + + err = errors.PaginationError("test error") + assert str(err) == "Pagination error: test error" + + err = errors.InvalidInputError("test error") + assert str(err) == "Input error: test error" + + +def test_format_error_with_debug(): + err = GraphQLError("test error") + response = errors.format_error(err, debug=True) + assert "extensions" in response + + +def test_format_error_without_debug(): + err = GraphQLError("test error") + response = errors.format_error(err) + assert "extensions" not in response + + +def test_format_error_sent_to_sentry(mocker): + mocked_senty_call = mocker.patch.object(sentry_sdk, "capture_exception") + err = GraphQLError("test error") + err.original_error = NameError("test error") # not an expected error + errors.format_error(err) + mocked_senty_call.assert_called_once_with(err) + + +@pytest.mark.parametrize( + "error", + [errors.ObjectNotFoundError, errors.PaginationError, errors.InvalidInputError], +) +def test_format_error_skip_sentry(mocker, error): + mocked_senty_call = mocker.patch.object(sentry_sdk, "capture_exception") + err = GraphQLError("test error") + err.original_error = error("test error") + errors.format_error(err) + mocked_senty_call.assert_not_called