Compare commits
2 Commits
36b06c6c7a
...
43ca2dc561
Author | SHA1 | Date | |
---|---|---|---|
43ca2dc561 | |||
7595e33fd1 |
@ -22,7 +22,7 @@ from .http import HTTPClient
|
|||||||
from .gateway import GatewayClient
|
from .gateway import GatewayClient
|
||||||
from .shard_manager import ShardManager
|
from .shard_manager import ShardManager
|
||||||
from .event_dispatcher import EventDispatcher
|
from .event_dispatcher import EventDispatcher
|
||||||
from .enums import GatewayIntent, InteractionType, GatewayOpcode
|
from .enums import GatewayIntent, InteractionType, GatewayOpcode, VoiceRegion
|
||||||
from .errors import DisagreementException, AuthenticationError
|
from .errors import DisagreementException, AuthenticationError
|
||||||
from .typing import Typing
|
from .typing import Typing
|
||||||
from .ext.commands.core import CommandHandler
|
from .ext.commands.core import CommandHandler
|
||||||
@ -1235,6 +1235,20 @@ class Client:
|
|||||||
print(f"Failed to fetch channel {channel_id}: {e}")
|
print(f"Failed to fetch channel {channel_id}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def fetch_voice_regions(self) -> List[VoiceRegion]:
|
||||||
|
"""Fetches available voice regions."""
|
||||||
|
|
||||||
|
if self._closed:
|
||||||
|
raise DisagreementException("Client is closed.")
|
||||||
|
|
||||||
|
data = await self._http.get_voice_regions()
|
||||||
|
regions = []
|
||||||
|
for region in data:
|
||||||
|
region_id = region.get("id")
|
||||||
|
if region_id:
|
||||||
|
regions.append(VoiceRegion(region_id))
|
||||||
|
return regions
|
||||||
|
|
||||||
async def create_webhook(
|
async def create_webhook(
|
||||||
self, channel_id: Snowflake, payload: Dict[str, Any]
|
self, channel_id: Snowflake, payload: Dict[str, Any]
|
||||||
) -> "Webhook":
|
) -> "Webhook":
|
||||||
|
@ -278,6 +278,36 @@ class GuildFeature(str, Enum): # Changed from IntEnum to Enum
|
|||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceRegion(str, Enum):
|
||||||
|
"""Voice region identifier."""
|
||||||
|
|
||||||
|
AMSTERDAM = "amsterdam"
|
||||||
|
BRAZIL = "brazil"
|
||||||
|
DUBAI = "dubai"
|
||||||
|
EU_CENTRAL = "eu-central"
|
||||||
|
EU_WEST = "eu-west"
|
||||||
|
EUROPE = "europe"
|
||||||
|
FRANKFURT = "frankfurt"
|
||||||
|
HONGKONG = "hongkong"
|
||||||
|
INDIA = "india"
|
||||||
|
JAPAN = "japan"
|
||||||
|
RUSSIA = "russia"
|
||||||
|
SINGAPORE = "singapore"
|
||||||
|
SOUTHAFRICA = "southafrica"
|
||||||
|
SOUTH_KOREA = "south-korea"
|
||||||
|
SYDNEY = "sydney"
|
||||||
|
US_CENTRAL = "us-central"
|
||||||
|
US_EAST = "us-east"
|
||||||
|
US_SOUTH = "us-south"
|
||||||
|
US_WEST = "us-west"
|
||||||
|
VIP_US_EAST = "vip-us-east"
|
||||||
|
VIP_US_WEST = "vip-us-west"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _missing_(cls, value): # type: ignore
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
# --- Channel Enums ---
|
# --- Channel Enums ---
|
||||||
|
|
||||||
|
|
||||||
|
@ -344,6 +344,9 @@ class GatewayClient:
|
|||||||
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
|
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
|
||||||
)
|
)
|
||||||
await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
|
await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
|
||||||
|
await self._dispatcher.dispatch(
|
||||||
|
"SHARD_RESUME", {"shard_id": self._shard_id}
|
||||||
|
)
|
||||||
elif event_name:
|
elif event_name:
|
||||||
# For other events, ensure 'd' is a dict, or pass {} if 'd' is null/missing.
|
# For other events, ensure 'd' is a dict, or pass {} if 'd' is null/missing.
|
||||||
# Models/parsers in EventDispatcher will need to handle potentially empty dicts.
|
# Models/parsers in EventDispatcher will need to handle potentially empty dicts.
|
||||||
@ -508,6 +511,10 @@ class GatewayClient:
|
|||||||
self._receive_task.cancel()
|
self._receive_task.cancel()
|
||||||
self._receive_task = self._loop.create_task(self._receive_loop())
|
self._receive_task = self._loop.create_task(self._receive_loop())
|
||||||
|
|
||||||
|
await self._dispatcher.dispatch(
|
||||||
|
"SHARD_CONNECT", {"shard_id": self._shard_id}
|
||||||
|
)
|
||||||
|
|
||||||
except aiohttp.ClientConnectorError as e:
|
except aiohttp.ClientConnectorError as e:
|
||||||
raise GatewayException(
|
raise GatewayException(
|
||||||
f"Failed to connect to Gateway (Connector Error): {e}"
|
f"Failed to connect to Gateway (Connector Error): {e}"
|
||||||
@ -559,6 +566,10 @@ class GatewayClient:
|
|||||||
self._last_sequence = None
|
self._last_sequence = None
|
||||||
self._resume_gateway_url = None # This might be re-fetched anyway
|
self._resume_gateway_url = None # This might be re-fetched anyway
|
||||||
|
|
||||||
|
await self._dispatcher.dispatch(
|
||||||
|
"SHARD_DISCONNECT", {"shard_id": self._shard_id}
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def latency(self) -> Optional[float]:
|
def latency(self) -> Optional[float]:
|
||||||
"""Returns the latency between heartbeat and ACK in seconds."""
|
"""Returns the latency between heartbeat and ACK in seconds."""
|
||||||
|
@ -912,3 +912,7 @@ class HTTPClient:
|
|||||||
async def trigger_typing(self, channel_id: str) -> None:
|
async def trigger_typing(self, channel_id: str) -> None:
|
||||||
"""Sends a typing indicator to the specified channel."""
|
"""Sends a typing indicator to the specified channel."""
|
||||||
await self.request("POST", f"/channels/{channel_id}/typing")
|
await self.request("POST", f"/channels/{channel_id}/typing")
|
||||||
|
|
||||||
|
async def get_voice_regions(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Returns available voice regions."""
|
||||||
|
return await self.request("GET", "/voice/regions")
|
||||||
|
@ -98,3 +98,36 @@ Emitted when a guild role is updated. The callback receives a
|
|||||||
async def on_guild_role_update(event: disagreement.GuildRoleUpdate):
|
async def on_guild_role_update(event: disagreement.GuildRoleUpdate):
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## SHARD_CONNECT
|
||||||
|
|
||||||
|
Fired when a shard establishes its gateway connection. The callback receives a
|
||||||
|
dictionary with the shard ID.
|
||||||
|
|
||||||
|
```python
|
||||||
|
@client.event
|
||||||
|
async def on_shard_connect(info: dict):
|
||||||
|
print("shard connected", info["shard_id"])
|
||||||
|
```
|
||||||
|
|
||||||
|
## SHARD_DISCONNECT
|
||||||
|
|
||||||
|
Emitted when a shard's gateway connection is closed. The callback receives a
|
||||||
|
dictionary with the shard ID.
|
||||||
|
|
||||||
|
```python
|
||||||
|
@client.event
|
||||||
|
async def on_shard_disconnect(info: dict):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## SHARD_RESUME
|
||||||
|
|
||||||
|
Sent when a shard successfully resumes after a reconnect. The callback receives
|
||||||
|
a dictionary with the shard ID.
|
||||||
|
|
||||||
|
```python
|
||||||
|
@client.event
|
||||||
|
async def on_shard_resume(info: dict):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
@ -42,3 +42,14 @@ await vc.play(FFmpegAudioSource("other.mp3"))
|
|||||||
```
|
```
|
||||||
|
|
||||||
Call `await vc.close()` when finished.
|
Call `await vc.close()` when finished.
|
||||||
|
|
||||||
|
## Fetching Available Voice Regions
|
||||||
|
|
||||||
|
Use :meth:`Client.fetch_voice_regions` to list the voice regions that Discord
|
||||||
|
currently offers. The method returns a list of :class:`VoiceRegion` values.
|
||||||
|
|
||||||
|
```python
|
||||||
|
regions = await client.fetch_voice_regions()
|
||||||
|
for region in regions:
|
||||||
|
print(region.value)
|
||||||
|
```
|
||||||
|
@ -3,6 +3,7 @@ from unittest.mock import AsyncMock
|
|||||||
|
|
||||||
from disagreement.shard_manager import ShardManager
|
from disagreement.shard_manager import ShardManager
|
||||||
from disagreement.client import Client, AutoShardedClient
|
from disagreement.client import Client, AutoShardedClient
|
||||||
|
from disagreement.event_dispatcher import EventDispatcher
|
||||||
|
|
||||||
|
|
||||||
class DummyGateway:
|
class DummyGateway:
|
||||||
@ -10,11 +11,27 @@ class DummyGateway:
|
|||||||
self.connect = AsyncMock()
|
self.connect = AsyncMock()
|
||||||
self.close = AsyncMock()
|
self.close = AsyncMock()
|
||||||
|
|
||||||
|
dispatcher = kwargs.get("event_dispatcher")
|
||||||
|
shard_id = kwargs.get("shard_id")
|
||||||
|
|
||||||
|
async def emit_connect():
|
||||||
|
await dispatcher.dispatch("SHARD_CONNECT", {"shard_id": shard_id})
|
||||||
|
|
||||||
|
async def emit_close():
|
||||||
|
await dispatcher.dispatch("SHARD_DISCONNECT", {"shard_id": shard_id})
|
||||||
|
|
||||||
|
async def emit_resume():
|
||||||
|
await dispatcher.dispatch("SHARD_RESUME", {"shard_id": shard_id})
|
||||||
|
|
||||||
|
self.connect.side_effect = emit_connect
|
||||||
|
self.close.side_effect = emit_close
|
||||||
|
self.resume = AsyncMock(side_effect=emit_resume)
|
||||||
|
|
||||||
|
|
||||||
class DummyClient:
|
class DummyClient:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._http = object()
|
self._http = object()
|
||||||
self._event_dispatcher = object()
|
self._event_dispatcher = EventDispatcher(self)
|
||||||
self.token = "t"
|
self.token = "t"
|
||||||
self.intents = 0
|
self.intents = 0
|
||||||
self.verbose = False
|
self.verbose = False
|
||||||
@ -68,3 +85,34 @@ async def test_auto_sharded_client_fetches_count(monkeypatch):
|
|||||||
await c.connect()
|
await c.connect()
|
||||||
dummy_manager.start.assert_awaited_once()
|
dummy_manager.start.assert_awaited_once()
|
||||||
assert c.shard_count == 4
|
assert c.shard_count == 4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_shard_events_emitted(monkeypatch):
|
||||||
|
monkeypatch.setattr("disagreement.shard_manager.GatewayClient", DummyGateway)
|
||||||
|
|
||||||
|
client = DummyClient()
|
||||||
|
manager = ShardManager(client, shard_count=1)
|
||||||
|
|
||||||
|
events: list[tuple[str, int | None]] = []
|
||||||
|
|
||||||
|
async def on_connect(info):
|
||||||
|
events.append(("connect", info.get("shard_id")))
|
||||||
|
|
||||||
|
async def on_disconnect(info):
|
||||||
|
events.append(("disconnect", info.get("shard_id")))
|
||||||
|
|
||||||
|
async def on_resume(info):
|
||||||
|
events.append(("resume", info.get("shard_id")))
|
||||||
|
|
||||||
|
client._event_dispatcher.register("SHARD_CONNECT", on_connect)
|
||||||
|
client._event_dispatcher.register("SHARD_DISCONNECT", on_disconnect)
|
||||||
|
client._event_dispatcher.register("SHARD_RESUME", on_resume)
|
||||||
|
|
||||||
|
await manager.start()
|
||||||
|
await manager.shards[0].gateway.resume()
|
||||||
|
await manager.close()
|
||||||
|
|
||||||
|
assert ("connect", 0) in events
|
||||||
|
assert ("disconnect", 0) in events
|
||||||
|
assert ("resume", 0) in events
|
||||||
|
Loading…
x
Reference in New Issue
Block a user