You have reached the Software Heritage graph API server.
See its API documentation for more information.
""", ) class GraphView(aiohttp.web.View): """Base class for views working on the graph, with utility functions""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.backend = self.request.app["backend"] def node_of_pid(self, pid): """Lookup a PID in a pid2node map, failing in an HTTP-nice way if needed.""" try: return self.backend.pid2node[pid] except KeyError: raise aiohttp.web.HTTPNotFound(body=f"PID not found: {pid}") except ValidationError: raise aiohttp.web.HTTPBadRequest(body=f"malformed PID: {pid}") def pid_of_node(self, node): """Lookup a node in a node2pid map, failing in an HTTP-nice way if needed.""" try: return self.backend.node2pid[node] except KeyError: raise aiohttp.web.HTTPInternalServerError( body=f"reverse lookup failed for node id: {node}" ) def get_direction(self): """Validate HTTP query parameter `direction`""" s = self.request.query.get("direction", "forward") if s not in ("forward", "backward"): raise aiohttp.web.HTTPBadRequest(body=f"invalid direction: {s}") return s def get_edges(self): """Validate HTTP query parameter `edges`, i.e., edge restrictions""" s = self.request.query.get("edges", "*") if any( [ node_type != "*" and node_type not in PID_TYPES for edge in s.split(":") for node_type in edge.split(",", maxsplit=1) ] ): raise aiohttp.web.HTTPBadRequest(body=f"invalid edge restriction: {s}") return s def get_traversal(self): """Validate HTTP query parameter `traversal`, i.e., visit order""" s = self.request.query.get("traversal", "dfs") if s not in ("bfs", "dfs"): raise aiohttp.web.HTTPBadRequest(body=f"invalid traversal order: {s}") return s def get_limit(self): """Validate HTTP query parameter `limit`, i.e., number of results""" s = self.request.query.get("limit", "0") try: return int(s) except ValueError: raise aiohttp.web.HTTPBadRequest(body=f"invalid limit value: {s}") class StreamingGraphView(GraphView): """Base class for views streaming their response line by line.""" content_type = "text/plain" @asynccontextmanager async def response_streamer(self, *args, **kwargs): """Context manager to prepare then close a StreamResponse""" response = aiohttp.web.StreamResponse(*args, **kwargs) response.content_type = self.content_type await response.prepare(self.request) yield response await response.write_eof() async def get(self): await self.prepare_response() async with self.response_streamer() as self.response_stream: await self.stream_response() return self.response_stream async def prepare_response(self): """This can be overridden with some setup to be run before the response actually starts streaming. """ pass async def stream_response(self): """Override this to perform the response streaming. Implementations of this should await self.stream_line(line) to write each line. """ raise NotImplementedError async def stream_line(self, line): """Write a line in the response stream.""" await self.response_stream.write((line + "\n").encode()) class StatsView(GraphView): """View showing some statistics on the graph""" async def get(self): stats = self.backend.stats() return aiohttp.web.Response(body=stats, content_type="application/json") class SimpleTraversalView(StreamingGraphView): """Base class for views of simple traversals""" simple_traversal_type: Optional[str] = None async def prepare_response(self): src = self.request.match_info["src"] self.src_node = self.node_of_pid(src) self.edges = self.get_edges() self.direction = self.get_direction() async def stream_response(self): async for res_node in self.backend.simple_traversal( self.simple_traversal_type, self.direction, self.edges, self.src_node ): res_pid = self.pid_of_node(res_node) await self.stream_line(res_pid) class LeavesView(SimpleTraversalView): simple_traversal_type = "leaves" class NeighborsView(SimpleTraversalView): simple_traversal_type = "neighbors" class VisitNodesView(SimpleTraversalView): simple_traversal_type = "visit_nodes" class WalkView(StreamingGraphView): async def prepare_response(self): src = self.request.match_info["src"] dst = self.request.match_info["dst"] self.src_node = self.node_of_pid(src) if dst not in PID_TYPES: self.dst_thing = self.node_of_pid(dst) else: self.dst_thing = dst self.edges = self.get_edges() self.direction = self.get_direction() self.algo = self.get_traversal() self.limit = self.get_limit() async def get_walk_iterator(self): return self.backend.walk( self.direction, self.edges, self.algo, self.src_node, self.dst_thing ) async def stream_response(self): it = self.get_walk_iterator() if self.limit < 0: queue = deque(maxlen=-self.limit) async for res_node in it: res_pid = self.pid_of_node(res_node) queue.append(res_pid) while queue: await self.stream_line(queue.popleft()) else: count = 0 async for res_node in it: if self.limit == 0 or count < self.limit: res_pid = self.pid_of_node(res_node) await self.stream_line(res_pid) count += 1 else: break class RandomWalkView(WalkView): def get_walk_iterator(self): return self.backend.random_walk( self.direction, self.edges, RANDOM_RETRIES, self.src_node, self.dst_thing ) class VisitEdgesView(SimpleTraversalView): async def stream_response(self): it = self.backend.visit_edges(self.direction, self.edges, self.src_node) async for (res_src, res_dst) in it: res_src_pid = self.pid_of_node(res_src) res_dst_pid = self.pid_of_node(res_dst) await self.stream_line("{} {}".format(res_src_pid, res_dst_pid)) class VisitPathsView(SimpleTraversalView): content_type = "application/x-ndjson" async def stream_response(self): it = self.backend.visit_paths(self.direction, self.edges, self.src_node) async for res_path in it: res_path_pid = [self.pid_of_node(n) for n in res_path] line = json.dumps(res_path_pid) await self.stream_line(line) class CountView(GraphView): """Base class for counting views.""" count_type: Optional[str] = None async def get(self): src = self.request.match_info["src"] self.src_node = self.node_of_pid(src) self.edges = self.get_edges() self.direction = self.get_direction() loop = asyncio.get_event_loop() cnt = await loop.run_in_executor( None, self.backend.count, self.count_type, self.direction, self.edges, self.src_node, ) return aiohttp.web.Response(body=str(cnt), content_type="application/json") class CountNeighborsView(CountView): count_type = "neighbors" class CountLeavesView(CountView): count_type = "leaves" class CountVisitNodesView(CountView): count_type = "visit_nodes" def make_app(backend, **kwargs): app = RPCServerApp(**kwargs) app.add_routes( [ aiohttp.web.get("/", index), aiohttp.web.get("/graph", index), aiohttp.web.view("/graph/stats", StatsView), aiohttp.web.view("/graph/leaves/{src}", LeavesView), aiohttp.web.view("/graph/neighbors/{src}", NeighborsView), aiohttp.web.view("/graph/visit/nodes/{src}", VisitNodesView), aiohttp.web.view("/graph/visit/edges/{src}", VisitEdgesView), aiohttp.web.view("/graph/visit/paths/{src}", VisitPathsView), # temporarily disabled in wait of a proper fix for T1969 # aiohttp.web.view("/graph/walk/{src}/{dst}", WalkView) aiohttp.web.view("/graph/randomwalk/{src}/{dst}", RandomWalkView), aiohttp.web.view("/graph/neighbors/count/{src}", CountNeighborsView), aiohttp.web.view("/graph/leaves/count/{src}", CountLeavesView), aiohttp.web.view("/graph/visit/nodes/count/{src}", CountVisitNodesView), ] ) app["backend"] = backend return app diff --git a/swh/graph/tests/conftest.py b/swh/graph/tests/conftest.py index 442e32c..497062e 100644 --- a/swh/graph/tests/conftest.py +++ b/swh/graph/tests/conftest.py @@ -1,51 +1,51 @@ import multiprocessing -import pytest - -from aiohttp.test_utils import TestServer, TestClient, loop_context from pathlib import Path -from swh.graph.graph import load as graph_load -from swh.graph.client import RemoteGraphClient +from aiohttp.test_utils import TestClient, TestServer, loop_context +import pytest + from swh.graph.backend import Backend +from swh.graph.client import RemoteGraphClient +from swh.graph.graph import load as graph_load from swh.graph.server.app import make_app SWH_GRAPH_TESTS_ROOT = Path(__file__).parents[0] TEST_GRAPH_PATH = SWH_GRAPH_TESTS_ROOT / "dataset/output/example" class GraphServerProcess(multiprocessing.Process): def __init__(self, q, *args, **kwargs): self.q = q super().__init__(*args, **kwargs) def run(self): try: backend = Backend(graph_path=str(TEST_GRAPH_PATH)) with backend: with loop_context() as loop: app = make_app(backend=backend, debug=True) client = TestClient(TestServer(app), loop=loop) loop.run_until_complete(client.start_server()) url = client.make_url("/graph/") self.q.put(url) loop.run_forever() except Exception as e: self.q.put(e) @pytest.fixture(scope="module") def graph_client(): queue = multiprocessing.Queue() server = GraphServerProcess(queue) server.start() res = queue.get() if isinstance(res, Exception): raise res yield RemoteGraphClient(str(res)) server.terminate() @pytest.fixture(scope="module") def graph(): with graph_load(str(TEST_GRAPH_PATH)) as g: yield g diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py index 96279d0..4bed140 100644 --- a/swh/graph/tests/test_api_client.py +++ b/swh/graph/tests/test_api_client.py @@ -1,305 +1,303 @@ import pytest - - from pytest import raises from swh.core.api import RemoteException def test_stats(graph_client): stats = graph_client.stats() assert set(stats.keys()) == {"counts", "ratios", "indegree", "outdegree"} assert set(stats["counts"].keys()) == {"nodes", "edges"} assert set(stats["ratios"].keys()) == { "compression", "bits_per_node", "bits_per_edge", "avg_locality", } assert set(stats["indegree"].keys()) == {"min", "max", "avg"} assert set(stats["outdegree"].keys()) == {"min", "max", "avg"} assert stats["counts"]["nodes"] == 21 assert stats["counts"]["edges"] == 23 assert isinstance(stats["ratios"]["compression"], float) assert isinstance(stats["ratios"]["bits_per_node"], float) assert isinstance(stats["ratios"]["bits_per_edge"], float) assert isinstance(stats["ratios"]["avg_locality"], float) assert stats["indegree"]["min"] == 0 assert stats["indegree"]["max"] == 3 assert isinstance(stats["indegree"]["avg"], float) assert stats["outdegree"]["min"] == 0 assert stats["outdegree"]["max"] == 3 assert isinstance(stats["outdegree"]["avg"], float) def test_leaves(graph_client): actual = list( graph_client.leaves("swh:1:ori:0000000000000000000000000000000000000021") ) expected = [ "swh:1:cnt:0000000000000000000000000000000000000001", "swh:1:cnt:0000000000000000000000000000000000000004", "swh:1:cnt:0000000000000000000000000000000000000005", "swh:1:cnt:0000000000000000000000000000000000000007", ] assert set(actual) == set(expected) def test_neighbors(graph_client): actual = list( graph_client.neighbors( "swh:1:rev:0000000000000000000000000000000000000009", direction="backward" ) ) expected = [ "swh:1:snp:0000000000000000000000000000000000000020", "swh:1:rel:0000000000000000000000000000000000000010", "swh:1:rev:0000000000000000000000000000000000000013", ] assert set(actual) == set(expected) def test_visit_nodes(graph_client): actual = list( graph_client.visit_nodes( "swh:1:rel:0000000000000000000000000000000000000010", edges="rel:rev,rev:rev", ) ) expected = [ "swh:1:rel:0000000000000000000000000000000000000010", "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:rev:0000000000000000000000000000000000000003", ] assert set(actual) == set(expected) def test_visit_edges(graph_client): actual = list( graph_client.visit_edges( "swh:1:rel:0000000000000000000000000000000000000010", edges="rel:rev,rev:rev,rev:dir", ) ) expected = [ ( "swh:1:rel:0000000000000000000000000000000000000010", "swh:1:rev:0000000000000000000000000000000000000009", ), ( "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:rev:0000000000000000000000000000000000000003", ), ( "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:dir:0000000000000000000000000000000000000008", ), ( "swh:1:rev:0000000000000000000000000000000000000003", "swh:1:dir:0000000000000000000000000000000000000002", ), ] assert set(actual) == set(expected) def test_visit_edges_diamond_pattern(graph_client): actual = list( graph_client.visit_edges( "swh:1:rev:0000000000000000000000000000000000000009", edges="*", ) ) expected = [ ( "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:rev:0000000000000000000000000000000000000003", ), ( "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:dir:0000000000000000000000000000000000000008", ), ( "swh:1:rev:0000000000000000000000000000000000000003", "swh:1:dir:0000000000000000000000000000000000000002", ), ( "swh:1:dir:0000000000000000000000000000000000000002", "swh:1:cnt:0000000000000000000000000000000000000001", ), ( "swh:1:dir:0000000000000000000000000000000000000008", "swh:1:cnt:0000000000000000000000000000000000000001", ), ( "swh:1:dir:0000000000000000000000000000000000000008", "swh:1:cnt:0000000000000000000000000000000000000007", ), ( "swh:1:dir:0000000000000000000000000000000000000008", "swh:1:dir:0000000000000000000000000000000000000006", ), ( "swh:1:dir:0000000000000000000000000000000000000006", "swh:1:cnt:0000000000000000000000000000000000000004", ), ( "swh:1:dir:0000000000000000000000000000000000000006", "swh:1:cnt:0000000000000000000000000000000000000005", ), ] assert set(actual) == set(expected) def test_visit_paths(graph_client): actual = list( graph_client.visit_paths( "swh:1:snp:0000000000000000000000000000000000000020", edges="snp:*,rev:*" ) ) actual = [tuple(path) for path in actual] expected = [ ( "swh:1:snp:0000000000000000000000000000000000000020", "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:rev:0000000000000000000000000000000000000003", "swh:1:dir:0000000000000000000000000000000000000002", ), ( "swh:1:snp:0000000000000000000000000000000000000020", "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:dir:0000000000000000000000000000000000000008", ), ( "swh:1:snp:0000000000000000000000000000000000000020", "swh:1:rel:0000000000000000000000000000000000000010", ), ] assert set(actual) == set(expected) @pytest.mark.skip(reason="currently disabled due to T1969") def test_walk(graph_client): args = ("swh:1:dir:0000000000000000000000000000000000000016", "rel") kwargs = { "edges": "dir:dir,dir:rev,rev:*", "direction": "backward", "traversal": "bfs", } actual = list(graph_client.walk(*args, **kwargs)) expected = [ "swh:1:dir:0000000000000000000000000000000000000016", "swh:1:dir:0000000000000000000000000000000000000017", "swh:1:rev:0000000000000000000000000000000000000018", "swh:1:rel:0000000000000000000000000000000000000019", ] assert set(actual) == set(expected) kwargs2 = kwargs.copy() kwargs2["limit"] = -1 actual = list(graph_client.walk(*args, **kwargs2)) expected = ["swh:1:rel:0000000000000000000000000000000000000019"] assert set(actual) == set(expected) kwargs2 = kwargs.copy() kwargs2["limit"] = 2 actual = list(graph_client.walk(*args, **kwargs2)) expected = [ "swh:1:dir:0000000000000000000000000000000000000016", "swh:1:dir:0000000000000000000000000000000000000017", ] assert set(actual) == set(expected) def test_random_walk(graph_client): """as the walk is random, we test a visit from a cnt node to the only origin in the dataset, and only check the final node of the path (i.e., the origin) """ args = ("swh:1:cnt:0000000000000000000000000000000000000001", "ori") kwargs = {"direction": "backward"} expected_root = "swh:1:ori:0000000000000000000000000000000000000021" actual = list(graph_client.random_walk(*args, **kwargs)) assert len(actual) > 1 # no origin directly links to a content assert actual[0] == args[0] assert actual[-1] == expected_root kwargs2 = kwargs.copy() kwargs2["limit"] = -1 actual = list(graph_client.random_walk(*args, **kwargs2)) assert actual == [expected_root] kwargs2["limit"] = -2 actual = list(graph_client.random_walk(*args, **kwargs2)) assert len(actual) == 2 assert actual[-1] == expected_root kwargs2["limit"] = 3 actual = list(graph_client.random_walk(*args, **kwargs2)) assert len(actual) == 3 def test_count(graph_client): actual = graph_client.count_leaves( "swh:1:ori:0000000000000000000000000000000000000021" ) assert actual == 4 actual = graph_client.count_visit_nodes( "swh:1:rel:0000000000000000000000000000000000000010", edges="rel:rev,rev:rev" ) assert actual == 3 actual = graph_client.count_neighbors( "swh:1:rev:0000000000000000000000000000000000000009", direction="backward" ) assert actual == 3 def test_param_validation(graph_client): with raises(RemoteException) as exc_info: # PID not found list(graph_client.leaves("swh:1:ori:fff0000000000000000000000000000000000021")) assert exc_info.value.response.status_code == 404 with raises(RemoteException) as exc_info: # malformed PID list( graph_client.neighbors("swh:1:ori:fff000000zzzzzz0000000000000000000000021") ) assert exc_info.value.response.status_code == 400 with raises(RemoteException) as exc_info: # malformed edge specificaiton list( graph_client.visit_nodes( "swh:1:dir:0000000000000000000000000000000000000016", edges="dir:notanodetype,dir:rev,rev:*", direction="backward", ) ) assert exc_info.value.response.status_code == 400 with raises(RemoteException) as exc_info: # malformed direction list( graph_client.visit_nodes( "swh:1:dir:0000000000000000000000000000000000000016", edges="dir:dir,dir:rev,rev:*", direction="notadirection", ) ) assert exc_info.value.response.status_code == 400 @pytest.mark.skip(reason="currently disabled due to T1969") def test_param_validation_walk(graph_client): """test validation of walk-specific parameters only""" with raises(RemoteException) as exc_info: # malformed traversal order list( graph_client.walk( "swh:1:dir:0000000000000000000000000000000000000016", "rel", edges="dir:dir,dir:rev,rev:*", direction="backward", traversal="notatraversalorder", ) ) assert exc_info.value.response.status_code == 400 diff --git a/swh/graph/tests/test_cli.py b/swh/graph/tests/test_cli.py index 009a23d..4eaa389 100644 --- a/swh/graph/tests/test_cli.py +++ b/swh/graph/tests/test_cli.py @@ -1,46 +1,45 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from pathlib import Path from tempfile import TemporaryDirectory from typing import Dict from click.testing import CliRunner from swh.graph.cli import cli - DATA_DIR = Path(__file__).parents[0] / "dataset" def read_properties(properties_fname) -> Dict[str, str]: """read a Java .properties file""" with open(properties_fname) as f: keyvalues = ( line.split("=", maxsplit=1) for line in f if not line.strip().startswith("#") ) return dict((k.strip(), v.strip()) for (k, v) in keyvalues) def test_pipeline(): """run full compression pipeline""" # bare bone configuration, to allow testing the compression pipeline # with minimum RAM requirements on trivial graphs config = {"graph": {"compress": {"batch_size": 1000}}} runner = CliRunner() with TemporaryDirectory(suffix=".swh-graph-test") as tmpdir: result = runner.invoke( cli, ["compress", "--graph", DATA_DIR / "example", "--outdir", tmpdir], obj={"config": config}, ) assert result.exit_code == 0, result properties = read_properties(Path(tmpdir) / "example.properties") assert int(properties["nodes"]) == 21 assert int(properties["arcs"]) == 23 diff --git a/swh/graph/tests/test_pid.py b/swh/graph/tests/test_pid.py index 110c61e..f937729 100644 --- a/swh/graph/tests/test_pid.py +++ b/swh/graph/tests/test_pid.py @@ -1,202 +1,200 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from itertools import islice import os import shutil import tempfile import unittest -from itertools import islice - -from swh.graph.pid import str_to_bytes, bytes_to_str -from swh.graph.pid import PidToNodeMap, NodeToPidMap +from swh.graph.pid import NodeToPidMap, PidToNodeMap, bytes_to_str, str_to_bytes from swh.model.identifiers import PID_TYPES class TestPidSerialization(unittest.TestCase): pairs = [ ( "swh:1:cnt:94a9ed024d3859793618152ea559a168bbcbb5e2", bytes.fromhex("01" + "00" + "94a9ed024d3859793618152ea559a168bbcbb5e2"), ), ( "swh:1:dir:d198bc9d7a6bcf6db04f476d29314f157507d505", bytes.fromhex("01" + "01" + "d198bc9d7a6bcf6db04f476d29314f157507d505"), ), ( "swh:1:ori:b63a575fe3faab7692c9f38fb09d4bb45651bb0f", bytes.fromhex("01" + "02" + "b63a575fe3faab7692c9f38fb09d4bb45651bb0f"), ), ( "swh:1:rel:22ece559cc7cc2364edc5e5593d63ae8bd229f9f", bytes.fromhex("01" + "03" + "22ece559cc7cc2364edc5e5593d63ae8bd229f9f"), ), ( "swh:1:rev:309cf2674ee7a0749978cf8265ab91a60aea0f7d", bytes.fromhex("01" + "04" + "309cf2674ee7a0749978cf8265ab91a60aea0f7d"), ), ( "swh:1:snp:c7c108084bc0bf3d81436bf980b46e98bd338453", bytes.fromhex("01" + "05" + "c7c108084bc0bf3d81436bf980b46e98bd338453"), ), ] def test_str_to_bytes(self): for (pid_str, pid_bytes) in self.pairs: self.assertEqual(str_to_bytes(pid_str), pid_bytes) def test_bytes_to_str(self): for (pid_str, pid_bytes) in self.pairs: self.assertEqual(bytes_to_str(pid_bytes), pid_str) def test_round_trip(self): for (pid_str, pid_bytes) in self.pairs: self.assertEqual(pid_str, bytes_to_str(str_to_bytes(pid_str))) self.assertEqual(pid_bytes, str_to_bytes(bytes_to_str(pid_bytes))) def gen_records(types=["cnt", "dir", "ori", "rel", "rev", "snp"], length=10000): """generate sequential PID/int records, suitable for filling int<->pid maps for testing swh-graph on-disk binary databases Args: types (list): list of PID types to be generated, specified as the corresponding 3-letter component in PIDs length (int): number of PIDs to generate *per type* Yields: pairs (pid, int) where pid is a textual PID and int its sequential integer identifier """ pos = 0 for t in sorted(types): for i in range(0, length): seq = format(pos, "x") # current position as hex string pid = "swh:1:{}:{}{}".format(t, "0" * (40 - len(seq)), seq) yield (pid, pos) pos += 1 # pairs PID/position in the sequence generated by :func:`gen_records` above MAP_PAIRS = [ ("swh:1:cnt:0000000000000000000000000000000000000000", 0), ("swh:1:cnt:000000000000000000000000000000000000002a", 42), ("swh:1:dir:0000000000000000000000000000000000002afc", 11004), ("swh:1:ori:00000000000000000000000000000000000056ce", 22222), ("swh:1:rel:0000000000000000000000000000000000008235", 33333), ("swh:1:rev:000000000000000000000000000000000000ad9c", 44444), ("swh:1:snp:000000000000000000000000000000000000ea5f", 59999), ] class TestPidToNodeMap(unittest.TestCase): @classmethod def setUpClass(cls): """create reasonably sized (~2 MB) PID->int map to test on-disk DB """ cls.tmpdir = tempfile.mkdtemp(prefix="swh.graph.test.") cls.fname = os.path.join(cls.tmpdir, "pid2int.bin") with open(cls.fname, "wb") as f: for (pid, i) in gen_records(length=10000): PidToNodeMap.write_record(f, pid, i) @classmethod def tearDownClass(cls): shutil.rmtree(cls.tmpdir) def setUp(self): self.map = PidToNodeMap(self.fname) def tearDown(self): self.map.close() def test_lookup(self): for (pid, pos) in MAP_PAIRS: self.assertEqual(self.map[pid], pos) def test_missing(self): with self.assertRaises(KeyError): self.map["swh:1:ori:0101010100000000000000000000000000000000"], with self.assertRaises(KeyError): self.map["swh:1:cnt:0101010100000000000000000000000000000000"], def test_type_error(self): with self.assertRaises(TypeError): self.map[42] with self.assertRaises(TypeError): self.map[1.2] def test_update(self): fname2 = self.fname + ".update" shutil.copy(self.fname, fname2) # fresh map copy map2 = PidToNodeMap(fname2, mode="rb+") for (pid, int) in islice(map2, 11): # update the first N items new_int = int + 42 map2[pid] = new_int self.assertEqual(map2[pid], new_int) # check updated value os.unlink(fname2) # tmpdir will be cleaned even if we don't reach this def test_iter_type(self): for t in PID_TYPES: first_20 = list(islice(self.map.iter_type(t), 20)) k = first_20[0][1] expected = [("swh:1:{}:{:040x}".format(t, i), i) for i in range(k, k + 20)] assert first_20 == expected def test_iter_prefix(self): for t in PID_TYPES: prefix = self.map.iter_prefix("swh:1:{}:00".format(t)) first_20 = list(islice(prefix, 20)) k = first_20[0][1] expected = [("swh:1:{}:{:040x}".format(t, i), i) for i in range(k, k + 20)] assert first_20 == expected class TestNodeToPidMap(unittest.TestCase): @classmethod def setUpClass(cls): """create reasonably sized (~1 MB) int->PID map to test on-disk DB """ cls.tmpdir = tempfile.mkdtemp(prefix="swh.graph.test.") cls.fname = os.path.join(cls.tmpdir, "int2pid.bin") with open(cls.fname, "wb") as f: for (pid, _i) in gen_records(length=10000): NodeToPidMap.write_record(f, pid) @classmethod def tearDownClass(cls): shutil.rmtree(cls.tmpdir) def setUp(self): self.map = NodeToPidMap(self.fname) def tearDown(self): self.map.close() def test_lookup(self): for (pid, pos) in MAP_PAIRS: self.assertEqual(self.map[pos], pid) def test_out_of_bounds(self): with self.assertRaises(IndexError): self.map[1000000] with self.assertRaises(IndexError): self.map[-1000000] def test_update(self): fname2 = self.fname + ".update" shutil.copy(self.fname, fname2) # fresh map copy map2 = NodeToPidMap(fname2, mode="rb+") for (int, pid) in islice(map2, 11): # update the first N items new_pid = pid.replace(":0", ":f") # mangle first hex digit map2[int] = new_pid self.assertEqual(map2[int], new_pid) # check updated value os.unlink(fname2) # tmpdir will be cleaned even if we don't reach this diff --git a/swh/graph/webgraph.py b/swh/graph/webgraph.py index c27aaac..6390014 100644 --- a/swh/graph/webgraph.py +++ b/swh/graph/webgraph.py @@ -1,226 +1,225 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """WebGraph driver """ +from datetime import datetime +from enum import Enum import logging import os -import subprocess - -from enum import Enum -from datetime import datetime from pathlib import Path +import subprocess from typing import Dict, List, Set from swh.graph.config import check_config_compress class CompressionStep(Enum): MPH = 1 BV = 2 BV_OBL = 3 BFS = 4 PERMUTE = 5 PERMUTE_OBL = 6 STATS = 7 TRANSPOSE = 8 TRANSPOSE_OBL = 9 MAPS = 10 CLEAN_TMP = 11 def __str__(self): return self.name # full compression pipeline COMP_SEQ = list(CompressionStep) # Mapping from compression steps to shell commands implementing them. Commands # will be executed by the shell, so be careful with meta characters. They are # specified here as lists of tokens that will be joined together only for ease # of line splitting. In commands, {tokens} will be interpolated with # configuration values, see :func:`compress`. STEP_ARGV: Dict[CompressionStep, List[str]] = { CompressionStep.MPH: [ "{java}", "it.unimi.dsi.sux4j.mph.GOVMinimalPerfectHashFunction", "--temp-dir", "{tmp_dir}", "{out_dir}/{graph_name}.mph", "<( zstdcat {in_dir}/{graph_name}.nodes.csv.zst )", ], # use process substitution (and hence FIFO) above as MPH class load the # entire file in memory when reading from stdin CompressionStep.BV: [ "zstdcat", "{in_dir}/{graph_name}.edges.csv.zst", "|", "{java}", "it.unimi.dsi.big.webgraph.ScatteredArcsASCIIGraph", "--temp-dir", "{tmp_dir}", "--function", "{out_dir}/{graph_name}.mph", "{out_dir}/{graph_name}-bv", ], CompressionStep.BV_OBL: [ "{java}", "it.unimi.dsi.big.webgraph.BVGraph", "--list", "{out_dir}/{graph_name}-bv", ], CompressionStep.BFS: [ "{java}", "it.unimi.dsi.law.big.graph.BFS", "{out_dir}/{graph_name}-bv", "{out_dir}/{graph_name}.order", ], CompressionStep.PERMUTE: [ "{java}", "it.unimi.dsi.big.webgraph.Transform", "mapOffline", "{out_dir}/{graph_name}-bv", "{out_dir}/{graph_name}", "{out_dir}/{graph_name}.order", "{batch_size}", "{tmp_dir}", ], CompressionStep.PERMUTE_OBL: [ "{java}", "it.unimi.dsi.big.webgraph.BVGraph", "--list", "{out_dir}/{graph_name}", ], CompressionStep.STATS: [ "{java}", "it.unimi.dsi.big.webgraph.Stats", "{out_dir}/{graph_name}", ], CompressionStep.TRANSPOSE: [ "{java}", "it.unimi.dsi.big.webgraph.Transform", "transposeOffline", "{out_dir}/{graph_name}", "{out_dir}/{graph_name}-transposed", "{batch_size}", "{tmp_dir}", ], CompressionStep.TRANSPOSE_OBL: [ "{java}", "it.unimi.dsi.big.webgraph.BVGraph", "--list", "{out_dir}/{graph_name}-transposed", ], CompressionStep.MAPS: [ "zstdcat", "{in_dir}/{graph_name}.nodes.csv.zst", "|", "{java}", "org.softwareheritage.graph.maps.NodeMapBuilder", "{out_dir}/{graph_name}", "{tmp_dir}", ], CompressionStep.CLEAN_TMP: [ "rm", "-rf", "{out_dir}/{graph_name}-bv.graph", "{out_dir}/{graph_name}-bv.obl", "{out_dir}/{graph_name}-bv.offsets", "{tmp_dir}", ], } def do_step(step, conf): cmd = " ".join(STEP_ARGV[step]).format(**conf) cmd_env = os.environ.copy() cmd_env["JAVA_TOOL_OPTIONS"] = conf["java_tool_options"] cmd_env["CLASSPATH"] = conf["classpath"] logging.info(f"running: {cmd}") process = subprocess.Popen( ["/bin/bash", "-c", cmd], env=cmd_env, encoding="utf8", stdout=subprocess.PIPE, stderr=subprocess.STDOUT, ) with process.stdout as stdout: for line in stdout: logging.info(line.rstrip()) rc = process.wait() if rc != 0: raise RuntimeError( f"compression step {step} returned non-zero " f"exit code {rc}" ) else: return rc def compress( graph_name: str, in_dir: Path, out_dir: Path, steps: Set[CompressionStep] = set(COMP_SEQ), conf: Dict[str, str] = {}, ): """graph compression pipeline driver from nodes/edges files to compressed on-disk representation Args: graph_name: graph base name, relative to in_dir in_dir: input directory, where the uncompressed graph can be found out_dir: output directory, where the compressed graph will be stored steps: compression steps to run (default: all steps) conf: compression configuration, supporting the following keys (all are optional, so an empty configuration is fine and is the default) - batch_size: batch size for `WebGraph transformations