Add get and find helpers (#103)
This commit is contained in:
parent
3437050f0e
commit
f5f8f6908c
@ -51,7 +51,15 @@ from .errors import (
|
||||
NotFound,
|
||||
)
|
||||
from .color import Color
|
||||
from .utils import escape_markdown, escape_mentions, message_pager, utcnow, snowflake_time
|
||||
from .utils import (
|
||||
utcnow,
|
||||
message_pager,
|
||||
get,
|
||||
find,
|
||||
escape_markdown,
|
||||
escape_mentions,
|
||||
snowflake_time,
|
||||
)
|
||||
from .enums import (
|
||||
GatewayIntent,
|
||||
GatewayOpcode,
|
||||
@ -153,6 +161,8 @@ __all__ = [
|
||||
"escape_markdown",
|
||||
"escape_mentions",
|
||||
"message_pager",
|
||||
"get",
|
||||
"find",
|
||||
"snowflake_time",
|
||||
"GatewayIntent",
|
||||
"GatewayOpcode",
|
||||
|
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, AsyncIterator, Dict, Optional, TYPE_CHECKING
|
||||
from typing import Any, AsyncIterator, Dict, Iterable, Optional, TYPE_CHECKING, Callable
|
||||
import re
|
||||
|
||||
# Discord epoch in milliseconds (2015-01-01T00:00:00Z)
|
||||
@ -18,20 +18,23 @@ def utcnow() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def find(predicate: Callable[[Any], bool], iterable: Iterable[Any]) -> Optional[Any]:
|
||||
"""Return the first element in ``iterable`` matching the ``predicate``."""
|
||||
for element in iterable:
|
||||
if predicate(element):
|
||||
return element
|
||||
return None
|
||||
|
||||
|
||||
def get(iterable: Iterable[Any], **attrs: Any) -> Optional[Any]:
|
||||
"""Return the first element with matching attributes."""
|
||||
def predicate(elem: Any) -> bool:
|
||||
return all(getattr(elem, attr, None) == value for attr, value in attrs.items())
|
||||
return find(predicate, iterable)
|
||||
|
||||
|
||||
def snowflake_time(snowflake: int) -> datetime:
|
||||
"""Return the creation time of a Discord snowflake.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
snowflake:
|
||||
The snowflake ID to convert.
|
||||
|
||||
Returns
|
||||
-------
|
||||
datetime
|
||||
The UTC timestamp embedded in the snowflake.
|
||||
"""
|
||||
|
||||
"""Return the creation time of a Discord snowflake."""
|
||||
timestamp_ms = (snowflake >> 22) + DISCORD_EPOCH
|
||||
return datetime.fromtimestamp(timestamp_ms / 1000, tz=timezone.utc)
|
||||
|
||||
@ -43,32 +46,11 @@ async def message_pager(
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
) -> AsyncIterator["Message"]:
|
||||
"""Asynchronously paginate a channel's messages.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
channel:
|
||||
The :class:`TextChannel` to fetch messages from.
|
||||
limit:
|
||||
The maximum number of messages to yield. ``None`` fetches until no
|
||||
more messages are returned.
|
||||
before:
|
||||
Fetch messages with IDs less than this snowflake.
|
||||
after:
|
||||
Fetch messages with IDs greater than this snowflake.
|
||||
|
||||
Yields
|
||||
------
|
||||
Message
|
||||
Messages in the channel, oldest first.
|
||||
"""
|
||||
|
||||
"""Asynchronously paginate a channel's messages."""
|
||||
remaining = limit
|
||||
last_id = before
|
||||
while remaining is None or remaining > 0:
|
||||
fetch_limit = 100
|
||||
if remaining is not None:
|
||||
fetch_limit = min(fetch_limit, remaining)
|
||||
fetch_limit = min(100, remaining) if remaining is not None else 100
|
||||
|
||||
params: Dict[str, Any] = {"limit": fetch_limit}
|
||||
if last_id is not None:
|
||||
@ -135,36 +117,10 @@ class Paginator:
|
||||
|
||||
|
||||
def escape_markdown(text: str) -> str:
|
||||
"""Escape Discord markdown formatting in ``text``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text:
|
||||
The text to escape.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The escaped text with Discord formatting characters preceded by a
|
||||
backslash.
|
||||
"""
|
||||
|
||||
"""Escape Discord markdown formatting in ``text``."""
|
||||
return re.sub(r"([\\*_~`>|])", r"\\\1", text)
|
||||
|
||||
|
||||
def escape_mentions(text: str) -> str:
|
||||
"""Escape Discord mentions in ``text``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text:
|
||||
The text in which to escape mentions.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The text with ``@`` characters replaced by ``@\u200b`` to prevent
|
||||
unintended mentions.
|
||||
"""
|
||||
|
||||
"""Escape Discord mentions in ``text``."""
|
||||
return text.replace("@", "@\u200b")
|
||||
|
@ -1,10 +1,13 @@
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
|
||||
from disagreement.utils import (
|
||||
escape_markdown,
|
||||
escape_mentions,
|
||||
utcnow,
|
||||
snowflake_time
|
||||
snowflake_time,
|
||||
find,
|
||||
get
|
||||
)
|
||||
|
||||
|
||||
@ -13,6 +16,22 @@ def test_utcnow_timezone():
|
||||
assert now.tzinfo == timezone.utc
|
||||
|
||||
|
||||
def test_find_returns_matching_element():
|
||||
seq = [1, 2, 3]
|
||||
assert find(lambda x: x > 1, seq) == 2
|
||||
assert find(lambda x: x > 3, seq) is None
|
||||
|
||||
|
||||
def test_get_matches_attributes():
|
||||
items = [
|
||||
SimpleNamespace(id=1, name="a"),
|
||||
SimpleNamespace(id=2, name="b"),
|
||||
]
|
||||
assert get(items, id=2) is items[1]
|
||||
assert get(items, id=1, name="a") is items[0]
|
||||
assert get(items, name="c") is None
|
||||
|
||||
|
||||
def test_snowflake_time():
|
||||
dt = datetime(2020, 1, 1, tzinfo=timezone.utc)
|
||||
ms = int(dt.timestamp() * 1000) - 1420070400000
|
||||
|
Loading…
x
Reference in New Issue
Block a user