diff --git a/swh/indexer/fossology_license.py b/swh/indexer/fossology_license.py --- a/swh/indexer/fossology_license.py +++ b/swh/indexer/fossology_license.py @@ -9,6 +9,7 @@ from swh.core.config import merge_configs from swh.indexer.storage.interface import IndexerStorageInterface, PagedResult, Sha1 +from swh.indexer.storage.model import ContentLicenseRow from swh.model import hashutil from swh.model.model import Revision @@ -17,7 +18,7 @@ logger = logging.getLogger(__name__) -def compute_license(path): +def compute_license(path) -> Dict: """Determine license from file at path. Args: @@ -83,7 +84,7 @@ def index( self, id: Union[bytes, Dict, Revision], data: Optional[bytes] = None, **kwargs - ) -> Dict[str, Any]: + ) -> List[ContentLicenseRow]: """Index sha1s' content and store result. Args: @@ -107,10 +108,15 @@ working_directory=self.working_directory, ) as content_path: properties = compute_license(path=content_path) - return [properties] + return [ + ContentLicenseRow( + id=id, indexer_configuration_id=self.tool["id"], license=license, + ) + for license in properties["licenses"] + ] def persist_index_computations( - self, results: List[Dict], policy_update: str + self, results: List[ContentLicenseRow], policy_update: str ) -> Dict[str, int]: """Persist the results in storage. @@ -131,7 +137,9 @@ ) -class FossologyLicenseIndexer(MixinFossologyLicenseIndexer, ContentIndexer[Dict]): +class FossologyLicenseIndexer( + MixinFossologyLicenseIndexer, ContentIndexer[ContentLicenseRow] +): """Indexer in charge of: - filtering out content already indexed @@ -151,7 +159,7 @@ class FossologyLicensePartitionIndexer( - MixinFossologyLicenseIndexer, ContentPartitionIndexer[Dict] + MixinFossologyLicenseIndexer, ContentPartitionIndexer[ContentLicenseRow] ): """FossologyLicense Range Indexer working on range/partition of content identifiers. diff --git a/swh/indexer/storage/__init__.py b/swh/indexer/storage/__init__.py --- a/swh/indexer/storage/__init__.py +++ b/swh/indexer/storage/__init__.py @@ -4,7 +4,7 @@ # See top-level LICENSE file for more information -from collections import Counter, defaultdict +from collections import Counter import itertools import json from typing import Dict, Iterable, List, Optional, Tuple @@ -357,33 +357,31 @@ @timed @db_transaction_generator() - def content_fossology_license_get(self, ids, db=None, cur=None): - d = defaultdict(list) + def content_fossology_license_get( + self, ids: Iterable[Sha1], db=None, cur=None + ) -> Iterable[ContentLicenseRow]: for c in db.content_fossology_license_get_from_list(ids, cur): - license = dict(zip(db.content_fossology_license_cols, c)) - - id_ = license["id"] - d[id_].append(converters.db_to_fossology_license(license)) - - for id_, facts in d.items(): - yield {id_: facts} + yield ContentLicenseRow.from_dict( + converters.db_to_fossology_license( + dict(zip(db.content_fossology_license_cols, c)) + ) + ) @timed @process_metrics @db_transaction() def content_fossology_license_add( - self, licenses: List[Dict], conflict_update: bool = False, db=None, cur=None + self, + licenses: List[ContentLicenseRow], + conflict_update: bool = False, + db=None, + cur=None, ) -> Dict[str, int]: - rows = list( - itertools.chain.from_iterable( - map(converters.fossology_license_to_db, licenses) - ) - ) - check_id_duplicates(map(ContentLicenseRow.from_dict, rows)) - licenses.sort(key=lambda m: m["id"]) + check_id_duplicates(licenses) + licenses.sort(key=lambda m: m.id) db.mktemp_content_fossology_license(cur) db.copy_to( - rows, + [license.to_dict() for license in licenses], tblname="tmp_content_fossology_license", columns=["id", "license", "indexer_configuration_id"], cur=cur, diff --git a/swh/indexer/storage/converters.py b/swh/indexer/storage/converters.py --- a/swh/indexer/storage/converters.py +++ b/swh/indexer/storage/converters.py @@ -42,18 +42,6 @@ } -def fossology_license_to_db(licenses): - """Similar to ctags_to_db, but for licenses.""" - id = licenses["id"] - tool_id = licenses["indexer_configuration_id"] - for license in licenses["licenses"]: - yield { - "id": id, - "indexer_configuration_id": tool_id, - "license": license, - } - - def db_to_ctags(ctag): """Convert a ctags entry into a ready ctags entry. @@ -142,7 +130,8 @@ def db_to_fossology_license(license): return { - "licenses": license["licenses"], + "id": license["id"], + "license": license["license"], "tool": { "id": license["tool_id"], "name": license["tool_name"], diff --git a/swh/indexer/storage/db.py b/swh/indexer/storage/db.py --- a/swh/indexer/storage/db.py +++ b/swh/indexer/storage/db.py @@ -98,13 +98,15 @@ return "%s.id" % main_table elif key == "tool_id": return "i.id as tool_id" - elif key == "licenses": + elif key == "license": return ( """ - array(select name - from fossology_license - where id = ANY( - array_agg(%s.license_id))) as licenses""" + ( + select name + from fossology_license + where id = %s.license_id + ) + as licenses""" % main_table ) return key @@ -294,7 +296,7 @@ "tool_name", "tool_version", "tool_configuration", - "licenses", + "license", ] @stored_procedure("swh_mktemp_content_fossology_license") @@ -325,8 +327,6 @@ inner join content_fossology_license c on t.id=c.id inner join indexer_configuration i on i.id=c.indexer_configuration_id - group by c.id, i.id, i.tool_name, i.tool_version, - i.tool_configuration; """ % ", ".join(keys), ((_id,) for _id in ids), diff --git a/swh/indexer/storage/in_memory.py b/swh/indexer/storage/in_memory.py --- a/swh/indexer/storage/in_memory.py +++ b/swh/indexer/storage/in_memory.py @@ -345,36 +345,15 @@ for ctag in ctags: yield {"id": id_, "tool": tool, **ctag} - def content_fossology_license_get(self, ids): - # Rewrites the output of SubStorage.get from the old format to - # the new one. SubStorage.get should be updated once all other - # *_get methods use the new format. - # See: https://forge.softwareheritage.org/T1433 - for id_ in ids: - items = {} - for obj in self._licenses.get([id_]): - items.setdefault(obj.tool["id"], (obj.tool, []))[1].append(obj.license) - if items: - yield { - id_: [ - {"tool": tool, "licenses": licenses} - for (tool, licenses) in items.values() - ] - } + def content_fossology_license_get( + self, ids: Iterable[Sha1] + ) -> Iterable[ContentLicenseRow]: + return self._licenses.get(ids) def content_fossology_license_add( - self, licenses: List[Dict], conflict_update: bool = False + self, licenses: List[ContentLicenseRow], conflict_update: bool = False ) -> Dict[str, int]: - check_id_types(licenses) - added = self._licenses.add( - map( - ContentLicenseRow.from_dict, - itertools.chain.from_iterable( - map(converters.fossology_license_to_db, licenses) - ), - ), - conflict_update, - ) + added = self._licenses.add(licenses, conflict_update) return {"content_fossology_license:add": added} def content_fossology_license_get_partition( diff --git a/swh/indexer/storage/interface.py b/swh/indexer/storage/interface.py --- a/swh/indexer/storage/interface.py +++ b/swh/indexer/storage/interface.py @@ -7,7 +7,7 @@ from swh.core.api import remote_api_endpoint from swh.core.api.classes import PagedResult as CorePagedResult -from swh.indexer.storage.model import ContentMimetypeRow +from swh.indexer.storage.model import ContentLicenseRow, ContentMimetypeRow TResult = TypeVar("TResult") PagedResult = CorePagedResult[TResult, str] @@ -79,7 +79,7 @@ Args: mimetypes: mimetype rows to be added, with their `tool` attribute set to - not None. + None. conflict_update: Flag to determine if we want to overwrite (``True``) or skip duplicates (``False``, the default) @@ -233,34 +233,30 @@ ... @remote_api_endpoint("content/fossology_license") - def content_fossology_license_get(self, ids): + def content_fossology_license_get( + self, ids: Iterable[Sha1] + ) -> Iterable[ContentLicenseRow]: """Retrieve licenses per id. Args: - ids (iterable): sha1 checksums + ids: sha1 identifiers Yields: - dict: ``{id: facts}`` where ``facts`` is a dict with the - following keys: - - - **licenses** ([str]): associated licenses for that content - - **tool** (dict): Tool used to compute the license + license rows; possibly more than one per (sha1, tool_id) if there + are multiple licenses. """ ... @remote_api_endpoint("content/fossology_license/add") def content_fossology_license_add( - self, licenses: List[Dict], conflict_update: bool = False + self, licenses: List[ContentLicenseRow], conflict_update: bool = False ) -> Dict[str, int]: """Add licenses not present in storage. Args: - licenses (iterable): dictionaries with keys: - - - **id**: sha1 - - **licenses** ([bytes]): List of licenses associated to sha1 - - **tool** (str): nomossa + license: license rows to be added, with their `tool` attribute set to + None. conflict_update: Flag to determine if we want to overwrite (true) or skip duplicates (false, the default) diff --git a/swh/indexer/tests/storage/conftest.py b/swh/indexer/tests/storage/conftest.py --- a/swh/indexer/tests/storage/conftest.py +++ b/swh/indexer/tests/storage/conftest.py @@ -8,7 +8,7 @@ import pytest from swh.indexer.storage import get_indexer_storage -from swh.indexer.storage.model import ContentMimetypeRow +from swh.indexer.storage.model import ContentLicenseRow, ContentMimetypeRow from swh.model.hashutil import hash_to_bytes from swh.storage.pytest_plugin import postgresql_fact @@ -53,8 +53,13 @@ ] swh_indexer_storage.content_mimetype_add(data.mimetypes) data.fossology_licenses = [ - {**fossology_obj, "indexer_configuration_id": tools["nomos"]["id"]} + ContentLicenseRow( + id=fossology_obj["id"], + indexer_configuration_id=tools["nomos"]["id"], + license=license, + ) for fossology_obj in FOSSOLOGY_LICENSES + for license in fossology_obj["licenses"] ] swh_indexer_storage._test_data = data diff --git a/swh/indexer/tests/storage/test_converters.py b/swh/indexer/tests/storage/test_converters.py --- a/swh/indexer/tests/storage/test_converters.py +++ b/swh/indexer/tests/storage/test_converters.py @@ -136,11 +136,12 @@ "tool_name": "nomossa", "tool_version": "5.22", "tool_configuration": {}, - "licenses": ["GPL2.0"], + "license": "GPL2.0", } expected_license = { - "licenses": ["GPL2.0"], + "id": b"some-id", + "license": "GPL2.0", "tool": {"id": 20, "name": "nomossa", "version": "5.22", "configuration": {},}, } diff --git a/swh/indexer/tests/storage/test_storage.py b/swh/indexer/tests/storage/test_storage.py --- a/swh/indexer/tests/storage/test_storage.py +++ b/swh/indexer/tests/storage/test_storage.py @@ -12,11 +12,13 @@ from swh.indexer.storage.exc import DuplicateId, IndexerStorageArgumentException from swh.indexer.storage.interface import IndexerStorageInterface -from swh.indexer.storage.model import BaseRow, ContentMimetypeRow +from swh.indexer.storage.model import BaseRow, ContentLicenseRow, ContentMimetypeRow from swh.model.hashutil import hash_to_bytes -def prepare_mimetypes_from(fossology_licenses: List[Dict]) -> List[ContentMimetypeRow]: +def prepare_mimetypes_from_licenses( + fossology_licenses: List[ContentLicenseRow], +) -> List[ContentMimetypeRow]: """Fossology license needs some consistent data in db to run. """ @@ -24,10 +26,10 @@ for c in fossology_licenses: mimetypes.append( ContentMimetypeRow( - id=c["id"], + id=c.id, mimetype="text/plain", # for filtering on textual data to work encoding="utf-8", - indexer_configuration_id=c["indexer_configuration_id"], + indexer_configuration_id=c.indexer_configuration_id, ) ) return mimetypes @@ -1004,6 +1006,9 @@ endpoint_type = "content_fossology_license" tool_name = "nomos" + row_from_dict = ContentLicenseRow.from_dict + dict_from_row = staticmethod(lambda x: x.to_dict()) + def test_content_fossology_license_add__new_license_added( self, swh_indexer_storage_with_data: Tuple[IndexerStorageInterface, Any] ) -> None: @@ -1012,40 +1017,39 @@ tool = data.tools["nomos"] tool_id = tool["id"] - license_v1 = { - "id": data.sha1_1, - "licenses": ["Apache-2.0"], - "indexer_configuration_id": tool_id, - } + license1 = ContentLicenseRow( + id=data.sha1_1, license="Apache-2.0", indexer_configuration_id=tool_id, + ) # given - storage.content_fossology_license_add([license_v1]) + storage.content_fossology_license_add([license1]) # conflict does nothing - storage.content_fossology_license_add([license_v1]) + storage.content_fossology_license_add([license1]) # when actual_licenses = list(storage.content_fossology_license_get([data.sha1_1])) # then - expected_license = {data.sha1_1: [{"licenses": ["Apache-2.0"], "tool": tool,}]} - assert actual_licenses == [expected_license] + expected_licenses = [ + ContentLicenseRow(id=data.sha1_1, license="Apache-2.0", tool=tool,) + ] + assert actual_licenses == expected_licenses # given - license_v2 = license_v1.copy() - license_v2.update( - {"licenses": ["BSD-2-Clause"],} + license2 = ContentLicenseRow( + id=data.sha1_1, license="BSD-2-Clause", indexer_configuration_id=tool_id, ) - storage.content_fossology_license_add([license_v2]) + storage.content_fossology_license_add([license2]) actual_licenses = list(storage.content_fossology_license_get([data.sha1_1])) - expected_license = { - data.sha1_1: [{"licenses": ["Apache-2.0", "BSD-2-Clause"], "tool": tool}] - } + expected_licenses.append( + ContentLicenseRow(id=data.sha1_1, license="BSD-2-Clause", tool=tool,) + ) - # license did not change as the v2 was dropped. - assert actual_licenses == [expected_license] + # first license was not removed when the second one was added + assert sorted(actual_licenses) == sorted(expected_licenses) def test_generate_content_fossology_license_get_partition_failure( self, swh_indexer_storage_with_data: Tuple[IndexerStorageInterface, Any] @@ -1067,15 +1071,15 @@ storage, data = swh_indexer_storage_with_data # craft some consistent mimetypes fossology_licenses = data.fossology_licenses - mimetypes = prepare_mimetypes_from(fossology_licenses) - indexer_configuration_id = fossology_licenses[0]["indexer_configuration_id"] + mimetypes = prepare_mimetypes_from_licenses(fossology_licenses) + indexer_configuration_id = fossology_licenses[0].indexer_configuration_id storage.content_mimetype_add(mimetypes, conflict_update=True) # add fossology_licenses to storage storage.content_fossology_license_add(fossology_licenses) # All ids from the db - expected_ids = set([c["id"] for c in fossology_licenses]) + expected_ids = set([c.id for c in fossology_licenses]) assert len(fossology_licenses) == 10 assert len(mimetypes) == 10 @@ -1103,15 +1107,15 @@ storage, data = swh_indexer_storage_with_data # craft some consistent mimetypes fossology_licenses = data.fossology_licenses - mimetypes = prepare_mimetypes_from(fossology_licenses) - indexer_configuration_id = fossology_licenses[0]["indexer_configuration_id"] + mimetypes = prepare_mimetypes_from_licenses(fossology_licenses) + indexer_configuration_id = fossology_licenses[0].indexer_configuration_id storage.content_mimetype_add(mimetypes, conflict_update=True) # add fossology_licenses to storage storage.content_fossology_license_add(fossology_licenses) # All ids from the db - expected_ids = set([c["id"] for c in fossology_licenses]) + expected_ids = set([c.id for c in fossology_licenses]) actual_result = storage.content_fossology_license_get_partition( indexer_configuration_id, 0, 1 @@ -1129,15 +1133,15 @@ storage, data = swh_indexer_storage_with_data # craft some consistent mimetypes fossology_licenses = data.fossology_licenses - mimetypes = prepare_mimetypes_from(fossology_licenses) - indexer_configuration_id = fossology_licenses[0]["indexer_configuration_id"] + mimetypes = prepare_mimetypes_from_licenses(fossology_licenses) + indexer_configuration_id = fossology_licenses[0].indexer_configuration_id storage.content_mimetype_add(mimetypes, conflict_update=True) # add fossology_licenses to storage storage.content_fossology_license_add(fossology_licenses) # All ids from the db - expected_ids = set([c["id"] for c in fossology_licenses]) + expected_ids = set([c.id for c in fossology_licenses]) # nb_partitions = smallest power of 2 such that at least one of # the partitions is empty @@ -1171,15 +1175,15 @@ storage, data = swh_indexer_storage_with_data # craft some consistent mimetypes fossology_licenses = data.fossology_licenses - mimetypes = prepare_mimetypes_from(fossology_licenses) - indexer_configuration_id = fossology_licenses[0]["indexer_configuration_id"] + mimetypes = prepare_mimetypes_from_licenses(fossology_licenses) + indexer_configuration_id = fossology_licenses[0].indexer_configuration_id storage.content_mimetype_add(mimetypes, conflict_update=True) # add fossology_licenses to storage storage.content_fossology_license_add(fossology_licenses) # All ids from the db - expected_ids = [c["id"] for c in fossology_licenses] + expected_ids = [c.id for c in fossology_licenses] nb_partitions = 4 @@ -1208,17 +1212,8 @@ ) -> None: (storage, data) = swh_indexer_storage_with_data etype = self.endpoint_type - tool = data.tools[self.tool_name] - summary = endpoint(storage, etype, "add")( - [ - { - "id": data.sha1_2, - "indexer_configuration_id": tool["id"], - "licenses": [], - } - ] - ) + summary = endpoint(storage, etype, "add")([]) assert summary == {"content_fossology_license:add": 0} actual_license = list(endpoint(storage, etype, "get")([data.sha1_2])) diff --git a/swh/indexer/tests/test_ctags.py b/swh/indexer/tests/test_ctags.py --- a/swh/indexer/tests/test_ctags.py +++ b/swh/indexer/tests/test_ctags.py @@ -20,6 +20,7 @@ fill_storage, filter_dict, ) +from swh.model.hashutil import hash_to_bytes class BasicTest(unittest.TestCase): @@ -87,8 +88,6 @@ """ - legacy_get_format = True - def get_indexer_results(self, ids): yield from self.idx_storage.content_ctags_get(ids) @@ -107,11 +106,23 @@ tool = {k.replace("tool_", ""): v for (k, v) in self.indexer.tool.items()} - self.expected_results = { - self.id0: {"id": self.id0, "tool": tool, **SHA1_TO_CTAGS[self.id0][0],}, - self.id1: {"id": self.id1, "tool": tool, **SHA1_TO_CTAGS[self.id1][0],}, - self.id2: {"id": self.id2, "tool": tool, **SHA1_TO_CTAGS[self.id2][0],}, - } + self.expected_results = [ + { + "id": hash_to_bytes(self.id0), + "tool": tool, + **SHA1_TO_CTAGS[self.id0][0], + }, + { + "id": hash_to_bytes(self.id1), + "tool": tool, + **SHA1_TO_CTAGS[self.id1][0], + }, + { + "id": hash_to_bytes(self.id2), + "tool": tool, + **SHA1_TO_CTAGS[self.id2][0], + }, + ] self._set_mocks() diff --git a/swh/indexer/tests/test_fossology_license.py b/swh/indexer/tests/test_fossology_license.py --- a/swh/indexer/tests/test_fossology_license.py +++ b/swh/indexer/tests/test_fossology_license.py @@ -25,6 +25,7 @@ fill_storage, filter_dict, ) +from swh.model.hashutil import hash_to_bytes class BasicTest(unittest.TestCase): @@ -98,11 +99,21 @@ tool = {k.replace("tool_", ""): v for (k, v) in self.indexer.tool.items()} # then - self.expected_results = { - self.id0: {"tool": tool, "licenses": SHA1_TO_LICENSES[self.id0],}, - self.id1: {"tool": tool, "licenses": SHA1_TO_LICENSES[self.id1],}, - self.id2: None, - } + self.expected_results = [ + *[ + ContentLicenseRow( + id=hash_to_bytes(self.id0), tool=tool, license=license + ) + for license in SHA1_TO_LICENSES[self.id0] + ], + *[ + ContentLicenseRow( + id=hash_to_bytes(self.id1), tool=tool, license=license + ) + for license in SHA1_TO_LICENSES[self.id1] + ], + *[], # self.id2 + ] def tearDown(self): super().tearDown() @@ -137,21 +148,6 @@ super().tearDown() fossology_license.compute_license = self.orig_compute_license - def assert_results_ok(self, partition_id, nb_partitions, actual_results): - # TODO: remove this method when fossology_license endpoints moved away - # from dicts. - actual_result_rows = [] - for res in actual_results: - for license in res["licenses"]: - actual_result_rows.append( - ContentLicenseRow( - id=res["id"], - indexer_configuration_id=res["indexer_configuration_id"], - license=license, - ) - ) - super().assert_results_ok(partition_id, nb_partitions, actual_result_rows) - def test_fossology_w_no_tool(): with pytest.raises(ValueError): diff --git a/swh/indexer/tests/test_mimetype.py b/swh/indexer/tests/test_mimetype.py --- a/swh/indexer/tests/test_mimetype.py +++ b/swh/indexer/tests/test_mimetype.py @@ -13,6 +13,7 @@ MimetypePartitionIndexer, compute_mimetype_encoding, ) +from swh.indexer.storage.model import ContentMimetypeRow from swh.indexer.tests.utils import ( BASE_TEST_CONFIG, CommonContentIndexerPartitionTest, @@ -21,6 +22,7 @@ fill_storage, filter_dict, ) +from swh.model.hashutil import hash_to_bytes def test_compute_mimetype_encoding(): @@ -52,10 +54,8 @@ """ - legacy_get_format = True - def get_indexer_results(self, ids): - yield from (x.to_dict() for x in self.idx_storage.content_mimetype_get(ids)) + yield from self.idx_storage.content_mimetype_get(ids) def setUp(self): self.indexer = MimetypeIndexer(config=CONFIG) @@ -70,26 +70,26 @@ tool = {k.replace("tool_", ""): v for (k, v) in self.indexer.tool.items()} - self.expected_results = { - self.id0: { - "id": self.id0, - "tool": tool, - "mimetype": "text/plain", - "encoding": "us-ascii", - }, - self.id1: { - "id": self.id1, - "tool": tool, - "mimetype": "text/plain", - "encoding": "us-ascii", - }, - self.id2: { - "id": self.id2, - "tool": tool, - "mimetype": "application/x-empty", - "encoding": "binary", - }, - } + self.expected_results = [ + ContentMimetypeRow( + id=hash_to_bytes(self.id0), + tool=tool, + mimetype="text/plain", + encoding="us-ascii", + ), + ContentMimetypeRow( + id=hash_to_bytes(self.id1), + tool=tool, + mimetype="text/plain", + encoding="us-ascii", + ), + ContentMimetypeRow( + id=hash_to_bytes(self.id2), + tool=tool, + mimetype="application/x-empty", + encoding="binary", + ), + ] RANGE_CONFIG = dict(list(CONFIG.items()) + [("write_batch_size", 100)]) diff --git a/swh/indexer/tests/utils.py b/swh/indexer/tests/utils.py --- a/swh/indexer/tests/utils.py +++ b/swh/indexer/tests/utils.py @@ -585,44 +585,11 @@ class CommonContentIndexerTest(metaclass=abc.ABCMeta): - legacy_get_format = False - """True if and only if the tested indexer uses the legacy format. - see: https://forge.softwareheritage.org/T1433 - - """ - def get_indexer_results(self, ids): """Override this for indexers that don't have a mock storage.""" return self.indexer.idx_storage.state - def assert_legacy_results_ok(self, sha1s, expected_results=None): - # XXX old format, remove this when all endpoints are - # updated to the new one - # see: https://forge.softwareheritage.org/T1433 - sha1s = [ - sha1 if isinstance(sha1, bytes) else hash_to_bytes(sha1) for sha1 in sha1s - ] - actual_results = list(self.get_indexer_results(sha1s)) - - if expected_results is None: - expected_results = self.expected_results - - self.assertEqual( - len(expected_results), - len(actual_results), - (expected_results, actual_results), - ) - for indexed_data in actual_results: - _id = indexed_data["id"] - expected_data = expected_results[hashutil.hash_to_hex(_id)].copy() - expected_data["id"] = _id - self.assertEqual(indexed_data, expected_data) - def assert_results_ok(self, sha1s, expected_results=None): - if self.legacy_get_format: - self.assert_legacy_results_ok(sha1s, expected_results) - return - sha1s = [ sha1 if isinstance(sha1, bytes) else hash_to_bytes(sha1) for sha1 in sha1s ] @@ -631,19 +598,7 @@ if expected_results is None: expected_results = self.expected_results - self.assertEqual( - sum(res is not None for res in expected_results.values()), - sum(sum(map(len, res.values())) for res in actual_results), - (expected_results, actual_results), - ) - for indexed_data in actual_results: - (_id, indexed_data) = list(indexed_data.items())[0] - if expected_results.get(hashutil.hash_to_hex(_id)) is None: - self.assertEqual(indexed_data, []) - else: - expected_data = expected_results[hashutil.hash_to_hex(_id)].copy() - expected_data = [expected_data] - self.assertEqual(indexed_data, expected_data) + self.assertEqual(expected_results, actual_results) def test_index(self): """Known sha1 have their data indexed @@ -673,9 +628,12 @@ self.indexer.run(sha1s, policy_update="update-dups") # then - expected_results = { - k: v for k, v in self.expected_results.items() if k in sha1s - } + # TODO: unconditionally use res.id when all endpoints moved away from dicts + expected_results = [ + res + for res in self.expected_results + if hashutil.hash_to_hex(getattr(res, "id", None) or res["id"]) in sha1s + ] self.assert_results_ok(sha1s, expected_results)