diff --git a/disagreement/__init__.py b/disagreement/__init__.py index bfd7d3b..b6252c1 100644 --- a/disagreement/__init__.py +++ b/disagreement/__init__.py @@ -51,7 +51,7 @@ from .errors import ( NotFound, ) from .color import Color -from .utils import utcnow, message_pager +from .utils import escape_markdown, escape_mentions, message_pager, utcnow from .enums import ( GatewayIntent, GatewayOpcode, @@ -150,6 +150,8 @@ __all__ = [ "NotFound", "Color", "utcnow", + "escape_markdown", + "escape_mentions", "message_pager", "GatewayIntent", "GatewayOpcode", diff --git a/disagreement/utils.py b/disagreement/utils.py index c9a5fe3..ab167e0 100644 --- a/disagreement/utils.py +++ b/disagreement/utils.py @@ -4,6 +4,7 @@ from __future__ import annotations from datetime import datetime, timezone from typing import Any, AsyncIterator, Dict, Optional, TYPE_CHECKING +import re if TYPE_CHECKING: # pragma: no cover - for type hinting only from .models import Message, TextChannel @@ -110,3 +111,39 @@ class Paginator: if self._current: pages.append(self._current) return pages + + +def escape_markdown(text: str) -> str: + """Escape Discord markdown formatting in ``text``. + + Parameters + ---------- + text: + The text to escape. + + Returns + ------- + str + The escaped text with Discord formatting characters preceded by a + backslash. + """ + + return re.sub(r"([\\*_~`>|])", r"\\\1", text) + + +def escape_mentions(text: str) -> str: + """Escape Discord mentions in ``text``. + + Parameters + ---------- + text: + The text in which to escape mentions. + + Returns + ------- + str + The text with ``@`` characters replaced by ``@\u200b`` to prevent + unintended mentions. + """ + + return text.replace("@", "@\u200b") diff --git a/tests/test_utils.py b/tests/test_utils.py index cac52d4..4022506 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,23 @@ from datetime import timezone -from disagreement.utils import utcnow +from disagreement.utils import escape_markdown, escape_mentions, utcnow def test_utcnow_timezone(): now = utcnow() assert now.tzinfo == timezone.utc + + +def test_escape_markdown(): + text = "**bold** _under_ ~strike~ `code` > quote | pipe" + escaped = escape_markdown(text) + assert ( + escaped + == "\\*\\*bold\\*\\* \\_under\\_ \\~strike\\~ \\`code\\` \\> quote \\| pipe" + ) + + +def test_escape_mentions(): + text = "Hello @everyone and <@123>!" + escaped = escape_mentions(text) + assert escaped == "Hello @\u200beveryone and <@\u200b123>!"