Add AutoShardedClient with automatic shard count (#18)
This commit is contained in:
parent
0eed122f02
commit
1c2241c9c4
@ -117,13 +117,20 @@ roles = await client.fetch_roles(guild.id)
|
|||||||
|
|
||||||
## Sharding
|
## Sharding
|
||||||
|
|
||||||
To run your bot across multiple gateway shards, pass `shard_count` when creating
|
To run your bot across multiple gateway shards, pass ``shard_count`` when creating
|
||||||
the client:
|
the client:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
client = disagreement.Client(token=BOT_TOKEN, shard_count=2)
|
client = disagreement.Client(token=BOT_TOKEN, shard_count=2)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you want the library to determine the recommended shard count automatically,
|
||||||
|
use ``AutoShardedClient``:
|
||||||
|
|
||||||
|
```python
|
||||||
|
client = disagreement.AutoShardedClient(token=BOT_TOKEN)
|
||||||
|
```
|
||||||
|
|
||||||
See `examples/sharded_bot.py` for a full example.
|
See `examples/sharded_bot.py` for a full example.
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
@ -16,7 +16,7 @@ __license__ = "BSD 3-Clause License"
|
|||||||
__copyright__ = "Copyright 2025 Slipstream"
|
__copyright__ = "Copyright 2025 Slipstream"
|
||||||
__version__ = "0.0.2"
|
__version__ = "0.0.2"
|
||||||
|
|
||||||
from .client import Client
|
from .client import Client, AutoShardedClient
|
||||||
from .models import Message, User, Reaction
|
from .models import Message, User, Reaction
|
||||||
from .voice_client import VoiceClient
|
from .voice_client import VoiceClient
|
||||||
from .audio import AudioSource, FFmpegAudioSource
|
from .audio import AudioSource, FFmpegAudioSource
|
||||||
|
@ -1287,3 +1287,19 @@ class Client:
|
|||||||
|
|
||||||
print(f"Unhandled exception in event listener for '{event_method}':")
|
print(f"Unhandled exception in event listener for '{event_method}':")
|
||||||
print(f"{type(exc).__name__}: {exc}")
|
print(f"{type(exc).__name__}: {exc}")
|
||||||
|
|
||||||
|
|
||||||
|
class AutoShardedClient(Client):
|
||||||
|
"""A :class:`Client` that automatically determines the shard count.
|
||||||
|
|
||||||
|
If ``shard_count`` is not provided, the client will query the Discord API
|
||||||
|
via :meth:`HTTPClient.get_gateway_bot` for the recommended shard count and
|
||||||
|
use that when connecting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def connect(self, reconnect: bool = True) -> None: # type: ignore[override]
|
||||||
|
if self.shard_count is None:
|
||||||
|
data = await self._http.get_gateway_bot()
|
||||||
|
self.shard_count = data.get("shards", 1)
|
||||||
|
|
||||||
|
await super().connect(reconnect=reconnect)
|
||||||
|
20
docs/sharding.md
Normal file
20
docs/sharding.md
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# Sharding
|
||||||
|
|
||||||
|
`disagreement` supports splitting your gateway connection across multiple shards.
|
||||||
|
Use `Client` with the `shard_count` parameter when you want to control the count
|
||||||
|
manually.
|
||||||
|
|
||||||
|
`AutoShardedClient` asks Discord for the recommended number of shards at runtime
|
||||||
|
and configures the `ShardManager` automatically.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import asyncio
|
||||||
|
import disagreement
|
||||||
|
|
||||||
|
bot = disagreement.AutoShardedClient(token="YOUR_TOKEN")
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
await bot.run()
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
```
|
@ -2,7 +2,7 @@ import pytest
|
|||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
from disagreement.shard_manager import ShardManager
|
from disagreement.shard_manager import ShardManager
|
||||||
from disagreement.client import Client
|
from disagreement.client import Client, AutoShardedClient
|
||||||
|
|
||||||
|
|
||||||
class DummyGateway:
|
class DummyGateway:
|
||||||
@ -50,3 +50,19 @@ async def test_client_uses_shard_manager(monkeypatch):
|
|||||||
monkeypatch.setattr(c, "wait_until_ready", AsyncMock())
|
monkeypatch.setattr(c, "wait_until_ready", AsyncMock())
|
||||||
await c.connect()
|
await c.connect()
|
||||||
dummy_manager.start.assert_awaited_once()
|
dummy_manager.start.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_sharded_client_fetches_count(monkeypatch):
|
||||||
|
class DummyHTTP:
|
||||||
|
async def get_gateway_bot(self):
|
||||||
|
return {"shards": 4}
|
||||||
|
|
||||||
|
dummy_manager = AsyncMock()
|
||||||
|
monkeypatch.setattr("disagreement.client.ShardManager", lambda c, n: dummy_manager)
|
||||||
|
c = AutoShardedClient(token="x")
|
||||||
|
c._http = DummyHTTP()
|
||||||
|
monkeypatch.setattr(c, "wait_until_ready", AsyncMock())
|
||||||
|
await c.connect()
|
||||||
|
dummy_manager.start.assert_awaited_once()
|
||||||
|
assert c.shard_count == 4
|
||||||
|
Loading…
x
Reference in New Issue
Block a user