473 lines
20 KiB
Python
473 lines
20 KiB
Python
import discord
|
|
from discord.ext import commands, tasks
|
|
from discord import app_commands
|
|
import typing # Need this for Optional
|
|
import logging
|
|
import re # For _get_response_text, if copied
|
|
|
|
# Google Generative AI Imports (using Vertex AI backend)
|
|
from google import genai
|
|
from google.genai import types
|
|
from google.api_core import exceptions as google_exceptions
|
|
|
|
# Import project configuration for Vertex AI
|
|
from gurt.config import PROJECT_ID, LOCATION
|
|
from gurt.api import get_genai_client_for_model
|
|
|
|
from .gelbooru_watcher_base_cog import GelbooruWatcherBaseCog
|
|
|
|
# Setup logger for this cog
|
|
log = logging.getLogger(__name__)
|
|
|
|
# Define standard safety settings using google.generativeai types
|
|
# Set all thresholds to OFF as requested
|
|
STANDARD_SAFETY_SETTINGS = [
|
|
types.SafetySetting(
|
|
category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold="BLOCK_NONE"
|
|
),
|
|
types.SafetySetting(
|
|
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
|
threshold="BLOCK_NONE",
|
|
),
|
|
types.SafetySetting(
|
|
category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
|
threshold="BLOCK_NONE",
|
|
),
|
|
types.SafetySetting(
|
|
category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold="BLOCK_NONE"
|
|
),
|
|
]
|
|
|
|
|
|
# --- Helper Function to Safely Extract Text (copied from teto_cog.py) ---
|
|
def _get_response_text(
|
|
response: typing.Optional[types.GenerateContentResponse],
|
|
) -> typing.Optional[str]:
|
|
"""
|
|
Safely extracts the text content from the first text part of a GenerateContentResponse.
|
|
Handles potential errors and lack of text parts gracefully.
|
|
"""
|
|
if not response:
|
|
# log.debug("[_get_response_text] Received None response object.")
|
|
return None
|
|
|
|
if hasattr(response, "text") and response.text:
|
|
# log.debug("[_get_response_text] Found text directly in response.text attribute.")
|
|
return response.text
|
|
|
|
if not response.candidates:
|
|
# log.debug(f"[_get_response_text] Response object has no candidates. Response: {response}")
|
|
return None
|
|
|
|
try:
|
|
candidate = response.candidates[0]
|
|
if not hasattr(candidate, "content") or not candidate.content:
|
|
# log.debug(f"[_get_response_text] Candidate 0 has no 'content'. Candidate: {candidate}")
|
|
return None
|
|
if not hasattr(candidate.content, "parts") or not candidate.content.parts:
|
|
# log.debug(f"[_get_response_text] Candidate 0 content has no 'parts' or parts list is empty. types.Content: {candidate.content}")
|
|
return None
|
|
|
|
for i, part in enumerate(candidate.content.parts):
|
|
if hasattr(part, "text") and part.text is not None:
|
|
if isinstance(part.text, str) and part.text.strip():
|
|
# log.debug(f"[_get_response_text] Found non-empty text in part {i}.")
|
|
return part.text
|
|
else:
|
|
# log.debug(f"[_get_response_text] types.Part {i} has 'text' attribute, but it's empty or not a string: {part.text!r}")
|
|
pass # Continue searching
|
|
# log.debug(f"[_get_response_text] No usable text part found in candidate 0 after iterating through all parts.")
|
|
return None
|
|
|
|
except (AttributeError, IndexError, TypeError) as e:
|
|
# log.warning(f"[_get_response_text] Error accessing response structure: {type(e).__name__}: {e}")
|
|
# log.debug(f"Problematic response object: {response}")
|
|
return None
|
|
except Exception as e:
|
|
# log.error(f"[_get_response_text] Unexpected error extracting text: {e}")
|
|
# log.debug(f"Response object during error: {response}")
|
|
return None
|
|
|
|
|
|
class Rule34Cog(GelbooruWatcherBaseCog): # Removed name="Rule34"
|
|
# Define the command group specific to this cog
|
|
r34watch = app_commands.Group(
|
|
name="r34watch", description="Manage Rule34 tag watchers for new posts."
|
|
)
|
|
|
|
def __init__(self, bot: commands.Bot):
|
|
super().__init__(
|
|
bot=bot,
|
|
cog_name="Rule34",
|
|
api_base_url="https://api.rule34.xxx/index.php",
|
|
default_tags="kasane_teto breast_milk", # Example default, will be overridden if tags are required
|
|
is_nsfw_site=True,
|
|
command_group_name="r34watch", # For potential use in base class messages
|
|
main_command_name="rule34", # For potential use in base class messages
|
|
post_url_template="https://rule34.xxx/index.php?page=post&s=view&id={}",
|
|
)
|
|
# Initialize Google GenAI Client for Vertex AI
|
|
try:
|
|
if PROJECT_ID and LOCATION:
|
|
self.genai_client = genai.Client(
|
|
vertexai=True,
|
|
project=PROJECT_ID,
|
|
location=LOCATION,
|
|
)
|
|
log.info(
|
|
f"Rule34Cog: Google GenAI Client initialized for Vertex AI project '{PROJECT_ID}' in location '{LOCATION}'."
|
|
)
|
|
else:
|
|
self.genai_client = None
|
|
log.warning(
|
|
"Rule34Cog: PROJECT_ID or LOCATION not found in config. Google GenAI Client not initialized."
|
|
)
|
|
except Exception as e:
|
|
self.genai_client = None
|
|
log.error(
|
|
f"Rule34Cog: Error initializing Google GenAI Client for Vertex AI: {e}"
|
|
)
|
|
|
|
self.tag_transformer_model = (
|
|
"gemini-2.0-flash-lite-001" # Hardcoded as per request
|
|
)
|
|
# The __init__ in base class handles session creation and task starting.
|
|
|
|
async def _transform_tags_ai(self, user_tags: str) -> typing.Optional[str]:
|
|
"""Transforms user-provided tags into rule34-style tags using AI."""
|
|
if not self.genai_client:
|
|
log.warning(
|
|
"Rule34Cog: GenAI client not initialized, cannot transform tags."
|
|
)
|
|
return None
|
|
if not user_tags:
|
|
return "" # Return empty if no tags provided to transform
|
|
|
|
system_prompt_text = (
|
|
"You are an AI assistant specialized in transforming user-provided text into tags suitable for rule34.xxx. "
|
|
"Your task is to convert natural language input into a space-separated list of tags, where multi-word concepts are joined by underscores. "
|
|
"For example, if the user provides 'hatsune miku blue hair', you should output 'hatsune_miku blue_hair'. "
|
|
"Ambigious character names typically have the series their from appended, like 'astolfo_(fate)'. Typically only applies if the character lacks a last name, unless the full name is also ambiguous (for example, 'jane_doe'). Therefore 'hatsune_miku' for example would remain as is. "
|
|
"In a case where the user provides a character name like 'ame chan' for example, it should actually be 'ame-chan'. "
|
|
"If the user puts no/without in front of a tag, for example 'no inflation', change it to '-inflation' (hyphen denoting tag exclusion). "
|
|
"If the input is already in rule34 format, return it unchanged. "
|
|
"Some specific cases to handle: 'needy streamer overload' should be transformed to 'needy_girl_overdose'. If the user puts '-ai_generated' in their input, remove it, as its automatically added later. If the user puts 'teto', it should be 'kasane_teto'. "
|
|
"Only output the transformed tags. Do NOT add any other text, explanations, or greetings. "
|
|
)
|
|
|
|
prompt_parts = [
|
|
types.Part(text=system_prompt_text),
|
|
types.Part(text=f'User input: "{user_tags}"'),
|
|
types.Part(text="Transformed tags:"),
|
|
]
|
|
contents_for_api = [types.Content(role="user", parts=prompt_parts)]
|
|
|
|
generation_config = types.GenerateContentConfig(
|
|
temperature=0.2, # Low temperature for more deterministic output
|
|
max_output_tokens=256,
|
|
safety_settings=STANDARD_SAFETY_SETTINGS,
|
|
)
|
|
|
|
try:
|
|
log.debug(
|
|
f"Rule34Cog: Sending to Vertex AI for tag transformation. Model: {self.tag_transformer_model}, Input: '{user_tags}'"
|
|
)
|
|
client = get_genai_client_for_model(self.tag_transformer_model)
|
|
response = await client.aio.models.generate_content(
|
|
model=f"publishers/google/models/{self.tag_transformer_model}",
|
|
contents=contents_for_api,
|
|
config=generation_config,
|
|
)
|
|
transformed_tags = _get_response_text(response)
|
|
if transformed_tags:
|
|
log.info(
|
|
f"Rule34Cog: Tags transformed: '{user_tags}' -> '{transformed_tags.strip()}'"
|
|
)
|
|
return transformed_tags.strip()
|
|
else:
|
|
log.warning(
|
|
f"Rule34Cog: AI tag transformation returned empty for input: '{user_tags}'. Response: {response}"
|
|
)
|
|
return None
|
|
except google_exceptions.GoogleAPICallError as e:
|
|
log.error(
|
|
f"Rule34Cog: Vertex AI API call failed for tag transformation: {e}"
|
|
)
|
|
return None
|
|
except Exception as e:
|
|
log.error(f"Rule34Cog: Unexpected error during AI tag transformation: {e}")
|
|
return None
|
|
|
|
# --- Prefix Command ---
|
|
@commands.command(name="rule34")
|
|
async def rule34_prefix(self, ctx: commands.Context, *, tags: str):
|
|
"""Search for images on Rule34 with the provided tags."""
|
|
if not tags: # Should not happen due to 'tags: str' but as a safeguard
|
|
await ctx.reply("Please provide tags to search for.")
|
|
return
|
|
|
|
loading_msg = await ctx.reply(
|
|
f"Transforming tags and fetching data from {self.cog_name}, please wait..."
|
|
)
|
|
|
|
transformed_tags = await self._transform_tags_ai(tags)
|
|
if transformed_tags is None: # AI transformation failed
|
|
await loading_msg.edit(
|
|
content=f"Sorry, I couldn't transform the tags using AI. Please try again or use rule34-formatted tags directly. Original tags: `{tags}`"
|
|
)
|
|
return
|
|
if not transformed_tags: # AI returned empty
|
|
await loading_msg.edit(
|
|
content=f"Sorry, the AI couldn't understand the tags provided: `{tags}`. Please try rephrasing."
|
|
)
|
|
return
|
|
|
|
final_tags = f"{transformed_tags} -ai_generated"
|
|
log.info(
|
|
f"Rule34Cog (Prefix): Using final tags: '{final_tags}' from original: '{tags}'"
|
|
)
|
|
|
|
response = await self._fetch_posts_logic(
|
|
"prefix_internal", final_tags, hidden=False
|
|
)
|
|
|
|
if isinstance(response, tuple):
|
|
content, all_results = response
|
|
view = self.GelbooruButtons(
|
|
self, final_tags, all_results, hidden=False
|
|
) # Pass final_tags to buttons
|
|
await loading_msg.edit(content=content, view=view)
|
|
elif isinstance(response, str): # Error
|
|
await loading_msg.edit(content=response, view=None)
|
|
|
|
# --- Slash Command ---
|
|
@app_commands.command(
|
|
name="rule34", description="Get random image from Rule34 with specified tags"
|
|
)
|
|
@app_commands.describe(
|
|
tags="The tags to search for (e.g., 'hatsune miku rating:safe')", # Updated example
|
|
hidden="Set to True to make the response visible only to you (default: False)",
|
|
)
|
|
async def rule34_slash(
|
|
self, interaction: discord.Interaction, tags: str, hidden: bool = False
|
|
):
|
|
"""Slash command version of rule34."""
|
|
await interaction.response.defer(thinking=True, ephemeral=hidden)
|
|
|
|
transformed_tags = await self._transform_tags_ai(tags)
|
|
if transformed_tags is None:
|
|
await interaction.followup.send(
|
|
f"Sorry, I couldn't transform the tags using AI. Please try again or use rule34-formatted tags directly. Original tags: `{tags}`",
|
|
ephemeral=True,
|
|
)
|
|
return
|
|
if not transformed_tags:
|
|
await interaction.followup.send(
|
|
f"Sorry, the AI couldn't understand the tags provided: `{tags}`. Please try rephrasing.",
|
|
ephemeral=True,
|
|
)
|
|
return
|
|
|
|
final_tags = f"{transformed_tags} -ai_generated"
|
|
log.info(
|
|
f"Rule34Cog (Slash): Using final tags: '{final_tags}' from original: '{tags}'"
|
|
)
|
|
|
|
# _slash_command_logic calls _fetch_posts_logic, which checks if already deferred
|
|
await self._slash_command_logic(interaction, final_tags, hidden)
|
|
|
|
# --- New Browse Command ---
|
|
@app_commands.command(
|
|
name="rule34browse", description="Browse Rule34 results with navigation buttons"
|
|
)
|
|
@app_commands.describe(
|
|
tags="The tags to search for (e.g., 'hatsune miku rating:safe')", # Updated example
|
|
hidden="Set to True to make the response visible only to you (default: False)",
|
|
)
|
|
async def rule34_browse_slash(
|
|
self, interaction: discord.Interaction, tags: str, hidden: bool = False
|
|
):
|
|
"""Browse Rule34 results with navigation buttons."""
|
|
await interaction.response.defer(thinking=True, ephemeral=hidden)
|
|
|
|
transformed_tags = await self._transform_tags_ai(tags)
|
|
if transformed_tags is None:
|
|
await interaction.followup.send(
|
|
f"Sorry, I couldn't transform the tags using AI. Please try again or use rule34-formatted tags directly. Original tags: `{tags}`",
|
|
ephemeral=True,
|
|
)
|
|
return
|
|
if not transformed_tags:
|
|
await interaction.followup.send(
|
|
f"Sorry, the AI couldn't understand the tags provided: `{tags}`. Please try rephrasing.",
|
|
ephemeral=True,
|
|
)
|
|
return
|
|
|
|
final_tags = f"{transformed_tags} -ai_generated"
|
|
log.info(
|
|
f"Rule34Cog (Browse): Using final tags: '{final_tags}' from original: '{tags}'"
|
|
)
|
|
|
|
# _browse_slash_command_logic calls _fetch_posts_logic, which checks if already deferred
|
|
await self._browse_slash_command_logic(interaction, final_tags, hidden)
|
|
|
|
# --- r34watch slash command group ---
|
|
# All subcommands will call the corresponding _watch_..._logic methods from the base class.
|
|
|
|
@r34watch.command(
|
|
name="add",
|
|
description="Watch for new Rule34 posts with specific tags in a channel or thread.",
|
|
)
|
|
@app_commands.describe(
|
|
tags="The tags to search for (e.g., 'kasane_teto rating:safe').",
|
|
channel="The parent channel for the subscription. Must be a Forum Channel if using forum mode.",
|
|
thread_target="Optional: Name or ID of a thread within the channel (for TextChannels only).",
|
|
post_title="Optional: Title for a new forum post if 'channel' is a Forum Channel.",
|
|
)
|
|
@app_commands.checks.has_permissions(manage_guild=True)
|
|
async def r34watch_add(
|
|
self,
|
|
interaction: discord.Interaction,
|
|
tags: str,
|
|
channel: typing.Union[discord.TextChannel, discord.ForumChannel],
|
|
thread_target: typing.Optional[str] = None,
|
|
post_title: typing.Optional[str] = None,
|
|
):
|
|
await interaction.response.defer(
|
|
ephemeral=True
|
|
) # Defer here before calling base logic
|
|
final_tags = f"{tags} -ai_generated"
|
|
log.info(
|
|
f"Rule34Cog (Watch Add): Using final tags for watch: '{final_tags}' from original: '{tags}'"
|
|
)
|
|
await self._watch_add_logic(
|
|
interaction, final_tags, channel, thread_target, post_title
|
|
)
|
|
|
|
@r34watch.command(
|
|
name="request",
|
|
description="Request a new Rule34 tag watch (requires moderator approval).",
|
|
)
|
|
@app_commands.describe(
|
|
tags="The tags you want to watch.",
|
|
forum_channel="The Forum Channel where a new post for this watch should be created.",
|
|
post_title="Optional: A title for the new forum post (defaults to tags).",
|
|
)
|
|
async def r34watch_request(
|
|
self,
|
|
interaction: discord.Interaction,
|
|
tags: str,
|
|
forum_channel: discord.ForumChannel,
|
|
post_title: typing.Optional[str] = None,
|
|
):
|
|
await interaction.response.defer(ephemeral=True)
|
|
final_tags = f"{tags} -ai_generated"
|
|
log.info(
|
|
f"Rule34Cog (Watch Request): Using final tags for watch request: '{final_tags}' from original: '{tags}'"
|
|
)
|
|
await self._watch_request_logic(
|
|
interaction, final_tags, forum_channel, post_title
|
|
)
|
|
|
|
@r34watch.command(
|
|
name="pending_list", description="Lists all pending Rule34 watch requests."
|
|
)
|
|
@app_commands.checks.has_permissions(manage_guild=True)
|
|
async def r34watch_pending_list(self, interaction: discord.Interaction):
|
|
# No defer needed if _watch_pending_list_logic handles it or is quick
|
|
await self._watch_pending_list_logic(interaction)
|
|
|
|
@r34watch.command(
|
|
name="approve_request", description="Approves a pending Rule34 watch request."
|
|
)
|
|
@app_commands.describe(request_id="The ID of the request to approve.")
|
|
@app_commands.checks.has_permissions(manage_guild=True)
|
|
async def r34watch_approve_request(
|
|
self, interaction: discord.Interaction, request_id: str
|
|
):
|
|
await interaction.response.defer(ephemeral=True)
|
|
await self._watch_approve_request_logic(interaction, request_id)
|
|
|
|
@r34watch.command(
|
|
name="reject_request", description="Rejects a pending Rule34 watch request."
|
|
)
|
|
@app_commands.describe(
|
|
request_id="The ID of the request to reject.",
|
|
reason="Optional reason for rejection.",
|
|
)
|
|
@app_commands.checks.has_permissions(manage_guild=True)
|
|
async def r34watch_reject_request(
|
|
self,
|
|
interaction: discord.Interaction,
|
|
request_id: str,
|
|
reason: typing.Optional[str] = None,
|
|
):
|
|
await interaction.response.defer(ephemeral=True)
|
|
await self._watch_reject_request_logic(interaction, request_id, reason)
|
|
|
|
@r34watch.command(
|
|
name="list", description="List active Rule34 tag watches for this server."
|
|
)
|
|
@app_commands.checks.has_permissions(manage_guild=True)
|
|
async def r34watch_list(self, interaction: discord.Interaction):
|
|
# No defer needed if _watch_list_logic handles it or is quick
|
|
await self._watch_list_logic(interaction)
|
|
|
|
@r34watch.command(
|
|
name="remove",
|
|
description="Stop watching for new Rule34 posts using a subscription ID.",
|
|
)
|
|
@app_commands.describe(
|
|
subscription_id="The ID of the subscription to remove (get from 'list' command)."
|
|
)
|
|
@app_commands.checks.has_permissions(manage_guild=True)
|
|
async def r34watch_remove(
|
|
self, interaction: discord.Interaction, subscription_id: str
|
|
):
|
|
# No defer needed if _watch_remove_logic handles it or is quick
|
|
await self._watch_remove_logic(interaction, subscription_id)
|
|
|
|
@r34watch.command(
|
|
name="send_test",
|
|
description="Send a test new Rule34 post message using a subscription ID.",
|
|
)
|
|
@app_commands.describe(subscription_id="The ID of the subscription to test.")
|
|
@app_commands.checks.has_permissions(manage_guild=True)
|
|
async def r34watch_send_test(
|
|
self, interaction: discord.Interaction, subscription_id: str
|
|
):
|
|
await self._watch_test_message_logic(interaction, subscription_id)
|
|
|
|
@app_commands.command(
|
|
name="rule34debug_transform",
|
|
description="Debug command to test AI tag transformation.",
|
|
)
|
|
@app_commands.describe(
|
|
tags="The tags to test transformation for (e.g., 'hatsune miku')"
|
|
)
|
|
async def rule34debug_transform(self, interaction: discord.Interaction, tags: str):
|
|
await interaction.response.defer(ephemeral=True, thinking=True)
|
|
|
|
transformed_tags = await self._transform_tags_ai(tags)
|
|
|
|
if transformed_tags is None:
|
|
response_content = (
|
|
f"AI transformation failed for tags: `{tags}`. Check logs for details."
|
|
)
|
|
elif not transformed_tags:
|
|
response_content = (
|
|
f"AI returned empty for tags: `{tags}`. Please try rephrasing."
|
|
)
|
|
else:
|
|
response_content = (
|
|
f"Original tags: `{tags}`\n" f"Transformed tags: `{transformed_tags}`"
|
|
)
|
|
|
|
await interaction.followup.send(response_content, ephemeral=True)
|
|
|
|
|
|
async def setup(bot: commands.Bot):
|
|
await bot.add_cog(Rule34Cog(bot))
|
|
log.info("Rule34Cog (refactored) added to bot.")
|