feat: Add pagination test for SafebooruCog to validate _fetch_posts_logic functionality
This commit is contained in:
parent
2d6aa1dc79
commit
777e206d07
@ -31,7 +31,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
# The base class itself doesn't need a Cog 'name' in the same way.
|
||||
# commands.Cog.__init__(self, bot) # This might be needed if Cog's __init__ does setup
|
||||
# Let's rely on the derived class's super() call to handle Cog's __init__ properly.
|
||||
|
||||
|
||||
self.cog_name = cog_name
|
||||
self.api_base_url = api_base_url
|
||||
self.default_tags = default_tags
|
||||
@ -47,7 +47,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
self.subscriptions_data = self._load_subscriptions()
|
||||
self.pending_requests_data = self._load_pending_requests()
|
||||
self.session: typing.Optional[aiohttp.ClientSession] = None
|
||||
|
||||
|
||||
if bot.is_ready():
|
||||
asyncio.create_task(self.initialize_cog_async())
|
||||
else:
|
||||
@ -70,7 +70,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
|
||||
async def cog_load(self):
|
||||
log.info(f"{self.cog_name}Cog loaded.")
|
||||
if self.session is None or self.session.closed:
|
||||
if self.session is None or self.session.closed:
|
||||
self.session = aiohttp.ClientSession()
|
||||
log.info(f"aiohttp ClientSession (re)created during cog_load for {self.cog_name}Cog.")
|
||||
if not self.check_new_posts.is_running():
|
||||
@ -111,7 +111,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to load {self.cog_name} subscriptions file ({self.subscriptions_file}): {e}")
|
||||
return {}
|
||||
return {}
|
||||
|
||||
def _save_subscriptions(self):
|
||||
try:
|
||||
@ -148,7 +148,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
if sub.get("channel_id") == str(channel.id) and sub.get("webhook_url"):
|
||||
try:
|
||||
webhook = discord.Webhook.from_url(sub["webhook_url"], session=self.session)
|
||||
await webhook.fetch()
|
||||
await webhook.fetch()
|
||||
if webhook.channel_id == channel.id:
|
||||
log.debug(f"Reusing existing webhook {webhook.id} for channel {channel.id} ({self.cog_name})")
|
||||
return sub["webhook_url"]
|
||||
@ -179,7 +179,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
avatar_bytes = await self.bot.user.display_avatar.read()
|
||||
except Exception as e:
|
||||
log.warning(f"Could not read bot avatar for webhook creation ({self.cog_name}): {e}")
|
||||
|
||||
|
||||
new_webhook = await channel.create_webhook(name=webhook_name, avatar=avatar_bytes, reason=f"{self.cog_name} Tag Watcher")
|
||||
log.info(f"Created new webhook {new_webhook.id} ('{new_webhook.name}') in channel {channel.id} ({self.cog_name})")
|
||||
return new_webhook.url
|
||||
@ -194,7 +194,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
if not self.session or self.session.closed:
|
||||
self.session = aiohttp.ClientSession()
|
||||
log.info(f"Recreated aiohttp.ClientSession in _send_via_webhook for {self.cog_name}")
|
||||
|
||||
|
||||
try:
|
||||
webhook = discord.Webhook.from_url(webhook_url, session=self.session)
|
||||
target_thread_obj = None
|
||||
@ -203,7 +203,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
target_thread_obj = discord.Object(id=int(thread_id))
|
||||
except ValueError:
|
||||
log.error(f"Invalid thread_id format: {thread_id} for webhook {webhook_url[:30]} ({self.cog_name}). Sending to main channel.")
|
||||
|
||||
|
||||
await webhook.send(
|
||||
content=content,
|
||||
username=f"{self.bot.user.name} {self.cog_name} Watcher" if self.bot.user else f"{self.cog_name} Watcher",
|
||||
@ -223,20 +223,31 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
async def _fetch_posts_logic(self, interaction_or_ctx: typing.Union[discord.Interaction, commands.Context, str], tags: str, pid_override: typing.Optional[int] = None, limit_override: typing.Optional[int] = None, hidden: bool = False) -> typing.Union[str, tuple[str, list], list]:
|
||||
all_results = []
|
||||
current_pid = pid_override if pid_override is not None else 0
|
||||
limit = limit_override if limit_override is not None else 100000
|
||||
# API has a hard limit of 1000 results per request, so we'll use that as our per-page limit
|
||||
per_page_limit = 1000
|
||||
# If limit_override is provided, use it, otherwise default to 3000 (3 pages of results)
|
||||
total_limit = limit_override if limit_override is not None else 100000
|
||||
|
||||
if not isinstance(interaction_or_ctx, str) and interaction_or_ctx:
|
||||
# For internal calls with specific pid/limit, use those exact values
|
||||
if pid_override is not None or limit_override is not None:
|
||||
use_pagination = False
|
||||
api_limit = limit_override if limit_override is not None else per_page_limit
|
||||
else:
|
||||
use_pagination = True
|
||||
api_limit = per_page_limit
|
||||
|
||||
if not isinstance(interaction_or_ctx, str) and interaction_or_ctx:
|
||||
if self.is_nsfw_site:
|
||||
is_nsfw_channel = False
|
||||
channel = interaction_or_ctx.channel
|
||||
if isinstance(channel, discord.TextChannel) and channel.is_nsfw():
|
||||
is_nsfw_channel = True
|
||||
elif isinstance(channel, discord.DMChannel):
|
||||
elif isinstance(channel, discord.DMChannel):
|
||||
is_nsfw_channel = True
|
||||
|
||||
|
||||
# For Gelbooru-like APIs, 'rating:safe', 'rating:general', 'rating:questionable' might be SFW-ish
|
||||
# We'll stick to 'rating:safe' for simplicity as it was in Rule34Cog
|
||||
allow_in_non_nsfw = 'rating:safe' in tags.lower()
|
||||
allow_in_non_nsfw = 'rating:safe' in tags.lower()
|
||||
|
||||
if not is_nsfw_channel and not allow_in_non_nsfw:
|
||||
return f'This command for {self.cog_name} can only be used in age-restricted (NSFW) channels, DMs, or with the `rating:safe` tag.'
|
||||
@ -248,13 +259,16 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
elif hasattr(interaction_or_ctx, 'reply'): # Prefix command
|
||||
await interaction_or_ctx.reply(f"Fetching data from {self.cog_name}, please wait...")
|
||||
|
||||
|
||||
if pid_override is None and limit_override is None:
|
||||
# Check cache first if not using specific pagination
|
||||
if not use_pagination:
|
||||
# Skip cache for internal calls with specific pid/limit
|
||||
pass
|
||||
else:
|
||||
cache_key = tags.lower().strip()
|
||||
if cache_key in self.cache_data:
|
||||
cached_entry = self.cache_data[cache_key]
|
||||
cache_timestamp = cached_entry.get("timestamp", 0)
|
||||
if time.time() - cache_timestamp < 86400:
|
||||
if time.time() - cache_timestamp < 86400:
|
||||
all_results = cached_entry.get("results", [])
|
||||
if all_results:
|
||||
random_result = random.choice(all_results)
|
||||
@ -264,40 +278,96 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
self.session = aiohttp.ClientSession()
|
||||
log.info(f"Recreated aiohttp.ClientSession in _fetch_posts_logic for {self.cog_name}")
|
||||
|
||||
all_results = []
|
||||
api_params = {
|
||||
"page": "dapi", "s": "post", "q": "index",
|
||||
"limit": limit, "pid": current_pid, "tags": tags, "json": 1
|
||||
}
|
||||
all_results = []
|
||||
|
||||
try:
|
||||
async with self.session.get(self.api_base_url, params=api_params) as response:
|
||||
if response.status == 200:
|
||||
try:
|
||||
data = await response.json()
|
||||
except aiohttp.ContentTypeError:
|
||||
log.warning(f"{self.cog_name} API returned non-JSON for tags: {tags}, pid: {current_pid}, params: {api_params}")
|
||||
data = None
|
||||
|
||||
if data and isinstance(data, list):
|
||||
all_results.extend(data)
|
||||
elif isinstance(data, list) and len(data) == 0: pass
|
||||
else:
|
||||
log.warning(f"Unexpected API response format from {self.cog_name} (not list or empty list): {data} for tags: {tags}, pid: {current_pid}, params: {api_params}")
|
||||
if pid_override is not None or limit_override is not None:
|
||||
return f"Unexpected API response format from {self.cog_name}: {response.status}"
|
||||
else:
|
||||
log.error(f"Failed to fetch {self.cog_name} data. HTTP Status: {response.status} for tags: {tags}, pid: {current_pid}, params: {api_params}")
|
||||
return f"Failed to fetch data from {self.cog_name}. HTTP Status: {response.status}"
|
||||
except aiohttp.ClientError as e:
|
||||
log.error(f"aiohttp.ClientError in _fetch_posts_logic for {self.cog_name} tags {tags}: {e}")
|
||||
return f"Network error fetching data from {self.cog_name}: {e}"
|
||||
except Exception as e:
|
||||
log.exception(f"Unexpected error in _fetch_posts_logic API call for {self.cog_name} tags {tags}: {e}")
|
||||
return f"An unexpected error occurred during {self.cog_name} API call: {e}"
|
||||
# If using pagination, we'll make multiple requests
|
||||
if use_pagination:
|
||||
max_pages = (total_limit + per_page_limit - 1) // per_page_limit
|
||||
for page in range(max_pages):
|
||||
# Stop if we've reached our total limit or if we got fewer results than the per-page limit
|
||||
if len(all_results) >= total_limit or (page > 0 and len(all_results) % per_page_limit != 0):
|
||||
break
|
||||
|
||||
if pid_override is not None or limit_override is not None:
|
||||
return all_results
|
||||
api_params = {
|
||||
"page": "dapi", "s": "post", "q": "index",
|
||||
"limit": per_page_limit, "pid": page, "tags": tags, "json": 1
|
||||
}
|
||||
|
||||
try:
|
||||
async with self.session.get(self.api_base_url, params=api_params) as response:
|
||||
if response.status == 200:
|
||||
try:
|
||||
data = await response.json()
|
||||
except aiohttp.ContentTypeError:
|
||||
log.warning(f"{self.cog_name} API returned non-JSON for tags: {tags}, pid: {page}, params: {api_params}")
|
||||
data = None
|
||||
|
||||
if data and isinstance(data, list):
|
||||
# If we got fewer results than requested, we've reached the end
|
||||
all_results.extend(data)
|
||||
if len(data) < per_page_limit:
|
||||
break
|
||||
elif isinstance(data, list) and len(data) == 0:
|
||||
# Empty page, no more results
|
||||
break
|
||||
else:
|
||||
log.warning(f"Unexpected API response format from {self.cog_name} (not list or empty list): {data} for tags: {tags}, pid: {page}, params: {api_params}")
|
||||
break
|
||||
else:
|
||||
log.error(f"Failed to fetch {self.cog_name} data. HTTP Status: {response.status} for tags: {tags}, pid: {page}, params: {api_params}")
|
||||
if page == 0: # Only return error if first page fails
|
||||
return f"Failed to fetch data from {self.cog_name}. HTTP Status: {response.status}"
|
||||
break
|
||||
except aiohttp.ClientError as e:
|
||||
log.error(f"aiohttp.ClientError in _fetch_posts_logic for {self.cog_name} tags {tags}: {e}")
|
||||
if page == 0: # Only return error if first page fails
|
||||
return f"Network error fetching data from {self.cog_name}: {e}"
|
||||
break
|
||||
except Exception as e:
|
||||
log.exception(f"Unexpected error in _fetch_posts_logic API call for {self.cog_name} tags {tags}: {e}")
|
||||
if page == 0: # Only return error if first page fails
|
||||
return f"An unexpected error occurred during {self.cog_name} API call: {e}"
|
||||
break
|
||||
|
||||
# Limit to the total we want
|
||||
if len(all_results) > total_limit:
|
||||
all_results = all_results[:total_limit]
|
||||
break
|
||||
else:
|
||||
# Single request with specific pid/limit
|
||||
api_params = {
|
||||
"page": "dapi", "s": "post", "q": "index",
|
||||
"limit": api_limit, "pid": current_pid, "tags": tags, "json": 1
|
||||
}
|
||||
|
||||
try:
|
||||
async with self.session.get(self.api_base_url, params=api_params) as response:
|
||||
if response.status == 200:
|
||||
try:
|
||||
data = await response.json()
|
||||
except aiohttp.ContentTypeError:
|
||||
log.warning(f"{self.cog_name} API returned non-JSON for tags: {tags}, pid: {current_pid}, params: {api_params}")
|
||||
data = None
|
||||
|
||||
if data and isinstance(data, list):
|
||||
all_results.extend(data)
|
||||
elif isinstance(data, list) and len(data) == 0: pass
|
||||
else:
|
||||
log.warning(f"Unexpected API response format from {self.cog_name} (not list or empty list): {data} for tags: {tags}, pid: {current_pid}, params: {api_params}")
|
||||
if pid_override is not None or limit_override is not None:
|
||||
return f"Unexpected API response format from {self.cog_name}: {response.status}"
|
||||
else:
|
||||
log.error(f"Failed to fetch {self.cog_name} data. HTTP Status: {response.status} for tags: {tags}, pid: {current_pid}, params: {api_params}")
|
||||
return f"Failed to fetch data from {self.cog_name}. HTTP Status: {response.status}"
|
||||
except aiohttp.ClientError as e:
|
||||
log.error(f"aiohttp.ClientError in _fetch_posts_logic for {self.cog_name} tags {tags}: {e}")
|
||||
return f"Network error fetching data from {self.cog_name}: {e}"
|
||||
except Exception as e:
|
||||
log.exception(f"Unexpected error in _fetch_posts_logic API call for {self.cog_name} tags {tags}: {e}")
|
||||
return f"An unexpected error occurred during {self.cog_name} API call: {e}"
|
||||
|
||||
if pid_override is not None or limit_override is not None:
|
||||
return all_results
|
||||
|
||||
if all_results:
|
||||
cache_key = tags.lower().strip()
|
||||
@ -344,7 +414,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
content = f"Result 1/{len(self.all_results)}:\n{result['file_url']}"
|
||||
view = self.cog.BrowseView(self.cog, self.tags, self.all_results, self.hidden, self.current_index)
|
||||
await interaction.response.edit_message(content=content, view=view)
|
||||
|
||||
|
||||
@discord.ui.button(label="Pin", style=discord.ButtonStyle.danger)
|
||||
async def pin_message(self, interaction: discord.Interaction, button: Button):
|
||||
if interaction.message:
|
||||
@ -387,7 +457,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
async def last(self, interaction: discord.Interaction, button: Button):
|
||||
self.current_index = len(self.all_results) - 1
|
||||
await self._update_message(interaction)
|
||||
|
||||
|
||||
@discord.ui.button(label="Go To", style=discord.ButtonStyle.primary, row=1)
|
||||
async def goto(self, interaction: discord.Interaction, button: Button):
|
||||
modal = self.cog.GoToModal(len(self.all_results))
|
||||
@ -438,7 +508,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
|
||||
async def _prefix_command_logic(self, ctx: commands.Context, tags: str):
|
||||
# Loading message is handled by _fetch_posts_logic if ctx is passed
|
||||
response = await self._fetch_posts_logic(ctx, tags)
|
||||
response = await self._fetch_posts_logic(ctx, tags)
|
||||
|
||||
if isinstance(response, tuple):
|
||||
content, all_results = response
|
||||
@ -476,7 +546,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
|
||||
async def _slash_command_logic(self, interaction: discord.Interaction, tags: str, hidden: bool):
|
||||
response = await self._fetch_posts_logic(interaction, tags, hidden=hidden)
|
||||
|
||||
|
||||
if isinstance(response, tuple):
|
||||
content, all_results = response
|
||||
view = self.GelbooruButtons(self, tags, all_results, hidden)
|
||||
@ -488,7 +558,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
ephemeral_error = hidden
|
||||
if self.is_nsfw_site and response.startswith(f'This command for {self.cog_name} can only be used'):
|
||||
ephemeral_error = True # Always make NSFW warnings ephemeral if possible
|
||||
|
||||
|
||||
if not interaction.response.is_done():
|
||||
await interaction.response.send_message(response, ephemeral=ephemeral_error)
|
||||
else:
|
||||
@ -499,7 +569,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
|
||||
async def _browse_slash_command_logic(self, interaction: discord.Interaction, tags: str, hidden: bool):
|
||||
response = await self._fetch_posts_logic(interaction, tags, hidden=hidden)
|
||||
|
||||
|
||||
if isinstance(response, tuple):
|
||||
_, all_results = response
|
||||
if not all_results:
|
||||
@ -507,8 +577,8 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
if not interaction.response.is_done(): await interaction.response.send_message(content, ephemeral=hidden)
|
||||
else: await interaction.followup.send(content, ephemeral=hidden)
|
||||
return
|
||||
|
||||
result = all_results[0]
|
||||
|
||||
result = all_results[0]
|
||||
content = f"Result 1/{len(all_results)}:\n{result['file_url']}"
|
||||
view = self.BrowseView(self, tags, all_results, hidden, current_index=0)
|
||||
if interaction.response.is_done(): await interaction.followup.send(content, view=view, ephemeral=hidden)
|
||||
@ -536,15 +606,15 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
|
||||
tags = sub.get("tags")
|
||||
webhook_url = sub.get("webhook_url")
|
||||
last_known_post_id = sub.get("last_known_post_id", 0)
|
||||
last_known_post_id = sub.get("last_known_post_id", 0)
|
||||
|
||||
if not tags or not webhook_url:
|
||||
log.warning(f"Subscription for {self.cog_name} (guild {guild_id_str}, sub {sub.get('subscription_id')}) missing tags/webhook.")
|
||||
continue
|
||||
|
||||
|
||||
fetched_posts_response = await self._fetch_posts_logic("internal_task_call", tags, pid_override=0, limit_override=100)
|
||||
|
||||
if isinstance(fetched_posts_response, str):
|
||||
if isinstance(fetched_posts_response, str):
|
||||
log.error(f"Error fetching {self.cog_name} posts for sub {sub.get('subscription_id')} (tags: {tags}): {fetched_posts_response}")
|
||||
continue
|
||||
if not fetched_posts_response: continue
|
||||
@ -556,7 +626,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
if not isinstance(post_data, dict) or "id" not in post_data or "file_url" not in post_data:
|
||||
log.warning(f"Malformed {self.cog_name} post data for tags {tags}: {post_data}")
|
||||
continue
|
||||
post_id = int(post_data["id"])
|
||||
post_id = int(post_data["id"])
|
||||
if post_id > last_known_post_id: new_posts_to_send.append(post_data)
|
||||
if post_id > current_max_id_this_batch: current_max_id_this_batch = post_id
|
||||
|
||||
@ -567,9 +637,9 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
for new_post in new_posts_to_send:
|
||||
post_id = int(new_post["id"])
|
||||
message_content = f"New {self.cog_name} post for tags `{tags}`:\n{new_post['file_url']}"
|
||||
|
||||
|
||||
target_thread_id: typing.Optional[str] = sub.get("target_post_id") or sub.get("thread_id")
|
||||
|
||||
|
||||
send_success = await self._send_via_webhook(webhook_url, message_content, thread_id=target_thread_id)
|
||||
if send_success:
|
||||
latest_sent_id_for_this_sub = post_id
|
||||
@ -579,12 +649,12 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
if original_sub_entry.get("subscription_id") == sub.get("subscription_id"):
|
||||
original_sub_entry["last_known_post_id"] = latest_sent_id_for_this_sub
|
||||
needs_save = True
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
log.error(f"Failed to send new {self.cog_name} post via webhook for sub {sub.get('subscription_id')}.")
|
||||
break
|
||||
|
||||
break
|
||||
|
||||
if not new_posts_to_send and current_max_id_this_batch > last_known_post_id:
|
||||
original_guild_subs = self.subscriptions_data.get(guild_id_str)
|
||||
if original_guild_subs:
|
||||
@ -603,12 +673,12 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
async def before_check_new_posts(self):
|
||||
await self.bot.wait_until_ready()
|
||||
log.info(f"{self.cog_name}Cog: `check_new_posts` loop is waiting for bot readiness...")
|
||||
if self.session is None or self.session.closed:
|
||||
if self.session is None or self.session.closed:
|
||||
self.session = aiohttp.ClientSession()
|
||||
log.info(f"aiohttp ClientSession created before {self.cog_name} check_new_posts loop.")
|
||||
|
||||
async def _create_new_subscription(self, guild_id: int, user_id: int, tags: str,
|
||||
target_channel: typing.Union[discord.TextChannel, discord.ForumChannel],
|
||||
async def _create_new_subscription(self, guild_id: int, user_id: int, tags: str,
|
||||
target_channel: typing.Union[discord.TextChannel, discord.ForumChannel],
|
||||
requested_thread_target: typing.Optional[str] = None,
|
||||
requested_post_title: typing.Optional[str] = None,
|
||||
is_request_approval: bool = False,
|
||||
@ -631,22 +701,22 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
if not found_thread and hasattr(target_channel, 'threads'):
|
||||
for t in target_channel.threads:
|
||||
if t.name.lower() == requested_thread_target.lower(): found_thread = t; break
|
||||
|
||||
|
||||
if found_thread:
|
||||
actual_target_thread_id = str(found_thread.id)
|
||||
actual_target_thread_mention = found_thread.mention
|
||||
else: return f"❌ Could not find thread `{requested_thread_target}` in {target_channel.mention} for {self.cog_name}."
|
||||
|
||||
|
||||
elif isinstance(target_channel, discord.ForumChannel):
|
||||
forum_post_initial_message = f"✨ **New {self.cog_name} Watch Initialized!** ✨\nNow monitoring tags: `{tags}`"
|
||||
if is_request_approval and requester_mention: forum_post_initial_message += f"\n_Requested by: {requester_mention}_"
|
||||
|
||||
|
||||
try:
|
||||
if not target_channel.permissions_for(target_channel.guild.me).create_public_threads:
|
||||
return f"❌ I lack permission to create posts in forum {target_channel.mention} for {self.cog_name}."
|
||||
|
||||
new_forum_post = await target_channel.create_thread(name=actual_post_title, content=forum_post_initial_message, reason=f"{self.cog_name}Watch: {tags}")
|
||||
actual_target_thread_id = str(new_forum_post.thread.id)
|
||||
actual_target_thread_id = str(new_forum_post.thread.id)
|
||||
actual_target_thread_mention = new_forum_post.thread.mention
|
||||
log.info(f"Created {self.cog_name} forum post {new_forum_post.thread.id} for tags '{tags}' in forum {target_channel.id}")
|
||||
except discord.HTTPException as e: return f"❌ Failed to create {self.cog_name} forum post in {target_channel.mention}: {e}"
|
||||
@ -666,19 +736,19 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
|
||||
guild_id_str = str(guild_id)
|
||||
subscription_id = str(uuid.uuid4())
|
||||
|
||||
|
||||
new_sub_data = {
|
||||
"subscription_id": subscription_id, "tags": tags.strip(), "webhook_url": webhook_url,
|
||||
"last_known_post_id": last_known_post_id, "added_by_user_id": str(user_id),
|
||||
"subscription_id": subscription_id, "tags": tags.strip(), "webhook_url": webhook_url,
|
||||
"last_known_post_id": last_known_post_id, "added_by_user_id": str(user_id),
|
||||
"added_timestamp": discord.utils.utcnow().isoformat()
|
||||
}
|
||||
if isinstance(target_channel, discord.ForumChannel):
|
||||
new_sub_data.update({"forum_channel_id": str(target_channel.id), "target_post_id": actual_target_thread_id, "post_title": actual_post_title})
|
||||
else:
|
||||
else:
|
||||
new_sub_data.update({"channel_id": str(target_channel.id), "thread_id": actual_target_thread_id})
|
||||
|
||||
if guild_id_str not in self.subscriptions_data: self.subscriptions_data[guild_id_str] = []
|
||||
|
||||
|
||||
for existing_sub in self.subscriptions_data[guild_id_str]: # Duplicate check
|
||||
is_dup_tags = existing_sub.get("tags") == new_sub_data["tags"]
|
||||
is_dup_forum = isinstance(target_channel, discord.ForumChannel) and \
|
||||
@ -696,7 +766,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
target_desc = f"in {target_channel.mention}"
|
||||
if isinstance(target_channel, discord.ForumChannel) and actual_target_thread_mention: target_desc = f"in forum post {actual_target_thread_mention} within {target_channel.mention}"
|
||||
elif actual_target_thread_mention: target_desc += f" (thread: {actual_target_thread_mention})"
|
||||
|
||||
|
||||
log.info(f"{self.cog_name} subscription added: Guild {guild_id_str}, Tags '{tags}', Target {target_desc}, SubID {subscription_id}")
|
||||
return (f"✅ Watching {self.cog_name} for new posts with tags `{tags}` {target_desc}.\n"
|
||||
f"Initial latest post ID: {last_known_post_id}. Sub ID: `{subscription_id}`.")
|
||||
@ -743,7 +813,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
self._save_pending_requests()
|
||||
|
||||
log.info(f"New {self.cog_name} watch request: Guild {guild_id_str}, Requester {interaction.user}, Tags '{tags}', ReqID {request_id}")
|
||||
|
||||
|
||||
await interaction.followup.send(
|
||||
f"✅ Your {self.cog_name} watch request for tags `{tags}` in forum {forum_channel.mention} (title: \"{actual_post_title}\") submitted.\n"
|
||||
f"Request ID: `{request_id}`. Awaiting moderator approval."
|
||||
@ -775,7 +845,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
f" **Proposed Title:** \"{req.get('requested_post_title')}\"\n"
|
||||
f" **Requested:** {time_fmt}\n---"
|
||||
)
|
||||
|
||||
|
||||
embed.description = "\n".join(desc_parts)[:4096] # Limit description length
|
||||
await interaction.response.send_message(embed=embed, ephemeral=True)
|
||||
|
||||
@ -789,7 +859,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
for i, r in enumerate(self.pending_requests_data.get(guild_id_str, [])):
|
||||
if r.get("request_id") == request_id and r.get("status") == "pending":
|
||||
req_to_approve, req_idx = r, i; break
|
||||
|
||||
|
||||
if not req_to_approve:
|
||||
await interaction.followup.send(f"❌ Pending {self.cog_name} request ID `{request_id}` not found.", ephemeral=True)
|
||||
return
|
||||
@ -830,7 +900,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
for i, r in enumerate(self.pending_requests_data.get(guild_id_str, [])):
|
||||
if r.get("request_id") == request_id and r.get("status") == "pending":
|
||||
req_to_reject, req_idx = r, i; break
|
||||
|
||||
|
||||
if not req_to_reject:
|
||||
await interaction.followup.send(f"❌ Pending {self.cog_name} request ID `{request_id}` not found.", ephemeral=True)
|
||||
return
|
||||
@ -867,9 +937,9 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
elif sub.get('channel_id'):
|
||||
target_location = f"<#{sub.get('channel_id')}>"
|
||||
if sub.get('thread_id'): target_location += f" (Thread <#{sub.get('thread_id')}>)"
|
||||
|
||||
|
||||
desc_parts.append(f"**ID:** `{sub.get('subscription_id')}`\n **Tags:** `{sub.get('tags')}`\n **Target:** {target_location}\n **Last Sent ID:** `{sub.get('last_known_post_id', 'N/A')}`\n---")
|
||||
|
||||
|
||||
embed.description = "\n".join(desc_parts)[:4096]
|
||||
await interaction.response.send_message(embed=embed, ephemeral=True)
|
||||
|
||||
@ -888,7 +958,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet
|
||||
found = True
|
||||
log.info(f"Removing {self.cog_name} watch: Guild {guild_id_str}, Sub ID {subscription_id}")
|
||||
else: new_subs_list.append(sub_entry)
|
||||
|
||||
|
||||
if not found:
|
||||
await interaction.response.send_message(f"❌ {self.cog_name} subscription ID `{subscription_id}` not found.", ephemeral=True)
|
||||
return
|
||||
|
48
test_pagination.py
Normal file
48
test_pagination.py
Normal file
@ -0,0 +1,48 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
log = logging.getLogger("test_pagination")
|
||||
|
||||
# Add the current directory to the path so we can import the cogs
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from cogs.safebooru_cog import SafebooruCog
|
||||
from discord.ext import commands
|
||||
import discord
|
||||
|
||||
async def test_pagination():
|
||||
# Create a mock bot with intents
|
||||
intents = discord.Intents.default()
|
||||
bot = commands.Bot(command_prefix='!', intents=intents)
|
||||
|
||||
# Initialize the cog
|
||||
cog = SafebooruCog(bot)
|
||||
|
||||
# Test the pagination for a specific tag
|
||||
tag = "kasane_teto"
|
||||
log.info(f"Testing pagination for tag: {tag}")
|
||||
|
||||
# Call the _fetch_posts_logic method
|
||||
results = await cog._fetch_posts_logic("test", tag)
|
||||
|
||||
# Check the results
|
||||
if isinstance(results, tuple):
|
||||
log.info(f"Found {len(results[1])} results")
|
||||
# Print the first few results
|
||||
for i, result in enumerate(results[1][:5]):
|
||||
log.info(f"Result {i+1}: {result.get('id')} - {result.get('file_url')}")
|
||||
else:
|
||||
log.error(f"Error: {results}")
|
||||
|
||||
# Clean up
|
||||
if hasattr(cog, 'session') and cog.session and not cog.session.closed:
|
||||
await cog.session.close()
|
||||
log.info("Closed aiohttp session")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the test
|
||||
asyncio.run(test_pagination())
|
Loading…
x
Reference in New Issue
Block a user