diff --git a/disagreement/http.py b/disagreement/http.py index ab50062..4b7b46f 100644 --- a/disagreement/http.py +++ b/disagreement/http.py @@ -854,13 +854,13 @@ class HTTPClient: use_auth_header=False, ) - async def get_user(self, user_id: "Snowflake") -> Dict[str, Any]: - """Fetches a user object for a given user ID.""" - return await self.request("GET", f"/users/{user_id}") - - async def get_current_user_guilds(self) -> List[Dict[str, Any]]: - """Returns the guilds the current user is in.""" - return await self.request("GET", "/users/@me/guilds") + async def get_user(self, user_id: "Snowflake") -> Dict[str, Any]: + """Fetches a user object for a given user ID.""" + return await self.request("GET", f"/users/{user_id}") + + async def get_current_user_guilds(self) -> List[Dict[str, Any]]: + """Returns the guilds the current user is in.""" + return await self.request("GET", "/users/@me/guilds") async def get_guild_member( self, guild_id: "Snowflake", user_id: "Snowflake" @@ -917,6 +917,29 @@ class HTTPClient: custom_headers=headers, ) + async def get_guild_prune_count(self, guild_id: "Snowflake", *, days: int) -> int: + """Returns the number of members that would be pruned.""" + + data = await self.request( + "GET", + f"/guilds/{guild_id}/prune", + params={"days": days}, + ) + return int(data.get("pruned", 0)) + + async def begin_guild_prune( + self, guild_id: "Snowflake", *, days: int, compute_count: bool = True + ) -> int: + """Begins a prune operation for the guild and returns the count.""" + + payload = {"days": days, "compute_prune_count": compute_count} + data = await self.request( + "POST", + f"/guilds/{guild_id}/prune", + payload=payload, + ) + return int(data.get("pruned", 0)) + async def get_guild_roles(self, guild_id: "Snowflake") -> List[Dict[str, Any]]: """Returns a list of role objects for the guild.""" return await self.request("GET", f"/guilds/{guild_id}/roles") @@ -1376,11 +1399,11 @@ class HTTPClient: """Joins the current user to a thread.""" await self.request("PUT", f"/channels/{channel_id}/thread-members/@me") - async def leave_thread(self, channel_id: "Snowflake") -> None: - """Removes the current user from a thread.""" - await self.request("DELETE", f"/channels/{channel_id}/thread-members/@me") - - async def create_dm(self, recipient_id: "Snowflake") -> Dict[str, Any]: - """Creates (or opens) a DM channel with the given user.""" - payload = {"recipient_id": str(recipient_id)} - return await self.request("POST", "/users/@me/channels", payload=payload) + async def leave_thread(self, channel_id: "Snowflake") -> None: + """Removes the current user from a thread.""" + await self.request("DELETE", f"/channels/{channel_id}/thread-members/@me") + + async def create_dm(self, recipient_id: "Snowflake") -> Dict[str, Any]: + """Creates (or opens) a DM channel with the given user.""" + payload = {"recipient_id": str(recipient_id)} + return await self.request("POST", "/users/@me/channels", payload=payload) diff --git a/disagreement/models.py b/disagreement/models.py index 6c046e8..fdd3f31 100644 --- a/disagreement/models.py +++ b/disagreement/models.py @@ -1344,10 +1344,30 @@ class Guild: ) member_data = await asyncio.wait_for(future, timeout=60.0) return [Member(m, self._client) for m in member_data] - except asyncio.TimeoutError: - if nonce in self._client._gateway._member_chunk_requests: - del self._client._gateway._member_chunk_requests[nonce] - raise + except asyncio.TimeoutError: + if nonce in self._client._gateway._member_chunk_requests: + del self._client._gateway._member_chunk_requests[nonce] + raise + + async def prune_members(self, days: int, *, compute_count: bool = True) -> int: + """|coro| Remove inactive members from the guild. + + Parameters + ---------- + days: int + Number of days of inactivity required to be pruned. + compute_count: bool + Whether to return the number of members pruned. + + Returns + ------- + int + The number of members pruned. + """ + + return await self._client._http.begin_guild_prune( + self.id, days=days, compute_count=compute_count + ) class Channel: diff --git a/tests/test_guild_prune.py b/tests/test_guild_prune.py new file mode 100644 index 0000000..8eaab3e --- /dev/null +++ b/tests/test_guild_prune.py @@ -0,0 +1,64 @@ +import pytest +from types import SimpleNamespace +from unittest.mock import AsyncMock + +from disagreement.http import HTTPClient +from disagreement.client import Client +from disagreement.enums import ( + VerificationLevel, + MessageNotificationLevel, + ExplicitContentFilterLevel, + MFALevel, + GuildNSFWLevel, + PremiumTier, +) +from disagreement.models import Guild + + +@pytest.mark.asyncio +async def test_http_get_guild_prune_count_calls_request(): + http = HTTPClient(token="t") + http.request = AsyncMock(return_value={"pruned": 3}) + count = await http.get_guild_prune_count("1", days=7) + http.request.assert_called_once_with("GET", f"/guilds/1/prune", params={"days": 7}) + assert count == 3 + + +@pytest.mark.asyncio +async def test_http_begin_guild_prune_calls_request(): + http = HTTPClient(token="t") + http.request = AsyncMock(return_value={"pruned": 2}) + count = await http.begin_guild_prune("1", days=1, compute_count=True) + http.request.assert_called_once_with( + "POST", + f"/guilds/1/prune", + payload={"days": 1, "compute_prune_count": True}, + ) + assert count == 2 + + +@pytest.mark.asyncio +async def test_guild_prune_members_calls_http(): + http = SimpleNamespace(begin_guild_prune=AsyncMock(return_value=1)) + client = Client(token="t") + client._http = http + guild_data = { + "id": "1", + "name": "g", + "owner_id": "1", + "afk_timeout": 60, + "verification_level": VerificationLevel.NONE.value, + "default_message_notifications": MessageNotificationLevel.ALL_MESSAGES.value, + "explicit_content_filter": ExplicitContentFilterLevel.DISABLED.value, + "roles": [], + "emojis": [], + "features": [], + "mfa_level": MFALevel.NONE.value, + "system_channel_flags": 0, + "premium_tier": PremiumTier.NONE.value, + "nsfw_level": GuildNSFWLevel.DEFAULT.value, + } + guild = Guild(guild_data, client_instance=client) + count = await guild.prune_members(2) + http.begin_guild_prune.assert_awaited_once_with("1", days=2, compute_count=True) + assert count == 1