diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -19,12 +19,7 @@ ) from swh.objstorage import get_objstorage from swh.objstorage.exc import ObjNotFoundError -try: - from swh.journal.writer import get_journal_writer -except ImportError: - get_journal_writer = None # type: ignore - # mypy limitation, see https://github.com/python/mypy/issues/1153 - +from swh.storage.writer import JournalWriter from .common import TOKEN_BEGIN, TOKEN_END from .converters import ( @@ -48,11 +43,7 @@ self._cql_runner = CqlRunner(hosts, keyspace, port) self.objstorage = get_objstorage(**objstorage) - - if journal_writer: - self.journal_writer = get_journal_writer(**journal_writer) - else: - self.journal_writer = None + self.writer = JournalWriter(journal_writer) def check_config(self, *, check_write): self._cql_runner.check_read() @@ -66,12 +57,7 @@ contents = [c for c in contents if not self._cql_runner.content_get_from_pk(c.to_dict())] - if self.journal_writer: - for content in contents: - content = content.to_dict() - if 'data' in content: - del content['data'] - self.journal_writer.write_addition('content', content) + self.writer.content_add(c.to_dict() for c in contents) count_contents = 0 count_content_added = 0 @@ -267,12 +253,7 @@ c for c in contents if not self._cql_runner.skipped_content_get_from_pk(c.to_dict())] - if self.journal_writer: - for content in contents: - content = content.to_dict() - if 'data' in content: - del content['data'] - self.journal_writer.write_addition('content', content) + self.writer.skipped_content_add(c.to_dict() for c in contents) for content in contents: # Add to index tables @@ -306,8 +287,7 @@ missing = self.directory_missing([dir_['id'] for dir_ in directories]) directories = [dir_ for dir_ in directories if dir_['id'] in missing] - if self.journal_writer: - self.journal_writer.write_additions('directory', directories) + self.writer.directory_add(directories) for directory in directories: directory = Directory.from_dict(directory) @@ -412,8 +392,7 @@ missing = self.revision_missing([rev['id'] for rev in revisions]) revisions = [rev for rev in revisions if rev['id'] in missing] - if self.journal_writer: - self.journal_writer.write_additions('revision', revisions) + self.writer.revision_add(revisions) for revision in revisions: revision = revision_to_db(revision) @@ -507,8 +486,7 @@ missing = self.release_missing([rel['id'] for rel in releases]) releases = [rel for rel in releases if rel['id'] in missing] - if self.journal_writer: - self.journal_writer.write_additions('release', releases) + self.writer.release_add(releases) for release in releases: release = release_to_db(release) @@ -542,8 +520,7 @@ snapshots = [snp for snp in snapshots if snp['id'] in missing] for snapshot in snapshots: - if self.journal_writer: - self.journal_writer.write_addition('snapshot', snapshot) + self.writer.snapshot_add(snapshot) # Add branches for (branch_name, branch) in snapshot['branches'].items(): @@ -785,8 +762,7 @@ if known_origin: origin_url = known_origin['url'] else: - if self.journal_writer: - self.journal_writer.write_addition('origin', origin) + self.writer.origin_add_one(origin) self._cql_runner.origin_add_one(origin) origin_url = origin['url'] @@ -815,9 +791,7 @@ 'metadata': None, 'visit': visit_id } - - if self.journal_writer: - self.journal_writer.write_addition('origin_visit', visit) + self.writer.origin_visit_add(visit, date, type) self._cql_runner.origin_visit_add_one(visit) @@ -846,8 +820,8 @@ visit = attr.evolve(visit, **updates) - if self.journal_writer: - self.journal_writer.write_update('origin_visit', visit) + self.writer.origin_visit_update( + visit.to_dict(), visit_id, status, metadata, snapshot) self._cql_runner.origin_visit_update(origin_url, visit_id, updates) @@ -857,9 +831,7 @@ if isinstance(visit['date'], str): visit['date'] = dateutil.parser.parse(visit['date']) - if self.journal_writer: - for visit in visits: - self.journal_writer.write_addition('origin_visit', visit) + self.writer.origin_visit_upsert(visits) for visit in visits: visit = visit.copy() diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -25,9 +25,9 @@ from swh.objstorage import get_objstorage from swh.objstorage.exc import ObjNotFoundError -from .storage import get_journal_writer from .converters import origin_url_to_sha1 from .utils import get_partition_bounds_bytes +from .writer import JournalWriter # Max block size of contents to return BULK_BLOCK_CONTENT_LEN_MAX = 10000 @@ -45,11 +45,7 @@ self._skipped_content_indexes = defaultdict(lambda: defaultdict(set)) self.reset() - - if journal_writer: - self.journal_writer = get_journal_writer(**journal_writer) - else: - self.journal_writer = None + self.writer = JournalWriter(journal_writer) def reset(self): self._directories = {} @@ -83,10 +79,8 @@ if content.length is None: raise ValueError('content with length=None') - if self.journal_writer: - for content in contents: - content = attr.evolve(content, data=None) - self.journal_writer.write_addition('content', content) + self.writer.content_add( + attr.evolve(c, data=None).to_dict() for c in contents) summary = { 'content:add': 0, @@ -130,9 +124,7 @@ return self._content_add(content, with_data=True) def content_update(self, content, keys=[]): - if self.journal_writer: - raise NotImplementedError( - 'content_update is not yet supported with a journal_writer.') + self.writer.content_update(content, keys=keys) for cont_update in content: cont_update = cont_update.copy() @@ -280,9 +272,7 @@ if content.status != 'absent': raise ValueError(f'Content with status={content.status}') - if self.journal_writer: - for content in contents: - self.journal_writer.write_addition('content', content) + self.writer.skipped_content_add(contents) summary = { 'skipped_content:add': 0 @@ -323,11 +313,10 @@ def directory_add(self, directories): directories = list(directories) - if self.journal_writer: - self.journal_writer.write_additions( - 'directory', - (dir_ for dir_ in directories - if dir_['id'] not in self._directories)) + self.writer.directory_add( + dir_ for dir_ in directories + if dir_['id'] not in self._directories + ) directories = [Directory.from_dict(d) for d in directories] @@ -416,11 +405,10 @@ def revision_add(self, revisions): revisions = list(revisions) - if self.journal_writer: - self.journal_writer.write_additions( - 'revision', - (rev for rev in revisions - if rev['id'] not in self._revisions)) + self.writer.revision_add( + rev for rev in revisions + if rev['id'] not in self._revisions + ) revisions = [Revision.from_dict(rev) for rev in revisions] @@ -474,11 +462,10 @@ def release_add(self, releases): releases = list(releases) - if self.journal_writer: - self.journal_writer.write_additions( - 'release', - (rel for rel in releases - if rel['id'] not in self._releases)) + self.writer.release_add( + rel for rel in releases + if rel['id'] not in self._releases + ) releases = [Release.from_dict(rel) for rel in releases] @@ -513,9 +500,7 @@ snapshots = (snap for snap in snapshots if snap.id not in self._snapshots) for snapshot in snapshots: - if self.journal_writer: - self.journal_writer.write_addition('snapshot', snapshot) - + self.writer.snapshot_add(snapshot.to_dict()) sorted_branch_names = sorted(snapshot.branches) self._snapshots[snapshot.id] = (snapshot, sorted_branch_names) self._objects[snapshot.id].append(('snapshot', snapshot.id)) @@ -728,9 +713,7 @@ def origin_add_one(self, origin): origin = Origin.from_dict(origin) if origin.url not in self._origins: - if self.journal_writer: - self.journal_writer.write_addition('origin', origin) - + self.writer.origin_add_one(origin.to_dict()) # generate an origin_id because it is needed by origin_get_range. # TODO: remove this when we remove origin_get_range origin_id = len(self._origins) + 1 @@ -779,8 +762,7 @@ self._objects[(origin_url, visit_id)].append( ('origin_visit', None)) - if self.journal_writer: - self.journal_writer.write_addition('origin_visit', visit) + self.writer.origin_visit_add(visit.to_dict(), date, type) return visit_ret @@ -808,8 +790,8 @@ visit = attr.evolve(visit, **updates) - if self.journal_writer: - self.journal_writer.write_update('origin_visit', visit) + self.writer.origin_visit_update( + visit.to_dict(), visit_id, status, metadata, snapshot) self._origin_visits[origin_url][visit_id-1] = visit @@ -820,9 +802,7 @@ % (visit['origin'],)) visits = [OriginVisit.from_dict(d) for d in visits] - if self.journal_writer: - for visit in visits: - self.journal_writer.write_addition('origin_visit', visit) + self.writer.origin_visit_upsert(v.to_dict() for v in visits) for visit in visits: visit_id = visit.visit diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -21,11 +21,6 @@ from swh.model.hashutil import ALGORITHMS, hash_to_bytes, hash_to_hex from swh.objstorage import get_objstorage from swh.objstorage.exc import ObjNotFoundError -try: - from swh.journal.writer import get_journal_writer -except ImportError: - get_journal_writer = None # type: ignore - # mypy limitation, see https://github.com/python/mypy/issues/1153 from . import converters from .common import db_transaction_generator, db_transaction @@ -34,6 +29,7 @@ from .algos import diff from .metrics import timed, send_metric, process_metrics from .utils import get_partition_bounds_bytes +from .writer import JournalWriter # Max block size of contents to return @@ -69,14 +65,7 @@ raise StorageDBError(e) self.objstorage = get_objstorage(**objstorage) - if journal_writer: - if get_journal_writer is None: - raise EnvironmentError( - 'You need the swh.journal package to use the ' - 'journal_writer feature') - self.journal_writer = get_journal_writer(**journal_writer) - else: - self.journal_writer = None + self.writer = JournalWriter(journal_writer) def get_db(self): if self._db: @@ -193,12 +182,7 @@ missing = list(self.content_missing(content, key_hash='sha1_git')) content = [c for c in content if c['sha1_git'] in missing] - if self.journal_writer: - for item in content: - if 'data' in item: - item = item.copy() - del item['data'] - self.journal_writer.write_addition('content', item) + self.writer.content_add(content) def add_to_objstorage(): """Add to objstorage the new missing_content @@ -240,10 +224,7 @@ def content_update(self, content, keys=[], db=None, cur=None): # TODO: Add a check on input keys. How to properly implement # this? We don't know yet the new columns. - - if self.journal_writer: - raise NotImplementedError( - 'content_update is not yet supported with a journal_writer.') + self.writer.content_update(content, keys=keys) db.mktemp('content', cur) select_keys = list(set(db.content_get_metadata_keys).union(set(keys))) @@ -262,11 +243,7 @@ missing = self.content_missing(content, key_hash='sha1_git') content = [c for c in content if c['sha1_git'] in missing] - if self.journal_writer: - for item in itertools.chain(content): - assert 'data' not in content - self.journal_writer.write_addition('content', item) - + self.writer.content_add_metadata(content) self._content_add_metadata(db, cur, content) return { @@ -451,10 +428,7 @@ for algo in ALGORITHMS) for missing_content in missing_contents)] - if self.journal_writer: - for item in content: - self.journal_writer.write_addition('content', item) - + self.writer.skipped_content_add(content) self._skipped_content_add_metadata(db, cur, content) return { @@ -497,11 +471,10 @@ if not dirs_missing: return summary - if self.journal_writer: - self.journal_writer.write_additions( - 'directory', - (dir_ for dir_ in directories - if dir_['id'] in dirs_missing)) + self.writer.directory_add( + dir_ for dir_ in directories + if dir_['id'] in dirs_missing + ) # Copy directory ids dirs_missing_dict = ({'id': dir} for dir in dirs_missing) @@ -580,8 +553,7 @@ revision for revision in revisions if revision['id'] in revisions_missing] - if self.journal_writer: - self.journal_writer.write_additions('revision', revisions_filtered) + self.writer.revision_add(revisions_filtered) revisions_filtered = map(converters.revision_to_db, revisions_filtered) @@ -666,8 +638,7 @@ if release['id'] in releases_missing ] - if self.journal_writer: - self.journal_writer.write_additions('release', releases_filtered) + self.writer.release_add(releases_filtered) releases_filtered = map(converters.release_to_db, releases_filtered) @@ -729,8 +700,7 @@ cur, ) - if self.journal_writer: - self.journal_writer.write_addition('snapshot', snapshot) + self.writer.snapshot_add(snapshot) db.snapshot_add(snapshot['id'], cur) count += 1 @@ -844,13 +814,18 @@ visit_id = db.origin_visit_add(origin_url, date, type, cur) - if self.journal_writer: - # We can write to the journal only after inserting to the - # DB, because we want the id of the visit - self.journal_writer.write_addition('origin_visit', { - 'origin': origin_url, 'date': date, 'type': type, - 'visit': visit_id, - 'status': 'ongoing', 'metadata': None, 'snapshot': None}) + # We can write to the journal only after inserting to the + # DB, because we want the id of the visit + visit = { + 'origin': origin_url, + 'date': date, + 'type': type, + 'visit': visit_id, + 'status': 'ongoing', + 'metadata': None, + 'snapshot': None + } + self.writer.origin_visit_add(visit, date, type) send_metric('origin_visit:add', count=1, method_name='origin_visit') return { @@ -882,9 +857,9 @@ updates['snapshot'] = snapshot if updates: - if self.journal_writer: - self.journal_writer.write_update('origin_visit', { - **visit, **updates}) + updated_visit = {**visit, **updates} + self.writer.origin_visit_update( + updated_visit, visit_id, status, metadata, snapshot) db.origin_visit_update(origin_url, visit_id, updates, cur) @@ -899,9 +874,7 @@ raise TypeError("visit['origin'] must be a string, not %r" % (visit['origin'],)) - if self.journal_writer: - for visit in visits: - self.journal_writer.write_addition('origin_visit', visit) + self.writer.origin_visit_upsert(visits) for visit in visits: # TODO: upsert them all in a single query @@ -1063,8 +1036,7 @@ if origin_url: return origin_url - if self.journal_writer: - self.journal_writer.write_addition('origin', origin) + self.writer.origin_add_one(origin) origins = db.origin_add(origin['url'], cur) send_metric('origin:add', count=len(origins), method_name='origin_add') diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py --- a/swh/storage/tests/test_api_client.py +++ b/swh/storage/tests/test_api_client.py @@ -52,14 +52,15 @@ # in-memory backend storage is attached to the RemoteStorage as its # journal_writer attribute. storage = swh_rpc_client - journal_writer = getattr(storage, 'journal_writer', None) - storage.journal_writer = app_server.storage.journal_writer + journal_writer = getattr(storage, 'writer', None) + storage.writer = app_server.storage.writer yield storage - storage.journal_writer = journal_writer + storage.writer = journal_writer class TestStorage(_TestStorage): def test_content_update(self, swh_storage, app_server): - swh_storage.journal_writer = None # TODO, journal_writer not supported - with patch.object(server.storage, 'journal_writer', None): + # TODO, journal_writer not supported + swh_storage.writer.journal = None + with patch.object(server.storage.writer, 'journal', None): super().test_content_update(swh_storage) diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -148,7 +148,7 @@ expected_cont = data.cont del expected_cont['data'] - journal_objects = list(swh_storage.journal_writer.objects) + journal_objects = list(swh_storage.writer.journal.objects) for (obj_type, obj) in journal_objects: assert insertion_start_time <= obj['ctime'] assert obj['ctime'] <= insertion_end_time @@ -241,14 +241,14 @@ 'content:add': 1, 'content:add:bytes': data.cont['length'], } - assert len(swh_storage.journal_writer.objects) == 1 + assert len(swh_storage.writer.journal.objects) == 1 actual_result = swh_storage.content_add([data.cont, data.cont2]) assert actual_result == { 'content:add': 1, 'content:add:bytes': data.cont2['length'], } - assert 2 <= len(swh_storage.journal_writer.objects) <= 3 + assert 2 <= len(swh_storage.writer.journal.objects) <= 3 assert len(swh_storage.content_find(data.cont)) == 1 assert len(swh_storage.content_find(data.cont2)) == 1 @@ -268,7 +268,7 @@ assert cm.value.args[0] in ['sha1', 'sha1_git', 'blake2s256'] def test_content_update(self, swh_storage): - swh_storage.journal_writer = None # TODO, not supported + swh_storage.writer.journal = None # TODO, not supported cont = copy.deepcopy(data.cont) @@ -299,7 +299,7 @@ cont['sha1']: [expected_cont] } - assert list(swh_storage.journal_writer.objects) == [('content', cont)] + assert list(swh_storage.writer.journal.objects) == [('content', cont)] def test_content_add_metadata_different_input(self, swh_storage): cont = data.cont @@ -550,7 +550,7 @@ actual_result = swh_storage.directory_add([data.dir]) assert actual_result == {'directory:add': 1} - assert list(swh_storage.journal_writer.objects) == \ + assert list(swh_storage.writer.journal.objects) == \ [('directory', data.dir)] actual_data = list(swh_storage.directory_ls(data.dir['id'])) @@ -572,7 +572,7 @@ actual_result = swh_storage.directory_add(directories=_dir_gen()) assert actual_result == {'directory:add': 1} - assert list(swh_storage.journal_writer.objects) == \ + assert list(swh_storage.writer.journal.objects) == \ [('directory', data.dir)] swh_storage.refresh_stat_counters() @@ -599,13 +599,13 @@ actual_result = swh_storage.directory_add([data.dir]) assert actual_result == {'directory:add': 1} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.writer.journal.objects) \ == [('directory', data.dir)] actual_result = swh_storage.directory_add([data.dir]) assert actual_result == {'directory:add': 0} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.writer.journal.objects) \ == [('directory', data.dir)] def test_directory_get_recursive(self, swh_storage): @@ -616,7 +616,7 @@ [data.dir, data.dir2, data.dir3]) assert actual_result == {'directory:add': 3} - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.writer.journal.objects) == [ ('directory', data.dir), ('directory', data.dir2), ('directory', data.dir3)] @@ -653,7 +653,7 @@ [data.dir, data.dir2, data.dir3]) assert actual_result == {'directory:add': 3} - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.writer.journal.objects) == [ ('directory', data.dir), ('directory', data.dir2), ('directory', data.dir3)] @@ -765,7 +765,7 @@ end_missing = swh_storage.revision_missing([data.revision['id']]) assert list(end_missing) == [] - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.writer.journal.objects) \ == [('revision', data.revision)] # already there so nothing added @@ -823,14 +823,14 @@ actual_result = swh_storage.revision_add([data.revision]) assert actual_result == {'revision:add': 1} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.writer.journal.objects) \ == [('revision', data.revision)] actual_result = swh_storage.revision_add( [data.revision, data.revision2]) assert actual_result == {'revision:add': 1} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.writer.journal.objects) \ == [('revision', data.revision), ('revision', data.revision2)] @@ -872,7 +872,7 @@ assert actual_results[0] == normalize_entity(data.revision4) assert actual_results[1] == normalize_entity(data.revision3) - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.writer.journal.objects) == [ ('revision', data.revision3), ('revision', data.revision4)] @@ -966,7 +966,7 @@ data.release2['id']]) assert list(end_missing) == [] - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.writer.journal.objects) == [ ('release', data.release), ('release', data.release2)] @@ -985,7 +985,7 @@ actual_result = swh_storage.release_add(_rel_gen()) assert actual_result == {'release:add': 2} - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.writer.journal.objects) == [ ('release', data.release), ('release', data.release2)] @@ -1004,7 +1004,7 @@ end_missing = swh_storage.release_missing([data.release['id']]) assert list(end_missing) == [] - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.writer.journal.objects) \ == [('release', release)] def test_release_add_validation(self, swh_storage): @@ -1033,13 +1033,13 @@ actual_result = swh_storage.release_add([data.release]) assert actual_result == {'release:add': 1} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.writer.journal.objects) \ == [('release', data.release)] actual_result = swh_storage.release_add([data.release, data.release2]) assert actual_result == {'release:add': 1} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.writer.journal.objects) \ == [('release', data.release), ('release', data.release2)] @@ -1121,7 +1121,7 @@ del actual_origin['id'] del actual_origin2['id'] - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.writer.journal.objects) \ == [('origin', actual_origin), ('origin', actual_origin2)] @@ -1149,7 +1149,7 @@ del actual_origin['id'] del actual_origin2['id'] - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.writer.journal.objects) \ == [('origin', actual_origin), ('origin', actual_origin2)] @@ -1158,12 +1158,12 @@ def test_origin_add_twice(self, swh_storage): add1 = swh_storage.origin_add([data.origin, data.origin2]) - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.writer.journal.objects) \ == [('origin', data.origin), ('origin', data.origin2)] add2 = swh_storage.origin_add([data.origin, data.origin2]) - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.writer.journal.objects) \ == [('origin', data.origin), ('origin', data.origin2)] @@ -1419,7 +1419,7 @@ 'metadata': None, 'snapshot': None, } - objects = list(swh_storage.journal_writer.objects) + objects = list(swh_storage.writer.journal.objects) assert ('origin', data.origin2) in objects assert ('origin_visit', origin_visit) in objects @@ -1482,7 +1482,7 @@ for visit in expected_visits: assert visit in actual_origin_visits - objects = list(swh_storage.journal_writer.objects) + objects = list(swh_storage.writer.journal.objects) assert ('origin', data.origin2) in objects for visit in expected_visits: @@ -1655,7 +1655,7 @@ 'metadata': None, 'snapshot': None, } - objects = list(swh_storage.journal_writer.objects) + objects = list(swh_storage.writer.journal.objects) assert ('origin', data.origin) in objects assert ('origin', data.origin2) in objects assert ('origin_visit', data1) in objects @@ -1877,7 +1877,7 @@ 'metadata': None, 'snapshot': None, } - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.writer.journal.objects) == [ ('origin', data.origin2), ('origin_visit', data1), ('origin_visit', data2)] @@ -1938,7 +1938,7 @@ 'metadata': None, 'snapshot': None, } - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.writer.journal.objects) == [ ('origin', data.origin2), ('origin_visit', data1), ('origin_visit', data2)] @@ -2115,7 +2115,7 @@ 'metadata': None, 'snapshot': data.empty_snapshot['id'], } - assert list(swh_storage.journal_writer.objects) == \ + assert list(swh_storage.writer.journal.objects) == \ [('origin', data.origin), ('origin_visit', data1), ('snapshot', data.empty_snapshot), @@ -2185,13 +2185,13 @@ actual_result = swh_storage.snapshot_add([data.empty_snapshot]) assert actual_result == {'snapshot:add': 1} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.writer.journal.objects) \ == [('snapshot', data.empty_snapshot)] actual_result = swh_storage.snapshot_add([data.snapshot]) assert actual_result == {'snapshot:add': 1} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.writer.journal.objects) \ == [('snapshot', data.empty_snapshot), ('snapshot', data.snapshot)] @@ -2427,7 +2427,7 @@ swh_storage.origin_add_one(data.origin) visit_id = 54164461156 - swh_storage.journal_writer.objects[:] = [] + swh_storage.writer.journal.objects[:] = [] swh_storage.snapshot_add([data.snapshot]) @@ -2435,7 +2435,7 @@ swh_storage.origin_visit_update( origin_url, visit_id, snapshot=data.snapshot['id']) - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.writer.journal.objects) == [ ('snapshot', data.snapshot)] def test_snapshot_add_twice__by_origin_visit(self, swh_storage): @@ -2506,7 +2506,7 @@ 'metadata': None, 'snapshot': data.snapshot['id'], } - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.writer.journal.objects) \ == [('origin', data.origin), ('origin_visit', data1), ('snapshot', data.snapshot), @@ -3691,7 +3691,7 @@ """ def test_content_update_with_new_cols(self, swh_storage): - swh_storage.journal_writer = None # TODO, not supported + swh_storage.writer.journal = None # TODO, not supported with db_transaction(swh_storage) as (_, cur): cur.execute("""alter table content @@ -3745,7 +3745,7 @@ expected_cont = cont.copy() del expected_cont['data'] - journal_objects = list(swh_storage.journal_writer.objects) + journal_objects = list(swh_storage.writer.journal.objects) for (obj_type, obj) in journal_objects: del obj['ctime'] assert journal_objects == [('content', expected_cont)] @@ -3771,7 +3771,7 @@ assert datum == (cont['sha1'], cont['sha1_git'], cont['sha256'], cont['length'], 'visible') - assert list(swh_storage.journal_writer.objects) == [('content', cont)] + assert list(swh_storage.writer.journal.objects) == [('content', cont)] def test_skipped_content_add_db(self, swh_storage): cont = data.skipped_cont diff --git a/swh/storage/writer.py b/swh/storage/writer.py new file mode 100644 --- /dev/null +++ b/swh/storage/writer.py @@ -0,0 +1,94 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +try: + from swh.journal.writer import get_journal_writer +except ImportError: + get_journal_writer = None # type: ignore + # mypy limitation, see https://github.com/python/mypy/issues/1153 + + +class JournalWriter: + """Journal writer storage collaborator. It's in charge of adding objects to + the journal. + + """ + def __init__(self, journal_writer): + if journal_writer: + if get_journal_writer is None: + raise EnvironmentError( + 'You need the swh.journal package to use the ' + 'journal_writer feature') + self.journal = get_journal_writer(**journal_writer) + else: + self.journal = None + + def content_add(self, contents): + if not self.journal: + return + for item in contents: + if 'data' in item: + item = item.copy() + del item['data'] + self.journal.write_addition('content', item) + + def content_update(self, content, keys=[]): + if not self.journal: + return + raise NotImplementedError( + 'content_update is not yet supported with a journal writer.') + + def content_add_metadata(self, contents): + return self.content_add(contents) + + def skipped_content_add(self, contents): + if not self.journal: + return + for item in contents: + self.journal.write_addition('content', item) + + def directory_add(self, directories): + if not self.journal: + return + self.journal.write_additions('directory', directories) + + def revision_add(self, revisions): + if not self.journal: + return + self.journal.write_additions('revision', revisions) + + def release_add(self, releases): + if not self.journal: + return + self.journal.write_additions('release', releases) + + def snapshot_add(self, snapshots): + if not self.journal: + return + snaps = snapshots if isinstance(snapshots, list) else [snapshots] + for snapshot in snaps: + self.journal.write_addition('snapshot', snapshot) + + def origin_visit_add(self, origin, date, type): + if not self.journal: + return + self.journal.write_addition('origin_visit', origin) + + def origin_visit_update(self, origin, visit_id, status=None, + metadata=None, snapshot=None): + if not self.journal: + return + self.journal.write_update('origin_visit', origin) + + def origin_visit_upsert(self, visits): + if not self.journal: + return + for visit in visits: + self.journal.write_addition('origin_visit', visit) + + def origin_add_one(self, origin): + if not self.journal: + return + self.journal.write_addition('origin', origin)