From 45a5ef1fb540539864cb9bff23a70dec30158dac Mon Sep 17 00:00:00 2001 From: Slipstream Date: Wed, 11 Jun 2025 02:25:24 -0600 Subject: [PATCH] Improves asyncio loop handling and test initialization Replaces deprecated get_event_loop() with proper running loop detection and fallback to new loop creation for better asyncio compatibility. Fixes test suite by replacing manual Client instantiation with proper constructor calls, ensuring all internal caches and attributes are correctly initialized. Updates cache access patterns to use new cache API methods consistently across the codebase. --- disagreement/client.py | 9 +++++- disagreement/ext/commands/core.py | 2 +- disagreement/models.py | 3 +- tests/test_additional_converters.py | 44 +++++++++++++++++------------ tests/test_channel_permissions.py | 16 +++++------ tests/test_event_dispatcher.py | 11 +++++++- tests/test_http_reactions.py | 4 +-- tests/test_textchannel_purge.py | 6 ++-- tests/test_voice_client.py | 39 +++++++++++++++++++++++-- tests/test_webhooks.py | 19 ++++--------- 10 files changed, 100 insertions(+), 53 deletions(-) diff --git a/disagreement/client.py b/disagreement/client.py index 23919eb..6ed4462 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -109,7 +109,14 @@ class Client: member_cache_flags if member_cache_flags is not None else MemberCacheFlags() ) self.intents: int = intents if intents is not None else GatewayIntent.default() - self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() + if loop: + self.loop: asyncio.AbstractEventLoop = loop + else: + try: + self.loop = asyncio.get_running_loop() + except RuntimeError: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) self.application_id: Optional[Snowflake] = ( str(application_id) if application_id else None ) diff --git a/disagreement/ext/commands/core.py b/disagreement/ext/commands/core.py index b9f48c2..5a98997 100644 --- a/disagreement/ext/commands/core.py +++ b/disagreement/ext/commands/core.py @@ -42,7 +42,7 @@ if TYPE_CHECKING: class GroupMixin: def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + super().__init__() self.commands: Dict[str, "Command"] = {} self.name: str = "" diff --git a/disagreement/models.py b/disagreement/models.py index aa18424..9323415 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -10,6 +10,7 @@ from dataclasses import dataclass from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast from .cache import ChannelCache, MemberCache +from .caching import MemberCacheFlags import aiohttp # pylint: disable=import-error from .color import Color @@ -1087,7 +1088,7 @@ class Guild: # Internal caches, populated by events or specific fetches self._channels: ChannelCache = ChannelCache() - self._members: MemberCache = MemberCache(client_instance.member_cache_flags) + self._members: MemberCache = MemberCache(getattr(client_instance, "member_cache_flags", MemberCacheFlags())) self._threads: Dict[str, "Thread"] = {} def get_channel(self, channel_id: str) -> Optional["Channel"]: diff --git a/tests/test_additional_converters.py b/tests/test_additional_converters.py index 88f4975..40da104 100644 --- a/tests/test_additional_converters.py +++ b/tests/test_additional_converters.py @@ -14,23 +14,28 @@ from disagreement.enums import ( ) -class DummyBot: - def __init__(self, guild: Guild): - self._guilds = {guild.id: guild} +from disagreement.client import Client +from disagreement.cache import GuildCache - def get_guild(self, gid): - return self._guilds.get(gid) - async def fetch_member(self, gid, mid): - guild = self._guilds.get(gid) - return guild.get_member(mid) if guild else None +class DummyBot(Client): + def __init__(self): + super().__init__(token="test") + self._guilds = GuildCache() - async def fetch_role(self, gid, rid): - guild = self._guilds.get(gid) - return guild.get_role(rid) if guild else None + def get_guild(self, guild_id): + return self._guilds.get(guild_id) - async def fetch_guild(self, gid): - return self._guilds.get(gid) + async def fetch_member(self, guild_id, member_id): + guild = self._guilds.get(guild_id) + return guild.get_member(member_id) if guild else None + + async def fetch_role(self, guild_id, role_id): + guild = self._guilds.get(guild_id) + return guild.get_role(role_id) if guild else None + + async def fetch_guild(self, guild_id): + return self._guilds.get(guild_id) @pytest.fixture() @@ -51,7 +56,9 @@ def guild_objects(): "premium_tier": PremiumTier.NONE.value, "nsfw_level": GuildNSFWLevel.DEFAULT.value, } - guild = Guild(guild_data, client_instance=None) + bot = DummyBot() + guild = Guild(guild_data, client_instance=bot) + bot._guilds.set(guild.id, guild) member = Member( { @@ -76,7 +83,7 @@ def guild_objects(): } ) - guild._members[member.id] = member + guild._members.set(member.id, member) guild.roles.append(role) return guild, member, role @@ -85,7 +92,7 @@ def guild_objects(): @pytest.fixture() def command_context(guild_objects): guild, member, role = guild_objects - bot = DummyBot(guild) + bot = guild._client message_data = { "id": "10", "channel_id": "20", @@ -150,8 +157,9 @@ async def test_member_converter_no_guild(): "premium_tier": PremiumTier.NONE.value, "nsfw_level": GuildNSFWLevel.DEFAULT.value, } - guild = Guild(guild_data, client_instance=None) - bot = DummyBot(guild) + bot = DummyBot() + guild = Guild(guild_data, client_instance=bot) + bot._guilds.set(guild.id, guild) message_data = { "id": "11", "channel_id": "20", diff --git a/tests/test_channel_permissions.py b/tests/test_channel_permissions.py index 0d1024a..2ee9186 100644 --- a/tests/test_channel_permissions.py +++ b/tests/test_channel_permissions.py @@ -14,12 +14,12 @@ from disagreement.enums import ( from disagreement.permissions import Permissions -class DummyClient: - def __init__(self): - self._guilds = {} +from disagreement.client import Client - def get_guild(self, gid): - return self._guilds.get(gid) + +class DummyClient(Client): + def __init__(self): + super().__init__(token="test") def _base_guild(client): @@ -40,7 +40,7 @@ def _base_guild(client): "nsfw_level": GuildNSFWLevel.DEFAULT.value, } guild = Guild(data, client_instance=client) - client._guilds[guild.id] = guild + client._guilds.set(guild.id, guild) return guild @@ -52,7 +52,7 @@ def _member(guild, *roles): } member = Member(data, client_instance=None) member.guild_id = guild.id - guild._members[member.id] = member + guild._members.set(member.id, member) return member @@ -81,7 +81,7 @@ def _channel(guild, client): "permission_overwrites": [], } channel = TextChannel(data, client_instance=client) - guild._channels[channel.id] = channel + guild._channels.set(channel.id, channel) return channel diff --git a/tests/test_event_dispatcher.py b/tests/test_event_dispatcher.py index 8558345..289881a 100644 --- a/tests/test_event_dispatcher.py +++ b/tests/test_event_dispatcher.py @@ -5,10 +5,14 @@ import pytest from disagreement.event_dispatcher import EventDispatcher +from disagreement.cache import Cache + + class DummyClient: def __init__(self): self.parsed = {} - self._messages = {"1": "cached"} + self._messages = Cache() + self._messages.set("1", "cached") def parse_message(self, data): self.parsed["message"] = True @@ -22,6 +26,11 @@ class DummyClient: self.parsed["channel"] = True return data + def parse_message_delete(self, data): + message = self._messages.get(data["id"]) + self._messages.invalidate(data["id"]) + return message + @pytest.mark.asyncio async def test_dispatch_calls_listener(): diff --git a/tests/test_http_reactions.py b/tests/test_http_reactions.py index 02b43a0..730062a 100644 --- a/tests/test_http_reactions.py +++ b/tests/test_http_reactions.py @@ -73,10 +73,8 @@ async def test_delete_reaction_calls_http(): async def test_get_reactions_parses_users(): users_payload = [{"id": "1", "username": "u", "discriminator": "0001"}] http = SimpleNamespace(get_reactions=AsyncMock(return_value=users_payload)) - client = Client.__new__(Client) + client = Client(token="test") client._http = http - client._closed = False - client._users = {} users = await client.get_reactions("1", "2", "😀") diff --git a/tests/test_textchannel_purge.py b/tests/test_textchannel_purge.py index 1be072a..172cc76 100644 --- a/tests/test_textchannel_purge.py +++ b/tests/test_textchannel_purge.py @@ -12,9 +12,8 @@ async def test_textchannel_purge_calls_bulk_delete(): request=AsyncMock(return_value=[{"id": "1"}, {"id": "2"}]), bulk_delete_messages=AsyncMock(), ) - client = Client.__new__(Client) + client = Client(token="test") client._http = http - client._messages = {} channel = TextChannel({"id": "c", "type": 0}, client) @@ -33,9 +32,8 @@ async def test_textchannel_purge_before_param(): request=AsyncMock(return_value=[]), bulk_delete_messages=AsyncMock(), ) - client = Client.__new__(Client) + client = Client(token="test") client._http = http - client._messages = {} channel = TextChannel({"id": "c", "type": 0}, client) diff --git a/tests/test_voice_client.py b/tests/test_voice_client.py index fae0289..da052bc 100644 --- a/tests/test_voice_client.py +++ b/tests/test_voice_client.py @@ -3,6 +3,12 @@ import pytest from disagreement.voice_client import VoiceClient from disagreement.audio import AudioSource +from disagreement.client import Client + + +class DummyVoiceClient(Client): + def __init__(self): + super().__init__(token="test") class DummyWebSocket: @@ -58,7 +64,16 @@ async def test_voice_client_handshake(): ws = DummyWebSocket([hello, ready, session_desc]) udp = DummyUDP() - vc = VoiceClient("ws://localhost", "sess", "tok", 1, 2, ws=ws, udp=udp) + vc = VoiceClient( + client=DummyVoiceClient(), + endpoint="ws://localhost", + session_id="sess", + token="tok", + guild_id=1, + user_id=2, + ws=ws, + udp=udp, + ) await vc.connect() vc._heartbeat_task.cancel() @@ -78,7 +93,16 @@ async def test_send_audio_frame(): ] ) udp = DummyUDP() - vc = VoiceClient("ws://localhost", "sess", "tok", 1, 2, ws=ws, udp=udp) + vc = VoiceClient( + client=DummyVoiceClient(), + endpoint="ws://localhost", + session_id="sess", + token="tok", + guild_id=1, + user_id=2, + ws=ws, + udp=udp, + ) await vc.connect() vc._heartbeat_task.cancel() @@ -96,7 +120,16 @@ async def test_play_and_switch_sources(): ] ) udp = DummyUDP() - vc = VoiceClient("ws://localhost", "sess", "tok", 1, 2, ws=ws, udp=udp) + vc = VoiceClient( + client=DummyVoiceClient(), + endpoint="ws://localhost", + session_id="sess", + token="tok", + guild_id=1, + user_id=2, + ws=ws, + udp=udp, + ) await vc.connect() vc._heartbeat_task.cancel() diff --git a/tests/test_webhooks.py b/tests/test_webhooks.py index 129f658..fa82803 100644 --- a/tests/test_webhooks.py +++ b/tests/test_webhooks.py @@ -94,16 +94,14 @@ async def test_client_create_webhook_returns_model(): from disagreement.models import Webhook http = SimpleNamespace(create_webhook=AsyncMock(return_value={"id": "1"})) - client = Client.__new__(Client) + client = Client(token="test") client._http = http - client._closed = False - client._webhooks = {} webhook = await client.create_webhook("123", {"name": "wh"}) http.create_webhook.assert_awaited_once_with("123", {"name": "wh"}) assert isinstance(webhook, Webhook) - assert client._webhooks["1"] is webhook + assert client._webhooks.get("1") is webhook @pytest.mark.asyncio @@ -113,16 +111,14 @@ async def test_client_edit_webhook_returns_model(): from disagreement.models import Webhook http = SimpleNamespace(edit_webhook=AsyncMock(return_value={"id": "1"})) - client = Client.__new__(Client) + client = Client(token="test") client._http = http - client._closed = False - client._webhooks = {} webhook = await client.edit_webhook("1", {"name": "rename"}) http.edit_webhook.assert_awaited_once_with("1", {"name": "rename"}) assert isinstance(webhook, Webhook) - assert client._webhooks["1"] is webhook + assert client._webhooks.get("1") is webhook @pytest.mark.asyncio @@ -131,9 +127,8 @@ async def test_client_delete_webhook_calls_http(): from disagreement.client import Client http = SimpleNamespace(delete_webhook=AsyncMock()) - client = Client.__new__(Client) + client = Client(token="test") client._http = http - client._closed = False await client.delete_webhook("1") @@ -181,10 +176,8 @@ async def test_webhook_send_uses_http(): } ) ) - client = Client.__new__(Client) + client = Client(token="test") client._http = http - client._messages = {} - client._webhooks = {} webhook = Webhook({"id": "1", "token": "tok"}, client_instance=client)