diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -1307,6 +1307,11 @@ List of visits. """ + if isinstance(origin, str): + origin = self.origin_get([{'url': origin}])[0] + if not origin: + return + origin = origin['id'] if origin <= len(self._origin_visits): visits = self._origin_visits[origin-1] if last_visit is not None: diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1316,6 +1316,11 @@ List of visits. """ + if isinstance(origin, str): + origin = self.origin_get([{'url': origin}], db=db, cur=cur)[0] + if not origin: + return + origin = origin['id'] for line in db.origin_visit_get_all( origin, last_visit=last_visit, limit=limit, cur=cur): data = dict(zip(db.origin_visit_get_cols, line)) diff --git a/swh/storage/tests/storage_testing.py b/swh/storage/tests/storage_testing.py --- a/swh/storage/tests/storage_testing.py +++ b/swh/storage/tests/storage_testing.py @@ -56,6 +56,7 @@ self.storage = None super().tearDown() - def reset_storage_tables(self): + def reset_storage(self): excluded = {'dbversion', 'tool'} self.reset_db_tables(self.TEST_DB_NAME, excluded=excluded) + self.journal_writer.objects[:] = [] diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py --- a/swh/storage/tests/test_api_client.py +++ b/swh/storage/tests/test_api_client.py @@ -66,6 +66,11 @@ shutil.rmtree(self.storage_base) storage.get_journal_writer = get_journal_writer + def reset_storage(self): + excluded = {'dbversion', 'tool'} + self.reset_db_tables(self.TEST_DB_NAME, excluded=excluded) + self.journal_writer.objects[:] = [] + class RemoteMemStorageFixture(ServerTestFixture, unittest.TestCase): def setUp(self): @@ -79,7 +84,9 @@ } } } - self.__storage = InMemoryStorage(journal_writer={'cls': 'inmemory'}) + self.__storage = InMemoryStorage( + journal_writer={'cls': 'inmemory'}) + self._get_storage_patcher = unittest.mock.patch( 'swh.storage.api.server.get_storage', return_value=self.__storage) self._get_storage_patcher.start() @@ -92,6 +99,10 @@ super().tearDown() self._get_storage_patcher.stop() + def reset_storage(self): + self.storage.reset() + self.journal_writer.objects[:] = [] + @pytest.mark.network class TestRemoteMemStorage(CommonTestStorage, RemoteMemStorageFixture): diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py --- a/swh/storage/tests/test_in_memory.py +++ b/swh/storage/tests/test_in_memory.py @@ -21,8 +21,7 @@ """ def setUp(self): super().setUp() - self.storage = Storage(journal_writer={'cls': 'inmemory'}) - self.journal_writer = self.storage.journal_writer + self.reset_storage() @pytest.mark.skip('postgresql-specific test') def test_content_add_db(self): @@ -41,6 +40,10 @@ def test_skipped_content_add(self): pass + def reset_storage(self): + self.storage = Storage(journal_writer={'cls': 'inmemory'}) + self.journal_writer = self.storage.journal_writer + @pytest.mark.property_based class PropTestInMemoryStorage(CommonPropTestStorage, unittest.TestCase): @@ -54,5 +57,5 @@ super().setUp() self.storage = Storage() - def reset_storage_tables(self): + def reset_storage(self): self.storage = Storage() diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -36,7 +36,7 @@ self.maxDiff = None def tearDown(self): - self.reset_storage_tables() + self.reset_storage() super().tearDown() @@ -1556,60 +1556,26 @@ self.assertEqual(len(found_origins), 1) self.assertEqual(found_origins[0], origin2_data) - def test_origin_visit_add(self): - # given - self.assertIsNone(self.storage.origin_get([self.origin2])[0]) - - origin_id = self.storage.origin_add_one(self.origin2) - self.assertIsNotNone(origin_id) - - # when - origin_visit1 = self.storage.origin_visit_add( - origin_id, - type='git', - date=self.date_visit2) - - actual_origin_visits = list(self.storage.origin_visit_get(origin_id)) - self.assertEqual(actual_origin_visits, - [{ - 'origin': origin_id, - 'date': self.date_visit2, - 'visit': origin_visit1['visit'], - 'type': 'git', - 'status': 'ongoing', - 'metadata': None, - 'snapshot': None, - }]) - - expected_origin = self.origin2.copy() - data = { - 'origin': expected_origin, - 'date': self.date_visit2, - 'visit': origin_visit1['visit'], - 'type': 'git', - 'status': 'ongoing', - 'metadata': None, - 'snapshot': None, - } - self.assertEqual(list(self.journal_writer.objects), - [('origin', expected_origin), - ('origin_visit', data)]) + @given(strategies.booleans()) + def test_origin_visit_add(self, use_url): + self.reset_storage() - def test_origin_visit_add_from_url(self): # given self.assertIsNone(self.storage.origin_get([self.origin2])[0]) origin_id = self.storage.origin_add_one(self.origin2) - origin_url = self.origin2['url'] self.assertIsNotNone(origin_id) + origin_id_or_url = self.origin2['url'] if use_url else origin_id + # when origin_visit1 = self.storage.origin_visit_add( - origin_url, + origin_id_or_url, type='git', date=self.date_visit2) - actual_origin_visits = list(self.storage.origin_visit_get(origin_id)) + actual_origin_visits = list(self.storage.origin_visit_get( + origin_id_or_url)) self.assertEqual(actual_origin_visits, [{ 'origin': origin_id, @@ -2791,8 +2757,12 @@ bogus_visit_id) self.assertIsNone(by_ov) - def test_snapshot_get_latest(self): + @given(strategies.booleans()) + def test_snapshot_get_latest(self, use_url): + self.reset_storage() + origin_id = self.storage.origin_add_one(self.origin) + origin_id_or_url = self.origin['url'] if use_url else origin_id origin_visit1 = self.storage.origin_visit_add(origin_id, self.date_visit1) visit1_id = origin_visit1['visit'] @@ -2806,28 +2776,32 @@ visit3_id = origin_visit3['visit'] # Two visits, both with no snapshot: latest snapshot is None - self.assertIsNone(self.storage.snapshot_get_latest(origin_id)) + self.assertIsNone(self.storage.snapshot_get_latest( + origin_id_or_url)) # Add snapshot to visit1, latest snapshot = visit 1 snapshot self.storage.snapshot_add([self.complete_snapshot]) self.storage.origin_visit_update( origin_id, visit1_id, snapshot=self.complete_snapshot['id']) self.assertEqual(self.complete_snapshot, - self.storage.snapshot_get_latest(origin_id)) + self.storage.snapshot_get_latest( + origin_id_or_url)) # Status filter: all three visits are status=ongoing, so no snapshot # returned self.assertIsNone( - self.storage.snapshot_get_latest(origin_id, - allowed_statuses=['full']) + self.storage.snapshot_get_latest( + origin_id_or_url, + allowed_statuses=['full']) ) # Mark the first visit as completed and check status filter again self.storage.origin_visit_update(origin_id, visit1_id, status='full') self.assertEqual( self.complete_snapshot, - self.storage.snapshot_get_latest(origin_id, - allowed_statuses=['full']), + self.storage.snapshot_get_latest( + origin_id_or_url, + allowed_statuses=['full']), ) # Add snapshot to visit2 and check that the new snapshot is returned @@ -2840,8 +2814,9 @@ # Check that the status filter is still working self.assertEqual( self.complete_snapshot, - self.storage.snapshot_get_latest(origin_id, - allowed_statuses=['full']), + self.storage.snapshot_get_latest( + origin_id_or_url, + allowed_statuses=['full']), ) # Add snapshot to visit3 (same date as visit2) and check that @@ -2850,69 +2825,8 @@ self.storage.origin_visit_update( origin_id, visit3_id, snapshot=self.complete_snapshot['id']) self.assertEqual(self.complete_snapshot, - self.storage.snapshot_get_latest(origin_id)) - - def test_snapshot_get_latest_from_url(self): - self.storage.origin_add_one(self.origin) - origin_url = self.origin['url'] - origin_visit1 = self.storage.origin_visit_add(origin_url, - self.date_visit1) - visit1_id = origin_visit1['visit'] - origin_visit2 = self.storage.origin_visit_add(origin_url, - self.date_visit2) - visit2_id = origin_visit2['visit'] - - # Add a visit with the same date as the previous one - origin_visit3 = self.storage.origin_visit_add(origin_url, - self.date_visit2) - visit3_id = origin_visit3['visit'] - - # Two visits, both with no snapshot: latest snapshot is None - self.assertIsNone(self.storage.snapshot_get_latest(origin_url)) - - # Add snapshot to visit1, latest snapshot = visit 1 snapshot - self.storage.snapshot_add([self.complete_snapshot]) - self.storage.origin_visit_update( - origin_url, visit1_id, snapshot=self.complete_snapshot['id']) - self.assertEqual(self.complete_snapshot, - self.storage.snapshot_get_latest(origin_url)) - - # Status filter: both visits are status=ongoing, so no snapshot - # returned - self.assertIsNone( - self.storage.snapshot_get_latest(origin_url, - allowed_statuses=['full']) - ) - - # Mark the first visit as completed and check status filter again - self.storage.origin_visit_update(origin_url, visit1_id, status='full') - self.assertEqual( - self.complete_snapshot, - self.storage.snapshot_get_latest(origin_url, - allowed_statuses=['full']), - ) - - # Add snapshot to visit2 and check that the new snapshot is returned - self.storage.snapshot_add([self.empty_snapshot]) - self.storage.origin_visit_update( - origin_url, visit2_id, snapshot=self.empty_snapshot['id']) - self.assertEqual(self.empty_snapshot, - self.storage.snapshot_get_latest(origin_url)) - - # Check that the status filter is still working - self.assertEqual( - self.complete_snapshot, - self.storage.snapshot_get_latest(origin_url, - allowed_statuses=['full']), - ) - - # Add snapshot to visit3 (same date as visit2) and check that - # the new snapshot is returned - self.storage.snapshot_add([self.complete_snapshot]) - self.storage.origin_visit_update( - origin_url, visit3_id, snapshot=self.complete_snapshot['id']) - self.assertEqual(self.complete_snapshot, - self.storage.snapshot_get_latest(origin_url)) + self.storage.snapshot_get_latest( + origin_id_or_url)) def test_snapshot_get_latest__missing_snapshot(self): origin_id = self.storage.origin_add_one(self.origin) @@ -3696,7 +3610,7 @@ @given(gen_contents(min_size=1, max_size=4)) def test_generate_content_get(self, contents): - self.reset_storage_tables() + self.reset_storage() # add contents to storage self.storage.content_add(contents) @@ -3710,7 +3624,7 @@ @given(gen_contents(min_size=1, max_size=4)) def test_generate_content_get_metadata(self, contents): - self.reset_storage_tables() + self.reset_storage() # add contents to storage self.storage.content_add(contents) @@ -3765,7 +3679,7 @@ @given(gen_contents(min_size=1, max_size=4)) def test_generate_content_get_range_no_limit(self, contents): """content_get_range returns contents within range provided""" - self.reset_storage_tables() + self.reset_storage() # add contents to storage self.storage.content_add(contents) @@ -3790,7 +3704,7 @@ @given(gen_contents(min_size=4, max_size=4)) def test_generate_content_get_range_limit(self, contents): """content_get_range paginates results if limit exceeded""" - self.reset_storage_tables() + self.reset_storage() contents_map = {c['sha1']: c for c in contents} # add contents to storage @@ -3850,7 +3764,7 @@ @given(strategies.sets(origins().map(lambda x: tuple(x.to_dict().items())), min_size=6, max_size=15)) def test_origin_get_range(self, new_origins): - self.reset_storage_tables() + self.reset_storage() new_origins = list(map(dict, new_origins)) nb_origins = len(new_origins) @@ -3924,7 +3838,7 @@ @settings(suppress_health_check=[HealthCheck.too_slow]) @given(strategies.lists(objects(), max_size=2)) def test_add_arbitrary(self, objects): - self.reset_storage_tables() + self.reset_storage() for (obj_type, obj) in objects: obj = obj.to_dict() if obj_type == 'origin_visit': @@ -3949,6 +3863,8 @@ # datetimes for the remote server @given(strategies.booleans()) def test_fetch_history(self, use_url): + self.reset_storage() + origin = self.storage.origin_add_one(self.origin) if use_url: origin_id = self.origin['url']