diff --git a/swh/storage/listener.py b/swh/storage/listener.py --- a/swh/storage/listener.py +++ b/swh/storage/listener.py @@ -1,4 +1,4 @@ -# Copyright (C) 2016-2017 The Software Heritage developers +# Copyright (C) 2016-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -23,18 +23,39 @@ } -def decode_sha(value): - """Decode the textual representation of a SHA hash""" - if isinstance(value, str): - return bytes.fromhex(value) - return value +def decode_simple(value): + """Decode simple values (no hex identifiers in their midst) + Args: + value (str): json representation -def decode_json(value): - """Decode a JSON value containing hashes and other types""" - value = json.loads(value) + Returns: + dict representation - return {k: decode_sha(v) for k, v in value.items()} + """ + return json.loads(value) + + +def decode_with_identifier(value): + """Decode a JSON value containing hashes and other types + + Args: + value (str): json dict representation whose value might be hex + identifier. + + Returns: + dict representation whose identifier is now an hexadecimal + string. + + """ + value = decode_simple(value) + + m = {} + for k, v in value.items(): + if isinstance(value, bytes): + v = bytes.fromhex(v) + m[k] = v + return m OBJECT_TYPES = { @@ -53,11 +74,15 @@ """Register to notifications for all object types listed in OBJECT_TYPES""" with db.transaction() as cur: for object_type in OBJECT_TYPES: + logging.debug('Register to notify events %s' % object_type) db.register_listener('new_%s' % object_type, cur) def dispatch_notify(topic_prefix, producer, notify): """Dispatch a notification to the proper topic""" + logging.debug('topic_prefix: %s' % topic_prefix) + logging.debug('producer: %s' % producer) + logging.debug('notify: %s' % notify) channel = notify.channel if not channel.startswith('new_') or channel[4:] not in OBJECT_TYPES: logging.warn("Got unexpected notify %s" % notify) @@ -66,7 +91,13 @@ object_type = channel[4:] topic = '%s.%s' % (topic_prefix, object_type) - data = decode_json(notify.payload) + # mapping function depending on the object type + mapping_fn = { + 'origin': decode_simple, + 'origin_visit': decode_simple, + } + mapping_callable = mapping_fn.get(object_type, decode_with_identifier) + data = mapping_callable(notify.payload) producer.send(topic, value=data) @@ -96,6 +127,7 @@ try: while True: for notify in db.listen_notifies(poll_timeout): + logging.debug('Notified by event %s' % notify) dispatch_notify(topic_prefix, producer, notify) producer.flush() except Exception: @@ -104,9 +136,17 @@ if __name__ == '__main__': - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s %(process)d %(levelname)s %(message)s' - ) - config = load_named_config(CONFIG_BASENAME, DEFAULT_CONFIG) - run_from_config(config) + import click + + @click.command() + @click.option('--verbose', is_flag=True, default=False, + help='Be verbose if asked.') + def main(verbose): + logging.basicConfig( + level=logging.DEBUG if verbose else logging.INFO, + format='%(asctime)s %(process)d %(levelname)s %(message)s' + ) + config = load_named_config(CONFIG_BASENAME, DEFAULT_CONFIG) + run_from_config(config) + + main() diff --git a/swh/storage/tests/test_listener.py b/swh/storage/tests/test_listener.py new file mode 100644 --- /dev/null +++ b/swh/storage/tests/test_listener.py @@ -0,0 +1,49 @@ +# Copyright (C) 2018 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import json +import unittest + +from swh.storage.listener import ( + decode_with_identifier, decode_simple +) + + +class ListenerUtils(unittest.TestCase): + def test_decode_simple(self): + input = json.dumps({ + 'url': 'https://some/origin', + 'type': 'svn', + }) + + actual_value = decode_simple(input) + + self.assertEqual(actual_value, { + 'url': 'https://some/origin', + 'type': 'svn', + }) + + def test_decode_with_identifier(self): + inputs = map(json.dumps, [{ + 'id': '34973274ccef6ab4dfaaf86599792fa9c3fe4689', + 'url': 'https://some/origin', + 'type': 'svn', + }, { + 'url': 'https://some/origin', + 'type': 'svn', + }]) + + expected_inputs = [{ + 'id': '34973274ccef6ab4dfaaf86599792fa9c3fe4689', + 'url': 'https://some/origin', + 'type': 'svn', + }, { + 'url': 'https://some/origin', + 'type': 'svn', + }] + + for i, input in enumerate(inputs): + actual_value = decode_with_identifier(input) + self.assertEqual(actual_value, expected_inputs[i])