diff --git a/mypy.ini b/mypy.ini --- a/mypy.ini +++ b/mypy.ini @@ -17,6 +17,9 @@ [mypy-msgpack.*] ignore_missing_imports = True +[mypy-pika.*] +ignore_missing_imports = True + [mypy-pkg_resources.*] ignore_missing_imports = True diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py --- a/swh/provenance/__init__.py +++ b/swh/provenance/__init__.py @@ -96,12 +96,12 @@ db = MongoClient(**kwargs["db"]).get_database(dbname) return ProvenanceStorageMongoDb(db) - elif cls == "remote": - from .api.client import RemoteProvenanceStorage + elif cls == "rabbitmq": + from .api.client import ProvenanceStorageRabbitMQClient - storage = RemoteProvenanceStorage(**kwargs) - assert isinstance(storage, ProvenanceStorageInterface) - return storage + rmq_storage = ProvenanceStorageRabbitMQClient(**kwargs) + if TYPE_CHECKING: + assert isinstance(rmq_storage, ProvenanceStorageInterface) + return rmq_storage - else: - raise ValueError + raise ValueError diff --git a/swh/provenance/api/client.py b/swh/provenance/api/client.py --- a/swh/provenance/api/client.py +++ b/swh/provenance/api/client.py @@ -3,15 +3,531 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from swh.core.api import RPCClient +import functools +import inspect +import logging +import queue +import threading +import time +from typing import Any, Dict +import uuid + +import pika +import pika.channel +import pika.connection +import pika.frame +import pika.spec + +from swh.core.api.serializers import encode_data_client as encode_data +from swh.core.api.serializers import msgpack_loads as decode_data +from swh.provenance import get_provenance_storage from ..interface import ProvenanceStorageInterface from .serializers import DECODERS, ENCODERS +from .server import ProvenanceStorageRabbitMQServer + +LOG_FORMAT = ( + "%(levelname) -10s %(asctime)s %(name) -30s %(funcName) " + "-35s %(lineno) -5d: %(message)s" +) +LOGGER = logging.getLogger(__name__) + + +class ConfigurationError(Exception): + pass + + +class ResponseTimeout(Exception): + pass + + +class MetaRabbitMQClient(type): + def __new__(cls, name, bases, attributes): + # For each method wrapped with @remote_api_endpoint in an API backend + # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new + # method in RemoteStorage, with the same documentation. + # + # Note that, despite the usage of decorator magic (eg. functools.wrap), + # this never actually calls an IndexerStorage method. + backend_class = attributes.get("backend_class", None) + for base in bases: + if backend_class is not None: + break + backend_class = getattr(base, "backend_class", None) + if backend_class: + for meth_name, meth in backend_class.__dict__.items(): + if hasattr(meth, "_endpoint_path"): + cls.__add_endpoint(meth_name, meth, attributes) + return super().__new__(cls, name, bases, attributes) + + @staticmethod + def __add_endpoint(meth_name, meth, attributes): + wrapped_meth = inspect.unwrap(meth) + + @functools.wraps(meth) # Copy signature and doc + def meth_(*args, **kwargs): + # Match arguments and parameters + data = inspect.getcallargs(wrapped_meth, *args, **kwargs) + # Remove arguments that should not be passed + self = data.pop("self") -class RemoteProvenanceStorage(RPCClient): - """Proxy to a remote provenance storage API""" + # Call storage method with remaining arguments + return getattr(self._storage, meth_name)(**data) + + @functools.wraps(meth) # Copy signature and doc + def write_meth_(*args, **kwargs): + # Match arguments and parameters + post_data = inspect.getcallargs(wrapped_meth, *args, **kwargs) + + # Remove arguments that should not be passed + self = post_data.pop("self") + relation = post_data.pop("relation", None) + assert len(post_data) == 1 + if relation is not None: + data = [ + (row.src, row.dst, row.path) + for row in next(iter(post_data.values())) + ] + else: + data = list(next(iter(post_data.values())).items()) + + acks_expected = len(data) + self._correlation_id = str(uuid.uuid4()) + exchange = ProvenanceStorageRabbitMQServer.get_exchange(meth_name, relation) + for item in data: + # FIXME: this is running in a different thread! Hence, if + # self._connection drops, there is no guarantee that the request can + # be sent for the current elements. This situation should be handled + # properly. + self._connection.ioloop.add_callback_threadsafe( + functools.partial( + ProvenanceStorageRabbitMQClient.request, + channel=self._channel, + reply_to=self._callback_queue, + exchange=exchange, + routing_key=ProvenanceStorageRabbitMQServer.get_routing_key( + item + ), + correlation_id=self._correlation_id, + data=item, + ) + ) + return self.wait_for_acks(acks_expected) + + if meth_name not in attributes: + attributes[meth_name] = ( + write_meth_ + if ProvenanceStorageRabbitMQServer.is_write_method(meth_name) + else meth_ + ) + + +class ProvenanceStorageRabbitMQClient(metaclass=MetaRabbitMQClient): + """This is an example publisher that will handle unexpected interactions + with RabbitMQ such as channel and connection closures. + + If RabbitMQ closes the connection, it will reopen it. You should + look at the output, as there are limited reasons why the connection may + be closed, which usually are tied to permission related issues or + socket timeouts. + + It uses delivery confirmations and illustrates one way to keep track of + messages that have been sent and if they've been confirmed by RabbitMQ. + + """ backend_class = ProvenanceStorageInterface extra_type_decoders = DECODERS extra_type_encoders = ENCODERS + + def __init__(self, url: str, storage_config: Dict[str, Any]) -> None: + """Setup the example publisher object, passing in the URL we will use + to connect to RabbitMQ. + + :param str url: The URL for connecting to RabbitMQ + :param str routing_key: The routing key name from which this worker will + consume messages + :param str storage_config: Configuration parameters for the underlying + ``ProvenanceStorage`` object + + """ + self._connection = None + self._channel = None + self._closing = False + self._consumer_tag = None + self._consuming = False + self._prefetch_count = 100 + + self._url = url + self._storage = get_provenance_storage(**storage_config) + self._response_queue: queue.Queue = queue.Queue() + self._correlation_id = str(uuid.uuid4()) + + self._consumer_thread = threading.Thread(target=self.run) + self._consumer_thread.start() + + def connect(self) -> pika.SelectConnection: + """This method connects to RabbitMQ, returning the connection handle. + When the connection is established, the on_connection_open method + will be invoked by pika. + + :rtype: pika.SelectConnection + + """ + LOGGER.info("Connecting to %s", self._url) + return pika.SelectConnection( + parameters=pika.URLParameters(self._url), + on_open_callback=self.on_connection_open, + on_open_error_callback=self.on_connection_open_error, + on_close_callback=self.on_connection_closed, + ) + + def close_connection(self) -> None: + assert self._connection is not None + self._consuming = False + if self._connection.is_closing or self._connection.is_closed: + LOGGER.info("Connection is closing or already closed") + else: + LOGGER.info("Closing connection") + self._connection.close() + + def on_connection_open(self, _unused_connection: pika.SelectConnection) -> None: + """This method is called by pika once the connection to RabbitMQ has + been established. It passes the handle to the connection object in + case we need it, but in this case, we'll just mark it unused. + + :param pika.SelectConnection _unused_connection: The connection + + """ + LOGGER.info("Connection opened") + self.open_channel() + + def on_connection_open_error( + self, _unused_connection: pika.SelectConnection, err: Exception + ) -> None: + """This method is called by pika if the connection to RabbitMQ + can't be established. + + :param pika.SelectConnection _unused_connection: The connection + :param Exception err: The error + + """ + LOGGER.error("Connection open failed, reopening in 5 seconds: %s", err) + assert self._connection is not None + self._connection.ioloop.call_later(5, self._connection.ioloop.stop) + + def on_connection_closed(self, _unused_connection: pika.SelectConnection, reason): + """This method is invoked by pika when the connection to RabbitMQ is + closed unexpectedly. Since it is unexpected, we will reconnect to + RabbitMQ if it disconnects. + + :param pika.connection.Connection connection: The closed connection obj + :param Exception reason: exception representing reason for loss of + connection. + + """ + assert self._connection is not None + self._channel = None + if self._closing: + self._connection.ioloop.stop() + else: + LOGGER.warning("Connection closed, reopening in 5 seconds: %s", reason) + self._connection.ioloop.call_later(5, self._connection.ioloop.stop) + + def open_channel(self) -> None: + """Open a new channel with RabbitMQ by issuing the Channel.Open RPC + command. When RabbitMQ responds that the channel is open, the + on_channel_open callback will be invoked by pika. + + """ + LOGGER.info("Creating a new channel") + assert self._connection is not None + self._connection.channel(on_open_callback=self.on_channel_open) + + def on_channel_open(self, channel: pika.channel.Channel) -> None: + """This method is invoked by pika when the channel has been opened. + The channel object is passed in so we can make use of it. + + Since the channel is now open, we'll declare the exchange to use. + + :param pika.channel.Channel channel: The channel object + + """ + LOGGER.info("Channel opened") + self._channel = channel + self.add_on_channel_close_callback() + self.setup_queue() + + def add_on_channel_close_callback(self) -> None: + """This method tells pika to call the on_channel_closed method if + RabbitMQ unexpectedly closes the channel. + + """ + LOGGER.info("Adding channel close callback") + assert self._channel is not None + self._channel.add_on_close_callback(callback=self.on_channel_closed) + + def on_channel_closed( + self, channel: pika.channel.Channel, reason: Exception + ) -> None: + """Invoked by pika when RabbitMQ unexpectedly closes the channel. + Channels are usually closed if you attempt to do something that + violates the protocol, such as re-declare an exchange or queue with + different parameters. In this case, we'll close the connection + to shutdown the object. + + :param pika.channel.Channel: The closed channel + :param Exception reason: why the channel was closed + + """ + LOGGER.warning("Channel %i was closed: %s", channel, reason) + self.close_connection() + + def setup_queue(self) -> None: + """Setup the queue on RabbitMQ by invoking the Queue.Declare RPC + command. When it is complete, the on_queue_declareok method will + be invoked by pika. + + """ + LOGGER.info("Declaring callback queue") + assert self._channel is not None + self._channel.queue_declare( + queue="", exclusive=True, callback=self.on_queue_declareok + ) + + def on_queue_declareok(self, frame: pika.frame.Method) -> None: + """Method invoked by pika when the Queue.Declare RPC call made in + setup_queue has completed. This method sets up the consumer prefetch + to only be delivered one message at a time. The consumer must + acknowledge this message before RabbitMQ will deliver another one. + You should experiment with different prefetch values to achieve desired + performance. + + :param pika.frame.Method method_frame: The Queue.DeclareOk frame + + """ + LOGGER.info("Binding queue to default exchanger") + assert self._channel is not None + self._callback_queue = frame.method.queue + self._channel.basic_qos( + prefetch_count=self._prefetch_count, callback=self.on_basic_qos_ok + ) + + def on_basic_qos_ok(self, _unused_frame: pika.frame.Method) -> None: + """Invoked by pika when the Basic.QoS method has completed. At this + point we will start consuming messages by calling start_consuming + which will invoke the needed RPC commands to start the process. + + :param pika.frame.Method _unused_frame: The Basic.QosOk response frame + + """ + LOGGER.info("QOS set to: %d", self._prefetch_count) + self.start_consuming() + + def start_consuming(self) -> None: + """This method sets up the consumer by first calling + add_on_cancel_callback so that the object is notified if RabbitMQ + cancels the consumer. It then issues the Basic.Consume RPC command + which returns the consumer tag that is used to uniquely identify the + consumer with RabbitMQ. We keep the value to use it when we want to + cancel consuming. The on_response method is passed in as a callback pika + will invoke when a message is fully received. + + """ + LOGGER.info("Issuing consumer related RPC commands") + assert self._channel is not None + self.add_on_cancel_callback() + + self._consumer_tag = self._channel.basic_consume( + queue=self._callback_queue, on_message_callback=self.on_response + ) + self._consuming = True + + def add_on_cancel_callback(self) -> None: + """Add a callback that will be invoked if RabbitMQ cancels the consumer + for some reason. If RabbitMQ does cancel the consumer, + on_consumer_cancelled will be invoked by pika. + + """ + LOGGER.info("Adding consumer cancellation callback") + assert self._channel is not None + self._channel.add_on_cancel_callback(callback=self.on_consumer_cancelled) + + def on_consumer_cancelled(self, method_frame: pika.frame.Method) -> None: + """Invoked by pika when RabbitMQ sends a Basic.Cancel for a consumer + receiving messages. + + :param pika.frame.Method method_frame: The Basic.Cancel frame + + """ + LOGGER.info("Consumer was cancelled remotely, shutting down: %r", method_frame) + if self._channel: + self._channel.close() + + def on_response( + self, + _unused_channel: pika.channel.Channel, + basic_deliver: pika.spec.Basic.Deliver, + properties: pika.spec.BasicProperties, + body: bytes, + ) -> None: + """Invoked by pika when a message is delivered from RabbitMQ. The + channel is passed for your convenience. The basic_deliver object that + is passed in carries the exchange, routing key, delivery tag and + a redelivered flag for the message. The properties passed in is an + instance of BasicProperties with the message properties and the body + is the message that was sent. + + :param pika.channel.Channel _unused_channel: The channel object + :param pika.spec.Basic.Deliver: basic_deliver method + :param pika.spec.BasicProperties: properties + :param bytes body: The message body + + """ + LOGGER.info( + "Received message # %s from %s: %s", + basic_deliver.delivery_tag, + properties.app_id, + body, + ) + self._response_queue.put( + ( + properties.correlation_id, + decode_data(body, extra_decoders=self.extra_type_decoders), + ) + ) + self.acknowledge_message(delivery_tag=basic_deliver.delivery_tag) + + def acknowledge_message(self, delivery_tag: int) -> None: + """Acknowledge the message delivery from RabbitMQ by sending a + Basic.Ack RPC method for the delivery tag. + + :param int delivery_tag: The delivery tag from the Basic.Deliver frame + + """ + LOGGER.info("Acknowledging message %s", delivery_tag) + assert self._channel is not None + self._channel.basic_ack(delivery_tag=delivery_tag) + + def stop_consuming(self) -> None: + """Tell RabbitMQ that you would like to stop consuming by sending the + Basic.Cancel RPC command. + + """ + if self._channel: + LOGGER.info("Sending a Basic.Cancel RPC command to RabbitMQ") + self._channel.basic_cancel(self._consumer_tag, self.on_cancelok) + + def on_cancelok(self, _unused_frame: pika.frame.Method) -> None: + """This method is invoked by pika when RabbitMQ acknowledges the + cancellation of a consumer. At this point we will close the channel. + This will invoke the on_channel_closed method once the channel has been + closed, which will in-turn close the connection. + + :param pika.frame.Method _unused_frame: The Basic.CancelOk frame + :param str|unicode consumer_tag: Tag of the consumer to be stopped + + """ + self._consuming = False + LOGGER.info( + "RabbitMQ acknowledged the cancellation of the consumer: %s", + self._consumer_tag, + ) + self.close_channel() + + def close_channel(self) -> None: + """Call to close the channel with RabbitMQ cleanly by issuing the + Channel.Close RPC command. + + """ + LOGGER.info("Closing the channel") + assert self._channel is not None + self._channel.close() + + def run(self) -> None: + """Run the example code by connecting and then starting the IOLoop.""" + + while not self._closing: + try: + self._connection = self.connect() + assert self._connection is not None + self._connection.ioloop.start() + except KeyboardInterrupt: + self.stop() + if self._connection is not None and not self._connection.is_closed: + # Finish closing + self._connection.ioloop.start() + LOGGER.info("Stopped") + + def stop(self) -> None: + """Cleanly shutdown the connection to RabbitMQ by stopping the consumer + with RabbitMQ. When RabbitMQ confirms the cancellation, on_cancelok + will be invoked by pika, which will then closing the channel and + connection. The IOLoop is started again because this method is invoked + when CTRL-C is pressed raising a KeyboardInterrupt exception. This + exception stops the IOLoop which needs to be running for pika to + communicate with RabbitMQ. All of the commands issued prior to starting + the IOLoop will be buffered but not processed. + + """ + assert self._connection is not None + if not self._closing: + self._closing = True + LOGGER.info("Stopping") + if self._consuming: + self.stop_consuming() + self._connection.ioloop.start() + else: + self._connection.ioloop.stop() + LOGGER.info("Stopped") + + @staticmethod + def request( + channel: pika.channel.Channel, + reply_to: str, + exchange: str, + routing_key: str, + correlation_id: str, + **kwargs, + ) -> None: + channel.basic_publish( + exchange=exchange, + routing_key=routing_key, + properties=pika.BasicProperties( + content_type="application/msgpack", + correlation_id=correlation_id, + reply_to=reply_to, + ), + body=encode_data( + kwargs, + extra_encoders=ProvenanceStorageRabbitMQClient.extra_type_encoders, + ), + ) + + def wait_for_acks(self, acks_expected: int) -> bool: + acks_received = 0 + while acks_received < acks_expected: + try: + acks_received += self.wait_for_response() + except ResponseTimeout: + LOGGER.warning( + "Timed out waiting for acks, %s received, %s expected", + acks_received, + acks_expected, + ) + return False + return acks_received == acks_expected + + def wait_for_response(self, timeout: float = 60.0) -> Any: + start = time.monotonic() + while True: + try: + correlation_id, response = self._response_queue.get(block=False) + if correlation_id == self._correlation_id: + return response + except queue.Empty: + pass + + if self._response_queue.empty() and time.monotonic() > start + timeout: + raise ResponseTimeout diff --git a/swh/provenance/api/server.py b/swh/provenance/api/server.py --- a/swh/provenance/api/server.py +++ b/swh/provenance/api/server.py @@ -3,79 +3,645 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from collections import Counter +from datetime import datetime +import functools import logging +import multiprocessing import os -from typing import Any, Dict, List, Optional - -from werkzeug.routing import Rule +import queue +import threading +import time +from typing import Any, Callable +from typing import Counter as TCounter +from typing import Dict, Generator, List, Optional, Tuple + +import pika +import pika.channel +import pika.connection +import pika.exceptions +from pika.exchange_type import ExchangeType +import pika.frame +import pika.spec from swh.core import config -from swh.core.api import JSONFormatter, MsgpackFormatter, RPCServerApp, negotiate +from swh.core.api.serializers import encode_data_client as encode_data +from swh.core.api.serializers import msgpack_loads as decode_data +from swh.model.model import Sha1Git from swh.provenance import get_provenance_storage -from swh.provenance.interface import ProvenanceStorageInterface +from swh.provenance.interface import ( + ProvenanceStorageInterface, + RelationData, + RelationType, +) from .serializers import DECODERS, ENCODERS -storage: Optional[ProvenanceStorageInterface] = None +LOG_FORMAT = ( + "%(levelname) -10s %(asctime)s %(name) -30s %(funcName) " + "-35s %(lineno) -5d: %(message)s" +) +LOGGER = logging.getLogger(__name__) +TERMINATE = object() -def get_global_provenance_storage() -> ProvenanceStorageInterface: - global storage - if storage is None: - storage = get_provenance_storage(**app.config["provenance"]["storage"]) - return storage +def resolve_dates(dates: List[Tuple[Sha1Git, datetime]]) -> Dict[Sha1Git, datetime]: + result: Dict[Sha1Git, datetime] = {} + for sha1, date in dates: + if sha1 in result and result[sha1] < date: + continue + else: + result[sha1] = date + return result -class ProvenanceStorageServerApp(RPCServerApp): - extra_type_decoders = DECODERS - extra_type_encoders = ENCODERS +class ProvenanceStorageRabbitMQWorker(multiprocessing.Process): + """This is an example publisher that will handle unexpected interactions + with RabbitMQ such as channel and connection closures. -app = ProvenanceStorageServerApp( - __name__, - backend_class=ProvenanceStorageInterface, - backend_factory=get_global_provenance_storage, -) + If RabbitMQ closes the connection, it will reopen it. You should + look at the output, as there are limited reasons why the connection may + be closed, which usually are tied to permission related issues or + socket timeouts. + It uses delivery confirmations and illustrates one way to keep track of + messages that have been sent and if they've been confirmed by RabbitMQ. -def has_no_empty_params(rule: Rule) -> bool: - return len(rule.defaults or ()) >= len(rule.arguments or ()) - - -@app.route("/") -def index() -> str: - return """ -Software Heritage provenance storage RPC server - -

You have reached the -Software Heritage -provenance storage RPC server.
-See its -documentation -and API for more information

- -""" - - -@app.route("/site-map") -@negotiate(MsgpackFormatter) -@negotiate(JSONFormatter) -def site_map() -> List[Dict[str, Any]]: - links = [] - for rule in app.url_map.iter_rules(): - if has_no_empty_params(rule) and hasattr( - ProvenanceStorageInterface, rule.endpoint - ): - links.append( - dict( - rule=rule.rule, - description=getattr( - ProvenanceStorageInterface, rule.endpoint - ).__doc__, + """ + + EXCHANGE_TYPE = ExchangeType.direct + + extra_type_decoders = DECODERS + extra_type_encoders = ENCODERS + + def __init__( + self, url: str, exchange: str, routing_key: str, storage_config: Dict[str, Any] + ) -> None: + """Setup the example publisher object, passing in the URL we will use + to connect to RabbitMQ. + + :param str url: The URL for connecting to RabbitMQ + :param str routing_key: The routing key name from which this worker will + consume messages + :param str storage_config: Configuration parameters for the underlying + ``ProvenanceStorage`` object + + """ + name = f"{exchange}_{routing_key}" + super().__init__(name=name) + + self._connection = None + self._channel = None + self._closing = False + self._consumer_tag = None + self._consuming = False + self._prefetch_count = 100 + + self._url = url + self._exchange = exchange + self._routing_key = routing_key + self._queue = name + self._storage_config = storage_config + self._batch_size = 100 + + def connect(self) -> pika.SelectConnection: + """This method connects to RabbitMQ, returning the connection handle. + When the connection is established, the on_connection_open method + will be invoked by pika. + + :rtype: pika.SelectConnection + + """ + LOGGER.info("Connecting to %s", self._url) + return pika.SelectConnection( + parameters=pika.URLParameters(self._url), + on_open_callback=self.on_connection_open, + on_open_error_callback=self.on_connection_open_error, + on_close_callback=self.on_connection_closed, + ) + + def close_connection(self) -> None: + assert self._connection is not None + self._consuming = False + if self._connection.is_closing or self._connection.is_closed: + LOGGER.info("Connection is closing or already closed") + else: + LOGGER.info("Closing connection") + self._connection.close() + + def on_connection_open(self, _unused_connection: pika.SelectConnection) -> None: + """This method is called by pika once the connection to RabbitMQ has + been established. It passes the handle to the connection object in + case we need it, but in this case, we'll just mark it unused. + + :param pika.SelectConnection _unused_connection: The connection + + """ + LOGGER.info("Connection opened") + self.open_channel() + + def on_connection_open_error( + self, _unused_connection: pika.SelectConnection, err: Exception + ) -> None: + """This method is called by pika if the connection to RabbitMQ + can't be established. + + :param pika.SelectConnection _unused_connection: The connection + :param Exception err: The error + + """ + LOGGER.error("Connection open failed, reopening in 5 seconds: %s", err) + assert self._connection is not None + self._connection.ioloop.call_later(5, self._connection.ioloop.stop) + + def on_connection_closed(self, _unused_connection: pika.SelectConnection, reason): + """This method is invoked by pika when the connection to RabbitMQ is + closed unexpectedly. Since it is unexpected, we will reconnect to + RabbitMQ if it disconnects. + + :param pika.connection.Connection connection: The closed connection obj + :param Exception reason: exception representing reason for loss of + connection. + + """ + assert self._connection is not None + self._channel = None + if self._closing: + self._connection.ioloop.stop() + else: + LOGGER.warning("Connection closed, reopening in 5 seconds: %s", reason) + self._connection.ioloop.call_later(5, self._connection.ioloop.stop) + + def open_channel(self) -> None: + """Open a new channel with RabbitMQ by issuing the Channel.Open RPC + command. When RabbitMQ responds that the channel is open, the + on_channel_open callback will be invoked by pika. + + """ + LOGGER.info("Creating a new channel") + assert self._connection is not None + self._connection.channel(on_open_callback=self.on_channel_open) + + def on_channel_open(self, channel: pika.channel.Channel) -> None: + """This method is invoked by pika when the channel has been opened. + The channel object is passed in so we can make use of it. + + Since the channel is now open, we'll declare the exchange to use. + + :param pika.channel.Channel channel: The channel object + + """ + LOGGER.info("Channel opened") + self._channel = channel + self.add_on_channel_close_callback() + self.setup_exchange() + + def add_on_channel_close_callback(self) -> None: + """This method tells pika to call the on_channel_closed method if + RabbitMQ unexpectedly closes the channel. + + """ + LOGGER.info("Adding channel close callback") + assert self._channel is not None + self._channel.add_on_close_callback(callback=self.on_channel_closed) + + def on_channel_closed( + self, channel: pika.channel.Channel, reason: Exception + ) -> None: + """Invoked by pika when RabbitMQ unexpectedly closes the channel. + Channels are usually closed if you attempt to do something that + violates the protocol, such as re-declare an exchange or queue with + different parameters. In this case, we'll close the connection + to shutdown the object. + + :param pika.channel.Channel: The closed channel + :param Exception reason: why the channel was closed + + """ + LOGGER.warning("Channel %i was closed: %s", channel, reason) + self.close_connection() + + def setup_exchange(self) -> None: + """Setup the exchange on RabbitMQ by invoking the Exchange.Declare RPC + command. When it is complete, the on_exchange_declareok method will + be invoked by pika. + + """ + LOGGER.info("Declaring exchange %s", self._routing_key) + assert self._channel is not None + self._channel.exchange_declare( + exchange=self._exchange, + exchange_type=self.EXCHANGE_TYPE, + callback=self.on_exchange_declareok, + ) + + def on_exchange_declareok(self, _unused_frame: pika.frame.Method) -> None: + """Invoked by pika when RabbitMQ has finished the Exchange.Declare RPC + command. + + :param pika.frame.Method unused_frame: Exchange.DeclareOk response frame + + """ + LOGGER.info("Exchange declared: %s", self._exchange) + self.setup_queue() + + def setup_queue(self) -> None: + """Setup the queue on RabbitMQ by invoking the Queue.Declare RPC + command. When it is complete, the on_queue_declareok method will + be invoked by pika. + + """ + LOGGER.info("Declaring queue %s", self._queue) + assert self._channel is not None + self._channel.queue_declare(queue=self._queue, callback=self.on_queue_declareok) + + def on_queue_declareok(self, _unused_frame: pika.frame.Method) -> None: + """Method invoked by pika when the Queue.Declare RPC call made in + setup_queue has completed. In this method we will bind the queue + and exchange together with the routing key by issuing the Queue.Bind + RPC command. When this command is complete, the on_bindok method will + be invoked by pika. + + :param pika.frame.Method method_frame: The Queue.DeclareOk frame + + """ + LOGGER.info( + "Binding queue %s to exchange %s with routing key %s", + self._queue, + self._exchange, + self._routing_key, + ) + assert self._channel is not None + self._channel.queue_bind( + queue=self._queue, + exchange=self._exchange, + routing_key=self._routing_key, + callback=self.on_bindok, + ) + + def on_bindok(self, _unused_frame: pika.frame.Method) -> None: + """Invoked by pika when the Queue.Bind method has completed. At this + point we will set the prefetch count for the channel. + + :param pika.frame.Method _unused_frame: The Queue.BindOk response frame + :param str|unicode queue_name: The name of the queue to declare + + """ + LOGGER.info("Queue bound: %s", self._queue) + self.set_qos() + + def set_qos(self) -> None: + """This method sets up the consumer prefetch to only be delivered + one message at a time. The consumer must acknowledge this message + before RabbitMQ will deliver another one. You should experiment + with different prefetch values to achieve desired performance. + + """ + assert self._channel is not None + self._channel.basic_qos( + prefetch_count=self._prefetch_count, callback=self.on_basic_qos_ok + ) + + def on_basic_qos_ok(self, _unused_frame: pika.frame.Method) -> None: + """Invoked by pika when the Basic.QoS method has completed. At this + point we will start consuming messages by calling start_consuming + which will invoke the needed RPC commands to start the process. + + :param pika.frame.Method _unused_frame: The Basic.QosOk response frame + + """ + LOGGER.info("QOS set to: %d", self._prefetch_count) + self.start_consuming() + + def start_consuming(self) -> None: + """This method sets up the consumer by first calling + add_on_cancel_callback so that the object is notified if RabbitMQ + cancels the consumer. It then issues the Basic.Consume RPC command + which returns the consumer tag that is used to uniquely identify the + consumer with RabbitMQ. We keep the value to use it when we want to + cancel consuming. The on_request method is passed in as a callback pika + will invoke when a message is fully received. + + """ + LOGGER.info("Issuing consumer related RPC commands") + assert self._channel is not None + self.add_on_cancel_callback() + self._consumer_tag = self._channel.basic_consume( + queue=self._queue, on_message_callback=self.on_request + ) + self._consuming = True + + def add_on_cancel_callback(self) -> None: + """Add a callback that will be invoked if RabbitMQ cancels the consumer + for some reason. If RabbitMQ does cancel the consumer, + on_consumer_cancelled will be invoked by pika. + + """ + LOGGER.info("Adding consumer cancellation callback") + assert self._channel is not None + self._channel.add_on_cancel_callback(callback=self.on_consumer_cancelled) + + def on_consumer_cancelled(self, method_frame: pika.frame.Method) -> None: + """Invoked by pika when RabbitMQ sends a Basic.Cancel for a consumer + receiving messages. + + :param pika.frame.Method method_frame: The Basic.Cancel frame + + """ + LOGGER.info("Consumer was cancelled remotely, shutting down: %r", method_frame) + if self._channel: + self._channel.close() + + def on_request( + self, + _unused_channel: pika.channel.Channel, + basic_deliver: pika.spec.Basic.Deliver, + properties: pika.spec.BasicProperties, + body: bytes, + ) -> None: + """Invoked by pika when a message is delivered from RabbitMQ. The + channel is passed for your convenience. The basic_deliver object that + is passed in carries the exchange, routing key, delivery tag and + a redelivered flag for the message. The properties passed in is an + instance of BasicProperties with the message properties and the body + is the message that was sent. + + :param pika.channel.Channel _unused_channel: The channel object + :param pika.spec.Basic.Deliver: basic_deliver method + :param pika.spec.BasicProperties: properties + :param bytes body: The message body + + """ + LOGGER.info( + "Received message # %s from %s: %s", + basic_deliver.delivery_tag, + properties.app_id, + body, + ) + item = decode_data(data=body, extra_decoders=self.extra_type_decoders)["data"] + self._request_queue.put( + (item, (properties.correlation_id, properties.reply_to)) + ) + self.acknowledge_message(delivery_tag=basic_deliver.delivery_tag) + + def acknowledge_message(self, delivery_tag: int) -> None: + """Acknowledge the message delivery from RabbitMQ by sending a + Basic.Ack RPC method for the delivery tag. + + :param int delivery_tag: The delivery tag from the Basic.Deliver frame + + """ + LOGGER.info("Acknowledging message %s", delivery_tag) + assert self._channel is not None + self._channel.basic_ack(delivery_tag=delivery_tag) + + def stop_consuming(self) -> None: + """Tell RabbitMQ that you would like to stop consuming by sending the + Basic.Cancel RPC command. + + """ + if self._channel: + LOGGER.info("Sending a Basic.Cancel RPC command to RabbitMQ") + self._channel.basic_cancel(self._consumer_tag, self.on_cancelok) + + def on_cancelok(self, _unused_frame: pika.frame.Method) -> None: + """This method is invoked by pika when RabbitMQ acknowledges the + cancellation of a consumer. At this point we will close the channel. + This will invoke the on_channel_closed method once the channel has been + closed, which will in-turn close the connection. + + :param pika.frame.Method _unused_frame: The Basic.CancelOk frame + :param str|unicode consumer_tag: Tag of the consumer to be stopped + + """ + self._consuming = False + LOGGER.info( + "RabbitMQ acknowledged the cancellation of the consumer: %s", + self._consumer_tag, + ) + self.close_channel() + + def close_channel(self) -> None: + """Call to close the channel with RabbitMQ cleanly by issuing the + Channel.Close RPC command. + + """ + LOGGER.info("Closing the channel") + assert self._channel is not None + self._channel.close() + + def run(self) -> None: + """Run the example code by connecting and then starting the IOLoop.""" + + self._request_queue: queue.Queue = queue.Queue() + self._storage_thread = threading.Thread(target=self.run_storage_thread) + self._storage_thread.start() + + while not self._closing: + try: + self._connection = self.connect() + assert self._connection is not None + self._connection.ioloop.start() + except KeyboardInterrupt: + self.stop() + if self._connection is not None and not self._connection.is_closed: + # Finish closing + self._connection.ioloop.start() + + self._request_queue.put(TERMINATE) + self._storage_thread.join() + LOGGER.info("Stopped") + + def run_storage_thread(self) -> None: + storage = get_provenance_storage(**self._storage_config) + + meth_name, relation = ProvenanceStorageRabbitMQServer.get_meth_name( + self._exchange + ) + resolve_conflicts = ProvenanceStorageRabbitMQWorker.get_conflicts_func( + meth_name + ) + + while True: + terminate = False + elements = [] + while True: + try: + elem = self._request_queue.get(timeout=0.1) + if elem is TERMINATE: + terminate = True + break + elements.append(elem) + except queue.Empty: + break + + if len(elements) >= self._batch_size: + break + + if terminate: + break + + if not elements: + continue + + items, props = zip(*elements) + acks_count: TCounter[Tuple[str, str]] = Counter(props) + data = resolve_conflicts(items) + + args = (relation, data) if relation is not None else (data,) + if getattr(storage, meth_name)(*args): + for (correlation_id, reply_to), count in acks_count.items(): + # FIXME: this is running in a different thread! Hence, if + # self._connection drops, there is no guarantee that the response + # can be sent for the current elements. This situation should be + # handled properly. + assert self._connection is not None + self._connection.ioloop.add_callback_threadsafe( + functools.partial( + ProvenanceStorageRabbitMQServer.respond, + channel=self._channel, + correlation_id=correlation_id, + reply_to=reply_to, + response=count, + ) + ) + else: + LOGGER.warning( + "Unable to process elements for queue %s", self._routing_key ) - ) - # links is now a list of url, endpoint tuples - return links + for elem in elements: + self._request_queue.put(elem) + + def stop(self) -> None: + """Cleanly shutdown the connection to RabbitMQ by stopping the consumer + with RabbitMQ. When RabbitMQ confirms the cancellation, on_cancelok + will be invoked by pika, which will then closing the channel and + connection. The IOLoop is started again because this method is invoked + when CTRL-C is pressed raising a KeyboardInterrupt exception. This + exception stops the IOLoop which needs to be running for pika to + communicate with RabbitMQ. All of the commands issued prior to starting + the IOLoop will be buffered but not processed. + + """ + assert self._connection is not None + if not self._closing: + self._closing = True + LOGGER.info("Stopping") + if self._consuming: + self.stop_consuming() + self._connection.ioloop.start() + else: + self._connection.ioloop.stop() + LOGGER.info("Stopped") + + @staticmethod + def get_conflicts_func(meth_name: str) -> Callable[[List[Any]], Any]: + if meth_name.startswith("relation_add"): + # Create RelationData structures from tuples and deduplicate + return lambda data: {RelationData(*row) for row in data} + else: + # Dates should be resolved to the earliest one, + # otherwise last processed value is good enough + return resolve_dates if "_date" in meth_name else lambda data: dict(data) + + +class ProvenanceStorageRabbitMQServer: + backend_class = ProvenanceStorageInterface + extra_type_decoders = DECODERS + extra_type_encoders = ENCODERS + + queue_count = 16 + + def __init__(self, url: str, storage_config: Dict[str, Any]) -> None: + workers: List[ProvenanceStorageRabbitMQWorker] = [] + for meth_name, meth in self.backend_class.__dict__.items(): + if hasattr( + meth, "_endpoint_path" + ) and ProvenanceStorageRabbitMQServer.is_write_method(meth_name): + for exchange in ProvenanceStorageRabbitMQServer.get_exchanges( + meth_name + ): + for ( + routing_key + ) in ProvenanceStorageRabbitMQServer.get_routing_keys(): + worker = ProvenanceStorageRabbitMQWorker( + url, exchange, routing_key, storage_config + ) + worker.start() + workers.append(worker) + + LOGGER.info("Start consuming") + while True: + try: + time.sleep(1.0) + except KeyboardInterrupt: + break + LOGGER.info("Stop consuming") + for worker in workers: + worker.terminate() + worker.join() + + @staticmethod + def ack(channel: pika.channel.Channel, delivery_tag: int) -> None: + channel.basic_ack(delivery_tag=delivery_tag) + + @staticmethod + def get_exchange(meth_name: str, relation: Optional[RelationType] = None) -> str: + if relation is not None: + return f"{meth_name}_{relation.value}".lower() + else: + return meth_name + + @staticmethod + def get_exchanges(meth_name: str) -> Generator[str, None, None]: + if meth_name.startswith("relation_add"): + for relation in RelationType: + yield ProvenanceStorageRabbitMQServer.get_exchange(meth_name, relation) + else: + yield ProvenanceStorageRabbitMQServer.get_exchange(meth_name) + + @staticmethod + def get_meth_name(exchange: str) -> Tuple[str, Optional[RelationType]]: + if exchange.startswith("relation_add"): + *_, relation = exchange.split("_", 2) + return "relation_add", RelationType(relation) + else: + return exchange, None + + @staticmethod + def get_routing_key(data: Any) -> str: + idx = data[0][0] // (256 // ProvenanceStorageRabbitMQServer.queue_count) + return f"{idx:x}".lower() + + @staticmethod + def get_routing_keys() -> Generator[str, None, None]: + for idx in range(ProvenanceStorageRabbitMQServer.queue_count): + yield f"{idx:x}".lower() + + @staticmethod + def is_write_method(meth_name: str) -> bool: + return meth_name.startswith("relation_add") or "_set_" in meth_name + + @staticmethod + def respond( + channel: pika.channel.Channel, + correlation_id: str, + reply_to: str, + response: Any, + ): + channel.basic_publish( + exchange="", + routing_key=reply_to, + properties=pika.BasicProperties( + content_type="application/msgpack", + correlation_id=correlation_id, + ), + body=encode_data( + response, + extra_encoders=ProvenanceStorageRabbitMQServer.extra_type_encoders, + ), + ) def load_and_check_config( @@ -127,22 +693,10 @@ return cfg -api_cfg: Optional[Dict[str, Any]] = None - - -def make_app_from_configfile() -> ProvenanceStorageServerApp: - """Run the WSGI app from the webserver, loading the configuration from - a configuration file. - - SWH_CONFIG_FILENAME environment variable defines the - configuration path to load. - - """ - global api_cfg - if api_cfg is None: - config_path = os.environ.get("SWH_CONFIG_FILENAME") - api_cfg = load_and_check_config(config_path) - app.config.update(api_cfg) - handler = logging.StreamHandler() - app.logger.addHandler(handler) - return app +def make_server_from_configfile() -> None: + config_path = os.environ.get("SWH_CONFIG_FILENAME") + server_cfg = load_and_check_config(config_path) + ProvenanceStorageRabbitMQServer( + url=server_cfg["provenance"]["rabbitmq"]["url"], + storage_config=server_cfg["provenance"]["storage"], + ) diff --git a/swh/provenance/cli.py b/swh/provenance/cli.py --- a/swh/provenance/cli.py +++ b/swh/provenance/cli.py @@ -42,21 +42,30 @@ }, "storage": { # Local PostgreSQL Storage - "cls": "postgresql", - "db": { - "host": "localhost", - "user": "postgres", - "password": "postgres", - "dbname": "provenance", - }, + # "cls": "postgresql", + # "db": { + # "host": "localhost", + # "user": "postgres", + # "password": "postgres", + # "dbname": "provenance", + # }, # Local MongoDB Storage # "cls": "mongodb", # "db": { # "dbname": "provenance", # }, - # Remote REST-API/PostgreSQL - # "cls": "remote", - # "url": "http://localhost:8080/%2f", + # Remote RabbitMQ/PostgreSQL Storage + "cls": "rabbitmq", + "url": "amqp://localhost:5672/%2f", + "storage_config": { + "cls": "postgresql", + "db": { + "host": "localhost", + "user": "postgres", + "password": "postgres", + "dbname": "dummy", + }, + }, }, } } diff --git a/swh/provenance/graph.py b/swh/provenance/graph.py --- a/swh/provenance/graph.py +++ b/swh/provenance/graph.py @@ -6,7 +6,6 @@ from __future__ import annotations from datetime import datetime, timezone -import logging import os from typing import Any, Dict, Optional, Set @@ -187,9 +186,6 @@ root_date = provenance.directory_get_date_in_isochrone_frontier(directory) root = IsochroneNode(directory, dbdate=root_date) stack = [root] - logging.debug( - f"Recursively creating isochrone graph for revision {revision.id.hex()}..." - ) fdates: Dict[Sha1Git, datetime] = {} # map {file_id: date} while stack: current = stack.pop() @@ -198,12 +194,6 @@ # is greater or equal to the current revision's one, it should be ignored as # the revision is being processed out of order. if current.dbdate is not None and current.dbdate > revision.date: - logging.debug( - f"Invalidating frontier on {current.entry.id.hex()}" - f" (date {current.dbdate})" - f" when processing revision {revision.id.hex()}" - f" (date {revision.date})" - ) current.invalidate() # Pre-query all known dates for directories in the current directory @@ -220,12 +210,8 @@ fdates.update(provenance.content_get_early_dates(current.entry.files)) - logging.debug( - f"Isochrone graph for revision {revision.id.hex()} successfully created!" - ) # Precalculate max known date for each node in the graph (only directory nodes are # pushed to the stack). - logging.debug(f"Computing maxdates for revision {revision.id.hex()}...") stack = [root] while stack: @@ -276,5 +262,4 @@ # node should be treated as unknown current.maxdate = revision.date current.known = False - logging.debug(f"Maxdates for revision {revision.id.hex()} successfully computed!") return root diff --git a/swh/provenance/origin.py b/swh/provenance/origin.py --- a/swh/provenance/origin.py +++ b/swh/provenance/origin.py @@ -4,8 +4,6 @@ # See top-level LICENSE file for more information from itertools import islice -import logging -import time from typing import Generator, Iterable, Iterator, List, Optional, Tuple from swh.model.model import Sha1Git @@ -49,21 +47,13 @@ archive: ArchiveInterface, origins: List[OriginEntry], ) -> None: - start = time.time() for origin in origins: provenance.origin_add(origin) origin.retrieve_revisions(archive) for revision in origin.revisions: graph = HistoryGraph(archive, provenance, revision) origin_add_revision(provenance, origin, graph) - done = time.time() provenance.flush() - stop = time.time() - logging.debug( - "Origins " - ";".join([origin.id.hex() + ":" + origin.snapshot.hex() for origin in origins]) - + f" were processed in {stop - start} secs (commit took {stop - done} secs)!" - ) def origin_add_revision( diff --git a/swh/provenance/postgresql/provenance.py b/swh/provenance/postgresql/provenance.py --- a/swh/provenance/postgresql/provenance.py +++ b/swh/provenance/postgresql/provenance.py @@ -23,6 +23,8 @@ RevisionData, ) +LOGGER = logging.getLogger(__name__) + class ProvenanceStoragePostgreSql: def __init__( @@ -91,7 +93,6 @@ try: if urls: sql = """ - LOCK TABLE ONLY origin; INSERT INTO origin(sha1, url) VALUES %s ON CONFLICT DO NOTHING """ @@ -99,7 +100,7 @@ return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message - logging.exception("Unexpected error") + LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False @@ -126,7 +127,6 @@ try: if origins: sql = """ - LOCK TABLE ONLY revision; INSERT INTO revision(sha1, origin) (SELECT V.rev AS sha1, O.id AS origin FROM (VALUES %s) AS V(rev, org) @@ -138,7 +138,7 @@ return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message - logging.exception("Unexpected error") + LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False @@ -176,7 +176,6 @@ # non-null information srcs = tuple(set((sha1,) for (sha1, _, _) in rows)) sql = f""" - LOCK TABLE ONLY {src_table}; INSERT INTO {src_table}(sha1) VALUES %s ON CONFLICT DO NOTHING """ @@ -187,7 +186,6 @@ # non-null information dsts = tuple(set((sha1,) for (_, sha1, _) in rows)) sql = f""" - LOCK TABLE ONLY {dst_table}; INSERT INTO {dst_table}(sha1) VALUES %s ON CONFLICT DO NOTHING """ @@ -211,7 +209,7 @@ return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message - logging.exception("Unexpected error") + LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False @@ -252,7 +250,6 @@ try: if data: sql = f""" - LOCK TABLE ONLY {entity}; INSERT INTO {entity}(sha1, date) VALUES %s ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date,{entity}.date) @@ -261,7 +258,7 @@ return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message - logging.exception("Unexpected error") + LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -20,6 +20,8 @@ ) from .model import DirectoryEntry, FileEntry, OriginEntry, RevisionEntry +LOGGER = logging.getLogger(__name__) + class DatetimeCache(TypedDict): data: Dict[Sha1Git, Optional[datetime]] @@ -79,41 +81,44 @@ # For this layer, relations need to be inserted first so that, in case of # failure, reprocessing the input does not generated an inconsistent database. - while not self.storage.relation_add( - RelationType.CNT_EARLY_IN_REV, - ( - RelationData(src=src, dst=dst, path=path) - for src, dst, path in self.cache["content_in_revision"] - ), - ): - logging.warning( - f"Unable to write {RelationType.CNT_EARLY_IN_REV} rows to the storage. " - f"Data: {self.cache['content_in_revision']}. Retrying..." - ) - - while not self.storage.relation_add( - RelationType.CNT_IN_DIR, - ( - RelationData(src=src, dst=dst, path=path) - for src, dst, path in self.cache["content_in_directory"] - ), - ): - logging.warning( - f"Unable to write {RelationType.CNT_IN_DIR} rows to the storage. " - f"Data: {self.cache['content_in_directory']}. Retrying..." - ) - - while not self.storage.relation_add( - RelationType.DIR_IN_REV, - ( - RelationData(src=src, dst=dst, path=path) - for src, dst, path in self.cache["directory_in_revision"] - ), - ): - logging.warning( - f"Unable to write {RelationType.DIR_IN_REV} rows to the storage. " - f"Data: {self.cache['directory_in_revision']}. Retrying..." - ) + if self.cache["content_in_revision"]: + while not self.storage.relation_add( + RelationType.CNT_EARLY_IN_REV, + ( + RelationData(src=src, dst=dst, path=path) + for src, dst, path in self.cache["content_in_revision"] + ), + ): + LOGGER.warning( + "Unable to write %s rows to the storage. Retrying...", + RelationType.CNT_EARLY_IN_REV, + ) + + if self.cache["content_in_directory"]: + while not self.storage.relation_add( + RelationType.CNT_IN_DIR, + ( + RelationData(src=src, dst=dst, path=path) + for src, dst, path in self.cache["content_in_directory"] + ), + ): + LOGGER.warning( + "Unable to write %s rows to the storage. Retrying...", + RelationType.CNT_IN_DIR, + ) + + if self.cache["directory_in_revision"]: + while not self.storage.relation_add( + RelationType.DIR_IN_REV, + ( + RelationData(src=src, dst=dst, path=path) + for src, dst, path in self.cache["directory_in_revision"] + ), + ): + LOGGER.warning( + "Unable to write %s rows to the storage. Retrying...", + RelationType.DIR_IN_REV, + ) # After relations, dates for the entities can be safely set, acknowledging that # these entities won't need to be reprocessed in case of failure. @@ -122,33 +127,33 @@ for sha1, date in self.cache["content"]["data"].items() if sha1 in self.cache["content"]["added"] and date is not None } - while not self.storage.content_set_date(dates): - logging.warning( - f"Unable to write content dates to the storage. " - f"Data: {dates}. Retrying..." - ) + if dates: + while not self.storage.content_set_date(dates): + LOGGER.warning( + "Unable to write content dates to the storage. Retrying..." + ) dates = { sha1: date for sha1, date in self.cache["directory"]["data"].items() if sha1 in self.cache["directory"]["added"] and date is not None } - while not self.storage.directory_set_date(dates): - logging.warning( - f"Unable to write directory dates to the storage. " - f"Data: {dates}. Retrying..." - ) + if dates: + while not self.storage.directory_set_date(dates): + LOGGER.warning( + "Unable to write directory dates to the storage. Retrying..." + ) dates = { sha1: date for sha1, date in self.cache["revision"]["data"].items() if sha1 in self.cache["revision"]["added"] and date is not None } - while not self.storage.revision_set_date(dates): - logging.warning( - f"Unable to write revision dates to the storage. " - f"Data: {dates}. Retrying..." - ) + if dates: + while not self.storage.revision_set_date(dates): + LOGGER.warning( + "Unable to write revision dates to the storage. Retrying..." + ) # Origin-revision layer insertions ############################################# @@ -159,11 +164,11 @@ for sha1, url in self.cache["origin"]["data"].items() if sha1 in self.cache["origin"]["added"] } - while not self.storage.origin_set_url(urls): - logging.warning( - f"Unable to write origins urls to the storage. " - f"Data: {urls}. Retrying..." - ) + if urls: + while not self.storage.origin_set_url(urls): + LOGGER.warning( + "Unable to write origins urls to the storage. Retrying..." + ) # Second, flat models for revisions' histories (ie. revision-before-revision). data: Iterable[RelationData] = sum( @@ -176,11 +181,12 @@ ], [], ) - while not self.storage.relation_add(RelationType.REV_BEFORE_REV, data): - logging.warning( - f"Unable to write {RelationType.REV_BEFORE_REV} rows to the storage. " - f"Data: {data}. Retrying..." - ) + if data: + while not self.storage.relation_add(RelationType.REV_BEFORE_REV, data): + LOGGER.warning( + "Unable to write %s rows to the storage. Retrying...", + RelationType.REV_BEFORE_REV, + ) # Heads (ie. revision-in-origin entries) should be inserted once flat models for # their histories were already added. This is to guarantee consistent results if @@ -190,11 +196,12 @@ RelationData(src=rev, dst=org, path=None) for rev, org in self.cache["revision_in_origin"] ) - while not self.storage.relation_add(RelationType.REV_IN_ORG, data): - logging.warning( - f"Unable to write {RelationType.REV_IN_ORG} rows to the storage. " - f"Data: {data}. Retrying..." - ) + if data: + while not self.storage.relation_add(RelationType.REV_IN_ORG, data): + LOGGER.warning( + "Unable to write %s rows to the storage. Retrying...", + RelationType.REV_IN_ORG, + ) # Finally, preferred origins for the visited revisions are set (this step can be # reordered if required). @@ -202,11 +209,11 @@ sha1: self.cache["revision_origin"]["data"][sha1] for sha1 in self.cache["revision_origin"]["added"] } - while not self.storage.revision_set_origin(origins): - logging.warning( - f"Unable to write preferred origins to the storage. " - f"Data: {origins}. Retrying..." - ) + if origins: + while not self.storage.revision_set_origin(origins): + LOGGER.warning( + "Unable to write preferred origins to the storage. Retrying..." + ) # clear local cache ############################################################ self.clear_caches() diff --git a/swh/provenance/revision.py b/swh/provenance/revision.py --- a/swh/provenance/revision.py +++ b/swh/provenance/revision.py @@ -4,9 +4,7 @@ # See top-level LICENSE file for more information from datetime import datetime, timezone -import logging import os -import time from typing import Generator, Iterable, Iterator, List, Optional, Tuple from swh.model.model import Sha1Git @@ -59,17 +57,12 @@ mindepth: int = 1, commit: bool = True, ) -> None: - start = time.time() for revision in revisions: assert revision.date is not None assert revision.root is not None # Processed content starting from the revision's root directory. date = provenance.revision_get_date(revision) if date is None or revision.date < date: - logging.debug( - f"Processing revisions {revision.id.hex()}" - f" (known date {date} / revision date {revision.date})..." - ) graph = build_isochrone_graph( archive, provenance, @@ -86,14 +79,8 @@ lower=lower, mindepth=mindepth, ) - done = time.time() if commit: provenance.flush() - stop = time.time() - logging.debug( - f"Revisions {';'.join([revision.id.hex() for revision in revisions])} " - f" were processed in {stop - start} secs (commit took {stop - done} secs)!" - ) def revision_process_content( diff --git a/swh/provenance/sql/40-funcs.sql b/swh/provenance/sql/40-funcs.sql --- a/swh/provenance/sql/40-funcs.sql +++ b/swh/provenance/sql/40-funcs.sql @@ -99,7 +99,6 @@ join_location text; begin if src_table in ('content'::regclass, 'directory'::regclass) then - lock table only location; insert into location(path) select V.path from tmp_relation_add as V @@ -113,15 +112,14 @@ end if; execute format( - 'lock table only %s; - insert into %s + 'insert into %s select S.id, ' || select_fields || ' from tmp_relation_add as V inner join %s as S on (S.sha1 = V.src) inner join %s as D on (D.sha1 = V.dst) ' || join_location || ' on conflict do nothing', - rel_table, rel_table, src_table, dst_table + rel_table, src_table, dst_table ); end; $$; @@ -254,14 +252,13 @@ as $$ begin execute format( - 'lock table only %s; - insert into %s + 'insert into %s select S.id, D.id from tmp_relation_add as V inner join %s as S on (S.sha1 = V.src) inner join %s as D on (D.sha1 = V.dst) on conflict do nothing', - rel_table, rel_table, src_table, dst_table + rel_table, src_table, dst_table ); end; $$; @@ -422,7 +419,6 @@ on_conflict text; begin if src_table in ('content'::regclass, 'directory'::regclass) then - lock table only location; insert into location(path) select V.path from tmp_relation_add as V @@ -448,8 +444,7 @@ end if; execute format( - 'lock table only %s; - insert into %s + 'insert into %s select S.id, ' || select_fields || ' from tmp_relation_add as V inner join %s as S on (S.sha1 = V.src) @@ -457,7 +452,7 @@ ' || join_location || ' ' || group_entries || ' on conflict ' || on_conflict, - rel_table, rel_table, src_table, dst_table + rel_table, src_table, dst_table ); end; $$; @@ -641,15 +636,14 @@ end if; execute format( - 'lock table only %s; - insert into %s + 'insert into %s select S.id, ' || select_fields || ' from tmp_relation_add as V inner join %s as S on (S.sha1 = V.src) inner join %s as D on (D.sha1 = V.dst) ' || group_entries || ' on conflict ' || on_conflict, - rel_table, rel_table, src_table, dst_table + rel_table, src_table, dst_table ); end; $$; diff --git a/swh/provenance/tests/conftest.py b/swh/provenance/tests/conftest.py --- a/swh/provenance/tests/conftest.py +++ b/swh/provenance/tests/conftest.py @@ -5,7 +5,7 @@ from datetime import datetime, timedelta, timezone from os import path -from typing import Any, Dict, Iterable, Iterator +from typing import Any, Dict, Iterable from _pytest.fixtures import SubRequest import msgpack @@ -16,8 +16,6 @@ from swh.journal.serializers import msgpack_ext_hook from swh.provenance import get_provenance, get_provenance_storage -from swh.provenance.api.client import RemoteProvenanceStorage -import swh.provenance.api.server as server from swh.provenance.archive import ArchiveInterface from swh.provenance.interface import ProvenanceInterface, ProvenanceStorageInterface from swh.provenance.storage.archive import ArchiveStorage @@ -46,38 +44,15 @@ return postgresql.get_dsn_parameters() -# the Flask app used as server in these tests -@pytest.fixture -def app( - provenance_postgresqldb: Dict[str, str] -) -> Iterator[server.ProvenanceStorageServerApp]: - assert hasattr(server, "storage") - server.storage = get_provenance_storage( - cls="postgresql", db=provenance_postgresqldb - ) - yield server.app - - -# the RPCClient class used as client used in these tests -@pytest.fixture -def swh_rpc_client_class() -> type: - return RemoteProvenanceStorage - - -@pytest.fixture(params=["mongodb", "postgresql", "remote"]) +@pytest.fixture(params=["mongodb", "postgresql"]) def provenance_storage( request: SubRequest, provenance_postgresqldb: Dict[str, str], mongodb: pymongo.database.Database, - swh_rpc_client: RemoteProvenanceStorage, ) -> ProvenanceStorageInterface: """Return a working and initialized ProvenanceStorageInterface object""" - if request.param == "remote": - assert isinstance(swh_rpc_client, ProvenanceStorageInterface) - return swh_rpc_client - - elif request.param == "mongodb": + if request.param == "mongodb": from swh.provenance.mongo.backend import ProvenanceStorageMongoDb return ProvenanceStorageMongoDb(mongodb)