diff --git a/swh/scanner/backend.py b/swh/scanner/backend.py index 95169ce..7444b01 100644 --- a/swh/scanner/backend.py +++ b/swh/scanner/backend.py @@ -1,40 +1,40 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from flask import Flask, request from .db import Db from .exceptions import LargePayloadExc - -LIMIT = 1000 +from .policy import QUERY_LIMIT def create_app(db: Db): """Backend for swh-scanner, implementing the /known endpoint of the Software Heritage Web API""" app = Flask(__name__) @app.route("/api/1/known/", methods=["POST"]) def known(): swhids = request.get_json() - if len(swhids) > LIMIT: + if len(swhids) > QUERY_LIMIT: raise LargePayloadExc( - f"The maximum number of SWHIDs this endpoint can receive is {LIMIT}" + f"The maximum number of SWHIDs this endpoint can receive is" + f"{QUERY_LIMIT}" ) cur = db.conn.cursor() res = {swhid: {"known": db.known(swhid, cur)} for swhid in swhids} cur.close() return res return app def run(host: str, port: int, db: Db): """Serve the local database""" app = create_app(db) app.run(host, port, debug=True) diff --git a/swh/scanner/cli.py b/swh/scanner/cli.py index 4547abd..e57da27 100644 --- a/swh/scanner/cli.py +++ b/swh/scanner/cli.py @@ -1,249 +1,264 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information # WARNING: do not import unnecessary things here to keep cli startup time under # control import os from typing import Any, Dict, Optional import click from importlib_metadata import version import yaml from swh.core import config from swh.core.cli import CONTEXT_SETTINGS from swh.core.cli import swh as swh_cli_group from .exceptions import DBError # Config for the "serve" option BACKEND_DEFAULT_PORT = 5011 # All generic config code should reside in swh.core.config CONFIG_ENVVAR = "SWH_CONFIG_FILE" DEFAULT_CONFIG_PATH = os.path.join(click.get_app_dir("swh"), "global.yml") DEFAULT_CONFIG: Dict[str, Any] = { "web-api": { "url": "https://archive.softwareheritage.org/api/1/", "auth-token": None, } } CONFIG_FILE_HELP = f"""Configuration file: \b The CLI option or the environment variable will fail if invalid. CLI option is checked first. Then, environment variable {CONFIG_ENVVAR} is checked. Then, if cannot load the default path, a set of default values are used. Default config path is {DEFAULT_CONFIG_PATH}. Default config values are: \b {yaml.dump(DEFAULT_CONFIG)}""" SCANNER_HELP = f"""Software Heritage Scanner tools. {CONFIG_FILE_HELP}""" def setup_config(ctx, api_url): config = ctx.obj["config"] if api_url: if not api_url.endswith("/"): api_url += "/" config["web-api"]["url"] = api_url return config @swh_cli_group.group( name="scanner", context_settings=CONTEXT_SETTINGS, help=SCANNER_HELP, ) @click.option( "-C", "--config-file", default=None, type=click.Path(exists=False, dir_okay=False, path_type=str), help="""YAML configuration file""", ) @click.version_option( version=version("swh.scanner"), prog_name="swh.scanner", ) @click.pass_context def scanner(ctx, config_file: Optional[str]): env_config_path = os.environ.get(CONFIG_ENVVAR) # read_raw_config do not fail if file does not exist, so check it beforehand # while enforcing loading priority if config_file: if not config.config_exists(config_file): raise click.BadParameter( f"File '{config_file}' cannot be opened.", param_hint="--config-file" ) elif env_config_path: if not config.config_exists(env_config_path): raise click.BadParameter( f"File '{env_config_path}' cannot be opened.", param_hint=CONFIG_ENVVAR ) config_file = env_config_path elif config.config_exists(DEFAULT_CONFIG_PATH): config_file = DEFAULT_CONFIG_PATH conf = DEFAULT_CONFIG if config_file is not None: conf = config.read_raw_config(config.config_basepath(config_file)) conf = config.merge_configs(DEFAULT_CONFIG, conf) ctx.ensure_object(dict) ctx.obj["config"] = conf @scanner.command(name="scan") @click.argument("root_path", required=True, type=click.Path(exists=True)) @click.option( "-u", "--api-url", default=None, metavar="API_URL", show_default=True, help="URL for the api request", ) @click.option( "--exclude", "-x", "patterns", metavar="PATTERN", multiple=True, help="Exclude directories using glob patterns \ (e.g., ``*.git`` to exclude all .git directories)", ) @click.option( "-f", "--output-format", "out_fmt", default="text", show_default=True, type=click.Choice(["text", "json", "ndjson", "sunburst"], case_sensitive=False), help="The output format", ) @click.option( "-i", "--interactive", is_flag=True, help="Show the result in a dashboard" ) @click.option( "-p", "--policy", - default="bfs", + default="auto", show_default=True, - type=click.Choice(["bfs", "filepriority", "dirpriority"]), + type=click.Choice(["auto", "bfs", "filepriority", "dirpriority"]), help="The scan policy.", ) @click.pass_context def scan(ctx, root_path, api_url, patterns, out_fmt, interactive, policy): """Scan a source code project to discover files and directories already - present in the archive""" + present in the archive. + + The source code project can be checked using different policies that can be set + using the -p/--policy option: + + auto: it selects the best policy based on the source code, for codebase(s) with + less than 1000 file/dir contents all the nodes will be queried. + + bfs: scan the source code in the BFS order, checking unknown directories only. + + filepriority: scan all the source code file contents, checking only unset + directories. (useful if the codebase contains a lot of source files) + + dirpriority: scan all the source code directories and check only unknown + directory contents. + """ import swh.scanner.scanner as scanner config = setup_config(ctx, api_url) scanner.scan(config, root_path, patterns, out_fmt, interactive, policy) @scanner.group("db", help="Manage local knowledge base for swh-scanner") @click.pass_context def db(ctx): pass @db.command("import") @click.option( "-i", "--input", "input_file", metavar="INPUT_FILE", required=True, type=click.File("r"), help="A file containing SWHIDs", ) @click.option( "-o", "--output", "output_file_db", metavar="OUTPUT_DB_FILE", required=True, show_default=True, help="The name of the generated sqlite database", ) @click.option( "-s", "--chunk-size", "chunk_size", default="10000", metavar="SIZE", show_default=True, type=int, help="The chunk size ", ) @click.pass_context def import_(ctx, chunk_size, input_file, output_file_db): """Create SQLite database of known SWHIDs from a textual list of SWHIDs""" from .db import Db db = Db(output_file_db) cur = db.conn.cursor() try: db.create_from(input_file, chunk_size, cur) db.close() except DBError as e: ctx.fail("Failed to import SWHIDs into database: {0}".format(e)) @db.command("serve") @click.option( "-h", "--host", metavar="HOST", default="127.0.0.1", show_default=True, help="The host of the API server", ) @click.option( "-p", "--port", metavar="PORT", default=f"{BACKEND_DEFAULT_PORT}", show_default=True, help="The port of the API server", ) @click.option( "-f", "--db-file", "db_file", metavar="DB_FILE", default="SWHID_DB.sqlite", show_default=True, type=click.Path(exists=True), help="An sqlite database file (it can be generated with: 'swh scanner db import')", ) @click.pass_context def serve(ctx, host, port, db_file): """Start an API service using the sqlite database generated with the "db import" option.""" import swh.scanner.backend as backend from .db import Db db = Db(db_file) backend.run(host, port, db) db.close() def main(): return scanner(auto_envvar_prefix="SWH_SCANNER") if __name__ == "__main__": main() diff --git a/swh/scanner/policy.py b/swh/scanner/policy.py index a107200..dacf0df 100644 --- a/swh/scanner/policy.py +++ b/swh/scanner/policy.py @@ -1,250 +1,286 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import abc import asyncio import itertools from typing import Dict, List, no_type_check import aiohttp from swh.model.from_disk import Directory from swh.model.identifiers import CONTENT, DIRECTORY from .data import MerkleNodeInfo from .exceptions import error_response +# Maximum number of SWHIDs that can be requested by a single call to the +# Web API endpoint /known/ +QUERY_LIMIT = 1000 + async def swhids_discovery( swhids: List[str], session: aiohttp.ClientSession, api_url: str, ) -> Dict[str, Dict[str, bool]]: """API Request to get information about the SoftWare Heritage persistent IDentifiers (SWHIDs) given in input. Args: swhids: a list of SWHIDS api_url: url for the API request Returns: A dictionary with: key: SWHID searched value: value['known'] = True if the SWHID is found value['known'] = False if the SWHID is not found """ endpoint = api_url + "known/" - chunk_size = 1000 requests = [] def get_chunk(swhids): - for i in range(0, len(swhids), chunk_size): - yield swhids[i : i + chunk_size] + for i in range(0, len(swhids), QUERY_LIMIT): + yield swhids[i : i + QUERY_LIMIT] async def make_request(swhids): async with session.post(endpoint, json=swhids) as resp: if resp.status != 200: error_response(resp.reason, resp.status, endpoint) return await resp.json() - if len(swhids) > chunk_size: + if len(swhids) > QUERY_LIMIT: for swhids_chunk in get_chunk(swhids): requests.append(asyncio.create_task(make_request(swhids_chunk))) res = await asyncio.gather(*requests) # concatenate list of dictionaries return dict(itertools.chain.from_iterable(e.items() for e in res)) else: return await make_request(swhids) class Policy(metaclass=abc.ABCMeta): data: MerkleNodeInfo """information about contents and directories of the merkle tree""" source_tree: Directory """representation of a source code project directory in the merkle tree""" def __init__(self, source_tree: Directory, data: MerkleNodeInfo): self.data = data self.source_tree = source_tree for node in source_tree.iter_tree(): self.data[node.swhid()] = {"known": None} # type: ignore @abc.abstractmethod async def run( self, session: aiohttp.ClientSession, api_url: str, ): """Scan a source code project""" raise NotImplementedError("Must implement run method") class LazyBFS(Policy): + """Read nodes in the merkle tree using the BFS algorithm. + Lookup only directories that are unknown otherwise set all the downstream + contents to known. + """ + async def run( self, session: aiohttp.ClientSession, api_url: str, ): queue = [] queue.append(self.source_tree) while queue: swhids = [str(node.swhid()) for node in queue] swhids_res = await swhids_discovery(swhids, session, api_url) for node in queue.copy(): queue.remove(node) self.data[node.swhid()]["known"] = swhids_res[str(node.swhid())][ "known" ] if node.object_type == DIRECTORY: if not self.data[node.swhid()]["known"]: children = [n[1] for n in list(node.items())] queue.extend(children) else: for sub_node in node.iter_tree(): if sub_node == node: continue self.data[sub_node.swhid()]["known"] = True # type: ignore class FilePriority(Policy): + """Check the Merkle tree querying all the file contents and set all the upstream + directories to unknown in the case a file content is unknown. + Finally check all the directories which status is still unknown and set all the + sub-directories of known directories to known. + """ + @no_type_check async def run( self, session: aiohttp.ClientSession, api_url: str, ): # get all the files all_contents = list( filter( lambda node: node.object_type == CONTENT, self.source_tree.iter_tree() ) ) all_contents.reverse() # check deepest node first # query the backend to get all file contents status cnt_swhids = [str(node.swhid()) for node in all_contents] cnt_status_res = await swhids_discovery(cnt_swhids, session, api_url) # set all the file contents status for cnt in all_contents: self.data[cnt.swhid()]["known"] = cnt_status_res[str(cnt.swhid())]["known"] # set all the upstream directories of unknown file contents to unknown if not self.data[cnt.swhid()]["known"]: parent = cnt.parents[0] while parent: self.data[parent.swhid()]["known"] = False parent = parent.parents[0] if parent.parents else None # get all unset directories and check their status # (update children directories accordingly) unset_dirs = list( filter( lambda node: node.object_type == DIRECTORY and self.data[node.swhid()]["known"] is None, self.source_tree.iter_tree(), ) ) # check unset directories for dir_ in unset_dirs: if self.data[dir_.swhid()]["known"] is None: # update directory status dir_status = await swhids_discovery( [str(dir_.swhid())], session, api_url ) dir_known = dir_status[str(dir_.swhid())]["known"] self.data[dir_.swhid()]["known"] = dir_known if dir_known: sub_dirs = list( filter( lambda n: n.object_type == DIRECTORY and self.data[n.swhid()]["known"] is None, dir_.iter_tree(), ) ) for node in sub_dirs: self.data[node.swhid()]["known"] = True class DirectoryPriority(Policy): + """Check the Merkle tree querying all the directories that have at least one file + content and set all the upstream directories to unknown in the case a directory + is unknown otherwise set all the downstream contents to known. + Finally check the status of empty directories and all the remaining file + contents. + """ + @no_type_check async def run( self, session: aiohttp.ClientSession, api_url: str, ): # get all directory contents that have at least one file content unknown_dirs = list( filter( lambda dir_: dir_.object_type == DIRECTORY and self.has_contents(dir_), self.source_tree.iter_tree(), ) ) unknown_dirs.reverse() # check deepest node first for dir_ in unknown_dirs: if self.data[dir_.swhid()]["known"] is None: dir_status = await swhids_discovery( [str(dir_.swhid())], session, api_url ) dir_known = dir_status[str(dir_.swhid())]["known"] self.data[dir_.swhid()]["known"] = dir_known # set all the downstream file contents to known if dir_known: for cnt in self.get_contents(dir_): self.data[cnt.swhid()]["known"] = True # otherwise set all the upstream directories to unknown else: parent = dir_.parents[0] while parent: self.data[parent.swhid()]["known"] = False parent = parent.parents[0] if parent.parents else None # get remaining directories that have no file contents empty_dirs = list( filter( lambda n: n.object_type == DIRECTORY and not self.has_contents(n) and self.data[n.swhid()]["known"] is None, self.source_tree.iter_tree(), ) ) empty_dirs_swhids = [str(n.swhid()) for n in empty_dirs] empty_dir_status = await swhids_discovery(empty_dirs_swhids, session, api_url) # update status of directories that have no file contents for dir_ in empty_dirs: self.data[dir_.swhid()]["known"] = empty_dir_status[str(dir_.swhid())][ "known" ] # check unknown file contents unknown_cnts = list( filter( lambda n: n.object_type == CONTENT and self.data[n.swhid()]["known"] is None, self.source_tree.iter_tree(), ) ) unknown_cnts_swhids = [str(n.swhid()) for n in unknown_cnts] unknown_cnts_status = await swhids_discovery( unknown_cnts_swhids, session, api_url ) for cnt in unknown_cnts: self.data[cnt.swhid()]["known"] = unknown_cnts_status[str(cnt.swhid())][ "known" ] def has_contents(self, directory: Directory): """Check if the directory given in input has contents""" for entry in directory.entries: if entry["type"] == "file": return True return False def get_contents(self, dir_: Directory): """Get all the contents of a given directory""" for _, node in list(dir_.items()): if node.object_type == CONTENT: yield node + + +class QueryAll(Policy): + """Check the status of every node in the Merkle tree. + """ + + @no_type_check + async def run( + self, session: aiohttp.ClientSession, api_url: str, + ): + all_nodes = [node for node in self.source_tree.iter_tree()] + all_swhids = [str(node.swhid()) for node in all_nodes] + swhids_res = await swhids_discovery(all_swhids, session, api_url) + for node in all_nodes: + self.data[node.swhid()]["known"] = swhids_res[str(node.swhid())]["known"] diff --git a/swh/scanner/scanner.py b/swh/scanner/scanner.py index 2e77b5b..9540e1e 100644 --- a/swh/scanner/scanner.py +++ b/swh/scanner/scanner.py @@ -1,71 +1,81 @@ # Copyright (C) 2020-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import asyncio from typing import Any, Dict, Iterable import aiohttp from swh.model.cli import model_of_dir from swh.model.from_disk import Directory from .data import MerkleNodeInfo from .output import Output -from .policy import DirectoryPriority, FilePriority, LazyBFS +from .policy import QUERY_LIMIT, DirectoryPriority, FilePriority, LazyBFS, QueryAll async def run(config: Dict[str, Any], policy) -> None: """Scan a given source code according to the policy given in input. Args: root: the root path to scan api_url: url for the API request """ api_url = config["web-api"]["url"] if config["web-api"]["auth-token"]: headers = {"Authorization": f"Bearer {config['web-api']['auth-token']}"} else: headers = {} async with aiohttp.ClientSession(headers=headers, trust_env=True) as session: await policy.run(session, api_url) +def source_size(source_tree: Directory): + return len([n for n in source_tree.iter_tree(dedup=False)]) + + def get_policy_obj(source_tree: Directory, nodes_data: MerkleNodeInfo, policy: str): - if policy == "bfs": + if policy == "auto": + return ( + QueryAll(source_tree, nodes_data) + if source_size(source_tree) <= QUERY_LIMIT + else LazyBFS(source_tree, nodes_data) + ) + elif policy == "bfs": return LazyBFS(source_tree, nodes_data) elif policy == "filepriority": return FilePriority(source_tree, nodes_data) elif policy == "dirpriority": return DirectoryPriority(source_tree, nodes_data) else: raise Exception(f"policy '{policy}' not found") def scan( config: Dict[str, Any], root_path: str, exclude_patterns: Iterable[str], out_fmt: str, interactive: bool, policy: str, ): """Scan a source code project to discover files and directories already present in the archive""" converted_patterns = [pattern.encode() for pattern in exclude_patterns] source_tree = model_of_dir(root_path.encode(), converted_patterns) nodes_data = MerkleNodeInfo() policy = get_policy_obj(source_tree, nodes_data, policy) loop = asyncio.get_event_loop() loop.run_until_complete(run(config, policy)) out = Output(root_path, nodes_data, source_tree) if interactive: out.show("interactive") else: out.show(out_fmt) diff --git a/swh/scanner/tests/conftest.py b/swh/scanner/tests/conftest.py index 3e5b56d..1806279 100644 --- a/swh/scanner/tests/conftest.py +++ b/swh/scanner/tests/conftest.py @@ -1,134 +1,150 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import asyncio import os from pathlib import Path import shutil import aiohttp from aioresponses import aioresponses # type: ignore import pytest from swh.model.cli import model_of_dir from swh.scanner.data import MerkleNodeInfo +from swh.scanner.policy import QUERY_LIMIT from .data import present_swhids from .flask_api import create_app @pytest.fixture def mock_aioresponse(): with aioresponses() as m: yield m @pytest.fixture def event_loop(): """Fixture that generate an asyncio event loop.""" loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) yield loop loop.close() @pytest.fixture async def aiosession(): """Fixture that generate an aiohttp Client Session.""" session = aiohttp.ClientSession() yield session session.detach() @pytest.fixture(scope="function") def test_sample_folder(datadir, tmp_path): """Location of the "data" folder""" archive_path = Path(os.path.join(datadir, "sample-folder.tgz")) assert archive_path.exists() shutil.unpack_archive(archive_path, extract_dir=tmp_path) test_sample_folder = Path(os.path.join(tmp_path, "sample-folder")) assert test_sample_folder.exists() return test_sample_folder @pytest.fixture(scope="function") def test_sample_folder_policy(datadir, tmp_path): """Location of the sample source code project to test the scanner policies""" archive_path = Path(os.path.join(datadir, "sample-folder-policy.tgz")) assert archive_path.exists() shutil.unpack_archive(archive_path, extract_dir=tmp_path) test_sample_folder = Path(os.path.join(tmp_path, "sample-folder-policy")) assert test_sample_folder.exists() return test_sample_folder @pytest.fixture(scope="function") def source_tree(test_sample_folder): """Generate a model.from_disk.Directory object from the test sample folder """ return model_of_dir(str(test_sample_folder).encode()) +@pytest.fixture(scope="function") +def big_source_tree(tmp_path): + """Generate a model.from_disk.Directory from a "big" temporary directory + (more than 1000 nodes) + """ + dir_ = tmp_path / "big-directory" + dir_.mkdir() + for i in range(0, QUERY_LIMIT + 1): + file_ = dir_ / f"file_{i}.org" + file_.touch() + dir_obj = model_of_dir(str(dir_).encode()) + assert len(dir_obj) > QUERY_LIMIT + return dir_obj + + @pytest.fixture(scope="function") def source_tree_policy(test_sample_folder_policy): """Generate a model.from_disk.Directory object from the test sample folder """ return model_of_dir(str(test_sample_folder_policy).encode()) @pytest.fixture(scope="function") def source_tree_dirs(source_tree): """Returns a list of all directories contained inside the test sample folder """ root = source_tree.data["path"] return list( map( lambda n: Path(n.data["path"].decode()).relative_to(Path(root.decode())), filter( lambda n: n.object_type == "directory" and not n.data["path"] == source_tree.data["path"], source_tree.iter_tree(dedup=False), ), ) ) @pytest.fixture(scope="function") def nodes_data(source_tree): """mock known status of file/dirs in test_sample_folder""" nodes_data = MerkleNodeInfo() for node in source_tree.iter_tree(): nodes_data[node.swhid()] = {"known": True} return nodes_data @pytest.fixture def test_swhids_sample(tmp_path): """Create and return the opened "swhids_sample" file, filled with present swhids present in data.py """ test_swhids_sample = Path(os.path.join(tmp_path, "swhids_sample.txt")) with open(test_swhids_sample, "w") as f: f.write("\n".join(swhid for swhid in present_swhids)) assert test_swhids_sample.exists() return open(test_swhids_sample, "r") @pytest.fixture(scope="session") def tmp_requests(tmpdir_factory): requests_file = tmpdir_factory.mktemp("data").join("requests.json") return requests_file @pytest.fixture(scope="session") def app(tmp_requests): """Flask backend API (used by live_server).""" app = create_app(tmp_requests) return app diff --git a/swh/scanner/tests/test_backend.py b/swh/scanner/tests/test_backend.py index 93d5b5f..c3d942b 100644 --- a/swh/scanner/tests/test_backend.py +++ b/swh/scanner/tests/test_backend.py @@ -1,61 +1,62 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from swh.scanner.backend import LIMIT, create_app +from swh.scanner.backend import create_app from swh.scanner.db import Db +from swh.scanner.policy import QUERY_LIMIT from .data import present_swhids def test_backend_endpoint_all_present(tmp_path, live_server, test_swhids_sample): tmp_dbfile = tmp_path / "tmp_db.sqlite" db = Db(tmp_dbfile) cur = db.conn.cursor() - db.create_from(test_swhids_sample, LIMIT, cur) + db.create_from(test_swhids_sample, QUERY_LIMIT, cur) app = create_app(db) with app.test_client() as test_client: res = test_client.post("/api/1/known/", json=present_swhids) for swhid, attr in res.json.items(): assert attr["known"] def test_backend_endpoint_one_not_present(tmp_path, live_server, test_swhids_sample): tmp_dbfile = tmp_path / "tmp_db.sqlite" not_present_swhid = "swh:1:cnt:fa8eacf43d8646129ae8adfa1648f9307d999999" swhids = present_swhids + [not_present_swhid] db = Db(tmp_dbfile) cur = db.conn.cursor() - db.create_from(test_swhids_sample, LIMIT, cur) + db.create_from(test_swhids_sample, QUERY_LIMIT, cur) app = create_app(db) with app.test_client() as test_client: res = test_client.post("/api/1/known/", json=swhids) for swhid, attr in res.json.items(): if swhid != not_present_swhid: assert attr["known"] else: assert not attr["known"] def test_backend_large_payload_exc(tmp_path, live_server, test_swhids_sample): tmp_dbfile = tmp_path / "tmp_db.sqlite" swhid = "swh:1:cnt:fa8eacf43d8646129ae8adfa1648f9307d999999" # the backend supports up to 1000 SWHID requests swhids = [swhid for n in range(1001)] db = Db(tmp_dbfile) cur = db.conn.cursor() - db.create_from(test_swhids_sample, LIMIT, cur) + db.create_from(test_swhids_sample, QUERY_LIMIT, cur) app = create_app(db) with app.test_client() as test_client: res = test_client.post("/api/1/known/", json=swhids) assert res.status_code != 200 diff --git a/swh/scanner/tests/test_db.py b/swh/scanner/tests/test_db.py index 96a3260..222edcf 100644 --- a/swh/scanner/tests/test_db.py +++ b/swh/scanner/tests/test_db.py @@ -1,40 +1,39 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from swh.scanner.db import Db +from swh.scanner.policy import QUERY_LIMIT from .data import present_swhids -CHUNK_SIZE = 1000 - def test_db_create_from(tmp_path, test_swhids_sample): tmp_dbfile = tmp_path / "tmp_db.sqlite" db = Db(tmp_dbfile) cur = db.conn.cursor() - db.create_from(test_swhids_sample, CHUNK_SIZE, cur) + db.create_from(test_swhids_sample, QUERY_LIMIT, cur) for swhid in present_swhids: cur = db.conn.cursor() assert db.known(swhid, cur) def test_db_create_from_one_not_present(tmp_path, test_swhids_sample): not_present_swhid = "swh:1:cnt:fa8eacf43d8646129ae8adfa1648f9307d999999" swhids = present_swhids + [not_present_swhid] tmp_dbfile = tmp_path / "tmp_db.sqlite" db = Db(tmp_dbfile) cur = db.conn.cursor() - db.create_from(test_swhids_sample, CHUNK_SIZE, cur) + db.create_from(test_swhids_sample, QUERY_LIMIT, cur) for swhid in swhids: cur = db.conn.cursor() if swhid != not_present_swhid: assert db.known(swhid, cur) else: assert not db.known(swhid, cur) diff --git a/swh/scanner/tests/test_scanner.py b/swh/scanner/tests/test_scanner.py index 9e5e59b..903ed0b 100644 --- a/swh/scanner/tests/test_scanner.py +++ b/swh/scanner/tests/test_scanner.py @@ -1,60 +1,84 @@ # Copyright (C) 2020-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from flask import url_for import pytest from swh.scanner.data import MerkleNodeInfo -from swh.scanner.policy import DirectoryPriority, FilePriority, LazyBFS -from swh.scanner.scanner import run +from swh.scanner.policy import DirectoryPriority, FilePriority, LazyBFS, QueryAll +from swh.scanner.scanner import get_policy_obj, run from .data import unknown_swhids @pytest.mark.options(debug=False) def test_app(app): assert not app.debug +def test_get_policy_obj_auto(source_tree, nodes_data): + assert isinstance(get_policy_obj(source_tree, nodes_data, "auto"), QueryAll) + + +def test_get_policy_obj_bfs(big_source_tree, nodes_data): + # check that the policy object is the LazyBFS if the source tree contains more than + # 1000 nodes + assert isinstance(get_policy_obj(big_source_tree, nodes_data, "auto"), LazyBFS) + + def test_scanner_result_bfs(live_server, event_loop, source_tree): api_url = url_for("index", _external=True) config = {"web-api": {"url": api_url, "auth-token": None}} nodes_data = MerkleNodeInfo() policy = LazyBFS(source_tree, nodes_data) event_loop.run_until_complete(run(config, policy)) for node in source_tree.iter_tree(): if str(node.swhid()) in unknown_swhids: assert nodes_data[node.swhid()]["known"] is False else: assert nodes_data[node.swhid()]["known"] is True def test_scanner_result_file_priority(live_server, event_loop, source_tree): api_url = url_for("index", _external=True) config = {"web-api": {"url": api_url, "auth-token": None}} nodes_data = MerkleNodeInfo() policy = FilePriority(source_tree, nodes_data) event_loop.run_until_complete(run(config, policy)) for node in source_tree.iter_tree(): if str(node.swhid()) in unknown_swhids: assert nodes_data[node.swhid()]["known"] is False else: assert nodes_data[node.swhid()]["known"] is True def test_scanner_result_directory_priority(live_server, event_loop, source_tree): api_url = url_for("index", _external=True) config = {"web-api": {"url": api_url, "auth-token": None}} nodes_data = MerkleNodeInfo() policy = DirectoryPriority(source_tree, nodes_data) event_loop.run_until_complete(run(config, policy)) for node in source_tree.iter_tree(): if str(node.swhid()) in unknown_swhids: assert nodes_data[node.swhid()]["known"] is False else: assert nodes_data[node.swhid()]["known"] is True + + +def test_scanner_result_query_all(live_server, event_loop, source_tree): + api_url = url_for("index", _external=True) + config = {"web-api": {"url": api_url, "auth-token": None}} + + nodes_data = MerkleNodeInfo() + policy = QueryAll(source_tree, nodes_data) + event_loop.run_until_complete(run(config, policy)) + for node in source_tree.iter_tree(): + if str(node.swhid()) in unknown_swhids: + assert nodes_data[node.swhid()]["known"] is False + else: + assert nodes_data[node.swhid()]["known"] is True