Changeset View
Changeset View
Standalone View
Standalone View
swh/provenance/api/server.py
# Copyright (C) 2021 The Software Heritage developers | # Copyright (C) 2021 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
from collections import Counter | |||||
from datetime import datetime | |||||
import functools | |||||
import logging | import logging | ||||
import multiprocessing | |||||
import os | import os | ||||
from typing import Any, Dict, List, Optional | import queue | ||||
import re | |||||
import threading | |||||
from typing import Any, Callable | |||||
from typing import Counter as TCounter | |||||
from typing import Dict, Generator, List, Optional, Tuple | |||||
import pika | |||||
import pika.channel | |||||
import pika.connection | |||||
import pika.exceptions | |||||
from pika.exchange_type import ExchangeType | |||||
import pika.frame | |||||
import pika.spec | |||||
from werkzeug.routing import Rule | from werkzeug.routing import Rule | ||||
from swh.core import config | from swh.core import config | ||||
from swh.core.api import JSONFormatter, MsgpackFormatter, RPCServerApp, negotiate | from swh.core.api import JSONFormatter, MsgpackFormatter, RPCServerApp, negotiate | ||||
from swh.core.api.serializers import encode_data_client as encode_data | |||||
from swh.core.api.serializers import msgpack_loads as decode_data | |||||
from swh.model.model import Sha1Git | |||||
from swh.provenance import get_provenance_storage | from swh.provenance import get_provenance_storage | ||||
from swh.provenance.interface import ProvenanceStorageInterface | from swh.provenance.interface import ( | ||||
ProvenanceStorageInterface, | |||||
RelationData, | |||||
RelationType, | |||||
) | |||||
from .serializers import DECODERS, ENCODERS | from .serializers import DECODERS, ENCODERS | ||||
storage: Optional[ProvenanceStorageInterface] = None | storage: Optional[ProvenanceStorageInterface] = None | ||||
def get_global_provenance_storage() -> ProvenanceStorageInterface: | def get_global_provenance_storage() -> ProvenanceStorageInterface: | ||||
global storage | global storage | ||||
if storage is None: | if storage is None: | ||||
storage = get_provenance_storage(**app.config["provenance"]["storage"]) | storage = get_provenance_storage(**app.config["provenance"]["storage"]) | ||||
return storage | return storage | ||||
class ProvenanceStorageServerApp(RPCServerApp): | class ProvenanceStorageRPCServerApp(RPCServerApp): | ||||
extra_type_decoders = DECODERS | extra_type_decoders = DECODERS | ||||
extra_type_encoders = ENCODERS | extra_type_encoders = ENCODERS | ||||
app = ProvenanceStorageServerApp( | app = ProvenanceStorageRPCServerApp( | ||||
__name__, | __name__, | ||||
backend_class=ProvenanceStorageInterface, | backend_class=ProvenanceStorageInterface, | ||||
backend_factory=get_global_provenance_storage, | backend_factory=get_global_provenance_storage, | ||||
) | ) | ||||
def has_no_empty_params(rule: Rule) -> bool: | def has_no_empty_params(rule: Rule) -> bool: | ||||
return len(rule.defaults or ()) >= len(rule.arguments or ()) | return len(rule.defaults or ()) >= len(rule.arguments or ()) | ||||
▲ Show 20 Lines • Show All 82 Lines • ▼ Show 20 Lines | if type == "local": | ||||
raise KeyError("Invalid configuration; missing 'db' config entry") | raise KeyError("Invalid configuration; missing 'db' config entry") | ||||
return cfg | return cfg | ||||
api_cfg: Optional[Dict[str, Any]] = None | api_cfg: Optional[Dict[str, Any]] = None | ||||
def make_app_from_configfile() -> ProvenanceStorageServerApp: | def make_app_from_configfile() -> ProvenanceStorageRPCServerApp: | ||||
"""Run the WSGI app from the webserver, loading the configuration from | """Run the WSGI app from the webserver, loading the configuration from | ||||
a configuration file. | a configuration file. | ||||
SWH_CONFIG_FILENAME environment variable defines the | SWH_CONFIG_FILENAME environment variable defines the | ||||
configuration path to load. | configuration path to load. | ||||
""" | """ | ||||
global api_cfg | global api_cfg | ||||
if api_cfg is None: | if api_cfg is None: | ||||
config_path = os.environ.get("SWH_CONFIG_FILENAME") | config_path = os.environ.get("SWH_CONFIG_FILENAME") | ||||
api_cfg = load_and_check_config(config_path) | api_cfg = load_and_check_config(config_path) | ||||
app.config.update(api_cfg) | app.config.update(api_cfg) | ||||
handler = logging.StreamHandler() | handler = logging.StreamHandler() | ||||
app.logger.addHandler(handler) | app.logger.addHandler(handler) | ||||
return app | return app | ||||
def resolve_dates(dates: List[Tuple[Sha1Git, datetime]]) -> Dict[Sha1Git, datetime]: | |||||
result: Dict[Sha1Git, datetime] = {} | |||||
for sha1, date in dates: | |||||
if sha1 in result and result[sha1] < date: | |||||
continue | |||||
else: | |||||
result[sha1] = date | |||||
return result | |||||
LOG_FORMAT = ( | |||||
"%(levelname) -10s %(asctime)s %(name) -30s %(funcName) " | |||||
"-35s %(lineno) -5d: %(message)s" | |||||
) | |||||
LOGGER = logging.getLogger(__name__) | |||||
TERMINATE = object() | |||||
class ProvenanceStorageRabbitMQWorker(multiprocessing.Process): | |||||
"""This is an example publisher that will handle unexpected interactions | |||||
with RabbitMQ such as channel and connection closures. | |||||
If RabbitMQ closes the connection, it will reopen it. You should | |||||
look at the output, as there are limited reasons why the connection may | |||||
be closed, which usually are tied to permission related issues or | |||||
socket timeouts. | |||||
It uses delivery confirmations and illustrates one way to keep track of | |||||
messages that have been sent and if they've been confirmed by RabbitMQ. | |||||
""" | |||||
EXCHANGE_TYPE = ExchangeType.topic | |||||
extra_type_decoders = DECODERS | |||||
extra_type_encoders = ENCODERS | |||||
def __init__( | |||||
self, url: str, routing_key: str, storage_config: Dict[str, Any] | |||||
) -> None: | |||||
"""Setup the example publisher object, passing in the URL we will use | |||||
to connect to RabbitMQ. | |||||
:param str url: The URL for connecting to RabbitMQ | |||||
:param str routing_key: The routing key name from which this worker will | |||||
consume messages | |||||
:param str storage_config: Configuration parameters for the underlying | |||||
``ProvenanceStorage`` object | |||||
""" | |||||
super().__init__(name=routing_key) | |||||
self.should_reconnect = False | |||||
self.was_consuming = False | |||||
self._connection = None | |||||
self._channel = None | |||||
self._closing = False | |||||
self._consumer_tag = None | |||||
self._consuming = False | |||||
self._prefetch_count = 100 | |||||
self._url = url | |||||
self._routing_key = routing_key | |||||
self._storage_config = storage_config | |||||
self._batch_size = 100 | |||||
def connect(self) -> pika.SelectConnection: | |||||
"""This method connects to RabbitMQ, returning the connection handle. | |||||
When the connection is established, the on_connection_open method | |||||
will be invoked by pika. | |||||
:rtype: pika.SelectConnection | |||||
""" | |||||
LOGGER.info("Connecting to %s", self._url) | |||||
return pika.SelectConnection( | |||||
parameters=pika.URLParameters(self._url), | |||||
on_open_callback=self.on_connection_open, | |||||
on_open_error_callback=self.on_connection_open_error, | |||||
on_close_callback=self.on_connection_closed, | |||||
) | |||||
def close_connection(self) -> None: | |||||
assert self._connection is not None | |||||
self._consuming = False | |||||
if self._connection.is_closing or self._connection.is_closed: | |||||
LOGGER.info("Connection is closing or already closed") | |||||
else: | |||||
LOGGER.info("Closing connection") | |||||
self._connection.close() | |||||
def on_connection_open(self, _unused_connection: pika.SelectConnection) -> None: | |||||
"""This method is called by pika once the connection to RabbitMQ has | |||||
been established. It passes the handle to the connection object in | |||||
case we need it, but in this case, we'll just mark it unused. | |||||
:param pika.SelectConnection _unused_connection: The connection | |||||
""" | |||||
LOGGER.info("Connection opened") | |||||
self.open_channel() | |||||
def on_connection_open_error( | |||||
self, _unused_connection: pika.SelectConnection, err: Exception | |||||
) -> None: | |||||
"""This method is called by pika if the connection to RabbitMQ | |||||
can't be established. | |||||
:param pika.SelectConnection _unused_connection: The connection | |||||
:param Exception err: The error | |||||
""" | |||||
LOGGER.error("Connection open failed: %s", err) | |||||
self.reconnect() | |||||
def on_connection_closed(self, _unused_connection: pika.SelectConnection, reason): | |||||
"""This method is invoked by pika when the connection to RabbitMQ is | |||||
closed unexpectedly. Since it is unexpected, we will reconnect to | |||||
RabbitMQ if it disconnects. | |||||
:param pika.connection.Connection connection: The closed connection obj | |||||
:param Exception reason: exception representing reason for loss of | |||||
connection. | |||||
""" | |||||
self._channel = None | |||||
if self._closing: | |||||
assert self._connection is not None | |||||
self._connection.ioloop.stop() | |||||
else: | |||||
LOGGER.warning("Connection closed, reconnect necessary: %s", reason) | |||||
self.reconnect() | |||||
def reconnect(self) -> None: | |||||
"""Will be invoked if the connection can't be opened or is | |||||
closed. Indicates that a reconnect is necessary then stops the | |||||
ioloop. | |||||
""" | |||||
self.should_reconnect = True | |||||
self.stop() | |||||
def open_channel(self) -> None: | |||||
"""Open a new channel with RabbitMQ by issuing the Channel.Open RPC | |||||
command. When RabbitMQ responds that the channel is open, the | |||||
on_channel_open callback will be invoked by pika. | |||||
""" | |||||
LOGGER.info("Creating a new channel") | |||||
assert self._connection is not None | |||||
self._connection.channel(on_open_callback=self.on_channel_open) | |||||
def on_channel_open(self, channel: pika.channel.Channel) -> None: | |||||
"""This method is invoked by pika when the channel has been opened. | |||||
The channel object is passed in so we can make use of it. | |||||
Since the channel is now open, we'll declare the exchange to use. | |||||
:param pika.channel.Channel channel: The channel object | |||||
""" | |||||
LOGGER.info("Channel opened") | |||||
self._channel = channel | |||||
self.add_on_channel_close_callback() | |||||
self.setup_exchange() | |||||
def add_on_channel_close_callback(self) -> None: | |||||
"""This method tells pika to call the on_channel_closed method if | |||||
RabbitMQ unexpectedly closes the channel. | |||||
""" | |||||
LOGGER.info("Adding channel close callback") | |||||
assert self._channel is not None | |||||
self._channel.add_on_close_callback(callback=self.on_channel_closed) | |||||
def on_channel_closed( | |||||
self, channel: pika.channel.Channel, reason: Exception | |||||
) -> None: | |||||
"""Invoked by pika when RabbitMQ unexpectedly closes the channel. | |||||
Channels are usually closed if you attempt to do something that | |||||
violates the protocol, such as re-declare an exchange or queue with | |||||
different parameters. In this case, we'll close the connection | |||||
to shutdown the object. | |||||
:param pika.channel.Channel: The closed channel | |||||
:param Exception reason: why the channel was closed | |||||
""" | |||||
LOGGER.warning("Channel %i was closed: %s", channel, reason) | |||||
self.close_connection() | |||||
def setup_exchange(self) -> None: | |||||
"""Setup the exchange on RabbitMQ by invoking the Exchange.Declare RPC | |||||
command. When it is complete, the on_exchange_declareok method will | |||||
be invoked by pika. | |||||
""" | |||||
LOGGER.info("Declaring exchange %s", self._routing_key) | |||||
assert self._channel is not None | |||||
self._channel.exchange_declare( | |||||
exchange=self._routing_key, | |||||
exchange_type=self.EXCHANGE_TYPE, | |||||
callback=self.on_exchange_declareok, | |||||
) | |||||
def on_exchange_declareok(self, _unused_frame: pika.frame.Method) -> None: | |||||
"""Invoked by pika when RabbitMQ has finished the Exchange.Declare RPC | |||||
command. | |||||
:param pika.frame.Method unused_frame: Exchange.DeclareOk response frame | |||||
""" | |||||
LOGGER.info("Exchange declared: %s", self._routing_key) | |||||
self.setup_queue() | |||||
def setup_queue(self) -> None: | |||||
"""Setup the queue on RabbitMQ by invoking the Queue.Declare RPC | |||||
command. When it is complete, the on_queue_declareok method will | |||||
be invoked by pika. | |||||
""" | |||||
LOGGER.info("Declaring queue %s", self._routing_key) | |||||
assert self._channel is not None | |||||
self._channel.queue_declare( | |||||
queue=self._routing_key, callback=self.on_queue_declareok | |||||
) | |||||
def on_queue_declareok(self, _unused_frame: pika.frame.Method) -> None: | |||||
"""Method invoked by pika when the Queue.Declare RPC call made in | |||||
setup_queue has completed. In this method we will bind the queue | |||||
and exchange together with the routing key by issuing the Queue.Bind | |||||
RPC command. When this command is complete, the on_bindok method will | |||||
be invoked by pika. | |||||
:param pika.frame.Method method_frame: The Queue.DeclareOk frame | |||||
""" | |||||
LOGGER.info( | |||||
"Binding queue %s to exchange %s with routing key %s", | |||||
self._routing_key, | |||||
self._routing_key, | |||||
self._routing_key, | |||||
) | |||||
assert self._channel is not None | |||||
self._channel.queue_bind( | |||||
queue=self._routing_key, | |||||
exchange=self._routing_key, | |||||
routing_key=self._routing_key, | |||||
callback=self.on_bindok, | |||||
) | |||||
def on_bindok(self, _unused_frame: pika.frame.Method) -> None: | |||||
"""Invoked by pika when the Queue.Bind method has completed. At this | |||||
point we will set the prefetch count for the channel. | |||||
:param pika.frame.Method _unused_frame: The Queue.BindOk response frame | |||||
:param str|unicode queue_name: The name of the queue to declare | |||||
""" | |||||
LOGGER.info("Queue bound: %s", self._routing_key) | |||||
self.set_qos() | |||||
def set_qos(self) -> None: | |||||
"""This method sets up the consumer prefetch to only be delivered | |||||
one message at a time. The consumer must acknowledge this message | |||||
before RabbitMQ will deliver another one. You should experiment | |||||
with different prefetch values to achieve desired performance. | |||||
""" | |||||
assert self._channel is not None | |||||
self._channel.basic_qos( | |||||
prefetch_count=self._prefetch_count, callback=self.on_basic_qos_ok | |||||
) | |||||
def on_basic_qos_ok(self, _unused_frame: pika.frame.Method) -> None: | |||||
"""Invoked by pika when the Basic.QoS method has completed. At this | |||||
point we will start consuming messages by calling start_consuming | |||||
which will invoke the needed RPC commands to start the process. | |||||
:param pika.frame.Method _unused_frame: The Basic.QosOk response frame | |||||
""" | |||||
LOGGER.info("QOS set to: %d", self._prefetch_count) | |||||
self.start_consuming() | |||||
def start_consuming(self) -> None: | |||||
"""This method sets up the consumer by first calling | |||||
add_on_cancel_callback so that the object is notified if RabbitMQ | |||||
cancels the consumer. It then issues the Basic.Consume RPC command | |||||
which returns the consumer tag that is used to uniquely identify the | |||||
consumer with RabbitMQ. We keep the value to use it when we want to | |||||
cancel consuming. The on_message method is passed in as a callback pika | |||||
will invoke when a message is fully received. | |||||
""" | |||||
LOGGER.info("Issuing consumer related RPC commands") | |||||
assert self._channel is not None | |||||
self.add_on_cancel_callback() | |||||
self._consumer_tag = self._channel.basic_consume( | |||||
queue=self._routing_key, on_message_callback=self.on_message | |||||
) | |||||
self.was_consuming = True | |||||
self._consuming = True | |||||
def add_on_cancel_callback(self) -> None: | |||||
"""Add a callback that will be invoked if RabbitMQ cancels the consumer | |||||
for some reason. If RabbitMQ does cancel the consumer, | |||||
on_consumer_cancelled will be invoked by pika. | |||||
""" | |||||
LOGGER.info("Adding consumer cancellation callback") | |||||
assert self._channel is not None | |||||
self._channel.add_on_cancel_callback(callback=self.on_consumer_cancelled) | |||||
def on_consumer_cancelled(self, method_frame: pika.frame.Method) -> None: | |||||
"""Invoked by pika when RabbitMQ sends a Basic.Cancel for a consumer | |||||
receiving messages. | |||||
:param pika.frame.Method method_frame: The Basic.Cancel frame | |||||
""" | |||||
LOGGER.info("Consumer was cancelled remotely, shutting down: %r", method_frame) | |||||
if self._channel: | |||||
self._channel.close() | |||||
def on_message( | |||||
self, | |||||
_unused_channel: pika.channel.Channel, | |||||
basic_deliver: pika.spec.Basic.Deliver, | |||||
properties: pika.spec.BasicProperties, | |||||
body: bytes, | |||||
) -> None: | |||||
"""Invoked by pika when a message is delivered from RabbitMQ. The | |||||
channel is passed for your convenience. The basic_deliver object that | |||||
is passed in carries the exchange, routing key, delivery tag and | |||||
a redelivered flag for the message. The properties passed in is an | |||||
instance of BasicProperties with the message properties and the body | |||||
is the message that was sent. | |||||
:param pika.channel.Channel _unused_channel: The channel object | |||||
:param pika.spec.Basic.Deliver: basic_deliver method | |||||
:param pika.spec.BasicProperties: properties | |||||
:param bytes body: The message body | |||||
""" | |||||
LOGGER.info( | |||||
"Received message # %s from %s: %s", | |||||
basic_deliver.delivery_tag, | |||||
properties.app_id, | |||||
body, | |||||
) | |||||
item = decode_data(data=body, extra_decoders=self.extra_type_decoders)["data"] | |||||
self._request_queue.put( | |||||
(item, (properties.correlation_id, properties.reply_to)) | |||||
) | |||||
self.acknowledge_message(delivery_tag=basic_deliver.delivery_tag) | |||||
def acknowledge_message(self, delivery_tag: int) -> None: | |||||
"""Acknowledge the message delivery from RabbitMQ by sending a | |||||
Basic.Ack RPC method for the delivery tag. | |||||
:param int delivery_tag: The delivery tag from the Basic.Deliver frame | |||||
""" | |||||
LOGGER.info("Acknowledging message %s", delivery_tag) | |||||
assert self._channel is not None | |||||
self._channel.basic_ack(delivery_tag=delivery_tag) | |||||
def stop_consuming(self) -> None: | |||||
"""Tell RabbitMQ that you would like to stop consuming by sending the | |||||
Basic.Cancel RPC command. | |||||
""" | |||||
if self._channel: | |||||
LOGGER.info("Sending a Basic.Cancel RPC command to RabbitMQ") | |||||
self._channel.basic_cancel(self._consumer_tag, self.on_cancelok) | |||||
def on_cancelok(self, _unused_frame: pika.frame.Method) -> None: | |||||
"""This method is invoked by pika when RabbitMQ acknowledges the | |||||
cancellation of a consumer. At this point we will close the channel. | |||||
This will invoke the on_channel_closed method once the channel has been | |||||
closed, which will in-turn close the connection. | |||||
:param pika.frame.Method _unused_frame: The Basic.CancelOk frame | |||||
:param str|unicode consumer_tag: Tag of the consumer to be stopped | |||||
""" | |||||
self._consuming = False | |||||
LOGGER.info( | |||||
"RabbitMQ acknowledged the cancellation of the consumer: %s", | |||||
self._consumer_tag, | |||||
) | |||||
self.close_channel() | |||||
def close_channel(self) -> None: | |||||
"""Call to close the channel with RabbitMQ cleanly by issuing the | |||||
Channel.Close RPC command. | |||||
""" | |||||
LOGGER.info("Closing the channel") | |||||
assert self._channel is not None | |||||
self._channel.close() | |||||
def run(self) -> None: | |||||
"""Run the example code by connecting and then starting the IOLoop.""" | |||||
self._request_queue: queue.Queue = queue.Queue() | |||||
self._storage_thread = threading.Thread(target=self.run_storage_thread) | |||||
self._storage_thread.start() | |||||
while not self._closing: | |||||
try: | |||||
self._connection = self.connect() | |||||
assert self._connection is not None | |||||
self._connection.ioloop.start() | |||||
except KeyboardInterrupt: | |||||
self.stop() | |||||
if self._connection is not None and not self._connection.is_closed: | |||||
# Finish closing | |||||
self._connection.ioloop.start() | |||||
self._request_queue.put(TERMINATE) | |||||
self._storage_thread.join() | |||||
LOGGER.info("Stopped") | |||||
def run_storage_thread(self) -> None: | |||||
storage = get_provenance_storage(**self._storage_config) | |||||
meth_name, relation = ProvenanceStorageRabbitMQServer.get_meth_name( | |||||
self._routing_key | |||||
) | |||||
resolve_conflicts = ProvenanceStorageRabbitMQWorker.get_conflicts_func( | |||||
meth_name | |||||
) | |||||
while True: | |||||
terminate = False | |||||
elements = [] | |||||
while True: | |||||
try: | |||||
elem = self._request_queue.get(timeout=0.1) | |||||
if elem is TERMINATE: | |||||
terminate = True | |||||
break | |||||
elements.append(elem) | |||||
except queue.Empty: | |||||
break | |||||
if len(elements) >= self._batch_size: | |||||
break | |||||
if terminate: | |||||
break | |||||
if not elements: | |||||
continue | |||||
# FIXME: this is running in a different thread! Hence, if self._connection | |||||
# drops, there is no guarantee that the response can be sent for the current | |||||
# elements. This situation should be handled properly. | |||||
assert self._connection is not None | |||||
items, props = zip(*elements) | |||||
acks_count: TCounter[Tuple[str, str]] = Counter(props) | |||||
data = resolve_conflicts(items) | |||||
args = (relation, data) if relation is not None else (data,) | |||||
if getattr(storage, meth_name)(*args): | |||||
for (correlation_id, reply_to), count in acks_count.items(): | |||||
self._connection.ioloop.add_callback_threadsafe( | |||||
functools.partial( | |||||
ProvenanceStorageRabbitMQServer.respond, | |||||
channel=self._channel, | |||||
correlation_id=correlation_id, | |||||
reply_to=reply_to, | |||||
response=count, | |||||
) | |||||
) | |||||
else: | |||||
logging.warning( | |||||
"Unable to process elements for queue %s", self._routing_key | |||||
) | |||||
for elem in elements: | |||||
self._request_queue.put(elem) | |||||
def stop(self) -> None: | |||||
"""Cleanly shutdown the connection to RabbitMQ by stopping the consumer | |||||
with RabbitMQ. When RabbitMQ confirms the cancellation, on_cancelok | |||||
will be invoked by pika, which will then closing the channel and | |||||
connection. The IOLoop is started again because this method is invoked | |||||
when CTRL-C is pressed raising a KeyboardInterrupt exception. This | |||||
exception stops the IOLoop which needs to be running for pika to | |||||
communicate with RabbitMQ. All of the commands issued prior to starting | |||||
the IOLoop will be buffered but not processed. | |||||
""" | |||||
assert self._connection is not None | |||||
if not self._closing: | |||||
self._closing = True | |||||
LOGGER.info("Stopping") | |||||
if self._consuming: | |||||
self.stop_consuming() | |||||
self._connection.ioloop.start() | |||||
else: | |||||
self._connection.ioloop.stop() | |||||
LOGGER.info("Stopped") | |||||
@staticmethod | |||||
def get_conflicts_func(meth_name: str) -> Callable[[List[Any]], Any]: | |||||
if meth_name.startswith("relation_add"): | |||||
# Create RelationData structures from tuples and deduplicate | |||||
return lambda data: {RelationData(*row) for row in data} | |||||
else: | |||||
# Dates should be resolved to the earliest one, | |||||
# otherwise last processed value is good enough | |||||
return resolve_dates if "_date" in meth_name else lambda data: dict(data) | |||||
class ProvenanceStorageRabbitMQServer: | |||||
backend_class = ProvenanceStorageInterface | |||||
extra_type_decoders = DECODERS | |||||
extra_type_encoders = ENCODERS | |||||
queue_count = 16 | |||||
relation_add_regex = re.compile(r"relation_add_(?P<relation>\w+)_[0-9a-fA-F]+") | |||||
def __init__(self, url: str, storage_config: Dict[str, Any]) -> None: | |||||
self._url = url | |||||
self._connection = pika.BlockingConnection( | |||||
pika.connection.URLParameters(self._url) | |||||
) | |||||
self._channel = self._connection.channel() | |||||
self._workers: List[ProvenanceStorageRabbitMQWorker] = [] | |||||
self._channel.basic_qos(prefetch_count=1) | |||||
self._storage_config = storage_config | |||||
self._channel.queue_declare(queue="get_storage_config") | |||||
self._channel.basic_consume( | |||||
queue="get_storage_config", on_message_callback=self.get_storage_config | |||||
) | |||||
for meth_name, meth in self.backend_class.__dict__.items(): | |||||
if hasattr( | |||||
meth, "_endpoint_path" | |||||
) and ProvenanceStorageRabbitMQServer.is_write_method(meth_name): | |||||
for routing_key in ProvenanceStorageRabbitMQServer.get_routing_keys( | |||||
meth_name | |||||
): | |||||
worker = ProvenanceStorageRabbitMQWorker( | |||||
self._url, routing_key, self._storage_config | |||||
) | |||||
worker.start() | |||||
self._workers.append(worker) | |||||
try: | |||||
LOGGER.info("Start consuming") | |||||
self._channel.start_consuming() | |||||
finally: | |||||
LOGGER.info("Stop consuming") | |||||
for worker in self._workers: | |||||
worker.terminate() | |||||
worker.join() | |||||
self._channel.close() | |||||
self._connection.close() | |||||
@staticmethod | |||||
def ack(channel: pika.channel.Channel, delivery_tag: int) -> None: | |||||
channel.basic_ack(delivery_tag=delivery_tag) | |||||
@staticmethod | |||||
def get_meth_name(routing_key: str) -> Tuple[str, Optional[RelationType]]: | |||||
match = ProvenanceStorageRabbitMQServer.relation_add_regex.match(routing_key) | |||||
if match: | |||||
return "relation_add", RelationType(match.group("relation")) | |||||
else: | |||||
return next(iter(routing_key.rsplit("_", 1))), None | |||||
@staticmethod | |||||
def get_routing_key( | |||||
meth_name: str, relation: Optional[RelationType], data: Any | |||||
) -> str: | |||||
idx = data[0][0] // (256 // ProvenanceStorageRabbitMQServer.queue_count) | |||||
if relation is not None: | |||||
return f"{meth_name}_{relation.value}_{idx:x}".lower() | |||||
else: | |||||
return f"{meth_name}_{idx:x}".lower() | |||||
@staticmethod | |||||
def get_routing_keys(meth_name: str) -> Generator[str, None, None]: | |||||
if meth_name.startswith("relation_add"): | |||||
for relation in RelationType: | |||||
for idx in range(ProvenanceStorageRabbitMQServer.queue_count): | |||||
yield f"{meth_name}_{relation.value}_{idx:x}".lower() | |||||
else: | |||||
for idx in range(ProvenanceStorageRabbitMQServer.queue_count): | |||||
yield f"{meth_name}_{idx:x}".lower() | |||||
def get_storage_config( | |||||
self, | |||||
channel: pika.channel.Channel, | |||||
method: pika.spec.Basic.Deliver, | |||||
props: pika.spec.BasicProperties, | |||||
body: bytes, | |||||
) -> None: | |||||
assert method.routing_key == "get_storage_config" | |||||
ProvenanceStorageRabbitMQServer.respond( | |||||
channel=channel, | |||||
correlation_id=props.correlation_id, | |||||
reply_to=props.reply_to, | |||||
response=self._storage_config, | |||||
) | |||||
ProvenanceStorageRabbitMQServer.ack( | |||||
channel=channel, delivery_tag=method.delivery_tag | |||||
) | |||||
@staticmethod | |||||
def is_write_method(meth_name: str) -> bool: | |||||
return meth_name.startswith("relation_add") or "_set_" in meth_name | |||||
@staticmethod | |||||
def respond( | |||||
channel: pika.channel.Channel, | |||||
correlation_id: str, | |||||
reply_to: str, | |||||
response: Any, | |||||
): | |||||
channel.basic_publish( | |||||
exchange="", | |||||
routing_key=reply_to, | |||||
properties=pika.BasicProperties(correlation_id=correlation_id), | |||||
body=encode_data( | |||||
response, | |||||
extra_encoders=ProvenanceStorageRabbitMQServer.extra_type_encoders, | |||||
), | |||||
) | |||||
def make_server_from_configfile() -> None: | |||||
config_path = os.environ.get("SWH_CONFIG_FILENAME") | |||||
server_cfg = load_and_check_config(config_path) | |||||
ProvenanceStorageRabbitMQServer( | |||||
url=server_cfg["provenance"]["rabbitmq"]["url"], | |||||
storage_config=server_cfg["provenance"]["storage"], | |||||
) |