import discord import aiohttp import asyncio import json import base64 import re import time import datetime from typing import TYPE_CHECKING, Optional, List, Dict, Any, Union, AsyncIterable import jsonschema # For manual JSON validation from .tools import get_conversation_summary # Vertex AI Imports try: import vertexai from vertexai import generative_models from vertexai.generative_models import ( GenerativeModel, GenerationConfig, Part, Content, Tool, FunctionDeclaration, GenerationResponse, FinishReason, ) from google.api_core import exceptions as google_exceptions from google.cloud.storage import Client as GCSClient # For potential image uploads except ImportError: print( "WARNING: google-cloud-vertexai or google-cloud-storage not installed. API calls will fail." ) # Define dummy classes/exceptions if library isn't installed class DummyGenerativeModel: def __init__(self, model_name, system_instruction=None, tools=None): pass async def generate_content_async( self, contents, generation_config=None, safety_settings=None, stream=False ): return None GenerativeModel = DummyGenerativeModel class DummyPart: @staticmethod def from_text(text): return None @staticmethod def from_data(data, mime_type): return None @staticmethod def from_uri(uri, mime_type): return None @staticmethod def from_function_response(name, response): return None Part = DummyPart Content = dict Tool = list FunctionDeclaration = object GenerationConfig = dict GenerationResponse = object FinishReason = object class DummyGoogleExceptions: ResourceExhausted = type("ResourceExhausted", (Exception,), {}) InternalServerError = type("InternalServerError", (Exception,), {}) ServiceUnavailable = type("ServiceUnavailable", (Exception,), {}) InvalidArgument = type("InvalidArgument", (Exception,), {}) GoogleAPICallError = type( "GoogleAPICallError", (Exception,), {} ) # Generic fallback google_exceptions = DummyGoogleExceptions() # Relative imports for components within the 'gurt' package from .config import ( PROJECT_ID, LOCATION, DEFAULT_MODEL, FALLBACK_MODEL, API_TIMEOUT, API_RETRY_ATTEMPTS, API_RETRY_DELAY, TOOLS, RESPONSE_SCHEMA, PROACTIVE_PLAN_SCHEMA, # Import the new schema TAVILY_API_KEY, PISTON_API_URL, PISTON_API_KEY, # Import other needed configs ) from .prompt import build_dynamic_system_prompt from .context import ( gather_conversation_context, get_memory_context, ) # Renamed functions from .tools import TOOL_MAPPING # Import tool mapping from .utils import format_message, log_internal_api_call # Import utilities import copy # Needed for deep copying schemas if TYPE_CHECKING: from .cog import WheatleyCog # Import WheatleyCog for type hinting only # --- Schema Preprocessing Helper --- def _preprocess_schema_for_vertex(schema: Dict[str, Any]) -> Dict[str, Any]: """ Recursively preprocesses a JSON schema dictionary to replace list types (like ["string", "null"]) with the first non-null type, making it compatible with Vertex AI's GenerationConfig schema requirements. Args: schema: The JSON schema dictionary to preprocess. Returns: A new, preprocessed schema dictionary. """ if not isinstance(schema, dict): return schema # Return non-dict elements as is processed_schema = copy.deepcopy(schema) # Work on a copy for key, value in processed_schema.items(): if key == "type" and isinstance(value, list): # Find the first non-"null" type in the list first_valid_type = next( (t for t in value if isinstance(t, str) and t.lower() != "null"), None ) if first_valid_type: processed_schema[key] = first_valid_type else: # Fallback if only "null" or invalid types are present (shouldn't happen in valid schemas) processed_schema[key] = "object" # Or handle as error print( f"Warning: Schema preprocessing found list type '{value}' with no valid non-null string type. Falling back to 'object'." ) elif isinstance(value, dict): processed_schema[key] = _preprocess_schema_for_vertex( value ) # Recurse for nested objects elif isinstance(value, list): # Recurse for items within arrays (e.g., in 'properties' of array items) processed_schema[key] = [ _preprocess_schema_for_vertex(item) if isinstance(item, dict) else item for item in value ] # Handle 'properties' specifically elif key == "properties" and isinstance(value, dict): processed_schema[key] = { prop_key: _preprocess_schema_for_vertex(prop_value) for prop_key, prop_value in value.items() } # Handle 'items' specifically if it's a schema object elif key == "items" and isinstance(value, dict): processed_schema[key] = _preprocess_schema_for_vertex(value) return processed_schema # --- Helper Function to Safely Extract Text --- def _get_response_text(response: Optional["GenerationResponse"]) -> Optional[str]: """Safely extracts the text content from the first text part of a GenerationResponse.""" if not response or not response.candidates: return None try: # Iterate through parts to find the first text part for part in response.candidates[0].content.parts: # Check if the part has a 'text' attribute and it's not empty if hasattr(part, "text") and part.text: return part.text # If no text part is found (e.g., only function call or empty text parts) print( f"[_get_response_text] No text part found in candidate parts: {response.candidates[0].content.parts}" ) # Log parts structure return None except (AttributeError, IndexError) as e: # Handle cases where structure is unexpected or parts list is empty print(f"Error accessing response parts: {type(e).__name__}: {e}") return None except Exception as e: # Catch unexpected errors during access print(f"Unexpected error extracting text from response part: {e}") return None # --- Initialize Vertex AI --- try: vertexai.init(project=PROJECT_ID, location=LOCATION) print(f"Vertex AI initialized for project '{PROJECT_ID}' in location '{LOCATION}'.") except NameError: print("Vertex AI SDK not imported, skipping initialization.") except Exception as e: print(f"Error initializing Vertex AI: {e}") # --- Constants --- # Define standard safety settings (adjust as needed) # Use actual types if import succeeded, otherwise fallback to Any _HarmCategory = getattr(generative_models, "HarmCategory", Any) _HarmBlockThreshold = getattr(generative_models, "HarmBlockThreshold", Any) STANDARD_SAFETY_SETTINGS = { getattr( _HarmCategory, "HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_HATE_SPEECH" ): getattr(_HarmBlockThreshold, "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE"), getattr( _HarmCategory, "HARM_CATEGORY_DANGEROUS_CONTENT", "HARM_CATEGORY_DANGEROUS_CONTENT", ): getattr(_HarmBlockThreshold, "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE"), getattr( _HarmCategory, "HARM_CATEGORY_SEXUALLY_EXPLICIT", "HARM_CATEGORY_SEXUALLY_EXPLICIT", ): getattr(_HarmBlockThreshold, "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE"), getattr( _HarmCategory, "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_HARASSMENT" ): getattr(_HarmBlockThreshold, "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE"), } # --- API Call Helper --- async def call_vertex_api_with_retry( cog: "WheatleyCog", model: "GenerativeModel", # Use string literal for type hint contents: List["Content"], # Use string literal for type hint generation_config: "GenerationConfig", # Use string literal for type hint safety_settings: Optional[Dict[Any, Any]], # Use Any for broader compatibility request_desc: str, stream: bool = False, ) -> Union[ "GenerationResponse", AsyncIterable["GenerationResponse"], None ]: # Use string literals """ Calls the Vertex AI Gemini API with retry logic. Args: cog: The WheatleyCog instance. model: The initialized GenerativeModel instance. contents: The list of Content objects for the prompt. generation_config: The GenerationConfig object. safety_settings: Safety settings for the request. request_desc: A description of the request for logging purposes. stream: Whether to stream the response. Returns: The GenerationResponse object or an AsyncIterable if streaming, or None on failure. Raises: Exception: If the API call fails after all retry attempts or encounters a non-retryable error. """ last_exception = None model_name = model._model_name # Get model name for logging start_time = time.monotonic() for attempt in range(API_RETRY_ATTEMPTS + 1): try: print( f"Sending API request for {request_desc} using {model_name} (Attempt {attempt + 1}/{API_RETRY_ATTEMPTS + 1})..." ) response = await model.generate_content_async( contents=contents, generation_config=generation_config, safety_settings=safety_settings or STANDARD_SAFETY_SETTINGS, stream=stream, ) # --- Success Logging --- elapsed_time = time.monotonic() - start_time # Ensure model_name exists in stats before incrementing if model_name not in cog.api_stats: cog.api_stats[model_name] = { "success": 0, "failure": 0, "retries": 0, "total_time": 0.0, "count": 0, } cog.api_stats[model_name]["success"] += 1 cog.api_stats[model_name]["total_time"] += elapsed_time cog.api_stats[model_name]["count"] += 1 print( f"API request successful for {request_desc} ({model_name}) in {elapsed_time:.2f}s." ) return response # Success except google_exceptions.ResourceExhausted as e: error_msg = f"Rate limit error (ResourceExhausted) for {request_desc}: {e}" print(f"{error_msg} (Attempt {attempt + 1})") last_exception = e if attempt < API_RETRY_ATTEMPTS: if model_name not in cog.api_stats: cog.api_stats[model_name] = { "success": 0, "failure": 0, "retries": 0, "total_time": 0.0, "count": 0, } cog.api_stats[model_name]["retries"] += 1 wait_time = API_RETRY_DELAY * (2**attempt) # Exponential backoff print(f"Waiting {wait_time:.2f} seconds before retrying...") await asyncio.sleep(wait_time) continue else: break # Max retries reached except ( google_exceptions.InternalServerError, google_exceptions.ServiceUnavailable, ) as e: error_msg = f"API server error ({type(e).__name__}) for {request_desc}: {e}" print(f"{error_msg} (Attempt {attempt + 1})") last_exception = e if attempt < API_RETRY_ATTEMPTS: if model_name not in cog.api_stats: cog.api_stats[model_name] = { "success": 0, "failure": 0, "retries": 0, "total_time": 0.0, "count": 0, } cog.api_stats[model_name]["retries"] += 1 wait_time = API_RETRY_DELAY * (2**attempt) # Exponential backoff print(f"Waiting {wait_time:.2f} seconds before retrying...") await asyncio.sleep(wait_time) continue else: break # Max retries reached except google_exceptions.InvalidArgument as e: # Often indicates a problem with the request itself (e.g., bad schema, unsupported format) error_msg = f"Invalid argument error for {request_desc}: {e}" print(error_msg) last_exception = e break # Non-retryable except ( asyncio.TimeoutError ): # Handle potential client-side timeouts if applicable error_msg = f"Client-side request timed out for {request_desc} (Attempt {attempt + 1})" print(error_msg) last_exception = asyncio.TimeoutError(error_msg) # Decide if client-side timeouts should be retried if attempt < API_RETRY_ATTEMPTS: if model_name not in cog.api_stats: cog.api_stats[model_name] = { "success": 0, "failure": 0, "retries": 0, "total_time": 0.0, "count": 0, } cog.api_stats[model_name]["retries"] += 1 await asyncio.sleep(API_RETRY_DELAY * (attempt + 1)) continue else: break except Exception as e: # Catch other potential exceptions error_msg = f"Unexpected error during API call for {request_desc} (Attempt {attempt + 1}): {type(e).__name__}: {e}" print(error_msg) import traceback traceback.print_exc() last_exception = e # Decide if this generic exception is retryable # For now, treat unexpected errors as non-retryable break # --- Failure Logging --- elapsed_time = time.monotonic() - start_time if model_name not in cog.api_stats: cog.api_stats[model_name] = { "success": 0, "failure": 0, "retries": 0, "total_time": 0.0, "count": 0, } cog.api_stats[model_name]["failure"] += 1 cog.api_stats[model_name]["total_time"] += elapsed_time cog.api_stats[model_name]["count"] += 1 print( f"API request failed for {request_desc} ({model_name}) after {attempt + 1} attempts in {elapsed_time:.2f}s." ) # Raise the last encountered exception or a generic one raise last_exception or Exception( f"API request failed for {request_desc} after {API_RETRY_ATTEMPTS + 1} attempts." ) # --- JSON Parsing and Validation Helper --- def parse_and_validate_json_response( response_text: Optional[str], schema: Dict[str, Any], context_description: str ) -> Optional[Dict[str, Any]]: """ Parses the AI's response text, attempting to extract and validate a JSON object against a schema. Args: response_text: The raw text content from the AI response. schema: The JSON schema (as a dictionary) to validate against. context_description: A description for logging purposes. Returns: A parsed and validated dictionary if successful, None otherwise. """ if response_text is None: print(f"Parsing ({context_description}): Response text is None.") return None parsed_data = None raw_json_text = response_text # Start with the full text # Attempt 1: Try parsing the whole string directly try: parsed_data = json.loads(raw_json_text) print( f"Parsing ({context_description}): Successfully parsed entire response as JSON." ) except json.JSONDecodeError: # Attempt 2: Extract JSON object, handling optional markdown fences # More robust regex to handle potential leading/trailing text and variations json_match = re.search( r"```(?:json)?\s*(\{.*\})\s*```|(\{.*\})", response_text, re.DOTALL | re.MULTILINE, ) if json_match: json_str = json_match.group(1) or json_match.group(2) if json_str: raw_json_text = json_str # Use the extracted string for parsing try: parsed_data = json.loads(raw_json_text) print( f"Parsing ({context_description}): Successfully extracted and parsed JSON using regex." ) except json.JSONDecodeError as e_inner: print( f"Parsing ({context_description}): Regex found potential JSON, but it failed to parse: {e_inner}\nContent: {raw_json_text[:500]}" ) parsed_data = None else: print( f"Parsing ({context_description}): Regex matched, but failed to capture JSON content." ) parsed_data = None else: print( f"Parsing ({context_description}): Could not parse directly or extract JSON object using regex.\nContent: {raw_json_text[:500]}" ) parsed_data = None # Validation step if parsed_data is not None: if not isinstance(parsed_data, dict): print( f"Parsing ({context_description}): Parsed data is not a dictionary: {type(parsed_data)}" ) return None # Fail validation if not a dict try: jsonschema.validate(instance=parsed_data, schema=schema) print( f"Parsing ({context_description}): JSON successfully validated against schema." ) # Ensure default keys exist after validation parsed_data.setdefault("should_respond", False) parsed_data.setdefault("content", None) parsed_data.setdefault("react_with_emoji", None) return parsed_data except jsonschema.ValidationError as e: print( f"Parsing ({context_description}): JSON failed schema validation: {e.message}" ) # Optionally log more details: e.path, e.schema_path, e.instance return None # Validation failed except Exception as e: # Catch other potential validation errors print( f"Parsing ({context_description}): Unexpected error during JSON schema validation: {e}" ) return None else: # Parsing failed before validation could occur return None # --- Tool Processing --- async def process_requested_tools( cog: "WheatleyCog", function_call: "generative_models.FunctionCall" ) -> "Part": # Use string literals """ Process a tool request specified by the AI's FunctionCall response. Args: cog: The WheatleyCog instance. function_call: The FunctionCall object from the GenerationResponse. Returns: A Part object containing the tool result or error, formatted for the follow-up API call. """ function_name = function_call.name # Convert the Struct field arguments to a standard Python dict function_args = dict(function_call.args) if function_call.args else {} tool_result_content = None print(f"Processing tool request: {function_name} with args: {function_args}") tool_start_time = time.monotonic() if function_name in TOOL_MAPPING: try: tool_func = TOOL_MAPPING[function_name] # Execute the mapped function # Ensure the function signature matches the expected arguments # Pass cog if the tool implementation requires it result = await tool_func(cog, **function_args) # --- Tool Success Logging --- tool_elapsed_time = time.monotonic() - tool_start_time if function_name not in cog.tool_stats: cog.tool_stats[function_name] = { "success": 0, "failure": 0, "total_time": 0.0, "count": 0, } cog.tool_stats[function_name]["success"] += 1 cog.tool_stats[function_name]["total_time"] += tool_elapsed_time cog.tool_stats[function_name]["count"] += 1 print( f"Tool '{function_name}' executed successfully in {tool_elapsed_time:.2f}s." ) # Prepare result for API - must be JSON serializable, typically a dict if not isinstance(result, dict): # Attempt to convert common types or wrap in a dict if isinstance(result, (str, int, float, bool, list)) or result is None: result = {"result": result} else: print( f"Warning: Tool '{function_name}' returned non-standard type {type(result)}. Attempting str conversion." ) result = {"result": str(result)} tool_result_content = result except Exception as e: # --- Tool Failure Logging --- tool_elapsed_time = time.monotonic() - tool_start_time if function_name not in cog.tool_stats: cog.tool_stats[function_name] = { "success": 0, "failure": 0, "total_time": 0.0, "count": 0, } cog.tool_stats[function_name]["failure"] += 1 cog.tool_stats[function_name]["total_time"] += tool_elapsed_time cog.tool_stats[function_name]["count"] += 1 error_message = ( f"Error executing tool {function_name}: {type(e).__name__}: {str(e)}" ) print(f"{error_message} (Took {tool_elapsed_time:.2f}s)") import traceback traceback.print_exc() tool_result_content = {"error": error_message} else: # --- Tool Not Found Logging --- tool_elapsed_time = time.monotonic() - tool_start_time # Log attempt even if tool not found if function_name not in cog.tool_stats: cog.tool_stats[function_name] = { "success": 0, "failure": 0, "total_time": 0.0, "count": 0, } cog.tool_stats[function_name]["failure"] += 1 cog.tool_stats[function_name]["total_time"] += tool_elapsed_time cog.tool_stats[function_name]["count"] += 1 error_message = f"Tool '{function_name}' not found or implemented." print(f"{error_message} (Took {tool_elapsed_time:.2f}s)") tool_result_content = {"error": error_message} # Return the result formatted as a Part for the API return Part.from_function_response(name=function_name, response=tool_result_content) # --- Main AI Response Function --- async def get_ai_response( cog: "WheatleyCog", message: discord.Message, model_name: Optional[str] = None ) -> Dict[str, Any]: """ Gets responses from the Vertex AI Gemini API, handling potential tool usage and returning the final parsed response. Args: cog: The WheatleyCog instance. message: The triggering discord.Message. model_name: Optional override for the AI model name (e.g., "gemini-1.5-pro-preview-0409"). Returns: A dictionary containing: - "final_response": Parsed JSON data from the final AI call (or None if parsing/validation fails). - "error": An error message string if a critical error occurred, otherwise None. - "fallback_initial": Optional minimal response if initial parsing failed critically (less likely with controlled generation). """ if not PROJECT_ID or not LOCATION: return { "final_response": None, "error": "Google Cloud Project ID or Location not configured", } channel_id = message.channel.id user_id = message.author.id initial_parsed_data = None # Added to store initial parsed result final_parsed_data = None error_message = None fallback_response = None try: # --- Build Prompt Components --- final_system_prompt = await build_dynamic_system_prompt(cog, message) conversation_context_messages = gather_conversation_context( cog, channel_id, message.id ) # Pass cog memory_context = await get_memory_context(cog, message) # Pass cog # --- Initialize Model --- # Tools are passed during model initialization in Vertex AI SDK # Combine tool declarations into a Tool object vertex_tool = Tool(function_declarations=TOOLS) if TOOLS else None model = GenerativeModel( model_name or DEFAULT_MODEL, system_instruction=final_system_prompt, tools=[vertex_tool] if vertex_tool else None, ) # --- Prepare Message History (Contents) --- contents: List[Content] = [] # Add memory context if available if memory_context: # System messages aren't directly supported in the 'contents' list for multi-turn like OpenAI. # It's better handled via the 'system_instruction' parameter of GenerativeModel. # We might prepend it to the first user message or handle it differently if needed. # For now, we rely on system_instruction. Let's log if we have memory context. print("Memory context available, relying on system_instruction.") # If needed, could potentially add as a 'model' role message before user messages, # but this might confuse the turn structure. # contents.append(Content(role="model", parts=[Part.from_text(f"System Note: {memory_context}")])) # Add conversation history for msg in conversation_context_messages: role = msg.get("role", "user") # Default to user if role missing # Map roles if necessary (e.g., 'assistant' -> 'model') if role == "assistant": role = "model" elif role == "system": # Skip system messages here, handled by system_instruction continue # Handle potential multimodal content in history (if stored that way) if isinstance(msg.get("content"), list): parts = [ ( Part.from_text(part["text"]) if part["type"] == "text" else ( Part.from_uri( part["image_url"]["url"], mime_type=part["image_url"]["url"] .split(";")[0] .split(":")[1], ) if part["type"] == "image_url" else None ) ) for part in msg["content"] ] parts = [p for p in parts if p] # Filter out None parts if parts: contents.append(Content(role=role, parts=parts)) elif isinstance(msg.get("content"), str): contents.append( Content(role=role, parts=[Part.from_text(msg["content"])]) ) # --- Prepare the current message content (potentially multimodal) --- current_message_parts = [] formatted_current_message = format_message(cog, message) # Pass cog if needed # --- Construct text content, including reply context if applicable --- text_content = "" if formatted_current_message.get("is_reply") and formatted_current_message.get( "replied_to_author_name" ): reply_author = formatted_current_message["replied_to_author_name"] reply_content = formatted_current_message.get( "replied_to_content", "..." ) # Use ellipsis if content missing # Truncate long replied content to keep context concise max_reply_len = 150 if len(reply_content) > max_reply_len: reply_content = reply_content[:max_reply_len] + "..." text_content += f'(Replying to {reply_author}: "{reply_content}")\n' # Add current message author and content text_content += f"{formatted_current_message['author']['display_name']}: {formatted_current_message['content']}" # Add mention details if formatted_current_message.get("mentioned_users_details"): mentions_str = ", ".join( [ f"{m['display_name']}(id:{m['id']})" for m in formatted_current_message["mentioned_users_details"] ] ) text_content += f"\n(Message Details: Mentions=[{mentions_str}])" current_message_parts.append(Part.from_text(text_content)) # --- End text content construction --- if message.attachments: print( f"Processing {len(message.attachments)} attachments for message {message.id}" ) for attachment in message.attachments: mime_type = attachment.content_type file_url = attachment.url filename = attachment.filename # Check if MIME type is supported for URI input by Gemini # Expand this list based on Gemini documentation for supported types via URI supported_mime_prefixes = [ "image/", "video/", "audio/", "text/plain", "application/pdf", ] is_supported = False if mime_type: for prefix in supported_mime_prefixes: if mime_type.startswith(prefix): is_supported = True break # Add specific non-prefixed types if needed # if mime_type in ["application/vnd.google-apps.document", ...]: # is_supported = True if is_supported and file_url: try: # 1. Add text part instructing AI about the file instruction_text = f"User attached a file: '{filename}' (Type: {mime_type}). Analyze this file from the following URI and incorporate your understanding into your response." current_message_parts.append(Part.from_text(instruction_text)) print(f"Added text instruction for attachment: {filename}") # 2. Add the URI part # Ensure mime_type doesn't contain parameters like '; charset=...' if the API doesn't like them clean_mime_type = mime_type.split(";")[0] current_message_parts.append( Part.from_uri(uri=file_url, mime_type=clean_mime_type) ) print( f"Added URI part for attachment: {filename} ({clean_mime_type}) using URL: {file_url}" ) except Exception as e: print( f"Error creating Part for attachment {filename} ({mime_type}): {e}" ) # Optionally add a text part indicating the error current_message_parts.append( Part.from_text( f"(System Note: Failed to process attachment '{filename}' - {e})" ) ) else: print( f"Skipping unsupported or invalid attachment: {filename} (Type: {mime_type}, URL: {file_url})" ) # Optionally inform the AI that an unsupported file was attached current_message_parts.append( Part.from_text( f"(System Note: User attached an unsupported file '{filename}' of type '{mime_type}' which cannot be processed.)" ) ) # Ensure there's always *some* content part, even if only text or errors if current_message_parts: contents.append(Content(role="user", parts=current_message_parts)) else: print("Warning: No content parts generated for user message.") contents.append(Content(role="user", parts=[Part.from_text("")])) # --- First API Call (Check for Tool Use) --- print("Making initial API call to check for tool use...") generation_config_initial = GenerationConfig( temperature=0.75, max_output_tokens=10000, # Adjust as needed # No response schema needed for the initial call, just checking for function calls ) initial_response = await call_vertex_api_with_retry( cog=cog, model=model, contents=contents, generation_config=generation_config_initial, safety_settings=STANDARD_SAFETY_SETTINGS, request_desc=f"Initial response check for message {message.id}", ) # --- Log Raw Request and Response --- try: # Log the request payload (contents) request_payload_log = [ {"role": c.role, "parts": [str(p) for p in c.parts]} for c in contents ] # Convert parts to string for logging print( f"--- Raw API Request (Initial Call) ---\n{json.dumps(request_payload_log, indent=2)}\n------------------------------------" ) # Log the raw response object print( f"--- Raw API Response (Initial Call) ---\n{initial_response}\n-----------------------------------" ) except Exception as log_e: print(f"Error logging raw request/response: {log_e}") # --- End Logging --- if not initial_response or not initial_response.candidates: raise Exception("Initial API call returned no response or candidates.") # --- Check for Tool Call FIRST --- candidate = initial_response.candidates[0] finish_reason = getattr(candidate, "finish_reason", None) function_call = None function_call_part_content = None # Store the AI's request message content # Check primarily for the *presence* of a function call part, # as finish_reason might be STOP even with a function call. if hasattr(candidate, "content") and candidate.content.parts: for part in candidate.content.parts: if hasattr(part, "function_call"): function_call = part.function_call # Assign the value # Add check to ensure function_call is not None before proceeding if function_call: # Store the whole content containing the call to add to history later function_call_part_content = candidate.content print( f"AI requested tool (found function_call part): {function_call.name}" ) break # Found a valid function call part else: # Log if the attribute exists but is None (unexpected case) print( "Warning: Found part with 'function_call' attribute, but its value was None." ) # --- Process Tool Call or Handle Direct Response --- if function_call and function_call_part_content: # --- Tool Call Path --- initial_parsed_data = None # No initial JSON expected if tool is called # Process the tool request tool_response_part = await process_requested_tools(cog, function_call) # Append the AI's request and the tool's response to the history contents.append( candidate.content ) # Add the AI's function call request message contents.append( Content(role="function", parts=[tool_response_part]) ) # Add the function response part # --- Second API Call (Get Final Response After Tool) --- print("Making follow-up API call with tool results...") # Initialize a NEW model instance WITHOUT tools for the follow-up call # This prevents the InvalidArgument error when specifying response schema model_final = GenerativeModel( model_name or DEFAULT_MODEL, # Use the same model name system_instruction=final_system_prompt, # Keep the same system prompt # Omit the 'tools' parameter here ) # Preprocess the schema before passing it to GenerationConfig processed_response_schema = _preprocess_schema_for_vertex( RESPONSE_SCHEMA["schema"] ) generation_config_final = GenerationConfig( temperature=0.75, # Keep original temperature for final response max_output_tokens=10000, # Keep original max tokens response_mime_type="application/json", response_schema=processed_response_schema, # Use preprocessed schema ) final_response_obj = await call_vertex_api_with_retry( # Renamed variable for clarity cog=cog, model=model_final, # Use the new model instance WITHOUT tools contents=contents, # History now includes tool call/response generation_config=generation_config_final, safety_settings=STANDARD_SAFETY_SETTINGS, request_desc=f"Follow-up response for message {message.id} after tool execution", ) if not final_response_obj or not final_response_obj.candidates: raise Exception( "Follow-up API call returned no response or candidates." ) final_response_text = _get_response_text(final_response_obj) # Use helper final_parsed_data = parse_and_validate_json_response( final_response_text, RESPONSE_SCHEMA["schema"], "final response after tools", ) # Handle validation failure - Re-prompt loop (simplified example) if final_parsed_data is None: print( "Warning: Final response failed validation. Attempting re-prompt (basic)..." ) # Construct a basic re-prompt message contents.append( final_response_obj.candidates[0].content ) # Add the invalid response contents.append( Content( role="user", parts=[ Part.from_text( "Your previous JSON response was invalid or did not match the required schema. " f"Please provide the response again, strictly adhering to this schema:\n{json.dumps(RESPONSE_SCHEMA['schema'], indent=2)}" ) ], ) ) # Retry the final call retry_response_obj = await call_vertex_api_with_retry( cog=cog, model=model, contents=contents, generation_config=generation_config_final, safety_settings=STANDARD_SAFETY_SETTINGS, request_desc=f"Re-prompt validation failure for message {message.id}", ) if retry_response_obj and retry_response_obj.candidates: final_response_text = _get_response_text( retry_response_obj ) # Use helper final_parsed_data = parse_and_validate_json_response( final_response_text, RESPONSE_SCHEMA["schema"], "re-prompted final response", ) if final_parsed_data is None: print( "Critical Error: Re-prompted response still failed validation." ) error_message = ( "Failed to get valid JSON response after re-prompting." ) else: error_message = "Failed to get response after re-prompting." # final_parsed_data is now set (or None if failed) after tool use and potential re-prompt else: # --- No Tool Call Path --- print("No tool call requested by AI. Processing initial response as final.") # Attempt to parse the initial response text directly. initial_response_text = _get_response_text(initial_response) # Use helper # Validate against the final schema because this IS the final response. final_parsed_data = parse_and_validate_json_response( initial_response_text, RESPONSE_SCHEMA["schema"], "final response (no tools)", ) initial_parsed_data = ( final_parsed_data # Keep initial_parsed_data consistent for return dict ) if final_parsed_data is None: # This means the initial response failed validation. print("Critical Error: Initial response failed validation (no tools).") error_message = "Failed to parse/validate initial AI JSON response." # Create a basic fallback if the bot was mentioned replied_to_bot = ( message.reference and message.reference.resolved and message.reference.resolved.author == cog.bot.user ) if cog.bot.user.mentioned_in(message) or replied_to_bot: fallback_response = { "should_respond": True, "content": "...", "react_with_emoji": "❓", } # initial_parsed_data is not used in this path, only final_parsed_data matters except Exception as e: error_message = f"Error in get_ai_response main loop for message {message.id}: {type(e).__name__}: {str(e)}" print(error_message) import traceback traceback.print_exc() # Ensure both are None on critical error initial_parsed_data = None final_parsed_data = None return { "initial_response": initial_parsed_data, # Return parsed initial data "final_response": final_parsed_data, # Return parsed final data "error": error_message, "fallback_initial": fallback_response, } # --- Proactive AI Response Function --- async def get_proactive_ai_response( cog: "WheatleyCog", message: discord.Message, trigger_reason: str ) -> Dict[str, Any]: """Generates a proactive response based on a specific trigger using Vertex AI.""" if not PROJECT_ID or not LOCATION: return { "should_respond": False, "content": None, "react_with_emoji": None, "error": "Google Cloud Project ID or Location not configured", } print(f"--- Proactive Response Triggered: {trigger_reason} ---") channel_id = message.channel.id final_parsed_data = None error_message = None plan = None # Variable to store the plan try: # --- Build Context for Planning --- # Gather relevant context: recent messages, topic, sentiment, Gurt's mood/interests, trigger reason planning_context_parts = [ f"Proactive Trigger Reason: {trigger_reason}", f"Current Mood: {cog.current_mood}", ] # Add recent messages summary summary_data = await get_conversation_summary( cog, str(channel_id), message_limit=15 ) # Use tool function if summary_data and not summary_data.get("error"): planning_context_parts.append( f"Recent Conversation Summary: {summary_data['summary']}" ) # Add active topics active_topics_data = cog.active_topics.get(channel_id) if active_topics_data and active_topics_data.get("topics"): topics_str = ", ".join( [ f"{t['topic']} ({t['score']:.1f})" for t in active_topics_data["topics"][:3] ] ) planning_context_parts.append(f"Active Topics: {topics_str}") # Add sentiment sentiment_data = cog.conversation_sentiment.get(channel_id) if sentiment_data: planning_context_parts.append( f"Conversation Sentiment: {sentiment_data.get('overall', 'N/A')} (Intensity: {sentiment_data.get('intensity', 0):.1f})" ) # Add Wheatley's interests (Note: Interests are likely disabled/removed for Wheatley, this might fetch nothing) try: interests = await cog.memory_manager.get_interests(limit=5) if interests: interests_str = ", ".join([f"{t} ({l:.1f})" for t, l in interests]) planning_context_parts.append( f"Wheatley's Interests: {interests_str}" ) # Changed text except Exception as int_e: print(f"Error getting interests for planning: {int_e}") planning_context = "\n".join(planning_context_parts) # --- Planning Step --- print("Generating proactive response plan...") planning_prompt_messages = [ { "role": "system", "content": "You are Wheatley's planning module. Analyze the context and trigger reason to decide if Wheatley should respond proactively and, if so, outline a plan (goal, key info, tone). Focus on natural, in-character engagement (rambling, insecure, bad ideas). Respond ONLY with JSON matching the provided schema.", }, # Updated system prompt { "role": "user", "content": f"Context:\n{planning_context}\n\nBased on this context and the trigger reason, create a plan for Wheatley's proactive response.", }, # Updated user prompt ] plan = await get_internal_ai_json_response( cog=cog, prompt_messages=planning_prompt_messages, task_description=f"Proactive Planning ({trigger_reason})", response_schema_dict=PROACTIVE_PLAN_SCHEMA["schema"], model_name_override=FALLBACK_MODEL, # Use a potentially faster/cheaper model for planning temperature=0.5, max_tokens=300, ) if not plan or not plan.get("should_respond"): reason = ( plan.get("reasoning", "Planning failed or decided against responding.") if plan else "Planning failed." ) print(f"Proactive response aborted by plan: {reason}") return { "should_respond": False, "content": None, "react_with_emoji": None, "note": f"Plan: {reason}", } print( f"Proactive Plan Generated: Goal='{plan.get('response_goal', 'N/A')}', Reasoning='{plan.get('reasoning', 'N/A')}'" ) # --- Build Final Proactive Prompt using Plan --- persistent_traits = await cog.memory_manager.get_all_personality_traits() if not persistent_traits: persistent_traits = {} # Wheatley doesn't use these Gurt traits final_proactive_prompt_parts = [ f"You are Wheatley, an Aperture Science Personality Core. Your tone is rambling, insecure, uses British slang, and you often have terrible ideas you think are brilliant.", # Updated personality description # Removed Gurt-specific traits # Removed mood reference as it's disabled for Wheatley # Incorporate Plan Details: f"You decided to respond proactively (maybe?). Trigger Reason: {trigger_reason}.", # Wheatley-style uncertainty f"Your Brilliant Plan (Goal): {plan.get('response_goal', 'Say something... probably helpful?')}.", # Wheatley-style goal f"Reasoning: {plan.get('reasoning', 'N/A')}.", ] if plan.get("key_info_to_include"): info_str = "; ".join(plan["key_info_to_include"]) final_proactive_prompt_parts.append(f"Consider mentioning: {info_str}") if plan.get("suggested_tone"): final_proactive_prompt_parts.append( f"Adjust tone to be: {plan['suggested_tone']}" ) final_proactive_prompt_parts.append( "Generate a casual, in-character message based on the plan and context. Keep it relatively short and natural-sounding." ) final_proactive_system_prompt = "\n\n".join(final_proactive_prompt_parts) # --- Initialize Final Model --- model = GenerativeModel( model_name=DEFAULT_MODEL, system_instruction=final_proactive_system_prompt ) # --- Prepare Final Contents --- contents = [ Content( role="user", parts=[ Part.from_text( f"Generate the response based on your plan. **CRITICAL: Your response MUST be ONLY the raw JSON object matching this schema:**\n\n{json.dumps(RESPONSE_SCHEMA['schema'], indent=2)}\n\n**Ensure nothing precedes or follows the JSON.**" ) ], ) ] # --- Call Final LLM API --- # Preprocess the schema before passing it to GenerationConfig processed_response_schema_proactive = _preprocess_schema_for_vertex( RESPONSE_SCHEMA["schema"] ) generation_config_final = GenerationConfig( temperature=0.8, # Use original proactive temp max_output_tokens=200, response_mime_type="application/json", response_schema=processed_response_schema_proactive, # Use preprocessed schema ) response_obj = await call_vertex_api_with_retry( cog=cog, model=model, contents=contents, generation_config=generation_config_final, safety_settings=STANDARD_SAFETY_SETTINGS, request_desc=f"Final proactive response for channel {channel_id} ({trigger_reason})", ) if not response_obj or not response_obj.candidates: raise Exception( "Final proactive API call returned no response or candidates." ) # --- Parse and Validate Final Response --- final_response_text = _get_response_text(response_obj) final_parsed_data = parse_and_validate_json_response( final_response_text, RESPONSE_SCHEMA["schema"], f"final proactive response ({trigger_reason})", ) if final_parsed_data is None: print( f"Warning: Failed to parse/validate final proactive JSON response for {trigger_reason}." ) final_parsed_data = { "should_respond": False, "content": None, "react_with_emoji": None, "note": "Fallback - Failed to parse/validate final proactive JSON", } else: # --- Cache Bot Response --- if final_parsed_data.get("should_respond") and final_parsed_data.get( "content" ): bot_response_cache_entry = { "id": f"bot_proactive_{message.id}_{int(time.time())}", "author": { "id": str(cog.bot.user.id), "name": cog.bot.user.name, "display_name": cog.bot.user.display_name, "bot": True, }, "content": final_parsed_data.get("content", ""), "created_at": datetime.datetime.now().isoformat(), "attachments": [], "embeds": False, "mentions": [], "replied_to_message_id": None, "channel": message.channel, "guild": message.guild, "reference": None, "mentioned_users_details": [], } cog.message_cache["by_channel"].setdefault(channel_id, []).append( bot_response_cache_entry ) cog.message_cache["global_recent"].append(bot_response_cache_entry) cog.bot_last_spoke[channel_id] = time.time() # Removed Gurt-specific participation tracking except Exception as e: error_message = f"Error getting proactive AI response for channel {channel_id} ({trigger_reason}): {type(e).__name__}: {str(e)}" print(error_message) final_parsed_data = { "should_respond": False, "content": None, "react_with_emoji": None, "error": error_message, } # Ensure default keys exist final_parsed_data.setdefault("should_respond", False) final_parsed_data.setdefault("content", None) final_parsed_data.setdefault("react_with_emoji", None) if error_message and "error" not in final_parsed_data: final_parsed_data["error"] = error_message return final_parsed_data # --- Internal AI Call for Specific Tasks --- async def get_internal_ai_json_response( cog: "WheatleyCog", prompt_messages: List[Dict[str, Any]], # Keep this format task_description: str, response_schema_dict: Dict[str, Any], # Expect schema as dict model_name: Optional[str] = None, temperature: float = 0.7, max_tokens: int = 5000, ) -> Optional[Dict[str, Any]]: # Keep return type hint simple """ Makes a Vertex AI call expecting a specific JSON response format for internal tasks. Args: cog: The WheatleyCog instance. prompt_messages: List of message dicts (like OpenAI format: {'role': 'user'/'model', 'content': '...'}). task_description: Description for logging. response_schema_dict: The expected JSON output schema as a dictionary. model_name: Optional model override. temperature: Generation temperature. max_tokens: Max output tokens. Returns: The parsed and validated JSON dictionary if successful, None otherwise. """ if not PROJECT_ID or not LOCATION: print( f"Error in get_internal_ai_json_response ({task_description}): GCP Project/Location not set." ) return None final_parsed_data = None error_occurred = None request_payload_for_logging = {} # For logging try: # --- Convert prompt messages to Vertex AI Content format --- contents: List[Content] = [] system_instruction = None for msg in prompt_messages: role = msg.get("role", "user") content_text = msg.get("content", "") if role == "system": # Use the first system message as system_instruction if system_instruction is None: system_instruction = content_text else: # Append subsequent system messages to the instruction system_instruction += "\n\n" + content_text continue # Skip adding system messages to contents list elif role == "assistant": role = "model" # --- Process content (string or list) --- content_value = msg.get("content") message_parts: List[Part] = ( [] ) # Initialize list to hold parts for this message if isinstance(content_value, str): # Handle simple string content message_parts.append(Part.from_text(content_value)) elif isinstance(content_value, list): # Handle list content (e.g., multimodal from ProfileUpdater) for part_data in content_value: part_type = part_data.get("type") if part_type == "text": text = part_data.get("text", "") message_parts.append(Part.from_text(text)) elif part_type == "image_data": mime_type = part_data.get("mime_type") base64_data = part_data.get("data") if mime_type and base64_data: try: image_bytes = base64.b64decode(base64_data) message_parts.append( Part.from_data( data=image_bytes, mime_type=mime_type ) ) except Exception as decode_err: print( f"Error decoding/adding image part in get_internal_ai_json_response: {decode_err}" ) # Optionally add a placeholder text part indicating failure message_parts.append( Part.from_text( "(System Note: Failed to process an image part)" ) ) else: print("Warning: image_data part missing mime_type or data.") else: print( f"Warning: Unknown part type '{part_type}' in internal prompt message." ) else: print( f"Warning: Unexpected content type '{type(content_value)}' in internal prompt message." ) # Add the content object if parts were generated if message_parts: contents.append(Content(role=role, parts=message_parts)) else: print(f"Warning: No parts generated for message role '{role}'.") # Add the critical JSON instruction to the last user message or as a new user message json_instruction_content = ( f"**CRITICAL: Your response MUST consist *only* of the raw JSON object itself, matching this schema:**\n" f"{json.dumps(response_schema_dict, indent=2)}\n" f"**Ensure nothing precedes or follows the JSON.**" ) if contents and contents[-1].role == "user": contents[-1].parts.append(Part.from_text(f"\n\n{json_instruction_content}")) else: contents.append( Content(role="user", parts=[Part.from_text(json_instruction_content)]) ) # --- Initialize Model --- model = GenerativeModel( model_name=model_name or DEFAULT_MODEL, # Use keyword argument system_instruction=system_instruction, # No tools needed for internal JSON tasks usually ) # --- Prepare Generation Config --- # Preprocess the schema before passing it to GenerationConfig processed_schema_internal = _preprocess_schema_for_vertex(response_schema_dict) generation_config = GenerationConfig( temperature=temperature, max_output_tokens=max_tokens, response_mime_type="application/json", response_schema=processed_schema_internal, # Use preprocessed schema ) # Prepare payload for logging (approximate) request_payload_for_logging = { "model": model._model_name, "system_instruction": system_instruction, "contents": [ # Simplified representation for logging { "role": c.role, "parts": [ p.text if hasattr(p, "text") else str(type(p)) for p in c.parts ], } for c in contents ], # Use the original generation_config dict directly for logging "generation_config": generation_config, # It's already a dict } # --- Add detailed logging for raw request --- try: print(f"--- Raw request payload for {task_description} ---") # Use json.dumps for pretty printing, handle potential errors print( json.dumps(request_payload_for_logging, indent=2, default=str) ) # Use default=str as fallback print(f"--- End Raw request payload ---") except Exception as req_log_e: print(f"Error logging raw request payload: {req_log_e}") print( f"Payload causing error: {request_payload_for_logging}" ) # Print the raw dict on error # --- End detailed logging --- # --- Call API --- response_obj = await call_vertex_api_with_retry( cog=cog, model=model, contents=contents, generation_config=generation_config, safety_settings=STANDARD_SAFETY_SETTINGS, # Use standard safety request_desc=task_description, ) if not response_obj or not response_obj.candidates: raise Exception("Internal API call returned no response or candidates.") # --- Parse and Validate --- # This function always expects JSON, so directly use response_obj.text final_response_text = response_obj.text # --- Add detailed logging for raw response text --- print(f"--- Raw response_obj.text for {task_description} ---") print(final_response_text) print(f"--- End Raw response_obj.text ---") # --- End detailed logging --- print(f"Parsing ({task_description}): Using response_obj.text for JSON.") final_parsed_data = parse_and_validate_json_response( final_response_text, response_schema_dict, f"internal task ({task_description})", ) if final_parsed_data is None: print( f"Warning: Internal task '{task_description}' failed JSON validation." ) # No re-prompting for internal tasks, just return None except Exception as e: print( f"Error in get_internal_ai_json_response ({task_description}): {type(e).__name__}: {e}" ) error_occurred = e import traceback traceback.print_exc() final_parsed_data = None finally: # Log the call try: # Pass the simplified payload for logging await log_internal_api_call( cog, task_description, request_payload_for_logging, final_parsed_data, error_occurred, ) except Exception as log_e: print(f"Error logging internal API call: {log_e}") return final_parsed_data