Compare commits

...

2 Commits

Author SHA1 Message Date
43ca2dc561
Add voice region support (#44)
Squash and merge PR #44
2025-06-10 20:50:27 -06:00
7595e33fd1
Emit shard lifecycle events (#45)
Squash and merge PR #45
2025-06-10 20:50:25 -06:00
7 changed files with 153 additions and 2 deletions

View File

@ -22,7 +22,7 @@ from .http import HTTPClient
from .gateway import GatewayClient from .gateway import GatewayClient
from .shard_manager import ShardManager from .shard_manager import ShardManager
from .event_dispatcher import EventDispatcher from .event_dispatcher import EventDispatcher
from .enums import GatewayIntent, InteractionType, GatewayOpcode from .enums import GatewayIntent, InteractionType, GatewayOpcode, VoiceRegion
from .errors import DisagreementException, AuthenticationError from .errors import DisagreementException, AuthenticationError
from .typing import Typing from .typing import Typing
from .ext.commands.core import CommandHandler from .ext.commands.core import CommandHandler
@ -1235,6 +1235,20 @@ class Client:
print(f"Failed to fetch channel {channel_id}: {e}") print(f"Failed to fetch channel {channel_id}: {e}")
return None return None
async def fetch_voice_regions(self) -> List[VoiceRegion]:
"""Fetches available voice regions."""
if self._closed:
raise DisagreementException("Client is closed.")
data = await self._http.get_voice_regions()
regions = []
for region in data:
region_id = region.get("id")
if region_id:
regions.append(VoiceRegion(region_id))
return regions
async def create_webhook( async def create_webhook(
self, channel_id: Snowflake, payload: Dict[str, Any] self, channel_id: Snowflake, payload: Dict[str, Any]
) -> "Webhook": ) -> "Webhook":

View File

@ -278,6 +278,36 @@ class GuildFeature(str, Enum): # Changed from IntEnum to Enum
return str(value) return str(value)
class VoiceRegion(str, Enum):
"""Voice region identifier."""
AMSTERDAM = "amsterdam"
BRAZIL = "brazil"
DUBAI = "dubai"
EU_CENTRAL = "eu-central"
EU_WEST = "eu-west"
EUROPE = "europe"
FRANKFURT = "frankfurt"
HONGKONG = "hongkong"
INDIA = "india"
JAPAN = "japan"
RUSSIA = "russia"
SINGAPORE = "singapore"
SOUTHAFRICA = "southafrica"
SOUTH_KOREA = "south-korea"
SYDNEY = "sydney"
US_CENTRAL = "us-central"
US_EAST = "us-east"
US_SOUTH = "us-south"
US_WEST = "us-west"
VIP_US_EAST = "vip-us-east"
VIP_US_WEST = "vip-us-west"
@classmethod
def _missing_(cls, value): # type: ignore
return str(value)
# --- Channel Enums --- # --- Channel Enums ---

View File

@ -344,6 +344,9 @@ class GatewayClient:
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {} raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
) )
await self._dispatcher.dispatch(event_name, event_data_to_dispatch) await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
await self._dispatcher.dispatch(
"SHARD_RESUME", {"shard_id": self._shard_id}
)
elif event_name: elif event_name:
# For other events, ensure 'd' is a dict, or pass {} if 'd' is null/missing. # For other events, ensure 'd' is a dict, or pass {} if 'd' is null/missing.
# Models/parsers in EventDispatcher will need to handle potentially empty dicts. # Models/parsers in EventDispatcher will need to handle potentially empty dicts.
@ -508,6 +511,10 @@ class GatewayClient:
self._receive_task.cancel() self._receive_task.cancel()
self._receive_task = self._loop.create_task(self._receive_loop()) self._receive_task = self._loop.create_task(self._receive_loop())
await self._dispatcher.dispatch(
"SHARD_CONNECT", {"shard_id": self._shard_id}
)
except aiohttp.ClientConnectorError as e: except aiohttp.ClientConnectorError as e:
raise GatewayException( raise GatewayException(
f"Failed to connect to Gateway (Connector Error): {e}" f"Failed to connect to Gateway (Connector Error): {e}"
@ -559,6 +566,10 @@ class GatewayClient:
self._last_sequence = None self._last_sequence = None
self._resume_gateway_url = None # This might be re-fetched anyway self._resume_gateway_url = None # This might be re-fetched anyway
await self._dispatcher.dispatch(
"SHARD_DISCONNECT", {"shard_id": self._shard_id}
)
@property @property
def latency(self) -> Optional[float]: def latency(self) -> Optional[float]:
"""Returns the latency between heartbeat and ACK in seconds.""" """Returns the latency between heartbeat and ACK in seconds."""

View File

@ -912,3 +912,7 @@ class HTTPClient:
async def trigger_typing(self, channel_id: str) -> None: async def trigger_typing(self, channel_id: str) -> None:
"""Sends a typing indicator to the specified channel.""" """Sends a typing indicator to the specified channel."""
await self.request("POST", f"/channels/{channel_id}/typing") await self.request("POST", f"/channels/{channel_id}/typing")
async def get_voice_regions(self) -> List[Dict[str, Any]]:
"""Returns available voice regions."""
return await self.request("GET", "/voice/regions")

View File

@ -98,3 +98,36 @@ Emitted when a guild role is updated. The callback receives a
async def on_guild_role_update(event: disagreement.GuildRoleUpdate): async def on_guild_role_update(event: disagreement.GuildRoleUpdate):
... ...
``` ```
## SHARD_CONNECT
Fired when a shard establishes its gateway connection. The callback receives a
dictionary with the shard ID.
```python
@client.event
async def on_shard_connect(info: dict):
print("shard connected", info["shard_id"])
```
## SHARD_DISCONNECT
Emitted when a shard's gateway connection is closed. The callback receives a
dictionary with the shard ID.
```python
@client.event
async def on_shard_disconnect(info: dict):
...
```
## SHARD_RESUME
Sent when a shard successfully resumes after a reconnect. The callback receives
a dictionary with the shard ID.
```python
@client.event
async def on_shard_resume(info: dict):
...
```

View File

@ -42,3 +42,14 @@ await vc.play(FFmpegAudioSource("other.mp3"))
``` ```
Call `await vc.close()` when finished. Call `await vc.close()` when finished.
## Fetching Available Voice Regions
Use :meth:`Client.fetch_voice_regions` to list the voice regions that Discord
currently offers. The method returns a list of :class:`VoiceRegion` values.
```python
regions = await client.fetch_voice_regions()
for region in regions:
print(region.value)
```

View File

@ -3,6 +3,7 @@ from unittest.mock import AsyncMock
from disagreement.shard_manager import ShardManager from disagreement.shard_manager import ShardManager
from disagreement.client import Client, AutoShardedClient from disagreement.client import Client, AutoShardedClient
from disagreement.event_dispatcher import EventDispatcher
class DummyGateway: class DummyGateway:
@ -10,11 +11,27 @@ class DummyGateway:
self.connect = AsyncMock() self.connect = AsyncMock()
self.close = AsyncMock() self.close = AsyncMock()
dispatcher = kwargs.get("event_dispatcher")
shard_id = kwargs.get("shard_id")
async def emit_connect():
await dispatcher.dispatch("SHARD_CONNECT", {"shard_id": shard_id})
async def emit_close():
await dispatcher.dispatch("SHARD_DISCONNECT", {"shard_id": shard_id})
async def emit_resume():
await dispatcher.dispatch("SHARD_RESUME", {"shard_id": shard_id})
self.connect.side_effect = emit_connect
self.close.side_effect = emit_close
self.resume = AsyncMock(side_effect=emit_resume)
class DummyClient: class DummyClient:
def __init__(self): def __init__(self):
self._http = object() self._http = object()
self._event_dispatcher = object() self._event_dispatcher = EventDispatcher(self)
self.token = "t" self.token = "t"
self.intents = 0 self.intents = 0
self.verbose = False self.verbose = False
@ -68,3 +85,34 @@ async def test_auto_sharded_client_fetches_count(monkeypatch):
await c.connect() await c.connect()
dummy_manager.start.assert_awaited_once() dummy_manager.start.assert_awaited_once()
assert c.shard_count == 4 assert c.shard_count == 4
@pytest.mark.asyncio
async def test_shard_events_emitted(monkeypatch):
monkeypatch.setattr("disagreement.shard_manager.GatewayClient", DummyGateway)
client = DummyClient()
manager = ShardManager(client, shard_count=1)
events: list[tuple[str, int | None]] = []
async def on_connect(info):
events.append(("connect", info.get("shard_id")))
async def on_disconnect(info):
events.append(("disconnect", info.get("shard_id")))
async def on_resume(info):
events.append(("resume", info.get("shard_id")))
client._event_dispatcher.register("SHARD_CONNECT", on_connect)
client._event_dispatcher.register("SHARD_DISCONNECT", on_disconnect)
client._event_dispatcher.register("SHARD_RESUME", on_resume)
await manager.start()
await manager.shards[0].gateway.resume()
await manager.close()
assert ("connect", 0) in events
assert ("disconnect", 0) in events
assert ("resume", 0) in events