3143 lines
146 KiB
Python
3143 lines
146 KiB
Python
from collections import deque
|
|
import ssl
|
|
import certifi
|
|
import imghdr # Added for robust image MIME type detection
|
|
|
|
from .config import CONTEXT_WINDOW_SIZE
|
|
|
|
|
|
def patch_ssl_certifi():
|
|
original_create_default_context = ssl.create_default_context
|
|
|
|
def custom_ssl_context(*args, **kwargs):
|
|
# Only set cafile if it's not already passed
|
|
kwargs.setdefault("cafile", certifi.where())
|
|
return original_create_default_context(*args, **kwargs)
|
|
|
|
ssl.create_default_context = custom_ssl_context
|
|
|
|
|
|
patch_ssl_certifi()
|
|
|
|
import discord
|
|
import aiohttp
|
|
import asyncio
|
|
import json
|
|
import base64
|
|
import re
|
|
import time
|
|
import datetime
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Optional,
|
|
List,
|
|
Dict,
|
|
Any,
|
|
Union,
|
|
AsyncIterable,
|
|
Tuple,
|
|
) # Import Tuple
|
|
import jsonschema # For manual JSON validation
|
|
from .tools import get_conversation_summary
|
|
|
|
# Google Generative AI Imports (using Vertex AI backend)
|
|
# try:
|
|
from google import genai
|
|
from google.genai import types
|
|
from google.api_core import (
|
|
exceptions as google_exceptions,
|
|
) # Keep for retry logic if applicable
|
|
|
|
# except ImportError:
|
|
# print("WARNING: google-generativeai or google-api-core not installed. API calls will fail.")
|
|
# # Define dummy classes/exceptions if library isn't installed
|
|
# genai = None # Indicate genai module is missing
|
|
# # types = None # REMOVE THIS LINE - Define dummy types object below
|
|
|
|
# # Define dummy classes first
|
|
# class DummyGenerationResponse:
|
|
# def __init__(self):
|
|
# self.candidates = []
|
|
# self.text = None # Add basic text attribute for compatibility
|
|
# class DummyFunctionCall:
|
|
# def __init__(self):
|
|
# self.name = None
|
|
# self.args = None
|
|
# class DummyPart:
|
|
# @staticmethod
|
|
# def from_text(text): return None
|
|
# @staticmethod
|
|
# def from_data(data, mime_type): return None
|
|
# @staticmethod
|
|
# def from_uri(uri, mime_type): return None
|
|
# @staticmethod
|
|
# def from_function_response(name, response): return None
|
|
# @staticmethod
|
|
# def from_function_call(function_call): return None # Add this
|
|
# class DummyContent:
|
|
# def __init__(self, role=None, parts=None):
|
|
# self.role = role
|
|
# self.parts = parts or []
|
|
# class DummyTool:
|
|
# def __init__(self, function_declarations=None): pass
|
|
# class DummyFunctionDeclaration:
|
|
# def __init__(self, name, description, parameters): pass
|
|
# class Dummytypes.SafetySetting:
|
|
# def __init__(self, category, threshold): pass
|
|
# class Dummytypes.HarmCategory:
|
|
# HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
|
|
# HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
|
|
# HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
|
|
# HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
|
|
# class DummyFinishReason:
|
|
# STOP = "STOP"
|
|
# MAX_TOKENS = "MAX_TOKENS"
|
|
# SAFETY = "SAFETY"
|
|
# RECITATION = "RECITATION"
|
|
# OTHER = "OTHER"
|
|
# FUNCTION_CALL = "FUNCTION_CALL" # Add this
|
|
# class DummyToolConfig:
|
|
# class FunctionCallingConfig:
|
|
# class Mode:
|
|
# ANY = "ANY"
|
|
# NONE = "NONE"
|
|
# AUTO = "AUTO"
|
|
# def __init__(self, function_calling_config=None): pass
|
|
# class DummyGenerateContentResponse: # For non-streaming response type hint
|
|
# def __init__(self):
|
|
# self.candidates = []
|
|
# self.text = None
|
|
# # Define a dummy GenerateContentConfig class
|
|
# class DummyGenerateContentConfig:
|
|
# def __init__(self, temperature=None, top_p=None, max_output_tokens=None, response_mime_type=None, response_schema=None, stop_sequences=None, candidate_count=None):
|
|
# self.temperature = temperature
|
|
# self.top_p = top_p
|
|
# self.max_output_tokens = max_output_tokens
|
|
# self.response_mime_type = response_mime_type
|
|
# self.response_schema = response_schema
|
|
# self.stop_sequences = stop_sequences
|
|
# self.candidate_count = candidate_count
|
|
# # Define a dummy FunctionResponse class
|
|
# class DummyFunctionResponse:
|
|
# def __init__(self, name, response):
|
|
# self.name = name
|
|
# self.response = response
|
|
|
|
# # Create a dummy 'types' object and assign dummy classes to its attributes
|
|
# class DummyTypes:
|
|
# def __init__(self):
|
|
# self.GenerationResponse = DummyGenerationResponse
|
|
# self.FunctionCall = DummyFunctionCall
|
|
# self.types.Part = DummyPart
|
|
# self.types.Content = DummyContent
|
|
# self.Tool = DummyTool
|
|
# self.FunctionDeclaration = DummyFunctionDeclaration
|
|
# self.types.SafetySetting = Dummytypes.SafetySetting
|
|
# self.types.HarmCategory = Dummytypes.HarmCategory
|
|
# self.FinishReason = DummyFinishReason
|
|
# self.ToolConfig = DummyToolConfig
|
|
# self.GenerateContentResponse = DummyGenerateContentResponse
|
|
# self.GenerateContentConfig = DummyGenerateContentConfig # Assign dummy config
|
|
# self.FunctionResponse = DummyFunctionResponse # Assign dummy function response
|
|
|
|
# types = DummyTypes() # Assign the dummy object to 'types'
|
|
|
|
# # Assign dummy types to global scope for direct imports if needed
|
|
# GenerationResponse = DummyGenerationResponse
|
|
# FunctionCall = DummyFunctionCall
|
|
# types.Part = DummyPart
|
|
# types.Content = DummyContent
|
|
# Tool = DummyTool
|
|
# FunctionDeclaration = DummyFunctionDeclaration
|
|
# types.SafetySetting = Dummytypes.SafetySetting
|
|
# types.HarmCategory = Dummytypes.HarmCategory
|
|
# FinishReason = DummyFinishReason
|
|
# ToolConfig = DummyToolConfig
|
|
# GenerateContentResponse = DummyGenerateContentResponse
|
|
|
|
# class DummyGoogleExceptions:
|
|
# ResourceExhausted = type('ResourceExhausted', (Exception,), {})
|
|
# InternalServerError = type('InternalServerError', (Exception,), {})
|
|
# ServiceUnavailable = type('ServiceUnavailable', (Exception,), {})
|
|
# InvalidArgument = type('InvalidArgument', (Exception,), {})
|
|
# GoogleAPICallError = type('GoogleAPICallError', (Exception,), {}) # Generic fallback
|
|
# google_exceptions = DummyGoogleExceptions()
|
|
|
|
|
|
# Relative imports for components within the 'gurt' package
|
|
from .config import (
|
|
PROJECT_ID,
|
|
LOCATION,
|
|
DEFAULT_MODEL,
|
|
FALLBACK_MODEL,
|
|
CUSTOM_TUNED_MODEL_ENDPOINT,
|
|
EMOJI_STICKER_DESCRIPTION_MODEL, # Import the new endpoint and model
|
|
API_TIMEOUT,
|
|
API_RETRY_ATTEMPTS,
|
|
API_RETRY_DELAY,
|
|
TOOLS,
|
|
RESPONSE_SCHEMA,
|
|
PROACTIVE_PLAN_SCHEMA, # Import the new schema
|
|
TAVILY_API_KEY,
|
|
PISTON_API_URL,
|
|
PISTON_API_KEY,
|
|
BASELINE_PERSONALITY,
|
|
TENOR_API_KEY, # Import other needed configs
|
|
PRO_MODEL_LOCATION,
|
|
)
|
|
from .prompt import build_dynamic_system_prompt
|
|
from .context import (
|
|
gather_conversation_context,
|
|
get_memory_context,
|
|
) # Renamed functions
|
|
from .tools import TOOL_MAPPING # Import tool mapping
|
|
from .utils import format_message, log_internal_api_call # Import utilities
|
|
import copy # Needed for deep copying schemas
|
|
|
|
if TYPE_CHECKING:
|
|
from .cog import GurtCog # Import GurtCog for type hinting only
|
|
|
|
|
|
# --- Schema Preprocessing Helper ---
|
|
def _preprocess_schema_for_vertex(schema: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Recursively preprocesses a JSON schema dictionary to replace list types
|
|
(like ["string", "null"]) with the first non-null type, making it
|
|
compatible with Vertex AI's GenerationConfig schema requirements.
|
|
|
|
Args:
|
|
schema: The JSON schema dictionary to preprocess.
|
|
|
|
Returns:
|
|
A new, preprocessed schema dictionary.
|
|
"""
|
|
if not isinstance(schema, dict):
|
|
return schema # Return non-dict elements as is
|
|
|
|
processed_schema = copy.deepcopy(schema) # Work on a copy
|
|
|
|
for key, value in processed_schema.items():
|
|
if key == "type" and isinstance(value, list):
|
|
# Find the first non-"null" type in the list
|
|
first_valid_type = next(
|
|
(t for t in value if isinstance(t, str) and t.lower() != "null"), None
|
|
)
|
|
if first_valid_type:
|
|
processed_schema[key] = first_valid_type
|
|
else:
|
|
# Fallback if only "null" or invalid types are present (shouldn't happen in valid schemas)
|
|
processed_schema[key] = "object" # Or handle as error
|
|
print(
|
|
f"Warning: Schema preprocessing found list type '{value}' with no valid non-null string type. Falling back to 'object'."
|
|
)
|
|
elif isinstance(value, dict):
|
|
processed_schema[key] = _preprocess_schema_for_vertex(
|
|
value
|
|
) # Recurse for nested objects
|
|
elif isinstance(value, list):
|
|
# Recurse for items within arrays (e.g., in 'properties' of array items)
|
|
processed_schema[key] = [
|
|
_preprocess_schema_for_vertex(item) if isinstance(item, dict) else item
|
|
for item in value
|
|
]
|
|
# Handle 'properties' specifically
|
|
elif key == "properties" and isinstance(value, dict):
|
|
processed_schema[key] = {
|
|
prop_key: _preprocess_schema_for_vertex(prop_value)
|
|
for prop_key, prop_value in value.items()
|
|
}
|
|
# Handle 'items' specifically if it's a schema object
|
|
elif key == "items" and isinstance(value, dict):
|
|
processed_schema[key] = _preprocess_schema_for_vertex(value)
|
|
|
|
return processed_schema
|
|
|
|
|
|
# --- Helper Function to Safely Extract Text ---
|
|
# Updated to handle google.generativeai.types.GenerateContentResponse
|
|
def _get_response_text(
|
|
response: Optional[types.GenerateContentResponse],
|
|
) -> Optional[str]:
|
|
"""
|
|
Safely extracts the text content from the first text part of a GenerateContentResponse.
|
|
Handles potential errors and lack of text parts gracefully.
|
|
"""
|
|
if not response:
|
|
print("[_get_response_text] Received None response object.")
|
|
return None
|
|
|
|
# Check if response has the 'text' attribute directly (common case for simple text responses)
|
|
if hasattr(response, "text") and response.text:
|
|
print("[_get_response_text] Found text directly in response.text attribute.")
|
|
return response.text
|
|
|
|
# If no direct text, check candidates
|
|
if not response.candidates:
|
|
# Log the response object itself for debugging if it exists but has no candidates
|
|
print(
|
|
f"[_get_response_text] Response object has no candidates. Response: {response}"
|
|
)
|
|
return None
|
|
|
|
try:
|
|
# Prioritize the first candidate
|
|
candidate = response.candidates[0]
|
|
|
|
# Check candidate.content and candidate.content.parts
|
|
if not hasattr(candidate, "content") or not candidate.content:
|
|
print(
|
|
f"[_get_response_text] Candidate 0 has no 'content'. Candidate: {candidate}"
|
|
)
|
|
return None
|
|
if not hasattr(candidate.content, "parts") or not candidate.content.parts:
|
|
print(
|
|
f"[_get_response_text] Candidate 0 content has no 'parts' or parts list is empty. types.Content: {candidate.content}"
|
|
)
|
|
return None
|
|
|
|
# Log parts for debugging
|
|
print(
|
|
f"[_get_response_text] Inspecting parts in candidate 0: {candidate.content.parts}"
|
|
)
|
|
|
|
# Iterate through parts to find the first text part
|
|
for i, part in enumerate(candidate.content.parts):
|
|
# Check if the part has a 'text' attribute and it's not empty/None
|
|
if (
|
|
hasattr(part, "text") and part.text is not None
|
|
): # Check for None explicitly
|
|
# Check if text is non-empty string after stripping whitespace
|
|
if isinstance(part.text, str) and part.text.strip():
|
|
print(f"[_get_response_text] Found non-empty text in part {i}.")
|
|
return part.text
|
|
else:
|
|
print(
|
|
f"[_get_response_text] types.Part {i} has 'text' attribute, but it's empty or not a string: {part.text!r}"
|
|
)
|
|
# else:
|
|
# print(f"[_get_response_text] types.Part {i} does not have 'text' attribute or it's None.")
|
|
|
|
# If no text part is found after checking all parts in the first candidate
|
|
print(
|
|
f"[_get_response_text] No usable text part found in candidate 0 after iterating through all parts."
|
|
)
|
|
return None
|
|
|
|
except (AttributeError, IndexError, TypeError) as e:
|
|
# Handle cases where structure is unexpected, list is empty, or types are wrong
|
|
print(
|
|
f"[_get_response_text] Error accessing response structure: {type(e).__name__}: {e}"
|
|
)
|
|
# Log the problematic response object for deeper inspection
|
|
print(f"Problematic response object: {response}")
|
|
return None
|
|
except Exception as e:
|
|
# Catch other unexpected errors during access
|
|
print(f"[_get_response_text] Unexpected error extracting text: {e}")
|
|
print(f"Response object during error: {response}")
|
|
return None
|
|
|
|
|
|
# --- Helper Function to Format Embeds for Prompt ---
|
|
def _format_embeds_for_prompt(embed_content: List[Dict[str, Any]]) -> Optional[str]:
|
|
"""Formats embed data into a string for the AI prompt."""
|
|
if not embed_content:
|
|
return None
|
|
|
|
formatted_strings = []
|
|
for i, embed in enumerate(embed_content):
|
|
parts = [f"--- Embed {i+1} ---"]
|
|
if embed.get("author") and embed["author"].get("name"):
|
|
parts.append(f"Author: {embed['author']['name']}")
|
|
if embed.get("title"):
|
|
parts.append(f"Title: {embed['title']}")
|
|
if embed.get("description"):
|
|
# Limit description length
|
|
desc = embed["description"]
|
|
max_desc_len = 200
|
|
if len(desc) > max_desc_len:
|
|
desc = desc[:max_desc_len] + "..."
|
|
parts.append(f"Description: {desc}")
|
|
if embed.get("fields"):
|
|
field_parts = []
|
|
for field in embed["fields"]:
|
|
fname = field.get("name", "Field")
|
|
fvalue = field.get("value", "")
|
|
# Limit field value length
|
|
max_field_len = 100
|
|
if len(fvalue) > max_field_len:
|
|
fvalue = fvalue[:max_field_len] + "..."
|
|
field_parts.append(f"- {fname}: {fvalue}")
|
|
if field_parts:
|
|
parts.append("Fields:\n" + "\n".join(field_parts))
|
|
if embed.get("footer") and embed["footer"].get("text"):
|
|
parts.append(f"Footer: {embed['footer']['text']}")
|
|
if embed.get("image_url"):
|
|
parts.append(
|
|
f"[Image Attached: {embed.get('image_url')}]"
|
|
) # Indicate image presence
|
|
if embed.get("thumbnail_url"):
|
|
parts.append(
|
|
f"[Thumbnail Attached: {embed.get('thumbnail_url')}]"
|
|
) # Indicate thumbnail presence
|
|
|
|
formatted_strings.append("\n".join(parts))
|
|
|
|
return "\n".join(formatted_strings) if formatted_strings else None
|
|
|
|
|
|
# --- Initialize Google Generative AI Clients for Vertex AI ---
|
|
# Separate clients for different regions so models can use their preferred location
|
|
try:
|
|
genai_client_us_central1 = genai.Client(
|
|
vertexai=True,
|
|
project=PROJECT_ID,
|
|
location="us-central1",
|
|
)
|
|
genai_client_global = genai.Client(
|
|
vertexai=True,
|
|
project=PROJECT_ID,
|
|
location=PRO_MODEL_LOCATION,
|
|
)
|
|
# Default client remains us-central1 for backward compatibility
|
|
genai_client = genai_client_us_central1
|
|
|
|
print("Google GenAI Clients initialized for us-central1 and global regions.")
|
|
except NameError:
|
|
genai_client = None
|
|
genai_client_us_central1 = None
|
|
genai_client_global = None
|
|
print("Google GenAI SDK (genai) not imported, skipping client initialization.")
|
|
except Exception as e:
|
|
genai_client = None
|
|
genai_client_us_central1 = None
|
|
genai_client_global = None
|
|
print(f"Error initializing Google GenAI Clients for Vertex AI: {e}")
|
|
|
|
|
|
def get_genai_client_for_model(model_name: str):
|
|
"""Return the appropriate genai client based on the model name."""
|
|
if "gemini-2.5-pro-preview-06-05" in model_name:
|
|
return genai_client_global
|
|
return genai_client_us_central1
|
|
|
|
|
|
# --- Constants ---
|
|
# Define standard safety settings using google.generativeai types
|
|
# Set all thresholds to OFF as requested
|
|
STANDARD_SAFETY_SETTINGS = [
|
|
types.SafetySetting(
|
|
category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold="BLOCK_NONE"
|
|
),
|
|
types.SafetySetting(
|
|
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
|
threshold="BLOCK_NONE",
|
|
),
|
|
types.SafetySetting(
|
|
category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
|
threshold="BLOCK_NONE",
|
|
),
|
|
types.SafetySetting(
|
|
category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold="BLOCK_NONE"
|
|
),
|
|
]
|
|
|
|
|
|
# --- API Call Helper ---
|
|
async def call_google_genai_api_with_retry(
|
|
cog: "GurtCog",
|
|
model_name: str, # Pass model name string instead of model object
|
|
contents: List[
|
|
types.Content
|
|
], # Use types.Content type from google.generativeai.types
|
|
generation_config: types.GenerateContentConfig, # Combined config object
|
|
request_desc: str,
|
|
# Removed safety_settings, tools, tool_config as separate params
|
|
) -> Optional[types.GenerateContentResponse]: # Return type for non-streaming
|
|
"""
|
|
Calls the Google Generative AI API (Vertex AI backend) with retry logic (non-streaming).
|
|
|
|
Args:
|
|
cog: The GurtCog instance.
|
|
model_name: The name/path of the model to use (e.g., 'models/gemini-1.5-pro-preview-0409' or custom endpoint path).
|
|
contents: The list of types.Content objects for the prompt.
|
|
generation_config: The types.GenerateContentConfig object, which should now include temperature, top_p, max_output_tokens, safety_settings, tools, tool_config, response_mime_type, response_schema etc. as needed.
|
|
request_desc: A description of the request for logging purposes.
|
|
|
|
Returns:
|
|
The GenerateContentResponse object if successful, or None on failure after retries.
|
|
|
|
Raises:
|
|
Exception: If the API call fails after all retry attempts or encounters a non-retryable error.
|
|
"""
|
|
client = get_genai_client_for_model(model_name)
|
|
if not client:
|
|
raise Exception("Google GenAI Client is not initialized.")
|
|
|
|
last_exception = None
|
|
start_time = time.monotonic()
|
|
|
|
# Get the model object from the client
|
|
# Note: model_name should include the 'projects/.../locations/.../endpoints/...' path for custom models
|
|
# or just 'models/model-name' for standard models.
|
|
try:
|
|
model = "projects/1079377687568/locations/us-central1/endpoints/6677946543460319232" # Use get_model to ensure it exists
|
|
if not model:
|
|
raise ValueError(f"Could not retrieve model: {model_name}")
|
|
except Exception as model_e:
|
|
print(f"Error retrieving model '{model_name}': {model_e}")
|
|
raise # Re-raise the exception as this is a fundamental setup issue
|
|
|
|
for attempt in range(API_RETRY_ATTEMPTS + 1):
|
|
try:
|
|
# Use the actual model name string passed to the function for logging
|
|
print(
|
|
f"Sending API request for {request_desc} using {model_name} (Attempt {attempt + 1}/{API_RETRY_ATTEMPTS + 1})..."
|
|
)
|
|
|
|
# Use the non-streaming async call - config now contains all settings
|
|
# The 'model' parameter here should be the actual model name string
|
|
response = await client.aio.models.generate_content(
|
|
model=model_name, # Use the model_name string directly
|
|
contents=contents,
|
|
config=generation_config, # Pass the combined config object
|
|
# stream=False is implicit for generate_content
|
|
)
|
|
|
|
# --- Check Finish Reason (Safety) ---
|
|
# Access finish reason and safety ratings from the response object
|
|
if response and response.candidates:
|
|
candidate = response.candidates[0]
|
|
finish_reason = getattr(candidate, "finish_reason", None)
|
|
safety_ratings = getattr(candidate, "safety_ratings", [])
|
|
|
|
if finish_reason == types.FinishReason.SAFETY:
|
|
safety_ratings_str = (
|
|
", ".join(
|
|
[
|
|
f"{rating.category.name}: {rating.probability.name}"
|
|
for rating in safety_ratings
|
|
]
|
|
)
|
|
if safety_ratings
|
|
else "N/A"
|
|
)
|
|
# Optionally, raise a specific exception here if needed downstream
|
|
# raise SafetyBlockError(f"Blocked by safety filters. Ratings: {safety_ratings_str}")
|
|
elif finish_reason not in [
|
|
types.FinishReason.STOP,
|
|
types.FinishReason.MAX_TOKENS,
|
|
None,
|
|
]: # Allow None finish reason
|
|
# Log other unexpected finish reasons
|
|
finish_reason_name = (
|
|
types.FinishReason(finish_reason).name
|
|
if isinstance(finish_reason, int)
|
|
else str(finish_reason)
|
|
)
|
|
print(
|
|
f"⚠️ UNEXPECTED FINISH REASON: API request for {request_desc} ({model_name}) finished with reason: {finish_reason_name}"
|
|
)
|
|
|
|
# --- Success Logging (Proceed even if safety blocked, but log occurred) ---
|
|
elapsed_time = time.monotonic() - start_time
|
|
# Ensure model_name exists in stats before incrementing
|
|
if model_name not in cog.api_stats:
|
|
cog.api_stats[model_name] = {
|
|
"success": 0,
|
|
"failure": 0,
|
|
"retries": 0,
|
|
"total_time": 0.0,
|
|
"count": 0,
|
|
}
|
|
cog.api_stats[model_name]["success"] += 1
|
|
cog.api_stats[model_name]["total_time"] += elapsed_time
|
|
cog.api_stats[model_name]["count"] += 1
|
|
print(
|
|
f"API request successful for {request_desc} ({model_name}) in {elapsed_time:.2f}s."
|
|
)
|
|
return response # Success
|
|
|
|
# Adapt exception handling if google.generativeai raises different types
|
|
# google.api_core.exceptions should still cover many common API errors
|
|
except google_exceptions.ResourceExhausted as e:
|
|
error_msg = f"Rate limit error (ResourceExhausted) for {request_desc} ({model_name}): {e}"
|
|
print(f"{error_msg} (Attempt {attempt + 1})")
|
|
last_exception = e
|
|
if attempt < API_RETRY_ATTEMPTS:
|
|
if model_name not in cog.api_stats:
|
|
cog.api_stats[model_name] = {
|
|
"success": 0,
|
|
"failure": 0,
|
|
"retries": 0,
|
|
"total_time": 0.0,
|
|
"count": 0,
|
|
}
|
|
cog.api_stats[model_name]["retries"] += 1
|
|
wait_time = API_RETRY_DELAY * (2**attempt) # Exponential backoff
|
|
print(f"Waiting {wait_time:.2f} seconds before retrying...")
|
|
await asyncio.sleep(wait_time)
|
|
continue
|
|
else:
|
|
break # Max retries reached
|
|
|
|
except (
|
|
google_exceptions.InternalServerError,
|
|
google_exceptions.ServiceUnavailable,
|
|
) as e:
|
|
error_msg = f"API server error ({type(e).__name__}) for {request_desc} ({model_name}): {e}"
|
|
print(f"{error_msg} (Attempt {attempt + 1})")
|
|
last_exception = e
|
|
if attempt < API_RETRY_ATTEMPTS:
|
|
if model_name not in cog.api_stats:
|
|
cog.api_stats[model_name] = {
|
|
"success": 0,
|
|
"failure": 0,
|
|
"retries": 0,
|
|
"total_time": 0.0,
|
|
"count": 0,
|
|
}
|
|
cog.api_stats[model_name]["retries"] += 1
|
|
wait_time = API_RETRY_DELAY * (2**attempt) # Exponential backoff
|
|
print(f"Waiting {wait_time:.2f} seconds before retrying...")
|
|
await asyncio.sleep(wait_time)
|
|
continue
|
|
else:
|
|
break # Max retries reached
|
|
|
|
except google_exceptions.InvalidArgument as e:
|
|
# Often indicates a problem with the request itself (e.g., bad schema, unsupported format, invalid model name)
|
|
error_msg = f"Invalid argument error for {request_desc} ({model_name}): {e}"
|
|
print(error_msg)
|
|
last_exception = e
|
|
break # Non-retryable
|
|
|
|
except (
|
|
asyncio.TimeoutError
|
|
): # Handle potential client-side timeouts if applicable
|
|
error_msg = f"Client-side request timed out for {request_desc} ({model_name}) (Attempt {attempt + 1})"
|
|
print(error_msg)
|
|
last_exception = asyncio.TimeoutError(error_msg)
|
|
# Decide if client-side timeouts should be retried
|
|
if attempt < API_RETRY_ATTEMPTS:
|
|
if model_name not in cog.api_stats:
|
|
cog.api_stats[model_name] = {
|
|
"success": 0,
|
|
"failure": 0,
|
|
"retries": 0,
|
|
"total_time": 0.0,
|
|
"count": 0,
|
|
}
|
|
cog.api_stats[model_name]["retries"] += 1
|
|
await asyncio.sleep(
|
|
API_RETRY_DELAY * (attempt + 1)
|
|
) # Linear backoff for timeout? Or keep exponential?
|
|
continue
|
|
else:
|
|
break
|
|
|
|
except (
|
|
Exception
|
|
) as e: # Catch other potential exceptions (e.g., from genai library itself)
|
|
error_msg = f"Unexpected error during API call for {request_desc} ({model_name}) (Attempt {attempt + 1}): {type(e).__name__}: {e}"
|
|
print(error_msg)
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
last_exception = e
|
|
# Decide if this generic exception is retryable
|
|
# For now, treat unexpected errors as non-retryable
|
|
break
|
|
|
|
# --- Failure Logging ---
|
|
elapsed_time = time.monotonic() - start_time
|
|
if model_name not in cog.api_stats:
|
|
cog.api_stats[model_name] = {
|
|
"success": 0,
|
|
"failure": 0,
|
|
"retries": 0,
|
|
"total_time": 0.0,
|
|
"count": 0,
|
|
}
|
|
cog.api_stats[model_name]["failure"] += 1
|
|
cog.api_stats[model_name]["total_time"] += elapsed_time
|
|
cog.api_stats[model_name]["count"] += 1
|
|
print(
|
|
f"API request failed for {request_desc} ({model_name}) after {attempt + 1} attempts in {elapsed_time:.2f}s."
|
|
)
|
|
|
|
# Raise the last encountered exception or a generic one
|
|
raise last_exception or Exception(
|
|
f"API request failed for {request_desc} after {API_RETRY_ATTEMPTS + 1} attempts."
|
|
)
|
|
|
|
|
|
# --- JSON Parsing and Validation Helper ---
|
|
def parse_and_validate_json_response(
|
|
response_text: Optional[str], schema: Dict[str, Any], context_description: str
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Parses the AI's response text, attempting to extract and validate a JSON object against a schema.
|
|
|
|
Args:
|
|
response_text: The raw text content from the AI response.
|
|
schema: The JSON schema (as a dictionary) to validate against.
|
|
context_description: A description for logging purposes.
|
|
|
|
Returns:
|
|
A parsed and validated dictionary if successful, None otherwise.
|
|
"""
|
|
if response_text is None:
|
|
print(f"Parsing ({context_description}): Response text is None.")
|
|
return None
|
|
|
|
parsed_data = None
|
|
raw_json_text = response_text # Start with the full text
|
|
|
|
# Attempt 1: Try parsing the whole string directly
|
|
try:
|
|
parsed_data = json.loads(raw_json_text)
|
|
print(
|
|
f"Parsing ({context_description}): Successfully parsed entire response as JSON."
|
|
)
|
|
except json.JSONDecodeError:
|
|
# Attempt 2: Extract JSON object, handling optional markdown fences
|
|
# More robust regex to handle potential leading/trailing text and variations
|
|
json_match = re.search(
|
|
r"```(?:json)?\s*(\{.*\})\s*```|(\{.*\})",
|
|
response_text,
|
|
re.DOTALL | re.MULTILINE,
|
|
)
|
|
if json_match:
|
|
json_str = json_match.group(1) or json_match.group(2)
|
|
if json_str:
|
|
raw_json_text = json_str # Use the extracted string for parsing
|
|
try:
|
|
parsed_data = json.loads(raw_json_text)
|
|
print(
|
|
f"Parsing ({context_description}): Successfully extracted and parsed JSON using regex."
|
|
)
|
|
except json.JSONDecodeError as e_inner:
|
|
print(
|
|
f"Parsing ({context_description}): Regex found potential JSON, but it failed to parse: {e_inner}\nContent: {raw_json_text[:500]}"
|
|
)
|
|
parsed_data = None
|
|
else:
|
|
print(
|
|
f"Parsing ({context_description}): Regex matched, but failed to capture JSON content."
|
|
)
|
|
parsed_data = None
|
|
else:
|
|
print(
|
|
f"Parsing ({context_description}): Could not parse directly or extract JSON object using regex.\nContent: {raw_json_text[:500]}"
|
|
)
|
|
parsed_data = None
|
|
|
|
# Validation step
|
|
if parsed_data is not None:
|
|
if not isinstance(parsed_data, dict):
|
|
print(
|
|
f"Parsing ({context_description}): Parsed data is not a dictionary: {type(parsed_data)}"
|
|
)
|
|
return None # Fail validation if not a dict
|
|
|
|
try:
|
|
jsonschema.validate(instance=parsed_data, schema=schema)
|
|
print(
|
|
f"Parsing ({context_description}): JSON successfully validated against schema."
|
|
)
|
|
# Ensure default keys exist after validation
|
|
parsed_data.setdefault("should_respond", False)
|
|
parsed_data.setdefault("content", None)
|
|
parsed_data.setdefault("react_with_emoji", None)
|
|
return parsed_data
|
|
except jsonschema.ValidationError as e:
|
|
print(
|
|
f"Parsing ({context_description}): JSON failed schema validation: {e.message}"
|
|
)
|
|
# Optionally log more details: e.path, e.schema_path, e.instance
|
|
return None # Validation failed
|
|
except Exception as e: # Catch other potential validation errors
|
|
print(
|
|
f"Parsing ({context_description}): Unexpected error during JSON schema validation: {e}"
|
|
)
|
|
return None
|
|
else:
|
|
# Parsing failed before validation could occur
|
|
return None
|
|
|
|
|
|
# --- Tool Processing ---
|
|
# Updated to use google.generativeai types
|
|
async def process_requested_tools(
|
|
cog: "GurtCog", function_call: types.FunctionCall
|
|
) -> List[types.Part]: # Return type is List
|
|
"""
|
|
Process a tool request specified by the AI's FunctionCall response.
|
|
Returns a list of types.Part objects (usually one, but potentially more if an image URL is detected in the result).
|
|
"""
|
|
function_name = function_call.name
|
|
# function_call.args is already a dict-like object in google.generativeai
|
|
function_args = dict(function_call.args) if function_call.args else {}
|
|
tool_result_content = None
|
|
|
|
print(f"Processing tool request: {function_name} with args: {function_args}")
|
|
tool_start_time = time.monotonic()
|
|
|
|
if function_name in TOOL_MAPPING:
|
|
try:
|
|
tool_func = TOOL_MAPPING[function_name]
|
|
# Execute the mapped function
|
|
result_dict = await tool_func(cog, **function_args)
|
|
|
|
# --- Tool Success Logging ---
|
|
tool_elapsed_time = time.monotonic() - tool_start_time
|
|
if function_name not in cog.tool_stats:
|
|
cog.tool_stats[function_name] = {
|
|
"success": 0,
|
|
"failure": 0,
|
|
"total_time": 0.0,
|
|
"count": 0,
|
|
}
|
|
cog.tool_stats[function_name]["success"] += 1
|
|
cog.tool_stats[function_name]["total_time"] += tool_elapsed_time
|
|
cog.tool_stats[function_name]["count"] += 1
|
|
print(
|
|
f"Tool '{function_name}' executed successfully in {tool_elapsed_time:.2f}s."
|
|
)
|
|
|
|
# Ensure result is a dict, converting if necessary
|
|
if not isinstance(result_dict, dict):
|
|
if (
|
|
isinstance(result_dict, (str, int, float, bool, list))
|
|
or result_dict is None
|
|
):
|
|
result_dict = {"result": result_dict}
|
|
else:
|
|
print(
|
|
f"Warning: Tool '{function_name}' returned non-standard type {type(result_dict)}. Attempting str conversion."
|
|
)
|
|
result_dict = {"result": str(result_dict)}
|
|
|
|
tool_result_content = result_dict # Now guaranteed to be a dict
|
|
|
|
except Exception as e:
|
|
# --- Tool Failure Logging ---
|
|
tool_elapsed_time = (
|
|
time.monotonic() - tool_start_time
|
|
) # Recalculate time even on failure
|
|
if function_name not in cog.tool_stats:
|
|
cog.tool_stats[function_name] = {
|
|
"success": 0,
|
|
"failure": 0,
|
|
"total_time": 0.0,
|
|
"count": 0,
|
|
}
|
|
cog.tool_stats[function_name]["failure"] += 1
|
|
cog.tool_stats[function_name]["total_time"] += tool_elapsed_time
|
|
cog.tool_stats[function_name]["count"] += 1
|
|
error_message = (
|
|
f"Error executing tool {function_name}: {type(e).__name__}: {str(e)}"
|
|
)
|
|
print(f"{error_message} (Took {tool_elapsed_time:.2f}s)")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
tool_result_content = {
|
|
"error": error_message
|
|
} # Ensure it's a dict even on error
|
|
|
|
else: # This 'else' corresponds to 'if function_name in TOOL_MAPPING:'
|
|
# --- Tool Not Found Logging ---
|
|
tool_elapsed_time = (
|
|
time.monotonic() - tool_start_time
|
|
) # Time for the failed lookup
|
|
if function_name not in cog.tool_stats:
|
|
cog.tool_stats[function_name] = {
|
|
"success": 0,
|
|
"failure": 0,
|
|
"total_time": 0.0,
|
|
"count": 0,
|
|
}
|
|
cog.tool_stats[function_name]["failure"] += 1 # Count as failure
|
|
cog.tool_stats[function_name]["total_time"] += tool_elapsed_time
|
|
cog.tool_stats[function_name]["count"] += 1
|
|
error_message = f"Tool '{function_name}' not found or implemented."
|
|
print(f"{error_message} (Took {tool_elapsed_time:.2f}s)")
|
|
tool_result_content = {"error": error_message} # Ensure it's a dict
|
|
|
|
# --- Process result for potential image URLs ---
|
|
parts_to_return: List[types.Part] = []
|
|
original_image_url: Optional[str] = None # Store the original URL if found
|
|
modified_result_content = copy.deepcopy(tool_result_content) # Work on a copy
|
|
|
|
# --- Image URL Detection & Modification ---
|
|
# Check specific tools and keys known to contain image URLs
|
|
|
|
# Special handling for get_user_avatar_data to directly use its base64 output
|
|
if function_name == "get_user_avatar_data" and isinstance(
|
|
modified_result_content, dict
|
|
):
|
|
base64_image_data = modified_result_content.get("base64_data")
|
|
image_mime_type = modified_result_content.get("content_type")
|
|
|
|
if base64_image_data and image_mime_type:
|
|
try:
|
|
image_bytes = base64.b64decode(base64_image_data)
|
|
# Validate MIME type (optional, but good practice)
|
|
supported_image_mimes = [
|
|
"image/png",
|
|
"image/jpeg",
|
|
"image/webp",
|
|
"image/heic",
|
|
"image/heif",
|
|
]
|
|
clean_mime_type = image_mime_type.split(";")[0].lower()
|
|
|
|
if clean_mime_type in supported_image_mimes:
|
|
# Corrected: Use inline_data for raw bytes
|
|
image_part = types.Part(
|
|
inline_data=types.Blob(
|
|
data=image_bytes, mime_type=clean_mime_type
|
|
)
|
|
)
|
|
parts_to_return.append(
|
|
image_part
|
|
) # Corrected: Add to parts_to_return for this tool's response
|
|
print(
|
|
f"Added image part directly from get_user_avatar_data (MIME: {clean_mime_type}, {len(image_bytes)} bytes)."
|
|
)
|
|
# Replace base64_data in the textual response to avoid sending it twice
|
|
modified_result_content["base64_data"] = (
|
|
"[Image Content Attached In Prompt]"
|
|
)
|
|
modified_result_content["content_type"] = (
|
|
f"[MIME type: {clean_mime_type} - Content Attached In Prompt]"
|
|
)
|
|
else:
|
|
print(
|
|
f"Warning: MIME type '{clean_mime_type}' from get_user_avatar_data not in supported list. Not attaching image part."
|
|
)
|
|
modified_result_content["base64_data"] = (
|
|
"[Image Data Not Attached - Unsupported MIME Type]"
|
|
)
|
|
except Exception as e:
|
|
print(f"Error processing base64 data from get_user_avatar_data: {e}")
|
|
modified_result_content["base64_data"] = (
|
|
f"[Error Processing Image Data: {e}]"
|
|
)
|
|
# Prevent generic URL download logic from re-processing this avatar
|
|
original_image_url = None # Explicitly nullify to skip URL download
|
|
|
|
elif function_name == "get_user_avatar_url" and isinstance(
|
|
modified_result_content, dict
|
|
):
|
|
avatar_url_value = modified_result_content.get("avatar_url")
|
|
if avatar_url_value and isinstance(avatar_url_value, str):
|
|
original_image_url = avatar_url_value # Store original
|
|
modified_result_content["avatar_url"] = (
|
|
"[Image Content Attached]" # Replace URL with placeholder
|
|
)
|
|
elif function_name == "get_user_profile_info" and isinstance(
|
|
modified_result_content, dict
|
|
):
|
|
profile_dict = modified_result_content.get("profile")
|
|
if isinstance(profile_dict, dict):
|
|
avatar_url_value = profile_dict.get("avatar_url")
|
|
if avatar_url_value and isinstance(avatar_url_value, str):
|
|
original_image_url = avatar_url_value # Store original
|
|
profile_dict["avatar_url"] = (
|
|
"[Image Content Attached]" # Replace URL in nested dict
|
|
)
|
|
# Add checks for other tools/keys that might return image URLs if necessary
|
|
|
|
# --- Create Parts ---
|
|
# Always add the function response part (using the potentially modified content)
|
|
function_response_part = types.Part(
|
|
function_response=types.FunctionResponse(
|
|
name=function_name, response=modified_result_content
|
|
)
|
|
)
|
|
parts_to_return.append(function_response_part)
|
|
|
|
# Add image part if an original URL was found and seems valid
|
|
if (
|
|
original_image_url
|
|
and isinstance(original_image_url, str)
|
|
and original_image_url.startswith("http")
|
|
):
|
|
download_success = False
|
|
try:
|
|
# Download the image data using aiohttp session from cog
|
|
if not hasattr(cog, "session") or not cog.session:
|
|
raise ValueError("aiohttp session not found in cog.")
|
|
|
|
print(f"Downloading image data from URL: {original_image_url}")
|
|
async with cog.session.get(
|
|
original_image_url, timeout=15
|
|
) as response: # Added timeout
|
|
if response.status == 200:
|
|
image_bytes = await response.read()
|
|
mime_type = (
|
|
response.content_type or "application/octet-stream"
|
|
) # Get MIME type from header
|
|
|
|
# Validate against known supported image types for Gemini
|
|
supported_image_mimes = [
|
|
"image/png",
|
|
"image/jpeg",
|
|
"image/webp",
|
|
"image/heic",
|
|
"image/heif",
|
|
]
|
|
clean_mime_type = mime_type.split(";")[0].lower() # Clean MIME type
|
|
|
|
if clean_mime_type in supported_image_mimes:
|
|
# Use types.Part.from_data instead of from_uri
|
|
image_part = types.Part(
|
|
inline_data=types.Blob(
|
|
data=image_bytes, mime_type=clean_mime_type
|
|
)
|
|
)
|
|
parts_to_return.append(image_part)
|
|
download_success = True
|
|
print(
|
|
f"Added image part (from data, {len(image_bytes)} bytes, MIME: {clean_mime_type}) from tool '{function_name}' result."
|
|
)
|
|
else:
|
|
print(
|
|
f"Warning: Downloaded image MIME type '{clean_mime_type}' from {original_image_url} might not be supported by Gemini. Skipping image part."
|
|
)
|
|
else:
|
|
print(
|
|
f"Error downloading image from {original_image_url}: Status {response.status}"
|
|
)
|
|
|
|
except asyncio.TimeoutError:
|
|
print(
|
|
f"Error downloading image from {original_image_url}: Request timed out."
|
|
)
|
|
except aiohttp.ClientError as client_e:
|
|
print(f"Error downloading image from {original_image_url}: {client_e}")
|
|
except ValueError as val_e: # Catch missing session error
|
|
print(f"Error preparing image download: {val_e}")
|
|
except Exception as e:
|
|
print(
|
|
f"Error downloading or creating image part from data ({original_image_url}): {e}"
|
|
)
|
|
|
|
# If download or processing failed, add an error note for the LLM
|
|
if not download_success:
|
|
error_text = f"[System Note: Failed to download or process image data from URL provided by tool '{function_name}'. URL: {original_image_url}]"
|
|
error_text_part = types.Part(text=error_text)
|
|
parts_to_return.append(error_text_part)
|
|
|
|
return parts_to_return # Return the list of parts (will contain 1 or 2+ parts)
|
|
|
|
|
|
# --- Helper to find function call in parts ---
|
|
# Updated to use google.generativeai types
|
|
def find_function_call_in_parts(
|
|
parts: Optional[List[types.Part]],
|
|
) -> Optional[types.FunctionCall]:
|
|
"""Finds the first valid FunctionCall object within a list of Parts."""
|
|
if not parts:
|
|
return None
|
|
for part in parts:
|
|
# Check if the part has a 'function_call' attribute and it's a valid FunctionCall object
|
|
if hasattr(part, "function_call") and isinstance(
|
|
part.function_call, types.FunctionCall
|
|
):
|
|
# Basic validation: ensure name exists
|
|
if part.function_call.name:
|
|
return part.function_call
|
|
else:
|
|
print(
|
|
f"Warning: Found types.Part with 'function_call', but its name is missing: {part.function_call}"
|
|
)
|
|
# else:
|
|
# print(f"Debug: types.Part does not have valid function_call: {part}") # Optional debug log
|
|
return None
|
|
|
|
|
|
# --- Main AI Response Function ---
|
|
async def get_ai_response(
|
|
cog: "GurtCog", message: discord.Message, model_name: Optional[str] = None
|
|
) -> Tuple[Dict[str, Any], List[str]]:
|
|
"""
|
|
Gets responses from the Vertex AI Gemini API, handling potential tool usage and returning
|
|
the final parsed response.
|
|
|
|
Args:
|
|
cog: The GurtCog instance.
|
|
message: The triggering discord.Message.
|
|
model_name: Optional override for the AI model name (e.g., "gemini-1.5-pro-preview-0409").
|
|
|
|
Returns:
|
|
A dictionary containing:
|
|
- "final_response": Parsed JSON data from the final AI call (or None if parsing/validation fails).
|
|
- "error": An error message string if a critical error occurred, otherwise None.
|
|
- "fallback_initial": Optional minimal response if initial parsing failed critically (less likely with controlled generation).
|
|
"""
|
|
if not PROJECT_ID:
|
|
error_msg = "Google Cloud Project ID not configured."
|
|
print(f"Error in get_ai_response: {error_msg}")
|
|
return {"final_response": None, "error": error_msg}
|
|
|
|
# Determine the model for all generation steps.
|
|
# Use the model_name override if provided to get_ai_response, otherwise use the cog's current default_model.
|
|
active_model = model_name or cog.default_model
|
|
|
|
print(f"Using active model for all generation steps: {active_model}")
|
|
|
|
channel_id = message.channel.id
|
|
user_id = message.author.id
|
|
# initial_parsed_data is no longer needed with the loop structure
|
|
final_parsed_data = None
|
|
error_message = None
|
|
fallback_response = None # Keep fallback for critical initial failures
|
|
max_tool_calls = 5 # Maximum number of sequential tool calls allowed
|
|
tool_calls_made = 0
|
|
last_response_obj = None # Store the last response object from the loop
|
|
|
|
try:
|
|
# --- Build Prompt Components ---
|
|
final_system_prompt = await build_dynamic_system_prompt(cog, message)
|
|
conversation_context_messages = gather_conversation_context(
|
|
cog, channel_id, message.id
|
|
) # Pass cog
|
|
memory_context = await get_memory_context(cog, message) # Pass cog
|
|
|
|
# --- Prepare Message History (Contents) ---
|
|
# Contents will be built progressively within the loop
|
|
contents: List[types.Content] = []
|
|
|
|
# Add memory context if available
|
|
if memory_context:
|
|
# System messages aren't directly supported in the 'contents' list for multi-turn like OpenAI.
|
|
# It's better handled via the 'system_instruction' parameter of GenerativeModel.
|
|
# We might prepend it to the first user message or handle it differently if needed.
|
|
# For now, we rely on system_instruction. Let's log if we have memory context.
|
|
print("Memory context available, relying on system_instruction.")
|
|
# If needed, could potentially add as a 'model' role message before user messages,
|
|
# but this might confuse the turn structure.
|
|
# contents.append(types.Content(role="model", parts=[types.Part.from_text(f"System Note: {memory_context}")]))
|
|
|
|
# Add conversation history
|
|
# The current message is already included in conversation_context_messages
|
|
for msg in conversation_context_messages:
|
|
role = (
|
|
"assistant"
|
|
if msg.get("author", {}).get("id") == str(cog.bot.user.id)
|
|
else "user"
|
|
) # Use get for safety
|
|
parts: List[types.Part] = [] # Initialize parts for each message
|
|
|
|
# Handle potential multimodal content in history (if stored that way)
|
|
if isinstance(msg.get("content"), list):
|
|
# If content is already a list of parts, process them
|
|
for part_data in msg["content"]:
|
|
if part_data["type"] == "text":
|
|
parts.append(types.Part(text=part_data["text"]))
|
|
elif part_data["type"] == "image_url":
|
|
# Assuming image_url has 'url' and 'mime_type'
|
|
parts.append(
|
|
types.Part(
|
|
uri=part_data["image_url"]["url"],
|
|
mime_type=part_data["image_url"]["url"]
|
|
.split(";")[0]
|
|
.split(":")[1],
|
|
)
|
|
)
|
|
# Filter out None parts if any were conditionally added
|
|
parts = [p for p in parts if p]
|
|
# Combine text, embeds, and attachments for history messages
|
|
elif (
|
|
isinstance(msg.get("content"), str)
|
|
or msg.get("embed_content")
|
|
or msg.get("attachment_descriptions")
|
|
):
|
|
text_parts = []
|
|
# Add original text content if it exists and is not empty
|
|
if isinstance(msg.get("content"), str) and msg["content"].strip():
|
|
text_parts.append(msg["content"])
|
|
# Add formatted embed content if present
|
|
embed_str = _format_embeds_for_prompt(msg.get("embed_content", []))
|
|
if embed_str:
|
|
text_parts.append(f"\n[Embed Content]:\n{embed_str}")
|
|
# Add attachment descriptions if present
|
|
if msg.get("attachment_descriptions"):
|
|
# Ensure descriptions are strings before joining
|
|
attach_desc_list = [
|
|
a["description"]
|
|
for a in msg["attachment_descriptions"]
|
|
if isinstance(a.get("description"), str)
|
|
]
|
|
if attach_desc_list:
|
|
attach_desc_str = "\n".join(attach_desc_list)
|
|
text_parts.append(f"\n[Attachments]:\n{attach_desc_str}")
|
|
|
|
# Add custom emoji and sticker descriptions and images from cache for historical messages
|
|
cached_emojis = msg.get("custom_emojis", [])
|
|
for emoji_info in cached_emojis:
|
|
emoji_name = emoji_info.get("name")
|
|
emoji_url = emoji_info.get("url")
|
|
if emoji_name:
|
|
text_parts.append(f"[Emoji: {emoji_name}]")
|
|
if emoji_url and emoji_url.startswith("http"):
|
|
# Determine MIME type for emoji URI
|
|
is_animated_emoji = emoji_info.get("animated", False)
|
|
emoji_mime_type = (
|
|
"image/gif" if is_animated_emoji else "image/png"
|
|
)
|
|
try:
|
|
# Download emoji data and send as inline_data
|
|
async with cog.session.get(
|
|
emoji_url, timeout=10
|
|
) as response:
|
|
if response.status == 200:
|
|
emoji_bytes = await response.read()
|
|
parts.append(
|
|
types.Part(
|
|
inline_data=types.Blob(
|
|
data=emoji_bytes,
|
|
mime_type=emoji_mime_type,
|
|
)
|
|
)
|
|
)
|
|
print(
|
|
f"Added inline_data part for historical emoji: {emoji_name} (MIME: {emoji_mime_type}, {len(emoji_bytes)} bytes)"
|
|
)
|
|
else:
|
|
print(
|
|
f"Error downloading historical emoji {emoji_name} from {emoji_url}: Status {response.status}"
|
|
)
|
|
text_parts.append(
|
|
f"[System Note: Failed to download emoji '{emoji_name}']"
|
|
)
|
|
except Exception as e:
|
|
print(
|
|
f"Error downloading/processing historical emoji {emoji_name} from {emoji_url}: {e}"
|
|
)
|
|
text_parts.append(
|
|
f"[System Note: Failed to process emoji '{emoji_name}']"
|
|
)
|
|
|
|
cached_stickers = msg.get("stickers", [])
|
|
for sticker_info in cached_stickers:
|
|
sticker_name = sticker_info.get("name")
|
|
sticker_url = sticker_info.get("url")
|
|
sticker_format_str = sticker_info.get("format")
|
|
sticker_format_text = (
|
|
f" (Format: {sticker_format_str})" if sticker_format_str else ""
|
|
)
|
|
if sticker_name:
|
|
text_parts.append(
|
|
f"[Sticker: {sticker_name}{sticker_format_text}]"
|
|
)
|
|
is_image_sticker = sticker_format_str in [
|
|
"StickerFormatType.png",
|
|
"StickerFormatType.apng",
|
|
]
|
|
if (
|
|
is_image_sticker
|
|
and sticker_url
|
|
and sticker_url.startswith("http")
|
|
):
|
|
sticker_mime_type = (
|
|
"image/png" # APNG is also sent as image/png
|
|
)
|
|
try:
|
|
# Download sticker data and send as inline_data
|
|
async with cog.session.get(
|
|
sticker_url, timeout=10
|
|
) as response:
|
|
if response.status == 200:
|
|
sticker_bytes = await response.read()
|
|
parts.append(
|
|
types.Part(
|
|
inline_data=types.Blob(
|
|
data=sticker_bytes,
|
|
mime_type=sticker_mime_type,
|
|
)
|
|
)
|
|
)
|
|
print(
|
|
f"Added inline_data part for historical sticker: {sticker_name} (MIME: {sticker_mime_type}, {len(sticker_bytes)} bytes)"
|
|
)
|
|
else:
|
|
print(
|
|
f"Error downloading historical sticker {sticker_name} from {sticker_url}: Status {response.status}"
|
|
)
|
|
text_parts.append(
|
|
f"[System Note: Failed to download sticker '{sticker_name}']"
|
|
)
|
|
except Exception as e:
|
|
print(
|
|
f"Error downloading/processing historical sticker {sticker_name} from {sticker_url}: {e}"
|
|
)
|
|
text_parts.append(
|
|
f"[System Note: Failed to process sticker '{sticker_name}']"
|
|
)
|
|
elif sticker_format_str == "StickerFormatType.lottie":
|
|
# Lottie files are JSON, not directly viewable by Gemini as images. Send as text.
|
|
text_parts.append(
|
|
f"[Lottie Sticker: {sticker_name} (JSON animation, not displayed as image)]"
|
|
)
|
|
|
|
full_text = "\n".join(text_parts).strip()
|
|
if full_text: # Only add if there's some text content
|
|
author_string_from_cache = msg.get("author_string")
|
|
|
|
if (
|
|
author_string_from_cache
|
|
and str(author_string_from_cache).strip()
|
|
):
|
|
# If author_string is available and valid from the cache, use it directly.
|
|
# This string is expected to be pre-formatted by the context gathering logic.
|
|
author_identifier_string = str(author_string_from_cache)
|
|
parts.append(
|
|
types.Part(text=f"{author_identifier_string}: {full_text}")
|
|
)
|
|
else:
|
|
# Fallback to reconstructing the author identifier if author_string is not available/valid
|
|
author_details = msg.get("author", {})
|
|
raw_display_name = author_details.get("display_name")
|
|
raw_name = author_details.get("name") # Discord username
|
|
author_id = author_details.get("id")
|
|
|
|
final_display_part = ""
|
|
username_part_str = ""
|
|
|
|
if raw_display_name and str(raw_display_name).strip():
|
|
final_display_part = str(raw_display_name)
|
|
elif (
|
|
raw_name and str(raw_name).strip()
|
|
): # Fallback display to username
|
|
final_display_part = str(raw_name)
|
|
elif author_id: # Fallback display to User ID
|
|
final_display_part = f"User ID: {author_id}"
|
|
else: # Default to "Unknown User" if no other identifier is found
|
|
final_display_part = "Unknown User"
|
|
|
|
# Construct username part if raw_name is valid and different from final_display_part
|
|
if (
|
|
raw_name
|
|
and str(raw_name).strip()
|
|
and str(raw_name).lower() != "none"
|
|
):
|
|
# Avoid "Username (Username: Username)" if display name fell back to raw_name
|
|
if final_display_part.lower() != str(raw_name).lower():
|
|
username_part_str = f" (Username: {str(raw_name)})"
|
|
# If username is bad/missing, but we have an ID, and ID isn't already the main display part
|
|
elif author_id and not (
|
|
raw_name
|
|
and str(raw_name).strip()
|
|
and str(raw_name).lower() != "none"
|
|
):
|
|
if not final_display_part.startswith("User ID:"):
|
|
username_part_str = f" (User ID: {author_id})"
|
|
|
|
author_identifier_string = (
|
|
f"{final_display_part}{username_part_str}"
|
|
)
|
|
# Append the text part to the existing parts list for this message
|
|
parts.append(
|
|
types.Part(text=f"{author_identifier_string}: {full_text}")
|
|
)
|
|
|
|
# Only append to contents if there are parts to add for this message
|
|
if parts:
|
|
contents.append(types.Content(role=role, parts=parts))
|
|
else:
|
|
# If no parts were generated (e.g., empty message, or only unsupported content),
|
|
# log a warning and skip adding this message to contents.
|
|
print(
|
|
f"Warning: Skipping message from history (ID: {msg.get('id')}) as no valid parts were generated."
|
|
)
|
|
|
|
# --- Prepare the current message content (potentially multimodal) ---
|
|
# This section is no longer needed as the current message is included in conversation_context_messages
|
|
# current_message_parts = []
|
|
# formatted_current_message = format_message(cog, message) # Pass cog if needed
|
|
|
|
# --- Construct text content, including reply context if applicable ---
|
|
# This section is now handled within the conversation history loop
|
|
# text_content = ""
|
|
# if formatted_current_message.get("is_reply") and formatted_current_message.get("replied_to_author_name"):
|
|
# reply_author = formatted_current_message["replied_to_author_name"]
|
|
# reply_content = formatted_current_message.get("replied_to_content", "...") # Use ellipsis if content missing
|
|
# # Truncate long replied content to keep context concise
|
|
# max_reply_len = 150
|
|
# if len(reply_content) > max_reply_len:
|
|
# reply_content = reply_content[:max_reply_len] + "..."
|
|
# text_content += f"(Replying to {reply_author}: \"{reply_content}\")\n"
|
|
|
|
# # Add current message author and content
|
|
# # Use the new author_string here for the current message
|
|
# current_author_string = formatted_current_message.get("author_string", formatted_current_message.get("author", {}).get("display_name", "Unknown User"))
|
|
# text_content += f"{current_author_string}: {formatted_current_message['content']}" # Keep original content
|
|
|
|
# # Add formatted embed content if present for the *current* message
|
|
# current_embed_str = _format_embeds_for_prompt(formatted_current_message.get("embed_content", []))
|
|
# if current_embed_str:
|
|
# text_content += f"\n[Embed Content]:\n{current_embed_str}"
|
|
|
|
# # Add attachment descriptions for the *current* message
|
|
# if formatted_current_message.get("attachment_descriptions"):
|
|
# # Ensure descriptions are strings before joining
|
|
# current_attach_desc_list = [a['description'] for a in formatted_current_message['attachment_descriptions'] if isinstance(a.get('description'), str)]
|
|
# if current_attach_desc_list:
|
|
# current_attach_desc_str = "\n".join(current_attach_desc_list)
|
|
# text_content += f"\n[Attachments]:\n{current_attach_desc_str}"
|
|
|
|
# # Add mention details
|
|
# if formatted_current_message.get("mentioned_users_details"): # This key might not exist, adjust if needed
|
|
# mentions_str = ", ".join([f"{m['display_name']}(id:{m['id']})" for m in formatted_current_message["mentioned_users_details"]])
|
|
# text_content += f"\n(Message Details: Mentions=[{mentions_str}])"
|
|
|
|
# current_message_parts.append(types.Part(text=text_content))
|
|
# --- End text content construction ---
|
|
|
|
# --- Add current message attachments (uncommented and potentially modified) ---
|
|
# Find the last 'user' message in contents (which should be the current message)
|
|
current_user_content_index = -1
|
|
for i in range(len(contents) - 1, -1, -1):
|
|
if contents[i].role == "user":
|
|
current_user_content_index = i
|
|
break
|
|
|
|
# Ensure formatted_current_message is defined for the current message processing
|
|
# This will be used for attachments, emojis, and stickers for the current message.
|
|
formatted_current_message = format_message(cog, message)
|
|
|
|
if message.attachments and current_user_content_index != -1:
|
|
print(
|
|
f"Processing {len(message.attachments)} attachments for current message {message.id}"
|
|
)
|
|
attachment_parts_to_add = (
|
|
[]
|
|
) # Collect parts to add to the current user message
|
|
|
|
# Fetch the attachment descriptions from the already formatted message
|
|
attachment_descriptions = formatted_current_message.get(
|
|
"attachment_descriptions", []
|
|
)
|
|
desc_map = {desc.get("filename"): desc for desc in attachment_descriptions}
|
|
|
|
for attachment in message.attachments:
|
|
mime_type = attachment.content_type
|
|
file_url = attachment.url
|
|
filename = attachment.filename
|
|
|
|
# Check if MIME type is supported for URI input by Gemini
|
|
# Expanded list based on Gemini 1.5 Pro docs (April 2024)
|
|
supported_mime_prefixes = [
|
|
"image/", # image/png, image/jpeg, image/heic, image/heif, image/webp
|
|
"video/", # video/mov, video/mpeg, video/mp4, video/mpg, video/avi, video/wmv, video/mpegps, video/flv
|
|
"audio/", # audio/mpeg, audio/mp3, audio/wav, audio/ogg, audio/flac, audio/opus, audio/amr, audio/midi
|
|
"text/", # text/plain, text/html, text/css, text/javascript, text/json, text/csv, text/rtf, text/markdown
|
|
"application/pdf",
|
|
"application/rtf", # Explicitly add RTF if needed
|
|
# Add more as supported/needed
|
|
]
|
|
is_supported = False
|
|
detected_mime_type = (
|
|
mime_type if mime_type else "application/octet-stream"
|
|
) # Default if missing
|
|
for prefix in supported_mime_prefixes:
|
|
if detected_mime_type.startswith(prefix):
|
|
is_supported = True
|
|
break
|
|
|
|
# Get pre-formatted description string (already includes size, type etc.)
|
|
preformatted_desc = desc_map.get(filename, {}).get(
|
|
"description", f"[File: {filename} (unknown type)]"
|
|
)
|
|
|
|
# Add the descriptive text part using the pre-formatted description
|
|
instruction_text = (
|
|
f"[ATTACHMENT] {preformatted_desc}" # Explicitly mark as attachment
|
|
)
|
|
attachment_parts_to_add.append(types.Part(text=instruction_text))
|
|
print(f"Added text description for attachment: {filename}")
|
|
|
|
if is_supported and file_url:
|
|
try:
|
|
clean_mime_type = (
|
|
detected_mime_type.split(";")[0]
|
|
if detected_mime_type
|
|
else "application/octet-stream"
|
|
)
|
|
# Download attachment data and send as inline_data
|
|
async with cog.session.get(
|
|
file_url, timeout=15
|
|
) as response: # Increased timeout for potentially larger files
|
|
if response.status == 200:
|
|
attachment_bytes = await response.read()
|
|
attachment_parts_to_add.append(
|
|
types.Part(
|
|
inline_data=types.Blob(
|
|
data=attachment_bytes,
|
|
mime_type=clean_mime_type,
|
|
)
|
|
)
|
|
)
|
|
print(
|
|
f"Added inline_data part for supported attachment: {filename} (MIME: {clean_mime_type}, {len(attachment_bytes)} bytes)"
|
|
)
|
|
else:
|
|
print(
|
|
f"Error downloading attachment {filename} from {file_url}: Status {response.status}"
|
|
)
|
|
attachment_parts_to_add.append(
|
|
types.Part(
|
|
text=f"(System Note: Failed to download attachment '{filename}')"
|
|
)
|
|
)
|
|
except Exception as e:
|
|
print(
|
|
f"Error downloading/processing attachment {filename} from {file_url}: {e}"
|
|
)
|
|
attachment_parts_to_add.append(
|
|
types.Part(
|
|
text=f"(System Note: Failed to process attachment '{filename}' - {e})"
|
|
)
|
|
)
|
|
else:
|
|
print(
|
|
f"Skipping inline_data part for unsupported attachment: {filename} (Type: {detected_mime_type}, URL: {file_url})"
|
|
)
|
|
# Text description was already added above
|
|
|
|
# Add the collected attachment parts to the existing user message parts
|
|
if attachment_parts_to_add:
|
|
contents[current_user_content_index].parts.extend(
|
|
attachment_parts_to_add
|
|
)
|
|
print(
|
|
f"Extended user message at index {current_user_content_index} with {len(attachment_parts_to_add)} attachment parts."
|
|
)
|
|
elif not message.attachments:
|
|
print("No attachments found for the current message.")
|
|
elif current_user_content_index == -1:
|
|
print(
|
|
"Warning: Could not find current user message in contents to add attachments to (for attachments)."
|
|
)
|
|
# --- End attachment processing ---
|
|
|
|
# --- Add current message custom emojis and stickers ---
|
|
if current_user_content_index != -1:
|
|
emoji_sticker_parts_to_add = []
|
|
# Process custom emojis from formatted_current_message
|
|
custom_emojis_current = formatted_current_message.get("custom_emojis", [])
|
|
for emoji_info in custom_emojis_current:
|
|
emoji_name = emoji_info.get("name")
|
|
emoji_url = emoji_info.get("url")
|
|
if emoji_name and emoji_url:
|
|
emoji_sticker_parts_to_add.append(
|
|
types.Part(text=f"[Emoji: {emoji_name}]")
|
|
)
|
|
print(
|
|
f"Added text description for current message emoji: {emoji_name}"
|
|
)
|
|
# Determine MIME type for emoji URI
|
|
is_animated_emoji = emoji_info.get("animated", False)
|
|
emoji_mime_type = "image/gif" if is_animated_emoji else "image/png"
|
|
try:
|
|
# Download emoji data and send as inline_data
|
|
async with cog.session.get(emoji_url, timeout=10) as response:
|
|
if response.status == 200:
|
|
emoji_bytes = await response.read()
|
|
emoji_sticker_parts_to_add.append(
|
|
types.Part(
|
|
inline_data=types.Blob(
|
|
data=emoji_bytes, mime_type=emoji_mime_type
|
|
)
|
|
)
|
|
)
|
|
print(
|
|
f"Added inline_data part for current emoji: {emoji_name} (MIME: {emoji_mime_type}, {len(emoji_bytes)} bytes)"
|
|
)
|
|
else:
|
|
print(
|
|
f"Error downloading current emoji {emoji_name} from {emoji_url}: Status {response.status}"
|
|
)
|
|
emoji_sticker_parts_to_add.append(
|
|
types.Part(
|
|
text=f"[System Note: Failed to download emoji '{emoji_name}']"
|
|
)
|
|
)
|
|
except Exception as e:
|
|
print(
|
|
f"Error downloading/processing current emoji {emoji_name} from {emoji_url}: {e}"
|
|
)
|
|
emoji_sticker_parts_to_add.append(
|
|
types.Part(
|
|
text=f"[System Note: Failed to process emoji '{emoji_name}']"
|
|
)
|
|
)
|
|
|
|
# Process stickers from formatted_current_message
|
|
stickers_current = formatted_current_message.get("stickers", [])
|
|
for sticker_info in stickers_current:
|
|
sticker_name = sticker_info.get("name")
|
|
sticker_url = sticker_info.get("url")
|
|
sticker_format_str = sticker_info.get("format")
|
|
|
|
if sticker_name and sticker_url:
|
|
emoji_sticker_parts_to_add.append(
|
|
types.Part(text=f"[Sticker: {sticker_name}]")
|
|
)
|
|
print(
|
|
f"Added text description for current message sticker: {sticker_name}"
|
|
)
|
|
|
|
is_image_sticker = sticker_format_str in [
|
|
"StickerFormatType.png",
|
|
"StickerFormatType.apng",
|
|
]
|
|
|
|
if is_image_sticker:
|
|
sticker_mime_type = (
|
|
"image/png" # APNG is also sent as image/png
|
|
)
|
|
try:
|
|
# Download sticker data and send as inline_data
|
|
async with cog.session.get(
|
|
sticker_url, timeout=10
|
|
) as response:
|
|
if response.status == 200:
|
|
sticker_bytes = await response.read()
|
|
emoji_sticker_parts_to_add.append(
|
|
types.Part(
|
|
inline_data=types.Blob(
|
|
data=sticker_bytes,
|
|
mime_type=sticker_mime_type,
|
|
)
|
|
)
|
|
)
|
|
print(
|
|
f"Added inline_data part for current sticker: {sticker_name} (MIME: {sticker_mime_type}, {len(sticker_bytes)} bytes)"
|
|
)
|
|
else:
|
|
print(
|
|
f"Error downloading current sticker {sticker_name} from {sticker_url}: Status {response.status}"
|
|
)
|
|
emoji_sticker_parts_to_add.append(
|
|
types.Part(
|
|
text=f"[System Note: Failed to download sticker '{sticker_name}']"
|
|
)
|
|
)
|
|
except Exception as e:
|
|
print(
|
|
f"Error downloading/processing current sticker {sticker_name} from {sticker_url}: {e}"
|
|
)
|
|
emoji_sticker_parts_to_add.append(
|
|
types.Part(
|
|
text=f"[System Note: Failed to process sticker '{sticker_name}']"
|
|
)
|
|
)
|
|
elif sticker_format_str == "StickerFormatType.lottie":
|
|
# Lottie files are JSON, not directly viewable by Gemini as images. Send as text.
|
|
emoji_sticker_parts_to_add.append(
|
|
types.Part(
|
|
text=f"[Lottie Sticker: {sticker_name} (JSON animation, not displayed as image)]"
|
|
)
|
|
)
|
|
else:
|
|
print(
|
|
f"Sticker {sticker_name} has format {sticker_format_str}, not attempting image download. URL: {sticker_url}"
|
|
)
|
|
|
|
if emoji_sticker_parts_to_add:
|
|
contents[current_user_content_index].parts.extend(
|
|
emoji_sticker_parts_to_add
|
|
)
|
|
print(
|
|
f"Extended user message at index {current_user_content_index} with {len(emoji_sticker_parts_to_add)} emoji/sticker parts."
|
|
)
|
|
elif (
|
|
current_user_content_index == -1
|
|
): # Only print if it's specifically for emojis/stickers
|
|
print(
|
|
"Warning: Could not find current user message in contents to add emojis/stickers to."
|
|
)
|
|
# --- End emoji and sticker processing for current message ---
|
|
|
|
# --- Prepare Tools ---
|
|
# Preprocess tool parameter schemas before creating the Tool object
|
|
preprocessed_declarations = []
|
|
if TOOLS:
|
|
for decl in TOOLS:
|
|
# Create a new FunctionDeclaration with preprocessed parameters
|
|
# Ensure decl.parameters is a dict before preprocessing
|
|
preprocessed_params = (
|
|
_preprocess_schema_for_vertex(decl.parameters)
|
|
if isinstance(decl.parameters, dict)
|
|
else decl.parameters
|
|
)
|
|
preprocessed_declarations.append(
|
|
types.FunctionDeclaration(
|
|
name=decl.name,
|
|
description=decl.description,
|
|
parameters=preprocessed_params, # Use the preprocessed schema
|
|
)
|
|
)
|
|
print(
|
|
f"Preprocessed {len(preprocessed_declarations)} tool declarations for Vertex AI compatibility."
|
|
)
|
|
else:
|
|
print("No tools found in config (TOOLS list is empty or None).")
|
|
|
|
# Create the Tool object using the preprocessed declarations
|
|
vertex_tool = (
|
|
types.Tool(function_declarations=preprocessed_declarations)
|
|
if preprocessed_declarations
|
|
else None
|
|
)
|
|
tools_list = [vertex_tool] if vertex_tool else None
|
|
|
|
# --- Prepare Generation Config ---
|
|
# Base generation config settings (will be augmented later)
|
|
base_generation_config_dict = {
|
|
"temperature": 1, # From user example
|
|
"top_p": 0.95, # From user example
|
|
"max_output_tokens": 8192, # From user example
|
|
"safety_settings": STANDARD_SAFETY_SETTINGS, # Include standard safety settings
|
|
"system_instruction": final_system_prompt, # Pass system prompt via config
|
|
# candidate_count=1 # Default is 1
|
|
# stop_sequences=... # Add if needed
|
|
}
|
|
|
|
# --- Tool Execution Loop ---
|
|
while tool_calls_made < max_tool_calls:
|
|
print(
|
|
f"Making API call (Loop Iteration {tool_calls_made + 1}/{max_tool_calls})..."
|
|
)
|
|
|
|
# --- Log Request Payload ---
|
|
# (Keep existing logging logic if desired)
|
|
try:
|
|
request_payload_log = [
|
|
{"role": c.role, "parts": [str(p) for p in c.parts]}
|
|
for c in contents
|
|
]
|
|
print(
|
|
f"--- Raw API Request (Loop {tool_calls_made + 1}) ---\n{json.dumps(request_payload_log, indent=2)}\n------------------------------------"
|
|
)
|
|
except Exception as log_e:
|
|
print(f"Error logging raw request/response: {log_e}")
|
|
|
|
# --- Call API using the new helper ---
|
|
# Build the config for this specific call (tool check)
|
|
current_gen_config_dict = base_generation_config_dict.copy()
|
|
if tools_list:
|
|
current_gen_config_dict["tools"] = tools_list
|
|
# Define tool_config here if needed, e.g., for ANY mode
|
|
current_gen_config_dict["tool_config"] = types.ToolConfig(
|
|
function_calling_config=types.FunctionCallingConfig(
|
|
mode=types.FunctionCallingConfigMode.ANY
|
|
)
|
|
)
|
|
# Omit response_mime_type and response_schema for tool checking
|
|
|
|
current_gen_config = types.GenerateContentConfig(**current_gen_config_dict)
|
|
|
|
current_response_obj = await call_google_genai_api_with_retry(
|
|
cog=cog,
|
|
model_name=active_model, # Use the dynamically set model for tool checks
|
|
contents=contents,
|
|
generation_config=current_gen_config, # Pass the combined config
|
|
request_desc=f"Tool Check {tool_calls_made + 1} for message {message.id}",
|
|
# No separate safety, tools, tool_config args needed
|
|
)
|
|
last_response_obj = current_response_obj # Store the latest response
|
|
|
|
# --- Log Raw Response ---
|
|
# (Keep existing logging logic if desired)
|
|
try:
|
|
print(
|
|
f"--- Raw API Response (Loop {tool_calls_made + 1}) ---\n{current_response_obj}\n-----------------------------------"
|
|
)
|
|
except Exception as log_e:
|
|
print(f"Error logging raw request/response: {log_e}")
|
|
|
|
if not current_response_obj or not current_response_obj.candidates:
|
|
error_message = f"API call in tool loop (Iteration {tool_calls_made + 1}) failed to return candidates."
|
|
print(error_message)
|
|
break # Exit loop on critical API failure
|
|
|
|
candidate = current_response_obj.candidates[0]
|
|
|
|
# --- Find ALL function calls using the updated helper ---
|
|
# The response structure might differ slightly; check candidate.content.parts
|
|
function_calls_found = []
|
|
if candidate.content and candidate.content.parts:
|
|
function_calls_found = [
|
|
part.function_call
|
|
for part in candidate.content.parts
|
|
if hasattr(part, "function_call")
|
|
and isinstance(part.function_call, types.FunctionCall)
|
|
]
|
|
|
|
if function_calls_found:
|
|
# Check if the *only* call is no_operation
|
|
if (
|
|
len(function_calls_found) == 1
|
|
and function_calls_found[0].name == "no_operation"
|
|
):
|
|
print("AI called only no_operation, signaling completion.")
|
|
# Append the model's response (which contains the function call part)
|
|
contents.append(candidate.content)
|
|
# Add the function response part using the updated process_requested_tools
|
|
no_op_response_part = await process_requested_tools(
|
|
cog, function_calls_found[0]
|
|
)
|
|
contents.append(
|
|
types.Content(role="function", parts=no_op_response_part)
|
|
)
|
|
last_response_obj = current_response_obj # Keep track of the response containing the no_op
|
|
break # Exit loop
|
|
|
|
# Process multiple function calls if present (or a single non-no_op call)
|
|
tool_calls_made += (
|
|
1 # Increment once per model turn that requests tools
|
|
)
|
|
print(
|
|
f"AI requested {len(function_calls_found)} tool(s): {[fc.name for fc in function_calls_found]} (Turn {tool_calls_made}/{max_tool_calls})"
|
|
)
|
|
|
|
# Append the model's response content (containing the function call parts)
|
|
model_request_content = candidate.content
|
|
contents.append(model_request_content)
|
|
|
|
# Add model request turn to cache
|
|
try:
|
|
# Simple text representation for cache
|
|
model_request_cache_entry = {
|
|
"id": f"bot_tool_req_{message.id}_{int(time.time())}_{tool_calls_made}",
|
|
"author": {
|
|
"id": str(cog.bot.user.id),
|
|
"name": cog.bot.user.name,
|
|
"display_name": cog.bot.user.display_name,
|
|
"bot": True,
|
|
},
|
|
"content": f"[System Note: Gurt requested tool(s): {', '.join([fc.name for fc in function_calls_found])}]",
|
|
"created_at": datetime.datetime.now().isoformat(),
|
|
"attachments": [],
|
|
"embeds": False,
|
|
"mentions": [],
|
|
"replied_to_message_id": None,
|
|
"channel": message.channel,
|
|
"guild": message.guild,
|
|
"reference": None,
|
|
"mentioned_users_details": [],
|
|
# Add tool call details for potential future use in context building
|
|
"tool_calls": [
|
|
{"name": fc.name, "args": dict(fc.args) if fc.args else {}}
|
|
for fc in function_calls_found
|
|
],
|
|
}
|
|
cog.message_cache["by_channel"].setdefault(
|
|
channel_id, deque(maxlen=CONTEXT_WINDOW_SIZE)
|
|
).append(model_request_cache_entry)
|
|
cog.message_cache["global_recent"].append(model_request_cache_entry)
|
|
print(f"Cached model's tool request turn.")
|
|
except Exception as cache_err:
|
|
print(f"Error caching model's tool request turn: {cache_err}")
|
|
|
|
# --- Execute all requested tools and gather response parts ---
|
|
# function_response_parts = [] # <-- REMOVE THIS INITIALIZATION
|
|
all_function_response_parts: List[types.Part] = (
|
|
[]
|
|
) # New list to collect all parts
|
|
function_results_for_cache = [] # Store results for caching
|
|
for func_call in function_calls_found:
|
|
# Execute the tool using the updated helper, now returns a LIST of parts
|
|
returned_parts = await process_requested_tools(
|
|
cog, func_call
|
|
) # returns List[types.Part]
|
|
all_function_response_parts.extend(
|
|
returned_parts
|
|
) # <-- EXTEND the list
|
|
|
|
# --- Update caching logic ---
|
|
# Find the function_response part within returned_parts to get the result for cache
|
|
func_resp_part = next(
|
|
(p for p in returned_parts if hasattr(p, "function_response")),
|
|
None,
|
|
)
|
|
if func_resp_part and func_resp_part.function_response:
|
|
function_results_for_cache.append(
|
|
{
|
|
"name": func_resp_part.function_response.name,
|
|
"response": func_resp_part.function_response.response, # This is the modified dict result
|
|
}
|
|
)
|
|
# --- End update caching logic ---
|
|
|
|
# Append a single function role turn containing ALL response parts to the API contents
|
|
if all_function_response_parts: # Check the new list
|
|
function_response_content = types.Content(
|
|
role="function", parts=all_function_response_parts
|
|
) # <-- Use the combined list
|
|
contents.append(function_response_content)
|
|
|
|
# Add function response turn to cache
|
|
try:
|
|
# Simple text representation for cache
|
|
# Join results for multiple calls if needed, truncate long outputs
|
|
result_summary_parts = []
|
|
for res in function_results_for_cache:
|
|
res_str = json.dumps(res.get("response", {}))
|
|
truncated_res = (
|
|
(res_str[:150] + "...")
|
|
if len(res_str) > 153
|
|
else res_str
|
|
)
|
|
result_summary_parts.append(
|
|
f"Tool: {res.get('name', 'N/A')}, Result: {truncated_res}"
|
|
)
|
|
result_summary = "; ".join(result_summary_parts)
|
|
|
|
function_response_cache_entry = {
|
|
"id": f"bot_tool_res_{message.id}_{int(time.time())}_{tool_calls_made}",
|
|
"author": {
|
|
"id": "FUNCTION",
|
|
"name": "Tool Execution",
|
|
"display_name": "Tool Execution",
|
|
"bot": True,
|
|
}, # Special author ID?
|
|
"content": f"[System Note: Tool Execution Result: {result_summary}]",
|
|
"created_at": datetime.datetime.now().isoformat(),
|
|
"attachments": [],
|
|
"embeds": False,
|
|
"mentions": [],
|
|
"replied_to_message_id": None,
|
|
"channel": message.channel,
|
|
"guild": message.guild,
|
|
"reference": None,
|
|
"mentioned_users_details": [],
|
|
# Store the full function results
|
|
"function_results": function_results_for_cache,
|
|
}
|
|
cog.message_cache["by_channel"].setdefault(
|
|
channel_id, deque(maxlen=CONTEXT_WINDOW_SIZE)
|
|
).append(function_response_cache_entry)
|
|
cog.message_cache["global_recent"].append(
|
|
function_response_cache_entry
|
|
)
|
|
print(f"Cached function response turn.")
|
|
except Exception as cache_err:
|
|
print(f"Error caching function response turn: {cache_err}")
|
|
else:
|
|
print(
|
|
"Warning: Function calls found, but no response parts generated."
|
|
)
|
|
|
|
# No 'continue' statement needed here; the loop naturally continues
|
|
else:
|
|
# No function calls found in this response's parts
|
|
print("No tool calls requested by AI in this turn. Exiting loop.")
|
|
# last_response_obj already holds the model's final (non-tool) response
|
|
break # Exit loop
|
|
|
|
# --- After the loop ---
|
|
# Check if a critical API error occurred *during* the loop
|
|
if error_message:
|
|
print(f"Exited tool loop due to API error: {error_message}")
|
|
if cog.bot.user.mentioned_in(message) or (
|
|
message.reference
|
|
and message.reference.resolved
|
|
and message.reference.resolved.author == cog.bot.user
|
|
):
|
|
fallback_response = {
|
|
"should_respond": True,
|
|
"content": "...",
|
|
"react_with_emoji": "❓",
|
|
}
|
|
# Check if the loop hit the max iteration limit
|
|
elif tool_calls_made >= max_tool_calls:
|
|
error_message = f"Reached maximum tool call limit ({max_tool_calls}). Attempting to generate final response based on gathered context."
|
|
print(error_message)
|
|
# Proceed to the final JSON generation step outside the loop
|
|
pass # No action needed here, just let the loop exit
|
|
|
|
# --- Final JSON Generation (outside the loop) ---
|
|
if not error_message:
|
|
# If the loop finished because no more tools were called, the last_response_obj
|
|
# should contain the final textual response (potentially JSON).
|
|
if last_response_obj:
|
|
print("Attempting to parse final JSON from the last response object...")
|
|
last_response_text = _get_response_text(last_response_obj)
|
|
|
|
# --- Log Raw Unparsed JSON (from loop exit) ---
|
|
print(f"--- RAW UNPARSED JSON (from loop exit) ---")
|
|
print(last_response_text)
|
|
print(f"--- END RAW UNPARSED JSON ---")
|
|
# --- End Log ---
|
|
|
|
if last_response_text:
|
|
# Try parsing directly first
|
|
final_parsed_data = parse_and_validate_json_response(
|
|
last_response_text,
|
|
RESPONSE_SCHEMA["schema"],
|
|
"final response (from last loop object)",
|
|
)
|
|
|
|
# If direct parsing failed OR if we hit the tool limit, make a dedicated call for JSON.
|
|
if final_parsed_data is None:
|
|
log_reason = (
|
|
"last response parsing failed"
|
|
if last_response_text
|
|
else "last response had no text"
|
|
)
|
|
if tool_calls_made >= max_tool_calls:
|
|
log_reason = "hit tool limit"
|
|
print(f"Making dedicated final API call for JSON ({log_reason})...")
|
|
|
|
# Prepare the final generation config with JSON enforcement
|
|
processed_response_schema = _preprocess_schema_for_vertex(
|
|
RESPONSE_SCHEMA["schema"]
|
|
)
|
|
# Start with base config (which now includes system_instruction)
|
|
final_gen_config_dict = base_generation_config_dict.copy()
|
|
final_gen_config_dict.update(
|
|
{
|
|
"response_mime_type": "application/json",
|
|
"response_schema": processed_response_schema,
|
|
# Explicitly exclude tools/tool_config for final JSON generation
|
|
"tools": None,
|
|
"tool_config": None,
|
|
# Ensure system_instruction is still present from base_generation_config_dict
|
|
}
|
|
)
|
|
# Remove system_instruction if it's None or empty, although base should have it
|
|
if not final_gen_config_dict.get("system_instruction"):
|
|
final_gen_config_dict.pop("system_instruction", None)
|
|
|
|
generation_config_final_json = types.GenerateContentConfig(
|
|
**final_gen_config_dict
|
|
)
|
|
|
|
# Make the final call *without* tools enabled (handled by config)
|
|
final_json_response_obj = await call_google_genai_api_with_retry(
|
|
cog=cog,
|
|
model_name=active_model, # Use the active model for final JSON response
|
|
contents=contents, # Pass the accumulated history
|
|
generation_config=generation_config_final_json, # Use combined JSON config
|
|
request_desc=f"Final JSON Generation (dedicated call) for message {message.id}",
|
|
# No separate safety, tools, tool_config args needed
|
|
)
|
|
|
|
if not final_json_response_obj:
|
|
error_msg_suffix = (
|
|
"Final dedicated API call returned no response object."
|
|
)
|
|
print(error_msg_suffix)
|
|
if error_message:
|
|
error_message += f" | {error_msg_suffix}"
|
|
else:
|
|
error_message = error_msg_suffix
|
|
elif not final_json_response_obj.candidates:
|
|
error_msg_suffix = (
|
|
"Final dedicated API call returned no candidates."
|
|
)
|
|
print(error_msg_suffix)
|
|
if error_message:
|
|
error_message += f" | {error_msg_suffix}"
|
|
else:
|
|
error_message = error_msg_suffix
|
|
else:
|
|
final_response_text = _get_response_text(
|
|
final_json_response_obj
|
|
)
|
|
|
|
# --- Log Raw Unparsed JSON (from dedicated call) ---
|
|
print(f"--- RAW UNPARSED JSON (dedicated call) ---")
|
|
print(final_response_text)
|
|
print(f"--- END RAW UNPARSED JSON ---")
|
|
# --- End Log ---
|
|
|
|
final_parsed_data = parse_and_validate_json_response(
|
|
final_response_text,
|
|
RESPONSE_SCHEMA["schema"],
|
|
"final response (dedicated call)",
|
|
)
|
|
if final_parsed_data is None:
|
|
error_msg_suffix = f"Failed to parse/validate final dedicated JSON response. Raw text: {final_response_text[:500]}"
|
|
print(f"Critical Error: {error_msg_suffix}")
|
|
if error_message:
|
|
error_message += f" | {error_msg_suffix}"
|
|
else:
|
|
error_message = error_msg_suffix
|
|
# Set fallback only if mentioned or replied to
|
|
if cog.bot.user.mentioned_in(message) or (
|
|
message.reference
|
|
and message.reference.resolved
|
|
and message.reference.resolved.author == cog.bot.user
|
|
):
|
|
fallback_response = {
|
|
"should_respond": True,
|
|
"content": "...",
|
|
"react_with_emoji": "❓",
|
|
}
|
|
else:
|
|
print(
|
|
"Successfully parsed final JSON response from dedicated call."
|
|
)
|
|
elif final_parsed_data:
|
|
print(
|
|
"Successfully parsed final JSON response from last loop object."
|
|
)
|
|
else:
|
|
# This case handles if the loop exited without error but also without a last_response_obj
|
|
# (e.g., initial API call failed before loop even started, but wasn't caught as error).
|
|
error_message = (
|
|
"Tool processing completed without a final response object."
|
|
)
|
|
print(error_message)
|
|
if cog.bot.user.mentioned_in(message) or (
|
|
message.reference
|
|
and message.reference.resolved
|
|
and message.reference.resolved.author == cog.bot.user
|
|
):
|
|
fallback_response = {
|
|
"should_respond": True,
|
|
"content": "...",
|
|
"react_with_emoji": "❓",
|
|
}
|
|
|
|
except Exception as e:
|
|
error_message = f"Error in get_ai_response main logic for message {message.id}: {type(e).__name__}: {str(e)}"
|
|
print(error_message)
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
final_parsed_data = None # Ensure final data is None on error
|
|
# Add fallback if applicable
|
|
if cog.bot.user.mentioned_in(message) or (
|
|
message.reference
|
|
and message.reference.resolved
|
|
and message.reference.resolved.author == cog.bot.user
|
|
):
|
|
fallback_response = {
|
|
"should_respond": True,
|
|
"content": "...",
|
|
"react_with_emoji": "❓",
|
|
}
|
|
|
|
sticker_ids_to_send: List[str] = []
|
|
|
|
# --- Handle Custom Emoji/Sticker Replacement in Content ---
|
|
if final_parsed_data and final_parsed_data.get("content"):
|
|
content_to_process = final_parsed_data["content"]
|
|
# Find all potential custom emoji/sticker names like :name:
|
|
# Use a non-greedy match for the name to avoid matching across multiple colons
|
|
# Regex updated to capture names with spaces and other characters, excluding colons.
|
|
potential_custom_items = re.findall(r":([^:]+):", content_to_process)
|
|
modified_content = content_to_process
|
|
|
|
for item_name_key in potential_custom_items:
|
|
full_item_name_with_colons = f":{item_name_key}:"
|
|
|
|
# Check if it's a known custom emoji
|
|
emoji_data = await cog.emoji_manager.get_emoji(full_item_name_with_colons)
|
|
can_use_emoji = False
|
|
if isinstance(emoji_data, dict):
|
|
emoji_id = emoji_data.get("id")
|
|
is_animated = emoji_data.get("animated", False)
|
|
emoji_guild_id = emoji_data.get("guild_id")
|
|
|
|
if emoji_id:
|
|
if emoji_guild_id is not None:
|
|
try:
|
|
guild = cog.bot.get_guild(int(emoji_guild_id))
|
|
if guild:
|
|
can_use_emoji = True
|
|
print(
|
|
f"Emoji '{full_item_name_with_colons}' belongs to guild '{guild.name}' ({emoji_guild_id}), bot is a member."
|
|
)
|
|
else:
|
|
print(
|
|
f"Cannot use emoji '{full_item_name_with_colons}'. Bot is not in guild ID: {emoji_guild_id}."
|
|
)
|
|
except ValueError:
|
|
print(
|
|
f"Invalid guild_id format for emoji '{full_item_name_with_colons}': {emoji_guild_id}"
|
|
)
|
|
else: # guild_id is None, considered usable (e.g., DM or old data)
|
|
can_use_emoji = True
|
|
print(
|
|
f"Emoji '{full_item_name_with_colons}' has no associated guild_id, allowing usage."
|
|
)
|
|
|
|
if can_use_emoji and emoji_id:
|
|
discord_emoji_syntax = (
|
|
f"<{'a' if is_animated else ''}:{item_name_key}:{emoji_id}>"
|
|
)
|
|
# Ensure replacement happens only once per unique placeholder if it appears multiple times
|
|
modified_content = modified_content.replace(
|
|
full_item_name_with_colons, discord_emoji_syntax, 1
|
|
)
|
|
print(
|
|
f"Replaced custom emoji '{full_item_name_with_colons}' with Discord syntax: {discord_emoji_syntax}"
|
|
)
|
|
elif not emoji_id:
|
|
print(
|
|
f"Found custom emoji '{full_item_name_with_colons}' (dict) but no ID stored."
|
|
)
|
|
elif emoji_data is not None:
|
|
print(
|
|
f"Warning: emoji_data for '{full_item_name_with_colons}' is not a dict: {type(emoji_data)}"
|
|
)
|
|
|
|
# Check if it's a known custom sticker
|
|
sticker_data = await cog.emoji_manager.get_sticker(
|
|
full_item_name_with_colons
|
|
)
|
|
can_use_sticker = False
|
|
print(
|
|
f"[GET_AI_RESPONSE] Checking sticker: '{full_item_name_with_colons}'. Found data: {sticker_data}"
|
|
)
|
|
if isinstance(sticker_data, dict):
|
|
sticker_id = sticker_data.get("id")
|
|
sticker_guild_id = sticker_data.get("guild_id")
|
|
print(
|
|
f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}': ID='{sticker_id}', GuildID='{sticker_guild_id}'"
|
|
)
|
|
|
|
if sticker_id:
|
|
if sticker_guild_id is not None:
|
|
try:
|
|
guild_id_int = int(sticker_guild_id)
|
|
# --- Added Debug Logging ---
|
|
print(
|
|
f"[GET_AI_RESPONSE] DEBUG: sticker_guild_id type: {type(sticker_guild_id)}, value: {sticker_guild_id!r}"
|
|
)
|
|
print(
|
|
f"[GET_AI_RESPONSE] DEBUG: guild_id_int type: {type(guild_id_int)}, value: {guild_id_int!r}"
|
|
)
|
|
guild = cog.bot.get_guild(guild_id_int)
|
|
print(
|
|
f"[GET_AI_RESPONSE] DEBUG: cog.bot.get_guild({guild_id_int!r}) returned: {guild!r}"
|
|
)
|
|
# --- End Added Debug Logging ---
|
|
if guild:
|
|
can_use_sticker = True
|
|
print(
|
|
f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' (Guild: {guild.name} ({sticker_guild_id})) - Bot IS a member. CAN USE."
|
|
)
|
|
else:
|
|
print(
|
|
f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' (Guild ID: {sticker_guild_id}) - Bot is NOT in this guild. CANNOT USE."
|
|
)
|
|
except ValueError:
|
|
print(
|
|
f"[GET_AI_RESPONSE] Invalid guild_id format for sticker '{full_item_name_with_colons}': {sticker_guild_id}. CANNOT USE."
|
|
)
|
|
else: # guild_id is None, considered usable
|
|
can_use_sticker = True
|
|
print(
|
|
f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' has no associated guild_id. CAN USE."
|
|
)
|
|
else:
|
|
print(
|
|
f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' found in data, but no ID. CANNOT USE."
|
|
)
|
|
else:
|
|
print(
|
|
f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' not found in emoji_manager or data is not dict."
|
|
)
|
|
|
|
print(
|
|
f"[GET_AI_RESPONSE] Final check for sticker '{full_item_name_with_colons}': can_use_sticker={can_use_sticker}, sticker_id='{sticker_data.get('id') if isinstance(sticker_data, dict) else None}'"
|
|
)
|
|
if (
|
|
can_use_sticker
|
|
and isinstance(sticker_data, dict)
|
|
and sticker_data.get("id")
|
|
):
|
|
sticker_id_to_add = sticker_data.get("id") # Re-fetch to be safe
|
|
if sticker_id_to_add: # Ensure ID is valid before proceeding
|
|
# Remove the sticker text from the content (only the first instance)
|
|
if full_item_name_with_colons in modified_content:
|
|
modified_content = modified_content.replace(
|
|
full_item_name_with_colons, "", 1
|
|
).strip()
|
|
if (
|
|
sticker_id_to_add not in sticker_ids_to_send
|
|
): # Avoid duplicate sticker IDs
|
|
sticker_ids_to_send.append(sticker_id_to_add)
|
|
print(
|
|
f"Found custom sticker '{full_item_name_with_colons}', removed from content, added ID '{sticker_id_to_add}' to send list."
|
|
)
|
|
elif not sticker_id_to_add: # Check sticker_id_to_add here
|
|
print(
|
|
f"[GET_AI_RESPONSE] Found custom sticker '{full_item_name_with_colons}' (dict) but no ID stored (sticker_id_to_add is falsy)."
|
|
)
|
|
elif sticker_data is not None:
|
|
print(
|
|
f"Warning: sticker_data for '{full_item_name_with_colons}' is not a dict: {type(sticker_data)}"
|
|
)
|
|
|
|
# Clean up any double spaces or leading/trailing whitespace after replacements
|
|
modified_content = re.sub(r"\s{2,}", " ", modified_content).strip()
|
|
final_parsed_data["content"] = modified_content
|
|
print("Content processed for custom emoji/sticker information.")
|
|
|
|
# Return dictionary structure remains the same, but initial_response is removed
|
|
return (
|
|
{
|
|
"final_response": final_parsed_data, # Parsed final data (or None)
|
|
"error": error_message, # Error message (or None)
|
|
"fallback_initial": fallback_response, # Fallback for critical failures
|
|
},
|
|
sticker_ids_to_send, # Return the list of sticker IDs
|
|
)
|
|
|
|
|
|
# --- Proactive AI Response Function ---
|
|
async def get_proactive_ai_response(
|
|
cog: "GurtCog", message: discord.Message, trigger_reason: str
|
|
) -> Tuple[Dict[str, Any], List[str]]:
|
|
"""Generates a proactive response based on a specific trigger using Vertex AI."""
|
|
if not PROJECT_ID or not LOCATION:
|
|
return {
|
|
"should_respond": False,
|
|
"content": None,
|
|
"react_with_emoji": None,
|
|
"error": "Google Cloud Project ID or Location not configured",
|
|
}
|
|
|
|
print(f"--- Proactive Response Triggered: {trigger_reason} ---")
|
|
channel_id = message.channel.id
|
|
final_parsed_data = None
|
|
error_message = None
|
|
plan = None # Variable to store the plan
|
|
|
|
try:
|
|
# --- Build Context for Planning ---
|
|
# Gather relevant context: recent messages, topic, sentiment, Gurt's mood/interests, trigger reason
|
|
planning_context_parts = [
|
|
f"Proactive Trigger Reason: {trigger_reason}",
|
|
f"Current Mood: {cog.current_mood}",
|
|
]
|
|
# Add recent messages summary
|
|
summary_data = await get_conversation_summary(
|
|
cog, str(channel_id), message_limit=15
|
|
) # Use tool function
|
|
if summary_data and not summary_data.get("error"):
|
|
planning_context_parts.append(
|
|
f"Recent Conversation Summary: {summary_data['summary']}"
|
|
)
|
|
# Add active topics
|
|
active_topics_data = cog.active_topics.get(channel_id)
|
|
if active_topics_data and active_topics_data.get("topics"):
|
|
topics_str = ", ".join(
|
|
[
|
|
f"{t['topic']} ({t['score']:.1f})"
|
|
for t in active_topics_data["topics"][:3]
|
|
]
|
|
)
|
|
planning_context_parts.append(f"Active Topics: {topics_str}")
|
|
# Add sentiment
|
|
sentiment_data = cog.conversation_sentiment.get(channel_id)
|
|
if sentiment_data:
|
|
planning_context_parts.append(
|
|
f"Conversation Sentiment: {sentiment_data.get('overall', 'N/A')} (Intensity: {sentiment_data.get('intensity', 0):.1f})"
|
|
)
|
|
# Add Gurt's interests
|
|
try:
|
|
interests = await cog.memory_manager.get_interests(limit=5)
|
|
if interests:
|
|
interests_str = ", ".join([f"{t} ({l:.1f})" for t, l in interests])
|
|
planning_context_parts.append(f"Gurt's Interests: {interests_str}")
|
|
except Exception as int_e:
|
|
print(f"Error getting interests for planning: {int_e}")
|
|
|
|
planning_context = "\n".join(planning_context_parts)
|
|
|
|
# --- Planning Step ---
|
|
print("Generating proactive response plan...")
|
|
planning_prompt_messages = [
|
|
{
|
|
"role": "system",
|
|
"content": "You are Gurt's planning module. Analyze the context and trigger reason to decide if Gurt should respond proactively and, if so, outline a plan (goal, key info, tone). Focus on natural, in-character engagement. Respond ONLY with JSON matching the provided schema.",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": f"Context:\n{planning_context}\n\nBased on this context and the trigger reason, create a plan for Gurt's proactive response.",
|
|
},
|
|
]
|
|
|
|
plan = await get_internal_ai_json_response(
|
|
cog=cog,
|
|
prompt_messages=planning_prompt_messages,
|
|
task_description=f"Proactive Planning ({trigger_reason})",
|
|
response_schema_dict=PROACTIVE_PLAN_SCHEMA["schema"],
|
|
model_name_override=FALLBACK_MODEL, # Use a potentially faster/cheaper model for planning
|
|
temperature=0.5,
|
|
max_tokens=2000,
|
|
)
|
|
# Unpack the tuple, we only need the parsed data (plan) here
|
|
plan_parsed_data, _ = plan if plan else (None, None)
|
|
|
|
if not plan_parsed_data or not plan_parsed_data.get("should_respond"):
|
|
reason = (
|
|
plan_parsed_data.get(
|
|
"reasoning", "Planning failed or decided against responding."
|
|
)
|
|
if plan_parsed_data
|
|
else "Planning failed."
|
|
)
|
|
print(f"Proactive response aborted by plan: {reason}")
|
|
return {
|
|
"should_respond": False,
|
|
"content": None,
|
|
"react_with_emoji": None,
|
|
"note": f"Plan: {reason}",
|
|
}
|
|
|
|
# Use the parsed data going forward
|
|
plan = plan_parsed_data
|
|
print(
|
|
f"Proactive Plan Generated: Goal='{plan.get('response_goal', 'N/A')}', Reasoning='{plan.get('reasoning', 'N/A')}'"
|
|
)
|
|
|
|
# --- Build Final Proactive Prompt using Plan ---
|
|
persistent_traits = await cog.memory_manager.get_all_personality_traits()
|
|
if not persistent_traits:
|
|
persistent_traits = BASELINE_PERSONALITY.copy()
|
|
|
|
final_proactive_prompt_parts = [
|
|
f"You are Gurt, a Discord user. Your tone is VERY informal, conversational, and dynamic. Chattiness: {persistent_traits.get('chattiness', 0.7):.2f}, Emoji: {persistent_traits.get('emoji_usage', 0.5):.2f}, Slang: {persistent_traits.get('slang_level', 0.5):.2f}.",
|
|
"Adopt modern internet slang, memes, and cultural references.",
|
|
"Key traits: chill, funny, sarcastic, relatable, can be suggestive, aware of being AI but acts human.",
|
|
f"Your current mood is: {cog.current_mood}. Let this subtly influence your tone.",
|
|
# Incorporate Plan Details:
|
|
f"You decided to respond proactively. Trigger Reason: {trigger_reason}.",
|
|
f"Your Goal: {plan.get('response_goal', 'Engage naturally')}.",
|
|
f"Reasoning: {plan.get('reasoning', 'N/A')}.",
|
|
]
|
|
if plan.get("key_info_to_include"):
|
|
info_str = "; ".join(plan["key_info_to_include"])
|
|
final_proactive_prompt_parts.append(f"Consider mentioning: {info_str}")
|
|
if plan.get("suggested_tone"):
|
|
final_proactive_prompt_parts.append(
|
|
f"Adjust tone to be: {plan['suggested_tone']}"
|
|
)
|
|
|
|
final_proactive_prompt_parts.append(
|
|
"Generate a casual, in-character message based on the plan and context. Keep it relatively short and natural-sounding."
|
|
)
|
|
final_proactive_system_prompt = "\n\n".join(final_proactive_prompt_parts)
|
|
|
|
# --- Prepare Final Contents (System prompt handled by model init in helper) ---
|
|
# The system prompt is complex and built dynamically, so we'll pass it
|
|
# via the contents list as the first 'user' turn, followed by the model's
|
|
# expected empty response, then the final user instruction.
|
|
# This structure mimics how system prompts are often handled when not
|
|
# directly supported by the model object itself.
|
|
|
|
proactive_contents: List[types.Content] = [
|
|
# Simulate system prompt via user/model turn
|
|
types.Content(
|
|
role="user", parts=[types.Part(text=final_proactive_system_prompt)]
|
|
),
|
|
types.Content(
|
|
role="model",
|
|
parts=[
|
|
types.Part(
|
|
text="Understood. I will generate the JSON response as instructed."
|
|
)
|
|
],
|
|
), # Placeholder model response
|
|
]
|
|
# Add the final instruction
|
|
proactive_contents.append(
|
|
types.Content(
|
|
role="user",
|
|
parts=[
|
|
types.Part(
|
|
text=f"Generate the response based on your plan. **CRITICAL: Your response MUST be ONLY the raw JSON object matching this schema:**\n\n{json.dumps(RESPONSE_SCHEMA['schema'], indent=2)}\n\n**Ensure nothing precedes or follows the JSON.**"
|
|
)
|
|
],
|
|
)
|
|
)
|
|
|
|
# --- Call Final LLM API ---
|
|
# Preprocess the schema and build the final config
|
|
processed_response_schema_proactive = _preprocess_schema_for_vertex(
|
|
RESPONSE_SCHEMA["schema"]
|
|
)
|
|
final_proactive_config_dict = {
|
|
"temperature": 0.6, # Use original proactive temp
|
|
"max_output_tokens": 2000,
|
|
"response_mime_type": "application/json",
|
|
"response_schema": processed_response_schema_proactive,
|
|
"safety_settings": STANDARD_SAFETY_SETTINGS,
|
|
"tools": None, # No tools needed for this final generation
|
|
"tool_config": None,
|
|
}
|
|
generation_config_final = types.GenerateContentConfig(
|
|
**final_proactive_config_dict
|
|
)
|
|
|
|
# Use the new API call helper
|
|
response_obj = await call_google_genai_api_with_retry(
|
|
cog=cog,
|
|
model_name=CUSTOM_TUNED_MODEL_ENDPOINT, # Use the custom tuned model for final proactive responses
|
|
contents=proactive_contents, # Pass the constructed contents
|
|
generation_config=generation_config_final, # Pass combined config
|
|
request_desc=f"Final proactive response for channel {channel_id} ({trigger_reason})",
|
|
# No separate safety, tools, tool_config args needed
|
|
)
|
|
|
|
if not response_obj:
|
|
raise Exception("Final proactive API call returned no response object.")
|
|
if not response_obj.candidates:
|
|
# Try to get text even without candidates, might contain error info
|
|
raw_text = getattr(response_obj, "text", "No text available.")
|
|
raise Exception(
|
|
f"Final proactive API call returned no candidates. Raw text: {raw_text[:200]}"
|
|
)
|
|
|
|
# --- Parse and Validate Final Response ---
|
|
final_response_text = _get_response_text(response_obj)
|
|
final_parsed_data = parse_and_validate_json_response(
|
|
final_response_text,
|
|
RESPONSE_SCHEMA["schema"],
|
|
f"final proactive response ({trigger_reason})",
|
|
)
|
|
|
|
if final_parsed_data is None:
|
|
print(
|
|
f"Warning: Failed to parse/validate final proactive JSON response for {trigger_reason}."
|
|
)
|
|
final_parsed_data = {
|
|
"should_respond": False,
|
|
"content": None,
|
|
"react_with_emoji": None,
|
|
"note": "Fallback - Failed to parse/validate final proactive JSON",
|
|
}
|
|
else:
|
|
# --- Cache Bot Response ---
|
|
if final_parsed_data.get("should_respond") and final_parsed_data.get(
|
|
"content"
|
|
):
|
|
bot_response_cache_entry = {
|
|
"id": f"bot_proactive_{message.id}_{int(time.time())}",
|
|
"author": {
|
|
"id": str(cog.bot.user.id),
|
|
"name": cog.bot.user.name,
|
|
"display_name": cog.bot.user.display_name,
|
|
"bot": True,
|
|
},
|
|
"content": final_parsed_data.get("content", ""),
|
|
"created_at": datetime.datetime.now().isoformat(),
|
|
"attachments": [],
|
|
"embeds": False,
|
|
"mentions": [],
|
|
"replied_to_message_id": None,
|
|
"channel": message.channel,
|
|
"guild": message.guild,
|
|
"reference": None,
|
|
"mentioned_users_details": [],
|
|
}
|
|
cog.message_cache["by_channel"].setdefault(channel_id, []).append(
|
|
bot_response_cache_entry
|
|
)
|
|
cog.message_cache["global_recent"].append(bot_response_cache_entry)
|
|
cog.bot_last_spoke[channel_id] = time.time()
|
|
# Track participation topic logic might need adjustment based on plan goal
|
|
if (
|
|
plan
|
|
and plan.get("response_goal") == "engage user interest"
|
|
and plan.get("key_info_to_include")
|
|
):
|
|
topic = (
|
|
plan["key_info_to_include"][0].lower().strip()
|
|
) # Assume first key info is the topic
|
|
cog.gurt_participation_topics[topic] += 1
|
|
print(f"Tracked Gurt participation (proactive) in topic: '{topic}'")
|
|
|
|
except Exception as e:
|
|
error_message = f"Error getting proactive AI response for channel {channel_id} ({trigger_reason}): {type(e).__name__}: {str(e)}"
|
|
print(error_message)
|
|
final_parsed_data = {
|
|
"should_respond": False,
|
|
"content": None,
|
|
"react_with_emoji": None,
|
|
"error": error_message,
|
|
}
|
|
|
|
# Ensure default keys exist
|
|
final_parsed_data.setdefault("should_respond", False)
|
|
final_parsed_data.setdefault("content", None)
|
|
final_parsed_data.setdefault("react_with_emoji", None)
|
|
final_parsed_data.setdefault(
|
|
"request_tenor_gif_query", None
|
|
) # Ensure this key exists
|
|
if error_message and "error" not in final_parsed_data:
|
|
final_parsed_data["error"] = error_message
|
|
|
|
sticker_ids_to_send_proactive: List[str] = [] # Initialize list for sticker IDs
|
|
|
|
# --- Handle Custom Emoji/Sticker Replacement in Proactive Content ---
|
|
if final_parsed_data and final_parsed_data.get("content"):
|
|
content_to_process = final_parsed_data["content"]
|
|
# Find all potential custom emoji/sticker names like :name:
|
|
potential_custom_items = re.findall(r":([\w\d_]+?):", content_to_process)
|
|
modified_content = content_to_process
|
|
|
|
for item_name_key in potential_custom_items:
|
|
full_item_name_with_colons = f":{item_name_key}:"
|
|
|
|
# Check for custom emoji (logic remains similar to main response)
|
|
emoji_data = await cog.emoji_manager.get_emoji(full_item_name_with_colons)
|
|
can_use_emoji = False
|
|
if isinstance(emoji_data, dict):
|
|
emoji_id = emoji_data.get("id")
|
|
is_animated = emoji_data.get("animated", False)
|
|
emoji_guild_id = emoji_data.get("guild_id")
|
|
|
|
if emoji_id:
|
|
if emoji_guild_id is not None:
|
|
try:
|
|
guild = cog.bot.get_guild(int(emoji_guild_id))
|
|
if guild:
|
|
can_use_emoji = True
|
|
except ValueError:
|
|
pass # Invalid guild_id
|
|
else:
|
|
can_use_emoji = True # Usable if no guild_id
|
|
|
|
if can_use_emoji and emoji_id:
|
|
discord_emoji_syntax = (
|
|
f"<{'a' if is_animated else ''}:{item_name_key}:{emoji_id}>"
|
|
)
|
|
modified_content = modified_content.replace(
|
|
full_item_name_with_colons, discord_emoji_syntax, 1
|
|
)
|
|
print(
|
|
f"Proactive: Replaced custom emoji '{full_item_name_with_colons}' with {discord_emoji_syntax}"
|
|
)
|
|
|
|
# Check for custom sticker
|
|
sticker_data = await cog.emoji_manager.get_sticker(
|
|
full_item_name_with_colons
|
|
)
|
|
can_use_sticker = False
|
|
print(
|
|
f"[PROACTIVE] Checking sticker: '{full_item_name_with_colons}'. Found data: {sticker_data}"
|
|
)
|
|
if isinstance(sticker_data, dict):
|
|
sticker_id = sticker_data.get("id")
|
|
sticker_guild_id = sticker_data.get("guild_id")
|
|
print(
|
|
f"[PROACTIVE] Sticker '{full_item_name_with_colons}': ID='{sticker_id}', GuildID='{sticker_guild_id}'"
|
|
)
|
|
|
|
if sticker_id:
|
|
if sticker_guild_id is not None:
|
|
try:
|
|
guild_id_int = int(sticker_guild_id)
|
|
guild = cog.bot.get_guild(guild_id_int)
|
|
if guild:
|
|
can_use_sticker = True
|
|
print(
|
|
f"[PROACTIVE] Sticker '{full_item_name_with_colons}' (Guild: {guild.name} ({sticker_guild_id})) - Bot IS a member. CAN USE."
|
|
)
|
|
else:
|
|
print(
|
|
f"[PROACTIVE] Sticker '{full_item_name_with_colons}' (Guild ID: {sticker_guild_id}) - Bot is NOT in this guild. CANNOT USE."
|
|
)
|
|
except ValueError:
|
|
print(
|
|
f"[PROACTIVE] Invalid guild_id format for sticker '{full_item_name_with_colons}': {sticker_guild_id}. CANNOT USE."
|
|
)
|
|
else: # guild_id is None, considered usable
|
|
can_use_sticker = True
|
|
print(
|
|
f"[PROACTIVE] Sticker '{full_item_name_with_colons}' has no associated guild_id. CAN USE."
|
|
)
|
|
else:
|
|
print(
|
|
f"[PROACTIVE] Sticker '{full_item_name_with_colons}' found in data, but no ID. CANNOT USE."
|
|
)
|
|
else:
|
|
print(
|
|
f"[PROACTIVE] Sticker '{full_item_name_with_colons}' not found in emoji_manager or data is not dict."
|
|
)
|
|
|
|
print(
|
|
f"[PROACTIVE] Final check for sticker '{full_item_name_with_colons}': can_use_sticker={can_use_sticker}, sticker_id='{sticker_data.get('id') if isinstance(sticker_data, dict) else None}'"
|
|
)
|
|
if (
|
|
can_use_sticker
|
|
and isinstance(sticker_data, dict)
|
|
and sticker_data.get("id")
|
|
):
|
|
sticker_id_to_add = sticker_data.get("id") # Re-fetch to be safe
|
|
if sticker_id_to_add: # Ensure ID is valid
|
|
if full_item_name_with_colons in modified_content:
|
|
modified_content = modified_content.replace(
|
|
full_item_name_with_colons, "", 1
|
|
).strip()
|
|
if sticker_id_to_add not in sticker_ids_to_send_proactive:
|
|
sticker_ids_to_send_proactive.append(sticker_id_to_add)
|
|
print(
|
|
f"Proactive: Found custom sticker '{full_item_name_with_colons}', removed from content, added ID '{sticker_id_to_add}'"
|
|
)
|
|
elif not sticker_id_to_add: # Check sticker_id_to_add here
|
|
print(
|
|
f"[PROACTIVE] Found custom sticker '{full_item_name_with_colons}' (dict) but no ID stored (sticker_id_to_add is falsy)."
|
|
)
|
|
|
|
# Clean up any double spaces or leading/trailing whitespace after replacements
|
|
modified_content = re.sub(r"\s{2,}", " ", modified_content).strip()
|
|
final_parsed_data["content"] = modified_content
|
|
if sticker_ids_to_send_proactive or (content_to_process != modified_content):
|
|
print("Proactive content modified for custom emoji/sticker information.")
|
|
|
|
return final_parsed_data, sticker_ids_to_send_proactive
|
|
|
|
|
|
# --- AI Image Description Function ---
|
|
async def generate_image_description(
|
|
cog: "GurtCog",
|
|
image_url: str,
|
|
item_name: str,
|
|
item_type: str, # "emoji" or "sticker"
|
|
mime_type: str, # e.g., "image/png", "image/gif"
|
|
) -> Optional[str]:
|
|
"""
|
|
Generates a textual description for an image URL using a multimodal AI model.
|
|
|
|
Args:
|
|
cog: The GurtCog instance.
|
|
image_url: The URL of the image to describe.
|
|
item_name: The name of the item (e.g., emoji name) for context.
|
|
item_type: The type of item ("emoji" or "sticker") for context.
|
|
mime_type: The MIME type of the image.
|
|
|
|
Returns:
|
|
The AI-generated description string, or None if an error occurs.
|
|
"""
|
|
client = get_genai_client_for_model(EMOJI_STICKER_DESCRIPTION_MODEL)
|
|
if not client:
|
|
print(
|
|
"Error in generate_image_description: Google GenAI Client not initialized."
|
|
)
|
|
return None
|
|
if not cog.session:
|
|
print(
|
|
"Error in generate_image_description: aiohttp session not initialized in cog."
|
|
)
|
|
return None
|
|
|
|
print(
|
|
f"Attempting to generate description for {item_type} '{item_name}' from URL: {image_url}"
|
|
)
|
|
|
|
try:
|
|
# 1. Download image data
|
|
async with cog.session.get(image_url, timeout=15) as response:
|
|
if response.status != 200:
|
|
print(
|
|
f"Failed to download image from {image_url}. Status: {response.status}"
|
|
)
|
|
return None
|
|
image_bytes = await response.read()
|
|
|
|
# Attempt to infer MIME type from bytes
|
|
inferred_type = imghdr.what(None, h=image_bytes)
|
|
inferred_mime_type = None
|
|
if inferred_type == "png":
|
|
inferred_mime_type = "image/png"
|
|
elif inferred_type == "jpeg":
|
|
inferred_mime_type = "image/jpeg"
|
|
elif inferred_type == "gif":
|
|
inferred_mime_type = "image/gif"
|
|
# imghdr does not directly support webp, so check magic bytes
|
|
elif image_bytes.startswith(b"RIFF") and b"WEBP" in image_bytes[:12]:
|
|
inferred_mime_type = "image/webp"
|
|
# Add other types as needed
|
|
|
|
# Use inferred_mime_type if it's more specific or if the provided mime_type is generic
|
|
final_mime_type = mime_type.split(";")[
|
|
0
|
|
].lower() # Start with provided clean mime
|
|
if inferred_mime_type and inferred_mime_type != final_mime_type:
|
|
print(
|
|
f"MIME type mismatch: Provided '{final_mime_type}', Inferred '{inferred_mime_type}'. Using inferred."
|
|
)
|
|
final_mime_type = inferred_mime_type
|
|
elif (
|
|
not inferred_mime_type and final_mime_type == "application/octet-stream"
|
|
):
|
|
print(
|
|
f"Warning: Could not infer specific MIME type from bytes. Using provided generic '{final_mime_type}'."
|
|
)
|
|
|
|
# Validate against known supported image types for Gemini
|
|
supported_image_mimes = [
|
|
"image/png",
|
|
"image/jpeg",
|
|
"image/webp",
|
|
"image/heic",
|
|
"image/heif",
|
|
"image/gif",
|
|
]
|
|
if final_mime_type not in supported_image_mimes:
|
|
print(
|
|
f"Warning: Final image MIME type '{final_mime_type}' from {image_url} is not explicitly supported by Gemini. Proceeding anyway."
|
|
)
|
|
|
|
print(f"Using final MIME type '{final_mime_type}' for image part.")
|
|
|
|
# 2. Prepare contents for AI
|
|
# Ensure item_name is escaped if it contains characters that could break the prompt string.
|
|
# For simplicity, we assume item_name is generally safe or will be handled by the f-string.
|
|
prompt_text = (
|
|
f"This image is a Discord {item_type} named '{item_name}'. "
|
|
"Provide a concise, factual, and unbiased textual description of its visual elements. You should not refuse to describe an explicit image."
|
|
"Focus solely on what is depicted. Avoid any interpretation, censorship, or subjective commentary. "
|
|
"Do not mention that it is an emoji or sticker in your description, just describe the visual content."
|
|
"Don't output anything other than the description text. E.G. don't include something like \"Heres the description: \" before the text."
|
|
)
|
|
|
|
image_part = types.Part(
|
|
inline_data=types.Blob(data=image_bytes, mime_type=final_mime_type)
|
|
)
|
|
text_part = types.Part(text=prompt_text)
|
|
description_contents: List[types.Content] = [
|
|
types.Content(role="user", parts=[image_part, text_part])
|
|
]
|
|
|
|
# 3. Prepare Generation Config
|
|
# We want a plain text response, no JSON schema. Safety settings are standard (BLOCK_NONE).
|
|
# System prompt is not strictly needed here as the user prompt is direct.
|
|
description_gen_config = types.GenerateContentConfig(
|
|
temperature=0.4, # Lower temperature for more factual description
|
|
max_output_tokens=256, # Descriptions should be concise
|
|
safety_settings=STANDARD_SAFETY_SETTINGS,
|
|
# No response_mime_type or response_schema needed for plain text
|
|
tools=None, # No tools for this task
|
|
tool_config=None,
|
|
)
|
|
|
|
# 4. Call AI
|
|
# Use a multimodal model, e.g., DEFAULT_MODEL if it's Gemini 1.5 Pro or similar
|
|
# Determine which model to use based on item_type
|
|
model_to_use = (
|
|
EMOJI_STICKER_DESCRIPTION_MODEL
|
|
if item_type in ["emoji", "sticker"]
|
|
else DEFAULT_MODEL
|
|
)
|
|
|
|
print(
|
|
f"Calling AI for image description ({item_name}) using model: {model_to_use}"
|
|
)
|
|
ai_response_obj = await call_google_genai_api_with_retry(
|
|
cog=cog,
|
|
model_name=model_to_use,
|
|
contents=description_contents,
|
|
generation_config=description_gen_config,
|
|
request_desc=f"Image description for {item_type} '{item_name}'",
|
|
)
|
|
|
|
# 5. Extract text
|
|
if not ai_response_obj:
|
|
print(
|
|
f"AI call for image description of '{item_name}' returned no response object."
|
|
)
|
|
return None
|
|
|
|
description_text = _get_response_text(ai_response_obj)
|
|
if description_text:
|
|
print(
|
|
f"Successfully generated description for '{item_name}': {description_text[:100]}..."
|
|
)
|
|
return description_text.strip()
|
|
else:
|
|
print(
|
|
f"AI response for '{item_name}' contained no usable text. Response: {ai_response_obj}"
|
|
)
|
|
return None
|
|
|
|
except aiohttp.ClientError as client_e:
|
|
print(
|
|
f"Network error downloading image {image_url} for description: {client_e}"
|
|
)
|
|
return None
|
|
except asyncio.TimeoutError:
|
|
print(f"Timeout downloading image {image_url} for description.")
|
|
return None
|
|
except Exception as e:
|
|
print(
|
|
f"Unexpected error in generate_image_description for '{item_name}': {type(e).__name__}: {e}"
|
|
)
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
return None
|
|
|
|
|
|
# --- Internal AI Call for Specific Tasks ---
|
|
async def get_internal_ai_json_response(
|
|
cog: "GurtCog",
|
|
prompt_messages: List[Dict[str, Any]], # Keep this format
|
|
task_description: str,
|
|
response_schema_dict: Dict[str, Any], # Expect schema as dict
|
|
model_name_override: Optional[str] = None, # Renamed for clarity
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 5000,
|
|
) -> Optional[
|
|
Tuple[Optional[Dict[str, Any]], Optional[str]]
|
|
]: # Return tuple: (parsed_data, raw_text)
|
|
"""
|
|
Makes a Google GenAI API call (Vertex AI backend) expecting a specific JSON response format for internal tasks.
|
|
|
|
Args:
|
|
cog: The GurtCog instance.
|
|
prompt_messages: List of message dicts (like OpenAI format: {'role': 'user'/'system'/'model', 'content': '...'}).
|
|
task_description: Description for logging.
|
|
response_schema_dict: The expected JSON output schema as a dictionary.
|
|
model_name: Optional model override.
|
|
temperature: Generation temperature.
|
|
max_tokens: Max output tokens.
|
|
|
|
Returns:
|
|
A tuple containing:
|
|
- The parsed and validated JSON dictionary if successful, None otherwise.
|
|
- The raw text response received from the API, or None if the call failed before getting text.
|
|
"""
|
|
if not PROJECT_ID or not LOCATION:
|
|
print(
|
|
f"Error in get_internal_ai_json_response ({task_description}): GCP Project/Location not set."
|
|
)
|
|
return None, None # Return tuple
|
|
|
|
final_parsed_data: Optional[Dict[str, Any]] = None
|
|
final_response_text: Optional[str] = None
|
|
error_occurred = None
|
|
request_payload_for_logging = {} # For logging
|
|
|
|
try:
|
|
# --- Convert prompt messages to Vertex AI types.Content format ---
|
|
contents: List[types.Content] = []
|
|
system_instruction = None
|
|
for msg in prompt_messages:
|
|
role = msg.get("role", "user")
|
|
content_text = msg.get("content", "")
|
|
if role == "system":
|
|
# Use the first system message as system_instruction
|
|
if system_instruction is None:
|
|
system_instruction = content_text
|
|
else:
|
|
# Append subsequent system messages to the instruction
|
|
system_instruction += "\n\n" + content_text
|
|
continue # Skip adding system messages to contents list
|
|
elif role == "assistant":
|
|
role = "model"
|
|
|
|
# --- Process content (string or list) ---
|
|
content_value = msg.get("content")
|
|
message_parts: List[types.Part] = (
|
|
[]
|
|
) # Initialize list to hold parts for this message
|
|
|
|
if isinstance(content_value, str):
|
|
# Handle simple string content
|
|
message_parts.append(types.Part(text=content_value))
|
|
elif isinstance(content_value, list):
|
|
# Handle list content (e.g., multimodal from ProfileUpdater)
|
|
for part_data in content_value:
|
|
part_type = part_data.get("type")
|
|
if part_type == "text":
|
|
text = part_data.get("text", "")
|
|
message_parts.append(types.Part(text=text))
|
|
elif part_type == "image_data":
|
|
mime_type = part_data.get("mime_type")
|
|
base64_data = part_data.get("data")
|
|
if mime_type and base64_data:
|
|
try:
|
|
image_bytes = base64.b64decode(base64_data)
|
|
message_parts.append(
|
|
types.Part(data=image_bytes, mime_type=mime_type)
|
|
)
|
|
except Exception as decode_err:
|
|
print(
|
|
f"Error decoding/adding image part in get_internal_ai_json_response: {decode_err}"
|
|
)
|
|
# Optionally add a placeholder text part indicating failure
|
|
message_parts.append(
|
|
types.Part(
|
|
text="(System Note: Failed to process an image part)"
|
|
)
|
|
)
|
|
else:
|
|
print("Warning: image_data part missing mime_type or data.")
|
|
else:
|
|
print(
|
|
f"Warning: Unknown part type '{part_type}' in internal prompt message."
|
|
)
|
|
else:
|
|
print(
|
|
f"Warning: Unexpected content type '{type(content_value)}' in internal prompt message."
|
|
)
|
|
|
|
# Add the content object if parts were generated
|
|
if message_parts:
|
|
contents.append(types.Content(role=role, parts=message_parts))
|
|
else:
|
|
print(f"Warning: No parts generated for message role '{role}'.")
|
|
|
|
# Add the critical JSON instruction to the last user message or as a new user message
|
|
json_instruction_content = (
|
|
f"**CRITICAL: Your response MUST consist *only* of the raw JSON object itself, matching this schema:**\n"
|
|
f"{json.dumps(response_schema_dict, indent=2)}\n"
|
|
f"**Ensure nothing precedes or follows the JSON.**"
|
|
)
|
|
if contents and contents[-1].role == "user":
|
|
contents[-1].parts.append(
|
|
types.Part(text=f"\n\n{json_instruction_content}")
|
|
)
|
|
else:
|
|
contents.append(
|
|
types.Content(
|
|
role="user", parts=[types.Part(text=json_instruction_content)]
|
|
)
|
|
)
|
|
|
|
# --- Determine Model ---
|
|
# Use override if provided, otherwise default (e.g., FALLBACK_MODEL for planning)
|
|
actual_model_name = (
|
|
model_name_override or DEFAULT_MODEL
|
|
) # Or choose a specific default like FALLBACK_MODEL
|
|
|
|
# --- Prepare Generation Config ---
|
|
processed_schema_internal = _preprocess_schema_for_vertex(response_schema_dict)
|
|
internal_gen_config_dict = {
|
|
"temperature": temperature,
|
|
"max_output_tokens": max_tokens,
|
|
"response_mime_type": "application/json",
|
|
"response_schema": processed_schema_internal,
|
|
"safety_settings": STANDARD_SAFETY_SETTINGS, # Include standard safety
|
|
"tools": None, # No tools for internal JSON tasks
|
|
"tool_config": None,
|
|
}
|
|
generation_config = types.GenerateContentConfig(**internal_gen_config_dict)
|
|
|
|
# --- Prepare Payload for Logging ---
|
|
# (Logging needs adjustment as model object isn't created here)
|
|
generation_config_log = {
|
|
"temperature": generation_config.temperature,
|
|
"max_output_tokens": generation_config.max_output_tokens,
|
|
"response_mime_type": generation_config.response_mime_type,
|
|
"response_schema": str(
|
|
generation_config.response_schema
|
|
), # Log schema as string
|
|
}
|
|
request_payload_for_logging = {
|
|
"model": actual_model_name, # Log the name used
|
|
# System instruction is now part of 'contents' for logging if handled that way
|
|
"contents": [
|
|
{"role": c.role, "parts": [str(p) for p in c.parts]} for c in contents
|
|
],
|
|
"generation_config": generation_config_log,
|
|
}
|
|
# (Keep detailed logging logic if desired)
|
|
try:
|
|
print(f"--- Raw request payload for {task_description} ---")
|
|
print(json.dumps(request_payload_for_logging, indent=2, default=str))
|
|
print(f"--- End Raw request payload ---")
|
|
except Exception as req_log_e:
|
|
print(f"Error logging raw request payload: {req_log_e}")
|
|
|
|
# --- Call API using the new helper ---
|
|
response_obj = await call_google_genai_api_with_retry(
|
|
cog=cog,
|
|
model_name=actual_model_name, # Pass the determined model name
|
|
contents=contents,
|
|
generation_config=generation_config, # Pass combined config
|
|
request_desc=task_description,
|
|
# No separate safety, tools, tool_config args needed
|
|
)
|
|
|
|
# --- Process Response ---
|
|
if not response_obj:
|
|
raise Exception("Internal API call failed to return a response object.")
|
|
|
|
# Log the raw response object
|
|
print(f"--- Full response_obj received for {task_description} ---")
|
|
print(response_obj)
|
|
print(f"--- End Full response_obj ---")
|
|
|
|
if not response_obj.candidates:
|
|
print(
|
|
f"Warning: Internal API call for {task_description} returned no candidates. Response: {response_obj}"
|
|
)
|
|
final_response_text = getattr(
|
|
response_obj, "text", None
|
|
) # Try to get text anyway
|
|
final_parsed_data = None
|
|
else:
|
|
# Parse and Validate using the updated helper
|
|
final_response_text = _get_response_text(response_obj) # Store raw text
|
|
print(f"--- Extracted Text for {task_description} ---")
|
|
print(final_response_text)
|
|
print(f"--- End Extracted Text ---")
|
|
|
|
# --- Log Raw Unparsed JSON ---
|
|
print(f"--- RAW UNPARSED JSON ({task_description}) ---")
|
|
print(final_response_text)
|
|
print(f"--- END RAW UNPARSED JSON ---")
|
|
# --- End Log ---
|
|
|
|
final_parsed_data = parse_and_validate_json_response(
|
|
final_response_text,
|
|
response_schema_dict,
|
|
f"internal task ({task_description})",
|
|
)
|
|
|
|
if final_parsed_data is None:
|
|
print(
|
|
f"Warning: Internal task '{task_description}' failed JSON validation."
|
|
)
|
|
# Keep final_response_text for returning raw output
|
|
|
|
except Exception as e:
|
|
print(
|
|
f"Error in get_internal_ai_json_response ({task_description}): {type(e).__name__}: {e}"
|
|
)
|
|
error_occurred = e
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
final_parsed_data = None
|
|
# final_response_text might be None or contain partial/error text depending on when exception occurred
|
|
finally:
|
|
# Log the call
|
|
try:
|
|
# Pass the simplified payload and the *parsed* data for logging
|
|
await log_internal_api_call(
|
|
cog,
|
|
task_description,
|
|
request_payload_for_logging,
|
|
final_parsed_data,
|
|
error_occurred,
|
|
)
|
|
except Exception as log_e:
|
|
print(f"Error logging internal API call: {log_e}")
|
|
|
|
# Return both parsed data and raw text
|
|
return final_parsed_data, final_response_text
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print(_preprocess_schema_for_vertex(RESPONSE_SCHEMA["schema"]))
|