Add stage channel and instance support (#43)
This commit is contained in:
parent
3059041ba8
commit
2d6c2cb0be
@ -335,6 +335,13 @@ class ChannelType(IntEnum):
|
||||
GUILD_MEDIA = 16 # (Still in development) a channel that can only contain media
|
||||
|
||||
|
||||
class StageInstancePrivacyLevel(IntEnum):
|
||||
"""Privacy level of a stage instance."""
|
||||
|
||||
PUBLIC = 1
|
||||
GUILD_ONLY = 2
|
||||
|
||||
|
||||
class OverwriteType(IntEnum):
|
||||
"""Type of target for a permission overwrite."""
|
||||
|
||||
|
@ -23,7 +23,7 @@ from .interactions import InteractionResponsePayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import Client
|
||||
from .models import Message, Webhook, File, Invite
|
||||
from .models import Message, Webhook, File, StageInstance, Invite
|
||||
from .interactions import ApplicationCommand, Snowflake
|
||||
|
||||
# Discord API constants
|
||||
@ -924,6 +924,48 @@ class HTTPClient:
|
||||
"""Sends a typing indicator to the specified channel."""
|
||||
await self.request("POST", f"/channels/{channel_id}/typing")
|
||||
|
||||
async def start_stage_instance(
|
||||
self, payload: Dict[str, Any], reason: Optional[str] = None
|
||||
) -> "StageInstance":
|
||||
"""Starts a stage instance."""
|
||||
|
||||
headers = {"X-Audit-Log-Reason": reason} if reason else None
|
||||
data = await self.request(
|
||||
"POST", "/stage-instances", payload=payload, custom_headers=headers
|
||||
)
|
||||
from .models import StageInstance
|
||||
|
||||
return StageInstance(data)
|
||||
|
||||
async def edit_stage_instance(
|
||||
self,
|
||||
channel_id: "Snowflake",
|
||||
payload: Dict[str, Any],
|
||||
reason: Optional[str] = None,
|
||||
) -> "StageInstance":
|
||||
"""Edits an existing stage instance."""
|
||||
|
||||
headers = {"X-Audit-Log-Reason": reason} if reason else None
|
||||
data = await self.request(
|
||||
"PATCH",
|
||||
f"/stage-instances/{channel_id}",
|
||||
payload=payload,
|
||||
custom_headers=headers,
|
||||
)
|
||||
from .models import StageInstance
|
||||
|
||||
return StageInstance(data)
|
||||
|
||||
async def end_stage_instance(
|
||||
self, channel_id: "Snowflake", reason: Optional[str] = None
|
||||
) -> None:
|
||||
"""Ends a stage instance."""
|
||||
|
||||
headers = {"X-Audit-Log-Reason": reason} if reason else None
|
||||
await self.request(
|
||||
"DELETE", f"/stage-instances/{channel_id}", custom_headers=headers
|
||||
)
|
||||
|
||||
async def get_voice_regions(self) -> List[Dict[str, Any]]:
|
||||
"""Returns available voice regions."""
|
||||
return await self.request("GET", "/voice/regions")
|
||||
|
@ -1160,6 +1160,85 @@ class VoiceChannel(Channel):
|
||||
return f"<VoiceChannel id='{self.id}' name='{self.name}' guild_id='{self.guild_id}'>"
|
||||
|
||||
|
||||
class StageChannel(VoiceChannel):
|
||||
"""Represents a guild stage channel."""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<StageChannel id='{self.id}' name='{self.name}' guild_id='{self.guild_id}'>"
|
||||
|
||||
async def start_stage_instance(
|
||||
self,
|
||||
topic: str,
|
||||
*,
|
||||
privacy_level: int = 2,
|
||||
reason: Optional[str] = None,
|
||||
guild_scheduled_event_id: Optional[str] = None,
|
||||
) -> "StageInstance":
|
||||
if not hasattr(self._client, "_http"):
|
||||
raise DisagreementException("Client missing HTTP for stage instance")
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"channel_id": self.id,
|
||||
"topic": topic,
|
||||
"privacy_level": privacy_level,
|
||||
}
|
||||
if guild_scheduled_event_id is not None:
|
||||
payload["guild_scheduled_event_id"] = guild_scheduled_event_id
|
||||
|
||||
instance = await self._client._http.start_stage_instance(payload, reason=reason)
|
||||
instance._client = self._client
|
||||
return instance
|
||||
|
||||
async def edit_stage_instance(
|
||||
self,
|
||||
*,
|
||||
topic: Optional[str] = None,
|
||||
privacy_level: Optional[int] = None,
|
||||
reason: Optional[str] = None,
|
||||
) -> "StageInstance":
|
||||
if not hasattr(self._client, "_http"):
|
||||
raise DisagreementException("Client missing HTTP for stage instance")
|
||||
|
||||
payload: Dict[str, Any] = {}
|
||||
if topic is not None:
|
||||
payload["topic"] = topic
|
||||
if privacy_level is not None:
|
||||
payload["privacy_level"] = privacy_level
|
||||
|
||||
instance = await self._client._http.edit_stage_instance(
|
||||
self.id, payload, reason=reason
|
||||
)
|
||||
instance._client = self._client
|
||||
return instance
|
||||
|
||||
async def end_stage_instance(self, *, reason: Optional[str] = None) -> None:
|
||||
if not hasattr(self._client, "_http"):
|
||||
raise DisagreementException("Client missing HTTP for stage instance")
|
||||
|
||||
await self._client._http.end_stage_instance(self.id, reason=reason)
|
||||
|
||||
|
||||
class StageInstance:
|
||||
"""Represents a stage instance."""
|
||||
|
||||
def __init__(
|
||||
self, data: Dict[str, Any], client_instance: Optional["Client"] = None
|
||||
) -> None:
|
||||
self._client = client_instance
|
||||
self.id: str = data["id"]
|
||||
self.guild_id: Optional[str] = data.get("guild_id")
|
||||
self.channel_id: str = data["channel_id"]
|
||||
self.topic: str = data["topic"]
|
||||
self.privacy_level: int = data.get("privacy_level", 2)
|
||||
self.discoverable_disabled: bool = data.get("discoverable_disabled", False)
|
||||
self.guild_scheduled_event_id: Optional[str] = data.get(
|
||||
"guild_scheduled_event_id"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<StageInstance id='{self.id}' channel_id='{self.channel_id}'>"
|
||||
|
||||
|
||||
class CategoryChannel(Channel):
|
||||
"""Represents a guild category channel."""
|
||||
|
||||
@ -2100,11 +2179,10 @@ def channel_factory(data: Dict[str, Any], client: "Client") -> Channel:
|
||||
ChannelType.GUILD_ANNOUNCEMENT.value,
|
||||
):
|
||||
return TextChannel(data, client)
|
||||
if channel_type in (
|
||||
ChannelType.GUILD_VOICE.value,
|
||||
ChannelType.GUILD_STAGE_VOICE.value,
|
||||
):
|
||||
if channel_type == ChannelType.GUILD_VOICE.value:
|
||||
return VoiceChannel(data, client)
|
||||
if channel_type == ChannelType.GUILD_STAGE_VOICE.value:
|
||||
return StageChannel(data, client)
|
||||
if channel_type == ChannelType.GUILD_CATEGORY.value:
|
||||
return CategoryChannel(data, client)
|
||||
if channel_type in (
|
||||
|
63
tests/test_stage_instance.py
Normal file
63
tests/test_stage_instance.py
Normal file
@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from disagreement.http import HTTPClient
|
||||
from disagreement.models import StageChannel, StageInstance
|
||||
from disagreement.enums import ChannelType
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_start_stage_instance_calls_request():
|
||||
http = HTTPClient(token="t")
|
||||
http.request = AsyncMock(return_value={"id": "1", "channel_id": "c", "topic": "t"})
|
||||
payload = {"channel_id": "c", "topic": "t", "privacy_level": 2}
|
||||
instance = await http.start_stage_instance(payload)
|
||||
http.request.assert_called_once_with(
|
||||
"POST",
|
||||
"/stage-instances",
|
||||
payload=payload,
|
||||
custom_headers=None,
|
||||
)
|
||||
assert isinstance(instance, StageInstance)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_end_stage_instance_calls_request():
|
||||
http = HTTPClient(token="t")
|
||||
http.request = AsyncMock(return_value=None)
|
||||
await http.end_stage_instance("c")
|
||||
http.request.assert_called_once_with(
|
||||
"DELETE", "/stage-instances/c", custom_headers=None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stage_channel_start_and_end():
|
||||
http = SimpleNamespace(
|
||||
start_stage_instance=AsyncMock(
|
||||
return_value=StageInstance({"id": "1", "channel_id": "c", "topic": "hi"})
|
||||
),
|
||||
edit_stage_instance=AsyncMock(
|
||||
return_value=StageInstance({"id": "1", "channel_id": "c", "topic": "hi"})
|
||||
),
|
||||
end_stage_instance=AsyncMock(),
|
||||
)
|
||||
client = type("Client", (), {})()
|
||||
client._http = http
|
||||
channel_data = {
|
||||
"id": "c",
|
||||
"type": ChannelType.GUILD_STAGE_VOICE.value,
|
||||
"guild_id": "g",
|
||||
}
|
||||
channel = StageChannel(channel_data, client)
|
||||
|
||||
instance = await channel.start_stage_instance("hi")
|
||||
http.start_stage_instance.assert_awaited_once_with(
|
||||
{"channel_id": "c", "topic": "hi", "privacy_level": 2}, reason=None
|
||||
)
|
||||
assert isinstance(instance, StageInstance)
|
||||
assert instance._client is client
|
||||
|
||||
await channel.end_stage_instance()
|
||||
http.end_stage_instance.assert_awaited_once_with("c", reason=None)
|
Loading…
x
Reference in New Issue
Block a user