Source code for polaris.rabbitmq.worker

import copy
import functools
import importlib
import threading
import pickle

import pika

from polaris.rabbitmq.config import (
    RABBITMQ_HOST,
    RABBITMQ_PORT,
    RABBITMQ_VIRTUAL_HOST,
    RABBITMQ_USERNAME,
    RABBITMQ_PASSWORD,
)


[docs]class JobWorker(object): """ A worker class for parallel experiments. You can start the worker like below. `polaris-worker --exp-key this_is_test` And if you want to run this worker on multi node environment, you have to add `--mpi` flag. """ def __init__(self, args, logger=None, debug=False): self.exp_key = args.exp_key self.use_mpi = args.mpi self.debug = debug if self.use_mpi: from mpi4py import MPI self.comm = MPI.COMM_WORLD self.rank = self.comm.Get_rank() if (not self.use_mpi) or self.rank == 0: self.job_queue_name = f'job_{self.exp_key}' self.request_queue_name = f'request_{self.exp_key}' if RABBITMQ_USERNAME and RABBITMQ_PASSWORD: credentials = pika.PlainCredentials( RABBITMQ_USERNAME, RABBITMQ_PASSWORD) rabbitmq_params = pika.ConnectionParameters( host=RABBITMQ_HOST, port=RABBITMQ_PORT, virtual_host=RABBITMQ_VIRTUAL_HOST, credentials=credentials) else: rabbitmq_params = pika.ConnectionParameters( host=RABBITMQ_HOST, port=RABBITMQ_PORT, virtual_host=RABBITMQ_VIRTUAL_HOST) self.connection = pika.BlockingConnection(rabbitmq_params) self.channel = self.connection.channel() self.channel.queue_declare(queue=self.request_queue_name) self.channel.queue_declare(queue=self.job_queue_name) self.channel.basic_qos(prefetch_count=1) self.channel.basic_consume( self.on_request, queue=self.job_queue_name) if logger is None: import logging self.logger = logging.getLogger(__name__) if self.debug: self.logger.setLevel(logging.DEBUG) else: self.logger.setLevel(logging.INFO) stream = logging.StreamHandler() formatter = logging.Formatter( '%(asctime)s:%(lineno)d:%(levelname)s:%(message)s') stream.setFormatter(formatter) self.logger.addHandler(stream) else: self.logger = logger def start(self): try: if (not self.use_mpi) or self.rank == 0: self.logger.info('Waiting for new job...') self.request_job() self.channel.start_consuming() else: while True: ctx = None ctx = self.comm.bcast(ctx, root=0) self.run(ctx) except KeyboardInterrupt: self.logger.info('Stop current worker...') if (not self.use_mpi) or self.rank == 0: self.connection.close() if self.use_mpi: from mpi4py import MPI MPI.Finalize() def ack_message(self, reply_to, delivery_tag, exp_payload): self.channel.basic_publish( exchange='', routing_key=reply_to, body=pickle.dumps(exp_payload) ) self.request_job() self.channel.basic_ack(delivery_tag=delivery_tag) def run_thread(self, reply_to, delivery_tag, body): ctx = pickle.loads(body) self.logger.info(ctx) if self.use_mpi: self.comm.bcast(ctx, root=0) exp_result = self.run(ctx) exp_payload = { 'exp_result': exp_result, 'params': ctx['params'], 'exp_info': ctx['exp_info'], } cb = functools.partial( self.ack_message, reply_to, delivery_tag, exp_payload) self.connection.add_callback_threadsafe(cb) def on_request(self, ch, method, props, body): t = threading.Thread( target=self.run_thread, args=(props.reply_to, method.delivery_tag, body)) t.start() def request_job(self): self.channel.basic_publish( exchange='', routing_key=self.request_queue_name, body='' ) def run(self, ctx): params = ctx['params'] exp_info = ctx['exp_info'] fn_module = ctx['fn_module'] fn_name = ctx['fn_name'] args = ctx['args'] fn_params = copy.copy(params) fn_module = importlib.import_module(fn_module) fn = getattr(fn_module, fn_name) if args is not None: exp_result = fn(fn_params, exp_info, *args) else: exp_result = fn(fn_params, exp_info) return exp_result