diff --git a/swh/dataset/graph.py b/swh/dataset/graph.py --- a/swh/dataset/graph.py +++ b/swh/dataset/graph.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import contextlib import functools import os import os.path @@ -13,12 +14,12 @@ import uuid from swh.dataset.exporter import ParallelExporter -from swh.dataset.utils import ZSTFile +from swh.dataset.utils import ZSTFile, SQLiteSet from swh.model.identifiers import origin_identifier, persistent_identifier from swh.storage.fixer import fix_objects -def process_messages(messages, config, node_writer, edge_writer): +def process_messages(messages, config, node_writer, edge_writer, node_set): """ Args: messages: A sequence of messages to process @@ -47,10 +48,14 @@ for visit in messages.get("origin_visit", []): origin_id = origin_identifier({"url": visit["origin"]}) + if not node_set.add(origin_id): + continue write_node(("origin", origin_id)) write_edge(("origin", origin_id), ("snapshot", visit["snapshot"])) for snapshot in messages.get("snapshot", []): + if not node_set.add(snapshot["id"]): + continue write_node(("snapshot", snapshot["id"])) for branch_name, branch in snapshot["branches"].items(): while branch and branch.get("target_type") == "alias": @@ -68,18 +73,24 @@ ) for release in messages.get("release", []): + if not node_set.add(release["id"]): + continue write_node(("release", release["id"])) write_edge( ("release", release["id"]), (release["target_type"], release["target"]) ) for revision in messages.get("revision", []): + if not node_set.add(revision["id"]): + continue write_node(("revision", revision["id"])) write_edge(("revision", revision["id"]), ("directory", revision["directory"])) for parent in revision["parents"]: write_edge(("revision", revision["id"]), ("revision", parent)) for directory in messages.get("directory", []): + if not node_set.add(directory["id"]): + continue write_node(("directory", directory["id"])) for entry in directory["entries"]: entry_type_mapping = { @@ -93,6 +104,8 @@ ) for content in messages.get("content", []): + if not node_set.add(content["sha1_git"]): + continue write_node(("content", content["sha1_git"])) @@ -107,17 +120,22 @@ def export_worker(self, export_path, **kwargs): dataset_path = pathlib.Path(export_path) dataset_path.mkdir(exist_ok=True, parents=True) - nodes_file = dataset_path / ("graph-{}.nodes.csv.zst".format(str(uuid.uuid4()))) - edges_file = dataset_path / ("graph-{}.edges.csv.zst".format(str(uuid.uuid4()))) + unique_id = str(uuid.uuid4()) + nodes_file = dataset_path / ("graph-{}.nodes.csv.zst".format(unique_id)) + edges_file = dataset_path / ("graph-{}.edges.csv.zst".format(unique_id)) + node_set_file = dataset_path / (".set-nodes-{}.sqlite3".format(unique_id)) + + with contextlib.ExitStack() as stack: + nodes_writer = stack.enter_context(ZSTFile(nodes_file, "w")) + edges_writer = stack.enter_context(ZSTFile(edges_file, "w")) + node_set = stack.enter_context(SQLiteSet(node_set_file)) - with ZSTFile(nodes_file, "w") as nodes_writer, ZSTFile( - edges_file, "w" - ) as edges_writer: process_fn = functools.partial( process_messages, config=self.config, nodes_writer=nodes_writer, edges_writer=edges_writer, + node_set=node_set, ) self.process(process_fn, **kwargs) diff --git a/swh/dataset/test/test_graph.py b/swh/dataset/test/test_graph.py --- a/swh/dataset/test/test_graph.py +++ b/swh/dataset/test/test_graph.py @@ -68,13 +68,19 @@ @pytest.fixture def exporter(): - def wrapped(messages, config=None) -> Tuple[Mock, Mock]: + def wrapped(messages, config=None, nodes_already_there=False) -> Tuple[Mock, Mock]: if config is None: config = {} node_writer = Mock() edge_writer = Mock() + node_set = Mock() + node_set.add.return_value = not nodes_already_there process_messages( - messages, config=config, node_writer=node_writer, edge_writer=edge_writer, + messages, + config=config, + node_writer=node_writer, + edge_writer=edge_writer, + node_set=node_set, ) return node_writer.write, edge_writer.write @@ -343,6 +349,18 @@ assert edge_writer.mock_calls == [] +def test_export_already_there_nodes(exporter): + node_writer, edge_writer = exporter( + { + "content": [{**TEST_CONTENT, "sha1_git": binhash("cnt1")}], + "directory": [{"id": binhash("dir2"), "entries": []}], + }, + nodes_already_there=True, + ) + assert node_writer.mock_calls == [] + assert edge_writer.mock_calls == [] + + def zstwrite(fp, lines): with ZSTFile(fp, "w") as writer: for l in lines: diff --git a/swh/dataset/test/test_utils.py b/swh/dataset/test/test_utils.py new file mode 100644 --- /dev/null +++ b/swh/dataset/test/test_utils.py @@ -0,0 +1,15 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from swh.dataset.utils import SQLiteSet + + +def test_sqliteset(tmp_path): + f = tmp_path / "test.sqlite3" + + with SQLiteSet(f) as s: + assert s.add(b"a") + assert not s.add(b"a") + assert s.add(b"b") diff --git a/swh/dataset/utils.py b/swh/dataset/utils.py --- a/swh/dataset/utils.py +++ b/swh/dataset/utils.py @@ -4,9 +4,16 @@ # See top-level LICENSE file for more information import subprocess +import sqlite3 +import os class ZSTFile: + """ + Object-like wrapper around a ZST file. Uses a subprocess of the "zstd" + command to compress and deflate the objects. + """ + def __init__(self, path, mode="r"): if mode not in ("r", "rb", "w", "wb"): raise ValueError(f"ZSTFile mode {mode} is invalid.") @@ -35,3 +42,40 @@ def write(self, buf): self.process.stdin.write(buf) + + +class SQLiteSet: + """ + On-disk Set object for hashes using SQLite as an indexer backend. Used to + deduplicate objects when processing large queues with duplicates. + """ + + def __init__(self, db_path: os.PathLike): + self.db_path = db_path + + def __enter__(self): + self.db = sqlite3.connect(str(self.db_path)) + self.db.execute( + "CREATE TABLE tmpset (val TEXT NOT NULL PRIMARY KEY) WITHOUT ROWID" + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.db.close() + + def add(self, v: bytes) -> bool: + """ + Add an item to the set. + + Args: + v: The value to add to the set. + + Returns: + True if the value was added to the set, False if it was already present. + """ + try: + self.db.execute("INSERT INTO tmpset(val) VALUES (?)", (v.hex(),)) + except sqlite3.IntegrityError: + return False + else: + return True