diff --git a/sql/upgrades/145.sql b/sql/upgrades/145.sql new file mode 100644 --- /dev/null +++ b/sql/upgrades/145.sql @@ -0,0 +1,9 @@ +-- SWH DB schema upgrade +-- from_version: 144 +-- to_version: 145 +-- description: Improve query on origin_visit + +insert into dbversion(version, release, description) + values(145, now(), 'Work In Progress'); + +create index concurrently on origin_visit(type, status, date); diff --git a/swh/storage/api/client.py b/swh/storage/api/client.py --- a/swh/storage/api/client.py +++ b/swh/storage/api/client.py @@ -185,6 +185,11 @@ return self.post('origin/visit/get', { 'origin': origin, 'last_visit': last_visit, 'limit': limit}) + def origin_visit_get_random(self, type): + return self.post('origin/visit/get_random', { + 'type': type, + }) + def origin_visit_find_by_date(self, origin, visit_date, limit=None): return self.post('origin/visit/find_by_date', { 'origin': origin, 'visit_date': visit_date}) diff --git a/swh/storage/api/server.py b/swh/storage/api/server.py --- a/swh/storage/api/server.py +++ b/swh/storage/api/server.py @@ -402,6 +402,13 @@ **decode_request(request))) +@app.route('/origin/visit/get_random', methods=['POST']) +@timed +def origin_visit_get_random(): + return encode_data(get_storage().origin_visit_get_random( + **decode_request(request))) + + @app.route('/origin/visit/find_by_date', methods=['POST']) @timed def origin_visit_find_by_date(): diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -648,6 +648,30 @@ yield from execute_values_generator( cur, query, ((sha1,) for sha1 in sha1s)) + def origin_visit_get_random(self, type, cur=None): + """Randomly select one origin whose last visit was full in the last 3 + months + + """ + cur = self._cursor(cur) + columns = ','.join(self.origin_visit_select_cols) + query = f"""with visits as ( + select * + from origin_visit + where origin_visit.status='full' and + origin_visit.type=%s and + origin_visit.date > now() - '3 months'::interval + ) + select {columns} + from visits as origin_visit + inner join origin + on origin_visit.origin=origin.id + where random() < 0.1 + limit 1 + """ + cur.execute(query, (type, )) + return cur.fetchone() + def origin_id_get_by_url(self, origins, cur=None): """Retrieve origin `(type, url)` from urls if found.""" cur = self._cursor(cur) diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -7,12 +7,15 @@ import bisect import dateutil import collections -from collections import defaultdict import copy import datetime import itertools import random +from collections import defaultdict +from datetime import timedelta +from typing import Any, Dict, Mapping + import attr from swh.model.model import \ @@ -1089,6 +1092,36 @@ for sha1 in sha1s ] + def _select_random_origin_by_type(self, type: str) -> str: + """Select randomly an origin visit """ + while True: + url = random.choice(list(self._origin_visits.keys())) + random_origin_visits = self._origin_visits[url] + if random_origin_visits[0].type == type: + return url + + def origin_visit_get_random(self, type: str) -> Mapping[str, Any]: + """Randomly select one origin with whose visit was successful + in the last 3 months. + + Returns: + origin dict selected randomly on the dataset + + """ + random_visit: Dict[str, Any] = {} + if not self._origin_visits: # empty dataset + return random_visit + url = self._select_random_origin_by_type(type) + random_origin_visits = copy.deepcopy(self._origin_visits[url]) + random_origin_visits.reverse() + back_in_the_day = now() - timedelta(weeks=12) # 3 months back + # This should be enough for tests + for visit in random_origin_visits: + if visit.date > back_in_the_day and visit.status == 'full': + random_visit = visit.to_dict() + break + return random_visit + def origin_get_range(self, origin_from=1, origin_count=100): """Retrieve ``origin_count`` origins whose ids are greater or equal than ``origin_from``. diff --git a/swh/storage/sql/30-swh-schema.sql b/swh/storage/sql/30-swh-schema.sql --- a/swh/storage/sql/30-swh-schema.sql +++ b/swh/storage/sql/30-swh-schema.sql @@ -17,7 +17,7 @@ -- latest schema version insert into dbversion(version, release, description) - values(144, now(), 'Work In Progress'); + values(145, now(), 'Work In Progress'); -- a SHA1 checksum create domain sha1 as bytea check (length(value) = 20); @@ -466,4 +466,3 @@ comment on column object_counts_bucketed.bucket_end is 'Upper bound (exclusive) for the bucket'; comment on column object_counts_bucketed.value is 'Count of objects in the bucket'; comment on column object_counts_bucketed.last_update is 'Last update for the object count in this bucket'; - diff --git a/swh/storage/sql/60-swh-indexes.sql b/swh/storage/sql/60-swh-indexes.sql --- a/swh/storage/sql/60-swh-indexes.sql +++ b/swh/storage/sql/60-swh-indexes.sql @@ -125,6 +125,7 @@ alter table origin_visit add primary key using index origin_visit_pkey; create index concurrently on origin_visit(date); +create index concurrently on origin_visit(type, status, date); alter table origin_visit add constraint origin_visit_origin_fkey foreign key (origin) references origin(id) not valid; alter table origin_visit validate constraint origin_visit_origin_fkey; diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -3,15 +3,16 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information - -from collections import defaultdict import copy -from concurrent.futures import ThreadPoolExecutor -from contextlib import contextmanager import datetime import itertools import json +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +from typing import Any, Dict, Mapping + import dateutil.parser import psycopg2 import psycopg2.pool @@ -1507,6 +1508,21 @@ else: yield None + @db_transaction() + def origin_visit_get_random( + self, type, db=None, cur=None) -> Mapping[str, Any]: + """Randomly select one origin from the archive + + Returns: + origin dict selected randomly on the dataset if found + + """ + data: Dict[str, Any] = {} + result = db.origin_visit_get_random(type, cur) + if result: + data = dict(zip(db.origin_visit_get_cols, result)) + return data + @db_transaction_generator() def origin_get_range(self, origin_from=1, origin_count=100, db=None, cur=None): diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -8,8 +8,11 @@ import datetime import itertools import queue +import random import threading + from collections import defaultdict +from datetime import timedelta from unittest.mock import Mock import psycopg2 @@ -936,6 +939,67 @@ assert len(actual_origin0) == 1 assert actual_origin0[0]['url'] == data.origin['url'] + def _generate_random_visits(self, nb_visits=100, start=0, end=7): + """Generate random visits within the last 2 months (to avoid + computations) + + """ + visits = [] + today = datetime.datetime.now(tz=datetime.timezone.utc) + for weeks in range(nb_visits, 0, -1): + hours = random.randint(0, 24) + minutes = random.randint(0, 60) + seconds = random.randint(0, 60) + days = random.randint(0, 28) + weeks = random.randint(start, end) + date_visit = today - timedelta( + weeks=weeks, hours=hours, minutes=minutes, + seconds=seconds, days=days) + visits.append(date_visit) + return visits + + def test_origin_visit_get_random(self, swh_storage): + swh_storage.origin_add(data.origins) + # Add some random visits within the selection range + visits = self._generate_random_visits() + visit_type = 'git' + + # Add visits to those origins + for origin in data.origins: + for date_visit in visits: + visit = swh_storage.origin_visit_add( + origin['url'], date=date_visit, type=visit_type) + swh_storage.origin_visit_update( + origin['url'], visit_id=visit['visit'], status='full') + + swh_storage.refresh_stat_counters() + + stats = swh_storage.stat_counters() + assert stats['origin'] == len(data.origins) + assert stats['origin_visit'] == len(data.origins) * len(visits) + + random_origin_visit = swh_storage.origin_visit_get_random(visit_type) + assert random_origin_visit + assert random_origin_visit['origin'] is not None + original_urls = [o['url'] for o in data.origins] + assert random_origin_visit['origin'] in original_urls + + def test_origin_visit_get_random_nothing_found(self, swh_storage): + swh_storage.origin_add(data.origins) + visit_type = 'hg' + # Add some visits outside of the random generation selection so nothing + # will be found by the random selection + visits = self._generate_random_visits(nb_visits=3, start=13, end=24) + for origin in data.origins: + for date_visit in visits: + visit = swh_storage.origin_visit_add( + origin['url'], date=date_visit, type=visit_type) + swh_storage.origin_visit_update( + origin['url'], visit_id=visit['visit'], status='full') + + random_origin_visit = swh_storage.origin_visit_get_random(visit_type) + assert random_origin_visit == {} + def test_origin_get_by_sha1(self, swh_storage): assert swh_storage.origin_get(data.origin) is None swh_storage.origin_add_one(data.origin) diff --git a/tox.ini b/tox.ini --- a/tox.ini +++ b/tox.ini @@ -6,6 +6,7 @@ testing deps = pytest-cov + dev: ipdb commands = pytest \ !slow: --hypothesis-profile=fast \