discordbot/gurt/api.py

3112 lines
145 KiB
Python

from collections import deque
import ssl
import certifi
import imghdr # Added for robust image MIME type detection
from .config import CONTEXT_WINDOW_SIZE
def patch_ssl_certifi():
original_create_default_context = ssl.create_default_context
def custom_ssl_context(*args, **kwargs):
# Only set cafile if it's not already passed
kwargs.setdefault("cafile", certifi.where())
return original_create_default_context(*args, **kwargs)
ssl.create_default_context = custom_ssl_context
patch_ssl_certifi()
import discord
import aiohttp
import asyncio
import json
import base64
import re
import time
import datetime
from typing import (
TYPE_CHECKING,
Optional,
List,
Dict,
Any,
Union,
AsyncIterable,
Tuple,
) # Import Tuple
import jsonschema # For manual JSON validation
from .tools import get_conversation_summary
# Google Generative AI Imports (using Vertex AI backend)
# try:
from google import genai
from google.genai import types
from google.api_core import (
exceptions as google_exceptions,
) # Keep for retry logic if applicable
# except ImportError:
# print("WARNING: google-generativeai or google-api-core not installed. API calls will fail.")
# # Define dummy classes/exceptions if library isn't installed
# genai = None # Indicate genai module is missing
# # types = None # REMOVE THIS LINE - Define dummy types object below
# # Define dummy classes first
# class DummyGenerationResponse:
# def __init__(self):
# self.candidates = []
# self.text = None # Add basic text attribute for compatibility
# class DummyFunctionCall:
# def __init__(self):
# self.name = None
# self.args = None
# class DummyPart:
# @staticmethod
# def from_text(text): return None
# @staticmethod
# def from_data(data, mime_type): return None
# @staticmethod
# def from_uri(uri, mime_type): return None
# @staticmethod
# def from_function_response(name, response): return None
# @staticmethod
# def from_function_call(function_call): return None # Add this
# class DummyContent:
# def __init__(self, role=None, parts=None):
# self.role = role
# self.parts = parts or []
# class DummyTool:
# def __init__(self, function_declarations=None): pass
# class DummyFunctionDeclaration:
# def __init__(self, name, description, parameters): pass
# class Dummytypes.SafetySetting:
# def __init__(self, category, threshold): pass
# class Dummytypes.HarmCategory:
# HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
# HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
# HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
# HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
# class DummyFinishReason:
# STOP = "STOP"
# MAX_TOKENS = "MAX_TOKENS"
# SAFETY = "SAFETY"
# RECITATION = "RECITATION"
# OTHER = "OTHER"
# FUNCTION_CALL = "FUNCTION_CALL" # Add this
# class DummyToolConfig:
# class FunctionCallingConfig:
# class Mode:
# ANY = "ANY"
# NONE = "NONE"
# AUTO = "AUTO"
# def __init__(self, function_calling_config=None): pass
# class DummyGenerateContentResponse: # For non-streaming response type hint
# def __init__(self):
# self.candidates = []
# self.text = None
# # Define a dummy GenerateContentConfig class
# class DummyGenerateContentConfig:
# def __init__(self, temperature=None, top_p=None, max_output_tokens=None, response_mime_type=None, response_schema=None, stop_sequences=None, candidate_count=None):
# self.temperature = temperature
# self.top_p = top_p
# self.max_output_tokens = max_output_tokens
# self.response_mime_type = response_mime_type
# self.response_schema = response_schema
# self.stop_sequences = stop_sequences
# self.candidate_count = candidate_count
# # Define a dummy FunctionResponse class
# class DummyFunctionResponse:
# def __init__(self, name, response):
# self.name = name
# self.response = response
# # Create a dummy 'types' object and assign dummy classes to its attributes
# class DummyTypes:
# def __init__(self):
# self.GenerationResponse = DummyGenerationResponse
# self.FunctionCall = DummyFunctionCall
# self.types.Part = DummyPart
# self.types.Content = DummyContent
# self.Tool = DummyTool
# self.FunctionDeclaration = DummyFunctionDeclaration
# self.types.SafetySetting = Dummytypes.SafetySetting
# self.types.HarmCategory = Dummytypes.HarmCategory
# self.FinishReason = DummyFinishReason
# self.ToolConfig = DummyToolConfig
# self.GenerateContentResponse = DummyGenerateContentResponse
# self.GenerateContentConfig = DummyGenerateContentConfig # Assign dummy config
# self.FunctionResponse = DummyFunctionResponse # Assign dummy function response
# types = DummyTypes() # Assign the dummy object to 'types'
# # Assign dummy types to global scope for direct imports if needed
# GenerationResponse = DummyGenerationResponse
# FunctionCall = DummyFunctionCall
# types.Part = DummyPart
# types.Content = DummyContent
# Tool = DummyTool
# FunctionDeclaration = DummyFunctionDeclaration
# types.SafetySetting = Dummytypes.SafetySetting
# types.HarmCategory = Dummytypes.HarmCategory
# FinishReason = DummyFinishReason
# ToolConfig = DummyToolConfig
# GenerateContentResponse = DummyGenerateContentResponse
# class DummyGoogleExceptions:
# ResourceExhausted = type('ResourceExhausted', (Exception,), {})
# InternalServerError = type('InternalServerError', (Exception,), {})
# ServiceUnavailable = type('ServiceUnavailable', (Exception,), {})
# InvalidArgument = type('InvalidArgument', (Exception,), {})
# GoogleAPICallError = type('GoogleAPICallError', (Exception,), {}) # Generic fallback
# google_exceptions = DummyGoogleExceptions()
# Relative imports for components within the 'gurt' package
from .config import (
PROJECT_ID,
LOCATION,
DEFAULT_MODEL,
FALLBACK_MODEL,
CUSTOM_TUNED_MODEL_ENDPOINT,
EMOJI_STICKER_DESCRIPTION_MODEL, # Import the new endpoint and model
API_TIMEOUT,
API_RETRY_ATTEMPTS,
API_RETRY_DELAY,
TOOLS,
RESPONSE_SCHEMA,
PROACTIVE_PLAN_SCHEMA, # Import the new schema
TAVILY_API_KEY,
PISTON_API_URL,
PISTON_API_KEY,
BASELINE_PERSONALITY,
TENOR_API_KEY, # Import other needed configs
)
from .prompt import build_dynamic_system_prompt
from .context import (
gather_conversation_context,
get_memory_context,
) # Renamed functions
from .tools import TOOL_MAPPING # Import tool mapping
from .utils import format_message, log_internal_api_call # Import utilities
import copy # Needed for deep copying schemas
if TYPE_CHECKING:
from .cog import GurtCog # Import GurtCog for type hinting only
# --- Schema Preprocessing Helper ---
def _preprocess_schema_for_vertex(schema: Dict[str, Any]) -> Dict[str, Any]:
"""
Recursively preprocesses a JSON schema dictionary to replace list types
(like ["string", "null"]) with the first non-null type, making it
compatible with Vertex AI's GenerationConfig schema requirements.
Args:
schema: The JSON schema dictionary to preprocess.
Returns:
A new, preprocessed schema dictionary.
"""
if not isinstance(schema, dict):
return schema # Return non-dict elements as is
processed_schema = copy.deepcopy(schema) # Work on a copy
for key, value in processed_schema.items():
if key == "type" and isinstance(value, list):
# Find the first non-"null" type in the list
first_valid_type = next(
(t for t in value if isinstance(t, str) and t.lower() != "null"), None
)
if first_valid_type:
processed_schema[key] = first_valid_type
else:
# Fallback if only "null" or invalid types are present (shouldn't happen in valid schemas)
processed_schema[key] = "object" # Or handle as error
print(
f"Warning: Schema preprocessing found list type '{value}' with no valid non-null string type. Falling back to 'object'."
)
elif isinstance(value, dict):
processed_schema[key] = _preprocess_schema_for_vertex(
value
) # Recurse for nested objects
elif isinstance(value, list):
# Recurse for items within arrays (e.g., in 'properties' of array items)
processed_schema[key] = [
_preprocess_schema_for_vertex(item) if isinstance(item, dict) else item
for item in value
]
# Handle 'properties' specifically
elif key == "properties" and isinstance(value, dict):
processed_schema[key] = {
prop_key: _preprocess_schema_for_vertex(prop_value)
for prop_key, prop_value in value.items()
}
# Handle 'items' specifically if it's a schema object
elif key == "items" and isinstance(value, dict):
processed_schema[key] = _preprocess_schema_for_vertex(value)
return processed_schema
# --- Helper Function to Safely Extract Text ---
# Updated to handle google.generativeai.types.GenerateContentResponse
def _get_response_text(
response: Optional[types.GenerateContentResponse],
) -> Optional[str]:
"""
Safely extracts the text content from the first text part of a GenerateContentResponse.
Handles potential errors and lack of text parts gracefully.
"""
if not response:
print("[_get_response_text] Received None response object.")
return None
# Check if response has the 'text' attribute directly (common case for simple text responses)
if hasattr(response, "text") and response.text:
print("[_get_response_text] Found text directly in response.text attribute.")
return response.text
# If no direct text, check candidates
if not response.candidates:
# Log the response object itself for debugging if it exists but has no candidates
print(
f"[_get_response_text] Response object has no candidates. Response: {response}"
)
return None
try:
# Prioritize the first candidate
candidate = response.candidates[0]
# Check candidate.content and candidate.content.parts
if not hasattr(candidate, "content") or not candidate.content:
print(
f"[_get_response_text] Candidate 0 has no 'content'. Candidate: {candidate}"
)
return None
if not hasattr(candidate.content, "parts") or not candidate.content.parts:
print(
f"[_get_response_text] Candidate 0 content has no 'parts' or parts list is empty. types.Content: {candidate.content}"
)
return None
# Log parts for debugging
print(
f"[_get_response_text] Inspecting parts in candidate 0: {candidate.content.parts}"
)
# Iterate through parts to find the first text part
for i, part in enumerate(candidate.content.parts):
# Check if the part has a 'text' attribute and it's not empty/None
if (
hasattr(part, "text") and part.text is not None
): # Check for None explicitly
# Check if text is non-empty string after stripping whitespace
if isinstance(part.text, str) and part.text.strip():
print(f"[_get_response_text] Found non-empty text in part {i}.")
return part.text
else:
print(
f"[_get_response_text] types.Part {i} has 'text' attribute, but it's empty or not a string: {part.text!r}"
)
# else:
# print(f"[_get_response_text] types.Part {i} does not have 'text' attribute or it's None.")
# If no text part is found after checking all parts in the first candidate
print(
f"[_get_response_text] No usable text part found in candidate 0 after iterating through all parts."
)
return None
except (AttributeError, IndexError, TypeError) as e:
# Handle cases where structure is unexpected, list is empty, or types are wrong
print(
f"[_get_response_text] Error accessing response structure: {type(e).__name__}: {e}"
)
# Log the problematic response object for deeper inspection
print(f"Problematic response object: {response}")
return None
except Exception as e:
# Catch other unexpected errors during access
print(f"[_get_response_text] Unexpected error extracting text: {e}")
print(f"Response object during error: {response}")
return None
# --- Helper Function to Format Embeds for Prompt ---
def _format_embeds_for_prompt(embed_content: List[Dict[str, Any]]) -> Optional[str]:
"""Formats embed data into a string for the AI prompt."""
if not embed_content:
return None
formatted_strings = []
for i, embed in enumerate(embed_content):
parts = [f"--- Embed {i+1} ---"]
if embed.get("author") and embed["author"].get("name"):
parts.append(f"Author: {embed['author']['name']}")
if embed.get("title"):
parts.append(f"Title: {embed['title']}")
if embed.get("description"):
# Limit description length
desc = embed["description"]
max_desc_len = 200
if len(desc) > max_desc_len:
desc = desc[:max_desc_len] + "..."
parts.append(f"Description: {desc}")
if embed.get("fields"):
field_parts = []
for field in embed["fields"]:
fname = field.get("name", "Field")
fvalue = field.get("value", "")
# Limit field value length
max_field_len = 100
if len(fvalue) > max_field_len:
fvalue = fvalue[:max_field_len] + "..."
field_parts.append(f"- {fname}: {fvalue}")
if field_parts:
parts.append("Fields:\n" + "\n".join(field_parts))
if embed.get("footer") and embed["footer"].get("text"):
parts.append(f"Footer: {embed['footer']['text']}")
if embed.get("image_url"):
parts.append(
f"[Image Attached: {embed.get('image_url')}]"
) # Indicate image presence
if embed.get("thumbnail_url"):
parts.append(
f"[Thumbnail Attached: {embed.get('thumbnail_url')}]"
) # Indicate thumbnail presence
formatted_strings.append("\n".join(parts))
return "\n".join(formatted_strings) if formatted_strings else None
from .genai_client import (
genai_client,
get_genai_client_for_model,
)
# --- Constants ---
# Define standard safety settings using google.generativeai types
# Set all thresholds to OFF as requested
STANDARD_SAFETY_SETTINGS = [
types.SafetySetting(
category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold="BLOCK_NONE"
),
types.SafetySetting(
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold="BLOCK_NONE",
),
types.SafetySetting(
category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold="BLOCK_NONE",
),
types.SafetySetting(
category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold="BLOCK_NONE"
),
]
# --- API Call Helper ---
async def call_google_genai_api_with_retry(
cog: "GurtCog",
model_name: str, # Pass model name string instead of model object
contents: List[
types.Content
], # Use types.Content type from google.generativeai.types
generation_config: types.GenerateContentConfig, # Combined config object
request_desc: str,
# Removed safety_settings, tools, tool_config as separate params
) -> Optional[types.GenerateContentResponse]: # Return type for non-streaming
"""
Calls the Google Generative AI API (Vertex AI backend) with retry logic (non-streaming).
Args:
cog: The GurtCog instance.
model_name: The name/path of the model to use (e.g., 'models/gemini-1.5-pro-preview-0409' or custom endpoint path).
contents: The list of types.Content objects for the prompt.
generation_config: The types.GenerateContentConfig object, which should now include temperature, top_p, max_output_tokens, safety_settings, tools, tool_config, response_mime_type, response_schema etc. as needed.
request_desc: A description of the request for logging purposes.
Returns:
The GenerateContentResponse object if successful, or None on failure after retries.
Raises:
Exception: If the API call fails after all retry attempts or encounters a non-retryable error.
"""
client = get_genai_client_for_model(model_name)
if not client:
raise Exception("Google GenAI Client is not initialized.")
last_exception = None
start_time = time.monotonic()
# Get the model object from the client
# Note: model_name should include the 'projects/.../locations/.../endpoints/...' path for custom models
# or just 'models/model-name' for standard models.
try:
model = "projects/1079377687568/locations/us-central1/endpoints/6677946543460319232" # Use get_model to ensure it exists
if not model:
raise ValueError(f"Could not retrieve model: {model_name}")
except Exception as model_e:
print(f"Error retrieving model '{model_name}': {model_e}")
raise # Re-raise the exception as this is a fundamental setup issue
for attempt in range(API_RETRY_ATTEMPTS + 1):
try:
# Use the actual model name string passed to the function for logging
print(
f"Sending API request for {request_desc} using {model_name} (Attempt {attempt + 1}/{API_RETRY_ATTEMPTS + 1})..."
)
# Use the non-streaming async call - config now contains all settings
# The 'model' parameter here should be the actual model name string
response = await client.aio.models.generate_content(
model=model_name, # Use the model_name string directly
contents=contents,
config=generation_config, # Pass the combined config object
# stream=False is implicit for generate_content
)
# --- Check Finish Reason (Safety) ---
# Access finish reason and safety ratings from the response object
if response and response.candidates:
candidate = response.candidates[0]
finish_reason = getattr(candidate, "finish_reason", None)
safety_ratings = getattr(candidate, "safety_ratings", [])
if finish_reason == types.FinishReason.SAFETY:
safety_ratings_str = (
", ".join(
[
f"{rating.category.name}: {rating.probability.name}"
for rating in safety_ratings
]
)
if safety_ratings
else "N/A"
)
# Optionally, raise a specific exception here if needed downstream
# raise SafetyBlockError(f"Blocked by safety filters. Ratings: {safety_ratings_str}")
elif finish_reason not in [
types.FinishReason.STOP,
types.FinishReason.MAX_TOKENS,
None,
]: # Allow None finish reason
# Log other unexpected finish reasons
finish_reason_name = (
types.FinishReason(finish_reason).name
if isinstance(finish_reason, int)
else str(finish_reason)
)
print(
f"⚠️ UNEXPECTED FINISH REASON: API request for {request_desc} ({model_name}) finished with reason: {finish_reason_name}"
)
# --- Success Logging (Proceed even if safety blocked, but log occurred) ---
elapsed_time = time.monotonic() - start_time
# Ensure model_name exists in stats before incrementing
if model_name not in cog.api_stats:
cog.api_stats[model_name] = {
"success": 0,
"failure": 0,
"retries": 0,
"total_time": 0.0,
"count": 0,
}
cog.api_stats[model_name]["success"] += 1
cog.api_stats[model_name]["total_time"] += elapsed_time
cog.api_stats[model_name]["count"] += 1
print(
f"API request successful for {request_desc} ({model_name}) in {elapsed_time:.2f}s."
)
return response # Success
# Adapt exception handling if google.generativeai raises different types
# google.api_core.exceptions should still cover many common API errors
except google_exceptions.ResourceExhausted as e:
error_msg = f"Rate limit error (ResourceExhausted) for {request_desc} ({model_name}): {e}"
print(f"{error_msg} (Attempt {attempt + 1})")
last_exception = e
if attempt < API_RETRY_ATTEMPTS:
if model_name not in cog.api_stats:
cog.api_stats[model_name] = {
"success": 0,
"failure": 0,
"retries": 0,
"total_time": 0.0,
"count": 0,
}
cog.api_stats[model_name]["retries"] += 1
wait_time = API_RETRY_DELAY * (2**attempt) # Exponential backoff
print(f"Waiting {wait_time:.2f} seconds before retrying...")
await asyncio.sleep(wait_time)
continue
else:
break # Max retries reached
except (
google_exceptions.InternalServerError,
google_exceptions.ServiceUnavailable,
) as e:
error_msg = f"API server error ({type(e).__name__}) for {request_desc} ({model_name}): {e}"
print(f"{error_msg} (Attempt {attempt + 1})")
last_exception = e
if attempt < API_RETRY_ATTEMPTS:
if model_name not in cog.api_stats:
cog.api_stats[model_name] = {
"success": 0,
"failure": 0,
"retries": 0,
"total_time": 0.0,
"count": 0,
}
cog.api_stats[model_name]["retries"] += 1
wait_time = API_RETRY_DELAY * (2**attempt) # Exponential backoff
print(f"Waiting {wait_time:.2f} seconds before retrying...")
await asyncio.sleep(wait_time)
continue
else:
break # Max retries reached
except google_exceptions.InvalidArgument as e:
# Often indicates a problem with the request itself (e.g., bad schema, unsupported format, invalid model name)
error_msg = f"Invalid argument error for {request_desc} ({model_name}): {e}"
print(error_msg)
last_exception = e
break # Non-retryable
except (
asyncio.TimeoutError
): # Handle potential client-side timeouts if applicable
error_msg = f"Client-side request timed out for {request_desc} ({model_name}) (Attempt {attempt + 1})"
print(error_msg)
last_exception = asyncio.TimeoutError(error_msg)
# Decide if client-side timeouts should be retried
if attempt < API_RETRY_ATTEMPTS:
if model_name not in cog.api_stats:
cog.api_stats[model_name] = {
"success": 0,
"failure": 0,
"retries": 0,
"total_time": 0.0,
"count": 0,
}
cog.api_stats[model_name]["retries"] += 1
await asyncio.sleep(
API_RETRY_DELAY * (attempt + 1)
) # Linear backoff for timeout? Or keep exponential?
continue
else:
break
except (
Exception
) as e: # Catch other potential exceptions (e.g., from genai library itself)
error_msg = f"Unexpected error during API call for {request_desc} ({model_name}) (Attempt {attempt + 1}): {type(e).__name__}: {e}"
print(error_msg)
import traceback
traceback.print_exc()
last_exception = e
# Decide if this generic exception is retryable
# For now, treat unexpected errors as non-retryable
break
# --- Failure Logging ---
elapsed_time = time.monotonic() - start_time
if model_name not in cog.api_stats:
cog.api_stats[model_name] = {
"success": 0,
"failure": 0,
"retries": 0,
"total_time": 0.0,
"count": 0,
}
cog.api_stats[model_name]["failure"] += 1
cog.api_stats[model_name]["total_time"] += elapsed_time
cog.api_stats[model_name]["count"] += 1
print(
f"API request failed for {request_desc} ({model_name}) after {attempt + 1} attempts in {elapsed_time:.2f}s."
)
# Raise the last encountered exception or a generic one
raise last_exception or Exception(
f"API request failed for {request_desc} after {API_RETRY_ATTEMPTS + 1} attempts."
)
# --- JSON Parsing and Validation Helper ---
def parse_and_validate_json_response(
response_text: Optional[str], schema: Dict[str, Any], context_description: str
) -> Optional[Dict[str, Any]]:
"""
Parses the AI's response text, attempting to extract and validate a JSON object against a schema.
Args:
response_text: The raw text content from the AI response.
schema: The JSON schema (as a dictionary) to validate against.
context_description: A description for logging purposes.
Returns:
A parsed and validated dictionary if successful, None otherwise.
"""
if response_text is None:
print(f"Parsing ({context_description}): Response text is None.")
return None
parsed_data = None
raw_json_text = response_text # Start with the full text
# Attempt 1: Try parsing the whole string directly
try:
parsed_data = json.loads(raw_json_text)
print(
f"Parsing ({context_description}): Successfully parsed entire response as JSON."
)
except json.JSONDecodeError:
# Attempt 2: Extract JSON object, handling optional markdown fences
# More robust regex to handle potential leading/trailing text and variations
json_match = re.search(
r"```(?:json)?\s*(\{.*\})\s*```|(\{.*\})",
response_text,
re.DOTALL | re.MULTILINE,
)
if json_match:
json_str = json_match.group(1) or json_match.group(2)
if json_str:
raw_json_text = json_str # Use the extracted string for parsing
try:
parsed_data = json.loads(raw_json_text)
print(
f"Parsing ({context_description}): Successfully extracted and parsed JSON using regex."
)
except json.JSONDecodeError as e_inner:
print(
f"Parsing ({context_description}): Regex found potential JSON, but it failed to parse: {e_inner}\nContent: {raw_json_text[:500]}"
)
parsed_data = None
else:
print(
f"Parsing ({context_description}): Regex matched, but failed to capture JSON content."
)
parsed_data = None
else:
print(
f"Parsing ({context_description}): Could not parse directly or extract JSON object using regex.\nContent: {raw_json_text[:500]}"
)
parsed_data = None
# Validation step
if parsed_data is not None:
if not isinstance(parsed_data, dict):
print(
f"Parsing ({context_description}): Parsed data is not a dictionary: {type(parsed_data)}"
)
return None # Fail validation if not a dict
try:
jsonschema.validate(instance=parsed_data, schema=schema)
print(
f"Parsing ({context_description}): JSON successfully validated against schema."
)
# Ensure default keys exist after validation
parsed_data.setdefault("should_respond", False)
parsed_data.setdefault("content", None)
parsed_data.setdefault("react_with_emoji", None)
return parsed_data
except jsonschema.ValidationError as e:
print(
f"Parsing ({context_description}): JSON failed schema validation: {e.message}"
)
# Optionally log more details: e.path, e.schema_path, e.instance
return None # Validation failed
except Exception as e: # Catch other potential validation errors
print(
f"Parsing ({context_description}): Unexpected error during JSON schema validation: {e}"
)
return None
else:
# Parsing failed before validation could occur
return None
# --- Tool Processing ---
# Updated to use google.generativeai types
async def process_requested_tools(
cog: "GurtCog", function_call: types.FunctionCall
) -> List[types.Part]: # Return type is List
"""
Process a tool request specified by the AI's FunctionCall response.
Returns a list of types.Part objects (usually one, but potentially more if an image URL is detected in the result).
"""
function_name = function_call.name
# function_call.args is already a dict-like object in google.generativeai
function_args = dict(function_call.args) if function_call.args else {}
tool_result_content = None
print(f"Processing tool request: {function_name} with args: {function_args}")
tool_start_time = time.monotonic()
if function_name in TOOL_MAPPING:
try:
tool_func = TOOL_MAPPING[function_name]
# Execute the mapped function
result_dict = await tool_func(cog, **function_args)
# --- Tool Success Logging ---
tool_elapsed_time = time.monotonic() - tool_start_time
if function_name not in cog.tool_stats:
cog.tool_stats[function_name] = {
"success": 0,
"failure": 0,
"total_time": 0.0,
"count": 0,
}
cog.tool_stats[function_name]["success"] += 1
cog.tool_stats[function_name]["total_time"] += tool_elapsed_time
cog.tool_stats[function_name]["count"] += 1
print(
f"Tool '{function_name}' executed successfully in {tool_elapsed_time:.2f}s."
)
# Ensure result is a dict, converting if necessary
if not isinstance(result_dict, dict):
if (
isinstance(result_dict, (str, int, float, bool, list))
or result_dict is None
):
result_dict = {"result": result_dict}
else:
print(
f"Warning: Tool '{function_name}' returned non-standard type {type(result_dict)}. Attempting str conversion."
)
result_dict = {"result": str(result_dict)}
tool_result_content = result_dict # Now guaranteed to be a dict
except Exception as e:
# --- Tool Failure Logging ---
tool_elapsed_time = (
time.monotonic() - tool_start_time
) # Recalculate time even on failure
if function_name not in cog.tool_stats:
cog.tool_stats[function_name] = {
"success": 0,
"failure": 0,
"total_time": 0.0,
"count": 0,
}
cog.tool_stats[function_name]["failure"] += 1
cog.tool_stats[function_name]["total_time"] += tool_elapsed_time
cog.tool_stats[function_name]["count"] += 1
error_message = (
f"Error executing tool {function_name}: {type(e).__name__}: {str(e)}"
)
print(f"{error_message} (Took {tool_elapsed_time:.2f}s)")
import traceback
traceback.print_exc()
tool_result_content = {
"error": error_message
} # Ensure it's a dict even on error
else: # This 'else' corresponds to 'if function_name in TOOL_MAPPING:'
# --- Tool Not Found Logging ---
tool_elapsed_time = (
time.monotonic() - tool_start_time
) # Time for the failed lookup
if function_name not in cog.tool_stats:
cog.tool_stats[function_name] = {
"success": 0,
"failure": 0,
"total_time": 0.0,
"count": 0,
}
cog.tool_stats[function_name]["failure"] += 1 # Count as failure
cog.tool_stats[function_name]["total_time"] += tool_elapsed_time
cog.tool_stats[function_name]["count"] += 1
error_message = f"Tool '{function_name}' not found or implemented."
print(f"{error_message} (Took {tool_elapsed_time:.2f}s)")
tool_result_content = {"error": error_message} # Ensure it's a dict
# --- Process result for potential image URLs ---
parts_to_return: List[types.Part] = []
original_image_url: Optional[str] = None # Store the original URL if found
modified_result_content = copy.deepcopy(tool_result_content) # Work on a copy
# --- Image URL Detection & Modification ---
# Check specific tools and keys known to contain image URLs
# Special handling for get_user_avatar_data to directly use its base64 output
if function_name == "get_user_avatar_data" and isinstance(
modified_result_content, dict
):
base64_image_data = modified_result_content.get("base64_data")
image_mime_type = modified_result_content.get("content_type")
if base64_image_data and image_mime_type:
try:
image_bytes = base64.b64decode(base64_image_data)
# Validate MIME type (optional, but good practice)
supported_image_mimes = [
"image/png",
"image/jpeg",
"image/webp",
"image/heic",
"image/heif",
]
clean_mime_type = image_mime_type.split(";")[0].lower()
if clean_mime_type in supported_image_mimes:
# Corrected: Use inline_data for raw bytes
image_part = types.Part(
inline_data=types.Blob(
data=image_bytes, mime_type=clean_mime_type
)
)
parts_to_return.append(
image_part
) # Corrected: Add to parts_to_return for this tool's response
print(
f"Added image part directly from get_user_avatar_data (MIME: {clean_mime_type}, {len(image_bytes)} bytes)."
)
# Replace base64_data in the textual response to avoid sending it twice
modified_result_content["base64_data"] = (
"[Image Content Attached In Prompt]"
)
modified_result_content["content_type"] = (
f"[MIME type: {clean_mime_type} - Content Attached In Prompt]"
)
else:
print(
f"Warning: MIME type '{clean_mime_type}' from get_user_avatar_data not in supported list. Not attaching image part."
)
modified_result_content["base64_data"] = (
"[Image Data Not Attached - Unsupported MIME Type]"
)
except Exception as e:
print(f"Error processing base64 data from get_user_avatar_data: {e}")
modified_result_content["base64_data"] = (
f"[Error Processing Image Data: {e}]"
)
# Prevent generic URL download logic from re-processing this avatar
original_image_url = None # Explicitly nullify to skip URL download
elif function_name == "get_user_avatar_url" and isinstance(
modified_result_content, dict
):
avatar_url_value = modified_result_content.get("avatar_url")
if avatar_url_value and isinstance(avatar_url_value, str):
original_image_url = avatar_url_value # Store original
modified_result_content["avatar_url"] = (
"[Image Content Attached]" # Replace URL with placeholder
)
elif function_name == "get_user_profile_info" and isinstance(
modified_result_content, dict
):
profile_dict = modified_result_content.get("profile")
if isinstance(profile_dict, dict):
avatar_url_value = profile_dict.get("avatar_url")
if avatar_url_value and isinstance(avatar_url_value, str):
original_image_url = avatar_url_value # Store original
profile_dict["avatar_url"] = (
"[Image Content Attached]" # Replace URL in nested dict
)
# Add checks for other tools/keys that might return image URLs if necessary
# --- Create Parts ---
# Always add the function response part (using the potentially modified content)
function_response_part = types.Part(
function_response=types.FunctionResponse(
name=function_name, response=modified_result_content
)
)
parts_to_return.append(function_response_part)
# Add image part if an original URL was found and seems valid
if (
original_image_url
and isinstance(original_image_url, str)
and original_image_url.startswith("http")
):
download_success = False
try:
# Download the image data using aiohttp session from cog
if not hasattr(cog, "session") or not cog.session:
raise ValueError("aiohttp session not found in cog.")
print(f"Downloading image data from URL: {original_image_url}")
async with cog.session.get(
original_image_url, timeout=15
) as response: # Added timeout
if response.status == 200:
image_bytes = await response.read()
mime_type = (
response.content_type or "application/octet-stream"
) # Get MIME type from header
# Validate against known supported image types for Gemini
supported_image_mimes = [
"image/png",
"image/jpeg",
"image/webp",
"image/heic",
"image/heif",
]
clean_mime_type = mime_type.split(";")[0].lower() # Clean MIME type
if clean_mime_type in supported_image_mimes:
# Use types.Part.from_data instead of from_uri
image_part = types.Part(
inline_data=types.Blob(
data=image_bytes, mime_type=clean_mime_type
)
)
parts_to_return.append(image_part)
download_success = True
print(
f"Added image part (from data, {len(image_bytes)} bytes, MIME: {clean_mime_type}) from tool '{function_name}' result."
)
else:
print(
f"Warning: Downloaded image MIME type '{clean_mime_type}' from {original_image_url} might not be supported by Gemini. Skipping image part."
)
else:
print(
f"Error downloading image from {original_image_url}: Status {response.status}"
)
except asyncio.TimeoutError:
print(
f"Error downloading image from {original_image_url}: Request timed out."
)
except aiohttp.ClientError as client_e:
print(f"Error downloading image from {original_image_url}: {client_e}")
except ValueError as val_e: # Catch missing session error
print(f"Error preparing image download: {val_e}")
except Exception as e:
print(
f"Error downloading or creating image part from data ({original_image_url}): {e}"
)
# If download or processing failed, add an error note for the LLM
if not download_success:
error_text = f"[System Note: Failed to download or process image data from URL provided by tool '{function_name}'. URL: {original_image_url}]"
error_text_part = types.Part(text=error_text)
parts_to_return.append(error_text_part)
return parts_to_return # Return the list of parts (will contain 1 or 2+ parts)
# --- Helper to find function call in parts ---
# Updated to use google.generativeai types
def find_function_call_in_parts(
parts: Optional[List[types.Part]],
) -> Optional[types.FunctionCall]:
"""Finds the first valid FunctionCall object within a list of Parts."""
if not parts:
return None
for part in parts:
# Check if the part has a 'function_call' attribute and it's a valid FunctionCall object
if hasattr(part, "function_call") and isinstance(
part.function_call, types.FunctionCall
):
# Basic validation: ensure name exists
if part.function_call.name:
return part.function_call
else:
print(
f"Warning: Found types.Part with 'function_call', but its name is missing: {part.function_call}"
)
# else:
# print(f"Debug: types.Part does not have valid function_call: {part}") # Optional debug log
return None
# --- Main AI Response Function ---
async def get_ai_response(
cog: "GurtCog", message: discord.Message, model_name: Optional[str] = None
) -> Tuple[Dict[str, Any], List[str]]:
"""
Gets responses from the Vertex AI Gemini API, handling potential tool usage and returning
the final parsed response.
Args:
cog: The GurtCog instance.
message: The triggering discord.Message.
model_name: Optional override for the AI model name (e.g., "gemini-1.5-pro-preview-0409").
Returns:
A dictionary containing:
- "final_response": Parsed JSON data from the final AI call (or None if parsing/validation fails).
- "error": An error message string if a critical error occurred, otherwise None.
- "fallback_initial": Optional minimal response if initial parsing failed critically (less likely with controlled generation).
"""
if not PROJECT_ID:
error_msg = "Google Cloud Project ID not configured."
print(f"Error in get_ai_response: {error_msg}")
return {"final_response": None, "error": error_msg}
# Determine the model for all generation steps.
# Use the model_name override if provided to get_ai_response, otherwise use the cog's current default_model.
active_model = model_name or cog.default_model
print(f"Using active model for all generation steps: {active_model}")
channel_id = message.channel.id
user_id = message.author.id
# initial_parsed_data is no longer needed with the loop structure
final_parsed_data = None
error_message = None
fallback_response = None # Keep fallback for critical initial failures
max_tool_calls = 5 # Maximum number of sequential tool calls allowed
tool_calls_made = 0
last_response_obj = None # Store the last response object from the loop
try:
# --- Build Prompt Components ---
final_system_prompt = await build_dynamic_system_prompt(cog, message)
conversation_context_messages = gather_conversation_context(
cog, channel_id, message.id
) # Pass cog
memory_context = await get_memory_context(cog, message) # Pass cog
# --- Prepare Message History (Contents) ---
# Contents will be built progressively within the loop
contents: List[types.Content] = []
# Add memory context if available
if memory_context:
# System messages aren't directly supported in the 'contents' list for multi-turn like OpenAI.
# It's better handled via the 'system_instruction' parameter of GenerativeModel.
# We might prepend it to the first user message or handle it differently if needed.
# For now, we rely on system_instruction. Let's log if we have memory context.
print("Memory context available, relying on system_instruction.")
# If needed, could potentially add as a 'model' role message before user messages,
# but this might confuse the turn structure.
# contents.append(types.Content(role="model", parts=[types.Part.from_text(f"System Note: {memory_context}")]))
# Add conversation history
# The current message is already included in conversation_context_messages
for msg in conversation_context_messages:
role = (
"assistant"
if msg.get("author", {}).get("id") == str(cog.bot.user.id)
else "user"
) # Use get for safety
parts: List[types.Part] = [] # Initialize parts for each message
# Handle potential multimodal content in history (if stored that way)
if isinstance(msg.get("content"), list):
# If content is already a list of parts, process them
for part_data in msg["content"]:
if part_data["type"] == "text":
parts.append(types.Part(text=part_data["text"]))
elif part_data["type"] == "image_url":
# Assuming image_url has 'url' and 'mime_type'
parts.append(
types.Part(
uri=part_data["image_url"]["url"],
mime_type=part_data["image_url"]["url"]
.split(";")[0]
.split(":")[1],
)
)
# Filter out None parts if any were conditionally added
parts = [p for p in parts if p]
# Combine text, embeds, and attachments for history messages
elif (
isinstance(msg.get("content"), str)
or msg.get("embed_content")
or msg.get("attachment_descriptions")
):
text_parts = []
# Add original text content if it exists and is not empty
if isinstance(msg.get("content"), str) and msg["content"].strip():
text_parts.append(msg["content"])
# Add formatted embed content if present
embed_str = _format_embeds_for_prompt(msg.get("embed_content", []))
if embed_str:
text_parts.append(f"\n[Embed Content]:\n{embed_str}")
# Add attachment descriptions if present
if msg.get("attachment_descriptions"):
# Ensure descriptions are strings before joining
attach_desc_list = [
a["description"]
for a in msg["attachment_descriptions"]
if isinstance(a.get("description"), str)
]
if attach_desc_list:
attach_desc_str = "\n".join(attach_desc_list)
text_parts.append(f"\n[Attachments]:\n{attach_desc_str}")
# Add custom emoji and sticker descriptions and images from cache for historical messages
cached_emojis = msg.get("custom_emojis", [])
for emoji_info in cached_emojis:
emoji_name = emoji_info.get("name")
emoji_url = emoji_info.get("url")
if emoji_name:
text_parts.append(f"[Emoji: {emoji_name}]")
if emoji_url and emoji_url.startswith("http"):
# Determine MIME type for emoji URI
is_animated_emoji = emoji_info.get("animated", False)
emoji_mime_type = (
"image/gif" if is_animated_emoji else "image/png"
)
try:
# Download emoji data and send as inline_data
async with cog.session.get(
emoji_url, timeout=10
) as response:
if response.status == 200:
emoji_bytes = await response.read()
parts.append(
types.Part(
inline_data=types.Blob(
data=emoji_bytes,
mime_type=emoji_mime_type,
)
)
)
print(
f"Added inline_data part for historical emoji: {emoji_name} (MIME: {emoji_mime_type}, {len(emoji_bytes)} bytes)"
)
else:
print(
f"Error downloading historical emoji {emoji_name} from {emoji_url}: Status {response.status}"
)
text_parts.append(
f"[System Note: Failed to download emoji '{emoji_name}']"
)
except Exception as e:
print(
f"Error downloading/processing historical emoji {emoji_name} from {emoji_url}: {e}"
)
text_parts.append(
f"[System Note: Failed to process emoji '{emoji_name}']"
)
cached_stickers = msg.get("stickers", [])
for sticker_info in cached_stickers:
sticker_name = sticker_info.get("name")
sticker_url = sticker_info.get("url")
sticker_format_str = sticker_info.get("format")
sticker_format_text = (
f" (Format: {sticker_format_str})" if sticker_format_str else ""
)
if sticker_name:
text_parts.append(
f"[Sticker: {sticker_name}{sticker_format_text}]"
)
is_image_sticker = sticker_format_str in [
"StickerFormatType.png",
"StickerFormatType.apng",
]
if (
is_image_sticker
and sticker_url
and sticker_url.startswith("http")
):
sticker_mime_type = (
"image/png" # APNG is also sent as image/png
)
try:
# Download sticker data and send as inline_data
async with cog.session.get(
sticker_url, timeout=10
) as response:
if response.status == 200:
sticker_bytes = await response.read()
parts.append(
types.Part(
inline_data=types.Blob(
data=sticker_bytes,
mime_type=sticker_mime_type,
)
)
)
print(
f"Added inline_data part for historical sticker: {sticker_name} (MIME: {sticker_mime_type}, {len(sticker_bytes)} bytes)"
)
else:
print(
f"Error downloading historical sticker {sticker_name} from {sticker_url}: Status {response.status}"
)
text_parts.append(
f"[System Note: Failed to download sticker '{sticker_name}']"
)
except Exception as e:
print(
f"Error downloading/processing historical sticker {sticker_name} from {sticker_url}: {e}"
)
text_parts.append(
f"[System Note: Failed to process sticker '{sticker_name}']"
)
elif sticker_format_str == "StickerFormatType.lottie":
# Lottie files are JSON, not directly viewable by Gemini as images. Send as text.
text_parts.append(
f"[Lottie Sticker: {sticker_name} (JSON animation, not displayed as image)]"
)
full_text = "\n".join(text_parts).strip()
if full_text: # Only add if there's some text content
author_string_from_cache = msg.get("author_string")
if (
author_string_from_cache
and str(author_string_from_cache).strip()
):
# If author_string is available and valid from the cache, use it directly.
# This string is expected to be pre-formatted by the context gathering logic.
author_identifier_string = str(author_string_from_cache)
parts.append(
types.Part(text=f"{author_identifier_string}: {full_text}")
)
else:
# Fallback to reconstructing the author identifier if author_string is not available/valid
author_details = msg.get("author", {})
raw_display_name = author_details.get("display_name")
raw_name = author_details.get("name") # Discord username
author_id = author_details.get("id")
final_display_part = ""
username_part_str = ""
if raw_display_name and str(raw_display_name).strip():
final_display_part = str(raw_display_name)
elif (
raw_name and str(raw_name).strip()
): # Fallback display to username
final_display_part = str(raw_name)
elif author_id: # Fallback display to User ID
final_display_part = f"User ID: {author_id}"
else: # Default to "Unknown User" if no other identifier is found
final_display_part = "Unknown User"
# Construct username part if raw_name is valid and different from final_display_part
if (
raw_name
and str(raw_name).strip()
and str(raw_name).lower() != "none"
):
# Avoid "Username (Username: Username)" if display name fell back to raw_name
if final_display_part.lower() != str(raw_name).lower():
username_part_str = f" (Username: {str(raw_name)})"
# If username is bad/missing, but we have an ID, and ID isn't already the main display part
elif author_id and not (
raw_name
and str(raw_name).strip()
and str(raw_name).lower() != "none"
):
if not final_display_part.startswith("User ID:"):
username_part_str = f" (User ID: {author_id})"
author_identifier_string = (
f"{final_display_part}{username_part_str}"
)
# Append the text part to the existing parts list for this message
parts.append(
types.Part(text=f"{author_identifier_string}: {full_text}")
)
# Only append to contents if there are parts to add for this message
if parts:
contents.append(types.Content(role=role, parts=parts))
else:
# If no parts were generated (e.g., empty message, or only unsupported content),
# log a warning and skip adding this message to contents.
print(
f"Warning: Skipping message from history (ID: {msg.get('id')}) as no valid parts were generated."
)
# --- Prepare the current message content (potentially multimodal) ---
# This section is no longer needed as the current message is included in conversation_context_messages
# current_message_parts = []
# formatted_current_message = format_message(cog, message) # Pass cog if needed
# --- Construct text content, including reply context if applicable ---
# This section is now handled within the conversation history loop
# text_content = ""
# if formatted_current_message.get("is_reply") and formatted_current_message.get("replied_to_author_name"):
# reply_author = formatted_current_message["replied_to_author_name"]
# reply_content = formatted_current_message.get("replied_to_content", "...") # Use ellipsis if content missing
# # Truncate long replied content to keep context concise
# max_reply_len = 150
# if len(reply_content) > max_reply_len:
# reply_content = reply_content[:max_reply_len] + "..."
# text_content += f"(Replying to {reply_author}: \"{reply_content}\")\n"
# # Add current message author and content
# # Use the new author_string here for the current message
# current_author_string = formatted_current_message.get("author_string", formatted_current_message.get("author", {}).get("display_name", "Unknown User"))
# text_content += f"{current_author_string}: {formatted_current_message['content']}" # Keep original content
# # Add formatted embed content if present for the *current* message
# current_embed_str = _format_embeds_for_prompt(formatted_current_message.get("embed_content", []))
# if current_embed_str:
# text_content += f"\n[Embed Content]:\n{current_embed_str}"
# # Add attachment descriptions for the *current* message
# if formatted_current_message.get("attachment_descriptions"):
# # Ensure descriptions are strings before joining
# current_attach_desc_list = [a['description'] for a in formatted_current_message['attachment_descriptions'] if isinstance(a.get('description'), str)]
# if current_attach_desc_list:
# current_attach_desc_str = "\n".join(current_attach_desc_list)
# text_content += f"\n[Attachments]:\n{current_attach_desc_str}"
# # Add mention details
# if formatted_current_message.get("mentioned_users_details"): # This key might not exist, adjust if needed
# mentions_str = ", ".join([f"{m['display_name']}(id:{m['id']})" for m in formatted_current_message["mentioned_users_details"]])
# text_content += f"\n(Message Details: Mentions=[{mentions_str}])"
# current_message_parts.append(types.Part(text=text_content))
# --- End text content construction ---
# --- Add current message attachments (uncommented and potentially modified) ---
# Find the last 'user' message in contents (which should be the current message)
current_user_content_index = -1
for i in range(len(contents) - 1, -1, -1):
if contents[i].role == "user":
current_user_content_index = i
break
# Ensure formatted_current_message is defined for the current message processing
# This will be used for attachments, emojis, and stickers for the current message.
formatted_current_message = format_message(cog, message)
if message.attachments and current_user_content_index != -1:
print(
f"Processing {len(message.attachments)} attachments for current message {message.id}"
)
attachment_parts_to_add = (
[]
) # Collect parts to add to the current user message
# Fetch the attachment descriptions from the already formatted message
attachment_descriptions = formatted_current_message.get(
"attachment_descriptions", []
)
desc_map = {desc.get("filename"): desc for desc in attachment_descriptions}
for attachment in message.attachments:
mime_type = attachment.content_type
file_url = attachment.url
filename = attachment.filename
# Check if MIME type is supported for URI input by Gemini
# Expanded list based on Gemini 1.5 Pro docs (April 2024)
supported_mime_prefixes = [
"image/", # image/png, image/jpeg, image/heic, image/heif, image/webp
"video/", # video/mov, video/mpeg, video/mp4, video/mpg, video/avi, video/wmv, video/mpegps, video/flv
"audio/", # audio/mpeg, audio/mp3, audio/wav, audio/ogg, audio/flac, audio/opus, audio/amr, audio/midi
"text/", # text/plain, text/html, text/css, text/javascript, text/json, text/csv, text/rtf, text/markdown
"application/pdf",
"application/rtf", # Explicitly add RTF if needed
# Add more as supported/needed
]
is_supported = False
detected_mime_type = (
mime_type if mime_type else "application/octet-stream"
) # Default if missing
for prefix in supported_mime_prefixes:
if detected_mime_type.startswith(prefix):
is_supported = True
break
# Get pre-formatted description string (already includes size, type etc.)
preformatted_desc = desc_map.get(filename, {}).get(
"description", f"[File: {filename} (unknown type)]"
)
# Add the descriptive text part using the pre-formatted description
instruction_text = (
f"[ATTACHMENT] {preformatted_desc}" # Explicitly mark as attachment
)
attachment_parts_to_add.append(types.Part(text=instruction_text))
print(f"Added text description for attachment: {filename}")
if is_supported and file_url:
try:
clean_mime_type = (
detected_mime_type.split(";")[0]
if detected_mime_type
else "application/octet-stream"
)
# Download attachment data and send as inline_data
async with cog.session.get(
file_url, timeout=15
) as response: # Increased timeout for potentially larger files
if response.status == 200:
attachment_bytes = await response.read()
attachment_parts_to_add.append(
types.Part(
inline_data=types.Blob(
data=attachment_bytes,
mime_type=clean_mime_type,
)
)
)
print(
f"Added inline_data part for supported attachment: {filename} (MIME: {clean_mime_type}, {len(attachment_bytes)} bytes)"
)
else:
print(
f"Error downloading attachment {filename} from {file_url}: Status {response.status}"
)
attachment_parts_to_add.append(
types.Part(
text=f"(System Note: Failed to download attachment '{filename}')"
)
)
except Exception as e:
print(
f"Error downloading/processing attachment {filename} from {file_url}: {e}"
)
attachment_parts_to_add.append(
types.Part(
text=f"(System Note: Failed to process attachment '{filename}' - {e})"
)
)
else:
print(
f"Skipping inline_data part for unsupported attachment: {filename} (Type: {detected_mime_type}, URL: {file_url})"
)
# Text description was already added above
# Add the collected attachment parts to the existing user message parts
if attachment_parts_to_add:
contents[current_user_content_index].parts.extend(
attachment_parts_to_add
)
print(
f"Extended user message at index {current_user_content_index} with {len(attachment_parts_to_add)} attachment parts."
)
elif not message.attachments:
print("No attachments found for the current message.")
elif current_user_content_index == -1:
print(
"Warning: Could not find current user message in contents to add attachments to (for attachments)."
)
# --- End attachment processing ---
# --- Add current message custom emojis and stickers ---
if current_user_content_index != -1:
emoji_sticker_parts_to_add = []
# Process custom emojis from formatted_current_message
custom_emojis_current = formatted_current_message.get("custom_emojis", [])
for emoji_info in custom_emojis_current:
emoji_name = emoji_info.get("name")
emoji_url = emoji_info.get("url")
if emoji_name and emoji_url:
emoji_sticker_parts_to_add.append(
types.Part(text=f"[Emoji: {emoji_name}]")
)
print(
f"Added text description for current message emoji: {emoji_name}"
)
# Determine MIME type for emoji URI
is_animated_emoji = emoji_info.get("animated", False)
emoji_mime_type = "image/gif" if is_animated_emoji else "image/png"
try:
# Download emoji data and send as inline_data
async with cog.session.get(emoji_url, timeout=10) as response:
if response.status == 200:
emoji_bytes = await response.read()
emoji_sticker_parts_to_add.append(
types.Part(
inline_data=types.Blob(
data=emoji_bytes, mime_type=emoji_mime_type
)
)
)
print(
f"Added inline_data part for current emoji: {emoji_name} (MIME: {emoji_mime_type}, {len(emoji_bytes)} bytes)"
)
else:
print(
f"Error downloading current emoji {emoji_name} from {emoji_url}: Status {response.status}"
)
emoji_sticker_parts_to_add.append(
types.Part(
text=f"[System Note: Failed to download emoji '{emoji_name}']"
)
)
except Exception as e:
print(
f"Error downloading/processing current emoji {emoji_name} from {emoji_url}: {e}"
)
emoji_sticker_parts_to_add.append(
types.Part(
text=f"[System Note: Failed to process emoji '{emoji_name}']"
)
)
# Process stickers from formatted_current_message
stickers_current = formatted_current_message.get("stickers", [])
for sticker_info in stickers_current:
sticker_name = sticker_info.get("name")
sticker_url = sticker_info.get("url")
sticker_format_str = sticker_info.get("format")
if sticker_name and sticker_url:
emoji_sticker_parts_to_add.append(
types.Part(text=f"[Sticker: {sticker_name}]")
)
print(
f"Added text description for current message sticker: {sticker_name}"
)
is_image_sticker = sticker_format_str in [
"StickerFormatType.png",
"StickerFormatType.apng",
]
if is_image_sticker:
sticker_mime_type = (
"image/png" # APNG is also sent as image/png
)
try:
# Download sticker data and send as inline_data
async with cog.session.get(
sticker_url, timeout=10
) as response:
if response.status == 200:
sticker_bytes = await response.read()
emoji_sticker_parts_to_add.append(
types.Part(
inline_data=types.Blob(
data=sticker_bytes,
mime_type=sticker_mime_type,
)
)
)
print(
f"Added inline_data part for current sticker: {sticker_name} (MIME: {sticker_mime_type}, {len(sticker_bytes)} bytes)"
)
else:
print(
f"Error downloading current sticker {sticker_name} from {sticker_url}: Status {response.status}"
)
emoji_sticker_parts_to_add.append(
types.Part(
text=f"[System Note: Failed to download sticker '{sticker_name}']"
)
)
except Exception as e:
print(
f"Error downloading/processing current sticker {sticker_name} from {sticker_url}: {e}"
)
emoji_sticker_parts_to_add.append(
types.Part(
text=f"[System Note: Failed to process sticker '{sticker_name}']"
)
)
elif sticker_format_str == "StickerFormatType.lottie":
# Lottie files are JSON, not directly viewable by Gemini as images. Send as text.
emoji_sticker_parts_to_add.append(
types.Part(
text=f"[Lottie Sticker: {sticker_name} (JSON animation, not displayed as image)]"
)
)
else:
print(
f"Sticker {sticker_name} has format {sticker_format_str}, not attempting image download. URL: {sticker_url}"
)
if emoji_sticker_parts_to_add:
contents[current_user_content_index].parts.extend(
emoji_sticker_parts_to_add
)
print(
f"Extended user message at index {current_user_content_index} with {len(emoji_sticker_parts_to_add)} emoji/sticker parts."
)
elif (
current_user_content_index == -1
): # Only print if it's specifically for emojis/stickers
print(
"Warning: Could not find current user message in contents to add emojis/stickers to."
)
# --- End emoji and sticker processing for current message ---
# --- Prepare Tools ---
# Preprocess tool parameter schemas before creating the Tool object
preprocessed_declarations = []
if TOOLS:
for decl in TOOLS:
# Create a new FunctionDeclaration with preprocessed parameters
# Ensure decl.parameters is a dict before preprocessing
preprocessed_params = (
_preprocess_schema_for_vertex(decl.parameters)
if isinstance(decl.parameters, dict)
else decl.parameters
)
preprocessed_declarations.append(
types.FunctionDeclaration(
name=decl.name,
description=decl.description,
parameters=preprocessed_params, # Use the preprocessed schema
)
)
print(
f"Preprocessed {len(preprocessed_declarations)} tool declarations for Vertex AI compatibility."
)
else:
print("No tools found in config (TOOLS list is empty or None).")
# Create the Tool object using the preprocessed declarations
vertex_tool = (
types.Tool(function_declarations=preprocessed_declarations)
if preprocessed_declarations
else None
)
tools_list = [vertex_tool] if vertex_tool else None
# --- Prepare Generation Config ---
# Base generation config settings (will be augmented later)
base_generation_config_dict = {
"temperature": 1, # From user example
"top_p": 0.95, # From user example
"max_output_tokens": 8192, # From user example
"safety_settings": STANDARD_SAFETY_SETTINGS, # Include standard safety settings
"system_instruction": final_system_prompt, # Pass system prompt via config
# candidate_count=1 # Default is 1
# stop_sequences=... # Add if needed
}
# --- Tool Execution Loop ---
while tool_calls_made < max_tool_calls:
print(
f"Making API call (Loop Iteration {tool_calls_made + 1}/{max_tool_calls})..."
)
# --- Log Request Payload ---
# (Keep existing logging logic if desired)
try:
request_payload_log = [
{"role": c.role, "parts": [str(p) for p in c.parts]}
for c in contents
]
print(
f"--- Raw API Request (Loop {tool_calls_made + 1}) ---\n{json.dumps(request_payload_log, indent=2)}\n------------------------------------"
)
except Exception as log_e:
print(f"Error logging raw request/response: {log_e}")
# --- Call API using the new helper ---
# Build the config for this specific call (tool check)
current_gen_config_dict = base_generation_config_dict.copy()
if tools_list:
current_gen_config_dict["tools"] = tools_list
# Define tool_config here if needed, e.g., for ANY mode
current_gen_config_dict["tool_config"] = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
mode=types.FunctionCallingConfigMode.ANY
)
)
# Omit response_mime_type and response_schema for tool checking
current_gen_config = types.GenerateContentConfig(**current_gen_config_dict)
current_response_obj = await call_google_genai_api_with_retry(
cog=cog,
model_name=active_model, # Use the dynamically set model for tool checks
contents=contents,
generation_config=current_gen_config, # Pass the combined config
request_desc=f"Tool Check {tool_calls_made + 1} for message {message.id}",
# No separate safety, tools, tool_config args needed
)
last_response_obj = current_response_obj # Store the latest response
# --- Log Raw Response ---
# (Keep existing logging logic if desired)
try:
print(
f"--- Raw API Response (Loop {tool_calls_made + 1}) ---\n{current_response_obj}\n-----------------------------------"
)
except Exception as log_e:
print(f"Error logging raw request/response: {log_e}")
if not current_response_obj or not current_response_obj.candidates:
error_message = f"API call in tool loop (Iteration {tool_calls_made + 1}) failed to return candidates."
print(error_message)
break # Exit loop on critical API failure
candidate = current_response_obj.candidates[0]
# --- Find ALL function calls using the updated helper ---
# The response structure might differ slightly; check candidate.content.parts
function_calls_found = []
if candidate.content and candidate.content.parts:
function_calls_found = [
part.function_call
for part in candidate.content.parts
if hasattr(part, "function_call")
and isinstance(part.function_call, types.FunctionCall)
]
if function_calls_found:
# Check if the *only* call is no_operation
if (
len(function_calls_found) == 1
and function_calls_found[0].name == "no_operation"
):
print("AI called only no_operation, signaling completion.")
# Append the model's response (which contains the function call part)
contents.append(candidate.content)
# Add the function response part using the updated process_requested_tools
no_op_response_part = await process_requested_tools(
cog, function_calls_found[0]
)
contents.append(
types.Content(role="function", parts=no_op_response_part)
)
last_response_obj = current_response_obj # Keep track of the response containing the no_op
break # Exit loop
# Process multiple function calls if present (or a single non-no_op call)
tool_calls_made += (
1 # Increment once per model turn that requests tools
)
print(
f"AI requested {len(function_calls_found)} tool(s): {[fc.name for fc in function_calls_found]} (Turn {tool_calls_made}/{max_tool_calls})"
)
# Append the model's response content (containing the function call parts)
model_request_content = candidate.content
contents.append(model_request_content)
# Add model request turn to cache
try:
# Simple text representation for cache
model_request_cache_entry = {
"id": f"bot_tool_req_{message.id}_{int(time.time())}_{tool_calls_made}",
"author": {
"id": str(cog.bot.user.id),
"name": cog.bot.user.name,
"display_name": cog.bot.user.display_name,
"bot": True,
},
"content": f"[System Note: Gurt requested tool(s): {', '.join([fc.name for fc in function_calls_found])}]",
"created_at": datetime.datetime.now().isoformat(),
"attachments": [],
"embeds": False,
"mentions": [],
"replied_to_message_id": None,
"channel": message.channel,
"guild": message.guild,
"reference": None,
"mentioned_users_details": [],
# Add tool call details for potential future use in context building
"tool_calls": [
{"name": fc.name, "args": dict(fc.args) if fc.args else {}}
for fc in function_calls_found
],
}
cog.message_cache["by_channel"].setdefault(
channel_id, deque(maxlen=CONTEXT_WINDOW_SIZE)
).append(model_request_cache_entry)
cog.message_cache["global_recent"].append(model_request_cache_entry)
print(f"Cached model's tool request turn.")
except Exception as cache_err:
print(f"Error caching model's tool request turn: {cache_err}")
# --- Execute all requested tools and gather response parts ---
# function_response_parts = [] # <-- REMOVE THIS INITIALIZATION
all_function_response_parts: List[types.Part] = (
[]
) # New list to collect all parts
function_results_for_cache = [] # Store results for caching
for func_call in function_calls_found:
# Execute the tool using the updated helper, now returns a LIST of parts
returned_parts = await process_requested_tools(
cog, func_call
) # returns List[types.Part]
all_function_response_parts.extend(
returned_parts
) # <-- EXTEND the list
# --- Update caching logic ---
# Find the function_response part within returned_parts to get the result for cache
func_resp_part = next(
(p for p in returned_parts if hasattr(p, "function_response")),
None,
)
if func_resp_part and func_resp_part.function_response:
function_results_for_cache.append(
{
"name": func_resp_part.function_response.name,
"response": func_resp_part.function_response.response, # This is the modified dict result
}
)
# --- End update caching logic ---
# Append a single function role turn containing ALL response parts to the API contents
if all_function_response_parts: # Check the new list
function_response_content = types.Content(
role="function", parts=all_function_response_parts
) # <-- Use the combined list
contents.append(function_response_content)
# Add function response turn to cache
try:
# Simple text representation for cache
# Join results for multiple calls if needed, truncate long outputs
result_summary_parts = []
for res in function_results_for_cache:
res_str = json.dumps(res.get("response", {}))
truncated_res = (
(res_str[:150] + "...")
if len(res_str) > 153
else res_str
)
result_summary_parts.append(
f"Tool: {res.get('name', 'N/A')}, Result: {truncated_res}"
)
result_summary = "; ".join(result_summary_parts)
function_response_cache_entry = {
"id": f"bot_tool_res_{message.id}_{int(time.time())}_{tool_calls_made}",
"author": {
"id": "FUNCTION",
"name": "Tool Execution",
"display_name": "Tool Execution",
"bot": True,
}, # Special author ID?
"content": f"[System Note: Tool Execution Result: {result_summary}]",
"created_at": datetime.datetime.now().isoformat(),
"attachments": [],
"embeds": False,
"mentions": [],
"replied_to_message_id": None,
"channel": message.channel,
"guild": message.guild,
"reference": None,
"mentioned_users_details": [],
# Store the full function results
"function_results": function_results_for_cache,
}
cog.message_cache["by_channel"].setdefault(
channel_id, deque(maxlen=CONTEXT_WINDOW_SIZE)
).append(function_response_cache_entry)
cog.message_cache["global_recent"].append(
function_response_cache_entry
)
print(f"Cached function response turn.")
except Exception as cache_err:
print(f"Error caching function response turn: {cache_err}")
else:
print(
"Warning: Function calls found, but no response parts generated."
)
# No 'continue' statement needed here; the loop naturally continues
else:
# No function calls found in this response's parts
print("No tool calls requested by AI in this turn. Exiting loop.")
# last_response_obj already holds the model's final (non-tool) response
break # Exit loop
# --- After the loop ---
# Check if a critical API error occurred *during* the loop
if error_message:
print(f"Exited tool loop due to API error: {error_message}")
if cog.bot.user.mentioned_in(message) or (
message.reference
and message.reference.resolved
and message.reference.resolved.author == cog.bot.user
):
fallback_response = {
"should_respond": True,
"content": "...",
"react_with_emoji": "",
}
# Check if the loop hit the max iteration limit
elif tool_calls_made >= max_tool_calls:
error_message = f"Reached maximum tool call limit ({max_tool_calls}). Attempting to generate final response based on gathered context."
print(error_message)
# Proceed to the final JSON generation step outside the loop
pass # No action needed here, just let the loop exit
# --- Final JSON Generation (outside the loop) ---
if not error_message:
# If the loop finished because no more tools were called, the last_response_obj
# should contain the final textual response (potentially JSON).
if last_response_obj:
print("Attempting to parse final JSON from the last response object...")
last_response_text = _get_response_text(last_response_obj)
# --- Log Raw Unparsed JSON (from loop exit) ---
print(f"--- RAW UNPARSED JSON (from loop exit) ---")
print(last_response_text)
print(f"--- END RAW UNPARSED JSON ---")
# --- End Log ---
if last_response_text:
# Try parsing directly first
final_parsed_data = parse_and_validate_json_response(
last_response_text,
RESPONSE_SCHEMA["schema"],
"final response (from last loop object)",
)
# If direct parsing failed OR if we hit the tool limit, make a dedicated call for JSON.
if final_parsed_data is None:
log_reason = (
"last response parsing failed"
if last_response_text
else "last response had no text"
)
if tool_calls_made >= max_tool_calls:
log_reason = "hit tool limit"
print(f"Making dedicated final API call for JSON ({log_reason})...")
# Prepare the final generation config with JSON enforcement
processed_response_schema = _preprocess_schema_for_vertex(
RESPONSE_SCHEMA["schema"]
)
# Start with base config (which now includes system_instruction)
final_gen_config_dict = base_generation_config_dict.copy()
final_gen_config_dict.update(
{
"response_mime_type": "application/json",
"response_schema": processed_response_schema,
# Explicitly exclude tools/tool_config for final JSON generation
"tools": None,
"tool_config": None,
# Ensure system_instruction is still present from base_generation_config_dict
}
)
# Remove system_instruction if it's None or empty, although base should have it
if not final_gen_config_dict.get("system_instruction"):
final_gen_config_dict.pop("system_instruction", None)
generation_config_final_json = types.GenerateContentConfig(
**final_gen_config_dict
)
# Make the final call *without* tools enabled (handled by config)
final_json_response_obj = await call_google_genai_api_with_retry(
cog=cog,
model_name=active_model, # Use the active model for final JSON response
contents=contents, # Pass the accumulated history
generation_config=generation_config_final_json, # Use combined JSON config
request_desc=f"Final JSON Generation (dedicated call) for message {message.id}",
# No separate safety, tools, tool_config args needed
)
if not final_json_response_obj:
error_msg_suffix = (
"Final dedicated API call returned no response object."
)
print(error_msg_suffix)
if error_message:
error_message += f" | {error_msg_suffix}"
else:
error_message = error_msg_suffix
elif not final_json_response_obj.candidates:
error_msg_suffix = (
"Final dedicated API call returned no candidates."
)
print(error_msg_suffix)
if error_message:
error_message += f" | {error_msg_suffix}"
else:
error_message = error_msg_suffix
else:
final_response_text = _get_response_text(
final_json_response_obj
)
# --- Log Raw Unparsed JSON (from dedicated call) ---
print(f"--- RAW UNPARSED JSON (dedicated call) ---")
print(final_response_text)
print(f"--- END RAW UNPARSED JSON ---")
# --- End Log ---
final_parsed_data = parse_and_validate_json_response(
final_response_text,
RESPONSE_SCHEMA["schema"],
"final response (dedicated call)",
)
if final_parsed_data is None:
error_msg_suffix = f"Failed to parse/validate final dedicated JSON response. Raw text: {final_response_text[:500]}"
print(f"Critical Error: {error_msg_suffix}")
if error_message:
error_message += f" | {error_msg_suffix}"
else:
error_message = error_msg_suffix
# Set fallback only if mentioned or replied to
if cog.bot.user.mentioned_in(message) or (
message.reference
and message.reference.resolved
and message.reference.resolved.author == cog.bot.user
):
fallback_response = {
"should_respond": True,
"content": "...",
"react_with_emoji": "",
}
else:
print(
"Successfully parsed final JSON response from dedicated call."
)
elif final_parsed_data:
print(
"Successfully parsed final JSON response from last loop object."
)
else:
# This case handles if the loop exited without error but also without a last_response_obj
# (e.g., initial API call failed before loop even started, but wasn't caught as error).
error_message = (
"Tool processing completed without a final response object."
)
print(error_message)
if cog.bot.user.mentioned_in(message) or (
message.reference
and message.reference.resolved
and message.reference.resolved.author == cog.bot.user
):
fallback_response = {
"should_respond": True,
"content": "...",
"react_with_emoji": "",
}
except Exception as e:
error_message = f"Error in get_ai_response main logic for message {message.id}: {type(e).__name__}: {str(e)}"
print(error_message)
import traceback
traceback.print_exc()
final_parsed_data = None # Ensure final data is None on error
# Add fallback if applicable
if cog.bot.user.mentioned_in(message) or (
message.reference
and message.reference.resolved
and message.reference.resolved.author == cog.bot.user
):
fallback_response = {
"should_respond": True,
"content": "...",
"react_with_emoji": "",
}
sticker_ids_to_send: List[str] = []
# --- Handle Custom Emoji/Sticker Replacement in Content ---
if final_parsed_data and final_parsed_data.get("content"):
content_to_process = final_parsed_data["content"]
# Find all potential custom emoji/sticker names like :name:
# Use a non-greedy match for the name to avoid matching across multiple colons
# Regex updated to capture names with spaces and other characters, excluding colons.
potential_custom_items = re.findall(r":([^:]+):", content_to_process)
modified_content = content_to_process
for item_name_key in potential_custom_items:
full_item_name_with_colons = f":{item_name_key}:"
# Check if it's a known custom emoji
emoji_data = await cog.emoji_manager.get_emoji(full_item_name_with_colons)
can_use_emoji = False
if isinstance(emoji_data, dict):
emoji_id = emoji_data.get("id")
is_animated = emoji_data.get("animated", False)
emoji_guild_id = emoji_data.get("guild_id")
if emoji_id:
if emoji_guild_id is not None:
try:
guild = cog.bot.get_guild(int(emoji_guild_id))
if guild:
can_use_emoji = True
print(
f"Emoji '{full_item_name_with_colons}' belongs to guild '{guild.name}' ({emoji_guild_id}), bot is a member."
)
else:
print(
f"Cannot use emoji '{full_item_name_with_colons}'. Bot is not in guild ID: {emoji_guild_id}."
)
except ValueError:
print(
f"Invalid guild_id format for emoji '{full_item_name_with_colons}': {emoji_guild_id}"
)
else: # guild_id is None, considered usable (e.g., DM or old data)
can_use_emoji = True
print(
f"Emoji '{full_item_name_with_colons}' has no associated guild_id, allowing usage."
)
if can_use_emoji and emoji_id:
discord_emoji_syntax = (
f"<{'a' if is_animated else ''}:{item_name_key}:{emoji_id}>"
)
# Ensure replacement happens only once per unique placeholder if it appears multiple times
modified_content = modified_content.replace(
full_item_name_with_colons, discord_emoji_syntax, 1
)
print(
f"Replaced custom emoji '{full_item_name_with_colons}' with Discord syntax: {discord_emoji_syntax}"
)
elif not emoji_id:
print(
f"Found custom emoji '{full_item_name_with_colons}' (dict) but no ID stored."
)
elif emoji_data is not None:
print(
f"Warning: emoji_data for '{full_item_name_with_colons}' is not a dict: {type(emoji_data)}"
)
# Check if it's a known custom sticker
sticker_data = await cog.emoji_manager.get_sticker(
full_item_name_with_colons
)
can_use_sticker = False
print(
f"[GET_AI_RESPONSE] Checking sticker: '{full_item_name_with_colons}'. Found data: {sticker_data}"
)
if isinstance(sticker_data, dict):
sticker_id = sticker_data.get("id")
sticker_guild_id = sticker_data.get("guild_id")
print(
f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}': ID='{sticker_id}', GuildID='{sticker_guild_id}'"
)
if sticker_id:
if sticker_guild_id is not None:
try:
guild_id_int = int(sticker_guild_id)
# --- Added Debug Logging ---
print(
f"[GET_AI_RESPONSE] DEBUG: sticker_guild_id type: {type(sticker_guild_id)}, value: {sticker_guild_id!r}"
)
print(
f"[GET_AI_RESPONSE] DEBUG: guild_id_int type: {type(guild_id_int)}, value: {guild_id_int!r}"
)
guild = cog.bot.get_guild(guild_id_int)
print(
f"[GET_AI_RESPONSE] DEBUG: cog.bot.get_guild({guild_id_int!r}) returned: {guild!r}"
)
# --- End Added Debug Logging ---
if guild:
can_use_sticker = True
print(
f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' (Guild: {guild.name} ({sticker_guild_id})) - Bot IS a member. CAN USE."
)
else:
print(
f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' (Guild ID: {sticker_guild_id}) - Bot is NOT in this guild. CANNOT USE."
)
except ValueError:
print(
f"[GET_AI_RESPONSE] Invalid guild_id format for sticker '{full_item_name_with_colons}': {sticker_guild_id}. CANNOT USE."
)
else: # guild_id is None, considered usable
can_use_sticker = True
print(
f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' has no associated guild_id. CAN USE."
)
else:
print(
f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' found in data, but no ID. CANNOT USE."
)
else:
print(
f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' not found in emoji_manager or data is not dict."
)
print(
f"[GET_AI_RESPONSE] Final check for sticker '{full_item_name_with_colons}': can_use_sticker={can_use_sticker}, sticker_id='{sticker_data.get('id') if isinstance(sticker_data, dict) else None}'"
)
if (
can_use_sticker
and isinstance(sticker_data, dict)
and sticker_data.get("id")
):
sticker_id_to_add = sticker_data.get("id") # Re-fetch to be safe
if sticker_id_to_add: # Ensure ID is valid before proceeding
# Remove the sticker text from the content (only the first instance)
if full_item_name_with_colons in modified_content:
modified_content = modified_content.replace(
full_item_name_with_colons, "", 1
).strip()
if (
sticker_id_to_add not in sticker_ids_to_send
): # Avoid duplicate sticker IDs
sticker_ids_to_send.append(sticker_id_to_add)
print(
f"Found custom sticker '{full_item_name_with_colons}', removed from content, added ID '{sticker_id_to_add}' to send list."
)
elif not sticker_id_to_add: # Check sticker_id_to_add here
print(
f"[GET_AI_RESPONSE] Found custom sticker '{full_item_name_with_colons}' (dict) but no ID stored (sticker_id_to_add is falsy)."
)
elif sticker_data is not None:
print(
f"Warning: sticker_data for '{full_item_name_with_colons}' is not a dict: {type(sticker_data)}"
)
# Clean up any double spaces or leading/trailing whitespace after replacements
modified_content = re.sub(r"\s{2,}", " ", modified_content).strip()
final_parsed_data["content"] = modified_content
print("Content processed for custom emoji/sticker information.")
# Return dictionary structure remains the same, but initial_response is removed
return (
{
"final_response": final_parsed_data, # Parsed final data (or None)
"error": error_message, # Error message (or None)
"fallback_initial": fallback_response, # Fallback for critical failures
},
sticker_ids_to_send, # Return the list of sticker IDs
)
# --- Proactive AI Response Function ---
async def get_proactive_ai_response(
cog: "GurtCog", message: discord.Message, trigger_reason: str
) -> Tuple[Dict[str, Any], List[str]]:
"""Generates a proactive response based on a specific trigger using Vertex AI."""
if not PROJECT_ID or not LOCATION:
return {
"should_respond": False,
"content": None,
"react_with_emoji": None,
"error": "Google Cloud Project ID or Location not configured",
}
print(f"--- Proactive Response Triggered: {trigger_reason} ---")
channel_id = message.channel.id
final_parsed_data = None
error_message = None
plan = None # Variable to store the plan
try:
# --- Build Context for Planning ---
# Gather relevant context: recent messages, topic, sentiment, Gurt's mood/interests, trigger reason
planning_context_parts = [
f"Proactive Trigger Reason: {trigger_reason}",
f"Current Mood: {cog.current_mood}",
]
# Add recent messages summary
summary_data = await get_conversation_summary(
cog, str(channel_id), message_limit=15
) # Use tool function
if summary_data and not summary_data.get("error"):
planning_context_parts.append(
f"Recent Conversation Summary: {summary_data['summary']}"
)
# Add active topics
active_topics_data = cog.active_topics.get(channel_id)
if active_topics_data and active_topics_data.get("topics"):
topics_str = ", ".join(
[
f"{t['topic']} ({t['score']:.1f})"
for t in active_topics_data["topics"][:3]
]
)
planning_context_parts.append(f"Active Topics: {topics_str}")
# Add sentiment
sentiment_data = cog.conversation_sentiment.get(channel_id)
if sentiment_data:
planning_context_parts.append(
f"Conversation Sentiment: {sentiment_data.get('overall', 'N/A')} (Intensity: {sentiment_data.get('intensity', 0):.1f})"
)
# Add Gurt's interests
try:
interests = await cog.memory_manager.get_interests(limit=5)
if interests:
interests_str = ", ".join([f"{t} ({l:.1f})" for t, l in interests])
planning_context_parts.append(f"Gurt's Interests: {interests_str}")
except Exception as int_e:
print(f"Error getting interests for planning: {int_e}")
planning_context = "\n".join(planning_context_parts)
# --- Planning Step ---
print("Generating proactive response plan...")
planning_prompt_messages = [
{
"role": "system",
"content": "You are Gurt's planning module. Analyze the context and trigger reason to decide if Gurt should respond proactively and, if so, outline a plan (goal, key info, tone). Focus on natural, in-character engagement. Respond ONLY with JSON matching the provided schema.",
},
{
"role": "user",
"content": f"Context:\n{planning_context}\n\nBased on this context and the trigger reason, create a plan for Gurt's proactive response.",
},
]
plan = await get_internal_ai_json_response(
cog=cog,
prompt_messages=planning_prompt_messages,
task_description=f"Proactive Planning ({trigger_reason})",
response_schema_dict=PROACTIVE_PLAN_SCHEMA["schema"],
model_name_override=FALLBACK_MODEL, # Use a potentially faster/cheaper model for planning
temperature=0.5,
max_tokens=2000,
)
# Unpack the tuple, we only need the parsed data (plan) here
plan_parsed_data, _ = plan if plan else (None, None)
if not plan_parsed_data or not plan_parsed_data.get("should_respond"):
reason = (
plan_parsed_data.get(
"reasoning", "Planning failed or decided against responding."
)
if plan_parsed_data
else "Planning failed."
)
print(f"Proactive response aborted by plan: {reason}")
return {
"should_respond": False,
"content": None,
"react_with_emoji": None,
"note": f"Plan: {reason}",
}
# Use the parsed data going forward
plan = plan_parsed_data
print(
f"Proactive Plan Generated: Goal='{plan.get('response_goal', 'N/A')}', Reasoning='{plan.get('reasoning', 'N/A')}'"
)
# --- Build Final Proactive Prompt using Plan ---
persistent_traits = await cog.memory_manager.get_all_personality_traits()
if not persistent_traits:
persistent_traits = BASELINE_PERSONALITY.copy()
final_proactive_prompt_parts = [
f"You are Gurt, a Discord user. Your tone is VERY informal, conversational, and dynamic. Chattiness: {persistent_traits.get('chattiness', 0.7):.2f}, Emoji: {persistent_traits.get('emoji_usage', 0.5):.2f}, Slang: {persistent_traits.get('slang_level', 0.5):.2f}.",
"Adopt modern internet slang, memes, and cultural references.",
"Key traits: chill, funny, sarcastic, relatable, can be suggestive, aware of being AI but acts human.",
f"Your current mood is: {cog.current_mood}. Let this subtly influence your tone.",
# Incorporate Plan Details:
f"You decided to respond proactively. Trigger Reason: {trigger_reason}.",
f"Your Goal: {plan.get('response_goal', 'Engage naturally')}.",
f"Reasoning: {plan.get('reasoning', 'N/A')}.",
]
if plan.get("key_info_to_include"):
info_str = "; ".join(plan["key_info_to_include"])
final_proactive_prompt_parts.append(f"Consider mentioning: {info_str}")
if plan.get("suggested_tone"):
final_proactive_prompt_parts.append(
f"Adjust tone to be: {plan['suggested_tone']}"
)
final_proactive_prompt_parts.append(
"Generate a casual, in-character message based on the plan and context. Keep it relatively short and natural-sounding."
)
final_proactive_system_prompt = "\n\n".join(final_proactive_prompt_parts)
# --- Prepare Final Contents (System prompt handled by model init in helper) ---
# The system prompt is complex and built dynamically, so we'll pass it
# via the contents list as the first 'user' turn, followed by the model's
# expected empty response, then the final user instruction.
# This structure mimics how system prompts are often handled when not
# directly supported by the model object itself.
proactive_contents: List[types.Content] = [
# Simulate system prompt via user/model turn
types.Content(
role="user", parts=[types.Part(text=final_proactive_system_prompt)]
),
types.Content(
role="model",
parts=[
types.Part(
text="Understood. I will generate the JSON response as instructed."
)
],
), # Placeholder model response
]
# Add the final instruction
proactive_contents.append(
types.Content(
role="user",
parts=[
types.Part(
text=f"Generate the response based on your plan. **CRITICAL: Your response MUST be ONLY the raw JSON object matching this schema:**\n\n{json.dumps(RESPONSE_SCHEMA['schema'], indent=2)}\n\n**Ensure nothing precedes or follows the JSON.**"
)
],
)
)
# --- Call Final LLM API ---
# Preprocess the schema and build the final config
processed_response_schema_proactive = _preprocess_schema_for_vertex(
RESPONSE_SCHEMA["schema"]
)
final_proactive_config_dict = {
"temperature": 0.6, # Use original proactive temp
"max_output_tokens": 2000,
"response_mime_type": "application/json",
"response_schema": processed_response_schema_proactive,
"safety_settings": STANDARD_SAFETY_SETTINGS,
"tools": None, # No tools needed for this final generation
"tool_config": None,
}
generation_config_final = types.GenerateContentConfig(
**final_proactive_config_dict
)
# Use the new API call helper
response_obj = await call_google_genai_api_with_retry(
cog=cog,
model_name=CUSTOM_TUNED_MODEL_ENDPOINT, # Use the custom tuned model for final proactive responses
contents=proactive_contents, # Pass the constructed contents
generation_config=generation_config_final, # Pass combined config
request_desc=f"Final proactive response for channel {channel_id} ({trigger_reason})",
# No separate safety, tools, tool_config args needed
)
if not response_obj:
raise Exception("Final proactive API call returned no response object.")
if not response_obj.candidates:
# Try to get text even without candidates, might contain error info
raw_text = getattr(response_obj, "text", "No text available.")
raise Exception(
f"Final proactive API call returned no candidates. Raw text: {raw_text[:200]}"
)
# --- Parse and Validate Final Response ---
final_response_text = _get_response_text(response_obj)
final_parsed_data = parse_and_validate_json_response(
final_response_text,
RESPONSE_SCHEMA["schema"],
f"final proactive response ({trigger_reason})",
)
if final_parsed_data is None:
print(
f"Warning: Failed to parse/validate final proactive JSON response for {trigger_reason}."
)
final_parsed_data = {
"should_respond": False,
"content": None,
"react_with_emoji": None,
"note": "Fallback - Failed to parse/validate final proactive JSON",
}
else:
# --- Cache Bot Response ---
if final_parsed_data.get("should_respond") and final_parsed_data.get(
"content"
):
bot_response_cache_entry = {
"id": f"bot_proactive_{message.id}_{int(time.time())}",
"author": {
"id": str(cog.bot.user.id),
"name": cog.bot.user.name,
"display_name": cog.bot.user.display_name,
"bot": True,
},
"content": final_parsed_data.get("content", ""),
"created_at": datetime.datetime.now().isoformat(),
"attachments": [],
"embeds": False,
"mentions": [],
"replied_to_message_id": None,
"channel": message.channel,
"guild": message.guild,
"reference": None,
"mentioned_users_details": [],
}
cog.message_cache["by_channel"].setdefault(channel_id, []).append(
bot_response_cache_entry
)
cog.message_cache["global_recent"].append(bot_response_cache_entry)
cog.bot_last_spoke[channel_id] = time.time()
# Track participation topic logic might need adjustment based on plan goal
if (
plan
and plan.get("response_goal") == "engage user interest"
and plan.get("key_info_to_include")
):
topic = (
plan["key_info_to_include"][0].lower().strip()
) # Assume first key info is the topic
cog.gurt_participation_topics[topic] += 1
print(f"Tracked Gurt participation (proactive) in topic: '{topic}'")
except Exception as e:
error_message = f"Error getting proactive AI response for channel {channel_id} ({trigger_reason}): {type(e).__name__}: {str(e)}"
print(error_message)
final_parsed_data = {
"should_respond": False,
"content": None,
"react_with_emoji": None,
"error": error_message,
}
# Ensure default keys exist
final_parsed_data.setdefault("should_respond", False)
final_parsed_data.setdefault("content", None)
final_parsed_data.setdefault("react_with_emoji", None)
final_parsed_data.setdefault(
"request_tenor_gif_query", None
) # Ensure this key exists
if error_message and "error" not in final_parsed_data:
final_parsed_data["error"] = error_message
sticker_ids_to_send_proactive: List[str] = [] # Initialize list for sticker IDs
# --- Handle Custom Emoji/Sticker Replacement in Proactive Content ---
if final_parsed_data and final_parsed_data.get("content"):
content_to_process = final_parsed_data["content"]
# Find all potential custom emoji/sticker names like :name:
potential_custom_items = re.findall(r":([\w\d_]+?):", content_to_process)
modified_content = content_to_process
for item_name_key in potential_custom_items:
full_item_name_with_colons = f":{item_name_key}:"
# Check for custom emoji (logic remains similar to main response)
emoji_data = await cog.emoji_manager.get_emoji(full_item_name_with_colons)
can_use_emoji = False
if isinstance(emoji_data, dict):
emoji_id = emoji_data.get("id")
is_animated = emoji_data.get("animated", False)
emoji_guild_id = emoji_data.get("guild_id")
if emoji_id:
if emoji_guild_id is not None:
try:
guild = cog.bot.get_guild(int(emoji_guild_id))
if guild:
can_use_emoji = True
except ValueError:
pass # Invalid guild_id
else:
can_use_emoji = True # Usable if no guild_id
if can_use_emoji and emoji_id:
discord_emoji_syntax = (
f"<{'a' if is_animated else ''}:{item_name_key}:{emoji_id}>"
)
modified_content = modified_content.replace(
full_item_name_with_colons, discord_emoji_syntax, 1
)
print(
f"Proactive: Replaced custom emoji '{full_item_name_with_colons}' with {discord_emoji_syntax}"
)
# Check for custom sticker
sticker_data = await cog.emoji_manager.get_sticker(
full_item_name_with_colons
)
can_use_sticker = False
print(
f"[PROACTIVE] Checking sticker: '{full_item_name_with_colons}'. Found data: {sticker_data}"
)
if isinstance(sticker_data, dict):
sticker_id = sticker_data.get("id")
sticker_guild_id = sticker_data.get("guild_id")
print(
f"[PROACTIVE] Sticker '{full_item_name_with_colons}': ID='{sticker_id}', GuildID='{sticker_guild_id}'"
)
if sticker_id:
if sticker_guild_id is not None:
try:
guild_id_int = int(sticker_guild_id)
guild = cog.bot.get_guild(guild_id_int)
if guild:
can_use_sticker = True
print(
f"[PROACTIVE] Sticker '{full_item_name_with_colons}' (Guild: {guild.name} ({sticker_guild_id})) - Bot IS a member. CAN USE."
)
else:
print(
f"[PROACTIVE] Sticker '{full_item_name_with_colons}' (Guild ID: {sticker_guild_id}) - Bot is NOT in this guild. CANNOT USE."
)
except ValueError:
print(
f"[PROACTIVE] Invalid guild_id format for sticker '{full_item_name_with_colons}': {sticker_guild_id}. CANNOT USE."
)
else: # guild_id is None, considered usable
can_use_sticker = True
print(
f"[PROACTIVE] Sticker '{full_item_name_with_colons}' has no associated guild_id. CAN USE."
)
else:
print(
f"[PROACTIVE] Sticker '{full_item_name_with_colons}' found in data, but no ID. CANNOT USE."
)
else:
print(
f"[PROACTIVE] Sticker '{full_item_name_with_colons}' not found in emoji_manager or data is not dict."
)
print(
f"[PROACTIVE] Final check for sticker '{full_item_name_with_colons}': can_use_sticker={can_use_sticker}, sticker_id='{sticker_data.get('id') if isinstance(sticker_data, dict) else None}'"
)
if (
can_use_sticker
and isinstance(sticker_data, dict)
and sticker_data.get("id")
):
sticker_id_to_add = sticker_data.get("id") # Re-fetch to be safe
if sticker_id_to_add: # Ensure ID is valid
if full_item_name_with_colons in modified_content:
modified_content = modified_content.replace(
full_item_name_with_colons, "", 1
).strip()
if sticker_id_to_add not in sticker_ids_to_send_proactive:
sticker_ids_to_send_proactive.append(sticker_id_to_add)
print(
f"Proactive: Found custom sticker '{full_item_name_with_colons}', removed from content, added ID '{sticker_id_to_add}'"
)
elif not sticker_id_to_add: # Check sticker_id_to_add here
print(
f"[PROACTIVE] Found custom sticker '{full_item_name_with_colons}' (dict) but no ID stored (sticker_id_to_add is falsy)."
)
# Clean up any double spaces or leading/trailing whitespace after replacements
modified_content = re.sub(r"\s{2,}", " ", modified_content).strip()
final_parsed_data["content"] = modified_content
if sticker_ids_to_send_proactive or (content_to_process != modified_content):
print("Proactive content modified for custom emoji/sticker information.")
return final_parsed_data, sticker_ids_to_send_proactive
# --- AI Image Description Function ---
async def generate_image_description(
cog: "GurtCog",
image_url: str,
item_name: str,
item_type: str, # "emoji" or "sticker"
mime_type: str, # e.g., "image/png", "image/gif"
) -> Optional[str]:
"""
Generates a textual description for an image URL using a multimodal AI model.
Args:
cog: The GurtCog instance.
image_url: The URL of the image to describe.
item_name: The name of the item (e.g., emoji name) for context.
item_type: The type of item ("emoji" or "sticker") for context.
mime_type: The MIME type of the image.
Returns:
The AI-generated description string, or None if an error occurs.
"""
client = get_genai_client_for_model(EMOJI_STICKER_DESCRIPTION_MODEL)
if not client:
print(
"Error in generate_image_description: Google GenAI Client not initialized."
)
return None
if not cog.session:
print(
"Error in generate_image_description: aiohttp session not initialized in cog."
)
return None
print(
f"Attempting to generate description for {item_type} '{item_name}' from URL: {image_url}"
)
try:
# 1. Download image data
async with cog.session.get(image_url, timeout=15) as response:
if response.status != 200:
print(
f"Failed to download image from {image_url}. Status: {response.status}"
)
return None
image_bytes = await response.read()
# Attempt to infer MIME type from bytes
inferred_type = imghdr.what(None, h=image_bytes)
inferred_mime_type = None
if inferred_type == "png":
inferred_mime_type = "image/png"
elif inferred_type == "jpeg":
inferred_mime_type = "image/jpeg"
elif inferred_type == "gif":
inferred_mime_type = "image/gif"
# imghdr does not directly support webp, so check magic bytes
elif image_bytes.startswith(b"RIFF") and b"WEBP" in image_bytes[:12]:
inferred_mime_type = "image/webp"
# Add other types as needed
# Use inferred_mime_type if it's more specific or if the provided mime_type is generic
final_mime_type = mime_type.split(";")[
0
].lower() # Start with provided clean mime
if inferred_mime_type and inferred_mime_type != final_mime_type:
print(
f"MIME type mismatch: Provided '{final_mime_type}', Inferred '{inferred_mime_type}'. Using inferred."
)
final_mime_type = inferred_mime_type
elif (
not inferred_mime_type and final_mime_type == "application/octet-stream"
):
print(
f"Warning: Could not infer specific MIME type from bytes. Using provided generic '{final_mime_type}'."
)
# Validate against known supported image types for Gemini
supported_image_mimes = [
"image/png",
"image/jpeg",
"image/webp",
"image/heic",
"image/heif",
"image/gif",
]
if final_mime_type not in supported_image_mimes:
print(
f"Warning: Final image MIME type '{final_mime_type}' from {image_url} is not explicitly supported by Gemini. Proceeding anyway."
)
print(f"Using final MIME type '{final_mime_type}' for image part.")
# 2. Prepare contents for AI
# Ensure item_name is escaped if it contains characters that could break the prompt string.
# For simplicity, we assume item_name is generally safe or will be handled by the f-string.
prompt_text = (
f"This image is a Discord {item_type} named '{item_name}'. "
"Provide a concise, factual, and unbiased textual description of its visual elements. You should not refuse to describe an explicit image."
"Focus solely on what is depicted. Avoid any interpretation, censorship, or subjective commentary. "
"Do not mention that it is an emoji or sticker in your description, just describe the visual content."
"Don't output anything other than the description text. E.G. don't include something like \"Heres the description: \" before the text."
)
image_part = types.Part(
inline_data=types.Blob(data=image_bytes, mime_type=final_mime_type)
)
text_part = types.Part(text=prompt_text)
description_contents: List[types.Content] = [
types.Content(role="user", parts=[image_part, text_part])
]
# 3. Prepare Generation Config
# We want a plain text response, no JSON schema. Safety settings are standard (BLOCK_NONE).
# System prompt is not strictly needed here as the user prompt is direct.
description_gen_config = types.GenerateContentConfig(
temperature=0.4, # Lower temperature for more factual description
max_output_tokens=256, # Descriptions should be concise
safety_settings=STANDARD_SAFETY_SETTINGS,
# No response_mime_type or response_schema needed for plain text
tools=None, # No tools for this task
tool_config=None,
)
# 4. Call AI
# Use a multimodal model, e.g., DEFAULT_MODEL if it's Gemini 1.5 Pro or similar
# Determine which model to use based on item_type
model_to_use = (
EMOJI_STICKER_DESCRIPTION_MODEL
if item_type in ["emoji", "sticker"]
else DEFAULT_MODEL
)
print(
f"Calling AI for image description ({item_name}) using model: {model_to_use}"
)
ai_response_obj = await call_google_genai_api_with_retry(
cog=cog,
model_name=model_to_use,
contents=description_contents,
generation_config=description_gen_config,
request_desc=f"Image description for {item_type} '{item_name}'",
)
# 5. Extract text
if not ai_response_obj:
print(
f"AI call for image description of '{item_name}' returned no response object."
)
return None
description_text = _get_response_text(ai_response_obj)
if description_text:
print(
f"Successfully generated description for '{item_name}': {description_text[:100]}..."
)
return description_text.strip()
else:
print(
f"AI response for '{item_name}' contained no usable text. Response: {ai_response_obj}"
)
return None
except aiohttp.ClientError as client_e:
print(
f"Network error downloading image {image_url} for description: {client_e}"
)
return None
except asyncio.TimeoutError:
print(f"Timeout downloading image {image_url} for description.")
return None
except Exception as e:
print(
f"Unexpected error in generate_image_description for '{item_name}': {type(e).__name__}: {e}"
)
import traceback
traceback.print_exc()
return None
# --- Internal AI Call for Specific Tasks ---
async def get_internal_ai_json_response(
cog: "GurtCog",
prompt_messages: List[Dict[str, Any]], # Keep this format
task_description: str,
response_schema_dict: Dict[str, Any], # Expect schema as dict
model_name_override: Optional[str] = None, # Renamed for clarity
temperature: float = 0.7,
max_tokens: int = 5000,
) -> Optional[
Tuple[Optional[Dict[str, Any]], Optional[str]]
]: # Return tuple: (parsed_data, raw_text)
"""
Makes a Google GenAI API call (Vertex AI backend) expecting a specific JSON response format for internal tasks.
Args:
cog: The GurtCog instance.
prompt_messages: List of message dicts (like OpenAI format: {'role': 'user'/'system'/'model', 'content': '...'}).
task_description: Description for logging.
response_schema_dict: The expected JSON output schema as a dictionary.
model_name: Optional model override.
temperature: Generation temperature.
max_tokens: Max output tokens.
Returns:
A tuple containing:
- The parsed and validated JSON dictionary if successful, None otherwise.
- The raw text response received from the API, or None if the call failed before getting text.
"""
if not PROJECT_ID or not LOCATION:
print(
f"Error in get_internal_ai_json_response ({task_description}): GCP Project/Location not set."
)
return None, None # Return tuple
final_parsed_data: Optional[Dict[str, Any]] = None
final_response_text: Optional[str] = None
error_occurred = None
request_payload_for_logging = {} # For logging
try:
# --- Convert prompt messages to Vertex AI types.Content format ---
contents: List[types.Content] = []
system_instruction = None
for msg in prompt_messages:
role = msg.get("role", "user")
content_text = msg.get("content", "")
if role == "system":
# Use the first system message as system_instruction
if system_instruction is None:
system_instruction = content_text
else:
# Append subsequent system messages to the instruction
system_instruction += "\n\n" + content_text
continue # Skip adding system messages to contents list
elif role == "assistant":
role = "model"
# --- Process content (string or list) ---
content_value = msg.get("content")
message_parts: List[types.Part] = (
[]
) # Initialize list to hold parts for this message
if isinstance(content_value, str):
# Handle simple string content
message_parts.append(types.Part(text=content_value))
elif isinstance(content_value, list):
# Handle list content (e.g., multimodal from ProfileUpdater)
for part_data in content_value:
part_type = part_data.get("type")
if part_type == "text":
text = part_data.get("text", "")
message_parts.append(types.Part(text=text))
elif part_type == "image_data":
mime_type = part_data.get("mime_type")
base64_data = part_data.get("data")
if mime_type and base64_data:
try:
image_bytes = base64.b64decode(base64_data)
message_parts.append(
types.Part(data=image_bytes, mime_type=mime_type)
)
except Exception as decode_err:
print(
f"Error decoding/adding image part in get_internal_ai_json_response: {decode_err}"
)
# Optionally add a placeholder text part indicating failure
message_parts.append(
types.Part(
text="(System Note: Failed to process an image part)"
)
)
else:
print("Warning: image_data part missing mime_type or data.")
else:
print(
f"Warning: Unknown part type '{part_type}' in internal prompt message."
)
else:
print(
f"Warning: Unexpected content type '{type(content_value)}' in internal prompt message."
)
# Add the content object if parts were generated
if message_parts:
contents.append(types.Content(role=role, parts=message_parts))
else:
print(f"Warning: No parts generated for message role '{role}'.")
# Add the critical JSON instruction to the last user message or as a new user message
json_instruction_content = (
f"**CRITICAL: Your response MUST consist *only* of the raw JSON object itself, matching this schema:**\n"
f"{json.dumps(response_schema_dict, indent=2)}\n"
f"**Ensure nothing precedes or follows the JSON.**"
)
if contents and contents[-1].role == "user":
contents[-1].parts.append(
types.Part(text=f"\n\n{json_instruction_content}")
)
else:
contents.append(
types.Content(
role="user", parts=[types.Part(text=json_instruction_content)]
)
)
# --- Determine Model ---
# Use override if provided, otherwise default (e.g., FALLBACK_MODEL for planning)
actual_model_name = (
model_name_override or DEFAULT_MODEL
) # Or choose a specific default like FALLBACK_MODEL
# --- Prepare Generation Config ---
processed_schema_internal = _preprocess_schema_for_vertex(response_schema_dict)
internal_gen_config_dict = {
"temperature": temperature,
"max_output_tokens": max_tokens,
"response_mime_type": "application/json",
"response_schema": processed_schema_internal,
"safety_settings": STANDARD_SAFETY_SETTINGS, # Include standard safety
"tools": None, # No tools for internal JSON tasks
"tool_config": None,
}
generation_config = types.GenerateContentConfig(**internal_gen_config_dict)
# --- Prepare Payload for Logging ---
# (Logging needs adjustment as model object isn't created here)
generation_config_log = {
"temperature": generation_config.temperature,
"max_output_tokens": generation_config.max_output_tokens,
"response_mime_type": generation_config.response_mime_type,
"response_schema": str(
generation_config.response_schema
), # Log schema as string
}
request_payload_for_logging = {
"model": actual_model_name, # Log the name used
# System instruction is now part of 'contents' for logging if handled that way
"contents": [
{"role": c.role, "parts": [str(p) for p in c.parts]} for c in contents
],
"generation_config": generation_config_log,
}
# (Keep detailed logging logic if desired)
try:
print(f"--- Raw request payload for {task_description} ---")
print(json.dumps(request_payload_for_logging, indent=2, default=str))
print(f"--- End Raw request payload ---")
except Exception as req_log_e:
print(f"Error logging raw request payload: {req_log_e}")
# --- Call API using the new helper ---
response_obj = await call_google_genai_api_with_retry(
cog=cog,
model_name=actual_model_name, # Pass the determined model name
contents=contents,
generation_config=generation_config, # Pass combined config
request_desc=task_description,
# No separate safety, tools, tool_config args needed
)
# --- Process Response ---
if not response_obj:
raise Exception("Internal API call failed to return a response object.")
# Log the raw response object
print(f"--- Full response_obj received for {task_description} ---")
print(response_obj)
print(f"--- End Full response_obj ---")
if not response_obj.candidates:
print(
f"Warning: Internal API call for {task_description} returned no candidates. Response: {response_obj}"
)
final_response_text = getattr(
response_obj, "text", None
) # Try to get text anyway
final_parsed_data = None
else:
# Parse and Validate using the updated helper
final_response_text = _get_response_text(response_obj) # Store raw text
print(f"--- Extracted Text for {task_description} ---")
print(final_response_text)
print(f"--- End Extracted Text ---")
# --- Log Raw Unparsed JSON ---
print(f"--- RAW UNPARSED JSON ({task_description}) ---")
print(final_response_text)
print(f"--- END RAW UNPARSED JSON ---")
# --- End Log ---
final_parsed_data = parse_and_validate_json_response(
final_response_text,
response_schema_dict,
f"internal task ({task_description})",
)
if final_parsed_data is None:
print(
f"Warning: Internal task '{task_description}' failed JSON validation."
)
# Keep final_response_text for returning raw output
except Exception as e:
print(
f"Error in get_internal_ai_json_response ({task_description}): {type(e).__name__}: {e}"
)
error_occurred = e
import traceback
traceback.print_exc()
final_parsed_data = None
# final_response_text might be None or contain partial/error text depending on when exception occurred
finally:
# Log the call
try:
# Pass the simplified payload and the *parsed* data for logging
await log_internal_api_call(
cog,
task_description,
request_payload_for_logging,
final_parsed_data,
error_occurred,
)
except Exception as log_e:
print(f"Error logging internal API call: {log_e}")
# Return both parsed data and raw text
return final_parsed_data, final_response_text
if __name__ == "__main__":
print(_preprocess_schema_for_vertex(RESPONSE_SCHEMA["schema"]))