Changeset View
Changeset View
Standalone View
Standalone View
swh/provenance/api/server.py
# Copyright (C) 2021 The Software Heritage developers | # Copyright (C) 2021 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
from collections import Counter | |||||
from datetime import datetime | |||||
from enum import Enum | |||||
import functools | |||||
import logging | import logging | ||||
import multiprocessing | |||||
import os | import os | ||||
from typing import Any, Dict, List, Optional | import queue | ||||
import threading | |||||
from typing import Any, Callable | |||||
from typing import Counter as TCounter | |||||
from typing import Dict, Generator, Iterable, List, Optional, Set, Tuple, Union, cast | |||||
import pika | |||||
import pika.channel | |||||
import pika.connection | |||||
import pika.exceptions | |||||
from pika.exchange_type import ExchangeType | |||||
import pika.frame | |||||
import pika.spec | |||||
from werkzeug.routing import Rule | from werkzeug.routing import Rule | ||||
from swh.core import config | from swh.core import config | ||||
from swh.core.api import JSONFormatter, MsgpackFormatter, RPCServerApp, negotiate | from swh.core.api import JSONFormatter, MsgpackFormatter, RPCServerApp, negotiate | ||||
from swh.core.api.serializers import encode_data_client as encode_data | |||||
from swh.core.api.serializers import msgpack_loads as decode_data | |||||
from swh.model.model import Sha1Git | |||||
from swh.provenance import get_provenance_storage | from swh.provenance import get_provenance_storage | ||||
from swh.provenance.interface import ProvenanceStorageInterface | from swh.provenance.interface import ( | ||||
EntityType, | |||||
ProvenanceStorageInterface, | |||||
RelationData, | |||||
RelationType, | |||||
RevisionData, | |||||
) | |||||
from .serializers import DECODERS, ENCODERS | from .serializers import DECODERS, ENCODERS | ||||
LOG_FORMAT = ( | |||||
"%(levelname) -10s %(asctime)s %(name) -30s %(funcName) " | |||||
"-35s %(lineno) -5d: %(message)s" | |||||
) | |||||
LOGGER = logging.getLogger(__name__) | |||||
TERMINATE = object() | |||||
storage: Optional[ProvenanceStorageInterface] = None | storage: Optional[ProvenanceStorageInterface] = None | ||||
def get_global_provenance_storage() -> ProvenanceStorageInterface: | def get_global_provenance_storage() -> ProvenanceStorageInterface: | ||||
global storage | global storage | ||||
if storage is None: | if storage is None: | ||||
storage = get_provenance_storage(**app.config["provenance"]["storage"]) | storage = get_provenance_storage(**app.config["provenance"]["storage"]) | ||||
return storage | return storage | ||||
▲ Show 20 Lines • Show All 114 Lines • ▼ Show 20 Lines | def make_app_from_configfile() -> ProvenanceStorageRPCServerApp: | ||||
global api_cfg | global api_cfg | ||||
if api_cfg is None: | if api_cfg is None: | ||||
config_path = os.environ.get("SWH_CONFIG_FILENAME") | config_path = os.environ.get("SWH_CONFIG_FILENAME") | ||||
api_cfg = load_and_check_config(config_path) | api_cfg = load_and_check_config(config_path) | ||||
app.config.update(api_cfg) | app.config.update(api_cfg) | ||||
handler = logging.StreamHandler() | handler = logging.StreamHandler() | ||||
app.logger.addHandler(handler) | app.logger.addHandler(handler) | ||||
return app | return app | ||||
class ServerCommand(Enum): | |||||
TERMINATE = "terminate" | |||||
CONSUMING = "consuming" | |||||
class TerminateSignal(BaseException): | |||||
pass | |||||
def resolve_dates( | |||||
dates: Iterable[Union[Tuple[Sha1Git, Optional[datetime]], Tuple[Sha1Git]]] | |||||
) -> Dict[Sha1Git, Optional[datetime]]: | |||||
result: Dict[Sha1Git, Optional[datetime]] = {} | |||||
for row in dates: | |||||
sha1 = row[0] | |||||
date = ( | |||||
cast(Tuple[Sha1Git, Optional[datetime]], row)[1] if len(row) > 1 else None | |||||
) | |||||
known = result.setdefault(sha1, None) | |||||
if date is not None and (known is None or date < known): | |||||
result[sha1] = date | |||||
return result | |||||
def resolve_revision( | |||||
data: Iterable[Union[Tuple[Sha1Git, RevisionData], Tuple[Sha1Git]]] | |||||
) -> Dict[Sha1Git, RevisionData]: | |||||
result: Dict[Sha1Git, RevisionData] = {} | |||||
for row in data: | |||||
sha1 = row[0] | |||||
rev = ( | |||||
cast(Tuple[Sha1Git, RevisionData], row)[1] | |||||
if len(row) > 1 | |||||
else RevisionData(date=None, origin=None) | |||||
) | |||||
known = result.setdefault(sha1, RevisionData(date=None, origin=None)) | |||||
value = known | |||||
if rev.date is not None and (known.date is None or rev.date < known.date): | |||||
value = RevisionData(date=rev.date, origin=value.origin) | |||||
if rev.origin is not None: | |||||
value = RevisionData(date=value.date, origin=rev.origin) | |||||
if value != known: | |||||
result[sha1] = value | |||||
return result | |||||
def resolve_relation( | |||||
data: Iterable[Tuple[Sha1Git, Sha1Git, bytes]] | |||||
) -> Dict[Sha1Git, Set[RelationData]]: | |||||
result: Dict[Sha1Git, Set[RelationData]] = {} | |||||
for src, dst, path in data: | |||||
result.setdefault(src, set()).add(RelationData(dst=dst, path=path)) | |||||
return result | |||||
class ProvenanceStorageRabbitMQWorker(multiprocessing.Process): | |||||
"""This is an example publisher that will handle unexpected interactions | |||||
with RabbitMQ such as channel and connection closures. | |||||
If RabbitMQ closes the connection, it will reopen it. You should | |||||
look at the output, as there are limited reasons why the connection may | |||||
be closed, which usually are tied to permission related issues or | |||||
socket timeouts. | |||||
It uses delivery confirmations and illustrates one way to keep track of | |||||
messages that have been sent and if they've been confirmed by RabbitMQ. | |||||
""" | |||||
EXCHANGE_TYPE = ExchangeType.direct | |||||
extra_type_decoders = DECODERS | |||||
extra_type_encoders = ENCODERS | |||||
def __init__( | |||||
self, url: str, exchange: str, range: int, storage_config: Dict[str, Any] | |||||
) -> None: | |||||
"""Setup the example publisher object, passing in the URL we will use | |||||
to connect to RabbitMQ. | |||||
:param str url: The URL for connecting to RabbitMQ | |||||
:param str routing_key: The routing key name from which this worker will | |||||
consume messages | |||||
:param str storage_config: Configuration parameters for the underlying | |||||
``ProvenanceStorage`` object | |||||
""" | |||||
super().__init__(name=f"{exchange}_{range:x}") | |||||
self._connection = None | |||||
self._channel = None | |||||
self._closing = False | |||||
self._consumer_tag: Dict[str, str] = {} | |||||
self._consuming: Dict[str, bool] = {} | |||||
self._prefetch_count = 100 | |||||
self._url = url | |||||
self._exchange = exchange | |||||
self._binding_keys = list( | |||||
ProvenanceStorageRabbitMQServer.get_binding_keys(self._exchange, range) | |||||
) | |||||
self._queues: Dict[str, str] = {} | |||||
self._storage_config = storage_config | |||||
self._batch_size = 100 | |||||
self.command: multiprocessing.Queue = multiprocessing.Queue() | |||||
self.signal: multiprocessing.Queue = multiprocessing.Queue() | |||||
def connect(self) -> pika.SelectConnection: | |||||
"""This method connects to RabbitMQ, returning the connection handle. | |||||
When the connection is established, the on_connection_open method | |||||
will be invoked by pika. | |||||
:rtype: pika.SelectConnection | |||||
""" | |||||
LOGGER.info("Connecting to %s", self._url) | |||||
return pika.SelectConnection( | |||||
parameters=pika.URLParameters(self._url), | |||||
on_open_callback=self.on_connection_open, | |||||
on_open_error_callback=self.on_connection_open_error, | |||||
on_close_callback=self.on_connection_closed, | |||||
) | |||||
def close_connection(self) -> None: | |||||
assert self._connection is not None | |||||
self._consuming = {binding_key: False for binding_key in self._binding_keys} | |||||
if self._connection.is_closing or self._connection.is_closed: | |||||
LOGGER.info("Connection is closing or already closed") | |||||
else: | |||||
LOGGER.info("Closing connection") | |||||
self._connection.close() | |||||
def on_connection_open(self, _unused_connection: pika.SelectConnection) -> None: | |||||
"""This method is called by pika once the connection to RabbitMQ has | |||||
been established. It passes the handle to the connection object in | |||||
case we need it, but in this case, we'll just mark it unused. | |||||
:param pika.SelectConnection _unused_connection: The connection | |||||
""" | |||||
LOGGER.info("Connection opened") | |||||
self.open_channel() | |||||
def on_connection_open_error( | |||||
self, _unused_connection: pika.SelectConnection, err: Exception | |||||
) -> None: | |||||
"""This method is called by pika if the connection to RabbitMQ | |||||
can't be established. | |||||
:param pika.SelectConnection _unused_connection: The connection | |||||
:param Exception err: The error | |||||
""" | |||||
LOGGER.error("Connection open failed, reopening in 5 seconds: %s", err) | |||||
assert self._connection is not None | |||||
self._connection.ioloop.call_later(5, self._connection.ioloop.stop) | |||||
def on_connection_closed(self, _unused_connection: pika.SelectConnection, reason): | |||||
"""This method is invoked by pika when the connection to RabbitMQ is | |||||
closed unexpectedly. Since it is unexpected, we will reconnect to | |||||
RabbitMQ if it disconnects. | |||||
:param pika.connection.Connection connection: The closed connection obj | |||||
:param Exception reason: exception representing reason for loss of | |||||
connection. | |||||
""" | |||||
assert self._connection is not None | |||||
self._channel = None | |||||
if self._closing: | |||||
self._connection.ioloop.stop() | |||||
else: | |||||
LOGGER.warning("Connection closed, reopening in 5 seconds: %s", reason) | |||||
self._connection.ioloop.call_later(5, self._connection.ioloop.stop) | |||||
def open_channel(self) -> None: | |||||
"""Open a new channel with RabbitMQ by issuing the Channel.Open RPC | |||||
command. When RabbitMQ responds that the channel is open, the | |||||
on_channel_open callback will be invoked by pika. | |||||
""" | |||||
LOGGER.info("Creating a new channel") | |||||
assert self._connection is not None | |||||
self._connection.channel(on_open_callback=self.on_channel_open) | |||||
def on_channel_open(self, channel: pika.channel.Channel) -> None: | |||||
"""This method is invoked by pika when the channel has been opened. | |||||
The channel object is passed in so we can make use of it. | |||||
Since the channel is now open, we'll declare the exchange to use. | |||||
:param pika.channel.Channel channel: The channel object | |||||
""" | |||||
LOGGER.info("Channel opened") | |||||
self._channel = channel | |||||
LOGGER.info("Adding channel close callback") | |||||
assert self._channel is not None | |||||
self._channel.add_on_close_callback(callback=self.on_channel_closed) | |||||
self.setup_exchange() | |||||
def on_channel_closed( | |||||
self, channel: pika.channel.Channel, reason: Exception | |||||
) -> None: | |||||
"""Invoked by pika when RabbitMQ unexpectedly closes the channel. | |||||
Channels are usually closed if you attempt to do something that | |||||
violates the protocol, such as re-declare an exchange or queue with | |||||
different parameters. In this case, we'll close the connection | |||||
to shutdown the object. | |||||
:param pika.channel.Channel: The closed channel | |||||
:param Exception reason: why the channel was closed | |||||
""" | |||||
LOGGER.warning("Channel %i was closed: %s", channel, reason) | |||||
self.close_connection() | |||||
def setup_exchange(self) -> None: | |||||
"""Setup the exchange on RabbitMQ by invoking the Exchange.Declare RPC | |||||
command. When it is complete, the on_exchange_declare_ok method will | |||||
be invoked by pika. | |||||
""" | |||||
LOGGER.info("Declaring exchange %s", self._exchange) | |||||
assert self._channel is not None | |||||
self._channel.exchange_declare( | |||||
exchange=self._exchange, | |||||
exchange_type=self.EXCHANGE_TYPE, | |||||
callback=self.on_exchange_declare_ok, | |||||
) | |||||
def on_exchange_declare_ok(self, _unused_frame: pika.frame.Method) -> None: | |||||
"""Invoked by pika when RabbitMQ has finished the Exchange.Declare RPC | |||||
command. | |||||
:param pika.frame.Method unused_frame: Exchange.DeclareOk response frame | |||||
""" | |||||
LOGGER.info("Exchange declared: %s", self._exchange) | |||||
self.setup_queues() | |||||
def setup_queues(self) -> None: | |||||
"""Setup the queues on RabbitMQ by invoking the Queue.Declare RPC | |||||
command. When it is complete, the on_queue_declare_ok method will | |||||
be invoked by pika. | |||||
""" | |||||
for binding_key in self._binding_keys: | |||||
LOGGER.info("Declaring queue %s", binding_key) | |||||
assert self._channel is not None | |||||
callback = functools.partial( | |||||
self.on_queue_declare_ok, | |||||
binding_key=binding_key, | |||||
) | |||||
self._channel.queue_declare(queue=binding_key, callback=callback) | |||||
def on_queue_declare_ok(self, frame: pika.frame.Method, binding_key: str) -> None: | |||||
"""Method invoked by pika when the Queue.Declare RPC call made in | |||||
setup_queue has completed. In this method we will bind the queue | |||||
and exchange together with the routing key by issuing the Queue.Bind | |||||
RPC command. When this command is complete, the on_bind_ok method will | |||||
be invoked by pika. | |||||
:param pika.frame.Method frame: The Queue.DeclareOk frame | |||||
:param str|unicode binding_key: Binding key of the queue to declare | |||||
""" | |||||
LOGGER.info( | |||||
"Binding queue %s to exchange %s with routing key %s", | |||||
frame.method.queue, | |||||
self._exchange, | |||||
binding_key, | |||||
) | |||||
assert self._channel is not None | |||||
callback = functools.partial(self.on_bind_ok, queue_name=frame.method.queue) | |||||
self._queues[binding_key] = frame.method.queue | |||||
self._channel.queue_bind( | |||||
queue=frame.method.queue, | |||||
exchange=self._exchange, | |||||
routing_key=binding_key, | |||||
callback=callback, | |||||
) | |||||
def on_bind_ok(self, _unused_frame: pika.frame.Method, queue_name: str) -> None: | |||||
"""Invoked by pika when the Queue.Bind method has completed. At this | |||||
point we will set the prefetch count for the channel. | |||||
:param pika.frame.Method _unused_frame: The Queue.BindOk response frame | |||||
:param str|unicode queue_name: The name of the queue to declare | |||||
""" | |||||
LOGGER.info("Queue bound: %s", queue_name) | |||||
self.set_qos() | |||||
def set_qos(self) -> None: | |||||
"""This method sets up the consumer prefetch to only be delivered | |||||
one message at a time. The consumer must acknowledge this message | |||||
before RabbitMQ will deliver another one. You should experiment | |||||
with different prefetch values to achieve desired performance. | |||||
""" | |||||
assert self._channel is not None | |||||
self._channel.basic_qos( | |||||
prefetch_count=self._prefetch_count, callback=self.on_basic_qos_ok | |||||
) | |||||
def on_basic_qos_ok(self, _unused_frame: pika.frame.Method) -> None: | |||||
"""Invoked by pika when the Basic.QoS method has completed. At this | |||||
point we will start consuming messages by calling start_consuming | |||||
which will invoke the needed RPC commands to start the process. | |||||
:param pika.frame.Method _unused_frame: The Basic.QosOk response frame | |||||
""" | |||||
LOGGER.info("QOS set to: %d", self._prefetch_count) | |||||
self.start_consuming() | |||||
def start_consuming(self) -> None: | |||||
"""This method sets up the consumer by first calling | |||||
add_on_cancel_callback so that the object is notified if RabbitMQ | |||||
cancels the consumer. It then issues the Basic.Consume RPC command | |||||
which returns the consumer tag that is used to uniquely identify the | |||||
consumer with RabbitMQ. We keep the value to use it when we want to | |||||
cancel consuming. The on_request method is passed in as a callback pika | |||||
will invoke when a message is fully received. | |||||
""" | |||||
LOGGER.info("Issuing consumer related RPC commands") | |||||
LOGGER.info("Adding consumer cancellation callback") | |||||
assert self._channel is not None | |||||
self._channel.add_on_cancel_callback(callback=self.on_consumer_cancelled) | |||||
for binding_key in self._binding_keys: | |||||
self._consumer_tag[binding_key] = self._channel.basic_consume( | |||||
queue=self._queues[binding_key], on_message_callback=self.on_request | |||||
) | |||||
self._consuming[binding_key] = True | |||||
self.signal.put(ServerCommand.CONSUMING) | |||||
def on_consumer_cancelled(self, method_frame: pika.frame.Method) -> None: | |||||
"""Invoked by pika when RabbitMQ sends a Basic.Cancel for a consumer | |||||
receiving messages. | |||||
:param pika.frame.Method method_frame: The Basic.Cancel frame | |||||
""" | |||||
LOGGER.info("Consumer was cancelled remotely, shutting down: %r", method_frame) | |||||
if self._channel: | |||||
self._channel.close() | |||||
def on_request( | |||||
self, | |||||
channel: pika.channel.Channel, | |||||
deliver: pika.spec.Basic.Deliver, | |||||
properties: pika.spec.BasicProperties, | |||||
body: bytes, | |||||
) -> None: | |||||
"""Invoked by pika when a message is delivered from RabbitMQ. The | |||||
channel is passed for your convenience. The deliver object that | |||||
is passed in carries the exchange, routing key, delivery tag and | |||||
a redelivered flag for the message. The properties passed in is an | |||||
instance of BasicProperties with the message properties and the body | |||||
is the message that was sent. | |||||
:param pika.channel.Channel channel: The channel object | |||||
:param pika.spec.Basic.Deliver: deliver method | |||||
:param pika.spec.BasicProperties: properties | |||||
:param bytes body: The message body | |||||
""" | |||||
LOGGER.info( | |||||
"Received message # %s from %s: %s", | |||||
deliver.delivery_tag, | |||||
properties.app_id, | |||||
body, | |||||
) | |||||
# XXX: for some reason this function is returning lists instead of tuples | |||||
# (the client send tuples) | |||||
item = decode_data(data=body, extra_decoders=self.extra_type_decoders)["data"] | |||||
self._request_queues[deliver.routing_key].put( | |||||
(tuple(item), (properties.correlation_id, properties.reply_to)) | |||||
) | |||||
LOGGER.info("Acknowledging message %s", deliver.delivery_tag) | |||||
channel.basic_ack(delivery_tag=deliver.delivery_tag) | |||||
def stop_consuming(self) -> None: | |||||
"""Tell RabbitMQ that you would like to stop consuming by sending the | |||||
Basic.Cancel RPC command. | |||||
""" | |||||
if self._channel: | |||||
LOGGER.info("Sending a Basic.Cancel RPC command to RabbitMQ") | |||||
for binding_key in self._binding_keys: | |||||
callback = functools.partial(self.on_cancel_ok, binding_key=binding_key) | |||||
self._channel.basic_cancel( | |||||
self._consumer_tag[binding_key], callback=callback | |||||
) | |||||
def on_cancel_ok(self, _unused_frame: pika.frame.Method, binding_key: str) -> None: | |||||
"""This method is invoked by pika when RabbitMQ acknowledges the | |||||
cancellation of a consumer. At this point we will close the channel. | |||||
This will invoke the on_channel_closed method once the channel has been | |||||
closed, which will in-turn close the connection. | |||||
:param pika.frame.Method _unused_frame: The Basic.CancelOk frame | |||||
:param str|unicode binding_key: Binding key of of the consumer to be stopped | |||||
""" | |||||
self._consuming[binding_key] = False | |||||
LOGGER.info( | |||||
"RabbitMQ acknowledged the cancellation of the consumer: %s", | |||||
self._consuming[binding_key], | |||||
) | |||||
LOGGER.info("Closing the channel") | |||||
assert self._channel is not None | |||||
self._channel.close() | |||||
def run(self) -> None: | |||||
"""Run the example code by connecting and then starting the IOLoop.""" | |||||
self._command_thread = threading.Thread(target=self.run_command_thread) | |||||
self._command_thread.start() | |||||
self._request_queues: Dict[str, queue.Queue] = {} | |||||
self._request_threads: Dict[str, threading.Thread] = {} | |||||
for binding_key in self._binding_keys: | |||||
meth_name, relation = ProvenanceStorageRabbitMQServer.get_meth_name( | |||||
binding_key | |||||
) | |||||
self._request_queues[binding_key] = queue.Queue() | |||||
self._request_threads[binding_key] = threading.Thread( | |||||
target=self.run_request_thread, | |||||
args=(binding_key, meth_name, relation), | |||||
) | |||||
self._request_threads[binding_key].start() | |||||
while not self._closing: | |||||
try: | |||||
self._connection = self.connect() | |||||
assert self._connection is not None | |||||
self._connection.ioloop.start() | |||||
except KeyboardInterrupt: | |||||
LOGGER.info("Connection closed by keyboard interruption, reopening") | |||||
if self._connection is not None: | |||||
self._connection.ioloop.stop() | |||||
except TerminateSignal as ex: | |||||
LOGGER.info("Termination requested: %s", ex) | |||||
self.stop() | |||||
if self._connection is not None and not self._connection.is_closed: | |||||
# Finish closing | |||||
self._connection.ioloop.start() | |||||
except BaseException as ex: | |||||
LOGGER.warning("Unexpected exception, terminating: %s", ex) | |||||
self.stop() | |||||
if self._connection is not None and not self._connection.is_closed: | |||||
# Finish closing | |||||
self._connection.ioloop.start() | |||||
for binding_key in self._binding_keys: | |||||
self._request_queues[binding_key].put(TERMINATE) | |||||
for binding_key in self._binding_keys: | |||||
self._request_threads[binding_key].join() | |||||
self._command_thread.join() | |||||
LOGGER.info("Stopped") | |||||
def run_command_thread(self) -> None: | |||||
while True: | |||||
try: | |||||
command = self.command.get() | |||||
if command == ServerCommand.TERMINATE: | |||||
self.request_termination() | |||||
break | |||||
except queue.Empty: | |||||
pass | |||||
except BaseException as ex: | |||||
self.request_termination(str(ex)) | |||||
break | |||||
def request_termination(self, reason: str = "Normal shutdown") -> None: | |||||
assert self._connection is not None | |||||
def termination_callback(): | |||||
raise TerminateSignal(reason) | |||||
self._connection.ioloop.add_callback_threadsafe(termination_callback) | |||||
def run_request_thread( | |||||
self, binding_key: str, meth_name: str, relation: Optional[RelationType] | |||||
) -> None: | |||||
storage = get_provenance_storage(**self._storage_config) | |||||
request_queue = self._request_queues[binding_key] | |||||
merge_items = ProvenanceStorageRabbitMQWorker.get_conflicts_func(meth_name) | |||||
while True: | |||||
terminate = False | |||||
elements = [] | |||||
while True: | |||||
try: | |||||
# TODO: consider reducing this timeout or removing it | |||||
elem = request_queue.get(timeout=0.1) | |||||
if elem is TERMINATE: | |||||
terminate = True | |||||
break | |||||
elements.append(elem) | |||||
except queue.Empty: | |||||
break | |||||
if len(elements) >= self._batch_size: | |||||
break | |||||
if terminate: | |||||
break | |||||
if not elements: | |||||
continue | |||||
try: | |||||
items, props = zip(*elements) | |||||
acks_count: TCounter[Tuple[str, str]] = Counter(props) | |||||
data = merge_items(items) | |||||
args = (relation, data) if relation is not None else (data,) | |||||
if getattr(storage, meth_name)(*args): | |||||
for (correlation_id, reply_to), count in acks_count.items(): | |||||
# FIXME: this is running in a different thread! Hence, if | |||||
# self._connection drops, there is no guarantee that the | |||||
# response can be sent for the current elements. This | |||||
# situation should be handled properly. | |||||
assert self._connection is not None | |||||
self._connection.ioloop.add_callback_threadsafe( | |||||
functools.partial( | |||||
ProvenanceStorageRabbitMQServer.respond, | |||||
channel=self._channel, | |||||
correlation_id=correlation_id, | |||||
reply_to=reply_to, | |||||
response=count, | |||||
) | |||||
) | |||||
else: | |||||
LOGGER.warning( | |||||
"Unable to process elements for queue %s", binding_key | |||||
) | |||||
for elem in elements: | |||||
request_queue.put(elem) | |||||
except BaseException as ex: | |||||
self.request_termination(str(ex)) | |||||
break | |||||
storage.close() | |||||
def stop(self) -> None: | |||||
"""Cleanly shutdown the connection to RabbitMQ by stopping the consumer | |||||
with RabbitMQ. When RabbitMQ confirms the cancellation, on_cancel_ok | |||||
will be invoked by pika, which will then closing the channel and | |||||
connection. The IOLoop is started again because this method is invoked | |||||
by raising a TerminateSignal exception. This exception stops the IOLoop | |||||
which needs to be running for pika to communicate with RabbitMQ. All of | |||||
the commands issued prior to starting the IOLoop will be buffered but | |||||
not processed. | |||||
""" | |||||
assert self._connection is not None | |||||
if not self._closing: | |||||
self._closing = True | |||||
LOGGER.info("Stopping") | |||||
if any(self._consuming): | |||||
self.stop_consuming() | |||||
self._connection.ioloop.start() | |||||
else: | |||||
self._connection.ioloop.stop() | |||||
LOGGER.info("Stopped") | |||||
@staticmethod | |||||
def get_conflicts_func(meth_name: str) -> Callable[[Iterable[Any]], Any]: | |||||
if meth_name in ["content_add", "directory_add"]: | |||||
return resolve_dates | |||||
elif meth_name == "location_add": | |||||
return lambda data: set(data) # just remove duplicates | |||||
elif meth_name == "origin_add": | |||||
return lambda data: dict(data) # last processed value is good enough | |||||
elif meth_name == "revision_add": | |||||
return resolve_revision | |||||
elif meth_name == "relation_add": | |||||
return resolve_relation | |||||
else: | |||||
LOGGER.warning( | |||||
"Unexpected conflict resolution function request for method %s", | |||||
meth_name, | |||||
) | |||||
return lambda x: x | |||||
class ProvenanceStorageRabbitMQServer: | |||||
backend_class = ProvenanceStorageInterface | |||||
extra_type_decoders = DECODERS | |||||
extra_type_encoders = ENCODERS | |||||
queue_count = 16 | |||||
def __init__(self, url: str, storage_config: Dict[str, Any]) -> None: | |||||
self._workers: List[ProvenanceStorageRabbitMQWorker] = [] | |||||
for exchange in ProvenanceStorageRabbitMQServer.get_exchanges(): | |||||
for range in ProvenanceStorageRabbitMQServer.get_ranges(): | |||||
worker = ProvenanceStorageRabbitMQWorker( | |||||
url, exchange, range, storage_config | |||||
) | |||||
self._workers.append(worker) | |||||
self._running = False | |||||
def start(self) -> None: | |||||
if not self._running: | |||||
self._running = True | |||||
for worker in self._workers: | |||||
worker.start() | |||||
for worker in self._workers: | |||||
try: | |||||
signal = worker.signal.get(timeout=60) | |||||
assert signal == ServerCommand.CONSUMING | |||||
except queue.Empty: | |||||
LOGGER.error( | |||||
"Could not initialize worker %s. Leaving...", worker.name | |||||
) | |||||
self.stop() | |||||
return | |||||
LOGGER.info("Start serving") | |||||
def stop(self) -> None: | |||||
if self._running: | |||||
for worker in self._workers: | |||||
worker.command.put(ServerCommand.TERMINATE) | |||||
for worker in self._workers: | |||||
worker.join() | |||||
LOGGER.info("Stop serving") | |||||
self._running = False | |||||
@staticmethod | |||||
def ack(channel: pika.channel.Channel, delivery_tag: int) -> None: | |||||
channel.basic_ack(delivery_tag=delivery_tag) | |||||
@staticmethod | |||||
def get_binding_keys(exchange: str, range: int) -> Generator[str, None, None]: | |||||
for meth_name, relation in ProvenanceStorageRabbitMQServer.get_meth_names( | |||||
exchange | |||||
): | |||||
if relation is None: | |||||
yield f"{meth_name}.unknown.{range:x}".lower() | |||||
else: | |||||
yield f"{meth_name}.{relation.value}.{range:x}".lower() | |||||
@staticmethod | |||||
def get_exchange(meth_name: str, relation: Optional[RelationType] = None) -> str: | |||||
if meth_name == "relation_add": | |||||
assert relation is not None | |||||
split = relation.value | |||||
else: | |||||
split = meth_name | |||||
exchange, *_ = split.split("_") | |||||
return exchange | |||||
@staticmethod | |||||
def get_exchanges() -> Generator[str, None, None]: | |||||
yield from [entity.value for entity in EntityType] + ["location"] | |||||
@staticmethod | |||||
def get_meth_name( | |||||
binding_key: str, | |||||
) -> Tuple[str, Optional[RelationType]]: | |||||
meth_name, relation, *_ = binding_key.split(".") | |||||
return meth_name, (RelationType(relation) if relation != "unknown" else None) | |||||
@staticmethod | |||||
def get_meth_names( | |||||
exchange: str, | |||||
) -> Generator[Tuple[str, Optional[RelationType]], None, None]: | |||||
if exchange == EntityType.CONTENT.value: | |||||
yield from [ | |||||
("content_add", None), | |||||
("relation_add", RelationType.CNT_EARLY_IN_REV), | |||||
("relation_add", RelationType.CNT_IN_DIR), | |||||
] | |||||
elif exchange == EntityType.DIRECTORY.value: | |||||
yield from [ | |||||
("directory_add", None), | |||||
("relation_add", RelationType.DIR_IN_REV), | |||||
] | |||||
elif exchange == EntityType.ORIGIN.value: | |||||
yield from [("origin_add", None)] | |||||
elif exchange == EntityType.REVISION.value: | |||||
yield from [ | |||||
("revision_add", None), | |||||
("relation_add", RelationType.REV_BEFORE_REV), | |||||
("relation_add", RelationType.REV_IN_ORG), | |||||
] | |||||
elif exchange == "location": | |||||
yield "location_add", None | |||||
@staticmethod | |||||
def get_ranges() -> Generator[int, None, None]: | |||||
yield from range(ProvenanceStorageRabbitMQServer.queue_count) | |||||
@staticmethod | |||||
def get_routing_key( | |||||
data: Tuple[bytes, ...], meth_name: str, relation: Optional[RelationType] = None | |||||
) -> str: | |||||
idx = ( | |||||
int(data[0][0]) % ProvenanceStorageRabbitMQServer.queue_count | |||||
if data and data[0] | |||||
else 0 | |||||
) | |||||
if relation is None: | |||||
return f"{meth_name}.unknown.{idx:x}".lower() | |||||
else: | |||||
return f"{meth_name}.{relation.value}.{idx:x}".lower() | |||||
@staticmethod | |||||
def is_write_method(meth_name: str) -> bool: | |||||
return "_add" in meth_name | |||||
@staticmethod | |||||
def respond( | |||||
channel: pika.channel.Channel, | |||||
correlation_id: str, | |||||
reply_to: str, | |||||
response: Any, | |||||
): | |||||
channel.basic_publish( | |||||
exchange="", | |||||
routing_key=reply_to, | |||||
properties=pika.BasicProperties( | |||||
content_type="application/msgpack", | |||||
correlation_id=correlation_id, | |||||
), | |||||
body=encode_data( | |||||
response, | |||||
extra_encoders=ProvenanceStorageRabbitMQServer.extra_type_encoders, | |||||
), | |||||
) | |||||
def make_server_from_configfile() -> ProvenanceStorageRabbitMQServer: | |||||
config_path = os.environ.get("SWH_CONFIG_FILENAME") | |||||
server_cfg = load_and_check_config(config_path) | |||||
return ProvenanceStorageRabbitMQServer( | |||||
url=server_cfg["provenance"]["rabbitmq"]["url"], | |||||
storage_config=server_cfg["provenance"]["storage"], | |||||
) |