1057 lines
59 KiB
Python
1057 lines
59 KiB
Python
import os
|
|
import discord
|
|
from discord.ext import commands, tasks
|
|
from discord import app_commands
|
|
from discord.ui import Button, View
|
|
from discord import ui
|
|
from discord.components import MediaGalleryItem
|
|
ui.MediaGalleryItem = MediaGalleryItem
|
|
import random
|
|
import aiohttp
|
|
import time
|
|
import json
|
|
import typing # Need this for Optional
|
|
import uuid # For subscription IDs
|
|
import asyncio
|
|
import logging # For logging
|
|
from datetime import datetime # For parsing ISO format timestamps
|
|
import abc # For Abstract Base Class
|
|
|
|
# Setup logger for this cog
|
|
log = logging.getLogger(__name__)
|
|
|
|
# Combined metaclass to resolve conflicts between CogMeta and ABCMeta
|
|
class GelbooruWatcherMeta(commands.CogMeta, abc.ABCMeta):
|
|
pass
|
|
|
|
class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMeta):
|
|
def __init__(self, bot: commands.Bot, cog_name: str, api_base_url: str, default_tags: str, is_nsfw_site: bool, command_group_name: str, main_command_name: str):
|
|
self.bot = bot
|
|
# Ensure super().__init__() is called for Cog's metaclass features, especially if 'name' was passed to Cog.
|
|
# However, 'name' is handled by the derived classes (Rule34Cog, SafebooruCog)
|
|
# For the base class, we don't pass 'name' to commands.Cog constructor directly.
|
|
# The `name` parameter in `Rule34Cog(..., name="Rule34")` is handled by CogMeta.
|
|
# The base class itself doesn't need a Cog 'name' in the same way.
|
|
# commands.Cog.__init__(self, bot) # This might be needed if Cog's __init__ does setup
|
|
# Let's rely on the derived class's super() call to handle Cog's __init__ properly.
|
|
|
|
self.cog_name = cog_name
|
|
self.api_base_url = api_base_url
|
|
self.default_tags = default_tags
|
|
self.is_nsfw_site = is_nsfw_site
|
|
self.command_group_name = command_group_name
|
|
self.main_command_name = main_command_name
|
|
|
|
self.cache_file = f"{self.cog_name.lower()}_cache.json"
|
|
self.subscriptions_file = f"{self.cog_name.lower()}_subscriptions.json"
|
|
self.pending_requests_file = f"{self.cog_name.lower()}_pending_requests.json"
|
|
|
|
self.cache_data = self._load_cache()
|
|
self.subscriptions_data = self._load_subscriptions()
|
|
self.pending_requests_data = self._load_pending_requests()
|
|
self.session: typing.Optional[aiohttp.ClientSession] = None
|
|
|
|
if bot.is_ready():
|
|
asyncio.create_task(self.initialize_cog_async())
|
|
else:
|
|
asyncio.create_task(self.start_task_when_ready())
|
|
|
|
async def initialize_cog_async(self):
|
|
"""Asynchronous part of cog initialization."""
|
|
log.info(f"Initializing {self.cog_name}Cog...")
|
|
if self.session is None or self.session.closed:
|
|
self.session = aiohttp.ClientSession()
|
|
log.info(f"aiohttp ClientSession created for {self.cog_name}Cog.")
|
|
if not self.check_new_posts.is_running():
|
|
self.check_new_posts.start()
|
|
log.info(f"{self.cog_name} new post checker task started.")
|
|
|
|
async def start_task_when_ready(self):
|
|
"""Waits until bot is ready, then initializes and starts tasks."""
|
|
await self.bot.wait_until_ready()
|
|
await self.initialize_cog_async()
|
|
|
|
async def cog_load(self):
|
|
log.info(f"{self.cog_name}Cog loaded.")
|
|
if self.session is None or self.session.closed:
|
|
self.session = aiohttp.ClientSession()
|
|
log.info(f"aiohttp ClientSession (re)created during cog_load for {self.cog_name}Cog.")
|
|
if not self.check_new_posts.is_running():
|
|
if self.bot.is_ready():
|
|
self.check_new_posts.start()
|
|
log.info(f"{self.cog_name} new post checker task started from cog_load.")
|
|
else:
|
|
log.warning(f"{self.cog_name}Cog loaded but bot not ready, task start deferred.")
|
|
|
|
async def cog_unload(self):
|
|
"""Clean up resources when the cog is unloaded."""
|
|
self.check_new_posts.cancel()
|
|
log.info(f"{self.cog_name} new post checker task stopped.")
|
|
if self.session and not self.session.closed:
|
|
await self.session.close()
|
|
log.info(f"aiohttp ClientSession closed for {self.cog_name}Cog.")
|
|
|
|
def _load_cache(self):
|
|
if os.path.exists(self.cache_file):
|
|
try:
|
|
with open(self.cache_file, "r") as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
log.error(f"Failed to load {self.cog_name} cache file ({self.cache_file}): {e}")
|
|
return {}
|
|
|
|
def _save_cache(self):
|
|
try:
|
|
with open(self.cache_file, "w") as f:
|
|
json.dump(self.cache_data, f, indent=4)
|
|
except Exception as e:
|
|
log.error(f"Failed to save {self.cog_name} cache file ({self.cache_file}): {e}")
|
|
|
|
def _load_subscriptions(self):
|
|
if os.path.exists(self.subscriptions_file):
|
|
try:
|
|
with open(self.subscriptions_file, "r") as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
log.error(f"Failed to load {self.cog_name} subscriptions file ({self.subscriptions_file}): {e}")
|
|
return {}
|
|
|
|
def _save_subscriptions(self):
|
|
try:
|
|
with open(self.subscriptions_file, "w") as f:
|
|
json.dump(self.subscriptions_data, f, indent=4)
|
|
log.debug(f"Saved {self.cog_name} subscriptions to {self.subscriptions_file}")
|
|
except Exception as e:
|
|
log.error(f"Failed to save {self.cog_name} subscriptions file ({self.subscriptions_file}): {e}")
|
|
|
|
def _load_pending_requests(self):
|
|
if os.path.exists(self.pending_requests_file):
|
|
try:
|
|
with open(self.pending_requests_file, "r") as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
log.error(f"Failed to load {self.cog_name} pending requests file ({self.pending_requests_file}): {e}")
|
|
return {}
|
|
|
|
def _save_pending_requests(self):
|
|
try:
|
|
with open(self.pending_requests_file, "w") as f:
|
|
json.dump(self.pending_requests_data, f, indent=4)
|
|
log.debug(f"Saved {self.cog_name} pending requests to {self.pending_requests_file}")
|
|
except Exception as e:
|
|
log.error(f"Failed to save {self.cog_name} pending requests file ({self.pending_requests_file}): {e}")
|
|
|
|
async def _get_or_create_webhook(self, channel: typing.Union[discord.TextChannel, discord.ForumChannel]) -> typing.Optional[str]:
|
|
if not self.session or self.session.closed:
|
|
self.session = aiohttp.ClientSession()
|
|
log.info(f"Recreated aiohttp.ClientSession in _get_or_create_webhook for {self.cog_name}")
|
|
|
|
guild_subs = self.subscriptions_data.get(str(channel.guild.id), [])
|
|
for sub in guild_subs:
|
|
if sub.get("channel_id") == str(channel.id) and sub.get("webhook_url"):
|
|
try:
|
|
webhook = discord.Webhook.from_url(sub["webhook_url"], session=self.session)
|
|
await webhook.fetch()
|
|
if webhook.channel_id == channel.id:
|
|
log.debug(f"Reusing existing webhook {webhook.id} for channel {channel.id} ({self.cog_name})")
|
|
return sub["webhook_url"]
|
|
except (discord.NotFound, ValueError, discord.HTTPException):
|
|
log.warning(f"Found stored webhook URL for channel {channel.id} ({self.cog_name}) but it's invalid.")
|
|
|
|
try:
|
|
webhooks = await channel.webhooks()
|
|
for wh in webhooks:
|
|
if wh.user == self.bot.user:
|
|
log.debug(f"Found existing bot-owned webhook {wh.id} ('{wh.name}') in channel {channel.id} ({self.cog_name})")
|
|
return wh.url
|
|
except discord.Forbidden:
|
|
log.warning(f"Missing 'Manage Webhooks' permission in channel {channel.id} ({self.cog_name}).")
|
|
return None
|
|
except discord.HTTPException as e:
|
|
log.error(f"HTTP error listing webhooks for channel {channel.id} ({self.cog_name}): {e}")
|
|
|
|
if not channel.permissions_for(channel.guild.me).manage_webhooks:
|
|
log.warning(f"Missing 'Manage Webhooks' permission in channel {channel.id} ({self.cog_name}). Cannot create webhook.")
|
|
return None
|
|
|
|
try:
|
|
webhook_name = f"{self.bot.user.name} {self.cog_name} Watcher"
|
|
avatar_bytes = None
|
|
if self.bot.user and self.bot.user.display_avatar:
|
|
try:
|
|
avatar_bytes = await self.bot.user.display_avatar.read()
|
|
except Exception as e:
|
|
log.warning(f"Could not read bot avatar for webhook creation ({self.cog_name}): {e}")
|
|
|
|
new_webhook = await channel.create_webhook(name=webhook_name, avatar=avatar_bytes, reason=f"{self.cog_name} Tag Watcher")
|
|
log.info(f"Created new webhook {new_webhook.id} ('{new_webhook.name}') in channel {channel.id} ({self.cog_name})")
|
|
return new_webhook.url
|
|
except discord.HTTPException as e:
|
|
log.error(f"Failed to create webhook in {channel.mention} ({self.cog_name}): {e}")
|
|
return None
|
|
except Exception as e:
|
|
log.exception(f"Unexpected error creating webhook in {channel.mention} ({self.cog_name})")
|
|
return None
|
|
|
|
async def _send_via_webhook(
|
|
self,
|
|
webhook_url: str,
|
|
content: str = "",
|
|
thread_id: typing.Optional[str] = None,
|
|
view: typing.Optional[discord.ui.View] = None,
|
|
):
|
|
if not self.session or self.session.closed:
|
|
self.session = aiohttp.ClientSession()
|
|
log.info(f"Recreated aiohttp.ClientSession in _send_via_webhook for {self.cog_name}")
|
|
|
|
try:
|
|
webhook = discord.Webhook.from_url(
|
|
webhook_url, session=self.session, client=self.bot
|
|
)
|
|
target_thread_obj = None
|
|
if thread_id:
|
|
try:
|
|
target_thread_obj = discord.Object(id=int(thread_id))
|
|
except ValueError:
|
|
log.error(f"Invalid thread_id format: {thread_id} for webhook {webhook_url[:30]} ({self.cog_name}). Sending to main channel.")
|
|
|
|
await webhook.send(
|
|
content=content,
|
|
username=f"{self.bot.user.name} {self.cog_name} Watcher" if self.bot.user else f"{self.cog_name} Watcher",
|
|
avatar_url=self.bot.user.display_avatar.url if self.bot.user and self.bot.user.display_avatar else None,
|
|
thread=target_thread_obj,
|
|
view=view,
|
|
)
|
|
log.debug(f"Sent message via webhook to {webhook_url[:30]}... (Thread: {thread_id if thread_id else 'None'}) ({self.cog_name})")
|
|
return True
|
|
except ValueError: log.error(f"Invalid webhook URL format: {webhook_url[:30]}... ({self.cog_name})")
|
|
except discord.NotFound: log.error(f"Webhook not found: {webhook_url[:30]}... ({self.cog_name})")
|
|
except discord.Forbidden: log.error(f"Forbidden to send to webhook: {webhook_url[:30]}... ({self.cog_name})")
|
|
except discord.HTTPException as e: log.error(f"HTTP error sending to webhook {webhook_url[:30]}... ({self.cog_name}): {e}")
|
|
except aiohttp.ClientError as e: log.error(f"aiohttp client error sending to webhook {webhook_url[:30]}... ({self.cog_name}): {e}")
|
|
except Exception as e: log.exception(f"Unexpected error sending to webhook {webhook_url[:30]}... ({self.cog_name}): {e}")
|
|
return False
|
|
|
|
async def _fetch_posts_logic(self, interaction_or_ctx: typing.Union[discord.Interaction, commands.Context, str], tags: str, pid_override: typing.Optional[int] = None, limit_override: typing.Optional[int] = None, hidden: bool = False) -> typing.Union[str, tuple[str, list], list]:
|
|
all_results = []
|
|
current_pid = pid_override if pid_override is not None else 0
|
|
# API has a hard limit of 1000 results per request, so we'll use that as our per-page limit
|
|
per_page_limit = 1000
|
|
# If limit_override is provided, use it, otherwise default to 3000 (3 pages of results)
|
|
total_limit = limit_override if limit_override is not None else 100000
|
|
|
|
# For internal calls with specific pid/limit, use those exact values
|
|
if pid_override is not None or limit_override is not None:
|
|
use_pagination = False
|
|
api_limit = limit_override if limit_override is not None else per_page_limit
|
|
else:
|
|
use_pagination = True
|
|
api_limit = per_page_limit
|
|
|
|
if not isinstance(interaction_or_ctx, str) and interaction_or_ctx:
|
|
if self.is_nsfw_site:
|
|
is_nsfw_channel = False
|
|
channel = interaction_or_ctx.channel
|
|
if isinstance(channel, discord.TextChannel) and channel.is_nsfw():
|
|
is_nsfw_channel = True
|
|
elif isinstance(channel, discord.DMChannel):
|
|
is_nsfw_channel = True
|
|
|
|
# For Gelbooru-like APIs, 'rating:safe', 'rating:general', 'rating:questionable' might be SFW-ish
|
|
# We'll stick to 'rating:safe' for simplicity as it was in Rule34Cog
|
|
allow_in_non_nsfw = 'rating:safe' in tags.lower()
|
|
|
|
if not is_nsfw_channel and not allow_in_non_nsfw:
|
|
return f'This command for {self.cog_name} can only be used in age-restricted (NSFW) channels, DMs, or with the `rating:safe` tag.'
|
|
|
|
is_interaction = not isinstance(interaction_or_ctx, commands.Context)
|
|
if is_interaction:
|
|
if not interaction_or_ctx.response.is_done():
|
|
await interaction_or_ctx.response.defer(ephemeral=hidden)
|
|
elif hasattr(interaction_or_ctx, 'reply'): # Prefix command
|
|
await interaction_or_ctx.reply(f"Fetching data from {self.cog_name}, please wait...")
|
|
|
|
# Check cache first if not using specific pagination
|
|
if not use_pagination:
|
|
# Skip cache for internal calls with specific pid/limit
|
|
pass
|
|
else:
|
|
cache_key = tags.lower().strip()
|
|
if cache_key in self.cache_data:
|
|
cached_entry = self.cache_data[cache_key]
|
|
cache_timestamp = cached_entry.get("timestamp", 0)
|
|
if time.time() - cache_timestamp < 86400:
|
|
all_results = cached_entry.get("results", [])
|
|
if all_results:
|
|
random_result = random.choice(all_results)
|
|
return (f"{random_result['file_url']}", all_results)
|
|
|
|
if not self.session or self.session.closed:
|
|
self.session = aiohttp.ClientSession()
|
|
log.info(f"Recreated aiohttp.ClientSession in _fetch_posts_logic for {self.cog_name}")
|
|
|
|
all_results = []
|
|
|
|
# If using pagination, we'll make multiple requests
|
|
if use_pagination:
|
|
max_pages = (total_limit + per_page_limit - 1) // per_page_limit
|
|
for page in range(max_pages):
|
|
# Stop if we've reached our total limit or if we got fewer results than the per-page limit
|
|
if len(all_results) >= total_limit or (page > 0 and len(all_results) % per_page_limit != 0):
|
|
break
|
|
|
|
api_params = {
|
|
"page": "dapi", "s": "post", "q": "index",
|
|
"limit": per_page_limit, "pid": page, "tags": tags, "json": 1
|
|
}
|
|
|
|
try:
|
|
async with self.session.get(self.api_base_url, params=api_params) as response:
|
|
if response.status == 200:
|
|
try:
|
|
data = await response.json()
|
|
except aiohttp.ContentTypeError:
|
|
log.warning(f"{self.cog_name} API returned non-JSON for tags: {tags}, pid: {page}, params: {api_params}")
|
|
data = None
|
|
|
|
if data and isinstance(data, list):
|
|
# If we got fewer results than requested, we've reached the end
|
|
all_results.extend(data)
|
|
if len(data) < per_page_limit:
|
|
break
|
|
elif isinstance(data, list) and len(data) == 0:
|
|
# Empty page, no more results
|
|
break
|
|
else:
|
|
log.warning(f"Unexpected API response format from {self.cog_name} (not list or empty list): {data} for tags: {tags}, pid: {page}, params: {api_params}")
|
|
break
|
|
else:
|
|
log.error(f"Failed to fetch {self.cog_name} data. HTTP Status: {response.status} for tags: {tags}, pid: {page}, params: {api_params}")
|
|
if page == 0: # Only return error if first page fails
|
|
return f"Failed to fetch data from {self.cog_name}. HTTP Status: {response.status}"
|
|
break
|
|
except aiohttp.ClientError as e:
|
|
log.error(f"aiohttp.ClientError in _fetch_posts_logic for {self.cog_name} tags {tags}: {e}")
|
|
if page == 0: # Only return error if first page fails
|
|
return f"Network error fetching data from {self.cog_name}: {e}"
|
|
break
|
|
except Exception as e:
|
|
log.exception(f"Unexpected error in _fetch_posts_logic API call for {self.cog_name} tags {tags}: {e}")
|
|
if page == 0: # Only return error if first page fails
|
|
return f"An unexpected error occurred during {self.cog_name} API call: {e}"
|
|
break
|
|
|
|
# Limit to the total we want
|
|
if len(all_results) > total_limit:
|
|
all_results = all_results[:total_limit]
|
|
break
|
|
else:
|
|
# Single request with specific pid/limit
|
|
api_params = {
|
|
"page": "dapi", "s": "post", "q": "index",
|
|
"limit": api_limit, "pid": current_pid, "tags": tags, "json": 1
|
|
}
|
|
|
|
try:
|
|
async with self.session.get(self.api_base_url, params=api_params) as response:
|
|
if response.status == 200:
|
|
try:
|
|
data = await response.json()
|
|
except aiohttp.ContentTypeError:
|
|
log.warning(f"{self.cog_name} API returned non-JSON for tags: {tags}, pid: {current_pid}, params: {api_params}")
|
|
data = None
|
|
|
|
if data and isinstance(data, list):
|
|
all_results.extend(data)
|
|
elif isinstance(data, list) and len(data) == 0: pass
|
|
else:
|
|
log.warning(f"Unexpected API response format from {self.cog_name} (not list or empty list): {data} for tags: {tags}, pid: {current_pid}, params: {api_params}")
|
|
if pid_override is not None or limit_override is not None:
|
|
return f"Unexpected API response format from {self.cog_name}: {response.status}"
|
|
else:
|
|
log.error(f"Failed to fetch {self.cog_name} data. HTTP Status: {response.status} for tags: {tags}, pid: {current_pid}, params: {api_params}")
|
|
return f"Failed to fetch data from {self.cog_name}. HTTP Status: {response.status}"
|
|
except aiohttp.ClientError as e:
|
|
log.error(f"aiohttp.ClientError in _fetch_posts_logic for {self.cog_name} tags {tags}: {e}")
|
|
return f"Network error fetching data from {self.cog_name}: {e}"
|
|
except Exception as e:
|
|
log.exception(f"Unexpected error in _fetch_posts_logic API call for {self.cog_name} tags {tags}: {e}")
|
|
return f"An unexpected error occurred during {self.cog_name} API call: {e}"
|
|
|
|
if pid_override is not None or limit_override is not None:
|
|
return all_results
|
|
|
|
if all_results:
|
|
cache_key = tags.lower().strip()
|
|
self.cache_data[cache_key] = {
|
|
"timestamp": int(time.time()),
|
|
"results": all_results
|
|
}
|
|
self._save_cache()
|
|
|
|
if not all_results:
|
|
return f"No results found from {self.cog_name} for the given tags."
|
|
else:
|
|
random_result = random.choice(all_results)
|
|
return (f"{random_result['file_url']}", all_results)
|
|
|
|
class GelbooruButtons(View):
|
|
def __init__(self, cog: 'GelbooruWatcherBaseCog', tags: str, all_results: list, hidden: bool = False):
|
|
super().__init__(timeout=60)
|
|
self.cog = cog
|
|
self.tags = tags
|
|
self.all_results = all_results
|
|
self.hidden = hidden
|
|
self.current_index = 0 # For browse view state if we merge browse buttons here
|
|
|
|
@discord.ui.button(label="New Random", style=discord.ButtonStyle.primary)
|
|
async def new_random(self, interaction: discord.Interaction, button: Button):
|
|
random_result = random.choice(self.all_results)
|
|
content = f"{random_result['file_url']}"
|
|
await interaction.response.edit_message(content=content, view=self)
|
|
|
|
@discord.ui.button(label="Random In New Message", style=discord.ButtonStyle.success)
|
|
async def new_message(self, interaction: discord.Interaction, button: Button):
|
|
random_result = random.choice(self.all_results)
|
|
content = f"{random_result['file_url']}"
|
|
await interaction.response.send_message(content, view=self, ephemeral=self.hidden)
|
|
|
|
@discord.ui.button(label="Browse Results", style=discord.ButtonStyle.secondary)
|
|
async def browse_results(self, interaction: discord.Interaction, button: Button):
|
|
if not self.all_results:
|
|
await interaction.response.send_message("No results to browse.", ephemeral=True)
|
|
return
|
|
self.current_index = 0
|
|
result = self.all_results[self.current_index]
|
|
content = f"Result 1/{len(self.all_results)}:\n{result['file_url']}"
|
|
view = self.cog.BrowseView(self.cog, self.tags, self.all_results, self.hidden, self.current_index)
|
|
await interaction.response.edit_message(content=content, view=view)
|
|
|
|
@discord.ui.button(label="Pin", style=discord.ButtonStyle.danger)
|
|
async def pin_message(self, interaction: discord.Interaction, button: Button):
|
|
if interaction.message:
|
|
try:
|
|
await interaction.message.pin()
|
|
await interaction.response.send_message("Message pinned successfully!", ephemeral=True)
|
|
except discord.Forbidden: await interaction.response.send_message("I don't have permission to pin messages.", ephemeral=True)
|
|
except discord.HTTPException as e: await interaction.response.send_message(f"Failed to pin: {e}", ephemeral=True)
|
|
|
|
class BrowseView(View):
|
|
def __init__(self, cog: 'GelbooruWatcherBaseCog', tags: str, all_results: list, hidden: bool = False, current_index: int = 0):
|
|
super().__init__(timeout=60)
|
|
self.cog = cog
|
|
self.tags = tags
|
|
self.all_results = all_results
|
|
self.hidden = hidden
|
|
self.current_index = current_index
|
|
|
|
async def _update_message(self, interaction: discord.Interaction):
|
|
result = self.all_results[self.current_index]
|
|
content = f"Result {self.current_index + 1}/{len(self.all_results)}:\n{result['file_url']}"
|
|
await interaction.response.edit_message(content=content, view=self)
|
|
|
|
@discord.ui.button(label="First", style=discord.ButtonStyle.secondary, emoji="⏪")
|
|
async def first(self, interaction: discord.Interaction, button: Button):
|
|
self.current_index = 0
|
|
await self._update_message(interaction)
|
|
|
|
@discord.ui.button(label="Previous", style=discord.ButtonStyle.secondary, emoji="◀️")
|
|
async def previous(self, interaction: discord.Interaction, button: Button):
|
|
self.current_index = (self.current_index - 1 + len(self.all_results)) % len(self.all_results)
|
|
await self._update_message(interaction)
|
|
|
|
@discord.ui.button(label="Next", style=discord.ButtonStyle.primary, emoji="▶️")
|
|
async def next_result(self, interaction: discord.Interaction, button: Button): # Renamed from next to avoid conflict
|
|
self.current_index = (self.current_index + 1) % len(self.all_results)
|
|
await self._update_message(interaction)
|
|
|
|
@discord.ui.button(label="Last", style=discord.ButtonStyle.secondary, emoji="⏩")
|
|
async def last(self, interaction: discord.Interaction, button: Button):
|
|
self.current_index = len(self.all_results) - 1
|
|
await self._update_message(interaction)
|
|
|
|
@discord.ui.button(label="Go To", style=discord.ButtonStyle.primary, row=1)
|
|
async def goto(self, interaction: discord.Interaction, button: Button):
|
|
modal = self.cog.GoToModal(len(self.all_results))
|
|
await interaction.response.send_modal(modal)
|
|
await modal.wait()
|
|
if modal.value is not None:
|
|
self.current_index = modal.value - 1
|
|
# Edit the original message from the modal's followup context
|
|
result = self.all_results[self.current_index]
|
|
content = f"Result {modal.value}/{len(self.all_results)}:\n{result['file_url']}"
|
|
await interaction.followup.edit_message(interaction.message.id, content=content, view=self)
|
|
|
|
|
|
@discord.ui.button(label="Back to Main Controls", style=discord.ButtonStyle.danger, row=1)
|
|
async def back_to_main(self, interaction: discord.Interaction, button: Button):
|
|
# Send a random one from all_results as the content
|
|
if not self.all_results: # Should not happen if browse was initiated
|
|
await interaction.response.edit_message(content="No results available.", view=None)
|
|
return
|
|
random_result = random.choice(self.all_results)
|
|
content = f"{random_result['file_url']}"
|
|
view = self.cog.GelbooruButtons(self.cog, self.tags, self.all_results, self.hidden)
|
|
await interaction.response.edit_message(content=content, view=view)
|
|
|
|
class GoToModal(discord.ui.Modal):
|
|
def __init__(self, max_pages: int):
|
|
super().__init__(title="Go To Page")
|
|
self.value = None
|
|
self.max_pages = max_pages
|
|
self.page_num = discord.ui.TextInput(
|
|
label=f"Page Number (1-{max_pages})",
|
|
placeholder=f"Enter a number between 1 and {max_pages}",
|
|
min_length=1,
|
|
max_length=len(str(max_pages))
|
|
)
|
|
self.add_item(self.page_num)
|
|
|
|
async def on_submit(self, interaction: discord.Interaction):
|
|
try:
|
|
num = int(self.page_num.value)
|
|
if 1 <= num <= self.max_pages:
|
|
self.value = num
|
|
await interaction.response.defer() # Defer here, followup in BrowseView.goto
|
|
else:
|
|
await interaction.response.send_message(f"Please enter a number between 1 and {self.max_pages}",ephemeral=True)
|
|
except ValueError:
|
|
await interaction.response.send_message("Please enter a valid number",ephemeral=True)
|
|
|
|
def _build_new_post_view(self, tags: str, file_url: str) -> ui.LayoutView:
|
|
view = ui.LayoutView(timeout=None)
|
|
container = ui.Container()
|
|
view.add_item(container)
|
|
section = ui.Section(accessory=ui.Media)
|
|
container.add_item(section)
|
|
section.add_item(ui.TextDisplay(f"New {self.cog_name} post for tags `{tags}`:"))
|
|
container.add_item(ui.TextDisplay(file_url))
|
|
|
|
return view
|
|
|
|
async def _prefix_command_logic(self, ctx: commands.Context, tags: str):
|
|
# Loading message is handled by _fetch_posts_logic if ctx is passed
|
|
response = await self._fetch_posts_logic(ctx, tags)
|
|
|
|
if isinstance(response, tuple):
|
|
content, all_results = response
|
|
view = self.GelbooruButtons(self, tags, all_results, hidden=False) # Prefix commands are not hidden
|
|
# _fetch_posts_logic sends initial reply, so we edit it
|
|
original_message = ctx.message # This might not be the bot's reply message
|
|
# We need a reference to the "Fetching data..." message.
|
|
# This requires _fetch_posts_logic to return it or for the command to manage it.
|
|
# For now, let's assume _fetch_posts_logic's "reply" is what we want to edit.
|
|
# This part is tricky with prefix commands and defer/reply.
|
|
# A simpler way for prefix:
|
|
# loading_msg = await ctx.reply(f"Fetching data from {self.cog_name}...")
|
|
# response = await self._fetch_posts_logic("prefix_internal", tags) # Pass a string to skip its deferral
|
|
# if isinstance(response, tuple):
|
|
# await loading_msg.edit(content=content, view=view)
|
|
# else: await loading_msg.edit(content=response, view=None)
|
|
# Let's stick to the original structure where _fetch_posts_logic handles the initial reply.
|
|
# The issue is editing that reply.
|
|
# The `ctx.reply` in `_fetch_posts_logic` is not returned.
|
|
# This needs a slight refactor.
|
|
# For now, we'll assume the user's message is replied to, and we send a new one. This is not ideal.
|
|
# A better way:
|
|
if ctx.message.reference and ctx.message.reference.message_id: # If the bot replied
|
|
try:
|
|
reply_msg = await ctx.channel.fetch_message(ctx.message.reference.message_id)
|
|
if reply_msg.author == self.bot.user:
|
|
await reply_msg.edit(content=content, view=view)
|
|
return
|
|
except (discord.NotFound, discord.Forbidden): pass # Fallback to new message
|
|
await ctx.send(content, view=view) # Fallback
|
|
elif isinstance(response, str): # Error
|
|
# Similar issue with editing the reply.
|
|
await ctx.send(response)
|
|
|
|
|
|
async def _slash_command_logic(self, interaction: discord.Interaction, tags: str, hidden: bool):
|
|
response = await self._fetch_posts_logic(interaction, tags, hidden=hidden)
|
|
|
|
if isinstance(response, tuple):
|
|
content, all_results = response
|
|
view = self.GelbooruButtons(self, tags, all_results, hidden)
|
|
if interaction.response.is_done():
|
|
await interaction.followup.send(content, view=view, ephemeral=hidden)
|
|
else: # Should have been deferred by _fetch_posts_logic
|
|
await interaction.response.send_message(content, view=view, ephemeral=hidden)
|
|
elif isinstance(response, str): # Error
|
|
ephemeral_error = hidden
|
|
if self.is_nsfw_site and response.startswith(f'This command for {self.cog_name} can only be used'):
|
|
ephemeral_error = True # Always make NSFW warnings ephemeral if possible
|
|
|
|
if not interaction.response.is_done():
|
|
await interaction.response.send_message(response, ephemeral=ephemeral_error)
|
|
else:
|
|
try:
|
|
await interaction.followup.send(response, ephemeral=ephemeral_error)
|
|
except discord.HTTPException as e:
|
|
log.error(f"{self.cog_name} slash command: Failed to send error followup for tags '{tags}': {e}")
|
|
|
|
async def _browse_slash_command_logic(self, interaction: discord.Interaction, tags: str, hidden: bool):
|
|
response = await self._fetch_posts_logic(interaction, tags, hidden=hidden)
|
|
|
|
if isinstance(response, tuple):
|
|
_, all_results = response
|
|
if not all_results:
|
|
content = f"No results found from {self.cog_name} for the given tags."
|
|
if not interaction.response.is_done(): await interaction.response.send_message(content, ephemeral=hidden)
|
|
else: await interaction.followup.send(content, ephemeral=hidden)
|
|
return
|
|
|
|
result = all_results[0]
|
|
content = f"Result 1/{len(all_results)}:\n{result['file_url']}"
|
|
view = self.BrowseView(self, tags, all_results, hidden, current_index=0)
|
|
if interaction.response.is_done(): await interaction.followup.send(content, view=view, ephemeral=hidden)
|
|
else: await interaction.response.send_message(content, view=view, ephemeral=hidden)
|
|
elif isinstance(response, str): # Error
|
|
ephemeral_error = hidden
|
|
if self.is_nsfw_site and response.startswith(f'This command for {self.cog_name} can only be used'):
|
|
ephemeral_error = True
|
|
if not interaction.response.is_done(): await interaction.response.send_message(response, ephemeral=ephemeral_error)
|
|
else: await interaction.followup.send(response, ephemeral=ephemeral_error)
|
|
|
|
@tasks.loop(minutes=10)
|
|
async def check_new_posts(self):
|
|
log.debug(f"Running {self.cog_name} new post check...")
|
|
if not self.subscriptions_data:
|
|
return
|
|
|
|
current_subscriptions = json.loads(json.dumps(self.subscriptions_data))
|
|
needs_save = False
|
|
|
|
for guild_id_str, subs_list in current_subscriptions.items():
|
|
if not isinstance(subs_list, list): continue
|
|
for sub_index, sub in enumerate(subs_list):
|
|
if not isinstance(sub, dict): continue
|
|
|
|
tags = sub.get("tags")
|
|
webhook_url = sub.get("webhook_url")
|
|
last_known_post_id = sub.get("last_known_post_id", 0)
|
|
|
|
if not tags or not webhook_url:
|
|
log.warning(f"Subscription for {self.cog_name} (guild {guild_id_str}, sub {sub.get('subscription_id')}) missing tags/webhook.")
|
|
continue
|
|
|
|
fetched_posts_response = await self._fetch_posts_logic("internal_task_call", tags, pid_override=0, limit_override=100)
|
|
|
|
if isinstance(fetched_posts_response, str):
|
|
log.error(f"Error fetching {self.cog_name} posts for sub {sub.get('subscription_id')} (tags: {tags}): {fetched_posts_response}")
|
|
continue
|
|
if not fetched_posts_response: continue
|
|
|
|
new_posts_to_send = []
|
|
current_max_id_this_batch = last_known_post_id
|
|
|
|
for post_data in fetched_posts_response:
|
|
if not isinstance(post_data, dict) or "id" not in post_data or "file_url" not in post_data:
|
|
log.warning(f"Malformed {self.cog_name} post data for tags {tags}: {post_data}")
|
|
continue
|
|
post_id = int(post_data["id"])
|
|
if post_id > last_known_post_id: new_posts_to_send.append(post_data)
|
|
if post_id > current_max_id_this_batch: current_max_id_this_batch = post_id
|
|
|
|
if new_posts_to_send:
|
|
new_posts_to_send.sort(key=lambda p: int(p["id"]))
|
|
log.info(f"Found {len(new_posts_to_send)} new {self.cog_name} post(s) for tags '{tags}', sub {sub.get('subscription_id')}")
|
|
latest_sent_id_for_this_sub = last_known_post_id
|
|
for new_post in new_posts_to_send:
|
|
post_id = int(new_post["id"])
|
|
target_thread_id: typing.Optional[str] = sub.get("target_post_id") or sub.get("thread_id")
|
|
|
|
view = self._build_new_post_view(tags, new_post["file_url"])
|
|
|
|
send_success = await self._send_via_webhook(
|
|
webhook_url,
|
|
content="",
|
|
thread_id=target_thread_id,
|
|
view=view,
|
|
)
|
|
if send_success:
|
|
latest_sent_id_for_this_sub = post_id
|
|
original_guild_subs = self.subscriptions_data.get(guild_id_str)
|
|
if original_guild_subs:
|
|
for original_sub_entry in original_guild_subs:
|
|
if original_sub_entry.get("subscription_id") == sub.get("subscription_id"):
|
|
original_sub_entry["last_known_post_id"] = latest_sent_id_for_this_sub
|
|
needs_save = True
|
|
break
|
|
await asyncio.sleep(1)
|
|
else:
|
|
log.error(f"Failed to send new {self.cog_name} post via webhook for sub {sub.get('subscription_id')}.")
|
|
break
|
|
|
|
if not new_posts_to_send and current_max_id_this_batch > last_known_post_id:
|
|
original_guild_subs = self.subscriptions_data.get(guild_id_str)
|
|
if original_guild_subs:
|
|
for original_sub_entry in original_guild_subs:
|
|
if original_sub_entry.get("subscription_id") == sub.get("subscription_id"):
|
|
if original_sub_entry.get("last_known_post_id", 0) < current_max_id_this_batch:
|
|
original_sub_entry["last_known_post_id"] = current_max_id_this_batch
|
|
needs_save = True
|
|
log.debug(f"Fast-forwarded last_known_post_id for {self.cog_name} sub {sub.get('subscription_id')} to {current_max_id_this_batch}.")
|
|
break
|
|
if needs_save:
|
|
self._save_subscriptions()
|
|
log.debug(f"Finished {self.cog_name} new post check.")
|
|
|
|
@check_new_posts.before_loop
|
|
async def before_check_new_posts(self):
|
|
await self.bot.wait_until_ready()
|
|
log.info(f"{self.cog_name}Cog: `check_new_posts` loop is waiting for bot readiness...")
|
|
if self.session is None or self.session.closed:
|
|
self.session = aiohttp.ClientSession()
|
|
log.info(f"aiohttp ClientSession created before {self.cog_name} check_new_posts loop.")
|
|
|
|
async def _create_new_subscription(self, guild_id: int, user_id: int, tags: str,
|
|
target_channel: typing.Union[discord.TextChannel, discord.ForumChannel],
|
|
requested_thread_target: typing.Optional[str] = None,
|
|
requested_post_title: typing.Optional[str] = None,
|
|
is_request_approval: bool = False,
|
|
requester_mention: typing.Optional[str] = None) -> str:
|
|
actual_target_thread_id: typing.Optional[str] = None
|
|
actual_target_thread_mention: str = ""
|
|
actual_post_title = requested_post_title or f"{self.cog_name} Watch: {tags[:50]}"
|
|
|
|
if isinstance(target_channel, discord.TextChannel):
|
|
if requested_thread_target:
|
|
found_thread: typing.Optional[discord.Thread] = None
|
|
try:
|
|
if not target_channel.guild: return "Error: Guild context not found."
|
|
thread_as_obj = await target_channel.guild.fetch_channel(int(requested_thread_target))
|
|
if isinstance(thread_as_obj, discord.Thread) and thread_as_obj.parent_id == target_channel.id:
|
|
found_thread = thread_as_obj
|
|
except (ValueError, discord.NotFound, discord.Forbidden): pass
|
|
except Exception as e: log.error(f"Error fetching thread by ID '{requested_thread_target}' for {self.cog_name}: {e}")
|
|
|
|
if not found_thread and hasattr(target_channel, 'threads'):
|
|
for t in target_channel.threads:
|
|
if t.name.lower() == requested_thread_target.lower(): found_thread = t; break
|
|
|
|
if found_thread:
|
|
actual_target_thread_id = str(found_thread.id)
|
|
actual_target_thread_mention = found_thread.mention
|
|
else: return f"❌ Could not find thread `{requested_thread_target}` in {target_channel.mention} for {self.cog_name}."
|
|
|
|
elif isinstance(target_channel, discord.ForumChannel):
|
|
forum_post_initial_message = f"✨ **New {self.cog_name} Watch Initialized!** ✨\nNow monitoring tags: `{tags}`"
|
|
if is_request_approval and requester_mention: forum_post_initial_message += f"\n_Requested by: {requester_mention}_"
|
|
|
|
try:
|
|
if not target_channel.permissions_for(target_channel.guild.me).create_public_threads:
|
|
return f"❌ I lack permission to create posts in forum {target_channel.mention} for {self.cog_name}."
|
|
|
|
new_forum_post = await target_channel.create_thread(name=actual_post_title, content=forum_post_initial_message, reason=f"{self.cog_name}Watch: {tags}")
|
|
actual_target_thread_id = str(new_forum_post.thread.id)
|
|
actual_target_thread_mention = new_forum_post.thread.mention
|
|
log.info(f"Created {self.cog_name} forum post {new_forum_post.thread.id} for tags '{tags}' in forum {target_channel.id}")
|
|
except discord.HTTPException as e: return f"❌ Failed to create {self.cog_name} forum post in {target_channel.mention}: {e}"
|
|
except Exception as e:
|
|
log.exception(f"Unexpected error creating {self.cog_name} forum post for tags '{tags}'")
|
|
return f"❌ Unexpected error creating {self.cog_name} forum post."
|
|
|
|
webhook_url = await self._get_or_create_webhook(target_channel)
|
|
if not webhook_url: return f"❌ Failed to get/create webhook for {target_channel.mention} ({self.cog_name}). Check permissions."
|
|
|
|
initial_posts = await self._fetch_posts_logic("internal_initial_fetch", tags, pid_override=0, limit_override=1)
|
|
last_known_post_id = 0
|
|
if isinstance(initial_posts, list) and initial_posts:
|
|
if isinstance(initial_posts[0], dict) and "id" in initial_posts[0]: last_known_post_id = int(initial_posts[0]["id"])
|
|
else: log.warning(f"Malformed {self.cog_name} post data for initial fetch (tags: '{tags}'): {initial_posts[0]}")
|
|
elif isinstance(initial_posts, str): log.error(f"{self.cog_name} API error on initial fetch (tags: '{tags}'): {initial_posts}")
|
|
|
|
guild_id_str = str(guild_id)
|
|
subscription_id = str(uuid.uuid4())
|
|
|
|
new_sub_data = {
|
|
"subscription_id": subscription_id, "tags": tags.strip(), "webhook_url": webhook_url,
|
|
"last_known_post_id": last_known_post_id, "added_by_user_id": str(user_id),
|
|
"added_timestamp": discord.utils.utcnow().isoformat()
|
|
}
|
|
if isinstance(target_channel, discord.ForumChannel):
|
|
new_sub_data.update({"forum_channel_id": str(target_channel.id), "target_post_id": actual_target_thread_id, "post_title": actual_post_title})
|
|
else:
|
|
new_sub_data.update({"channel_id": str(target_channel.id), "thread_id": actual_target_thread_id})
|
|
|
|
if guild_id_str not in self.subscriptions_data: self.subscriptions_data[guild_id_str] = []
|
|
|
|
for existing_sub in self.subscriptions_data[guild_id_str]: # Duplicate check
|
|
is_dup_tags = existing_sub.get("tags") == new_sub_data["tags"]
|
|
is_dup_forum = isinstance(target_channel, discord.ForumChannel) and \
|
|
existing_sub.get("forum_channel_id") == new_sub_data.get("forum_channel_id") and \
|
|
existing_sub.get("target_post_id") == new_sub_data.get("target_post_id")
|
|
is_dup_text_chan = isinstance(target_channel, discord.TextChannel) and \
|
|
existing_sub.get("channel_id") == new_sub_data.get("channel_id") and \
|
|
existing_sub.get("thread_id") == new_sub_data.get("thread_id")
|
|
if is_dup_tags and (is_dup_forum or is_dup_text_chan):
|
|
return f"⚠️ A {self.cog_name} subscription for these tags in this location already exists (ID: `{existing_sub.get('subscription_id')}`)."
|
|
|
|
self.subscriptions_data[guild_id_str].append(new_sub_data)
|
|
self._save_subscriptions()
|
|
|
|
target_desc = f"in {target_channel.mention}"
|
|
if isinstance(target_channel, discord.ForumChannel) and actual_target_thread_mention: target_desc = f"in forum post {actual_target_thread_mention} within {target_channel.mention}"
|
|
elif actual_target_thread_mention: target_desc += f" (thread: {actual_target_thread_mention})"
|
|
|
|
log.info(f"{self.cog_name} subscription added: Guild {guild_id_str}, Tags '{tags}', Target {target_desc}, SubID {subscription_id}")
|
|
return (f"✅ Watching {self.cog_name} for new posts with tags `{tags}` {target_desc}.\n"
|
|
f"Initial latest post ID: {last_known_post_id}. Sub ID: `{subscription_id}`.")
|
|
|
|
# --- Watch command group logic methods ---
|
|
# These will be called by the specific cog's command handlers
|
|
|
|
async def _watch_add_logic(self, interaction: discord.Interaction, tags: str, channel: typing.Union[discord.TextChannel, discord.ForumChannel], thread_target: typing.Optional[str], post_title: typing.Optional[str]):
|
|
if not interaction.guild_id or not interaction.user: # Should be caught by permissions but good check
|
|
await interaction.followup.send("Command error: Missing guild or user context.", ephemeral=True)
|
|
return
|
|
|
|
if isinstance(channel, discord.TextChannel) and post_title:
|
|
await interaction.followup.send("`post_title` is only for Forum Channels.", ephemeral=True)
|
|
return
|
|
if isinstance(channel, discord.ForumChannel) and thread_target:
|
|
await interaction.followup.send("`thread_target` is only for Text Channels.", ephemeral=True)
|
|
return
|
|
|
|
response_message = await self._create_new_subscription(
|
|
guild_id=interaction.guild_id, user_id=interaction.user.id, tags=tags,
|
|
target_channel=channel, requested_thread_target=thread_target, requested_post_title=post_title
|
|
)
|
|
await interaction.followup.send(response_message, ephemeral=True)
|
|
|
|
async def _watch_request_logic(self, interaction: discord.Interaction, tags: str, forum_channel: discord.ForumChannel, post_title: typing.Optional[str]):
|
|
if not interaction.guild_id or not interaction.user:
|
|
await interaction.followup.send("Command error: Missing guild or user context.", ephemeral=True)
|
|
return
|
|
|
|
guild_id_str = str(interaction.guild_id)
|
|
request_id = str(uuid.uuid4())
|
|
actual_post_title = post_title or f"{self.cog_name} Watch: {tags[:50]}"
|
|
|
|
new_request = {
|
|
"request_id": request_id, "requester_id": str(interaction.user.id), "requester_name": str(interaction.user),
|
|
"requested_tags": tags.strip(), "target_forum_channel_id": str(forum_channel.id),
|
|
"requested_post_title": actual_post_title, "status": "pending",
|
|
"request_timestamp": discord.utils.utcnow().isoformat(), "moderator_id": None, "moderation_timestamp": None
|
|
}
|
|
|
|
if guild_id_str not in self.pending_requests_data: self.pending_requests_data[guild_id_str] = []
|
|
self.pending_requests_data[guild_id_str].append(new_request)
|
|
self._save_pending_requests()
|
|
|
|
log.info(f"New {self.cog_name} watch request: Guild {guild_id_str}, Requester {interaction.user}, Tags '{tags}', ReqID {request_id}")
|
|
|
|
await interaction.followup.send(
|
|
f"✅ Your {self.cog_name} watch request for tags `{tags}` in forum {forum_channel.mention} (title: \"{actual_post_title}\") submitted.\n"
|
|
f"Request ID: `{request_id}`. Awaiting moderator approval."
|
|
)
|
|
|
|
async def _watch_pending_list_logic(self, interaction: discord.Interaction):
|
|
if not interaction.guild_id or not interaction.guild: # Ensure guild object for name
|
|
await interaction.response.send_message("Command error: Missing guild context.", ephemeral=True)
|
|
return
|
|
|
|
guild_id_str = str(interaction.guild_id)
|
|
pending_reqs = [req for req in self.pending_requests_data.get(guild_id_str, []) if req.get("status") == "pending"]
|
|
|
|
if not pending_reqs:
|
|
await interaction.response.send_message(f"No pending {self.cog_name} watch requests.", ephemeral=True)
|
|
return
|
|
|
|
embed = discord.Embed(title=f"Pending {self.cog_name} Watch Requests for {interaction.guild.name}", color=discord.Color.orange())
|
|
desc_parts = []
|
|
for req in pending_reqs:
|
|
forum_mention = f"<#{req.get('target_forum_channel_id', 'Unknown')}>"
|
|
timestamp_str = req.get('request_timestamp')
|
|
time_fmt = discord.utils.format_dt(datetime.fromisoformat(timestamp_str), style='R') if timestamp_str else 'Unknown time'
|
|
desc_parts.append(
|
|
f"**ID:** `{req.get('request_id')}`\n"
|
|
f" **Requester:** {req.get('requester_name', 'Unknown')} (`{req.get('requester_id')}`)\n"
|
|
f" **Tags:** `{req.get('requested_tags')}`\n"
|
|
f" **Target Forum:** {forum_mention}\n"
|
|
f" **Proposed Title:** \"{req.get('requested_post_title')}\"\n"
|
|
f" **Requested:** {time_fmt}\n---"
|
|
)
|
|
|
|
embed.description = "\n".join(desc_parts)[:4096] # Limit description length
|
|
await interaction.response.send_message(embed=embed, ephemeral=True)
|
|
|
|
async def _watch_approve_request_logic(self, interaction: discord.Interaction, request_id: str):
|
|
if not interaction.guild_id or not interaction.user or not interaction.guild: # Need guild for fetch_channel and name
|
|
await interaction.followup.send("Command error: Missing context.", ephemeral=True)
|
|
return
|
|
|
|
guild_id_str = str(interaction.guild_id)
|
|
req_to_approve, req_idx = None, -1
|
|
for i, r in enumerate(self.pending_requests_data.get(guild_id_str, [])):
|
|
if r.get("request_id") == request_id and r.get("status") == "pending":
|
|
req_to_approve, req_idx = r, i; break
|
|
|
|
if not req_to_approve:
|
|
await interaction.followup.send(f"❌ Pending {self.cog_name} request ID `{request_id}` not found.", ephemeral=True)
|
|
return
|
|
|
|
target_forum_channel: typing.Optional[discord.ForumChannel] = None
|
|
try:
|
|
target_forum_channel = await interaction.guild.fetch_channel(int(req_to_approve["target_forum_channel_id"]))
|
|
if not isinstance(target_forum_channel, discord.ForumChannel): raise ValueError("Not a ForumChannel")
|
|
except Exception:
|
|
await interaction.followup.send(f"❌ Target forum channel for {self.cog_name} request `{request_id}` not found/invalid.", ephemeral=True)
|
|
return
|
|
|
|
creation_response = await self._create_new_subscription(
|
|
guild_id=interaction.guild_id, user_id=int(req_to_approve["requester_id"]), tags=req_to_approve["requested_tags"],
|
|
target_channel=target_forum_channel, requested_post_title=req_to_approve["requested_post_title"],
|
|
is_request_approval=True, requester_mention=f"<@{req_to_approve['requester_id']}>"
|
|
)
|
|
|
|
if creation_response.startswith("✅"):
|
|
req_to_approve.update({"status": "approved", "moderator_id": str(interaction.user.id), "moderation_timestamp": discord.utils.utcnow().isoformat()})
|
|
self.pending_requests_data[guild_id_str][req_idx] = req_to_approve
|
|
self._save_pending_requests()
|
|
await interaction.followup.send(f"✅ {self.cog_name} Request ID `{request_id}` approved. {creation_response}", ephemeral=True)
|
|
try:
|
|
requester = await self.bot.fetch_user(int(req_to_approve["requester_id"]))
|
|
await requester.send(f"🎉 Your {self.cog_name} watch request (`{request_id}`) for tags `{req_to_approve['requested_tags']}` in server `{interaction.guild.name}` was **approved** by {interaction.user.mention}!\nDetails: {creation_response}")
|
|
except Exception as e: log.error(f"Failed to notify {self.cog_name} requester {req_to_approve['requester_id']} of approval: {e}")
|
|
else:
|
|
await interaction.followup.send(f"❌ Failed to approve {self.cog_name} request `{request_id}`. Sub creation failed: {creation_response}", ephemeral=True)
|
|
|
|
async def _watch_reject_request_logic(self, interaction: discord.Interaction, request_id: str, reason: typing.Optional[str]):
|
|
if not interaction.guild_id or not interaction.user or not interaction.guild: # Need guild for name
|
|
await interaction.followup.send("Command error: Missing context.", ephemeral=True)
|
|
return
|
|
|
|
guild_id_str = str(interaction.guild_id)
|
|
req_to_reject, req_idx = None, -1
|
|
for i, r in enumerate(self.pending_requests_data.get(guild_id_str, [])):
|
|
if r.get("request_id") == request_id and r.get("status") == "pending":
|
|
req_to_reject, req_idx = r, i; break
|
|
|
|
if not req_to_reject:
|
|
await interaction.followup.send(f"❌ Pending {self.cog_name} request ID `{request_id}` not found.", ephemeral=True)
|
|
return
|
|
|
|
req_to_reject.update({"status": "rejected", "moderator_id": str(interaction.user.id), "moderation_timestamp": discord.utils.utcnow().isoformat(), "rejection_reason": reason})
|
|
self.pending_requests_data[guild_id_str][req_idx] = req_to_reject
|
|
self._save_pending_requests()
|
|
await interaction.followup.send(f"🗑️ {self.cog_name} Request ID `{request_id}` rejected.", ephemeral=True)
|
|
try:
|
|
requester = await self.bot.fetch_user(int(req_to_reject["requester_id"]))
|
|
msg = f"😥 Your {self.cog_name} watch request (`{request_id}`) for tags `{req_to_reject['requested_tags']}` in server `{interaction.guild.name}` was **rejected** by {interaction.user.mention}."
|
|
if reason: msg += f"\nReason: {reason}"
|
|
await requester.send(msg)
|
|
except Exception as e: log.error(f"Failed to notify {self.cog_name} requester {req_to_reject['requester_id']} of rejection: {e}")
|
|
|
|
async def _watch_list_logic(self, interaction: discord.Interaction):
|
|
if not interaction.guild_id or not interaction.guild:
|
|
await interaction.response.send_message("Command error: Missing guild context.", ephemeral=True)
|
|
return
|
|
|
|
guild_id_str = str(interaction.guild_id)
|
|
guild_subs = self.subscriptions_data.get(guild_id_str, [])
|
|
if not guild_subs:
|
|
await interaction.response.send_message(f"No active {self.cog_name} tag watches for this server.", ephemeral=True)
|
|
return
|
|
|
|
embed = discord.Embed(title=f"Active {self.cog_name} Tag Watches for {interaction.guild.name}", color=discord.Color.blue())
|
|
desc_parts = []
|
|
for sub in guild_subs:
|
|
target_location = "Unknown Target"
|
|
# Simplified location fetching for brevity, original logic was more robust
|
|
if sub.get('forum_channel_id') and sub.get('target_post_id'):
|
|
target_location = f"Forum Post <#{sub.get('target_post_id')}> in Forum <#{sub.get('forum_channel_id')}>"
|
|
elif sub.get('channel_id'):
|
|
target_location = f"<#{sub.get('channel_id')}>"
|
|
if sub.get('thread_id'): target_location += f" (Thread <#{sub.get('thread_id')}>)"
|
|
|
|
desc_parts.append(f"**ID:** `{sub.get('subscription_id')}`\n **Tags:** `{sub.get('tags')}`\n **Target:** {target_location}\n **Last Sent ID:** `{sub.get('last_known_post_id', 'N/A')}`\n---")
|
|
|
|
embed.description = "\n".join(desc_parts)[:4096]
|
|
await interaction.response.send_message(embed=embed, ephemeral=True)
|
|
|
|
async def _watch_remove_logic(self, interaction: discord.Interaction, subscription_id: str):
|
|
if not interaction.guild_id:
|
|
await interaction.response.send_message("Command error: Missing guild context.", ephemeral=True)
|
|
return
|
|
|
|
guild_id_str = str(interaction.guild_id)
|
|
guild_subs = self.subscriptions_data.get(guild_id_str, [])
|
|
removed_info, found = None, False
|
|
new_subs_list = []
|
|
for sub_entry in guild_subs:
|
|
if sub_entry.get("subscription_id") == subscription_id:
|
|
removed_info = f"tags `{sub_entry.get('tags')}`" # Simplified
|
|
found = True
|
|
log.info(f"Removing {self.cog_name} watch: Guild {guild_id_str}, Sub ID {subscription_id}")
|
|
else: new_subs_list.append(sub_entry)
|
|
|
|
if not found:
|
|
await interaction.response.send_message(f"❌ {self.cog_name} subscription ID `{subscription_id}` not found.", ephemeral=True)
|
|
return
|
|
|
|
if not new_subs_list: del self.subscriptions_data[guild_id_str]
|
|
else: self.subscriptions_data[guild_id_str] = new_subs_list
|
|
self._save_subscriptions()
|
|
await interaction.response.send_message(f"✅ Removed {self.cog_name} watch for {removed_info} (ID: `{subscription_id}`).", ephemeral=True)
|
|
|
|
async def _watch_test_message_logic(self, interaction: discord.Interaction, subscription_id: str):
|
|
"""Sends a test new-post message for a given subscription."""
|
|
if not interaction.guild_id:
|
|
await interaction.response.send_message("Command error: Missing guild context.", ephemeral=True)
|
|
return
|
|
|
|
guild_id_str = str(interaction.guild_id)
|
|
guild_subs = self.subscriptions_data.get(guild_id_str, [])
|
|
target_sub = None
|
|
for sub_entry in guild_subs:
|
|
if sub_entry.get("subscription_id") == subscription_id:
|
|
target_sub = sub_entry
|
|
break
|
|
|
|
if not target_sub:
|
|
await interaction.response.send_message(
|
|
f"❌ {self.cog_name} subscription ID `{subscription_id}` not found.",
|
|
ephemeral=True,
|
|
)
|
|
return
|
|
|
|
tags = target_sub.get("tags")
|
|
webhook_url = target_sub.get("webhook_url")
|
|
if not tags or not webhook_url:
|
|
await interaction.response.send_message(
|
|
f"❌ Subscription `{subscription_id}` is missing tags or webhook information.",
|
|
ephemeral=True,
|
|
)
|
|
return
|
|
|
|
fetched = await self._fetch_posts_logic("internal_test_msg", tags, pid_override=0, limit_override=1)
|
|
file_url = None
|
|
if isinstance(fetched, list) and fetched:
|
|
first = fetched[0]
|
|
if isinstance(first, dict):
|
|
file_url = first.get("file_url")
|
|
|
|
if not file_url:
|
|
await interaction.response.send_message(
|
|
f"❌ Failed to fetch a post for tags `{tags}`.", ephemeral=True
|
|
)
|
|
return
|
|
|
|
view = self._build_new_post_view(tags, file_url)
|
|
thread_id = target_sub.get("target_post_id") or target_sub.get("thread_id")
|
|
send_success = await self._send_via_webhook(
|
|
webhook_url, content="", thread_id=thread_id, view=view
|
|
)
|
|
if send_success:
|
|
await interaction.response.send_message(
|
|
f"✅ Test {self.cog_name} post sent for subscription `{subscription_id}`.",
|
|
ephemeral=True,
|
|
)
|
|
else:
|
|
await interaction.response.send_message(
|
|
f"❌ Failed to send test {self.cog_name} post for subscription `{subscription_id}`.",
|
|
ephemeral=True,
|
|
)
|