Extend File to handle streams (#85)

This commit is contained in:
Slipstream 2025-06-15 15:20:04 -06:00 committed by GitHub
parent 9f2fc0857b
commit c811e2b578
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 82 additions and 7 deletions

View File

@ -3,10 +3,22 @@ Data models for Discord objects.
""" """
import asyncio import asyncio
import io
import json import json
import os
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast from typing import (
Any,
AsyncIterator,
Dict,
List,
Optional,
TYPE_CHECKING,
Union,
cast,
IO,
)
from .cache import ChannelCache, MemberCache from .cache import ChannelCache, MemberCache
from .caching import MemberCacheFlags from .caching import MemberCacheFlags
@ -636,11 +648,45 @@ class Attachment:
class File: class File:
"""Represents a file to be uploaded.""" """Represents a file to be uploaded.
def __init__(self, filename: str, data: bytes): Parameters
self.filename = filename ----------
self.data = data fp:
A file path, file-like object, or bytes-like object containing the
data to upload.
filename:
Optional name of the file. If not provided and ``fp`` is a path or has
a ``name`` attribute, the name will be inferred.
spoiler:
When ``True`` the filename will be prefixed with ``"SPOILER_"``.
"""
def __init__(
self,
fp: Union[str, bytes, os.PathLike[Any], IO[bytes]],
*,
filename: Optional[str] = None,
spoiler: bool = False,
) -> None:
if isinstance(fp, (str, os.PathLike)):
self.data = open(fp, "rb")
inferred = os.path.basename(fp)
elif isinstance(fp, bytes):
self.data = io.BytesIO(fp)
inferred = None
else:
self.data = fp
inferred = getattr(fp, "name", None)
name = filename or inferred
if name is None:
raise ValueError("filename could not be inferred")
if spoiler and not name.startswith("SPOILER_"):
name = f"SPOILER_{name}"
self.filename = name
self.spoiler = spoiler
class AllowedMentions: class AllowedMentions:

View File

@ -1,3 +1,4 @@
import io
import pytest import pytest
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
@ -38,7 +39,9 @@ async def test_http_send_message_with_files_uses_formdata():
"timestamp": "t", "timestamp": "t",
} }
) )
await http.send_message("c", "hi", files=[File("f.txt", b"data")]) await http.send_message(
"c", "hi", files=[File(io.BytesIO(b"data"), filename="f.txt")]
)
args, kwargs = http.request.call_args args, kwargs = http.request.call_args
assert kwargs["is_json"] is False assert kwargs["is_json"] is False
@ -75,7 +78,33 @@ async def test_client_send_message_passes_files():
"timestamp": "t", "timestamp": "t",
} }
) )
await client.send_message("c", "hi", files=[File("f.txt", b"data")]) await client.send_message(
"c", "hi", files=[File(io.BytesIO(b"data"), filename="f.txt")]
)
client._http.send_message.assert_awaited_once() client._http.send_message.assert_awaited_once()
kwargs = client._http.send_message.call_args.kwargs kwargs = client._http.send_message.call_args.kwargs
assert kwargs["files"][0].filename == "f.txt" assert kwargs["files"][0].filename == "f.txt"
@pytest.mark.asyncio
async def test_file_from_path(tmp_path):
file_path = tmp_path / "path.txt"
file_path.write_bytes(b"ok")
http = HTTPClient(token="t")
http.request = AsyncMock(
return_value={
"id": "1",
"channel_id": "c",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": "hi",
"timestamp": "t",
}
)
await http.send_message("c", "hi", files=[File(file_path)])
_, kwargs = http.request.call_args
assert kwargs["is_json"] is False
def test_file_spoiler():
f = File(io.BytesIO(b"d"), filename="a.txt", spoiler=True)
assert f.filename == "SPOILER_a.txt"