Add get and find helpers (#103)
This commit is contained in:
parent
3437050f0e
commit
f5f8f6908c
@ -51,7 +51,15 @@ from .errors import (
|
|||||||
NotFound,
|
NotFound,
|
||||||
)
|
)
|
||||||
from .color import Color
|
from .color import Color
|
||||||
from .utils import escape_markdown, escape_mentions, message_pager, utcnow, snowflake_time
|
from .utils import (
|
||||||
|
utcnow,
|
||||||
|
message_pager,
|
||||||
|
get,
|
||||||
|
find,
|
||||||
|
escape_markdown,
|
||||||
|
escape_mentions,
|
||||||
|
snowflake_time,
|
||||||
|
)
|
||||||
from .enums import (
|
from .enums import (
|
||||||
GatewayIntent,
|
GatewayIntent,
|
||||||
GatewayOpcode,
|
GatewayOpcode,
|
||||||
@ -153,6 +161,8 @@ __all__ = [
|
|||||||
"escape_markdown",
|
"escape_markdown",
|
||||||
"escape_mentions",
|
"escape_mentions",
|
||||||
"message_pager",
|
"message_pager",
|
||||||
|
"get",
|
||||||
|
"find",
|
||||||
"snowflake_time",
|
"snowflake_time",
|
||||||
"GatewayIntent",
|
"GatewayIntent",
|
||||||
"GatewayOpcode",
|
"GatewayOpcode",
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, AsyncIterator, Dict, Optional, TYPE_CHECKING
|
from typing import Any, AsyncIterator, Dict, Iterable, Optional, TYPE_CHECKING, Callable
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# Discord epoch in milliseconds (2015-01-01T00:00:00Z)
|
# Discord epoch in milliseconds (2015-01-01T00:00:00Z)
|
||||||
@ -18,20 +18,23 @@ def utcnow() -> datetime:
|
|||||||
return datetime.now(timezone.utc)
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
def find(predicate: Callable[[Any], bool], iterable: Iterable[Any]) -> Optional[Any]:
|
||||||
|
"""Return the first element in ``iterable`` matching the ``predicate``."""
|
||||||
|
for element in iterable:
|
||||||
|
if predicate(element):
|
||||||
|
return element
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get(iterable: Iterable[Any], **attrs: Any) -> Optional[Any]:
|
||||||
|
"""Return the first element with matching attributes."""
|
||||||
|
def predicate(elem: Any) -> bool:
|
||||||
|
return all(getattr(elem, attr, None) == value for attr, value in attrs.items())
|
||||||
|
return find(predicate, iterable)
|
||||||
|
|
||||||
|
|
||||||
def snowflake_time(snowflake: int) -> datetime:
|
def snowflake_time(snowflake: int) -> datetime:
|
||||||
"""Return the creation time of a Discord snowflake.
|
"""Return the creation time of a Discord snowflake."""
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
snowflake:
|
|
||||||
The snowflake ID to convert.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
datetime
|
|
||||||
The UTC timestamp embedded in the snowflake.
|
|
||||||
"""
|
|
||||||
|
|
||||||
timestamp_ms = (snowflake >> 22) + DISCORD_EPOCH
|
timestamp_ms = (snowflake >> 22) + DISCORD_EPOCH
|
||||||
return datetime.fromtimestamp(timestamp_ms / 1000, tz=timezone.utc)
|
return datetime.fromtimestamp(timestamp_ms / 1000, tz=timezone.utc)
|
||||||
|
|
||||||
@ -43,32 +46,11 @@ async def message_pager(
|
|||||||
before: Optional[str] = None,
|
before: Optional[str] = None,
|
||||||
after: Optional[str] = None,
|
after: Optional[str] = None,
|
||||||
) -> AsyncIterator["Message"]:
|
) -> AsyncIterator["Message"]:
|
||||||
"""Asynchronously paginate a channel's messages.
|
"""Asynchronously paginate a channel's messages."""
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
channel:
|
|
||||||
The :class:`TextChannel` to fetch messages from.
|
|
||||||
limit:
|
|
||||||
The maximum number of messages to yield. ``None`` fetches until no
|
|
||||||
more messages are returned.
|
|
||||||
before:
|
|
||||||
Fetch messages with IDs less than this snowflake.
|
|
||||||
after:
|
|
||||||
Fetch messages with IDs greater than this snowflake.
|
|
||||||
|
|
||||||
Yields
|
|
||||||
------
|
|
||||||
Message
|
|
||||||
Messages in the channel, oldest first.
|
|
||||||
"""
|
|
||||||
|
|
||||||
remaining = limit
|
remaining = limit
|
||||||
last_id = before
|
last_id = before
|
||||||
while remaining is None or remaining > 0:
|
while remaining is None or remaining > 0:
|
||||||
fetch_limit = 100
|
fetch_limit = min(100, remaining) if remaining is not None else 100
|
||||||
if remaining is not None:
|
|
||||||
fetch_limit = min(fetch_limit, remaining)
|
|
||||||
|
|
||||||
params: Dict[str, Any] = {"limit": fetch_limit}
|
params: Dict[str, Any] = {"limit": fetch_limit}
|
||||||
if last_id is not None:
|
if last_id is not None:
|
||||||
@ -135,36 +117,10 @@ class Paginator:
|
|||||||
|
|
||||||
|
|
||||||
def escape_markdown(text: str) -> str:
|
def escape_markdown(text: str) -> str:
|
||||||
"""Escape Discord markdown formatting in ``text``.
|
"""Escape Discord markdown formatting in ``text``."""
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
text:
|
|
||||||
The text to escape.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
str
|
|
||||||
The escaped text with Discord formatting characters preceded by a
|
|
||||||
backslash.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return re.sub(r"([\\*_~`>|])", r"\\\1", text)
|
return re.sub(r"([\\*_~`>|])", r"\\\1", text)
|
||||||
|
|
||||||
|
|
||||||
def escape_mentions(text: str) -> str:
|
def escape_mentions(text: str) -> str:
|
||||||
"""Escape Discord mentions in ``text``.
|
"""Escape Discord mentions in ``text``."""
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
text:
|
|
||||||
The text in which to escape mentions.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
str
|
|
||||||
The text with ``@`` characters replaced by ``@\u200b`` to prevent
|
|
||||||
unintended mentions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return text.replace("@", "@\u200b")
|
return text.replace("@", "@\u200b")
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
from disagreement.utils import (
|
from disagreement.utils import (
|
||||||
escape_markdown,
|
escape_markdown,
|
||||||
escape_mentions,
|
escape_mentions,
|
||||||
utcnow,
|
utcnow,
|
||||||
snowflake_time
|
snowflake_time,
|
||||||
|
find,
|
||||||
|
get
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -13,6 +16,22 @@ def test_utcnow_timezone():
|
|||||||
assert now.tzinfo == timezone.utc
|
assert now.tzinfo == timezone.utc
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_returns_matching_element():
|
||||||
|
seq = [1, 2, 3]
|
||||||
|
assert find(lambda x: x > 1, seq) == 2
|
||||||
|
assert find(lambda x: x > 3, seq) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_matches_attributes():
|
||||||
|
items = [
|
||||||
|
SimpleNamespace(id=1, name="a"),
|
||||||
|
SimpleNamespace(id=2, name="b"),
|
||||||
|
]
|
||||||
|
assert get(items, id=2) is items[1]
|
||||||
|
assert get(items, id=1, name="a") is items[0]
|
||||||
|
assert get(items, name="c") is None
|
||||||
|
|
||||||
|
|
||||||
def test_snowflake_time():
|
def test_snowflake_time():
|
||||||
dt = datetime(2020, 1, 1, tzinfo=timezone.utc)
|
dt = datetime(2020, 1, 1, tzinfo=timezone.utc)
|
||||||
ms = int(dt.timestamp() * 1000) - 1420070400000
|
ms = int(dt.timestamp() * 1000) - 1420070400000
|
||||||
|
Loading…
x
Reference in New Issue
Block a user