diff --git a/requirements-swh.txt b/requirements-swh.txt --- a/requirements-swh.txt +++ b/requirements-swh.txt @@ -1,3 +1,3 @@ swh.core[db,http] >= 0.0.65 -swh.model >= 0.0.32 +swh.model >= 0.0.41 swh.objstorage >= 0.0.17 diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -15,6 +15,9 @@ import random import warnings +import attr + +from swh.model.model import Content from swh.model.hashutil import DEFAULT_ALGORITHMS from swh.model.identifiers import normalize_timestamp from swh.objstorage import get_objstorage @@ -74,21 +77,19 @@ def _content_add(self, contents, with_data): if self.journal_writer: for content in contents: - if 'data' in content: - content = content.copy() - del content['data'] + content = attr.evolve(content, data=None) self.journal_writer.write_addition('content', content) content_with_data = [] content_without_data = [] for content in contents: - if 'status' not in content: - content['status'] = 'visible' - if 'length' not in content: - content['length'] = -1 - if content['status'] == 'visible': + if content.status is None: + content.status = 'visible' + if content.length is None: + content.length = -1 + if content.status == 'visible': content_with_data.append(content) - elif content['status'] == 'absent': + elif content.status == 'absent': content_without_data.append(content) count_content_added, count_content_bytes_added = \ @@ -116,21 +117,24 @@ if key in self._contents: continue for algorithm in DEFAULT_ALGORITHMS: - if content[algorithm] in self._content_indexes[algorithm]\ + hash_ = content.get_hash(algorithm) + if hash_ in self._content_indexes[algorithm]\ and (algorithm not in {'blake2s256', 'sha256'}): from . import HashCollision - raise HashCollision(algorithm, content[algorithm], key) + raise HashCollision(algorithm, hash_, key) for algorithm in DEFAULT_ALGORITHMS: - self._content_indexes[algorithm][content[algorithm]].add(key) - self._objects[content['sha1_git']].append( - ('content', content['sha1'])) - self._contents[key] = copy.deepcopy(content) - bisect.insort(self._sorted_sha1s, content['sha1']) + hash_ = content.get_hash(algorithm) + self._content_indexes[algorithm][hash_].add(key) + self._objects[content.sha1_git].append( + ('content', content.sha1)) + self._contents[key] = content + bisect.insort(self._sorted_sha1s, content.sha1) count_content_added += 1 if with_data: - content_data = self._contents[key].pop('data') + content_data = self._contents[key].data + self._contents[key].data = None count_content_bytes_added += len(content_data) - self.objstorage.add(content_data, content['sha1']) + self.objstorage.add(content_data, content.sha1) return (count_content_added, count_content_bytes_added) @@ -140,8 +144,9 @@ for content in skipped_content_missing: key = self._content_key(content) for algo in DEFAULT_ALGORITHMS: - self._skipped_content_indexes[algo][content[algo]].add(key) - self._skipped_contents[key] = copy.deepcopy(content) + self._skipped_content_indexes[algo][content.get_hash(algo)] \ + .add(key) + self._skipped_contents[key] = content count += 1 return count @@ -175,10 +180,10 @@ skipped_content:add: New skipped contents (no data) added """ - content = [dict(c.items()) for c in content] # semi-shallow copy + content = [Content.from_dict(c) for c in content] now = datetime.datetime.now(tz=datetime.timezone.utc) for item in content: - item['ctime'] = now + item.ctime = now return self._content_add(content, with_data=True) def content_add_metadata(self, content): @@ -210,6 +215,7 @@ skipped_content:add: New skipped contents (no data) added """ + content = [Content.from_dict(c) for c in content] return self._content_add(content, with_data=False) def content_get(self, content): @@ -282,9 +288,7 @@ if len(matched) >= limit: next_content = sha1 break - matched.append({ - **self._contents[key], - }) + matched.append(self._contents[key].to_dict()) return { 'contents': matched, 'next': next_content, @@ -308,9 +312,9 @@ # hash, we should return all of them. See: # https://forge.softwareheritage.org/D645?id=1994#inline-3389 key = random.sample(objs, 1)[0] - data = copy.deepcopy(self._contents[key]) - data.pop('ctime') - yield data + d = self._contents[key].to_dict() + del d['ctime'] + yield d else: # FIXME: should really be None yield { @@ -336,7 +340,7 @@ return [] keys = list(set.intersection(*found)) - return copy.deepcopy([self._contents[key] for key in keys]) + return [self._contents[key].to_dict() for key in keys] def content_missing(self, content, key_hash='sha1'): """List content missing from storage @@ -1560,8 +1564,9 @@ for item in self._origin_metadata[origin_id]: item = copy.deepcopy(item) provider = self.metadata_provider_get(item['provider_id']) - for attr in ('name', 'type', 'url'): - item['provider_' + attr] = provider['provider_' + attr] + for attr_name in ('name', 'type', 'url'): + item['provider_' + attr_name] = \ + provider['provider_' + attr_name] metadata.append(item) return metadata @@ -1692,11 +1697,14 @@ @staticmethod def _content_key(content): """A stable key for a content""" - return tuple(content.get(key) for key in sorted(DEFAULT_ALGORITHMS)) + return tuple(getattr(content, key) + for key in sorted(DEFAULT_ALGORITHMS)) @staticmethod def _content_key_algorithm(content): """ A stable key and the algorithm for a content""" + if isinstance(content, Content): + content = content.to_dict() return tuple((content.get(key), key) for key in sorted(DEFAULT_ALGORITHMS)) diff --git a/swh/storage/journal_writer.py b/swh/storage/journal_writer.py --- a/swh/storage/journal_writer.py +++ b/swh/storage/journal_writer.py @@ -7,6 +7,9 @@ from multiprocessing import Manager +from swh.model.model import BaseModel + + class InMemoryJournalWriter: def __init__(self): # Share the list of objects across processes, for RemoteAPI tests. @@ -14,6 +17,8 @@ self.objects = self.manager.list() def write_addition(self, object_type, object_): + if isinstance(object_, BaseModel): + object_ = object_.to_dict() self.objects.append((object_type, copy.deepcopy(object_))) write_update = write_addition diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -105,16 +105,37 @@ return hash return tuple([hash[k] for k in keys]) + def _normalize_content(self, d): + d = d.copy() + + if 'status' not in d: + d['status'] = 'visible' + + if 'length' not in d: + d['length'] = -1 + + return d + + def _validate_content(self, d): + if d['status'] not in ('visible', 'absent', 'hidden'): + raise ValueError('Invalid content status: {}'.format(d['status'])) + + if d['status'] != 'absent' and d.get('reason') is not None: + raise ValueError( + 'Must not provide a reason if content is not absent.') + + if d['length'] < -1: + raise ValueError('Content length must be positive or -1.') + def _filter_new_content(self, content, db, cur): content_by_status = defaultdict(list) for d in content: - if 'status' not in d: - d['status'] = 'visible' - if 'length' not in d: - d['length'] = -1 + d = self._normalize_content(d) + self._validate_content(d) content_by_status[d['status']].append(d) - content_with_data = content_by_status['visible'] + content_with_data = content_by_status['visible'] \ + + content_by_status['hidden'] content_without_data = content_by_status['absent'] missing_content = set(self.content_missing(content_with_data, diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -11,6 +11,7 @@ from collections import defaultdict from unittest.mock import Mock, patch +import psycopg2.errors import pytest from hypothesis import given, strategies, settings, HealthCheck @@ -596,6 +597,25 @@ self.assertEqual(journal_objects, [('content', expected_cont)]) + def test_content_add_validation(self): + cont = self.cont + + with self.assertRaisesRegex(ValueError, 'status'): + self.storage.content_add([{**cont, 'status': 'foobar'}]) + + with self.assertRaisesRegex(ValueError, "(?i)length"): + self.storage.content_add([{**cont, 'length': -2}]) + + with self.assertRaisesRegex( + (ValueError, psycopg2.errors.NotNullViolation), + "reason"): + self.storage.content_add([{**cont, 'status': 'absent'}]) + + with self.assertRaisesRegex( + ValueError, + "^Must not provide a reason if content is not absent.$"): + self.storage.content_add([{**cont, 'reason': 'foobar'}]) + def test_content_get_missing(self): cont = self.cont @@ -3764,7 +3784,8 @@ origin_id = self.storage.origin_add_one(obj.pop('origin')) if 'visit' in obj: del obj['visit'] - self.storage.origin_visit_add(origin_id, **obj) + self.storage.origin_visit_add( + origin_id, obj['date'], obj['type']) else: method = getattr(self.storage, obj_type + '_add') try: