diff --git a/swh/storage/listener.py b/swh/storage/listener.py --- a/swh/storage/listener.py +++ b/swh/storage/listener.py @@ -66,7 +66,8 @@ with db.transaction() as cur: for object_type in OBJECT_TYPES: db.register_listener('new_%s' % object_type, cur) - logging.debug('Registered to notify events %s' % object_type) + logging.debug( + 'Registered to events for object type %s', object_type) def dispatch_notify(topic_prefix, producer, notify): @@ -83,6 +84,13 @@ producer.send(topic, value=decode(object_type, notify.payload)) +def run_once(db, producer, topic_prefix, poll_timeout): + for notify in db.listen_notifies(poll_timeout): + logging.debug('Notified by event %s' % notify) + dispatch_notify(topic_prefix, producer, notify) + producer.flush() + + def run_from_config(config): """Run the Software Heritage listener from configuration""" db = swh.storage.db.Db.connect(config['database']) @@ -108,10 +116,7 @@ poll_timeout = config['poll_timeout'] try: while True: - for notify in db.listen_notifies(poll_timeout): - logging.debug('Notified by event %s' % notify) - dispatch_notify(topic_prefix, producer, notify) - producer.flush() + run_once(db, producer, topic_prefix, poll_timeout) except Exception: logging.exception("Caught exception") producer.flush() diff --git a/swh/storage/sql/70-swh-triggers.sql b/swh/storage/sql/70-swh-triggers.sql --- a/swh/storage/sql/70-swh-triggers.sql +++ b/swh/storage/sql/70-swh-triggers.sql @@ -116,6 +116,26 @@ execute procedure notify_new_origin_visit(); +-- Asynchronous notification of modified origin visits +create function notify_changed_origin_visit() + returns trigger + language plpgsql +as $$ + begin + perform pg_notify('new_origin_visit', json_build_object( + 'origin', new.origin, + 'visit', new.visit + )::text); + return null; + end; +$$; + +create trigger notify_changed_origin_visit + after update on origin_visit + for each row + execute procedure notify_changed_origin_visit(); + + -- Asynchronous notification of new release insertions create function notify_new_release() returns trigger diff --git a/swh/storage/tests/test_listener.py b/swh/storage/tests/test_listener.py --- a/swh/storage/tests/test_listener.py +++ b/swh/storage/tests/test_listener.py @@ -4,9 +4,67 @@ # See top-level LICENSE file for more information import json +import os import unittest +import unittest.mock -from swh.storage.listener import decode +import pytest + +from swh.core.tests.db_testing import SingleDbTestFixture +from swh.storage.tests.storage_testing import StorageTestFixture +from swh.storage.tests.test_storage import TestStorageData +import swh.storage.listener as listener +from swh.storage.db import Db +from . import SQL_DIR + + +@pytest.mark.db +class ListenerTest(StorageTestFixture, SingleDbTestFixture, + TestStorageData, unittest.TestCase): + TEST_DB_NAME = 'softwareheritage-test-storage' + TEST_DB_DUMP = os.path.join(SQL_DIR, '*.sql') + + def setUp(self): + super().setUp() + self.db = Db(self.conn) + + def tearDown(self): + self.db.conn.close() + super().tearDown() + + def test_notify(self): + class MockProducer: + def send(self, topic, value): + sent.append((topic, value)) + + def flush(self): + pass + + listener.register_all_notifies(self.db) + + # Add an origin and an origin visit + origin_id = self.storage.origin_add_one(self.origin) + visit = self.storage.origin_visit_add(origin_id, date=self.date_visit1) + visit_id = visit['visit'] + + sent = [] + listener.run_once(self.db, MockProducer(), 'swh.tmp_journal.new', 10) + self.assertEqual(sent, [ + ('swh.tmp_journal.new.origin', + {'type': 'git', 'url': 'file:///dev/null'}), + ('swh.tmp_journal.new.origin_visit', + {'origin': 1, 'visit': 1}), + ]) + + # Update the status of the origin visit + self.storage.origin_visit_update(origin_id, visit_id, status='full') + + sent = [] + listener.run_once(self.db, MockProducer(), 'swh.tmp_journal.new', 10) + self.assertEqual(sent, [ + ('swh.tmp_journal.new.origin_visit', + {'origin': 1, 'visit': 1}), + ]) class ListenerUtils(unittest.TestCase): @@ -42,5 +100,5 @@ }] for i, (object_type, obj) in enumerate(inputs): - actual_value = decode(object_type, obj) + actual_value = listener.decode(object_type, obj) self.assertEqual(actual_value, expected_inputs[i])