diff --git a/swh/indexer/storage/__init__.py b/swh/indexer/storage/__init__.py --- a/swh/indexer/storage/__init__.py +++ b/swh/indexer/storage/__init__.py @@ -442,7 +442,7 @@ yield converters.db_to_ctags(dict(zip(db.content_ctags_cols, obj))) @remote_api_endpoint('content/fossology_license') - @db_transaction_generator() + @db_transaction() def content_fossology_license_get(self, ids, db=None, cur=None): """Retrieve licenses per id. @@ -450,8 +450,8 @@ ids (iterable): sha1 checksums Yields: - dict: ``{id: facts}`` where ``facts`` is a dict with the - following keys: + dict: ``{id: facts}`` where ``facts`` is a list of + dicts with the following keys: - **licenses** ([str]): associated licenses for that content - **tool** (dict): Tool used to compute the license @@ -464,8 +464,7 @@ id_ = license['id'] d[id_].append(converters.db_to_fossology_license(license)) - for id_, facts in d.items(): - yield {id_: facts} + return d @remote_api_endpoint('content/fossology_license/add') @db_transaction() diff --git a/swh/indexer/storage/in_memory.py b/swh/indexer/storage/in_memory.py --- a/swh/indexer/storage/in_memory.py +++ b/swh/indexer/storage/in_memory.py @@ -435,11 +435,10 @@ # the new one. SubStorage.get should be updated once all other # *_get methods use the new format. # See: https://forge.softwareheritage.org/T1433 - res = {} + res = defaultdict(list) for d in self._licenses.get(ids): - res.setdefault(d.pop('id'), []).append(d) - for (id_, facts) in res.items(): - yield {id_: facts} + res[d.pop('id')].append(d) + return res def content_fossology_license_add(self, licenses, conflict_update=False): """Add licenses not present in storage. diff --git a/swh/indexer/tests/storage/test_storage.py b/swh/indexer/tests/storage/test_storage.py --- a/swh/indexer/tests/storage/test_storage.py +++ b/swh/indexer/tests/storage/test_storage.py @@ -884,8 +884,8 @@ self.storage.content_fossology_license_add([license_v1]) # when - actual_licenses = list(self.storage.content_fossology_license_get( - [self.sha1_1])) + actual_licenses = self.storage.content_fossology_license_get( + [self.sha1_1]) # then expected_license = { @@ -894,7 +894,7 @@ 'tool': tool, }] } - self.assertEqual(actual_licenses, [expected_license]) + self.assertEqual(actual_licenses, expected_license) # given license_v2 = license_v1.copy() @@ -904,8 +904,8 @@ self.storage.content_fossology_license_add([license_v2]) - actual_licenses = list(self.storage.content_fossology_license_get( - [self.sha1_1])) + actual_licenses = self.storage.content_fossology_license_get( + [self.sha1_1]) expected_license = { self.sha1_1: [{ @@ -915,7 +915,7 @@ } # license did not change as the v2 was dropped. - self.assertEqual(actual_licenses, [expected_license]) + self.assertEqual(actual_licenses, expected_license) # content_metadata tests ( diff --git a/swh/indexer/tests/test_fossology_license.py b/swh/indexer/tests/test_fossology_license.py --- a/swh/indexer/tests/test_fossology_license.py +++ b/swh/indexer/tests/test_fossology_license.py @@ -8,6 +8,7 @@ import pytest +from swh.model.hashutil import hash_to_bytes from swh.indexer import fossology_license from swh.indexer.fossology_license import ( FossologyLicenseIndexer, FossologyLicenseRangeIndexer, @@ -80,7 +81,7 @@ """ def get_indexer_results(self, ids): - yield from self.idx_storage.content_fossology_license_get(ids) + return self.idx_storage.content_fossology_license_get(ids) def setUp(self): super().setUp() @@ -102,24 +103,32 @@ for (k, v) in self.indexer.tool.items()} # then self.expected_results = { - self.id0: { + self.id0: [{ 'tool': tool, 'licenses': SHA1_TO_LICENSES[self.id0], - }, - self.id1: { + }], + self.id1: [{ 'tool': tool, 'licenses': SHA1_TO_LICENSES[self.id1], - }, - self.id2: { + }], + self.id2: [{ 'tool': tool, 'licenses': SHA1_TO_LICENSES[self.id2], - } + }] } def tearDown(self): super().tearDown() fossology_license.compute_license = self.orig_compute_license + def assert_results_ok(self, sha1s, expected_results=None): + sha1s = [sha1 if isinstance(sha1, bytes) else hash_to_bytes(sha1) + for sha1 in sha1s] + if not expected_results: + expected_results = self.expected_results + actual_results = self.get_indexer_results(sha1s) + self.assertEqual(expected_results, actual_results) + class TestFossologyLicenseRangeIndexer( CommonContentIndexerRangeTest, unittest.TestCase):