Add AutoShardedClient with automatic shard count (#18)
This commit is contained in:
parent
0eed122f02
commit
1c2241c9c4
@ -117,13 +117,20 @@ roles = await client.fetch_roles(guild.id)
|
||||
|
||||
## Sharding
|
||||
|
||||
To run your bot across multiple gateway shards, pass `shard_count` when creating
|
||||
To run your bot across multiple gateway shards, pass ``shard_count`` when creating
|
||||
the client:
|
||||
|
||||
```python
|
||||
client = disagreement.Client(token=BOT_TOKEN, shard_count=2)
|
||||
```
|
||||
|
||||
If you want the library to determine the recommended shard count automatically,
|
||||
use ``AutoShardedClient``:
|
||||
|
||||
```python
|
||||
client = disagreement.AutoShardedClient(token=BOT_TOKEN)
|
||||
```
|
||||
|
||||
See `examples/sharded_bot.py` for a full example.
|
||||
|
||||
## Contributing
|
||||
|
@ -16,7 +16,7 @@ __license__ = "BSD 3-Clause License"
|
||||
__copyright__ = "Copyright 2025 Slipstream"
|
||||
__version__ = "0.0.2"
|
||||
|
||||
from .client import Client
|
||||
from .client import Client, AutoShardedClient
|
||||
from .models import Message, User, Reaction
|
||||
from .voice_client import VoiceClient
|
||||
from .audio import AudioSource, FFmpegAudioSource
|
||||
|
@ -1287,3 +1287,19 @@ class Client:
|
||||
|
||||
print(f"Unhandled exception in event listener for '{event_method}':")
|
||||
print(f"{type(exc).__name__}: {exc}")
|
||||
|
||||
|
||||
class AutoShardedClient(Client):
|
||||
"""A :class:`Client` that automatically determines the shard count.
|
||||
|
||||
If ``shard_count`` is not provided, the client will query the Discord API
|
||||
via :meth:`HTTPClient.get_gateway_bot` for the recommended shard count and
|
||||
use that when connecting.
|
||||
"""
|
||||
|
||||
async def connect(self, reconnect: bool = True) -> None: # type: ignore[override]
|
||||
if self.shard_count is None:
|
||||
data = await self._http.get_gateway_bot()
|
||||
self.shard_count = data.get("shards", 1)
|
||||
|
||||
await super().connect(reconnect=reconnect)
|
||||
|
20
docs/sharding.md
Normal file
20
docs/sharding.md
Normal file
@ -0,0 +1,20 @@
|
||||
# Sharding
|
||||
|
||||
`disagreement` supports splitting your gateway connection across multiple shards.
|
||||
Use `Client` with the `shard_count` parameter when you want to control the count
|
||||
manually.
|
||||
|
||||
`AutoShardedClient` asks Discord for the recommended number of shards at runtime
|
||||
and configures the `ShardManager` automatically.
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import disagreement
|
||||
|
||||
bot = disagreement.AutoShardedClient(token="YOUR_TOKEN")
|
||||
|
||||
async def main():
|
||||
await bot.run()
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
@ -2,7 +2,7 @@ import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from disagreement.shard_manager import ShardManager
|
||||
from disagreement.client import Client
|
||||
from disagreement.client import Client, AutoShardedClient
|
||||
|
||||
|
||||
class DummyGateway:
|
||||
@ -50,3 +50,19 @@ async def test_client_uses_shard_manager(monkeypatch):
|
||||
monkeypatch.setattr(c, "wait_until_ready", AsyncMock())
|
||||
await c.connect()
|
||||
dummy_manager.start.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_sharded_client_fetches_count(monkeypatch):
|
||||
class DummyHTTP:
|
||||
async def get_gateway_bot(self):
|
||||
return {"shards": 4}
|
||||
|
||||
dummy_manager = AsyncMock()
|
||||
monkeypatch.setattr("disagreement.client.ShardManager", lambda c, n: dummy_manager)
|
||||
c = AutoShardedClient(token="x")
|
||||
c._http = DummyHTTP()
|
||||
monkeypatch.setattr(c, "wait_until_ready", AsyncMock())
|
||||
await c.connect()
|
||||
dummy_manager.start.assert_awaited_once()
|
||||
assert c.shard_count == 4
|
||||
|
Loading…
x
Reference in New Issue
Block a user