diff --git a/sql/upgrades/144.sql b/sql/upgrades/144.sql new file mode 100644 --- /dev/null +++ b/sql/upgrades/144.sql @@ -0,0 +1,10 @@ +-- SWH DB schema upgrade +-- from_version: 143 +-- to_version: 144 +-- description: add index on sha1(origin.url) + +insert into dbversion(version, release, description) + values(143, now(), 'Work In Progress'); + +create index concurrently on origin using btree(digest(url, 'sha1')); + diff --git a/swh/storage/api/client.py b/swh/storage/api/client.py --- a/swh/storage/api/client.py +++ b/swh/storage/api/client.py @@ -139,6 +139,9 @@ DeprecationWarning) return self.post('origin/get', {'origins': origins}) + def origin_get_sha1(self, sha1s): + return self.post('origin/get_sha1', {'sha1s': sha1s}) + def origin_search(self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False): return self.post('origin/search', {'url_pattern': url_pattern, diff --git a/swh/storage/api/server.py b/swh/storage/api/server.py --- a/swh/storage/api/server.py +++ b/swh/storage/api/server.py @@ -351,6 +351,13 @@ return encode_data(get_storage().origin_get(**decode_request(request))) +@app.route('/origin/get_sha1', methods=['POST']) +@timed +def origin_get_sha1(): + return encode_data(get_storage().origin_get_sha1( + **decode_request(request))) + + @app.route('/origin/get_range', methods=['POST']) @timed def origin_get_range(): diff --git a/swh/storage/converters.py b/swh/storage/converters.py --- a/swh/storage/converters.py +++ b/swh/storage/converters.py @@ -7,6 +7,7 @@ from swh.core.utils import decode_with_escape, encode_with_unescape from swh.model import identifiers +from swh.model.hashutil import MultiHash DEFAULT_AUTHOR = { @@ -310,3 +311,10 @@ ret['object_id'] = db_release['object_id'] return ret + + +def origin_url_to_sha1(origin_url): + """Convert an origin URL to a sha1. Encodes URL to utf-8.""" + return MultiHash.from_data( + origin_url.encode('utf-8'), {'sha1'} + ).digest()['sha1'] diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -637,6 +637,17 @@ yield from execute_values_generator( cur, query, ((url,) for url in origins)) + def origin_get_by_sha1(self, sha1s, cur=None): + """Retrieve origin `(type, url)` from urls if found.""" + cur = self._cursor(cur) + + query = """SELECT %s FROM (VALUES %%s) as t(sha1) + LEFT JOIN origin ON t.sha1 = digest(origin.url, 'sha1') + """ % ','.join('origin.' + col for col in self.origin_cols) + + yield from execute_values_generator( + cur, query, ((sha1,) for sha1 in sha1s)) + def origin_id_get_by_url(self, origins, cur=None): """Retrieve origin `(type, url)` from urls if found.""" cur = self._cursor(cur) diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -22,6 +22,7 @@ from swh.objstorage.exc import ObjNotFoundError from .storage import get_journal_writer +from .converters import origin_url_to_sha1 # Max block size of contents to return BULK_BLOCK_CONTENT_LEN_MAX = 10000 @@ -52,6 +53,7 @@ self._snapshots = {} self._origins = {} self._origins_by_id = [] + self._origins_by_sha1 = {} self._origin_visits = {} self._persons = [] self._origin_metadata = defaultdict(list) @@ -1071,6 +1073,13 @@ else: return results + def origin_get_sha1(self, sha1s): + """Return origins matching the given sha1s""" + return [ + self._convert_origin(self._origins_by_sha1.get(sha1)) + for sha1 in sha1s + ] + def origin_get_range(self, origin_from=1, origin_count=100): """Retrieve ``origin_count`` origins whose ids are greater or equal than ``origin_from``. @@ -1196,6 +1205,7 @@ assert len(self._origins_by_id) == origin_id self._origins[origin.url] = origin + self._origins_by_sha1[origin_url_to_sha1(origin.url)] = origin self._origin_visits[origin.url] = [] self._objects[origin.url].append(('origin', origin.url)) diff --git a/swh/storage/sql/30-swh-schema.sql b/swh/storage/sql/30-swh-schema.sql --- a/swh/storage/sql/30-swh-schema.sql +++ b/swh/storage/sql/30-swh-schema.sql @@ -17,7 +17,7 @@ -- latest schema version insert into dbversion(version, release, description) - values(143, now(), 'Work In Progress'); + values(144, now(), 'Work In Progress'); -- a SHA1 checksum create domain sha1 as bytea check (length(value) = 20); diff --git a/swh/storage/sql/60-swh-indexes.sql b/swh/storage/sql/60-swh-indexes.sql --- a/swh/storage/sql/60-swh-indexes.sql +++ b/swh/storage/sql/60-swh-indexes.sql @@ -16,6 +16,7 @@ create index concurrently on origin using gin (url gin_trgm_ops); create index concurrently on origin using hash (url); +create index concurrently on origin using btree(digest(url, 'sha1')); -- skipped_content diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1488,6 +1488,24 @@ else: return [None if res['url'] is None else res for res in results] + @db_transaction_generator(statement_timeout=500) + def origin_get_sha1(self, sha1s, db=None, cur=None): + """Return an origin, identified by the sha1 of its URL. + + Args: + sha1s (list[bytes]): a list of sha1s + + Returns: + Optional[dict]: the origin dictionary with the keys: + - url: origin's url + or None if the origin was not found + """ + for line in db.origin_get_by_sha1(sha1s, cur): + if line[0] is not None: + yield dict(zip(db.origin_cols, line)) + else: + yield None + @db_transaction_generator() def origin_get_range(self, origin_from=1, origin_count=100, db=None, cur=None): diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -23,6 +23,7 @@ from swh.model.hashutil import hash_to_bytes from swh.model.hypothesis_strategies import objects from swh.storage import HashCollision +from swh.storage.converters import origin_url_to_sha1 as sha1 from .storage_data import data @@ -935,6 +936,24 @@ assert len(actual_origin0) == 1 assert actual_origin0[0]['url'] == data.origin['url'] + def test_origin_get_sha1(self, swh_storage): + assert swh_storage.origin_get(data.origin) is None + swh_storage.origin_add_one(data.origin) + + origins = list(swh_storage.origin_get_sha1([ + sha1(data.origin['url']) + ])) + assert len(origins) == 1 + assert origins[0]['url'] == data.origin['url'] + + def test_origin_get_sha1_not_found(self, swh_storage): + assert swh_storage.origin_get(data.origin) is None + origins = list(swh_storage.origin_get_sha1([ + sha1(data.origin['url']) + ])) + assert len(origins) == 1 + assert origins[0] is None + def test_origin_search_single_result(self, swh_storage): found_origins = list(swh_storage.origin_search(data.origin['url'])) assert len(found_origins) == 0