Improves asyncio loop handling and test initialization

Replaces deprecated get_event_loop() with proper running loop detection and fallback to new loop creation for better asyncio compatibility.

Fixes test suite by replacing manual Client instantiation with proper constructor calls, ensuring all internal caches and attributes are correctly initialized.

Updates cache access patterns to use new cache API methods consistently across the codebase.
This commit is contained in:
Slipstream 2025-06-11 02:25:24 -06:00
parent ed83a9da85
commit 45a5ef1fb5
Signed by: slipstream
GPG Key ID: 13E498CE010AC6FD
10 changed files with 100 additions and 53 deletions

View File

@ -109,7 +109,14 @@ class Client:
member_cache_flags if member_cache_flags is not None else MemberCacheFlags()
)
self.intents: int = intents if intents is not None else GatewayIntent.default()
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
if loop:
self.loop: asyncio.AbstractEventLoop = loop
else:
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.application_id: Optional[Snowflake] = (
str(application_id) if application_id else None
)

View File

@ -42,7 +42,7 @@ if TYPE_CHECKING:
class GroupMixin:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
super().__init__()
self.commands: Dict[str, "Command"] = {}
self.name: str = ""

View File

@ -10,6 +10,7 @@ from dataclasses import dataclass
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast
from .cache import ChannelCache, MemberCache
from .caching import MemberCacheFlags
import aiohttp # pylint: disable=import-error
from .color import Color
@ -1087,7 +1088,7 @@ class Guild:
# Internal caches, populated by events or specific fetches
self._channels: ChannelCache = ChannelCache()
self._members: MemberCache = MemberCache(client_instance.member_cache_flags)
self._members: MemberCache = MemberCache(getattr(client_instance, "member_cache_flags", MemberCacheFlags()))
self._threads: Dict[str, "Thread"] = {}
def get_channel(self, channel_id: str) -> Optional["Channel"]:

View File

@ -14,23 +14,28 @@ from disagreement.enums import (
)
class DummyBot:
def __init__(self, guild: Guild):
self._guilds = {guild.id: guild}
from disagreement.client import Client
from disagreement.cache import GuildCache
def get_guild(self, gid):
return self._guilds.get(gid)
async def fetch_member(self, gid, mid):
guild = self._guilds.get(gid)
return guild.get_member(mid) if guild else None
class DummyBot(Client):
def __init__(self):
super().__init__(token="test")
self._guilds = GuildCache()
async def fetch_role(self, gid, rid):
guild = self._guilds.get(gid)
return guild.get_role(rid) if guild else None
def get_guild(self, guild_id):
return self._guilds.get(guild_id)
async def fetch_guild(self, gid):
return self._guilds.get(gid)
async def fetch_member(self, guild_id, member_id):
guild = self._guilds.get(guild_id)
return guild.get_member(member_id) if guild else None
async def fetch_role(self, guild_id, role_id):
guild = self._guilds.get(guild_id)
return guild.get_role(role_id) if guild else None
async def fetch_guild(self, guild_id):
return self._guilds.get(guild_id)
@pytest.fixture()
@ -51,7 +56,9 @@ def guild_objects():
"premium_tier": PremiumTier.NONE.value,
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
}
guild = Guild(guild_data, client_instance=None)
bot = DummyBot()
guild = Guild(guild_data, client_instance=bot)
bot._guilds.set(guild.id, guild)
member = Member(
{
@ -76,7 +83,7 @@ def guild_objects():
}
)
guild._members[member.id] = member
guild._members.set(member.id, member)
guild.roles.append(role)
return guild, member, role
@ -85,7 +92,7 @@ def guild_objects():
@pytest.fixture()
def command_context(guild_objects):
guild, member, role = guild_objects
bot = DummyBot(guild)
bot = guild._client
message_data = {
"id": "10",
"channel_id": "20",
@ -150,8 +157,9 @@ async def test_member_converter_no_guild():
"premium_tier": PremiumTier.NONE.value,
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
}
guild = Guild(guild_data, client_instance=None)
bot = DummyBot(guild)
bot = DummyBot()
guild = Guild(guild_data, client_instance=bot)
bot._guilds.set(guild.id, guild)
message_data = {
"id": "11",
"channel_id": "20",

View File

@ -14,12 +14,12 @@ from disagreement.enums import (
from disagreement.permissions import Permissions
class DummyClient:
def __init__(self):
self._guilds = {}
from disagreement.client import Client
def get_guild(self, gid):
return self._guilds.get(gid)
class DummyClient(Client):
def __init__(self):
super().__init__(token="test")
def _base_guild(client):
@ -40,7 +40,7 @@ def _base_guild(client):
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
}
guild = Guild(data, client_instance=client)
client._guilds[guild.id] = guild
client._guilds.set(guild.id, guild)
return guild
@ -52,7 +52,7 @@ def _member(guild, *roles):
}
member = Member(data, client_instance=None)
member.guild_id = guild.id
guild._members[member.id] = member
guild._members.set(member.id, member)
return member
@ -81,7 +81,7 @@ def _channel(guild, client):
"permission_overwrites": [],
}
channel = TextChannel(data, client_instance=client)
guild._channels[channel.id] = channel
guild._channels.set(channel.id, channel)
return channel

View File

@ -5,10 +5,14 @@ import pytest
from disagreement.event_dispatcher import EventDispatcher
from disagreement.cache import Cache
class DummyClient:
def __init__(self):
self.parsed = {}
self._messages = {"1": "cached"}
self._messages = Cache()
self._messages.set("1", "cached")
def parse_message(self, data):
self.parsed["message"] = True
@ -22,6 +26,11 @@ class DummyClient:
self.parsed["channel"] = True
return data
def parse_message_delete(self, data):
message = self._messages.get(data["id"])
self._messages.invalidate(data["id"])
return message
@pytest.mark.asyncio
async def test_dispatch_calls_listener():

View File

@ -73,10 +73,8 @@ async def test_delete_reaction_calls_http():
async def test_get_reactions_parses_users():
users_payload = [{"id": "1", "username": "u", "discriminator": "0001"}]
http = SimpleNamespace(get_reactions=AsyncMock(return_value=users_payload))
client = Client.__new__(Client)
client = Client(token="test")
client._http = http
client._closed = False
client._users = {}
users = await client.get_reactions("1", "2", "😀")

View File

@ -12,9 +12,8 @@ async def test_textchannel_purge_calls_bulk_delete():
request=AsyncMock(return_value=[{"id": "1"}, {"id": "2"}]),
bulk_delete_messages=AsyncMock(),
)
client = Client.__new__(Client)
client = Client(token="test")
client._http = http
client._messages = {}
channel = TextChannel({"id": "c", "type": 0}, client)
@ -33,9 +32,8 @@ async def test_textchannel_purge_before_param():
request=AsyncMock(return_value=[]),
bulk_delete_messages=AsyncMock(),
)
client = Client.__new__(Client)
client = Client(token="test")
client._http = http
client._messages = {}
channel = TextChannel({"id": "c", "type": 0}, client)

View File

@ -3,6 +3,12 @@ import pytest
from disagreement.voice_client import VoiceClient
from disagreement.audio import AudioSource
from disagreement.client import Client
class DummyVoiceClient(Client):
def __init__(self):
super().__init__(token="test")
class DummyWebSocket:
@ -58,7 +64,16 @@ async def test_voice_client_handshake():
ws = DummyWebSocket([hello, ready, session_desc])
udp = DummyUDP()
vc = VoiceClient("ws://localhost", "sess", "tok", 1, 2, ws=ws, udp=udp)
vc = VoiceClient(
client=DummyVoiceClient(),
endpoint="ws://localhost",
session_id="sess",
token="tok",
guild_id=1,
user_id=2,
ws=ws,
udp=udp,
)
await vc.connect()
vc._heartbeat_task.cancel()
@ -78,7 +93,16 @@ async def test_send_audio_frame():
]
)
udp = DummyUDP()
vc = VoiceClient("ws://localhost", "sess", "tok", 1, 2, ws=ws, udp=udp)
vc = VoiceClient(
client=DummyVoiceClient(),
endpoint="ws://localhost",
session_id="sess",
token="tok",
guild_id=1,
user_id=2,
ws=ws,
udp=udp,
)
await vc.connect()
vc._heartbeat_task.cancel()
@ -96,7 +120,16 @@ async def test_play_and_switch_sources():
]
)
udp = DummyUDP()
vc = VoiceClient("ws://localhost", "sess", "tok", 1, 2, ws=ws, udp=udp)
vc = VoiceClient(
client=DummyVoiceClient(),
endpoint="ws://localhost",
session_id="sess",
token="tok",
guild_id=1,
user_id=2,
ws=ws,
udp=udp,
)
await vc.connect()
vc._heartbeat_task.cancel()

View File

@ -94,16 +94,14 @@ async def test_client_create_webhook_returns_model():
from disagreement.models import Webhook
http = SimpleNamespace(create_webhook=AsyncMock(return_value={"id": "1"}))
client = Client.__new__(Client)
client = Client(token="test")
client._http = http
client._closed = False
client._webhooks = {}
webhook = await client.create_webhook("123", {"name": "wh"})
http.create_webhook.assert_awaited_once_with("123", {"name": "wh"})
assert isinstance(webhook, Webhook)
assert client._webhooks["1"] is webhook
assert client._webhooks.get("1") is webhook
@pytest.mark.asyncio
@ -113,16 +111,14 @@ async def test_client_edit_webhook_returns_model():
from disagreement.models import Webhook
http = SimpleNamespace(edit_webhook=AsyncMock(return_value={"id": "1"}))
client = Client.__new__(Client)
client = Client(token="test")
client._http = http
client._closed = False
client._webhooks = {}
webhook = await client.edit_webhook("1", {"name": "rename"})
http.edit_webhook.assert_awaited_once_with("1", {"name": "rename"})
assert isinstance(webhook, Webhook)
assert client._webhooks["1"] is webhook
assert client._webhooks.get("1") is webhook
@pytest.mark.asyncio
@ -131,9 +127,8 @@ async def test_client_delete_webhook_calls_http():
from disagreement.client import Client
http = SimpleNamespace(delete_webhook=AsyncMock())
client = Client.__new__(Client)
client = Client(token="test")
client._http = http
client._closed = False
await client.delete_webhook("1")
@ -181,10 +176,8 @@ async def test_webhook_send_uses_http():
}
)
)
client = Client.__new__(Client)
client = Client(token="test")
client._http = http
client._messages = {}
client._webhooks = {}
webhook = Webhook({"id": "1", "token": "tok"}, client_instance=client)