diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -445,7 +445,7 @@ def _get_parent_revs(self, rev_id, seen, limit): if limit and len(seen) >= limit: return - if rev_id in seen: + if rev_id in seen or rev_id not in self._revisions: return seen.add(rev_id) yield self._revisions[rev_id] @@ -812,7 +812,7 @@ raise ValueError('Origin must have either id or (type and url).') origin = None # self._origin_id can return None - if origin_id is not None: + if origin_id is not None and origin_id <= len(self._origins): origin = copy.deepcopy(self._origins[origin_id-1]) origin['id'] = origin_id return origin @@ -1017,14 +1017,15 @@ List of visits. """ - visits = self._origin_visits[origin-1] - if last_visit is not None: - visits = visits[last_visit:] - if limit is not None: - visits = visits[:limit] - for visit in visits: - visit_id = visit['visit'] - yield copy.deepcopy(self._origin_visits[origin-1][visit_id-1]) + if origin <= len(self._origin_visits): + visits = self._origin_visits[origin-1] + if last_visit is not None: + visits = visits[last_visit:] + if limit is not None: + visits = visits[:limit] + for visit in visits: + visit_id = visit['visit'] + yield copy.deepcopy(self._origin_visits[origin-1][visit_id-1]) def origin_visit_get_by(self, origin, visit): """Retrieve origin visit's information. diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -848,6 +848,10 @@ self.assertEqual(len(actual_results), 1) self.assertEqual(actual_results[0], self.revision4) + def test_revision_log_unknown_revision(self): + rev_log = list(self.storage.revision_log([self.revision['id']])) + self.assertEqual(rev_log, []) + @staticmethod def _short_revision(revision): return [revision['id'], revision['parents']] @@ -2139,6 +2143,16 @@ keys_to_check) + def test_origin_get_invalid_id(self): + + invalid_origin_id = 1 + + origin_info = self.storage.origin_get({'id': invalid_origin_id}) + self.assertIsNone(origin_info) + + origin_visits = list(self.storage.origin_visit_get(invalid_origin_id)) + self.assertEqual(origin_visits, []) + @given(gen_origins(min_size=100, max_size=100)) def test_origin_get_range(self, new_origins):