Page MenuHomeSoftware Heritage

D8447.id30838.diff
No OneTemporary

D8447.id30838.diff

diff --git a/swh/graph/http_client.py b/swh/graph/http_client.py
--- a/swh/graph/http_client.py
+++ b/swh/graph/http_client.py
@@ -52,7 +52,13 @@
return self.get("stats")
def leaves(
- self, src, edges="*", direction="forward", max_edges=0, return_types="*"
+ self,
+ src,
+ edges="*",
+ direction="forward",
+ max_edges=0,
+ return_types="*",
+ max_matching_nodes=0,
):
return self.get_lines(
"leaves/{}".format(src),
@@ -61,6 +67,7 @@
"direction": direction,
"max_edges": max_edges,
"return_types": return_types,
+ "max_matching_nodes": max_matching_nodes,
},
)
@@ -137,10 +144,14 @@
},
)
- def count_leaves(self, src, edges="*", direction="forward"):
+ def count_leaves(self, src, edges="*", direction="forward", max_matching_nodes=0):
return self.get(
"leaves/count/{}".format(src),
- params={"edges": edges, "direction": direction},
+ params={
+ "edges": edges,
+ "direction": direction,
+ "max_matching_nodes": max_matching_nodes,
+ },
)
def count_neighbors(self, src, edges="*", direction="forward"):
diff --git a/swh/graph/http_naive_client.py b/swh/graph/http_naive_client.py
--- a/swh/graph/http_naive_client.py
+++ b/swh/graph/http_naive_client.py
@@ -1,10 +1,11 @@
-# Copyright (C) 2021 The Software Heritage developers
+# Copyright (C) 2021-2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import functools
import inspect
+import itertools
import re
import statistics
from typing import (
@@ -150,9 +151,10 @@
direction: str = "forward",
max_edges: int = 0,
return_types: str = "*",
+ max_matching_nodes: int = 0,
) -> Iterator[str]:
# TODO: max_edges
- yield from filter_node_types(
+ leaves = filter_node_types(
return_types,
[
node
@@ -161,6 +163,11 @@
],
)
+ if max_matching_nodes > 0:
+ leaves = itertools.islice(leaves, max_matching_nodes)
+
+ return leaves
+
@check_arguments
def neighbors(
self,
@@ -250,9 +257,19 @@
@check_arguments
def count_leaves(
- self, src: str, edges: str = "*", direction: str = "forward"
+ self,
+ src: str,
+ edges: str = "*",
+ direction: str = "forward",
+ max_matching_nodes: int = 0,
) -> int:
- return len(list(self.leaves(src, edges, direction)))
+ return len(
+ list(
+ self.leaves(
+ src, edges, direction, max_matching_nodes=max_matching_nodes
+ )
+ )
+ )
@check_arguments
def count_neighbors(
diff --git a/swh/graph/http_rpc_server.py b/swh/graph/http_rpc_server.py
--- a/swh/graph/http_rpc_server.py
+++ b/swh/graph/http_rpc_server.py
@@ -144,13 +144,15 @@
else:
return s
- def get_limit(self):
- """Validate HTTP query parameter `limit`, i.e., number of results"""
- s = self.request.query.get("limit", "0")
+ def get_max_matching_nodes(self):
+ """Validate HTTP query parameter `max_matching_nodes`, i.e., number of results"""
+ s = self.request.query.get("max_matching_nodes", "0")
try:
return int(s)
except ValueError:
- raise aiohttp.web.HTTPBadRequest(text=f"invalid limit value: {s}")
+ raise aiohttp.web.HTTPBadRequest(
+ text=f"invalid max_matching_nodes value: {s}"
+ )
def get_max_edges(self):
"""Validate HTTP query parameter 'max_edges', i.e.,
@@ -249,6 +251,7 @@
direction=self.get_direction(),
return_nodes=NodeFilter(types=self.get_return_types()),
mask=FieldMask(paths=["swhid"]),
+ max_matching_nodes=self.get_max_matching_nodes(),
)
if self.get_max_edges():
self.traversal_request.max_edges = self.get_max_edges()
@@ -307,6 +310,7 @@
direction=self.get_direction(),
return_nodes=NodeFilter(types=self.get_return_types()),
mask=FieldMask(paths=["swhid"]),
+ max_matching_nodes=self.get_max_matching_nodes(),
)
if self.get_max_edges():
self.traversal_request.max_edges = self.get_max_edges()
diff --git a/swh/graph/tests/test_http_client.py b/swh/graph/tests/test_http_client.py
--- a/swh/graph/tests/test_http_client.py
+++ b/swh/graph/tests/test_http_client.py
@@ -43,6 +43,25 @@
assert set(actual) == set(expected)
+@pytest.mark.parametrize("max_matching_nodes", [0, 1, 2, 3, 4, 5, 10, 1 << 31])
+def test_leaves_with_limit(graph_client, max_matching_nodes):
+ actual = list(
+ graph_client.leaves(TEST_ORIGIN_ID, max_matching_nodes=max_matching_nodes)
+ )
+ expected = [
+ "swh:1:cnt:0000000000000000000000000000000000000001",
+ "swh:1:cnt:0000000000000000000000000000000000000004",
+ "swh:1:cnt:0000000000000000000000000000000000000005",
+ "swh:1:cnt:0000000000000000000000000000000000000007",
+ ]
+
+ if max_matching_nodes == 0:
+ assert set(actual) == set(expected)
+ else:
+ assert set(actual) <= set(expected)
+ assert len(actual) == min(4, max_matching_nodes)
+
+
def test_neighbors(graph_client):
actual = list(
graph_client.neighbors(
@@ -326,6 +345,17 @@
assert actual == 3
+@pytest.mark.parametrize("max_matching_nodes", [0, 1, 2, 3, 4, 5, 10, 1 << 31])
+def test_count_with_limit(graph_client, max_matching_nodes):
+ actual = graph_client.count_leaves(
+ TEST_ORIGIN_ID, max_matching_nodes=max_matching_nodes
+ )
+ if max_matching_nodes == 0:
+ assert actual == 4
+ else:
+ assert actual == min(4, max_matching_nodes)
+
+
def test_param_validation(graph_client):
with raises(GraphArgumentException) as exc_info: # SWHID not found
list(graph_client.leaves("swh:1:rel:00ffffffff000000000000000000000000000010"))

File Metadata

Mime Type
text/plain
Expires
Wed, Dec 18, 4:36 AM (22 h, 32 m ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3222850

Event Timeline