1606 lines
68 KiB
Python
1606 lines
68 KiB
Python
import os
|
|
import discord
|
|
from discord.ext import commands, tasks
|
|
from discord import app_commands
|
|
from discord import ui
|
|
import random
|
|
import aiohttp
|
|
import time
|
|
import json
|
|
import typing # Need this for Optional
|
|
import uuid # For subscription IDs
|
|
import asyncio
|
|
import logging # For logging
|
|
from datetime import datetime # For parsing ISO format timestamps
|
|
import abc # For Abstract Base Class
|
|
|
|
# Setup logger for this cog
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
# Combined metaclass to resolve conflicts between CogMeta and ABCMeta
|
|
class GelbooruWatcherMeta(commands.CogMeta, abc.ABCMeta):
|
|
pass
|
|
|
|
|
|
class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMeta):
|
|
def __init__(
|
|
self,
|
|
bot: commands.Bot,
|
|
cog_name: str,
|
|
api_base_url: str,
|
|
default_tags: str,
|
|
is_nsfw_site: bool,
|
|
command_group_name: str,
|
|
main_command_name: str,
|
|
post_url_template: str,
|
|
):
|
|
self.bot = bot
|
|
# Ensure super().__init__() is called for Cog's metaclass features, especially if 'name' was passed to Cog.
|
|
# However, 'name' is handled by the derived classes (Rule34Cog, SafebooruCog)
|
|
# For the base class, we don't pass 'name' to commands.Cog constructor directly.
|
|
# The `name` parameter in `Rule34Cog(..., name="Rule34")` is handled by CogMeta.
|
|
# The base class itself doesn't need a Cog 'name' in the same way.
|
|
# commands.Cog.__init__(self, bot) # This might be needed if Cog's __init__ does setup
|
|
# Let's rely on the derived class's super() call to handle Cog's __init__ properly.
|
|
|
|
self.cog_name = cog_name
|
|
self.api_base_url = api_base_url
|
|
self.default_tags = default_tags
|
|
self.is_nsfw_site = is_nsfw_site
|
|
self.command_group_name = command_group_name
|
|
self.main_command_name = main_command_name
|
|
self.post_url_template = post_url_template
|
|
|
|
self.cache_file = f"{self.cog_name.lower()}_cache.json"
|
|
self.subscriptions_file = f"{self.cog_name.lower()}_subscriptions.json"
|
|
self.pending_requests_file = f"{self.cog_name.lower()}_pending_requests.json"
|
|
|
|
self.cache_data = self._load_cache()
|
|
self.subscriptions_data = self._load_subscriptions()
|
|
self.pending_requests_data = self._load_pending_requests()
|
|
self.session: typing.Optional[aiohttp.ClientSession] = None
|
|
|
|
if bot.is_ready():
|
|
asyncio.create_task(self.initialize_cog_async())
|
|
else:
|
|
asyncio.create_task(self.start_task_when_ready())
|
|
|
|
async def initialize_cog_async(self):
|
|
"""Asynchronous part of cog initialization."""
|
|
log.info(f"Initializing {self.cog_name}Cog...")
|
|
if self.session is None or self.session.closed:
|
|
self.session = aiohttp.ClientSession()
|
|
log.info(f"aiohttp ClientSession created for {self.cog_name}Cog.")
|
|
if not self.check_new_posts.is_running():
|
|
self.check_new_posts.start()
|
|
log.info(f"{self.cog_name} new post checker task started.")
|
|
|
|
async def start_task_when_ready(self):
|
|
"""Waits until bot is ready, then initializes and starts tasks."""
|
|
await self.bot.wait_until_ready()
|
|
await self.initialize_cog_async()
|
|
|
|
async def cog_load(self):
|
|
log.info(f"{self.cog_name}Cog loaded.")
|
|
if self.session is None or self.session.closed:
|
|
self.session = aiohttp.ClientSession()
|
|
log.info(
|
|
f"aiohttp ClientSession (re)created during cog_load for {self.cog_name}Cog."
|
|
)
|
|
if not self.check_new_posts.is_running():
|
|
if self.bot.is_ready():
|
|
self.check_new_posts.start()
|
|
log.info(
|
|
f"{self.cog_name} new post checker task started from cog_load."
|
|
)
|
|
else:
|
|
log.warning(
|
|
f"{self.cog_name}Cog loaded but bot not ready, task start deferred."
|
|
)
|
|
|
|
async def cog_unload(self):
|
|
"""Clean up resources when the cog is unloaded."""
|
|
self.check_new_posts.cancel()
|
|
log.info(f"{self.cog_name} new post checker task stopped.")
|
|
if self.session and not self.session.closed:
|
|
await self.session.close()
|
|
log.info(f"aiohttp ClientSession closed for {self.cog_name}Cog.")
|
|
|
|
def _load_cache(self):
|
|
if os.path.exists(self.cache_file):
|
|
try:
|
|
with open(self.cache_file, "r") as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
log.error(
|
|
f"Failed to load {self.cog_name} cache file ({self.cache_file}): {e}"
|
|
)
|
|
return {}
|
|
|
|
def _save_cache(self):
|
|
try:
|
|
with open(self.cache_file, "w") as f:
|
|
json.dump(self.cache_data, f, indent=4)
|
|
except Exception as e:
|
|
log.error(
|
|
f"Failed to save {self.cog_name} cache file ({self.cache_file}): {e}"
|
|
)
|
|
|
|
def _load_subscriptions(self):
|
|
if os.path.exists(self.subscriptions_file):
|
|
try:
|
|
with open(self.subscriptions_file, "r") as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
log.error(
|
|
f"Failed to load {self.cog_name} subscriptions file ({self.subscriptions_file}): {e}"
|
|
)
|
|
return {}
|
|
|
|
def _save_subscriptions(self):
|
|
try:
|
|
with open(self.subscriptions_file, "w") as f:
|
|
json.dump(self.subscriptions_data, f, indent=4)
|
|
log.debug(
|
|
f"Saved {self.cog_name} subscriptions to {self.subscriptions_file}"
|
|
)
|
|
except Exception as e:
|
|
log.error(
|
|
f"Failed to save {self.cog_name} subscriptions file ({self.subscriptions_file}): {e}"
|
|
)
|
|
|
|
def _load_pending_requests(self):
|
|
if os.path.exists(self.pending_requests_file):
|
|
try:
|
|
with open(self.pending_requests_file, "r") as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
log.error(
|
|
f"Failed to load {self.cog_name} pending requests file ({self.pending_requests_file}): {e}"
|
|
)
|
|
return {}
|
|
|
|
def _save_pending_requests(self):
|
|
try:
|
|
with open(self.pending_requests_file, "w") as f:
|
|
json.dump(self.pending_requests_data, f, indent=4)
|
|
log.debug(
|
|
f"Saved {self.cog_name} pending requests to {self.pending_requests_file}"
|
|
)
|
|
except Exception as e:
|
|
log.error(
|
|
f"Failed to save {self.cog_name} pending requests file ({self.pending_requests_file}): {e}"
|
|
)
|
|
|
|
async def _get_or_create_webhook(
|
|
self, channel: typing.Union[discord.TextChannel, discord.ForumChannel]
|
|
) -> typing.Optional[str]:
|
|
if not self.session or self.session.closed:
|
|
self.session = aiohttp.ClientSession()
|
|
log.info(
|
|
f"Recreated aiohttp.ClientSession in _get_or_create_webhook for {self.cog_name}"
|
|
)
|
|
|
|
guild_subs = self.subscriptions_data.get(str(channel.guild.id), [])
|
|
for sub in guild_subs:
|
|
if sub.get("channel_id") == str(channel.id) and sub.get("webhook_url"):
|
|
try:
|
|
webhook = discord.Webhook.from_url(
|
|
sub["webhook_url"], session=self.session
|
|
)
|
|
await webhook.fetch()
|
|
if webhook.channel_id == channel.id:
|
|
log.debug(
|
|
f"Reusing existing webhook {webhook.id} for channel {channel.id} ({self.cog_name})"
|
|
)
|
|
return sub["webhook_url"]
|
|
except (discord.NotFound, ValueError, discord.HTTPException):
|
|
log.warning(
|
|
f"Found stored webhook URL for channel {channel.id} ({self.cog_name}) but it's invalid."
|
|
)
|
|
|
|
try:
|
|
webhooks = await channel.webhooks()
|
|
for wh in webhooks:
|
|
if wh.user == self.bot.user:
|
|
log.debug(
|
|
f"Found existing bot-owned webhook {wh.id} ('{wh.name}') in channel {channel.id} ({self.cog_name})"
|
|
)
|
|
return wh.url
|
|
except discord.Forbidden:
|
|
log.warning(
|
|
f"Missing 'Manage Webhooks' permission in channel {channel.id} ({self.cog_name})."
|
|
)
|
|
return None
|
|
except discord.HTTPException as e:
|
|
log.error(
|
|
f"HTTP error listing webhooks for channel {channel.id} ({self.cog_name}): {e}"
|
|
)
|
|
|
|
if not channel.permissions_for(channel.guild.me).manage_webhooks:
|
|
log.warning(
|
|
f"Missing 'Manage Webhooks' permission in channel {channel.id} ({self.cog_name}). Cannot create webhook."
|
|
)
|
|
return None
|
|
|
|
try:
|
|
webhook_name = f"{self.bot.user.name} {self.cog_name} Watcher"
|
|
avatar_bytes = None
|
|
if self.bot.user and self.bot.user.display_avatar:
|
|
try:
|
|
avatar_bytes = await self.bot.user.display_avatar.read()
|
|
except Exception as e:
|
|
log.warning(
|
|
f"Could not read bot avatar for webhook creation ({self.cog_name}): {e}"
|
|
)
|
|
|
|
new_webhook = await channel.create_webhook(
|
|
name=webhook_name,
|
|
avatar=avatar_bytes,
|
|
reason=f"{self.cog_name} Tag Watcher",
|
|
)
|
|
log.info(
|
|
f"Created new webhook {new_webhook.id} ('{new_webhook.name}') in channel {channel.id} ({self.cog_name})"
|
|
)
|
|
return new_webhook.url
|
|
except discord.HTTPException as e:
|
|
log.error(
|
|
f"Failed to create webhook in {channel.mention} ({self.cog_name}): {e}"
|
|
)
|
|
return None
|
|
except Exception as e:
|
|
log.exception(
|
|
f"Unexpected error creating webhook in {channel.mention} ({self.cog_name})"
|
|
)
|
|
return None
|
|
|
|
async def _send_via_webhook(
|
|
self,
|
|
webhook_url: str,
|
|
content: str = "",
|
|
thread_id: typing.Optional[str] = None,
|
|
view: typing.Optional[discord.ui.View] = None,
|
|
):
|
|
if not self.session or self.session.closed:
|
|
self.session = aiohttp.ClientSession()
|
|
log.info(
|
|
f"Recreated aiohttp.ClientSession in _send_via_webhook for {self.cog_name}"
|
|
)
|
|
|
|
try:
|
|
webhook = discord.Webhook.from_url(
|
|
webhook_url, session=self.session, client=self.bot
|
|
)
|
|
target_thread_obj = None
|
|
if thread_id:
|
|
try:
|
|
target_thread_obj = discord.Object(id=int(thread_id))
|
|
except ValueError:
|
|
log.error(
|
|
f"Invalid thread_id format: {thread_id} for webhook {webhook_url[:30]} ({self.cog_name}). Sending to main channel."
|
|
)
|
|
|
|
await webhook.send(
|
|
content=content,
|
|
username=(
|
|
f"{self.bot.user.name} {self.cog_name} Watcher"
|
|
if self.bot.user
|
|
else f"{self.cog_name} Watcher"
|
|
),
|
|
avatar_url=(
|
|
self.bot.user.display_avatar.url
|
|
if self.bot.user and self.bot.user.display_avatar
|
|
else None
|
|
),
|
|
thread=target_thread_obj,
|
|
view=view,
|
|
)
|
|
log.debug(
|
|
f"Sent message via webhook to {webhook_url[:30]}... (Thread: {thread_id if thread_id else 'None'}) ({self.cog_name})"
|
|
)
|
|
return True
|
|
except ValueError:
|
|
log.error(
|
|
f"Invalid webhook URL format: {webhook_url[:30]}... ({self.cog_name})"
|
|
)
|
|
except discord.NotFound:
|
|
log.error(f"Webhook not found: {webhook_url[:30]}... ({self.cog_name})")
|
|
except discord.Forbidden:
|
|
log.error(
|
|
f"Forbidden to send to webhook: {webhook_url[:30]}... ({self.cog_name})"
|
|
)
|
|
except discord.HTTPException as e:
|
|
log.error(
|
|
f"HTTP error sending to webhook {webhook_url[:30]}... ({self.cog_name}): {e}"
|
|
)
|
|
except aiohttp.ClientError as e:
|
|
log.error(
|
|
f"aiohttp client error sending to webhook {webhook_url[:30]}... ({self.cog_name}): {e}"
|
|
)
|
|
except Exception as e:
|
|
log.exception(
|
|
f"Unexpected error sending to webhook {webhook_url[:30]}... ({self.cog_name}): {e}"
|
|
)
|
|
return False
|
|
|
|
async def _fetch_posts_logic(
|
|
self,
|
|
interaction_or_ctx: typing.Union[discord.Interaction, commands.Context, str],
|
|
tags: str,
|
|
pid_override: typing.Optional[int] = None,
|
|
limit_override: typing.Optional[int] = None,
|
|
hidden: bool = False,
|
|
) -> typing.Union[str, tuple[str, list], list]:
|
|
all_results = []
|
|
current_pid = pid_override if pid_override is not None else 0
|
|
# API has a hard limit of 1000 results per request, so we'll use that as our per-page limit
|
|
per_page_limit = 1000
|
|
# If limit_override is provided, use it, otherwise default to 3000 (3 pages of results)
|
|
total_limit = limit_override if limit_override is not None else 100000
|
|
|
|
# For internal calls with specific pid/limit, use those exact values
|
|
if pid_override is not None or limit_override is not None:
|
|
use_pagination = False
|
|
api_limit = limit_override if limit_override is not None else per_page_limit
|
|
else:
|
|
use_pagination = True
|
|
api_limit = per_page_limit
|
|
|
|
if not isinstance(interaction_or_ctx, str) and interaction_or_ctx:
|
|
if self.is_nsfw_site:
|
|
is_nsfw_channel = False
|
|
channel = interaction_or_ctx.channel
|
|
if isinstance(channel, discord.TextChannel) and channel.is_nsfw():
|
|
is_nsfw_channel = True
|
|
elif isinstance(channel, discord.DMChannel):
|
|
is_nsfw_channel = True
|
|
|
|
# For Gelbooru-like APIs, 'rating:safe', 'rating:general', 'rating:questionable' might be SFW-ish
|
|
# We'll stick to 'rating:safe' for simplicity as it was in Rule34Cog
|
|
allow_in_non_nsfw = "rating:safe" in tags.lower()
|
|
|
|
if not is_nsfw_channel and not allow_in_non_nsfw:
|
|
return f"This command for {self.cog_name} can only be used in age-restricted (NSFW) channels, DMs, or with the `rating:safe` tag."
|
|
|
|
is_interaction = not isinstance(interaction_or_ctx, commands.Context)
|
|
if is_interaction:
|
|
if not interaction_or_ctx.response.is_done():
|
|
await interaction_or_ctx.response.defer(ephemeral=hidden)
|
|
elif hasattr(interaction_or_ctx, "reply"): # Prefix command
|
|
await interaction_or_ctx.reply(
|
|
f"Fetching data from {self.cog_name}, please wait..."
|
|
)
|
|
|
|
# Check cache first if not using specific pagination
|
|
if not use_pagination:
|
|
# Skip cache for internal calls with specific pid/limit
|
|
pass
|
|
else:
|
|
cache_key = tags.lower().strip()
|
|
if cache_key in self.cache_data:
|
|
cached_entry = self.cache_data[cache_key]
|
|
cache_timestamp = cached_entry.get("timestamp", 0)
|
|
if time.time() - cache_timestamp < 86400:
|
|
all_results = cached_entry.get("results", [])
|
|
if all_results:
|
|
random_result = random.choice(all_results)
|
|
return (f"{random_result['file_url']}", all_results)
|
|
|
|
if not self.session or self.session.closed:
|
|
self.session = aiohttp.ClientSession()
|
|
log.info(
|
|
f"Recreated aiohttp.ClientSession in _fetch_posts_logic for {self.cog_name}"
|
|
)
|
|
|
|
all_results = []
|
|
|
|
# If using pagination, we'll make multiple requests
|
|
if use_pagination:
|
|
max_pages = (total_limit + per_page_limit - 1) // per_page_limit
|
|
for page in range(max_pages):
|
|
# Stop if we've reached our total limit or if we got fewer results than the per-page limit
|
|
if len(all_results) >= total_limit or (
|
|
page > 0 and len(all_results) % per_page_limit != 0
|
|
):
|
|
break
|
|
|
|
api_params = {
|
|
"page": "dapi",
|
|
"s": "post",
|
|
"q": "index",
|
|
"limit": per_page_limit,
|
|
"pid": page,
|
|
"tags": tags,
|
|
"json": 1,
|
|
}
|
|
|
|
try:
|
|
async with self.session.get(
|
|
self.api_base_url, params=api_params
|
|
) as response:
|
|
if response.status == 200:
|
|
try:
|
|
data = await response.json()
|
|
except aiohttp.ContentTypeError:
|
|
log.warning(
|
|
f"{self.cog_name} API returned non-JSON for tags: {tags}, pid: {page}, params: {api_params}"
|
|
)
|
|
data = None
|
|
|
|
if data and isinstance(data, list):
|
|
# If we got fewer results than requested, we've reached the end
|
|
all_results.extend(data)
|
|
if len(data) < per_page_limit:
|
|
break
|
|
elif isinstance(data, list) and len(data) == 0:
|
|
# Empty page, no more results
|
|
break
|
|
else:
|
|
log.warning(
|
|
f"Unexpected API response format from {self.cog_name} (not list or empty list): {data} for tags: {tags}, pid: {page}, params: {api_params}"
|
|
)
|
|
break
|
|
else:
|
|
log.error(
|
|
f"Failed to fetch {self.cog_name} data. HTTP Status: {response.status} for tags: {tags}, pid: {page}, params: {api_params}"
|
|
)
|
|
if page == 0: # Only return error if first page fails
|
|
return f"Failed to fetch data from {self.cog_name}. HTTP Status: {response.status}"
|
|
break
|
|
except aiohttp.ClientError as e:
|
|
log.error(
|
|
f"aiohttp.ClientError in _fetch_posts_logic for {self.cog_name} tags {tags}: {e}"
|
|
)
|
|
if page == 0: # Only return error if first page fails
|
|
return f"Network error fetching data from {self.cog_name}: {e}"
|
|
break
|
|
except Exception as e:
|
|
log.exception(
|
|
f"Unexpected error in _fetch_posts_logic API call for {self.cog_name} tags {tags}: {e}"
|
|
)
|
|
if page == 0: # Only return error if first page fails
|
|
return f"An unexpected error occurred during {self.cog_name} API call: {e}"
|
|
break
|
|
|
|
# Limit to the total we want
|
|
if len(all_results) > total_limit:
|
|
all_results = all_results[:total_limit]
|
|
break
|
|
else:
|
|
# Single request with specific pid/limit
|
|
api_params = {
|
|
"page": "dapi",
|
|
"s": "post",
|
|
"q": "index",
|
|
"limit": api_limit,
|
|
"pid": current_pid,
|
|
"tags": tags,
|
|
"json": 1,
|
|
}
|
|
|
|
try:
|
|
async with self.session.get(
|
|
self.api_base_url, params=api_params
|
|
) as response:
|
|
if response.status == 200:
|
|
try:
|
|
data = await response.json()
|
|
except aiohttp.ContentTypeError:
|
|
log.warning(
|
|
f"{self.cog_name} API returned non-JSON for tags: {tags}, pid: {current_pid}, params: {api_params}"
|
|
)
|
|
data = None
|
|
|
|
if data and isinstance(data, list):
|
|
all_results.extend(data)
|
|
elif isinstance(data, list) and len(data) == 0:
|
|
pass
|
|
else:
|
|
log.warning(
|
|
f"Unexpected API response format from {self.cog_name} (not list or empty list): {data} for tags: {tags}, pid: {current_pid}, params: {api_params}"
|
|
)
|
|
if pid_override is not None or limit_override is not None:
|
|
return f"Unexpected API response format from {self.cog_name}: {response.status}"
|
|
else:
|
|
log.error(
|
|
f"Failed to fetch {self.cog_name} data. HTTP Status: {response.status} for tags: {tags}, pid: {current_pid}, params: {api_params}"
|
|
)
|
|
return f"Failed to fetch data from {self.cog_name}. HTTP Status: {response.status}"
|
|
except aiohttp.ClientError as e:
|
|
log.error(
|
|
f"aiohttp.ClientError in _fetch_posts_logic for {self.cog_name} tags {tags}: {e}"
|
|
)
|
|
return f"Network error fetching data from {self.cog_name}: {e}"
|
|
except Exception as e:
|
|
log.exception(
|
|
f"Unexpected error in _fetch_posts_logic API call for {self.cog_name} tags {tags}: {e}"
|
|
)
|
|
return (
|
|
f"An unexpected error occurred during {self.cog_name} API call: {e}"
|
|
)
|
|
|
|
if pid_override is not None or limit_override is not None:
|
|
return all_results
|
|
|
|
if all_results:
|
|
cache_key = tags.lower().strip()
|
|
self.cache_data[cache_key] = {
|
|
"timestamp": int(time.time()),
|
|
"results": all_results,
|
|
}
|
|
self._save_cache()
|
|
|
|
if not all_results:
|
|
return f"No results found from {self.cog_name} for the given tags."
|
|
else:
|
|
random_result = random.choice(all_results)
|
|
post_url = self.post_url_template.format(random_result["id"])
|
|
return (f"<{post_url}>\n{random_result['file_url']}", all_results)
|
|
|
|
class GelbooruButtons(ui.LayoutView):
|
|
container = ui.Container()
|
|
buttons = ui.ActionRow()
|
|
|
|
def __init__(
|
|
self,
|
|
cog: "GelbooruWatcherBaseCog",
|
|
tags: str,
|
|
all_results: list,
|
|
hidden: bool = False,
|
|
):
|
|
super().__init__(timeout=300)
|
|
self.cog = cog
|
|
self.tags = tags
|
|
self.all_results = all_results
|
|
self.hidden = hidden
|
|
self.current_index = 0
|
|
|
|
if self.all_results:
|
|
self._update_container(random.choice(self.all_results))
|
|
|
|
def _update_container(self, result: dict):
|
|
self.container.clear_items()
|
|
gallery = ui.MediaGallery()
|
|
gallery.add_item(media=result["file_url"])
|
|
self.container.add_item(gallery)
|
|
self.container.add_item(
|
|
ui.TextDisplay(f"{self.cog.cog_name} result for tags `{self.tags}`:")
|
|
)
|
|
post_url = self.cog.post_url_template.format(result["id"])
|
|
self.container.add_item(ui.TextDisplay(post_url))
|
|
|
|
@buttons.button(label="New Random", style=discord.ButtonStyle.primary)
|
|
async def new_random(
|
|
self, interaction: discord.Interaction, button: discord.ui.Button
|
|
):
|
|
random_result = random.choice(self.all_results)
|
|
self._update_container(random_result)
|
|
await interaction.response.edit_message(content="", view=self)
|
|
|
|
@buttons.button(
|
|
label="Random In New Message", style=discord.ButtonStyle.success
|
|
)
|
|
async def new_message(
|
|
self, interaction: discord.Interaction, button: discord.ui.Button
|
|
):
|
|
random_result = random.choice(self.all_results)
|
|
new_view = self.cog.GelbooruButtons(
|
|
self.cog, self.tags, self.all_results, self.hidden
|
|
)
|
|
new_view._update_container(random_result)
|
|
await interaction.response.send_message(
|
|
content="", view=new_view, ephemeral=self.hidden
|
|
)
|
|
|
|
@buttons.button(label="Browse Results", style=discord.ButtonStyle.secondary)
|
|
async def browse_results(
|
|
self, interaction: discord.Interaction, button: discord.ui.Button
|
|
):
|
|
if not self.all_results:
|
|
await interaction.response.send_message(
|
|
"No results to browse.", ephemeral=True
|
|
)
|
|
return
|
|
self.current_index = 0
|
|
view = self.cog.BrowseView(
|
|
self.cog, self.tags, self.all_results, self.hidden, self.current_index
|
|
)
|
|
await interaction.response.edit_message(content="", view=view)
|
|
|
|
@buttons.button(label="Pin", style=discord.ButtonStyle.danger)
|
|
async def pin_message(
|
|
self, interaction: discord.Interaction, button: discord.ui.Button
|
|
):
|
|
if interaction.message:
|
|
try:
|
|
await interaction.message.pin()
|
|
await interaction.response.send_message(
|
|
"Message pinned successfully!", ephemeral=True
|
|
)
|
|
except discord.Forbidden:
|
|
await interaction.response.send_message(
|
|
"I don't have permission to pin messages.", ephemeral=True
|
|
)
|
|
except discord.HTTPException as e:
|
|
await interaction.response.send_message(
|
|
f"Failed to pin: {e}", ephemeral=True
|
|
)
|
|
|
|
class BrowseView(ui.LayoutView):
|
|
container = ui.Container()
|
|
nav_row = ui.ActionRow()
|
|
extra_row = ui.ActionRow()
|
|
|
|
def __init__(
|
|
self,
|
|
cog: "GelbooruWatcherBaseCog",
|
|
tags: str,
|
|
all_results: list,
|
|
hidden: bool = False,
|
|
current_index: int = 0,
|
|
):
|
|
super().__init__(timeout=300)
|
|
self.cog = cog
|
|
self.tags = tags
|
|
self.all_results = all_results
|
|
self.hidden = hidden
|
|
self.current_index = current_index
|
|
|
|
if self.all_results:
|
|
self._refresh_container()
|
|
|
|
def _refresh_container(self):
|
|
self.container.clear_items()
|
|
result = self.all_results[self.current_index]
|
|
gallery = ui.MediaGallery()
|
|
gallery.add_item(media=result["file_url"])
|
|
self.container.add_item(gallery)
|
|
idx_label = f"Result {self.current_index + 1}/{len(self.all_results)} for tags `{self.tags}`:"
|
|
self.container.add_item(ui.TextDisplay(idx_label))
|
|
post_url = self.cog.post_url_template.format(result["id"])
|
|
self.container.add_item(ui.TextDisplay(post_url))
|
|
|
|
async def _update_message(self, interaction: discord.Interaction):
|
|
self._refresh_container()
|
|
await interaction.response.edit_message(content="", view=self)
|
|
|
|
@nav_row.button(label="First", style=discord.ButtonStyle.secondary, emoji="⏪")
|
|
async def first(
|
|
self, interaction: discord.Interaction, button: discord.ui.Button
|
|
):
|
|
self.current_index = 0
|
|
await self._update_message(interaction)
|
|
|
|
@nav_row.button(
|
|
label="Previous", style=discord.ButtonStyle.secondary, emoji="◀️"
|
|
)
|
|
async def previous(
|
|
self, interaction: discord.Interaction, button: discord.ui.Button
|
|
):
|
|
self.current_index = (self.current_index - 1 + len(self.all_results)) % len(
|
|
self.all_results
|
|
)
|
|
await self._update_message(interaction)
|
|
|
|
@nav_row.button(label="Next", style=discord.ButtonStyle.primary, emoji="▶️")
|
|
async def next_result(
|
|
self, interaction: discord.Interaction, button: discord.ui.Button
|
|
):
|
|
self.current_index = (self.current_index + 1) % len(self.all_results)
|
|
await self._update_message(interaction)
|
|
|
|
@nav_row.button(label="Last", style=discord.ButtonStyle.secondary, emoji="⏩")
|
|
async def last(
|
|
self, interaction: discord.Interaction, button: discord.ui.Button
|
|
):
|
|
self.current_index = len(self.all_results) - 1
|
|
await self._update_message(interaction)
|
|
|
|
@extra_row.button(label="Go To", style=discord.ButtonStyle.primary)
|
|
async def goto(
|
|
self, interaction: discord.Interaction, button: discord.ui.Button
|
|
):
|
|
modal = self.cog.GoToModal(len(self.all_results))
|
|
await interaction.response.send_modal(modal)
|
|
await modal.wait()
|
|
if modal.value is not None:
|
|
self.current_index = modal.value - 1
|
|
await interaction.followup.edit_message(
|
|
interaction.message.id, content="", view=self
|
|
)
|
|
|
|
@extra_row.button(
|
|
label="Back to Main Controls", style=discord.ButtonStyle.danger
|
|
)
|
|
async def back_to_main(
|
|
self, interaction: discord.Interaction, button: discord.ui.Button
|
|
):
|
|
# Send a random one from all_results as the content
|
|
if not self.all_results: # Should not happen if browse was initiated
|
|
await interaction.response.edit_message(
|
|
content="No results available.", view=None
|
|
)
|
|
return
|
|
random_result = random.choice(self.all_results)
|
|
view = self.cog.GelbooruButtons(
|
|
self.cog, self.tags, self.all_results, self.hidden
|
|
)
|
|
view._update_container(random_result)
|
|
await interaction.response.edit_message(content="", view=view)
|
|
|
|
class GoToModal(discord.ui.Modal):
|
|
def __init__(self, max_pages: int):
|
|
super().__init__(title="Go To Page")
|
|
self.value = None
|
|
self.max_pages = max_pages
|
|
self.page_num = discord.ui.TextInput(
|
|
label=f"Page Number (1-{max_pages})",
|
|
placeholder=f"Enter a number between 1 and {max_pages}",
|
|
min_length=1,
|
|
max_length=len(str(max_pages)),
|
|
)
|
|
self.add_item(self.page_num)
|
|
|
|
async def on_submit(self, interaction: discord.Interaction):
|
|
try:
|
|
num = int(self.page_num.value)
|
|
if 1 <= num <= self.max_pages:
|
|
self.value = num
|
|
await interaction.response.defer() # Defer here, followup in BrowseView.goto
|
|
else:
|
|
await interaction.response.send_message(
|
|
f"Please enter a number between 1 and {self.max_pages}",
|
|
ephemeral=True,
|
|
)
|
|
except ValueError:
|
|
await interaction.response.send_message(
|
|
"Please enter a valid number", ephemeral=True
|
|
)
|
|
|
|
def _build_new_post_view(
|
|
self, tags: str, file_url: str, post_id: int
|
|
) -> ui.LayoutView:
|
|
view = ui.LayoutView(timeout=None)
|
|
container = ui.Container()
|
|
view.add_item(container)
|
|
|
|
gallery = ui.MediaGallery()
|
|
gallery.add_item(media=file_url)
|
|
container.add_item(gallery)
|
|
container.add_item(
|
|
ui.TextDisplay(f"New {self.cog_name} post for tags `{tags}`:")
|
|
)
|
|
post_url = self.post_url_template.format(post_id)
|
|
container.add_item(ui.TextDisplay(post_url))
|
|
|
|
return view
|
|
|
|
async def _prefix_command_logic(self, ctx: commands.Context, tags: str):
|
|
# Loading message is handled by _fetch_posts_logic if ctx is passed
|
|
response = await self._fetch_posts_logic(ctx, tags)
|
|
|
|
if isinstance(response, tuple):
|
|
_, all_results = response
|
|
view = self.GelbooruButtons(self, tags, all_results, hidden=False)
|
|
# _fetch_posts_logic sends initial reply, so we edit it
|
|
original_message = ctx.message # This might not be the bot's reply message
|
|
# We need a reference to the "Fetching data..." message.
|
|
# This requires _fetch_posts_logic to return it or for the command to manage it.
|
|
# For now, let's assume _fetch_posts_logic's "reply" is what we want to edit.
|
|
# This part is tricky with prefix commands and defer/reply.
|
|
# A simpler way for prefix:
|
|
# loading_msg = await ctx.reply(f"Fetching data from {self.cog_name}...")
|
|
# response = await self._fetch_posts_logic("prefix_internal", tags) # Pass a string to skip its deferral
|
|
# if isinstance(response, tuple):
|
|
# await loading_msg.edit(content=content, view=view)
|
|
# else: await loading_msg.edit(content=response, view=None)
|
|
# Let's stick to the original structure where _fetch_posts_logic handles the initial reply.
|
|
# The issue is editing that reply.
|
|
# The `ctx.reply` in `_fetch_posts_logic` is not returned.
|
|
# This needs a slight refactor.
|
|
# For now, we'll assume the user's message is replied to, and we send a new one. This is not ideal.
|
|
# A better way:
|
|
if (
|
|
ctx.message.reference and ctx.message.reference.message_id
|
|
): # If the bot replied
|
|
try:
|
|
reply_msg = await ctx.channel.fetch_message(
|
|
ctx.message.reference.message_id
|
|
)
|
|
if reply_msg.author == self.bot.user:
|
|
await reply_msg.edit(content="", view=view)
|
|
return
|
|
except (discord.NotFound, discord.Forbidden):
|
|
pass # Fallback to new message
|
|
await ctx.send("", view=view) # Fallback
|
|
elif isinstance(response, str): # Error
|
|
# Similar issue with editing the reply.
|
|
await ctx.send(response)
|
|
|
|
async def _slash_command_logic(
|
|
self, interaction: discord.Interaction, tags: str, hidden: bool
|
|
):
|
|
response = await self._fetch_posts_logic(interaction, tags, hidden=hidden)
|
|
|
|
if isinstance(response, tuple):
|
|
_, all_results = response
|
|
view = self.GelbooruButtons(self, tags, all_results, hidden)
|
|
if interaction.response.is_done():
|
|
await interaction.followup.send(content="", view=view, ephemeral=hidden)
|
|
else:
|
|
await interaction.response.send_message(
|
|
content="", view=view, ephemeral=hidden
|
|
)
|
|
elif isinstance(response, str): # Error
|
|
ephemeral_error = hidden
|
|
if self.is_nsfw_site and response.startswith(
|
|
f"This command for {self.cog_name} can only be used"
|
|
):
|
|
ephemeral_error = (
|
|
True # Always make NSFW warnings ephemeral if possible
|
|
)
|
|
|
|
if not interaction.response.is_done():
|
|
await interaction.response.send_message(
|
|
response, ephemeral=ephemeral_error
|
|
)
|
|
else:
|
|
try:
|
|
await interaction.followup.send(response, ephemeral=ephemeral_error)
|
|
except discord.HTTPException as e:
|
|
log.error(
|
|
f"{self.cog_name} slash command: Failed to send error followup for tags '{tags}': {e}"
|
|
)
|
|
|
|
async def _browse_slash_command_logic(
|
|
self, interaction: discord.Interaction, tags: str, hidden: bool
|
|
):
|
|
response = await self._fetch_posts_logic(interaction, tags, hidden=hidden)
|
|
|
|
if isinstance(response, tuple):
|
|
_, all_results = response
|
|
if not all_results:
|
|
content = f"No results found from {self.cog_name} for the given tags."
|
|
if not interaction.response.is_done():
|
|
await interaction.response.send_message(content, ephemeral=hidden)
|
|
else:
|
|
await interaction.followup.send(content, ephemeral=hidden)
|
|
return
|
|
|
|
view = self.BrowseView(self, tags, all_results, hidden, current_index=0)
|
|
if interaction.response.is_done():
|
|
await interaction.followup.send(content="", view=view, ephemeral=hidden)
|
|
else:
|
|
await interaction.response.send_message(
|
|
content="", view=view, ephemeral=hidden
|
|
)
|
|
elif isinstance(response, str): # Error
|
|
ephemeral_error = hidden
|
|
if self.is_nsfw_site and response.startswith(
|
|
f"This command for {self.cog_name} can only be used"
|
|
):
|
|
ephemeral_error = True
|
|
if not interaction.response.is_done():
|
|
await interaction.response.send_message(
|
|
response, ephemeral=ephemeral_error
|
|
)
|
|
else:
|
|
await interaction.followup.send(response, ephemeral=ephemeral_error)
|
|
|
|
@tasks.loop(minutes=10)
|
|
async def check_new_posts(self):
|
|
log.debug(f"Running {self.cog_name} new post check...")
|
|
if not self.subscriptions_data:
|
|
return
|
|
|
|
current_subscriptions = json.loads(json.dumps(self.subscriptions_data))
|
|
needs_save = False
|
|
|
|
for guild_id_str, subs_list in current_subscriptions.items():
|
|
if not isinstance(subs_list, list):
|
|
continue
|
|
for sub_index, sub in enumerate(subs_list):
|
|
if not isinstance(sub, dict):
|
|
continue
|
|
|
|
tags = sub.get("tags")
|
|
webhook_url = sub.get("webhook_url")
|
|
last_known_post_id = sub.get("last_known_post_id", 0)
|
|
|
|
if not tags or not webhook_url:
|
|
log.warning(
|
|
f"Subscription for {self.cog_name} (guild {guild_id_str}, sub {sub.get('subscription_id')}) missing tags/webhook."
|
|
)
|
|
continue
|
|
|
|
fetched_posts_response = await self._fetch_posts_logic(
|
|
"internal_task_call", tags, pid_override=0, limit_override=100
|
|
)
|
|
|
|
if isinstance(fetched_posts_response, str):
|
|
log.error(
|
|
f"Error fetching {self.cog_name} posts for sub {sub.get('subscription_id')} (tags: {tags}): {fetched_posts_response}"
|
|
)
|
|
continue
|
|
if not fetched_posts_response:
|
|
continue
|
|
|
|
new_posts_to_send = []
|
|
current_max_id_this_batch = last_known_post_id
|
|
|
|
for post_data in fetched_posts_response:
|
|
if (
|
|
not isinstance(post_data, dict)
|
|
or "id" not in post_data
|
|
or "file_url" not in post_data
|
|
):
|
|
log.warning(
|
|
f"Malformed {self.cog_name} post data for tags {tags}: {post_data}"
|
|
)
|
|
continue
|
|
post_id = int(post_data["id"])
|
|
if post_id > last_known_post_id:
|
|
new_posts_to_send.append(post_data)
|
|
if post_id > current_max_id_this_batch:
|
|
current_max_id_this_batch = post_id
|
|
|
|
if new_posts_to_send:
|
|
new_posts_to_send.sort(key=lambda p: int(p["id"]))
|
|
log.info(
|
|
f"Found {len(new_posts_to_send)} new {self.cog_name} post(s) for tags '{tags}', sub {sub.get('subscription_id')}"
|
|
)
|
|
latest_sent_id_for_this_sub = last_known_post_id
|
|
for new_post in new_posts_to_send:
|
|
post_id = int(new_post["id"])
|
|
target_thread_id: typing.Optional[str] = sub.get(
|
|
"target_post_id"
|
|
) or sub.get("thread_id")
|
|
|
|
view = self._build_new_post_view(
|
|
tags, new_post["file_url"], post_id
|
|
)
|
|
|
|
send_success = await self._send_via_webhook(
|
|
webhook_url,
|
|
content="",
|
|
thread_id=target_thread_id,
|
|
view=view,
|
|
)
|
|
if send_success:
|
|
latest_sent_id_for_this_sub = post_id
|
|
original_guild_subs = self.subscriptions_data.get(
|
|
guild_id_str
|
|
)
|
|
if original_guild_subs:
|
|
for original_sub_entry in original_guild_subs:
|
|
if original_sub_entry.get(
|
|
"subscription_id"
|
|
) == sub.get("subscription_id"):
|
|
original_sub_entry["last_known_post_id"] = (
|
|
latest_sent_id_for_this_sub
|
|
)
|
|
needs_save = True
|
|
break
|
|
await asyncio.sleep(1)
|
|
else:
|
|
log.error(
|
|
f"Failed to send new {self.cog_name} post via webhook for sub {sub.get('subscription_id')}."
|
|
)
|
|
break
|
|
|
|
if (
|
|
not new_posts_to_send
|
|
and current_max_id_this_batch > last_known_post_id
|
|
):
|
|
original_guild_subs = self.subscriptions_data.get(guild_id_str)
|
|
if original_guild_subs:
|
|
for original_sub_entry in original_guild_subs:
|
|
if original_sub_entry.get("subscription_id") == sub.get(
|
|
"subscription_id"
|
|
):
|
|
if (
|
|
original_sub_entry.get("last_known_post_id", 0)
|
|
< current_max_id_this_batch
|
|
):
|
|
original_sub_entry["last_known_post_id"] = (
|
|
current_max_id_this_batch
|
|
)
|
|
needs_save = True
|
|
log.debug(
|
|
f"Fast-forwarded last_known_post_id for {self.cog_name} sub {sub.get('subscription_id')} to {current_max_id_this_batch}."
|
|
)
|
|
break
|
|
if needs_save:
|
|
self._save_subscriptions()
|
|
log.debug(f"Finished {self.cog_name} new post check.")
|
|
|
|
@check_new_posts.before_loop
|
|
async def before_check_new_posts(self):
|
|
await self.bot.wait_until_ready()
|
|
log.info(
|
|
f"{self.cog_name}Cog: `check_new_posts` loop is waiting for bot readiness..."
|
|
)
|
|
if self.session is None or self.session.closed:
|
|
self.session = aiohttp.ClientSession()
|
|
log.info(
|
|
f"aiohttp ClientSession created before {self.cog_name} check_new_posts loop."
|
|
)
|
|
|
|
async def _create_new_subscription(
|
|
self,
|
|
guild_id: int,
|
|
user_id: int,
|
|
tags: str,
|
|
target_channel: typing.Union[discord.TextChannel, discord.ForumChannel],
|
|
requested_thread_target: typing.Optional[str] = None,
|
|
requested_post_title: typing.Optional[str] = None,
|
|
is_request_approval: bool = False,
|
|
requester_mention: typing.Optional[str] = None,
|
|
) -> str:
|
|
actual_target_thread_id: typing.Optional[str] = None
|
|
actual_target_thread_mention: str = ""
|
|
actual_post_title = (
|
|
requested_post_title or f"{self.cog_name} Watch: {tags[:50]}"
|
|
)
|
|
|
|
if isinstance(target_channel, discord.TextChannel):
|
|
if requested_thread_target:
|
|
found_thread: typing.Optional[discord.Thread] = None
|
|
try:
|
|
if not target_channel.guild:
|
|
return "Error: Guild context not found."
|
|
thread_as_obj = await target_channel.guild.fetch_channel(
|
|
int(requested_thread_target)
|
|
)
|
|
if (
|
|
isinstance(thread_as_obj, discord.Thread)
|
|
and thread_as_obj.parent_id == target_channel.id
|
|
):
|
|
found_thread = thread_as_obj
|
|
except (ValueError, discord.NotFound, discord.Forbidden):
|
|
pass
|
|
except Exception as e:
|
|
log.error(
|
|
f"Error fetching thread by ID '{requested_thread_target}' for {self.cog_name}: {e}"
|
|
)
|
|
|
|
if not found_thread and hasattr(target_channel, "threads"):
|
|
for t in target_channel.threads:
|
|
if t.name.lower() == requested_thread_target.lower():
|
|
found_thread = t
|
|
break
|
|
|
|
if found_thread:
|
|
actual_target_thread_id = str(found_thread.id)
|
|
actual_target_thread_mention = found_thread.mention
|
|
else:
|
|
return f"❌ Could not find thread `{requested_thread_target}` in {target_channel.mention} for {self.cog_name}."
|
|
|
|
elif isinstance(target_channel, discord.ForumChannel):
|
|
forum_post_initial_message = f"✨ **New {self.cog_name} Watch Initialized!** ✨\nNow monitoring tags: `{tags}`"
|
|
if is_request_approval and requester_mention:
|
|
forum_post_initial_message += f"\n_Requested by: {requester_mention}_"
|
|
|
|
try:
|
|
if not target_channel.permissions_for(
|
|
target_channel.guild.me
|
|
).create_public_threads:
|
|
return f"❌ I lack permission to create posts in forum {target_channel.mention} for {self.cog_name}."
|
|
|
|
new_forum_post = await target_channel.create_thread(
|
|
name=actual_post_title,
|
|
content=forum_post_initial_message,
|
|
reason=f"{self.cog_name}Watch: {tags}",
|
|
)
|
|
actual_target_thread_id = str(new_forum_post.thread.id)
|
|
actual_target_thread_mention = new_forum_post.thread.mention
|
|
log.info(
|
|
f"Created {self.cog_name} forum post {new_forum_post.thread.id} for tags '{tags}' in forum {target_channel.id}"
|
|
)
|
|
except discord.HTTPException as e:
|
|
return f"❌ Failed to create {self.cog_name} forum post in {target_channel.mention}: {e}"
|
|
except Exception as e:
|
|
log.exception(
|
|
f"Unexpected error creating {self.cog_name} forum post for tags '{tags}'"
|
|
)
|
|
return f"❌ Unexpected error creating {self.cog_name} forum post."
|
|
|
|
webhook_url = await self._get_or_create_webhook(target_channel)
|
|
if not webhook_url:
|
|
return f"❌ Failed to get/create webhook for {target_channel.mention} ({self.cog_name}). Check permissions."
|
|
|
|
initial_posts = await self._fetch_posts_logic(
|
|
"internal_initial_fetch", tags, pid_override=0, limit_override=1
|
|
)
|
|
last_known_post_id = 0
|
|
if isinstance(initial_posts, list) and initial_posts:
|
|
if isinstance(initial_posts[0], dict) and "id" in initial_posts[0]:
|
|
last_known_post_id = int(initial_posts[0]["id"])
|
|
else:
|
|
log.warning(
|
|
f"Malformed {self.cog_name} post data for initial fetch (tags: '{tags}'): {initial_posts[0]}"
|
|
)
|
|
elif isinstance(initial_posts, str):
|
|
log.error(
|
|
f"{self.cog_name} API error on initial fetch (tags: '{tags}'): {initial_posts}"
|
|
)
|
|
|
|
guild_id_str = str(guild_id)
|
|
subscription_id = str(uuid.uuid4())
|
|
|
|
new_sub_data = {
|
|
"subscription_id": subscription_id,
|
|
"tags": tags.strip(),
|
|
"webhook_url": webhook_url,
|
|
"last_known_post_id": last_known_post_id,
|
|
"added_by_user_id": str(user_id),
|
|
"added_timestamp": discord.utils.utcnow().isoformat(),
|
|
}
|
|
if isinstance(target_channel, discord.ForumChannel):
|
|
new_sub_data.update(
|
|
{
|
|
"forum_channel_id": str(target_channel.id),
|
|
"target_post_id": actual_target_thread_id,
|
|
"post_title": actual_post_title,
|
|
}
|
|
)
|
|
else:
|
|
new_sub_data.update(
|
|
{
|
|
"channel_id": str(target_channel.id),
|
|
"thread_id": actual_target_thread_id,
|
|
}
|
|
)
|
|
|
|
if guild_id_str not in self.subscriptions_data:
|
|
self.subscriptions_data[guild_id_str] = []
|
|
|
|
for existing_sub in self.subscriptions_data[guild_id_str]: # Duplicate check
|
|
is_dup_tags = existing_sub.get("tags") == new_sub_data["tags"]
|
|
is_dup_forum = (
|
|
isinstance(target_channel, discord.ForumChannel)
|
|
and existing_sub.get("forum_channel_id")
|
|
== new_sub_data.get("forum_channel_id")
|
|
and existing_sub.get("target_post_id")
|
|
== new_sub_data.get("target_post_id")
|
|
)
|
|
is_dup_text_chan = (
|
|
isinstance(target_channel, discord.TextChannel)
|
|
and existing_sub.get("channel_id") == new_sub_data.get("channel_id")
|
|
and existing_sub.get("thread_id") == new_sub_data.get("thread_id")
|
|
)
|
|
if is_dup_tags and (is_dup_forum or is_dup_text_chan):
|
|
return f"⚠️ A {self.cog_name} subscription for these tags in this location already exists (ID: `{existing_sub.get('subscription_id')}`)."
|
|
|
|
self.subscriptions_data[guild_id_str].append(new_sub_data)
|
|
self._save_subscriptions()
|
|
|
|
target_desc = f"in {target_channel.mention}"
|
|
if (
|
|
isinstance(target_channel, discord.ForumChannel)
|
|
and actual_target_thread_mention
|
|
):
|
|
target_desc = f"in forum post {actual_target_thread_mention} within {target_channel.mention}"
|
|
elif actual_target_thread_mention:
|
|
target_desc += f" (thread: {actual_target_thread_mention})"
|
|
|
|
log.info(
|
|
f"{self.cog_name} subscription added: Guild {guild_id_str}, Tags '{tags}', Target {target_desc}, SubID {subscription_id}"
|
|
)
|
|
return (
|
|
f"✅ Watching {self.cog_name} for new posts with tags `{tags}` {target_desc}.\n"
|
|
f"Initial latest post ID: {last_known_post_id}. Sub ID: `{subscription_id}`."
|
|
)
|
|
|
|
# --- Watch command group logic methods ---
|
|
# These will be called by the specific cog's command handlers
|
|
|
|
async def _watch_add_logic(
|
|
self,
|
|
interaction: discord.Interaction,
|
|
tags: str,
|
|
channel: typing.Union[discord.TextChannel, discord.ForumChannel],
|
|
thread_target: typing.Optional[str],
|
|
post_title: typing.Optional[str],
|
|
):
|
|
if (
|
|
not interaction.guild_id or not interaction.user
|
|
): # Should be caught by permissions but good check
|
|
await interaction.followup.send(
|
|
"Command error: Missing guild or user context.", ephemeral=True
|
|
)
|
|
return
|
|
|
|
if isinstance(channel, discord.TextChannel) and post_title:
|
|
await interaction.followup.send(
|
|
"`post_title` is only for Forum Channels.", ephemeral=True
|
|
)
|
|
return
|
|
if isinstance(channel, discord.ForumChannel) and thread_target:
|
|
await interaction.followup.send(
|
|
"`thread_target` is only for Text Channels.", ephemeral=True
|
|
)
|
|
return
|
|
|
|
response_message = await self._create_new_subscription(
|
|
guild_id=interaction.guild_id,
|
|
user_id=interaction.user.id,
|
|
tags=tags,
|
|
target_channel=channel,
|
|
requested_thread_target=thread_target,
|
|
requested_post_title=post_title,
|
|
)
|
|
await interaction.followup.send(response_message, ephemeral=True)
|
|
|
|
async def _watch_request_logic(
|
|
self,
|
|
interaction: discord.Interaction,
|
|
tags: str,
|
|
forum_channel: discord.ForumChannel,
|
|
post_title: typing.Optional[str],
|
|
):
|
|
if not interaction.guild_id or not interaction.user:
|
|
await interaction.followup.send(
|
|
"Command error: Missing guild or user context.", ephemeral=True
|
|
)
|
|
return
|
|
|
|
guild_id_str = str(interaction.guild_id)
|
|
request_id = str(uuid.uuid4())
|
|
actual_post_title = post_title or f"{self.cog_name} Watch: {tags[:50]}"
|
|
|
|
new_request = {
|
|
"request_id": request_id,
|
|
"requester_id": str(interaction.user.id),
|
|
"requester_name": str(interaction.user),
|
|
"requested_tags": tags.strip(),
|
|
"target_forum_channel_id": str(forum_channel.id),
|
|
"requested_post_title": actual_post_title,
|
|
"status": "pending",
|
|
"request_timestamp": discord.utils.utcnow().isoformat(),
|
|
"moderator_id": None,
|
|
"moderation_timestamp": None,
|
|
}
|
|
|
|
if guild_id_str not in self.pending_requests_data:
|
|
self.pending_requests_data[guild_id_str] = []
|
|
self.pending_requests_data[guild_id_str].append(new_request)
|
|
self._save_pending_requests()
|
|
|
|
log.info(
|
|
f"New {self.cog_name} watch request: Guild {guild_id_str}, Requester {interaction.user}, Tags '{tags}', ReqID {request_id}"
|
|
)
|
|
|
|
await interaction.followup.send(
|
|
f'✅ Your {self.cog_name} watch request for tags `{tags}` in forum {forum_channel.mention} (title: "{actual_post_title}") submitted.\n'
|
|
f"Request ID: `{request_id}`. Awaiting moderator approval."
|
|
)
|
|
|
|
async def _watch_pending_list_logic(self, interaction: discord.Interaction):
|
|
if (
|
|
not interaction.guild_id or not interaction.guild
|
|
): # Ensure guild object for name
|
|
await interaction.response.send_message(
|
|
"Command error: Missing guild context.", ephemeral=True
|
|
)
|
|
return
|
|
|
|
guild_id_str = str(interaction.guild_id)
|
|
pending_reqs = [
|
|
req
|
|
for req in self.pending_requests_data.get(guild_id_str, [])
|
|
if req.get("status") == "pending"
|
|
]
|
|
|
|
if not pending_reqs:
|
|
await interaction.response.send_message(
|
|
f"No pending {self.cog_name} watch requests.", ephemeral=True
|
|
)
|
|
return
|
|
|
|
embed = discord.Embed(
|
|
title=f"Pending {self.cog_name} Watch Requests for {interaction.guild.name}",
|
|
color=discord.Color.orange(),
|
|
)
|
|
desc_parts = []
|
|
for req in pending_reqs:
|
|
forum_mention = f"<#{req.get('target_forum_channel_id', 'Unknown')}>"
|
|
timestamp_str = req.get("request_timestamp")
|
|
time_fmt = (
|
|
discord.utils.format_dt(
|
|
datetime.fromisoformat(timestamp_str), style="R"
|
|
)
|
|
if timestamp_str
|
|
else "Unknown time"
|
|
)
|
|
desc_parts.append(
|
|
f"**ID:** `{req.get('request_id')}`\n"
|
|
f" **Requester:** {req.get('requester_name', 'Unknown')} (`{req.get('requester_id')}`)\n"
|
|
f" **Tags:** `{req.get('requested_tags')}`\n"
|
|
f" **Target Forum:** {forum_mention}\n"
|
|
f" **Proposed Title:** \"{req.get('requested_post_title')}\"\n"
|
|
f" **Requested:** {time_fmt}\n---"
|
|
)
|
|
|
|
embed.description = "\n".join(desc_parts)[:4096] # Limit description length
|
|
await interaction.response.send_message(embed=embed, ephemeral=True)
|
|
|
|
async def _watch_approve_request_logic(
|
|
self, interaction: discord.Interaction, request_id: str
|
|
):
|
|
if (
|
|
not interaction.guild_id or not interaction.user or not interaction.guild
|
|
): # Need guild for fetch_channel and name
|
|
await interaction.followup.send(
|
|
"Command error: Missing context.", ephemeral=True
|
|
)
|
|
return
|
|
|
|
guild_id_str = str(interaction.guild_id)
|
|
req_to_approve, req_idx = None, -1
|
|
for i, r in enumerate(self.pending_requests_data.get(guild_id_str, [])):
|
|
if r.get("request_id") == request_id and r.get("status") == "pending":
|
|
req_to_approve, req_idx = r, i
|
|
break
|
|
|
|
if not req_to_approve:
|
|
await interaction.followup.send(
|
|
f"❌ Pending {self.cog_name} request ID `{request_id}` not found.",
|
|
ephemeral=True,
|
|
)
|
|
return
|
|
|
|
target_forum_channel: typing.Optional[discord.ForumChannel] = None
|
|
try:
|
|
target_forum_channel = await interaction.guild.fetch_channel(
|
|
int(req_to_approve["target_forum_channel_id"])
|
|
)
|
|
if not isinstance(target_forum_channel, discord.ForumChannel):
|
|
raise ValueError("Not a ForumChannel")
|
|
except Exception:
|
|
await interaction.followup.send(
|
|
f"❌ Target forum channel for {self.cog_name} request `{request_id}` not found/invalid.",
|
|
ephemeral=True,
|
|
)
|
|
return
|
|
|
|
creation_response = await self._create_new_subscription(
|
|
guild_id=interaction.guild_id,
|
|
user_id=int(req_to_approve["requester_id"]),
|
|
tags=req_to_approve["requested_tags"],
|
|
target_channel=target_forum_channel,
|
|
requested_post_title=req_to_approve["requested_post_title"],
|
|
is_request_approval=True,
|
|
requester_mention=f"<@{req_to_approve['requester_id']}>",
|
|
)
|
|
|
|
if creation_response.startswith("✅"):
|
|
req_to_approve.update(
|
|
{
|
|
"status": "approved",
|
|
"moderator_id": str(interaction.user.id),
|
|
"moderation_timestamp": discord.utils.utcnow().isoformat(),
|
|
}
|
|
)
|
|
self.pending_requests_data[guild_id_str][req_idx] = req_to_approve
|
|
self._save_pending_requests()
|
|
await interaction.followup.send(
|
|
f"✅ {self.cog_name} Request ID `{request_id}` approved. {creation_response}",
|
|
ephemeral=True,
|
|
)
|
|
try:
|
|
requester = await self.bot.fetch_user(
|
|
int(req_to_approve["requester_id"])
|
|
)
|
|
await requester.send(
|
|
f"🎉 Your {self.cog_name} watch request (`{request_id}`) for tags `{req_to_approve['requested_tags']}` in server `{interaction.guild.name}` was **approved** by {interaction.user.mention}!\nDetails: {creation_response}"
|
|
)
|
|
except Exception as e:
|
|
log.error(
|
|
f"Failed to notify {self.cog_name} requester {req_to_approve['requester_id']} of approval: {e}"
|
|
)
|
|
else:
|
|
await interaction.followup.send(
|
|
f"❌ Failed to approve {self.cog_name} request `{request_id}`. Sub creation failed: {creation_response}",
|
|
ephemeral=True,
|
|
)
|
|
|
|
async def _watch_reject_request_logic(
|
|
self,
|
|
interaction: discord.Interaction,
|
|
request_id: str,
|
|
reason: typing.Optional[str],
|
|
):
|
|
if (
|
|
not interaction.guild_id or not interaction.user or not interaction.guild
|
|
): # Need guild for name
|
|
await interaction.followup.send(
|
|
"Command error: Missing context.", ephemeral=True
|
|
)
|
|
return
|
|
|
|
guild_id_str = str(interaction.guild_id)
|
|
req_to_reject, req_idx = None, -1
|
|
for i, r in enumerate(self.pending_requests_data.get(guild_id_str, [])):
|
|
if r.get("request_id") == request_id and r.get("status") == "pending":
|
|
req_to_reject, req_idx = r, i
|
|
break
|
|
|
|
if not req_to_reject:
|
|
await interaction.followup.send(
|
|
f"❌ Pending {self.cog_name} request ID `{request_id}` not found.",
|
|
ephemeral=True,
|
|
)
|
|
return
|
|
|
|
req_to_reject.update(
|
|
{
|
|
"status": "rejected",
|
|
"moderator_id": str(interaction.user.id),
|
|
"moderation_timestamp": discord.utils.utcnow().isoformat(),
|
|
"rejection_reason": reason,
|
|
}
|
|
)
|
|
self.pending_requests_data[guild_id_str][req_idx] = req_to_reject
|
|
self._save_pending_requests()
|
|
await interaction.followup.send(
|
|
f"🗑️ {self.cog_name} Request ID `{request_id}` rejected.", ephemeral=True
|
|
)
|
|
try:
|
|
requester = await self.bot.fetch_user(int(req_to_reject["requester_id"]))
|
|
msg = f"😥 Your {self.cog_name} watch request (`{request_id}`) for tags `{req_to_reject['requested_tags']}` in server `{interaction.guild.name}` was **rejected** by {interaction.user.mention}."
|
|
if reason:
|
|
msg += f"\nReason: {reason}"
|
|
await requester.send(msg)
|
|
except Exception as e:
|
|
log.error(
|
|
f"Failed to notify {self.cog_name} requester {req_to_reject['requester_id']} of rejection: {e}"
|
|
)
|
|
|
|
async def _watch_list_logic(self, interaction: discord.Interaction):
|
|
if not interaction.guild_id or not interaction.guild:
|
|
await interaction.response.send_message(
|
|
"Command error: Missing guild context.", ephemeral=True
|
|
)
|
|
return
|
|
|
|
guild_id_str = str(interaction.guild_id)
|
|
guild_subs = self.subscriptions_data.get(guild_id_str, [])
|
|
if not guild_subs:
|
|
await interaction.response.send_message(
|
|
f"No active {self.cog_name} tag watches for this server.",
|
|
ephemeral=True,
|
|
)
|
|
return
|
|
|
|
embed = discord.Embed(
|
|
title=f"Active {self.cog_name} Tag Watches for {interaction.guild.name}",
|
|
color=discord.Color.blue(),
|
|
)
|
|
desc_parts = []
|
|
for sub in guild_subs:
|
|
target_location = "Unknown Target"
|
|
# Simplified location fetching for brevity, original logic was more robust
|
|
if sub.get("forum_channel_id") and sub.get("target_post_id"):
|
|
target_location = f"Forum Post <#{sub.get('target_post_id')}> in Forum <#{sub.get('forum_channel_id')}>"
|
|
elif sub.get("channel_id"):
|
|
target_location = f"<#{sub.get('channel_id')}>"
|
|
if sub.get("thread_id"):
|
|
target_location += f" (Thread <#{sub.get('thread_id')}>)"
|
|
|
|
desc_parts.append(
|
|
f"**ID:** `{sub.get('subscription_id')}`\n **Tags:** `{sub.get('tags')}`\n **Target:** {target_location}\n **Last Sent ID:** `{sub.get('last_known_post_id', 'N/A')}`\n---"
|
|
)
|
|
|
|
embed.description = "\n".join(desc_parts)[:4096]
|
|
await interaction.response.send_message(embed=embed, ephemeral=True)
|
|
|
|
async def _watch_remove_logic(
|
|
self, interaction: discord.Interaction, subscription_id: str
|
|
):
|
|
if not interaction.guild_id:
|
|
await interaction.response.send_message(
|
|
"Command error: Missing guild context.", ephemeral=True
|
|
)
|
|
return
|
|
|
|
guild_id_str = str(interaction.guild_id)
|
|
guild_subs = self.subscriptions_data.get(guild_id_str, [])
|
|
removed_info, found = None, False
|
|
new_subs_list = []
|
|
for sub_entry in guild_subs:
|
|
if sub_entry.get("subscription_id") == subscription_id:
|
|
removed_info = f"tags `{sub_entry.get('tags')}`" # Simplified
|
|
found = True
|
|
log.info(
|
|
f"Removing {self.cog_name} watch: Guild {guild_id_str}, Sub ID {subscription_id}"
|
|
)
|
|
else:
|
|
new_subs_list.append(sub_entry)
|
|
|
|
if not found:
|
|
await interaction.response.send_message(
|
|
f"❌ {self.cog_name} subscription ID `{subscription_id}` not found.",
|
|
ephemeral=True,
|
|
)
|
|
return
|
|
|
|
if not new_subs_list:
|
|
del self.subscriptions_data[guild_id_str]
|
|
else:
|
|
self.subscriptions_data[guild_id_str] = new_subs_list
|
|
self._save_subscriptions()
|
|
await interaction.response.send_message(
|
|
f"✅ Removed {self.cog_name} watch for {removed_info} (ID: `{subscription_id}`).",
|
|
ephemeral=True,
|
|
)
|
|
|
|
async def _watch_test_message_logic(
|
|
self, interaction: discord.Interaction, subscription_id: str
|
|
):
|
|
"""Sends a test new-post message for a given subscription."""
|
|
if not interaction.guild_id:
|
|
await interaction.response.send_message(
|
|
"Command error: Missing guild context.", ephemeral=True
|
|
)
|
|
return
|
|
|
|
guild_id_str = str(interaction.guild_id)
|
|
guild_subs = self.subscriptions_data.get(guild_id_str, [])
|
|
target_sub = None
|
|
for sub_entry in guild_subs:
|
|
if sub_entry.get("subscription_id") == subscription_id:
|
|
target_sub = sub_entry
|
|
break
|
|
|
|
if not target_sub:
|
|
await interaction.response.send_message(
|
|
f"❌ {self.cog_name} subscription ID `{subscription_id}` not found.",
|
|
ephemeral=True,
|
|
)
|
|
return
|
|
|
|
tags = target_sub.get("tags")
|
|
webhook_url = target_sub.get("webhook_url")
|
|
if not tags or not webhook_url:
|
|
await interaction.response.send_message(
|
|
f"❌ Subscription `{subscription_id}` is missing tags or webhook information.",
|
|
ephemeral=True,
|
|
)
|
|
return
|
|
|
|
fetched = await self._fetch_posts_logic(
|
|
"internal_test_msg", tags, pid_override=0, limit_override=1
|
|
)
|
|
file_url = None
|
|
post_id = None
|
|
if isinstance(fetched, list) and fetched:
|
|
first = fetched[0]
|
|
if isinstance(first, dict):
|
|
file_url = first.get("file_url")
|
|
if "id" in first:
|
|
post_id = int(first["id"])
|
|
|
|
if not file_url or post_id is None:
|
|
await interaction.response.send_message(
|
|
f"❌ Failed to fetch a post for tags `{tags}`.", ephemeral=True
|
|
)
|
|
return
|
|
|
|
view = self._build_new_post_view(tags, file_url, post_id)
|
|
thread_id = target_sub.get("target_post_id") or target_sub.get("thread_id")
|
|
send_success = await self._send_via_webhook(
|
|
webhook_url, content="", thread_id=thread_id, view=view
|
|
)
|
|
if send_success:
|
|
await interaction.response.send_message(
|
|
f"✅ Test {self.cog_name} post sent for subscription `{subscription_id}`.",
|
|
ephemeral=True,
|
|
)
|
|
else:
|
|
await interaction.response.send_message(
|
|
f"❌ Failed to send test {self.cog_name} post for subscription `{subscription_id}`.",
|
|
ephemeral=True,
|
|
)
|