exchange.py 21.38 KiB
"""
copied from: dashboard-v3-python v0.226, dashboard.messaging.exchange
updated typehints to satisfy mypy, and linted with changes suggested by ruff
"""
import json
import logging
import threading
import time
from collections.abc import Callable, Generator, Iterator, Sequence
from json import JSONDecodeError
from typing import Any
import jsonschema
import pika
import pika.channel
from pika.adapters.blocking_connection import BlockingChannel
from pika.exceptions import AMQPError
from pika.exchange_type import ExchangeType
from .queue import setup_channel
logger = logging.getLogger(__name__)
def now() -> float:
return time.monotonic()
def default_rmq_connection_params(
hostname: str | Sequence[str],
username: str,
password: str,
vhost: str,
port: int = 5672,
) -> Sequence[pika.ConnectionParameters]:
"""
build a list of pika.ConnectionParameters objects for connecting
to the cluster
:param hostname: hostname or list of hostnames
:param username: username (same for all hostnames)
:param password: password (same for all hostnames)
:param vhost: connection vhost
:param port: port (default value = 5672, same for all hostnames)
:return: list of pika.ConnectionParameters
"""
hostnames = [hostname] if isinstance(hostname, str) else hostname
logger.debug(f"rmq hostnames: {hostnames}")
return [
pika.ConnectionParameters(
host=h,
port=port,
virtual_host=vhost,
credentials=pika.PlainCredentials(username, password),
)
for h in hostnames
]
class RMQResponse:
def __init__(
self,
body: bytes,
client: "RabbitMQClient",
delivery_tag: int = 0,
reply_to: str | None = None
):
self.body = body
self.client = client
self.delivery_tag = delivery_tag
self.reply_to = reply_to
def ack(self) -> None:
# TODO: Handle AMQP error when acking fails
self.client.maybe_ack(self.delivery_tag)
def json(self, schema: dict[str, Any] | None = None) -> Any:
"""try and load the response body as a json object.
:param schema: Optional. A jsonschema dictionary to validate the json response
if the json parsing or validation fails, this method will ensure that the
message is still acked
"""
try:
result = json.loads(self.body)
if schema is not None:
jsonschema.validate(result, schema)
except (JSONDecodeError, jsonschema.ValidationError):
self.ack() # Maybe we should actually nack the message here
raise
return result
def reply(self, body: str | bytes) -> None:
self.client.send_message(
body=body,
exchange_name="",
exchange_type=ExchangeType.direct,
routing_key=self.reply_to,
)
class RabbitMQClient:
"""
A RabbitMQClient can connect to RabbitMQ and persisently maintain that connection.
It will reconnect on connection failures. It has two modes of operation for
consuming a queue. The first is by explicitly requesting messages through the
``get_message`` method:
.. code-block:: python
consumer = RabbitMQConsumer(cp, 'exchange', 'queue', exchange_type='fanout')
while True:
response = consumer.get_message()
... # handle message
response.ack()
or through it's run_forever mode:
.. code-block:: python
def mycallback(message: dict):
pass
consumer = RabbitMQConsumer(cp, 'exchange', 'queue', exchange_type='fanout')
consumer.run_forever(mycallback, json=True)
The consumer is not thread safe and only a single message may be handled at a time
For sending messages (producing), you can use the `send_message()` method. This
method will ensure that a message is sent, by reconnecting to rabbitmq in case of
a failure.
"""
STOP_CHECK_FREQUENCY_S = 5
PAUSE_BEFORE_RECONNECT_S = 5
DEFAULT_PREFETCH_COUNT = 50
connection: pika.BlockingConnection | None = None
channel: BlockingChannel | None = None
consumer: Iterator[Any] | None = None
# consumer: Generator[Any, None, None] | None = None
last_message_ts = None
_close_time = None
def __init__(
self,
connection_params: pika.ConnectionParameters | Sequence[pika.ConnectionParameters],
exchange_name: str,
exchange_type: ExchangeType,
queue_name: str | None = None,
routing_keys: str | Sequence[str] | None = None,
exclusive: bool = False,
single_active_consumer: bool = False,
quorum_queue: bool = False,
prefetch_count: int = DEFAULT_PREFETCH_COUNT,
auto_ack: bool = False,
reconnect_on_idle_timeout: float | None = None,
reconnect_callback: Callable[[], None] | None= None,
error_callback: Callable[[Exception], None] | None = None,
stop_event: threading.Event | None = None,
pause_before_reconnect: float | None = PAUSE_BEFORE_RECONNECT_S,
):
"""
:param connection_params: rmq connection parameters
:param exchange_name: rmq exchange name
:param queue_name: rmq queue name
:param exchange_type: rmw exchange type, either 'direct' or 'fanout'
:param routing_keys: Either a single string or a sequence of strings that will
be used as routing key for 'direct' exchanges. Default: None
:param exclusive: is this an exclusive queue? Exclusive queues are destroying on
disconnect. Default: False
:param single_active_consumer: The queue has can have a maximum of 1 active
consumer. If multiple consumers are connecting to the same queue, they do
not receive any messages until the current active consumer disconnects for
some reason and they become the active consumer. Default: False.
:param prefetch_count: Limit the number of messages to prefetch from the broker.
This is highly recommended. Set to None to disable and let the broker
decide. Default: ``DEFAULT_PREFETCH_COUNT``
:param auto_ack: automatically acknowledge messages. Default: False.
:param reconnect_callback: Optional callback that will be called when
reconnecting. Default: None
:param error_callback: Optional callback that will be called when an AMQPEror
is raised or an invalid json message is received. Default: None
:param reconnect_on_timeout: Maximum idle time between subsequent
messages before triggering a reconnect. Default: None
:param pause_before_reconnect: Sleep a number of seconds between disconnecting
and reconnecting. Supply a value that evaluates to False (0, None, False) to
disable sleeping. Default: ``PAUSE_BEFORE_RECONNECT_S``.
:param stop_event: A `threading.Event` that will be checked periodically to
determine whether the consumer should stop. Default: None
"""
self.connection_params = self._ensure_connection_params(connection_params)
self.exchange_name = exchange_name
self.queue_name = queue_name
self.exchange_type = exchange_type
self.routing_keys: Sequence[str] = [] # for mypy
if routing_keys is None:
self.routing_keys = []
elif isinstance(routing_keys, str):
self.routing_keys = [routing_keys]
elif isinstance(routing_keys, Sequence):
self.routing_keys = routing_keys
else:
raise AssertionError("impossible - only here for mypy clarity")
self.exclusive = exclusive
self.single_active_consumer = single_active_consumer
self.quorum_queue = quorum_queue
self.prefetch_count = prefetch_count
self.auto_ack = auto_ack
self.reconnect_on_idle_timeout = reconnect_on_idle_timeout
self.reconnect_callback = reconnect_callback
self.error_callback = error_callback
self.stop_event = stop_event
self.pause_before_reconnect = pause_before_reconnect
@staticmethod
def _ensure_connection_params(
connection_params: dict[str, Any] | pika.ConnectionParameters | Sequence[pika.ConnectionParameters]
) -> pika.ConnectionParameters | Sequence[pika.ConnectionParameters]:
if not isinstance(connection_params, dict):
return connection_params
if "hostname" in connection_params:
hostname = connection_params["hostname"]
elif "hostnames" in connection_params:
hostname = connection_params["hostnames"]
else:
raise ValueError(
"Connection parameters must contain either 'hostname' or 'hostnames'"
)
return default_rmq_connection_params(
hostname,
username=connection_params["username"],
password=connection_params["password"],
vhost=connection_params["vhost"],
)
def connection_str(self) -> str:
result = self.exchange_name
if self.queue_name:
result += f"/{self.queue_name}"
if self.routing_keys and self.routing_keys[0]:
result += f" with routing key(s): {','.join(self.routing_keys)}"
return result
def __enter__(self) -> "RabbitMQClient":
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
self.close()
def __del__(self) -> None:
self.close()
def _connect(self, as_consumer: bool = True) -> tuple[pika.BlockingConnection, BlockingChannel, str | None]:
if self.connection is not None:
raise RuntimeError("Already connected")
if self._close_time is not None:
logger.warning(f"Detected rmq disconnect from {self.connection_str()}")
if self.reconnect_callback is not None:
self.reconnect_callback()
self._pause_before_reconnect()
logger.info(f"Connecting to rmq {self.connection_str()}")
connection = pika.BlockingConnection(self.connection_params)
channel, queue = setup_channel(
connection=connection,
exchange_name=self.exchange_name,
exchange_type=self.exchange_type,
queue_name=self.queue_name if as_consumer else None,
exclusive=self.exclusive,
single_active_consumer=self.single_active_consumer,
routing_keys=self.routing_keys,
prefetch_count=self.prefetch_count,
force_quorum_queue=self.quorum_queue,
)
return connection, channel, queue
def connect_consumer(self, timeout: float | None = None) -> None:
"""
Create a channel and bind to the exchange/queue, setting the internal `consumer`
attribute
:param timeout: an optional timeout in seconds. The consumer will
yield None if no message is received within ``timeout`` seconds. default:
RabbitMQClient.STOP_CHECK_FREQUENCY_S
"""
if self.consumer is not None:
return
timeout = timeout if timeout is not None else self.STOP_CHECK_FREQUENCY_S
self.connection, self.channel, self.queue_name = self._connect(as_consumer=True)
assert self.queue_name # should be defined by _connect
self.consumer = self.channel.consume(
queue=self.queue_name, auto_ack=self.auto_ack, inactivity_timeout=timeout
)
def connect_publisher(self) -> None:
if self.connection is not None:
return
self.connection, self.channel, _ = self._connect(as_consumer=False)
def close(self) -> None:
if self.connection and self.connection.is_open:
self.connection.close()
self.consumer = None
self.channel = None
self.connection = None
self._close_time = now()
def consume_queue(self,
json: bool = False,
schema: dict[str, Any] | None = None,
timeout: float | None = None,
as_response: bool = False) -> Generator[Any]:
"""
:param json: set to True to json decode all messages before invoking the
callback. This is ignored if ``as_response`` is set to True. Default False
:param schema: An optional jsonschema to validate the message against. Requires
the json parameter to be set to True
:param timeout: Optional timeout in seconds. See ``RabbitMQClient.get_message``.
In case of a timeout, this function will yield ``None``
:param as_response: yield ``RMQResponse`` objects instead of raw bytes or json
dictionary
"""
if schema is not None and not json:
raise ValueError("Must set json to True when supplying a json schema")
with self:
while self._should_continue():
try:
response = self.get_message(timeout=timeout, reset_last_ts=False)
except TimeoutError:
yield None
continue
except CancelledError:
return
if as_response:
yield response
elif json:
try:
yield response.json(schema=schema)
except (JSONDecodeError, jsonschema.ValidationError) as e:
logger.exception("error parsing message")
self._handle_recoverable_error(e)
continue
else:
yield response.body
try:
response.ack()
except AMQPError as e:
logger.exception("Error while acknowledging message")
self._handle_recoverable_error(e)
self.close() # trigger reconnect
def consume_forever(
self,
callback: Callable[[Any], None],
json: bool = False,
schema: dict[str, Any] | None = None,
timeout: float | None = None,
as_response: bool = False
) -> None:
"""
See RabbitMQClient.consume_queue for other other paramters
:param callback: a function that takes a message coming from the queue as a
single argument
"""
for result in self.consume_queue(
json=json, schema=schema, timeout=timeout, as_response=as_response
):
callback(result)
def get_message(
self, timeout: float | None = None, reset_last_ts: bool | None = None
) -> RMQResponse:
"""
Get a message from the queue, reconnecting when necessary. This may block
indefinitely.
:param timeout: Optional timeout in seconds. If given, this method will raise
`TimeoutError` if the timeout expired before a message was received. The
connection will be kept open. Warning: setting a large timeout may make
the program less responsive, since it will override the default timeout
(``RabbitMQClient.STOP_CHECK_FREQUENCY_S``)
:param reset_last_ts: whether or not to reset the last time a message was
received (RabbitMQClient.last_message_ts) at the beginning of this function.
Setting this to False will ensure that this function will correctly detect
when the connection has been idle for too long over multiple calls that do
not yield a message, but when the function does return because of a timeout.
By default this is set to False if a timeout is given. If no timeout is
given, this is by default set to True.
:return: An RMQResponse or None. If it is a valid response, the `content`
attribute is set the raw body of the message. It is set to None if returning
after the RabbitMQClient.stop_event is set
:raises TimeoutError: If a given custom timeout was reached while waiting for
messages
:raises Cancelled: After a ``RabbitMQClient.stop_event`` was set
"""
if reset_last_ts is None:
reset_last_ts = timeout is None
if not self.last_message_ts or reset_last_ts:
self.last_message_ts = now()
while self._should_continue():
try:
self.connect_consumer(timeout=timeout)
try:
assert self.consumer is not None
method, properties, body = next(self.consumer)
except StopIteration:
# Consumer was cancelled by broker
logger.warning(f"Broker closed connection {self.connection_str()}")
self.close()
continue
if not self._should_continue():
break
if body is not None: # STOP_CHECK_FREQUENCY_S timeout
self.last_message_ts = now()
return RMQResponse(
body=body,
client=self,
delivery_tag=method.delivery_tag,
reply_to=properties.reply_to,
)
if self._is_idle_for_too_long():
logger.warning(
f"No message received from '{self.connection_str()}' in"
f" {self.reconnect_on_idle_timeout} seconds, reconnecting"
)
self.last_message_ts = now()
self.close()
continue
if timeout is not None:
raise TimeoutError
except AMQPError as e:
logger.exception(
f"An error occured while reading '{self.connection_str()}'"
)
self._handle_recoverable_error(e)
self.close()
raise CancelledError
def send_message(
self,
body: str | bytes,
exchange_name: str | None = None,
exchange_type: ExchangeType | None = None,
routing_key: str | None = None,
properties: pika.BasicProperties | None = None,
) -> None:
"""
Send a message to the exchange, retrying on AMQPError to ensure delivery
:param body: the message body (str or bytes)
:param exchange_name: Override exchange name. Default: None
:param exchange_type: Override exchange type. Default: None
:param routing_key: Override routing key. Default: None
:param properties: Additional pika.spec.BasicProperties message properties.
Default: None
"""
if exchange_name is None:
exchange_name = self.exchange_name
if not exchange_type:
exchange_type = self.exchange_type
if routing_key is None:
routing_keys = self.routing_keys
else:
routing_keys = [routing_key]
if len(routing_keys) != 1:
raise RuntimeError("Can only publish messages with a single routing key")
if exchange_name != "" and (
(exchange_name, exchange_type) != (self.exchange_name, self.exchange_type)
):
self.connect_publisher()
assert self.channel is not None
self.channel.exchange_declare(
exchange=exchange_name, exchange_type=exchange_type
)
while self._should_continue():
self.connect_publisher()
assert self.channel is not None
try:
self.channel.basic_publish(
exchange=exchange_name,
routing_key=routing_keys[0],
# mypy won't allow str|bytes here, so ...
body=body if isinstance(body, bytes) else body.encode('utf-8'),
properties=properties,
)
except AMQPError as e:
logger.exception(
f"An error occured sending a message to {exchange_name}"
+ f" with routing key '{routing_key}'"
if routing_key
else ""
)
self._handle_recoverable_error(e)
self.close()
def _should_continue(self) -> bool:
return not (self.stop_event and self.stop_event.is_set())
def _is_idle_for_too_long(self) -> bool:
idle_time = (
now() - self.last_message_ts if self.last_message_ts is not None else 0
)
return (
self.reconnect_on_idle_timeout is not None
and idle_time > self.reconnect_on_idle_timeout
)
def _pause_before_reconnect(self) -> None:
if not self.pause_before_reconnect or self._close_time is None:
return
sleep_time = self._close_time + self.pause_before_reconnect - now()
if sleep_time > 0:
time.sleep(sleep_time)
def maybe_ack(self, tag: int) -> None:
if self.auto_ack:
return # noop
assert self.channel, "channel should be defined by _connect"
self.channel.basic_ack(delivery_tag=tag)
def _handle_recoverable_error(self, exception: Exception) -> None:
if self.error_callback:
self.error_callback(exception)
class TimeoutError(Exception):
"""Raised whenever RabbitMQClient.get_message was called with a ``timeout`` and
that timeout occured while waiting for a message"""
pass
class CancelledError(RuntimeError):
"""Raised whenever RabbitMQClient.stop_event was set while waiting for a message"""
pass