diff --git a/README.md b/README.md index 7021d83..9fafb37 100644 --- a/README.md +++ b/README.md @@ -117,13 +117,20 @@ roles = await client.fetch_roles(guild.id) ## Sharding -To run your bot across multiple gateway shards, pass `shard_count` when creating +To run your bot across multiple gateway shards, pass ``shard_count`` when creating the client: ```python client = disagreement.Client(token=BOT_TOKEN, shard_count=2) ``` +If you want the library to determine the recommended shard count automatically, +use ``AutoShardedClient``: + +```python +client = disagreement.AutoShardedClient(token=BOT_TOKEN) +``` + See `examples/sharded_bot.py` for a full example. ## Contributing diff --git a/disagreement/__init__.py b/disagreement/__init__.py index f0eca2b..86f385c 100644 --- a/disagreement/__init__.py +++ b/disagreement/__init__.py @@ -16,7 +16,7 @@ __license__ = "BSD 3-Clause License" __copyright__ = "Copyright 2025 Slipstream" __version__ = "0.0.2" -from .client import Client +from .client import Client, AutoShardedClient from .models import Message, User, Reaction from .voice_client import VoiceClient from .audio import AudioSource, FFmpegAudioSource diff --git a/disagreement/client.py b/disagreement/client.py index b3284fd..87d35f3 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -1287,3 +1287,19 @@ class Client: print(f"Unhandled exception in event listener for '{event_method}':") print(f"{type(exc).__name__}: {exc}") + + +class AutoShardedClient(Client): + """A :class:`Client` that automatically determines the shard count. + + If ``shard_count`` is not provided, the client will query the Discord API + via :meth:`HTTPClient.get_gateway_bot` for the recommended shard count and + use that when connecting. + """ + + async def connect(self, reconnect: bool = True) -> None: # type: ignore[override] + if self.shard_count is None: + data = await self._http.get_gateway_bot() + self.shard_count = data.get("shards", 1) + + await super().connect(reconnect=reconnect) diff --git a/docs/sharding.md b/docs/sharding.md new file mode 100644 index 0000000..aafd933 --- /dev/null +++ b/docs/sharding.md @@ -0,0 +1,20 @@ +# Sharding + +`disagreement` supports splitting your gateway connection across multiple shards. +Use `Client` with the `shard_count` parameter when you want to control the count +manually. + +`AutoShardedClient` asks Discord for the recommended number of shards at runtime +and configures the `ShardManager` automatically. + +```python +import asyncio +import disagreement + +bot = disagreement.AutoShardedClient(token="YOUR_TOKEN") + +async def main(): + await bot.run() + +asyncio.run(main()) +``` diff --git a/tests/test_sharding.py b/tests/test_sharding.py index d69d209..c33ab49 100644 --- a/tests/test_sharding.py +++ b/tests/test_sharding.py @@ -2,7 +2,7 @@ import pytest from unittest.mock import AsyncMock from disagreement.shard_manager import ShardManager -from disagreement.client import Client +from disagreement.client import Client, AutoShardedClient class DummyGateway: @@ -50,3 +50,19 @@ async def test_client_uses_shard_manager(monkeypatch): monkeypatch.setattr(c, "wait_until_ready", AsyncMock()) await c.connect() dummy_manager.start.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_auto_sharded_client_fetches_count(monkeypatch): + class DummyHTTP: + async def get_gateway_bot(self): + return {"shards": 4} + + dummy_manager = AsyncMock() + monkeypatch.setattr("disagreement.client.ShardManager", lambda c, n: dummy_manager) + c = AutoShardedClient(token="x") + c._http = DummyHTTP() + monkeypatch.setattr(c, "wait_until_ready", AsyncMock()) + await c.connect() + dummy_manager.start.assert_awaited_once() + assert c.shard_count == 4