import pika
from pika.credentials import PlainCredentials
import os
import json
import pkg_resources
from dotenv import load_dotenv
import sys
[docs]def try_parse_json(json_text: str):
if not json_text.startswith('{'):
return json_text
try:
return json.loads(json_text)
except json.JSONDecodeError:
return json_text
env_location = pkg_resources.resource_filename('resources', '.env')
if os.environ.get('DOTENV_LOADED', '0') != '1':
load_dotenv(env_location)
exchanges = {k[18:].lower(): try_parse_json(v)
for k, v in os.environ.items()
if k.startswith('RABBITMQ_EXCHANGE_')}
queues = {k[15:].lower(): try_parse_json(v)
for k, v in os.environ.items()
if k.startswith('RABBITMQ_QUEUE_')}
[docs]class ConnectionManager:
"""
This manager handles all things related to RabbitMQ. Using this class one can connect to a RabbitMQ instance and
not have to remake this code in every other class which needs connection to RabbitMQ.
This class starts a blocking connection meaning that it will keep the connection open as long as the class exists.
By default the connection manager will connect to the `vhost` `/` and on `port` `5672`
Args:
username (str, optional):
The username which is used to connect to the RabbitMQ node. If none is provide it will be
fetched from the environment file.
password (str, optional):
The password which is used to connect to the RabbitMQ node. If none is provide it will be
fetched from the environment file.
url (str, optional):
The url which is used to connect to the RabbitMQ node. If none is provide it will be
fetched from the environment file.
prefetch_count(int, optional):
The number of unacknowledged messages a worker can except. This is a natural way of
spreading load between workers. If None is provided it will be fetched from the environment file.
Note:
The recommended value for `prefetch_count` is around 500 since this maximises performance when using both a
single worker and multiple different workers.
"""
def __init__(self, username: str = "", password: str = "", url: str = "", prefetch_count: int = -2):
if 'sphinx' in sys.modules:
return # don't load when sphinx is running
if username == "":
username = os.environ.get("RABBITMQ_DEFAULT_USER")
if password == "":
password = os.environ.get("RABBITMQ_DEFAULT_PASS")
if url == "":
url = os.environ.get("RABBITMQ_URL", "localhost")
if prefetch_count == -2:
self.prefetch_count = int(os.environ.get("PREFETCH_COUNT", 500))
credentials = PlainCredentials(username=username, password=password)
virtual_host = os.environ.get("RABBITMQ_DEFAULT_VHOST")
params = pika.ConnectionParameters(host=url, port=5672, virtual_host=virtual_host, credentials=credentials,
connection_attempts=10, retry_delay=10)
self.connection = pika.BlockingConnection(parameters=params)
self.channel = self.connection.channel()
[docs] def publish_to_queue(self, queue_type: str, message: str):
"""
Sends a message to a queue.
Args:
queue_type (str): The queue to send the message to. The actual queue name will be fetched based on the
queue name provided by the ``get_queue_name`` function.
message (str): The message to publish to the queue.
"""
channel_name = self.get_queue_name(queue_type)
self.channel.basic_publish(exchange='', routing_key=channel_name, body=message)
[docs] def _subscribe(self, queue_type: str, callback: callable):
"""
Subscribes to a queue and starts consuming. When a new message is received the callback will be executed.
Args:
queue_type (str): The queue to subscribe to.The actual queue name will be fetched based on the
queue name provided by the ``get_queue_name`` function.
callback (function): The function to execute upon receiving a message.
Note:
This is a private function and should not be used directly.
Use the ``subscribe_to_queue`` function instead
"""
self.channel.basic_qos(prefetch_count=self.prefetch_count)
self.channel.basic_consume(callback, queue_type)
self.channel.start_consuming()
[docs] def subscribe_to_queue(self, queue_type: str, callback: callable):
"""
Subscribe to a queue. Once a message is received it wil be passed to the callback which will then be executed.
Uses the ``_subscribe`` function underwater.
Examples:
>>> def callback(ch, method, properties, body):
>>> print('Got message {}'.format(body))
>>> # Create a connection manager and subscribe to the test queue
>>> connection_manager = ConnectionManager()
>>> connection_manager.subscribe_to_queue('clean', callback)
>>> # Send a message to test the callback is working.
>>> connection_manager.publish_to_queue('clean', 'test')
'Got message test'
Args:
queue_type: The queue to subscribe to.The actual queue name will be fetched based on the
queue name provided by the ``get_queue_name`` function.
callback: The function to execute upon receiving a message.
"""
queue_type = self.get_queue_name(queue_type)
self._subscribe(queue_type, callback)
[docs] @staticmethod
def get_queue_name(queue_type: str):
"""
Returns the channel name for the provided queue_type. The queue name is fetched from the ``queues``
variable.
Examples:
>>> queues = {
>>> 'clean': "1_clean",
>>> 'preprocess': "2_preprocess",
>>> 'process_wordvec': json.loads('{"name":"3_process_wordvec","exchange":"3_process"}'),
>>> 'process_location': json.loads('{"name":"3_process_location","exchange":"3_process"}'),
>>> 'save': '4_save',
>>> }
>>> ConnectionManager.get_queue_name('clean')
'1_clean'
>>> ConnectionManager.get_queue_name('process_wordvec')
'3_process_wordvec'
>>> ConnectionManager.get_queue_name('non_existing_queue')
KeyError: 'non_existing_queue'
Args:
queue_type (str): The queue_type you want to get the name from. Name will be resolved using the ``queues``
variable.
Returns:
str: `queue_name` if successful
If the `queue_type` is present in ``queues`` returns the corresponding queue name.
If the queue_type is one with a exchange binding returns the name.
Raises:
KeyError: the provided `queue_type` is not present in ``queues``
"""
queue = queues[queue_type]
return queue['name'] if isinstance(queue, dict) else queue
[docs] @staticmethod
def get_queues():
"""
Get's all the queues in the ``queues`` variable
Returns:
dict: returns the queue dict.
"""
return queues
[docs] @staticmethod
def get_exchanges():
"""
Get's all the exchanges in the ``exchanges`` variable
Returns:
dict: returns the exchanges dict.
"""
return exchanges
[docs] @staticmethod
def get_exchange_name(exchange_type):
"""
Returns the name of the provided exchange type. Is resolved by the ``exchanges`` variable
Examples:
>>> exchanges = {
>>> 'process': {'name': '3_process', 'type': 'fanout'}
>>> }
>>> ConnectionManager.get_exchange_name('process')
'3_process'
>>> ConnectionManager.get_queue_name('non_existing_exchange')
KeyError: 'non_existing_exchange'
Args:
exchange_type: The exchange_type you want to get the name from. Name will be resolved using the ``exchanges``
variable.
Returns:
str: Returns the name corresponding to the `exchange_type`
Raises:
KeyError: the provided `exchange_type` is not present in ``exchanges``
"""
return exchanges[exchange_type]['name']