diff --git a/swh/storage/listener.py b/swh/storage/listener.py --- a/swh/storage/listener.py +++ b/swh/storage/listener.py @@ -19,7 +19,7 @@ DEFAULT_CONFIG = { 'database': ('str', 'service=softwareheritage'), 'brokers': ('list[str]', ['getty.internal.softwareheritage.org']), - 'topic_prefix': ('str', 'swh.tmp_journal.new'), + 'topic_prefix': ('str', 'swh.tmp_journal'), 'poll_timeout': ('int', 10), } @@ -60,13 +60,23 @@ 'origin', } +MUTABLE_OBJECT_TYPES = { + 'origin_visit', +} + def register_all_notifies(db): """Register to notifications for all object types listed in OBJECT_TYPES""" with db.transaction() as cur: for object_type in OBJECT_TYPES: db.register_listener('new_%s' % object_type, cur) - logging.debug('Registered to notify events %s' % object_type) + logging.debug( + 'Registered to notify events on new %s', object_type) + + for object_type in MUTABLE_OBJECT_TYPES: + db.register_listener('changed_%s' % object_type, cur) + logging.debug( + 'Registered to notify events on changed %s', object_type) def dispatch_notify(topic_prefix, producer, notify): @@ -74,15 +84,27 @@ logging.debug('topic_prefix: %s, producer: %s, notify: %s' % ( topic_prefix, producer, notify)) channel = notify.channel - if not channel.startswith('new_') or channel[4:] not in OBJECT_TYPES: + if channel.startswith('changed_') and channel[8:] in MUTABLE_OBJECT_TYPES: + object_type = channel[8:] + notification_type = 'changed' + elif channel.startswith('new_') and channel[4:] in OBJECT_TYPES: + object_type = channel[4:] + notification_type = 'new' + else: logging.warn("Got unexpected notify %s" % notify) return - object_type = channel[4:] - topic = '%s.%s' % (topic_prefix, object_type) + topic = '%s.%s.%s' % (topic_prefix, notification_type, object_type) producer.send(topic, value=decode(object_type, notify.payload)) +def run_once(db, producer, topic_prefix, poll_timeout): + for notify in db.listen_notifies(poll_timeout): + logging.debug('Notified by event %s' % notify) + dispatch_notify(topic_prefix, producer, notify) + producer.flush() + + def run_from_config(config): """Run the Software Heritage listener from configuration""" db = swh.storage.db.Db.connect(config['database']) @@ -108,10 +130,7 @@ poll_timeout = config['poll_timeout'] try: while True: - for notify in db.listen_notifies(poll_timeout): - logging.debug('Notified by event %s' % notify) - dispatch_notify(topic_prefix, producer, notify) - producer.flush() + run_once(db, producer, topic_prefix, poll_timeout) except Exception: logging.exception("Caught exception") producer.flush() diff --git a/swh/storage/sql/70-swh-triggers.sql b/swh/storage/sql/70-swh-triggers.sql --- a/swh/storage/sql/70-swh-triggers.sql +++ b/swh/storage/sql/70-swh-triggers.sql @@ -116,6 +116,26 @@ execute procedure notify_new_origin_visit(); +-- Asynchronous notification of modified origin visits +create function notify_changed_origin_visit() + returns trigger + language plpgsql +as $$ + begin + perform pg_notify('changed_origin_visit', json_build_object( + 'origin', new.origin, + 'visit', new.visit + )::text); + return null; + end; +$$; + +create trigger notify_changed_origin_visit + after update on origin_visit + for each row + execute procedure notify_changed_origin_visit(); + + -- Asynchronous notification of new release insertions create function notify_new_release() returns trigger diff --git a/swh/storage/tests/test_listener.py b/swh/storage/tests/test_listener.py --- a/swh/storage/tests/test_listener.py +++ b/swh/storage/tests/test_listener.py @@ -4,9 +4,67 @@ # See top-level LICENSE file for more information import json +import os import unittest +import unittest.mock -from swh.storage.listener import decode +import pytest + +from swh.core.tests.db_testing import SingleDbTestFixture +from swh.storage.tests.storage_testing import StorageTestFixture +from swh.storage.tests.test_storage import TestStorageData +import swh.storage.listener as listener +from swh.storage.db import Db +from . import SQL_DIR + + +@pytest.mark.db +class ListenerTest(StorageTestFixture, SingleDbTestFixture, + TestStorageData, unittest.TestCase): + TEST_DB_NAME = 'softwareheritage-test-storage' + TEST_DB_DUMP = os.path.join(SQL_DIR, '*.sql') + + def setUp(self): + super().setUp() + self.db = Db(self.conn) + + def tearDown(self): + self.db.conn.close() + super().tearDown() + + def test_notify(self): + class MockProducer: + def send(self, topic, value): + sent.append((topic, value)) + + def flush(self): + pass + + listener.register_all_notifies(self.db) + + # Add an origin and an origin visit + origin_id = self.storage.origin_add_one(self.origin) + visit = self.storage.origin_visit_add(origin_id, date=self.date_visit1) + visit_id = visit['visit'] + + sent = [] + listener.run_once(self.db, MockProducer(), 'swh.tmp_journal', 10) + self.assertEqual(sent, [ + ('swh.tmp_journal.new.origin', + {'type': 'git', 'url': 'file:///dev/null'}), + ('swh.tmp_journal.new.origin_visit', + {'origin': 1, 'visit': 1}), + ]) + + # Update the status of the origin visit + self.storage.origin_visit_update(origin_id, visit_id, status='full') + + sent = [] + listener.run_once(self.db, MockProducer(), 'swh.tmp_journal', 10) + self.assertEqual(sent, [ + ('swh.tmp_journal.changed.origin_visit', + {'origin': 1, 'visit': 1}), + ]) class ListenerUtils(unittest.TestCase): @@ -42,5 +100,5 @@ }] for i, (object_type, obj) in enumerate(inputs): - actual_value = decode(object_type, obj) + actual_value = listener.decode(object_type, obj) self.assertEqual(actual_value, expected_inputs[i])