feat: Add pagination test for SafebooruCog to validate _fetch_posts_logic functionality
This commit is contained in:
parent
2d6aa1dc79
commit
777e206d07
@ -223,7 +223,18 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
async def _fetch_posts_logic(self, interaction_or_ctx: typing.Union[discord.Interaction, commands.Context, str], tags: str, pid_override: typing.Optional[int] = None, limit_override: typing.Optional[int] = None, hidden: bool = False) -> typing.Union[str, tuple[str, list], list]:
|
||||
all_results = []
|
||||
current_pid = pid_override if pid_override is not None else 0
|
||||
limit = limit_override if limit_override is not None else 100000
|
||||
# API has a hard limit of 1000 results per request, so we'll use that as our per-page limit
|
||||
per_page_limit = 1000
|
||||
# If limit_override is provided, use it, otherwise default to 3000 (3 pages of results)
|
||||
total_limit = limit_override if limit_override is not None else 100000
|
||||
|
||||
# For internal calls with specific pid/limit, use those exact values
|
||||
if pid_override is not None or limit_override is not None:
|
||||
use_pagination = False
|
||||
api_limit = limit_override if limit_override is not None else per_page_limit
|
||||
else:
|
||||
use_pagination = True
|
||||
api_limit = per_page_limit
|
||||
|
||||
if not isinstance(interaction_or_ctx, str) and interaction_or_ctx:
|
||||
if self.is_nsfw_site:
|
||||
@ -248,8 +259,11 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
elif hasattr(interaction_or_ctx, 'reply'): # Prefix command
|
||||
await interaction_or_ctx.reply(f"Fetching data from {self.cog_name}, please wait...")
|
||||
|
||||
|
||||
if pid_override is None and limit_override is None:
|
||||
# Check cache first if not using specific pagination
|
||||
if not use_pagination:
|
||||
# Skip cache for internal calls with specific pid/limit
|
||||
pass
|
||||
else:
|
||||
cache_key = tags.lower().strip()
|
||||
if cache_key in self.cache_data:
|
||||
cached_entry = self.cache_data[cache_key]
|
||||
@ -265,9 +279,65 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
log.info(f"Recreated aiohttp.ClientSession in _fetch_posts_logic for {self.cog_name}")
|
||||
|
||||
all_results = []
|
||||
|
||||
# If using pagination, we'll make multiple requests
|
||||
if use_pagination:
|
||||
max_pages = (total_limit + per_page_limit - 1) // per_page_limit
|
||||
for page in range(max_pages):
|
||||
# Stop if we've reached our total limit or if we got fewer results than the per-page limit
|
||||
if len(all_results) >= total_limit or (page > 0 and len(all_results) % per_page_limit != 0):
|
||||
break
|
||||
|
||||
api_params = {
|
||||
"page": "dapi", "s": "post", "q": "index",
|
||||
"limit": limit, "pid": current_pid, "tags": tags, "json": 1
|
||||
"limit": per_page_limit, "pid": page, "tags": tags, "json": 1
|
||||
}
|
||||
|
||||
try:
|
||||
async with self.session.get(self.api_base_url, params=api_params) as response:
|
||||
if response.status == 200:
|
||||
try:
|
||||
data = await response.json()
|
||||
except aiohttp.ContentTypeError:
|
||||
log.warning(f"{self.cog_name} API returned non-JSON for tags: {tags}, pid: {page}, params: {api_params}")
|
||||
data = None
|
||||
|
||||
if data and isinstance(data, list):
|
||||
# If we got fewer results than requested, we've reached the end
|
||||
all_results.extend(data)
|
||||
if len(data) < per_page_limit:
|
||||
break
|
||||
elif isinstance(data, list) and len(data) == 0:
|
||||
# Empty page, no more results
|
||||
break
|
||||
else:
|
||||
log.warning(f"Unexpected API response format from {self.cog_name} (not list or empty list): {data} for tags: {tags}, pid: {page}, params: {api_params}")
|
||||
break
|
||||
else:
|
||||
log.error(f"Failed to fetch {self.cog_name} data. HTTP Status: {response.status} for tags: {tags}, pid: {page}, params: {api_params}")
|
||||
if page == 0: # Only return error if first page fails
|
||||
return f"Failed to fetch data from {self.cog_name}. HTTP Status: {response.status}"
|
||||
break
|
||||
except aiohttp.ClientError as e:
|
||||
log.error(f"aiohttp.ClientError in _fetch_posts_logic for {self.cog_name} tags {tags}: {e}")
|
||||
if page == 0: # Only return error if first page fails
|
||||
return f"Network error fetching data from {self.cog_name}: {e}"
|
||||
break
|
||||
except Exception as e:
|
||||
log.exception(f"Unexpected error in _fetch_posts_logic API call for {self.cog_name} tags {tags}: {e}")
|
||||
if page == 0: # Only return error if first page fails
|
||||
return f"An unexpected error occurred during {self.cog_name} API call: {e}"
|
||||
break
|
||||
|
||||
# Limit to the total we want
|
||||
if len(all_results) > total_limit:
|
||||
all_results = all_results[:total_limit]
|
||||
break
|
||||
else:
|
||||
# Single request with specific pid/limit
|
||||
api_params = {
|
||||
"page": "dapi", "s": "post", "q": "index",
|
||||
"limit": api_limit, "pid": current_pid, "tags": tags, "json": 1
|
||||
}
|
||||
|
||||
try:
|
||||
|
48
test_pagination.py
Normal file
48
test_pagination.py
Normal file
@ -0,0 +1,48 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
log = logging.getLogger("test_pagination")
|
||||
|
||||
# Add the current directory to the path so we can import the cogs
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from cogs.safebooru_cog import SafebooruCog
|
||||
from discord.ext import commands
|
||||
import discord
|
||||
|
||||
async def test_pagination():
|
||||
# Create a mock bot with intents
|
||||
intents = discord.Intents.default()
|
||||
bot = commands.Bot(command_prefix='!', intents=intents)
|
||||
|
||||
# Initialize the cog
|
||||
cog = SafebooruCog(bot)
|
||||
|
||||
# Test the pagination for a specific tag
|
||||
tag = "kasane_teto"
|
||||
log.info(f"Testing pagination for tag: {tag}")
|
||||
|
||||
# Call the _fetch_posts_logic method
|
||||
results = await cog._fetch_posts_logic("test", tag)
|
||||
|
||||
# Check the results
|
||||
if isinstance(results, tuple):
|
||||
log.info(f"Found {len(results[1])} results")
|
||||
# Print the first few results
|
||||
for i, result in enumerate(results[1][:5]):
|
||||
log.info(f"Result {i+1}: {result.get('id')} - {result.get('file_url')}")
|
||||
else:
|
||||
log.error(f"Error: {results}")
|
||||
|
||||
# Clean up
|
||||
if hasattr(cog, 'session') and cog.session and not cog.session.closed:
|
||||
await cog.session.close()
|
||||
log.info("Closed aiohttp session")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the test
|
||||
asyncio.run(test_pagination())
|
Loading…
x
Reference in New Issue
Block a user