Use global region for gemini 2.5 pro
This commit is contained in:
parent
35bd409d54
commit
998d0a3b15
@ -35,6 +35,7 @@ except ImportError:
|
||||
# Allow cog to load but genai_client will be None
|
||||
|
||||
from tavily import TavilyClient
|
||||
from gurt.api import get_genai_client_for_model
|
||||
|
||||
# Define standard safety settings using google.generativeai types
|
||||
# Set all thresholds to OFF as requested for internal tools
|
||||
@ -2432,7 +2433,8 @@ class AICodeAgentCog(commands.Cog):
|
||||
# for i, item in enumerate(vertex_contents):
|
||||
# print(f" History {i} Role: {item.role}, Parts: {item.parts}")
|
||||
|
||||
response = await self.genai_client.aio.models.generate_content(
|
||||
client = get_genai_client_for_model(self._ai_model)
|
||||
response = await client.aio.models.generate_content(
|
||||
model=f"publishers/google/models/{self._ai_model}",
|
||||
contents=vertex_contents,
|
||||
config=generation_config, # Corrected parameter name
|
||||
|
@ -30,6 +30,7 @@ from gurt.config import (
|
||||
PROJECT_ID,
|
||||
LOCATION,
|
||||
) # Assuming gurt.config exists and has these
|
||||
from gurt.api import get_genai_client_for_model
|
||||
|
||||
from . import aimod_config as aimod_config_module
|
||||
from .aimod_config import (
|
||||
@ -1727,7 +1728,8 @@ CRITICAL: Do NOT output anything other than the required JSON response.
|
||||
# response_mime_type="application/json", # Consider if model supports this for forcing JSON
|
||||
)
|
||||
|
||||
response = await self.genai_client.aio.models.generate_content(
|
||||
client = get_genai_client_for_model(model_id_to_use)
|
||||
response = await client.aio.models.generate_content(
|
||||
model=model_path, # Correctly formatted model path
|
||||
contents=request_contents, # User's message with context and images
|
||||
config=final_generation_config, # Pass the config with system_instruction
|
||||
@ -2412,7 +2414,8 @@ CRITICAL: Do NOT output anything other than the required JSON response.
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.genai_client.aio.models.generate_content(
|
||||
client = get_genai_client_for_model(APPEAL_AI_MODEL)
|
||||
response = await client.aio.models.generate_content(
|
||||
model=f"publishers/google/models/{APPEAL_AI_MODEL}",
|
||||
contents=[
|
||||
types.Content(role="user", parts=[types.Part(text=user_prompt)])
|
||||
|
@ -27,6 +27,7 @@ from google.genai import types
|
||||
from google.api_core import exceptions as google_exceptions
|
||||
|
||||
from gurt.config import PROJECT_ID, LOCATION
|
||||
from gurt.api import get_genai_client_for_model
|
||||
|
||||
STANDARD_SAFETY_SETTINGS = [
|
||||
types.SafetySetting(
|
||||
@ -128,7 +129,8 @@ class DmbotTetoCog(commands.Cog):
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.genai_client.aio.models.generate_content(
|
||||
client = get_genai_client_for_model(self._ai_model)
|
||||
response = await client.aio.models.generate_content(
|
||||
model=f"publishers/google/models/{self._ai_model}",
|
||||
contents=contents,
|
||||
config=generation_config,
|
||||
|
@ -12,6 +12,7 @@ from google.api_core import exceptions as google_exceptions
|
||||
|
||||
# Import project configuration for Vertex AI
|
||||
from gurt.config import PROJECT_ID, LOCATION
|
||||
from gurt.api import get_genai_client_for_model
|
||||
|
||||
from .gelbooru_watcher_base_cog import GelbooruWatcherBaseCog
|
||||
|
||||
@ -171,7 +172,8 @@ class Rule34Cog(GelbooruWatcherBaseCog): # Removed name="Rule34"
|
||||
log.debug(
|
||||
f"Rule34Cog: Sending to Vertex AI for tag transformation. Model: {self.tag_transformer_model}, Input: '{user_tags}'"
|
||||
)
|
||||
response = await self.genai_client.aio.models.generate_content(
|
||||
client = get_genai_client_for_model(self.tag_transformer_model)
|
||||
response = await client.aio.models.generate_content(
|
||||
model=f"publishers/google/models/{self.tag_transformer_model}",
|
||||
contents=contents_for_api,
|
||||
config=generation_config,
|
||||
|
@ -27,6 +27,7 @@ from google.api_core import exceptions as google_exceptions
|
||||
|
||||
# Import project configuration for Vertex AI
|
||||
from gurt.config import PROJECT_ID, LOCATION
|
||||
from gurt.api import get_genai_client_for_model
|
||||
|
||||
# Define standard safety settings using google.generativeai types
|
||||
# Set all thresholds to OFF as requested
|
||||
@ -527,7 +528,8 @@ class TetoCog(commands.Cog):
|
||||
print(
|
||||
f"[TETO DEBUG] Sending to Vertex AI. Model: {self._ai_model}, Tool Config: {vertex_tools is not None}"
|
||||
)
|
||||
response = await self.genai_client.aio.models.generate_content(
|
||||
client = get_genai_client_for_model(self._ai_model)
|
||||
response = await client.aio.models.generate_content(
|
||||
model=f"publishers/google/models/{self._ai_model}", # Use simpler model path
|
||||
contents=final_contents_for_api,
|
||||
config=generation_config_with_system, # Pass the updated config
|
||||
|
48
gurt/api.py
48
gurt/api.py
@ -183,6 +183,7 @@ from .config import (
|
||||
PISTON_API_KEY,
|
||||
BASELINE_PERSONALITY,
|
||||
TENOR_API_KEY, # Import other needed configs
|
||||
PRO_MODEL_LOCATION,
|
||||
)
|
||||
from .prompt import build_dynamic_system_prompt
|
||||
from .context import (
|
||||
@ -385,24 +386,41 @@ def _format_embeds_for_prompt(embed_content: List[Dict[str, Any]]) -> Optional[s
|
||||
return "\n".join(formatted_strings) if formatted_strings else None
|
||||
|
||||
|
||||
# --- Initialize Google Generative AI Client for Vertex AI ---
|
||||
# No explicit genai.configure(api_key=...) needed when using Vertex AI backend
|
||||
# --- Initialize Google Generative AI Clients for Vertex AI ---
|
||||
# Separate clients for different regions so models can use their preferred location
|
||||
try:
|
||||
genai_client = genai.Client(
|
||||
genai_client_us_central1 = genai.Client(
|
||||
vertexai=True,
|
||||
project=PROJECT_ID,
|
||||
location=LOCATION,
|
||||
location="us-central1",
|
||||
)
|
||||
genai_client_global = genai.Client(
|
||||
vertexai=True,
|
||||
project=PROJECT_ID,
|
||||
location=PRO_MODEL_LOCATION,
|
||||
)
|
||||
# Default client remains us-central1 for backward compatibility
|
||||
genai_client = genai_client_us_central1
|
||||
|
||||
print(
|
||||
f"Google GenAI Client initialized for Vertex AI project '{PROJECT_ID}' in location '{LOCATION}'."
|
||||
)
|
||||
print("Google GenAI Clients initialized for us-central1 and global regions.")
|
||||
except NameError:
|
||||
genai_client = None
|
||||
genai_client_us_central1 = None
|
||||
genai_client_global = None
|
||||
print("Google GenAI SDK (genai) not imported, skipping client initialization.")
|
||||
except Exception as e:
|
||||
genai_client = None
|
||||
print(f"Error initializing Google GenAI Client for Vertex AI: {e}")
|
||||
genai_client_us_central1 = None
|
||||
genai_client_global = None
|
||||
print(f"Error initializing Google GenAI Clients for Vertex AI: {e}")
|
||||
|
||||
|
||||
def get_genai_client_for_model(model_name: str):
|
||||
"""Return the appropriate genai client based on the model name."""
|
||||
if "gemini-2.5-pro-preview-06-05" in model_name:
|
||||
return genai_client_global
|
||||
return genai_client_us_central1
|
||||
|
||||
|
||||
# --- Constants ---
|
||||
# Define standard safety settings using google.generativeai types
|
||||
@ -452,8 +470,9 @@ async def call_google_genai_api_with_retry(
|
||||
Raises:
|
||||
Exception: If the API call fails after all retry attempts or encounters a non-retryable error.
|
||||
"""
|
||||
if not genai_client:
|
||||
raise Exception("Google GenAI Client (genai_client) is not initialized.")
|
||||
client = get_genai_client_for_model(model_name)
|
||||
if not client:
|
||||
raise Exception("Google GenAI Client is not initialized.")
|
||||
|
||||
last_exception = None
|
||||
start_time = time.monotonic()
|
||||
@ -478,7 +497,7 @@ async def call_google_genai_api_with_retry(
|
||||
|
||||
# Use the non-streaming async call - config now contains all settings
|
||||
# The 'model' parameter here should be the actual model name string
|
||||
response = await genai_client.aio.models.generate_content(
|
||||
response = await client.aio.models.generate_content(
|
||||
model=model_name, # Use the model_name string directly
|
||||
contents=contents,
|
||||
config=generation_config, # Pass the combined config object
|
||||
@ -1060,8 +1079,8 @@ async def get_ai_response(
|
||||
- "error": An error message string if a critical error occurred, otherwise None.
|
||||
- "fallback_initial": Optional minimal response if initial parsing failed critically (less likely with controlled generation).
|
||||
"""
|
||||
if not PROJECT_ID or not LOCATION or not genai_client: # Check genai_client too
|
||||
error_msg = "Google Cloud Project ID/Location not configured or GenAI Client failed to initialize."
|
||||
if not PROJECT_ID:
|
||||
error_msg = "Google Cloud Project ID not configured."
|
||||
print(f"Error in get_ai_response: {error_msg}")
|
||||
return {"final_response": None, "error": error_msg}
|
||||
|
||||
@ -2707,7 +2726,8 @@ async def generate_image_description(
|
||||
Returns:
|
||||
The AI-generated description string, or None if an error occurs.
|
||||
"""
|
||||
if not genai_client:
|
||||
client = get_genai_client_for_model(EMOJI_STICKER_DESCRIPTION_MODEL)
|
||||
if not client:
|
||||
print(
|
||||
"Error in generate_image_description: Google GenAI Client not initialized."
|
||||
)
|
||||
|
@ -11,6 +11,8 @@ load_dotenv()
|
||||
# --- API and Keys ---
|
||||
PROJECT_ID = os.getenv("GCP_PROJECT_ID", "1079377687568")
|
||||
LOCATION = os.getenv("GCP_LOCATION", "us-central1")
|
||||
# Specific region for Gemini 2.5 Pro model
|
||||
PRO_MODEL_LOCATION = "global"
|
||||
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY", "")
|
||||
TENOR_API_KEY = os.getenv("TENOR_API_KEY", "") # Added Tenor API Key
|
||||
PISTON_API_URL = os.getenv("PISTON_API_URL") # For run_python_code tool
|
||||
|
@ -41,6 +41,7 @@ from .config import (
|
||||
TAVILY_DEFAULT_MAX_RESULTS,
|
||||
TAVILY_DISABLE_ADVANCED,
|
||||
)
|
||||
from gurt.api import get_genai_client_for_model
|
||||
|
||||
# Assume these helpers will be moved or are accessible via cog
|
||||
# We might need to pass 'cog' to these tool functions if they rely on cog state heavily
|
||||
@ -4544,15 +4545,10 @@ async def send_tenor_gif(
|
||||
for gif_data in gif_parts:
|
||||
ai_content.append(gif_data["part"])
|
||||
|
||||
# Initialize AI client if needed
|
||||
if not hasattr(cog, "genai_client") or not cog.genai_client:
|
||||
cog.genai_client = genai.Client(
|
||||
vertexai=True, project=PROJECT_ID, location=LOCATION
|
||||
)
|
||||
|
||||
# Generate AI selection
|
||||
try:
|
||||
response = await cog.genai_client.aio.models.generate_content(
|
||||
client = get_genai_client_for_model(DEFAULT_MODEL)
|
||||
response = await client.aio.models.generate_content(
|
||||
model=DEFAULT_MODEL,
|
||||
contents=[types.Content(role="user", parts=ai_content)],
|
||||
config=types.GenerateContentConfig(
|
||||
|
Loading…
x
Reference in New Issue
Block a user