diff --git a/config/dev.yml b/config/dev.yml --- a/config/dev.yml +++ b/config/dev.yml @@ -9,3 +9,7 @@ debug: yes introspection: yes + +max_query_cost: + anonymous: 500 + diff --git a/config/staging.yml b/config/staging.yml --- a/config/staging.yml +++ b/config/staging.yml @@ -9,3 +9,7 @@ debug: no introspection: yes + +max_query_cost: + anonymous: 500 + diff --git a/swh/graphql/app.py b/swh/graphql/app.py --- a/swh/graphql/app.py +++ b/swh/graphql/app.py @@ -1,4 +1,4 @@ -# Copyright (C) 2022 The Software Heritage developers +# copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -8,6 +8,7 @@ from pathlib import Path from ariadne import gql, load_schema_from_path, make_executable_schema +from ariadne.validation import cost_validator from .resolvers import resolvers, scalars @@ -41,3 +42,15 @@ scalars.datetime_scalar, scalars.swhid_scalar, ) + + +def validation_rules(context_value, document, data): + from .server import graphql_cfg + + # add logic to set max_query_cost depending on user type + max_query_cost = graphql_cfg["max_query_cost"]["anonymous"] + if max_query_cost: + return [ + cost_validator(maximum_cost=max_query_cost, variables=data.get("variables")) + ] + return None diff --git a/swh/graphql/schema/schema.graphql b/swh/graphql/schema/schema.graphql --- a/swh/graphql/schema/schema.graphql +++ b/swh/graphql/schema/schema.graphql @@ -1,3 +1,5 @@ +directive @cost(complexity: Int, multipliers: [String!], useMultipliers: Boolean) on FIELD | FIELD_DEFINITION + """ SoftWare Heritage persistent Identifier """ @@ -129,7 +131,7 @@ Returns the page after this cursor """ after: String - ): VisitConnection + ): VisitConnection! @cost(complexity: 1, multipliers: ["first"]) """ Latest visit object for the origin @@ -149,7 +151,7 @@ If True, the latest visit with a snapshot will be returned """ requireSnapshot: Boolean - ): Visit + ): Visit @cost(complexity: 1) """ Connection to all the snapshots for the origin @@ -164,7 +166,7 @@ Returns the page after this cursor """ after: String - ): SnapshotConnection + ): SnapshotConnection @cost(complexity: 2, multipliers: ["first"]) # This costs more because of local (graphql level) pagination } """ @@ -259,7 +261,7 @@ Returns the page after this cursor """ after: String - ): VisitStatusConnection + ): VisitStatusConnection @cost(complexity: 3) # here first is optional, hence adding a higher value for cost """ Latest status object for the Visit @@ -274,7 +276,7 @@ Filter by the availability of a snapshot in the status """ requireSnapshot: Boolean - ): VisitStatus + ): VisitStatus @cost(complexity: 1) } """ @@ -337,7 +339,7 @@ """ Snapshot object """ - snapshot: Snapshot + snapshot: Snapshot @cost(complexity: 1) """ Type of the origin visited. Eg: git/hg/svn/tar/deb @@ -427,7 +429,7 @@ Do not return branches whose name contains the given prefix """ nameExcludePrefix: String - ): BranchConnection + ): BranchConnection @cost(complexity: 2, multipliers: ["first"]) # This costs more because of local (graphql level) pagination } """ @@ -527,7 +529,7 @@ """ Branch target object """ - target: BranchTarget + target: BranchTarget @cost(complexity: 1) } """ @@ -632,7 +634,7 @@ """ The unique directory object that revision points to """ - directory: Directory + directory: Directory @cost(complexity: 1) """ Connection to all the parents of the revision @@ -647,7 +649,7 @@ Returns the page after this cursor """ after: String - ): RevisionConnection + ): RevisionConnection @cost(complexity: 3) # here first is not mandatory, hence adding a higher value for cost """ Connection to all the revisions heading to this one @@ -663,7 +665,7 @@ Returns the page after the cursor """ after: String - ): RevisionConnection + ): RevisionConnection @cost(complexity: 2, multipliers: ["first"]) # This costs more because of local (graphql level) pagination } """ @@ -723,7 +725,7 @@ """ Release target object """ - target: ReleaseTarget + target: ReleaseTarget @cost(complexity: 1) } """ @@ -797,7 +799,7 @@ """ Directory entry target object """ - target: DirectoryEntryTarget + target: DirectoryEntryTarget @cost(complexity: 1) } """ @@ -832,7 +834,7 @@ Filter by entry name """ nameInclude: String - ): DirectoryEntryConnection + ): DirectoryEntryConnection @cost(complexity: 2, multipliers: ["first"]) # pagination is local, hence adding a higher value for cost } """ @@ -997,7 +999,7 @@ """ Result target object """ - target: SearchResultTarget + target: SearchResultTarget @cost(complexity: 1) } """ @@ -1012,7 +1014,7 @@ URL of the Origin """ url: String! - ): Origin + ): Origin @cost(complexity: 1) """ Get a Connection to all the origins @@ -1032,7 +1034,7 @@ Filter origins with a URL pattern """ urlPattern: String - ): OriginConnection + ): OriginConnection @cost(complexity: 1, multipliers: ["first"]) """ Get the visit object with an origin URL and a visit id @@ -1047,7 +1049,7 @@ Visit id to get """ visitId: Int! - ): Visit + ): Visit @cost(complexity: 1) """ Get the snapshot with a SWHID @@ -1057,7 +1059,7 @@ SWHID of the snapshot object """ swhid: SWHID! - ): Snapshot + ): Snapshot @cost(complexity: 1) """ Get the revision with a SWHID @@ -1067,7 +1069,7 @@ SWHID of the revision object """ swhid: SWHID! - ): Revision + ): Revision @cost(complexity: 1) """ Get the release with a SWHID @@ -1077,7 +1079,7 @@ SWHID of the release object """ swhid: SWHID! - ): Release + ): Release @cost(complexity: 1) """ Get the directory with a SWHID @@ -1087,7 +1089,7 @@ SWHID of the directory object """ swhid: SWHID! - ): Directory + ): Directory @cost(complexity: 1) """ Get a directory entry with directory SWHID and a path @@ -1102,7 +1104,7 @@ Relative path to the requested object """ path: String! - ): DirectoryEntry + ): DirectoryEntry @cost(complexity: 2) # This costs more because path can be any level deep """ Get a list of contents for the given SWHID @@ -1112,7 +1114,7 @@ SWHID to look for """ swhid: SWHID! - ): [Content] + ): [Content] @cost(complexity: 1) """ Get contents with hashes @@ -1126,7 +1128,7 @@ sha1_git: String blake2s256: String - ): [Content] + ): [Content] @cost(complexity: 1) """ Get a content that match all the given hashes. @@ -1142,7 +1144,7 @@ sha1_git: String! blake2s256: String! - ): Content + ): Content @cost(complexity: 1) """ Resolve the given SWHID to an object @@ -1152,7 +1154,7 @@ SWHID to look for """ swhid: SWHID! - ): [SearchResult] + ): [SearchResult] @cost(complexity: 1) """ Search in SWH @@ -1172,5 +1174,5 @@ Returns the page after the cursor """ after: String - ): SearchResultConnection + ): SearchResultConnection! @cost(complexity: 1, multipliers: ["first"]) } diff --git a/swh/graphql/server.py b/swh/graphql/server.py --- a/swh/graphql/server.py +++ b/swh/graphql/server.py @@ -67,7 +67,7 @@ from ariadne.asgi import GraphQL from starlette.middleware.cors import CORSMiddleware - from .app import schema + from .app import schema, validation_rules from .errors import format_error global graphql_cfg @@ -80,8 +80,9 @@ GraphQL( schema, debug=graphql_cfg["debug"], - error_formatter=format_error, introspection=graphql_cfg["introspection"], + validation_rules=validation_rules, + error_formatter=format_error, ), # FIXME, restrict origins after deploying the JS client allow_origins=["*"], diff --git a/swh/graphql/tests/conftest.py b/swh/graphql/tests/conftest.py --- a/swh/graphql/tests/conftest.py +++ b/swh/graphql/tests/conftest.py @@ -9,7 +9,7 @@ import pytest from swh.graphql import server as app_server -from swh.graphql.app import schema +from swh.graphql.app import schema, validation_rules from swh.graphql.errors import format_error from swh.search import get_search as get_swh_search from swh.storage import get_storage as get_swh_storage @@ -38,6 +38,11 @@ return search +@pytest.fixture(autouse=True) +def config(): + app_server.graphql_cfg = {"max_query_cost": {"anonymous": 100}} + + @pytest.fixture(scope="session") def test_app(storage, search): app = Flask(__name__) @@ -51,6 +56,7 @@ data, context_value=request, debug=app.debug, + validation_rules=validation_rules, error_formatter=format_error, ) status_code = 200 if success else 400 diff --git a/swh/graphql/tests/functional/test_directory_entry.py b/swh/graphql/tests/functional/test_directory_entry.py --- a/swh/graphql/tests/functional/test_directory_entry.py +++ b/swh/graphql/tests/functional/test_directory_entry.py @@ -107,7 +107,7 @@ query getDirectory($swhid: SWHID!) { directory(swhid: $swhid) { swhid - entries { + entries(first: 10) { nodes { targetType name { diff --git a/swh/graphql/tests/functional/test_pagination.py b/swh/graphql/tests/functional/test_pagination.py --- a/swh/graphql/tests/functional/test_pagination.py +++ b/swh/graphql/tests/functional/test_pagination.py @@ -50,7 +50,12 @@ def test_too_big_first_arg(client): - data, errors = get_origin_nodes(client, first=1001) # max page size is 1000 + from swh.graphql import server as app_server + + # set the query cost limit to a higher value for this test + app_server.graphql_cfg = {"max_query_cost": {"anonymous": 2000}} + + data, errors = get_origin_nodes(client, 1001) # max page size is 1000 assert data["origins"] is None assert (len(errors)) == 2 assert ( diff --git a/swh/graphql/tests/functional/test_query_cost.py b/swh/graphql/tests/functional/test_query_cost.py new file mode 100644 --- /dev/null +++ b/swh/graphql/tests/functional/test_query_cost.py @@ -0,0 +1,113 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from . import utils +from ..data import get_snapshots + + +def test_valid_query(client): + query_str = """ + query getOrigins { + origins(first: 2) { + nodes { + url + } + } + } + """ + response, _ = utils.get_query_response(client, query_str) + assert len(response["origins"]["nodes"]) == 2 + + +def test_query_cost_simple(client): + query_str = """ + query getOrigins { + origins(first: 1000) { + nodes { + url + } + } + } + """ + errors = utils.get_error_response(client, query_str, response_code=400) + assert ( + "The query exceeds the maximum cost of 100. Actual cost is 1000" + in errors[0]["message"] + ) + + +def test_query_cost_origin(client): + query_str = """ + query getOrigins { + origins(first: 10) { + nodes { + url + latestVisit { + date + } + visits(first: 5) { + nodes { + date + statuses { + nodes { + date + } + } + } + } + snapshots(first: 5) { + nodes { + swhid + } + } + } + } + } + """ + # Total cost here is 170 + # 10 (origin) + 10 (latestVisit) + 10*5 (visits) + 10 * 5 * 3 (status) + + # 10 * 5*2 (snapshots) = 320 + errors = utils.get_error_response(client, query_str, response_code=400) + assert ( + "The query exceeds the maximum cost of 100. Actual cost is 320" + in errors[0]["message"] + ) + + +def test_query_cost_snapshots(client): + query_str = """ + query getSnapshot($swhid: SWHID!) { + snapshot(swhid: $swhid) { + branches(first: 50) { + nodes { + target { + ...on Revision { + swhid + } + ...on Directory { + swhid + entries(first: 3) { + nodes { + targetType + } + } + } + } + } + } + } + } + """ + # Total cost here is 157 + # 1 (snapshot) + 2 *50 (branches) + 50 * 1 (revision or Directory) + 3 * 2 = 157 + # parent multiplier is not applied when schema introspection is used + # ie: directory entry connection cost is 3 * 2 and not 50 * 3 * 2 + errors = utils.get_error_response( + client, query_str, swhid=str(get_snapshots()[0].swhid()), response_code=400 + ) + assert ( + "The query exceeds the maximum cost of 100. Actual cost is 157" + in errors[0]["message"] + ) diff --git a/swh/graphql/tests/functional/test_release_node.py b/swh/graphql/tests/functional/test_release_node.py --- a/swh/graphql/tests/functional/test_release_node.py +++ b/swh/graphql/tests/functional/test_release_node.py @@ -96,7 +96,7 @@ errors = utils.get_error_response(client, query_str, swhid="swh:1:rel:invalid") # API will throw an error in case of an invalid SWHID assert len(errors) == 1 - assert "Expected type 'SWHID'. Input error: Invalid SWHID" in errors[0]["message"] + assert "Input error: Invalid SWHID" in errors[0]["message"] @pytest.mark.parametrize("release_with_target", get_releases_with_target()) diff --git a/swh/graphql/tests/functional/utils.py b/swh/graphql/tests/functional/utils.py --- a/swh/graphql/tests/functional/utils.py +++ b/swh/graphql/tests/functional/utils.py @@ -9,10 +9,12 @@ from ariadne import gql -def get_query_response(client, query_str: str, **kwargs) -> Tuple[Dict, Dict]: +def get_query_response( + client, query_str: str, response_code: int = 200, **kwargs +) -> Tuple[Dict, Dict]: query = gql(query_str) response = client.post("/", json={"query": query, "variables": kwargs}) - assert response.status_code == 200, response.data + assert response.status_code == response_code, response.data result = json.loads(response.data) return result.get("data"), result.get("errors") @@ -25,7 +27,11 @@ assert errors[0]["path"] == [obj_type] -def get_error_response(client, query_str: str, **kwargs) -> Dict: - data, errors = get_query_response(client, query_str, **kwargs) +def get_error_response( + client, query_str: str, response_code: int = 200, **kwargs +) -> Dict: + data, errors = get_query_response( + client, query_str, response_code=response_code, **kwargs + ) assert len(errors) > 0 return errors diff --git a/swh/graphql/utils/__init__.py b/swh/graphql/utils/__init__.py new file mode 100644