Source code for djangofloor.tasks

"""Define Celery tasks and functions for calling signals
=====================================================

This module is automatically imported by Celery.
Use these functions for:

  * setting websocket channels allowed for a given :class:`django.http.response.HttpResponse`,
  * calling signals, with a full function (:meth:`djangofloor.tasks.call`) and a
    shortcut (:meth:`djangofloor.tasks.scall`)

"""

import json
import logging
import os
import uuid
import warnings

from celery import shared_task
from django.conf import settings
from django.utils.lru_cache import lru_cache
from django.utils.module_loading import import_string
from redis import StrictRedis, ConnectionPool

from djangofloor.decorators import (
    REGISTERED_SIGNALS,
    SignalConnection,
    REGISTERED_FUNCTIONS,
    FunctionConnection,
    DynamicQueueName,
)
from djangofloor.utils import import_module, RemovedInDjangoFloor200Warning
from djangofloor.wsgi.exceptions import NoWindowKeyException
from djangofloor.wsgi.window_info import WindowInfo

__author__ = "Matthieu Gallet"
logger = logging.getLogger("djangofloor.signals")


[docs]class Constant: """Allow to define constants that can be nicely printed to stdout""" def __init__(self, name): self.name = name def __str__(self): return self.name def __repr__(self): return self.name
# special values for the "to" argument SERVER = Constant("SERVER") SESSION = Constant("SESSION") WINDOW = Constant("WINDOW") USER = Constant("USER") BROADCAST = Constant("BROADCAST") _signal_encoder = import_string(settings.WEBSOCKET_SIGNAL_ENCODER) _topic_serializer = import_string(settings.WEBSOCKET_TOPIC_SERIALIZER) __values = { "host": settings.WEBSOCKET_REDIS_HOST, "port": ":%s" % settings.WEBSOCKET_REDIS_PORT if settings.WEBSOCKET_REDIS_PORT else "", "db": settings.WEBSOCKET_REDIS_DB, "password": ":%s@" % settings.WEBSOCKET_REDIS_PASSWORD if settings.WEBSOCKET_REDIS_PASSWORD else "", } redis_connection_pool = ConnectionPool.from_url( "redis://%(password)s%(host)s%(port)s/%(db)s" % __values )
[docs]def get_websocket_redis_connection(): """Return a valid Redis connection, using a connection pool.""" return StrictRedis(connection_pool=redis_connection_pool)
[docs]def set_websocket_topics(request, *topics): """Use it in a Django view for setting websocket topics. Any signal sent to one of these topics will be received by the client. :param request: :class:`django.http.request.HttpRequest` :param topics: list of topics that will be subscribed by the websocket (can be any Python object). """ if not settings.USE_CELERY: return # noinspection PyTypeChecker if not hasattr(request, "window_key"): raise NoWindowKeyException( "You should use the DjangoFloorMiddleware middleware" ) token = request.window_key request.has_websocket_topics = True prefix = settings.WEBSOCKET_REDIS_PREFIX request = WindowInfo.from_request(request) topic_strings = {_topic_serializer(request, x) for x in topics if x is not SERVER} # noinspection PyUnresolvedReferences,PyTypeChecker if getattr(request, "user", None) and request.user.is_authenticated: topic_strings.add(_topic_serializer(request, USER)) topic_strings.add(_topic_serializer(request, WINDOW)) topic_strings.add(_topic_serializer(request, BROADCAST)) connection = get_websocket_redis_connection() redis_key = "%s%s" % (prefix, token) connection.delete(redis_key) for topic in topic_strings: if topic is not None: connection.rpush(redis_key, prefix + topic) connection.expire(redis_key, settings.WEBSOCKET_REDIS_EXPIRE)
[docs]def scall(window_info, signal_name, to=None, **kwargs): """Shortcut to :meth:`djangofloor.tasks.call`, allowing to directly pass arguments of the signal to this function. Your signal cannot use `window_info`, `signal_name` and `to` as argument names. These two successive calls are strictly equivalent: .. code-block:: python from djangofloor.tasks import call, scall, WINDOW, SERVER def my_python_view(request): scall(request, 'my.signal.name', to=[WINDOW, SERVER], arg1=12, arg2='Hello') call(request, 'my.signal.name', to=[WINDOW, SERVER], kwargs={'arg1': 12, 'arg2': 'Hello'}) """ return _call_signal( window_info, signal_name, to=to, kwargs=kwargs, from_client=False )
# noinspection PyIncorrectDocstring
[docs]def call( window_info, signal_name, to=None, kwargs=None, countdown=None, expires=None, eta=None, ): """Call a DjangoFloor signal. :param window_info: either a :class:`django.http.request.HttpRequest` or a :class:`djangofloor.wsgi.window_info.WindowInfo` :param signal_name: name of the called signal (:class:`str`) :param to: :class:`list` of the topics that should receive the signal :param kwargs: dict with all arguments of your signal. Will be encoded to JSON with `settings.WEBSOCKET_SIGNAL_ENCODER` and decoded with `settings.WEBSOCKET_SIGNAL_DECODER`. :param countdown: check the Celery doc (in a nutshell: number of seconds before executing the signal) :param expires: check the Celery doc (in a nutshell: if this signal is not executed before this number of seconds, it is cancelled) :param eta: check the Celery doc (in a nutshell: datetime of running this signal) """ return _call_signal( window_info, signal_name, to=to, kwargs=kwargs, countdown=countdown, expires=expires, eta=eta, from_client=False, )
def _call_signal( window_info, signal_name, to=None, kwargs=None, countdown=None, expires=None, eta=None, from_client=False, ): """actually calls a DF signal, dispatching them to their destination: * only calls Celery tasks if a delay is required (`coutdown` argument) * write messages to websockets if no delay is required """ import_signals_and_functions() window_info = WindowInfo.from_request( window_info ) # ensure that we always have a true WindowInfo object if kwargs is None: kwargs = {} for k in (SERVER, WINDOW, USER, BROADCAST): if to is k: to = [k] if to is None: to = [USER] serialized_client_topics = [] to_server = False logger.debug('received signal "%s" to %r' % (signal_name, to)) for topic in to: if topic is SERVER: if signal_name not in REGISTERED_SIGNALS: logger.debug('Signal "%s" is unknown by the server.' % signal_name) to_server = True else: serialized_topic = _topic_serializer(window_info, topic) if serialized_topic is not None: serialized_client_topics.append(serialized_topic) celery_kwargs = {} if expires: celery_kwargs["expires"] = expires if eta: celery_kwargs["eta"] = eta if countdown: celery_kwargs["countdown"] = countdown queues = { x.get_queue(window_info, kwargs) for x in REGISTERED_SIGNALS.get(signal_name, []) } window_info_as_dict = None if window_info: window_info_as_dict = window_info.to_dict() if celery_kwargs: if serialized_client_topics: queues.add(settings.CELERY_DEFAULT_QUEUE) for queue in queues: topics = ( serialized_client_topics if queue == settings.CELERY_DEFAULT_QUEUE else [] ) _server_signal_call.apply_async( [ signal_name, window_info_as_dict, kwargs, from_client, topics, to_server, queue, ], queue=queue, **celery_kwargs ) else: if to_server: for queue in queues: _server_signal_call.apply_async( [ signal_name, window_info_as_dict, kwargs, from_client, [], to_server, queue, ], queue=queue, ) if serialized_client_topics: signal_id = str(uuid.uuid4()) for topic in serialized_client_topics: _call_ws_signal(signal_name, signal_id, topic, kwargs) def _call_ws_signal(signal_name, signal_id, serialized_topic, kwargs): connection = get_websocket_redis_connection() serialized_message = json.dumps( {"signal": signal_name, "opts": kwargs, "signal_id": signal_id}, cls=_signal_encoder, ) topic = settings.WEBSOCKET_REDIS_PREFIX + serialized_topic logger.debug("send message to topic %r" % topic) connection.publish(topic, serialized_message.encode("utf-8")) def _return_ws_function_result(window_info, result_id, result, exception=None): connection = get_websocket_redis_connection() json_msg = { "result_id": result_id, "result": result, "exception": str(exception) if exception else None, } serialized_message = json.dumps(json_msg, cls=_signal_encoder) serialized_topic = _topic_serializer(window_info, WINDOW) if serialized_topic: topic = settings.WEBSOCKET_REDIS_PREFIX + serialized_topic logger.debug("send function result to topic %r" % topic) connection.publish(topic, serialized_message.encode("utf-8"))
[docs]@lru_cache() def import_signals_and_functions(): """Import all `signals.py`, 'forms.py' and `functions.py` files to register signals and WS functions (tries these files for all Django apps). """ for app in settings.INSTALLED_APPS: package_dir = None try: mod = import_module(app) package_dir = os.path.dirname(mod.__file__) except ImportError: pass except Exception as e: logger.exception(e) for module_name in ("signals", "forms", "functions"): try: import_module("%s.%s" % (app, module_name)) except ImportError as e: if package_dir and os.path.isfile( os.path.join(package_dir, "%s.py" % module_name) ): logger.exception(e) except Exception as e: logger.exception(e) logger.debug( "Found signals: %s" % ", ".join(["%s (%d)" % (k, len(v)) for (k, v) in REGISTERED_SIGNALS.items()]) ) logger.debug( "Found functions: %s" % ", ".join([str(k) for k in REGISTERED_FUNCTIONS]) )
@shared_task(serializer="json") def _server_signal_call( signal_name, window_info_dict, kwargs=None, from_client=False, serialized_client_topics=None, to_server=False, queue=None, ): logger.info( 'Signal "%s" called on queue "%s" to topics %s (from client?: %s, to server?: %s)' % (signal_name, queue, serialized_client_topics, from_client, to_server) ) try: if kwargs is None: kwargs = {} if serialized_client_topics: signal_id = str(uuid.uuid4()) for topic in serialized_client_topics: _call_ws_signal(signal_name, signal_id, topic, kwargs) window_info = WindowInfo.from_dict(window_info_dict) import_signals_and_functions() if not to_server or signal_name not in REGISTERED_SIGNALS: return for connection in REGISTERED_SIGNALS[signal_name]: assert isinstance(connection, SignalConnection) if connection.get_queue(window_info, kwargs) != queue or ( from_client and not connection.is_allowed_to(connection, window_info, kwargs) ): continue new_kwargs = connection.check(kwargs) if new_kwargs is None: continue result = connection(window_info, **new_kwargs) # TODO remove the following part if isinstance(result, list): warnings.warn( "signals should not return list anymore.", RemovedInDjangoFloor200Warning, ) for data in result: call( window_info, data["signal"], to=[WINDOW, SERVER], kwargs=data["options"], ) except Exception as e: logger.exception(e) @shared_task(serializer="json") def _server_function_call(function_name, window_info_dict, result_id, kwargs=None): logger.info("Function %s called from client." % function_name) e, result, window_info = None, None, None try: if kwargs is None: kwargs = {} window_info = WindowInfo.from_dict(window_info_dict) import_signals_and_functions() connection = REGISTERED_FUNCTIONS[function_name] assert isinstance(connection, FunctionConnection) if not connection.is_allowed_to(connection, window_info, kwargs): raise ValueError("Unauthorized function call %s" % connection.path) kwargs = connection.check(kwargs) if kwargs is not None: # noinspection PyBroadException result = connection(window_info, **kwargs) except Exception as e: logger.exception(e) result = None if window_info: _return_ws_function_result(window_info, result_id, result, exception=e) # TODO remove the following functions
[docs]def import_signals(): """.. deprecated:: 1.0 do not use it""" warnings.warn( "djangofloor.tasks.import_signals() has been replaced by " "djangofloor.tasks.import_signals_and_functions()", RemovedInDjangoFloor200Warning, ) return import_signals_and_functions()
@shared_task(serializer="json") def signal_task(signal_name, request_dict, from_client, kwargs): """.. deprecated:: 1.0 do not use it""" warnings.warn( "djangofloor.tasks.signal_task is deprecated.", RemovedInDjangoFloor200Warning ) return _server_signal_call( signal_name, request_dict, kwargs=kwargs, from_client=from_client, to_server=True, ) @shared_task(serializer="json") def delayed_task(signal_name, request_dict, sharing, from_client, kwargs): """.. deprecated:: 1.0 do not use it""" warnings.warn( "djangofloor.tasks.delayed_task is deprecated.", RemovedInDjangoFloor200Warning ) import_signals() window_info = WindowInfo.from_dict(request_dict) # noinspection PyProtectedMember from djangofloor.df_ws4redis import _sharing_to_topics to = _sharing_to_topics(window_info, sharing) + [SERVER] return _server_signal_call( signal_name, request_dict, kwargs=kwargs, from_client=from_client, serialized_client_topics=to, to_server=True, )
[docs]def df_call( signal_name, request, sharing=None, from_client=False, kwargs=None, countdown=None, expires=None, eta=None, ): """.. deprecated:: 1.0, do not use it""" # noinspection PyUnusedLocal from_client = from_client warnings.warn( "djangofloor.tasks.df_call is deprecated.", RemovedInDjangoFloor200Warning ) # noinspection PyProtectedMember from djangofloor.df_ws4redis import _sharing_to_topics to = _sharing_to_topics(request, sharing) + [SERVER] call( signal_name, request, to=to, kwargs=kwargs, countdown=countdown, expires=expires, eta=eta, )
[docs]def get_expected_queues(): expected_queues = set() if not settings.USE_CELERY: return expected_queues import_signals_and_functions() for connection in REGISTERED_FUNCTIONS.values(): if isinstance(connection.queue, DynamicQueueName): for queue_name in connection.queue.get_available_queues(): expected_queues.add(queue_name) elif not callable(connection.queue): expected_queues.add(connection.queue) for connections in REGISTERED_SIGNALS.values(): for connection in connections: if isinstance(connection.queue, DynamicQueueName): for queue_name in connection.queue.get_available_queues(): expected_queues.add(queue_name) elif not callable(connection.queue): expected_queues.add(connection.queue) return expected_queues