Refactor genai client logic to avoid circular imports
This commit is contained in:
parent
f1d08908c3
commit
3b024e132b
@ -35,7 +35,7 @@ except ImportError:
|
||||
# Allow cog to load but genai_client will be None
|
||||
|
||||
from tavily import TavilyClient
|
||||
from gurt.api import get_genai_client_for_model
|
||||
from gurt.genai_client import get_genai_client_for_model
|
||||
|
||||
# Define standard safety settings using google.generativeai types
|
||||
# Set all thresholds to OFF as requested for internal tools
|
||||
|
@ -30,7 +30,7 @@ from gurt.config import (
|
||||
PROJECT_ID,
|
||||
LOCATION,
|
||||
) # Assuming gurt.config exists and has these
|
||||
from gurt.api import get_genai_client_for_model
|
||||
from gurt.genai_client import get_genai_client_for_model
|
||||
|
||||
from . import aimod_config as aimod_config_module
|
||||
from .aimod_config import (
|
||||
|
@ -27,7 +27,7 @@ from google.genai import types
|
||||
from google.api_core import exceptions as google_exceptions
|
||||
|
||||
from gurt.config import PROJECT_ID, LOCATION
|
||||
from gurt.api import get_genai_client_for_model
|
||||
from gurt.genai_client import get_genai_client_for_model
|
||||
|
||||
STANDARD_SAFETY_SETTINGS = [
|
||||
types.SafetySetting(
|
||||
|
@ -12,7 +12,7 @@ from google.api_core import exceptions as google_exceptions
|
||||
|
||||
# Import project configuration for Vertex AI
|
||||
from gurt.config import PROJECT_ID, LOCATION
|
||||
from gurt.api import get_genai_client_for_model
|
||||
from gurt.genai_client import get_genai_client_for_model
|
||||
|
||||
from .gelbooru_watcher_base_cog import GelbooruWatcherBaseCog
|
||||
|
||||
|
@ -27,7 +27,7 @@ from google.api_core import exceptions as google_exceptions
|
||||
|
||||
# Import project configuration for Vertex AI
|
||||
from gurt.config import PROJECT_ID, LOCATION
|
||||
from gurt.api import get_genai_client_for_model
|
||||
from gurt.genai_client import get_genai_client_for_model
|
||||
|
||||
# Define standard safety settings using google.generativeai types
|
||||
# Set all thresholds to OFF as requested
|
||||
|
37
gurt/api.py
37
gurt/api.py
@ -183,7 +183,6 @@ from .config import (
|
||||
PISTON_API_KEY,
|
||||
BASELINE_PERSONALITY,
|
||||
TENOR_API_KEY, # Import other needed configs
|
||||
PRO_MODEL_LOCATION,
|
||||
)
|
||||
from .prompt import build_dynamic_system_prompt
|
||||
from .context import (
|
||||
@ -386,40 +385,10 @@ def _format_embeds_for_prompt(embed_content: List[Dict[str, Any]]) -> Optional[s
|
||||
return "\n".join(formatted_strings) if formatted_strings else None
|
||||
|
||||
|
||||
# --- Initialize Google Generative AI Clients for Vertex AI ---
|
||||
# Separate clients for different regions so models can use their preferred location
|
||||
try:
|
||||
genai_client_us_central1 = genai.Client(
|
||||
vertexai=True,
|
||||
project=PROJECT_ID,
|
||||
location="us-central1",
|
||||
from .genai_client import (
|
||||
genai_client,
|
||||
get_genai_client_for_model,
|
||||
)
|
||||
genai_client_global = genai.Client(
|
||||
vertexai=True,
|
||||
project=PROJECT_ID,
|
||||
location=PRO_MODEL_LOCATION,
|
||||
)
|
||||
# Default client remains us-central1 for backward compatibility
|
||||
genai_client = genai_client_us_central1
|
||||
|
||||
print("Google GenAI Clients initialized for us-central1 and global regions.")
|
||||
except NameError:
|
||||
genai_client = None
|
||||
genai_client_us_central1 = None
|
||||
genai_client_global = None
|
||||
print("Google GenAI SDK (genai) not imported, skipping client initialization.")
|
||||
except Exception as e:
|
||||
genai_client = None
|
||||
genai_client_us_central1 = None
|
||||
genai_client_global = None
|
||||
print(f"Error initializing Google GenAI Clients for Vertex AI: {e}")
|
||||
|
||||
|
||||
def get_genai_client_for_model(model_name: str):
|
||||
"""Return the appropriate genai client based on the model name."""
|
||||
if "gemini-2.5-pro-preview-06-05" in model_name:
|
||||
return genai_client_global
|
||||
return genai_client_us_central1
|
||||
|
||||
|
||||
# --- Constants ---
|
||||
|
41
gurt/genai_client.py
Normal file
41
gurt/genai_client.py
Normal file
@ -0,0 +1,41 @@
|
||||
from google import genai
|
||||
|
||||
from .config import PROJECT_ID, PRO_MODEL_LOCATION
|
||||
|
||||
__all__ = [
|
||||
"genai_client_us_central1",
|
||||
"genai_client_global",
|
||||
"genai_client",
|
||||
"get_genai_client_for_model",
|
||||
]
|
||||
|
||||
try:
|
||||
genai_client_us_central1 = genai.Client(
|
||||
vertexai=True,
|
||||
project=PROJECT_ID,
|
||||
location="us-central1",
|
||||
)
|
||||
genai_client_global = genai.Client(
|
||||
vertexai=True,
|
||||
project=PROJECT_ID,
|
||||
location=PRO_MODEL_LOCATION,
|
||||
)
|
||||
genai_client = genai_client_us_central1
|
||||
print("Google GenAI Clients initialized for us-central1 and global regions.")
|
||||
except NameError:
|
||||
genai_client_us_central1 = None
|
||||
genai_client_global = None
|
||||
genai_client = None
|
||||
print("Google GenAI SDK (genai) not imported, skipping client initialization.")
|
||||
except Exception as e:
|
||||
genai_client_us_central1 = None
|
||||
genai_client_global = None
|
||||
genai_client = None
|
||||
print(f"Error initializing Google GenAI Clients for Vertex AI: {e}")
|
||||
|
||||
|
||||
def get_genai_client_for_model(model_name: str):
|
||||
"""Return the appropriate genai client based on the model name."""
|
||||
if "gemini-2.5-pro-preview-06-05" in model_name:
|
||||
return genai_client_global
|
||||
return genai_client_us_central1
|
@ -41,7 +41,7 @@ from .config import (
|
||||
TAVILY_DEFAULT_MAX_RESULTS,
|
||||
TAVILY_DISABLE_ADVANCED,
|
||||
)
|
||||
from gurt.api import get_genai_client_for_model
|
||||
from gurt.genai_client import get_genai_client_for_model
|
||||
|
||||
# Assume these helpers will be moved or are accessible via cog
|
||||
# We might need to pass 'cog' to these tool functions if they rely on cog state heavily
|
||||
|
Loading…
x
Reference in New Issue
Block a user