diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -818,6 +818,7 @@ limit: int = 50, regexp: bool = False, with_visit: bool = False, + visit_types: Optional[List[str]] = None, ) -> PagedResult[Origin]: # TODO: remove this endpoint, swh-search should be used instead. next_page_token = None @@ -833,6 +834,23 @@ if with_visit: origin_rows = [row for row in origin_rows if row.next_visit_id > 1] + if visit_types: + + def _has_visit_types(origin, visit_types): + page_token = None + while True: + page = self.origin_visit_get(origin, page_token=page_token) + for origin_visit in page.results: + if origin_visit.type in visit_types: + return True + page_token = page.next_page_token + if page_token is None: + return False + + origin_rows = [ + row for row in origin_rows if _has_visit_types(row.url, visit_types) + ] + origins = [Origin(url=row.url) for row in origin_rows] origins = origins[offset : offset + limit + 1] diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -1003,6 +1003,7 @@ limit: int = 50, regexp: bool = False, with_visit: bool = False, + visit_types: Optional[List[str]] = None, ) -> PagedResult[Origin]: """Search for origins whose urls contain a provided string pattern or match a provided regular expression. @@ -1015,6 +1016,8 @@ regexp: if True, consider the provided pattern as a regular expression and return origins whose urls match it with_visit: if True, filter out origins with no visit + visit_types: Only origins having any of the provided visit types + (e.g. git, svn, pypi) will be returned Yields: PagedResult of Origin diff --git a/swh/storage/postgresql/db.py b/swh/storage/postgresql/db.py --- a/swh/storage/postgresql/db.py +++ b/swh/storage/postgresql/db.py @@ -982,6 +982,7 @@ limit=50, regexp=False, with_visit=False, + visit_types=None, cur=None, ): """ @@ -1014,16 +1015,28 @@ FROM filtered_origins AS o """ - if with_visit: - query += """ - WHERE EXISTS ( - SELECT 1 - FROM origin_visit ov - INNER JOIN origin_visit_status ovs USING (origin, visit) - INNER JOIN snapshot ON ovs.snapshot=snapshot.id - WHERE ov.origin=o.id - ) - """ + if with_visit or visit_types: + visit_predicat = ( + """ + INNER JOIN origin_visit_status ovs USING (origin, visit) + INNER JOIN snapshot ON ovs.snapshot=snapshot.id + """ + if with_visit + else "" + ) + + type_predicat = ( + f"AND ov.type=any(ARRAY{visit_types})" if visit_types else "" + ) + + query += f""" + WHERE EXISTS ( + SELECT 1 + FROM origin_visit ov + {visit_predicat} + WHERE ov.origin=o.id {type_predicat} + ) + """ if not count: query += "OFFSET %s LIMIT %s" @@ -1038,6 +1051,7 @@ limit: int = 50, regexp: bool = False, with_visit: bool = False, + visit_types: Optional[List[str]] = None, cur=None, ): """Search for origins whose urls contain a provided string pattern @@ -1060,6 +1074,7 @@ limit=limit, regexp=regexp, with_visit=with_visit, + visit_types=visit_types, cur=cur, ) yield from cur diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -1147,6 +1147,7 @@ limit: int = 50, regexp: bool = False, with_visit: bool = False, + visit_types: Optional[List[str]] = None, db=None, cur=None, ) -> PagedResult[Origin]: @@ -1156,7 +1157,7 @@ origins = [] # Take one more origin so we can reuse it as the next page token if any for origin in db.origin_search( - url_pattern, offset, limit + 1, regexp, with_visit, cur + url_pattern, offset, limit + 1, regexp, with_visit, visit_types, cur ): row_d = dict(zip(db.origin_cols, origin)) origins.append(Origin(url=row_d["url"])) diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -1748,6 +1748,52 @@ assert actual_page.next_page_token is None assert actual_page.results == [origin2] + def test_origin_search_no_visit_types(self, swh_storage, sample_data): + origin = sample_data.origins[0] + swh_storage.origin_add([origin]) + actual_page = swh_storage.origin_search(origin.url, visit_types=["git"]) + assert actual_page.next_page_token is None + assert actual_page.results == [] + + def test_origin_search_with_visit_types(self, swh_storage, sample_data): + origin, origin2 = sample_data.origins[:2] + swh_storage.origin_add([origin, origin2]) + swh_storage.origin_visit_add( + [ + OriginVisit(origin=origin.url, date=now(), type="git"), + OriginVisit(origin=origin2.url, date=now(), type="svn"), + ] + ) + actual_page = swh_storage.origin_search(origin.url, visit_types=["git"]) + assert actual_page.next_page_token is None + assert actual_page.results == [origin] + + actual_page = swh_storage.origin_search(origin2.url, visit_types=["svn"]) + assert actual_page.next_page_token is None + assert actual_page.results == [origin2] + + def test_origin_search_multiple_visit_types(self, swh_storage, sample_data): + origin = sample_data.origins[0] + swh_storage.origin_add([origin]) + + def _add_visit_type(visit_type): + swh_storage.origin_visit_add( + [OriginVisit(origin=origin.url, date=now(), type=visit_type)] + ) + + def _check_visit_types(visit_types): + actual_page = swh_storage.origin_search(origin.url, visit_types=visit_types) + assert actual_page.next_page_token is None + assert actual_page.results == [origin] + + _add_visit_type("git") + _check_visit_types(["git"]) + _check_visit_types(["git", "hg"]) + + _add_visit_type("hg") + _check_visit_types(["hg"]) + _check_visit_types(["git", "hg"]) + def test_origin_visit_add(self, swh_storage, sample_data): origin1 = sample_data.origins[1] swh_storage.origin_add([origin1])