Add stage channel and instance support (#43)

This commit is contained in:
Slipstream 2025-06-10 21:08:32 -06:00 committed by GitHub
parent 3059041ba8
commit 2d6c2cb0be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 195 additions and 5 deletions

View File

@ -335,6 +335,13 @@ class ChannelType(IntEnum):
GUILD_MEDIA = 16 # (Still in development) a channel that can only contain media GUILD_MEDIA = 16 # (Still in development) a channel that can only contain media
class StageInstancePrivacyLevel(IntEnum):
"""Privacy level of a stage instance."""
PUBLIC = 1
GUILD_ONLY = 2
class OverwriteType(IntEnum): class OverwriteType(IntEnum):
"""Type of target for a permission overwrite.""" """Type of target for a permission overwrite."""

View File

@ -23,7 +23,7 @@ from .interactions import InteractionResponsePayload
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import Client from .client import Client
from .models import Message, Webhook, File, Invite from .models import Message, Webhook, File, StageInstance, Invite
from .interactions import ApplicationCommand, Snowflake from .interactions import ApplicationCommand, Snowflake
# Discord API constants # Discord API constants
@ -924,6 +924,48 @@ class HTTPClient:
"""Sends a typing indicator to the specified channel.""" """Sends a typing indicator to the specified channel."""
await self.request("POST", f"/channels/{channel_id}/typing") await self.request("POST", f"/channels/{channel_id}/typing")
async def start_stage_instance(
self, payload: Dict[str, Any], reason: Optional[str] = None
) -> "StageInstance":
"""Starts a stage instance."""
headers = {"X-Audit-Log-Reason": reason} if reason else None
data = await self.request(
"POST", "/stage-instances", payload=payload, custom_headers=headers
)
from .models import StageInstance
return StageInstance(data)
async def edit_stage_instance(
self,
channel_id: "Snowflake",
payload: Dict[str, Any],
reason: Optional[str] = None,
) -> "StageInstance":
"""Edits an existing stage instance."""
headers = {"X-Audit-Log-Reason": reason} if reason else None
data = await self.request(
"PATCH",
f"/stage-instances/{channel_id}",
payload=payload,
custom_headers=headers,
)
from .models import StageInstance
return StageInstance(data)
async def end_stage_instance(
self, channel_id: "Snowflake", reason: Optional[str] = None
) -> None:
"""Ends a stage instance."""
headers = {"X-Audit-Log-Reason": reason} if reason else None
await self.request(
"DELETE", f"/stage-instances/{channel_id}", custom_headers=headers
)
async def get_voice_regions(self) -> List[Dict[str, Any]]: async def get_voice_regions(self) -> List[Dict[str, Any]]:
"""Returns available voice regions.""" """Returns available voice regions."""
return await self.request("GET", "/voice/regions") return await self.request("GET", "/voice/regions")

View File

@ -1160,6 +1160,85 @@ class VoiceChannel(Channel):
return f"<VoiceChannel id='{self.id}' name='{self.name}' guild_id='{self.guild_id}'>" return f"<VoiceChannel id='{self.id}' name='{self.name}' guild_id='{self.guild_id}'>"
class StageChannel(VoiceChannel):
"""Represents a guild stage channel."""
def __repr__(self) -> str:
return f"<StageChannel id='{self.id}' name='{self.name}' guild_id='{self.guild_id}'>"
async def start_stage_instance(
self,
topic: str,
*,
privacy_level: int = 2,
reason: Optional[str] = None,
guild_scheduled_event_id: Optional[str] = None,
) -> "StageInstance":
if not hasattr(self._client, "_http"):
raise DisagreementException("Client missing HTTP for stage instance")
payload: Dict[str, Any] = {
"channel_id": self.id,
"topic": topic,
"privacy_level": privacy_level,
}
if guild_scheduled_event_id is not None:
payload["guild_scheduled_event_id"] = guild_scheduled_event_id
instance = await self._client._http.start_stage_instance(payload, reason=reason)
instance._client = self._client
return instance
async def edit_stage_instance(
self,
*,
topic: Optional[str] = None,
privacy_level: Optional[int] = None,
reason: Optional[str] = None,
) -> "StageInstance":
if not hasattr(self._client, "_http"):
raise DisagreementException("Client missing HTTP for stage instance")
payload: Dict[str, Any] = {}
if topic is not None:
payload["topic"] = topic
if privacy_level is not None:
payload["privacy_level"] = privacy_level
instance = await self._client._http.edit_stage_instance(
self.id, payload, reason=reason
)
instance._client = self._client
return instance
async def end_stage_instance(self, *, reason: Optional[str] = None) -> None:
if not hasattr(self._client, "_http"):
raise DisagreementException("Client missing HTTP for stage instance")
await self._client._http.end_stage_instance(self.id, reason=reason)
class StageInstance:
"""Represents a stage instance."""
def __init__(
self, data: Dict[str, Any], client_instance: Optional["Client"] = None
) -> None:
self._client = client_instance
self.id: str = data["id"]
self.guild_id: Optional[str] = data.get("guild_id")
self.channel_id: str = data["channel_id"]
self.topic: str = data["topic"]
self.privacy_level: int = data.get("privacy_level", 2)
self.discoverable_disabled: bool = data.get("discoverable_disabled", False)
self.guild_scheduled_event_id: Optional[str] = data.get(
"guild_scheduled_event_id"
)
def __repr__(self) -> str:
return f"<StageInstance id='{self.id}' channel_id='{self.channel_id}'>"
class CategoryChannel(Channel): class CategoryChannel(Channel):
"""Represents a guild category channel.""" """Represents a guild category channel."""
@ -2100,11 +2179,10 @@ def channel_factory(data: Dict[str, Any], client: "Client") -> Channel:
ChannelType.GUILD_ANNOUNCEMENT.value, ChannelType.GUILD_ANNOUNCEMENT.value,
): ):
return TextChannel(data, client) return TextChannel(data, client)
if channel_type in ( if channel_type == ChannelType.GUILD_VOICE.value:
ChannelType.GUILD_VOICE.value,
ChannelType.GUILD_STAGE_VOICE.value,
):
return VoiceChannel(data, client) return VoiceChannel(data, client)
if channel_type == ChannelType.GUILD_STAGE_VOICE.value:
return StageChannel(data, client)
if channel_type == ChannelType.GUILD_CATEGORY.value: if channel_type == ChannelType.GUILD_CATEGORY.value:
return CategoryChannel(data, client) return CategoryChannel(data, client)
if channel_type in ( if channel_type in (

View File

@ -0,0 +1,63 @@
import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock
from disagreement.http import HTTPClient
from disagreement.models import StageChannel, StageInstance
from disagreement.enums import ChannelType
@pytest.mark.asyncio
async def test_http_start_stage_instance_calls_request():
http = HTTPClient(token="t")
http.request = AsyncMock(return_value={"id": "1", "channel_id": "c", "topic": "t"})
payload = {"channel_id": "c", "topic": "t", "privacy_level": 2}
instance = await http.start_stage_instance(payload)
http.request.assert_called_once_with(
"POST",
"/stage-instances",
payload=payload,
custom_headers=None,
)
assert isinstance(instance, StageInstance)
@pytest.mark.asyncio
async def test_http_end_stage_instance_calls_request():
http = HTTPClient(token="t")
http.request = AsyncMock(return_value=None)
await http.end_stage_instance("c")
http.request.assert_called_once_with(
"DELETE", "/stage-instances/c", custom_headers=None
)
@pytest.mark.asyncio
async def test_stage_channel_start_and_end():
http = SimpleNamespace(
start_stage_instance=AsyncMock(
return_value=StageInstance({"id": "1", "channel_id": "c", "topic": "hi"})
),
edit_stage_instance=AsyncMock(
return_value=StageInstance({"id": "1", "channel_id": "c", "topic": "hi"})
),
end_stage_instance=AsyncMock(),
)
client = type("Client", (), {})()
client._http = http
channel_data = {
"id": "c",
"type": ChannelType.GUILD_STAGE_VOICE.value,
"guild_id": "g",
}
channel = StageChannel(channel_data, client)
instance = await channel.start_stage_instance("hi")
http.start_stage_instance.assert_awaited_once_with(
{"channel_id": "c", "topic": "hi", "privacy_level": 2}, reason=None
)
assert isinstance(instance, StageInstance)
assert instance._client is client
await channel.end_stage_instance()
http.end_stage_instance.assert_awaited_once_with("c", reason=None)