diff --git a/swh/graph/naive_client.py b/swh/graph/naive_client.py --- a/swh/graph/naive_client.py +++ b/swh/graph/naive_client.py @@ -17,9 +17,10 @@ Set, Tuple, TypeVar, + Union, ) -from swh.model.swhids import ExtendedSWHID, ValidationError +from swh.model.swhids import CoreSWHID, ExtendedSWHID, ValidationError from .client import GraphArgumentException @@ -29,6 +30,7 @@ T = TypeVar("T", bound=Callable) +SWHIDlike = Union[CoreSWHID, ExtendedSWHID, str] def check_arguments(f: T) -> T: @@ -107,7 +109,9 @@ ['swh:1:rev:3333333333333333333333333333333333333333'] """ - def __init__(self, *, nodes: List[str], edges: List[Tuple[str, str]]): + def __init__( + self, *, nodes: List[SWHIDlike], edges: List[Tuple[SWHIDlike, SWHIDlike]] + ): self.graph = Graph(nodes, edges) def _check_swhid(self, swhid): @@ -268,16 +272,18 @@ class Graph: - def __init__(self, nodes: List[str], edges: List[Tuple[str, str]]): - self.nodes = nodes + def __init__( + self, nodes: List[SWHIDlike], edges: List[Tuple[SWHIDlike, SWHIDlike]] + ): + self.nodes = [str(node) for node in nodes] self.forward_edges: Dict[str, List[str]] = {} self.backward_edges: Dict[str, List[str]] = {} for node in nodes: - self.forward_edges[node] = [] - self.backward_edges[node] = [] + self.forward_edges[str(node)] = [] + self.backward_edges[str(node)] = [] for (src, dst) in edges: - self.forward_edges[src].append(dst) - self.backward_edges[dst].append(src) + self.forward_edges[str(src)].append(str(dst)) + self.backward_edges[str(dst)].append(str(src)) def get_filtered_neighbors( self, src: str, edges_fmt: str, direction: str,