disagreement/tests/test_context.py
2025-06-09 22:25:14 -06:00

200 lines
6.9 KiB
Python

import asyncio
from unittest.mock import AsyncMock
import pytest
from disagreement.ext.app_commands.context import AppCommandContext
from disagreement.interactions import Interaction
from disagreement.enums import InteractionType, MessageFlags
from disagreement.models import (
Embed,
ActionRow,
Button,
Container,
TextDisplay,
)
from disagreement.enums import ButtonStyle, ComponentType
class DummyHTTP:
def __init__(self):
self.create_interaction_response = AsyncMock()
self.create_followup_message = AsyncMock(return_value={"id": "123"})
self.edit_original_interaction_response = AsyncMock(return_value={"id": "123"})
self.edit_followup_message = AsyncMock(return_value={"id": "123"})
class DummyBot:
def __init__(self):
self._http = DummyHTTP()
self.application_id = "app123"
self._guilds = {}
self._channels = {}
def get_guild(self, gid):
return self._guilds.get(gid)
async def fetch_channel(self, cid):
return self._channels.get(cid)
from disagreement.ext.commands.core import CommandContext, Command
from disagreement.enums import MessageFlags, ButtonStyle, ComponentType
from disagreement.models import ActionRow, Button, Container, TextDisplay
@pytest.mark.asyncio
async def test_sends_extra_payload(dummy_bot, interaction):
ctx = AppCommandContext(dummy_bot, interaction)
button = Button(style=ButtonStyle.PRIMARY, label="Click", custom_id="a")
row = ActionRow([button])
await ctx.send(
content="hi",
tts=True,
components=[row],
allowed_mentions={"parse": []},
files=[{"id": 1, "filename": "f.txt"}],
ephemeral=True,
)
dummy_bot._http.create_interaction_response.assert_called_once()
payload = dummy_bot._http.create_interaction_response.call_args.kwargs[
"payload"
].data.to_dict()
assert payload["tts"] is True
assert payload["components"]
assert payload["allowed_mentions"] == {"parse": []}
assert payload["attachments"] == [{"id": 1, "filename": "f.txt"}]
assert payload["flags"] == MessageFlags.EPHEMERAL.value
await ctx.send_followup(content="again")
dummy_bot._http.create_followup_message.assert_called_once()
@pytest.mark.asyncio
async def test_second_send_is_followup(dummy_bot, interaction):
ctx = AppCommandContext(dummy_bot, interaction)
await ctx.send(content="first")
await ctx.send_followup(content="second")
assert dummy_bot._http.create_interaction_response.call_count == 1
assert dummy_bot._http.create_followup_message.call_count == 1
@pytest.mark.asyncio
async def test_edit_with_components_and_attachments(dummy_bot, interaction):
ctx = AppCommandContext(dummy_bot, interaction)
await ctx.send(content="orig")
row = ActionRow([Button(style=ButtonStyle.PRIMARY, label="B", custom_id="b")])
await ctx.edit(content="new", components=[row], attachments=[{"id": 1}])
dummy_bot._http.edit_original_interaction_response.assert_called_once()
payload = dummy_bot._http.edit_original_interaction_response.call_args.kwargs[
"payload"
]
assert payload["components"]
assert payload["attachments"] == [{"id": 1}]
@pytest.mark.asyncio
async def test_send_with_flags(dummy_bot, interaction):
ctx = AppCommandContext(dummy_bot, interaction)
await ctx.send(content="hi", flags=MessageFlags.IS_COMPONENTS_V2.value)
payload = dummy_bot._http.create_interaction_response.call_args.kwargs[
"payload"
].data.to_dict()
assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value
@pytest.mark.asyncio
async def test_send_container_component(dummy_bot, interaction):
ctx = AppCommandContext(dummy_bot, interaction)
container = Container(components=[TextDisplay(content="hi")])
await ctx.send(components=[container], flags=MessageFlags.IS_COMPONENTS_V2.value)
payload = dummy_bot._http.create_interaction_response.call_args.kwargs[
"payload"
].data.to_dict()
assert payload["components"][0]["type"] == ComponentType.CONTAINER.value
assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value
@pytest.mark.asyncio
async def test_command_context_edit(command_bot, message):
async def dummy(ctx):
pass
cmd = Command(dummy)
ctx = CommandContext(
message=message, bot=command_bot, prefix="!", command=cmd, invoked_with="dummy"
)
await ctx.edit(message.id, content="new")
command_bot._http.edit_message.assert_called_once()
args = command_bot._http.edit_message.call_args[0]
assert args[0] == message.channel_id
assert args[1] == message.id
assert args[2]["content"] == "new"
@pytest.mark.asyncio
async def test_send_http_error_propagates(dummy_bot, interaction):
ctx = AppCommandContext(dummy_bot, interaction)
dummy_bot._http.create_interaction_response.side_effect = RuntimeError("boom")
with pytest.raises(RuntimeError):
await ctx.send(content="hi")
@pytest.mark.asyncio
async def test_concurrent_send_only_initial_once(dummy_bot, interaction):
ctx = AppCommandContext(dummy_bot, interaction)
async def send_msg(i: int):
if i == 0:
await ctx.send(content=str(i))
else:
await ctx.send_followup(content=str(i))
await asyncio.gather(*(send_msg(i) for i in range(50)))
assert dummy_bot._http.create_interaction_response.call_count == 1
assert dummy_bot._http.create_followup_message.call_count == 49
@pytest.mark.asyncio
async def test_send_with_flags_2():
bot = DummyBot()
interaction = Interaction(
{
"id": "1",
"application_id": bot.application_id,
"type": InteractionType.APPLICATION_COMMAND.value,
"token": "tok",
"version": 1,
},
client_instance=bot,
)
ctx = AppCommandContext(bot, interaction)
await ctx.send(content="hi", flags=MessageFlags.IS_COMPONENTS_V2.value)
payload = bot._http.create_interaction_response.call_args[1][
"payload"
].data.to_dict()
assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value
@pytest.mark.asyncio
async def test_send_container_component_2():
bot = DummyBot()
interaction = Interaction(
{
"id": "1",
"application_id": bot.application_id,
"type": InteractionType.APPLICATION_COMMAND.value,
"token": "tok",
"version": 1,
},
client_instance=bot,
)
ctx = AppCommandContext(bot, interaction)
container = Container(components=[TextDisplay(content="hi")])
await ctx.send(
components=[container],
flags=MessageFlags.IS_COMPONENTS_V2.value,
)
payload = bot._http.create_interaction_response.call_args[1][
"payload"
].data.to_dict()
assert payload["components"][0]["type"] == ComponentType.CONTAINER.value
assert payload["flags"] == MessageFlags.IS_COMPONENTS_V2.value