feat: Add pagination test for SafebooruCog to validate _fetch_posts_logic functionality

This commit is contained in:
Slipstream 2025-05-18 09:31:05 -06:00
parent 2d6aa1dc79
commit 777e206d07
Signed by: slipstream
GPG Key ID: 13E498CE010AC6FD
2 changed files with 204 additions and 86 deletions

View File

@ -223,7 +223,18 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
async def _fetch_posts_logic(self, interaction_or_ctx: typing.Union[discord.Interaction, commands.Context, str], tags: str, pid_override: typing.Optional[int] = None, limit_override: typing.Optional[int] = None, hidden: bool = False) -> typing.Union[str, tuple[str, list], list]:
all_results = []
current_pid = pid_override if pid_override is not None else 0
limit = limit_override if limit_override is not None else 100000
# API has a hard limit of 1000 results per request, so we'll use that as our per-page limit
per_page_limit = 1000
# If limit_override is provided, use it, otherwise default to 3000 (3 pages of results)
total_limit = limit_override if limit_override is not None else 100000
# For internal calls with specific pid/limit, use those exact values
if pid_override is not None or limit_override is not None:
use_pagination = False
api_limit = limit_override if limit_override is not None else per_page_limit
else:
use_pagination = True
api_limit = per_page_limit
if not isinstance(interaction_or_ctx, str) and interaction_or_ctx:
if self.is_nsfw_site:
@ -248,8 +259,11 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
elif hasattr(interaction_or_ctx, 'reply'): # Prefix command
await interaction_or_ctx.reply(f"Fetching data from {self.cog_name}, please wait...")
if pid_override is None and limit_override is None:
# Check cache first if not using specific pagination
if not use_pagination:
# Skip cache for internal calls with specific pid/limit
pass
else:
cache_key = tags.lower().strip()
if cache_key in self.cache_data:
cached_entry = self.cache_data[cache_key]
@ -265,9 +279,65 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
log.info(f"Recreated aiohttp.ClientSession in _fetch_posts_logic for {self.cog_name}")
all_results = []
# If using pagination, we'll make multiple requests
if use_pagination:
max_pages = (total_limit + per_page_limit - 1) // per_page_limit
for page in range(max_pages):
# Stop if we've reached our total limit or if we got fewer results than the per-page limit
if len(all_results) >= total_limit or (page > 0 and len(all_results) % per_page_limit != 0):
break
api_params = {
"page": "dapi", "s": "post", "q": "index",
"limit": limit, "pid": current_pid, "tags": tags, "json": 1
"limit": per_page_limit, "pid": page, "tags": tags, "json": 1
}
try:
async with self.session.get(self.api_base_url, params=api_params) as response:
if response.status == 200:
try:
data = await response.json()
except aiohttp.ContentTypeError:
log.warning(f"{self.cog_name} API returned non-JSON for tags: {tags}, pid: {page}, params: {api_params}")
data = None
if data and isinstance(data, list):
# If we got fewer results than requested, we've reached the end
all_results.extend(data)
if len(data) < per_page_limit:
break
elif isinstance(data, list) and len(data) == 0:
# Empty page, no more results
break
else:
log.warning(f"Unexpected API response format from {self.cog_name} (not list or empty list): {data} for tags: {tags}, pid: {page}, params: {api_params}")
break
else:
log.error(f"Failed to fetch {self.cog_name} data. HTTP Status: {response.status} for tags: {tags}, pid: {page}, params: {api_params}")
if page == 0: # Only return error if first page fails
return f"Failed to fetch data from {self.cog_name}. HTTP Status: {response.status}"
break
except aiohttp.ClientError as e:
log.error(f"aiohttp.ClientError in _fetch_posts_logic for {self.cog_name} tags {tags}: {e}")
if page == 0: # Only return error if first page fails
return f"Network error fetching data from {self.cog_name}: {e}"
break
except Exception as e:
log.exception(f"Unexpected error in _fetch_posts_logic API call for {self.cog_name} tags {tags}: {e}")
if page == 0: # Only return error if first page fails
return f"An unexpected error occurred during {self.cog_name} API call: {e}"
break
# Limit to the total we want
if len(all_results) > total_limit:
all_results = all_results[:total_limit]
break
else:
# Single request with specific pid/limit
api_params = {
"page": "dapi", "s": "post", "q": "index",
"limit": api_limit, "pid": current_pid, "tags": tags, "json": 1
}
try:

48
test_pagination.py Normal file
View File

@ -0,0 +1,48 @@
import asyncio
import logging
import sys
import os
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
log = logging.getLogger("test_pagination")
# Add the current directory to the path so we can import the cogs
sys.path.append(os.getcwd())
from cogs.safebooru_cog import SafebooruCog
from discord.ext import commands
import discord
async def test_pagination():
# Create a mock bot with intents
intents = discord.Intents.default()
bot = commands.Bot(command_prefix='!', intents=intents)
# Initialize the cog
cog = SafebooruCog(bot)
# Test the pagination for a specific tag
tag = "kasane_teto"
log.info(f"Testing pagination for tag: {tag}")
# Call the _fetch_posts_logic method
results = await cog._fetch_posts_logic("test", tag)
# Check the results
if isinstance(results, tuple):
log.info(f"Found {len(results[1])} results")
# Print the first few results
for i, result in enumerate(results[1][:5]):
log.info(f"Result {i+1}: {result.get('id')} - {result.get('file_url')}")
else:
log.error(f"Error: {results}")
# Clean up
if hasattr(cog, 'session') and cog.session and not cog.session.closed:
await cog.session.close()
log.info("Closed aiohttp session")
if __name__ == "__main__":
# Run the test
asyncio.run(test_pagination())