Changeset View
Changeset View
Standalone View
Standalone View
swh/storage/cassandra/storage.py
Show First 20 Lines • Show All 42 Lines • ▼ Show 20 Lines | from swh.storage.interface import ( | ||||
Sha1, | Sha1, | ||||
VISIT_STATUSES, | VISIT_STATUSES, | ||||
) | ) | ||||
from swh.storage.objstorage import ObjStorage | from swh.storage.objstorage import ObjStorage | ||||
from swh.storage.writer import JournalWriter | from swh.storage.writer import JournalWriter | ||||
from swh.storage.utils import map_optional, now | from swh.storage.utils import map_optional, now | ||||
from ..exc import StorageArgumentException, HashCollision | from ..exc import StorageArgumentException, HashCollision | ||||
from .common import TOKEN_BEGIN, TOKEN_END | from .common import TOKEN_BEGIN, TOKEN_END, hash_url, remove_keys | ||||
from . import converters | from . import converters | ||||
from .cql import CqlRunner | from .cql import CqlRunner | ||||
from .schema import HASH_ALGORITHMS | from .schema import HASH_ALGORITHMS | ||||
from .model import ( | |||||
ContentRow, | |||||
DirectoryEntryRow, | |||||
DirectoryRow, | |||||
MetadataAuthorityRow, | |||||
MetadataFetcherRow, | |||||
OriginRow, | |||||
OriginVisitRow, | |||||
RawExtrinsicMetadataRow, | |||||
RevisionParentRow, | |||||
SkippedContentRow, | |||||
SnapshotBranchRow, | |||||
SnapshotRow, | |||||
) | |||||
# Max block size of contents to return | # Max block size of contents to return | ||||
BULK_BLOCK_CONTENT_LEN_MAX = 10000 | BULK_BLOCK_CONTENT_LEN_MAX = 10000 | ||||
class CassandraStorage: | class CassandraStorage: | ||||
def __init__(self, hosts, keyspace, objstorage, port=9042, journal_writer=None): | def __init__(self, hosts, keyspace, objstorage, port=9042, journal_writer=None): | ||||
▲ Show 20 Lines • Show All 71 Lines • ▼ Show 20 Lines | def _content_add(self, contents: List[Content], with_data: bool) -> Dict: | ||||
collisions.append( | collisions.append( | ||||
{algo: getattr(row, algo) for algo in HASH_ALGORITHMS} | {algo: getattr(row, algo) for algo in HASH_ALGORITHMS} | ||||
) | ) | ||||
if collisions: | if collisions: | ||||
collisions.append(content.hashes()) | collisions.append(content.hashes()) | ||||
raise HashCollision(algo, content.get_hash(algo), collisions) | raise HashCollision(algo, content.get_hash(algo), collisions) | ||||
(token, insertion_finalizer) = self._cql_runner.content_add_prepare(content) | (token, insertion_finalizer) = self._cql_runner.content_add_prepare( | ||||
ContentRow(**remove_keys(content.to_dict(), ("data",))) | |||||
) | |||||
# Then add to index tables | # Then add to index tables | ||||
for algo in HASH_ALGORITHMS: | for algo in HASH_ALGORITHMS: | ||||
self._cql_runner.content_index_add_one(algo, content, token) | self._cql_runner.content_index_add_one(algo, content, token) | ||||
# Then to the main table | # Then to the main table | ||||
insertion_finalizer() | insertion_finalizer() | ||||
▲ Show 20 Lines • Show All 47 Lines • ▼ Show 20 Lines | ) -> PagedResult[Content]: | ||||
range_start = int(page_token) | range_start = int(page_token) | ||||
next_page_token: Optional[str] = None | next_page_token: Optional[str] = None | ||||
rows = self._cql_runner.content_get_token_range( | rows = self._cql_runner.content_get_token_range( | ||||
range_start, range_end, limit + 1 | range_start, range_end, limit + 1 | ||||
) | ) | ||||
contents = [] | contents = [] | ||||
last_id: Optional[int] = None | for counter, (tok, row) in enumerate(rows): | ||||
for counter, row in enumerate(rows): | |||||
if row.status == "absent": | if row.status == "absent": | ||||
continue | continue | ||||
row_d = row._asdict() | row_d = row.to_dict() | ||||
last_id = row_d.pop("tok") | |||||
if counter >= limit: | if counter >= limit: | ||||
next_page_token = str(last_id) | next_page_token = str(tok) | ||||
break | break | ||||
contents.append(Content(**row_d)) | contents.append(Content(**row_d)) | ||||
assert len(contents) <= limit | assert len(contents) <= limit | ||||
return PagedResult(results=contents, next_page_token=next_page_token) | return PagedResult(results=contents, next_page_token=next_page_token) | ||||
def content_get(self, contents: List[Sha1]) -> List[Optional[Content]]: | def content_get(self, contents: List[Sha1]) -> List[Optional[Content]]: | ||||
contents_by_sha1: Dict[Sha1, Optional[Content]] = {} | contents_by_sha1: Dict[Sha1, Optional[Content]] = {} | ||||
for sha1 in contents: | for sha1 in contents: | ||||
# Get all (sha1, sha1_git, sha256, blake2s256) whose sha1 | # Get all (sha1, sha1_git, sha256, blake2s256) whose sha1 | ||||
# matches the argument, from the index table ('content_by_sha1') | # matches the argument, from the index table ('content_by_sha1') | ||||
for row in self._content_get_from_hash("sha1", sha1): | for row in self._content_get_from_hash("sha1", sha1): | ||||
row_d = row._asdict() | row_d = row.to_dict() | ||||
row_d.pop("ctime") | row_d.pop("ctime") | ||||
content = Content(**row_d) | content = Content(**row_d) | ||||
contents_by_sha1[content.sha1] = content | contents_by_sha1[content.sha1] = content | ||||
return [contents_by_sha1.get(sha1) for sha1 in contents] | return [contents_by_sha1.get(sha1) for sha1 in contents] | ||||
def content_find(self, content: Dict[str, Any]) -> List[Content]: | def content_find(self, content: Dict[str, Any]) -> List[Content]: | ||||
# Find an algorithm that is common to all the requested contents. | # Find an algorithm that is common to all the requested contents. | ||||
# It will be used to do an initial filtering efficiently. | # It will be used to do an initial filtering efficiently. | ||||
Show All 11 Lines | def content_find(self, content: Dict[str, Any]) -> List[Content]: | ||||
# Re-check all the hashes, in case of collisions (either of the | # Re-check all the hashes, in case of collisions (either of the | ||||
# hash of the partition key, or the hashes in it) | # hash of the partition key, or the hashes in it) | ||||
for algo in HASH_ALGORITHMS: | for algo in HASH_ALGORITHMS: | ||||
if content.get(algo) and getattr(row, algo) != content[algo]: | if content.get(algo) and getattr(row, algo) != content[algo]: | ||||
# This hash didn't match; discard the row. | # This hash didn't match; discard the row. | ||||
break | break | ||||
else: | else: | ||||
# All hashes match, keep this row. | # All hashes match, keep this row. | ||||
row_d = row._asdict() | row_d = row.to_dict() | ||||
row_d["ctime"] = row.ctime.replace(tzinfo=datetime.timezone.utc) | row_d["ctime"] = row.ctime.replace(tzinfo=datetime.timezone.utc) | ||||
results.append(Content(**row_d)) | results.append(Content(**row_d)) | ||||
return results | return results | ||||
def content_missing( | def content_missing( | ||||
self, contents: List[Dict[str, Any]], key_hash: str = "sha1" | self, contents: List[Dict[str, Any]], key_hash: str = "sha1" | ||||
) -> Iterable[bytes]: | ) -> Iterable[bytes]: | ||||
if key_hash not in DEFAULT_ALGORITHMS: | if key_hash not in DEFAULT_ALGORITHMS: | ||||
▲ Show 20 Lines • Show All 46 Lines • ▼ Show 20 Lines | def _skipped_content_add(self, contents: List[SkippedContent]) -> Dict: | ||||
if not self._cql_runner.skipped_content_get_from_pk(c.to_dict()) | if not self._cql_runner.skipped_content_get_from_pk(c.to_dict()) | ||||
] | ] | ||||
self.journal_writer.skipped_content_add(contents) | self.journal_writer.skipped_content_add(contents) | ||||
for content in contents: | for content in contents: | ||||
# Compute token of the row in the main table | # Compute token of the row in the main table | ||||
(token, insertion_finalizer) = self._cql_runner.skipped_content_add_prepare( | (token, insertion_finalizer) = self._cql_runner.skipped_content_add_prepare( | ||||
content | SkippedContentRow.from_dict({"origin": None, **content.to_dict()}) | ||||
) | ) | ||||
# Then add to index tables | # Then add to index tables | ||||
for algo in HASH_ALGORITHMS: | for algo in HASH_ALGORITHMS: | ||||
self._cql_runner.skipped_content_index_add_one(algo, content, token) | self._cql_runner.skipped_content_index_add_one(algo, content, token) | ||||
# Then to the main table | # Then to the main table | ||||
insertion_finalizer() | insertion_finalizer() | ||||
Show All 17 Lines | def directory_add(self, directories: List[Directory]) -> Dict: | ||||
directories = [dir_ for dir_ in directories if dir_.id in missing] | directories = [dir_ for dir_ in directories if dir_.id in missing] | ||||
self.journal_writer.directory_add(directories) | self.journal_writer.directory_add(directories) | ||||
for directory in directories: | for directory in directories: | ||||
# Add directory entries to the 'directory_entry' table | # Add directory entries to the 'directory_entry' table | ||||
for entry in directory.entries: | for entry in directory.entries: | ||||
self._cql_runner.directory_entry_add_one( | self._cql_runner.directory_entry_add_one( | ||||
{**entry.to_dict(), "directory_id": directory.id} | DirectoryEntryRow(directory_id=directory.id, **entry.to_dict()) | ||||
) | ) | ||||
# Add the directory *after* adding all the entries, so someone | # Add the directory *after* adding all the entries, so someone | ||||
# calling snapshot_get_branch in the meantime won't end up | # calling snapshot_get_branch in the meantime won't end up | ||||
# with half the entries. | # with half the entries. | ||||
self._cql_runner.directory_add_one(directory.id) | self._cql_runner.directory_add_one(DirectoryRow(id=directory.id)) | ||||
return {"directory:add": len(directories)} | return {"directory:add": len(directories)} | ||||
def directory_missing(self, directories: List[Sha1Git]) -> Iterable[Sha1Git]: | def directory_missing(self, directories: List[Sha1Git]) -> Iterable[Sha1Git]: | ||||
return self._cql_runner.directory_missing(directories) | return self._cql_runner.directory_missing(directories) | ||||
def _join_dentry_to_content(self, dentry: DirectoryEntry) -> Dict[str, Any]: | def _join_dentry_to_content(self, dentry: DirectoryEntry) -> Dict[str, Any]: | ||||
keys = ( | keys = ( | ||||
Show All 16 Lines | class CassandraStorage: | ||||
def _directory_ls( | def _directory_ls( | ||||
self, directory_id: Sha1Git, recursive: bool, prefix: bytes = b"" | self, directory_id: Sha1Git, recursive: bool, prefix: bytes = b"" | ||||
) -> Iterable[Dict[str, Any]]: | ) -> Iterable[Dict[str, Any]]: | ||||
if self.directory_missing([directory_id]): | if self.directory_missing([directory_id]): | ||||
return | return | ||||
rows = list(self._cql_runner.directory_entry_get([directory_id])) | rows = list(self._cql_runner.directory_entry_get([directory_id])) | ||||
for row in rows: | for row in rows: | ||||
entry_d = row.to_dict() | |||||
# Build and yield the directory entry dict | # Build and yield the directory entry dict | ||||
entry = row._asdict() | del entry_d["directory_id"] | ||||
del entry["directory_id"] | entry = DirectoryEntry.from_dict(entry_d) | ||||
entry = DirectoryEntry.from_dict(entry) | |||||
ret = self._join_dentry_to_content(entry) | ret = self._join_dentry_to_content(entry) | ||||
ret["name"] = prefix + ret["name"] | ret["name"] = prefix + ret["name"] | ||||
ret["dir_id"] = directory_id | ret["dir_id"] = directory_id | ||||
yield ret | yield ret | ||||
if recursive and ret["type"] == "dir": | if recursive and ret["type"] == "dir": | ||||
yield from self._directory_ls( | yield from self._directory_ls( | ||||
ret["target"], True, prefix + ret["name"] + b"/" | ret["target"], True, prefix + ret["name"] + b"/" | ||||
▲ Show 20 Lines • Show All 51 Lines • ▼ Show 20 Lines | def revision_add(self, revisions: List[Revision]) -> Dict: | ||||
missing = self.revision_missing([rev.id for rev in revisions]) | missing = self.revision_missing([rev.id for rev in revisions]) | ||||
revisions = [rev for rev in revisions if rev.id in missing] | revisions = [rev for rev in revisions if rev.id in missing] | ||||
self.journal_writer.revision_add(revisions) | self.journal_writer.revision_add(revisions) | ||||
for revision in revisions: | for revision in revisions: | ||||
revobject = converters.revision_to_db(revision) | revobject = converters.revision_to_db(revision) | ||||
if revobject: | if revobject: | ||||
# Add parents first | # Add parents first | ||||
for (rank, parent) in enumerate(revobject["parents"]): | for (rank, parent) in enumerate(revision.parents): | ||||
self._cql_runner.revision_parent_add_one( | self._cql_runner.revision_parent_add_one( | ||||
revobject["id"], rank, parent | RevisionParentRow( | ||||
id=revobject.id, parent_rank=rank, parent_id=parent | |||||
) | |||||
) | ) | ||||
# Then write the main revision row. | # Then write the main revision row. | ||||
# Writing this after all parents were written ensures that | # Writing this after all parents were written ensures that | ||||
# read endpoints don't return a partial view while writing | # read endpoints don't return a partial view while writing | ||||
# the parents | # the parents | ||||
self._cql_runner.revision_add_one(revobject) | self._cql_runner.revision_add_one(revobject) | ||||
return {"revision:add": len(revisions)} | return {"revision:add": len(revisions)} | ||||
def revision_missing(self, revisions: List[Sha1Git]) -> Iterable[Sha1Git]: | def revision_missing(self, revisions: List[Sha1Git]) -> Iterable[Sha1Git]: | ||||
return self._cql_runner.revision_missing(revisions) | return self._cql_runner.revision_missing(revisions) | ||||
def revision_get( | def revision_get( | ||||
self, revisions: List[Sha1Git] | self, revisions: List[Sha1Git] | ||||
) -> Iterable[Optional[Dict[str, Any]]]: | ) -> Iterable[Optional[Dict[str, Any]]]: | ||||
rows = self._cql_runner.revision_get(revisions) | rows = self._cql_runner.revision_get(revisions) | ||||
revs = {} | revs = {} | ||||
for row in rows: | for row in rows: | ||||
# TODO: use a single query to get all parents? | # TODO: use a single query to get all parents? | ||||
# (it might have lower latency, but requires more code and more | # (it might have lower latency, but requires more code and more | ||||
# bandwidth, because revision id would be part of each returned | # bandwidth, because revision id would be part of each returned | ||||
# row) | # row) | ||||
parent_rows = self._cql_runner.revision_parent_get(row.id) | parents = tuple(self._cql_runner.revision_parent_get(row.id)) | ||||
# parent_rank is the clustering key, so results are already | # parent_rank is the clustering key, so results are already | ||||
# sorted by rank. | # sorted by rank. | ||||
parents = tuple(row.parent_id for row in parent_rows) | |||||
rev = converters.revision_from_db(row, parents=parents) | rev = converters.revision_from_db(row, parents=parents) | ||||
revs[rev.id] = rev.to_dict() | revs[rev.id] = rev.to_dict() | ||||
for rev_id in revisions: | for rev_id in revisions: | ||||
yield revs.get(rev_id) | yield revs.get(rev_id) | ||||
def _get_parent_revs( | def _get_parent_revs( | ||||
self, | self, | ||||
Show All 10 Lines | ]: | ||||
if not rev_ids: | if not rev_ids: | ||||
return | return | ||||
seen |= set(rev_ids) | seen |= set(rev_ids) | ||||
# We need this query, even if short=True, to return consistent | # We need this query, even if short=True, to return consistent | ||||
# results (ie. not return only a subset of a revision's parents | # results (ie. not return only a subset of a revision's parents | ||||
# if it is being written) | # if it is being written) | ||||
if short: | if short: | ||||
rows = self._cql_runner.revision_get_ids(rev_ids) | ids = self._cql_runner.revision_get_ids(rev_ids) | ||||
for id_ in ids: | |||||
# TODO: use a single query to get all parents? | |||||
# (it might have less latency, but requires less code and more | |||||
# bandwidth (because revision id would be part of each returned | |||||
# row) | |||||
parents = tuple(self._cql_runner.revision_parent_get(id_)) | |||||
# parent_rank is the clustering key, so results are already | |||||
# sorted by rank. | |||||
yield (id_, parents) | |||||
yield from self._get_parent_revs(parents, seen, limit, short) | |||||
else: | else: | ||||
rows = self._cql_runner.revision_get(rev_ids) | rows = self._cql_runner.revision_get(rev_ids) | ||||
for row in rows: | for row in rows: | ||||
# TODO: use a single query to get all parents? | # TODO: use a single query to get all parents? | ||||
# (it might have less latency, but requires less code and more | # (it might have less latency, but requires less code and more | ||||
# bandwidth (because revision id would be part of each returned | # bandwidth (because revision id would be part of each returned | ||||
# row) | # row) | ||||
parent_rows = self._cql_runner.revision_parent_get(row.id) | parents = tuple(self._cql_runner.revision_parent_get(row.id)) | ||||
# parent_rank is the clustering key, so results are already | # parent_rank is the clustering key, so results are already | ||||
# sorted by rank. | # sorted by rank. | ||||
parents = tuple(row.parent_id for row in parent_rows) | |||||
if short: | |||||
yield (row.id, parents) | |||||
else: | |||||
rev = converters.revision_from_db(row, parents=parents) | rev = converters.revision_from_db(row, parents=parents) | ||||
yield rev.to_dict() | yield rev.to_dict() | ||||
yield from self._get_parent_revs(parents, seen, limit, short) | yield from self._get_parent_revs(parents, seen, limit, short) | ||||
def revision_log( | def revision_log( | ||||
self, revisions: List[Sha1Git], limit: Optional[int] = None | self, revisions: List[Sha1Git], limit: Optional[int] = None | ||||
) -> Iterable[Optional[Dict[str, Any]]]: | ) -> Iterable[Optional[Dict[str, Any]]]: | ||||
seen: Set[Sha1Git] = set() | seen: Set[Sha1Git] = set() | ||||
yield from self._get_parent_revs(revisions, seen, limit, False) | yield from self._get_parent_revs(revisions, seen, limit, False) | ||||
def revision_shortlog( | def revision_shortlog( | ||||
▲ Show 20 Lines • Show All 44 Lines • ▼ Show 20 Lines | def snapshot_add(self, snapshots: List[Snapshot]) -> Dict: | ||||
snapshots = [snp for snp in snapshots if snp.id in missing] | snapshots = [snp for snp in snapshots if snp.id in missing] | ||||
for snapshot in snapshots: | for snapshot in snapshots: | ||||
self.journal_writer.snapshot_add([snapshot]) | self.journal_writer.snapshot_add([snapshot]) | ||||
# Add branches | # Add branches | ||||
for (branch_name, branch) in snapshot.branches.items(): | for (branch_name, branch) in snapshot.branches.items(): | ||||
if branch is None: | if branch is None: | ||||
target_type = None | target_type: Optional[str] = None | ||||
target = None | target: Optional[bytes] = None | ||||
else: | else: | ||||
target_type = branch.target_type.value | target_type = branch.target_type.value | ||||
target = branch.target | target = branch.target | ||||
self._cql_runner.snapshot_branch_add_one( | self._cql_runner.snapshot_branch_add_one( | ||||
{ | SnapshotBranchRow( | ||||
"snapshot_id": snapshot.id, | snapshot_id=snapshot.id, | ||||
"name": branch_name, | name=branch_name, | ||||
"target_type": target_type, | target_type=target_type, | ||||
"target": target, | target=target, | ||||
} | ) | ||||
) | ) | ||||
# Add the snapshot *after* adding all the branches, so someone | # Add the snapshot *after* adding all the branches, so someone | ||||
# calling snapshot_get_branch in the meantime won't end up | # calling snapshot_get_branch in the meantime won't end up | ||||
# with half the branches. | # with half the branches. | ||||
self._cql_runner.snapshot_add_one(snapshot.id) | self._cql_runner.snapshot_add_one(SnapshotRow(id=snapshot.id)) | ||||
return {"snapshot:add": len(snapshots)} | return {"snapshot:add": len(snapshots)} | ||||
def snapshot_missing(self, snapshots: List[Sha1Git]) -> Iterable[Sha1Git]: | def snapshot_missing(self, snapshots: List[Sha1Git]) -> Iterable[Sha1Git]: | ||||
return self._cql_runner.snapshot_missing(snapshots) | return self._cql_runner.snapshot_missing(snapshots) | ||||
def snapshot_get(self, snapshot_id: Sha1Git) -> Optional[Dict[str, Any]]: | def snapshot_get(self, snapshot_id: Sha1Git) -> Optional[Dict[str, Any]]: | ||||
d = self.snapshot_get_branches(snapshot_id) | d = self.snapshot_get_branches(snapshot_id) | ||||
Show All 18 Lines | ) -> Optional[Dict[str, Any]]: | ||||
return self.snapshot_get(visit_status.snapshot) | return self.snapshot_get(visit_status.snapshot) | ||||
return None | return None | ||||
def snapshot_count_branches(self, snapshot_id: Sha1Git) -> Optional[Dict[str, int]]: | def snapshot_count_branches(self, snapshot_id: Sha1Git) -> Optional[Dict[str, int]]: | ||||
if self._cql_runner.snapshot_missing([snapshot_id]): | if self._cql_runner.snapshot_missing([snapshot_id]): | ||||
# Makes sure we don't fetch branches for a snapshot that is | # Makes sure we don't fetch branches for a snapshot that is | ||||
# being added. | # being added. | ||||
return None | return None | ||||
rows = list(self._cql_runner.snapshot_count_branches(snapshot_id)) | return self._cql_runner.snapshot_count_branches(snapshot_id) | ||||
assert len(rows) == 1 | |||||
(nb_none, counts) = rows[0].counts | |||||
counts = dict(counts) | |||||
if nb_none: | |||||
counts[None] = nb_none | |||||
return counts | |||||
def snapshot_get_branches( | def snapshot_get_branches( | ||||
self, | self, | ||||
snapshot_id: Sha1Git, | snapshot_id: Sha1Git, | ||||
branches_from: bytes = b"", | branches_from: bytes = b"", | ||||
branches_count: int = 1000, | branches_count: int = 1000, | ||||
target_types: Optional[List[str]] = None, | target_types: Optional[List[str]] = None, | ||||
) -> Optional[PartialBranches]: | ) -> Optional[PartialBranches]: | ||||
▲ Show 20 Lines • Show All 90 Lines • ▼ Show 20 Lines | def origin_get_one(self, origin_url: str) -> Optional[Origin]: | ||||
assert len(rows) == 1 | assert len(rows) == 1 | ||||
return Origin(url=rows[0].url) | return Origin(url=rows[0].url) | ||||
else: | else: | ||||
return None | return None | ||||
def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]: | def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]: | ||||
results = [] | results = [] | ||||
for sha1 in sha1s: | for sha1 in sha1s: | ||||
rows = self._cql_runner.origin_get_by_sha1(sha1) | rows = list(self._cql_runner.origin_get_by_sha1(sha1)) | ||||
origin = {"url": rows.one().url} if rows else None | origin = {"url": rows[0].url} if rows else None | ||||
results.append(origin) | results.append(origin) | ||||
return results | return results | ||||
def origin_list( | def origin_list( | ||||
self, page_token: Optional[str] = None, limit: int = 100 | self, page_token: Optional[str] = None, limit: int = 100 | ||||
) -> PagedResult[Origin]: | ) -> PagedResult[Origin]: | ||||
# Compute what token to begin the listing from | # Compute what token to begin the listing from | ||||
start_token = TOKEN_BEGIN | start_token = TOKEN_BEGIN | ||||
if page_token: | if page_token: | ||||
start_token = int(page_token) | start_token = int(page_token) | ||||
if not (TOKEN_BEGIN <= start_token <= TOKEN_END): | if not (TOKEN_BEGIN <= start_token <= TOKEN_END): | ||||
raise StorageArgumentException("Invalid page_token.") | raise StorageArgumentException("Invalid page_token.") | ||||
next_page_token = None | next_page_token = None | ||||
origins = [] | origins = [] | ||||
# Take one more origin so we can reuse it as the next page token if any | # Take one more origin so we can reuse it as the next page token if any | ||||
for row in self._cql_runner.origin_list(start_token, limit + 1): | for (tok, row) in self._cql_runner.origin_list(start_token, limit + 1): | ||||
origins.append(Origin(url=row.url)) | origins.append(Origin(url=row.url)) | ||||
# keep reference of the last id for pagination purposes | # keep reference of the last id for pagination purposes | ||||
last_id = row.tok | last_id = tok | ||||
if len(origins) > limit: | if len(origins) > limit: | ||||
# last origin id is the next page token | # last origin id is the next page token | ||||
next_page_token = str(last_id) | next_page_token = str(last_id) | ||||
# excluding that origin from the result to respect the limit size | # excluding that origin from the result to respect the limit size | ||||
origins = origins[:limit] | origins = origins[:limit] | ||||
assert len(origins) <= limit | assert len(origins) <= limit | ||||
return PagedResult(results=origins, next_page_token=next_page_token) | return PagedResult(results=origins, next_page_token=next_page_token) | ||||
def origin_search( | def origin_search( | ||||
self, | self, | ||||
url_pattern: str, | url_pattern: str, | ||||
page_token: Optional[str] = None, | page_token: Optional[str] = None, | ||||
limit: int = 50, | limit: int = 50, | ||||
regexp: bool = False, | regexp: bool = False, | ||||
with_visit: bool = False, | with_visit: bool = False, | ||||
) -> PagedResult[Origin]: | ) -> PagedResult[Origin]: | ||||
# TODO: remove this endpoint, swh-search should be used instead. | # TODO: remove this endpoint, swh-search should be used instead. | ||||
next_page_token = None | next_page_token = None | ||||
offset = int(page_token) if page_token else 0 | offset = int(page_token) if page_token else 0 | ||||
origins = self._cql_runner.origin_iter_all() | origin_rows = [row for row in self._cql_runner.origin_iter_all()] | ||||
if regexp: | if regexp: | ||||
pat = re.compile(url_pattern) | pat = re.compile(url_pattern) | ||||
origins = [Origin(orig.url) for orig in origins if pat.search(orig.url)] | origin_rows = [row for row in origin_rows if pat.search(row.url)] | ||||
else: | else: | ||||
origins = [Origin(orig.url) for orig in origins if url_pattern in orig.url] | origin_rows = [row for row in origin_rows if url_pattern in row.url] | ||||
if with_visit: | if with_visit: | ||||
origins = [Origin(orig.url) for orig in origins if orig.next_visit_id > 1] | origin_rows = [row for row in origin_rows if row.next_visit_id > 1] | ||||
origins = [Origin(url=row.url) for row in origin_rows] | |||||
origins = origins[offset : offset + limit + 1] | origins = origins[offset : offset + limit + 1] | ||||
if len(origins) > limit: | if len(origins) > limit: | ||||
# next offset | # next offset | ||||
next_page_token = str(offset + limit) | next_page_token = str(offset + limit) | ||||
# excluding that origin from the result to respect the limit size | # excluding that origin from the result to respect the limit size | ||||
origins = origins[:limit] | origins = origins[:limit] | ||||
assert len(origins) <= limit | assert len(origins) <= limit | ||||
return PagedResult(results=origins, next_page_token=next_page_token) | return PagedResult(results=origins, next_page_token=next_page_token) | ||||
def origin_add(self, origins: List[Origin]) -> Dict[str, int]: | def origin_add(self, origins: List[Origin]) -> Dict[str, int]: | ||||
to_add = [ori for ori in origins if self.origin_get_one(ori.url) is None] | to_add = [ori for ori in origins if self.origin_get_one(ori.url) is None] | ||||
self.journal_writer.origin_add(to_add) | self.journal_writer.origin_add(to_add) | ||||
for origin in to_add: | for origin in to_add: | ||||
self._cql_runner.origin_add_one(origin) | self._cql_runner.origin_add_one( | ||||
OriginRow(sha1=hash_url(origin.url), url=origin.url, next_visit_id=1) | |||||
) | |||||
return {"origin:add": len(to_add)} | return {"origin:add": len(to_add)} | ||||
def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]: | def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]: | ||||
for visit in visits: | for visit in visits: | ||||
origin = self.origin_get_one(visit.origin) | origin = self.origin_get_one(visit.origin) | ||||
if not origin: # Cannot add a visit without an origin | if not origin: # Cannot add a visit without an origin | ||||
raise StorageArgumentException("Unknown origin %s", visit.origin) | raise StorageArgumentException("Unknown origin %s", visit.origin) | ||||
all_visits = [] | all_visits = [] | ||||
nb_visits = 0 | nb_visits = 0 | ||||
for visit in visits: | for visit in visits: | ||||
nb_visits += 1 | nb_visits += 1 | ||||
if not visit.visit: | if not visit.visit: | ||||
visit_id = self._cql_runner.origin_generate_unique_visit_id( | visit_id = self._cql_runner.origin_generate_unique_visit_id( | ||||
visit.origin | visit.origin | ||||
) | ) | ||||
visit = attr.evolve(visit, visit=visit_id) | visit = attr.evolve(visit, visit=visit_id) | ||||
self.journal_writer.origin_visit_add([visit]) | self.journal_writer.origin_visit_add([visit]) | ||||
self._cql_runner.origin_visit_add_one(visit) | self._cql_runner.origin_visit_add_one(OriginVisitRow(**visit.to_dict())) | ||||
assert visit.visit is not None | assert visit.visit is not None | ||||
all_visits.append(visit) | all_visits.append(visit) | ||||
self._origin_visit_status_add( | self._origin_visit_status_add( | ||||
OriginVisitStatus( | OriginVisitStatus( | ||||
origin=visit.origin, | origin=visit.origin, | ||||
visit=visit.visit, | visit=visit.visit, | ||||
date=visit.date, | date=visit.date, | ||||
status="created", | status="created", | ||||
snapshot=None, | snapshot=None, | ||||
) | ) | ||||
) | ) | ||||
return all_visits | return all_visits | ||||
def _origin_visit_status_add(self, visit_status: OriginVisitStatus) -> None: | def _origin_visit_status_add(self, visit_status: OriginVisitStatus) -> None: | ||||
"""Add an origin visit status""" | """Add an origin visit status""" | ||||
self.journal_writer.origin_visit_status_add([visit_status]) | self.journal_writer.origin_visit_status_add([visit_status]) | ||||
self._cql_runner.origin_visit_status_add_one(visit_status) | self._cql_runner.origin_visit_status_add_one( | ||||
converters.visit_status_to_row(visit_status) | |||||
) | |||||
def origin_visit_status_add(self, visit_statuses: List[OriginVisitStatus]) -> None: | def origin_visit_status_add(self, visit_statuses: List[OriginVisitStatus]) -> None: | ||||
# First round to check existence (fail early if any is ko) | # First round to check existence (fail early if any is ko) | ||||
for visit_status in visit_statuses: | for visit_status in visit_statuses: | ||||
origin_url = self.origin_get_one(visit_status.origin) | origin_url = self.origin_get_one(visit_status.origin) | ||||
if not origin_url: | if not origin_url: | ||||
raise StorageArgumentException(f"Unknown origin {visit_status.origin}") | raise StorageArgumentException(f"Unknown origin {visit_status.origin}") | ||||
Show All 29 Lines | def _origin_visit_get_latest_status(self, visit: OriginVisit) -> OriginVisitStatus: | ||||
row = self._cql_runner.origin_visit_status_get_latest(visit.origin, visit.visit) | row = self._cql_runner.origin_visit_status_get_latest(visit.origin, visit.visit) | ||||
assert row is not None | assert row is not None | ||||
visit_status = converters.row_to_visit_status(row) | visit_status = converters.row_to_visit_status(row) | ||||
return attr.evolve(visit_status, origin=visit.origin) | return attr.evolve(visit_status, origin=visit.origin) | ||||
@staticmethod | @staticmethod | ||||
def _format_origin_visit_row(visit): | def _format_origin_visit_row(visit): | ||||
return { | return { | ||||
**visit._asdict(), | **visit.to_dict(), | ||||
"origin": visit.origin, | "origin": visit.origin, | ||||
"date": visit.date.replace(tzinfo=datetime.timezone.utc), | "date": visit.date.replace(tzinfo=datetime.timezone.utc), | ||||
} | } | ||||
def origin_visit_get( | def origin_visit_get( | ||||
self, | self, | ||||
origin: str, | origin: str, | ||||
page_token: Optional[str] = None, | page_token: Optional[str] = None, | ||||
▲ Show 20 Lines • Show All 119 Lines • ▼ Show 20 Lines | def origin_visit_status_get_latest( | ||||
allowed_statuses: Optional[List[str]] = None, | allowed_statuses: Optional[List[str]] = None, | ||||
require_snapshot: bool = False, | require_snapshot: bool = False, | ||||
) -> Optional[OriginVisitStatus]: | ) -> Optional[OriginVisitStatus]: | ||||
if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): | if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): | ||||
raise StorageArgumentException( | raise StorageArgumentException( | ||||
f"Unknown allowed statuses {','.join(allowed_statuses)}, only " | f"Unknown allowed statuses {','.join(allowed_statuses)}, only " | ||||
f"{','.join(VISIT_STATUSES)} authorized" | f"{','.join(VISIT_STATUSES)} authorized" | ||||
) | ) | ||||
rows = self._cql_runner.origin_visit_status_get( | rows = list( | ||||
self._cql_runner.origin_visit_status_get( | |||||
origin_url, visit, allowed_statuses, require_snapshot | origin_url, visit, allowed_statuses, require_snapshot | ||||
) | ) | ||||
) | |||||
# filtering is done python side as we cannot do it server side | # filtering is done python side as we cannot do it server side | ||||
if allowed_statuses: | if allowed_statuses: | ||||
rows = [row for row in rows if row.status in allowed_statuses] | rows = [row for row in rows if row.status in allowed_statuses] | ||||
if require_snapshot: | if require_snapshot: | ||||
rows = [row for row in rows if row.snapshot is not None] | rows = [row for row in rows if row.snapshot is not None] | ||||
if not rows: | if not rows: | ||||
return None | return None | ||||
return converters.row_to_visit_status(rows[0]) | return converters.row_to_visit_status(rows[0]) | ||||
▲ Show 20 Lines • Show All 46 Lines • ▼ Show 20 Lines | def raw_extrinsic_metadata_add(self, metadata: List[RawExtrinsicMetadata]) -> None: | ||||
if not self._cql_runner.metadata_fetcher_get( | if not self._cql_runner.metadata_fetcher_get( | ||||
metadata_entry.fetcher.name, metadata_entry.fetcher.version | metadata_entry.fetcher.name, metadata_entry.fetcher.version | ||||
): | ): | ||||
raise StorageArgumentException( | raise StorageArgumentException( | ||||
f"Unknown fetcher {metadata_entry.fetcher}" | f"Unknown fetcher {metadata_entry.fetcher}" | ||||
) | ) | ||||
try: | try: | ||||
self._cql_runner.raw_extrinsic_metadata_add( | row = RawExtrinsicMetadataRow( | ||||
type=metadata_entry.type.value, | type=metadata_entry.type.value, | ||||
id=str(metadata_entry.id), | id=str(metadata_entry.id), | ||||
authority_type=metadata_entry.authority.type.value, | authority_type=metadata_entry.authority.type.value, | ||||
authority_url=metadata_entry.authority.url, | authority_url=metadata_entry.authority.url, | ||||
discovery_date=metadata_entry.discovery_date, | discovery_date=metadata_entry.discovery_date, | ||||
fetcher_name=metadata_entry.fetcher.name, | fetcher_name=metadata_entry.fetcher.name, | ||||
fetcher_version=metadata_entry.fetcher.version, | fetcher_version=metadata_entry.fetcher.version, | ||||
format=metadata_entry.format, | format=metadata_entry.format, | ||||
metadata=metadata_entry.metadata, | metadata=metadata_entry.metadata, | ||||
origin=metadata_entry.origin, | origin=metadata_entry.origin, | ||||
visit=metadata_entry.visit, | visit=metadata_entry.visit, | ||||
snapshot=map_optional(str, metadata_entry.snapshot), | snapshot=map_optional(str, metadata_entry.snapshot), | ||||
release=map_optional(str, metadata_entry.release), | release=map_optional(str, metadata_entry.release), | ||||
revision=map_optional(str, metadata_entry.revision), | revision=map_optional(str, metadata_entry.revision), | ||||
path=metadata_entry.path, | path=metadata_entry.path, | ||||
directory=map_optional(str, metadata_entry.directory), | directory=map_optional(str, metadata_entry.directory), | ||||
) | ) | ||||
self._cql_runner.raw_extrinsic_metadata_add(row) | |||||
except TypeError as e: | except TypeError as e: | ||||
raise StorageArgumentException(*e.args) | raise StorageArgumentException(*e.args) | ||||
def raw_extrinsic_metadata_get( | def raw_extrinsic_metadata_get( | ||||
self, | self, | ||||
type: MetadataTargetType, | type: MetadataTargetType, | ||||
id: Union[str, SWHID], | id: Union[str, SWHID], | ||||
authority: MetadataAuthority, | authority: MetadataAuthority, | ||||
▲ Show 20 Lines • Show All 89 Lines • ▼ Show 20 Lines | ) -> PagedResult[RawExtrinsicMetadata]: | ||||
next_page_token = None | next_page_token = None | ||||
return PagedResult(next_page_token=next_page_token, results=results,) | return PagedResult(next_page_token=next_page_token, results=results,) | ||||
def metadata_fetcher_add(self, fetchers: List[MetadataFetcher]) -> None: | def metadata_fetcher_add(self, fetchers: List[MetadataFetcher]) -> None: | ||||
self.journal_writer.metadata_fetcher_add(fetchers) | self.journal_writer.metadata_fetcher_add(fetchers) | ||||
for fetcher in fetchers: | for fetcher in fetchers: | ||||
self._cql_runner.metadata_fetcher_add( | self._cql_runner.metadata_fetcher_add( | ||||
fetcher.name, | MetadataFetcherRow( | ||||
fetcher.version, | name=fetcher.name, | ||||
json.dumps(map_optional(dict, fetcher.metadata)), | version=fetcher.version, | ||||
metadata=json.dumps(map_optional(dict, fetcher.metadata)), | |||||
) | |||||
) | ) | ||||
def metadata_fetcher_get( | def metadata_fetcher_get( | ||||
self, name: str, version: str | self, name: str, version: str | ||||
) -> Optional[MetadataFetcher]: | ) -> Optional[MetadataFetcher]: | ||||
fetcher = self._cql_runner.metadata_fetcher_get(name, version) | fetcher = self._cql_runner.metadata_fetcher_get(name, version) | ||||
if fetcher: | if fetcher: | ||||
return MetadataFetcher( | return MetadataFetcher( | ||||
name=fetcher.name, | name=fetcher.name, | ||||
version=fetcher.version, | version=fetcher.version, | ||||
metadata=json.loads(fetcher.metadata), | metadata=json.loads(fetcher.metadata), | ||||
) | ) | ||||
else: | else: | ||||
return None | return None | ||||
def metadata_authority_add(self, authorities: List[MetadataAuthority]) -> None: | def metadata_authority_add(self, authorities: List[MetadataAuthority]) -> None: | ||||
self.journal_writer.metadata_authority_add(authorities) | self.journal_writer.metadata_authority_add(authorities) | ||||
for authority in authorities: | for authority in authorities: | ||||
self._cql_runner.metadata_authority_add( | self._cql_runner.metadata_authority_add( | ||||
authority.url, | MetadataAuthorityRow( | ||||
authority.type.value, | url=authority.url, | ||||
json.dumps(map_optional(dict, authority.metadata)), | type=authority.type.value, | ||||
metadata=json.dumps(map_optional(dict, authority.metadata)), | |||||
) | |||||
) | ) | ||||
def metadata_authority_get( | def metadata_authority_get( | ||||
self, type: MetadataAuthorityType, url: str | self, type: MetadataAuthorityType, url: str | ||||
) -> Optional[MetadataAuthority]: | ) -> Optional[MetadataAuthority]: | ||||
authority = self._cql_runner.metadata_authority_get(type.value, url) | authority = self._cql_runner.metadata_authority_get(type.value, url) | ||||
if authority: | if authority: | ||||
return MetadataAuthority( | return MetadataAuthority( | ||||
Show All 15 Lines |