diff --git a/sql/upgrades/169.sql b/sql/upgrades/169.sql new file mode 100644 --- /dev/null +++ b/sql/upgrades/169.sql @@ -0,0 +1,39 @@ +-- SWH DB schema upgrade +-- from_version: 168 +-- to_version: 169 +-- description: Make origin_visit_status.type not null + +insert into dbversion(version, release, description) + values(169, now(), 'Work In Progress'); + +create or replace function swh_snapshot_get_by_id(id sha1_git, + branches_from bytea default '', branches_count bigint default null, + target_types snapshot_target[] default NULL, + branch_name_include_pattern bytea default NULL, + branch_name_exclude_prefix bytea default NULL) + returns setof snapshot_result + language sql + stable +as $$ + -- with small limits, the "naive" version of this query can degenerate into + -- using the deduplication index on snapshot_branch (name, target, + -- target_type); The planner happily scans several hundred million rows. + + -- Do the query in two steps: first pull the relevant branches for the given + -- snapshot (filtering them by type), then do the limiting. This two-step + -- process guides the planner into using the proper index. + with filtered_snapshot_branches as ( + select swh_snapshot_get_by_id.id as snapshot_id, name, target, target_type + from snapshot_branches + inner join snapshot_branch on snapshot_branches.branch_id = snapshot_branch.object_id + where snapshot_id = (select object_id from snapshot where snapshot.id = swh_snapshot_get_by_id.id) + and (target_types is null or target_type = any(target_types)) + order by name + ) + select snapshot_id, name, target, target_type + from filtered_snapshot_branches + where name >= branches_from + and (branch_name_include_pattern is null or name like '%'||branch_name_include_pattern||'%') + and (branch_name_exclude_prefix is null or name not like branch_name_exclude_prefix||'%') + order by name limit branches_count; +$$; diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -630,7 +630,7 @@ @_prepared_select_statement( SnapshotBranchRow, "WHERE snapshot_id = ? AND name >= ? LIMIT ?" ) - def snapshot_branch_get( + def snapshot_branch_get_from_name( self, snapshot_id: Sha1Git, from_: bytes, limit: int, *, statement ) -> Iterable[SnapshotBranchRow]: return map( @@ -638,6 +638,49 @@ self._execute_with_retries(statement, [snapshot_id, from_, limit]), ) + @_prepared_select_statement( + SnapshotBranchRow, "WHERE snapshot_id = ? AND name >= ? AND name < ? LIMIT ?" + ) + def snapshot_branch_get_before_name( + self, + snapshot_id: Sha1Git, + from_: bytes, + before: bytes, + limit: int, + *, + statement, + ) -> Iterable[SnapshotBranchRow]: + return map( + SnapshotBranchRow.from_dict, + self._execute_with_retries(statement, [snapshot_id, from_, before, limit]), + ) + + def snapshot_branch_get( + self, + snapshot_id: Sha1Git, + from_: bytes, + limit: int, + branch_name_exclude_prefix: Optional[bytes] = None, + ) -> Iterable[SnapshotBranchRow]: + if branch_name_exclude_prefix is None: + return self.snapshot_branch_get_from_name(snapshot_id, from_, limit) + else: + branches = list( + self.snapshot_branch_get_before_name( + snapshot_id, from_, branch_name_exclude_prefix, limit + ) + ) + nb_branches = len(branches) + if nb_branches < limit: + prefix = branch_name_exclude_prefix + from_branch = prefix[:-1] + bytes([prefix[-1] + 1]) + branches += list( + self.snapshot_branch_get_from_name( + snapshot_id, from_branch, limit - nb_branches + ) + ) + return branches + ########################## # 'origin' table ########################## diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -681,6 +681,8 @@ branches_from: bytes = b"", branches_count: int = 1000, target_types: Optional[List[str]] = None, + branch_name_include_pattern: Optional[bytes] = None, + branch_name_exclude_prefix: Optional[bytes] = None, ) -> Optional[PartialBranches]: if self._cql_runner.snapshot_missing([snapshot_id]): # Makes sure we don't fetch branches for a snapshot that is @@ -691,7 +693,10 @@ while len(branches) < branches_count + 1: new_branches = list( self._cql_runner.snapshot_branch_get( - snapshot_id, branches_from, branches_count + 1 + snapshot_id, + branches_from, + branches_count + 1, + branch_name_exclude_prefix, ) ) @@ -710,6 +715,18 @@ if branch.target is not None and branch.target_type in target_types ] + # Filter by branches_name_pattern + if branch_name_include_pattern: + new_branches_filtered = [ + branch + for branch in new_branches_filtered + if branch.name is not None + and ( + branch_name_include_pattern is None + or branch_name_include_pattern in branch.name + ) + ] + branches.extend(new_branches_filtered) if len(new_branches) < branches_count + 1: diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -402,11 +402,18 @@ return counts def snapshot_branch_get( - self, snapshot_id: Sha1Git, from_: bytes, limit: int + self, + snapshot_id: Sha1Git, + from_: bytes, + limit: int, + branch_name_exclude_prefix: Optional[bytes] = None, ) -> Iterable[SnapshotBranchRow]: count = 0 for branch in self._snapshot_branches.get_from_partition_key((snapshot_id,)): - if branch.name >= from_: + prefix = branch_name_exclude_prefix + if branch.name >= from_ and ( + prefix is None or not branch.name.startswith(prefix) + ): count += 1 yield branch if count >= limit: diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -748,6 +748,8 @@ branches_from: bytes = b"", branches_count: int = 1000, target_types: Optional[List[str]] = None, + branch_name_include_pattern: Optional[bytes] = None, + branch_name_exclude_prefix: Optional[bytes] = None, ) -> Optional[PartialBranches]: """Get the content, possibly partial, of a snapshot with the given id @@ -764,6 +766,10 @@ target types of branch to return (possible values that can be contained in that list are `'content', 'directory', 'revision', 'release', 'snapshot', 'alias'`) + branch_name_include_pattern: if provided, only return branches whose name + contains given pattern + branch_name_exclude_prefix: if provided, do not return branches whose name + contains given prefix Returns: dict: None if the snapshot does not exist; diff --git a/swh/storage/postgresql/db.py b/swh/storage/postgresql/db.py --- a/swh/storage/postgresql/db.py +++ b/swh/storage/postgresql/db.py @@ -29,7 +29,7 @@ """ - current_version = 168 + current_version = 169 def mktemp_dir_entry(self, entry_type, cur=None): self._cursor(cur).execute( @@ -259,17 +259,29 @@ branches_from=b"", branches_count=None, target_types=None, + branch_name_include_pattern=None, + branch_name_exclude_prefix=None, cur=None, ): cur = self._cursor(cur) query = """\ - SELECT %s - FROM swh_snapshot_get_by_id(%%s, %%s, %%s, %%s :: snapshot_target[]) + SELECT %s + FROM swh_snapshot_get_by_id(%%s, %%s, %%s, %%s :: snapshot_target[], %%s, %%s) """ % ", ".join( self.snapshot_get_cols ) - cur.execute(query, (snapshot_id, branches_from, branches_count, target_types)) + cur.execute( + query, + ( + snapshot_id, + branches_from, + branches_count, + target_types, + branch_name_include_pattern, + branch_name_exclude_prefix, + ), + ) yield from cur diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -816,6 +816,8 @@ branches_from: bytes = b"", branches_count: int = 1000, target_types: Optional[List[str]] = None, + branch_name_include_pattern: Optional[bytes] = None, + branch_name_exclude_prefix: Optional[bytes] = None, db=None, cur=None, ) -> Optional[PartialBranches]: @@ -834,6 +836,8 @@ # optimal performances branches_count=max(branches_count + 1, 10), target_types=target_types, + branch_name_include_pattern=branch_name_include_pattern, + branch_name_exclude_prefix=branch_name_exclude_prefix, cur=cur, ) ) diff --git a/swh/storage/sql/30-schema.sql b/swh/storage/sql/30-schema.sql --- a/swh/storage/sql/30-schema.sql +++ b/swh/storage/sql/30-schema.sql @@ -17,7 +17,7 @@ -- latest schema version insert into dbversion(version, release, description) - values(168, now(), 'Work In Progress'); + values(169, now(), 'Work In Progress'); -- a SHA1 checksum create domain sha1 as bytea check (length(value) = 20); diff --git a/swh/storage/sql/40-funcs.sql b/swh/storage/sql/40-funcs.sql --- a/swh/storage/sql/40-funcs.sql +++ b/swh/storage/sql/40-funcs.sql @@ -674,7 +674,9 @@ create or replace function swh_snapshot_get_by_id(id sha1_git, branches_from bytea default '', branches_count bigint default null, - target_types snapshot_target[] default NULL) + target_types snapshot_target[] default NULL, + branch_name_include_pattern bytea default NULL, + branch_name_exclude_prefix bytea default NULL) returns setof snapshot_result language sql stable @@ -697,6 +699,8 @@ select snapshot_id, name, target, target_type from filtered_snapshot_branches where name >= branches_from + and (branch_name_include_pattern is null or name like '%'||branch_name_include_pattern||'%') + and (branch_name_exclude_prefix is null or name not like branch_name_exclude_prefix||'%') order by name limit branches_count; $$; diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -10,6 +10,7 @@ import itertools import math import random +import re from typing import Any, ClassVar, Dict, Iterator, Optional import attr @@ -32,6 +33,7 @@ Revision, SkippedContent, Snapshot, + SnapshotBranch, TargetType, ) from swh.storage import get_storage @@ -3159,6 +3161,111 @@ assert len(branches) == 1 assert alias1 in branches + def test_snapshot_add_get_by_branches_name_pattern(self, swh_storage, sample_data): + snapshot = Snapshot( + branches={ + b"refs/heads/master": SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ), + b"refs/heads/incoming": SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ), + b"refs/pull/1": SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ), + b"refs/pull/2": SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ), + b"dangling": None, + }, + ) + swh_storage.snapshot_add([snapshot]) + + for include_pattern, exclude_prefix, nb_results in ( + (b"pull", None, 2), + (b"incoming", None, 1), + (b"dangling", None, 1), + (None, b"refs/heads/", 3), + (b"refs", b"refs/heads/master", 3), + ): + branches = swh_storage.snapshot_get_branches( + snapshot.id, + branch_name_include_pattern=include_pattern, + branch_name_exclude_prefix=exclude_prefix, + )["branches"] + assert len(branches) == nb_results + for branch_name in branches: + if include_pattern: + assert include_pattern in branch_name + if exclude_prefix: + assert not branch_name.startswith(exclude_prefix) + + def test_snapshot_add_get_by_branches_name_pattern_filtered_paginated( + self, swh_storage, sample_data + ): + pattern = "foo" + nb_branches_by_target_type = 10 + branches = {} + for i in range(nb_branches_by_target_type): + branches[f"branch/directory/bar{i}".encode()] = SnapshotBranch( + target=sample_data.directory.id, target_type=TargetType.DIRECTORY, + ) + branches[f"branch/revision/bar{i}".encode()] = SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ) + branches[f"branch/directory/{pattern}{i}".encode()] = SnapshotBranch( + target=sample_data.directory.id, target_type=TargetType.DIRECTORY, + ) + branches[f"branch/revision/{pattern}{i}".encode()] = SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ) + + snapshot = Snapshot(branches=branches) + swh_storage.snapshot_add([snapshot]) + + regexp = re.compile(pattern) + branches_count = nb_branches_by_target_type // 2 + + for target_type in ( + TargetType.DIRECTORY, + TargetType.REVISION, + ): + target_type_str = target_type.value + partial_branches = swh_storage.snapshot_get_branches( + snapshot.id, + branch_name_include_pattern=pattern.encode(), + target_types=[target_type_str], + branches_count=branches_count, + ) + branches = partial_branches["branches"] + + assert len(branches) == branches_count + for branch_name, branch_data in branches.items(): + assert regexp.search(branch_name.decode("utf-8")) + assert branch_data.target_type == target_type + for i in range(branches_count): + assert f"branch/{target_type_str}/{pattern}{i}".encode() in branches + assert ( + partial_branches["next_branch"] + == f"branch/{target_type_str}/{pattern}{branches_count}".encode() + ) + + partial_branches = swh_storage.snapshot_get_branches( + snapshot.id, + branch_name_include_pattern=pattern.encode(), + target_types=[target_type_str], + branches_from=partial_branches["next_branch"], + ) + branches = partial_branches["branches"] + + assert len(branches) == branches_count + for branch_name, branch_data in branches.items(): + assert regexp.search(branch_name.decode("utf-8")) + assert branch_data.target_type == target_type + for i in range(branches_count, 2 * branches_count): + assert f"branch/{target_type_str}/{pattern}{i}".encode() in branches + assert partial_branches["next_branch"] is None + def test_snapshot_add_get(self, swh_storage, sample_data): snapshot = sample_data.snapshot origin = sample_data.origin