Add file upload support (#13)

This commit is contained in:
Slipstream 2025-06-10 15:45:25 -06:00 committed by GitHub
parent 06f972851b
commit 7f0b03fa69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 97 additions and 9 deletions

View File

@ -860,6 +860,7 @@ class Client:
allowed_mentions: Optional[Dict[str, Any]] = None,
message_reference: Optional[Dict[str, Any]] = None,
attachments: Optional[List[Any]] = None,
files: Optional[List[Any]] = None,
flags: Optional[int] = None,
view: Optional["View"] = None,
) -> "Message":
@ -877,6 +878,7 @@ class Client:
allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message.
message_reference (Optional[Dict[str, Any]]): Message reference for replying.
attachments (Optional[List[Any]]): Attachments to include with the message.
files (Optional[List[Any]]): Files to upload with the message.
flags (Optional[int]): Message flags.
view (Optional[View]): A view to send with the message.
@ -929,6 +931,7 @@ class Client:
allowed_mentions=allowed_mentions,
message_reference=message_reference,
attachments=attachments,
files=files,
flags=flags,
)

View File

@ -18,7 +18,7 @@ from .models import (
Thumbnail,
MediaGallery,
MediaGalleryItem,
File,
FileComponent,
Separator,
Container,
UnfurledMediaItem,
@ -140,7 +140,7 @@ def component_factory(
)
if ctype == ComponentType.FILE:
return File(
return FileComponent(
file=UnfurledMediaItem(**data["file"]),
spoiler=data.get("spoiler", False),
id=data.get("id"),

View File

@ -20,7 +20,7 @@ from . import __version__ # For User-Agent
if TYPE_CHECKING:
from .client import Client
from .models import Message, Webhook
from .models import Message, Webhook, File
from .interactions import ApplicationCommand, InteractionResponsePayload, Snowflake
# Discord API constants
@ -60,7 +60,9 @@ class HTTPClient:
self,
method: str,
endpoint: str,
payload: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
payload: Optional[
Union[Dict[str, Any], List[Dict[str, Any]], aiohttp.FormData]
] = None,
params: Optional[Dict[str, Any]] = None,
is_json: bool = True,
use_auth_header: bool = True,
@ -205,6 +207,7 @@ class HTTPClient:
allowed_mentions: Optional[dict] = None,
message_reference: Optional[Dict[str, Any]] = None,
attachments: Optional[List[Any]] = None,
files: Optional[List[Any]] = None,
flags: Optional[int] = None,
) -> Dict[str, Any]:
"""Sends a message to a channel.
@ -213,6 +216,8 @@ class HTTPClient:
----------
attachments:
A list of attachment payloads to include with the message.
files:
A list of :class:`File` objects containing binary data to upload.
Returns
-------
@ -230,10 +235,28 @@ class HTTPClient:
payload["components"] = components
if allowed_mentions:
payload["allowed_mentions"] = allowed_mentions
all_files: List["File"] = []
if attachments is not None:
payload["attachments"] = [
a.to_dict() if hasattr(a, "to_dict") else a for a in attachments
]
payload["attachments"] = []
for a in attachments:
if hasattr(a, "data") and hasattr(a, "filename"):
idx = len(all_files)
all_files.append(a)
payload["attachments"].append({"id": idx, "filename": a.filename})
else:
payload["attachments"].append(
a.to_dict() if hasattr(a, "to_dict") else a
)
if files is not None:
for f in files:
if hasattr(f, "data") and hasattr(f, "filename"):
idx = len(all_files)
all_files.append(f)
if "attachments" not in payload:
payload["attachments"] = []
payload["attachments"].append({"id": idx, "filename": f.filename})
else:
raise TypeError("files must be File objects")
if flags:
payload["flags"] = flags
if message_reference:
@ -242,6 +265,25 @@ class HTTPClient:
if not payload:
raise ValueError("Message must have content, embeds, or components.")
if all_files:
form = aiohttp.FormData()
form.add_field(
"payload_json", json.dumps(payload), content_type="application/json"
)
for idx, f in enumerate(all_files):
form.add_field(
f"files[{idx}]",
f.data,
filename=f.filename,
content_type="application/octet-stream",
)
return await self.request(
"POST",
f"/channels/{channel_id}/messages",
payload=form,
is_json=False,
)
return await self.request(
"POST", f"/channels/{channel_id}/messages", payload=payload
)

View File

@ -379,6 +379,14 @@ class Attachment:
return payload
class File:
"""Represents a file to be uploaded."""
def __init__(self, filename: str, data: bytes):
self.filename = filename
self.data = data
class AllowedMentions:
"""Represents allowed mentions for a message or interaction response."""
@ -1561,7 +1569,7 @@ class MediaGallery(Component):
return payload
class File(Component):
class FileComponent(Component):
"""Represents a file component."""
def __init__(

View File

@ -3,7 +3,7 @@ from unittest.mock import AsyncMock
from disagreement.client import Client
from disagreement.http import HTTPClient
from disagreement.models import Attachment
from disagreement.models import Attachment, File
@pytest.mark.asyncio
@ -26,6 +26,23 @@ async def test_http_send_message_includes_attachments():
)
@pytest.mark.asyncio
async def test_http_send_message_with_files_uses_formdata():
http = HTTPClient(token="t")
http.request = AsyncMock(
return_value={
"id": "1",
"channel_id": "c",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": "hi",
"timestamp": "t",
}
)
await http.send_message("c", "hi", files=[File("f.txt", b"data")])
args, kwargs = http.request.call_args
assert kwargs["is_json"] is False
@pytest.mark.asyncio
async def test_client_send_message_passes_attachments():
client = Client(token="t")
@ -44,3 +61,21 @@ async def test_client_send_message_passes_attachments():
kwargs = client._http.send_message.call_args.kwargs
assert kwargs["attachments"] == [{"id": "1"}]
assert isinstance(msg.attachments[0], Attachment)
@pytest.mark.asyncio
async def test_client_send_message_passes_files():
client = Client(token="t")
client._http.send_message = AsyncMock(
return_value={
"id": "1",
"channel_id": "c",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": "hi",
"timestamp": "t",
}
)
await client.send_message("c", "hi", files=[File("f.txt", b"data")])
client._http.send_message.assert_awaited_once()
kwargs = client._http.send_message.call_args.kwargs
assert kwargs["files"][0].filename == "f.txt"