diff --git a/sql/upgrades/165.sql b/sql/upgrades/165.sql new file mode 100644 --- /dev/null +++ b/sql/upgrades/165.sql @@ -0,0 +1,30 @@ +-- SWH DB schema upgrade +-- from_version: 164 +-- to_version: 165 +-- description: add branches_name_pattern parameter to swh_snapshot_get_by_id + +insert into dbversion(version, release, description) + values(165, now(), 'Work In Progress'); + +create or replace function swh_snapshot_get_by_id(id sha1_git, + branches_from bytea default '', branches_count bigint default null, + target_types snapshot_target[] default NULL, + branches_name_pattern text default NULL) + returns setof snapshot_result + language sql + stable +as $$ + with filtered_snapshot_branches as ( + select swh_snapshot_get_by_id.id as snapshot_id, name, target, target_type + from snapshot_branches + inner join snapshot_branch on snapshot_branches.branch_id = snapshot_branch.object_id + where snapshot_id = (select object_id from snapshot where snapshot.id = swh_snapshot_get_by_id.id) + and (target_types is null or target_type = any(target_types)) + order by name + ) + select snapshot_id, name, target, target_type + from filtered_snapshot_branches + where name >= branches_from + and (branches_name_pattern is null or convert_from(name, 'utf-8') ~* branches_name_pattern) + order by name limit branches_count; +$$; diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -677,6 +677,7 @@ branches_from: bytes = b"", branches_count: int = 1000, target_types: Optional[List[str]] = None, + branches_name_pattern: Optional[str] = None, ) -> Optional[PartialBranches]: if self._cql_runner.snapshot_missing([snapshot_id]): # Makes sure we don't fetch branches for a snapshot that is @@ -706,6 +707,16 @@ if branch.target is not None and branch.target_type in target_types ] + # Filter by branches_name_pattern + if branches_name_pattern: + pattern = re.compile(branches_name_pattern) + new_branches_filtered = [ + branch + for branch in new_branches_filtered + if branch.name is not None + and pattern.search(branch.name.decode("utf-8")) + ] + branches.extend(new_branches_filtered) if len(new_branches) < branches_count + 1: diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -702,6 +702,7 @@ branches_from: bytes = b"", branches_count: int = 1000, target_types: Optional[List[str]] = None, + branches_name_pattern: Optional[str] = None, ) -> Optional[PartialBranches]: """Get the content, possibly partial, of a snapshot with the given id @@ -718,6 +719,8 @@ target types of branch to return (possible values that can be contained in that list are `'content', 'directory', 'revision', 'release', 'snapshot', 'alias'`) + branches_name_pattern: if provided, only return branches whose names + match given pattern Returns: dict: None if the snapshot does not exist; diff --git a/swh/storage/postgresql/db.py b/swh/storage/postgresql/db.py --- a/swh/storage/postgresql/db.py +++ b/swh/storage/postgresql/db.py @@ -28,7 +28,7 @@ """ - current_version = 164 + current_version = 165 def mktemp_dir_entry(self, entry_type, cur=None): self._cursor(cur).execute( @@ -254,17 +254,27 @@ branches_from=b"", branches_count=None, target_types=None, + branches_name_pattern=None, cur=None, ): cur = self._cursor(cur) query = """\ SELECT %s - FROM swh_snapshot_get_by_id(%%s, %%s, %%s, %%s :: snapshot_target[]) + FROM swh_snapshot_get_by_id(%%s, %%s, %%s, %%s :: snapshot_target[], %%s) """ % ", ".join( self.snapshot_get_cols ) - cur.execute(query, (snapshot_id, branches_from, branches_count, target_types)) + cur.execute( + query, + ( + snapshot_id, + branches_from, + branches_count, + target_types, + branches_name_pattern, + ), + ) yield from cur diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -767,6 +767,7 @@ branches_from: bytes = b"", branches_count: int = 1000, target_types: Optional[List[str]] = None, + branches_name_pattern: Optional[str] = None, db=None, cur=None, ) -> Optional[PartialBranches]: @@ -782,6 +783,7 @@ branches_from=branches_from, branches_count=branches_count + 1, target_types=target_types, + branches_name_pattern=branches_name_pattern, cur=cur, ) ) diff --git a/swh/storage/sql/30-schema.sql b/swh/storage/sql/30-schema.sql --- a/swh/storage/sql/30-schema.sql +++ b/swh/storage/sql/30-schema.sql @@ -17,7 +17,7 @@ -- latest schema version insert into dbversion(version, release, description) - values(164, now(), 'Work In Progress'); + values(165, now(), 'Work In Progress'); -- a SHA1 checksum create domain sha1 as bytea check (length(value) = 20); diff --git a/swh/storage/sql/40-funcs.sql b/swh/storage/sql/40-funcs.sql --- a/swh/storage/sql/40-funcs.sql +++ b/swh/storage/sql/40-funcs.sql @@ -657,7 +657,8 @@ create or replace function swh_snapshot_get_by_id(id sha1_git, branches_from bytea default '', branches_count bigint default null, - target_types snapshot_target[] default NULL) + target_types snapshot_target[] default NULL, + branches_name_pattern text default NULL) returns setof snapshot_result language sql stable @@ -680,6 +681,7 @@ select snapshot_id, name, target, target_type from filtered_snapshot_branches where name >= branches_from + and (branches_name_pattern is null or convert_from(name, 'utf-8') ~* branches_name_pattern) order by name limit branches_count; $$; diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -10,6 +10,7 @@ import itertools import math import random +import re from typing import Any, ClassVar, Dict, Iterator, Optional import attr @@ -31,6 +32,7 @@ Revision, SkippedContent, Snapshot, + SnapshotBranch, TargetType, ) from swh.storage import get_storage @@ -2915,6 +2917,109 @@ assert len(branches) == 1 assert alias1 in branches + def test_snapshot_add_get_by_branches_name_pattern(self, swh_storage, sample_data): + snapshot = Snapshot( + branches={ + b"refs/heads/master": SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ), + b"refs/heads/incoming": SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ), + b"refs/pull/1": SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ), + b"refs/pull/2": SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ), + b"dangling": None, + }, + ) + swh_storage.snapshot_add([snapshot]) + + for regexp, nb_results in ( + ("pull", 2), + ("incoming$", 1), + ("^dangling$", 1), + # does not contain heads + ("^((?!heads).)*$", 3), + # contains refs but not master + ("^(?:(?!(?:master)).)*(?:refs)(?:(?!(?:master)).)*$", 3), + ): + pattern = re.compile(regexp) + branches = swh_storage.snapshot_get_branches( + snapshot.id, branches_name_pattern=regexp + )["branches"] + assert len(branches) == nb_results + for branch_name in branches: + assert pattern.search(branch_name.decode("utf-8")) + + def test_snapshot_add_get_by_branches_name_pattern_filtered_paginated( + self, swh_storage, sample_data + ): + pattern = "foo" + nb_branches_by_target_type = 10 + branches = {} + for i in range(nb_branches_by_target_type): + branches[f"branch/directory/bar{i}".encode()] = SnapshotBranch( + target=sample_data.directory.id, target_type=TargetType.DIRECTORY, + ) + branches[f"branch/revision/bar{i}".encode()] = SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ) + branches[f"branch/directory/{pattern}{i}".encode()] = SnapshotBranch( + target=sample_data.directory.id, target_type=TargetType.DIRECTORY, + ) + branches[f"branch/revision/{pattern}{i}".encode()] = SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ) + + snapshot = Snapshot(branches=branches) + swh_storage.snapshot_add([snapshot]) + + regexp = re.compile(pattern) + branches_count = nb_branches_by_target_type // 2 + + for target_type in ( + TargetType.DIRECTORY, + TargetType.REVISION, + ): + target_type_str = target_type.value + partial_branches = swh_storage.snapshot_get_branches( + snapshot.id, + branches_name_pattern=pattern, + target_types=[target_type_str], + branches_count=branches_count, + ) + branches = partial_branches["branches"] + + assert len(branches) == branches_count + for branch_name, branch_data in branches.items(): + assert regexp.search(branch_name.decode("utf-8")) + assert branch_data.target_type == target_type + for i in range(branches_count): + assert f"branch/{target_type_str}/{pattern}{i}".encode() in branches + assert ( + partial_branches["next_branch"] + == f"branch/{target_type_str}/{pattern}{branches_count}".encode() + ) + + partial_branches = swh_storage.snapshot_get_branches( + snapshot.id, + branches_name_pattern=pattern, + target_types=[target_type_str], + branches_from=partial_branches["next_branch"], + ) + branches = partial_branches["branches"] + + assert len(branches) == branches_count + for branch_name, branch_data in branches.items(): + assert regexp.search(branch_name.decode("utf-8")) + assert branch_data.target_type == target_type + for i in range(branches_count, 2 * branches_count): + assert f"branch/{target_type_str}/{pattern}{i}".encode() in branches + assert partial_branches["next_branch"] is None + def test_snapshot_add_get(self, swh_storage, sample_data): snapshot = sample_data.snapshot origin = sample_data.origin