diff --git a/swh/scanner/cli.py b/swh/scanner/cli.py --- a/swh/scanner/cli.py +++ b/swh/scanner/cli.py @@ -44,6 +44,16 @@ {CONFIG_FILE_HELP}""" +def setup_config(ctx, api_url): + config = ctx.obj["config"] + if api_url: + if not api_url.endswith("/"): + api_url += "/" + config["web-api"]["url"] = api_url + + return config + + @swh_cli_group.group( name="scanner", context_settings=CONTEXT_SETTINGS, help=SCANNER_HELP, ) @@ -121,15 +131,55 @@ present in the archive""" import swh.scanner.scanner as scanner - config = ctx.obj["config"] - if api_url: - if not api_url.endswith("/"): - api_url += "/" - config["web-api"]["url"] = api_url - + config = setup_config(ctx, api_url) scanner.scan(config, root_path, patterns, out_fmt, interactive) +@scanner.group("db") +@click.pass_context +def db(ctx): + pass + + +@db.command("import") +@click.option( + "-u", + "--api-url", + default=None, + metavar="API_URL", + show_default=True, + help="URL for the api request", +) +@click.option( + "-i", + "--input", + "input_file", + metavar="INPUT_FILE", + type=click.File("r"), + help="A file containing SWHIDs", +) +@click.option( + "-o", + "--output", + "output_file_db", + metavar="OUTPUT_DB_FILE", + default="SWHID_DB.sqlite", + show_default=True, + help="The name of the generated sqlite database", +) +@click.pass_context +def import_(ctx, api_url, input_file, output_file_db): + """Parse an input list of SWHID to generate a local sqlite database + """ + from .db import Db + + config = setup_config(ctx, api_url) + db = Db(output_file_db) + cur = db.conn.cursor() + db.create_from(config, input_file, cur) + db.close() + + def main(): return scanner(auto_envvar_prefix="SWH_SCANNER") diff --git a/swh/scanner/db.py b/swh/scanner/db.py new file mode 100644 --- /dev/null +++ b/swh/scanner/db.py @@ -0,0 +1,98 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +""" +This module is an interface to interact with the local database +where the SWHIDs will be saved for the local API. + +SWHIDs can be added directly from an input file. +""" + +import asyncio +from io import TextIOWrapper +import os +from pathlib import Path +import sqlite3 +import sys +from typing import Any, Dict + +import aiohttp + +from .exceptions import APIError +from .scanner import swhids_discovery + + +class Db: + """Local database interface""" + + def __init__(self, db_file: Path): + self.db_file: Path = db_file + self.conn: sqlite3.Connection = sqlite3.connect( + db_file, check_same_thread=False + ) + + def close(self): + """Close the connection to the database.""" + self.conn.close() + + def create_table(self, cur: sqlite3.Cursor): + """Create the table where the SWHIDs will be stored.""" + cur.execute("""CREATE TABLE IF NOT EXISTS swhid_db (swhid text)""") + + def add(self, swhid: str, cur: sqlite3.Cursor): + """Insert the SWHID inside the database.""" + cur.execute( + """INSERT INTO swhid_db SELECT (?) + WHERE NOT EXISTS (SELECT 1 FROM swhid_db WHERE swhid=?)""", + (swhid, swhid), + ) + + async def query_swhids_from( + self, input_file: TextIOWrapper, config: Dict[str, Any], cur: sqlite3.Cursor + ): + """Query all the SWHIDs present inside the input file to the Web API and fill + the local database only with known SWHIDs. + """ + swhids = [line.strip() for line in input_file.readlines()] + + api_url = config["web-api"]["url"] + if config["web-api"]["auth-token"]: + headers = {"Authorization": f"Bearer {config['web-api']['auth-token']}"} + else: + headers = {} + + async with aiohttp.ClientSession(headers=headers) as session: + parsed_swhids = await swhids_discovery(swhids, session, api_url) + for swhid, attr in parsed_swhids.items(): + if attr["known"]: + self.add(swhid, cur) + + def create_from( + self, config: Dict[str, Any], input_file: TextIOWrapper, cur: sqlite3.Cursor + ): + """Create a new database with the SWHIDs present inside the input file.""" + self.create_table(cur) + + try: + loop = asyncio.get_event_loop() + loop.run_until_complete(self.query_swhids_from(input_file, config, cur)) + cur.close() + self.conn.commit() + except APIError: + print("Error during the api call") + os.remove(self.db_file) + sys.exit(1) + except Exception: + print("Failed to create database") + os.remove(self.db_file) + sys.exit(1) + + def check(self, swhid: str, cur: sqlite3.Cursor): + """Check if a given SWHID is present or not inside the local database.""" + cur.execute("""SELECT 1 FROM swhid_db WHERE swhid=?""", (swhid,)) + res = cur.fetchone() + cur.close() + + return res is not None diff --git a/swh/scanner/exceptions.py b/swh/scanner/exceptions.py --- a/swh/scanner/exceptions.py +++ b/swh/scanner/exceptions.py @@ -12,6 +12,10 @@ pass +class BadInputExc(TypeError): + pass + + class APIError(Exception): def __str__(self): return '"%s"' % self.args diff --git a/swh/scanner/tests/conftest.py b/swh/scanner/tests/conftest.py --- a/swh/scanner/tests/conftest.py +++ b/swh/scanner/tests/conftest.py @@ -15,6 +15,7 @@ from swh.model.cli import swhid_of_dir, swhid_of_file from swh.scanner.model import Tree +from .data import present_swhids from .flask_api import create_app @@ -136,6 +137,22 @@ return test_sample_folder +@pytest.fixture +def test_swhids_sample(tmp_path): + """Create and return the opened "swhids_sample" file, + filled with present swhids present in data.py + """ + test_swhids_sample = Path(os.path.join(tmp_path, "swhids_sample.txt")) + not_present_swhid = "swh:1:cnt:fa8eacf43d8646129ae8adfa1648f9307d999999" + swhids = present_swhids + [not_present_swhid] + + with open(test_swhids_sample, "w") as f: + f.write("\n".join(swhid for swhid in swhids)) + + assert test_swhids_sample.exists() + return open(test_swhids_sample, "r") + + @pytest.fixture(scope="session") def app(): """Flask backend API (used by live_server).""" diff --git a/swh/scanner/tests/flask_api.py b/swh/scanner/tests/flask_api.py --- a/swh/scanner/tests/flask_api.py +++ b/swh/scanner/tests/flask_api.py @@ -5,6 +5,9 @@ from flask import Flask, request +from swh.model.exceptions import ValidationError +from swh.model.identifiers import parse_swhid +from swh.scanner.exceptions import BadInputExc from swh.web.common.exc import LargePayloadExc from .data import present_swhids @@ -24,8 +27,12 @@ res = {swhid: {"known": False} for swhid in swhids} for swhid in swhids: - if swhid in present_swhids: - res[swhid]["known"] = True + try: + parse_swhid(swhid) + if swhid in present_swhids: + res[swhid]["known"] = True + except ValidationError: + raise BadInputExc("An invalid SWHID was provided", status_code=400) return res diff --git a/swh/scanner/tests/test_cli.py b/swh/scanner/tests/test_cli.py --- a/swh/scanner/tests/test_cli.py +++ b/swh/scanner/tests/test_cli.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import os from pathlib import Path from unittest.mock import Mock, call @@ -12,6 +13,8 @@ import swh.scanner.cli as cli import swh.scanner.scanner as scanner +from .data import present_swhids + DATADIR = Path(__file__).absolute().parent / "data" CONFIG_PATH_GOOD = str(DATADIR / "global.yml") CONFIG_PATH_GOOD2 = str(DATADIR / "global2.yml") # alternative to global.yml @@ -41,6 +44,31 @@ return CliRunner(env={"SWH_CONFIG_FILE": None}) +@pytest.fixture(scope="function") +def bad_input_file(tmp_path): + bad_input_file = Path(os.path.join(tmp_path, "input_file.txt")) + # wrong SWHID hash + bad_swhid = "swh:1:cnt:fa8eacf43d8646129ae8adfa1648f9307d999999bbbbb" + swhids = present_swhids + [bad_swhid] + + with open(bad_input_file, "w") as f: + f.write("\n".join(swhid for swhid in swhids)) + + assert bad_input_file.exists() + return bad_input_file + + +@pytest.fixture(scope="function") +def good_input_file(tmp_path): + good_input_file = Path(os.path.join(tmp_path, "input_file.txt")) + + with open(good_input_file, "w") as f: + f.write("\n".join(swhid for swhid in present_swhids)) + + assert good_input_file.exists() + return good_input_file + + # TEST BEGIN # For nominal code paths, check that the right config file is loaded @@ -119,3 +147,19 @@ res = cli_runner.invoke(cli.scanner, ["scan", ROOTPATH_GOOD, "-u", API_URL]) assert res.exit_code == 0 assert m_scanner.scan.call_count == 1 + + +def test_db_option_bad_input_file(cli_runner, bad_input_file, live_server): + api_url = live_server.url() + "/" + res = cli_runner.invoke( + cli.scanner, ["db", "import", "--input", bad_input_file, "-u", api_url] + ) + assert res.exit_code != 0 + + +def test_db_option_good_input_file(cli_runner, good_input_file, live_server): + api_url = live_server.url() + "/" + res = cli_runner.invoke( + cli.scanner, ["db", "import", "--input", good_input_file, "-u", api_url] + ) + assert res.exit_code == 0 diff --git a/swh/scanner/tests/test_db.py b/swh/scanner/tests/test_db.py new file mode 100644 --- /dev/null +++ b/swh/scanner/tests/test_db.py @@ -0,0 +1,65 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from swh.scanner.db import Db + +from .data import present_swhids + + +def test_db_create_from(tmp_path, live_server, test_sample_folder, test_swhids_sample): + tmp_dbfile = tmp_path / "tmp_db.sqlite" + api_url = live_server.url() + "/" + config = {"web-api": {"url": api_url, "auth-token": None}} + + db = Db(tmp_dbfile) + cur = db.conn.cursor() + db.create_from(config, test_swhids_sample, cur) + + for swhid in present_swhids: + cur = db.conn.cursor() + assert db.check(swhid, cur) + + +def test_db_create_from_one_not_present( + tmp_path, live_server, test_sample_folder, test_swhids_sample +): + not_present_swhid = "swh:1:cnt:fa8eacf43d8646129ae8adfa1648f9307d999999" + present_swhids.append(not_present_swhid) + + tmp_dbfile = tmp_path / "tmp_db.sqlite" + api_url = live_server.url() + "/" + config = {"web-api": {"url": api_url, "auth-token": None}} + + db = Db(tmp_dbfile) + cur = db.conn.cursor() + db.create_from(config, test_swhids_sample, cur) + + for swhid in present_swhids: + cur = db.conn.cursor() + if swhid != not_present_swhid: + assert db.check(swhid, cur) + else: + assert not db.check(swhid, cur) + + +def test_db_add(tmp_path): + swhid = "swh:1:cnt:fa8eacf43d8646129ae8adfa1648f9307d999999" + tmp_dbfile = tmp_path / "tmp_db.sqlite" + db = Db(tmp_dbfile) + cur = db.conn.cursor() + + db.create_table(cur) + db.add(swhid, cur) + assert db.check(swhid, cur) + + +def test_db_add_non_present_swhid(tmp_path): + swhid = "swh:1:cnt:fa8eacf43d8646129ae8adfa1648f9307d99999" + tmp_dbfile = tmp_path / "tmp_db.sqlite" + db = Db(tmp_dbfile) + + cur = db.conn.cursor() + db.create_table(cur) + assert not db.check(swhid, cur)