diff --git a/swh/scanner/cli.py b/swh/scanner/cli.py --- a/swh/scanner/cli.py +++ b/swh/scanner/cli.py @@ -6,6 +6,7 @@ # WARNING: do not import unnecessary things here to keep cli startup time under # control import os +import sys from typing import Any, Dict, Optional import click @@ -15,6 +16,8 @@ from swh.core.cli import CONTEXT_SETTINGS from swh.core.cli import swh as swh_cli_group +from .exceptions import DBError + # All generic config code should reside in swh.core.config CONFIG_ENVVAR = "SWH_CONFIG_FILE" DEFAULT_CONFIG_PATH = os.path.join(click.get_app_dir("swh"), "global.yml") @@ -44,6 +47,16 @@ {CONFIG_FILE_HELP}""" +def setup_config(ctx, api_url): + config = ctx.obj["config"] + if api_url: + if not api_url.endswith("/"): + api_url += "/" + config["web-api"]["url"] = api_url + + return config + + @swh_cli_group.group( name="scanner", context_settings=CONTEXT_SETTINGS, help=SCANNER_HELP, ) @@ -121,15 +134,61 @@ present in the archive""" import swh.scanner.scanner as scanner - config = ctx.obj["config"] - if api_url: - if not api_url.endswith("/"): - api_url += "/" - config["web-api"]["url"] = api_url - + config = setup_config(ctx, api_url) scanner.scan(config, root_path, patterns, out_fmt, interactive) +@scanner.group("db") +@click.pass_context +def db(ctx): + pass + + +@db.command("import") +@click.option( + "-s", + "--chunk-size", + "chunk_size", + default="10000", + metavar="SIZE", + show_default=True, + type=int, + help="The chunk size ", +) +@click.option( + "-i", + "--input", + "input_file", + metavar="INPUT_FILE", + type=click.File("r"), + help="A file containing SWHIDs", +) +@click.option( + "-o", + "--output", + "output_file_db", + metavar="OUTPUT_DB_FILE", + default="SWHID_DB.sqlite", + show_default=True, + help="The name of the generated sqlite database", +) +@click.pass_context +def import_(ctx, chunk_size, input_file, output_file_db): + """Parse an input list of SWHID to generate a local sqlite database + """ + from .db import Db + + db = Db(output_file_db) + cur = db.conn.cursor() + try: + db.create_from(input_file, chunk_size, cur) + db.close() + except DBError: + print("Failed to create database") + os.remove(output_file_db) + sys.exit(1) + + def main(): return scanner(auto_envvar_prefix="SWH_SCANNER") diff --git a/swh/scanner/db.py b/swh/scanner/db.py new file mode 100644 --- /dev/null +++ b/swh/scanner/db.py @@ -0,0 +1,69 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +""" +This module is an interface to interact with the local database +where the SWHIDs will be saved for the local API. + +SWHIDs can be added directly from an input file. +""" + +from io import TextIOWrapper +from pathlib import Path +import sqlite3 +from typing import Iterable + +from swh.core.utils import grouper + +from .exceptions import DBError + + +class Db: + """Local database interface""" + + def __init__(self, db_file: Path): + self.db_file: Path = db_file + self.conn: sqlite3.Connection = sqlite3.connect( + db_file, check_same_thread=False + ) + + def close(self): + """Close the connection to the database.""" + self.conn.close() + + def create_table(self, cur: sqlite3.Cursor): + """Create the table where the SWHIDs will be stored.""" + cur.execute("""CREATE TABLE IF NOT EXISTS swhids (swhid text PRIMARY KEY)""") + + def add(self, swhids: Iterable[str], chunk_size: int, cur: sqlite3.Cursor): + """Insert the SWHID inside the database.""" + for swhids_chunk in grouper(swhids, chunk_size): + cur.executemany( + """INSERT INTO swhids VALUES (?)""", + [(swhid_chunk,) for swhid_chunk in swhids_chunk], + ) + + def create_from( + self, input_file: TextIOWrapper, chunk_size: int, cur: sqlite3.Cursor + ): + """Create a new database with the SWHIDs present inside the input file.""" + self.create_table(cur) + # use a set to avoid equal swhid + swhids = set(line.strip() for line in input_file.readlines()) + + try: + self.add(list(swhids), chunk_size, cur) + cur.close() + self.conn.commit() + except Exception: + raise DBError + + def known(self, swhid: str, cur: sqlite3.Cursor): + """Check if a given SWHID is present or not inside the local database.""" + cur.execute("""SELECT 1 FROM swhids WHERE swhid=?""", (swhid,)) + res = cur.fetchone() + cur.close() + + return res is not None diff --git a/swh/scanner/exceptions.py b/swh/scanner/exceptions.py --- a/swh/scanner/exceptions.py +++ b/swh/scanner/exceptions.py @@ -12,6 +12,10 @@ pass +class DBError(Exception): + pass + + class APIError(Exception): def __str__(self): return '"%s"' % self.args diff --git a/swh/scanner/tests/conftest.py b/swh/scanner/tests/conftest.py --- a/swh/scanner/tests/conftest.py +++ b/swh/scanner/tests/conftest.py @@ -15,6 +15,7 @@ from swh.model.cli import swhid_of_dir, swhid_of_file from swh.scanner.model import Tree +from .data import present_swhids from .flask_api import create_app @@ -136,6 +137,20 @@ return test_sample_folder +@pytest.fixture +def test_swhids_sample(tmp_path): + """Create and return the opened "swhids_sample" file, + filled with present swhids present in data.py + """ + test_swhids_sample = Path(os.path.join(tmp_path, "swhids_sample.txt")) + + with open(test_swhids_sample, "w") as f: + f.write("\n".join(swhid for swhid in present_swhids)) + + assert test_swhids_sample.exists() + return open(test_swhids_sample, "r") + + @pytest.fixture(scope="session") def app(): """Flask backend API (used by live_server).""" diff --git a/swh/scanner/tests/test_cli.py b/swh/scanner/tests/test_cli.py --- a/swh/scanner/tests/test_cli.py +++ b/swh/scanner/tests/test_cli.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import os from pathlib import Path from unittest.mock import Mock, call @@ -12,6 +13,8 @@ import swh.scanner.cli as cli import swh.scanner.scanner as scanner +from .data import present_swhids + DATADIR = Path(__file__).absolute().parent / "data" CONFIG_PATH_GOOD = str(DATADIR / "global.yml") CONFIG_PATH_GOOD2 = str(DATADIR / "global2.yml") # alternative to global.yml @@ -41,6 +44,17 @@ return CliRunner(env={"SWH_CONFIG_FILE": None}) +@pytest.fixture(scope="function") +def swhids_input_file(tmp_path): + swhids_input_file = Path(os.path.join(tmp_path, "input_file.txt")) + + with open(swhids_input_file, "w") as f: + f.write("\n".join(swhid for swhid in present_swhids)) + + assert swhids_input_file.exists() + return swhids_input_file + + # TEST BEGIN # For nominal code paths, check that the right config file is loaded @@ -119,3 +133,18 @@ res = cli_runner.invoke(cli.scanner, ["scan", ROOTPATH_GOOD, "-u", API_URL]) assert res.exit_code == 0 assert m_scanner.scan.call_count == 1 + + +def test_db_option(cli_runner, swhids_input_file, tmp_path): + res = cli_runner.invoke( + cli.scanner, + [ + "db", + "import", + "--input", + swhids_input_file, + "--output", + f"{tmp_path}/test_db.sqlite", + ], + ) + assert res.exit_code == 0 diff --git a/swh/scanner/tests/test_db.py b/swh/scanner/tests/test_db.py new file mode 100644 --- /dev/null +++ b/swh/scanner/tests/test_db.py @@ -0,0 +1,40 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from swh.scanner.db import Db + +from .data import present_swhids + +CHUNK_SIZE = 1000 + + +def test_db_create_from(tmp_path, test_swhids_sample): + tmp_dbfile = tmp_path / "tmp_db.sqlite" + + db = Db(tmp_dbfile) + cur = db.conn.cursor() + db.create_from(test_swhids_sample, CHUNK_SIZE, cur) + + for swhid in present_swhids: + cur = db.conn.cursor() + assert db.known(swhid, cur) + + +def test_db_create_from_one_not_present(tmp_path, test_swhids_sample): + not_present_swhid = "swh:1:cnt:fa8eacf43d8646129ae8adfa1648f9307d999999" + swhids = present_swhids + [not_present_swhid] + + tmp_dbfile = tmp_path / "tmp_db.sqlite" + + db = Db(tmp_dbfile) + cur = db.conn.cursor() + db.create_from(test_swhids_sample, CHUNK_SIZE, cur) + + for swhid in swhids: + cur = db.conn.cursor() + if swhid != not_present_swhid: + assert db.known(swhid, cur) + else: + assert not db.known(swhid, cur)