diff --git a/swh/loader/core/tests/__init__.py b/swh/loader/core/tests/__init__.py --- a/swh/loader/core/tests/__init__.py +++ b/swh/loader/core/tests/__init__.py @@ -12,6 +12,9 @@ from unittest import TestCase from swh.model import hashutil +from swh.model.hashutil import hash_to_bytes + +from swh.storage.in_memory import Storage @pytest.mark.fs @@ -49,6 +52,7 @@ def setUp(self, archive_name, *, start_path, filename=None, resources_path='resources', prefix_tmp_folder_name='', uncompress_archive=True): + self.storage = Storage() repo_path = os.path.join(start_path, resources_path, archive_name) if not uncompress_archive: # In that case, simply sets the archive's path @@ -78,14 +82,13 @@ if self.tmp_root_path and os.path.exists(self.tmp_root_path): shutil.rmtree(self.tmp_root_path) - def state(self, _type): - return self.loader.state(_type) - def _assertCountOk(self, type, expected_length, msg=None): """Check typed 'type' state to have the same expected length. """ - self.assertEqual(len(self.state(type)), expected_length, msg=msg) + self.storage.refresh_stat_counters() + self.assertEqual(self.storage.stat_counters()[type], + expected_length, msg=msg) def assertCountContents(self, len_expected_contents, msg=None): self._assertCountOk('content', len_expected_contents, msg=msg) @@ -104,15 +107,16 @@ def assertContentsOk(self, expected_contents): self._assertCountOk('content', len(expected_contents)) - for content in self.state('content'): - content_id = hashutil.hash_to_hex(content['sha1']) - self.assertIn(content_id, expected_contents) + missing = list(self.storage.content_missing( + {'sha1': hash_to_bytes(content_hash)} + for content_hash in expected_contents)) + self.assertEqual(missing, []) def assertDirectoriesOk(self, expected_directories): self._assertCountOk('directory', len(expected_directories)) - for _dir in self.state('directory'): - _dir_id = hashutil.hash_to_hex(_dir['id']) - self.assertIn(_dir_id, expected_directories) + missing = list(self.storage.directory_missing( + dir_['id'] for dir_ in expected_directories)) + self.assertEqual(missing, []) def assertReleasesOk(self, expected_releases): """Check the loader's releases match the expected releases. @@ -122,9 +126,9 @@ """ self._assertCountOk('release', len(expected_releases)) - for i, rel in enumerate(self.state('release')): - rel_id = hashutil.hash_to_hex(rel['id']) - self.assertEqual(expected_releases[i], rel_id) + missing = list(self.storage.releases_missing( + rel['id'] for rel in expected_releases)) + self.assertEqual(missing, []) def assertRevisionsOk(self, expected_revisions): """Check the loader's revisions match the expected revisions. @@ -138,11 +142,14 @@ """ self._assertCountOk('revision', len(expected_revisions)) - for rev in self.state('revision'): - rev_id = hashutil.hash_to_hex(rev['id']) - directory_id = hashutil.hash_to_hex(rev['directory']) - self.assertEqual(expected_revisions[rev_id], directory_id) + revs = list(self.storage.revision_get( + hashutil.hash_to_bytes(rev_id) for rev_id in expected_revisions)) + self.assertNotIn(None, revs) + self.assertEqual( + {rev['id']: rev['directory'] for rev in revs}, + {hash_to_bytes(rev_id): hash_to_bytes(rev_dir) + for (rev_id, rev_dir) in expected_revisions.items()}) def assertSnapshotOk(self, expected_snapshot, expected_branches=[]): """Check for snapshot match. @@ -165,12 +172,10 @@ else: expected_snapshot_id = expected_snapshot - snapshots = self.state('snapshot') - self.assertEqual(len(snapshots), 1) + self._assertCountOk('snapshot', 1) - snap = snapshots[0] - snap_id = hashutil.hash_to_hex(snap['id']) - self.assertEqual(snap_id, expected_snapshot_id) + snap = self.storage.snapshot_get(hash_to_bytes(expected_snapshot_id)) + self.assertIsNot(snap, None) def decode_target(target): if not target: @@ -192,95 +197,3 @@ for branch, target in snap['branches'].items() } self.assertEqual(expected_branches, branches) - - -class LoaderNoStorage: - """Mixin class to inhibit the persistence and keep in memory the data - sent for storage (for testing purposes). - - This overrides the core loader's behavior to store in a dict the - swh objects. - - cf. :class:`HgLoaderNoStorage`, :class:`SvnLoaderNoStorage`, etc... - - """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._state = { - 'content': [], - 'directory': [], - 'revision': [], - 'release': [], - 'snapshot': [], - } - - def state(self, type): - return self._state[type] - - def _add(self, type, l): - """Add without duplicates and keeping the insertion order. - - Args: - type (str): Type of objects concerned by the action - l ([object]): List of 'type' object - - """ - col = self.state(type) - for o in l: - if o in col: - continue - col.append(o) - - def maybe_load_contents(self, all_contents): - self._add('content', all_contents) - - def maybe_load_directories(self, all_directories): - self._add('directory', all_directories) - - def maybe_load_revisions(self, all_revisions): - self._add('revision', all_revisions) - - def maybe_load_releases(self, all_releases): - self._add('release', all_releases) - - def maybe_load_snapshot(self, snapshot): - self._add('snapshot', [snapshot]) - - def send_batch_contents(self, all_contents): - self._add('content', all_contents) - - def send_batch_directories(self, all_directories): - self._add('directory', all_directories) - - def send_batch_revisions(self, all_revisions): - self._add('revision', all_revisions) - - def send_batch_releases(self, all_releases): - self._add('release', all_releases) - - def send_snapshot(self, snapshot): - self._add('snapshot', [snapshot]) - - def _store_origin_visit(self): - pass - - def open_fetch_history(self): - pass - - def close_fetch_history_success(self, fetch_history_id): - pass - - def close_fetch_history_failure(self, fetch_history_id): - pass - - def update_origin_visit(self, origin_id, visit, status): - pass - - def close_failure(self): - pass - - def close_success(self): - pass - - def pre_cleanup(self): - pass