diff --git a/disagreement/client.py b/disagreement/client.py index 23919eb..6ed4462 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -109,7 +109,14 @@ class Client: member_cache_flags if member_cache_flags is not None else MemberCacheFlags() ) self.intents: int = intents if intents is not None else GatewayIntent.default() - self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() + if loop: + self.loop: asyncio.AbstractEventLoop = loop + else: + try: + self.loop = asyncio.get_running_loop() + except RuntimeError: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) self.application_id: Optional[Snowflake] = ( str(application_id) if application_id else None ) diff --git a/disagreement/ext/commands/core.py b/disagreement/ext/commands/core.py index b9f48c2..5a98997 100644 --- a/disagreement/ext/commands/core.py +++ b/disagreement/ext/commands/core.py @@ -42,7 +42,7 @@ if TYPE_CHECKING: class GroupMixin: def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + super().__init__() self.commands: Dict[str, "Command"] = {} self.name: str = "" diff --git a/disagreement/models.py b/disagreement/models.py index aa18424..9323415 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -10,6 +10,7 @@ from dataclasses import dataclass from typing import Any, AsyncIterator, Dict, List, Optional, TYPE_CHECKING, Union, cast from .cache import ChannelCache, MemberCache +from .caching import MemberCacheFlags import aiohttp # pylint: disable=import-error from .color import Color @@ -1087,7 +1088,7 @@ class Guild: # Internal caches, populated by events or specific fetches self._channels: ChannelCache = ChannelCache() - self._members: MemberCache = MemberCache(client_instance.member_cache_flags) + self._members: MemberCache = MemberCache(getattr(client_instance, "member_cache_flags", MemberCacheFlags())) self._threads: Dict[str, "Thread"] = {} def get_channel(self, channel_id: str) -> Optional["Channel"]: diff --git a/tests/test_additional_converters.py b/tests/test_additional_converters.py index 88f4975..40da104 100644 --- a/tests/test_additional_converters.py +++ b/tests/test_additional_converters.py @@ -14,23 +14,28 @@ from disagreement.enums import ( ) -class DummyBot: - def __init__(self, guild: Guild): - self._guilds = {guild.id: guild} +from disagreement.client import Client +from disagreement.cache import GuildCache - def get_guild(self, gid): - return self._guilds.get(gid) - async def fetch_member(self, gid, mid): - guild = self._guilds.get(gid) - return guild.get_member(mid) if guild else None +class DummyBot(Client): + def __init__(self): + super().__init__(token="test") + self._guilds = GuildCache() - async def fetch_role(self, gid, rid): - guild = self._guilds.get(gid) - return guild.get_role(rid) if guild else None + def get_guild(self, guild_id): + return self._guilds.get(guild_id) - async def fetch_guild(self, gid): - return self._guilds.get(gid) + async def fetch_member(self, guild_id, member_id): + guild = self._guilds.get(guild_id) + return guild.get_member(member_id) if guild else None + + async def fetch_role(self, guild_id, role_id): + guild = self._guilds.get(guild_id) + return guild.get_role(role_id) if guild else None + + async def fetch_guild(self, guild_id): + return self._guilds.get(guild_id) @pytest.fixture() @@ -51,7 +56,9 @@ def guild_objects(): "premium_tier": PremiumTier.NONE.value, "nsfw_level": GuildNSFWLevel.DEFAULT.value, } - guild = Guild(guild_data, client_instance=None) + bot = DummyBot() + guild = Guild(guild_data, client_instance=bot) + bot._guilds.set(guild.id, guild) member = Member( { @@ -76,7 +83,7 @@ def guild_objects(): } ) - guild._members[member.id] = member + guild._members.set(member.id, member) guild.roles.append(role) return guild, member, role @@ -85,7 +92,7 @@ def guild_objects(): @pytest.fixture() def command_context(guild_objects): guild, member, role = guild_objects - bot = DummyBot(guild) + bot = guild._client message_data = { "id": "10", "channel_id": "20", @@ -150,8 +157,9 @@ async def test_member_converter_no_guild(): "premium_tier": PremiumTier.NONE.value, "nsfw_level": GuildNSFWLevel.DEFAULT.value, } - guild = Guild(guild_data, client_instance=None) - bot = DummyBot(guild) + bot = DummyBot() + guild = Guild(guild_data, client_instance=bot) + bot._guilds.set(guild.id, guild) message_data = { "id": "11", "channel_id": "20", diff --git a/tests/test_channel_permissions.py b/tests/test_channel_permissions.py index 0d1024a..2ee9186 100644 --- a/tests/test_channel_permissions.py +++ b/tests/test_channel_permissions.py @@ -14,12 +14,12 @@ from disagreement.enums import ( from disagreement.permissions import Permissions -class DummyClient: - def __init__(self): - self._guilds = {} +from disagreement.client import Client - def get_guild(self, gid): - return self._guilds.get(gid) + +class DummyClient(Client): + def __init__(self): + super().__init__(token="test") def _base_guild(client): @@ -40,7 +40,7 @@ def _base_guild(client): "nsfw_level": GuildNSFWLevel.DEFAULT.value, } guild = Guild(data, client_instance=client) - client._guilds[guild.id] = guild + client._guilds.set(guild.id, guild) return guild @@ -52,7 +52,7 @@ def _member(guild, *roles): } member = Member(data, client_instance=None) member.guild_id = guild.id - guild._members[member.id] = member + guild._members.set(member.id, member) return member @@ -81,7 +81,7 @@ def _channel(guild, client): "permission_overwrites": [], } channel = TextChannel(data, client_instance=client) - guild._channels[channel.id] = channel + guild._channels.set(channel.id, channel) return channel diff --git a/tests/test_event_dispatcher.py b/tests/test_event_dispatcher.py index 8558345..289881a 100644 --- a/tests/test_event_dispatcher.py +++ b/tests/test_event_dispatcher.py @@ -5,10 +5,14 @@ import pytest from disagreement.event_dispatcher import EventDispatcher +from disagreement.cache import Cache + + class DummyClient: def __init__(self): self.parsed = {} - self._messages = {"1": "cached"} + self._messages = Cache() + self._messages.set("1", "cached") def parse_message(self, data): self.parsed["message"] = True @@ -22,6 +26,11 @@ class DummyClient: self.parsed["channel"] = True return data + def parse_message_delete(self, data): + message = self._messages.get(data["id"]) + self._messages.invalidate(data["id"]) + return message + @pytest.mark.asyncio async def test_dispatch_calls_listener(): diff --git a/tests/test_http_reactions.py b/tests/test_http_reactions.py index 02b43a0..730062a 100644 --- a/tests/test_http_reactions.py +++ b/tests/test_http_reactions.py @@ -73,10 +73,8 @@ async def test_delete_reaction_calls_http(): async def test_get_reactions_parses_users(): users_payload = [{"id": "1", "username": "u", "discriminator": "0001"}] http = SimpleNamespace(get_reactions=AsyncMock(return_value=users_payload)) - client = Client.__new__(Client) + client = Client(token="test") client._http = http - client._closed = False - client._users = {} users = await client.get_reactions("1", "2", "😀") diff --git a/tests/test_textchannel_purge.py b/tests/test_textchannel_purge.py index 1be072a..172cc76 100644 --- a/tests/test_textchannel_purge.py +++ b/tests/test_textchannel_purge.py @@ -12,9 +12,8 @@ async def test_textchannel_purge_calls_bulk_delete(): request=AsyncMock(return_value=[{"id": "1"}, {"id": "2"}]), bulk_delete_messages=AsyncMock(), ) - client = Client.__new__(Client) + client = Client(token="test") client._http = http - client._messages = {} channel = TextChannel({"id": "c", "type": 0}, client) @@ -33,9 +32,8 @@ async def test_textchannel_purge_before_param(): request=AsyncMock(return_value=[]), bulk_delete_messages=AsyncMock(), ) - client = Client.__new__(Client) + client = Client(token="test") client._http = http - client._messages = {} channel = TextChannel({"id": "c", "type": 0}, client) diff --git a/tests/test_voice_client.py b/tests/test_voice_client.py index fae0289..da052bc 100644 --- a/tests/test_voice_client.py +++ b/tests/test_voice_client.py @@ -3,6 +3,12 @@ import pytest from disagreement.voice_client import VoiceClient from disagreement.audio import AudioSource +from disagreement.client import Client + + +class DummyVoiceClient(Client): + def __init__(self): + super().__init__(token="test") class DummyWebSocket: @@ -58,7 +64,16 @@ async def test_voice_client_handshake(): ws = DummyWebSocket([hello, ready, session_desc]) udp = DummyUDP() - vc = VoiceClient("ws://localhost", "sess", "tok", 1, 2, ws=ws, udp=udp) + vc = VoiceClient( + client=DummyVoiceClient(), + endpoint="ws://localhost", + session_id="sess", + token="tok", + guild_id=1, + user_id=2, + ws=ws, + udp=udp, + ) await vc.connect() vc._heartbeat_task.cancel() @@ -78,7 +93,16 @@ async def test_send_audio_frame(): ] ) udp = DummyUDP() - vc = VoiceClient("ws://localhost", "sess", "tok", 1, 2, ws=ws, udp=udp) + vc = VoiceClient( + client=DummyVoiceClient(), + endpoint="ws://localhost", + session_id="sess", + token="tok", + guild_id=1, + user_id=2, + ws=ws, + udp=udp, + ) await vc.connect() vc._heartbeat_task.cancel() @@ -96,7 +120,16 @@ async def test_play_and_switch_sources(): ] ) udp = DummyUDP() - vc = VoiceClient("ws://localhost", "sess", "tok", 1, 2, ws=ws, udp=udp) + vc = VoiceClient( + client=DummyVoiceClient(), + endpoint="ws://localhost", + session_id="sess", + token="tok", + guild_id=1, + user_id=2, + ws=ws, + udp=udp, + ) await vc.connect() vc._heartbeat_task.cancel() diff --git a/tests/test_webhooks.py b/tests/test_webhooks.py index 129f658..fa82803 100644 --- a/tests/test_webhooks.py +++ b/tests/test_webhooks.py @@ -94,16 +94,14 @@ async def test_client_create_webhook_returns_model(): from disagreement.models import Webhook http = SimpleNamespace(create_webhook=AsyncMock(return_value={"id": "1"})) - client = Client.__new__(Client) + client = Client(token="test") client._http = http - client._closed = False - client._webhooks = {} webhook = await client.create_webhook("123", {"name": "wh"}) http.create_webhook.assert_awaited_once_with("123", {"name": "wh"}) assert isinstance(webhook, Webhook) - assert client._webhooks["1"] is webhook + assert client._webhooks.get("1") is webhook @pytest.mark.asyncio @@ -113,16 +111,14 @@ async def test_client_edit_webhook_returns_model(): from disagreement.models import Webhook http = SimpleNamespace(edit_webhook=AsyncMock(return_value={"id": "1"})) - client = Client.__new__(Client) + client = Client(token="test") client._http = http - client._closed = False - client._webhooks = {} webhook = await client.edit_webhook("1", {"name": "rename"}) http.edit_webhook.assert_awaited_once_with("1", {"name": "rename"}) assert isinstance(webhook, Webhook) - assert client._webhooks["1"] is webhook + assert client._webhooks.get("1") is webhook @pytest.mark.asyncio @@ -131,9 +127,8 @@ async def test_client_delete_webhook_calls_http(): from disagreement.client import Client http = SimpleNamespace(delete_webhook=AsyncMock()) - client = Client.__new__(Client) + client = Client(token="test") client._http = http - client._closed = False await client.delete_webhook("1") @@ -181,10 +176,8 @@ async def test_webhook_send_uses_http(): } ) ) - client = Client.__new__(Client) + client = Client(token="test") client._http = http - client._messages = {} - client._webhooks = {} webhook = Webhook({"id": "1", "token": "tok"}, client_instance=client)