diff --git a/swh/loader/core/loader.py b/swh/loader/core/loader.py --- a/swh/loader/core/loader.py +++ b/swh/loader/core/loader.py @@ -12,8 +12,10 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union from swh.core import config +from swh.model.model import ( + BaseContent, Content, SkippedContent, Directory, Origin, Revision, + Release, Snapshot) from swh.storage import get_storage -from swh.loader.core.converters import prepare_contents class BaseLoader(config.SWHConfig, metaclass=ABCMeta): @@ -90,7 +92,7 @@ # possibly overridden in self.prepare method self.visit_date: Optional[Union[str, datetime.datetime]] = None - self.origin: Dict[str, Any] = {} + self.origin: Optional[Origin] = None if not hasattr(self, 'visit_type'): self.visit_type: Optional[str] = None @@ -114,7 +116,8 @@ if not hasattr(self, '__save_data_path'): year = str(self.visit_date.year) # type: ignore - url = self.origin['url'].encode('utf-8') + assert self.origin + url = self.origin.url.encode('utf-8') origin_url_hash = hashlib.sha1(url).hexdigest() path = '%s/sha1:%s/%s/%s' % ( @@ -161,13 +164,13 @@ self.visit references. """ - origin = self.origin.copy() - self.storage.origin_add_one(origin) + assert self.origin + self.storage.origin_add_one(self.origin) if not self.visit_date: # now as default visit_date if not provided self.visit_date = datetime.datetime.now(tz=datetime.timezone.utc) self.origin_visit = self.storage.origin_visit_add( - origin['url'], self.visit_date, self.visit_type) + self.origin.url, self.visit_date, self.visit_type) self.visit = self.origin_visit['visit'] @abstractmethod @@ -178,7 +181,7 @@ """ pass - def get_origin(self) -> Dict[str, Any]: + def get_origin(self) -> Origin: """Get the origin that is currently being loaded. self.origin should be set in :func:`prepare_origin` @@ -186,6 +189,7 @@ dict: an origin ready to be sent to storage by :func:`origin_add_one`. """ + assert self.origin return self.origin @abstractmethod @@ -286,6 +290,8 @@ self.prepare_origin_visit(*args, **kwargs) self._store_origin_visit() + assert self.origin + try: self.prepare(*args, **kwargs) @@ -297,7 +303,7 @@ self.store_metadata() self.storage.origin_visit_update( - self.origin['url'], self.visit, self.visit_status() + self.origin.url, self.visit, self.visit_status() ) self.post_load() except Exception: @@ -307,7 +313,7 @@ 'swh_task_kwargs': kwargs, }) self.storage.origin_visit_update( - self.origin['url'], self.visit, 'partial' + self.origin.url, self.visit, 'partial' ) self.post_load(success=False) return {'status': 'failed'} @@ -338,7 +344,7 @@ """Checks whether we need to load contents""" return True - def get_contents(self) -> Iterable[Dict[str, Any]]: + def get_contents(self) -> Iterable[BaseContent]: """Get the contents that need to be loaded""" raise NotImplementedError @@ -346,7 +352,7 @@ """Checks whether we need to load directories""" return True - def get_directories(self) -> Iterable[Dict[str, Any]]: + def get_directories(self) -> Iterable[Directory]: """Get the directories that need to be loaded""" raise NotImplementedError @@ -354,7 +360,7 @@ """Checks whether we need to load revisions""" return True - def get_revisions(self) -> Iterable[Dict[str, Any]]: + def get_revisions(self) -> Iterable[Revision]: """Get the revisions that need to be loaded""" raise NotImplementedError @@ -362,11 +368,11 @@ """Checks whether we need to load releases""" return True - def get_releases(self) -> Iterable[Dict[str, Any]]: + def get_releases(self) -> Iterable[Release]: """Get the releases that need to be loaded""" raise NotImplementedError - def get_snapshot(self) -> Dict[str, Any]: + def get_snapshot(self) -> Snapshot: """Get the snapshot that needs to be loaded""" raise NotImplementedError @@ -375,13 +381,20 @@ raise NotImplementedError def store_data(self) -> None: + assert self.origin if self.config['save_data']: self.save_data() if self.has_contents(): - contents, skipped_contents = prepare_contents( - self.get_contents(), max_content_size=self.max_content_size, - origin_url=self.origin['url']) + contents = [] + skipped_contents = [] + for obj in self.get_contents(): + if isinstance(obj, Content): + contents.append(obj) + elif isinstance(obj, SkippedContent): + skipped_contents.append(obj) + else: + raise TypeError(f'Unexpected content type: {obj}') self.storage.skipped_content_add(skipped_contents) self.storage.content_add(contents) if self.has_directories(): @@ -394,5 +407,5 @@ snapshot = self.get_snapshot() self.storage.snapshot_add([snapshot]) self.storage.origin_visit_update( - self.origin['url'], self.visit, snapshot=snapshot['id']) + self.origin.url, self.visit, snapshot=snapshot.id) self.flush() diff --git a/swh/loader/core/tests/test_loader.py b/swh/loader/core/tests/test_loader.py --- a/swh/loader/core/tests/test_loader.py +++ b/swh/loader/core/tests/test_loader.py @@ -8,6 +8,8 @@ import logging import pytest +from swh.model.model import Origin + from swh.loader.core.loader import BaseLoader, DVCSLoader @@ -25,9 +27,9 @@ pass def prepare_origin_visit(self, *args, **kwargs): - origin = {'url': 'some-url'} + origin = Origin(url='some-url') self.origin = origin - self.origin_url = origin['url'] + self.origin_url = origin.url self.visit_date = datetime.datetime.utcnow() self.visit_type = 'git' self.storage.origin_visit_add(self.origin_url, self.visit_date, @@ -44,9 +46,6 @@ 'storage': { 'cls': 'pipeline', 'steps': [ - { - 'cls': 'validate', - }, { 'cls': 'retry', }, @@ -71,9 +70,6 @@ 'storage': { 'cls': 'pipeline', 'steps': [ - { - 'cls': 'validate', - }, { 'cls': 'retry', }, @@ -134,9 +130,7 @@ def test_loader_save_data_path(tmp_path): loader = DummyBaseLoader('some.logger.name.1') url = 'http://bitbucket.org/something' - loader.origin = { - 'url': url, - } + loader.origin = Origin(url=url) loader.visit_date = datetime.datetime(year=2019, month=10, day=1) loader.config = { 'save_data_path': tmp_path, diff --git a/swh/loader/tests/conftest.py b/swh/loader/tests/conftest.py --- a/swh/loader/tests/conftest.py +++ b/swh/loader/tests/conftest.py @@ -14,9 +14,6 @@ 'storage': { 'cls': 'pipeline', 'steps': [ - { - 'cls': 'validate', - }, { 'cls': 'memory', },