diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -729,25 +729,41 @@ # excluding that origin from the result to respect the limit size origins = origins[:limit] - assert (len(origins)) <= limit + assert len(origins) <= limit return PagedResult(results=origins, next_page_token=next_page_token) def origin_search( - self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False - ): + self, + url_pattern: str, + page_token: Optional[str] = None, + limit: int = 50, + regexp: bool = False, + with_visit: bool = False, + ) -> PagedResult[Origin]: # TODO: remove this endpoint, swh-search should be used instead. + next_page_token = None + offset = int(page_token) if page_token else 0 + origins = self._cql_runner.origin_iter_all() if regexp: pat = re.compile(url_pattern) - origins = [orig for orig in origins if pat.search(orig.url)] + origins = [Origin(orig.url) for orig in origins if pat.search(orig.url)] else: - origins = [orig for orig in origins if url_pattern in orig.url] + origins = [Origin(orig.url) for orig in origins if url_pattern in orig.url] if with_visit: - origins = [orig for orig in origins if orig.next_visit_id > 1] + origins = [Origin(orig.url) for orig in origins if orig.next_visit_id > 1] - return [{"url": orig.url,} for orig in origins[offset : offset + limit]] + origins = origins[offset : offset + limit + 1] + if len(origins) > limit: + # next offset + next_page_token = str(offset + limit) + # excluding that origin from the result to respect the limit size + origins = origins[:limit] + + assert len(origins) <= limit + return PagedResult(results=origins, next_page_token=next_page_token) def origin_add(self, origins: List[Origin]) -> Dict[str, int]: to_add = [ori for ori in origins if self.origin_get_one(ori.url) is None] diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -1023,7 +1023,7 @@ if not regexp: query = query % (origin_cols, "ILIKE") - query_params = ("%" + url_pattern + "%", offset, limit) + query_params = (f"%{url_pattern}%", offset, limit) else: query = query % (origin_cols, "~*") query_params = (url_pattern, offset, limit) @@ -1034,20 +1034,26 @@ cur.execute(query, query_params) def origin_search( - self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False, cur=None + self, + url_pattern: str, + offset: int = 0, + limit: int = 50, + regexp: bool = False, + with_visit: bool = False, + cur=None, ): """Search for origins whose urls contain a provided string pattern or match a provided regular expression. The search is performed in a case insensitive way. Args: - url_pattern (str): the string pattern to search for in origin urls - offset (int): number of found origins to skip before returning + url_pattern: the string pattern to search for in origin urls + offset: number of found origins to skip before returning results - limit (int): the maximum number of found origins to return - regexp (bool): if True, consider the provided pattern as a regular + limit: the maximum number of found origins to return + regexp: if True, consider the provided pattern as a regular expression and returns origins whose urls match it - with_visit (bool): if True, filter out origins with no visit + with_visit: if True, filter out origins with no visit """ self._origin_query( diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -711,20 +711,29 @@ return PagedResult(results=origins, next_page_token=next_page_token) def origin_search( - self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False - ): - origins = map(self._convert_origin, self._origins.values()) + self, + url_pattern: str, + page_token: Optional[str] = None, + limit: int = 50, + regexp: bool = False, + with_visit: bool = False, + ) -> PagedResult[Origin]: + next_page_token = None + offset = int(page_token) if page_token else 0 + + origins = self._origins.values() if regexp: pat = re.compile(url_pattern) - origins = [orig for orig in origins if pat.search(orig["url"])] + origins = [orig for orig in origins if pat.search(orig.url)] else: - origins = [orig for orig in origins if url_pattern in orig["url"]] + origins = [orig for orig in origins if url_pattern in orig.url] + if with_visit: filtered_origins = [] for orig in origins: visits = ( self._origin_visit_get_updated(ov.origin, ov.visit) - for ov in self._origin_visits[orig["url"]] + for ov in self._origin_visits[orig.url] ) for ov in visits: snapshot = ov["snapshot"] @@ -734,17 +743,23 @@ else: filtered_origins = origins - return filtered_origins[offset : offset + limit] + # Take one more origin so we can reuse it as the next page token if any + origins = filtered_origins[offset : offset + limit + 1] + if len(origins) > limit: + # next offset + next_page_token = str(offset + limit) + # excluding that origin from the result to respect the limit size + origins = origins[:limit] + + assert len(origins) <= limit + return PagedResult(results=origins, next_page_token=next_page_token) def origin_count(self, url_pattern, regexp=False, with_visit=False): - return len( - self.origin_search( - url_pattern, - regexp=regexp, - with_visit=with_visit, - limit=len(self._origins), - ) + actual_page = self.origin_search( + url_pattern, regexp=regexp, with_visit=with_visit, limit=len(self._origins), ) + assert actual_page.next_page_token is None + return len(actual_page.results) def origin_add(self, origins: List[Origin]) -> Dict[str, int]: added = 0 diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -1040,24 +1040,28 @@ @remote_api_endpoint("origin/search") def origin_search( - self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False - ): + self, + url_pattern: str, + page_token: Optional[str] = None, + limit: int = 50, + regexp: bool = False, + with_visit: bool = False, + ) -> PagedResult[Origin]: """Search for origins whose urls contain a provided string pattern or match a provided regular expression. The search is performed in a case insensitive way. Args: - url_pattern (str): the string pattern to search for in origin urls - offset (int): number of found origins to skip before returning - results - limit (int): the maximum number of found origins to return - regexp (bool): if True, consider the provided pattern as a regular + url_pattern: the string pattern to search for in origin urls + page_token: opaque token used for pagination + limit: the maximum number of found origins to return + regexp: if True, consider the provided pattern as a regular expression and return origins whose urls match it - with_visit (bool): if True, filter out origins with no visit + with_visit: if True, filter out origins with no visit Yields: - dicts containing origin information as returned - by :meth:`swh.storage.storage.Storage.origin_get`. + PagedResult of Origin + """ ... diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1106,21 +1106,37 @@ return PagedResult(results=origins, next_page_token=next_page_token) @timed - @db_transaction_generator() + @db_transaction() def origin_search( self, - url_pattern, - offset=0, - limit=50, - regexp=False, - with_visit=False, + url_pattern: str, + page_token: Optional[str] = None, + limit: int = 50, + regexp: bool = False, + with_visit: bool = False, db=None, cur=None, - ): + ) -> PagedResult[Origin]: + next_page_token = None + offset = int(page_token) if page_token else 0 + + origins = [] + # Take one more origin so we can reuse it as the next page token if any for origin in db.origin_search( - url_pattern, offset, limit, regexp, with_visit, cur + url_pattern, offset, limit + 1, regexp, with_visit, cur ): - yield dict(zip(db.origin_cols, origin)) + row_d = dict(zip(db.origin_cols, origin)) + origins.append(Origin(url=row_d["url"])) + + if len(origins) > limit: + # next offset + next_page_token = str(offset + limit) + # excluding that origin from the result to respect the limit size + origins = origins[:limit] + + assert len(origins) <= limit + + return PagedResult(results=origins, next_page_token=next_page_token) @timed @db_transaction() diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1576,113 +1576,103 @@ def test_origin_search_single_result(self, swh_storage, sample_data): origin, origin2 = sample_data.origins[:2] - found_origins = list(swh_storage.origin_search(origin.url)) - assert len(found_origins) == 0 + actual_page = swh_storage.origin_search(origin.url) + assert actual_page.next_page_token is None + assert actual_page.results == [] - found_origins = list(swh_storage.origin_search(origin.url, regexp=True)) - assert len(found_origins) == 0 + actual_page = swh_storage.origin_search(origin.url, regexp=True) + assert actual_page.next_page_token is None + assert actual_page.results == [] swh_storage.origin_add([origin]) - origin_data = origin.to_dict() - found_origins = list(swh_storage.origin_search(origin.url)) - - assert len(found_origins) == 1 - assert found_origins[0] == origin_data + actual_page = swh_storage.origin_search(origin.url) + assert actual_page.next_page_token is None + assert actual_page.results == [origin] - found_origins = list( - swh_storage.origin_search(f".{origin.url[1:-1]}.", regexp=True) - ) - assert len(found_origins) == 1 - assert found_origins[0] == origin_data + actual_page = swh_storage.origin_search(f".{origin.url[1:-1]}.", regexp=True) + assert actual_page.next_page_token is None + assert actual_page.results == [origin] swh_storage.origin_add([origin2]) - origin2_data = origin2.to_dict() - found_origins = list(swh_storage.origin_search(origin2.url)) - assert len(found_origins) == 1 - assert found_origins[0] == origin2_data + actual_page = swh_storage.origin_search(origin2.url) + assert actual_page.next_page_token is None + assert actual_page.results == [origin2] - found_origins = list( - swh_storage.origin_search(f".{origin2.url[1:-1]}.", regexp=True) - ) - assert len(found_origins) == 1 - assert found_origins[0] == origin2_data + actual_page = swh_storage.origin_search(f".{origin2.url[1:-1]}.", regexp=True) + assert actual_page.next_page_token is None + assert actual_page.results == [origin2] def test_origin_search_no_regexp(self, swh_storage, sample_data): origin, origin2 = sample_data.origins[:2] - origin_dicts = [o.to_dict() for o in [origin, origin2]] - swh_storage.origin_add([origin, origin2]) # no pagination - found_origins = list(swh_storage.origin_search("/")) - assert len(found_origins) == 2 + actual_page = swh_storage.origin_search("/") + assert actual_page.next_page_token is None + assert actual_page.results == [origin, origin2] # offset=0 - found_origins0 = list(swh_storage.origin_search("/", offset=0, limit=1)) - assert len(found_origins0) == 1 - assert found_origins0[0] in origin_dicts + actual_page = swh_storage.origin_search("/", page_token=None, limit=1) + next_page_token = actual_page.next_page_token + assert next_page_token is not None + assert actual_page.results == [origin] # offset=1 - found_origins1 = list(swh_storage.origin_search("/", offset=1, limit=1)) - assert len(found_origins1) == 1 - assert found_origins1[0] in origin_dicts - - # check both origins were returned - assert found_origins0 != found_origins1 + actual_page = swh_storage.origin_search( + "/", page_token=next_page_token, limit=1 + ) + assert actual_page.next_page_token is None + assert actual_page.results == [origin2] def test_origin_search_regexp_substring(self, swh_storage, sample_data): origin, origin2 = sample_data.origins[:2] - origin_dicts = [o.to_dict() for o in [origin, origin2]] swh_storage.origin_add([origin, origin2]) # no pagination - found_origins = list(swh_storage.origin_search("/", regexp=True)) - assert len(found_origins) == 2 + actual_page = swh_storage.origin_search("/", regexp=True) + assert actual_page.next_page_token is None + assert actual_page.results == [origin, origin2] # offset=0 - found_origins0 = list( - swh_storage.origin_search("/", offset=0, limit=1, regexp=True) + actual_page = swh_storage.origin_search( + "/", page_token=None, limit=1, regexp=True ) - assert len(found_origins0) == 1 - assert found_origins0[0] in origin_dicts + next_page_token = actual_page.next_page_token + assert next_page_token is not None + assert actual_page.results == [origin] # offset=1 - found_origins1 = list( - swh_storage.origin_search("/", offset=1, limit=1, regexp=True) + actual_page = swh_storage.origin_search( + "/", page_token=next_page_token, limit=1, regexp=True ) - assert len(found_origins1) == 1 - assert found_origins1[0] in origin_dicts - - # check both origins were returned - assert found_origins0 != found_origins1 + assert actual_page.next_page_token is None + assert actual_page.results == [origin2] def test_origin_search_regexp_fullstring(self, swh_storage, sample_data): origin, origin2 = sample_data.origins[:2] - origin_dicts = [o.to_dict() for o in [origin, origin2]] swh_storage.origin_add([origin, origin2]) # no pagination - found_origins = list(swh_storage.origin_search(".*/.*", regexp=True)) - assert len(found_origins) == 2 + actual_page = swh_storage.origin_search(".*/.*", regexp=True) + assert actual_page.next_page_token is None + assert actual_page.results == [origin, origin2] # offset=0 - found_origins0 = list( - swh_storage.origin_search(".*/.*", offset=0, limit=1, regexp=True) + actual_page = swh_storage.origin_search( + ".*/.*", page_token=None, limit=1, regexp=True ) - assert len(found_origins0) == 1 - assert found_origins0[0] in origin_dicts + next_page_token = actual_page.next_page_token + assert next_page_token is not None + assert actual_page.results == [origin] # offset=1 - found_origins1 = list( - swh_storage.origin_search(".*/.*", offset=1, limit=1, regexp=True) + actual_page = swh_storage.origin_search( + ".*/.*", page_token=next_page_token, limit=1, regexp=True ) - assert len(found_origins1) == 1 - assert found_origins1[0] in origin_dicts - - # check both origins were returned - assert found_origins0 != found_origins1 + assert actual_page.next_page_token is None + assert actual_page.results == [origin2] def test_origin_visit_add(self, swh_storage, sample_data): origin1 = sample_data.origins[1]