diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -1123,6 +1123,14 @@ self._execute_with_retries(statement, [origin, visit]), ) + @_prepared_statement("SELECT snapshot FROM origin_visit_status WHERE origin = ?") + def origin_snapshot_get_all(self, origin: str, *, statement) -> Iterable[Sha1Git]: + yield from { + d["snapshot"] + for d in self._execute_with_retries(statement, [origin]) + if d["snapshot"] is not None + } + ########################## # 'metadata_authority' table ########################## diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -1090,6 +1090,10 @@ "The Cassandra backend does not implement origin_count" ) + @timed + def origin_snapshot_get_all(self, origin_url: str) -> List[Sha1Git]: + return list(self._cql_runner.origin_snapshot_get_all(origin_url)) + @timed @process_metrics def origin_add(self, origins: List[Origin]) -> Dict[str, int]: diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -583,6 +583,18 @@ statuses.sort(key=lambda s: s.date, reverse=True) return iter(statuses) + def origin_snapshot_get_all(self, origin: str) -> Iterator[Sha1Git]: + """Return all snapshots for a given origin + + """ + return iter( + { + s.snapshot + for s in self._origin_visit_statuses.get_from_partition_key((origin,)) + if s.snapshot is not None + } + ) + ########################## # 'metadata_authority' table ########################## diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -1144,6 +1144,19 @@ """ ... + @remote_api_endpoint("origin/snapshot/get") + def origin_snapshot_get_all(self, origin_url: str) -> List[Sha1Git]: + """Return all unique snapshot identifiers resulting from origin visits. + + Args: + origin_url: origin URL + + Returns: + list of sha1s + + """ + ... + @remote_api_endpoint("origin/add_multi") def origin_add(self, origins: List[Origin]) -> Dict[str, int]: """Add origins to the storage diff --git a/swh/storage/postgresql/db.py b/swh/storage/postgresql/db.py --- a/swh/storage/postgresql/db.py +++ b/swh/storage/postgresql/db.py @@ -1209,6 +1209,16 @@ ] release_get_cols = release_add_cols + def origin_snapshot_get_all(self, origin_url: str, cur=None) -> Iterable[Sha1Git]: + cur = self._cursor(cur) + query = f"""\ + SELECT DISTINCT snapshot FROM origin_visit_status ovs + INNER JOIN origin o ON o.id = ovs.origin + WHERE o.url = '{origin_url}' and snapshot IS NOT NULL; + """ + cur.execute(query) + yield from map(lambda row: row[0], cur) + def release_get_from_list(self, releases, cur=None): cur = self._cursor(cur) query_keys = ", ".join( diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -1350,6 +1350,13 @@ ) -> int: return db.origin_count(url_pattern, regexp, with_visit, cur) + @timed + @db_transaction() + def origin_snapshot_get_all( + self, origin_url: str, *, db: Db, cur=None + ) -> List[Sha1Git]: + return list(db.origin_snapshot_get_all(origin_url, cur)) + @timed @process_metrics @db_transaction() diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -2034,6 +2034,54 @@ random_origin_visit = swh_storage.origin_visit_status_get_random(visit_type) assert random_origin_visit is None + def test_origin_snapshot_get_all(self, swh_storage, sample_data): + origin = sample_data.origins[0] + swh_storage.origin_add([origin]) + + # add some random visits within the selection range + visits = self._generate_random_visits() + visit_type = "git" + + # set first visit to a null snapshot + visit = swh_storage.origin_visit_add( + [OriginVisit(origin=origin.url, date=visits[0], type=visit_type,)] + )[0] + swh_storage.origin_visit_status_add( + [ + OriginVisitStatus( + origin=origin.url, + visit=visit.visit, + date=now(), + status="created", + snapshot=None, + ) + ] + ) + + # add visits to origin + snapshots = set() + for date_visit in visits[1:]: + visit = swh_storage.origin_visit_add( + [OriginVisit(origin=origin.url, date=date_visit, type=visit_type,)] + )[0] + # pick a random snapshot and keep track of it + snapshot = random.choice(sample_data.snapshots).id + snapshots.add(snapshot) + swh_storage.origin_visit_status_add( + [ + OriginVisitStatus( + origin=origin.url, + visit=visit.visit, + date=now(), + status="full", + snapshot=snapshot, + ) + ] + ) + + # check expected snapshots are returned + assert set(swh_storage.origin_snapshot_get_all(origin.url)) == snapshots + def test_origin_get_by_sha1(self, swh_storage, sample_data): origin = sample_data.origin assert swh_storage.origin_get([origin.url])[0] is None