diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -1269,31 +1269,32 @@ rows = self._cql_runner.origin_visit_get(origin, visit_from, extra_limit, order) visits: List[OriginVisit] = [converters.row_to_visit(row) for row in rows] - assert visits[0].visit is not None - assert visits[-1].visit is not None - visit_from = min(visits[0].visit, visits[-1].visit) - visit_to = max(visits[0].visit, visits[-1].visit) - - # Then, fetch all statuses associated to these visits - statuses_rows = self._cql_runner.origin_visit_status_get_all_range( - origin, visit_from, visit_to - ) - visit_statuses: Dict[int, List[OriginVisitStatus]] = defaultdict(list) - for status_row in statuses_rows: - if allowed_statuses and status_row.status not in allowed_statuses: - continue - if require_snapshot and status_row.snapshot is None: - continue - visit_status = converters.row_to_visit_status(status_row) - visit_statuses[visit_status.visit].append(visit_status) - - # Add pagination if there are more visits - assert len(visits) <= extra_limit - if len(visits) == extra_limit: - # excluding that visit from the result to respect the limit size - visits = visits[:limit] - # last visit id is the next page token - next_page_token = str(visits[-1].visit) + if visits: + assert visits[0].visit is not None + assert visits[-1].visit is not None + visit_from = min(visits[0].visit, visits[-1].visit) + visit_to = max(visits[0].visit, visits[-1].visit) + + # Then, fetch all statuses associated to these visits + statuses_rows = self._cql_runner.origin_visit_status_get_all_range( + origin, visit_from, visit_to + ) + visit_statuses: Dict[int, List[OriginVisitStatus]] = defaultdict(list) + for status_row in statuses_rows: + if allowed_statuses and status_row.status not in allowed_statuses: + continue + if require_snapshot and status_row.snapshot is None: + continue + visit_status = converters.row_to_visit_status(status_row) + visit_statuses[visit_status.visit].append(visit_status) + + # Add pagination if there are more visits + assert len(visits) <= extra_limit + if len(visits) == extra_limit: + # excluding that visit from the result to respect the limit size + visits = visits[:limit] + # last visit id is the next page token + next_page_token = str(visits[-1].visit) results = [ OriginVisitWithStatuses(visit=visit, statuses=visit_statuses[visit.visit]) diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -1204,12 +1204,10 @@ visit_statuses[row_d["visit"]].append(OriginVisitStatus(**row_d)) - results = [ - OriginVisitWithStatuses( - visit=visit, statuses=visit_statuses[visit.visit] - ) - for visit in visits - ] + results = [ + OriginVisitWithStatuses(visit=visit, statuses=visit_statuses[visit.visit]) + for visit in visits + ] return PagedResult(results=results, next_page_token=next_page_token) diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -2349,6 +2349,16 @@ assert actual_page.next_page_token is None assert actual_page.results == [ovws1] + # should return empty results if page_token is last visit + actual_page = swh_storage.origin_visit_get_with_statuses( + origin.url, + allowed_statuses=allowed_statuses, + require_snapshot=require_snapshot, + page_token=str(ov3.visit), + ) + assert actual_page.next_page_token is None + assert actual_page.results == [] + def test_origin_visit_status_get__unknown_cases(self, swh_storage, sample_data): origin = sample_data.origin actual_page = swh_storage.origin_visit_status_get("foobar", 1)