Store shard id on guild and expose shard property (#87)

This commit is contained in:
Slipstream 2025-06-15 15:19:58 -06:00 committed by GitHub
parent a93ad432b7
commit 775dce0c80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 102 additions and 10 deletions

View File

@ -926,12 +926,13 @@ class Client:
guild.roles.append(role) guild.roles.append(role)
return role return role
def parse_guild(self, data: Dict[str, Any]) -> "Guild": def parse_guild(self, data: Dict[str, Any]) -> "Guild":
"""Parses guild data and returns a Guild object, updating cache.""" """Parses guild data and returns a :class:`Guild` object, updating cache."""
from .models import Guild from .models import Guild
guild = Guild(data, client_instance=self) shard_id = data.get("shard_id")
self._guilds.set(guild.id, guild) guild = Guild(data, client_instance=self, shard_id=shard_id)
self._guilds.set(guild.id, guild)
presences = {p["user"]["id"]: p for p in data.get("presences", [])} presences = {p["user"]["id"]: p for p in data.get("presences", [])}
voice_states = {vs["user_id"]: vs for vs in data.get("voice_states", [])} voice_states = {vs["user_id"]: vs for vs in data.get("voice_states", [])}

View File

@ -338,6 +338,8 @@ class GatewayClient:
self._client_instance._ready_event.set() self._client_instance._ready_event.set()
logger.info("Client is now marked as ready.") logger.info("Client is now marked as ready.")
if isinstance(raw_event_d_payload, dict) and self._shard_id is not None:
raw_event_d_payload["shard_id"] = self._shard_id
await self._dispatcher.dispatch(event_name, raw_event_d_payload) await self._dispatcher.dispatch(event_name, raw_event_d_payload)
elif event_name == "GUILD_MEMBERS_CHUNK": elif event_name == "GUILD_MEMBERS_CHUNK":
if isinstance(raw_event_d_payload, dict): if isinstance(raw_event_d_payload, dict):
@ -388,6 +390,8 @@ class GatewayClient:
event_data_to_dispatch = ( event_data_to_dispatch = (
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {} raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
) )
if isinstance(event_data_to_dispatch, dict) and self._shard_id is not None:
event_data_to_dispatch["shard_id"] = self._shard_id
await self._dispatcher.dispatch(event_name, event_data_to_dispatch) await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
await self._dispatcher.dispatch( await self._dispatcher.dispatch(
"SHARD_RESUME", {"shard_id": self._shard_id} "SHARD_RESUME", {"shard_id": self._shard_id}
@ -398,6 +402,8 @@ class GatewayClient:
event_data_to_dispatch = ( event_data_to_dispatch = (
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {} raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
) )
if isinstance(event_data_to_dispatch, dict) and self._shard_id is not None:
event_data_to_dispatch["shard_id"] = self._shard_id
await self._dispatcher.dispatch(event_name, event_data_to_dispatch) await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
else: else:

View File

@ -37,9 +37,10 @@ from .permissions import Permissions
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import Client # For type hinting to avoid circular imports from .client import Client # For type hinting to avoid circular imports
from .enums import OverwriteType # For PermissionOverwrite model from .enums import OverwriteType # For PermissionOverwrite model
from .ui.view import View from .ui.view import View
from .interactions import Snowflake from .interactions import Snowflake
from .typing import Typing from .typing import Typing
from .shard_manager import Shard
# Forward reference Message if it were used in type hints before its definition # Forward reference Message if it were used in type hints before its definition
# from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc. # from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc.
@ -1086,8 +1087,17 @@ class Guild:
premium_progress_bar_enabled (bool): Whether the guild has the premium progress bar enabled. premium_progress_bar_enabled (bool): Whether the guild has the premium progress bar enabled.
""" """
def __init__(self, data: Dict[str, Any], client_instance: "Client"): def __init__(
self,
data: Dict[str, Any],
client_instance: "Client",
*,
shard_id: Optional[int] = None,
):
self._client: "Client" = client_instance self._client: "Client" = client_instance
self._shard_id: Optional[int] = (
shard_id if shard_id is not None else data.get("shard_id")
)
self.id: str = data["id"] self.id: str = data["id"]
self.name: str = data["name"] self.name: str = data["name"]
self.icon: Optional[str] = data.get("icon") self.icon: Optional[str] = data.get("icon")
@ -1169,6 +1179,25 @@ class Guild:
) )
self._threads: Dict[str, "Thread"] = {} self._threads: Dict[str, "Thread"] = {}
@property
def shard_id(self) -> Optional[int]:
"""ID of the shard that received this guild, if any."""
return self._shard_id
@property
def shard(self) -> Optional["Shard"]:
"""The :class:`Shard` this guild belongs to."""
if self._shard_id is None:
return None
manager = getattr(self._client, "_shard_manager", None)
if not manager:
return None
if 0 <= self._shard_id < len(manager.shards):
return manager.shards[self._shard_id]
return None
def get_channel(self, channel_id: str) -> Optional["Channel"]: def get_channel(self, channel_id: str) -> Optional["Channel"]:
return self._channels.get(channel_id) return self._channels.get(channel_id)

View File

@ -0,0 +1,56 @@
import pytest
from unittest.mock import Mock
from disagreement.models import Guild
from disagreement.enums import (
VerificationLevel,
MessageNotificationLevel,
ExplicitContentFilterLevel,
MFALevel,
GuildNSFWLevel,
PremiumTier,
)
class DummyShard:
def __init__(self, shard_id):
self.id = shard_id
self.count = 1
self.gateway = Mock()
class DummyManager:
def __init__(self):
self.shards = [DummyShard(0)]
class DummyClient:
pass
def _guild_data():
return {
"id": "1",
"name": "g",
"owner_id": "1",
"afk_timeout": 60,
"verification_level": VerificationLevel.NONE.value,
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
"roles": [],
"emojis": [],
"features": [],
"mfa_level": MFALevel.NONE.value,
"system_channel_flags": 0,
"premium_tier": PremiumTier.NONE.value,
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
"shard_id": 0,
}
def test_guild_shard_property():
client = DummyClient()
client._shard_manager = DummyManager()
guild = Guild(_guild_data(), client_instance=client, shard_id=0)
assert guild.shard_id == 0
assert guild.shard is client._shard_manager.shards[0]