This commit is contained in:
Slipstream 2025-04-29 17:29:20 -06:00
parent f2003bfef7
commit 6248577f89
Signed by: slipstream
GPG Key ID: 13E498CE010AC6FD

View File

@ -291,11 +291,9 @@ async def call_google_genai_api_with_retry(
cog: 'GurtCog',
model_name: str, # Pass model name string instead of model object
contents: List[types.Content], # Use types.Content type from google.generativeai.types
generation_config: types.GenerateContentConfig, # Use specific type
safety_settings: Optional[List[types.SafetySetting]], # Use specific type
generation_config: types.GenerateContentConfig, # Combined config object
request_desc: str,
tools: Optional[List[types.Tool]] = None, # Pass tools list if needed
tool_config: Optional[types.ToolConfig] = None
# Removed safety_settings, tools, tool_config as separate params
) -> Optional[types.GenerateContentResponse]: # Return type for non-streaming
"""
Calls the Google Generative AI API (Vertex AI backend) with retry logic (non-streaming).
@ -304,11 +302,8 @@ async def call_google_genai_api_with_retry(
cog: The GurtCog instance.
model_name: The name/path of the model to use (e.g., 'models/gemini-1.5-pro-preview-0409' or custom endpoint path).
contents: The list of types.Content objects for the prompt.
generation_config: The types.GenerateContentConfig object.
safety_settings: Safety settings for the request (List[types.SafetySetting]).
generation_config: The types.GenerateContentConfig object, which should now include temperature, top_p, max_output_tokens, safety_settings, tools, tool_config, response_mime_type, response_schema etc. as needed.
request_desc: A description of the request for logging purposes.
tools: Optional list of Tool objects for function calling.
tool_config: Optional ToolConfig object.
Returns:
The GenerateContentResponse object if successful, or None on failure after retries.
@ -337,14 +332,11 @@ async def call_google_genai_api_with_retry(
try:
print(f"Sending API request for {request_desc} using {model_name} (Attempt {attempt + 1}/{API_RETRY_ATTEMPTS + 1})...")
# Use the non-streaming async call
# Use the non-streaming async call - config now contains all settings
response = await genai_client.models.generate_content(
contents=contents,
config=generation_config,
safety_settings=safety_settings or STANDARD_SAFETY_SETTINGS,
tools=tools,
tool_config=tool_config,
# stream=False is implicit for generate_content_async
config=generation_config, # Pass the combined config object
# stream=False is implicit for generate_content
)
# --- Check Finish Reason (Safety) ---
@ -800,24 +792,15 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name:
tools_list = [vertex_tool] if vertex_tool else None
# --- Prepare Generation Config ---
# Use settings from user example and config.py
# Note: response_modalities and speech_config from user example are not standard genai config
generation_config = types.GenerateContentConfig(
temperature=0.85, # From user example
top_p=0.95, # From user example
max_output_tokens=8192, # From user example
# response_mime_type="application/json", # Set this later for the final JSON call
# response_schema=... # Set this later for the final JSON call
# stop_sequences=... # Add if needed
# Base generation config settings (will be augmented later)
base_generation_config_dict = {
"temperature": 0.85, # From user example
"top_p": 0.95, # From user example
"max_output_tokens": 8192, # From user example
"safety_settings": STANDARD_SAFETY_SETTINGS, # Include standard safety settings
# candidate_count=1 # Default is 1
)
# Tool config for the loop (allow any tool call)
tool_config_any = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
mode=types.FunctionCallingConfigMode.ANY # Use ANY to allow model to call tools
)
) if vertex_tool else None
# stop_sequences=... # Add if needed
}
# --- Tool Execution Loop ---
while tool_calls_made < max_tool_calls:
@ -832,23 +815,27 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name:
print(f"Error logging raw request/response: {log_e}")
# --- Call API using the new helper ---
# Use a temporary config for tool checking (no JSON enforcement yet)
current_gen_config = types.GenerateContentConfig(
temperature=generation_config.temperature, # Use base temp or specific tool temp
top_p=generation_config.top_p,
max_output_tokens=generation_config.max_output_tokens,
# Omit response_mime_type and response_schema here
)
# Build the config for this specific call (tool check)
current_gen_config_dict = base_generation_config_dict.copy()
if tools_list:
current_gen_config_dict["tools"] = tools_list
# Define tool_config here if needed, e.g., for ANY mode
current_gen_config_dict["tool_config"] = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
mode=types.FunctionCallingConfigMode.ANY
)
)
# Omit response_mime_type and response_schema for tool checking
current_gen_config = types.GenerateContentConfig(**current_gen_config_dict)
current_response_obj = await call_google_genai_api_with_retry(
cog=cog,
model_name=target_model_name, # Pass the model name string
contents=contents,
generation_config=current_gen_config, # Use temp config
safety_settings=STANDARD_SAFETY_SETTINGS, # Use the new list format
generation_config=current_gen_config, # Pass the combined config
request_desc=f"Tool Check {tool_calls_made + 1} for message {message.id}",
tools=tools_list, # Pass the Tool object list
tool_config=tool_config_any
# No separate safety, tools, tool_config args needed
)
last_response_obj = current_response_obj # Store the latest response
@ -946,25 +933,26 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name:
print(f"Making dedicated final API call for JSON ({log_reason})...")
# Prepare the final generation config with JSON enforcement
processed_response_schema = _preprocess_schema_for_vertex(RESPONSE_SCHEMA['schema']) # Keep using this helper for now
generation_config_final_json = types.GenerateContentConfig(
temperature=generation_config.temperature, # Use original temp
top_p=generation_config.top_p,
max_output_tokens=generation_config.max_output_tokens,
response_mime_type="application/json",
response_schema=processed_response_schema # Pass the schema here
)
processed_response_schema = _preprocess_schema_for_vertex(RESPONSE_SCHEMA['schema'])
final_gen_config_dict = base_generation_config_dict.copy() # Start with base
final_gen_config_dict.update({
"response_mime_type": "application/json",
"response_schema": processed_response_schema,
# Explicitly exclude tools/tool_config for final JSON generation
"tools": None,
"tool_config": None
})
generation_config_final_json = types.GenerateContentConfig(**final_gen_config_dict)
# Make the final call *without* tools enabled
# Make the final call *without* tools enabled (handled by config)
final_json_response_obj = await call_google_genai_api_with_retry(
cog=cog,
model_name=target_model_name, # Use the target model
contents=contents, # Pass the accumulated history
generation_config=generation_config_final_json, # Use JSON config
safety_settings=STANDARD_SAFETY_SETTINGS,
generation_config=generation_config_final_json, # Use combined JSON config
request_desc=f"Final JSON Generation (dedicated call) for message {message.id}",
tools=None, # Explicitly disable tools for final JSON generation
tool_config=None
# No separate safety, tools, tool_config args needed
)
if not final_json_response_obj:
@ -1137,25 +1125,28 @@ async def get_proactive_ai_response(cog: 'GurtCog', message: discord.Message, tr
# --- Call Final LLM API ---
# Preprocess the schema
# Preprocess the schema and build the final config
processed_response_schema_proactive = _preprocess_schema_for_vertex(RESPONSE_SCHEMA['schema'])
generation_config_final = types.GenerateContentConfig(
temperature=0.6, # Use original proactive temp
max_output_tokens=2000,
response_mime_type="application/json",
response_schema=processed_response_schema_proactive # Use preprocessed schema
)
final_proactive_config_dict = {
"temperature": 0.6, # Use original proactive temp
"max_output_tokens": 2000,
"response_mime_type": "application/json",
"response_schema": processed_response_schema_proactive,
"safety_settings": STANDARD_SAFETY_SETTINGS,
"tools": None, # No tools needed for this final generation
"tool_config": None
}
generation_config_final = types.GenerateContentConfig(**final_proactive_config_dict)
# Use the new API call helper
response_obj = await call_google_genai_api_with_retry(
cog=cog,
model_name=DEFAULT_MODEL, # Use the default model for proactive responses
contents=proactive_contents, # Pass the constructed contents
generation_config=generation_config_final,
safety_settings=STANDARD_SAFETY_SETTINGS,
generation_config=generation_config_final, # Pass combined config
request_desc=f"Final proactive response for channel {channel_id} ({trigger_reason})",
tools=None, # No tools needed for this final generation
tool_config=None
# No separate safety, tools, tool_config args needed
)
if not response_obj:
@ -1320,12 +1311,16 @@ async def get_internal_ai_json_response(
# --- Prepare Generation Config ---
processed_schema_internal = _preprocess_schema_for_vertex(response_schema_dict)
generation_config = types.GenerateContentConfig(
temperature=temperature,
max_output_tokens=max_tokens,
response_mime_type="application/json",
response_schema=processed_schema_internal # Use preprocessed schema
)
internal_gen_config_dict = {
"temperature": temperature,
"max_output_tokens": max_tokens,
"response_mime_type": "application/json",
"response_schema": processed_schema_internal,
"safety_settings": STANDARD_SAFETY_SETTINGS, # Include standard safety
"tools": None, # No tools for internal JSON tasks
"tool_config": None
}
generation_config = types.GenerateContentConfig(**internal_gen_config_dict)
# --- Prepare Payload for Logging ---
# (Logging needs adjustment as model object isn't created here)
@ -1355,11 +1350,9 @@ async def get_internal_ai_json_response(
cog=cog,
model_name=actual_model_name, # Pass the determined model name
contents=contents,
generation_config=generation_config,
safety_settings=STANDARD_SAFETY_SETTINGS, # Use standard safety
generation_config=generation_config, # Pass combined config
request_desc=task_description,
tools=None, # No tools for internal JSON tasks
tool_config=None
# No separate safety, tools, tool_config args needed
)
# --- Process Response ---