diff --git a/mypy.ini b/mypy.ini --- a/mypy.ini +++ b/mypy.ini @@ -17,6 +17,9 @@ [mypy-msgpack.*] ignore_missing_imports = True +[mypy-pika.*] +ignore_missing_imports = True + [mypy-pkg_resources.*] ignore_missing_imports = True diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py --- a/swh/provenance/__init__.py +++ b/swh/provenance/__init__.py @@ -96,12 +96,12 @@ db = MongoClient(**kwargs["db"]).get_database(dbname) return ProvenanceStorageMongoDb(db) - elif cls == "remote": - from .api.client import RemoteProvenanceStorage + elif cls == "rabbitmq": + from .api.client import ProvenanceStorageRabbitMQClient - storage = RemoteProvenanceStorage(**kwargs) - assert isinstance(storage, ProvenanceStorageInterface) - return storage + rmq_storage = ProvenanceStorageRabbitMQClient(**kwargs) + if TYPE_CHECKING: + assert isinstance(rmq_storage, ProvenanceStorageInterface) + return rmq_storage - else: - raise ValueError + raise ValueError diff --git a/swh/provenance/api/client.py b/swh/provenance/api/client.py --- a/swh/provenance/api/client.py +++ b/swh/provenance/api/client.py @@ -3,15 +3,521 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from swh.core.api import RPCClient +import functools +import inspect +import logging +import queue +import threading +import time +from typing import Any, Dict, Optional +import uuid + +import pika +import pika.channel +import pika.connection +import pika.frame +import pika.spec + +from swh.core.api.serializers import encode_data_client as encode_data +from swh.core.api.serializers import msgpack_loads as decode_data +from swh.provenance import get_provenance_storage from ..interface import ProvenanceStorageInterface from .serializers import DECODERS, ENCODERS +from .server import ProvenanceStorageRabbitMQServer + +LOG_FORMAT = ( + "%(levelname) -10s %(asctime)s %(name) -30s %(funcName) " + "-35s %(lineno) -5d: %(message)s" +) +LOGGER = logging.getLogger(__name__) + + +class ConfigurationError(Exception): + pass + + +class ResponseTimeout(Exception): + pass + + +class TerminateSignal(Exception): + pass + + +class MetaRabbitMQClient(type): + def __new__(cls, name, bases, attributes): + # For each method wrapped with @remote_api_endpoint in an API backend + # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new + # method in RemoteStorage, with the same documentation. + # + # Note that, despite the usage of decorator magic (eg. functools.wrap), + # this never actually calls an IndexerStorage method. + backend_class = attributes.get("backend_class", None) + for base in bases: + if backend_class is not None: + break + backend_class = getattr(base, "backend_class", None) + if backend_class: + for meth_name, meth in backend_class.__dict__.items(): + if hasattr(meth, "_endpoint_path"): + cls.__add_endpoint(meth_name, meth, attributes) + return super().__new__(cls, name, bases, attributes) + + @staticmethod + def __add_endpoint(meth_name, meth, attributes): + wrapped_meth = inspect.unwrap(meth) + + @functools.wraps(meth) # Copy signature and doc + def meth_(*args, **kwargs): + # Match arguments and parameters + data = inspect.getcallargs(wrapped_meth, *args, **kwargs) + + # Remove arguments that should not be passed + self = data.pop("self") + + # Call storage method with remaining arguments + return getattr(self._storage, meth_name)(**data) + + @functools.wraps(meth) # Copy signature and doc + def write_meth_(*args, **kwargs): + # Match arguments and parameters + post_data = inspect.getcallargs(wrapped_meth, *args, **kwargs) + + # Remove arguments that should not be passed + self = post_data.pop("self") + relation = post_data.pop("relation", None) + assert len(post_data) == 1 + if relation is not None: + items = [ + (src, rel.dst, rel.path) + for src, dsts in next(iter(post_data.values())).items() + for rel in dsts + ] + else: + data = next(iter(post_data.values())) + items = ( + list(data.items()) + if isinstance(data, dict) + else list({(item,) for item in data}) + ) + + acks_expected = len(items) + self._correlation_id = str(uuid.uuid4()) + exchange = ProvenanceStorageRabbitMQServer.get_exchange(meth_name, relation) + for item in items: + routing_key = ProvenanceStorageRabbitMQServer.get_routing_key( + item, meth_name, relation + ) + # FIXME: this is running in a different thread! Hence, if + # self._connection drops, there is no guarantee that the request can + # be sent for the current elements. This situation should be handled + # properly. + while self._callback_queue is None: + time.sleep(0.1) + self._connection.ioloop.add_callback_threadsafe( + functools.partial( + ProvenanceStorageRabbitMQClient.request, + channel=self._channel, + reply_to=self._callback_queue, + exchange=exchange, + routing_key=routing_key, + correlation_id=self._correlation_id, + data=item, + ) + ) + return self.wait_for_acks(acks_expected) + + if meth_name not in attributes: + attributes[meth_name] = ( + write_meth_ + if ProvenanceStorageRabbitMQServer.is_write_method(meth_name) + else meth_ + ) + +class ProvenanceStorageRabbitMQClient(metaclass=MetaRabbitMQClient): + """This is an example publisher that will handle unexpected interactions + with RabbitMQ such as channel and connection closures. -class RemoteProvenanceStorage(RPCClient): - """Proxy to a remote provenance storage API""" + If RabbitMQ closes the connection, it will reopen it. You should + look at the output, as there are limited reasons why the connection may + be closed, which usually are tied to permission related issues or + socket timeouts. + + It uses delivery confirmations and illustrates one way to keep track of + messages that have been sent and if they've been confirmed by RabbitMQ. + + """ backend_class = ProvenanceStorageInterface + extra_type_decoders = DECODERS extra_type_encoders = ENCODERS + + def __init__(self, url: str, storage_config: Dict[str, Any]) -> None: + """Setup the example publisher object, passing in the URL we will use + to connect to RabbitMQ. + + :param str url: The URL for connecting to RabbitMQ + :param str routing_key: The routing key name from which this worker will + consume messages + :param str storage_config: Configuration parameters for the underlying + ``ProvenanceStorage`` object + + """ + self._connection = None + self._callback_queue: Optional[str] = None + self._channel = None + self._closing = False + self._consumer_tag = None + self._consuming = False + self._correlation_id = str(uuid.uuid4()) + self._prefetch_count = 100 + + self._response_queue: queue.Queue = queue.Queue() + self._storage = get_provenance_storage(**storage_config) + self._url = url + + self._consumer_thread = threading.Thread(target=self.run) + self._consumer_thread.start() + + def __del__(self) -> None: + assert self._connection is not None + self._connection.ioloop.add_callback_threadsafe(self._request_terminate) + self._consumer_thread.join() + + def _request_terminate(self) -> None: + raise TerminateSignal + + def connect(self) -> pika.SelectConnection: + """This method connects to RabbitMQ, returning the connection handle. + When the connection is established, the on_connection_open method + will be invoked by pika. + + :rtype: pika.SelectConnection + + """ + LOGGER.info("Connecting to %s", self._url) + return pika.SelectConnection( + parameters=pika.URLParameters(self._url), + on_open_callback=self.on_connection_open, + on_open_error_callback=self.on_connection_open_error, + on_close_callback=self.on_connection_closed, + ) + + def close_connection(self) -> None: + assert self._connection is not None + self._consuming = False + if self._connection.is_closing or self._connection.is_closed: + LOGGER.info("Connection is closing or already closed") + else: + LOGGER.info("Closing connection") + self._connection.close() + + def on_connection_open(self, _unused_connection: pika.SelectConnection) -> None: + """This method is called by pika once the connection to RabbitMQ has + been established. It passes the handle to the connection object in + case we need it, but in this case, we'll just mark it unused. + + :param pika.SelectConnection _unused_connection: The connection + + """ + LOGGER.info("Connection opened") + self.open_channel() + + def on_connection_open_error( + self, _unused_connection: pika.SelectConnection, err: Exception + ) -> None: + """This method is called by pika if the connection to RabbitMQ + can't be established. + + :param pika.SelectConnection _unused_connection: The connection + :param Exception err: The error + + """ + LOGGER.error("Connection open failed, reopening in 5 seconds: %s", err) + assert self._connection is not None + self._connection.ioloop.call_later(5, self._connection.ioloop.stop) + + def on_connection_closed(self, _unused_connection: pika.SelectConnection, reason): + """This method is invoked by pika when the connection to RabbitMQ is + closed unexpectedly. Since it is unexpected, we will reconnect to + RabbitMQ if it disconnects. + + :param pika.connection.Connection connection: The closed connection obj + :param Exception reason: exception representing reason for loss of + connection. + + """ + assert self._connection is not None + self._channel = None + if self._closing: + self._connection.ioloop.stop() + else: + LOGGER.warning("Connection closed, reopening in 5 seconds: %s", reason) + self._connection.ioloop.call_later(5, self._connection.ioloop.stop) + + def open_channel(self) -> None: + """Open a new channel with RabbitMQ by issuing the Channel.Open RPC + command. When RabbitMQ responds that the channel is open, the + on_channel_open callback will be invoked by pika. + + """ + LOGGER.info("Creating a new channel") + assert self._connection is not None + self._connection.channel(on_open_callback=self.on_channel_open) + + def on_channel_open(self, channel: pika.channel.Channel) -> None: + """This method is invoked by pika when the channel has been opened. + The channel object is passed in so we can make use of it. + + Since the channel is now open, we'll declare the exchange to use. + + :param pika.channel.Channel channel: The channel object + + """ + LOGGER.info("Channel opened") + self._channel = channel + LOGGER.info("Adding channel close callback") + assert self._channel is not None + self._channel.add_on_close_callback(callback=self.on_channel_closed) + self.setup_queue() + + def on_channel_closed( + self, channel: pika.channel.Channel, reason: Exception + ) -> None: + """Invoked by pika when RabbitMQ unexpectedly closes the channel. + Channels are usually closed if you attempt to do something that + violates the protocol, such as re-declare an exchange or queue with + different parameters. In this case, we'll close the connection + to shutdown the object. + + :param pika.channel.Channel: The closed channel + :param Exception reason: why the channel was closed + + """ + LOGGER.warning("Channel %i was closed: %s", channel, reason) + self.close_connection() + + def setup_queue(self) -> None: + """Setup the queue on RabbitMQ by invoking the Queue.Declare RPC + command. When it is complete, the on_queue_declare_ok method will + be invoked by pika. + + """ + LOGGER.info("Declaring callback queue") + assert self._channel is not None + self._channel.queue_declare( + queue="", exclusive=True, callback=self.on_queue_declare_ok + ) + + def on_queue_declare_ok(self, frame: pika.frame.Method) -> None: + """Method invoked by pika when the Queue.Declare RPC call made in + setup_queue has completed. This method sets up the consumer prefetch + to only be delivered one message at a time. The consumer must + acknowledge this message before RabbitMQ will deliver another one. + You should experiment with different prefetch values to achieve desired + performance. + + :param pika.frame.Method method_frame: The Queue.DeclareOk frame + + """ + LOGGER.info("Binding queue to default exchanger") + assert self._channel is not None + self._callback_queue = frame.method.queue + self._channel.basic_qos( + prefetch_count=self._prefetch_count, callback=self.on_basic_qos_ok + ) + + def on_basic_qos_ok(self, _unused_frame: pika.frame.Method) -> None: + """Invoked by pika when the Basic.QoS method has completed. At this + point we will start consuming messages by calling start_consuming + which will invoke the needed RPC commands to start the process. + + :param pika.frame.Method _unused_frame: The Basic.QosOk response frame + + """ + LOGGER.info("QOS set to: %d", self._prefetch_count) + self.start_consuming() + + def start_consuming(self) -> None: + """This method sets up the consumer by first calling + add_on_cancel_callback so that the object is notified if RabbitMQ + cancels the consumer. It then issues the Basic.Consume RPC command + which returns the consumer tag that is used to uniquely identify the + consumer with RabbitMQ. We keep the value to use it when we want to + cancel consuming. The on_response method is passed in as a callback pika + will invoke when a message is fully received. + + """ + LOGGER.info("Issuing consumer related RPC commands") + LOGGER.info("Adding consumer cancellation callback") + assert self._channel is not None + self._channel.add_on_cancel_callback(callback=self.on_consumer_cancelled) + assert self._callback_queue is not None + self._consumer_tag = self._channel.basic_consume( + queue=self._callback_queue, on_message_callback=self.on_response + ) + self._consuming = True + + def on_consumer_cancelled(self, method_frame: pika.frame.Method) -> None: + """Invoked by pika when RabbitMQ sends a Basic.Cancel for a consumer + receiving messages. + + :param pika.frame.Method method_frame: The Basic.Cancel frame + + """ + LOGGER.info("Consumer was cancelled remotely, shutting down: %r", method_frame) + if self._channel: + self._channel.close() + + def on_response( + self, + channel: pika.channel.Channel, + deliver: pika.spec.Basic.Deliver, + properties: pika.spec.BasicProperties, + body: bytes, + ) -> None: + """Invoked by pika when a message is delivered from RabbitMQ. The + channel is passed for your convenience. The deliver object that + is passed in carries the exchange, routing key, delivery tag and + a redelivered flag for the message. The properties passed in is an + instance of BasicProperties with the message properties and the body + is the message that was sent. + + :param pika.channel.Channel channel: The channel object + :param pika.spec.Basic.Deliver: deliver method + :param pika.spec.BasicProperties: properties + :param bytes body: The message body + + """ + LOGGER.info( + "Received message # %s from %s: %s", + deliver.delivery_tag, + properties.app_id, + body, + ) + self._response_queue.put( + ( + properties.correlation_id, + decode_data(body, extra_decoders=self.extra_type_decoders), + ) + ) + LOGGER.info("Acknowledging message %s", deliver.delivery_tag) + channel.basic_ack(delivery_tag=deliver.delivery_tag) + + def stop_consuming(self) -> None: + """Tell RabbitMQ that you would like to stop consuming by sending the + Basic.Cancel RPC command. + + """ + if self._channel: + LOGGER.info("Sending a Basic.Cancel RPC command to RabbitMQ") + self._channel.basic_cancel(self._consumer_tag, self.on_cancel_ok) + + def on_cancel_ok(self, _unused_frame: pika.frame.Method) -> None: + """This method is invoked by pika when RabbitMQ acknowledges the + cancellation of a consumer. At this point we will close the channel. + This will invoke the on_channel_closed method once the channel has been + closed, which will in-turn close the connection. + + :param pika.frame.Method _unused_frame: The Basic.CancelOk frame + :param str|unicode consumer_tag: Tag of the consumer to be stopped + + """ + self._consuming = False + LOGGER.info( + "RabbitMQ acknowledged the cancellation of the consumer: %s", + self._consumer_tag, + ) + LOGGER.info("Closing the channel") + assert self._channel is not None + self._channel.close() + + def run(self) -> None: + """Run the example code by connecting and then starting the IOLoop.""" + + while not self._closing: + try: + self._connection = self.connect() + assert self._connection is not None + self._connection.ioloop.start() + except KeyboardInterrupt: + LOGGER.info("Connection closed by keyboard interruption, reopening") + if self._connection is not None: + self._connection.ioloop.stop() + except TerminateSignal: + self.stop() + if self._connection is not None and not self._connection.is_closed: + # Finish closing + self._connection.ioloop.start() + LOGGER.info("Stopped") + + def stop(self) -> None: + """Cleanly shutdown the connection to RabbitMQ by stopping the consumer + with RabbitMQ. When RabbitMQ confirms the cancellation, on_cancel_ok + will be invoked by pika, which will then closing the channel and + connection. The IOLoop is started again because this method is invoked + by raising a TerminateSignal exception. This exception stops the IOLoop + which needs to be running for pika to communicate with RabbitMQ. All of + the commands issued prior to starting the IOLoop will be buffered but + not processed. + + """ + assert self._connection is not None + if not self._closing: + self._closing = True + LOGGER.info("Stopping") + if self._consuming: + self.stop_consuming() + self._connection.ioloop.start() + else: + self._connection.ioloop.stop() + LOGGER.info("Stopped") + + @staticmethod + def request( + channel: pika.channel.Channel, + reply_to: str, + exchange: str, + routing_key: str, + correlation_id: str, + **kwargs, + ) -> None: + channel.basic_publish( + exchange=exchange, + routing_key=routing_key, + properties=pika.BasicProperties( + content_type="application/msgpack", + correlation_id=correlation_id, + reply_to=reply_to, + ), + body=encode_data( + kwargs, + extra_encoders=ProvenanceStorageRabbitMQClient.extra_type_encoders, + ), + ) + + def wait_for_acks(self, acks_expected: int) -> bool: + acks_received = 0 + while acks_received < acks_expected: + try: + acks_received += self.wait_for_response() + except ResponseTimeout: + LOGGER.warning( + "Timed out waiting for acks, %s received, %s expected", + acks_received, + acks_expected, + ) + return False + return acks_received == acks_expected + + def wait_for_response(self, timeout: float = 60.0) -> Any: + while True: + try: + correlation_id, response = self._response_queue.get(timeout=timeout) + if correlation_id == self._correlation_id: + return response + except queue.Empty: + raise ResponseTimeout diff --git a/swh/provenance/api/server.py b/swh/provenance/api/server.py --- a/swh/provenance/api/server.py +++ b/swh/provenance/api/server.py @@ -3,79 +3,761 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from collections import Counter +from datetime import datetime +from enum import Enum +import functools import logging +import multiprocessing import os -from typing import Any, Dict, List, Optional - -from werkzeug.routing import Rule +import queue +import threading +from typing import Any, Callable +from typing import Counter as TCounter +from typing import Dict, Generator, Iterable, List, Optional, Set, Tuple, Union, cast + +import pika +import pika.channel +import pika.connection +import pika.exceptions +from pika.exchange_type import ExchangeType +import pika.frame +import pika.spec from swh.core import config -from swh.core.api import JSONFormatter, MsgpackFormatter, RPCServerApp, negotiate +from swh.core.api.serializers import encode_data_client as encode_data +from swh.core.api.serializers import msgpack_loads as decode_data +from swh.model.model import Sha1Git from swh.provenance import get_provenance_storage -from swh.provenance.interface import ProvenanceStorageInterface +from swh.provenance.interface import ( + EntityType, + ProvenanceStorageInterface, + RelationData, + RelationType, + RevisionData, +) from .serializers import DECODERS, ENCODERS -storage: Optional[ProvenanceStorageInterface] = None - +LOG_FORMAT = ( + "%(levelname) -10s %(asctime)s %(name) -30s %(funcName) " + "-35s %(lineno) -5d: %(message)s" +) +LOGGER = logging.getLogger(__name__) + +TERMINATE = object() + + +class ServerCommand(Enum): + TERMINATE = "terminate" + + +class TerminateSignal(Exception): + pass + + +def resolve_dates( + dates: Iterable[Union[Tuple[Sha1Git, Optional[datetime]], Tuple[Sha1Git]]] +) -> Dict[Sha1Git, Optional[datetime]]: + result: Dict[Sha1Git, Optional[datetime]] = {} + for row in dates: + sha1 = row[0] + date = ( + cast(Tuple[Sha1Git, Optional[datetime]], row)[1] if len(row) > 1 else None + ) + known = result.setdefault(sha1, None) + if date is not None and (known is None or date < known): + result[sha1] = date + return result + + +def resolve_revision( + data: Iterable[Union[Tuple[Sha1Git, RevisionData], Tuple[Sha1Git]]] +) -> Dict[Sha1Git, RevisionData]: + result: Dict[Sha1Git, RevisionData] = {} + for row in data: + sha1 = row[0] + rev = ( + cast(Tuple[Sha1Git, RevisionData], row)[1] + if len(row) > 1 + else RevisionData(date=None, origin=None) + ) + known = result.setdefault(sha1, RevisionData(date=None, origin=None)) + value = known + if rev.date is not None and (known.date is None or rev.date < known.date): + value = RevisionData(date=rev.date, origin=value.origin) + if rev.origin is not None: + value = RevisionData(date=value.date, origin=rev.origin) + if value != known: + result[sha1] = value + return result + + +def resolve_relation( + data: Iterable[Tuple[Sha1Git, Sha1Git, bytes]] +) -> Dict[Sha1Git, Set[RelationData]]: + result: Dict[Sha1Git, Set[RelationData]] = {} + for src, dst, path in data: + result.setdefault(src, set()).add(RelationData(dst=dst, path=path)) + return result + + +class ProvenanceStorageRabbitMQWorker(multiprocessing.Process): + """This is an example publisher that will handle unexpected interactions + with RabbitMQ such as channel and connection closures. + + If RabbitMQ closes the connection, it will reopen it. You should + look at the output, as there are limited reasons why the connection may + be closed, which usually are tied to permission related issues or + socket timeouts. + + It uses delivery confirmations and illustrates one way to keep track of + messages that have been sent and if they've been confirmed by RabbitMQ. -def get_global_provenance_storage() -> ProvenanceStorageInterface: - global storage - if storage is None: - storage = get_provenance_storage(**app.config["provenance"]["storage"]) - return storage + """ + EXCHANGE_TYPE = ExchangeType.topic -class ProvenanceStorageServerApp(RPCServerApp): extra_type_decoders = DECODERS extra_type_encoders = ENCODERS + def __init__( + self, url: str, exchange: str, range: int, storage_config: Dict[str, Any] + ) -> None: + """Setup the example publisher object, passing in the URL we will use + to connect to RabbitMQ. + + :param str url: The URL for connecting to RabbitMQ + :param str routing_key: The routing key name from which this worker will + consume messages + :param str storage_config: Configuration parameters for the underlying + ``ProvenanceStorage`` object + + """ + super().__init__(name=f"{exchange}_{range:x}") + + self._connection = None + self._channel = None + self._closing = False + self._consumer_tag: Dict[str, str] = {} + self._consuming: Dict[str, bool] = {} + self._prefetch_count = 100 + + self._url = url + self._exchange = exchange + self._binding_keys = list( + ProvenanceStorageRabbitMQServer.get_binding_keys(self._exchange, range) + ) + self._queues: Dict[str, str] = {} + self._storage_config = storage_config + self._batch_size = 100 + + self.command: multiprocessing.Queue = multiprocessing.Queue() + + def connect(self) -> pika.SelectConnection: + """This method connects to RabbitMQ, returning the connection handle. + When the connection is established, the on_connection_open method + will be invoked by pika. + + :rtype: pika.SelectConnection + + """ + LOGGER.info("Connecting to %s", self._url) + return pika.SelectConnection( + parameters=pika.URLParameters(self._url), + on_open_callback=self.on_connection_open, + on_open_error_callback=self.on_connection_open_error, + on_close_callback=self.on_connection_closed, + ) + + def close_connection(self) -> None: + assert self._connection is not None + self._consuming = {binding_key: False for binding_key in self._binding_keys} + if self._connection.is_closing or self._connection.is_closed: + LOGGER.info("Connection is closing or already closed") + else: + LOGGER.info("Closing connection") + self._connection.close() + + def on_connection_open(self, _unused_connection: pika.SelectConnection) -> None: + """This method is called by pika once the connection to RabbitMQ has + been established. It passes the handle to the connection object in + case we need it, but in this case, we'll just mark it unused. + + :param pika.SelectConnection _unused_connection: The connection + + """ + LOGGER.info("Connection opened") + self.open_channel() + + def on_connection_open_error( + self, _unused_connection: pika.SelectConnection, err: Exception + ) -> None: + """This method is called by pika if the connection to RabbitMQ + can't be established. + + :param pika.SelectConnection _unused_connection: The connection + :param Exception err: The error + + """ + LOGGER.error("Connection open failed, reopening in 5 seconds: %s", err) + assert self._connection is not None + self._connection.ioloop.call_later(5, self._connection.ioloop.stop) + + def on_connection_closed(self, _unused_connection: pika.SelectConnection, reason): + """This method is invoked by pika when the connection to RabbitMQ is + closed unexpectedly. Since it is unexpected, we will reconnect to + RabbitMQ if it disconnects. + + :param pika.connection.Connection connection: The closed connection obj + :param Exception reason: exception representing reason for loss of + connection. + + """ + assert self._connection is not None + self._channel = None + if self._closing: + self._connection.ioloop.stop() + else: + LOGGER.warning("Connection closed, reopening in 5 seconds: %s", reason) + self._connection.ioloop.call_later(5, self._connection.ioloop.stop) + + def open_channel(self) -> None: + """Open a new channel with RabbitMQ by issuing the Channel.Open RPC + command. When RabbitMQ responds that the channel is open, the + on_channel_open callback will be invoked by pika. + + """ + LOGGER.info("Creating a new channel") + assert self._connection is not None + self._connection.channel(on_open_callback=self.on_channel_open) + + def on_channel_open(self, channel: pika.channel.Channel) -> None: + """This method is invoked by pika when the channel has been opened. + The channel object is passed in so we can make use of it. + + Since the channel is now open, we'll declare the exchange to use. + + :param pika.channel.Channel channel: The channel object + + """ + LOGGER.info("Channel opened") + self._channel = channel + LOGGER.info("Adding channel close callback") + assert self._channel is not None + self._channel.add_on_close_callback(callback=self.on_channel_closed) + self.setup_exchange() + + def on_channel_closed( + self, channel: pika.channel.Channel, reason: Exception + ) -> None: + """Invoked by pika when RabbitMQ unexpectedly closes the channel. + Channels are usually closed if you attempt to do something that + violates the protocol, such as re-declare an exchange or queue with + different parameters. In this case, we'll close the connection + to shutdown the object. + + :param pika.channel.Channel: The closed channel + :param Exception reason: why the channel was closed + + """ + LOGGER.warning("Channel %i was closed: %s", channel, reason) + self.close_connection() + + def setup_exchange(self) -> None: + """Setup the exchange on RabbitMQ by invoking the Exchange.Declare RPC + command. When it is complete, the on_exchange_declare_ok method will + be invoked by pika. + + """ + LOGGER.info("Declaring exchange %s", self._exchange) + assert self._channel is not None + self._channel.exchange_declare( + exchange=self._exchange, + exchange_type=self.EXCHANGE_TYPE, + callback=self.on_exchange_declare_ok, + ) + + def on_exchange_declare_ok(self, _unused_frame: pika.frame.Method) -> None: + """Invoked by pika when RabbitMQ has finished the Exchange.Declare RPC + command. + + :param pika.frame.Method unused_frame: Exchange.DeclareOk response frame + + """ + LOGGER.info("Exchange declared: %s", self._exchange) + self.setup_queues() + + def setup_queues(self) -> None: + """Setup the queues on RabbitMQ by invoking the Queue.Declare RPC + command. When it is complete, the on_queue_declare_ok method will + be invoked by pika. + + """ + for binding_key in self._binding_keys: + LOGGER.info("Declaring queue %s", binding_key) + assert self._channel is not None + callback = functools.partial( + self.on_queue_declare_ok, + binding_key=binding_key, + ) + self._channel.queue_declare(queue="", exclusive=True, callback=callback) + + def on_queue_declare_ok(self, frame: pika.frame.Method, binding_key: str) -> None: + """Method invoked by pika when the Queue.Declare RPC call made in + setup_queue has completed. In this method we will bind the queue + and exchange together with the routing key by issuing the Queue.Bind + RPC command. When this command is complete, the on_bind_ok method will + be invoked by pika. + + :param pika.frame.Method frame: The Queue.DeclareOk frame + :param str|unicode binding_key: Binding key of the queue to declare + + """ + LOGGER.info( + "Binding queue %s to exchange %s with routing key %s", + frame.method.queue, + self._exchange, + binding_key, + ) + assert self._channel is not None + callback = functools.partial(self.on_bind_ok, queue_name=frame.method.queue) + self._queues[binding_key] = frame.method.queue + self._channel.queue_bind( + queue=frame.method.queue, + exchange=self._exchange, + routing_key=binding_key, + callback=callback, + ) + + def on_bind_ok(self, _unused_frame: pika.frame.Method, queue_name: str) -> None: + """Invoked by pika when the Queue.Bind method has completed. At this + point we will set the prefetch count for the channel. + + :param pika.frame.Method _unused_frame: The Queue.BindOk response frame + :param str|unicode queue_name: The name of the queue to declare + + """ + LOGGER.info("Queue bound: %s", queue_name) + self.set_qos() + + def set_qos(self) -> None: + """This method sets up the consumer prefetch to only be delivered + one message at a time. The consumer must acknowledge this message + before RabbitMQ will deliver another one. You should experiment + with different prefetch values to achieve desired performance. + + """ + assert self._channel is not None + self._channel.basic_qos( + prefetch_count=self._prefetch_count, callback=self.on_basic_qos_ok + ) + + def on_basic_qos_ok(self, _unused_frame: pika.frame.Method) -> None: + """Invoked by pika when the Basic.QoS method has completed. At this + point we will start consuming messages by calling start_consuming + which will invoke the needed RPC commands to start the process. + + :param pika.frame.Method _unused_frame: The Basic.QosOk response frame + + """ + LOGGER.info("QOS set to: %d", self._prefetch_count) + self.start_consuming() + + def start_consuming(self) -> None: + """This method sets up the consumer by first calling + add_on_cancel_callback so that the object is notified if RabbitMQ + cancels the consumer. It then issues the Basic.Consume RPC command + which returns the consumer tag that is used to uniquely identify the + consumer with RabbitMQ. We keep the value to use it when we want to + cancel consuming. The on_request method is passed in as a callback pika + will invoke when a message is fully received. + + """ + LOGGER.info("Issuing consumer related RPC commands") + LOGGER.info("Adding consumer cancellation callback") + assert self._channel is not None + self._channel.add_on_cancel_callback(callback=self.on_consumer_cancelled) + for binding_key in self._binding_keys: + self._consumer_tag[binding_key] = self._channel.basic_consume( + queue=self._queues[binding_key], on_message_callback=self.on_request + ) + self._consuming[binding_key] = True + + def on_consumer_cancelled(self, method_frame: pika.frame.Method) -> None: + """Invoked by pika when RabbitMQ sends a Basic.Cancel for a consumer + receiving messages. + + :param pika.frame.Method method_frame: The Basic.Cancel frame + + """ + LOGGER.info("Consumer was cancelled remotely, shutting down: %r", method_frame) + if self._channel: + self._channel.close() + + def on_request( + self, + channel: pika.channel.Channel, + deliver: pika.spec.Basic.Deliver, + properties: pika.spec.BasicProperties, + body: bytes, + ) -> None: + """Invoked by pika when a message is delivered from RabbitMQ. The + channel is passed for your convenience. The deliver object that + is passed in carries the exchange, routing key, delivery tag and + a redelivered flag for the message. The properties passed in is an + instance of BasicProperties with the message properties and the body + is the message that was sent. + + :param pika.channel.Channel channel: The channel object + :param pika.spec.Basic.Deliver: deliver method + :param pika.spec.BasicProperties: properties + :param bytes body: The message body + + """ + LOGGER.info( + "Received message # %s from %s: %s", + deliver.delivery_tag, + properties.app_id, + body, + ) + # XXX: for some reason this function is returning lists instead of tuples + # (the client send tuples) + item = decode_data(data=body, extra_decoders=self.extra_type_decoders)["data"] + self._request_queues[deliver.routing_key].put( + (tuple(item), (properties.correlation_id, properties.reply_to)) + ) + LOGGER.info("Acknowledging message %s", deliver.delivery_tag) + channel.basic_ack(delivery_tag=deliver.delivery_tag) + + def stop_consuming(self) -> None: + """Tell RabbitMQ that you would like to stop consuming by sending the + Basic.Cancel RPC command. + + """ + if self._channel: + LOGGER.info("Sending a Basic.Cancel RPC command to RabbitMQ") + for binding_key in self._binding_keys: + callback = functools.partial(self.on_cancel_ok, binding_key=binding_key) + self._channel.basic_cancel( + self._consumer_tag[binding_key], callback=callback + ) -app = ProvenanceStorageServerApp( - __name__, - backend_class=ProvenanceStorageInterface, - backend_factory=get_global_provenance_storage, -) + def on_cancel_ok(self, _unused_frame: pika.frame.Method, binding_key: str) -> None: + """This method is invoked by pika when RabbitMQ acknowledges the + cancellation of a consumer. At this point we will close the channel. + This will invoke the on_channel_closed method once the channel has been + closed, which will in-turn close the connection. + + :param pika.frame.Method _unused_frame: The Basic.CancelOk frame + :param str|unicode binding_key: Binding key of of the consumer to be stopped + + """ + self._consuming[binding_key] = False + LOGGER.info( + "RabbitMQ acknowledged the cancellation of the consumer: %s", + self._consuming[binding_key], + ) + LOGGER.info("Closing the channel") + assert self._channel is not None + self._channel.close() + + def run(self) -> None: + """Run the example code by connecting and then starting the IOLoop.""" + + self._command_thread = threading.Thread(target=self.run_command_thread) + self._command_thread.start() + + self._request_queues: Dict[str, queue.Queue] = {} + self._request_threads: Dict[str, threading.Thread] = {} + for binding_key in self._binding_keys: + meth_name, relation = ProvenanceStorageRabbitMQServer.get_meth_name( + binding_key + ) + self._request_queues[binding_key] = queue.Queue() + self._request_threads[binding_key] = threading.Thread( + target=self.run_request_thread, + args=(binding_key, meth_name, relation), + ) + self._request_threads[binding_key].start() + + while not self._closing: + try: + self._connection = self.connect() + assert self._connection is not None + self._connection.ioloop.start() + except KeyboardInterrupt: + LOGGER.info("Connection closed by keyboard interruption, reopening") + if self._connection is not None: + self._connection.ioloop.stop() + except TerminateSignal: + self.stop() + if self._connection is not None and not self._connection.is_closed: + # Finish closing + self._connection.ioloop.start() + + for binding_key in self._binding_keys: + self._request_queues[binding_key].put(TERMINATE) + self._request_threads[binding_key].join() + + self._command_thread.join() + LOGGER.info("Stopped") + + def run_command_thread(self) -> None: + while True: + try: + command = self.command.get() + if command == ServerCommand.TERMINATE: + assert self._connection is not None + self._connection.ioloop.add_callback_threadsafe( + self._request_terminate + ) + break + except queue.Empty: + pass + + def _request_terminate(self) -> None: + raise TerminateSignal + + def run_request_thread( + self, binding_key: str, meth_name: str, relation: Optional[RelationType] + ) -> None: + storage = get_provenance_storage(**self._storage_config) + + request_queue = self._request_queues[binding_key] + merge_items = ProvenanceStorageRabbitMQWorker.get_conflicts_func(meth_name) + + while True: + terminate = False + elements = [] + while True: + try: + # TODO: consider reducing this timeout or removing it + elem = request_queue.get(timeout=0.1) + if elem is TERMINATE: + terminate = True + break + elements.append(elem) + except queue.Empty: + break + + if len(elements) >= self._batch_size: + break + + if terminate: + break + + if not elements: + continue + + items, props = zip(*elements) + acks_count: TCounter[Tuple[str, str]] = Counter(props) + data = merge_items(items) + if not data: + print("Elements", elements) + print("Props", props) + print("Items", items) + print("Data", data) + + args = (relation, data) if relation is not None else (data,) + if getattr(storage, meth_name)(*args): + for (correlation_id, reply_to), count in acks_count.items(): + # FIXME: this is running in a different thread! Hence, if + # self._connection drops, there is no guarantee that the response + # can be sent for the current elements. This situation should be + # handled properly. + assert self._connection is not None + self._connection.ioloop.add_callback_threadsafe( + functools.partial( + ProvenanceStorageRabbitMQServer.respond, + channel=self._channel, + correlation_id=correlation_id, + reply_to=reply_to, + response=count, + ) + ) + else: + LOGGER.warning("Unable to process elements for queue %s", binding_key) + for elem in elements: + request_queue.put(elem) + + def stop(self) -> None: + """Cleanly shutdown the connection to RabbitMQ by stopping the consumer + with RabbitMQ. When RabbitMQ confirms the cancellation, on_cancel_ok + will be invoked by pika, which will then closing the channel and + connection. The IOLoop is started again because this method is invoked + by raising a TerminateSignal exception. This exception stops the IOLoop + which needs to be running for pika to communicate with RabbitMQ. All of + the commands issued prior to starting the IOLoop will be buffered but + not processed. + + """ + assert self._connection is not None + if not self._closing: + self._closing = True + LOGGER.info("Stopping") + if any(self._consuming): + self.stop_consuming() + self._connection.ioloop.start() + else: + self._connection.ioloop.stop() + LOGGER.info("Stopped") + + @staticmethod + def get_conflicts_func(meth_name: str) -> Callable[[Iterable[Any]], Any]: + if meth_name in ["content_add", "directory_add"]: + return resolve_dates + elif meth_name == "location_add": + return lambda data: set(data) # just remove duplicates + elif meth_name == "origin_add": + return lambda data: dict(data) # last processed value is good enough + elif meth_name == "revision_add": + return resolve_revision + elif meth_name == "relation_add": + return resolve_relation + else: + LOGGER.warning( + "Unexpected conflict resolution function request for method %s", + meth_name, + ) + return lambda x: x -def has_no_empty_params(rule: Rule) -> bool: - return len(rule.defaults or ()) >= len(rule.arguments or ()) - - -@app.route("/") -def index() -> str: - return """<html> -<head><title>Software Heritage provenance storage RPC server</title></head> -<body> -<p>You have reached the -<a href="https://www.softwareheritage.org/">Software Heritage</a> -provenance storage RPC server.<br /> -See its -<a href="https://docs.softwareheritage.org/devel/swh-provenance/">documentation -and API</a> for more information</p> -</body> -</html>""" - - -@app.route("/site-map") -@negotiate(MsgpackFormatter) -@negotiate(JSONFormatter) -def site_map() -> List[Dict[str, Any]]: - links = [] - for rule in app.url_map.iter_rules(): - if has_no_empty_params(rule) and hasattr( - ProvenanceStorageInterface, rule.endpoint - ): - links.append( - dict( - rule=rule.rule, - description=getattr( - ProvenanceStorageInterface, rule.endpoint - ).__doc__, + +class ProvenanceStorageRabbitMQServer: + backend_class = ProvenanceStorageInterface + extra_type_decoders = DECODERS + extra_type_encoders = ENCODERS + + queue_count = 16 + + def __init__(self, url: str, storage_config: Dict[str, Any]) -> None: + self._workers: List[ProvenanceStorageRabbitMQWorker] = [] + for exchange in ProvenanceStorageRabbitMQServer.get_exchanges(): + for range in ProvenanceStorageRabbitMQServer.get_ranges(): + worker = ProvenanceStorageRabbitMQWorker( + url, exchange, range, storage_config ) - ) - # links is now a list of url, endpoint tuples - return links + self._workers.append(worker) + self._running = False + + def start(self) -> None: + if not self._running: + for worker in self._workers: + worker.start() + LOGGER.info("Start serving") + self._running = True + + def stop(self) -> None: + if self._running: + for worker in self._workers: + worker.command.put(ServerCommand.TERMINATE) + worker.join() + LOGGER.info("Stop serving") + self._running = False + + @staticmethod + def ack(channel: pika.channel.Channel, delivery_tag: int) -> None: + channel.basic_ack(delivery_tag=delivery_tag) + + @staticmethod + def get_binding_keys(exchange: str, range: int) -> Generator[str, None, None]: + for meth_name, relation in ProvenanceStorageRabbitMQServer.get_meth_names( + exchange + ): + if relation is None: + yield f"{meth_name}.unknown.{range:x}".lower() + else: + yield f"{meth_name}.{relation.value}.{range:x}".lower() + + @staticmethod + def get_exchange(meth_name: str, relation: Optional[RelationType] = None) -> str: + if meth_name == "relation_add": + assert relation is not None + split = relation.value + else: + split = meth_name + exchange, *_ = split.split("_") + return exchange + + @staticmethod + def get_exchanges() -> Generator[str, None, None]: + yield from [entity.value for entity in EntityType] + ["location"] + + @staticmethod + def get_meth_name( + binding_key: str, + ) -> Tuple[str, Optional[RelationType]]: + meth_name, relation, *_ = binding_key.split(".") + return meth_name, (RelationType(relation) if relation != "unknown" else None) + + @staticmethod + def get_meth_names( + exchange: str, + ) -> Generator[Tuple[str, Optional[RelationType]], None, None]: + if exchange == EntityType.CONTENT.value: + yield from [ + ("content_add", None), + ("relation_add", RelationType.CNT_EARLY_IN_REV), + ("relation_add", RelationType.CNT_IN_DIR), + ] + elif exchange == EntityType.DIRECTORY.value: + yield from [ + ("directory_add", None), + ("relation_add", RelationType.DIR_IN_REV), + ] + elif exchange == EntityType.ORIGIN.value: + yield from [("origin_add", None)] + elif exchange == EntityType.REVISION.value: + yield from [ + ("revision_add", None), + ("relation_add", RelationType.REV_BEFORE_REV), + ("relation_add", RelationType.REV_IN_ORG), + ] + elif exchange == "location": + yield "location_add", None + + @staticmethod + def get_ranges() -> Generator[int, None, None]: + yield from range(ProvenanceStorageRabbitMQServer.queue_count) + + @staticmethod + def get_routing_key( + data: Tuple[bytes, ...], meth_name: str, relation: Optional[RelationType] = None + ) -> str: + idx = ( + data[0][0] // (256 // ProvenanceStorageRabbitMQServer.queue_count) + if data and data[0] + else 0 + ) + if relation is None: + return f"{meth_name}.unknown.{idx:x}".lower() + else: + return f"{meth_name}.{relation.value}.{idx:x}".lower() + + @staticmethod + def is_write_method(meth_name: str) -> bool: + return "_add" in meth_name + + @staticmethod + def respond( + channel: pika.channel.Channel, + correlation_id: str, + reply_to: str, + response: Any, + ): + channel.basic_publish( + exchange="", + routing_key=reply_to, + properties=pika.BasicProperties( + content_type="application/msgpack", + correlation_id=correlation_id, + ), + body=encode_data( + response, + extra_encoders=ProvenanceStorageRabbitMQServer.extra_type_encoders, + ), + ) def load_and_check_config( @@ -127,22 +809,10 @@ return cfg -api_cfg: Optional[Dict[str, Any]] = None - - -def make_app_from_configfile() -> ProvenanceStorageServerApp: - """Run the WSGI app from the webserver, loading the configuration from - a configuration file. - - SWH_CONFIG_FILENAME environment variable defines the - configuration path to load. - - """ - global api_cfg - if api_cfg is None: - config_path = os.environ.get("SWH_CONFIG_FILENAME") - api_cfg = load_and_check_config(config_path) - app.config.update(api_cfg) - handler = logging.StreamHandler() - app.logger.addHandler(handler) - return app +def make_server_from_configfile() -> ProvenanceStorageRabbitMQServer: + config_path = os.environ.get("SWH_CONFIG_FILENAME") + server_cfg = load_and_check_config(config_path) + return ProvenanceStorageRabbitMQServer( + url=server_cfg["provenance"]["rabbitmq"]["url"], + storage_config=server_cfg["provenance"]["storage"], + ) diff --git a/swh/provenance/cli.py b/swh/provenance/cli.py --- a/swh/provenance/cli.py +++ b/swh/provenance/cli.py @@ -42,21 +42,30 @@ }, "storage": { # Local PostgreSQL Storage - "cls": "postgresql", - "db": { - "host": "localhost", - "user": "postgres", - "password": "postgres", - "dbname": "provenance", - }, + # "cls": "postgresql", + # "db": { + # "host": "localhost", + # "user": "postgres", + # "password": "postgres", + # "dbname": "provenance", + # }, # Local MongoDB Storage # "cls": "mongodb", # "db": { # "dbname": "provenance", # }, - # Remote REST-API/PostgreSQL - # "cls": "remote", - # "url": "http://localhost:8080/%2f", + # Remote RabbitMQ/PostgreSQL Storage + "cls": "rabbitmq", + "url": "amqp://localhost:5672/%2f", + "storage_config": { + "cls": "postgresql", + "db": { + "host": "localhost", + "user": "postgres", + "password": "postgres", + "dbname": "dummy", + }, + }, }, } } diff --git a/swh/provenance/tests/conftest.py b/swh/provenance/tests/conftest.py --- a/swh/provenance/tests/conftest.py +++ b/swh/provenance/tests/conftest.py @@ -5,7 +5,7 @@ from datetime import datetime, timedelta, timezone from os import path -from typing import Any, Dict, Iterable, Iterator +from typing import Any, Dict, Iterable from _pytest.fixtures import SubRequest import msgpack @@ -16,8 +16,6 @@ from swh.journal.serializers import msgpack_ext_hook from swh.provenance import get_provenance, get_provenance_storage -from swh.provenance.api.client import RemoteProvenanceStorage -import swh.provenance.api.server as server from swh.provenance.archive import ArchiveInterface from swh.provenance.interface import ProvenanceInterface, ProvenanceStorageInterface from swh.provenance.storage.archive import ArchiveStorage @@ -46,38 +44,15 @@ return postgresql.get_dsn_parameters() -# the Flask app used as server in these tests -@pytest.fixture -def app( - provenance_postgresqldb: Dict[str, str] -) -> Iterator[server.ProvenanceStorageServerApp]: - assert hasattr(server, "storage") - server.storage = get_provenance_storage( - cls="postgresql", db=provenance_postgresqldb - ) - yield server.app - - -# the RPCClient class used as client used in these tests -@pytest.fixture -def swh_rpc_client_class() -> type: - return RemoteProvenanceStorage - - -@pytest.fixture(params=["mongodb", "postgresql", "remote"]) +@pytest.fixture(params=["mongodb", "postgresql"]) def provenance_storage( request: SubRequest, provenance_postgresqldb: Dict[str, str], mongodb: pymongo.database.Database, - swh_rpc_client: RemoteProvenanceStorage, ) -> ProvenanceStorageInterface: """Return a working and initialized ProvenanceStorageInterface object""" - if request.param == "remote": - assert isinstance(swh_rpc_client, ProvenanceStorageInterface) - return swh_rpc_client - - elif request.param == "mongodb": + if request.param == "mongodb": from swh.provenance.mongo.backend import ProvenanceStorageMongoDb return ProvenanceStorageMongoDb(mongodb)