Refactor genai client logic to avoid circular imports

This commit is contained in:
Codex 2025-06-07 06:06:35 +00:00 committed by Slipstream
parent f1d08908c3
commit 3b024e132b
Signed by: slipstream
GPG Key ID: 13E498CE010AC6FD
8 changed files with 51 additions and 41 deletions

View File

@ -35,7 +35,7 @@ except ImportError:
# Allow cog to load but genai_client will be None
from tavily import TavilyClient
from gurt.api import get_genai_client_for_model
from gurt.genai_client import get_genai_client_for_model
# Define standard safety settings using google.generativeai types
# Set all thresholds to OFF as requested for internal tools

View File

@ -30,7 +30,7 @@ from gurt.config import (
PROJECT_ID,
LOCATION,
) # Assuming gurt.config exists and has these
from gurt.api import get_genai_client_for_model
from gurt.genai_client import get_genai_client_for_model
from . import aimod_config as aimod_config_module
from .aimod_config import (

View File

@ -27,7 +27,7 @@ from google.genai import types
from google.api_core import exceptions as google_exceptions
from gurt.config import PROJECT_ID, LOCATION
from gurt.api import get_genai_client_for_model
from gurt.genai_client import get_genai_client_for_model
STANDARD_SAFETY_SETTINGS = [
types.SafetySetting(

View File

@ -12,7 +12,7 @@ from google.api_core import exceptions as google_exceptions
# Import project configuration for Vertex AI
from gurt.config import PROJECT_ID, LOCATION
from gurt.api import get_genai_client_for_model
from gurt.genai_client import get_genai_client_for_model
from .gelbooru_watcher_base_cog import GelbooruWatcherBaseCog

View File

@ -27,7 +27,7 @@ from google.api_core import exceptions as google_exceptions
# Import project configuration for Vertex AI
from gurt.config import PROJECT_ID, LOCATION
from gurt.api import get_genai_client_for_model
from gurt.genai_client import get_genai_client_for_model
# Define standard safety settings using google.generativeai types
# Set all thresholds to OFF as requested

View File

@ -183,7 +183,6 @@ from .config import (
PISTON_API_KEY,
BASELINE_PERSONALITY,
TENOR_API_KEY, # Import other needed configs
PRO_MODEL_LOCATION,
)
from .prompt import build_dynamic_system_prompt
from .context import (
@ -386,40 +385,10 @@ def _format_embeds_for_prompt(embed_content: List[Dict[str, Any]]) -> Optional[s
return "\n".join(formatted_strings) if formatted_strings else None
# --- Initialize Google Generative AI Clients for Vertex AI ---
# Separate clients for different regions so models can use their preferred location
try:
genai_client_us_central1 = genai.Client(
vertexai=True,
project=PROJECT_ID,
location="us-central1",
from .genai_client import (
genai_client,
get_genai_client_for_model,
)
genai_client_global = genai.Client(
vertexai=True,
project=PROJECT_ID,
location=PRO_MODEL_LOCATION,
)
# Default client remains us-central1 for backward compatibility
genai_client = genai_client_us_central1
print("Google GenAI Clients initialized for us-central1 and global regions.")
except NameError:
genai_client = None
genai_client_us_central1 = None
genai_client_global = None
print("Google GenAI SDK (genai) not imported, skipping client initialization.")
except Exception as e:
genai_client = None
genai_client_us_central1 = None
genai_client_global = None
print(f"Error initializing Google GenAI Clients for Vertex AI: {e}")
def get_genai_client_for_model(model_name: str):
"""Return the appropriate genai client based on the model name."""
if "gemini-2.5-pro-preview-06-05" in model_name:
return genai_client_global
return genai_client_us_central1
# --- Constants ---

41
gurt/genai_client.py Normal file
View File

@ -0,0 +1,41 @@
from google import genai
from .config import PROJECT_ID, PRO_MODEL_LOCATION
__all__ = [
"genai_client_us_central1",
"genai_client_global",
"genai_client",
"get_genai_client_for_model",
]
try:
genai_client_us_central1 = genai.Client(
vertexai=True,
project=PROJECT_ID,
location="us-central1",
)
genai_client_global = genai.Client(
vertexai=True,
project=PROJECT_ID,
location=PRO_MODEL_LOCATION,
)
genai_client = genai_client_us_central1
print("Google GenAI Clients initialized for us-central1 and global regions.")
except NameError:
genai_client_us_central1 = None
genai_client_global = None
genai_client = None
print("Google GenAI SDK (genai) not imported, skipping client initialization.")
except Exception as e:
genai_client_us_central1 = None
genai_client_global = None
genai_client = None
print(f"Error initializing Google GenAI Clients for Vertex AI: {e}")
def get_genai_client_for_model(model_name: str):
"""Return the appropriate genai client based on the model name."""
if "gemini-2.5-pro-preview-06-05" in model_name:
return genai_client_global
return genai_client_us_central1

View File

@ -41,7 +41,7 @@ from .config import (
TAVILY_DEFAULT_MAX_RESULTS,
TAVILY_DISABLE_ADVANCED,
)
from gurt.api import get_genai_client_for_model
from gurt.genai_client import get_genai_client_for_model
# Assume these helpers will be moved or are accessible via cog
# We might need to pass 'cog' to these tool functions if they rely on cog state heavily