diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -14,7 +14,7 @@ from collections import defaultdict from datetime import timedelta -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Optional import attr @@ -1597,7 +1597,7 @@ if random_origin_visits[0].type == type: return url - def origin_visit_get_random(self, type: str) -> Mapping[str, Any]: + def origin_visit_get_random(self, type: str) -> Optional[Dict[str, Any]]: """Randomly select one successful origin visit with made in the last 3 months. @@ -1606,9 +1606,6 @@ `origin_visit_get`. """ - random_visit: Dict[str, Any] = {} - if not self._origin_visits: # empty dataset - return random_visit url = self._select_random_origin_visit_by_type(type) random_origin_visits = copy.deepcopy(self._origin_visits[url]) random_origin_visits.reverse() @@ -1616,9 +1613,9 @@ # This should be enough for tests for visit in random_origin_visits: if visit.date > back_in_the_day and visit.status == 'full': - random_visit = visit.to_dict() - break - return random_visit + return visit.to_dict() + else: + return None def stat_counters(self): """compute statistics about the number of tuples in various tables diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -11,7 +11,7 @@ from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Optional import dateutil.parser import psycopg2 @@ -1610,7 +1610,7 @@ @timed @db_transaction() def origin_visit_get_random( - self, type: str, db=None, cur=None) -> Mapping[str, Any]: + self, type: str, db=None, cur=None) -> Optional[Dict[str, Any]]: """Randomly select one successful origin visit with made in the last 3 months. @@ -1619,11 +1619,11 @@ :py:meth:`origin_visit_get`. """ - data: Dict[str, Any] = {} result = db.origin_visit_get_random(type, cur) if result: - data = dict(zip(db.origin_visit_get_cols, result)) - return data + return dict(zip(db.origin_visit_get_cols, result)) + else: + return None @remote_api_endpoint('object/find_by_sha1_git') @timed diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1188,7 +1188,7 @@ origin['url'], visit_id=visit['visit'], status='full') random_origin_visit = swh_storage.origin_visit_get_random(visit_type) - assert random_origin_visit == {} + assert random_origin_visit is None def test_origin_get_by_sha1(self, swh_storage): assert swh_storage.origin_get(data.origin) is None