diff --git a/swh/indexer/indexer.py b/swh/indexer/indexer.py --- a/swh/indexer/indexer.py +++ b/swh/indexer/indexer.py @@ -14,6 +14,7 @@ from swh.core import utils from swh.core.config import load_from_envvar, merge_configs from swh.indexer.storage import INDEXER_CFG_KEY, PagedResult, Sha1, get_indexer_storage +from swh.indexer.storage.model import BaseRow from swh.model import hashutil from swh.model.model import Revision from swh.objstorage.exc import ObjNotFoundError @@ -109,7 +110,7 @@ """ - results: List[Dict] + results: List[Union[Dict, BaseRow]] USE_TOOLS = True @@ -211,7 +212,7 @@ def index( self, id: Union[bytes, Dict, Revision], data: Optional[bytes] = None, **kwargs - ) -> Dict[str, Any]: + ) -> Union[Dict[str, Any], BaseRow]: """Index computation for the id and associated raw data. Args: @@ -385,7 +386,7 @@ def _index_contents( self, partition_id: int, nb_partitions: int, indexed: Set[Sha1], **kwargs: Any - ) -> Iterator[Dict]: + ) -> Iterator[Union[BaseRow, Dict]]: """Index the contents within the partition_id. Args: @@ -405,7 +406,8 @@ continue res = self.index(sha1, raw_content, **kwargs) if res: - if not isinstance(res["id"], bytes): + # TODO: remove this check when all endpoints moved away from dicts. + if isinstance(res, dict) and not isinstance(res["id"], bytes): raise TypeError( "%r.index should return ids as bytes, not %r" % (self.__class__.__name__, res["id"]) @@ -414,7 +416,7 @@ def _index_with_skipping_already_done( self, partition_id: int, nb_partitions: int - ) -> Iterator[Dict]: + ) -> Iterator[Union[BaseRow, Dict]]: """Index not already indexed contents within the partition partition_id Args: @@ -538,7 +540,9 @@ summary.update(summary_persist) return summary - def index_list(self, origins: List[Any], **kwargs: Any) -> List[Dict]: + def index_list( + self, origins: List[Any], **kwargs: Any + ) -> List[Union[Dict, BaseRow]]: results = [] for origin in origins: try: diff --git a/swh/indexer/metadata.py b/swh/indexer/metadata.py --- a/swh/indexer/metadata.py +++ b/swh/indexer/metadata.py @@ -268,6 +268,7 @@ ) # on the fly possibility: for result in c_metadata_indexer.results: + assert isinstance(result, dict) # TODO: remove this local_metadata = result["metadata"] metadata.append(local_metadata) diff --git a/swh/indexer/mimetype.py b/swh/indexer/mimetype.py --- a/swh/indexer/mimetype.py +++ b/swh/indexer/mimetype.py @@ -9,6 +9,7 @@ from swh.core.config import merge_configs from swh.indexer.storage.interface import PagedResult, Sha1 +from swh.indexer.storage.model import ContentMimetypeRow from swh.model.model import Revision from .indexer import ContentIndexer, ContentPartitionIndexer @@ -20,7 +21,7 @@ ) -def compute_mimetype_encoding(raw_content: bytes) -> Dict[str, bytes]: +def compute_mimetype_encoding(raw_content: bytes) -> Dict[str, str]: """Determine mimetype and encoding from the raw content. Args: @@ -68,7 +69,7 @@ def index( self, id: Union[bytes, Dict, Revision], data: Optional[bytes] = None, **kwargs - ) -> Dict[str, Any]: + ) -> ContentMimetypeRow: """Index sha1s' content and store result. Args: @@ -83,13 +84,15 @@ - encoding: encoding in bytes """ + assert isinstance(id, bytes) assert data is not None properties = compute_mimetype_encoding(data) - assert isinstance(id, bytes) - properties.update( - {"id": id, "indexer_configuration_id": self.tool["id"],} + return ContentMimetypeRow( + id=id, + indexer_configuration_id=self.tool["id"], + mimetype=properties["mimetype"], + encoding=properties["encoding"], ) - return properties def persist_index_computations( self, results: List[Dict], policy_update: str diff --git a/swh/indexer/storage/__init__.py b/swh/indexer/storage/__init__.py --- a/swh/indexer/storage/__init__.py +++ b/swh/indexer/storage/__init__.py @@ -7,7 +7,7 @@ from collections import Counter, defaultdict import itertools import json -from typing import Dict, List, Optional +from typing import Dict, Iterable, Iterator, List, Optional, Tuple import psycopg2 import psycopg2.pool @@ -144,7 +144,9 @@ @timed @db_transaction_generator() - def content_mimetype_missing(self, mimetypes, db=None, cur=None): + def content_mimetype_missing( + self, mimetypes: Iterable[Dict], db=None, cur=None + ) -> Iterator[Tuple[Sha1, int]]: for obj in db.content_mimetype_missing_from_list(mimetypes, cur): yield obj[0] @@ -245,7 +247,11 @@ @process_metrics @db_transaction() def content_mimetype_add( - self, mimetypes: List[Dict], conflict_update: bool = False, db=None, cur=None + self, + mimetypes: List[ContentMimetypeRow], + conflict_update: bool = False, + db=None, + cur=None, ) -> Dict[str, int]: """Add mimetypes to the storage (if conflict_update is True, this will override existing data if any). @@ -254,11 +260,11 @@ A dict with the number of new elements added to the storage. """ - check_id_duplicates(map(ContentMimetypeRow.from_dict, mimetypes)) - mimetypes.sort(key=lambda m: m["id"]) + check_id_duplicates(mimetypes) + mimetypes.sort(key=lambda m: m.id) db.mktemp_content_mimetype(cur) db.copy_to( - mimetypes, + [m.to_dict() for m in mimetypes], "tmp_content_mimetype", ["id", "mimetype", "encoding", "indexer_configuration_id"], cur, @@ -268,9 +274,13 @@ @timed @db_transaction_generator() - def content_mimetype_get(self, ids, db=None, cur=None): + def content_mimetype_get( + self, ids: Iterable[Sha1], db=None, cur=None + ) -> Iterator[ContentMimetypeRow]: for c in db.content_mimetype_get_from_list(ids, cur): - yield converters.db_to_mimetype(dict(zip(db.content_mimetype_cols, c))) + yield ContentMimetypeRow.from_dict( + converters.db_to_mimetype(dict(zip(db.content_mimetype_cols, c))) + ) @timed @db_transaction_generator() diff --git a/swh/indexer/storage/api/client.py b/swh/indexer/storage/api/client.py --- a/swh/indexer/storage/api/client.py +++ b/swh/indexer/storage/api/client.py @@ -11,6 +11,7 @@ ) from ..interface import IndexerStorageInterface +from .serializers import DECODERS, ENCODERS class RemoteStorage(RPCClient): @@ -19,3 +20,5 @@ backend_class = IndexerStorageInterface api_exception = IndexerStorageAPIError reraise_exceptions = [IndexerStorageArgumentException, DuplicateId] + extra_type_decoders = DECODERS + extra_type_encoders = ENCODERS diff --git a/swh/indexer/storage/api/serializers.py b/swh/indexer/storage/api/serializers.py new file mode 100644 --- /dev/null +++ b/swh/indexer/storage/api/serializers.py @@ -0,0 +1,26 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +"""Decoder and encoders for swh-model objects.""" + +from typing import Callable, Dict, List, Tuple + +import swh.indexer.storage.model as idx_model + + +def _encode_model_object(obj): + d = obj.to_dict() + d["__type__"] = type(obj).__name__ + return d + + +ENCODERS: List[Tuple[type, str, Callable]] = [ + (idx_model.BaseRow, "idx_model", _encode_model_object), +] + + +DECODERS: Dict[str, Callable] = { + "idx_model": lambda d: getattr(idx_model, d.pop("__type__")).from_dict(d), +} diff --git a/swh/indexer/storage/api/server.py b/swh/indexer/storage/api/server.py --- a/swh/indexer/storage/api/server.py +++ b/swh/indexer/storage/api/server.py @@ -14,6 +14,8 @@ from swh.indexer.storage.exc import IndexerStorageArgumentException from swh.indexer.storage.interface import IndexerStorageInterface +from .serializers import DECODERS, ENCODERS + def get_storage(): global storage @@ -23,7 +25,12 @@ return storage -app = RPCServerApp( +class IndexerStorageServerApp(RPCServerApp): + extra_type_decoders = DECODERS + extra_type_encoders = ENCODERS + + +app = IndexerStorageServerApp( __name__, backend_class=IndexerStorageInterface, backend_factory=get_storage ) storage = None diff --git a/swh/indexer/storage/db.py b/swh/indexer/storage/db.py --- a/swh/indexer/storage/db.py +++ b/swh/indexer/storage/db.py @@ -3,10 +3,14 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from typing import Dict, Iterable, Iterator, List + from swh.core.db import BaseDb from swh.core.db.db_utils import execute_values_generator, stored_procedure from swh.model import hashutil +from .interface import Sha1 + class Db(BaseDb): """Proxy to the SWH Indexer DB, with wrappers around stored procedures @@ -15,14 +19,16 @@ content_mimetype_hash_keys = ["id", "indexer_configuration_id"] - def _missing_from_list(self, table, data, hash_keys, cur=None): + def _missing_from_list( + self, table: str, data: Iterable[Dict], hash_keys: List[str], cur=None + ): """Read from table the data with hash_keys that are missing. Args: - table (str): Table name (e.g content_mimetype, content_language, + table: Table name (e.g content_mimetype, content_language, etc...) - data (dict): Dict of data to read from - hash_keys ([str]): List of keys to read in the data dict. + data: Dict of data to read from + hash_keys: List of keys to read in the data dict. Yields: The data which is missing from the db. @@ -44,7 +50,9 @@ (tuple(m[k] for k in hash_keys) for m in data), ) - def content_mimetype_missing_from_list(self, mimetypes, cur=None): + def content_mimetype_missing_from_list( + self, mimetypes: Iterable[Dict], cur=None + ) -> Iterator[Sha1]: """List missing mimetypes. """ diff --git a/swh/indexer/storage/in_memory.py b/swh/indexer/storage/in_memory.py --- a/swh/indexer/storage/in_memory.py +++ b/swh/indexer/storage/in_memory.py @@ -257,7 +257,9 @@ def check_config(self, *, check_write): return True - def content_mimetype_missing(self, mimetypes): + def content_mimetype_missing( + self, mimetypes: Iterable[Dict] + ) -> Iterator[Tuple[Sha1, int]]: yield from self._mimetypes.missing(mimetypes) def content_mimetype_get_partition( @@ -273,16 +275,13 @@ ) def content_mimetype_add( - self, mimetypes: List[Dict], conflict_update: bool = False + self, mimetypes: List[ContentMimetypeRow], conflict_update: bool = False ) -> Dict[str, int]: - check_id_types(mimetypes) - added = self._mimetypes.add( - map(ContentMimetypeRow.from_dict, mimetypes), conflict_update - ) + added = self._mimetypes.add(mimetypes, conflict_update) return {"content_mimetype:add": added} - def content_mimetype_get(self, ids): - yield from (obj.to_dict() for obj in self._mimetypes.get(ids)) + def content_mimetype_get(self, ids: Iterable[Sha1]) -> Iterator[ContentMimetypeRow]: + yield from self._mimetypes.get(ids) def content_language_missing(self, languages): yield from self._languages.missing(languages) diff --git a/swh/indexer/storage/interface.py b/swh/indexer/storage/interface.py --- a/swh/indexer/storage/interface.py +++ b/swh/indexer/storage/interface.py @@ -3,10 +3,11 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Dict, List, Optional, TypeVar +from typing import Dict, Iterable, Iterator, List, Optional, Tuple, TypeVar from swh.core.api import remote_api_endpoint from swh.core.api.classes import PagedResult as CorePagedResult +from swh.indexer.storage.model import ContentMimetypeRow TResult = TypeVar("TResult") PagedResult = CorePagedResult[TResult, str] @@ -22,7 +23,9 @@ ... @remote_api_endpoint("content_mimetype/missing") - def content_mimetype_missing(self, mimetypes): + def content_mimetype_missing( + self, mimetypes: Iterable[Dict] + ) -> Iterator[Tuple[Sha1, int]]: """Generate mimetypes missing from storage. Args: @@ -70,21 +73,15 @@ @remote_api_endpoint("content_mimetype/add") def content_mimetype_add( - self, mimetypes: List[Dict], conflict_update: bool = False + self, mimetypes: List[ContentMimetypeRow], conflict_update: bool = False ) -> Dict[str, int]: """Add mimetypes not present in storage. Args: - mimetypes (iterable): dictionaries with keys: - - - **id** (bytes): sha1 identifier - - **mimetype** (bytes): raw content's mimetype - - **encoding** (bytes): raw content's encoding - - **indexer_configuration_id** (int): tool's id used to - compute the results - - **conflict_update** (bool): Flag to determine if we want to - overwrite (``True``) or skip duplicates (``False``, the - default) + mimetypes: mimetype rows to be added + conflict_update: Flag to determine if we want to + overwrite (``True``) or skip duplicates (``False``, the + default) Returns: Dict summary of number of rows added @@ -93,7 +90,7 @@ ... @remote_api_endpoint("content_mimetype") - def content_mimetype_get(self, ids): + def content_mimetype_get(self, ids: Iterable[Sha1]) -> Iterator[ContentMimetypeRow]: """Retrieve full content mimetype per ids. Args: diff --git a/swh/indexer/tests/storage/conftest.py b/swh/indexer/tests/storage/conftest.py --- a/swh/indexer/tests/storage/conftest.py +++ b/swh/indexer/tests/storage/conftest.py @@ -8,6 +8,7 @@ import pytest from swh.indexer.storage import get_indexer_storage +from swh.indexer.storage.model import ContentMimetypeRow from swh.model.hashutil import hash_to_bytes from swh.storage.pytest_plugin import postgresql_fact @@ -47,7 +48,7 @@ data.origin_url_2 = "file:///dev/1/one" # 44434342 data.origin_url_3 = "file:///dev/2/two" # 54974445 data.mimetypes = [ - {**mimetype_obj, "indexer_configuration_id": tools["file"]["id"]} + ContentMimetypeRow(indexer_configuration_id=tools["file"]["id"], **mimetype_obj) for mimetype_obj in MIMETYPE_OBJECTS ] swh_indexer_storage.content_mimetype_add(data.mimetypes) diff --git a/swh/indexer/tests/storage/test_storage.py b/swh/indexer/tests/storage/test_storage.py --- a/swh/indexer/tests/storage/test_storage.py +++ b/swh/indexer/tests/storage/test_storage.py @@ -6,12 +6,13 @@ import inspect import math import threading -from typing import Dict +from typing import Callable, Dict, Union import pytest from swh.indexer.storage.exc import DuplicateId, IndexerStorageArgumentException from swh.indexer.storage.interface import IndexerStorageInterface +from swh.indexer.storage.model import BaseRow, ContentMimetypeRow from swh.model.hashutil import hash_to_bytes @@ -22,12 +23,12 @@ mimetypes = [] for c in fossology_licenses: mimetypes.append( - { - "id": c["id"], - "mimetype": "text/plain", # for filtering on textual data to work - "encoding": "utf-8", - "indexer_configuration_id": c["indexer_configuration_id"], - } + ContentMimetypeRow( + id=c["id"], + mimetype="text/plain", # for filtering on textual data to work + encoding="utf-8", + indexer_configuration_id=c["indexer_configuration_id"], + ) ) return mimetypes @@ -110,6 +111,18 @@ See below for example usage. """ + # For old endpoints (which use dicts), this is just the identity function. + # For new endpoints, it returns a BaseRow instance. + row_from_dict: Callable[[Dict], Union[Dict, BaseRow]] = ( + staticmethod(lambda x: x) # type: ignore + ) + + # Inverse function of row_from_dict + # TODO: remove this once all endpoints are migrated to rows + dict_from_row: Callable[[Union[Dict, BaseRow]], Dict] = ( + staticmethod(lambda x: x) # type: ignore + ) + def test_missing(self, swh_indexer_storage_with_data): storage, data = swh_indexer_storage_with_data etype = self.endpoint_type @@ -131,11 +144,13 @@ # now, when we add one of them summary = endpoint(storage, etype, "add")( [ - { - "id": data.sha1_2, - **self.example_data[0], - "indexer_configuration_id": tool_id, - } + self.row_from_dict( + { + "id": data.sha1_2, + **self.example_data[0], + "indexer_configuration_id": tool_id, + } + ) ] ) @@ -156,24 +171,26 @@ **self.example_data[0], "indexer_configuration_id": tool_id, } - summary = endpoint(storage, etype, "add")([data_v1]) + summary = endpoint(storage, etype, "add")([self.row_from_dict(data_v1)]) assert summary == expected_summary(1, etype) # should be able to retrieve it actual_data = list(endpoint(storage, etype, "get")([data.sha1_2])) expected_data_v1 = [ - { - "id": data.sha1_2, - **self.example_data[0], - "tool": data.tools[self.tool_name], - } + self.row_from_dict( + { + "id": data.sha1_2, + **self.example_data[0], + "tool": data.tools[self.tool_name], + } + ) ] assert actual_data == expected_data_v1 # now if we add a modified version of the same object (same id) data_v2 = data_v1.copy() data_v2.update(self.example_data[1]) - summary2 = endpoint(storage, etype, "add")([data_v2]) + summary2 = endpoint(storage, etype, "add")([self.row_from_dict(data_v2)]) assert summary2 == expected_summary(0, etype) # not added # we expect to retrieve the original data, not the modified one @@ -192,13 +209,17 @@ } # given - summary = endpoint(storage, etype, "add")([data_v1]) + summary = endpoint(storage, etype, "add")([self.row_from_dict(data_v1)]) assert summary == expected_summary(1, etype) # not added # when actual_data = list(endpoint(storage, etype, "get")([data.sha1_2])) - expected_data_v1 = [{"id": data.sha1_2, **self.example_data[0], "tool": tool,}] + expected_data_v1 = [ + self.row_from_dict( + {"id": data.sha1_2, **self.example_data[0], "tool": tool} + ) + ] # then assert actual_data == expected_data_v1 @@ -207,12 +228,18 @@ data_v2 = data_v1.copy() data_v2.update(self.example_data[1]) - endpoint(storage, etype, "add")([data_v2], conflict_update=True) + endpoint(storage, etype, "add")( + [self.row_from_dict(data_v2)], conflict_update=True + ) assert summary == expected_summary(1, etype) # modified so counted actual_data = list(endpoint(storage, etype, "get")([data.sha1_2])) - expected_data_v2 = [{"id": data.sha1_2, **self.example_data[1], "tool": tool,}] + expected_data_v2 = [ + self.row_from_dict( + {"id": data.sha1_2, **self.example_data[1], "tool": tool,} + ) + ] # data did change as the v2 was used to overwrite v1 assert actual_data == expected_data_v2 @@ -228,19 +255,23 @@ ] data_v1 = [ - { - "id": hash_, - **self.example_data[0], - "indexer_configuration_id": tool["id"], - } + self.row_from_dict( + { + "id": hash_, + **self.example_data[0], + "indexer_configuration_id": tool["id"], + } + ) for hash_ in hashes ] data_v2 = [ - { - "id": hash_, - **self.example_data[1], - "indexer_configuration_id": tool["id"], - } + self.row_from_dict( + { + "id": hash_, + **self.example_data[1], + "indexer_configuration_id": tool["id"], + } + ) for hash_ in hashes ] @@ -256,7 +287,8 @@ actual_data = list(endpoint(storage, etype, "get")(hashes)) expected_data_v1 = [ - {"id": hash_, **self.example_data[0], "tool": tool,} for hash_ in hashes + self.row_from_dict({"id": hash_, **self.example_data[0], "tool": tool}) + for hash_ in hashes ] # then @@ -278,11 +310,12 @@ t2.join() actual_data = sorted( - endpoint(storage, etype, "get")(hashes), key=lambda x: x["id"] + map(self.dict_from_row, endpoint(storage, etype, "get")(hashes)), + key=lambda x: x["id"], ) expected_data_v2 = [ - {"id": hash_, **self.example_data[1], "tool": tool,} for hash_ in hashes + {"id": hash_, **self.example_data[1], "tool": tool} for hash_ in hashes ] assert actual_data == expected_data_v2 @@ -292,17 +325,21 @@ etype = self.endpoint_type tool = data.tools[self.tool_name] - data_rev1 = { - "id": data.revision_id_2, - **self.example_data[0], - "indexer_configuration_id": tool["id"], - } + data_rev1 = self.row_from_dict( + { + "id": data.revision_id_2, + **self.example_data[0], + "indexer_configuration_id": tool["id"], + } + ) - data_rev2 = { - "id": data.revision_id_2, - **self.example_data[1], - "indexer_configuration_id": tool["id"], - } + data_rev2 = self.row_from_dict( + { + "id": data.revision_id_2, + **self.example_data[1], + "indexer_configuration_id": tool["id"], + } + ) # when summary = endpoint(storage, etype, "add")([data_rev1]) @@ -319,7 +356,9 @@ ) expected_data = [ - {"id": data.revision_id_2, **self.example_data[0], "tool": tool,} + self.row_from_dict( + {"id": data.revision_id_2, **self.example_data[0], "tool": tool} + ) ] assert actual_data == expected_data @@ -329,11 +368,13 @@ tool = data.tools[self.tool_name] query = [data.sha1_2, data.sha1_1] - data1 = { - "id": data.sha1_2, - **self.example_data[0], - "indexer_configuration_id": tool["id"], - } + data1 = self.row_from_dict( + { + "id": data.sha1_2, + **self.example_data[0], + "indexer_configuration_id": tool["id"], + } + ) # when summary = endpoint(storage, etype, "add")([data1]) @@ -343,7 +384,11 @@ actual_data = list(endpoint(storage, etype, "get")(query)) # then - expected_data = [{"id": data.sha1_2, **self.example_data[0], "tool": tool,}] + expected_data = [ + self.row_from_dict( + {"id": data.sha1_2, **self.example_data[0], "tool": tool} + ) + ] assert actual_data == expected_data @@ -358,6 +403,8 @@ {"mimetype": "text/plain", "encoding": "utf-8",}, {"mimetype": "text/html", "encoding": "us-ascii",}, ] + row_from_dict = ContentMimetypeRow.from_dict + dict_from_row = staticmethod(lambda x: x.to_dict()) # type: ignore def test_generate_content_mimetype_get_partition_failure(self, swh_indexer_storage): """get_partition call with wrong limit input should fail""" @@ -377,8 +424,8 @@ storage, data = swh_indexer_storage_with_data mimetypes = data.mimetypes - expected_ids = set([c["id"] for c in mimetypes]) - indexer_configuration_id = mimetypes[0]["indexer_configuration_id"] + expected_ids = set([c.id for c in mimetypes]) + indexer_configuration_id = mimetypes[0].indexer_configuration_id assert len(mimetypes) == 16 nb_partitions = 16 @@ -403,8 +450,8 @@ """ storage, data = swh_indexer_storage_with_data mimetypes = data.mimetypes - expected_ids = set([c["id"] for c in mimetypes]) - indexer_configuration_id = mimetypes[0]["indexer_configuration_id"] + expected_ids = set([c.id for c in mimetypes]) + indexer_configuration_id = mimetypes[0].indexer_configuration_id actual_result = storage.content_mimetype_get_partition( indexer_configuration_id, 0, 1 @@ -421,8 +468,8 @@ """get_partition when at least one of the partitions is empty""" storage, data = swh_indexer_storage_with_data mimetypes = data.mimetypes - expected_ids = set([c["id"] for c in mimetypes]) - indexer_configuration_id = mimetypes[0]["indexer_configuration_id"] + expected_ids = set([c.id for c in mimetypes]) + indexer_configuration_id = mimetypes[0].indexer_configuration_id # nb_partitions = smallest power of 2 such that at least one of # the partitions is empty @@ -455,8 +502,8 @@ """ storage, data = swh_indexer_storage_with_data mimetypes = data.mimetypes - expected_ids = set([c["id"] for c in mimetypes]) - indexer_configuration_id = mimetypes[0]["indexer_configuration_id"] + expected_ids = set([c.id for c in mimetypes]) + indexer_configuration_id = mimetypes[0].indexer_configuration_id nb_partitions = 4 diff --git a/swh/indexer/tests/test_fossology_license.py b/swh/indexer/tests/test_fossology_license.py --- a/swh/indexer/tests/test_fossology_license.py +++ b/swh/indexer/tests/test_fossology_license.py @@ -15,6 +15,7 @@ FossologyLicensePartitionIndexer, compute_license, ) +from swh.indexer.storage.model import ContentLicenseRow from swh.indexer.tests.utils import ( BASE_TEST_CONFIG, SHA1_TO_LICENSES, @@ -136,6 +137,21 @@ super().tearDown() fossology_license.compute_license = self.orig_compute_license + def assert_results_ok(self, partition_id, nb_partitions, actual_results): + # TODO: remove this method when fossology_license endpoints moved away + # from dicts. + actual_result_rows = [] + for res in actual_results: + for license in res["licenses"]: + actual_result_rows.append( + ContentLicenseRow( + id=res["id"], + indexer_configuration_id=res["indexer_configuration_id"], + license=license, + ) + ) + super().assert_results_ok(partition_id, nb_partitions, actual_result_rows) + def test_fossology_w_no_tool(): with pytest.raises(ValueError): diff --git a/swh/indexer/tests/test_mimetype.py b/swh/indexer/tests/test_mimetype.py --- a/swh/indexer/tests/test_mimetype.py +++ b/swh/indexer/tests/test_mimetype.py @@ -55,7 +55,7 @@ legacy_get_format = True def get_indexer_results(self, ids): - yield from self.idx_storage.content_mimetype_get(ids) + yield from (x.to_dict() for x in self.idx_storage.content_mimetype_get(ids)) def setUp(self): self.indexer = MimetypeIndexer(config=CONFIG) @@ -107,6 +107,8 @@ """ + row_from_dict = staticmethod(lambda x: x) # type: ignore + def setUp(self): super().setUp() self.indexer = MimetypePartitionIndexer(config=RANGE_CONFIG) diff --git a/swh/indexer/tests/utils.py b/swh/indexer/tests/utils.py --- a/swh/indexer/tests/utils.py +++ b/swh/indexer/tests/utils.py @@ -5,13 +5,14 @@ import abc import functools -from typing import Any, Dict +from typing import Any, Callable, Dict, Union import unittest from hypothesis import strategies from swh.core.api.classes import stream_results from swh.indexer.storage import INDEXER_CFG_KEY +from swh.indexer.storage.model import BaseRow from swh.model import hashutil from swh.model.hashutil import hash_to_bytes from swh.model.model import ( @@ -684,6 +685,9 @@ """ + # TODO: remove this when all endpoints moved away from dicts + row_from_dict: Callable[[Union[Dict, BaseRow]], BaseRow] + def setUp(self): self.contents = sorted(OBJ_STORAGE_DATA) @@ -699,11 +703,10 @@ actual_results = list(actual_results) for indexed_data in actual_results: - _id = indexed_data["id"] - assert isinstance(_id, bytes) + _id = indexed_data.id assert _id in expected_ids - _tool_id = indexed_data["indexer_configuration_id"] + _tool_id = indexed_data.indexer_configuration_id assert _tool_id == self.indexer.tool["id"] def test__index_contents(self): @@ -728,12 +731,13 @@ # first pass actual_results = list( - self.indexer._index_contents(partition_id, nb_partitions, indexed={}) + self.indexer._index_contents(partition_id, nb_partitions, indexed={}), ) self.assert_results_ok(partition_id, nb_partitions, actual_results) - indexed_ids = set(res["id"] for res in actual_results) + # TODO: unconditionally use res.id when all endpoints moved away from dicts + indexed_ids = {getattr(res, "id", None) or res["id"] for res in actual_results} actual_results = list( self.indexer._index_contents(