From 775dce0c800d44f03a7a80fbbbfeaf90e36ed96e Mon Sep 17 00:00:00 2001 From: Slipstream Date: Sun, 15 Jun 2025 15:19:58 -0600 Subject: [PATCH] Store shard id on guild and expose shard property (#87) --- disagreement/client.py | 13 +++---- disagreement/gateway.py | 6 ++++ disagreement/models.py | 37 +++++++++++++++++--- tests/test_guild_shard_property.py | 56 ++++++++++++++++++++++++++++++ 4 files changed, 102 insertions(+), 10 deletions(-) create mode 100644 tests/test_guild_shard_property.py diff --git a/disagreement/client.py b/disagreement/client.py index 49c5db8..fc58f54 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -926,12 +926,13 @@ class Client: guild.roles.append(role) return role - def parse_guild(self, data: Dict[str, Any]) -> "Guild": - """Parses guild data and returns a Guild object, updating cache.""" - from .models import Guild - - guild = Guild(data, client_instance=self) - self._guilds.set(guild.id, guild) + def parse_guild(self, data: Dict[str, Any]) -> "Guild": + """Parses guild data and returns a :class:`Guild` object, updating cache.""" + from .models import Guild + + shard_id = data.get("shard_id") + guild = Guild(data, client_instance=self, shard_id=shard_id) + self._guilds.set(guild.id, guild) presences = {p["user"]["id"]: p for p in data.get("presences", [])} voice_states = {vs["user_id"]: vs for vs in data.get("voice_states", [])} diff --git a/disagreement/gateway.py b/disagreement/gateway.py index a39ed58..10bb87e 100644 --- a/disagreement/gateway.py +++ b/disagreement/gateway.py @@ -338,6 +338,8 @@ class GatewayClient: self._client_instance._ready_event.set() logger.info("Client is now marked as ready.") + if isinstance(raw_event_d_payload, dict) and self._shard_id is not None: + raw_event_d_payload["shard_id"] = self._shard_id await self._dispatcher.dispatch(event_name, raw_event_d_payload) elif event_name == "GUILD_MEMBERS_CHUNK": if isinstance(raw_event_d_payload, dict): @@ -388,6 +390,8 @@ class GatewayClient: event_data_to_dispatch = ( raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {} ) + if isinstance(event_data_to_dispatch, dict) and self._shard_id is not None: + event_data_to_dispatch["shard_id"] = self._shard_id await self._dispatcher.dispatch(event_name, event_data_to_dispatch) await self._dispatcher.dispatch( "SHARD_RESUME", {"shard_id": self._shard_id} @@ -398,6 +402,8 @@ class GatewayClient: event_data_to_dispatch = ( raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {} ) + if isinstance(event_data_to_dispatch, dict) and self._shard_id is not None: + event_data_to_dispatch["shard_id"] = self._shard_id await self._dispatcher.dispatch(event_name, event_data_to_dispatch) else: diff --git a/disagreement/models.py b/disagreement/models.py index 5ac3e4a..3cade7c 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -37,9 +37,10 @@ from .permissions import Permissions if TYPE_CHECKING: from .client import Client # For type hinting to avoid circular imports from .enums import OverwriteType # For PermissionOverwrite model - from .ui.view import View - from .interactions import Snowflake - from .typing import Typing + from .ui.view import View + from .interactions import Snowflake + from .typing import Typing + from .shard_manager import Shard # Forward reference Message if it were used in type hints before its definition # from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc. @@ -1086,8 +1087,17 @@ class Guild: premium_progress_bar_enabled (bool): Whether the guild has the premium progress bar enabled. """ - def __init__(self, data: Dict[str, Any], client_instance: "Client"): + def __init__( + self, + data: Dict[str, Any], + client_instance: "Client", + *, + shard_id: Optional[int] = None, + ): self._client: "Client" = client_instance + self._shard_id: Optional[int] = ( + shard_id if shard_id is not None else data.get("shard_id") + ) self.id: str = data["id"] self.name: str = data["name"] self.icon: Optional[str] = data.get("icon") @@ -1169,6 +1179,25 @@ class Guild: ) self._threads: Dict[str, "Thread"] = {} + @property + def shard_id(self) -> Optional[int]: + """ID of the shard that received this guild, if any.""" + + return self._shard_id + + @property + def shard(self) -> Optional["Shard"]: + """The :class:`Shard` this guild belongs to.""" + + if self._shard_id is None: + return None + manager = getattr(self._client, "_shard_manager", None) + if not manager: + return None + if 0 <= self._shard_id < len(manager.shards): + return manager.shards[self._shard_id] + return None + def get_channel(self, channel_id: str) -> Optional["Channel"]: return self._channels.get(channel_id) diff --git a/tests/test_guild_shard_property.py b/tests/test_guild_shard_property.py new file mode 100644 index 0000000..8a44886 --- /dev/null +++ b/tests/test_guild_shard_property.py @@ -0,0 +1,56 @@ +import pytest +from unittest.mock import Mock + +from disagreement.models import Guild +from disagreement.enums import ( + VerificationLevel, + MessageNotificationLevel, + ExplicitContentFilterLevel, + MFALevel, + GuildNSFWLevel, + PremiumTier, +) + + +class DummyShard: + def __init__(self, shard_id): + self.id = shard_id + self.count = 1 + self.gateway = Mock() + + +class DummyManager: + def __init__(self): + self.shards = [DummyShard(0)] + + +class DummyClient: + pass + + +def _guild_data(): + return { + "id": "1", + "name": "g", + "owner_id": "1", + "afk_timeout": 60, + "verification_level": VerificationLevel.NONE.value, + "default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value, + "explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value, + "roles": [], + "emojis": [], + "features": [], + "mfa_level": MFALevel.NONE.value, + "system_channel_flags": 0, + "premium_tier": PremiumTier.NONE.value, + "nsfw_level": GuildNSFWLevel.DEFAULT.value, + "shard_id": 0, + } + + +def test_guild_shard_property(): + client = DummyClient() + client._shard_manager = DummyManager() + guild = Guild(_guild_data(), client_instance=client, shard_id=0) + assert guild.shard_id == 0 + assert guild.shard is client._shard_manager.shards[0]