diff --git a/requirements.txt b/requirements.txt --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ # dependency lines, see https://pip.readthedocs.org/en/1.1/requirements.html click elasticsearch>=7.0.0,<8.0.0 +typing-extensions diff --git a/swh/search/elasticsearch.py b/swh/search/elasticsearch.py --- a/swh/search/elasticsearch.py +++ b/swh/search/elasticsearch.py @@ -12,7 +12,7 @@ from swh.indexer import codemeta from swh.model import model from swh.model.identifiers import origin_identifier -from swh.search.interface import PagedResult +from swh.search.interface import MinimalOriginDict, OriginDict, PagedResult from swh.search.metrics import send_metric, timed @@ -125,7 +125,7 @@ self._backend.indices.refresh(index=self.origin_index) @timed - def origin_update(self, documents: Iterable[Dict]) -> None: + def origin_update(self, documents: Iterable[OriginDict]) -> None: documents = map(_sanitize_origin, documents) documents_with_sha1 = ( (origin_identifier(document), document) for document in documents @@ -191,7 +191,7 @@ visit_types: Optional[List[str]] = None, page_token: Optional[str] = None, limit: int = 50, - ) -> PagedResult[Dict[str, Any]]: + ) -> PagedResult[MinimalOriginDict]: query_clauses: List[Dict[str, Any]] = [] if url_pattern: diff --git a/swh/search/in_memory.py b/swh/search/in_memory.py --- a/swh/search/in_memory.py +++ b/swh/search/in_memory.py @@ -9,7 +9,7 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional from swh.model.identifiers import origin_identifier -from swh.search.interface import PagedResult +from swh.search.interface import MinimalOriginDict, OriginDict, PagedResult _words_regexp = re.compile(r"\w+") @@ -53,14 +53,16 @@ _url_splitter = re.compile(r"\W") - def origin_update(self, documents: Iterable[Dict]) -> None: - for document in documents: - document = document.copy() + def origin_update(self, documents: Iterable[OriginDict]) -> None: + for source_document in documents: + document: Dict[str, Any] = dict(source_document) id_ = origin_identifier(document) if "url" in document: - document["_url_tokens"] = set(self._url_splitter.split(document["url"])) + document["_url_tokens"] = set( + self._url_splitter.split(source_document["url"]) + ) if "visit_types" in document: - document["visit_types"] = set(document["visit_types"]) + document["visit_types"] = set(source_document["visit_types"]) if "visit_types" in self._origins[id_]: document["visit_types"].update(self._origins[id_]["visit_types"]) self._origins[id_].update(document) @@ -77,7 +79,7 @@ visit_types: Optional[List[str]] = None, page_token: Optional[str] = None, limit: int = 50, - ) -> PagedResult[Dict[str, Any]]: + ) -> PagedResult[MinimalOriginDict]: hits: Iterator[Dict[str, Any]] = ( self._origins[id_] for id_ in self._origin_ids ) diff --git a/swh/search/interface.py b/swh/search/interface.py --- a/swh/search/interface.py +++ b/swh/search/interface.py @@ -3,7 +3,9 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Any, Dict, Iterable, List, Optional, TypeVar +from typing import Iterable, List, Optional, TypeVar + +from typing_extensions import TypedDict from swh.core.api import remote_api_endpoint from swh.core.api.classes import PagedResult as CorePagedResult @@ -12,6 +14,19 @@ PagedResult = CorePagedResult[TResult, str] +class MinimalOriginDict(TypedDict): + """Mandatory keys of an :cls:`OriginDict`""" + + url: str + + +class OriginDict(MinimalOriginDict, total=False): + """Argument passed to :meth:`SearchInterface.origin_update`.""" + + visit_types: List[str] + has_visits: bool + + class SearchInterface: @remote_api_endpoint("check") def check(self): @@ -29,7 +44,7 @@ ... @remote_api_endpoint("origin/update") - def origin_update(self, documents: Iterable[Dict]) -> None: + def origin_update(self, documents: Iterable[OriginDict]) -> None: """Persist documents to the search backend. """ @@ -45,7 +60,7 @@ visit_types: Optional[List[str]] = None, page_token: Optional[str] = None, limit: int = 50, - ) -> PagedResult[Dict[str, Any]]: + ) -> PagedResult[MinimalOriginDict]: """Searches for origins matching the `url_pattern`. Args: