diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -1335,6 +1335,22 @@ def extid_get_from_token(self, token: int, *, statement) -> Iterable[ExtIDRow]: return map(ExtIDRow.from_dict, self._execute_with_retries(statement, [token]),) + # Rows are partitioned by token(extid_type, extid), then ordered (aka. "clustered") + # by (extid_type, extid, extid_version, ...). This means that, without knowing the + # exact extid_type and extid, we need to scan the whole partition; which should be + # reasonably small. We can change the schema later if this becomes an issue + @_prepared_select_statement( + ExtIDRow, + "WHERE token(extid_type, extid) = ? AND extid_version = ? ALLOW FILTERING", + ) + def extid_get_from_token_and_extid_version( + self, token: int, extid_version: int, *, statement + ) -> Iterable[ExtIDRow]: + return map( + ExtIDRow.from_dict, + self._execute_with_retries(statement, [token, extid_version]), + ) + @_prepared_select_statement( ExtIDRow, "WHERE extid_type=? AND extid=?", ) @@ -1346,17 +1362,50 @@ self._execute_with_retries(statement, [extid_type, extid]), ) + @_prepared_select_statement( + ExtIDRow, "WHERE extid_type=? AND extid=? AND extid_version = ?", + ) + def extid_get_from_extid_and_version( + self, extid_type: str, extid: bytes, extid_version: int, *, statement + ) -> Iterable[ExtIDRow]: + return map( + ExtIDRow.from_dict, + self._execute_with_retries(statement, [extid_type, extid, extid_version]), + ) + def extid_get_from_target( - self, target_type: str, target: bytes + self, + target_type: str, + target: bytes, + extid_type: Optional[str] = None, + extid_version: Optional[int] = None, ) -> Iterable[ExtIDRow]: for token in self._extid_get_tokens_from_target(target_type, target): if token is not None: - for extid in self.extid_get_from_token(token): + if extid_type is not None and extid_version is not None: + extids = self.extid_get_from_token_and_extid_version( + token, extid_version + ) + else: + extids = self.extid_get_from_token(token) + + for extid in extids: # re-check the extid against target (in case of murmur3 collision) if ( extid is not None and extid.target_type == target_type and extid.target == target + and ( + (extid_version is None and extid_type is None) + or ( + ( + extid_version is not None + and extid.extid_version == extid_version + and extid_type is not None + and extid.extid_type == extid_type + ) + ) + ) ): yield extid diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -1617,9 +1617,11 @@ for extid in extids: target_type = extid.target.object_type.value target = extid.target.object_id + extid_version = extid.extid_version + extid_type = extid.extid_type extidrow = ExtIDRow( - extid_type=extid.extid_type, - extid_version=extid.extid_version, + extid_type=extid_type, + extid_version=extid_version, extid=extid.extid, target_type=target_type, target=target, @@ -1634,10 +1636,17 @@ return {"extid:add": inserted} @timed - def extid_get_from_extid(self, id_type: str, ids: List[bytes]) -> List[ExtID]: + def extid_get_from_extid( + self, id_type: str, ids: List[bytes], version: Optional[int] = None + ) -> List[ExtID]: result: List[ExtID] = [] for extid in ids: - extidrows = list(self._cql_runner.extid_get_from_extid(id_type, extid)) + if version is not None: + extidrows = self._cql_runner.extid_get_from_extid_and_version( + id_type, extid, version + ) + else: + extidrows = self._cql_runner.extid_get_from_extid(id_type, extid) result.extend( ExtID( extid_type=extidrow.extid_type, @@ -1653,13 +1662,26 @@ @timed def extid_get_from_target( - self, target_type: SwhidObjectType, ids: List[Sha1Git] + self, + target_type: SwhidObjectType, + ids: List[Sha1Git], + extid_type: Optional[str] = None, + extid_version: Optional[int] = None, ) -> List[ExtID]: + if (extid_version is not None and extid_type is None) or ( + extid_version is None and extid_type is not None + ): + raise ValueError("You must provide both extid_type and extid_version") + result: List[ExtID] = [] for target in ids: - extidrows = list( - self._cql_runner.extid_get_from_target(target_type.value, target) + extidrows = self._cql_runner.extid_get_from_target( + target_type.value, + target, + extid_type=extid_type, + extid_version=extid_version, ) + result.extend( ExtID( extid_type=extidrow.extid_type, diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2020 The Software Heritage developers +# Copyright (C) 2015-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -706,15 +706,40 @@ ) return self._extid.get_from_primary_key(primary_key) - def extid_get_from_extid(self, extid_type: str, extid: bytes) -> Iterable[ExtIDRow]: + def extid_get_from_extid( + self, extid_type: str, extid: bytes, + ) -> Iterable[ExtIDRow]: return ( row for pk, row in self._extid.iter_all() if row.extid_type == extid_type and row.extid == extid ) - def extid_get_from_target( - self, target_type: str, target: bytes + def extid_get_from_extid_and_version( + self, extid_type: str, extid: bytes, extid_version: int, + ) -> Iterable[ExtIDRow]: + return ( + row + for pk, row in self._extid.iter_all() + if row.extid_type == extid_type + and row.extid == extid + and (extid_version is None or row.extid_version == extid_version) + ) + + def _extid_get_from_target_with_type_and_version( + self, target_type: str, target: bytes, extid_type: str, extid_version: int, + ) -> Iterable[ExtIDRow]: + return ( + row + for pk, row in self._extid.iter_all() + if row.target_type == target_type + and row.target == target + and row.extid_version == extid_version + and row.extid_type == extid_type + ) + + def _extid_get_from_target( + self, target_type: str, target: bytes, ) -> Iterable[ExtIDRow]: return ( row @@ -722,6 +747,26 @@ if row.target_type == target_type and row.target == target ) + def extid_get_from_target( + self, + target_type: str, + target: bytes, + extid_type: Optional[str] = None, + extid_version: Optional[int] = None, + ) -> Iterable[ExtIDRow]: + if (extid_version is not None and extid_type is None) or ( + extid_version is None and extid_type is not None + ): + raise ValueError("You must provide both extid_type and extid_version") + + if extid_type is not None and extid_version is not None: + extids = self._extid_get_from_target_with_type_and_version( + target_type, target, extid_type, extid_version + ) + else: + extids = self._extid_get_from_target(target_type, target) + return extids + class InMemoryStorage(CassandraStorage): _cql_runner: InMemoryCqlRunner # type: ignore diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2020 The Software Heritage developers +# Copyright (C) 2015-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -532,12 +532,15 @@ ... @remote_api_endpoint("extid/from_extid") - def extid_get_from_extid(self, id_type: str, ids: List[bytes]) -> List[ExtID]: + def extid_get_from_extid( + self, id_type: str, ids: List[bytes], version: Optional[int] = None + ) -> List[ExtID]: """Get ExtID objects from external IDs Args: id_type: type of the given external identifiers (e.g. 'mercurial') ids: list of external IDs + version: (Optional) version to use as filter Returns: list of ExtID objects @@ -547,13 +550,24 @@ @remote_api_endpoint("extid/from_target") def extid_get_from_target( - self, target_type: ObjectType, ids: List[Sha1Git] + self, + target_type: ObjectType, + ids: List[Sha1Git], + extid_type: Optional[str] = None, + extid_version: Optional[int] = None, ) -> List[ExtID]: """Get ExtID objects from target IDs and target_type Args: target_type: type the SWH object ids: list of target IDs + extid_type: (Optional) extid_type to use as filter. This cannot be empty if + extid_version is provided. + extid_version: (Optional) version to use as filter. This cannot be empty if + extid_type is provided. + + Raises: + ValueError if extid_version is provided without extid_type and vice versa. Returns: list of ExtID objects diff --git a/swh/storage/postgresql/db.py b/swh/storage/postgresql/db.py --- a/swh/storage/postgresql/db.py +++ b/swh/storage/postgresql/db.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2020 The Software Heritage developers +# Copyright (C) 2015-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -845,19 +845,26 @@ extid_cols = ["extid", "extid_version", "extid_type", "target", "target_type"] - def extid_get_from_extid_list(self, extid_type, ids, cur=None): + def extid_get_from_extid_list( + self, extid_type: str, ids: List[bytes], version: Optional[int] = None, cur=None + ): cur = self._cursor(cur) query_keys = ", ".join( self.mangle_query_key(k, "extid") for k in self.extid_cols ) - sql = """ - SELECT %s - FROM (VALUES %%s) as t(sortkey, extid, extid_type) + filter_query = "" + if version is not None: + filter_query = cur.mogrify( + f"WHERE extid_version={version}", (version,) + ).decode() + + sql = f""" + SELECT {query_keys} + FROM (VALUES %s) as t(sortkey, extid, extid_type) LEFT JOIN extid USING (extid, extid_type) + {filter_query} ORDER BY sortkey - """ % ( - query_keys, - ) + """ yield from execute_values_generator( cur, @@ -865,7 +872,14 @@ (((sortkey, extid, extid_type) for sortkey, extid in enumerate(ids))), ) - def extid_get_from_swhid_list(self, target_type, ids, cur=None): + def extid_get_from_swhid_list( + self, + target_type: str, + ids: List[bytes], + extid_version: Optional[int] = None, + extid_type: Optional[str] = None, + cur=None, + ): cur = self._cursor(cur) target_type = ObjectType( target_type @@ -873,14 +887,20 @@ query_keys = ", ".join( self.mangle_query_key(k, "extid") for k in self.extid_cols ) - sql = """ - SELECT %s - FROM (VALUES %%s) as t(sortkey, target, target_type) + filter_query = "" + if extid_version is not None and extid_type is not None: + filter_query = cur.mogrify( + "WHERE extid_version=%s AND extid_type=%s", (extid_version, extid_type,) + ).decode() + + sql = f""" + SELECT {query_keys} + FROM (VALUES %s) as t(sortkey, target, target_type) LEFT JOIN extid USING (target, target_type) + {filter_query} ORDER BY sortkey - """ % ( - query_keys, - ) + """ + yield from execute_values_generator( cur, sql, diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -692,10 +692,16 @@ @timed @db_transaction() def extid_get_from_extid( - self, id_type: str, ids: List[bytes], *, db: Db, cur=None + self, + id_type: str, + ids: List[bytes], + version: Optional[int] = None, + *, + db: Db, + cur=None, ) -> List[ExtID]: extids = [] - for row in db.extid_get_from_extid_list(id_type, ids, cur): + for row in db.extid_get_from_extid_list(id_type, ids, version=version, cur=cur): if row[0] is not None: extids.append(converters.db_to_extid(dict(zip(db.extid_cols, row)))) return extids @@ -703,10 +709,28 @@ @timed @db_transaction() def extid_get_from_target( - self, target_type: ObjectType, ids: List[Sha1Git], *, db: Db, cur=None + self, + target_type: ObjectType, + ids: List[Sha1Git], + extid_type: Optional[str] = None, + extid_version: Optional[int] = None, + *, + db: Db, + cur=None, ) -> List[ExtID]: extids = [] - for row in db.extid_get_from_swhid_list(target_type.value, ids, cur): + if (extid_version is not None and extid_type is None) or ( + extid_version is None and extid_type is not None + ): + raise ValueError("You must provide both extid_type and extid_version") + + for row in db.extid_get_from_swhid_list( + target_type.value, + ids, + extid_version=extid_version, + extid_type=extid_type, + cur=cur, + ): if row[0] is not None: extids.append(converters.db_to_extid(dict(zip(db.extid_cols, row)))) return extids diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -17,6 +17,7 @@ from hypothesis import HealthCheck, given, settings, strategies import pytest +from swh.core.api import RemoteException from swh.core.api.classes import stream_results from swh.model import from_disk from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes @@ -1350,6 +1351,7 @@ ExtID( extid=extid, extid_type="git", + extid_version=0, target=CoreSWHID(object_id=extid, object_type=ObjectType.REVISION,), ) for extid in ids @@ -1373,6 +1375,46 @@ objs = swh_storage.extid_get_from_target(ObjectType.REVISION, [swhid]) assert len(objs) == 2 assert set(obj.extid_version for obj in objs) == {0, 1} + for version in [0, 1]: + for git_id in ids: + objs = swh_storage.extid_get_from_extid( + "git", [git_id], version=version + ) + assert len(objs) == 1 + assert objs[0].extid_version == version + for swhid in ids: + objs = swh_storage.extid_get_from_target( + ObjectType.REVISION, + [swhid], + extid_version=version, + extid_type="git", + ) + assert len(objs) == 1 + assert objs[0].extid_version == version + assert objs[0].extid_type == "git" + + def test_extid_version_behavior_failure(self, swh_storage, sample_data): + """Calls with wrong input should raise""" + ids = [ + revision.id + for revision in sample_data.revisions + if revision.type.value == "git" + ] + + # Other edge cases + with pytest.raises( + (ValueError, RemoteException), match="both extid_type and extid_version" + ): + swh_storage.extid_get_from_target( + ObjectType.REVISION, [ids[0]], extid_version=0 + ) + + with pytest.raises( + (ValueError, RemoteException), match="both extid_type and extid_version" + ): + swh_storage.extid_get_from_target( + ObjectType.REVISION, [ids[0]], extid_type="git" + ) def test_release_add(self, swh_storage, sample_data): release, release2 = sample_data.releases[:2]