Changeset View
Changeset View
Standalone View
Standalone View
swh/storage/postgresql/storage.py
Show First 20 Lines • Show All 140 Lines • ▼ Show 20 Lines | def db(self): | ||||
db = self.get_db() | db = self.get_db() | ||||
yield db | yield db | ||||
finally: | finally: | ||||
if db: | if db: | ||||
self.put_db(db) | self.put_db(db) | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def check_config(self, *, check_write: bool, db=None, cur=None) -> bool: | def check_config(self, *, check_write: bool, db: Db, cur=None) -> bool: | ||||
if not self.objstorage.check_config(check_write=check_write): | if not self.objstorage.check_config(check_write=check_write): | ||||
return False | return False | ||||
if not db.check_dbversion(): | if not db.check_dbversion(): | ||||
return False | return False | ||||
# Check permissions on one of the tables | # Check permissions on one of the tables | ||||
▲ Show 20 Lines • Show All 84 Lines • ▼ Show 20 Lines | def content_add(self, content: List[Content]) -> Dict[str, int]: | ||||
return { | return { | ||||
"content:add": len(contents), | "content:add": len(contents), | ||||
"content:add:bytes": objstorage_summary["content:add:bytes"], | "content:add:bytes": objstorage_summary["content:add:bytes"], | ||||
} | } | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def content_update( | def content_update( | ||||
self, contents: List[Dict[str, Any]], keys: List[str] = [], db=None, cur=None | self, contents: List[Dict[str, Any]], keys: List[str] = [], *, db: Db, cur=None | ||||
) -> None: | ) -> None: | ||||
# TODO: Add a check on input keys. How to properly implement | # TODO: Add a check on input keys. How to properly implement | ||||
# this? We don't know yet the new columns. | # this? We don't know yet the new columns. | ||||
self.journal_writer.content_update(contents) | self.journal_writer.content_update(contents) | ||||
db.mktemp("content", cur) | db.mktemp("content", cur) | ||||
select_keys = list(set(db.content_get_metadata_keys).union(set(keys))) | select_keys = list(set(db.content_get_metadata_keys).union(set(keys))) | ||||
with convert_validation_exceptions(): | with convert_validation_exceptions(): | ||||
db.copy_to(contents, "tmp_content", select_keys, cur) | db.copy_to(contents, "tmp_content", select_keys, cur) | ||||
db.content_update_from_temp(keys_to_update=keys, cur=cur) | db.content_update_from_temp(keys_to_update=keys, cur=cur) | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def content_add_metadata( | def content_add_metadata( | ||||
self, content: List[Content], db=None, cur=None | self, content: List[Content], *, db: Db, cur=None | ||||
) -> Dict[str, int]: | ) -> Dict[str, int]: | ||||
missing = self.content_missing( | missing = self.content_missing( | ||||
(c.to_dict() for c in content), key_hash="sha1_git", db=db, cur=cur, | (c.to_dict() for c in content), key_hash="sha1_git", db=db, cur=cur, | ||||
) | ) | ||||
contents = [c for c in content if c.sha1_git in missing] | contents = [c for c in content if c.sha1_git in missing] | ||||
self.journal_writer.content_add_metadata(contents) | self.journal_writer.content_add_metadata(contents) | ||||
self._content_add_metadata(db, cur, contents) | self._content_add_metadata(db, cur, contents) | ||||
Show All 10 Lines | class Storage: | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def content_get_partition( | def content_get_partition( | ||||
self, | self, | ||||
partition_id: int, | partition_id: int, | ||||
nb_partitions: int, | nb_partitions: int, | ||||
page_token: Optional[str] = None, | page_token: Optional[str] = None, | ||||
limit: int = 1000, | limit: int = 1000, | ||||
db=None, | *, | ||||
db: Db, | |||||
cur=None, | cur=None, | ||||
) -> PagedResult[Content]: | ) -> PagedResult[Content]: | ||||
if limit is None: | if limit is None: | ||||
raise StorageArgumentException("limit should not be None") | raise StorageArgumentException("limit should not be None") | ||||
(start, end) = get_partition_bounds_bytes( | (start, end) = get_partition_bounds_bytes( | ||||
partition_id, nb_partitions, SHA1_SIZE | partition_id, nb_partitions, SHA1_SIZE | ||||
) | ) | ||||
if page_token: | if page_token: | ||||
Show All 13 Lines | ) -> PagedResult[Content]: | ||||
contents.append(content) | contents.append(content) | ||||
assert len(contents) <= limit | assert len(contents) <= limit | ||||
return PagedResult(results=contents, next_page_token=next_page_token) | return PagedResult(results=contents, next_page_token=next_page_token) | ||||
@timed | @timed | ||||
@db_transaction(statement_timeout=500) | @db_transaction(statement_timeout=500) | ||||
def content_get( | def content_get( | ||||
self, contents: List[bytes], algo: str = "sha1", db=None, cur=None | self, contents: List[bytes], algo: str = "sha1", *, db: Db, cur=None | ||||
) -> List[Optional[Content]]: | ) -> List[Optional[Content]]: | ||||
contents_by_hash: Dict[bytes, Optional[Content]] = {} | contents_by_hash: Dict[bytes, Optional[Content]] = {} | ||||
if algo not in DEFAULT_ALGORITHMS: | if algo not in DEFAULT_ALGORITHMS: | ||||
raise StorageArgumentException( | raise StorageArgumentException( | ||||
"algo should be one of {','.join(DEFAULT_ALGORITHMS)}" | "algo should be one of {','.join(DEFAULT_ALGORITHMS)}" | ||||
) | ) | ||||
rows = db.content_get_metadata_from_hashes(contents, algo, cur) | rows = db.content_get_metadata_from_hashes(contents, algo, cur) | ||||
key = operator.attrgetter(algo) | key = operator.attrgetter(algo) | ||||
for row in rows: | for row in rows: | ||||
row_d = dict(zip(db.content_get_metadata_keys, row)) | row_d = dict(zip(db.content_get_metadata_keys, row)) | ||||
content = Content(**row_d) | content = Content(**row_d) | ||||
contents_by_hash[key(content)] = content | contents_by_hash[key(content)] = content | ||||
return [contents_by_hash.get(sha1) for sha1 in contents] | return [contents_by_hash.get(sha1) for sha1 in contents] | ||||
@timed | @timed | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def content_missing( | def content_missing( | ||||
self, contents: List[Dict[str, Any]], key_hash: str = "sha1", db=None, cur=None | self, | ||||
contents: List[Dict[str, Any]], | |||||
key_hash: str = "sha1", | |||||
*, | |||||
db: Db, | |||||
cur=None, | |||||
) -> Iterable[bytes]: | ) -> Iterable[bytes]: | ||||
if key_hash not in DEFAULT_ALGORITHMS: | if key_hash not in DEFAULT_ALGORITHMS: | ||||
raise StorageArgumentException( | raise StorageArgumentException( | ||||
"key_hash should be one of {','.join(DEFAULT_ALGORITHMS)}" | "key_hash should be one of {','.join(DEFAULT_ALGORITHMS)}" | ||||
) | ) | ||||
keys = db.content_hash_keys | keys = db.content_hash_keys | ||||
key_hash_idx = keys.index(key_hash) | key_hash_idx = keys.index(key_hash) | ||||
for obj in db.content_missing_from_list(contents, cur): | for obj in db.content_missing_from_list(contents, cur): | ||||
yield obj[key_hash_idx] | yield obj[key_hash_idx] | ||||
@timed | @timed | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def content_missing_per_sha1( | def content_missing_per_sha1( | ||||
self, contents: List[bytes], db=None, cur=None | self, contents: List[bytes], *, db: Db, cur=None | ||||
) -> Iterable[bytes]: | ) -> Iterable[bytes]: | ||||
for obj in db.content_missing_per_sha1(contents, cur): | for obj in db.content_missing_per_sha1(contents, cur): | ||||
yield obj[0] | yield obj[0] | ||||
@timed | @timed | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def content_missing_per_sha1_git( | def content_missing_per_sha1_git( | ||||
self, contents: List[bytes], db=None, cur=None | self, contents: List[bytes], *, db: Db, cur=None | ||||
) -> Iterable[Sha1Git]: | ) -> Iterable[Sha1Git]: | ||||
for obj in db.content_missing_per_sha1_git(contents, cur): | for obj in db.content_missing_per_sha1_git(contents, cur): | ||||
yield obj[0] | yield obj[0] | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def content_find(self, content: Dict[str, Any], db=None, cur=None) -> List[Content]: | def content_find( | ||||
self, content: Dict[str, Any], *, db: Db, cur=None | |||||
) -> List[Content]: | |||||
if not set(content).intersection(DEFAULT_ALGORITHMS): | if not set(content).intersection(DEFAULT_ALGORITHMS): | ||||
raise StorageArgumentException( | raise StorageArgumentException( | ||||
"content keys must contain at least one " | "content keys must contain at least one " | ||||
f"of: {', '.join(sorted(DEFAULT_ALGORITHMS))}" | f"of: {', '.join(sorted(DEFAULT_ALGORITHMS))}" | ||||
) | ) | ||||
rows = db.content_find( | rows = db.content_find( | ||||
sha1=content.get("sha1"), | sha1=content.get("sha1"), | ||||
sha1_git=content.get("sha1_git"), | sha1_git=content.get("sha1_git"), | ||||
sha256=content.get("sha256"), | sha256=content.get("sha256"), | ||||
blake2s256=content.get("blake2s256"), | blake2s256=content.get("blake2s256"), | ||||
cur=cur, | cur=cur, | ||||
) | ) | ||||
contents = [] | contents = [] | ||||
for row in rows: | for row in rows: | ||||
row_d = dict(zip(db.content_find_cols, row)) | row_d = dict(zip(db.content_find_cols, row)) | ||||
contents.append(Content(**row_d)) | contents.append(Content(**row_d)) | ||||
return contents | return contents | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def content_get_random(self, db=None, cur=None) -> Sha1Git: | def content_get_random(self, *, db: Db, cur=None) -> Sha1Git: | ||||
return db.content_get_random(cur) | return db.content_get_random(cur) | ||||
@staticmethod | @staticmethod | ||||
def _skipped_content_normalize(d): | def _skipped_content_normalize(d): | ||||
d = d.copy() | d = d.copy() | ||||
if d.get("status") is None: | if d.get("status") is None: | ||||
d["status"] = "absent" | d["status"] = "absent" | ||||
Show All 19 Lines | def _skipped_content_add_metadata(self, db, cur, content: List[SkippedContent]): | ||||
# move metadata in place | # move metadata in place | ||||
db.skipped_content_add_from_temp(cur) | db.skipped_content_add_from_temp(cur) | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def skipped_content_add( | def skipped_content_add( | ||||
self, content: List[SkippedContent], db=None, cur=None | self, content: List[SkippedContent], *, db: Db, cur=None | ||||
) -> Dict[str, int]: | ) -> Dict[str, int]: | ||||
ctime = now() | ctime = now() | ||||
content = [attr.evolve(c, ctime=ctime) for c in content] | content = [attr.evolve(c, ctime=ctime) for c in content] | ||||
missing_contents = self.skipped_content_missing( | missing_contents = self.skipped_content_missing( | ||||
(c.to_dict() for c in content), db=db, cur=cur, | (c.to_dict() for c in content), db=db, cur=cur, | ||||
) | ) | ||||
content = [ | content = [ | ||||
Show All 13 Lines | ) -> Dict[str, int]: | ||||
return { | return { | ||||
"skipped_content:add": len(content), | "skipped_content:add": len(content), | ||||
} | } | ||||
@timed | @timed | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def skipped_content_missing( | def skipped_content_missing( | ||||
self, contents: List[Dict[str, Any]], db=None, cur=None | self, contents: List[Dict[str, Any]], *, db: Db, cur=None | ||||
) -> Iterable[Dict[str, Any]]: | ) -> Iterable[Dict[str, Any]]: | ||||
contents = list(contents) | contents = list(contents) | ||||
for content in db.skipped_content_missing(contents, cur): | for content in db.skipped_content_missing(contents, cur): | ||||
yield dict(zip(db.content_hash_keys, content)) | yield dict(zip(db.content_hash_keys, content)) | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def directory_add( | def directory_add( | ||||
self, directories: List[Directory], db=None, cur=None | self, directories: List[Directory], *, db: Db, cur=None | ||||
) -> Dict[str, int]: | ) -> Dict[str, int]: | ||||
summary = {"directory:add": 0} | summary = {"directory:add": 0} | ||||
dirs = set() | dirs = set() | ||||
dir_entries: Dict[str, defaultdict] = { | dir_entries: Dict[str, defaultdict] = { | ||||
"file": defaultdict(list), | "file": defaultdict(list), | ||||
"dir": defaultdict(list), | "dir": defaultdict(list), | ||||
"rev": defaultdict(list), | "rev": defaultdict(list), | ||||
▲ Show 20 Lines • Show All 41 Lines • ▼ Show 20 Lines | ) -> Dict[str, int]: | ||||
db.directory_add_from_temp(cur) | db.directory_add_from_temp(cur) | ||||
summary["directory:add"] = len(dirs_missing) | summary["directory:add"] = len(dirs_missing) | ||||
return summary | return summary | ||||
@timed | @timed | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def directory_missing( | def directory_missing( | ||||
self, directories: List[Sha1Git], db=None, cur=None | self, directories: List[Sha1Git], *, db: Db, cur=None | ||||
) -> Iterable[Sha1Git]: | ) -> Iterable[Sha1Git]: | ||||
for obj in db.directory_missing_from_list(directories, cur): | for obj in db.directory_missing_from_list(directories, cur): | ||||
yield obj[0] | yield obj[0] | ||||
@timed | @timed | ||||
@db_transaction_generator(statement_timeout=20000) | @db_transaction_generator(statement_timeout=20000) | ||||
def directory_ls( | def directory_ls( | ||||
self, directory: Sha1Git, recursive: bool = False, db=None, cur=None | self, directory: Sha1Git, recursive: bool = False, *, db: Db, cur=None | ||||
) -> Iterable[Dict[str, Any]]: | ) -> Iterable[Dict[str, Any]]: | ||||
if recursive: | if recursive: | ||||
res_gen = db.directory_walk(directory, cur=cur) | res_gen = db.directory_walk(directory, cur=cur) | ||||
else: | else: | ||||
res_gen = db.directory_walk_one(directory, cur=cur) | res_gen = db.directory_walk_one(directory, cur=cur) | ||||
for line in res_gen: | for line in res_gen: | ||||
yield dict(zip(db.directory_ls_cols, line)) | yield dict(zip(db.directory_ls_cols, line)) | ||||
@timed | @timed | ||||
@db_transaction(statement_timeout=2000) | @db_transaction(statement_timeout=2000) | ||||
def directory_entry_get_by_path( | def directory_entry_get_by_path( | ||||
self, directory: Sha1Git, paths: List[bytes], db=None, cur=None | self, directory: Sha1Git, paths: List[bytes], *, db: Db, cur=None | ||||
) -> Optional[Dict[str, Any]]: | ) -> Optional[Dict[str, Any]]: | ||||
res = db.directory_entry_get_by_path(directory, paths, cur) | res = db.directory_entry_get_by_path(directory, paths, cur) | ||||
return dict(zip(db.directory_ls_cols, res)) if res else None | return dict(zip(db.directory_ls_cols, res)) if res else None | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def directory_get_random(self, db=None, cur=None) -> Sha1Git: | def directory_get_random(self, *, db: Db, cur=None) -> Sha1Git: | ||||
return db.directory_get_random(cur) | return db.directory_get_random(cur) | ||||
@db_transaction() | @db_transaction() | ||||
def directory_get_entries( | def directory_get_entries( | ||||
self, | self, | ||||
directory_id: Sha1Git, | directory_id: Sha1Git, | ||||
page_token: Optional[bytes] = None, | page_token: Optional[bytes] = None, | ||||
limit: int = 1000, | limit: int = 1000, | ||||
db=None, | *, | ||||
db: Db, | |||||
cur=None, | cur=None, | ||||
) -> Optional[PagedResult[DirectoryEntry]]: | ) -> Optional[PagedResult[DirectoryEntry]]: | ||||
if list(self.directory_missing([directory_id], db=db, cur=cur)): | if list(self.directory_missing([directory_id], db=db, cur=cur)): | ||||
return None | return None | ||||
if page_token is not None: | if page_token is not None: | ||||
raise StorageArgumentException("Unsupported page token") | raise StorageArgumentException("Unsupported page token") | ||||
# TODO: actually paginate | # TODO: actually paginate | ||||
rows = db.directory_get_entries(directory_id, cur=cur) | rows = db.directory_get_entries(directory_id, cur=cur) | ||||
return PagedResult( | return PagedResult( | ||||
results=[ | results=[ | ||||
DirectoryEntry(**dict(zip(db.directory_get_entries_cols, row))) | DirectoryEntry(**dict(zip(db.directory_get_entries_cols, row))) | ||||
for row in rows | for row in rows | ||||
], | ], | ||||
next_page_token=None, | next_page_token=None, | ||||
) | ) | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def revision_add( | def revision_add( | ||||
self, revisions: List[Revision], db=None, cur=None | self, revisions: List[Revision], *, db: Db, cur=None | ||||
) -> Dict[str, int]: | ) -> Dict[str, int]: | ||||
summary = {"revision:add": 0} | summary = {"revision:add": 0} | ||||
revisions_missing = set( | revisions_missing = set( | ||||
self.revision_missing( | self.revision_missing( | ||||
set(revision.id for revision in revisions), db=db, cur=cur | set(revision.id for revision in revisions), db=db, cur=cur | ||||
) | ) | ||||
) | ) | ||||
if not revisions_missing: | if not revisions_missing: | ||||
return summary | return summary | ||||
db.mktemp_revision(cur) | db.mktemp_revision(cur) | ||||
revisions_filtered = [ | revisions_filtered = [ | ||||
revision for revision in revisions if revision.id in revisions_missing | revision for revision in revisions if revision.id in revisions_missing | ||||
] | ] | ||||
self.journal_writer.revision_add(revisions_filtered) | self.journal_writer.revision_add(revisions_filtered) | ||||
db_revisions_filtered = list(map(converters.revision_to_db, revisions_filtered)) | db_revisions_filtered = list(map(converters.revision_to_db, revisions_filtered)) | ||||
parents_filtered: List[bytes] = [] | parents_filtered: List[Dict[str, Any]] = [] | ||||
with convert_validation_exceptions(): | with convert_validation_exceptions(): | ||||
db.copy_to( | db.copy_to( | ||||
db_revisions_filtered, | db_revisions_filtered, | ||||
"tmp_revision", | "tmp_revision", | ||||
db.revision_add_cols, | db.revision_add_cols, | ||||
cur, | cur, | ||||
lambda rev: parents_filtered.extend(rev["parents"]), | lambda rev: parents_filtered.extend(rev["parents"]), | ||||
) | ) | ||||
db.revision_add_from_temp(cur) | db.revision_add_from_temp(cur) | ||||
db.copy_to( | db.copy_to( | ||||
parents_filtered, | parents_filtered, | ||||
"revision_history", | "revision_history", | ||||
["id", "parent_id", "parent_rank"], | ["id", "parent_id", "parent_rank"], | ||||
cur, | cur, | ||||
) | ) | ||||
return {"revision:add": len(revisions_missing)} | return {"revision:add": len(revisions_missing)} | ||||
@timed | @timed | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def revision_missing( | def revision_missing( | ||||
self, revisions: List[Sha1Git], db=None, cur=None | self, revisions: List[Sha1Git], *, db: Db, cur=None | ||||
) -> Iterable[Sha1Git]: | ) -> Iterable[Sha1Git]: | ||||
if not revisions: | if not revisions: | ||||
return None | return None | ||||
for obj in db.revision_missing_from_list(revisions, cur): | for obj in db.revision_missing_from_list(revisions, cur): | ||||
yield obj[0] | yield obj[0] | ||||
@timed | @timed | ||||
@db_transaction(statement_timeout=1000) | @db_transaction(statement_timeout=1000) | ||||
def revision_get( | def revision_get( | ||||
self, revision_ids: List[Sha1Git], db=None, cur=None | self, revision_ids: List[Sha1Git], *, db: Db, cur=None | ||||
) -> List[Optional[Revision]]: | ) -> List[Optional[Revision]]: | ||||
revisions = [] | revisions = [] | ||||
for line in db.revision_get_from_list(revision_ids, cur): | for line in db.revision_get_from_list(revision_ids, cur): | ||||
revision = converters.db_to_revision(dict(zip(db.revision_get_cols, line))) | revision = converters.db_to_revision(dict(zip(db.revision_get_cols, line))) | ||||
revisions.append(revision) | revisions.append(revision) | ||||
return revisions | return revisions | ||||
@timed | @timed | ||||
@db_transaction_generator(statement_timeout=2000) | @db_transaction_generator(statement_timeout=2000) | ||||
def revision_log( | def revision_log( | ||||
self, revisions: List[Sha1Git], limit: Optional[int] = None, db=None, cur=None | self, revisions: List[Sha1Git], limit: Optional[int] = None, *, db: Db, cur=None | ||||
) -> Iterable[Optional[Dict[str, Any]]]: | ) -> Iterable[Optional[Dict[str, Any]]]: | ||||
for line in db.revision_log(revisions, limit, cur): | for line in db.revision_log(revisions, limit, cur): | ||||
data = converters.db_to_revision(dict(zip(db.revision_get_cols, line))) | data = converters.db_to_revision(dict(zip(db.revision_get_cols, line))) | ||||
if not data: | if not data: | ||||
yield None | yield None | ||||
continue | continue | ||||
yield data.to_dict() | yield data.to_dict() | ||||
@timed | @timed | ||||
@db_transaction_generator(statement_timeout=2000) | @db_transaction_generator(statement_timeout=2000) | ||||
def revision_shortlog( | def revision_shortlog( | ||||
self, revisions: List[Sha1Git], limit: Optional[int] = None, db=None, cur=None | self, revisions: List[Sha1Git], limit: Optional[int] = None, *, db: Db, cur=None | ||||
) -> Iterable[Optional[Tuple[Sha1Git, Tuple[Sha1Git, ...]]]]: | ) -> Iterable[Optional[Tuple[Sha1Git, Tuple[Sha1Git, ...]]]]: | ||||
yield from db.revision_shortlog(revisions, limit, cur) | yield from db.revision_shortlog(revisions, limit, cur) | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def revision_get_random(self, db=None, cur=None) -> Sha1Git: | def revision_get_random(self, *, db: Db, cur=None) -> Sha1Git: | ||||
return db.revision_get_random(cur) | return db.revision_get_random(cur) | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def extid_get_from_extid( | def extid_get_from_extid( | ||||
self, id_type: str, ids: List[bytes], db=None, cur=None | self, id_type: str, ids: List[bytes], *, db: Db, cur=None | ||||
) -> List[ExtID]: | ) -> List[ExtID]: | ||||
extids = [] | extids = [] | ||||
for row in db.extid_get_from_extid_list(id_type, ids, cur): | for row in db.extid_get_from_extid_list(id_type, ids, cur): | ||||
if row[0] is not None: | if row[0] is not None: | ||||
extids.append(converters.db_to_extid(dict(zip(db.extid_cols, row)))) | extids.append(converters.db_to_extid(dict(zip(db.extid_cols, row)))) | ||||
return extids | return extids | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def extid_get_from_target( | def extid_get_from_target( | ||||
self, target_type: ObjectType, ids: List[Sha1Git], db=None, cur=None | self, target_type: ObjectType, ids: List[Sha1Git], *, db: Db, cur=None | ||||
) -> List[ExtID]: | ) -> List[ExtID]: | ||||
extids = [] | extids = [] | ||||
for row in db.extid_get_from_swhid_list(target_type.value, ids, cur): | for row in db.extid_get_from_swhid_list(target_type.value, ids, cur): | ||||
if row[0] is not None: | if row[0] is not None: | ||||
extids.append(converters.db_to_extid(dict(zip(db.extid_cols, row)))) | extids.append(converters.db_to_extid(dict(zip(db.extid_cols, row)))) | ||||
return extids | return extids | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def extid_add(self, ids: List[ExtID], db=None, cur=None) -> Dict[str, int]: | def extid_add(self, ids: List[ExtID], *, db: Db, cur=None) -> Dict[str, int]: | ||||
extid = [ | extid = [ | ||||
{ | { | ||||
"extid": extid.extid, | "extid": extid.extid, | ||||
"extid_type": extid.extid_type, | "extid_type": extid.extid_type, | ||||
"target": extid.target.object_id, | "target": extid.target.object_id, | ||||
"target_type": extid.target.object_type.name.lower(), # arghh | "target_type": extid.target.object_type.name.lower(), # arghh | ||||
} | } | ||||
for extid in ids | for extid in ids | ||||
] | ] | ||||
db.mktemp("extid", cur) | db.mktemp("extid", cur) | ||||
self.journal_writer.extid_add(ids) | self.journal_writer.extid_add(ids) | ||||
db.copy_to(extid, "tmp_extid", db.extid_cols, cur) | db.copy_to(extid, "tmp_extid", db.extid_cols, cur) | ||||
# move metadata in place | # move metadata in place | ||||
db.extid_add_from_temp(cur) | db.extid_add_from_temp(cur) | ||||
return {"extid:add": len(extid)} | return {"extid:add": len(extid)} | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def release_add(self, releases: List[Release], db=None, cur=None) -> Dict[str, int]: | def release_add( | ||||
self, releases: List[Release], *, db: Db, cur=None | |||||
) -> Dict[str, int]: | |||||
summary = {"release:add": 0} | summary = {"release:add": 0} | ||||
release_ids = set(release.id for release in releases) | release_ids = set(release.id for release in releases) | ||||
releases_missing = set(self.release_missing(release_ids, db=db, cur=cur)) | releases_missing = set(self.release_missing(release_ids, db=db, cur=cur)) | ||||
if not releases_missing: | if not releases_missing: | ||||
return summary | return summary | ||||
Show All 12 Lines | ) -> Dict[str, int]: | ||||
db.release_add_from_temp(cur) | db.release_add_from_temp(cur) | ||||
return {"release:add": len(releases_missing)} | return {"release:add": len(releases_missing)} | ||||
@timed | @timed | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def release_missing( | def release_missing( | ||||
self, releases: List[Sha1Git], db=None, cur=None | self, releases: List[Sha1Git], *, db: Db, cur=None | ||||
) -> Iterable[Sha1Git]: | ) -> Iterable[Sha1Git]: | ||||
if not releases: | if not releases: | ||||
return | return | ||||
for obj in db.release_missing_from_list(releases, cur): | for obj in db.release_missing_from_list(releases, cur): | ||||
yield obj[0] | yield obj[0] | ||||
@timed | @timed | ||||
@db_transaction(statement_timeout=500) | @db_transaction(statement_timeout=500) | ||||
def release_get( | def release_get( | ||||
self, releases: List[Sha1Git], db=None, cur=None | self, releases: List[Sha1Git], *, db: Db, cur=None | ||||
) -> List[Optional[Release]]: | ) -> List[Optional[Release]]: | ||||
rels = [] | rels = [] | ||||
for release in db.release_get_from_list(releases, cur): | for release in db.release_get_from_list(releases, cur): | ||||
data = converters.db_to_release(dict(zip(db.release_get_cols, release))) | data = converters.db_to_release(dict(zip(db.release_get_cols, release))) | ||||
rels.append(data if data else None) | rels.append(data if data else None) | ||||
return rels | return rels | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def release_get_random(self, db=None, cur=None) -> Sha1Git: | def release_get_random(self, *, db: Db, cur=None) -> Sha1Git: | ||||
return db.release_get_random(cur) | return db.release_get_random(cur) | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def snapshot_add( | def snapshot_add( | ||||
self, snapshots: List[Snapshot], db=None, cur=None | self, snapshots: List[Snapshot], *, db: Db, cur=None | ||||
) -> Dict[str, int]: | ) -> Dict[str, int]: | ||||
created_temp_table = False | created_temp_table = False | ||||
count = 0 | count = 0 | ||||
for snapshot in snapshots: | for snapshot in snapshots: | ||||
if not db.snapshot_exists(snapshot.id, cur): | if not db.snapshot_exists(snapshot.id, cur): | ||||
if not created_temp_table: | if not created_temp_table: | ||||
db.mktemp_snapshot_branch(cur) | db.mktemp_snapshot_branch(cur) | ||||
Show All 21 Lines | ) -> Dict[str, int]: | ||||
db.snapshot_add(snapshot.id, cur) | db.snapshot_add(snapshot.id, cur) | ||||
count += 1 | count += 1 | ||||
return {"snapshot:add": count} | return {"snapshot:add": count} | ||||
@timed | @timed | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def snapshot_missing( | def snapshot_missing( | ||||
self, snapshots: List[Sha1Git], db=None, cur=None | self, snapshots: List[Sha1Git], *, db: Db, cur=None | ||||
) -> Iterable[Sha1Git]: | ) -> Iterable[Sha1Git]: | ||||
for obj in db.snapshot_missing_from_list(snapshots, cur): | for obj in db.snapshot_missing_from_list(snapshots, cur): | ||||
yield obj[0] | yield obj[0] | ||||
@timed | @timed | ||||
@db_transaction(statement_timeout=2000) | @db_transaction(statement_timeout=2000) | ||||
def snapshot_get( | def snapshot_get( | ||||
self, snapshot_id: Sha1Git, db=None, cur=None | self, snapshot_id: Sha1Git, *, db: Db, cur=None | ||||
) -> Optional[Dict[str, Any]]: | ) -> Optional[Dict[str, Any]]: | ||||
d = self.snapshot_get_branches(snapshot_id) | d = self.snapshot_get_branches(snapshot_id) | ||||
if d is None: | if d is None: | ||||
return d | return d | ||||
return { | return { | ||||
"id": d["id"], | "id": d["id"], | ||||
"branches": { | "branches": { | ||||
name: branch.to_dict() if branch else None | name: branch.to_dict() if branch else None | ||||
for (name, branch) in d["branches"].items() | for (name, branch) in d["branches"].items() | ||||
}, | }, | ||||
"next_branch": d["next_branch"], | "next_branch": d["next_branch"], | ||||
} | } | ||||
@timed | @timed | ||||
@db_transaction(statement_timeout=2000) | @db_transaction(statement_timeout=2000) | ||||
def snapshot_count_branches( | def snapshot_count_branches( | ||||
self, | self, | ||||
snapshot_id: Sha1Git, | snapshot_id: Sha1Git, | ||||
branch_name_exclude_prefix: Optional[bytes] = None, | branch_name_exclude_prefix: Optional[bytes] = None, | ||||
db=None, | *, | ||||
db: Db, | |||||
cur=None, | cur=None, | ||||
) -> Optional[Dict[Optional[str], int]]: | ) -> Optional[Dict[Optional[str], int]]: | ||||
return dict( | return dict( | ||||
[ | [ | ||||
bc | bc | ||||
for bc in db.snapshot_count_branches( | for bc in db.snapshot_count_branches( | ||||
snapshot_id, branch_name_exclude_prefix, cur, | snapshot_id, branch_name_exclude_prefix, cur, | ||||
) | ) | ||||
] | ] | ||||
) | ) | ||||
@timed | @timed | ||||
@db_transaction(statement_timeout=2000) | @db_transaction(statement_timeout=2000) | ||||
def snapshot_get_branches( | def snapshot_get_branches( | ||||
self, | self, | ||||
snapshot_id: Sha1Git, | snapshot_id: Sha1Git, | ||||
branches_from: bytes = b"", | branches_from: bytes = b"", | ||||
branches_count: int = 1000, | branches_count: int = 1000, | ||||
target_types: Optional[List[str]] = None, | target_types: Optional[List[str]] = None, | ||||
branch_name_include_substring: Optional[bytes] = None, | branch_name_include_substring: Optional[bytes] = None, | ||||
branch_name_exclude_prefix: Optional[bytes] = None, | branch_name_exclude_prefix: Optional[bytes] = None, | ||||
db=None, | *, | ||||
db: Db, | |||||
cur=None, | cur=None, | ||||
) -> Optional[PartialBranches]: | ) -> Optional[PartialBranches]: | ||||
if snapshot_id == EMPTY_SNAPSHOT_ID: | if snapshot_id == EMPTY_SNAPSHOT_ID: | ||||
return PartialBranches(id=snapshot_id, branches={}, next_branch=None,) | return PartialBranches(id=snapshot_id, branches={}, next_branch=None,) | ||||
branches = {} | branches = {} | ||||
next_branch = None | next_branch = None | ||||
Show All 34 Lines | ) -> Optional[PartialBranches]: | ||||
return PartialBranches( | return PartialBranches( | ||||
id=snapshot_id, branches=branches, next_branch=next_branch, | id=snapshot_id, branches=branches, next_branch=next_branch, | ||||
) | ) | ||||
return None | return None | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def snapshot_get_random(self, db=None, cur=None) -> Sha1Git: | def snapshot_get_random(self, *, db: Db, cur=None) -> Sha1Git: | ||||
return db.snapshot_get_random(cur) | return db.snapshot_get_random(cur) | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def origin_visit_add( | def origin_visit_add( | ||||
self, visits: List[OriginVisit], db=None, cur=None | self, visits: List[OriginVisit], *, db: Db, cur=None | ||||
) -> Iterable[OriginVisit]: | ) -> Iterable[OriginVisit]: | ||||
for visit in visits: | for visit in visits: | ||||
origin = self.origin_get([visit.origin], db=db, cur=cur)[0] | origin = self.origin_get([visit.origin], db=db, cur=cur)[0] | ||||
if not origin: # Cannot add a visit without an origin | if not origin: # Cannot add a visit without an origin | ||||
raise StorageArgumentException("Unknown origin %s", visit.origin) | raise StorageArgumentException("Unknown origin %s", visit.origin) | ||||
all_visits = [] | all_visits = [] | ||||
nb_visits = 0 | nb_visits = 0 | ||||
Show All 30 Lines | ) -> None: | ||||
"""Add an origin visit status""" | """Add an origin visit status""" | ||||
self.journal_writer.origin_visit_status_add([visit_status]) | self.journal_writer.origin_visit_status_add([visit_status]) | ||||
db.origin_visit_status_add(visit_status, cur=cur) | db.origin_visit_status_add(visit_status, cur=cur) | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def origin_visit_status_add( | def origin_visit_status_add( | ||||
self, visit_statuses: List[OriginVisitStatus], db=None, cur=None, | self, visit_statuses: List[OriginVisitStatus], *, db: Db, cur=None, | ||||
) -> Dict[str, int]: | ) -> Dict[str, int]: | ||||
visit_statuses_ = [] | visit_statuses_ = [] | ||||
# First round to check existence (fail early if any is ko) | # First round to check existence (fail early if any is ko) | ||||
for visit_status in visit_statuses: | for visit_status in visit_statuses: | ||||
origin_url = self.origin_get([visit_status.origin], db=db, cur=cur)[0] | origin_url = self.origin_get([visit_status.origin], db=db, cur=cur)[0] | ||||
if not origin_url: | if not origin_url: | ||||
raise StorageArgumentException(f"Unknown origin {visit_status.origin}") | raise StorageArgumentException(f"Unknown origin {visit_status.origin}") | ||||
Show All 21 Lines | class Storage: | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def origin_visit_status_get_latest( | def origin_visit_status_get_latest( | ||||
self, | self, | ||||
origin_url: str, | origin_url: str, | ||||
visit: int, | visit: int, | ||||
allowed_statuses: Optional[List[str]] = None, | allowed_statuses: Optional[List[str]] = None, | ||||
require_snapshot: bool = False, | require_snapshot: bool = False, | ||||
db=None, | *, | ||||
db: Db, | |||||
cur=None, | cur=None, | ||||
) -> Optional[OriginVisitStatus]: | ) -> Optional[OriginVisitStatus]: | ||||
if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): | if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): | ||||
raise StorageArgumentException( | raise StorageArgumentException( | ||||
f"Unknown allowed statuses {','.join(allowed_statuses)}, only " | f"Unknown allowed statuses {','.join(allowed_statuses)}, only " | ||||
f"{','.join(VISIT_STATUSES)} authorized" | f"{','.join(VISIT_STATUSES)} authorized" | ||||
) | ) | ||||
row_d = db.origin_visit_status_get_latest( | row_d = db.origin_visit_status_get_latest( | ||||
origin_url, visit, allowed_statuses, require_snapshot, cur=cur | origin_url, visit, allowed_statuses, require_snapshot, cur=cur | ||||
) | ) | ||||
if not row_d: | if not row_d: | ||||
return None | return None | ||||
return OriginVisitStatus(**row_d) | return OriginVisitStatus(**row_d) | ||||
@timed | @timed | ||||
@db_transaction(statement_timeout=500) | @db_transaction(statement_timeout=500) | ||||
def origin_visit_get( | def origin_visit_get( | ||||
self, | self, | ||||
origin: str, | origin: str, | ||||
page_token: Optional[str] = None, | page_token: Optional[str] = None, | ||||
order: ListOrder = ListOrder.ASC, | order: ListOrder = ListOrder.ASC, | ||||
limit: int = 10, | limit: int = 10, | ||||
db=None, | *, | ||||
db: Db, | |||||
cur=None, | cur=None, | ||||
) -> PagedResult[OriginVisit]: | ) -> PagedResult[OriginVisit]: | ||||
page_token = page_token or "0" | page_token = page_token or "0" | ||||
if not isinstance(order, ListOrder): | if not isinstance(order, ListOrder): | ||||
raise StorageArgumentException("order must be a ListOrder value") | raise StorageArgumentException("order must be a ListOrder value") | ||||
if not isinstance(page_token, str): | if not isinstance(page_token, str): | ||||
raise StorageArgumentException("page_token must be a string.") | raise StorageArgumentException("page_token must be a string.") | ||||
Show All 20 Lines | ) -> PagedResult[OriginVisit]: | ||||
visits = visits[:limit] | visits = visits[:limit] | ||||
next_page_token = str(visits[-1].visit) | next_page_token = str(visits[-1].visit) | ||||
return PagedResult(results=visits, next_page_token=next_page_token) | return PagedResult(results=visits, next_page_token=next_page_token) | ||||
@timed | @timed | ||||
@db_transaction(statement_timeout=500) | @db_transaction(statement_timeout=500) | ||||
def origin_visit_find_by_date( | def origin_visit_find_by_date( | ||||
self, origin: str, visit_date: datetime.datetime, db=None, cur=None | self, origin: str, visit_date: datetime.datetime, *, db: Db, cur=None | ||||
) -> Optional[OriginVisit]: | ) -> Optional[OriginVisit]: | ||||
row_d = db.origin_visit_find_by_date(origin, visit_date, cur=cur) | row_d = db.origin_visit_find_by_date(origin, visit_date, cur=cur) | ||||
if not row_d: | if not row_d: | ||||
return None | return None | ||||
return OriginVisit( | return OriginVisit( | ||||
origin=row_d["origin"], | origin=row_d["origin"], | ||||
visit=row_d["visit"], | visit=row_d["visit"], | ||||
date=row_d["date"], | date=row_d["date"], | ||||
type=row_d["type"], | type=row_d["type"], | ||||
) | ) | ||||
@timed | @timed | ||||
@db_transaction(statement_timeout=500) | @db_transaction(statement_timeout=500) | ||||
def origin_visit_get_by( | def origin_visit_get_by( | ||||
self, origin: str, visit: int, db=None, cur=None | self, origin: str, visit: int, *, db: Db, cur=None | ||||
) -> Optional[OriginVisit]: | ) -> Optional[OriginVisit]: | ||||
row = db.origin_visit_get(origin, visit, cur) | row = db.origin_visit_get(origin, visit, cur) | ||||
if row: | if row: | ||||
row_d = dict(zip(db.origin_visit_get_cols, row)) | row_d = dict(zip(db.origin_visit_get_cols, row)) | ||||
return OriginVisit( | return OriginVisit( | ||||
origin=row_d["origin"], | origin=row_d["origin"], | ||||
visit=row_d["visit"], | visit=row_d["visit"], | ||||
date=row_d["date"], | date=row_d["date"], | ||||
type=row_d["type"], | type=row_d["type"], | ||||
) | ) | ||||
return None | return None | ||||
@timed | @timed | ||||
@db_transaction(statement_timeout=4000) | @db_transaction(statement_timeout=4000) | ||||
def origin_visit_get_latest( | def origin_visit_get_latest( | ||||
self, | self, | ||||
origin: str, | origin: str, | ||||
type: Optional[str] = None, | type: Optional[str] = None, | ||||
allowed_statuses: Optional[List[str]] = None, | allowed_statuses: Optional[List[str]] = None, | ||||
require_snapshot: bool = False, | require_snapshot: bool = False, | ||||
db=None, | *, | ||||
db: Db, | |||||
cur=None, | cur=None, | ||||
) -> Optional[OriginVisit]: | ) -> Optional[OriginVisit]: | ||||
if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): | if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): | ||||
raise StorageArgumentException( | raise StorageArgumentException( | ||||
f"Unknown allowed statuses {','.join(allowed_statuses)}, only " | f"Unknown allowed statuses {','.join(allowed_statuses)}, only " | ||||
f"{','.join(VISIT_STATUSES)} authorized" | f"{','.join(VISIT_STATUSES)} authorized" | ||||
) | ) | ||||
Show All 19 Lines | |||||
@db_transaction(statement_timeout=500) | @db_transaction(statement_timeout=500) | ||||
def origin_visit_status_get( | def origin_visit_status_get( | ||||
self, | self, | ||||
origin: str, | origin: str, | ||||
visit: int, | visit: int, | ||||
page_token: Optional[str] = None, | page_token: Optional[str] = None, | ||||
order: ListOrder = ListOrder.ASC, | order: ListOrder = ListOrder.ASC, | ||||
limit: int = 10, | limit: int = 10, | ||||
db=None, | *, | ||||
db: Db, | |||||
cur=None, | cur=None, | ||||
) -> PagedResult[OriginVisitStatus]: | ) -> PagedResult[OriginVisitStatus]: | ||||
next_page_token = None | next_page_token = None | ||||
date_from = None | date_from = None | ||||
if page_token is not None: | if page_token is not None: | ||||
date_from = datetime.datetime.fromisoformat(page_token) | date_from = datetime.datetime.fromisoformat(page_token) | ||||
visit_statuses: List[OriginVisitStatus] = [] | visit_statuses: List[OriginVisitStatus] = [] | ||||
Show All 10 Lines | ) -> PagedResult[OriginVisitStatus]: | ||||
# excluding that visit status from the result to respect the limit size | # excluding that visit status from the result to respect the limit size | ||||
visit_statuses = visit_statuses[:limit] | visit_statuses = visit_statuses[:limit] | ||||
return PagedResult(results=visit_statuses, next_page_token=next_page_token) | return PagedResult(results=visit_statuses, next_page_token=next_page_token) | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def origin_visit_status_get_random( | def origin_visit_status_get_random( | ||||
self, type: str, db=None, cur=None | self, type: str, *, db: Db, cur=None | ||||
) -> Optional[OriginVisitStatus]: | ) -> Optional[OriginVisitStatus]: | ||||
row = db.origin_visit_get_random(type, cur) | row = db.origin_visit_get_random(type, cur) | ||||
if row is not None: | if row is not None: | ||||
row_d = dict(zip(db.origin_visit_status_cols, row)) | row_d = dict(zip(db.origin_visit_status_cols, row)) | ||||
return OriginVisitStatus(**row_d) | return OriginVisitStatus(**row_d) | ||||
return None | return None | ||||
@timed | @timed | ||||
@db_transaction(statement_timeout=2000) | @db_transaction(statement_timeout=2000) | ||||
def object_find_by_sha1_git( | def object_find_by_sha1_git( | ||||
self, ids: List[Sha1Git], db=None, cur=None | self, ids: List[Sha1Git], *, db: Db, cur=None | ||||
) -> Dict[Sha1Git, List[Dict]]: | ) -> Dict[Sha1Git, List[Dict]]: | ||||
ret: Dict[Sha1Git, List[Dict]] = {id: [] for id in ids} | ret: Dict[Sha1Git, List[Dict]] = {id: [] for id in ids} | ||||
for retval in db.object_find_by_sha1_git(ids, cur=cur): | for retval in db.object_find_by_sha1_git(ids, cur=cur): | ||||
if retval[1]: | if retval[1]: | ||||
ret[retval[0]].append( | ret[retval[0]].append( | ||||
dict(zip(db.object_find_by_sha1_git_cols, retval)) | dict(zip(db.object_find_by_sha1_git_cols, retval)) | ||||
) | ) | ||||
return ret | return ret | ||||
@timed | @timed | ||||
@db_transaction(statement_timeout=500) | @db_transaction(statement_timeout=500) | ||||
def origin_get( | def origin_get( | ||||
self, origins: List[str], db=None, cur=None | self, origins: List[str], *, db: Db, cur=None | ||||
) -> Iterable[Optional[Origin]]: | ) -> Iterable[Optional[Origin]]: | ||||
rows = db.origin_get_by_url(origins, cur) | rows = db.origin_get_by_url(origins, cur) | ||||
result: List[Optional[Origin]] = [] | result: List[Optional[Origin]] = [] | ||||
for row in rows: | for row in rows: | ||||
origin_d = dict(zip(db.origin_cols, row)) | origin_d = dict(zip(db.origin_cols, row)) | ||||
url = origin_d["url"] | url = origin_d["url"] | ||||
result.append(None if url is None else Origin(url=url)) | result.append(None if url is None else Origin(url=url)) | ||||
return result | return result | ||||
@timed | @timed | ||||
@db_transaction(statement_timeout=500) | @db_transaction(statement_timeout=500) | ||||
def origin_get_by_sha1( | def origin_get_by_sha1( | ||||
self, sha1s: List[bytes], db=None, cur=None | self, sha1s: List[bytes], *, db: Db, cur=None | ||||
) -> List[Optional[Dict[str, Any]]]: | ) -> List[Optional[Dict[str, Any]]]: | ||||
return [ | return [ | ||||
dict(zip(db.origin_cols, row)) if row[0] else None | dict(zip(db.origin_cols, row)) if row[0] else None | ||||
for row in db.origin_get_by_sha1(sha1s, cur) | for row in db.origin_get_by_sha1(sha1s, cur) | ||||
] | ] | ||||
@timed | @timed | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def origin_get_range(self, origin_from=1, origin_count=100, db=None, cur=None): | def origin_get_range(self, origin_from=1, origin_count=100, *, db: Db, cur=None): | ||||
for origin in db.origin_get_range(origin_from, origin_count, cur): | for origin in db.origin_get_range(origin_from, origin_count, cur): | ||||
yield dict(zip(db.origin_get_range_cols, origin)) | yield dict(zip(db.origin_get_range_cols, origin)) | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def origin_list( | def origin_list( | ||||
self, page_token: Optional[str] = None, limit: int = 100, *, db=None, cur=None | self, page_token: Optional[str] = None, limit: int = 100, *, db: Db, cur=None | ||||
) -> PagedResult[Origin]: | ) -> PagedResult[Origin]: | ||||
page_token = page_token or "0" | page_token = page_token or "0" | ||||
if not isinstance(page_token, str): | if not isinstance(page_token, str): | ||||
raise StorageArgumentException("page_token must be a string.") | raise StorageArgumentException("page_token must be a string.") | ||||
origin_from = int(page_token) | origin_from = int(page_token) | ||||
next_page_token = None | next_page_token = None | ||||
origins: List[Origin] = [] | origins: List[Origin] = [] | ||||
Show All 17 Lines | |||||
def origin_search( | def origin_search( | ||||
self, | self, | ||||
url_pattern: str, | url_pattern: str, | ||||
page_token: Optional[str] = None, | page_token: Optional[str] = None, | ||||
limit: int = 50, | limit: int = 50, | ||||
regexp: bool = False, | regexp: bool = False, | ||||
with_visit: bool = False, | with_visit: bool = False, | ||||
visit_types: Optional[List[str]] = None, | visit_types: Optional[List[str]] = None, | ||||
db=None, | *, | ||||
db: Db, | |||||
cur=None, | cur=None, | ||||
) -> PagedResult[Origin]: | ) -> PagedResult[Origin]: | ||||
next_page_token = None | next_page_token = None | ||||
offset = int(page_token) if page_token else 0 | offset = int(page_token) if page_token else 0 | ||||
origins = [] | origins = [] | ||||
# Take one more origin so we can reuse it as the next page token if any | # Take one more origin so we can reuse it as the next page token if any | ||||
for origin in db.origin_search( | for origin in db.origin_search( | ||||
Show All 14 Lines | |||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def origin_count( | def origin_count( | ||||
self, | self, | ||||
url_pattern: str, | url_pattern: str, | ||||
regexp: bool = False, | regexp: bool = False, | ||||
with_visit: bool = False, | with_visit: bool = False, | ||||
db=None, | *, | ||||
db: Db, | |||||
cur=None, | cur=None, | ||||
) -> int: | ) -> int: | ||||
return db.origin_count(url_pattern, regexp, with_visit, cur) | return db.origin_count(url_pattern, regexp, with_visit, cur) | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def origin_add(self, origins: List[Origin], db=None, cur=None) -> Dict[str, int]: | def origin_add(self, origins: List[Origin], *, db: Db, cur=None) -> Dict[str, int]: | ||||
urls = [o.url for o in origins] | urls = [o.url for o in origins] | ||||
known_origins = set(url for (url,) in db.origin_get_by_url(urls, cur)) | known_origins = set(url for (url,) in db.origin_get_by_url(urls, cur)) | ||||
# keep only one occurrence of each given origin while keeping the list | # keep only one occurrence of each given origin while keeping the list | ||||
# sorted as originally given | # sorted as originally given | ||||
to_add = sorted(set(urls) - known_origins, key=urls.index) | to_add = sorted(set(urls) - known_origins, key=urls.index) | ||||
self.journal_writer.origin_add([Origin(url=url) for url in to_add]) | self.journal_writer.origin_add([Origin(url=url) for url in to_add]) | ||||
added = 0 | added = 0 | ||||
for url in to_add: | for url in to_add: | ||||
if db.origin_add(url, cur): | if db.origin_add(url, cur): | ||||
added += 1 | added += 1 | ||||
return {"origin:add": added} | return {"origin:add": added} | ||||
@db_transaction(statement_timeout=500) | @db_transaction(statement_timeout=500) | ||||
def stat_counters(self, db=None, cur=None): | def stat_counters(self, *, db: Db, cur=None): | ||||
return {k: v for (k, v) in db.stat_counters()} | return {k: v for (k, v) in db.stat_counters()} | ||||
@db_transaction() | @db_transaction() | ||||
def refresh_stat_counters(self, db=None, cur=None): | def refresh_stat_counters(self, *, db: Db, cur=None): | ||||
keys = [ | keys = [ | ||||
"content", | "content", | ||||
"directory", | "directory", | ||||
"directory_entry_dir", | "directory_entry_dir", | ||||
"directory_entry_file", | "directory_entry_file", | ||||
"directory_entry_rev", | "directory_entry_rev", | ||||
"origin", | "origin", | ||||
"origin_visit", | "origin_visit", | ||||
▲ Show 20 Lines • Show All 48 Lines • ▼ Show 20 Lines | |||||
@db_transaction() | @db_transaction() | ||||
def raw_extrinsic_metadata_get( | def raw_extrinsic_metadata_get( | ||||
self, | self, | ||||
target: ExtendedSWHID, | target: ExtendedSWHID, | ||||
authority: MetadataAuthority, | authority: MetadataAuthority, | ||||
after: Optional[datetime.datetime] = None, | after: Optional[datetime.datetime] = None, | ||||
page_token: Optional[bytes] = None, | page_token: Optional[bytes] = None, | ||||
limit: int = 1000, | limit: int = 1000, | ||||
db=None, | *, | ||||
db: Db, | |||||
cur=None, | cur=None, | ||||
) -> PagedResult[RawExtrinsicMetadata]: | ) -> PagedResult[RawExtrinsicMetadata]: | ||||
if page_token: | if page_token: | ||||
(after_time, after_fetcher) = msgpack_loads(base64.b64decode(page_token)) | (after_time, after_fetcher) = msgpack_loads(base64.b64decode(page_token)) | ||||
if after and after_time < after: | if after and after_time < after: | ||||
raise StorageArgumentException( | raise StorageArgumentException( | ||||
"page_token is inconsistent with the value of 'after'." | "page_token is inconsistent with the value of 'after'." | ||||
) | ) | ||||
else: | else: | ||||
after_time = after | after_time = after | ||||
after_fetcher = None | after_fetcher = None | ||||
authority_id = self._get_authority_id(authority, db, cur) | authority_id = self._get_authority_id(authority, db, cur) | ||||
if not authority_id: | if not authority_id: | ||||
return PagedResult(next_page_token=None, results=[],) | return PagedResult(next_page_token=None, results=[],) | ||||
rows = db.raw_extrinsic_metadata_get( | rows = db.raw_extrinsic_metadata_get( | ||||
type, str(target), authority_id, after_time, after_fetcher, limit + 1, cur, | str(target), authority_id, after_time, after_fetcher, limit + 1, cur, | ||||
) | ) | ||||
rows = [dict(zip(db.raw_extrinsic_metadata_get_cols, row)) for row in rows] | rows = [dict(zip(db.raw_extrinsic_metadata_get_cols, row)) for row in rows] | ||||
results = [] | results = [] | ||||
for row in rows: | for row in rows: | ||||
assert str(target) == row["raw_extrinsic_metadata.target"] | assert str(target) == row["raw_extrinsic_metadata.target"] | ||||
results.append(converters.db_to_raw_extrinsic_metadata(row)) | results.append(converters.db_to_raw_extrinsic_metadata(row)) | ||||
if len(results) > limit: | if len(results) > limit: | ||||
Show All 10 Lines | ) -> PagedResult[RawExtrinsicMetadata]: | ||||
).decode() | ).decode() | ||||
else: | else: | ||||
next_page_token = None | next_page_token = None | ||||
return PagedResult(next_page_token=next_page_token, results=results,) | return PagedResult(next_page_token=next_page_token, results=results,) | ||||
@db_transaction() | @db_transaction() | ||||
def raw_extrinsic_metadata_get_by_ids( | def raw_extrinsic_metadata_get_by_ids( | ||||
self, ids: List[Sha1Git], db=None, cur=None, | self, ids: List[Sha1Git], *, db: Db, cur=None, | ||||
) -> List[RawExtrinsicMetadata]: | ) -> List[RawExtrinsicMetadata]: | ||||
return [ | return [ | ||||
converters.db_to_raw_extrinsic_metadata( | converters.db_to_raw_extrinsic_metadata( | ||||
dict(zip(db.raw_extrinsic_metadata_get_cols, row)) | dict(zip(db.raw_extrinsic_metadata_get_cols, row)) | ||||
) | ) | ||||
for row in db.raw_extrinsic_metadata_get_by_ids(ids) | for row in db.raw_extrinsic_metadata_get_by_ids(ids) | ||||
] | ] | ||||
@db_transaction() | @db_transaction() | ||||
def raw_extrinsic_metadata_get_authorities( | def raw_extrinsic_metadata_get_authorities( | ||||
self, target: ExtendedSWHID, db=None, cur=None, | self, target: ExtendedSWHID, *, db: Db, cur=None, | ||||
) -> List[MetadataAuthority]: | ) -> List[MetadataAuthority]: | ||||
return [ | return [ | ||||
MetadataAuthority( | MetadataAuthority( | ||||
type=MetadataAuthorityType(authority_type), url=authority_url | type=MetadataAuthorityType(authority_type), url=authority_url | ||||
) | ) | ||||
for ( | for ( | ||||
authority_type, | authority_type, | ||||
authority_url, | authority_url, | ||||
) in db.raw_extrinsic_metadata_get_authorities(str(target), cur) | ) in db.raw_extrinsic_metadata_get_authorities(str(target), cur) | ||||
] | ] | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def metadata_fetcher_add( | def metadata_fetcher_add( | ||||
self, fetchers: List[MetadataFetcher], db=None, cur=None | self, fetchers: List[MetadataFetcher], *, db: Db, cur=None | ||||
) -> Dict[str, int]: | ) -> Dict[str, int]: | ||||
fetchers = list(fetchers) | fetchers = list(fetchers) | ||||
self.journal_writer.metadata_fetcher_add(fetchers) | self.journal_writer.metadata_fetcher_add(fetchers) | ||||
count = 0 | count = 0 | ||||
for fetcher in fetchers: | for fetcher in fetchers: | ||||
db.metadata_fetcher_add(fetcher.name, fetcher.version, cur=cur) | db.metadata_fetcher_add(fetcher.name, fetcher.version, cur=cur) | ||||
count += 1 | count += 1 | ||||
return {"metadata_fetcher:add": count} | return {"metadata_fetcher:add": count} | ||||
@timed | @timed | ||||
@db_transaction(statement_timeout=500) | @db_transaction(statement_timeout=500) | ||||
def metadata_fetcher_get( | def metadata_fetcher_get( | ||||
self, name: str, version: str, db=None, cur=None | self, name: str, version: str, *, db: Db, cur=None | ||||
) -> Optional[MetadataFetcher]: | ) -> Optional[MetadataFetcher]: | ||||
row = db.metadata_fetcher_get(name, version, cur=cur) | row = db.metadata_fetcher_get(name, version, cur=cur) | ||||
if not row: | if not row: | ||||
return None | return None | ||||
return MetadataFetcher.from_dict(dict(zip(db.metadata_fetcher_cols, row))) | return MetadataFetcher.from_dict(dict(zip(db.metadata_fetcher_cols, row))) | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def metadata_authority_add( | def metadata_authority_add( | ||||
self, authorities: List[MetadataAuthority], db=None, cur=None | self, authorities: List[MetadataAuthority], *, db: Db, cur=None | ||||
) -> Dict[str, int]: | ) -> Dict[str, int]: | ||||
authorities = list(authorities) | authorities = list(authorities) | ||||
self.journal_writer.metadata_authority_add(authorities) | self.journal_writer.metadata_authority_add(authorities) | ||||
count = 0 | count = 0 | ||||
for authority in authorities: | for authority in authorities: | ||||
db.metadata_authority_add(authority.type.value, authority.url, cur=cur) | db.metadata_authority_add(authority.type.value, authority.url, cur=cur) | ||||
count += 1 | count += 1 | ||||
return {"metadata_authority:add": count} | return {"metadata_authority:add": count} | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def metadata_authority_get( | def metadata_authority_get( | ||||
self, type: MetadataAuthorityType, url: str, db=None, cur=None | self, type: MetadataAuthorityType, url: str, *, db: Db, cur=None | ||||
) -> Optional[MetadataAuthority]: | ) -> Optional[MetadataAuthority]: | ||||
row = db.metadata_authority_get(type.value, url, cur=cur) | row = db.metadata_authority_get(type.value, url, cur=cur) | ||||
if not row: | if not row: | ||||
return None | return None | ||||
return MetadataAuthority.from_dict(dict(zip(db.metadata_authority_cols, row))) | return MetadataAuthority.from_dict(dict(zip(db.metadata_authority_cols, row))) | ||||
def clear_buffers(self, object_types: Sequence[str] = ()) -> None: | def clear_buffers(self, object_types: Sequence[str] = ()) -> None: | ||||
"""Do nothing | """Do nothing | ||||
Show All 20 Lines |