Skip to content
Snippets Groups Projects
exchange.py 21.38 KiB
"""

copied from: dashboard-v3-python v0.226, dashboard.messaging.exchange
updated typehints to satisfy mypy, and linted with changes suggested by ruff

"""
import json
import logging
import threading
import time
from collections.abc import Callable, Generator, Iterator, Sequence
from json import JSONDecodeError
from typing import Any

import jsonschema
import pika
import pika.channel
from pika.adapters.blocking_connection import BlockingChannel
from pika.exceptions import AMQPError
from pika.exchange_type import ExchangeType

from .queue import setup_channel

logger = logging.getLogger(__name__)


def now() -> float:
    return time.monotonic()


def default_rmq_connection_params(
    hostname: str | Sequence[str],
    username: str,
    password: str,
    vhost: str,
    port: int = 5672,
) -> Sequence[pika.ConnectionParameters]:
    """
    build a list of pika.ConnectionParameters objects for connecting
    to the cluster

    :param hostname: hostname or list of hostnames
    :param username: username (same for all hostnames)
    :param password: password (same for all hostnames)
    :param vhost: connection vhost
    :param port: port (default value = 5672, same for all hostnames)
    :return: list of pika.ConnectionParameters
    """
    hostnames = [hostname] if isinstance(hostname, str) else hostname
    logger.debug(f"rmq hostnames: {hostnames}")
    return [
        pika.ConnectionParameters(
            host=h,
            port=port,
            virtual_host=vhost,
            credentials=pika.PlainCredentials(username, password),
        )
        for h in hostnames
    ]


class RMQResponse:
    def __init__(
        self,
        body: bytes,
        client: "RabbitMQClient",
        delivery_tag: int = 0,
        reply_to: str | None = None
    ):
        self.body = body
        self.client = client
        self.delivery_tag = delivery_tag
        self.reply_to = reply_to

    def ack(self) -> None:
        # TODO: Handle AMQP error when acking fails
        self.client.maybe_ack(self.delivery_tag)

    def json(self, schema: dict[str, Any] | None = None) -> Any:
        """try and load the response body as a json object.
        :param schema: Optional. A jsonschema dictionary to validate the json response
        if the json parsing or validation fails, this method will ensure that the
        message is still acked
        """
        try:
            result = json.loads(self.body)
            if schema is not None:
                jsonschema.validate(result, schema)
        except (JSONDecodeError, jsonschema.ValidationError):
            self.ack()  # Maybe we should actually nack the message here
            raise
        return result

    def reply(self, body: str | bytes) -> None:
        self.client.send_message(
            body=body,
            exchange_name="",
            exchange_type=ExchangeType.direct,
            routing_key=self.reply_to,
        )


class RabbitMQClient:
    """
    A RabbitMQClient can connect to RabbitMQ and persisently maintain that connection.
    It will reconnect on connection failures. It has two modes of operation for
    consuming a queue. The first is by explicitly requesting messages through the
    ``get_message`` method:

    .. code-block:: python

        consumer = RabbitMQConsumer(cp, 'exchange', 'queue', exchange_type='fanout')
        while True:
            response = consumer.get_message()
            ... # handle message
            response.ack()

    or through it's run_forever mode:

    .. code-block:: python

        def mycallback(message: dict):
            pass

        consumer = RabbitMQConsumer(cp, 'exchange', 'queue', exchange_type='fanout')
        consumer.run_forever(mycallback, json=True)

    The consumer is not thread safe and only a single message may be handled at a time

    For sending messages (producing), you can use the `send_message()` method. This
    method will ensure that a message is sent, by reconnecting to rabbitmq in case of
    a failure.
    """

    STOP_CHECK_FREQUENCY_S = 5
    PAUSE_BEFORE_RECONNECT_S = 5
    DEFAULT_PREFETCH_COUNT = 50
    connection: pika.BlockingConnection | None = None
    channel: BlockingChannel | None = None
    consumer: Iterator[Any] | None = None
    # consumer: Generator[Any, None, None] | None = None
    last_message_ts = None
    _close_time = None

    def __init__(
        self,
        connection_params: pika.ConnectionParameters | Sequence[pika.ConnectionParameters],
        exchange_name: str,
        exchange_type: ExchangeType,
        queue_name: str | None = None,
        routing_keys: str | Sequence[str] | None = None,
        exclusive: bool = False,
        single_active_consumer: bool = False,
        quorum_queue: bool = False,
        prefetch_count: int = DEFAULT_PREFETCH_COUNT,
        auto_ack: bool = False,
        reconnect_on_idle_timeout: float | None = None,
        reconnect_callback: Callable[[], None] | None= None,
        error_callback: Callable[[Exception], None] | None = None,
        stop_event: threading.Event | None = None,
        pause_before_reconnect: float | None = PAUSE_BEFORE_RECONNECT_S,
    ):
        """
        :param connection_params: rmq connection parameters
        :param exchange_name: rmq exchange name
        :param queue_name: rmq queue name
        :param exchange_type: rmw exchange type, either 'direct' or 'fanout'
        :param routing_keys: Either a single string or a sequence of strings that will
            be used as routing key for 'direct' exchanges. Default: None
        :param exclusive: is this an exclusive queue? Exclusive queues are destroying on
            disconnect. Default: False
        :param single_active_consumer: The queue has can have a maximum of 1 active
            consumer. If multiple consumers are connecting to the same queue, they do
            not receive any messages until the current active consumer disconnects for
            some reason and they become the active consumer. Default: False.
        :param prefetch_count: Limit the number of messages to prefetch from the broker.
            This is highly recommended. Set to None to disable and let the broker
            decide. Default: ``DEFAULT_PREFETCH_COUNT``
        :param auto_ack: automatically acknowledge messages. Default: False.
        :param reconnect_callback: Optional callback that will be called when
            reconnecting. Default: None
        :param error_callback: Optional callback that will be called when an AMQPEror
            is raised or an invalid json message is received. Default: None
        :param reconnect_on_timeout: Maximum idle time between subsequent
            messages before triggering a reconnect. Default: None
        :param pause_before_reconnect: Sleep a number of seconds between disconnecting
            and reconnecting. Supply a value that evaluates to False (0, None, False) to
            disable sleeping. Default: ``PAUSE_BEFORE_RECONNECT_S``.
        :param stop_event: A `threading.Event` that will be checked periodically to
            determine whether the consumer should stop. Default: None
        """

        self.connection_params = self._ensure_connection_params(connection_params)
        self.exchange_name = exchange_name
        self.queue_name = queue_name
        self.exchange_type = exchange_type

        self.routing_keys: Sequence[str] = []  # for mypy
        if routing_keys is None:
            self.routing_keys = []
        elif isinstance(routing_keys, str):
            self.routing_keys = [routing_keys]
        elif isinstance(routing_keys, Sequence):
            self.routing_keys = routing_keys
        else:
            raise AssertionError("impossible - only here for mypy clarity")

        self.exclusive = exclusive
        self.single_active_consumer = single_active_consumer
        self.quorum_queue = quorum_queue
        self.prefetch_count = prefetch_count
        self.auto_ack = auto_ack
        self.reconnect_on_idle_timeout = reconnect_on_idle_timeout
        self.reconnect_callback = reconnect_callback
        self.error_callback = error_callback
        self.stop_event = stop_event
        self.pause_before_reconnect = pause_before_reconnect

    @staticmethod
    def _ensure_connection_params(
        connection_params: dict[str, Any] | pika.ConnectionParameters | Sequence[pika.ConnectionParameters]
    ) -> pika.ConnectionParameters | Sequence[pika.ConnectionParameters]:
        if not isinstance(connection_params, dict):
            return connection_params

        if "hostname" in connection_params:
            hostname = connection_params["hostname"]
        elif "hostnames" in connection_params:
            hostname = connection_params["hostnames"]
        else:
            raise ValueError(
                "Connection parameters must contain either 'hostname' or 'hostnames'"
            )
        return default_rmq_connection_params(
            hostname,
            username=connection_params["username"],
            password=connection_params["password"],
            vhost=connection_params["vhost"],
        )

    def connection_str(self) -> str:
        result = self.exchange_name
        if self.queue_name:
            result += f"/{self.queue_name}"
        if self.routing_keys and self.routing_keys[0]:
            result += f" with routing key(s): {','.join(self.routing_keys)}"
        return result

    def __enter__(self) -> "RabbitMQClient":
        return self

    def __exit__(self, *args: Any, **kwargs: Any) -> None:
        self.close()

    def __del__(self) -> None:
        self.close()

    def _connect(self, as_consumer: bool = True) -> tuple[pika.BlockingConnection, BlockingChannel, str | None]:
        if self.connection is not None:
            raise RuntimeError("Already connected")

        if self._close_time is not None:
            logger.warning(f"Detected rmq disconnect from {self.connection_str()}")
            if self.reconnect_callback is not None:
                self.reconnect_callback()

        self._pause_before_reconnect()
        logger.info(f"Connecting to rmq {self.connection_str()}")

        connection = pika.BlockingConnection(self.connection_params)
        channel, queue = setup_channel(
            connection=connection,
            exchange_name=self.exchange_name,
            exchange_type=self.exchange_type,
            queue_name=self.queue_name if as_consumer else None,
            exclusive=self.exclusive,
            single_active_consumer=self.single_active_consumer,
            routing_keys=self.routing_keys,
            prefetch_count=self.prefetch_count,
            force_quorum_queue=self.quorum_queue,
        )

        return connection, channel, queue

    def connect_consumer(self, timeout: float | None = None) -> None:
        """
        Create a channel and bind to the exchange/queue, setting the internal `consumer`
        attribute

        :param timeout: an optional timeout in seconds. The consumer will
        yield None if no message is received within ``timeout`` seconds. default:
        RabbitMQClient.STOP_CHECK_FREQUENCY_S

        """
        if self.consumer is not None:
            return
        timeout = timeout if timeout is not None else self.STOP_CHECK_FREQUENCY_S

        self.connection, self.channel, self.queue_name = self._connect(as_consumer=True)

        assert self.queue_name  # should be defined by _connect
        self.consumer = self.channel.consume(
            queue=self.queue_name, auto_ack=self.auto_ack, inactivity_timeout=timeout
        )

    def connect_publisher(self) -> None:
        if self.connection is not None:
            return
        self.connection, self.channel, _ = self._connect(as_consumer=False)

    def close(self) -> None:
        if self.connection and self.connection.is_open:
            self.connection.close()
        self.consumer = None
        self.channel = None
        self.connection = None
        self._close_time = now()

    def consume_queue(self,
                      json: bool = False,
                      schema: dict[str, Any] | None = None,
                      timeout: float | None = None,
                      as_response: bool = False) -> Generator[Any]:
        """
        :param json: set to True to json decode all messages before invoking the
            callback. This is ignored if ``as_response`` is set to True. Default False
        :param schema: An optional jsonschema to validate the message against. Requires
            the json parameter to be set to True
        :param timeout: Optional timeout in seconds. See ``RabbitMQClient.get_message``.
            In case of a timeout, this function will yield ``None``
        :param as_response: yield ``RMQResponse`` objects instead of raw bytes or json
            dictionary
        """
        if schema is not None and not json:
            raise ValueError("Must set json to True when supplying a json schema")
        with self:
            while self._should_continue():
                try:
                    response = self.get_message(timeout=timeout, reset_last_ts=False)
                except TimeoutError:
                    yield None
                    continue
                except CancelledError:
                    return
                if as_response:
                    yield response
                elif json:
                    try:
                        yield response.json(schema=schema)
                    except (JSONDecodeError, jsonschema.ValidationError) as e:
                        logger.exception("error parsing message")
                        self._handle_recoverable_error(e)
                        continue
                else:
                    yield response.body

                try:
                    response.ack()
                except AMQPError as e:
                    logger.exception("Error while acknowledging message")
                    self._handle_recoverable_error(e)
                    self.close()  # trigger reconnect

    def consume_forever(
        self,
        callback: Callable[[Any], None],
        json: bool = False,
        schema: dict[str, Any] | None = None,
        timeout: float | None = None,
        as_response: bool = False
    ) -> None:
        """
        See RabbitMQClient.consume_queue for other other paramters

        :param callback: a function that takes a message coming from the queue as a
            single argument
        """
        for result in self.consume_queue(
            json=json, schema=schema, timeout=timeout, as_response=as_response
        ):
            callback(result)

    def get_message(
        self, timeout: float | None = None, reset_last_ts: bool | None = None
    ) -> RMQResponse:
        """
        Get a message from the queue, reconnecting when necessary. This may block
        indefinitely.
        :param timeout: Optional timeout in seconds. If given, this method will raise
            `TimeoutError` if the timeout expired before a message was received. The
            connection will be kept open. Warning: setting a large timeout may make
            the program less responsive, since it will override the default timeout
            (``RabbitMQClient.STOP_CHECK_FREQUENCY_S``)
        :param reset_last_ts: whether or not to reset the last time a message was
            received (RabbitMQClient.last_message_ts) at the beginning of this function.
            Setting this to False will ensure that this function will correctly detect
            when the connection has been idle for too long over multiple calls that do
            not yield a message, but when the function does return because of a timeout.
            By default this is set to False if a timeout is given. If no timeout is
            given, this is by default set to True.
        :return: An RMQResponse or None. If it is a valid response, the `content`
            attribute is set the raw body of the message. It is set to None if returning
            after the RabbitMQClient.stop_event is set
        :raises TimeoutError: If a given custom timeout was reached while waiting for
            messages
        :raises Cancelled: After a ``RabbitMQClient.stop_event`` was set
        """
        if reset_last_ts is None:
            reset_last_ts = timeout is None

        if not self.last_message_ts or reset_last_ts:
            self.last_message_ts = now()

        while self._should_continue():
            try:
                self.connect_consumer(timeout=timeout)
                try:
                    assert self.consumer is not None

                    method, properties, body = next(self.consumer)
                except StopIteration:
                    # Consumer was cancelled by broker
                    logger.warning(f"Broker closed connection {self.connection_str()}")
                    self.close()
                    continue

                if not self._should_continue():
                    break

                if body is not None:  # STOP_CHECK_FREQUENCY_S timeout
                    self.last_message_ts = now()
                    return RMQResponse(
                        body=body,
                        client=self,
                        delivery_tag=method.delivery_tag,
                        reply_to=properties.reply_to,
                    )

                if self._is_idle_for_too_long():
                    logger.warning(
                        f"No message received from '{self.connection_str()}' in"
                        f" {self.reconnect_on_idle_timeout} seconds, reconnecting"
                    )
                    self.last_message_ts = now()
                    self.close()
                    continue

                if timeout is not None:
                    raise TimeoutError

            except AMQPError as e:
                logger.exception(
                    f"An error occured while reading '{self.connection_str()}'"
                )
                self._handle_recoverable_error(e)
                self.close()

        raise CancelledError

    def send_message(
        self,
        body: str | bytes,
        exchange_name: str | None = None,
        exchange_type: ExchangeType | None = None,
        routing_key: str | None = None,
        properties: pika.BasicProperties | None = None,
    ) -> None:
        """
        Send a message to the exchange, retrying on AMQPError to ensure delivery
        :param body: the message body (str or bytes)
        :param exchange_name: Override exchange name. Default: None
        :param exchange_type: Override exchange type. Default: None
        :param routing_key: Override routing key. Default: None
        :param properties: Additional pika.spec.BasicProperties message properties.
            Default: None
        """
        if exchange_name is None:
            exchange_name = self.exchange_name
        if not exchange_type:
            exchange_type = self.exchange_type

        if routing_key is None:
            routing_keys = self.routing_keys
        else:
            routing_keys = [routing_key]

        if len(routing_keys) != 1:
            raise RuntimeError("Can only publish messages with a single routing key")

        if exchange_name != "" and (
            (exchange_name, exchange_type) != (self.exchange_name, self.exchange_type)
        ):
            self.connect_publisher()
            assert self.channel is not None
            self.channel.exchange_declare(
                exchange=exchange_name, exchange_type=exchange_type
            )

        while self._should_continue():
            self.connect_publisher()
            assert self.channel is not None

            try:
                self.channel.basic_publish(
                    exchange=exchange_name,
                    routing_key=routing_keys[0],
                    # mypy won't allow str|bytes here, so ...
                    body=body if isinstance(body, bytes) else body.encode('utf-8'),
                    properties=properties,
                )

            except AMQPError as e:
                logger.exception(
                    f"An error occured sending a message to {exchange_name}"
                    + f" with routing key '{routing_key}'"
                    if routing_key
                    else ""
                )
                self._handle_recoverable_error(e)
                self.close()

    def _should_continue(self) -> bool:
        return not (self.stop_event and self.stop_event.is_set())

    def _is_idle_for_too_long(self) -> bool:
        idle_time = (
            now() - self.last_message_ts if self.last_message_ts is not None else 0
        )
        return (
            self.reconnect_on_idle_timeout is not None
            and idle_time > self.reconnect_on_idle_timeout
        )

    def _pause_before_reconnect(self) -> None:
        if not self.pause_before_reconnect or self._close_time is None:
            return
        sleep_time = self._close_time + self.pause_before_reconnect - now()
        if sleep_time > 0:
            time.sleep(sleep_time)

    def maybe_ack(self, tag: int) -> None:
        if self.auto_ack:
            return  # noop
        assert self.channel, "channel should be defined by _connect"
        self.channel.basic_ack(delivery_tag=tag)

    def _handle_recoverable_error(self, exception: Exception) -> None:
        if self.error_callback:
            self.error_callback(exception)


class TimeoutError(Exception):
    """Raised whenever RabbitMQClient.get_message was called with a ``timeout`` and
    that timeout occured while waiting for a message"""

    pass


class CancelledError(RuntimeError):
    """Raised whenever RabbitMQClient.stop_event was set while waiting for a message"""

    pass