diff --git a/docs/api.rst b/docs/api.rst --- a/docs/api.rst +++ b/docs/api.rst @@ -83,6 +83,8 @@ default to 0 (not restricted) :query string return_types: only return the nodes matching this type; default to ``"*"`` + :query integer max_matching_nodes: how many results to return before stopping; + default to 0 (not restricted) :statuscode 200: success :statuscode 400: invalid query string provided @@ -207,6 +209,8 @@ default to 0 (not restricted) :query string return_types: only return the nodes matching this type; default to ``"*"`` + :query integer max_matching_nodes: how many nodes to return/visit before stopping; + default to 0 (not restricted) :statuscode 200: success :statuscode 400: invalid query string provided diff --git a/swh/graph/cli.py b/swh/graph/cli.py --- a/swh/graph/cli.py +++ b/swh/graph/cli.py @@ -1,9 +1,11 @@ -# Copyright (C) 2019-2020 The Software Heritage developers +# Copyright (C) 2019-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import logging from pathlib import Path +import shlex from typing import TYPE_CHECKING, Any, Dict, Set, Tuple # WARNING: do not import unnecessary things here to keep cli startup time under @@ -16,6 +18,8 @@ if TYPE_CHECKING: from swh.graph.webgraph import CompressionStep # noqa +logger = logging.getLogger(__name__) + class StepOption(click.ParamType): """click type for specifying a compression step on the CLI @@ -171,13 +175,17 @@ config = ctx.obj["config"] config.setdefault("graph", {}) config["graph"]["path"] = graph + + logger.debug("Building gPRC server command line") cmd, port = build_grpc_server_cmdline(**config["graph"]) java_bin = cmd[0] if java_home is not None: java_bin = str(Path(java_home) / "bin" / java_bin) - print(f"Starting the GRPC server on 0.0.0.0:{port}") + # XXX: shlex.join() is in 3.8 + # logger.info("Starting gRPC server: %s", shlex.join(cmd)) + logger.info("Starting gRPC server: %s", " ".join(shlex.quote(x) for x in cmd)) os.execvp(java_bin, cmd) diff --git a/swh/graph/config.py b/swh/graph/config.py --- a/swh/graph/config.py +++ b/swh/graph/config.py @@ -1,4 +1,4 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -9,6 +9,8 @@ import psutil +logger = logging.getLogger(__name__) + def find_graph_jar(): """find swh-graph.jar, containing the Java part of swh-graph @@ -17,6 +19,7 @@ deployments who fecthed the JAR from pypi) """ + logger.debug("Looking for swh-graph JAR") swh_graph_root = Path(__file__).parents[2] try_paths = [ swh_graph_root / "java/target/", @@ -24,13 +27,14 @@ Path(sys.prefix) / "local/share/swh-graph/", ] for path in try_paths: + logger.debug("Looking for swh-graph JAR in %s", path) glob = list(path.glob("swh-graph-*.jar")) if glob: if len(glob) > 1: - logging.warning( + logger.warning( "found multiple swh-graph JARs, " "arbitrarily picking one" ) - logging.info("using swh-graph JAR: {0}".format(glob[0])) + logger.info("using swh-graph JAR: {0}".format(glob[0])) return str(glob[0]) raise RuntimeError("swh-graph JAR not found. Have you run `make java`?") @@ -42,10 +46,13 @@ # Use 0.1% of the RAM as a batch size: # ~1 billion for big servers, ~10 million for small desktop machines conf["batch_size"] = min(int(psutil.virtual_memory().total / 1000), 2**30 - 1) + logger.debug("batch_size not configured, defaulting to %s", conf["batch_size"]) if "llp_gammas" not in conf: conf["llp_gammas"] = "-0,-1,-2,-3,-4" + logger.debug("llp_gammas not configured, defaulting to %s", conf["llp_gammas"]) if "max_ram" not in conf: conf["max_ram"] = str(int(psutil.virtual_memory().total * 0.9)) + logger.debug("max_ram not configured, defaulting to %s", conf["max_ram"]) if "java_tool_options" not in conf: conf["java_tool_options"] = " ".join( [ @@ -59,6 +66,10 @@ "-XX:+ResizeTLAB", ] ) + logger.debug( + "java_tool_options not providing, defaulting to %s", + conf["java_tool_options"], + ) conf["java_tool_options"] = conf["java_tool_options"].format( max_ram=conf["max_ram"] ) diff --git a/swh/graph/grpc_server.py b/swh/graph/grpc_server.py --- a/swh/graph/grpc_server.py +++ b/swh/graph/grpc_server.py @@ -16,12 +16,17 @@ from swh.graph.config import check_config +logger = logging.getLogger(__name__) + def build_grpc_server_cmdline(**config): port = config.pop("port", None) if port is None: port = aiohttp.test_utils.unused_port() + logger.debug("Port not configured, using random port %s", port) + logger.debug("Checking configuration and populating default values") config = check_config(config) + logger.debug("Configuration: %r", config) cmd = [ "java", "--class-path", @@ -39,8 +44,8 @@ cmd, port = build_grpc_server_cmdline(**config) print(cmd) # XXX: shlex.join() is in 3.8 - # logging.info("Starting RPC server: %s", shlex.join(cmd)) - logging.info("Starting GRPC server: %s", " ".join(shlex.quote(x) for x in cmd)) + # logger.info("Starting gRPC server: %s", shlex.join(cmd)) + logger.info("Starting gRPC server: %s", " ".join(shlex.quote(x) for x in cmd)) server = subprocess.Popen(cmd) return server, port @@ -50,5 +55,5 @@ try: server.wait(timeout=timeout) except subprocess.TimeoutExpired: - logging.warning("Server did not terminate, sending kill signal...") + logger.warning("Server did not terminate, sending kill signal...") server.kill() diff --git a/swh/graph/http_client.py b/swh/graph/http_client.py --- a/swh/graph/http_client.py +++ b/swh/graph/http_client.py @@ -85,7 +85,13 @@ ) def visit_nodes( - self, src, edges="*", direction="forward", max_edges=0, return_types="*" + self, + src, + edges="*", + direction="forward", + max_edges=0, + return_types="*", + max_matching_nodes=0, ): return self.get_lines( "visit/nodes/{}".format(src), @@ -94,6 +100,7 @@ "direction": direction, "max_edges": max_edges, "return_types": return_types, + "max_matching_nodes": max_matching_nodes, }, ) @@ -160,8 +167,14 @@ params={"edges": edges, "direction": direction}, ) - def count_visit_nodes(self, src, edges="*", direction="forward"): + def count_visit_nodes( + self, src, edges="*", direction="forward", max_matching_nodes=0 + ): return self.get( "visit/nodes/count/{}".format(src), - params={"edges": edges, "direction": direction}, + params={ + "edges": edges, + "direction": direction, + "max_matching_nodes": max_matching_nodes, + }, ) diff --git a/swh/graph/http_naive_client.py b/swh/graph/http_naive_client.py --- a/swh/graph/http_naive_client.py +++ b/swh/graph/http_naive_client.py @@ -190,11 +190,15 @@ direction: str = "forward", max_edges: int = 0, return_types: str = "*", + max_matching_nodes: int = 0, ) -> Iterator[str]: # TODO: max_edges - yield from filter_node_types( + res = filter_node_types( return_types, self.graph.get_subgraph(src, edges, direction) ) + if max_matching_nodes > 0: + res = itertools.islice(res, max_matching_nodes) + return res @check_arguments def visit_edges( @@ -279,9 +283,16 @@ @check_arguments def count_visit_nodes( - self, src: str, edges: str = "*", direction: str = "forward" + self, + src: str, + edges: str = "*", + direction: str = "forward", + max_matching_nodes: int = 0, ) -> int: - return len(self.graph.get_subgraph(src, edges, direction)) + res = len(self.graph.get_subgraph(src, edges, direction)) + if max_matching_nodes > 0: + res = min(max_matching_nodes, res) + return res class Graph: diff --git a/swh/graph/tests/test_http_client.py b/swh/graph/tests/test_http_client.py --- a/swh/graph/tests/test_http_client.py +++ b/swh/graph/tests/test_http_client.py @@ -91,6 +91,27 @@ assert set(actual) == set(expected) +@pytest.mark.parametrize("max_matching_nodes", [0, 1, 2, 3, 4, 5, 10, 1 << 31]) +def test_visit_nodes_limit(graph_client, max_matching_nodes): + actual = list( + graph_client.visit_nodes( + "swh:1:rel:0000000000000000000000000000000000000010", + edges="rel:rev,rev:rev", + max_matching_nodes=max_matching_nodes, + ) + ) + expected = [ + "swh:1:rel:0000000000000000000000000000000000000010", + "swh:1:rev:0000000000000000000000000000000000000009", + "swh:1:rev:0000000000000000000000000000000000000003", + ] + if max_matching_nodes == 0: + assert set(actual) == set(expected) + else: + assert set(actual) <= set(expected) + assert len(actual) == min(3, max_matching_nodes) + + def test_visit_nodes_filtered(graph_client): actual = list( graph_client.visit_nodes(