diff --git a/requirements.txt b/requirements.txt --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ numpy dash dash_bootstrap_components +flask dulwich diff --git a/swh/scanner/api.py b/swh/scanner/api.py new file mode 100644 --- /dev/null +++ b/swh/scanner/api.py @@ -0,0 +1,41 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from flask import Flask, request + +from .db import Db +from .exceptions import LargePayloadExc + + +def create_app(db: Db): + """Check if an input list of SWHIDs are present inside the database + """ + app = Flask(__name__) + + @app.route("/api/1/known/", methods=["POST"]) + def known(): + swhids = request.get_json() + + if len(swhids) > 1000: + raise LargePayloadExc( + "The maximum number of SWHIDs this endpoint can receive is 1000" + ) + + res = {swhid: {"known": False} for swhid in swhids} + for swhid in swhids: + cur = db.conn.cursor() + if db.known(swhid, cur): + res[swhid]["known"] = True + + return res + + return app + + +def run(host: str, port: int, db: Db): + """Serve the local database + """ + app = create_app(db) + app.run(host, port, debug=True) diff --git a/swh/scanner/cli.py b/swh/scanner/cli.py --- a/swh/scanner/cli.py +++ b/swh/scanner/cli.py @@ -174,8 +174,7 @@ ) @click.pass_context def import_(ctx, chunk_size, input_file, output_file_db): - """Parse an input list of SWHID to generate a local sqlite database - """ + """Parse an input list of SWHID to generate a local sqlite database""" from .db import Db db = Db(output_file_db) @@ -189,6 +188,48 @@ sys.exit(1) +@db.command("serve") +@click.option( + "-h", + "--host", + "host", + metavar="HOST", + default="127.0.0.1", + show_default=True, + help="The host of the API server", +) +@click.option( + "-p", + "--port", + "port", + metavar="PORT", + default="5011", + show_default=True, + help="The port of the API server", +) +@click.option( + "-f", + "--db-file", + "db_file", + metavar="DB_FILE", + default="SWHID_DB.sqlite", + show_default=True, + type=click.Path(exists=True), + help="An sqlite database file (it can be generated with: 'swh scanner db import')", +) +@click.pass_context +def serve(ctx, host, port, db_file): + """Start an API service using the sqlite database generated with the "db import" + option.""" + import swh.scanner.api as api + + from .db import Db + + db = Db(db_file) + api.run(host, port, db) + db.close() + + def main(): return scanner(auto_envvar_prefix="SWH_SCANNER") diff --git a/swh/scanner/exceptions.py b/swh/scanner/exceptions.py --- a/swh/scanner/exceptions.py +++ b/swh/scanner/exceptions.py @@ -12,6 +12,10 @@ pass +class LargePayloadExc(Exception): + pass + + class DBError(Exception): pass diff --git a/swh/scanner/tests/test_api.py b/swh/scanner/tests/test_api.py new file mode 100644 --- /dev/null +++ b/swh/scanner/tests/test_api.py @@ -0,0 +1,47 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from swh.scanner.api import create_app +from swh.scanner.db import Db + +from .data import present_swhids + +CHUNK_SIZE = 1000 + + +def test_local_api_endpoint_all_present(tmp_path, live_server, test_swhids_sample): + tmp_dbfile = tmp_path / "tmp_db.sqlite" + db = Db(tmp_dbfile) + cur = db.conn.cursor() + db.create_from(test_swhids_sample, CHUNK_SIZE, cur) + + app = create_app(db) + + with app.test_client() as test_client: + res = test_client.post("/api/1/known/", json=present_swhids) + + for swhid, attr in res.json.items(): + assert attr["known"] + + +def test_local_api_endpoint_one_not_present(tmp_path, live_server, test_swhids_sample): + tmp_dbfile = tmp_path / "tmp_db.sqlite" + not_present_swhid = "swh:1:cnt:fa8eacf43d8646129ae8adfa1648f9307d999999" + swhids = present_swhids + [not_present_swhid] + + db = Db(tmp_dbfile) + cur = db.conn.cursor() + db.create_from(test_swhids_sample, CHUNK_SIZE, cur) + + app = create_app(db) + + with app.test_client() as test_client: + res = test_client.post("/api/1/known/", json=swhids) + + for swhid, attr in res.json.items(): + if swhid != not_present_swhid: + assert attr["known"] + else: + assert not attr["known"]