From e6f03abf6db726813b478cd7f756d0c8db529bb1 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Fri, 30 May 2025 13:08:23 -0600 Subject: [PATCH] feat: Add dynamic AI model switching commands and update model configuration --- gurt/api.py | 16 ++++++-------- gurt/commands.py | 56 +++++++++++++++++++++++++++++++++++++++++++++++- gurt/config.py | 19 +++++++++++----- 3 files changed, 76 insertions(+), 15 deletions(-) diff --git a/gurt/api.py b/gurt/api.py index 91ac508..e6e667f 100644 --- a/gurt/api.py +++ b/gurt/api.py @@ -791,13 +791,11 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: print(f"Error in get_ai_response: {error_msg}") return {"final_response": None, "error": error_msg} - # Determine the model for the final response generation (custom tuned) - # Use the override if provided, otherwise default to the custom tuned endpoint - final_response_model = model_name or CUSTOM_TUNED_MODEL_ENDPOINT - # Model for tool checking will be DEFAULT_MODEL - tool_check_model = DEFAULT_MODEL - print(f"Using model for final response: {final_response_model}") - print(f"Using model for tool checks: {tool_check_model}") + # Determine the model for all generation steps. + # Use the model_name override if provided to get_ai_response, otherwise use the cog's current default_model. + active_model = model_name or cog.default_model + + print(f"Using active model for all generation steps: {active_model}") channel_id = message.channel.id user_id = message.author.id @@ -1216,7 +1214,7 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: current_response_obj = await call_google_genai_api_with_retry( cog=cog, - model_name=tool_check_model, # Use DEFAULT_MODEL for tool checks + model_name=active_model, # Use the dynamically set model for tool checks contents=contents, generation_config=current_gen_config, # Pass the combined config request_desc=f"Tool Check {tool_calls_made + 1} for message {message.id}", @@ -1407,7 +1405,7 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: # Make the final call *without* tools enabled (handled by config) final_json_response_obj = await call_google_genai_api_with_retry( cog=cog, - model_name=final_response_model, # Use the CUSTOM_TUNED_MODEL for final response + model_name=active_model, # Use the active model for final JSON response contents=contents, # Pass the accumulated history generation_config=generation_config_final_json, # Use combined JSON config request_desc=f"Final JSON Generation (dedicated call) for message {message.id}", diff --git a/gurt/commands.py b/gurt/commands.py index d34ade7..2bf12e6 100644 --- a/gurt/commands.py +++ b/gurt/commands.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Optional, Dict, Any, List, Tuple # Add more ty if TYPE_CHECKING: from .cog import GurtCog # For type hinting - from .config import MOOD_OPTIONS, IGNORED_CHANNEL_IDS, update_ignored_channels_file, TENOR_API_KEY # Import for choices and ignored channels + from .config import MOOD_OPTIONS, IGNORED_CHANNEL_IDS, update_ignored_channels_file, TENOR_API_KEY, AVAILABLE_AI_MODELS # Import for choices and ignored channels from .emojis import EmojiManager # Import EmojiManager # --- Helper Function for Embeds --- @@ -781,6 +781,60 @@ def setup_commands(cog: 'GurtCog'): command_functions.append(gurtresetpersonality) + # --- Gurt Model Command (Owner Only) --- + @cog.bot.tree.command(name="gurtmodel", description="Change Gurt's active AI model dynamically. (Owner only)") + @app_commands.describe(model="The AI model to switch to.") + @app_commands.choices(model=[ + app_commands.Choice(name=friendly_name, value=model_id) + for model_id, friendly_name in AVAILABLE_AI_MODELS.items() + ]) + async def gurtmodel(interaction: discord.Interaction, model: app_commands.Choice[str]): + """Handles the /gurtmodel command.""" + if interaction.user.id != cog.bot.owner_id: + await interaction.response.send_message("⛔ Only the bot owner can change Gurt's AI model.", ephemeral=True) + return + + await interaction.response.defer(ephemeral=True) + try: + new_model_id = model.value + new_model_friendly_name = model.name + + # Update the cog's default model + cog.default_model = new_model_id + + # Optionally, update the config file if you want this change to persist across restarts + # This would require a function in config.py to update DEFAULT_MODEL in the .env or a separate config file + # For now, we'll just update the runtime attribute. + # If persistence is desired, you'd add something like: + # await cog.config_manager.set_default_model(new_model_id) # Assuming a config_manager exists + + await interaction.followup.send(f"✅ Gurt's AI model has been changed to: **{new_model_friendly_name}** (`{new_model_id}`).", ephemeral=True) + except Exception as e: + print(f"Error in /gurtmodel command: {e}") + import traceback + traceback.print_exc() + await interaction.followup.send("❌ An error occurred while changing Gurt's AI model.", ephemeral=True) + + command_functions.append(gurtmodel) + + # --- Gurt Get Model Command --- + @cog.bot.tree.command(name="gurtgetmodel", description="Display Gurt's currently active AI model.") + async def gurtgetmodel(interaction: discord.Interaction): + """Handles the /gurtgetmodel command.""" + await interaction.response.defer(ephemeral=True) + try: + current_model_id = cog.default_model + # Try to get the friendly name from AVAILABLE_AI_MODELS + friendly_name = AVAILABLE_AI_MODELS.get(current_model_id, current_model_id) # Fallback to ID if not found + + await interaction.followup.send(f"Gurt is currently using AI model: **{friendly_name}** (`{current_model_id}`).", ephemeral=True) + except Exception as e: + print(f"Error in /gurtgetmodel command: {e}") + import traceback + traceback.print_exc() + await interaction.followup.send("❌ An error occurred while fetching Gurt's current AI model.", ephemeral=True) + + command_functions.append(gurtgetmodel) # Get command names safely - Command objects don't have __name__ attribute command_names = [] diff --git a/gurt/config.py b/gurt/config.py index d8231ba..50ab89d 100644 --- a/gurt/config.py +++ b/gurt/config.py @@ -22,11 +22,20 @@ TAVILY_DEFAULT_MAX_RESULTS = int(os.getenv("TAVILY_DEFAULT_MAX_RESULTS", 5)) TAVILY_DISABLE_ADVANCED = os.getenv("TAVILY_DISABLE_ADVANCED", "false").lower() == "true" # For cost control # --- Model Configuration --- -DEFAULT_MODEL = os.getenv("GURT_DEFAULT_MODEL", "gemini-2.5-flash-preview-05-20") -FALLBACK_MODEL = os.getenv("GURT_FALLBACK_MODEL", "gemini-2.5-flash-preview-05-20") -CUSTOM_TUNED_MODEL_ENDPOINT = os.getenv("GURT_CUSTOM_TUNED_MODEL", "gemini-2.5-flash-preview-05-20") -SAFETY_CHECK_MODEL = os.getenv("GURT_SAFETY_CHECK_MODEL", "gemini-2.5-flash-preview-05-20") # Use a Vertex AI model for safety checks -EMOJI_STICKER_DESCRIPTION_MODEL = "gemini-2.0-flash-001" # Hardcoded for emoji/sticker image descriptions +DEFAULT_MODEL = os.getenv("GURT_DEFAULT_MODEL", "google/gemini-2.5-flash-preview-05-20") +FALLBACK_MODEL = os.getenv("GURT_FALLBACK_MODEL", "google/gemini-2.5-flash-preview-05-20") +CUSTOM_TUNED_MODEL_ENDPOINT = os.getenv("GURT_CUSTOM_TUNED_MODEL", "google/gemini-2.5-flash-preview-05-20") +SAFETY_CHECK_MODEL = os.getenv("GURT_SAFETY_CHECK_MODEL", "google/gemini-2.5-flash-preview-05-20") # Use a Vertex AI model for safety checks +EMOJI_STICKER_DESCRIPTION_MODEL = "google/gemini-2.0-flash-001" # Hardcoded for emoji/sticker image descriptions + +# Available AI Models for dynamic switching +AVAILABLE_AI_MODELS = { + "google/gemini-2.5-flash-preview-05-20": "Gemini 2.5 Flash Preview", + "google/gemini-2.5-pro-preview-05-06": "Gemini 2.5 Pro Preview", + "claude-sonnet-4@20250514": "Claude Sonnet 4", + "llama-4-maverick-17b-128e-instruct-maas": "Llama 4 Maverick", + "google/gemini-2.0-flash-001": "Gemini 2.0 Flash" +} # --- Database Paths --- DB_PATH = os.getenv("GURT_DB_PATH", "data/gurt_memory.db")