diff --git a/swh/lister/core/models.py b/swh/lister/core/models.py index ed6a591..27eb080 100644 --- a/swh/lister/core/models.py +++ b/swh/lister/core/models.py @@ -1,80 +1,79 @@ # Copyright (C) 2015-2019 the Software Heritage developers # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import abc from datetime import datetime import logging from sqlalchemy import Column, DateTime, Integer, String from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base from typing import Type, Union from .abstractattribute import AbstractAttribute SQLBase = declarative_base() logger = logging.getLogger(__name__) class ABCSQLMeta(abc.ABCMeta, DeclarativeMeta): pass class ModelBase(SQLBase, metaclass=ABCSQLMeta): """a common repository""" __abstract__ = True __tablename__ = \ AbstractAttribute # type: Union[Type[AbstractAttribute], str] uid = AbstractAttribute( 'Column(, primary_key=True)' ) # type: Union[AbstractAttribute, Column] name = Column(String, index=True) full_name = Column(String, index=True) html_url = Column(String) origin_url = Column(String) origin_type = Column(String) last_seen = Column(DateTime, nullable=False) task_id = Column(Integer) def __init__(self, **kw): kw['last_seen'] = datetime.now() super().__init__(**kw) class IndexingModelBase(ModelBase, metaclass=ABCSQLMeta): __abstract__ = True __tablename__ = \ AbstractAttribute # type: Union[Type[AbstractAttribute], str] # The value used for sorting, segmenting, or api query paging, # because uids aren't always sequential. indexable = AbstractAttribute( 'Column(, index=True)' ) # type: Union[AbstractAttribute, Column] def initialize(db_engine, drop_tables=False, **kwargs): """Default database initialization function for a lister. Typically called from the lister's initialization hook. Args: models (list): list of SQLAlchemy tables/models to drop/create. db_engine (): the SQLAlchemy DB engine. drop_tables (bool): if True, tables will be dropped before (re)creating them. """ - if drop_tables: logger.info('Dropping tables') SQLBase.metadata.drop_all(db_engine, checkfirst=True) logger.info('Creating tables') SQLBase.metadata.create_all(db_engine, checkfirst=True) diff --git a/swh/lister/debian/__init__.py b/swh/lister/debian/__init__.py index b1398f6..6dbb553 100644 --- a/swh/lister/debian/__init__.py +++ b/swh/lister/debian/__init__.py @@ -1,58 +1,54 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from typing import Any, List, Mapping -def debian_init(db_engine, lister=None, +def debian_init(db_engine, override_conf: Mapping[str, Any] = {}, distributions: List[str] = ['stretch', 'buster'], area_names: List[str] = ['main', 'contrib', 'non-free']): """Initialize the debian data model. Args: db_engine: SQLAlchemy manipulation database object - lister: Debian lister instance. None by default. override_conf: Override conf to pass to instantiate a lister distributions: Default distribution to build """ distribution_name = 'Debian' from swh.lister.debian.models import Distribution, Area + from sqlalchemy.orm import sessionmaker + db_session = sessionmaker(bind=db_engine)() + + existing_distrib = db_session \ + .query(Distribution) \ + .filter(Distribution.name == distribution_name) \ + .one_or_none() + if not existing_distrib: + distrib = Distribution(name=distribution_name, + type='deb', + mirror_uri='http://deb.debian.org/debian/') + db_session.add(distrib) - if lister is None: - from .lister import DebianLister - lister = DebianLister(distribution=distribution_name, - override_config=override_conf) - - if not lister.db_session\ - .query(Distribution)\ - .filter(Distribution.name == distribution_name)\ - .one_or_none(): - - d = Distribution( - name=distribution_name, - type='deb', - mirror_uri='http://deb.debian.org/debian/') - lister.db_session.add(d) - - areas = [] for distribution_name in distributions: for area_name in area_names: - areas.append(Area( + area = Area( name='%s/%s' % (distribution_name, area_name), - distribution=d, - )) - lister.db_session.add_all(areas) - lister.db_session.commit() + distribution=distrib, + ) + db_session.add(area) + + db_session.commit() + db_session.close() def register() -> Mapping[str, Any]: from .lister import DebianLister return {'models': [DebianLister.MODEL], 'lister': DebianLister, 'task_modules': ['%s.tasks' % __name__], 'init': debian_init} diff --git a/swh/lister/debian/tests/conftest.py b/swh/lister/debian/tests/conftest.py index 6a54149..a479ed1 100644 --- a/swh/lister/debian/tests/conftest.py +++ b/swh/lister/debian/tests/conftest.py @@ -1,30 +1,30 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest from swh.lister.core.tests.conftest import * # noqa from swh.lister.debian import debian_init @pytest.fixture def lister_debian(swh_listers): lister = swh_listers['debian'] # Initialize the debian data model - debian_init(lister.db_engine, lister=lister, + debian_init(lister.db_engine, distributions=['stretch'], area_names=['main', 'contrib']) # Add the load-deb-package in the scheduler backend lister.scheduler.create_task_type({ 'type': 'load-deb-package', 'description': 'Load a Debian package', 'backend_name': 'swh.loader.debian.tasks.LoaderDebianPackage', 'default_interval': '1 day', }) return lister