From 998d0a3b15444ed67eb227b59da24f8eca45bece Mon Sep 17 00:00:00 2001 From: Codex Date: Sat, 7 Jun 2025 05:55:13 +0000 Subject: [PATCH] Use global region for gemini 2.5 pro --- cogs/ai_code_agent_cog.py | 4 +++- cogs/aimod_cog.py | 7 ++++-- cogs/neru_teto_cog.py | 4 +++- cogs/rule34_cog.py | 4 +++- cogs/teto_cog.py | 4 +++- gurt/api.py | 48 +++++++++++++++++++++++++++------------ gurt/config.py | 2 ++ gurt/tools.py | 10 +++----- 8 files changed, 56 insertions(+), 27 deletions(-) diff --git a/cogs/ai_code_agent_cog.py b/cogs/ai_code_agent_cog.py index a0d58ab..30ffabd 100644 --- a/cogs/ai_code_agent_cog.py +++ b/cogs/ai_code_agent_cog.py @@ -35,6 +35,7 @@ except ImportError: # Allow cog to load but genai_client will be None from tavily import TavilyClient +from gurt.api import get_genai_client_for_model # Define standard safety settings using google.generativeai types # Set all thresholds to OFF as requested for internal tools @@ -2432,7 +2433,8 @@ class AICodeAgentCog(commands.Cog): # for i, item in enumerate(vertex_contents): # print(f" History {i} Role: {item.role}, Parts: {item.parts}") - response = await self.genai_client.aio.models.generate_content( + client = get_genai_client_for_model(self._ai_model) + response = await client.aio.models.generate_content( model=f"publishers/google/models/{self._ai_model}", contents=vertex_contents, config=generation_config, # Corrected parameter name diff --git a/cogs/aimod_cog.py b/cogs/aimod_cog.py index b14e4f1..edec6e5 100644 --- a/cogs/aimod_cog.py +++ b/cogs/aimod_cog.py @@ -30,6 +30,7 @@ from gurt.config import ( PROJECT_ID, LOCATION, ) # Assuming gurt.config exists and has these +from gurt.api import get_genai_client_for_model from . import aimod_config as aimod_config_module from .aimod_config import ( @@ -1727,7 +1728,8 @@ CRITICAL: Do NOT output anything other than the required JSON response. # response_mime_type="application/json", # Consider if model supports this for forcing JSON ) - response = await self.genai_client.aio.models.generate_content( + client = get_genai_client_for_model(model_id_to_use) + response = await client.aio.models.generate_content( model=model_path, # Correctly formatted model path contents=request_contents, # User's message with context and images config=final_generation_config, # Pass the config with system_instruction @@ -2412,7 +2414,8 @@ CRITICAL: Do NOT output anything other than the required JSON response. ) try: - response = await self.genai_client.aio.models.generate_content( + client = get_genai_client_for_model(APPEAL_AI_MODEL) + response = await client.aio.models.generate_content( model=f"publishers/google/models/{APPEAL_AI_MODEL}", contents=[ types.Content(role="user", parts=[types.Part(text=user_prompt)]) diff --git a/cogs/neru_teto_cog.py b/cogs/neru_teto_cog.py index d994fc1..d6ba25f 100644 --- a/cogs/neru_teto_cog.py +++ b/cogs/neru_teto_cog.py @@ -27,6 +27,7 @@ from google.genai import types from google.api_core import exceptions as google_exceptions from gurt.config import PROJECT_ID, LOCATION +from gurt.api import get_genai_client_for_model STANDARD_SAFETY_SETTINGS = [ types.SafetySetting( @@ -128,7 +129,8 @@ class DmbotTetoCog(commands.Cog): ) try: - response = await self.genai_client.aio.models.generate_content( + client = get_genai_client_for_model(self._ai_model) + response = await client.aio.models.generate_content( model=f"publishers/google/models/{self._ai_model}", contents=contents, config=generation_config, diff --git a/cogs/rule34_cog.py b/cogs/rule34_cog.py index 78a1c31..5db3eaf 100644 --- a/cogs/rule34_cog.py +++ b/cogs/rule34_cog.py @@ -12,6 +12,7 @@ from google.api_core import exceptions as google_exceptions # Import project configuration for Vertex AI from gurt.config import PROJECT_ID, LOCATION +from gurt.api import get_genai_client_for_model from .gelbooru_watcher_base_cog import GelbooruWatcherBaseCog @@ -171,7 +172,8 @@ class Rule34Cog(GelbooruWatcherBaseCog): # Removed name="Rule34" log.debug( f"Rule34Cog: Sending to Vertex AI for tag transformation. Model: {self.tag_transformer_model}, Input: '{user_tags}'" ) - response = await self.genai_client.aio.models.generate_content( + client = get_genai_client_for_model(self.tag_transformer_model) + response = await client.aio.models.generate_content( model=f"publishers/google/models/{self.tag_transformer_model}", contents=contents_for_api, config=generation_config, diff --git a/cogs/teto_cog.py b/cogs/teto_cog.py index 46bc9c6..6ba5a0d 100644 --- a/cogs/teto_cog.py +++ b/cogs/teto_cog.py @@ -27,6 +27,7 @@ from google.api_core import exceptions as google_exceptions # Import project configuration for Vertex AI from gurt.config import PROJECT_ID, LOCATION +from gurt.api import get_genai_client_for_model # Define standard safety settings using google.generativeai types # Set all thresholds to OFF as requested @@ -527,7 +528,8 @@ class TetoCog(commands.Cog): print( f"[TETO DEBUG] Sending to Vertex AI. Model: {self._ai_model}, Tool Config: {vertex_tools is not None}" ) - response = await self.genai_client.aio.models.generate_content( + client = get_genai_client_for_model(self._ai_model) + response = await client.aio.models.generate_content( model=f"publishers/google/models/{self._ai_model}", # Use simpler model path contents=final_contents_for_api, config=generation_config_with_system, # Pass the updated config diff --git a/gurt/api.py b/gurt/api.py index cea93d7..3e33b07 100644 --- a/gurt/api.py +++ b/gurt/api.py @@ -183,6 +183,7 @@ from .config import ( PISTON_API_KEY, BASELINE_PERSONALITY, TENOR_API_KEY, # Import other needed configs + PRO_MODEL_LOCATION, ) from .prompt import build_dynamic_system_prompt from .context import ( @@ -385,24 +386,41 @@ def _format_embeds_for_prompt(embed_content: List[Dict[str, Any]]) -> Optional[s return "\n".join(formatted_strings) if formatted_strings else None -# --- Initialize Google Generative AI Client for Vertex AI --- -# No explicit genai.configure(api_key=...) needed when using Vertex AI backend +# --- Initialize Google Generative AI Clients for Vertex AI --- +# Separate clients for different regions so models can use their preferred location try: - genai_client = genai.Client( + genai_client_us_central1 = genai.Client( vertexai=True, project=PROJECT_ID, - location=LOCATION, + location="us-central1", ) + genai_client_global = genai.Client( + vertexai=True, + project=PROJECT_ID, + location=PRO_MODEL_LOCATION, + ) + # Default client remains us-central1 for backward compatibility + genai_client = genai_client_us_central1 - print( - f"Google GenAI Client initialized for Vertex AI project '{PROJECT_ID}' in location '{LOCATION}'." - ) + print("Google GenAI Clients initialized for us-central1 and global regions.") except NameError: genai_client = None + genai_client_us_central1 = None + genai_client_global = None print("Google GenAI SDK (genai) not imported, skipping client initialization.") except Exception as e: genai_client = None - print(f"Error initializing Google GenAI Client for Vertex AI: {e}") + genai_client_us_central1 = None + genai_client_global = None + print(f"Error initializing Google GenAI Clients for Vertex AI: {e}") + + +def get_genai_client_for_model(model_name: str): + """Return the appropriate genai client based on the model name.""" + if "gemini-2.5-pro-preview-06-05" in model_name: + return genai_client_global + return genai_client_us_central1 + # --- Constants --- # Define standard safety settings using google.generativeai types @@ -452,8 +470,9 @@ async def call_google_genai_api_with_retry( Raises: Exception: If the API call fails after all retry attempts or encounters a non-retryable error. """ - if not genai_client: - raise Exception("Google GenAI Client (genai_client) is not initialized.") + client = get_genai_client_for_model(model_name) + if not client: + raise Exception("Google GenAI Client is not initialized.") last_exception = None start_time = time.monotonic() @@ -478,7 +497,7 @@ async def call_google_genai_api_with_retry( # Use the non-streaming async call - config now contains all settings # The 'model' parameter here should be the actual model name string - response = await genai_client.aio.models.generate_content( + response = await client.aio.models.generate_content( model=model_name, # Use the model_name string directly contents=contents, config=generation_config, # Pass the combined config object @@ -1060,8 +1079,8 @@ async def get_ai_response( - "error": An error message string if a critical error occurred, otherwise None. - "fallback_initial": Optional minimal response if initial parsing failed critically (less likely with controlled generation). """ - if not PROJECT_ID or not LOCATION or not genai_client: # Check genai_client too - error_msg = "Google Cloud Project ID/Location not configured or GenAI Client failed to initialize." + if not PROJECT_ID: + error_msg = "Google Cloud Project ID not configured." print(f"Error in get_ai_response: {error_msg}") return {"final_response": None, "error": error_msg} @@ -2707,7 +2726,8 @@ async def generate_image_description( Returns: The AI-generated description string, or None if an error occurs. """ - if not genai_client: + client = get_genai_client_for_model(EMOJI_STICKER_DESCRIPTION_MODEL) + if not client: print( "Error in generate_image_description: Google GenAI Client not initialized." ) diff --git a/gurt/config.py b/gurt/config.py index 2083315..e09900a 100644 --- a/gurt/config.py +++ b/gurt/config.py @@ -11,6 +11,8 @@ load_dotenv() # --- API and Keys --- PROJECT_ID = os.getenv("GCP_PROJECT_ID", "1079377687568") LOCATION = os.getenv("GCP_LOCATION", "us-central1") +# Specific region for Gemini 2.5 Pro model +PRO_MODEL_LOCATION = "global" TAVILY_API_KEY = os.getenv("TAVILY_API_KEY", "") TENOR_API_KEY = os.getenv("TENOR_API_KEY", "") # Added Tenor API Key PISTON_API_URL = os.getenv("PISTON_API_URL") # For run_python_code tool diff --git a/gurt/tools.py b/gurt/tools.py index 8e62d4d..535ec3f 100644 --- a/gurt/tools.py +++ b/gurt/tools.py @@ -41,6 +41,7 @@ from .config import ( TAVILY_DEFAULT_MAX_RESULTS, TAVILY_DISABLE_ADVANCED, ) +from gurt.api import get_genai_client_for_model # Assume these helpers will be moved or are accessible via cog # We might need to pass 'cog' to these tool functions if they rely on cog state heavily @@ -4544,15 +4545,10 @@ async def send_tenor_gif( for gif_data in gif_parts: ai_content.append(gif_data["part"]) - # Initialize AI client if needed - if not hasattr(cog, "genai_client") or not cog.genai_client: - cog.genai_client = genai.Client( - vertexai=True, project=PROJECT_ID, location=LOCATION - ) - # Generate AI selection try: - response = await cog.genai_client.aio.models.generate_content( + client = get_genai_client_for_model(DEFAULT_MODEL) + response = await client.aio.models.generate_content( model=DEFAULT_MODEL, contents=[types.Content(role="user", parts=ai_content)], config=types.GenerateContentConfig(