Refactor genai client logic to avoid circular imports
This commit is contained in:
parent
f1d08908c3
commit
3b024e132b
@ -35,7 +35,7 @@ except ImportError:
|
|||||||
# Allow cog to load but genai_client will be None
|
# Allow cog to load but genai_client will be None
|
||||||
|
|
||||||
from tavily import TavilyClient
|
from tavily import TavilyClient
|
||||||
from gurt.api import get_genai_client_for_model
|
from gurt.genai_client import get_genai_client_for_model
|
||||||
|
|
||||||
# Define standard safety settings using google.generativeai types
|
# Define standard safety settings using google.generativeai types
|
||||||
# Set all thresholds to OFF as requested for internal tools
|
# Set all thresholds to OFF as requested for internal tools
|
||||||
|
@ -30,7 +30,7 @@ from gurt.config import (
|
|||||||
PROJECT_ID,
|
PROJECT_ID,
|
||||||
LOCATION,
|
LOCATION,
|
||||||
) # Assuming gurt.config exists and has these
|
) # Assuming gurt.config exists and has these
|
||||||
from gurt.api import get_genai_client_for_model
|
from gurt.genai_client import get_genai_client_for_model
|
||||||
|
|
||||||
from . import aimod_config as aimod_config_module
|
from . import aimod_config as aimod_config_module
|
||||||
from .aimod_config import (
|
from .aimod_config import (
|
||||||
|
@ -27,7 +27,7 @@ from google.genai import types
|
|||||||
from google.api_core import exceptions as google_exceptions
|
from google.api_core import exceptions as google_exceptions
|
||||||
|
|
||||||
from gurt.config import PROJECT_ID, LOCATION
|
from gurt.config import PROJECT_ID, LOCATION
|
||||||
from gurt.api import get_genai_client_for_model
|
from gurt.genai_client import get_genai_client_for_model
|
||||||
|
|
||||||
STANDARD_SAFETY_SETTINGS = [
|
STANDARD_SAFETY_SETTINGS = [
|
||||||
types.SafetySetting(
|
types.SafetySetting(
|
||||||
|
@ -12,7 +12,7 @@ from google.api_core import exceptions as google_exceptions
|
|||||||
|
|
||||||
# Import project configuration for Vertex AI
|
# Import project configuration for Vertex AI
|
||||||
from gurt.config import PROJECT_ID, LOCATION
|
from gurt.config import PROJECT_ID, LOCATION
|
||||||
from gurt.api import get_genai_client_for_model
|
from gurt.genai_client import get_genai_client_for_model
|
||||||
|
|
||||||
from .gelbooru_watcher_base_cog import GelbooruWatcherBaseCog
|
from .gelbooru_watcher_base_cog import GelbooruWatcherBaseCog
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ from google.api_core import exceptions as google_exceptions
|
|||||||
|
|
||||||
# Import project configuration for Vertex AI
|
# Import project configuration for Vertex AI
|
||||||
from gurt.config import PROJECT_ID, LOCATION
|
from gurt.config import PROJECT_ID, LOCATION
|
||||||
from gurt.api import get_genai_client_for_model
|
from gurt.genai_client import get_genai_client_for_model
|
||||||
|
|
||||||
# Define standard safety settings using google.generativeai types
|
# Define standard safety settings using google.generativeai types
|
||||||
# Set all thresholds to OFF as requested
|
# Set all thresholds to OFF as requested
|
||||||
|
37
gurt/api.py
37
gurt/api.py
@ -183,7 +183,6 @@ from .config import (
|
|||||||
PISTON_API_KEY,
|
PISTON_API_KEY,
|
||||||
BASELINE_PERSONALITY,
|
BASELINE_PERSONALITY,
|
||||||
TENOR_API_KEY, # Import other needed configs
|
TENOR_API_KEY, # Import other needed configs
|
||||||
PRO_MODEL_LOCATION,
|
|
||||||
)
|
)
|
||||||
from .prompt import build_dynamic_system_prompt
|
from .prompt import build_dynamic_system_prompt
|
||||||
from .context import (
|
from .context import (
|
||||||
@ -386,40 +385,10 @@ def _format_embeds_for_prompt(embed_content: List[Dict[str, Any]]) -> Optional[s
|
|||||||
return "\n".join(formatted_strings) if formatted_strings else None
|
return "\n".join(formatted_strings) if formatted_strings else None
|
||||||
|
|
||||||
|
|
||||||
# --- Initialize Google Generative AI Clients for Vertex AI ---
|
from .genai_client import (
|
||||||
# Separate clients for different regions so models can use their preferred location
|
genai_client,
|
||||||
try:
|
get_genai_client_for_model,
|
||||||
genai_client_us_central1 = genai.Client(
|
|
||||||
vertexai=True,
|
|
||||||
project=PROJECT_ID,
|
|
||||||
location="us-central1",
|
|
||||||
)
|
)
|
||||||
genai_client_global = genai.Client(
|
|
||||||
vertexai=True,
|
|
||||||
project=PROJECT_ID,
|
|
||||||
location=PRO_MODEL_LOCATION,
|
|
||||||
)
|
|
||||||
# Default client remains us-central1 for backward compatibility
|
|
||||||
genai_client = genai_client_us_central1
|
|
||||||
|
|
||||||
print("Google GenAI Clients initialized for us-central1 and global regions.")
|
|
||||||
except NameError:
|
|
||||||
genai_client = None
|
|
||||||
genai_client_us_central1 = None
|
|
||||||
genai_client_global = None
|
|
||||||
print("Google GenAI SDK (genai) not imported, skipping client initialization.")
|
|
||||||
except Exception as e:
|
|
||||||
genai_client = None
|
|
||||||
genai_client_us_central1 = None
|
|
||||||
genai_client_global = None
|
|
||||||
print(f"Error initializing Google GenAI Clients for Vertex AI: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_genai_client_for_model(model_name: str):
|
|
||||||
"""Return the appropriate genai client based on the model name."""
|
|
||||||
if "gemini-2.5-pro-preview-06-05" in model_name:
|
|
||||||
return genai_client_global
|
|
||||||
return genai_client_us_central1
|
|
||||||
|
|
||||||
|
|
||||||
# --- Constants ---
|
# --- Constants ---
|
||||||
|
41
gurt/genai_client.py
Normal file
41
gurt/genai_client.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from google import genai
|
||||||
|
|
||||||
|
from .config import PROJECT_ID, PRO_MODEL_LOCATION
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"genai_client_us_central1",
|
||||||
|
"genai_client_global",
|
||||||
|
"genai_client",
|
||||||
|
"get_genai_client_for_model",
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
genai_client_us_central1 = genai.Client(
|
||||||
|
vertexai=True,
|
||||||
|
project=PROJECT_ID,
|
||||||
|
location="us-central1",
|
||||||
|
)
|
||||||
|
genai_client_global = genai.Client(
|
||||||
|
vertexai=True,
|
||||||
|
project=PROJECT_ID,
|
||||||
|
location=PRO_MODEL_LOCATION,
|
||||||
|
)
|
||||||
|
genai_client = genai_client_us_central1
|
||||||
|
print("Google GenAI Clients initialized for us-central1 and global regions.")
|
||||||
|
except NameError:
|
||||||
|
genai_client_us_central1 = None
|
||||||
|
genai_client_global = None
|
||||||
|
genai_client = None
|
||||||
|
print("Google GenAI SDK (genai) not imported, skipping client initialization.")
|
||||||
|
except Exception as e:
|
||||||
|
genai_client_us_central1 = None
|
||||||
|
genai_client_global = None
|
||||||
|
genai_client = None
|
||||||
|
print(f"Error initializing Google GenAI Clients for Vertex AI: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_genai_client_for_model(model_name: str):
|
||||||
|
"""Return the appropriate genai client based on the model name."""
|
||||||
|
if "gemini-2.5-pro-preview-06-05" in model_name:
|
||||||
|
return genai_client_global
|
||||||
|
return genai_client_us_central1
|
@ -41,7 +41,7 @@ from .config import (
|
|||||||
TAVILY_DEFAULT_MAX_RESULTS,
|
TAVILY_DEFAULT_MAX_RESULTS,
|
||||||
TAVILY_DISABLE_ADVANCED,
|
TAVILY_DISABLE_ADVANCED,
|
||||||
)
|
)
|
||||||
from gurt.api import get_genai_client_for_model
|
from gurt.genai_client import get_genai_client_for_model
|
||||||
|
|
||||||
# Assume these helpers will be moved or are accessible via cog
|
# Assume these helpers will be moved or are accessible via cog
|
||||||
# We might need to pass 'cog' to these tool functions if they rely on cog state heavily
|
# We might need to pass 'cog' to these tool functions if they rely on cog state heavily
|
||||||
|
Loading…
x
Reference in New Issue
Block a user