From a57c91067ebb85e77cda510b78c2bbca9dc7926c Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Sat, 26 Aug 2017 18:02:01 +0300 Subject: [PATCH] Optimize state filter. --- aiogram/dispatcher/__init__.py | 47 +++++++++++++++++++++++++++++++--- aiogram/dispatcher/filters.py | 13 +++++++--- aiogram/dispatcher/webhook.py | 15 +++++++++-- 3 files changed, 66 insertions(+), 9 deletions(-) diff --git a/aiogram/dispatcher/__init__.py b/aiogram/dispatcher/__init__.py index 0230484d..3933a04a 100644 --- a/aiogram/dispatcher/__init__.py +++ b/aiogram/dispatcher/__init__.py @@ -3,16 +3,20 @@ import functools import logging import typing -from .filters import CommandsFilter, RegexpFilter, ContentTypeFilter, generate_default_filters +from .filters import CommandsFilter, ContentTypeFilter, RegexpFilter, USER_STATE, generate_default_filters from .handler import Handler -from .storage import DisabledStorage, BaseStorage, FSMContext +from .storage import BaseStorage, DisabledStorage, FSMContext from .webhook import BaseResponse from ..bot import Bot from ..types.message import ContentType -from ..utils.exceptions import TelegramAPIError, NetworkError +from ..utils import context +from ..utils.exceptions import NetworkError, TelegramAPIError log = logging.getLogger(__name__) +MODE = 'MODE' +LONG_POOLING = 'long-pooling' + class Dispatcher: """ @@ -79,7 +83,7 @@ class Dispatcher: """ tasks = [] for update in updates: - tasks.append(self.updates_handler.notify(update)) + tasks.append(self.process_update(update)) return await asyncio.gather(*tasks) async def process_update(self, update): @@ -90,23 +94,56 @@ class Dispatcher: :return: """ self.last_update_id = update.update_id + has_context = context.check_configured() if update.message: + if has_context: + state = self.storage.get_state(chat=update.message.chat.id, + user=update.message.from_user.id) + context.set_value(USER_STATE, await state) return await self.message_handlers.notify(update.message) if update.edited_message: + if has_context: + state = self.storage.get_state(chat=update.edited_message.chat.id, + user=update.edited_message.from_user.id) + context.set_value(USER_STATE, await state) return await self.edited_message_handlers.notify(update.edited_message) if update.channel_post: + if has_context: + state = self.storage.get_state(chat=update.message.chat.id, + user=update.message.from_user.id) + context.set_value(USER_STATE, await state) return await self.channel_post_handlers.notify(update.channel_post) if update.edited_channel_post: + if has_context: + state = self.storage.get_state(chat=update.edited_channel_post.chat.id, + user=update.edited_channel_post.from_user.id) + context.set_value(USER_STATE, await state) return await self.edited_channel_post_handlers.notify(update.edited_channel_post) if update.inline_query: + if has_context: + state = self.storage.get_state(user=update.inline_query.from_user.id) + context.set_value(USER_STATE, await state) return await self.inline_query_handlers.notify(update.inline_query) if update.chosen_inline_result: + if has_context: + state = self.storage.get_state(user=update.chosen_inline_result.from_user.id) + context.set_value(USER_STATE, await state) return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result) if update.callback_query: + if has_context: + state = self.storage.get_state(chat=update.callback_query.message.chat.id, + user=update.callback_query.from_user.id) + context.set_value(USER_STATE, await state) return await self.callback_query_handlers.notify(update.callback_query) if update.shipping_query: + if has_context: + state = self.storage.get_state(user=update.shipping_query.from_user.id) + context.set_value(USER_STATE, await state) return await self.shipping_query_handlers.notify(update.shipping_query) if update.pre_checkout_query: + if has_context: + state = self.storage.get_state(user=update.pre_checkout_query.from_user.id) + context.set_value(USER_STATE, await state) return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query) async def start_pooling(self, timeout=20, relax=0.1, limit=None): @@ -121,6 +158,7 @@ class Dispatcher: if self._pooling: raise RuntimeError('Pooling already started') log.info('Start pooling.') + context.set_value(MODE, LONG_POOLING) self._pooling = True offset = None @@ -730,6 +768,7 @@ class Dispatcher: :param func: :return: """ + def process_response(task): response = task.result() self.loop.create_task(response.execute_response(self.bot)) diff --git a/aiogram/dispatcher/filters.py b/aiogram/dispatcher/filters.py index d62f5310..d4e114a4 100644 --- a/aiogram/dispatcher/filters.py +++ b/aiogram/dispatcher/filters.py @@ -1,8 +1,11 @@ import inspect import re +from aiogram.utils import context from ..utils.helper import Helper, HelperMode, Item +USER_STATE = 'USER_STATE' + async def check_filter(filter_, args, kwargs): if not callable(filter_): @@ -102,10 +105,14 @@ class StateFilter(AsyncFilter): if self.state == '*': return True - chat, user = self.get_target(obj) + if context.check_value(USER_STATE): + context_state = context.get_value(USER_STATE) + return self.state == context_state + else: + chat, user = self.get_target(obj) - if chat or user: - return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state + if chat or user: + return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state return False diff --git a/aiogram/dispatcher/webhook.py b/aiogram/dispatcher/webhook.py index 64f4b566..a52ae618 100644 --- a/aiogram/dispatcher/webhook.py +++ b/aiogram/dispatcher/webhook.py @@ -3,13 +3,14 @@ import asyncio.tasks import datetime import functools import typing -from typing import Union, Dict, Optional +from typing import Dict, Optional, Union from aiohttp import web from .. import types from ..bot import api -from ..bot.base import Integer, String, Boolean, Float +from ..bot.base import Boolean, Float, Integer, String +from ..utils import context from ..utils import json from ..utils.deprecated import warn_deprecated as warn from ..utils.exceptions import TimeoutWarning @@ -20,6 +21,10 @@ BOT_DISPATCHER_KEY = 'BOT_DISPATCHER' RESPONSE_TIMEOUT = 55 +WEBHOOK = 'webhook' +WEBHOOK_CONNECTION = 'WEBHOOK_CONNECTION' +WEBHOOK_REQUEST = 'WEBHOOK_REQUEST' + class WebhookRequestHandler(web.View): """ @@ -71,6 +76,11 @@ class WebhookRequestHandler(web.View): :return: :class:`aiohttp.web.Response` """ + + context.update_state({'CALLER': WEBHOOK, + WEBHOOK_CONNECTION: True, + WEBHOOK_REQUEST: self.request}) + dispatcher = self.get_dispatcher() update = await self.parse_update(dispatcher.bot) @@ -113,6 +123,7 @@ class WebhookRequestHandler(web.View): if fut.done(): return fut.result() else: + context.set_value(WEBHOOK_CONNECTION, False) fut.remove_done_callback(cb) fut.add_done_callback(self.respond_via_request) finally: