diff --git a/swh/dataset/athena.py b/swh/dataset/athena.py index 235cf39..779b231 100644 --- a/swh/dataset/athena.py +++ b/swh/dataset/athena.py @@ -1,273 +1,281 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """ This module implements the "athena" subcommands for the CLI. It can install and query a remote AWS Athena database. """ import datetime import logging import os import sys import textwrap import time import boto3 import botocore.exceptions from swh.dataset.relational import TABLES def create_database(database_name): return "CREATE DATABASE IF NOT EXISTS {};".format(database_name) def drop_table(database_name, table): return "DROP TABLE IF EXISTS {}.{};".format(database_name, table) def create_table(database_name, table, location_prefix): req = textwrap.dedent( """\ CREATE EXTERNAL TABLE IF NOT EXISTS {db}.{table} ( {fields} ) STORED AS ORC LOCATION '{location}/' TBLPROPERTIES ("orc.compress"="ZSTD"); """ ).format( db=database_name, table=table, fields=",\n".join( [ " `{}` {}".format(col_name, col_type) for col_name, col_type in TABLES[table] ] ), location=os.path.join(location_prefix, "orc", table), ) return req def repair_table(database_name, table): return "MSCK REPAIR TABLE {}.{};".format(database_name, table) def query(client, query_string, *, desc="Querying", delay_secs=0.5, silent=False): def log(*args, **kwargs): if not silent: print(*args, **kwargs, flush=True, file=sys.stderr) log(desc, end="...") query_options = { "QueryString": query_string, "ResultConfiguration": {}, "QueryExecutionContext": {}, } if client.output_location: query_options["ResultConfiguration"]["OutputLocation"] = client.output_location if client.database_name: query_options["QueryExecutionContext"]["Database"] = client.database_name try: res = client.start_query_execution(**query_options) except botocore.exceptions.ClientError as e: raise RuntimeError( str(e) + "\n\nQuery:\n" + textwrap.indent(query_string, " " * 2) ) qid = res["QueryExecutionId"] while True: time.sleep(delay_secs) log(".", end="") execution = client.get_query_execution(QueryExecutionId=qid) status = execution["QueryExecution"]["Status"] if status["State"] in ("SUCCEEDED", "FAILED", "CANCELLED"): break log(" {}.".format(status["State"])) if status["State"] != "SUCCEEDED": raise RuntimeError( status["StateChangeReason"] + "\n\nQuery:\n" + textwrap.indent(query_string, " " * 2) ) return execution["QueryExecution"] def create_tables(database_name, dataset_location, output_location=None, replace=False): """ Create the Software Heritage Dataset tables on AWS Athena. Athena works on external columnar data stored in S3, but requires a schema for each table to run queries. This creates all the necessary tables remotely by using the relational schemas in swh.dataset.relational. """ client = boto3.client("athena") client.output_location = output_location client.database_name = database_name query( client, create_database(database_name), desc="Creating {} database".format(database_name), ) if replace: for table in TABLES: query( client, drop_table(database_name, table), desc="Dropping table {}".format(table), ) for table in TABLES: query( client, create_table(database_name, table, dataset_location), desc="Creating table {}".format(table), ) for table in TABLES: query( client, repair_table(database_name, table), desc="Refreshing table metadata for {}".format(table), ) def human_size(n, units=["bytes", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB"]): - """ Returns a human readable string representation of bytes """ + """Returns a human readable string representation of bytes""" return f"{n} " + units[0] if n < 1024 else human_size(n >> 10, units[1:]) def _s3_url_to_bucket_path(s3_url): loc = s3_url.removeprefix("s3://") bucket, path = loc.split("/", 1) return bucket, path def run_query_get_results( - database_name, query_string, output_location=None, + database_name, + query_string, + output_location=None, ): """ Run a query on AWS Athena and return the resulting data in CSV format. """ athena = boto3.client("athena") athena.output_location = output_location athena.database_name = database_name s3 = boto3.client("s3") result = query(athena, query_string, silent=True) logging.info( "Scanned %s in %s", human_size(result["Statistics"]["DataScannedInBytes"]), datetime.timedelta( milliseconds=result["Statistics"]["TotalExecutionTimeInMillis"] ), ) bucket, path = _s3_url_to_bucket_path( result["ResultConfiguration"]["OutputLocation"] ) return s3.get_object(Bucket=bucket, Key=path)["Body"].read().decode() def generate_subdataset( - dataset_db, subdataset_db, subdataset_s3_path, swhids_file, output_location=None, + dataset_db, + subdataset_db, + subdataset_s3_path, + swhids_file, + output_location=None, ): # Upload list of all the swhids included in the dataset subdataset_bucket, subdataset_path = _s3_url_to_bucket_path(subdataset_s3_path) s3_client = boto3.client("s3") print(f"Uploading {swhids_file} to S3...") s3_client.upload_file( swhids_file, subdataset_bucket, os.path.join(subdataset_path, "swhids", "swhids.csv"), ) athena_client = boto3.client("athena") athena_client.output_location = output_location athena_client.database_name = subdataset_db # Create subdataset database query( athena_client, create_database(subdataset_db), desc="Creating {} database".format(subdataset_db), ) # Create SWHID temporary table create_swhid_table_query = textwrap.dedent( """\ CREATE EXTERNAL TABLE IF NOT EXISTS {newdb}.swhids ( swhprefix string, version int, type string, hash string ) ROW FORMAT DELIMITED FIELDS TERMINATED BY ':' STORED AS TEXTFILE LOCATION '{location}/swhids/' """ ).format(newdb=subdataset_db, location=subdataset_s3_path) query( athena_client, create_swhid_table_query, desc="Creating SWHIDs table of subdataset", ) query( athena_client, repair_table(subdataset_db, "swhids"), desc="Refreshing table metadata for swhids table", ) # Create join tables query_tpl = textwrap.dedent( """\ CREATE TABLE IF NOT EXISTS {newdb}.{table} WITH ( format = 'ORC', write_compression = 'ZSTD', external_location = '{location}/{table}/' ) AS SELECT * FROM {basedb}.{table} WHERE {field} IN (select hash from swhids) """ ) tables_join_field = [ ("origin", "lower(to_hex(sha1(to_utf8(url))))"), ("origin_visit", "lower(to_hex(sha1(to_utf8(origin))))"), ("origin_visit_status", "lower(to_hex(sha1(to_utf8(origin))))"), ("snapshot", "id"), ("snapshot_branch", "snapshot_id"), ("release", "id"), ("revision", "id"), ("revision_history", "id"), ("directory", "id"), ("directory_entry", "directory_id"), ("content", "sha1_git"), ("skipped_content", "sha1_git"), ] for table, join_field in tables_join_field: ctas_query = query_tpl.format( newdb=subdataset_db, basedb=dataset_db, location=subdataset_s3_path, table=table, field=join_field, ) # Temporary fix: Athena no longer supports >32MB rows, but some of # the objects were added to the dataset before this restriction was # in place. if table == "revision": ctas_query += " AND length(message) < 100000" query( - athena_client, ctas_query, desc="Creating join table {}".format(table), + athena_client, + ctas_query, + desc="Creating join table {}".format(table), ) diff --git a/swh/dataset/cli.py b/swh/dataset/cli.py index 8deaed2..f7fa7dc 100644 --- a/swh/dataset/cli.py +++ b/swh/dataset/cli.py @@ -1,229 +1,238 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information # WARNING: do not import unnecessary things here to keep cli startup time under # control import os import pathlib import sys import click from swh.core.cli import CONTEXT_SETTINGS from swh.core.cli import swh as swh_cli_group from swh.dataset.relational import MAIN_TABLES @swh_cli_group.group(name="dataset", context_settings=CONTEXT_SETTINGS) @click.option( "--config-file", "-C", default=None, type=click.Path(exists=True, dir_okay=False), help="Configuration file.", ) @click.pass_context def dataset_cli_group(ctx, config_file): """Dataset Tools. A set of tools to export datasets from the Software Heritage Archive in various formats. """ from swh.core import config ctx.ensure_object(dict) conf = config.read(config_file) ctx.obj["config"] = conf @dataset_cli_group.group("graph") @click.pass_context def graph(ctx): """Manage graph export""" pass AVAILABLE_EXPORTERS = { "edges": "swh.dataset.exporters.edges:GraphEdgesExporter", "orc": "swh.dataset.exporters.orc:ORCExporter", } @graph.command("export") @click.argument("export-path", type=click.Path()) @click.option( "--export-id", "-e", help=( "Unique ID of the export run. This is appended to the kafka " "group_id config file option. If group_id is not set in the " "'journal' section of the config file, defaults to 'swh-dataset-export-'." ), ) @click.option( "--formats", "-f", type=click.STRING, default=",".join(AVAILABLE_EXPORTERS.keys()), show_default=True, help="Formats to export.", ) @click.option("--processes", "-p", default=1, help="Number of parallel processes") @click.option( "--exclude", type=click.STRING, help="Comma-separated list of object types to exclude", ) @click.pass_context def export_graph(ctx, export_path, export_id, formats, exclude, processes): """Export the Software Heritage graph as an edge dataset.""" - import uuid from importlib import import_module + import uuid + from swh.dataset.journalprocessor import ParallelJournalProcessor config = ctx.obj["config"] if not export_id: export_id = str(uuid.uuid4()) exclude_obj_types = {o.strip() for o in (exclude.split(",") if exclude else [])} export_formats = [c.strip() for c in formats.split(",")] for f in export_formats: if f not in AVAILABLE_EXPORTERS: raise click.BadOptionUsage( option_name="formats", message=f"{f} is not an available format." ) def importcls(clspath): mod, cls = clspath.split(":") m = import_module(mod) return getattr(m, cls) exporter_cls = dict( (fmt, importcls(clspath)) for (fmt, clspath) in AVAILABLE_EXPORTERS.items() if fmt in export_formats ) # Run the exporter for each edge type. for obj_type in MAIN_TABLES: if obj_type in exclude_obj_types: continue exporters = [ - (exporter_cls[f], {"export_path": os.path.join(export_path, f)},) + ( + exporter_cls[f], + {"export_path": os.path.join(export_path, f)}, + ) for f in export_formats ] parallel_exporter = ParallelJournalProcessor( config, exporters, export_id, obj_type, node_sets_path=pathlib.Path(export_path) / ".node_sets" / obj_type, processes=processes, ) print("Exporting {}:".format(obj_type)) parallel_exporter.run() @graph.command("sort") @click.argument("export-path", type=click.Path()) @click.pass_context def sort_graph(ctx, export_path): config = ctx.obj["config"] from swh.dataset.exporters.edges import sort_graph_nodes sort_graph_nodes(export_path, config) @dataset_cli_group.group("athena") @click.pass_context def athena(ctx): """Manage and query a remote AWS Athena database""" pass @athena.command("create") @click.option( "--database-name", "-d", default="swh", help="Name of the database to create" ) @click.option( "--location-prefix", "-l", required=True, help="S3 prefix where the dataset can be found", ) @click.option( "-o", "--output-location", help="S3 prefix where results should be stored" ) @click.option( "-r", "--replace-tables", is_flag=True, help="Replace the tables that already exist" ) def athena_create( database_name, location_prefix, output_location=None, replace_tables=False ): """Create tables on AWS Athena pointing to a given graph dataset on S3.""" from swh.dataset.athena import create_tables create_tables( database_name, location_prefix, output_location=output_location, replace=replace_tables, ) @athena.command("query") @click.option( "--database-name", "-d", default="swh", help="Name of the database to query" ) @click.option( "-o", "--output-location", help="S3 prefix where results should be stored" ) @click.argument("query_file", type=click.File("r"), default=sys.stdin) def athena_query( - database_name, query_file, output_location=None, + database_name, + query_file, + output_location=None, ): """Query the AWS Athena database with a given command""" from swh.dataset.athena import run_query_get_results print( run_query_get_results( - database_name, query_file.read(), output_location=output_location, + database_name, + query_file.read(), + output_location=output_location, ), end="", ) # CSV already ends with \n @athena.command("gensubdataset") @click.option("--database", "-d", default="swh", help="Name of the base database") @click.option( - "--subdataset-database", required=True, - help="Name of the subdataset database to create" + "--subdataset-database", + required=True, + help="Name of the subdataset database to create", ) @click.option( "--subdataset-location", required=True, help="S3 prefix where the subdataset should be stored", ) @click.option( "--swhids", required=True, help="File containing the list of SWHIDs to include in the subdataset", ) def athena_gensubdataset(database, subdataset_database, subdataset_location, swhids): """ Generate a subdataset with Athena, from an existing database and a list of SWHIDs. Athena will generate a new dataset with the same tables as in the base dataset, but only containing the objects present in the SWHID list. """ from swh.dataset.athena import generate_subdataset generate_subdataset( database, subdataset_database, subdataset_location, swhids, os.path.join(subdataset_location, "queries"), ) diff --git a/swh/dataset/exporters/edges.py b/swh/dataset/exporters/edges.py index 440d67c..54996fb 100644 --- a/swh/dataset/exporters/edges.py +++ b/swh/dataset/exporters/edges.py @@ -1,229 +1,231 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 import os import os.path import shlex import subprocess import tempfile from typing import Tuple from swh.dataset.exporter import ExporterDispatch from swh.dataset.utils import ZSTFile, remove_pull_requests from swh.model.hashutil import hash_to_hex from swh.model.model import Origin from swh.model.swhids import ExtendedObjectType def swhid(object_type, object_id): # We use string interpolation here instead of using ExtendedSWHID to format, # as building temporary ExtendedSWHID objects has a non-negligeable impact # on performance. return f"swh:1:{object_type.value}:{hash_to_hex(object_id)}" class GraphEdgesExporter(ExporterDispatch): """ Implementation of an exporter which writes all the graph edges of a specific type to a Zstandard-compressed CSV file. Each row of the CSV is in the format: `` ``. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.writers = {} def get_writers_for(self, obj_type: ExtendedObjectType): if obj_type not in self.writers: dataset_path = self.export_path / obj_type.name.lower() dataset_path.mkdir(exist_ok=True) unique_id = self.get_unique_file_id() nodes_file = dataset_path / ("graph-{}.nodes.csv.zst".format(unique_id)) edges_file = dataset_path / ("graph-{}.edges.csv.zst".format(unique_id)) node_writer = self.exit_stack.enter_context(ZSTFile(str(nodes_file), "w")) edge_writer = self.exit_stack.enter_context(ZSTFile(str(edges_file), "w")) self.writers[obj_type] = (node_writer, edge_writer) return self.writers[obj_type] def get_node_writer_for(self, obj_type: ExtendedObjectType): return self.get_writers_for(obj_type)[0] def get_edge_writer_for(self, obj_type: ExtendedObjectType): return self.get_writers_for(obj_type)[1] def write_node(self, node: Tuple[ExtendedObjectType, bytes]): node_type, node_id = node if node_id is None: return node_swhid = swhid(object_type=node_type, object_id=node_id) node_writer = self.get_node_writer_for(node_type) node_writer.write("{}\n".format(node_swhid)) def write_edge( self, src: Tuple[ExtendedObjectType, bytes], dst: Tuple[ExtendedObjectType, bytes], *, labels=None, ): src_type, src_id = src dst_type, dst_id = dst if src_id is None or dst_id is None: return src_swhid = swhid(object_type=src_type, object_id=src_id) dst_swhid = swhid(object_type=dst_type, object_id=dst_id) edge_line = " ".join([src_swhid, dst_swhid] + (labels if labels else [])) edge_writer = self.get_edge_writer_for(src_type) edge_writer.write("{}\n".format(edge_line)) def process_origin(self, origin): origin_id = Origin(url=origin["url"]).id self.write_node((ExtendedObjectType.ORIGIN, origin_id)) def process_origin_visit_status(self, visit_status): origin_id = Origin(url=visit_status["origin"]).id self.write_edge( (ExtendedObjectType.ORIGIN, origin_id), (ExtendedObjectType.SNAPSHOT, visit_status["snapshot"]), ) def process_snapshot(self, snapshot): if self.config.get("remove_pull_requests"): remove_pull_requests(snapshot) self.write_node((ExtendedObjectType.SNAPSHOT, snapshot["id"])) for branch_name, branch in snapshot["branches"].items(): original_branch_name = branch_name while branch and branch.get("target_type") == "alias": branch_name = branch["target"] branch = snapshot["branches"].get(branch_name) if branch is None or not branch_name: continue self.write_edge( (ExtendedObjectType.SNAPSHOT, snapshot["id"]), (ExtendedObjectType[branch["target_type"].upper()], branch["target"]), - labels=[base64.b64encode(original_branch_name).decode(),], + labels=[ + base64.b64encode(original_branch_name).decode(), + ], ) def process_release(self, release): self.write_node((ExtendedObjectType.RELEASE, release["id"])) self.write_edge( (ExtendedObjectType.RELEASE, release["id"]), (ExtendedObjectType[release["target_type"].upper()], release["target"]), ) def process_revision(self, revision): self.write_node((ExtendedObjectType.REVISION, revision["id"])) self.write_edge( (ExtendedObjectType.REVISION, revision["id"]), (ExtendedObjectType.DIRECTORY, revision["directory"]), ) for parent in revision["parents"]: self.write_edge( (ExtendedObjectType.REVISION, revision["id"]), (ExtendedObjectType.REVISION, parent), ) def process_directory(self, directory): self.write_node((ExtendedObjectType.DIRECTORY, directory["id"])) for entry in directory["entries"]: entry_type_mapping = { "file": ExtendedObjectType.CONTENT, "dir": ExtendedObjectType.DIRECTORY, "rev": ExtendedObjectType.REVISION, } self.write_edge( (ExtendedObjectType.DIRECTORY, directory["id"]), (entry_type_mapping[entry["type"]], entry["target"]), labels=[base64.b64encode(entry["name"]).decode(), str(entry["perms"])], ) def process_content(self, content): self.write_node((ExtendedObjectType.CONTENT, content["sha1_git"])) def sort_graph_nodes(export_path, config): """ Generate the node list from the edges files. We cannot solely rely on the object IDs that are read in the journal, as some nodes that are referred to as destinations in the edge file might not be present in the archive (e.g a rev_entry referring to a revision that we do not have crawled yet). The most efficient way of getting all the nodes that are mentioned in the edges file is therefore to use sort(1) on the gigantic edge files to get all the unique node IDs, while using the disk as a temporary buffer. This pipeline does, in order: - concatenate and write all the compressed edges files in graph.edges.csv.zst (using the fact that ZST compression is an additive function) ; - deflate the edges ; - count the number of edges and write it in graph.edges.count.txt ; - count the number of occurrences of each edge type and write them in graph.edges.stats.txt ; - concatenate all the (deflated) nodes from the export with the destination edges, and sort the output to get the list of unique graph nodes ; - count the number of unique graph nodes and write it in graph.nodes.count.txt ; - count the number of occurrences of each node type and write them in graph.nodes.stats.txt ; - compress and write the resulting nodes in graph.nodes.csv.zst. """ # Use awk as a replacement of `sort | uniq -c` to avoid buffering everything # in memory counter_command = "awk '{ t[$0]++ } END { for (i in t) print i,t[i] }'" sort_script = """ pv {export_path}/*/*.edges.csv.zst | tee {export_path}/graph.edges.csv.zst | zstdcat | tee >( wc -l > {export_path}/graph.edges.count.txt ) | tee >( cut -d: -f3,6 | {counter_command} | sort \ > {export_path}/graph.edges.stats.txt ) | tee >( cut -d' ' -f3 | grep . | \ sort -u -S{sort_buffer_size} -T{buffer_path} | \ zstdmt > {export_path}/graph.labels.csv.zst ) | cut -d' ' -f2 | cat - <( zstdcat {export_path}/*/*.nodes.csv.zst ) | sort -u -S{sort_buffer_size} -T{buffer_path} | tee >( wc -l > {export_path}/graph.nodes.count.txt ) | tee >( cut -d: -f3 | {counter_command} | sort \ > {export_path}/graph.nodes.stats.txt ) | zstdmt > {export_path}/graph.nodes.csv.zst """ # Use bytes for the sorting algorithm (faster than being locale-specific) env = { **os.environ.copy(), "LC_ALL": "C", "LC_COLLATE": "C", "LANG": "C", } sort_buffer_size = config.get("sort_buffer_size", "4G") disk_buffer_dir = config.get("disk_buffer_dir", export_path) with tempfile.TemporaryDirectory( prefix=".graph_node_sort_", dir=disk_buffer_dir ) as buffer_path: subprocess.run( [ "bash", "-c", sort_script.format( export_path=shlex.quote(str(export_path)), buffer_path=shlex.quote(str(buffer_path)), sort_buffer_size=shlex.quote(sort_buffer_size), counter_command=counter_command, ), ], env=env, ) diff --git a/swh/dataset/exporters/orc.py b/swh/dataset/exporters/orc.py index 6d99c0d..972494e 100644 --- a/swh/dataset/exporters/orc.py +++ b/swh/dataset/exporters/orc.py @@ -1,372 +1,384 @@ # Copyright (C) 2020-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime import logging import math from types import TracebackType from typing import Any, Optional, Tuple, Type, cast from pkg_resources import get_distribution from pyorc import ( BigInt, Binary, CompressionKind, Int, SmallInt, String, Struct, Timestamp, TypeKind, Writer, ) from pyorc.converters import ORCConverter from swh.dataset.exporter import ExporterDispatch from swh.dataset.relational import MAIN_TABLES, TABLES from swh.dataset.utils import remove_pull_requests from swh.model.hashutil import hash_to_hex from swh.model.model import TimestampWithTimezone from swh.objstorage.factory import get_objstorage from swh.objstorage.objstorage import ID_HASH_ALGO, ObjNotFoundError ORC_TYPE_MAP = { "string": String, "smallint": SmallInt, "int": Int, "bigint": BigInt, "timestamp": Timestamp, "binary": Binary, } EXPORT_SCHEMA = { table_name: Struct( **{ column_name: ORC_TYPE_MAP[column_type]() for column_name, column_type in columns } ) for table_name, columns in TABLES.items() } logger = logging.getLogger(__name__) def hash_to_hex_or_none(hash): return hash_to_hex(hash) if hash is not None else None def swh_date_to_tuple(obj): if obj is None or obj["timestamp"] is None: return (None, None, None) offset_bytes = obj.get("offset_bytes") if offset_bytes is None: offset = obj.get("offset", 0) negative = offset < 0 or obj.get("negative_utc", False) (hours, minutes) = divmod(abs(offset), 60) offset_bytes = f"{'-' if negative else '+'}{hours:02}{minutes:02}".encode() else: offset = TimestampWithTimezone._parse_offset_bytes(offset_bytes) return ( (obj["timestamp"]["seconds"], obj["timestamp"]["microseconds"]), offset, offset_bytes, ) def datetime_to_tuple(obj: Optional[datetime]) -> Optional[Tuple[int, int]]: if obj is None: return None return (math.floor(obj.timestamp()), obj.microsecond) class SWHTimestampConverter: """This is an ORCConverter compatible class to convert timestamps from/to ORC files timestamps in python are given as a couple (seconds, microseconds) and are serialized as a couple (seconds, nanoseconds) in the ORC file. Reimplemented because we do not want the Python object to be converted as ORC timestamp to be Python datatime objects, since swh.model's Timestamp cannot be converted without loss a Python datetime objects. """ # use Any as timezone annotation to make it easier to run mypy on python < # 3.9, plus we do not use the timezone argument here... @staticmethod - def from_orc(seconds: int, nanoseconds: int, timezone: Any,) -> Tuple[int, int]: + def from_orc( + seconds: int, + nanoseconds: int, + timezone: Any, + ) -> Tuple[int, int]: return (seconds, nanoseconds // 1000) @staticmethod def to_orc( - obj: Optional[Tuple[int, int]], timezone: Any, + obj: Optional[Tuple[int, int]], + timezone: Any, ) -> Optional[Tuple[int, int]]: if obj is None: return None return (obj[0], obj[1] * 1000 if obj[1] is not None else None) class ORCExporter(ExporterDispatch): """ Implementation of an exporter which writes the entire graph dataset as ORC files. Useful for large scale processing, notably on cloud instances (e.g BigQuery, Amazon Athena, Azure). """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) config = self.config.get("orc", {}) self.max_rows = config.get("max_rows", {}) invalid_tables = [ table_name for table_name in self.max_rows if table_name not in MAIN_TABLES ] if invalid_tables: raise ValueError( "Limiting the number of secondary table (%s) is not supported " "for now.", invalid_tables, ) self.with_data = config.get("with_data", False) self.objstorage = None if self.with_data: assert "objstorage" in config self.objstorage = get_objstorage(**config["objstorage"]) self._reset() def _reset(self): self.writers = {} self.writer_files = {} self.uuids = {} self.uuid_main_table = {} def __exit__( self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> Optional[bool]: for writer in self.writers.values(): writer.close() for fileobj in self.writer_files.values(): fileobj.close() self._reset() return super().__exit__(exc_type, exc_value, traceback) def maybe_close_writer_for(self, table_name: str): uuid = self.uuids.get(table_name) if ( uuid is not None and table_name in self.max_rows and self.writers[table_name].current_row >= self.max_rows[table_name] ): main_table = self.uuid_main_table[uuid] if table_name != main_table: logger.warning( "Limiting the number of secondary table (%s) is not supported " "for now (size limit ignored).", table_name, ) else: # sync/close all tables having the current uuid (aka main and # related tables) for table in [ tname for tname, tuuid in self.uuids.items() if tuuid == uuid ]: # close the writer and remove from the writers dict self.writers.pop(table).close() self.writer_files.pop(table).close() # and clean uuids dicts self.uuids.pop(table) self.uuid_main_table.pop(uuid, None) def get_writer_for(self, table_name: str, unique_id=None): self.maybe_close_writer_for(table_name) if table_name not in self.writers: object_type_dir = self.export_path / table_name object_type_dir.mkdir(exist_ok=True) if unique_id is None: unique_id = self.get_unique_file_id() self.uuid_main_table[unique_id] = table_name export_file = object_type_dir / (f"{table_name}-{unique_id}.orc") export_obj = export_file.open("wb") self.writer_files[table_name] = export_obj self.writers[table_name] = Writer( export_obj, EXPORT_SCHEMA[table_name], compression=CompressionKind.ZSTD, converters={ TypeKind.TIMESTAMP: cast(Type[ORCConverter], SWHTimestampConverter) }, ) self.writers[table_name].set_user_metadata( swh_object_type=table_name.encode(), swh_uuid=unique_id.encode(), swh_model_version=get_distribution("swh.model").version.encode(), swh_dataset_version=get_distribution("swh.dataset").version.encode(), # maybe put a copy of the config (redacted) also? ) self.uuids[table_name] = unique_id return self.writers[table_name] def process_origin(self, origin): origin_writer = self.get_writer_for("origin") origin_writer.write((origin["url"],)) def process_origin_visit(self, visit): origin_visit_writer = self.get_writer_for("origin_visit") origin_visit_writer.write( ( visit["origin"], visit["visit"], datetime_to_tuple(visit["date"]), visit["type"], ) ) def process_origin_visit_status(self, visit_status): origin_visit_status_writer = self.get_writer_for("origin_visit_status") origin_visit_status_writer.write( ( visit_status["origin"], visit_status["visit"], datetime_to_tuple(visit_status["date"]), visit_status["status"], hash_to_hex_or_none(visit_status["snapshot"]), visit_status["type"], ) ) def process_snapshot(self, snapshot): if self.config.get("orc", {}).get("remove_pull_requests"): remove_pull_requests(snapshot) snapshot_writer = self.get_writer_for("snapshot") snapshot_writer.write((hash_to_hex_or_none(snapshot["id"]),)) # we want to store branches in the same directory as snapshot objects, # and have both files have the same UUID. snapshot_branch_writer = self.get_writer_for( - "snapshot_branch", unique_id=self.uuids["snapshot"], + "snapshot_branch", + unique_id=self.uuids["snapshot"], ) for branch_name, branch in snapshot["branches"].items(): if branch is None: continue snapshot_branch_writer.write( ( hash_to_hex_or_none(snapshot["id"]), branch_name, hash_to_hex_or_none(branch["target"]), branch["target_type"], ) ) def process_release(self, release): release_writer = self.get_writer_for("release") release_writer.write( ( hash_to_hex_or_none(release["id"]), release["name"], release["message"], hash_to_hex_or_none(release["target"]), release["target_type"], (release.get("author") or {}).get("fullname"), *swh_date_to_tuple(release["date"]), release.get("raw_manifest"), ) ) def process_revision(self, revision): release_writer = self.get_writer_for("revision") release_writer.write( ( hash_to_hex_or_none(revision["id"]), revision["message"], revision["author"]["fullname"], *swh_date_to_tuple(revision["date"]), revision["committer"]["fullname"], *swh_date_to_tuple(revision["committer_date"]), hash_to_hex_or_none(revision["directory"]), revision["type"], revision.get("raw_manifest"), ) ) revision_history_writer = self.get_writer_for( - "revision_history", unique_id=self.uuids["revision"], + "revision_history", + unique_id=self.uuids["revision"], ) for i, parent_id in enumerate(revision["parents"]): revision_history_writer.write( ( hash_to_hex_or_none(revision["id"]), hash_to_hex_or_none(parent_id), i, ) ) revision_header_writer = self.get_writer_for( - "revision_extra_headers", unique_id=self.uuids["revision"], + "revision_extra_headers", + unique_id=self.uuids["revision"], ) for key, value in revision["extra_headers"]: revision_header_writer.write( (hash_to_hex_or_none(revision["id"]), key, value) ) def process_directory(self, directory): directory_writer = self.get_writer_for("directory") directory_writer.write( - (hash_to_hex_or_none(directory["id"]), directory.get("raw_manifest"),) + ( + hash_to_hex_or_none(directory["id"]), + directory.get("raw_manifest"), + ) ) directory_entry_writer = self.get_writer_for( - "directory_entry", unique_id=self.uuids["directory"], + "directory_entry", + unique_id=self.uuids["directory"], ) for entry in directory["entries"]: directory_entry_writer.write( ( hash_to_hex_or_none(directory["id"]), entry["name"], entry["type"], hash_to_hex_or_none(entry["target"]), entry["perms"], ) ) def process_content(self, content): content_writer = self.get_writer_for("content") data = None if self.with_data: obj_id = content[ID_HASH_ALGO] try: data = self.objstorage.get(obj_id) except ObjNotFoundError: logger.warning(f"Missing object {hash_to_hex(obj_id)}") content_writer.write( ( hash_to_hex_or_none(content["sha1"]), hash_to_hex_or_none(content["sha1_git"]), hash_to_hex_or_none(content["sha256"]), hash_to_hex_or_none(content["blake2s256"]), content["length"], content["status"], data, ) ) def process_skipped_content(self, skipped_content): skipped_content_writer = self.get_writer_for("skipped_content") skipped_content_writer.write( ( hash_to_hex_or_none(skipped_content["sha1"]), hash_to_hex_or_none(skipped_content["sha1_git"]), hash_to_hex_or_none(skipped_content["sha256"]), hash_to_hex_or_none(skipped_content["blake2s256"]), skipped_content["length"], skipped_content["status"], skipped_content["reason"], ) ) diff --git a/swh/dataset/journalprocessor.py b/swh/dataset/journalprocessor.py index ced495e..932b6df 100644 --- a/swh/dataset/journalprocessor.py +++ b/swh/dataset/journalprocessor.py @@ -1,460 +1,462 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import collections import concurrent.futures from concurrent.futures import FIRST_EXCEPTION, ProcessPoolExecutor import contextlib import logging import multiprocessing from pathlib import Path import time from typing import ( Any, Container, Dict, List, Mapping, Optional, Sequence, Set, Tuple, Type, ) from confluent_kafka import Message, TopicPartition import tqdm from swh.dataset.exporter import Exporter from swh.dataset.utils import LevelDBSet from swh.journal.client import JournalClient from swh.journal.serializers import kafka_to_value from swh.storage.fixer import fix_objects logger = logging.getLogger(__name__) class JournalClientOffsetRanges(JournalClient): """ A subclass of JournalClient reading only inside some specific offset range. Partition assignments have to be manually given to the class. This client can only read a single topic at a time. """ def __init__( self, *args, offset_ranges: Mapping[int, Tuple[int, int]] = None, assignment: Sequence[int] = None, progress_queue: multiprocessing.Queue = None, refresh_every: int = 200, **kwargs, ): """ Args: offset_ranges: A mapping of partition_id -> (low, high) offsets that define the boundaries of the messages to consume. assignment: The list of partitions to assign to this client. progress_queue: a multiprocessing.Queue where the current progress will be reported. refresh_every: the refreshing rate of the progress reporting. """ self.offset_ranges = offset_ranges self.progress_queue = progress_queue self.refresh_every = refresh_every self.assignment = assignment self._messages_to_commit: List[Message] = [] self._partitions_to_unsubscribe: Set[int] = set() self.count = None self.topic_name: Optional[str] = None kwargs["stop_on_eof"] = True # Stop when the assignment is empty super().__init__(*args, **kwargs) def subscribe(self): self.topic_name = self.subscription[0] time.sleep(0.1) # https://github.com/edenhill/librdkafka/issues/1983 logger.debug("Changing assignment to %s", str(self.assignment)) self.consumer.assign( [TopicPartition(self.topic_name, pid) for pid in self.assignment] ) def unsubscribe(self, partitions: Container[int]): assert self.assignment is not None self.assignment = [pid for pid in self.assignment if pid not in partitions] self.consumer.assign( [TopicPartition(self.topic_name, pid) for pid in self.assignment] ) def process(self, worker_fn): self.count = 0 try: if self.assignment: super().process(worker_fn) finally: self.progress_queue.put(None) def handle_offset(self, message): """ Check whether the client has reached the end of the current partition, and trigger a reassignment if that is the case. """ offset = message.offset() partition_id = message.partition() if offset < 0: # Uninitialized partition offset return if self.count % self.refresh_every == 0: self.progress_queue.put({partition_id: offset}) if offset >= self.offset_ranges[partition_id][1] - 1: if partition_id in self.assignment: self.progress_queue.put({partition_id: offset}) # unsubscribe from partition but make sure current message's # offset will be committed after executing the worker_fn in # process(); see handle_messages() below self._messages_to_commit.append(message) # delay the unsubcription to handle_messages() to prevent # rdkakfa errors like # # rd_kafka_assignment_partition_stopped: # Assertion `rktp->rktp_started' failed # # in case the unsubscription from parition_id do actually tries # to subscribe an already depleted partition. self._partitions_to_unsubscribe.add(partition_id) def deserialize_message(self, message, object_type=None): """ Override of the message deserialization to hook the handling of the message offset. We also return the raw objects instead of deserializing them because we will need the partition ID later. """ self.handle_offset(message) self.count += 1 return message def handle_messages(self, messages, worker_fn): """Override of the handle_messages() method to get a chance to commit messages. Make sure messages properly handled by `worker_fn` (executed in super()) do get committed in kafka even if their originating partition has been desubscribed from. This helps having a consistent view of the consumption of each partition at the end of the export process (EOF). """ nb_processed, at_eof = super().handle_messages(messages, worker_fn) for msg in self._messages_to_commit: self.consumer.commit(message=msg) self._messages_to_commit.clear() if self._partitions_to_unsubscribe: partitions = list(self._partitions_to_unsubscribe) self._partitions_to_unsubscribe.clear() self.unsubscribe(partitions) return nb_processed, at_eof class ParallelJournalProcessor: """ Reads the given object type from the journal in parallel. It creates one JournalExportWorker per process. """ def __init__( self, config, exporters: Sequence[Tuple[Type[Exporter], Dict[str, Any]]], export_id: str, obj_type: str, node_sets_path: Path, processes: int = 1, ): """ Args: config: the exporter config, which should also include the JournalClient configuration. exporters: a list of Exporter to process the objects export_id: a unique identifier for the export that will be used as part of a Kafka consumer group ID. obj_type: The type of SWH object to export. node_sets_path: A directory where to store the node sets. processes: The number of processes to run. """ self.config = config self.exporters = exporters prefix = self.config["journal"].get("group_id", "swh-dataset-export-") self.group_id = f"{prefix}{export_id}" self.obj_type = obj_type self.processes = processes self.node_sets_path = node_sets_path self.offsets = None def get_offsets(self): """ Compute (lo, high) offset boundaries for all partitions. First pass to fetch all the current low and high watermark offsets of each partition to define the consumption boundaries. If available, use committed offsets as lo offset boundaries. """ if self.offsets is None: cfg = self.config["journal"].copy() cfg["object_types"] = [self.obj_type] cfg["group_id"] = self.group_id client = JournalClient(**cfg) topic_name = client.subscription[0] topics = client.consumer.list_topics(topic_name).topics partitions = topics[topic_name].partitions self.offsets = {} # LOW watermark offset: The offset of the earliest message in the # topic/partition. If no messages have been written to the topic, # the low watermark offset is set to 0. The low watermark will also # be 0 if one message has been written to the partition (with # offset 0). # HIGH watermark offset: the offset of the latest message in the # topic/partition available for consumption + 1 def fetch_insert_partition_id(partition_id): logger.debug("Fetching offset for partition %s", partition_id) tp = TopicPartition(topic_name, partition_id) (lo, hi) = client.consumer.get_watermark_offsets(tp) logger.debug("[%s] (lo,hi)=(%s, %s)", partition_id, lo, hi) if hi > lo: # hi == low means there is nothing in the partition to consume. # If the partition is not empty, retrieve the committed offset, # if any, to use it at lo offset. committed = client.consumer.committed([tp])[0] logger.debug( "[%s] committed offset: %s", partition_id, committed.offset ) lo = max(lo, committed.offset) if hi > lo: # do only process the partition is there are actually new # messages to process (partition not empty and committed # offset is behind the high watermark). self.offsets[partition_id] = (lo, hi) logger.debug( "Fetching partition offsets using %s processes", self.processes ) with concurrent.futures.ThreadPoolExecutor( max_workers=self.processes ) as executor: list( tqdm.tqdm( executor.map(fetch_insert_partition_id, partitions.keys()), total=len(partitions), desc=" - Offset", ) ) client.close() return self.offsets def run(self): """ Run the parallel export. """ offsets = self.get_offsets() to_assign = list(offsets.keys()) if not to_assign: print(f" - Export ({self.obj_type}): skipped (nothing to export)") return manager = multiprocessing.Manager() q = manager.Queue() with ProcessPoolExecutor(self.processes + 1) as pool: futures = [] for i in range(self.processes): futures.append( pool.submit( self.export_worker, assignment=to_assign[i :: self.processes], progress_queue=q, ) ) futures.append(pool.submit(self.progress_worker, queue=q)) concurrent.futures.wait(futures, return_when=FIRST_EXCEPTION) for f in futures: if f.running(): continue exc = f.exception() if exc: pool.shutdown(wait=False) f.result() raise exc def progress_worker(self, queue=None): """ An additional worker process that reports the current progress of the export between all the different parallel consumers and across all the partitions, by consuming the shared progress reporting Queue. """ d = {} active_workers = self.processes offset_diff = sum((hi - lo) for lo, hi in self.offsets.values()) desc = " - Export" with tqdm.tqdm(total=offset_diff, desc=desc, unit_scale=True) as pbar: while active_workers: item = queue.get() if item is None: active_workers -= 1 continue d.update(item) progress = sum(n + 1 - self.offsets[p][0] for p, n in d.items()) - pbar.set_postfix(workers=f"{active_workers}/{self.processes}",) + pbar.set_postfix( + workers=f"{active_workers}/{self.processes}", + ) pbar.update(progress - pbar.n) def export_worker(self, assignment, progress_queue): worker = JournalProcessorWorker( self.config, self.exporters, self.group_id, self.obj_type, self.offsets, assignment, progress_queue, self.node_sets_path, ) with worker: worker.run() class JournalProcessorWorker: """ Worker process that processes all the messages and calls the given exporters for each object read from the journal. """ def __init__( self, config, exporters: Sequence[Tuple[Type[Exporter], Dict[str, Any]]], group_id: str, obj_type: str, offsets: Dict[int, Tuple[int, int]], assignment: Sequence[int], progress_queue: multiprocessing.Queue, node_sets_path: Path, ): self.config = config self.group_id = group_id self.obj_type = obj_type self.offsets = offsets self.assignment = assignment self.progress_queue = progress_queue self.node_sets_path = node_sets_path self.node_sets_path.mkdir(exist_ok=True, parents=True) self.node_sets: Dict[Tuple[int, str], LevelDBSet] = {} self.exporters = [ exporter_class(config, **kwargs) for exporter_class, kwargs in exporters ] self.exit_stack: contextlib.ExitStack = contextlib.ExitStack() def __enter__(self): self.exit_stack.__enter__() for exporter in self.exporters: self.exit_stack.enter_context(exporter) return self def __exit__(self, exc_type, exc_value, traceback): self.exit_stack.__exit__(exc_type, exc_value, traceback) def get_node_set_for_object(self, partition_id: int, object_id: bytes): """ Return an on-disk set object, which stores the nodes that have already been processed. Node sets are sharded by partition ID (as each object is guaranteed to be assigned to a deterministic Kafka partition) then by object ID prefix. The sharding path of each file looks like: .node_sets/{origin..content}/part-{0..256}/nodes-{0..f}.sqlite """ # obj_id_prefix = "{:x}".format(object_id[0] % 16) obj_id_prefix = "all" # disable sharding for now shard_id = (partition_id, obj_id_prefix) if shard_id not in self.node_sets: node_set_dir = ( self.node_sets_path / self.obj_type / ("part-{}".format(str(partition_id))) ) node_set_dir.mkdir(exist_ok=True, parents=True) node_set_file = node_set_dir / "nodes-{}.db".format(obj_id_prefix) node_set = LevelDBSet(node_set_file) self.exit_stack.enter_context(node_set) self.node_sets[shard_id] = node_set return self.node_sets[shard_id] def run(self): """ Start a Journal client on the given assignment and process all the incoming messages. """ logger.debug("Start the JournalProcessorWorker") cfg = self.config["journal"].copy() cfg.update( object_types=[self.obj_type], group_id=self.group_id, debug="cgrp,broker", offset_ranges=self.offsets, assignment=self.assignment, progress_queue=self.progress_queue, **{"message.max.bytes": str(500 * 1024 * 1024)}, ) client = JournalClientOffsetRanges(**cfg) client.process(self.process_messages) def process_messages(self, messages): """ Process the incoming Kafka messages. """ for object_type, message_list in messages.items(): fixed_objects_by_partition = collections.defaultdict(list) for message in message_list: fixed_objects_by_partition[message.partition()].extend( zip( [message.key()], fix_objects(object_type, [kafka_to_value(message.value())]), ) ) for partition, objects in fixed_objects_by_partition.items(): for (key, obj) in objects: self.process_message(object_type, partition, key, obj) def process_message(self, object_type, partition, obj_key, obj): """ Process a single incoming Kafka message if the object it refers to has not been processed yet. It uses an on-disk set to make sure that each object is only ever processed once. """ node_set = self.get_node_set_for_object(partition, obj_key) if not node_set.add(obj_key): # Node already processed, skipping. return for exporter in self.exporters: try: exporter.process_object(object_type, obj) except Exception: logger.exception( "Exporter %s: error while exporting the object: %s", exporter.__class__.__name__, str(obj), ) diff --git a/swh/dataset/test/test_edges.py b/swh/dataset/test/test_edges.py index ec8fa77..bb43369 100644 --- a/swh/dataset/test/test_edges.py +++ b/swh/dataset/test/test_edges.py @@ -1,576 +1,583 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from base64 import b64encode import collections import hashlib from typing import Tuple from unittest.mock import Mock, call import pytest from swh.dataset.exporters.edges import GraphEdgesExporter, sort_graph_nodes from swh.dataset.utils import ZSTFile from swh.model.hashutil import MultiHash, hash_to_bytes DATE = { "timestamp": {"seconds": 1234567891, "microseconds": 0}, "offset": 120, "negative_utc": False, } TEST_CONTENT = { **MultiHash.from_data(b"foo").digest(), "length": 3, "status": "visible", } TEST_REVISION = { "id": hash_to_bytes("7026b7c1a2af56521e951c01ed20f255fa054238"), "message": b"hello", "date": DATE, "committer": {"fullname": b"foo", "name": b"foo", "email": b""}, "author": {"fullname": b"foo", "name": b"foo", "email": b""}, "committer_date": DATE, "type": "git", "directory": b"\x01" * 20, "synthetic": False, "metadata": None, "parents": [], } TEST_RELEASE = { "id": hash_to_bytes("d81cc0710eb6cf9efd5b920a8453e1e07157b6cd"), "name": b"v0.0.1", "date": { "timestamp": {"seconds": 1234567890, "microseconds": 0}, "offset": 120, "negative_utc": False, }, "author": {"author": {"fullname": b"foo", "name": b"foo", "email": b""}}, "target_type": "revision", "target": b"\x04" * 20, "message": b"foo", "synthetic": False, } TEST_ORIGIN = {"url": "https://somewhere.org/den/fox"} TEST_ORIGIN_2 = {"url": "https://somewhere.org/den/fox/2"} TEST_ORIGIN_VISIT_STATUS = { "origin": TEST_ORIGIN["url"], "visit": 1, "date": "2013-05-07 04:20:39.369271+00:00", "snapshot": None, # TODO "status": "ongoing", # TODO "metadata": {"foo": "bar"}, } class FakeDiskSet(set): """ A set with an add() method that returns whether the item has been added or was already there. Used to replace disk sets in unittests. """ def add(self, v): assert isinstance(v, bytes) r = True if v in self: r = False super().add(v) return r @pytest.fixture def exporter(): def wrapped(messages, config=None) -> Tuple[Mock, Mock]: if config is None: config = {} exporter = GraphEdgesExporter(config, "/dummy_path") node_writer = Mock() edge_writer = Mock() exporter.get_writers_for = lambda *a, **k: ( # type: ignore node_writer, edge_writer, ) for object_type, objects in messages.items(): for obj in objects: exporter.process_object(object_type, obj) return node_writer.write, edge_writer.write return wrapped def binhash(s): return hashlib.sha1(s.encode()).digest() def hexhash(s): return hashlib.sha1(s.encode()).hexdigest() def b64e(s: str) -> str: return b64encode(s.encode()).decode() def test_export_origin(exporter): - node_writer, edge_writer = exporter({"origin": [{"url": "ori1"}, {"url": "ori2"},]}) + node_writer, edge_writer = exporter( + { + "origin": [ + {"url": "ori1"}, + {"url": "ori2"}, + ] + } + ) assert node_writer.mock_calls == [ call(f"swh:1:ori:{hexhash('ori1')}\n"), call(f"swh:1:ori:{hexhash('ori2')}\n"), ] assert edge_writer.mock_calls == [] def test_export_origin_visit_status(exporter): node_writer, edge_writer = exporter( { "origin_visit_status": [ { **TEST_ORIGIN_VISIT_STATUS, "origin": "ori1", "snapshot": binhash("snp1"), }, { **TEST_ORIGIN_VISIT_STATUS, "origin": "ori2", "snapshot": binhash("snp2"), }, ] } ) assert node_writer.mock_calls == [] assert edge_writer.mock_calls == [ call(f"swh:1:ori:{hexhash('ori1')} swh:1:snp:{hexhash('snp1')}\n"), call(f"swh:1:ori:{hexhash('ori2')} swh:1:snp:{hexhash('snp2')}\n"), ] def test_export_snapshot_simple(exporter): node_writer, edge_writer = exporter( { "snapshot": [ { "id": binhash("snp1"), "branches": { b"refs/heads/master": { "target": binhash("rev1"), "target_type": "revision", }, b"HEAD": {"target": binhash("rev1"), "target_type": "revision"}, }, }, { "id": binhash("snp2"), "branches": { b"refs/heads/master": { "target": binhash("rev1"), "target_type": "revision", }, b"HEAD": {"target": binhash("rev2"), "target_type": "revision"}, b"bcnt": {"target": binhash("cnt1"), "target_type": "content"}, b"bdir": { "target": binhash("dir1"), "target_type": "directory", }, b"brel": {"target": binhash("rel1"), "target_type": "release"}, b"bsnp": {"target": binhash("snp1"), "target_type": "snapshot"}, }, }, {"id": binhash("snp3"), "branches": {}}, ] } ) assert node_writer.mock_calls == [ call(f"swh:1:snp:{hexhash('snp1')}\n"), call(f"swh:1:snp:{hexhash('snp2')}\n"), call(f"swh:1:snp:{hexhash('snp3')}\n"), ] assert edge_writer.mock_calls == [ call( f"swh:1:snp:{hexhash('snp1')} swh:1:rev:{hexhash('rev1')}" f" {b64e('refs/heads/master')}\n" ), call( f"swh:1:snp:{hexhash('snp1')} swh:1:rev:{hexhash('rev1')}" f" {b64e('HEAD')}\n" ), call( f"swh:1:snp:{hexhash('snp2')} swh:1:rev:{hexhash('rev1')}" f" {b64e('refs/heads/master')}\n" ), call( f"swh:1:snp:{hexhash('snp2')} swh:1:rev:{hexhash('rev2')}" f" {b64e('HEAD')}\n" ), call( f"swh:1:snp:{hexhash('snp2')} swh:1:cnt:{hexhash('cnt1')}" f" {b64e('bcnt')}\n" ), call( f"swh:1:snp:{hexhash('snp2')} swh:1:dir:{hexhash('dir1')}" f" {b64e('bdir')}\n" ), call( f"swh:1:snp:{hexhash('snp2')} swh:1:rel:{hexhash('rel1')}" f" {b64e('brel')}\n" ), call( f"swh:1:snp:{hexhash('snp2')} swh:1:snp:{hexhash('snp1')}" f" {b64e('bsnp')}\n" ), ] def test_export_snapshot_aliases(exporter): node_writer, edge_writer = exporter( { "snapshot": [ { "id": binhash("snp1"), "branches": { b"origin_branch": { "target": binhash("rev1"), "target_type": "revision", }, b"alias1": {"target": b"origin_branch", "target_type": "alias"}, b"alias2": {"target": b"alias1", "target_type": "alias"}, b"alias3": {"target": b"alias2", "target_type": "alias"}, }, }, ] } ) assert node_writer.mock_calls == [call(f"swh:1:snp:{hexhash('snp1')}\n")] assert edge_writer.mock_calls == [ call( f"swh:1:snp:{hexhash('snp1')} swh:1:rev:{hexhash('rev1')}" f" {b64e('origin_branch')}\n" ), call( f"swh:1:snp:{hexhash('snp1')} swh:1:rev:{hexhash('rev1')}" f" {b64e('alias1')}\n" ), call( f"swh:1:snp:{hexhash('snp1')} swh:1:rev:{hexhash('rev1')}" f" {b64e('alias2')}\n" ), call( f"swh:1:snp:{hexhash('snp1')} swh:1:rev:{hexhash('rev1')}" f" {b64e('alias3')}\n" ), ] def test_export_snapshot_no_pull_requests(exporter): snp = { "id": binhash("snp1"), "branches": { b"refs/heads/master": { "target": binhash("rev1"), "target_type": "revision", }, b"refs/pull/42": {"target": binhash("rev2"), "target_type": "revision"}, b"refs/merge-requests/lol": { "target": binhash("rev3"), "target_type": "revision", }, b"refs/tags/v1.0.0": { "target": binhash("rev4"), "target_type": "revision", }, b"refs/patch/123456abc": { "target": binhash("rev5"), "target_type": "revision", }, }, } node_writer, edge_writer = exporter({"snapshot": [snp]}) assert edge_writer.mock_calls == [ call( f"swh:1:snp:{hexhash('snp1')} swh:1:rev:{hexhash('rev1')}" f" {b64e('refs/heads/master')}\n" ), call( f"swh:1:snp:{hexhash('snp1')} swh:1:rev:{hexhash('rev2')}" f" {b64e('refs/pull/42')}\n" ), call( f"swh:1:snp:{hexhash('snp1')} swh:1:rev:{hexhash('rev3')}" f" {b64e('refs/merge-requests/lol')}\n" ), call( f"swh:1:snp:{hexhash('snp1')} swh:1:rev:{hexhash('rev4')}" f" {b64e('refs/tags/v1.0.0')}\n" ), call( f"swh:1:snp:{hexhash('snp1')} swh:1:rev:{hexhash('rev5')}" f" {b64e('refs/patch/123456abc')}\n" ), ] node_writer, edge_writer = exporter( {"snapshot": [snp]}, config={"remove_pull_requests": True} ) assert edge_writer.mock_calls == [ call( f"swh:1:snp:{hexhash('snp1')} swh:1:rev:{hexhash('rev1')}" f" {b64e('refs/heads/master')}\n" ), call( f"swh:1:snp:{hexhash('snp1')} swh:1:rev:{hexhash('rev4')}" f" {b64e('refs/tags/v1.0.0')}\n" ), ] def test_export_releases(exporter): node_writer, edge_writer = exporter( { "release": [ { **TEST_RELEASE, "id": binhash("rel1"), "target": binhash("rev1"), "target_type": "revision", }, { **TEST_RELEASE, "id": binhash("rel2"), "target": binhash("rel1"), "target_type": "release", }, { **TEST_RELEASE, "id": binhash("rel3"), "target": binhash("dir1"), "target_type": "directory", }, { **TEST_RELEASE, "id": binhash("rel4"), "target": binhash("cnt1"), "target_type": "content", }, ] } ) assert node_writer.mock_calls == [ call(f"swh:1:rel:{hexhash('rel1')}\n"), call(f"swh:1:rel:{hexhash('rel2')}\n"), call(f"swh:1:rel:{hexhash('rel3')}\n"), call(f"swh:1:rel:{hexhash('rel4')}\n"), ] assert edge_writer.mock_calls == [ call(f"swh:1:rel:{hexhash('rel1')} swh:1:rev:{hexhash('rev1')}\n"), call(f"swh:1:rel:{hexhash('rel2')} swh:1:rel:{hexhash('rel1')}\n"), call(f"swh:1:rel:{hexhash('rel3')} swh:1:dir:{hexhash('dir1')}\n"), call(f"swh:1:rel:{hexhash('rel4')} swh:1:cnt:{hexhash('cnt1')}\n"), ] def test_export_revision(exporter): node_writer, edge_writer = exporter( { "revision": [ { **TEST_REVISION, "id": binhash("rev1"), "directory": binhash("dir1"), "parents": [binhash("rev2"), binhash("rev3")], }, { **TEST_REVISION, "id": binhash("rev2"), "directory": binhash("dir2"), "parents": [], }, ] } ) assert node_writer.mock_calls == [ call(f"swh:1:rev:{hexhash('rev1')}\n"), call(f"swh:1:rev:{hexhash('rev2')}\n"), ] assert edge_writer.mock_calls == [ call(f"swh:1:rev:{hexhash('rev1')} swh:1:dir:{hexhash('dir1')}\n"), call(f"swh:1:rev:{hexhash('rev1')} swh:1:rev:{hexhash('rev2')}\n"), call(f"swh:1:rev:{hexhash('rev1')} swh:1:rev:{hexhash('rev3')}\n"), call(f"swh:1:rev:{hexhash('rev2')} swh:1:dir:{hexhash('dir2')}\n"), ] def test_export_directory(exporter): node_writer, edge_writer = exporter( { "directory": [ { "id": binhash("dir1"), "entries": [ { "type": "file", "target": binhash("cnt1"), "name": b"cnt1", "perms": 0o644, }, { "type": "dir", "target": binhash("dir2"), "name": b"dir2", "perms": 0o755, }, { "type": "rev", "target": binhash("rev1"), "name": b"rev1", "perms": 0o160000, }, ], }, {"id": binhash("dir2"), "entries": []}, ] } ) assert node_writer.mock_calls == [ call(f"swh:1:dir:{hexhash('dir1')}\n"), call(f"swh:1:dir:{hexhash('dir2')}\n"), ] assert edge_writer.mock_calls == [ call( f"swh:1:dir:{hexhash('dir1')} swh:1:cnt:{hexhash('cnt1')}" f" {b64e('cnt1')} {0o644}\n" ), call( f"swh:1:dir:{hexhash('dir1')} swh:1:dir:{hexhash('dir2')}" f" {b64e('dir2')} {0o755}\n" ), call( f"swh:1:dir:{hexhash('dir1')} swh:1:rev:{hexhash('rev1')}" f" {b64e('rev1')} {0o160000}\n" ), ] def test_export_content(exporter): node_writer, edge_writer = exporter( { "content": [ {**TEST_CONTENT, "sha1_git": binhash("cnt1")}, {**TEST_CONTENT, "sha1_git": binhash("cnt2")}, ] } ) assert node_writer.mock_calls == [ call(f"swh:1:cnt:{hexhash('cnt1')}\n"), call(f"swh:1:cnt:{hexhash('cnt2')}\n"), ] assert edge_writer.mock_calls == [] def zstwrite(fp, lines): with ZSTFile(fp, "w") as writer: for line in lines: writer.write(line + "\n") def zstread(fp): with ZSTFile(fp, "r") as reader: return reader.read() def test_sort_pipeline(tmp_path): short_type_mapping = { "origin_visit_status": "ori", "snapshot": "snp", "release": "rel", "revision": "rev", "directory": "dir", "content": "cnt", } input_nodes = [ f"swh:1:{short}:{hexhash(short + str(x))}" for short in short_type_mapping.values() for x in range(4) ] input_edges = [ f"swh:1:ori:{hexhash('ori1')} swh:1:snp:{hexhash('snp1')}", f"swh:1:ori:{hexhash('ori2')} swh:1:snp:{hexhash('snp2')}", f"swh:1:ori:{hexhash('ori3')} swh:1:snp:{hexhash('snp3')}", f"swh:1:ori:{hexhash('ori4')} swh:1:snp:{hexhash('snpX')}", # missing dest f"swh:1:snp:{hexhash('snp1')} swh:1:rev:{hexhash('rev1')} {b64e('dup1')}", f"swh:1:snp:{hexhash('snp1')} swh:1:rev:{hexhash('rev1')} {b64e('dup2')}", f"swh:1:snp:{hexhash('snp3')} swh:1:cnt:{hexhash('cnt1')} {b64e('c1')}", f"swh:1:snp:{hexhash('snp4')} swh:1:rel:{hexhash('rel1')} {b64e('r1')}", f"swh:1:rel:{hexhash('rel1')} swh:1:rel:{hexhash('rel2')}", f"swh:1:rel:{hexhash('rel2')} swh:1:rev:{hexhash('rev1')}", f"swh:1:rel:{hexhash('rel3')} swh:1:rev:{hexhash('rev2')}", f"swh:1:rel:{hexhash('rel4')} swh:1:dir:{hexhash('dir1')}", f"swh:1:rev:{hexhash('rev1')} swh:1:rev:{hexhash('rev1')}", # dup f"swh:1:rev:{hexhash('rev1')} swh:1:rev:{hexhash('rev1')}", # dup f"swh:1:rev:{hexhash('rev1')} swh:1:rev:{hexhash('rev2')}", f"swh:1:rev:{hexhash('rev2')} swh:1:rev:{hexhash('revX')}", # missing dest f"swh:1:rev:{hexhash('rev3')} swh:1:rev:{hexhash('rev2')}", f"swh:1:rev:{hexhash('rev4')} swh:1:dir:{hexhash('dir1')}", f"swh:1:dir:{hexhash('dir1')} swh:1:cnt:{hexhash('cnt1')} {b64e('c1')} 42", f"swh:1:dir:{hexhash('dir1')} swh:1:dir:{hexhash('dir1')} {b64e('d1')} 1337", f"swh:1:dir:{hexhash('dir1')} swh:1:rev:{hexhash('rev1')} {b64e('r1')} 0", ] for obj_type, short_obj_type in short_type_mapping.items(): p = tmp_path / obj_type p.mkdir() edges = [e for e in input_edges if e.startswith(f"swh:1:{short_obj_type}")] zstwrite(p / "00.edges.csv.zst", edges[0::2]) zstwrite(p / "01.edges.csv.zst", edges[1::2]) nodes = [n for n in input_nodes if n.startswith(f"swh:1:{short_obj_type}")] zstwrite(p / "00.nodes.csv.zst", nodes[0::2]) zstwrite(p / "01.nodes.csv.zst", nodes[1::2]) sort_graph_nodes(tmp_path, config={"sort_buffer_size": "1M"}) output_nodes = zstread(tmp_path / "graph.nodes.csv.zst").split("\n") output_edges = zstread(tmp_path / "graph.edges.csv.zst").split("\n") output_labels = zstread(tmp_path / "graph.labels.csv.zst").split("\n") output_nodes = list(filter(bool, output_nodes)) output_edges = list(filter(bool, output_edges)) output_labels = list(filter(bool, output_labels)) expected_nodes = set(input_nodes) | set(e.split()[1] for e in input_edges) assert output_nodes == sorted(expected_nodes) assert int((tmp_path / "graph.nodes.count.txt").read_text()) == len(expected_nodes) assert sorted(output_edges) == sorted(input_edges) assert int((tmp_path / "graph.edges.count.txt").read_text()) == len(input_edges) expected_labels = set(e[2] for e in [e.split() for e in input_edges] if len(e) > 2) assert output_labels == sorted(expected_labels) actual_node_stats = (tmp_path / "graph.nodes.stats.txt").read_text().strip() expected_node_stats = "\n".join( sorted( "{} {}".format(k, v) for k, v in collections.Counter( node.split(":")[2] for node in expected_nodes ).items() ) ) assert actual_node_stats == expected_node_stats actual_edge_stats = (tmp_path / "graph.edges.stats.txt").read_text().strip() expected_edge_stats = "\n".join( sorted( "{} {}".format(k, v) for k, v in collections.Counter( "{}:{}".format(edge.split(":")[2], edge.split(":")[5]) for edge in input_edges ).items() ) ) assert actual_edge_stats == expected_edge_stats diff --git a/swh/dataset/test/test_orc.py b/swh/dataset/test/test_orc.py index 62ca105..bf62b25 100644 --- a/swh/dataset/test/test_orc.py +++ b/swh/dataset/test/test_orc.py @@ -1,336 +1,353 @@ # Copyright (C) 2020-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import collections from contextlib import contextmanager import math from pathlib import Path import tempfile import pyorc import pytest from swh.dataset.exporters import orc from swh.dataset.relational import MAIN_TABLES, RELATION_TABLES from swh.model.tests.swh_model_data import TEST_OBJECTS from swh.objstorage.factory import get_objstorage @contextmanager def orc_tmpdir(tmpdir): if tmpdir: yield Path(tmpdir) else: with tempfile.TemporaryDirectory() as tmpdir: yield Path(tmpdir) @contextmanager def orc_export(messages, config=None, tmpdir=None): with orc_tmpdir(tmpdir) as tmpdir: if config is None: config = {} with orc.ORCExporter(config, tmpdir) as exporter: for object_type, objects in messages.items(): for obj in objects: exporter.process_object(object_type, obj.to_dict()) yield tmpdir def orc_load(rootdir): res = collections.defaultdict(list) res["rootdir"] = rootdir for obj_type_dir in rootdir.iterdir(): for orc_file in obj_type_dir.iterdir(): with orc_file.open("rb") as orc_obj: reader = pyorc.Reader( orc_obj, converters={pyorc.TypeKind.TIMESTAMP: orc.SWHTimestampConverter}, ) obj_type = reader.user_metadata["swh_object_type"].decode() res[obj_type].extend(reader) return res def exporter(messages, config=None, tmpdir=None): with orc_export(messages, config, tmpdir) as exportdir: return orc_load(exportdir) def test_export_origin(): obj_type = "origin" output = exporter({obj_type: TEST_OBJECTS[obj_type]}) for obj in TEST_OBJECTS[obj_type]: assert (obj.url,) in output[obj_type] def test_export_origin_visit(): obj_type = "origin_visit" output = exporter({obj_type: TEST_OBJECTS[obj_type]}) for obj in TEST_OBJECTS[obj_type]: assert ( obj.origin, obj.visit, orc.datetime_to_tuple(obj.date), obj.type, ) in output[obj_type] def test_export_origin_visit_status(): obj_type = "origin_visit_status" output = exporter({obj_type: TEST_OBJECTS[obj_type]}) for obj in TEST_OBJECTS[obj_type]: assert ( obj.origin, obj.visit, orc.datetime_to_tuple(obj.date), obj.status, orc.hash_to_hex_or_none(obj.snapshot), obj.type, ) in output[obj_type] def test_export_snapshot(): obj_type = "snapshot" output = exporter({obj_type: TEST_OBJECTS[obj_type]}) for obj in TEST_OBJECTS[obj_type]: assert (orc.hash_to_hex_or_none(obj.id),) in output["snapshot"] for branch_name, branch in obj.branches.items(): if branch is None: continue assert ( orc.hash_to_hex_or_none(obj.id), branch_name, orc.hash_to_hex_or_none(branch.target), str(branch.target_type.value), ) in output["snapshot_branch"] def test_export_release(): obj_type = "release" output = exporter({obj_type: TEST_OBJECTS[obj_type]}) for obj in TEST_OBJECTS[obj_type]: assert ( orc.hash_to_hex_or_none(obj.id), obj.name, obj.message, orc.hash_to_hex_or_none(obj.target), obj.target_type.value, obj.author.fullname if obj.author else None, *orc.swh_date_to_tuple( obj.date.to_dict() if obj.date is not None else None ), obj.raw_manifest, ) in output[obj_type] def test_export_revision(): obj_type = "revision" output = exporter({obj_type: TEST_OBJECTS[obj_type]}) for obj in TEST_OBJECTS[obj_type]: assert ( orc.hash_to_hex_or_none(obj.id), obj.message, obj.author.fullname, *orc.swh_date_to_tuple( obj.date.to_dict() if obj.date is not None else None ), obj.committer.fullname, *orc.swh_date_to_tuple( obj.committer_date.to_dict() if obj.committer_date is not None else None ), orc.hash_to_hex_or_none(obj.directory), obj.type.value, obj.raw_manifest, ) in output["revision"] for i, parent in enumerate(obj.parents): assert ( orc.hash_to_hex_or_none(obj.id), orc.hash_to_hex_or_none(parent), i, ) in output["revision_history"] def test_export_directory(): obj_type = "directory" output = exporter({obj_type: TEST_OBJECTS[obj_type]}) for obj in TEST_OBJECTS[obj_type]: assert (orc.hash_to_hex_or_none(obj.id), obj.raw_manifest) in output[ "directory" ] for entry in obj.entries: assert ( orc.hash_to_hex_or_none(obj.id), entry.name, entry.type, orc.hash_to_hex_or_none(entry.target), entry.perms, ) in output["directory_entry"] def test_export_content(): obj_type = "content" output = exporter({obj_type: TEST_OBJECTS[obj_type]}) for obj in TEST_OBJECTS[obj_type]: assert ( orc.hash_to_hex_or_none(obj.sha1), orc.hash_to_hex_or_none(obj.sha1_git), orc.hash_to_hex_or_none(obj.sha256), orc.hash_to_hex_or_none(obj.blake2s256), obj.length, obj.status, None, ) in output[obj_type] def test_export_skipped_content(): obj_type = "skipped_content" output = exporter({obj_type: TEST_OBJECTS[obj_type]}) for obj in TEST_OBJECTS[obj_type]: assert ( orc.hash_to_hex_or_none(obj.sha1), orc.hash_to_hex_or_none(obj.sha1_git), orc.hash_to_hex_or_none(obj.sha256), orc.hash_to_hex_or_none(obj.blake2s256), obj.length, obj.status, obj.reason, ) in output[obj_type] def test_date_to_tuple(): ts = {"seconds": 123456, "microseconds": 1515} assert orc.swh_date_to_tuple({"timestamp": ts, "offset_bytes": b"+0100"}) == ( (123456, 1515), 60, b"+0100", ) assert orc.swh_date_to_tuple( { "timestamp": ts, "offset": 120, "negative_utc": False, "offset_bytes": b"+0100", } ) == ((123456, 1515), 60, b"+0100") assert orc.swh_date_to_tuple( - {"timestamp": ts, "offset": 120, "negative_utc": False,} + { + "timestamp": ts, + "offset": 120, + "negative_utc": False, + } ) == ((123456, 1515), 120, b"+0200") assert orc.swh_date_to_tuple( - {"timestamp": ts, "offset": 0, "negative_utc": True,} - ) == ((123456, 1515), 0, b"-0000",) + { + "timestamp": ts, + "offset": 0, + "negative_utc": True, + } + ) == ( + (123456, 1515), + 0, + b"-0000", + ) # mapping of related tables for each main table (if any) RELATED = { "snapshot": ["snapshot_branch"], "revision": ["revision_history", "revision_extra_headers"], "directory": ["directory_entry"], } @pytest.mark.parametrize( - "obj_type", MAIN_TABLES.keys(), + "obj_type", + MAIN_TABLES.keys(), ) @pytest.mark.parametrize("max_rows", (None, 1, 2, 10000)) def test_export_related_files(max_rows, obj_type, tmpdir): config = {"orc": {}} if max_rows is not None: config["orc"]["max_rows"] = {obj_type: max_rows} exporter({obj_type: TEST_OBJECTS[obj_type]}, config=config, tmpdir=tmpdir) # check there are as many ORC files as objects orcfiles = [fname for fname in (tmpdir / obj_type).listdir(f"{obj_type}-*.orc")] if max_rows is None: assert len(orcfiles) == 1 else: assert len(orcfiles) == math.ceil(len(TEST_OBJECTS[obj_type]) / max_rows) # check the number of related ORC files for related in RELATED.get(obj_type, ()): related_orcfiles = [ fname for fname in (tmpdir / related).listdir(f"{related}-*.orc") ] assert len(related_orcfiles) == len(orcfiles) # for each ORC file, check related files only reference objects in the # corresponding main table for orc_file in orcfiles: with orc_file.open("rb") as orc_obj: reader = pyorc.Reader( orc_obj, converters={pyorc.TypeKind.TIMESTAMP: orc.SWHTimestampConverter}, ) uuid = reader.user_metadata["swh_uuid"].decode() assert orc_file.basename == f"{obj_type}-{uuid}.orc" rows = list(reader) obj_ids = [row[0] for row in rows] # check the related tables for related in RELATED.get(obj_type, ()): orc_file = tmpdir / related / f"{related}-{uuid}.orc" with orc_file.open("rb") as orc_obj: reader = pyorc.Reader( orc_obj, converters={pyorc.TypeKind.TIMESTAMP: orc.SWHTimestampConverter}, ) assert reader.user_metadata["swh_uuid"].decode() == uuid rows = list(reader) # check branches in this file only concern current snapshot (obj_id) for row in rows: assert row[0] in obj_ids @pytest.mark.parametrize( - "obj_type", MAIN_TABLES.keys(), + "obj_type", + MAIN_TABLES.keys(), ) def test_export_related_files_separated(obj_type, tmpdir): exporter({obj_type: TEST_OBJECTS[obj_type]}, tmpdir=tmpdir) # check there are as many ORC files as objects orcfiles = [fname for fname in (tmpdir / obj_type).listdir(f"{obj_type}-*.orc")] assert len(orcfiles) == 1 # check related ORC files are in their own directory for related in RELATED.get(obj_type, ()): related_orcfiles = [ fname for fname in (tmpdir / related).listdir(f"{related}-*.orc") ] assert len(related_orcfiles) == len(orcfiles) @pytest.mark.parametrize("table_name", RELATION_TABLES.keys()) def test_export_invalid_max_rows(table_name): config = {"orc": {"max_rows": {table_name: 10}}} with pytest.raises(ValueError): exporter({}, config=config) def test_export_content_with_data(monkeypatch, tmpdir): obj_type = "content" objstorage = get_objstorage("memory") for content in TEST_OBJECTS[obj_type]: objstorage.add(content.data) def get_objstorage_mock(**kw): if kw.get("cls") == "mock": return objstorage monkeypatch.setattr(orc, "get_objstorage", get_objstorage_mock) config = { - "orc": {"with_data": True, "objstorage": {"cls": "mock"},}, + "orc": { + "with_data": True, + "objstorage": {"cls": "mock"}, + }, } output = exporter({obj_type: TEST_OBJECTS[obj_type]}, config=config, tmpdir=tmpdir) for obj in TEST_OBJECTS[obj_type]: assert ( orc.hash_to_hex_or_none(obj.sha1), orc.hash_to_hex_or_none(obj.sha1_git), orc.hash_to_hex_or_none(obj.sha256), orc.hash_to_hex_or_none(obj.blake2s256), obj.length, obj.status, obj.data, ) in output[obj_type] diff --git a/swh/dataset/utils.py b/swh/dataset/utils.py index d1be461..bd84335 100644 --- a/swh/dataset/utils.py +++ b/swh/dataset/utils.py @@ -1,146 +1,149 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import sqlite3 import subprocess try: # Plyvel shouldn't be a hard dependency if we want to use sqlite instead import plyvel except ImportError: plyvel = None class ZSTFile: """ Object-like wrapper around a ZST file. Uses a subprocess of the "zstd" command to compress and deflate the objects. """ def __init__(self, path: str, mode: str = "r"): if mode not in ("r", "rb", "w", "wb"): raise ValueError(f"ZSTFile mode {mode} is invalid.") self.path = path self.mode = mode def __enter__(self) -> "ZSTFile": is_text = not (self.mode in ("rb", "wb")) writing = self.mode in ("w", "wb") if writing: cmd = ["zstd", "-f", "-q", "-o", self.path] else: cmd = ["zstdcat", self.path] self.process = subprocess.Popen( - cmd, text=is_text, stdin=subprocess.PIPE, stdout=subprocess.PIPE, + cmd, + text=is_text, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, ) return self def __exit__(self, exc_type, exc_value, tb): self.process.stdin.close() self.process.stdout.close() self.process.wait() def read(self, *args): return self.process.stdout.read(*args) def write(self, buf): self.process.stdin.write(buf) class SQLiteSet: """ On-disk Set object for hashes using SQLite as an indexer backend. Used to deduplicate objects when processing large queues with duplicates. """ def __init__(self, db_path): self.db_path = db_path def __enter__(self): self.db = sqlite3.connect(str(self.db_path)) self.db.execute( "CREATE TABLE IF NOT EXISTS" " tmpset (val TEXT NOT NULL PRIMARY KEY)" " WITHOUT ROWID" ) self.db.execute("PRAGMA synchronous = OFF") self.db.execute("PRAGMA journal_mode = OFF") return self def __exit__(self, exc_type, exc_val, exc_tb): self.db.commit() self.db.close() def add(self, v: bytes) -> bool: """ Add an item to the set. Args: v: The value to add to the set. Returns: True if the value was added to the set, False if it was already present. """ try: self.db.execute("INSERT INTO tmpset(val) VALUES (?)", (v.hex(),)) except sqlite3.IntegrityError: return False else: return True class LevelDBSet: """ On-disk Set object for hashes using LevelDB as an indexer backend. Used to deduplicate objects when processing large queues with duplicates. """ def __init__(self, db_path): self.db_path = db_path if plyvel is None: raise ImportError("Plyvel library not found, required for LevelDBSet") def __enter__(self): self.db = plyvel.DB(str(self.db_path), create_if_missing=True) return self def __exit__(self, exc_type, exc_val, exc_tb): self.db.close() def add(self, v: bytes) -> bool: """ Add an item to the set. Args: v: The value to add to the set. Returns: True if the value was added to the set, False if it was already present. """ if self.db.get(v): return False else: self.db.put(v, b"T") return True def remove_pull_requests(snapshot): """ Heuristic to filter out pull requests in snapshots: remove all branches that start with refs/ but do not start with refs/heads or refs/tags. """ # Copy the items with list() to remove items during iteration for branch_name, branch in list(snapshot["branches"].items()): original_branch_name = branch_name while branch and branch.get("target_type") == "alias": branch_name = branch["target"] branch = snapshot["branches"].get(branch_name) if branch is None or not branch_name: continue if branch_name.startswith(b"refs/") and not ( branch_name.startswith(b"refs/heads") or branch_name.startswith(b"refs/tags") ): snapshot["branches"].pop(original_branch_name)