diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -104,7 +104,7 @@ return summary - def content_add(self, contents): + def content_add(self, content): """Add content blobs to the storage Args: @@ -133,13 +133,13 @@ skipped_content:add: New skipped contents (no data) added """ - contents = [dict(c.items()) for c in contents] # semi-shallow copy + content = [dict(c.items()) for c in content] # semi-shallow copy now = datetime.datetime.now(tz=datetime.timezone.utc) - for item in contents: + for item in content: item['ctime'] = now - return self._content_add(contents, with_data=True) + return self._content_add(content, with_data=True) - def content_add_metadata(self, contents): + def content_add_metadata(self, content): """Add content metadata to the storage (like `content_add`, but without inserting to the objstorage). @@ -168,9 +168,9 @@ skipped_content:add: New skipped contents (no data) added """ - return self._content_add(contents, with_data=False) + return self._content_add(content, with_data=False) - def content_get(self, ids): + def content_get(self, content): """Retrieve in bulk contents and their data. This function may yield more blobs than provided sha1 identifiers, @@ -192,10 +192,10 @@ """ # FIXME: Make this method support slicing the `data`. - if len(ids) > BULK_BLOCK_CONTENT_LEN_MAX: + if len(content) > BULK_BLOCK_CONTENT_LEN_MAX: raise ValueError( "Sending at most %s contents." % BULK_BLOCK_CONTENT_LEN_MAX) - for obj_id in ids: + for obj_id in content: try: data = self.objstorage.get(obj_id) except ObjNotFoundError: @@ -248,7 +248,7 @@ 'next': next_content, } - def content_get_metadata(self, sha1s): + def content_get_metadata(self, content): """Retrieve content metadata in bulk Args: @@ -259,7 +259,7 @@ """ # FIXME: the return value should be a mapping from search key to found # content*s* - for sha1 in sha1s: + for sha1 in content: if sha1 in self._content_indexes['sha1']: objs = self._content_indexes['sha1'][sha1] # FIXME: rather than selecting one of the objects with that @@ -296,7 +296,7 @@ keys = list(set.intersection(*found)) return copy.deepcopy([self._contents[key] for key in keys]) - def content_missing(self, contents, key_hash='sha1'): + def content_missing(self, content, key_hash='sha1'): """List content missing from storage Args: @@ -313,17 +313,17 @@ iterable ([bytes]): missing content ids (as per the key_hash column) """ - for content in contents: - for (algo, hash_) in content.items(): + for cont in content: + for (algo, hash_) in cont.items(): if algo not in DEFAULT_ALGORITHMS: continue if hash_ not in self._content_indexes.get(algo, []): - yield content[key_hash] + yield cont[key_hash] break else: - for result in self.content_find(content): + for result in self.content_find(cont): if result['status'] == 'missing': - yield content[key_hash] + yield cont[key_hash] def content_missing_per_sha1(self, contents): """List content missing from storage based only on sha1. @@ -379,7 +379,7 @@ return {'directory:add': count} - def directory_missing(self, directory_ids): + def directory_missing(self, directories): """List directories missing from storage Args: @@ -389,7 +389,7 @@ missing directory ids """ - for id in directory_ids: + for id in directories: if id not in self._directories: yield id @@ -423,7 +423,7 @@ yield from self._directory_ls( ret['target'], True, prefix + ret['name'] + b'/') - def directory_ls(self, directory_id, recursive=False): + def directory_ls(self, directory, recursive=False): """Get entries for one directory. Args: @@ -436,7 +436,7 @@ If `recursive=True`, names in the path of a dir/file not at the root are concatenated with a slash (`/`). """ - yield from self._directory_ls(directory_id, recursive) + yield from self._directory_ls(directory, recursive) def directory_entry_get_by_path(self, directory, paths): """Get the directory entry (either file or dir) from directory with path. @@ -529,7 +529,7 @@ return {'revision:add': count} - def revision_missing(self, revision_ids): + def revision_missing(self, revisions): """List revisions missing from storage Args: @@ -539,12 +539,12 @@ missing revision ids """ - for id in revision_ids: + for id in revisions: if id not in self._revisions: yield id - def revision_get(self, revision_ids): - for id in revision_ids: + def revision_get(self, revisions): + for id in revisions: yield copy.deepcopy(self._revisions.get(id)) def _get_parent_revs(self, rev_id, seen, limit): @@ -557,7 +557,7 @@ for parent in self._revisions[rev_id]['parents']: yield from self._get_parent_revs(parent, seen, limit) - def revision_log(self, revision_ids, limit=None): + def revision_log(self, revisions, limit=None): """Fetch revision entry from the given root revisions. Args: @@ -569,7 +569,7 @@ """ seen = set() - for rev_id in revision_ids: + for rev_id in revisions: yield from self._get_parent_revs(rev_id, seen, limit) def revision_shortlog(self, revisions, limit=None): @@ -655,7 +655,7 @@ for rel_id in releases: yield copy.deepcopy(self._releases.get(rel_id)) - def snapshot_add(self, snapshots, legacy_arg1=None, legacy_arg2=None): + def snapshot_add(self, snapshots, origin=None, visit=None): """Add a snapshot to the storage Args: @@ -684,12 +684,20 @@ snapshot_added: Count of object actually stored in db """ - if legacy_arg1: - assert legacy_arg2 - (origin, visit, snapshots) = \ - (snapshots, legacy_arg1, [legacy_arg2]) - else: - origin = visit = None + if origin: + if not visit: + raise TypeError( + 'snapshot_add expects one argument (or, as a legacy ' + 'behavior, three arguments), not two') + if isinstance(snapshots, (int, bytes)): + # Called by legacy code that uses the new api/client.py + (origin_id, visit_id, snapshots) = \ + (snapshots, origin, [visit]) + else: + # Called by legacy code that uses the old api/client.py + origin_id = origin + visit_id = visit + snapshots = [snapshots] count = 0 for snapshot in snapshots: @@ -706,10 +714,10 @@ self._objects[snapshot_id].append(('snapshot', snapshot_id)) count += 1 - if origin: + if visit_id: # Legacy API, there can be only one snapshot self.origin_visit_update( - origin, visit, snapshot=snapshots[0]['id']) + origin_id, visit_id, snapshot=snapshots[0]['id']) return {'snapshot:add': count} @@ -1026,7 +1034,7 @@ origins = self._origins if regexp: pat = re.compile(url_pattern) - origins = [orig for orig in origins if pat.match(orig['url'])] + origins = [orig for orig in origins if pat.search(orig['url'])] else: origins = [orig for orig in origins if url_pattern in orig['url']] if with_visit: @@ -1531,9 +1539,9 @@ @staticmethod def _tool_key(tool): - return (tool['name'], tool['version'], - tuple(sorted(tool['configuration'].items()))) + return '%r %r %r' % (tool['name'], tool['version'], + tuple(sorted(tool['configuration'].items()))) @staticmethod def _metadata_provider_key(provider): - return (provider['name'], provider['url']) + return '%r %r' % (provider['name'], provider['url']) diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py --- a/swh/storage/tests/test_api_client.py +++ b/swh/storage/tests/test_api_client.py @@ -3,15 +3,17 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import pytest import shutil import tempfile import unittest +import pytest + from swh.core.api.tests.server_testing import ServerTestFixture import swh.storage.storage as storage from swh.storage.journal_writer import \ get_journal_writer, InMemoryJournalWriter +from swh.storage.in_memory import Storage as InMemoryStorage from swh.storage.api.client import RemoteStorage import swh.storage.api.server as server from swh.storage.api.server import app @@ -19,7 +21,7 @@ CommonTestStorage, CommonPropTestStorage, StorageTestDbFixture -class RemoteStorageFixture(ServerTestFixture, StorageTestDbFixture, +class RemoteStorageFixture(ServerTestFixture, unittest.TestCase): """Test the remote storage API. @@ -29,6 +31,13 @@ """ def setUp(self): + self.app = app + super().setUp() + self.storage = RemoteStorage(self.url()) + + +class RemotePgStorageFixture(StorageTestDbFixture, RemoteStorageFixture): + def setUp(self): def mock_get_journal_writer(cls, args=None): assert cls == 'inmemory' return journal_writer @@ -43,6 +52,7 @@ # To avoid confusion, override the self.objroot to a # one chosen in this class. self.storage_base = tempfile.mkdtemp() + self.objroot = self.storage_base self.config = { 'storage': { 'cls': 'local', @@ -61,19 +71,63 @@ } } } - self.app = app super().setUp() - self.storage = RemoteStorage(self.url()) - self.objroot = self.storage_base def tearDown(self): - storage.get_journal_writer = get_journal_writer super().tearDown() shutil.rmtree(self.storage_base) + storage.get_journal_writer = get_journal_writer + + +class RemoteMemStorageFixture(RemoteStorageFixture): + def setUp(self): + self.config = { + 'storage': { + 'cls': 'memory', + 'args': { + 'journal_writer': { + 'cls': 'inmemory', + } + } + } + } + self.__storage = InMemoryStorage(journal_writer={'cls': 'inmemory'}) + self._get_storage_patcher = unittest.mock.patch( + 'swh.storage.api.server.get_storage', return_value=self.__storage) + self._get_storage_patcher.start() + super().setUp() + self.journal_writer = self.__storage.journal_writer + + def tearDown(self): + super().tearDown() + self._get_storage_patcher.stop() + + +class TestRemoteMemStorage(CommonTestStorage, RemoteMemStorageFixture): + @pytest.mark.skip('refresh_stat_counters not available in the remote api.') + def test_stat_counters(self): + pass + + @pytest.mark.skip('postgresql-specific test') + def test_content_add_db(self): + pass + + @pytest.mark.skip('postgresql-specific test') + def test_skipped_content_add_db(self): + pass + + @pytest.mark.skip('postgresql-specific test') + def test_content_add_metadata_db(self): + pass + + @pytest.mark.skip( + 'not implemented, see https://forge.softwareheritage.org/T1633') + def test_skipped_content_add(self): + pass @pytest.mark.db -class TestRemoteStorage(CommonTestStorage, RemoteStorageFixture): +class TestRemotePgStorage(CommonTestStorage, RemotePgStorageFixture): @pytest.mark.skip('refresh_stat_counters not available in the remote api.') def test_stat_counters(self): pass @@ -81,7 +135,7 @@ @pytest.mark.db @pytest.mark.property_based -class PropTestRemoteStorage(CommonPropTestStorage, RemoteStorageFixture): +class PropTestRemotePgStorage(CommonPropTestStorage, RemotePgStorageFixture): @pytest.mark.skip('too slow') def test_add_arbitrary(self): pass diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py --- a/swh/storage/tests/test_in_memory.py +++ b/swh/storage/tests/test_in_memory.py @@ -42,7 +42,6 @@ pass -@pytest.mark.db @pytest.mark.property_based class PropTestInMemoryStorage(CommonPropTestStorage, unittest.TestCase): """Test the in-memory storage API