diff --git a/swh/graph/client.py b/swh/graph/client.py --- a/swh/graph/client.py +++ b/swh/graph/client.py @@ -16,7 +16,7 @@ class GraphArgumentException(Exception): - def __init__(self, *args, response): + def __init__(self, *args, response=None): super().__init__(*args) self.response = response diff --git a/swh/graph/standalone_client.py b/swh/graph/standalone_client.py --- a/swh/graph/standalone_client.py +++ b/swh/graph/standalone_client.py @@ -3,43 +3,45 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import collections +import functools +import inspect +import re import statistics -from typing import AsyncIterator, Dict, Iterable, Iterator, List, Set, Tuple +from typing import Callable, Dict, Iterator, List, Optional, Set, Tuple, TypeVar -from swh.model.identifiers import ExtendedSWHID +from swh.model.identifiers import ExtendedSWHID, ValidationError +from .client import GraphArgumentException -class Swhid2NodeDict(collections.UserDict): - def iter_type(self, swhid_type: str) -> Iterator[Tuple[str, str]]: - prefix = "swh:1:{}:".format(swhid_type) - for (swhid, node) in self.items(): - if swhid.startswith(prefix): - yield (swhid, node) +_NODE_TYPES = "ori|snp|rel|rev|dir|cnt" +EDGES_RE = re.compile(fr"(\*|{_NODE_TYPES}):(\*|{_NODE_TYPES})") - def __getitem__(self, swhid): - ExtendedSWHID.from_string(swhid) # Raises ValidationError, caught by server - return self.data[swhid] +T = TypeVar("T", bound=Callable) -class Node2SwhidDict(collections.UserDict): - def __getitem__(self, key): - try: - return self.data[key] - except KeyError: - # Pretend to be a list - raise IndexError(key) from None +def check_arguments(f: T) -> T: + """Decorator for generic argument checking for methods of StandaloneClient. + Checks ``src`` is a valid and known SWHID, and ``edges`` has the right format.""" + signature = inspect.signature(f) + + @functools.wraps(f) + def newf(*args, **kwargs): + bound_args = signature.bind(*args, **kwargs) + self = bound_args.arguments["self"] -class JavaIterator: - def __init__(self, iterator: Iterable): - self.iterator = iter(iterator) + src = bound_args.arguments.get("src") + if src: + self._check_swhid(src) - def nextLong(self): - return next(self.iterator) + edges = bound_args.arguments.get("edges") + if edges: + if edges != "*" and not EDGES_RE.match(edges): + raise GraphArgumentException(f"invalid edge restriction: {edges}") - def __getattr__(self, name): - return getattr(self.iterator, name) + return f(*args, **kwargs) + + return newf # type: ignore class StandaloneClient: @@ -53,17 +55,25 @@ def __init__(self, *, nodes: List[str], edges: List[Tuple[str, str]]): self.graph = Graph(nodes, edges) + def _check_swhid(self, swhid): + try: + ExtendedSWHID.from_string(swhid) + except ValidationError as e: + raise GraphArgumentException(*e.args) from None + if swhid not in self.graph.nodes: + raise GraphArgumentException(f"SWHID not found: {swhid}") + def stats(self) -> Dict: return { "counts": { "nodes": len(self.graph.nodes), - "edges": len(self.graph.forward_edges), + "edges": sum(map(len, self.graph.forward_edges.values())), }, "ratios": { "compression": 1.0, - "bits_per_edge": 100, - "bits_per_node": 100, - "avg_locality": 0, + "bits_per_edge": 100.0, + "bits_per_node": 100.0, + "avg_locality": 0.0, }, "indegree": { "min": min(map(len, self.graph.backward_edges.values())), @@ -77,68 +87,107 @@ }, } - def count_neighbors(self, ttype, direction, edges_fmt, src) -> int: - return len(self.graph.get_filtered_neighbors(direction, edges_fmt, src)) - - def count_visit_nodes(self, ttype, direction, edges_fmt, src) -> int: - return len(self.graph.get_subgraph(direction, edges_fmt, src)) - - def count_leaves(self, ttype, direction, edges_fmt, src) -> int: - return len(list(self.leaves(direction, edges_fmt, src))) - - async def simple_traversal(self, ttype, direction, edges_fmt, src, max_edges): - # TODO: max_edges? - if ttype == "visit_nodes": - for node in self.graph.get_subgraph(direction, edges_fmt, src): - yield node - elif ttype == "leaves": - for node in self.leaves(direction, edges_fmt, src): - yield node - else: - assert False, f"unknown ttype {ttype!r}" - - def leaves(self, direction, edges_fmt, src) -> Iterator[str]: + @check_arguments + def leaves( + self, src: str, edges: str = "*", direction: str = "forward", max_edges: int = 0 + ) -> Iterator[str]: + # TODO: max_edges yield from [ node - for node in self.graph.get_subgraph(direction, edges_fmt, src) - if not self.graph.get_filtered_neighbors(direction, edges_fmt, node) + for node in self.graph.get_subgraph(src, edges, direction) + if not self.graph.get_filtered_neighbors(node, edges, direction) ] - async def walk(self, direction, edges_fmt, algo, src, dst) -> AsyncIterator[str]: + @check_arguments + def neighbors( + self, src: str, edges: str = "*", direction: str = "forward", max_edges: int = 0 + ) -> Iterator[str]: + # TODO: max_edges + yield from self.graph.get_filtered_neighbors(src, edges, direction) + + @check_arguments + def visit_nodes( + self, src: str, edges: str = "*", direction: str = "forward", max_edges: int = 0 + ) -> Iterator[str]: + # TODO: max_edges + yield from self.graph.get_subgraph(src, edges, direction) + + @check_arguments + def visit_edges( + self, src: str, edges: str = "*", direction: str = "forward", max_edges: int = 0 + ) -> Iterator[Tuple[str, str]]: + if max_edges == 0: + max_edges = None # type: ignore + else: + max_edges -= 1 + yield from list(self.graph.iter_edges_dfs(direction, edges, src))[:max_edges] + + @check_arguments + def visit_paths( + self, src: str, edges: str = "*", direction: str = "forward", max_edges: int = 0 + ) -> Iterator[List[str]]: + # TODO: max_edges + for path in self.graph.iter_paths_dfs(direction, edges, src): + if path[-1] in self.leaves(src, edges, direction): + yield list(path) + + @check_arguments + def walk( + self, + src: str, + dst: str, + edges: str = "*", + traversal: str = "dfs", + direction: str = "forward", + limit: Optional[int] = None, + ) -> Iterator[str]: # TODO: implement algo="bfs" + # TODO: limit + match_path: Callable[[str], bool] if ":" in dst: match_path = dst.__eq__ + self._check_swhid(dst) else: match_path = lambda node: node.startswith(f"swh:1:{dst}:") # noqa - for path in self.graph.iter_paths_dfs(direction, edges_fmt, src): + for path in self.graph.iter_paths_dfs(direction, edges, src): if match_path(path[-1]): - for node in path: - yield node - - async def random_walk( - self, direction, edges_fmt, retries, src, dst - ) -> AsyncIterator[str]: - async for node in self.walk(direction, edges_fmt, "dfs", src, dst): - yield node - - async def visit_paths( - self, direction, edges_fmt, src, max_edges - ) -> AsyncIterator[List[str]]: - # TODO: max_edges? - for path in self.graph.iter_paths_dfs(direction, edges_fmt, src): - if path[-1] in self.leaves(direction, edges_fmt, src): - yield list(path) - - async def visit_edges( - self, direction, edges_fmt, src, max_edges - ) -> AsyncIterator[Tuple[str, str]]: - if max_edges == 0: - max_edges = None - else: - max_edges -= 1 - edges = list(self.graph.iter_edges_dfs(direction, edges_fmt, src)) - for (from_, to) in edges[:max_edges]: - yield (from_, to) + if not limit: + # 0 or None + yield from path + elif limit > 0: + yield from path[0:limit] + else: + yield from path[limit:] + + @check_arguments + def random_walk( + self, + src: str, + dst: str, + edges: str = "*", + direction: str = "forward", + limit: Optional[int] = None, + ): + # TODO: limit + yield from self.walk(src, dst, edges, "dfs", direction, limit) + + @check_arguments + def count_leaves( + self, src: str, edges: str = "*", direction: str = "forward" + ) -> int: + return len(list(self.leaves(src, edges, direction))) + + @check_arguments + def count_neighbors( + self, src: str, edges: str = "*", direction: str = "forward" + ) -> int: + return len(self.graph.get_filtered_neighbors(src, edges, direction)) + + @check_arguments + def count_visit_nodes( + self, src: str, edges: str = "*", direction: str = "forward" + ) -> int: + return len(self.graph.get_subgraph(src, edges, direction)) class Graph: @@ -146,34 +195,22 @@ self.nodes = nodes self.forward_edges: Dict[str, List[str]] = {} self.backward_edges: Dict[str, List[str]] = {} + for node in nodes: + self.forward_edges[node] = [] + self.backward_edges[node] = [] for (src, dst) in edges: - self.forward_edges.setdefault(src, []).append(dst) - self.backward_edges.setdefault(dst, []).append(src) - - def numNodes(self) -> int: - return len(self.nodes) - - def successors(self, node: str) -> Iterator[str]: - return JavaIterator(self.forward_edges[node]) - - def outdegree(self, node: str) -> int: - return len(self.forward_edges[node]) - - def predecessors(self, node: str) -> Iterator[str]: - return JavaIterator(self.backward_edges[node]) - - def indegree(self, node: str) -> int: - return len(self.backward_edges[node]) + self.forward_edges[src].append(dst) + self.backward_edges[dst].append(src) def get_filtered_neighbors( - self, direction: str, edges_fmt: str, src: str + self, src: str, edges_fmt: str, direction: str, ) -> Set[str]: if direction == "forward": edges = self.forward_edges elif direction == "backward": edges = self.backward_edges else: - assert False, f"unknown direction {direction!r}" + raise GraphArgumentException(f"invalid direction: {direction}") neighbors = edges.get(src, []) @@ -194,13 +231,13 @@ ) return filtered_neighbors - def get_subgraph(self, direction: str, edges_fmt: str, src: str) -> Set[str]: + def get_subgraph(self, src: str, edges_fmt: str, direction: str) -> Set[str]: seen = set() to_visit = {src} while to_visit: node = to_visit.pop() seen.add(node) - neighbors = set(self.get_filtered_neighbors(direction, edges_fmt, node)) + neighbors = set(self.get_filtered_neighbors(node, edges_fmt, direction)) new_nodes = neighbors - seen to_visit.update(new_nodes) @@ -214,7 +251,7 @@ def iter_edges_dfs( self, direction: str, edges_fmt: str, src: str - ) -> Iterator[Tuple[str, ...]]: + ) -> Iterator[Tuple[str, str]]: for (path, node) in DfsSubgraphIterator(self, direction, edges_fmt, src): if len(path) > 0: yield (path[-1], node) @@ -248,7 +285,7 @@ if node not in self.seen: neighbors = self.graph.get_filtered_neighbors( - self.direction, self.edges_fmt, node + node, self.edges_fmt, self.direction ) # We want to visit the first neighbor first, and to_visit is a stack; diff --git a/swh/graph/tests/conftest.py b/swh/graph/tests/conftest.py --- a/swh/graph/tests/conftest.py +++ b/swh/graph/tests/conftest.py @@ -1,3 +1,9 @@ +# Copyright (C) 2019-2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import csv import multiprocessing from pathlib import Path @@ -8,6 +14,7 @@ from swh.graph.client import RemoteGraphClient from swh.graph.graph import load as graph_load from swh.graph.server.app import make_app +from swh.graph.standalone_client import StandaloneClient SWH_GRAPH_TESTS_ROOT = Path(__file__).parents[0] TEST_GRAPH_PATH = SWH_GRAPH_TESTS_ROOT / "dataset/output/example" @@ -33,16 +40,23 @@ self.q.put(e) -@pytest.fixture(scope="module") -def graph_client(): - queue = multiprocessing.Queue() - server = GraphServerProcess(queue) - server.start() - res = queue.get() - if isinstance(res, Exception): - raise res - yield RemoteGraphClient(str(res)) - server.terminate() +@pytest.fixture(scope="module", params=["remote", "standalone"]) +def graph_client(request): + if request.param == "remote": + queue = multiprocessing.Queue() + server = GraphServerProcess(queue) + server.start() + res = queue.get() + if isinstance(res, Exception): + raise res + yield RemoteGraphClient(str(res)) + server.terminate() + else: + with open(SWH_GRAPH_TESTS_ROOT / "dataset/example.nodes.csv") as fd: + nodes = [node for (node,) in csv.reader(fd, delimiter=" ")] + with open(SWH_GRAPH_TESTS_ROOT / "dataset/example.edges.csv") as fd: + edges = list(csv.reader(fd, delimiter=" ")) + yield StandaloneClient(nodes=nodes, edges=edges) @pytest.fixture(scope="module") diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py --- a/swh/graph/tests/test_api_client.py +++ b/swh/graph/tests/test_api_client.py @@ -291,13 +291,15 @@ def test_param_validation(graph_client): with raises(GraphArgumentException) as exc_info: # SWHID not found list(graph_client.leaves("swh:1:ori:fff0000000000000000000000000000000000021")) - assert exc_info.value.response.status_code == 404 + if exc_info.value.response: + assert exc_info.value.response.status_code == 404 with raises(GraphArgumentException) as exc_info: # malformed SWHID list( graph_client.neighbors("swh:1:ori:fff000000zzzzzz0000000000000000000000021") ) - assert exc_info.value.response.status_code == 400 + if exc_info.value.response: + assert exc_info.value.response.status_code == 400 with raises(GraphArgumentException) as exc_info: # malformed edge specificaiton list( @@ -307,7 +309,8 @@ direction="backward", ) ) - assert exc_info.value.response.status_code == 400 + if exc_info.value.response: + assert exc_info.value.response.status_code == 400 with raises(GraphArgumentException) as exc_info: # malformed direction list( @@ -317,7 +320,8 @@ direction="notadirection", ) ) - assert exc_info.value.response.status_code == 400 + if exc_info.value.response: + assert exc_info.value.response.status_code == 400 @pytest.mark.skip(reason="currently disabled due to T1969")