323 lines
17 KiB
Python
323 lines
17 KiB
Python
import discord
|
|
from discord.ext import commands, tasks
|
|
from discord import app_commands
|
|
import typing # Need this for Optional
|
|
import logging
|
|
import re # For _get_response_text, if copied
|
|
|
|
# Google Generative AI Imports (using Vertex AI backend)
|
|
from google import genai
|
|
from google.genai import types
|
|
from google.api_core import exceptions as google_exceptions
|
|
|
|
# Import project configuration for Vertex AI
|
|
from gurt.config import PROJECT_ID, LOCATION
|
|
|
|
from .gelbooru_watcher_base_cog import GelbooruWatcherBaseCog
|
|
|
|
# Setup logger for this cog
|
|
log = logging.getLogger(__name__)
|
|
|
|
# Define standard safety settings using google.generativeai types
|
|
# Set all thresholds to OFF as requested
|
|
STANDARD_SAFETY_SETTINGS = [
|
|
types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold="BLOCK_NONE"),
|
|
types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold="BLOCK_NONE"),
|
|
types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold="BLOCK_NONE"),
|
|
types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold="BLOCK_NONE"),
|
|
]
|
|
|
|
# --- Helper Function to Safely Extract Text (copied from teto_cog.py) ---
|
|
def _get_response_text(response: typing.Optional[types.GenerateContentResponse]) -> typing.Optional[str]:
|
|
"""
|
|
Safely extracts the text content from the first text part of a GenerateContentResponse.
|
|
Handles potential errors and lack of text parts gracefully.
|
|
"""
|
|
if not response:
|
|
# log.debug("[_get_response_text] Received None response object.")
|
|
return None
|
|
|
|
if hasattr(response, 'text') and response.text:
|
|
# log.debug("[_get_response_text] Found text directly in response.text attribute.")
|
|
return response.text
|
|
|
|
if not response.candidates:
|
|
# log.debug(f"[_get_response_text] Response object has no candidates. Response: {response}")
|
|
return None
|
|
|
|
try:
|
|
candidate = response.candidates[0]
|
|
if not hasattr(candidate, 'content') or not candidate.content:
|
|
# log.debug(f"[_get_response_text] Candidate 0 has no 'content'. Candidate: {candidate}")
|
|
return None
|
|
if not hasattr(candidate.content, 'parts') or not candidate.content.parts:
|
|
# log.debug(f"[_get_response_text] Candidate 0 content has no 'parts' or parts list is empty. types.Content: {candidate.content}")
|
|
return None
|
|
|
|
for i, part in enumerate(candidate.content.parts):
|
|
if hasattr(part, 'text') and part.text is not None:
|
|
if isinstance(part.text, str) and part.text.strip():
|
|
# log.debug(f"[_get_response_text] Found non-empty text in part {i}.")
|
|
return part.text
|
|
else:
|
|
# log.debug(f"[_get_response_text] types.Part {i} has 'text' attribute, but it's empty or not a string: {part.text!r}")
|
|
pass # Continue searching
|
|
# log.debug(f"[_get_response_text] No usable text part found in candidate 0 after iterating through all parts.")
|
|
return None
|
|
|
|
except (AttributeError, IndexError, TypeError) as e:
|
|
# log.warning(f"[_get_response_text] Error accessing response structure: {type(e).__name__}: {e}")
|
|
# log.debug(f"Problematic response object: {response}")
|
|
return None
|
|
except Exception as e:
|
|
# log.error(f"[_get_response_text] Unexpected error extracting text: {e}")
|
|
# log.debug(f"Response object during error: {response}")
|
|
return None
|
|
|
|
class Rule34Cog(GelbooruWatcherBaseCog): # Removed name="Rule34"
|
|
# Define the command group specific to this cog
|
|
r34watch = app_commands.Group(name="r34watch", description="Manage Rule34 tag watchers for new posts.")
|
|
|
|
def __init__(self, bot: commands.Bot):
|
|
super().__init__(
|
|
bot=bot,
|
|
cog_name="Rule34",
|
|
api_base_url="https://api.rule34.xxx/index.php",
|
|
default_tags="kasane_teto breast_milk", # Example default, will be overridden if tags are required
|
|
is_nsfw_site=True,
|
|
command_group_name="r34watch", # For potential use in base class messages
|
|
main_command_name="rule34" # For potential use in base class messages
|
|
)
|
|
# Initialize Google GenAI Client for Vertex AI
|
|
try:
|
|
if PROJECT_ID and LOCATION:
|
|
self.genai_client = genai.Client(
|
|
vertexai=True,
|
|
project=PROJECT_ID,
|
|
location=LOCATION,
|
|
)
|
|
log.info(f"Rule34Cog: Google GenAI Client initialized for Vertex AI project '{PROJECT_ID}' in location '{LOCATION}'.")
|
|
else:
|
|
self.genai_client = None
|
|
log.warning("Rule34Cog: PROJECT_ID or LOCATION not found in config. Google GenAI Client not initialized.")
|
|
except Exception as e:
|
|
self.genai_client = None
|
|
log.error(f"Rule34Cog: Error initializing Google GenAI Client for Vertex AI: {e}")
|
|
|
|
self.tag_transformer_model = "gemini-2.0-flash-lite-001" # Hardcoded as per request
|
|
# The __init__ in base class handles session creation and task starting.
|
|
|
|
async def _transform_tags_ai(self, user_tags: str) -> typing.Optional[str]:
|
|
"""Transforms user-provided tags into rule34-style tags using AI."""
|
|
if not self.genai_client:
|
|
log.warning("Rule34Cog: GenAI client not initialized, cannot transform tags.")
|
|
return None
|
|
if not user_tags:
|
|
return "" # Return empty if no tags provided to transform
|
|
|
|
system_prompt_text = (
|
|
"You are an AI assistant specialized in transforming user-provided text into tags suitable for rule34.xxx. "
|
|
"Your task is to convert natural language input into a space-separated list of tags, where multi-word concepts are joined by underscores. "
|
|
"For example, if the user provides 'hatsune miku blue hair', you should output 'hatsune_miku blue_hair'. "
|
|
"Only output the transformed tags. Do not add any other text, explanations, or greetings."
|
|
)
|
|
|
|
prompt_parts = [
|
|
types.Part(text=system_prompt_text),
|
|
types.Part(text=f"User input: \"{user_tags}\""),
|
|
types.Part(text="Transformed tags:"),
|
|
]
|
|
contents_for_api = [types.Content(role="user", parts=prompt_parts)]
|
|
|
|
generation_config = types.GenerateContentConfig(
|
|
temperature=0.2, # Low temperature for more deterministic output
|
|
max_output_tokens=256,
|
|
safety_settings=STANDARD_SAFETY_SETTINGS,
|
|
)
|
|
|
|
try:
|
|
log.debug(f"Rule34Cog: Sending to Vertex AI for tag transformation. Model: {self.tag_transformer_model}, Input: '{user_tags}'")
|
|
response = await self.genai_client.aio.models.generate_content(
|
|
model=f"publishers/google/models/{self.tag_transformer_model}",
|
|
contents=contents_for_api,
|
|
config=generation_config,
|
|
)
|
|
transformed_tags = _get_response_text(response)
|
|
if transformed_tags:
|
|
log.info(f"Rule34Cog: Tags transformed: '{user_tags}' -> '{transformed_tags.strip()}'")
|
|
return transformed_tags.strip()
|
|
else:
|
|
log.warning(f"Rule34Cog: AI tag transformation returned empty for input: '{user_tags}'. Response: {response}")
|
|
return None
|
|
except google_exceptions.GoogleAPICallError as e:
|
|
log.error(f"Rule34Cog: Vertex AI API call failed for tag transformation: {e}")
|
|
return None
|
|
except Exception as e:
|
|
log.error(f"Rule34Cog: Unexpected error during AI tag transformation: {e}")
|
|
return None
|
|
|
|
# --- Prefix Command ---
|
|
@commands.command(name="rule34")
|
|
async def rule34_prefix(self, ctx: commands.Context, *, tags: str):
|
|
"""Search for images on Rule34 with the provided tags."""
|
|
if not tags: # Should not happen due to 'tags: str' but as a safeguard
|
|
await ctx.reply("Please provide tags to search for.")
|
|
return
|
|
|
|
loading_msg = await ctx.reply(f"Transforming tags and fetching data from {self.cog_name}, please wait...")
|
|
|
|
transformed_tags = await self._transform_tags_ai(tags)
|
|
if transformed_tags is None: # AI transformation failed
|
|
await loading_msg.edit(content=f"Sorry, I couldn't transform the tags using AI. Please try again or use rule34-formatted tags directly. Original tags: `{tags}`")
|
|
return
|
|
if not transformed_tags: # AI returned empty
|
|
await loading_msg.edit(content=f"Sorry, the AI couldn't understand the tags provided: `{tags}`. Please try rephrasing.")
|
|
return
|
|
|
|
final_tags = f"{transformed_tags} -ai_generated"
|
|
log.info(f"Rule34Cog (Prefix): Using final tags: '{final_tags}' from original: '{tags}'")
|
|
|
|
response = await self._fetch_posts_logic("prefix_internal", final_tags, hidden=False)
|
|
|
|
if isinstance(response, tuple):
|
|
content, all_results = response
|
|
view = self.GelbooruButtons(self, final_tags, all_results, hidden=False) # Pass final_tags to buttons
|
|
await loading_msg.edit(content=content, view=view)
|
|
elif isinstance(response, str): # Error
|
|
await loading_msg.edit(content=response, view=None)
|
|
|
|
|
|
# --- Slash Command ---
|
|
@app_commands.command(name="rule34", description="Get random image from Rule34 with specified tags")
|
|
@app_commands.describe(
|
|
tags="The tags to search for (e.g., 'hatsune miku rating:safe')", # Updated example
|
|
hidden="Set to True to make the response visible only to you (default: False)"
|
|
)
|
|
async def rule34_slash(self, interaction: discord.Interaction, tags: str, hidden: bool = False):
|
|
"""Slash command version of rule34."""
|
|
await interaction.response.defer(thinking=True, ephemeral=hidden)
|
|
|
|
transformed_tags = await self._transform_tags_ai(tags)
|
|
if transformed_tags is None:
|
|
await interaction.followup.send(f"Sorry, I couldn't transform the tags using AI. Please try again or use rule34-formatted tags directly. Original tags: `{tags}`", ephemeral=True)
|
|
return
|
|
if not transformed_tags:
|
|
await interaction.followup.send(f"Sorry, the AI couldn't understand the tags provided: `{tags}`. Please try rephrasing.", ephemeral=True)
|
|
return
|
|
|
|
final_tags = f"{transformed_tags} -ai_generated"
|
|
log.info(f"Rule34Cog (Slash): Using final tags: '{final_tags}' from original: '{tags}'")
|
|
|
|
# _slash_command_logic handles its own deferral/response
|
|
await self._slash_command_logic(interaction, final_tags, hidden, already_deferred=True)
|
|
|
|
|
|
# --- New Browse Command ---
|
|
@app_commands.command(name="rule34browse", description="Browse Rule34 results with navigation buttons")
|
|
@app_commands.describe(
|
|
tags="The tags to search for (e.g., 'hatsune miku rating:safe')", # Updated example
|
|
hidden="Set to True to make the response visible only to you (default: False)"
|
|
)
|
|
async def rule34_browse_slash(self, interaction: discord.Interaction, tags: str, hidden: bool = False):
|
|
"""Browse Rule34 results with navigation buttons."""
|
|
await interaction.response.defer(thinking=True, ephemeral=hidden)
|
|
|
|
transformed_tags = await self._transform_tags_ai(tags)
|
|
if transformed_tags is None:
|
|
await interaction.followup.send(f"Sorry, I couldn't transform the tags using AI. Please try again or use rule34-formatted tags directly. Original tags: `{tags}`", ephemeral=True)
|
|
return
|
|
if not transformed_tags:
|
|
await interaction.followup.send(f"Sorry, the AI couldn't understand the tags provided: `{tags}`. Please try rephrasing.", ephemeral=True)
|
|
return
|
|
|
|
final_tags = f"{transformed_tags} -ai_generated"
|
|
log.info(f"Rule34Cog (Browse): Using final tags: '{final_tags}' from original: '{tags}'")
|
|
|
|
# _browse_slash_command_logic handles its own deferral/response
|
|
await self._browse_slash_command_logic(interaction, final_tags, hidden, already_deferred=True)
|
|
|
|
# --- r34watch slash command group ---
|
|
# All subcommands will call the corresponding _watch_..._logic methods from the base class.
|
|
|
|
@r34watch.command(name="add", description="Watch for new Rule34 posts with specific tags in a channel or thread.")
|
|
@app_commands.describe(
|
|
tags="The tags to search for (e.g., 'kasane_teto rating:safe').",
|
|
channel="The parent channel for the subscription. Must be a Forum Channel if using forum mode.",
|
|
thread_target="Optional: Name or ID of a thread within the channel (for TextChannels only).",
|
|
post_title="Optional: Title for a new forum post if 'channel' is a Forum Channel."
|
|
)
|
|
@app_commands.checks.has_permissions(manage_guild=True)
|
|
async def r34watch_add(self, interaction: discord.Interaction, tags: str, channel: typing.Union[discord.TextChannel, discord.ForumChannel], thread_target: typing.Optional[str] = None, post_title: typing.Optional[str] = None):
|
|
await interaction.response.defer(ephemeral=True) # Defer here before calling base logic
|
|
final_tags = f"{tags} -ai_generated"
|
|
log.info(f"Rule34Cog (Watch Add): Using final tags for watch: '{final_tags}' from original: '{tags}'")
|
|
await self._watch_add_logic(interaction, final_tags, channel, thread_target, post_title)
|
|
|
|
@r34watch.command(name="request", description="Request a new Rule34 tag watch (requires moderator approval).")
|
|
@app_commands.describe(
|
|
tags="The tags you want to watch.",
|
|
forum_channel="The Forum Channel where a new post for this watch should be created.",
|
|
post_title="Optional: A title for the new forum post (defaults to tags)."
|
|
)
|
|
async def r34watch_request(self, interaction: discord.Interaction, tags: str, forum_channel: discord.ForumChannel, post_title: typing.Optional[str] = None):
|
|
await interaction.response.defer(ephemeral=True)
|
|
final_tags = f"{tags} -ai_generated"
|
|
log.info(f"Rule34Cog (Watch Request): Using final tags for watch request: '{final_tags}' from original: '{tags}'")
|
|
await self._watch_request_logic(interaction, final_tags, forum_channel, post_title)
|
|
|
|
@r34watch.command(name="pending_list", description="Lists all pending Rule34 watch requests.")
|
|
@app_commands.checks.has_permissions(manage_guild=True)
|
|
async def r34watch_pending_list(self, interaction: discord.Interaction):
|
|
# No defer needed if _watch_pending_list_logic handles it or is quick
|
|
await self._watch_pending_list_logic(interaction)
|
|
|
|
@r34watch.command(name="approve_request", description="Approves a pending Rule34 watch request.")
|
|
@app_commands.describe(request_id="The ID of the request to approve.")
|
|
@app_commands.checks.has_permissions(manage_guild=True)
|
|
async def r34watch_approve_request(self, interaction: discord.Interaction, request_id: str):
|
|
await interaction.response.defer(ephemeral=True)
|
|
await self._watch_approve_request_logic(interaction, request_id)
|
|
|
|
@r34watch.command(name="reject_request", description="Rejects a pending Rule34 watch request.")
|
|
@app_commands.describe(request_id="The ID of the request to reject.", reason="Optional reason for rejection.")
|
|
@app_commands.checks.has_permissions(manage_guild=True)
|
|
async def r34watch_reject_request(self, interaction: discord.Interaction, request_id: str, reason: typing.Optional[str] = None):
|
|
await interaction.response.defer(ephemeral=True)
|
|
await self._watch_reject_request_logic(interaction, request_id, reason)
|
|
|
|
@r34watch.command(name="list", description="List active Rule34 tag watches for this server.")
|
|
@app_commands.checks.has_permissions(manage_guild=True)
|
|
async def r34watch_list(self, interaction: discord.Interaction):
|
|
# No defer needed if _watch_list_logic handles it or is quick
|
|
await self._watch_list_logic(interaction)
|
|
|
|
@r34watch.command(name="remove", description="Stop watching for new Rule34 posts using a subscription ID.")
|
|
@app_commands.describe(subscription_id="The ID of the subscription to remove (get from 'list' command).")
|
|
@app_commands.checks.has_permissions(manage_guild=True)
|
|
async def r34watch_remove(self, interaction: discord.Interaction, subscription_id: str):
|
|
# No defer needed if _watch_remove_logic handles it or is quick
|
|
await self._watch_remove_logic(interaction, subscription_id)
|
|
|
|
@app_commands.command(name="rule34debug_transform", description="Debug command to test AI tag transformation.")
|
|
@app_commands.describe(tags="The tags to test transformation for (e.g., 'hatsune miku')")
|
|
async def rule34debug_transform(self, interaction: discord.Interaction, tags: str):
|
|
await interaction.response.defer(ephemeral=True, thinking=True)
|
|
|
|
transformed_tags = await self._transform_tags_ai(tags)
|
|
|
|
if transformed_tags is None:
|
|
response_content = f"AI transformation failed for tags: `{tags}`. Check logs for details."
|
|
elif not transformed_tags:
|
|
response_content = f"AI returned empty for tags: `{tags}`. Please try rephrasing."
|
|
else:
|
|
response_content = (
|
|
f"Original tags: `{tags}`\n"
|
|
f"Transformed tags: `{transformed_tags}`"
|
|
)
|
|
|
|
await interaction.followup.send(response_content, ephemeral=True)
|
|
|
|
async def setup(bot: commands.Bot):
|
|
await bot.add_cog(Rule34Cog(bot))
|
|
log.info("Rule34Cog (refactored) added to bot.")
|