diff --git a/swh/storage/listener.py b/swh/storage/listener.py index 52b1ef49f..4d0f844e7 100644 --- a/swh/storage/listener.py +++ b/swh/storage/listener.py @@ -1,137 +1,142 @@ # Copyright (C) 2016-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import json import logging import kafka import msgpack import swh.storage.db from swh.core.config import load_named_config from swh.model import hashutil CONFIG_BASENAME = 'storage/listener' DEFAULT_CONFIG = { 'database': ('str', 'service=softwareheritage'), 'brokers': ('list[str]', ['getty.internal.softwareheritage.org']), 'topic_prefix': ('str', 'swh.tmp_journal.new'), 'poll_timeout': ('int', 10), } def decode(object_type, obj): """Decode a JSON obj of nature object_type. Depending on the nature of the object, this can contain hex hashes (cf. `/swh/storage/sql/70-swh-triggers.sql`). Args: object_type (str): Nature of the object obj (str): json dict representation whose values might be hex identifier. Returns: dict representation ready for journal serialization """ value = json.loads(obj) if object_type in ('origin', 'origin_visit'): result = value else: result = {} for k, v in value.items(): result[k] = hashutil.hash_to_bytes(v) return result OBJECT_TYPES = { 'content', 'skipped_content', 'directory', 'revision', 'release', 'snapshot', 'origin_visit', 'origin', } def register_all_notifies(db): """Register to notifications for all object types listed in OBJECT_TYPES""" with db.transaction() as cur: for object_type in OBJECT_TYPES: db.register_listener('new_%s' % object_type, cur) - logging.debug('Registered to notify events %s' % object_type) + logging.debug( + 'Registered to events for object type %s', object_type) def dispatch_notify(topic_prefix, producer, notify): """Dispatch a notification to the proper topic""" logging.debug('topic_prefix: %s, producer: %s, notify: %s' % ( topic_prefix, producer, notify)) channel = notify.channel if not channel.startswith('new_') or channel[4:] not in OBJECT_TYPES: logging.warn("Got unexpected notify %s" % notify) return object_type = channel[4:] topic = '%s.%s' % (topic_prefix, object_type) producer.send(topic, value=decode(object_type, notify.payload)) +def run_once(db, producer, topic_prefix, poll_timeout): + for notify in db.listen_notifies(poll_timeout): + logging.debug('Notified by event %s' % notify) + dispatch_notify(topic_prefix, producer, notify) + producer.flush() + + def run_from_config(config): """Run the Software Heritage listener from configuration""" db = swh.storage.db.Db.connect(config['database']) def key_to_kafka(key): """Serialize a key, possibly a dict, in a predictable way. Duplicated from swh.journal to avoid a cyclic dependency.""" p = msgpack.Packer(use_bin_type=True) if isinstance(key, dict): return p.pack_map_pairs(sorted(key.items())) else: return p.pack(key) producer = kafka.KafkaProducer( bootstrap_servers=config['brokers'], value_serializer=key_to_kafka, ) register_all_notifies(db) topic_prefix = config['topic_prefix'] poll_timeout = config['poll_timeout'] try: while True: - for notify in db.listen_notifies(poll_timeout): - logging.debug('Notified by event %s' % notify) - dispatch_notify(topic_prefix, producer, notify) - producer.flush() + run_once(db, producer, topic_prefix, poll_timeout) except Exception: logging.exception("Caught exception") producer.flush() if __name__ == '__main__': import click @click.command() @click.option('--verbose', is_flag=True, default=False, help='Be verbose if asked.') def main(verbose): logging.basicConfig( level=logging.DEBUG if verbose else logging.INFO, format='%(asctime)s %(process)d %(levelname)s %(message)s' ) _log = logging.getLogger('kafka') _log.setLevel(logging.INFO) config = load_named_config(CONFIG_BASENAME, DEFAULT_CONFIG) run_from_config(config) main() diff --git a/swh/storage/sql/70-swh-triggers.sql b/swh/storage/sql/70-swh-triggers.sql index e8f955b53..d087bf1c1 100644 --- a/swh/storage/sql/70-swh-triggers.sql +++ b/swh/storage/sql/70-swh-triggers.sql @@ -1,150 +1,170 @@ -- Asynchronous notification of new content insertions create function notify_new_content() returns trigger language plpgsql as $$ begin perform pg_notify('new_content', json_build_object( 'sha1', encode(new.sha1, 'hex'), 'sha1_git', encode(new.sha1_git, 'hex'), 'sha256', encode(new.sha256, 'hex'), 'blake2s256', encode(new.blake2s256, 'hex') )::text); return null; end; $$; create trigger notify_new_content after insert on content for each row execute procedure notify_new_content(); -- Asynchronous notification of new origin insertions create function notify_new_origin() returns trigger language plpgsql as $$ begin perform pg_notify('new_origin', json_build_object( 'url', new.url::text, 'type', new.type::text )::text); return null; end; $$; create trigger notify_new_origin after insert on origin for each row execute procedure notify_new_origin(); -- Asynchronous notification of new skipped content insertions create function notify_new_skipped_content() returns trigger language plpgsql as $$ begin perform pg_notify('new_skipped_content', json_build_object( 'sha1', encode(new.sha1, 'hex'), 'sha1_git', encode(new.sha1_git, 'hex'), 'sha256', encode(new.sha256, 'hex'), 'blake2s256', encode(new.blake2s256, 'hex') )::text); return null; end; $$; create trigger notify_new_skipped_content after insert on skipped_content for each row execute procedure notify_new_skipped_content(); -- Asynchronous notification of new directory insertions create function notify_new_directory() returns trigger language plpgsql as $$ begin perform pg_notify('new_directory', json_build_object('id', encode(new.id, 'hex'))::text); return null; end; $$; create trigger notify_new_directory after insert on directory for each row execute procedure notify_new_directory(); -- Asynchronous notification of new revision insertions create function notify_new_revision() returns trigger language plpgsql as $$ begin perform pg_notify('new_revision', json_build_object('id', encode(new.id, 'hex'))::text); return null; end; $$; create trigger notify_new_revision after insert on revision for each row execute procedure notify_new_revision(); -- Asynchronous notification of new origin visits create function notify_new_origin_visit() returns trigger language plpgsql as $$ begin perform pg_notify('new_origin_visit', json_build_object( 'origin', new.origin, 'visit', new.visit )::text); return null; end; $$; create trigger notify_new_origin_visit after insert on origin_visit for each row execute procedure notify_new_origin_visit(); +-- Asynchronous notification of modified origin visits +create function notify_changed_origin_visit() + returns trigger + language plpgsql +as $$ + begin + perform pg_notify('new_origin_visit', json_build_object( + 'origin', new.origin, + 'visit', new.visit + )::text); + return null; + end; +$$; + +create trigger notify_changed_origin_visit + after update on origin_visit + for each row + execute procedure notify_changed_origin_visit(); + + -- Asynchronous notification of new release insertions create function notify_new_release() returns trigger language plpgsql as $$ begin perform pg_notify('new_release', json_build_object('id', encode(new.id, 'hex'))::text); return null; end; $$; create trigger notify_new_release after insert on release for each row execute procedure notify_new_release(); -- Asynchronous notification of new snapshot insertions create function notify_new_snapshot() returns trigger language plpgsql as $$ begin perform pg_notify('new_snapshot', json_build_object('id', encode(new.id, 'hex'))::text); return null; end; $$; create trigger notify_new_snapshot after insert on snapshot for each row execute procedure notify_new_snapshot(); diff --git a/swh/storage/tests/test_listener.py b/swh/storage/tests/test_listener.py index 4b32ea047..8496a0157 100644 --- a/swh/storage/tests/test_listener.py +++ b/swh/storage/tests/test_listener.py @@ -1,46 +1,104 @@ # Copyright (C) 2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import json +import os import unittest +import unittest.mock -from swh.storage.listener import decode +import pytest + +from swh.core.tests.db_testing import SingleDbTestFixture +from swh.storage.tests.storage_testing import StorageTestFixture +from swh.storage.tests.test_storage import TestStorageData +import swh.storage.listener as listener +from swh.storage.db import Db +from . import SQL_DIR + + +@pytest.mark.db +class ListenerTest(StorageTestFixture, SingleDbTestFixture, + TestStorageData, unittest.TestCase): + TEST_DB_NAME = 'softwareheritage-test-storage' + TEST_DB_DUMP = os.path.join(SQL_DIR, '*.sql') + + def setUp(self): + super().setUp() + self.db = Db(self.conn) + + def tearDown(self): + self.db.conn.close() + super().tearDown() + + def test_notify(self): + class MockProducer: + def send(self, topic, value): + sent.append((topic, value)) + + def flush(self): + pass + + listener.register_all_notifies(self.db) + + # Add an origin and an origin visit + origin_id = self.storage.origin_add_one(self.origin) + visit = self.storage.origin_visit_add(origin_id, date=self.date_visit1) + visit_id = visit['visit'] + + sent = [] + listener.run_once(self.db, MockProducer(), 'swh.tmp_journal.new', 10) + self.assertEqual(sent, [ + ('swh.tmp_journal.new.origin', + {'type': 'git', 'url': 'file:///dev/null'}), + ('swh.tmp_journal.new.origin_visit', + {'origin': 1, 'visit': 1}), + ]) + + # Update the status of the origin visit + self.storage.origin_visit_update(origin_id, visit_id, status='full') + + sent = [] + listener.run_once(self.db, MockProducer(), 'swh.tmp_journal.new', 10) + self.assertEqual(sent, [ + ('swh.tmp_journal.new.origin_visit', + {'origin': 1, 'visit': 1}), + ]) class ListenerUtils(unittest.TestCase): def test_decode(self): inputs = [ ('content', json.dumps({ 'sha1': '34973274ccef6ab4dfaaf86599792fa9c3fe4689', })), ('origin', json.dumps({ 'url': 'https://some/origin', 'type': 'svn', })), ('origin_visit', json.dumps({ 'visit': 2, 'origin': { 'url': 'https://some/origin', 'type': 'hg', } })) ] expected_inputs = [{ 'sha1': bytes.fromhex('34973274ccef6ab4dfaaf86599792fa9c3fe4689'), }, { 'url': 'https://some/origin', 'type': 'svn', }, { 'visit': 2, 'origin': { 'url': 'https://some/origin', 'type': 'hg' }, }] for i, (object_type, obj) in enumerate(inputs): - actual_value = decode(object_type, obj) + actual_value = listener.decode(object_type, obj) self.assertEqual(actual_value, expected_inputs[i])