diff --git a/config/dev.yml b/config/dev.yml --- a/config/dev.yml +++ b/config/dev.yml @@ -9,3 +9,6 @@ debug: yes server-type: asgi + +max_query_cost: + anonymous: 500 diff --git a/config/staging.yml b/config/staging.yml --- a/config/staging.yml +++ b/config/staging.yml @@ -9,3 +9,6 @@ debug: yes server-type: wsgi + +max_query_cost: + anonymous: 500 diff --git a/swh/graphql/app.py b/swh/graphql/app.py --- a/swh/graphql/app.py +++ b/swh/graphql/app.py @@ -8,9 +8,12 @@ from pathlib import Path from ariadne import gql, load_schema_from_path, make_executable_schema +from ariadne.validation import cost_validator from .resolvers import resolvers, scalars +# from .server import graphql_cfg + type_defs = gql( # pkg_resources.resource_string("swh.graphql", "schem/schema.graphql").decode() load_schema_from_path( @@ -41,3 +44,10 @@ scalars.swhid_scalar, scalars.content_hash_scalar, ) + + +def validation_rules(context_value, document, data): + from .server import graphql_cfg + + max_query_cost = graphql_cfg["max_query_cost"]["anonymous"] + return [cost_validator(maximum_cost=max_query_cost)] diff --git a/swh/graphql/schema/schema.graphql b/swh/graphql/schema/schema.graphql --- a/swh/graphql/schema/schema.graphql +++ b/swh/graphql/schema/schema.graphql @@ -1,3 +1,5 @@ +directive @cost(complexity: Int, multipliers: [String!], useMultipliers: Boolean) on FIELD | FIELD_DEFINITION + """ SoftWare Heritage persistent Identifier """ @@ -132,12 +134,12 @@ Returns the page after this cursor """ after: String - ): VisitConnection! + ): VisitConnection! @cost(complexity: 1, multipliers: ["first"]) """ Latest visit object for the origin """ - latestVisit: Visit + latestVisit: Visit @cost(complexity: 1) """ Connection to all the snapshots for the origin @@ -152,7 +154,7 @@ Returns the page after this cursor """ after: String - ): SnapshotConnection + ): SnapshotConnection @cost(complexity: 2, multipliers: ["first"]) # This costs more because of local (graphql level) pagination } """ @@ -232,12 +234,12 @@ Returns the page after this cursor """ after: String - ): VisitStatusConnection + ): VisitStatusConnection @cost(complexity: 3) # here first is optional, hence adding a higher value for cost """ Latest status object for the Visit """ - latestStatus: VisitStatus + latestStatus: VisitStatus @cost(complexity: 1) } """ @@ -297,7 +299,7 @@ """ Snapshot object """ - snapshot: Snapshot + snapshot: Snapshot @cost(complexity: 1) """ Type of the origin visited. Eg: git/hg/svn/tar/deb @@ -382,7 +384,7 @@ Filter by branch name """ nameInclude: String - ): BranchConnection + ): BranchConnection @cost(complexity: 2, multipliers: ["first"]) # This costs more because of local (graphql level) pagination } """ @@ -479,7 +481,7 @@ """ Branch target object """ - target: BranchTarget + target: BranchTarget @cost(complexity: 1) } """ @@ -562,7 +564,7 @@ """ The unique directory object that revision points to """ - directory: Directory + directory: Directory @cost(complexity: 1) """ Connection to all the parents of the revision @@ -577,7 +579,7 @@ Returns the page after this cursor """ after: String - ): RevisionConnection + ): RevisionConnection @cost(complexity: 3) # here first is not mandatory, hence adding a higher value for cost """ Connection to all the revisions heading to this one @@ -593,7 +595,7 @@ Returns the page after the cursor """ after: String - ): RevisionConnection + ): RevisionConnection @cost(complexity: 2, multipliers: ["first"]) # This costs more because of local (graphql level) pagination } """ @@ -652,7 +654,7 @@ """ Release target object """ - target: ReleaseTarget + target: ReleaseTarget @cost(complexity: 1) } """ @@ -726,7 +728,7 @@ """ Directory entry target object """ - target: DirectoryEntryTarget + target: DirectoryEntryTarget @cost(complexity: 1) } """ @@ -750,13 +752,13 @@ """ Returns the first _n_ elements from the list """ - first: Int + first: Int! """ Returns the page after this cursor """ after: String - ): DirectoryEntryConnection + ): DirectoryEntryConnection @cost(complexity: 2, multipliers: ["first"]) # pagination is local, hence adding a higher value for cost } """ @@ -918,7 +920,7 @@ """ Result target object """ - target: SearchResultTarget + target: SearchResultTarget @cost(complexity: 1) } """ @@ -933,7 +935,7 @@ URL of the Origin """ url: String! - ): Origin + ): Origin @cost(complexity: 1) """ Get a Connection to all the origins @@ -953,7 +955,7 @@ Filter origins with a URL pattern """ urlPattern: String - ): OriginConnection + ): OriginConnection @cost(complexity: 1, multipliers: ["first"]) """ Get the visit object with an origin URL and a visit id @@ -968,7 +970,7 @@ Visit id to get """ visitId: Int! - ): Visit + ): Visit @cost(complexity: 1) """ Get the snapshot with a SWHID @@ -978,7 +980,7 @@ SWHID of the snapshot object """ swhid: SWHID! - ): Snapshot + ): Snapshot @cost(complexity: 1) """ Get the revision with a SWHID @@ -988,7 +990,7 @@ SWHID of the revision object """ swhid: SWHID! - ): Revision + ): Revision @cost(complexity: 1) """ Get the release with a SWHID @@ -998,7 +1000,7 @@ SWHID of the release object """ swhid: SWHID! - ): Release + ): Release @cost(complexity: 1) """ Get the directory with a SWHID @@ -1008,7 +1010,7 @@ SWHID of the directory object """ swhid: SWHID! - ): Directory + ): Directory @cost(complexity: 1) """ Get the content with a SWHID @@ -1018,7 +1020,7 @@ SWHID of the content object """ swhid: SWHID! - ): Content + ): Content @cost(complexity: 1) """ Get the content by one or more hashes @@ -1029,7 +1031,7 @@ List of hashType:hashValue strings """ checksums: [ContentHash]! - ): Content + ): Content @cost(complexity: 1) """ Resolve the given SWHID to an object @@ -1039,7 +1041,7 @@ SWHID to look for """ swhid: SWHID! - ): SearchResultConnection! + ): SearchResultConnection! @cost(complexity: 1) """ Search in SWH @@ -1059,5 +1061,5 @@ Returns the page after the cursor """ after: String - ): SearchResultConnection! + ): SearchResultConnection! @cost(complexity: 1, multipliers: ["first"]) } diff --git a/swh/graphql/server.py b/swh/graphql/server.py --- a/swh/graphql/server.py +++ b/swh/graphql/server.py @@ -62,7 +62,7 @@ SWH_CONFIG_FILENAME environment variable defines the configuration path to load. """ - from .app import schema + from .app import schema, validation_rules global graphql_cfg @@ -77,12 +77,12 @@ # Enable cors in the asgi version application = CORSMiddleware( - GraphQL(schema), + GraphQL(schema, validation_rules=validation_rules), allow_origins=["*"], allow_methods=("GET", "POST", "OPTIONS"), ) else: from ariadne.wsgi import GraphQL - application = GraphQL(schema) + application = GraphQL(schema, validation_rules=validation_rules) return application diff --git a/swh/graphql/tests/conftest.py b/swh/graphql/tests/conftest.py --- a/swh/graphql/tests/conftest.py +++ b/swh/graphql/tests/conftest.py @@ -9,7 +9,7 @@ import pytest from swh.graphql import server as app_server -from swh.graphql.app import schema +from swh.graphql.app import schema, validation_rules from swh.search import get_search as get_swh_search from swh.storage import get_storage as get_swh_storage @@ -38,7 +38,12 @@ @pytest.fixture(scope="session") -def test_app(storage, search): +def config(): + app_server.graphql_cfg = {"max_query_cost": {"anonymous": 100}} + + +@pytest.fixture(scope="session") +def test_app(storage, search, config): app = Flask(__name__) @app.route("/", methods=["POST"]) @@ -46,7 +51,11 @@ # GraphQL queries are always sent as POST data = request.get_json() success, result = graphql_sync( - schema, data, context_value=request, debug=app.debug + schema, + data, + validation_rules=validation_rules, + context_value=request, + debug=app.debug, ) status_code = 200 if success else 400 return jsonify(result), status_code diff --git a/swh/graphql/tests/functional/test_query_cost.py b/swh/graphql/tests/functional/test_query_cost.py new file mode 100644 --- /dev/null +++ b/swh/graphql/tests/functional/test_query_cost.py @@ -0,0 +1,86 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from . import utils +from ..data import get_snapshots + + +def test_valid_query(client): + query_str = """ + { + origins(first: 2) { + nodes { + url + } + } + } + """ + response, _ = utils.get_query_response(client, query_str) + assert len(response["origins"]["nodes"]) == 2 + + +def test_query_cost_origin(client): + query_str = """ + { + origins(first: 10) { + nodes { + url + latestVisit { + date + } + visits(first: 5) { + nodes { + date + } + } + snapshots(first: 5) { + nodes { + swhid + } + } + } + } + } + """ + # Total cost here is 170 + # 10 (origin) + 10 (latestVisit) + 10*5 (visits) + 10 * 5*2 (snapshots) = 170 + errors = utils.get_error_response(client, query_str) + assert ( + "The query exceeds the maximum cost of 100. Actual cost is 170" + in errors[0]["message"] + ) + + +def test_query_cost_snapshots(client): + query_str = """ + { + snapshot(swhid: "%s") { + branches(first: 50) { + nodes { + target { + ...on Revision { + swhid + } + ...on Directory { + swhid + entries(first: 3) { + nodes { + type + } + } + } + } + } + } + } + } + """ + # Total cost here is 157 + # 1 (snapshot) + 2 *50 (branches) + 50 * 1 (revision or Directory) + 3 * 2 = 157 + errors = utils.get_error_response(client, query_str % get_snapshots()[0].swhid()) + assert ( + "The query exceeds the maximum cost of 100. Actual cost is 157" + in errors[0]["message"] + ) diff --git a/swh/graphql/utils/__init__.py b/swh/graphql/utils/__init__.py new file mode 100644