diff --git a/sql/upgrades/170.sql b/sql/upgrades/170.sql new file mode 100644 --- /dev/null +++ b/sql/upgrades/170.sql @@ -0,0 +1,19 @@ +-- SWH DB schema upgrade +-- from_version: 169 +-- to_version: 170 +-- description: Make origin_visit_status.type not null + +insert into dbversion(version, release, description) + values(170, now(), 'Work In Progress'); + +create or replace function swh_snapshot_count_branches(id sha1_git, + branch_name_exclude_prefix bytea default NULL) + returns setof snapshot_size + language sql + stable +as $$ + SELECT target_type, count(name) + from swh_snapshot_get_by_id(swh_snapshot_count_branches.id, + branch_name_exclude_prefix => swh_snapshot_count_branches.branch_name_exclude_prefix) + group by target_type; +$$; diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from collections import Counter import dataclasses import datetime import functools @@ -202,6 +203,14 @@ return decorator +def _next_prefix(prefix: bytes) -> bytes: + next_prefix_int = int.from_bytes(prefix, byteorder="big") + 1 + next_prefix = next_prefix_int.to_bytes( + (next_prefix_int.bit_length() + 7) // 8, byteorder="big" + ) + return next_prefix + + class CqlRunner: """Class managing prepared statements and building queries to be sent to Cassandra.""" @@ -616,16 +625,47 @@ @_prepared_statement( "SELECT ascii_bins_count(target_type) AS counts " "FROM snapshot_branch " - "WHERE snapshot_id = ? " + "WHERE snapshot_id = ? AND name >= ?" ) + def snapshot_count_branches_from_name( + self, snapshot_id: Sha1Git, from_: bytes, *, statement + ) -> Dict[Optional[str], int]: + row = self._execute_with_retries(statement, [snapshot_id, from_]).one() + (nb_none, counts) = row["counts"] + return {None: nb_none, **counts} + + @_prepared_statement( + "SELECT ascii_bins_count(target_type) AS counts " + "FROM snapshot_branch " + "WHERE snapshot_id = ? AND name < ?" + ) + def snapshot_count_branches_before_name( + self, snapshot_id: Sha1Git, before: bytes, *, statement, + ) -> Dict[Optional[str], int]: + row = self._execute_with_retries(statement, [snapshot_id, before]).one() + (nb_none, counts) = row["counts"] + return {None: nb_none, **counts} + def snapshot_count_branches( - self, snapshot_id: Sha1Git, *, statement + self, snapshot_id: Sha1Git, branch_name_exclude_prefix: Optional[bytes] = None, ) -> Dict[Optional[str], int]: """Returns a dictionary from type names to the number of branches of that type.""" - row = self._execute_with_retries(statement, [snapshot_id]).one() - (nb_none, counts) = row["counts"] - return {None: nb_none, **counts} + prefix = branch_name_exclude_prefix + if prefix is None: + return self.snapshot_count_branches_from_name(snapshot_id, b"") + else: + counts = Counter( + self.snapshot_count_branches_before_name(snapshot_id, prefix) + ) + + if prefix.replace(b"\xff", b"") != b"": + counts.update( + self.snapshot_count_branches_from_name( + snapshot_id, _next_prefix(prefix) + ) + ) + return counts @_prepared_select_statement( SnapshotBranchRow, "WHERE snapshot_id = ? AND name >= ? LIMIT ?" @@ -671,13 +711,9 @@ ) nb_branches = len(branches) if nb_branches < limit and prefix.replace(b"\xff", b"") != b"": - next_prefix_int = int.from_bytes(prefix, byteorder="big") + 1 - next_prefix = next_prefix_int.to_bytes( - (next_prefix_int.bit_length() + 7) // 8, byteorder="big" - ) branches += list( self.snapshot_branch_get_from_name( - snapshot_id, next_prefix, limit - nb_branches + snapshot_id, _next_prefix(prefix), limit - nb_branches ) ) return branches diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -667,13 +667,16 @@ } def snapshot_count_branches( - self, snapshot_id: Sha1Git + self, snapshot_id: Sha1Git, branch_name_exclude_prefix: Optional[bytes] = None, ) -> Optional[Dict[Optional[str], int]]: if self._cql_runner.snapshot_missing([snapshot_id]): # Makes sure we don't fetch branches for a snapshot that is # being added. return None - return self._cql_runner.snapshot_count_branches(snapshot_id) + + return self._cql_runner.snapshot_count_branches( + snapshot_id, branch_name_exclude_prefix + ) def snapshot_get_branches( self, diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -389,11 +389,17 @@ def snapshot_branch_add_one(self, branch: SnapshotBranchRow) -> None: self._snapshot_branches.insert(branch) - def snapshot_count_branches(self, snapshot_id: Sha1Git) -> Dict[Optional[str], int]: + def snapshot_count_branches( + self, snapshot_id: Sha1Git, branch_name_exclude_prefix: Optional[bytes] = None, + ) -> Dict[Optional[str], int]: """Returns a dictionary from type names to the number of branches of that type.""" counts: Dict[Optional[str], int] = defaultdict(int) for branch in self._snapshot_branches.get_from_partition_key((snapshot_id,)): + if branch_name_exclude_prefix and branch.name.startswith( + branch_name_exclude_prefix + ): + continue if branch.target_type is None: target_type = None else: diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -727,12 +727,14 @@ @remote_api_endpoint("snapshot/count_branches") def snapshot_count_branches( - self, snapshot_id: Sha1Git + self, snapshot_id: Sha1Git, branch_name_exclude_prefix: Optional[bytes] = None, ) -> Optional[Dict[Optional[str], int]]: """Count the number of branches in the snapshot with the given id Args: snapshot_id: snapshot identifier + branch_name_exclude_prefix: if provided, do not count branches whose name + starts with given prefix Returns: A dict whose keys are the target types of branches and values their diff --git a/swh/storage/postgresql/db.py b/swh/storage/postgresql/db.py --- a/swh/storage/postgresql/db.py +++ b/swh/storage/postgresql/db.py @@ -29,7 +29,7 @@ """ - current_version = 169 + current_version = 170 def mktemp_dir_entry(self, entry_type, cur=None): self._cursor(cur).execute( @@ -239,15 +239,17 @@ snapshot_count_cols = ["target_type", "count"] - def snapshot_count_branches(self, snapshot_id, cur=None): + def snapshot_count_branches( + self, snapshot_id, branch_name_exclude_prefix=None, cur=None, + ): cur = self._cursor(cur) query = """\ - SELECT %s FROM swh_snapshot_count_branches(%%s) + SELECT %s FROM swh_snapshot_count_branches(%%s, %%s) """ % ", ".join( self.snapshot_count_cols ) - cur.execute(query, (snapshot_id,)) + cur.execute(query, (snapshot_id, branch_name_exclude_prefix)) yield from cur diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -804,9 +804,20 @@ @timed @db_transaction(statement_timeout=2000) def snapshot_count_branches( - self, snapshot_id: Sha1Git, db=None, cur=None + self, + snapshot_id: Sha1Git, + branch_name_exclude_prefix: Optional[bytes] = None, + db=None, + cur=None, ) -> Optional[Dict[Optional[str], int]]: - return dict([bc for bc in db.snapshot_count_branches(snapshot_id, cur)]) + return dict( + [ + bc + for bc in db.snapshot_count_branches( + snapshot_id, branch_name_exclude_prefix, cur, + ) + ] + ) @timed @db_transaction(statement_timeout=2000) diff --git a/swh/storage/sql/30-schema.sql b/swh/storage/sql/30-schema.sql --- a/swh/storage/sql/30-schema.sql +++ b/swh/storage/sql/30-schema.sql @@ -17,7 +17,7 @@ -- latest schema version insert into dbversion(version, release, description) - values(169, now(), 'Work In Progress'); + values(170, now(), 'Work In Progress'); -- a SHA1 checksum create domain sha1 as bytea check (length(value) = 20); diff --git a/swh/storage/sql/40-funcs.sql b/swh/storage/sql/40-funcs.sql --- a/swh/storage/sql/40-funcs.sql +++ b/swh/storage/sql/40-funcs.sql @@ -709,13 +709,15 @@ count bigint ); -create or replace function swh_snapshot_count_branches(id sha1_git) +create or replace function swh_snapshot_count_branches(id sha1_git, + branch_name_exclude_prefix bytea default NULL) returns setof snapshot_size language sql stable as $$ SELECT target_type, count(name) - from swh_snapshot_get_by_id(swh_snapshot_count_branches.id) + from swh_snapshot_get_by_id(swh_snapshot_count_branches.id, + branch_name_exclude_prefix => swh_snapshot_count_branches.branch_name_exclude_prefix) group by target_type; $$; diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -2954,6 +2954,57 @@ } assert snp_size == expected_snp_size + def test_snapshot_add_count_branches_with_filtering(self, swh_storage, sample_data): + complete_snapshot = sample_data.snapshots[2] + + actual_result = swh_storage.snapshot_add([complete_snapshot]) + assert actual_result == {"snapshot:add": 1} + + snp_size = swh_storage.snapshot_count_branches( + complete_snapshot.id, branch_name_exclude_prefix=b"release" + ) + + expected_snp_size = { + "alias": 1, + "content": 1, + "directory": 2, + "revision": 1, + "snapshot": 1, + None: 1, + } + assert snp_size == expected_snp_size + + def test_snapshot_add_count_branches_with_filtering_edge_cases( + self, swh_storage, sample_data + ): + snapshot = Snapshot( + branches={ + b"\xaa\xff": SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ), + b"\xaa\xff\x00": SnapshotBranch( + target=sample_data.revision.id, target_type=TargetType.REVISION, + ), + b"\xff\xff": SnapshotBranch( + target=sample_data.release.id, target_type=TargetType.RELEASE, + ), + b"\xff\xff\x00": SnapshotBranch( + target=sample_data.release.id, target_type=TargetType.RELEASE, + ), + b"dangling": None, + }, + ) + + swh_storage.snapshot_add([snapshot]) + + assert swh_storage.snapshot_count_branches( + snapshot.id, branch_name_exclude_prefix=b"\xaa\xff" + ) == {None: 1, "release": 2} + + assert swh_storage.snapshot_count_branches( + snapshot.id, branch_name_exclude_prefix=b"\xff\xff" + ) == {None: 1, "revision": 2} + def test_snapshot_add_get_paginated(self, swh_storage, sample_data): complete_snapshot = sample_data.snapshots[2]