Changeset View
Changeset View
Standalone View
Standalone View
swh/provenance/api/client.py
# Copyright (C) 2021 The Software Heritage developers | # Copyright (C) 2021 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
import functools | |||||
import inspect | |||||
import logging | |||||
import queue | |||||
import threading | |||||
import time | |||||
from typing import Any, Dict, Optional | |||||
import uuid | |||||
import pika | |||||
import pika.channel | |||||
import pika.connection | |||||
import pika.frame | |||||
import pika.spec | |||||
from swh.core.api.serializers import encode_data_client as encode_data | |||||
from swh.core.api.serializers import msgpack_loads as decode_data | |||||
from swh.core.statsd import statsd | |||||
from .. import get_provenance_storage | |||||
from ..interface import ProvenanceStorageInterface | |||||
from .serializers import DECODERS, ENCODERS | |||||
from .server import ProvenanceStorageRabbitMQServer | |||||
LOG_FORMAT = ( | |||||
"%(levelname) -10s %(asctime)s %(name) -30s %(funcName) " | |||||
"-35s %(lineno) -5d: %(message)s" | |||||
) | |||||
LOGGER = logging.getLogger(__name__) | |||||
STORAGE_DURATION_METRIC = "swh_provenance_storage_rabbitmq_duration_seconds" | |||||
class ResponseTimeout(Exception): | |||||
pass | |||||
class TerminateSignal(Exception): | |||||
pass | |||||
class MetaRabbitMQClient(type): | |||||
def __new__(cls, name, bases, attributes): | |||||
# For each method wrapped with @remote_api_endpoint in an API backend | |||||
# (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new | |||||
# method in RemoteStorage, with the same documentation. | |||||
# | |||||
# Note that, despite the usage of decorator magic (eg. functools.wrap), | |||||
# this never actually calls an IndexerStorage method. | |||||
backend_class = attributes.get("backend_class", None) | |||||
for base in bases: | |||||
if backend_class is not None: | |||||
break | |||||
backend_class = getattr(base, "backend_class", None) | |||||
if backend_class: | |||||
for meth_name, meth in backend_class.__dict__.items(): | |||||
if hasattr(meth, "_endpoint_path"): | |||||
cls.__add_endpoint(meth_name, meth, attributes) | |||||
return super().__new__(cls, name, bases, attributes) | |||||
@staticmethod | |||||
def __add_endpoint(meth_name, meth, attributes): | |||||
wrapped_meth = inspect.unwrap(meth) | |||||
@functools.wraps(meth) # Copy signature and doc | |||||
def meth_(*args, **kwargs): | |||||
with statsd.timed( | |||||
metric=STORAGE_DURATION_METRIC, tags={"method": meth_name} | |||||
): | |||||
# Match arguments and parameters | |||||
data = inspect.getcallargs(wrapped_meth, *args, **kwargs) | |||||
# Remove arguments that should not be passed | |||||
self = data.pop("self") | |||||
# Call storage method with remaining arguments | |||||
return getattr(self._storage, meth_name)(**data) | |||||
@functools.wraps(meth) # Copy signature and doc | |||||
def write_meth_(*args, **kwargs): | |||||
with statsd.timed( | |||||
metric=STORAGE_DURATION_METRIC, tags={"method": meth_name} | |||||
): | |||||
# Match arguments and parameters | |||||
post_data = inspect.getcallargs(wrapped_meth, *args, **kwargs) | |||||
try: | |||||
# Remove arguments that should not be passed | |||||
self = post_data.pop("self") | |||||
relation = post_data.pop("relation", None) | |||||
assert len(post_data) == 1 | |||||
if relation is not None: | |||||
items = [ | |||||
(src, rel.dst, rel.path) | |||||
for src, dsts in next(iter(post_data.values())).items() | |||||
for rel in dsts | |||||
] | |||||
else: | |||||
data = next(iter(post_data.values())) | |||||
items = ( | |||||
list(data.items()) | |||||
if isinstance(data, dict) | |||||
else list({(item,) for item in data}) | |||||
) | |||||
acks_expected = len(items) | |||||
self._correlation_id = str(uuid.uuid4()) | |||||
exchange = ProvenanceStorageRabbitMQServer.get_exchange( | |||||
meth_name, relation | |||||
) | |||||
for item in items: | |||||
routing_key = ProvenanceStorageRabbitMQServer.get_routing_key( | |||||
item, meth_name, relation | |||||
) | |||||
# FIXME: this is running in a different thread! Hence, if | |||||
# self._connection drops, there is no guarantee that the | |||||
# request can be sent for the current elements. This | |||||
# situation should be handled properly. | |||||
self._connection.ioloop.add_callback_threadsafe( | |||||
functools.partial( | |||||
ProvenanceStorageRabbitMQClient.request, | |||||
channel=self._channel, | |||||
reply_to=self._callback_queue, | |||||
exchange=exchange, | |||||
routing_key=routing_key, | |||||
correlation_id=self._correlation_id, | |||||
data=item, | |||||
) | |||||
) | |||||
return self.wait_for_acks(acks_expected) | |||||
except BaseException as ex: | |||||
self.request_termination(str(ex)) | |||||
return False | |||||
if meth_name not in attributes: | |||||
attributes[meth_name] = ( | |||||
write_meth_ | |||||
if ProvenanceStorageRabbitMQServer.is_write_method(meth_name) | |||||
else meth_ | |||||
) | |||||
class ProvenanceStorageRabbitMQClient(threading.Thread, metaclass=MetaRabbitMQClient): | |||||
"""This is an example publisher that will handle unexpected interactions | |||||
with RabbitMQ such as channel and connection closures. | |||||
If RabbitMQ closes the connection, it will reopen it. You should | |||||
look at the output, as there are limited reasons why the connection may | |||||
be closed, which usually are tied to permission related issues or | |||||
socket timeouts. | |||||
It uses delivery confirmations and illustrates one way to keep track of | |||||
messages that have been sent and if they've been confirmed by RabbitMQ. | |||||
""" | |||||
backend_class = ProvenanceStorageInterface | |||||
extra_type_decoders = DECODERS | |||||
extra_type_encoders = ENCODERS | |||||
def __init__(self, url: str, storage_config: Dict[str, Any]) -> None: | |||||
"""Setup the example publisher object, passing in the URL we will use | |||||
to connect to RabbitMQ. | |||||
:param str url: The URL for connecting to RabbitMQ | |||||
:param str routing_key: The routing key name from which this worker will | |||||
consume messages | |||||
:param str storage_config: Configuration parameters for the underlying | |||||
``ProvenanceStorage`` object | |||||
""" | |||||
super().__init__() | |||||
self._connection = None | |||||
self._callback_queue: Optional[str] = None | |||||
self._channel = None | |||||
self._closing = False | |||||
self._consumer_tag = None | |||||
self._consuming = False | |||||
self._correlation_id = str(uuid.uuid4()) | |||||
self._prefetch_count = 100 | |||||
self._response_queue: queue.Queue = queue.Queue() | |||||
self._storage = get_provenance_storage(**storage_config) | |||||
self._url = url | |||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "open"}) | |||||
def open(self) -> None: | |||||
self.start() | |||||
while self._callback_queue is None: | |||||
time.sleep(0.1) | |||||
self._storage.open() | |||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "close"}) | |||||
def close(self) -> None: | |||||
assert self._connection is not None | |||||
self._connection.ioloop.add_callback_threadsafe(self.request_termination) | |||||
self.join() | |||||
self._storage.close() | |||||
def request_termination(self, reason: str = "Normal shutdown") -> None: | |||||
assert self._connection is not None | |||||
def termination_callback(): | |||||
raise TerminateSignal(reason) | |||||
self._connection.ioloop.add_callback_threadsafe(termination_callback) | |||||
def connect(self) -> pika.SelectConnection: | |||||
"""This method connects to RabbitMQ, returning the connection handle. | |||||
When the connection is established, the on_connection_open method | |||||
will be invoked by pika. | |||||
:rtype: pika.SelectConnection | |||||
""" | |||||
LOGGER.info("Connecting to %s", self._url) | |||||
return pika.SelectConnection( | |||||
parameters=pika.URLParameters(self._url), | |||||
on_open_callback=self.on_connection_open, | |||||
on_open_error_callback=self.on_connection_open_error, | |||||
on_close_callback=self.on_connection_closed, | |||||
) | |||||
def close_connection(self) -> None: | |||||
assert self._connection is not None | |||||
self._consuming = False | |||||
if self._connection.is_closing or self._connection.is_closed: | |||||
LOGGER.info("Connection is closing or already closed") | |||||
else: | |||||
LOGGER.info("Closing connection") | |||||
self._connection.close() | |||||
def on_connection_open(self, _unused_connection: pika.SelectConnection) -> None: | |||||
"""This method is called by pika once the connection to RabbitMQ has | |||||
been established. It passes the handle to the connection object in | |||||
case we need it, but in this case, we'll just mark it unused. | |||||
:param pika.SelectConnection _unused_connection: The connection | |||||
""" | |||||
LOGGER.info("Connection opened") | |||||
self.open_channel() | |||||
def on_connection_open_error( | |||||
self, _unused_connection: pika.SelectConnection, err: Exception | |||||
) -> None: | |||||
"""This method is called by pika if the connection to RabbitMQ | |||||
can't be established. | |||||
:param pika.SelectConnection _unused_connection: The connection | |||||
:param Exception err: The error | |||||
""" | |||||
LOGGER.error("Connection open failed, reopening in 5 seconds: %s", err) | |||||
assert self._connection is not None | |||||
self._connection.ioloop.call_later(5, self._connection.ioloop.stop) | |||||
def on_connection_closed(self, _unused_connection: pika.SelectConnection, reason): | |||||
"""This method is invoked by pika when the connection to RabbitMQ is | |||||
closed unexpectedly. Since it is unexpected, we will reconnect to | |||||
RabbitMQ if it disconnects. | |||||
:param pika.connection.Connection connection: The closed connection obj | |||||
:param Exception reason: exception representing reason for loss of | |||||
connection. | |||||
""" | |||||
assert self._connection is not None | |||||
self._channel = None | |||||
if self._closing: | |||||
self._connection.ioloop.stop() | |||||
else: | |||||
LOGGER.warning("Connection closed, reopening in 5 seconds: %s", reason) | |||||
self._connection.ioloop.call_later(5, self._connection.ioloop.stop) | |||||
def open_channel(self) -> None: | |||||
"""Open a new channel with RabbitMQ by issuing the Channel.Open RPC | |||||
command. When RabbitMQ responds that the channel is open, the | |||||
on_channel_open callback will be invoked by pika. | |||||
""" | |||||
LOGGER.info("Creating a new channel") | |||||
assert self._connection is not None | |||||
self._connection.channel(on_open_callback=self.on_channel_open) | |||||
def on_channel_open(self, channel: pika.channel.Channel) -> None: | |||||
"""This method is invoked by pika when the channel has been opened. | |||||
The channel object is passed in so we can make use of it. | |||||
Since the channel is now open, we'll declare the exchange to use. | |||||
:param pika.channel.Channel channel: The channel object | |||||
""" | |||||
LOGGER.info("Channel opened") | |||||
self._channel = channel | |||||
LOGGER.info("Adding channel close callback") | |||||
assert self._channel is not None | |||||
self._channel.add_on_close_callback(callback=self.on_channel_closed) | |||||
self.setup_queue() | |||||
def on_channel_closed( | |||||
self, channel: pika.channel.Channel, reason: Exception | |||||
) -> None: | |||||
"""Invoked by pika when RabbitMQ unexpectedly closes the channel. | |||||
Channels are usually closed if you attempt to do something that | |||||
violates the protocol, such as re-declare an exchange or queue with | |||||
different parameters. In this case, we'll close the connection | |||||
to shutdown the object. | |||||
:param pika.channel.Channel: The closed channel | |||||
:param Exception reason: why the channel was closed | |||||
""" | |||||
LOGGER.warning("Channel %i was closed: %s", channel, reason) | |||||
self.close_connection() | |||||
def setup_queue(self) -> None: | |||||
"""Setup the queue on RabbitMQ by invoking the Queue.Declare RPC | |||||
command. When it is complete, the on_queue_declare_ok method will | |||||
be invoked by pika. | |||||
""" | |||||
LOGGER.info("Declaring callback queue") | |||||
assert self._channel is not None | |||||
self._channel.queue_declare( | |||||
queue="", exclusive=True, callback=self.on_queue_declare_ok | |||||
) | |||||
def on_queue_declare_ok(self, frame: pika.frame.Method) -> None: | |||||
"""Method invoked by pika when the Queue.Declare RPC call made in | |||||
setup_queue has completed. This method sets up the consumer prefetch | |||||
to only be delivered one message at a time. The consumer must | |||||
acknowledge this message before RabbitMQ will deliver another one. | |||||
You should experiment with different prefetch values to achieve desired | |||||
performance. | |||||
:param pika.frame.Method method_frame: The Queue.DeclareOk frame | |||||
""" | |||||
LOGGER.info("Binding queue to default exchanger") | |||||
assert self._channel is not None | |||||
self._callback_queue = frame.method.queue | |||||
self._channel.basic_qos( | |||||
prefetch_count=self._prefetch_count, callback=self.on_basic_qos_ok | |||||
) | |||||
def on_basic_qos_ok(self, _unused_frame: pika.frame.Method) -> None: | |||||
"""Invoked by pika when the Basic.QoS method has completed. At this | |||||
point we will start consuming messages by calling start_consuming | |||||
which will invoke the needed RPC commands to start the process. | |||||
:param pika.frame.Method _unused_frame: The Basic.QosOk response frame | |||||
""" | |||||
LOGGER.info("QOS set to: %d", self._prefetch_count) | |||||
self.start_consuming() | |||||
def start_consuming(self) -> None: | |||||
"""This method sets up the consumer by first calling | |||||
add_on_cancel_callback so that the object is notified if RabbitMQ | |||||
cancels the consumer. It then issues the Basic.Consume RPC command | |||||
which returns the consumer tag that is used to uniquely identify the | |||||
consumer with RabbitMQ. We keep the value to use it when we want to | |||||
cancel consuming. The on_response method is passed in as a callback pika | |||||
will invoke when a message is fully received. | |||||
""" | |||||
LOGGER.info("Issuing consumer related RPC commands") | |||||
LOGGER.info("Adding consumer cancellation callback") | |||||
assert self._channel is not None | |||||
self._channel.add_on_cancel_callback(callback=self.on_consumer_cancelled) | |||||
assert self._callback_queue is not None | |||||
self._consumer_tag = self._channel.basic_consume( | |||||
queue=self._callback_queue, on_message_callback=self.on_response | |||||
) | |||||
self._consuming = True | |||||
def on_consumer_cancelled(self, method_frame: pika.frame.Method) -> None: | |||||
"""Invoked by pika when RabbitMQ sends a Basic.Cancel for a consumer | |||||
receiving messages. | |||||
:param pika.frame.Method method_frame: The Basic.Cancel frame | |||||
""" | |||||
LOGGER.info("Consumer was cancelled remotely, shutting down: %r", method_frame) | |||||
if self._channel: | |||||
self._channel.close() | |||||
def on_response( | |||||
self, | |||||
channel: pika.channel.Channel, | |||||
deliver: pika.spec.Basic.Deliver, | |||||
properties: pika.spec.BasicProperties, | |||||
body: bytes, | |||||
) -> None: | |||||
"""Invoked by pika when a message is delivered from RabbitMQ. The | |||||
channel is passed for your convenience. The deliver object that | |||||
is passed in carries the exchange, routing key, delivery tag and | |||||
a redelivered flag for the message. The properties passed in is an | |||||
instance of BasicProperties with the message properties and the body | |||||
is the message that was sent. | |||||
:param pika.channel.Channel channel: The channel object | |||||
:param pika.spec.Basic.Deliver: deliver method | |||||
:param pika.spec.BasicProperties: properties | |||||
:param bytes body: The message body | |||||
""" | |||||
LOGGER.info( | |||||
"Received message # %s from %s: %s", | |||||
deliver.delivery_tag, | |||||
properties.app_id, | |||||
body, | |||||
) | |||||
self._response_queue.put( | |||||
( | |||||
properties.correlation_id, | |||||
decode_data(body, extra_decoders=self.extra_type_decoders), | |||||
) | |||||
) | |||||
LOGGER.info("Acknowledging message %s", deliver.delivery_tag) | |||||
channel.basic_ack(delivery_tag=deliver.delivery_tag) | |||||
def stop_consuming(self) -> None: | |||||
"""Tell RabbitMQ that you would like to stop consuming by sending the | |||||
Basic.Cancel RPC command. | |||||
""" | |||||
if self._channel: | |||||
LOGGER.info("Sending a Basic.Cancel RPC command to RabbitMQ") | |||||
self._channel.basic_cancel(self._consumer_tag, self.on_cancel_ok) | |||||
def on_cancel_ok(self, _unused_frame: pika.frame.Method) -> None: | |||||
"""This method is invoked by pika when RabbitMQ acknowledges the | |||||
cancellation of a consumer. At this point we will close the channel. | |||||
This will invoke the on_channel_closed method once the channel has been | |||||
closed, which will in-turn close the connection. | |||||
:param pika.frame.Method _unused_frame: The Basic.CancelOk frame | |||||
:param str|unicode consumer_tag: Tag of the consumer to be stopped | |||||
""" | |||||
self._consuming = False | |||||
LOGGER.info( | |||||
"RabbitMQ acknowledged the cancellation of the consumer: %s", | |||||
self._consumer_tag, | |||||
) | |||||
LOGGER.info("Closing the channel") | |||||
assert self._channel is not None | |||||
self._channel.close() | |||||
def run(self) -> None: | |||||
"""Run the example code by connecting and then starting the IOLoop.""" | |||||
while not self._closing: | |||||
try: | |||||
self._connection = self.connect() | |||||
assert self._connection is not None | |||||
self._connection.ioloop.start() | |||||
except KeyboardInterrupt: | |||||
LOGGER.info("Connection closed by keyboard interruption, reopening") | |||||
if self._connection is not None: | |||||
self._connection.ioloop.stop() | |||||
except TerminateSignal as ex: | |||||
LOGGER.info("Termination requested: %s", ex) | |||||
self.stop() | |||||
if self._connection is not None and not self._connection.is_closed: | |||||
# Finish closing | |||||
self._connection.ioloop.start() | |||||
except BaseException as ex: | |||||
LOGGER.warning("Unexpected exception, terminating: %s", ex) | |||||
self.stop() | |||||
if self._connection is not None and not self._connection.is_closed: | |||||
# Finish closing | |||||
self._connection.ioloop.start() | |||||
LOGGER.info("Stopped") | |||||
def stop(self) -> None: | |||||
"""Cleanly shutdown the connection to RabbitMQ by stopping the consumer | |||||
with RabbitMQ. When RabbitMQ confirms the cancellation, on_cancel_ok | |||||
will be invoked by pika, which will then closing the channel and | |||||
connection. The IOLoop is started again because this method is invoked | |||||
by raising a TerminateSignal exception. This exception stops the IOLoop | |||||
which needs to be running for pika to communicate with RabbitMQ. All of | |||||
the commands issued prior to starting the IOLoop will be buffered but | |||||
not processed. | |||||
""" | |||||
assert self._connection is not None | |||||
if not self._closing: | |||||
self._closing = True | |||||
LOGGER.info("Stopping") | |||||
if self._consuming: | |||||
self.stop_consuming() | |||||
self._connection.ioloop.start() | |||||
else: | |||||
self._connection.ioloop.stop() | |||||
LOGGER.info("Stopped") | |||||
@staticmethod | |||||
def request( | |||||
channel: pika.channel.Channel, | |||||
reply_to: str, | |||||
exchange: str, | |||||
routing_key: str, | |||||
correlation_id: str, | |||||
**kwargs, | |||||
) -> None: | |||||
channel.basic_publish( | |||||
exchange=exchange, | |||||
routing_key=routing_key, | |||||
properties=pika.BasicProperties( | |||||
content_type="application/msgpack", | |||||
correlation_id=correlation_id, | |||||
reply_to=reply_to, | |||||
), | |||||
body=encode_data( | |||||
kwargs, | |||||
extra_encoders=ProvenanceStorageRabbitMQClient.extra_type_encoders, | |||||
), | |||||
) | |||||
def wait_for_acks(self, acks_expected: int) -> bool: | |||||
acks_received = 0 | |||||
while acks_received < acks_expected: | |||||
try: | |||||
acks_received += self.wait_for_response() | |||||
except ResponseTimeout: | |||||
LOGGER.warning( | |||||
"Timed out waiting for acks, %s received, %s expected", | |||||
acks_received, | |||||
acks_expected, | |||||
) | |||||
return False | |||||
return acks_received == acks_expected | |||||
def wait_for_response(self, timeout: float = 60.0) -> Any: | |||||
while True: | |||||
try: | |||||
correlation_id, response = self._response_queue.get(timeout=timeout) | |||||
if correlation_id == self._correlation_id: | |||||
return response | |||||
except queue.Empty: | |||||
raise ResponseTimeout |