diff --git a/sql/upgrades/168.sql b/sql/upgrades/168.sql new file mode 100644 --- /dev/null +++ b/sql/upgrades/168.sql @@ -0,0 +1,39 @@ +-- SWH DB schema upgrade +-- from_version: 167 +-- to_version: 168 +-- description: Make origin_visit_status.type not null + +insert into dbversion(version, release, description) + values(168, now(), 'Work In Progress'); + +create or replace function swh_snapshot_get_by_id(id sha1_git, + branches_from bytea default '', branches_count bigint default null, + target_types snapshot_target[] default NULL, + branch_name_include_pattern text default NULL, + branch_name_exclude_pattern text default NULL) + returns setof snapshot_result + language sql + stable +as $$ + -- with small limits, the "naive" version of this query can degenerate into + -- using the deduplication index on snapshot_branch (name, target, + -- target_type); The planner happily scans several hundred million rows. + + -- Do the query in two steps: first pull the relevant branches for the given + -- snapshot (filtering them by type), then do the limiting. This two-step + -- process guides the planner into using the proper index. + with filtered_snapshot_branches as ( + select swh_snapshot_get_by_id.id as snapshot_id, name, target, target_type + from snapshot_branches + inner join snapshot_branch on snapshot_branches.branch_id = snapshot_branch.object_id + where snapshot_id = (select object_id from snapshot where snapshot.id = swh_snapshot_get_by_id.id) + and (target_types is null or target_type = any(target_types)) + order by name + ) + select snapshot_id, name, target, target_type + from filtered_snapshot_branches + where name >= branches_from + and (branch_name_include_pattern is null or convert_from(name, 'utf-8') ilike branch_name_include_pattern) + and (branch_name_exclude_pattern is null or convert_from(name, 'utf-8') not ilike branch_name_exclude_pattern) + order by name limit branches_count; +$$; diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -678,6 +678,8 @@ branches_from: bytes = b"", branches_count: int = 1000, target_types: Optional[List[str]] = None, + branch_name_include_pattern: Optional[str] = None, + branch_name_exclude_pattern: Optional[str] = None, ) -> Optional[PartialBranches]: if self._cql_runner.snapshot_missing([snapshot_id]): # Makes sure we don't fetch branches for a snapshot that is @@ -707,6 +709,23 @@ if branch.target is not None and branch.target_type in target_types ] + # Filter by branches_name_pattern + if branch_name_include_pattern or branch_name_exclude_pattern: + new_branches_filtered = [ + branch + for branch in new_branches_filtered + if branch.name is not None + and ( + branch_name_include_pattern is None + or branch_name_include_pattern in branch.name.decode("utf-8") + ) + and ( + branch_name_exclude_pattern is None + or branch_name_exclude_pattern + not in branch.name.decode("utf-8") + ) + ] + branches.extend(new_branches_filtered) if len(new_branches) < branches_count + 1: diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -701,6 +701,8 @@ branches_from: bytes = b"", branches_count: int = 1000, target_types: Optional[List[str]] = None, + branch_name_include_pattern: Optional[str] = None, + branch_name_exclude_pattern: Optional[str] = None, ) -> Optional[PartialBranches]: """Get the content, possibly partial, of a snapshot with the given id @@ -717,6 +719,10 @@ target types of branch to return (possible values that can be contained in that list are `'content', 'directory', 'revision', 'release', 'snapshot', 'alias'`) + branch_name_include_pattern: if provided, only return branches whose name + contains given pattern + branch_name_exclude_pattern: if provided, do not return branches whose name + contains given pattern Returns: dict: None if the snapshot does not exist; diff --git a/swh/storage/postgresql/db.py b/swh/storage/postgresql/db.py --- a/swh/storage/postgresql/db.py +++ b/swh/storage/postgresql/db.py @@ -28,7 +28,7 @@ """ - current_version = 167 + current_version = 168 def mktemp_dir_entry(self, entry_type, cur=None): self._cursor(cur).execute( @@ -254,17 +254,33 @@ branches_from=b"", branches_count=None, target_types=None, + branch_name_include_pattern=None, + branch_name_exclude_pattern=None, cur=None, ): cur = self._cursor(cur) query = """\ - SELECT %s - FROM swh_snapshot_get_by_id(%%s, %%s, %%s, %%s :: snapshot_target[]) + SELECT %s + FROM swh_snapshot_get_by_id(%%s, %%s, %%s, %%s :: snapshot_target[], %%s, %%s) """ % ", ".join( self.snapshot_get_cols ) - cur.execute(query, (snapshot_id, branches_from, branches_count, target_types)) + cur.execute( + query, + ( + snapshot_id, + branches_from, + branches_count, + target_types, + f"%{branch_name_include_pattern}%" + if branch_name_include_pattern + else None, + f"%{branch_name_exclude_pattern}%" + if branch_name_exclude_pattern + else None, + ), + ) yield from cur diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -766,6 +766,8 @@ branches_from: bytes = b"", branches_count: int = 1000, target_types: Optional[List[str]] = None, + branch_name_include_pattern: Optional[str] = None, + branch_name_exclude_pattern: Optional[str] = None, db=None, cur=None, ) -> Optional[PartialBranches]: @@ -781,6 +783,8 @@ branches_from=branches_from, branches_count=branches_count + 1, target_types=target_types, + branch_name_include_pattern=branch_name_include_pattern, + branch_name_exclude_pattern=branch_name_exclude_pattern, cur=cur, ) ) diff --git a/swh/storage/sql/30-schema.sql b/swh/storage/sql/30-schema.sql --- a/swh/storage/sql/30-schema.sql +++ b/swh/storage/sql/30-schema.sql @@ -17,7 +17,7 @@ -- latest schema version insert into dbversion(version, release, description) - values(167, now(), 'Work In Progress'); + values(168, now(), 'Work In Progress'); -- a SHA1 checksum create domain sha1 as bytea check (length(value) = 20); diff --git a/swh/storage/sql/40-funcs.sql b/swh/storage/sql/40-funcs.sql --- a/swh/storage/sql/40-funcs.sql +++ b/swh/storage/sql/40-funcs.sql @@ -657,7 +657,9 @@ create or replace function swh_snapshot_get_by_id(id sha1_git, branches_from bytea default '', branches_count bigint default null, - target_types snapshot_target[] default NULL) + target_types snapshot_target[] default NULL, + branch_name_include_pattern text default NULL, + branch_name_exclude_pattern text default NULL) returns setof snapshot_result language sql stable @@ -680,6 +682,8 @@ select snapshot_id, name, target, target_type from filtered_snapshot_branches where name >= branches_from + and (branch_name_include_pattern is null or convert_from(name, 'utf-8') ilike branch_name_include_pattern) + and (branch_name_exclude_pattern is null or convert_from(name, 'utf-8') not ilike branch_name_exclude_pattern) order by name limit branches_count; $$; diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -10,6 +10,7 @@ import itertools import math import random +import re from typing import Any, ClassVar, Dict, Iterator, Optional import attr @@ -30,6 +31,7 @@ Revision, SkippedContent, Snapshot, + SnapshotBranch, TargetType, ) from swh.storage import get_storage @@ -2987,6 +2989,114 @@ assert len(branches) == 1 assert alias1 in branches + def test_snapshot_add_get_by_branches_name_pattern(self, swh_storage, sample_data): + snapshot = Snapshot( + branches={ + b"refs/heads/master": SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ), + b"refs/heads/incoming": SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ), + b"refs/pull/1": SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ), + b"refs/pull/2": SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ), + b"dangling": None, + }, + ) + swh_storage.snapshot_add([snapshot]) + + for include_pattern, exclude_pattern, nb_results in ( + ("pull", None, 2), + ("incoming", None, 1), + ("dangling", None, 1), + # does not contain heads + (None, "heads", 3), + # contains refs but not master + ("refs", "master", 3), + ): + branches = swh_storage.snapshot_get_branches( + snapshot.id, + branch_name_include_pattern=include_pattern, + branch_name_exclude_pattern=exclude_pattern, + )["branches"] + assert len(branches) == nb_results + for branch_name in branches: + branch_name_decoded = branch_name.decode("utf-8") + if include_pattern: + assert include_pattern in branch_name_decoded + if exclude_pattern: + assert exclude_pattern not in branch_name_decoded + + def test_snapshot_add_get_by_branches_name_pattern_filtered_paginated( + self, swh_storage, sample_data + ): + pattern = "foo" + nb_branches_by_target_type = 10 + branches = {} + for i in range(nb_branches_by_target_type): + branches[f"branch/directory/bar{i}".encode()] = SnapshotBranch( + target=sample_data.directory.id, target_type=TargetType.DIRECTORY, + ) + branches[f"branch/revision/bar{i}".encode()] = SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ) + branches[f"branch/directory/{pattern}{i}".encode()] = SnapshotBranch( + target=sample_data.directory.id, target_type=TargetType.DIRECTORY, + ) + branches[f"branch/revision/{pattern}{i}".encode()] = SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ) + + snapshot = Snapshot(branches=branches) + swh_storage.snapshot_add([snapshot]) + + regexp = re.compile(pattern) + branches_count = nb_branches_by_target_type // 2 + + for target_type in ( + TargetType.DIRECTORY, + TargetType.REVISION, + ): + target_type_str = target_type.value + partial_branches = swh_storage.snapshot_get_branches( + snapshot.id, + branch_name_include_pattern=pattern, + target_types=[target_type_str], + branches_count=branches_count, + ) + branches = partial_branches["branches"] + + assert len(branches) == branches_count + for branch_name, branch_data in branches.items(): + assert regexp.search(branch_name.decode("utf-8")) + assert branch_data.target_type == target_type + for i in range(branches_count): + assert f"branch/{target_type_str}/{pattern}{i}".encode() in branches + assert ( + partial_branches["next_branch"] + == f"branch/{target_type_str}/{pattern}{branches_count}".encode() + ) + + partial_branches = swh_storage.snapshot_get_branches( + snapshot.id, + branch_name_include_pattern=pattern, + target_types=[target_type_str], + branches_from=partial_branches["next_branch"], + ) + branches = partial_branches["branches"] + + assert len(branches) == branches_count + for branch_name, branch_data in branches.items(): + assert regexp.search(branch_name.decode("utf-8")) + assert branch_data.target_type == target_type + for i in range(branches_count, 2 * branches_count): + assert f"branch/{target_type_str}/{pattern}{i}".encode() in branches + assert partial_branches["next_branch"] is None + def test_snapshot_add_get(self, swh_storage, sample_data): snapshot = sample_data.snapshot origin = sample_data.origin