diff --git a/disagreement/__init__.py b/disagreement/__init__.py index 25307b9..327cb86 100644 --- a/disagreement/__init__.py +++ b/disagreement/__init__.py @@ -17,7 +17,7 @@ __copyright__ = "Copyright 2025 Slipstream" __version__ = "0.0.2" from .client import Client -from .models import Message, User +from .models import Message, User, Reaction from .voice_client import VoiceClient from .typing import Typing from .errors import ( diff --git a/disagreement/client.py b/disagreement/client.py index df58c8a..f2da1d4 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -922,6 +922,16 @@ class Client: await self._http.create_reaction(channel_id, message_id, emoji) + user_id = getattr(getattr(self, "user", None), "id", None) + payload = { + "user_id": user_id, + "channel_id": channel_id, + "message_id": message_id, + "emoji": {"name": emoji, "id": None}, + } + if hasattr(self, "_event_dispatcher"): + await self._event_dispatcher.dispatch("MESSAGE_REACTION_ADD", payload) + async def delete_reaction( self, channel_id: str, message_id: str, emoji: str ) -> None: @@ -932,6 +942,16 @@ class Client: await self._http.delete_reaction(channel_id, message_id, emoji) + user_id = getattr(getattr(self, "user", None), "id", None) + payload = { + "user_id": user_id, + "channel_id": channel_id, + "message_id": message_id, + "emoji": {"name": emoji, "id": None}, + } + if hasattr(self, "_event_dispatcher"): + await self._event_dispatcher.dispatch("MESSAGE_REACTION_REMOVE", payload) + async def get_reactions( self, channel_id: str, message_id: str, emoji: str ) -> List["User"]: diff --git a/disagreement/event_dispatcher.py b/disagreement/event_dispatcher.py index 7938d59..0cfc779 100644 --- a/disagreement/event_dispatcher.py +++ b/disagreement/event_dispatcher.py @@ -52,6 +52,8 @@ class EventDispatcher: "CHANNEL_CREATE": self._parse_channel_create, "PRESENCE_UPDATE": self._parse_presence_update, "TYPING_START": self._parse_typing_start, + "MESSAGE_REACTION_ADD": self._parse_message_reaction, + "MESSAGE_REACTION_REMOVE": self._parse_message_reaction, } def _parse_message_create(self, data: Dict[str, Any]) -> Message: @@ -88,6 +90,13 @@ class EventDispatcher: return TypingStart(data, client_instance=self._client) + def _parse_message_reaction(self, data: Dict[str, Any]): + """Parses raw reaction data into a Reaction object.""" + + from .models import Reaction + + return Reaction(data, client_instance=self._client) + # Potentially add _parse_user for events that directly provide a full user object # def _parse_user_update(self, data: Dict[str, Any]) -> User: # return User(data=data) diff --git a/disagreement/models.py b/disagreement/models.py index 698d134..2e14ca0 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -1614,6 +1614,27 @@ class TypingStart: return f"" +class Reaction: + """Represents a message reaction event.""" + + def __init__( + self, data: Dict[str, Any], client_instance: Optional["Client"] = None + ): + self._client = client_instance + self.user_id: str = data["user_id"] + self.channel_id: str = data["channel_id"] + self.message_id: str = data["message_id"] + self.guild_id: Optional[str] = data.get("guild_id") + self.member: Optional[Member] = ( + Member(data["member"], client_instance) if data.get("member") else None + ) + self.emoji: Dict[str, Any] = data.get("emoji", {}) + + def __repr__(self) -> str: + emoji_value = self.emoji.get("name") or self.emoji.get("id") + return f"" + + def channel_factory(data: Dict[str, Any], client: "Client") -> Channel: """Create a channel object from raw API data.""" channel_type = data.get("type") diff --git a/tests/test_http_reactions.py b/tests/test_http_reactions.py index 101875f..e1c0914 100644 --- a/tests/test_http_reactions.py +++ b/tests/test_http_reactions.py @@ -4,7 +4,7 @@ from unittest.mock import AsyncMock from disagreement.client import Client from disagreement.errors import DisagreementException -from disagreement.models import User +from disagreement.models import User, Reaction @pytest.mark.asyncio @@ -55,3 +55,37 @@ async def test_get_reactions_parses_users(): http.get_reactions.assert_called_once_with("1", "2", "😀") assert isinstance(users[0], User) + + +@pytest.mark.asyncio +async def test_create_reaction_dispatches_event(monkeypatch): + http = SimpleNamespace(create_reaction=AsyncMock()) + client = Client(token="t") + client._http = http + events = {} + + async def on_add(reaction): + events["add"] = reaction + + client._event_dispatcher.register("MESSAGE_REACTION_ADD", on_add) + + await client.create_reaction("1", "2", "😀") + + assert isinstance(events.get("add"), Reaction) + + +@pytest.mark.asyncio +async def test_delete_reaction_dispatches_event(monkeypatch): + http = SimpleNamespace(delete_reaction=AsyncMock()) + client = Client(token="t") + client._http = http + events = {} + + async def on_remove(reaction): + events["remove"] = reaction + + client._event_dispatcher.register("MESSAGE_REACTION_REMOVE", on_remove) + + await client.delete_reaction("1", "2", "😀") + + assert isinstance(events.get("remove"), Reaction) diff --git a/tests/test_reactions.py b/tests/test_reactions.py index b73ea37..c40dac9 100644 --- a/tests/test_reactions.py +++ b/tests/test_reactions.py @@ -1,19 +1,19 @@ import pytest from disagreement.event_dispatcher import EventDispatcher +from disagreement.models import Reaction @pytest.mark.asyncio async def test_reaction_payload(): - # This test now checks the raw payload dictionary, as the Reaction model is removed. data = { "user_id": "1", "channel_id": "2", "message_id": "3", "emoji": {"name": "😀", "id": None}, } - # The "reaction" is just the data dictionary itself. - assert data["user_id"] == "1" - assert data["emoji"]["name"] == "😀" + reaction = Reaction(data) + assert reaction.user_id == "1" + assert reaction.emoji["name"] == "😀" @pytest.mark.asyncio @@ -21,7 +21,7 @@ async def test_dispatch_reaction_event(dummy_client): dispatcher = EventDispatcher(dummy_client) captured = [] - async def listener(payload: dict): + async def listener(payload: Reaction): captured.append(payload) # The event name is now MESSAGE_REACTION_ADD as per the original test setup. @@ -35,4 +35,4 @@ async def test_dispatch_reaction_event(dummy_client): } await dispatcher.dispatch("MESSAGE_REACTION_ADD", payload) assert len(captured) == 1 - assert isinstance(captured[0], dict) + assert isinstance(captured[0], Reaction)