From c811e2b578e67104942ae1dea6d5136a19994156 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Sun, 15 Jun 2025 15:20:04 -0600 Subject: [PATCH] Extend File to handle streams (#85) --- disagreement/models.py | 56 ++++++++++++++++++++++++++++++++++++---- tests/test_send_files.py | 33 +++++++++++++++++++++-- 2 files changed, 82 insertions(+), 7 deletions(-) diff --git a/disagreement/models.py b/disagreement/models.py index 3cade7c..c3d05d3 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -3,10 +3,22 @@ Data models for Discord objects. """ import asyncio +import io import json +import os import re from dataclasses import dataclass -from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast +from typing import ( + Any, + AsyncIterator, + Dict, + List, + Optional, + TYPE_CHECKING, + Union, + cast, + IO, +) from .cache import ChannelCache, MemberCache from .caching import MemberCacheFlags @@ -636,11 +648,45 @@ class Attachment: class File: - """Represents a file to be uploaded.""" + """Represents a file to be uploaded. - def __init__(self, filename: str, data: bytes): - self.filename = filename - self.data = data + Parameters + ---------- + fp: + A file path, file-like object, or bytes-like object containing the + data to upload. + filename: + Optional name of the file. If not provided and ``fp`` is a path or has + a ``name`` attribute, the name will be inferred. + spoiler: + When ``True`` the filename will be prefixed with ``"SPOILER_"``. + """ + + def __init__( + self, + fp: Union[str, bytes, os.PathLike[Any], IO[bytes]], + *, + filename: Optional[str] = None, + spoiler: bool = False, + ) -> None: + if isinstance(fp, (str, os.PathLike)): + self.data = open(fp, "rb") + inferred = os.path.basename(fp) + elif isinstance(fp, bytes): + self.data = io.BytesIO(fp) + inferred = None + else: + self.data = fp + inferred = getattr(fp, "name", None) + + name = filename or inferred + if name is None: + raise ValueError("filename could not be inferred") + + if spoiler and not name.startswith("SPOILER_"): + name = f"SPOILER_{name}" + self.filename = name + self.spoiler = spoiler class AllowedMentions: diff --git a/tests/test_send_files.py b/tests/test_send_files.py index 5e64434..4843fa6 100644 --- a/tests/test_send_files.py +++ b/tests/test_send_files.py @@ -1,3 +1,4 @@ +import io import pytest from unittest.mock import AsyncMock @@ -38,7 +39,9 @@ async def test_http_send_message_with_files_uses_formdata(): "timestamp": "t", } ) - await http.send_message("c", "hi", files=[File("f.txt", b"data")]) + await http.send_message( + "c", "hi", files=[File(io.BytesIO(b"data"), filename="f.txt")] + ) args, kwargs = http.request.call_args assert kwargs["is_json"] is False @@ -75,7 +78,33 @@ async def test_client_send_message_passes_files(): "timestamp": "t", } ) - await client.send_message("c", "hi", files=[File("f.txt", b"data")]) + await client.send_message( + "c", "hi", files=[File(io.BytesIO(b"data"), filename="f.txt")] + ) client._http.send_message.assert_awaited_once() kwargs = client._http.send_message.call_args.kwargs assert kwargs["files"][0].filename == "f.txt" + + +@pytest.mark.asyncio +async def test_file_from_path(tmp_path): + file_path = tmp_path / "path.txt" + file_path.write_bytes(b"ok") + http = HTTPClient(token="t") + http.request = AsyncMock( + return_value={ + "id": "1", + "channel_id": "c", + "author": {"id": "2", "username": "u", "discriminator": "0001"}, + "content": "hi", + "timestamp": "t", + } + ) + await http.send_message("c", "hi", files=[File(file_path)]) + _, kwargs = http.request.call_args + assert kwargs["is_json"] is False + + +def test_file_spoiler(): + f = File(io.BytesIO(b"d"), filename="a.txt", spoiler=True) + assert f.filename == "SPOILER_a.txt"