diff --git a/mypy.ini b/mypy.ini
--- a/mypy.ini
+++ b/mypy.ini
@@ -17,6 +17,9 @@
[mypy-msgpack.*]
ignore_missing_imports = True
+[mypy-pika.*]
+ignore_missing_imports = True
+
[mypy-pkg_resources.*]
ignore_missing_imports = True
diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py
--- a/swh/provenance/__init__.py
+++ b/swh/provenance/__init__.py
@@ -96,12 +96,12 @@
db = MongoClient(**kwargs["db"]).get_database(dbname)
return ProvenanceStorageMongoDb(db)
- elif cls == "remote":
- from .api.client import RemoteProvenanceStorage
+ elif cls == "rabbitmq":
+ from .api.client import ProvenanceStorageRabbitMQClient
- storage = RemoteProvenanceStorage(**kwargs)
- assert isinstance(storage, ProvenanceStorageInterface)
- return storage
+ rmq_storage = ProvenanceStorageRabbitMQClient(**kwargs)
+ if TYPE_CHECKING:
+ assert isinstance(rmq_storage, ProvenanceStorageInterface)
+ return rmq_storage
- else:
- raise ValueError
+ raise ValueError
diff --git a/swh/provenance/api/client.py b/swh/provenance/api/client.py
--- a/swh/provenance/api/client.py
+++ b/swh/provenance/api/client.py
@@ -3,15 +3,667 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
-from swh.core.api import RPCClient
+import functools
+import inspect
+
+# import json
+import logging
+import queue
+import time
+from typing import Any, Optional
+
+# from typing import Dict
+import uuid
+
+import pika
+import pika.channel
+import pika.connection
+
+# from pika.exchange_type import ExchangeType
+import pika.frame
+import pika.spec
+
+from swh.core.api.serializers import encode_data_client as encode_data
+from swh.core.api.serializers import msgpack_loads as decode_data
+from swh.provenance import get_provenance_storage
from ..interface import ProvenanceStorageInterface
from .serializers import DECODERS, ENCODERS
+from .server import ProvenanceStorageRabbitMQServer
+
+LOG_FORMAT = (
+ "%(levelname) -10s %(asctime)s %(name) -30s %(funcName) "
+ "-35s %(lineno) -5d: %(message)s"
+)
+LOGGER = logging.getLogger(__name__)
+
+
+class MetaRabbitMQClient(type):
+ def __new__(cls, name, bases, attributes):
+ # For each method wrapped with @remote_api_endpoint in an API backend
+ # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new
+ # method in RemoteStorage, with the same documentation.
+ #
+ # Note that, despite the usage of decorator magic (eg. functools.wrap),
+ # this never actually calls an IndexerStorage method.
+ backend_class = attributes.get("backend_class", None)
+ for base in bases:
+ if backend_class is not None:
+ break
+ backend_class = getattr(base, "backend_class", None)
+ if backend_class:
+ for meth_name, meth in backend_class.__dict__.items():
+ if hasattr(meth, "_endpoint_path"):
+ cls.__add_endpoint(meth_name, meth, attributes)
+ return super().__new__(cls, name, bases, attributes)
+
+ @staticmethod
+ def __add_endpoint(meth_name, meth, attributes):
+ wrapped_meth = inspect.unwrap(meth)
+
+ @functools.wraps(meth) # Copy signature and doc
+ def meth_(*args, **kwargs):
+ # Match arguments and parameters
+ data = inspect.getcallargs(wrapped_meth, *args, **kwargs)
+
+ # Remove arguments that should not be passed
+ self = data.pop("self")
+
+ # Call storage method with remaining arguments
+ return getattr(self._storage, meth_name)(**data)
+
+ @functools.wraps(meth) # Copy signature and doc
+ def write_meth_(*args, **kwargs):
+ # Match arguments and parameters
+ post_data = inspect.getcallargs(wrapped_meth, *args, **kwargs)
+
+ # Remove arguments that should not be passed
+ self = post_data.pop("self")
+ relation = post_data.pop("relation", None)
+ assert len(post_data) == 1
+ if relation is not None:
+ data = [
+ (row.src, row.dst, row.path)
+ for row in next(iter(post_data.values()))
+ ]
+ else:
+ data = list(next(iter(post_data.values())).items())
+ acks_expected = len(data)
+ correlation_id = str(uuid.uuid4())
+ for item in data:
+ routing_key = ProvenanceStorageRabbitMQServer.get_routing_key(
+ meth_name, relation, item
+ )
+ self.request(routing_key, data=item, correlation_id=correlation_id)
+ return self.wait_for_acks(acks_expected)
-class RemoteProvenanceStorage(RPCClient):
- """Proxy to a remote provenance storage API"""
+ if meth_name not in attributes:
+ attributes[meth_name] = (
+ write_meth_
+ if ProvenanceStorageRabbitMQServer.is_write_method(meth_name)
+ else meth_
+ )
+
+class ConfigurationError(Exception):
+ pass
+
+
+class ResponseTimeout(Exception):
+ pass
+
+
+class ProvenanceStorageRabbitMQClient(metaclass=MetaRabbitMQClient):
backend_class = ProvenanceStorageInterface
extra_type_decoders = DECODERS
extra_type_encoders = ENCODERS
+
+ def __init__(self, url: str, **kwargs) -> None:
+ self.conn = pika.BlockingConnection(pika.connection.URLParameters(url))
+ self.channel = self.conn.channel()
+
+ result = self.channel.queue_declare(queue="", exclusive=True)
+ self.callback_queue = result.method.queue
+ self.channel.basic_consume(
+ queue=self.callback_queue,
+ on_message_callback=self.on_response,
+ auto_ack=True,
+ )
+
+ self.response_queue: queue.Queue = queue.Queue()
+
+ # Get storage configuration from server.
+ self.request("get_storage_config")
+ try:
+ self._storage = get_provenance_storage(**self.wait_for_response())
+ except ResponseTimeout:
+ LOGGER.warning("Timed out waiting for response on get_storage_config")
+ raise ConfigurationError
+
+ def on_response(
+ self,
+ channel: pika.channel.Channel,
+ method: pika.spec.Basic.Deliver,
+ props: pika.spec.BasicProperties,
+ body: bytes,
+ ) -> None:
+ self.response_queue.put(
+ (
+ props.correlation_id,
+ decode_data(body, extra_decoders=self.extra_type_decoders),
+ )
+ )
+
+ def request(
+ self, routing_key: str, correlation_id: Optional[str] = None, **kwargs
+ ) -> Any:
+ self.response = None
+ self.correlation_id = (
+ correlation_id if correlation_id is not None else str(uuid.uuid4())
+ )
+ self.channel.basic_publish(
+ exchange="",
+ routing_key=routing_key,
+ properties=pika.BasicProperties(
+ reply_to=self.callback_queue,
+ correlation_id=self.correlation_id,
+ ),
+ body=encode_data(kwargs, extra_encoders=self.extra_type_encoders),
+ )
+
+ def wait_for_acks(self, acks_expected: int) -> bool:
+ acks_received = 0
+ while acks_received < acks_expected:
+ try:
+ acks_received += self.wait_for_response()
+ except ResponseTimeout:
+ LOGGER.warning(
+ "Timed out waiting for acks, %s received, %s expected",
+ acks_received,
+ acks_expected,
+ )
+ return False
+ return acks_received == acks_expected
+
+ def wait_for_response(self, timeout: float = 60.0) -> Any:
+ start = time.monotonic()
+ while True:
+ try:
+ correlation_id, response = self.response_queue.get(block=False)
+ if correlation_id == self.correlation_id:
+ return response
+ except queue.Empty:
+ # Always make time_limit > 1 second
+ time_limit = max(timeout - (start - time.monotonic()), 1.0)
+ self.conn.process_data_events(time_limit=time_limit)
+
+ if self.response_queue.empty() and time.monotonic() > start + timeout:
+ raise ResponseTimeout
+
+
+# class ProvenanceStorageRabbitMQClient(metaclass=MetaRabbitMQClient):
+# """This is an example publisher that will handle unexpected interactions
+# with RabbitMQ such as channel and connection closures.
+
+# If RabbitMQ closes the connection, it will reopen it. You should
+# look at the output, as there are limited reasons why the connection may
+# be closed, which usually are tied to permission related issues or
+# socket timeouts.
+
+# It uses delivery confirmations and illustrates one way to keep track of
+# messages that have been sent and if they've been confirmed by RabbitMQ.
+
+# """
+
+# EXCHANGE_TYPE = ExchangeType.topic
+
+# backend_class = ProvenanceStorageInterface
+# extra_type_decoders = DECODERS
+# extra_type_encoders = ENCODERS
+
+# def __init__(self, url: str, storage_config: Dict[str, Any]) -> None:
+# """Setup the example publisher object, passing in the URL we will use
+# to connect to RabbitMQ.
+
+# :param str url: The URL for connecting to RabbitMQ
+# :param str routing_key: The routing key name from which this worker will
+# consume messages
+# :param str storage_config: Configuration parameters for the underlying
+# ``ProvenanceStorage`` object
+
+# """
+# self.should_reconnect = False
+# self.was_consuming = False
+
+# self._connection = None
+# self._channel = None
+# self._closing = False
+# self._consumer_tag = None
+# self._consuming = False
+# self._prefetch_count = 100
+
+# self._url = url
+# self._storage = get_provenance_storage(**storage_config)
+# self._response_queue: queue.Queue = queue.Queue()
+
+# self._mutex = threading.Lock()
+# self._consumer_thread = threading.Thread(target=self.run)
+# self._consumer_thread.start()
+
+# # def __del__(self) -> None:
+# # self.stop()
+# # self._consumer_thread.join()
+
+# def connect(self) -> pika.SelectConnection:
+# """This method connects to RabbitMQ, returning the connection handle.
+# When the connection is established, the on_connection_open method
+# will be invoked by pika.
+
+# :rtype: pika.SelectConnection
+
+# """
+# LOGGER.info("Connecting to %s", self._url)
+# LOGGER.debug("Acquiring mutex to establish connection")
+# self._mutex.acquire()
+# return pika.SelectConnection(
+# parameters=pika.URLParameters(self._url),
+# on_open_callback=self.on_connection_open,
+# on_open_error_callback=self.on_connection_open_error,
+# on_close_callback=self.on_connection_closed,
+# )
+
+# def close_connection(self) -> None:
+# assert self._connection is not None
+# self._consuming = False
+# if self._connection.is_closing or self._connection.is_closed:
+# LOGGER.info("Connection is closing or already closed")
+# else:
+# LOGGER.info("Closing connection")
+# self._connection.close()
+
+# def on_connection_open(self, _unused_connection: pika.SelectConnection) -> None:
+# """This method is called by pika once the connection to RabbitMQ has
+# been established. It passes the handle to the connection object in
+# case we need it, but in this case, we'll just mark it unused.
+
+# :param pika.SelectConnection _unused_connection: The connection
+
+# """
+# LOGGER.info("Connection opened")
+# self.open_channel()
+
+# def on_connection_open_error(
+# self, _unused_connection: pika.SelectConnection, err: Exception
+# ) -> None:
+# """This method is called by pika if the connection to RabbitMQ
+# can't be established.
+
+# :param pika.SelectConnection _unused_connection: The connection
+# :param Exception err: The error
+
+# """
+# LOGGER.error("Connection open failed: %s", err)
+# self._mutex.release()
+# LOGGER.debug("Mutex released due to connection error")
+# self.reconnect()
+
+# def on_connection_closed(self, _unused_connection: pika.SelectConnection, reason):
+# """This method is invoked by pika when the connection to RabbitMQ is
+# closed unexpectedly. Since it is unexpected, we will reconnect to
+# RabbitMQ if it disconnects.
+
+# :param pika.connection.Connection connection: The closed connection obj
+# :param Exception reason: exception representing reason for loss of
+# connection.
+
+# """
+# self._channel = None
+# if self._closing:
+# assert self._connection is not None
+# self._connection.ioloop.stop()
+# else:
+# LOGGER.warning("Connection closed, reconnect necessary: %s", reason)
+# self.reconnect()
+
+# def reconnect(self) -> None:
+# """Will be invoked if the connection can't be opened or is
+# closed. Indicates that a reconnect is necessary then stops the
+# ioloop.
+
+# """
+# self.should_reconnect = True
+# self.stop()
+
+# def open_channel(self) -> None:
+# """Open a new channel with RabbitMQ by issuing the Channel.Open RPC
+# command. When RabbitMQ responds that the channel is open, the
+# on_channel_open callback will be invoked by pika.
+
+# """
+# LOGGER.info("Creating a new channel")
+# assert self._connection is not None
+# self._connection.channel(on_open_callback=self.on_channel_open)
+
+# def on_channel_open(self, channel: pika.channel.Channel) -> None:
+# """This method is invoked by pika when the channel has been opened.
+# The channel object is passed in so we can make use of it.
+
+# Since the channel is now open, we'll declare the exchange to use.
+
+# :param pika.channel.Channel channel: The channel object
+
+# """
+# LOGGER.info("Channel opened")
+# self._channel = channel
+# self.add_on_channel_close_callback()
+# self.setup_queue()
+
+# def add_on_channel_close_callback(self) -> None:
+# """This method tells pika to call the on_channel_closed method if
+# RabbitMQ unexpectedly closes the channel.
+
+# """
+# LOGGER.info("Adding channel close callback")
+# assert self._channel is not None
+# self._channel.add_on_close_callback(callback=self.on_channel_closed)
+
+# def on_channel_closed(
+# self, channel: pika.channel.Channel, reason: Exception
+# ) -> None:
+# """Invoked by pika when RabbitMQ unexpectedly closes the channel.
+# Channels are usually closed if you attempt to do something that
+# violates the protocol, such as re-declare an exchange or queue with
+# different parameters. In this case, we'll close the connection
+# to shutdown the object.
+
+# :param pika.channel.Channel: The closed channel
+# :param Exception reason: why the channel was closed
+
+# """
+# LOGGER.warning("Channel %i was closed: %s", channel, reason)
+# self.close_connection()
+
+# def setup_queue(self) -> None:
+# """Setup the queue on RabbitMQ by invoking the Queue.Declare RPC
+# command. When it is complete, the on_queue_declareok method will
+# be invoked by pika.
+
+# """
+# LOGGER.info("Declaring callback queue")
+# assert self._channel is not None
+# self._channel.queue_declare(
+# queue="", exclusive=True, callback=self.on_queue_declareok
+# )
+
+# def on_queue_declareok(self, frame: pika.frame.Method) -> None:
+# """Method invoked by pika when the Queue.Declare RPC call made in
+# setup_queue has completed. In this method we will bind the queue
+# and exchange together with the routing key by issuing the Queue.Bind
+# RPC command. When this command is complete, the on_bindok method will
+# be invoked by pika.
+
+# :param pika.frame.Method method_frame: The Queue.DeclareOk frame
+# 4
+# """
+# LOGGER.info("Binding queue to default exchanger")
+# self._callback_queue = frame.method.queue
+# # assert self._channel is not None
+# # self._channel.queue_bind(
+# # queue=self._callback_queue, exchange="", callback=self.on_bindok
+# # )
+
+# # def on_bindok(self, _unused_frame: pika.frame.Method) -> None:
+# # """Invoked by pika when the Queue.Bind method has completed. At this
+# # point we will set the prefetch count for the channel.
+
+# # :param pika.frame.Method _unused_frame: The Queue.BindOk response frame
+# # :param str|unicode queue_name: The name of the queue to declare
+
+# # """
+# # LOGGER.info("Queue bound")
+# # self.set_qos()
+
+# # def set_qos(self) -> None:
+# """This method sets up the consumer prefetch to only be delivered
+# one message at a time. The consumer must acknowledge this message
+# before RabbitMQ will deliver another one. You should experiment
+# with different prefetch values to achieve desired performance.
+
+# """
+# assert self._channel is not None
+# self._channel.basic_qos(
+# prefetch_count=self._prefetch_count, callback=self.on_basic_qos_ok
+# )
+
+# def on_basic_qos_ok(self, _unused_frame: pika.frame.Method) -> None:
+# """Invoked by pika when the Basic.QoS method has completed. At this
+# point we will start consuming messages by calling start_consuming
+# which will invoke the needed RPC commands to start the process.
+
+# :param pika.frame.Method _unused_frame: The Basic.QosOk response frame
+
+# """
+# LOGGER.info("QOS set to: %d", self._prefetch_count)
+# self._mutex.release()
+# LOGGER.debug("Mutex released with connection successfully established")
+# self.start_consuming()
+
+# def start_consuming(self) -> None:
+# """This method sets up the consumer by first calling
+# add_on_cancel_callback so that the object is notified if RabbitMQ
+# cancels the consumer. It then issues the Basic.Consume RPC command
+# which returns the consumer tag that is used to uniquely identify the
+# consumer with RabbitMQ. We keep the value to use it when we want to
+# cancel consuming. The on_message method is passed in as a callback pika
+# will invoke when a message is fully received.
+
+# """
+# LOGGER.info("Issuing consumer related RPC commands")
+# assert self._channel is not None
+# self.add_on_cancel_callback()
+
+# # Get storage configuration from server.
+# # if self._storage is None:
+# # self.request("get_storage_config")
+# # try:
+# # self._storage = get_provenance_storage(**self.wait_for_response())
+# # except ResponseTimeout:
+# # LOGGER.warning(
+# # "Timed out waiting for response on get_storage_config"
+# # )
+# # raise ConfigurationError
+
+# self._consumer_tag = self._channel.basic_consume(
+# queue=self._callback_queue, on_message_callback=self.on_message
+# )
+# self.was_consuming = True
+# self._consuming = True
+
+# def add_on_cancel_callback(self) -> None:
+# """Add a callback that will be invoked if RabbitMQ cancels the consumer
+# for some reason. If RabbitMQ does cancel the consumer,
+# on_consumer_cancelled will be invoked by pika.
+
+# """
+# LOGGER.info("Adding consumer cancellation callback")
+# assert self._channel is not None
+# self._channel.add_on_cancel_callback(callback=self.on_consumer_cancelled)
+
+# def on_consumer_cancelled(self, method_frame: pika.frame.Method) -> None:
+# """Invoked by pika when RabbitMQ sends a Basic.Cancel for a consumer
+# receiving messages.
+
+# :param pika.frame.Method method_frame: The Basic.Cancel frame
+
+# """
+# LOGGER.info(
+# "Consumer was cancelled remotely, shutting down: %r", method_frame
+# )
+# if self._channel:
+# self._channel.close()
+
+# def on_message(
+# self,
+# _unused_channel: pika.channel.Channel,
+# basic_deliver: pika.spec.Basic.Deliver,
+# properties: pika.spec.BasicProperties,
+# body: bytes,
+# ) -> None:
+# """Invoked by pika when a message is delivered from RabbitMQ. The
+# channel is passed for your convenience. The basic_deliver object that
+# is passed in carries the exchange, routing key, delivery tag and
+# a redelivered flag for the message. The properties passed in is an
+# instance of BasicProperties with the message properties and the body
+# is the message that was sent.
+
+# :param pika.channel.Channel _unused_channel: The channel object
+# :param pika.spec.Basic.Deliver: basic_deliver method
+# :param pika.spec.BasicProperties: properties
+# :param bytes body: The message body
+
+# """
+# LOGGER.info(
+# "Received message # %s from %s: %s",
+# basic_deliver.delivery_tag,
+# properties.app_id,
+# body,
+# )
+# self._response_queue.put(
+# (
+# properties.correlation_id,
+# decode_data(body, extra_decoders=self.extra_type_decoders),
+# )
+# )
+# self.acknowledge_message(delivery_tag=basic_deliver.delivery_tag)
+
+# def acknowledge_message(self, delivery_tag: int) -> None:
+# """Acknowledge the message delivery from RabbitMQ by sending a
+# Basic.Ack RPC method for the delivery tag.
+
+# :param int delivery_tag: The delivery tag from the Basic.Deliver frame
+
+# """
+# LOGGER.info("Acknowledging message %s", delivery_tag)
+# assert self._channel is not None
+# self._channel.basic_ack(delivery_tag=delivery_tag)
+
+# def stop_consuming(self) -> None:
+# """Tell RabbitMQ that you would like to stop consuming by sending the
+# Basic.Cancel RPC command.
+
+# """
+# if self._channel:
+# LOGGER.info("Sending a Basic.Cancel RPC command to RabbitMQ")
+# self._channel.basic_cancel(self._consumer_tag, self.on_cancelok)
+
+# def on_cancelok(self, _unused_frame: pika.frame.Method) -> None:
+# """This method is invoked by pika when RabbitMQ acknowledges the
+# cancellation of a consumer. At this point we will close the channel.
+# This will invoke the on_channel_closed method once the channel has been
+# closed, which will in-turn close the connection.
+
+# :param pika.frame.Method _unused_frame: The Basic.CancelOk frame
+# :param str|unicode consumer_tag: Tag of the consumer to be stopped
+
+# """
+# self._consuming = False
+# LOGGER.info(
+# "RabbitMQ acknowledged the cancellation of the consumer: %s",
+# self._consumer_tag,
+# )
+# self.close_channel()
+
+# def close_channel(self) -> None:
+# """Call to close the channel with RabbitMQ cleanly by issuing the
+# Channel.Close RPC command.
+
+# """
+# LOGGER.info("Closing the channel")
+# assert self._channel is not None
+# self._channel.close()
+
+# def run(self) -> None:
+# """Run the example code by connecting and then starting the IOLoop."""
+
+# while not self._closing:
+# try:
+# self._connection = self.connect()
+# assert self._connection is not None
+# self._connection.ioloop.start()
+# except KeyboardInterrupt:
+# self.stop()
+# if self._connection is not None and not self._connection.is_closed:
+# # Finish closing
+# self._connection.ioloop.start()
+# LOGGER.info("Stopped")
+
+# def stop(self) -> None:
+# """Cleanly shutdown the connection to RabbitMQ by stopping the consumer
+# with RabbitMQ. When RabbitMQ confirms the cancellation, on_cancelok
+# will be invoked by pika, which will then closing the channel and
+# connection. The IOLoop is started again because this method is invoked
+# when CTRL-C is pressed raising a KeyboardInterrupt exception. This
+# exception stops the IOLoop which needs to be running for pika to
+# communicate with RabbitMQ. All of the commands issued prior to starting
+# the IOLoop will be buffered but not processed.
+
+# """
+# assert self._connection is not None
+# if not self._closing:
+# self._closing = True
+# LOGGER.info("Stopping")
+# if self._consuming:
+# self.stop_consuming()
+# self._connection.ioloop.start()
+# else:
+# self._connection.ioloop.stop()
+# LOGGER.info("Stopped")
+
+# def request(
+# self, routing_key: str, correlation_id: Optional[str] = None, **kwargs
+# ) -> Any:
+# LOGGER.debug("Acquiring mutex to send request on %s", routing_key)
+# self._mutex.acquire()
+# assert self._channel is not None
+# self._correlation_id = (
+# correlation_id if correlation_id is not None else str(uuid.uuid4())
+# )
+# self._channel.basic_publish(
+# exchange="",
+# routing_key=routing_key,
+# properties=pika.BasicProperties(
+# reply_to=self._callback_queue,
+# correlation_id=self._correlation_id,
+# ),
+# body=encode_data(kwargs, extra_encoders=self.extra_type_encoders),
+# )
+# self._mutex.release()
+# LOGGER.debug("Mutex released after sending request")
+
+# def wait_for_acks(self, acks_expected: int) -> bool:
+# acks_received = 0
+# while acks_received < acks_expected:
+# try:
+# acks_received += self.wait_for_response()
+# except ResponseTimeout:
+# LOGGER.warning(
+# "Timed out waiting for acks, %s received, %s expected",
+# acks_received,
+# acks_expected,
+# )
+# return False
+# return acks_received == acks_expected
+
+# def wait_for_response(self, timeout: float = 60.0) -> Any:
+# start = time.monotonic()
+# while True:
+# try:
+# correlation_id, response = self._response_queue.get(block=False)
+# if correlation_id == self._correlation_id:
+# return response
+# except queue.Empty:
+# pass
+
+# if self._response_queue.empty() and time.monotonic() > start + timeout:
+# raise ResponseTimeout
diff --git a/swh/provenance/api/server.py b/swh/provenance/api/server.py
--- a/swh/provenance/api/server.py
+++ b/swh/provenance/api/server.py
@@ -3,79 +3,678 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
+from collections import Counter
+from datetime import datetime
+import functools
import logging
+import multiprocessing
import os
-from typing import Any, Dict, List, Optional
-
-from werkzeug.routing import Rule
+import queue
+import re
+import threading
+from typing import Any, Callable
+from typing import Counter as TCounter
+from typing import Dict, Generator, List, Optional, Tuple
+
+import pika
+import pika.channel
+import pika.connection
+import pika.exceptions
+from pika.exchange_type import ExchangeType
+import pika.frame
+import pika.spec
from swh.core import config
-from swh.core.api import JSONFormatter, MsgpackFormatter, RPCServerApp, negotiate
+from swh.core.api.serializers import encode_data_client as encode_data
+from swh.core.api.serializers import msgpack_loads as decode_data
+from swh.model.model import Sha1Git
from swh.provenance import get_provenance_storage
-from swh.provenance.interface import ProvenanceStorageInterface
+from swh.provenance.interface import (
+ ProvenanceStorageInterface,
+ RelationData,
+ RelationType,
+)
from .serializers import DECODERS, ENCODERS
-storage: Optional[ProvenanceStorageInterface] = None
+LOG_FORMAT = (
+ "%(levelname) -10s %(asctime)s %(name) -30s %(funcName) "
+ "-35s %(lineno) -5d: %(message)s"
+)
+LOGGER = logging.getLogger(__name__)
+TERMINATE = object()
-def get_global_provenance_storage() -> ProvenanceStorageInterface:
- global storage
- if storage is None:
- storage = get_provenance_storage(**app.config["provenance"]["storage"])
- return storage
+def resolve_dates(dates: List[Tuple[Sha1Git, datetime]]) -> Dict[Sha1Git, datetime]:
+ result: Dict[Sha1Git, datetime] = {}
+ for sha1, date in dates:
+ if sha1 in result and result[sha1] < date:
+ continue
+ else:
+ result[sha1] = date
+ return result
-class ProvenanceStorageServerApp(RPCServerApp):
- extra_type_decoders = DECODERS
- extra_type_encoders = ENCODERS
+class ProvenanceStorageRabbitMQWorker(multiprocessing.Process):
+ """This is an example publisher that will handle unexpected interactions
+ with RabbitMQ such as channel and connection closures.
-app = ProvenanceStorageServerApp(
- __name__,
- backend_class=ProvenanceStorageInterface,
- backend_factory=get_global_provenance_storage,
-)
+ If RabbitMQ closes the connection, it will reopen it. You should
+ look at the output, as there are limited reasons why the connection may
+ be closed, which usually are tied to permission related issues or
+ socket timeouts.
+ It uses delivery confirmations and illustrates one way to keep track of
+ messages that have been sent and if they've been confirmed by RabbitMQ.
-def has_no_empty_params(rule: Rule) -> bool:
- return len(rule.defaults or ()) >= len(rule.arguments or ())
-
-
-@app.route("/")
-def index() -> str:
- return """
-
Software Heritage provenance storage RPC server
-
-You have reached the
-Software Heritage
-provenance storage RPC server.
-See its
-documentation
-and API for more information
-
-"""
-
-
-@app.route("/site-map")
-@negotiate(MsgpackFormatter)
-@negotiate(JSONFormatter)
-def site_map() -> List[Dict[str, Any]]:
- links = []
- for rule in app.url_map.iter_rules():
- if has_no_empty_params(rule) and hasattr(
- ProvenanceStorageInterface, rule.endpoint
- ):
- links.append(
- dict(
- rule=rule.rule,
- description=getattr(
- ProvenanceStorageInterface, rule.endpoint
- ).__doc__,
+ """
+
+ EXCHANGE_TYPE = ExchangeType.topic
+
+ extra_type_decoders = DECODERS
+ extra_type_encoders = ENCODERS
+
+ def __init__(
+ self, url: str, routing_key: str, storage_config: Dict[str, Any]
+ ) -> None:
+ """Setup the example publisher object, passing in the URL we will use
+ to connect to RabbitMQ.
+
+ :param str url: The URL for connecting to RabbitMQ
+ :param str routing_key: The routing key name from which this worker will
+ consume messages
+ :param str storage_config: Configuration parameters for the underlying
+ ``ProvenanceStorage`` object
+
+ """
+ super().__init__(name=routing_key)
+
+ self.should_reconnect = False
+ self.was_consuming = False
+
+ self._connection = None
+ self._channel = None
+ self._closing = False
+ self._consumer_tag = None
+ self._consuming = False
+ self._prefetch_count = 100
+
+ self._url = url
+ self._routing_key = routing_key
+ self._storage_config = storage_config
+ self._batch_size = 100
+
+ def connect(self) -> pika.SelectConnection:
+ """This method connects to RabbitMQ, returning the connection handle.
+ When the connection is established, the on_connection_open method
+ will be invoked by pika.
+
+ :rtype: pika.SelectConnection
+
+ """
+ LOGGER.info("Connecting to %s", self._url)
+ return pika.SelectConnection(
+ parameters=pika.URLParameters(self._url),
+ on_open_callback=self.on_connection_open,
+ on_open_error_callback=self.on_connection_open_error,
+ on_close_callback=self.on_connection_closed,
+ )
+
+ def close_connection(self) -> None:
+ assert self._connection is not None
+ self._consuming = False
+ if self._connection.is_closing or self._connection.is_closed:
+ LOGGER.info("Connection is closing or already closed")
+ else:
+ LOGGER.info("Closing connection")
+ self._connection.close()
+
+ def on_connection_open(self, _unused_connection: pika.SelectConnection) -> None:
+ """This method is called by pika once the connection to RabbitMQ has
+ been established. It passes the handle to the connection object in
+ case we need it, but in this case, we'll just mark it unused.
+
+ :param pika.SelectConnection _unused_connection: The connection
+
+ """
+ LOGGER.info("Connection opened")
+ self.open_channel()
+
+ def on_connection_open_error(
+ self, _unused_connection: pika.SelectConnection, err: Exception
+ ) -> None:
+ """This method is called by pika if the connection to RabbitMQ
+ can't be established.
+
+ :param pika.SelectConnection _unused_connection: The connection
+ :param Exception err: The error
+
+ """
+ LOGGER.error("Connection open failed: %s", err)
+ self.reconnect()
+
+ def on_connection_closed(self, _unused_connection: pika.SelectConnection, reason):
+ """This method is invoked by pika when the connection to RabbitMQ is
+ closed unexpectedly. Since it is unexpected, we will reconnect to
+ RabbitMQ if it disconnects.
+
+ :param pika.connection.Connection connection: The closed connection obj
+ :param Exception reason: exception representing reason for loss of
+ connection.
+
+ """
+ self._channel = None
+ if self._closing:
+ assert self._connection is not None
+ self._connection.ioloop.stop()
+ else:
+ LOGGER.warning("Connection closed, reconnect necessary: %s", reason)
+ self.reconnect()
+
+ def reconnect(self) -> None:
+ """Will be invoked if the connection can't be opened or is
+ closed. Indicates that a reconnect is necessary then stops the
+ ioloop.
+
+ """
+ self.should_reconnect = True
+ self.stop()
+
+ def open_channel(self) -> None:
+ """Open a new channel with RabbitMQ by issuing the Channel.Open RPC
+ command. When RabbitMQ responds that the channel is open, the
+ on_channel_open callback will be invoked by pika.
+
+ """
+ LOGGER.info("Creating a new channel")
+ assert self._connection is not None
+ self._connection.channel(on_open_callback=self.on_channel_open)
+
+ def on_channel_open(self, channel: pika.channel.Channel) -> None:
+ """This method is invoked by pika when the channel has been opened.
+ The channel object is passed in so we can make use of it.
+
+ Since the channel is now open, we'll declare the exchange to use.
+
+ :param pika.channel.Channel channel: The channel object
+
+ """
+ LOGGER.info("Channel opened")
+ self._channel = channel
+ self.add_on_channel_close_callback()
+ self.setup_exchange()
+
+ def add_on_channel_close_callback(self) -> None:
+ """This method tells pika to call the on_channel_closed method if
+ RabbitMQ unexpectedly closes the channel.
+
+ """
+ LOGGER.info("Adding channel close callback")
+ assert self._channel is not None
+ self._channel.add_on_close_callback(callback=self.on_channel_closed)
+
+ def on_channel_closed(
+ self, channel: pika.channel.Channel, reason: Exception
+ ) -> None:
+ """Invoked by pika when RabbitMQ unexpectedly closes the channel.
+ Channels are usually closed if you attempt to do something that
+ violates the protocol, such as re-declare an exchange or queue with
+ different parameters. In this case, we'll close the connection
+ to shutdown the object.
+
+ :param pika.channel.Channel: The closed channel
+ :param Exception reason: why the channel was closed
+
+ """
+ LOGGER.warning("Channel %i was closed: %s", channel, reason)
+ self.close_connection()
+
+ def setup_exchange(self) -> None:
+ """Setup the exchange on RabbitMQ by invoking the Exchange.Declare RPC
+ command. When it is complete, the on_exchange_declareok method will
+ be invoked by pika.
+
+ """
+ LOGGER.info("Declaring exchange %s", self._routing_key)
+ assert self._channel is not None
+ self._channel.exchange_declare(
+ exchange=self._routing_key,
+ exchange_type=self.EXCHANGE_TYPE,
+ callback=self.on_exchange_declareok,
+ )
+
+ def on_exchange_declareok(self, _unused_frame: pika.frame.Method) -> None:
+ """Invoked by pika when RabbitMQ has finished the Exchange.Declare RPC
+ command.
+
+ :param pika.frame.Method unused_frame: Exchange.DeclareOk response frame
+
+ """
+ LOGGER.info("Exchange declared: %s", self._routing_key)
+ self.setup_queue()
+
+ def setup_queue(self) -> None:
+ """Setup the queue on RabbitMQ by invoking the Queue.Declare RPC
+ command. When it is complete, the on_queue_declareok method will
+ be invoked by pika.
+
+ """
+ LOGGER.info("Declaring queue %s", self._routing_key)
+ assert self._channel is not None
+ self._channel.queue_declare(
+ queue=self._routing_key, callback=self.on_queue_declareok
+ )
+
+ def on_queue_declareok(self, _unused_frame: pika.frame.Method) -> None:
+ """Method invoked by pika when the Queue.Declare RPC call made in
+ setup_queue has completed. In this method we will bind the queue
+ and exchange together with the routing key by issuing the Queue.Bind
+ RPC command. When this command is complete, the on_bindok method will
+ be invoked by pika.
+
+ :param pika.frame.Method method_frame: The Queue.DeclareOk frame
+
+ """
+ LOGGER.info(
+ "Binding queue %s to exchange %s with routing key %s",
+ self._routing_key,
+ self._routing_key,
+ self._routing_key,
+ )
+ assert self._channel is not None
+ self._channel.queue_bind(
+ queue=self._routing_key,
+ exchange=self._routing_key,
+ routing_key=self._routing_key,
+ callback=self.on_bindok,
+ )
+
+ def on_bindok(self, _unused_frame: pika.frame.Method) -> None:
+ """Invoked by pika when the Queue.Bind method has completed. At this
+ point we will set the prefetch count for the channel.
+
+ :param pika.frame.Method _unused_frame: The Queue.BindOk response frame
+ :param str|unicode queue_name: The name of the queue to declare
+
+ """
+ LOGGER.info("Queue bound: %s", self._routing_key)
+ self.set_qos()
+
+ def set_qos(self) -> None:
+ """This method sets up the consumer prefetch to only be delivered
+ one message at a time. The consumer must acknowledge this message
+ before RabbitMQ will deliver another one. You should experiment
+ with different prefetch values to achieve desired performance.
+
+ """
+ assert self._channel is not None
+ self._channel.basic_qos(
+ prefetch_count=self._prefetch_count, callback=self.on_basic_qos_ok
+ )
+
+ def on_basic_qos_ok(self, _unused_frame: pika.frame.Method) -> None:
+ """Invoked by pika when the Basic.QoS method has completed. At this
+ point we will start consuming messages by calling start_consuming
+ which will invoke the needed RPC commands to start the process.
+
+ :param pika.frame.Method _unused_frame: The Basic.QosOk response frame
+
+ """
+ LOGGER.info("QOS set to: %d", self._prefetch_count)
+ self.start_consuming()
+
+ def start_consuming(self) -> None:
+ """This method sets up the consumer by first calling
+ add_on_cancel_callback so that the object is notified if RabbitMQ
+ cancels the consumer. It then issues the Basic.Consume RPC command
+ which returns the consumer tag that is used to uniquely identify the
+ consumer with RabbitMQ. We keep the value to use it when we want to
+ cancel consuming. The on_message method is passed in as a callback pika
+ will invoke when a message is fully received.
+
+ """
+ LOGGER.info("Issuing consumer related RPC commands")
+ assert self._channel is not None
+ self.add_on_cancel_callback()
+ self._consumer_tag = self._channel.basic_consume(
+ queue=self._routing_key, on_message_callback=self.on_message
+ )
+ self.was_consuming = True
+ self._consuming = True
+
+ def add_on_cancel_callback(self) -> None:
+ """Add a callback that will be invoked if RabbitMQ cancels the consumer
+ for some reason. If RabbitMQ does cancel the consumer,
+ on_consumer_cancelled will be invoked by pika.
+
+ """
+ LOGGER.info("Adding consumer cancellation callback")
+ assert self._channel is not None
+ self._channel.add_on_cancel_callback(callback=self.on_consumer_cancelled)
+
+ def on_consumer_cancelled(self, method_frame: pika.frame.Method) -> None:
+ """Invoked by pika when RabbitMQ sends a Basic.Cancel for a consumer
+ receiving messages.
+
+ :param pika.frame.Method method_frame: The Basic.Cancel frame
+
+ """
+ LOGGER.info("Consumer was cancelled remotely, shutting down: %r", method_frame)
+ if self._channel:
+ self._channel.close()
+
+ def on_message(
+ self,
+ _unused_channel: pika.channel.Channel,
+ basic_deliver: pika.spec.Basic.Deliver,
+ properties: pika.spec.BasicProperties,
+ body: bytes,
+ ) -> None:
+ """Invoked by pika when a message is delivered from RabbitMQ. The
+ channel is passed for your convenience. The basic_deliver object that
+ is passed in carries the exchange, routing key, delivery tag and
+ a redelivered flag for the message. The properties passed in is an
+ instance of BasicProperties with the message properties and the body
+ is the message that was sent.
+
+ :param pika.channel.Channel _unused_channel: The channel object
+ :param pika.spec.Basic.Deliver: basic_deliver method
+ :param pika.spec.BasicProperties: properties
+ :param bytes body: The message body
+
+ """
+ LOGGER.info(
+ "Received message # %s from %s: %s",
+ basic_deliver.delivery_tag,
+ properties.app_id,
+ body,
+ )
+ item = decode_data(data=body, extra_decoders=self.extra_type_decoders)["data"]
+ self._request_queue.put(
+ (item, (properties.correlation_id, properties.reply_to))
+ )
+ self.acknowledge_message(delivery_tag=basic_deliver.delivery_tag)
+
+ def acknowledge_message(self, delivery_tag: int) -> None:
+ """Acknowledge the message delivery from RabbitMQ by sending a
+ Basic.Ack RPC method for the delivery tag.
+
+ :param int delivery_tag: The delivery tag from the Basic.Deliver frame
+
+ """
+ LOGGER.info("Acknowledging message %s", delivery_tag)
+ assert self._channel is not None
+ self._channel.basic_ack(delivery_tag=delivery_tag)
+
+ def stop_consuming(self) -> None:
+ """Tell RabbitMQ that you would like to stop consuming by sending the
+ Basic.Cancel RPC command.
+
+ """
+ if self._channel:
+ LOGGER.info("Sending a Basic.Cancel RPC command to RabbitMQ")
+ self._channel.basic_cancel(self._consumer_tag, self.on_cancelok)
+
+ def on_cancelok(self, _unused_frame: pika.frame.Method) -> None:
+ """This method is invoked by pika when RabbitMQ acknowledges the
+ cancellation of a consumer. At this point we will close the channel.
+ This will invoke the on_channel_closed method once the channel has been
+ closed, which will in-turn close the connection.
+
+ :param pika.frame.Method _unused_frame: The Basic.CancelOk frame
+ :param str|unicode consumer_tag: Tag of the consumer to be stopped
+
+ """
+ self._consuming = False
+ LOGGER.info(
+ "RabbitMQ acknowledged the cancellation of the consumer: %s",
+ self._consumer_tag,
+ )
+ self.close_channel()
+
+ def close_channel(self) -> None:
+ """Call to close the channel with RabbitMQ cleanly by issuing the
+ Channel.Close RPC command.
+
+ """
+ LOGGER.info("Closing the channel")
+ assert self._channel is not None
+ self._channel.close()
+
+ def run(self) -> None:
+ """Run the example code by connecting and then starting the IOLoop."""
+
+ self._request_queue: queue.Queue = queue.Queue()
+ self._storage_thread = threading.Thread(target=self.run_storage_thread)
+ self._storage_thread.start()
+
+ while not self._closing:
+ try:
+ self._connection = self.connect()
+ assert self._connection is not None
+ self._connection.ioloop.start()
+ except KeyboardInterrupt:
+ self.stop()
+ if self._connection is not None and not self._connection.is_closed:
+ # Finish closing
+ self._connection.ioloop.start()
+
+ self._request_queue.put(TERMINATE)
+ self._storage_thread.join()
+ LOGGER.info("Stopped")
+
+ def run_storage_thread(self) -> None:
+ storage = get_provenance_storage(**self._storage_config)
+
+ meth_name, relation = ProvenanceStorageRabbitMQServer.get_meth_name(
+ self._routing_key
+ )
+ resolve_conflicts = ProvenanceStorageRabbitMQWorker.get_conflicts_func(
+ meth_name
+ )
+
+ while True:
+ terminate = False
+ elements = []
+ while True:
+ try:
+ elem = self._request_queue.get(timeout=0.1)
+ if elem is TERMINATE:
+ terminate = True
+ break
+ elements.append(elem)
+ except queue.Empty:
+ break
+
+ if len(elements) >= self._batch_size:
+ break
+
+ if terminate:
+ break
+
+ if not elements:
+ continue
+
+ # FIXME: this is running in a different thread! Hence, if self._connection
+ # drops, there is no guarantee that the response can be sent for the current
+ # elements. This situation should be handled properly.
+ assert self._connection is not None
+
+ items, props = zip(*elements)
+ acks_count: TCounter[Tuple[str, str]] = Counter(props)
+ data = resolve_conflicts(items)
+
+ args = (relation, data) if relation is not None else (data,)
+ if getattr(storage, meth_name)(*args):
+ for (correlation_id, reply_to), count in acks_count.items():
+ self._connection.ioloop.add_callback_threadsafe(
+ functools.partial(
+ ProvenanceStorageRabbitMQServer.respond,
+ channel=self._channel,
+ correlation_id=correlation_id,
+ reply_to=reply_to,
+ response=count,
+ )
+ )
+ else:
+ LOGGER.warning(
+ "Unable to process elements for queue %s", self._routing_key
)
- )
- # links is now a list of url, endpoint tuples
- return links
+ for elem in elements:
+ self._request_queue.put(elem)
+
+ def stop(self) -> None:
+ """Cleanly shutdown the connection to RabbitMQ by stopping the consumer
+ with RabbitMQ. When RabbitMQ confirms the cancellation, on_cancelok
+ will be invoked by pika, which will then closing the channel and
+ connection. The IOLoop is started again because this method is invoked
+ when CTRL-C is pressed raising a KeyboardInterrupt exception. This
+ exception stops the IOLoop which needs to be running for pika to
+ communicate with RabbitMQ. All of the commands issued prior to starting
+ the IOLoop will be buffered but not processed.
+
+ """
+ assert self._connection is not None
+ if not self._closing:
+ self._closing = True
+ LOGGER.info("Stopping")
+ if self._consuming:
+ self.stop_consuming()
+ self._connection.ioloop.start()
+ else:
+ self._connection.ioloop.stop()
+ LOGGER.info("Stopped")
+
+ @staticmethod
+ def get_conflicts_func(meth_name: str) -> Callable[[List[Any]], Any]:
+ if meth_name.startswith("relation_add"):
+ # Create RelationData structures from tuples and deduplicate
+ return lambda data: {RelationData(*row) for row in data}
+ else:
+ # Dates should be resolved to the earliest one,
+ # otherwise last processed value is good enough
+ return resolve_dates if "_date" in meth_name else lambda data: dict(data)
+
+
+class ProvenanceStorageRabbitMQServer:
+ backend_class = ProvenanceStorageInterface
+ extra_type_decoders = DECODERS
+ extra_type_encoders = ENCODERS
+
+ queue_count = 16
+ relation_add_regex = re.compile(r"relation_add_(?P\w+)_[0-9a-fA-F]+")
+
+ def __init__(self, url: str, storage_config: Dict[str, Any]) -> None:
+ self._url = url
+ self._connection = pika.BlockingConnection(
+ pika.connection.URLParameters(self._url)
+ )
+ self._channel = self._connection.channel()
+ self._workers: List[ProvenanceStorageRabbitMQWorker] = []
+
+ self._channel.basic_qos(prefetch_count=1)
+
+ self._storage_config = storage_config
+ self._channel.queue_declare(queue="get_storage_config")
+ self._channel.basic_consume(
+ queue="get_storage_config", on_message_callback=self.get_storage_config
+ )
+
+ for meth_name, meth in self.backend_class.__dict__.items():
+ if hasattr(
+ meth, "_endpoint_path"
+ ) and ProvenanceStorageRabbitMQServer.is_write_method(meth_name):
+ for routing_key in ProvenanceStorageRabbitMQServer.get_routing_keys(
+ meth_name
+ ):
+ worker = ProvenanceStorageRabbitMQWorker(
+ self._url, routing_key, self._storage_config
+ )
+ worker.start()
+ self._workers.append(worker)
+
+ try:
+ LOGGER.info("Start consuming")
+ self._channel.start_consuming()
+ finally:
+ LOGGER.info("Stop consuming")
+ for worker in self._workers:
+ worker.terminate()
+ worker.join()
+ self._channel.close()
+ self._connection.close()
+
+ @staticmethod
+ def ack(channel: pika.channel.Channel, delivery_tag: int) -> None:
+ channel.basic_ack(delivery_tag=delivery_tag)
+
+ @staticmethod
+ def get_meth_name(routing_key: str) -> Tuple[str, Optional[RelationType]]:
+ match = ProvenanceStorageRabbitMQServer.relation_add_regex.match(routing_key)
+ if match:
+ return "relation_add", RelationType(match.group("relation"))
+ else:
+ return next(iter(routing_key.rsplit("_", 1))), None
+
+ @staticmethod
+ def get_routing_key(
+ meth_name: str, relation: Optional[RelationType], data: Any
+ ) -> str:
+ idx = data[0][0] // (256 // ProvenanceStorageRabbitMQServer.queue_count)
+ if relation is not None:
+ return f"{meth_name}_{relation.value}_{idx:x}".lower()
+ else:
+ return f"{meth_name}_{idx:x}".lower()
+
+ @staticmethod
+ def get_routing_keys(meth_name: str) -> Generator[str, None, None]:
+ if meth_name.startswith("relation_add"):
+ for relation in RelationType:
+ for idx in range(ProvenanceStorageRabbitMQServer.queue_count):
+ yield f"{meth_name}_{relation.value}_{idx:x}".lower()
+ else:
+ for idx in range(ProvenanceStorageRabbitMQServer.queue_count):
+ yield f"{meth_name}_{idx:x}".lower()
+
+ def get_storage_config(
+ self,
+ channel: pika.channel.Channel,
+ method: pika.spec.Basic.Deliver,
+ props: pika.spec.BasicProperties,
+ body: bytes,
+ ) -> None:
+ assert method.routing_key == "get_storage_config"
+ ProvenanceStorageRabbitMQServer.respond(
+ channel=channel,
+ correlation_id=props.correlation_id,
+ reply_to=props.reply_to,
+ response=self._storage_config,
+ )
+ ProvenanceStorageRabbitMQServer.ack(
+ channel=channel, delivery_tag=method.delivery_tag
+ )
+
+ @staticmethod
+ def is_write_method(meth_name: str) -> bool:
+ return meth_name.startswith("relation_add") or "_set_" in meth_name
+
+ @staticmethod
+ def respond(
+ channel: pika.channel.Channel,
+ correlation_id: str,
+ reply_to: str,
+ response: Any,
+ ):
+ channel.basic_publish(
+ exchange="",
+ routing_key=reply_to,
+ properties=pika.BasicProperties(correlation_id=correlation_id),
+ body=encode_data(
+ response,
+ extra_encoders=ProvenanceStorageRabbitMQServer.extra_type_encoders,
+ ),
+ )
def load_and_check_config(
@@ -127,22 +726,10 @@
return cfg
-api_cfg: Optional[Dict[str, Any]] = None
-
-
-def make_app_from_configfile() -> ProvenanceStorageServerApp:
- """Run the WSGI app from the webserver, loading the configuration from
- a configuration file.
-
- SWH_CONFIG_FILENAME environment variable defines the
- configuration path to load.
-
- """
- global api_cfg
- if api_cfg is None:
- config_path = os.environ.get("SWH_CONFIG_FILENAME")
- api_cfg = load_and_check_config(config_path)
- app.config.update(api_cfg)
- handler = logging.StreamHandler()
- app.logger.addHandler(handler)
- return app
+def make_server_from_configfile() -> None:
+ config_path = os.environ.get("SWH_CONFIG_FILENAME")
+ server_cfg = load_and_check_config(config_path)
+ ProvenanceStorageRabbitMQServer(
+ url=server_cfg["provenance"]["rabbitmq"]["url"],
+ storage_config=server_cfg["provenance"]["storage"],
+ )
diff --git a/swh/provenance/cli.py b/swh/provenance/cli.py
--- a/swh/provenance/cli.py
+++ b/swh/provenance/cli.py
@@ -42,21 +42,21 @@
},
"storage": {
# Local PostgreSQL Storage
- "cls": "postgresql",
- "db": {
- "host": "localhost",
- "user": "postgres",
- "password": "postgres",
- "dbname": "provenance",
- },
+ # "cls": "postgresql",
+ # "db": {
+ # "host": "localhost",
+ # "user": "postgres",
+ # "password": "postgres",
+ # "dbname": "provenance",
+ # },
# Local MongoDB Storage
# "cls": "mongodb",
# "db": {
# "dbname": "provenance",
# },
- # Remote REST-API/PostgreSQL
- # "cls": "remote",
- # "url": "http://localhost:8080/%2f",
+ # Remote RabbitMQ/PostgreSQL Storage
+ "cls": "rabbitmq",
+ "url": "amqp://localhost:5672/%2f",
},
}
}
diff --git a/swh/provenance/graph.py b/swh/provenance/graph.py
--- a/swh/provenance/graph.py
+++ b/swh/provenance/graph.py
@@ -6,7 +6,6 @@
from __future__ import annotations
from datetime import datetime, timezone
-import logging
import os
from typing import Any, Dict, Optional, Set
@@ -187,9 +186,6 @@
root_date = provenance.directory_get_date_in_isochrone_frontier(directory)
root = IsochroneNode(directory, dbdate=root_date)
stack = [root]
- logging.debug(
- f"Recursively creating isochrone graph for revision {revision.id.hex()}..."
- )
fdates: Dict[Sha1Git, datetime] = {} # map {file_id: date}
while stack:
current = stack.pop()
@@ -198,12 +194,6 @@
# is greater or equal to the current revision's one, it should be ignored as
# the revision is being processed out of order.
if current.dbdate is not None and current.dbdate > revision.date:
- logging.debug(
- f"Invalidating frontier on {current.entry.id.hex()}"
- f" (date {current.dbdate})"
- f" when processing revision {revision.id.hex()}"
- f" (date {revision.date})"
- )
current.invalidate()
# Pre-query all known dates for directories in the current directory
@@ -220,12 +210,8 @@
fdates.update(provenance.content_get_early_dates(current.entry.files))
- logging.debug(
- f"Isochrone graph for revision {revision.id.hex()} successfully created!"
- )
# Precalculate max known date for each node in the graph (only directory nodes are
# pushed to the stack).
- logging.debug(f"Computing maxdates for revision {revision.id.hex()}...")
stack = [root]
while stack:
@@ -276,5 +262,4 @@
# node should be treated as unknown
current.maxdate = revision.date
current.known = False
- logging.debug(f"Maxdates for revision {revision.id.hex()} successfully computed!")
return root
diff --git a/swh/provenance/origin.py b/swh/provenance/origin.py
--- a/swh/provenance/origin.py
+++ b/swh/provenance/origin.py
@@ -4,8 +4,6 @@
# See top-level LICENSE file for more information
from itertools import islice
-import logging
-import time
from typing import Generator, Iterable, Iterator, List, Optional, Tuple
from swh.model.model import Sha1Git
@@ -49,21 +47,13 @@
archive: ArchiveInterface,
origins: List[OriginEntry],
) -> None:
- start = time.time()
for origin in origins:
provenance.origin_add(origin)
origin.retrieve_revisions(archive)
for revision in origin.revisions:
graph = HistoryGraph(archive, provenance, revision)
origin_add_revision(provenance, origin, graph)
- done = time.time()
provenance.flush()
- stop = time.time()
- logging.debug(
- "Origins "
- ";".join([origin.id.hex() + ":" + origin.snapshot.hex() for origin in origins])
- + f" were processed in {stop - start} secs (commit took {stop - done} secs)!"
- )
def origin_add_revision(
diff --git a/swh/provenance/postgresql/provenance.py b/swh/provenance/postgresql/provenance.py
--- a/swh/provenance/postgresql/provenance.py
+++ b/swh/provenance/postgresql/provenance.py
@@ -23,6 +23,8 @@
RevisionData,
)
+LOGGER = logging.getLogger(__name__)
+
class ProvenanceStoragePostgreSql:
def __init__(
@@ -91,7 +93,6 @@
try:
if urls:
sql = """
- LOCK TABLE ONLY origin;
INSERT INTO origin(sha1, url) VALUES %s
ON CONFLICT DO NOTHING
"""
@@ -99,7 +100,7 @@
return True
except: # noqa: E722
# Unexpected error occurred, rollback all changes and log message
- logging.exception("Unexpected error")
+ LOGGER.exception("Unexpected error")
if self.raise_on_commit:
raise
return False
@@ -126,7 +127,6 @@
try:
if origins:
sql = """
- LOCK TABLE ONLY revision;
INSERT INTO revision(sha1, origin)
(SELECT V.rev AS sha1, O.id AS origin
FROM (VALUES %s) AS V(rev, org)
@@ -138,7 +138,7 @@
return True
except: # noqa: E722
# Unexpected error occurred, rollback all changes and log message
- logging.exception("Unexpected error")
+ LOGGER.exception("Unexpected error")
if self.raise_on_commit:
raise
return False
@@ -176,7 +176,6 @@
# non-null information
srcs = tuple(set((sha1,) for (sha1, _, _) in rows))
sql = f"""
- LOCK TABLE ONLY {src_table};
INSERT INTO {src_table}(sha1) VALUES %s
ON CONFLICT DO NOTHING
"""
@@ -187,22 +186,19 @@
# non-null information
dsts = tuple(set((sha1,) for (_, sha1, _) in rows))
sql = f"""
- LOCK TABLE ONLY {dst_table};
INSERT INTO {dst_table}(sha1) VALUES %s
ON CONFLICT DO NOTHING
"""
psycopg2.extras.execute_values(self.cursor, sql, dsts)
sql = """
- SELECT * FROM swh_provenance_relation_add(
- %s, %s, %s, %s::rel_row[]
- )
+ SELECT * FROM swh_provenance_relation_add(%s, %s, %s, %s::rel_row[])
"""
self.cursor.execute(sql, (rel_table, src_table, dst_table, rows))
return True
except: # noqa: E722
# Unexpected error occurred, rollback all changes and log message
- logging.exception("Unexpected error")
+ LOGGER.exception("Unexpected error")
if self.raise_on_commit:
raise
return False
@@ -243,7 +239,6 @@
try:
if data:
sql = f"""
- LOCK TABLE ONLY {entity};
INSERT INTO {entity}(sha1, date) VALUES %s
ON CONFLICT (sha1) DO
UPDATE SET date=LEAST(EXCLUDED.date,{entity}.date)
@@ -252,7 +247,7 @@
return True
except: # noqa: E722
# Unexpected error occurred, rollback all changes and log message
- logging.exception("Unexpected error")
+ LOGGER.exception("Unexpected error")
if self.raise_on_commit:
raise
return False
diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py
--- a/swh/provenance/provenance.py
+++ b/swh/provenance/provenance.py
@@ -20,6 +20,8 @@
)
from .model import DirectoryEntry, FileEntry, OriginEntry, RevisionEntry
+LOGGER = logging.getLogger(__name__)
+
class DatetimeCache(TypedDict):
data: Dict[Sha1Git, Optional[datetime]]
@@ -79,41 +81,44 @@
# For this layer, relations need to be inserted first so that, in case of
# failure, reprocessing the input does not generated an inconsistent database.
- while not self.storage.relation_add(
- RelationType.CNT_EARLY_IN_REV,
- (
- RelationData(src=src, dst=dst, path=path)
- for src, dst, path in self.cache["content_in_revision"]
- ),
- ):
- logging.warning(
- f"Unable to write {RelationType.CNT_EARLY_IN_REV} rows to the storage. "
- f"Data: {self.cache['content_in_revision']}. Retrying..."
- )
-
- while not self.storage.relation_add(
- RelationType.CNT_IN_DIR,
- (
- RelationData(src=src, dst=dst, path=path)
- for src, dst, path in self.cache["content_in_directory"]
- ),
- ):
- logging.warning(
- f"Unable to write {RelationType.CNT_IN_DIR} rows to the storage. "
- f"Data: {self.cache['content_in_directory']}. Retrying..."
- )
-
- while not self.storage.relation_add(
- RelationType.DIR_IN_REV,
- (
- RelationData(src=src, dst=dst, path=path)
- for src, dst, path in self.cache["directory_in_revision"]
- ),
- ):
- logging.warning(
- f"Unable to write {RelationType.DIR_IN_REV} rows to the storage. "
- f"Data: {self.cache['directory_in_revision']}. Retrying..."
- )
+ if self.cache["content_in_revision"]:
+ while not self.storage.relation_add(
+ RelationType.CNT_EARLY_IN_REV,
+ (
+ RelationData(src=src, dst=dst, path=path)
+ for src, dst, path in self.cache["content_in_revision"]
+ ),
+ ):
+ LOGGER.warning(
+ "Unable to write %s rows to the storage. Retrying...",
+ RelationType.CNT_EARLY_IN_REV,
+ )
+
+ if self.cache["content_in_directory"]:
+ while not self.storage.relation_add(
+ RelationType.CNT_IN_DIR,
+ (
+ RelationData(src=src, dst=dst, path=path)
+ for src, dst, path in self.cache["content_in_directory"]
+ ),
+ ):
+ LOGGER.warning(
+ "Unable to write %s rows to the storage. Retrying...",
+ RelationType.CNT_IN_DIR,
+ )
+
+ if self.cache["directory_in_revision"]:
+ while not self.storage.relation_add(
+ RelationType.DIR_IN_REV,
+ (
+ RelationData(src=src, dst=dst, path=path)
+ for src, dst, path in self.cache["directory_in_revision"]
+ ),
+ ):
+ LOGGER.warning(
+ "Unable to write %s rows to the storage. Retrying...",
+ RelationType.DIR_IN_REV,
+ )
# After relations, dates for the entities can be safely set, acknowledging that
# these entities won't need to be reprocessed in case of failure.
@@ -122,33 +127,33 @@
for sha1, date in self.cache["content"]["data"].items()
if sha1 in self.cache["content"]["added"] and date is not None
}
- while not self.storage.content_set_date(dates):
- logging.warning(
- f"Unable to write content dates to the storage. "
- f"Data: {dates}. Retrying..."
- )
+ if dates:
+ while not self.storage.content_set_date(dates):
+ LOGGER.warning(
+ "Unable to write content dates to the storage. Retrying..."
+ )
dates = {
sha1: date
for sha1, date in self.cache["directory"]["data"].items()
if sha1 in self.cache["directory"]["added"] and date is not None
}
- while not self.storage.directory_set_date(dates):
- logging.warning(
- f"Unable to write directory dates to the storage. "
- f"Data: {dates}. Retrying..."
- )
+ if dates:
+ while not self.storage.directory_set_date(dates):
+ LOGGER.warning(
+ "Unable to write directory dates to the storage. Retrying..."
+ )
dates = {
sha1: date
for sha1, date in self.cache["revision"]["data"].items()
if sha1 in self.cache["revision"]["added"] and date is not None
}
- while not self.storage.revision_set_date(dates):
- logging.warning(
- f"Unable to write revision dates to the storage. "
- f"Data: {dates}. Retrying..."
- )
+ if dates:
+ while not self.storage.revision_set_date(dates):
+ LOGGER.warning(
+ "Unable to write revision dates to the storage. Retrying..."
+ )
# Origin-revision layer insertions #############################################
@@ -159,11 +164,11 @@
for sha1, url in self.cache["origin"]["data"].items()
if sha1 in self.cache["origin"]["added"]
}
- while not self.storage.origin_set_url(urls):
- logging.warning(
- f"Unable to write origins urls to the storage. "
- f"Data: {urls}. Retrying..."
- )
+ if urls:
+ while not self.storage.origin_set_url(urls):
+ LOGGER.warning(
+ "Unable to write origins urls to the storage. Retrying..."
+ )
# Second, flat models for revisions' histories (ie. revision-before-revision).
data: Iterable[RelationData] = sum(
@@ -176,11 +181,12 @@
],
[],
)
- while not self.storage.relation_add(RelationType.REV_BEFORE_REV, data):
- logging.warning(
- f"Unable to write {RelationType.REV_BEFORE_REV} rows to the storage. "
- f"Data: {data}. Retrying..."
- )
+ if data:
+ while not self.storage.relation_add(RelationType.REV_BEFORE_REV, data):
+ LOGGER.warning(
+ "Unable to write %s rows to the storage. Retrying...",
+ RelationType.REV_BEFORE_REV,
+ )
# Heads (ie. revision-in-origin entries) should be inserted once flat models for
# their histories were already added. This is to guarantee consistent results if
@@ -190,11 +196,12 @@
RelationData(src=rev, dst=org, path=None)
for rev, org in self.cache["revision_in_origin"]
)
- while not self.storage.relation_add(RelationType.REV_IN_ORG, data):
- logging.warning(
- f"Unable to write {RelationType.REV_IN_ORG} rows to the storage. "
- f"Data: {data}. Retrying..."
- )
+ if data:
+ while not self.storage.relation_add(RelationType.REV_IN_ORG, data):
+ LOGGER.warning(
+ "Unable to write %s rows to the storage. Retrying...",
+ RelationType.REV_IN_ORG,
+ )
# Finally, preferred origins for the visited revisions are set (this step can be
# reordered if required).
@@ -202,11 +209,11 @@
sha1: self.cache["revision_origin"]["data"][sha1]
for sha1 in self.cache["revision_origin"]["added"]
}
- while not self.storage.revision_set_origin(origins):
- logging.warning(
- f"Unable to write preferred origins to the storage. "
- f"Data: {origins}. Retrying..."
- )
+ if origins:
+ while not self.storage.revision_set_origin(origins):
+ LOGGER.warning(
+ "Unable to write preferred origins to the storage. Retrying..."
+ )
# clear local cache ############################################################
self.clear_caches()
diff --git a/swh/provenance/revision.py b/swh/provenance/revision.py
--- a/swh/provenance/revision.py
+++ b/swh/provenance/revision.py
@@ -4,9 +4,7 @@
# See top-level LICENSE file for more information
from datetime import datetime, timezone
-import logging
import os
-import time
from typing import Generator, Iterable, Iterator, List, Optional, Tuple
from swh.model.model import Sha1Git
@@ -59,17 +57,12 @@
mindepth: int = 1,
commit: bool = True,
) -> None:
- start = time.time()
for revision in revisions:
assert revision.date is not None
assert revision.root is not None
# Processed content starting from the revision's root directory.
date = provenance.revision_get_date(revision)
if date is None or revision.date < date:
- logging.debug(
- f"Processing revisions {revision.id.hex()}"
- f" (known date {date} / revision date {revision.date})..."
- )
graph = build_isochrone_graph(
archive,
provenance,
@@ -86,14 +79,8 @@
lower=lower,
mindepth=mindepth,
)
- done = time.time()
if commit:
provenance.flush()
- stop = time.time()
- logging.debug(
- f"Revisions {';'.join([revision.id.hex() for revision in revisions])} "
- f" were processed in {stop - start} secs (commit took {stop - done} secs)!"
- )
def revision_process_content(
diff --git a/swh/provenance/sql/40-funcs.sql b/swh/provenance/sql/40-funcs.sql
--- a/swh/provenance/sql/40-funcs.sql
+++ b/swh/provenance/sql/40-funcs.sql
@@ -89,7 +89,6 @@
join_location text;
begin
if src_table in ('content'::regclass, 'directory'::regclass) then
- lock table only location;
insert into location(path)
select V.path
from unnest(rel_data) as V
@@ -103,15 +102,14 @@
end if;
execute format(
- 'lock table only %s;
- insert into %s
+ 'insert into %s
select S.id, ' || select_fields || '
from unnest($1) as V
inner join %s as S on (S.sha1 = V.src)
inner join %s as D on (D.sha1 = V.dst)
' || join_location || '
on conflict do nothing',
- rel_table, rel_table, src_table, dst_table
+ rel_table, src_table, dst_table
) using rel_data;
end;
$$;
@@ -244,14 +242,13 @@
as $$
begin
execute format(
- 'lock table only %s;
- insert into %s
+ 'insert into %s
select S.id, D.id
from unnest($1) as V
inner join %s as S on (S.sha1 = V.src)
inner join %s as D on (D.sha1 = V.dst)
on conflict do nothing',
- rel_table, rel_table, src_table, dst_table
+ rel_table, src_table, dst_table
) using rel_data;
end;
$$;
@@ -412,7 +409,6 @@
on_conflict text;
begin
if src_table in ('content'::regclass, 'directory'::regclass) then
- lock table only location;
insert into location(path)
select V.path
from unnest(rel_data) as V
@@ -438,8 +434,7 @@
end if;
execute format(
- 'lock table only %s;
- insert into %s
+ 'insert into %s
select S.id, ' || select_fields || '
from unnest($1) as V
inner join %s as S on (S.sha1 = V.src)
@@ -447,7 +442,7 @@
' || join_location || '
' || group_entries || '
on conflict ' || on_conflict,
- rel_table, rel_table, src_table, dst_table
+ rel_table, src_table, dst_table
) using rel_data;
end;
$$;
@@ -631,15 +626,14 @@
end if;
execute format(
- 'lock table only %s;
- insert into %s
+ 'insert into %s
select S.id, ' || select_fields || '
from unnest($1) as V
inner join %s as S on (S.sha1 = V.src)
inner join %s as D on (D.sha1 = V.dst)
' || group_entries || '
on conflict ' || on_conflict,
- rel_table, rel_table, src_table, dst_table
+ rel_table, src_table, dst_table
) using rel_data;
end;
$$;
diff --git a/swh/provenance/tests/conftest.py b/swh/provenance/tests/conftest.py
--- a/swh/provenance/tests/conftest.py
+++ b/swh/provenance/tests/conftest.py
@@ -5,7 +5,7 @@
from datetime import datetime, timedelta, timezone
from os import path
-from typing import Any, Dict, Iterable, Iterator
+from typing import Any, Dict, Iterable
from _pytest.fixtures import SubRequest
import msgpack
@@ -16,8 +16,6 @@
from swh.journal.serializers import msgpack_ext_hook
from swh.provenance import get_provenance, get_provenance_storage
-from swh.provenance.api.client import RemoteProvenanceStorage
-import swh.provenance.api.server as server
from swh.provenance.archive import ArchiveInterface
from swh.provenance.interface import ProvenanceInterface, ProvenanceStorageInterface
from swh.provenance.storage.archive import ArchiveStorage
@@ -46,38 +44,15 @@
return postgresql.get_dsn_parameters()
-# the Flask app used as server in these tests
-@pytest.fixture
-def app(
- provenance_postgresqldb: Dict[str, str]
-) -> Iterator[server.ProvenanceStorageServerApp]:
- assert hasattr(server, "storage")
- server.storage = get_provenance_storage(
- cls="postgresql", db=provenance_postgresqldb
- )
- yield server.app
-
-
-# the RPCClient class used as client used in these tests
-@pytest.fixture
-def swh_rpc_client_class() -> type:
- return RemoteProvenanceStorage
-
-
-@pytest.fixture(params=["mongodb", "postgresql", "remote"])
+@pytest.fixture(params=["mongodb", "postgresql"])
def provenance_storage(
request: SubRequest,
provenance_postgresqldb: Dict[str, str],
mongodb: pymongo.database.Database,
- swh_rpc_client: RemoteProvenanceStorage,
) -> ProvenanceStorageInterface:
"""Return a working and initialized ProvenanceStorageInterface object"""
- if request.param == "remote":
- assert isinstance(swh_rpc_client, ProvenanceStorageInterface)
- return swh_rpc_client
-
- elif request.param == "mongodb":
+ if request.param == "mongodb":
from swh.provenance.mongo.backend import ProvenanceStorageMongoDb
return ProvenanceStorageMongoDb(mongodb)