diff --git a/swh/storage/listener.py b/swh/storage/listener.py index e86c8d0..8ceb4ae 100644 --- a/swh/storage/listener.py +++ b/swh/storage/listener.py @@ -1,112 +1,120 @@ -# Copyright (C) 2016-2017 The Software Heritage developers +# Copyright (C) 2016-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import json import logging import kafka import msgpack import swh.storage.db from swh.core.config import load_named_config CONFIG_BASENAME = 'storage/listener' DEFAULT_CONFIG = { 'database': ('str', 'service=softwareheritage'), 'brokers': ('list[str]', ['getty.internal.softwareheritage.org']), 'topic_prefix': ('str', 'swh.tmp_journal.new'), 'poll_timeout': ('int', 10), } -def decode_sha(value): - """Decode the textual representation of a SHA hash""" - if isinstance(value, str): - return bytes.fromhex(value) - return value +def decode(object_type, obj): + """Decode a JSON obj of nature object_type. Depending on the nature of + the object, this can contain hex hashes + Args: + object_type (str): Nature of the object + obj (str): json dict representation whose values might be hex + identifier. -def decode_json(value): - """Decode a JSON value containing hashes and other types""" - value = json.loads(value) + Returns: + dict representation ready for journal serialization - return {k: decode_sha(v) for k, v in value.items()} + """ + value = json.loads(obj) + + if object_type in ('origin', 'origin_visit'): + result = value + else: + result = {} + for k, v in value.items(): + result[k] = bytes.fromhex(v) + return result OBJECT_TYPES = { 'content', 'skipped_content', 'directory', 'revision', 'release', 'snapshot', 'origin_visit', 'origin', } def register_all_notifies(db): """Register to notifications for all object types listed in OBJECT_TYPES""" with db.transaction() as cur: for object_type in OBJECT_TYPES: db.register_listener('new_%s' % object_type, cur) def dispatch_notify(topic_prefix, producer, notify): """Dispatch a notification to the proper topic""" channel = notify.channel if not channel.startswith('new_') or channel[4:] not in OBJECT_TYPES: logging.warn("Got unexpected notify %s" % notify) return object_type = channel[4:] - topic = '%s.%s' % (topic_prefix, object_type) - data = decode_json(notify.payload) - producer.send(topic, value=data) + producer.send(topic, value=decode(object_type, notify.payload)) def run_from_config(config): """Run the Software Heritage listener from configuration""" db = swh.storage.db.Db.connect(config['database']) def key_to_kafka(key): """Serialize a key, possibly a dict, in a predictable way. Duplicated from swh.journal to avoid a cyclic dependency.""" p = msgpack.Packer(use_bin_type=True) if isinstance(key, dict): return p.pack_map_pairs(sorted(key.items())) else: return p.pack(key) producer = kafka.KafkaProducer( bootstrap_servers=config['brokers'], value_serializer=key_to_kafka, ) register_all_notifies(db) topic_prefix = config['topic_prefix'] poll_timeout = config['poll_timeout'] try: while True: for notify in db.listen_notifies(poll_timeout): dispatch_notify(topic_prefix, producer, notify) producer.flush() except Exception: logging.exception("Caught exception") producer.flush() if __name__ == '__main__': logging.basicConfig( level=logging.INFO, format='%(asctime)s %(process)d %(levelname)s %(message)s' ) config = load_named_config(CONFIG_BASENAME, DEFAULT_CONFIG) run_from_config(config) diff --git a/swh/storage/tests/test_listener.py b/swh/storage/tests/test_listener.py new file mode 100644 index 0000000..4b32ea0 --- /dev/null +++ b/swh/storage/tests/test_listener.py @@ -0,0 +1,46 @@ +# Copyright (C) 2018 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import json +import unittest + +from swh.storage.listener import decode + + +class ListenerUtils(unittest.TestCase): + def test_decode(self): + inputs = [ + ('content', json.dumps({ + 'sha1': '34973274ccef6ab4dfaaf86599792fa9c3fe4689', + })), + ('origin', json.dumps({ + 'url': 'https://some/origin', + 'type': 'svn', + })), + ('origin_visit', json.dumps({ + 'visit': 2, + 'origin': { + 'url': 'https://some/origin', + 'type': 'hg', + } + })) + ] + + expected_inputs = [{ + 'sha1': bytes.fromhex('34973274ccef6ab4dfaaf86599792fa9c3fe4689'), + }, { + 'url': 'https://some/origin', + 'type': 'svn', + }, { + 'visit': 2, + 'origin': { + 'url': 'https://some/origin', + 'type': 'hg' + }, + }] + + for i, (object_type, obj) in enumerate(inputs): + actual_value = decode(object_type, obj) + self.assertEqual(actual_value, expected_inputs[i]) diff --git a/tox.ini b/tox.ini index 385a77f..2f47441 100644 --- a/tox.ini +++ b/tox.ini @@ -1,25 +1,27 @@ [tox] envlist=flake8,py3 [testenv:py3] deps = .[testing] + .[listener] pytest-cov pifpaf commands = pifpaf run postgresql -- pytest --hypothesis-profile=fast --cov=swh --cov-branch {posargs} [testenv:py3-slow] deps = .[testing] + .[listener] pytest-cov pifpaf commands = pifpaf run postgresql -- pytest --hypothesis-profile=slow --cov=swh --cov-branch {posargs} [testenv:flake8] skip_install = true deps = flake8 commands = {envpython} -m flake8