import copy
import functools
import importlib
import pickle
import pika
from polaris.rabbitmq.config import (
RABBITMQ_HOST,
RABBITMQ_PORT,
RABBITMQ_VIRTUAL_HOST,
RABBITMQ_USERNAME,
RABBITMQ_PASSWORD,
)
[docs]class JobWorker(object):
"""
A worker class for parallel experiments.
You can start the worker like below.
`polaris-worker --exp-key this_is_test`
And if you want to run this worker on multi node environment,
you have to add `--mpi` flag.
"""
def __init__(self, args, logger=None, debug=False):
self.exp_key = args.exp_key
self.use_mpi = args.mpi
self.debug = debug
self.run_once = args.run_once
if self.use_mpi:
from mpi4py import MPI
self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank()
if (not self.use_mpi) or self.rank == 0:
self.job_queue_name = f'job_{self.exp_key}'
self.request_queue_name = f'request_{self.exp_key}'
if RABBITMQ_USERNAME and RABBITMQ_PASSWORD:
credentials = pika.PlainCredentials(
RABBITMQ_USERNAME, RABBITMQ_PASSWORD)
rabbitmq_params = pika.ConnectionParameters(
host=RABBITMQ_HOST,
port=RABBITMQ_PORT,
virtual_host=RABBITMQ_VIRTUAL_HOST,
credentials=credentials,
heartbeat_interval=0)
else:
rabbitmq_params = pika.ConnectionParameters(
host=RABBITMQ_HOST,
port=RABBITMQ_PORT,
virtual_host=RABBITMQ_VIRTUAL_HOST,
heartbeat_interval=0)
self.connection = pika.BlockingConnection(rabbitmq_params)
self.channel = self.connection.channel()
self.channel.queue_declare(queue=self.request_queue_name)
self.channel.queue_declare(queue=self.job_queue_name)
self.channel.basic_qos(prefetch_count=1)
self.channel.basic_consume(
self.on_request, queue=self.job_queue_name)
if logger is None:
import logging
self.logger = logging.getLogger(__name__)
if self.debug:
self.logger.setLevel(logging.DEBUG)
else:
self.logger.setLevel(logging.INFO)
stream = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s:%(lineno)d:%(levelname)s:%(message)s')
stream.setFormatter(formatter)
self.logger.addHandler(stream)
else:
self.logger = logger
def start(self):
try:
if (not self.use_mpi) or self.rank == 0:
self.logger.info('Waiting for new job...')
self.request_job()
self.channel.start_consuming()
else:
while True:
ctx = None
ctx = self.comm.bcast(ctx, root=0)
self.run(ctx)
if self.run_once:
break
except KeyboardInterrupt:
self.logger.info('Stop current worker...')
if (not self.use_mpi) or self.rank == 0:
self.connection.close()
if self.use_mpi:
from mpi4py import MPI
MPI.Finalize()
def on_request(self, ch, method, props, body):
ctx = pickle.loads(body)
self.logger.info(ctx)
if self.use_mpi:
ctx = self.comm.bcast(ctx, root=0)
exp_result = self.run(ctx)
exp_payload = {
'exp_result': exp_result,
'params': ctx['params'],
'exp_info': ctx['exp_info'],
}
self.channel.basic_publish(
exchange='',
routing_key=props.reply_to,
body=pickle.dumps(exp_payload)
)
self.channel.basic_ack(delivery_tag=method.delivery_tag)
if self.run_once:
self.connection.close()
else:
self.request_job()
def request_job(self):
self.channel.basic_publish(
exchange='',
routing_key=self.request_queue_name,
body=''
)
def run(self, ctx):
params = ctx['params']
exp_info = ctx['exp_info']
fn_module = ctx['fn_module']
fn_name = ctx['fn_name']
args = ctx['args']
fn_params = copy.copy(params)
fn_module = importlib.import_module(fn_module)
fn = getattr(fn_module, fn_name)
if args is not None:
exp_result = fn(fn_params, exp_info, *args)
else:
exp_result = fn(fn_params, exp_info)
return exp_result