feat: Add dynamic AI model switching commands and update model configuration

This commit is contained in:
Slipstream 2025-05-30 13:08:23 -06:00
parent e0c1b98182
commit e6f03abf6d
Signed by: slipstream
GPG Key ID: 13E498CE010AC6FD
3 changed files with 76 additions and 15 deletions

View File

@ -791,13 +791,11 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name:
print(f"Error in get_ai_response: {error_msg}")
return {"final_response": None, "error": error_msg}
# Determine the model for the final response generation (custom tuned)
# Use the override if provided, otherwise default to the custom tuned endpoint
final_response_model = model_name or CUSTOM_TUNED_MODEL_ENDPOINT
# Model for tool checking will be DEFAULT_MODEL
tool_check_model = DEFAULT_MODEL
print(f"Using model for final response: {final_response_model}")
print(f"Using model for tool checks: {tool_check_model}")
# Determine the model for all generation steps.
# Use the model_name override if provided to get_ai_response, otherwise use the cog's current default_model.
active_model = model_name or cog.default_model
print(f"Using active model for all generation steps: {active_model}")
channel_id = message.channel.id
user_id = message.author.id
@ -1216,7 +1214,7 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name:
current_response_obj = await call_google_genai_api_with_retry(
cog=cog,
model_name=tool_check_model, # Use DEFAULT_MODEL for tool checks
model_name=active_model, # Use the dynamically set model for tool checks
contents=contents,
generation_config=current_gen_config, # Pass the combined config
request_desc=f"Tool Check {tool_calls_made + 1} for message {message.id}",
@ -1407,7 +1405,7 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name:
# Make the final call *without* tools enabled (handled by config)
final_json_response_obj = await call_google_genai_api_with_retry(
cog=cog,
model_name=final_response_model, # Use the CUSTOM_TUNED_MODEL for final response
model_name=active_model, # Use the active model for final JSON response
contents=contents, # Pass the accumulated history
generation_config=generation_config_final_json, # Use combined JSON config
request_desc=f"Final JSON Generation (dedicated call) for message {message.id}",

View File

@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Optional, Dict, Any, List, Tuple # Add more ty
if TYPE_CHECKING:
from .cog import GurtCog # For type hinting
from .config import MOOD_OPTIONS, IGNORED_CHANNEL_IDS, update_ignored_channels_file, TENOR_API_KEY # Import for choices and ignored channels
from .config import MOOD_OPTIONS, IGNORED_CHANNEL_IDS, update_ignored_channels_file, TENOR_API_KEY, AVAILABLE_AI_MODELS # Import for choices and ignored channels
from .emojis import EmojiManager # Import EmojiManager
# --- Helper Function for Embeds ---
@ -781,6 +781,60 @@ def setup_commands(cog: 'GurtCog'):
command_functions.append(gurtresetpersonality)
# --- Gurt Model Command (Owner Only) ---
@cog.bot.tree.command(name="gurtmodel", description="Change Gurt's active AI model dynamically. (Owner only)")
@app_commands.describe(model="The AI model to switch to.")
@app_commands.choices(model=[
app_commands.Choice(name=friendly_name, value=model_id)
for model_id, friendly_name in AVAILABLE_AI_MODELS.items()
])
async def gurtmodel(interaction: discord.Interaction, model: app_commands.Choice[str]):
"""Handles the /gurtmodel command."""
if interaction.user.id != cog.bot.owner_id:
await interaction.response.send_message("⛔ Only the bot owner can change Gurt's AI model.", ephemeral=True)
return
await interaction.response.defer(ephemeral=True)
try:
new_model_id = model.value
new_model_friendly_name = model.name
# Update the cog's default model
cog.default_model = new_model_id
# Optionally, update the config file if you want this change to persist across restarts
# This would require a function in config.py to update DEFAULT_MODEL in the .env or a separate config file
# For now, we'll just update the runtime attribute.
# If persistence is desired, you'd add something like:
# await cog.config_manager.set_default_model(new_model_id) # Assuming a config_manager exists
await interaction.followup.send(f"✅ Gurt's AI model has been changed to: **{new_model_friendly_name}** (`{new_model_id}`).", ephemeral=True)
except Exception as e:
print(f"Error in /gurtmodel command: {e}")
import traceback
traceback.print_exc()
await interaction.followup.send("❌ An error occurred while changing Gurt's AI model.", ephemeral=True)
command_functions.append(gurtmodel)
# --- Gurt Get Model Command ---
@cog.bot.tree.command(name="gurtgetmodel", description="Display Gurt's currently active AI model.")
async def gurtgetmodel(interaction: discord.Interaction):
"""Handles the /gurtgetmodel command."""
await interaction.response.defer(ephemeral=True)
try:
current_model_id = cog.default_model
# Try to get the friendly name from AVAILABLE_AI_MODELS
friendly_name = AVAILABLE_AI_MODELS.get(current_model_id, current_model_id) # Fallback to ID if not found
await interaction.followup.send(f"Gurt is currently using AI model: **{friendly_name}** (`{current_model_id}`).", ephemeral=True)
except Exception as e:
print(f"Error in /gurtgetmodel command: {e}")
import traceback
traceback.print_exc()
await interaction.followup.send("❌ An error occurred while fetching Gurt's current AI model.", ephemeral=True)
command_functions.append(gurtgetmodel)
# Get command names safely - Command objects don't have __name__ attribute
command_names = []

View File

@ -22,11 +22,20 @@ TAVILY_DEFAULT_MAX_RESULTS = int(os.getenv("TAVILY_DEFAULT_MAX_RESULTS", 5))
TAVILY_DISABLE_ADVANCED = os.getenv("TAVILY_DISABLE_ADVANCED", "false").lower() == "true" # For cost control
# --- Model Configuration ---
DEFAULT_MODEL = os.getenv("GURT_DEFAULT_MODEL", "gemini-2.5-flash-preview-05-20")
FALLBACK_MODEL = os.getenv("GURT_FALLBACK_MODEL", "gemini-2.5-flash-preview-05-20")
CUSTOM_TUNED_MODEL_ENDPOINT = os.getenv("GURT_CUSTOM_TUNED_MODEL", "gemini-2.5-flash-preview-05-20")
SAFETY_CHECK_MODEL = os.getenv("GURT_SAFETY_CHECK_MODEL", "gemini-2.5-flash-preview-05-20") # Use a Vertex AI model for safety checks
EMOJI_STICKER_DESCRIPTION_MODEL = "gemini-2.0-flash-001" # Hardcoded for emoji/sticker image descriptions
DEFAULT_MODEL = os.getenv("GURT_DEFAULT_MODEL", "google/gemini-2.5-flash-preview-05-20")
FALLBACK_MODEL = os.getenv("GURT_FALLBACK_MODEL", "google/gemini-2.5-flash-preview-05-20")
CUSTOM_TUNED_MODEL_ENDPOINT = os.getenv("GURT_CUSTOM_TUNED_MODEL", "google/gemini-2.5-flash-preview-05-20")
SAFETY_CHECK_MODEL = os.getenv("GURT_SAFETY_CHECK_MODEL", "google/gemini-2.5-flash-preview-05-20") # Use a Vertex AI model for safety checks
EMOJI_STICKER_DESCRIPTION_MODEL = "google/gemini-2.0-flash-001" # Hardcoded for emoji/sticker image descriptions
# Available AI Models for dynamic switching
AVAILABLE_AI_MODELS = {
"google/gemini-2.5-flash-preview-05-20": "Gemini 2.5 Flash Preview",
"google/gemini-2.5-pro-preview-05-06": "Gemini 2.5 Pro Preview",
"claude-sonnet-4@20250514": "Claude Sonnet 4",
"llama-4-maverick-17b-128e-instruct-maas": "Llama 4 Maverick",
"google/gemini-2.0-flash-001": "Gemini 2.0 Flash"
}
# --- Database Paths ---
DB_PATH = os.getenv("GURT_DB_PATH", "data/gurt_memory.db")