Changeset View
Changeset View
Standalone View
Standalone View
swh/storage/cassandra/cql.py
# Copyright (C) 2019-2020 The Software Heritage developers | # Copyright (C) 2019-2020 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
import datetime | |||||
import functools | import functools | ||||
import json | import json | ||||
import logging | import logging | ||||
import random | import random | ||||
from typing import ( | from typing import ( | ||||
Any, | Any, | ||||
Callable, | Callable, | ||||
Dict, | Dict, | ||||
Show All 19 Lines | |||||
from swh.model.model import ( | from swh.model.model import ( | ||||
Sha1Git, | Sha1Git, | ||||
TimestampWithTimezone, | TimestampWithTimezone, | ||||
Timestamp, | Timestamp, | ||||
Person, | Person, | ||||
Content, | Content, | ||||
SkippedContent, | SkippedContent, | ||||
OriginVisit, | OriginVisit, | ||||
OriginVisitUpdate, | |||||
Origin, | Origin, | ||||
) | ) | ||||
from .common import Row, TOKEN_BEGIN, TOKEN_END, hash_url | from .common import Row, TOKEN_BEGIN, TOKEN_END, hash_url | ||||
from .schema import CREATE_TABLES_QUERIES, HASH_ALGORITHMS | from .schema import CREATE_TABLES_QUERIES, HASH_ALGORITHMS | ||||
logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
▲ Show 20 Lines • Show All 633 Lines • ▼ Show 20 Lines | ) -> ResultSet: | ||||
if last_visit is None: | if last_visit is None: | ||||
last_visit = -1 | last_visit = -1 | ||||
if limit is None: | if limit is None: | ||||
return self._origin_visit_get_no_limit(origin_url, last_visit) | return self._origin_visit_get_no_limit(origin_url, last_visit) | ||||
else: | else: | ||||
return self._origin_visit_get_limit(origin_url, last_visit, limit) | return self._origin_visit_get_limit(origin_url, last_visit, limit) | ||||
def origin_visit_update( | @_prepared_insert_statement("origin_visit", _origin_visit_keys) | ||||
self, origin_url: str, visit_id: int, updates: Dict[str, Any] | def origin_visit_add_one(self, visit: OriginVisit, *, statement) -> None: | ||||
) -> None: | self._add_one(statement, "origin_visit", visit, self._origin_visit_keys) | ||||
set_parts = [] | |||||
args: List[Any] = [] | _origin_visit_update_table_keys = [ | ||||
for (column, value) in updates.items(): | "origin", | ||||
set_parts.append(f"{column} = %s") | "visit", | ||||
if column == "metadata": | "date", | ||||
args.append(json.dumps(value)) | "status", | ||||
else: | "snapshot", | ||||
args.append(value) | "metadata", | ||||
] | |||||
if not set_parts: | @_prepared_insert_statement("origin_visit_update", _origin_visit_update_table_keys) | ||||
return | def origin_visit_update_add_one( | ||||
self, visit_update: OriginVisitUpdate, *, statement | |||||
) -> None: | |||||
assert self._origin_visit_update_table_keys[-1] == "metadata" | |||||
keys = self._origin_visit_update_table_keys | |||||
query = ( | metadata = json.dumps(visit_update.metadata) | ||||
"UPDATE origin_visit SET " | self._execute_with_retries( | ||||
+ ", ".join(set_parts) | statement, [getattr(visit_update, key) for key in keys[:-1]] + [metadata] | ||||
+ " WHERE origin = %s AND visit = %s" | |||||
) | ) | ||||
self._execute_with_retries(query, args + [origin_url, visit_id]) | |||||
@_prepared_insert_statement("origin_visit", _origin_visit_keys) | def _format_origin_visit_update_row( | ||||
def origin_visit_add_one(self, visit: OriginVisit, *, statement) -> None: | self, visit_update: ResultSet | ||||
self._add_one(statement, "origin_visit", visit, self._origin_visit_keys) | ) -> Dict[str, Any]: | ||||
"""Format a row visit_update into an origin_visit_update dict | |||||
""" | |||||
return { | |||||
**visit_update._asdict(), | |||||
"origin": visit_update.origin, | |||||
"date": visit_update.date.replace(tzinfo=datetime.timezone.utc), | |||||
"metadata": ( | |||||
json.loads(visit_update.metadata) if visit_update.metadata else None | |||||
), | |||||
} | |||||
@_prepared_statement( | |||||
"SELECT * FROM origin_visit_update " | |||||
"WHERE origin = ? AND visit = ? " | |||||
"ORDER BY date DESC " | |||||
"LIMIT 1" | |||||
) | |||||
def origin_visit_update_get_latest( | |||||
self, origin: str, visit: int, *, statement | |||||
) -> Optional[Dict[str, Any]]: | |||||
"""Given an origin visit id, return its latest origin_visit_update | |||||
""" | |||||
rows = list(self._execute_with_retries(statement, [origin, visit])) | |||||
if rows: | |||||
return self._format_origin_visit_update_row(rows[0]) | |||||
else: | |||||
return None | |||||
@_prepared_statement( | @_prepared_statement( | ||||
"UPDATE origin_visit SET " | "UPDATE origin_visit SET " | ||||
+ ", ".join("%s = ?" % key for key in _origin_visit_update_keys) | + ", ".join("%s = ?" % key for key in _origin_visit_update_keys) | ||||
+ " WHERE origin = ? AND visit = ?" | + " WHERE origin = ? AND visit = ?" | ||||
) | ) | ||||
def origin_visit_upsert(self, visit: OriginVisit, *, statement) -> None: | def origin_visit_upsert(self, visit: OriginVisit, *, statement) -> None: | ||||
args: List[Any] = [] | args: List[Any] = [] | ||||
Show All 17 Lines | ) -> Optional[Row]: | ||||
return rows[0] | return rows[0] | ||||
else: | else: | ||||
return None | return None | ||||
@_prepared_statement("SELECT * FROM origin_visit " "WHERE origin = ?") | @_prepared_statement("SELECT * FROM origin_visit " "WHERE origin = ?") | ||||
def origin_visit_get_all(self, origin_url: str, *, statement) -> ResultSet: | def origin_visit_get_all(self, origin_url: str, *, statement) -> ResultSet: | ||||
return self._execute_with_retries(statement, [origin_url]) | return self._execute_with_retries(statement, [origin_url]) | ||||
@_prepared_statement("SELECT * FROM origin_visit WHERE origin = ?") | |||||
def origin_visit_get_latest( | |||||
self, | |||||
origin: str, | |||||
allowed_statuses: Optional[Iterable[str]], | |||||
require_snapshot: bool, | |||||
*, | |||||
statement, | |||||
) -> Optional[Row]: | |||||
# TODO: do the ordering and filtering in Cassandra | |||||
rows = list(self._execute_with_retries(statement, [origin])) | |||||
rows.sort(key=lambda row: (row.date, row.visit), reverse=True) | |||||
for row in rows: | |||||
if require_snapshot and row.snapshot is None: | |||||
continue | |||||
if allowed_statuses is not None and row.status not in allowed_statuses: | |||||
continue | |||||
if row.snapshot is not None and self.snapshot_missing([row.snapshot]): | |||||
raise ValueError("visit references unknown snapshot") | |||||
return row | |||||
else: | |||||
return None | |||||
@_prepared_statement("SELECT * FROM origin_visit WHERE token(origin) >= ?") | @_prepared_statement("SELECT * FROM origin_visit WHERE token(origin) >= ?") | ||||
def _origin_visit_iter_from(self, min_token: int, *, statement) -> Iterator[Row]: | def _origin_visit_iter_from(self, min_token: int, *, statement) -> Iterator[Row]: | ||||
yield from self._execute_with_retries(statement, [min_token]) | yield from self._execute_with_retries(statement, [min_token]) | ||||
@_prepared_statement("SELECT * FROM origin_visit WHERE token(origin) < ?") | @_prepared_statement("SELECT * FROM origin_visit WHERE token(origin) < ?") | ||||
def _origin_visit_iter_to(self, max_token: int, *, statement) -> Iterator[Row]: | def _origin_visit_iter_to(self, max_token: int, *, statement) -> Iterator[Row]: | ||||
yield from self._execute_with_retries(statement, [max_token]) | yield from self._execute_with_retries(statement, [max_token]) | ||||
▲ Show 20 Lines • Show All 49 Lines • Show Last 20 Lines |