diff --git a/disagreement/http.py b/disagreement/http.py index b67068a..32afb83 100644 --- a/disagreement/http.py +++ b/disagreement/http.py @@ -734,17 +734,37 @@ class HTTPClient: return await self.request("GET", f"/channels/{channel_id}/invites") - async def create_invite( - self, channel_id: "Snowflake", payload: Dict[str, Any] - ) -> "Invite": - """Creates an invite for a channel.""" + async def create_invite( + self, channel_id: "Snowflake", payload: Dict[str, Any] + ) -> "Invite": + """Creates an invite for a channel.""" data = await self.request( "POST", f"/channels/{channel_id}/invites", payload=payload ) - from .models import Invite - - return Invite.from_dict(data) + from .models import Invite + + return Invite.from_dict(data) + + async def create_channel_invite( + self, + channel_id: "Snowflake", + payload: Dict[str, Any], + *, + reason: Optional[str] = None, + ) -> "Invite": + """Creates an invite for a channel with an optional audit log reason.""" + + headers = {"X-Audit-Log-Reason": reason} if reason else None + data = await self.request( + "POST", + f"/channels/{channel_id}/invites", + payload=payload, + custom_headers=headers, + ) + from .models import Invite + + return Invite.from_dict(data) async def delete_invite(self, code: str) -> None: """Deletes an invite by code.""" diff --git a/disagreement/models.py b/disagreement/models.py index d7c83e6..d983696 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -1663,11 +1663,11 @@ class TextChannel(Channel, Messageable): messages_data = await self._client._http.get_pinned_messages(self.id) return [self._client.parse_message(m) for m in messages_data] - async def create_thread( - self, - name: str, - *, - type: ChannelType = ChannelType.PUBLIC_THREAD, + async def create_thread( + self, + name: str, + *, + type: ChannelType = ChannelType.PUBLIC_THREAD, auto_archive_duration: Optional[AutoArchiveDuration] = None, invitable: Optional[bool] = None, rate_limit_per_user: Optional[int] = None, @@ -1710,8 +1710,33 @@ class TextChannel(Channel, Messageable): if rate_limit_per_user is not None: payload["rate_limit_per_user"] = rate_limit_per_user - data = await self._client._http.start_thread_without_message(self.id, payload) - return cast("Thread", self._client.parse_channel(data)) + data = await self._client._http.start_thread_without_message(self.id, payload) + return cast("Thread", self._client.parse_channel(data)) + + async def create_invite( + self, + *, + max_age: Optional[int] = None, + max_uses: Optional[int] = None, + temporary: Optional[bool] = None, + unique: Optional[bool] = None, + reason: Optional[str] = None, + ) -> "Invite": + """|coro| Create an invite to this channel.""" + + payload: Dict[str, Any] = {} + if max_age is not None: + payload["max_age"] = max_age + if max_uses is not None: + payload["max_uses"] = max_uses + if temporary is not None: + payload["temporary"] = temporary + if unique is not None: + payload["unique"] = unique + + return await self._client._http.create_channel_invite( + self.id, payload, reason=reason + ) class VoiceChannel(Channel): diff --git a/tests/test_invites.py b/tests/test_invites.py new file mode 100644 index 0000000..4b4eb39 --- /dev/null +++ b/tests/test_invites.py @@ -0,0 +1,39 @@ +import pytest +from types import SimpleNamespace +from unittest.mock import AsyncMock + +from disagreement.client import Client +from disagreement.http import HTTPClient +from disagreement.models import TextChannel, Invite + + +@pytest.mark.asyncio +async def test_create_channel_invite_calls_request_and_returns_model(): + http = HTTPClient(token="t") + http.request = AsyncMock(return_value={"code": "abc"}) + invite = await http.create_channel_invite("123", {"max_age": 60}, reason="r") + + http.request.assert_called_once_with( + "POST", + "/channels/123/invites", + payload={"max_age": 60}, + custom_headers={"X-Audit-Log-Reason": "r"}, + ) + assert isinstance(invite, Invite) + + +@pytest.mark.asyncio +async def test_textchannel_create_invite_uses_http(): + http = SimpleNamespace( + create_channel_invite=AsyncMock(return_value=Invite.from_dict({"code": "a"})) + ) + client = Client(token="t") + client._http = http + + channel = TextChannel({"id": "c", "type": 0}, client) + invite = await channel.create_invite(max_age=30, reason="why") + + http.create_channel_invite.assert_awaited_once_with( + "c", {"max_age": 30}, reason="why" + ) + assert isinstance(invite, Invite)