diff --git a/swh/dataset/exporter.py b/swh/dataset/exporter.py --- a/swh/dataset/exporter.py +++ b/swh/dataset/exporter.py @@ -3,239 +3,40 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import concurrent.futures -from concurrent.futures import FIRST_EXCEPTION, ProcessPoolExecutor -import multiprocessing -import time -from typing import Mapping, Sequence, Tuple +from types import TracebackType +from typing import Any, Dict, Optional, Type -from confluent_kafka import TopicPartition -import tqdm -from swh.journal.client import JournalClient +class Exporter: + def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any) -> None: + self.config: Dict[str, Any] = config + def __enter__(self) -> "Exporter": + return self -class JournalClientOffsetRanges(JournalClient): - """ - A subclass of JournalClient reading only inside some specific offset - range. Partition assignments have to be manually given to the class. - - This client can only read a single topic at a time. - """ - - def __init__( + def __exit__( self, - *args, - offset_ranges: Mapping[int, Tuple[int, int]] = None, - assignment: Sequence[int] = None, - progress_queue: multiprocessing.Queue = None, - refresh_every: int = 200, - **kwargs, - ): - """ - Args: - offset_ranges: A mapping of partition_id -> (low, high) offsets - that define the boundaries of the messages to consume. - assignment: The list of partitions to assign to this client. - progress_queue: a multiprocessing.Queue where the current - progress will be reported. - refresh_every: the refreshing rate of the progress reporting. - """ - self.offset_ranges = offset_ranges - self.progress_queue = progress_queue - self.refresh_every = refresh_every - self.assignment = assignment - self.count = None - self.topic_name = None - super().__init__(*args, **kwargs) + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> Optional[bool]: + pass - def subscribe(self): - self.topic_name = self.subscription[0] - time.sleep(0.1) # https://github.com/edenhill/librdkafka/issues/1983 - self.consumer.assign( - [TopicPartition(self.topic_name, pid) for pid in self.assignment] - ) - - def process(self, *args, **kwargs): - self.count = 0 - try: - self.handle_committed_offsets() - super().process(*args, **kwargs) - except EOFError: - pass - finally: - self.progress_queue.put(None) - - def handle_committed_offsets(self,): - """ - Handle already committed partition offsets before starting processing. + def process_object(self, object_type: str, object: Dict[str, Any]) -> None: """ - committed = self.consumer.committed( - [TopicPartition(self.topic_name, pid) for pid in self.assignment] - ) - for tp in committed: - self.handle_offset(tp.partition, tp.offset) + Process a SWH object to export. - def handle_offset(self, partition_id, offset): + Override this with your custom worker. """ - Check whether the client has reached the end of the current - partition, and trigger a reassignment if that is the case. - Raise EOFError if all the partitions have reached the end. - """ - if offset < 0: # Uninitialized partition offset - return - - if self.count % self.refresh_every == 0: - self.progress_queue.put({partition_id: offset}) - - if offset >= self.offset_ranges[partition_id][1] - 1: - self.assignment = [pid for pid in self.assignment if pid != partition_id] - self.subscribe() - - if not self.assignment: - raise EOFError - - def deserialize_message(self, message): - """ - Override of the message deserialization to hook the handling of the - message offset. - """ - self.handle_offset(message.partition(), message.offset()) - self.count += 1 - return super().deserialize_message(message) + raise NotImplementedError -class ParallelExporter: +class ExporterDispatch(Exporter): """ - Base class for all the Journal exporters. - - Each exporter should override the `export_worker` function with an - implementation of how to run the message processing. + Like Exporter, but dispatches each object type to a different function. """ - def __init__(self, config, export_id: str, obj_type, processes=1): - """ - Args: - config: the exporter config, which should also include the - JournalClient configuration. - export_id: a unique identifier for the export that will be used - as part of a Kafka consumer group ID. - obj_type: The type of SWH object to export. - processes: The number of processes to run. - """ - self.config = config - self.export_id = "swh-dataset-export-{}".format(export_id) - self.obj_type = obj_type - self.processes = processes - self.offsets = None - - def get_offsets(self): - """ - First pass to fetch all the current low and high offsets of each - partition to define the consumption boundaries. - """ - if self.offsets is None: - client = JournalClient( - **self.config["journal"], - object_types=[self.obj_type], - group_id=self.export_id, - ) - topic_name = client.subscription[0] - topics = client.consumer.list_topics(topic_name).topics - partitions = topics[topic_name].partitions - - self.offsets = {} - for partition_id in tqdm.tqdm( - partitions.keys(), desc=" - Partition offsets" - ): - tp = TopicPartition(topic_name, partition_id) - (lo, hi) = client.consumer.get_watermark_offsets(tp) - self.offsets[partition_id] = (lo, hi) - return self.offsets - - def run(self, *args): - """ - Run the parallel export. - """ - offsets = self.get_offsets() - to_assign = list(offsets.keys()) - - manager = multiprocessing.Manager() - q = manager.Queue() - - with ProcessPoolExecutor(self.processes + 1) as pool: - futures = [] - for i in range(self.processes): - futures.append( - pool.submit( - self.export_worker, - *args, - assignment=to_assign[i :: self.processes], - queue=q, - ) - ) - futures.append(pool.submit(self.progress_worker, queue=q)) - - concurrent.futures.wait(futures, return_when=FIRST_EXCEPTION) - for f in futures: - if f.running(): - continue - exc = f.exception() - if exc: - pool.shutdown(wait=False) - f.result() - raise exc - - def progress_worker(self, *args, queue=None): - """ - An additional worker process that reports the current progress of the - export between all the different parallel consumers and across all the - partitions, by consuming the shared progress reporting Queue. - """ - d = {} - active_workers = self.processes - offset_diff = sum((hi - lo) for lo, hi in self.offsets.values()) - with tqdm.tqdm(total=offset_diff, desc=" - Journal export") as pbar: - while active_workers: - item = queue.get() - if item is None: - active_workers -= 1 - continue - d.update(item) - progress = sum(n - self.offsets[p][0] for p, n in d.items()) - pbar.set_postfix( - active_workers=active_workers, total_workers=self.processes - ) - pbar.update(progress - pbar.n) - - def process(self, callback, assignment=None, queue=None): - client = JournalClientOffsetRanges( - **self.config["journal"], - object_types=[self.obj_type], - group_id=self.export_id, - debug="cgrp,broker", - offset_ranges=self.offsets, - assignment=assignment, - progress_queue=queue, - **{"message.max.bytes": str(500 * 1024 * 1024)}, - ) - client.process(callback) - - def export_worker(self, *args, **kwargs): - """ - Override this with a custom implementation of a worker function. - - A worker function should call `self.process(fn, **kwargs)` with `fn` - being a callback that will be called in the same fashion as with - `JournalClient.process()`. - - A simple exporter to print all the objects in the log would look like - this: - - ``` - class PrintExporter(ParallelExporter): - def export_worker(self, **kwargs): - self.process(print, **kwargs) - ``` - """ - raise NotImplementedError + def process_object(self, object_type: str, object: Dict[str, Any]) -> None: + method_name = "process_" + object_type + if hasattr(self, method_name): + getattr(self, method_name)(object) diff --git a/swh/dataset/graph.py b/swh/dataset/graph.py --- a/swh/dataset/graph.py +++ b/swh/dataset/graph.py @@ -4,8 +4,6 @@ # See top-level LICENSE file for more information import base64 -import contextlib -import functools import os import os.path import pathlib @@ -14,29 +12,48 @@ import tempfile import uuid -from swh.dataset.exporter import ParallelExporter -from swh.dataset.utils import SQLiteSet, ZSTFile +from swh.dataset.exporter import ExporterDispatch +from swh.dataset.journalprocessor import ParallelJournalProcessor +from swh.dataset.utils import ZSTFile from swh.model.identifiers import origin_identifier, swhid -from swh.storage.fixer import fix_objects -def process_messages(messages, config, node_writer, edge_writer, node_set): +class GraphEdgesExporter(ExporterDispatch): """ - Args: - messages: A sequence of messages to process - config: The exporter configuration - node_writer: A file-like object where to write nodes - edge_writer: A file-like object where to write edges + Implementation of an exporter which writes all the graph edges + of a specific type in a Zstandard-compressed CSV file. + + Each row of the CSV is in the format: ` """ - def write_node(node): + def __init__(self, config, export_path, **kwargs): + super().__init__(config) + self.export_path = export_path + + def __enter__(self): + dataset_path = pathlib.Path(self.export_path) + dataset_path.mkdir(exist_ok=True, parents=True) + unique_id = str(uuid.uuid4()) + nodes_file = dataset_path / ("graph-{}.nodes.csv.zst".format(unique_id)) + edges_file = dataset_path / ("graph-{}.edges.csv.zst".format(unique_id)) + self.node_writer = ZSTFile(nodes_file, "w") + self.edge_writer = ZSTFile(edges_file, "w") + self.node_writer.__enter__() + self.edge_writer.__enter__() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.node_writer.__exit__(exc_type, exc_value, traceback) + self.edge_writer.__exit__(exc_type, exc_value, traceback) + + def write_node(self, node): node_type, node_id = node if node_id is None: return node_swhid = swhid(object_type=node_type, object_id=node_id) - node_writer.write("{}\n".format(node_swhid)) + self.node_writer.write("{}\n".format(node_swhid)) - def write_edge(src, dst, *, labels=None): + def write_edge(self, src, dst, *, labels=None): src_type, src_id = src dst_type, dst_id = dst if src_id is None or dst_id is None: @@ -44,22 +61,21 @@ src_swhid = swhid(object_type=src_type, object_id=src_id) dst_swhid = swhid(object_type=dst_type, object_id=dst_id) edge_line = " ".join([src_swhid, dst_swhid] + (labels if labels else [])) - edge_writer.write("{}\n".format(edge_line)) + self.edge_writer.write("{}\n".format(edge_line)) - messages = {k: fix_objects(k, v) for k, v in messages.items()} + def process_origin(self, origin): + origin_id = origin_identifier({"url": origin["origin"]}) + self.write_node(("origin", origin_id)) - for visit_status in messages.get("origin_visit_status", []): + def process_origin_visit_status(self, visit_status): origin_id = origin_identifier({"url": visit_status["origin"]}) - visit_id = visit_status["visit"] - if not node_set.add("{}:{}".format(origin_id, visit_id).encode()): - continue - write_node(("origin", origin_id)) - write_edge(("origin", origin_id), ("snapshot", visit_status["snapshot"])) - - for snapshot in messages.get("snapshot", []): - if not node_set.add(snapshot["id"]): - continue - write_node(("snapshot", snapshot["id"])) + self.write_edge(("origin", origin_id), ("snapshot", visit_status["snapshot"])) + + def process_snapshot(self, snapshot): + if self.config.get("remove_pull_requests"): + self.remove_pull_requests(snapshot) + + self.write_node(("snapshot", snapshot["id"])) for branch_name, branch in snapshot["branches"].items(): original_branch_name = branch_name while branch and branch.get("target_type") == "alias": @@ -67,95 +83,67 @@ branch = snapshot["branches"][branch_name] if branch is None or not branch_name: continue - # Heuristic to filter out pull requests in snapshots: remove all - # branches that start with refs/ but do not start with refs/heads or - # refs/tags. - if config.get("remove_pull_requests") and ( - branch_name.startswith(b"refs/") - and not ( - branch_name.startswith(b"refs/heads") - or branch_name.startswith(b"refs/tags") - ) - ): - continue - write_edge( + self.write_edge( ("snapshot", snapshot["id"]), (branch["target_type"], branch["target"]), labels=[base64.b64encode(original_branch_name).decode(),], ) - for release in messages.get("release", []): - if not node_set.add(release["id"]): - continue - write_node(("release", release["id"])) - write_edge( + def process_release(self, release): + self.write_node(("release", release["id"])) + self.write_edge( ("release", release["id"]), (release["target_type"], release["target"]) ) - for revision in messages.get("revision", []): - if not node_set.add(revision["id"]): - continue - write_node(("revision", revision["id"])) - write_edge(("revision", revision["id"]), ("directory", revision["directory"])) + def process_revision(self, revision): + self.write_node(("revision", revision["id"])) + self.write_edge( + ("revision", revision["id"]), ("directory", revision["directory"]) + ) for parent in revision["parents"]: - write_edge(("revision", revision["id"]), ("revision", parent)) + self.write_edge(("revision", revision["id"]), ("revision", parent)) - for directory in messages.get("directory", []): - if not node_set.add(directory["id"]): - continue - write_node(("directory", directory["id"])) + def process_directory(self, directory): + self.write_node(("directory", directory["id"])) for entry in directory["entries"]: entry_type_mapping = { "file": "content", "dir": "directory", "rev": "revision", } - write_edge( + self.write_edge( ("directory", directory["id"]), (entry_type_mapping[entry["type"]], entry["target"]), labels=[base64.b64encode(entry["name"]).decode(), str(entry["perms"]),], ) - for content in messages.get("content", []): - if not node_set.add(content["sha1_git"]): - continue - write_node(("content", content["sha1_git"])) - + def process_content(self, content): + self.write_node(("content", content["sha1_git"])) -class GraphEdgeExporter(ParallelExporter): - """ - Implementation of ParallelExporter which writes all the graph edges - of a specific type in a Zstandard-compressed CSV file. - - Each row of the CSV is in the format: ` - """ - - def export_worker(self, export_path, **kwargs): - dataset_path = pathlib.Path(export_path) - dataset_path.mkdir(exist_ok=True, parents=True) - unique_id = str(uuid.uuid4()) - nodes_file = dataset_path / ("graph-{}.nodes.csv.zst".format(unique_id)) - edges_file = dataset_path / ("graph-{}.edges.csv.zst".format(unique_id)) - node_set_file = dataset_path / (".set-nodes-{}.sqlite3".format(unique_id)) - - with contextlib.ExitStack() as stack: - node_writer = stack.enter_context(ZSTFile(nodes_file, "w")) - edge_writer = stack.enter_context(ZSTFile(edges_file, "w")) - node_set = stack.enter_context(SQLiteSet(node_set_file)) - - process_fn = functools.partial( - process_messages, - config=self.config, - node_writer=node_writer, - edge_writer=edge_writer, - node_set=node_set, - ) - self.process(process_fn, **kwargs) + def remove_pull_requests(self, snapshot): + """ + Heuristic to filter out pull requests in snapshots: remove all branches + that start with refs/ but do not start with refs/heads or refs/tags. + """ + # Copy the items with list() to remove items during iteration + for branch_name, branch in list(snapshot["branches"].items()): + original_branch_name = branch_name + while branch and branch.get("target_type") == "alias": + branch_name = branch["target"] + branch = snapshot["branches"][branch_name] + if branch is None or not branch_name: + continue + if branch_name.startswith(b"refs/") and not ( + branch_name.startswith(b"refs/heads") + or branch_name.startswith(b"refs/tags") + ): + snapshot["branches"].pop(original_branch_name) def export_edges(config, export_path, export_id, processes): """Run the edge exporter for each edge type.""" object_types = [ + "origin", "origin_visit_status", "snapshot", "release", @@ -165,8 +153,13 @@ ] for obj_type in object_types: print("{} edges:".format(obj_type)) - exporter = GraphEdgeExporter(config, export_id, obj_type, processes) - exporter.run(os.path.join(export_path, obj_type)) + exporters = [ + (GraphEdgesExporter, {"export_path": os.path.join(export_path, obj_type)}), + ] + parallel_exporter = ParallelJournalProcessor( + config, exporters, export_id, obj_type, processes + ) + parallel_exporter.run() def sort_graph_nodes(export_path, config): diff --git a/swh/dataset/exporter.py b/swh/dataset/journalprocessor.py copy from swh/dataset/exporter.py copy to swh/dataset/journalprocessor.py --- a/swh/dataset/exporter.py +++ b/swh/dataset/journalprocessor.py @@ -5,14 +5,21 @@ import concurrent.futures from concurrent.futures import FIRST_EXCEPTION, ProcessPoolExecutor +from hashlib import sha1 import multiprocessing +from pathlib import Path import time -from typing import Mapping, Sequence, Tuple +from typing import Any, Dict, Mapping, Sequence, Tuple, Type from confluent_kafka import TopicPartition import tqdm +from swh.dataset.exporter import Exporter +from swh.dataset.utils import SQLiteSet from swh.journal.client import JournalClient +from swh.journal.serializers import kafka_to_value +from swh.model.identifiers import origin_identifier +from swh.storage.fixer import fix_objects class JournalClientOffsetRanges(JournalClient): @@ -99,34 +106,51 @@ """ Override of the message deserialization to hook the handling of the message offset. + We also return the raw objects instead of deserializing them because we + will need the partition ID later. """ + # XXX: this is a bad hack that we have to do because of how the + # journal API works. Ideally it would be better to change the API so + # that journal clients can know the partition of the message they are + # handling. self.handle_offset(message.partition(), message.offset()) self.count += 1 - return super().deserialize_message(message) + # return super().deserialize_message(message) + return message -class ParallelExporter: +class ParallelJournalProcessor: """ - Base class for all the Journal exporters. - - Each exporter should override the `export_worker` function with an - implementation of how to run the message processing. + Reads the given object type from the journal in parallel. + It creates one JournalExportWorker per process. """ - def __init__(self, config, export_id: str, obj_type, processes=1): + def __init__( + self, + config, + exporters: Sequence[Tuple[Type[Exporter], Dict[str, Any]]], + export_id: str, + obj_type: str, + processes: int = 1, + node_sets_path: Path = None, + ): """ Args: config: the exporter config, which should also include the JournalClient configuration. + exporters: a list of Exporter to process the objects export_id: a unique identifier for the export that will be used as part of a Kafka consumer group ID. obj_type: The type of SWH object to export. processes: The number of processes to run. + node_sets_path: A directory where to store the node sets. """ self.config = config + self.exporters = exporters self.export_id = "swh-dataset-export-{}".format(export_id) self.obj_type = obj_type self.processes = processes + self.node_sets_path = node_sets_path self.offsets = None def get_offsets(self): @@ -153,7 +177,7 @@ self.offsets[partition_id] = (lo, hi) return self.offsets - def run(self, *args): + def run(self): """ Run the parallel export. """ @@ -169,7 +193,6 @@ futures.append( pool.submit( self.export_worker, - *args, assignment=to_assign[i :: self.processes], queue=q, ) @@ -208,34 +231,124 @@ ) pbar.update(progress - pbar.n) - def process(self, callback, assignment=None, queue=None): + def export_worker(self, assignment, progress_queue): + worker = JournalProcessorWorker( + self.config, + self.exporters, + self.export_id, + self.obj_type, + self.offsets, + assignment, + progress_queue, + ) + with worker: + worker.run() + + +class JournalProcessorWorker: + """ + Worker process that processes all the messages and calls the given exporters + for each object read from the journal. + """ + + def __init__( + self, + config, + exporters: Sequence[Tuple[Type[Exporter], Dict[str, Any]]], + export_id: str, + obj_type: str, + offsets: Dict[int, Tuple[int, int]], + assignment: Sequence[int], + progress_queue: multiprocessing.Queue, + node_sets_path: Path, + ): + self.config = config + self.export_id = export_id + self.obj_type = obj_type + self.offsets = offsets + self.assignment = assignment + self.progress_queue = progress_queue + + self.node_sets_path = node_sets_path + self.node_sets: Dict[int, SQLiteSet] = {} + + self.exporters = [ + exporter_class(**kwargs) for exporter_class, kwargs in exporters + ] + + def __enter__(self): + for exporter in self.exporters: + exporter.__enter__() + return self + + def __exit__(self, exc_type, exc_value, traceback): + for exporter in self.exporters: + exporter.__exit__() + + def get_node_set_for_partition(self, partition_id: int): + """ + Return an on-disk set object, which stores the nodes that have + already been processed. + + Node sets are sharded by partition ID, as each object is guaranteed to + be assigned to a deterministic Kafka partition. + """ + if partition_id not in self.node_sets: + node_set_file = self.node_sets_path / "nodes-part-{}.sqlite".format( + partition_id + ) + self.node_sets[partition_id] = SQLiteSet(node_set_file) + return self.node_sets[partition_id] + + def run(self): + """ + Start a Journal client on the given assignment and process all the + incoming messages. + """ client = JournalClientOffsetRanges( **self.config["journal"], object_types=[self.obj_type], group_id=self.export_id, debug="cgrp,broker", offset_ranges=self.offsets, - assignment=assignment, - progress_queue=queue, + assignment=self.assignment, + progress_queue=self.progress_queue, **{"message.max.bytes": str(500 * 1024 * 1024)}, ) - client.process(callback) + client.process(self.process_messages) - def export_worker(self, *args, **kwargs): + def process_messages(self, messages): """ - Override this with a custom implementation of a worker function. - - A worker function should call `self.process(fn, **kwargs)` with `fn` - being a callback that will be called in the same fashion as with - `JournalClient.process()`. + Process the incoming Kafka messages. + """ + messages = {k: fix_objects(k, v) for k, v in messages.items()} + for object_type, message_list in messages.items(): + for message in message_list: + self.process_messages(object_type, message) - A simple exporter to print all the objects in the log would look like - this: + def process_message(self, object_type, message): + """ + Process a single incoming Kafka message if the object it refers to has + not been processed yet. - ``` - class PrintExporter(ParallelExporter): - def export_worker(self, **kwargs): - self.process(print, **kwargs) - ``` + It uses an on-disk set to make sure that each object is only ever + processed once. """ - raise NotImplementedError + object = kafka_to_value(message.value()) + partition = message.partition() + node_set = self.get_node_set_for_partition(partition) + + if object_type == "origin_visit_status": + origin_id = origin_identifier({"url": object["origin"]}) + visit = object["visit"] + node_id = sha1("{}:{}".format(origin_id, visit)).digest() + elif object_type == "content": + node_id = object["sha1_git"] + else: + node_id = object["id"] + if not node_set.add(node_id): + # Node already processed, skipping. + return + + for exporter in self.exporters: + exporter.process_object(object_type, object) diff --git a/swh/dataset/test/test_graph.py b/swh/dataset/test/test_graph.py --- a/swh/dataset/test/test_graph.py +++ b/swh/dataset/test/test_graph.py @@ -11,7 +11,7 @@ import pytest -from swh.dataset.graph import process_messages, sort_graph_nodes +from swh.dataset.graph import GraphEdgesExporter, sort_graph_nodes from swh.dataset.utils import ZSTFile from swh.model.hashutil import MultiHash, hash_to_bytes @@ -89,17 +89,13 @@ def wrapped(messages, config=None) -> Tuple[Mock, Mock]: if config is None: config = {} - node_writer = Mock() - edge_writer = Mock() - node_set = FakeDiskSet() - process_messages( - messages, - config=config, - node_writer=node_writer, - edge_writer=edge_writer, - node_set=node_set, - ) - return node_writer.write, edge_writer.write + exporter = GraphEdgesExporter(config, "/dummy_path") + exporter.node_writer = Mock() + exporter.edge_writer = Mock() + for object_type, objects in messages.items(): + for obj in objects: + exporter.process_object(object_type, obj) + return exporter.node_writer.write, exporter.edge_writer.write return wrapped @@ -116,6 +112,17 @@ return b64encode(s.encode()).decode() +def test_export_origin(exporter): + node_writer, edge_writer = exporter( + {"origin": [{"origin": "ori1"}, {"origin": "ori2"},]} + ) + assert node_writer.mock_calls == [ + call(f"swh:1:ori:{hexhash('ori1')}\n"), + call(f"swh:1:ori:{hexhash('ori2')}\n"), + ] + assert edge_writer.mock_calls == [] + + def test_export_origin_visit_status(exporter): node_writer, edge_writer = exporter( { @@ -133,10 +140,7 @@ ] } ) - assert node_writer.mock_calls == [ - call(f"swh:1:ori:{hexhash('ori1')}\n"), - call(f"swh:1:ori:{hexhash('ori2')}\n"), - ] + assert node_writer.mock_calls == [] assert edge_writer.mock_calls == [ call(f"swh:1:ori:{hexhash('ori1')} swh:1:snp:{hexhash('snp1')}\n"), call(f"swh:1:ori:{hexhash('ori2')} swh:1:snp:{hexhash('snp2')}\n"), @@ -465,45 +469,6 @@ assert edge_writer.mock_calls == [] -def test_export_duplicate_node(exporter): - node_writer, edge_writer = exporter( - { - "content": [ - {**TEST_CONTENT, "sha1_git": binhash("cnt1")}, - {**TEST_CONTENT, "sha1_git": binhash("cnt1")}, - {**TEST_CONTENT, "sha1_git": binhash("cnt1")}, - ], - }, - ) - assert node_writer.mock_calls == [ - call(f"swh:1:cnt:{hexhash('cnt1')}\n"), - ] - assert edge_writer.mock_calls == [] - - -def test_export_duplicate_visit(exporter): - node_writer, edge_writer = exporter( - { - "origin_visit_status": [ - {**TEST_ORIGIN_VISIT_STATUS, "origin": "ori1", "visit": 1}, - {**TEST_ORIGIN_VISIT_STATUS, "origin": "ori2", "visit": 1}, - {**TEST_ORIGIN_VISIT_STATUS, "origin": "ori1", "visit": 1}, - {**TEST_ORIGIN_VISIT_STATUS, "origin": "ori2", "visit": 1}, - {**TEST_ORIGIN_VISIT_STATUS, "origin": "ori1", "visit": 2}, - {**TEST_ORIGIN_VISIT_STATUS, "origin": "ori2", "visit": 2}, - {**TEST_ORIGIN_VISIT_STATUS, "origin": "ori2", "visit": 2}, - ], - }, - ) - assert node_writer.mock_calls == [ - call(f"swh:1:ori:{hexhash('ori1')}\n"), - call(f"swh:1:ori:{hexhash('ori2')}\n"), - call(f"swh:1:ori:{hexhash('ori1')}\n"), - call(f"swh:1:ori:{hexhash('ori2')}\n"), - ] - assert edge_writer.mock_calls == [] - - def zstwrite(fp, lines): with ZSTFile(fp, "w") as writer: for line in lines: diff --git a/swh/dataset/utils.py b/swh/dataset/utils.py --- a/swh/dataset/utils.py +++ b/swh/dataset/utils.py @@ -14,13 +14,13 @@ command to compress and deflate the objects. """ - def __init__(self, path, mode="r"): + def __init__(self, path: str, mode: str = "r"): if mode not in ("r", "rb", "w", "wb"): raise ValueError(f"ZSTFile mode {mode} is invalid.") self.path = path self.mode = mode - def __enter__(self): + def __enter__(self) -> "ZSTFile": is_text = not (self.mode in ("rb", "wb")) writing = self.mode in ("w", "wb") if writing: @@ -50,13 +50,15 @@ deduplicate objects when processing large queues with duplicates. """ - def __init__(self, db_path: os.PathLike): + def __init__(self, db_path: os.PathLike[str]): self.db_path = db_path def __enter__(self): self.db = sqlite3.connect(str(self.db_path)) self.db.execute( - "CREATE TABLE tmpset (val TEXT NOT NULL PRIMARY KEY) WITHOUT ROWID" + "CREATE TABLE IF NOT EXISTS" + " tmpset (val TEXT NOT NULL PRIMARY KEY)" + " WITHOUT ROWID" ) return self