diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -614,15 +614,40 @@ @_prepared_statement( "SELECT ascii_bins_count(target_type) AS counts " "FROM snapshot_branch " - "WHERE snapshot_id = ? " + "WHERE snapshot_id = ? AND name >= ? AND name < ?" ) def snapshot_count_branches( - self, snapshot_id: Sha1Git, *, statement + self, + snapshot_id: Sha1Git, + branch_name_exclude_prefix: Optional[str], + *, + statement, ) -> Dict[Optional[str], int]: """Returns a dictionary from type names to the number of branches of that type.""" - row = self._execute_with_retries(statement, [snapshot_id]).one() + prefix = b"" + next_prefix = b"\xff" + row = self._execute_with_retries( + statement, [snapshot_id, prefix, next_prefix], + ).one() (nb_none, counts) = row["counts"] + + # count branch names starting with prefix and subtstract values + # from global counters + if branch_name_exclude_prefix: + prefix = branch_name_exclude_prefix.encode() + next_prefix = prefix[:-1] + bytes([prefix[-1] + 1]) + row = self._execute_with_retries( + statement, [snapshot_id, prefix, next_prefix], + ).one() + (nb_none_prefix, counts_prefix) = row["counts"] + + for target_type, count in list(counts_prefix.items()): + counts[target_type] -= count + if counts[target_type] == 0: + del counts[target_type] + nb_none -= nb_none_prefix + return {None: nb_none, **counts} @_prepared_select_statement( diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -664,13 +664,16 @@ } def snapshot_count_branches( - self, snapshot_id: Sha1Git + self, snapshot_id: Sha1Git, branch_name_exclude_prefix: Optional[str] = None, ) -> Optional[Dict[Optional[str], int]]: if self._cql_runner.snapshot_missing([snapshot_id]): # Makes sure we don't fetch branches for a snapshot that is # being added. return None - return self._cql_runner.snapshot_count_branches(snapshot_id) + + return self._cql_runner.snapshot_count_branches( + snapshot_id, branch_name_exclude_prefix + ) def snapshot_get_branches( self, diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -386,11 +386,17 @@ def snapshot_branch_add_one(self, branch: SnapshotBranchRow) -> None: self._snapshot_branches.insert(branch) - def snapshot_count_branches(self, snapshot_id: Sha1Git) -> Dict[Optional[str], int]: + def snapshot_count_branches( + self, snapshot_id: Sha1Git, branch_name_exclude_prefix: Optional[str] = None, + ) -> Dict[Optional[str], int]: """Returns a dictionary from type names to the number of branches of that type.""" counts: Dict[Optional[str], int] = defaultdict(int) for branch in self._snapshot_branches.get_from_partition_key((snapshot_id,)): + if branch_name_exclude_prefix and branch.name.startswith( + branch_name_exclude_prefix.encode() + ): + continue if branch.target_type is None: target_type = None else: diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -680,12 +680,14 @@ @remote_api_endpoint("snapshot/count_branches") def snapshot_count_branches( - self, snapshot_id: Sha1Git + self, snapshot_id: Sha1Git, branch_name_exclude_prefix: Optional[str] = None, ) -> Optional[Dict[Optional[str], int]]: """Count the number of branches in the snapshot with the given id Args: snapshot_id: snapshot identifier + branch_name_exclude_prefix: if provided, do not count branches whose name + starts with given prefix Returns: A dict whose keys are the target types of branches and values their diff --git a/swh/storage/postgresql/db.py b/swh/storage/postgresql/db.py --- a/swh/storage/postgresql/db.py +++ b/swh/storage/postgresql/db.py @@ -234,15 +234,17 @@ snapshot_count_cols = ["target_type", "count"] - def snapshot_count_branches(self, snapshot_id, cur=None): + def snapshot_count_branches( + self, snapshot_id, branch_name_exclude_prefix=None, cur=None, + ): cur = self._cursor(cur) query = """\ - SELECT %s FROM swh_snapshot_count_branches(%%s) + SELECT %s FROM swh_snapshot_count_branches(%%s, %%s) """ % ", ".join( self.snapshot_count_cols ) - cur.execute(query, (snapshot_id,)) + cur.execute(query, (snapshot_id, branch_name_exclude_prefix)) yield from cur diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -754,9 +754,20 @@ @timed @db_transaction(statement_timeout=2000) def snapshot_count_branches( - self, snapshot_id: Sha1Git, db=None, cur=None + self, + snapshot_id: Sha1Git, + branch_name_exclude_prefix: Optional[str] = None, + db=None, + cur=None, ) -> Optional[Dict[Optional[str], int]]: - return dict([bc for bc in db.snapshot_count_branches(snapshot_id, cur)]) + return dict( + [ + bc + for bc in db.snapshot_count_branches( + snapshot_id, branch_name_exclude_prefix, cur, + ) + ] + ) @timed @db_transaction(statement_timeout=2000) diff --git a/swh/storage/sql/40-funcs.sql b/swh/storage/sql/40-funcs.sql --- a/swh/storage/sql/40-funcs.sql +++ b/swh/storage/sql/40-funcs.sql @@ -692,13 +692,15 @@ count bigint ); -create or replace function swh_snapshot_count_branches(id sha1_git) +create or replace function swh_snapshot_count_branches(id sha1_git, + branch_name_exclude_prefix text default NULL) returns setof snapshot_size language sql stable as $$ SELECT target_type, count(name) - from swh_snapshot_get_by_id(swh_snapshot_count_branches.id) + from swh_snapshot_get_by_id(swh_snapshot_count_branches.id, + branch_name_exclude_prefix => swh_snapshot_count_branches.branch_name_exclude_prefix) group by target_type; $$; diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -2787,6 +2787,26 @@ } assert snp_size == expected_snp_size + def test_snapshot_add_count_branches_with_filtering(self, swh_storage, sample_data): + complete_snapshot = sample_data.snapshots[2] + + actual_result = swh_storage.snapshot_add([complete_snapshot]) + assert actual_result == {"snapshot:add": 1} + + snp_size = swh_storage.snapshot_count_branches( + complete_snapshot.id, branch_name_exclude_prefix="release" + ) + + expected_snp_size = { + "alias": 1, + "content": 1, + "directory": 2, + "revision": 1, + "snapshot": 1, + None: 1, + } + assert snp_size == expected_snp_size + def test_snapshot_add_get_paginated(self, swh_storage, sample_data): complete_snapshot = sample_data.snapshots[2]