diff --git a/swh/graph/client.py b/swh/graph/client.py
index 26f05ed..7f32546 100644
--- a/swh/graph/client.py
+++ b/swh/graph/client.py
@@ -1,111 +1,110 @@
# Copyright (C) 2019 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import json
from swh.core.api import RPCClient
class GraphAPIError(Exception):
"""Graph API Error"""
def __str__(self):
return "An unexpected error occurred in the Graph backend: {}".format(self.args)
class RemoteGraphClient(RPCClient):
"""Client to the Software Heritage Graph."""
def __init__(self, url, timeout=None):
super().__init__(api_exception=GraphAPIError, url=url, timeout=timeout)
def raw_verb_lines(self, verb, endpoint, **kwargs):
response = self.raw_verb(verb, endpoint, stream=True, **kwargs)
self.raise_for_status(response)
for line in response.iter_lines():
yield line.decode().lstrip("\n")
def get_lines(self, endpoint, **kwargs):
yield from self.raw_verb_lines("get", endpoint, **kwargs)
# Web API endpoints
def stats(self):
return self.get("stats")
def leaves(self, src, edges="*", direction="forward"):
return self.get_lines(
"leaves/{}".format(src), params={"edges": edges, "direction": direction}
)
def neighbors(self, src, edges="*", direction="forward"):
return self.get_lines(
"neighbors/{}".format(src), params={"edges": edges, "direction": direction}
)
def visit_nodes(self, src, edges="*", direction="forward"):
return self.get_lines(
"visit/nodes/{}".format(src),
params={"edges": edges, "direction": direction},
)
def visit_edges(self, src, edges="*", direction="forward"):
for edge in self.get_lines(
"visit/edges/{}".format(src),
params={"edges": edges, "direction": direction},
):
- print(edge)
yield tuple(edge.split())
def visit_paths(self, src, edges="*", direction="forward"):
def decode_path_wrapper(it):
for e in it:
yield json.loads(e)
return decode_path_wrapper(
self.get_lines(
"visit/paths/{}".format(src),
params={"edges": edges, "direction": direction},
)
)
def walk(
self, src, dst, edges="*", traversal="dfs", direction="forward", limit=None
):
endpoint = "walk/{}/{}"
return self.get_lines(
endpoint.format(src, dst),
params={
"edges": edges,
"traversal": traversal,
"direction": direction,
"limit": limit,
},
)
def random_walk(self, src, dst, edges="*", direction="forward", limit=None):
endpoint = "randomwalk/{}/{}"
return self.get_lines(
endpoint.format(src, dst),
params={"edges": edges, "direction": direction, "limit": limit},
)
def count_leaves(self, src, edges="*", direction="forward"):
return self.get(
"leaves/count/{}".format(src),
params={"edges": edges, "direction": direction},
)
def count_neighbors(self, src, edges="*", direction="forward"):
return self.get(
"neighbors/count/{}".format(src),
params={"edges": edges, "direction": direction},
)
def count_visit_nodes(self, src, edges="*", direction="forward"):
return self.get(
"visit/nodes/count/{}".format(src),
params={"edges": edges, "direction": direction},
)
diff --git a/swh/graph/config.py b/swh/graph/config.py
index a9146d1..93c17c4 100644
--- a/swh/graph/config.py
+++ b/swh/graph/config.py
@@ -1,111 +1,110 @@
# Copyright (C) 2019 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import logging
import psutil
import sys
from pathlib import Path
def find_graph_jar():
"""find swh-graph.jar, containing the Java part of swh-graph
look both in development directories and installed data (for in-production
deployments who fecthed the JAR from pypi)
"""
swh_graph_root = Path(__file__).parents[2]
try_paths = [
swh_graph_root / "java/target/",
Path(sys.prefix) / "share/swh-graph/",
Path(sys.prefix) / "local/share/swh-graph/",
]
for path in try_paths:
glob = list(path.glob("swh-graph-*.jar"))
if glob:
if len(glob) > 1:
logging.warn(
"found multiple swh-graph JARs, " "arbitrarily picking one"
)
logging.info("using swh-graph JAR: {0}".format(glob[0]))
return str(glob[0])
raise RuntimeError("swh-graph JAR not found. Have you run `make java`?")
def check_config(conf):
"""check configuration and propagate defaults
"""
conf = conf.copy()
if "batch_size" not in conf:
conf["batch_size"] = "1000000000" # 1 billion
if "max_ram" not in conf:
conf["max_ram"] = str(psutil.virtual_memory().total)
if "java_tool_options" not in conf:
conf["java_tool_options"] = " ".join(
[
"-Xmx{max_ram}",
"-XX:PretenureSizeThreshold=512M",
"-XX:MaxNewSize=4G",
"-XX:+UseLargePages",
"-XX:+UseTransparentHugePages",
"-XX:+UseNUMA",
"-XX:+UseTLAB",
"-XX:+ResizeTLAB",
]
)
conf["java_tool_options"] = conf["java_tool_options"].format(
max_ram=conf["max_ram"]
)
if "java" not in conf:
conf["java"] = "java"
if "classpath" not in conf:
conf["classpath"] = find_graph_jar()
return conf
def check_config_compress(config, graph_name, in_dir, out_dir):
"""check compression-specific configuration and initialize its execution
environment.
"""
conf = check_config(config)
conf["graph_name"] = graph_name
conf["in_dir"] = str(in_dir)
conf["out_dir"] = str(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
if "tmp_dir" not in conf:
tmp_dir = out_dir / "tmp"
conf["tmp_dir"] = str(tmp_dir)
else:
tmp_dir = Path(conf["tmp_dir"])
tmp_dir.mkdir(parents=True, exist_ok=True)
if "logback" not in conf:
logback_confpath = tmp_dir / "logback.xml"
with open(logback_confpath, "w") as conffile:
conffile.write(
"""
You have reached the Software Heritage graph API server.
See its API documentation for more information.
""", ) async def stats(request): stats = request.app["backend"].stats() return aiohttp.web.Response(body=stats, content_type="application/json") def get_direction(request): """validate HTTP query parameter `direction`""" s = request.query.get("direction", "forward") if s not in ("forward", "backward"): raise aiohttp.web.HTTPBadRequest(body=f"invalid direction: {s}") return s def get_edges(request): """validate HTTP query parameter `edges`, i.e., edge restrictions""" s = request.query.get("edges", "*") if any( [ node_type != "*" and node_type not in PID_TYPES for edge in s.split(":") for node_type in edge.split(",", maxsplit=1) ] ): raise aiohttp.web.HTTPBadRequest(body=f"invalid edge restriction: {s}") return s def get_traversal(request): """validate HTTP query parameter `traversal`, i.e., visit order""" s = request.query.get("traversal", "dfs") if s not in ("bfs", "dfs"): raise aiohttp.web.HTTPBadRequest(body=f"invalid traversal order: {s}") return s def get_limit(request): """validate HTTP query parameter `limit`, i.e., number of results""" s = request.query.get("limit", "0") try: return int(s) except ValueError: raise aiohttp.web.HTTPBadRequest(body=f"invalid limit value: {s}") def node_of_pid(pid, backend): """lookup a PID in a pid2node map, failing in an HTTP-nice way if needed""" try: return backend.pid2node[pid] except KeyError: raise aiohttp.web.HTTPNotFound(body=f"PID not found: {pid}") except ValidationError: raise aiohttp.web.HTTPBadRequest(body=f"malformed PID: {pid}") def pid_of_node(node, backend): """lookup a node in a node2pid map, failing in an HTTP-nice way if needed """ try: return backend.node2pid[node] except KeyError: raise aiohttp.web.HTTPInternalServerError( body=f"reverse lookup failed for node id: {node}" ) def get_simple_traversal_handler(ttype): async def simple_traversal(request): backend = request.app["backend"] src = request.match_info["src"] edges = get_edges(request) direction = get_direction(request) src_node = node_of_pid(src, backend) async with stream_response(request) as response: async for res_node in backend.simple_traversal( ttype, direction, edges, src_node ): res_pid = pid_of_node(res_node, backend) await response.write("{}\n".format(res_pid).encode()) return response return simple_traversal def get_walk_handler(random=False): async def walk(request): backend = request.app["backend"] src = request.match_info["src"] dst = request.match_info["dst"] edges = get_edges(request) direction = get_direction(request) algo = get_traversal(request) limit = get_limit(request) src_node = node_of_pid(src, backend) if dst not in PID_TYPES: dst = node_of_pid(dst, backend) async with stream_response(request) as response: if random: it = backend.random_walk( direction, edges, RANDOM_RETRIES, src_node, dst ) else: it = backend.walk(direction, edges, algo, src_node, dst) if limit < 0: queue = deque(maxlen=-limit) async for res_node in it: res_pid = pid_of_node(res_node, backend) queue.append("{}\n".format(res_pid).encode()) while queue: await response.write(queue.popleft()) else: count = 0 async for res_node in it: if limit == 0 or count < limit: res_pid = pid_of_node(res_node, backend) await response.write("{}\n".format(res_pid).encode()) count += 1 else: break return response return walk async def visit_paths(request): backend = request.app["backend"] src = request.match_info["src"] edges = get_edges(request) direction = get_direction(request) src_node = node_of_pid(src, backend) it = backend.visit_paths(direction, edges, src_node) async with stream_response( request, content_type="application/x-ndjson" ) as response: async for res_path in it: res_path_pid = [pid_of_node(n, backend) for n in res_path] line = json.dumps(res_path_pid) await response.write("{}\n".format(line).encode()) return response async def visit_edges(request): backend = request.app["backend"] src = request.match_info["src"] edges = get_edges(request) direction = get_direction(request) src_node = node_of_pid(src, backend) it = backend.visit_edges(direction, edges, src_node) - print(it) async with stream_response(request) as response: async for (res_src, res_dst) in it: res_src_pid = pid_of_node(res_src, backend) res_dst_pid = pid_of_node(res_dst, backend) await response.write("{} {}\n".format(res_src_pid, res_dst_pid).encode()) return response def get_count_handler(ttype): async def count(request): loop = asyncio.get_event_loop() backend = request.app["backend"] src = request.match_info["src"] edges = get_edges(request) direction = get_direction(request) src_node = node_of_pid(src, backend) cnt = await loop.run_in_executor( None, backend.count, ttype, direction, edges, src_node ) return aiohttp.web.Response(body=str(cnt), content_type="application/json") return count def make_app(backend, **kwargs): app = RPCServerApp(**kwargs) app.router.add_get("/", index) app.router.add_get("/graph", index) app.router.add_get("/graph/stats", stats) app.router.add_get("/graph/leaves/{src}", get_simple_traversal_handler("leaves")) app.router.add_get( "/graph/neighbors/{src}", get_simple_traversal_handler("neighbors") ) app.router.add_get( "/graph/visit/nodes/{src}", get_simple_traversal_handler("visit_nodes") ) app.router.add_get("/graph/visit/edges/{src}", visit_edges) app.router.add_get("/graph/visit/paths/{src}", visit_paths) # temporarily disabled in wait of a proper fix for T1969 # app.router.add_get('/graph/walk/{src}/{dst}', # get_walk_handler(random=False)) # app.router.add_get('/graph/walk/last/{src}/{dst}', # get_walk_handler(random=False, last=True)) app.router.add_get("/graph/randomwalk/{src}/{dst}", get_walk_handler(random=True)) app.router.add_get("/graph/neighbors/count/{src}", get_count_handler("neighbors")) app.router.add_get("/graph/leaves/count/{src}", get_count_handler("leaves")) app.router.add_get( "/graph/visit/nodes/count/{src}", get_count_handler("visit_nodes") ) app["backend"] = backend return app