diff --git a/swh/counters/interface.py b/swh/counters/interface.py index ad36d81..ac7049d 100644 --- a/swh/counters/interface.py +++ b/swh/counters/interface.py @@ -1,47 +1,52 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Any, Iterable +from typing import Any, Dict, Iterable, List from swh.core.api import remote_api_endpoint class CountersInterface: @remote_api_endpoint("check") def check(self): """Dedicated method to execute some specific check per implementation. """ ... @remote_api_endpoint("add") def add(self, collection: str, keys: Iterable[Any]) -> None: """Add the provided keys to the collection Only count new keys. """ ... @remote_api_endpoint("count") def get_count(self, collection: str) -> int: """Return the number of keys for the provided collection""" ... + @remote_api_endpoint("counts") + def get_counts(self, collections: List[str]) -> Dict[str, int]: + """Return the number of keys for the provided collection""" + ... + @remote_api_endpoint("counters") def get_counters(self) -> Iterable[str]: """Return the list of managed counters""" ... class HistoryInterface: @remote_api_endpoint("history") def get_history(self, cache_file: str): """Return the content of an history file previously created by the refresh_counters method""" @remote_api_endpoint("refresh_history") def refresh_history(self, cache_file: str): """Refresh the cache file containing the counters historical data. It can be an aggregate of live data and static data stored on a separate file""" ... diff --git a/swh/counters/redis.py b/swh/counters/redis.py index 3862759..982e25a 100644 --- a/swh/counters/redis.py +++ b/swh/counters/redis.py @@ -1,60 +1,63 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging -from typing import Any, Iterable +from typing import Any, Dict, Iterable, List from redis.client import Redis as RedisClient from redis.exceptions import ConnectionError DEFAULT_REDIS_PORT = 6379 logger = logging.getLogger(__name__) class Redis: """Redis based implementation of the counters. It uses one HyperLogLog collection per counter""" _redis_client = None def __init__(self, host: str): host_port = host.split(":") if len(host_port) > 2: raise ValueError("Invalid server url `%s`" % host) self.host = host_port[0] self.port = int(host_port[1]) if len(host_port) > 1 else DEFAULT_REDIS_PORT @property def redis_client(self) -> RedisClient: if self._redis_client is None: self._redis_client = RedisClient(host=self.host, port=self.port) return self._redis_client def check(self): try: return self.redis_client.ping() except ConnectionError: logger.exception("Unable to connect to the redis server") return False def add(self, collection: str, keys: Iterable[Any]) -> None: redis = self.redis_client pipeline = redis.pipeline(transaction=False) [pipeline.pfadd(collection, key) for key in keys] pipeline.execute() def get_count(self, collection: str) -> int: return self.redis_client.pfcount(collection) + def get_counts(self, collections: List[str]) -> Dict[str, int]: + return {coll: self.get_count(coll) for coll in collections} + def get_counters(self) -> Iterable[str]: return self.redis_client.keys() diff --git a/swh/counters/tests/test_redis.py b/swh/counters/tests/test_redis.py index 704b7c2..ab579a7 100644 --- a/swh/counters/tests/test_redis.py +++ b/swh/counters/tests/test_redis.py @@ -1,80 +1,98 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest from redis import Redis as RedisClient from swh.counters.redis import DEFAULT_REDIS_PORT, Redis def test__redis__constructor(): r = Redis("fakehost") assert r.host == "fakehost" assert r.port == DEFAULT_REDIS_PORT r = Redis("host:11") assert r.host == "host" assert r.port == 11 with pytest.raises(ValueError, match="url"): Redis("fake:host:port") def test__redis__only_one_client_instantiation(mocker): mock = mocker.patch("swh.counters.redis.RedisClient") r = Redis("redishost:1234") # ensure lazy loading assert r._redis_client is None client = r.redis_client assert mock.call_count == 1 args = mock.call_args[1] assert args["host"] == "redishost" assert args["port"] == 1234 assert r._redis_client is not None client2 = r.redis_client assert mock.call_count == 1 assert client == client2 def test__redis__ping_ko(): r = Redis("wronghost") assert r.check() is False def test__redis__ping_ok(local_redis): r = Redis("%s:%d" % (local_redis.host, local_redis.port)) assert r.check() is True def test__redis__collection(local_redis): r = Redis("%s:%d" % (local_redis.host, local_redis.port)) r.add("c1", [b"k1", b"k2", b"k3"]) r.add("c2", [b"k1"]) r.add("c3", [b"k2"]) r.add("c3", [b"k5"]) assert 3 == r.get_count("c1") assert 1 == r.get_count("c2") assert 2 == r.get_count("c3") assert 0 == r.get_count("c4") def test__redis__collections(local_redis): client = RedisClient(host=local_redis.host, port=local_redis.port) client.pfadd("counter1", b"k1") client.pfadd("counter2", b"k2") r = Redis("%s:%d" % (local_redis.host, local_redis.port)) counters = r.get_counters() assert 2 == len(counters) assert b"counter1" in counters assert b"counter2" in counters + + +def test__redis_counts(local_redis): + client = RedisClient(host=local_redis.host, port=local_redis.port) + client.pfadd("counter1", b"k1") + client.pfadd("counter1", b"k2") + client.pfadd("counter2", b"k1") + client.pfadd("counter2", b"k2") + client.pfadd("counter2", b"k3") + client.pfadd("counter3", b"k3") + + r = Redis("%s:%d" % (local_redis.host, local_redis.port)) + + counts = r.get_counts(["counter2", "counter1"]) + + assert 2 == len(counts) + assert 2 == counts["counter1"] + assert 3 == counts["counter2"]