Add HashableById mixin and tests (#118)

This commit is contained in:
Slipstream 2025-06-15 20:39:16 -06:00 committed by GitHub
parent 4b3b6aeb45
commit 17751d3b09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 102 additions and 4 deletions

View File

@ -63,7 +63,19 @@ if TYPE_CHECKING:
from .components import component_factory
class User:
class HashableById:
"""Mixin providing equality and hashing based on the ``id`` attribute."""
id: str
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.id == other.id # type: ignore[attr-defined]
def __hash__(self) -> int: # pragma: no cover - trivial
return hash(self.id)
class User(HashableById):
"""Represents a Discord User."""
def __init__(self, data: dict, client_instance: Optional["Client"] = None) -> None:
@ -125,7 +137,7 @@ class User:
return await target_client.send_dm(self.id, content=content, **kwargs)
class Message:
class Message(HashableById):
"""Represents a message sent in a channel on Discord.
Attributes:
@ -1228,7 +1240,7 @@ class PermissionOverwrite:
return f"<PermissionOverwrite id='{self.id}' type='{self.type.name if hasattr(self.type, 'name') else self._type_val}' allow='{self.allow}' deny='{self.deny}'>"
class Guild:
class Guild(HashableById):
"""Represents a Discord Guild (Server).
Attributes:
@ -1649,7 +1661,7 @@ class Guild:
return cast("CategoryChannel", self._client.parse_channel(data))
class Channel:
class Channel(HashableById):
"""Base class for Discord channels."""
def __init__(self, data: Dict[str, Any], client_instance: "Client"):

View File

@ -0,0 +1,86 @@
import types
from disagreement.models import User, Guild, Channel, Message
from disagreement.enums import (
VerificationLevel,
MessageNotificationLevel,
ExplicitContentFilterLevel,
MFALevel,
GuildNSFWLevel,
PremiumTier,
ChannelType,
)
def _guild_data(gid="1"):
return {
"id": gid,
"name": "g",
"owner_id": gid,
"afk_timeout": 60,
"verification_level": VerificationLevel.NONE.value,
"default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value,
"explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value,
"roles": [],
"emojis": [],
"features": [],
"mfa_level": MFALevel.NONE.value,
"system_channel_flags": 0,
"premium_tier": PremiumTier.NONE.value,
"nsfw_level": GuildNSFWLevel.DEFAULT.value,
}
def _user(uid="1"):
return User({"id": uid, "username": "u", "discriminator": "0001"})
def _message(mid="1"):
data = {
"id": mid,
"channel_id": "c",
"author": {"id": "2", "username": "u", "discriminator": "0001"},
"content": "hi",
"timestamp": "t",
}
return Message(data, client_instance=types.SimpleNamespace())
def _channel(cid="1"):
data = {"id": cid, "type": ChannelType.GUILD_TEXT.value}
return Channel(data, client_instance=types.SimpleNamespace())
def test_user_hash_and_eq():
a = _user()
b = _user()
c = _user("2")
assert a == b
assert hash(a) == hash(b)
assert a != c
def test_guild_hash_and_eq():
a = Guild(_guild_data(), client_instance=types.SimpleNamespace())
b = Guild(_guild_data(), client_instance=types.SimpleNamespace())
c = Guild(_guild_data("2"), client_instance=types.SimpleNamespace())
assert a == b
assert hash(a) == hash(b)
assert a != c
def test_channel_hash_and_eq():
a = _channel()
b = _channel()
c = _channel("2")
assert a == b
assert hash(a) == hash(b)
assert a != c
def test_message_hash_and_eq():
a = _message()
b = _message()
c = _message("2")
assert a == b
assert hash(a) == hash(b)
assert a != c