diff --git a/requirements.txt b/requirements.txt --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -ariadne +ariadne>=0.15 diff --git a/swh/graphql/app.py b/swh/graphql/app.py --- a/swh/graphql/app.py +++ b/swh/graphql/app.py @@ -8,6 +8,7 @@ from pathlib import Path from ariadne import gql, load_schema_from_path, make_executable_schema +from ariadne.validation import cost_directive, cost_validator from .resolvers import resolvers, scalars @@ -19,22 +20,26 @@ ) schema = make_executable_schema( - type_defs, - resolvers.query, - resolvers.origin, - resolvers.visit, - resolvers.visit_status, - resolvers.snapshot, - resolvers.snapshot_branch, - resolvers.revision, - resolvers.release, - resolvers.directory, - resolvers.directory_entry, - resolvers.branch_target, - resolvers.release_target, - resolvers.directory_entry_target, - resolvers.binary_string, - scalars.id_scalar, - scalars.datetime_scalar, - scalars.swhid_scalar, + [type_defs, cost_directive], + [ + resolvers.query, + resolvers.origin, + resolvers.visit, + resolvers.visit_status, + resolvers.snapshot, + resolvers.snapshot_branch, + resolvers.revision, + resolvers.release, + resolvers.directory, + resolvers.directory_entry, + resolvers.branch_target, + resolvers.release_target, + resolvers.directory_entry_target, + resolvers.binary_string, + scalars.id_scalar, + scalars.datetime_scalar, + scalars.swhid_scalar, + ], ) + +validation_rules = [cost_validator(maximum_cost=200)] diff --git a/swh/graphql/schema/schema.graphql b/swh/graphql/schema/schema.graphql --- a/swh/graphql/schema/schema.graphql +++ b/swh/graphql/schema/schema.graphql @@ -127,12 +127,12 @@ Returns the page after this cursor """ after: String - ): VisitConnection! + ): VisitConnection! @cost(complexity: 2) """ Latest visit object for the origin """ - latestVisit: Visit + latestVisit: Visit @cost(complexity: 1) """ Connection to all the snapshots for the origin @@ -147,7 +147,7 @@ Returns the page after this cursor """ after: String - ): SnapshotConnection + ): SnapshotConnection @cost(complexity: 2) } """ @@ -227,12 +227,12 @@ Returns the page after this cursor """ after: String - ): VisitStatusConnection + ): VisitStatusConnection @cost(complexity: 2) """ Latest status object for the Visit """ - latestStatus: VisitStatus + latestStatus: VisitStatus @cost(complexity: 1) } """ @@ -292,7 +292,7 @@ """ Snapshot object """ - snapshot: Snapshot + snapshot: Snapshot @cost(complexity: 1) """ Type of the origin visited. Eg: git/hg/svn/tar/deb @@ -377,7 +377,7 @@ Filter by branch name """ nameInclude: String - ): BranchConnection + ): BranchConnection @cost(complexity: 2) } """ @@ -474,7 +474,7 @@ """ Branch target object """ - target: BranchTarget + target: BranchTarget @cost(complexity: 1) } """ @@ -557,7 +557,7 @@ """ The unique directory object that revision points to """ - directory: Directory + directory: Directory @cost(complexity: 1) """ Connection to all the parents of the revision @@ -572,7 +572,7 @@ Returns the page after this cursor """ after: String - ): RevisionConnection + ): RevisionConnection @cost(complexity: 2) """ Connection to all the revisions heading to this one @@ -588,7 +588,7 @@ Returns the page after the cursor """ after: String - ): RevisionConnection + ): RevisionConnection @cost(complexity: 2) } """ @@ -647,7 +647,7 @@ """ Release target object """ - target: ReleaseTarget + target: ReleaseTarget @cost(complexity: 1) } """ @@ -721,7 +721,7 @@ """ Directory entry target object """ - target: DirectoryEntryTarget + target: DirectoryEntryTarget @cost(complexity: 1) } """ @@ -751,7 +751,7 @@ Returns the page after this cursor """ after: String - ): DirectoryEntryConnection + ): DirectoryEntryConnection @cost(complexity: 2) } """ @@ -817,7 +817,7 @@ URL of the Origin """ url: String! - ): Origin + ): Origin @cost(complexity: 1) """ Get a Connection to all the origins @@ -837,7 +837,7 @@ Filter origins with a URL pattern """ urlPattern: String - ): OriginConnection + ): OriginConnection @cost(complexity: 2) """ Get the visit object with an origin URL and a visit id @@ -852,7 +852,7 @@ Visit id to get """ visitId: Int! - ): Visit + ): Visit @cost(complexity: 1) """ Get the snapshot with a SWHID @@ -862,7 +862,7 @@ SWHID of the snapshot object """ swhid: SWHID! - ): Snapshot + ): Snapshot @cost(complexity: 1) """ Get the revision with a SWHID @@ -872,7 +872,7 @@ SWHID of the revision object """ swhid: SWHID! - ): Revision + ): Revision @cost(complexity: 1) """ Get the release with a SWHID @@ -882,7 +882,7 @@ SWHID of the release object """ swhid: SWHID! - ): Release + ): Release @cost(complexity: 1) """ Get the directory with a SWHID @@ -892,7 +892,7 @@ SWHID of the directory object """ swhid: SWHID! - ): Directory + ): Directory @cost(complexity: 1) """ Get the content with a SWHID @@ -902,5 +902,5 @@ SWHID of the content object """ swhid: SWHID! - ): Content + ): Content @cost(complexity: 1) } diff --git a/swh/graphql/server.py b/swh/graphql/server.py --- a/swh/graphql/server.py +++ b/swh/graphql/server.py @@ -55,7 +55,7 @@ configuration path to load. """ - from .app import schema + from .app import schema, validation_rules global graphql_cfg @@ -67,9 +67,9 @@ if server_type == "asgi": from ariadne.asgi import GraphQL - application = GraphQL(schema) + application = GraphQL(schema, validation_rules=validation_rules) else: from ariadne.wsgi import GraphQL - application = GraphQL(schema) + application = GraphQL(schema, validation_rules=validation_rules) return application diff --git a/swh/graphql/tests/conftest.py b/swh/graphql/tests/conftest.py --- a/swh/graphql/tests/conftest.py +++ b/swh/graphql/tests/conftest.py @@ -9,7 +9,7 @@ import pytest from swh.graphql import server as app_server -from swh.graphql.app import schema +from swh.graphql.app import schema, validation_rules from swh.storage import get_storage as get_swhstorage from .data import populate_dummy_data @@ -33,8 +33,13 @@ def graphql_server(): # GraphQL queries are always sent as POST data = request.get_json() + success, result = graphql_sync( - schema, data, context_value=request, debug=app.debug + schema, + data, + context_value=request, + debug=app.debug, + validation_rules=validation_rules, ) status_code = 200 if success else 400 return jsonify(result), status_code diff --git a/swh/graphql/tests/functional/test_rate_limiter.py b/swh/graphql/tests/functional/test_rate_limiter.py new file mode 100644 --- /dev/null +++ b/swh/graphql/tests/functional/test_rate_limiter.py @@ -0,0 +1,42 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from itertools import permutations + +from ..data import get_origins +from .utils import get_error_response, get_query_response + + +def get_origin_query(count: int) -> str: + url = get_origins()[0].url + # query to get one origin with the URL + origin_query = '%s: origin(url: "%s") { url }' + # String permutations will have 720 elements (6!) + perms = ["".join(p) for p in permutations("abcdef")] + # Create {count} instances of the same query using alias + # eg { abcdef: origin(..) {..} bacdef: origin(..) ..} + # total cost for this query will be count itself (each node costs 1 unit) + return "{%s}" % (" ".join([origin_query % (i, url) for i in perms[:count]])) + + +# Using Origin object to run functional tests for rate limiting +def test_get_origins_within_limit(client): + # Create a query requesting 100 nodes + query_str = get_origin_query(100) + data, error = get_query_response(client, query_str) + assert error is None + assert len(data) == 100 # query cost is 100 in this case + + +def test_get_origins_over_limit(client): + # Create a query requesting more than 200 nodes + query_str = get_origin_query(201) + # This will throw a 400 error + errors = get_error_response(client, query_str, error_code=400) + assert len(errors) == 1 + assert ( + errors[0]["message"] + == "The query exceeds the maximum cost of 200. Actual cost is 201" + ) diff --git a/swh/graphql/tests/functional/test_snapshot_node.py b/swh/graphql/tests/functional/test_snapshot_node.py --- a/swh/graphql/tests/functional/test_snapshot_node.py +++ b/swh/graphql/tests/functional/test_snapshot_node.py @@ -53,5 +53,7 @@ } """ errors = get_error_response(client, query_str) - assert len(errors) == 1 - assert "Invalid SWHID: invalid syntax" in errors[0]["message"] + assert ( + len(errors) == 2 + ) # FIXME, a bug in ariadne validator is causing this duplicate error + assert "Invalid SWHID: invalid syntax" in errors[1]["message"]