diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -17,13 +17,8 @@ Revision, Release, Directory, DirectoryEntry, Content, SkippedContent, OriginVisit, Snapshot, Origin ) -try: - from swh.journal.writer import get_journal_writer -except ImportError: - get_journal_writer = None # type: ignore - # mypy limitation, see https://github.com/python/mypy/issues/1153 - from swh.storage.objstorage import ObjStorage +from swh.storage.writer import JournalWriter from .. import HashCollision from ..exc import StorageArgumentException @@ -47,11 +42,7 @@ def __init__(self, hosts, keyspace, objstorage, port=9042, journal_writer=None): self._cql_runner = CqlRunner(hosts, keyspace, port) - - if journal_writer: - self.journal_writer = get_journal_writer(**journal_writer) - else: - self.journal_writer = None + self.journal_writer = JournalWriter(journal_writer) self.objstorage = ObjStorage(objstorage) def check_config(self, *, check_write): @@ -64,12 +55,7 @@ contents = [c for c in contents if not self._cql_runner.content_get_from_pk(c.to_dict())] - if self.journal_writer: - for content in contents: - cont = content.to_dict() - if 'data' in cont: - del cont['data'] - self.journal_writer.write_addition('content', cont) + self.journal_writer.content_add(contents) if with_data: # First insert to the objstorage, if the endpoint is @@ -249,12 +235,7 @@ c for c in contents if not self._cql_runner.skipped_content_get_from_pk(c.to_dict())] - if self.journal_writer: - for content in contents: - cont = content.to_dict() - if 'data' in cont: - del cont['data'] - self.journal_writer.write_addition('content', cont) + self.journal_writer.skipped_content_add(contents) for content in contents: # Add to index tables @@ -285,8 +266,7 @@ missing = self.directory_missing([dir_.id for dir_ in directories]) directories = [dir_ for dir_ in directories if dir_.id in missing] - if self.journal_writer: - self.journal_writer.write_additions('directory', directories) + self.journal_writer.directory_add(directories) for directory in directories: # Add directory entries to the 'directory_entry' table @@ -390,8 +370,7 @@ missing = self.revision_missing([rev.id for rev in revisions]) revisions = [rev for rev in revisions if rev.id in missing] - if self.journal_writer: - self.journal_writer.write_additions('revision', revisions) + self.journal_writer.revision_add(revisions) for revision in revisions: revision = revision_to_db(revision) @@ -483,8 +462,7 @@ missing = self.release_missing([rel.id for rel in releases]) releases = [rel for rel in releases if rel.id in missing] - if self.journal_writer: - self.journal_writer.write_additions('release', releases) + self.journal_writer.release_add(releases) for release in releases: if release: @@ -516,8 +494,7 @@ snapshots = [snp for snp in snapshots if snp.id in missing] for snapshot in snapshots: - if self.journal_writer: - self.journal_writer.write_addition('snapshot', snapshot) + self.journal_writer.snapshot_add(snapshot) # Add branches for (branch_name, branch) in snapshot.branches.items(): @@ -763,8 +740,7 @@ if known_origin: origin_url = known_origin['url'] else: - if self.journal_writer: - self.journal_writer.write_addition('origin', origin) + self.journal_writer.origin_add_one(origin) self._cql_runner.origin_add_one(origin) origin_url = origin.url @@ -797,9 +773,7 @@ }) except (KeyError, TypeError, ValueError) as e: raise StorageArgumentException(*e.args) - - if self.journal_writer: - self.journal_writer.write_addition('origin_visit', visit) + self.journal_writer.origin_visit_add(visit) self._cql_runner.origin_visit_add_one(visit) @@ -835,8 +809,7 @@ except (KeyError, TypeError, ValueError) as e: raise StorageArgumentException(*e.args) - if self.journal_writer: - self.journal_writer.write_update('origin_visit', visit) + self.journal_writer.origin_visit_update(visit) self._cql_runner.origin_visit_update(origin_url, visit_id, updates) @@ -846,9 +819,7 @@ if isinstance(visit['date'], str): visit['date'] = dateutil.parser.parse(visit['date']) - if self.journal_writer: - for visit in visits: - self.journal_writer.write_addition('origin_visit', visit) + self.journal_writer.origin_visit_upsert(visits) for visit in visits: visit = visit.copy() diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -26,9 +26,10 @@ from . import HashCollision from .exc import StorageArgumentException -from .storage import get_journal_writer + from .converters import origin_url_to_sha1 from .utils import get_partition_bounds_bytes +from .writer import JournalWriter # Max block size of contents to return BULK_BLOCK_CONTENT_LEN_MAX = 10000 @@ -46,11 +47,7 @@ self._skipped_content_indexes = defaultdict(lambda: defaultdict(set)) self.reset() - - if journal_writer: - self.journal_writer = get_journal_writer(**journal_writer) - else: - self.journal_writer = None + self.journal_writer = JournalWriter(journal_writer) def reset(self): self._directories = {} @@ -77,10 +74,7 @@ def _content_add( self, contents: Iterable[Content], with_data: bool) -> Dict: - if self.journal_writer: - for content in contents: - content = attr.evolve(content, data=None) - self.journal_writer.write_addition('content', content) + self.journal_writer.content_add(contents) content_add = 0 content_add_bytes = 0 @@ -125,9 +119,7 @@ return self._content_add(content, with_data=True) def content_update(self, content, keys=[]): - if self.journal_writer: - raise NotImplementedError( - 'content_update is not yet supported with a journal_writer.') + self.journal_writer.content_update(content) for cont_update in content: cont_update = cont_update.copy() @@ -260,9 +252,7 @@ return random.choice(list(self._content_indexes['sha1_git'])) def _skipped_content_add(self, contents: Iterable[SkippedContent]) -> Dict: - if self.journal_writer: - for cont in contents: - self.journal_writer.write_addition('content', cont) + self.journal_writer.skipped_content_add(contents) summary = { 'skipped_content:add': 0 @@ -301,20 +291,16 @@ return self._skipped_content_add(content) def directory_add(self, directories: Iterable[Directory]) -> Dict: - directories = list(directories) - if self.journal_writer: - self.journal_writer.write_additions( - 'directory', - (dir_ for dir_ in directories - if dir_.id not in self._directories)) + directories = [dir_ for dir_ in directories + if dir_.id not in self._directories] + self.journal_writer.directory_add(directories) count = 0 for directory in directories: - if directory.id not in self._directories: - count += 1 - self._directories[directory.id] = directory - self._objects[directory.id].append( - ('directory', directory.id)) + count += 1 + self._directories[directory.id] = directory + self._objects[directory.id].append( + ('directory', directory.id)) return {'directory:add': count} @@ -392,24 +378,20 @@ first_item['target'], paths[1:], prefix + paths[0] + b'/') def revision_add(self, revisions: Iterable[Revision]) -> Dict: - revisions = list(revisions) - if self.journal_writer: - self.journal_writer.write_additions( - 'revision', - (rev for rev in revisions - if rev.id not in self._revisions)) + revisions = [rev for rev in revisions + if rev.id not in self._revisions] + self.journal_writer.revision_add(revisions) count = 0 for revision in revisions: - if revision.id not in self._revisions: - revision = attr.evolve( - revision, - committer=self._person_add(revision.committer), - author=self._person_add(revision.author)) - self._revisions[revision.id] = revision - self._objects[revision.id].append( - ('revision', revision.id)) - count += 1 + revision = attr.evolve( + revision, + committer=self._person_add(revision.committer), + author=self._person_add(revision.author)) + self._revisions[revision.id] = revision + self._objects[revision.id].append( + ('revision', revision.id)) + count += 1 return {'revision:add': count} @@ -448,22 +430,18 @@ return random.choice(list(self._revisions)) def release_add(self, releases: Iterable[Release]) -> Dict: - releases = list(releases) - if self.journal_writer: - self.journal_writer.write_additions( - 'release', - (rel for rel in releases - if rel.id not in self._releases)) + releases = [rel for rel in releases + if rel.id not in self._releases] + self.journal_writer.release_add(releases) count = 0 for rel in releases: - if rel.id not in self._releases: - if rel.author: - self._person_add(rel.author) - self._objects[rel.id].append( - ('release', rel.id)) - self._releases[rel.id] = rel - count += 1 + if rel.author: + self._person_add(rel.author) + self._objects[rel.id].append( + ('release', rel.id)) + self._releases[rel.id] = rel + count += 1 return {'release:add': count} @@ -485,9 +463,7 @@ snapshots = (snap for snap in snapshots if snap.id not in self._snapshots) for snapshot in snapshots: - if self.journal_writer: - self.journal_writer.write_addition('snapshot', snapshot) - + self.journal_writer.snapshot_add(snapshot) sorted_branch_names = sorted(snapshot.branches) self._snapshots[snapshot.id] = (snapshot, sorted_branch_names) self._objects[snapshot.id].append(('snapshot', snapshot.id)) @@ -699,9 +675,7 @@ def origin_add_one(self, origin: Origin) -> str: if origin.url not in self._origins: - if self.journal_writer: - self.journal_writer.write_addition('origin', origin) - + self.journal_writer.origin_add_one(origin) # generate an origin_id because it is needed by origin_get_range. # TODO: remove this when we remove origin_get_range origin_id = len(self._origins) + 1 @@ -752,8 +726,7 @@ self._objects[(origin_url, visit_id)].append( ('origin_visit', None)) - if self.journal_writer: - self.journal_writer.write_addition('origin_visit', visit) + self.journal_writer.origin_visit_add(visit) return visit_ret @@ -783,8 +756,7 @@ except (KeyError, TypeError, ValueError) as e: raise StorageArgumentException(*e.args) - if self.journal_writer: - self.journal_writer.write_update('origin_visit', visit) + self.journal_writer.origin_visit_update(visit) self._origin_visits[origin_url][visit_id-1] = visit @@ -798,9 +770,7 @@ except (KeyError, TypeError, ValueError) as e: raise StorageArgumentException(*e.args) - if self.journal_writer: - for visit in visits: - self.journal_writer.write_addition('origin_visit', visit) + self.journal_writer.origin_visit_upsert(visits) for visit in visits: visit_id = visit.visit diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -25,11 +25,6 @@ Snapshot, Origin, SHA1_SIZE ) from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex -try: - from swh.journal.writer import get_journal_writer -except ImportError: - get_journal_writer = None # type: ignore - # mypy limitation, see https://github.com/python/mypy/issues/1153 from swh.storage.objstorage import ObjStorage from . import converters, HashCollision @@ -39,6 +34,7 @@ from .algos import diff from .metrics import timed, send_metric, process_metrics from .utils import get_partition_bounds_bytes +from .writer import JournalWriter # Max block size of contents to return @@ -95,14 +91,7 @@ except psycopg2.OperationalError as e: raise StorageDBError(e) - if journal_writer: - if get_journal_writer is None: - raise EnvironmentError( - 'You need the swh.journal package to use the ' - 'journal_writer feature') - self.journal_writer = get_journal_writer(**journal_writer) - else: - self.journal_writer = None + self.journal_writer = JournalWriter(journal_writer) self.objstorage = ObjStorage(objstorage) def get_db(self): @@ -186,17 +175,13 @@ def content_add( self, content: Iterable[Content], db=None, cur=None) -> Dict: now = datetime.datetime.now(tz=datetime.timezone.utc) - content = [attr.evolve(c, ctime=now) for c in content] + contents = [attr.evolve(c, ctime=now) for c in content] missing = list(self.content_missing( - map(Content.to_dict, content), key_hash='sha1_git')) - content = [c for c in content if c.sha1_git in missing] + map(Content.to_dict, contents), key_hash='sha1_git')) + contents = [c for c in contents if c.sha1_git in missing] - if self.journal_writer: - for item in content: - if item.data: - item = attr.evolve(item, data=None) - self.journal_writer.write_addition('content', item) + self.journal_writer.content_add(contents) def add_to_objstorage(): """Add to objstorage the new missing_content @@ -206,20 +191,20 @@ objstorage. Content present twice is only sent once. """ - summary = self.objstorage.content_add(content) + summary = self.objstorage.content_add(contents) return summary['content:add:bytes'] with ThreadPoolExecutor(max_workers=1) as executor: added_to_objstorage = executor.submit(add_to_objstorage) - self._content_add_metadata(db, cur, content) + self._content_add_metadata(db, cur, contents) # Wait for objstorage addition before returning from the # transaction, bubbling up any exception content_bytes_added = added_to_objstorage.result() return { - 'content:add': len(content), + 'content:add': len(contents), 'content:add:bytes': content_bytes_added, } @@ -228,10 +213,7 @@ def content_update(self, content, keys=[], db=None, cur=None): # TODO: Add a check on input keys. How to properly implement # this? We don't know yet the new columns. - - if self.journal_writer: - raise NotImplementedError( - 'content_update is not yet supported with a journal_writer.') + self.journal_writer.content_update(content) db.mktemp('content', cur) select_keys = list(set(db.content_get_metadata_keys).union(set(keys))) @@ -245,20 +227,16 @@ @db_transaction() def content_add_metadata(self, content: Iterable[Content], db=None, cur=None) -> Dict: - content = list(content) + contents = list(content) missing = self.content_missing( - (c.to_dict() for c in content), key_hash='sha1_git') - content = [c for c in content if c.sha1_git in missing] + (c.to_dict() for c in contents), key_hash='sha1_git') + contents = [c for c in contents if c.sha1_git in missing] - if self.journal_writer: - for item in itertools.chain(content): - assert item.data is None - self.journal_writer.write_addition('content', item) - - self._content_add_metadata(db, cur, content) + self.journal_writer.content_add_metadata(contents) + self._content_add_metadata(db, cur, contents) return { - 'content:add': len(content), + 'content:add': len(contents), } @timed @@ -429,10 +407,7 @@ for algo in DEFAULT_ALGORITHMS) for missing_content in missing_contents)] - if self.journal_writer: - for item in content: - self.journal_writer.write_addition('content', item) - + self.journal_writer.skipped_content_add(content) self._skipped_content_add_metadata(db, cur, content) return { @@ -473,11 +448,10 @@ if not dirs_missing: return summary - if self.journal_writer: - self.journal_writer.write_additions( - 'directory', - (dir_ for dir_ in directories - if dir_.id in dirs_missing)) + self.journal_writer.directory_add( + dir_ for dir_ in directories + if dir_.id in dirs_missing + ) # Copy directory ids dirs_missing_dict = ({'id': dir} for dir in dirs_missing) @@ -557,8 +531,7 @@ revision for revision in revisions if revision.id in revisions_missing] - if self.journal_writer: - self.journal_writer.write_additions('revision', revisions_filtered) + self.journal_writer.revision_add(revisions_filtered) revisions_filtered = \ list(map(converters.revision_to_db, revisions_filtered)) @@ -644,8 +617,7 @@ if release.id in releases_missing ] - if self.journal_writer: - self.journal_writer.write_additions('release', releases_filtered) + self.journal_writer.release_add(releases_filtered) releases_filtered = \ list(map(converters.release_to_db, releases_filtered)) @@ -713,8 +685,7 @@ except VALIDATION_EXCEPTIONS + (KeyError,) as e: raise StorageArgumentException(*e.args) - if self.journal_writer: - self.journal_writer.write_addition('snapshot', snapshot) + self.journal_writer.snapshot_add(snapshot) db.snapshot_add(snapshot.id, cur) count += 1 @@ -830,13 +801,18 @@ with convert_validation_exceptions(): visit_id = db.origin_visit_add(origin_url, date, type, cur) - if self.journal_writer: - # We can write to the journal only after inserting to the - # DB, because we want the id of the visit - self.journal_writer.write_addition('origin_visit', { - 'origin': origin_url, 'date': date, 'type': type, - 'visit': visit_id, - 'status': 'ongoing', 'metadata': None, 'snapshot': None}) + # We can write to the journal only after inserting to the + # DB, because we want the id of the visit + visit = { + 'origin': origin_url, + 'date': date, + 'type': type, + 'visit': visit_id, + 'status': 'ongoing', + 'metadata': None, + 'snapshot': None + } + self.journal_writer.origin_visit_add(visit) send_metric('origin_visit:add', count=1, method_name='origin_visit') return { @@ -871,9 +847,8 @@ updates['snapshot'] = snapshot if updates: - if self.journal_writer: - self.journal_writer.write_update('origin_visit', { - **visit, **updates}) + updated_visit = {**visit, **updates} + self.journal_writer.origin_visit_update(updated_visit) with convert_validation_exceptions(): db.origin_visit_update(origin_url, visit_id, updates, cur) @@ -890,9 +865,7 @@ "visit['origin'] must be a string, not %r" % (visit['origin'],)) - if self.journal_writer: - for visit in visits: - self.journal_writer.write_addition('origin_visit', visit) + self.journal_writer.origin_visit_upsert(visits) for visit in visits: # TODO: upsert them all in a single query @@ -1055,8 +1028,7 @@ if origin_url: return origin_url - if self.journal_writer: - self.journal_writer.write_addition('origin', origin) + self.journal_writer.origin_add_one(origin) origins = db.origin_add(origin.url, cur) send_metric('origin:add', count=len(origins), method_name='origin_add') diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py --- a/swh/storage/tests/test_api_client.py +++ b/swh/storage/tests/test_api_client.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2018 The Software Heritage developers +# Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -63,6 +63,7 @@ class TestStorage(_TestStorage): def test_content_update(self, swh_storage, app_server): - swh_storage.journal_writer = None # TODO, journal_writer not supported - with patch.object(server.storage.storage, 'journal_writer', None): + # TODO, journal_writer not supported + swh_storage.journal_writer.journal = None + with patch.object(server.storage.journal_writer, 'journal', None): super().test_content_update(swh_storage) diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -151,7 +151,7 @@ expected_cont = data.cont del expected_cont['data'] - journal_objects = list(swh_storage.journal_writer.objects) + journal_objects = list(swh_storage.journal_writer.journal.objects) for (obj_type, obj) in journal_objects: assert insertion_start_time <= obj['ctime'] assert obj['ctime'] <= insertion_end_time @@ -241,14 +241,14 @@ 'content:add': 1, 'content:add:bytes': data.cont['length'], } - assert len(swh_storage.journal_writer.objects) == 1 + assert len(swh_storage.journal_writer.journal.objects) == 1 actual_result = swh_storage.content_add([data.cont, data.cont2]) assert actual_result == { 'content:add': 1, 'content:add:bytes': data.cont2['length'], } - assert 2 <= len(swh_storage.journal_writer.objects) <= 3 + assert 2 <= len(swh_storage.journal_writer.journal.objects) <= 3 assert len(swh_storage.content_find(data.cont)) == 1 assert len(swh_storage.content_find(data.cont2)) == 1 @@ -269,7 +269,7 @@ def test_content_update(self, swh_storage): if hasattr(swh_storage, 'storage'): - swh_storage.storage.journal_writer = None # TODO, not supported + swh_storage.journal_writer.journal = None # TODO, not supported cont = copy.deepcopy(data.cont) @@ -300,7 +300,8 @@ cont['sha1']: [expected_cont] } - assert list(swh_storage.journal_writer.objects) == [('content', cont)] + assert list(swh_storage.journal_writer.journal.objects) == [ + ('content', cont)] def test_content_add_metadata_different_input(self, swh_storage): cont = data.cont @@ -551,7 +552,7 @@ actual_result = swh_storage.directory_add([data.dir]) assert actual_result == {'directory:add': 1} - assert list(swh_storage.journal_writer.objects) == \ + assert list(swh_storage.journal_writer.journal.objects) == \ [('directory', data.dir)] actual_data = list(swh_storage.directory_ls(data.dir['id'])) @@ -573,7 +574,7 @@ actual_result = swh_storage.directory_add(directories=_dir_gen()) assert actual_result == {'directory:add': 1} - assert list(swh_storage.journal_writer.objects) == \ + assert list(swh_storage.journal_writer.journal.objects) == \ [('directory', data.dir)] swh_storage.refresh_stat_counters() @@ -599,13 +600,13 @@ actual_result = swh_storage.directory_add([data.dir]) assert actual_result == {'directory:add': 1} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('directory', data.dir)] actual_result = swh_storage.directory_add([data.dir]) assert actual_result == {'directory:add': 0} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('directory', data.dir)] def test_directory_get_recursive(self, swh_storage): @@ -616,7 +617,7 @@ [data.dir, data.dir2, data.dir3]) assert actual_result == {'directory:add': 3} - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.journal_writer.journal.objects) == [ ('directory', data.dir), ('directory', data.dir2), ('directory', data.dir3)] @@ -653,7 +654,7 @@ [data.dir, data.dir2, data.dir3]) assert actual_result == {'directory:add': 3} - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.journal_writer.journal.objects) == [ ('directory', data.dir), ('directory', data.dir2), ('directory', data.dir3)] @@ -767,7 +768,7 @@ normalized_revision = Revision.from_dict(data.revision).to_dict() - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('revision', normalized_revision)] # already there so nothing added @@ -825,14 +826,14 @@ normalized_revision = Revision.from_dict(data.revision).to_dict() normalized_revision2 = Revision.from_dict(data.revision2).to_dict() - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('revision', normalized_revision)] actual_result = swh_storage.revision_add( [data.revision, data.revision2]) assert actual_result == {'revision:add': 1} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('revision', normalized_revision), ('revision', normalized_revision2)] @@ -877,7 +878,7 @@ normalized_revision3 = Revision.from_dict(data.revision3).to_dict() normalized_revision4 = Revision.from_dict(data.revision4).to_dict() - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.journal_writer.journal.objects) == [ ('revision', normalized_revision3), ('revision', normalized_revision4)] @@ -974,7 +975,7 @@ data.release2['id']]) assert list(end_missing) == [] - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.journal_writer.journal.objects) == [ ('release', normalized_release), ('release', normalized_release2)] @@ -996,7 +997,7 @@ actual_result = swh_storage.release_add(_rel_gen()) assert actual_result == {'release:add': 2} - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.journal_writer.journal.objects) == [ ('release', normalized_release), ('release', normalized_release2)] @@ -1015,7 +1016,7 @@ end_missing = swh_storage.release_missing([data.release['id']]) assert list(end_missing) == [] - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('release', release)] def test_release_add_validation(self, swh_storage): @@ -1045,13 +1046,13 @@ normalized_release = Release.from_dict(data.release).to_dict() normalized_release2 = Release.from_dict(data.release2).to_dict() - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('release', normalized_release)] actual_result = swh_storage.release_add([data.release, data.release2]) assert actual_result == {'release:add': 1} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('release', normalized_release), ('release', normalized_release2)] @@ -1133,7 +1134,7 @@ del actual_origin['id'] del actual_origin2['id'] - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('origin', actual_origin), ('origin', actual_origin2)] @@ -1161,7 +1162,7 @@ del actual_origin['id'] del actual_origin2['id'] - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('origin', actual_origin), ('origin', actual_origin2)] @@ -1170,12 +1171,12 @@ def test_origin_add_twice(self, swh_storage): add1 = swh_storage.origin_add([data.origin, data.origin2]) - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('origin', data.origin), ('origin', data.origin2)] add2 = swh_storage.origin_add([data.origin, data.origin2]) - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('origin', data.origin), ('origin', data.origin2)] @@ -1431,7 +1432,7 @@ 'metadata': None, 'snapshot': None, } - objects = list(swh_storage.journal_writer.objects) + objects = list(swh_storage.journal_writer.journal.objects) assert ('origin', data.origin2) in objects assert ('origin_visit', origin_visit) in objects @@ -1494,7 +1495,7 @@ for visit in expected_visits: assert visit in actual_origin_visits - objects = list(swh_storage.journal_writer.objects) + objects = list(swh_storage.journal_writer.journal.objects) assert ('origin', data.origin2) in objects for visit in expected_visits: @@ -1667,7 +1668,7 @@ 'metadata': None, 'snapshot': None, } - objects = list(swh_storage.journal_writer.objects) + objects = list(swh_storage.journal_writer.journal.objects) assert ('origin', data.origin) in objects assert ('origin', data.origin2) in objects assert ('origin_visit', data1) in objects @@ -1888,7 +1889,7 @@ 'metadata': None, 'snapshot': None, } - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.journal_writer.journal.objects) == [ ('origin', data.origin2), ('origin_visit', data1), ('origin_visit', data2)] @@ -1949,7 +1950,7 @@ 'metadata': None, 'snapshot': None, } - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.journal_writer.journal.objects) == [ ('origin', data.origin2), ('origin_visit', data1), ('origin_visit', data2)] @@ -2126,7 +2127,7 @@ 'metadata': None, 'snapshot': data.empty_snapshot['id'], } - assert list(swh_storage.journal_writer.objects) == \ + assert list(swh_storage.journal_writer.journal.objects) == \ [('origin', data.origin), ('origin_visit', data1), ('snapshot', data.empty_snapshot), @@ -2196,13 +2197,13 @@ actual_result = swh_storage.snapshot_add([data.empty_snapshot]) assert actual_result == {'snapshot:add': 1} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('snapshot', data.empty_snapshot)] actual_result = swh_storage.snapshot_add([data.snapshot]) assert actual_result == {'snapshot:add': 1} - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('snapshot', data.empty_snapshot), ('snapshot', data.snapshot)] @@ -2438,7 +2439,7 @@ swh_storage.origin_add_one(data.origin) visit_id = 54164461156 - swh_storage.journal_writer.objects[:] = [] + swh_storage.journal_writer.journal.objects[:] = [] swh_storage.snapshot_add([data.snapshot]) @@ -2446,7 +2447,7 @@ swh_storage.origin_visit_update( origin_url, visit_id, snapshot=data.snapshot['id']) - assert list(swh_storage.journal_writer.objects) == [ + assert list(swh_storage.journal_writer.journal.objects) == [ ('snapshot', data.snapshot)] def test_snapshot_add_twice__by_origin_visit(self, swh_storage): @@ -2517,7 +2518,7 @@ 'metadata': None, 'snapshot': data.snapshot['id'], } - assert list(swh_storage.journal_writer.objects) \ + assert list(swh_storage.journal_writer.journal.objects) \ == [('origin', data.origin), ('origin_visit', data1), ('snapshot', data.snapshot), @@ -3704,7 +3705,7 @@ """ def test_content_update_with_new_cols(self, swh_storage): - swh_storage.storage.journal_writer = None # TODO, not supported + swh_storage.journal_writer.journal = None # TODO, not supported with db_transaction(swh_storage) as (_, cur): cur.execute("""alter table content @@ -3758,7 +3759,7 @@ expected_cont = cont.copy() del expected_cont['data'] - journal_objects = list(swh_storage.journal_writer.objects) + journal_objects = list(swh_storage.journal_writer.journal.objects) for (obj_type, obj) in journal_objects: del obj['ctime'] assert journal_objects == [('content', expected_cont)] @@ -3784,7 +3785,8 @@ assert datum == (cont['sha1'], cont['sha1_git'], cont['sha256'], cont['length'], 'visible') - assert list(swh_storage.journal_writer.objects) == [('content', cont)] + assert list(swh_storage.journal_writer.journal.objects) == [ + ('content', cont)] def test_skipped_content_add_db(self, swh_storage): cont = data.skipped_cont diff --git a/swh/storage/writer.py b/swh/storage/writer.py new file mode 100644 --- /dev/null +++ b/swh/storage/writer.py @@ -0,0 +1,99 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from typing import Iterable, Union +from swh.model.model import ( + Origin, OriginVisit, Snapshot, Directory, Revision, Release, Content +) + +try: + from swh.journal.writer import get_journal_writer +except ImportError: + get_journal_writer = None # type: ignore + # mypy limitation, see https://github.com/python/mypy/issues/1153 + + +class JournalWriter: + """Journal writer storage collaborator. It's in charge of adding objects to + the journal. + + """ + def __init__(self, journal_writer): + if journal_writer: + if get_journal_writer is None: + raise EnvironmentError( + 'You need the swh.journal package to use the ' + 'journal_writer feature') + self.journal = get_journal_writer(**journal_writer) + else: + self.journal = None + + def content_add(self, contents: Iterable[Content]) -> None: + """Add contents to the journal. Drop the data field if provided. + + """ + if not self.journal: + return + for item in contents: + content = item.to_dict() + if 'data' in content: + del content['data'] + self.journal.write_addition('content', content) + + def content_update(self, contents: Iterable[Content]) -> None: + if not self.journal: + return + raise NotImplementedError( + 'content_update is not yet supported with a journal writer.') + + def content_add_metadata( + self, contents: Iterable[Content]) -> None: + return self.content_add(contents) + + def skipped_content_add( + self, contents: Iterable[Content]) -> None: + return self.content_add(contents) + + def directory_add(self, directories: Iterable[Directory]) -> None: + if not self.journal: + return + self.journal.write_additions('directory', directories) + + def revision_add(self, revisions: Iterable[Revision]) -> None: + if not self.journal: + return + self.journal.write_additions('revision', revisions) + + def release_add(self, releases: Iterable[Release]) -> None: + if not self.journal: + return + self.journal.write_additions('release', releases) + + def snapshot_add( + self, snapshots: Union[Iterable[Snapshot], Snapshot]) -> None: + if not self.journal: + return + snaps = snapshots if isinstance(snapshots, list) else [snapshots] + self.journal.write_additions('snapshot', snaps) + + def origin_visit_add(self, visit: OriginVisit): + if not self.journal: + return + self.journal.write_addition('origin_visit', visit) + + def origin_visit_update(self, visit: OriginVisit): + if not self.journal: + return + self.journal.write_update('origin_visit', visit) + + def origin_visit_upsert(self, visits: Iterable[OriginVisit]): + if not self.journal: + return + self.journal.write_additions('origin_visit', visits) + + def origin_add_one(self, origin: Origin): + if not self.journal: + return + self.journal.write_addition('origin', origin)