Skip to content
This repository has been archived by the owner on Nov 17, 2018. It is now read-only.

Add callbacks for sent task ack, sent task and reshape API as retrieving result from AsyncResult.get() (fix #38) #43

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,55 @@ tornado-celery is a non-blocking Celery client for Tornado web framework
Usage
-----

Calling Celery tasks from Tornado RequestHandler: ::
Calling Celery tasks(has return value) from Tornado RequestHandler: ::

from tornado import gen, web
import tcelery, tasks

tcelery.setup_nonblocking_producer()

class AsyncHandler(web.RequestHandler):
@asynchronous
@web.asynchronous
def get(self):
tasks.echo.apply_async(args=['Hello world!'], callback=self.on_result)
tasks.echo.apply_async(args=['Hello world!'], callback=self.on_async_result)

def on_result(self, response):
self.write(str(response.result))
def on_async_result(self, async_result):
async_result.get(callback=self.on_actual_result)

def on_actual_result(self, result):
self.write(str(result))
self.finish()

Calling tasks with generator-based interface: ::
with generator-based interface: ::

class GenAsyncHandler(web.RequestHandler):
@asynchronous
@web.asynchronous
@gen.coroutine
def get(self):
response = yield gen.Task(tasks.sleep.apply_async, args=[3])
self.write(str(response.result))
async_result = yield gen.Task(tasks.sleep.apply_async, args=[3])
result = yield gen.Task(async_result.get)
self.write(str(result))
self.finish()

Calling Celery tasks(no return value) from Tornado RequestHandler: ::

@web.asynchronous
def get(self):
tasks.echo.apply_async(args=['Hello world!'], callback=self.on_async_result)

def on_async_result(self, async_result):
self.write("task sent") # ack-ed if BROKER_TRANSPORT_OPTIONS: {'confirm_publish': True}
self.finish()

with generator-based interface: ::

@web.asynchronous
@gen.coroutine
def get(self):
yield gen.Task(tasks.sleep.apply_async, args=[3])
self.write("task sent") # ack-ed if BROKER_TRANSPORT_OPTIONS: {'confirm_publish': True}
self.finish()

**NOTE:** Currently callbacks only work with AMQP and Redis backends.
To use the Redis backend, you must install `tornado-redis
<https://github.com/leporo/tornado-redis>`_.
Expand Down
10 changes: 9 additions & 1 deletion tcelery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def connect():
options = celery_app.conf.get('CELERYT_PIKA_OPTIONS', {})
producer_cls.conn_pool.connect(broker_url,
options=options,
callback=on_ready)
callback=on_ready,
confirm_delivery=_get_confirm_publish_conf(celery_app.conf))

io_loop.add_callback(connect)

def _get_confirm_publish_conf(conf):
broker_transport_options = conf.get('BROKER_TRANSPORT_OPTIONS', {})
if (broker_transport_options and
broker_transport_options.get('confirm_publish') is True):
return True
return False
65 changes: 62 additions & 3 deletions tcelery/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,20 @@

from tornado import ioloop

LOGGER = logging.getLogger(__name__)

class Connection(object):

content_type = 'application/x-python-serialize'

def __init__(self, io_loop=None):
def __init__(self, io_loop=None, confirm_delivery=False):
self.channel = None
self.connection = None
self.url = None
self.io_loop = io_loop or ioloop.IOLoop.instance()
self.confirm_delivery = confirm_delivery
if self.confirm_delivery:
self.confirm_delivery_handler = ConfirmDeliveryHandler()

def connect(self, url=None, options=None, callback=None):
if url is not None:
Expand Down Expand Up @@ -61,9 +65,17 @@ def on_connect(self, callback, connection):

def on_channel_open(self, callback, channel):
self.channel = channel
if self.confirm_delivery:
self.init_confirm_delivery()
if callback:
callback()

def init_confirm_delivery(self):
self.channel.confirm_delivery(callback=self.confirm_delivery_handler.on_delivery_confirmation,
nowait=True)
self.confirm_delivery_handler.reset_message_seq()
self.confirm_delivery_handler.reset_coroutine_callbacks()

def on_exchange_declare(self, frame):
pass

Expand Down Expand Up @@ -118,10 +130,10 @@ def __init__(self, limit, io_loop=None):
self._connection = None
self.io_loop = io_loop

def connect(self, broker_url, options=None, callback=None):
def connect(self, broker_url, options=None, callback=None, confirm_delivery=False):
self._on_ready = callback
for _ in range(self._limit):
conn = Connection(io_loop=self.io_loop)
conn = Connection(io_loop=self.io_loop, confirm_delivery=confirm_delivery)
conn.connect(broker_url, options=options,
callback=partial(self._on_connect, conn))

Expand All @@ -135,3 +147,50 @@ def _on_connect(self, connection):
def connection(self):
assert self._connection is not None
return next(self._connection)

class ConfirmDeliveryHandler(object):

def __init__(self):
self._message_seq = 0
self._acked = 0
self._nacked = 0
self._unknown_ack = 0
self.coroutine_callbacks = {}

def on_delivery_confirmation(self, method_frame):
"""Invoked by pika when RabbitMQ responds to a Basic.Publish RPC
command, passing in either a Basic.Ack or Basic.Nack frame with
the delivery tag of the message that was published. The delivery tag
is an integer counter indicating the message number that was sent
on the channel via Basic.Publish. After Basic.Ack is received, it
will call corresponding callback based on delivery tag number.

:param pika.frame.Method method_frame: Basic.Ack or Basic.Nack frame

"""
confirmation_type = method_frame.method.NAME.split('.')[1].lower()
delivery_tag = method_frame.method.delivery_tag
message = ('Received %s for delivery tag: %i' %
(confirmation_type,
delivery_tag))
LOGGER.debug(message)

if confirmation_type == 'ack':
self._acked += 1
elif confirmation_type == 'nack':
self._nacked += 1
else:
self._unknown_ack += 1
coroutine_callback = self.coroutine_callbacks.pop(delivery_tag)
if coroutine_callback:
coroutine_callback(None)

def reset_message_seq(self):
self._message_seq = 0

def reset_coroutine_callbacks(self):
self.coroutine_callbacks.clear()

def add_callback(self, callback):
self._message_seq += 1
self.coroutine_callbacks[self._message_seq] = callback
18 changes: 11 additions & 7 deletions tcelery/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,20 +230,24 @@ def post(self, taskname):
partial(self.on_time, task_id))

task.apply_async(args=args, kwargs=kwargs, task_id=task_id,
callback=partial(self.on_complete, htimeout),
callback=partial(self.on_async_result, htimeout),
**options)

def on_complete(self, htimeout, result):
def on_async_result(self, htimeout, async_result):
self._result = async_result
async_result.get(callback=partial(self.on_actual_result, htimeout))

def on_actual_result(self, htimeout, result):
if self._finished:
return
if htimeout:
ioloop.IOLoop.instance().remove_timeout(htimeout)
response = {'task-id': result.task_id, 'state': result.state}
if result.successful():
response['result'] = result.result
response = {'task-id': self._result.task_id, 'state': self._result.state}
if self._result.successful():
response['result'] = result
else:
response['traceback'] = result.traceback
response['error'] = repr(result.result)
response['traceback'] = self._result.traceback
response['error'] = repr( self._result.result)
self.write(response)
self.finish()

Expand Down
34 changes: 18 additions & 16 deletions tcelery/producer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import absolute_import

import sys
from functools import partial

from datetime import timedelta
from kombu import serialization
from kombu.utils import cached_property
from celery.app.amqp import TaskProducer
from celery.backends.amqp import AMQPBackend
from celery.backends.redis import RedisBackend
from celery.backends.base import DisabledBackend
from celery.utils import timeutils

from .result import AsyncResult
Expand All @@ -20,7 +21,6 @@

is_py3k = sys.version_info >= (3, 0)


class AMQPConsumer(object):
def __init__(self, producer):
self.producer = producer
Expand Down Expand Up @@ -68,10 +68,6 @@ def publish(self, body, routing_key=None, delivery_mode=None,

if callback and not callable(callback):
raise ValueError('callback should be callable')
if callback and not isinstance(self.app.backend,
(AMQPBackend, RedisBackend)):
raise NotImplementedError(
'callback can be used only with AMQP or Redis backends')

body, content_type, content_encoding = self._prepare(
body, serializer, content_type, content_encoding,
Expand All @@ -94,10 +90,16 @@ def publish(self, body, routing_key=None, delivery_mode=None,
exchange=exchange, declare=declare)

if callback:
self.consumer.wait_for(task_id,
partial(self.on_result, task_id, callback),
expires=self.prepare_expires(type=int),
persistent=self.app.conf.CELERY_RESULT_PERSISTENT)
async_result = self.result_cls(task_id=task_id,
result=result,
producer=self)
if conn.confirm_delivery:
conn.confirm_delivery_handler.add_callback(lambda result:
callback(async_result))

else:
callback(async_result)

return result

@cached_property
Expand All @@ -117,12 +119,6 @@ def decode(self, payload):
content_type=self.content_type,
content_encoding=self.content_encoding)

def on_result(self, task_id, callback, reply):
reply = self.decode(reply)
reply['task_id'] = task_id
result = self.result_cls(**reply)
callback(result)

def prepare_expires(self, value=None, type=None):
if value is None:
value = self.app.conf.CELERY_TASK_RESULT_EXPIRES
Expand All @@ -132,5 +128,11 @@ def prepare_expires(self, value=None, type=None):
return type(value * 1000)
return value

def fail_if_backend_not_supported(self):
if not isinstance(self.app.backend,
(AMQPBackend, RedisBackend, DisabledBackend)):
raise NotImplementedError(
'result retrieval can be used only with AMQP or Redis backends')

def __repr__(self):
return '<NonBlockingTaskProducer: {0.channel}>'.format(self)
26 changes: 24 additions & 2 deletions tcelery/result.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from __future__ import absolute_import
from __future__ import with_statement

from functools import partial

import celery


class AsyncResult(celery.result.AsyncResult):
def __init__(self, task_id, status=None, traceback=None, result=None,
**kwargs):
def __init__(self, task_id, status=None, traceback=None,
result=None, producer=None, **kwargs):
super(AsyncResult, self).__init__(task_id)
self._status = status
self._traceback = traceback
self._result = result
self._producer = producer

@property
def status(self):
Expand All @@ -27,3 +30,22 @@ def traceback(self):
@property
def result(self):
return self._result or super(AsyncResult, self).result

def get(self, callback=None):
self._producer.fail_if_backend_not_supported()
self._producer.consumer.wait_for(self.task_id,
partial(self.on_result, callback),
expires=self._producer.prepare_expires(type=int),
persistent=self._producer.app.conf.CELERY_RESULT_PERSISTENT)

def on_result(self, callback, reply):
reply = self._producer.decode(reply)
self._status = reply.get('status')
self._traceback = reply.get('traceback')
self._result = reply.get('result')
if callback:
callback(self._result)

def _get_task_meta(self):
self._producer.fail_if_backend_not_supported()
return super(AsyncResult, self)._get_task_meta()