Page Menu
Home
Software Heritage
Search
Configure Global Search
Log In
Files
F7124015
D6165.id22648.diff
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
61 KB
Subscribers
None
D6165.id22648.diff
View Options
diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py
--- a/swh/provenance/__init__.py
+++ b/swh/provenance/__init__.py
@@ -104,19 +104,4 @@
assert isinstance(rmq_storage, ProvenanceStorageInterface)
return rmq_storage
- elif cls in ["remote", "restapi"]:
- from .api.client import ProvenanceStorageRPCClient
-
- if cls == "remote":
- warnings.warn(
- '"remote" class is deprecated for provenance storage, please '
- 'use "restapi" class instead.',
- DeprecationWarning,
- )
-
- rpc_storage = ProvenanceStorageRPCClient(**kwargs)
- if TYPE_CHECKING:
- assert isinstance(rpc_storage, ProvenanceStorageInterface)
- return rpc_storage
-
raise ValueError
diff --git a/swh/provenance/api/client.py b/swh/provenance/api/client.py
--- a/swh/provenance/api/client.py
+++ b/swh/provenance/api/client.py
@@ -5,28 +5,40 @@
import functools
import inspect
-from typing import Any
+import logging
+import queue
+import threading
+import time
+from typing import Any, Dict, Optional
import uuid
import pika
import pika.channel
import pika.connection
+import pika.frame
import pika.spec
-from swh.core.api import RPCClient
from swh.core.api.serializers import encode_data_client as encode_data
from swh.core.api.serializers import msgpack_loads as decode_data
+from swh.provenance import get_provenance_storage
from ..interface import ProvenanceStorageInterface
from .serializers import DECODERS, ENCODERS
+from .server import ProvenanceStorageRabbitMQServer
+LOG_FORMAT = (
+ "%(levelname) -10s %(asctime)s %(name) -30s %(funcName) "
+ "-35s %(lineno) -5d: %(message)s"
+)
+LOGGER = logging.getLogger(__name__)
-class ProvenanceStorageRPCClient(RPCClient):
- """Proxy to a remote provenance storage API"""
- backend_class = ProvenanceStorageInterface
- extra_type_decoders = DECODERS
- extra_type_encoders = ENCODERS
+class ConfigurationError(Exception):
+ pass
+
+
+class ResponseTimeout(Exception):
+ pass
class MetaRabbitMQClient(type):
@@ -43,7 +55,7 @@
break
backend_class = getattr(base, "backend_class", None)
if backend_class:
- for (meth_name, meth) in backend_class.__dict__.items():
+ for meth_name, meth in backend_class.__dict__.items():
if hasattr(meth, "_endpoint_path"):
cls.__add_endpoint(meth_name, meth, attributes)
return super().__new__(cls, name, bases, attributes)
@@ -54,61 +66,478 @@
@functools.wraps(meth) # Copy signature and doc
def meth_(*args, **kwargs):
+ # Match arguments and parameters
+ data = inspect.getcallargs(wrapped_meth, *args, **kwargs)
+
+ # Remove arguments that should not be passed
+ self = data.pop("self")
+
+ # Call storage method with remaining arguments
+ return getattr(self._storage, meth_name)(**data)
+
+ @functools.wraps(meth) # Copy signature and doc
+ def write_meth_(*args, **kwargs):
# Match arguments and parameters
post_data = inspect.getcallargs(wrapped_meth, *args, **kwargs)
# Remove arguments that should not be passed
self = post_data.pop("self")
- post_data.pop("cur", None)
- post_data.pop("db", None)
+ relation = post_data.pop("relation", None)
+ assert len(post_data) == 1
+ if relation is not None:
+ items = [
+ (src, rel.dst, rel.path)
+ for src, dsts in next(iter(post_data.values())).items()
+ for rel in dsts
+ ]
+ else:
+ data = next(iter(post_data.values()))
+ items = (
+ list(data.items())
+ if isinstance(data, dict)
+ else list({(item,) for item in data})
+ )
- # Send the request.
- return self.request(meth_name, **post_data)
+ acks_expected = len(items)
+ self._correlation_id = str(uuid.uuid4())
+ exchange = ProvenanceStorageRabbitMQServer.get_exchange(meth_name, relation)
+ for item in items:
+ routing_key = ProvenanceStorageRabbitMQServer.get_routing_key(
+ item, meth_name, relation
+ )
+ # FIXME: this is running in a different thread! Hence, if
+ # self._connection drops, there is no guarantee that the request can
+ # be sent for the current elements. This situation should be handled
+ # properly.
+ while self._callback_queue is None:
+ time.sleep(0.1)
+ self._connection.ioloop.add_callback_threadsafe(
+ functools.partial(
+ ProvenanceStorageRabbitMQClient.request,
+ channel=self._channel,
+ reply_to=self._callback_queue,
+ exchange=exchange,
+ routing_key=routing_key,
+ correlation_id=self._correlation_id,
+ data=item,
+ )
+ )
+ return self.wait_for_acks(acks_expected)
if meth_name not in attributes:
- attributes[meth_name] = meth_
+ attributes[meth_name] = (
+ write_meth_
+ if ProvenanceStorageRabbitMQServer.is_write_method(meth_name)
+ else meth_
+ )
class ProvenanceStorageRabbitMQClient(metaclass=MetaRabbitMQClient):
+ """This is an example publisher that will handle unexpected interactions
+ with RabbitMQ such as channel and connection closures.
+
+ If RabbitMQ closes the connection, it will reopen it. You should
+ look at the output, as there are limited reasons why the connection may
+ be closed, which usually are tied to permission related issues or
+ socket timeouts.
+
+ It uses delivery confirmations and illustrates one way to keep track of
+ messages that have been sent and if they've been confirmed by RabbitMQ.
+
+ """
+
backend_class = ProvenanceStorageInterface
extra_type_decoders = DECODERS
extra_type_encoders = ENCODERS
- def __init__(self, url: str) -> None:
- self.conn = pika.BlockingConnection(pika.connection.URLParameters(url))
- self.channel = self.conn.channel()
+ def __init__(self, url: str, storage_config: Dict[str, Any]) -> None:
+ """Setup the example publisher object, passing in the URL we will use
+ to connect to RabbitMQ.
+
+ :param str url: The URL for connecting to RabbitMQ
+ :param str routing_key: The routing key name from which this worker will
+ consume messages
+ :param str storage_config: Configuration parameters for the underlying
+ ``ProvenanceStorage`` object
+
+ """
+ self._connection = None
+ self._callback_queue: Optional[str] = None
+ self._channel = None
+ self._closing = False
+ self._consumer_tag = None
+ self._consuming = False
+ self._correlation_id = str(uuid.uuid4())
+ self._prefetch_count = 100
+
+ self._url = url
+ self._storage = get_provenance_storage(**storage_config)
+ self._response_queue: queue.Queue = queue.Queue()
- result = self.channel.queue_declare(queue="", exclusive=True)
- self.callback_queue = result.method.queue
+ self._consumer_thread = threading.Thread(target=self.run)
+ self._consumer_thread.start()
- self.channel.basic_consume(
- queue=self.callback_queue,
- on_message_callback=self.on_response,
- auto_ack=True,
+ def connect(self) -> pika.SelectConnection:
+ """This method connects to RabbitMQ, returning the connection handle.
+ When the connection is established, the on_connection_open method
+ will be invoked by pika.
+
+ :rtype: pika.SelectConnection
+
+ """
+ LOGGER.info("Connecting to %s", self._url)
+ return pika.SelectConnection(
+ parameters=pika.URLParameters(self._url),
+ on_open_callback=self.on_connection_open,
+ on_open_error_callback=self.on_connection_open_error,
+ on_close_callback=self.on_connection_closed,
)
- def request(self, method: str, **kwargs) -> Any:
- self.response = None
- self.corr_id = str(uuid.uuid4())
- self.channel.basic_publish(
- exchange="",
- routing_key=method,
- properties=pika.BasicProperties(
- reply_to=self.callback_queue,
- correlation_id=self.corr_id,
- ),
- body=encode_data(kwargs, extra_encoders=self.extra_type_encoders),
+ def close_connection(self) -> None:
+ assert self._connection is not None
+ self._consuming = False
+ if self._connection.is_closing or self._connection.is_closed:
+ LOGGER.info("Connection is closing or already closed")
+ else:
+ LOGGER.info("Closing connection")
+ self._connection.close()
+
+ def on_connection_open(self, _unused_connection: pika.SelectConnection) -> None:
+ """This method is called by pika once the connection to RabbitMQ has
+ been established. It passes the handle to the connection object in
+ case we need it, but in this case, we'll just mark it unused.
+
+ :param pika.SelectConnection _unused_connection: The connection
+
+ """
+ LOGGER.info("Connection opened")
+ self.open_channel()
+
+ def on_connection_open_error(
+ self, _unused_connection: pika.SelectConnection, err: Exception
+ ) -> None:
+ """This method is called by pika if the connection to RabbitMQ
+ can't be established.
+
+ :param pika.SelectConnection _unused_connection: The connection
+ :param Exception err: The error
+
+ """
+ LOGGER.error("Connection open failed, reopening in 5 seconds: %s", err)
+ assert self._connection is not None
+ self._connection.ioloop.call_later(5, self._connection.ioloop.stop)
+
+ def on_connection_closed(self, _unused_connection: pika.SelectConnection, reason):
+ """This method is invoked by pika when the connection to RabbitMQ is
+ closed unexpectedly. Since it is unexpected, we will reconnect to
+ RabbitMQ if it disconnects.
+
+ :param pika.connection.Connection connection: The closed connection obj
+ :param Exception reason: exception representing reason for loss of
+ connection.
+
+ """
+ assert self._connection is not None
+ self._channel = None
+ if self._closing:
+ self._connection.ioloop.stop()
+ else:
+ LOGGER.warning("Connection closed, reopening in 5 seconds: %s", reason)
+ self._connection.ioloop.call_later(5, self._connection.ioloop.stop)
+
+ def open_channel(self) -> None:
+ """Open a new channel with RabbitMQ by issuing the Channel.Open RPC
+ command. When RabbitMQ responds that the channel is open, the
+ on_channel_open callback will be invoked by pika.
+
+ """
+ LOGGER.info("Creating a new channel")
+ assert self._connection is not None
+ self._connection.channel(on_open_callback=self.on_channel_open)
+
+ def on_channel_open(self, channel: pika.channel.Channel) -> None:
+ """This method is invoked by pika when the channel has been opened.
+ The channel object is passed in so we can make use of it.
+
+ Since the channel is now open, we'll declare the exchange to use.
+
+ :param pika.channel.Channel channel: The channel object
+
+ """
+ LOGGER.info("Channel opened")
+ self._channel = channel
+ self.add_on_channel_close_callback()
+ self.setup_queue()
+
+ def add_on_channel_close_callback(self) -> None:
+ """This method tells pika to call the on_channel_closed method if
+ RabbitMQ unexpectedly closes the channel.
+
+ """
+ LOGGER.info("Adding channel close callback")
+ assert self._channel is not None
+ self._channel.add_on_close_callback(callback=self.on_channel_closed)
+
+ def on_channel_closed(
+ self, channel: pika.channel.Channel, reason: Exception
+ ) -> None:
+ """Invoked by pika when RabbitMQ unexpectedly closes the channel.
+ Channels are usually closed if you attempt to do something that
+ violates the protocol, such as re-declare an exchange or queue with
+ different parameters. In this case, we'll close the connection
+ to shutdown the object.
+
+ :param pika.channel.Channel: The closed channel
+ :param Exception reason: why the channel was closed
+
+ """
+ LOGGER.warning("Channel %i was closed: %s", channel, reason)
+ self.close_connection()
+
+ def setup_queue(self) -> None:
+ """Setup the queue on RabbitMQ by invoking the Queue.Declare RPC
+ command. When it is complete, the on_queue_declareok method will
+ be invoked by pika.
+
+ """
+ LOGGER.info("Declaring callback queue")
+ assert self._channel is not None
+ self._channel.queue_declare(
+ queue="", exclusive=True, callback=self.on_queue_declareok
)
- while self.response is None:
- self.conn.process_data_events()
- return self.response
+
+ def on_queue_declareok(self, frame: pika.frame.Method) -> None:
+ """Method invoked by pika when the Queue.Declare RPC call made in
+ setup_queue has completed. This method sets up the consumer prefetch
+ to only be delivered one message at a time. The consumer must
+ acknowledge this message before RabbitMQ will deliver another one.
+ You should experiment with different prefetch values to achieve desired
+ performance.
+
+ :param pika.frame.Method method_frame: The Queue.DeclareOk frame
+
+ """
+ LOGGER.info("Binding queue to default exchanger")
+ assert self._channel is not None
+ self._callback_queue = frame.method.queue
+ self._channel.basic_qos(
+ prefetch_count=self._prefetch_count, callback=self.on_basic_qos_ok
+ )
+
+ def on_basic_qos_ok(self, _unused_frame: pika.frame.Method) -> None:
+ """Invoked by pika when the Basic.QoS method has completed. At this
+ point we will start consuming messages by calling start_consuming
+ which will invoke the needed RPC commands to start the process.
+
+ :param pika.frame.Method _unused_frame: The Basic.QosOk response frame
+
+ """
+ LOGGER.info("QOS set to: %d", self._prefetch_count)
+ self.start_consuming()
+
+ def start_consuming(self) -> None:
+ """This method sets up the consumer by first calling
+ add_on_cancel_callback so that the object is notified if RabbitMQ
+ cancels the consumer. It then issues the Basic.Consume RPC command
+ which returns the consumer tag that is used to uniquely identify the
+ consumer with RabbitMQ. We keep the value to use it when we want to
+ cancel consuming. The on_response method is passed in as a callback pika
+ will invoke when a message is fully received.
+
+ """
+ LOGGER.info("Issuing consumer related RPC commands")
+ assert self._channel is not None
+ self.add_on_cancel_callback()
+ assert self._callback_queue is not None
+ self._consumer_tag = self._channel.basic_consume(
+ queue=self._callback_queue, on_message_callback=self.on_response
+ )
+ self._consuming = True
+
+ def add_on_cancel_callback(self) -> None:
+ """Add a callback that will be invoked if RabbitMQ cancels the consumer
+ for some reason. If RabbitMQ does cancel the consumer,
+ on_consumer_cancelled will be invoked by pika.
+
+ """
+ LOGGER.info("Adding consumer cancellation callback")
+ assert self._channel is not None
+ self._channel.add_on_cancel_callback(callback=self.on_consumer_cancelled)
+
+ def on_consumer_cancelled(self, method_frame: pika.frame.Method) -> None:
+ """Invoked by pika when RabbitMQ sends a Basic.Cancel for a consumer
+ receiving messages.
+
+ :param pika.frame.Method method_frame: The Basic.Cancel frame
+
+ """
+ LOGGER.info("Consumer was cancelled remotely, shutting down: %r", method_frame)
+ if self._channel:
+ self._channel.close()
def on_response(
self,
- channel: pika.channel.Channel,
- method: pika.spec.Basic.Deliver,
- props: pika.spec.BasicProperties,
+ _unused_channel: pika.channel.Channel,
+ basic_deliver: pika.spec.Basic.Deliver,
+ properties: pika.spec.BasicProperties,
body: bytes,
) -> None:
- if self.corr_id == props.correlation_id:
- self.response = decode_data(body, extra_decoders=self.extra_type_decoders)
+ """Invoked by pika when a message is delivered from RabbitMQ. The
+ channel is passed for your convenience. The basic_deliver object that
+ is passed in carries the exchange, routing key, delivery tag and
+ a redelivered flag for the message. The properties passed in is an
+ instance of BasicProperties with the message properties and the body
+ is the message that was sent.
+
+ :param pika.channel.Channel _unused_channel: The channel object
+ :param pika.spec.Basic.Deliver: basic_deliver method
+ :param pika.spec.BasicProperties: properties
+ :param bytes body: The message body
+
+ """
+ LOGGER.info(
+ "Received message # %s from %s: %s",
+ basic_deliver.delivery_tag,
+ properties.app_id,
+ body,
+ )
+ self._response_queue.put(
+ (
+ properties.correlation_id,
+ decode_data(body, extra_decoders=self.extra_type_decoders),
+ )
+ )
+ self.acknowledge_message(delivery_tag=basic_deliver.delivery_tag)
+
+ def acknowledge_message(self, delivery_tag: int) -> None:
+ """Acknowledge the message delivery from RabbitMQ by sending a
+ Basic.Ack RPC method for the delivery tag.
+
+ :param int delivery_tag: The delivery tag from the Basic.Deliver frame
+
+ """
+ LOGGER.info("Acknowledging message %s", delivery_tag)
+ assert self._channel is not None
+ self._channel.basic_ack(delivery_tag=delivery_tag)
+
+ def stop_consuming(self) -> None:
+ """Tell RabbitMQ that you would like to stop consuming by sending the
+ Basic.Cancel RPC command.
+
+ """
+ if self._channel:
+ LOGGER.info("Sending a Basic.Cancel RPC command to RabbitMQ")
+ self._channel.basic_cancel(self._consumer_tag, self.on_cancelok)
+
+ def on_cancelok(self, _unused_frame: pika.frame.Method) -> None:
+ """This method is invoked by pika when RabbitMQ acknowledges the
+ cancellation of a consumer. At this point we will close the channel.
+ This will invoke the on_channel_closed method once the channel has been
+ closed, which will in-turn close the connection.
+
+ :param pika.frame.Method _unused_frame: The Basic.CancelOk frame
+ :param str|unicode consumer_tag: Tag of the consumer to be stopped
+
+ """
+ self._consuming = False
+ LOGGER.info(
+ "RabbitMQ acknowledged the cancellation of the consumer: %s",
+ self._consumer_tag,
+ )
+ self.close_channel()
+
+ def close_channel(self) -> None:
+ """Call to close the channel with RabbitMQ cleanly by issuing the
+ Channel.Close RPC command.
+
+ """
+ LOGGER.info("Closing the channel")
+ assert self._channel is not None
+ self._channel.close()
+
+ def run(self) -> None:
+ """Run the example code by connecting and then starting the IOLoop."""
+
+ while not self._closing:
+ try:
+ self._connection = self.connect()
+ assert self._connection is not None
+ self._connection.ioloop.start()
+ except KeyboardInterrupt:
+ self.stop()
+ if self._connection is not None and not self._connection.is_closed:
+ # Finish closing
+ self._connection.ioloop.start()
+ LOGGER.info("Stopped")
+
+ def stop(self) -> None:
+ """Cleanly shutdown the connection to RabbitMQ by stopping the consumer
+ with RabbitMQ. When RabbitMQ confirms the cancellation, on_cancelok
+ will be invoked by pika, which will then closing the channel and
+ connection. The IOLoop is started again because this method is invoked
+ when CTRL-C is pressed raising a KeyboardInterrupt exception. This
+ exception stops the IOLoop which needs to be running for pika to
+ communicate with RabbitMQ. All of the commands issued prior to starting
+ the IOLoop will be buffered but not processed.
+
+ """
+ assert self._connection is not None
+ if not self._closing:
+ self._closing = True
+ LOGGER.info("Stopping")
+ if self._consuming:
+ self.stop_consuming()
+ self._connection.ioloop.start()
+ else:
+ self._connection.ioloop.stop()
+ LOGGER.info("Stopped")
+
+ @staticmethod
+ def request(
+ channel: pika.channel.Channel,
+ reply_to: str,
+ exchange: str,
+ routing_key: str,
+ correlation_id: str,
+ **kwargs,
+ ) -> None:
+ channel.basic_publish(
+ exchange=exchange,
+ routing_key=routing_key,
+ properties=pika.BasicProperties(
+ content_type="application/msgpack",
+ correlation_id=correlation_id,
+ reply_to=reply_to,
+ ),
+ body=encode_data(
+ kwargs,
+ extra_encoders=ProvenanceStorageRabbitMQClient.extra_type_encoders,
+ ),
+ )
+
+ def wait_for_acks(self, acks_expected: int) -> bool:
+ acks_received = 0
+ while acks_received < acks_expected:
+ try:
+ acks_received += self.wait_for_response()
+ except ResponseTimeout:
+ LOGGER.warning(
+ "Timed out waiting for acks, %s received, %s expected",
+ acks_received,
+ acks_expected,
+ )
+ return False
+ return acks_received == acks_expected
+
+ def wait_for_response(self, timeout: float = 60.0) -> Any:
+ start = time.monotonic()
+ while True:
+ try:
+ correlation_id, response = self._response_queue.get(block=False)
+ if correlation_id == self._correlation_id:
+ return response
+ except queue.Empty:
+ pass
+
+ if self._response_queue.empty() and time.monotonic() > start + timeout:
+ raise ResponseTimeout
diff --git a/swh/provenance/api/server.py b/swh/provenance/api/server.py
--- a/swh/provenance/api/server.py
+++ b/swh/provenance/api/server.py
@@ -3,85 +3,727 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
+from collections import Counter
+from datetime import datetime
+import functools
import logging
+import multiprocessing
import os
-from typing import Any, Dict, List, Optional
+import queue
+import threading
+import time
+from typing import Any, Callable
+from typing import Counter as TCounter
+from typing import Dict, Generator, Iterable, List, Optional, Set, Tuple, Union, cast
import pika
import pika.channel
import pika.connection
+import pika.exceptions
+from pika.exchange_type import ExchangeType
+import pika.frame
import pika.spec
-from werkzeug.routing import Rule
from swh.core import config
-from swh.core.api import JSONFormatter, MsgpackFormatter, RPCServerApp, negotiate
from swh.core.api.serializers import encode_data_client as encode_data
from swh.core.api.serializers import msgpack_loads as decode_data
+from swh.model.model import Sha1Git
from swh.provenance import get_provenance_storage
-from swh.provenance.interface import ProvenanceStorageInterface
+from swh.provenance.interface import (
+ EntityType,
+ ProvenanceStorageInterface,
+ RelationData,
+ RelationType,
+ RevisionData,
+)
from .serializers import DECODERS, ENCODERS
-storage: Optional[ProvenanceStorageInterface] = None
+LOG_FORMAT = (
+ "%(levelname) -10s %(asctime)s %(name) -30s %(funcName) "
+ "-35s %(lineno) -5d: %(message)s"
+)
+LOGGER = logging.getLogger(__name__)
+
+TERMINATE = object()
-def get_global_provenance_storage() -> ProvenanceStorageInterface:
- global storage
- if storage is None:
- storage = get_provenance_storage(**app.config["provenance"]["storage"])
- return storage
+def resolve_dates(
+ dates: Iterable[Union[Tuple[Sha1Git, Optional[datetime]], Tuple[Sha1Git]]]
+) -> Dict[Sha1Git, Optional[datetime]]:
+ result: Dict[Sha1Git, Optional[datetime]] = {}
+ for row in dates:
+ sha1 = row[0]
+ date = (
+ cast(Tuple[Sha1Git, Optional[datetime]], row)[1] if len(row) > 1 else None
+ )
+ known = result.setdefault(sha1, None)
+ if date is not None and (known is None or date < known):
+ result[sha1] = date
+ return result
+
+
+def resolve_revision(
+ data: Iterable[Union[Tuple[Sha1Git, RevisionData], Tuple[Sha1Git]]]
+) -> Dict[Sha1Git, RevisionData]:
+ result: Dict[Sha1Git, RevisionData] = {}
+ for row in data:
+ sha1 = row[0]
+ rev = (
+ cast(Tuple[Sha1Git, RevisionData], row)[1]
+ if len(row) > 1
+ else RevisionData(date=None, origin=None)
+ )
+ known = result.setdefault(sha1, RevisionData(date=None, origin=None))
+ value = known
+ if rev.date is not None and (known.date is None or rev.date < known.date):
+ value = RevisionData(date=rev.date, origin=value.origin)
+ if rev.origin is not None:
+ value = RevisionData(date=value.date, origin=rev.origin)
+ if value != known:
+ result[sha1] = value
+ return result
+
+
+def resolve_relation(
+ data: Iterable[Tuple[Sha1Git, Sha1Git, bytes]]
+) -> Dict[Sha1Git, Set[RelationData]]:
+ result: Dict[Sha1Git, Set[RelationData]] = {}
+ for src, dst, path in data:
+ result.setdefault(src, set()).add(RelationData(dst=dst, path=path))
+ return result
+
+
+class ProvenanceStorageRabbitMQWorker(multiprocessing.Process):
+ """This is an example publisher that will handle unexpected interactions
+ with RabbitMQ such as channel and connection closures.
+
+ If RabbitMQ closes the connection, it will reopen it. You should
+ look at the output, as there are limited reasons why the connection may
+ be closed, which usually are tied to permission related issues or
+ socket timeouts.
+
+ It uses delivery confirmations and illustrates one way to keep track of
+ messages that have been sent and if they've been confirmed by RabbitMQ.
+
+ """
+ EXCHANGE_TYPE = ExchangeType.topic
-class ProvenanceStorageRPCServerApp(RPCServerApp):
extra_type_decoders = DECODERS
extra_type_encoders = ENCODERS
+ def __init__(
+ self, url: str, exchange: str, range: int, storage_config: Dict[str, Any]
+ ) -> None:
+ """Setup the example publisher object, passing in the URL we will use
+ to connect to RabbitMQ.
+
+ :param str url: The URL for connecting to RabbitMQ
+ :param str routing_key: The routing key name from which this worker will
+ consume messages
+ :param str storage_config: Configuration parameters for the underlying
+ ``ProvenanceStorage`` object
+
+ """
+ super().__init__(name=f"{exchange}_{range:x}")
+
+ self._connection = None
+ self._channel = None
+ self._closing = False
+ self._consumer_tag: Dict[str, str] = {}
+ self._consuming: Dict[str, bool] = {}
+ self._prefetch_count = 100
+
+ self._url = url
+ self._exchange = exchange
+ self._binding_keys = list(
+ ProvenanceStorageRabbitMQServer.get_binding_keys(self._exchange, range)
+ )
+ self._queues: Dict[str, str] = {}
+ self._storage_config = storage_config
+ self._batch_size = 100
+
+ def connect(self) -> pika.SelectConnection:
+ """This method connects to RabbitMQ, returning the connection handle.
+ When the connection is established, the on_connection_open method
+ will be invoked by pika.
+
+ :rtype: pika.SelectConnection
+
+ """
+ LOGGER.info("Connecting to %s", self._url)
+ return pika.SelectConnection(
+ parameters=pika.URLParameters(self._url),
+ on_open_callback=self.on_connection_open,
+ on_open_error_callback=self.on_connection_open_error,
+ on_close_callback=self.on_connection_closed,
+ )
-app = ProvenanceStorageRPCServerApp(
- __name__,
- backend_class=ProvenanceStorageInterface,
- backend_factory=get_global_provenance_storage,
-)
+ def close_connection(self) -> None:
+ assert self._connection is not None
+ self._consuming = {binding_key: False for binding_key in self._binding_keys}
+ if self._connection.is_closing or self._connection.is_closed:
+ LOGGER.info("Connection is closing or already closed")
+ else:
+ LOGGER.info("Closing connection")
+ self._connection.close()
+ def on_connection_open(self, _unused_connection: pika.SelectConnection) -> None:
+ """This method is called by pika once the connection to RabbitMQ has
+ been established. It passes the handle to the connection object in
+ case we need it, but in this case, we'll just mark it unused.
-def has_no_empty_params(rule: Rule) -> bool:
- return len(rule.defaults or ()) >= len(rule.arguments or ())
-
-
-@app.route("/")
-def index() -> str:
- return """<html>
-<head><title>Software Heritage provenance storage RPC server</title></head>
-<body>
-<p>You have reached the
-<a href="https://www.softwareheritage.org/">Software Heritage</a>
-provenance storage RPC server.<br />
-See its
-<a href="https://docs.softwareheritage.org/devel/swh-provenance/">documentation
-and API</a> for more information</p>
-</body>
-</html>"""
-
-
-@app.route("/site-map")
-@negotiate(MsgpackFormatter)
-@negotiate(JSONFormatter)
-def site_map() -> List[Dict[str, Any]]:
- links = []
- for rule in app.url_map.iter_rules():
- if has_no_empty_params(rule) and hasattr(
- ProvenanceStorageInterface, rule.endpoint
- ):
- links.append(
- dict(
- rule=rule.rule,
- description=getattr(
- ProvenanceStorageInterface, rule.endpoint
- ).__doc__,
+ :param pika.SelectConnection _unused_connection: The connection
+
+ """
+ LOGGER.info("Connection opened")
+ self.open_channel()
+
+ def on_connection_open_error(
+ self, _unused_connection: pika.SelectConnection, err: Exception
+ ) -> None:
+ """This method is called by pika if the connection to RabbitMQ
+ can't be established.
+
+ :param pika.SelectConnection _unused_connection: The connection
+ :param Exception err: The error
+
+ """
+ LOGGER.error("Connection open failed, reopening in 5 seconds: %s", err)
+ assert self._connection is not None
+ self._connection.ioloop.call_later(5, self._connection.ioloop.stop)
+
+ def on_connection_closed(self, _unused_connection: pika.SelectConnection, reason):
+ """This method is invoked by pika when the connection to RabbitMQ is
+ closed unexpectedly. Since it is unexpected, we will reconnect to
+ RabbitMQ if it disconnects.
+
+ :param pika.connection.Connection connection: The closed connection obj
+ :param Exception reason: exception representing reason for loss of
+ connection.
+
+ """
+ assert self._connection is not None
+ self._channel = None
+ if self._closing:
+ self._connection.ioloop.stop()
+ else:
+ LOGGER.warning("Connection closed, reopening in 5 seconds: %s", reason)
+ self._connection.ioloop.call_later(5, self._connection.ioloop.stop)
+
+ def open_channel(self) -> None:
+ """Open a new channel with RabbitMQ by issuing the Channel.Open RPC
+ command. When RabbitMQ responds that the channel is open, the
+ on_channel_open callback will be invoked by pika.
+
+ """
+ LOGGER.info("Creating a new channel")
+ assert self._connection is not None
+ self._connection.channel(on_open_callback=self.on_channel_open)
+
+ def on_channel_open(self, channel: pika.channel.Channel) -> None:
+ """This method is invoked by pika when the channel has been opened.
+ The channel object is passed in so we can make use of it.
+
+ Since the channel is now open, we'll declare the exchange to use.
+
+ :param pika.channel.Channel channel: The channel object
+
+ """
+ LOGGER.info("Channel opened")
+ self._channel = channel
+ self.add_on_channel_close_callback()
+ self.setup_exchange()
+
+ def add_on_channel_close_callback(self) -> None:
+ """This method tells pika to call the on_channel_closed method if
+ RabbitMQ unexpectedly closes the channel.
+
+ """
+ LOGGER.info("Adding channel close callback")
+ assert self._channel is not None
+ self._channel.add_on_close_callback(callback=self.on_channel_closed)
+
+ def on_channel_closed(
+ self, channel: pika.channel.Channel, reason: Exception
+ ) -> None:
+ """Invoked by pika when RabbitMQ unexpectedly closes the channel.
+ Channels are usually closed if you attempt to do something that
+ violates the protocol, such as re-declare an exchange or queue with
+ different parameters. In this case, we'll close the connection
+ to shutdown the object.
+
+ :param pika.channel.Channel: The closed channel
+ :param Exception reason: why the channel was closed
+
+ """
+ LOGGER.warning("Channel %i was closed: %s", channel, reason)
+ self.close_connection()
+
+ def setup_exchange(self) -> None:
+ """Setup the exchange on RabbitMQ by invoking the Exchange.Declare RPC
+ command. When it is complete, the on_exchange_declareok method will
+ be invoked by pika.
+
+ """
+ LOGGER.info("Declaring exchange %s", self._exchange)
+ assert self._channel is not None
+ self._channel.exchange_declare(
+ exchange=self._exchange,
+ exchange_type=self.EXCHANGE_TYPE,
+ callback=self.on_exchange_declareok,
+ )
+
+ def on_exchange_declareok(self, _unused_frame: pika.frame.Method) -> None:
+ """Invoked by pika when RabbitMQ has finished the Exchange.Declare RPC
+ command.
+
+ :param pika.frame.Method unused_frame: Exchange.DeclareOk response frame
+
+ """
+ LOGGER.info("Exchange declared: %s", self._exchange)
+ self.setup_queues()
+
+ def setup_queues(self) -> None:
+ """Setup the queues on RabbitMQ by invoking the Queue.Declare RPC
+ command. When it is complete, the on_queue_declareok method will
+ be invoked by pika.
+
+ """
+ for binding_key in self._binding_keys:
+ LOGGER.info("Declaring queue %s", binding_key)
+ assert self._channel is not None
+ callback = functools.partial(
+ self.on_queue_declareok,
+ binding_key=binding_key,
+ )
+ self._channel.queue_declare(queue="", exclusive=True, callback=callback)
+
+ def on_queue_declareok(self, frame: pika.frame.Method, binding_key: str) -> None:
+ """Method invoked by pika when the Queue.Declare RPC call made in
+ setup_queue has completed. In this method we will bind the queue
+ and exchange together with the routing key by issuing the Queue.Bind
+ RPC command. When this command is complete, the on_bindok method will
+ be invoked by pika.
+
+ :param pika.frame.Method frame: The Queue.DeclareOk frame
+ :param str|unicode binding_key: Binding key of the queue to declare
+
+ """
+ LOGGER.info(
+ "Binding queue %s to exchange %s with routing key %s",
+ frame.method.queue,
+ self._exchange,
+ binding_key,
+ )
+ assert self._channel is not None
+ callback = functools.partial(self.on_bindok, queue_name=frame.method.queue)
+ self._queues[binding_key] = frame.method.queue
+ self._channel.queue_bind(
+ queue=frame.method.queue,
+ exchange=self._exchange,
+ routing_key=binding_key,
+ callback=callback,
+ )
+
+ def on_bindok(self, _unused_frame: pika.frame.Method, queue_name: str) -> None:
+ """Invoked by pika when the Queue.Bind method has completed. At this
+ point we will set the prefetch count for the channel.
+
+ :param pika.frame.Method _unused_frame: The Queue.BindOk response frame
+ :param str|unicode queue_name: The name of the queue to declare
+
+ """
+ LOGGER.info("Queue bound: %s", queue_name)
+ self.set_qos()
+
+ def set_qos(self) -> None:
+ """This method sets up the consumer prefetch to only be delivered
+ one message at a time. The consumer must acknowledge this message
+ before RabbitMQ will deliver another one. You should experiment
+ with different prefetch values to achieve desired performance.
+
+ """
+ assert self._channel is not None
+ self._channel.basic_qos(
+ prefetch_count=self._prefetch_count, callback=self.on_basic_qos_ok
+ )
+
+ def on_basic_qos_ok(self, _unused_frame: pika.frame.Method) -> None:
+ """Invoked by pika when the Basic.QoS method has completed. At this
+ point we will start consuming messages by calling start_consuming
+ which will invoke the needed RPC commands to start the process.
+
+ :param pika.frame.Method _unused_frame: The Basic.QosOk response frame
+
+ """
+ LOGGER.info("QOS set to: %d", self._prefetch_count)
+ self.start_consuming()
+
+ def start_consuming(self) -> None:
+ """This method sets up the consumer by first calling
+ add_on_cancel_callback so that the object is notified if RabbitMQ
+ cancels the consumer. It then issues the Basic.Consume RPC command
+ which returns the consumer tag that is used to uniquely identify the
+ consumer with RabbitMQ. We keep the value to use it when we want to
+ cancel consuming. The on_request method is passed in as a callback pika
+ will invoke when a message is fully received.
+
+ """
+ LOGGER.info("Issuing consumer related RPC commands")
+ LOGGER.info("Adding consumer cancellation callback")
+ assert self._channel is not None
+ self._channel.add_on_cancel_callback(callback=self.on_consumer_cancelled)
+ for binding_key in self._binding_keys:
+ self._consumer_tag[binding_key] = self._channel.basic_consume(
+ queue=self._queues[binding_key], on_message_callback=self.on_request
+ )
+ self._consuming[binding_key] = True
+
+ def on_consumer_cancelled(self, method_frame: pika.frame.Method) -> None:
+ """Invoked by pika when RabbitMQ sends a Basic.Cancel for a consumer
+ receiving messages.
+
+ :param pika.frame.Method method_frame: The Basic.Cancel frame
+
+ """
+ LOGGER.info("Consumer was cancelled remotely, shutting down: %r", method_frame)
+ if self._channel:
+ self._channel.close()
+
+ def on_request(
+ self,
+ channel: pika.channel.Channel,
+ deliver: pika.spec.Basic.Deliver,
+ properties: pika.spec.BasicProperties,
+ body: bytes,
+ ) -> None:
+ """Invoked by pika when a message is delivered from RabbitMQ. The
+ channel is passed for your convenience. The basic_deliver object that
+ is passed in carries the exchange, routing key, delivery tag and
+ a redelivered flag for the message. The properties passed in is an
+ instance of BasicProperties with the message properties and the body
+ is the message that was sent.
+
+ :param pika.channel.Channel channel: The channel object
+ :param pika.spec.Basic.Deliver: basic_deliver method
+ :param pika.spec.BasicProperties: properties
+ :param bytes body: The message body
+
+ """
+ LOGGER.info(
+ "Received message # %s from %s: %s",
+ deliver.delivery_tag,
+ properties.app_id,
+ body,
+ )
+ item = decode_data(data=body, extra_decoders=self.extra_type_decoders)["data"]
+ self._request_queues[deliver.routing_key].put(
+ (tuple(item), (properties.correlation_id, properties.reply_to))
+ )
+ LOGGER.info("Acknowledging message %s", deliver.delivery_tag)
+ channel.basic_ack(delivery_tag=deliver.delivery_tag)
+
+ def stop_consuming(self) -> None:
+ """Tell RabbitMQ that you would like to stop consuming by sending the
+ Basic.Cancel RPC command.
+
+ """
+ if self._channel:
+ LOGGER.info("Sending a Basic.Cancel RPC command to RabbitMQ")
+ for binding_key in self._binding_keys:
+ callback = functools.partial(self.on_cancelok, binding_key=binding_key)
+ self._channel.basic_cancel(
+ self._consumer_tag[binding_key], callback=callback
)
+
+ def on_cancelok(self, _unused_frame: pika.frame.Method, binding_key: str) -> None:
+ """This method is invoked by pika when RabbitMQ acknowledges the
+ cancellation of a consumer. At this point we will close the channel.
+ This will invoke the on_channel_closed method once the channel has been
+ closed, which will in-turn close the connection.
+
+ :param pika.frame.Method _unused_frame: The Basic.CancelOk frame
+ :param str|unicode binding_key: Binding key of of the consumer to be stopped
+
+ """
+ self._consuming[binding_key] = False
+ LOGGER.info(
+ "RabbitMQ acknowledged the cancellation of the consumer: %s",
+ self._consuming[binding_key],
+ )
+ LOGGER.info("Closing the channel")
+ assert self._channel is not None
+ self._channel.close()
+
+ def run(self) -> None:
+ """Run the example code by connecting and then starting the IOLoop."""
+
+ self._request_queues: Dict[str, queue.Queue] = {}
+ self._storage_threads: Dict[str, threading.Thread] = {}
+ for binding_key in self._binding_keys:
+ meth_name, relation = ProvenanceStorageRabbitMQServer.get_meth_name(
+ binding_key
)
- # links is now a list of url, endpoint tuples
- return links
+
+ self._request_queues[binding_key] = queue.Queue()
+ self._storage_threads[binding_key] = threading.Thread(
+ target=self.run_storage_thread,
+ args=(binding_key, meth_name, relation),
+ )
+ self._storage_threads[binding_key].start()
+
+ while not self._closing:
+ try:
+ self._connection = self.connect()
+ assert self._connection is not None
+ self._connection.ioloop.start()
+ except KeyboardInterrupt:
+ self.stop()
+ if self._connection is not None and not self._connection.is_closed:
+ # Finish closing
+ self._connection.ioloop.start()
+
+ for binding_key in self._binding_keys:
+ self._request_queues[binding_key].put(TERMINATE)
+ self._storage_threads[binding_key].join()
+ LOGGER.info("Stopped")
+
+ def run_storage_thread(
+ self, binding_key: str, meth_name: str, relation: Optional[RelationType]
+ ) -> None:
+ storage = get_provenance_storage(**self._storage_config)
+
+ request_queue = self._request_queues[binding_key]
+ merge_items = ProvenanceStorageRabbitMQWorker.get_conflicts_func(meth_name)
+
+ while True:
+ terminate = False
+ elements = []
+ while True:
+ try:
+ # TODO: consider reducing this timeout or removing it
+ elem = request_queue.get(timeout=0.1)
+ if elem is TERMINATE:
+ terminate = True
+ break
+ elements.append(elem)
+ except queue.Empty:
+ break
+
+ if len(elements) >= self._batch_size:
+ break
+
+ if terminate:
+ break
+
+ if not elements:
+ continue
+
+ items, props = zip(*elements)
+ acks_count: TCounter[Tuple[str, str]] = Counter(props)
+ data = merge_items(items)
+ if not data:
+ print("Elements", elements)
+ print("Props", props)
+ print("Items", items)
+ print("Data", data)
+
+ args = (relation, data) if relation is not None else (data,)
+ if getattr(storage, meth_name)(*args):
+ for (correlation_id, reply_to), count in acks_count.items():
+ # FIXME: this is running in a different thread! Hence, if
+ # self._connection drops, there is no guarantee that the response
+ # can be sent for the current elements. This situation should be
+ # handled properly.
+ assert self._connection is not None
+ self._connection.ioloop.add_callback_threadsafe(
+ functools.partial(
+ ProvenanceStorageRabbitMQServer.respond,
+ channel=self._channel,
+ correlation_id=correlation_id,
+ reply_to=reply_to,
+ response=count,
+ )
+ )
+ else:
+ LOGGER.warning("Unable to process elements for queue %s", binding_key)
+ for elem in elements:
+ request_queue.put(elem)
+
+ def stop(self) -> None:
+ """Cleanly shutdown the connection to RabbitMQ by stopping the consumer
+ with RabbitMQ. When RabbitMQ confirms the cancellation, on_cancelok
+ will be invoked by pika, which will then closing the channel and
+ connection. The IOLoop is started again because this method is invoked
+ when CTRL-C is pressed raising a KeyboardInterrupt exception. This
+ exception stops the IOLoop which needs to be running for pika to
+ communicate with RabbitMQ. All of the commands issued prior to starting
+ the IOLoop will be buffered but not processed.
+
+ """
+ assert self._connection is not None
+ if not self._closing:
+ self._closing = True
+ LOGGER.info("Stopping")
+ if any(self._consuming):
+ self.stop_consuming()
+ self._connection.ioloop.start()
+ else:
+ self._connection.ioloop.stop()
+ LOGGER.info("Stopped")
+
+ @staticmethod
+ def get_conflicts_func(meth_name: str) -> Callable[[Iterable[Any]], Any]:
+ if meth_name in ["content_add", "directory_add"]:
+ return resolve_dates
+ elif meth_name == "location_add":
+ return lambda data: set(data) # just remove duplicates
+ elif meth_name == "origin_add":
+ return lambda data: dict(data) # last processed value is good enough
+ elif meth_name == "revision_add":
+ return resolve_revision
+ elif meth_name == "relation_add":
+ return resolve_relation
+ else:
+ LOGGER.warning(
+ "Unexpected conflict resolution function request for method %s",
+ meth_name,
+ )
+ return lambda x: x
+
+
+class ProvenanceStorageRabbitMQServer:
+ backend_class = ProvenanceStorageInterface
+ extra_type_decoders = DECODERS
+ extra_type_encoders = ENCODERS
+
+ queue_count = 16
+
+ def __init__(self, url: str, storage_config: Dict[str, Any]) -> None:
+ workers: List[ProvenanceStorageRabbitMQWorker] = []
+ for exchange in ProvenanceStorageRabbitMQServer.get_exchanges():
+ for range in ProvenanceStorageRabbitMQServer.get_ranges():
+ worker = ProvenanceStorageRabbitMQWorker(
+ url, exchange, range, storage_config
+ )
+ worker.start()
+ workers.append(worker)
+
+ LOGGER.info("Start consuming")
+ while True:
+ try:
+ time.sleep(1.0)
+ except KeyboardInterrupt:
+ break
+ LOGGER.info("Stop consuming")
+ for worker in workers:
+ worker.terminate()
+ worker.join()
+
+ @staticmethod
+ def ack(channel: pika.channel.Channel, delivery_tag: int) -> None:
+ channel.basic_ack(delivery_tag=delivery_tag)
+
+ @staticmethod
+ def get_binding_keys(exchange: str, range: int) -> Generator[str, None, None]:
+ for meth_name, relation in ProvenanceStorageRabbitMQServer.get_meth_names(
+ exchange
+ ):
+ if relation is None:
+ yield f"{meth_name}.unknown.{range:x}".lower()
+ else:
+ yield f"{meth_name}.{relation.value}.{range:x}".lower()
+
+ @staticmethod
+ def get_exchange(meth_name: str, relation: Optional[RelationType] = None) -> str:
+ if meth_name == "relation_add":
+ assert relation is not None
+ split = relation.value
+ else:
+ split = meth_name
+ exchange, *_ = split.split("_")
+ return exchange
+
+ @staticmethod
+ def get_exchanges() -> Generator[str, None, None]:
+ yield from [entity.value for entity in EntityType] + ["location"]
+
+ @staticmethod
+ def get_meth_name(
+ binding_key: str,
+ ) -> Tuple[str, Optional[RelationType]]:
+ meth_name, relation, *_ = binding_key.split(".")
+ return meth_name, (RelationType(relation) if relation != "unknown" else None)
+
+ @staticmethod
+ def get_meth_names(
+ exchange: str,
+ ) -> Generator[Tuple[str, Optional[RelationType]], None, None]:
+ if exchange == EntityType.CONTENT.value:
+ yield from [
+ ("content_add", None),
+ ("relation_add", RelationType.CNT_EARLY_IN_REV),
+ ("relation_add", RelationType.CNT_IN_DIR),
+ ]
+ elif exchange == EntityType.DIRECTORY.value:
+ yield from [
+ ("directory_add", None),
+ ("relation_add", RelationType.DIR_IN_REV),
+ ]
+ elif exchange == EntityType.ORIGIN.value:
+ yield from [("origin_add", None)]
+ elif exchange == EntityType.REVISION.value:
+ yield from [
+ ("revision_add", None),
+ ("relation_add", RelationType.REV_BEFORE_REV),
+ ("relation_add", RelationType.REV_IN_ORG),
+ ]
+ elif exchange == "location":
+ yield "location_add", None
+
+ @staticmethod
+ def get_ranges() -> Generator[int, None, None]:
+ yield from range(ProvenanceStorageRabbitMQServer.queue_count)
+
+ @staticmethod
+ def get_routing_key(
+ data: Tuple[bytes, ...], meth_name: str, relation: Optional[RelationType] = None
+ ) -> str:
+ idx = (
+ data[0][0] // (256 // ProvenanceStorageRabbitMQServer.queue_count)
+ if data and data[0]
+ else 0
+ )
+ if relation is None:
+ return f"{meth_name}.unknown.{idx:x}".lower()
+ else:
+ return f"{meth_name}.{relation.value}.{idx:x}".lower()
+
+ @staticmethod
+ def is_write_method(meth_name: str) -> bool:
+ return "_add" in meth_name
+
+ @staticmethod
+ def respond(
+ channel: pika.channel.Channel,
+ correlation_id: str,
+ reply_to: str,
+ response: Any,
+ ):
+ channel.basic_publish(
+ exchange="",
+ routing_key=reply_to,
+ properties=pika.BasicProperties(
+ content_type="application/msgpack",
+ correlation_id=correlation_id,
+ ),
+ body=encode_data(
+ response,
+ extra_encoders=ProvenanceStorageRabbitMQServer.extra_type_encoders,
+ ),
+ )
def load_and_check_config(
@@ -133,70 +775,10 @@
return cfg
-api_cfg: Optional[Dict[str, Any]] = None
-
-
-def make_app_from_configfile() -> ProvenanceStorageRPCServerApp:
- """Run the WSGI app from the webserver, loading the configuration from
- a configuration file.
-
- SWH_CONFIG_FILENAME environment variable defines the
- configuration path to load.
-
- """
- global api_cfg
- if api_cfg is None:
- config_path = os.environ.get("SWH_CONFIG_FILENAME")
- api_cfg = load_and_check_config(config_path)
- app.config.update(api_cfg)
- handler = logging.StreamHandler()
- app.logger.addHandler(handler)
- return app
-
-
-class ProvenanceStorageRabbitMQServer:
- backend_class = ProvenanceStorageInterface
- extra_type_decoders = DECODERS
- extra_type_encoders = ENCODERS
-
- def __init__(self, url: str, storage: ProvenanceStorageInterface) -> None:
- self.storage = storage
-
- self.conn = pika.BlockingConnection(pika.connection.URLParameters(url))
- self.channel = self.conn.channel()
-
- self.channel.basic_qos(prefetch_count=1)
- for (meth_name, meth) in self.backend_class.__dict__.items():
- if hasattr(meth, "_endpoint_path"):
- self.channel.queue_declare(queue=meth_name)
- self.channel.basic_consume(
- queue=meth_name, on_message_callback=self.on_request
- )
- self.channel.start_consuming()
-
- def on_request(
- self,
- channel: pika.channel.Channel,
- method: pika.spec.Basic.Deliver,
- props: pika.spec.BasicProperties,
- body: bytes,
- ) -> None:
- response = getattr(self.storage, method.routing_key)(
- **decode_data(body, extra_decoders=self.extra_type_decoders)
- )
- channel.basic_publish(
- exchange="",
- routing_key=props.reply_to,
- properties=pika.BasicProperties(correlation_id=props.correlation_id),
- body=encode_data(response, extra_encoders=self.extra_type_encoders),
- )
- channel.basic_ack(delivery_tag=method.delivery_tag)
-
-
def make_server_from_configfile() -> None:
config_path = os.environ.get("SWH_CONFIG_FILENAME")
server_cfg = load_and_check_config(config_path)
- storage = get_provenance_storage(**server_cfg["provenance"]["storage"])
ProvenanceStorageRabbitMQServer(
- url=server_cfg["provenance"]["rabbitmq"]["url"], storage=storage
+ url=server_cfg["provenance"]["rabbitmq"]["url"],
+ storage_config=server_cfg["provenance"]["storage"],
)
diff --git a/swh/provenance/cli.py b/swh/provenance/cli.py
--- a/swh/provenance/cli.py
+++ b/swh/provenance/cli.py
@@ -57,9 +57,15 @@
# Remote RabbitMQ/PostgreSQL Storage
"cls": "rabbitmq",
"url": "amqp://localhost:5672/%2f",
- # Remote REST-API/PostgreSQL
- # "cls": "restapi",
- # "url": "http://localhost:8080/%2f",
+ "storage_config": {
+ "cls": "postgresql",
+ "db": {
+ "host": "localhost",
+ "user": "postgres",
+ "password": "postgres",
+ "dbname": "dummy",
+ },
+ },
},
}
}
diff --git a/swh/provenance/interface.py b/swh/provenance/interface.py
--- a/swh/provenance/interface.py
+++ b/swh/provenance/interface.py
@@ -67,7 +67,7 @@
class ProvenanceStorageInterface(Protocol):
@remote_api_endpoint("content_add")
def content_add(
- self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]]
+ self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]]
) -> bool:
"""Add blobs identified by sha1 ids, with an optional associated date (as paired
in `cnts`) to the provenance storage. Return a boolean stating whether the
@@ -96,7 +96,7 @@
@remote_api_endpoint("directory_add")
def directory_add(
- self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]]
+ self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]]
) -> bool:
"""Add directories identified by sha1 ids, with an optional associated date (as
paired in `dirs`) to the provenance storage. Return a boolean stating if the
diff --git a/swh/provenance/mongo/backend.py b/swh/provenance/mongo/backend.py
--- a/swh/provenance/mongo/backend.py
+++ b/swh/provenance/mongo/backend.py
@@ -26,7 +26,7 @@
self.db = db
def content_add(
- self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]]
+ self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]]
) -> bool:
data = cnts if isinstance(cnts, dict) else dict.fromkeys(cnts)
existing = {
@@ -149,7 +149,7 @@
}
def directory_add(
- self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]]
+ self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]]
) -> bool:
data = dirs if isinstance(dirs, dict) else dict.fromkeys(dirs)
existing = {
diff --git a/swh/provenance/postgresql/provenance.py b/swh/provenance/postgresql/provenance.py
--- a/swh/provenance/postgresql/provenance.py
+++ b/swh/provenance/postgresql/provenance.py
@@ -61,7 +61,7 @@
return "denormalized" in self.flavor
def content_add(
- self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]]
+ self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]]
) -> bool:
return self._entity_set_date("content", cnts)
@@ -84,7 +84,7 @@
return self._entity_get_date("content", ids)
def directory_add(
- self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]]
+ self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]]
) -> bool:
return self._entity_set_date("directory", dirs)
@@ -209,9 +209,7 @@
def relation_add(
self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]]
) -> bool:
- rows = [
- (src, rel.dst, rel.path) for src, dsts in data.items() for rel in dsts
- ]
+ rows = [(src, rel.dst, rel.path) for src, dsts in data.items() for rel in dsts]
try:
if rows:
rel_table = relation.value
@@ -270,7 +268,7 @@
def _entity_set_date(
self,
entity: Literal["content", "directory"],
- dates: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]],
+ dates: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]],
) -> bool:
data = dates if isinstance(dates, dict) else dict.fromkeys(dates)
try:
diff --git a/swh/provenance/tests/conftest.py b/swh/provenance/tests/conftest.py
--- a/swh/provenance/tests/conftest.py
+++ b/swh/provenance/tests/conftest.py
@@ -5,7 +5,7 @@
from datetime import datetime, timedelta, timezone
from os import path
-from typing import Any, Dict, Iterable, Iterator
+from typing import Any, Dict, Iterable
from _pytest.fixtures import SubRequest
import msgpack
@@ -16,8 +16,6 @@
from swh.journal.serializers import msgpack_ext_hook
from swh.provenance import get_provenance, get_provenance_storage
-from swh.provenance.api.client import ProvenanceStorageRPCClient
-import swh.provenance.api.server as server
from swh.provenance.archive import ArchiveInterface
from swh.provenance.interface import ProvenanceInterface, ProvenanceStorageInterface
from swh.provenance.storage.archive import ArchiveStorage
@@ -46,38 +44,15 @@
return postgresql.get_dsn_parameters()
-# the Flask app used as server in these tests
-@pytest.fixture
-def app(
- provenance_postgresqldb: Dict[str, str]
-) -> Iterator[server.ProvenanceStorageRPCServerApp]:
- assert hasattr(server, "storage")
- server.storage = get_provenance_storage(
- cls="postgresql", db=provenance_postgresqldb
- )
- yield server.app
-
-
-# the RPCClient class used as client used in these tests
-@pytest.fixture
-def swh_rpc_client_class() -> type:
- return ProvenanceStorageRPCClient
-
-
-@pytest.fixture(params=["mongodb", "postgresql", "restapi"])
+@pytest.fixture(params=["mongodb", "postgresql"])
def provenance_storage(
request: SubRequest,
provenance_postgresqldb: Dict[str, str],
mongodb: pymongo.database.Database,
- swh_rpc_client: ProvenanceStorageRPCClient,
) -> ProvenanceStorageInterface:
"""Return a working and initialized ProvenanceStorageInterface object"""
- if request.param == "restapi":
- assert isinstance(swh_rpc_client, ProvenanceStorageInterface)
- return swh_rpc_client
-
- elif request.param == "mongodb":
+ if request.param == "mongodb":
from swh.provenance.mongo.backend import ProvenanceStorageMongoDb
return ProvenanceStorageMongoDb(mongodb)
File Metadata
Details
Attached
Mime Type
text/plain
Expires
Dec 20 2024, 9:16 AM (11 w, 4 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3226745
Attached To
D6165: Add new RabbitMQ-based client/server API
Event Timeline
Log In to Comment