feat: Add pagination test for SafebooruCog to validate _fetch_posts_logic functionality

This commit is contained in:
Slipstream 2025-05-18 09:31:05 -06:00
parent 2d6aa1dc79
commit 777e206d07
Signed by: slipstream
GPG Key ID: 13E498CE010AC6FD
2 changed files with 204 additions and 86 deletions

View File

@ -31,7 +31,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
# The base class itself doesn't need a Cog 'name' in the same way. # The base class itself doesn't need a Cog 'name' in the same way.
# commands.Cog.__init__(self, bot) # This might be needed if Cog's __init__ does setup # commands.Cog.__init__(self, bot) # This might be needed if Cog's __init__ does setup
# Let's rely on the derived class's super() call to handle Cog's __init__ properly. # Let's rely on the derived class's super() call to handle Cog's __init__ properly.
self.cog_name = cog_name self.cog_name = cog_name
self.api_base_url = api_base_url self.api_base_url = api_base_url
self.default_tags = default_tags self.default_tags = default_tags
@ -47,7 +47,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
self.subscriptions_data = self._load_subscriptions() self.subscriptions_data = self._load_subscriptions()
self.pending_requests_data = self._load_pending_requests() self.pending_requests_data = self._load_pending_requests()
self.session: typing.Optional[aiohttp.ClientSession] = None self.session: typing.Optional[aiohttp.ClientSession] = None
if bot.is_ready(): if bot.is_ready():
asyncio.create_task(self.initialize_cog_async()) asyncio.create_task(self.initialize_cog_async())
else: else:
@ -70,7 +70,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
async def cog_load(self): async def cog_load(self):
log.info(f"{self.cog_name}Cog loaded.") log.info(f"{self.cog_name}Cog loaded.")
if self.session is None or self.session.closed: if self.session is None or self.session.closed:
self.session = aiohttp.ClientSession() self.session = aiohttp.ClientSession()
log.info(f"aiohttp ClientSession (re)created during cog_load for {self.cog_name}Cog.") log.info(f"aiohttp ClientSession (re)created during cog_load for {self.cog_name}Cog.")
if not self.check_new_posts.is_running(): if not self.check_new_posts.is_running():
@ -111,7 +111,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
return json.load(f) return json.load(f)
except Exception as e: except Exception as e:
log.error(f"Failed to load {self.cog_name} subscriptions file ({self.subscriptions_file}): {e}") log.error(f"Failed to load {self.cog_name} subscriptions file ({self.subscriptions_file}): {e}")
return {} return {}
def _save_subscriptions(self): def _save_subscriptions(self):
try: try:
@ -148,7 +148,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
if sub.get("channel_id") == str(channel.id) and sub.get("webhook_url"): if sub.get("channel_id") == str(channel.id) and sub.get("webhook_url"):
try: try:
webhook = discord.Webhook.from_url(sub["webhook_url"], session=self.session) webhook = discord.Webhook.from_url(sub["webhook_url"], session=self.session)
await webhook.fetch() await webhook.fetch()
if webhook.channel_id == channel.id: if webhook.channel_id == channel.id:
log.debug(f"Reusing existing webhook {webhook.id} for channel {channel.id} ({self.cog_name})") log.debug(f"Reusing existing webhook {webhook.id} for channel {channel.id} ({self.cog_name})")
return sub["webhook_url"] return sub["webhook_url"]
@ -179,7 +179,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
avatar_bytes = await self.bot.user.display_avatar.read() avatar_bytes = await self.bot.user.display_avatar.read()
except Exception as e: except Exception as e:
log.warning(f"Could not read bot avatar for webhook creation ({self.cog_name}): {e}") log.warning(f"Could not read bot avatar for webhook creation ({self.cog_name}): {e}")
new_webhook = await channel.create_webhook(name=webhook_name, avatar=avatar_bytes, reason=f"{self.cog_name} Tag Watcher") new_webhook = await channel.create_webhook(name=webhook_name, avatar=avatar_bytes, reason=f"{self.cog_name} Tag Watcher")
log.info(f"Created new webhook {new_webhook.id} ('{new_webhook.name}') in channel {channel.id} ({self.cog_name})") log.info(f"Created new webhook {new_webhook.id} ('{new_webhook.name}') in channel {channel.id} ({self.cog_name})")
return new_webhook.url return new_webhook.url
@ -194,7 +194,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
if not self.session or self.session.closed: if not self.session or self.session.closed:
self.session = aiohttp.ClientSession() self.session = aiohttp.ClientSession()
log.info(f"Recreated aiohttp.ClientSession in _send_via_webhook for {self.cog_name}") log.info(f"Recreated aiohttp.ClientSession in _send_via_webhook for {self.cog_name}")
try: try:
webhook = discord.Webhook.from_url(webhook_url, session=self.session) webhook = discord.Webhook.from_url(webhook_url, session=self.session)
target_thread_obj = None target_thread_obj = None
@ -203,7 +203,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
target_thread_obj = discord.Object(id=int(thread_id)) target_thread_obj = discord.Object(id=int(thread_id))
except ValueError: except ValueError:
log.error(f"Invalid thread_id format: {thread_id} for webhook {webhook_url[:30]} ({self.cog_name}). Sending to main channel.") log.error(f"Invalid thread_id format: {thread_id} for webhook {webhook_url[:30]} ({self.cog_name}). Sending to main channel.")
await webhook.send( await webhook.send(
content=content, content=content,
username=f"{self.bot.user.name} {self.cog_name} Watcher" if self.bot.user else f"{self.cog_name} Watcher", username=f"{self.bot.user.name} {self.cog_name} Watcher" if self.bot.user else f"{self.cog_name} Watcher",
@ -223,20 +223,31 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
async def _fetch_posts_logic(self, interaction_or_ctx: typing.Union[discord.Interaction, commands.Context, str], tags: str, pid_override: typing.Optional[int] = None, limit_override: typing.Optional[int] = None, hidden: bool = False) -> typing.Union[str, tuple[str, list], list]: async def _fetch_posts_logic(self, interaction_or_ctx: typing.Union[discord.Interaction, commands.Context, str], tags: str, pid_override: typing.Optional[int] = None, limit_override: typing.Optional[int] = None, hidden: bool = False) -> typing.Union[str, tuple[str, list], list]:
all_results = [] all_results = []
current_pid = pid_override if pid_override is not None else 0 current_pid = pid_override if pid_override is not None else 0
limit = limit_override if limit_override is not None else 100000 # API has a hard limit of 1000 results per request, so we'll use that as our per-page limit
per_page_limit = 1000
# If limit_override is provided, use it, otherwise default to 3000 (3 pages of results)
total_limit = limit_override if limit_override is not None else 100000
if not isinstance(interaction_or_ctx, str) and interaction_or_ctx: # For internal calls with specific pid/limit, use those exact values
if pid_override is not None or limit_override is not None:
use_pagination = False
api_limit = limit_override if limit_override is not None else per_page_limit
else:
use_pagination = True
api_limit = per_page_limit
if not isinstance(interaction_or_ctx, str) and interaction_or_ctx:
if self.is_nsfw_site: if self.is_nsfw_site:
is_nsfw_channel = False is_nsfw_channel = False
channel = interaction_or_ctx.channel channel = interaction_or_ctx.channel
if isinstance(channel, discord.TextChannel) and channel.is_nsfw(): if isinstance(channel, discord.TextChannel) and channel.is_nsfw():
is_nsfw_channel = True is_nsfw_channel = True
elif isinstance(channel, discord.DMChannel): elif isinstance(channel, discord.DMChannel):
is_nsfw_channel = True is_nsfw_channel = True
# For Gelbooru-like APIs, 'rating:safe', 'rating:general', 'rating:questionable' might be SFW-ish # For Gelbooru-like APIs, 'rating:safe', 'rating:general', 'rating:questionable' might be SFW-ish
# We'll stick to 'rating:safe' for simplicity as it was in Rule34Cog # We'll stick to 'rating:safe' for simplicity as it was in Rule34Cog
allow_in_non_nsfw = 'rating:safe' in tags.lower() allow_in_non_nsfw = 'rating:safe' in tags.lower()
if not is_nsfw_channel and not allow_in_non_nsfw: if not is_nsfw_channel and not allow_in_non_nsfw:
return f'This command for {self.cog_name} can only be used in age-restricted (NSFW) channels, DMs, or with the `rating:safe` tag.' return f'This command for {self.cog_name} can only be used in age-restricted (NSFW) channels, DMs, or with the `rating:safe` tag.'
@ -248,13 +259,16 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
elif hasattr(interaction_or_ctx, 'reply'): # Prefix command elif hasattr(interaction_or_ctx, 'reply'): # Prefix command
await interaction_or_ctx.reply(f"Fetching data from {self.cog_name}, please wait...") await interaction_or_ctx.reply(f"Fetching data from {self.cog_name}, please wait...")
# Check cache first if not using specific pagination
if pid_override is None and limit_override is None: if not use_pagination:
# Skip cache for internal calls with specific pid/limit
pass
else:
cache_key = tags.lower().strip() cache_key = tags.lower().strip()
if cache_key in self.cache_data: if cache_key in self.cache_data:
cached_entry = self.cache_data[cache_key] cached_entry = self.cache_data[cache_key]
cache_timestamp = cached_entry.get("timestamp", 0) cache_timestamp = cached_entry.get("timestamp", 0)
if time.time() - cache_timestamp < 86400: if time.time() - cache_timestamp < 86400:
all_results = cached_entry.get("results", []) all_results = cached_entry.get("results", [])
if all_results: if all_results:
random_result = random.choice(all_results) random_result = random.choice(all_results)
@ -264,40 +278,96 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
self.session = aiohttp.ClientSession() self.session = aiohttp.ClientSession()
log.info(f"Recreated aiohttp.ClientSession in _fetch_posts_logic for {self.cog_name}") log.info(f"Recreated aiohttp.ClientSession in _fetch_posts_logic for {self.cog_name}")
all_results = [] all_results = []
api_params = {
"page": "dapi", "s": "post", "q": "index",
"limit": limit, "pid": current_pid, "tags": tags, "json": 1
}
try: # If using pagination, we'll make multiple requests
async with self.session.get(self.api_base_url, params=api_params) as response: if use_pagination:
if response.status == 200: max_pages = (total_limit + per_page_limit - 1) // per_page_limit
try: for page in range(max_pages):
data = await response.json() # Stop if we've reached our total limit or if we got fewer results than the per-page limit
except aiohttp.ContentTypeError: if len(all_results) >= total_limit or (page > 0 and len(all_results) % per_page_limit != 0):
log.warning(f"{self.cog_name} API returned non-JSON for tags: {tags}, pid: {current_pid}, params: {api_params}") break
data = None
if data and isinstance(data, list):
all_results.extend(data)
elif isinstance(data, list) and len(data) == 0: pass
else:
log.warning(f"Unexpected API response format from {self.cog_name} (not list or empty list): {data} for tags: {tags}, pid: {current_pid}, params: {api_params}")
if pid_override is not None or limit_override is not None:
return f"Unexpected API response format from {self.cog_name}: {response.status}"
else:
log.error(f"Failed to fetch {self.cog_name} data. HTTP Status: {response.status} for tags: {tags}, pid: {current_pid}, params: {api_params}")
return f"Failed to fetch data from {self.cog_name}. HTTP Status: {response.status}"
except aiohttp.ClientError as e:
log.error(f"aiohttp.ClientError in _fetch_posts_logic for {self.cog_name} tags {tags}: {e}")
return f"Network error fetching data from {self.cog_name}: {e}"
except Exception as e:
log.exception(f"Unexpected error in _fetch_posts_logic API call for {self.cog_name} tags {tags}: {e}")
return f"An unexpected error occurred during {self.cog_name} API call: {e}"
if pid_override is not None or limit_override is not None: api_params = {
return all_results "page": "dapi", "s": "post", "q": "index",
"limit": per_page_limit, "pid": page, "tags": tags, "json": 1
}
try:
async with self.session.get(self.api_base_url, params=api_params) as response:
if response.status == 200:
try:
data = await response.json()
except aiohttp.ContentTypeError:
log.warning(f"{self.cog_name} API returned non-JSON for tags: {tags}, pid: {page}, params: {api_params}")
data = None
if data and isinstance(data, list):
# If we got fewer results than requested, we've reached the end
all_results.extend(data)
if len(data) < per_page_limit:
break
elif isinstance(data, list) and len(data) == 0:
# Empty page, no more results
break
else:
log.warning(f"Unexpected API response format from {self.cog_name} (not list or empty list): {data} for tags: {tags}, pid: {page}, params: {api_params}")
break
else:
log.error(f"Failed to fetch {self.cog_name} data. HTTP Status: {response.status} for tags: {tags}, pid: {page}, params: {api_params}")
if page == 0: # Only return error if first page fails
return f"Failed to fetch data from {self.cog_name}. HTTP Status: {response.status}"
break
except aiohttp.ClientError as e:
log.error(f"aiohttp.ClientError in _fetch_posts_logic for {self.cog_name} tags {tags}: {e}")
if page == 0: # Only return error if first page fails
return f"Network error fetching data from {self.cog_name}: {e}"
break
except Exception as e:
log.exception(f"Unexpected error in _fetch_posts_logic API call for {self.cog_name} tags {tags}: {e}")
if page == 0: # Only return error if first page fails
return f"An unexpected error occurred during {self.cog_name} API call: {e}"
break
# Limit to the total we want
if len(all_results) > total_limit:
all_results = all_results[:total_limit]
break
else:
# Single request with specific pid/limit
api_params = {
"page": "dapi", "s": "post", "q": "index",
"limit": api_limit, "pid": current_pid, "tags": tags, "json": 1
}
try:
async with self.session.get(self.api_base_url, params=api_params) as response:
if response.status == 200:
try:
data = await response.json()
except aiohttp.ContentTypeError:
log.warning(f"{self.cog_name} API returned non-JSON for tags: {tags}, pid: {current_pid}, params: {api_params}")
data = None
if data and isinstance(data, list):
all_results.extend(data)
elif isinstance(data, list) and len(data) == 0: pass
else:
log.warning(f"Unexpected API response format from {self.cog_name} (not list or empty list): {data} for tags: {tags}, pid: {current_pid}, params: {api_params}")
if pid_override is not None or limit_override is not None:
return f"Unexpected API response format from {self.cog_name}: {response.status}"
else:
log.error(f"Failed to fetch {self.cog_name} data. HTTP Status: {response.status} for tags: {tags}, pid: {current_pid}, params: {api_params}")
return f"Failed to fetch data from {self.cog_name}. HTTP Status: {response.status}"
except aiohttp.ClientError as e:
log.error(f"aiohttp.ClientError in _fetch_posts_logic for {self.cog_name} tags {tags}: {e}")
return f"Network error fetching data from {self.cog_name}: {e}"
except Exception as e:
log.exception(f"Unexpected error in _fetch_posts_logic API call for {self.cog_name} tags {tags}: {e}")
return f"An unexpected error occurred during {self.cog_name} API call: {e}"
if pid_override is not None or limit_override is not None:
return all_results
if all_results: if all_results:
cache_key = tags.lower().strip() cache_key = tags.lower().strip()
@ -344,7 +414,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
content = f"Result 1/{len(self.all_results)}:\n{result['file_url']}" content = f"Result 1/{len(self.all_results)}:\n{result['file_url']}"
view = self.cog.BrowseView(self.cog, self.tags, self.all_results, self.hidden, self.current_index) view = self.cog.BrowseView(self.cog, self.tags, self.all_results, self.hidden, self.current_index)
await interaction.response.edit_message(content=content, view=view) await interaction.response.edit_message(content=content, view=view)
@discord.ui.button(label="Pin", style=discord.ButtonStyle.danger) @discord.ui.button(label="Pin", style=discord.ButtonStyle.danger)
async def pin_message(self, interaction: discord.Interaction, button: Button): async def pin_message(self, interaction: discord.Interaction, button: Button):
if interaction.message: if interaction.message:
@ -387,7 +457,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
async def last(self, interaction: discord.Interaction, button: Button): async def last(self, interaction: discord.Interaction, button: Button):
self.current_index = len(self.all_results) - 1 self.current_index = len(self.all_results) - 1
await self._update_message(interaction) await self._update_message(interaction)
@discord.ui.button(label="Go To", style=discord.ButtonStyle.primary, row=1) @discord.ui.button(label="Go To", style=discord.ButtonStyle.primary, row=1)
async def goto(self, interaction: discord.Interaction, button: Button): async def goto(self, interaction: discord.Interaction, button: Button):
modal = self.cog.GoToModal(len(self.all_results)) modal = self.cog.GoToModal(len(self.all_results))
@ -438,7 +508,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
async def _prefix_command_logic(self, ctx: commands.Context, tags: str): async def _prefix_command_logic(self, ctx: commands.Context, tags: str):
# Loading message is handled by _fetch_posts_logic if ctx is passed # Loading message is handled by _fetch_posts_logic if ctx is passed
response = await self._fetch_posts_logic(ctx, tags) response = await self._fetch_posts_logic(ctx, tags)
if isinstance(response, tuple): if isinstance(response, tuple):
content, all_results = response content, all_results = response
@ -476,7 +546,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
async def _slash_command_logic(self, interaction: discord.Interaction, tags: str, hidden: bool): async def _slash_command_logic(self, interaction: discord.Interaction, tags: str, hidden: bool):
response = await self._fetch_posts_logic(interaction, tags, hidden=hidden) response = await self._fetch_posts_logic(interaction, tags, hidden=hidden)
if isinstance(response, tuple): if isinstance(response, tuple):
content, all_results = response content, all_results = response
view = self.GelbooruButtons(self, tags, all_results, hidden) view = self.GelbooruButtons(self, tags, all_results, hidden)
@ -488,7 +558,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
ephemeral_error = hidden ephemeral_error = hidden
if self.is_nsfw_site and response.startswith(f'This command for {self.cog_name} can only be used'): if self.is_nsfw_site and response.startswith(f'This command for {self.cog_name} can only be used'):
ephemeral_error = True # Always make NSFW warnings ephemeral if possible ephemeral_error = True # Always make NSFW warnings ephemeral if possible
if not interaction.response.is_done(): if not interaction.response.is_done():
await interaction.response.send_message(response, ephemeral=ephemeral_error) await interaction.response.send_message(response, ephemeral=ephemeral_error)
else: else:
@ -499,7 +569,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
async def _browse_slash_command_logic(self, interaction: discord.Interaction, tags: str, hidden: bool): async def _browse_slash_command_logic(self, interaction: discord.Interaction, tags: str, hidden: bool):
response = await self._fetch_posts_logic(interaction, tags, hidden=hidden) response = await self._fetch_posts_logic(interaction, tags, hidden=hidden)
if isinstance(response, tuple): if isinstance(response, tuple):
_, all_results = response _, all_results = response
if not all_results: if not all_results:
@ -507,8 +577,8 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
if not interaction.response.is_done(): await interaction.response.send_message(content, ephemeral=hidden) if not interaction.response.is_done(): await interaction.response.send_message(content, ephemeral=hidden)
else: await interaction.followup.send(content, ephemeral=hidden) else: await interaction.followup.send(content, ephemeral=hidden)
return return
result = all_results[0] result = all_results[0]
content = f"Result 1/{len(all_results)}:\n{result['file_url']}" content = f"Result 1/{len(all_results)}:\n{result['file_url']}"
view = self.BrowseView(self, tags, all_results, hidden, current_index=0) view = self.BrowseView(self, tags, all_results, hidden, current_index=0)
if interaction.response.is_done(): await interaction.followup.send(content, view=view, ephemeral=hidden) if interaction.response.is_done(): await interaction.followup.send(content, view=view, ephemeral=hidden)
@ -536,15 +606,15 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
tags = sub.get("tags") tags = sub.get("tags")
webhook_url = sub.get("webhook_url") webhook_url = sub.get("webhook_url")
last_known_post_id = sub.get("last_known_post_id", 0) last_known_post_id = sub.get("last_known_post_id", 0)
if not tags or not webhook_url: if not tags or not webhook_url:
log.warning(f"Subscription for {self.cog_name} (guild {guild_id_str}, sub {sub.get('subscription_id')}) missing tags/webhook.") log.warning(f"Subscription for {self.cog_name} (guild {guild_id_str}, sub {sub.get('subscription_id')}) missing tags/webhook.")
continue continue
fetched_posts_response = await self._fetch_posts_logic("internal_task_call", tags, pid_override=0, limit_override=100) fetched_posts_response = await self._fetch_posts_logic("internal_task_call", tags, pid_override=0, limit_override=100)
if isinstance(fetched_posts_response, str): if isinstance(fetched_posts_response, str):
log.error(f"Error fetching {self.cog_name} posts for sub {sub.get('subscription_id')} (tags: {tags}): {fetched_posts_response}") log.error(f"Error fetching {self.cog_name} posts for sub {sub.get('subscription_id')} (tags: {tags}): {fetched_posts_response}")
continue continue
if not fetched_posts_response: continue if not fetched_posts_response: continue
@ -556,7 +626,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
if not isinstance(post_data, dict) or "id" not in post_data or "file_url" not in post_data: if not isinstance(post_data, dict) or "id" not in post_data or "file_url" not in post_data:
log.warning(f"Malformed {self.cog_name} post data for tags {tags}: {post_data}") log.warning(f"Malformed {self.cog_name} post data for tags {tags}: {post_data}")
continue continue
post_id = int(post_data["id"]) post_id = int(post_data["id"])
if post_id > last_known_post_id: new_posts_to_send.append(post_data) if post_id > last_known_post_id: new_posts_to_send.append(post_data)
if post_id > current_max_id_this_batch: current_max_id_this_batch = post_id if post_id > current_max_id_this_batch: current_max_id_this_batch = post_id
@ -567,9 +637,9 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
for new_post in new_posts_to_send: for new_post in new_posts_to_send:
post_id = int(new_post["id"]) post_id = int(new_post["id"])
message_content = f"New {self.cog_name} post for tags `{tags}`:\n{new_post['file_url']}" message_content = f"New {self.cog_name} post for tags `{tags}`:\n{new_post['file_url']}"
target_thread_id: typing.Optional[str] = sub.get("target_post_id") or sub.get("thread_id") target_thread_id: typing.Optional[str] = sub.get("target_post_id") or sub.get("thread_id")
send_success = await self._send_via_webhook(webhook_url, message_content, thread_id=target_thread_id) send_success = await self._send_via_webhook(webhook_url, message_content, thread_id=target_thread_id)
if send_success: if send_success:
latest_sent_id_for_this_sub = post_id latest_sent_id_for_this_sub = post_id
@ -579,12 +649,12 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
if original_sub_entry.get("subscription_id") == sub.get("subscription_id"): if original_sub_entry.get("subscription_id") == sub.get("subscription_id"):
original_sub_entry["last_known_post_id"] = latest_sent_id_for_this_sub original_sub_entry["last_known_post_id"] = latest_sent_id_for_this_sub
needs_save = True needs_save = True
break break
await asyncio.sleep(1) await asyncio.sleep(1)
else: else:
log.error(f"Failed to send new {self.cog_name} post via webhook for sub {sub.get('subscription_id')}.") log.error(f"Failed to send new {self.cog_name} post via webhook for sub {sub.get('subscription_id')}.")
break break
if not new_posts_to_send and current_max_id_this_batch > last_known_post_id: if not new_posts_to_send and current_max_id_this_batch > last_known_post_id:
original_guild_subs = self.subscriptions_data.get(guild_id_str) original_guild_subs = self.subscriptions_data.get(guild_id_str)
if original_guild_subs: if original_guild_subs:
@ -603,12 +673,12 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
async def before_check_new_posts(self): async def before_check_new_posts(self):
await self.bot.wait_until_ready() await self.bot.wait_until_ready()
log.info(f"{self.cog_name}Cog: `check_new_posts` loop is waiting for bot readiness...") log.info(f"{self.cog_name}Cog: `check_new_posts` loop is waiting for bot readiness...")
if self.session is None or self.session.closed: if self.session is None or self.session.closed:
self.session = aiohttp.ClientSession() self.session = aiohttp.ClientSession()
log.info(f"aiohttp ClientSession created before {self.cog_name} check_new_posts loop.") log.info(f"aiohttp ClientSession created before {self.cog_name} check_new_posts loop.")
async def _create_new_subscription(self, guild_id: int, user_id: int, tags: str, async def _create_new_subscription(self, guild_id: int, user_id: int, tags: str,
target_channel: typing.Union[discord.TextChannel, discord.ForumChannel], target_channel: typing.Union[discord.TextChannel, discord.ForumChannel],
requested_thread_target: typing.Optional[str] = None, requested_thread_target: typing.Optional[str] = None,
requested_post_title: typing.Optional[str] = None, requested_post_title: typing.Optional[str] = None,
is_request_approval: bool = False, is_request_approval: bool = False,
@ -631,22 +701,22 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
if not found_thread and hasattr(target_channel, 'threads'): if not found_thread and hasattr(target_channel, 'threads'):
for t in target_channel.threads: for t in target_channel.threads:
if t.name.lower() == requested_thread_target.lower(): found_thread = t; break if t.name.lower() == requested_thread_target.lower(): found_thread = t; break
if found_thread: if found_thread:
actual_target_thread_id = str(found_thread.id) actual_target_thread_id = str(found_thread.id)
actual_target_thread_mention = found_thread.mention actual_target_thread_mention = found_thread.mention
else: return f"❌ Could not find thread `{requested_thread_target}` in {target_channel.mention} for {self.cog_name}." else: return f"❌ Could not find thread `{requested_thread_target}` in {target_channel.mention} for {self.cog_name}."
elif isinstance(target_channel, discord.ForumChannel): elif isinstance(target_channel, discord.ForumChannel):
forum_post_initial_message = f"✨ **New {self.cog_name} Watch Initialized!** ✨\nNow monitoring tags: `{tags}`" forum_post_initial_message = f"✨ **New {self.cog_name} Watch Initialized!** ✨\nNow monitoring tags: `{tags}`"
if is_request_approval and requester_mention: forum_post_initial_message += f"\n_Requested by: {requester_mention}_" if is_request_approval and requester_mention: forum_post_initial_message += f"\n_Requested by: {requester_mention}_"
try: try:
if not target_channel.permissions_for(target_channel.guild.me).create_public_threads: if not target_channel.permissions_for(target_channel.guild.me).create_public_threads:
return f"❌ I lack permission to create posts in forum {target_channel.mention} for {self.cog_name}." return f"❌ I lack permission to create posts in forum {target_channel.mention} for {self.cog_name}."
new_forum_post = await target_channel.create_thread(name=actual_post_title, content=forum_post_initial_message, reason=f"{self.cog_name}Watch: {tags}") new_forum_post = await target_channel.create_thread(name=actual_post_title, content=forum_post_initial_message, reason=f"{self.cog_name}Watch: {tags}")
actual_target_thread_id = str(new_forum_post.thread.id) actual_target_thread_id = str(new_forum_post.thread.id)
actual_target_thread_mention = new_forum_post.thread.mention actual_target_thread_mention = new_forum_post.thread.mention
log.info(f"Created {self.cog_name} forum post {new_forum_post.thread.id} for tags '{tags}' in forum {target_channel.id}") log.info(f"Created {self.cog_name} forum post {new_forum_post.thread.id} for tags '{tags}' in forum {target_channel.id}")
except discord.HTTPException as e: return f"❌ Failed to create {self.cog_name} forum post in {target_channel.mention}: {e}" except discord.HTTPException as e: return f"❌ Failed to create {self.cog_name} forum post in {target_channel.mention}: {e}"
@ -666,19 +736,19 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
guild_id_str = str(guild_id) guild_id_str = str(guild_id)
subscription_id = str(uuid.uuid4()) subscription_id = str(uuid.uuid4())
new_sub_data = { new_sub_data = {
"subscription_id": subscription_id, "tags": tags.strip(), "webhook_url": webhook_url, "subscription_id": subscription_id, "tags": tags.strip(), "webhook_url": webhook_url,
"last_known_post_id": last_known_post_id, "added_by_user_id": str(user_id), "last_known_post_id": last_known_post_id, "added_by_user_id": str(user_id),
"added_timestamp": discord.utils.utcnow().isoformat() "added_timestamp": discord.utils.utcnow().isoformat()
} }
if isinstance(target_channel, discord.ForumChannel): if isinstance(target_channel, discord.ForumChannel):
new_sub_data.update({"forum_channel_id": str(target_channel.id), "target_post_id": actual_target_thread_id, "post_title": actual_post_title}) new_sub_data.update({"forum_channel_id": str(target_channel.id), "target_post_id": actual_target_thread_id, "post_title": actual_post_title})
else: else:
new_sub_data.update({"channel_id": str(target_channel.id), "thread_id": actual_target_thread_id}) new_sub_data.update({"channel_id": str(target_channel.id), "thread_id": actual_target_thread_id})
if guild_id_str not in self.subscriptions_data: self.subscriptions_data[guild_id_str] = [] if guild_id_str not in self.subscriptions_data: self.subscriptions_data[guild_id_str] = []
for existing_sub in self.subscriptions_data[guild_id_str]: # Duplicate check for existing_sub in self.subscriptions_data[guild_id_str]: # Duplicate check
is_dup_tags = existing_sub.get("tags") == new_sub_data["tags"] is_dup_tags = existing_sub.get("tags") == new_sub_data["tags"]
is_dup_forum = isinstance(target_channel, discord.ForumChannel) and \ is_dup_forum = isinstance(target_channel, discord.ForumChannel) and \
@ -696,7 +766,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
target_desc = f"in {target_channel.mention}" target_desc = f"in {target_channel.mention}"
if isinstance(target_channel, discord.ForumChannel) and actual_target_thread_mention: target_desc = f"in forum post {actual_target_thread_mention} within {target_channel.mention}" if isinstance(target_channel, discord.ForumChannel) and actual_target_thread_mention: target_desc = f"in forum post {actual_target_thread_mention} within {target_channel.mention}"
elif actual_target_thread_mention: target_desc += f" (thread: {actual_target_thread_mention})" elif actual_target_thread_mention: target_desc += f" (thread: {actual_target_thread_mention})"
log.info(f"{self.cog_name} subscription added: Guild {guild_id_str}, Tags '{tags}', Target {target_desc}, SubID {subscription_id}") log.info(f"{self.cog_name} subscription added: Guild {guild_id_str}, Tags '{tags}', Target {target_desc}, SubID {subscription_id}")
return (f"✅ Watching {self.cog_name} for new posts with tags `{tags}` {target_desc}.\n" return (f"✅ Watching {self.cog_name} for new posts with tags `{tags}` {target_desc}.\n"
f"Initial latest post ID: {last_known_post_id}. Sub ID: `{subscription_id}`.") f"Initial latest post ID: {last_known_post_id}. Sub ID: `{subscription_id}`.")
@ -743,7 +813,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
self._save_pending_requests() self._save_pending_requests()
log.info(f"New {self.cog_name} watch request: Guild {guild_id_str}, Requester {interaction.user}, Tags '{tags}', ReqID {request_id}") log.info(f"New {self.cog_name} watch request: Guild {guild_id_str}, Requester {interaction.user}, Tags '{tags}', ReqID {request_id}")
await interaction.followup.send( await interaction.followup.send(
f"✅ Your {self.cog_name} watch request for tags `{tags}` in forum {forum_channel.mention} (title: \"{actual_post_title}\") submitted.\n" f"✅ Your {self.cog_name} watch request for tags `{tags}` in forum {forum_channel.mention} (title: \"{actual_post_title}\") submitted.\n"
f"Request ID: `{request_id}`. Awaiting moderator approval." f"Request ID: `{request_id}`. Awaiting moderator approval."
@ -775,7 +845,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
f" **Proposed Title:** \"{req.get('requested_post_title')}\"\n" f" **Proposed Title:** \"{req.get('requested_post_title')}\"\n"
f" **Requested:** {time_fmt}\n---" f" **Requested:** {time_fmt}\n---"
) )
embed.description = "\n".join(desc_parts)[:4096] # Limit description length embed.description = "\n".join(desc_parts)[:4096] # Limit description length
await interaction.response.send_message(embed=embed, ephemeral=True) await interaction.response.send_message(embed=embed, ephemeral=True)
@ -789,7 +859,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
for i, r in enumerate(self.pending_requests_data.get(guild_id_str, [])): for i, r in enumerate(self.pending_requests_data.get(guild_id_str, [])):
if r.get("request_id") == request_id and r.get("status") == "pending": if r.get("request_id") == request_id and r.get("status") == "pending":
req_to_approve, req_idx = r, i; break req_to_approve, req_idx = r, i; break
if not req_to_approve: if not req_to_approve:
await interaction.followup.send(f"❌ Pending {self.cog_name} request ID `{request_id}` not found.", ephemeral=True) await interaction.followup.send(f"❌ Pending {self.cog_name} request ID `{request_id}` not found.", ephemeral=True)
return return
@ -830,7 +900,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
for i, r in enumerate(self.pending_requests_data.get(guild_id_str, [])): for i, r in enumerate(self.pending_requests_data.get(guild_id_str, [])):
if r.get("request_id") == request_id and r.get("status") == "pending": if r.get("request_id") == request_id and r.get("status") == "pending":
req_to_reject, req_idx = r, i; break req_to_reject, req_idx = r, i; break
if not req_to_reject: if not req_to_reject:
await interaction.followup.send(f"❌ Pending {self.cog_name} request ID `{request_id}` not found.", ephemeral=True) await interaction.followup.send(f"❌ Pending {self.cog_name} request ID `{request_id}` not found.", ephemeral=True)
return return
@ -867,9 +937,9 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
elif sub.get('channel_id'): elif sub.get('channel_id'):
target_location = f"<#{sub.get('channel_id')}>" target_location = f"<#{sub.get('channel_id')}>"
if sub.get('thread_id'): target_location += f" (Thread <#{sub.get('thread_id')}>)" if sub.get('thread_id'): target_location += f" (Thread <#{sub.get('thread_id')}>)"
desc_parts.append(f"**ID:** `{sub.get('subscription_id')}`\n **Tags:** `{sub.get('tags')}`\n **Target:** {target_location}\n **Last Sent ID:** `{sub.get('last_known_post_id', 'N/A')}`\n---") desc_parts.append(f"**ID:** `{sub.get('subscription_id')}`\n **Tags:** `{sub.get('tags')}`\n **Target:** {target_location}\n **Last Sent ID:** `{sub.get('last_known_post_id', 'N/A')}`\n---")
embed.description = "\n".join(desc_parts)[:4096] embed.description = "\n".join(desc_parts)[:4096]
await interaction.response.send_message(embed=embed, ephemeral=True) await interaction.response.send_message(embed=embed, ephemeral=True)
@ -888,7 +958,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
found = True found = True
log.info(f"Removing {self.cog_name} watch: Guild {guild_id_str}, Sub ID {subscription_id}") log.info(f"Removing {self.cog_name} watch: Guild {guild_id_str}, Sub ID {subscription_id}")
else: new_subs_list.append(sub_entry) else: new_subs_list.append(sub_entry)
if not found: if not found:
await interaction.response.send_message(f"{self.cog_name} subscription ID `{subscription_id}` not found.", ephemeral=True) await interaction.response.send_message(f"{self.cog_name} subscription ID `{subscription_id}` not found.", ephemeral=True)
return return

48
test_pagination.py Normal file
View File

@ -0,0 +1,48 @@
import asyncio
import logging
import sys
import os
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
log = logging.getLogger("test_pagination")
# Add the current directory to the path so we can import the cogs
sys.path.append(os.getcwd())
from cogs.safebooru_cog import SafebooruCog
from discord.ext import commands
import discord
async def test_pagination():
# Create a mock bot with intents
intents = discord.Intents.default()
bot = commands.Bot(command_prefix='!', intents=intents)
# Initialize the cog
cog = SafebooruCog(bot)
# Test the pagination for a specific tag
tag = "kasane_teto"
log.info(f"Testing pagination for tag: {tag}")
# Call the _fetch_posts_logic method
results = await cog._fetch_posts_logic("test", tag)
# Check the results
if isinstance(results, tuple):
log.info(f"Found {len(results[1])} results")
# Print the first few results
for i, result in enumerate(results[1][:5]):
log.info(f"Result {i+1}: {result.get('id')} - {result.get('file_url')}")
else:
log.error(f"Error: {results}")
# Clean up
if hasattr(cog, 'session') and cog.session and not cog.session.closed:
await cog.session.close()
log.info("Closed aiohttp session")
if __name__ == "__main__":
# Run the test
asyncio.run(test_pagination())