From f5f8f6908cd4ce39e9e90912e4d8c41506e3e425 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Sun, 15 Jun 2025 18:55:17 -0600 Subject: [PATCH] Add get and find helpers (#103) --- disagreement/__init__.py | 12 +++++- disagreement/utils.py | 86 ++++++++++------------------------------ tests/test_utils.py | 21 +++++++++- 3 files changed, 52 insertions(+), 67 deletions(-) diff --git a/disagreement/__init__.py b/disagreement/__init__.py index 4fcac77..ca28c17 100644 --- a/disagreement/__init__.py +++ b/disagreement/__init__.py @@ -51,7 +51,15 @@ from .errors import ( NotFound, ) from .color import Color -from .utils import escape_markdown, escape_mentions, message_pager, utcnow, snowflake_time +from .utils import ( + utcnow, + message_pager, + get, + find, + escape_markdown, + escape_mentions, + snowflake_time, +) from .enums import ( GatewayIntent, GatewayOpcode, @@ -153,6 +161,8 @@ __all__ = [ "escape_markdown", "escape_mentions", "message_pager", + "get", + "find", "snowflake_time", "GatewayIntent", "GatewayOpcode", diff --git a/disagreement/utils.py b/disagreement/utils.py index 6648086..f797814 100644 --- a/disagreement/utils.py +++ b/disagreement/utils.py @@ -3,7 +3,7 @@ from __future__ import annotations from datetime import datetime, timezone -from typing import Any, AsyncIterator, Dict, Optional, TYPE_CHECKING +from typing import Any, AsyncIterator, Dict, Iterable, Optional, TYPE_CHECKING, Callable import re # Discord epoch in milliseconds (2015-01-01T00:00:00Z) @@ -18,20 +18,23 @@ def utcnow() -> datetime: return datetime.now(timezone.utc) +def find(predicate: Callable[[Any], bool], iterable: Iterable[Any]) -> Optional[Any]: + """Return the first element in ``iterable`` matching the ``predicate``.""" + for element in iterable: + if predicate(element): + return element + return None + + +def get(iterable: Iterable[Any], **attrs: Any) -> Optional[Any]: + """Return the first element with matching attributes.""" + def predicate(elem: Any) -> bool: + return all(getattr(elem, attr, None) == value for attr, value in attrs.items()) + return find(predicate, iterable) + + def snowflake_time(snowflake: int) -> datetime: - """Return the creation time of a Discord snowflake. - - Parameters - ---------- - snowflake: - The snowflake ID to convert. - - Returns - ------- - datetime - The UTC timestamp embedded in the snowflake. - """ - + """Return the creation time of a Discord snowflake.""" timestamp_ms = (snowflake >> 22) + DISCORD_EPOCH return datetime.fromtimestamp(timestamp_ms / 1000, tz=timezone.utc) @@ -43,32 +46,11 @@ async def message_pager( before: Optional[str] = None, after: Optional[str] = None, ) -> AsyncIterator["Message"]: - """Asynchronously paginate a channel's messages. - - Parameters - ---------- - channel: - The :class:`TextChannel` to fetch messages from. - limit: - The maximum number of messages to yield. ``None`` fetches until no - more messages are returned. - before: - Fetch messages with IDs less than this snowflake. - after: - Fetch messages with IDs greater than this snowflake. - - Yields - ------ - Message - Messages in the channel, oldest first. - """ - + """Asynchronously paginate a channel's messages.""" remaining = limit last_id = before while remaining is None or remaining > 0: - fetch_limit = 100 - if remaining is not None: - fetch_limit = min(fetch_limit, remaining) + fetch_limit = min(100, remaining) if remaining is not None else 100 params: Dict[str, Any] = {"limit": fetch_limit} if last_id is not None: @@ -135,36 +117,10 @@ class Paginator: def escape_markdown(text: str) -> str: - """Escape Discord markdown formatting in ``text``. - - Parameters - ---------- - text: - The text to escape. - - Returns - ------- - str - The escaped text with Discord formatting characters preceded by a - backslash. - """ - + """Escape Discord markdown formatting in ``text``.""" return re.sub(r"([\\*_~`>|])", r"\\\1", text) def escape_mentions(text: str) -> str: - """Escape Discord mentions in ``text``. - - Parameters - ---------- - text: - The text in which to escape mentions. - - Returns - ------- - str - The text with ``@`` characters replaced by ``@\u200b`` to prevent - unintended mentions. - """ - + """Escape Discord mentions in ``text``.""" return text.replace("@", "@\u200b") diff --git a/tests/test_utils.py b/tests/test_utils.py index bc9c077..82dcb83 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,13 @@ from datetime import datetime, timezone +from types import SimpleNamespace from disagreement.utils import ( escape_markdown, escape_mentions, utcnow, - snowflake_time + snowflake_time, + find, + get ) @@ -13,6 +16,22 @@ def test_utcnow_timezone(): assert now.tzinfo == timezone.utc +def test_find_returns_matching_element(): + seq = [1, 2, 3] + assert find(lambda x: x > 1, seq) == 2 + assert find(lambda x: x > 3, seq) is None + + +def test_get_matches_attributes(): + items = [ + SimpleNamespace(id=1, name="a"), + SimpleNamespace(id=2, name="b"), + ] + assert get(items, id=2) is items[1] + assert get(items, id=1, name="a") is items[0] + assert get(items, name="c") is None + + def test_snowflake_time(): dt = datetime(2020, 1, 1, tzinfo=timezone.utc) ms = int(dt.timestamp() * 1000) - 1420070400000