diff --git a/swh/dataset/exporters/edges.py b/swh/dataset/exporters/edges.py --- a/swh/dataset/exporters/edges.py +++ b/swh/dataset/exporters/edges.py @@ -9,20 +9,20 @@ import shlex import subprocess import tempfile +from typing import Tuple import uuid from swh.dataset.exporter import ExporterDispatch from swh.dataset.utils import ZSTFile, remove_pull_requests -from swh.model.hashutil import hash_to_bytes -from swh.model.identifiers import ExtendedObjectType, ExtendedSWHID, origin_identifier +from swh.model.hashutil import hash_to_bytes, hash_to_hex +from swh.model.identifiers import ExtendedObjectType, origin_identifier def swhid(object_type, object_id): - return str( - ExtendedSWHID( - object_type=ExtendedObjectType[object_type.upper()], object_id=object_id - ) - ) + # We use string interpolation here instead of using ExtendedSWHID to format, + # as building temporary ExtendedSWHID objects has a non-negligeable impact + # on performance. + return f"swh:1:{object_type.value}:{hash_to_hex(object_id)}" class GraphEdgesExporter(ExporterDispatch): @@ -37,9 +37,9 @@ super().__init__(*args, **kwargs) self.writers = {} - def get_writers_for(self, obj_type: str): + def get_writers_for(self, obj_type: ExtendedObjectType): if obj_type not in self.writers: - dataset_path = self.export_path / obj_type + dataset_path = self.export_path / obj_type.name.lower() dataset_path.mkdir(exist_ok=True) unique_id = str(uuid.uuid4()) nodes_file = dataset_path / ("graph-{}.nodes.csv.zst".format(unique_id)) @@ -49,13 +49,13 @@ self.writers[obj_type] = (node_writer, edge_writer) return self.writers[obj_type] - def get_node_writer_for(self, obj_type: str): + def get_node_writer_for(self, obj_type: ExtendedObjectType): return self.get_writers_for(obj_type)[0] - def get_edge_writer_for(self, obj_type: str): + def get_edge_writer_for(self, obj_type: ExtendedObjectType): return self.get_writers_for(obj_type)[1] - def write_node(self, node): + def write_node(self, node: Tuple[ExtendedObjectType, bytes]): node_type, node_id = node if node_id is None: return @@ -63,7 +63,13 @@ node_writer = self.get_node_writer_for(node_type) node_writer.write("{}\n".format(node_swhid)) - def write_edge(self, src, dst, *, labels=None): + def write_edge( + self, + src: Tuple[ExtendedObjectType, bytes], + dst: Tuple[ExtendedObjectType, bytes], + *, + labels=None, + ): src_type, src_id = src dst_type, dst_id = dst if src_id is None or dst_id is None: @@ -76,17 +82,20 @@ def process_origin(self, origin): origin_id = hash_to_bytes(origin_identifier({"url": origin["url"]})) - self.write_node(("origin", origin_id)) + self.write_node((ExtendedObjectType.ORIGIN, origin_id)) def process_origin_visit_status(self, visit_status): origin_id = hash_to_bytes(origin_identifier({"url": visit_status["origin"]})) - self.write_edge(("origin", origin_id), ("snapshot", visit_status["snapshot"])) + self.write_edge( + (ExtendedObjectType.ORIGIN, origin_id), + (ExtendedObjectType.SNAPSHOT, visit_status["snapshot"]), + ) def process_snapshot(self, snapshot): if self.config.get("remove_pull_requests"): remove_pull_requests(snapshot) - self.write_node(("snapshot", snapshot["id"])) + self.write_node((ExtendedObjectType.SNAPSHOT, snapshot["id"])) for branch_name, branch in snapshot["branches"].items(): original_branch_name = branch_name while branch and branch.get("target_type") == "alias": @@ -95,41 +104,46 @@ if branch is None or not branch_name: continue self.write_edge( - ("snapshot", snapshot["id"]), - (branch["target_type"], branch["target"]), + (ExtendedObjectType.SNAPSHOT, snapshot["id"]), + (ExtendedObjectType[branch["target_type"].upper()], branch["target"]), labels=[base64.b64encode(original_branch_name).decode(),], ) def process_release(self, release): - self.write_node(("release", release["id"])) + self.write_node((ExtendedObjectType.RELEASE, release["id"])) self.write_edge( - ("release", release["id"]), (release["target_type"], release["target"]) + (ExtendedObjectType.RELEASE, release["id"]), + (ExtendedObjectType[release["target_type"].upper()], release["target"]), ) def process_revision(self, revision): - self.write_node(("revision", revision["id"])) + self.write_node((ExtendedObjectType.REVISION, revision["id"])) self.write_edge( - ("revision", revision["id"]), ("directory", revision["directory"]) + (ExtendedObjectType.REVISION, revision["id"]), + (ExtendedObjectType.DIRECTORY, revision["directory"]), ) for parent in revision["parents"]: - self.write_edge(("revision", revision["id"]), ("revision", parent)) + self.write_edge( + (ExtendedObjectType.REVISION, revision["id"]), + (ExtendedObjectType.REVISION, parent), + ) def process_directory(self, directory): - self.write_node(("directory", directory["id"])) + self.write_node((ExtendedObjectType.DIRECTORY, directory["id"])) for entry in directory["entries"]: entry_type_mapping = { - "file": "content", - "dir": "directory", - "rev": "revision", + "file": ExtendedObjectType.CONTENT, + "dir": ExtendedObjectType.DIRECTORY, + "rev": ExtendedObjectType.REVISION, } self.write_edge( - ("directory", directory["id"]), + (ExtendedObjectType.DIRECTORY, directory["id"]), (entry_type_mapping[entry["type"]], entry["target"]), labels=[base64.b64encode(entry["name"]).decode(), str(entry["perms"])], ) def process_content(self, content): - self.write_node(("content", content["sha1_git"])) + self.write_node((ExtendedObjectType.CONTENT, content["sha1_git"])) def sort_graph_nodes(export_path, config):