Add stage channel and instance support (#43)

This commit is contained in:
Slipstream 2025-06-10 21:08:32 -06:00 committed by GitHub
parent 3059041ba8
commit 2d6c2cb0be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 195 additions and 5 deletions

View File

@ -335,6 +335,13 @@ class ChannelType(IntEnum):
GUILD_MEDIA = 16 # (Still in development) a channel that can only contain media
class StageInstancePrivacyLevel(IntEnum):
"""Privacy level of a stage instance."""
PUBLIC = 1
GUILD_ONLY = 2
class OverwriteType(IntEnum):
"""Type of target for a permission overwrite."""

View File

@ -23,7 +23,7 @@ from .interactions import InteractionResponsePayload
if TYPE_CHECKING:
from .client import Client
from .models import Message, Webhook, File, Invite
from .models import Message, Webhook, File, StageInstance, Invite
from .interactions import ApplicationCommand, Snowflake
# Discord API constants
@ -924,6 +924,48 @@ class HTTPClient:
"""Sends a typing indicator to the specified channel."""
await self.request("POST", f"/channels/{channel_id}/typing")
async def start_stage_instance(
self, payload: Dict[str, Any], reason: Optional[str] = None
) -> "StageInstance":
"""Starts a stage instance."""
headers = {"X-Audit-Log-Reason": reason} if reason else None
data = await self.request(
"POST", "/stage-instances", payload=payload, custom_headers=headers
)
from .models import StageInstance
return StageInstance(data)
async def edit_stage_instance(
self,
channel_id: "Snowflake",
payload: Dict[str, Any],
reason: Optional[str] = None,
) -> "StageInstance":
"""Edits an existing stage instance."""
headers = {"X-Audit-Log-Reason": reason} if reason else None
data = await self.request(
"PATCH",
f"/stage-instances/{channel_id}",
payload=payload,
custom_headers=headers,
)
from .models import StageInstance
return StageInstance(data)
async def end_stage_instance(
self, channel_id: "Snowflake", reason: Optional[str] = None
) -> None:
"""Ends a stage instance."""
headers = {"X-Audit-Log-Reason": reason} if reason else None
await self.request(
"DELETE", f"/stage-instances/{channel_id}", custom_headers=headers
)
async def get_voice_regions(self) -> List[Dict[str, Any]]:
"""Returns available voice regions."""
return await self.request("GET", "/voice/regions")

View File

@ -1160,6 +1160,85 @@ class VoiceChannel(Channel):
return f"<VoiceChannel id='{self.id}' name='{self.name}' guild_id='{self.guild_id}'>"
class StageChannel(VoiceChannel):
"""Represents a guild stage channel."""
def __repr__(self) -> str:
return f"<StageChannel id='{self.id}' name='{self.name}' guild_id='{self.guild_id}'>"
async def start_stage_instance(
self,
topic: str,
*,
privacy_level: int = 2,
reason: Optional[str] = None,
guild_scheduled_event_id: Optional[str] = None,
) -> "StageInstance":
if not hasattr(self._client, "_http"):
raise DisagreementException("Client missing HTTP for stage instance")
payload: Dict[str, Any] = {
"channel_id": self.id,
"topic": topic,
"privacy_level": privacy_level,
}
if guild_scheduled_event_id is not None:
payload["guild_scheduled_event_id"] = guild_scheduled_event_id
instance = await self._client._http.start_stage_instance(payload, reason=reason)
instance._client = self._client
return instance
async def edit_stage_instance(
self,
*,
topic: Optional[str] = None,
privacy_level: Optional[int] = None,
reason: Optional[str] = None,
) -> "StageInstance":
if not hasattr(self._client, "_http"):
raise DisagreementException("Client missing HTTP for stage instance")
payload: Dict[str, Any] = {}
if topic is not None:
payload["topic"] = topic
if privacy_level is not None:
payload["privacy_level"] = privacy_level
instance = await self._client._http.edit_stage_instance(
self.id, payload, reason=reason
)
instance._client = self._client
return instance
async def end_stage_instance(self, *, reason: Optional[str] = None) -> None:
if not hasattr(self._client, "_http"):
raise DisagreementException("Client missing HTTP for stage instance")
await self._client._http.end_stage_instance(self.id, reason=reason)
class StageInstance:
"""Represents a stage instance."""
def __init__(
self, data: Dict[str, Any], client_instance: Optional["Client"] = None
) -> None:
self._client = client_instance
self.id: str = data["id"]
self.guild_id: Optional[str] = data.get("guild_id")
self.channel_id: str = data["channel_id"]
self.topic: str = data["topic"]
self.privacy_level: int = data.get("privacy_level", 2)
self.discoverable_disabled: bool = data.get("discoverable_disabled", False)
self.guild_scheduled_event_id: Optional[str] = data.get(
"guild_scheduled_event_id"
)
def __repr__(self) -> str:
return f"<StageInstance id='{self.id}' channel_id='{self.channel_id}'>"
class CategoryChannel(Channel):
"""Represents a guild category channel."""
@ -2100,11 +2179,10 @@ def channel_factory(data: Dict[str, Any], client: "Client") -> Channel:
ChannelType.GUILD_ANNOUNCEMENT.value,
):
return TextChannel(data, client)
if channel_type in (
ChannelType.GUILD_VOICE.value,
ChannelType.GUILD_STAGE_VOICE.value,
):
if channel_type == ChannelType.GUILD_VOICE.value:
return VoiceChannel(data, client)
if channel_type == ChannelType.GUILD_STAGE_VOICE.value:
return StageChannel(data, client)
if channel_type == ChannelType.GUILD_CATEGORY.value:
return CategoryChannel(data, client)
if channel_type in (

View File

@ -0,0 +1,63 @@
import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock
from disagreement.http import HTTPClient
from disagreement.models import StageChannel, StageInstance
from disagreement.enums import ChannelType
@pytest.mark.asyncio
async def test_http_start_stage_instance_calls_request():
http = HTTPClient(token="t")
http.request = AsyncMock(return_value={"id": "1", "channel_id": "c", "topic": "t"})
payload = {"channel_id": "c", "topic": "t", "privacy_level": 2}
instance = await http.start_stage_instance(payload)
http.request.assert_called_once_with(
"POST",
"/stage-instances",
payload=payload,
custom_headers=None,
)
assert isinstance(instance, StageInstance)
@pytest.mark.asyncio
async def test_http_end_stage_instance_calls_request():
http = HTTPClient(token="t")
http.request = AsyncMock(return_value=None)
await http.end_stage_instance("c")
http.request.assert_called_once_with(
"DELETE", "/stage-instances/c", custom_headers=None
)
@pytest.mark.asyncio
async def test_stage_channel_start_and_end():
http = SimpleNamespace(
start_stage_instance=AsyncMock(
return_value=StageInstance({"id": "1", "channel_id": "c", "topic": "hi"})
),
edit_stage_instance=AsyncMock(
return_value=StageInstance({"id": "1", "channel_id": "c", "topic": "hi"})
),
end_stage_instance=AsyncMock(),
)
client = type("Client", (), {})()
client._http = http
channel_data = {
"id": "c",
"type": ChannelType.GUILD_STAGE_VOICE.value,
"guild_id": "g",
}
channel = StageChannel(channel_data, client)
instance = await channel.start_stage_instance("hi")
http.start_stage_instance.assert_awaited_once_with(
{"channel_id": "c", "topic": "hi", "privacy_level": 2}, reason=None
)
assert isinstance(instance, StageInstance)
assert instance._client is client
await channel.end_stage_instance()
http.end_stage_instance.assert_awaited_once_with("c", reason=None)