Add AutoShardedClient with automatic shard count (#18)

This commit is contained in:
Slipstream 2025-06-10 15:45:01 -06:00 committed by GitHub
parent 0eed122f02
commit 1c2241c9c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 62 additions and 3 deletions

View File

@ -117,13 +117,20 @@ roles = await client.fetch_roles(guild.id)
## Sharding ## Sharding
To run your bot across multiple gateway shards, pass `shard_count` when creating To run your bot across multiple gateway shards, pass ``shard_count`` when creating
the client: the client:
```python ```python
client = disagreement.Client(token=BOT_TOKEN, shard_count=2) client = disagreement.Client(token=BOT_TOKEN, shard_count=2)
``` ```
If you want the library to determine the recommended shard count automatically,
use ``AutoShardedClient``:
```python
client = disagreement.AutoShardedClient(token=BOT_TOKEN)
```
See `examples/sharded_bot.py` for a full example. See `examples/sharded_bot.py` for a full example.
## Contributing ## Contributing

View File

@ -16,7 +16,7 @@ __license__ = "BSD 3-Clause License"
__copyright__ = "Copyright 2025 Slipstream" __copyright__ = "Copyright 2025 Slipstream"
__version__ = "0.0.2" __version__ = "0.0.2"
from .client import Client from .client import Client, AutoShardedClient
from .models import Message, User, Reaction from .models import Message, User, Reaction
from .voice_client import VoiceClient from .voice_client import VoiceClient
from .audio import AudioSource, FFmpegAudioSource from .audio import AudioSource, FFmpegAudioSource

View File

@ -1287,3 +1287,19 @@ class Client:
print(f"Unhandled exception in event listener for '{event_method}':") print(f"Unhandled exception in event listener for '{event_method}':")
print(f"{type(exc).__name__}: {exc}") print(f"{type(exc).__name__}: {exc}")
class AutoShardedClient(Client):
"""A :class:`Client` that automatically determines the shard count.
If ``shard_count`` is not provided, the client will query the Discord API
via :meth:`HTTPClient.get_gateway_bot` for the recommended shard count and
use that when connecting.
"""
async def connect(self, reconnect: bool = True) -> None: # type: ignore[override]
if self.shard_count is None:
data = await self._http.get_gateway_bot()
self.shard_count = data.get("shards", 1)
await super().connect(reconnect=reconnect)

20
docs/sharding.md Normal file
View File

@ -0,0 +1,20 @@
# Sharding
`disagreement` supports splitting your gateway connection across multiple shards.
Use `Client` with the `shard_count` parameter when you want to control the count
manually.
`AutoShardedClient` asks Discord for the recommended number of shards at runtime
and configures the `ShardManager` automatically.
```python
import asyncio
import disagreement
bot = disagreement.AutoShardedClient(token="YOUR_TOKEN")
async def main():
await bot.run()
asyncio.run(main())
```

View File

@ -2,7 +2,7 @@ import pytest
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
from disagreement.shard_manager import ShardManager from disagreement.shard_manager import ShardManager
from disagreement.client import Client from disagreement.client import Client, AutoShardedClient
class DummyGateway: class DummyGateway:
@ -50,3 +50,19 @@ async def test_client_uses_shard_manager(monkeypatch):
monkeypatch.setattr(c, "wait_until_ready", AsyncMock()) monkeypatch.setattr(c, "wait_until_ready", AsyncMock())
await c.connect() await c.connect()
dummy_manager.start.assert_awaited_once() dummy_manager.start.assert_awaited_once()
@pytest.mark.asyncio
async def test_auto_sharded_client_fetches_count(monkeypatch):
class DummyHTTP:
async def get_gateway_bot(self):
return {"shards": 4}
dummy_manager = AsyncMock()
monkeypatch.setattr("disagreement.client.ShardManager", lambda c, n: dummy_manager)
c = AutoShardedClient(token="x")
c._http = DummyHTTP()
monkeypatch.setattr(c, "wait_until_ready", AsyncMock())
await c.connect()
dummy_manager.start.assert_awaited_once()
assert c.shard_count == 4