diff --git a/swh/graph/backend.py b/swh/graph/backend.py index f8cbc02..22d4036 100644 --- a/swh/graph/backend.py +++ b/swh/graph/backend.py @@ -1,200 +1,206 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import asyncio import contextlib import io import os import struct import subprocess import sys import tempfile from py4j.java_gateway import JavaGateway from swh.graph.config import check_config from swh.graph.swhid import NodeToSwhidMap, SwhidToNodeMap from swh.model.identifiers import EXTENDED_SWHID_TYPES BUF_SIZE = 64 * 1024 BIN_FMT = ">q" # 64 bit integer, big endian PATH_SEPARATOR_ID = -1 NODE2SWHID_EXT = "node2swhid.bin" SWHID2NODE_EXT = "swhid2node.bin" def _get_pipe_stderr(): # Get stderr if possible, or pipe to stdout if running with Jupyter. try: sys.stderr.fileno() except io.UnsupportedOperation: return subprocess.STDOUT else: return sys.stderr class Backend: def __init__(self, graph_path, config=None): self.gateway = None self.entry = None self.graph_path = graph_path self.config = check_config(config or {}) - def __enter__(self): + def start_gateway(self): self.gateway = JavaGateway.launch_gateway( java_path=None, javaopts=self.config["java_tool_options"].split(), classpath=self.config["classpath"], die_on_exit=True, redirect_stdout=sys.stdout, redirect_stderr=_get_pipe_stderr(), ) self.entry = self.gateway.jvm.org.softwareheritage.graph.Entry() self.entry.load_graph(self.graph_path) self.node2swhid = NodeToSwhidMap(self.graph_path + "." + NODE2SWHID_EXT) self.swhid2node = SwhidToNodeMap(self.graph_path + "." + SWHID2NODE_EXT) self.stream_proxy = JavaStreamProxy(self.entry) + + def stop_gateway(self): + self.gateway.shutdown() + + def __enter__(self): + self.start_gateway() return self def __exit__(self, exc_type, exc_value, tb): - self.gateway.shutdown() + self.stop_gateway() def stats(self): return self.entry.stats() def count(self, ttype, direction, edges_fmt, src): method = getattr(self.entry, "count_" + ttype) return method(direction, edges_fmt, src) async def simple_traversal( self, ttype, direction, edges_fmt, src, max_edges, return_types ): assert ttype in ("leaves", "neighbors", "visit_nodes") method = getattr(self.stream_proxy, ttype) async for node_id in method(direction, edges_fmt, src, max_edges, return_types): yield node_id async def walk(self, direction, edges_fmt, algo, src, dst): if dst in EXTENDED_SWHID_TYPES: it = self.stream_proxy.walk_type(direction, edges_fmt, algo, src, dst) else: it = self.stream_proxy.walk(direction, edges_fmt, algo, src, dst) async for node_id in it: yield node_id async def random_walk(self, direction, edges_fmt, retries, src, dst, return_types): if dst in EXTENDED_SWHID_TYPES: it = self.stream_proxy.random_walk_type( direction, edges_fmt, retries, src, dst, return_types ) else: it = self.stream_proxy.random_walk( direction, edges_fmt, retries, src, dst, return_types ) async for node_id in it: # TODO return 404 if path is empty yield node_id async def visit_edges(self, direction, edges_fmt, src, max_edges): it = self.stream_proxy.visit_edges(direction, edges_fmt, src, max_edges) # convert stream a, b, c, d -> (a, b), (c, d) prevNode = None async for node in it: if prevNode is not None: yield (prevNode, node) prevNode = None else: prevNode = node async def visit_paths(self, direction, edges_fmt, src, max_edges): path = [] async for node in self.stream_proxy.visit_paths( direction, edges_fmt, src, max_edges ): if node == PATH_SEPARATOR_ID: yield path path = [] else: path.append(node) class JavaStreamProxy: """A proxy class for the org.softwareheritage.graph.Entry Java class that takes care of the setup and teardown of the named-pipe FIFO communication between Python and Java. Initialize JavaStreamProxy using: proxy = JavaStreamProxy(swh_entry_class_instance) Then you can call an Entry method and iterate on the FIFO results like this: async for value in proxy.java_method(arg1, arg2): print(value) """ def __init__(self, entry): self.entry = entry async def read_node_ids(self, fname): loop = asyncio.get_event_loop() open_thread = loop.run_in_executor(None, open, fname, "rb") # Since the open() call on the FIFO is blocking until it is also opened # on the Java side, we await it with a timeout in case there is an # exception that prevents the write-side open(). with (await asyncio.wait_for(open_thread, timeout=2)) as f: while True: data = await loop.run_in_executor(None, f.read, BUF_SIZE) if not data: break for data in struct.iter_unpack(BIN_FMT, data): yield data[0] class _HandlerWrapper: def __init__(self, handler): self._handler = handler def __getattr__(self, name): func = getattr(self._handler, name) async def java_call(*args, **kwargs): loop = asyncio.get_event_loop() await loop.run_in_executor(None, lambda: func(*args, **kwargs)) def java_task(*args, **kwargs): return asyncio.create_task(java_call(*args, **kwargs)) return java_task @contextlib.contextmanager def get_handler(self): with tempfile.TemporaryDirectory(prefix="swh-graph-") as tmpdirname: cli_fifo = os.path.join(tmpdirname, "swh-graph.fifo") os.mkfifo(cli_fifo) reader = self.read_node_ids(cli_fifo) query_handler = self.entry.get_handler(cli_fifo) handler = self._HandlerWrapper(query_handler) yield (handler, reader) def __getattr__(self, name): async def java_call_iterator(*args, **kwargs): with self.get_handler() as (handler, reader): java_task = getattr(handler, name)(*args, **kwargs) try: async for value in reader: yield value except asyncio.TimeoutError: # If the read-side open() timeouts, an exception on the # Java side probably happened that prevented the # write-side open(). We propagate this exception here if # that is the case. task_exc = java_task.exception() if task_exc: raise task_exc raise await java_task return java_call_iterator diff --git a/swh/graph/cli.py b/swh/graph/cli.py index ea3dae1..d6631f3 100644 --- a/swh/graph/cli.py +++ b/swh/graph/cli.py @@ -1,446 +1,443 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging from pathlib import Path import sys from typing import TYPE_CHECKING, Any, Dict, Set, Tuple # WARNING: do not import unnecessary things here to keep cli startup time under # control import click from swh.core.cli import CONTEXT_SETTINGS, AliasedGroup from swh.core.cli import swh as swh_cli_group if TYPE_CHECKING: from swh.graph.webgraph import CompressionStep # noqa class StepOption(click.ParamType): """click type for specifying a compression step on the CLI parse either individual steps, specified as step names or integers, or step ranges """ name = "compression step" def convert(self, value, param, ctx): # type: (...) -> Set[CompressionStep] from swh.graph.webgraph import COMP_SEQ, CompressionStep # noqa steps: Set[CompressionStep] = set() specs = value.split(",") for spec in specs: if "-" in spec: # step range (raw_l, raw_r) = spec.split("-", maxsplit=1) if raw_l == "": # no left endpoint raw_l = COMP_SEQ[0].name if raw_r == "": # no right endpoint raw_r = COMP_SEQ[-1].name l_step = self.convert(raw_l, param, ctx) r_step = self.convert(raw_r, param, ctx) if len(l_step) != 1 or len(r_step) != 1: self.fail(f"invalid step specification: {value}, " f"see --help") l_idx = l_step.pop() r_idx = r_step.pop() steps = steps.union( set(CompressionStep(i) for i in range(l_idx.value, r_idx.value + 1)) ) else: # singleton step try: steps.add(CompressionStep(int(spec))) # integer step except ValueError: try: steps.add(CompressionStep[spec.upper()]) # step name except KeyError: self.fail( f"invalid step specification: {value}, " f"see --help" ) return steps class PathlibPath(click.Path): """A Click path argument that returns a pathlib Path, not a string""" def convert(self, value, param, ctx): return Path(super().convert(value, param, ctx)) DEFAULT_CONFIG: Dict[str, Tuple[str, Any]] = {"graph": ("dict", {})} @swh_cli_group.group(name="graph", context_settings=CONTEXT_SETTINGS, cls=AliasedGroup) @click.option( "--config-file", "-C", default=None, type=click.Path(exists=True, dir_okay=False,), help="YAML configuration file", ) @click.pass_context def graph_cli_group(ctx, config_file): """Software Heritage graph tools.""" from swh.core import config ctx.ensure_object(dict) conf = config.read(config_file, DEFAULT_CONFIG) if "graph" not in conf: raise ValueError( 'no "graph" stanza found in configuration file %s' % config_file ) ctx.obj["config"] = conf @graph_cli_group.command("api-client") @click.option("--host", default="localhost", help="Graph server host") @click.option("--port", default="5009", help="Graph server port") @click.pass_context def api_client(ctx, host, port): """client for the graph RPC service""" from swh.graph import client url = "http://{}:{}".format(host, port) app = client.RemoteGraphClient(url) # TODO: run web app print(app.stats()) @graph_cli_group.group("map") @click.pass_context def map(ctx): """Manage swh-graph on-disk maps""" pass def dump_swhid2node(filename): from swh.graph.swhid import SwhidToNodeMap for (swhid, int) in SwhidToNodeMap(filename): print("{}\t{}".format(swhid, int)) def dump_node2swhid(filename): from swh.graph.swhid import NodeToSwhidMap for (int, swhid) in NodeToSwhidMap(filename): print("{}\t{}".format(int, swhid)) def restore_swhid2node(filename): """read a textual SWHID->int map from stdin and write its binary version to filename """ from swh.graph.swhid import SwhidToNodeMap with open(filename, "wb") as dst: for line in sys.stdin: (str_swhid, str_int) = line.split() SwhidToNodeMap.write_record(dst, str_swhid, int(str_int)) def restore_node2swhid(filename, length): """read a textual int->SWHID map from stdin and write its binary version to filename """ from swh.graph.swhid import NodeToSwhidMap node2swhid = NodeToSwhidMap(filename, mode="wb", length=length) for line in sys.stdin: (str_int, str_swhid) = line.split() node2swhid[int(str_int)] = str_swhid node2swhid.close() @map.command("dump") @click.option( "--type", "-t", "map_type", required=True, type=click.Choice(["swhid2node", "node2swhid"]), help="type of map to dump", ) @click.argument("filename", required=True, type=click.Path(exists=True)) @click.pass_context def dump_map(ctx, map_type, filename): """Dump a binary SWHID<->node map to textual format.""" if map_type == "swhid2node": dump_swhid2node(filename) elif map_type == "node2swhid": dump_node2swhid(filename) else: raise ValueError("invalid map type: " + map_type) pass @map.command("restore") @click.option( "--type", "-t", "map_type", required=True, type=click.Choice(["swhid2node", "node2swhid"]), help="type of map to dump", ) @click.option( "--length", "-l", type=int, help="""map size in number of logical records (required for node2swhid maps)""", ) @click.argument("filename", required=True, type=click.Path()) @click.pass_context def restore_map(ctx, map_type, length, filename): """Restore a binary SWHID<->node map from textual format.""" if map_type == "swhid2node": restore_swhid2node(filename) elif map_type == "node2swhid": if length is None: raise click.UsageError( "map length is required when restoring {} maps".format(map_type), ctx ) restore_node2swhid(filename, length) else: raise ValueError("invalid map type: " + map_type) @map.command("write") @click.option( "--type", "-t", "map_type", required=True, type=click.Choice(["swhid2node", "node2swhid"]), help="type of map to write", ) @click.argument("filename", required=True, type=click.Path()) @click.pass_context def write(ctx, map_type, filename): """Write a map to disk sequentially. read from stdin a textual SWHID->node mapping (for swhid2node, or a simple sequence of SWHIDs for node2swhid) and write it to disk in the requested binary map format note that no sorting is applied, so the input should already be sorted as required by the chosen map type (by SWHID for swhid2node, by int for node2swhid) """ from swh.graph.swhid import NodeToSwhidMap, SwhidToNodeMap with open(filename, "wb") as f: if map_type == "swhid2node": for line in sys.stdin: (swhid, int_str) = line.rstrip().split(maxsplit=1) SwhidToNodeMap.write_record(f, swhid, int(int_str)) elif map_type == "node2swhid": for line in sys.stdin: swhid = line.rstrip() NodeToSwhidMap.write_record(f, swhid) else: raise ValueError("invalid map type: " + map_type) @map.command("lookup") @click.option( "--graph", "-g", required=True, metavar="GRAPH", help="compressed graph basename" ) @click.argument("identifiers", nargs=-1) def map_lookup(graph, identifiers): """Lookup identifiers using on-disk maps. Depending on the identifier type lookup either a SWHID into a SWHID->node (and return the node integer identifier) or, vice-versa, lookup a node integer identifier into a node->SWHID (and return the SWHID). The desired behavior is chosen depending on the syntax of each given identifier. Identifiers can be passed either directly on the command line or on standard input, separate by blanks. Logical lines (as returned by readline()) in stdin will be preserved in stdout. """ from swh.graph.backend import NODE2SWHID_EXT, SWHID2NODE_EXT from swh.graph.swhid import NodeToSwhidMap, SwhidToNodeMap import swh.model.exceptions from swh.model.identifiers import ExtendedSWHID success = True # no identifiers failed to be looked up swhid2node = SwhidToNodeMap(f"{graph}.{SWHID2NODE_EXT}") node2swhid = NodeToSwhidMap(f"{graph}.{NODE2SWHID_EXT}") def lookup(identifier): nonlocal success, swhid2node, node2swhid is_swhid = None try: int(identifier) is_swhid = False except ValueError: try: ExtendedSWHID.from_string(identifier) is_swhid = True except swh.model.exceptions.ValidationError: success = False logging.error(f'invalid identifier: "{identifier}", skipping') try: if is_swhid: return str(swhid2node[identifier]) else: return node2swhid[int(identifier)] except KeyError: success = False logging.error(f'identifier not found: "{identifier}", skipping') if identifiers: # lookup identifiers passed via CLI for identifier in identifiers: print(lookup(identifier)) else: # lookup identifiers passed via stdin, preserving logical lines for line in sys.stdin: results = [lookup(id) for id in line.rstrip().split()] if results: # might be empty if all IDs on the same line failed print(" ".join(results)) sys.exit(0 if success else 1) @graph_cli_group.command(name="rpc-serve") @click.option( "--host", "-h", default="0.0.0.0", metavar="IP", show_default=True, help="host IP address to bind the server on", ) @click.option( "--port", "-p", default=5009, type=click.INT, metavar="PORT", show_default=True, help="port to bind the server on", ) @click.option( "--graph", "-g", required=True, metavar="GRAPH", help="compressed graph basename" ) @click.pass_context def serve(ctx, host, port, graph): """run the graph RPC service""" import aiohttp - from swh.graph.backend import Backend from swh.graph.server.app import make_app - backend = Backend(graph_path=graph, config=ctx.obj["config"]) - app = make_app(backend=backend) + app = make_app(config=ctx.obj["config"]) - with backend: - aiohttp.web.run_app(app, host=host, port=port) + aiohttp.web.run_app(app, host=host, port=port) @graph_cli_group.command() @click.option( "--graph", "-g", required=True, metavar="GRAPH", type=PathlibPath(), help="input graph basename", ) @click.option( "--outdir", "-o", "out_dir", required=True, metavar="DIR", type=PathlibPath(), help="directory where to store compressed graph", ) @click.option( "--steps", "-s", metavar="STEPS", type=StepOption(), help="run only these compression steps (default: all steps)", ) @click.pass_context def compress(ctx, graph, out_dir, steps): """Compress a graph using WebGraph Input: a pair of files g.nodes.csv.gz, g.edges.csv.gz Output: a directory containing a WebGraph compressed graph Compression steps are: (1) mph, (2) bv, (3) bv_obl, (4) bfs, (5) permute, (6) permute_obl, (7) stats, (8) transpose, (9) transpose_obl, (10) maps, (11) clean_tmp. Compression steps can be selected by name or number using --steps, separating them with commas; step ranges (e.g., 3-9, 6-, etc.) are also supported. """ from swh.graph import webgraph graph_name = graph.name in_dir = graph.parent try: conf = ctx.obj["config"]["graph"]["compress"] except KeyError: conf = {} # use defaults webgraph.compress(graph_name, in_dir, out_dir, steps, conf) @graph_cli_group.command(name="cachemount") @click.option( "--graph", "-g", required=True, metavar="GRAPH", help="compressed graph basename" ) @click.option( "--cache", "-c", default="/dev/shm/swh-graph/default", metavar="CACHE", type=PathlibPath(), help="Memory cache path (defaults to /dev/shm/swh-graph/default)", ) @click.pass_context def cachemount(ctx, graph, cache): """ Cache the mmapped files of the compressed graph in a tmpfs. This command creates a new directory at the path given by CACHE that has the same structure as the compressed graph basename, except it copies the files that require mmap access (:file:`{*}.graph`) but uses symlinks from the source for all the other files (:file:`{*}.map`, :file:`{*}.bin`, ...). The command outputs the path to the memory cache directory (particularly useful when relying on the default value). """ import shutil cache.mkdir(parents=True) for src in Path(graph).parent.glob("*"): dst = cache / src.name if src.suffix == ".graph": shutil.copy2(src, dst) else: dst.symlink_to(src.resolve()) print(cache) def main(): return graph_cli_group(auto_envvar_prefix="SWH_GRAPH") if __name__ == "__main__": main() diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py index 227524a..d071fd4 100644 --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -1,359 +1,392 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """ A proxy HTTP server for swh-graph, talking to the Java code via py4j, and using FIFO as a transport to stream integers between the two languages. """ import asyncio from collections import deque import json +import os from typing import Optional import aiohttp.web from swh.core.api.asynchronous import RPCServerApp +from swh.core.config import read as config_read +from swh.graph.backend import Backend from swh.model.exceptions import ValidationError from swh.model.identifiers import EXTENDED_SWHID_TYPES try: from contextlib import asynccontextmanager except ImportError: # Compatibility with 3.6 backport from async_generator import asynccontextmanager # type: ignore # maximum number of retries for random walks RANDOM_RETRIES = 5 # TODO make this configurable via rpc-serve configuration +class GraphServerApp(RPCServerApp): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.on_startup.append(self._start_gateway) + self.on_shutdown.append(self._stop_gateway) + + @staticmethod + async def _start_gateway(app): + # Equivalent to entering `with app["backend"]:` + app["backend"].start_gateway() + + @staticmethod + async def _stop_gateway(app): + # Equivalent to exiting `with app["backend"]:` with no error + app["backend"].stop_gateway() + + async def index(request): return aiohttp.web.Response( content_type="text/html", body=""" Software Heritage storage server

You have reached the Software Heritage graph API server.

See its API documentation for more information.

""", ) class GraphView(aiohttp.web.View): """Base class for views working on the graph, with utility functions""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.backend = self.request.app["backend"] def node_of_swhid(self, swhid): """Lookup a SWHID in a swhid2node map, failing in an HTTP-nice way if needed.""" try: return self.backend.swhid2node[swhid] except KeyError: raise aiohttp.web.HTTPNotFound(text=f"SWHID not found: {swhid}") except ValidationError: raise aiohttp.web.HTTPBadRequest(text=f"malformed SWHID: {swhid}") def swhid_of_node(self, node): """Lookup a node in a node2swhid map, failing in an HTTP-nice way if needed.""" try: return self.backend.node2swhid[node] except KeyError: raise aiohttp.web.HTTPInternalServerError( text=f"reverse lookup failed for node id: {node}" ) def get_direction(self): """Validate HTTP query parameter `direction`""" s = self.request.query.get("direction", "forward") if s not in ("forward", "backward"): raise aiohttp.web.HTTPBadRequest(text=f"invalid direction: {s}") return s def get_edges(self): """Validate HTTP query parameter `edges`, i.e., edge restrictions""" s = self.request.query.get("edges", "*") if any( [ node_type != "*" and node_type not in EXTENDED_SWHID_TYPES for edge in s.split(":") for node_type in edge.split(",", maxsplit=1) ] ): raise aiohttp.web.HTTPBadRequest(text=f"invalid edge restriction: {s}") return s def get_return_types(self): """Validate HTTP query parameter 'return types', i.e, a set of types which we will filter the query results with""" s = self.request.query.get("return_types", "*") if any( node_type != "*" and node_type not in EXTENDED_SWHID_TYPES for node_type in s.split(",") ): raise aiohttp.web.HTTPBadRequest( text=f"invalid type for filtering res: {s}" ) # if the user puts a star, # then we filter nothing, we don't need the other information if "*" in s: return "*" else: return s def get_traversal(self): """Validate HTTP query parameter `traversal`, i.e., visit order""" s = self.request.query.get("traversal", "dfs") if s not in ("bfs", "dfs"): raise aiohttp.web.HTTPBadRequest(text=f"invalid traversal order: {s}") return s def get_limit(self): """Validate HTTP query parameter `limit`, i.e., number of results""" s = self.request.query.get("limit", "0") try: return int(s) except ValueError: raise aiohttp.web.HTTPBadRequest(text=f"invalid limit value: {s}") def get_max_edges(self): """Validate HTTP query parameter 'max_edges', i.e., the limit of the number of edges that can be visited""" s = self.request.query.get("max_edges", "0") try: return int(s) except ValueError: raise aiohttp.web.HTTPBadRequest(text=f"invalid max_edges value: {s}") class StreamingGraphView(GraphView): """Base class for views streaming their response line by line.""" content_type = "text/plain" @asynccontextmanager async def response_streamer(self, *args, **kwargs): """Context manager to prepare then close a StreamResponse""" response = aiohttp.web.StreamResponse(*args, **kwargs) response.content_type = self.content_type await response.prepare(self.request) yield response await response.write_eof() async def get(self): await self.prepare_response() async with self.response_streamer() as self.response_stream: await self.stream_response() return self.response_stream async def prepare_response(self): """This can be overridden with some setup to be run before the response actually starts streaming. """ pass async def stream_response(self): """Override this to perform the response streaming. Implementations of this should await self.stream_line(line) to write each line. """ raise NotImplementedError async def stream_line(self, line): """Write a line in the response stream.""" await self.response_stream.write((line + "\n").encode()) class StatsView(GraphView): """View showing some statistics on the graph""" async def get(self): stats = self.backend.stats() return aiohttp.web.Response(body=stats, content_type="application/json") class SimpleTraversalView(StreamingGraphView): """Base class for views of simple traversals""" simple_traversal_type: Optional[str] = None async def prepare_response(self): src = self.request.match_info["src"] self.src_node = self.node_of_swhid(src) self.edges = self.get_edges() self.direction = self.get_direction() self.max_edges = self.get_max_edges() self.return_types = self.get_return_types() async def stream_response(self): async for res_node in self.backend.simple_traversal( self.simple_traversal_type, self.direction, self.edges, self.src_node, self.max_edges, self.return_types, ): res_swhid = self.swhid_of_node(res_node) await self.stream_line(res_swhid) class LeavesView(SimpleTraversalView): simple_traversal_type = "leaves" class NeighborsView(SimpleTraversalView): simple_traversal_type = "neighbors" class VisitNodesView(SimpleTraversalView): simple_traversal_type = "visit_nodes" class WalkView(StreamingGraphView): async def prepare_response(self): src = self.request.match_info["src"] dst = self.request.match_info["dst"] self.src_node = self.node_of_swhid(src) if dst not in EXTENDED_SWHID_TYPES: self.dst_thing = self.node_of_swhid(dst) else: self.dst_thing = dst self.edges = self.get_edges() self.direction = self.get_direction() self.algo = self.get_traversal() self.limit = self.get_limit() self.return_types = self.get_return_types() async def get_walk_iterator(self): return self.backend.walk( self.direction, self.edges, self.algo, self.src_node, self.dst_thing ) async def stream_response(self): it = self.get_walk_iterator() if self.limit < 0: queue = deque(maxlen=-self.limit) async for res_node in it: res_swhid = self.swhid_of_node(res_node) queue.append(res_swhid) while queue: await self.stream_line(queue.popleft()) else: count = 0 async for res_node in it: if self.limit == 0 or count < self.limit: res_swhid = self.swhid_of_node(res_node) await self.stream_line(res_swhid) count += 1 else: break class RandomWalkView(WalkView): def get_walk_iterator(self): return self.backend.random_walk( self.direction, self.edges, RANDOM_RETRIES, self.src_node, self.dst_thing, self.return_types, ) class VisitEdgesView(SimpleTraversalView): async def stream_response(self): it = self.backend.visit_edges( self.direction, self.edges, self.src_node, self.max_edges ) async for (res_src, res_dst) in it: res_src_swhid = self.swhid_of_node(res_src) res_dst_swhid = self.swhid_of_node(res_dst) await self.stream_line("{} {}".format(res_src_swhid, res_dst_swhid)) class VisitPathsView(SimpleTraversalView): content_type = "application/x-ndjson" async def stream_response(self): it = self.backend.visit_paths( self.direction, self.edges, self.src_node, self.max_edges ) async for res_path in it: res_path_swhid = [self.swhid_of_node(n) for n in res_path] line = json.dumps(res_path_swhid) await self.stream_line(line) class CountView(GraphView): """Base class for counting views.""" count_type: Optional[str] = None async def get(self): src = self.request.match_info["src"] self.src_node = self.node_of_swhid(src) self.edges = self.get_edges() self.direction = self.get_direction() loop = asyncio.get_event_loop() cnt = await loop.run_in_executor( None, self.backend.count, self.count_type, self.direction, self.edges, self.src_node, ) return aiohttp.web.Response(body=str(cnt), content_type="application/json") class CountNeighborsView(CountView): count_type = "neighbors" class CountLeavesView(CountView): count_type = "leaves" class CountVisitNodesView(CountView): count_type = "visit_nodes" -def make_app(backend, **kwargs): - app = RPCServerApp(**kwargs) +def make_app(config=None, backend=None, **kwargs): + if (config is None) == (backend is None): + raise ValueError("make_app() expects exactly one of 'config' or 'backend'") + if backend is None: + backend = Backend(graph_path=config["graph"]["path"], config=config["graph"]) + app = GraphServerApp(**kwargs) app.add_routes( [ aiohttp.web.get("/", index), aiohttp.web.get("/graph", index), aiohttp.web.view("/graph/stats", StatsView), aiohttp.web.view("/graph/leaves/{src}", LeavesView), aiohttp.web.view("/graph/neighbors/{src}", NeighborsView), aiohttp.web.view("/graph/visit/nodes/{src}", VisitNodesView), aiohttp.web.view("/graph/visit/edges/{src}", VisitEdgesView), aiohttp.web.view("/graph/visit/paths/{src}", VisitPathsView), # temporarily disabled in wait of a proper fix for T1969 # aiohttp.web.view("/graph/walk/{src}/{dst}", WalkView) aiohttp.web.view("/graph/randomwalk/{src}/{dst}", RandomWalkView), aiohttp.web.view("/graph/neighbors/count/{src}", CountNeighborsView), aiohttp.web.view("/graph/leaves/count/{src}", CountLeavesView), aiohttp.web.view("/graph/visit/nodes/count/{src}", CountVisitNodesView), ] ) app["backend"] = backend return app + + +def make_app_from_configfile(): + """Load configuration and then build application to run + + """ + config_file = os.environ.get("SWH_CONFIG_FILENAME") + config = config_read(config_file) + return make_app(config=config) diff --git a/swh/graph/tests/conftest.py b/swh/graph/tests/conftest.py index e66a789..aad4fdb 100644 --- a/swh/graph/tests/conftest.py +++ b/swh/graph/tests/conftest.py @@ -1,65 +1,64 @@ # Copyright (C) 2019-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import csv import multiprocessing from pathlib import Path from aiohttp.test_utils import TestClient, TestServer, loop_context import pytest from swh.graph.backend import Backend from swh.graph.client import RemoteGraphClient from swh.graph.graph import load as graph_load from swh.graph.naive_client import NaiveClient from swh.graph.server.app import make_app SWH_GRAPH_TESTS_ROOT = Path(__file__).parents[0] TEST_GRAPH_PATH = SWH_GRAPH_TESTS_ROOT / "dataset/output/example" class GraphServerProcess(multiprocessing.Process): def __init__(self, q, *args, **kwargs): self.q = q super().__init__(*args, **kwargs) def run(self): try: backend = Backend(graph_path=str(TEST_GRAPH_PATH)) - with backend: - with loop_context() as loop: - app = make_app(backend=backend, debug=True) - client = TestClient(TestServer(app), loop=loop) - loop.run_until_complete(client.start_server()) - url = client.make_url("/graph/") - self.q.put(url) - loop.run_forever() + with loop_context() as loop: + app = make_app(backend=backend, debug=True) + client = TestClient(TestServer(app), loop=loop) + loop.run_until_complete(client.start_server()) + url = client.make_url("/graph/") + self.q.put(url) + loop.run_forever() except Exception as e: self.q.put(e) @pytest.fixture(scope="module", params=["remote", "naive"]) def graph_client(request): if request.param == "remote": queue = multiprocessing.Queue() server = GraphServerProcess(queue) server.start() res = queue.get() if isinstance(res, Exception): raise res yield RemoteGraphClient(str(res)) server.terminate() else: with open(SWH_GRAPH_TESTS_ROOT / "dataset/example.nodes.csv") as fd: nodes = [node for (node,) in csv.reader(fd, delimiter=" ")] with open(SWH_GRAPH_TESTS_ROOT / "dataset/example.edges.csv") as fd: edges = list(csv.reader(fd, delimiter=" ")) yield NaiveClient(nodes=nodes, edges=edges) @pytest.fixture(scope="module") def graph(): with graph_load(str(TEST_GRAPH_PATH)) as g: yield g