diff --git a/mypy.ini b/mypy.ini --- a/mypy.ini +++ b/mypy.ini @@ -17,6 +17,9 @@ [mypy-msgpack.*] ignore_missing_imports = True +[mypy-pika.*] +ignore_missing_imports = True + [mypy-pkg_resources.*] ignore_missing_imports = True diff --git a/requirements.txt b/requirements.txt --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ iso8601 methodtools mongomock +pika pymongo PyYAML types-click diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py --- a/swh/provenance/__init__.py +++ b/swh/provenance/__init__.py @@ -92,4 +92,12 @@ engine = kwargs.get("engine", "pymongo") return ProvenanceStorageMongoDb(engine=engine, **kwargs["db"]) + elif cls == "rabbitmq": + from .api.client import ProvenanceStorageRabbitMQClient + + rmq_storage = ProvenanceStorageRabbitMQClient(**kwargs) + if TYPE_CHECKING: + assert isinstance(rmq_storage, ProvenanceStorageInterface) + return rmq_storage + raise ValueError diff --git a/swh/provenance/api/client.py b/swh/provenance/api/client.py --- a/swh/provenance/api/client.py +++ b/swh/provenance/api/client.py @@ -2,3 +2,585 @@ # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information + +from __future__ import annotations + +import functools +import inspect +import logging +import queue +import threading +import time +from types import TracebackType +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union +import uuid + +import pika +import pika.channel +import pika.connection +import pika.frame +import pika.spec + +from swh.core.api.serializers import encode_data_client as encode_data +from swh.core.api.serializers import msgpack_loads as decode_data +from swh.core.statsd import statsd + +from .. import get_provenance_storage +from ..interface import ProvenanceStorageInterface, RelationData, RelationType +from .serializers import DECODERS, ENCODERS +from .server import ProvenanceStorageRabbitMQServer + +LOG_FORMAT = ( + "%(levelname) -10s %(asctime)s %(name) -30s %(funcName) " + "-35s %(lineno) -5d: %(message)s" +) +LOGGER = logging.getLogger(__name__) + +STORAGE_DURATION_METRIC = "swh_provenance_storage_rabbitmq_duration_seconds" + + +class ResponseTimeout(Exception): + pass + + +class TerminateSignal(Exception): + pass + + +def split_ranges( + data: Iterable[bytes], meth_name: str, relation: Optional[RelationType] = None +) -> Dict[str, List[Tuple[Any, ...]]]: + ranges: Dict[str, List[Tuple[Any, ...]]] = {} + if relation is not None: + assert isinstance(data, dict) + for src, dsts in data.items(): + key = ProvenanceStorageRabbitMQServer.get_routing_key( + (src,), meth_name, relation + ) + for rel in dsts: + assert isinstance(rel, RelationData) + ranges.setdefault(key, []).append((src, rel.dst, rel.path)) + else: + items: Union[List[Tuple[Any, Any]], List[Tuple[Any]]] + if isinstance(data, dict): + items = list(data.items()) + else: + items = list({(item,) for item in data}) + for item in items: + key = ProvenanceStorageRabbitMQServer.get_routing_key(item, meth_name) + ranges.setdefault(key, []).append(item) + return ranges + + +class MetaRabbitMQClient(type): + def __new__(cls, name, bases, attributes): + # For each method wrapped with @remote_api_endpoint in an API backend + # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new + # method in RemoteStorage, with the same documentation. + # + # Note that, despite the usage of decorator magic (eg. functools.wrap), + # this never actually calls an IndexerStorage method. + backend_class = attributes.get("backend_class", None) + for base in bases: + if backend_class is not None: + break + backend_class = getattr(base, "backend_class", None) + if backend_class: + for meth_name, meth in backend_class.__dict__.items(): + if hasattr(meth, "_endpoint_path"): + cls.__add_endpoint(meth_name, meth, attributes) + return super().__new__(cls, name, bases, attributes) + + @staticmethod + def __add_endpoint(meth_name, meth, attributes): + wrapped_meth = inspect.unwrap(meth) + + @functools.wraps(meth) # Copy signature and doc + def meth_(*args, **kwargs): + with statsd.timed( + metric=STORAGE_DURATION_METRIC, tags={"method": meth_name} + ): + # Match arguments and parameters + data = inspect.getcallargs(wrapped_meth, *args, **kwargs) + + # Remove arguments that should not be passed + self = data.pop("self") + + # Call storage method with remaining arguments + return getattr(self._storage, meth_name)(**data) + + @functools.wraps(meth) # Copy signature and doc + def write_meth_(*args, **kwargs): + with statsd.timed( + metric=STORAGE_DURATION_METRIC, tags={"method": meth_name} + ): + # Match arguments and parameters + post_data = inspect.getcallargs(wrapped_meth, *args, **kwargs) + + try: + # Remove arguments that should not be passed + self = post_data.pop("self") + relation = post_data.pop("relation", None) + assert len(post_data) == 1 + data = next(iter(post_data.values())) + + ranges = split_ranges(data, meth_name, relation) + acks_expected = sum(len(items) for items in ranges.values()) + self._correlation_id = str(uuid.uuid4()) + + exchange = ProvenanceStorageRabbitMQServer.get_exchange( + meth_name, relation + ) + batch_size = 100 + for routing_key, items in ranges.items(): + batches = ( + items[idx : idx + batch_size] + for idx in range(0, len(items), batch_size) + ) + for batch in batches: + # FIXME: this is running in a different thread! Hence, if + # self._connection drops, there is no guarantee that the + # request can be sent for the current elements. This + # situation should be handled properly. + self._connection.ioloop.add_callback_threadsafe( + functools.partial( + ProvenanceStorageRabbitMQClient.request, + channel=self._channel, + reply_to=self._callback_queue, + exchange=exchange, + routing_key=routing_key, + correlation_id=self._correlation_id, + data=batch, + ) + ) + return self.wait_for_acks(acks_expected) + except BaseException as ex: + self.request_termination(str(ex)) + return False + + if meth_name not in attributes: + attributes[meth_name] = ( + write_meth_ + if ProvenanceStorageRabbitMQServer.is_write_method(meth_name) + else meth_ + ) + + +class ProvenanceStorageRabbitMQClient(threading.Thread, metaclass=MetaRabbitMQClient): + """This is an example publisher that will handle unexpected interactions + with RabbitMQ such as channel and connection closures. + + If RabbitMQ closes the connection, it will reopen it. You should + look at the output, as there are limited reasons why the connection may + be closed, which usually are tied to permission related issues or + socket timeouts. + + It uses delivery confirmations and illustrates one way to keep track of + messages that have been sent and if they've been confirmed by RabbitMQ. + + """ + + backend_class = ProvenanceStorageInterface + + extra_type_decoders = DECODERS + extra_type_encoders = ENCODERS + + def __init__(self, url: str, storage_config: Dict[str, Any]) -> None: + """Setup the example publisher object, passing in the URL we will use + to connect to RabbitMQ. + + :param str url: The URL for connecting to RabbitMQ + :param str routing_key: The routing key name from which this worker will + consume messages + :param str storage_config: Configuration parameters for the underlying + ``ProvenanceStorage`` object + + """ + super().__init__() + + self._connection = None + self._callback_queue: Optional[str] = None + self._channel = None + self._closing = False + self._consumer_tag = None + self._consuming = False + self._correlation_id = str(uuid.uuid4()) + self._prefetch_count = 100 + + self._response_queue: queue.Queue = queue.Queue() + self._storage = get_provenance_storage(**storage_config) + self._url = url + + def __enter__(self) -> ProvenanceStorageInterface: + self.open() + assert isinstance(self, ProvenanceStorageInterface) + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "open"}) + def open(self) -> None: + self.start() + while self._callback_queue is None: + time.sleep(0.1) + self._storage.open() + + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "close"}) + def close(self) -> None: + assert self._connection is not None + self._connection.ioloop.add_callback_threadsafe(self.request_termination) + self.join() + self._storage.close() + + def request_termination(self, reason: str = "Normal shutdown") -> None: + assert self._connection is not None + + def termination_callback(): + raise TerminateSignal(reason) + + self._connection.ioloop.add_callback_threadsafe(termination_callback) + + def connect(self) -> pika.SelectConnection: + """This method connects to RabbitMQ, returning the connection handle. + When the connection is established, the on_connection_open method + will be invoked by pika. + + :rtype: pika.SelectConnection + + """ + LOGGER.info("Connecting to %s", self._url) + return pika.SelectConnection( + parameters=pika.URLParameters(self._url), + on_open_callback=self.on_connection_open, + on_open_error_callback=self.on_connection_open_error, + on_close_callback=self.on_connection_closed, + ) + + def close_connection(self) -> None: + assert self._connection is not None + self._consuming = False + if self._connection.is_closing or self._connection.is_closed: + LOGGER.info("Connection is closing or already closed") + else: + LOGGER.info("Closing connection") + self._connection.close() + + def on_connection_open(self, _unused_connection: pika.SelectConnection) -> None: + """This method is called by pika once the connection to RabbitMQ has + been established. It passes the handle to the connection object in + case we need it, but in this case, we'll just mark it unused. + + :param pika.SelectConnection _unused_connection: The connection + + """ + LOGGER.info("Connection opened") + self.open_channel() + + def on_connection_open_error( + self, _unused_connection: pika.SelectConnection, err: Exception + ) -> None: + """This method is called by pika if the connection to RabbitMQ + can't be established. + + :param pika.SelectConnection _unused_connection: The connection + :param Exception err: The error + + """ + LOGGER.error("Connection open failed, reopening in 5 seconds: %s", err) + assert self._connection is not None + self._connection.ioloop.call_later(5, self._connection.ioloop.stop) + + def on_connection_closed(self, _unused_connection: pika.SelectConnection, reason): + """This method is invoked by pika when the connection to RabbitMQ is + closed unexpectedly. Since it is unexpected, we will reconnect to + RabbitMQ if it disconnects. + + :param pika.connection.Connection connection: The closed connection obj + :param Exception reason: exception representing reason for loss of + connection. + + """ + assert self._connection is not None + self._channel = None + if self._closing: + self._connection.ioloop.stop() + else: + LOGGER.warning("Connection closed, reopening in 5 seconds: %s", reason) + self._connection.ioloop.call_later(5, self._connection.ioloop.stop) + + def open_channel(self) -> None: + """Open a new channel with RabbitMQ by issuing the Channel.Open RPC + command. When RabbitMQ responds that the channel is open, the + on_channel_open callback will be invoked by pika. + + """ + LOGGER.info("Creating a new channel") + assert self._connection is not None + self._connection.channel(on_open_callback=self.on_channel_open) + + def on_channel_open(self, channel: pika.channel.Channel) -> None: + """This method is invoked by pika when the channel has been opened. + The channel object is passed in so we can make use of it. + + Since the channel is now open, we'll declare the exchange to use. + + :param pika.channel.Channel channel: The channel object + + """ + LOGGER.info("Channel opened") + self._channel = channel + LOGGER.info("Adding channel close callback") + assert self._channel is not None + self._channel.add_on_close_callback(callback=self.on_channel_closed) + self.setup_queue() + + def on_channel_closed( + self, channel: pika.channel.Channel, reason: Exception + ) -> None: + """Invoked by pika when RabbitMQ unexpectedly closes the channel. + Channels are usually closed if you attempt to do something that + violates the protocol, such as re-declare an exchange or queue with + different parameters. In this case, we'll close the connection + to shutdown the object. + + :param pika.channel.Channel: The closed channel + :param Exception reason: why the channel was closed + + """ + LOGGER.warning("Channel %i was closed: %s", channel, reason) + self.close_connection() + + def setup_queue(self) -> None: + """Setup the queue on RabbitMQ by invoking the Queue.Declare RPC + command. When it is complete, the on_queue_declare_ok method will + be invoked by pika. + + """ + LOGGER.info("Declaring callback queue") + assert self._channel is not None + self._channel.queue_declare( + queue="", exclusive=True, callback=self.on_queue_declare_ok + ) + + def on_queue_declare_ok(self, frame: pika.frame.Method) -> None: + """Method invoked by pika when the Queue.Declare RPC call made in + setup_queue has completed. This method sets up the consumer prefetch + to only be delivered one message at a time. The consumer must + acknowledge this message before RabbitMQ will deliver another one. + You should experiment with different prefetch values to achieve desired + performance. + + :param pika.frame.Method method_frame: The Queue.DeclareOk frame + + """ + LOGGER.info("Binding queue to default exchanger") + assert self._channel is not None + self._callback_queue = frame.method.queue + self._channel.basic_qos( + prefetch_count=self._prefetch_count, callback=self.on_basic_qos_ok + ) + + def on_basic_qos_ok(self, _unused_frame: pika.frame.Method) -> None: + """Invoked by pika when the Basic.QoS method has completed. At this + point we will start consuming messages by calling start_consuming + which will invoke the needed RPC commands to start the process. + + :param pika.frame.Method _unused_frame: The Basic.QosOk response frame + + """ + LOGGER.info("QOS set to: %d", self._prefetch_count) + self.start_consuming() + + def start_consuming(self) -> None: + """This method sets up the consumer by first calling + add_on_cancel_callback so that the object is notified if RabbitMQ + cancels the consumer. It then issues the Basic.Consume RPC command + which returns the consumer tag that is used to uniquely identify the + consumer with RabbitMQ. We keep the value to use it when we want to + cancel consuming. The on_response method is passed in as a callback pika + will invoke when a message is fully received. + + """ + LOGGER.info("Issuing consumer related RPC commands") + LOGGER.info("Adding consumer cancellation callback") + assert self._channel is not None + self._channel.add_on_cancel_callback(callback=self.on_consumer_cancelled) + assert self._callback_queue is not None + self._consumer_tag = self._channel.basic_consume( + queue=self._callback_queue, on_message_callback=self.on_response + ) + self._consuming = True + + def on_consumer_cancelled(self, method_frame: pika.frame.Method) -> None: + """Invoked by pika when RabbitMQ sends a Basic.Cancel for a consumer + receiving messages. + + :param pika.frame.Method method_frame: The Basic.Cancel frame + + """ + LOGGER.info("Consumer was cancelled remotely, shutting down: %r", method_frame) + if self._channel: + self._channel.close() + + def on_response( + self, + channel: pika.channel.Channel, + deliver: pika.spec.Basic.Deliver, + properties: pika.spec.BasicProperties, + body: bytes, + ) -> None: + """Invoked by pika when a message is delivered from RabbitMQ. The + channel is passed for your convenience. The deliver object that + is passed in carries the exchange, routing key, delivery tag and + a redelivered flag for the message. The properties passed in is an + instance of BasicProperties with the message properties and the body + is the message that was sent. + + :param pika.channel.Channel channel: The channel object + :param pika.spec.Basic.Deliver: deliver method + :param pika.spec.BasicProperties: properties + :param bytes body: The message body + + """ + LOGGER.info( + "Received message # %s from %s: %s", + deliver.delivery_tag, + properties.app_id, + body, + ) + self._response_queue.put( + ( + properties.correlation_id, + decode_data(body, extra_decoders=self.extra_type_decoders), + ) + ) + LOGGER.info("Acknowledging message %s", deliver.delivery_tag) + channel.basic_ack(delivery_tag=deliver.delivery_tag) + + def stop_consuming(self) -> None: + """Tell RabbitMQ that you would like to stop consuming by sending the + Basic.Cancel RPC command. + + """ + if self._channel: + LOGGER.info("Sending a Basic.Cancel RPC command to RabbitMQ") + self._channel.basic_cancel(self._consumer_tag, self.on_cancel_ok) + + def on_cancel_ok(self, _unused_frame: pika.frame.Method) -> None: + """This method is invoked by pika when RabbitMQ acknowledges the + cancellation of a consumer. At this point we will close the channel. + This will invoke the on_channel_closed method once the channel has been + closed, which will in-turn close the connection. + + :param pika.frame.Method _unused_frame: The Basic.CancelOk frame + :param str|unicode consumer_tag: Tag of the consumer to be stopped + + """ + self._consuming = False + LOGGER.info( + "RabbitMQ acknowledged the cancellation of the consumer: %s", + self._consumer_tag, + ) + LOGGER.info("Closing the channel") + assert self._channel is not None + self._channel.close() + + def run(self) -> None: + """Run the example code by connecting and then starting the IOLoop.""" + + while not self._closing: + try: + self._connection = self.connect() + assert self._connection is not None + self._connection.ioloop.start() + except KeyboardInterrupt: + LOGGER.info("Connection closed by keyboard interruption, reopening") + if self._connection is not None: + self._connection.ioloop.stop() + except TerminateSignal as ex: + LOGGER.info("Termination requested: %s", ex) + self.stop() + if self._connection is not None and not self._connection.is_closed: + # Finish closing + self._connection.ioloop.start() + except BaseException as ex: + LOGGER.warning("Unexpected exception, terminating: %s", ex) + self.stop() + if self._connection is not None and not self._connection.is_closed: + # Finish closing + self._connection.ioloop.start() + LOGGER.info("Stopped") + + def stop(self) -> None: + """Cleanly shutdown the connection to RabbitMQ by stopping the consumer + with RabbitMQ. When RabbitMQ confirms the cancellation, on_cancel_ok + will be invoked by pika, which will then closing the channel and + connection. The IOLoop is started again because this method is invoked + by raising a TerminateSignal exception. This exception stops the IOLoop + which needs to be running for pika to communicate with RabbitMQ. All of + the commands issued prior to starting the IOLoop will be buffered but + not processed. + + """ + assert self._connection is not None + if not self._closing: + self._closing = True + LOGGER.info("Stopping") + if self._consuming: + self.stop_consuming() + self._connection.ioloop.start() + else: + self._connection.ioloop.stop() + LOGGER.info("Stopped") + + @staticmethod + def request( + channel: pika.channel.Channel, + reply_to: str, + exchange: str, + routing_key: str, + correlation_id: str, + **kwargs, + ) -> None: + channel.basic_publish( + exchange=exchange, + routing_key=routing_key, + properties=pika.BasicProperties( + content_type="application/msgpack", + correlation_id=correlation_id, + reply_to=reply_to, + ), + body=encode_data( + kwargs, + extra_encoders=ProvenanceStorageRabbitMQClient.extra_type_encoders, + ), + ) + + def wait_for_acks(self, acks_expected: int) -> bool: + acks_received = 0 + while acks_received < acks_expected: + try: + acks_received += self.wait_for_response() + except ResponseTimeout: + LOGGER.warning( + "Timed out waiting for acks, %s received, %s expected", + acks_received, + acks_expected, + ) + return False + return acks_received == acks_expected + + def wait_for_response(self, timeout: float = 60.0) -> Any: + while True: + try: + correlation_id, response = self._response_queue.get(timeout=timeout) + if correlation_id == self._correlation_id: + return response + except queue.Empty: + raise ResponseTimeout diff --git a/swh/provenance/api/server.py b/swh/provenance/api/server.py --- a/swh/provenance/api/server.py +++ b/swh/provenance/api/server.py @@ -3,10 +3,785 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from collections import Counter +from datetime import datetime +from enum import Enum +import functools +import logging +import multiprocessing import os -from typing import Any, Dict, Optional +import queue +import threading +from typing import Any, Callable +from typing import Counter as TCounter +from typing import Dict, Generator, Iterable, List, Optional, Set, Tuple, Union, cast + +import pika +import pika.channel +import pika.connection +import pika.exceptions +from pika.exchange_type import ExchangeType +import pika.frame +import pika.spec from swh.core import config +from swh.core.api.serializers import encode_data_client as encode_data +from swh.core.api.serializers import msgpack_loads as decode_data +from swh.model.model import Sha1Git + +from .. import get_provenance_storage +from ..interface import ( + EntityType, + ProvenanceStorageInterface, + RelationData, + RelationType, + RevisionData, +) +from ..util import path_id +from .serializers import DECODERS, ENCODERS + +LOG_FORMAT = ( + "%(levelname) -10s %(asctime)s %(name) -30s %(funcName) " + "-35s %(lineno) -5d: %(message)s" +) +LOGGER = logging.getLogger(__name__) + +TERMINATE = object() + + +class ServerCommand(Enum): + TERMINATE = "terminate" + CONSUMING = "consuming" + + +class TerminateSignal(BaseException): + pass + + +def resolve_dates( + dates: Iterable[Union[Tuple[Sha1Git, Optional[datetime]], Tuple[Sha1Git]]] +) -> Dict[Sha1Git, Optional[datetime]]: + result: Dict[Sha1Git, Optional[datetime]] = {} + for row in dates: + sha1 = row[0] + date = ( + cast(Tuple[Sha1Git, Optional[datetime]], row)[1] if len(row) > 1 else None + ) + known = result.setdefault(sha1, None) + if date is not None and (known is None or date < known): + result[sha1] = date + return result + + +def resolve_revision( + data: Iterable[Union[Tuple[Sha1Git, RevisionData], Tuple[Sha1Git]]] +) -> Dict[Sha1Git, RevisionData]: + result: Dict[Sha1Git, RevisionData] = {} + for row in data: + sha1 = row[0] + rev = ( + cast(Tuple[Sha1Git, RevisionData], row)[1] + if len(row) > 1 + else RevisionData(date=None, origin=None) + ) + known = result.setdefault(sha1, RevisionData(date=None, origin=None)) + value = known + if rev.date is not None and (known.date is None or rev.date < known.date): + value = RevisionData(date=rev.date, origin=value.origin) + if rev.origin is not None: + value = RevisionData(date=value.date, origin=rev.origin) + if value != known: + result[sha1] = value + return result + + +def resolve_relation( + data: Iterable[Tuple[Sha1Git, Sha1Git, bytes]] +) -> Dict[Sha1Git, Set[RelationData]]: + result: Dict[Sha1Git, Set[RelationData]] = {} + for src, dst, path in data: + result.setdefault(src, set()).add(RelationData(dst=dst, path=path)) + return result + + +class ProvenanceStorageRabbitMQWorker(multiprocessing.Process): + """This is an example publisher that will handle unexpected interactions + with RabbitMQ such as channel and connection closures. + + If RabbitMQ closes the connection, it will reopen it. You should + look at the output, as there are limited reasons why the connection may + be closed, which usually are tied to permission related issues or + socket timeouts. + + It uses delivery confirmations and illustrates one way to keep track of + messages that have been sent and if they've been confirmed by RabbitMQ. + + """ + + EXCHANGE_TYPE = ExchangeType.direct + + extra_type_decoders = DECODERS + extra_type_encoders = ENCODERS + + def __init__( + self, url: str, exchange: str, range: int, storage_config: Dict[str, Any] + ) -> None: + """Setup the example publisher object, passing in the URL we will use + to connect to RabbitMQ. + + :param str url: The URL for connecting to RabbitMQ + :param str routing_key: The routing key name from which this worker will + consume messages + :param str storage_config: Configuration parameters for the underlying + ``ProvenanceStorage`` object + + """ + super().__init__(name=f"{exchange}_{range:x}") + + self._connection = None + self._channel = None + self._closing = False + self._consumer_tag: Dict[str, str] = {} + self._consuming: Dict[str, bool] = {} + self._prefetch_count = 100 + + self._url = url + self._exchange = exchange + self._binding_keys = list( + ProvenanceStorageRabbitMQServer.get_binding_keys(self._exchange, range) + ) + self._queues: Dict[str, str] = {} + self._storage_config = storage_config + self._batch_size = 100 + + self.command: multiprocessing.Queue = multiprocessing.Queue() + self.signal: multiprocessing.Queue = multiprocessing.Queue() + + def connect(self) -> pika.SelectConnection: + """This method connects to RabbitMQ, returning the connection handle. + When the connection is established, the on_connection_open method + will be invoked by pika. + + :rtype: pika.SelectConnection + + """ + LOGGER.info("Connecting to %s", self._url) + return pika.SelectConnection( + parameters=pika.URLParameters(self._url), + on_open_callback=self.on_connection_open, + on_open_error_callback=self.on_connection_open_error, + on_close_callback=self.on_connection_closed, + ) + + def close_connection(self) -> None: + assert self._connection is not None + self._consuming = {binding_key: False for binding_key in self._binding_keys} + if self._connection.is_closing or self._connection.is_closed: + LOGGER.info("Connection is closing or already closed") + else: + LOGGER.info("Closing connection") + self._connection.close() + + def on_connection_open(self, _unused_connection: pika.SelectConnection) -> None: + """This method is called by pika once the connection to RabbitMQ has + been established. It passes the handle to the connection object in + case we need it, but in this case, we'll just mark it unused. + + :param pika.SelectConnection _unused_connection: The connection + + """ + LOGGER.info("Connection opened") + self.open_channel() + + def on_connection_open_error( + self, _unused_connection: pika.SelectConnection, err: Exception + ) -> None: + """This method is called by pika if the connection to RabbitMQ + can't be established. + + :param pika.SelectConnection _unused_connection: The connection + :param Exception err: The error + + """ + LOGGER.error("Connection open failed, reopening in 5 seconds: %s", err) + assert self._connection is not None + self._connection.ioloop.call_later(5, self._connection.ioloop.stop) + + def on_connection_closed(self, _unused_connection: pika.SelectConnection, reason): + """This method is invoked by pika when the connection to RabbitMQ is + closed unexpectedly. Since it is unexpected, we will reconnect to + RabbitMQ if it disconnects. + + :param pika.connection.Connection connection: The closed connection obj + :param Exception reason: exception representing reason for loss of + connection. + + """ + assert self._connection is not None + self._channel = None + if self._closing: + self._connection.ioloop.stop() + else: + LOGGER.warning("Connection closed, reopening in 5 seconds: %s", reason) + self._connection.ioloop.call_later(5, self._connection.ioloop.stop) + + def open_channel(self) -> None: + """Open a new channel with RabbitMQ by issuing the Channel.Open RPC + command. When RabbitMQ responds that the channel is open, the + on_channel_open callback will be invoked by pika. + + """ + LOGGER.info("Creating a new channel") + assert self._connection is not None + self._connection.channel(on_open_callback=self.on_channel_open) + + def on_channel_open(self, channel: pika.channel.Channel) -> None: + """This method is invoked by pika when the channel has been opened. + The channel object is passed in so we can make use of it. + + Since the channel is now open, we'll declare the exchange to use. + + :param pika.channel.Channel channel: The channel object + + """ + LOGGER.info("Channel opened") + self._channel = channel + LOGGER.info("Adding channel close callback") + assert self._channel is not None + self._channel.add_on_close_callback(callback=self.on_channel_closed) + self.setup_exchange() + + def on_channel_closed( + self, channel: pika.channel.Channel, reason: Exception + ) -> None: + """Invoked by pika when RabbitMQ unexpectedly closes the channel. + Channels are usually closed if you attempt to do something that + violates the protocol, such as re-declare an exchange or queue with + different parameters. In this case, we'll close the connection + to shutdown the object. + + :param pika.channel.Channel: The closed channel + :param Exception reason: why the channel was closed + + """ + LOGGER.warning("Channel %i was closed: %s", channel, reason) + self.close_connection() + + def setup_exchange(self) -> None: + """Setup the exchange on RabbitMQ by invoking the Exchange.Declare RPC + command. When it is complete, the on_exchange_declare_ok method will + be invoked by pika. + + """ + LOGGER.info("Declaring exchange %s", self._exchange) + assert self._channel is not None + self._channel.exchange_declare( + exchange=self._exchange, + exchange_type=self.EXCHANGE_TYPE, + callback=self.on_exchange_declare_ok, + ) + + def on_exchange_declare_ok(self, _unused_frame: pika.frame.Method) -> None: + """Invoked by pika when RabbitMQ has finished the Exchange.Declare RPC + command. + + :param pika.frame.Method unused_frame: Exchange.DeclareOk response frame + + """ + LOGGER.info("Exchange declared: %s", self._exchange) + self.setup_queues() + + def setup_queues(self) -> None: + """Setup the queues on RabbitMQ by invoking the Queue.Declare RPC + command. When it is complete, the on_queue_declare_ok method will + be invoked by pika. + + """ + for binding_key in self._binding_keys: + LOGGER.info("Declaring queue %s", binding_key) + assert self._channel is not None + callback = functools.partial( + self.on_queue_declare_ok, + binding_key=binding_key, + ) + self._channel.queue_declare(queue=binding_key, callback=callback) + + def on_queue_declare_ok(self, frame: pika.frame.Method, binding_key: str) -> None: + """Method invoked by pika when the Queue.Declare RPC call made in + setup_queue has completed. In this method we will bind the queue + and exchange together with the routing key by issuing the Queue.Bind + RPC command. When this command is complete, the on_bind_ok method will + be invoked by pika. + + :param pika.frame.Method frame: The Queue.DeclareOk frame + :param str|unicode binding_key: Binding key of the queue to declare + + """ + LOGGER.info( + "Binding queue %s to exchange %s with routing key %s", + frame.method.queue, + self._exchange, + binding_key, + ) + assert self._channel is not None + callback = functools.partial(self.on_bind_ok, queue_name=frame.method.queue) + self._queues[binding_key] = frame.method.queue + self._channel.queue_bind( + queue=frame.method.queue, + exchange=self._exchange, + routing_key=binding_key, + callback=callback, + ) + + def on_bind_ok(self, _unused_frame: pika.frame.Method, queue_name: str) -> None: + """Invoked by pika when the Queue.Bind method has completed. At this + point we will set the prefetch count for the channel. + + :param pika.frame.Method _unused_frame: The Queue.BindOk response frame + :param str|unicode queue_name: The name of the queue to declare + + """ + LOGGER.info("Queue bound: %s", queue_name) + self.set_qos() + + def set_qos(self) -> None: + """This method sets up the consumer prefetch to only be delivered + one message at a time. The consumer must acknowledge this message + before RabbitMQ will deliver another one. You should experiment + with different prefetch values to achieve desired performance. + + """ + assert self._channel is not None + self._channel.basic_qos( + prefetch_count=self._prefetch_count, callback=self.on_basic_qos_ok + ) + + def on_basic_qos_ok(self, _unused_frame: pika.frame.Method) -> None: + """Invoked by pika when the Basic.QoS method has completed. At this + point we will start consuming messages by calling start_consuming + which will invoke the needed RPC commands to start the process. + + :param pika.frame.Method _unused_frame: The Basic.QosOk response frame + + """ + LOGGER.info("QOS set to: %d", self._prefetch_count) + self.start_consuming() + + def start_consuming(self) -> None: + """This method sets up the consumer by first calling + add_on_cancel_callback so that the object is notified if RabbitMQ + cancels the consumer. It then issues the Basic.Consume RPC command + which returns the consumer tag that is used to uniquely identify the + consumer with RabbitMQ. We keep the value to use it when we want to + cancel consuming. The on_request method is passed in as a callback pika + will invoke when a message is fully received. + + """ + LOGGER.info("Issuing consumer related RPC commands") + LOGGER.info("Adding consumer cancellation callback") + assert self._channel is not None + self._channel.add_on_cancel_callback(callback=self.on_consumer_cancelled) + for binding_key in self._binding_keys: + self._consumer_tag[binding_key] = self._channel.basic_consume( + queue=self._queues[binding_key], on_message_callback=self.on_request + ) + self._consuming[binding_key] = True + self.signal.put(ServerCommand.CONSUMING) + + def on_consumer_cancelled(self, method_frame: pika.frame.Method) -> None: + """Invoked by pika when RabbitMQ sends a Basic.Cancel for a consumer + receiving messages. + + :param pika.frame.Method method_frame: The Basic.Cancel frame + + """ + LOGGER.info("Consumer was cancelled remotely, shutting down: %r", method_frame) + if self._channel: + self._channel.close() + + def on_request( + self, + channel: pika.channel.Channel, + deliver: pika.spec.Basic.Deliver, + properties: pika.spec.BasicProperties, + body: bytes, + ) -> None: + """Invoked by pika when a message is delivered from RabbitMQ. The + channel is passed for your convenience. The deliver object that + is passed in carries the exchange, routing key, delivery tag and + a redelivered flag for the message. The properties passed in is an + instance of BasicProperties with the message properties and the body + is the message that was sent. + + :param pika.channel.Channel channel: The channel object + :param pika.spec.Basic.Deliver: deliver method + :param pika.spec.BasicProperties: properties + :param bytes body: The message body + + """ + LOGGER.info( + "Received message # %s from %s: %s", + deliver.delivery_tag, + properties.app_id, + body, + ) + # XXX: for some reason this function is returning lists instead of tuples + # (the client send tuples) + batch = decode_data(data=body, extra_decoders=self.extra_type_decoders)["data"] + for item in batch: + self._request_queues[deliver.routing_key].put( + (tuple(item), (properties.correlation_id, properties.reply_to)) + ) + LOGGER.info("Acknowledging message %s", deliver.delivery_tag) + channel.basic_ack(delivery_tag=deliver.delivery_tag) + + def stop_consuming(self) -> None: + """Tell RabbitMQ that you would like to stop consuming by sending the + Basic.Cancel RPC command. + + """ + if self._channel: + LOGGER.info("Sending a Basic.Cancel RPC command to RabbitMQ") + for binding_key in self._binding_keys: + callback = functools.partial(self.on_cancel_ok, binding_key=binding_key) + self._channel.basic_cancel( + self._consumer_tag[binding_key], callback=callback + ) + + def on_cancel_ok(self, _unused_frame: pika.frame.Method, binding_key: str) -> None: + """This method is invoked by pika when RabbitMQ acknowledges the + cancellation of a consumer. At this point we will close the channel. + This will invoke the on_channel_closed method once the channel has been + closed, which will in-turn close the connection. + + :param pika.frame.Method _unused_frame: The Basic.CancelOk frame + :param str|unicode binding_key: Binding key of of the consumer to be stopped + + """ + self._consuming[binding_key] = False + LOGGER.info( + "RabbitMQ acknowledged the cancellation of the consumer: %s", + self._consuming[binding_key], + ) + LOGGER.info("Closing the channel") + assert self._channel is not None + self._channel.close() + + def run(self) -> None: + """Run the example code by connecting and then starting the IOLoop.""" + + self._command_thread = threading.Thread(target=self.run_command_thread) + self._command_thread.start() + + self._request_queues: Dict[str, queue.Queue] = {} + self._request_threads: Dict[str, threading.Thread] = {} + for binding_key in self._binding_keys: + meth_name, relation = ProvenanceStorageRabbitMQServer.get_meth_name( + binding_key + ) + self._request_queues[binding_key] = queue.Queue() + self._request_threads[binding_key] = threading.Thread( + target=self.run_request_thread, + args=(binding_key, meth_name, relation), + ) + self._request_threads[binding_key].start() + + while not self._closing: + try: + self._connection = self.connect() + assert self._connection is not None + self._connection.ioloop.start() + except KeyboardInterrupt: + LOGGER.info("Connection closed by keyboard interruption, reopening") + if self._connection is not None: + self._connection.ioloop.stop() + except TerminateSignal as ex: + LOGGER.info("Termination requested: %s", ex) + self.stop() + if self._connection is not None and not self._connection.is_closed: + # Finish closing + self._connection.ioloop.start() + except BaseException as ex: + LOGGER.warning("Unexpected exception, terminating: %s", ex) + self.stop() + if self._connection is not None and not self._connection.is_closed: + # Finish closing + self._connection.ioloop.start() + + for binding_key in self._binding_keys: + self._request_queues[binding_key].put(TERMINATE) + for binding_key in self._binding_keys: + self._request_threads[binding_key].join() + self._command_thread.join() + LOGGER.info("Stopped") + + def run_command_thread(self) -> None: + while True: + try: + command = self.command.get() + if command == ServerCommand.TERMINATE: + self.request_termination() + break + except queue.Empty: + pass + except BaseException as ex: + self.request_termination(str(ex)) + break + + def request_termination(self, reason: str = "Normal shutdown") -> None: + assert self._connection is not None + + def termination_callback(): + raise TerminateSignal(reason) + + self._connection.ioloop.add_callback_threadsafe(termination_callback) + + def run_request_thread( + self, binding_key: str, meth_name: str, relation: Optional[RelationType] + ) -> None: + with get_provenance_storage(**self._storage_config) as storage: + request_queue = self._request_queues[binding_key] + merge_items = ProvenanceStorageRabbitMQWorker.get_conflicts_func(meth_name) + while True: + terminate = False + elements = [] + while True: + try: + # TODO: consider reducing this timeout or removing it + elem = request_queue.get(timeout=0.1) + if elem is TERMINATE: + terminate = True + break + elements.append(elem) + except queue.Empty: + break + + if len(elements) >= self._batch_size: + break + + if terminate: + break + + if not elements: + continue + + try: + items, props = zip(*elements) + acks_count: TCounter[Tuple[str, str]] = Counter(props) + data = merge_items(items) + + args = (relation, data) if relation is not None else (data,) + if getattr(storage, meth_name)(*args): + for (correlation_id, reply_to), count in acks_count.items(): + # FIXME: this is running in a different thread! Hence, if + # self._connection drops, there is no guarantee that the + # response can be sent for the current elements. This + # situation should be handled properly. + assert self._connection is not None + self._connection.ioloop.add_callback_threadsafe( + functools.partial( + ProvenanceStorageRabbitMQServer.respond, + channel=self._channel, + correlation_id=correlation_id, + reply_to=reply_to, + response=count, + ) + ) + else: + LOGGER.warning( + "Unable to process elements for queue %s", binding_key + ) + for elem in elements: + request_queue.put(elem) + except BaseException as ex: + self.request_termination(str(ex)) + break + + def stop(self) -> None: + """Cleanly shutdown the connection to RabbitMQ by stopping the consumer + with RabbitMQ. When RabbitMQ confirms the cancellation, on_cancel_ok + will be invoked by pika, which will then closing the channel and + connection. The IOLoop is started again because this method is invoked + by raising a TerminateSignal exception. This exception stops the IOLoop + which needs to be running for pika to communicate with RabbitMQ. All of + the commands issued prior to starting the IOLoop will be buffered but + not processed. + + """ + assert self._connection is not None + if not self._closing: + self._closing = True + LOGGER.info("Stopping") + if any(self._consuming): + self.stop_consuming() + self._connection.ioloop.start() + else: + self._connection.ioloop.stop() + LOGGER.info("Stopped") + + @staticmethod + def get_conflicts_func(meth_name: str) -> Callable[[Iterable[Any]], Any]: + if meth_name in ["content_add", "directory_add"]: + return resolve_dates + elif meth_name == "location_add": + return lambda data: set(data) # just remove duplicates + elif meth_name == "origin_add": + return lambda data: dict(data) # last processed value is good enough + elif meth_name == "revision_add": + return resolve_revision + elif meth_name == "relation_add": + return resolve_relation + else: + LOGGER.warning( + "Unexpected conflict resolution function request for method %s", + meth_name, + ) + return lambda x: x + + +class ProvenanceStorageRabbitMQServer: + backend_class = ProvenanceStorageInterface + extra_type_decoders = DECODERS + extra_type_encoders = ENCODERS + + queue_count = 16 + + def __init__(self, url: str, storage_config: Dict[str, Any]) -> None: + self._workers: List[ProvenanceStorageRabbitMQWorker] = [] + for exchange in ProvenanceStorageRabbitMQServer.get_exchanges(): + for range in ProvenanceStorageRabbitMQServer.get_ranges(exchange): + worker = ProvenanceStorageRabbitMQWorker( + url, exchange, range, storage_config + ) + self._workers.append(worker) + self._running = False + + def start(self) -> None: + if not self._running: + self._running = True + for worker in self._workers: + worker.start() + for worker in self._workers: + try: + signal = worker.signal.get(timeout=60) + assert signal == ServerCommand.CONSUMING + except queue.Empty: + LOGGER.error( + "Could not initialize worker %s. Leaving...", worker.name + ) + self.stop() + return + LOGGER.info("Start serving") + + def stop(self) -> None: + if self._running: + for worker in self._workers: + worker.command.put(ServerCommand.TERMINATE) + for worker in self._workers: + worker.join() + LOGGER.info("Stop serving") + self._running = False + + @staticmethod + def ack(channel: pika.channel.Channel, delivery_tag: int) -> None: + channel.basic_ack(delivery_tag=delivery_tag) + + @staticmethod + def get_binding_keys(exchange: str, range: int) -> Generator[str, None, None]: + for meth_name, relation in ProvenanceStorageRabbitMQServer.get_meth_names( + exchange + ): + if relation is None: + yield f"{meth_name}.unknown.{range:x}".lower() + else: + yield f"{meth_name}.{relation.value}.{range:x}".lower() + + @staticmethod + def get_exchange(meth_name: str, relation: Optional[RelationType] = None) -> str: + if meth_name == "relation_add": + assert relation is not None + split = relation.value + else: + split = meth_name + exchange, *_ = split.split("_") + return exchange + + @staticmethod + def get_exchanges() -> Generator[str, None, None]: + yield from [entity.value for entity in EntityType] + ["location"] + + @staticmethod + def get_meth_name( + binding_key: str, + ) -> Tuple[str, Optional[RelationType]]: + meth_name, relation, *_ = binding_key.split(".") + return meth_name, (RelationType(relation) if relation != "unknown" else None) + + @staticmethod + def get_meth_names( + exchange: str, + ) -> Generator[Tuple[str, Optional[RelationType]], None, None]: + if exchange == EntityType.CONTENT.value: + yield from [ + ("content_add", None), + ("relation_add", RelationType.CNT_EARLY_IN_REV), + ("relation_add", RelationType.CNT_IN_DIR), + ] + elif exchange == EntityType.DIRECTORY.value: + yield from [ + ("directory_add", None), + ("relation_add", RelationType.DIR_IN_REV), + ] + elif exchange == EntityType.ORIGIN.value: + yield from [("origin_add", None)] + elif exchange == EntityType.REVISION.value: + yield from [ + ("revision_add", None), + ("relation_add", RelationType.REV_BEFORE_REV), + ("relation_add", RelationType.REV_IN_ORG), + ] + elif exchange == "location": + yield "location_add", None + + @staticmethod + def get_ranges(unused_exchange: str) -> Generator[int, None, None]: + # XXX: we might want to have a different range per exchange + yield from range(ProvenanceStorageRabbitMQServer.queue_count) + + @staticmethod + def get_routing_key( + data: Tuple[bytes, ...], meth_name: str, relation: Optional[RelationType] = None + ) -> str: + hashid = path_id(data[0]) if meth_name.startswith("location") else data[0] + idx = int(hashid[0]) % ProvenanceStorageRabbitMQServer.queue_count + if relation is None: + return f"{meth_name}.unknown.{idx:x}".lower() + else: + return f"{meth_name}.{relation.value}.{idx:x}".lower() + + @staticmethod + def is_write_method(meth_name: str) -> bool: + return "_add" in meth_name + + @staticmethod + def respond( + channel: pika.channel.Channel, + correlation_id: str, + reply_to: str, + response: Any, + ): + channel.basic_publish( + exchange="", + routing_key=reply_to, + properties=pika.BasicProperties( + content_type="application/msgpack", + correlation_id=correlation_id, + ), + body=encode_data( + response, + extra_encoders=ProvenanceStorageRabbitMQServer.extra_type_encoders, + ), + ) def load_and_check_config( @@ -39,9 +814,13 @@ if pcfg is None: raise KeyError("Missing 'provenance' configuration") - scfg: Optional[Dict[str, Any]] = pcfg.get("storage") + rcfg: Optional[Dict[str, Any]] = pcfg.get("rabbitmq") + if rcfg is None: + raise KeyError("Missing 'provenance.rabbitmq' configuration") + + scfg: Optional[Dict[str, Any]] = rcfg.get("storage_config") if scfg is None: - raise KeyError("Missing 'provenance.storage' configuration") + raise KeyError("Missing 'provenance.rabbitmq.storage_config' configuration") if type == "local": cls = scfg.get("cls") @@ -56,3 +835,12 @@ raise KeyError("Invalid configuration; missing 'db' config entry") return cfg + + +def make_server_from_configfile() -> ProvenanceStorageRabbitMQServer: + config_path = os.environ.get("SWH_CONFIG_FILENAME") + server_cfg = load_and_check_config(config_path) + return ProvenanceStorageRabbitMQServer( + url=server_cfg["provenance"]["rabbitmq"]["url"], + storage_config=server_cfg["provenance"]["rabbitmq"]["storage_config"], + ) diff --git a/swh/provenance/cli.py b/swh/provenance/cli.py --- a/swh/provenance/cli.py +++ b/swh/provenance/cli.py @@ -42,18 +42,30 @@ }, "storage": { # Local PostgreSQL Storage - "cls": "postgresql", - "db": { - "host": "localhost", - "user": "postgres", - "password": "postgres", - "dbname": "provenance", - }, + # "cls": "postgresql", + # "db": { + # "host": "localhost", + # "user": "postgres", + # "password": "postgres", + # "dbname": "provenance", + # }, # Local MongoDB Storage # "cls": "mongodb", # "db": { # "dbname": "provenance", # }, + # Remote RabbitMQ/PostgreSQL Storage + "cls": "rabbitmq", + "url": "amqp://localhost:5672/%2f", + "storage_config": { + "cls": "postgresql", + "db": { + "host": "localhost", + "user": "postgres", + "password": "postgres", + "dbname": "provenance", + }, + }, }, } } diff --git a/swh/provenance/util.py b/swh/provenance/util.py --- a/swh/provenance/util.py +++ b/swh/provenance/util.py @@ -3,8 +3,13 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import hashlib import os +def path_id(path: bytes) -> bytes: + return hashlib.sha1(path).digest() + + def path_normalize(path: bytes) -> bytes: return path[2:] if path.startswith(bytes("." + os.path.sep, "utf-8")) else path