diff --git a/swh/search/elasticsearch.py b/swh/search/elasticsearch.py --- a/swh/search/elasticsearch.py +++ b/swh/search/elasticsearch.py @@ -10,7 +10,7 @@ from textwrap import dedent from typing import Any, Dict, Iterable, List, Optional, cast -from elasticsearch import Elasticsearch, helpers +from elasticsearch import Elasticsearch, NotFoundError, helpers import msgpack from swh.indexer import codemeta @@ -387,6 +387,17 @@ "document:index_error", count=len(errors), method_name="origin_update" ) + def origin_get(self, url: str) -> Optional[Dict[str, str]]: + origin_id = hash_to_hex(model.Origin(url=url).id) + try: + document = self._backend.get( + index=self._get_origin_read_alias(), id=origin_id + ) + except NotFoundError: + return None + else: + return document["_source"] + @timed def origin_search( self, diff --git a/swh/search/in_memory.py b/swh/search/in_memory.py --- a/swh/search/in_memory.py +++ b/swh/search/in_memory.py @@ -171,8 +171,11 @@ def origin_update(self, documents: Iterable[OriginDict]) -> None: for source_document in documents: - document: Dict[str, Any] = dict(source_document) - id_ = hash_to_hex(model.Origin(url=document["url"]).id) + id_ = hash_to_hex(model.Origin(url=source_document["url"]).id) + document: Dict[str, Any] = { + **source_document, + "sha1": id_, + } if "url" in document: document["_url_tokens"] = set( self._url_splitter.split(source_document["url"]) @@ -521,6 +524,14 @@ next_page_token=next_page_token, ) + def origin_get(self, url: str) -> Optional[Dict[str, str]]: + origin_id = hash_to_hex(model.Origin(url=url).id) + document = self._origins.get(origin_id) + if document is None: + return None + else: + return {k: v for (k, v) in document.items() if k != "_url_tokens"} + def visit_types_count(self) -> Counter: hits = self._get_hits() return Counter(chain(*[hit.get("visit_types", []) for hit in hits])) diff --git a/swh/search/interface.py b/swh/search/interface.py --- a/swh/search/interface.py +++ b/swh/search/interface.py @@ -4,7 +4,7 @@ # See top-level LICENSE file for more information from collections import Counter -from typing import Iterable, List, Optional, TypeVar +from typing import Dict, Iterable, List, Optional, TypeVar from typing_extensions import TypedDict @@ -132,6 +132,13 @@ """ ... + @remote_api_endpoint("origin/get") + def origin_get(self, url: List[str]) -> Optional[Dict[str, str]]: + """Returns the full documents associated to the given origin URLs. + + Order is arbitrary; unknown origins are not returned. + """ + @remote_api_endpoint("visit_types_count") def visit_types_count(self) -> Counter: """Returns origin counts per visit type (git, hg, svn, ...).""" diff --git a/swh/search/tests/test_search.py b/swh/search/tests/test_search.py --- a/swh/search/tests/test_search.py +++ b/swh/search/tests/test_search.py @@ -5,6 +5,7 @@ from collections import Counter from datetime import datetime, timedelta, timezone +import hashlib from itertools import permutations from hypothesis import given, settings, strategies @@ -1227,7 +1228,7 @@ assert result_page.next_page_token is None assert result_page.results == [] - def test_filter_keyword_in_filter(self): + def test_search_filter_keyword_in_filter(self): origin1 = { "url": "foo language in ['foo baz'] bar", } @@ -1242,6 +1243,94 @@ assert result_page.next_page_token is None assert result_page.results == [] + def test_origin_get(self): + """Checks the same field can have a concrete value, an object, or an array + in different documents.""" + origin1 = {"url": "http://origin1"} + origin2 = {"url": "http://origin2"} + origin3 = {"url": "http://origin3"} + origins = [ + { + **origin1, + "jsonld": { + "@context": "https://doi.org/10.5063/schema/codemeta-2.0", + "author": { + "familyName": "Foo", + "givenName": "Bar", + }, + }, + }, + { + **origin2, + "jsonld": { + "@context": "https://doi.org/10.5063/schema/codemeta-2.0", + "author": "Bar Baz", + }, + }, + { + **origin3, + "jsonld": { + "@context": "https://doi.org/10.5063/schema/codemeta-2.0", + "author": ["Baz", "Qux"], + }, + }, + ] + + expanded_origins = [ + { + **origin1, + "sha1": hashlib.sha1(origin1["url"].encode()).hexdigest(), + "jsonld": [ + { + "http://schema.org/author": [ + { + "@list": [ + { + "http://schema.org/familyName": [ + {"@value": "Foo"} + ], + "http://schema.org/givenName": [ + {"@value": "Bar"} + ], + } + ] + } + ], + } + ], + }, + { + **origin2, + "sha1": hashlib.sha1(origin2["url"].encode()).hexdigest(), + "jsonld": [ + { + "http://schema.org/author": [ + {"@list": [{"@value": "Bar Baz"}]} + ], + } + ], + }, + { + **origin3, + "sha1": hashlib.sha1(origin3["url"].encode()).hexdigest(), + "jsonld": [ + { + "http://schema.org/author": [ + {"@list": [{"@value": "Baz"}, {"@value": "Qux"}]} + ], + } + ], + }, + ] + + self.search.origin_update(origins) + self.search.flush() + + assert self.search.origin_get(origin1["url"]) == expanded_origins[0] + assert self.search.origin_get(origin2["url"]) == expanded_origins[1] + assert self.search.origin_get(origin3["url"]) == expanded_origins[2] + assert self.search.origin_get("http://origin4") is None + def test_visit_types_count(self): assert self.search.visit_types_count() == Counter()