From 17751d3b09e6ce9e8e8cd05f3f9de9a9272960f3 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Sun, 15 Jun 2025 20:39:16 -0600 Subject: [PATCH] Add HashableById mixin and tests (#118) --- disagreement/models.py | 20 +++++++-- tests/test_hashable_mixin.py | 86 ++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 4 deletions(-) create mode 100644 tests/test_hashable_mixin.py diff --git a/disagreement/models.py b/disagreement/models.py index de5d0af..1f09d7c 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -63,7 +63,19 @@ if TYPE_CHECKING: from .components import component_factory -class User: +class HashableById: + """Mixin providing equality and hashing based on the ``id`` attribute.""" + + id: str + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) and self.id == other.id # type: ignore[attr-defined] + + def __hash__(self) -> int: # pragma: no cover - trivial + return hash(self.id) + + +class User(HashableById): """Represents a Discord User.""" def __init__(self, data: dict, client_instance: Optional["Client"] = None) -> None: @@ -125,7 +137,7 @@ class User: return await target_client.send_dm(self.id, content=content, **kwargs) -class Message: +class Message(HashableById): """Represents a message sent in a channel on Discord. Attributes: @@ -1228,7 +1240,7 @@ class PermissionOverwrite: return f"" -class Guild: +class Guild(HashableById): """Represents a Discord Guild (Server). Attributes: @@ -1649,7 +1661,7 @@ class Guild: return cast("CategoryChannel", self._client.parse_channel(data)) -class Channel: +class Channel(HashableById): """Base class for Discord channels.""" def __init__(self, data: Dict[str, Any], client_instance: "Client"): diff --git a/tests/test_hashable_mixin.py b/tests/test_hashable_mixin.py new file mode 100644 index 0000000..5a6c2ad --- /dev/null +++ b/tests/test_hashable_mixin.py @@ -0,0 +1,86 @@ +import types +from disagreement.models import User, Guild, Channel, Message +from disagreement.enums import ( + VerificationLevel, + MessageNotificationLevel, + ExplicitContentFilterLevel, + MFALevel, + GuildNSFWLevel, + PremiumTier, + ChannelType, +) + + +def _guild_data(gid="1"): + return { + "id": gid, + "name": "g", + "owner_id": gid, + "afk_timeout": 60, + "verification_level": VerificationLevel.NONE.value, + "default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value, + "explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value, + "roles": [], + "emojis": [], + "features": [], + "mfa_level": MFALevel.NONE.value, + "system_channel_flags": 0, + "premium_tier": PremiumTier.NONE.value, + "nsfw_level": GuildNSFWLevel.DEFAULT.value, + } + + +def _user(uid="1"): + return User({"id": uid, "username": "u", "discriminator": "0001"}) + + +def _message(mid="1"): + data = { + "id": mid, + "channel_id": "c", + "author": {"id": "2", "username": "u", "discriminator": "0001"}, + "content": "hi", + "timestamp": "t", + } + return Message(data, client_instance=types.SimpleNamespace()) + + +def _channel(cid="1"): + data = {"id": cid, "type": ChannelType.GUILD_TEXT.value} + return Channel(data, client_instance=types.SimpleNamespace()) + + +def test_user_hash_and_eq(): + a = _user() + b = _user() + c = _user("2") + assert a == b + assert hash(a) == hash(b) + assert a != c + + +def test_guild_hash_and_eq(): + a = Guild(_guild_data(), client_instance=types.SimpleNamespace()) + b = Guild(_guild_data(), client_instance=types.SimpleNamespace()) + c = Guild(_guild_data("2"), client_instance=types.SimpleNamespace()) + assert a == b + assert hash(a) == hash(b) + assert a != c + + +def test_channel_hash_and_eq(): + a = _channel() + b = _channel() + c = _channel("2") + assert a == b + assert hash(a) == hash(b) + assert a != c + + +def test_message_hash_and_eq(): + a = _message() + b = _message() + c = _message("2") + assert a == b + assert hash(a) == hash(b) + assert a != c