diff --git a/java/src/main/java/org/softwareheritage/graph/Entry.java b/java/src/main/java/org/softwareheritage/graph/Entry.java index 0db1bb6..e110941 100644 --- a/java/src/main/java/org/softwareheritage/graph/Entry.java +++ b/java/src/main/java/org/softwareheritage/graph/Entry.java @@ -1,196 +1,193 @@ package org.softwareheritage.graph; -import java.io.DataOutputStream; -import java.io.FileOutputStream; -import java.io.IOException; +import java.io.*; import java.util.ArrayList; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.PropertyNamingStrategy; public class Entry { - private final long PATH_SEPARATOR_ID = -1; private Graph graph; public void load_graph(String graphBasename) throws IOException { System.err.println("Loading graph " + graphBasename + " ..."); this.graph = Graph.loadMapped(graphBasename); System.err.println("Graph loaded."); } public Graph get_graph() { return graph.copy(); } public String stats() { try { Stats stats = new Stats(graph.getPath()); ObjectMapper objectMapper = new ObjectMapper(); objectMapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE); return objectMapper.writeValueAsString(stats); } catch (IOException e) { throw new RuntimeException("Cannot read stats: " + e); } } + public void check_swhid(String src) { + graph.getNodeId(new SWHID(src)); + } + private int count_visitor(NodeCountVisitor f, long srcNodeId) { int[] count = {0}; f.accept(srcNodeId, (node) -> { count[0]++; }); return count[0]; } - public int count_leaves(String direction, String edgesFmt, long srcNodeId) { - Traversal t = new Traversal(this.graph.copy(), direction, edgesFmt); + public int count_leaves(String direction, String edgesFmt, String src) { + long srcNodeId = graph.getNodeId(new SWHID(src)); + Traversal t = new Traversal(graph.copy(), direction, edgesFmt); return count_visitor(t::leavesVisitor, srcNodeId); } - public int count_neighbors(String direction, String edgesFmt, long srcNodeId) { - Traversal t = new Traversal(this.graph.copy(), direction, edgesFmt); + public int count_neighbors(String direction, String edgesFmt, String src) { + long srcNodeId = graph.getNodeId(new SWHID(src)); + Traversal t = new Traversal(graph.copy(), direction, edgesFmt); return count_visitor(t::neighborsVisitor, srcNodeId); } - public int count_visit_nodes(String direction, String edgesFmt, long srcNodeId) { - Traversal t = new Traversal(this.graph.copy(), direction, edgesFmt); + public int count_visit_nodes(String direction, String edgesFmt, String src) { + long srcNodeId = graph.getNodeId(new SWHID(src)); + Traversal t = new Traversal(graph.copy(), direction, edgesFmt); return count_visitor(t::visitNodesVisitor, srcNodeId); } public QueryHandler get_handler(String clientFIFO) { - return new QueryHandler(this.graph.copy(), clientFIFO); + return new QueryHandler(graph.copy(), clientFIFO); } private interface NodeCountVisitor { void accept(long nodeId, Traversal.NodeIdConsumer consumer); } public class QueryHandler { Graph graph; - DataOutputStream out; + BufferedWriter out; String clientFIFO; public QueryHandler(Graph graph, String clientFIFO) { this.graph = graph; this.clientFIFO = clientFIFO; this.out = null; } - public void writeNode(long nodeId) { + public void writeNode(SWHID swhid) { try { - out.writeLong(nodeId); + out.write(swhid.toString() + "\n"); } catch (IOException e) { throw new RuntimeException("Cannot write response to client: " + e); } } - public void writeEdge(long srcId, long dstId) { - writeNode(srcId); - writeNode(dstId); - } - - public void writePath(ArrayList path) { - for (Long nodeId : path) { - writeNode(nodeId); + public void writeEdge(SWHID src, SWHID dst) { + try { + out.write(src.toString() + " " + dst.toString() + "\n"); + } catch (IOException e) { + throw new RuntimeException("Cannot write response to client: " + e); } - writeNode(PATH_SEPARATOR_ID); } public void open() { try { FileOutputStream file = new FileOutputStream(this.clientFIFO); - this.out = new DataOutputStream(file); + this.out = new BufferedWriter(new OutputStreamWriter(file)); } catch (IOException e) { throw new RuntimeException("Cannot open client FIFO: " + e); } } public void close() { try { out.close(); } catch (IOException e) { throw new RuntimeException("Cannot write response to client: " + e); } } - public void leaves(String direction, String edgesFmt, long srcNodeId, long maxEdges, String returnTypes) { + public void leaves(String direction, String edgesFmt, String src, long maxEdges, String returnTypes) { + long srcNodeId = graph.getNodeId(new SWHID(src)); open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges, returnTypes); + Traversal t = new Traversal(graph, direction, edgesFmt, maxEdges, returnTypes); for (Long nodeId : t.leaves(srcNodeId)) { - writeNode(nodeId); + writeNode(graph.getSWHID(nodeId)); } close(); } - public void neighbors(String direction, String edgesFmt, long srcNodeId, long maxEdges, String returnTypes) { + public void neighbors(String direction, String edgesFmt, String src, long maxEdges, String returnTypes) { + long srcNodeId = graph.getNodeId(new SWHID(src)); open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges, returnTypes); + Traversal t = new Traversal(graph, direction, edgesFmt, maxEdges, returnTypes); for (Long nodeId : t.neighbors(srcNodeId)) { - writeNode(nodeId); + writeNode(graph.getSWHID(nodeId)); } close(); } - public void visit_nodes(String direction, String edgesFmt, long srcNodeId, long maxEdges, String returnTypes) { + public void visit_nodes(String direction, String edgesFmt, String src, long maxEdges, String returnTypes) { + long srcNodeId = graph.getNodeId(new SWHID(src)); open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges, returnTypes); + Traversal t = new Traversal(graph, direction, edgesFmt, maxEdges, returnTypes); for (Long nodeId : t.visitNodes(srcNodeId)) { - writeNode(nodeId); + writeNode(graph.getSWHID(nodeId)); } close(); } - public void visit_edges(String direction, String edgesFmt, long srcNodeId, long maxEdges) { + public void visit_edges(String direction, String edgesFmt, String src, long maxEdges, String returnTypes) { + long srcNodeId = graph.getNodeId(new SWHID(src)); open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges); - t.visitNodesVisitor(srcNodeId, null, this::writeEdge); + Traversal t = new Traversal(graph, direction, edgesFmt, maxEdges); + t.visitNodesVisitor(srcNodeId, null, (srcId, dstId) -> { + writeEdge(graph.getSWHID(srcId), graph.getSWHID(dstId)); + }); close(); } - public void visit_paths(String direction, String edgesFmt, long srcNodeId, long maxEdges) { + public void walk(String direction, String edgesFmt, String algorithm, String src, String dst) { + long srcNodeId = graph.getNodeId(new SWHID(src)); open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges); - t.visitPathsVisitor(srcNodeId, this::writePath); - close(); - } - - public void walk(String direction, String edgesFmt, String algorithm, long srcNodeId, long dstNodeId) { - open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt); - for (Long nodeId : t.walk(srcNodeId, dstNodeId, algorithm)) { - writeNode(nodeId); + ArrayList res; + if (dst.matches("ori|snp|rel|rev|dir|cnt")) { + Node.Type dstType = Node.Type.fromStr(dst); + Traversal t = new Traversal(graph, direction, edgesFmt); + res = t.walk(srcNodeId, dstType, algorithm); + } else { + long dstNodeId = graph.getNodeId(new SWHID(dst)); + Traversal t = new Traversal(graph, direction, edgesFmt); + res = t.walk(srcNodeId, dstNodeId, algorithm); } - close(); - } - - public void walk_type(String direction, String edgesFmt, String algorithm, long srcNodeId, String dst) { - open(); - Node.Type dstType = Node.Type.fromStr(dst); - Traversal t = new Traversal(this.graph, direction, edgesFmt); - for (Long nodeId : t.walk(srcNodeId, dstType, algorithm)) { - writeNode(nodeId); + for (Long nodeId : res) { + writeNode(graph.getSWHID(nodeId)); } close(); } - public void random_walk(String direction, String edgesFmt, int retries, long srcNodeId, long dstNodeId, - String returnTypes) { + public void random_walk(String direction, String edgesFmt, int retries, String src, String dst) { + long srcNodeId = graph.getNodeId(new SWHID(src)); open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt, 0, returnTypes); - for (Long nodeId : t.randomWalk(srcNodeId, dstNodeId, retries)) { - writeNode(nodeId); + ArrayList res; + if (dst.matches("ori|snp|rel|rev|dir|cnt")) { + Node.Type dstType = Node.Type.fromStr(dst); + Traversal t = new Traversal(graph, direction, edgesFmt); + res = t.randomWalk(srcNodeId, dstType, retries); + } else { + long dstNodeId = graph.getNodeId(new SWHID(dst)); + Traversal t = new Traversal(graph, direction, edgesFmt); + res = t.randomWalk(srcNodeId, dstNodeId, retries); } - close(); - } - - public void random_walk_type(String direction, String edgesFmt, int retries, long srcNodeId, String dst, - String returnTypes) { - open(); - Node.Type dstType = Node.Type.fromStr(dst); - Traversal t = new Traversal(this.graph, direction, edgesFmt, 0, returnTypes); - for (Long nodeId : t.randomWalk(srcNodeId, dstType, retries)) { - writeNode(nodeId); + for (Long nodeId : res) { + writeNode(graph.getSWHID(nodeId)); } close(); } } } diff --git a/swh/graph/backend.py b/swh/graph/backend.py index de54810..b123238 100644 --- a/swh/graph/backend.py +++ b/swh/graph/backend.py @@ -1,206 +1,176 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import asyncio import contextlib import io import os -import struct +import re import subprocess import sys import tempfile from py4j.java_gateway import JavaGateway +from py4j.protocol import Py4JJavaError from swh.graph.config import check_config -from swh.graph.swhid import NodeToSwhidMap, SwhidToNodeMap -from swh.model.swhids import EXTENDED_SWHID_TYPES -BUF_SIZE = 64 * 1024 -BIN_FMT = ">q" # 64 bit integer, big endian -PATH_SEPARATOR_ID = -1 -NODE2SWHID_EXT = "node2swhid.bin" -SWHID2NODE_EXT = "swhid2node.bin" +BUF_LINES = 1024 def _get_pipe_stderr(): # Get stderr if possible, or pipe to stdout if running with Jupyter. try: sys.stderr.fileno() except io.UnsupportedOperation: return subprocess.STDOUT else: return sys.stderr class Backend: def __init__(self, graph_path, config=None): self.gateway = None self.entry = None self.graph_path = graph_path self.config = check_config(config or {}) def start_gateway(self): self.gateway = JavaGateway.launch_gateway( java_path=None, javaopts=self.config["java_tool_options"].split(), classpath=self.config["classpath"], die_on_exit=True, redirect_stdout=sys.stdout, redirect_stderr=_get_pipe_stderr(), ) self.entry = self.gateway.jvm.org.softwareheritage.graph.Entry() self.entry.load_graph(self.graph_path) - self.node2swhid = NodeToSwhidMap(self.graph_path + "." + NODE2SWHID_EXT) - self.swhid2node = SwhidToNodeMap(self.graph_path + "." + SWHID2NODE_EXT) self.stream_proxy = JavaStreamProxy(self.entry) def stop_gateway(self): self.gateway.shutdown() def __enter__(self): self.start_gateway() return self def __exit__(self, exc_type, exc_value, tb): self.stop_gateway() def stats(self): return self.entry.stats() + def check_swhid(self, swhid): + try: + self.entry.check_swhid(swhid) + except Py4JJavaError as e: + m = re.search(r"malformed SWHID: (\w+)", str(e)) + if m: + raise ValueError(f"malformed SWHID: {m[1]}") + m = re.search(r"Unknown SWHID: (\w+)", str(e)) + if m: + raise NameError(f"Unknown SWHID: {m[1]}") + raise + def count(self, ttype, direction, edges_fmt, src): method = getattr(self.entry, "count_" + ttype) return method(direction, edges_fmt, src) - async def simple_traversal( - self, ttype, direction, edges_fmt, src, max_edges, return_types - ): - assert ttype in ("leaves", "neighbors", "visit_nodes") + async def traversal(self, ttype, *args): method = getattr(self.stream_proxy, ttype) - async for node_id in method(direction, edges_fmt, src, max_edges, return_types): - yield node_id - - async def walk(self, direction, edges_fmt, algo, src, dst): - if dst in EXTENDED_SWHID_TYPES: - it = self.stream_proxy.walk_type(direction, edges_fmt, algo, src, dst) - else: - it = self.stream_proxy.walk(direction, edges_fmt, algo, src, dst) - async for node_id in it: - yield node_id - - async def random_walk(self, direction, edges_fmt, retries, src, dst, return_types): - if dst in EXTENDED_SWHID_TYPES: - it = self.stream_proxy.random_walk_type( - direction, edges_fmt, retries, src, dst, return_types - ) - else: - it = self.stream_proxy.random_walk( - direction, edges_fmt, retries, src, dst, return_types - ) - async for node_id in it: # TODO return 404 if path is empty - yield node_id - - async def visit_edges(self, direction, edges_fmt, src, max_edges): - it = self.stream_proxy.visit_edges(direction, edges_fmt, src, max_edges) - # convert stream a, b, c, d -> (a, b), (c, d) - prevNode = None - async for node in it: - if prevNode is not None: - yield (prevNode, node) - prevNode = None - else: - prevNode = node - - async def visit_paths(self, direction, edges_fmt, src, max_edges): - path = [] - async for node in self.stream_proxy.visit_paths( - direction, edges_fmt, src, max_edges - ): - if node == PATH_SEPARATOR_ID: - yield path - path = [] - else: - path.append(node) + async for line in method(*args): + yield line.decode().rstrip("\n") class JavaStreamProxy: """A proxy class for the org.softwareheritage.graph.Entry Java class that takes care of the setup and teardown of the named-pipe FIFO communication between Python and Java. Initialize JavaStreamProxy using: proxy = JavaStreamProxy(swh_entry_class_instance) Then you can call an Entry method and iterate on the FIFO results like this: async for value in proxy.java_method(arg1, arg2): print(value) """ def __init__(self, entry): self.entry = entry async def read_node_ids(self, fname): loop = asyncio.get_event_loop() open_thread = loop.run_in_executor(None, open, fname, "rb") # Since the open() call on the FIFO is blocking until it is also opened # on the Java side, we await it with a timeout in case there is an # exception that prevents the write-side open(). with (await asyncio.wait_for(open_thread, timeout=2)) as f: + + def read_n_lines(f, n): + buf = [] + for _ in range(n): + try: + buf.append(next(f)) + except StopIteration: + break + return buf + while True: - data = await loop.run_in_executor(None, f.read, BUF_SIZE) - if not data: + lines = await loop.run_in_executor(None, read_n_lines, f, BUF_LINES) + if not lines: break - for data in struct.iter_unpack(BIN_FMT, data): - yield data[0] + for line in lines: + yield line class _HandlerWrapper: def __init__(self, handler): self._handler = handler def __getattr__(self, name): func = getattr(self._handler, name) async def java_call(*args, **kwargs): loop = asyncio.get_event_loop() await loop.run_in_executor(None, lambda: func(*args, **kwargs)) def java_task(*args, **kwargs): return asyncio.create_task(java_call(*args, **kwargs)) return java_task @contextlib.contextmanager def get_handler(self): with tempfile.TemporaryDirectory(prefix="swh-graph-") as tmpdirname: cli_fifo = os.path.join(tmpdirname, "swh-graph.fifo") os.mkfifo(cli_fifo) reader = self.read_node_ids(cli_fifo) query_handler = self.entry.get_handler(cli_fifo) handler = self._HandlerWrapper(query_handler) yield (handler, reader) def __getattr__(self, name): async def java_call_iterator(*args, **kwargs): with self.get_handler() as (handler, reader): java_task = getattr(handler, name)(*args, **kwargs) try: async for value in reader: yield value except asyncio.TimeoutError: # If the read-side open() timeouts, an exception on the # Java side probably happened that prevented the # write-side open(). We propagate this exception here if # that is the case. task_exc = java_task.exception() if task_exc: raise task_exc raise await java_task return java_call_iterator diff --git a/swh/graph/graph.py b/swh/graph/graph.py deleted file mode 100644 index 3fd853b..0000000 --- a/swh/graph/graph.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright (C) 2019 The Software Heritage developers -# See the AUTHORS file at the top-level directory of this distribution -# License: GNU General Public License version 3, or any later version -# See top-level LICENSE file for more information - -import asyncio -import contextlib -import functools - -from swh.graph.backend import Backend -from swh.graph.dot import KIND_TO_SHAPE, dot_to_svg, graph_dot - -BASE_URL = "https://archive.softwareheritage.org/browse" -KIND_TO_URL_FRAGMENT = { - "ori": "/origin/{}", - "snp": "/snapshot/{}", - "rel": "/release/{}", - "rev": "/revision/{}", - "dir": "/directory/{}", - "cnt": "/content/sha1_git:{}/", -} - - -def call_async_gen(generator, *args, **kwargs): - loop = asyncio.get_event_loop() - it = generator(*args, **kwargs).__aiter__() - while True: - try: - res = loop.run_until_complete(it.__anext__()) - yield res - except StopAsyncIteration: - break - - -class Neighbors: - """Neighbor iterator with custom O(1) length method""" - - def __init__(self, graph, iterator, length_func): - self.graph = graph - self.iterator = iterator - self.length_func = length_func - - def __iter__(self): - return self - - def __next__(self): - succ = self.iterator.nextLong() - if succ == -1: - raise StopIteration - return GraphNode(self.graph, succ) - - def __len__(self): - return self.length_func() - - -class GraphNode: - """Node in the SWH graph""" - - def __init__(self, graph, node_id): - self.graph = graph - self.id = node_id - - def children(self): - return Neighbors( - self.graph, - self.graph.java_graph.successors(self.id), - lambda: self.graph.java_graph.outdegree(self.id), - ) - - def parents(self): - return Neighbors( - self.graph, - self.graph.java_graph.predecessors(self.id), - lambda: self.graph.java_graph.indegree(self.id), - ) - - def simple_traversal( - self, ttype, direction="forward", edges="*", max_edges=0, return_types="*" - ): - for node in call_async_gen( - self.graph.backend.simple_traversal, - ttype, - direction, - edges, - self.id, - max_edges, - return_types, - ): - yield self.graph[node] - - def leaves(self, *args, **kwargs): - yield from self.simple_traversal("leaves", *args, **kwargs) - - def visit_nodes(self, *args, **kwargs): - yield from self.simple_traversal("visit_nodes", *args, **kwargs) - - def visit_edges(self, direction="forward", edges="*", max_edges=0): - for src, dst in call_async_gen( - self.graph.backend.visit_edges, direction, edges, self.id, max_edges - ): - yield (self.graph[src], self.graph[dst]) - - def visit_paths(self, direction="forward", edges="*", max_edges=0): - for path in call_async_gen( - self.graph.backend.visit_paths, direction, edges, self.id, max_edges - ): - yield [self.graph[node] for node in path] - - def walk(self, dst, direction="forward", edges="*", traversal="dfs"): - for node in call_async_gen( - self.graph.backend.walk, direction, edges, traversal, self.id, dst - ): - yield self.graph[node] - - def _count(self, ttype, direction="forward", edges="*"): - return self.graph.backend.count(ttype, direction, edges, self.id) - - count_leaves = functools.partialmethod(_count, ttype="leaves") - count_neighbors = functools.partialmethod(_count, ttype="neighbors") - count_visit_nodes = functools.partialmethod(_count, ttype="visit_nodes") - - @property - def swhid(self): - return self.graph.node2swhid[self.id] - - @property - def kind(self): - return self.swhid.split(":")[2] - - def __str__(self): - return self.swhid - - def __repr__(self): - return "<{}>".format(self.swhid) - - def dot_fragment(self): - swh, version, kind, hash = self.swhid.split(":") - label = "{}:{}..{}".format(kind, hash[0:2], hash[-2:]) - url = BASE_URL + KIND_TO_URL_FRAGMENT[kind].format(hash) - shape = KIND_TO_SHAPE[kind] - return '{} [label="{}", href="{}", target="_blank", shape="{}"];'.format( - self.id, label, url, shape - ) - - def _repr_svg_(self): - nodes = [self, *list(self.children()), *list(self.parents())] - dot = graph_dot(nodes) - svg = dot_to_svg(dot) - return svg - - -class Graph: - def __init__(self, backend, node2swhid, swhid2node): - self.backend = backend - self.java_graph = backend.entry.get_graph() - self.node2swhid = node2swhid - self.swhid2node = swhid2node - - def stats(self): - return self.backend.stats() - - @property - def path(self): - return self.java_graph.getPath() - - def __len__(self): - return self.java_graph.numNodes() - - def __getitem__(self, node_id): - if isinstance(node_id, int): - self.node2swhid[node_id] # check existence - return GraphNode(self, node_id) - elif isinstance(node_id, str): - node_id = self.swhid2node[node_id] - return GraphNode(self, node_id) - - def __iter__(self): - for swhid, pos in self.backend.swhid2node: - yield self[swhid] - - def iter_prefix(self, prefix): - for swhid, pos in self.backend.swhid2node.iter_prefix(prefix): - yield self[swhid] - - def iter_type(self, swhid_type): - for swhid, pos in self.backend.swhid2node.iter_type(swhid_type): - yield self[swhid] - - -@contextlib.contextmanager -def load(graph_path): - with Backend(graph_path) as backend: - yield Graph(backend, backend.node2swhid, backend.swhid2node) diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py index bd21952..be07fc0 100644 --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -1,402 +1,360 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """ A proxy HTTP server for swh-graph, talking to the Java code via py4j, and using FIFO as a transport to stream integers between the two languages. """ import asyncio from collections import deque -import json import os from typing import Optional import aiohttp.web from swh.core.api.asynchronous import RPCServerApp from swh.core.config import read as config_read from swh.graph.backend import Backend -from swh.model.exceptions import ValidationError from swh.model.swhids import EXTENDED_SWHID_TYPES try: from contextlib import asynccontextmanager except ImportError: # Compatibility with 3.6 backport from async_generator import asynccontextmanager # type: ignore # maximum number of retries for random walks RANDOM_RETRIES = 5 # TODO make this configurable via rpc-serve configuration class GraphServerApp(RPCServerApp): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.on_startup.append(self._start_gateway) self.on_shutdown.append(self._stop_gateway) @staticmethod async def _start_gateway(app): # Equivalent to entering `with app["backend"]:` app["backend"].start_gateway() @staticmethod async def _stop_gateway(app): # Equivalent to exiting `with app["backend"]:` with no error app["backend"].stop_gateway() async def index(request): return aiohttp.web.Response( content_type="text/html", body=""" Software Heritage graph server

You have reached the Software Heritage graph API server.

See its API documentation for more information.

""", ) class GraphView(aiohttp.web.View): """Base class for views working on the graph, with utility functions""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.backend = self.request.app["backend"] - def node_of_swhid(self, swhid): - """Lookup a SWHID in a swhid2node map, failing in an HTTP-nice way if - needed.""" - try: - return self.backend.swhid2node[swhid] - except KeyError: - raise aiohttp.web.HTTPNotFound(text=f"SWHID not found: {swhid}") - except ValidationError: - raise aiohttp.web.HTTPBadRequest(text=f"malformed SWHID: {swhid}") - - def swhid_of_node(self, node): - """Lookup a node in a node2swhid map, failing in an HTTP-nice way if - needed.""" - try: - return self.backend.node2swhid[node] - except KeyError: - raise aiohttp.web.HTTPInternalServerError( - text=f"reverse lookup failed for node id: {node}" - ) - def get_direction(self): """Validate HTTP query parameter `direction`""" s = self.request.query.get("direction", "forward") if s not in ("forward", "backward"): raise aiohttp.web.HTTPBadRequest(text=f"invalid direction: {s}") return s def get_edges(self): """Validate HTTP query parameter `edges`, i.e., edge restrictions""" s = self.request.query.get("edges", "*") if any( [ node_type != "*" and node_type not in EXTENDED_SWHID_TYPES for edge in s.split(":") for node_type in edge.split(",", maxsplit=1) ] ): raise aiohttp.web.HTTPBadRequest(text=f"invalid edge restriction: {s}") return s def get_return_types(self): """Validate HTTP query parameter 'return types', i.e, a set of types which we will filter the query results with""" s = self.request.query.get("return_types", "*") if any( node_type != "*" and node_type not in EXTENDED_SWHID_TYPES for node_type in s.split(",") ): raise aiohttp.web.HTTPBadRequest( text=f"invalid type for filtering res: {s}" ) # if the user puts a star, # then we filter nothing, we don't need the other information if "*" in s: return "*" else: return s def get_traversal(self): """Validate HTTP query parameter `traversal`, i.e., visit order""" s = self.request.query.get("traversal", "dfs") if s not in ("bfs", "dfs"): raise aiohttp.web.HTTPBadRequest(text=f"invalid traversal order: {s}") return s def get_limit(self): """Validate HTTP query parameter `limit`, i.e., number of results""" s = self.request.query.get("limit", "0") try: return int(s) except ValueError: raise aiohttp.web.HTTPBadRequest(text=f"invalid limit value: {s}") def get_max_edges(self): """Validate HTTP query parameter 'max_edges', i.e., the limit of the number of edges that can be visited""" s = self.request.query.get("max_edges", "0") try: return int(s) except ValueError: raise aiohttp.web.HTTPBadRequest(text=f"invalid max_edges value: {s}") + def check_swhid(self, swhid): + """Validate that the given SWHID exists in the graph""" + try: + self.backend.check_swhid(swhid) + except (NameError, ValueError) as e: + raise aiohttp.web.HTTPBadRequest(text=str(e)) + class StreamingGraphView(GraphView): """Base class for views streaming their response line by line.""" content_type = "text/plain" @asynccontextmanager async def response_streamer(self, *args, **kwargs): """Context manager to prepare then close a StreamResponse""" response = aiohttp.web.StreamResponse(*args, **kwargs) response.content_type = self.content_type await response.prepare(self.request) yield response await response.write_eof() async def get(self): await self.prepare_response() async with self.response_streamer() as self.response_stream: self._buf = [] try: await self.stream_response() finally: await self._flush_buffer() return self.response_stream async def prepare_response(self): """This can be overridden with some setup to be run before the response actually starts streaming. """ pass async def stream_response(self): """Override this to perform the response streaming. Implementations of this should await self.stream_line(line) to write each line. """ raise NotImplementedError async def stream_line(self, line): """Write a line in the response stream.""" self._buf.append(line) if len(self._buf) > 100: await self._flush_buffer() async def _flush_buffer(self): await self.response_stream.write("\n".join(self._buf).encode() + b"\n") self._buf = [] class StatsView(GraphView): """View showing some statistics on the graph""" async def get(self): stats = self.backend.stats() return aiohttp.web.Response(body=stats, content_type="application/json") class SimpleTraversalView(StreamingGraphView): """Base class for views of simple traversals""" simple_traversal_type: Optional[str] = None async def prepare_response(self): - src = self.request.match_info["src"] - self.src_node = self.node_of_swhid(src) - + self.src = self.request.match_info["src"] self.edges = self.get_edges() self.direction = self.get_direction() self.max_edges = self.get_max_edges() self.return_types = self.get_return_types() + self.check_swhid(self.src) async def stream_response(self): - async for res_node in self.backend.simple_traversal( + async for res_line in self.backend.traversal( self.simple_traversal_type, self.direction, self.edges, - self.src_node, + self.src, self.max_edges, self.return_types, ): - res_swhid = self.swhid_of_node(res_node) - await self.stream_line(res_swhid) + await self.stream_line(res_line) class LeavesView(SimpleTraversalView): simple_traversal_type = "leaves" class NeighborsView(SimpleTraversalView): simple_traversal_type = "neighbors" class VisitNodesView(SimpleTraversalView): simple_traversal_type = "visit_nodes" +class VisitEdgesView(SimpleTraversalView): + simple_traversal_type = "visit_edges" + + class WalkView(StreamingGraphView): async def prepare_response(self): - src = self.request.match_info["src"] - dst = self.request.match_info["dst"] - self.src_node = self.node_of_swhid(src) - if dst not in EXTENDED_SWHID_TYPES: - self.dst_thing = self.node_of_swhid(dst) - else: - self.dst_thing = dst + self.src = self.request.match_info["src"] + self.dst = self.request.match_info["dst"] self.edges = self.get_edges() self.direction = self.get_direction() self.algo = self.get_traversal() self.limit = self.get_limit() - self.return_types = self.get_return_types() + + self.check_swhid(self.src) + if self.dst not in EXTENDED_SWHID_TYPES: + self.check_swhid(self.dst) async def get_walk_iterator(self): - return self.backend.walk( - self.direction, self.edges, self.algo, self.src_node, self.dst_thing + return self.backend.traversal( + "walk", self.direction, self.edges, self.algo, self.src, self.dst ) async def stream_response(self): it = self.get_walk_iterator() if self.limit < 0: queue = deque(maxlen=-self.limit) - async for res_node in it: - res_swhid = self.swhid_of_node(res_node) + async for res_swhid in it: queue.append(res_swhid) while queue: await self.stream_line(queue.popleft()) else: count = 0 - async for res_node in it: + async for res_swhid in it: if self.limit == 0 or count < self.limit: - res_swhid = self.swhid_of_node(res_node) await self.stream_line(res_swhid) count += 1 else: break class RandomWalkView(WalkView): def get_walk_iterator(self): - return self.backend.random_walk( + return self.backend.traversal( + "random_walk", self.direction, self.edges, RANDOM_RETRIES, - self.src_node, - self.dst_thing, - self.return_types, - ) - - -class VisitEdgesView(SimpleTraversalView): - async def stream_response(self): - it = self.backend.visit_edges( - self.direction, self.edges, self.src_node, self.max_edges - ) - async for (res_src, res_dst) in it: - res_src_swhid = self.swhid_of_node(res_src) - res_dst_swhid = self.swhid_of_node(res_dst) - await self.stream_line("{} {}".format(res_src_swhid, res_dst_swhid)) - - -class VisitPathsView(SimpleTraversalView): - content_type = "application/x-ndjson" - - async def stream_response(self): - it = self.backend.visit_paths( - self.direction, self.edges, self.src_node, self.max_edges + self.src, + self.dst, ) - async for res_path in it: - res_path_swhid = [self.swhid_of_node(n) for n in res_path] - line = json.dumps(res_path_swhid) - await self.stream_line(line) class CountView(GraphView): """Base class for counting views.""" count_type: Optional[str] = None async def get(self): - src = self.request.match_info["src"] - self.src_node = self.node_of_swhid(src) + self.src = self.request.match_info["src"] + self.check_swhid(self.src) self.edges = self.get_edges() self.direction = self.get_direction() loop = asyncio.get_event_loop() cnt = await loop.run_in_executor( None, self.backend.count, self.count_type, self.direction, self.edges, - self.src_node, + self.src, ) return aiohttp.web.Response(body=str(cnt), content_type="application/json") class CountNeighborsView(CountView): count_type = "neighbors" class CountLeavesView(CountView): count_type = "leaves" class CountVisitNodesView(CountView): count_type = "visit_nodes" def make_app(config=None, backend=None, **kwargs): if (config is None) == (backend is None): raise ValueError("make_app() expects exactly one of 'config' or 'backend'") if backend is None: backend = Backend(graph_path=config["graph"]["path"], config=config["graph"]) app = GraphServerApp(**kwargs) app.add_routes( [ aiohttp.web.get("/", index), aiohttp.web.get("/graph", index), aiohttp.web.view("/graph/stats", StatsView), aiohttp.web.view("/graph/leaves/{src}", LeavesView), aiohttp.web.view("/graph/neighbors/{src}", NeighborsView), aiohttp.web.view("/graph/visit/nodes/{src}", VisitNodesView), aiohttp.web.view("/graph/visit/edges/{src}", VisitEdgesView), - aiohttp.web.view("/graph/visit/paths/{src}", VisitPathsView), # temporarily disabled in wait of a proper fix for T1969 # aiohttp.web.view("/graph/walk/{src}/{dst}", WalkView) aiohttp.web.view("/graph/randomwalk/{src}/{dst}", RandomWalkView), aiohttp.web.view("/graph/neighbors/count/{src}", CountNeighborsView), aiohttp.web.view("/graph/leaves/count/{src}", CountLeavesView), aiohttp.web.view("/graph/visit/nodes/count/{src}", CountVisitNodesView), ] ) app["backend"] = backend return app def make_app_from_configfile(): """Load configuration and then build application to run """ config_file = os.environ.get("SWH_CONFIG_FILENAME") config = config_read(config_file) return make_app(config=config) diff --git a/swh/graph/tests/conftest.py b/swh/graph/tests/conftest.py index 5a7bb92..fed877b 100644 --- a/swh/graph/tests/conftest.py +++ b/swh/graph/tests/conftest.py @@ -1,68 +1,59 @@ # Copyright (C) 2019-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import csv import multiprocessing from pathlib import Path from aiohttp.test_utils import TestClient, TestServer, loop_context import pytest from swh.graph.client import RemoteGraphClient from swh.graph.naive_client import NaiveClient SWH_GRAPH_TESTS_ROOT = Path(__file__).parents[0] TEST_GRAPH_PATH = SWH_GRAPH_TESTS_ROOT / "dataset/output/example" class GraphServerProcess(multiprocessing.Process): def __init__(self, q, *args, **kwargs): self.q = q super().__init__(*args, **kwargs) def run(self): # Lazy import to allow debian packaging from swh.graph.backend import Backend from swh.graph.server.app import make_app try: backend = Backend(graph_path=str(TEST_GRAPH_PATH)) with loop_context() as loop: app = make_app(backend=backend, debug=True) client = TestClient(TestServer(app), loop=loop) loop.run_until_complete(client.start_server()) url = client.make_url("/graph/") self.q.put(url) loop.run_forever() except Exception as e: self.q.put(e) @pytest.fixture(scope="module", params=["remote", "naive"]) def graph_client(request): if request.param == "remote": queue = multiprocessing.Queue() server = GraphServerProcess(queue) server.start() res = queue.get() if isinstance(res, Exception): raise res yield RemoteGraphClient(str(res)) server.terminate() else: with open(SWH_GRAPH_TESTS_ROOT / "dataset/example.nodes.csv") as fd: nodes = [node for (node,) in csv.reader(fd, delimiter=" ")] with open(SWH_GRAPH_TESTS_ROOT / "dataset/example.edges.csv") as fd: edges = list(csv.reader(fd, delimiter=" ")) yield NaiveClient(nodes=nodes, edges=edges) - - -@pytest.fixture(scope="module") -def graph(): - # Lazy import to allow debian packaging - from swh.graph.graph import load as graph_load - - with graph_load(str(TEST_GRAPH_PATH)) as g: - yield g diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py index 90f9a0a..46e0227 100644 --- a/swh/graph/tests/test_api_client.py +++ b/swh/graph/tests/test_api_client.py @@ -1,375 +1,379 @@ import pytest from pytest import raises from swh.core.api import RemoteException from swh.graph.client import GraphArgumentException def test_stats(graph_client): stats = graph_client.stats() assert set(stats.keys()) == {"counts", "ratios", "indegree", "outdegree"} assert set(stats["counts"].keys()) == {"nodes", "edges"} assert set(stats["ratios"].keys()) == { "compression", "bits_per_node", "bits_per_edge", "avg_locality", } assert set(stats["indegree"].keys()) == {"min", "max", "avg"} assert set(stats["outdegree"].keys()) == {"min", "max", "avg"} assert stats["counts"]["nodes"] == 21 assert stats["counts"]["edges"] == 23 assert isinstance(stats["ratios"]["compression"], float) assert isinstance(stats["ratios"]["bits_per_node"], float) assert isinstance(stats["ratios"]["bits_per_edge"], float) assert isinstance(stats["ratios"]["avg_locality"], float) assert stats["indegree"]["min"] == 0 assert stats["indegree"]["max"] == 3 assert isinstance(stats["indegree"]["avg"], float) assert stats["outdegree"]["min"] == 0 assert stats["outdegree"]["max"] == 3 assert isinstance(stats["outdegree"]["avg"], float) def test_leaves(graph_client): actual = list( graph_client.leaves("swh:1:ori:0000000000000000000000000000000000000021") ) expected = [ "swh:1:cnt:0000000000000000000000000000000000000001", "swh:1:cnt:0000000000000000000000000000000000000004", "swh:1:cnt:0000000000000000000000000000000000000005", "swh:1:cnt:0000000000000000000000000000000000000007", ] assert set(actual) == set(expected) def test_neighbors(graph_client): actual = list( graph_client.neighbors( "swh:1:rev:0000000000000000000000000000000000000009", direction="backward" ) ) expected = [ "swh:1:snp:0000000000000000000000000000000000000020", "swh:1:rel:0000000000000000000000000000000000000010", "swh:1:rev:0000000000000000000000000000000000000013", ] assert set(actual) == set(expected) def test_visit_nodes(graph_client): actual = list( graph_client.visit_nodes( "swh:1:rel:0000000000000000000000000000000000000010", edges="rel:rev,rev:rev", ) ) expected = [ "swh:1:rel:0000000000000000000000000000000000000010", "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:rev:0000000000000000000000000000000000000003", ] assert set(actual) == set(expected) def test_visit_nodes_filtered(graph_client): actual = list( graph_client.visit_nodes( "swh:1:rel:0000000000000000000000000000000000000010", return_types="dir", ) ) expected = [ "swh:1:dir:0000000000000000000000000000000000000002", "swh:1:dir:0000000000000000000000000000000000000008", "swh:1:dir:0000000000000000000000000000000000000006", ] assert set(actual) == set(expected) def test_visit_nodes_filtered_star(graph_client): actual = list( graph_client.visit_nodes( "swh:1:rel:0000000000000000000000000000000000000010", return_types="*", ) ) expected = [ "swh:1:rel:0000000000000000000000000000000000000010", "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:rev:0000000000000000000000000000000000000003", "swh:1:dir:0000000000000000000000000000000000000002", "swh:1:cnt:0000000000000000000000000000000000000001", "swh:1:dir:0000000000000000000000000000000000000008", "swh:1:cnt:0000000000000000000000000000000000000007", "swh:1:dir:0000000000000000000000000000000000000006", "swh:1:cnt:0000000000000000000000000000000000000004", "swh:1:cnt:0000000000000000000000000000000000000005", ] assert set(actual) == set(expected) def test_visit_edges(graph_client): actual = list( graph_client.visit_edges( "swh:1:rel:0000000000000000000000000000000000000010", edges="rel:rev,rev:rev,rev:dir", ) ) expected = [ ( "swh:1:rel:0000000000000000000000000000000000000010", "swh:1:rev:0000000000000000000000000000000000000009", ), ( "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:rev:0000000000000000000000000000000000000003", ), ( "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:dir:0000000000000000000000000000000000000008", ), ( "swh:1:rev:0000000000000000000000000000000000000003", "swh:1:dir:0000000000000000000000000000000000000002", ), ] assert set(actual) == set(expected) def test_visit_edges_limited(graph_client): actual = list( graph_client.visit_edges( "swh:1:rel:0000000000000000000000000000000000000010", max_edges=4, edges="rel:rev,rev:rev,rev:dir", ) ) expected = [ ( "swh:1:rel:0000000000000000000000000000000000000010", "swh:1:rev:0000000000000000000000000000000000000009", ), ( "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:rev:0000000000000000000000000000000000000003", ), ( "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:dir:0000000000000000000000000000000000000008", ), ( "swh:1:rev:0000000000000000000000000000000000000003", "swh:1:dir:0000000000000000000000000000000000000002", ), ] # As there are four valid answers (up to reordering), we cannot check for # equality. Instead, we check the client returned all edges but one. assert set(actual).issubset(set(expected)) assert len(actual) == 3 def test_visit_edges_diamond_pattern(graph_client): actual = list( graph_client.visit_edges( "swh:1:rev:0000000000000000000000000000000000000009", edges="*", ) ) expected = [ ( "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:rev:0000000000000000000000000000000000000003", ), ( "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:dir:0000000000000000000000000000000000000008", ), ( "swh:1:rev:0000000000000000000000000000000000000003", "swh:1:dir:0000000000000000000000000000000000000002", ), ( "swh:1:dir:0000000000000000000000000000000000000002", "swh:1:cnt:0000000000000000000000000000000000000001", ), ( "swh:1:dir:0000000000000000000000000000000000000008", "swh:1:cnt:0000000000000000000000000000000000000001", ), ( "swh:1:dir:0000000000000000000000000000000000000008", "swh:1:cnt:0000000000000000000000000000000000000007", ), ( "swh:1:dir:0000000000000000000000000000000000000008", "swh:1:dir:0000000000000000000000000000000000000006", ), ( "swh:1:dir:0000000000000000000000000000000000000006", "swh:1:cnt:0000000000000000000000000000000000000004", ), ( "swh:1:dir:0000000000000000000000000000000000000006", "swh:1:cnt:0000000000000000000000000000000000000005", ), ] assert set(actual) == set(expected) -def test_visit_paths(graph_client): - actual = list( - graph_client.visit_paths( - "swh:1:snp:0000000000000000000000000000000000000020", edges="snp:*,rev:*" - ) - ) - actual = [tuple(path) for path in actual] - expected = [ - ( - "swh:1:snp:0000000000000000000000000000000000000020", - "swh:1:rev:0000000000000000000000000000000000000009", - "swh:1:rev:0000000000000000000000000000000000000003", - "swh:1:dir:0000000000000000000000000000000000000002", - ), - ( - "swh:1:snp:0000000000000000000000000000000000000020", - "swh:1:rev:0000000000000000000000000000000000000009", - "swh:1:dir:0000000000000000000000000000000000000008", - ), - ( - "swh:1:snp:0000000000000000000000000000000000000020", - "swh:1:rel:0000000000000000000000000000000000000010", - ), - ] - assert set(actual) == set(expected) - - @pytest.mark.skip(reason="currently disabled due to T1969") def test_walk(graph_client): args = ("swh:1:dir:0000000000000000000000000000000000000016", "rel") kwargs = { "edges": "dir:dir,dir:rev,rev:*", "direction": "backward", "traversal": "bfs", } actual = list(graph_client.walk(*args, **kwargs)) expected = [ "swh:1:dir:0000000000000000000000000000000000000016", "swh:1:dir:0000000000000000000000000000000000000017", "swh:1:rev:0000000000000000000000000000000000000018", "swh:1:rel:0000000000000000000000000000000000000019", ] assert set(actual) == set(expected) kwargs2 = kwargs.copy() kwargs2["limit"] = -1 actual = list(graph_client.walk(*args, **kwargs2)) expected = ["swh:1:rel:0000000000000000000000000000000000000019"] assert set(actual) == set(expected) kwargs2 = kwargs.copy() kwargs2["limit"] = 2 actual = list(graph_client.walk(*args, **kwargs2)) expected = [ "swh:1:dir:0000000000000000000000000000000000000016", "swh:1:dir:0000000000000000000000000000000000000017", ] assert set(actual) == set(expected) -def test_random_walk(graph_client): +def test_random_walk_dst_is_type(graph_client): """as the walk is random, we test a visit from a cnt node to the only origin in the dataset, and only check the final node of the path (i.e., the origin) """ args = ("swh:1:cnt:0000000000000000000000000000000000000001", "ori") kwargs = {"direction": "backward"} expected_root = "swh:1:ori:0000000000000000000000000000000000000021" actual = list(graph_client.random_walk(*args, **kwargs)) assert len(actual) > 1 # no origin directly links to a content assert actual[0] == args[0] assert actual[-1] == expected_root kwargs2 = kwargs.copy() kwargs2["limit"] = -1 actual = list(graph_client.random_walk(*args, **kwargs2)) assert actual == [expected_root] kwargs2["limit"] = -2 actual = list(graph_client.random_walk(*args, **kwargs2)) assert len(actual) == 2 assert actual[-1] == expected_root kwargs2["limit"] = 3 actual = list(graph_client.random_walk(*args, **kwargs2)) assert len(actual) == 3 +def test_random_walk_dst_is_node(graph_client): + """Same as test_random_walk_dst_is_type, but we target the specific origin + node instead of a type + """ + args = ( + "swh:1:cnt:0000000000000000000000000000000000000001", + "swh:1:ori:0000000000000000000000000000000000000021", + ) + kwargs = {"direction": "backward"} + expected_root = "swh:1:ori:0000000000000000000000000000000000000021" + + actual = list(graph_client.random_walk(*args, **kwargs)) + assert len(actual) > 1 # no origin directly links to a content + assert actual[0] == args[0] + assert actual[-1] == expected_root + + kwargs2 = kwargs.copy() + kwargs2["limit"] = -1 + actual = list(graph_client.random_walk(*args, **kwargs2)) + assert actual == [expected_root] + + kwargs2["limit"] = -2 + actual = list(graph_client.random_walk(*args, **kwargs2)) + assert len(actual) == 2 + assert actual[-1] == expected_root + + kwargs2["limit"] = 3 + actual = list(graph_client.random_walk(*args, **kwargs2)) + assert len(actual) == 3 + + def test_count(graph_client): actual = graph_client.count_leaves( "swh:1:ori:0000000000000000000000000000000000000021" ) assert actual == 4 actual = graph_client.count_visit_nodes( "swh:1:rel:0000000000000000000000000000000000000010", edges="rel:rev,rev:rev" ) assert actual == 3 actual = graph_client.count_neighbors( "swh:1:rev:0000000000000000000000000000000000000009", direction="backward" ) assert actual == 3 def test_param_validation(graph_client): with raises(GraphArgumentException) as exc_info: # SWHID not found list(graph_client.leaves("swh:1:ori:fff0000000000000000000000000000000000021")) if exc_info.value.response: assert exc_info.value.response.status_code == 404 with raises(GraphArgumentException) as exc_info: # malformed SWHID list( graph_client.neighbors("swh:1:ori:fff000000zzzzzz0000000000000000000000021") ) if exc_info.value.response: assert exc_info.value.response.status_code == 400 with raises(GraphArgumentException) as exc_info: # malformed edge specificaiton list( graph_client.visit_nodes( "swh:1:dir:0000000000000000000000000000000000000016", edges="dir:notanodetype,dir:rev,rev:*", direction="backward", ) ) if exc_info.value.response: assert exc_info.value.response.status_code == 400 with raises(GraphArgumentException) as exc_info: # malformed direction list( graph_client.visit_nodes( "swh:1:dir:0000000000000000000000000000000000000016", edges="dir:dir,dir:rev,rev:*", direction="notadirection", ) ) if exc_info.value.response: assert exc_info.value.response.status_code == 400 @pytest.mark.skip(reason="currently disabled due to T1969") def test_param_validation_walk(graph_client): """test validation of walk-specific parameters only""" with raises(RemoteException) as exc_info: # malformed traversal order list( graph_client.walk( "swh:1:dir:0000000000000000000000000000000000000016", "rel", edges="dir:dir,dir:rev,rev:*", direction="backward", traversal="notatraversalorder", ) ) assert exc_info.value.response.status_code == 400 diff --git a/swh/graph/tests/test_graph.py b/swh/graph/tests/test_graph.py deleted file mode 100644 index c752580..0000000 --- a/swh/graph/tests/test_graph.py +++ /dev/null @@ -1,166 +0,0 @@ -import pytest - - -def test_graph(graph): - assert len(graph) == 21 - - obj = "swh:1:dir:0000000000000000000000000000000000000008" - node = graph[obj] - - assert str(node) == obj - assert len(node.children()) == 3 - assert len(node.parents()) == 2 - - actual = {p.swhid for p in node.children()} - expected = { - "swh:1:cnt:0000000000000000000000000000000000000001", - "swh:1:dir:0000000000000000000000000000000000000006", - "swh:1:cnt:0000000000000000000000000000000000000007", - } - assert expected == actual - - actual = {p.swhid for p in node.parents()} - expected = { - "swh:1:rev:0000000000000000000000000000000000000009", - "swh:1:dir:0000000000000000000000000000000000000012", - } - assert expected == actual - - -def test_invalid_swhid(graph): - with pytest.raises(IndexError): - graph[1337] - - with pytest.raises(IndexError): - graph[len(graph) + 1] - - with pytest.raises(KeyError): - graph["swh:1:dir:0000000000000000000000000000000420000012"] - - -def test_leaves(graph): - actual = list(graph["swh:1:ori:0000000000000000000000000000000000000021"].leaves()) - actual = [p.swhid for p in actual] - expected = [ - "swh:1:cnt:0000000000000000000000000000000000000001", - "swh:1:cnt:0000000000000000000000000000000000000004", - "swh:1:cnt:0000000000000000000000000000000000000005", - "swh:1:cnt:0000000000000000000000000000000000000007", - ] - assert set(actual) == set(expected) - - -def test_visit_nodes(graph): - actual = list( - graph["swh:1:rel:0000000000000000000000000000000000000010"].visit_nodes( - edges="rel:rev,rev:rev" - ) - ) - actual = [p.swhid for p in actual] - expected = [ - "swh:1:rel:0000000000000000000000000000000000000010", - "swh:1:rev:0000000000000000000000000000000000000009", - "swh:1:rev:0000000000000000000000000000000000000003", - ] - assert set(actual) == set(expected) - - -def test_visit_edges(graph): - actual = list( - graph["swh:1:rel:0000000000000000000000000000000000000010"].visit_edges( - edges="rel:rev,rev:rev,rev:dir" - ) - ) - actual = [(src.swhid, dst.swhid) for src, dst in actual] - expected = [ - ( - "swh:1:rel:0000000000000000000000000000000000000010", - "swh:1:rev:0000000000000000000000000000000000000009", - ), - ( - "swh:1:rev:0000000000000000000000000000000000000009", - "swh:1:rev:0000000000000000000000000000000000000003", - ), - ( - "swh:1:rev:0000000000000000000000000000000000000009", - "swh:1:dir:0000000000000000000000000000000000000008", - ), - ( - "swh:1:rev:0000000000000000000000000000000000000003", - "swh:1:dir:0000000000000000000000000000000000000002", - ), - ] - assert set(actual) == set(expected) - - -def test_visit_paths(graph): - actual = list( - graph["swh:1:snp:0000000000000000000000000000000000000020"].visit_paths( - edges="snp:*,rev:*" - ) - ) - actual = [tuple(n.swhid for n in path) for path in actual] - expected = [ - ( - "swh:1:snp:0000000000000000000000000000000000000020", - "swh:1:rev:0000000000000000000000000000000000000009", - "swh:1:rev:0000000000000000000000000000000000000003", - "swh:1:dir:0000000000000000000000000000000000000002", - ), - ( - "swh:1:snp:0000000000000000000000000000000000000020", - "swh:1:rev:0000000000000000000000000000000000000009", - "swh:1:dir:0000000000000000000000000000000000000008", - ), - ( - "swh:1:snp:0000000000000000000000000000000000000020", - "swh:1:rel:0000000000000000000000000000000000000010", - ), - ] - assert set(actual) == set(expected) - - -def test_walk(graph): - actual = list( - graph["swh:1:dir:0000000000000000000000000000000000000016"].walk( - "rel", edges="dir:dir,dir:rev,rev:*", direction="backward", traversal="bfs" - ) - ) - actual = [p.swhid for p in actual] - expected = [ - "swh:1:dir:0000000000000000000000000000000000000016", - "swh:1:dir:0000000000000000000000000000000000000017", - "swh:1:rev:0000000000000000000000000000000000000018", - "swh:1:rel:0000000000000000000000000000000000000019", - ] - assert set(actual) == set(expected) - - -def test_count(graph): - assert ( - graph["swh:1:ori:0000000000000000000000000000000000000021"].count_leaves() == 4 - ) - assert ( - graph["swh:1:rel:0000000000000000000000000000000000000010"].count_visit_nodes( - edges="rel:rev,rev:rev" - ) - == 3 - ) - assert ( - graph["swh:1:rev:0000000000000000000000000000000000000009"].count_neighbors( - direction="backward" - ) - == 3 - ) - - -def test_iter_type(graph): - rev_list = list(graph.iter_type("rev")) - actual = [n.swhid for n in rev_list] - expected = [ - "swh:1:rev:0000000000000000000000000000000000000003", - "swh:1:rev:0000000000000000000000000000000000000009", - "swh:1:rev:0000000000000000000000000000000000000013", - "swh:1:rev:0000000000000000000000000000000000000018", - ] - assert expected == actual