Changeset View
Changeset View
Standalone View
Standalone View
swh/storage/cassandra/storage.py
# Copyright (C) 2019-2020 The Software Heritage developers | # Copyright (C) 2019-2021 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
import base64 | import base64 | ||||
import datetime | import datetime | ||||
import itertools | import itertools | ||||
import operator | import operator | ||||
▲ Show 20 Lines • Show All 42 Lines • ▼ Show 20 Lines | |||||
) | ) | ||||
from swh.storage.interface import ( | from swh.storage.interface import ( | ||||
VISIT_STATUSES, | VISIT_STATUSES, | ||||
ListOrder, | ListOrder, | ||||
PagedResult, | PagedResult, | ||||
PartialBranches, | PartialBranches, | ||||
Sha1, | Sha1, | ||||
) | ) | ||||
from swh.storage.metrics import process_metrics, send_metric, timed | |||||
from swh.storage.objstorage import ObjStorage | from swh.storage.objstorage import ObjStorage | ||||
from swh.storage.utils import map_optional, now | from swh.storage.utils import map_optional, now | ||||
from swh.storage.writer import JournalWriter | from swh.storage.writer import JournalWriter | ||||
from . import converters | from . import converters | ||||
from ..exc import HashCollision, StorageArgumentException | from ..exc import HashCollision, StorageArgumentException | ||||
from ..utils import remove_keys | from ..utils import remove_keys | ||||
from .common import TOKEN_BEGIN, TOKEN_END, hash_url | from .common import TOKEN_BEGIN, TOKEN_END, hash_url | ||||
▲ Show 20 Lines • Show All 63 Lines • ▼ Show 20 Lines | ): | ||||
self._allow_overwrite = allow_overwrite | self._allow_overwrite = allow_overwrite | ||||
def _set_cql_runner(self): | def _set_cql_runner(self): | ||||
"""Used by tests when they need to reset the CqlRunner""" | """Used by tests when they need to reset the CqlRunner""" | ||||
self._cql_runner: CqlRunner = CqlRunner( | self._cql_runner: CqlRunner = CqlRunner( | ||||
self._hosts, self._keyspace, self._port, self._consistency_level | self._hosts, self._keyspace, self._port, self._consistency_level | ||||
) | ) | ||||
@timed | |||||
def check_config(self, *, check_write: bool) -> bool: | def check_config(self, *, check_write: bool) -> bool: | ||||
self._cql_runner.check_read() | self._cql_runner.check_read() | ||||
return True | return True | ||||
def _content_get_from_hash(self, algo, hash_) -> Iterable: | def _content_get_from_hash(self, algo, hash_) -> Iterable: | ||||
"""From the name of a hash algorithm and a value of that hash, | """From the name of a hash algorithm and a value of that hash, | ||||
looks up the "hash -> token" secondary table (content_by_{algo}) | looks up the "hash -> token" secondary table (content_by_{algo}) | ||||
▲ Show 20 Lines • Show All 86 Lines • ▼ Show 20 Lines | def _content_add(self, contents: List[Content], with_data: bool) -> Dict[str, int]: | ||||
"content:add": content_add, | "content:add": content_add, | ||||
} | } | ||||
if with_data: | if with_data: | ||||
summary["content:add:bytes"] = content_add_bytes | summary["content:add:bytes"] = content_add_bytes | ||||
return summary | return summary | ||||
@timed | |||||
@process_metrics | |||||
def content_add(self, content: List[Content]) -> Dict[str, int]: | def content_add(self, content: List[Content]) -> Dict[str, int]: | ||||
to_add = { | to_add = { | ||||
(c.sha1, c.sha1_git, c.sha256, c.blake2s256): c for c in content | (c.sha1, c.sha1_git, c.sha256, c.blake2s256): c for c in content | ||||
}.values() | }.values() | ||||
contents = [attr.evolve(c, ctime=now()) for c in to_add] | contents = [attr.evolve(c, ctime=now()) for c in to_add] | ||||
return self._content_add(list(contents), with_data=True) | return self._content_add(list(contents), with_data=True) | ||||
@timed | |||||
def content_update( | def content_update( | ||||
self, contents: List[Dict[str, Any]], keys: List[str] = [] | self, contents: List[Dict[str, Any]], keys: List[str] = [] | ||||
) -> None: | ) -> None: | ||||
raise NotImplementedError( | raise NotImplementedError( | ||||
"content_update is not supported by the Cassandra backend" | "content_update is not supported by the Cassandra backend" | ||||
) | ) | ||||
@timed | |||||
@process_metrics | |||||
def content_add_metadata(self, content: List[Content]) -> Dict[str, int]: | def content_add_metadata(self, content: List[Content]) -> Dict[str, int]: | ||||
return self._content_add(content, with_data=False) | return self._content_add(content, with_data=False) | ||||
@timed | |||||
def content_get_data(self, content: Sha1) -> Optional[bytes]: | def content_get_data(self, content: Sha1) -> Optional[bytes]: | ||||
# FIXME: Make this method support slicing the `data` | # FIXME: Make this method support slicing the `data` | ||||
return self.objstorage.content_get(content) | return self.objstorage.content_get(content) | ||||
@timed | |||||
def content_get_partition( | def content_get_partition( | ||||
self, | self, | ||||
partition_id: int, | partition_id: int, | ||||
nb_partitions: int, | nb_partitions: int, | ||||
page_token: Optional[str] = None, | page_token: Optional[str] = None, | ||||
limit: int = 1000, | limit: int = 1000, | ||||
) -> PagedResult[Content]: | ) -> PagedResult[Content]: | ||||
if limit is None: | if limit is None: | ||||
Show All 25 Lines | ) -> PagedResult[Content]: | ||||
next_page_token = str(tok) | next_page_token = str(tok) | ||||
break | break | ||||
row_d.pop("ctime") | row_d.pop("ctime") | ||||
contents.append(Content(**row_d)) | contents.append(Content(**row_d)) | ||||
assert len(contents) <= limit | assert len(contents) <= limit | ||||
return PagedResult(results=contents, next_page_token=next_page_token) | return PagedResult(results=contents, next_page_token=next_page_token) | ||||
@timed | |||||
def content_get( | def content_get( | ||||
self, contents: List[bytes], algo: str = "sha1" | self, contents: List[bytes], algo: str = "sha1" | ||||
) -> List[Optional[Content]]: | ) -> List[Optional[Content]]: | ||||
if algo not in DEFAULT_ALGORITHMS: | if algo not in DEFAULT_ALGORITHMS: | ||||
raise StorageArgumentException( | raise StorageArgumentException( | ||||
"algo should be one of {','.join(DEFAULT_ALGORITHMS)}" | "algo should be one of {','.join(DEFAULT_ALGORITHMS)}" | ||||
) | ) | ||||
key = operator.attrgetter(algo) | key = operator.attrgetter(algo) | ||||
contents_by_hash: Dict[Sha1, Optional[Content]] = {} | contents_by_hash: Dict[Sha1, Optional[Content]] = {} | ||||
for hash_ in contents: | for hash_ in contents: | ||||
# Get all (sha1, sha1_git, sha256, blake2s256) whose sha1/sha1_git | # Get all (sha1, sha1_git, sha256, blake2s256) whose sha1/sha1_git | ||||
# matches the argument, from the index table ('content_by_*') | # matches the argument, from the index table ('content_by_*') | ||||
for row in self._content_get_from_hash(algo, hash_): | for row in self._content_get_from_hash(algo, hash_): | ||||
row_d = row.to_dict() | row_d = row.to_dict() | ||||
row_d.pop("ctime") | row_d.pop("ctime") | ||||
content = Content(**row_d) | content = Content(**row_d) | ||||
contents_by_hash[key(content)] = content | contents_by_hash[key(content)] = content | ||||
return [contents_by_hash.get(hash_) for hash_ in contents] | return [contents_by_hash.get(hash_) for hash_ in contents] | ||||
@timed | |||||
def content_find(self, content: Dict[str, Any]) -> List[Content]: | def content_find(self, content: Dict[str, Any]) -> List[Content]: | ||||
# Find an algorithm that is common to all the requested contents. | # Find an algorithm that is common to all the requested contents. | ||||
# It will be used to do an initial filtering efficiently. | # It will be used to do an initial filtering efficiently. | ||||
filter_algos = list(set(content).intersection(HASH_ALGORITHMS)) | filter_algos = list(set(content).intersection(HASH_ALGORITHMS)) | ||||
if not filter_algos: | if not filter_algos: | ||||
raise StorageArgumentException( | raise StorageArgumentException( | ||||
"content keys must contain at least one " | "content keys must contain at least one " | ||||
f"of: {', '.join(sorted(HASH_ALGORITHMS))}" | f"of: {', '.join(sorted(HASH_ALGORITHMS))}" | ||||
Show All 11 Lines | def content_find(self, content: Dict[str, Any]) -> List[Content]: | ||||
break | break | ||||
else: | else: | ||||
# All hashes match, keep this row. | # All hashes match, keep this row. | ||||
row_d = row.to_dict() | row_d = row.to_dict() | ||||
row_d["ctime"] = row.ctime.replace(tzinfo=datetime.timezone.utc) | row_d["ctime"] = row.ctime.replace(tzinfo=datetime.timezone.utc) | ||||
results.append(Content(**row_d)) | results.append(Content(**row_d)) | ||||
return results | return results | ||||
@timed | |||||
def content_missing( | def content_missing( | ||||
self, contents: List[Dict[str, Any]], key_hash: str = "sha1" | self, contents: List[Dict[str, Any]], key_hash: str = "sha1" | ||||
) -> Iterable[bytes]: | ) -> Iterable[bytes]: | ||||
if key_hash not in DEFAULT_ALGORITHMS: | if key_hash not in DEFAULT_ALGORITHMS: | ||||
raise StorageArgumentException( | raise StorageArgumentException( | ||||
"key_hash should be one of {','.join(DEFAULT_ALGORITHMS)}" | "key_hash should be one of {','.join(DEFAULT_ALGORITHMS)}" | ||||
) | ) | ||||
Show All 12 Lines | ) -> Iterable[bytes]: | ||||
yield content[key_hash] | yield content[key_hash] | ||||
# For these, we need the expensive index lookups + main table. | # For these, we need the expensive index lookups + main table. | ||||
for content in contents_with_missing_hashes: | for content in contents_with_missing_hashes: | ||||
res = self.content_find(content) | res = self.content_find(content) | ||||
if not res: | if not res: | ||||
yield content[key_hash] | yield content[key_hash] | ||||
@timed | |||||
def content_missing_per_sha1(self, contents: List[bytes]) -> Iterable[bytes]: | def content_missing_per_sha1(self, contents: List[bytes]) -> Iterable[bytes]: | ||||
return self.content_missing([{"sha1": c} for c in contents]) | return self.content_missing([{"sha1": c} for c in contents]) | ||||
@timed | |||||
def content_missing_per_sha1_git( | def content_missing_per_sha1_git( | ||||
self, contents: List[Sha1Git] | self, contents: List[Sha1Git] | ||||
) -> Iterable[Sha1Git]: | ) -> Iterable[Sha1Git]: | ||||
return self.content_missing( | return self.content_missing( | ||||
[{"sha1_git": c} for c in contents], key_hash="sha1_git" | [{"sha1_git": c} for c in contents], key_hash="sha1_git" | ||||
) | ) | ||||
@timed | |||||
def content_get_random(self) -> Sha1Git: | def content_get_random(self) -> Sha1Git: | ||||
content = self._cql_runner.content_get_random() | content = self._cql_runner.content_get_random() | ||||
assert content, "Could not find any content" | assert content, "Could not find any content" | ||||
return content.sha1_git | return content.sha1_git | ||||
def _skipped_content_add(self, contents: List[SkippedContent]) -> Dict[str, int]: | def _skipped_content_add(self, contents: List[SkippedContent]) -> Dict[str, int]: | ||||
# Filter-out content already in the database. | # Filter-out content already in the database. | ||||
if not self._allow_overwrite: | if not self._allow_overwrite: | ||||
Show All 15 Lines | def _skipped_content_add(self, contents: List[SkippedContent]) -> Dict[str, int]: | ||||
for algo in HASH_ALGORITHMS: | for algo in HASH_ALGORITHMS: | ||||
self._cql_runner.skipped_content_index_add_one(algo, content, token) | self._cql_runner.skipped_content_index_add_one(algo, content, token) | ||||
# Then to the main table | # Then to the main table | ||||
insertion_finalizer() | insertion_finalizer() | ||||
return {"skipped_content:add": len(contents)} | return {"skipped_content:add": len(contents)} | ||||
@timed | |||||
@process_metrics | |||||
def skipped_content_add(self, content: List[SkippedContent]) -> Dict[str, int]: | def skipped_content_add(self, content: List[SkippedContent]) -> Dict[str, int]: | ||||
contents = [attr.evolve(c, ctime=now()) for c in content] | contents = [attr.evolve(c, ctime=now()) for c in content] | ||||
return self._skipped_content_add(contents) | return self._skipped_content_add(contents) | ||||
@timed | |||||
def skipped_content_missing( | def skipped_content_missing( | ||||
self, contents: List[Dict[str, Any]] | self, contents: List[Dict[str, Any]] | ||||
) -> Iterable[Dict[str, Any]]: | ) -> Iterable[Dict[str, Any]]: | ||||
for content in contents: | for content in contents: | ||||
if not self._cql_runner.skipped_content_get_from_pk(content): | if not self._cql_runner.skipped_content_get_from_pk(content): | ||||
yield {algo: content[algo] for algo in DEFAULT_ALGORITHMS} | yield {algo: content[algo] for algo in DEFAULT_ALGORITHMS} | ||||
@timed | |||||
@process_metrics | |||||
def directory_add(self, directories: List[Directory]) -> Dict[str, int]: | def directory_add(self, directories: List[Directory]) -> Dict[str, int]: | ||||
to_add = {d.id: d for d in directories}.values() | to_add = {d.id: d for d in directories}.values() | ||||
if not self._allow_overwrite: | if not self._allow_overwrite: | ||||
# Filter out directories that are already inserted. | # Filter out directories that are already inserted. | ||||
missing = self.directory_missing([dir_.id for dir_ in to_add]) | missing = self.directory_missing([dir_.id for dir_ in to_add]) | ||||
directories = [dir_ for dir_ in directories if dir_.id in missing] | directories = [dir_ for dir_ in directories if dir_.id in missing] | ||||
self.journal_writer.directory_add(directories) | self.journal_writer.directory_add(directories) | ||||
for directory in directories: | for directory in directories: | ||||
# Add directory entries to the 'directory_entry' table | # Add directory entries to the 'directory_entry' table | ||||
for entry in directory.entries: | for entry in directory.entries: | ||||
self._cql_runner.directory_entry_add_one( | self._cql_runner.directory_entry_add_one( | ||||
DirectoryEntryRow(directory_id=directory.id, **entry.to_dict()) | DirectoryEntryRow(directory_id=directory.id, **entry.to_dict()) | ||||
) | ) | ||||
# Add the directory *after* adding all the entries, so someone | # Add the directory *after* adding all the entries, so someone | ||||
# calling snapshot_get_branch in the meantime won't end up | # calling snapshot_get_branch in the meantime won't end up | ||||
# with half the entries. | # with half the entries. | ||||
self._cql_runner.directory_add_one(DirectoryRow(id=directory.id)) | self._cql_runner.directory_add_one(DirectoryRow(id=directory.id)) | ||||
return {"directory:add": len(directories)} | return {"directory:add": len(directories)} | ||||
@timed | |||||
def directory_missing(self, directories: List[Sha1Git]) -> Iterable[Sha1Git]: | def directory_missing(self, directories: List[Sha1Git]) -> Iterable[Sha1Git]: | ||||
return self._cql_runner.directory_missing(directories) | return self._cql_runner.directory_missing(directories) | ||||
def _join_dentry_to_content(self, dentry: DirectoryEntry) -> Dict[str, Any]: | def _join_dentry_to_content(self, dentry: DirectoryEntry) -> Dict[str, Any]: | ||||
contents: Union[List[Content], List[SkippedContentRow]] | contents: Union[List[Content], List[SkippedContentRow]] | ||||
keys = ( | keys = ( | ||||
"status", | "status", | ||||
"sha1", | "sha1", | ||||
Show All 38 Lines | ) -> Iterable[Dict[str, Any]]: | ||||
ret["dir_id"] = directory_id | ret["dir_id"] = directory_id | ||||
yield ret | yield ret | ||||
if recursive and ret["type"] == "dir": | if recursive and ret["type"] == "dir": | ||||
yield from self._directory_ls( | yield from self._directory_ls( | ||||
ret["target"], True, prefix + ret["name"] + b"/" | ret["target"], True, prefix + ret["name"] + b"/" | ||||
) | ) | ||||
@timed | |||||
def directory_entry_get_by_path( | def directory_entry_get_by_path( | ||||
self, directory: Sha1Git, paths: List[bytes] | self, directory: Sha1Git, paths: List[bytes] | ||||
) -> Optional[Dict[str, Any]]: | ) -> Optional[Dict[str, Any]]: | ||||
return self._directory_entry_get_by_path(directory, paths, b"") | return self._directory_entry_get_by_path(directory, paths, b"") | ||||
def _directory_entry_get_by_path( | def _directory_entry_get_by_path( | ||||
self, directory: Sha1Git, paths: List[bytes], prefix: bytes | self, directory: Sha1Git, paths: List[bytes], prefix: bytes | ||||
) -> Optional[Dict[str, Any]]: | ) -> Optional[Dict[str, Any]]: | ||||
Show All 23 Lines | ) -> Optional[Dict[str, Any]]: | ||||
if not first_item or first_item["type"] != "dir": | if not first_item or first_item["type"] != "dir": | ||||
return None | return None | ||||
return self._directory_entry_get_by_path( | return self._directory_entry_get_by_path( | ||||
first_item["target"], paths[1:], prefix + paths[0] + b"/" | first_item["target"], paths[1:], prefix + paths[0] + b"/" | ||||
) | ) | ||||
@timed | |||||
def directory_ls( | def directory_ls( | ||||
self, directory: Sha1Git, recursive: bool = False | self, directory: Sha1Git, recursive: bool = False | ||||
) -> Iterable[Dict[str, Any]]: | ) -> Iterable[Dict[str, Any]]: | ||||
yield from self._directory_ls(directory, recursive) | yield from self._directory_ls(directory, recursive) | ||||
@timed | |||||
def directory_get_entries( | def directory_get_entries( | ||||
self, | self, | ||||
directory_id: Sha1Git, | directory_id: Sha1Git, | ||||
page_token: Optional[bytes] = None, | page_token: Optional[bytes] = None, | ||||
limit: int = 1000, | limit: int = 1000, | ||||
) -> Optional[PagedResult[DirectoryEntry]]: | ) -> Optional[PagedResult[DirectoryEntry]]: | ||||
if self.directory_missing([directory_id]): | if self.directory_missing([directory_id]): | ||||
return None | return None | ||||
entries_from: bytes = page_token or b"" | entries_from: bytes = page_token or b"" | ||||
rows = self._cql_runner.directory_entry_get_from_name( | rows = self._cql_runner.directory_entry_get_from_name( | ||||
directory_id, entries_from, limit + 1 | directory_id, entries_from, limit + 1 | ||||
) | ) | ||||
entries = [ | entries = [ | ||||
DirectoryEntry.from_dict(remove_keys(row.to_dict(), ("directory_id",))) | DirectoryEntry.from_dict(remove_keys(row.to_dict(), ("directory_id",))) | ||||
for row in rows | for row in rows | ||||
] | ] | ||||
if len(entries) > limit: | if len(entries) > limit: | ||||
last_entry = entries.pop() | last_entry = entries.pop() | ||||
next_page_token = last_entry.name | next_page_token = last_entry.name | ||||
else: | else: | ||||
next_page_token = None | next_page_token = None | ||||
return PagedResult(results=entries, next_page_token=next_page_token) | return PagedResult(results=entries, next_page_token=next_page_token) | ||||
@timed | |||||
def directory_get_random(self) -> Sha1Git: | def directory_get_random(self) -> Sha1Git: | ||||
directory = self._cql_runner.directory_get_random() | directory = self._cql_runner.directory_get_random() | ||||
assert directory, "Could not find any directory" | assert directory, "Could not find any directory" | ||||
return directory.id | return directory.id | ||||
@timed | |||||
@process_metrics | |||||
def revision_add(self, revisions: List[Revision]) -> Dict[str, int]: | def revision_add(self, revisions: List[Revision]) -> Dict[str, int]: | ||||
# Filter-out revisions already in the database | # Filter-out revisions already in the database | ||||
if not self._allow_overwrite: | if not self._allow_overwrite: | ||||
to_add = {r.id: r for r in revisions}.values() | to_add = {r.id: r for r in revisions}.values() | ||||
missing = self.revision_missing([rev.id for rev in to_add]) | missing = self.revision_missing([rev.id for rev in to_add]) | ||||
revisions = [rev for rev in revisions if rev.id in missing] | revisions = [rev for rev in revisions if rev.id in missing] | ||||
self.journal_writer.revision_add(revisions) | self.journal_writer.revision_add(revisions) | ||||
Show All 11 Lines | def revision_add(self, revisions: List[Revision]) -> Dict[str, int]: | ||||
# Then write the main revision row. | # Then write the main revision row. | ||||
# Writing this after all parents were written ensures that | # Writing this after all parents were written ensures that | ||||
# read endpoints don't return a partial view while writing | # read endpoints don't return a partial view while writing | ||||
# the parents | # the parents | ||||
self._cql_runner.revision_add_one(revobject) | self._cql_runner.revision_add_one(revobject) | ||||
return {"revision:add": len(revisions)} | return {"revision:add": len(revisions)} | ||||
@timed | |||||
def revision_missing(self, revisions: List[Sha1Git]) -> Iterable[Sha1Git]: | def revision_missing(self, revisions: List[Sha1Git]) -> Iterable[Sha1Git]: | ||||
return self._cql_runner.revision_missing(revisions) | return self._cql_runner.revision_missing(revisions) | ||||
@timed | |||||
def revision_get(self, revision_ids: List[Sha1Git]) -> List[Optional[Revision]]: | def revision_get(self, revision_ids: List[Sha1Git]) -> List[Optional[Revision]]: | ||||
rows = self._cql_runner.revision_get(revision_ids) | rows = self._cql_runner.revision_get(revision_ids) | ||||
revisions: Dict[Sha1Git, Revision] = {} | revisions: Dict[Sha1Git, Revision] = {} | ||||
for row in rows: | for row in rows: | ||||
# TODO: use a single query to get all parents? | # TODO: use a single query to get all parents? | ||||
# (it might have lower latency, but requires more code and more | # (it might have lower latency, but requires more code and more | ||||
# bandwidth, because revision id would be part of each returned | # bandwidth, because revision id would be part of each returned | ||||
# row) | # row) | ||||
▲ Show 20 Lines • Show All 50 Lines • ▼ Show 20 Lines | ]: | ||||
# parent_rank is the clustering key, so results are already | # parent_rank is the clustering key, so results are already | ||||
# sorted by rank. | # sorted by rank. | ||||
rev = converters.revision_from_db(row, parents=parents) | rev = converters.revision_from_db(row, parents=parents) | ||||
yield rev.to_dict() | yield rev.to_dict() | ||||
yield from self._get_parent_revs(parents, seen, limit, short) | yield from self._get_parent_revs(parents, seen, limit, short) | ||||
@timed | |||||
def revision_log( | def revision_log( | ||||
self, revisions: List[Sha1Git], limit: Optional[int] = None | self, revisions: List[Sha1Git], limit: Optional[int] = None | ||||
) -> Iterable[Optional[Dict[str, Any]]]: | ) -> Iterable[Optional[Dict[str, Any]]]: | ||||
seen: Set[Sha1Git] = set() | seen: Set[Sha1Git] = set() | ||||
yield from self._get_parent_revs(revisions, seen, limit, False) | yield from self._get_parent_revs(revisions, seen, limit, False) | ||||
@timed | |||||
def revision_shortlog( | def revision_shortlog( | ||||
self, revisions: List[Sha1Git], limit: Optional[int] = None | self, revisions: List[Sha1Git], limit: Optional[int] = None | ||||
) -> Iterable[Optional[Tuple[Sha1Git, Tuple[Sha1Git, ...]]]]: | ) -> Iterable[Optional[Tuple[Sha1Git, Tuple[Sha1Git, ...]]]]: | ||||
seen: Set[Sha1Git] = set() | seen: Set[Sha1Git] = set() | ||||
yield from self._get_parent_revs(revisions, seen, limit, True) | yield from self._get_parent_revs(revisions, seen, limit, True) | ||||
@timed | |||||
def revision_get_random(self) -> Sha1Git: | def revision_get_random(self) -> Sha1Git: | ||||
revision = self._cql_runner.revision_get_random() | revision = self._cql_runner.revision_get_random() | ||||
assert revision, "Could not find any revision" | assert revision, "Could not find any revision" | ||||
return revision.id | return revision.id | ||||
@timed | |||||
@process_metrics | |||||
def release_add(self, releases: List[Release]) -> Dict[str, int]: | def release_add(self, releases: List[Release]) -> Dict[str, int]: | ||||
if not self._allow_overwrite: | if not self._allow_overwrite: | ||||
to_add = {r.id: r for r in releases}.values() | to_add = {r.id: r for r in releases}.values() | ||||
missing = set(self.release_missing([rel.id for rel in to_add])) | missing = set(self.release_missing([rel.id for rel in to_add])) | ||||
releases = [rel for rel in to_add if rel.id in missing] | releases = [rel for rel in to_add if rel.id in missing] | ||||
self.journal_writer.release_add(releases) | self.journal_writer.release_add(releases) | ||||
for release in releases: | for release in releases: | ||||
if release: | if release: | ||||
self._cql_runner.release_add_one(converters.release_to_db(release)) | self._cql_runner.release_add_one(converters.release_to_db(release)) | ||||
return {"release:add": len(releases)} | return {"release:add": len(releases)} | ||||
@timed | |||||
def release_missing(self, releases: List[Sha1Git]) -> Iterable[Sha1Git]: | def release_missing(self, releases: List[Sha1Git]) -> Iterable[Sha1Git]: | ||||
return self._cql_runner.release_missing(releases) | return self._cql_runner.release_missing(releases) | ||||
@timed | |||||
def release_get(self, releases: List[Sha1Git]) -> List[Optional[Release]]: | def release_get(self, releases: List[Sha1Git]) -> List[Optional[Release]]: | ||||
rows = self._cql_runner.release_get(releases) | rows = self._cql_runner.release_get(releases) | ||||
rels: Dict[Sha1Git, Release] = {} | rels: Dict[Sha1Git, Release] = {} | ||||
for row in rows: | for row in rows: | ||||
release = converters.release_from_db(row) | release = converters.release_from_db(row) | ||||
rels[row.id] = release | rels[row.id] = release | ||||
return [rels.get(rel_id) for rel_id in releases] | return [rels.get(rel_id) for rel_id in releases] | ||||
@timed | |||||
def release_get_random(self) -> Sha1Git: | def release_get_random(self) -> Sha1Git: | ||||
release = self._cql_runner.release_get_random() | release = self._cql_runner.release_get_random() | ||||
assert release, "Could not find any release" | assert release, "Could not find any release" | ||||
return release.id | return release.id | ||||
@timed | |||||
@process_metrics | |||||
def snapshot_add(self, snapshots: List[Snapshot]) -> Dict[str, int]: | def snapshot_add(self, snapshots: List[Snapshot]) -> Dict[str, int]: | ||||
if not self._allow_overwrite: | if not self._allow_overwrite: | ||||
to_add = {s.id: s for s in snapshots}.values() | to_add = {s.id: s for s in snapshots}.values() | ||||
missing = self._cql_runner.snapshot_missing([snp.id for snp in to_add]) | missing = self._cql_runner.snapshot_missing([snp.id for snp in to_add]) | ||||
snapshots = [snp for snp in snapshots if snp.id in missing] | snapshots = [snp for snp in snapshots if snp.id in missing] | ||||
for snapshot in snapshots: | for snapshot in snapshots: | ||||
self.journal_writer.snapshot_add([snapshot]) | self.journal_writer.snapshot_add([snapshot]) | ||||
Show All 17 Lines | def snapshot_add(self, snapshots: List[Snapshot]) -> Dict[str, int]: | ||||
# Add the snapshot *after* adding all the branches, so someone | # Add the snapshot *after* adding all the branches, so someone | ||||
# calling snapshot_get_branch in the meantime won't end up | # calling snapshot_get_branch in the meantime won't end up | ||||
# with half the branches. | # with half the branches. | ||||
self._cql_runner.snapshot_add_one(SnapshotRow(id=snapshot.id)) | self._cql_runner.snapshot_add_one(SnapshotRow(id=snapshot.id)) | ||||
return {"snapshot:add": len(snapshots)} | return {"snapshot:add": len(snapshots)} | ||||
@timed | |||||
def snapshot_missing(self, snapshots: List[Sha1Git]) -> Iterable[Sha1Git]: | def snapshot_missing(self, snapshots: List[Sha1Git]) -> Iterable[Sha1Git]: | ||||
return self._cql_runner.snapshot_missing(snapshots) | return self._cql_runner.snapshot_missing(snapshots) | ||||
@timed | |||||
def snapshot_get(self, snapshot_id: Sha1Git) -> Optional[Dict[str, Any]]: | def snapshot_get(self, snapshot_id: Sha1Git) -> Optional[Dict[str, Any]]: | ||||
d = self.snapshot_get_branches(snapshot_id) | d = self.snapshot_get_branches(snapshot_id) | ||||
if d is None: | if d is None: | ||||
return None | return None | ||||
return { | return { | ||||
"id": d["id"], | "id": d["id"], | ||||
"branches": { | "branches": { | ||||
name: branch.to_dict() if branch else None | name: branch.to_dict() if branch else None | ||||
for (name, branch) in d["branches"].items() | for (name, branch) in d["branches"].items() | ||||
}, | }, | ||||
"next_branch": d["next_branch"], | "next_branch": d["next_branch"], | ||||
} | } | ||||
@timed | |||||
def snapshot_count_branches( | def snapshot_count_branches( | ||||
self, snapshot_id: Sha1Git, branch_name_exclude_prefix: Optional[bytes] = None, | self, snapshot_id: Sha1Git, branch_name_exclude_prefix: Optional[bytes] = None, | ||||
) -> Optional[Dict[Optional[str], int]]: | ) -> Optional[Dict[Optional[str], int]]: | ||||
if self._cql_runner.snapshot_missing([snapshot_id]): | if self._cql_runner.snapshot_missing([snapshot_id]): | ||||
# Makes sure we don't fetch branches for a snapshot that is | # Makes sure we don't fetch branches for a snapshot that is | ||||
# being added. | # being added. | ||||
return None | return None | ||||
return self._cql_runner.snapshot_count_branches( | return self._cql_runner.snapshot_count_branches( | ||||
snapshot_id, branch_name_exclude_prefix | snapshot_id, branch_name_exclude_prefix | ||||
) | ) | ||||
@timed | |||||
def snapshot_get_branches( | def snapshot_get_branches( | ||||
self, | self, | ||||
snapshot_id: Sha1Git, | snapshot_id: Sha1Git, | ||||
branches_from: bytes = b"", | branches_from: bytes = b"", | ||||
branches_count: int = 1000, | branches_count: int = 1000, | ||||
target_types: Optional[List[str]] = None, | target_types: Optional[List[str]] = None, | ||||
branch_name_include_substring: Optional[bytes] = None, | branch_name_include_substring: Optional[bytes] = None, | ||||
branch_name_exclude_prefix: Optional[bytes] = None, | branch_name_exclude_prefix: Optional[bytes] = None, | ||||
▲ Show 20 Lines • Show All 59 Lines • ▼ Show 20 Lines | ) -> Optional[PartialBranches]: | ||||
else SnapshotBranch( | else SnapshotBranch( | ||||
target=branch.target, target_type=TargetType(branch.target_type) | target=branch.target, target_type=TargetType(branch.target_type) | ||||
) | ) | ||||
for branch in branches | for branch in branches | ||||
}, | }, | ||||
next_branch=last_branch, | next_branch=last_branch, | ||||
) | ) | ||||
@timed | |||||
def snapshot_get_random(self) -> Sha1Git: | def snapshot_get_random(self) -> Sha1Git: | ||||
snapshot = self._cql_runner.snapshot_get_random() | snapshot = self._cql_runner.snapshot_get_random() | ||||
assert snapshot, "Could not find any snapshot" | assert snapshot, "Could not find any snapshot" | ||||
return snapshot.id | return snapshot.id | ||||
@timed | |||||
def object_find_by_sha1_git(self, ids: List[Sha1Git]) -> Dict[Sha1Git, List[Dict]]: | def object_find_by_sha1_git(self, ids: List[Sha1Git]) -> Dict[Sha1Git, List[Dict]]: | ||||
results: Dict[Sha1Git, List[Dict]] = {id_: [] for id_ in ids} | results: Dict[Sha1Git, List[Dict]] = {id_: [] for id_ in ids} | ||||
missing_ids = set(ids) | missing_ids = set(ids) | ||||
# Mind the order, revision is the most likely one for a given ID, | # Mind the order, revision is the most likely one for a given ID, | ||||
# so we check revisions first. | # so we check revisions first. | ||||
queries: List[Tuple[str, Callable[[List[Sha1Git]], List[Sha1Git]]]] = [ | queries: List[Tuple[str, Callable[[List[Sha1Git]], List[Sha1Git]]]] = [ | ||||
("revision", self._cql_runner.revision_missing), | ("revision", self._cql_runner.revision_missing), | ||||
Show All 11 Lines | def object_find_by_sha1_git(self, ids: List[Sha1Git]) -> Dict[Sha1Git, List[Dict]]: | ||||
missing_ids.remove(sha1_git) | missing_ids.remove(sha1_git) | ||||
if not missing_ids: | if not missing_ids: | ||||
# We found everything, skipping the next queries. | # We found everything, skipping the next queries. | ||||
break | break | ||||
return results | return results | ||||
@timed | |||||
def origin_get(self, origins: List[str]) -> Iterable[Optional[Origin]]: | def origin_get(self, origins: List[str]) -> Iterable[Optional[Origin]]: | ||||
return [self.origin_get_one(origin) for origin in origins] | return [self.origin_get_one(origin) for origin in origins] | ||||
@timed | |||||
def origin_get_one(self, origin_url: str) -> Optional[Origin]: | def origin_get_one(self, origin_url: str) -> Optional[Origin]: | ||||
"""Given an origin url, return the origin if it exists, None otherwise | """Given an origin url, return the origin if it exists, None otherwise | ||||
""" | """ | ||||
rows = list(self._cql_runner.origin_get_by_url(origin_url)) | rows = list(self._cql_runner.origin_get_by_url(origin_url)) | ||||
if rows: | if rows: | ||||
assert len(rows) == 1 | assert len(rows) == 1 | ||||
return Origin(url=rows[0].url) | return Origin(url=rows[0].url) | ||||
else: | else: | ||||
return None | return None | ||||
@timed | |||||
def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]: | def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]: | ||||
results = [] | results = [] | ||||
for sha1 in sha1s: | for sha1 in sha1s: | ||||
rows = list(self._cql_runner.origin_get_by_sha1(sha1)) | rows = list(self._cql_runner.origin_get_by_sha1(sha1)) | ||||
origin = {"url": rows[0].url} if rows else None | origin = {"url": rows[0].url} if rows else None | ||||
results.append(origin) | results.append(origin) | ||||
return results | return results | ||||
@timed | |||||
def origin_list( | def origin_list( | ||||
self, page_token: Optional[str] = None, limit: int = 100 | self, page_token: Optional[str] = None, limit: int = 100 | ||||
) -> PagedResult[Origin]: | ) -> PagedResult[Origin]: | ||||
# Compute what token to begin the listing from | # Compute what token to begin the listing from | ||||
start_token = TOKEN_BEGIN | start_token = TOKEN_BEGIN | ||||
if page_token: | if page_token: | ||||
start_token = int(page_token) | start_token = int(page_token) | ||||
if not (TOKEN_BEGIN <= start_token <= TOKEN_END): | if not (TOKEN_BEGIN <= start_token <= TOKEN_END): | ||||
Show All 12 Lines | ) -> PagedResult[Origin]: | ||||
next_page_token = str(last_id) | next_page_token = str(last_id) | ||||
# excluding that origin from the result to respect the limit size | # excluding that origin from the result to respect the limit size | ||||
origins = origins[:limit] | origins = origins[:limit] | ||||
assert len(origins) <= limit | assert len(origins) <= limit | ||||
return PagedResult(results=origins, next_page_token=next_page_token) | return PagedResult(results=origins, next_page_token=next_page_token) | ||||
@timed | |||||
def origin_search( | def origin_search( | ||||
self, | self, | ||||
url_pattern: str, | url_pattern: str, | ||||
page_token: Optional[str] = None, | page_token: Optional[str] = None, | ||||
limit: int = 50, | limit: int = 50, | ||||
regexp: bool = False, | regexp: bool = False, | ||||
with_visit: bool = False, | with_visit: bool = False, | ||||
visit_types: Optional[List[str]] = None, | visit_types: Optional[List[str]] = None, | ||||
Show All 31 Lines | ) -> PagedResult[Origin]: | ||||
# next offset | # next offset | ||||
next_page_token = str(offset + limit) | next_page_token = str(offset + limit) | ||||
# excluding that origin from the result to respect the limit size | # excluding that origin from the result to respect the limit size | ||||
origins = origins[:limit] | origins = origins[:limit] | ||||
assert len(origins) <= limit | assert len(origins) <= limit | ||||
return PagedResult(results=origins, next_page_token=next_page_token) | return PagedResult(results=origins, next_page_token=next_page_token) | ||||
@timed | |||||
def origin_count( | def origin_count( | ||||
self, url_pattern: str, regexp: bool = False, with_visit: bool = False | self, url_pattern: str, regexp: bool = False, with_visit: bool = False | ||||
) -> int: | ) -> int: | ||||
raise NotImplementedError( | raise NotImplementedError( | ||||
"The Cassandra backend does not implement origin_count" | "The Cassandra backend does not implement origin_count" | ||||
) | ) | ||||
@timed | |||||
@process_metrics | |||||
def origin_add(self, origins: List[Origin]) -> Dict[str, int]: | def origin_add(self, origins: List[Origin]) -> Dict[str, int]: | ||||
if not self._allow_overwrite: | if not self._allow_overwrite: | ||||
to_add = {o.url: o for o in origins}.values() | to_add = {o.url: o for o in origins}.values() | ||||
origins = [ori for ori in to_add if self.origin_get_one(ori.url) is None] | origins = [ori for ori in to_add if self.origin_get_one(ori.url) is None] | ||||
self.journal_writer.origin_add(origins) | self.journal_writer.origin_add(origins) | ||||
for origin in origins: | for origin in origins: | ||||
self._cql_runner.origin_add_one( | self._cql_runner.origin_add_one( | ||||
OriginRow(sha1=hash_url(origin.url), url=origin.url, next_visit_id=1) | OriginRow(sha1=hash_url(origin.url), url=origin.url, next_visit_id=1) | ||||
) | ) | ||||
return {"origin:add": len(origins)} | return {"origin:add": len(origins)} | ||||
@timed | |||||
def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]: | def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]: | ||||
for visit in visits: | for visit in visits: | ||||
origin = self.origin_get_one(visit.origin) | origin = self.origin_get_one(visit.origin) | ||||
if not origin: # Cannot add a visit without an origin | if not origin: # Cannot add a visit without an origin | ||||
raise StorageArgumentException("Unknown origin %s", visit.origin) | raise StorageArgumentException("Unknown origin %s", visit.origin) | ||||
all_visits = [] | all_visits = [] | ||||
nb_visits = 0 | nb_visits = 0 | ||||
Show All 17 Lines | def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]: | ||||
origin=visit.origin, | origin=visit.origin, | ||||
visit=visit.visit, | visit=visit.visit, | ||||
date=visit.date, | date=visit.date, | ||||
type=visit.type, | type=visit.type, | ||||
status="created", | status="created", | ||||
snapshot=None, | snapshot=None, | ||||
) | ) | ||||
) | ) | ||||
send_metric("origin_visit:add", count=nb_visits, method_name="origin_visit") | |||||
return all_visits | return all_visits | ||||
def _origin_visit_status_add(self, visit_status: OriginVisitStatus) -> None: | def _origin_visit_status_add(self, visit_status: OriginVisitStatus) -> None: | ||||
"""Add an origin visit status""" | """Add an origin visit status""" | ||||
if visit_status.type is None: | if visit_status.type is None: | ||||
visit_row = self._cql_runner.origin_visit_get_one( | visit_row = self._cql_runner.origin_visit_get_one( | ||||
visit_status.origin, visit_status.visit | visit_status.origin, visit_status.visit | ||||
) | ) | ||||
if visit_row is None: | if visit_row is None: | ||||
raise StorageArgumentException( | raise StorageArgumentException( | ||||
f"Unknown origin visit {visit_status.visit} " | f"Unknown origin visit {visit_status.visit} " | ||||
f"of origin {visit_status.origin}" | f"of origin {visit_status.origin}" | ||||
) | ) | ||||
visit_status = attr.evolve(visit_status, type=visit_row.type) | visit_status = attr.evolve(visit_status, type=visit_row.type) | ||||
self.journal_writer.origin_visit_status_add([visit_status]) | self.journal_writer.origin_visit_status_add([visit_status]) | ||||
self._cql_runner.origin_visit_status_add_one( | self._cql_runner.origin_visit_status_add_one( | ||||
converters.visit_status_to_row(visit_status) | converters.visit_status_to_row(visit_status) | ||||
) | ) | ||||
@timed | |||||
@process_metrics | |||||
def origin_visit_status_add( | def origin_visit_status_add( | ||||
self, visit_statuses: List[OriginVisitStatus] | self, visit_statuses: List[OriginVisitStatus] | ||||
) -> Dict[str, int]: | ) -> Dict[str, int]: | ||||
# First round to check existence (fail early if any is ko) | # First round to check existence (fail early if any is ko) | ||||
for visit_status in visit_statuses: | for visit_status in visit_statuses: | ||||
origin_url = self.origin_get_one(visit_status.origin) | origin_url = self.origin_get_one(visit_status.origin) | ||||
if not origin_url: | if not origin_url: | ||||
raise StorageArgumentException(f"Unknown origin {visit_status.origin}") | raise StorageArgumentException(f"Unknown origin {visit_status.origin}") | ||||
Show All 37 Lines | |||||
@staticmethod | @staticmethod | ||||
def _format_origin_visit_row(visit): | def _format_origin_visit_row(visit): | ||||
return { | return { | ||||
**visit.to_dict(), | **visit.to_dict(), | ||||
"origin": visit.origin, | "origin": visit.origin, | ||||
"date": visit.date.replace(tzinfo=datetime.timezone.utc), | "date": visit.date.replace(tzinfo=datetime.timezone.utc), | ||||
} | } | ||||
@timed | |||||
def origin_visit_get( | def origin_visit_get( | ||||
self, | self, | ||||
origin: str, | origin: str, | ||||
page_token: Optional[str] = None, | page_token: Optional[str] = None, | ||||
order: ListOrder = ListOrder.ASC, | order: ListOrder = ListOrder.ASC, | ||||
limit: int = 10, | limit: int = 10, | ||||
) -> PagedResult[OriginVisit]: | ) -> PagedResult[OriginVisit]: | ||||
if not isinstance(order, ListOrder): | if not isinstance(order, ListOrder): | ||||
Show All 12 Lines | ) -> PagedResult[OriginVisit]: | ||||
assert len(visits) <= extra_limit | assert len(visits) <= extra_limit | ||||
if len(visits) == extra_limit: | if len(visits) == extra_limit: | ||||
visits = visits[:limit] | visits = visits[:limit] | ||||
next_page_token = str(visits[-1].visit) | next_page_token = str(visits[-1].visit) | ||||
return PagedResult(results=visits, next_page_token=next_page_token) | return PagedResult(results=visits, next_page_token=next_page_token) | ||||
@timed | |||||
def origin_visit_status_get( | def origin_visit_status_get( | ||||
self, | self, | ||||
origin: str, | origin: str, | ||||
visit: int, | visit: int, | ||||
page_token: Optional[str] = None, | page_token: Optional[str] = None, | ||||
order: ListOrder = ListOrder.ASC, | order: ListOrder = ListOrder.ASC, | ||||
limit: int = 10, | limit: int = 10, | ||||
) -> PagedResult[OriginVisitStatus]: | ) -> PagedResult[OriginVisitStatus]: | ||||
Show All 10 Lines | ) -> PagedResult[OriginVisitStatus]: | ||||
if len(visit_statuses) > limit: | if len(visit_statuses) > limit: | ||||
# last visit status date is the next page token | # last visit status date is the next page token | ||||
next_page_token = str(visit_statuses[-1].date) | next_page_token = str(visit_statuses[-1].date) | ||||
# excluding that visit status from the result to respect the limit size | # excluding that visit status from the result to respect the limit size | ||||
visit_statuses = visit_statuses[:limit] | visit_statuses = visit_statuses[:limit] | ||||
return PagedResult(results=visit_statuses, next_page_token=next_page_token) | return PagedResult(results=visit_statuses, next_page_token=next_page_token) | ||||
@timed | |||||
def origin_visit_find_by_date( | def origin_visit_find_by_date( | ||||
self, origin: str, visit_date: datetime.datetime | self, origin: str, visit_date: datetime.datetime | ||||
) -> Optional[OriginVisit]: | ) -> Optional[OriginVisit]: | ||||
# Iterator over all the visits of the origin | # Iterator over all the visits of the origin | ||||
# This should be ok for now, as there aren't too many visits | # This should be ok for now, as there aren't too many visits | ||||
# per origin. | # per origin. | ||||
rows = list(self._cql_runner.origin_visit_get_all(origin)) | rows = list(self._cql_runner.origin_visit_get_all(origin)) | ||||
def key(visit): | def key(visit): | ||||
dt = visit.date.replace(tzinfo=datetime.timezone.utc) - visit_date | dt = visit.date.replace(tzinfo=datetime.timezone.utc) - visit_date | ||||
return (abs(dt), -visit.visit) | return (abs(dt), -visit.visit) | ||||
if rows: | if rows: | ||||
return converters.row_to_visit(min(rows, key=key)) | return converters.row_to_visit(min(rows, key=key)) | ||||
return None | return None | ||||
@timed | |||||
def origin_visit_get_by(self, origin: str, visit: int) -> Optional[OriginVisit]: | def origin_visit_get_by(self, origin: str, visit: int) -> Optional[OriginVisit]: | ||||
row = self._cql_runner.origin_visit_get_one(origin, visit) | row = self._cql_runner.origin_visit_get_one(origin, visit) | ||||
if row: | if row: | ||||
return converters.row_to_visit(row) | return converters.row_to_visit(row) | ||||
return None | return None | ||||
@timed | |||||
def origin_visit_get_latest( | def origin_visit_get_latest( | ||||
self, | self, | ||||
origin: str, | origin: str, | ||||
type: Optional[str] = None, | type: Optional[str] = None, | ||||
allowed_statuses: Optional[List[str]] = None, | allowed_statuses: Optional[List[str]] = None, | ||||
require_snapshot: bool = False, | require_snapshot: bool = False, | ||||
) -> Optional[OriginVisit]: | ) -> Optional[OriginVisit]: | ||||
if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): | if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): | ||||
Show All 31 Lines | ) -> Optional[OriginVisit]: | ||||
return None | return None | ||||
return OriginVisit( | return OriginVisit( | ||||
origin=latest_visit["origin"], | origin=latest_visit["origin"], | ||||
visit=latest_visit["visit"], | visit=latest_visit["visit"], | ||||
date=latest_visit["date"], | date=latest_visit["date"], | ||||
type=latest_visit["type"], | type=latest_visit["type"], | ||||
) | ) | ||||
@timed | |||||
def origin_visit_status_get_latest( | def origin_visit_status_get_latest( | ||||
self, | self, | ||||
origin_url: str, | origin_url: str, | ||||
visit: int, | visit: int, | ||||
allowed_statuses: Optional[List[str]] = None, | allowed_statuses: Optional[List[str]] = None, | ||||
require_snapshot: bool = False, | require_snapshot: bool = False, | ||||
) -> Optional[OriginVisitStatus]: | ) -> Optional[OriginVisitStatus]: | ||||
if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): | if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): | ||||
raise StorageArgumentException( | raise StorageArgumentException( | ||||
f"Unknown allowed statuses {','.join(allowed_statuses)}, only " | f"Unknown allowed statuses {','.join(allowed_statuses)}, only " | ||||
f"{','.join(VISIT_STATUSES)} authorized" | f"{','.join(VISIT_STATUSES)} authorized" | ||||
) | ) | ||||
rows = list(self._cql_runner.origin_visit_status_get(origin_url, visit)) | rows = list(self._cql_runner.origin_visit_status_get(origin_url, visit)) | ||||
# filtering is done python side as we cannot do it server side | # filtering is done python side as we cannot do it server side | ||||
if allowed_statuses: | if allowed_statuses: | ||||
rows = [row for row in rows if row.status in allowed_statuses] | rows = [row for row in rows if row.status in allowed_statuses] | ||||
if require_snapshot: | if require_snapshot: | ||||
rows = [row for row in rows if row.snapshot is not None] | rows = [row for row in rows if row.snapshot is not None] | ||||
if not rows: | if not rows: | ||||
return None | return None | ||||
return converters.row_to_visit_status(rows[0]) | return converters.row_to_visit_status(rows[0]) | ||||
@timed | |||||
def origin_visit_status_get_random(self, type: str) -> Optional[OriginVisitStatus]: | def origin_visit_status_get_random(self, type: str) -> Optional[OriginVisitStatus]: | ||||
back_in_the_day = now() - datetime.timedelta(weeks=12) # 3 months back | back_in_the_day = now() - datetime.timedelta(weeks=12) # 3 months back | ||||
# Random position to start iteration at | # Random position to start iteration at | ||||
start_token = random.randint(TOKEN_BEGIN, TOKEN_END) | start_token = random.randint(TOKEN_BEGIN, TOKEN_END) | ||||
# Iterator over all visits, ordered by token(origins) then visit_id | # Iterator over all visits, ordered by token(origins) then visit_id | ||||
rows = self._cql_runner.origin_visit_iter(start_token) | rows = self._cql_runner.origin_visit_iter(start_token) | ||||
for row in rows: | for row in rows: | ||||
visit = converters.row_to_visit(row) | visit = converters.row_to_visit(row) | ||||
visit_status = self._origin_visit_get_latest_status(visit) | visit_status = self._origin_visit_get_latest_status(visit) | ||||
if visit.date > back_in_the_day and visit_status.status == "full": | if visit.date > back_in_the_day and visit_status.status == "full": | ||||
return visit_status | return visit_status | ||||
return None | return None | ||||
@timed | |||||
def stat_counters(self): | def stat_counters(self): | ||||
rows = self._cql_runner.stat_counters() | rows = self._cql_runner.stat_counters() | ||||
keys = ( | keys = ( | ||||
"content", | "content", | ||||
"directory", | "directory", | ||||
"origin", | "origin", | ||||
"origin_visit", | "origin_visit", | ||||
"release", | "release", | ||||
"revision", | "revision", | ||||
"skipped_content", | "skipped_content", | ||||
"snapshot", | "snapshot", | ||||
) | ) | ||||
stats = {key: 0 for key in keys} | stats = {key: 0 for key in keys} | ||||
stats.update({row.object_type: row.count for row in rows}) | stats.update({row.object_type: row.count for row in rows}) | ||||
return stats | return stats | ||||
@timed | |||||
def refresh_stat_counters(self): | def refresh_stat_counters(self): | ||||
pass | pass | ||||
@timed | |||||
@process_metrics | |||||
def raw_extrinsic_metadata_add( | def raw_extrinsic_metadata_add( | ||||
self, metadata: List[RawExtrinsicMetadata] | self, metadata: List[RawExtrinsicMetadata] | ||||
) -> Dict[str, int]: | ) -> Dict[str, int]: | ||||
self.journal_writer.raw_extrinsic_metadata_add(metadata) | self.journal_writer.raw_extrinsic_metadata_add(metadata) | ||||
counter = Counter[ExtendedObjectType]() | counter = Counter[ExtendedObjectType]() | ||||
for metadata_entry in metadata: | for metadata_entry in metadata: | ||||
if not self._cql_runner.metadata_authority_get( | if not self._cql_runner.metadata_authority_get( | ||||
metadata_entry.authority.type.value, metadata_entry.authority.url | metadata_entry.authority.type.value, metadata_entry.authority.url | ||||
▲ Show 20 Lines • Show All 43 Lines • ▼ Show 20 Lines | ) -> Dict[str, int]: | ||||
# Then to the main table | # Then to the main table | ||||
self._cql_runner.raw_extrinsic_metadata_add(row) | self._cql_runner.raw_extrinsic_metadata_add(row) | ||||
counter[metadata_entry.target.object_type] += 1 | counter[metadata_entry.target.object_type] += 1 | ||||
return { | return { | ||||
f"{type.value}_metadata:add": count for (type, count) in counter.items() | f"{type.value}_metadata:add": count for (type, count) in counter.items() | ||||
} | } | ||||
@timed | |||||
def raw_extrinsic_metadata_get( | def raw_extrinsic_metadata_get( | ||||
self, | self, | ||||
target: ExtendedSWHID, | target: ExtendedSWHID, | ||||
authority: MetadataAuthority, | authority: MetadataAuthority, | ||||
after: Optional[datetime.datetime] = None, | after: Optional[datetime.datetime] = None, | ||||
page_token: Optional[bytes] = None, | page_token: Optional[bytes] = None, | ||||
limit: int = 1000, | limit: int = 1000, | ||||
) -> PagedResult[RawExtrinsicMetadata]: | ) -> PagedResult[RawExtrinsicMetadata]: | ||||
Show All 31 Lines | ) -> PagedResult[RawExtrinsicMetadata]: | ||||
next_page_token: Optional[str] = base64.b64encode( | next_page_token: Optional[str] = base64.b64encode( | ||||
msgpack_dumps((last_result.discovery_date, last_result.id,)) | msgpack_dumps((last_result.discovery_date, last_result.id,)) | ||||
).decode() | ).decode() | ||||
else: | else: | ||||
next_page_token = None | next_page_token = None | ||||
return PagedResult(next_page_token=next_page_token, results=results,) | return PagedResult(next_page_token=next_page_token, results=results,) | ||||
@timed | |||||
def raw_extrinsic_metadata_get_by_ids( | def raw_extrinsic_metadata_get_by_ids( | ||||
self, ids: List[Sha1Git] | self, ids: List[Sha1Git] | ||||
) -> List[RawExtrinsicMetadata]: | ) -> List[RawExtrinsicMetadata]: | ||||
keys = self._cql_runner.raw_extrinsic_metadata_get_by_ids(ids) | keys = self._cql_runner.raw_extrinsic_metadata_get_by_ids(ids) | ||||
results: Set[RawExtrinsicMetadata] = set() | results: Set[RawExtrinsicMetadata] = set() | ||||
for key in keys: | for key in keys: | ||||
candidates = self._cql_runner.raw_extrinsic_metadata_get( | candidates = self._cql_runner.raw_extrinsic_metadata_get( | ||||
key.target, key.authority_type, key.authority_url | key.target, key.authority_type, key.authority_url | ||||
) | ) | ||||
candidates = [ | candidates = [ | ||||
candidate for candidate in candidates if candidate.id == key.id | candidate for candidate in candidates if candidate.id == key.id | ||||
] | ] | ||||
if len(candidates) > 1: | if len(candidates) > 1: | ||||
raise Exception( | raise Exception( | ||||
"Found multiple RawExtrinsicMetadata objects with the same id: " | "Found multiple RawExtrinsicMetadata objects with the same id: " | ||||
+ hash_to_hex(key.id) | + hash_to_hex(key.id) | ||||
) | ) | ||||
results.update(map(converters.row_to_raw_extrinsic_metadata, candidates)) | results.update(map(converters.row_to_raw_extrinsic_metadata, candidates)) | ||||
return list(results) | return list(results) | ||||
@timed | |||||
def raw_extrinsic_metadata_get_authorities( | def raw_extrinsic_metadata_get_authorities( | ||||
self, target: ExtendedSWHID | self, target: ExtendedSWHID | ||||
) -> List[MetadataAuthority]: | ) -> List[MetadataAuthority]: | ||||
return [ | return [ | ||||
MetadataAuthority( | MetadataAuthority( | ||||
type=MetadataAuthorityType(authority_type), url=authority_url | type=MetadataAuthorityType(authority_type), url=authority_url | ||||
) | ) | ||||
for (authority_type, authority_url) in set( | for (authority_type, authority_url) in set( | ||||
self._cql_runner.raw_extrinsic_metadata_get_authorities(str(target)) | self._cql_runner.raw_extrinsic_metadata_get_authorities(str(target)) | ||||
) | ) | ||||
] | ] | ||||
@timed | |||||
@process_metrics | |||||
def metadata_fetcher_add(self, fetchers: List[MetadataFetcher]) -> Dict[str, int]: | def metadata_fetcher_add(self, fetchers: List[MetadataFetcher]) -> Dict[str, int]: | ||||
self.journal_writer.metadata_fetcher_add(fetchers) | self.journal_writer.metadata_fetcher_add(fetchers) | ||||
for fetcher in fetchers: | for fetcher in fetchers: | ||||
self._cql_runner.metadata_fetcher_add( | self._cql_runner.metadata_fetcher_add( | ||||
MetadataFetcherRow(name=fetcher.name, version=fetcher.version,) | MetadataFetcherRow(name=fetcher.name, version=fetcher.version,) | ||||
) | ) | ||||
return {"metadata_fetcher:add": len(fetchers)} | return {"metadata_fetcher:add": len(fetchers)} | ||||
@timed | |||||
def metadata_fetcher_get( | def metadata_fetcher_get( | ||||
self, name: str, version: str | self, name: str, version: str | ||||
) -> Optional[MetadataFetcher]: | ) -> Optional[MetadataFetcher]: | ||||
fetcher = self._cql_runner.metadata_fetcher_get(name, version) | fetcher = self._cql_runner.metadata_fetcher_get(name, version) | ||||
if fetcher: | if fetcher: | ||||
return MetadataFetcher(name=fetcher.name, version=fetcher.version,) | return MetadataFetcher(name=fetcher.name, version=fetcher.version,) | ||||
else: | else: | ||||
return None | return None | ||||
@timed | |||||
@process_metrics | |||||
def metadata_authority_add( | def metadata_authority_add( | ||||
self, authorities: List[MetadataAuthority] | self, authorities: List[MetadataAuthority] | ||||
) -> Dict[str, int]: | ) -> Dict[str, int]: | ||||
self.journal_writer.metadata_authority_add(authorities) | self.journal_writer.metadata_authority_add(authorities) | ||||
for authority in authorities: | for authority in authorities: | ||||
self._cql_runner.metadata_authority_add( | self._cql_runner.metadata_authority_add( | ||||
MetadataAuthorityRow(url=authority.url, type=authority.type.value,) | MetadataAuthorityRow(url=authority.url, type=authority.type.value,) | ||||
) | ) | ||||
return {"metadata_authority:add": len(authorities)} | return {"metadata_authority:add": len(authorities)} | ||||
@timed | |||||
def metadata_authority_get( | def metadata_authority_get( | ||||
self, type: MetadataAuthorityType, url: str | self, type: MetadataAuthorityType, url: str | ||||
) -> Optional[MetadataAuthority]: | ) -> Optional[MetadataAuthority]: | ||||
authority = self._cql_runner.metadata_authority_get(type.value, url) | authority = self._cql_runner.metadata_authority_get(type.value, url) | ||||
if authority: | if authority: | ||||
return MetadataAuthority( | return MetadataAuthority( | ||||
type=MetadataAuthorityType(authority.type), url=authority.url, | type=MetadataAuthorityType(authority.type), url=authority.url, | ||||
) | ) | ||||
else: | else: | ||||
return None | return None | ||||
# ExtID tables | # ExtID tables | ||||
@timed | |||||
@process_metrics | |||||
def extid_add(self, ids: List[ExtID]) -> Dict[str, int]: | def extid_add(self, ids: List[ExtID]) -> Dict[str, int]: | ||||
if not self._allow_overwrite: | if not self._allow_overwrite: | ||||
extids = [ | extids = [ | ||||
extid | extid | ||||
for extid in ids | for extid in ids | ||||
if not self._cql_runner.extid_get_from_pk( | if not self._cql_runner.extid_get_from_pk( | ||||
extid_type=extid.extid_type, | extid_type=extid.extid_type, | ||||
extid_version=extid.extid_version, | extid_version=extid.extid_version, | ||||
Show All 21 Lines | def extid_add(self, ids: List[ExtID]) -> Dict[str, int]: | ||||
indexrow = ExtIDByTargetRow( | indexrow = ExtIDByTargetRow( | ||||
target_type=target_type, target=target, target_token=token, | target_type=target_type, target=target, target_token=token, | ||||
) | ) | ||||
self._cql_runner.extid_index_add_one(indexrow) | self._cql_runner.extid_index_add_one(indexrow) | ||||
insertion_finalizer() | insertion_finalizer() | ||||
inserted += 1 | inserted += 1 | ||||
return {"extid:add": inserted} | return {"extid:add": inserted} | ||||
@timed | |||||
def extid_get_from_extid(self, id_type: str, ids: List[bytes]) -> List[ExtID]: | def extid_get_from_extid(self, id_type: str, ids: List[bytes]) -> List[ExtID]: | ||||
result: List[ExtID] = [] | result: List[ExtID] = [] | ||||
for extid in ids: | for extid in ids: | ||||
extidrows = list(self._cql_runner.extid_get_from_extid(id_type, extid)) | extidrows = list(self._cql_runner.extid_get_from_extid(id_type, extid)) | ||||
result.extend( | result.extend( | ||||
ExtID( | ExtID( | ||||
extid_type=extidrow.extid_type, | extid_type=extidrow.extid_type, | ||||
extid_version=extidrow.extid_version, | extid_version=extidrow.extid_version, | ||||
extid=extidrow.extid, | extid=extidrow.extid, | ||||
target=CoreSWHID( | target=CoreSWHID( | ||||
object_type=extidrow.target_type, object_id=extidrow.target, | object_type=extidrow.target_type, object_id=extidrow.target, | ||||
), | ), | ||||
) | ) | ||||
for extidrow in extidrows | for extidrow in extidrows | ||||
) | ) | ||||
return result | return result | ||||
@timed | |||||
def extid_get_from_target( | def extid_get_from_target( | ||||
self, target_type: SwhidObjectType, ids: List[Sha1Git] | self, target_type: SwhidObjectType, ids: List[Sha1Git] | ||||
) -> List[ExtID]: | ) -> List[ExtID]: | ||||
result: List[ExtID] = [] | result: List[ExtID] = [] | ||||
for target in ids: | for target in ids: | ||||
extidrows = list( | extidrows = list( | ||||
self._cql_runner.extid_get_from_target(target_type.value, target) | self._cql_runner.extid_get_from_target(target_type.value, target) | ||||
) | ) | ||||
Show All 23 Lines |