Fix #665, add separate parametrs for saving to directory and file (#677)

* close #665

* add backward compatibility

* improve doc, codestyle

* warning text update

* use tmpdir fixture in tests
This commit is contained in:
darksidecat 2021-09-06 00:05:52 +03:00 committed by GitHub
parent 82b1b1ab03
commit 358ecc7821
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 171 additions and 11 deletions

View file

@ -1,5 +1,9 @@
import os
import pathlib
from io import IOBase
from typing import Union, Optional
from aiogram.utils.deprecated import warn_deprecated
class Downloadable:
@ -7,32 +11,86 @@ class Downloadable:
Mixin for files
"""
async def download(self, destination=None, timeout=30, chunk_size=65536, seek=True, make_dirs=True):
async def download(
self,
destination=None,
timeout=30,
chunk_size=65536,
seek=True,
make_dirs=True,
*,
destination_dir: Optional[Union[str, pathlib.Path]] = None,
destination_file: Optional[Union[str, pathlib.Path, IOBase]] = None
):
"""
Download file
:param destination: filename or instance of :class:`io.IOBase`. For e. g. :class:`io.BytesIO`
At most one of these parameters can be used: :param destination_dir:, :param destination_file:
:param destination: deprecated, use :param destination_dir: or :param destination_file: instead
:param timeout: Integer
:param chunk_size: Integer
:param seek: Boolean - go to start of file when downloading is finished.
:param make_dirs: Make dirs if not exist
:param destination_dir: directory for saving files
:param destination_file: path to the file or instance of :class:`io.IOBase`. For e. g. :class:`io.BytesIO`
:return: destination
"""
if destination:
warn_deprecated(
"destination parameter is deprecated, please use destination_dir or destination_file."
)
if destination_dir and destination_file:
raise ValueError(
"Use only one of the parameters: destination_dir or destination_file."
)
file, destination = await self._prepare_destination(
destination,
destination_dir,
destination_file,
make_dirs
)
return await self.bot.download_file(
file_path=file.file_path,
destination=destination,
timeout=timeout,
chunk_size=chunk_size,
seek=seek,
)
async def _prepare_destination(self, dest, destination_dir, destination_file, make_dirs):
file = await self.get_file()
is_path = True
if destination is None:
if not(any((dest, destination_dir, destination_file))):
destination = file.file_path
elif isinstance(destination, (str, pathlib.Path)) and os.path.isdir(destination):
destination = os.path.join(destination, file.file_path)
else:
is_path = False
if is_path and make_dirs:
elif dest: # backward compatibility
if isinstance(dest, IOBase):
return file, dest
if isinstance(dest, (str, pathlib.Path)) and os.path.isdir(dest):
destination = os.path.join(dest, file.file_path)
else:
destination = dest
elif destination_dir:
if isinstance(destination_dir, (str, pathlib.Path)):
destination = os.path.join(destination_dir, file.file_path)
else:
raise TypeError("destination_dir must be str or pathlib.Path")
else:
if isinstance(destination_file, IOBase):
return file, destination_file
elif isinstance(destination_file, (str, pathlib.Path)):
destination = destination_file
else:
raise TypeError("destination_file must be str, pathlib.Path or io.IOBase type")
if make_dirs and os.path.dirname(destination):
os.makedirs(os.path.dirname(destination), exist_ok=True)
return await self.bot.download_file(file_path=file.file_path, destination=destination, timeout=timeout,
chunk_size=chunk_size, seek=seek)
return file, destination
async def get_file(self):
"""

102
tests/types/test_mixins.py Normal file
View file

@ -0,0 +1,102 @@
import os
from io import BytesIO
from pathlib import Path
import pytest
from aiogram import Bot
from aiogram.types import File
from aiogram.types.mixins import Downloadable
from tests import TOKEN
from tests.types.dataset import FILE
pytestmark = pytest.mark.asyncio
@pytest.fixture(name='bot')
async def bot_fixture():
""" Bot fixture """
_bot = Bot(TOKEN)
yield _bot
await _bot.session.close()
@pytest.fixture
def tmppath(tmpdir, request):
os.chdir(tmpdir)
yield Path(tmpdir)
os.chdir(request.config.invocation_dir)
@pytest.fixture
def downloadable(bot):
async def get_file():
return File(**FILE)
downloadable = Downloadable()
downloadable.get_file = get_file
downloadable.bot = bot
return downloadable
class TestDownloadable:
async def test_download_make_dirs_false_nodir(self, tmppath, downloadable):
with pytest.raises(FileNotFoundError):
await downloadable.download(make_dirs=False)
async def test_download_make_dirs_false_mkdir(self, tmppath, downloadable):
os.mkdir('voice')
await downloadable.download(make_dirs=False)
assert os.path.isfile(tmppath.joinpath(FILE["file_path"]))
async def test_download_make_dirs_true(self, tmppath, downloadable):
await downloadable.download(make_dirs=True)
assert os.path.isfile(tmppath.joinpath(FILE["file_path"]))
async def test_download_deprecation_warning(self, tmppath, downloadable):
with pytest.deprecated_call():
await downloadable.download("test.file")
async def test_download_destination(self, tmppath, downloadable):
with pytest.deprecated_call():
await downloadable.download("test.file")
assert os.path.isfile(tmppath.joinpath('test.file'))
async def test_download_destination_dir_exist(self, tmppath, downloadable):
os.mkdir("test_folder")
with pytest.deprecated_call():
await downloadable.download("test_folder")
assert os.path.isfile(tmppath.joinpath('test_folder', FILE["file_path"]))
async def test_download_destination_with_dir(self, tmppath, downloadable):
with pytest.deprecated_call():
await downloadable.download(os.path.join('dir_name', 'file_name'))
assert os.path.isfile(tmppath.joinpath('dir_name', 'file_name'))
async def test_download_destination_io_bytes(self, tmppath, downloadable):
file = BytesIO()
with pytest.deprecated_call():
await downloadable.download(file)
assert len(file.read()) != 0
async def test_download_raise_value_error(self, tmppath, downloadable):
with pytest.raises(ValueError):
await downloadable.download(destination_dir="a", destination_file="b")
async def test_download_destination_dir(self, tmppath, downloadable):
await downloadable.download(destination_dir='test_dir')
assert os.path.isfile(tmppath.joinpath('test_dir', FILE["file_path"]))
async def test_download_destination_file(self, tmppath, downloadable):
await downloadable.download(destination_file='file_name')
assert os.path.isfile(tmppath.joinpath('file_name'))
async def test_download_destination_file_with_dir(self, tmppath, downloadable):
await downloadable.download(destination_file=os.path.join('dir_name', 'file_name'))
assert os.path.isfile(tmppath.joinpath('dir_name', 'file_name'))
async def test_download_io_bytes(self, tmppath, downloadable):
file = BytesIO()
await downloadable.download(destination_file=file)
assert len(file.read()) != 0