Changeset View
Changeset View
Standalone View
Standalone View
swh/search/in_memory.py
# Copyright (C) 2019 The Software Heritage developers | # Copyright (C) 2019-2020 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
import base64 | |||||
from collections import defaultdict | |||||
import itertools | import itertools | ||||
import re | import re | ||||
from typing import Any, Dict, Iterable, Iterator, List, Optional | |||||
import msgpack | from collections import defaultdict | ||||
from typing import Any, Dict, Iterable, Iterator, List, Optional | |||||
from swh.core.api import remote_api_endpoint | from swh.core.api import remote_api_endpoint | ||||
from swh.model.identifiers import origin_identifier | from swh.model.identifiers import origin_identifier | ||||
from swh.search.interface import PagedResult | |||||
def _sanitize_origin(origin): | def _sanitize_origin(origin): | ||||
origin = origin.copy() | origin = origin.copy() | ||||
res = {"url": origin.pop("url")} | res = {"url": origin.pop("url")} | ||||
for field_name in ("type", "intrinsic_metadata"): | for field_name in ("type", "intrinsic_metadata"): | ||||
if field_name in origin: | if field_name in origin: | ||||
res[field_name] = origin.pop(field_name) | res[field_name] = origin.pop(field_name) | ||||
return res | return res | ||||
ardumont: dead code ^ (unrelated to this diff though) | |||||
class InMemorySearch: | class InMemorySearch: | ||||
def __init__(self): | def __init__(self): | ||||
pass | pass | ||||
@remote_api_endpoint("check") | @remote_api_endpoint("check") | ||||
def check(self): | def check(self): | ||||
return True | return True | ||||
def deinitialize(self) -> None: | def deinitialize(self) -> None: | ||||
if hasattr(self, "_origins"): | if hasattr(self, "_origins"): | ||||
del self._origins | del self._origins | ||||
del self._origin_ids | del self._origin_ids | ||||
def initialize(self) -> None: | def initialize(self) -> None: | ||||
self._origins = defaultdict(dict) # type: Dict[str, Dict[str, Any]] | self._origins: Dict[str, Dict[str, Any]] = defaultdict(dict) | ||||
self._origin_ids = [] # type: List[str] | self._origin_ids: List[str] = [] | ||||
def flush(self) -> None: | def flush(self) -> None: | ||||
pass | pass | ||||
_url_splitter = re.compile(r"\W") | _url_splitter = re.compile(r"\W") | ||||
@remote_api_endpoint("origin/update") | @remote_api_endpoint("origin/update") | ||||
def origin_update(self, documents: Iterable[dict]) -> None: | def origin_update(self, documents: Iterable[dict]) -> None: | ||||
for document in documents: | for document in documents: | ||||
document = document.copy() | document = document.copy() | ||||
id_ = origin_identifier(document) | id_ = origin_identifier(document) | ||||
if "url" in document: | if "url" in document: | ||||
document["_url_tokens"] = set(self._url_splitter.split(document["url"])) | document["_url_tokens"] = set(self._url_splitter.split(document["url"])) | ||||
self._origins[id_].update(document) | self._origins[id_].update(document) | ||||
if id_ not in self._origin_ids: | if id_ not in self._origin_ids: | ||||
self._origin_ids.append(id_) | self._origin_ids.append(id_) | ||||
@remote_api_endpoint("origin/search") | @remote_api_endpoint("origin/search") | ||||
def origin_search( | def origin_search( | ||||
self, | self, | ||||
*, | *, | ||||
url_pattern: str = None, | url_pattern: Optional[str] = None, | ||||
metadata_pattern: str = None, | metadata_pattern: Optional[str] = None, | ||||
with_visit: bool = False, | with_visit: bool = False, | ||||
page_token: str = None, | page_token: Optional[str] = None, | ||||
count: int = 50, | limit: int = 50, | ||||
) -> Dict[str, object]: | ) -> PagedResult[Dict[str, Any]]: | ||||
matches = ( | hits: Iterator[Dict[str, Any]] = ( | ||||
self._origins[id_] for id_ in self._origin_ids | self._origins[id_] for id_ in self._origin_ids | ||||
) # type: Iterator[Dict[str, Any]] | ) | ||||
if url_pattern: | if url_pattern: | ||||
tokens = set(self._url_splitter.split(url_pattern)) | tokens = set(self._url_splitter.split(url_pattern)) | ||||
def predicate(match): | def predicate(match): | ||||
missing_tokens = tokens - match["_url_tokens"] | missing_tokens = tokens - match["_url_tokens"] | ||||
if len(missing_tokens) == 0: | if len(missing_tokens) == 0: | ||||
return True | return True | ||||
elif len(missing_tokens) > 1: | elif len(missing_tokens) > 1: | ||||
return False | return False | ||||
else: | else: | ||||
# There is one missing token, look up by prefix. | # There is one missing token, look up by prefix. | ||||
(missing_token,) = missing_tokens | (missing_token,) = missing_tokens | ||||
return any( | return any( | ||||
token.startswith(missing_token) | token.startswith(missing_token) | ||||
for token in match["_url_tokens"] | for token in match["_url_tokens"] | ||||
) | ) | ||||
matches = filter(predicate, matches) | hits = filter(predicate, hits) | ||||
if metadata_pattern: | if metadata_pattern: | ||||
raise NotImplementedError( | raise NotImplementedError( | ||||
"Metadata search is not implemented in the in-memory backend." | "Metadata search is not implemented in the in-memory backend." | ||||
) | ) | ||||
if not url_pattern and not metadata_pattern: | if not url_pattern and not metadata_pattern: | ||||
raise ValueError( | raise ValueError( | ||||
"At least one of url_pattern and metadata_pattern must be provided." | "At least one of url_pattern and metadata_pattern must be provided." | ||||
) | ) | ||||
next_page_token: Optional[str] = None | |||||
if with_visit: | if with_visit: | ||||
matches = filter(lambda o: o.get("has_visits"), matches) | hits = filter(lambda o: o.get("has_visits"), hits) | ||||
Done Inline Actionsyou should rename this to hits while you're at it, to match ES' terminology vlorentz: you should rename this to `hits` while you're at it, to match ES' terminology | |||||
if page_token: | start_at_index = int(page_token) if page_token else 0 | ||||
page_token_content = msgpack.loads(base64.b64decode(page_token)) | |||||
start_at_index = page_token_content[b"start_at_index"] | |||||
else: | |||||
start_at_index = 0 | |||||
hits = list(itertools.islice(matches, start_at_index, start_at_index + count)) | origins = [ | ||||
{"url": hit["url"]} | |||||
for hit in itertools.islice(hits, start_at_index, start_at_index + limit) | |||||
] | |||||
if len(hits) == count: | if len(origins) == limit: | ||||
next_page_token_content = { | next_page_token = str(start_at_index + limit) | ||||
b"start_at_index": start_at_index + count, | |||||
} | assert len(origins) <= limit | ||||
next_page_token = base64.b64encode( | |||||
msgpack.dumps(next_page_token_content) | |||||
) # type: Optional[bytes] | |||||
else: | |||||
next_page_token = None | |||||
return { | return PagedResult(results=origins, next_page_token=next_page_token,) | ||||
"next_page_token": next_page_token, | |||||
"results": [{"url": hit["url"]} for hit in hits], | |||||
} |
dead code ^ (unrelated to this diff though)