diff --git a/swh/web/tests/api/views/test_origin.py b/swh/web/tests/api/views/test_origin.py --- a/swh/web/tests/api/views/test_origin.py +++ b/swh/web/tests/api/views/test_origin.py @@ -3,10 +3,12 @@ # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information -from hypothesis import given -from rest_framework.test import APITestCase from unittest.mock import patch +from hypothesis import given, strategies +from requests.utils import parse_header_links +from rest_framework.test import APITestCase + from swh.storage.exc import StorageDBError, StorageAPIError from swh.web.common.utils import reverse @@ -15,6 +17,7 @@ origin, new_origin, visit_dates, new_snapshots ) from swh.web.tests.testcase import WebTestCase +from swh.web.tests.data import get_tests_data class OriginApiTestCase(WebTestCase, APITestCase): @@ -295,6 +298,70 @@ (origin['url'], max_visit_id+1) }) + def test_api_origins(self): + origins = get_tests_data()['origins'] + origin_urls = {origin['url'] for origin in origins} + + # Get only one + url = reverse('api-1-origins', + query_params={'origin_count': 1}) + rv = self.client.get(url) + self.assertEqual(rv.status_code, 200, rv.data) + self.assertEqual(rv['Content-Type'], 'application/json') + self.assertEqual(len(rv.data), 1) + self.assertLess({origin['url'] for origin in rv.data}, origin_urls) + + # Get all + url = reverse('api-1-origins', + query_params={'origin_count': len(origins)}) + rv = self.client.get(url) + self.assertEqual(rv.status_code, 200, rv.data) + self.assertEqual(rv['Content-Type'], 'application/json') + self.assertEqual(len(rv.data), len(origins)) + self.assertEqual({origin['url'] for origin in rv.data}, origin_urls) + + # Get "all + 10" + url = reverse('api-1-origins', + query_params={'origin_count': len(origins)+10}) + rv = self.client.get(url) + self.assertEqual(rv.status_code, 200, rv.data) + self.assertEqual(rv['Content-Type'], 'application/json') + self.assertEqual(len(rv.data), len(origins)) + self.assertEqual({origin['url'] for origin in rv.data}, origin_urls) + + @given(strategies.integers(min_value=1)) + def test_api_origins_scroll(self, origin_count): + origins = get_tests_data()['origins'] + origin_urls = {origin['url'] for origin in origins} + + url = reverse('api-1-origins', + query_params={'origin_count': origin_count}) + + results = [] + + while True: + rv = self.client.get(url) + self.assertEqual(rv.status_code, 200, rv.data) + self.assertEqual(rv['Content-Type'], 'application/json') + + results.extend(rv.data) + + if 'Link' in rv: + for link in parse_header_links(rv['Link']): + if link['rel'] == 'next': + # Found link to next page of results + url = link['url'] + break + else: + # No link with 'rel=next' + break + else: + # No Link header + break + + self.assertEqual(len(results), len(origins)) + self.assertEqual({origin['url'] for origin in results}, origin_urls) + @given(origin()) def test_api_origin_by_url(self, origin):