From fd13c35afdf9de7e8f8cd273563d85662934a484 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Thu, 29 May 2025 11:41:36 -0600 Subject: [PATCH] feat: Use specific model for emoji/sticker descriptions Introduces `EMOJI_STICKER_DESCRIPTION_MODEL` to `config.py` and uses it in `generate_image_description` for "emoji" and "sticker" item types. This ensures a dedicated model is used for these specific image description tasks, improving accuracy or efficiency. --- gurt/api.py | 7 +++---- gurt/config.py | 1 + 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gurt/api.py b/gurt/api.py index aef49b1..952e236 100644 --- a/gurt/api.py +++ b/gurt/api.py @@ -151,7 +151,7 @@ from google.api_core import exceptions as google_exceptions # Keep for retry log # Relative imports for components within the 'gurt' package from .config import ( - PROJECT_ID, LOCATION, DEFAULT_MODEL, FALLBACK_MODEL, CUSTOM_TUNED_MODEL_ENDPOINT, # Import the new endpoint + PROJECT_ID, LOCATION, DEFAULT_MODEL, FALLBACK_MODEL, CUSTOM_TUNED_MODEL_ENDPOINT, EMOJI_STICKER_DESCRIPTION_MODEL, # Import the new endpoint and model API_TIMEOUT, API_RETRY_ATTEMPTS, API_RETRY_DELAY, TOOLS, RESPONSE_SCHEMA, PROACTIVE_PLAN_SCHEMA, # Import the new schema TAVILY_API_KEY, PISTON_API_URL, PISTON_API_KEY, BASELINE_PERSONALITY, TENOR_API_KEY # Import other needed configs @@ -1923,9 +1923,8 @@ async def generate_image_description( # 4. Call AI # Use a multimodal model, e.g., DEFAULT_MODEL if it's Gemini 1.5 Pro or similar - # If DEFAULT_MODEL is tuned for JSON, we might need to specify a base multimodal model here. - # For now, assume DEFAULT_MODEL can handle this. - model_to_use = DEFAULT_MODEL # Or specify a known multimodal model like "models/gemini-1.5-pro-preview-0409" + # Determine which model to use based on item_type + model_to_use = EMOJI_STICKER_DESCRIPTION_MODEL if item_type in ["emoji", "sticker"] else DEFAULT_MODEL print(f"Calling AI for image description ({item_name}) using model: {model_to_use}") ai_response_obj = await call_google_genai_api_with_retry( diff --git a/gurt/config.py b/gurt/config.py index 49fe9ce..c0e5cf4 100644 --- a/gurt/config.py +++ b/gurt/config.py @@ -26,6 +26,7 @@ DEFAULT_MODEL = os.getenv("GURT_DEFAULT_MODEL", "gemini-2.5-flash-preview-05-20" FALLBACK_MODEL = os.getenv("GURT_FALLBACK_MODEL", "gemini-2.5-flash-preview-05-20") CUSTOM_TUNED_MODEL_ENDPOINT = os.getenv("GURT_CUSTOM_TUNED_MODEL", "gemini-2.5-flash-preview-05-20") SAFETY_CHECK_MODEL = os.getenv("GURT_SAFETY_CHECK_MODEL", "gemini-2.5-flash-preview-05-20") # Use a Vertex AI model for safety checks +EMOJI_STICKER_DESCRIPTION_MODEL = "gemini-2.0-flash-001" # Hardcoded for emoji/sticker image descriptions # --- Database Paths --- DB_PATH = os.getenv("GURT_DB_PATH", "data/gurt_memory.db")