Improves asyncio loop handling and test initialization
Replaces deprecated get_event_loop() with proper running loop detection and fallback to new loop creation for better asyncio compatibility. Fixes test suite by replacing manual Client instantiation with proper constructor calls, ensuring all internal caches and attributes are correctly initialized. Updates cache access patterns to use new cache API methods consistently across the codebase.
This commit is contained in:
parent
ed83a9da85
commit
45a5ef1fb5
@ -109,7 +109,14 @@ class Client:
|
|||||||
member_cache_flags if member_cache_flags is not None else MemberCacheFlags()
|
member_cache_flags if member_cache_flags is not None else MemberCacheFlags()
|
||||||
)
|
)
|
||||||
self.intents: int = intents if intents is not None else GatewayIntent.default()
|
self.intents: int = intents if intents is not None else GatewayIntent.default()
|
||||||
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
|
if loop:
|
||||||
|
self.loop: asyncio.AbstractEventLoop = loop
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
self.loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self.loop)
|
||||||
self.application_id: Optional[Snowflake] = (
|
self.application_id: Optional[Snowflake] = (
|
||||||
str(application_id) if application_id else None
|
str(application_id) if application_id else None
|
||||||
)
|
)
|
||||||
|
@ -42,7 +42,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
class GroupMixin:
|
class GroupMixin:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__()
|
||||||
self.commands: Dict[str, "Command"] = {}
|
self.commands: Dict[str, "Command"] = {}
|
||||||
self.name: str = ""
|
self.name: str = ""
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast
|
from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast
|
||||||
|
|
||||||
from .cache import ChannelCache, MemberCache
|
from .cache import ChannelCache, MemberCache
|
||||||
|
from .caching import MemberCacheFlags
|
||||||
|
|
||||||
import aiohttp # pylint: disable=import-error
|
import aiohttp # pylint: disable=import-error
|
||||||
from .color import Color
|
from .color import Color
|
||||||
@ -1087,7 +1088,7 @@ class Guild:
|
|||||||
|
|
||||||
# Internal caches, populated by events or specific fetches
|
# Internal caches, populated by events or specific fetches
|
||||||
self._channels: ChannelCache = ChannelCache()
|
self._channels: ChannelCache = ChannelCache()
|
||||||
self._members: MemberCache = MemberCache(client_instance.member_cache_flags)
|
self._members: MemberCache = MemberCache(getattr(client_instance, "member_cache_flags", MemberCacheFlags()))
|
||||||
self._threads: Dict[str, "Thread"] = {}
|
self._threads: Dict[str, "Thread"] = {}
|
||||||
|
|
||||||
def get_channel(self, channel_id: str) -> Optional["Channel"]:
|
def get_channel(self, channel_id: str) -> Optional["Channel"]:
|
||||||
|
@ -14,23 +14,28 @@ from disagreement.enums import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DummyBot:
|
from disagreement.client import Client
|
||||||
def __init__(self, guild: Guild):
|
from disagreement.cache import GuildCache
|
||||||
self._guilds = {guild.id: guild}
|
|
||||||
|
|
||||||
def get_guild(self, gid):
|
|
||||||
return self._guilds.get(gid)
|
|
||||||
|
|
||||||
async def fetch_member(self, gid, mid):
|
class DummyBot(Client):
|
||||||
guild = self._guilds.get(gid)
|
def __init__(self):
|
||||||
return guild.get_member(mid) if guild else None
|
super().__init__(token="test")
|
||||||
|
self._guilds = GuildCache()
|
||||||
|
|
||||||
async def fetch_role(self, gid, rid):
|
def get_guild(self, guild_id):
|
||||||
guild = self._guilds.get(gid)
|
return self._guilds.get(guild_id)
|
||||||
return guild.get_role(rid) if guild else None
|
|
||||||
|
|
||||||
async def fetch_guild(self, gid):
|
async def fetch_member(self, guild_id, member_id):
|
||||||
return self._guilds.get(gid)
|
guild = self._guilds.get(guild_id)
|
||||||
|
return guild.get_member(member_id) if guild else None
|
||||||
|
|
||||||
|
async def fetch_role(self, guild_id, role_id):
|
||||||
|
guild = self._guilds.get(guild_id)
|
||||||
|
return guild.get_role(role_id) if guild else None
|
||||||
|
|
||||||
|
async def fetch_guild(self, guild_id):
|
||||||
|
return self._guilds.get(guild_id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
@ -51,7 +56,9 @@ def guild_objects():
|
|||||||
"premium_tier": PremiumTier.NONE.value,
|
"premium_tier": PremiumTier.NONE.value,
|
||||||
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||||
}
|
}
|
||||||
guild = Guild(guild_data, client_instance=None)
|
bot = DummyBot()
|
||||||
|
guild = Guild(guild_data, client_instance=bot)
|
||||||
|
bot._guilds.set(guild.id, guild)
|
||||||
|
|
||||||
member = Member(
|
member = Member(
|
||||||
{
|
{
|
||||||
@ -76,7 +83,7 @@ def guild_objects():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
guild._members[member.id] = member
|
guild._members.set(member.id, member)
|
||||||
guild.roles.append(role)
|
guild.roles.append(role)
|
||||||
|
|
||||||
return guild, member, role
|
return guild, member, role
|
||||||
@ -85,7 +92,7 @@ def guild_objects():
|
|||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def command_context(guild_objects):
|
def command_context(guild_objects):
|
||||||
guild, member, role = guild_objects
|
guild, member, role = guild_objects
|
||||||
bot = DummyBot(guild)
|
bot = guild._client
|
||||||
message_data = {
|
message_data = {
|
||||||
"id": "10",
|
"id": "10",
|
||||||
"channel_id": "20",
|
"channel_id": "20",
|
||||||
@ -150,8 +157,9 @@ async def test_member_converter_no_guild():
|
|||||||
"premium_tier": PremiumTier.NONE.value,
|
"premium_tier": PremiumTier.NONE.value,
|
||||||
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||||
}
|
}
|
||||||
guild = Guild(guild_data, client_instance=None)
|
bot = DummyBot()
|
||||||
bot = DummyBot(guild)
|
guild = Guild(guild_data, client_instance=bot)
|
||||||
|
bot._guilds.set(guild.id, guild)
|
||||||
message_data = {
|
message_data = {
|
||||||
"id": "11",
|
"id": "11",
|
||||||
"channel_id": "20",
|
"channel_id": "20",
|
||||||
|
@ -14,12 +14,12 @@ from disagreement.enums import (
|
|||||||
from disagreement.permissions import Permissions
|
from disagreement.permissions import Permissions
|
||||||
|
|
||||||
|
|
||||||
class DummyClient:
|
from disagreement.client import Client
|
||||||
def __init__(self):
|
|
||||||
self._guilds = {}
|
|
||||||
|
|
||||||
def get_guild(self, gid):
|
|
||||||
return self._guilds.get(gid)
|
class DummyClient(Client):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(token="test")
|
||||||
|
|
||||||
|
|
||||||
def _base_guild(client):
|
def _base_guild(client):
|
||||||
@ -40,7 +40,7 @@ def _base_guild(client):
|
|||||||
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
|
||||||
}
|
}
|
||||||
guild = Guild(data, client_instance=client)
|
guild = Guild(data, client_instance=client)
|
||||||
client._guilds[guild.id] = guild
|
client._guilds.set(guild.id, guild)
|
||||||
return guild
|
return guild
|
||||||
|
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ def _member(guild, *roles):
|
|||||||
}
|
}
|
||||||
member = Member(data, client_instance=None)
|
member = Member(data, client_instance=None)
|
||||||
member.guild_id = guild.id
|
member.guild_id = guild.id
|
||||||
guild._members[member.id] = member
|
guild._members.set(member.id, member)
|
||||||
return member
|
return member
|
||||||
|
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ def _channel(guild, client):
|
|||||||
"permission_overwrites": [],
|
"permission_overwrites": [],
|
||||||
}
|
}
|
||||||
channel = TextChannel(data, client_instance=client)
|
channel = TextChannel(data, client_instance=client)
|
||||||
guild._channels[channel.id] = channel
|
guild._channels.set(channel.id, channel)
|
||||||
return channel
|
return channel
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,10 +5,14 @@ import pytest
|
|||||||
from disagreement.event_dispatcher import EventDispatcher
|
from disagreement.event_dispatcher import EventDispatcher
|
||||||
|
|
||||||
|
|
||||||
|
from disagreement.cache import Cache
|
||||||
|
|
||||||
|
|
||||||
class DummyClient:
|
class DummyClient:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.parsed = {}
|
self.parsed = {}
|
||||||
self._messages = {"1": "cached"}
|
self._messages = Cache()
|
||||||
|
self._messages.set("1", "cached")
|
||||||
|
|
||||||
def parse_message(self, data):
|
def parse_message(self, data):
|
||||||
self.parsed["message"] = True
|
self.parsed["message"] = True
|
||||||
@ -22,6 +26,11 @@ class DummyClient:
|
|||||||
self.parsed["channel"] = True
|
self.parsed["channel"] = True
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def parse_message_delete(self, data):
|
||||||
|
message = self._messages.get(data["id"])
|
||||||
|
self._messages.invalidate(data["id"])
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_dispatch_calls_listener():
|
async def test_dispatch_calls_listener():
|
||||||
|
@ -73,10 +73,8 @@ async def test_delete_reaction_calls_http():
|
|||||||
async def test_get_reactions_parses_users():
|
async def test_get_reactions_parses_users():
|
||||||
users_payload = [{"id": "1", "username": "u", "discriminator": "0001"}]
|
users_payload = [{"id": "1", "username": "u", "discriminator": "0001"}]
|
||||||
http = SimpleNamespace(get_reactions=AsyncMock(return_value=users_payload))
|
http = SimpleNamespace(get_reactions=AsyncMock(return_value=users_payload))
|
||||||
client = Client.__new__(Client)
|
client = Client(token="test")
|
||||||
client._http = http
|
client._http = http
|
||||||
client._closed = False
|
|
||||||
client._users = {}
|
|
||||||
|
|
||||||
users = await client.get_reactions("1", "2", "😀")
|
users = await client.get_reactions("1", "2", "😀")
|
||||||
|
|
||||||
|
@ -12,9 +12,8 @@ async def test_textchannel_purge_calls_bulk_delete():
|
|||||||
request=AsyncMock(return_value=[{"id": "1"}, {"id": "2"}]),
|
request=AsyncMock(return_value=[{"id": "1"}, {"id": "2"}]),
|
||||||
bulk_delete_messages=AsyncMock(),
|
bulk_delete_messages=AsyncMock(),
|
||||||
)
|
)
|
||||||
client = Client.__new__(Client)
|
client = Client(token="test")
|
||||||
client._http = http
|
client._http = http
|
||||||
client._messages = {}
|
|
||||||
|
|
||||||
channel = TextChannel({"id": "c", "type": 0}, client)
|
channel = TextChannel({"id": "c", "type": 0}, client)
|
||||||
|
|
||||||
@ -33,9 +32,8 @@ async def test_textchannel_purge_before_param():
|
|||||||
request=AsyncMock(return_value=[]),
|
request=AsyncMock(return_value=[]),
|
||||||
bulk_delete_messages=AsyncMock(),
|
bulk_delete_messages=AsyncMock(),
|
||||||
)
|
)
|
||||||
client = Client.__new__(Client)
|
client = Client(token="test")
|
||||||
client._http = http
|
client._http = http
|
||||||
client._messages = {}
|
|
||||||
|
|
||||||
channel = TextChannel({"id": "c", "type": 0}, client)
|
channel = TextChannel({"id": "c", "type": 0}, client)
|
||||||
|
|
||||||
|
@ -3,6 +3,12 @@ import pytest
|
|||||||
|
|
||||||
from disagreement.voice_client import VoiceClient
|
from disagreement.voice_client import VoiceClient
|
||||||
from disagreement.audio import AudioSource
|
from disagreement.audio import AudioSource
|
||||||
|
from disagreement.client import Client
|
||||||
|
|
||||||
|
|
||||||
|
class DummyVoiceClient(Client):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(token="test")
|
||||||
|
|
||||||
|
|
||||||
class DummyWebSocket:
|
class DummyWebSocket:
|
||||||
@ -58,7 +64,16 @@ async def test_voice_client_handshake():
|
|||||||
ws = DummyWebSocket([hello, ready, session_desc])
|
ws = DummyWebSocket([hello, ready, session_desc])
|
||||||
udp = DummyUDP()
|
udp = DummyUDP()
|
||||||
|
|
||||||
vc = VoiceClient("ws://localhost", "sess", "tok", 1, 2, ws=ws, udp=udp)
|
vc = VoiceClient(
|
||||||
|
client=DummyVoiceClient(),
|
||||||
|
endpoint="ws://localhost",
|
||||||
|
session_id="sess",
|
||||||
|
token="tok",
|
||||||
|
guild_id=1,
|
||||||
|
user_id=2,
|
||||||
|
ws=ws,
|
||||||
|
udp=udp,
|
||||||
|
)
|
||||||
await vc.connect()
|
await vc.connect()
|
||||||
vc._heartbeat_task.cancel()
|
vc._heartbeat_task.cancel()
|
||||||
|
|
||||||
@ -78,7 +93,16 @@ async def test_send_audio_frame():
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
udp = DummyUDP()
|
udp = DummyUDP()
|
||||||
vc = VoiceClient("ws://localhost", "sess", "tok", 1, 2, ws=ws, udp=udp)
|
vc = VoiceClient(
|
||||||
|
client=DummyVoiceClient(),
|
||||||
|
endpoint="ws://localhost",
|
||||||
|
session_id="sess",
|
||||||
|
token="tok",
|
||||||
|
guild_id=1,
|
||||||
|
user_id=2,
|
||||||
|
ws=ws,
|
||||||
|
udp=udp,
|
||||||
|
)
|
||||||
await vc.connect()
|
await vc.connect()
|
||||||
vc._heartbeat_task.cancel()
|
vc._heartbeat_task.cancel()
|
||||||
|
|
||||||
@ -96,7 +120,16 @@ async def test_play_and_switch_sources():
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
udp = DummyUDP()
|
udp = DummyUDP()
|
||||||
vc = VoiceClient("ws://localhost", "sess", "tok", 1, 2, ws=ws, udp=udp)
|
vc = VoiceClient(
|
||||||
|
client=DummyVoiceClient(),
|
||||||
|
endpoint="ws://localhost",
|
||||||
|
session_id="sess",
|
||||||
|
token="tok",
|
||||||
|
guild_id=1,
|
||||||
|
user_id=2,
|
||||||
|
ws=ws,
|
||||||
|
udp=udp,
|
||||||
|
)
|
||||||
await vc.connect()
|
await vc.connect()
|
||||||
vc._heartbeat_task.cancel()
|
vc._heartbeat_task.cancel()
|
||||||
|
|
||||||
|
@ -94,16 +94,14 @@ async def test_client_create_webhook_returns_model():
|
|||||||
from disagreement.models import Webhook
|
from disagreement.models import Webhook
|
||||||
|
|
||||||
http = SimpleNamespace(create_webhook=AsyncMock(return_value={"id": "1"}))
|
http = SimpleNamespace(create_webhook=AsyncMock(return_value={"id": "1"}))
|
||||||
client = Client.__new__(Client)
|
client = Client(token="test")
|
||||||
client._http = http
|
client._http = http
|
||||||
client._closed = False
|
|
||||||
client._webhooks = {}
|
|
||||||
|
|
||||||
webhook = await client.create_webhook("123", {"name": "wh"})
|
webhook = await client.create_webhook("123", {"name": "wh"})
|
||||||
|
|
||||||
http.create_webhook.assert_awaited_once_with("123", {"name": "wh"})
|
http.create_webhook.assert_awaited_once_with("123", {"name": "wh"})
|
||||||
assert isinstance(webhook, Webhook)
|
assert isinstance(webhook, Webhook)
|
||||||
assert client._webhooks["1"] is webhook
|
assert client._webhooks.get("1") is webhook
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -113,16 +111,14 @@ async def test_client_edit_webhook_returns_model():
|
|||||||
from disagreement.models import Webhook
|
from disagreement.models import Webhook
|
||||||
|
|
||||||
http = SimpleNamespace(edit_webhook=AsyncMock(return_value={"id": "1"}))
|
http = SimpleNamespace(edit_webhook=AsyncMock(return_value={"id": "1"}))
|
||||||
client = Client.__new__(Client)
|
client = Client(token="test")
|
||||||
client._http = http
|
client._http = http
|
||||||
client._closed = False
|
|
||||||
client._webhooks = {}
|
|
||||||
|
|
||||||
webhook = await client.edit_webhook("1", {"name": "rename"})
|
webhook = await client.edit_webhook("1", {"name": "rename"})
|
||||||
|
|
||||||
http.edit_webhook.assert_awaited_once_with("1", {"name": "rename"})
|
http.edit_webhook.assert_awaited_once_with("1", {"name": "rename"})
|
||||||
assert isinstance(webhook, Webhook)
|
assert isinstance(webhook, Webhook)
|
||||||
assert client._webhooks["1"] is webhook
|
assert client._webhooks.get("1") is webhook
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -131,9 +127,8 @@ async def test_client_delete_webhook_calls_http():
|
|||||||
from disagreement.client import Client
|
from disagreement.client import Client
|
||||||
|
|
||||||
http = SimpleNamespace(delete_webhook=AsyncMock())
|
http = SimpleNamespace(delete_webhook=AsyncMock())
|
||||||
client = Client.__new__(Client)
|
client = Client(token="test")
|
||||||
client._http = http
|
client._http = http
|
||||||
client._closed = False
|
|
||||||
|
|
||||||
await client.delete_webhook("1")
|
await client.delete_webhook("1")
|
||||||
|
|
||||||
@ -181,10 +176,8 @@ async def test_webhook_send_uses_http():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
client = Client.__new__(Client)
|
client = Client(token="test")
|
||||||
client._http = http
|
client._http = http
|
||||||
client._messages = {}
|
|
||||||
client._webhooks = {}
|
|
||||||
|
|
||||||
webhook = Webhook({"id": "1", "token": "tok"}, client_instance=client)
|
webhook = Webhook({"id": "1", "token": "tok"}, client_instance=client)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user