Changeset View
Changeset View
Standalone View
Standalone View
swh/storage/storage.py
# Copyright (C) 2015-2020 The Software Heritage developers | # Copyright (C) 2015-2020 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
import contextlib | |||||
import copy | import copy | ||||
import datetime | import datetime | ||||
import itertools | import itertools | ||||
import json | import json | ||||
from collections import defaultdict | from collections import defaultdict | ||||
from concurrent.futures import ThreadPoolExecutor | from concurrent.futures import ThreadPoolExecutor | ||||
from contextlib import contextmanager | from contextlib import contextmanager | ||||
from typing import Any, Dict, List, Optional | from typing import Any, Dict, List, Optional | ||||
import dateutil.parser | import dateutil.parser | ||||
import psycopg2 | import psycopg2 | ||||
import psycopg2.pool | import psycopg2.pool | ||||
import psycopg2.errors | |||||
from swh.model.model import SHA1_SIZE | from swh.model.model import SHA1_SIZE | ||||
from swh.model.hashutil import ALGORITHMS, hash_to_bytes, hash_to_hex | from swh.model.hashutil import ALGORITHMS, hash_to_bytes, hash_to_hex | ||||
from swh.objstorage import get_objstorage | from swh.objstorage import get_objstorage | ||||
from swh.objstorage.exc import ObjNotFoundError | from swh.objstorage.exc import ObjNotFoundError | ||||
try: | try: | ||||
from swh.journal.writer import get_journal_writer | from swh.journal.writer import get_journal_writer | ||||
except ImportError: | except ImportError: | ||||
get_journal_writer = None # type: ignore | get_journal_writer = None # type: ignore | ||||
# mypy limitation, see https://github.com/python/mypy/issues/1153 | # mypy limitation, see https://github.com/python/mypy/issues/1153 | ||||
from . import converters | from . import converters, HashCollision | ||||
from .common import db_transaction_generator, db_transaction | from .common import db_transaction_generator, db_transaction | ||||
from .db import Db | from .db import Db | ||||
from .exc import StorageDBError | from .exc import StorageArgumentException, StorageDBError | ||||
from .algos import diff | from .algos import diff | ||||
from .metrics import timed, send_metric, process_metrics | from .metrics import timed, send_metric, process_metrics | ||||
from .utils import get_partition_bounds_bytes | from .utils import get_partition_bounds_bytes | ||||
# Max block size of contents to return | # Max block size of contents to return | ||||
BULK_BLOCK_CONTENT_LEN_MAX = 10000 | BULK_BLOCK_CONTENT_LEN_MAX = 10000 | ||||
EMPTY_SNAPSHOT_ID = hash_to_bytes('1a8893e6a86f444e8be8e7bda6cb34fb1735a00e') | EMPTY_SNAPSHOT_ID = hash_to_bytes('1a8893e6a86f444e8be8e7bda6cb34fb1735a00e') | ||||
"""Identifier for the empty snapshot""" | """Identifier for the empty snapshot""" | ||||
VALIDATION_EXCEPTIONS = ( | |||||
psycopg2.errors.CheckViolation, | |||||
psycopg2.errors.IntegrityError, | |||||
psycopg2.errors.InvalidTextRepresentation, | |||||
psycopg2.errors.NotNullViolation, | |||||
psycopg2.errors.NumericValueOutOfRange, | |||||
psycopg2.errors.UndefinedFunction, # (raised on wrong argument typs) | |||||
) | |||||
"""Exceptions raised by postgresql when validation of the arguments | |||||
failed.""" | |||||
@contextlib.contextmanager | |||||
def convert_validation_exceptions(): | |||||
"""Catches postgresql errors related to invalid arguments, and | |||||
re-raises a StorageArgumentException.""" | |||||
try: | |||||
yield | |||||
except VALIDATION_EXCEPTIONS as e: | |||||
raise StorageArgumentException(*e.args) | |||||
class Storage(): | class Storage(): | ||||
"""SWH storage proxy, encompassing DB and object storage | """SWH storage proxy, encompassing DB and object storage | ||||
""" | """ | ||||
def __init__(self, db, objstorage, min_pool_conns=1, max_pool_conns=10, | def __init__(self, db, objstorage, min_pool_conns=1, max_pool_conns=10, | ||||
journal_writer=None): | journal_writer=None): | ||||
""" | """ | ||||
▲ Show 20 Lines • Show All 82 Lines • ▼ Show 20 Lines | def _content_normalize(d): | ||||
return d | return d | ||||
@staticmethod | @staticmethod | ||||
def _content_validate(d): | def _content_validate(d): | ||||
"""Sanity checks on status / reason / length, that postgresql | """Sanity checks on status / reason / length, that postgresql | ||||
doesn't enforce.""" | doesn't enforce.""" | ||||
if d['status'] not in ('visible', 'hidden'): | if d['status'] not in ('visible', 'hidden'): | ||||
raise ValueError('Invalid content status: {}'.format(d['status'])) | raise StorageArgumentException( | ||||
'Invalid content status: {}'.format(d['status'])) | |||||
if d.get('reason') is not None: | if d.get('reason') is not None: | ||||
raise ValueError( | raise StorageArgumentException( | ||||
'Must not provide a reason if content is present.') | 'Must not provide a reason if content is present.') | ||||
if d['length'] is None or d['length'] < 0: | if d['length'] is None or d['length'] < 0: | ||||
raise ValueError('Content length must be positive.') | raise StorageArgumentException('Content length must be positive.') | ||||
def _content_add_metadata(self, db, cur, content): | def _content_add_metadata(self, db, cur, content): | ||||
"""Add content to the postgresql database but not the object storage. | """Add content to the postgresql database but not the object storage. | ||||
""" | """ | ||||
# create temporary table for metadata injection | # create temporary table for metadata injection | ||||
db.mktemp('content', cur) | db.mktemp('content', cur) | ||||
with convert_validation_exceptions(): | |||||
db.copy_to(content, 'tmp_content', | db.copy_to(content, 'tmp_content', | ||||
db.content_add_keys, cur) | db.content_add_keys, cur) | ||||
# move metadata in place | # move metadata in place | ||||
try: | try: | ||||
db.content_add_from_temp(cur) | db.content_add_from_temp(cur) | ||||
except psycopg2.IntegrityError as e: | except psycopg2.IntegrityError as e: | ||||
from . import HashCollision | |||||
if e.diag.sqlstate == '23505' and \ | if e.diag.sqlstate == '23505' and \ | ||||
e.diag.table_name == 'content': | e.diag.table_name == 'content': | ||||
constraint_to_hash_name = { | constraint_to_hash_name = { | ||||
'content_pkey': 'sha1', | 'content_pkey': 'sha1', | ||||
'content_sha1_git_idx': 'sha1_git', | 'content_sha1_git_idx': 'sha1_git', | ||||
'content_sha256_idx': 'sha256', | 'content_sha256_idx': 'sha256', | ||||
} | } | ||||
colliding_hash_name = constraint_to_hash_name \ | colliding_hash_name = constraint_to_hash_name \ | ||||
.get(e.diag.constraint_name) | .get(e.diag.constraint_name) | ||||
raise HashCollision(colliding_hash_name) from None | raise HashCollision(colliding_hash_name) from None | ||||
else: | else: | ||||
raise | raise | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def content_add(self, content, db=None, cur=None): | def content_add(self, content, db=None, cur=None): | ||||
content = [dict(c.items()) for c in content] # semi-shallow copy | content = [dict(c.items()) for c in content] # semi-shallow copy | ||||
now = datetime.datetime.now(tz=datetime.timezone.utc) | now = datetime.datetime.now(tz=datetime.timezone.utc) | ||||
for item in content: | for item in content: | ||||
▲ Show 20 Lines • Show All 55 Lines • ▼ Show 20 Lines | def content_update(self, content, keys=[], db=None, cur=None): | ||||
# this? We don't know yet the new columns. | # this? We don't know yet the new columns. | ||||
if self.journal_writer: | if self.journal_writer: | ||||
raise NotImplementedError( | raise NotImplementedError( | ||||
'content_update is not yet supported with a journal_writer.') | 'content_update is not yet supported with a journal_writer.') | ||||
db.mktemp('content', cur) | db.mktemp('content', cur) | ||||
select_keys = list(set(db.content_get_metadata_keys).union(set(keys))) | select_keys = list(set(db.content_get_metadata_keys).union(set(keys))) | ||||
with convert_validation_exceptions(): | |||||
db.copy_to(content, 'tmp_content', select_keys, cur) | db.copy_to(content, 'tmp_content', select_keys, cur) | ||||
db.content_update_from_temp(keys_to_update=keys, | db.content_update_from_temp(keys_to_update=keys, | ||||
cur=cur) | cur=cur) | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def content_add_metadata(self, content, db=None, cur=None): | def content_add_metadata(self, content, db=None, cur=None): | ||||
content = [self._content_normalize(c) for c in content] | content = [self._content_normalize(c) for c in content] | ||||
for c in content: | for c in content: | ||||
self._content_validate(c) | self._content_validate(c) | ||||
Show All 11 Lines | def content_add_metadata(self, content, db=None, cur=None): | ||||
return { | return { | ||||
'content:add': len(content), | 'content:add': len(content), | ||||
} | } | ||||
@timed | @timed | ||||
def content_get(self, content): | def content_get(self, content): | ||||
# FIXME: Make this method support slicing the `data`. | # FIXME: Make this method support slicing the `data`. | ||||
if len(content) > BULK_BLOCK_CONTENT_LEN_MAX: | if len(content) > BULK_BLOCK_CONTENT_LEN_MAX: | ||||
raise ValueError( | raise StorageArgumentException( | ||||
"Send at maximum %s contents." % BULK_BLOCK_CONTENT_LEN_MAX) | "Send at maximum %s contents." % BULK_BLOCK_CONTENT_LEN_MAX) | ||||
for obj_id in content: | for obj_id in content: | ||||
try: | try: | ||||
data = self.objstorage.get(obj_id) | data = self.objstorage.get(obj_id) | ||||
except ObjNotFoundError: | except ObjNotFoundError: | ||||
yield None | yield None | ||||
continue | continue | ||||
yield {'sha1': obj_id, 'data': data} | yield {'sha1': obj_id, 'data': data} | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def content_get_range(self, start, end, limit=1000, db=None, cur=None): | def content_get_range(self, start, end, limit=1000, db=None, cur=None): | ||||
if limit is None: | if limit is None: | ||||
raise ValueError('Development error: limit should not be None') | raise StorageArgumentException('limit should not be None') | ||||
contents = [] | contents = [] | ||||
next_content = None | next_content = None | ||||
for counter, content_row in enumerate( | for counter, content_row in enumerate( | ||||
db.content_get_range(start, end, limit+1, cur)): | db.content_get_range(start, end, limit+1, cur)): | ||||
content = dict(zip(db.content_get_metadata_keys, content_row)) | content = dict(zip(db.content_get_metadata_keys, content_row)) | ||||
if counter >= limit: | if counter >= limit: | ||||
# take the last commit for the next page starting from this | # take the last commit for the next page starting from this | ||||
next_content = content['sha1'] | next_content = content['sha1'] | ||||
break | break | ||||
contents.append(content) | contents.append(content) | ||||
return { | return { | ||||
'contents': contents, | 'contents': contents, | ||||
'next': next_content, | 'next': next_content, | ||||
} | } | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def content_get_partition( | def content_get_partition( | ||||
self, partition_id: int, nb_partitions: int, limit: int = 1000, | self, partition_id: int, nb_partitions: int, limit: int = 1000, | ||||
page_token: str = None, db=None, cur=None): | page_token: str = None, db=None, cur=None): | ||||
if limit is None: | if limit is None: | ||||
raise ValueError('Development error: limit should not be None') | raise StorageArgumentException('limit should not be None') | ||||
(start, end) = get_partition_bounds_bytes( | (start, end) = get_partition_bounds_bytes( | ||||
partition_id, nb_partitions, SHA1_SIZE) | partition_id, nb_partitions, SHA1_SIZE) | ||||
if page_token: | if page_token: | ||||
start = hash_to_bytes(page_token) | start = hash_to_bytes(page_token) | ||||
if end is None: | if end is None: | ||||
end = b'\xff'*SHA1_SIZE | end = b'\xff'*SHA1_SIZE | ||||
result = self.content_get_range(start, end, limit) | result = self.content_get_range(start, end, limit) | ||||
result2 = { | result2 = { | ||||
Show All 16 Lines | def content_get_metadata( | ||||
return result | return result | ||||
@timed | @timed | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def content_missing(self, content, key_hash='sha1', db=None, cur=None): | def content_missing(self, content, key_hash='sha1', db=None, cur=None): | ||||
keys = db.content_hash_keys | keys = db.content_hash_keys | ||||
if key_hash not in keys: | if key_hash not in keys: | ||||
raise ValueError("key_hash should be one of %s" % keys) | raise StorageArgumentException( | ||||
"key_hash should be one of %s" % keys) | |||||
key_hash_idx = keys.index(key_hash) | key_hash_idx = keys.index(key_hash) | ||||
if not content: | if not content: | ||||
return | return | ||||
for obj in db.content_missing_from_list(content, cur): | for obj in db.content_missing_from_list(content, cur): | ||||
yield obj[key_hash_idx] | yield obj[key_hash_idx] | ||||
Show All 9 Lines | class Storage(): | ||||
def content_missing_per_sha1_git(self, contents, db=None, cur=None): | def content_missing_per_sha1_git(self, contents, db=None, cur=None): | ||||
for obj in db.content_missing_per_sha1_git(contents, cur): | for obj in db.content_missing_per_sha1_git(contents, cur): | ||||
yield obj[0] | yield obj[0] | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def content_find(self, content, db=None, cur=None): | def content_find(self, content, db=None, cur=None): | ||||
if not set(content).intersection(ALGORITHMS): | if not set(content).intersection(ALGORITHMS): | ||||
raise ValueError('content keys must contain at least one of: ' | raise StorageArgumentException( | ||||
'content keys must contain at least one of: ' | |||||
'sha1, sha1_git, sha256, blake2s256') | 'sha1, sha1_git, sha256, blake2s256') | ||||
contents = db.content_find(sha1=content.get('sha1'), | contents = db.content_find(sha1=content.get('sha1'), | ||||
sha1_git=content.get('sha1_git'), | sha1_git=content.get('sha1_git'), | ||||
sha256=content.get('sha256'), | sha256=content.get('sha256'), | ||||
blake2s256=content.get('blake2s256'), | blake2s256=content.get('blake2s256'), | ||||
cur=cur) | cur=cur) | ||||
return [dict(zip(db.content_find_cols, content)) | return [dict(zip(db.content_find_cols, content)) | ||||
for content in contents] | for content in contents] | ||||
Show All 15 Lines | def _skipped_content_normalize(d): | ||||
return d | return d | ||||
@staticmethod | @staticmethod | ||||
def _skipped_content_validate(d): | def _skipped_content_validate(d): | ||||
"""Sanity checks on status / reason / length, that postgresql | """Sanity checks on status / reason / length, that postgresql | ||||
doesn't enforce.""" | doesn't enforce.""" | ||||
if d['status'] != 'absent': | if d['status'] != 'absent': | ||||
raise ValueError('Invalid content status: {}'.format(d['status'])) | raise StorageArgumentException( | ||||
'Invalid content status: {}'.format(d['status'])) | |||||
if d.get('reason') is None: | if d.get('reason') is None: | ||||
raise ValueError( | raise StorageArgumentException( | ||||
'Must provide a reason if content is absent.') | 'Must provide a reason if content is absent.') | ||||
if d['length'] < -1: | if d['length'] < -1: | ||||
raise ValueError('Content length must be positive or -1.') | raise StorageArgumentException( | ||||
'Content length must be positive or -1.') | |||||
def _skipped_content_add_metadata(self, db, cur, content): | def _skipped_content_add_metadata(self, db, cur, content): | ||||
content = \ | content = \ | ||||
[cont.copy() for cont in content] | [cont.copy() for cont in content] | ||||
origin_ids = db.origin_id_get_by_url( | origin_ids = db.origin_id_get_by_url( | ||||
[cont.get('origin') for cont in content], | [cont.get('origin') for cont in content], | ||||
cur=cur) | cur=cur) | ||||
for (cont, origin_id) in zip(content, origin_ids): | for (cont, origin_id) in zip(content, origin_ids): | ||||
if 'origin' in cont: | if 'origin' in cont: | ||||
cont['origin'] = origin_id | cont['origin'] = origin_id | ||||
db.mktemp('skipped_content', cur) | db.mktemp('skipped_content', cur) | ||||
with convert_validation_exceptions(): | |||||
db.copy_to(content, 'tmp_skipped_content', | db.copy_to(content, 'tmp_skipped_content', | ||||
db.skipped_content_keys, cur) | db.skipped_content_keys, cur) | ||||
# move metadata in place | # move metadata in place | ||||
db.skipped_content_add_from_temp(cur) | db.skipped_content_add_from_temp(cur) | ||||
@timed | @timed | ||||
@process_metrics | @process_metrics | ||||
@db_transaction() | @db_transaction() | ||||
def skipped_content_add(self, content, db=None, cur=None): | def skipped_content_add(self, content, db=None, cur=None): | ||||
content = [dict(c.items()) for c in content] # semi-shallow copy | content = [dict(c.items()) for c in content] # semi-shallow copy | ||||
now = datetime.datetime.now(tz=datetime.timezone.utc) | now = datetime.datetime.now(tz=datetime.timezone.utc) | ||||
for item in content: | for item in content: | ||||
▲ Show 20 Lines • Show All 41 Lines • ▼ Show 20 Lines | def directory_add(self, directories, db=None, cur=None): | ||||
for cur_dir in directories: | for cur_dir in directories: | ||||
dir_id = cur_dir['id'] | dir_id = cur_dir['id'] | ||||
dirs.add(dir_id) | dirs.add(dir_id) | ||||
for src_entry in cur_dir['entries']: | for src_entry in cur_dir['entries']: | ||||
entry = src_entry.copy() | entry = src_entry.copy() | ||||
entry['dir_id'] = dir_id | entry['dir_id'] = dir_id | ||||
if entry['type'] not in ('file', 'dir', 'rev'): | if entry['type'] not in ('file', 'dir', 'rev'): | ||||
raise ValueError( | raise StorageArgumentException( | ||||
'Entry type must be file, dir, or rev; not %s' | 'Entry type must be file, dir, or rev; not %s' | ||||
% entry['type']) | % entry['type']) | ||||
dir_entries[entry['type']][dir_id].append(entry) | dir_entries[entry['type']][dir_id].append(entry) | ||||
dirs_missing = set(self.directory_missing(dirs, db=db, cur=cur)) | dirs_missing = set(self.directory_missing(dirs, db=db, cur=cur)) | ||||
if not dirs_missing: | if not dirs_missing: | ||||
return summary | return summary | ||||
if self.journal_writer: | if self.journal_writer: | ||||
self.journal_writer.write_additions( | self.journal_writer.write_additions( | ||||
'directory', | 'directory', | ||||
(dir_ for dir_ in directories | (dir_ for dir_ in directories | ||||
if dir_['id'] in dirs_missing)) | if dir_['id'] in dirs_missing)) | ||||
# Copy directory ids | # Copy directory ids | ||||
dirs_missing_dict = ({'id': dir} for dir in dirs_missing) | dirs_missing_dict = ({'id': dir} for dir in dirs_missing) | ||||
db.mktemp('directory', cur) | db.mktemp('directory', cur) | ||||
with convert_validation_exceptions(): | |||||
db.copy_to(dirs_missing_dict, 'tmp_directory', ['id'], cur) | db.copy_to(dirs_missing_dict, 'tmp_directory', ['id'], cur) | ||||
# Copy entries | # Copy entries | ||||
for entry_type, entry_list in dir_entries.items(): | for entry_type, entry_list in dir_entries.items(): | ||||
entries = itertools.chain.from_iterable( | entries = itertools.chain.from_iterable( | ||||
entries_for_dir | entries_for_dir | ||||
for dir_id, entries_for_dir | for dir_id, entries_for_dir | ||||
in entry_list.items() | in entry_list.items() | ||||
if dir_id in dirs_missing) | if dir_id in dirs_missing) | ||||
db.mktemp_dir_entry(entry_type) | db.mktemp_dir_entry(entry_type) | ||||
db.copy_to( | db.copy_to( | ||||
entries, | entries, | ||||
'tmp_directory_entry_%s' % entry_type, | 'tmp_directory_entry_%s' % entry_type, | ||||
['target', 'name', 'perms', 'dir_id'], | ['target', 'name', 'perms', 'dir_id'], | ||||
cur, | cur, | ||||
) | ) | ||||
# Do the final copy | # Do the final copy | ||||
db.directory_add_from_temp(cur) | db.directory_add_from_temp(cur) | ||||
summary['directory:add'] = len(dirs_missing) | summary['directory:add'] = len(dirs_missing) | ||||
return summary | return summary | ||||
@timed | @timed | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def directory_missing(self, directories, db=None, cur=None): | def directory_missing(self, directories, db=None, cur=None): | ||||
for obj in db.directory_missing_from_list(directories, cur): | for obj in db.directory_missing_from_list(directories, cur): | ||||
▲ Show 20 Lines • Show All 44 Lines • ▼ Show 20 Lines | def revision_add(self, revisions, db=None, cur=None): | ||||
if self.journal_writer: | if self.journal_writer: | ||||
self.journal_writer.write_additions('revision', revisions_filtered) | self.journal_writer.write_additions('revision', revisions_filtered) | ||||
revisions_filtered = map(converters.revision_to_db, revisions_filtered) | revisions_filtered = map(converters.revision_to_db, revisions_filtered) | ||||
parents_filtered = [] | parents_filtered = [] | ||||
with convert_validation_exceptions(): | |||||
db.copy_to( | db.copy_to( | ||||
revisions_filtered, 'tmp_revision', db.revision_add_cols, | revisions_filtered, 'tmp_revision', db.revision_add_cols, | ||||
cur, | cur, | ||||
lambda rev: parents_filtered.extend(rev['parents'])) | lambda rev: parents_filtered.extend(rev['parents'])) | ||||
db.revision_add_from_temp(cur) | db.revision_add_from_temp(cur) | ||||
db.copy_to(parents_filtered, 'revision_history', | db.copy_to(parents_filtered, 'revision_history', | ||||
['id', 'parent_id', 'parent_rank'], cur) | ['id', 'parent_id', 'parent_rank'], cur) | ||||
return {'revision:add': len(revisions_missing)} | return {'revision:add': len(revisions_missing)} | ||||
@timed | @timed | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def revision_missing(self, revisions, db=None, cur=None): | def revision_missing(self, revisions, db=None, cur=None): | ||||
if not revisions: | if not revisions: | ||||
return | return | ||||
▲ Show 20 Lines • Show All 59 Lines • ▼ Show 20 Lines | def release_add(self, releases, db=None, cur=None): | ||||
if release['id'] in releases_missing | if release['id'] in releases_missing | ||||
] | ] | ||||
if self.journal_writer: | if self.journal_writer: | ||||
self.journal_writer.write_additions('release', releases_filtered) | self.journal_writer.write_additions('release', releases_filtered) | ||||
releases_filtered = map(converters.release_to_db, releases_filtered) | releases_filtered = map(converters.release_to_db, releases_filtered) | ||||
with convert_validation_exceptions(): | |||||
db.copy_to(releases_filtered, 'tmp_release', db.release_add_cols, | db.copy_to(releases_filtered, 'tmp_release', db.release_add_cols, | ||||
cur) | cur) | ||||
db.release_add_from_temp(cur) | db.release_add_from_temp(cur) | ||||
return {'release:add': len(releases_missing)} | return {'release:add': len(releases_missing)} | ||||
@timed | @timed | ||||
@db_transaction_generator() | @db_transaction_generator() | ||||
def release_missing(self, releases, db=None, cur=None): | def release_missing(self, releases, db=None, cur=None): | ||||
if not releases: | if not releases: | ||||
return | return | ||||
Show All 23 Lines | def snapshot_add(self, snapshots, db=None, cur=None): | ||||
count = 0 | count = 0 | ||||
for snapshot in snapshots: | for snapshot in snapshots: | ||||
if not db.snapshot_exists(snapshot['id'], cur): | if not db.snapshot_exists(snapshot['id'], cur): | ||||
if not created_temp_table: | if not created_temp_table: | ||||
db.mktemp_snapshot_branch(cur) | db.mktemp_snapshot_branch(cur) | ||||
created_temp_table = True | created_temp_table = True | ||||
try: | |||||
db.copy_to( | db.copy_to( | ||||
( | ( | ||||
{ | { | ||||
'name': name, | 'name': name, | ||||
'target': info['target'] if info else None, | 'target': info['target'] if info else None, | ||||
'target_type': (info['target_type'] | 'target_type': (info['target_type'] | ||||
if info else None), | if info else None), | ||||
} | } | ||||
for name, info in snapshot['branches'].items() | for name, info in snapshot['branches'].items() | ||||
), | ), | ||||
'tmp_snapshot_branch', | 'tmp_snapshot_branch', | ||||
['name', 'target', 'target_type'], | ['name', 'target', 'target_type'], | ||||
cur, | cur, | ||||
) | ) | ||||
except VALIDATION_EXCEPTIONS + (KeyError,) as e: | |||||
raise StorageArgumentException(*e.args) | |||||
if self.journal_writer: | if self.journal_writer: | ||||
self.journal_writer.write_addition('snapshot', snapshot) | self.journal_writer.write_addition('snapshot', snapshot) | ||||
db.snapshot_add(snapshot['id'], cur) | db.snapshot_add(snapshot['id'], cur) | ||||
count += 1 | count += 1 | ||||
return {'snapshot:add': count} | return {'snapshot:add': count} | ||||
Show All 32 Lines | def snapshot_get_latest(self, origin, allowed_statuses=None, db=None, | ||||
origin_visit = self.origin_visit_get_latest( | origin_visit = self.origin_visit_get_latest( | ||||
origin, allowed_statuses=allowed_statuses, require_snapshot=True, | origin, allowed_statuses=allowed_statuses, require_snapshot=True, | ||||
db=db, cur=cur) | db=db, cur=cur) | ||||
if origin_visit and origin_visit['snapshot']: | if origin_visit and origin_visit['snapshot']: | ||||
snapshot = self.snapshot_get( | snapshot = self.snapshot_get( | ||||
origin_visit['snapshot'], db=db, cur=cur) | origin_visit['snapshot'], db=db, cur=cur) | ||||
if not snapshot: | if not snapshot: | ||||
raise ValueError( | raise StorageArgumentException( | ||||
'last origin visit references an unknown snapshot') | 'last origin visit references an unknown snapshot') | ||||
return snapshot | return snapshot | ||||
@timed | @timed | ||||
@db_transaction(statement_timeout=2000) | @db_transaction(statement_timeout=2000) | ||||
def snapshot_count_branches(self, snapshot_id, db=None, cur=None): | def snapshot_count_branches(self, snapshot_id, db=None, cur=None): | ||||
return dict([bc for bc in | return dict([bc for bc in | ||||
db.snapshot_count_branches(snapshot_id, cur)]) | db.snapshot_count_branches(snapshot_id, cur)]) | ||||
▲ Show 20 Lines • Show All 49 Lines • ▼ Show 20 Lines | class Storage(): | ||||
def origin_visit_add(self, origin, date, type, | def origin_visit_add(self, origin, date, type, | ||||
db=None, cur=None): | db=None, cur=None): | ||||
origin_url = origin | origin_url = origin | ||||
if isinstance(date, str): | if isinstance(date, str): | ||||
# FIXME: Converge on iso8601 at some point | # FIXME: Converge on iso8601 at some point | ||||
date = dateutil.parser.parse(date) | date = dateutil.parser.parse(date) | ||||
with convert_validation_exceptions(): | |||||
visit_id = db.origin_visit_add(origin_url, date, type, cur) | visit_id = db.origin_visit_add(origin_url, date, type, cur) | ||||
if self.journal_writer: | if self.journal_writer: | ||||
# We can write to the journal only after inserting to the | # We can write to the journal only after inserting to the | ||||
# DB, because we want the id of the visit | # DB, because we want the id of the visit | ||||
self.journal_writer.write_addition('origin_visit', { | self.journal_writer.write_addition('origin_visit', { | ||||
'origin': origin_url, 'date': date, 'type': type, | 'origin': origin_url, 'date': date, 'type': type, | ||||
'visit': visit_id, | 'visit': visit_id, | ||||
'status': 'ongoing', 'metadata': None, 'snapshot': None}) | 'status': 'ongoing', 'metadata': None, 'snapshot': None}) | ||||
send_metric('origin_visit:add', count=1, method_name='origin_visit') | send_metric('origin_visit:add', count=1, method_name='origin_visit') | ||||
return { | return { | ||||
'origin': origin_url, | 'origin': origin_url, | ||||
'visit': visit_id, | 'visit': visit_id, | ||||
} | } | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def origin_visit_update(self, origin, visit_id, status=None, | def origin_visit_update(self, origin, visit_id, status=None, | ||||
metadata=None, snapshot=None, | metadata=None, snapshot=None, | ||||
db=None, cur=None): | db=None, cur=None): | ||||
if not isinstance(origin, str): | if not isinstance(origin, str): | ||||
raise TypeError('origin must be a string, not %r' % (origin,)) | raise StorageArgumentException( | ||||
'origin must be a string, not %r' % (origin,)) | |||||
origin_url = origin | origin_url = origin | ||||
visit = db.origin_visit_get(origin_url, visit_id, cur=cur) | visit = db.origin_visit_get(origin_url, visit_id, cur=cur) | ||||
if not visit: | if not visit: | ||||
raise ValueError('Invalid visit_id for this origin.') | raise StorageArgumentException('Invalid visit_id for this origin.') | ||||
visit = dict(zip(db.origin_visit_get_cols, visit)) | visit = dict(zip(db.origin_visit_get_cols, visit)) | ||||
updates = {} | updates = {} | ||||
if status and status != visit['status']: | if status and status != visit['status']: | ||||
updates['status'] = status | updates['status'] = status | ||||
if metadata and metadata != visit['metadata']: | if metadata and metadata != visit['metadata']: | ||||
updates['metadata'] = metadata | updates['metadata'] = metadata | ||||
if snapshot and snapshot != visit['snapshot']: | if snapshot and snapshot != visit['snapshot']: | ||||
updates['snapshot'] = snapshot | updates['snapshot'] = snapshot | ||||
if updates: | if updates: | ||||
if self.journal_writer: | if self.journal_writer: | ||||
self.journal_writer.write_update('origin_visit', { | self.journal_writer.write_update('origin_visit', { | ||||
**visit, **updates}) | **visit, **updates}) | ||||
with convert_validation_exceptions(): | |||||
db.origin_visit_update(origin_url, visit_id, updates, cur) | db.origin_visit_update(origin_url, visit_id, updates, cur) | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def origin_visit_upsert(self, visits, db=None, cur=None): | def origin_visit_upsert(self, visits, db=None, cur=None): | ||||
visits = copy.deepcopy(visits) | visits = copy.deepcopy(visits) | ||||
for visit in visits: | for visit in visits: | ||||
if isinstance(visit['date'], str): | if isinstance(visit['date'], str): | ||||
visit['date'] = dateutil.parser.parse(visit['date']) | visit['date'] = dateutil.parser.parse(visit['date']) | ||||
if not isinstance(visit['origin'], str): | if not isinstance(visit['origin'], str): | ||||
raise TypeError("visit['origin'] must be a string, not %r" | raise StorageArgumentException( | ||||
"visit['origin'] must be a string, not %r" | |||||
% (visit['origin'],)) | % (visit['origin'],)) | ||||
if self.journal_writer: | if self.journal_writer: | ||||
for visit in visits: | for visit in visits: | ||||
self.journal_writer.write_addition('origin_visit', visit) | self.journal_writer.write_addition('origin_visit', visit) | ||||
for visit in visits: | for visit in visits: | ||||
# TODO: upsert them all in a single query | # TODO: upsert them all in a single query | ||||
db.origin_visit_upsert(**visit, cur=cur) | db.origin_visit_upsert(**visit, cur=cur) | ||||
▲ Show 20 Lines • Show All 99 Lines • ▼ Show 20 Lines | def origin_get_range(self, origin_from=1, origin_count=100, | ||||
yield dict(zip(db.origin_get_range_cols, origin)) | yield dict(zip(db.origin_get_range_cols, origin)) | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def origin_list(self, page_token: Optional[str] = None, limit: int = 100, | def origin_list(self, page_token: Optional[str] = None, limit: int = 100, | ||||
*, db=None, cur=None) -> dict: | *, db=None, cur=None) -> dict: | ||||
page_token = page_token or '0' | page_token = page_token or '0' | ||||
if not isinstance(page_token, str): | if not isinstance(page_token, str): | ||||
raise TypeError('page_token must be a string.') | raise StorageArgumentException('page_token must be a string.') | ||||
origin_from = int(page_token) | origin_from = int(page_token) | ||||
result: Dict[str, Any] = { | result: Dict[str, Any] = { | ||||
'origins': [ | 'origins': [ | ||||
dict(zip(db.origin_get_range_cols, origin)) | dict(zip(db.origin_get_range_cols, origin)) | ||||
for origin in db.origin_get_range(origin_from, limit, cur) | for origin in db.origin_get_range(origin_from, limit, cur) | ||||
], | ], | ||||
} | } | ||||
Show All 28 Lines | def origin_add(self, origins, db=None, cur=None): | ||||
self.origin_add_one(origin, db=db, cur=cur) | self.origin_add_one(origin, db=db, cur=cur) | ||||
send_metric('origin:add', count=len(origins), method_name='origin_add') | send_metric('origin:add', count=len(origins), method_name='origin_add') | ||||
return origins | return origins | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def origin_add_one(self, origin, db=None, cur=None): | def origin_add_one(self, origin, db=None, cur=None): | ||||
if 'url' not in origin: | |||||
raise StorageArgumentException('Missing origin url') | |||||
origin_row = list(db.origin_get_by_url([origin['url']], cur))[0] | origin_row = list(db.origin_get_by_url([origin['url']], cur))[0] | ||||
origin_url = dict(zip(db.origin_cols, origin_row))['url'] | origin_url = dict(zip(db.origin_cols, origin_row))['url'] | ||||
if origin_url: | if origin_url: | ||||
return origin_url | return origin_url | ||||
if self.journal_writer: | if self.journal_writer: | ||||
self.journal_writer.write_addition('origin', origin) | self.journal_writer.write_addition('origin', origin) | ||||
▲ Show 20 Lines • Show All 43 Lines • ▼ Show 20 Lines | def origin_metadata_get_by(self, origin_url, provider_type=None, db=None, | ||||
cur=None): | cur=None): | ||||
for line in db.origin_metadata_get_by(origin_url, provider_type, cur): | for line in db.origin_metadata_get_by(origin_url, provider_type, cur): | ||||
yield dict(zip(db.origin_metadata_get_cols, line)) | yield dict(zip(db.origin_metadata_get_cols, line)) | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def tool_add(self, tools, db=None, cur=None): | def tool_add(self, tools, db=None, cur=None): | ||||
db.mktemp_tool(cur) | db.mktemp_tool(cur) | ||||
with convert_validation_exceptions(): | |||||
db.copy_to(tools, 'tmp_tool', | db.copy_to(tools, 'tmp_tool', | ||||
['name', 'version', 'configuration'], | ['name', 'version', 'configuration'], | ||||
cur) | cur) | ||||
tools = db.tool_add_from_temp(cur) | tools = db.tool_add_from_temp(cur) | ||||
results = [dict(zip(db.tool_cols, line)) for line in tools] | results = [dict(zip(db.tool_cols, line)) for line in tools] | ||||
send_metric('tool:add', count=len(results), method_name='tool_add') | send_metric('tool:add', count=len(results), method_name='tool_add') | ||||
return results | return results | ||||
@timed | @timed | ||||
@db_transaction(statement_timeout=500) | @db_transaction(statement_timeout=500) | ||||
def tool_get(self, tool, db=None, cur=None): | def tool_get(self, tool, db=None, cur=None): | ||||
tool_conf = tool['configuration'] | tool_conf = tool['configuration'] | ||||
▲ Show 20 Lines • Show All 48 Lines • Show Last 20 Lines |