diff --git a/cogs/ai_code_agent_cog.py b/cogs/ai_code_agent_cog.py index 30ffabd..10a5f9e 100644 --- a/cogs/ai_code_agent_cog.py +++ b/cogs/ai_code_agent_cog.py @@ -35,7 +35,7 @@ except ImportError: # Allow cog to load but genai_client will be None from tavily import TavilyClient -from gurt.api import get_genai_client_for_model +from gurt.genai_client import get_genai_client_for_model # Define standard safety settings using google.generativeai types # Set all thresholds to OFF as requested for internal tools diff --git a/cogs/aimod_cog.py b/cogs/aimod_cog.py index edec6e5..40a4bed 100644 --- a/cogs/aimod_cog.py +++ b/cogs/aimod_cog.py @@ -30,7 +30,7 @@ from gurt.config import ( PROJECT_ID, LOCATION, ) # Assuming gurt.config exists and has these -from gurt.api import get_genai_client_for_model +from gurt.genai_client import get_genai_client_for_model from . import aimod_config as aimod_config_module from .aimod_config import ( diff --git a/cogs/neru_teto_cog.py b/cogs/neru_teto_cog.py index d6ba25f..c206417 100644 --- a/cogs/neru_teto_cog.py +++ b/cogs/neru_teto_cog.py @@ -27,7 +27,7 @@ from google.genai import types from google.api_core import exceptions as google_exceptions from gurt.config import PROJECT_ID, LOCATION -from gurt.api import get_genai_client_for_model +from gurt.genai_client import get_genai_client_for_model STANDARD_SAFETY_SETTINGS = [ types.SafetySetting( diff --git a/cogs/rule34_cog.py b/cogs/rule34_cog.py index 5db3eaf..8109f0f 100644 --- a/cogs/rule34_cog.py +++ b/cogs/rule34_cog.py @@ -12,7 +12,7 @@ from google.api_core import exceptions as google_exceptions # Import project configuration for Vertex AI from gurt.config import PROJECT_ID, LOCATION -from gurt.api import get_genai_client_for_model +from gurt.genai_client import get_genai_client_for_model from .gelbooru_watcher_base_cog import GelbooruWatcherBaseCog diff --git a/cogs/teto_cog.py b/cogs/teto_cog.py index 6ba5a0d..268388a 100644 --- a/cogs/teto_cog.py +++ b/cogs/teto_cog.py @@ -27,7 +27,7 @@ from google.api_core import exceptions as google_exceptions # Import project configuration for Vertex AI from gurt.config import PROJECT_ID, LOCATION -from gurt.api import get_genai_client_for_model +from gurt.genai_client import get_genai_client_for_model # Define standard safety settings using google.generativeai types # Set all thresholds to OFF as requested diff --git a/gurt/api.py b/gurt/api.py index 3e33b07..3b46bc9 100644 --- a/gurt/api.py +++ b/gurt/api.py @@ -183,7 +183,6 @@ from .config import ( PISTON_API_KEY, BASELINE_PERSONALITY, TENOR_API_KEY, # Import other needed configs - PRO_MODEL_LOCATION, ) from .prompt import build_dynamic_system_prompt from .context import ( @@ -386,40 +385,10 @@ def _format_embeds_for_prompt(embed_content: List[Dict[str, Any]]) -> Optional[s return "\n".join(formatted_strings) if formatted_strings else None -# --- Initialize Google Generative AI Clients for Vertex AI --- -# Separate clients for different regions so models can use their preferred location -try: - genai_client_us_central1 = genai.Client( - vertexai=True, - project=PROJECT_ID, - location="us-central1", - ) - genai_client_global = genai.Client( - vertexai=True, - project=PROJECT_ID, - location=PRO_MODEL_LOCATION, - ) - # Default client remains us-central1 for backward compatibility - genai_client = genai_client_us_central1 - - print("Google GenAI Clients initialized for us-central1 and global regions.") -except NameError: - genai_client = None - genai_client_us_central1 = None - genai_client_global = None - print("Google GenAI SDK (genai) not imported, skipping client initialization.") -except Exception as e: - genai_client = None - genai_client_us_central1 = None - genai_client_global = None - print(f"Error initializing Google GenAI Clients for Vertex AI: {e}") - - -def get_genai_client_for_model(model_name: str): - """Return the appropriate genai client based on the model name.""" - if "gemini-2.5-pro-preview-06-05" in model_name: - return genai_client_global - return genai_client_us_central1 +from .genai_client import ( + genai_client, + get_genai_client_for_model, +) # --- Constants --- diff --git a/gurt/genai_client.py b/gurt/genai_client.py new file mode 100644 index 0000000..bc86a9b --- /dev/null +++ b/gurt/genai_client.py @@ -0,0 +1,41 @@ +from google import genai + +from .config import PROJECT_ID, PRO_MODEL_LOCATION + +__all__ = [ + "genai_client_us_central1", + "genai_client_global", + "genai_client", + "get_genai_client_for_model", +] + +try: + genai_client_us_central1 = genai.Client( + vertexai=True, + project=PROJECT_ID, + location="us-central1", + ) + genai_client_global = genai.Client( + vertexai=True, + project=PROJECT_ID, + location=PRO_MODEL_LOCATION, + ) + genai_client = genai_client_us_central1 + print("Google GenAI Clients initialized for us-central1 and global regions.") +except NameError: + genai_client_us_central1 = None + genai_client_global = None + genai_client = None + print("Google GenAI SDK (genai) not imported, skipping client initialization.") +except Exception as e: + genai_client_us_central1 = None + genai_client_global = None + genai_client = None + print(f"Error initializing Google GenAI Clients for Vertex AI: {e}") + + +def get_genai_client_for_model(model_name: str): + """Return the appropriate genai client based on the model name.""" + if "gemini-2.5-pro-preview-06-05" in model_name: + return genai_client_global + return genai_client_us_central1 diff --git a/gurt/tools.py b/gurt/tools.py index 535ec3f..3434548 100644 --- a/gurt/tools.py +++ b/gurt/tools.py @@ -41,7 +41,7 @@ from .config import ( TAVILY_DEFAULT_MAX_RESULTS, TAVILY_DISABLE_ADVANCED, ) -from gurt.api import get_genai_client_for_model +from gurt.genai_client import get_genai_client_for_model # Assume these helpers will be moved or are accessible via cog # We might need to pass 'cog' to these tool functions if they rely on cog state heavily