diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -19,6 +19,8 @@ from swh.objstorage import get_objstorage from swh.objstorage.exc import ObjNotFoundError +from .journal_writer import get_journal_writer + # Max block size of contents to return BULK_BLOCK_CONTENT_LEN_MAX = 10000 @@ -28,7 +30,7 @@ class Storage: - def __init__(self): + def __init__(self, journal_writer=None): self._contents = {} self._content_indexes = defaultdict(lambda: defaultdict(set)) @@ -48,6 +50,10 @@ self._sorted_sha1s = [] self.objstorage = get_objstorage('memory', {}) + if journal_writer: + self.journal_writer = get_journal_writer(**journal_writer) + else: + self.journal_writer = None def check_config(self, *, check_write): """Check that the storage is configured and ready to go.""" @@ -72,6 +78,13 @@ content in """ + if self.journal_writer: + for content in contents: + if 'data' in content: + content = content.copy() + del content['data'] + self.journal_writer.write_addition( + 'content', content, id_=content['sha1']) for content in contents: key = self._content_key(content) if key in self._contents: @@ -283,6 +296,9 @@ directory entry - perms (int): entry permissions """ + if self.journal_writer: + self.journal_writer.write_additions('directory', directories) + for directory in directories: if directory['id'] not in self._directories: self._directories[directory['id']] = copy.deepcopy(directory) @@ -413,6 +429,9 @@ date dictionaries have the form defined in :mod:`swh.model`. """ + if self.journal_writer: + self.journal_writer.write_additions('revision', revisions) + for revision in revisions: if revision['id'] not in self._revisions: self._revisions[revision['id']] = rev = copy.deepcopy(revision) @@ -501,6 +520,9 @@ the date dictionary has the form defined in :mod:`swh.model`. """ + if self.journal_writer: + self.journal_writer.write_additions('release', releases) + for rel in releases: rel = copy.deepcopy(rel) rel['date'] = normalize_timestamp(rel['date']) @@ -560,6 +582,10 @@ Raises: ValueError: if the origin's or visit's identifier does not exist. """ + if self.journal_writer: + self.journal_writer.write_addition( + 'snapshot', (origin, visit, snapshot), id_=snapshot['id']) + snapshot_id = snapshot['id'] if snapshot_id not in self._snapshots: self._snapshots[snapshot_id] = { @@ -965,6 +991,12 @@ self._origin_visits.append([]) key = (origin['type'], origin['url']) self._objects[key].append(('origin', origin_id)) + else: + origin['id'] = origin_id + + if self.journal_writer: + self.journal_writer.write_addition('origin', origin) + return origin_id def fetch_history_start(self, origin_id): @@ -1031,6 +1063,11 @@ 'visit': visit_id, } + if self.journal_writer: + self.journal_writer.write_addition('origin_visit', { + 'origin': origin, 'date': date, 'visit': visit_id}, + id_=b'%d-%d' % (origin, visit_id)) + return visit_ret def origin_visit_update(self, origin, visit_id, status, metadata=None): @@ -1046,6 +1083,11 @@ None """ + if self.journal_writer: + self.journal_writer.write_update('origin_visit', { + 'origin': origin, 'visit': visit_id, + 'status': status, 'metadata': metadata}, + id_=b'%d-%d' % (origin, visit_id)) if origin > len(self._origin_visits) or \ visit_id > len(self._origin_visits[origin-1]): return diff --git a/swh/storage/journal_writer.py b/swh/storage/journal_writer.py new file mode 100644 --- /dev/null +++ b/swh/storage/journal_writer.py @@ -0,0 +1,35 @@ +# Copyright (C) 2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from multiprocessing import Manager + + +class InMemoryJournalWriter: + def __init__(self): + # Share the list of objects across processes, for RemoteAPI tests. + self.manager = Manager() + self.objects = self.manager.list() + + def write_addition(self, object_type, object_, id_=None): + if id_ is None: + id_ = object_['id'] + self.objects.append((object_type, id_, object_)) + + write_update = write_addition + + def write_additions(self, object_type, objects): + for object_ in objects: + self.write_addition(object_type, object_) + + +def get_journal_writer(cls, args={}): + if cls == 'inmemory': + JournalWriter = InMemoryJournalWriter + elif cls == 'kafka': + import swh.journal.direct_writer.DirectKafkaWriter as JournalWriter + else: + raise ValueError('Unknown storage class `%s`' % cls) + + return JournalWriter(**args) diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -20,6 +20,7 @@ from .db import Db from .exc import StorageDBError from .algos import diff +from .journal_writer import get_journal_writer from swh.model.hashutil import ALGORITHMS, hash_to_bytes from swh.objstorage import get_objstorage @@ -37,7 +38,8 @@ """ - def __init__(self, db, objstorage, min_pool_conns=1, max_pool_conns=10): + def __init__(self, db, objstorage, min_pool_conns=1, max_pool_conns=10, + journal_writer=None): """ Args: db_conn: either a libpq connection string, or a psycopg2 connection @@ -57,6 +59,10 @@ raise StorageDBError(e) self.objstorage = get_objstorage(**objstorage) + if journal_writer: + self.journal_writer = get_journal_writer(**journal_writer) + else: + self.journal_writer = None def get_db(self): if self._db: @@ -93,7 +99,7 @@ object storage is idempotent, that should not be a problem. Args: - content (iterable): iterable of dictionaries representing + contents (iterable): iterable of dictionaries representing individual pieces of content to add. Each dictionary has the following keys: @@ -108,6 +114,14 @@ content in """ + if self.journal_writer: + for item in content: + if 'data' in item: + item = item.copy() + del item['data'] + self.journal_writer.write_addition( + 'content', item, id_=item['sha1']) + db = self.get_db() def _unique_key(hash, keys=db.content_hash_keys): @@ -215,6 +229,14 @@ # TODO: Add a check on input keys. How to properly implement # this? We don't know yet the new columns. + if self.journal_writer: + for item in content: + if 'data' in item: + item = item.copy() + del item['data'] + self.journal_writer.write_update( + 'content', item, id_=item['sha1']) + db.mktemp('content', cur) select_keys = list(set(db.content_get_metadata_keys).union(set(keys))) db.copy_to(content, 'tmp_content', select_keys, cur) @@ -432,6 +454,9 @@ directory entry - perms (int): entry permissions """ + if self.journal_writer: + self.journal_writer.write_additions('directory', directories) + dirs = set() dir_entries = { 'file': defaultdict(list), @@ -563,6 +588,9 @@ date dictionaries have the form defined in :mod:`swh.model`. """ + if self.journal_writer: + self.journal_writer.write_additions('revision', revisions) + db = self.get_db() revisions_missing = set(self.revision_missing( @@ -684,6 +712,9 @@ the date dictionary has the form defined in :mod:`swh.model`. """ + if self.journal_writer: + self.journal_writer.write_additions('release', releases) + db = self.get_db() release_ids = set(release['id'] for release in releases) @@ -767,6 +798,10 @@ Raises: ValueError: if the origin or visit id does not exist. """ + if self.journal_writer: + self.journal_writer.write_addition( + 'snapshot', (origin, visit, snapshot), id_=snapshot['id']) + if not db.snapshot_exists(snapshot['id'], cur): db.mktemp_snapshot_branch(cur) db.copy_to( @@ -991,9 +1026,16 @@ if isinstance(date, str): date = dateutil.parser.parse(date) + visit = db.origin_visit_add(origin, date, cur) + + if self.journal_writer: + self.journal_writer.write_addition('origin_visit', { + 'origin': origin, 'date': date, 'visit': visit}, + id_=b'%d-%d' % (origin, visit)) + return { 'origin': origin, - 'visit': db.origin_visit_add(origin, date, cur) + 'visit': visit, } @db_transaction() @@ -1011,6 +1053,11 @@ None """ + if self.journal_writer: + self.journal_writer.write_update('origin_visit', { + 'origin': origin, 'visit': visit_id, + 'status': status, 'metadata': metadata}, + id_=b'%d-%d' % (origin, visit_id)) return db.origin_visit_update(origin, visit_id, status, metadata, cur) @db_transaction_generator(statement_timeout=500) @@ -1242,6 +1289,7 @@ """ for origin in origins: origin['id'] = self.origin_add_one(origin, db=db, cur=cur) + return origins @db_transaction() @@ -1265,7 +1313,12 @@ if origin_id: return origin_id - return db.origin_add(origin['type'], origin['url'], cur) + origin['id'] = db.origin_add(origin['type'], origin['url'], cur) + + if self.journal_writer: + self.journal_writer.write_addition('origin', origin) + + return origin['id'] @db_transaction() def fetch_history_start(self, origin_id, db=None, cur=None): diff --git a/swh/storage/tests/storage_testing.py b/swh/storage/tests/storage_testing.py --- a/swh/storage/tests/storage_testing.py +++ b/swh/storage/tests/storage_testing.py @@ -43,9 +43,13 @@ 'slicing': '0:1/1:5', }, }, + 'journal_writer': { + 'cls': 'inmemory', + }, }, } self.storage = get_storage(**self.storage_config) + self.journal_writer = self.storage.journal_writer def tearDown(self): self.objtmp.cleanup() diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py --- a/swh/storage/tests/test_api_client.py +++ b/swh/storage/tests/test_api_client.py @@ -9,7 +9,11 @@ import unittest from swh.core.tests.server_testing import ServerTestFixture +import swh.storage.storage as storage +from swh.storage.journal_writer import \ + get_journal_writer, InMemoryJournalWriter from swh.storage.api.client import RemoteStorage +import swh.storage.api.server as server from swh.storage.api.server import app from swh.storage.tests.test_storage import \ CommonTestStorage, CommonPropTestStorage, StorageTestDbFixture @@ -25,6 +29,14 @@ """ def setUp(self): + def mock_get_journal_writer(cls, args=None): + assert cls == 'inmemory' + return journal_writer + server.storage = None + storage.get_journal_writer = mock_get_journal_writer + journal_writer = InMemoryJournalWriter() + self.journal_writer = journal_writer + # ServerTestFixture needs to have self.objroot for # setUp() method, but this field is defined in # AbstractTestStorage's setUp() @@ -43,6 +55,9 @@ 'slicing': '0:2', }, }, + 'journal_writer': { + 'cls': 'inmemory', + } } } } @@ -52,6 +67,7 @@ self.objroot = self.storage_base def tearDown(self): + storage.get_journal_writer = get_journal_writer super().tearDown() shutil.rmtree(self.storage_base) diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py --- a/swh/storage/tests/test_in_memory.py +++ b/swh/storage/tests/test_in_memory.py @@ -21,10 +21,11 @@ """ def setUp(self): super().setUp() - self.storage = Storage() + self.storage = Storage(journal_writer={'cls': 'inmemory'}) + self.journal_writer = self.storage.journal_writer @pytest.mark.skip('postgresql-specific test') - def test_content_add(self): + def test_content_add_db(self): pass @pytest.mark.skip('postgresql-specific test') diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -553,6 +553,18 @@ cont = self.cont self.storage.content_add([cont]) + self.assertEqual(list(self.storage.content_get([cont['sha1']])), + [{'sha1': cont['sha1'], 'data': cont['data']}]) + + expected_cont = cont.copy() + del expected_cont['data'] + self.assertEqual(list(self.journal_writer.objects), + [('content', expected_cont['sha1'], expected_cont)]) + + def test_content_add_db(self): + cont = self.cont + + self.storage.content_add([cont]) if hasattr(self.storage, 'objstorage'): self.assertIn(cont['sha1'], self.storage.objstorage) self.cursor.execute('SELECT sha1, sha1_git, sha256, length, status' @@ -565,6 +577,11 @@ (cont['sha1'], cont['sha1_git'], cont['sha256'], cont['length'], 'visible')) + expected_cont = cont.copy() + del expected_cont['data'] + self.assertEqual(list(self.journal_writer.objects), + [('content', expected_cont['sha1'], expected_cont)]) + def test_content_add_collision(self): cont1 = self.cont @@ -706,6 +723,8 @@ self.assertEqual([self.dir['id']], init_missing) self.storage.directory_add([self.dir]) + self.assertEqual(list(self.journal_writer.objects), + [('directory', self.dir['id'], self.dir)]) actual_data = list(self.storage.directory_ls(self.dir['id'])) expected_data = list(self._transform_entries(self.dir)) @@ -719,6 +738,10 @@ self.assertEqual([self.dir['id']], init_missing) self.storage.directory_add([self.dir, self.dir2, self.dir3]) + self.assertEqual(list(self.journal_writer.objects), + [('directory', self.dir['id'], self.dir), + ('directory', self.dir2['id'], self.dir2), + ('directory', self.dir3['id'], self.dir3)]) actual_data = list(self.storage.directory_ls( self.dir['id'], recursive=True)) @@ -807,6 +830,9 @@ end_missing = self.storage.revision_missing([self.revision['id']]) self.assertEqual([], list(end_missing)) + self.assertEqual(list(self.journal_writer.objects), + [('revision', self.revision['id'], self.revision)]) + def test_revision_log(self): # given # self.revision4 -is-child-of-> self.revision3 @@ -830,6 +856,10 @@ self.assertEqual(actual_results[1], self.normalize_entity(self.revision3)) + self.assertEqual(list(self.journal_writer.objects), + [('revision', self.revision3['id'], self.revision3), + ('revision', self.revision4['id'], self.revision4)]) + def test_revision_log_with_limit(self): # given # self.revision4 -is-child-of-> self.revision3 @@ -921,6 +951,10 @@ self.release2['id']]) self.assertEqual([], list(end_missing)) + self.assertEqual(list(self.journal_writer.objects), + [('release', self.release['id'], self.release), + ('release', self.release2['id'], self.release2)]) + def test_release_get(self): # given self.storage.release_add([self.release, self.release2]) @@ -975,6 +1009,10 @@ }])[0] self.assertEqual(actual_origin2['id'], origin2['id']) + self.assertEqual(list(self.journal_writer.objects), + [('origin', actual_origin['id'], actual_origin), + ('origin', actual_origin2['id'], actual_origin2)]) + def test_origin_add_twice(self): add1 = self.storage.origin_add([self.origin, self.origin2]) add2 = self.storage.origin_add([self.origin, self.origin2]) @@ -1107,6 +1145,18 @@ 'snapshot': None, }]) + expected_origin = self.origin2.copy() + expected_origin['id'] = origin_id + id_ = b'%d-%d' % (origin_id, origin_visit1['visit']) + data = { + 'origin': origin_id, + 'date': self.date_visit2, + 'visit': origin_visit1['visit'], + } + self.assertEqual(list(self.journal_writer.objects), + [('origin', origin_id, expected_origin), + ('origin_visit', id_, data)]) + def test_origin_visit_update(self): # given origin_id = self.storage.origin_add_one(self.origin2)