diff --git a/conftest.py b/conftest.py --- a/conftest.py +++ b/conftest.py @@ -1,10 +1,10 @@ -# Copyright (C) 2020 The Software Heritage developers +# Copyright (C) 2020-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os -pytest_plugins = ["swh.scheduler.pytest_plugin", "swh.lister.pytest_plugin"] +pytest_plugins = ["swh.scheduler.pytest_plugin"] os.environ["LC_ALL"] = "C.UTF-8" diff --git a/mypy.ini b/mypy.ini --- a/mypy.ini +++ b/mypy.ini @@ -2,10 +2,6 @@ namespace_packages = True warn_unused_ignores = True -# support for sqlalchemy magic: see https://github.com/dropbox/sqlalchemy-stubs -plugins = sqlmypy - - # 3rd party libraries without stubs (yet) [mypy-bs4.*] diff --git a/requirements-test.txt b/requirements-test.txt --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,5 +1,3 @@ pytest pytest-mock requests_mock -sqlalchemy-stubs -testing.postgresql diff --git a/requirements.txt b/requirements.txt --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,3 @@ -SQLAlchemy -arrow python_debian requests setuptools diff --git a/swh/lister/cli.py b/swh/lister/cli.py --- a/swh/lister/cli.py +++ b/swh/lister/cli.py @@ -1,4 +1,4 @@ -# Copyright (C) 2018-2020 The Software Heritage developers +# Copyright (C) 2018-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -14,31 +14,11 @@ from swh.core.cli import CONTEXT_SETTINGS from swh.core.cli import swh as swh_cli_group -from swh.lister import LISTERS, SUPPORTED_LISTERS, get_lister +from swh.lister import SUPPORTED_LISTERS, get_lister logger = logging.getLogger(__name__) -# the key in this dict is the suffix used to match new task-type to be added. -# For example for a task which function name is "list_gitlab_full', the default -# value used when inserting a new task-type in the scheduler db will be the one -# under the 'full' key below (because it matches xxx_full). -DEFAULT_TASK_TYPE = { - "full": { # for tasks like 'list_xxx_full()' - "default_interval": "90 days", - "min_interval": "90 days", - "max_interval": "90 days", - "backoff_factor": 1, - }, - "*": { # value if not suffix matches - "default_interval": "1 day", - "min_interval": "1 day", - "max_interval": "1 day", - "backoff_factor": 1, - }, -} - - @swh_cli_group.group(name="lister", context_settings=CONTEXT_SETTINGS) @click.option( "--config-file", @@ -47,15 +27,8 @@ type=click.Path(exists=True, dir_okay=False,), help="Configuration file.", ) -@click.option( - "--db-url", - "-d", - default=None, - help="SQLAlchemy DB URL; see " - "", -) # noqa @click.pass_context -def lister(ctx, config_file, db_url): +def lister(ctx, config_file): """Software Heritage Lister tools.""" from swh.core import config @@ -64,51 +37,8 @@ if not config_file: config_file = os.environ.get("SWH_CONFIG_FILENAME") conf = config.read(config_file) - if db_url: - conf["lister"] = {"cls": "local", "args": {"db": db_url}} - ctx.obj["config"] = conf - - -@lister.command(name="db-init", context_settings=CONTEXT_SETTINGS) -@click.option( - "--drop-tables", - "-D", - is_flag=True, - default=False, - help="Drop tables before creating the database schema", -) -@click.pass_context -def db_init(ctx, drop_tables): - """Initialize the database model for given listers. - """ - from sqlalchemy import create_engine - - from swh.lister.core.models import initialize - - cfg = ctx.obj["config"] - lister_cfg = cfg["lister"] - if lister_cfg["cls"] != "local": - click.echo("A local lister configuration is required") - ctx.exit(1) - - db_url = lister_cfg["args"]["db"] - db_engine = create_engine(db_url) - - registry = {} - for lister, entrypoint in LISTERS.items(): - logger.info("Loading lister %s", lister) - registry[lister] = entrypoint.load()() - - logger.info("Initializing database") - initialize(db_engine, drop_tables) - - for lister, entrypoint in LISTERS.items(): - registry_entry = registry[lister] - init_hook = registry_entry.get("init") - if callable(init_hook): - logger.info("Calling init hook for %s", lister) - init_hook(db_engine) + ctx.obj["config"] = conf @lister.command( @@ -122,17 +52,9 @@ @click.option( "--lister", "-l", help="Lister to run", type=click.Choice(SUPPORTED_LISTERS) ) -@click.option( - "--priority", - "-p", - default="high", - type=click.Choice(["high", "medium", "low"]), - help="Task priority for the listed repositories to ingest", -) -@click.option("--legacy", help="Allow unported lister to run with such flag") @click.argument("options", nargs=-1) @click.pass_context -def run(ctx, lister, priority, options, legacy): +def run(ctx, lister, options): from swh.scheduler.cli.utils import parse_options config = deepcopy(ctx.obj["config"]) @@ -140,10 +62,6 @@ if options: config.update(parse_options(options)[1]) - if legacy: - config["priority"] = priority - config["policy"] = "oneshot" - get_lister(lister, **config).run() diff --git a/swh/lister/core/__init__.py b/swh/lister/core/__init__.py deleted file mode 100644 diff --git a/swh/lister/core/abstractattribute.py b/swh/lister/core/abstractattribute.py deleted file mode 100644 --- a/swh/lister/core/abstractattribute.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (C) 2017 the Software Heritage developers -# License: GNU General Public License version 3, or any later version -# See top-level LICENSE file for more information - - -class AbstractAttribute: - """AbstractAttributes in a base class must be overridden by the subclass. - - It's like the :func:`abc.abstractmethod` decorator, but for things that - are explicitly attributes/properties, not methods, without the need for - empty method def boilerplate. Like abc.abstractmethod, the class containing - AbstractAttributes must inherit from :class:`abc.ABC` or use the - :class:`abc.ABCMeta` metaclass. - - Usage example:: - - import abc - class ClassContainingAnAbstractAttribute(abc.ABC): - foo: Union[AbstractAttribute, Any] = \ - AbstractAttribute('docstring for foo') - - """ - - __isabstractmethod__ = True - - def __init__(self, docstring=None): - if docstring is not None: - self.__doc__ = "AbstractAttribute: " + docstring diff --git a/swh/lister/core/lister_base.py b/swh/lister/core/lister_base.py deleted file mode 100644 --- a/swh/lister/core/lister_base.py +++ /dev/null @@ -1,508 +0,0 @@ -# Copyright (C) 2015-2020 the Software Heritage developers -# License: GNU General Public License version 3, or any later version -# See top-level LICENSE file for more information - -import abc -import datetime -import gzip -import json -import logging -import os -import re -import time -from typing import Any, Dict, List, Optional, Type, Union - -from requests import Response -from sqlalchemy import create_engine, func -from sqlalchemy.orm import sessionmaker - -from swh.core import config -from swh.core.utils import grouper -from swh.scheduler import get_scheduler, utils - -from .abstractattribute import AbstractAttribute - -logger = logging.getLogger(__name__) - - -def utcnow(): - return datetime.datetime.now(tz=datetime.timezone.utc) - - -class FetchError(RuntimeError): - def __init__(self, response): - self.response = response - - def __str__(self): - return repr(self.response) - - -DEFAULT_CONFIG = { - "scheduler": {"cls": "memory"}, - "lister": {"cls": "local", "args": {"db": "postgresql:///lister",},}, - "credentials": {}, - "cache_responses": False, -} - - -class ListerBase(abc.ABC): - """Lister core base class. - Generally a source code hosting service provides an API endpoint - for listing the set of stored repositories. A Lister is the discovery - service responsible for finding this list, all at once or sequentially - by parts, and queueing local tasks to fetch and ingest the referenced - repositories. - - The core method in this class is ingest_data. Any subclasses should be - calling this method one or more times to fetch and ingest data from API - endpoints. See swh.lister.core.lister_base.IndexingLister for - example usage. - - This class cannot be instantiated. Any instantiable Lister descending - from ListerBase must provide at least the required overrides. - (see member docstrings for details): - - Required Overrides: - MODEL - def transport_request - def transport_response_to_string - def transport_response_simplified - def transport_quota_check - - Optional Overrides: - def filter_before_inject - def is_within_bounds - """ - - MODEL = AbstractAttribute( - "Subclass type (not instance) of swh.lister.core.models.ModelBase " - "customized for a specific service." - ) # type: Union[AbstractAttribute, Type[Any]] - LISTER_NAME = AbstractAttribute( - "Lister's name" - ) # type: Union[AbstractAttribute, str] - - def transport_request(self, identifier): - """Given a target endpoint identifier to query, try once to request it. - - Implementation of this method determines the network request protocol. - - Args: - identifier (string): unique identifier for an endpoint query. - e.g. If the service indexes lists of repositories by date and - time of creation, this might be that as a formatted string. Or - it might be an integer UID. Or it might be nothing. - It depends on what the service needs. - Returns: - the entire request response - Raises: - Will catch internal transport-dependent connection exceptions and - raise swh.lister.core.lister_base.FetchError instead. Other - non-connection exceptions should propagate unchanged. - """ - pass - - def transport_response_to_string(self, response): - """Convert the server response into a formatted string for logging. - - Implementation of this method depends on the shape of the network - response object returned by the transport_request method. - - Args: - response: the server response - Returns: - a pretty string of the response - """ - pass - - def transport_response_simplified(self, response): - """Convert the server response into list of a dict for each repo in the - response, mapping columns in the lister's MODEL class to repo data. - - Implementation of this method depends on the server API spec and the - shape of the network response object returned by the transport_request - method. - - Args: - response: response object from the server. - Returns: - list of repo MODEL dicts - ( eg. [{'uid': r['id'], etc.} for r in response.json()] ) - """ - pass - - def transport_quota_check(self, response): - """Check server response to see if we're hitting request rate limits. - - Implementation of this method depends on the server communication - protocol and API spec and the shape of the network response object - returned by the transport_request method. - - Args: - response (session response): complete API query response - Returns: - 1) must retry request? True/False - 2) seconds to delay if True - """ - pass - - def filter_before_inject(self, models_list: List[Dict]) -> List[Dict]: - """Filter models_list entries prior to injection in the db. - This is ran directly after `transport_response_simplified`. - - Default implementation is to have no filtering. - - Args: - models_list: list of dicts returned by - transport_response_simplified. - Returns: - models_list with entries changed according to custom logic. - - """ - return models_list - - def do_additional_checks(self, models_list: List[Dict]) -> List[Dict]: - """Execute some additional checks on the model list (after the - filtering). - - Default implementation is to run no check at all and to return - the input as is. - - Args: - models_list: list of dicts returned by - transport_response_simplified. - - Returns: - models_list with entries if checks ok, False otherwise - - """ - return models_list - - def is_within_bounds( - self, inner: int, lower: Optional[int] = None, upper: Optional[int] = None - ) -> bool: - """See if a sortable value is inside the range [lower,upper]. - - MAY BE OVERRIDDEN, for example if the server indexable* key is - technically sortable but not automatically so. - - * - ( see: swh.lister.core.indexing_lister.IndexingLister ) - - Args: - inner (sortable type): the value being checked - lower (sortable type): optional lower bound - upper (sortable type): optional upper bound - Returns: - whether inner is confined by the optional lower and upper bounds - """ - try: - if lower is None and upper is None: - return True - elif lower is None: - ret = inner <= upper # type: ignore - elif upper is None: - ret = inner >= lower - else: - ret = lower <= inner <= upper - - self.string_pattern_check(inner, lower, upper) - except Exception as e: - logger.error( - str(e) - + ": %s, %s, %s" - % ( - ("inner=%s%s" % (type(inner), inner)), - ("lower=%s%s" % (type(lower), lower)), - ("upper=%s%s" % (type(upper), upper)), - ) - ) - raise - - return ret - - # You probably don't need to override anything below this line. - - INITIAL_BACKOFF = 10 - MAX_RETRIES = 7 - CONN_SLEEP = 10 - - def __init__(self, override_config=None): - self.backoff = self.INITIAL_BACKOFF - self.config = config.load_from_envvar(DEFAULT_CONFIG) - if self.config["cache_responses"]: - cache_dir = self.config.get( - "cache_dir", f"~/.cache/swh/lister/{self.LISTER_NAME}" - ) - self.config["cache_dir"] = os.path.expanduser(cache_dir) - config.prepare_folders(self.config, "cache_dir") - - if override_config: - self.config.update(override_config) - - logger.debug("%s CONFIG=%s" % (self, self.config)) - self.scheduler = get_scheduler(**self.config["scheduler"]) - self.db_engine = create_engine(self.config["lister"]["args"]["db"]) - self.mk_session = sessionmaker(bind=self.db_engine) - self.db_session = self.mk_session() - - def reset_backoff(self): - """Reset exponential backoff timeout to initial level.""" - self.backoff = self.INITIAL_BACKOFF - - def back_off(self) -> int: - """Get next exponential backoff timeout.""" - ret = self.backoff - self.backoff *= 10 - return ret - - def safely_issue_request(self, identifier: int) -> Optional[Response]: - """Make network request with retries, rate quotas, and response logs. - - Protocol is handled by the implementation of the transport_request - method. - - Args: - identifier: resource identifier - Returns: - server response - """ - retries_left = self.MAX_RETRIES - do_cache = self.config["cache_responses"] - r = None - while retries_left > 0: - try: - r = self.transport_request(identifier) - except FetchError: - # network-level connection error, try again - logger.warning( - "connection error on %s: sleep for %d seconds" - % (identifier, self.CONN_SLEEP) - ) - time.sleep(self.CONN_SLEEP) - retries_left -= 1 - continue - - if do_cache: - self.save_response(r) - - # detect throttling - must_retry, delay = self.transport_quota_check(r) - if must_retry: - logger.warning( - "rate limited on %s: sleep for %f seconds" % (identifier, delay) - ) - time.sleep(delay) - else: # request ok - break - - retries_left -= 1 - - if not retries_left: - logger.warning("giving up on %s: max retries exceeded" % identifier) - - return r - - def db_query_equal(self, key: Any, value: Any): - """Look in the db for a row with key == value - - Args: - key: column key to look at - value: value to look for in that column - Returns: - sqlalchemy.ext.declarative.declarative_base object - with the given key == value - """ - if isinstance(key, str): - key = self.MODEL.__dict__[key] - return self.db_session.query(self.MODEL).filter(key == value).first() - - def winnow_models(self, mlist, key, to_remove): - """Given a list of models, remove any with matching - some member of a list of values. - - Args: - mlist (list of model rows): the initial list of models - key (column): the column to filter on - to_remove (list): if anything in mlist has column equal to - one of the values in to_remove, it will be removed from the - result - Returns: - A list of model rows starting from mlist minus any matching rows - """ - if isinstance(key, str): - key = self.MODEL.__dict__[key] - - if to_remove: - return mlist.filter(~key.in_(to_remove)).all() - else: - return mlist.all() - - def db_num_entries(self): - """Return the known number of entries in the lister db""" - return self.db_session.query(func.count("*")).select_from(self.MODEL).scalar() - - def db_inject_repo(self, model_dict): - """Add/update a new repo to the db and mark it last_seen now. - - Args: - model_dict: dictionary mapping model keys to values - - Returns: - new or updated sqlalchemy.ext.declarative.declarative_base - object associated with the injection - - """ - sql_repo = self.db_query_equal("uid", model_dict["uid"]) - - if not sql_repo: - sql_repo = self.MODEL(**model_dict) - self.db_session.add(sql_repo) - else: - for k in model_dict: - setattr(sql_repo, k, model_dict[k]) - sql_repo.last_seen = utcnow() - - return sql_repo - - def task_dict(self, origin_type: str, origin_url: str, **kwargs) -> Dict[str, Any]: - """Return special dict format for the tasks list - - Args: - origin_type (string) - origin_url (string) - Returns: - the same information in a different form - """ - logger.debug("origin-url: %s, type: %s", origin_url, origin_type) - _type = "load-%s" % origin_type - _policy = kwargs.get("policy", "recurring") - priority = kwargs.get("priority") - kw = {"priority": priority} if priority else {} - return utils.create_task_dict(_type, _policy, url=origin_url, **kw) - - def string_pattern_check(self, a, b, c=None): - """When comparing indexable types in is_within_bounds, complex strings - may not be allowed to differ in basic structure. If they do, it - could be a sign of not understanding the data well. For instance, - an ISO 8601 time string cannot be compared against its urlencoded - equivalent, but this is an easy mistake to accidentally make. This - method acts as a friendly sanity check. - - Args: - a (string): inner component of the is_within_bounds method - b (string): lower component of the is_within_bounds method - c (string): upper component of the is_within_bounds method - Returns: - nothing - Raises: - TypeError if strings a, b, and c don't conform to the same basic - pattern. - """ - if isinstance(a, str): - a_pattern = re.sub("[a-zA-Z0-9]", "[a-zA-Z0-9]", re.escape(a)) - if ( - isinstance(b, str) - and (re.match(a_pattern, b) is None) - or isinstance(c, str) - and (re.match(a_pattern, c) is None) - ): - logger.debug(a_pattern) - raise TypeError("incomparable string patterns detected") - - def inject_repo_data_into_db(self, models_list: List[Dict]) -> Dict: - """Inject data into the db. - - Args: - models_list: list of dicts mapping keys from the db model - for each repo to be injected - Returns: - dict of uid:sql_repo pairs - - """ - injected_repos = {} - for m in models_list: - injected_repos[m["uid"]] = self.db_inject_repo(m) - return injected_repos - - def schedule_missing_tasks( - self, models_list: List[Dict], injected_repos: Dict - ) -> None: - """Schedule any newly created db entries that do not have been - scheduled yet. - - Args: - models_list: List of dicts mapping keys in the db model - for each repo - injected_repos: Dict of uid:sql_repo pairs that have just - been created - - Returns: - Nothing. (Note that it Modifies injected_repos to set the new - task_id). - - """ - tasks = {} - - def _task_key(m): - return "%s-%s" % (m["type"], json.dumps(m["arguments"], sort_keys=True)) - - for m in models_list: - ir = injected_repos[m["uid"]] - if not ir.task_id: - # Patching the model instance to add the policy/priority task - # scheduling - if "policy" in self.config: - m["policy"] = self.config["policy"] - if "priority" in self.config: - m["priority"] = self.config["priority"] - task_dict = self.task_dict(**m) - task_dict.setdefault("retries_left", 3) - tasks[_task_key(task_dict)] = (ir, m, task_dict) - - gen_tasks = (task_dicts for (_, _, task_dicts) in tasks.values()) - for grouped_tasks in grouper(gen_tasks, n=1000): - new_tasks = self.scheduler.create_tasks(list(grouped_tasks)) - for task in new_tasks: - ir, m, _ = tasks[_task_key(task)] - ir.task_id = task["id"] - - def ingest_data(self, identifier: int, checks: bool = False): - """The core data fetch sequence. Request server endpoint. Simplify and - filter response list of repositories. Inject repo information into - local db. Queue loader tasks for linked repositories. - - Args: - identifier: Resource identifier. - checks (bool): Additional checks required - """ - # Request (partial?) list of repositories info - response = self.safely_issue_request(identifier) - if not response: - return response, [] - models_list = self.transport_response_simplified(response) - models_list = self.filter_before_inject(models_list) - if checks: - models_list = self.do_additional_checks(models_list) - if not models_list: - return response, [] - # inject into local db - injected = self.inject_repo_data_into_db(models_list) - # queue workers - self.schedule_missing_tasks(models_list, injected) - return response, injected - - def save_response(self, response): - """Log the response from a server request to a cache dir. - - Args: - response: full server response - cache_dir: system path for cache dir - Returns: - nothing - """ - datepath = utcnow().isoformat() - - fname = os.path.join(self.config["cache_dir"], datepath + ".gz",) - - with gzip.open(fname, "w") as f: - f.write(bytes(self.transport_response_to_string(response), "UTF-8")) diff --git a/swh/lister/core/lister_transports.py b/swh/lister/core/lister_transports.py deleted file mode 100644 --- a/swh/lister/core/lister_transports.py +++ /dev/null @@ -1,233 +0,0 @@ -# Copyright (C) 2017-2018 the Software Heritage developers -# License: GNU General Public License version 3, or any later version -# See top-level LICENSE file for more information - -import abc -from datetime import datetime -from email.utils import parsedate -import logging -from pprint import pformat -import random -from typing import Any, Dict, List, Optional, Union - -import requests -from requests import Response -import xmltodict - -from swh.lister import USER_AGENT_TEMPLATE, __version__ - -from .abstractattribute import AbstractAttribute -from .lister_base import FetchError - -logger = logging.getLogger(__name__) - - -class ListerHttpTransport(abc.ABC): - """Use the Requests library for making Lister endpoint requests. - - To be used in conjunction with ListerBase or a subclass of it. - """ - - DEFAULT_URL = None # type: Optional[str] - PATH_TEMPLATE = AbstractAttribute( - "string containing a python string format pattern that produces" - " the API endpoint path for listing stored repositories when given" - ' an index, e.g., "/repositories?after=%s". To be implemented in' - " the API-specific class inheriting this." - ) # type: Union[AbstractAttribute, Optional[str]] - - EXPECTED_STATUS_CODES = (200, 429, 403, 404) - - def request_headers(self) -> Dict[str, Any]: - """Returns dictionary of any request headers needed by the server. - - MAY BE OVERRIDDEN if request headers are needed. - """ - return {"User-Agent": USER_AGENT_TEMPLATE % self.lister_version} - - def request_instance_credentials(self) -> List[Dict[str, Any]]: - """Returns dictionary of any credentials configuration needed by the - forge instance to list. - - The 'credentials' configuration is expected to be a dict of multiple - levels. The first level is the lister's name, the second is the - lister's instance name, which value is expected to be a list of - credential structures (typically a couple username/password). - - For example:: - - credentials: - github: # github lister - github: # has only one instance (so far) - - username: some - password: somekey - - username: one - password: onekey - - ... - gitlab: # gitlab lister - riseup: # has many instances - - username: someone - password: ... - - ... - gitlab: - - username: someone - password: ... - - ... - - Returns: - list of credential dicts for the current lister. - - """ - all_creds = self.config.get("credentials") # type: ignore - if not all_creds: - return [] - lister_creds = all_creds.get(self.LISTER_NAME, {}) # type: ignore - creds = lister_creds.get(self.instance, []) # type: ignore - return creds - - def request_uri(self, identifier: str) -> str: - """Get the full request URI given the transport_request identifier. - - MAY BE OVERRIDDEN if something more complex than the PATH_TEMPLATE is - required. - """ - path = self.PATH_TEMPLATE % identifier # type: ignore - return self.url + path - - def request_params(self, identifier: str) -> Dict[str, Any]: - """Get the full parameters passed to requests given the - transport_request identifier. - - This uses credentials if any are provided (see - request_instance_credentials). - - MAY BE OVERRIDDEN if something more complex than the request headers - is needed. - - """ - params = {} - params["headers"] = self.request_headers() or {} - creds = self.request_instance_credentials() - if not creds: - return params - auth = random.choice(creds) if creds else None - if auth: - params["auth"] = ( - auth["username"], # type: ignore - auth["password"], - ) - return params - - def transport_quota_check(self, response): - """Implements ListerBase.transport_quota_check with standard 429 - code check for HTTP with Requests library. - - MAY BE OVERRIDDEN if the server notifies about rate limits in a - non-standard way that doesn't use HTTP 429 and the Retry-After - response header. ( https://tools.ietf.org/html/rfc6585#section-4 ) - - """ - if response.status_code == 429: # HTTP too many requests - retry_after = response.headers.get("Retry-After", self.back_off()) - try: - # might be seconds - return True, float(retry_after) - except Exception: - # might be http-date - at_date = datetime(*parsedate(retry_after)[:6]) - from_now = (at_date - datetime.today()).total_seconds() + 5 - return True, max(0, from_now) - else: # response ok - self.reset_backoff() - return False, 0 - - def __init__(self, url=None): - if not url: - url = self.config.get("url") - if not url: - url = self.DEFAULT_URL - if not url: - raise NameError("HTTP Lister Transport requires an url.") - self.url = url # eg. 'https://api.github.com' - self.session = requests.Session() - self.lister_version = __version__ - - def _transport_action(self, identifier: str, method: str = "get") -> Response: - """Permit to ask information to the api prior to actually executing - query. - - """ - path = self.request_uri(identifier) - params = self.request_params(identifier) - - logger.debug("path: %s", path) - logger.debug("params: %s", params) - logger.debug("method: %s", method) - try: - if method == "head": - response = self.session.head(path, **params) - else: - response = self.session.get(path, **params) - except requests.exceptions.ConnectionError as e: - logger.warning("Failed to fetch %s: %s", path, e) - raise FetchError(e) - else: - if response.status_code not in self.EXPECTED_STATUS_CODES: - raise FetchError(response) - return response - - def transport_head(self, identifier: str) -> Response: - """Retrieve head information on api. - - """ - return self._transport_action(identifier, method="head") - - def transport_request(self, identifier: str) -> Response: - """Implements ListerBase.transport_request for HTTP using Requests. - - Retrieve get information on api. - - """ - return self._transport_action(identifier) - - def transport_response_to_string(self, response: Response) -> str: - """Implements ListerBase.transport_response_to_string for HTTP given - Requests responses. - """ - s = pformat(response.request.path_url) - s += "\n#\n" + pformat(response.request.headers) - s += "\n#\n" + pformat(response.status_code) - s += "\n#\n" + pformat(response.headers) - s += "\n#\n" - try: # json? - s += pformat(response.json()) - except Exception: # not json - try: # xml? - s += pformat(xmltodict.parse(response.text)) - except Exception: # not xml - s += pformat(response.text) - return s - - -class ListerOnePageApiTransport(ListerHttpTransport): - """Leverage requests library to retrieve basic html page and parse - result. - - To be used in conjunction with ListerBase or a subclass of it. - - """ - - PAGE = AbstractAttribute( - "URL of the API's unique page to retrieve and parse " "for information" - ) # type: Union[AbstractAttribute, str] - PATH_TEMPLATE = None # we do not use it - - def __init__(self, url=None): - self.session = requests.Session() - self.lister_version = __version__ - - def request_uri(self, _): - """Get the full request URI given the transport_request identifier. - - """ - return self.PAGE diff --git a/swh/lister/core/models.py b/swh/lister/core/models.py deleted file mode 100644 --- a/swh/lister/core/models.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (C) 2015-2019 the Software Heritage developers -# License: GNU General Public License version 3, or any later version -# See top-level LICENSE file for more information - -import abc -from datetime import datetime -import logging -from typing import Type, Union - -from sqlalchemy import Column, DateTime, Integer, String -from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base - -from .abstractattribute import AbstractAttribute - -SQLBase = declarative_base() - - -logger = logging.getLogger(__name__) - - -class ABCSQLMeta(abc.ABCMeta, DeclarativeMeta): - pass - - -class ModelBase(SQLBase, metaclass=ABCSQLMeta): - """a common repository""" - - __abstract__ = True - __tablename__ = AbstractAttribute # type: Union[Type[AbstractAttribute], str] - - uid = AbstractAttribute( - "Column(, primary_key=True)" - ) # type: Union[AbstractAttribute, Column] - - name = Column(String, index=True) - full_name = Column(String, index=True) - html_url = Column(String) - origin_url = Column(String) - origin_type = Column(String) - - last_seen = Column(DateTime, nullable=False) - - task_id = Column(Integer) - - def __init__(self, **kw): - kw["last_seen"] = datetime.now() - super().__init__(**kw) - - -class IndexingModelBase(ModelBase, metaclass=ABCSQLMeta): - __abstract__ = True - __tablename__ = AbstractAttribute # type: Union[Type[AbstractAttribute], str] - - # The value used for sorting, segmenting, or api query paging, - # because uids aren't always sequential. - indexable = AbstractAttribute( - "Column(, index=True)" - ) # type: Union[AbstractAttribute, Column] - - -def initialize(db_engine, drop_tables=False, **kwargs): - """Default database initialization function for a lister. - - Typically called from the lister's initialization hook. - - Args: - models (list): list of SQLAlchemy tables/models to drop/create. - db_engine (): the SQLAlchemy DB engine. - drop_tables (bool): if True, tables will be dropped before - (re)creating them. - """ - if drop_tables: - logger.info("Dropping tables") - SQLBase.metadata.drop_all(db_engine, checkfirst=True) - - logger.info("Creating tables") - SQLBase.metadata.create_all(db_engine, checkfirst=True) diff --git a/swh/lister/core/simple_lister.py b/swh/lister/core/simple_lister.py deleted file mode 100644 --- a/swh/lister/core/simple_lister.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (C) 2018-2019 The Software Heritage developers -# See the AUTHORS file at the top-level directory of this distribution -# License: GNU General Public License version 3, or any later version -# See top-level LICENSE file for more information - -import logging -from typing import Any, List - -from swh.core import utils - -from .lister_base import ListerBase - -logger = logging.getLogger(__name__) - - -class SimpleLister(ListerBase): - """Lister* intermediate class for any service that follows the simple, - 'list in oneshot information' pattern. - - - Client sends a request to list repositories in oneshot - - - Client receives structured (json/xml/etc) response with - information and stores those in db - - """ - - flush_packet_db = 2 - """Number of iterations in-between write flushes of lister repositories to - db (see fn:`ingest_data`). - """ - - def list_packages(self, response: Any) -> List[Any]: - """Listing packages method. - - """ - pass - - def ingest_data(self, identifier, checks=False): - """Rework the base ingest_data. - Request server endpoint which gives all in one go. - - Simplify and filter response list of repositories. Inject - repo information into local db. Queue loader tasks for - linked repositories. - - Args: - identifier: Resource identifier (unused) - checks (bool): Additional checks required (unused) - - """ - response = self.safely_issue_request(identifier) - response = self.list_packages(response) - if not response: - return response, [] - models_list = self.transport_response_simplified(response) - models_list = self.filter_before_inject(models_list) - all_injected = [] - for i, models in enumerate(utils.grouper(models_list, n=100), start=1): - models = list(models) - logging.debug("models: %s" % len(models)) - # inject into local db - injected = self.inject_repo_data_into_db(models) - # queue workers - self.schedule_missing_tasks(models, injected) - all_injected.append(injected) - if (i % self.flush_packet_db) == 0: - logger.debug("Flushing updates at index %s", i) - self.db_session.commit() - self.db_session = self.mk_session() - - return response, all_injected - - def transport_response_simplified(self, response): - """Transform response to list for model manipulation - - """ - return [self.get_model_from_repo(repo_name) for repo_name in response] - - def run(self): - """Query the server which answers in one query. Stores the - information, dropping actual redundant information we - already have. - - Returns: - nothing - - """ - dump_not_used_identifier = 0 - response, injected_repos = self.ingest_data(dump_not_used_identifier) - if not response and not injected_repos: - logging.info("No response from api server, stopping") - status = "uneventful" - else: - status = "eventful" - - return {"status": status} diff --git a/swh/lister/core/tests/__init__.py b/swh/lister/core/tests/__init__.py deleted file mode 100644 diff --git a/swh/lister/core/tests/test_abstractattribute.py b/swh/lister/core/tests/test_abstractattribute.py deleted file mode 100644 --- a/swh/lister/core/tests/test_abstractattribute.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (C) 2017 the Software Heritage developers -# License: GNU General Public License version 3, or any later version -# See top-level LICENSE file for more information - -import abc -from typing import Any -import unittest - -from swh.lister.core.abstractattribute import AbstractAttribute - - -class BaseClass(abc.ABC): - v1 = AbstractAttribute # type: Any - v2 = AbstractAttribute() # type: Any - v3 = AbstractAttribute("changed docstring") # type: Any - v4 = "qux" - - -class BadSubclass1(BaseClass): - pass - - -class BadSubclass2(BaseClass): - v1 = "foo" - v2 = "bar" - - -class BadSubclass3(BaseClass): - v2 = "bar" - v3 = "baz" - - -class GoodSubclass(BaseClass): - v1 = "foo" - v2 = "bar" - v3 = "baz" - - -class TestAbstractAttributes(unittest.TestCase): - def test_aa(self): - with self.assertRaises(TypeError): - BaseClass() - - with self.assertRaises(TypeError): - BadSubclass1() - - with self.assertRaises(TypeError): - BadSubclass2() - - with self.assertRaises(TypeError): - BadSubclass3() - - self.assertIsInstance(GoodSubclass(), GoodSubclass) - gsc = GoodSubclass() - - self.assertEqual(gsc.v1, "foo") - self.assertEqual(gsc.v2, "bar") - self.assertEqual(gsc.v3, "baz") - self.assertEqual(gsc.v4, "qux") - - def test_aa_docstrings(self): - self.assertEqual(BaseClass.v1.__doc__, AbstractAttribute.__doc__) - self.assertEqual(BaseClass.v2.__doc__, AbstractAttribute.__doc__) - self.assertEqual(BaseClass.v3.__doc__, "AbstractAttribute: changed docstring") diff --git a/swh/lister/core/tests/test_lister.py b/swh/lister/core/tests/test_lister.py deleted file mode 100644 --- a/swh/lister/core/tests/test_lister.py +++ /dev/null @@ -1,453 +0,0 @@ -# Copyright (C) 2019 the Software Heritage developers -# License: GNU General Public License version 3, or any later version -# See top-level LICENSE file for more information - -import abc -import datetime -import time -from typing import Any, Callable, Optional, Pattern, Type, Union -from unittest import TestCase -from unittest.mock import Mock, patch - -import requests_mock -from sqlalchemy import create_engine - -import swh.lister -from swh.lister.core.abstractattribute import AbstractAttribute -from swh.lister.tests.test_utils import init_db - - -def noop(*args, **kwargs): - pass - - -def test_version_generation(): - assert ( - swh.lister.__version__ != "devel" - ), "Make sure swh.lister is installed (e.g. pip install -e .)" - - -class HttpListerTesterBase(abc.ABC): - """Testing base class for listers. - This contains methods for both :class:`HttpSimpleListerTester` and - :class:`HttpListerTester`. - - See :class:`swh.lister.gitlab.tests.test_lister` for an example of how - to customize for a specific listing service. - - """ - - Lister = AbstractAttribute( - "Lister class to test" - ) # type: Union[AbstractAttribute, Type[Any]] - lister_subdir = AbstractAttribute( - "bitbucket, github, etc." - ) # type: Union[AbstractAttribute, str] - good_api_response_file = AbstractAttribute( - "Example good response body" - ) # type: Union[AbstractAttribute, str] - LISTER_NAME = "fake-lister" - - # May need to override this if the headers are used for something - def response_headers(self, request): - return {} - - # May need to override this if the server uses non-standard rate limiting - # method. - # Please keep the requested retry delay reasonably low. - def mock_rate_quota(self, n, request, context): - self.rate_limit += 1 - context.status_code = 429 - context.headers["Retry-After"] = "1" - return '{"error":"dummy"}' - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.rate_limit = 1 - self.response = None - self.fl = None - self.helper = None - self.scheduler_tasks = [] - if self.__class__ != HttpListerTesterBase: - self.run = TestCase.run.__get__(self, self.__class__) - else: - self.run = noop - - def mock_limit_n_response(self, n, request, context): - self.fl.reset_backoff() - if self.rate_limit <= n: - return self.mock_rate_quota(n, request, context) - else: - return self.mock_response(request, context) - - def mock_limit_twice_response(self, request, context): - return self.mock_limit_n_response(2, request, context) - - def get_api_response(self, identifier): - fl = self.get_fl() - if self.response is None: - self.response = fl.safely_issue_request(identifier) - return self.response - - def get_fl(self, override_config=None): - """Retrieve an instance of fake lister (fl). - - """ - if override_config or self.fl is None: - self.fl = self.Lister( - url="https://fakeurl", override_config=override_config - ) - self.fl.INITIAL_BACKOFF = 1 - - self.fl.reset_backoff() - self.scheduler_tasks = [] - return self.fl - - def disable_scheduler(self, fl): - fl.schedule_missing_tasks = Mock(return_value=None) - - def mock_scheduler(self, fl): - def _create_tasks(tasks): - task_id = 0 - current_nb_tasks = len(self.scheduler_tasks) - if current_nb_tasks > 0: - task_id = self.scheduler_tasks[-1]["id"] + 1 - for task in tasks: - scheduler_task = dict(task) - scheduler_task.update( - { - "status": "next_run_not_scheduled", - "retries_left": 0, - "priority": None, - "id": task_id, - "current_interval": datetime.timedelta(days=64), - } - ) - self.scheduler_tasks.append(scheduler_task) - task_id = task_id + 1 - return self.scheduler_tasks[current_nb_tasks:] - - def _disable_tasks(task_ids): - for task_id in task_ids: - self.scheduler_tasks[task_id]["status"] = "disabled" - - fl.scheduler.create_tasks = Mock(wraps=_create_tasks) - fl.scheduler.disable_tasks = Mock(wraps=_disable_tasks) - - def disable_db(self, fl): - fl.winnow_models = Mock(return_value=[]) - fl.db_inject_repo = Mock(return_value=fl.MODEL()) - fl.disable_deleted_repo_tasks = Mock(return_value=None) - - def init_db(self, db, model): - engine = create_engine(db.url()) - model.metadata.create_all(engine) - - @requests_mock.Mocker() - def test_is_within_bounds(self, http_mocker): - fl = self.get_fl() - self.assertFalse(fl.is_within_bounds(1, 2, 3)) - self.assertTrue(fl.is_within_bounds(2, 1, 3)) - self.assertTrue(fl.is_within_bounds(1, 1, 1)) - self.assertTrue(fl.is_within_bounds(1, None, None)) - self.assertTrue(fl.is_within_bounds(1, None, 2)) - self.assertTrue(fl.is_within_bounds(1, 0, None)) - self.assertTrue(fl.is_within_bounds("b", "a", "c")) - self.assertFalse(fl.is_within_bounds("a", "b", "c")) - self.assertTrue(fl.is_within_bounds("a", None, "c")) - self.assertTrue(fl.is_within_bounds("a", None, None)) - self.assertTrue(fl.is_within_bounds("b", "a", None)) - self.assertFalse(fl.is_within_bounds("a", "b", None)) - self.assertTrue(fl.is_within_bounds("aa:02", "aa:01", "aa:03")) - self.assertFalse(fl.is_within_bounds("aa:12", None, "aa:03")) - with self.assertRaises(TypeError): - fl.is_within_bounds(1.0, "b", None) - with self.assertRaises(TypeError): - fl.is_within_bounds("A:B", "A::B", None) - - -class HttpListerTester(HttpListerTesterBase, abc.ABC): - """Base testing class for subclass of - - :class:`swh.lister.core.indexing_lister.IndexingHttpLister` - - See :class:`swh.lister.github.tests.test_gh_lister` for an example of how - to customize for a specific listing service. - - """ - - last_index = AbstractAttribute( - "Last index " "in good_api_response" - ) # type: Union[AbstractAttribute, int] - first_index = AbstractAttribute( - "First index in " " good_api_response" - ) # type: Union[AbstractAttribute, Optional[int]] - bad_api_response_file = AbstractAttribute( - "Example bad response body" - ) # type: Union[AbstractAttribute, str] - entries_per_page = AbstractAttribute( - "Number of results in " "good response" - ) # type: Union[AbstractAttribute, int] - test_re = AbstractAttribute( - "Compiled regex matching the server url. Must capture the " "index value." - ) # type: Union[AbstractAttribute, Pattern] - convert_type = str # type: Callable[..., Any] - """static method used to convert the "request_index" to its right type (for - indexing listers for example, this is in accordance with the model's - "indexable" column). - - """ - - def mock_response(self, request, context): - self.fl.reset_backoff() - self.rate_limit = 1 - context.status_code = 200 - custom_headers = self.response_headers(request) - context.headers.update(custom_headers) - req_index = self.request_index(request) - - if req_index == self.first_index: - response_file = self.good_api_response_file - else: - response_file = self.bad_api_response_file - - with open( - "swh/lister/%s/tests/%s" % (self.lister_subdir, response_file), - "r", - encoding="utf-8", - ) as r: - return r.read() - - def request_index(self, request): - m = self.test_re.search(request.path_url) - if m and (len(m.groups()) > 0): - return self.convert_type(m.group(1)) - - def create_fl_with_db(self, http_mocker): - http_mocker.get(self.test_re, text=self.mock_response) - db = init_db() - - fl = self.get_fl( - override_config={"lister": {"cls": "local", "args": {"db": db.url()}}} - ) - fl.db = db - self.init_db(db, fl.MODEL) - - self.mock_scheduler(fl) - return fl - - @requests_mock.Mocker() - def test_fetch_no_bounds_yesdb(self, http_mocker): - fl = self.create_fl_with_db(http_mocker) - - fl.run() - - self.assertEqual(fl.db_last_index(), self.last_index) - ingested_repos = list(fl.db_query_range(self.first_index, self.last_index)) - self.assertEqual(len(ingested_repos), self.entries_per_page) - - @requests_mock.Mocker() - def test_fetch_multiple_pages_yesdb(self, http_mocker): - - fl = self.create_fl_with_db(http_mocker) - fl.run(min_bound=self.first_index) - - self.assertEqual(fl.db_last_index(), self.last_index) - - partitions = fl.db_partition_indices(5) - self.assertGreater(len(partitions), 0) - for k in partitions: - self.assertLessEqual(len(k), 5) - self.assertGreater(len(k), 0) - - @requests_mock.Mocker() - def test_fetch_none_nodb(self, http_mocker): - http_mocker.get(self.test_re, text=self.mock_response) - fl = self.get_fl() - - self.disable_scheduler(fl) - self.disable_db(fl) - - fl.run(min_bound=1, max_bound=1) # stores no results - # FIXME: Determine what this method tries to test and add checks to - # actually test - - @requests_mock.Mocker() - def test_fetch_one_nodb(self, http_mocker): - http_mocker.get(self.test_re, text=self.mock_response) - fl = self.get_fl() - - self.disable_scheduler(fl) - self.disable_db(fl) - - fl.run(min_bound=self.first_index, max_bound=self.first_index) - # FIXME: Determine what this method tries to test and add checks to - # actually test - - @requests_mock.Mocker() - def test_fetch_multiple_pages_nodb(self, http_mocker): - http_mocker.get(self.test_re, text=self.mock_response) - fl = self.get_fl() - - self.disable_scheduler(fl) - self.disable_db(fl) - - fl.run(min_bound=self.first_index) - # FIXME: Determine what this method tries to test and add checks to - # actually test - - @requests_mock.Mocker() - def test_repos_list(self, http_mocker): - """Test the number of repos listed by the lister - - """ - http_mocker.get(self.test_re, text=self.mock_response) - li = self.get_fl().transport_response_simplified( - self.get_api_response(self.first_index) - ) - self.assertIsInstance(li, list) - self.assertEqual(len(li), self.entries_per_page) - - @requests_mock.Mocker() - def test_model_map(self, http_mocker): - """Check if all the keys of model are present in the model created by - the `transport_response_simplified` - - """ - http_mocker.get(self.test_re, text=self.mock_response) - fl = self.get_fl() - li = fl.transport_response_simplified(self.get_api_response(self.first_index)) - di = li[0] - self.assertIsInstance(di, dict) - pubs = [k for k in vars(fl.MODEL).keys() if not k.startswith("_")] - for k in pubs: - if k not in ["last_seen", "task_id", "id"]: - self.assertIn(k, di) - - @requests_mock.Mocker() - def test_api_request(self, http_mocker): - """Test API request for rate limit handling - - """ - http_mocker.get(self.test_re, text=self.mock_limit_twice_response) - with patch.object(time, "sleep", wraps=time.sleep) as sleepmock: - self.get_api_response(self.first_index) - self.assertEqual(sleepmock.call_count, 2) - - @requests_mock.Mocker() - def test_request_headers(self, http_mocker): - fl = self.create_fl_with_db(http_mocker) - fl.run() - self.assertNotEqual(len(http_mocker.request_history), 0) - for request in http_mocker.request_history: - assert "User-Agent" in request.headers - user_agent = request.headers["User-Agent"] - assert "Software Heritage Lister" in user_agent - assert swh.lister.__version__ in user_agent - - def scheduled_tasks_test( - self, next_api_response_file, next_last_index, http_mocker - ): - """Check that no loading tasks get disabled when processing a new - page of repositories returned by a forge API - """ - fl = self.create_fl_with_db(http_mocker) - - # process first page of repositories listing - fl.run() - - # process second page of repositories listing - prev_last_index = self.last_index - self.first_index = self.last_index - self.last_index = next_last_index - self.good_api_response_file = next_api_response_file - fl.run(min_bound=prev_last_index) - - # check expected number of ingested repos and loading tasks - ingested_repos = list(fl.db_query_range(0, self.last_index)) - self.assertEqual(len(ingested_repos), len(self.scheduler_tasks)) - self.assertEqual(len(ingested_repos), 2 * self.entries_per_page) - - # check tasks are not disabled - for task in self.scheduler_tasks: - self.assertTrue(task["status"] != "disabled") - - -class HttpSimpleListerTester(HttpListerTesterBase, abc.ABC): - """Base testing class for subclass of - :class:`swh.lister.core.simple)_lister.SimpleLister` - - See :class:`swh.lister.pypi.tests.test_lister` for an example of how - to customize for a specific listing service. - - """ - - entries = AbstractAttribute( - "Number of results " "in good response" - ) # type: Union[AbstractAttribute, int] - PAGE = AbstractAttribute( - "URL of the server api's unique page to retrieve and " "parse for information" - ) # type: Union[AbstractAttribute, str] - - def get_fl(self, override_config=None): - """Retrieve an instance of fake lister (fl). - - """ - if override_config or self.fl is None: - self.fl = self.Lister(override_config=override_config) - self.fl.INITIAL_BACKOFF = 1 - - self.fl.reset_backoff() - return self.fl - - def mock_response(self, request, context): - self.fl.reset_backoff() - self.rate_limit = 1 - context.status_code = 200 - custom_headers = self.response_headers(request) - context.headers.update(custom_headers) - response_file = self.good_api_response_file - - with open( - "swh/lister/%s/tests/%s" % (self.lister_subdir, response_file), - "r", - encoding="utf-8", - ) as r: - return r.read() - - @requests_mock.Mocker() - def test_api_request(self, http_mocker): - """Test API request for rate limit handling - - """ - http_mocker.get(self.PAGE, text=self.mock_limit_twice_response) - with patch.object(time, "sleep", wraps=time.sleep) as sleepmock: - self.get_api_response(0) - self.assertEqual(sleepmock.call_count, 2) - - @requests_mock.Mocker() - def test_model_map(self, http_mocker): - """Check if all the keys of model are present in the model created by - the `transport_response_simplified` - - """ - http_mocker.get(self.PAGE, text=self.mock_response) - fl = self.get_fl() - li = fl.list_packages(self.get_api_response(0)) - li = fl.transport_response_simplified(li) - di = li[0] - self.assertIsInstance(di, dict) - pubs = [k for k in vars(fl.MODEL).keys() if not k.startswith("_")] - for k in pubs: - if k not in ["last_seen", "task_id", "id"]: - self.assertIn(k, di) - - @requests_mock.Mocker() - def test_repos_list(self, http_mocker): - """Test the number of packages listed by the lister - - """ - http_mocker.get(self.PAGE, text=self.mock_response) - li = self.get_fl().list_packages(self.get_api_response(0)) - self.assertIsInstance(li, list) - self.assertEqual(len(li), self.entries) diff --git a/swh/lister/core/tests/test_model.py b/swh/lister/core/tests/test_model.py deleted file mode 100644 --- a/swh/lister/core/tests/test_model.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (C) 2017 the Software Heritage developers -# License: GNU General Public License version 3, or any later version -# See top-level LICENSE file for more information - -import unittest - -from sqlalchemy import Column, Integer - -from swh.lister.core.models import IndexingModelBase, ModelBase - - -class BadSubclass1(ModelBase): - __abstract__ = True - pass - - -class BadSubclass2(ModelBase): - __abstract__ = True - __tablename__ = "foo" - - -class BadSubclass3(BadSubclass2): - __abstract__ = True - pass - - -class GoodSubclass(BadSubclass2): - uid = Column(Integer, primary_key=True) - indexable = Column(Integer, index=True) - - -class IndexingBadSubclass(IndexingModelBase): - __abstract__ = True - pass - - -class IndexingBadSubclass2(IndexingModelBase): - __abstract__ = True - __tablename__ = "foo" - - -class IndexingBadSubclass3(IndexingBadSubclass2): - __abstract__ = True - pass - - -class IndexingGoodSubclass(IndexingModelBase): - uid = Column(Integer, primary_key=True) - indexable = Column(Integer, index=True) - __tablename__ = "bar" - - -class TestModel(unittest.TestCase): - def test_model_instancing(self): - with self.assertRaises(TypeError): - ModelBase() - - with self.assertRaises(TypeError): - BadSubclass1() - - with self.assertRaises(TypeError): - BadSubclass2() - - with self.assertRaises(TypeError): - BadSubclass3() - - self.assertIsInstance(GoodSubclass(), GoodSubclass) - gsc = GoodSubclass(uid="uid") - - self.assertEqual(gsc.__tablename__, "foo") - self.assertEqual(gsc.uid, "uid") - - def test_indexing_model_instancing(self): - with self.assertRaises(TypeError): - IndexingModelBase() - - with self.assertRaises(TypeError): - IndexingBadSubclass() - - with self.assertRaises(TypeError): - IndexingBadSubclass2() - - with self.assertRaises(TypeError): - IndexingBadSubclass3() - - self.assertIsInstance(IndexingGoodSubclass(), IndexingGoodSubclass) - gsc = IndexingGoodSubclass(uid="uid", indexable="indexable") - - self.assertEqual(gsc.__tablename__, "bar") - self.assertEqual(gsc.uid, "uid") - self.assertEqual(gsc.indexable, "indexable") diff --git a/swh/lister/pytest_plugin.py b/swh/lister/pytest_plugin.py deleted file mode 100644 --- a/swh/lister/pytest_plugin.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright (C) 2019-2020 The Software Heritage developers -# See the AUTHORS file at the top-level directory of this distribution -# License: GNU General Public License version 3, or any later version -# See top-level LICENSE file for more information - -import logging -import os - -import pytest -from sqlalchemy import create_engine -import yaml - -from swh.lister import SUPPORTED_LISTERS, get_lister -from swh.lister.core.models import initialize - -logger = logging.getLogger(__name__) - - -@pytest.fixture -def lister_db_url(postgresql): - db_params = postgresql.get_dsn_parameters() - db_url = "postgresql://{user}@{host}:{port}/{dbname}".format(**db_params) - logger.debug("lister db_url: %s", db_url) - return db_url - - -@pytest.fixture -def lister_under_test(): - """Fixture to determine which lister to test""" - return "core" - - -@pytest.fixture -def swh_lister_config(lister_db_url, swh_scheduler_config): - return { - "scheduler": {"cls": "local", **swh_scheduler_config}, - "lister": {"cls": "local", "args": {"db": lister_db_url},}, - "credentials": {}, - "cache_responses": False, - } - - -@pytest.fixture(autouse=True) -def swh_config(swh_lister_config, monkeypatch, tmp_path): - conf_path = os.path.join(str(tmp_path), "lister.yml") - with open(conf_path, "w") as f: - f.write(yaml.dump(swh_lister_config)) - monkeypatch.setenv("SWH_CONFIG_FILENAME", conf_path) - return conf_path - - -@pytest.fixture -def engine(lister_db_url): - engine = create_engine(lister_db_url) - initialize(engine, drop_tables=True) - return engine - - -@pytest.fixture -def swh_lister(engine, lister_db_url, lister_under_test, swh_config): - assert lister_under_test in SUPPORTED_LISTERS - return get_lister(lister_under_test, db_url=lister_db_url) diff --git a/swh/lister/tests/test_cli.py b/swh/lister/tests/test_cli.py --- a/swh/lister/tests/test_cli.py +++ b/swh/lister/tests/test_cli.py @@ -1,4 +1,4 @@ -# Copyright (C) 2019-2020 The Software Heritage developers +# Copyright (C) 2019-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -7,8 +7,6 @@ from swh.lister.cli import SUPPORTED_LISTERS, get_lister -from .test_utils import init_db - lister_args = { "cgit": {"url": "https://git.eclipse.org/c/",}, "phabricator": { @@ -33,13 +31,11 @@ """Instantiating a supported lister should be ok """ - db_url = init_db().url() # Drop launchpad lister from the lister to check, its test setup is more involved # than the other listers and it's not currently done here for lister_name in SUPPORTED_LISTERS: lst = get_lister( lister_name, - db_url, scheduler={"cls": "local", **swh_scheduler_config}, **lister_args.get(lister_name, {}), ) diff --git a/swh/lister/tests/test_utils.py b/swh/lister/tests/test_utils.py --- a/swh/lister/tests/test_utils.py +++ b/swh/lister/tests/test_utils.py @@ -6,7 +6,6 @@ import requests from requests.status_codes import codes from tenacity.wait import wait_fixed -from testing.postgresql import Postgresql from swh.lister.utils import ( MAX_NUMBER_ATTEMPTS, @@ -37,18 +36,6 @@ next(split_range(total_pages, nb_pages)) -def init_db(): - """Factorize the db_url instantiation - - Returns: - db object to ease db manipulation - - """ - initdb_args = Postgresql.DEFAULT_SETTINGS["initdb_args"] - initdb_args = " ".join([initdb_args, "-E UTF-8"]) - return Postgresql(initdb_args=initdb_args) - - TEST_URL = "https://example.og/api/repositories"