Improves asyncio loop handling and test initialization

Replaces deprecated get_event_loop() with proper running loop detection and fallback to new loop creation for better asyncio compatibility.

Fixes test suite by replacing manual Client instantiation with proper constructor calls, ensuring all internal caches and attributes are correctly initialized.

Updates cache access patterns to use new cache API methods consistently across the codebase.
This commit is contained in:
Slipstream 2025-06-11 02:25:24 -06:00
parent ed83a9da85
commit 45a5ef1fb5
Signed by: slipstream
GPG Key ID: 13E498CE010AC6FD
10 changed files with 100 additions and 53 deletions

View File

@ -109,7 +109,14 @@ class Client:
member_cache_flags if member_cache_flags is not None else MemberCacheFlags() member_cache_flags if member_cache_flags is not None else MemberCacheFlags()
) )
self.intents: int = intents if intents is not None else GatewayIntent.default() self.intents: int = intents if intents is not None else GatewayIntent.default()
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() if loop:
self.loop: asyncio.AbstractEventLoop = loop
else:
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.application_id: Optional[Snowflake] = ( self.application_id: Optional[Snowflake] = (
str(application_id) if application_id else None str(application_id) if application_id else None
) )

View File

@ -42,7 +42,7 @@ if TYPE_CHECKING:
class GroupMixin: class GroupMixin:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__()
self.commands: Dict[str, "Command"] = {} self.commands: Dict[str, "Command"] = {}
self.name: str = "" self.name: str = ""

View File

@ -10,6 +10,7 @@ from dataclasses import dataclass
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast
from .cache import ChannelCache, MemberCache from .cache import ChannelCache, MemberCache
from .caching import MemberCacheFlags
import aiohttp # pylint: disable=import-error import aiohttp # pylint: disable=import-error
from .color import Color from .color import Color
@ -1087,7 +1088,7 @@ class Guild:
# Internal caches, populated by events or specific fetches # Internal caches, populated by events or specific fetches
self._channels: ChannelCache = ChannelCache() self._channels: ChannelCache = ChannelCache()
self._members: MemberCache = MemberCache(client_instance.member_cache_flags) self._members: MemberCache = MemberCache(getattr(client_instance, "member_cache_flags", MemberCacheFlags()))
self._threads: Dict[str, "Thread"] = {} self._threads: Dict[str, "Thread"] = {}
def get_channel(self, channel_id: str) -> Optional["Channel"]: def get_channel(self, channel_id: str) -> Optional["Channel"]:

View File

@ -14,23 +14,28 @@ from disagreement.enums import (
) )
class DummyBot: from disagreement.client import Client
def __init__(self, guild: Guild): from disagreement.cache import GuildCache
self._guilds = {guild.id: guild}
def get_guild(self, gid):
return self._guilds.get(gid)
async def fetch_member(self, gid, mid): class DummyBot(Client):
guild = self._guilds.get(gid) def __init__(self):
return guild.get_member(mid) if guild else None super().__init__(token="test")
self._guilds = GuildCache()
async def fetch_role(self, gid, rid): def get_guild(self, guild_id):
guild = self._guilds.get(gid) return self._guilds.get(guild_id)
return guild.get_role(rid) if guild else None
async def fetch_guild(self, gid): async def fetch_member(self, guild_id, member_id):
return self._guilds.get(gid) guild = self._guilds.get(guild_id)
return guild.get_member(member_id) if guild else None
async def fetch_role(self, guild_id, role_id):
guild = self._guilds.get(guild_id)
return guild.get_role(role_id) if guild else None
async def fetch_guild(self, guild_id):
return self._guilds.get(guild_id)
@pytest.fixture() @pytest.fixture()
@ -51,7 +56,9 @@ def guild_objects():
"premium_tier": PremiumTier.NONE.value, "premium_tier": PremiumTier.NONE.value,
"nsfw_level": GuildNSFWLevel.DEFAULT.value, "nsfw_level": GuildNSFWLevel.DEFAULT.value,
} }
guild = Guild(guild_data, client_instance=None) bot = DummyBot()
guild = Guild(guild_data, client_instance=bot)
bot._guilds.set(guild.id, guild)
member = Member( member = Member(
{ {
@ -76,7 +83,7 @@ def guild_objects():
} }
) )
guild._members[member.id] = member guild._members.set(member.id, member)
guild.roles.append(role) guild.roles.append(role)
return guild, member, role return guild, member, role
@ -85,7 +92,7 @@ def guild_objects():
@pytest.fixture() @pytest.fixture()
def command_context(guild_objects): def command_context(guild_objects):
guild, member, role = guild_objects guild, member, role = guild_objects
bot = DummyBot(guild) bot = guild._client
message_data = { message_data = {
"id": "10", "id": "10",
"channel_id": "20", "channel_id": "20",
@ -150,8 +157,9 @@ async def test_member_converter_no_guild():
"premium_tier": PremiumTier.NONE.value, "premium_tier": PremiumTier.NONE.value,
"nsfw_level": GuildNSFWLevel.DEFAULT.value, "nsfw_level": GuildNSFWLevel.DEFAULT.value,
} }
guild = Guild(guild_data, client_instance=None) bot = DummyBot()
bot = DummyBot(guild) guild = Guild(guild_data, client_instance=bot)
bot._guilds.set(guild.id, guild)
message_data = { message_data = {
"id": "11", "id": "11",
"channel_id": "20", "channel_id": "20",

View File

@ -14,12 +14,12 @@ from disagreement.enums import (
from disagreement.permissions import Permissions from disagreement.permissions import Permissions
class DummyClient: from disagreement.client import Client
def __init__(self):
self._guilds = {}
def get_guild(self, gid):
return self._guilds.get(gid) class DummyClient(Client):
def __init__(self):
super().__init__(token="test")
def _base_guild(client): def _base_guild(client):
@ -40,7 +40,7 @@ def _base_guild(client):
"nsfw_level": GuildNSFWLevel.DEFAULT.value, "nsfw_level": GuildNSFWLevel.DEFAULT.value,
} }
guild = Guild(data, client_instance=client) guild = Guild(data, client_instance=client)
client._guilds[guild.id] = guild client._guilds.set(guild.id, guild)
return guild return guild
@ -52,7 +52,7 @@ def _member(guild, *roles):
} }
member = Member(data, client_instance=None) member = Member(data, client_instance=None)
member.guild_id = guild.id member.guild_id = guild.id
guild._members[member.id] = member guild._members.set(member.id, member)
return member return member
@ -81,7 +81,7 @@ def _channel(guild, client):
"permission_overwrites": [], "permission_overwrites": [],
} }
channel = TextChannel(data, client_instance=client) channel = TextChannel(data, client_instance=client)
guild._channels[channel.id] = channel guild._channels.set(channel.id, channel)
return channel return channel

View File

@ -5,10 +5,14 @@ import pytest
from disagreement.event_dispatcher import EventDispatcher from disagreement.event_dispatcher import EventDispatcher
from disagreement.cache import Cache
class DummyClient: class DummyClient:
def __init__(self): def __init__(self):
self.parsed = {} self.parsed = {}
self._messages = {"1": "cached"} self._messages = Cache()
self._messages.set("1", "cached")
def parse_message(self, data): def parse_message(self, data):
self.parsed["message"] = True self.parsed["message"] = True
@ -22,6 +26,11 @@ class DummyClient:
self.parsed["channel"] = True self.parsed["channel"] = True
return data return data
def parse_message_delete(self, data):
message = self._messages.get(data["id"])
self._messages.invalidate(data["id"])
return message
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dispatch_calls_listener(): async def test_dispatch_calls_listener():

View File

@ -73,10 +73,8 @@ async def test_delete_reaction_calls_http():
async def test_get_reactions_parses_users(): async def test_get_reactions_parses_users():
users_payload = [{"id": "1", "username": "u", "discriminator": "0001"}] users_payload = [{"id": "1", "username": "u", "discriminator": "0001"}]
http = SimpleNamespace(get_reactions=AsyncMock(return_value=users_payload)) http = SimpleNamespace(get_reactions=AsyncMock(return_value=users_payload))
client = Client.__new__(Client) client = Client(token="test")
client._http = http client._http = http
client._closed = False
client._users = {}
users = await client.get_reactions("1", "2", "😀") users = await client.get_reactions("1", "2", "😀")

View File

@ -12,9 +12,8 @@ async def test_textchannel_purge_calls_bulk_delete():
request=AsyncMock(return_value=[{"id": "1"}, {"id": "2"}]), request=AsyncMock(return_value=[{"id": "1"}, {"id": "2"}]),
bulk_delete_messages=AsyncMock(), bulk_delete_messages=AsyncMock(),
) )
client = Client.__new__(Client) client = Client(token="test")
client._http = http client._http = http
client._messages = {}
channel = TextChannel({"id": "c", "type": 0}, client) channel = TextChannel({"id": "c", "type": 0}, client)
@ -33,9 +32,8 @@ async def test_textchannel_purge_before_param():
request=AsyncMock(return_value=[]), request=AsyncMock(return_value=[]),
bulk_delete_messages=AsyncMock(), bulk_delete_messages=AsyncMock(),
) )
client = Client.__new__(Client) client = Client(token="test")
client._http = http client._http = http
client._messages = {}
channel = TextChannel({"id": "c", "type": 0}, client) channel = TextChannel({"id": "c", "type": 0}, client)

View File

@ -3,6 +3,12 @@ import pytest
from disagreement.voice_client import VoiceClient from disagreement.voice_client import VoiceClient
from disagreement.audio import AudioSource from disagreement.audio import AudioSource
from disagreement.client import Client
class DummyVoiceClient(Client):
def __init__(self):
super().__init__(token="test")
class DummyWebSocket: class DummyWebSocket:
@ -58,7 +64,16 @@ async def test_voice_client_handshake():
ws = DummyWebSocket([hello, ready, session_desc]) ws = DummyWebSocket([hello, ready, session_desc])
udp = DummyUDP() udp = DummyUDP()
vc = VoiceClient("ws://localhost", "sess", "tok", 1, 2, ws=ws, udp=udp) vc = VoiceClient(
client=DummyVoiceClient(),
endpoint="ws://localhost",
session_id="sess",
token="tok",
guild_id=1,
user_id=2,
ws=ws,
udp=udp,
)
await vc.connect() await vc.connect()
vc._heartbeat_task.cancel() vc._heartbeat_task.cancel()
@ -78,7 +93,16 @@ async def test_send_audio_frame():
] ]
) )
udp = DummyUDP() udp = DummyUDP()
vc = VoiceClient("ws://localhost", "sess", "tok", 1, 2, ws=ws, udp=udp) vc = VoiceClient(
client=DummyVoiceClient(),
endpoint="ws://localhost",
session_id="sess",
token="tok",
guild_id=1,
user_id=2,
ws=ws,
udp=udp,
)
await vc.connect() await vc.connect()
vc._heartbeat_task.cancel() vc._heartbeat_task.cancel()
@ -96,7 +120,16 @@ async def test_play_and_switch_sources():
] ]
) )
udp = DummyUDP() udp = DummyUDP()
vc = VoiceClient("ws://localhost", "sess", "tok", 1, 2, ws=ws, udp=udp) vc = VoiceClient(
client=DummyVoiceClient(),
endpoint="ws://localhost",
session_id="sess",
token="tok",
guild_id=1,
user_id=2,
ws=ws,
udp=udp,
)
await vc.connect() await vc.connect()
vc._heartbeat_task.cancel() vc._heartbeat_task.cancel()

View File

@ -94,16 +94,14 @@ async def test_client_create_webhook_returns_model():
from disagreement.models import Webhook from disagreement.models import Webhook
http = SimpleNamespace(create_webhook=AsyncMock(return_value={"id": "1"})) http = SimpleNamespace(create_webhook=AsyncMock(return_value={"id": "1"}))
client = Client.__new__(Client) client = Client(token="test")
client._http = http client._http = http
client._closed = False
client._webhooks = {}
webhook = await client.create_webhook("123", {"name": "wh"}) webhook = await client.create_webhook("123", {"name": "wh"})
http.create_webhook.assert_awaited_once_with("123", {"name": "wh"}) http.create_webhook.assert_awaited_once_with("123", {"name": "wh"})
assert isinstance(webhook, Webhook) assert isinstance(webhook, Webhook)
assert client._webhooks["1"] is webhook assert client._webhooks.get("1") is webhook
@pytest.mark.asyncio @pytest.mark.asyncio
@ -113,16 +111,14 @@ async def test_client_edit_webhook_returns_model():
from disagreement.models import Webhook from disagreement.models import Webhook
http = SimpleNamespace(edit_webhook=AsyncMock(return_value={"id": "1"})) http = SimpleNamespace(edit_webhook=AsyncMock(return_value={"id": "1"}))
client = Client.__new__(Client) client = Client(token="test")
client._http = http client._http = http
client._closed = False
client._webhooks = {}
webhook = await client.edit_webhook("1", {"name": "rename"}) webhook = await client.edit_webhook("1", {"name": "rename"})
http.edit_webhook.assert_awaited_once_with("1", {"name": "rename"}) http.edit_webhook.assert_awaited_once_with("1", {"name": "rename"})
assert isinstance(webhook, Webhook) assert isinstance(webhook, Webhook)
assert client._webhooks["1"] is webhook assert client._webhooks.get("1") is webhook
@pytest.mark.asyncio @pytest.mark.asyncio
@ -131,9 +127,8 @@ async def test_client_delete_webhook_calls_http():
from disagreement.client import Client from disagreement.client import Client
http = SimpleNamespace(delete_webhook=AsyncMock()) http = SimpleNamespace(delete_webhook=AsyncMock())
client = Client.__new__(Client) client = Client(token="test")
client._http = http client._http = http
client._closed = False
await client.delete_webhook("1") await client.delete_webhook("1")
@ -181,10 +176,8 @@ async def test_webhook_send_uses_http():
} }
) )
) )
client = Client.__new__(Client) client = Client(token="test")
client._http = http client._http = http
client._messages = {}
client._webhooks = {}
webhook = Webhook({"id": "1", "token": "tok"}, client_instance=client) webhook = Webhook({"id": "1", "token": "tok"}, client_instance=client)