fasfd
This commit is contained in:
parent
f2003bfef7
commit
6248577f89
145
gurt/api.py
145
gurt/api.py
@ -291,11 +291,9 @@ async def call_google_genai_api_with_retry(
|
||||
cog: 'GurtCog',
|
||||
model_name: str, # Pass model name string instead of model object
|
||||
contents: List[types.Content], # Use types.Content type from google.generativeai.types
|
||||
generation_config: types.GenerateContentConfig, # Use specific type
|
||||
safety_settings: Optional[List[types.SafetySetting]], # Use specific type
|
||||
generation_config: types.GenerateContentConfig, # Combined config object
|
||||
request_desc: str,
|
||||
tools: Optional[List[types.Tool]] = None, # Pass tools list if needed
|
||||
tool_config: Optional[types.ToolConfig] = None
|
||||
# Removed safety_settings, tools, tool_config as separate params
|
||||
) -> Optional[types.GenerateContentResponse]: # Return type for non-streaming
|
||||
"""
|
||||
Calls the Google Generative AI API (Vertex AI backend) with retry logic (non-streaming).
|
||||
@ -304,11 +302,8 @@ async def call_google_genai_api_with_retry(
|
||||
cog: The GurtCog instance.
|
||||
model_name: The name/path of the model to use (e.g., 'models/gemini-1.5-pro-preview-0409' or custom endpoint path).
|
||||
contents: The list of types.Content objects for the prompt.
|
||||
generation_config: The types.GenerateContentConfig object.
|
||||
safety_settings: Safety settings for the request (List[types.SafetySetting]).
|
||||
generation_config: The types.GenerateContentConfig object, which should now include temperature, top_p, max_output_tokens, safety_settings, tools, tool_config, response_mime_type, response_schema etc. as needed.
|
||||
request_desc: A description of the request for logging purposes.
|
||||
tools: Optional list of Tool objects for function calling.
|
||||
tool_config: Optional ToolConfig object.
|
||||
|
||||
Returns:
|
||||
The GenerateContentResponse object if successful, or None on failure after retries.
|
||||
@ -337,14 +332,11 @@ async def call_google_genai_api_with_retry(
|
||||
try:
|
||||
print(f"Sending API request for {request_desc} using {model_name} (Attempt {attempt + 1}/{API_RETRY_ATTEMPTS + 1})...")
|
||||
|
||||
# Use the non-streaming async call
|
||||
# Use the non-streaming async call - config now contains all settings
|
||||
response = await genai_client.models.generate_content(
|
||||
contents=contents,
|
||||
config=generation_config,
|
||||
safety_settings=safety_settings or STANDARD_SAFETY_SETTINGS,
|
||||
tools=tools,
|
||||
tool_config=tool_config,
|
||||
# stream=False is implicit for generate_content_async
|
||||
config=generation_config, # Pass the combined config object
|
||||
# stream=False is implicit for generate_content
|
||||
)
|
||||
|
||||
# --- Check Finish Reason (Safety) ---
|
||||
@ -800,24 +792,15 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name:
|
||||
tools_list = [vertex_tool] if vertex_tool else None
|
||||
|
||||
# --- Prepare Generation Config ---
|
||||
# Use settings from user example and config.py
|
||||
# Note: response_modalities and speech_config from user example are not standard genai config
|
||||
generation_config = types.GenerateContentConfig(
|
||||
temperature=0.85, # From user example
|
||||
top_p=0.95, # From user example
|
||||
max_output_tokens=8192, # From user example
|
||||
# response_mime_type="application/json", # Set this later for the final JSON call
|
||||
# response_schema=... # Set this later for the final JSON call
|
||||
# stop_sequences=... # Add if needed
|
||||
# Base generation config settings (will be augmented later)
|
||||
base_generation_config_dict = {
|
||||
"temperature": 0.85, # From user example
|
||||
"top_p": 0.95, # From user example
|
||||
"max_output_tokens": 8192, # From user example
|
||||
"safety_settings": STANDARD_SAFETY_SETTINGS, # Include standard safety settings
|
||||
# candidate_count=1 # Default is 1
|
||||
)
|
||||
|
||||
# Tool config for the loop (allow any tool call)
|
||||
tool_config_any = types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(
|
||||
mode=types.FunctionCallingConfigMode.ANY # Use ANY to allow model to call tools
|
||||
)
|
||||
) if vertex_tool else None
|
||||
# stop_sequences=... # Add if needed
|
||||
}
|
||||
|
||||
# --- Tool Execution Loop ---
|
||||
while tool_calls_made < max_tool_calls:
|
||||
@ -832,23 +815,27 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name:
|
||||
print(f"Error logging raw request/response: {log_e}")
|
||||
|
||||
# --- Call API using the new helper ---
|
||||
# Use a temporary config for tool checking (no JSON enforcement yet)
|
||||
current_gen_config = types.GenerateContentConfig(
|
||||
temperature=generation_config.temperature, # Use base temp or specific tool temp
|
||||
top_p=generation_config.top_p,
|
||||
max_output_tokens=generation_config.max_output_tokens,
|
||||
# Omit response_mime_type and response_schema here
|
||||
)
|
||||
# Build the config for this specific call (tool check)
|
||||
current_gen_config_dict = base_generation_config_dict.copy()
|
||||
if tools_list:
|
||||
current_gen_config_dict["tools"] = tools_list
|
||||
# Define tool_config here if needed, e.g., for ANY mode
|
||||
current_gen_config_dict["tool_config"] = types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(
|
||||
mode=types.FunctionCallingConfigMode.ANY
|
||||
)
|
||||
)
|
||||
# Omit response_mime_type and response_schema for tool checking
|
||||
|
||||
current_gen_config = types.GenerateContentConfig(**current_gen_config_dict)
|
||||
|
||||
current_response_obj = await call_google_genai_api_with_retry(
|
||||
cog=cog,
|
||||
model_name=target_model_name, # Pass the model name string
|
||||
contents=contents,
|
||||
generation_config=current_gen_config, # Use temp config
|
||||
safety_settings=STANDARD_SAFETY_SETTINGS, # Use the new list format
|
||||
generation_config=current_gen_config, # Pass the combined config
|
||||
request_desc=f"Tool Check {tool_calls_made + 1} for message {message.id}",
|
||||
tools=tools_list, # Pass the Tool object list
|
||||
tool_config=tool_config_any
|
||||
# No separate safety, tools, tool_config args needed
|
||||
)
|
||||
last_response_obj = current_response_obj # Store the latest response
|
||||
|
||||
@ -946,25 +933,26 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name:
|
||||
print(f"Making dedicated final API call for JSON ({log_reason})...")
|
||||
|
||||
# Prepare the final generation config with JSON enforcement
|
||||
processed_response_schema = _preprocess_schema_for_vertex(RESPONSE_SCHEMA['schema']) # Keep using this helper for now
|
||||
generation_config_final_json = types.GenerateContentConfig(
|
||||
temperature=generation_config.temperature, # Use original temp
|
||||
top_p=generation_config.top_p,
|
||||
max_output_tokens=generation_config.max_output_tokens,
|
||||
response_mime_type="application/json",
|
||||
response_schema=processed_response_schema # Pass the schema here
|
||||
)
|
||||
processed_response_schema = _preprocess_schema_for_vertex(RESPONSE_SCHEMA['schema'])
|
||||
final_gen_config_dict = base_generation_config_dict.copy() # Start with base
|
||||
final_gen_config_dict.update({
|
||||
"response_mime_type": "application/json",
|
||||
"response_schema": processed_response_schema,
|
||||
# Explicitly exclude tools/tool_config for final JSON generation
|
||||
"tools": None,
|
||||
"tool_config": None
|
||||
})
|
||||
generation_config_final_json = types.GenerateContentConfig(**final_gen_config_dict)
|
||||
|
||||
# Make the final call *without* tools enabled
|
||||
|
||||
# Make the final call *without* tools enabled (handled by config)
|
||||
final_json_response_obj = await call_google_genai_api_with_retry(
|
||||
cog=cog,
|
||||
model_name=target_model_name, # Use the target model
|
||||
contents=contents, # Pass the accumulated history
|
||||
generation_config=generation_config_final_json, # Use JSON config
|
||||
safety_settings=STANDARD_SAFETY_SETTINGS,
|
||||
generation_config=generation_config_final_json, # Use combined JSON config
|
||||
request_desc=f"Final JSON Generation (dedicated call) for message {message.id}",
|
||||
tools=None, # Explicitly disable tools for final JSON generation
|
||||
tool_config=None
|
||||
# No separate safety, tools, tool_config args needed
|
||||
)
|
||||
|
||||
if not final_json_response_obj:
|
||||
@ -1137,25 +1125,28 @@ async def get_proactive_ai_response(cog: 'GurtCog', message: discord.Message, tr
|
||||
|
||||
|
||||
# --- Call Final LLM API ---
|
||||
# Preprocess the schema
|
||||
# Preprocess the schema and build the final config
|
||||
processed_response_schema_proactive = _preprocess_schema_for_vertex(RESPONSE_SCHEMA['schema'])
|
||||
generation_config_final = types.GenerateContentConfig(
|
||||
temperature=0.6, # Use original proactive temp
|
||||
max_output_tokens=2000,
|
||||
response_mime_type="application/json",
|
||||
response_schema=processed_response_schema_proactive # Use preprocessed schema
|
||||
)
|
||||
final_proactive_config_dict = {
|
||||
"temperature": 0.6, # Use original proactive temp
|
||||
"max_output_tokens": 2000,
|
||||
"response_mime_type": "application/json",
|
||||
"response_schema": processed_response_schema_proactive,
|
||||
"safety_settings": STANDARD_SAFETY_SETTINGS,
|
||||
"tools": None, # No tools needed for this final generation
|
||||
"tool_config": None
|
||||
}
|
||||
generation_config_final = types.GenerateContentConfig(**final_proactive_config_dict)
|
||||
|
||||
|
||||
# Use the new API call helper
|
||||
response_obj = await call_google_genai_api_with_retry(
|
||||
cog=cog,
|
||||
model_name=DEFAULT_MODEL, # Use the default model for proactive responses
|
||||
contents=proactive_contents, # Pass the constructed contents
|
||||
generation_config=generation_config_final,
|
||||
safety_settings=STANDARD_SAFETY_SETTINGS,
|
||||
generation_config=generation_config_final, # Pass combined config
|
||||
request_desc=f"Final proactive response for channel {channel_id} ({trigger_reason})",
|
||||
tools=None, # No tools needed for this final generation
|
||||
tool_config=None
|
||||
# No separate safety, tools, tool_config args needed
|
||||
)
|
||||
|
||||
if not response_obj:
|
||||
@ -1320,12 +1311,16 @@ async def get_internal_ai_json_response(
|
||||
|
||||
# --- Prepare Generation Config ---
|
||||
processed_schema_internal = _preprocess_schema_for_vertex(response_schema_dict)
|
||||
generation_config = types.GenerateContentConfig(
|
||||
temperature=temperature,
|
||||
max_output_tokens=max_tokens,
|
||||
response_mime_type="application/json",
|
||||
response_schema=processed_schema_internal # Use preprocessed schema
|
||||
)
|
||||
internal_gen_config_dict = {
|
||||
"temperature": temperature,
|
||||
"max_output_tokens": max_tokens,
|
||||
"response_mime_type": "application/json",
|
||||
"response_schema": processed_schema_internal,
|
||||
"safety_settings": STANDARD_SAFETY_SETTINGS, # Include standard safety
|
||||
"tools": None, # No tools for internal JSON tasks
|
||||
"tool_config": None
|
||||
}
|
||||
generation_config = types.GenerateContentConfig(**internal_gen_config_dict)
|
||||
|
||||
# --- Prepare Payload for Logging ---
|
||||
# (Logging needs adjustment as model object isn't created here)
|
||||
@ -1355,11 +1350,9 @@ async def get_internal_ai_json_response(
|
||||
cog=cog,
|
||||
model_name=actual_model_name, # Pass the determined model name
|
||||
contents=contents,
|
||||
generation_config=generation_config,
|
||||
safety_settings=STANDARD_SAFETY_SETTINGS, # Use standard safety
|
||||
generation_config=generation_config, # Pass combined config
|
||||
request_desc=task_description,
|
||||
tools=None, # No tools for internal JSON tasks
|
||||
tool_config=None
|
||||
# No separate safety, tools, tool_config args needed
|
||||
)
|
||||
|
||||
# --- Process Response ---
|
||||
|
Loading…
x
Reference in New Issue
Block a user