diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -19,12 +19,7 @@ ) from swh.objstorage import get_objstorage from swh.objstorage.exc import ObjNotFoundError -try: - from swh.journal.writer import get_journal_writer -except ImportError: - get_journal_writer = None # type: ignore - # mypy limitation, see https://github.com/python/mypy/issues/1153 - +from swh.storage.writer import JournalWriter from .. import HashCollision from ..exc import StorageArgumentException @@ -50,11 +45,7 @@ self._cql_runner = CqlRunner(hosts, keyspace, port) self.objstorage = get_objstorage(**objstorage) - - if journal_writer: - self.journal_writer = get_journal_writer(**journal_writer) - else: - self.journal_writer = None + self.journal_writer = JournalWriter(journal_writer) def check_config(self, *, check_write): self._cql_runner.check_read() @@ -71,12 +62,7 @@ contents = [c for c in contents if not self._cql_runner.content_get_from_pk(c.to_dict())] - if self.journal_writer: - for content in contents: - content = content.to_dict() - if 'data' in content: - del content['data'] - self.journal_writer.write_addition('content', content) + self.journal_writer.content_add(c.to_dict() for c in contents) count_contents = 0 count_content_added = 0 @@ -275,12 +261,7 @@ c for c in contents if not self._cql_runner.skipped_content_get_from_pk(c.to_dict())] - if self.journal_writer: - for content in contents: - content = content.to_dict() - if 'data' in content: - del content['data'] - self.journal_writer.write_addition('content', content) + self.journal_writer.skipped_content_add(c.to_dict() for c in contents) for content in contents: # Add to index tables @@ -314,8 +295,7 @@ missing = self.directory_missing([dir_['id'] for dir_ in directories]) directories = [dir_ for dir_ in directories if dir_['id'] in missing] - if self.journal_writer: - self.journal_writer.write_additions('directory', directories) + self.journal_writer.directory_add(directories) for directory in directories: try: @@ -423,8 +403,7 @@ missing = self.revision_missing([rev['id'] for rev in revisions]) revisions = [rev for rev in revisions if rev['id'] in missing] - if self.journal_writer: - self.journal_writer.write_additions('revision', revisions) + self.journal_writer.revision_add(revisions) for revision in revisions: try: @@ -521,8 +500,7 @@ missing = self.release_missing([rel['id'] for rel in releases]) releases = [rel for rel in releases if rel['id'] in missing] - if self.journal_writer: - self.journal_writer.write_additions('release', releases) + self.journal_writer.release_add(releases) for release in releases: try: @@ -562,8 +540,7 @@ snapshots = [snp for snp in snapshots if snp.id in missing] for snapshot in snapshots: - if self.journal_writer: - self.journal_writer.write_addition('snapshot', snapshot) + self.journal_writer.snapshot_add(snapshot) # Add branches for (branch_name, branch) in snapshot.branches.items(): @@ -814,8 +791,7 @@ if known_origin: origin_url = known_origin['url'] else: - if self.journal_writer: - self.journal_writer.write_addition('origin', origin) + self.journal_writer.origin_add_one(origin) self._cql_runner.origin_add_one(origin) origin_url = origin['url'] @@ -844,9 +820,7 @@ 'metadata': None, 'visit': visit_id } - - if self.journal_writer: - self.journal_writer.write_addition('origin_visit', visit) + self.journal_writer.origin_visit_add(visit, date, type) try: visit = OriginVisit.from_dict(visit) @@ -886,8 +860,8 @@ except (KeyError, TypeError, ValueError) as e: raise StorageArgumentException(*e.args) - if self.journal_writer: - self.journal_writer.write_update('origin_visit', visit) + self.journal_writer.origin_visit_update( + visit.to_dict(), visit_id, status, metadata, snapshot) self._cql_runner.origin_visit_update(origin_url, visit_id, updates) @@ -897,9 +871,7 @@ if isinstance(visit['date'], str): visit['date'] = dateutil.parser.parse(visit['date']) - if self.journal_writer: - for visit in visits: - self.journal_writer.write_addition('origin_visit', visit) + self.journal_writer.origin_visit_upsert(visits) for visit in visits: visit = visit.copy() diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -27,9 +27,10 @@ from . import HashCollision from .exc import StorageArgumentException -from .storage import get_journal_writer + from .converters import origin_url_to_sha1 from .utils import get_partition_bounds_bytes +from .writer import JournalWriter # Max block size of contents to return BULK_BLOCK_CONTENT_LEN_MAX = 10000 @@ -47,11 +48,7 @@ self._skipped_content_indexes = defaultdict(lambda: defaultdict(set)) self.reset() - - if journal_writer: - self.journal_writer = get_journal_writer(**journal_writer) - else: - self.journal_writer = None + self.journal_writer = JournalWriter(journal_writer) def reset(self): self._directories = {} @@ -84,14 +81,12 @@ raise StorageArgumentException('content with status=absent') if content.length is None: raise StorageArgumentException('content with length=None') + try: + attr.evolve(content, data=None) + except (KeyError, TypeError, ValueError) as e: + raise StorageArgumentException(*e.args) - if self.journal_writer: - for content in contents: - try: - content = attr.evolve(content, data=None) - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) - self.journal_writer.write_addition('content', content) + self.journal_writer.content_add(c.to_dict() for c in contents) summary = { 'content:add': 0, @@ -140,9 +135,7 @@ return self._content_add(content, with_data=True) def content_update(self, content, keys=[]): - if self.journal_writer: - raise NotImplementedError( - 'content_update is not yet supported with a journal_writer.') + self.journal_writer.content_update(content, keys=keys) for cont_update in content: cont_update = cont_update.copy() @@ -295,9 +288,7 @@ raise StorageArgumentException( f'Content with status={content.status}') - if self.journal_writer: - for content in contents: - self.journal_writer.write_addition('content', content) + self.journal_writer.skipped_content_add(contents) summary = { 'skipped_content:add': 0 @@ -341,11 +332,10 @@ def directory_add(self, directories): directories = list(directories) - if self.journal_writer: - self.journal_writer.write_additions( - 'directory', - (dir_ for dir_ in directories - if dir_['id'] not in self._directories)) + self.journal_writer.directory_add( + dir_ for dir_ in directories + if dir_['id'] not in self._directories + ) try: directories = [Directory.from_dict(d) for d in directories] @@ -437,11 +427,10 @@ def revision_add(self, revisions): revisions = list(revisions) - if self.journal_writer: - self.journal_writer.write_additions( - 'revision', - (rev for rev in revisions - if rev['id'] not in self._revisions)) + self.journal_writer.revision_add( + rev for rev in revisions + if rev['id'] not in self._revisions + ) try: revisions = [Revision.from_dict(rev) for rev in revisions] @@ -498,11 +487,10 @@ def release_add(self, releases): releases = list(releases) - if self.journal_writer: - self.journal_writer.write_additions( - 'release', - (rel for rel in releases - if rel['id'] not in self._releases)) + self.journal_writer.release_add( + rel for rel in releases + if rel['id'] not in self._releases + ) try: releases = [Release.from_dict(rel) for rel in releases] @@ -543,9 +531,7 @@ snapshots = (snap for snap in snapshots if snap.id not in self._snapshots) for snapshot in snapshots: - if self.journal_writer: - self.journal_writer.write_addition('snapshot', snapshot) - + self.journal_writer.snapshot_add(snapshot.to_dict()) sorted_branch_names = sorted(snapshot.branches) self._snapshots[snapshot.id] = (snapshot, sorted_branch_names) self._objects[snapshot.id].append(('snapshot', snapshot.id)) @@ -761,9 +747,7 @@ except (KeyError, TypeError, ValueError) as e: raise StorageArgumentException(*e.args) if origin.url not in self._origins: - if self.journal_writer: - self.journal_writer.write_addition('origin', origin) - + self.journal_writer.origin_add_one(origin.to_dict()) # generate an origin_id because it is needed by origin_get_range. # TODO: remove this when we remove origin_get_range origin_id = len(self._origins) + 1 @@ -813,8 +797,7 @@ self._objects[(origin_url, visit_id)].append( ('origin_visit', None)) - if self.journal_writer: - self.journal_writer.write_addition('origin_visit', visit) + self.journal_writer.origin_visit_add(visit.to_dict(), date, type) return visit_ret @@ -845,8 +828,8 @@ except (KeyError, TypeError, ValueError) as e: raise StorageArgumentException(*e.args) - if self.journal_writer: - self.journal_writer.write_update('origin_visit', visit) + self.journal_writer.origin_visit_update( + visit.to_dict(), visit_id, status, metadata, snapshot) self._origin_visits[origin_url][visit_id-1] = visit @@ -860,9 +843,7 @@ except (KeyError, TypeError, ValueError) as e: raise StorageArgumentException(*e.args) - if self.journal_writer: - for visit in visits: - self.journal_writer.write_addition('origin_visit', visit) + self.journal_writer.origin_visit_upsert(v.to_dict() for v in visits) for visit in visits: visit_id = visit.visit diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -23,11 +23,6 @@ from swh.model.hashutil import ALGORITHMS, hash_to_bytes, hash_to_hex from swh.objstorage import get_objstorage from swh.objstorage.exc import ObjNotFoundError -try: - from swh.journal.writer import get_journal_writer -except ImportError: - get_journal_writer = None # type: ignore - # mypy limitation, see https://github.com/python/mypy/issues/1153 from . import converters, HashCollision from .common import db_transaction_generator, db_transaction @@ -36,6 +31,7 @@ from .algos import diff from .metrics import timed, send_metric, process_metrics from .utils import get_partition_bounds_bytes +from .writer import JournalWriter # Max block size of contents to return @@ -93,14 +89,7 @@ raise StorageDBError(e) self.objstorage = get_objstorage(**objstorage) - if journal_writer: - if get_journal_writer is None: - raise EnvironmentError( - 'You need the swh.journal package to use the ' - 'journal_writer feature') - self.journal_writer = get_journal_writer(**journal_writer) - else: - self.journal_writer = None + self.journal_writer = JournalWriter(journal_writer) def get_db(self): if self._db: @@ -218,12 +207,7 @@ missing = list(self.content_missing(content, key_hash='sha1_git')) content = [c for c in content if c['sha1_git'] in missing] - if self.journal_writer: - for item in content: - if 'data' in item: - item = item.copy() - del item['data'] - self.journal_writer.write_addition('content', item) + self.journal_writer.content_add(content) def add_to_objstorage(): """Add to objstorage the new missing_content @@ -265,10 +249,7 @@ def content_update(self, content, keys=[], db=None, cur=None): # TODO: Add a check on input keys. How to properly implement # this? We don't know yet the new columns. - - if self.journal_writer: - raise NotImplementedError( - 'content_update is not yet supported with a journal_writer.') + self.journal_writer.content_update(content, keys=keys) db.mktemp('content', cur) select_keys = list(set(db.content_get_metadata_keys).union(set(keys))) @@ -288,11 +269,7 @@ missing = self.content_missing(content, key_hash='sha1_git') content = [c for c in content if c['sha1_git'] in missing] - if self.journal_writer: - for item in itertools.chain(content): - assert 'data' not in content - self.journal_writer.write_addition('content', item) - + self.journal_writer.content_add_metadata(content) self._content_add_metadata(db, cur, content) return { @@ -482,10 +459,7 @@ for algo in ALGORITHMS) for missing_content in missing_contents)] - if self.journal_writer: - for item in content: - self.journal_writer.write_addition('content', item) - + self.journal_writer.skipped_content_add(content) self._skipped_content_add_metadata(db, cur, content) return { @@ -528,11 +502,10 @@ if not dirs_missing: return summary - if self.journal_writer: - self.journal_writer.write_additions( - 'directory', - (dir_ for dir_ in directories - if dir_['id'] in dirs_missing)) + self.journal_writer.directory_add( + dir_ for dir_ in directories + if dir_['id'] in dirs_missing + ) # Copy directory ids dirs_missing_dict = ({'id': dir} for dir in dirs_missing) @@ -612,8 +585,7 @@ revision for revision in revisions if revision['id'] in revisions_missing] - if self.journal_writer: - self.journal_writer.write_additions('revision', revisions_filtered) + self.journal_writer.revision_add(revisions_filtered) revisions_filtered = map(converters.revision_to_db, revisions_filtered) @@ -699,8 +671,7 @@ if release['id'] in releases_missing ] - if self.journal_writer: - self.journal_writer.write_additions('release', releases_filtered) + self.journal_writer.release_add(releases_filtered) releases_filtered = map(converters.release_to_db, releases_filtered) @@ -766,8 +737,7 @@ except VALIDATION_EXCEPTIONS + (KeyError,) as e: raise StorageArgumentException(*e.args) - if self.journal_writer: - self.journal_writer.write_addition('snapshot', snapshot) + self.journal_writer.snapshot_add(snapshot) db.snapshot_add(snapshot['id'], cur) count += 1 @@ -882,13 +852,18 @@ with convert_validation_exceptions(): visit_id = db.origin_visit_add(origin_url, date, type, cur) - if self.journal_writer: - # We can write to the journal only after inserting to the - # DB, because we want the id of the visit - self.journal_writer.write_addition('origin_visit', { - 'origin': origin_url, 'date': date, 'type': type, - 'visit': visit_id, - 'status': 'ongoing', 'metadata': None, 'snapshot': None}) + # We can write to the journal only after inserting to the + # DB, because we want the id of the visit + visit = { + 'origin': origin_url, + 'date': date, + 'type': type, + 'visit': visit_id, + 'status': 'ongoing', + 'metadata': None, + 'snapshot': None + } + self.journal_writer.origin_visit_add(visit, date, type) send_metric('origin_visit:add', count=1, method_name='origin_visit') return { @@ -921,9 +896,9 @@ updates['snapshot'] = snapshot if updates: - if self.journal_writer: - self.journal_writer.write_update('origin_visit', { - **visit, **updates}) + updated_visit = {**visit, **updates} + self.journal_writer.origin_visit_update( + updated_visit, visit_id, status, metadata, snapshot) with convert_validation_exceptions(): db.origin_visit_update(origin_url, visit_id, updates, cur) @@ -940,9 +915,7 @@ "visit['origin'] must be a string, not %r" % (visit['origin'],)) - if self.journal_writer: - for visit in visits: - self.journal_writer.write_addition('origin_visit', visit) + self.journal_writer.origin_visit_upsert(visits) for visit in visits: # TODO: upsert them all in a single query @@ -1106,8 +1079,7 @@ if origin_url: return origin_url - if self.journal_writer: - self.journal_writer.write_addition('origin', origin) + self.journal_writer.origin_add_one(origin) origins = db.origin_add(origin['url'], cur) send_metric('origin:add', count=len(origins), method_name='origin_add') diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py --- a/swh/storage/tests/test_api_client.py +++ b/swh/storage/tests/test_api_client.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2018 The Software Heritage developers +# Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -60,6 +60,7 @@ class TestStorage(_TestStorage): def test_content_update(self, swh_storage, app_server): - swh_storage.journal_writer = None # TODO, journal_writer not supported - with patch.object(server.storage, 'journal_writer', None): + # TODO, journal_writer not supported + swh_storage.journal_writer.journal = None + with patch.object(server.storage.journal_writer, 'journal', None): super().test_content_update(swh_storage) diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -149,7 +149,7 @@ expected_cont = data.cont del expected_cont['data'] - journal_objects = list(swh_storage.journal_writer.objects) + journal_objects = list(swh_storage.journal_writer.journal.objects) for (obj_type, obj) in journal_objects: assert insertion_start_time <= obj['ctime'] assert obj['ctime'] <= insertion_end_time @@ -239,14 +239,14 @@ 'content:add': 1, 'content:add:bytes': data.cont['length'], } - assert len(swh_storage.journal_writer.objects) == 1 + assert len(swh_storage.journal_writer.journal.objects) == 1 actual_result = swh_storage.content_add([data.cont, data.cont2]) assert actual_result == { 'content:add': 1, 'content:add:bytes': data.cont2['length'], } - assert 2 <= len(swh_storage.journal_writer.objects) <= 3 + assert 2 <= len(swh_storage.journal_writer.journal.objects) <= 3 assert len(swh_storage.content_find(data.cont)) == 1 assert len(swh_storage.content_find(data.cont2)) == 1 @@ -266,7 +266,7 @@ assert cm.value.args[0] in ['sha1', 'sha1_git', 'blake2s256'] def test_content_update(self, swh_storage): - swh_storage.journal_writer = None # TODO, not supported + swh_storage.journal_writer.journal = None # TODO, not supported cont = copy.deepcopy(data.cont) @@ -297,7 +297,8 @@ cont['sha1']: [expected_cont] } - assert list(swh_storage.journal_writer.objects) == [('content', cont)] + assert list(swh_storage.journal_writer.journal.objects) == [ + ('content', cont)] def test_content_add_metadata_different_input(self, swh_storage): cont = data.cont @@ -548,7 +549,7 @@ actual_result = swh_storage.directory_add([data.dir]) assert actual_result == {'directory:add': 1} - assert list(swh_storage.journal_writer.objects) == \ + assert list(swh_storage.journal_writer.journal.objects) == \ [('directory', data.dir)] actual_data = list(swh_storage.directory_ls(data.dir['id'])) @@ -570,7 +571,7 @@ actual_result = swh_storage.directory_add(directories=_dir_gen()) assert actual_result == {'directory:add': 1} - assert list(swh_storage.journal_writer.objects) == \ + assert list(swh_storage.journal_writer.journal.objects) == \ [('directory', data.dir)] swh_storage.refresh_stat_counters() @@ -596,13 +597,13 @@ actual_result = swh_storage.directory_add([data.dir]) assert actual_result == {'directory:add': 1} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('directory', data.dir)] actual_result = swh_storage.directory_add([data.dir]) assert actual_result == {'directory:add': 0} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('directory', data.dir)] def test_directory_get_recursive(self, swh_storage): @@ -613,7 +614,7 @@ [data.dir, data.dir2, data.dir3]) assert actual_result == {'directory:add': 3} - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.journal_writer.journal.objects) == [ ('directory', data.dir), ('directory', data.dir2), ('directory', data.dir3)] @@ -650,7 +651,7 @@ [data.dir, data.dir2, data.dir3]) assert actual_result == {'directory:add': 3} - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.journal_writer.journal.objects) == [ ('directory', data.dir), ('directory', data.dir2), ('directory', data.dir3)] @@ -762,7 +763,7 @@ end_missing = swh_storage.revision_missing([data.revision['id']]) assert list(end_missing) == [] - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('revision', data.revision)] # already there so nothing added @@ -817,14 +818,14 @@ actual_result = swh_storage.revision_add([data.revision]) assert actual_result == {'revision:add': 1} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('revision', data.revision)] actual_result = swh_storage.revision_add( [data.revision, data.revision2]) assert actual_result == {'revision:add': 1} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('revision', data.revision), ('revision', data.revision2)] @@ -866,7 +867,7 @@ assert actual_results[0] == normalize_entity(data.revision4) assert actual_results[1] == normalize_entity(data.revision3) - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.journal_writer.journal.objects) == [ ('revision', data.revision3), ('revision', data.revision4)] @@ -960,7 +961,7 @@ data.release2['id']]) assert list(end_missing) == [] - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.journal_writer.journal.objects) == [ ('release', data.release), ('release', data.release2)] @@ -979,7 +980,7 @@ actual_result = swh_storage.release_add(_rel_gen()) assert actual_result == {'release:add': 2} - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.journal_writer.journal.objects) == [ ('release', data.release), ('release', data.release2)] @@ -998,7 +999,7 @@ end_missing = swh_storage.release_missing([data.release['id']]) assert list(end_missing) == [] - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('release', release)] def test_release_add_validation(self, swh_storage): @@ -1025,13 +1026,13 @@ actual_result = swh_storage.release_add([data.release]) assert actual_result == {'release:add': 1} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('release', data.release)] actual_result = swh_storage.release_add([data.release, data.release2]) assert actual_result == {'release:add': 1} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('release', data.release), ('release', data.release2)] @@ -1113,7 +1114,7 @@ del actual_origin['id'] del actual_origin2['id'] - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('origin', actual_origin), ('origin', actual_origin2)] @@ -1141,7 +1142,7 @@ del actual_origin['id'] del actual_origin2['id'] - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('origin', actual_origin), ('origin', actual_origin2)] @@ -1150,12 +1151,12 @@ def test_origin_add_twice(self, swh_storage): add1 = swh_storage.origin_add([data.origin, data.origin2]) - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('origin', data.origin), ('origin', data.origin2)] add2 = swh_storage.origin_add([data.origin, data.origin2]) - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('origin', data.origin), ('origin', data.origin2)] @@ -1411,7 +1412,7 @@ 'metadata': None, 'snapshot': None, } - objects = list(swh_storage.journal_writer.objects) + objects = list(swh_storage.journal_writer.journal.objects) assert ('origin', data.origin2) in objects assert ('origin_visit', origin_visit) in objects @@ -1474,7 +1475,7 @@ for visit in expected_visits: assert visit in actual_origin_visits - objects = list(swh_storage.journal_writer.objects) + objects = list(swh_storage.journal_writer.journal.objects) assert ('origin', data.origin2) in objects for visit in expected_visits: @@ -1647,7 +1648,7 @@ 'metadata': None, 'snapshot': None, } - objects = list(swh_storage.journal_writer.objects) + objects = list(swh_storage.journal_writer.journal.objects) assert ('origin', data.origin) in objects assert ('origin', data.origin2) in objects assert ('origin_visit', data1) in objects @@ -1868,7 +1869,7 @@ 'metadata': None, 'snapshot': None, } - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.journal_writer.journal.objects) == [ ('origin', data.origin2), ('origin_visit', data1), ('origin_visit', data2)] @@ -1929,7 +1930,7 @@ 'metadata': None, 'snapshot': None, } - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.journal_writer.journal.objects) == [ ('origin', data.origin2), ('origin_visit', data1), ('origin_visit', data2)] @@ -2106,7 +2107,7 @@ 'metadata': None, 'snapshot': data.empty_snapshot['id'], } - assert list(swh_storage.journal_writer.objects) == \ + assert list(swh_storage.journal_writer.journal.objects) == \ [('origin', data.origin), ('origin_visit', data1), ('snapshot', data.empty_snapshot), @@ -2176,13 +2177,13 @@ actual_result = swh_storage.snapshot_add([data.empty_snapshot]) assert actual_result == {'snapshot:add': 1} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('snapshot', data.empty_snapshot)] actual_result = swh_storage.snapshot_add([data.snapshot]) assert actual_result == {'snapshot:add': 1} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('snapshot', data.empty_snapshot), ('snapshot', data.snapshot)] @@ -2418,7 +2419,7 @@ swh_storage.origin_add_one(data.origin) visit_id = 54164461156 - swh_storage.journal_writer.objects[:] = [] + swh_storage.journal_writer.journal.objects[:] = [] swh_storage.snapshot_add([data.snapshot]) @@ -2426,7 +2427,7 @@ swh_storage.origin_visit_update( origin_url, visit_id, snapshot=data.snapshot['id']) - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.journal_writer.journal.objects) == [ ('snapshot', data.snapshot)] def test_snapshot_add_twice__by_origin_visit(self, swh_storage): @@ -2497,7 +2498,7 @@ 'metadata': None, 'snapshot': data.snapshot['id'], } - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('origin', data.origin), ('origin_visit', data1), ('snapshot', data.snapshot), @@ -3684,7 +3685,7 @@ """ def test_content_update_with_new_cols(self, swh_storage): - swh_storage.journal_writer = None # TODO, not supported + swh_storage.journal_writer.journal = None # TODO, not supported with db_transaction(swh_storage) as (_, cur): cur.execute("""alter table content @@ -3738,7 +3739,7 @@ expected_cont = cont.copy() del expected_cont['data'] - journal_objects = list(swh_storage.journal_writer.objects) + journal_objects = list(swh_storage.journal_writer.journal.objects) for (obj_type, obj) in journal_objects: del obj['ctime'] assert journal_objects == [('content', expected_cont)] @@ -3764,7 +3765,8 @@ assert datum == (cont['sha1'], cont['sha1_git'], cont['sha256'], cont['length'], 'visible') - assert list(swh_storage.journal_writer.objects) == [('content', cont)] + assert list(swh_storage.journal_writer.journal.objects) == [ + ('content', cont)] def test_skipped_content_add_db(self, swh_storage): cont = data.skipped_cont diff --git a/swh/storage/writer.py b/swh/storage/writer.py new file mode 100644 --- /dev/null +++ b/swh/storage/writer.py @@ -0,0 +1,94 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +try: + from swh.journal.writer import get_journal_writer +except ImportError: + get_journal_writer = None # type: ignore + # mypy limitation, see https://github.com/python/mypy/issues/1153 + + +class JournalWriter: + """Journal writer storage collaborator. It's in charge of adding objects to + the journal. + + """ + def __init__(self, journal_writer): + if journal_writer: + if get_journal_writer is None: + raise EnvironmentError( + 'You need the swh.journal package to use the ' + 'journal_writer feature') + self.journal = get_journal_writer(**journal_writer) + else: + self.journal = None + + def content_add(self, contents): + if not self.journal: + return + for item in contents: + if 'data' in item: + item = item.copy() + del item['data'] + self.journal.write_addition('content', item) + + def content_update(self, content, keys=[]): + if not self.journal: + return + raise NotImplementedError( + 'content_update is not yet supported with a journal writer.') + + def content_add_metadata(self, contents): + return self.content_add(contents) + + def skipped_content_add(self, contents): + if not self.journal: + return + for item in contents: + self.journal.write_addition('content', item) + + def directory_add(self, directories): + if not self.journal: + return + self.journal.write_additions('directory', directories) + + def revision_add(self, revisions): + if not self.journal: + return + self.journal.write_additions('revision', revisions) + + def release_add(self, releases): + if not self.journal: + return + self.journal.write_additions('release', releases) + + def snapshot_add(self, snapshots): + if not self.journal: + return + snaps = snapshots if isinstance(snapshots, list) else [snapshots] + for snapshot in snaps: + self.journal.write_addition('snapshot', snapshot) + + def origin_visit_add(self, origin, date, type): + if not self.journal: + return + self.journal.write_addition('origin_visit', origin) + + def origin_visit_update(self, origin, visit_id, status=None, + metadata=None, snapshot=None): + if not self.journal: + return + self.journal.write_update('origin_visit', origin) + + def origin_visit_upsert(self, visits): + if not self.journal: + return + for visit in visits: + self.journal.write_addition('origin_visit', visit) + + def origin_add_one(self, origin): + if not self.journal: + return + self.journal.write_addition('origin', origin)