diff --git a/swh/lister/core/lister_base.py b/swh/lister/core/lister_base.py index 22a35f8..93b78b8 100644 --- a/swh/lister/core/lister_base.py +++ b/swh/lister/core/lister_base.py @@ -1,545 +1,520 @@ # Copyright (C) 2015-2018 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import abc import datetime import gzip import json import logging import os import re import time from sqlalchemy import create_engine, func from sqlalchemy.orm import sessionmaker from swh.core import config from swh.scheduler import get_scheduler, utils from swh.storage import get_storage from .abstractattribute import AbstractAttribute logger = logging.getLogger(__name__) def utcnow(): return datetime.datetime.now(tz=datetime.timezone.utc) class FetchError(RuntimeError): def __init__(self, response): self.response = response def __str__(self): return repr(self.response) class SWHListerBase(abc.ABC, config.SWHConfig): """Lister core base class. Generally a source code hosting service provides an API endpoint for listing the set of stored repositories. A Lister is the discovery service responsible for finding this list, all at once or sequentially by parts, and queueing local tasks to fetch and ingest the referenced repositories. The core method in this class is ingest_data. Any subclasses should be calling this method one or more times to fetch and ingest data from API endpoints. See swh.lister.core.lister_base.SWHIndexingLister for example usage. This class cannot be instantiated. Any instantiable Lister descending from SWHListerBase must provide at least the required overrides. (see member docstrings for details): Required Overrides: MODEL def transport_request def transport_response_to_string def transport_response_simplified def transport_quota_check Optional Overrides: def filter_before_inject def is_within_bounds """ MODEL = AbstractAttribute('Subclass type (not instance)' ' of swh.lister.core.models.ModelBase' ' customized for a specific service.') LISTER_NAME = AbstractAttribute("Lister's name") def transport_request(self, identifier): """Given a target endpoint identifier to query, try once to request it. Implementation of this method determines the network request protocol. Args: identifier (string): unique identifier for an endpoint query. e.g. If the service indexes lists of repositories by date and time of creation, this might be that as a formatted string. Or it might be an integer UID. Or it might be nothing. It depends on what the service needs. Returns: the entire request response Raises: Will catch internal transport-dependent connection exceptions and raise swh.lister.core.lister_base.FetchError instead. Other non-connection exceptions should propagate unchanged. """ pass def transport_response_to_string(self, response): """Convert the server response into a formatted string for logging. Implementation of this method depends on the shape of the network response object returned by the transport_request method. Args: response: the server response Returns: a pretty string of the response """ pass def transport_response_simplified(self, response): """Convert the server response into list of a dict for each repo in the response, mapping columns in the lister's MODEL class to repo data. Implementation of this method depends on the server API spec and the shape of the network response object returned by the transport_request method. Args: response: response object from the server. Returns: list of repo MODEL dicts ( eg. [{'uid': r['id'], etc.} for r in response.json()] ) """ pass def transport_quota_check(self, response): """Check server response to see if we're hitting request rate limits. Implementation of this method depends on the server communication protocol and API spec and the shape of the network response object returned by the transport_request method. Args: response (session response): complete API query response Returns: 1) must retry request? True/False 2) seconds to delay if True """ pass def filter_before_inject(self, models_list): """Function run after transport_response_simplified but before injection into the local db and creation of workers. Can be used to eliminate some of the results if necessary. MAY BE OVERRIDDEN if an intermediate Lister class needs to filter results before injection without requiring every child class to do so. Args: models_list: list of dicts returned by transport_response_simplified. Returns: models_list with entries changed according to custom logic. """ return models_list def do_additional_checks(self, models_list): """Execute some additional checks on the model list. For example, to check for existing repositories in the db. MAY BE OVERRIDDEN if an intermediate Lister class needs to check some more the results before injection. Checks are fine by default, returns the models_list as is by default. Args: models_list: list of dicts returned by transport_response_simplified. Returns: models_list with entries if checks ok, False otherwise """ return models_list def is_within_bounds(self, inner, lower=None, upper=None): """See if a sortable value is inside the range [lower,upper]. MAY BE OVERRIDDEN, for example if the server indexable* key is technically sortable but not automatically so. * - ( see: swh.lister.core.indexing_lister.SWHIndexingLister ) Args: inner (sortable type): the value being checked lower (sortable type): optional lower bound upper (sortable type): optional upper bound Returns: whether inner is confined by the optional lower and upper bounds """ try: if lower is None and upper is None: return True elif lower is None: ret = inner <= upper elif upper is None: ret = inner >= lower else: ret = lower <= inner <= upper self.string_pattern_check(inner, lower, upper) except Exception as e: logger.error(str(e) + ': %s, %s, %s' % (('inner=%s%s' % (type(inner), inner)), ('lower=%s%s' % (type(lower), lower)), ('upper=%s%s' % (type(upper), upper))) ) raise return ret # You probably don't need to override anything below this line. DEFAULT_CONFIG = { 'storage': ('dict', { 'cls': 'remote', 'args': { 'url': 'http://localhost:5002/' }, }), 'scheduler': ('dict', { 'cls': 'remote', 'args': { 'url': 'http://localhost:5008/' }, }), 'lister': ('dict', { 'cls': 'local', 'args': { 'db': 'postgresql:///lister', }, }), } @property def CONFIG_BASE_FILENAME(self): # noqa: N802 return 'lister_%s' % self.LISTER_NAME @property def ADDITIONAL_CONFIG(self): # noqa: N802 return { 'credentials': ('dict', {}), 'cache_responses': ('bool', False), 'cache_dir': ('str', '~/.cache/swh/lister/%s' % self.LISTER_NAME), } INITIAL_BACKOFF = 10 MAX_RETRIES = 7 CONN_SLEEP = 10 def __init__(self, override_config=None): self.backoff = self.INITIAL_BACKOFF logger.debug('Loading config from %s' % self.CONFIG_BASE_FILENAME) self.config = self.parse_config_file( base_filename=self.CONFIG_BASE_FILENAME, additional_configs=[self.ADDITIONAL_CONFIG] ) self.config['cache_dir'] = os.path.expanduser(self.config['cache_dir']) if self.config['cache_responses']: config.prepare_folders(self.config, 'cache_dir') if override_config: self.config.update(override_config) logger.debug('%s CONFIG=%s' % (self, self.config)) self.storage = get_storage(**self.config['storage']) self.scheduler = get_scheduler(**self.config['scheduler']) self.db_engine = create_engine(self.config['lister']['args']['db']) self.mk_session = sessionmaker(bind=self.db_engine) self.db_session = self.mk_session() def reset_backoff(self): """Reset exponential backoff timeout to initial level.""" self.backoff = self.INITIAL_BACKOFF def back_off(self): """Get next exponential backoff timeout.""" ret = self.backoff self.backoff *= 10 return ret def safely_issue_request(self, identifier): """Make network request with retries, rate quotas, and response logs. Protocol is handled by the implementation of the transport_request method. Args: identifier: resource identifier Returns: server response """ retries_left = self.MAX_RETRIES do_cache = self.config['cache_responses'] r = None while retries_left > 0: try: r = self.transport_request(identifier) except FetchError: # network-level connection error, try again logger.warning( 'connection error on %s: sleep for %d seconds' % (identifier, self.CONN_SLEEP)) time.sleep(self.CONN_SLEEP) retries_left -= 1 continue if do_cache: self.save_response(r) # detect throttling must_retry, delay = self.transport_quota_check(r) if must_retry: logger.warning( 'rate limited on %s: sleep for %f seconds' % (identifier, delay)) time.sleep(delay) else: # request ok break retries_left -= 1 if not retries_left: logger.warning( 'giving up on %s: max retries exceeded' % identifier) return r def db_query_equal(self, key, value): """Look in the db for a row with key == value Args: key: column key to look at value: value to look for in that column Returns: sqlalchemy.ext.declarative.declarative_base object with the given key == value """ if isinstance(key, str): key = self.MODEL.__dict__[key] return self.db_session.query(self.MODEL) \ .filter(key == value).first() def winnow_models(self, mlist, key, to_remove): """Given a list of models, remove any with matching some member of a list of values. Args: mlist (list of model rows): the initial list of models key (column): the column to filter on to_remove (list): if anything in mlist has column equal to one of the values in to_remove, it will be removed from the result Returns: A list of model rows starting from mlist minus any matching rows """ if isinstance(key, str): key = self.MODEL.__dict__[key] if to_remove: return mlist.filter(~key.in_(to_remove)).all() else: return mlist.all() def db_num_entries(self): """Return the known number of entries in the lister db""" return self.db_session.query(func.count('*')).select_from(self.MODEL) \ .scalar() def db_inject_repo(self, model_dict): """Add/update a new repo to the db and mark it last_seen now. Args: model_dict: dictionary mapping model keys to values Returns: new or updated sqlalchemy.ext.declarative.declarative_base object associated with the injection """ sql_repo = self.db_query_equal('uid', model_dict['uid']) if not sql_repo: sql_repo = self.MODEL(**model_dict) self.db_session.add(sql_repo) else: for k in model_dict: setattr(sql_repo, k, model_dict[k]) sql_repo.last_seen = utcnow() return sql_repo - def origin_dict(self, origin_type, origin_url, **kwargs): - """Return special dict format for the origins list - - Args: - origin_type (string) - origin_url (string) - Returns: - the same information in a different form - """ - return { - 'type': origin_type, - 'url': origin_url, - } - def task_dict(self, origin_type, origin_url, **kwargs): """Return special dict format for the tasks list Args: origin_type (string) origin_url (string) Returns: the same information in a different form """ _type = 'load-%s' % origin_type _policy = 'recurring' return utils.create_task_dict(_type, _policy, origin_url) def string_pattern_check(self, a, b, c=None): """When comparing indexable types in is_within_bounds, complex strings may not be allowed to differ in basic structure. If they do, it could be a sign of not understanding the data well. For instance, an ISO 8601 time string cannot be compared against its urlencoded equivalent, but this is an easy mistake to accidentally make. This method acts as a friendly sanity check. Args: a (string): inner component of the is_within_bounds method b (string): lower component of the is_within_bounds method c (string): upper component of the is_within_bounds method Returns: nothing Raises: TypeError if strings a, b, and c don't conform to the same basic pattern. """ if isinstance(a, str): a_pattern = re.sub('[a-zA-Z0-9]', '[a-zA-Z0-9]', re.escape(a)) if (isinstance(b, str) and (re.match(a_pattern, b) is None) or isinstance(c, str) and (re.match(a_pattern, c) is None)): logger.debug(a_pattern) raise TypeError('incomparable string patterns detected') def inject_repo_data_into_db(self, models_list): """Inject data into the db. Args: models_list: list of dicts mapping keys from the db model for each repo to be injected Returns: dict of uid:sql_repo pairs """ injected_repos = {} for m in models_list: injected_repos[m['uid']] = self.db_inject_repo(m) return injected_repos - def create_missing_origins_and_tasks(self, models_list, injected_repos): - """Find any newly created db entries that don't yet have tasks or - origin objects assigned. + def schedule_missing_tasks(self, models_list, injected_repos): + """Find any newly created db entries that do not have been scheduled + yet. Args: - models_list: a list of dicts mapping keys in the db model for - each repo - injected_repos: dict of uid:sql_repo pairs that have just - been created + models_list ([Model]): List of dicts mapping keys in the db model + for each repo + injected_repos ([dict]): Dict of uid:sql_repo pairs that have just + been created + Returns: Nothing. Modifies injected_repos. + """ - origins = {} tasks = {} - def _origin_key(m): - _type = m.get('origin_type', m.get('type')) - _url = m.get('origin_url', m.get('url')) - return '%s-%s' % (_type, _url) - def _task_key(m): - return '%s-%s' % (m['type'], - json.dumps(m['arguments'], sort_keys=True)) + return '%s-%s' % ( + m['type'], + json.dumps(m['arguments'], sort_keys=True) + ) for m in models_list: ir = injected_repos[m['uid']] - if not ir.origin_id: - origin_dict = self.origin_dict(**m) - origins[_origin_key(m)] = (ir, m, origin_dict) if not ir.task_id: task_dict = self.task_dict(**m) tasks[_task_key(task_dict)] = (ir, m, task_dict) - new_origins = self.storage.origin_add( - (origin_dicts for (_, _, origin_dicts) in origins.values())) - for origin in new_origins: - ir, m, _ = origins[_origin_key(origin)] - ir.origin_id = origin['id'] - new_tasks = self.scheduler.create_tasks( (task_dicts for (_, _, task_dicts) in tasks.values())) for task in new_tasks: ir, m, _ = tasks[_task_key(task)] ir.task_id = task['id'] def ingest_data(self, identifier, checks=False): """The core data fetch sequence. Request server endpoint. Simplify and filter response list of repositories. Inject repo information into local db. Queue loader tasks for linked repositories. Args: identifier: Resource identifier. checks (bool): Additional checks required """ # Request (partial?) list of repositories info response = self.safely_issue_request(identifier) if not response: return response, [] models_list = self.transport_response_simplified(response) models_list = self.filter_before_inject(models_list) if checks: models_list = self.do_additional_checks(models_list) if not models_list: return response, [] # inject into local db injected = self.inject_repo_data_into_db(models_list) # queue workers - self.create_missing_origins_and_tasks(models_list, injected) + self.schedule_missing_tasks(models_list, injected) return response, injected def save_response(self, response): """Log the response from a server request to a cache dir. Args: response: full server response cache_dir: system path for cache dir Returns: nothing """ datepath = utcnow().isoformat() fname = os.path.join( self.config['cache_dir'], datepath + '.gz', ) with gzip.open(fname, 'w') as f: f.write(bytes( self.transport_response_to_string(response), 'UTF-8' )) diff --git a/swh/lister/core/simple_lister.py b/swh/lister/core/simple_lister.py index 11060eb..40c47b2 100644 --- a/swh/lister/core/simple_lister.py +++ b/swh/lister/core/simple_lister.py @@ -1,74 +1,74 @@ # Copyright (C) 2018 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging from swh.core import utils from .lister_base import SWHListerBase class SimpleLister(SWHListerBase): """Lister* intermediate class for any service that follows the simple, 'list in oneshot information' pattern. - Client sends a request to list repositories in oneshot - Client receives structured (json/xml/etc) response with information and stores those in db """ def list_packages(self, *args): """Listing packages method. """ pass def ingest_data(self, identifier, checks=False): """Rework the base ingest_data. Request server endpoint which gives all in one go. Simplify and filter response list of repositories. Inject repo information into local db. Queue loader tasks for linked repositories. Args: identifier: Resource identifier (unused) checks (bool): Additional checks required (unused) """ response = self.safely_issue_request(identifier) response = self.list_packages(response) if not response: return response, [] models_list = self.transport_response_simplified(response) models_list = self.filter_before_inject(models_list) all_injected = [] for models in utils.grouper(models_list, n=10000): models = list(models) logging.debug('models: %s' % len(models)) # inject into local db injected = self.inject_repo_data_into_db(models) # queue workers - self.create_missing_origins_and_tasks(models, injected) + self.schedule_missing_tasks(models, injected) all_injected.append(injected) # flush self.db_session.commit() self.db_session = self.mk_session() return response, all_injected def run(self): """Query the server which answers in one query. Stores the information, dropping actual redundant information we already have. Returns: nothing """ dump_not_used_identifier = 0 response, injected_repos = self.ingest_data(dump_not_used_identifier) if not response and not injected_repos: logging.info('No response from api server, stopping') diff --git a/swh/lister/core/tests/test_lister.py b/swh/lister/core/tests/test_lister.py index 29dcd2a..5b93b64 100644 --- a/swh/lister/core/tests/test_lister.py +++ b/swh/lister/core/tests/test_lister.py @@ -1,234 +1,234 @@ # Copyright (C) 2017-2018 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import abc import time from unittest import TestCase from unittest.mock import Mock, patch import requests_mock from sqlalchemy import create_engine from testing.postgresql import Postgresql from swh.lister.core.abstractattribute import AbstractAttribute def noop(*args, **kwargs): pass @requests_mock.Mocker() class HttpListerTesterBase(abc.ABC): """Base testing class for subclasses of swh.lister.core.indexing_lister.SWHIndexingHttpLister. swh.lister.core.page_by_page_lister.PageByPageHttpLister See swh.lister.github.tests.test_gh_lister for an example of how to customize for a specific listing service. """ Lister = AbstractAttribute('The lister class to test') test_re = AbstractAttribute('Compiled regex matching the server url. Must' ' capture the index value.') lister_subdir = AbstractAttribute('bitbucket, github, etc.') good_api_response_file = AbstractAttribute('Example good response body') bad_api_response_file = AbstractAttribute('Example bad response body') first_index = AbstractAttribute('First index in good_api_response') entries_per_page = AbstractAttribute('Number of results in good response') LISTER_NAME = 'fake-lister' # May need to override this if the headers are used for something def response_headers(self, request): return {} # May need to override this if the server uses non-standard rate limiting # method. # Please keep the requested retry delay reasonably low. def mock_rate_quota(self, n, request, context): self.rate_limit += 1 context.status_code = 429 context.headers['Retry-After'] = '1' return '{"error":"dummy"}' def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.rate_limit = 1 self.response = None self.fl = None self.helper = None if self.__class__ != HttpListerTesterBase: self.run = TestCase.run.__get__(self, self.__class__) else: self.run = noop def request_index(self, request): m = self.test_re.search(request.path_url) if m and (len(m.groups()) > 0): return m.group(1) else: return None def mock_response(self, request, context): self.fl.reset_backoff() self.rate_limit = 1 context.status_code = 200 custom_headers = self.response_headers(request) context.headers.update(custom_headers) if self.request_index(request) == str(self.first_index): with open('swh/lister/%s/tests/%s' % (self.lister_subdir, self.good_api_response_file), 'r', encoding='utf-8') as r: return r.read() else: with open('swh/lister/%s/tests/%s' % (self.lister_subdir, self.bad_api_response_file), 'r', encoding='utf-8') as r: return r.read() def mock_limit_n_response(self, n, request, context): self.fl.reset_backoff() if self.rate_limit <= n: return self.mock_rate_quota(n, request, context) else: return self.mock_response(request, context) def mock_limit_once_response(self, request, context): return self.mock_limit_n_response(1, request, context) def mock_limit_twice_response(self, request, context): return self.mock_limit_n_response(2, request, context) def get_fl(self, override_config=None): """Retrieve an instance of fake lister (fl). """ if override_config or self.fl is None: self.fl = self.Lister(api_baseurl='https://fakeurl', override_config=override_config) self.fl.INITIAL_BACKOFF = 1 self.fl.reset_backoff() return self.fl def get_api_response(self): fl = self.get_fl() if self.response is None: self.response = fl.safely_issue_request(self.first_index) return self.response def test_is_within_bounds(self, http_mocker): fl = self.get_fl() self.assertFalse(fl.is_within_bounds(1, 2, 3)) self.assertTrue(fl.is_within_bounds(2, 1, 3)) self.assertTrue(fl.is_within_bounds(1, 1, 1)) self.assertTrue(fl.is_within_bounds(1, None, None)) self.assertTrue(fl.is_within_bounds(1, None, 2)) self.assertTrue(fl.is_within_bounds(1, 0, None)) self.assertTrue(fl.is_within_bounds("b", "a", "c")) self.assertFalse(fl.is_within_bounds("a", "b", "c")) self.assertTrue(fl.is_within_bounds("a", None, "c")) self.assertTrue(fl.is_within_bounds("a", None, None)) self.assertTrue(fl.is_within_bounds("b", "a", None)) self.assertFalse(fl.is_within_bounds("a", "b", None)) self.assertTrue(fl.is_within_bounds("aa:02", "aa:01", "aa:03")) self.assertFalse(fl.is_within_bounds("aa:12", None, "aa:03")) with self.assertRaises(TypeError): fl.is_within_bounds(1.0, "b", None) with self.assertRaises(TypeError): fl.is_within_bounds("A:B", "A::B", None) def test_api_request(self, http_mocker): http_mocker.get(self.test_re, text=self.mock_limit_twice_response) with patch.object(time, 'sleep', wraps=time.sleep) as sleepmock: self.get_api_response() self.assertEqual(sleepmock.call_count, 2) def test_repos_list(self, http_mocker): http_mocker.get(self.test_re, text=self.mock_response) li = self.get_fl().transport_response_simplified( self.get_api_response() ) self.assertIsInstance(li, list) self.assertEqual(len(li), self.entries_per_page) def test_model_map(self, http_mocker): http_mocker.get(self.test_re, text=self.mock_response) fl = self.get_fl() li = fl.transport_response_simplified(self.get_api_response()) di = li[0] self.assertIsInstance(di, dict) pubs = [k for k in vars(fl.MODEL).keys() if not k.startswith('_')] for k in pubs: if k not in ['last_seen', 'task_id', 'origin_id', 'id']: self.assertIn(k, di) - def disable_storage_and_scheduler(self, fl): - fl.create_missing_origins_and_tasks = Mock(return_value=None) + def disable_scheduler(self, fl): + fl.schedule_missing_tasks = Mock(return_value=None) def disable_db(self, fl): fl.winnow_models = Mock(return_value=[]) fl.db_inject_repo = Mock(return_value=fl.MODEL()) fl.disable_deleted_repo_tasks = Mock(return_value=None) def test_fetch_none_nodb(self, http_mocker): http_mocker.get(self.test_re, text=self.mock_response) fl = self.get_fl() - self.disable_storage_and_scheduler(fl) + self.disable_scheduler(fl) self.disable_db(fl) fl.run(min_bound=1, max_bound=1) # stores no results def test_fetch_one_nodb(self, http_mocker): http_mocker.get(self.test_re, text=self.mock_response) fl = self.get_fl() - self.disable_storage_and_scheduler(fl) + self.disable_scheduler(fl) self.disable_db(fl) fl.run(min_bound=self.first_index, max_bound=self.first_index) def test_fetch_multiple_pages_nodb(self, http_mocker): http_mocker.get(self.test_re, text=self.mock_response) fl = self.get_fl() - self.disable_storage_and_scheduler(fl) + self.disable_scheduler(fl) self.disable_db(fl) fl.run(min_bound=self.first_index) def init_db(self, db, model): engine = create_engine(db.url()) model.metadata.create_all(engine) class HttpListerTester(HttpListerTesterBase, abc.ABC): last_index = AbstractAttribute('Last index in good_api_response') @requests_mock.Mocker() def test_fetch_multiple_pages_yesdb(self, http_mocker): http_mocker.get(self.test_re, text=self.mock_response) initdb_args = Postgresql.DEFAULT_SETTINGS['initdb_args'] initdb_args = ' '.join([initdb_args, '-E UTF-8']) db = Postgresql(initdb_args=initdb_args) fl = self.get_fl(override_config={ 'lister': { 'cls': 'local', 'args': {'db': db.url()} } }) self.init_db(db, fl.MODEL) - self.disable_storage_and_scheduler(fl) + self.disable_scheduler(fl) fl.run(min_bound=self.first_index) self.assertEqual(fl.db_last_index(), self.last_index) partitions = fl.db_partition_indices(5) self.assertGreater(len(partitions), 0) for k in partitions: self.assertLessEqual(len(k), 5) self.assertGreater(len(k), 0) diff --git a/swh/lister/debian/lister.py b/swh/lister/debian/lister.py index 5b72c6f..a0dc11a 100644 --- a/swh/lister/debian/lister.py +++ b/swh/lister/debian/lister.py @@ -1,237 +1,237 @@ # Copyright (C) 2017 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import bz2 from collections import defaultdict import datetime import gzip import lzma import logging from debian.deb822 import Sources from sqlalchemy.orm import joinedload, load_only from sqlalchemy.schema import CreateTable, DropTable from swh.storage.schemata.distribution import ( AreaSnapshot, Distribution, DistributionSnapshot, Package, TempPackage, ) from swh.lister.core.lister_base import SWHListerBase, FetchError from swh.lister.core.lister_transports import SWHListerHttpTransport decompressors = { 'gz': lambda f: gzip.GzipFile(fileobj=f), 'bz2': bz2.BZ2File, 'xz': lzma.LZMAFile, } class DebianLister(SWHListerHttpTransport, SWHListerBase): MODEL = Package PATH_TEMPLATE = None LISTER_NAME = 'debian' instance = 'debian' def __init__(self, override_config=None): SWHListerHttpTransport.__init__(self, api_baseurl="bogus") SWHListerBase.__init__(self, override_config=override_config) def transport_request(self, identifier): """Subvert SWHListerHttpTransport.transport_request, to try several index URIs in turn. The Debian repository format supports several compression algorithms across the ages, so we try several URIs. Once we have found a working URI, we break and set `self.decompressor` to the one that matched. Returns: a requests Response object. Raises: FetchError: when all the URIs failed to be retrieved. """ response = None compression = None for uri, compression in self.area.index_uris(): response = super().transport_request(uri) if response.status_code == 200: break else: raise FetchError( "Could not retrieve index for %s" % self.area ) self.decompressor = decompressors.get(compression) return response def request_uri(self, identifier): # In the overridden transport_request, we pass # SWHListerBase.transport_request() the full URI as identifier, so we # need to return it here. return identifier def request_params(self, identifier): # Enable streaming to allow wrapping the response in the decompressor # in transport_response_simplified. params = super().request_params(identifier) params['stream'] = True return params def transport_response_simplified(self, response): """Decompress and parse the package index fetched in `transport_request`. For each package, we "pivot" the file list entries (Files, Checksums-Sha1, Checksums-Sha256), to return a files dict mapping filenames to their checksums. """ if self.decompressor: data = self.decompressor(response.raw) else: data = response.raw for src_pkg in Sources.iter_paragraphs(data.readlines()): files = defaultdict(dict) for field in src_pkg._multivalued_fields: if field.startswith('checksums-'): sum_name = field[len('checksums-'):] else: sum_name = 'md5sum' if field in src_pkg: for entry in src_pkg[field]: name = entry['name'] files[name]['name'] = entry['name'] files[name]['size'] = int(entry['size'], 10) files[name][sum_name] = entry[sum_name] yield { 'name': src_pkg['Package'], 'version': src_pkg['Version'], 'directory': src_pkg['Directory'], 'files': files, } def inject_repo_data_into_db(self, models_list): """Generate the Package entries that didn't previously exist. Contrary to SWHListerBase, we don't actually insert the data in - database. `create_missing_origins_and_tasks` does it once we have the + database. `schedule_missing_tasks` does it once we have the origin and task identifiers. """ by_name_version = {} temp_packages = [] area_id = self.area.id for model in models_list: name = model['name'] version = model['version'] temp_packages.append({ 'area_id': area_id, 'name': name, 'version': version, }) by_name_version[name, version] = model # Add all the listed packages to a temporary table self.db_session.execute(CreateTable(TempPackage.__table__)) self.db_session.bulk_insert_mappings(TempPackage, temp_packages) def exists_tmp_pkg(db_session, model): return ( db_session.query(model) .filter(Package.area_id == TempPackage.area_id) .filter(Package.name == TempPackage.name) .filter(Package.version == TempPackage.version) .exists() ) # Filter out the packages that already exist in the main Package table new_packages = self.db_session\ .query(TempPackage)\ .options(load_only('name', 'version'))\ .filter(~exists_tmp_pkg(self.db_session, Package))\ .all() self.old_area_packages = self.db_session.query(Package).filter( exists_tmp_pkg(self.db_session, TempPackage) ).all() self.db_session.execute(DropTable(TempPackage.__table__)) added_packages = [] for package in new_packages: model = by_name_version[package.name, package.version] added_packages.append(Package(area=self.area, **model)) self.db_session.add_all(added_packages) return added_packages - def create_missing_origins_and_tasks(self, models_list, added_packages): + def schedule_missing_tasks(self, models_list, added_packages): """We create tasks at the end of the full snapshot processing""" return def create_tasks_for_snapshot(self, snapshot): tasks = [ snapshot.task_for_package(name, versions) for name, versions in snapshot.get_packages().items() ] return self.scheduler.create_tasks(tasks) def run(self, distribution, date=None): """Run the lister for a given (distribution, area) tuple. Args: distribution (str): name of the distribution (e.g. "Debian") date (datetime.datetime): date the snapshot is taken (defaults to now) """ distribution = self.db_session\ .query(Distribution)\ .options(joinedload(Distribution.areas))\ .filter(Distribution.name == distribution)\ .one_or_none() if not distribution: raise ValueError("Distribution %s is not registered" % distribution) if not distribution.type == 'deb': raise ValueError("Distribution %s is not a Debian derivative" % distribution) date = date or datetime.datetime.now(tz=datetime.timezone.utc) logging.debug('Creating snapshot for distribution %s on date %s' % (distribution, date)) snapshot = DistributionSnapshot(date=date, distribution=distribution) self.db_session.add(snapshot) for area in distribution.areas: if not area.active: continue self.area = area logging.debug('Processing area %s' % area) _, new_area_packages = self.ingest_data(None) area_snapshot = AreaSnapshot(snapshot=snapshot, area=area) self.db_session.add(area_snapshot) area_snapshot.packages.extend(new_area_packages) area_snapshot.packages.extend(self.old_area_packages) self.create_tasks_for_snapshot(snapshot) self.db_session.commit() return True diff --git a/swh/lister/phabricator/lister.py b/swh/lister/phabricator/lister.py index d6e062e..c02103d 100644 --- a/swh/lister/phabricator/lister.py +++ b/swh/lister/phabricator/lister.py @@ -1,143 +1,143 @@ # Copyright (C) 2019 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import urllib.parse from swh.lister.core.indexing_lister import SWHIndexingHttpLister from swh.lister.phabricator.models import PhabricatorModel from collections import defaultdict class PhabricatorLister(SWHIndexingHttpLister): PATH_TEMPLATE = '&order=oldest&attachments[uris]=1&after=%s' MODEL = PhabricatorModel LISTER_NAME = 'phabricator' def __init__(self, forge_url, api_token, instance=None, override_config=None): if forge_url.endswith("/"): forge_url = forge_url[:-1] self.forge_url = forge_url api_endpoint = ('api/diffusion.repository.' 'search?api.token=%s') % api_token api_baseurl = '%s/%s' % (forge_url, api_endpoint) if not instance: instance = urllib.parse.urlparse(forge_url).hostname self.instance = instance super().__init__(api_baseurl=api_baseurl, override_config=override_config) def request_headers(self): """ (Override) Set requests headers to send when querying the Phabricator API """ return {'User-Agent': 'Software Heritage phabricator lister', 'Accept': 'application/json'} def get_model_from_repo(self, repo): url = get_repo_url(repo['attachments']['uris']['uris']) if url is None: return None return { 'uid': self.forge_url + str(repo['id']), 'indexable': repo['id'], 'name': repo['fields']['shortName'], 'full_name': repo['fields']['name'], 'html_url': url, 'origin_url': url, 'description': None, 'origin_type': repo['fields']['vcs'] } def get_next_target_from_response(self, response): body = response.json()['result']['cursor'] if body['after'] != 'null': return body['after'] else: return None def transport_response_simplified(self, response): repos = response.json() if repos['result'] is None: raise ValueError( 'Problem during information fetch: %s' % repos['error_code']) repos = repos['result']['data'] return [self.get_model_from_repo(repo) for repo in repos] def filter_before_inject(self, models_list): """ (Overrides) SWHIndexingLister.filter_before_inject Bounds query results by this Lister's set max_index. """ models_list = [m for m in models_list if m is not None] return super().filter_before_inject(models_list) def _bootstrap_repositories_listing(self): """ Method called when no min_bound value has been provided to the lister. Its purpose is to: 1. get the first repository data hosted on the Phabricator instance 2. inject them into the lister database 3. return the first repository index to start the listing after that value Returns: int: The first repository index """ params = '&order=oldest&limit=1' response = self.safely_issue_request(params) models_list = self.transport_response_simplified(response) self.max_index = models_list[0]['indexable'] models_list = self.filter_before_inject(models_list) injected = self.inject_repo_data_into_db(models_list) - self.create_missing_origins_and_tasks(models_list, injected) + self.schedule_missing_tasks(models_list, injected) return self.max_index def run(self, min_bound=None, max_bound=None): """ (Override) Run the lister on the specified Phabricator instance Args: min_bound (int): Optional repository index to start the listing after it max_bound (int): Optional repository index to stop the listing after it """ # initial call to the lister, we need to bootstrap it in that case if min_bound is None: min_bound = self._bootstrap_repositories_listing() super().run(min_bound, max_bound) def get_repo_url(attachments): """ Return url for a hosted repository from its uris attachments according to the following priority lists: * protocol: https > http * identifier: shortname > callsign > id """ processed_urls = defaultdict(dict) for uri in attachments: protocol = uri['fields']['builtin']['protocol'] url = uri['fields']['uri']['effective'] identifier = uri['fields']['builtin']['identifier'] if protocol in ('http', 'https'): processed_urls[protocol][identifier] = url elif protocol is None: for protocol in ('https', 'http'): if url.startswith(protocol): processed_urls[protocol]['undefined'] = url break for protocol in ['https', 'http']: for identifier in ['shortname', 'callsign', 'id', 'undefined']: if (protocol in processed_urls and identifier in processed_urls[protocol]): return processed_urls[protocol][identifier] return None