Add get and find helpers (#103)

This commit is contained in:
Slipstream 2025-06-15 18:55:17 -06:00 committed by GitHub
parent 3437050f0e
commit f5f8f6908c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 52 additions and 67 deletions

View File

@ -51,7 +51,15 @@ from .errors import (
NotFound,
)
from .color import Color
from .utils import escape_markdown, escape_mentions, message_pager, utcnow, snowflake_time
from .utils import (
utcnow,
message_pager,
get,
find,
escape_markdown,
escape_mentions,
snowflake_time,
)
from .enums import (
GatewayIntent,
GatewayOpcode,
@ -153,6 +161,8 @@ __all__ = [
"escape_markdown",
"escape_mentions",
"message_pager",
"get",
"find",
"snowflake_time",
"GatewayIntent",
"GatewayOpcode",

View File

@ -3,7 +3,7 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, AsyncIterator, Dict, Optional, TYPE_CHECKING
from typing import Any, AsyncIterator, Dict, Iterable, Optional, TYPE_CHECKING, Callable
import re
# Discord epoch in milliseconds (2015-01-01T00:00:00Z)
@ -18,20 +18,23 @@ def utcnow() -> datetime:
return datetime.now(timezone.utc)
def find(predicate: Callable[[Any], bool], iterable: Iterable[Any]) -> Optional[Any]:
"""Return the first element in ``iterable`` matching the ``predicate``."""
for element in iterable:
if predicate(element):
return element
return None
def get(iterable: Iterable[Any], **attrs: Any) -> Optional[Any]:
"""Return the first element with matching attributes."""
def predicate(elem: Any) -> bool:
return all(getattr(elem, attr, None) == value for attr, value in attrs.items())
return find(predicate, iterable)
def snowflake_time(snowflake: int) -> datetime:
"""Return the creation time of a Discord snowflake.
Parameters
----------
snowflake:
The snowflake ID to convert.
Returns
-------
datetime
The UTC timestamp embedded in the snowflake.
"""
"""Return the creation time of a Discord snowflake."""
timestamp_ms = (snowflake >> 22) + DISCORD_EPOCH
return datetime.fromtimestamp(timestamp_ms / 1000, tz=timezone.utc)
@ -43,32 +46,11 @@ async def message_pager(
before: Optional[str] = None,
after: Optional[str] = None,
) -> AsyncIterator["Message"]:
"""Asynchronously paginate a channel's messages.
Parameters
----------
channel:
The :class:`TextChannel` to fetch messages from.
limit:
The maximum number of messages to yield. ``None`` fetches until no
more messages are returned.
before:
Fetch messages with IDs less than this snowflake.
after:
Fetch messages with IDs greater than this snowflake.
Yields
------
Message
Messages in the channel, oldest first.
"""
"""Asynchronously paginate a channel's messages."""
remaining = limit
last_id = before
while remaining is None or remaining > 0:
fetch_limit = 100
if remaining is not None:
fetch_limit = min(fetch_limit, remaining)
fetch_limit = min(100, remaining) if remaining is not None else 100
params: Dict[str, Any] = {"limit": fetch_limit}
if last_id is not None:
@ -135,36 +117,10 @@ class Paginator:
def escape_markdown(text: str) -> str:
"""Escape Discord markdown formatting in ``text``.
Parameters
----------
text:
The text to escape.
Returns
-------
str
The escaped text with Discord formatting characters preceded by a
backslash.
"""
"""Escape Discord markdown formatting in ``text``."""
return re.sub(r"([\\*_~`>|])", r"\\\1", text)
def escape_mentions(text: str) -> str:
"""Escape Discord mentions in ``text``.
Parameters
----------
text:
The text in which to escape mentions.
Returns
-------
str
The text with ``@`` characters replaced by ``@\u200b`` to prevent
unintended mentions.
"""
"""Escape Discord mentions in ``text``."""
return text.replace("@", "@\u200b")

View File

@ -1,10 +1,13 @@
from datetime import datetime, timezone
from types import SimpleNamespace
from disagreement.utils import (
escape_markdown,
escape_mentions,
utcnow,
snowflake_time
snowflake_time,
find,
get
)
@ -13,6 +16,22 @@ def test_utcnow_timezone():
assert now.tzinfo == timezone.utc
def test_find_returns_matching_element():
seq = [1, 2, 3]
assert find(lambda x: x > 1, seq) == 2
assert find(lambda x: x > 3, seq) is None
def test_get_matches_attributes():
items = [
SimpleNamespace(id=1, name="a"),
SimpleNamespace(id=2, name="b"),
]
assert get(items, id=2) is items[1]
assert get(items, id=1, name="a") is items[0]
assert get(items, name="c") is None
def test_snowflake_time():
dt = datetime(2020, 1, 1, tzinfo=timezone.utc)
ms = int(dt.timestamp() * 1000) - 1420070400000