From 7f0b03fa69d056c300fe604dda5e7aeadf6590a8 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Tue, 10 Jun 2025 15:45:25 -0600 Subject: [PATCH] Add file upload support (#13) --- disagreement/client.py | 3 +++ disagreement/components.py | 4 +-- disagreement/http.py | 52 ++++++++++++++++++++++++++++++++++---- disagreement/models.py | 10 +++++++- tests/test_send_files.py | 37 ++++++++++++++++++++++++++- 5 files changed, 97 insertions(+), 9 deletions(-) diff --git a/disagreement/client.py b/disagreement/client.py index 87d35f3..8e0e523 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -860,6 +860,7 @@ class Client: allowed_mentions: Optional[Dict[str, Any]] = None, message_reference: Optional[Dict[str, Any]] = None, attachments: Optional[List[Any]] = None, + files: Optional[List[Any]] = None, flags: Optional[int] = None, view: Optional["View"] = None, ) -> "Message": @@ -877,6 +878,7 @@ class Client: allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message. message_reference (Optional[Dict[str, Any]]): Message reference for replying. attachments (Optional[List[Any]]): Attachments to include with the message. + files (Optional[List[Any]]): Files to upload with the message. flags (Optional[int]): Message flags. view (Optional[View]): A view to send with the message. @@ -929,6 +931,7 @@ class Client: allowed_mentions=allowed_mentions, message_reference=message_reference, attachments=attachments, + files=files, flags=flags, ) diff --git a/disagreement/components.py b/disagreement/components.py index 11da89a..ae57012 100644 --- a/disagreement/components.py +++ b/disagreement/components.py @@ -18,7 +18,7 @@ from .models import ( Thumbnail, MediaGallery, MediaGalleryItem, - File, + FileComponent, Separator, Container, UnfurledMediaItem, @@ -140,7 +140,7 @@ def component_factory( ) if ctype == ComponentType.FILE: - return File( + return FileComponent( file=UnfurledMediaItem(**data["file"]), spoiler=data.get("spoiler", False), id=data.get("id"), diff --git a/disagreement/http.py b/disagreement/http.py index caf375f..06edaf5 100644 --- a/disagreement/http.py +++ b/disagreement/http.py @@ -20,7 +20,7 @@ from . import __version__ # For User-Agent if TYPE_CHECKING: from .client import Client - from .models import Message, Webhook + from .models import Message, Webhook, File from .interactions import ApplicationCommand, InteractionResponsePayload, Snowflake # Discord API constants @@ -60,7 +60,9 @@ class HTTPClient: self, method: str, endpoint: str, - payload: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + payload: Optional[ + Union[Dict[str, Any], List[Dict[str, Any]], aiohttp.FormData] + ] = None, params: Optional[Dict[str, Any]] = None, is_json: bool = True, use_auth_header: bool = True, @@ -205,6 +207,7 @@ class HTTPClient: allowed_mentions: Optional[dict] = None, message_reference: Optional[Dict[str, Any]] = None, attachments: Optional[List[Any]] = None, + files: Optional[List[Any]] = None, flags: Optional[int] = None, ) -> Dict[str, Any]: """Sends a message to a channel. @@ -213,6 +216,8 @@ class HTTPClient: ---------- attachments: A list of attachment payloads to include with the message. + files: + A list of :class:`File` objects containing binary data to upload. Returns ------- @@ -230,10 +235,28 @@ class HTTPClient: payload["components"] = components if allowed_mentions: payload["allowed_mentions"] = allowed_mentions + all_files: List["File"] = [] if attachments is not None: - payload["attachments"] = [ - a.to_dict() if hasattr(a, "to_dict") else a for a in attachments - ] + payload["attachments"] = [] + for a in attachments: + if hasattr(a, "data") and hasattr(a, "filename"): + idx = len(all_files) + all_files.append(a) + payload["attachments"].append({"id": idx, "filename": a.filename}) + else: + payload["attachments"].append( + a.to_dict() if hasattr(a, "to_dict") else a + ) + if files is not None: + for f in files: + if hasattr(f, "data") and hasattr(f, "filename"): + idx = len(all_files) + all_files.append(f) + if "attachments" not in payload: + payload["attachments"] = [] + payload["attachments"].append({"id": idx, "filename": f.filename}) + else: + raise TypeError("files must be File objects") if flags: payload["flags"] = flags if message_reference: @@ -242,6 +265,25 @@ class HTTPClient: if not payload: raise ValueError("Message must have content, embeds, or components.") + if all_files: + form = aiohttp.FormData() + form.add_field( + "payload_json", json.dumps(payload), content_type="application/json" + ) + for idx, f in enumerate(all_files): + form.add_field( + f"files[{idx}]", + f.data, + filename=f.filename, + content_type="application/octet-stream", + ) + return await self.request( + "POST", + f"/channels/{channel_id}/messages", + payload=form, + is_json=False, + ) + return await self.request( "POST", f"/channels/{channel_id}/messages", payload=payload ) diff --git a/disagreement/models.py b/disagreement/models.py index e91e5d5..a9d92fb 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -379,6 +379,14 @@ class Attachment: return payload +class File: + """Represents a file to be uploaded.""" + + def __init__(self, filename: str, data: bytes): + self.filename = filename + self.data = data + + class AllowedMentions: """Represents allowed mentions for a message or interaction response.""" @@ -1561,7 +1569,7 @@ class MediaGallery(Component): return payload -class File(Component): +class FileComponent(Component): """Represents a file component.""" def __init__( diff --git a/tests/test_send_files.py b/tests/test_send_files.py index cd796d8..5e64434 100644 --- a/tests/test_send_files.py +++ b/tests/test_send_files.py @@ -3,7 +3,7 @@ from unittest.mock import AsyncMock from disagreement.client import Client from disagreement.http import HTTPClient -from disagreement.models import Attachment +from disagreement.models import Attachment, File @pytest.mark.asyncio @@ -26,6 +26,23 @@ async def test_http_send_message_includes_attachments(): ) +@pytest.mark.asyncio +async def test_http_send_message_with_files_uses_formdata(): + http = HTTPClient(token="t") + http.request = AsyncMock( + return_value={ + "id": "1", + "channel_id": "c", + "author": {"id": "2", "username": "u", "discriminator": "0001"}, + "content": "hi", + "timestamp": "t", + } + ) + await http.send_message("c", "hi", files=[File("f.txt", b"data")]) + args, kwargs = http.request.call_args + assert kwargs["is_json"] is False + + @pytest.mark.asyncio async def test_client_send_message_passes_attachments(): client = Client(token="t") @@ -44,3 +61,21 @@ async def test_client_send_message_passes_attachments(): kwargs = client._http.send_message.call_args.kwargs assert kwargs["attachments"] == [{"id": "1"}] assert isinstance(msg.attachments[0], Attachment) + + +@pytest.mark.asyncio +async def test_client_send_message_passes_files(): + client = Client(token="t") + client._http.send_message = AsyncMock( + return_value={ + "id": "1", + "channel_id": "c", + "author": {"id": "2", "username": "u", "discriminator": "0001"}, + "content": "hi", + "timestamp": "t", + } + ) + await client.send_message("c", "hi", files=[File("f.txt", b"data")]) + client._http.send_message.assert_awaited_once() + kwargs = client._http.send_message.call_args.kwargs + assert kwargs["files"][0].filename == "f.txt"