diff --git a/swh/storage/api/client.py b/swh/storage/api/client.py --- a/swh/storage/api/client.py +++ b/swh/storage/api/client.py @@ -127,6 +127,9 @@ 'target_types': target_types }) + def origin_get_random(self): + return self.get('origin/get_random') + def origin_get(self, origins=None, *, origin=None): if origin is None: if origins is None: diff --git a/swh/storage/api/server.py b/swh/storage/api/server.py --- a/swh/storage/api/server.py +++ b/swh/storage/api/server.py @@ -351,6 +351,12 @@ return encode_data(get_storage().origin_get(**decode_request(request))) +@app.route('/origin/get_random', methods=['GET']) +@timed +def origin_get_random(): + return encode_data(get_storage().origin_get_random()) + + @app.route('/origin/get_sha1', methods=['POST']) @timed def origin_get_by_sha1(): diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -648,6 +648,30 @@ yield from execute_values_generator( cur, query, ((sha1,) for sha1 in sha1s)) + def origin_get_random(self, cur=None): + """Randomly select one origin amongst dataset + + """ + cur = self._cursor(cur) + + columns = ','.join('origin.' + col for col in self.origin_cols) + query = f"""with swh_count_origins as ( + select value + from object_counts + where object_type='origin' + ), + swh_random_id as ( + select floor(random() * ( + select * from swh_count_origins) + )::int + ) + select {columns} + from origin + where id=(select * from swh_random_id); + """ + cur.execute(query) + return cur.fetchone()[0] + def origin_id_get_by_url(self, origins, cur=None): """Retrieve origin `(type, url)` from urls if found.""" cur = self._cursor(cur) diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -7,12 +7,14 @@ import bisect import dateutil import collections -from collections import defaultdict import copy import datetime import itertools import random +from collections import defaultdict +from typing import Any, Mapping + import attr from swh.model.model import \ @@ -1089,6 +1091,16 @@ for sha1 in sha1s ] + def origin_get_random(self) -> Mapping[str, Any]: + """Randomly select one origin from the archive + + Returns: + origin dict selected randomly on the dataset + + """ + key = random.choice(list(self._origins.keys())) + return self._convert_origin(self._origins[key]) + def origin_get_range(self, origin_from=1, origin_count=100): """Retrieve ``origin_count`` origins whose ids are greater or equal than ``origin_from``. diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -3,15 +3,16 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information - -from collections import defaultdict import copy -from concurrent.futures import ThreadPoolExecutor -from contextlib import contextmanager import datetime import itertools import json +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +from typing import Any, Mapping + import dateutil.parser import psycopg2 import psycopg2.pool @@ -1507,6 +1508,17 @@ else: yield None + @db_transaction(statement_timeout=500) + def origin_get_random(self, db=None, cur=None) -> Mapping[str, Any]: + """Randomly select one origin from the archive + + Returns: + origin dict selected randomly on the dataset + + """ + result = db.origin_get_random(cur) + return dict(zip(db.origin_cols, result)) + @db_transaction_generator() def origin_get_range(self, origin_from=1, origin_count=100, db=None, cur=None): diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -936,6 +936,14 @@ assert len(actual_origin0) == 1 assert actual_origin0[0]['url'] == data.origin['url'] + def test_origin_get_random(self, swh_storage): + swh_storage.origin_add(data.origins) + swh_storage.refresh_stat_counters() + random_origin = swh_storage.origin_get_random() + assert random_origin is not None + assert random_origin['url'] is not None + assert random_origin in data.origins + def test_origin_get_by_sha1(self, swh_storage): assert swh_storage.origin_get(data.origin) is None swh_storage.origin_add_one(data.origin)