From 2d6c2cb0bed99ffb554a2407dd857fcfc7bbf863 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Tue, 10 Jun 2025 21:08:32 -0600 Subject: [PATCH] Add stage channel and instance support (#43) --- disagreement/enums.py | 7 +++ disagreement/http.py | 44 +++++++++++++++++- disagreement/models.py | 86 ++++++++++++++++++++++++++++++++++-- tests/test_stage_instance.py | 63 ++++++++++++++++++++++++++ 4 files changed, 195 insertions(+), 5 deletions(-) create mode 100644 tests/test_stage_instance.py diff --git a/disagreement/enums.py b/disagreement/enums.py index 70f06f4..f77620f 100644 --- a/disagreement/enums.py +++ b/disagreement/enums.py @@ -335,6 +335,13 @@ class ChannelType(IntEnum): GUILD_MEDIA = 16 # (Still in development) a channel that can only contain media +class StageInstancePrivacyLevel(IntEnum): + """Privacy level of a stage instance.""" + + PUBLIC = 1 + GUILD_ONLY = 2 + + class OverwriteType(IntEnum): """Type of target for a permission overwrite.""" diff --git a/disagreement/http.py b/disagreement/http.py index 6b1ac4d..62fa089 100644 --- a/disagreement/http.py +++ b/disagreement/http.py @@ -23,7 +23,7 @@ from .interactions import InteractionResponsePayload if TYPE_CHECKING: from .client import Client - from .models import Message, Webhook, File, Invite + from .models import Message, Webhook, File, StageInstance, Invite from .interactions import ApplicationCommand, Snowflake # Discord API constants @@ -924,6 +924,48 @@ class HTTPClient: """Sends a typing indicator to the specified channel.""" await self.request("POST", f"/channels/{channel_id}/typing") + async def start_stage_instance( + self, payload: Dict[str, Any], reason: Optional[str] = None + ) -> "StageInstance": + """Starts a stage instance.""" + + headers = {"X-Audit-Log-Reason": reason} if reason else None + data = await self.request( + "POST", "/stage-instances", payload=payload, custom_headers=headers + ) + from .models import StageInstance + + return StageInstance(data) + + async def edit_stage_instance( + self, + channel_id: "Snowflake", + payload: Dict[str, Any], + reason: Optional[str] = None, + ) -> "StageInstance": + """Edits an existing stage instance.""" + + headers = {"X-Audit-Log-Reason": reason} if reason else None + data = await self.request( + "PATCH", + f"/stage-instances/{channel_id}", + payload=payload, + custom_headers=headers, + ) + from .models import StageInstance + + return StageInstance(data) + + async def end_stage_instance( + self, channel_id: "Snowflake", reason: Optional[str] = None + ) -> None: + """Ends a stage instance.""" + + headers = {"X-Audit-Log-Reason": reason} if reason else None + await self.request( + "DELETE", f"/stage-instances/{channel_id}", custom_headers=headers + ) + async def get_voice_regions(self) -> List[Dict[str, Any]]: """Returns available voice regions.""" return await self.request("GET", "/voice/regions") diff --git a/disagreement/models.py b/disagreement/models.py index cf4ba33..a52cd02 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -1160,6 +1160,85 @@ class VoiceChannel(Channel): return f"" +class StageChannel(VoiceChannel): + """Represents a guild stage channel.""" + + def __repr__(self) -> str: + return f"" + + async def start_stage_instance( + self, + topic: str, + *, + privacy_level: int = 2, + reason: Optional[str] = None, + guild_scheduled_event_id: Optional[str] = None, + ) -> "StageInstance": + if not hasattr(self._client, "_http"): + raise DisagreementException("Client missing HTTP for stage instance") + + payload: Dict[str, Any] = { + "channel_id": self.id, + "topic": topic, + "privacy_level": privacy_level, + } + if guild_scheduled_event_id is not None: + payload["guild_scheduled_event_id"] = guild_scheduled_event_id + + instance = await self._client._http.start_stage_instance(payload, reason=reason) + instance._client = self._client + return instance + + async def edit_stage_instance( + self, + *, + topic: Optional[str] = None, + privacy_level: Optional[int] = None, + reason: Optional[str] = None, + ) -> "StageInstance": + if not hasattr(self._client, "_http"): + raise DisagreementException("Client missing HTTP for stage instance") + + payload: Dict[str, Any] = {} + if topic is not None: + payload["topic"] = topic + if privacy_level is not None: + payload["privacy_level"] = privacy_level + + instance = await self._client._http.edit_stage_instance( + self.id, payload, reason=reason + ) + instance._client = self._client + return instance + + async def end_stage_instance(self, *, reason: Optional[str] = None) -> None: + if not hasattr(self._client, "_http"): + raise DisagreementException("Client missing HTTP for stage instance") + + await self._client._http.end_stage_instance(self.id, reason=reason) + + +class StageInstance: + """Represents a stage instance.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ) -> None: + self._client = client_instance + self.id: str = data["id"] + self.guild_id: Optional[str] = data.get("guild_id") + self.channel_id: str = data["channel_id"] + self.topic: str = data["topic"] + self.privacy_level: int = data.get("privacy_level", 2) + self.discoverable_disabled: bool = data.get("discoverable_disabled", False) + self.guild_scheduled_event_id: Optional[str] = data.get( + "guild_scheduled_event_id" + ) + + def __repr__(self) -> str: + return f"" + + class CategoryChannel(Channel): """Represents a guild category channel.""" @@ -2100,11 +2179,10 @@ def channel_factory(data: Dict[str, Any], client: "Client") -> Channel: ChannelType.GUILD_ANNOUNCEMENT.value, ): return TextChannel(data, client) - if channel_type in ( - ChannelType.GUILD_VOICE.value, - ChannelType.GUILD_STAGE_VOICE.value, - ): + if channel_type == ChannelType.GUILD_VOICE.value: return VoiceChannel(data, client) + if channel_type == ChannelType.GUILD_STAGE_VOICE.value: + return StageChannel(data, client) if channel_type == ChannelType.GUILD_CATEGORY.value: return CategoryChannel(data, client) if channel_type in ( diff --git a/tests/test_stage_instance.py b/tests/test_stage_instance.py new file mode 100644 index 0000000..e2ea7b3 --- /dev/null +++ b/tests/test_stage_instance.py @@ -0,0 +1,63 @@ +import pytest +from types import SimpleNamespace +from unittest.mock import AsyncMock + +from disagreement.http import HTTPClient +from disagreement.models import StageChannel, StageInstance +from disagreement.enums import ChannelType + + +@pytest.mark.asyncio +async def test_http_start_stage_instance_calls_request(): + http = HTTPClient(token="t") + http.request = AsyncMock(return_value={"id": "1", "channel_id": "c", "topic": "t"}) + payload = {"channel_id": "c", "topic": "t", "privacy_level": 2} + instance = await http.start_stage_instance(payload) + http.request.assert_called_once_with( + "POST", + "/stage-instances", + payload=payload, + custom_headers=None, + ) + assert isinstance(instance, StageInstance) + + +@pytest.mark.asyncio +async def test_http_end_stage_instance_calls_request(): + http = HTTPClient(token="t") + http.request = AsyncMock(return_value=None) + await http.end_stage_instance("c") + http.request.assert_called_once_with( + "DELETE", "/stage-instances/c", custom_headers=None + ) + + +@pytest.mark.asyncio +async def test_stage_channel_start_and_end(): + http = SimpleNamespace( + start_stage_instance=AsyncMock( + return_value=StageInstance({"id": "1", "channel_id": "c", "topic": "hi"}) + ), + edit_stage_instance=AsyncMock( + return_value=StageInstance({"id": "1", "channel_id": "c", "topic": "hi"}) + ), + end_stage_instance=AsyncMock(), + ) + client = type("Client", (), {})() + client._http = http + channel_data = { + "id": "c", + "type": ChannelType.GUILD_STAGE_VOICE.value, + "guild_id": "g", + } + channel = StageChannel(channel_data, client) + + instance = await channel.start_stage_instance("hi") + http.start_stage_instance.assert_awaited_once_with( + {"channel_id": "c", "topic": "hi", "privacy_level": 2}, reason=None + ) + assert isinstance(instance, StageInstance) + assert instance._client is client + + await channel.end_stage_instance() + http.end_stage_instance.assert_awaited_once_with("c", reason=None)