diff --git a/aiogram/contrib/middlewares/fsm.py b/aiogram/contrib/middlewares/fsm.py new file mode 100644 index 00000000..e3550a34 --- /dev/null +++ b/aiogram/contrib/middlewares/fsm.py @@ -0,0 +1,80 @@ +import copy +import weakref + +from aiogram.dispatcher.middlewares import LifetimeControllerMiddleware +from aiogram.dispatcher.storage import FSMContext + + +class FSMMiddleware(LifetimeControllerMiddleware): + skip_patterns = ['error', 'update'] + + def __init__(self): + super(FSMMiddleware, self).__init__() + self._proxies = weakref.WeakKeyDictionary() + + async def pre_process(self, obj, data, *args): + proxy = await FSMSStorageProxy.create(self.manager.dispatcher.current_state()) + data['state_data'] = proxy + + async def post_process(self, obj, data, *args): + proxy = data.get('state_data', None) + if isinstance(proxy, FSMSStorageProxy): + await proxy.save() + + +class FSMSStorageProxy(dict): + def __init__(self, fsm_context: FSMContext): + super(FSMSStorageProxy, self).__init__() + self.fsm_context = fsm_context + self._copy = {} + self._data = {} + self._state = None + self._is_dirty = False + + @classmethod + async def create(cls, fsm_context: FSMContext): + """ + :param fsm_context: + :return: + """ + proxy = cls(fsm_context) + await proxy.load() + return proxy + + async def load(self): + self.clear() + self._state = await self.fsm_context.get_state() + self.update(await self.fsm_context.get_data()) + self._copy = copy.deepcopy(self) + self._is_dirty = False + + @property + def state(self): + return self._state + + @state.setter + def state(self, value): + self._state = value + self._is_dirty = True + + @state.deleter + def state(self): + self._state = None + self._is_dirty = True + + async def save(self, force=False): + if self._copy != self or force: + await self.fsm_context.set_data(data=self) + if self._is_dirty or force: + await self.fsm_context.set_state(self.state) + self._is_dirty = False + self._copy = copy.deepcopy(self) + + def __str__(self): + s = super(FSMSStorageProxy, self).__str__() + readable_state = f"'{self.state}'" if self.state else "''" + return f"<{self.__class__.__name__}(state={readable_state}, data={s})>" + + def clear(self): + del self.state + return super(FSMSStorageProxy, self).clear() diff --git a/aiogram/dispatcher/middlewares.py b/aiogram/dispatcher/middlewares.py index 4de9d61f..dba3db4c 100644 --- a/aiogram/dispatcher/middlewares.py +++ b/aiogram/dispatcher/middlewares.py @@ -101,3 +101,28 @@ class BaseMiddleware: if not handler: return None await handler(*args) + + +class LifetimeControllerMiddleware(BaseMiddleware): + # TODO: Rename class + + skip_patterns = None + + async def pre_process(self, obj, data, *args): + pass + + async def post_process(self, obj, data, *args): + pass + + async def trigger(self, action, args): + if self.skip_patterns is not None and any(item in action for item in self.skip_patterns): + return False + + obj, *args, data = args + if action.startswith('pre_process_'): + await self.pre_process(obj, data, *args) + elif action.startswith('post_process_'): + await self.post_process(obj, data, *args) + else: + return False + return True diff --git a/examples/finite_state_machine_example_2.py b/examples/finite_state_machine_example_2.py new file mode 100644 index 00000000..5a2996bd --- /dev/null +++ b/examples/finite_state_machine_example_2.py @@ -0,0 +1,126 @@ +""" +This example is equals with 'finite_state_machine_example.py' but with FSM Middleware + +Note that FSM Middleware implements the more simple methods for working with storage. + +With that middleware all data from storage will be loaded before event will be processed +and data will be stored after processing the event. +""" +import asyncio + +import aiogram.utils.markdown as md +from aiogram import Bot, Dispatcher, types +from aiogram.contrib.fsm_storage.memory import MemoryStorage +from aiogram.contrib.middlewares.fsm import FSMMiddleware, FSMSStorageProxy +from aiogram.dispatcher.filters.state import State, StatesGroup +from aiogram.utils import executor + +API_TOKEN = 'BOT TOKEN HERE' + +loop = asyncio.get_event_loop() + +bot = Bot(token=API_TOKEN, loop=loop) + +# For example use simple MemoryStorage for Dispatcher. +storage = MemoryStorage() +dp = Dispatcher(bot, storage=storage) +dp.middleware.setup(FSMMiddleware()) + + +# States +class Form(StatesGroup): + name = State() # Will be represented in storage as 'Form:name' + age = State() # Will be represented in storage as 'Form:age' + gender = State() # Will be represented in storage as 'Form:gender' + + +@dp.message_handler(commands=['start']) +async def cmd_start(message: types.Message): + """ + Conversation's entry point + """ + # Set state + await Form.first() + + await message.reply("Hi there! What's your name?") + + +# You can use state '*' if you need to handle all states +@dp.message_handler(state='*', commands=['cancel']) +@dp.message_handler(lambda message: message.text.lower() == 'cancel', state='*') +async def cancel_handler(message: types.Message, state_data: FSMSStorageProxy): + """ + Allow user to cancel any action + """ + if state_data.state is None: + return + + # Cancel state and inform user about it + del state_data.state + # And remove keyboard (just in case) + await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove()) + + +@dp.message_handler(state=Form.name) +async def process_name(message: types.Message, state_data: FSMSStorageProxy): + """ + Process user name + """ + state_data.state = Form.age + state_data['name'] = message.text + + await message.reply("How old are you?") + + +# Check age. Age gotta be digit +@dp.message_handler(lambda message: not message.text.isdigit(), state=Form.age) +async def failed_process_age(message: types.Message): + """ + If age is invalid + """ + return await message.reply("Age gotta be a number.\nHow old are you? (digits only)") + + +@dp.message_handler(lambda message: message.text.isdigit(), state=Form.age) +async def process_age(message: types.Message, state_data: FSMSStorageProxy): + # Update state and data + state_data.state = Form.gender + state_data['age'] = int(message.text) + + # Configure ReplyKeyboardMarkup + markup = types.ReplyKeyboardMarkup(resize_keyboard=True, selective=True) + markup.add("Male", "Female") + markup.add("Other") + + await message.reply("What is your gender?", reply_markup=markup) + + +@dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender) +async def failed_process_gender(message: types.Message): + """ + In this example gender has to be one of: Male, Female, Other. + """ + return await message.reply("Bad gender name. Choose you gender from keyboard.") + + +@dp.message_handler(state=Form.gender) +async def process_gender(message: types.Message, state_data: FSMSStorageProxy): + state_data['gender'] = message.text + + # Remove keyboard + markup = types.ReplyKeyboardRemove() + + # And send message + await bot.send_message(message.chat.id, md.text( + md.text('Hi! Nice to meet you,', md.bold(state_data['name'])), + md.text('Age:', state_data['age']), + md.text('Gender:', state_data['gender']), + sep='\n'), reply_markup=markup, parse_mode=types.ParseMode.MARKDOWN) + + # Finish conversation + # WARNING! This method will destroy all data in storage for current user! + state_data.clear() + + +if __name__ == '__main__': + executor.start_polling(dp, loop=loop, skip_updates=True)