diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -19,6 +19,8 @@ from swh.objstorage import get_objstorage from swh.objstorage.exc import ObjNotFoundError +from .journal_writer import get_journal_writer + # Max block size of contents to return BULK_BLOCK_CONTENT_LEN_MAX = 10000 @@ -28,7 +30,7 @@ class Storage: - def __init__(self): + def __init__(self, journal_writer=None): self._contents = {} self._content_indexes = defaultdict(lambda: defaultdict(set)) @@ -48,6 +50,10 @@ self._sorted_sha1s = [] self.objstorage = get_objstorage('memory', {}) + if journal_writer: + self.journal_writer = get_journal_writer(**journal_writer) + else: + self.journal_writer = None def check_config(self, *, check_write): """Check that the storage is configured and ready to go.""" @@ -72,6 +78,12 @@ content in """ + if self.journal_writer: + for content in contents: + if 'data' in content: + content = content.copy() + del content['data'] + self.journal_writer.write_addition('content', content) for content in contents: key = self._content_key(content) if key in self._contents: @@ -283,6 +295,9 @@ directory entry - perms (int): entry permissions """ + if self.journal_writer: + self.journal_writer.write_additions('directory', directories) + for directory in directories: if directory['id'] not in self._directories: self._directories[directory['id']] = copy.deepcopy(directory) @@ -413,6 +428,9 @@ date dictionaries have the form defined in :mod:`swh.model`. """ + if self.journal_writer: + self.journal_writer.write_additions('revision', revisions) + for revision in revisions: if revision['id'] not in self._revisions: self._revisions[revision['id']] = rev = copy.deepcopy(revision) @@ -501,6 +519,9 @@ the date dictionary has the form defined in :mod:`swh.model`. """ + if self.journal_writer: + self.journal_writer.write_additions('release', releases) + for rel in releases: rel = copy.deepcopy(rel) rel['date'] = normalize_timestamp(rel['date']) @@ -560,6 +581,7 @@ Raises: ValueError: if the origin's or visit's identifier does not exist. """ + snapshot_id = snapshot['id'] if snapshot_id not in self._snapshots: self._snapshots[snapshot_id] = { @@ -571,7 +593,16 @@ } self._objects[snapshot_id].append(('snapshot', snapshot_id)) if origin <= len(self._origin_visits) and \ - visit <= len(self._origin_visits[origin-1]): + visit <= len(self._origin_visits[origin-1]): + + if self.journal_writer: + self.journal_writer.write_addition( + 'snapshot', snapshot) + self.journal_writer.write_update('origin_visit', { + **self._origin_visits[origin-1][visit-1], + 'origin': self._origins[origin-1], + 'snapshot': snapshot_id}) + self._origin_visits[origin-1][visit-1]['snapshot'] = snapshot_id else: raise ValueError('Origin with id %s does not exist or has no visit' @@ -965,6 +996,12 @@ self._origin_visits.append([]) key = (origin['type'], origin['url']) self._objects[key].append(('origin', origin_id)) + else: + origin['id'] = origin_id + + if self.journal_writer: + self.journal_writer.write_addition('origin', origin) + return origin_id def fetch_history_start(self, origin_id): @@ -1009,28 +1046,35 @@ DeprecationWarning) date = ts + origin_id = origin # TODO: rename the argument + if isinstance(date, str): date = dateutil.parser.parse(date) visit_ret = None - if origin <= len(self._origin_visits): + if origin_id <= len(self._origin_visits): # visit ids are in the range [1, +inf[ - visit_id = len(self._origin_visits[origin-1]) + 1 + visit_id = len(self._origin_visits[origin_id-1]) + 1 status = 'ongoing' visit = { - 'origin': origin, + 'origin': origin_id, 'date': date, 'status': status, 'snapshot': None, 'metadata': None, 'visit': visit_id } - self._origin_visits[origin-1].append(visit) + self._origin_visits[origin_id-1].append(visit) visit_ret = { - 'origin': origin, + 'origin': origin_id, 'visit': visit_id, } + if self.journal_writer: + origin = self.origin_get([{'id': origin_id}])[0] + self.journal_writer.write_addition('origin_visit', { + **visit, 'origin': origin}) + return visit_ret def origin_visit_update(self, origin, visit_id, status, metadata=None): @@ -1046,10 +1090,18 @@ None """ - if origin > len(self._origin_visits) or \ - visit_id > len(self._origin_visits[origin-1]): + origin_id = origin # TODO: rename the argument + + if self.journal_writer: + origin = self.origin_get([{'id': origin_id}])[0] + self.journal_writer.write_update('origin_visit', { + **self._origin_visits[origin_id-1][visit_id-1], + 'origin': origin, 'visit': visit_id, + 'status': status, 'metadata': metadata}) + if origin_id > len(self._origin_visits) or \ + visit_id > len(self._origin_visits[origin_id-1]): return - self._origin_visits[origin-1][visit_id-1].update({ + self._origin_visits[origin_id-1][visit_id-1].update({ 'status': status, 'metadata': metadata}) diff --git a/swh/storage/journal_writer.py b/swh/storage/journal_writer.py new file mode 100644 --- /dev/null +++ b/swh/storage/journal_writer.py @@ -0,0 +1,34 @@ +# Copyright (C) 2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import copy +from multiprocessing import Manager + + +class InMemoryJournalWriter: + def __init__(self): + # Share the list of objects across processes, for RemoteAPI tests. + self.manager = Manager() + self.objects = self.manager.list() + + def write_addition(self, object_type, object_): + self.objects.append((object_type, copy.deepcopy(object_))) + + write_update = write_addition + + def write_additions(self, object_type, objects): + for object_ in objects: + self.write_addition(object_type, object_) + + +def get_journal_writer(cls, args={}): + if cls == 'inmemory': + JournalWriter = InMemoryJournalWriter + elif cls == 'kafka': + import swh.journal.direct_writer.DirectKafkaWriter as JournalWriter + else: + raise ValueError('Unknown storage class `%s`' % cls) + + return JournalWriter(**args) diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -20,6 +20,7 @@ from .db import Db from .exc import StorageDBError from .algos import diff +from .journal_writer import get_journal_writer from swh.model.hashutil import ALGORITHMS, hash_to_bytes from swh.objstorage import get_objstorage @@ -38,7 +39,8 @@ """ - def __init__(self, db, objstorage, min_pool_conns=1, max_pool_conns=10): + def __init__(self, db, objstorage, min_pool_conns=1, max_pool_conns=10, + journal_writer=None): """ Args: db_conn: either a libpq connection string, or a psycopg2 connection @@ -58,6 +60,10 @@ raise StorageDBError(e) self.objstorage = get_objstorage(**objstorage) + if journal_writer: + self.journal_writer = get_journal_writer(**journal_writer) + else: + self.journal_writer = None def get_db(self): if self._db: @@ -94,7 +100,7 @@ object storage is idempotent, that should not be a problem. Args: - content (iterable): iterable of dictionaries representing + contents (iterable): iterable of dictionaries representing individual pieces of content to add. Each dictionary has the following keys: @@ -109,6 +115,13 @@ content in """ + if self.journal_writer: + for item in content: + if 'data' in item: + item = item.copy() + del item['data'] + self.journal_writer.write_addition('content', item) + db = self.get_db() def _unique_key(hash, keys=db.content_hash_keys): @@ -216,6 +229,10 @@ # TODO: Add a check on input keys. How to properly implement # this? We don't know yet the new columns. + if self.journal_writer: + raise NotImplementedError( + 'content_update is not yet support with a journal_writer.') + db.mktemp('content', cur) select_keys = list(set(db.content_get_metadata_keys).union(set(keys))) db.copy_to(content, 'tmp_content', select_keys, cur) @@ -433,6 +450,9 @@ directory entry - perms (int): entry permissions """ + if self.journal_writer: + self.journal_writer.write_additions('directory', directories) + dirs = set() dir_entries = { 'file': defaultdict(list), @@ -564,6 +584,9 @@ date dictionaries have the form defined in :mod:`swh.model`. """ + if self.journal_writer: + self.journal_writer.write_additions('revision', revisions) + db = self.get_db() revisions_missing = set(self.revision_missing( @@ -685,6 +708,9 @@ the date dictionary has the form defined in :mod:`swh.model`. """ + if self.journal_writer: + self.journal_writer.write_additions('release', releases) + db = self.get_db() release_ids = set(release['id'] for release in releases) @@ -768,7 +794,11 @@ Raises: ValueError: if the origin or visit id does not exist. """ + origin_id = origin + visit_id = visit + if not db.snapshot_exists(snapshot['id'], cur): + db.mktemp_snapshot_branch(cur) db.copy_to( ( @@ -783,11 +813,30 @@ ['name', 'target', 'target_type'], cur, ) - if not db.origin_visit_exists(origin, visit): - raise ValueError('Not origin visit with ids (%s, %s)' % - (origin, visit)) - db.snapshot_add(origin, visit, snapshot['id'], cur) + if self.journal_writer: + visit = db.origin_visit_get(origin_id, visit_id, cur=cur) + visit_exists = visit is not None + else: + visit_exists = db.origin_visit_exists(origin_id, visit_id) + + if not visit_exists: + raise ValueError('Not origin visit with ids (%s, %s)' % + (origin_id, visit_id)) + + if self.journal_writer: + # Send the snapshot before the origin: in case of a crash, + # it's better to have an orphan snapshot than have the + # origin_visit have a dangling reference to a snapshot + origin = self.origin_get([{'id': origin_id}], db=db, cur=cur)[0] + visit = dict(zip(db.origin_visit_get_cols, visit)) + self.journal_writer.write_addition('snapshot', snapshot) + self.journal_writer.write_update('origin_visit', { + 'origin': origin, 'visit': visit_id, + 'status': visit['status'], 'metadata': visit['metadata'], + 'date': visit['date'], 'snapshot': snapshot['id']}) + + db.snapshot_add(origin_id, visit_id, snapshot['id'], cur) @db_transaction(statement_timeout=2000) def snapshot_get(self, snapshot_id, db=None, cur=None): @@ -989,12 +1038,24 @@ DeprecationWarning) date = ts + origin_id = origin # TODO: rename the argument + if isinstance(date, str): date = dateutil.parser.parse(date) + visit = db.origin_visit_add(origin, date, cur) + + if self.journal_writer: + # We can write to the journal only after inserting to the + # DB, because we want the id of the visit + origin = self.origin_get([{'id': origin_id}], db=db, cur=cur)[0] + self.journal_writer.write_addition('origin_visit', { + 'origin': origin, 'date': date, 'visit': visit, + 'status': 'ongoing', 'metadata': None, 'snapshot': None}) + return { - 'origin': origin, - 'visit': db.origin_visit_add(origin, date, cur) + 'origin': origin_id, + 'visit': visit, } @db_transaction() @@ -1012,7 +1073,18 @@ None """ - return db.origin_visit_update(origin, visit_id, status, metadata, cur) + origin_id = origin # TODO: rename the argument + + if self.journal_writer: + origin = self.origin_get([{'id': origin_id}], db=db, cur=cur)[0] + visit = db.origin_visit_get(origin_id, visit_id, cur=cur) + visit = dict(zip(db.origin_visit_get_cols, visit)) + self.journal_writer.write_update('origin_visit', { + 'origin': origin, 'visit': visit_id, + 'status': status, 'metadata': metadata, + 'date': visit['date'], 'snapshot': None}) + return db.origin_visit_update( + origin_id, visit_id, status, metadata, cur) @db_transaction_generator(statement_timeout=500) def origin_visit_get(self, origin, last_visit=None, limit=None, db=None, @@ -1243,6 +1315,7 @@ """ for origin in origins: origin['id'] = self.origin_add_one(origin, db=db, cur=cur) + return origins @db_transaction() @@ -1266,7 +1339,12 @@ if origin_id: return origin_id - return db.origin_add(origin['type'], origin['url'], cur) + origin['id'] = db.origin_add(origin['type'], origin['url'], cur) + + if self.journal_writer: + self.journal_writer.write_addition('origin', origin) + + return origin['id'] @db_transaction() def fetch_history_start(self, origin_id, db=None, cur=None): diff --git a/swh/storage/tests/storage_testing.py b/swh/storage/tests/storage_testing.py --- a/swh/storage/tests/storage_testing.py +++ b/swh/storage/tests/storage_testing.py @@ -43,9 +43,13 @@ 'slicing': '0:1/1:5', }, }, + 'journal_writer': { + 'cls': 'inmemory', + }, }, } self.storage = get_storage(**self.storage_config) + self.journal_writer = self.storage.journal_writer def tearDown(self): self.objtmp.cleanup() diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py --- a/swh/storage/tests/test_api_client.py +++ b/swh/storage/tests/test_api_client.py @@ -9,7 +9,11 @@ import unittest from swh.core.tests.server_testing import ServerTestFixture +import swh.storage.storage as storage +from swh.storage.journal_writer import \ + get_journal_writer, InMemoryJournalWriter from swh.storage.api.client import RemoteStorage +import swh.storage.api.server as server from swh.storage.api.server import app from swh.storage.tests.test_storage import \ CommonTestStorage, CommonPropTestStorage, StorageTestDbFixture @@ -25,6 +29,14 @@ """ def setUp(self): + def mock_get_journal_writer(cls, args=None): + assert cls == 'inmemory' + return journal_writer + server.storage = None + storage.get_journal_writer = mock_get_journal_writer + journal_writer = InMemoryJournalWriter() + self.journal_writer = journal_writer + # ServerTestFixture needs to have self.objroot for # setUp() method, but this field is defined in # AbstractTestStorage's setUp() @@ -43,6 +55,9 @@ 'slicing': '0:2', }, }, + 'journal_writer': { + 'cls': 'inmemory', + } } } } @@ -52,6 +67,7 @@ self.objroot = self.storage_base def tearDown(self): + storage.get_journal_writer = get_journal_writer super().tearDown() shutil.rmtree(self.storage_base) diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py --- a/swh/storage/tests/test_in_memory.py +++ b/swh/storage/tests/test_in_memory.py @@ -21,10 +21,11 @@ """ def setUp(self): super().setUp() - self.storage = Storage() + self.storage = Storage(journal_writer={'cls': 'inmemory'}) + self.journal_writer = self.storage.journal_writer @pytest.mark.skip('postgresql-specific test') - def test_content_add(self): + def test_content_add_db(self): pass @pytest.mark.skip('postgresql-specific test') diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -553,6 +553,18 @@ cont = self.cont self.storage.content_add([cont]) + self.assertEqual(list(self.storage.content_get([cont['sha1']])), + [{'sha1': cont['sha1'], 'data': cont['data']}]) + + expected_cont = cont.copy() + del expected_cont['data'] + self.assertEqual(list(self.journal_writer.objects), + [('content', expected_cont)]) + + def test_content_add_db(self): + cont = self.cont + + self.storage.content_add([cont]) if hasattr(self.storage, 'objstorage'): self.assertIn(cont['sha1'], self.storage.objstorage) self.cursor.execute('SELECT sha1, sha1_git, sha256, length, status' @@ -565,6 +577,11 @@ (cont['sha1'], cont['sha1_git'], cont['sha256'], cont['length'], 'visible')) + expected_cont = cont.copy() + del expected_cont['data'] + self.assertEqual(list(self.journal_writer.objects), + [('content', expected_cont)]) + def test_content_add_collision(self): cont1 = self.cont @@ -706,6 +723,8 @@ self.assertEqual([self.dir['id']], init_missing) self.storage.directory_add([self.dir]) + self.assertEqual(list(self.journal_writer.objects), + [('directory', self.dir)]) actual_data = list(self.storage.directory_ls(self.dir['id'])) expected_data = list(self._transform_entries(self.dir)) @@ -719,6 +738,10 @@ self.assertEqual([self.dir['id']], init_missing) self.storage.directory_add([self.dir, self.dir2, self.dir3]) + self.assertEqual(list(self.journal_writer.objects), + [('directory', self.dir), + ('directory', self.dir2), + ('directory', self.dir3)]) actual_data = list(self.storage.directory_ls( self.dir['id'], recursive=True)) @@ -807,6 +830,9 @@ end_missing = self.storage.revision_missing([self.revision['id']]) self.assertEqual([], list(end_missing)) + self.assertEqual(list(self.journal_writer.objects), + [('revision', self.revision)]) + def test_revision_log(self): # given # self.revision4 -is-child-of-> self.revision3 @@ -830,6 +856,10 @@ self.assertEqual(actual_results[1], self.normalize_entity(self.revision3)) + self.assertEqual(list(self.journal_writer.objects), + [('revision', self.revision3), + ('revision', self.revision4)]) + def test_revision_log_with_limit(self): # given # self.revision4 -is-child-of-> self.revision3 @@ -921,6 +951,10 @@ self.release2['id']]) self.assertEqual([], list(end_missing)) + self.assertEqual(list(self.journal_writer.objects), + [('release', self.release), + ('release', self.release2)]) + def test_release_get(self): # given self.storage.release_add([self.release, self.release2]) @@ -975,6 +1009,10 @@ }])[0] self.assertEqual(actual_origin2['id'], origin2['id']) + self.assertEqual(list(self.journal_writer.objects), + [('origin', actual_origin), + ('origin', actual_origin2)]) + def test_origin_add_twice(self): add1 = self.storage.origin_add([self.origin, self.origin2]) add2 = self.storage.origin_add([self.origin, self.origin2]) @@ -1107,6 +1145,20 @@ 'snapshot': None, }]) + expected_origin = self.origin2.copy() + expected_origin['id'] = origin_id + data = { + 'origin': expected_origin, + 'date': self.date_visit2, + 'visit': origin_visit1['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': None, + } + self.assertEqual(list(self.journal_writer.objects), + [('origin', expected_origin), + ('origin_visit', data)]) + def test_origin_visit_update(self): # given origin_id = self.storage.origin_add_one(self.origin2) @@ -1188,6 +1240,59 @@ 'snapshot': None, }]) + expected_origin = self.origin2.copy() + expected_origin['id'] = origin_id + expected_origin2 = self.origin.copy() + expected_origin2['id'] = origin_id2 + data1 = { + 'origin': expected_origin, + 'date': self.date_visit2, + 'visit': origin_visit1['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': None, + } + data2 = { + 'origin': expected_origin, + 'date': self.date_visit3, + 'visit': origin_visit2['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': None, + } + data3 = { + 'origin': expected_origin2, + 'date': self.date_visit3, + 'visit': origin_visit3['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': None, + } + data4 = { + 'origin': expected_origin, + 'date': self.date_visit2, + 'visit': origin_visit1['visit'], + 'metadata': visit1_metadata, + 'status': 'full', + 'snapshot': None, + } + data5 = { + 'origin': expected_origin2, + 'date': self.date_visit3, + 'visit': origin_visit3['visit'], + 'status': 'partial', + 'metadata': None, + 'snapshot': None, + } + self.assertEqual(list(self.journal_writer.objects), + [('origin', expected_origin), + ('origin', expected_origin2), + ('origin_visit', data1), + ('origin_visit', data2), + ('origin_visit', data3), + ('origin_visit', data4), + ('origin_visit', data5)]) + def test_origin_visit_get_by(self): origin_id = self.storage.origin_add_one(self.origin2) origin_id2 = self.storage.origin_add_one(self.origin) @@ -1313,6 +1418,30 @@ by_ov = self.storage.snapshot_get_by_origin_visit(origin_id, visit_id) self.assertEqual(by_ov, self.empty_snapshot) + expected_origin = self.origin.copy() + expected_origin['id'] = origin_id + data1 = { + 'origin': expected_origin, + 'date': self.date_visit1, + 'visit': origin_visit1['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': None, + } + data2 = { + 'origin': expected_origin, + 'date': self.date_visit1, + 'visit': origin_visit1['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': self.empty_snapshot['id'], + } + self.assertEqual(list(self.journal_writer.objects), + [('origin', expected_origin), + ('origin_visit', data1), + ('snapshot', self.empty_snapshot), + ('origin_visit', data2)]) + def test_snapshot_add_get_complete(self): origin_id = self.storage.origin_add_one(self.origin) origin_visit1 = self.storage.origin_visit_add(origin_id, @@ -1467,6 +1596,20 @@ origin_id = self.storage.origin_add_one(self.origin) visit_id = 54164461156 + self.journal_writer.objects[:] = [] + + with self.assertRaises(ValueError): + self.storage.snapshot_add(origin_id, visit_id, self.snapshot) + + self.assertEqual(list(self.journal_writer.objects), []) + + def test_snapshot_add_nonexistent_visit_no_journal(self): + # Same test as before, but uses a different code path for checking + # the origin visit exists. + self.storage.journal_writer = None + origin_id = self.storage.origin_add_one(self.origin) + visit_id = 54164461156 + with self.assertRaises(ValueError): self.storage.snapshot_add(origin_id, visit_id, self.snapshot) @@ -1491,6 +1634,49 @@ visit2_id) self.assertEqual(by_ov2, self.snapshot) + expected_origin = self.origin.copy() + expected_origin['id'] = origin_id + data1 = { + 'origin': expected_origin, + 'date': self.date_visit1, + 'visit': origin_visit1['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': None, + } + data2 = { + 'origin': expected_origin, + 'date': self.date_visit1, + 'visit': origin_visit1['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': self.snapshot['id'], + } + data3 = { + 'origin': expected_origin, + 'date': self.date_visit2, + 'visit': origin_visit2['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': None, + } + data4 = { + 'origin': expected_origin, + 'date': self.date_visit2, + 'visit': origin_visit2['visit'], + 'status': 'ongoing', + 'metadata': None, + 'snapshot': self.snapshot['id'], + } + self.assertEqual(list(self.journal_writer.objects), + [('origin', expected_origin), + ('origin_visit', data1), + ('snapshot', self.snapshot), + ('origin_visit', data2), + ('origin_visit', data3), + ('snapshot', self.snapshot), + ('origin_visit', data4)]) + def test_snapshot_get_nonexistent(self): bogus_snapshot_id = b'bogus snapshot id 00' bogus_origin_id = 1 @@ -2319,6 +2505,8 @@ """ def test_content_update(self): + self.storage.journal_writer = None # TODO, not supported + cont = copy.deepcopy(self.cont) self.storage.content_add([cont]) @@ -2341,6 +2529,8 @@ cont['length'], 'visible')) def test_content_update_with_new_cols(self): + self.storage.journal_writer = None # TODO, not supported + with self.storage.get_db().transaction() as cur: cur.execute("""alter table content add column test text default null,