Changeset View
Changeset View
Standalone View
Standalone View
swh/storage/storage.py
# Copyright (C) 2015-2020 The Software Heritage developers | # Copyright (C) 2015-2020 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
import contextlib | import contextlib | ||||
import copy | import copy | ||||
import datetime | import datetime | ||||
import itertools | import itertools | ||||
import json | import json | ||||
from collections import defaultdict | from collections import defaultdict | ||||
from concurrent.futures import ThreadPoolExecutor | from concurrent.futures import ThreadPoolExecutor | ||||
from contextlib import contextmanager | from contextlib import contextmanager | ||||
from typing import Any, Dict, List, Optional | from typing import Any, Dict, Iterable, List, Optional, Union | ||||
import attr | |||||
import dateutil.parser | import dateutil.parser | ||||
import psycopg2 | import psycopg2 | ||||
import psycopg2.pool | import psycopg2.pool | ||||
import psycopg2.errors | import psycopg2.errors | ||||
from swh.model.model import SHA1_SIZE | from swh.model.model import ( | ||||
from swh.model.hashutil import ALGORITHMS, hash_to_bytes, hash_to_hex | SkippedContent, Content, Directory, Revision, Release, | ||||
Snapshot, Origin, SHA1_SIZE | |||||
) | |||||
from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex | |||||
from swh.objstorage import get_objstorage | from swh.objstorage import get_objstorage | ||||
from swh.objstorage.exc import ObjNotFoundError | from swh.objstorage.exc import ObjNotFoundError | ||||
try: | try: | ||||
from swh.journal.writer import get_journal_writer | from swh.journal.writer import get_journal_writer | ||||
except ImportError: | except ImportError: | ||||
get_journal_writer = None # type: ignore | get_journal_writer = None # type: ignore | ||||
# mypy limitation, see https://github.com/python/mypy/issues/1153 | # mypy limitation, see https://github.com/python/mypy/issues/1153 | ||||
▲ Show 20 Lines • Show All 114 Lines • ▼ Show 20 Lines | def _content_unique_key(self, hash, db): | ||||
aggregation of keys. | aggregation of keys. | ||||
""" | """ | ||||
keys = db.content_hash_keys | keys = db.content_hash_keys | ||||
if isinstance(hash, tuple): | if isinstance(hash, tuple): | ||||
return hash | return hash | ||||
return tuple([hash[k] for k in keys]) | return tuple([hash[k] for k in keys]) | ||||
@staticmethod | |||||
def _content_normalize(d): | |||||
d = d.copy() | |||||
if 'status' not in d: | |||||
d['status'] = 'visible' | |||||
return d | |||||
@staticmethod | |||||
def _content_validate(d): | |||||
"""Sanity checks on status / reason / length, that postgresql | |||||
doesn't enforce.""" | |||||
if d['status'] not in ('visible', 'hidden'): | |||||
raise StorageArgumentException( | |||||
'Invalid content status: {}'.format(d['status'])) | |||||
if d.get('reason') is not None: | |||||
raise StorageArgumentException( | |||||
'Must not provide a reason if content is present.') | |||||
if d['length'] is None or d['length'] < 0: | |||||
raise StorageArgumentException('Content length must be positive.') | |||||
def _content_add_metadata(self, db, cur, content): | def _content_add_metadata(self, db, cur, content): | ||||
"""Add content to the postgresql database but not the object storage. | """Add content to the postgresql database but not the object storage. | ||||
""" | """ | ||||
# create temporary table for metadata injection | # create temporary table for metadata injection | ||||
db.mktemp('content', cur) | db.mktemp('content', cur) | ||||
with convert_validation_exceptions(): | db.copy_to((c.to_dict() for c in content), 'tmp_content', | ||||
db.copy_to(content, 'tmp_content', | |||||
db.content_add_keys, cur) | db.content_add_keys, cur) | ||||
# move metadata in place | # move metadata in place | ||||
try: | try: | ||||
db.content_add_from_temp(cur) | db.content_add_from_temp(cur) | ||||
except psycopg2.IntegrityError as e: | except psycopg2.IntegrityError as e: | ||||
if e.diag.sqlstate == '23505' and \ | if e.diag.sqlstate == '23505' and \ | ||||
e.diag.table_name == 'content': | e.diag.table_name == 'content': | ||||
constraint_to_hash_name = { | constraint_to_hash_name = { | ||||
'content_pkey': 'sha1', | 'content_pkey': 'sha1', | ||||
'content_sha1_git_idx': 'sha1_git', | 'content_sha1_git_idx': 'sha1_git', | ||||
'content_sha256_idx': 'sha256', | 'content_sha256_idx': 'sha256', | ||||
} | } | ||||
colliding_hash_name = constraint_to_hash_name \ | colliding_hash_name = constraint_to_hash_name \ | ||||
.get(e.diag.constraint_name) | .get(e.diag.constraint_name) | ||||
raise HashCollision(colliding_hash_name) from None | raise HashCollision(colliding_hash_name) from None | ||||
else: | else: | ||||
raise | raise | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def content_add(self, content, db=None, cur=None): | def content_add( | ||||
content = [dict(c.items()) for c in content] # semi-shallow copy | self, content: Iterable[Content], db=None, cur=None) -> Dict: | ||||
now = datetime.datetime.now(tz=datetime.timezone.utc) | now = datetime.datetime.now(tz=datetime.timezone.utc) | ||||
for item in content: | content = [attr.evolve(c, ctime=now) for c in content] | ||||
item['ctime'] = now | |||||
content = [self._content_normalize(c) for c in content] | missing = list(self.content_missing( | ||||
for c in content: | map(Content.to_dict, content), key_hash='sha1_git')) | ||||
self._content_validate(c) | content = [c for c in content if c.sha1_git in missing] | ||||
missing = list(self.content_missing(content, key_hash='sha1_git')) | |||||
content = [c for c in content if c['sha1_git'] in missing] | |||||
if self.journal_writer: | if self.journal_writer: | ||||
for item in content: | for item in content: | ||||
if 'data' in item: | if item.data: | ||||
item = item.copy() | item = attr.evolve(item, data=None) | ||||
del item['data'] | |||||
self.journal_writer.write_addition('content', item) | self.journal_writer.write_addition('content', item) | ||||
def add_to_objstorage(): | def add_to_objstorage(): | ||||
"""Add to objstorage the new missing_content | """Add to objstorage the new missing_content | ||||
Returns: | Returns: | ||||
Sum of all the content's data length pushed to the | Sum of all the content's data length pushed to the | ||||
objstorage. Content present twice is only sent once. | objstorage. Content present twice is only sent once. | ||||
""" | """ | ||||
content_bytes_added = 0 | content_bytes_added = 0 | ||||
data = {} | data = {} | ||||
for cont in content: | for cont in content: | ||||
if cont['sha1'] not in data: | if cont.sha1 not in data: | ||||
data[cont['sha1']] = cont['data'] | data[cont.sha1] = cont.data | ||||
content_bytes_added += max(0, cont['length']) | content_bytes_added += max(0, cont.length) | ||||
# FIXME: Since we do the filtering anyway now, we might as | # FIXME: Since we do the filtering anyway now, we might as | ||||
# well make the objstorage's add_batch call return what we | # well make the objstorage's add_batch call return what we | ||||
# want here (real bytes added)... that'd simplify this... | # want here (real bytes added)... that'd simplify this... | ||||
self.objstorage.add_batch(data) | self.objstorage.add_batch(data) | ||||
return content_bytes_added | return content_bytes_added | ||||
with ThreadPoolExecutor(max_workers=1) as executor: | with ThreadPoolExecutor(max_workers=1) as executor: | ||||
Show All 25 Lines | def content_update(self, content, keys=[], db=None, cur=None): | ||||
with convert_validation_exceptions(): | with convert_validation_exceptions(): | ||||
db.copy_to(content, 'tmp_content', select_keys, cur) | db.copy_to(content, 'tmp_content', select_keys, cur) | ||||
db.content_update_from_temp(keys_to_update=keys, | db.content_update_from_temp(keys_to_update=keys, | ||||
cur=cur) | cur=cur) | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def content_add_metadata(self, content, db=None, cur=None): | def content_add_metadata(self, content: Iterable[Content], | ||||
content = [self._content_normalize(c) for c in content] | db=None, cur=None) -> Dict: | ||||
for c in content: | content = list(content) | ||||
self._content_validate(c) | missing = self.content_missing( | ||||
(c.to_dict() for c in content), key_hash='sha1_git') | |||||
missing = self.content_missing(content, key_hash='sha1_git') | content = [c for c in content if c.sha1_git in missing] | ||||
content = [c for c in content if c['sha1_git'] in missing] | |||||
if self.journal_writer: | if self.journal_writer: | ||||
for item in itertools.chain(content): | for item in itertools.chain(content): | ||||
assert 'data' not in content | assert item.data is None | ||||
ardumont: here the bug i mentioned orally (i was a bit fuzzy and wrong about my description though ;)
we… | |||||
self.journal_writer.write_addition('content', item) | self.journal_writer.write_addition('content', item) | ||||
self._content_add_metadata(db, cur, content) | self._content_add_metadata(db, cur, content) | ||||
return { | return { | ||||
'content:add': len(content), | 'content:add': len(content), | ||||
} | } | ||||
▲ Show 20 Lines • Show All 93 Lines • ▼ Show 20 Lines | class Storage(): | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def content_missing_per_sha1_git(self, contents, db=None, cur=None): | def content_missing_per_sha1_git(self, contents, db=None, cur=None): | ||||
for obj in db.content_missing_per_sha1_git(contents, cur): | for obj in db.content_missing_per_sha1_git(contents, cur): | ||||
yield obj[0] | yield obj[0] | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def content_find(self, content, db=None, cur=None): | def content_find(self, content, db=None, cur=None): | ||||
if not set(content).intersection(ALGORITHMS): | if not set(content).intersection(DEFAULT_ALGORITHMS): | ||||
raise StorageArgumentException( | raise StorageArgumentException( | ||||
'content keys must contain at least one of: ' | 'content keys must contain at least one of: ' | ||||
'sha1, sha1_git, sha256, blake2s256') | 'sha1, sha1_git, sha256, blake2s256') | ||||
contents = db.content_find(sha1=content.get('sha1'), | contents = db.content_find(sha1=content.get('sha1'), | ||||
sha1_git=content.get('sha1_git'), | sha1_git=content.get('sha1_git'), | ||||
sha256=content.get('sha256'), | sha256=content.get('sha256'), | ||||
blake2s256=content.get('blake2s256'), | blake2s256=content.get('blake2s256'), | ||||
Show All 29 Lines | def _skipped_content_validate(d): | ||||
if d.get('reason') is None: | if d.get('reason') is None: | ||||
raise StorageArgumentException( | raise StorageArgumentException( | ||||
'Must provide a reason if content is absent.') | 'Must provide a reason if content is absent.') | ||||
if d['length'] < -1: | if d['length'] < -1: | ||||
raise StorageArgumentException( | raise StorageArgumentException( | ||||
'Content length must be positive or -1.') | 'Content length must be positive or -1.') | ||||
def _skipped_content_add_metadata(self, db, cur, content): | def _skipped_content_add_metadata( | ||||
content = \ | self, db, cur, content: Iterable[SkippedContent]): | ||||
[cont.copy() for cont in content] | |||||
origin_ids = db.origin_id_get_by_url( | origin_ids = db.origin_id_get_by_url( | ||||
[cont.get('origin') for cont in content], | [cont.origin for cont in content], | ||||
cur=cur) | cur=cur) | ||||
for (cont, origin_id) in zip(content, origin_ids): | content = [attr.evolve(c, origin=origin_id) | ||||
if 'origin' in cont: | for (c, origin_id) in zip(content, origin_ids)] | ||||
cont['origin'] = origin_id | |||||
db.mktemp('skipped_content', cur) | db.mktemp('skipped_content', cur) | ||||
with convert_validation_exceptions(): | db.copy_to([c.to_dict() for c in content], 'tmp_skipped_content', | ||||
db.copy_to(content, 'tmp_skipped_content', | |||||
db.skipped_content_keys, cur) | db.skipped_content_keys, cur) | ||||
# move metadata in place | # move metadata in place | ||||
db.skipped_content_add_from_temp(cur) | db.skipped_content_add_from_temp(cur) | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def skipped_content_add(self, content, db=None, cur=None): | def skipped_content_add(self, content: Iterable[SkippedContent], | ||||
content = [dict(c.items()) for c in content] # semi-shallow copy | db=None, cur=None) -> Dict: | ||||
now = datetime.datetime.now(tz=datetime.timezone.utc) | now = datetime.datetime.now(tz=datetime.timezone.utc) | ||||
for item in content: | content = [attr.evolve(c, ctime=now) for c in content] | ||||
item['ctime'] = now | |||||
content = [self._skipped_content_normalize(c) for c in content] | |||||
for c in content: | |||||
self._skipped_content_validate(c) | |||||
missing_contents = self.skipped_content_missing(content) | missing_contents = self.skipped_content_missing( | ||||
c.to_dict() for c in content) | |||||
content = [c for c in content | content = [c for c in content | ||||
if any(all(c.get(algo) == missing_content.get(algo) | if any(all(c.get_hash(algo) == missing_content.get(algo) | ||||
for algo in ALGORITHMS) | for algo in DEFAULT_ALGORITHMS) | ||||
for missing_content in missing_contents)] | for missing_content in missing_contents)] | ||||
if self.journal_writer: | if self.journal_writer: | ||||
for item in content: | for item in content: | ||||
self.journal_writer.write_addition('content', item) | self.journal_writer.write_addition('content', item) | ||||
self._skipped_content_add_metadata(db, cur, content) | self._skipped_content_add_metadata(db, cur, content) | ||||
return { | return { | ||||
'skipped_content:add': len(content), | 'skipped_content:add': len(content), | ||||
} | } | ||||
@timed | @timed | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def skipped_content_missing(self, contents, db=None, cur=None): | def skipped_content_missing(self, contents, db=None, cur=None): | ||||
contents = list(contents) | |||||
for content in db.skipped_content_missing(contents, cur): | for content in db.skipped_content_missing(contents, cur): | ||||
yield dict(zip(db.content_hash_keys, content)) | yield dict(zip(db.content_hash_keys, content)) | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def directory_add(self, directories, db=None, cur=None): | def directory_add(self, directories: Iterable[Directory], | ||||
db=None, cur=None) -> Dict: | |||||
directories = list(directories) | directories = list(directories) | ||||
summary = {'directory:add': 0} | summary = {'directory:add': 0} | ||||
dirs = set() | dirs = set() | ||||
dir_entries = { | dir_entries: Dict[str, defaultdict] = { | ||||
'file': defaultdict(list), | 'file': defaultdict(list), | ||||
'dir': defaultdict(list), | 'dir': defaultdict(list), | ||||
'rev': defaultdict(list), | 'rev': defaultdict(list), | ||||
} | } | ||||
for cur_dir in directories: | for cur_dir in directories: | ||||
dir_id = cur_dir['id'] | dir_id = cur_dir.id | ||||
dirs.add(dir_id) | dirs.add(dir_id) | ||||
for src_entry in cur_dir['entries']: | for src_entry in cur_dir.entries: | ||||
entry = src_entry.copy() | entry = src_entry.to_dict() | ||||
entry['dir_id'] = dir_id | entry['dir_id'] = dir_id | ||||
if entry['type'] not in ('file', 'dir', 'rev'): | |||||
raise StorageArgumentException( | |||||
'Entry type must be file, dir, or rev; not %s' | |||||
% entry['type']) | |||||
dir_entries[entry['type']][dir_id].append(entry) | dir_entries[entry['type']][dir_id].append(entry) | ||||
dirs_missing = set(self.directory_missing(dirs, db=db, cur=cur)) | dirs_missing = set(self.directory_missing(dirs, db=db, cur=cur)) | ||||
if not dirs_missing: | if not dirs_missing: | ||||
return summary | return summary | ||||
if self.journal_writer: | if self.journal_writer: | ||||
self.journal_writer.write_additions( | self.journal_writer.write_additions( | ||||
'directory', | 'directory', | ||||
(dir_ for dir_ in directories | (dir_ for dir_ in directories | ||||
if dir_['id'] in dirs_missing)) | if dir_.id in dirs_missing)) | ||||
# Copy directory ids | # Copy directory ids | ||||
dirs_missing_dict = ({'id': dir} for dir in dirs_missing) | dirs_missing_dict = ({'id': dir} for dir in dirs_missing) | ||||
db.mktemp('directory', cur) | db.mktemp('directory', cur) | ||||
with convert_validation_exceptions(): | |||||
db.copy_to(dirs_missing_dict, 'tmp_directory', ['id'], cur) | db.copy_to(dirs_missing_dict, 'tmp_directory', ['id'], cur) | ||||
# Copy entries | # Copy entries | ||||
for entry_type, entry_list in dir_entries.items(): | for entry_type, entry_list in dir_entries.items(): | ||||
entries = itertools.chain.from_iterable( | entries = itertools.chain.from_iterable( | ||||
entries_for_dir | entries_for_dir | ||||
for dir_id, entries_for_dir | for dir_id, entries_for_dir | ||||
in entry_list.items() | in entry_list.items() | ||||
if dir_id in dirs_missing) | if dir_id in dirs_missing) | ||||
db.mktemp_dir_entry(entry_type) | db.mktemp_dir_entry(entry_type) | ||||
db.copy_to( | db.copy_to( | ||||
entries, | entries, | ||||
'tmp_directory_entry_%s' % entry_type, | 'tmp_directory_entry_%s' % entry_type, | ||||
['target', 'name', 'perms', 'dir_id'], | ['target', 'name', 'perms', 'dir_id'], | ||||
cur, | cur, | ||||
) | ) | ||||
# Do the final copy | # Do the final copy | ||||
db.directory_add_from_temp(cur) | db.directory_add_from_temp(cur) | ||||
summary['directory:add'] = len(dirs_missing) | summary['directory:add'] = len(dirs_missing) | ||||
return summary | return summary | ||||
@timed | @timed | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def directory_missing(self, directories, db=None, cur=None): | def directory_missing(self, directories, db=None, cur=None): | ||||
for obj in db.directory_missing_from_list(directories, cur): | for obj in db.directory_missing_from_list(directories, cur): | ||||
Show All 20 Lines | class Storage(): | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def directory_get_random(self, db=None, cur=None): | def directory_get_random(self, db=None, cur=None): | ||||
return db.directory_get_random(cur) | return db.directory_get_random(cur) | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def revision_add(self, revisions, db=None, cur=None): | def revision_add(self, revisions: Iterable[Revision], | ||||
db=None, cur=None) -> Dict: | |||||
revisions = list(revisions) | revisions = list(revisions) | ||||
summary = {'revision:add': 0} | summary = {'revision:add': 0} | ||||
revisions_missing = set(self.revision_missing( | revisions_missing = set(self.revision_missing( | ||||
set(revision['id'] for revision in revisions), | set(revision.id for revision in revisions), | ||||
db=db, cur=cur)) | db=db, cur=cur)) | ||||
if not revisions_missing: | if not revisions_missing: | ||||
return summary | return summary | ||||
db.mktemp_revision(cur) | db.mktemp_revision(cur) | ||||
revisions_filtered = [ | revisions_filtered = [ | ||||
revision for revision in revisions | revision for revision in revisions | ||||
if revision['id'] in revisions_missing] | if revision.id in revisions_missing] | ||||
if self.journal_writer: | if self.journal_writer: | ||||
self.journal_writer.write_additions('revision', revisions_filtered) | self.journal_writer.write_additions('revision', revisions_filtered) | ||||
revisions_filtered = map(converters.revision_to_db, revisions_filtered) | revisions_filtered = \ | ||||
list(map(converters.revision_to_db, revisions_filtered)) | |||||
parents_filtered = [] | parents_filtered: List[bytes] = [] | ||||
with convert_validation_exceptions(): | with convert_validation_exceptions(): | ||||
db.copy_to( | db.copy_to( | ||||
revisions_filtered, 'tmp_revision', db.revision_add_cols, | revisions_filtered, 'tmp_revision', db.revision_add_cols, | ||||
cur, | cur, | ||||
lambda rev: parents_filtered.extend(rev['parents'])) | lambda rev: parents_filtered.extend(rev['parents'])) | ||||
db.revision_add_from_temp(cur) | db.revision_add_from_temp(cur) | ||||
▲ Show 20 Lines • Show All 45 Lines • ▼ Show 20 Lines | class Storage(): | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def revision_get_random(self, db=None, cur=None): | def revision_get_random(self, db=None, cur=None): | ||||
return db.revision_get_random(cur) | return db.revision_get_random(cur) | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def release_add(self, releases, db=None, cur=None): | def release_add( | ||||
self, releases: Iterable[Release], db=None, cur=None) -> Dict: | |||||
releases = list(releases) | releases = list(releases) | ||||
summary = {'release:add': 0} | summary = {'release:add': 0} | ||||
release_ids = set(release['id'] for release in releases) | release_ids = set(release.id for release in releases) | ||||
releases_missing = set(self.release_missing(release_ids, | releases_missing = set(self.release_missing(release_ids, | ||||
db=db, cur=cur)) | db=db, cur=cur)) | ||||
if not releases_missing: | if not releases_missing: | ||||
return summary | return summary | ||||
db.mktemp_release(cur) | db.mktemp_release(cur) | ||||
releases_missing = list(releases_missing) | |||||
releases_filtered = [ | releases_filtered = [ | ||||
release for release in releases | release for release in releases | ||||
if release['id'] in releases_missing | if release.id in releases_missing | ||||
] | ] | ||||
if self.journal_writer: | if self.journal_writer: | ||||
self.journal_writer.write_additions('release', releases_filtered) | self.journal_writer.write_additions('release', releases_filtered) | ||||
releases_filtered = map(converters.release_to_db, releases_filtered) | releases_filtered = \ | ||||
list(map(converters.release_to_db, releases_filtered)) | |||||
with convert_validation_exceptions(): | with convert_validation_exceptions(): | ||||
db.copy_to(releases_filtered, 'tmp_release', db.release_add_cols, | db.copy_to(releases_filtered, 'tmp_release', db.release_add_cols, | ||||
cur) | cur) | ||||
db.release_add_from_temp(cur) | db.release_add_from_temp(cur) | ||||
return {'release:add': len(releases_missing)} | return {'release:add': len(releases_missing)} | ||||
Show All 19 Lines | class Storage(): | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def release_get_random(self, db=None, cur=None): | def release_get_random(self, db=None, cur=None): | ||||
return db.release_get_random(cur) | return db.release_get_random(cur) | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def snapshot_add(self, snapshots, db=None, cur=None): | def snapshot_add( | ||||
self, snapshots: Iterable[Snapshot], db=None, cur=None) -> Dict: | |||||
created_temp_table = False | created_temp_table = False | ||||
count = 0 | count = 0 | ||||
for snapshot in snapshots: | for snapshot in snapshots: | ||||
if not db.snapshot_exists(snapshot['id'], cur): | if not db.snapshot_exists(snapshot.id, cur): | ||||
if not created_temp_table: | if not created_temp_table: | ||||
db.mktemp_snapshot_branch(cur) | db.mktemp_snapshot_branch(cur) | ||||
created_temp_table = True | created_temp_table = True | ||||
try: | try: | ||||
db.copy_to( | db.copy_to( | ||||
( | ( | ||||
{ | { | ||||
'name': name, | 'name': name, | ||||
'target': info['target'] if info else None, | 'target': info.target if info else None, | ||||
'target_type': (info['target_type'] | 'target_type': (info.target_type.value | ||||
if info else None), | if info else None), | ||||
} | } | ||||
for name, info in snapshot['branches'].items() | for name, info in snapshot.branches.items() | ||||
), | ), | ||||
'tmp_snapshot_branch', | 'tmp_snapshot_branch', | ||||
['name', 'target', 'target_type'], | ['name', 'target', 'target_type'], | ||||
cur, | cur, | ||||
) | ) | ||||
except VALIDATION_EXCEPTIONS + (KeyError,) as e: | except VALIDATION_EXCEPTIONS + (KeyError,) as e: | ||||
raise StorageArgumentException(*e.args) | raise StorageArgumentException(*e.args) | ||||
if self.journal_writer: | if self.journal_writer: | ||||
self.journal_writer.write_addition('snapshot', snapshot) | self.journal_writer.write_addition('snapshot', snapshot) | ||||
db.snapshot_add(snapshot['id'], cur) | db.snapshot_add(snapshot.id, cur) | ||||
count += 1 | count += 1 | ||||
return {'snapshot:add': count} | return {'snapshot:add': count} | ||||
@timed | @timed | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def snapshot_missing(self, snapshots, db=None, cur=None): | def snapshot_missing(self, snapshots, db=None, cur=None): | ||||
for obj in db.snapshot_missing_from_list(snapshots, cur): | for obj in db.snapshot_missing_from_list(snapshots, cur): | ||||
▲ Show 20 Lines • Show All 85 Lines • ▼ Show 20 Lines | class Storage(): | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def snapshot_get_random(self, db=None, cur=None): | def snapshot_get_random(self, db=None, cur=None): | ||||
return db.snapshot_get_random(cur) | return db.snapshot_get_random(cur) | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def origin_visit_add(self, origin, date, type, | def origin_visit_add( | ||||
db=None, cur=None): | self, origin, date, type, db=None, cur=None | ||||
) -> Optional[Dict[str, Union[str, int]]]: | |||||
origin_url = origin | origin_url = origin | ||||
if isinstance(date, str): | if isinstance(date, str): | ||||
# FIXME: Converge on iso8601 at some point | # FIXME: Converge on iso8601 at some point | ||||
date = dateutil.parser.parse(date) | date = dateutil.parser.parse(date) | ||||
with convert_validation_exceptions(): | with convert_validation_exceptions(): | ||||
visit_id = db.origin_visit_add(origin_url, date, type, cur) | visit_id = db.origin_visit_add(origin_url, date, type, cur) | ||||
Show All 9 Lines | def origin_visit_add( | ||||
send_metric('origin_visit:add', count=1, method_name='origin_visit') | send_metric('origin_visit:add', count=1, method_name='origin_visit') | ||||
return { | return { | ||||
'origin': origin_url, | 'origin': origin_url, | ||||
'visit': visit_id, | 'visit': visit_id, | ||||
} | } | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def origin_visit_update(self, origin, visit_id, status=None, | def origin_visit_update(self, origin: str, visit_id: int, | ||||
metadata=None, snapshot=None, | status: Optional[str] = None, | ||||
metadata: Optional[Dict] = None, | |||||
snapshot: Optional[bytes] = None, | |||||
db=None, cur=None): | db=None, cur=None): | ||||
if not isinstance(origin, str): | if not isinstance(origin, str): | ||||
raise StorageArgumentException( | raise StorageArgumentException( | ||||
'origin must be a string, not %r' % (origin,)) | 'origin must be a string, not %r' % (origin,)) | ||||
origin_url = origin | origin_url = origin | ||||
visit = db.origin_visit_get(origin_url, visit_id, cur=cur) | visit = db.origin_visit_get(origin_url, visit_id, cur=cur) | ||||
if not visit: | if not visit: | ||||
raise StorageArgumentException('Invalid visit_id for this origin.') | raise StorageArgumentException('Invalid visit_id for this origin.') | ||||
visit = dict(zip(db.origin_visit_get_cols, visit)) | visit = dict(zip(db.origin_visit_get_cols, visit)) | ||||
updates = {} | updates: Dict[str, Any] = {} | ||||
if status and status != visit['status']: | if status and status != visit['status']: | ||||
updates['status'] = status | updates['status'] = status | ||||
if metadata and metadata != visit['metadata']: | if metadata and metadata != visit['metadata']: | ||||
updates['metadata'] = metadata | updates['metadata'] = metadata | ||||
if snapshot and snapshot != visit['snapshot']: | if snapshot and snapshot != visit['snapshot']: | ||||
updates['snapshot'] = snapshot | updates['snapshot'] = snapshot | ||||
if updates: | if updates: | ||||
▲ Show 20 Lines • Show All 159 Lines • ▼ Show 20 Lines | class Storage(): | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def origin_count(self, url_pattern, regexp=False, | def origin_count(self, url_pattern, regexp=False, | ||||
with_visit=False, db=None, cur=None): | with_visit=False, db=None, cur=None): | ||||
return db.origin_count(url_pattern, regexp, with_visit, cur) | return db.origin_count(url_pattern, regexp, with_visit, cur) | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def origin_add(self, origins, db=None, cur=None): | def origin_add( | ||||
origins = copy.deepcopy(list(origins)) | self, origins: Iterable[Origin], db=None, cur=None) -> List[Dict]: | ||||
origins = list(origins) | |||||
for origin in origins: | for origin in origins: | ||||
self.origin_add_one(origin, db=db, cur=cur) | self.origin_add_one(origin, db=db, cur=cur) | ||||
send_metric('origin:add', count=len(origins), method_name='origin_add') | send_metric('origin:add', count=len(origins), method_name='origin_add') | ||||
return origins | return [o.to_dict() for o in origins] | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def origin_add_one(self, origin, db=None, cur=None): | def origin_add_one(self, origin: Origin, db=None, cur=None) -> str: | ||||
if 'url' not in origin: | origin_row = list(db.origin_get_by_url([origin.url], cur))[0] | ||||
raise StorageArgumentException('Missing origin url') | |||||
origin_row = list(db.origin_get_by_url([origin['url']], cur))[0] | |||||
origin_url = dict(zip(db.origin_cols, origin_row))['url'] | origin_url = dict(zip(db.origin_cols, origin_row))['url'] | ||||
if origin_url: | if origin_url: | ||||
return origin_url | return origin_url | ||||
if self.journal_writer: | if self.journal_writer: | ||||
self.journal_writer.write_addition('origin', origin) | self.journal_writer.write_addition('origin', origin) | ||||
origins = db.origin_add(origin['url'], cur) | origins = db.origin_add(origin.url, cur) | ||||
send_metric('origin:add', count=len(origins), method_name='origin_add') | send_metric('origin:add', count=len(origins), method_name='origin_add') | ||||
return origins | return origins | ||||
@db_transaction(statement_timeout=500) | @db_transaction(statement_timeout=500) | ||||
def stat_counters(self, db=None, cur=None): | def stat_counters(self, db=None, cur=None): | ||||
return {k: v for (k, v) in db.stat_counters()} | return {k: v for (k, v) in db.stat_counters()} | ||||
@db_transaction() | @db_transaction() | ||||
▲ Show 20 Lines • Show All 104 Lines • Show Last 20 Lines |
here the bug i mentioned orally (i was a bit fuzzy and wrong about my description though ;)
we checked 'data' not in content instead of item.