diff --git a/swh/search/elasticsearch.py b/swh/search/elasticsearch.py --- a/swh/search/elasticsearch.py +++ b/swh/search/elasticsearch.py @@ -4,6 +4,7 @@ # See top-level LICENSE file for more information import base64 +from collections import Counter import logging import pprint from textwrap import dedent @@ -527,3 +528,27 @@ results=[{"url": hit["_source"]["url"]} for hit in hits], next_page_token=next_page_token, ) + + def visit_types_count(self) -> Counter: + body = { + "aggs": { + "not_blocklisted": { + "filter": {"bool": {"must_not": [{"term": {"blocklisted": True}}]}}, + "aggs": { + "visit_types": {"terms": {"field": "visit_types", "size": 1000}} + }, + } + } + } + + res = self._backend.search( + index=self._get_origin_read_alias(), body=body, size=0 + ) + + buckets = ( + res.get("aggregations", {}) + .get("not_blocklisted", {}) + .get("visit_types", {}) + .get("buckets", []) + ) + return Counter({bucket["key"]: bucket["doc_count"] for bucket in buckets}) diff --git a/swh/search/in_memory.py b/swh/search/in_memory.py --- a/swh/search/in_memory.py +++ b/swh/search/in_memory.py @@ -3,8 +3,9 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from collections import defaultdict +from collections import Counter, defaultdict from datetime import datetime, timezone +from itertools import chain import re from typing import Any, Dict, Iterable, Iterator, List, Optional @@ -288,11 +289,7 @@ page_token: Optional[str] = None, limit: int = 50, ) -> PagedResult[MinimalOriginDict]: - hits: Iterator[Dict[str, Any]] = ( - self._origins[id_] - for id_ in self._origin_ids - if not self._origins[id_].get("blocklisted") - ) + hits = self._get_hits() if url_pattern: tokens = set(self._url_splitter.split(url_pattern)) @@ -506,3 +503,14 @@ assert len(origins) <= limit return PagedResult(results=origins, next_page_token=next_page_token,) + + def visit_types_count(self) -> Counter: + hits = self._get_hits() + return Counter(chain(*[hit.get("visit_types", []) for hit in hits])) + + def _get_hits(self) -> Iterator[Dict[str, Any]]: + return ( + self._origins[id_] + for id_ in self._origin_ids + if not self._origins[id_].get("blocklisted") + ) diff --git a/swh/search/interface.py b/swh/search/interface.py --- a/swh/search/interface.py +++ b/swh/search/interface.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from collections import Counter from typing import Iterable, List, Optional, TypeVar from typing_extensions import TypedDict @@ -133,3 +134,9 @@ """ ... + + @remote_api_endpoint("visit_types_count") + def visit_types_count(self) -> Counter: + """Returns origin counts per visit type (git, hg, svn, ...). + """ + ... diff --git a/swh/search/tests/test_search.py b/swh/search/tests/test_search.py --- a/swh/search/tests/test_search.py +++ b/swh/search/tests/test_search.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from collections import Counter from datetime import datetime, timedelta, timezone from itertools import permutations @@ -1166,3 +1167,23 @@ result_page = self.search.origin_search(url_pattern="baaz") assert result_page.next_page_token is None assert result_page.results == [] + + def test_visit_types_count(self): + assert self.search.visit_types_count() == Counter() + + origins = [ + {"url": "http://foobar.baz", "visit_types": ["git"], "blocklisted": True} + ] + + for idx, visit_type in enumerate(["git", "hg", "svn"]): + for i in range(idx + 1): + origins.append( + { + "url": f"http://{visit_type}.foobar.baz.{i}", + "visit_types": [visit_type], + } + ) + self.search.origin_update(origins) + self.search.flush() + + assert self.search.visit_types_count() == Counter(git=1, hg=2, svn=3)