Source code for smug.connection_manager

import pika
from pika.credentials import PlainCredentials
import os
import pkg_resources
from dotenv import load_dotenv

queues = {
    'format': os.environ.get("FORMATTING_QUEUE_NAME", "1_format"),
    'clean': os.environ.get("CLEANING_QUEUE_NAME", "2_clean"),
    'preprocess': os.environ.get("PREPROCESSING_QUEUE_NAME", "3_preprocess"),
    'process': os.environ.get("PROCESSING", "4_process"),
    'save': os.environ.get("SAVE_QUEUE_NAME", "5_save"),

}


[docs]class ConnectionManager: def __init__(self): env_location = pkg_resources.resource_filename('resources', '.env') load_dotenv(env_location) username = os.environ.get("RABBITMQ_DEFAULT_USER") password = os.environ.get("RABBITMQ_DEFAULT_PASS") url = os.environ.get("RABBITMQ_URL", "localhost") credentials = PlainCredentials(username=username, password=password) params = pika.ConnectionParameters(host=url, port=5672, virtual_host="smug", credentials=credentials) self.connection = pika.BlockingConnection(parameters=params) self.channel = self.connection.channel() self.prefetch_count = int(os.environ.get("PREFETCH_COUNT", 500))
[docs] def publish_to_queue(self, queue_type, message): channel_name = self.get_queue_name(queue_type) self.channel.basic_publish(exchange='', routing_key=channel_name, body=message)
[docs] def publish_to_exchange(self, routing_key, message, exchange='amq.direct'): self.channel.basic_publish(exchange=exchange, routing_key=routing_key, body=message)
def _subscribe(self, queue_name, callback): self.channel.basic_qos(prefetch_count=self.prefetch_count) self.channel.basic_consume(callback, queue_name) self.channel.start_consuming()
[docs] def subscribe_to_queue(self, queue_name, callback): queue_name = self.get_queue_name(queue_name) self._subscribe(queue_name, callback)
[docs] def subscribe_to_routing_key(self, routing_key, callback): result = self.channel.queue_declare(exclusive=True) queue_name = result.method.queue self.channel.queue_bind(queue_name, 'amq.direct', routing_key=routing_key) self._subscribe(queue_name, callback)
[docs] @staticmethod def get_queue_name(channel_type): return queues[channel_type]
[docs] @staticmethod def get_queue_names(): return queues.values()