Add file upload support (#13)
This commit is contained in:
parent
06f972851b
commit
7f0b03fa69
@ -860,6 +860,7 @@ class Client:
|
||||
allowed_mentions: Optional[Dict[str, Any]] = None,
|
||||
message_reference: Optional[Dict[str, Any]] = None,
|
||||
attachments: Optional[List[Any]] = None,
|
||||
files: Optional[List[Any]] = None,
|
||||
flags: Optional[int] = None,
|
||||
view: Optional["View"] = None,
|
||||
) -> "Message":
|
||||
@ -877,6 +878,7 @@ class Client:
|
||||
allowed_mentions (Optional[Dict[str, Any]]): Allowed mentions for the message.
|
||||
message_reference (Optional[Dict[str, Any]]): Message reference for replying.
|
||||
attachments (Optional[List[Any]]): Attachments to include with the message.
|
||||
files (Optional[List[Any]]): Files to upload with the message.
|
||||
flags (Optional[int]): Message flags.
|
||||
view (Optional[View]): A view to send with the message.
|
||||
|
||||
@ -929,6 +931,7 @@ class Client:
|
||||
allowed_mentions=allowed_mentions,
|
||||
message_reference=message_reference,
|
||||
attachments=attachments,
|
||||
files=files,
|
||||
flags=flags,
|
||||
)
|
||||
|
||||
|
@ -18,7 +18,7 @@ from .models import (
|
||||
Thumbnail,
|
||||
MediaGallery,
|
||||
MediaGalleryItem,
|
||||
File,
|
||||
FileComponent,
|
||||
Separator,
|
||||
Container,
|
||||
UnfurledMediaItem,
|
||||
@ -140,7 +140,7 @@ def component_factory(
|
||||
)
|
||||
|
||||
if ctype == ComponentType.FILE:
|
||||
return File(
|
||||
return FileComponent(
|
||||
file=UnfurledMediaItem(**data["file"]),
|
||||
spoiler=data.get("spoiler", False),
|
||||
id=data.get("id"),
|
||||
|
@ -20,7 +20,7 @@ from . import __version__ # For User-Agent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import Client
|
||||
from .models import Message, Webhook
|
||||
from .models import Message, Webhook, File
|
||||
from .interactions import ApplicationCommand, InteractionResponsePayload, Snowflake
|
||||
|
||||
# Discord API constants
|
||||
@ -60,7 +60,9 @@ class HTTPClient:
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
payload: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
payload: Optional[
|
||||
Union[Dict[str, Any], List[Dict[str, Any]], aiohttp.FormData]
|
||||
] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
is_json: bool = True,
|
||||
use_auth_header: bool = True,
|
||||
@ -205,6 +207,7 @@ class HTTPClient:
|
||||
allowed_mentions: Optional[dict] = None,
|
||||
message_reference: Optional[Dict[str, Any]] = None,
|
||||
attachments: Optional[List[Any]] = None,
|
||||
files: Optional[List[Any]] = None,
|
||||
flags: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Sends a message to a channel.
|
||||
@ -213,6 +216,8 @@ class HTTPClient:
|
||||
----------
|
||||
attachments:
|
||||
A list of attachment payloads to include with the message.
|
||||
files:
|
||||
A list of :class:`File` objects containing binary data to upload.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -230,10 +235,28 @@ class HTTPClient:
|
||||
payload["components"] = components
|
||||
if allowed_mentions:
|
||||
payload["allowed_mentions"] = allowed_mentions
|
||||
all_files: List["File"] = []
|
||||
if attachments is not None:
|
||||
payload["attachments"] = [
|
||||
a.to_dict() if hasattr(a, "to_dict") else a for a in attachments
|
||||
]
|
||||
payload["attachments"] = []
|
||||
for a in attachments:
|
||||
if hasattr(a, "data") and hasattr(a, "filename"):
|
||||
idx = len(all_files)
|
||||
all_files.append(a)
|
||||
payload["attachments"].append({"id": idx, "filename": a.filename})
|
||||
else:
|
||||
payload["attachments"].append(
|
||||
a.to_dict() if hasattr(a, "to_dict") else a
|
||||
)
|
||||
if files is not None:
|
||||
for f in files:
|
||||
if hasattr(f, "data") and hasattr(f, "filename"):
|
||||
idx = len(all_files)
|
||||
all_files.append(f)
|
||||
if "attachments" not in payload:
|
||||
payload["attachments"] = []
|
||||
payload["attachments"].append({"id": idx, "filename": f.filename})
|
||||
else:
|
||||
raise TypeError("files must be File objects")
|
||||
if flags:
|
||||
payload["flags"] = flags
|
||||
if message_reference:
|
||||
@ -242,6 +265,25 @@ class HTTPClient:
|
||||
if not payload:
|
||||
raise ValueError("Message must have content, embeds, or components.")
|
||||
|
||||
if all_files:
|
||||
form = aiohttp.FormData()
|
||||
form.add_field(
|
||||
"payload_json", json.dumps(payload), content_type="application/json"
|
||||
)
|
||||
for idx, f in enumerate(all_files):
|
||||
form.add_field(
|
||||
f"files[{idx}]",
|
||||
f.data,
|
||||
filename=f.filename,
|
||||
content_type="application/octet-stream",
|
||||
)
|
||||
return await self.request(
|
||||
"POST",
|
||||
f"/channels/{channel_id}/messages",
|
||||
payload=form,
|
||||
is_json=False,
|
||||
)
|
||||
|
||||
return await self.request(
|
||||
"POST", f"/channels/{channel_id}/messages", payload=payload
|
||||
)
|
||||
|
@ -379,6 +379,14 @@ class Attachment:
|
||||
return payload
|
||||
|
||||
|
||||
class File:
|
||||
"""Represents a file to be uploaded."""
|
||||
|
||||
def __init__(self, filename: str, data: bytes):
|
||||
self.filename = filename
|
||||
self.data = data
|
||||
|
||||
|
||||
class AllowedMentions:
|
||||
"""Represents allowed mentions for a message or interaction response."""
|
||||
|
||||
@ -1561,7 +1569,7 @@ class MediaGallery(Component):
|
||||
return payload
|
||||
|
||||
|
||||
class File(Component):
|
||||
class FileComponent(Component):
|
||||
"""Represents a file component."""
|
||||
|
||||
def __init__(
|
||||
|
@ -3,7 +3,7 @@ from unittest.mock import AsyncMock
|
||||
|
||||
from disagreement.client import Client
|
||||
from disagreement.http import HTTPClient
|
||||
from disagreement.models import Attachment
|
||||
from disagreement.models import Attachment, File
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -26,6 +26,23 @@ async def test_http_send_message_includes_attachments():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_send_message_with_files_uses_formdata():
|
||||
http = HTTPClient(token="t")
|
||||
http.request = AsyncMock(
|
||||
return_value={
|
||||
"id": "1",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||
"content": "hi",
|
||||
"timestamp": "t",
|
||||
}
|
||||
)
|
||||
await http.send_message("c", "hi", files=[File("f.txt", b"data")])
|
||||
args, kwargs = http.request.call_args
|
||||
assert kwargs["is_json"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_send_message_passes_attachments():
|
||||
client = Client(token="t")
|
||||
@ -44,3 +61,21 @@ async def test_client_send_message_passes_attachments():
|
||||
kwargs = client._http.send_message.call_args.kwargs
|
||||
assert kwargs["attachments"] == [{"id": "1"}]
|
||||
assert isinstance(msg.attachments[0], Attachment)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_send_message_passes_files():
|
||||
client = Client(token="t")
|
||||
client._http.send_message = AsyncMock(
|
||||
return_value={
|
||||
"id": "1",
|
||||
"channel_id": "c",
|
||||
"author": {"id": "2", "username": "u", "discriminator": "0001"},
|
||||
"content": "hi",
|
||||
"timestamp": "t",
|
||||
}
|
||||
)
|
||||
await client.send_message("c", "hi", files=[File("f.txt", b"data")])
|
||||
client._http.send_message.assert_awaited_once()
|
||||
kwargs = client._http.send_message.call_args.kwargs
|
||||
assert kwargs["files"][0].filename == "f.txt"
|
||||
|
Loading…
x
Reference in New Issue
Block a user