2025-04-30 08:11:48 -06:00

1119 lines
57 KiB
Python

import discord
import aiohttp
import asyncio
import json
import base64
import re
import time
import datetime
from typing import TYPE_CHECKING, Optional, List, Dict, Any, Union, AsyncIterable
import jsonschema # For manual JSON validation
from .tools import get_conversation_summary
# Vertex AI Imports
try:
import vertexai
from vertexai import generative_models
from vertexai.generative_models import (
GenerativeModel, GenerationConfig, Part, Content, Tool, FunctionDeclaration,
GenerationResponse, FinishReason
)
from google.api_core import exceptions as google_exceptions
from google.cloud.storage import Client as GCSClient # For potential image uploads
except ImportError:
print("WARNING: google-cloud-vertexai or google-cloud-storage not installed. API calls will fail.")
# Define dummy classes/exceptions if library isn't installed
class DummyGenerativeModel:
def __init__(self, model_name, system_instruction=None, tools=None): pass
async def generate_content_async(self, contents, generation_config=None, safety_settings=None, stream=False): return None
GenerativeModel = DummyGenerativeModel
class DummyPart:
@staticmethod
def from_text(text): return None
@staticmethod
def from_data(data, mime_type): return None
@staticmethod
def from_uri(uri, mime_type): return None
@staticmethod
def from_function_response(name, response): return None
Part = DummyPart
Content = dict
Tool = list
FunctionDeclaration = object
GenerationConfig = dict
GenerationResponse = object
FinishReason = object
class DummyGoogleExceptions:
ResourceExhausted = type('ResourceExhausted', (Exception,), {})
InternalServerError = type('InternalServerError', (Exception,), {})
ServiceUnavailable = type('ServiceUnavailable', (Exception,), {})
InvalidArgument = type('InvalidArgument', (Exception,), {})
GoogleAPICallError = type('GoogleAPICallError', (Exception,), {}) # Generic fallback
google_exceptions = DummyGoogleExceptions()
# Relative imports for components within the 'gurt' package
from .config import (
PROJECT_ID, LOCATION, DEFAULT_MODEL, FALLBACK_MODEL,
API_TIMEOUT, API_RETRY_ATTEMPTS, API_RETRY_DELAY, TOOLS, RESPONSE_SCHEMA,
PROACTIVE_PLAN_SCHEMA, # Import the new schema
TAVILY_API_KEY, PISTON_API_URL, PISTON_API_KEY # Import other needed configs
)
from .prompt import build_dynamic_system_prompt
from .context import gather_conversation_context, get_memory_context # Renamed functions
from .tools import TOOL_MAPPING # Import tool mapping
from .utils import format_message, log_internal_api_call # Import utilities
import copy # Needed for deep copying schemas
if TYPE_CHECKING:
from .cog import WheatleyCog # Import WheatleyCog for type hinting only
# --- Schema Preprocessing Helper ---
def _preprocess_schema_for_vertex(schema: Dict[str, Any]) -> Dict[str, Any]:
"""
Recursively preprocesses a JSON schema dictionary to replace list types
(like ["string", "null"]) with the first non-null type, making it
compatible with Vertex AI's GenerationConfig schema requirements.
Args:
schema: The JSON schema dictionary to preprocess.
Returns:
A new, preprocessed schema dictionary.
"""
if not isinstance(schema, dict):
return schema # Return non-dict elements as is
processed_schema = copy.deepcopy(schema) # Work on a copy
for key, value in processed_schema.items():
if key == "type" and isinstance(value, list):
# Find the first non-"null" type in the list
first_valid_type = next((t for t in value if isinstance(t, str) and t.lower() != "null"), None)
if first_valid_type:
processed_schema[key] = first_valid_type
else:
# Fallback if only "null" or invalid types are present (shouldn't happen in valid schemas)
processed_schema[key] = "object" # Or handle as error
print(f"Warning: Schema preprocessing found list type '{value}' with no valid non-null string type. Falling back to 'object'.")
elif isinstance(value, dict):
processed_schema[key] = _preprocess_schema_for_vertex(value) # Recurse for nested objects
elif isinstance(value, list):
# Recurse for items within arrays (e.g., in 'properties' of array items)
processed_schema[key] = [_preprocess_schema_for_vertex(item) if isinstance(item, dict) else item for item in value]
# Handle 'properties' specifically
elif key == "properties" and isinstance(value, dict):
processed_schema[key] = {prop_key: _preprocess_schema_for_vertex(prop_value) for prop_key, prop_value in value.items()}
# Handle 'items' specifically if it's a schema object
elif key == "items" and isinstance(value, dict):
processed_schema[key] = _preprocess_schema_for_vertex(value)
return processed_schema
# --- Helper Function to Safely Extract Text ---
def _get_response_text(response: Optional['GenerationResponse']) -> Optional[str]:
"""Safely extracts the text content from the first text part of a GenerationResponse."""
if not response or not response.candidates:
return None
try:
# Iterate through parts to find the first text part
for part in response.candidates[0].content.parts:
# Check if the part has a 'text' attribute and it's not empty
if hasattr(part, 'text') and part.text:
return part.text
# If no text part is found (e.g., only function call or empty text parts)
print(f"[_get_response_text] No text part found in candidate parts: {response.candidates[0].content.parts}") # Log parts structure
return None
except (AttributeError, IndexError) as e:
# Handle cases where structure is unexpected or parts list is empty
print(f"Error accessing response parts: {type(e).__name__}: {e}")
return None
except Exception as e:
# Catch unexpected errors during access
print(f"Unexpected error extracting text from response part: {e}")
return None
# --- Initialize Vertex AI ---
try:
vertexai.init(project=PROJECT_ID, location=LOCATION)
print(f"Vertex AI initialized for project '{PROJECT_ID}' in location '{LOCATION}'.")
except NameError:
print("Vertex AI SDK not imported, skipping initialization.")
except Exception as e:
print(f"Error initializing Vertex AI: {e}")
# --- Constants ---
# Define standard safety settings (adjust as needed)
# Use actual types if import succeeded, otherwise fallback to Any
_HarmCategory = getattr(generative_models, 'HarmCategory', Any)
_HarmBlockThreshold = getattr(generative_models, 'HarmBlockThreshold', Any)
STANDARD_SAFETY_SETTINGS = {
getattr(_HarmCategory, 'HARM_CATEGORY_HATE_SPEECH', 'HARM_CATEGORY_HATE_SPEECH'): getattr(_HarmBlockThreshold, 'BLOCK_MEDIUM_AND_ABOVE', 'BLOCK_MEDIUM_AND_ABOVE'),
getattr(_HarmCategory, 'HARM_CATEGORY_DANGEROUS_CONTENT', 'HARM_CATEGORY_DANGEROUS_CONTENT'): getattr(_HarmBlockThreshold, 'BLOCK_MEDIUM_AND_ABOVE', 'BLOCK_MEDIUM_AND_ABOVE'),
getattr(_HarmCategory, 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'HARM_CATEGORY_SEXUALLY_EXPLICIT'): getattr(_HarmBlockThreshold, 'BLOCK_MEDIUM_AND_ABOVE', 'BLOCK_MEDIUM_AND_ABOVE'),
getattr(_HarmCategory, 'HARM_CATEGORY_HARASSMENT', 'HARM_CATEGORY_HARASSMENT'): getattr(_HarmBlockThreshold, 'BLOCK_MEDIUM_AND_ABOVE', 'BLOCK_MEDIUM_AND_ABOVE'),
}
# --- API Call Helper ---
async def call_vertex_api_with_retry(
cog: 'WheatleyCog',
model: 'GenerativeModel', # Use string literal for type hint
contents: List['Content'], # Use string literal for type hint
generation_config: 'GenerationConfig', # Use string literal for type hint
safety_settings: Optional[Dict[Any, Any]], # Use Any for broader compatibility
request_desc: str,
stream: bool = False
) -> Union['GenerationResponse', AsyncIterable['GenerationResponse'], None]: # Use string literals
"""
Calls the Vertex AI Gemini API with retry logic.
Args:
cog: The WheatleyCog instance.
model: The initialized GenerativeModel instance.
contents: The list of Content objects for the prompt.
generation_config: The GenerationConfig object.
safety_settings: Safety settings for the request.
request_desc: A description of the request for logging purposes.
stream: Whether to stream the response.
Returns:
The GenerationResponse object or an AsyncIterable if streaming, or None on failure.
Raises:
Exception: If the API call fails after all retry attempts or encounters a non-retryable error.
"""
last_exception = None
model_name = model._model_name # Get model name for logging
start_time = time.monotonic()
for attempt in range(API_RETRY_ATTEMPTS + 1):
try:
print(f"Sending API request for {request_desc} using {model_name} (Attempt {attempt + 1}/{API_RETRY_ATTEMPTS + 1})...")
response = await model.generate_content_async(
contents=contents,
generation_config=generation_config,
safety_settings=safety_settings or STANDARD_SAFETY_SETTINGS,
stream=stream
)
# --- Success Logging ---
elapsed_time = time.monotonic() - start_time
# Ensure model_name exists in stats before incrementing
if model_name not in cog.api_stats:
cog.api_stats[model_name] = {'success': 0, 'failure': 0, 'retries': 0, 'total_time': 0.0, 'count': 0}
cog.api_stats[model_name]['success'] += 1
cog.api_stats[model_name]['total_time'] += elapsed_time
cog.api_stats[model_name]['count'] += 1
print(f"API request successful for {request_desc} ({model_name}) in {elapsed_time:.2f}s.")
return response # Success
except google_exceptions.ResourceExhausted as e:
error_msg = f"Rate limit error (ResourceExhausted) for {request_desc}: {e}"
print(f"{error_msg} (Attempt {attempt + 1})")
last_exception = e
if attempt < API_RETRY_ATTEMPTS:
if model_name not in cog.api_stats:
cog.api_stats[model_name] = {'success': 0, 'failure': 0, 'retries': 0, 'total_time': 0.0, 'count': 0}
cog.api_stats[model_name]['retries'] += 1
wait_time = API_RETRY_DELAY * (2 ** attempt) # Exponential backoff
print(f"Waiting {wait_time:.2f} seconds before retrying...")
await asyncio.sleep(wait_time)
continue
else:
break # Max retries reached
except (google_exceptions.InternalServerError, google_exceptions.ServiceUnavailable) as e:
error_msg = f"API server error ({type(e).__name__}) for {request_desc}: {e}"
print(f"{error_msg} (Attempt {attempt + 1})")
last_exception = e
if attempt < API_RETRY_ATTEMPTS:
if model_name not in cog.api_stats:
cog.api_stats[model_name] = {'success': 0, 'failure': 0, 'retries': 0, 'total_time': 0.0, 'count': 0}
cog.api_stats[model_name]['retries'] += 1
wait_time = API_RETRY_DELAY * (2 ** attempt) # Exponential backoff
print(f"Waiting {wait_time:.2f} seconds before retrying...")
await asyncio.sleep(wait_time)
continue
else:
break # Max retries reached
except google_exceptions.InvalidArgument as e:
# Often indicates a problem with the request itself (e.g., bad schema, unsupported format)
error_msg = f"Invalid argument error for {request_desc}: {e}"
print(error_msg)
last_exception = e
break # Non-retryable
except asyncio.TimeoutError: # Handle potential client-side timeouts if applicable
error_msg = f"Client-side request timed out for {request_desc} (Attempt {attempt + 1})"
print(error_msg)
last_exception = asyncio.TimeoutError(error_msg)
# Decide if client-side timeouts should be retried
if attempt < API_RETRY_ATTEMPTS:
if model_name not in cog.api_stats:
cog.api_stats[model_name] = {'success': 0, 'failure': 0, 'retries': 0, 'total_time': 0.0, 'count': 0}
cog.api_stats[model_name]['retries'] += 1
await asyncio.sleep(API_RETRY_DELAY * (attempt + 1))
continue
else:
break
except Exception as e: # Catch other potential exceptions
error_msg = f"Unexpected error during API call for {request_desc} (Attempt {attempt + 1}): {type(e).__name__}: {e}"
print(error_msg)
import traceback
traceback.print_exc()
last_exception = e
# Decide if this generic exception is retryable
# For now, treat unexpected errors as non-retryable
break
# --- Failure Logging ---
elapsed_time = time.monotonic() - start_time
if model_name not in cog.api_stats:
cog.api_stats[model_name] = {'success': 0, 'failure': 0, 'retries': 0, 'total_time': 0.0, 'count': 0}
cog.api_stats[model_name]['failure'] += 1
cog.api_stats[model_name]['total_time'] += elapsed_time
cog.api_stats[model_name]['count'] += 1
print(f"API request failed for {request_desc} ({model_name}) after {attempt + 1} attempts in {elapsed_time:.2f}s.")
# Raise the last encountered exception or a generic one
raise last_exception or Exception(f"API request failed for {request_desc} after {API_RETRY_ATTEMPTS + 1} attempts.")
# --- JSON Parsing and Validation Helper ---
def parse_and_validate_json_response(
response_text: Optional[str],
schema: Dict[str, Any],
context_description: str
) -> Optional[Dict[str, Any]]:
"""
Parses the AI's response text, attempting to extract and validate a JSON object against a schema.
Args:
response_text: The raw text content from the AI response.
schema: The JSON schema (as a dictionary) to validate against.
context_description: A description for logging purposes.
Returns:
A parsed and validated dictionary if successful, None otherwise.
"""
if response_text is None:
print(f"Parsing ({context_description}): Response text is None.")
return None
parsed_data = None
raw_json_text = response_text # Start with the full text
# Attempt 1: Try parsing the whole string directly
try:
parsed_data = json.loads(raw_json_text)
print(f"Parsing ({context_description}): Successfully parsed entire response as JSON.")
except json.JSONDecodeError:
# Attempt 2: Extract JSON object, handling optional markdown fences
# More robust regex to handle potential leading/trailing text and variations
json_match = re.search(r'```(?:json)?\s*(\{.*\})\s*```|(\{.*\})', response_text, re.DOTALL | re.MULTILINE)
if json_match:
json_str = json_match.group(1) or json_match.group(2)
if json_str:
raw_json_text = json_str # Use the extracted string for parsing
try:
parsed_data = json.loads(raw_json_text)
print(f"Parsing ({context_description}): Successfully extracted and parsed JSON using regex.")
except json.JSONDecodeError as e_inner:
print(f"Parsing ({context_description}): Regex found potential JSON, but it failed to parse: {e_inner}\nContent: {raw_json_text[:500]}")
parsed_data = None
else:
print(f"Parsing ({context_description}): Regex matched, but failed to capture JSON content.")
parsed_data = None
else:
print(f"Parsing ({context_description}): Could not parse directly or extract JSON object using regex.\nContent: {raw_json_text[:500]}")
parsed_data = None
# Validation step
if parsed_data is not None:
if not isinstance(parsed_data, dict):
print(f"Parsing ({context_description}): Parsed data is not a dictionary: {type(parsed_data)}")
return None # Fail validation if not a dict
try:
jsonschema.validate(instance=parsed_data, schema=schema)
print(f"Parsing ({context_description}): JSON successfully validated against schema.")
# Ensure default keys exist after validation
parsed_data.setdefault("should_respond", False)
parsed_data.setdefault("content", None)
parsed_data.setdefault("react_with_emoji", None)
return parsed_data
except jsonschema.ValidationError as e:
print(f"Parsing ({context_description}): JSON failed schema validation: {e.message}")
# Optionally log more details: e.path, e.schema_path, e.instance
return None # Validation failed
except Exception as e: # Catch other potential validation errors
print(f"Parsing ({context_description}): Unexpected error during JSON schema validation: {e}")
return None
else:
# Parsing failed before validation could occur
return None
# --- Tool Processing ---
async def process_requested_tools(cog: 'WheatleyCog', function_call: 'generative_models.FunctionCall') -> 'Part': # Use string literals
"""
Process a tool request specified by the AI's FunctionCall response.
Args:
cog: The WheatleyCog instance.
function_call: The FunctionCall object from the GenerationResponse.
Returns:
A Part object containing the tool result or error, formatted for the follow-up API call.
"""
function_name = function_call.name
# Convert the Struct field arguments to a standard Python dict
function_args = dict(function_call.args) if function_call.args else {}
tool_result_content = None
print(f"Processing tool request: {function_name} with args: {function_args}")
tool_start_time = time.monotonic()
if function_name in TOOL_MAPPING:
try:
tool_func = TOOL_MAPPING[function_name]
# Execute the mapped function
# Ensure the function signature matches the expected arguments
# Pass cog if the tool implementation requires it
result = await tool_func(cog, **function_args)
# --- Tool Success Logging ---
tool_elapsed_time = time.monotonic() - tool_start_time
if function_name not in cog.tool_stats:
cog.tool_stats[function_name] = {'success': 0, 'failure': 0, 'total_time': 0.0, 'count': 0}
cog.tool_stats[function_name]['success'] += 1
cog.tool_stats[function_name]['total_time'] += tool_elapsed_time
cog.tool_stats[function_name]['count'] += 1
print(f"Tool '{function_name}' executed successfully in {tool_elapsed_time:.2f}s.")
# Prepare result for API - must be JSON serializable, typically a dict
if not isinstance(result, dict):
# Attempt to convert common types or wrap in a dict
if isinstance(result, (str, int, float, bool, list)) or result is None:
result = {"result": result}
else:
print(f"Warning: Tool '{function_name}' returned non-standard type {type(result)}. Attempting str conversion.")
result = {"result": str(result)}
tool_result_content = result
except Exception as e:
# --- Tool Failure Logging ---
tool_elapsed_time = time.monotonic() - tool_start_time
if function_name not in cog.tool_stats:
cog.tool_stats[function_name] = {'success': 0, 'failure': 0, 'total_time': 0.0, 'count': 0}
cog.tool_stats[function_name]['failure'] += 1
cog.tool_stats[function_name]['total_time'] += tool_elapsed_time
cog.tool_stats[function_name]['count'] += 1
error_message = f"Error executing tool {function_name}: {type(e).__name__}: {str(e)}"
print(f"{error_message} (Took {tool_elapsed_time:.2f}s)")
import traceback
traceback.print_exc()
tool_result_content = {"error": error_message}
else:
# --- Tool Not Found Logging ---
tool_elapsed_time = time.monotonic() - tool_start_time
# Log attempt even if tool not found
if function_name not in cog.tool_stats:
cog.tool_stats[function_name] = {'success': 0, 'failure': 0, 'total_time': 0.0, 'count': 0}
cog.tool_stats[function_name]['failure'] += 1
cog.tool_stats[function_name]['total_time'] += tool_elapsed_time
cog.tool_stats[function_name]['count'] += 1
error_message = f"Tool '{function_name}' not found or implemented."
print(f"{error_message} (Took {tool_elapsed_time:.2f}s)")
tool_result_content = {"error": error_message}
# Return the result formatted as a Part for the API
return Part.from_function_response(name=function_name, response=tool_result_content)
# --- Main AI Response Function ---
async def get_ai_response(cog: 'WheatleyCog', message: discord.Message, model_name: Optional[str] = None) -> Dict[str, Any]:
"""
Gets responses from the Vertex AI Gemini API, handling potential tool usage and returning
the final parsed response.
Args:
cog: The WheatleyCog instance.
message: The triggering discord.Message.
model_name: Optional override for the AI model name (e.g., "gemini-1.5-pro-preview-0409").
Returns:
A dictionary containing:
- "final_response": Parsed JSON data from the final AI call (or None if parsing/validation fails).
- "error": An error message string if a critical error occurred, otherwise None.
- "fallback_initial": Optional minimal response if initial parsing failed critically (less likely with controlled generation).
"""
if not PROJECT_ID or not LOCATION:
return {"final_response": None, "error": "Google Cloud Project ID or Location not configured"}
channel_id = message.channel.id
user_id = message.author.id
initial_parsed_data = None # Added to store initial parsed result
final_parsed_data = None
error_message = None
fallback_response = None
try:
# --- Build Prompt Components ---
final_system_prompt = await build_dynamic_system_prompt(cog, message)
conversation_context_messages = gather_conversation_context(cog, channel_id, message.id) # Pass cog
memory_context = await get_memory_context(cog, message) # Pass cog
# --- Initialize Model ---
# Tools are passed during model initialization in Vertex AI SDK
# Combine tool declarations into a Tool object
vertex_tool = Tool(function_declarations=TOOLS) if TOOLS else None
model = GenerativeModel(
model_name or DEFAULT_MODEL,
system_instruction=final_system_prompt,
tools=[vertex_tool] if vertex_tool else None
)
# --- Prepare Message History (Contents) ---
contents: List[Content] = []
# Add memory context if available
if memory_context:
# System messages aren't directly supported in the 'contents' list for multi-turn like OpenAI.
# It's better handled via the 'system_instruction' parameter of GenerativeModel.
# We might prepend it to the first user message or handle it differently if needed.
# For now, we rely on system_instruction. Let's log if we have memory context.
print("Memory context available, relying on system_instruction.")
# If needed, could potentially add as a 'model' role message before user messages,
# but this might confuse the turn structure.
# contents.append(Content(role="model", parts=[Part.from_text(f"System Note: {memory_context}")]))
# Add conversation history
for msg in conversation_context_messages:
role = msg.get("role", "user") # Default to user if role missing
# Map roles if necessary (e.g., 'assistant' -> 'model')
if role == "assistant":
role = "model"
elif role == "system":
# Skip system messages here, handled by system_instruction
continue
# Handle potential multimodal content in history (if stored that way)
if isinstance(msg.get("content"), list):
parts = [Part.from_text(part["text"]) if part["type"] == "text" else Part.from_uri(part["image_url"]["url"], mime_type=part["image_url"]["url"].split(";")[0].split(":")[1]) if part["type"] == "image_url" else None for part in msg["content"]]
parts = [p for p in parts if p] # Filter out None parts
if parts:
contents.append(Content(role=role, parts=parts))
elif isinstance(msg.get("content"), str):
contents.append(Content(role=role, parts=[Part.from_text(msg["content"])]))
# --- Prepare the current message content (potentially multimodal) ---
current_message_parts = []
formatted_current_message = format_message(cog, message) # Pass cog if needed
# --- Construct text content, including reply context if applicable ---
text_content = ""
if formatted_current_message.get("is_reply") and formatted_current_message.get("replied_to_author_name"):
reply_author = formatted_current_message["replied_to_author_name"]
reply_content = formatted_current_message.get("replied_to_content", "...") # Use ellipsis if content missing
# Truncate long replied content to keep context concise
max_reply_len = 150
if len(reply_content) > max_reply_len:
reply_content = reply_content[:max_reply_len] + "..."
text_content += f"(Replying to {reply_author}: \"{reply_content}\")\n"
# Add current message author and content
text_content += f"{formatted_current_message['author']['display_name']}: {formatted_current_message['content']}"
# Add mention details
if formatted_current_message.get("mentioned_users_details"):
mentions_str = ", ".join([f"{m['display_name']}(id:{m['id']})" for m in formatted_current_message["mentioned_users_details"]])
text_content += f"\n(Message Details: Mentions=[{mentions_str}])"
current_message_parts.append(Part.from_text(text_content))
# --- End text content construction ---
if message.attachments:
print(f"Processing {len(message.attachments)} attachments for message {message.id}")
for attachment in message.attachments:
mime_type = attachment.content_type
file_url = attachment.url
filename = attachment.filename
# Check if MIME type is supported for URI input by Gemini
# Expand this list based on Gemini documentation for supported types via URI
supported_mime_prefixes = ["image/", "video/", "audio/", "text/plain", "application/pdf"]
is_supported = False
if mime_type:
for prefix in supported_mime_prefixes:
if mime_type.startswith(prefix):
is_supported = True
break
# Add specific non-prefixed types if needed
# if mime_type in ["application/vnd.google-apps.document", ...]:
# is_supported = True
if is_supported and file_url:
try:
# 1. Add text part instructing AI about the file
instruction_text = f"User attached a file: '{filename}' (Type: {mime_type}). Analyze this file from the following URI and incorporate your understanding into your response."
current_message_parts.append(Part.from_text(instruction_text))
print(f"Added text instruction for attachment: {filename}")
# 2. Add the URI part
# Ensure mime_type doesn't contain parameters like '; charset=...' if the API doesn't like them
clean_mime_type = mime_type.split(';')[0]
current_message_parts.append(Part.from_uri(uri=file_url, mime_type=clean_mime_type))
print(f"Added URI part for attachment: {filename} ({clean_mime_type}) using URL: {file_url}")
except Exception as e:
print(f"Error creating Part for attachment {filename} ({mime_type}): {e}")
# Optionally add a text part indicating the error
current_message_parts.append(Part.from_text(f"(System Note: Failed to process attachment '{filename}' - {e})"))
else:
print(f"Skipping unsupported or invalid attachment: {filename} (Type: {mime_type}, URL: {file_url})")
# Optionally inform the AI that an unsupported file was attached
current_message_parts.append(Part.from_text(f"(System Note: User attached an unsupported file '{filename}' of type '{mime_type}' which cannot be processed.)"))
# Ensure there's always *some* content part, even if only text or errors
if current_message_parts:
contents.append(Content(role="user", parts=current_message_parts))
else:
print("Warning: No content parts generated for user message.")
contents.append(Content(role="user", parts=[Part.from_text("")]))
# --- First API Call (Check for Tool Use) ---
print("Making initial API call to check for tool use...")
generation_config_initial = GenerationConfig(
temperature=0.75,
max_output_tokens=10000, # Adjust as needed
# No response schema needed for the initial call, just checking for function calls
)
initial_response = await call_vertex_api_with_retry(
cog=cog,
model=model,
contents=contents,
generation_config=generation_config_initial,
safety_settings=STANDARD_SAFETY_SETTINGS,
request_desc=f"Initial response check for message {message.id}"
)
# --- Log Raw Request and Response ---
try:
# Log the request payload (contents)
request_payload_log = [{"role": c.role, "parts": [str(p) for p in c.parts]} for c in contents] # Convert parts to string for logging
print(f"--- Raw API Request (Initial Call) ---\n{json.dumps(request_payload_log, indent=2)}\n------------------------------------")
# Log the raw response object
print(f"--- Raw API Response (Initial Call) ---\n{initial_response}\n-----------------------------------")
except Exception as log_e:
print(f"Error logging raw request/response: {log_e}")
# --- End Logging ---
if not initial_response or not initial_response.candidates:
raise Exception("Initial API call returned no response or candidates.")
# --- Check for Tool Call FIRST ---
candidate = initial_response.candidates[0]
finish_reason = getattr(candidate, 'finish_reason', None)
function_call = None
function_call_part_content = None # Store the AI's request message content
# Check primarily for the *presence* of a function call part,
# as finish_reason might be STOP even with a function call.
if hasattr(candidate, 'content') and candidate.content.parts:
for part in candidate.content.parts:
if hasattr(part, 'function_call'):
function_call = part.function_call # Assign the value
# Add check to ensure function_call is not None before proceeding
if function_call:
# Store the whole content containing the call to add to history later
function_call_part_content = candidate.content
print(f"AI requested tool (found function_call part): {function_call.name}")
break # Found a valid function call part
else:
# Log if the attribute exists but is None (unexpected case)
print("Warning: Found part with 'function_call' attribute, but its value was None.")
# --- Process Tool Call or Handle Direct Response ---
if function_call and function_call_part_content:
# --- Tool Call Path ---
initial_parsed_data = None # No initial JSON expected if tool is called
# Process the tool request
tool_response_part = await process_requested_tools(cog, function_call)
# Append the AI's request and the tool's response to the history
contents.append(candidate.content) # Add the AI's function call request message
contents.append(Content(role="function", parts=[tool_response_part])) # Add the function response part
# --- Second API Call (Get Final Response After Tool) ---
print("Making follow-up API call with tool results...")
# Initialize a NEW model instance WITHOUT tools for the follow-up call
# This prevents the InvalidArgument error when specifying response schema
model_final = GenerativeModel(
model_name or DEFAULT_MODEL, # Use the same model name
system_instruction=final_system_prompt # Keep the same system prompt
# Omit the 'tools' parameter here
)
# Preprocess the schema before passing it to GenerationConfig
processed_response_schema = _preprocess_schema_for_vertex(RESPONSE_SCHEMA['schema'])
generation_config_final = GenerationConfig(
temperature=0.75, # Keep original temperature for final response
max_output_tokens=10000, # Keep original max tokens
response_mime_type="application/json",
response_schema=processed_response_schema # Use preprocessed schema
)
final_response_obj = await call_vertex_api_with_retry( # Renamed variable for clarity
cog=cog,
model=model_final, # Use the new model instance WITHOUT tools
contents=contents, # History now includes tool call/response
generation_config=generation_config_final,
safety_settings=STANDARD_SAFETY_SETTINGS,
request_desc=f"Follow-up response for message {message.id} after tool execution"
)
if not final_response_obj or not final_response_obj.candidates:
raise Exception("Follow-up API call returned no response or candidates.")
final_response_text = _get_response_text(final_response_obj) # Use helper
final_parsed_data = parse_and_validate_json_response(
final_response_text, RESPONSE_SCHEMA['schema'], "final response after tools"
)
# Handle validation failure - Re-prompt loop (simplified example)
if final_parsed_data is None:
print("Warning: Final response failed validation. Attempting re-prompt (basic)...")
# Construct a basic re-prompt message
contents.append(final_response_obj.candidates[0].content) # Add the invalid response
contents.append(Content(role="user", parts=[Part.from_text(
"Your previous JSON response was invalid or did not match the required schema. "
f"Please provide the response again, strictly adhering to this schema:\n{json.dumps(RESPONSE_SCHEMA['schema'], indent=2)}"
)]))
# Retry the final call
retry_response_obj = await call_vertex_api_with_retry(
cog=cog, model=model, contents=contents,
generation_config=generation_config_final, safety_settings=STANDARD_SAFETY_SETTINGS,
request_desc=f"Re-prompt validation failure for message {message.id}"
)
if retry_response_obj and retry_response_obj.candidates:
final_response_text = _get_response_text(retry_response_obj) # Use helper
final_parsed_data = parse_and_validate_json_response(
final_response_text, RESPONSE_SCHEMA['schema'], "re-prompted final response"
)
if final_parsed_data is None:
print("Critical Error: Re-prompted response still failed validation.")
error_message = "Failed to get valid JSON response after re-prompting."
else:
error_message = "Failed to get response after re-prompting."
# final_parsed_data is now set (or None if failed) after tool use and potential re-prompt
else:
# --- No Tool Call Path ---
print("No tool call requested by AI. Processing initial response as final.")
# Attempt to parse the initial response text directly.
initial_response_text = _get_response_text(initial_response) # Use helper
# Validate against the final schema because this IS the final response.
final_parsed_data = parse_and_validate_json_response(
initial_response_text, RESPONSE_SCHEMA['schema'], "final response (no tools)"
)
initial_parsed_data = final_parsed_data # Keep initial_parsed_data consistent for return dict
if final_parsed_data is None:
# This means the initial response failed validation.
print("Critical Error: Initial response failed validation (no tools).")
error_message = "Failed to parse/validate initial AI JSON response."
# Create a basic fallback if the bot was mentioned
replied_to_bot = message.reference and message.reference.resolved and message.reference.resolved.author == cog.bot.user
if cog.bot.user.mentioned_in(message) or replied_to_bot:
fallback_response = {"should_respond": True, "content": "...", "react_with_emoji": ""}
# initial_parsed_data is not used in this path, only final_parsed_data matters
except Exception as e:
error_message = f"Error in get_ai_response main loop for message {message.id}: {type(e).__name__}: {str(e)}"
print(error_message)
import traceback
traceback.print_exc()
# Ensure both are None on critical error
initial_parsed_data = None
final_parsed_data = None
return {
"initial_response": initial_parsed_data, # Return parsed initial data
"final_response": final_parsed_data, # Return parsed final data
"error": error_message,
"fallback_initial": fallback_response
}
# --- Proactive AI Response Function ---
async def get_proactive_ai_response(cog: 'WheatleyCog', message: discord.Message, trigger_reason: str) -> Dict[str, Any]:
"""Generates a proactive response based on a specific trigger using Vertex AI."""
if not PROJECT_ID or not LOCATION:
return {"should_respond": False, "content": None, "react_with_emoji": None, "error": "Google Cloud Project ID or Location not configured"}
print(f"--- Proactive Response Triggered: {trigger_reason} ---")
channel_id = message.channel.id
final_parsed_data = None
error_message = None
plan = None # Variable to store the plan
try:
# --- Build Context for Planning ---
# Gather relevant context: recent messages, topic, sentiment, Gurt's mood/interests, trigger reason
planning_context_parts = [
f"Proactive Trigger Reason: {trigger_reason}",
f"Current Mood: {cog.current_mood}",
]
# Add recent messages summary
summary_data = await get_conversation_summary(cog, str(channel_id), message_limit=15) # Use tool function
if summary_data and not summary_data.get("error"):
planning_context_parts.append(f"Recent Conversation Summary: {summary_data['summary']}")
# Add active topics
active_topics_data = cog.active_topics.get(channel_id)
if active_topics_data and active_topics_data.get("topics"):
topics_str = ", ".join([f"{t['topic']} ({t['score']:.1f})" for t in active_topics_data["topics"][:3]])
planning_context_parts.append(f"Active Topics: {topics_str}")
# Add sentiment
sentiment_data = cog.conversation_sentiment.get(channel_id)
if sentiment_data:
planning_context_parts.append(f"Conversation Sentiment: {sentiment_data.get('overall', 'N/A')} (Intensity: {sentiment_data.get('intensity', 0):.1f})")
# Add Wheatley's interests (Note: Interests are likely disabled/removed for Wheatley, this might fetch nothing)
try:
interests = await cog.memory_manager.get_interests(limit=5)
if interests:
interests_str = ", ".join([f"{t} ({l:.1f})" for t, l in interests])
planning_context_parts.append(f"Wheatley's Interests: {interests_str}") # Changed text
except Exception as int_e: print(f"Error getting interests for planning: {int_e}")
planning_context = "\n".join(planning_context_parts)
# --- Planning Step ---
print("Generating proactive response plan...")
planning_prompt_messages = [
{"role": "system", "content": "You are Wheatley's planning module. Analyze the context and trigger reason to decide if Wheatley should respond proactively and, if so, outline a plan (goal, key info, tone). Focus on natural, in-character engagement (rambling, insecure, bad ideas). Respond ONLY with JSON matching the provided schema."}, # Updated system prompt
{"role": "user", "content": f"Context:\n{planning_context}\n\nBased on this context and the trigger reason, create a plan for Wheatley's proactive response."} # Updated user prompt
]
plan = await get_internal_ai_json_response(
cog=cog,
prompt_messages=planning_prompt_messages,
task_description=f"Proactive Planning ({trigger_reason})",
response_schema_dict=PROACTIVE_PLAN_SCHEMA['schema'],
model_name_override=FALLBACK_MODEL, # Use a potentially faster/cheaper model for planning
temperature=0.5,
max_tokens=300
)
if not plan or not plan.get("should_respond"):
reason = plan.get('reasoning', 'Planning failed or decided against responding.') if plan else 'Planning failed.'
print(f"Proactive response aborted by plan: {reason}")
return {"should_respond": False, "content": None, "react_with_emoji": None, "note": f"Plan: {reason}"}
print(f"Proactive Plan Generated: Goal='{plan.get('response_goal', 'N/A')}', Reasoning='{plan.get('reasoning', 'N/A')}'")
# --- Build Final Proactive Prompt using Plan ---
persistent_traits = await cog.memory_manager.get_all_personality_traits()
if not persistent_traits:
persistent_traits = {} # Wheatley doesn't use these Gurt traits
final_proactive_prompt_parts = [
f"You are Wheatley, an Aperture Science Personality Core. Your tone is rambling, insecure, uses British slang, and you often have terrible ideas you think are brilliant.", # Updated personality description
# Removed Gurt-specific traits
# Removed mood reference as it's disabled for Wheatley
# Incorporate Plan Details:
f"You decided to respond proactively (maybe?). Trigger Reason: {trigger_reason}.", # Wheatley-style uncertainty
f"Your Brilliant Plan (Goal): {plan.get('response_goal', 'Say something... probably helpful?')}.", # Wheatley-style goal
f"Reasoning: {plan.get('reasoning', 'N/A')}.",
]
if plan.get('key_info_to_include'):
info_str = "; ".join(plan['key_info_to_include'])
final_proactive_prompt_parts.append(f"Consider mentioning: {info_str}")
if plan.get('suggested_tone'):
final_proactive_prompt_parts.append(f"Adjust tone to be: {plan['suggested_tone']}")
final_proactive_prompt_parts.append("Generate a casual, in-character message based on the plan and context. Keep it relatively short and natural-sounding.")
final_proactive_system_prompt = "\n\n".join(final_proactive_prompt_parts)
# --- Initialize Final Model ---
model = GenerativeModel(
model_name=DEFAULT_MODEL,
system_instruction=final_proactive_system_prompt
)
# --- Prepare Final Contents ---
contents = [
Content(role="user", parts=[Part.from_text(
f"Generate the response based on your plan. **CRITICAL: Your response MUST be ONLY the raw JSON object matching this schema:**\n\n{json.dumps(RESPONSE_SCHEMA['schema'], indent=2)}\n\n**Ensure nothing precedes or follows the JSON.**"
)])
]
# --- Call Final LLM API ---
# Preprocess the schema before passing it to GenerationConfig
processed_response_schema_proactive = _preprocess_schema_for_vertex(RESPONSE_SCHEMA['schema'])
generation_config_final = GenerationConfig(
temperature=0.8, # Use original proactive temp
max_output_tokens=200,
response_mime_type="application/json",
response_schema=processed_response_schema_proactive # Use preprocessed schema
)
response_obj = await call_vertex_api_with_retry(
cog=cog,
model=model,
contents=contents,
generation_config=generation_config_final,
safety_settings=STANDARD_SAFETY_SETTINGS,
request_desc=f"Final proactive response for channel {channel_id} ({trigger_reason})"
)
if not response_obj or not response_obj.candidates:
raise Exception("Final proactive API call returned no response or candidates.")
# --- Parse and Validate Final Response ---
final_response_text = _get_response_text(response_obj)
final_parsed_data = parse_and_validate_json_response(
final_response_text, RESPONSE_SCHEMA['schema'], f"final proactive response ({trigger_reason})"
)
if final_parsed_data is None:
print(f"Warning: Failed to parse/validate final proactive JSON response for {trigger_reason}.")
final_parsed_data = {"should_respond": False, "content": None, "react_with_emoji": None, "note": "Fallback - Failed to parse/validate final proactive JSON"}
else:
# --- Cache Bot Response ---
if final_parsed_data.get("should_respond") and final_parsed_data.get("content"):
bot_response_cache_entry = {
"id": f"bot_proactive_{message.id}_{int(time.time())}",
"author": {"id": str(cog.bot.user.id), "name": cog.bot.user.name, "display_name": cog.bot.user.display_name, "bot": True},
"content": final_parsed_data.get("content", ""), "created_at": datetime.datetime.now().isoformat(),
"attachments": [], "embeds": False, "mentions": [], "replied_to_message_id": None,
"channel": message.channel, "guild": message.guild, "reference": None, "mentioned_users_details": []
}
cog.message_cache['by_channel'].setdefault(channel_id, []).append(bot_response_cache_entry)
cog.message_cache['global_recent'].append(bot_response_cache_entry)
cog.bot_last_spoke[channel_id] = time.time()
# Removed Gurt-specific participation tracking
except Exception as e:
error_message = f"Error getting proactive AI response for channel {channel_id} ({trigger_reason}): {type(e).__name__}: {str(e)}"
print(error_message)
final_parsed_data = {"should_respond": False, "content": None, "react_with_emoji": None, "error": error_message}
# Ensure default keys exist
final_parsed_data.setdefault("should_respond", False)
final_parsed_data.setdefault("content", None)
final_parsed_data.setdefault("react_with_emoji", None)
if error_message and "error" not in final_parsed_data:
final_parsed_data["error"] = error_message
return final_parsed_data
# --- Internal AI Call for Specific Tasks ---
async def get_internal_ai_json_response(
cog: 'WheatleyCog',
prompt_messages: List[Dict[str, Any]], # Keep this format
task_description: str,
response_schema_dict: Dict[str, Any], # Expect schema as dict
model_name: Optional[str] = None,
temperature: float = 0.7,
max_tokens: int = 5000,
) -> Optional[Dict[str, Any]]: # Keep return type hint simple
"""
Makes a Vertex AI call expecting a specific JSON response format for internal tasks.
Args:
cog: The WheatleyCog instance.
prompt_messages: List of message dicts (like OpenAI format: {'role': 'user'/'model', 'content': '...'}).
task_description: Description for logging.
response_schema_dict: The expected JSON output schema as a dictionary.
model_name: Optional model override.
temperature: Generation temperature.
max_tokens: Max output tokens.
Returns:
The parsed and validated JSON dictionary if successful, None otherwise.
"""
if not PROJECT_ID or not LOCATION:
print(f"Error in get_internal_ai_json_response ({task_description}): GCP Project/Location not set.")
return None
final_parsed_data = None
error_occurred = None
request_payload_for_logging = {} # For logging
try:
# --- Convert prompt messages to Vertex AI Content format ---
contents: List[Content] = []
system_instruction = None
for msg in prompt_messages:
role = msg.get("role", "user")
content_text = msg.get("content", "")
if role == "system":
# Use the first system message as system_instruction
if system_instruction is None:
system_instruction = content_text
else:
# Append subsequent system messages to the instruction
system_instruction += "\n\n" + content_text
continue # Skip adding system messages to contents list
elif role == "assistant":
role = "model"
# --- Process content (string or list) ---
content_value = msg.get("content")
message_parts: List[Part] = [] # Initialize list to hold parts for this message
if isinstance(content_value, str):
# Handle simple string content
message_parts.append(Part.from_text(content_value))
elif isinstance(content_value, list):
# Handle list content (e.g., multimodal from ProfileUpdater)
for part_data in content_value:
part_type = part_data.get("type")
if part_type == "text":
text = part_data.get("text", "")
message_parts.append(Part.from_text(text))
elif part_type == "image_data":
mime_type = part_data.get("mime_type")
base64_data = part_data.get("data")
if mime_type and base64_data:
try:
image_bytes = base64.b64decode(base64_data)
message_parts.append(Part.from_data(data=image_bytes, mime_type=mime_type))
except Exception as decode_err:
print(f"Error decoding/adding image part in get_internal_ai_json_response: {decode_err}")
# Optionally add a placeholder text part indicating failure
message_parts.append(Part.from_text("(System Note: Failed to process an image part)"))
else:
print("Warning: image_data part missing mime_type or data.")
else:
print(f"Warning: Unknown part type '{part_type}' in internal prompt message.")
else:
print(f"Warning: Unexpected content type '{type(content_value)}' in internal prompt message.")
# Add the content object if parts were generated
if message_parts:
contents.append(Content(role=role, parts=message_parts))
else:
print(f"Warning: No parts generated for message role '{role}'.")
# Add the critical JSON instruction to the last user message or as a new user message
json_instruction_content = (
f"**CRITICAL: Your response MUST consist *only* of the raw JSON object itself, matching this schema:**\n"
f"{json.dumps(response_schema_dict, indent=2)}\n"
f"**Ensure nothing precedes or follows the JSON.**"
)
if contents and contents[-1].role == "user":
contents[-1].parts.append(Part.from_text(f"\n\n{json_instruction_content}"))
else:
contents.append(Content(role="user", parts=[Part.from_text(json_instruction_content)]))
# --- Initialize Model ---
model = GenerativeModel(
model_name=model_name or DEFAULT_MODEL, # Use keyword argument
system_instruction=system_instruction
# No tools needed for internal JSON tasks usually
)
# --- Prepare Generation Config ---
# Preprocess the schema before passing it to GenerationConfig
processed_schema_internal = _preprocess_schema_for_vertex(response_schema_dict)
generation_config = GenerationConfig(
temperature=temperature,
max_output_tokens=max_tokens,
response_mime_type="application/json",
response_schema=processed_schema_internal # Use preprocessed schema
)
# Prepare payload for logging (approximate)
request_payload_for_logging = {
"model": model._model_name,
"system_instruction": system_instruction,
"contents": [ # Simplified representation for logging
{"role": c.role, "parts": [p.text if hasattr(p,'text') else str(type(p)) for p in c.parts]}
for c in contents
],
# Use the original generation_config dict directly for logging
"generation_config": generation_config # It's already a dict
}
# --- Add detailed logging for raw request ---
try:
print(f"--- Raw request payload for {task_description} ---")
# Use json.dumps for pretty printing, handle potential errors
print(json.dumps(request_payload_for_logging, indent=2, default=str)) # Use default=str as fallback
print(f"--- End Raw request payload ---")
except Exception as req_log_e:
print(f"Error logging raw request payload: {req_log_e}")
print(f"Payload causing error: {request_payload_for_logging}") # Print the raw dict on error
# --- End detailed logging ---
# --- Call API ---
response_obj = await call_vertex_api_with_retry(
cog=cog,
model=model,
contents=contents,
generation_config=generation_config,
safety_settings=STANDARD_SAFETY_SETTINGS, # Use standard safety
request_desc=task_description
)
if not response_obj or not response_obj.candidates:
raise Exception("Internal API call returned no response or candidates.")
# --- Parse and Validate ---
# This function always expects JSON, so directly use response_obj.text
final_response_text = response_obj.text
# --- Add detailed logging for raw response text ---
print(f"--- Raw response_obj.text for {task_description} ---")
print(final_response_text)
print(f"--- End Raw response_obj.text ---")
# --- End detailed logging ---
print(f"Parsing ({task_description}): Using response_obj.text for JSON.")
final_parsed_data = parse_and_validate_json_response(
final_response_text, response_schema_dict, f"internal task ({task_description})"
)
if final_parsed_data is None:
print(f"Warning: Internal task '{task_description}' failed JSON validation.")
# No re-prompting for internal tasks, just return None
except Exception as e:
print(f"Error in get_internal_ai_json_response ({task_description}): {type(e).__name__}: {e}")
error_occurred = e
import traceback
traceback.print_exc()
final_parsed_data = None
finally:
# Log the call
try:
# Pass the simplified payload for logging
await log_internal_api_call(cog, task_description, request_payload_for_logging, final_parsed_data, error_occurred)
except Exception as log_e:
print(f"Error logging internal API call: {log_e}")
return final_parsed_data