Page MenuHomeSoftware Heritage

D8077.diff
No OneTemporary

D8077.diff

diff --git a/config/dev.yml b/config/dev.yml
--- a/config/dev.yml
+++ b/config/dev.yml
@@ -9,3 +9,7 @@
debug: yes
introspection: yes
+
+max_query_cost:
+ anonymous: 500
+
diff --git a/config/staging.yml b/config/staging.yml
--- a/config/staging.yml
+++ b/config/staging.yml
@@ -9,3 +9,7 @@
debug: no
introspection: yes
+
+max_query_cost:
+ anonymous: 500
+
diff --git a/swh/graphql/app.py b/swh/graphql/app.py
--- a/swh/graphql/app.py
+++ b/swh/graphql/app.py
@@ -8,6 +8,7 @@
from pathlib import Path
from ariadne import gql, load_schema_from_path, make_executable_schema
+from ariadne.validation import cost_validator
from .resolvers import resolvers, scalars
@@ -41,3 +42,16 @@
scalars.datetime_scalar,
scalars.swhid_scalar,
)
+
+
+def validation_rules(context_value, document, data):
+ from .server import graphql_cfg
+
+ # add logic to set max_query_cost depending on user type
+ max_query_cost = graphql_cfg["max_query_cost"]["anonymous"]
+ if max_query_cost:
+ return [
+ cost_validator(maximum_cost=max_query_cost, variables=data.get("variables"))
+ ]
+ # no limit is applied when max_query_cost is set to 0 or None
+ return None
diff --git a/swh/graphql/schema/schema.graphql b/swh/graphql/schema/schema.graphql
--- a/swh/graphql/schema/schema.graphql
+++ b/swh/graphql/schema/schema.graphql
@@ -1,3 +1,5 @@
+directive @cost(complexity: Int, multipliers: [String!], useMultipliers: Boolean) on FIELD | FIELD_DEFINITION
+
"""
SoftWare Heritage persistent Identifier
"""
@@ -129,7 +131,7 @@
Returns the page after this cursor
"""
after: String
- ): VisitConnection
+ ): VisitConnection! @cost(complexity: 1, multipliers: ["first"])
"""
Latest visit object for the origin
@@ -149,7 +151,7 @@
If True, the latest visit with a snapshot will be returned
"""
requireSnapshot: Boolean
- ): Visit
+ ): Visit @cost(complexity: 1)
"""
Connection to all the snapshots for the origin
@@ -164,7 +166,7 @@
Returns the page after this cursor
"""
after: String
- ): SnapshotConnection
+ ): SnapshotConnection @cost(complexity: 2, multipliers: ["first"]) # This costs more because of local (graphql level) pagination
}
"""
@@ -259,7 +261,7 @@
Returns the page after this cursor
"""
after: String
- ): VisitStatusConnection
+ ): VisitStatusConnection @cost(complexity: 3) # here first is optional, hence adding a higher value for cost
"""
Latest status object for the Visit
@@ -274,7 +276,7 @@
Filter by the availability of a snapshot in the status
"""
requireSnapshot: Boolean
- ): VisitStatus
+ ): VisitStatus @cost(complexity: 1)
}
"""
@@ -337,7 +339,7 @@
"""
Snapshot object
"""
- snapshot: Snapshot
+ snapshot: Snapshot @cost(complexity: 1)
"""
Type of the origin visited. Eg: git/hg/svn/tar/deb
@@ -427,7 +429,7 @@
Do not return branches whose name contains the given prefix
"""
nameExcludePrefix: String
- ): BranchConnection
+ ): BranchConnection @cost(complexity: 2, multipliers: ["first"]) # This costs more because of local (graphql level) pagination
}
"""
@@ -527,7 +529,7 @@
"""
Branch target object
"""
- target: BranchTarget
+ target: BranchTarget @cost(complexity: 1)
}
"""
@@ -632,7 +634,7 @@
"""
The unique directory object that revision points to
"""
- directory: Directory
+ directory: Directory @cost(complexity: 1)
"""
Connection to all the parents of the revision
@@ -647,7 +649,7 @@
Returns the page after this cursor
"""
after: String
- ): RevisionConnection
+ ): RevisionConnection @cost(complexity: 3) # here first is not mandatory, hence adding a higher value for cost
"""
Connection to all the revisions heading to this one
@@ -663,7 +665,7 @@
Returns the page after the cursor
"""
after: String
- ): RevisionConnection
+ ): RevisionConnection @cost(complexity: 2, multipliers: ["first"]) # This costs more because of local (graphql level) pagination
}
"""
@@ -723,7 +725,7 @@
"""
Release target object
"""
- target: ReleaseTarget
+ target: ReleaseTarget @cost(complexity: 1)
}
"""
@@ -797,7 +799,7 @@
"""
Directory entry target object
"""
- target: DirectoryEntryTarget
+ target: DirectoryEntryTarget @cost(complexity: 1)
}
"""
@@ -832,7 +834,7 @@
Filter by entry name
"""
nameInclude: String
- ): DirectoryEntryConnection
+ ): DirectoryEntryConnection @cost(complexity: 2, multipliers: ["first"]) # pagination is local, hence adding a higher value for cost
}
"""
@@ -997,7 +999,7 @@
"""
Result target object
"""
- target: SearchResultTarget
+ target: SearchResultTarget @cost(complexity: 1)
}
"""
@@ -1012,7 +1014,7 @@
URL of the Origin
"""
url: String!
- ): Origin
+ ): Origin @cost(complexity: 1)
"""
Get a Connection to all the origins
@@ -1032,7 +1034,7 @@
Filter origins with a URL pattern
"""
urlPattern: String
- ): OriginConnection
+ ): OriginConnection @cost(complexity: 1, multipliers: ["first"])
"""
Get the visit object with an origin URL and a visit id
@@ -1047,7 +1049,7 @@
Visit id to get
"""
visitId: Int!
- ): Visit
+ ): Visit @cost(complexity: 1)
"""
Get the snapshot with a SWHID
@@ -1057,7 +1059,7 @@
SWHID of the snapshot object
"""
swhid: SWHID!
- ): Snapshot
+ ): Snapshot @cost(complexity: 1)
"""
Get the revision with a SWHID
@@ -1067,7 +1069,7 @@
SWHID of the revision object
"""
swhid: SWHID!
- ): Revision
+ ): Revision @cost(complexity: 1)
"""
Get the release with a SWHID
@@ -1077,7 +1079,7 @@
SWHID of the release object
"""
swhid: SWHID!
- ): Release
+ ): Release @cost(complexity: 1)
"""
Get the directory with a SWHID
@@ -1087,7 +1089,7 @@
SWHID of the directory object
"""
swhid: SWHID!
- ): Directory
+ ): Directory @cost(complexity: 1)
"""
Get a directory entry with directory SWHID and a path
@@ -1102,7 +1104,7 @@
Relative path to the requested object
"""
path: String!
- ): DirectoryEntry
+ ): DirectoryEntry @cost(complexity: 2) # This costs more because path can be any level deep
"""
Get a list of contents for the given SWHID
@@ -1112,7 +1114,7 @@
SWHID to look for
"""
swhid: SWHID!
- ): [Content]
+ ): [Content] @cost(complexity: 1)
"""
Get contents with hashes
@@ -1126,7 +1128,7 @@
sha1_git: String
blake2s256: String
- ): [Content]
+ ): [Content] @cost(complexity: 1)
"""
Get a content that match all the given hashes.
@@ -1142,7 +1144,7 @@
sha1_git: String!
blake2s256: String!
- ): Content
+ ): Content @cost(complexity: 1)
"""
Resolve the given SWHID to an object
@@ -1152,7 +1154,7 @@
SWHID to look for
"""
swhid: SWHID!
- ): [SearchResult]
+ ): [SearchResult] @cost(complexity: 1)
"""
Search in SWH
@@ -1172,5 +1174,5 @@
Returns the page after the cursor
"""
after: String
- ): SearchResultConnection
+ ): SearchResultConnection! @cost(complexity: 1, multipliers: ["first"])
}
diff --git a/swh/graphql/server.py b/swh/graphql/server.py
--- a/swh/graphql/server.py
+++ b/swh/graphql/server.py
@@ -67,7 +67,7 @@
from ariadne.asgi import GraphQL
from starlette.middleware.cors import CORSMiddleware
- from .app import schema
+ from .app import schema, validation_rules
from .errors import format_error
global graphql_cfg
@@ -80,8 +80,9 @@
GraphQL(
schema,
debug=graphql_cfg["debug"],
- error_formatter=format_error,
introspection=graphql_cfg["introspection"],
+ validation_rules=validation_rules,
+ error_formatter=format_error,
),
# FIXME, restrict origins after deploying the JS client
allow_origins=["*"],
diff --git a/swh/graphql/tests/conftest.py b/swh/graphql/tests/conftest.py
--- a/swh/graphql/tests/conftest.py
+++ b/swh/graphql/tests/conftest.py
@@ -9,7 +9,7 @@
import pytest
from swh.graphql import server as app_server
-from swh.graphql.app import schema
+from swh.graphql.app import schema, validation_rules
from swh.graphql.errors import format_error
from swh.search import get_search as get_swh_search
from swh.storage import get_storage as get_swh_storage
@@ -38,6 +38,16 @@
return search
+@pytest.fixture(autouse=True)
+def max_query_cost_config():
+ app_server.graphql_cfg = {"max_query_cost": {"anonymous": 100}}
+
+
+@pytest.fixture
+def max_query_cost_none_config():
+ app_server.graphql_cfg = {"max_query_cost": {"anonymous": 0}}
+
+
@pytest.fixture(scope="session")
def test_app(storage, search):
app = Flask(__name__)
@@ -51,6 +61,7 @@
data,
context_value=request,
debug=app.debug,
+ validation_rules=validation_rules,
error_formatter=format_error,
)
status_code = 200 if success else 400
diff --git a/swh/graphql/tests/functional/test_directory_entry.py b/swh/graphql/tests/functional/test_directory_entry.py
--- a/swh/graphql/tests/functional/test_directory_entry.py
+++ b/swh/graphql/tests/functional/test_directory_entry.py
@@ -107,7 +107,7 @@
query getDirectory($swhid: SWHID!) {
directory(swhid: $swhid) {
swhid
- entries {
+ entries(first: 10) {
nodes {
targetType
name {
diff --git a/swh/graphql/tests/functional/test_pagination.py b/swh/graphql/tests/functional/test_pagination.py
--- a/swh/graphql/tests/functional/test_pagination.py
+++ b/swh/graphql/tests/functional/test_pagination.py
@@ -50,7 +50,12 @@
def test_too_big_first_arg(client):
- data, errors = get_origin_nodes(client, first=1001) # max page size is 1000
+ from swh.graphql import server as app_server
+
+ # set the query cost limit to a higher value for this test
+ app_server.graphql_cfg = {"max_query_cost": {"anonymous": 2000}}
+
+ data, errors = get_origin_nodes(client, 1001) # max page size is 1000
assert data["origins"] is None
assert (len(errors)) == 2
assert (
diff --git a/swh/graphql/tests/functional/test_query_cost.py b/swh/graphql/tests/functional/test_query_cost.py
new file mode 100644
--- /dev/null
+++ b/swh/graphql/tests/functional/test_query_cost.py
@@ -0,0 +1,127 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+from . import utils
+from ..data import get_snapshots
+
+
+def test_valid_query(client):
+ query_str = """
+ query getOrigins {
+ origins(first: 2) {
+ nodes {
+ url
+ }
+ }
+ }
+ """
+ response, _ = utils.get_query_response(client, query_str)
+ assert len(response["origins"]["nodes"]) == 2
+
+
+def test_query_cost_simple(client):
+ query_str = """
+ query getOrigins {
+ origins(first: 1000) {
+ nodes {
+ url
+ }
+ }
+ }
+ """
+ errors = utils.get_error_response(client, query_str, response_code=400)
+ assert (
+ "The query exceeds the maximum cost of 100. Actual cost is 1000"
+ in errors[0]["message"]
+ )
+
+
+def test_query_cost_with_no_limit(client, max_query_cost_none_config):
+ query_str = """
+ query getOrigins {
+ origins(first: 1000) {
+ nodes {
+ url
+ }
+ }
+ }
+ """
+ response, _ = utils.get_query_response(client, query_str)
+ assert len(response["origins"]["nodes"]) == 2
+
+
+def test_query_cost_origin(client):
+ query_str = """
+ query getOrigins {
+ origins(first: 10) {
+ nodes {
+ url
+ latestVisit {
+ date
+ }
+ visits(first: 5) {
+ nodes {
+ date
+ statuses {
+ nodes {
+ date
+ }
+ }
+ }
+ }
+ snapshots(first: 5) {
+ nodes {
+ swhid
+ }
+ }
+ }
+ }
+ }
+ """
+ # Total cost here is 170
+ # 10 (origin) + 10 (latestVisit) + 10*5 (visits) + 10 * 5 * 3 (status) +
+ # 10 * 5*2 (snapshots) = 320
+ errors = utils.get_error_response(client, query_str, response_code=400)
+ assert (
+ "The query exceeds the maximum cost of 100. Actual cost is 320"
+ in errors[0]["message"]
+ )
+
+
+def test_query_cost_snapshots(client):
+ query_str = """
+ query getSnapshot($swhid: SWHID!) {
+ snapshot(swhid: $swhid) {
+ branches(first: 50) {
+ nodes {
+ target {
+ ...on Revision {
+ swhid
+ }
+ ...on Directory {
+ swhid
+ entries(first: 3) {
+ nodes {
+ targetType
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ """
+ # Total cost here is 157
+ # 1 (snapshot) + 2 *50 (branches) + 50 * 1 (revision or Directory) + 3 * 2 = 157
+ # parent multiplier is not applied when schema introspection is used
+ # ie: directory entry connection cost is 3 * 2 and not 50 * 3 * 2
+ errors = utils.get_error_response(
+ client, query_str, swhid=str(get_snapshots()[0].swhid()), response_code=400
+ )
+ assert (
+ "The query exceeds the maximum cost of 100. Actual cost is 157"
+ in errors[0]["message"]
+ )
diff --git a/swh/graphql/tests/functional/test_release_node.py b/swh/graphql/tests/functional/test_release_node.py
--- a/swh/graphql/tests/functional/test_release_node.py
+++ b/swh/graphql/tests/functional/test_release_node.py
@@ -96,7 +96,7 @@
errors = utils.get_error_response(client, query_str, swhid="swh:1:rel:invalid")
# API will throw an error in case of an invalid SWHID
assert len(errors) == 1
- assert "Expected type 'SWHID'. Input error: Invalid SWHID" in errors[0]["message"]
+ assert "Input error: Invalid SWHID" in errors[0]["message"]
@pytest.mark.parametrize("release_with_target", get_releases_with_target())
diff --git a/swh/graphql/tests/functional/utils.py b/swh/graphql/tests/functional/utils.py
--- a/swh/graphql/tests/functional/utils.py
+++ b/swh/graphql/tests/functional/utils.py
@@ -9,10 +9,12 @@
from ariadne import gql
-def get_query_response(client, query_str: str, **kwargs) -> Tuple[Dict, Dict]:
+def get_query_response(
+ client, query_str: str, response_code: int = 200, **kwargs
+) -> Tuple[Dict, Dict]:
query = gql(query_str)
response = client.post("/", json={"query": query, "variables": kwargs})
- assert response.status_code == 200, response.data
+ assert response.status_code == response_code, response.data
result = json.loads(response.data)
return result.get("data"), result.get("errors")
@@ -25,7 +27,11 @@
assert errors[0]["path"] == [obj_type]
-def get_error_response(client, query_str: str, **kwargs) -> Dict:
- data, errors = get_query_response(client, query_str, **kwargs)
+def get_error_response(
+ client, query_str: str, response_code: int = 200, **kwargs
+) -> Dict:
+ data, errors = get_query_response(
+ client, query_str, response_code=response_code, **kwargs
+ )
assert len(errors) > 0
return errors
diff --git a/swh/graphql/utils/__init__.py b/swh/graphql/utils/__init__.py
new file mode 100644

File Metadata

Mime Type
text/plain
Expires
Nov 5 2024, 8:20 AM (11 w, 18 h ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3221215

Event Timeline