Page MenuHomeSoftware Heritage

D2634.diff
No OneTemporary

D2634.diff

diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py
--- a/swh/storage/cassandra/storage.py
+++ b/swh/storage/cassandra/storage.py
@@ -17,13 +17,8 @@
Revision, Release, Directory, DirectoryEntry, Content, SkippedContent,
OriginVisit, Snapshot, Origin
)
-try:
- from swh.journal.writer import get_journal_writer
-except ImportError:
- get_journal_writer = None # type: ignore
- # mypy limitation, see https://github.com/python/mypy/issues/1153
-
from swh.storage.objstorage import ObjStorage
+from swh.storage.writer import JournalWriter
from .. import HashCollision
from ..exc import StorageArgumentException
@@ -47,11 +42,7 @@
def __init__(self, hosts, keyspace, objstorage,
port=9042, journal_writer=None):
self._cql_runner = CqlRunner(hosts, keyspace, port)
-
- if journal_writer:
- self.journal_writer = get_journal_writer(**journal_writer)
- else:
- self.journal_writer = None
+ self.journal_writer = JournalWriter(journal_writer)
self.objstorage = ObjStorage(objstorage)
def check_config(self, *, check_write):
@@ -64,12 +55,7 @@
contents = [c for c in contents
if not self._cql_runner.content_get_from_pk(c.to_dict())]
- if self.journal_writer:
- for content in contents:
- cont = content.to_dict()
- if 'data' in cont:
- del cont['data']
- self.journal_writer.write_addition('content', cont)
+ self.journal_writer.content_add(contents)
if with_data:
# First insert to the objstorage, if the endpoint is
@@ -249,12 +235,7 @@
c for c in contents
if not self._cql_runner.skipped_content_get_from_pk(c.to_dict())]
- if self.journal_writer:
- for content in contents:
- cont = content.to_dict()
- if 'data' in cont:
- del cont['data']
- self.journal_writer.write_addition('content', cont)
+ self.journal_writer.skipped_content_add(contents)
for content in contents:
# Add to index tables
@@ -285,8 +266,7 @@
missing = self.directory_missing([dir_.id for dir_ in directories])
directories = [dir_ for dir_ in directories if dir_.id in missing]
- if self.journal_writer:
- self.journal_writer.write_additions('directory', directories)
+ self.journal_writer.directory_add(directories)
for directory in directories:
# Add directory entries to the 'directory_entry' table
@@ -390,8 +370,7 @@
missing = self.revision_missing([rev.id for rev in revisions])
revisions = [rev for rev in revisions if rev.id in missing]
- if self.journal_writer:
- self.journal_writer.write_additions('revision', revisions)
+ self.journal_writer.revision_add(revisions)
for revision in revisions:
revision = revision_to_db(revision)
@@ -483,8 +462,7 @@
missing = self.release_missing([rel.id for rel in releases])
releases = [rel for rel in releases if rel.id in missing]
- if self.journal_writer:
- self.journal_writer.write_additions('release', releases)
+ self.journal_writer.release_add(releases)
for release in releases:
if release:
@@ -516,8 +494,7 @@
snapshots = [snp for snp in snapshots if snp.id in missing]
for snapshot in snapshots:
- if self.journal_writer:
- self.journal_writer.write_addition('snapshot', snapshot)
+ self.journal_writer.snapshot_add(snapshot)
# Add branches
for (branch_name, branch) in snapshot.branches.items():
@@ -763,8 +740,7 @@
if known_origin:
origin_url = known_origin['url']
else:
- if self.journal_writer:
- self.journal_writer.write_addition('origin', origin)
+ self.journal_writer.origin_add_one(origin)
self._cql_runner.origin_add_one(origin)
origin_url = origin.url
@@ -797,9 +773,7 @@
})
except (KeyError, TypeError, ValueError) as e:
raise StorageArgumentException(*e.args)
-
- if self.journal_writer:
- self.journal_writer.write_addition('origin_visit', visit)
+ self.journal_writer.origin_visit_add(visit)
self._cql_runner.origin_visit_add_one(visit)
@@ -835,8 +809,7 @@
except (KeyError, TypeError, ValueError) as e:
raise StorageArgumentException(*e.args)
- if self.journal_writer:
- self.journal_writer.write_update('origin_visit', visit)
+ self.journal_writer.origin_visit_update(visit)
self._cql_runner.origin_visit_update(origin_url, visit_id, updates)
@@ -846,9 +819,7 @@
if isinstance(visit['date'], str):
visit['date'] = dateutil.parser.parse(visit['date'])
- if self.journal_writer:
- for visit in visits:
- self.journal_writer.write_addition('origin_visit', visit)
+ self.journal_writer.origin_visit_upsert(visits)
for visit in visits:
visit = visit.copy()
diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py
--- a/swh/storage/in_memory.py
+++ b/swh/storage/in_memory.py
@@ -26,9 +26,10 @@
from . import HashCollision
from .exc import StorageArgumentException
-from .storage import get_journal_writer
+
from .converters import origin_url_to_sha1
from .utils import get_partition_bounds_bytes
+from .writer import JournalWriter
# Max block size of contents to return
BULK_BLOCK_CONTENT_LEN_MAX = 10000
@@ -46,11 +47,7 @@
self._skipped_content_indexes = defaultdict(lambda: defaultdict(set))
self.reset()
-
- if journal_writer:
- self.journal_writer = get_journal_writer(**journal_writer)
- else:
- self.journal_writer = None
+ self.journal_writer = JournalWriter(journal_writer)
def reset(self):
self._directories = {}
@@ -77,10 +74,7 @@
def _content_add(
self, contents: Iterable[Content], with_data: bool) -> Dict:
- if self.journal_writer:
- for content in contents:
- content = attr.evolve(content, data=None)
- self.journal_writer.write_addition('content', content)
+ self.journal_writer.content_add(contents)
content_add = 0
content_add_bytes = 0
@@ -125,9 +119,7 @@
return self._content_add(content, with_data=True)
def content_update(self, content, keys=[]):
- if self.journal_writer:
- raise NotImplementedError(
- 'content_update is not yet supported with a journal_writer.')
+ self.journal_writer.content_update(content)
for cont_update in content:
cont_update = cont_update.copy()
@@ -260,9 +252,7 @@
return random.choice(list(self._content_indexes['sha1_git']))
def _skipped_content_add(self, contents: Iterable[SkippedContent]) -> Dict:
- if self.journal_writer:
- for cont in contents:
- self.journal_writer.write_addition('content', cont)
+ self.journal_writer.skipped_content_add(contents)
summary = {
'skipped_content:add': 0
@@ -301,20 +291,16 @@
return self._skipped_content_add(content)
def directory_add(self, directories: Iterable[Directory]) -> Dict:
- directories = list(directories)
- if self.journal_writer:
- self.journal_writer.write_additions(
- 'directory',
- (dir_ for dir_ in directories
- if dir_.id not in self._directories))
+ directories = [dir_ for dir_ in directories
+ if dir_.id not in self._directories]
+ self.journal_writer.directory_add(directories)
count = 0
for directory in directories:
- if directory.id not in self._directories:
- count += 1
- self._directories[directory.id] = directory
- self._objects[directory.id].append(
- ('directory', directory.id))
+ count += 1
+ self._directories[directory.id] = directory
+ self._objects[directory.id].append(
+ ('directory', directory.id))
return {'directory:add': count}
@@ -392,24 +378,20 @@
first_item['target'], paths[1:], prefix + paths[0] + b'/')
def revision_add(self, revisions: Iterable[Revision]) -> Dict:
- revisions = list(revisions)
- if self.journal_writer:
- self.journal_writer.write_additions(
- 'revision',
- (rev for rev in revisions
- if rev.id not in self._revisions))
+ revisions = [rev for rev in revisions
+ if rev.id not in self._revisions]
+ self.journal_writer.revision_add(revisions)
count = 0
for revision in revisions:
- if revision.id not in self._revisions:
- revision = attr.evolve(
- revision,
- committer=self._person_add(revision.committer),
- author=self._person_add(revision.author))
- self._revisions[revision.id] = revision
- self._objects[revision.id].append(
- ('revision', revision.id))
- count += 1
+ revision = attr.evolve(
+ revision,
+ committer=self._person_add(revision.committer),
+ author=self._person_add(revision.author))
+ self._revisions[revision.id] = revision
+ self._objects[revision.id].append(
+ ('revision', revision.id))
+ count += 1
return {'revision:add': count}
@@ -448,22 +430,18 @@
return random.choice(list(self._revisions))
def release_add(self, releases: Iterable[Release]) -> Dict:
- releases = list(releases)
- if self.journal_writer:
- self.journal_writer.write_additions(
- 'release',
- (rel for rel in releases
- if rel.id not in self._releases))
+ releases = [rel for rel in releases
+ if rel.id not in self._releases]
+ self.journal_writer.release_add(releases)
count = 0
for rel in releases:
- if rel.id not in self._releases:
- if rel.author:
- self._person_add(rel.author)
- self._objects[rel.id].append(
- ('release', rel.id))
- self._releases[rel.id] = rel
- count += 1
+ if rel.author:
+ self._person_add(rel.author)
+ self._objects[rel.id].append(
+ ('release', rel.id))
+ self._releases[rel.id] = rel
+ count += 1
return {'release:add': count}
@@ -485,9 +463,7 @@
snapshots = (snap for snap in snapshots
if snap.id not in self._snapshots)
for snapshot in snapshots:
- if self.journal_writer:
- self.journal_writer.write_addition('snapshot', snapshot)
-
+ self.journal_writer.snapshot_add(snapshot)
sorted_branch_names = sorted(snapshot.branches)
self._snapshots[snapshot.id] = (snapshot, sorted_branch_names)
self._objects[snapshot.id].append(('snapshot', snapshot.id))
@@ -699,9 +675,7 @@
def origin_add_one(self, origin: Origin) -> str:
if origin.url not in self._origins:
- if self.journal_writer:
- self.journal_writer.write_addition('origin', origin)
-
+ self.journal_writer.origin_add_one(origin)
# generate an origin_id because it is needed by origin_get_range.
# TODO: remove this when we remove origin_get_range
origin_id = len(self._origins) + 1
@@ -752,8 +726,7 @@
self._objects[(origin_url, visit_id)].append(
('origin_visit', None))
- if self.journal_writer:
- self.journal_writer.write_addition('origin_visit', visit)
+ self.journal_writer.origin_visit_add(visit)
return visit_ret
@@ -783,8 +756,7 @@
except (KeyError, TypeError, ValueError) as e:
raise StorageArgumentException(*e.args)
- if self.journal_writer:
- self.journal_writer.write_update('origin_visit', visit)
+ self.journal_writer.origin_visit_update(visit)
self._origin_visits[origin_url][visit_id-1] = visit
@@ -798,9 +770,7 @@
except (KeyError, TypeError, ValueError) as e:
raise StorageArgumentException(*e.args)
- if self.journal_writer:
- for visit in visits:
- self.journal_writer.write_addition('origin_visit', visit)
+ self.journal_writer.origin_visit_upsert(visits)
for visit in visits:
visit_id = visit.visit
diff --git a/swh/storage/storage.py b/swh/storage/storage.py
--- a/swh/storage/storage.py
+++ b/swh/storage/storage.py
@@ -25,11 +25,6 @@
Snapshot, Origin, SHA1_SIZE
)
from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex
-try:
- from swh.journal.writer import get_journal_writer
-except ImportError:
- get_journal_writer = None # type: ignore
- # mypy limitation, see https://github.com/python/mypy/issues/1153
from swh.storage.objstorage import ObjStorage
from . import converters, HashCollision
@@ -39,6 +34,7 @@
from .algos import diff
from .metrics import timed, send_metric, process_metrics
from .utils import get_partition_bounds_bytes
+from .writer import JournalWriter
# Max block size of contents to return
@@ -95,14 +91,7 @@
except psycopg2.OperationalError as e:
raise StorageDBError(e)
- if journal_writer:
- if get_journal_writer is None:
- raise EnvironmentError(
- 'You need the swh.journal package to use the '
- 'journal_writer feature')
- self.journal_writer = get_journal_writer(**journal_writer)
- else:
- self.journal_writer = None
+ self.journal_writer = JournalWriter(journal_writer)
self.objstorage = ObjStorage(objstorage)
def get_db(self):
@@ -186,17 +175,13 @@
def content_add(
self, content: Iterable[Content], db=None, cur=None) -> Dict:
now = datetime.datetime.now(tz=datetime.timezone.utc)
- content = [attr.evolve(c, ctime=now) for c in content]
+ contents = [attr.evolve(c, ctime=now) for c in content]
missing = list(self.content_missing(
- map(Content.to_dict, content), key_hash='sha1_git'))
- content = [c for c in content if c.sha1_git in missing]
+ map(Content.to_dict, contents), key_hash='sha1_git'))
+ contents = [c for c in contents if c.sha1_git in missing]
- if self.journal_writer:
- for item in content:
- if item.data:
- item = attr.evolve(item, data=None)
- self.journal_writer.write_addition('content', item)
+ self.journal_writer.content_add(contents)
def add_to_objstorage():
"""Add to objstorage the new missing_content
@@ -206,20 +191,20 @@
objstorage. Content present twice is only sent once.
"""
- summary = self.objstorage.content_add(content)
+ summary = self.objstorage.content_add(contents)
return summary['content:add:bytes']
with ThreadPoolExecutor(max_workers=1) as executor:
added_to_objstorage = executor.submit(add_to_objstorage)
- self._content_add_metadata(db, cur, content)
+ self._content_add_metadata(db, cur, contents)
# Wait for objstorage addition before returning from the
# transaction, bubbling up any exception
content_bytes_added = added_to_objstorage.result()
return {
- 'content:add': len(content),
+ 'content:add': len(contents),
'content:add:bytes': content_bytes_added,
}
@@ -228,10 +213,7 @@
def content_update(self, content, keys=[], db=None, cur=None):
# TODO: Add a check on input keys. How to properly implement
# this? We don't know yet the new columns.
-
- if self.journal_writer:
- raise NotImplementedError(
- 'content_update is not yet supported with a journal_writer.')
+ self.journal_writer.content_update(content)
db.mktemp('content', cur)
select_keys = list(set(db.content_get_metadata_keys).union(set(keys)))
@@ -245,20 +227,16 @@
@db_transaction()
def content_add_metadata(self, content: Iterable[Content],
db=None, cur=None) -> Dict:
- content = list(content)
+ contents = list(content)
missing = self.content_missing(
- (c.to_dict() for c in content), key_hash='sha1_git')
- content = [c for c in content if c.sha1_git in missing]
+ (c.to_dict() for c in contents), key_hash='sha1_git')
+ contents = [c for c in contents if c.sha1_git in missing]
- if self.journal_writer:
- for item in itertools.chain(content):
- assert item.data is None
- self.journal_writer.write_addition('content', item)
-
- self._content_add_metadata(db, cur, content)
+ self.journal_writer.content_add_metadata(contents)
+ self._content_add_metadata(db, cur, contents)
return {
- 'content:add': len(content),
+ 'content:add': len(contents),
}
@timed
@@ -429,10 +407,7 @@
for algo in DEFAULT_ALGORITHMS)
for missing_content in missing_contents)]
- if self.journal_writer:
- for item in content:
- self.journal_writer.write_addition('content', item)
-
+ self.journal_writer.skipped_content_add(content)
self._skipped_content_add_metadata(db, cur, content)
return {
@@ -473,11 +448,10 @@
if not dirs_missing:
return summary
- if self.journal_writer:
- self.journal_writer.write_additions(
- 'directory',
- (dir_ for dir_ in directories
- if dir_.id in dirs_missing))
+ self.journal_writer.directory_add(
+ dir_ for dir_ in directories
+ if dir_.id in dirs_missing
+ )
# Copy directory ids
dirs_missing_dict = ({'id': dir} for dir in dirs_missing)
@@ -557,8 +531,7 @@
revision for revision in revisions
if revision.id in revisions_missing]
- if self.journal_writer:
- self.journal_writer.write_additions('revision', revisions_filtered)
+ self.journal_writer.revision_add(revisions_filtered)
revisions_filtered = \
list(map(converters.revision_to_db, revisions_filtered))
@@ -644,8 +617,7 @@
if release.id in releases_missing
]
- if self.journal_writer:
- self.journal_writer.write_additions('release', releases_filtered)
+ self.journal_writer.release_add(releases_filtered)
releases_filtered = \
list(map(converters.release_to_db, releases_filtered))
@@ -713,8 +685,7 @@
except VALIDATION_EXCEPTIONS + (KeyError,) as e:
raise StorageArgumentException(*e.args)
- if self.journal_writer:
- self.journal_writer.write_addition('snapshot', snapshot)
+ self.journal_writer.snapshot_add(snapshot)
db.snapshot_add(snapshot.id, cur)
count += 1
@@ -830,13 +801,18 @@
with convert_validation_exceptions():
visit_id = db.origin_visit_add(origin_url, date, type, cur)
- if self.journal_writer:
- # We can write to the journal only after inserting to the
- # DB, because we want the id of the visit
- self.journal_writer.write_addition('origin_visit', {
- 'origin': origin_url, 'date': date, 'type': type,
- 'visit': visit_id,
- 'status': 'ongoing', 'metadata': None, 'snapshot': None})
+ # We can write to the journal only after inserting to the
+ # DB, because we want the id of the visit
+ visit = {
+ 'origin': origin_url,
+ 'date': date,
+ 'type': type,
+ 'visit': visit_id,
+ 'status': 'ongoing',
+ 'metadata': None,
+ 'snapshot': None
+ }
+ self.journal_writer.origin_visit_add(visit)
send_metric('origin_visit:add', count=1, method_name='origin_visit')
return {
@@ -871,9 +847,8 @@
updates['snapshot'] = snapshot
if updates:
- if self.journal_writer:
- self.journal_writer.write_update('origin_visit', {
- **visit, **updates})
+ updated_visit = {**visit, **updates}
+ self.journal_writer.origin_visit_update(updated_visit)
with convert_validation_exceptions():
db.origin_visit_update(origin_url, visit_id, updates, cur)
@@ -890,9 +865,7 @@
"visit['origin'] must be a string, not %r"
% (visit['origin'],))
- if self.journal_writer:
- for visit in visits:
- self.journal_writer.write_addition('origin_visit', visit)
+ self.journal_writer.origin_visit_upsert(visits)
for visit in visits:
# TODO: upsert them all in a single query
@@ -1055,8 +1028,7 @@
if origin_url:
return origin_url
- if self.journal_writer:
- self.journal_writer.write_addition('origin', origin)
+ self.journal_writer.origin_add_one(origin)
origins = db.origin_add(origin.url, cur)
send_metric('origin:add', count=len(origins), method_name='origin_add')
diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py
--- a/swh/storage/tests/test_api_client.py
+++ b/swh/storage/tests/test_api_client.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2015-2018 The Software Heritage developers
+# Copyright (C) 2015-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
@@ -63,6 +63,7 @@
class TestStorage(_TestStorage):
def test_content_update(self, swh_storage, app_server):
- swh_storage.journal_writer = None # TODO, journal_writer not supported
- with patch.object(server.storage.storage, 'journal_writer', None):
+ # TODO, journal_writer not supported
+ swh_storage.journal_writer.journal = None
+ with patch.object(server.storage.journal_writer, 'journal', None):
super().test_content_update(swh_storage)
diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py
--- a/swh/storage/tests/test_storage.py
+++ b/swh/storage/tests/test_storage.py
@@ -151,7 +151,7 @@
expected_cont = data.cont
del expected_cont['data']
- journal_objects = list(swh_storage.journal_writer.objects)
+ journal_objects = list(swh_storage.journal_writer.journal.objects)
for (obj_type, obj) in journal_objects:
assert insertion_start_time <= obj['ctime']
assert obj['ctime'] <= insertion_end_time
@@ -241,14 +241,14 @@
'content:add': 1,
'content:add:bytes': data.cont['length'],
}
- assert len(swh_storage.journal_writer.objects) == 1
+ assert len(swh_storage.journal_writer.journal.objects) == 1
actual_result = swh_storage.content_add([data.cont, data.cont2])
assert actual_result == {
'content:add': 1,
'content:add:bytes': data.cont2['length'],
}
- assert 2 <= len(swh_storage.journal_writer.objects) <= 3
+ assert 2 <= len(swh_storage.journal_writer.journal.objects) <= 3
assert len(swh_storage.content_find(data.cont)) == 1
assert len(swh_storage.content_find(data.cont2)) == 1
@@ -269,7 +269,7 @@
def test_content_update(self, swh_storage):
if hasattr(swh_storage, 'storage'):
- swh_storage.storage.journal_writer = None # TODO, not supported
+ swh_storage.journal_writer.journal = None # TODO, not supported
cont = copy.deepcopy(data.cont)
@@ -300,7 +300,8 @@
cont['sha1']: [expected_cont]
}
- assert list(swh_storage.journal_writer.objects) == [('content', cont)]
+ assert list(swh_storage.journal_writer.journal.objects) == [
+ ('content', cont)]
def test_content_add_metadata_different_input(self, swh_storage):
cont = data.cont
@@ -551,7 +552,7 @@
actual_result = swh_storage.directory_add([data.dir])
assert actual_result == {'directory:add': 1}
- assert list(swh_storage.journal_writer.objects) == \
+ assert list(swh_storage.journal_writer.journal.objects) == \
[('directory', data.dir)]
actual_data = list(swh_storage.directory_ls(data.dir['id']))
@@ -573,7 +574,7 @@
actual_result = swh_storage.directory_add(directories=_dir_gen())
assert actual_result == {'directory:add': 1}
- assert list(swh_storage.journal_writer.objects) == \
+ assert list(swh_storage.journal_writer.journal.objects) == \
[('directory', data.dir)]
swh_storage.refresh_stat_counters()
@@ -599,13 +600,13 @@
actual_result = swh_storage.directory_add([data.dir])
assert actual_result == {'directory:add': 1}
- assert list(swh_storage.journal_writer.objects) \
+ assert list(swh_storage.journal_writer.journal.objects) \
== [('directory', data.dir)]
actual_result = swh_storage.directory_add([data.dir])
assert actual_result == {'directory:add': 0}
- assert list(swh_storage.journal_writer.objects) \
+ assert list(swh_storage.journal_writer.journal.objects) \
== [('directory', data.dir)]
def test_directory_get_recursive(self, swh_storage):
@@ -616,7 +617,7 @@
[data.dir, data.dir2, data.dir3])
assert actual_result == {'directory:add': 3}
- assert list(swh_storage.journal_writer.objects) == [
+ assert list(swh_storage.journal_writer.journal.objects) == [
('directory', data.dir),
('directory', data.dir2),
('directory', data.dir3)]
@@ -653,7 +654,7 @@
[data.dir, data.dir2, data.dir3])
assert actual_result == {'directory:add': 3}
- assert list(swh_storage.journal_writer.objects) == [
+ assert list(swh_storage.journal_writer.journal.objects) == [
('directory', data.dir),
('directory', data.dir2),
('directory', data.dir3)]
@@ -767,7 +768,7 @@
normalized_revision = Revision.from_dict(data.revision).to_dict()
- assert list(swh_storage.journal_writer.objects) \
+ assert list(swh_storage.journal_writer.journal.objects) \
== [('revision', normalized_revision)]
# already there so nothing added
@@ -825,14 +826,14 @@
normalized_revision = Revision.from_dict(data.revision).to_dict()
normalized_revision2 = Revision.from_dict(data.revision2).to_dict()
- assert list(swh_storage.journal_writer.objects) \
+ assert list(swh_storage.journal_writer.journal.objects) \
== [('revision', normalized_revision)]
actual_result = swh_storage.revision_add(
[data.revision, data.revision2])
assert actual_result == {'revision:add': 1}
- assert list(swh_storage.journal_writer.objects) \
+ assert list(swh_storage.journal_writer.journal.objects) \
== [('revision', normalized_revision),
('revision', normalized_revision2)]
@@ -877,7 +878,7 @@
normalized_revision3 = Revision.from_dict(data.revision3).to_dict()
normalized_revision4 = Revision.from_dict(data.revision4).to_dict()
- assert list(swh_storage.journal_writer.objects) == [
+ assert list(swh_storage.journal_writer.journal.objects) == [
('revision', normalized_revision3),
('revision', normalized_revision4)]
@@ -974,7 +975,7 @@
data.release2['id']])
assert list(end_missing) == []
- assert list(swh_storage.journal_writer.objects) == [
+ assert list(swh_storage.journal_writer.journal.objects) == [
('release', normalized_release),
('release', normalized_release2)]
@@ -996,7 +997,7 @@
actual_result = swh_storage.release_add(_rel_gen())
assert actual_result == {'release:add': 2}
- assert list(swh_storage.journal_writer.objects) == [
+ assert list(swh_storage.journal_writer.journal.objects) == [
('release', normalized_release),
('release', normalized_release2)]
@@ -1015,7 +1016,7 @@
end_missing = swh_storage.release_missing([data.release['id']])
assert list(end_missing) == []
- assert list(swh_storage.journal_writer.objects) \
+ assert list(swh_storage.journal_writer.journal.objects) \
== [('release', release)]
def test_release_add_validation(self, swh_storage):
@@ -1045,13 +1046,13 @@
normalized_release = Release.from_dict(data.release).to_dict()
normalized_release2 = Release.from_dict(data.release2).to_dict()
- assert list(swh_storage.journal_writer.objects) \
+ assert list(swh_storage.journal_writer.journal.objects) \
== [('release', normalized_release)]
actual_result = swh_storage.release_add([data.release, data.release2])
assert actual_result == {'release:add': 1}
- assert list(swh_storage.journal_writer.objects) \
+ assert list(swh_storage.journal_writer.journal.objects) \
== [('release', normalized_release),
('release', normalized_release2)]
@@ -1133,7 +1134,7 @@
del actual_origin['id']
del actual_origin2['id']
- assert list(swh_storage.journal_writer.objects) \
+ assert list(swh_storage.journal_writer.journal.objects) \
== [('origin', actual_origin),
('origin', actual_origin2)]
@@ -1161,7 +1162,7 @@
del actual_origin['id']
del actual_origin2['id']
- assert list(swh_storage.journal_writer.objects) \
+ assert list(swh_storage.journal_writer.journal.objects) \
== [('origin', actual_origin),
('origin', actual_origin2)]
@@ -1170,12 +1171,12 @@
def test_origin_add_twice(self, swh_storage):
add1 = swh_storage.origin_add([data.origin, data.origin2])
- assert list(swh_storage.journal_writer.objects) \
+ assert list(swh_storage.journal_writer.journal.objects) \
== [('origin', data.origin),
('origin', data.origin2)]
add2 = swh_storage.origin_add([data.origin, data.origin2])
- assert list(swh_storage.journal_writer.objects) \
+ assert list(swh_storage.journal_writer.journal.objects) \
== [('origin', data.origin),
('origin', data.origin2)]
@@ -1431,7 +1432,7 @@
'metadata': None,
'snapshot': None,
}
- objects = list(swh_storage.journal_writer.objects)
+ objects = list(swh_storage.journal_writer.journal.objects)
assert ('origin', data.origin2) in objects
assert ('origin_visit', origin_visit) in objects
@@ -1494,7 +1495,7 @@
for visit in expected_visits:
assert visit in actual_origin_visits
- objects = list(swh_storage.journal_writer.objects)
+ objects = list(swh_storage.journal_writer.journal.objects)
assert ('origin', data.origin2) in objects
for visit in expected_visits:
@@ -1667,7 +1668,7 @@
'metadata': None,
'snapshot': None,
}
- objects = list(swh_storage.journal_writer.objects)
+ objects = list(swh_storage.journal_writer.journal.objects)
assert ('origin', data.origin) in objects
assert ('origin', data.origin2) in objects
assert ('origin_visit', data1) in objects
@@ -1888,7 +1889,7 @@
'metadata': None,
'snapshot': None,
}
- assert list(swh_storage.journal_writer.objects) == [
+ assert list(swh_storage.journal_writer.journal.objects) == [
('origin', data.origin2),
('origin_visit', data1),
('origin_visit', data2)]
@@ -1949,7 +1950,7 @@
'metadata': None,
'snapshot': None,
}
- assert list(swh_storage.journal_writer.objects) == [
+ assert list(swh_storage.journal_writer.journal.objects) == [
('origin', data.origin2),
('origin_visit', data1),
('origin_visit', data2)]
@@ -2126,7 +2127,7 @@
'metadata': None,
'snapshot': data.empty_snapshot['id'],
}
- assert list(swh_storage.journal_writer.objects) == \
+ assert list(swh_storage.journal_writer.journal.objects) == \
[('origin', data.origin),
('origin_visit', data1),
('snapshot', data.empty_snapshot),
@@ -2196,13 +2197,13 @@
actual_result = swh_storage.snapshot_add([data.empty_snapshot])
assert actual_result == {'snapshot:add': 1}
- assert list(swh_storage.journal_writer.objects) \
+ assert list(swh_storage.journal_writer.journal.objects) \
== [('snapshot', data.empty_snapshot)]
actual_result = swh_storage.snapshot_add([data.snapshot])
assert actual_result == {'snapshot:add': 1}
- assert list(swh_storage.journal_writer.objects) \
+ assert list(swh_storage.journal_writer.journal.objects) \
== [('snapshot', data.empty_snapshot),
('snapshot', data.snapshot)]
@@ -2438,7 +2439,7 @@
swh_storage.origin_add_one(data.origin)
visit_id = 54164461156
- swh_storage.journal_writer.objects[:] = []
+ swh_storage.journal_writer.journal.objects[:] = []
swh_storage.snapshot_add([data.snapshot])
@@ -2446,7 +2447,7 @@
swh_storage.origin_visit_update(
origin_url, visit_id, snapshot=data.snapshot['id'])
- assert list(swh_storage.journal_writer.objects) == [
+ assert list(swh_storage.journal_writer.journal.objects) == [
('snapshot', data.snapshot)]
def test_snapshot_add_twice__by_origin_visit(self, swh_storage):
@@ -2517,7 +2518,7 @@
'metadata': None,
'snapshot': data.snapshot['id'],
}
- assert list(swh_storage.journal_writer.objects) \
+ assert list(swh_storage.journal_writer.journal.objects) \
== [('origin', data.origin),
('origin_visit', data1),
('snapshot', data.snapshot),
@@ -3704,7 +3705,7 @@
"""
def test_content_update_with_new_cols(self, swh_storage):
- swh_storage.storage.journal_writer = None # TODO, not supported
+ swh_storage.journal_writer.journal = None # TODO, not supported
with db_transaction(swh_storage) as (_, cur):
cur.execute("""alter table content
@@ -3758,7 +3759,7 @@
expected_cont = cont.copy()
del expected_cont['data']
- journal_objects = list(swh_storage.journal_writer.objects)
+ journal_objects = list(swh_storage.journal_writer.journal.objects)
for (obj_type, obj) in journal_objects:
del obj['ctime']
assert journal_objects == [('content', expected_cont)]
@@ -3784,7 +3785,8 @@
assert datum == (cont['sha1'], cont['sha1_git'], cont['sha256'],
cont['length'], 'visible')
- assert list(swh_storage.journal_writer.objects) == [('content', cont)]
+ assert list(swh_storage.journal_writer.journal.objects) == [
+ ('content', cont)]
def test_skipped_content_add_db(self, swh_storage):
cont = data.skipped_cont
diff --git a/swh/storage/writer.py b/swh/storage/writer.py
new file mode 100644
--- /dev/null
+++ b/swh/storage/writer.py
@@ -0,0 +1,99 @@
+# Copyright (C) 2020 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+from typing import Iterable, Union
+from swh.model.model import (
+ Origin, OriginVisit, Snapshot, Directory, Revision, Release, Content
+)
+
+try:
+ from swh.journal.writer import get_journal_writer
+except ImportError:
+ get_journal_writer = None # type: ignore
+ # mypy limitation, see https://github.com/python/mypy/issues/1153
+
+
+class JournalWriter:
+ """Journal writer storage collaborator. It's in charge of adding objects to
+ the journal.
+
+ """
+ def __init__(self, journal_writer):
+ if journal_writer:
+ if get_journal_writer is None:
+ raise EnvironmentError(
+ 'You need the swh.journal package to use the '
+ 'journal_writer feature')
+ self.journal = get_journal_writer(**journal_writer)
+ else:
+ self.journal = None
+
+ def content_add(self, contents: Iterable[Content]) -> None:
+ """Add contents to the journal. Drop the data field if provided.
+
+ """
+ if not self.journal:
+ return
+ for item in contents:
+ content = item.to_dict()
+ if 'data' in content:
+ del content['data']
+ self.journal.write_addition('content', content)
+
+ def content_update(self, contents: Iterable[Content]) -> None:
+ if not self.journal:
+ return
+ raise NotImplementedError(
+ 'content_update is not yet supported with a journal writer.')
+
+ def content_add_metadata(
+ self, contents: Iterable[Content]) -> None:
+ return self.content_add(contents)
+
+ def skipped_content_add(
+ self, contents: Iterable[Content]) -> None:
+ return self.content_add(contents)
+
+ def directory_add(self, directories: Iterable[Directory]) -> None:
+ if not self.journal:
+ return
+ self.journal.write_additions('directory', directories)
+
+ def revision_add(self, revisions: Iterable[Revision]) -> None:
+ if not self.journal:
+ return
+ self.journal.write_additions('revision', revisions)
+
+ def release_add(self, releases: Iterable[Release]) -> None:
+ if not self.journal:
+ return
+ self.journal.write_additions('release', releases)
+
+ def snapshot_add(
+ self, snapshots: Union[Iterable[Snapshot], Snapshot]) -> None:
+ if not self.journal:
+ return
+ snaps = snapshots if isinstance(snapshots, list) else [snapshots]
+ self.journal.write_additions('snapshot', snaps)
+
+ def origin_visit_add(self, visit: OriginVisit):
+ if not self.journal:
+ return
+ self.journal.write_addition('origin_visit', visit)
+
+ def origin_visit_update(self, visit: OriginVisit):
+ if not self.journal:
+ return
+ self.journal.write_update('origin_visit', visit)
+
+ def origin_visit_upsert(self, visits: Iterable[OriginVisit]):
+ if not self.journal:
+ return
+ self.journal.write_additions('origin_visit', visits)
+
+ def origin_add_one(self, origin: Origin):
+ if not self.journal:
+ return
+ self.journal.write_addition('origin', origin)

File Metadata

Mime Type
text/plain
Expires
Wed, Dec 18, 3:35 PM (1 d, 5 h ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3221401

Event Timeline