diff --git a/swh/web/browse/snapshot_context.py b/swh/web/browse/snapshot_context.py --- a/swh/web/browse/snapshot_context.py +++ b/swh/web/browse/snapshot_context.py @@ -8,7 +8,6 @@ from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple -from django.core.cache import cache from django.shortcuts import render from django.utils.html import escape @@ -38,6 +37,7 @@ SWHObjectInfo, ) from swh.web.common.utils import ( + django_cache, format_utc_iso_date, gen_path_info, reverse, @@ -280,6 +280,7 @@ return ret_branches, ret_releases, resolved_aliases +@django_cache() def get_snapshot_content( snapshot_id: str, ) -> Tuple[List[SnapshotBranchInfo], List[SnapshotReleaseInfo], Dict[str, Any]]: @@ -303,15 +304,6 @@ Raises: NotFoundExc if the snapshot does not exist """ - cache_entry_id = "swh_snapshot_%s" % snapshot_id - cache_entry = cache.get(cache_entry_id) - - if cache_entry: - return ( - cache_entry["branches"], - cache_entry["releases"], - cache_entry.get("aliases", {}), - ) branches: List[SnapshotBranchInfo] = [] releases: List[SnapshotReleaseInfo] = [] @@ -325,10 +317,6 @@ ) branches, releases, aliases = process_snapshot_branches(snapshot) - cache.set( - cache_entry_id, {"branches": branches, "releases": releases, "aliases": aliases} - ) - return branches, releases, aliases @@ -483,11 +471,11 @@ releases = list(reversed(releases)) - snapshot_sizes_cache_id = f"swh_snapshot_{snapshot_id}_sizes" - snapshot_sizes = cache.get(snapshot_sizes_cache_id) - if snapshot_sizes is None: - snapshot_sizes = archive.lookup_snapshot_sizes(snapshot_id) - cache.set(snapshot_sizes_cache_id, snapshot_sizes) + @django_cache() + def _get_snapshot_sizes(snapshot_id): + return archive.lookup_snapshot_sizes(snapshot_id) + + snapshot_sizes = _get_snapshot_sizes(snapshot_id) is_empty = (snapshot_sizes["release"] + snapshot_sizes["revision"]) == 0 diff --git a/swh/web/browse/utils.py b/swh/web/browse/utils.py --- a/swh/web/browse/utils.py +++ b/swh/web/browse/utils.py @@ -12,7 +12,6 @@ import magic import sentry_sdk -from django.core.cache import cache from django.utils.html import escape from django.utils.safestring import mark_safe @@ -20,6 +19,7 @@ from swh.web.common.exc import NotFoundExc from swh.web.common.utils import ( browsers_supported_image_mimes, + django_cache, format_utc_iso_date, reverse, rst_to_html, @@ -27,6 +27,7 @@ from swh.web.config import get_config +@django_cache() def get_directory_entries(sha1_git): """Function that retrieves the content of a directory from the archive. @@ -44,12 +45,6 @@ Raises: NotFoundExc if the directory is not found """ - cache_entry_id = "directory_entries_%s" % sha1_git - cache_entry = cache.get(cache_entry_id) - - if cache_entry: - return cache_entry - entries = list(archive.lookup_directory(sha1_git)) for e in entries: e["perms"] = stat.filemode(e["perms"]) @@ -64,8 +59,6 @@ dirs = sorted(dirs, key=lambda d: d["name"]) files = sorted(files, key=lambda f: f["name"]) - cache.set(cache_entry_id, (dirs, files)) - return dirs, files @@ -717,18 +710,15 @@ # convert rst README to html server side as there is # no viable solution to perform that task client side if readme_name and readme_name.endswith(".rst"): - cache_entry_id = "readme_%s" % readme_sha1 - cache_entry = cache.get(cache_entry_id) - if cache_entry: - readme_html = cache_entry - else: - try: - rst_doc = request_content(readme_sha1) - readme_html = rst_to_html(rst_doc["raw_data"]) - cache.set(cache_entry_id, readme_html) - except Exception as exc: - sentry_sdk.capture_exception(exc) - readme_html = "Readme bytes are not available" + @django_cache( + catch_exception=True, + exception_return_value="Readme bytes are not available", + ) + def _rst_readme_to_html(readme_sha1): + rst_doc = request_content(readme_sha1) + return rst_to_html(rst_doc["raw_data"]) + + readme_html = _rst_readme_to_html(readme_sha1) return readme_name, readme_url, readme_html diff --git a/swh/web/common/utils.py b/swh/web/common/utils.py --- a/swh/web/common/utils.py +++ b/swh/web/common/utils.py @@ -4,9 +4,10 @@ # See top-level LICENSE file for more information from datetime import datetime, timezone +import functools import os import re -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import urllib.parse from xml.etree import ElementTree @@ -20,8 +21,10 @@ from prometheus_client.registry import CollectorRegistry import requests from requests.auth import HTTPBasicAuth +import sentry_sdk from django.core.cache import cache +from django.core.cache.backends.base import DEFAULT_TIMEOUT from django.http import HttpRequest, QueryDict from django.shortcuts import redirect from django.urls import resolve @@ -400,6 +403,55 @@ return BeautifulSoup(html, "lxml").prettify() +def django_cache( + timeout: int = DEFAULT_TIMEOUT, + catch_exception: bool = False, + exception_return_value: Any = None, + invalidate_cache_pred: Callable[[Any], bool] = lambda val: False, +): + """Decorator to put the result of a function call in Django cache, + subsequent calls will directly return the cached value. + + Args: + timeout: The number of seconds value will be hold in cache + catch_exception: If :const:`True`, any thrown exception by + the decorated function will be caught and not reraised + exception_return_value: The value to return if previous + parameter is set to :const:`True` + invalidate_cache_pred: A predicate function enabling to + invalidate the cache under certain conditions, decorated + function will then be called again + + Returns: + The returned value of the decorated function for the specified + parameters + + """ + + def inner(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + func_args = args + (0,) + tuple(sorted(kwargs.items())) + cache_key = str(hash((func.__module__, func.__name__) + func_args)) + ret = cache.get(cache_key) + if ret is None or invalidate_cache_pred(ret): + try: + ret = func(*args, **kwargs) + except Exception as exc: + sentry_sdk.capture_exception(exc) + if catch_exception: + return exception_return_value + else: + raise + else: + cache.set(cache_key, ret, timeout=timeout) + return ret + + return wrapper + + return inner + + def _deposits_list_url( deposits_list_base_url: str, page_size: int, username: Optional[str] ) -> str: @@ -426,31 +478,25 @@ deposits_list_url, auth=deposits_list_auth, timeout=30 ).json()["count"] - cache_key = f"swh-deposit-list-{username}" - deposits_data = cache.get(cache_key) - if not deposits_data or deposits_data["count"] != nb_deposits: + @django_cache(invalidate_cache_pred=lambda data: data["count"] != nb_deposits) + def _get_deposits_data(): deposits_list_url = _deposits_list_url( deposits_list_base_url, page_size=nb_deposits, username=username ) - deposits_data = requests.get( + return requests.get( deposits_list_url, auth=deposits_list_auth, timeout=30, ).json() - cache.set(cache_key, deposits_data) + + deposits_data = _get_deposits_data() return deposits_data["results"] +@django_cache() def get_deposit_raw_metadata(deposit_id: int) -> Optional[str]: - cache_key = f"swh-deposit-raw-metadata-{deposit_id}" - metadata = cache.get(cache_key) - if metadata is None: - config = get_config()["deposit"] - - url = f"{config['private_api_url']}/{deposit_id}/meta" - metadata = requests.get(url).json()["raw_metadata"] - cache.set(cache_key, metadata) - - return metadata + config = get_config()["deposit"] + url = f"{config['private_api_url']}/{deposit_id}/meta" + return requests.get(url).json()["raw_metadata"] def origin_visit_types() -> List[str]: diff --git a/swh/web/misc/coverage.py b/swh/web/misc/coverage.py --- a/swh/web/misc/coverage.py +++ b/swh/web/misc/coverage.py @@ -7,10 +7,7 @@ from typing import Any, Dict, List, Tuple from urllib.parse import urlparse -import sentry_sdk - from django.conf.urls import url -from django.core.cache import cache from django.http.request import HttpRequest from django.http.response import HttpResponse from django.shortcuts import render @@ -20,7 +17,12 @@ from swh.scheduler.model import SchedulerMetrics from swh.web.common import archive from swh.web.common.origin_save import get_savable_visit_types -from swh.web.common.utils import get_deposits_list, is_swh_web_production, reverse +from swh.web.common.utils import ( + django_cache, + get_deposits_list, + is_swh_web_production, + reverse, +) from swh.web.config import scheduler _swh_arch_overview_doc = ( @@ -238,24 +240,26 @@ Dict[lister_name, List[Tuple[instance_name, SchedulerMetrics]]] as a lister instance has one SchedulerMetrics object per visit type. """ - cache_key = "lister_metrics" - listers_metrics = cache.get(cache_key, {}) - if not listers_metrics: + + @django_cache( + timeout=_cache_timeout, + catch_exception=True, + exception_return_value={}, + invalidate_cache_pred=lambda m: not cache_metrics, + ) + def _get_listers_metrics_internal(): listers_metrics = defaultdict(list) - try: - listers = scheduler().get_listers() - scheduler_metrics = scheduler().get_metrics() - for lister in listers: - for metrics in filter( - lambda m: m.lister_id == lister.id, scheduler_metrics - ): - listers_metrics[lister.name].append((lister.instance_name, metrics)) - if cache_metrics: - cache.set(cache_key, listers_metrics, timeout=_cache_timeout) - except Exception as e: - sentry_sdk.capture_exception(e) - - return listers_metrics + listers = scheduler().get_listers() + scheduler_metrics = scheduler().get_metrics() + for lister in listers: + for metrics in filter( + lambda m: m.lister_id == lister.id, scheduler_metrics + ): + listers_metrics[lister.name].append((lister.instance_name, metrics)) + + return listers_metrics + + return _get_listers_metrics_internal() def _get_deposits_netloc_counts(cache_counts: bool = False) -> Counter: @@ -271,42 +275,47 @@ netloc += "/" + parsed_url.path.split("/")[1] return netloc - cache_key = "deposits_netloc_counts" - deposits_netloc_counts = cache.get(cache_key, Counter()) - if not deposits_netloc_counts: + @django_cache( + timeout=_cache_timeout, + catch_exception=True, + exception_return_value=Counter(), + invalidate_cache_pred=lambda m: not cache_counts, + ) + def _get_deposits_netloc_counts_internal(): netlocs = [] - try: - deposits = get_deposits_list() - netlocs = [ - _process_origin_url(d["origin_url"]) - for d in deposits - if d["status"] == "done" - ] - deposits_netloc_counts = Counter(netlocs) - if cache_counts: - cache.set(cache_key, deposits_netloc_counts, timeout=_cache_timeout) - except Exception as e: - sentry_sdk.capture_exception(e) + deposits = get_deposits_list() + netlocs = [ + _process_origin_url(d["origin_url"]) + for d in deposits + if d["status"] == "done" + ] + deposits_netloc_counts = Counter(netlocs) + return deposits_netloc_counts - return deposits_netloc_counts + return _get_deposits_netloc_counts_internal() def _get_nixguix_origins_count(origin_url: str, cache_count: bool = False) -> int: """Returns number of archived tarballs for NixOS, aka the number of branches in a dedicated origin in the archive. """ - cache_key = f"nixguix_origins_count_{origin_url}" - nixguix_origins_count = cache.get(cache_key, 0) - if not nixguix_origins_count: + + @django_cache( + timeout=_cache_timeout, + catch_exception=True, + exception_return_value=0, + invalidate_cache_pred=lambda m: not cache_count, + ) + def _get_nixguix_origins_count_internal(): snapshot = archive.lookup_latest_origin_snapshot(origin_url) if snapshot: snapshot_sizes = archive.lookup_snapshot_sizes(snapshot["id"]) nixguix_origins_count = snapshot_sizes["release"] else: nixguix_origins_count = 0 - if cache_count: - cache.set(cache_key, nixguix_origins_count, timeout=_cache_timeout) - return nixguix_origins_count + return nixguix_origins_count + + return _get_nixguix_origins_count_internal() def _search_url(query: str, visit_type: str) -> str: diff --git a/swh/web/tests/common/test_utils.py b/swh/web/tests/common/test_utils.py --- a/swh/web/tests/common/test_utils.py +++ b/swh/web/tests/common/test_utils.py @@ -4,7 +4,9 @@ # See top-level LICENSE file for more information from base64 import b64encode import datetime +import math from os.path import join +import sys from urllib.parse import quote import pytest @@ -355,3 +357,77 @@ actual_url = utils.parse_swh_deposit_origin(raw_metadata) assert actual_url == expected_url + + +def add(x, y): + return x + y + + +def test_django_cache(mocker): + """Decorated function should be called once and returned value + put in django cache.""" + spy_add = mocker.spy(sys.modules[__name__], "add") + spy_cache_set = mocker.spy(utils.cache, "set") + + cached_add = utils.django_cache()(add) + + val = cached_add(1, 2) + val2 = cached_add(1, 2) + + assert val == val2 == 3 + assert spy_add.call_count == 1 + assert spy_cache_set.call_count == 1 + + +def test_django_cache_invalidate_cache_pred(mocker): + """Decorated function should be called twice and returned value + put in django cache twice.""" + spy_add = mocker.spy(sys.modules[__name__], "add") + spy_cache_set = mocker.spy(utils.cache, "set") + + cached_add = utils.django_cache(invalidate_cache_pred=lambda val: val == 3)(add) + + val = cached_add(1, 2) + val2 = cached_add(1, 2) + + assert val == val2 == 3 + assert spy_add.call_count == 2 + assert spy_cache_set.call_count == 2 + + +def test_django_cache_raise_exception(mocker): + """Decorated function should be called twice, exceptions should be + raised and no value put in django cache""" + spy_add = mocker.spy(sys.modules[__name__], "add") + spy_cache_set = mocker.spy(utils.cache, "set") + + cached_add = utils.django_cache()(add) + + with pytest.raises(TypeError): + cached_add(1, "2") + + with pytest.raises(TypeError): + cached_add(1, "2") + + assert spy_add.call_count == 2 + assert spy_cache_set.call_count == 0 + + +def test_django_cache_catch_exception(mocker): + """Decorated function should be called twice, exceptions should not be + raised, specified fallback value should be returned and no value put + in django cache""" + spy_add = mocker.spy(sys.modules[__name__], "add") + spy_cache_set = mocker.spy(utils.cache, "set") + + cached_add = utils.django_cache( + catch_exception=True, exception_return_value=math.nan + )(add) + + val = cached_add(1, "2") + val2 = cached_add(1, "2") + + assert math.isnan(val) + assert math.isnan(val2) + assert spy_add.call_count == 2 + assert spy_cache_set.call_count == 0