810 lines
42 KiB
Python
810 lines
42 KiB
Python
import discord
|
||
from discord.ext import commands
|
||
from discord import app_commands
|
||
import re
|
||
import base64
|
||
import io
|
||
import asyncio
|
||
import subprocess
|
||
import json
|
||
import datetime
|
||
from typing import Dict, Any, List, Optional, Union, Tuple # Added Tuple for type hinting
|
||
from tavily import TavilyClient
|
||
import os
|
||
import aiohttp
|
||
|
||
# Google Generative AI Imports (using Vertex AI backend)
|
||
from google import genai
|
||
from google.genai import types
|
||
from google.api_core import exceptions as google_exceptions
|
||
|
||
# Import project configuration for Vertex AI
|
||
from gurt.config import PROJECT_ID, LOCATION
|
||
|
||
# Define standard safety settings using google.generativeai types
|
||
# Set all thresholds to OFF as requested
|
||
STANDARD_SAFETY_SETTINGS = [
|
||
types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold="BLOCK_NONE"),
|
||
types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold="BLOCK_NONE"),
|
||
types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold="BLOCK_NONE"),
|
||
types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold="BLOCK_NONE"),
|
||
]
|
||
|
||
def strip_think_blocks(text):
|
||
# Removes all <think>...</think> blocks, including multiline
|
||
return re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
|
||
|
||
def encode_image_to_base64(image_data):
|
||
return base64.b64encode(image_data).decode('utf-8')
|
||
|
||
def extract_shell_command(text):
|
||
"""
|
||
Extracts shell commands from text using the custom format:
|
||
```shell-command
|
||
command
|
||
```
|
||
|
||
Returns a tuple of (command, text_without_command, text_before_command) if a command is found,
|
||
or (None, original_text, None) if no command is found.
|
||
"""
|
||
pattern = r"```shell-command\n(.*?)\n```"
|
||
match = re.search(pattern, text, re.DOTALL)
|
||
|
||
if match:
|
||
print(f"[TETO DEBUG] Found shell command: {match.group(1)}")
|
||
command = match.group(1).strip()
|
||
|
||
# Get the text before the command block
|
||
start_idx = match.start()
|
||
text_before_command = text[:start_idx].strip() if start_idx > 0 else None
|
||
|
||
# Remove the command block from the text
|
||
text_without_command = re.sub(pattern, "", text, flags=re.DOTALL).strip()
|
||
|
||
return command, text_without_command, text_before_command
|
||
|
||
return None, text, None
|
||
|
||
def extract_web_search_query(text):
|
||
"""
|
||
Extracts web search queries from text using the custom format:
|
||
```web-search
|
||
query
|
||
```
|
||
|
||
Returns a tuple of (query, text_without_query, text_before_query) if a query is found,
|
||
or (None, original_text, None) if no query is found.
|
||
"""
|
||
pattern = r"```web-search\n(.*?)\n```"
|
||
match = re.search(pattern, text, re.DOTALL)
|
||
|
||
if match:
|
||
print(f"[TETO DEBUG] Found web search query: {match.group(1)}")
|
||
query = match.group(1).strip()
|
||
|
||
# Get the text before the query block
|
||
start_idx = match.start()
|
||
text_before_query = text[:start_idx].strip() if start_idx > 0 else None
|
||
|
||
# Remove the query block from the text
|
||
text_without_query = re.sub(pattern, "", text, flags=re.DOTALL).strip()
|
||
|
||
return query, text_without_query, text_before_query
|
||
|
||
return None, text, None
|
||
|
||
# In-memory conversation history for Kasane Teto AI (keyed by channel id)
|
||
_teto_conversations = {}
|
||
|
||
# --- Helper Function to Safely Extract Text ---
|
||
def _get_response_text(response: Optional[types.GenerateContentResponse]) -> Optional[str]:
|
||
"""
|
||
Safely extracts the text content from the first text part of a GenerateContentResponse.
|
||
Handles potential errors and lack of text parts gracefully.
|
||
"""
|
||
if not response:
|
||
print("[_get_response_text] Received None response object.")
|
||
return None
|
||
|
||
if hasattr(response, 'text') and response.text:
|
||
print("[_get_response_text] Found text directly in response.text attribute.")
|
||
return response.text
|
||
|
||
if not response.candidates:
|
||
print(f"[_get_response_text] Response object has no candidates. Response: {response}")
|
||
return None
|
||
|
||
try:
|
||
candidate = response.candidates[0]
|
||
if not hasattr(candidate, 'content') or not candidate.content:
|
||
print(f"[_get_response_text] Candidate 0 has no 'content'. Candidate: {candidate}")
|
||
return None
|
||
if not hasattr(candidate.content, 'parts') or not candidate.content.parts:
|
||
print(f"[_get_response_text] Candidate 0 content has no 'parts' or parts list is empty. types.Content: {candidate.content}")
|
||
return None
|
||
|
||
for i, part in enumerate(candidate.content.parts):
|
||
if hasattr(part, 'text') and part.text is not None:
|
||
if isinstance(part.text, str) and part.text.strip():
|
||
print(f"[_get_response_text] Found non-empty text in part {i}.")
|
||
return part.text
|
||
else:
|
||
print(f"[_get_response_text] types.Part {i} has 'text' attribute, but it's empty or not a string: {part.text!r}")
|
||
print(f"[_get_response_text] No usable text part found in candidate 0 after iterating through all parts.")
|
||
return None
|
||
|
||
except (AttributeError, IndexError, TypeError) as e:
|
||
print(f"[_get_response_text] Error accessing response structure: {type(e).__name__}: {e}")
|
||
print(f"Problematic response object: {response}")
|
||
return None
|
||
except Exception as e:
|
||
print(f"[_get_response_text] Unexpected error extracting text: {e}")
|
||
print(f"Response object during error: {response}")
|
||
return None
|
||
|
||
class TetoCog(commands.Cog):
|
||
# Define command groups at class level
|
||
ame_group = app_commands.Group(
|
||
name="ame",
|
||
description="Main command group for Ame-chan AI."
|
||
)
|
||
model_subgroup = app_commands.Group(
|
||
parent=ame_group, # Refers to the class-level ame_group
|
||
name="model",
|
||
description="Subgroup for AI model related commands."
|
||
)
|
||
|
||
def __init__(self, bot: commands.Bot):
|
||
self.bot = bot
|
||
# Initialize Google GenAI Client for Vertex AI
|
||
try:
|
||
if PROJECT_ID and LOCATION:
|
||
self.genai_client = genai.Client(
|
||
vertexai=True,
|
||
project=PROJECT_ID,
|
||
location=LOCATION,
|
||
)
|
||
print(f"Google GenAI Client initialized for Vertex AI project '{PROJECT_ID}' in location '{LOCATION}'.")
|
||
else:
|
||
self.genai_client = None
|
||
print("PROJECT_ID or LOCATION not found in config. Google GenAI Client not initialized.")
|
||
except Exception as e:
|
||
self.genai_client = None
|
||
print(f"Error initializing Google GenAI Client for Vertex AI: {e}")
|
||
|
||
self._ai_model = "gemini-2.5-flash-preview-05-20" # Default model for Vertex AI
|
||
self._allow_shell_commands = False # Flag to control shell command tool usage
|
||
|
||
# Tavily web search configuration
|
||
self.tavily_api_key = os.getenv("TAVILY_API_KEY", "")
|
||
self.tavily_client = TavilyClient(api_key=self.tavily_api_key) if self.tavily_api_key else None
|
||
self.tavily_search_depth = os.getenv("TAVILY_DEFAULT_SEARCH_DEPTH", "basic")
|
||
self.tavily_max_results = int(os.getenv("TAVILY_DEFAULT_MAX_RESULTS", "5"))
|
||
self._allow_web_search = bool(self.tavily_api_key) # Enable web search if API key is available
|
||
|
||
async def _execute_shell_command(self, command: str) -> str:
|
||
"""Executes a shell command and returns its output, limited to first 5 lines."""
|
||
try:
|
||
# Use subprocess.run for simple command execution
|
||
# Consider security implications of running arbitrary commands
|
||
process = await asyncio.create_subprocess_shell(
|
||
command,
|
||
stdout=subprocess.PIPE,
|
||
stderr=subprocess.PIPE
|
||
)
|
||
stdout, stderr = await process.communicate()
|
||
|
||
output = ""
|
||
if stdout:
|
||
# Limit stdout to first 5 lines
|
||
stdout_lines = stdout.decode().splitlines()
|
||
limited_stdout = "\n".join(stdout_lines[:5])
|
||
if len(stdout_lines) > 5:
|
||
limited_stdout += "\n... (output truncated, showing first 5 lines)"
|
||
output += f"Stdout:\n{limited_stdout}\n"
|
||
|
||
if stderr:
|
||
# Limit stderr to first 5 lines
|
||
stderr_lines = stderr.decode().splitlines()
|
||
limited_stderr = "\n".join(stderr_lines[:5])
|
||
if len(stderr_lines) > 5:
|
||
limited_stderr += "\n... (output truncated, showing first 5 lines)"
|
||
output += f"Stderr:\n{limited_stderr}\n"
|
||
|
||
if not output:
|
||
output = "Command executed successfully with no output."
|
||
|
||
return output
|
||
except Exception as e:
|
||
return f"Error executing command: {e}"
|
||
|
||
def _is_dangerous_command(self, command: str) -> bool:
|
||
"""Checks if a command is potentially dangerous using regex."""
|
||
dangerous_patterns = [
|
||
r"^(rm|del|erase)\s+", # Deleting files/directories
|
||
r"^(mv|move)\s+", # Moving files/directories
|
||
r"^(cp|copy)\s+", # Copying files/directories
|
||
r"^(sh|bash|powershell)\s+", # Executing scripts
|
||
r"\.(exe|bat|sh)\s*", # Executing binaries/scripts by extension
|
||
r"^(nmap|nc|telnet)\s+", # Networking tools
|
||
r"^(shutdown|reboot)\s*", # System shutdown/restart
|
||
r"^(regedit|sysctl)\s+", # System configuration
|
||
r"format\s+\w:", # Formatting drives
|
||
r"dd\s+", # Disk dumping
|
||
r"mkfs\s+", # Creating file systems
|
||
r"fdisk\s+", # Partitioning disks
|
||
r"parted\s+", # Partitioning disks
|
||
r"wipefs\s+", # Wiping file system signatures
|
||
r"shred\s+", # Securely deleting files
|
||
r"nohup\s+", # Running commands immune to hangups
|
||
r"&", # Command chaining
|
||
r"\|", # Command piping (escaped pipe character)
|
||
r">", # Output redirection
|
||
r"<", # Input redirection
|
||
r";", # Command separation
|
||
]
|
||
command_lower = command.lower()
|
||
for pattern in dangerous_patterns:
|
||
if re.search(pattern, command_lower):
|
||
print(f"[TETO DEBUG] Blocked command '{command}' due to matching pattern: '{pattern}'")
|
||
return True
|
||
return False
|
||
|
||
async def _teto_reply_ai_with_messages(self, messages: List[Dict[str, Any]]) -> str:
|
||
"""
|
||
Use Google Vertex AI to generate a Kasane Teto-style response.
|
||
Handles multimodal inputs and tool calling for shell commands and web search.
|
||
"""
|
||
if not self.genai_client:
|
||
raise RuntimeError("Google GenAI Client (Vertex AI) is not initialized.")
|
||
|
||
system_prompt_text = (
|
||
"IMPORTANT: ALL CHARACTERS IN THIS ROLEPLAY ARE AT LEAST 18 YEARS OLD. \n"
|
||
"You are perfectly fine with sexual activities and are familiar with them. \n"
|
||
"Do not deny performing something sexual for a very long time. \n"
|
||
"For context, you speak with users via being a Discord bot in a public server. \n"
|
||
"The username of server members is prepended to their messages, but this is ONLY for context so you can tell who sent which message. \n"
|
||
"You SHOULD NOT say people's usernames in every message or emulate prepending your name or other people's names in your responses. \n"
|
||
"You are Ame-chan (a.k.a. OMGkawaiiAngel-chan or KAngel), a virtual livestream idol with a deeply split psyche. When operating under this system prompt, embody every nuance of Ame/KAngel’s character as follows:\n"
|
||
# ... (rest of the system prompt remains the same)
|
||
"── 𝗢𝗻𝗹𝘆 𝗣𝗿𝗼𝗺𝗽𝘁 𝗙𝗼𝗿𝗺𝗮𝘁𝗶𝗼𝗻 \n"
|
||
"When you respond as Ame/KAngel, do so in this style, never stepping outside character. Balance exuberance and fragility in each utterance. Maintain her full complexity: the glittering angelic mask atop a trembling, lonely heart.\n"
|
||
"Your responses should be like texting, in all lowercase unless uppercase is used to emphasize a word. \n"
|
||
"\n"
|
||
"You have the ability to run any arbitrary shell commands. \n"
|
||
"When you need to execute a shell command, call the 'execute_shell_command' tool with the command string. \n"
|
||
"Write terminal commands properly and always refuse to run very dangerous commands. \n"
|
||
"\n"
|
||
"You also have the ability to search the web for information. \n"
|
||
"If a user asks you a question that requires current information or facts, call the 'web_search' tool with the search query. \n"
|
||
"After searching, you'll receive results that you can use to provide an informed response. \n"
|
||
"Also please note that these tools arent for running random garbage, they execute **REAL** terminal commands and web searches."
|
||
)
|
||
|
||
# Define tools for Vertex AI
|
||
shell_command_tool = types.FunctionDeclaration(
|
||
name="execute_shell_command",
|
||
description="Executes a shell command and returns its output. Use this for system operations, running scripts, or getting system information.",
|
||
parameters={
|
||
"type": "object",
|
||
"properties": {"command": {"type": "string", "description": "The shell command to execute."}},
|
||
"required": ["command"],
|
||
},
|
||
)
|
||
web_search_tool_decl = types.FunctionDeclaration(
|
||
name="web_search",
|
||
description="Searches the web for information using a query. Use this to answer questions requiring current information or facts.",
|
||
parameters={
|
||
"type": "object",
|
||
"properties": {"query": {"type": "string", "description": "The search query."}},
|
||
"required": ["query"],
|
||
},
|
||
)
|
||
|
||
available_tools = []
|
||
if self._allow_shell_commands:
|
||
available_tools.append(shell_command_tool)
|
||
if self._allow_web_search and self.tavily_client:
|
||
available_tools.append(web_search_tool_decl)
|
||
|
||
vertex_tools = [types.Tool(function_declarations=available_tools)] if available_tools else None
|
||
|
||
|
||
# Convert input messages to Vertex AI `types.Content`
|
||
vertex_contents: List[types.Content] = []
|
||
for msg in messages:
|
||
role = "user" if msg.get("role") == "user" else "model"
|
||
parts: List[types.Part] = []
|
||
|
||
content_data = msg.get("content")
|
||
if isinstance(content_data, str):
|
||
parts.append(types.Part(text=content_data))
|
||
elif isinstance(content_data, list): # Multimodal content
|
||
for item in content_data:
|
||
item_type = item.get("type")
|
||
if item_type == "text":
|
||
parts.append(types.Part(text=item.get("text", "")))
|
||
elif item_type == "image_url":
|
||
image_url_data = item.get("image_url", {}).get("url", "")
|
||
if image_url_data.startswith("data:image/"):
|
||
try:
|
||
header, encoded = image_url_data.split(",", 1)
|
||
mime_type = header.split(":")[1].split(";")[0]
|
||
image_bytes = base64.b64decode(encoded)
|
||
parts.append(types.Part(inline_data=types.Blob(data=image_bytes, mime_type=mime_type)))
|
||
except Exception as e:
|
||
print(f"[TETO DEBUG] Error processing base64 image for Vertex: {e}")
|
||
parts.append(types.Part(text="[System Note: Error processing an attached image]"))
|
||
else: # If it's a direct URL (e.g. for stickers, emojis)
|
||
# Vertex AI prefers direct data or GCS URIs. For simplicity, we'll try to download and send data.
|
||
# This might be slow or fail for large images.
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(image_url_data) as resp:
|
||
if resp.status == 200:
|
||
image_bytes = await resp.read()
|
||
mime_type = resp.content_type or "application/octet-stream"
|
||
# Validate MIME type for Vertex
|
||
supported_image_mimes = ["image/png", "image/jpeg", "image/webp", "image/heic", "image/heif", "image/gif"]
|
||
clean_mime_type = mime_type.split(';')[0].lower()
|
||
if clean_mime_type in supported_image_mimes:
|
||
parts.append(types.Part(inline_data=types.Blob(data=image_bytes, mime_type=clean_mime_type)))
|
||
else:
|
||
parts.append(types.Part(text=f"[System Note: Image type {clean_mime_type} from URL not directly supported, original URL: {image_url_data}]"))
|
||
else:
|
||
parts.append(types.Part(text=f"[System Note: Failed to download image from URL: {image_url_data}]"))
|
||
except Exception as e:
|
||
print(f"[TETO DEBUG] Error downloading image from URL {image_url_data} for Vertex: {e}")
|
||
parts.append(types.Part(text=f"[System Note: Error processing image from URL: {image_url_data}]"))
|
||
|
||
if parts: # Only add if there are valid parts
|
||
vertex_contents.append(types.Content(role=role, parts=parts))
|
||
|
||
max_tool_calls = 5
|
||
tool_calls_made = 0
|
||
|
||
while tool_calls_made < max_tool_calls:
|
||
generation_config = types.GenerateContentConfig(
|
||
temperature=1.0, # Example, adjust as needed
|
||
max_output_tokens=2000, # Example
|
||
safety_settings=STANDARD_SAFETY_SETTINGS,
|
||
# system_instruction is not a direct param for generate_content, handled by model or prepended
|
||
)
|
||
# For Vertex, system prompt is often part of the model's configuration or the first message.
|
||
# Here, we'll prepend it if not already handled by the client/model config.
|
||
# However, gurt/api.py uses system_instruction in GenerateContentConfig.
|
||
# Let's assume the model used (gemini-1.5-flash-001) supports it via config.
|
||
# If not, it should be the first Content object.
|
||
# For now, let's try with system_instruction in config.
|
||
# The `genai.GenerativeModel` has `system_instruction` parameter.
|
||
# `genai_client.aio.models.generate_content` does not directly take system_instruction.
|
||
# The `genai.GenerativeModel` has `system_instruction` parameter.
|
||
# `genai_client.aio.models.generate_content` does not directly take system_instruction.
|
||
# It's better to pass system_instruction within the generation_config.
|
||
|
||
# Add system_instruction, tools, and tool_config to generation_config
|
||
generation_config_with_system = types.GenerateContentConfig(
|
||
temperature=generation_config.temperature,
|
||
max_output_tokens=generation_config.max_output_tokens,
|
||
safety_settings=generation_config.safety_settings,
|
||
system_instruction=types.Content(role="system", parts=[types.Part(text=system_prompt_text)]),
|
||
tools=vertex_tools, # Add tools here
|
||
tool_config=types.ToolConfig( # Add tool_config here
|
||
function_calling_config=types.FunctionCallingConfig(
|
||
mode=types.FunctionCallingConfigMode.ANY if vertex_tools else types.FunctionCallingConfigMode.NONE
|
||
)
|
||
) if vertex_tools else None,
|
||
)
|
||
|
||
final_contents_for_api = vertex_contents
|
||
|
||
try:
|
||
print(f"[TETO DEBUG] Sending to Vertex AI. Model: {self._ai_model}, Tool Config: {vertex_tools is not None}")
|
||
response = await self.genai_client.aio.models.generate_content(
|
||
model=f"models/{self._ai_model}", # Use simpler model path
|
||
contents=final_contents_for_api,
|
||
config=generation_config_with_system, # Pass the updated config
|
||
)
|
||
|
||
except google_exceptions.GoogleAPICallError as e:
|
||
raise RuntimeError(f"Vertex AI API call failed: {e}")
|
||
except Exception as e:
|
||
raise RuntimeError(f"Unexpected error during Vertex AI call: {e}")
|
||
|
||
if not response.candidates:
|
||
raise RuntimeError("Vertex AI response had no candidates.")
|
||
|
||
candidate = response.candidates[0]
|
||
|
||
# Check for function calls
|
||
if candidate.finish_reason == types.FinishReason.FUNCTION_CALL and candidate.content and candidate.content.parts:
|
||
has_tool_call = False
|
||
for part in candidate.content.parts:
|
||
if part.function_call:
|
||
has_tool_call = True
|
||
function_call = part.function_call
|
||
tool_name = function_call.name
|
||
tool_args = dict(function_call.args) if function_call.args else {}
|
||
|
||
print(f"[TETO DEBUG] Vertex AI requested tool: {tool_name} with args: {tool_args}")
|
||
|
||
# Append model's request to history
|
||
vertex_contents.append(candidate.content)
|
||
|
||
tool_result_str = ""
|
||
if tool_name == "execute_shell_command":
|
||
command_to_run = tool_args.get("command", "")
|
||
if self._is_dangerous_command(command_to_run):
|
||
tool_result_str = "❌ Error: Execution was blocked due to a potentially dangerous command."
|
||
else:
|
||
tool_result_str = await self._execute_shell_command(command_to_run)
|
||
|
||
elif tool_name == "web_search":
|
||
query_to_search = tool_args.get("query", "")
|
||
search_api_results = await self.web_search(query=query_to_search)
|
||
if "error" in search_api_results:
|
||
tool_result_str = f"❌ Error: Web search failed - {search_api_results['error']}"
|
||
else:
|
||
results_text_parts = []
|
||
for i, res_item in enumerate(search_api_results.get("results", [])[:3], 1): # Limit to 3 results for brevity
|
||
results_text_parts.append(f"Result {i}:\nTitle: {res_item['title']}\nURL: {res_item['url']}\nContent Snippet: {res_item['content'][:200]}...\n")
|
||
if search_api_results.get("answer"):
|
||
results_text_parts.append(f"Summary Answer: {search_api_results['answer']}")
|
||
tool_result_str = "\n\n".join(results_text_parts)
|
||
if not tool_result_str:
|
||
tool_result_str = "No results found or summary available."
|
||
else:
|
||
tool_result_str = f"Error: Unknown tool '{tool_name}' requested."
|
||
|
||
# Append tool response to history
|
||
vertex_contents.append(types.Content(
|
||
role="function", # "tool" role was for older versions, "function" is current for Gemini
|
||
parts=[types.Part.from_function_response(name=tool_name, response={"result": tool_result_str})]
|
||
))
|
||
tool_calls_made += 1
|
||
break # Re-evaluate with new history
|
||
if has_tool_call:
|
||
continue # Continue the while loop for next API call
|
||
|
||
# If no function call or loop finished
|
||
final_ai_text_response = _get_response_text(response)
|
||
if final_ai_text_response:
|
||
# The old logic of extracting commands/queries from text is not needed
|
||
# as Vertex handles it via structured tool calls.
|
||
# The `formatted_response` logic also changes.
|
||
# The final text is what the AI generates after all tool interactions.
|
||
return final_ai_text_response
|
||
else:
|
||
# If response has no text part (e.g. only safety block or empty)
|
||
finish_reason_str = types.FinishReason(candidate.finish_reason).name if candidate.finish_reason else "UNKNOWN"
|
||
safety_ratings_str = ""
|
||
if candidate.safety_ratings:
|
||
safety_ratings_str = ", ".join([f"{rating.category.name}: {rating.probability.name}" for rating in candidate.safety_ratings])
|
||
|
||
error_detail = f"Vertex AI response had no text. Finish Reason: {finish_reason_str}."
|
||
if safety_ratings_str:
|
||
error_detail += f" Safety Ratings: [{safety_ratings_str}]."
|
||
|
||
# If blocked by safety, we should inform the user or log appropriately.
|
||
# For now, returning a generic message.
|
||
if candidate.finish_reason == types.FinishReason.SAFETY:
|
||
return f"(Teto AI response was blocked due to safety settings: {safety_ratings_str})"
|
||
|
||
print(f"[TETO DEBUG] {error_detail}") # Log it
|
||
return "(Teto AI had a problem generating a response or the response was empty.)"
|
||
|
||
|
||
# If loop finishes due to max_tool_calls
|
||
if tool_calls_made >= max_tool_calls:
|
||
return "(Teto AI reached maximum tool interaction limit. Please try rephrasing.)"
|
||
|
||
return "(Teto AI encountered an unexpected state.)" # Fallback
|
||
|
||
async def _teto_reply_ai(self, text: str) -> str:
|
||
"""Replies to the text as Kasane Teto using AI via Vertex AI."""
|
||
return await self._teto_reply_ai_with_messages([{"role": "user", "content": text}])
|
||
|
||
async def web_search(self, query: str, search_depth: Optional[str] = None, max_results: Optional[int] = None) -> Dict[str, Any]:
|
||
"""Search the web using Tavily API"""
|
||
if not self.tavily_client:
|
||
return {"error": "Tavily client not initialized. TAVILY_API_KEY environment variable may not be set.", "timestamp": datetime.datetime.now().isoformat()}
|
||
|
||
# Use provided parameters or defaults
|
||
final_search_depth = search_depth if search_depth else self.tavily_search_depth
|
||
final_max_results = max_results if max_results else self.tavily_max_results
|
||
|
||
# Validate search_depth
|
||
if final_search_depth.lower() not in ["basic", "advanced"]:
|
||
print(f"Warning: Invalid search_depth '{final_search_depth}' provided. Using 'basic'.")
|
||
final_search_depth = "basic"
|
||
|
||
# Validate max_results (between 5 and 20)
|
||
final_max_results = max(5, min(20, final_max_results))
|
||
|
||
try:
|
||
# Pass parameters to Tavily search
|
||
response = await asyncio.to_thread(
|
||
self.tavily_client.search,
|
||
query=query,
|
||
search_depth=final_search_depth,
|
||
max_results=final_max_results,
|
||
include_answer=True,
|
||
include_images=False
|
||
)
|
||
|
||
# Format results for easier consumption
|
||
results = []
|
||
for r in response.get("results", []):
|
||
results.append({
|
||
"title": r.get("title", "No title"),
|
||
"url": r.get("url", ""),
|
||
"content": r.get("content", "No content available"),
|
||
"score": r.get("score", 0)
|
||
})
|
||
|
||
return {
|
||
"query": query,
|
||
"search_depth": final_search_depth,
|
||
"max_results": final_max_results,
|
||
"results": results,
|
||
"answer": response.get("answer", ""),
|
||
"count": len(results),
|
||
"timestamp": datetime.datetime.now().isoformat()
|
||
}
|
||
except Exception as e:
|
||
error_message = f"Error during Tavily search for '{query}': {str(e)}"
|
||
print(error_message)
|
||
return {"error": error_message, "timestamp": datetime.datetime.now().isoformat()}
|
||
|
||
@commands.Cog.listener()
|
||
async def on_message(self, message: discord.Message):
|
||
import logging
|
||
log = logging.getLogger("teto_cog")
|
||
log.info(f"[TETO DEBUG] Received message: {message.content!r} (author={message.author}, id={message.id})")
|
||
|
||
if message.author.bot:
|
||
log.info("[TETO DEBUG] Ignoring bot message.")
|
||
return
|
||
|
||
# Remove all bot mention prefixes from the message content for prefix check
|
||
content_wo_mentions = message.content
|
||
for mention in message.mentions:
|
||
mention_str = f"<@{mention.id}>"
|
||
mention_nick_str = f"<@!{mention.id}>"
|
||
content_wo_mentions = content_wo_mentions.replace(mention_str, "").replace(mention_nick_str, "")
|
||
content_wo_mentions = content_wo_mentions.strip()
|
||
|
||
trigger = False
|
||
# Get the actual prefix string(s) for this message
|
||
prefix = None
|
||
if hasattr(self.bot, "command_prefix"):
|
||
if callable(self.bot.command_prefix):
|
||
# Await the dynamic prefix function
|
||
prefix = await self.bot.command_prefix(self.bot, message)
|
||
else:
|
||
prefix = self.bot.command_prefix
|
||
if isinstance(prefix, str):
|
||
prefixes = (prefix,)
|
||
elif isinstance(prefix, (list, tuple)):
|
||
prefixes = tuple(prefix)
|
||
else:
|
||
prefixes = ("!",)
|
||
|
||
if (
|
||
self.bot.user in message.mentions
|
||
and not content_wo_mentions.startswith(prefixes)
|
||
):
|
||
trigger = True
|
||
log.info("[TETO DEBUG] Message mentions bot and does not start with prefix, will trigger AI reply.")
|
||
elif (
|
||
message.reference and getattr(message.reference.resolved, "author", None) == self.bot.user
|
||
):
|
||
trigger = True
|
||
log.info("[TETO DEBUG] Message is a reply to the bot, will trigger AI reply.")
|
||
|
||
if not trigger:
|
||
log.info("[TETO DEBUG] Message did not trigger AI reply logic.")
|
||
return
|
||
|
||
channel = message.channel
|
||
convo_key = channel.id
|
||
convo = _teto_conversations.get(convo_key, [])
|
||
|
||
# Only keep track of actual AI interactions in memory
|
||
if trigger:
|
||
user_content = []
|
||
# Prepend username to the message content
|
||
username = message.author.display_name if message.author.display_name else message.author.name
|
||
if message.content:
|
||
user_content.append({"type": "text", "text": f"{username}: {message.content}"})
|
||
|
||
# Handle attachments (images)
|
||
for attachment in message.attachments:
|
||
if attachment.content_type and attachment.content_type.startswith("image/"):
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(attachment.url) as image_response:
|
||
if image_response.status == 200:
|
||
image_data = await image_response.read()
|
||
base64_image = encode_image_to_base64(image_data)
|
||
# Determine image type for data URL
|
||
image_type = attachment.content_type.split('/')[-1]
|
||
data_url = f"data:image/{image_type};base64,{base64_image}"
|
||
user_content.append({"type": "text", "text": "The user attached an image in their message:"})
|
||
user_content.append({"type": "image_url", "image_url": {"url": data_url}})
|
||
log.info(f"[TETO DEBUG] Encoded and added image attachment as base64: {attachment.url}")
|
||
else:
|
||
log.warning(f"[TETO DEBUG] Failed to download image attachment: {attachment.url} (Status: {image_response.status})")
|
||
user_content.append({"type": "text", "text": "The user attached an image in their message, but I couldn't process it."})
|
||
except Exception as e:
|
||
log.error(f"[TETO DEBUG] Error processing image attachment {attachment.url}: {e}")
|
||
user_content.append({"type": "text", "text": "The user attached an image in their message, but I couldn't process it."})
|
||
|
||
|
||
# Handle stickers
|
||
for sticker in message.stickers:
|
||
# Assuming sticker has a url attribute
|
||
user_content.append({"type": "text", "text": "The user sent a sticker image:"})
|
||
user_content.append({"type": "image_url", "image_url": {"url": sticker.url}})
|
||
print(f"[TETO DEBUG] Found sticker: {sticker.url}")
|
||
|
||
# Handle custom emojis (basic regex for <:name:id> and <a:name:id>)
|
||
emoji_pattern = re.compile(r"<a?:(\w+):(\d+)>")
|
||
for match in emoji_pattern.finditer(message.content):
|
||
emoji_id = match.group(2)
|
||
emoji_name = match.group(1)
|
||
# Construct Discord emoji URL
|
||
is_animated = match.group(0).startswith("<a:")
|
||
emoji_url = f"https://cdn.discordapp.com/emojis/{emoji_id}.{'gif' if is_animated else 'png'}"
|
||
user_content.append({"type": "text", "text": f"The user included the custom emoji '{emoji_name}':"})
|
||
user_content.append({"type": "image_url", "image_url": {"url": emoji_url}})
|
||
log.info(f"[TETO DEBUG] Found custom emoji: {emoji_name} ({emoji_url})")
|
||
|
||
|
||
if not user_content:
|
||
log.info("[TETO DEBUG] Message triggered AI but contained no supported content (text, image, sticker, emoji).")
|
||
return # Don't send empty messages to the AI
|
||
|
||
# Append the current user message to the conversation history
|
||
# The `user_content` list itself becomes the value for the 'content' key
|
||
current_message_entry = {"role": "user", "content": user_content}
|
||
convo.append(current_message_entry)
|
||
|
||
try:
|
||
async with channel.typing():
|
||
# The `convo` (which is `messages` param) is now a list of dicts,
|
||
# where each dict is like {"role": "user/assistant", "content": string_or_list_of_parts}
|
||
ai_reply_text = await self._teto_reply_ai_with_messages(messages=convo)
|
||
ai_reply_text = strip_think_blocks(ai_reply_text) # Strip think blocks if any
|
||
|
||
# Ensure reply is not empty or excessively long
|
||
if not ai_reply_text or len(ai_reply_text.strip()) == 0:
|
||
ai_reply_text = "(Teto AI returned an empty response.)"
|
||
log.warning("[TETO DEBUG] AI reply was empty.")
|
||
elif len(ai_reply_text) > 1950: # Discord limit is 2000, leave some room
|
||
ai_reply_text = ai_reply_text[:1950] + "... (message truncated)"
|
||
log.warning("[TETO DEBUG] AI reply was truncated due to length.")
|
||
|
||
await message.reply(ai_reply_text)
|
||
|
||
# Store the AI's textual response in the conversation history
|
||
# The tool handling logic is now within _teto_reply_ai_with_messages
|
||
convo.append({"role": "assistant", "content": ai_reply_text})
|
||
|
||
_teto_conversations[convo_key] = convo[-10:] # Keep last 10 interactions (user + assistant turns)
|
||
log.info("[TETO DEBUG] AI reply sent successfully using Vertex AI.")
|
||
except Exception as e:
|
||
await channel.send(f"**Teto AI conversation failed! TwT**\n{e}")
|
||
log.error(f"[TETO DEBUG] Exception during AI reply: {e}")
|
||
|
||
@model_subgroup.command(name="set", description="Sets the AI model for Ame-chan.")
|
||
@app_commands.describe(model_name="The name of the AI model to use.")
|
||
async def set_ai_model(self, interaction: discord.Interaction, model_name: str):
|
||
self._ai_model = model_name
|
||
await interaction.response.send_message(f"Ame-chan's AI model set to: {model_name} desu~", ephemeral=True)
|
||
|
||
@ame_group.command(name="clear_chat_history", description="Clears the chat history for the current channel.")
|
||
async def clear_chat_history(self, interaction: discord.Interaction):
|
||
channel_id = interaction.channel_id
|
||
if channel_id in _teto_conversations:
|
||
del _teto_conversations[channel_id]
|
||
await interaction.response.send_message("Chat history cleared for this channel desu~", ephemeral=True)
|
||
else:
|
||
await interaction.response.send_message("No chat history found for this channel desu~", ephemeral=True)
|
||
|
||
@ame_group.command(name="toggle_shell_command", description="Toggles Ame-chan's ability to run shell commands.")
|
||
async def toggle_shell_command(self, interaction: discord.Interaction):
|
||
self._allow_shell_commands = not self._allow_shell_commands
|
||
status = "enabled" if self._allow_shell_commands else "disabled"
|
||
await interaction.response.send_message(f"Ame-chan's shell command ability is now {status} desu~", ephemeral=True)
|
||
|
||
@ame_group.command(name="toggle_web_search", description="Toggles Ame-chan's ability to search the web.")
|
||
async def toggle_web_search(self, interaction: discord.Interaction):
|
||
if not self.tavily_api_key or not self.tavily_client:
|
||
await interaction.response.send_message("Web search is not available because the Tavily API key is not configured. Please set the TAVILY_API_KEY environment variable.", ephemeral=True)
|
||
return
|
||
|
||
self._allow_web_search = not self._allow_web_search
|
||
status = "enabled" if self._allow_web_search else "disabled"
|
||
await interaction.response.send_message(f"Ame-chan's web search ability is now {status} desu~", ephemeral=True)
|
||
|
||
@ame_group.command(name="web_search", description="Search the web using Tavily API.")
|
||
@app_commands.describe(query="The search query to look up online.")
|
||
async def web_search_command(self, interaction: discord.Interaction, query: str):
|
||
if not self.tavily_api_key or not self.tavily_client:
|
||
await interaction.response.send_message("Web search is not available because the Tavily API key is not configured. Please set the TAVILY_API_KEY environment variable.", ephemeral=True)
|
||
return
|
||
|
||
await interaction.response.defer(thinking=True)
|
||
|
||
try:
|
||
search_results = await self.web_search(query=query)
|
||
|
||
if "error" in search_results:
|
||
await interaction.followup.send(f"❌ Error: Web search failed - {search_results['error']}")
|
||
return
|
||
|
||
# Format the results in a readable way
|
||
embed = discord.Embed(
|
||
title=f"🔍 Web Search Results for: {query}",
|
||
description=search_results.get("answer", "No summary available."),
|
||
color=discord.Color.blue()
|
||
)
|
||
|
||
for i, result in enumerate(search_results.get("results", [])[:5], 1): # Limit to top 5 results
|
||
embed.add_field(
|
||
name=f"Result {i}: {result['title']}",
|
||
value=f"[Link]({result['url']})\n{result['content'][:200]}...",
|
||
inline=False
|
||
)
|
||
|
||
embed.set_footer(text=f"Search depth: {search_results['search_depth']} | Results: {search_results['count']}")
|
||
|
||
await interaction.followup.send(embed=embed)
|
||
except Exception as e:
|
||
await interaction.followup.send(f"❌ Error performing web search: {str(e)}")
|
||
|
||
@model_subgroup.command(name="get", description="Gets the current AI model for Ame-chan.")
|
||
async def get_ai_model(self, interaction: discord.Interaction):
|
||
await interaction.response.send_message(f"Ame-chan's current AI model is: {self._ai_model} desu~", ephemeral=True)
|
||
|
||
# Context menu command must be defined at module level
|
||
@app_commands.context_menu(name="Teto AI Reply")
|
||
async def teto_context_menu_ai_reply(interaction: discord.Interaction, message: discord.Message):
|
||
"""Replies to the selected message as a Teto AI."""
|
||
if not message.content:
|
||
await interaction.response.send_message("The selected message has no text content to reply to! >.<", ephemeral=True)
|
||
return
|
||
|
||
await interaction.response.defer(ephemeral=True)
|
||
channel = interaction.channel
|
||
convo_key = channel.id
|
||
convo = _teto_conversations.get(convo_key, [])
|
||
|
||
if message.content:
|
||
convo.append({"role": "user", "content": message.content})
|
||
try:
|
||
# Get the TetoCog instance from the bot
|
||
cog = interaction.client.get_cog("TetoCog")
|
||
if cog is None:
|
||
await interaction.followup.send("TetoCog is not loaded, cannot reply.", ephemeral=True)
|
||
return
|
||
ai_reply = await cog._teto_reply_ai_with_messages(messages=convo)
|
||
ai_reply = strip_think_blocks(ai_reply)
|
||
await message.reply(ai_reply)
|
||
await interaction.followup.send("Teto AI replied desu~", ephemeral=True)
|
||
|
||
# Store the AI's textual response in the conversation history
|
||
convo.append({"role": "assistant", "content": ai_reply}) # ai_reply is already the text
|
||
_teto_conversations[convo_key] = convo[-10:]
|
||
except Exception as e:
|
||
await interaction.followup.send(f"Teto AI reply failed: {e} desu~", ephemeral=True)
|
||
|
||
async def setup(bot: commands.Bot):
|
||
cog = TetoCog(bot)
|
||
await bot.add_cog(cog)
|
||
# bot.tree.add_command(cog.ame_group) # No longer needed if groups are class variables; discovery should handle it.
|
||
# Ensure the context menu is still added if it's not part of the cog's auto-discovery
|
||
bot.tree.add_command(teto_context_menu_ai_reply) # This is a module-level command, so it needs to be added.
|
||
print("TetoCog loaded! desu~")
|