diff --git a/swh/indexer/indexer.py b/swh/indexer/indexer.py --- a/swh/indexer/indexer.py +++ b/swh/indexer/indexer.py @@ -57,7 +57,6 @@ } -# TODO: should be bound=Optional[BaseRow] when all endpoints move away from dicts TResult = TypeVar("TResult") @@ -412,14 +411,7 @@ self.log.warning(f"Content {sha1.hex()} not found in objstorage") continue results = self.index(sha1, raw_content, **kwargs) - for res in results: - # TODO: remove this check when all endpoints moved away from dicts. - if isinstance(res, dict) and not isinstance(res["id"], bytes): - raise TypeError( - "%r.index should return ids as bytes, not %r" - % (self.__class__.__name__, res["id"]) - ) - yield res + yield from results def _index_with_skipping_already_done( self, partition_id: int, nb_partitions: int diff --git a/swh/indexer/tests/storage/test_storage.py b/swh/indexer/tests/storage/test_storage.py --- a/swh/indexer/tests/storage/test_storage.py +++ b/swh/indexer/tests/storage/test_storage.py @@ -6,7 +6,7 @@ import inspect import math import threading -from typing import Any, Dict, List, Tuple, Union, cast +from typing import Any, Dict, List, Tuple, Type, cast import attr import pytest @@ -126,19 +126,7 @@ endpoint_type: str tool_name: str example_data: List[Dict] - - # For old endpoints (which use dicts), this is just the identity function. - # For new endpoints, it returns a BaseRow instance. - @staticmethod - def row_from_dict(d: Dict) -> Union[Dict, BaseRow]: - return d - - # Inverse function of row_from_dict - # TODO: remove this once all endpoints are migrated to rows - @staticmethod - def dict_from_row(r: Union[Dict, BaseRow]) -> Dict: - assert isinstance(r, Dict) - return r + row_class: Type[BaseRow] def test_missing( self, swh_indexer_storage_with_data: Tuple[IndexerStorageInterface, Any] @@ -163,7 +151,7 @@ # now, when we add one of them summary = endpoint(storage, etype, "add")( [ - self.row_from_dict( + self.row_class.from_dict( { "id": data.sha1_2, **self.example_data[0], @@ -192,13 +180,13 @@ **self.example_data[0], "indexer_configuration_id": tool_id, } - summary = endpoint(storage, etype, "add")([self.row_from_dict(data_v1)]) + summary = endpoint(storage, etype, "add")([self.row_class.from_dict(data_v1)]) assert summary == expected_summary(1, etype) # should be able to retrieve it actual_data = list(endpoint(storage, etype, "get")([data.sha1_2])) expected_data_v1 = [ - self.row_from_dict( + self.row_class.from_dict( { "id": data.sha1_2, **self.example_data[0], @@ -211,7 +199,7 @@ # now if we add a modified version of the same object (same id) data_v2 = data_v1.copy() data_v2.update(self.example_data[1]) - summary2 = endpoint(storage, etype, "add")([self.row_from_dict(data_v2)]) + summary2 = endpoint(storage, etype, "add")([self.row_class.from_dict(data_v2)]) assert summary2 == expected_summary(0, etype) # not added # we expect to retrieve the original data, not the modified one @@ -232,14 +220,14 @@ } # given - summary = endpoint(storage, etype, "add")([self.row_from_dict(data_v1)]) + summary = endpoint(storage, etype, "add")([self.row_class.from_dict(data_v1)]) assert summary == expected_summary(1, etype) # not added # when actual_data = list(endpoint(storage, etype, "get")([data.sha1_2])) expected_data_v1 = [ - self.row_from_dict( + self.row_class.from_dict( {"id": data.sha1_2, **self.example_data[0], "tool": tool} ) ] @@ -252,14 +240,14 @@ data_v2.update(self.example_data[1]) endpoint(storage, etype, "add")( - [self.row_from_dict(data_v2)], conflict_update=True + [self.row_class.from_dict(data_v2)], conflict_update=True ) assert summary == expected_summary(1, etype) # modified so counted actual_data = list(endpoint(storage, etype, "get")([data.sha1_2])) expected_data_v2 = [ - self.row_from_dict( + self.row_class.from_dict( {"id": data.sha1_2, **self.example_data[1], "tool": tool,} ) ] @@ -280,7 +268,7 @@ ] data_v1 = [ - self.row_from_dict( + self.row_class.from_dict( { "id": hash_, **self.example_data[0], @@ -290,7 +278,7 @@ for hash_ in hashes ] data_v2 = [ - self.row_from_dict( + self.row_class.from_dict( { "id": hash_, **self.example_data[1], @@ -312,7 +300,9 @@ actual_data = list(endpoint(storage, etype, "get")(hashes)) expected_data_v1 = [ - self.row_from_dict({"id": hash_, **self.example_data[0], "tool": tool}) + self.row_class.from_dict( + {"id": hash_, **self.example_data[0], "tool": tool} + ) for hash_ in hashes ] @@ -335,7 +325,7 @@ t2.join() actual_data = sorted( - map(self.dict_from_row, endpoint(storage, etype, "get")(hashes)), + (row.to_dict() for row in endpoint(storage, etype, "get")(hashes)), key=lambda x: x["id"], ) @@ -352,7 +342,7 @@ etype = self.endpoint_type tool = data.tools[self.tool_name] - data_rev1 = self.row_from_dict( + data_rev1 = self.row_class.from_dict( { "id": data.revision_id_2, **self.example_data[0], @@ -360,7 +350,7 @@ } ) - data_rev2 = self.row_from_dict( + data_rev2 = self.row_class.from_dict( { "id": data.revision_id_2, **self.example_data[1], @@ -383,7 +373,7 @@ ) expected_data = [ - self.row_from_dict( + self.row_class.from_dict( {"id": data.revision_id_2, **self.example_data[0], "tool": tool} ) ] @@ -397,7 +387,7 @@ tool = data.tools[self.tool_name] query = [data.sha1_2, data.sha1_1] - data1 = self.row_from_dict( + data1 = self.row_class.from_dict( { "id": data.sha1_2, **self.example_data[0], @@ -414,7 +404,7 @@ # then expected_data = [ - self.row_from_dict( + self.row_class.from_dict( {"id": data.sha1_2, **self.example_data[0], "tool": tool} ) ] @@ -432,8 +422,7 @@ {"mimetype": "text/plain", "encoding": "utf-8",}, {"mimetype": "text/html", "encoding": "us-ascii",}, ] - row_from_dict = ContentMimetypeRow.from_dict - dict_from_row = staticmethod(lambda x: x.to_dict()) # type: ignore + row_class = ContentMimetypeRow def test_generate_content_mimetype_get_partition_failure( self, swh_indexer_storage: IndexerStorageInterface @@ -569,8 +558,7 @@ {"lang": "haskell",}, {"lang": "common-lisp",}, ] - row_from_dict = ContentLanguageRow.from_dict - dict_from_row = staticmethod(lambda x: x.to_dict()) # type: ignore + row_class = ContentLanguageRow class TestIndexerStorageContentCTags(StorageETypeTester): @@ -584,8 +572,7 @@ {"name": "done", "kind": "variable", "line": 100, "lang": "Python",}, {"name": "main", "kind": "function", "line": 119, "lang": "Python",}, ] - row_from_dict = ContentCtagsRow.from_dict - dict_from_row = staticmethod(lambda x: x.to_dict()) # type: ignore + row_class = ContentCtagsRow # the following tests are disabled because CTAGS behaves differently @pytest.mark.skip @@ -830,8 +817,7 @@ }, {"metadata": {"other": {}, "name": "test_metadata", "version": "0.0.1"},}, ] - row_from_dict = ContentMetadataRow.from_dict - dict_from_row = staticmethod(lambda x: x.to_dict()) # type: ignore + row_class = ContentMetadataRow class TestIndexerStorageRevisionIntrinsicMetadata(StorageETypeTester): @@ -859,8 +845,7 @@ "mappings": ["mapping2"], }, ] - row_from_dict = RevisionIntrinsicMetadataRow.from_dict - dict_from_row = staticmethod(lambda x: x.to_dict()) # type: ignore + row_class = RevisionIntrinsicMetadataRow def test_revision_intrinsic_metadata_delete( self, swh_indexer_storage_with_data: Tuple[IndexerStorageInterface, Any] @@ -906,8 +891,7 @@ endpoint_type = "content_fossology_license" tool_name = "nomos" - row_from_dict = ContentLicenseRow.from_dict - dict_from_row = staticmethod(lambda x: x.to_dict()) + row_class = ContentLicenseRow def test_content_fossology_license_add__new_license_added( self, swh_indexer_storage_with_data: Tuple[IndexerStorageInterface, Any] diff --git a/swh/indexer/tests/test_mimetype.py b/swh/indexer/tests/test_mimetype.py --- a/swh/indexer/tests/test_mimetype.py +++ b/swh/indexer/tests/test_mimetype.py @@ -107,8 +107,6 @@ """ - row_from_dict = staticmethod(lambda x: x) # type: ignore - def setUp(self): super().setUp() self.indexer = MimetypePartitionIndexer(config=RANGE_CONFIG) diff --git a/swh/indexer/tests/utils.py b/swh/indexer/tests/utils.py --- a/swh/indexer/tests/utils.py +++ b/swh/indexer/tests/utils.py @@ -5,14 +5,13 @@ import abc import functools -from typing import Any, Callable, Dict, Union +from typing import Any, Dict import unittest from hypothesis import strategies from swh.core.api.classes import stream_results from swh.indexer.storage import INDEXER_CFG_KEY -from swh.indexer.storage.model import BaseRow from swh.model import hashutil from swh.model.hashutil import hash_to_bytes from swh.model.model import ( @@ -628,11 +627,10 @@ self.indexer.run(sha1s, policy_update="update-dups") # then - # TODO: unconditionally use res.id when all endpoints moved away from dicts expected_results = [ res for res in self.expected_results - if hashutil.hash_to_hex(getattr(res, "id", None) or res["id"]) in sha1s + if hashutil.hash_to_hex(res.id) in sha1s ] self.assert_results_ok(sha1s, expected_results) @@ -643,9 +641,6 @@ """ - # TODO: remove this when all endpoints moved away from dicts - row_from_dict: Callable[[Union[Dict, BaseRow]], BaseRow] - def setUp(self): self.contents = sorted(OBJ_STORAGE_DATA) @@ -694,8 +689,7 @@ self.assert_results_ok(partition_id, nb_partitions, actual_results) - # TODO: unconditionally use res.id when all endpoints moved away from dicts - indexed_ids = {getattr(res, "id", None) or res["id"] for res in actual_results} + indexed_ids = {res.id for res in actual_results} actual_results = list( self.indexer._index_contents(