diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -17,7 +17,7 @@ import attr -from swh.model.model import Content, Directory, Revision, Release +from swh.model.model import Content, Directory, Revision, Release, Snapshot from swh.model.hashutil import DEFAULT_ALGORITHMS from swh.objstorage import get_objstorage from swh.objstorage.exc import ObjNotFoundError @@ -762,19 +762,17 @@ """ count = 0 + snapshots = (Snapshot.from_dict(d) for d in snapshots) + snapshots = (snap for snap in snapshots + if snap.id not in self._snapshots) for snapshot in snapshots: - snapshot_id = snapshot['id'] - if snapshot_id not in self._snapshots: - if self.journal_writer: - self.journal_writer.write_addition('snapshot', snapshot) - - self._snapshots[snapshot_id] = { - 'id': snapshot_id, - 'branches': copy.deepcopy(snapshot['branches']), - '_sorted_branch_names': sorted(snapshot['branches']) - } - self._objects[snapshot_id].append(('snapshot', snapshot_id)) - count += 1 + if self.journal_writer: + self.journal_writer.write_addition('snapshot', snapshot) + + sorted_branch_names = sorted(snapshot.branches) + self._snapshots[snapshot.id] = (snapshot, sorted_branch_names) + self._objects[snapshot.id].append(('snapshot', snapshot.id)) + count += 1 return {'snapshot:add': count} @@ -893,9 +891,9 @@ dict: A dict whose keys are the target types of branches and values their corresponding amount """ - branches = list(self._snapshots[snapshot_id]['branches'].values()) - return collections.Counter(branch['target_type'] if branch else None - for branch in branches) + (snapshot, _) = self._snapshots[snapshot_id] + return collections.Counter(branch.target_type.value if branch else None + for branch in snapshot.branches.values()) def snapshot_get_branches(self, snapshot_id, branches_from=b'', branches_count=1000, target_types=None): @@ -924,18 +922,18 @@ or :const:`None` if the snapshot has less than `branches_count` branches after `branches_from` included. """ - snapshot = self._snapshots.get(snapshot_id) - if snapshot is None: + res = self._snapshots.get(snapshot_id) + if res is None: return None - sorted_branch_names = snapshot['_sorted_branch_names'] + (snapshot, sorted_branch_names) = res from_index = bisect.bisect_left( sorted_branch_names, branches_from) if target_types: next_branch = None branches = {} for branch_name in sorted_branch_names[from_index:]: - branch = snapshot['branches'][branch_name] - if branch and branch['target_type'] in target_types: + branch = snapshot.branches[branch_name] + if branch and branch.target_type.value in target_types: if len(branches) < branches_count: branches[branch_name] = branch else: @@ -945,12 +943,16 @@ # As there is no 'target_types', we can do that much faster to_index = from_index + branches_count returned_branch_names = sorted_branch_names[from_index:to_index] - branches = {branch_name: snapshot['branches'][branch_name] + branches = {branch_name: snapshot.branches[branch_name] for branch_name in returned_branch_names} if to_index >= len(sorted_branch_names): next_branch = None else: next_branch = sorted_branch_names[to_index] + + branches = {name: branch.to_dict() if branch else None + for (name, branch) in branches.items()} + return { 'id': snapshot_id, 'branches': branches, diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -2453,6 +2453,19 @@ {**self.snapshot, 'next_branch': None}, self.storage.snapshot_get(self.snapshot['id'])) + def test_snapshot_add_validation(self): + snap = copy.deepcopy(self.snapshot) + snap['branches'][b'foo'] = {'target_type': 'revision'} + + with self.assertRaisesRegex(KeyError, 'target'): + self.storage.snapshot_add([snap]) + + snap = copy.deepcopy(self.snapshot) + snap['branches'][b'foo'] = {'target': b'\x42'*20} + + with self.assertRaisesRegex(KeyError, 'target_type'): + self.storage.snapshot_add([snap]) + def test_snapshot_add_count_branches(self): origin_id = self.storage.origin_add_one(self.origin) origin_visit1 = self.storage.origin_visit_add(origin_id,