diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -15,6 +15,9 @@ import random import warnings +import attr + +from swh.model.model import Content from swh.model.hashutil import DEFAULT_ALGORITHMS from swh.model.identifiers import normalize_timestamp from swh.objstorage import get_objstorage @@ -74,21 +77,19 @@ def _content_add(self, contents, with_data): if self.journal_writer: for content in contents: - if 'data' in content: - content = content.copy() - del content['data'] + content = attr.evolve(content, data=None) self.journal_writer.write_addition('content', content) content_with_data = [] content_without_data = [] for content in contents: - if 'status' not in content: - content['status'] = 'visible' - if 'length' not in content: - content['length'] = -1 - if content['status'] == 'visible': + if content.status is None: + content.status = 'visible' + if content.length is None: + content.length = -1 + if content.status == 'visible': content_with_data.append(content) - elif content['status'] == 'absent': + elif content.status == 'absent': content_without_data.append(content) count_content_added, count_content_bytes_added = \ @@ -116,21 +117,24 @@ if key in self._contents: continue for algorithm in DEFAULT_ALGORITHMS: - if content[algorithm] in self._content_indexes[algorithm]\ + hash_ = content.get_hash(algorithm) + if hash_ in self._content_indexes[algorithm]\ and (algorithm not in {'blake2s256', 'sha256'}): from . import HashCollision - raise HashCollision(algorithm, content[algorithm], key) + raise HashCollision(algorithm, hash_, key) for algorithm in DEFAULT_ALGORITHMS: - self._content_indexes[algorithm][content[algorithm]].add(key) - self._objects[content['sha1_git']].append( - ('content', content['sha1'])) + hash_ = content.get_hash(algorithm) + self._content_indexes[algorithm][hash_].add(key) + self._objects[content.sha1_git].append( + ('content', content.sha1)) self._contents[key] = copy.deepcopy(content) - bisect.insort(self._sorted_sha1s, content['sha1']) + bisect.insort(self._sorted_sha1s, content.sha1) count_content_added += 1 if with_data: - content_data = self._contents[key].pop('data') + content_data = self._contents[key].data + self._contents[key].data = None count_content_bytes_added += len(content_data) - self.objstorage.add(content_data, content['sha1']) + self.objstorage.add(content_data, content.sha1) return (count_content_added, count_content_bytes_added) @@ -140,7 +144,8 @@ for content in skipped_content_missing: key = self._content_key(content) for algo in DEFAULT_ALGORITHMS: - self._skipped_content_indexes[algo][content[algo]].add(key) + self._skipped_content_indexes[algo][content.get_hash(algo)] \ + .add(key) self._skipped_contents[key] = copy.deepcopy(content) count += 1 @@ -175,10 +180,10 @@ skipped_content:add: New skipped contents (no data) added """ - content = [dict(c.items()) for c in content] # semi-shallow copy + content = [Content(**c) for c in content] now = datetime.datetime.now(tz=datetime.timezone.utc) for item in content: - item['ctime'] = now + item.ctime = now return self._content_add(content, with_data=True) def content_add_metadata(self, content): @@ -210,6 +215,7 @@ skipped_content:add: New skipped contents (no data) added """ + content = [Content(**c) for c in content] return self._content_add(content, with_data=False) def content_get(self, content): @@ -282,9 +288,7 @@ if len(matched) >= limit: next_content = sha1 break - matched.append({ - **self._contents[key], - }) + matched.append(self._contents[key].to_dict()) return { 'contents': matched, 'next': next_content, @@ -308,9 +312,9 @@ # hash, we should return all of them. See: # https://forge.softwareheritage.org/D645?id=1994#inline-3389 key = random.sample(objs, 1)[0] - data = copy.deepcopy(self._contents[key]) - data.pop('ctime') - yield data + content = copy.deepcopy(self._contents[key]) + content.ctime = None + yield content.to_dict() else: # FIXME: should really be None yield { @@ -336,7 +340,7 @@ return [] keys = list(set.intersection(*found)) - return copy.deepcopy([self._contents[key] for key in keys]) + return copy.deepcopy([self._contents[key].to_dict() for key in keys]) def content_missing(self, content, key_hash='sha1'): """List content missing from storage @@ -1576,8 +1580,9 @@ for item in self._origin_metadata[origin_id]: item = copy.deepcopy(item) provider = self.metadata_provider_get(item['provider_id']) - for attr in ('name', 'type', 'url'): - item['provider_' + attr] = provider['provider_' + attr] + for attr_name in ('name', 'type', 'url'): + item['provider_' + attr_name] = \ + provider['provider_' + attr_name] metadata.append(item) return metadata @@ -1709,11 +1714,14 @@ @staticmethod def _content_key(content): """A stable key for a content""" - return tuple(content.get(key) for key in sorted(DEFAULT_ALGORITHMS)) + return tuple(getattr(content, key) + for key in sorted(DEFAULT_ALGORITHMS)) @staticmethod def _content_key_algorithm(content): """ A stable key and the algorithm for a content""" + if isinstance(content, Content): + content = content.to_dict() return tuple((content.get(key), key) for key in sorted(DEFAULT_ALGORITHMS)) diff --git a/swh/storage/journal_writer.py b/swh/storage/journal_writer.py --- a/swh/storage/journal_writer.py +++ b/swh/storage/journal_writer.py @@ -7,6 +7,9 @@ from multiprocessing import Manager +from swh.model.model import BaseModel + + class InMemoryJournalWriter: def __init__(self): # Share the list of objects across processes, for RemoteAPI tests. @@ -14,6 +17,8 @@ self.objects = self.manager.list() def write_addition(self, object_type, object_): + if isinstance(object_, BaseModel): + object_ = object_.to_dict() self.objects.append((object_type, copy.deepcopy(object_))) write_update = write_addition diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -105,16 +105,40 @@ return hash return tuple([hash[k] for k in keys]) + def _normalize_content(self, d): + d = d.copy() + + if 'status' not in d: + d['status'] = 'visible' + + if 'length' not in d: + d['length'] = -1 + + return d + + def _validate_content(self, d): + if d['status'] not in ('visible', 'absent', 'hidden'): + raise ValueError('Invalid content status: {}'.format(d['status'])) + + if d['status'] == 'absent' and d.get('reason') is None: + raise ValueError( + 'Must provide a reason if content is absent.') + if d['status'] != 'absent' and d.get('reason') is not None: + raise ValueError( + 'Must not provide a reason if content is not absent.') + + if d['length'] < -1: + raise ValueError('Content length must be positive or -1.') + def _filter_new_content(self, content, db, cur): content_by_status = defaultdict(list) for d in content: - if 'status' not in d: - d['status'] = 'visible' - if 'length' not in d: - d['length'] = -1 + d = self._normalize_content(d) + self._validate_content(d) content_by_status[d['status']].append(d) - content_with_data = content_by_status['visible'] + content_with_data = content_by_status['visible'] \ + + content_by_status['hidden'] content_without_data = content_by_status['absent'] missing_content = set(self.content_missing(content_with_data, diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -596,6 +596,25 @@ self.assertEqual(journal_objects, [('content', expected_cont)]) + def test_content_add_validation(self): + cont = self.cont + + with self.assertRaisesRegex(ValueError, 'status'): + self.storage.content_add([{**cont, 'status': 'foobar'}]) + + with self.assertRaisesRegex(ValueError, "(?i)length"): + self.storage.content_add([{**cont, 'length': -2}]) + + with self.assertRaisesRegex( + ValueError, + "^Must provide a reason if content is absent.$"): + self.storage.content_add([{**cont, 'status': 'absent'}]) + + with self.assertRaisesRegex( + ValueError, + "^Must not provide a reason if content is not absent.$"): + self.storage.content_add([{**cont, 'reason': 'foobar'}]) + def test_content_get_missing(self): cont = self.cont