feat: Add dynamic AI model switching commands and update model configuration
This commit is contained in:
parent
e0c1b98182
commit
e6f03abf6d
16
gurt/api.py
16
gurt/api.py
@ -791,13 +791,11 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name:
|
|||||||
print(f"Error in get_ai_response: {error_msg}")
|
print(f"Error in get_ai_response: {error_msg}")
|
||||||
return {"final_response": None, "error": error_msg}
|
return {"final_response": None, "error": error_msg}
|
||||||
|
|
||||||
# Determine the model for the final response generation (custom tuned)
|
# Determine the model for all generation steps.
|
||||||
# Use the override if provided, otherwise default to the custom tuned endpoint
|
# Use the model_name override if provided to get_ai_response, otherwise use the cog's current default_model.
|
||||||
final_response_model = model_name or CUSTOM_TUNED_MODEL_ENDPOINT
|
active_model = model_name or cog.default_model
|
||||||
# Model for tool checking will be DEFAULT_MODEL
|
|
||||||
tool_check_model = DEFAULT_MODEL
|
print(f"Using active model for all generation steps: {active_model}")
|
||||||
print(f"Using model for final response: {final_response_model}")
|
|
||||||
print(f"Using model for tool checks: {tool_check_model}")
|
|
||||||
|
|
||||||
channel_id = message.channel.id
|
channel_id = message.channel.id
|
||||||
user_id = message.author.id
|
user_id = message.author.id
|
||||||
@ -1216,7 +1214,7 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name:
|
|||||||
|
|
||||||
current_response_obj = await call_google_genai_api_with_retry(
|
current_response_obj = await call_google_genai_api_with_retry(
|
||||||
cog=cog,
|
cog=cog,
|
||||||
model_name=tool_check_model, # Use DEFAULT_MODEL for tool checks
|
model_name=active_model, # Use the dynamically set model for tool checks
|
||||||
contents=contents,
|
contents=contents,
|
||||||
generation_config=current_gen_config, # Pass the combined config
|
generation_config=current_gen_config, # Pass the combined config
|
||||||
request_desc=f"Tool Check {tool_calls_made + 1} for message {message.id}",
|
request_desc=f"Tool Check {tool_calls_made + 1} for message {message.id}",
|
||||||
@ -1407,7 +1405,7 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name:
|
|||||||
# Make the final call *without* tools enabled (handled by config)
|
# Make the final call *without* tools enabled (handled by config)
|
||||||
final_json_response_obj = await call_google_genai_api_with_retry(
|
final_json_response_obj = await call_google_genai_api_with_retry(
|
||||||
cog=cog,
|
cog=cog,
|
||||||
model_name=final_response_model, # Use the CUSTOM_TUNED_MODEL for final response
|
model_name=active_model, # Use the active model for final JSON response
|
||||||
contents=contents, # Pass the accumulated history
|
contents=contents, # Pass the accumulated history
|
||||||
generation_config=generation_config_final_json, # Use combined JSON config
|
generation_config=generation_config_final_json, # Use combined JSON config
|
||||||
request_desc=f"Final JSON Generation (dedicated call) for message {message.id}",
|
request_desc=f"Final JSON Generation (dedicated call) for message {message.id}",
|
||||||
|
@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Optional, Dict, Any, List, Tuple # Add more ty
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .cog import GurtCog # For type hinting
|
from .cog import GurtCog # For type hinting
|
||||||
from .config import MOOD_OPTIONS, IGNORED_CHANNEL_IDS, update_ignored_channels_file, TENOR_API_KEY # Import for choices and ignored channels
|
from .config import MOOD_OPTIONS, IGNORED_CHANNEL_IDS, update_ignored_channels_file, TENOR_API_KEY, AVAILABLE_AI_MODELS # Import for choices and ignored channels
|
||||||
from .emojis import EmojiManager # Import EmojiManager
|
from .emojis import EmojiManager # Import EmojiManager
|
||||||
|
|
||||||
# --- Helper Function for Embeds ---
|
# --- Helper Function for Embeds ---
|
||||||
@ -781,6 +781,60 @@ def setup_commands(cog: 'GurtCog'):
|
|||||||
|
|
||||||
command_functions.append(gurtresetpersonality)
|
command_functions.append(gurtresetpersonality)
|
||||||
|
|
||||||
|
# --- Gurt Model Command (Owner Only) ---
|
||||||
|
@cog.bot.tree.command(name="gurtmodel", description="Change Gurt's active AI model dynamically. (Owner only)")
|
||||||
|
@app_commands.describe(model="The AI model to switch to.")
|
||||||
|
@app_commands.choices(model=[
|
||||||
|
app_commands.Choice(name=friendly_name, value=model_id)
|
||||||
|
for model_id, friendly_name in AVAILABLE_AI_MODELS.items()
|
||||||
|
])
|
||||||
|
async def gurtmodel(interaction: discord.Interaction, model: app_commands.Choice[str]):
|
||||||
|
"""Handles the /gurtmodel command."""
|
||||||
|
if interaction.user.id != cog.bot.owner_id:
|
||||||
|
await interaction.response.send_message("⛔ Only the bot owner can change Gurt's AI model.", ephemeral=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
await interaction.response.defer(ephemeral=True)
|
||||||
|
try:
|
||||||
|
new_model_id = model.value
|
||||||
|
new_model_friendly_name = model.name
|
||||||
|
|
||||||
|
# Update the cog's default model
|
||||||
|
cog.default_model = new_model_id
|
||||||
|
|
||||||
|
# Optionally, update the config file if you want this change to persist across restarts
|
||||||
|
# This would require a function in config.py to update DEFAULT_MODEL in the .env or a separate config file
|
||||||
|
# For now, we'll just update the runtime attribute.
|
||||||
|
# If persistence is desired, you'd add something like:
|
||||||
|
# await cog.config_manager.set_default_model(new_model_id) # Assuming a config_manager exists
|
||||||
|
|
||||||
|
await interaction.followup.send(f"✅ Gurt's AI model has been changed to: **{new_model_friendly_name}** (`{new_model_id}`).", ephemeral=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in /gurtmodel command: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
await interaction.followup.send("❌ An error occurred while changing Gurt's AI model.", ephemeral=True)
|
||||||
|
|
||||||
|
command_functions.append(gurtmodel)
|
||||||
|
|
||||||
|
# --- Gurt Get Model Command ---
|
||||||
|
@cog.bot.tree.command(name="gurtgetmodel", description="Display Gurt's currently active AI model.")
|
||||||
|
async def gurtgetmodel(interaction: discord.Interaction):
|
||||||
|
"""Handles the /gurtgetmodel command."""
|
||||||
|
await interaction.response.defer(ephemeral=True)
|
||||||
|
try:
|
||||||
|
current_model_id = cog.default_model
|
||||||
|
# Try to get the friendly name from AVAILABLE_AI_MODELS
|
||||||
|
friendly_name = AVAILABLE_AI_MODELS.get(current_model_id, current_model_id) # Fallback to ID if not found
|
||||||
|
|
||||||
|
await interaction.followup.send(f"Gurt is currently using AI model: **{friendly_name}** (`{current_model_id}`).", ephemeral=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in /gurtgetmodel command: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
await interaction.followup.send("❌ An error occurred while fetching Gurt's current AI model.", ephemeral=True)
|
||||||
|
|
||||||
|
command_functions.append(gurtgetmodel)
|
||||||
|
|
||||||
# Get command names safely - Command objects don't have __name__ attribute
|
# Get command names safely - Command objects don't have __name__ attribute
|
||||||
command_names = []
|
command_names = []
|
||||||
|
@ -22,11 +22,20 @@ TAVILY_DEFAULT_MAX_RESULTS = int(os.getenv("TAVILY_DEFAULT_MAX_RESULTS", 5))
|
|||||||
TAVILY_DISABLE_ADVANCED = os.getenv("TAVILY_DISABLE_ADVANCED", "false").lower() == "true" # For cost control
|
TAVILY_DISABLE_ADVANCED = os.getenv("TAVILY_DISABLE_ADVANCED", "false").lower() == "true" # For cost control
|
||||||
|
|
||||||
# --- Model Configuration ---
|
# --- Model Configuration ---
|
||||||
DEFAULT_MODEL = os.getenv("GURT_DEFAULT_MODEL", "gemini-2.5-flash-preview-05-20")
|
DEFAULT_MODEL = os.getenv("GURT_DEFAULT_MODEL", "google/gemini-2.5-flash-preview-05-20")
|
||||||
FALLBACK_MODEL = os.getenv("GURT_FALLBACK_MODEL", "gemini-2.5-flash-preview-05-20")
|
FALLBACK_MODEL = os.getenv("GURT_FALLBACK_MODEL", "google/gemini-2.5-flash-preview-05-20")
|
||||||
CUSTOM_TUNED_MODEL_ENDPOINT = os.getenv("GURT_CUSTOM_TUNED_MODEL", "gemini-2.5-flash-preview-05-20")
|
CUSTOM_TUNED_MODEL_ENDPOINT = os.getenv("GURT_CUSTOM_TUNED_MODEL", "google/gemini-2.5-flash-preview-05-20")
|
||||||
SAFETY_CHECK_MODEL = os.getenv("GURT_SAFETY_CHECK_MODEL", "gemini-2.5-flash-preview-05-20") # Use a Vertex AI model for safety checks
|
SAFETY_CHECK_MODEL = os.getenv("GURT_SAFETY_CHECK_MODEL", "google/gemini-2.5-flash-preview-05-20") # Use a Vertex AI model for safety checks
|
||||||
EMOJI_STICKER_DESCRIPTION_MODEL = "gemini-2.0-flash-001" # Hardcoded for emoji/sticker image descriptions
|
EMOJI_STICKER_DESCRIPTION_MODEL = "google/gemini-2.0-flash-001" # Hardcoded for emoji/sticker image descriptions
|
||||||
|
|
||||||
|
# Available AI Models for dynamic switching
|
||||||
|
AVAILABLE_AI_MODELS = {
|
||||||
|
"google/gemini-2.5-flash-preview-05-20": "Gemini 2.5 Flash Preview",
|
||||||
|
"google/gemini-2.5-pro-preview-05-06": "Gemini 2.5 Pro Preview",
|
||||||
|
"claude-sonnet-4@20250514": "Claude Sonnet 4",
|
||||||
|
"llama-4-maverick-17b-128e-instruct-maas": "Llama 4 Maverick",
|
||||||
|
"google/gemini-2.0-flash-001": "Gemini 2.0 Flash"
|
||||||
|
}
|
||||||
|
|
||||||
# --- Database Paths ---
|
# --- Database Paths ---
|
||||||
DB_PATH = os.getenv("GURT_DB_PATH", "data/gurt_memory.db")
|
DB_PATH = os.getenv("GURT_DB_PATH", "data/gurt_memory.db")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user