diff --git a/java/server/src/main/java/org/softwareheritage/graph/Entry.java b/java/server/src/main/java/org/softwareheritage/graph/Entry.java index 03aae5e..2186d6d 100644 --- a/java/server/src/main/java/org/softwareheritage/graph/Entry.java +++ b/java/server/src/main/java/org/softwareheritage/graph/Entry.java @@ -1,143 +1,163 @@ package org.softwareheritage.graph; import java.util.ArrayList; -import java.util.Map; import java.io.DataOutputStream; import java.io.FileOutputStream; import java.io.IOException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.PropertyNamingStrategy; -import py4j.GatewayServer; +import org.softwareheritage.graph.algo.NodeIdConsumer; -import org.softwareheritage.graph.Graph; -import org.softwareheritage.graph.Node; import org.softwareheritage.graph.algo.Stats; -import org.softwareheritage.graph.algo.NodeIdsConsumer; import org.softwareheritage.graph.algo.Traversal; public class Entry { - Graph graph; + private Graph graph; - final long PATH_SEPARATOR_ID = -1; + private final long PATH_SEPARATOR_ID = -1; public void load_graph(String graphBasename) throws IOException { System.err.println("Loading graph " + graphBasename + " ..."); this.graph = new Graph(graphBasename); System.err.println("Graph loaded."); } public Graph get_graph() { return graph.copy(); } public String stats() { try { Stats stats = new Stats(graph.getPath()); ObjectMapper objectMapper = new ObjectMapper(); objectMapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE); - String res = objectMapper.writeValueAsString(stats); - return res; + return objectMapper.writeValueAsString(stats); } catch (IOException e) { throw new RuntimeException("Cannot read stats: " + e); } } + private interface NodeCountVisitor { + void accept(long nodeId, NodeIdConsumer consumer); + } + + private int count_visitor(NodeCountVisitor f, long srcNodeId) { + int count[] = { 0 }; + f.accept(srcNodeId, (node) -> { count[0]++; }); + return count[0]; + } + + public int count_leaves(String direction, String edgesFmt, long srcNodeId) { + Traversal t = new Traversal(this.graph, direction, edgesFmt); + return count_visitor(t::leavesVisitor, srcNodeId); + } + + public int count_neighbors(String direction, String edgesFmt, long srcNodeId) { + Traversal t = new Traversal(this.graph, direction, edgesFmt); + return count_visitor(t::neighborsVisitor, srcNodeId); + } + + public int count_visit_nodes(String direction, String edgesFmt, long srcNodeId) { + Traversal t = new Traversal(this.graph, direction, edgesFmt); + return count_visitor(t::visitNodesVisitor, srcNodeId); + } + public QueryHandler get_handler(String clientFIFO) { return new QueryHandler(this.graph.copy(), clientFIFO); } public class QueryHandler { Graph graph; DataOutputStream out; String clientFIFO; public QueryHandler(Graph graph, String clientFIFO) { this.graph = graph; this.clientFIFO = clientFIFO; this.out = null; } public void writeNode(long nodeId) { try { out.writeLong(nodeId); } catch (IOException e) { throw new RuntimeException("Cannot write response to client: " + e); } } public void writePath(ArrayList path) { for (Long nodeId : path) { writeNode(nodeId); } writeNode(PATH_SEPARATOR_ID); } public void open() { try { FileOutputStream file = new FileOutputStream(this.clientFIFO); this.out = new DataOutputStream(file); } catch (IOException e) { throw new RuntimeException("Cannot open client FIFO: " + e); } } public void close() { try { out.close(); } catch (IOException e) { throw new RuntimeException("Cannot write response to client: " + e); } } public void leaves(String direction, String edgesFmt, long srcNodeId) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); t.leavesVisitor(srcNodeId, this::writeNode); close(); } public void neighbors(String direction, String edgesFmt, long srcNodeId) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); t.neighborsVisitor(srcNodeId, this::writeNode); close(); } public void visit_nodes(String direction, String edgesFmt, long srcNodeId) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); t.visitNodesVisitor(srcNodeId, this::writeNode); close(); } public void visit_paths(String direction, String edgesFmt, long srcNodeId) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); t.visitPathsVisitor(srcNodeId, this::writePath); close(); } public void walk(String direction, String edgesFmt, String algorithm, long srcNodeId, long dstNodeId) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); for (Long nodeId : t.walk(srcNodeId, dstNodeId, algorithm)) { writeNode(nodeId); } close(); } public void walk_type(String direction, String edgesFmt, String algorithm, long srcNodeId, String dst) { open(); Node.Type dstType = Node.Type.fromStr(dst); Traversal t = new Traversal(this.graph, direction, edgesFmt); for (Long nodeId : t.walk(srcNodeId, dstType, algorithm)) { writeNode(nodeId); } close(); } } } diff --git a/swh/graph/backend.py b/swh/graph/backend.py index 0f5f3ef..6d64eec 100644 --- a/swh/graph/backend.py +++ b/swh/graph/backend.py @@ -1,182 +1,186 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import asyncio import contextlib import io import os import pathlib import struct import subprocess import sys import tempfile from py4j.java_gateway import JavaGateway from swh.graph.pid import IntToPidMap, PidToIntMap from swh.model.identifiers import PID_TYPES BUF_SIZE = 64*1024 BIN_FMT = '>q' # 64 bit integer, big endian PATH_SEPARATOR_ID = -1 NODE2PID_EXT = 'node2pid.bin' PID2NODE_EXT = 'pid2node.bin' def find_graph_jar(): swh_graph_root = pathlib.Path(__file__).parents[2] try_paths = [ swh_graph_root / 'java/server/target/', pathlib.Path(sys.prefix) / 'share/swh-graph/', ] for path in try_paths: glob = list(path.glob('swh-graph-*.jar')) if glob: return str(glob[0]) raise RuntimeError("swh-graph-*.jar not found. Have you run `make java`?") def _get_pipe_stderr(): # Get stderr if possible, or pipe to stdout if running with Jupyter. try: sys.stderr.fileno() except io.UnsupportedOperation: return subprocess.STDOUT else: return sys.stderr class Backend: def __init__(self, graph_path): self.gateway = None self.entry = None self.graph_path = graph_path def __enter__(self): self.gateway = JavaGateway.launch_gateway( java_path=None, classpath=find_graph_jar(), die_on_exit=True, redirect_stdout=sys.stdout, redirect_stderr=_get_pipe_stderr(), ) self.entry = self.gateway.jvm.org.softwareheritage.graph.Entry() self.entry.load_graph(self.graph_path) self.node2pid = IntToPidMap(self.graph_path + '.' + NODE2PID_EXT) self.pid2node = PidToIntMap(self.graph_path + '.' + PID2NODE_EXT) self.stream_proxy = JavaStreamProxy(self.entry) return self def __exit__(self, exc_type, exc_value, tb): self.gateway.shutdown() def stats(self): return self.entry.stats() + def count(self, ttype, direction, edges_fmt, src): + method = getattr(self.entry, 'count_' + ttype) + return method(direction, edges_fmt, src) + async def simple_traversal(self, ttype, direction, edges_fmt, src): assert ttype in ('leaves', 'neighbors', 'visit_nodes') method = getattr(self.stream_proxy, ttype) async for node_id in method(direction, edges_fmt, src): yield node_id async def walk(self, direction, edges_fmt, algo, src, dst): if dst in PID_TYPES: it = self.stream_proxy.walk_type(direction, edges_fmt, algo, src, dst) else: it = self.stream_proxy.walk(direction, edges_fmt, algo, src, dst) async for node_id in it: yield node_id async def visit_paths(self, direction, edges_fmt, src): path = [] async for node in self.stream_proxy.visit_paths( direction, edges_fmt, src): if node == PATH_SEPARATOR_ID: yield path path = [] else: path.append(node) class JavaStreamProxy: """A proxy class for the org.softwareheritage.graph.Entry Java class that takes care of the setup and teardown of the named-pipe FIFO communication between Python and Java. Initialize JavaStreamProxy using: proxy = JavaStreamProxy(swh_entry_class_instance) Then you can call an Entry method and iterate on the FIFO results like this: async for value in proxy.java_method(arg1, arg2): print(value) """ def __init__(self, entry): self.entry = entry async def read_node_ids(self, fname): loop = asyncio.get_event_loop() open_thread = loop.run_in_executor(None, open, fname, 'rb') # Since the open() call on the FIFO is blocking until it is also opened # on the Java side, we await it with a timeout in case there is an # exception that prevents the write-side open(). with (await asyncio.wait_for(open_thread, timeout=2)) as f: while True: data = await loop.run_in_executor(None, f.read, BUF_SIZE) if not data: break for data in struct.iter_unpack(BIN_FMT, data): yield data[0] class _HandlerWrapper: def __init__(self, handler): self._handler = handler def __getattr__(self, name): func = getattr(self._handler, name) async def java_call(*args, **kwargs): loop = asyncio.get_event_loop() await loop.run_in_executor(None, lambda: func(*args, **kwargs)) def java_task(*args, **kwargs): return asyncio.create_task(java_call(*args, **kwargs)) return java_task @contextlib.contextmanager def get_handler(self): with tempfile.TemporaryDirectory(prefix='swh-graph-') as tmpdirname: cli_fifo = os.path.join(tmpdirname, 'swh-graph.fifo') os.mkfifo(cli_fifo) reader = self.read_node_ids(cli_fifo) query_handler = self.entry.get_handler(cli_fifo) handler = self._HandlerWrapper(query_handler) yield (handler, reader) def __getattr__(self, name): async def java_call_iterator(*args, **kwargs): with self.get_handler() as (handler, reader): java_task = getattr(handler, name)(*args, **kwargs) try: async for value in reader: yield value except asyncio.TimeoutError: # If the read-side open() timeouts, an exception on the # Java side probably happened that prevented the # write-side open(). We propagate this exception here if # that is the case. task_exc = java_task.exception() if task_exc: raise task_exc raise await java_task return java_call_iterator diff --git a/swh/graph/client.py b/swh/graph/client.py index a2a8095..8254fe7 100644 --- a/swh/graph/client.py +++ b/swh/graph/client.py @@ -1,82 +1,106 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import json from swh.core.api import RPCClient class GraphAPIError(Exception): """Graph API Error""" def __str__(self): return ('An unexpected error occurred in the Graph backend: {}' .format(self.args)) class RemoteGraphClient(RPCClient): """Client to the Software Heritage Graph.""" def __init__(self, url, timeout=None): super().__init__( api_exception=GraphAPIError, url=url, timeout=timeout) def raw_verb_lines(self, verb, endpoint, **kwargs): response = self.raw_verb(verb, endpoint, stream=True, **kwargs) for line in response.iter_lines(): yield line.decode().lstrip('\n') def get_lines(self, endpoint, **kwargs): yield from self.raw_verb_lines('get', endpoint, **kwargs) # Web API endpoints def stats(self): return self.get('stats') def leaves(self, src, edges="*", direction="forward"): return self.get_lines( 'leaves/{}'.format(src), params={ 'edges': edges, 'direction': direction }) def neighbors(self, src, edges="*", direction="forward"): return self.get_lines( 'neighbors/{}'.format(src), params={ 'edges': edges, 'direction': direction }) def visit_nodes(self, src, edges="*", direction="forward"): return self.get_lines( 'visit/nodes/{}'.format(src), params={ 'edges': edges, 'direction': direction }) def visit_paths(self, src, edges="*", direction="forward"): def decode_path_wrapper(it): for e in it: yield json.loads(e) return decode_path_wrapper( self.get_lines( 'visit/paths/{}'.format(src), params={ 'edges': edges, 'direction': direction })) def walk(self, src, dst, edges="*", traversal="dfs", direction="forward"): return self.get_lines( 'walk/{}/{}'.format(src, dst), params={ 'edges': edges, 'traversal': traversal, 'direction': direction }) + + def count_leaves(self, src, edges="*", direction="forward"): + return self.get( + 'leaves/count/{}'.format(src), + params={ + 'edges': edges, + 'direction': direction + }) + + def count_neighbors(self, src, edges="*", direction="forward"): + return self.get( + 'neighbors/count/{}'.format(src), + params={ + 'edges': edges, + 'direction': direction + }) + + def count_visit_nodes(self, src, edges="*", direction="forward"): + return self.get( + 'visit/nodes/count/{}'.format(src), + params={ + 'edges': edges, + 'direction': direction + }) diff --git a/swh/graph/graph.py b/swh/graph/graph.py index b0dda0b..ed042d3 100644 --- a/swh/graph/graph.py +++ b/swh/graph/graph.py @@ -1,157 +1,166 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import asyncio import contextlib +import functools from swh.graph.backend import Backend from swh.graph.dot import dot_to_svg, graph_dot, KIND_TO_SHAPE -KIND_TO_URL = { - 'ori': 'https://archive.softwareheritage.org/browse/origin/{}', - 'snp': 'https://archive.softwareheritage.org/browse/snapshot/{}', - 'rel': 'https://archive.softwareheritage.org/browse/release/{}', - 'rev': 'https://archive.softwareheritage.org/browse/revision/{}', - 'dir': 'https://archive.softwareheritage.org/browse/directory/{}', - 'cnt': 'https://archive.softwareheritage.org/browse/content/sha1_git:{}/', +BASE_URL = 'https://archive.softwareheritage.org/browse' +KIND_TO_URL_FRAGMENT = { + 'ori': '/origin/{}', + 'snp': '/snapshot/{}', + 'rel': '/release/{}', + 'rev': '/revision/{}', + 'dir': '/directory/{}', + 'cnt': '/content/sha1_git:{}/', } def call_async_gen(generator, *args, **kwargs): loop = asyncio.get_event_loop() it = generator(*args, **kwargs).__aiter__() while True: try: res = loop.run_until_complete(it.__anext__()) yield res except StopAsyncIteration: break class Neighbors: """Neighbor iterator with custom O(1) length method""" def __init__(self, graph, iterator, length_func): self.graph = graph self.iterator = iterator self.length_func = length_func def __iter__(self): return self def __next__(self): succ = self.iterator.nextLong() if succ == -1: raise StopIteration return GraphNode(self.graph, succ) def __len__(self): return self.length_func() class GraphNode: """Node in the SWH graph""" def __init__(self, graph, node_id): self.graph = graph self.id = node_id def children(self): return Neighbors( self.graph, self.graph.java_graph.successors(self.id), lambda: self.graph.java_graph.outdegree(self.id)) def parents(self): return Neighbors( self.graph, self.graph.java_graph.predecessors(self.id), lambda: self.graph.java_graph.indegree(self.id)) def simple_traversal(self, ttype, direction='forward', edges='*'): for node in call_async_gen( self.graph.backend.simple_traversal, ttype, direction, edges, self.id ): yield self.graph[node] def leaves(self, *args, **kwargs): yield from self.simple_traversal('leaves', *args, **kwargs) def visit_nodes(self, *args, **kwargs): yield from self.simple_traversal('visit_nodes', *args, **kwargs) def visit_paths(self, direction='forward', edges='*'): for path in call_async_gen( self.graph.backend.visit_paths, direction, edges, self.id ): yield [self.graph[node] for node in path] def walk(self, dst, direction='forward', edges='*', traversal='dfs'): for node in call_async_gen( self.graph.backend.walk, direction, edges, traversal, self.id, dst ): yield self.graph[node] + def _count(self, ttype, direction='forward', edges='*'): + return self.graph.backend.count(ttype, direction, edges, self.id) + + count_leaves = functools.partialmethod(_count, ttype='leaves') + count_neighbors = functools.partialmethod(_count, ttype='neighbors') + count_visit_nodes = functools.partialmethod(_count, ttype='visit_nodes') + @property def pid(self): return self.graph.node2pid[self.id] @property def kind(self): return self.pid.split(':')[2] def __str__(self): return self.pid def __repr__(self): return '<{}>'.format(self.pid) def dot_fragment(self): swh, version, kind, hash = self.pid.split(':') label = '{}:{}..{}'.format(kind, hash[0:2], hash[-2:]) - url = KIND_TO_URL[kind].format(hash) + url = BASE_URL + KIND_TO_URL_FRAGMENT[kind].format(hash) shape = KIND_TO_SHAPE[kind] return ('{} [label="{}", href="{}", target="_blank", shape="{}"];' .format(self.id, label, url, shape)) def _repr_svg_(self): nodes = [self, *list(self.children()), *list(self.parents())] dot = graph_dot(nodes) svg = dot_to_svg(dot) return svg class Graph: def __init__(self, backend, node2pid, pid2node): self.backend = backend self.java_graph = backend.entry.get_graph() self.node2pid = node2pid self.pid2node = pid2node def stats(self): return self.backend.stats() @property def path(self): return self.java_graph.getPath() def __len__(self): return self.java_graph.getNbNodes() def __getitem__(self, node_id): if isinstance(node_id, int): self.node2pid[node_id] # check existence return GraphNode(self, node_id) elif isinstance(node_id, str): node_id = self.pid2node[node_id] return GraphNode(self, node_id) @contextlib.contextmanager def load(graph_path): with Backend(graph_path) as backend: yield Graph(backend, backend.node2pid, backend.pid2node) diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py index 8e57946..c5d4759 100644 --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -1,120 +1,144 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """ A proxy HTTP server for swh-graph, talking to the Java code via py4j, and using FIFO as a transport to stream integers between the two languages. """ import json import contextlib import aiohttp.web from swh.core.api.asynchronous import RPCServerApp from swh.model.identifiers import PID_TYPES @contextlib.asynccontextmanager async def stream_response(request, *args, **kwargs): response = aiohttp.web.StreamResponse(*args, **kwargs) await response.prepare(request) yield response await response.write_eof() async def index(request): return aiohttp.web.Response( content_type='text/html', body=""" Software Heritage storage server

You have reached the Software Heritage graph API server.

See its API documentation for more information.

""") async def stats(request): stats = request.app['backend'].stats() return aiohttp.web.Response(body=stats, content_type='application/json') def get_simple_traversal_handler(ttype): async def simple_traversal(request): backend = request.app['backend'] src = request.match_info['src'] edges = request.query.get('edges', '*') direction = request.query.get('direction', 'forward') src_node = backend.pid2node[src] async with stream_response(request) as response: async for res_node in backend.simple_traversal( ttype, direction, edges, src_node ): res_pid = backend.node2pid[res_node] await response.write('{}\n'.format(res_pid).encode()) return response return simple_traversal async def walk(request): backend = request.app['backend'] src = request.match_info['src'] dst = request.match_info['dst'] edges = request.query.get('edges', '*') direction = request.query.get('direction', 'forward') algo = request.query.get('traversal', 'dfs') src_node = backend.pid2node[src] if dst not in PID_TYPES: dst = backend.pid2node[dst] async with stream_response(request) as response: async for res_node in backend.walk( direction, edges, algo, src_node, dst ): res_pid = backend.node2pid[res_node] await response.write('{}\n'.format(res_pid).encode()) return response async def visit_paths(request): backend = request.app['backend'] src = request.match_info['src'] edges = request.query.get('edges', '*') direction = request.query.get('direction', 'forward') src_node = backend.pid2node[src] it = backend.visit_paths(direction, edges, src_node) async with stream_response(request) as response: async for res_path in it: res_path_pid = [backend.node2pid[n] for n in res_path] line = json.dumps(res_path_pid) await response.write('{}\n'.format(line).encode()) return response +def get_count_handler(ttype): + async def count(request): + backend = request.app['backend'] + + src = request.match_info['src'] + edges = request.query.get('edges', '*') + direction = request.query.get('direction', 'forward') + + src_node = backend.pid2node[src] + cnt = backend.count(ttype, direction, edges, src_node) + return aiohttp.web.Response(body=str(cnt), + content_type='application/json') + + return count + + def make_app(backend, **kwargs): app = RPCServerApp(**kwargs) app.router.add_route('GET', '/', index) app.router.add_route('GET', '/graph/stats', stats) + app.router.add_route('GET', '/graph/leaves/{src}', get_simple_traversal_handler('leaves')) app.router.add_route('GET', '/graph/neighbors/{src}', get_simple_traversal_handler('neighbors')) app.router.add_route('GET', '/graph/visit/nodes/{src}', get_simple_traversal_handler('visit_nodes')) app.router.add_route('GET', '/graph/visit/paths/{src}', visit_paths) app.router.add_route('GET', '/graph/walk/{src}/{dst}', walk) + app.router.add_route('GET', '/graph/neighbors/count/{src}', + get_count_handler('neighbors')) + app.router.add_route('GET', '/graph/leaves/count/{src}', + get_count_handler('leaves')) + app.router.add_route('GET', '/graph/visit/nodes/count/{src}', + get_count_handler('visit_nodes')) + app['backend'] = backend return app diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py index bbf55a2..8ff3cf4 100644 --- a/swh/graph/tests/test_api_client.py +++ b/swh/graph/tests/test_api_client.py @@ -1,104 +1,122 @@ def test_stats(graph_client): stats = graph_client.stats() assert set(stats.keys()) == {'counts', 'ratios', 'indegree', 'outdegree'} assert set(stats['counts'].keys()) == {'nodes', 'edges'} assert set(stats['ratios'].keys()) == {'compression', 'bits_per_node', 'bits_per_edge', 'avg_locality'} assert set(stats['indegree'].keys()) == {'min', 'max', 'avg'} assert set(stats['outdegree'].keys()) == {'min', 'max', 'avg'} assert stats['counts']['nodes'] == 21 assert stats['counts']['edges'] == 23 assert isinstance(stats['ratios']['compression'], float) assert isinstance(stats['ratios']['bits_per_node'], float) assert isinstance(stats['ratios']['bits_per_edge'], float) assert isinstance(stats['ratios']['avg_locality'], float) assert stats['indegree']['min'] == 0 assert stats['indegree']['max'] == 3 assert isinstance(stats['indegree']['avg'], float) assert stats['outdegree']['min'] == 0 assert stats['outdegree']['max'] == 3 assert isinstance(stats['outdegree']['avg'], float) def test_leaves(graph_client): actual = list(graph_client.leaves( 'swh:1:ori:0000000000000000000000000000000000000021' )) expected = [ 'swh:1:cnt:0000000000000000000000000000000000000001', 'swh:1:cnt:0000000000000000000000000000000000000004', 'swh:1:cnt:0000000000000000000000000000000000000005', 'swh:1:cnt:0000000000000000000000000000000000000007' ] assert set(actual) == set(expected) def test_neighbors(graph_client): actual = list(graph_client.neighbors( 'swh:1:rev:0000000000000000000000000000000000000009', direction='backward' )) expected = [ 'swh:1:snp:0000000000000000000000000000000000000020', 'swh:1:rel:0000000000000000000000000000000000000010', 'swh:1:rev:0000000000000000000000000000000000000013' ] assert set(actual) == set(expected) def test_visit_nodes(graph_client): actual = list(graph_client.visit_nodes( 'swh:1:rel:0000000000000000000000000000000000000010', edges='rel:rev,rev:rev' )) expected = [ 'swh:1:rel:0000000000000000000000000000000000000010', 'swh:1:rev:0000000000000000000000000000000000000009', 'swh:1:rev:0000000000000000000000000000000000000003' ] assert set(actual) == set(expected) def test_visit_paths(graph_client): actual = list(graph_client.visit_paths( 'swh:1:snp:0000000000000000000000000000000000000020', edges='snp:*,rev:*')) actual = [tuple(path) for path in actual] expected = [ ( 'swh:1:snp:0000000000000000000000000000000000000020', 'swh:1:rev:0000000000000000000000000000000000000009', 'swh:1:rev:0000000000000000000000000000000000000003', 'swh:1:dir:0000000000000000000000000000000000000002' ), ( 'swh:1:snp:0000000000000000000000000000000000000020', 'swh:1:rev:0000000000000000000000000000000000000009', 'swh:1:dir:0000000000000000000000000000000000000008' ), ( 'swh:1:snp:0000000000000000000000000000000000000020', 'swh:1:rel:0000000000000000000000000000000000000010' ) ] assert set(actual) == set(expected) def test_walk(graph_client): actual = list(graph_client.walk( 'swh:1:dir:0000000000000000000000000000000000000016', 'rel', edges='dir:dir,dir:rev,rev:*', direction='backward', traversal='bfs' )) expected = [ 'swh:1:dir:0000000000000000000000000000000000000016', 'swh:1:dir:0000000000000000000000000000000000000017', 'swh:1:rev:0000000000000000000000000000000000000018', 'swh:1:rel:0000000000000000000000000000000000000019' ] assert set(actual) == set(expected) + + +def test_count(graph_client): + print(graph_client) + actual = graph_client.count_leaves( + 'swh:1:ori:0000000000000000000000000000000000000021' + ) + assert actual == 4 + actual = graph_client.count_visit_nodes( + 'swh:1:rel:0000000000000000000000000000000000000010', + edges='rel:rev,rev:rev' + ) + assert actual == 3 + actual = graph_client.count_neighbors( + 'swh:1:rev:0000000000000000000000000000000000000009', + direction='backward' + ) + assert actual == 3 diff --git a/swh/graph/tests/test_graph.py b/swh/graph/tests/test_graph.py index ef5919b..4fe848d 100644 --- a/swh/graph/tests/test_graph.py +++ b/swh/graph/tests/test_graph.py @@ -1,103 +1,112 @@ import pytest def test_graph(graph): assert len(graph) == 21 obj = 'swh:1:dir:0000000000000000000000000000000000000008' node = graph[obj] assert str(node) == obj assert len(node.children()) == 3 assert len(node.parents()) == 2 actual = {p.pid for p in node.children()} expected = { 'swh:1:cnt:0000000000000000000000000000000000000001', 'swh:1:dir:0000000000000000000000000000000000000006', 'swh:1:cnt:0000000000000000000000000000000000000007' } assert expected == actual actual = {p.pid for p in node.parents()} expected = { 'swh:1:rev:0000000000000000000000000000000000000009', 'swh:1:dir:0000000000000000000000000000000000000012', } assert expected == actual def test_invalid_pid(graph): with pytest.raises(IndexError): graph[1337] with pytest.raises(IndexError): graph[len(graph) + 1] with pytest.raises(KeyError): graph['swh:1:dir:0000000000000000000000000000000420000012'] def test_leaves(graph): actual = list(graph['swh:1:ori:0000000000000000000000000000000000000021'] .leaves()) actual = [p.pid for p in actual] expected = [ 'swh:1:cnt:0000000000000000000000000000000000000001', 'swh:1:cnt:0000000000000000000000000000000000000004', 'swh:1:cnt:0000000000000000000000000000000000000005', 'swh:1:cnt:0000000000000000000000000000000000000007' ] assert set(actual) == set(expected) def test_visit_nodes(graph): actual = list(graph['swh:1:rel:0000000000000000000000000000000000000010'] .visit_nodes(edges='rel:rev,rev:rev')) actual = [p.pid for p in actual] expected = [ 'swh:1:rel:0000000000000000000000000000000000000010', 'swh:1:rev:0000000000000000000000000000000000000009', 'swh:1:rev:0000000000000000000000000000000000000003' ] assert set(actual) == set(expected) def test_visit_paths(graph): actual = list(graph['swh:1:snp:0000000000000000000000000000000000000020'] .visit_paths(edges='snp:*,rev:*')) actual = [tuple(n.pid for n in path) for path in actual] expected = [ ( 'swh:1:snp:0000000000000000000000000000000000000020', 'swh:1:rev:0000000000000000000000000000000000000009', 'swh:1:rev:0000000000000000000000000000000000000003', 'swh:1:dir:0000000000000000000000000000000000000002' ), ( 'swh:1:snp:0000000000000000000000000000000000000020', 'swh:1:rev:0000000000000000000000000000000000000009', 'swh:1:dir:0000000000000000000000000000000000000008' ), ( 'swh:1:snp:0000000000000000000000000000000000000020', 'swh:1:rel:0000000000000000000000000000000000000010' ) ] assert set(actual) == set(expected) def test_walk(graph): actual = list(graph['swh:1:dir:0000000000000000000000000000000000000016'] .walk('rel', edges='dir:dir,dir:rev,rev:*', direction='backward', traversal='bfs')) actual = [p.pid for p in actual] expected = [ 'swh:1:dir:0000000000000000000000000000000000000016', 'swh:1:dir:0000000000000000000000000000000000000017', 'swh:1:rev:0000000000000000000000000000000000000018', 'swh:1:rel:0000000000000000000000000000000000000019' ] assert set(actual) == set(expected) + + +def test_count(graph): + assert (graph['swh:1:ori:0000000000000000000000000000000000000021'] + .count_leaves() == 4) + assert (graph['swh:1:rel:0000000000000000000000000000000000000010'] + .count_visit_nodes(edges='rel:rev,rev:rev') == 3) + assert (graph['swh:1:rev:0000000000000000000000000000000000000009'] + .count_neighbors(direction='backward') == 3)