diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -863,6 +863,20 @@ def origin_iter_all(self, *, statement) -> Iterable[OriginRow]: return map(OriginRow.from_dict, self._execute_with_retries(statement, [])) + @_prepared_statement( + f""" + UPDATE {OriginRow.TABLE} + SET next_visit_id=? + WHERE sha1 = ? IF next_visit_id None: + origin_sha1 = hash_url(origin_url) + next_id = visit_id + 1 + self._execute_with_retries(statement, [next_id, origin_sha1, next_id]) + @_prepared_statement(f"SELECT next_visit_id FROM {OriginRow.TABLE} WHERE sha1 = ?") def _origin_get_next_visit_id(self, origin_sha1: bytes, *, statement) -> int: rows = list(self._execute_with_retries(statement, [origin_sha1])) diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -1014,7 +1014,11 @@ nb_visits = 0 for visit in visits: nb_visits += 1 - if not visit.visit: + if visit.visit: + # Set origin.next_visit_id = max(origin.next_visit_id, visit.visit+1) + # so the next loader run does not reuse the id. + self._cql_runner.origin_bump_next_visit_id(visit.origin, visit.visit) + else: visit_id = self._cql_runner.origin_generate_unique_visit_id( visit.origin ) diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -480,6 +480,10 @@ for (clustering_key, row) in partition.items() ) + def origin_bump_next_visit_id(self, origin_url: str, visit_id: int) -> None: + origin = list(self.origin_get_by_url(origin_url))[0] + origin.next_visit_id = max(origin.next_visit_id, visit_id + 1) + def origin_generate_unique_visit_id(self, origin_url: str) -> int: origin = list(self.origin_get_by_url(origin_url))[0] visit_id = origin.next_visit_id diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -2149,6 +2149,9 @@ assert ov1 == origin_visit1 assert ov2 == origin_visit2 + assert ov1.visit == 1 + assert ov2.visit == 2 + ovs1 = OriginVisitStatus( origin=ov1.origin, visit=ov1.visit, @@ -2182,6 +2185,51 @@ for obj in expected_objects: assert obj in actual_objects + def test_origin_visit_add_replayed(self, swh_storage, sample_data): + """Tests adding a visit with an id makes sure the next id is higher""" + origin1 = sample_data.origins[1] + swh_storage.origin_add([origin1]) + + date_visit = now() + date_visit2 = date_visit + datetime.timedelta(minutes=1) + + date_visit = round_to_milliseconds(date_visit) + date_visit2 = round_to_milliseconds(date_visit2) + + visit1 = OriginVisit( + origin=origin1.url, date=date_visit, type=sample_data.type_visit1, visit=42 + ) + visit2 = OriginVisit( + origin=origin1.url, date=date_visit2, type=sample_data.type_visit2, + ) + + # add once + ov1, ov2 = swh_storage.origin_visit_add([visit1, visit2]) + # then again (will be ignored as they already exist) + origin_visit1, origin_visit2 = swh_storage.origin_visit_add([ov1, ov2]) + assert ov1 == origin_visit1 + assert ov2 == origin_visit2 + + assert ov1.visit == 42 + assert ov2.visit == 43 + + visit3 = OriginVisit( + origin=origin1.url, date=date_visit, type=sample_data.type_visit1, visit=12 + ) + visit4 = OriginVisit( + origin=origin1.url, date=date_visit2, type=sample_data.type_visit2, + ) + + # add once + ov3, ov4 = swh_storage.origin_visit_add([visit3, visit4]) + # then again (will be ignored as they already exist) + origin_visit3, origin_visit4 = swh_storage.origin_visit_add([ov3, ov4]) + assert ov3 == origin_visit3 + assert ov4 == origin_visit4 + + assert ov3.visit == 12 + assert ov4.visit == 44 + def test_origin_visit_add_validation(self, swh_storage, sample_data): """Unknown origin when adding visits should raise""" visit = attr.evolve(sample_data.origin_visit, origin="something-unknonw")