diff --git a/swh/storage/api/client.py b/swh/storage/api/client.py --- a/swh/storage/api/client.py +++ b/swh/storage/api/client.py @@ -108,7 +108,7 @@ }) def snapshot_get_branches(self, snapshot_id, branches_from=b'', - branches_count=None, target_types=None): + branches_count=1000, target_types=None): return self.post('snapshot/get_branches', { 'snapshot_id': snapshot_id, 'branches_from': branches_from, diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -20,13 +20,16 @@ from .exc import StorageDBError from .algos import diff -from swh.model.hashutil import ALGORITHMS +from swh.model.hashutil import ALGORITHMS, hash_to_bytes from swh.objstorage import get_objstorage from swh.objstorage.exc import ObjNotFoundError # Max block size of contents to return BULK_BLOCK_CONTENT_LEN_MAX = 10000 +EMPTY_SNAPSHOT_ID = hash_to_bytes('1a8893e6a86f444e8be8e7bda6cb34fb1735a00e') +"""Identifier for the empty snapshot""" + class Storage(): """SWH storage proxy, encompassing DB and object storage @@ -745,39 +748,8 @@ or :const:`None` if the snapshot has less than 1000 branches. """ - max_branches = 1000 - branches = {} - next_branch = None - fetched_branches = list(db.snapshot_get_by_id( - snapshot_id, branches_count=max_branches+1, cur=cur)) - for branch in fetched_branches[:max_branches]: - branch = dict(zip(db.snapshot_get_cols, branch)) - del branch['snapshot_id'] - name = branch.pop('name') - if branch == {'target': None, 'target_type': None}: - branch = None - branches[name] = branch - - if len(fetched_branches) > max_branches: - branch = dict(zip(db.snapshot_get_cols, fetched_branches[-1])) - next_branch = branch['name'] - if branches: - return { - 'id': snapshot_id, - 'branches': branches, - 'next_branch': next_branch - } - - if db.snapshot_exists(snapshot_id, cur): - # empty snapshot - return { - 'id': snapshot_id, - 'branches': {}, - 'next_branch': None - } - - return None + return self.snapshot_get_branches(snapshot_id, db=db, cur=cur) @db_transaction(statement_timeout=2000) def snapshot_get_by_origin_visit(self, origin, visit, db=None, cur=None): @@ -863,7 +835,7 @@ @db_transaction(statement_timeout=2000) def snapshot_get_branches(self, snapshot_id, branches_from=b'', - branches_count=None, target_types=None, + branches_count=1000, target_types=None, db=None, cur=None): """Get the content, possibly partial, of a snapshot with the given id @@ -881,14 +853,30 @@ contained in that list are `'content', 'directory', 'revision', 'release', 'snapshot', 'alias'`) Returns: - dict: a dict with two keys: + dict: a dict with three keys: * **id**: identifier of the snapshot * **branches**: a dict of branches contained in the snapshot whose keys are the branches' names. + * **next_branch**: the name of the first branch not returned + or :const:`None` if the snapshot has less than + `branches_count` branches after `branches_from` included. """ + if snapshot_id == EMPTY_SNAPSHOT_ID: + return { + 'id': snapshot_id, + 'branches': {}, + 'next_branch': None, + } + branches = {} - for branch in db.snapshot_get_by_id(snapshot_id, branches_from, - branches_count, target_types, cur): + next_branch = None + + fetched_branches = list(db.snapshot_get_by_id( + snapshot_id, branches_from=branches_from, + branches_count=branches_count+1, target_types=target_types, + cur=cur, + )) + for branch in fetched_branches[:branches_count]: branch = dict(zip(db.snapshot_get_cols, branch)) del branch['snapshot_id'] name = branch.pop('name') @@ -896,11 +884,16 @@ branch = None branches[name] = branch - if branches: - return {'id': snapshot_id, 'branches': branches} + if len(fetched_branches) > branches_count: + branch = dict(zip(db.snapshot_get_cols, fetched_branches[-1])) + next_branch = branch['name'] - if db.snapshot_exists(snapshot_id, cur): - return {'id': snapshot_id, 'branches': {}} + if branches: + return { + 'id': snapshot_id, + 'branches': branches, + 'next_branch': next_branch, + } return None diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1228,6 +1228,7 @@ name: branches[name] for name in branch_names[rel_idx:] }, + 'next_branch': None, } self.assertEqual(snapshot, expected_snapshot) @@ -1240,6 +1241,7 @@ 'branches': { branch_names[0]: branches[branch_names[0]], }, + 'next_branch': b'content', } self.assertEqual(snapshot, expected_snapshot) @@ -1253,6 +1255,7 @@ name: branches[name] for name in branch_names[dir_idx:dir_idx + 3] }, + 'next_branch': branch_names[dir_idx + 3], } self.assertEqual(snapshot, expected_snapshot) @@ -1278,6 +1281,7 @@ for name, tgt in branches.items() if tgt and tgt['target_type'] in ['release', 'revision'] }, + 'next_branch': None, } self.assertEqual(snapshot, expected_snapshot) @@ -1292,6 +1296,7 @@ for name, tgt in branches.items() if tgt and tgt['target_type'] == 'alias' }, + 'next_branch': None, } self.assertEqual(snapshot, expected_snapshot)