Store shard id on guild and expose shard property (#87)
This commit is contained in:
parent
a93ad432b7
commit
775dce0c80
@ -926,12 +926,13 @@ class Client:
|
||||
guild.roles.append(role)
|
||||
return role
|
||||
|
||||
def parse_guild(self, data: Dict[str, Any]) -> "Guild":
|
||||
"""Parses guild data and returns a Guild object, updating cache."""
|
||||
from .models import Guild
|
||||
|
||||
guild = Guild(data, client_instance=self)
|
||||
self._guilds.set(guild.id, guild)
|
||||
def parse_guild(self, data: Dict[str, Any]) -> "Guild":
|
||||
"""Parses guild data and returns a :class:`Guild` object, updating cache."""
|
||||
from .models import Guild
|
||||
|
||||
shard_id = data.get("shard_id")
|
||||
guild = Guild(data, client_instance=self, shard_id=shard_id)
|
||||
self._guilds.set(guild.id, guild)
|
||||
|
||||
presences = {p["user"]["id"]: p for p in data.get("presences", [])}
|
||||
voice_states = {vs["user_id"]: vs for vs in data.get("voice_states", [])}
|
||||
|
@ -338,6 +338,8 @@ class GatewayClient:
|
||||
self._client_instance._ready_event.set()
|
||||
logger.info("Client is now marked as ready.")
|
||||
|
||||
if isinstance(raw_event_d_payload, dict) and self._shard_id is not None:
|
||||
raw_event_d_payload["shard_id"] = self._shard_id
|
||||
await self._dispatcher.dispatch(event_name, raw_event_d_payload)
|
||||
elif event_name == "GUILD_MEMBERS_CHUNK":
|
||||
if isinstance(raw_event_d_payload, dict):
|
||||
@ -388,6 +390,8 @@ class GatewayClient:
|
||||
event_data_to_dispatch = (
|
||||
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
|
||||
)
|
||||
if isinstance(event_data_to_dispatch, dict) and self._shard_id is not None:
|
||||
event_data_to_dispatch["shard_id"] = self._shard_id
|
||||
await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
|
||||
await self._dispatcher.dispatch(
|
||||
"SHARD_RESUME", {"shard_id": self._shard_id}
|
||||
@ -398,6 +402,8 @@ class GatewayClient:
|
||||
event_data_to_dispatch = (
|
||||
raw_event_d_payload if isinstance(raw_event_d_payload, dict) else {}
|
||||
)
|
||||
if isinstance(event_data_to_dispatch, dict) and self._shard_id is not None:
|
||||
event_data_to_dispatch["shard_id"] = self._shard_id
|
||||
|
||||
await self._dispatcher.dispatch(event_name, event_data_to_dispatch)
|
||||
else:
|
||||
|
@ -37,9 +37,10 @@ from .permissions import Permissions
|
||||
if TYPE_CHECKING:
|
||||
from .client import Client # For type hinting to avoid circular imports
|
||||
from .enums import OverwriteType # For PermissionOverwrite model
|
||||
from .ui.view import View
|
||||
from .interactions import Snowflake
|
||||
from .typing import Typing
|
||||
from .ui.view import View
|
||||
from .interactions import Snowflake
|
||||
from .typing import Typing
|
||||
from .shard_manager import Shard
|
||||
|
||||
# Forward reference Message if it were used in type hints before its definition
|
||||
# from .models import Message # Not needed as Message is defined before its use in TextChannel.send etc.
|
||||
@ -1086,8 +1087,17 @@ class Guild:
|
||||
premium_progress_bar_enabled (bool): Whether the guild has the premium progress bar enabled.
|
||||
"""
|
||||
|
||||
def __init__(self, data: Dict[str, Any], client_instance: "Client"):
|
||||
def __init__(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
client_instance: "Client",
|
||||
*,
|
||||
shard_id: Optional[int] = None,
|
||||
):
|
||||
self._client: "Client" = client_instance
|
||||
self._shard_id: Optional[int] = (
|
||||
shard_id if shard_id is not None else data.get("shard_id")
|
||||
)
|
||||
self.id: str = data["id"]
|
||||
self.name: str = data["name"]
|
||||
self.icon: Optional[str] = data.get("icon")
|
||||
@ -1169,6 +1179,25 @@ class Guild:
|
||||
)
|
||||
self._threads: Dict[str, "Thread"] = {}
|
||||
|
||||
@property
|
||||
def shard_id(self) -> Optional[int]:
|
||||
"""ID of the shard that received this guild, if any."""
|
||||
|
||||
return self._shard_id
|
||||
|
||||
@property
|
||||
def shard(self) -> Optional["Shard"]:
|
||||
"""The :class:`Shard` this guild belongs to."""
|
||||
|
||||
if self._shard_id is None:
|
||||
return None
|
||||
manager = getattr(self._client, "_shard_manager", None)
|
||||
if not manager:
|
||||
return None
|
||||
if 0 <= self._shard_id < len(manager.shards):
|
||||
return manager.shards[self._shard_id]
|
||||
return None
|
||||
|
||||
def get_channel(self, channel_id: str) -> Optional["Channel"]:
|
||||
return self._channels.get(channel_id)
|
||||
|
||||
|
56
tests/test_guild_shard_property.py
Normal file
56
tests/test_guild_shard_property.py
Normal file
@ -0,0 +1,56 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
|
||||
from disagreement.models import Guild
|
||||
from disagreement.enums import (
|
||||
VerificationLevel,
|
||||
MessageNotificationLevel,
|
||||
ExplicitContentFilterLevel,
|
||||
MFALevel,
|
||||
GuildNSFWLevel,
|
||||
PremiumTier,
|
||||
)
|
||||
|
||||
|
||||
class DummyShard:
|
||||
def __init__(self, shard_id):
|
||||
self.id = shard_id
|
||||
self.count = 1
|
||||
self.gateway = Mock()
|
||||
|
||||
|
||||
class DummyManager:
|
||||
def __init__(self):
|
||||
self.shards = [DummyShard(0)]
|
||||
|
||||
|
||||
class DummyClient:
|
||||
pass
|
||||
|
||||
|
||||
def _guild_data():
|
||||
return {
|
||||
"id": "1",
|
||||
"name": "g",
|
||||
"owner_id": "1",
|
||||
"afk_timeout": 60,
|
||||
"verification_level": VerificationLevel.NONE.value,
|
||||
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
|
||||
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
|
||||
"roles": [],
|
||||
"emojis": [],
|
||||
"features": [],
|
||||
"mfa_level": MFALevel.NONE.value,
|
||||
"system_channel_flags": 0,
|
||||
"premium_tier": PremiumTier.NONE.value,
|
||||
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||
"shard_id": 0,
|
||||
}
|
||||
|
||||
|
||||
def test_guild_shard_property():
|
||||
client = DummyClient()
|
||||
client._shard_manager = DummyManager()
|
||||
guild = Guild(_guild_data(), client_instance=client, shard_id=0)
|
||||
assert guild.shard_id == 0
|
||||
assert guild.shard is client._shard_manager.shards[0]
|
Loading…
x
Reference in New Issue
Block a user