diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -19,6 +19,7 @@ Tuple, Type, TypeVar, + Union, ) from cassandra import CoordinationFailure @@ -173,6 +174,32 @@ ) +def _prepared_select_statements( + row_class: Type[BaseRow], queries: Dict[Any, str], +) -> Callable[[Callable[..., TRet]], Callable[..., TRet]]: + """Like _prepared_statement, but supports multiple statements, passed a dict, + and passes a dict of prepared statements to the decorated method""" + cols = row_class.cols() + + statement_start = f"SELECT {', '.join(cols)} FROM {row_class.TABLE} " + + def decorator(f): + @functools.wraps(f) + def newf(self, *args, **kwargs) -> TRet: + if f.__name__ not in self._prepared_statements: + self._prepared_statements[f.__name__] = { + key: self._session.prepare(statement_start + query) + for (key, query) in queries.items() + } + return f( + self, *args, **kwargs, statements=self._prepared_statements[f.__name__] + ) + + return newf + + return decorator + + class CqlRunner: """Class managing prepared statements and building queries to be sent to Cassandra.""" @@ -188,7 +215,13 @@ self._cluster.register_user_type(keyspace, "microtimestamp", Timestamp) self._cluster.register_user_type(keyspace, "person", Person) - self._prepared_statements: Dict[str, PreparedStatement] = {} + # directly a PreparedStatement for methods decorated with + # @_prepared_statements (and its wrappers, _prepared_insert_statement, + # _prepared_exists_statement, and _prepared_select_statement); + # and a dict of PreparedStatements with @_prepared_select_statements + self._prepared_statements: Dict[ + str, Union[PreparedStatement, Dict[Any, PreparedStatement]] + ] = {} ########################## # Common utility functions @@ -658,54 +691,39 @@ # 'origin_visit' table ########################## - @_prepared_select_statement( - OriginVisitRow, "WHERE origin = ? AND visit > ? ORDER BY visit ASC LIMIT ?" - ) - def _origin_visit_get_pagination_asc( - self, origin_url: str, last_visit: int, limit: int, *, statement - ) -> ResultSet: - return self._execute_with_retries(statement, [origin_url, last_visit, limit]) - - @_prepared_select_statement( - OriginVisitRow, "WHERE origin = ? AND visit < ? ORDER BY visit DESC LIMIT ?" - ) - def _origin_visit_get_pagination_desc( - self, origin_url: str, last_visit: int, limit: int, *, statement - ) -> ResultSet: - return self._execute_with_retries(statement, [origin_url, last_visit, limit]) - - @_prepared_select_statement( - OriginVisitRow, "WHERE origin = ? ORDER BY visit ASC LIMIT ?" - ) - def _origin_visit_get_no_pagination_asc( - self, origin_url: str, limit: int, *, statement - ) -> ResultSet: - return self._execute_with_retries(statement, [origin_url, limit]) - - @_prepared_select_statement( - OriginVisitRow, "WHERE origin = ? ORDER BY visit DESC LIMIT ?" + @_prepared_select_statements( + OriginVisitRow, + { + (True, ListOrder.ASC): ( + "WHERE origin = ? AND visit > ? ORDER BY visit ASC LIMIT ?" + ), + (True, ListOrder.DESC): ( + "WHERE origin = ? AND visit < ? ORDER BY visit DESC LIMIT ?" + ), + (False, ListOrder.ASC): "WHERE origin = ? ORDER BY visit ASC LIMIT ?", + (False, ListOrder.DESC): "WHERE origin = ? ORDER BY visit DESC LIMIT ?", + }, ) - def _origin_visit_get_no_pagination_desc( - self, origin_url: str, limit: int, *, statement - ) -> ResultSet: - return self._execute_with_retries(statement, [origin_url, limit]) - def origin_visit_get( - self, origin_url: str, last_visit: Optional[int], limit: int, order: ListOrder, + self, + origin_url: str, + last_visit: Optional[int], + limit: int, + order: ListOrder, + *, + statements, ) -> Iterable[OriginVisitRow]: args: List[Any] = [origin_url] if last_visit is not None: - page_name = "pagination" args.append(last_visit) - else: - page_name = "no_pagination" args.append(limit) - method_name = f"_origin_visit_get_{page_name}_{order.value}" - origin_visit_get_method = getattr(self, method_name) - return map(OriginVisitRow.from_dict, origin_visit_get_method(*args)) + statement = statements[(last_visit is not None, order)] + return map( + OriginVisitRow.from_dict, self._execute_with_retries(statement, args) + ) @_prepared_insert_statement(OriginVisitRow) def origin_visit_add_one(self, visit: OriginVisitRow, *, statement) -> None: @@ -757,54 +775,25 @@ # 'origin_visit_status' table ########################## - @_prepared_select_statement( - OriginVisitStatusRow, - "WHERE origin = ? AND visit = ? AND date >= ? ORDER BY date ASC LIMIT ?", - ) - def _origin_visit_status_get_with_date_asc_limit( - self, - origin: str, - visit: int, - date_from: datetime.datetime, - limit: int, - *, - statement, - ) -> ResultSet: - return self._execute_with_retries(statement, [origin, visit, date_from, limit]) - - @_prepared_select_statement( - OriginVisitStatusRow, - "WHERE origin = ? AND visit = ? AND date <= ? ORDER BY visit DESC LIMIT ?", - ) - def _origin_visit_status_get_with_date_desc_limit( - self, - origin: str, - visit: int, - date_from: datetime.datetime, - limit: int, - *, - statement, - ) -> ResultSet: - return self._execute_with_retries(statement, [origin, visit, date_from, limit]) - - @_prepared_select_statement( - OriginVisitStatusRow, - "WHERE origin = ? AND visit = ? ORDER BY visit ASC LIMIT ?", - ) - def _origin_visit_status_get_with_no_date_asc_limit( - self, origin: str, visit: int, limit: int, *, statement - ) -> ResultSet: - return self._execute_with_retries(statement, [origin, visit, limit]) - - @_prepared_select_statement( + @_prepared_select_statements( OriginVisitStatusRow, - "WHERE origin = ? AND visit = ? ORDER BY visit DESC LIMIT ?", + { + (True, ListOrder.ASC): ( + "WHERE origin = ? AND visit = ? AND date >= ? " + "ORDER BY visit ASC LIMIT ?" + ), + (True, ListOrder.DESC): ( + "WHERE origin = ? AND visit = ? AND date <= ? " + "ORDER BY visit DESC LIMIT ?" + ), + (False, ListOrder.ASC): ( + "WHERE origin = ? AND visit = ? ORDER BY visit ASC LIMIT ?" + ), + (False, ListOrder.DESC): ( + "WHERE origin = ? AND visit = ? ORDER BY visit DESC LIMIT ?" + ), + }, ) - def _origin_visit_status_get_with_no_date_desc_limit( - self, origin: str, visit: int, limit: int, *, statement - ) -> ResultSet: - return self._execute_with_retries(statement, [origin, visit, limit]) - def origin_visit_status_get_range( self, origin: str, @@ -812,21 +801,20 @@ date_from: Optional[datetime.datetime], limit: int, order: ListOrder, + *, + statements, ) -> Iterable[OriginVisitStatusRow]: args: List[Any] = [origin, visit] if date_from is not None: - date_name = "date" args.append(date_from) - else: - date_name = "no_date" args.append(limit) - method_name = f"_origin_visit_status_get_with_{date_name}_{order.value}_limit" - origin_visit_status_get_method = getattr(self, method_name) + statement = statements[(date_from is not None, order)] + return map( - OriginVisitStatusRow.from_dict, origin_visit_status_get_method(*args) + OriginVisitStatusRow.from_dict, self._execute_with_retries(statement, args) ) @_prepared_insert_statement(OriginVisitStatusRow)