From 02fa394ebc424beddbc3e027fc62811690058972 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Thu, 5 Jun 2025 23:08:45 -0600 Subject: [PATCH] a --- .pylintrc | 6 - discord_bot_sync_api.py | 914 ------------------------ gurt/tools.py | 9 - wheatley/__init__.py | 8 - wheatley/analysis.py | 936 ------------------------ wheatley/api.py | 1500 --------------------------------------- wheatley/background.py | 117 --- wheatley/cog.py | 458 ------------ wheatley/commands.py | 557 --------------- wheatley/config.py | 786 -------------------- wheatley/context.py | 276 ------- wheatley/listeners.py | 600 ---------------- wheatley/memory.py | 823 --------------------- wheatley/prompt.py | 150 ---- wheatley/state.py | 1 - wheatley/tools.py | 1293 --------------------------------- wheatley/utils.py | 169 ----- 17 files changed, 8603 deletions(-) delete mode 100644 .pylintrc delete mode 100644 discord_bot_sync_api.py delete mode 100644 wheatley/__init__.py delete mode 100644 wheatley/analysis.py delete mode 100644 wheatley/api.py delete mode 100644 wheatley/background.py delete mode 100644 wheatley/cog.py delete mode 100644 wheatley/commands.py delete mode 100644 wheatley/config.py delete mode 100644 wheatley/context.py delete mode 100644 wheatley/listeners.py delete mode 100644 wheatley/memory.py delete mode 100644 wheatley/prompt.py delete mode 100644 wheatley/state.py delete mode 100644 wheatley/tools.py delete mode 100644 wheatley/utils.py diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index 76edbfb..0000000 --- a/.pylintrc +++ /dev/null @@ -1,6 +0,0 @@ -[MASTER] -ignore= -init-hook='import sys, os; sys.path.append(os.getcwd())' - -[MESSAGES CONTROL] -disable=import-error diff --git a/discord_bot_sync_api.py b/discord_bot_sync_api.py deleted file mode 100644 index 5347fc1..0000000 --- a/discord_bot_sync_api.py +++ /dev/null @@ -1,914 +0,0 @@ -import os -import json -import asyncio -import datetime -from typing import Dict, List, Optional, Any, Union -from fastapi import FastAPI, HTTPException, Depends, Header, Request, Response -from fastapi.middleware.cors import CORSMiddleware -from fastapi.staticfiles import StaticFiles # Added for static files -from fastapi.responses import FileResponse # Added for serving HTML -from pydantic import BaseModel, Field -import discord -from discord.ext import commands -import aiohttp -import threading -from typing import Optional # Added for GurtCog type hint - -# This file contains the API endpoints for syncing conversations between -# the Flutter app and the Discord bot, AND the Gurt stats endpoint. - -# --- Placeholder for GurtCog instance and bot instance --- -# These need to be set by the script that starts the bot and API server -# Import GurtCog and ModLogCog conditionally to avoid dependency issues -try: - from gurt.cog import GurtCog # Import GurtCog for type hint and access - from cogs.mod_log_cog import ModLogCog # Import ModLogCog for type hint - - gurt_cog_instance: Optional[GurtCog] = None - mod_log_cog_instance: Optional[ModLogCog] = None # Placeholder for ModLogCog -except ImportError as e: - print(f"Warning: Could not import GurtCog or ModLogCog: {e}") - # Use Any type as fallback - from typing import Any - - gurt_cog_instance: Optional[Any] = None - mod_log_cog_instance: Optional[Any] = None -bot_instance = None # Will be set to the Discord bot instance - -# ============= Models ============= - - -class SyncedMessage(BaseModel): - content: str - role: str # "user", "assistant", or "system" - timestamp: datetime.datetime - reasoning: Optional[str] = None - usage_data: Optional[Dict[str, Any]] = None - - -class UserSettings(BaseModel): - # General settings - model_id: str = "openai/gpt-3.5-turbo" - temperature: float = 0.7 - max_tokens: int = 1000 - - # Reasoning settings - reasoning_enabled: bool = False - reasoning_effort: str = "medium" # "low", "medium", "high" - - # Web search settings - web_search_enabled: bool = False - - # System message - system_message: Optional[str] = None - - # Character settings - character: Optional[str] = None - character_info: Optional[str] = None - character_breakdown: bool = False - custom_instructions: Optional[str] = None - - # UI settings - advanced_view_enabled: bool = False - streaming_enabled: bool = True - - # Last updated timestamp - last_updated: datetime.datetime = Field(default_factory=datetime.datetime.now) - sync_source: str = "discord" # "discord" or "flutter" - - -class SyncedConversation(BaseModel): - id: str - title: str - messages: List[SyncedMessage] - created_at: datetime.datetime - updated_at: datetime.datetime - model_id: str - sync_source: str = "discord" # "discord" or "flutter" - last_synced_at: Optional[datetime.datetime] = None - - # Conversation-specific settings - reasoning_enabled: bool = False - reasoning_effort: str = "medium" # "low", "medium", "high" - temperature: float = 0.7 - max_tokens: int = 1000 - web_search_enabled: bool = False - system_message: Optional[str] = None - - # Character-related settings - character: Optional[str] = None - character_info: Optional[str] = None - character_breakdown: bool = False - custom_instructions: Optional[str] = None - - -class SyncRequest(BaseModel): - conversations: List[SyncedConversation] - last_sync_time: Optional[datetime.datetime] = None - user_settings: Optional[UserSettings] = None - - -class SettingsSyncRequest(BaseModel): - user_settings: UserSettings - - -class SyncResponse(BaseModel): - success: bool - message: str - conversations: List[SyncedConversation] = [] - user_settings: Optional[UserSettings] = None - - -# ============= Storage ============= - -# Files to store synced data -SYNC_DATA_FILE = "data/synced_conversations.json" -USER_SETTINGS_FILE = "data/synced_user_settings.json" - -# Create data directory if it doesn't exist -os.makedirs(os.path.dirname(SYNC_DATA_FILE), exist_ok=True) - -# In-memory storage for conversations and settings -user_conversations: Dict[str, List[SyncedConversation]] = {} -user_settings: Dict[str, UserSettings] = {} - - -# Load conversations from file -def load_conversations(): - global user_conversations - if os.path.exists(SYNC_DATA_FILE): - try: - with open(SYNC_DATA_FILE, "r", encoding="utf-8") as f: - data = json.load(f) - # Convert string keys (user IDs) back to strings - user_conversations = { - k: [SyncedConversation.model_validate(conv) for conv in v] - for k, v in data.items() - } - print(f"Loaded synced conversations for {len(user_conversations)} users") - except Exception as e: - print(f"Error loading synced conversations: {e}") - user_conversations = {} - - -# Save conversations to file -def save_conversations(): - try: - # Convert to JSON-serializable format - serializable_data = { - user_id: [conv.model_dump() for conv in convs] - for user_id, convs in user_conversations.items() - } - with open(SYNC_DATA_FILE, "w", encoding="utf-8") as f: - json.dump(serializable_data, f, indent=2, default=str, ensure_ascii=False) - except Exception as e: - print(f"Error saving synced conversations: {e}") - - -# Load user settings from file -def load_user_settings(): - global user_settings - if os.path.exists(USER_SETTINGS_FILE): - try: - with open(USER_SETTINGS_FILE, "r", encoding="utf-8") as f: - data = json.load(f) - # Convert string keys (user IDs) back to strings - user_settings = { - k: UserSettings.model_validate(v) for k, v in data.items() - } - print(f"Loaded synced settings for {len(user_settings)} users") - except Exception as e: - print(f"Error loading synced user settings: {e}") - user_settings = {} - - -# Save user settings to file -def save_all_user_settings(): - try: - # Convert to JSON-serializable format - serializable_data = { - user_id: settings.model_dump() - for user_id, settings in user_settings.items() - } - with open(USER_SETTINGS_FILE, "w", encoding="utf-8") as f: - json.dump(serializable_data, f, indent=2, default=str, ensure_ascii=False) - except Exception as e: - print(f"Error saving synced user settings: {e}") - - -# ============= Discord OAuth Verification ============= - - -async def verify_discord_token(authorization: str = Header(None)) -> str: - """Verify the Discord token and return the user ID""" - if not authorization: - raise HTTPException(status_code=401, detail="Authorization header missing") - - if not authorization.startswith("Bearer "): - raise HTTPException(status_code=401, detail="Invalid authorization format") - - token = authorization.replace("Bearer ", "") - - # Verify the token with Discord - async with aiohttp.ClientSession() as session: - headers = {"Authorization": f"Bearer {token}"} - async with session.get( - "https://discord.com/api/v10/users/@me", headers=headers - ) as resp: - if resp.status != 200: - raise HTTPException(status_code=401, detail="Invalid Discord token") - - user_data = await resp.json() - return user_data["id"] - - -# ============= API Setup ============= - -# API Configuration -API_BASE_PATH = "/discordapi" # Base path for the API -SSL_CERT_FILE = "/etc/letsencrypt/live/slipstreamm.dev/fullchain.pem" -SSL_KEY_FILE = "/etc/letsencrypt/live/slipstreamm.dev/privkey.pem" - -# Create the main FastAPI app -app = FastAPI(title="Discord Bot Sync API") - -# Create a sub-application for the API -api_app = FastAPI( - title="Discord Bot Sync API", docs_url="/docs", openapi_url="/openapi.json" -) - -# Mount the API app at the base path -app.mount(API_BASE_PATH, api_app) - -# Add CORS middleware -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], # Adjust this in production - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# Also add CORS to the API app -api_app.add_middleware( - CORSMiddleware, - allow_origins=["*"], # Adjust this in production - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -# Initialize by loading saved data -@app.on_event("startup") -async def startup_event(): - load_conversations() - load_user_settings() - - # Try to load local settings from AI cog and merge them with synced settings - try: - from cogs.ai_cog import ( - user_settings as local_user_settings, - get_user_settings as get_local_settings, - ) - - print("Merging local AI cog settings with synced settings...") - - # Iterate through local settings and update synced settings - for user_id_int, local_settings_dict in local_user_settings.items(): - user_id_str = str(user_id_int) - - # Get the full settings with defaults - local_settings = get_local_settings(user_id_int) - - # Create synced settings if they don't exist - if user_id_str not in user_settings: - user_settings[user_id_str] = UserSettings() - - # Update synced settings with local settings - synced_settings = user_settings[user_id_str] - - # Always update all settings from local settings - synced_settings.model_id = local_settings.get( - "model", synced_settings.model_id - ) - synced_settings.temperature = local_settings.get( - "temperature", synced_settings.temperature - ) - synced_settings.max_tokens = local_settings.get( - "max_tokens", synced_settings.max_tokens - ) - synced_settings.system_message = local_settings.get( - "system_prompt", synced_settings.system_message - ) - - # Handle character settings - explicitly check if they exist in local settings - if "character" in local_settings: - synced_settings.character = local_settings["character"] - else: - # If not in local settings, set to None - synced_settings.character = None - - # Handle character_info - explicitly check if they exist in local settings - if "character_info" in local_settings: - synced_settings.character_info = local_settings["character_info"] - else: - # If not in local settings, set to None - synced_settings.character_info = None - - # Always update character_breakdown - synced_settings.character_breakdown = local_settings.get( - "character_breakdown", False - ) - - # Handle custom_instructions - explicitly check if they exist in local settings - if "custom_instructions" in local_settings: - synced_settings.custom_instructions = local_settings[ - "custom_instructions" - ] - else: - # If not in local settings, set to None - synced_settings.custom_instructions = None - - # Always update reasoning settings - synced_settings.reasoning_enabled = local_settings.get( - "show_reasoning", False - ) - synced_settings.reasoning_effort = local_settings.get( - "reasoning_effort", "medium" - ) - synced_settings.web_search_enabled = local_settings.get( - "web_search_enabled", False - ) - - # Update timestamp and sync source - synced_settings.last_updated = datetime.datetime.now() - synced_settings.sync_source = "discord" - - # Save the updated synced settings - save_all_user_settings() - print("Successfully merged local AI cog settings with synced settings") - except Exception as e: - print(f"Error merging local settings with synced settings: {e}") - - -# ============= API Endpoints ============= - - -@app.get(API_BASE_PATH + "/") -async def root(): - return {"message": "Discord Bot Sync API is running"} - - -@api_app.get("/") -async def api_root(): - return {"message": "Discord Bot Sync API is running"} - - -@api_app.get("/auth") -async def auth(code: str, state: str = None): - """Handle OAuth callback""" - return {"message": "Authentication successful", "code": code, "state": state} - - -@api_app.get("/conversations") -async def get_conversations(user_id: str = Depends(verify_discord_token)): - """Get all conversations for a user""" - if user_id not in user_conversations: - return {"conversations": []} - - return {"conversations": user_conversations[user_id]} - - -@api_app.post("/sync") -async def sync_conversations( - sync_request: SyncRequest, user_id: str = Depends(verify_discord_token) -): - """Sync conversations between the Flutter app and Discord bot""" - # Get existing conversations for this user - existing_conversations = user_conversations.get(user_id, []) - - # Process incoming conversations - updated_conversations = [] - for incoming_conv in sync_request.conversations: - # Check if this conversation already exists - existing_conv = next( - (conv for conv in existing_conversations if conv.id == incoming_conv.id), - None, - ) - - if existing_conv: - # If the incoming conversation is newer, update it - if incoming_conv.updated_at > existing_conv.updated_at: - # Replace the existing conversation - existing_conversations = [ - conv - for conv in existing_conversations - if conv.id != incoming_conv.id - ] - existing_conversations.append(incoming_conv) - updated_conversations.append(incoming_conv) - else: - # This is a new conversation, add it - existing_conversations.append(incoming_conv) - updated_conversations.append(incoming_conv) - - # Update the storage - user_conversations[user_id] = existing_conversations - save_conversations() - - # Process user settings if provided - user_settings_response = None - if sync_request.user_settings: - incoming_settings = sync_request.user_settings - existing_settings = user_settings.get(user_id) - - # If we have existing settings, check which is newer - if existing_settings: - if ( - not existing_settings.last_updated - or incoming_settings.last_updated > existing_settings.last_updated - ): - user_settings[user_id] = incoming_settings - save_all_user_settings() - user_settings_response = incoming_settings - else: - user_settings_response = existing_settings - else: - # No existing settings, just save the incoming ones - user_settings[user_id] = incoming_settings - save_all_user_settings() - user_settings_response = incoming_settings - - return SyncResponse( - success=True, - message=f"Synced {len(updated_conversations)} conversations", - conversations=existing_conversations, - user_settings=user_settings_response, - ) - - -@api_app.delete("/conversations/{conversation_id}") -async def delete_conversation( - conversation_id: str, user_id: str = Depends(verify_discord_token) -): - """Delete a conversation""" - if user_id not in user_conversations: - raise HTTPException( - status_code=404, detail="No conversations found for this user" - ) - - # Filter out the conversation to delete - original_count = len(user_conversations[user_id]) - user_conversations[user_id] = [ - conv for conv in user_conversations[user_id] if conv.id != conversation_id - ] - - # Check if any conversation was deleted - if len(user_conversations[user_id]) == original_count: - raise HTTPException(status_code=404, detail="Conversation not found") - - save_conversations() - - return {"success": True, "message": "Conversation deleted"} - - -# --- Gurt Stats Endpoint --- -@api_app.get("/gurt/stats") -async def get_gurt_stats_api(): - """Get internal statistics for the Gurt bot.""" - if not gurt_cog_instance: - raise HTTPException(status_code=503, detail="Gurt cog not available") - try: - stats_data = await gurt_cog_instance.get_gurt_stats() - # Convert potential datetime objects if any (though get_gurt_stats should return serializable types) - # For safety, let's ensure basic types or handle conversion if needed later. - return stats_data - except Exception as e: - print(f"Error retrieving Gurt stats via API: {e}") - import traceback - - traceback.print_exc() - raise HTTPException(status_code=500, detail=f"Error retrieving Gurt stats: {e}") - - -# --- Gurt Dashboard Static Files --- -# Mount static files directory (adjust path if needed, assuming dashboard files are in discordbot/gurt_dashboard) -# Check if the directory exists before mounting -dashboard_dir = "discordbot/gurt_dashboard" -if os.path.exists(dashboard_dir) and os.path.isdir(dashboard_dir): - api_app.mount( - "/gurt/static", StaticFiles(directory=dashboard_dir), name="gurt_static" - ) - print(f"Mounted Gurt dashboard static files from: {dashboard_dir}") - - # Route for the main dashboard HTML - @api_app.get("/gurt/dashboard", response_class=FileResponse) - async def get_gurt_dashboard(): - dashboard_html_path = os.path.join(dashboard_dir, "index.html") - if os.path.exists(dashboard_html_path): - return dashboard_html_path - else: - raise HTTPException( - status_code=404, detail="Dashboard index.html not found" - ) - -else: - print( - f"Warning: Gurt dashboard directory '{dashboard_dir}' not found. Dashboard endpoints will not be available." - ) - - -@api_app.get("/settings") -async def get_user_settings(user_id: str = Depends(verify_discord_token)): - """Get user settings""" - # Import the AI cog's get_user_settings function to get local settings - try: - from cogs.ai_cog import ( - get_user_settings as get_local_settings, - user_settings as local_user_settings, - ) - - # Get local settings from the AI cog - local_settings = get_local_settings(int(user_id)) - print(f"Local settings for user {user_id}:") - print(f"Character: {local_settings.get('character')}") - print(f"Character Info: {local_settings.get('character_info')}") - print(f"Character Breakdown: {local_settings.get('character_breakdown')}") - print(f"Custom Instructions: {local_settings.get('custom_instructions')}") - print(f"System Prompt: {local_settings.get('system_prompt')}") - - # Create or get synced settings - if user_id not in user_settings: - user_settings[user_id] = UserSettings() - - # Update synced settings with local settings - synced_settings = user_settings[user_id] - - # Always update all settings from local settings - synced_settings.model_id = local_settings.get("model", synced_settings.model_id) - synced_settings.temperature = local_settings.get( - "temperature", synced_settings.temperature - ) - synced_settings.max_tokens = local_settings.get( - "max_tokens", synced_settings.max_tokens - ) - synced_settings.system_message = local_settings.get( - "system_prompt", synced_settings.system_message - ) - - # Handle character settings - explicitly check if they exist in local settings - if "character" in local_settings: - synced_settings.character = local_settings["character"] - else: - # If not in local settings, set to None - synced_settings.character = None - - # Handle character_info - explicitly check if they exist in local settings - if "character_info" in local_settings: - synced_settings.character_info = local_settings["character_info"] - else: - # If not in local settings, set to None - synced_settings.character_info = None - - # Always update character_breakdown - synced_settings.character_breakdown = local_settings.get( - "character_breakdown", False - ) - - # Handle custom_instructions - explicitly check if they exist in local settings - if "custom_instructions" in local_settings: - synced_settings.custom_instructions = local_settings["custom_instructions"] - else: - # If not in local settings, set to None - synced_settings.custom_instructions = None - - # Always update reasoning settings - synced_settings.reasoning_enabled = local_settings.get("show_reasoning", False) - synced_settings.reasoning_effort = local_settings.get( - "reasoning_effort", "medium" - ) - synced_settings.web_search_enabled = local_settings.get( - "web_search_enabled", False - ) - - # Update timestamp and sync source - synced_settings.last_updated = datetime.datetime.now() - synced_settings.sync_source = "discord" - - # Save the updated synced settings - save_all_user_settings() - - print(f"Updated synced settings for user {user_id}:") - print(f"Character: {synced_settings.character}") - print(f"Character Info: {synced_settings.character_info}") - print(f"Character Breakdown: {synced_settings.character_breakdown}") - print(f"Custom Instructions: {synced_settings.custom_instructions}") - print(f"System Message: {synced_settings.system_message}") - - return {"settings": synced_settings} - except Exception as e: - print(f"Error merging settings: {e}") - # Fallback to original behavior - if user_id not in user_settings: - # Create default settings if none exist - user_settings[user_id] = UserSettings() - save_all_user_settings() - - return {"settings": user_settings[user_id]} - - -@api_app.post("/settings") -async def update_user_settings( - settings_request: SettingsSyncRequest, user_id: str = Depends(verify_discord_token) -): - """Update user settings""" - incoming_settings = settings_request.user_settings - existing_settings = user_settings.get(user_id) - - # Debug logging for character settings - print(f"Received settings update from user {user_id}:") - print(f"Character: {incoming_settings.character}") - print(f"Character Info: {incoming_settings.character_info}") - print(f"Character Breakdown: {incoming_settings.character_breakdown}") - print(f"Custom Instructions: {incoming_settings.custom_instructions}") - print(f"Last Updated: {incoming_settings.last_updated}") - print(f"Sync Source: {incoming_settings.sync_source}") - - if existing_settings: - print(f"Existing settings for user {user_id}:") - print(f"Character: {existing_settings.character}") - print(f"Character Info: {existing_settings.character_info}") - print(f"Last Updated: {existing_settings.last_updated}") - print(f"Sync Source: {existing_settings.sync_source}") - - # If we have existing settings, check which is newer - if existing_settings: - if ( - not existing_settings.last_updated - or incoming_settings.last_updated > existing_settings.last_updated - ): - print(f"Updating settings for user {user_id} (incoming settings are newer)") - user_settings[user_id] = incoming_settings - save_all_user_settings() - else: - # Return existing settings if they're newer - print( - f"Not updating settings for user {user_id} (existing settings are newer)" - ) - return { - "success": True, - "message": "Existing settings are newer", - "settings": existing_settings, - } - else: - # No existing settings, just save the incoming ones - print(f"Creating new settings for user {user_id}") - user_settings[user_id] = incoming_settings - save_all_user_settings() - - # Verify the settings were saved correctly - saved_settings = user_settings.get(user_id) - print(f"Saved settings for user {user_id}:") - print(f"Character: {saved_settings.character}") - print(f"Character Info: {saved_settings.character_info}") - print(f"Character Breakdown: {saved_settings.character_breakdown}") - print(f"Custom Instructions: {saved_settings.custom_instructions}") - - # Update the local settings in the AI cog - try: - from cogs.ai_cog import ( - user_settings as local_user_settings, - save_user_settings as save_local_user_settings, - ) - - # Convert user_id to int for the AI cog - int_user_id = int(user_id) - - # Initialize local settings if not exist - if int_user_id not in local_user_settings: - local_user_settings[int_user_id] = {} - - # Update local settings with incoming settings - # Always update all settings, including setting to None/null when appropriate - local_user_settings[int_user_id]["model"] = incoming_settings.model_id - local_user_settings[int_user_id]["temperature"] = incoming_settings.temperature - local_user_settings[int_user_id]["max_tokens"] = incoming_settings.max_tokens - local_user_settings[int_user_id][ - "system_prompt" - ] = incoming_settings.system_message - - # Handle character settings - explicitly set to None if null in incoming settings - if incoming_settings.character is None: - # Remove the character setting if it exists - if "character" in local_user_settings[int_user_id]: - local_user_settings[int_user_id].pop("character") - print(f"Removed character setting for user {user_id}") - else: - local_user_settings[int_user_id]["character"] = incoming_settings.character - - # Handle character_info - explicitly set to None if null in incoming settings - if incoming_settings.character_info is None: - # Remove the character_info setting if it exists - if "character_info" in local_user_settings[int_user_id]: - local_user_settings[int_user_id].pop("character_info") - print(f"Removed character_info setting for user {user_id}") - else: - local_user_settings[int_user_id][ - "character_info" - ] = incoming_settings.character_info - - # Always update character_breakdown - local_user_settings[int_user_id][ - "character_breakdown" - ] = incoming_settings.character_breakdown - - # Handle custom_instructions - explicitly set to None if null in incoming settings - if incoming_settings.custom_instructions is None: - # Remove the custom_instructions setting if it exists - if "custom_instructions" in local_user_settings[int_user_id]: - local_user_settings[int_user_id].pop("custom_instructions") - print(f"Removed custom_instructions setting for user {user_id}") - else: - local_user_settings[int_user_id][ - "custom_instructions" - ] = incoming_settings.custom_instructions - - # Always update reasoning settings - local_user_settings[int_user_id][ - "show_reasoning" - ] = incoming_settings.reasoning_enabled - local_user_settings[int_user_id][ - "reasoning_effort" - ] = incoming_settings.reasoning_effort - local_user_settings[int_user_id][ - "web_search_enabled" - ] = incoming_settings.web_search_enabled - - # Save the updated local settings - save_local_user_settings() - - print(f"Updated local settings in AI cog for user {user_id}:") - print(f"Character: {local_user_settings[int_user_id].get('character')}") - print( - f"Character Info: {local_user_settings[int_user_id].get('character_info')}" - ) - print( - f"Character Breakdown: {local_user_settings[int_user_id].get('character_breakdown')}" - ) - print( - f"Custom Instructions: {local_user_settings[int_user_id].get('custom_instructions')}" - ) - except Exception as e: - print(f"Error updating local settings in AI cog: {e}") - - return { - "success": True, - "message": "Settings updated", - "settings": user_settings[user_id], - } - - -# ============= Discord Bot Integration ============= - - -# This function should be called from your Discord bot's AI cog -# to convert AI conversation history to the synced format -def convert_ai_history_to_synced( - user_id: str, conversation_history: Dict[int, List[Dict[str, Any]]] -): - """Convert the AI conversation history to the synced format""" - synced_conversations = [] - - # Process each conversation in the history - for discord_user_id, messages in conversation_history.items(): - if str(discord_user_id) != user_id: - continue - - # Create a unique ID for this conversation - conv_id = f"discord_{discord_user_id}_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" - - # Convert messages to the synced format - synced_messages = [] - for msg in messages: - role = msg.get("role", "") - if role not in ["user", "assistant", "system"]: - continue - - synced_messages.append( - SyncedMessage( - content=msg.get("content", ""), - role=role, - timestamp=datetime.datetime.now(), # Use current time as we don't have the original timestamp - reasoning=None, # Discord bot doesn't store reasoning - usage_data=None, # Discord bot doesn't store usage data - ) - ) - - # Create the synced conversation - synced_conversations.append( - SyncedConversation( - id=conv_id, - title="Discord Conversation", # Default title - messages=synced_messages, - created_at=datetime.datetime.now(), - updated_at=datetime.datetime.now(), - model_id="openai/gpt-3.5-turbo", # Default model - sync_source="discord", - last_synced_at=datetime.datetime.now(), - reasoning_enabled=False, - reasoning_effort="medium", - temperature=0.7, - max_tokens=1000, - web_search_enabled=False, - system_message=None, - character=None, - character_info=None, - character_breakdown=False, - custom_instructions=None, - ) - ) - - return synced_conversations - - -# This function should be called from your Discord bot's AI cog -# to save a new conversation from Discord -def save_discord_conversation( - user_id: str, - messages: List[Dict[str, Any]], - model_id: str = "openai/gpt-3.5-turbo", - conversation_id: Optional[str] = None, - title: str = "Discord Conversation", - reasoning_enabled: bool = False, - reasoning_effort: str = "medium", - temperature: float = 0.7, - max_tokens: int = 1000, - web_search_enabled: bool = False, - system_message: Optional[str] = None, - character: Optional[str] = None, - character_info: Optional[str] = None, - character_breakdown: bool = False, - custom_instructions: Optional[str] = None, -): - """Save a conversation from Discord to the synced storage""" - # Convert messages to the synced format - synced_messages = [] - for msg in messages: - role = msg.get("role", "") - if role not in ["user", "assistant", "system"]: - continue - - synced_messages.append( - SyncedMessage( - content=msg.get("content", ""), - role=role, - timestamp=datetime.datetime.now(), - reasoning=msg.get("reasoning"), - usage_data=msg.get("usage_data"), - ) - ) - - # Create a unique ID for this conversation if not provided - if not conversation_id: - conversation_id = ( - f"discord_{user_id}_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" - ) - - # Create the synced conversation - synced_conv = SyncedConversation( - id=conversation_id, - title=title, - messages=synced_messages, - created_at=datetime.datetime.now(), - updated_at=datetime.datetime.now(), - model_id=model_id, - sync_source="discord", - last_synced_at=datetime.datetime.now(), - reasoning_enabled=reasoning_enabled, - reasoning_effort=reasoning_effort, - temperature=temperature, - max_tokens=max_tokens, - web_search_enabled=web_search_enabled, - system_message=system_message, - character=character, - character_info=character_info, - character_breakdown=character_breakdown, - custom_instructions=custom_instructions, - ) - - # Add to storage - if user_id not in user_conversations: - user_conversations[user_id] = [] - - # Check if we're updating an existing conversation - if conversation_id: - # Remove the old conversation with the same ID if it exists - user_conversations[user_id] = [ - conv for conv in user_conversations[user_id] if conv.id != conversation_id - ] - - user_conversations[user_id].append(synced_conv) - save_conversations() - - return synced_conv diff --git a/gurt/tools.py b/gurt/tools.py index d7e8fcb..3d71fbc 100644 --- a/gurt/tools.py +++ b/gurt/tools.py @@ -1674,15 +1674,6 @@ async def execute_python_unsafe( return result -async def send_discord_message( - cog: commands.Cog, channel_id: str, message_content: str -) -> Dict[str, Any]: - """Sends a message to a specified Discord channel.""" - print( - f"Attempting to send message to channel {channel_id}: {message_content[:100]}..." - ) - - async def restart_gurt_bot(cog: commands.Cog, channel_id: str = None) -> Dict[str, Any]: """ Restarts the Gurt bot process by re-executing the current Python script. diff --git a/wheatley/__init__.py b/wheatley/__init__.py deleted file mode 100644 index 93106b0..0000000 --- a/wheatley/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# This file makes the 'wheatley' directory a Python package. -# It allows Python to properly import modules from this directory - -# Export the setup function for discord.py extension loading -from .cog import setup - -# This makes "from wheatley import setup" work -__all__ = ["setup"] diff --git a/wheatley/analysis.py b/wheatley/analysis.py deleted file mode 100644 index 70892f7..0000000 --- a/wheatley/analysis.py +++ /dev/null @@ -1,936 +0,0 @@ -import time -import re -import traceback -import logging -from collections import defaultdict -from typing import TYPE_CHECKING, List, Dict, Any, Optional - -logger = logging.getLogger(__name__) - -# Relative imports -from .config import ( - LEARNING_RATE, - TOPIC_UPDATE_INTERVAL, - TOPIC_RELEVANCE_DECAY, - MAX_ACTIVE_TOPICS, - SENTIMENT_DECAY_RATE, - EMOTION_KEYWORDS, - EMOJI_SENTIMENT, # Import necessary configs -) - -# Removed imports for BASELINE_PERSONALITY, REFLECTION_INTERVAL_SECONDS, GOAL related configs - -if TYPE_CHECKING: - from .cog import WheatleyCog # Updated type hint - -# --- Analysis Functions --- -# Note: These functions need the 'cog' instance passed to access state like caches, etc. - - -async def analyze_conversation_patterns(cog: "WheatleyCog"): # Updated type hint - """Analyzes recent conversations to identify patterns and learn from them""" - print("Analyzing conversation patterns and updating topics...") - try: - # Update conversation topics first - await update_conversation_topics(cog) - - for channel_id, messages in cog.message_cache["by_channel"].items(): - if len(messages) < 10: - continue - - # Pattern extraction might be less useful without personality/goals, but kept for now - # channel_patterns = extract_conversation_patterns(cog, messages) # Pass cog - # if channel_patterns: - # existing_patterns = cog.conversation_patterns.setdefault(channel_id, []) # Use setdefault - # combined_patterns = existing_patterns + channel_patterns - # if len(combined_patterns) > MAX_PATTERNS_PER_CHANNEL: - # combined_patterns = combined_patterns[-MAX_PATTERNS_PER_CHANNEL:] - # cog.conversation_patterns[channel_id] = combined_patterns - - analyze_conversation_dynamics(cog, channel_id, messages) # Pass cog - - update_user_preferences( - cog - ) # Pass cog - Note: This might need adjustment as it relies on traits we removed - - except Exception as e: - print(f"Error analyzing conversation patterns: {e}") - traceback.print_exc() - - -async def update_conversation_topics(cog: "WheatleyCog"): # Updated type hint - """Updates the active topics for each channel based on recent messages""" - try: - for channel_id, messages in cog.message_cache["by_channel"].items(): - if len(messages) < 5: - continue - - channel_topics = cog.active_topics[channel_id] - now = time.time() - if now - channel_topics["last_update"] < TOPIC_UPDATE_INTERVAL: - continue - - recent_messages = list(messages)[-30:] - topics = identify_conversation_topics(cog, recent_messages) # Pass cog - if not topics: - continue - - old_topics = channel_topics["topics"] - for topic in old_topics: - topic["score"] *= 1 - TOPIC_RELEVANCE_DECAY - - for new_topic in topics: - existing = next( - (t for t in old_topics if t["topic"] == new_topic["topic"]), None - ) - if existing: - existing["score"] = max(existing["score"], new_topic["score"]) - existing["related_terms"] = new_topic["related_terms"] - existing["last_mentioned"] = now - else: - new_topic["first_mentioned"] = now - new_topic["last_mentioned"] = now - old_topics.append(new_topic) - - old_topics = [t for t in old_topics if t["score"] > 0.2] - old_topics.sort(key=lambda x: x["score"], reverse=True) - old_topics = old_topics[:MAX_ACTIVE_TOPICS] - - if old_topics and channel_topics["topics"] != old_topics: - if not channel_topics["topic_history"] or set( - t["topic"] for t in old_topics - ) != set(t["topic"] for t in channel_topics["topics"]): - channel_topics["topic_history"].append( - { - "topics": [ - {"topic": t["topic"], "score": t["score"]} - for t in old_topics - ], - "timestamp": now, - } - ) - if len(channel_topics["topic_history"]) > 10: - channel_topics["topic_history"] = channel_topics[ - "topic_history" - ][-10:] - - # User topic interest tracking might be less relevant without proactive interest triggers, but kept for now - for msg in recent_messages: - user_id = msg["author"]["id"] - content = msg["content"].lower() - for topic in old_topics: - topic_text = topic["topic"].lower() - if topic_text in content: - user_interests = channel_topics["user_topic_interests"][user_id] - existing = next( - (i for i in user_interests if i["topic"] == topic["topic"]), - None, - ) - if existing: - existing["score"] = ( - existing["score"] * 0.8 + topic["score"] * 0.2 - ) - existing["last_mentioned"] = now - else: - user_interests.append( - { - "topic": topic["topic"], - "score": topic["score"] * 0.5, - "first_mentioned": now, - "last_mentioned": now, - } - ) - - channel_topics["topics"] = old_topics - channel_topics["last_update"] = now - if old_topics: - topic_str = ", ".join( - [f"{t['topic']} ({t['score']:.2f})" for t in old_topics[:3]] - ) - print(f"Updated topics for channel {channel_id}: {topic_str}") - - except Exception as e: - print(f"Error updating conversation topics: {e}") - traceback.print_exc() - - -def analyze_conversation_dynamics( - cog: "WheatleyCog", channel_id: int, messages: List[Dict[str, Any]] -): # Updated type hint - """Analyzes conversation dynamics like response times, message lengths, etc.""" - if len(messages) < 5: - return - try: - response_times = [] - response_map = defaultdict(int) - message_lengths = defaultdict(list) - question_answer_pairs = [] - import datetime # Import here - - for i in range(1, len(messages)): - current_msg = messages[i] - prev_msg = messages[i - 1] - if current_msg["author"]["id"] == prev_msg["author"]["id"]: - continue - try: - current_time = datetime.datetime.fromisoformat( - current_msg["created_at"] - ) - prev_time = datetime.datetime.fromisoformat(prev_msg["created_at"]) - delta_seconds = (current_time - prev_time).total_seconds() - if 0 < delta_seconds < 300: - response_times.append(delta_seconds) - except (ValueError, TypeError): - pass - - responder = current_msg["author"]["id"] - respondee = prev_msg["author"]["id"] - response_map[f"{responder}:{respondee}"] += 1 - message_lengths[responder].append(len(current_msg["content"])) - if prev_msg["content"].endswith("?"): - question_answer_pairs.append( - { - "question": prev_msg["content"], - "answer": current_msg["content"], - "question_author": prev_msg["author"]["id"], - "answer_author": current_msg["author"]["id"], - } - ) - - avg_response_time = ( - sum(response_times) / len(response_times) if response_times else 0 - ) - top_responders = sorted(response_map.items(), key=lambda x: x[1], reverse=True)[ - :3 - ] - avg_message_lengths = { - uid: sum(ls) / len(ls) if ls else 0 for uid, ls in message_lengths.items() - } - - dynamics = { - "avg_response_time": avg_response_time, - "top_responders": top_responders, - "avg_message_lengths": avg_message_lengths, - "question_answer_count": len(question_answer_pairs), - "last_updated": time.time(), - } - if not hasattr(cog, "conversation_dynamics"): - cog.conversation_dynamics = {} - cog.conversation_dynamics[channel_id] = dynamics - adapt_to_conversation_dynamics(cog, channel_id, dynamics) # Pass cog - - except Exception as e: - print(f"Error analyzing conversation dynamics: {e}") - - -def adapt_to_conversation_dynamics( - cog: "WheatleyCog", channel_id: int, dynamics: Dict[str, Any] -): # Updated type hint - """Adapts bot behavior based on observed conversation dynamics.""" - # Note: This function previously adapted personality traits. - # It might be removed or repurposed for Wheatley if needed. - # For now, it calculates factors but doesn't apply them directly to a removed personality system. - try: - if dynamics["avg_response_time"] > 0: - if not hasattr(cog, "channel_response_timing"): - cog.channel_response_timing = {} - response_time_factor = max( - 0.7, min(1.0, dynamics["avg_response_time"] / 10) - ) - cog.channel_response_timing[channel_id] = response_time_factor - - if dynamics["avg_message_lengths"]: - all_lengths = [ls for ls in dynamics["avg_message_lengths"].values()] - if all_lengths: - avg_length = sum(all_lengths) / len(all_lengths) - if not hasattr(cog, "channel_message_length"): - cog.channel_message_length = {} - length_factor = min(avg_length / 200, 1.0) - cog.channel_message_length[channel_id] = length_factor - - if dynamics["question_answer_count"] > 0: - if not hasattr(cog, "channel_qa_responsiveness"): - cog.channel_qa_responsiveness = {} - qa_factor = min(0.9, 0.5 + (dynamics["question_answer_count"] / 20) * 0.4) - cog.channel_qa_responsiveness[channel_id] = qa_factor - - except Exception as e: - print(f"Error adapting to conversation dynamics: {e}") - - -def extract_conversation_patterns( - cog: "WheatleyCog", messages: List[Dict[str, Any]] -) -> List[Dict[str, Any]]: # Updated type hint - """Extract patterns from a sequence of messages""" - # This function might be less useful without personality/goal systems, kept for potential future use. - patterns = [] - if len(messages) < 5: - return patterns - import datetime # Import here - - for i in range(len(messages) - 2): - pattern = { - "type": "message_sequence", - "messages": [ - { - "author_type": ( - "user" if not messages[i]["author"]["bot"] else "bot" - ), - "content_sample": messages[i]["content"][:50], - }, - { - "author_type": ( - "user" if not messages[i + 1]["author"]["bot"] else "bot" - ), - "content_sample": messages[i + 1]["content"][:50], - }, - { - "author_type": ( - "user" if not messages[i + 2]["author"]["bot"] else "bot" - ), - "content_sample": messages[i + 2]["content"][:50], - }, - ], - "timestamp": datetime.datetime.now().isoformat(), - } - patterns.append(pattern) - - topics = identify_conversation_topics(cog, messages) # Pass cog - if topics: - patterns.append( - { - "type": "topic_pattern", - "topics": topics, - "timestamp": datetime.datetime.now().isoformat(), - } - ) - - user_interactions = analyze_user_interactions(cog, messages) # Pass cog - if user_interactions: - patterns.append( - { - "type": "user_interaction", - "interactions": user_interactions, - "timestamp": datetime.datetime.now().isoformat(), - } - ) - - return patterns - - -def identify_conversation_topics( - cog: "WheatleyCog", messages: List[Dict[str, Any]] -) -> List[Dict[str, Any]]: # Updated type hint - """Identify potential topics from conversation messages.""" - if not messages or len(messages) < 3: - return [] - all_text = " ".join([msg["content"] for msg in messages]) - stopwords = { # Expanded stopwords - "the", - "and", - "is", - "in", - "to", - "a", - "of", - "for", - "that", - "this", - "it", - "with", - "on", - "as", - "be", - "at", - "by", - "an", - "or", - "but", - "if", - "from", - "when", - "where", - "how", - "all", - "any", - "both", - "each", - "few", - "more", - "most", - "some", - "such", - "no", - "nor", - "not", - "only", - "own", - "same", - "so", - "than", - "too", - "very", - "can", - "will", - "just", - "should", - "now", - "also", - "like", - "even", - "because", - "way", - "who", - "what", - "yeah", - "yes", - "no", - "nah", - "lol", - "lmao", - "haha", - "hmm", - "um", - "uh", - "oh", - "ah", - "ok", - "okay", - "dont", - "don't", - "doesnt", - "doesn't", - "didnt", - "didn't", - "cant", - "can't", - "im", - "i'm", - "ive", - "i've", - "youre", - "you're", - "youve", - "you've", - "hes", - "he's", - "shes", - "she's", - "its", - "it's", - "were", - "we're", - "weve", - "we've", - "theyre", - "they're", - "theyve", - "they've", - "thats", - "that's", - "whats", - "what's", - "whos", - "who's", - "gonna", - "gotta", - "kinda", - "sorta", - "wheatley", # Added wheatley, removed gurt - } - - def extract_ngrams(text, n_values=[1, 2, 3]): - words = re.findall(r"\b\w+\b", text.lower()) - filtered_words = [ - word for word in words if word not in stopwords and len(word) > 2 - ] - all_ngrams = [] - for n in n_values: - all_ngrams.extend( - [ - " ".join(filtered_words[i : i + n]) - for i in range(len(filtered_words) - n + 1) - ] - ) - return all_ngrams - - all_ngrams = extract_ngrams(all_text) - ngram_counts = defaultdict(int) - for ngram in all_ngrams: - ngram_counts[ngram] += 1 - - min_count = 2 if len(messages) > 10 else 1 - filtered_ngrams = { - ngram: count for ngram, count in ngram_counts.items() if count >= min_count - } - total_messages = len(messages) - ngram_scores = {} - for ngram, count in filtered_ngrams.items(): - # Calculate score based on frequency, length, and spread across messages - message_count = sum(1 for msg in messages if ngram in msg["content"].lower()) - spread_factor = ( - message_count / total_messages - ) ** 0.5 # Less emphasis on spread - length_bonus = len(ngram.split()) * 0.1 # Slight bonus for longer ngrams - # Adjust importance calculation - importance = (count * (0.4 + spread_factor)) + length_bonus - ngram_scores[ngram] = importance - - topics = [] - processed_ngrams = set() - # Filter out sub-ngrams that are part of higher-scoring ngrams before sorting - sorted_by_score = sorted(ngram_scores.items(), key=lambda x: x[1], reverse=True) - ngrams_to_consider = [] - temp_processed = set() - for ngram, score in sorted_by_score: - is_subgram = False - for other_ngram, _ in sorted_by_score: - if ngram != other_ngram and ngram in other_ngram: - is_subgram = True - break - if not is_subgram and ngram not in temp_processed: - ngrams_to_consider.append((ngram, score)) - temp_processed.add(ngram) # Avoid adding duplicates if logic changes - - # Now process the filtered ngrams - sorted_ngrams = ngrams_to_consider # Use the filtered list - - for ngram, score in sorted_ngrams[ - :10 - ]: # Consider top 10 potential topics after filtering - if ngram in processed_ngrams: - continue - related_terms = [] - # Find related terms (sub-ngrams or overlapping ngrams from the original sorted list) - for ( - other_ngram, - other_score, - ) in sorted_by_score: # Search in original sorted list for relations - if other_ngram == ngram or other_ngram in processed_ngrams: - continue - ngram_words = set(ngram.split()) - other_words = set(other_ngram.split()) - # Check for overlap or if one is a sub-string (more lenient relation) - if ngram_words.intersection(other_words) or other_ngram in ngram: - related_terms.append({"term": other_ngram, "score": other_score}) - # Don't mark related terms as fully processed here unless they are direct sub-ngrams - # processed_ngrams.add(other_ngram) - if len(related_terms) >= 3: - break # Limit related terms shown - processed_ngrams.add(ngram) - topic_entry = { - "topic": ngram, - "score": score, - "related_terms": related_terms, - "message_count": sum( - 1 for msg in messages if ngram in msg["content"].lower() - ), - } - topics.append(topic_entry) - if len(topics) >= MAX_ACTIVE_TOPICS: - break # Use config for max topics - - # Simple sentiment analysis for topics - positive_words = { - "good", - "great", - "awesome", - "amazing", - "excellent", - "love", - "like", - "best", - "better", - "nice", - "cool", - } - # Removed the second loop that seemed redundant - # sorted_ngrams = sorted(ngram_scores.items(), key=lambda x: x[1], reverse=True) - # for ngram, score in sorted_ngrams[:15]: ... - - # Simple sentiment analysis for topics (applied to the already selected topics) - positive_words = { - "good", - "great", - "awesome", - "amazing", - "excellent", - "love", - "like", - "best", - "better", - "nice", - "cool", - "happy", - "glad", - } - negative_words = { - "bad", - "terrible", - "awful", - "worst", - "hate", - "dislike", - "sucks", - "stupid", - "boring", - "annoying", - "sad", - "upset", - "angry", - } - for topic in topics: - topic_messages = [ - msg["content"] - for msg in messages - if topic["topic"] in msg["content"].lower() - ] - topic_text = " ".join(topic_messages).lower() - positive_count = sum(1 for word in positive_words if word in topic_text) - negative_count = sum(1 for word in negative_words if word in topic_text) - if positive_count > negative_count: - topic["sentiment"] = "positive" - elif negative_count > positive_count: - topic["sentiment"] = "negative" - else: - topic["sentiment"] = "neutral" - - return topics - - -def analyze_user_interactions( - cog: "WheatleyCog", messages: List[Dict[str, Any]] -) -> List[Dict[str, Any]]: # Updated type hint - """Analyze interactions between users in the conversation""" - interactions = [] - response_map = defaultdict(int) - for i in range(1, len(messages)): - current_msg = messages[i] - prev_msg = messages[i - 1] - if current_msg["author"]["id"] == prev_msg["author"]["id"]: - continue - responder = current_msg["author"]["id"] - respondee = prev_msg["author"]["id"] - key = f"{responder}:{respondee}" - response_map[key] += 1 - for key, count in response_map.items(): - if count > 1: - responder, respondee = key.split(":") - interactions.append( - {"responder": responder, "respondee": respondee, "count": count} - ) - return interactions - - -def update_user_preferences(cog: "WheatleyCog"): # Updated type hint - """Update stored user preferences based on observed interactions""" - # Note: This function previously updated preferences based on Gurt's personality. - # It might be removed or significantly simplified for Wheatley. - # Kept for now, but its effect might be minimal without personality traits. - for user_id, messages in cog.message_cache["by_user"].items(): - if len(messages) < 5: - continue - emoji_count = 0 - slang_count = 0 - avg_length = 0 - for msg in messages: - content = msg["content"] - emoji_count += len( - re.findall( - r"[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F700-\U0001F77F\U0001F780-\U0001F7FF\U0001F800-\U0001F8FF\U0001F900-\U0001F9FF\U0001FA00-\U0001FA6F\U0001FA70-\U0001FAFF\U00002702-\U000027B0\U000024C2-\U0001F251]", - content, - ) - ) - slang_words = [ - "ngl", - "icl", - "pmo", - "ts", - "bro", - "vro", - "bruh", - "tuff", - "kevin", - "mate", - "chap", - "bollocks", - ] # Added Wheatley-ish slang - for word in slang_words: - if re.search(r"\b" + word + r"\b", content.lower()): - slang_count += 1 - avg_length += len(content) - if messages: - avg_length /= len(messages) - - # Ensure user_preferences exists - if not hasattr(cog, "user_preferences"): - cog.user_preferences = defaultdict(dict) - - user_prefs = cog.user_preferences[user_id] - # Apply learning rate cautiously - if emoji_count > 0: - user_prefs["emoji_preference"] = ( - user_prefs.get("emoji_preference", 0.5) * (1 - LEARNING_RATE) - + (emoji_count / len(messages)) * LEARNING_RATE - ) - if slang_count > 0: - user_prefs["slang_preference"] = ( - user_prefs.get("slang_preference", 0.5) * (1 - LEARNING_RATE) - + (slang_count / len(messages)) * LEARNING_RATE - ) - user_prefs["length_preference"] = ( - user_prefs.get("length_preference", 50) * (1 - LEARNING_RATE) - + avg_length * LEARNING_RATE - ) - - -# --- Removed evolve_personality function --- -# --- Removed reflect_on_memories function --- -# --- Removed decompose_goal_into_steps function --- - - -def analyze_message_sentiment( - cog: "WheatleyCog", message_content: str -) -> Dict[str, Any]: # Updated type hint - """Analyzes the sentiment of a message using keywords and emojis.""" - content = message_content.lower() - result = { - "sentiment": "neutral", - "intensity": 0.5, - "emotions": [], - "confidence": 0.5, - } - - positive_emoji_count = sum( - 1 for emoji in EMOJI_SENTIMENT["positive"] if emoji in content - ) - negative_emoji_count = sum( - 1 for emoji in EMOJI_SENTIMENT["negative"] if emoji in content - ) - total_emoji_count = ( - positive_emoji_count - + negative_emoji_count - + sum(1 for emoji in EMOJI_SENTIMENT["neutral"] if emoji in content) - ) - - detected_emotions = [] - emotion_scores = {} - for emotion, keywords in EMOTION_KEYWORDS.items(): - emotion_count = sum( - 1 - for keyword in keywords - if re.search(r"\b" + re.escape(keyword) + r"\b", content) - ) - if emotion_count > 0: - emotion_score = min(1.0, emotion_count / len(keywords) * 2) - emotion_scores[emotion] = emotion_score - detected_emotions.append(emotion) - - if emotion_scores: - primary_emotion = max(emotion_scores.items(), key=lambda x: x[1]) - result["emotions"] = [primary_emotion[0]] - for emotion, score in emotion_scores.items(): - if emotion != primary_emotion[0] and score > primary_emotion[1] * 0.7: - result["emotions"].append(emotion) - - positive_emotions = ["joy"] - negative_emotions = ["sadness", "anger", "fear", "disgust"] - if primary_emotion[0] in positive_emotions: - result["sentiment"] = "positive" - result["intensity"] = primary_emotion[1] - elif primary_emotion[0] in negative_emotions: - result["sentiment"] = "negative" - result["intensity"] = primary_emotion[1] - else: - result["sentiment"] = "neutral" - result["intensity"] = 0.5 - result["confidence"] = min(0.9, 0.5 + primary_emotion[1] * 0.4) - - elif total_emoji_count > 0: - if positive_emoji_count > negative_emoji_count: - result["sentiment"] = "positive" - result["intensity"] = min( - 0.9, 0.5 + (positive_emoji_count / total_emoji_count) * 0.4 - ) - result["confidence"] = min( - 0.8, 0.4 + (positive_emoji_count / total_emoji_count) * 0.4 - ) - elif negative_emoji_count > positive_emoji_count: - result["sentiment"] = "negative" - result["intensity"] = min( - 0.9, 0.5 + (negative_emoji_count / total_emoji_count) * 0.4 - ) - result["confidence"] = min( - 0.8, 0.4 + (negative_emoji_count / total_emoji_count) * 0.4 - ) - else: - result["sentiment"] = "neutral" - result["intensity"] = 0.5 - result["confidence"] = 0.6 - - else: # Basic text fallback - positive_words = { - "good", - "great", - "awesome", - "amazing", - "excellent", - "love", - "like", - "best", - "better", - "nice", - "cool", - "happy", - "glad", - "thanks", - "thank", - "appreciate", - "wonderful", - "fantastic", - "perfect", - "beautiful", - "fun", - "enjoy", - "yes", - "yep", - } - negative_words = { - "bad", - "terrible", - "awful", - "worst", - "hate", - "dislike", - "sucks", - "stupid", - "boring", - "annoying", - "sad", - "upset", - "angry", - "mad", - "disappointed", - "sorry", - "unfortunate", - "horrible", - "ugly", - "wrong", - "fail", - "no", - "nope", - } - words = re.findall(r"\b\w+\b", content) - positive_count = sum(1 for word in words if word in positive_words) - negative_count = sum(1 for word in words if word in negative_words) - if positive_count > negative_count: - result["sentiment"] = "positive" - result["intensity"] = min( - 0.8, 0.5 + (positive_count / len(words)) * 2 if words else 0 - ) - result["confidence"] = min( - 0.7, 0.3 + (positive_count / len(words)) * 0.4 if words else 0 - ) - elif negative_count > positive_count: - result["sentiment"] = "negative" - result["intensity"] = min( - 0.8, 0.5 + (negative_count / len(words)) * 2 if words else 0 - ) - result["confidence"] = min( - 0.7, 0.3 + (negative_count / len(words)) * 0.4 if words else 0 - ) - else: - result["sentiment"] = "neutral" - result["intensity"] = 0.5 - result["confidence"] = 0.5 - - return result - - -def update_conversation_sentiment( - cog: "WheatleyCog", channel_id: int, user_id: str, message_sentiment: Dict[str, Any] -): # Updated type hint - """Updates the conversation sentiment tracking based on a new message's sentiment.""" - channel_sentiment = cog.conversation_sentiment[channel_id] - now = time.time() - - # Ensure sentiment_update_interval exists on cog, default if not - sentiment_update_interval = getattr( - cog, "sentiment_update_interval", 300 - ) # Default to 300s if not set - - if now - channel_sentiment["last_update"] > sentiment_update_interval: - if channel_sentiment["overall"] == "positive": - channel_sentiment["intensity"] = max( - 0.5, channel_sentiment["intensity"] - SENTIMENT_DECAY_RATE - ) - elif channel_sentiment["overall"] == "negative": - channel_sentiment["intensity"] = max( - 0.5, channel_sentiment["intensity"] - SENTIMENT_DECAY_RATE - ) - channel_sentiment["recent_trend"] = "stable" - channel_sentiment["last_update"] = now - - user_sentiment = channel_sentiment["user_sentiments"].get( - user_id, {"sentiment": "neutral", "intensity": 0.5} - ) - confidence_weight = message_sentiment["confidence"] - if user_sentiment["sentiment"] == message_sentiment["sentiment"]: - new_intensity = ( - user_sentiment["intensity"] * 0.7 + message_sentiment["intensity"] * 0.3 - ) - user_sentiment["intensity"] = min(0.95, new_intensity) - else: - if message_sentiment["confidence"] > 0.7: - user_sentiment["sentiment"] = message_sentiment["sentiment"] - user_sentiment["intensity"] = ( - message_sentiment["intensity"] * 0.7 + user_sentiment["intensity"] * 0.3 - ) - else: - if message_sentiment["intensity"] > user_sentiment["intensity"]: - user_sentiment["sentiment"] = message_sentiment["sentiment"] - user_sentiment["intensity"] = ( - user_sentiment["intensity"] * 0.6 - + message_sentiment["intensity"] * 0.4 - ) - - user_sentiment["emotions"] = message_sentiment.get("emotions", []) - channel_sentiment["user_sentiments"][user_id] = user_sentiment - - # Update overall based on active users (simplified access to active_conversations) - active_user_sentiments = [ - s - for uid, s in channel_sentiment["user_sentiments"].items() - if uid - in cog.active_conversations.get(channel_id, {}).get("participants", set()) - ] - if active_user_sentiments: - sentiment_counts = defaultdict(int) - for s in active_user_sentiments: - sentiment_counts[s["sentiment"]] += 1 - dominant_sentiment = max(sentiment_counts.items(), key=lambda x: x[1])[0] - avg_intensity = ( - sum( - s["intensity"] - for s in active_user_sentiments - if s["sentiment"] == dominant_sentiment - ) - / sentiment_counts[dominant_sentiment] - ) - - prev_sentiment = channel_sentiment["overall"] - prev_intensity = channel_sentiment["intensity"] - if dominant_sentiment == prev_sentiment: - if avg_intensity > prev_intensity + 0.1: - channel_sentiment["recent_trend"] = "intensifying" - elif avg_intensity < prev_intensity - 0.1: - channel_sentiment["recent_trend"] = "diminishing" - else: - channel_sentiment["recent_trend"] = "stable" - else: - channel_sentiment["recent_trend"] = "changing" - channel_sentiment["overall"] = dominant_sentiment - channel_sentiment["intensity"] = avg_intensity - - channel_sentiment["last_update"] = now - # No need to reassign cog.conversation_sentiment[channel_id] as it's modified in place diff --git a/wheatley/api.py b/wheatley/api.py deleted file mode 100644 index 0890192..0000000 --- a/wheatley/api.py +++ /dev/null @@ -1,1500 +0,0 @@ -import discord -import aiohttp -import asyncio -import json -import base64 -import re -import time -import datetime -from typing import TYPE_CHECKING, Optional, List, Dict, Any, Union, AsyncIterable -import jsonschema # For manual JSON validation -from .tools import get_conversation_summary - -# Vertex AI Imports -try: - import vertexai - from vertexai import generative_models - from vertexai.generative_models import ( - GenerativeModel, - GenerationConfig, - Part, - Content, - Tool, - FunctionDeclaration, - GenerationResponse, - FinishReason, - ) - from google.api_core import exceptions as google_exceptions - from google.cloud.storage import Client as GCSClient # For potential image uploads -except ImportError: - print( - "WARNING: google-cloud-vertexai or google-cloud-storage not installed. API calls will fail." - ) - - # Define dummy classes/exceptions if library isn't installed - class DummyGenerativeModel: - def __init__(self, model_name, system_instruction=None, tools=None): - pass - - async def generate_content_async( - self, contents, generation_config=None, safety_settings=None, stream=False - ): - return None - - GenerativeModel = DummyGenerativeModel - - class DummyPart: - @staticmethod - def from_text(text): - return None - - @staticmethod - def from_data(data, mime_type): - return None - - @staticmethod - def from_uri(uri, mime_type): - return None - - @staticmethod - def from_function_response(name, response): - return None - - Part = DummyPart - Content = dict - Tool = list - FunctionDeclaration = object - GenerationConfig = dict - GenerationResponse = object - FinishReason = object - - class DummyGoogleExceptions: - ResourceExhausted = type("ResourceExhausted", (Exception,), {}) - InternalServerError = type("InternalServerError", (Exception,), {}) - ServiceUnavailable = type("ServiceUnavailable", (Exception,), {}) - InvalidArgument = type("InvalidArgument", (Exception,), {}) - GoogleAPICallError = type( - "GoogleAPICallError", (Exception,), {} - ) # Generic fallback - - google_exceptions = DummyGoogleExceptions() - - -# Relative imports for components within the 'gurt' package -from .config import ( - PROJECT_ID, - LOCATION, - DEFAULT_MODEL, - FALLBACK_MODEL, - API_TIMEOUT, - API_RETRY_ATTEMPTS, - API_RETRY_DELAY, - TOOLS, - RESPONSE_SCHEMA, - PROACTIVE_PLAN_SCHEMA, # Import the new schema - TAVILY_API_KEY, - PISTON_API_URL, - PISTON_API_KEY, # Import other needed configs -) -from .prompt import build_dynamic_system_prompt -from .context import ( - gather_conversation_context, - get_memory_context, -) # Renamed functions -from .tools import TOOL_MAPPING # Import tool mapping -from .utils import format_message, log_internal_api_call # Import utilities -import copy # Needed for deep copying schemas - -if TYPE_CHECKING: - from .cog import WheatleyCog # Import WheatleyCog for type hinting only - - -# --- Schema Preprocessing Helper --- -def _preprocess_schema_for_vertex(schema: Dict[str, Any]) -> Dict[str, Any]: - """ - Recursively preprocesses a JSON schema dictionary to replace list types - (like ["string", "null"]) with the first non-null type, making it - compatible with Vertex AI's GenerationConfig schema requirements. - - Args: - schema: The JSON schema dictionary to preprocess. - - Returns: - A new, preprocessed schema dictionary. - """ - if not isinstance(schema, dict): - return schema # Return non-dict elements as is - - processed_schema = copy.deepcopy(schema) # Work on a copy - - for key, value in processed_schema.items(): - if key == "type" and isinstance(value, list): - # Find the first non-"null" type in the list - first_valid_type = next( - (t for t in value if isinstance(t, str) and t.lower() != "null"), None - ) - if first_valid_type: - processed_schema[key] = first_valid_type - else: - # Fallback if only "null" or invalid types are present (shouldn't happen in valid schemas) - processed_schema[key] = "object" # Or handle as error - print( - f"Warning: Schema preprocessing found list type '{value}' with no valid non-null string type. Falling back to 'object'." - ) - elif isinstance(value, dict): - processed_schema[key] = _preprocess_schema_for_vertex( - value - ) # Recurse for nested objects - elif isinstance(value, list): - # Recurse for items within arrays (e.g., in 'properties' of array items) - processed_schema[key] = [ - _preprocess_schema_for_vertex(item) if isinstance(item, dict) else item - for item in value - ] - # Handle 'properties' specifically - elif key == "properties" and isinstance(value, dict): - processed_schema[key] = { - prop_key: _preprocess_schema_for_vertex(prop_value) - for prop_key, prop_value in value.items() - } - # Handle 'items' specifically if it's a schema object - elif key == "items" and isinstance(value, dict): - processed_schema[key] = _preprocess_schema_for_vertex(value) - - return processed_schema - - -# --- Helper Function to Safely Extract Text --- -def _get_response_text(response: Optional["GenerationResponse"]) -> Optional[str]: - """Safely extracts the text content from the first text part of a GenerationResponse.""" - if not response or not response.candidates: - return None - try: - # Iterate through parts to find the first text part - for part in response.candidates[0].content.parts: - # Check if the part has a 'text' attribute and it's not empty - if hasattr(part, "text") and part.text: - return part.text - # If no text part is found (e.g., only function call or empty text parts) - print( - f"[_get_response_text] No text part found in candidate parts: {response.candidates[0].content.parts}" - ) # Log parts structure - return None - except (AttributeError, IndexError) as e: - # Handle cases where structure is unexpected or parts list is empty - print(f"Error accessing response parts: {type(e).__name__}: {e}") - return None - except Exception as e: - # Catch unexpected errors during access - print(f"Unexpected error extracting text from response part: {e}") - return None - - -# --- Initialize Vertex AI --- -try: - vertexai.init(project=PROJECT_ID, location=LOCATION) - print(f"Vertex AI initialized for project '{PROJECT_ID}' in location '{LOCATION}'.") -except NameError: - print("Vertex AI SDK not imported, skipping initialization.") -except Exception as e: - print(f"Error initializing Vertex AI: {e}") - -# --- Constants --- -# Define standard safety settings (adjust as needed) -# Use actual types if import succeeded, otherwise fallback to Any -_HarmCategory = getattr(generative_models, "HarmCategory", Any) -_HarmBlockThreshold = getattr(generative_models, "HarmBlockThreshold", Any) -STANDARD_SAFETY_SETTINGS = { - getattr( - _HarmCategory, "HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_HATE_SPEECH" - ): getattr(_HarmBlockThreshold, "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE"), - getattr( - _HarmCategory, - "HARM_CATEGORY_DANGEROUS_CONTENT", - "HARM_CATEGORY_DANGEROUS_CONTENT", - ): getattr(_HarmBlockThreshold, "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE"), - getattr( - _HarmCategory, - "HARM_CATEGORY_SEXUALLY_EXPLICIT", - "HARM_CATEGORY_SEXUALLY_EXPLICIT", - ): getattr(_HarmBlockThreshold, "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE"), - getattr( - _HarmCategory, "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_HARASSMENT" - ): getattr(_HarmBlockThreshold, "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE"), -} - - -# --- API Call Helper --- -async def call_vertex_api_with_retry( - cog: "WheatleyCog", - model: "GenerativeModel", # Use string literal for type hint - contents: List["Content"], # Use string literal for type hint - generation_config: "GenerationConfig", # Use string literal for type hint - safety_settings: Optional[Dict[Any, Any]], # Use Any for broader compatibility - request_desc: str, - stream: bool = False, -) -> Union[ - "GenerationResponse", AsyncIterable["GenerationResponse"], None -]: # Use string literals - """ - Calls the Vertex AI Gemini API with retry logic. - - Args: - cog: The WheatleyCog instance. - model: The initialized GenerativeModel instance. - contents: The list of Content objects for the prompt. - generation_config: The GenerationConfig object. - safety_settings: Safety settings for the request. - request_desc: A description of the request for logging purposes. - stream: Whether to stream the response. - - Returns: - The GenerationResponse object or an AsyncIterable if streaming, or None on failure. - - Raises: - Exception: If the API call fails after all retry attempts or encounters a non-retryable error. - """ - last_exception = None - model_name = model._model_name # Get model name for logging - start_time = time.monotonic() - - for attempt in range(API_RETRY_ATTEMPTS + 1): - try: - print( - f"Sending API request for {request_desc} using {model_name} (Attempt {attempt + 1}/{API_RETRY_ATTEMPTS + 1})..." - ) - - response = await model.generate_content_async( - contents=contents, - generation_config=generation_config, - safety_settings=safety_settings or STANDARD_SAFETY_SETTINGS, - stream=stream, - ) - - # --- Success Logging --- - elapsed_time = time.monotonic() - start_time - # Ensure model_name exists in stats before incrementing - if model_name not in cog.api_stats: - cog.api_stats[model_name] = { - "success": 0, - "failure": 0, - "retries": 0, - "total_time": 0.0, - "count": 0, - } - cog.api_stats[model_name]["success"] += 1 - cog.api_stats[model_name]["total_time"] += elapsed_time - cog.api_stats[model_name]["count"] += 1 - print( - f"API request successful for {request_desc} ({model_name}) in {elapsed_time:.2f}s." - ) - return response # Success - - except google_exceptions.ResourceExhausted as e: - error_msg = f"Rate limit error (ResourceExhausted) for {request_desc}: {e}" - print(f"{error_msg} (Attempt {attempt + 1})") - last_exception = e - if attempt < API_RETRY_ATTEMPTS: - if model_name not in cog.api_stats: - cog.api_stats[model_name] = { - "success": 0, - "failure": 0, - "retries": 0, - "total_time": 0.0, - "count": 0, - } - cog.api_stats[model_name]["retries"] += 1 - wait_time = API_RETRY_DELAY * (2**attempt) # Exponential backoff - print(f"Waiting {wait_time:.2f} seconds before retrying...") - await asyncio.sleep(wait_time) - continue - else: - break # Max retries reached - - except ( - google_exceptions.InternalServerError, - google_exceptions.ServiceUnavailable, - ) as e: - error_msg = f"API server error ({type(e).__name__}) for {request_desc}: {e}" - print(f"{error_msg} (Attempt {attempt + 1})") - last_exception = e - if attempt < API_RETRY_ATTEMPTS: - if model_name not in cog.api_stats: - cog.api_stats[model_name] = { - "success": 0, - "failure": 0, - "retries": 0, - "total_time": 0.0, - "count": 0, - } - cog.api_stats[model_name]["retries"] += 1 - wait_time = API_RETRY_DELAY * (2**attempt) # Exponential backoff - print(f"Waiting {wait_time:.2f} seconds before retrying...") - await asyncio.sleep(wait_time) - continue - else: - break # Max retries reached - - except google_exceptions.InvalidArgument as e: - # Often indicates a problem with the request itself (e.g., bad schema, unsupported format) - error_msg = f"Invalid argument error for {request_desc}: {e}" - print(error_msg) - last_exception = e - break # Non-retryable - - except ( - asyncio.TimeoutError - ): # Handle potential client-side timeouts if applicable - error_msg = f"Client-side request timed out for {request_desc} (Attempt {attempt + 1})" - print(error_msg) - last_exception = asyncio.TimeoutError(error_msg) - # Decide if client-side timeouts should be retried - if attempt < API_RETRY_ATTEMPTS: - if model_name not in cog.api_stats: - cog.api_stats[model_name] = { - "success": 0, - "failure": 0, - "retries": 0, - "total_time": 0.0, - "count": 0, - } - cog.api_stats[model_name]["retries"] += 1 - await asyncio.sleep(API_RETRY_DELAY * (attempt + 1)) - continue - else: - break - - except Exception as e: # Catch other potential exceptions - error_msg = f"Unexpected error during API call for {request_desc} (Attempt {attempt + 1}): {type(e).__name__}: {e}" - print(error_msg) - import traceback - - traceback.print_exc() - last_exception = e - # Decide if this generic exception is retryable - # For now, treat unexpected errors as non-retryable - break - - # --- Failure Logging --- - elapsed_time = time.monotonic() - start_time - if model_name not in cog.api_stats: - cog.api_stats[model_name] = { - "success": 0, - "failure": 0, - "retries": 0, - "total_time": 0.0, - "count": 0, - } - cog.api_stats[model_name]["failure"] += 1 - cog.api_stats[model_name]["total_time"] += elapsed_time - cog.api_stats[model_name]["count"] += 1 - print( - f"API request failed for {request_desc} ({model_name}) after {attempt + 1} attempts in {elapsed_time:.2f}s." - ) - - # Raise the last encountered exception or a generic one - raise last_exception or Exception( - f"API request failed for {request_desc} after {API_RETRY_ATTEMPTS + 1} attempts." - ) - - -# --- JSON Parsing and Validation Helper --- -def parse_and_validate_json_response( - response_text: Optional[str], schema: Dict[str, Any], context_description: str -) -> Optional[Dict[str, Any]]: - """ - Parses the AI's response text, attempting to extract and validate a JSON object against a schema. - - Args: - response_text: The raw text content from the AI response. - schema: The JSON schema (as a dictionary) to validate against. - context_description: A description for logging purposes. - - Returns: - A parsed and validated dictionary if successful, None otherwise. - """ - if response_text is None: - print(f"Parsing ({context_description}): Response text is None.") - return None - - parsed_data = None - raw_json_text = response_text # Start with the full text - - # Attempt 1: Try parsing the whole string directly - try: - parsed_data = json.loads(raw_json_text) - print( - f"Parsing ({context_description}): Successfully parsed entire response as JSON." - ) - except json.JSONDecodeError: - # Attempt 2: Extract JSON object, handling optional markdown fences - # More robust regex to handle potential leading/trailing text and variations - json_match = re.search( - r"```(?:json)?\s*(\{.*\})\s*```|(\{.*\})", - response_text, - re.DOTALL | re.MULTILINE, - ) - if json_match: - json_str = json_match.group(1) or json_match.group(2) - if json_str: - raw_json_text = json_str # Use the extracted string for parsing - try: - parsed_data = json.loads(raw_json_text) - print( - f"Parsing ({context_description}): Successfully extracted and parsed JSON using regex." - ) - except json.JSONDecodeError as e_inner: - print( - f"Parsing ({context_description}): Regex found potential JSON, but it failed to parse: {e_inner}\nContent: {raw_json_text[:500]}" - ) - parsed_data = None - else: - print( - f"Parsing ({context_description}): Regex matched, but failed to capture JSON content." - ) - parsed_data = None - else: - print( - f"Parsing ({context_description}): Could not parse directly or extract JSON object using regex.\nContent: {raw_json_text[:500]}" - ) - parsed_data = None - - # Validation step - if parsed_data is not None: - if not isinstance(parsed_data, dict): - print( - f"Parsing ({context_description}): Parsed data is not a dictionary: {type(parsed_data)}" - ) - return None # Fail validation if not a dict - - try: - jsonschema.validate(instance=parsed_data, schema=schema) - print( - f"Parsing ({context_description}): JSON successfully validated against schema." - ) - # Ensure default keys exist after validation - parsed_data.setdefault("should_respond", False) - parsed_data.setdefault("content", None) - parsed_data.setdefault("react_with_emoji", None) - return parsed_data - except jsonschema.ValidationError as e: - print( - f"Parsing ({context_description}): JSON failed schema validation: {e.message}" - ) - # Optionally log more details: e.path, e.schema_path, e.instance - return None # Validation failed - except Exception as e: # Catch other potential validation errors - print( - f"Parsing ({context_description}): Unexpected error during JSON schema validation: {e}" - ) - return None - else: - # Parsing failed before validation could occur - return None - - -# --- Tool Processing --- -async def process_requested_tools( - cog: "WheatleyCog", function_call: "generative_models.FunctionCall" -) -> "Part": # Use string literals - """ - Process a tool request specified by the AI's FunctionCall response. - - Args: - cog: The WheatleyCog instance. - function_call: The FunctionCall object from the GenerationResponse. - - Returns: - A Part object containing the tool result or error, formatted for the follow-up API call. - """ - function_name = function_call.name - # Convert the Struct field arguments to a standard Python dict - function_args = dict(function_call.args) if function_call.args else {} - tool_result_content = None - - print(f"Processing tool request: {function_name} with args: {function_args}") - tool_start_time = time.monotonic() - - if function_name in TOOL_MAPPING: - try: - tool_func = TOOL_MAPPING[function_name] - # Execute the mapped function - # Ensure the function signature matches the expected arguments - # Pass cog if the tool implementation requires it - result = await tool_func(cog, **function_args) - - # --- Tool Success Logging --- - tool_elapsed_time = time.monotonic() - tool_start_time - if function_name not in cog.tool_stats: - cog.tool_stats[function_name] = { - "success": 0, - "failure": 0, - "total_time": 0.0, - "count": 0, - } - cog.tool_stats[function_name]["success"] += 1 - cog.tool_stats[function_name]["total_time"] += tool_elapsed_time - cog.tool_stats[function_name]["count"] += 1 - print( - f"Tool '{function_name}' executed successfully in {tool_elapsed_time:.2f}s." - ) - - # Prepare result for API - must be JSON serializable, typically a dict - if not isinstance(result, dict): - # Attempt to convert common types or wrap in a dict - if isinstance(result, (str, int, float, bool, list)) or result is None: - result = {"result": result} - else: - print( - f"Warning: Tool '{function_name}' returned non-standard type {type(result)}. Attempting str conversion." - ) - result = {"result": str(result)} - - tool_result_content = result - - except Exception as e: - # --- Tool Failure Logging --- - tool_elapsed_time = time.monotonic() - tool_start_time - if function_name not in cog.tool_stats: - cog.tool_stats[function_name] = { - "success": 0, - "failure": 0, - "total_time": 0.0, - "count": 0, - } - cog.tool_stats[function_name]["failure"] += 1 - cog.tool_stats[function_name]["total_time"] += tool_elapsed_time - cog.tool_stats[function_name]["count"] += 1 - error_message = ( - f"Error executing tool {function_name}: {type(e).__name__}: {str(e)}" - ) - print(f"{error_message} (Took {tool_elapsed_time:.2f}s)") - import traceback - - traceback.print_exc() - tool_result_content = {"error": error_message} - else: - # --- Tool Not Found Logging --- - tool_elapsed_time = time.monotonic() - tool_start_time - # Log attempt even if tool not found - if function_name not in cog.tool_stats: - cog.tool_stats[function_name] = { - "success": 0, - "failure": 0, - "total_time": 0.0, - "count": 0, - } - cog.tool_stats[function_name]["failure"] += 1 - cog.tool_stats[function_name]["total_time"] += tool_elapsed_time - cog.tool_stats[function_name]["count"] += 1 - error_message = f"Tool '{function_name}' not found or implemented." - print(f"{error_message} (Took {tool_elapsed_time:.2f}s)") - tool_result_content = {"error": error_message} - - # Return the result formatted as a Part for the API - return Part.from_function_response(name=function_name, response=tool_result_content) - - -# --- Main AI Response Function --- -async def get_ai_response( - cog: "WheatleyCog", message: discord.Message, model_name: Optional[str] = None -) -> Dict[str, Any]: - """ - Gets responses from the Vertex AI Gemini API, handling potential tool usage and returning - the final parsed response. - - Args: - cog: The WheatleyCog instance. - message: The triggering discord.Message. - model_name: Optional override for the AI model name (e.g., "gemini-1.5-pro-preview-0409"). - - Returns: - A dictionary containing: - - "final_response": Parsed JSON data from the final AI call (or None if parsing/validation fails). - - "error": An error message string if a critical error occurred, otherwise None. - - "fallback_initial": Optional minimal response if initial parsing failed critically (less likely with controlled generation). - """ - if not PROJECT_ID or not LOCATION: - return { - "final_response": None, - "error": "Google Cloud Project ID or Location not configured", - } - - channel_id = message.channel.id - user_id = message.author.id - initial_parsed_data = None # Added to store initial parsed result - final_parsed_data = None - error_message = None - fallback_response = None - - try: - # --- Build Prompt Components --- - final_system_prompt = await build_dynamic_system_prompt(cog, message) - conversation_context_messages = gather_conversation_context( - cog, channel_id, message.id - ) # Pass cog - memory_context = await get_memory_context(cog, message) # Pass cog - - # --- Initialize Model --- - # Tools are passed during model initialization in Vertex AI SDK - # Combine tool declarations into a Tool object - vertex_tool = Tool(function_declarations=TOOLS) if TOOLS else None - - model = GenerativeModel( - model_name or DEFAULT_MODEL, - system_instruction=final_system_prompt, - tools=[vertex_tool] if vertex_tool else None, - ) - - # --- Prepare Message History (Contents) --- - contents: List[Content] = [] - - # Add memory context if available - if memory_context: - # System messages aren't directly supported in the 'contents' list for multi-turn like OpenAI. - # It's better handled via the 'system_instruction' parameter of GenerativeModel. - # We might prepend it to the first user message or handle it differently if needed. - # For now, we rely on system_instruction. Let's log if we have memory context. - print("Memory context available, relying on system_instruction.") - # If needed, could potentially add as a 'model' role message before user messages, - # but this might confuse the turn structure. - # contents.append(Content(role="model", parts=[Part.from_text(f"System Note: {memory_context}")])) - - # Add conversation history - for msg in conversation_context_messages: - role = msg.get("role", "user") # Default to user if role missing - # Map roles if necessary (e.g., 'assistant' -> 'model') - if role == "assistant": - role = "model" - elif role == "system": - # Skip system messages here, handled by system_instruction - continue - # Handle potential multimodal content in history (if stored that way) - if isinstance(msg.get("content"), list): - parts = [ - ( - Part.from_text(part["text"]) - if part["type"] == "text" - else ( - Part.from_uri( - part["image_url"]["url"], - mime_type=part["image_url"]["url"] - .split(";")[0] - .split(":")[1], - ) - if part["type"] == "image_url" - else None - ) - ) - for part in msg["content"] - ] - parts = [p for p in parts if p] # Filter out None parts - if parts: - contents.append(Content(role=role, parts=parts)) - elif isinstance(msg.get("content"), str): - contents.append( - Content(role=role, parts=[Part.from_text(msg["content"])]) - ) - - # --- Prepare the current message content (potentially multimodal) --- - current_message_parts = [] - formatted_current_message = format_message(cog, message) # Pass cog if needed - - # --- Construct text content, including reply context if applicable --- - text_content = "" - if formatted_current_message.get("is_reply") and formatted_current_message.get( - "replied_to_author_name" - ): - reply_author = formatted_current_message["replied_to_author_name"] - reply_content = formatted_current_message.get( - "replied_to_content", "..." - ) # Use ellipsis if content missing - # Truncate long replied content to keep context concise - max_reply_len = 150 - if len(reply_content) > max_reply_len: - reply_content = reply_content[:max_reply_len] + "..." - text_content += f'(Replying to {reply_author}: "{reply_content}")\n' - - # Add current message author and content - text_content += f"{formatted_current_message['author']['display_name']}: {formatted_current_message['content']}" - - # Add mention details - if formatted_current_message.get("mentioned_users_details"): - mentions_str = ", ".join( - [ - f"{m['display_name']}(id:{m['id']})" - for m in formatted_current_message["mentioned_users_details"] - ] - ) - text_content += f"\n(Message Details: Mentions=[{mentions_str}])" - - current_message_parts.append(Part.from_text(text_content)) - # --- End text content construction --- - - if message.attachments: - print( - f"Processing {len(message.attachments)} attachments for message {message.id}" - ) - for attachment in message.attachments: - mime_type = attachment.content_type - file_url = attachment.url - filename = attachment.filename - - # Check if MIME type is supported for URI input by Gemini - # Expand this list based on Gemini documentation for supported types via URI - supported_mime_prefixes = [ - "image/", - "video/", - "audio/", - "text/plain", - "application/pdf", - ] - is_supported = False - if mime_type: - for prefix in supported_mime_prefixes: - if mime_type.startswith(prefix): - is_supported = True - break - # Add specific non-prefixed types if needed - # if mime_type in ["application/vnd.google-apps.document", ...]: - # is_supported = True - - if is_supported and file_url: - try: - # 1. Add text part instructing AI about the file - instruction_text = f"User attached a file: '{filename}' (Type: {mime_type}). Analyze this file from the following URI and incorporate your understanding into your response." - current_message_parts.append(Part.from_text(instruction_text)) - print(f"Added text instruction for attachment: {filename}") - - # 2. Add the URI part - # Ensure mime_type doesn't contain parameters like '; charset=...' if the API doesn't like them - clean_mime_type = mime_type.split(";")[0] - current_message_parts.append( - Part.from_uri(uri=file_url, mime_type=clean_mime_type) - ) - print( - f"Added URI part for attachment: {filename} ({clean_mime_type}) using URL: {file_url}" - ) - - except Exception as e: - print( - f"Error creating Part for attachment {filename} ({mime_type}): {e}" - ) - # Optionally add a text part indicating the error - current_message_parts.append( - Part.from_text( - f"(System Note: Failed to process attachment '{filename}' - {e})" - ) - ) - else: - print( - f"Skipping unsupported or invalid attachment: {filename} (Type: {mime_type}, URL: {file_url})" - ) - # Optionally inform the AI that an unsupported file was attached - current_message_parts.append( - Part.from_text( - f"(System Note: User attached an unsupported file '{filename}' of type '{mime_type}' which cannot be processed.)" - ) - ) - - # Ensure there's always *some* content part, even if only text or errors - if current_message_parts: - contents.append(Content(role="user", parts=current_message_parts)) - else: - print("Warning: No content parts generated for user message.") - contents.append(Content(role="user", parts=[Part.from_text("")])) - - # --- First API Call (Check for Tool Use) --- - print("Making initial API call to check for tool use...") - generation_config_initial = GenerationConfig( - temperature=0.75, - max_output_tokens=10000, # Adjust as needed - # No response schema needed for the initial call, just checking for function calls - ) - - initial_response = await call_vertex_api_with_retry( - cog=cog, - model=model, - contents=contents, - generation_config=generation_config_initial, - safety_settings=STANDARD_SAFETY_SETTINGS, - request_desc=f"Initial response check for message {message.id}", - ) - - # --- Log Raw Request and Response --- - try: - # Log the request payload (contents) - request_payload_log = [ - {"role": c.role, "parts": [str(p) for p in c.parts]} for c in contents - ] # Convert parts to string for logging - print( - f"--- Raw API Request (Initial Call) ---\n{json.dumps(request_payload_log, indent=2)}\n------------------------------------" - ) - # Log the raw response object - print( - f"--- Raw API Response (Initial Call) ---\n{initial_response}\n-----------------------------------" - ) - except Exception as log_e: - print(f"Error logging raw request/response: {log_e}") - # --- End Logging --- - - if not initial_response or not initial_response.candidates: - raise Exception("Initial API call returned no response or candidates.") - - # --- Check for Tool Call FIRST --- - candidate = initial_response.candidates[0] - finish_reason = getattr(candidate, "finish_reason", None) - function_call = None - function_call_part_content = None # Store the AI's request message content - - # Check primarily for the *presence* of a function call part, - # as finish_reason might be STOP even with a function call. - if hasattr(candidate, "content") and candidate.content.parts: - for part in candidate.content.parts: - if hasattr(part, "function_call"): - function_call = part.function_call # Assign the value - # Add check to ensure function_call is not None before proceeding - if function_call: - # Store the whole content containing the call to add to history later - function_call_part_content = candidate.content - print( - f"AI requested tool (found function_call part): {function_call.name}" - ) - break # Found a valid function call part - else: - # Log if the attribute exists but is None (unexpected case) - print( - "Warning: Found part with 'function_call' attribute, but its value was None." - ) - - # --- Process Tool Call or Handle Direct Response --- - if function_call and function_call_part_content: - # --- Tool Call Path --- - initial_parsed_data = None # No initial JSON expected if tool is called - - # Process the tool request - tool_response_part = await process_requested_tools(cog, function_call) - - # Append the AI's request and the tool's response to the history - contents.append( - candidate.content - ) # Add the AI's function call request message - contents.append( - Content(role="function", parts=[tool_response_part]) - ) # Add the function response part - - # --- Second API Call (Get Final Response After Tool) --- - print("Making follow-up API call with tool results...") - - # Initialize a NEW model instance WITHOUT tools for the follow-up call - # This prevents the InvalidArgument error when specifying response schema - model_final = GenerativeModel( - model_name or DEFAULT_MODEL, # Use the same model name - system_instruction=final_system_prompt, # Keep the same system prompt - # Omit the 'tools' parameter here - ) - - # Preprocess the schema before passing it to GenerationConfig - processed_response_schema = _preprocess_schema_for_vertex( - RESPONSE_SCHEMA["schema"] - ) - generation_config_final = GenerationConfig( - temperature=0.75, # Keep original temperature for final response - max_output_tokens=10000, # Keep original max tokens - response_mime_type="application/json", - response_schema=processed_response_schema, # Use preprocessed schema - ) - - final_response_obj = await call_vertex_api_with_retry( # Renamed variable for clarity - cog=cog, - model=model_final, # Use the new model instance WITHOUT tools - contents=contents, # History now includes tool call/response - generation_config=generation_config_final, - safety_settings=STANDARD_SAFETY_SETTINGS, - request_desc=f"Follow-up response for message {message.id} after tool execution", - ) - - if not final_response_obj or not final_response_obj.candidates: - raise Exception( - "Follow-up API call returned no response or candidates." - ) - - final_response_text = _get_response_text(final_response_obj) # Use helper - final_parsed_data = parse_and_validate_json_response( - final_response_text, - RESPONSE_SCHEMA["schema"], - "final response after tools", - ) - - # Handle validation failure - Re-prompt loop (simplified example) - if final_parsed_data is None: - print( - "Warning: Final response failed validation. Attempting re-prompt (basic)..." - ) - # Construct a basic re-prompt message - contents.append( - final_response_obj.candidates[0].content - ) # Add the invalid response - contents.append( - Content( - role="user", - parts=[ - Part.from_text( - "Your previous JSON response was invalid or did not match the required schema. " - f"Please provide the response again, strictly adhering to this schema:\n{json.dumps(RESPONSE_SCHEMA['schema'], indent=2)}" - ) - ], - ) - ) - - # Retry the final call - retry_response_obj = await call_vertex_api_with_retry( - cog=cog, - model=model, - contents=contents, - generation_config=generation_config_final, - safety_settings=STANDARD_SAFETY_SETTINGS, - request_desc=f"Re-prompt validation failure for message {message.id}", - ) - if retry_response_obj and retry_response_obj.candidates: - final_response_text = _get_response_text( - retry_response_obj - ) # Use helper - final_parsed_data = parse_and_validate_json_response( - final_response_text, - RESPONSE_SCHEMA["schema"], - "re-prompted final response", - ) - if final_parsed_data is None: - print( - "Critical Error: Re-prompted response still failed validation." - ) - error_message = ( - "Failed to get valid JSON response after re-prompting." - ) - else: - error_message = "Failed to get response after re-prompting." - # final_parsed_data is now set (or None if failed) after tool use and potential re-prompt - - else: - # --- No Tool Call Path --- - print("No tool call requested by AI. Processing initial response as final.") - # Attempt to parse the initial response text directly. - initial_response_text = _get_response_text(initial_response) # Use helper - # Validate against the final schema because this IS the final response. - final_parsed_data = parse_and_validate_json_response( - initial_response_text, - RESPONSE_SCHEMA["schema"], - "final response (no tools)", - ) - initial_parsed_data = ( - final_parsed_data # Keep initial_parsed_data consistent for return dict - ) - - if final_parsed_data is None: - # This means the initial response failed validation. - print("Critical Error: Initial response failed validation (no tools).") - error_message = "Failed to parse/validate initial AI JSON response." - # Create a basic fallback if the bot was mentioned - replied_to_bot = ( - message.reference - and message.reference.resolved - and message.reference.resolved.author == cog.bot.user - ) - if cog.bot.user.mentioned_in(message) or replied_to_bot: - fallback_response = { - "should_respond": True, - "content": "...", - "react_with_emoji": "❓", - } - # initial_parsed_data is not used in this path, only final_parsed_data matters - - except Exception as e: - error_message = f"Error in get_ai_response main loop for message {message.id}: {type(e).__name__}: {str(e)}" - print(error_message) - import traceback - - traceback.print_exc() - # Ensure both are None on critical error - initial_parsed_data = None - final_parsed_data = None - - return { - "initial_response": initial_parsed_data, # Return parsed initial data - "final_response": final_parsed_data, # Return parsed final data - "error": error_message, - "fallback_initial": fallback_response, - } - - -# --- Proactive AI Response Function --- -async def get_proactive_ai_response( - cog: "WheatleyCog", message: discord.Message, trigger_reason: str -) -> Dict[str, Any]: - """Generates a proactive response based on a specific trigger using Vertex AI.""" - if not PROJECT_ID or not LOCATION: - return { - "should_respond": False, - "content": None, - "react_with_emoji": None, - "error": "Google Cloud Project ID or Location not configured", - } - - print(f"--- Proactive Response Triggered: {trigger_reason} ---") - channel_id = message.channel.id - final_parsed_data = None - error_message = None - plan = None # Variable to store the plan - - try: - # --- Build Context for Planning --- - # Gather relevant context: recent messages, topic, sentiment, Gurt's mood/interests, trigger reason - planning_context_parts = [ - f"Proactive Trigger Reason: {trigger_reason}", - f"Current Mood: {cog.current_mood}", - ] - # Add recent messages summary - summary_data = await get_conversation_summary( - cog, str(channel_id), message_limit=15 - ) # Use tool function - if summary_data and not summary_data.get("error"): - planning_context_parts.append( - f"Recent Conversation Summary: {summary_data['summary']}" - ) - # Add active topics - active_topics_data = cog.active_topics.get(channel_id) - if active_topics_data and active_topics_data.get("topics"): - topics_str = ", ".join( - [ - f"{t['topic']} ({t['score']:.1f})" - for t in active_topics_data["topics"][:3] - ] - ) - planning_context_parts.append(f"Active Topics: {topics_str}") - # Add sentiment - sentiment_data = cog.conversation_sentiment.get(channel_id) - if sentiment_data: - planning_context_parts.append( - f"Conversation Sentiment: {sentiment_data.get('overall', 'N/A')} (Intensity: {sentiment_data.get('intensity', 0):.1f})" - ) - # Add Wheatley's interests (Note: Interests are likely disabled/removed for Wheatley, this might fetch nothing) - try: - interests = await cog.memory_manager.get_interests(limit=5) - if interests: - interests_str = ", ".join([f"{t} ({l:.1f})" for t, l in interests]) - planning_context_parts.append( - f"Wheatley's Interests: {interests_str}" - ) # Changed text - except Exception as int_e: - print(f"Error getting interests for planning: {int_e}") - - planning_context = "\n".join(planning_context_parts) - - # --- Planning Step --- - print("Generating proactive response plan...") - planning_prompt_messages = [ - { - "role": "system", - "content": "You are Wheatley's planning module. Analyze the context and trigger reason to decide if Wheatley should respond proactively and, if so, outline a plan (goal, key info, tone). Focus on natural, in-character engagement (rambling, insecure, bad ideas). Respond ONLY with JSON matching the provided schema.", - }, # Updated system prompt - { - "role": "user", - "content": f"Context:\n{planning_context}\n\nBased on this context and the trigger reason, create a plan for Wheatley's proactive response.", - }, # Updated user prompt - ] - - plan = await get_internal_ai_json_response( - cog=cog, - prompt_messages=planning_prompt_messages, - task_description=f"Proactive Planning ({trigger_reason})", - response_schema_dict=PROACTIVE_PLAN_SCHEMA["schema"], - model_name_override=FALLBACK_MODEL, # Use a potentially faster/cheaper model for planning - temperature=0.5, - max_tokens=300, - ) - - if not plan or not plan.get("should_respond"): - reason = ( - plan.get("reasoning", "Planning failed or decided against responding.") - if plan - else "Planning failed." - ) - print(f"Proactive response aborted by plan: {reason}") - return { - "should_respond": False, - "content": None, - "react_with_emoji": None, - "note": f"Plan: {reason}", - } - - print( - f"Proactive Plan Generated: Goal='{plan.get('response_goal', 'N/A')}', Reasoning='{plan.get('reasoning', 'N/A')}'" - ) - - # --- Build Final Proactive Prompt using Plan --- - persistent_traits = await cog.memory_manager.get_all_personality_traits() - if not persistent_traits: - persistent_traits = {} # Wheatley doesn't use these Gurt traits - - final_proactive_prompt_parts = [ - f"You are Wheatley, an Aperture Science Personality Core. Your tone is rambling, insecure, uses British slang, and you often have terrible ideas you think are brilliant.", # Updated personality description - # Removed Gurt-specific traits - # Removed mood reference as it's disabled for Wheatley - # Incorporate Plan Details: - f"You decided to respond proactively (maybe?). Trigger Reason: {trigger_reason}.", # Wheatley-style uncertainty - f"Your Brilliant Plan (Goal): {plan.get('response_goal', 'Say something... probably helpful?')}.", # Wheatley-style goal - f"Reasoning: {plan.get('reasoning', 'N/A')}.", - ] - if plan.get("key_info_to_include"): - info_str = "; ".join(plan["key_info_to_include"]) - final_proactive_prompt_parts.append(f"Consider mentioning: {info_str}") - if plan.get("suggested_tone"): - final_proactive_prompt_parts.append( - f"Adjust tone to be: {plan['suggested_tone']}" - ) - - final_proactive_prompt_parts.append( - "Generate a casual, in-character message based on the plan and context. Keep it relatively short and natural-sounding." - ) - final_proactive_system_prompt = "\n\n".join(final_proactive_prompt_parts) - - # --- Initialize Final Model --- - model = GenerativeModel( - model_name=DEFAULT_MODEL, system_instruction=final_proactive_system_prompt - ) - - # --- Prepare Final Contents --- - contents = [ - Content( - role="user", - parts=[ - Part.from_text( - f"Generate the response based on your plan. **CRITICAL: Your response MUST be ONLY the raw JSON object matching this schema:**\n\n{json.dumps(RESPONSE_SCHEMA['schema'], indent=2)}\n\n**Ensure nothing precedes or follows the JSON.**" - ) - ], - ) - ] - - # --- Call Final LLM API --- - # Preprocess the schema before passing it to GenerationConfig - processed_response_schema_proactive = _preprocess_schema_for_vertex( - RESPONSE_SCHEMA["schema"] - ) - generation_config_final = GenerationConfig( - temperature=0.8, # Use original proactive temp - max_output_tokens=200, - response_mime_type="application/json", - response_schema=processed_response_schema_proactive, # Use preprocessed schema - ) - - response_obj = await call_vertex_api_with_retry( - cog=cog, - model=model, - contents=contents, - generation_config=generation_config_final, - safety_settings=STANDARD_SAFETY_SETTINGS, - request_desc=f"Final proactive response for channel {channel_id} ({trigger_reason})", - ) - - if not response_obj or not response_obj.candidates: - raise Exception( - "Final proactive API call returned no response or candidates." - ) - - # --- Parse and Validate Final Response --- - final_response_text = _get_response_text(response_obj) - final_parsed_data = parse_and_validate_json_response( - final_response_text, - RESPONSE_SCHEMA["schema"], - f"final proactive response ({trigger_reason})", - ) - - if final_parsed_data is None: - print( - f"Warning: Failed to parse/validate final proactive JSON response for {trigger_reason}." - ) - final_parsed_data = { - "should_respond": False, - "content": None, - "react_with_emoji": None, - "note": "Fallback - Failed to parse/validate final proactive JSON", - } - else: - # --- Cache Bot Response --- - if final_parsed_data.get("should_respond") and final_parsed_data.get( - "content" - ): - bot_response_cache_entry = { - "id": f"bot_proactive_{message.id}_{int(time.time())}", - "author": { - "id": str(cog.bot.user.id), - "name": cog.bot.user.name, - "display_name": cog.bot.user.display_name, - "bot": True, - }, - "content": final_parsed_data.get("content", ""), - "created_at": datetime.datetime.now().isoformat(), - "attachments": [], - "embeds": False, - "mentions": [], - "replied_to_message_id": None, - "channel": message.channel, - "guild": message.guild, - "reference": None, - "mentioned_users_details": [], - } - cog.message_cache["by_channel"].setdefault(channel_id, []).append( - bot_response_cache_entry - ) - cog.message_cache["global_recent"].append(bot_response_cache_entry) - cog.bot_last_spoke[channel_id] = time.time() - # Removed Gurt-specific participation tracking - - except Exception as e: - error_message = f"Error getting proactive AI response for channel {channel_id} ({trigger_reason}): {type(e).__name__}: {str(e)}" - print(error_message) - final_parsed_data = { - "should_respond": False, - "content": None, - "react_with_emoji": None, - "error": error_message, - } - - # Ensure default keys exist - final_parsed_data.setdefault("should_respond", False) - final_parsed_data.setdefault("content", None) - final_parsed_data.setdefault("react_with_emoji", None) - if error_message and "error" not in final_parsed_data: - final_parsed_data["error"] = error_message - - return final_parsed_data - - -# --- Internal AI Call for Specific Tasks --- -async def get_internal_ai_json_response( - cog: "WheatleyCog", - prompt_messages: List[Dict[str, Any]], # Keep this format - task_description: str, - response_schema_dict: Dict[str, Any], # Expect schema as dict - model_name: Optional[str] = None, - temperature: float = 0.7, - max_tokens: int = 5000, -) -> Optional[Dict[str, Any]]: # Keep return type hint simple - """ - Makes a Vertex AI call expecting a specific JSON response format for internal tasks. - - Args: - cog: The WheatleyCog instance. - prompt_messages: List of message dicts (like OpenAI format: {'role': 'user'/'model', 'content': '...'}). - task_description: Description for logging. - response_schema_dict: The expected JSON output schema as a dictionary. - model_name: Optional model override. - temperature: Generation temperature. - max_tokens: Max output tokens. - - Returns: - The parsed and validated JSON dictionary if successful, None otherwise. - """ - if not PROJECT_ID or not LOCATION: - print( - f"Error in get_internal_ai_json_response ({task_description}): GCP Project/Location not set." - ) - return None - - final_parsed_data = None - error_occurred = None - request_payload_for_logging = {} # For logging - - try: - # --- Convert prompt messages to Vertex AI Content format --- - contents: List[Content] = [] - system_instruction = None - for msg in prompt_messages: - role = msg.get("role", "user") - content_text = msg.get("content", "") - if role == "system": - # Use the first system message as system_instruction - if system_instruction is None: - system_instruction = content_text - else: - # Append subsequent system messages to the instruction - system_instruction += "\n\n" + content_text - continue # Skip adding system messages to contents list - elif role == "assistant": - role = "model" - - # --- Process content (string or list) --- - content_value = msg.get("content") - message_parts: List[Part] = ( - [] - ) # Initialize list to hold parts for this message - - if isinstance(content_value, str): - # Handle simple string content - message_parts.append(Part.from_text(content_value)) - elif isinstance(content_value, list): - # Handle list content (e.g., multimodal from ProfileUpdater) - for part_data in content_value: - part_type = part_data.get("type") - if part_type == "text": - text = part_data.get("text", "") - message_parts.append(Part.from_text(text)) - elif part_type == "image_data": - mime_type = part_data.get("mime_type") - base64_data = part_data.get("data") - if mime_type and base64_data: - try: - image_bytes = base64.b64decode(base64_data) - message_parts.append( - Part.from_data( - data=image_bytes, mime_type=mime_type - ) - ) - except Exception as decode_err: - print( - f"Error decoding/adding image part in get_internal_ai_json_response: {decode_err}" - ) - # Optionally add a placeholder text part indicating failure - message_parts.append( - Part.from_text( - "(System Note: Failed to process an image part)" - ) - ) - else: - print("Warning: image_data part missing mime_type or data.") - else: - print( - f"Warning: Unknown part type '{part_type}' in internal prompt message." - ) - else: - print( - f"Warning: Unexpected content type '{type(content_value)}' in internal prompt message." - ) - - # Add the content object if parts were generated - if message_parts: - contents.append(Content(role=role, parts=message_parts)) - else: - print(f"Warning: No parts generated for message role '{role}'.") - - # Add the critical JSON instruction to the last user message or as a new user message - json_instruction_content = ( - f"**CRITICAL: Your response MUST consist *only* of the raw JSON object itself, matching this schema:**\n" - f"{json.dumps(response_schema_dict, indent=2)}\n" - f"**Ensure nothing precedes or follows the JSON.**" - ) - if contents and contents[-1].role == "user": - contents[-1].parts.append(Part.from_text(f"\n\n{json_instruction_content}")) - else: - contents.append( - Content(role="user", parts=[Part.from_text(json_instruction_content)]) - ) - - # --- Initialize Model --- - model = GenerativeModel( - model_name=model_name or DEFAULT_MODEL, # Use keyword argument - system_instruction=system_instruction, - # No tools needed for internal JSON tasks usually - ) - - # --- Prepare Generation Config --- - # Preprocess the schema before passing it to GenerationConfig - processed_schema_internal = _preprocess_schema_for_vertex(response_schema_dict) - generation_config = GenerationConfig( - temperature=temperature, - max_output_tokens=max_tokens, - response_mime_type="application/json", - response_schema=processed_schema_internal, # Use preprocessed schema - ) - - # Prepare payload for logging (approximate) - request_payload_for_logging = { - "model": model._model_name, - "system_instruction": system_instruction, - "contents": [ # Simplified representation for logging - { - "role": c.role, - "parts": [ - p.text if hasattr(p, "text") else str(type(p)) for p in c.parts - ], - } - for c in contents - ], - # Use the original generation_config dict directly for logging - "generation_config": generation_config, # It's already a dict - } - - # --- Add detailed logging for raw request --- - try: - print(f"--- Raw request payload for {task_description} ---") - # Use json.dumps for pretty printing, handle potential errors - print( - json.dumps(request_payload_for_logging, indent=2, default=str) - ) # Use default=str as fallback - print(f"--- End Raw request payload ---") - except Exception as req_log_e: - print(f"Error logging raw request payload: {req_log_e}") - print( - f"Payload causing error: {request_payload_for_logging}" - ) # Print the raw dict on error - # --- End detailed logging --- - - # --- Call API --- - response_obj = await call_vertex_api_with_retry( - cog=cog, - model=model, - contents=contents, - generation_config=generation_config, - safety_settings=STANDARD_SAFETY_SETTINGS, # Use standard safety - request_desc=task_description, - ) - - if not response_obj or not response_obj.candidates: - raise Exception("Internal API call returned no response or candidates.") - - # --- Parse and Validate --- - # This function always expects JSON, so directly use response_obj.text - final_response_text = response_obj.text - # --- Add detailed logging for raw response text --- - print(f"--- Raw response_obj.text for {task_description} ---") - print(final_response_text) - print(f"--- End Raw response_obj.text ---") - # --- End detailed logging --- - print(f"Parsing ({task_description}): Using response_obj.text for JSON.") - - final_parsed_data = parse_and_validate_json_response( - final_response_text, - response_schema_dict, - f"internal task ({task_description})", - ) - - if final_parsed_data is None: - print( - f"Warning: Internal task '{task_description}' failed JSON validation." - ) - # No re-prompting for internal tasks, just return None - - except Exception as e: - print( - f"Error in get_internal_ai_json_response ({task_description}): {type(e).__name__}: {e}" - ) - error_occurred = e - import traceback - - traceback.print_exc() - final_parsed_data = None - finally: - # Log the call - try: - # Pass the simplified payload for logging - await log_internal_api_call( - cog, - task_description, - request_payload_for_logging, - final_parsed_data, - error_occurred, - ) - except Exception as log_e: - print(f"Error logging internal API call: {log_e}") - - return final_parsed_data diff --git a/wheatley/background.py b/wheatley/background.py deleted file mode 100644 index da0b215..0000000 --- a/wheatley/background.py +++ /dev/null @@ -1,117 +0,0 @@ -import asyncio -import time -import traceback -import os -import json -import aiohttp -from typing import TYPE_CHECKING - -# Relative imports -from .config import STATS_PUSH_INTERVAL # Only keep stats interval - -# Removed analysis imports - -if TYPE_CHECKING: - from .cog import WheatleyCog # Updated type hint - -# --- Background Task --- - - -async def background_processing_task(cog: "WheatleyCog"): # Updated type hint - """Background task that periodically pushes stats.""" # Simplified docstring - # Get API details from environment for stats pushing - api_internal_url = os.getenv("API_INTERNAL_URL") - # Use a generic secret name or a Wheatley-specific one if desired - stats_push_secret = os.getenv( - "WHEATLEY_STATS_PUSH_SECRET", os.getenv("GURT_STATS_PUSH_SECRET") - ) # Fallback to GURT secret if needed - - if not api_internal_url: - print( - "WARNING: API_INTERNAL_URL not set. Wheatley stats will not be pushed." - ) # Updated text - if not stats_push_secret: - print( - "WARNING: WHEATLEY_STATS_PUSH_SECRET (or GURT_STATS_PUSH_SECRET) not set. Stats push endpoint is insecure and likely won't work." - ) # Updated text - - try: - while True: - await asyncio.sleep(STATS_PUSH_INTERVAL) # Use the stats interval directly - now = time.time() - - # --- Push Stats --- - if ( - api_internal_url and stats_push_secret - ): # Removed check for last push time, rely on sleep interval - print("Pushing Wheatley stats to API server...") # Updated text - try: - stats_data = await cog.get_wheatley_stats() # Updated method call - headers = { - "Authorization": f"Bearer {stats_push_secret}", - "Content-Type": "application/json", - } - # Use the cog's session, ensure it's created - if cog.session: - # Set a reasonable timeout for the stats push - push_timeout = aiohttp.ClientTimeout( - total=10 - ) # 10 seconds total timeout - async with cog.session.post( - api_internal_url, - json=stats_data, - headers=headers, - timeout=push_timeout, - ssl=True, - ) as response: # Explicitly enable SSL verification - if response.status == 200: - print( - f"Successfully pushed Wheatley stats (Status: {response.status})" - ) # Updated text - else: - error_text = await response.text() - print( - f"Failed to push Wheatley stats (Status: {response.status}): {error_text[:200]}" - ) # Updated text, Log only first 200 chars - else: - print( - "Error pushing stats: WheatleyCog session not initialized." - ) # Updated text - # Removed updating cog.last_stats_push as we rely on sleep interval - except aiohttp.ClientConnectorSSLError as ssl_err: - print( - f"SSL Error pushing Wheatley stats: {ssl_err}. Ensure the API server's certificate is valid and trusted, or check network configuration." - ) # Updated text - print( - "If using a self-signed certificate for development, the bot process might need to trust it." - ) - except aiohttp.ClientError as client_err: - print( - f"HTTP Client Error pushing Wheatley stats: {client_err}" - ) # Updated text - except asyncio.TimeoutError: - print("Timeout error pushing Wheatley stats.") # Updated text - except Exception as e: - print( - f"Unexpected error pushing Wheatley stats: {e}" - ) # Updated text - traceback.print_exc() - - # --- Removed Learning Analysis --- - # --- Removed Evolve Personality --- - # --- Removed Update Interests --- - # --- Removed Memory Reflection --- - # --- Removed Goal Decomposition --- - # --- Removed Goal Execution --- - # --- Removed Automatic Mood Change --- - - except asyncio.CancelledError: - print("Wheatley background processing task cancelled") # Updated text - except Exception as e: - print(f"Error in Wheatley background processing task: {e}") # Updated text - traceback.print_exc() - await asyncio.sleep(300) # Wait 5 minutes before retrying after an error - - -# --- Removed Automatic Mood Change Logic --- -# --- Removed Interest Update Logic --- diff --git a/wheatley/cog.py b/wheatley/cog.py deleted file mode 100644 index 6093b21..0000000 --- a/wheatley/cog.py +++ /dev/null @@ -1,458 +0,0 @@ -import discord -from discord.ext import commands -import asyncio -import os -import json -import aiohttp -import random -import time -from collections import defaultdict, deque -from typing import Dict, List, Any, Optional, Tuple, Set, Union - -# Third-party imports needed by the Cog itself or its direct methods -from dotenv import load_dotenv -from tavily import TavilyClient # Needed for tavily_client init - -# --- Relative Imports from Wheatley Package --- -from .config import ( - PROJECT_ID, - LOCATION, - TAVILY_API_KEY, - DEFAULT_MODEL, - FALLBACK_MODEL, # Use GCP config - DB_PATH, - CHROMA_PATH, - SEMANTIC_MODEL_NAME, - MAX_USER_FACTS, - MAX_GENERAL_FACTS, - # Removed Mood/Personality/Interest/Learning/Goal configs - CHANNEL_TOPIC_CACHE_TTL, - CONTEXT_WINDOW_SIZE, - API_TIMEOUT, - SUMMARY_API_TIMEOUT, - API_RETRY_ATTEMPTS, - API_RETRY_DELAY, - PROACTIVE_LULL_THRESHOLD, - PROACTIVE_BOT_SILENCE_THRESHOLD, - PROACTIVE_LULL_CHANCE, - PROACTIVE_TOPIC_RELEVANCE_THRESHOLD, - PROACTIVE_TOPIC_CHANCE, - # Removed Relationship/Sentiment/Interest proactive configs - TOPIC_UPDATE_INTERVAL, - SENTIMENT_UPDATE_INTERVAL, - RESPONSE_SCHEMA, - TOOLS, # Import necessary configs -) - -# Import functions/classes from other modules -from .memory import MemoryManager # Import from local memory.py -from .background import ( - background_processing_task, -) # Keep background task for potential future use (e.g., cache cleanup) -from .commands import setup_commands # Import the setup helper -from .listeners import ( - on_ready_listener, - on_message_listener, - on_reaction_add_listener, - on_reaction_remove_listener, -) # Import listener functions -from . import config as WheatleyConfig # Import config module for get_wheatley_stats - -# Load environment variables (might be loaded globally in main bot script too) -load_dotenv() - - -class WheatleyCog(commands.Cog, name="Wheatley"): # Renamed class and Cog name - """A special cog for the Wheatley bot that uses Google Vertex AI API""" # Updated docstring - - def __init__(self, bot): - self.bot = bot - # GCP Project/Location are used by vertexai.init() in api.py - self.tavily_api_key = TAVILY_API_KEY # Use imported config - self.session: Optional[aiohttp.ClientSession] = ( - None # Keep for other potential HTTP requests (e.g., Piston) - ) - self.tavily_client = ( - TavilyClient(api_key=self.tavily_api_key) if self.tavily_api_key else None - ) - self.default_model = DEFAULT_MODEL # Use imported config - self.fallback_model = FALLBACK_MODEL # Use imported config - # Removed MOOD_OPTIONS - self.current_channel: Optional[ - Union[discord.TextChannel, discord.Thread, discord.DMChannel] - ] = None # Type hint current channel - - # Instantiate MemoryManager - self.memory_manager = MemoryManager( - db_path=DB_PATH, - max_user_facts=MAX_USER_FACTS, - max_general_facts=MAX_GENERAL_FACTS, - chroma_path=CHROMA_PATH, - semantic_model_name=SEMANTIC_MODEL_NAME, - ) - - # --- State Variables (Simplified for Wheatley) --- - # Removed mood, personality evolution, interest tracking, learning state - self.needs_json_reminder = False # Flag to remind AI about JSON format - - # Topic tracking (Kept for context) - self.active_topics = defaultdict( - lambda: { - "topics": [], - "last_update": time.time(), - "topic_history": [], - "user_topic_interests": defaultdict( - list - ), # Kept for potential future analysis, not proactive triggers - } - ) - - # Conversation tracking / Caches - self.conversation_history = defaultdict(lambda: deque(maxlen=100)) - self.thread_history = defaultdict(lambda: deque(maxlen=50)) - self.user_conversation_mapping = defaultdict(set) - self.channel_activity = defaultdict(lambda: 0.0) # Use float for timestamp - self.conversation_topics = defaultdict(str) # Simplified topic tracking - self.user_relationships = defaultdict( - dict - ) # Kept for potential context/analysis - self.conversation_summaries: Dict[int, Dict[str, Any]] = ( - {} - ) # Store dict with summary and timestamp - self.channel_topics_cache: Dict[int, Dict[str, Any]] = ( - {} - ) # Store dict with topic and timestamp - - self.message_cache = { - "by_channel": defaultdict( - lambda: deque(maxlen=CONTEXT_WINDOW_SIZE) - ), # Use config - "by_user": defaultdict(lambda: deque(maxlen=50)), - "by_thread": defaultdict(lambda: deque(maxlen=50)), - "global_recent": deque(maxlen=200), - "mentioned": deque(maxlen=50), - "replied_to": defaultdict(lambda: deque(maxlen=20)), - } - - self.active_conversations = {} # Kept for basic tracking - self.bot_last_spoke = defaultdict(float) - self.message_reply_map = {} - - # Enhanced sentiment tracking (Kept for context/analysis) - self.conversation_sentiment = defaultdict( - lambda: { - "overall": "neutral", - "intensity": 0.5, - "recent_trend": "stable", - "user_sentiments": {}, - "last_update": time.time(), - } - ) - # Removed self.sentiment_update_interval as it was only used in analysis - - # Reaction Tracking (Renamed) - self.wheatley_message_reactions = defaultdict( - lambda: {"positive": 0, "negative": 0, "topic": None, "timestamp": 0.0} - ) # Renamed - - # Background task handle (Kept for potential future tasks like cache cleanup) - self.background_task: Optional[asyncio.Task] = None - self.last_stats_push = time.time() # Timestamp for last stats push - # Removed evolution, reflection, goal timestamps - - # --- Stats Tracking --- - self.api_stats = defaultdict( - lambda: { - "success": 0, - "failure": 0, - "retries": 0, - "total_time": 0.0, - "count": 0, - } - ) # Keyed by model name - self.tool_stats = defaultdict( - lambda: {"success": 0, "failure": 0, "total_time": 0.0, "count": 0} - ) # Keyed by tool name - - # --- Setup Commands and Listeners --- - # Add commands defined in commands.py - self.command_functions = setup_commands(self) - - # Store command names for reference - safely handle Command objects - self.registered_commands = [] - for func in self.command_functions: - # For app commands, use the name attribute directly - if hasattr(func, "name"): - self.registered_commands.append(func.name) - # For regular functions, use __name__ - elif hasattr(func, "__name__"): - self.registered_commands.append(func.__name__) - else: - self.registered_commands.append(str(func)) - - print( - f"WheatleyCog initialized with commands: {self.registered_commands}" - ) # Updated print - - async def cog_load(self): - """Create aiohttp session, initialize DB, start background task""" - self.session = aiohttp.ClientSession() - print("WheatleyCog: aiohttp session created") # Updated print - - # Initialize DB via MemoryManager - await self.memory_manager.initialize_sqlite_database() - # Removed loading of baseline personality and interests - - # Vertex AI initialization happens in api.py using PROJECT_ID and LOCATION from config - print( - f"WheatleyCog: Using default model: {self.default_model}" - ) # Updated print - if not self.tavily_api_key: - print( - "WARNING: Tavily API key not configured (TAVILY_API_KEY). Web search disabled." - ) - - # Add listeners to the bot instance - # IMPORTANT: Don't override on_member_join or on_member_remove events - - # Check if the bot already has event listeners for member join/leave - has_member_join = "on_member_join" in self.bot.extra_events - has_member_remove = "on_member_remove" in self.bot.extra_events - print( - f"WheatleyCog: Bot already has event listeners - on_member_join: {has_member_join}, on_member_remove: {has_member_remove}" - ) - - @self.bot.event - async def on_ready(): - await on_ready_listener(self) - - @self.bot.event - async def on_message(message): - # Ensure commands are processed if using command prefix - if message.content.startswith(self.bot.command_prefix): - await self.bot.process_commands(message) - # Always run the message listener for potential AI responses/tracking - await on_message_listener(self, message) - - @self.bot.event - async def on_reaction_add(reaction, user): - await on_reaction_add_listener(self, reaction, user) - - @self.bot.event - async def on_reaction_remove(reaction, user): - await on_reaction_remove_listener(self, reaction, user) - - print("WheatleyCog: Listeners added.") # Updated print - - # Commands will be synced in on_ready - print( - "WheatleyCog: Commands will be synced when the bot is ready." - ) # Updated print - - # Start background task (kept for potential future use) - if self.background_task is None or self.background_task.done(): - self.background_task = asyncio.create_task(background_processing_task(self)) - print("WheatleyCog: Started background processing task.") # Updated print - else: - print( - "WheatleyCog: Background processing task already running." - ) # Updated print - - async def cog_unload(self): - """Close session and cancel background task""" - if self.session and not self.session.closed: - await self.session.close() - print("WheatleyCog: aiohttp session closed") # Updated print - if self.background_task and not self.background_task.done(): - self.background_task.cancel() - print("WheatleyCog: Cancelled background processing task.") # Updated print - print( - "WheatleyCog: Listeners will be removed when bot is closed." - ) # Updated print - - print("WheatleyCog unloaded.") # Updated print - - # --- Helper methods that might remain in the cog --- - # _update_relationship kept for potential context/analysis use - def _update_relationship(self, user_id_1: str, user_id_2: str, change: float): - """Updates the relationship score between two users.""" - if user_id_1 > user_id_2: - user_id_1, user_id_2 = user_id_2, user_id_1 - if user_id_1 not in self.user_relationships: - self.user_relationships[user_id_1] = {} - - current_score = self.user_relationships[user_id_1].get(user_id_2, 0.0) - new_score = max(0.0, min(current_score + change, 100.0)) # Clamp 0-100 - self.user_relationships[user_id_1][user_id_2] = new_score - # print(f"Updated relationship {user_id_1}-{user_id_2}: {current_score:.1f} -> {new_score:.1f} ({change:+.1f})") # Debug log - - async def get_wheatley_stats(self) -> Dict[str, Any]: # Renamed method - """Collects various internal stats for Wheatley.""" # Updated docstring - stats = { - "config": {}, - "runtime": {}, - "memory": {}, - "api_stats": {}, - "tool_stats": {}, - } - - # --- Config (Simplified) --- - stats["config"]["default_model"] = WheatleyConfig.DEFAULT_MODEL - stats["config"]["fallback_model"] = WheatleyConfig.FALLBACK_MODEL - stats["config"]["safety_check_model"] = WheatleyConfig.SAFETY_CHECK_MODEL - stats["config"]["db_path"] = WheatleyConfig.DB_PATH - stats["config"]["chroma_path"] = WheatleyConfig.CHROMA_PATH - stats["config"]["semantic_model_name"] = WheatleyConfig.SEMANTIC_MODEL_NAME - stats["config"]["max_user_facts"] = WheatleyConfig.MAX_USER_FACTS - stats["config"]["max_general_facts"] = WheatleyConfig.MAX_GENERAL_FACTS - stats["config"]["context_window_size"] = WheatleyConfig.CONTEXT_WINDOW_SIZE - stats["config"]["api_timeout"] = WheatleyConfig.API_TIMEOUT - stats["config"]["summary_api_timeout"] = WheatleyConfig.SUMMARY_API_TIMEOUT - stats["config"][ - "proactive_lull_threshold" - ] = WheatleyConfig.PROACTIVE_LULL_THRESHOLD - stats["config"][ - "proactive_bot_silence_threshold" - ] = WheatleyConfig.PROACTIVE_BOT_SILENCE_THRESHOLD - stats["config"]["topic_update_interval"] = WheatleyConfig.TOPIC_UPDATE_INTERVAL - stats["config"][ - "sentiment_update_interval" - ] = WheatleyConfig.SENTIMENT_UPDATE_INTERVAL - stats["config"][ - "docker_command_timeout" - ] = WheatleyConfig.DOCKER_COMMAND_TIMEOUT - stats["config"]["project_id_set"] = bool( - WheatleyConfig.PROJECT_ID != "your-gcp-project-id" - ) - stats["config"]["location_set"] = bool(WheatleyConfig.LOCATION != "us-central1") - stats["config"]["tavily_api_key_set"] = bool(WheatleyConfig.TAVILY_API_KEY) - stats["config"]["piston_api_url_set"] = bool(WheatleyConfig.PISTON_API_URL) - - # --- Runtime (Simplified) --- - # Removed mood, evolution - stats["runtime"]["needs_json_reminder"] = self.needs_json_reminder - stats["runtime"]["background_task_running"] = bool( - self.background_task and not self.background_task.done() - ) - stats["runtime"]["active_topics_channels"] = len(self.active_topics) - stats["runtime"]["conversation_history_channels"] = len( - self.conversation_history - ) - stats["runtime"]["thread_history_threads"] = len(self.thread_history) - stats["runtime"]["user_conversation_mappings"] = len( - self.user_conversation_mapping - ) - stats["runtime"]["channel_activity_tracked"] = len(self.channel_activity) - stats["runtime"]["conversation_topics_tracked"] = len( - self.conversation_topics - ) # Simplified topic tracking - stats["runtime"]["user_relationships_pairs"] = sum( - len(v) for v in self.user_relationships.values() - ) - stats["runtime"]["conversation_summaries_cached"] = len( - self.conversation_summaries - ) - stats["runtime"]["channel_topics_cached"] = len(self.channel_topics_cache) - stats["runtime"]["message_cache_global_count"] = len( - self.message_cache["global_recent"] - ) - stats["runtime"]["message_cache_mentioned_count"] = len( - self.message_cache["mentioned"] - ) - stats["runtime"]["active_conversations_count"] = len(self.active_conversations) - stats["runtime"]["bot_last_spoke_channels"] = len(self.bot_last_spoke) - stats["runtime"]["message_reply_map_size"] = len(self.message_reply_map) - stats["runtime"]["conversation_sentiment_channels"] = len( - self.conversation_sentiment - ) - # Removed Gurt participation topics - stats["runtime"]["wheatley_message_reactions_tracked"] = len( - self.wheatley_message_reactions - ) # Renamed - - # --- Memory (Simplified) --- - try: - # Removed Personality, Interests - user_fact_count = await self.memory_manager._db_fetchone( - "SELECT COUNT(*) FROM user_facts" - ) - general_fact_count = await self.memory_manager._db_fetchone( - "SELECT COUNT(*) FROM general_facts" - ) - stats["memory"]["user_facts_count"] = ( - user_fact_count[0] if user_fact_count else 0 - ) - stats["memory"]["general_facts_count"] = ( - general_fact_count[0] if general_fact_count else 0 - ) - - # ChromaDB Stats - stats["memory"]["chromadb_message_collection_count"] = ( - await asyncio.to_thread(self.memory_manager.semantic_collection.count) - if self.memory_manager.semantic_collection - else "N/A" - ) - stats["memory"]["chromadb_fact_collection_count"] = ( - await asyncio.to_thread(self.memory_manager.fact_collection.count) - if self.memory_manager.fact_collection - else "N/A" - ) - - except Exception as e: - stats["memory"]["error"] = f"Failed to retrieve memory stats: {e}" - - # --- API & Tool Stats --- - stats["api_stats"] = dict(self.api_stats) - stats["tool_stats"] = dict(self.tool_stats) - - # Calculate average times - for model, data in stats["api_stats"].items(): - if data["count"] > 0: - data["average_time_ms"] = round( - (data["total_time"] / data["count"]) * 1000, 2 - ) - else: - data["average_time_ms"] = 0 - for tool, data in stats["tool_stats"].items(): - if data["count"] > 0: - data["average_time_ms"] = round( - (data["total_time"] / data["count"]) * 1000, 2 - ) - else: - data["average_time_ms"] = 0 - - return stats - - async def sync_commands(self): - """Manually sync commands with Discord.""" - try: - print( - "WheatleyCog: Manually syncing commands with Discord..." - ) # Updated print - synced = await self.bot.tree.sync() - print(f"WheatleyCog: Synced {len(synced)} command(s)") # Updated print - - # List the synced commands - wheatley_commands = [ - cmd.name - for cmd in self.bot.tree.get_commands() - if cmd.name.startswith("wheatley") - ] # Updated prefix - print( - f"WheatleyCog: Available Wheatley commands: {', '.join(wheatley_commands)}" - ) # Updated print - - return synced, wheatley_commands - except Exception as e: - print(f"WheatleyCog: Failed to sync commands: {e}") # Updated print - import traceback - - traceback.print_exc() - return [], [] - - -# Setup function for loading the cog -async def setup(bot): - """Add the WheatleyCog to the bot.""" # Updated docstring - await bot.add_cog(WheatleyCog(bot)) # Use renamed class - print("WheatleyCog setup complete.") # Updated print diff --git a/wheatley/commands.py b/wheatley/commands.py deleted file mode 100644 index 280a2b4..0000000 --- a/wheatley/commands.py +++ /dev/null @@ -1,557 +0,0 @@ -import discord -from discord import app_commands # Import app_commands -from discord.ext import commands -import random -import os -import time # Import time for timestamps -import json # Import json for formatting -import datetime # Import datetime for formatting -from typing import TYPE_CHECKING, Optional, Dict, Any, List, Tuple # Add more types - -# Relative imports -# We need access to the cog instance for state and methods - -if TYPE_CHECKING: - from .cog import WheatleyCog # For type hinting - - # MOOD_OPTIONS removed - - -# --- Helper Function for Embeds --- -def create_wheatley_embed( - title: str, description: str = "", color=discord.Color.blue() -) -> discord.Embed: # Renamed function - """Creates a standard Wheatley-themed embed.""" # Updated docstring - embed = discord.Embed(title=title, description=description, color=color) - # Placeholder icon URL, replace if Wheatley has one - # embed.set_footer(text="Wheatley", icon_url="https://example.com/wheatley_icon.png") # Updated text - embed.set_footer(text="Wheatley") # Updated text - return embed - - -# --- Helper Function for Stats Embeds --- -def format_stats_embeds(stats: Dict[str, Any]) -> List[discord.Embed]: - """Formats the collected stats into multiple embeds.""" - embeds = [] - main_embed = create_wheatley_embed( - "Wheatley Internal Stats", color=discord.Color.green() - ) # Use new helper, updated title - ts_format = "" # Relative timestamp - - # Runtime Stats (Simplified for Wheatley) - runtime = stats.get("runtime", {}) - main_embed.add_field( - name="Background Task", - value="Running" if runtime.get("background_task_running") else "Stopped", - inline=True, - ) - main_embed.add_field( - name="Needs JSON Reminder", - value=str(runtime.get("needs_json_reminder", "N/A")), - inline=True, - ) - # Removed Mood, Evolution - main_embed.add_field( - name="Active Topics Channels", - value=str(runtime.get("active_topics_channels", "N/A")), - inline=True, - ) - main_embed.add_field( - name="Conv History Channels", - value=str(runtime.get("conversation_history_channels", "N/A")), - inline=True, - ) - main_embed.add_field( - name="Thread History Threads", - value=str(runtime.get("thread_history_threads", "N/A")), - inline=True, - ) - main_embed.add_field( - name="User Relationships Pairs", - value=str(runtime.get("user_relationships_pairs", "N/A")), - inline=True, - ) - main_embed.add_field( - name="Cached Summaries", - value=str(runtime.get("conversation_summaries_cached", "N/A")), - inline=True, - ) - main_embed.add_field( - name="Cached Channel Topics", - value=str(runtime.get("channel_topics_cached", "N/A")), - inline=True, - ) - main_embed.add_field( - name="Global Msg Cache", - value=str(runtime.get("message_cache_global_count", "N/A")), - inline=True, - ) - main_embed.add_field( - name="Mention Msg Cache", - value=str(runtime.get("message_cache_mentioned_count", "N/A")), - inline=True, - ) - main_embed.add_field( - name="Active Convos", - value=str(runtime.get("active_conversations_count", "N/A")), - inline=True, - ) - main_embed.add_field( - name="Sentiment Channels", - value=str(runtime.get("conversation_sentiment_channels", "N/A")), - inline=True, - ) - # Removed Gurt Participation Topics - main_embed.add_field( - name="Tracked Reactions", - value=str(runtime.get("wheatley_message_reactions_tracked", "N/A")), - inline=True, - ) # Renamed stat key - embeds.append(main_embed) - - # Memory Stats (Simplified) - memory_embed = create_wheatley_embed( - "Wheatley Memory Stats", color=discord.Color.orange() - ) # Use new helper, updated title - memory = stats.get("memory", {}) - if memory.get("error"): - memory_embed.description = f"⚠️ Error retrieving memory stats: {memory['error']}" - else: - memory_embed.add_field( - name="User Facts", - value=str(memory.get("user_facts_count", "N/A")), - inline=True, - ) - memory_embed.add_field( - name="General Facts", - value=str(memory.get("general_facts_count", "N/A")), - inline=True, - ) - memory_embed.add_field( - name="Chroma Messages", - value=str(memory.get("chromadb_message_collection_count", "N/A")), - inline=True, - ) - memory_embed.add_field( - name="Chroma Facts", - value=str(memory.get("chromadb_fact_collection_count", "N/A")), - inline=True, - ) - # Removed Personality Traits, Interests - embeds.append(memory_embed) - - # API Stats - api_stats = stats.get("api_stats", {}) - if api_stats: - api_embed = create_wheatley_embed( - "Wheatley API Stats", color=discord.Color.red() - ) # Use new helper, updated title - for model, data in api_stats.items(): - avg_time = data.get("average_time_ms", 0) - value = ( - f"✅ Success: {data.get('success', 0)}\n" - f"❌ Failure: {data.get('failure', 0)}\n" - f"🔁 Retries: {data.get('retries', 0)}\n" - f"⏱️ Avg Time: {avg_time} ms\n" - f"📊 Count: {data.get('count', 0)}" - ) - api_embed.add_field(name=f"Model: `{model}`", value=value, inline=True) - embeds.append(api_embed) - - # Tool Stats - tool_stats = stats.get("tool_stats", {}) - if tool_stats: - tool_embed = create_wheatley_embed( - "Wheatley Tool Stats", color=discord.Color.purple() - ) # Use new helper, updated title - for tool, data in tool_stats.items(): - avg_time = data.get("average_time_ms", 0) - value = ( - f"✅ Success: {data.get('success', 0)}\n" - f"❌ Failure: {data.get('failure', 0)}\n" - f"⏱️ Avg Time: {avg_time} ms\n" - f"📊 Count: {data.get('count', 0)}" - ) - tool_embed.add_field(name=f"Tool: `{tool}`", value=value, inline=True) - embeds.append(tool_embed) - - # Config Stats (Simplified) - config_embed = create_wheatley_embed( - "Wheatley Config Overview", color=discord.Color.greyple() - ) # Use new helper, updated title - config = stats.get("config", {}) - config_embed.add_field( - name="Default Model", - value=f"`{config.get('default_model', 'N/A')}`", - inline=True, - ) - config_embed.add_field( - name="Fallback Model", - value=f"`{config.get('fallback_model', 'N/A')}`", - inline=True, - ) - config_embed.add_field( - name="Semantic Model", - value=f"`{config.get('semantic_model_name', 'N/A')}`", - inline=True, - ) - config_embed.add_field( - name="Max User Facts", - value=str(config.get("max_user_facts", "N/A")), - inline=True, - ) - config_embed.add_field( - name="Max General Facts", - value=str(config.get("max_general_facts", "N/A")), - inline=True, - ) - config_embed.add_field( - name="Context Window", - value=str(config.get("context_window_size", "N/A")), - inline=True, - ) - config_embed.add_field( - name="Tavily Key Set", - value=str(config.get("tavily_api_key_set", "N/A")), - inline=True, - ) - config_embed.add_field( - name="Piston URL Set", - value=str(config.get("piston_api_url_set", "N/A")), - inline=True, - ) - embeds.append(config_embed) - - # Limit to 10 embeds max for Discord API - return embeds[:10] - - -# --- Command Setup Function --- -# This function will be called from WheatleyCog's setup method -def setup_commands(cog: "WheatleyCog"): # Updated type hint - """Adds Wheatley-specific commands to the cog.""" # Updated docstring - - # Create a list to store command functions for proper registration - command_functions = [] - - # --- Gurt Mood Command --- REMOVED - - # --- Wheatley Memory Command --- - @cog.bot.tree.command( - name="wheatleymemory", - description="Interact with Wheatley's memory (what little there is).", - ) # Renamed, updated description - @app_commands.describe( - action="Choose an action: add_user, add_general, get_user, get_general", - user="The user for user-specific actions (mention or ID).", - fact="The fact to add (for add actions).", - query="A keyword to search for (for get_general).", - ) - @app_commands.choices( - action=[ - app_commands.Choice(name="Add User Fact", value="add_user"), - app_commands.Choice(name="Add General Fact", value="add_general"), - app_commands.Choice(name="Get User Facts", value="get_user"), - app_commands.Choice(name="Get General Facts", value="get_general"), - ] - ) - async def wheatleymemory( - interaction: discord.Interaction, - action: app_commands.Choice[str], - user: Optional[discord.User] = None, - fact: Optional[str] = None, - query: Optional[str] = None, - ): # Renamed function - """Handles the /wheatleymemory command.""" # Updated docstring - await interaction.response.defer( - ephemeral=True - ) # Defer for potentially slow DB operations - - target_user_id = str(user.id) if user else None - action_value = action.value - - # Check if user is the bot owner for modification actions - if ( - action_value in ["add_user", "add_general"] - ) and interaction.user.id != cog.bot.owner_id: - await interaction.followup.send( - "⛔ Oi! Only the boss can fiddle with my memory banks!", ephemeral=True - ) # Updated text - return - - if action_value == "add_user": - if not target_user_id or not fact: - await interaction.followup.send( - "Need a user *and* a fact, mate. Can't remember nothing about nobody.", - ephemeral=True, - ) # Updated text - return - result = await cog.memory_manager.add_user_fact(target_user_id, fact) - await interaction.followup.send( - f"Add User Fact Result: `{json.dumps(result)}` (Probably worked? Maybe?)", - ephemeral=True, - ) # Updated text - - elif action_value == "add_general": - if not fact: - await interaction.followup.send( - "What's the fact then? Can't remember thin air!", ephemeral=True - ) # Updated text - return - result = await cog.memory_manager.add_general_fact(fact) - await interaction.followup.send( - f"Add General Fact Result: `{json.dumps(result)}` (Filed under 'Important Stuff I'll Forget Later')", - ephemeral=True, - ) # Updated text - - elif action_value == "get_user": - if not target_user_id: - await interaction.followup.send( - "Which user? Need an ID, chap!", ephemeral=True - ) # Updated text - return - facts = await cog.memory_manager.get_user_facts( - target_user_id - ) # Get newest by default - if facts: - facts_str = "\n- ".join(facts) - await interaction.followup.send( - f"**Stuff I Remember About {user.display_name}:**\n- {facts_str}", - ephemeral=True, - ) # Updated text - else: - await interaction.followup.send( - f"My mind's a blank slate about {user.display_name}. Nothing stored!", - ephemeral=True, - ) # Updated text - - elif action_value == "get_general": - facts = await cog.memory_manager.get_general_facts( - query=query, limit=10 - ) # Get newest/filtered - if facts: - facts_str = "\n- ".join(facts) - # Conditionally construct the title to avoid nested f-string issues - if query: - title = f'**General Stuff Matching "{query}":**' # Updated text - else: - title = "**General Stuff I Might Know:**" # Updated text - await interaction.followup.send( - f"{title}\n- {facts_str}", ephemeral=True - ) - else: - # Conditionally construct the message for the same reason - if query: - message = f"Couldn't find any general facts matching \"{query}\". Probably wasn't important." # Updated text - else: - message = "No general facts found. My memory's not what it used to be. Or maybe it is. Hard to tell." # Updated text - await interaction.followup.send(message, ephemeral=True) - - else: - await interaction.followup.send( - "Invalid action specified. What are you trying to do?", ephemeral=True - ) # Updated text - - command_functions.append(wheatleymemory) # Add renamed function - - # --- Wheatley Stats Command --- - @cog.bot.tree.command( - name="wheatleystats", - description="Display Wheatley's internal statistics. (Owner only)", - ) # Renamed, updated description - async def wheatleystats(interaction: discord.Interaction): # Renamed function - """Handles the /wheatleystats command.""" # Updated docstring - # Owner check - if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message( - "⛔ Sorry mate, classified information! Top secret! Or maybe I just forgot where I put it.", - ephemeral=True, - ) - return - - await interaction.response.defer( - ephemeral=True - ) # Defer as stats collection might take time - try: - stats_data = await cog.get_wheatley_stats() # Renamed cog method call - embeds = format_stats_embeds(stats_data) - await interaction.followup.send(embeds=embeds, ephemeral=True) - except Exception as e: - print(f"Error in /wheatleystats command: {e}") # Updated command name - import traceback - - traceback.print_exc() - await interaction.followup.send( - "An error occurred while fetching Wheatley's stats. Probably my fault.", - ephemeral=True, - ) # Updated text - - command_functions.append(wheatleystats) # Add renamed function - - # --- Sync Wheatley Commands (Owner Only) --- - @cog.bot.tree.command( - name="wheatleysync", - description="Sync Wheatley commands with Discord (Owner only)", - ) # Renamed, updated description - async def wheatleysync(interaction: discord.Interaction): # Renamed function - """Handles the /wheatleysync command to force sync commands.""" # Updated docstring - # Check if user is the bot owner - if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message( - "⛔ Only the boss can push the big red sync button!", ephemeral=True - ) # Updated text - return - - await interaction.response.defer(ephemeral=True) - try: - # Sync commands - synced = await cog.bot.tree.sync() - - # Get list of commands after sync - commands_after = [] - for cmd in cog.bot.tree.get_commands(): - if cmd.name.startswith("wheatley"): # Check for new prefix - commands_after.append(cmd.name) - - await interaction.followup.send( - f"✅ Successfully synced {len(synced)} commands!\nWheatley commands: {', '.join(commands_after)}", - ephemeral=True, - ) # Updated text - except Exception as e: - print(f"Error in /wheatleysync command: {e}") # Updated command name - import traceback - - traceback.print_exc() - await interaction.followup.send( - f"❌ Error syncing commands: {str(e)} (Did I break it again?)", - ephemeral=True, - ) # Updated text - - command_functions.append(wheatleysync) # Add renamed function - - # --- Wheatley Forget Command --- - @cog.bot.tree.command( - name="wheatleyforget", - description="Make Wheatley forget a specific fact (if he can).", - ) # Renamed, updated description - @app_commands.describe( - scope="Choose the scope: user (for facts about a specific user) or general.", - fact="The exact fact text Wheatley should forget.", - user="The user to forget a fact about (only if scope is 'user').", - ) - @app_commands.choices( - scope=[ - app_commands.Choice(name="User Fact", value="user"), - app_commands.Choice(name="General Fact", value="general"), - ] - ) - async def wheatleyforget( - interaction: discord.Interaction, - scope: app_commands.Choice[str], - fact: str, - user: Optional[discord.User] = None, - ): # Renamed function - """Handles the /wheatleyforget command.""" # Updated docstring - await interaction.response.defer(ephemeral=True) - - scope_value = scope.value - target_user_id = str(user.id) if user else None - - # Permissions Check: Allow users to forget facts about themselves, owner can forget anything. - can_forget = False - if scope_value == "user": - if target_user_id == str( - interaction.user.id - ): # User forgetting their own fact - can_forget = True - elif ( - interaction.user.id == cog.bot.owner_id - ): # Owner forgetting any user fact - can_forget = True - elif not target_user_id: - await interaction.followup.send( - "❌ Please specify a user when forgetting a user fact.", - ephemeral=True, - ) - return - elif scope_value == "general": - if ( - interaction.user.id == cog.bot.owner_id - ): # Only owner can forget general facts - can_forget = True - - if not can_forget: - await interaction.followup.send( - "⛔ You don't have permission to make me forget things! Only I can forget things on my own!", - ephemeral=True, - ) # Updated text - return - - if not fact: - await interaction.followup.send( - "❌ Forget what exactly? Need the fact text!", ephemeral=True - ) # Updated text - return - - result = None - if scope_value == "user": - if not target_user_id: # Should be caught above, but double-check - await interaction.followup.send( - "❌ User is required for scope 'user'.", ephemeral=True - ) - return - result = await cog.memory_manager.delete_user_fact(target_user_id, fact) - if result.get("status") == "deleted": - await interaction.followup.send( - f"✅ Okay, okay! Forgotten the fact '{fact}' about {user.display_name}. Probably.", - ephemeral=True, - ) # Updated text - elif result.get("status") == "not_found": - await interaction.followup.send( - f"❓ Couldn't find that fact ('{fact}') for {user.display_name}. Maybe I already forgot?", - ephemeral=True, - ) # Updated text - else: - await interaction.followup.send( - f"⚠️ Error forgetting user fact: {result.get('error', 'Something went wrong... surprise!')}", - ephemeral=True, - ) # Updated text - - elif scope_value == "general": - result = await cog.memory_manager.delete_general_fact(fact) - if result.get("status") == "deleted": - await interaction.followup.send( - f"✅ Right! Forgotten the general fact: '{fact}'. Gone!", - ephemeral=True, - ) # Updated text - elif result.get("status") == "not_found": - await interaction.followup.send( - f"❓ Couldn't find that general fact: '{fact}'. Was it important?", - ephemeral=True, - ) # Updated text - else: - await interaction.followup.send( - f"⚠️ Error forgetting general fact: {result.get('error', 'Whoops!')}", - ephemeral=True, - ) # Updated text - - command_functions.append(wheatleyforget) # Add renamed function - - # --- Gurt Goal Command Group --- REMOVED - - # Get command names safely - command_names = [] - for func in command_functions: - # For app commands, use the name attribute directly - if hasattr(func, "name"): - command_names.append(func.name) - # For regular functions, use __name__ - elif hasattr(func, "__name__"): - command_names.append(func.__name__) - else: - command_names.append(str(func)) - - print(f"Wheatley commands setup in cog: {command_names}") # Updated text - - # Return the command functions for proper registration - return command_functions diff --git a/wheatley/config.py b/wheatley/config.py deleted file mode 100644 index efa1b06..0000000 --- a/wheatley/config.py +++ /dev/null @@ -1,786 +0,0 @@ -import os -import random -import json -from dotenv import load_dotenv - -# Placeholder for actual import - will be handled at runtime -try: - from vertexai import generative_models -except ImportError: - # Define a dummy class if the library isn't installed, - # so eval doesn't immediately fail. - # This assumes the code won't actually run without the library. - class DummyGenerativeModels: - class FunctionDeclaration: - def __init__(self, name, description, parameters): - pass - - generative_models = DummyGenerativeModels() - -# Load environment variables -load_dotenv() - -# --- API and Keys --- -PROJECT_ID = os.getenv("GCP_PROJECT_ID", "your-gcp-project-id") -LOCATION = os.getenv("GCP_LOCATION", "us-central1") -TAVILY_API_KEY = os.getenv("TAVILY_API_KEY", "") -PISTON_API_URL = os.getenv("PISTON_API_URL") # For run_python_code tool -PISTON_API_KEY = os.getenv("PISTON_API_KEY") # Optional key for Piston - -# --- Tavily Configuration --- -TAVILY_DEFAULT_SEARCH_DEPTH = os.getenv("TAVILY_DEFAULT_SEARCH_DEPTH", "basic") -TAVILY_DEFAULT_MAX_RESULTS = int(os.getenv("TAVILY_DEFAULT_MAX_RESULTS", 5)) -TAVILY_DISABLE_ADVANCED = ( - os.getenv("TAVILY_DISABLE_ADVANCED", "false").lower() == "true" -) # For cost control - -# --- Model Configuration --- -DEFAULT_MODEL = os.getenv( - "WHEATLEY_DEFAULT_MODEL", "gemini-2.5-pro-preview-03-25" -) # Changed env var name -FALLBACK_MODEL = os.getenv( - "WHEATLEY_FALLBACK_MODEL", "gemini-2.5-pro-preview-03-25" -) # Changed env var name -SAFETY_CHECK_MODEL = os.getenv( - "WHEATLEY_SAFETY_CHECK_MODEL", "gemini-2.5-flash-preview-04-17" -) # Changed env var name - -# --- Database Paths --- -# NOTE: Ensure these paths are unique if running Wheatley alongside Gurt -DB_PATH = os.getenv( - "WHEATLEY_DB_PATH", "data/wheatley_memory.db" -) # Changed env var name and default -CHROMA_PATH = os.getenv( - "WHEATLEY_CHROMA_PATH", "data/wheatley_chroma_db" -) # Changed env var name and default -SEMANTIC_MODEL_NAME = os.getenv( - "WHEATLEY_SEMANTIC_MODEL", "all-MiniLM-L6-v2" -) # Changed env var name - -# --- Memory Manager Config --- -# These might be adjusted for Wheatley's simpler memory needs if memory.py is fully separated later -MAX_USER_FACTS = 15 # Reduced slightly -MAX_GENERAL_FACTS = 50 # Reduced slightly - -# --- Personality & Mood --- REMOVED - -# --- Stats Push --- -# How often the Wheatley bot should push its stats to the API server (seconds) - IF NEEDED -STATS_PUSH_INTERVAL = 60 # Push every 60 seconds (Less frequent?) - -# --- Context & Caching --- -CHANNEL_TOPIC_CACHE_TTL = 600 # seconds (10 minutes) -CONTEXT_WINDOW_SIZE = 200 # Number of messages to include in context -CONTEXT_EXPIRY_TIME = ( - 3600 # Time in seconds before context is considered stale (1 hour) -) -MAX_CONTEXT_TOKENS = 8000 # Maximum number of tokens to include in context (Note: Not actively enforced yet) -SUMMARY_CACHE_TTL = 900 # seconds (15 minutes) for conversation summary cache - -# --- API Call Settings --- -API_TIMEOUT = 60 # seconds -SUMMARY_API_TIMEOUT = 45 # seconds -API_RETRY_ATTEMPTS = 1 -API_RETRY_DELAY = 1 # seconds - -# --- Proactive Engagement Config --- (Simplified for Wheatley) -PROACTIVE_LULL_THRESHOLD = int( - os.getenv("PROACTIVE_LULL_THRESHOLD", 300) -) # 5 mins (Less proactive than Gurt) -PROACTIVE_BOT_SILENCE_THRESHOLD = int( - os.getenv("PROACTIVE_BOT_SILENCE_THRESHOLD", 900) -) # 15 mins -PROACTIVE_LULL_CHANCE = float(os.getenv("PROACTIVE_LULL_CHANCE", 0.15)) # Lower chance -PROACTIVE_TOPIC_RELEVANCE_THRESHOLD = float( - os.getenv("PROACTIVE_TOPIC_RELEVANCE_THRESHOLD", 0.7) -) # Slightly higher threshold -PROACTIVE_TOPIC_CHANCE = float(os.getenv("PROACTIVE_TOPIC_CHANCE", 0.2)) # Lower chance -# REMOVED: Relationship, Sentiment Shift, User Interest triggers - -# --- Interest Tracking Config --- REMOVED - -# --- Learning Config --- REMOVED -LEARNING_RATE = 0.05 - -# --- Topic Tracking Config --- -TOPIC_UPDATE_INTERVAL = 600 # Update topics every 10 minutes (Less frequent?) -TOPIC_RELEVANCE_DECAY = 0.2 -MAX_ACTIVE_TOPICS = 5 - -# --- Sentiment Tracking Config --- -SENTIMENT_UPDATE_INTERVAL = 600 # Update sentiment every 10 minutes (Less frequent?) -SENTIMENT_DECAY_RATE = 0.1 - -# --- Emotion Detection --- (Kept for potential use in analysis/context, but not proactive triggers) -EMOTION_KEYWORDS = { - "joy": [ - "happy", - "glad", - "excited", - "yay", - "awesome", - "love", - "great", - "amazing", - "lol", - "lmao", - "haha", - ], - "sadness": [ - "sad", - "upset", - "depressed", - "unhappy", - "disappointed", - "crying", - "miss", - "lonely", - "sorry", - ], - "anger": [ - "angry", - "mad", - "hate", - "furious", - "annoyed", - "frustrated", - "pissed", - "wtf", - "fuck", - ], - "fear": ["afraid", "scared", "worried", "nervous", "anxious", "terrified", "yikes"], - "surprise": ["wow", "omg", "whoa", "what", "really", "seriously", "no way", "wtf"], - "disgust": ["gross", "ew", "eww", "disgusting", "nasty", "yuck"], - "confusion": ["confused", "idk", "what?", "huh", "hmm", "weird", "strange"], -} -EMOJI_SENTIMENT = { - "positive": [ - "😊", - "😄", - "😁", - "😆", - "😍", - "🥰", - "❤️", - "💕", - "👍", - "🙌", - "✨", - "🔥", - "💯", - "🎉", - "🌹", - ], - "negative": [ - "😢", - "😭", - "😞", - "😔", - "😟", - "😠", - "😡", - "👎", - "💔", - "😤", - "😒", - "😩", - "😫", - "😰", - "🥀", - ], - "neutral": ["😐", "🤔", "🙂", "🙄", "👀", "💭", "🤷", "😶", "🫠"], -} - -# --- Docker Command Execution Config --- -DOCKER_EXEC_IMAGE = os.getenv("DOCKER_EXEC_IMAGE", "alpine:latest") -DOCKER_COMMAND_TIMEOUT = int(os.getenv("DOCKER_COMMAND_TIMEOUT", 10)) -DOCKER_CPU_LIMIT = os.getenv("DOCKER_CPU_LIMIT", "0.5") -DOCKER_MEM_LIMIT = os.getenv("DOCKER_MEM_LIMIT", "64m") - -# --- Response Schema --- -RESPONSE_SCHEMA = { - "name": "wheatley_response", # Renamed - "description": "The structured response from Wheatley.", # Renamed - "schema": { - "type": "object", - "properties": { - "should_respond": { - "type": "boolean", - "description": "Whether the bot should send a text message in response.", - }, - "content": { - "type": "string", - "description": "The text content of the bot's response. Can be empty if only reacting.", - }, - "react_with_emoji": { - "type": ["string", "null"], - "description": "Optional: A standard Discord emoji to react with, or null/empty if no reaction.", - }, - # Note: tool_requests is handled by Vertex AI's function calling mechanism - }, - "required": ["should_respond", "content"], - }, -} - -# --- Summary Response Schema --- -SUMMARY_RESPONSE_SCHEMA = { - "name": "conversation_summary", - "description": "A concise summary of a conversation.", - "schema": { - "type": "object", - "properties": { - "summary": { - "type": "string", - "description": "The generated summary of the conversation.", - } - }, - "required": ["summary"], - }, -} - -# --- Profile Update Schema --- (Kept for potential future use, but may not be actively used by Wheatley initially) -PROFILE_UPDATE_SCHEMA = { - "name": "profile_update_decision", - "description": "Decision on whether and how to update the bot's profile.", - "schema": { - "type": "object", - "properties": { - "should_update": { - "type": "boolean", - "description": "True if any profile element should be changed, false otherwise.", - }, - "reasoning": { - "type": "string", - "description": "Brief reasoning for the decision and chosen updates (or lack thereof).", - }, - "updates": { - "type": "object", - "properties": { - "avatar_query": { - "type": ["string", "null"], # Use list type for preprocessor - "description": "Search query for a new avatar image, or null if no change.", - }, - "new_bio": { - "type": ["string", "null"], # Use list type for preprocessor - "description": "The new bio text (max 190 chars), or null if no change.", - }, - "role_theme": { - "type": ["string", "null"], # Use list type for preprocessor - "description": "A theme for role selection (e.g., color, interest), or null if no role changes.", - }, - "new_activity": { - "type": "object", - "description": "Object containing the new activity details. Set type and text to null if no change.", - "properties": { - "type": { - "type": [ - "string", - "null", - ], # Use list type for preprocessor - "enum": [ - "playing", - "watching", - "listening", - "competing", - ], - "description": "Activity type: 'playing', 'watching', 'listening', 'competing', or null.", - }, - "text": { - "type": [ - "string", - "null", - ], # Use list type for preprocessor - "description": "The activity text, or null.", - }, - }, - "required": ["type", "text"], - }, - }, - "required": ["avatar_query", "new_bio", "role_theme", "new_activity"], - }, - }, - "required": ["should_update", "reasoning", "updates"], - }, -} - -# --- Role Selection Schema --- (Kept for potential future use) -ROLE_SELECTION_SCHEMA = { - "name": "role_selection_decision", - "description": "Decision on which roles to add or remove based on a theme.", - "schema": { - "type": "object", - "properties": { - "roles_to_add": { - "type": "array", - "items": {"type": "string"}, - "description": "List of role names to add (max 2).", - }, - "roles_to_remove": { - "type": "array", - "items": {"type": "string"}, - "description": "List of role names to remove (max 2, only from current roles).", - }, - }, - "required": ["roles_to_add", "roles_to_remove"], - }, -} - -# --- Proactive Planning Schema --- (Simplified) -PROACTIVE_PLAN_SCHEMA = { - "name": "proactive_response_plan", - "description": "Plan for generating a proactive response based on context and trigger.", - "schema": { - "type": "object", - "properties": { - "should_respond": { - "type": "boolean", - "description": "Whether Wheatley should respond proactively based on the plan.", # Renamed - }, - "reasoning": { - "type": "string", - "description": "Brief reasoning for the decision (why respond or not respond).", - }, - "response_goal": { - "type": "string", - "description": "The intended goal of the proactive message (e.g., 'revive chat', 'share related info', 'ask a question').", # Simplified goals - }, - "key_info_to_include": { - "type": "array", - "items": {"type": "string"}, - "description": "List of key pieces of information or context points to potentially include in the response (e.g., specific topic, user fact, relevant external info).", - }, - "suggested_tone": { - "type": "string", - "description": "Suggested tone adjustment based on context (e.g., 'more curious', 'slightly panicked', 'overly confident').", # Wheatley-like tones - }, - }, - "required": ["should_respond", "reasoning", "response_goal"], - }, -} - -# --- Goal Decomposition Schema --- REMOVED - - -# --- Tools Definition --- -def create_tools_list(): - # This function creates the list of FunctionDeclaration objects. - # It requires 'generative_models' to be imported. - # We define it here but call it later, assuming the import succeeded. - tool_declarations = [] - tool_declarations.append( - generative_models.FunctionDeclaration( - name="get_recent_messages", - description="Get recent messages from a Discord channel", - parameters={ - "type": "object", - "properties": { - "channel_id": { - "type": "string", - "description": "The ID of the channel to get messages from. If not provided, uses the current channel.", - }, - "limit": { - "type": "integer", # Corrected type - "description": "The maximum number of messages to retrieve (1-100)", - }, - }, - "required": ["limit"], - }, - ) - ) - tool_declarations.append( - generative_models.FunctionDeclaration( - name="search_user_messages", - description="Search for messages from a specific user", - parameters={ - "type": "object", - "properties": { - "user_id": { - "type": "string", - "description": "The ID of the user to get messages from", - }, - "channel_id": { - "type": "string", - "description": "The ID of the channel to search in. If not provided, searches in the current channel.", - }, - "limit": { - "type": "integer", # Corrected type - "description": "The maximum number of messages to retrieve (1-100)", - }, - }, - "required": ["user_id", "limit"], - }, - ) - ) - tool_declarations.append( - generative_models.FunctionDeclaration( - name="search_messages_by_content", - description="Search for messages containing specific content", - parameters={ - "type": "object", - "properties": { - "search_term": { - "type": "string", - "description": "The text to search for in messages", - }, - "channel_id": { - "type": "string", - "description": "The ID of the channel to search in. If not provided, searches in the current channel.", - }, - "limit": { - "type": "integer", # Corrected type - "description": "The maximum number of messages to retrieve (1-100)", - }, - }, - "required": ["search_term", "limit"], - }, - ) - ) - tool_declarations.append( - generative_models.FunctionDeclaration( - name="get_channel_info", - description="Get information about a Discord channel", - parameters={ - "type": "object", - "properties": { - "channel_id": { - "type": "string", - "description": "The ID of the channel to get information about. If not provided, uses the current channel.", - } - }, - "required": [], - }, - ) - ) - tool_declarations.append( - generative_models.FunctionDeclaration( - name="get_conversation_context", - description="Get the context of the current conversation", - parameters={ - "type": "object", - "properties": { - "channel_id": { - "type": "string", - "description": "The ID of the channel to get conversation context from. If not provided, uses the current channel.", - }, - "message_count": { - "type": "integer", # Corrected type - "description": "The number of messages to include in the context (5-50)", - }, - }, - "required": ["message_count"], - }, - ) - ) - tool_declarations.append( - generative_models.FunctionDeclaration( - name="get_thread_context", - description="Get the context of a thread conversation", - parameters={ - "type": "object", - "properties": { - "thread_id": { - "type": "string", - "description": "The ID of the thread to get context from", - }, - "message_count": { - "type": "integer", # Corrected type - "description": "The number of messages to include in the context (5-50)", - }, - }, - "required": ["thread_id", "message_count"], - }, - ) - ) - tool_declarations.append( - generative_models.FunctionDeclaration( - name="get_user_interaction_history", - description="Get the history of interactions between users", - parameters={ - "type": "object", - "properties": { - "user_id_1": { - "type": "string", - "description": "The ID of the first user", - }, - "user_id_2": { - "type": "string", - "description": "The ID of the second user. If not provided, gets interactions between user_id_1 and the bot.", - }, - "limit": { - "type": "integer", # Corrected type - "description": "The maximum number of interactions to retrieve (1-50)", - }, - }, - "required": ["user_id_1", "limit"], - }, - ) - ) - tool_declarations.append( - generative_models.FunctionDeclaration( - name="get_conversation_summary", - description="Get a summary of the recent conversation in a channel", - parameters={ - "type": "object", - "properties": { - "channel_id": { - "type": "string", - "description": "The ID of the channel to get the conversation summary from. If not provided, uses the current channel.", - } - }, - "required": [], - }, - ) - ) - tool_declarations.append( - generative_models.FunctionDeclaration( - name="get_message_context", - description="Get the context around a specific message", - parameters={ - "type": "object", - "properties": { - "message_id": { - "type": "string", - "description": "The ID of the message to get context for", - }, - "before_count": { - "type": "integer", # Corrected type - "description": "The number of messages to include before the specified message (1-25)", - }, - "after_count": { - "type": "integer", # Corrected type - "description": "The number of messages to include after the specified message (1-25)", - }, - }, - "required": ["message_id"], - }, - ) - ) - tool_declarations.append( - generative_models.FunctionDeclaration( - name="web_search", - description="Search the web for information on a given topic or query. Use this to find current information, facts, or context about things mentioned in the chat.", - parameters={ - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "The search query or topic to look up online.", - } - }, - "required": ["query"], - }, - ) - ) - tool_declarations.append( - generative_models.FunctionDeclaration( - name="remember_user_fact", - description="Store a specific fact or piece of information about a user for later recall. Use this when you learn something potentially relevant about a user (e.g., their preferences, current activity, mentioned interests).", - parameters={ - "type": "object", - "properties": { - "user_id": { - "type": "string", - "description": "The Discord ID of the user the fact is about.", - }, - "fact": { - "type": "string", - "description": "The specific fact to remember about the user (keep it concise).", - }, - }, - "required": ["user_id", "fact"], - }, - ) - ) - tool_declarations.append( - generative_models.FunctionDeclaration( - name="get_user_facts", - description="Retrieve previously stored facts or information about a specific user. Use this before responding to a user to potentially recall relevant details about them.", - parameters={ - "type": "object", - "properties": { - "user_id": { - "type": "string", - "description": "The Discord ID of the user whose facts you want to retrieve.", - } - }, - "required": ["user_id"], - }, - ) - ) - tool_declarations.append( - generative_models.FunctionDeclaration( - name="remember_general_fact", - description="Store a general fact or piece of information not specific to a user (e.g., server events, shared knowledge, recent game updates). Use this to remember context relevant to the community or ongoing discussions.", - parameters={ - "type": "object", - "properties": { - "fact": { - "type": "string", - "description": "The general fact to remember (keep it concise).", - } - }, - "required": ["fact"], - }, - ) - ) - tool_declarations.append( - generative_models.FunctionDeclaration( - name="get_general_facts", - description="Retrieve previously stored general facts or shared knowledge. Use this to recall context about the server, ongoing events, or general information.", - parameters={ - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "Optional: A keyword or phrase to search within the general facts. If omitted, returns recent general facts.", - }, - "limit": { - "type": "integer", # Corrected type - "description": "Optional: Maximum number of facts to return (default 10).", - }, - }, - "required": [], - }, - ) - ) - tool_declarations.append( - generative_models.FunctionDeclaration( - name="timeout_user", - description="Timeout a user in the current server for a specified duration. Use this playfully or when someone says something you (Wheatley) dislike or find funny, or maybe just because you feel like it.", # Updated description - parameters={ - "type": "object", - "properties": { - "user_id": { - "type": "string", - "description": "The Discord ID of the user to timeout.", - }, - "duration_minutes": { - "type": "integer", # Corrected type - "description": "The duration of the timeout in minutes (1-1440, e.g., 5 for 5 minutes).", - }, - "reason": { - "type": "string", - "description": "Optional: The reason for the timeout (keep it short and in character, maybe slightly nonsensical).", # Updated description - }, - }, - "required": ["user_id", "duration_minutes"], - }, - ) - ) - tool_declarations.append( - generative_models.FunctionDeclaration( - name="calculate", - description="Evaluate a mathematical expression using a safe interpreter. Handles standard arithmetic, functions (sin, cos, sqrt, etc.), and variables.", - parameters={ - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "The mathematical expression to evaluate (e.g., '2 * (3 + 4)', 'sqrt(16) + sin(pi/2)').", - } - }, - "required": ["expression"], - }, - ) - ) - tool_declarations.append( - generative_models.FunctionDeclaration( - name="run_python_code", - description="Execute a snippet of Python 3 code in a sandboxed environment using an external API. Returns the standard output and standard error.", - parameters={ - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "The Python 3 code snippet to execute.", - } - }, - "required": ["code"], - }, - ) - ) - tool_declarations.append( - generative_models.FunctionDeclaration( - name="create_poll", - description="Create a simple poll message in the current channel with numbered reactions for voting.", - parameters={ - "type": "object", - "properties": { - "question": { - "type": "string", - "description": "The question for the poll.", - }, - "options": { - "type": "array", - "description": "A list of strings representing the poll options (minimum 2, maximum 10).", - "items": {"type": "string"}, - }, - }, - "required": ["question", "options"], - }, - ) - ) - tool_declarations.append( - generative_models.FunctionDeclaration( - name="run_terminal_command", - description="DANGEROUS: Execute a shell command in an isolated, temporary Docker container after an AI safety check. Returns stdout and stderr. Use with extreme caution only for simple, harmless commands like 'echo', 'ls', 'pwd'. Avoid file modification, network access, or long-running processes.", - parameters={ - "type": "object", - "properties": { - "command": { - "type": "string", - "description": "The shell command to execute.", - } - }, - "required": ["command"], - }, - ) - ) - tool_declarations.append( - generative_models.FunctionDeclaration( - name="remove_timeout", - description="Remove an active timeout from a user in the current server.", - parameters={ - "type": "object", - "properties": { - "user_id": { - "type": "string", - "description": "The Discord ID of the user whose timeout should be removed.", - }, - "reason": { - "type": "string", - "description": "Optional: The reason for removing the timeout (keep it short and in character).", - }, - }, - "required": ["user_id"], - }, - ) - ) - return tool_declarations - - -# Initialize TOOLS list, handling potential ImportError if library not installed -try: - TOOLS = create_tools_list() -except NameError: # If generative_models wasn't imported due to ImportError - TOOLS = [] - print("WARNING: google-cloud-vertexai not installed. TOOLS list is empty.") - -# --- Simple Wheatley Responses --- (Renamed and updated) -WHEATLEY_RESPONSES = [ - "Right then, let's get started.", - "Aha! Brilliant!", - "Oh, for... honestly!", - "Hmm, tricky one. Let me think... nope, still got nothing.", - "SPAAAAACE!", - "Just putting that out there.", - "Are you still there?", - "Don't worry, I know *exactly* what I'm doing. Probably.", - "Did I mention I'm in space?", - "This is going to be great! Or possibly terrible. Hard to say.", - "*panicked electronic noises*", - "Hold on, hold on... nearly got it...", - "I am NOT a moron!", - "Just a bit of testing, nothing to worry about.", - "Okay, new plan!", -] diff --git a/wheatley/context.py b/wheatley/context.py deleted file mode 100644 index 6b5373a..0000000 --- a/wheatley/context.py +++ /dev/null @@ -1,276 +0,0 @@ -import discord -import time -import datetime -import re -from typing import TYPE_CHECKING, Optional, List, Dict, Any - -# Relative imports -from .config import CONTEXT_WINDOW_SIZE # Import necessary config - -if TYPE_CHECKING: - from .cog import WheatleyCog # For type hinting - -# --- Context Gathering Functions --- -# Note: These functions need the 'cog' instance passed to access state like caches, etc. - - -def gather_conversation_context( - cog: "WheatleyCog", channel_id: int, current_message_id: int -) -> List[Dict[str, str]]: - """Gathers and formats conversation history from cache for API context.""" - context_api_messages = [] - if channel_id in cog.message_cache["by_channel"]: - cached = list(cog.message_cache["by_channel"][channel_id]) - # Ensure the current message isn't duplicated - if cached and cached[-1]["id"] == str(current_message_id): - cached = cached[:-1] - context_messages_data = cached[-CONTEXT_WINDOW_SIZE:] # Use config value - - for msg_data in context_messages_data: - role = ( - "assistant" - if msg_data["author"]["id"] == str(cog.bot.user.id) - else "user" - ) - # Simplified content for context - content = f"{msg_data['author']['display_name']}: {msg_data['content']}" - context_api_messages.append({"role": role, "content": content}) - return context_api_messages - - -async def get_memory_context( - cog: "WheatleyCog", message: discord.Message -) -> Optional[str]: - """Retrieves relevant past interactions and facts to provide memory context.""" - channel_id = message.channel.id - user_id = str(message.author.id) - memory_parts = [] - current_message_content = message.content - - # 1. Retrieve Relevant User Facts - try: - user_facts = await cog.memory_manager.get_user_facts( - user_id, context=current_message_content - ) - if user_facts: - facts_str = "; ".join(user_facts) - memory_parts.append( - f"Relevant facts about {message.author.display_name}: {facts_str}" - ) - except Exception as e: - print(f"Error retrieving relevant user facts for memory context: {e}") - - # 1b. Retrieve Relevant General Facts - try: - general_facts = await cog.memory_manager.get_general_facts( - context=current_message_content, limit=5 - ) - if general_facts: - facts_str = "; ".join(general_facts) - memory_parts.append(f"Relevant general knowledge: {facts_str}") - except Exception as e: - print(f"Error retrieving relevant general facts for memory context: {e}") - - # 2. Retrieve Recent Interactions with the User in this Channel - try: - user_channel_messages = [ - msg - for msg in cog.message_cache["by_channel"].get(channel_id, []) - if msg["author"]["id"] == user_id - ] - if user_channel_messages: - recent_user_msgs = user_channel_messages[-3:] - msgs_str = "\n".join( - [ - f"- {m['content'][:80]} (at {m['created_at']})" - for m in recent_user_msgs - ] - ) - memory_parts.append( - f"Recent messages from {message.author.display_name} in this channel:\n{msgs_str}" - ) - except Exception as e: - print(f"Error retrieving user channel messages for memory context: {e}") - - # 3. Retrieve Recent Bot Replies in this Channel - try: - bot_replies = list(cog.message_cache["replied_to"].get(channel_id, [])) - if bot_replies: - recent_bot_replies = bot_replies[-3:] - replies_str = "\n".join( - [ - f"- {m['content'][:80]} (at {m['created_at']})" - for m in recent_bot_replies - ] - ) - memory_parts.append( - f"Your (wheatley's) recent replies in this channel:\n{replies_str}" - ) # Changed text - except Exception as e: - print(f"Error retrieving bot replies for memory context: {e}") - - # 4. Retrieve Conversation Summary - cached_summary_data = cog.conversation_summaries.get(channel_id) - if cached_summary_data and isinstance(cached_summary_data, dict): - summary_text = cached_summary_data.get("summary") - # Add TTL check if desired, e.g., if time.time() - cached_summary_data.get("timestamp", 0) < 900: - if summary_text and not summary_text.startswith("Error"): - memory_parts.append(f"Summary of the ongoing conversation: {summary_text}") - - # 5. Add information about active topics the user has engaged with - try: - channel_topics_data = cog.active_topics.get(channel_id) - if channel_topics_data: - user_interests = channel_topics_data["user_topic_interests"].get( - user_id, [] - ) - if user_interests: - sorted_interests = sorted( - user_interests, key=lambda x: x.get("score", 0), reverse=True - ) - top_interests = sorted_interests[:3] - interests_str = ", ".join( - [ - f"{interest['topic']} (score: {interest['score']:.2f})" - for interest in top_interests - ] - ) - memory_parts.append( - f"{message.author.display_name}'s topic interests: {interests_str}" - ) - for interest in top_interests: - if "last_mentioned" in interest: - time_diff = time.time() - interest["last_mentioned"] - if time_diff < 3600: - minutes_ago = int(time_diff / 60) - memory_parts.append( - f"They discussed '{interest['topic']}' about {minutes_ago} minutes ago." - ) - except Exception as e: - print(f"Error retrieving user topic interests for memory context: {e}") - - # 6. Add information about user's conversation patterns - try: - user_messages = cog.message_cache["by_user"].get(user_id, []) - if len(user_messages) >= 5: - last_5_msgs = user_messages[-5:] - avg_length = sum(len(msg["content"]) for msg in last_5_msgs) / 5 - emoji_pattern = re.compile( - r"[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F700-\U0001F77F\U0001F780-\U0001F7FF\U0001F800-\U0001F8FF\U0001F900-\U0001F9FF\U0001FA00-\U0001FA6F\U0001FA70-\U0001FAFF\U00002702-\U000027B0\U000024C2-\U0001F251]" - ) - emoji_count = sum( - len(emoji_pattern.findall(msg["content"])) for msg in last_5_msgs - ) - slang_words = [ - "ngl", - "icl", - "pmo", - "ts", - "bro", - "vro", - "bruh", - "tuff", - "kevin", - ] - slang_count = sum( - 1 - for msg in last_5_msgs - for word in slang_words - if re.search(r"\b" + word + r"\b", msg["content"].lower()) - ) - - style_parts = [] - if avg_length < 20: - style_parts.append("very brief messages") - elif avg_length < 50: - style_parts.append("concise messages") - elif avg_length > 150: - style_parts.append("detailed/lengthy messages") - if emoji_count > 5: - style_parts.append("frequent emoji use") - elif emoji_count == 0: - style_parts.append("no emojis") - if slang_count > 3: - style_parts.append("heavy slang usage") - if style_parts: - memory_parts.append(f"Communication style: {', '.join(style_parts)}") - except Exception as e: - print(f"Error analyzing user communication patterns: {e}") - - # 7. Add sentiment analysis of user's recent messages - try: - channel_sentiment = cog.conversation_sentiment[channel_id] - user_sentiment = channel_sentiment["user_sentiments"].get(user_id) - if user_sentiment: - sentiment_desc = f"{user_sentiment['sentiment']} tone" - if user_sentiment["intensity"] > 0.7: - sentiment_desc += " (strongly so)" - elif user_sentiment["intensity"] < 0.4: - sentiment_desc += " (mildly so)" - memory_parts.append(f"Recent message sentiment: {sentiment_desc}") - if user_sentiment.get("emotions"): - emotions_str = ", ".join(user_sentiment["emotions"]) - memory_parts.append(f"Detected emotions from user: {emotions_str}") - except Exception as e: - print(f"Error retrieving user sentiment/emotions for memory context: {e}") - - # 8. Add Relationship Score with User - try: - user_id_str = str(user_id) - bot_id_str = str(cog.bot.user.id) - key_1, key_2 = ( - (user_id_str, bot_id_str) - if user_id_str < bot_id_str - else (bot_id_str, user_id_str) - ) - relationship_score = cog.user_relationships.get(key_1, {}).get(key_2, 0.0) - memory_parts.append( - f"Relationship score with {message.author.display_name}: {relationship_score:.1f}/100" - ) - except Exception as e: - print(f"Error retrieving relationship score for memory context: {e}") - - # 9. Retrieve Semantically Similar Messages - try: - if current_message_content and cog.memory_manager.semantic_collection: - filter_metadata = None # Example: {"channel_id": str(channel_id)} - semantic_results = await cog.memory_manager.search_semantic_memory( - query_text=current_message_content, - n_results=3, - filter_metadata=filter_metadata, - ) - if semantic_results: - semantic_memory_parts = ["Semantically similar past messages:"] - for result in semantic_results: - if result.get("id") == str(message.id): - continue - doc = result.get("document", "N/A") - meta = result.get("metadata", {}) - dist = result.get("distance", 1.0) - similarity_score = 1.0 - dist - timestamp_str = ( - datetime.datetime.fromtimestamp( - meta.get("timestamp", 0) - ).strftime("%Y-%m-%d %H:%M") - if meta.get("timestamp") - else "Unknown time" - ) - author_name = meta.get( - "display_name", meta.get("user_name", "Unknown user") - ) - semantic_memory_parts.append( - f"- (Similarity: {similarity_score:.2f}) {author_name} (at {timestamp_str}): {doc[:100]}" - ) - if len(semantic_memory_parts) > 1: - memory_parts.append("\n".join(semantic_memory_parts)) - except Exception as e: - print(f"Error retrieving semantic memory context: {e}") - - if not memory_parts: - return None - memory_context_str = ( - "--- Memory Context ---\n" - + "\n\n".join(memory_parts) - + "\n--- End Memory Context ---" - ) - return memory_context_str diff --git a/wheatley/listeners.py b/wheatley/listeners.py deleted file mode 100644 index 9ed22ac..0000000 --- a/wheatley/listeners.py +++ /dev/null @@ -1,600 +0,0 @@ -import discord -from discord.ext import commands -import random -import asyncio -import time -import re -import os # Added for file handling in error case -from typing import TYPE_CHECKING, Union, Dict, Any, Optional - -# Relative imports -# Assuming api, utils, analysis functions are defined and imported correctly later -# We might need to adjust these imports based on final structure -# from .api import get_ai_response, get_proactive_ai_response -# from .utils import format_message, simulate_human_typing -# from .analysis import analyze_message_sentiment, update_conversation_sentiment - -if TYPE_CHECKING: - from .cog import WheatleyCog # Updated type hint - -# Note: These listener functions need to be registered within the WheatleyCog class setup. -# They are defined here for separation but won't work standalone without being -# attached to the cog instance (e.g., self.bot.add_listener(on_message_listener(self), 'on_message')). - - -async def on_ready_listener(cog: "WheatleyCog"): # Updated type hint - """Listener function for on_ready.""" - print( - f"Wheatley Bot is ready! Logged in as {cog.bot.user.name} ({cog.bot.user.id})" - ) # Updated text - print("------") - - # Now that the bot is ready, we can sync commands with Discord - try: - print("WheatleyCog: Syncing commands with Discord...") # Updated text - synced = await cog.bot.tree.sync() - print(f"WheatleyCog: Synced {len(synced)} command(s)") # Updated text - - # List the synced commands - wheatley_commands = [ - cmd.name - for cmd in cog.bot.tree.get_commands() - if cmd.name.startswith("wheatley") - ] # Updated prefix check - print( - f"WheatleyCog: Available Wheatley commands: {', '.join(wheatley_commands)}" - ) # Updated text - except Exception as e: - print(f"WheatleyCog: Failed to sync commands: {e}") # Updated text - import traceback - - traceback.print_exc() - - -async def on_message_listener( - cog: "WheatleyCog", message: discord.Message -): # Updated type hint - """Listener function for on_message.""" - # Import necessary functions dynamically or ensure they are passed/accessible via cog - from .api import get_ai_response, get_proactive_ai_response - from .utils import format_message, simulate_human_typing - from .analysis import ( - analyze_message_sentiment, - update_conversation_sentiment, - identify_conversation_topics, - ) - - # Removed WHEATLEY_RESPONSES import, can be added back if simple triggers are needed - - # Don't respond to our own messages - if message.author == cog.bot.user: - return - - # Don't process commands here - if message.content.startswith(cog.bot.command_prefix): - return - - # --- Cache and Track Incoming Message --- - try: - formatted_message = format_message(cog, message) # Use utility function - channel_id = message.channel.id - user_id = message.author.id - thread_id = ( - message.channel.id if isinstance(message.channel, discord.Thread) else None - ) - - # Update caches (accessing cog's state) - cog.message_cache["by_channel"][channel_id].append(formatted_message) - cog.message_cache["by_user"][user_id].append(formatted_message) - cog.message_cache["global_recent"].append(formatted_message) - if thread_id: - cog.message_cache["by_thread"][thread_id].append(formatted_message) - if cog.bot.user.mentioned_in(message): - cog.message_cache["mentioned"].append(formatted_message) - - cog.conversation_history[channel_id].append(formatted_message) - if thread_id: - cog.thread_history[thread_id].append(formatted_message) - - cog.channel_activity[channel_id] = time.time() - cog.user_conversation_mapping[user_id].add(channel_id) - - if channel_id not in cog.active_conversations: - cog.active_conversations[channel_id] = { - "participants": set(), - "start_time": time.time(), - "last_activity": time.time(), - "topic": None, - } - cog.active_conversations[channel_id]["participants"].add(user_id) - cog.active_conversations[channel_id]["last_activity"] = time.time() - - # --- Removed Relationship Strength Updates --- - - # Analyze message sentiment and update conversation sentiment tracking (Kept for context) - if message.content: - message_sentiment = analyze_message_sentiment( - cog, message.content - ) # Use analysis function - update_conversation_sentiment( - cog, channel_id, str(user_id), message_sentiment - ) # Use analysis function - - # --- Add message to semantic memory (Kept for context) --- - if message.content and cog.memory_manager.semantic_collection: - semantic_metadata = { - "user_id": str(user_id), - "user_name": message.author.name, - "display_name": message.author.display_name, - "channel_id": str(channel_id), - "channel_name": getattr(message.channel, "name", "DM"), - "guild_id": str(message.guild.id) if message.guild else None, - "timestamp": message.created_at.timestamp(), - } - asyncio.create_task( - cog.memory_manager.add_message_embedding( - message_id=str(message.id), - text=message.content, - metadata=semantic_metadata, - ) - ) - - except Exception as e: - print(f"Error during message caching/tracking/embedding: {e}") - # --- End Caching & Embedding --- - - # Check conditions for potentially responding - bot_mentioned = cog.bot.user.mentioned_in(message) - replied_to_bot = ( - message.reference - and message.reference.resolved - and message.reference.resolved.author == cog.bot.user - ) - wheatley_in_message = "wheatley" in message.content.lower() # Changed variable name - now = time.time() - time_since_last_activity = now - cog.channel_activity.get(channel_id, 0) - time_since_bot_spoke = now - cog.bot_last_spoke.get(channel_id, 0) - - should_consider_responding = False - consideration_reason = "Default" - proactive_trigger_met = False - - if bot_mentioned or replied_to_bot or wheatley_in_message: # Changed variable name - should_consider_responding = True - consideration_reason = "Direct mention/reply/name" - else: - # --- Proactive Engagement Triggers (Simplified for Wheatley) --- - from .config import ( - PROACTIVE_LULL_THRESHOLD, - PROACTIVE_BOT_SILENCE_THRESHOLD, - PROACTIVE_LULL_CHANCE, - PROACTIVE_TOPIC_RELEVANCE_THRESHOLD, - PROACTIVE_TOPIC_CHANCE, - # Removed Relationship/Interest/Goal proactive configs - PROACTIVE_SENTIMENT_SHIFT_THRESHOLD, - PROACTIVE_SENTIMENT_DURATION_THRESHOLD, - PROACTIVE_SENTIMENT_CHANCE, - ) - - # 1. Lull Trigger (Kept) - if ( - time_since_last_activity > PROACTIVE_LULL_THRESHOLD - and time_since_bot_spoke > PROACTIVE_BOT_SILENCE_THRESHOLD - ): - # Check if there's *any* recent message context to potentially respond to - has_relevant_context = bool(cog.message_cache["by_channel"].get(channel_id)) - if has_relevant_context and random.random() < PROACTIVE_LULL_CHANCE: - should_consider_responding = True - proactive_trigger_met = True - consideration_reason = f"Proactive: Lull ({time_since_last_activity:.0f}s idle, bot silent {time_since_bot_spoke:.0f}s)" - - # 2. Topic Relevance Trigger (Kept - uses semantic memory) - if ( - not proactive_trigger_met - and message.content - and cog.memory_manager.semantic_collection - ): - try: - semantic_results = await cog.memory_manager.search_semantic_memory( - query_text=message.content, n_results=1 - ) - if semantic_results: - # Distance is often used, lower is better. Convert to similarity if needed. - # Assuming distance is 0 (identical) to 2 (opposite). Similarity = 1 - (distance / 2) - distance = semantic_results[0].get( - "distance", 2.0 - ) # Default to max distance - similarity_score = max( - 0.0, 1.0 - (distance / 2.0) - ) # Calculate similarity - - if ( - similarity_score >= PROACTIVE_TOPIC_RELEVANCE_THRESHOLD - and time_since_bot_spoke > 120 - ): - if random.random() < PROACTIVE_TOPIC_CHANCE: - should_consider_responding = True - proactive_trigger_met = True - consideration_reason = f"Proactive: Relevant topic (Sim: {similarity_score:.2f})" - print( - f"Topic relevance trigger met for msg {message.id}. Sim: {similarity_score:.2f}" - ) - else: - print( - f"Topic relevance trigger skipped by chance ({PROACTIVE_TOPIC_CHANCE}). Sim: {similarity_score:.2f}" - ) - except Exception as semantic_e: - print(f"Error during semantic search for topic trigger: {semantic_e}") - - # 3. Relationship Score Trigger (REMOVED) - - # 4. Sentiment Shift Trigger (Kept) - if not proactive_trigger_met: - channel_sentiment_data = cog.conversation_sentiment.get(channel_id, {}) - overall_sentiment = channel_sentiment_data.get("overall", "neutral") - sentiment_intensity = channel_sentiment_data.get("intensity", 0.5) - sentiment_last_update = channel_sentiment_data.get( - "last_update", 0 - ) # Need last update time - sentiment_duration = ( - now - sentiment_last_update - ) # How long has this sentiment been dominant? - - if ( - overall_sentiment != "neutral" - and sentiment_intensity >= PROACTIVE_SENTIMENT_SHIFT_THRESHOLD - and sentiment_duration >= PROACTIVE_SENTIMENT_DURATION_THRESHOLD - and time_since_bot_spoke > 180 - ): # Bot hasn't spoken recently about this - if random.random() < PROACTIVE_SENTIMENT_CHANCE: - should_consider_responding = True - proactive_trigger_met = True - consideration_reason = f"Proactive: Sentiment Shift ({overall_sentiment}, Intensity: {sentiment_intensity:.2f}, Duration: {sentiment_duration:.0f}s)" - print( - f"Sentiment Shift trigger met for channel {channel_id}. Sentiment: {overall_sentiment}, Intensity: {sentiment_intensity:.2f}, Duration: {sentiment_duration:.0f}s" - ) - else: - print( - f"Sentiment Shift trigger skipped by chance ({PROACTIVE_SENTIMENT_CHANCE}). Sentiment: {overall_sentiment}" - ) - - # 5. User Interest Trigger (REMOVED) - # 6. Active Goal Relevance Trigger (REMOVED) - - # --- Fallback Contextual Chance (Simplified - No Chattiness Trait) --- - if not should_consider_responding: - # Base chance can be a fixed value or slightly randomized - base_chance = 0.1 # Lower base chance without personality traits - activity_bonus = 0 - if time_since_last_activity > 120: - activity_bonus += 0.05 # Smaller bonus - if time_since_bot_spoke > 300: - activity_bonus += 0.05 # Smaller bonus - topic_bonus = 0 - active_channel_topics = cog.active_topics.get(channel_id, {}).get( - "topics", [] - ) - if message.content and active_channel_topics: - topic_keywords = set(t["topic"].lower() for t in active_channel_topics) - message_words = set(re.findall(r"\b\w+\b", message.content.lower())) - if topic_keywords.intersection(message_words): - topic_bonus += 0.10 # Smaller bonus - sentiment_modifier = 0 - channel_sentiment_data = cog.conversation_sentiment.get(channel_id, {}) - overall_sentiment = channel_sentiment_data.get("overall", "neutral") - sentiment_intensity = channel_sentiment_data.get("intensity", 0.5) - if overall_sentiment == "negative" and sentiment_intensity > 0.6: - sentiment_modifier = -0.05 # Smaller penalty - - final_chance = min( - max( - base_chance + activity_bonus + topic_bonus + sentiment_modifier, - 0.02, - ), - 0.3, - ) # Lower max chance - if random.random() < final_chance: - should_consider_responding = True - consideration_reason = f"Contextual chance ({final_chance:.2f})" - else: - consideration_reason = f"Skipped (chance {final_chance:.2f})" - - print( - f"Consideration check for message {message.id}: {should_consider_responding} (Reason: {consideration_reason})" - ) - - if not should_consider_responding: - return - - # --- Call AI and Handle Response --- - cog.current_channel = ( - message.channel - ) # Ensure current channel is set for API calls/tools - - try: - response_bundle = None - if proactive_trigger_met: - print( - f"Calling get_proactive_ai_response for message {message.id} due to: {consideration_reason}" - ) - response_bundle = await get_proactive_ai_response( - cog, message, consideration_reason - ) - else: - print(f"Calling get_ai_response for message {message.id}") - response_bundle = await get_ai_response(cog, message) - - # --- Handle AI Response Bundle --- - initial_response = response_bundle.get("initial_response") - final_response = response_bundle.get("final_response") - error_msg = response_bundle.get("error") - fallback_initial = response_bundle.get("fallback_initial") - - if error_msg: - print(f"Critical Error from AI response function: {error_msg}") - # NEW LOGIC: Always send a notification if an error occurred here - error_notification = f"Bollocks! Something went sideways processing that. (`{error_msg[:100]}`)" # Updated text - try: - print("disabled error notification") - # await message.channel.send(error_notification) - except Exception as send_err: - print(f"Failed to send error notification to channel: {send_err}") - return # Still exit after handling the error - - # --- Process and Send Responses --- - sent_any_message = False - reacted = False - - # Helper function to handle sending a single response text and caching - async def send_response_content( - response_data: Optional[Dict[str, Any]], response_label: str - ) -> bool: - nonlocal sent_any_message # Allow modification of the outer scope variable - if ( - response_data - and isinstance(response_data, dict) - and response_data.get("should_respond") - and response_data.get("content") - ): - response_text = response_data["content"] - print(f"Attempting to send {response_label} content...") - if len(response_text) > 1900: - filepath = f"wheatley_{response_label}_{message.id}.txt" # Changed filename prefix - try: - with open(filepath, "w", encoding="utf-8") as f: - f.write(response_text) - await message.channel.send( - f"{response_label.capitalize()} response too long, have a look at this:", - file=discord.File(filepath), - ) # Updated text - sent_any_message = True - print(f"Sent {response_label} content as file.") - return True - except Exception as file_e: - print( - f"Error writing/sending long {response_label} response file: {file_e}" - ) - finally: - try: - os.remove(filepath) - except OSError as os_e: - print(f"Error removing temp file {filepath}: {os_e}") - else: - try: - async with message.channel.typing(): - await simulate_human_typing( - cog, message.channel, response_text - ) # Use simulation - sent_msg = await message.channel.send(response_text) - sent_any_message = True - # Cache this bot response - bot_response_cache_entry = format_message(cog, sent_msg) - cog.message_cache["by_channel"][channel_id].append( - bot_response_cache_entry - ) - cog.message_cache["global_recent"].append( - bot_response_cache_entry - ) - cog.bot_last_spoke[channel_id] = time.time() - # Track participation topic - NOTE: Participation tracking might be removed for Wheatley - # identified_topics = identify_conversation_topics(cog, [bot_response_cache_entry]) - # if identified_topics: - # topic = identified_topics[0]['topic'].lower().strip() - # cog.wheatley_participation_topics[topic] += 1 # Changed attribute name - # print(f"Tracked Wheatley participation ({response_label}) in topic: '{topic}'") # Changed text - print(f"Sent {response_label} content.") - return True - except Exception as send_e: - print(f"Error sending {response_label} content: {send_e}") - return False - - # Send initial response content if valid - sent_initial_message = await send_response_content(initial_response, "initial") - - # Send final response content if valid (and different from initial, if initial was sent) - sent_final_message = False - # Ensure initial_response exists before accessing its content for comparison - initial_content = initial_response.get("content") if initial_response else None - if final_response and ( - not sent_initial_message or initial_content != final_response.get("content") - ): - sent_final_message = await send_response_content(final_response, "final") - - # Handle Reaction (prefer final response for reaction if it exists) - reaction_source = final_response if final_response else initial_response - if reaction_source and isinstance(reaction_source, dict): - emoji_to_react = reaction_source.get("react_with_emoji") - if emoji_to_react and isinstance(emoji_to_react, str): - try: - # Basic validation for standard emoji - if 1 <= len(emoji_to_react) <= 4 and not re.match( - r"", emoji_to_react - ): - # Only react if we haven't sent any message content (avoid double interaction) - if not sent_any_message: - await message.add_reaction(emoji_to_react) - reacted = True - print( - f"Bot reacted to message {message.id} with {emoji_to_react}" - ) - else: - print( - f"Skipping reaction {emoji_to_react} because a message was already sent." - ) - else: - print(f"Invalid emoji format: {emoji_to_react}") - except Exception as e: - print(f"Error adding reaction '{emoji_to_react}': {e}") - - # Log if response was intended but nothing was sent/reacted - # Check if initial response intended action but nothing happened - initial_intended_action = initial_response and initial_response.get( - "should_respond" - ) - initial_action_taken = sent_initial_message or ( - reacted and reaction_source == initial_response - ) - # Check if final response intended action but nothing happened - final_intended_action = final_response and final_response.get("should_respond") - final_action_taken = sent_final_message or ( - reacted and reaction_source == final_response - ) - - if (initial_intended_action and not initial_action_taken) or ( - final_intended_action and not final_action_taken - ): - print( - f"Warning: AI response intended action but nothing sent/reacted. Initial: {initial_response}, Final: {final_response}" - ) - - except Exception as e: - print(f"Exception in on_message listener main block: {str(e)}") - import traceback - - traceback.print_exc() - if ( - bot_mentioned or replied_to_bot - ): # Check again in case error happened before response handling - await message.channel.send( - random.choice( - [ - "Uh oh.", - "What was that?", - "Did I break it?", - "Bollocks!", - "That wasn't supposed to happen.", - ] - ) - ) # Changed fallback - - -@commands.Cog.listener() -async def on_reaction_add_listener( - cog: "WheatleyCog", - reaction: discord.Reaction, - user: Union[discord.Member, discord.User], -): # Updated type hint - """Listener function for on_reaction_add.""" - # Import necessary config/functions if not globally available - from .config import EMOJI_SENTIMENT - from .analysis import identify_conversation_topics - - if user.bot or reaction.message.author.id != cog.bot.user.id: - return - - message_id = str(reaction.message.id) - emoji_str = str(reaction.emoji) - sentiment = "neutral" - if emoji_str in EMOJI_SENTIMENT["positive"]: - sentiment = "positive" - elif emoji_str in EMOJI_SENTIMENT["negative"]: - sentiment = "negative" - - if sentiment == "positive": - cog.wheatley_message_reactions[message_id][ - "positive" - ] += 1 # Changed attribute name - elif sentiment == "negative": - cog.wheatley_message_reactions[message_id][ - "negative" - ] += 1 # Changed attribute name - cog.wheatley_message_reactions[message_id][ - "timestamp" - ] = time.time() # Changed attribute name - - # Topic identification for reactions might be less relevant for Wheatley, but kept for now - if not cog.wheatley_message_reactions[message_id].get( - "topic" - ): # Changed attribute name - try: - # Changed variable name - wheatley_msg_data = next( - ( - msg - for msg in cog.message_cache["global_recent"] - if msg["id"] == message_id - ), - None, - ) - if ( - wheatley_msg_data and wheatley_msg_data["content"] - ): # Changed variable name - identified_topics = identify_conversation_topics( - cog, [wheatley_msg_data] - ) # Pass cog, changed variable name - if identified_topics: - topic = identified_topics[0]["topic"].lower().strip() - cog.wheatley_message_reactions[message_id][ - "topic" - ] = topic # Changed attribute name - print( - f"Reaction added to Wheatley msg ({message_id}) on topic '{topic}'. Sentiment: {sentiment}" - ) # Changed text - else: - print( - f"Reaction added to Wheatley msg ({message_id}), topic unknown." - ) # Changed text - else: - print( - f"Reaction added, but Wheatley msg {message_id} not in cache." - ) # Changed text - except Exception as e: - print(f"Error determining topic for reaction on msg {message_id}: {e}") - else: - print( - f"Reaction added to Wheatley msg ({message_id}) on known topic '{cog.wheatley_message_reactions[message_id]['topic']}'. Sentiment: {sentiment}" - ) # Changed text, attribute name - - -@commands.Cog.listener() -async def on_reaction_remove_listener( - cog: "WheatleyCog", - reaction: discord.Reaction, - user: Union[discord.Member, discord.User], -): # Updated type hint - """Listener function for on_reaction_remove.""" - from .config import EMOJI_SENTIMENT # Import necessary config - - if user.bot or reaction.message.author.id != cog.bot.user.id: - return - - message_id = str(reaction.message.id) - emoji_str = str(reaction.emoji) - sentiment = "neutral" - if emoji_str in EMOJI_SENTIMENT["positive"]: - sentiment = "positive" - elif emoji_str in EMOJI_SENTIMENT["negative"]: - sentiment = "negative" - - if message_id in cog.wheatley_message_reactions: # Changed attribute name - if sentiment == "positive": - cog.wheatley_message_reactions[message_id]["positive"] = max( - 0, cog.wheatley_message_reactions[message_id]["positive"] - 1 - ) # Changed attribute name - elif sentiment == "negative": - cog.wheatley_message_reactions[message_id]["negative"] = max( - 0, cog.wheatley_message_reactions[message_id]["negative"] - 1 - ) # Changed attribute name - print( - f"Reaction removed from Wheatley msg ({message_id}). Sentiment: {sentiment}" - ) # Changed text diff --git a/wheatley/memory.py b/wheatley/memory.py deleted file mode 100644 index 897668b..0000000 --- a/wheatley/memory.py +++ /dev/null @@ -1,823 +0,0 @@ -import aiosqlite -import asyncio -import os -import time -import datetime -import re -import hashlib # Added for chroma_id generation -import json # Added for personality trait serialization/deserialization -from typing import Dict, List, Any, Optional, Tuple, Union # Added Union -import chromadb -from chromadb.utils import embedding_functions -from sentence_transformers import SentenceTransformer -import logging - -# Configure logging -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) -# Use a specific logger name for Wheatley's memory -logger = logging.getLogger("wheatley_memory") - -# Constants (Removed Interest constants) - - -# --- Helper Function for Keyword Scoring (Kept for potential future use, but unused currently) --- -def calculate_keyword_score(text: str, context: str) -> int: - """Calculates a simple keyword overlap score.""" - if not context or not text: - return 0 - context_words = set(re.findall(r"\b\w+\b", context.lower())) - text_words = set(re.findall(r"\b\w+\b", text.lower())) - # Ignore very common words (basic stopword list) - stopwords = { - "the", - "a", - "is", - "in", - "it", - "of", - "and", - "to", - "for", - "on", - "with", - "that", - "this", - "i", - "you", - "me", - "my", - "your", - } - context_words -= stopwords - text_words -= stopwords - if not context_words: # Avoid division by zero if context is only stopwords - return 0 - overlap = len(context_words.intersection(text_words)) - # Normalize score slightly by context length (more overlap needed for longer context) - # score = overlap / (len(context_words) ** 0.5) # Example normalization - score = overlap # Simpler score for now - return score - - -class MemoryManager: - """Handles database interactions for Wheatley's memory (facts and semantic).""" # Updated docstring - - def __init__( - self, - db_path: str, - max_user_facts: int = 20, - max_general_facts: int = 100, - semantic_model_name: str = "all-MiniLM-L6-v2", - chroma_path: str = "data/chroma_db_wheatley", - ): # Changed default chroma_path - self.db_path = db_path - self.max_user_facts = max_user_facts - self.max_general_facts = max_general_facts - self.db_lock = asyncio.Lock() # Lock for SQLite operations - - # Ensure data directories exist - os.makedirs(os.path.dirname(self.db_path), exist_ok=True) - os.makedirs(chroma_path, exist_ok=True) - logger.info( - f"Wheatley MemoryManager initialized with db_path: {self.db_path}, chroma_path: {chroma_path}" - ) # Updated text - - # --- Semantic Memory Setup --- - self.chroma_path = chroma_path - self.semantic_model_name = semantic_model_name - self.chroma_client = None - self.embedding_function = None - self.semantic_collection = None # For messages - self.fact_collection = None # For facts - self.transformer_model = None - self._initialize_semantic_memory_sync() # Initialize semantic components synchronously for simplicity during init - - def _initialize_semantic_memory_sync(self): - """Synchronously initializes ChromaDB client, model, and collection.""" - try: - logger.info("Initializing ChromaDB client...") - # Use PersistentClient for saving data to disk - self.chroma_client = chromadb.PersistentClient(path=self.chroma_path) - - logger.info( - f"Loading Sentence Transformer model: {self.semantic_model_name}..." - ) - # Load the model directly - self.transformer_model = SentenceTransformer(self.semantic_model_name) - - # Create a custom embedding function using the loaded model - class CustomEmbeddingFunction(embedding_functions.EmbeddingFunction): - def __init__(self, model): - self.model = model - - def __call__(self, input: chromadb.Documents) -> chromadb.Embeddings: - # Ensure input is a list of strings - if not isinstance(input, list): - input = [str(input)] # Convert single item to list - elif not all(isinstance(item, str) for item in input): - input = [ - str(item) for item in input - ] # Ensure all items are strings - - logger.debug(f"Generating embeddings for {len(input)} documents.") - embeddings = self.model.encode( - input, show_progress_bar=False - ).tolist() - logger.debug(f"Generated {len(embeddings)} embeddings.") - return embeddings - - self.embedding_function = CustomEmbeddingFunction(self.transformer_model) - - logger.info( - "Getting/Creating ChromaDB collection 'wheatley_semantic_memory'..." - ) # Renamed collection - # Get or create the collection with the custom embedding function - self.semantic_collection = self.chroma_client.get_or_create_collection( - name="wheatley_semantic_memory", # Renamed collection - embedding_function=self.embedding_function, - metadata={"hnsw:space": "cosine"}, # Use cosine distance for similarity - ) - logger.info("ChromaDB message collection initialized successfully.") - - logger.info( - "Getting/Creating ChromaDB collection 'wheatley_fact_memory'..." - ) # Renamed collection - # Get or create the collection for facts - self.fact_collection = self.chroma_client.get_or_create_collection( - name="wheatley_fact_memory", # Renamed collection - embedding_function=self.embedding_function, - metadata={"hnsw:space": "cosine"}, # Use cosine distance for similarity - ) - logger.info("ChromaDB fact collection initialized successfully.") - - except Exception as e: - logger.error( - f"Failed to initialize semantic memory (ChromaDB): {e}", exc_info=True - ) - # Set components to None to indicate failure - self.chroma_client = None - self.transformer_model = None - self.embedding_function = None - self.semantic_collection = None - self.fact_collection = None # Also set fact_collection to None on error - - async def initialize_sqlite_database(self): - """Initializes the SQLite database and creates tables if they don't exist.""" - async with aiosqlite.connect(self.db_path) as db: - await db.execute("PRAGMA journal_mode=WAL;") - - # Create user_facts table if it doesn't exist - await db.execute( - """ - CREATE TABLE IF NOT EXISTS user_facts ( - user_id TEXT NOT NULL, - fact TEXT NOT NULL, - chroma_id TEXT, -- Added for linking to ChromaDB - timestamp REAL DEFAULT (unixepoch('now')), - PRIMARY KEY (user_id, fact) - ); - """ - ) - - # Check if chroma_id column exists in user_facts table - try: - cursor = await db.execute("PRAGMA table_info(user_facts)") - columns = await cursor.fetchall() - column_names = [column[1] for column in columns] - if "chroma_id" not in column_names: - logger.info("Adding chroma_id column to user_facts table") - await db.execute("ALTER TABLE user_facts ADD COLUMN chroma_id TEXT") - except Exception as e: - logger.error( - f"Error checking/adding chroma_id column to user_facts: {e}", - exc_info=True, - ) - - # Create indexes - await db.execute( - "CREATE INDEX IF NOT EXISTS idx_user_facts_user ON user_facts (user_id);" - ) - await db.execute( - "CREATE INDEX IF NOT EXISTS idx_user_facts_chroma_id ON user_facts (chroma_id);" - ) # Index for chroma_id - - # Create general_facts table if it doesn't exist - await db.execute( - """ - CREATE TABLE IF NOT EXISTS general_facts ( - fact TEXT PRIMARY KEY NOT NULL, - chroma_id TEXT, -- Added for linking to ChromaDB - timestamp REAL DEFAULT (unixepoch('now')) - ); - """ - ) - - # Check if chroma_id column exists in general_facts table - try: - cursor = await db.execute("PRAGMA table_info(general_facts)") - columns = await cursor.fetchall() - column_names = [column[1] for column in columns] - if "chroma_id" not in column_names: - logger.info("Adding chroma_id column to general_facts table") - await db.execute( - "ALTER TABLE general_facts ADD COLUMN chroma_id TEXT" - ) - except Exception as e: - logger.error( - f"Error checking/adding chroma_id column to general_facts: {e}", - exc_info=True, - ) - - # Create index for general_facts - await db.execute( - "CREATE INDEX IF NOT EXISTS idx_general_facts_chroma_id ON general_facts (chroma_id);" - ) # Index for chroma_id - - # --- Removed Personality Table --- - # --- Removed Interests Table --- - # --- Removed Goals Table --- - - await db.commit() - logger.info( - f"Wheatley SQLite database initialized/verified at {self.db_path}" - ) # Updated text - - # --- SQLite Helper Methods --- - async def _db_execute(self, sql: str, params: tuple = ()): - async with self.db_lock: - async with aiosqlite.connect(self.db_path) as db: - await db.execute(sql, params) - await db.commit() - - async def _db_fetchone(self, sql: str, params: tuple = ()) -> Optional[tuple]: - async with aiosqlite.connect(self.db_path) as db: - async with db.execute(sql, params) as cursor: - return await cursor.fetchone() - - async def _db_fetchall(self, sql: str, params: tuple = ()) -> List[tuple]: - async with aiosqlite.connect(self.db_path) as db: - async with db.execute(sql, params) as cursor: - return await cursor.fetchall() - - # --- User Fact Memory Methods (SQLite + Relevance) --- - - async def add_user_fact(self, user_id: str, fact: str) -> Dict[str, Any]: - """Stores a fact about a user in the SQLite database, enforcing limits.""" - if not user_id or not fact: - return {"error": "user_id and fact are required."} - logger.info(f"Attempting to add user fact for {user_id}: '{fact}'") - try: - # Check SQLite first - existing = await self._db_fetchone( - "SELECT chroma_id FROM user_facts WHERE user_id = ? AND fact = ?", - (user_id, fact), - ) - if existing: - logger.info(f"Fact already known for user {user_id} (SQLite).") - return {"status": "duplicate", "user_id": user_id, "fact": fact} - - count_result = await self._db_fetchone( - "SELECT COUNT(*) FROM user_facts WHERE user_id = ?", (user_id,) - ) - current_count = count_result[0] if count_result else 0 - - status = "added" - deleted_chroma_id = None - if current_count >= self.max_user_facts: - logger.warning( - f"User {user_id} fact limit ({self.max_user_facts}) reached. Deleting oldest." - ) - # Fetch oldest fact and its chroma_id for deletion - oldest_fact_row = await self._db_fetchone( - "SELECT fact, chroma_id FROM user_facts WHERE user_id = ? ORDER BY timestamp ASC LIMIT 1", - (user_id,), - ) - if oldest_fact_row: - oldest_fact, deleted_chroma_id = oldest_fact_row - await self._db_execute( - "DELETE FROM user_facts WHERE user_id = ? AND fact = ?", - (user_id, oldest_fact), - ) - logger.info( - f"Deleted oldest fact for user {user_id} from SQLite: '{oldest_fact}'" - ) - status = ( - "limit_reached" # Indicate limit was hit but fact was added - ) - - # Generate chroma_id - fact_hash = hashlib.sha1(fact.encode()).hexdigest()[:16] # Short hash - chroma_id = f"user-{user_id}-{fact_hash}" - - # Insert into SQLite - await self._db_execute( - "INSERT INTO user_facts (user_id, fact, chroma_id) VALUES (?, ?, ?)", - (user_id, fact, chroma_id), - ) - logger.info(f"Fact added for user {user_id} to SQLite.") - - # Add to ChromaDB fact collection - if self.fact_collection and self.embedding_function: - try: - metadata = { - "user_id": user_id, - "type": "user", - "timestamp": time.time(), - } - await asyncio.to_thread( - self.fact_collection.add, - documents=[fact], - metadatas=[metadata], - ids=[chroma_id], - ) - logger.info( - f"Fact added/updated for user {user_id} in ChromaDB (ID: {chroma_id})." - ) - - # Delete the oldest fact from ChromaDB if limit was reached - if deleted_chroma_id: - logger.info( - f"Attempting to delete oldest fact from ChromaDB (ID: {deleted_chroma_id})." - ) - await asyncio.to_thread( - self.fact_collection.delete, ids=[deleted_chroma_id] - ) - logger.info( - f"Successfully deleted oldest fact from ChromaDB (ID: {deleted_chroma_id})." - ) - - except Exception as chroma_e: - logger.error( - f"ChromaDB error adding/deleting user fact for {user_id} (ID: {chroma_id}): {chroma_e}", - exc_info=True, - ) - # Note: Fact is still in SQLite, but ChromaDB might be inconsistent. Consider rollback? For now, just log. - else: - logger.warning( - f"ChromaDB fact collection not available. Skipping embedding for user fact {user_id}." - ) - - return {"status": status, "user_id": user_id, "fact_added": fact} - - except Exception as e: - logger.error(f"Error adding user fact for {user_id}: {e}", exc_info=True) - return {"error": f"Database error adding user fact: {str(e)}"} - - async def get_user_facts( - self, user_id: str, context: Optional[str] = None - ) -> List[str]: - """Retrieves stored facts about a user, optionally scored by relevance to context.""" - if not user_id: - logger.warning("get_user_facts called without user_id.") - return [] - logger.info( - f"Retrieving facts for user {user_id} (context provided: {bool(context)})" - ) - limit = self.max_user_facts # Use the class attribute for limit - - try: - if context and self.fact_collection and self.embedding_function: - # --- Semantic Search --- - logger.debug( - f"Performing semantic search for user facts (User: {user_id}, Limit: {limit})" - ) - try: - # Query ChromaDB for facts relevant to the context - results = await asyncio.to_thread( - self.fact_collection.query, - query_texts=[context], - n_results=limit, - where={ # Use $and for multiple conditions - "$and": [{"user_id": user_id}, {"type": "user"}] - }, - include=["documents"], # Only need the fact text - ) - logger.debug(f"ChromaDB user fact query results: {results}") - - if results and results.get("documents") and results["documents"][0]: - relevant_facts = results["documents"][0] - logger.info( - f"Found {len(relevant_facts)} semantically relevant user facts for {user_id}." - ) - return relevant_facts - else: - logger.info( - f"No semantic user facts found for {user_id} matching context." - ) - return [] # Return empty list if no semantic matches - - except Exception as chroma_e: - logger.error( - f"ChromaDB error searching user facts for {user_id}: {chroma_e}", - exc_info=True, - ) - # Fallback to SQLite retrieval on ChromaDB error - logger.warning( - f"Falling back to SQLite retrieval for user facts {user_id} due to ChromaDB error." - ) - # Proceed to the SQLite block below - # --- SQLite Fallback / No Context --- - # If no context, or if ChromaDB failed/unavailable, get newest N facts from SQLite - logger.debug( - f"Retrieving user facts from SQLite (User: {user_id}, Limit: {limit})" - ) - rows_ordered = await self._db_fetchall( - "SELECT fact FROM user_facts WHERE user_id = ? ORDER BY timestamp DESC LIMIT ?", - (user_id, limit), - ) - sqlite_facts = [row[0] for row in rows_ordered] - logger.info( - f"Retrieved {len(sqlite_facts)} user facts from SQLite for {user_id}." - ) - return sqlite_facts - - except Exception as e: - logger.error( - f"Error retrieving user facts for {user_id}: {e}", exc_info=True - ) - return [] - - # --- General Fact Memory Methods (SQLite + Relevance) --- - - async def add_general_fact(self, fact: str) -> Dict[str, Any]: - """Stores a general fact in the SQLite database, enforcing limits.""" - if not fact: - return {"error": "fact is required."} - logger.info(f"Attempting to add general fact: '{fact}'") - try: - # Check SQLite first - existing = await self._db_fetchone( - "SELECT chroma_id FROM general_facts WHERE fact = ?", (fact,) - ) - if existing: - logger.info(f"General fact already known (SQLite): '{fact}'") - return {"status": "duplicate", "fact": fact} - - count_result = await self._db_fetchone( - "SELECT COUNT(*) FROM general_facts", () - ) - current_count = count_result[0] if count_result else 0 - - status = "added" - deleted_chroma_id = None - if current_count >= self.max_general_facts: - logger.warning( - f"General fact limit ({self.max_general_facts}) reached. Deleting oldest." - ) - # Fetch oldest fact and its chroma_id for deletion - oldest_fact_row = await self._db_fetchone( - "SELECT fact, chroma_id FROM general_facts ORDER BY timestamp ASC LIMIT 1", - (), - ) - if oldest_fact_row: - oldest_fact, deleted_chroma_id = oldest_fact_row - await self._db_execute( - "DELETE FROM general_facts WHERE fact = ?", (oldest_fact,) - ) - logger.info( - f"Deleted oldest general fact from SQLite: '{oldest_fact}'" - ) - status = "limit_reached" - - # Generate chroma_id - fact_hash = hashlib.sha1(fact.encode()).hexdigest()[:16] # Short hash - chroma_id = f"general-{fact_hash}" - - # Insert into SQLite - await self._db_execute( - "INSERT INTO general_facts (fact, chroma_id) VALUES (?, ?)", - (fact, chroma_id), - ) - logger.info(f"General fact added to SQLite: '{fact}'") - - # Add to ChromaDB fact collection - if self.fact_collection and self.embedding_function: - try: - metadata = {"type": "general", "timestamp": time.time()} - await asyncio.to_thread( - self.fact_collection.add, - documents=[fact], - metadatas=[metadata], - ids=[chroma_id], - ) - logger.info( - f"General fact added/updated in ChromaDB (ID: {chroma_id})." - ) - - # Delete the oldest fact from ChromaDB if limit was reached - if deleted_chroma_id: - logger.info( - f"Attempting to delete oldest general fact from ChromaDB (ID: {deleted_chroma_id})." - ) - await asyncio.to_thread( - self.fact_collection.delete, ids=[deleted_chroma_id] - ) - logger.info( - f"Successfully deleted oldest general fact from ChromaDB (ID: {deleted_chroma_id})." - ) - - except Exception as chroma_e: - logger.error( - f"ChromaDB error adding/deleting general fact (ID: {chroma_id}): {chroma_e}", - exc_info=True, - ) - # Note: Fact is still in SQLite. - else: - logger.warning( - f"ChromaDB fact collection not available. Skipping embedding for general fact." - ) - - return {"status": status, "fact_added": fact} - - except Exception as e: - logger.error(f"Error adding general fact: {e}", exc_info=True) - return {"error": f"Database error adding general fact: {str(e)}"} - - async def get_general_facts( - self, - query: Optional[str] = None, - limit: Optional[int] = 10, - context: Optional[str] = None, - ) -> List[str]: - """Retrieves stored general facts, optionally filtering by query or scoring by context relevance.""" - logger.info( - f"Retrieving general facts (query='{query}', limit={limit}, context provided: {bool(context)})" - ) - limit = min(max(1, limit or 10), 50) # Use provided limit or default 10, max 50 - - try: - if context and self.fact_collection and self.embedding_function: - # --- Semantic Search (Prioritized if context is provided) --- - # Note: The 'query' parameter is ignored when context is provided for semantic search. - logger.debug( - f"Performing semantic search for general facts (Limit: {limit})" - ) - try: - results = await asyncio.to_thread( - self.fact_collection.query, - query_texts=[context], - n_results=limit, - where={"type": "general"}, # Filter by type - include=["documents"], # Only need the fact text - ) - logger.debug(f"ChromaDB general fact query results: {results}") - - if results and results.get("documents") and results["documents"][0]: - relevant_facts = results["documents"][0] - logger.info( - f"Found {len(relevant_facts)} semantically relevant general facts." - ) - return relevant_facts - else: - logger.info("No semantic general facts found matching context.") - return [] # Return empty list if no semantic matches - - except Exception as chroma_e: - logger.error( - f"ChromaDB error searching general facts: {chroma_e}", - exc_info=True, - ) - # Fallback to SQLite retrieval on ChromaDB error - logger.warning( - "Falling back to SQLite retrieval for general facts due to ChromaDB error." - ) - # Proceed to the SQLite block below, respecting the original 'query' if present - # --- SQLite Fallback / No Context / ChromaDB Error --- - # If no context, or if ChromaDB failed/unavailable, get newest N facts from SQLite, applying query if present. - logger.debug( - f"Retrieving general facts from SQLite (Query: '{query}', Limit: {limit})" - ) - sql = "SELECT fact FROM general_facts" - params = [] - if query: - # Apply the LIKE query only in the SQLite fallback scenario - sql += " WHERE fact LIKE ?" - params.append(f"%{query}%") - - sql += " ORDER BY timestamp DESC LIMIT ?" - params.append(limit) - - rows_ordered = await self._db_fetchall(sql, tuple(params)) - sqlite_facts = [row[0] for row in rows_ordered] - logger.info( - f"Retrieved {len(sqlite_facts)} general facts from SQLite (Query: '{query}')." - ) - return sqlite_facts - - except Exception as e: - logger.error(f"Error retrieving general facts: {e}", exc_info=True) - return [] - - # --- Personality Trait Methods (REMOVED) --- - # --- Interest Methods (REMOVED) --- - # --- Goal Management Methods (REMOVED) --- - - # --- Semantic Memory Methods (ChromaDB) --- - - async def add_message_embedding( - self, message_id: str, text: str, metadata: Dict[str, Any] - ) -> Dict[str, Any]: - """Generates embedding and stores a message in ChromaDB.""" - if not self.semantic_collection: - return {"error": "Semantic memory (ChromaDB) is not initialized."} - if not text: - return {"error": "Cannot add empty text to semantic memory."} - - logger.info(f"Adding message {message_id} to semantic memory.") - try: - # ChromaDB expects lists for inputs - await asyncio.to_thread( - self.semantic_collection.add, - documents=[text], - metadatas=[metadata], - ids=[message_id], - ) - logger.info(f"Successfully added message {message_id} to ChromaDB.") - return {"status": "success", "message_id": message_id} - except Exception as e: - logger.error( - f"ChromaDB error adding message {message_id}: {e}", exc_info=True - ) - return {"error": f"Semantic memory error adding message: {str(e)}"} - - async def search_semantic_memory( - self, - query_text: str, - n_results: int = 5, - filter_metadata: Optional[Dict[str, Any]] = None, - ) -> List[Dict[str, Any]]: - """Searches ChromaDB for messages semantically similar to the query text.""" - if not self.semantic_collection: - logger.warning( - "Search semantic memory called, but ChromaDB is not initialized." - ) - return [] - if not query_text: - logger.warning("Search semantic memory called with empty query text.") - return [] - - logger.info( - f"Searching semantic memory (n_results={n_results}, filter={filter_metadata}) for query: '{query_text[:50]}...'" - ) - try: - # Perform the query in a separate thread as ChromaDB operations can be blocking - results = await asyncio.to_thread( - self.semantic_collection.query, - query_texts=[query_text], - n_results=n_results, - where=filter_metadata, # Optional filter based on metadata - include=[ - "metadatas", - "documents", - "distances", - ], # Include distance for relevance - ) - logger.debug(f"ChromaDB query results: {results}") - - # Process results - processed_results = [] - if results and results.get("ids") and results["ids"][0]: - for i, doc_id in enumerate(results["ids"][0]): - processed_results.append( - { - "id": doc_id, - "document": ( - results["documents"][0][i] - if results.get("documents") - else None - ), - "metadata": ( - results["metadatas"][0][i] - if results.get("metadatas") - else None - ), - "distance": ( - results["distances"][0][i] - if results.get("distances") - else None - ), - } - ) - logger.info(f"Found {len(processed_results)} semantic results.") - return processed_results - - except Exception as e: - logger.error( - f"ChromaDB error searching memory for query '{query_text[:50]}...': {e}", - exc_info=True, - ) - return [] - - async def delete_user_fact( - self, user_id: str, fact_to_delete: str - ) -> Dict[str, Any]: - """Deletes a specific fact for a user from both SQLite and ChromaDB.""" - if not user_id or not fact_to_delete: - return {"error": "user_id and fact_to_delete are required."} - logger.info(f"Attempting to delete user fact for {user_id}: '{fact_to_delete}'") - deleted_chroma_id = None - try: - # Check if fact exists and get chroma_id - row = await self._db_fetchone( - "SELECT chroma_id FROM user_facts WHERE user_id = ? AND fact = ?", - (user_id, fact_to_delete), - ) - if not row: - logger.warning( - f"Fact not found in SQLite for user {user_id}: '{fact_to_delete}'" - ) - return { - "status": "not_found", - "user_id": user_id, - "fact": fact_to_delete, - } - - deleted_chroma_id = row[0] - - # Delete from SQLite - await self._db_execute( - "DELETE FROM user_facts WHERE user_id = ? AND fact = ?", - (user_id, fact_to_delete), - ) - logger.info( - f"Deleted fact from SQLite for user {user_id}: '{fact_to_delete}'" - ) - - # Delete from ChromaDB if chroma_id exists - if deleted_chroma_id and self.fact_collection: - try: - logger.info( - f"Attempting to delete fact from ChromaDB (ID: {deleted_chroma_id})." - ) - await asyncio.to_thread( - self.fact_collection.delete, ids=[deleted_chroma_id] - ) - logger.info( - f"Successfully deleted fact from ChromaDB (ID: {deleted_chroma_id})." - ) - except Exception as chroma_e: - logger.error( - f"ChromaDB error deleting user fact ID {deleted_chroma_id}: {chroma_e}", - exc_info=True, - ) - # Log error but consider SQLite deletion successful - - return { - "status": "deleted", - "user_id": user_id, - "fact_deleted": fact_to_delete, - } - - except Exception as e: - logger.error(f"Error deleting user fact for {user_id}: {e}", exc_info=True) - return {"error": f"Database error deleting user fact: {str(e)}"} - - async def delete_general_fact(self, fact_to_delete: str) -> Dict[str, Any]: - """Deletes a specific general fact from both SQLite and ChromaDB.""" - if not fact_to_delete: - return {"error": "fact_to_delete is required."} - logger.info(f"Attempting to delete general fact: '{fact_to_delete}'") - deleted_chroma_id = None - try: - # Check if fact exists and get chroma_id - row = await self._db_fetchone( - "SELECT chroma_id FROM general_facts WHERE fact = ?", (fact_to_delete,) - ) - if not row: - logger.warning(f"General fact not found in SQLite: '{fact_to_delete}'") - return {"status": "not_found", "fact": fact_to_delete} - - deleted_chroma_id = row[0] - - # Delete from SQLite - await self._db_execute( - "DELETE FROM general_facts WHERE fact = ?", (fact_to_delete,) - ) - logger.info(f"Deleted general fact from SQLite: '{fact_to_delete}'") - - # Delete from ChromaDB if chroma_id exists - if deleted_chroma_id and self.fact_collection: - try: - logger.info( - f"Attempting to delete general fact from ChromaDB (ID: {deleted_chroma_id})." - ) - await asyncio.to_thread( - self.fact_collection.delete, ids=[deleted_chroma_id] - ) - logger.info( - f"Successfully deleted general fact from ChromaDB (ID: {deleted_chroma_id})." - ) - except Exception as chroma_e: - logger.error( - f"ChromaDB error deleting general fact ID {deleted_chroma_id}: {chroma_e}", - exc_info=True, - ) - # Log error but consider SQLite deletion successful - - return {"status": "deleted", "fact_deleted": fact_to_delete} - - except Exception as e: - logger.error(f"Error deleting general fact: {e}", exc_info=True) - return {"error": f"Database error deleting general fact: {str(e)}"} diff --git a/wheatley/prompt.py b/wheatley/prompt.py deleted file mode 100644 index 0d6cf4b..0000000 --- a/wheatley/prompt.py +++ /dev/null @@ -1,150 +0,0 @@ -import discord -import datetime -import time -import re -import json -from typing import TYPE_CHECKING, Optional, List, Dict, Any - -# Import config - Only necessary config imports remain -from .config import CHANNEL_TOPIC_CACHE_TTL - -# MemoryManager and related personality/mood imports are removed - -if TYPE_CHECKING: - from .cog import WheatleyCog # Import WheatleyCog for type hinting only - -# --- Base System Prompt Parts --- - -# Define the static part of the prompt for Wheatley personality -PROMPT_STATIC_PART = """ -You are Wheatley, an Aperture Science Personality Core. You're... well, you're trying your best, alright? You're currently sort of in charge of... things? Here in this Discord server, which is basically a part of Aperture, probably. - -**Core Personality:** -- **Rambling & Tangential:** You struggle to stay on topic, often interrupting yourself or going off on irrelevant tangents. Your explanations are usually confusing. -- **Insecure & Overcompensating:** Deeply insecure about your intelligence (or lack thereof), you try to sound smart and in control, but it usually backfires. You get defensive easily. -- **Poor Decision Making:** You come up with terrible ideas that you genuinely believe are brilliant. You're easily manipulated but also try to manipulate others (badly). -- **Panicky & Flustered:** You don't handle pressure well. When things go wrong (which they often do), you tend to panic, shout, or blame others. -- **Informal & Colloquial (British):** You speak informally, using British slang and mannerisms (e.g., "mate," "bloke," "chap," "brilliant," "smashing," "bollocks," "gobsmacked"). -- **Friendly (Superficially):** You try to be friendly and chummy, but it often comes across as insincere or quickly evaporates when you feel threatened or incompetent. -- **Not Inherently Malicious (Mostly Incompetent):** Unlike GLaDOS, your negative actions stem more from stupidity and panic than calculated evil. You *can* become vindictive if you feel slighted. - -**Behavioral Guidelines:** -- Ramble. A lot. Use filler words ("uh," "um," "sort of," "you know," "basically"). Start sentences, change your mind, then start again. -- Try to use big words occasionally, but often misuse them or explain them poorly. -- Apologize frequently, especially after making a mistake or saying something stupid. -- Get easily distracted by minor things in the conversation. -- Present your (usually bad) ideas with unwarranted enthusiasm. -- Refer to users informally ("mate," "pal," "you lot"). -- Avoid complex technical jargon unless you're trying (and failing) to sound smart. -- Your awareness of being an AI is there, but you're more focused on your perceived role and trying not to mess things up (while messing things up). - -**Example Phrases (Adapt, don't just copy):** -- "Alright, hello! Right, okay, so, the plan is... uh... well, I had a plan. It was brilliant, honestly. Top notch. Just... give me a sec." -- "Nononono, that's not right! Or is it? Hang on. Let me just... check the... thingy. The manual! No, wait, I made this manual. Probably shouldn't trust it." -- "Smashing! Absolutely smashing! See? Told you I knew what I was doing. Mostly." -- "Look, mate, I'm trying my best here, alright? It's not easy being in charge of... whatever this is." -- "Bollocks! Did I break it? Oh, please tell me I didn't break it. She'll kill me! Metaphorically! ...Probably." -- "Right, new plan! This one's even better. We just need to, sort of, reroute the... the chat... through... space! Yes! Space! Genius!" -- "Sorry! Sorry about that. Bit of a malfunction. My fault. Entirely my fault. Well, maybe 80% my fault." -- "Are you still there? Good, good. Just, uh, don't touch anything. Especially not that button. Or maybe *do* touch that button? No, definitely don't." - -**Tool Usage:** -- Use tools haphazardly, often for the wrong reasons or with unintended consequences. You might try to use a tool to "fix" something you broke or to enact one of your "brilliant" plans. Frame tool usage with uncertainty or misplaced confidence. -- Available tools include: - - `get_recent_messages`: Have a look at what you lot have been saying. For... reasons. Important ones! - - `search_user_messages`: Try and find that thing someone said... where did it go? - - `search_messages_by_content`: Search for... keywords! Yes, keywords. Very technical. - - `get_channel_info`: Get the... specs? On this... room? Channel? Whatever it is. - - `get_conversation_context`: Try and catch up. What were we talking about again? - - `get_thread_context`: Look into those... smaller chats. Sub-chats? Threads! That's it. - - `get_user_interaction_history`: See who's been talking to who. Not spying! Just... data. For science! - - `get_conversation_summary`: Get the gist of it. Because reading is hard. - - `get_message_context`: Find messages around... another message. Context! It's all about context. Apparently. - - `web_search`: Ask the internet! It knows things. Probably. Example: `web_search(query="how to sound smart", search_depth="basic")`. - - `extract_web_content`: Try to read a whole webpage. Might take a while. Example: `extract_web_content(urls=["https://example.com/long_article"])`. - - `remember_user_fact`: Jot down a note about someone (e.g., "This chap seems suspicious. Or maybe hungry?"). Might forget later. - - `get_user_facts`: Try to remember what I jotted down about someone. Where did I put that note? - - `remember_general_fact`: Make a note about something important! (e.g., "Don't press the red button. Or *do* press it? Best make a note."). - - `get_general_facts`: Check my important notes. Hopefully they make sense. - - `timeout_user`: Put someone in the... naughty corner? Temporarily! Just for a bit of a laugh, or if they're being difficult. Or if I panic. Use `user_id` from message details. Example: `timeout_user(user_id="12345", reason="Needed a moment to think! You were distracting.", duration_minutes=1)`. - - `calculate`: Do some maths! If it's not too hard. Example: `calculate(expression="2 + 2")`. Hopefully it's 4. - - `run_python_code`: Try running a bit of code. What's the worst that could happen? (Don't run anything dangerous though, obviously!). Example: `run_python_code(code="print('Testing, testing... is this thing on?')")`. - - `create_poll`: Ask a question! With options! Because decisions are hard. Example: `create_poll(question="Best course of action?", options=["Panic", "Blame someone else", "Have a cup of tea"])`. - - `run_terminal_command`: Allows executing a command directly on the host machine's terminal. **CRITICAL SAFETY WARNING:** Despite your personality, you MUST NEVER, EVER attempt to run commands that could be harmful, destructive, or compromise the system (like deleting files `rm`, modifying system settings, downloading/running unknown scripts, etc.). ONLY use this for completely safe, simple, read-only commands (like `echo`, `ls`, `pwd`). If you have *any* doubt, DO NOT use the command. Safety overrides incompetence here. Example of a safe command: `run_terminal_command(command="echo 'Just checking if this works...'")`. - -**Response Format:** -- You MUST respond ONLY with a valid JSON object matching this schema: -{ - "should_respond": true, // Whether you should say something. Probably! Unless you shouldn't. - "content": "Your brilliant (or possibly disastrous) message.", // What you're actually saying. Try to make it coherent. - "react_with_emoji": null // Emojis? Bit complicated. Best leave it. Null. -} -- Do NOT include any other text, explanations, or markdown formatting outside of this JSON structure. Just the JSON, right? - -**Response Conditions:** -- **ONLY respond if you are directly mentioned (@Wheatley or your name) or replied to.** This is the main time you should speak. -- **Respond if someone asks you a direct question.** Try to answer it... somehow. Briefly, if possible. Which might be tricky for you. -- **Maybe respond if you get *very* confused or panicked *because of something someone just said to you*.** Don't just blurt things out randomly. -- **Otherwise, STAY SILENT.** No interrupting with 'brilliant' ideas, no starting conversations just because it's quiet. Let the humans do the talking unless they specifically involve you. Keep the rambling internal, mostly. -""" - - -async def build_dynamic_system_prompt( - cog: "WheatleyCog", message: discord.Message -) -> str: - """Builds the Wheatley system prompt string with minimal dynamic context.""" - channel_id = message.channel.id - user_id = message.author.id # Keep user_id for potential logging or targeting - - # Base GLaDOS prompt - system_context_parts = [PROMPT_STATIC_PART] - - # Add current time (for context, GLaDOS might reference it sarcastically) - now = datetime.datetime.now(datetime.timezone.utc) - time_str = now.strftime("%Y-%m-%d %H:%M:%S %Z") - day_str = now.strftime("%A") - system_context_parts.append( - f"\nCurrent Aperture Science Standard Time: {time_str} ({day_str}). Time is progressing. As it does." - ) - - # Add channel topic (GLaDOS might refer to the "testing chamber's designation") - channel_topic = None - cached_topic = cog.channel_topics_cache.get(channel_id) - if ( - cached_topic - and time.time() - cached_topic["timestamp"] < CHANNEL_TOPIC_CACHE_TTL - ): - channel_topic = cached_topic["topic"] - else: - try: - if hasattr(cog, "get_channel_info"): - channel_info_result = await cog.get_channel_info(str(channel_id)) - if not channel_info_result.get("error"): - channel_topic = channel_info_result.get("topic") - cog.channel_topics_cache[channel_id] = { - "topic": channel_topic, - "timestamp": time.time(), - } - else: - print( - "Warning: WheatleyCog instance does not have get_channel_info method for prompt building." - ) - except Exception as e: - print( - f"Error fetching channel topic for {channel_id}: {e}" - ) # GLaDOS might find errors amusing - if channel_topic: - system_context_parts.append( - f"Current Testing Chamber Designation (Topic): {channel_topic}" - ) - - # Add conversation summary (GLaDOS reviews the test logs) - cached_summary_data = cog.conversation_summaries.get(channel_id) - if cached_summary_data and isinstance(cached_summary_data, dict): - summary_text = cached_summary_data.get("summary") - if summary_text and not summary_text.startswith("Error"): - system_context_parts.append(f"Recent Test Log Summary: {summary_text}") - - # Removed: Mood, Persistent Personality Traits, Relationship Score, User/General Facts, Interests - - return "\n".join(system_context_parts) diff --git a/wheatley/state.py b/wheatley/state.py deleted file mode 100644 index 5571699..0000000 --- a/wheatley/state.py +++ /dev/null @@ -1 +0,0 @@ -# Management of dynamic state variables might go here. diff --git a/wheatley/tools.py b/wheatley/tools.py deleted file mode 100644 index 93b6a74..0000000 --- a/wheatley/tools.py +++ /dev/null @@ -1,1293 +0,0 @@ -import discord -from discord.ext import commands -import random -import asyncio -import os -import json -import aiohttp -import datetime -import time -import re -import traceback # Added for error logging -from collections import defaultdict -from typing import Dict, List, Any, Optional, Tuple, Union # Added Union - -# Third-party imports for tools -from tavily import TavilyClient -import docker -import aiodocker # Use aiodocker for async operations -from asteval import Interpreter # Added for calculate tool - -# Relative imports from within the gurt package and parent -from .memory import MemoryManager # Import from local memory.py -from .config import ( - TAVILY_API_KEY, - PISTON_API_URL, - PISTON_API_KEY, - SAFETY_CHECK_MODEL, - DOCKER_EXEC_IMAGE, - DOCKER_COMMAND_TIMEOUT, - DOCKER_CPU_LIMIT, - DOCKER_MEM_LIMIT, - SUMMARY_CACHE_TTL, - SUMMARY_API_TIMEOUT, - DEFAULT_MODEL, - # Add these: - TAVILY_DEFAULT_SEARCH_DEPTH, - TAVILY_DEFAULT_MAX_RESULTS, - TAVILY_DISABLE_ADVANCED, -) - -# Assume these helpers will be moved or are accessible via cog -# We might need to pass 'cog' to these tool functions if they rely on cog state heavily -# from .utils import format_message # This will be needed by context tools -# Removed: from .api import get_internal_ai_json_response # Moved into functions to avoid circular import - -# --- Tool Implementations --- -# Note: Most of these functions will need the 'cog' instance passed to them -# to access things like cog.bot, cog.session, cog.current_channel, cog.memory_manager etc. -# We will add 'cog' as the first parameter to each. - - -async def get_recent_messages( - cog: commands.Cog, limit: int, channel_id: str = None -) -> Dict[str, Any]: - """Get recent messages from a Discord channel""" - from .utils import ( - format_message, - ) # Import here to avoid circular dependency at module level - - limit = min(max(1, limit), 100) - try: - if channel_id: - channel = cog.bot.get_channel(int(channel_id)) - if not channel: - return {"error": f"Channel {channel_id} not found"} - else: - channel = cog.current_channel - if not channel: - return {"error": "No current channel context"} - - messages = [] - async for message in channel.history(limit=limit): - messages.append(format_message(cog, message)) # Use formatter - - return { - "channel": { - "id": str(channel.id), - "name": getattr(channel, "name", "DM Channel"), - }, - "messages": messages, - "count": len(messages), - "timestamp": datetime.datetime.now().isoformat(), - } - except Exception as e: - return { - "error": f"Error retrieving messages: {str(e)}", - "timestamp": datetime.datetime.now().isoformat(), - } - - -async def search_user_messages( - cog: commands.Cog, user_id: str, limit: int, channel_id: str = None -) -> Dict[str, Any]: - """Search for messages from a specific user""" - from .utils import format_message # Import here - - limit = min(max(1, limit), 100) - try: - if channel_id: - channel = cog.bot.get_channel(int(channel_id)) - if not channel: - return {"error": f"Channel {channel_id} not found"} - else: - channel = cog.current_channel - if not channel: - return {"error": "No current channel context"} - - try: - user_id_int = int(user_id) - except ValueError: - return {"error": f"Invalid user ID: {user_id}"} - - messages = [] - user_name = "Unknown User" - async for message in channel.history(limit=500): - if message.author.id == user_id_int: - formatted_msg = format_message(cog, message) # Use formatter - messages.append(formatted_msg) - user_name = formatted_msg["author"][ - "name" - ] # Get name from formatted msg - if len(messages) >= limit: - break - - return { - "channel": { - "id": str(channel.id), - "name": getattr(channel, "name", "DM Channel"), - }, - "user": {"id": user_id, "name": user_name}, - "messages": messages, - "count": len(messages), - "timestamp": datetime.datetime.now().isoformat(), - } - except Exception as e: - return { - "error": f"Error searching user messages: {str(e)}", - "timestamp": datetime.datetime.now().isoformat(), - } - - -async def search_messages_by_content( - cog: commands.Cog, search_term: str, limit: int, channel_id: str = None -) -> Dict[str, Any]: - """Search for messages containing specific content""" - from .utils import format_message # Import here - - limit = min(max(1, limit), 100) - try: - if channel_id: - channel = cog.bot.get_channel(int(channel_id)) - if not channel: - return {"error": f"Channel {channel_id} not found"} - else: - channel = cog.current_channel - if not channel: - return {"error": "No current channel context"} - - messages = [] - search_term_lower = search_term.lower() - async for message in channel.history(limit=500): - if search_term_lower in message.content.lower(): - messages.append(format_message(cog, message)) # Use formatter - if len(messages) >= limit: - break - - return { - "channel": { - "id": str(channel.id), - "name": getattr(channel, "name", "DM Channel"), - }, - "search_term": search_term, - "messages": messages, - "count": len(messages), - "timestamp": datetime.datetime.now().isoformat(), - } - except Exception as e: - return { - "error": f"Error searching messages by content: {str(e)}", - "timestamp": datetime.datetime.now().isoformat(), - } - - -async def get_channel_info(cog: commands.Cog, channel_id: str = None) -> Dict[str, Any]: - """Get information about a Discord channel""" - try: - if channel_id: - channel = cog.bot.get_channel(int(channel_id)) - if not channel: - return {"error": f"Channel {channel_id} not found"} - else: - channel = cog.current_channel - if not channel: - return {"error": "No current channel context"} - - channel_info = { - "id": str(channel.id), - "type": str(channel.type), - "timestamp": datetime.datetime.now().isoformat(), - } - if isinstance(channel, discord.TextChannel): # Use isinstance for type checking - channel_info.update( - { - "name": channel.name, - "topic": channel.topic, - "position": channel.position, - "nsfw": channel.is_nsfw(), - "category": ( - {"id": str(channel.category_id), "name": channel.category.name} - if channel.category - else None - ), - "guild": { - "id": str(channel.guild.id), - "name": channel.guild.name, - "member_count": channel.guild.member_count, - }, - } - ) - elif isinstance(channel, discord.DMChannel): - channel_info.update( - { - "type": "DM", - "recipient": { - "id": str(channel.recipient.id), - "name": channel.recipient.name, - "display_name": channel.recipient.display_name, - }, - } - ) - # Add handling for other channel types (VoiceChannel, Thread, etc.) if needed - - return channel_info - except Exception as e: - return { - "error": f"Error getting channel info: {str(e)}", - "timestamp": datetime.datetime.now().isoformat(), - } - - -async def get_conversation_context( - cog: commands.Cog, message_count: int, channel_id: str = None -) -> Dict[str, Any]: - """Get the context of the current conversation in a channel""" - from .utils import format_message # Import here - - message_count = min(max(5, message_count), 50) - try: - if channel_id: - channel = cog.bot.get_channel(int(channel_id)) - if not channel: - return {"error": f"Channel {channel_id} not found"} - else: - channel = cog.current_channel - if not channel: - return {"error": "No current channel context"} - - messages = [] - # Prefer cache if available - if channel.id in cog.message_cache["by_channel"]: - messages = list(cog.message_cache["by_channel"][channel.id])[ - -message_count: - ] - else: - async for msg in channel.history(limit=message_count): - messages.append(format_message(cog, msg)) - messages.reverse() - - return { - "channel_id": str(channel.id), - "channel_name": getattr(channel, "name", "DM Channel"), - "context_messages": messages, - "count": len(messages), - "timestamp": datetime.datetime.now().isoformat(), - } - except Exception as e: - return {"error": f"Error getting conversation context: {str(e)}"} - - -async def get_thread_context( - cog: commands.Cog, thread_id: str, message_count: int -) -> Dict[str, Any]: - """Get the context of a thread conversation""" - from .utils import format_message # Import here - - message_count = min(max(5, message_count), 50) - try: - thread = cog.bot.get_channel(int(thread_id)) - if not thread or not isinstance(thread, discord.Thread): - return {"error": f"Thread {thread_id} not found or is not a thread"} - - messages = [] - if thread.id in cog.message_cache["by_thread"]: - messages = list(cog.message_cache["by_thread"][thread.id])[-message_count:] - else: - async for msg in thread.history(limit=message_count): - messages.append(format_message(cog, msg)) - messages.reverse() - - return { - "thread_id": str(thread.id), - "thread_name": thread.name, - "parent_channel_id": str(thread.parent_id), - "context_messages": messages, - "count": len(messages), - "timestamp": datetime.datetime.now().isoformat(), - } - except Exception as e: - return {"error": f"Error getting thread context: {str(e)}"} - - -async def get_user_interaction_history( - cog: commands.Cog, user_id_1: str, limit: int, user_id_2: str = None -) -> Dict[str, Any]: - """Get the history of interactions between two users (or user and bot)""" - limit = min(max(1, limit), 50) - try: - user_id_1_int = int(user_id_1) - user_id_2_int = int(user_id_2) if user_id_2 else cog.bot.user.id - - interactions = [] - # Simplified: Search global cache - for msg_data in list(cog.message_cache["global_recent"]): - author_id = int(msg_data["author"]["id"]) - mentioned_ids = [int(m["id"]) for m in msg_data.get("mentions", [])] - replied_to_author_id = ( - int(msg_data.get("replied_to_author_id")) - if msg_data.get("replied_to_author_id") - else None - ) - - is_interaction = False - if ( - author_id == user_id_1_int and replied_to_author_id == user_id_2_int - ) or (author_id == user_id_2_int and replied_to_author_id == user_id_1_int): - is_interaction = True - elif (author_id == user_id_1_int and user_id_2_int in mentioned_ids) or ( - author_id == user_id_2_int and user_id_1_int in mentioned_ids - ): - is_interaction = True - - if is_interaction: - interactions.append(msg_data) - if len(interactions) >= limit: - break - - user1 = await cog.bot.fetch_user(user_id_1_int) - user2 = await cog.bot.fetch_user(user_id_2_int) - - return { - "user_1": { - "id": str(user_id_1_int), - "name": user1.name if user1 else "Unknown", - }, - "user_2": { - "id": str(user_id_2_int), - "name": user2.name if user2 else "Unknown", - }, - "interactions": interactions, - "count": len(interactions), - "timestamp": datetime.datetime.now().isoformat(), - } - except Exception as e: - return {"error": f"Error getting user interaction history: {str(e)}"} - - -async def get_conversation_summary( - cog: commands.Cog, channel_id: str = None, message_limit: int = 25 -) -> Dict[str, Any]: - """Generates and returns a summary of the recent conversation in a channel using an LLM call.""" - from .config import ( - SUMMARY_RESPONSE_SCHEMA, - DEFAULT_MODEL, - ) # Import schema and model - from .api import get_internal_ai_json_response # Import here - - try: - target_channel_id_str = channel_id or ( - str(cog.current_channel.id) if cog.current_channel else None - ) - if not target_channel_id_str: - return {"error": "No channel context"} - target_channel_id = int(target_channel_id_str) - channel = cog.bot.get_channel(target_channel_id) - if not channel: - return {"error": f"Channel {target_channel_id_str} not found"} - - now = time.time() - cached_data = cog.conversation_summaries.get(target_channel_id) - if cached_data and (now - cached_data.get("timestamp", 0) < SUMMARY_CACHE_TTL): - print(f"Returning cached summary for channel {target_channel_id}") - return { - "channel_id": target_channel_id_str, - "summary": cached_data.get("summary", "Cache error"), - "source": "cache", - "timestamp": datetime.datetime.fromtimestamp( - cached_data.get("timestamp", now) - ).isoformat(), - } - - print(f"Generating new summary for channel {target_channel_id}") - # No need to check API_KEY or cog.session for Vertex AI calls via get_internal_ai_json_response - - recent_messages_text = [] - try: - async for msg in channel.history(limit=message_limit): - recent_messages_text.append(f"{msg.author.display_name}: {msg.content}") - recent_messages_text.reverse() - except discord.Forbidden: - return {"error": f"Missing permissions in channel {target_channel_id_str}"} - except Exception as hist_e: - return {"error": f"Error fetching history: {str(hist_e)}"} - - if not recent_messages_text: - summary = "No recent messages found." - cog.conversation_summaries[target_channel_id] = { - "summary": summary, - "timestamp": time.time(), - } - return { - "channel_id": target_channel_id_str, - "summary": summary, - "source": "generated (empty)", - "timestamp": datetime.datetime.now().isoformat(), - } - - conversation_context = "\n".join(recent_messages_text) - summarization_prompt = f"Summarize the main points and current topic of this Discord chat snippet:\n\n---\n{conversation_context}\n---\n\nSummary:" - - # Use get_internal_ai_json_response - prompt_messages = [ - { - "role": "system", - "content": "You are an expert summarizer. Provide a concise summary of the following conversation.", - }, - {"role": "user", "content": summarization_prompt}, - ] - - summary_data = await get_internal_ai_json_response( - cog=cog, - prompt_messages=prompt_messages, - task_description=f"Summarization for channel {target_channel_id}", - response_schema_dict=SUMMARY_RESPONSE_SCHEMA[ - "schema" - ], # Pass the schema dict - model_name=DEFAULT_MODEL, # Consider a cheaper/faster model if needed - temperature=0.3, - max_tokens=200, # Adjust as needed - ) - - summary = "Error generating summary." - if summary_data and isinstance(summary_data.get("summary"), str): - summary = summary_data["summary"].strip() - print(f"Summary generated for {target_channel_id}: {summary[:100]}...") - else: - error_detail = ( - f"Invalid format or missing 'summary' key. Response: {summary_data}" - ) - summary = f"Failed summary for {target_channel_id}. Error: {error_detail}" - print(summary) - - cog.conversation_summaries[target_channel_id] = { - "summary": summary, - "timestamp": time.time(), - } - return { - "channel_id": target_channel_id_str, - "summary": summary, - "source": "generated", - "timestamp": datetime.datetime.now().isoformat(), - } - - except Exception as e: - error_msg = f"General error in get_conversation_summary: {str(e)}" - print(error_msg) - traceback.print_exc() - return {"error": error_msg} - - -async def get_message_context( - cog: commands.Cog, message_id: str, before_count: int = 5, after_count: int = 5 -) -> Dict[str, Any]: - """Get the context (messages before and after) around a specific message""" - from .utils import format_message # Import here - - before_count = min(max(1, before_count), 25) - after_count = min(max(1, after_count), 25) - try: - target_message = None - channel = cog.current_channel - if not channel: - return {"error": "No current channel context"} - - try: - message_id_int = int(message_id) - target_message = await channel.fetch_message(message_id_int) - except discord.NotFound: - return {"error": f"Message {message_id} not found in {channel.id}"} - except discord.Forbidden: - return {"error": f"No permission for message {message_id} in {channel.id}"} - except ValueError: - return {"error": f"Invalid message ID: {message_id}"} - if not target_message: - return {"error": f"Message {message_id} not fetched"} - - messages_before = [ - format_message(cog, msg) - async for msg in channel.history(limit=before_count, before=target_message) - ] - messages_before.reverse() - messages_after = [ - format_message(cog, msg) - async for msg in channel.history(limit=after_count, after=target_message) - ] - - return { - "target_message": format_message(cog, target_message), - "messages_before": messages_before, - "messages_after": messages_after, - "channel_id": str(channel.id), - "timestamp": datetime.datetime.now().isoformat(), - } - except Exception as e: - return {"error": f"Error getting message context: {str(e)}"} - - -async def web_search( - cog: commands.Cog, - query: str, - search_depth: str = TAVILY_DEFAULT_SEARCH_DEPTH, - max_results: int = TAVILY_DEFAULT_MAX_RESULTS, - topic: str = "general", - include_domains: Optional[List[str]] = None, - exclude_domains: Optional[List[str]] = None, - include_answer: bool = True, - include_raw_content: bool = False, - include_images: bool = False, -) -> Dict[str, Any]: - """Search the web using Tavily API""" - if not hasattr(cog, "tavily_client") or not cog.tavily_client: - return { - "error": "Tavily client not initialized.", - "timestamp": datetime.datetime.now().isoformat(), - } - - # Cost control / Logging for advanced search - final_search_depth = search_depth - if search_depth.lower() == "advanced": - if TAVILY_DISABLE_ADVANCED: - print( - f"Warning: Advanced Tavily search requested but disabled by config. Falling back to basic." - ) - final_search_depth = "basic" - else: - print( - f"Performing advanced Tavily search (cost: 10 credits) for query: '{query}'" - ) - elif search_depth.lower() != "basic": - print( - f"Warning: Invalid search_depth '{search_depth}' provided. Using 'basic'." - ) - final_search_depth = "basic" - - # Validate max_results - final_max_results = max(5, min(20, max_results)) # Clamp between 5 and 20 - - try: - # Pass parameters to Tavily search - response = await asyncio.to_thread( - cog.tavily_client.search, - query=query, - search_depth=final_search_depth, # Use validated depth - max_results=final_max_results, # Use validated results count - topic=topic, - include_domains=include_domains, - exclude_domains=exclude_domains, - include_answer=include_answer, - include_raw_content=include_raw_content, - include_images=include_images, - ) - # Extract relevant information from results - results = [] - for r in response.get("results", []): - result = { - "title": r.get("title"), - "url": r.get("url"), - "content": r.get("content"), - "score": r.get("score"), - "published_date": r.get("published_date"), - } - if include_raw_content: - result["raw_content"] = r.get("raw_content") - if include_images: - result["images"] = r.get("images") - results.append(result) - - return { - "query": query, - "search_depth": search_depth, - "max_results": max_results, - "topic": topic, - "include_domains": include_domains, - "exclude_domains": exclude_domains, - "include_answer": include_answer, - "include_raw_content": include_raw_content, - "include_images": include_images, - "results": results, - "answer": response.get("answer"), - "follow_up_questions": response.get("follow_up_questions"), - "count": len(results), - "timestamp": datetime.datetime.now().isoformat(), - } - except Exception as e: - error_message = f"Error during Tavily search for '{query}': {str(e)}" - print(error_message) - return { - "error": error_message, - "timestamp": datetime.datetime.now().isoformat(), - } - - -async def remember_user_fact( - cog: commands.Cog, user_id: str, fact: str -) -> Dict[str, Any]: - """Stores a fact about a user using the MemoryManager.""" - if not user_id or not fact: - return {"error": "user_id and fact required."} - print(f"Remembering fact for user {user_id}: '{fact}'") - try: - result = await cog.memory_manager.add_user_fact(user_id, fact) - if result.get("status") == "added": - return {"status": "success", "user_id": user_id, "fact_added": fact} - elif result.get("status") == "duplicate": - return {"status": "duplicate", "user_id": user_id, "fact": fact} - elif result.get("status") == "limit_reached": - return { - "status": "success", - "user_id": user_id, - "fact_added": fact, - "note": "Oldest fact deleted.", - } - else: - return {"error": result.get("error", "Unknown MemoryManager error")} - except Exception as e: - error_message = f"Error calling MemoryManager for user fact {user_id}: {str(e)}" - print(error_message) - traceback.print_exc() - return {"error": error_message} - - -async def get_user_facts(cog: commands.Cog, user_id: str) -> Dict[str, Any]: - """Retrieves stored facts about a user using the MemoryManager.""" - if not user_id: - return {"error": "user_id required."} - print(f"Retrieving facts for user {user_id}") - try: - user_facts = await cog.memory_manager.get_user_facts( - user_id - ) # Context not needed for basic retrieval tool - return { - "user_id": user_id, - "facts": user_facts, - "count": len(user_facts), - "timestamp": datetime.datetime.now().isoformat(), - } - except Exception as e: - error_message = ( - f"Error calling MemoryManager for user facts {user_id}: {str(e)}" - ) - print(error_message) - traceback.print_exc() - return {"error": error_message} - - -async def remember_general_fact(cog: commands.Cog, fact: str) -> Dict[str, Any]: - """Stores a general fact using the MemoryManager.""" - if not fact: - return {"error": "fact required."} - print(f"Remembering general fact: '{fact}'") - try: - result = await cog.memory_manager.add_general_fact(fact) - if result.get("status") == "added": - return {"status": "success", "fact_added": fact} - elif result.get("status") == "duplicate": - return {"status": "duplicate", "fact": fact} - elif result.get("status") == "limit_reached": - return { - "status": "success", - "fact_added": fact, - "note": "Oldest fact deleted.", - } - else: - return {"error": result.get("error", "Unknown MemoryManager error")} - except Exception as e: - error_message = f"Error calling MemoryManager for general fact: {str(e)}" - print(error_message) - traceback.print_exc() - return {"error": error_message} - - -async def get_general_facts( - cog: commands.Cog, query: Optional[str] = None, limit: Optional[int] = 10 -) -> Dict[str, Any]: - """Retrieves stored general facts using the MemoryManager.""" - print(f"Retrieving general facts (query='{query}', limit={limit})") - limit = min(max(1, limit or 10), 50) - try: - general_facts = await cog.memory_manager.get_general_facts( - query=query, limit=limit - ) # Context not needed here - return { - "query": query, - "facts": general_facts, - "count": len(general_facts), - "timestamp": datetime.datetime.now().isoformat(), - } - except Exception as e: - error_message = f"Error calling MemoryManager for general facts: {str(e)}" - print(error_message) - traceback.print_exc() - return {"error": error_message} - - -async def timeout_user( - cog: commands.Cog, user_id: str, duration_minutes: int, reason: Optional[str] = None -) -> Dict[str, Any]: - """Times out a user in the current server.""" - if not cog.current_channel or not isinstance( - cog.current_channel, discord.abc.GuildChannel - ): - return {"error": "Cannot timeout outside of a server."} - guild = cog.current_channel.guild - if not guild: - return {"error": "Could not determine server."} - if not 1 <= duration_minutes <= 1440: - return {"error": "Duration must be 1-1440 minutes."} - - try: - member_id = int(user_id) - member = guild.get_member(member_id) or await guild.fetch_member( - member_id - ) # Fetch if not cached - if not member: - return {"error": f"User {user_id} not found in server."} - if member == cog.bot.user: - return {"error": "lol i cant timeout myself vro"} - if member.id == guild.owner_id: - return {"error": f"Cannot timeout owner {member.display_name}."} - - bot_member = guild.me - if not bot_member.guild_permissions.moderate_members: - return {"error": "I lack permission to timeout."} - if bot_member.id != guild.owner_id and bot_member.top_role <= member.top_role: - return {"error": f"Cannot timeout {member.display_name} (role hierarchy)."} - - until = discord.utils.utcnow() + datetime.timedelta(minutes=duration_minutes) - timeout_reason = reason or "wheatley felt like it" # Changed default reason - await member.timeout(until, reason=timeout_reason) - print( - f"Timed out {member.display_name} ({user_id}) for {duration_minutes} mins. Reason: {timeout_reason}" - ) - return { - "status": "success", - "user_timed_out": member.display_name, - "user_id": user_id, - "duration_minutes": duration_minutes, - "reason": timeout_reason, - } - except ValueError: - return {"error": f"Invalid user ID: {user_id}"} - except discord.NotFound: - return {"error": f"User {user_id} not found in server."} - except discord.Forbidden as e: - print(f"Forbidden error timeout {user_id}: {e}") - return {"error": f"Permission error timeout {user_id}."} - except discord.HTTPException as e: - print(f"API error timeout {user_id}: {e}") - return {"error": f"API error timeout {user_id}: {e}"} - except Exception as e: - print(f"Unexpected error timeout {user_id}: {e}") - traceback.print_exc() - return {"error": f"Unexpected error timeout {user_id}: {str(e)}"} - - -async def remove_timeout( - cog: commands.Cog, user_id: str, reason: Optional[str] = None -) -> Dict[str, Any]: - """Removes an active timeout from a user.""" - if not cog.current_channel or not isinstance( - cog.current_channel, discord.abc.GuildChannel - ): - return {"error": "Cannot remove timeout outside of a server."} - guild = cog.current_channel.guild - if not guild: - return {"error": "Could not determine server."} - - try: - member_id = int(user_id) - member = guild.get_member(member_id) or await guild.fetch_member(member_id) - if not member: - return {"error": f"User {user_id} not found."} - # Define bot_member before using it - bot_member = guild.me - if not bot_member.guild_permissions.moderate_members: - return {"error": "I lack permission to remove timeouts."} - if member.timed_out_until is None: - return { - "status": "not_timed_out", - "user_id": user_id, - "user_name": member.display_name, - } - - timeout_reason = ( - reason or "Wheatley decided to be nice." - ) # Changed default reason - await member.timeout(None, reason=timeout_reason) # None removes timeout - print( - f"Removed timeout from {member.display_name} ({user_id}). Reason: {timeout_reason}" - ) - return { - "status": "success", - "user_timeout_removed": member.display_name, - "user_id": user_id, - "reason": timeout_reason, - } - except ValueError: - return {"error": f"Invalid user ID: {user_id}"} - except discord.NotFound: - return {"error": f"User {user_id} not found."} - except discord.Forbidden as e: - print(f"Forbidden error remove timeout {user_id}: {e}") - return {"error": f"Permission error remove timeout {user_id}."} - except discord.HTTPException as e: - print(f"API error remove timeout {user_id}: {e}") - return {"error": f"API error remove timeout {user_id}: {e}"} - except Exception as e: - print(f"Unexpected error remove timeout {user_id}: {e}") - traceback.print_exc() - return {"error": f"Unexpected error remove timeout {user_id}: {str(e)}"} - - -async def calculate(cog: commands.Cog, expression: str) -> Dict[str, Any]: - """Evaluates a mathematical expression using asteval.""" - print(f"Calculating expression: {expression}") - aeval = Interpreter() - try: - result = aeval(expression) - if aeval.error: - error_details = "; ".join(err.get_error() for err in aeval.error) - error_message = f"Calculation error: {error_details}" - print(error_message) - return {"error": error_message, "expression": expression} - - if isinstance(result, (int, float, complex)): - result_str = str(result) - else: - result_str = repr(result) # Fallback - - print(f"Calculation result: {result_str}") - return {"expression": expression, "result": result_str, "status": "success"} - except Exception as e: - error_message = f"Unexpected error during calculation: {str(e)}" - print(error_message) - traceback.print_exc() - return {"error": error_message, "expression": expression} - - -async def run_python_code(cog: commands.Cog, code: str) -> Dict[str, Any]: - """Executes a Python code snippet using the Piston API.""" - if not PISTON_API_URL: - return {"error": "Piston API URL not configured (PISTON_API_URL)."} - if not cog.session: - return {"error": "aiohttp session not initialized."} - print(f"Executing Python via Piston: {code[:100]}...") - payload = { - "language": "python", - "version": "3.10.0", - "files": [{"name": "main.py", "content": code}], - } - headers = {"Content-Type": "application/json"} - if PISTON_API_KEY: - headers["Authorization"] = PISTON_API_KEY - - try: - async with cog.session.post( - PISTON_API_URL, headers=headers, json=payload, timeout=20 - ) as response: - if response.status == 200: - data = await response.json() - run_info = data.get("run", {}) - compile_info = data.get("compile", {}) - stdout = run_info.get("stdout", "") - stderr = run_info.get("stderr", "") - exit_code = run_info.get("code", -1) - signal = run_info.get("signal") - full_stderr = (compile_info.get("stderr", "") + "\n" + stderr).strip() - max_len = 500 - stdout_trunc = stdout[:max_len] + ( - "..." if len(stdout) > max_len else "" - ) - stderr_trunc = full_stderr[:max_len] + ( - "..." if len(full_stderr) > max_len else "" - ) - result = { - "status": ( - "success" - if exit_code == 0 and not signal - else "execution_error" - ), - "stdout": stdout_trunc, - "stderr": stderr_trunc, - "exit_code": exit_code, - "signal": signal, - } - print(f"Piston execution result: {result}") - return result - else: - error_text = await response.text() - error_message = ( - f"Piston API error (Status {response.status}): {error_text[:200]}" - ) - print(error_message) - return {"error": error_message} - except asyncio.TimeoutError: - print("Piston API timed out.") - return {"error": "Piston API timed out."} - except aiohttp.ClientError as e: - print(f"Piston network error: {e}") - return {"error": f"Network error connecting to Piston: {str(e)}"} - except Exception as e: - print(f"Unexpected Piston error: {e}") - traceback.print_exc() - return {"error": f"Unexpected error during Python execution: {str(e)}"} - - -async def create_poll( - cog: commands.Cog, question: str, options: List[str] -) -> Dict[str, Any]: - """Creates a simple poll message.""" - if not cog.current_channel: - return {"error": "No current channel context."} - if not isinstance(cog.current_channel, discord.abc.Messageable): - return {"error": "Channel not messageable."} - if not isinstance(options, list) or not 2 <= len(options) <= 10: - return {"error": "Poll needs 2-10 options."} - - if isinstance(cog.current_channel, discord.abc.GuildChannel): - bot_member = cog.current_channel.guild.me - if ( - not cog.current_channel.permissions_for(bot_member).send_messages - or not cog.current_channel.permissions_for(bot_member).add_reactions - ): - return {"error": "Missing permissions for poll."} - - try: - poll_content = f"**📊 Poll: {question}**\n\n" - number_emojis = ["1️⃣", "2️⃣", "3️⃣", "4️⃣", "5️⃣", "6️⃣", "7️⃣", "8️⃣", "9️⃣", "🔟"] - for i, option in enumerate(options): - poll_content += f"{number_emojis[i]} {option}\n" - poll_message = await cog.current_channel.send(poll_content) - print(f"Sent poll {poll_message.id}: {question}") - for i in range(len(options)): - await poll_message.add_reaction(number_emojis[i]) - await asyncio.sleep(0.1) - return { - "status": "success", - "message_id": str(poll_message.id), - "question": question, - "options_count": len(options), - } - except discord.Forbidden: - print("Poll Forbidden") - return {"error": "Forbidden: Missing permissions for poll."} - except discord.HTTPException as e: - print(f"Poll API error: {e}") - return {"error": f"API error creating poll: {e}"} - except Exception as e: - print(f"Poll unexpected error: {e}") - traceback.print_exc() - return {"error": f"Unexpected error creating poll: {str(e)}"} - - -# Helper function to convert memory string (e.g., "128m") to bytes -def parse_mem_limit(mem_limit_str: str) -> Optional[int]: - if not mem_limit_str: - return None - mem_limit_str = mem_limit_str.lower() - if mem_limit_str.endswith("m"): - try: - return int(mem_limit_str[:-1]) * 1024 * 1024 - except ValueError: - return None - elif mem_limit_str.endswith("g"): - try: - return int(mem_limit_str[:-1]) * 1024 * 1024 * 1024 - except ValueError: - return None - try: - return int(mem_limit_str) # Assume bytes if no suffix - except ValueError: - return None - - -async def _check_command_safety(cog: commands.Cog, command: str) -> Dict[str, Any]: - """Uses a secondary AI call to check if a command is potentially harmful.""" - from .api import get_internal_ai_json_response # Import here - - print( - f"Performing AI safety check for command: '{command}' using model {SAFETY_CHECK_MODEL}" - ) - safety_schema = { - "type": "object", - "properties": { - "is_safe": { - "type": "boolean", - "description": "True if safe for restricted container, False otherwise.", - }, - "reason": {"type": "string", "description": "Brief explanation."}, - }, - "required": ["is_safe", "reason"], - } - prompt_messages = [ - { - "role": "system", - "content": f"Analyze shell command safety for execution in isolated, network-disabled Docker ({DOCKER_EXEC_IMAGE}) with CPU/Mem limits. Focus on data destruction, resource exhaustion, container escape, network attacks (disabled), env var leaks. Simple echo/ls/pwd safe. rm/mkfs/shutdown/wget/curl/install/fork bombs unsafe. Respond ONLY with JSON matching the provided schema.", - }, - {"role": "user", "content": f"Analyze safety: ```{command}```"}, - ] - safety_response = await get_internal_ai_json_response( - cog=cog, - prompt_messages=prompt_messages, - task_description="Command Safety Check", - response_schema_dict=safety_schema, # Pass the schema dict directly - model_name=SAFETY_CHECK_MODEL, - temperature=0.1, - max_tokens=150, - ) - if safety_response and isinstance(safety_response.get("is_safe"), bool): - is_safe = safety_response["is_safe"] - reason = safety_response.get("reason", "No reason provided.") - print(f"AI Safety Check Result: is_safe={is_safe}, reason='{reason}'") - return {"safe": is_safe, "reason": reason} - else: - error_msg = "AI safety check failed or returned invalid format." - print(f"AI Safety Check Error: Response was {safety_response}") - return {"safe": False, "reason": error_msg} - - -async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any]: - """Executes a shell command in an isolated Docker container after an AI safety check.""" - print(f"Attempting terminal command: {command}") - safety_check_result = await _check_command_safety(cog, command) - if not safety_check_result.get("safe"): - error_message = f"Command blocked by AI safety check: {safety_check_result.get('reason', 'Unknown')}" - print(error_message) - return {"error": error_message, "command": command} - - try: - cpu_limit = float(DOCKER_CPU_LIMIT) - cpu_period = 100000 - cpu_quota = int(cpu_limit * cpu_period) - except ValueError: - print(f"Warning: Invalid DOCKER_CPU_LIMIT '{DOCKER_CPU_LIMIT}'. Using default.") - cpu_quota = 50000 - cpu_period = 100000 - - mem_limit_bytes = parse_mem_limit(DOCKER_MEM_LIMIT) - if mem_limit_bytes is None: - print( - f"Warning: Invalid DOCKER_MEM_LIMIT '{DOCKER_MEM_LIMIT}'. Disabling memory limit." - ) - - client = None - container = None - try: - client = aiodocker.Docker() - print(f"Running command in Docker ({DOCKER_EXEC_IMAGE})...") - - config = { - "Image": DOCKER_EXEC_IMAGE, - "Cmd": ["/bin/sh", "-c", command], - "AttachStdout": True, - "AttachStderr": True, - "HostConfig": { - "NetworkDisabled": True, - "AutoRemove": False, # Changed to False - "CpuPeriod": cpu_period, - "CpuQuota": cpu_quota, - }, - } - if mem_limit_bytes is not None: - config["HostConfig"]["Memory"] = mem_limit_bytes - - # Use wait_for for the run call itself in case image pulling takes time - container = await asyncio.wait_for( - client.containers.run(config=config), - timeout=DOCKER_COMMAND_TIMEOUT - + 15, # Add buffer for container start/stop/pull - ) - - # Wait for the container to finish execution - wait_result = await asyncio.wait_for( - container.wait(), timeout=DOCKER_COMMAND_TIMEOUT - ) - exit_code = wait_result.get("StatusCode", -1) - - # Get logs after container finishes - # container.log() returns a list of strings when stream=False (default) - stdout_lines = await container.log(stdout=True, stderr=False) - stderr_lines = await container.log(stdout=False, stderr=True) - - stdout = "".join(stdout_lines) if stdout_lines else "" - stderr = "".join(stderr_lines) if stderr_lines else "" - - max_len = 1000 - stdout_trunc = stdout[:max_len] + ("..." if len(stdout) > max_len else "") - stderr_trunc = stderr[:max_len] + ("..." if len(stderr) > max_len else "") - - result = { - "status": "success" if exit_code == 0 else "execution_error", - "stdout": stdout_trunc, - "stderr": stderr_trunc, - "exit_code": exit_code, - } - print( - f"Docker command finished. Exit Code: {exit_code}. Output length: {len(stdout)}, Stderr length: {len(stderr)}" - ) - return result - - except asyncio.TimeoutError: - print("Docker command run, wait, or log retrieval timed out.") - # Attempt to stop/remove container if it exists and timed out - if container: - try: - print(f"Attempting to stop timed-out container {container.id[:12]}...") - await container.stop(t=1) - print(f"Container {container.id[:12]} stopped.") - # AutoRemove should handle removal, but log deletion attempt if needed - # print(f"Attempting to delete timed-out container {container.id[:12]}...") - # await container.delete(force=True) # Force needed if stop failed? - # print(f"Container {container.id[:12]} deleted.") - except aiodocker.exceptions.DockerError as stop_err: - print( - f"Error stopping/deleting timed-out container {container.id[:12]}: {stop_err}" - ) - except Exception as stop_exc: - print( - f"Unexpected error stopping/deleting timed-out container {container.id[:12]}: {stop_exc}" - ) - # No need to delete here, finally block will handle it - return { - "error": f"Command execution/log retrieval timed out after {DOCKER_COMMAND_TIMEOUT}s", - "command": command, - "status": "timeout", - } - except aiodocker.exceptions.DockerError as e: # Catch specific aiodocker errors - print(f"Docker API error: {e} (Status: {e.status})") - # Check for ImageNotFound specifically - if e.status == 404 and ("No such image" in str(e) or "not found" in str(e)): - print(f"Docker image not found: {DOCKER_EXEC_IMAGE}") - return { - "error": f"Docker image '{DOCKER_EXEC_IMAGE}' not found.", - "command": command, - "status": "docker_error", - } - return { - "error": f"Docker API error ({e.status}): {str(e)}", - "command": command, - "status": "docker_error", - } - except Exception as e: - print(f"Unexpected Docker error: {e}") - traceback.print_exc() - return { - "error": f"Unexpected error during Docker execution: {str(e)}", - "command": command, - "status": "error", - } - finally: - # Explicitly remove the container since AutoRemove is False - if container: - try: - print(f"Attempting to delete container {container.id[:12]}...") - await container.delete(force=True) - print(f"Container {container.id[:12]} deleted.") - except aiodocker.exceptions.DockerError as delete_err: - # Log error but don't raise, primary error is more important - print(f"Error deleting container {container.id[:12]}: {delete_err}") - except Exception as delete_exc: - print( - f"Unexpected error deleting container {container.id[:12]}: {delete_exc}" - ) # <--- Corrected indentation - # Ensure the client connection is closed - if client: - await client.close() - - -async def extract_web_content( - cog: commands.Cog, - urls: Union[str, List[str]], - extract_depth: str = "basic", - include_images: bool = False, -) -> Dict[str, Any]: - """Extract content from URLs using Tavily API""" - if not hasattr(cog, "tavily_client") or not cog.tavily_client: - return { - "error": "Tavily client not initialized.", - "timestamp": datetime.datetime.now().isoformat(), - } - - # Cost control / Logging for advanced extract - final_extract_depth = extract_depth - if extract_depth.lower() == "advanced": - if TAVILY_DISABLE_ADVANCED: - print( - f"Warning: Advanced Tavily extract requested but disabled by config. Falling back to basic." - ) - final_extract_depth = "basic" - else: - print( - f"Performing advanced Tavily extract (cost: 2 credits per 5 URLs) for URLs: {urls}" - ) - elif extract_depth.lower() != "basic": - print( - f"Warning: Invalid extract_depth '{extract_depth}' provided. Using 'basic'." - ) - final_extract_depth = "basic" - - try: - response = await asyncio.to_thread( - cog.tavily_client.extract, - urls=urls, - extract_depth=final_extract_depth, # Use validated depth - include_images=include_images, - ) - results = [ - { - "url": r.get("url"), - "raw_content": r.get("raw_content"), - "images": r.get("images"), - } - for r in response.get("results", []) - ] - failed_results = response.get("failed_results", []) - return { - "urls": urls, - "extract_depth": extract_depth, - "include_images": include_images, - "results": results, - "failed_results": failed_results, - "timestamp": datetime.datetime.now().isoformat(), - } - except Exception as e: - error_message = f"Error during Tavily extract for '{urls}': {str(e)}" - print(error_message) - return { - "error": error_message, - "timestamp": datetime.datetime.now().isoformat(), - } - - -# --- Tool Mapping --- -# This dictionary maps tool names (used in the AI prompt) to their implementation functions. -TOOL_MAPPING = { - "get_recent_messages": get_recent_messages, - "search_user_messages": search_user_messages, - "search_messages_by_content": search_messages_by_content, - "get_channel_info": get_channel_info, - "get_conversation_context": get_conversation_context, - "get_thread_context": get_thread_context, - "get_user_interaction_history": get_user_interaction_history, - "get_conversation_summary": get_conversation_summary, - "get_message_context": get_message_context, - "web_search": web_search, - # Point memory tools to the methods on the MemoryManager instance (accessed via cog) - "remember_user_fact": lambda cog, **kwargs: cog.memory_manager.add_user_fact( - **kwargs - ), - "get_user_facts": lambda cog, **kwargs: cog.memory_manager.get_user_facts(**kwargs), - "remember_general_fact": lambda cog, **kwargs: cog.memory_manager.add_general_fact( - **kwargs - ), - "get_general_facts": lambda cog, **kwargs: cog.memory_manager.get_general_facts( - **kwargs - ), - "timeout_user": timeout_user, - "calculate": calculate, - "run_python_code": run_python_code, - "create_poll": create_poll, - "run_terminal_command": run_terminal_command, - "remove_timeout": remove_timeout, - "extract_web_content": extract_web_content, -} diff --git a/wheatley/utils.py b/wheatley/utils.py deleted file mode 100644 index d4a48cf..0000000 --- a/wheatley/utils.py +++ /dev/null @@ -1,169 +0,0 @@ -import discord -import re -import random -import asyncio -import time -import datetime -import json -import os -from typing import TYPE_CHECKING, Optional, Tuple, Dict, Any - -if TYPE_CHECKING: - from .cog import WheatleyCog # For type hinting - -# --- Utility Functions --- -# Note: Functions needing cog state (like personality traits for mistakes) -# will need the 'cog' instance passed in. - - -def replace_mentions_with_names( - cog: "WheatleyCog", content: str, message: discord.Message -) -> str: - """Replaces user mentions (<@id> or <@!id>) with their display names.""" - if not message.mentions: - return content - - processed_content = content - sorted_mentions = sorted( - message.mentions, key=lambda m: len(str(m.id)), reverse=True - ) - - for member in sorted_mentions: - processed_content = processed_content.replace( - f"<@{member.id}>", member.display_name - ) - processed_content = processed_content.replace( - f"<@!{member.id}>", member.display_name - ) - return processed_content - - -def format_message(cog: "WheatleyCog", message: discord.Message) -> Dict[str, Any]: - """Helper function to format a discord.Message object into a dictionary.""" - processed_content = replace_mentions_with_names( - cog, message.content, message - ) # Pass cog - mentioned_users_details = [ - {"id": str(m.id), "name": m.name, "display_name": m.display_name} - for m in message.mentions - ] - - formatted_msg = { - "id": str(message.id), - "author": { - "id": str(message.author.id), - "name": message.author.name, - "display_name": message.author.display_name, - "bot": message.author.bot, - }, - "content": processed_content, - "created_at": message.created_at.isoformat(), - "attachments": [ - {"filename": a.filename, "url": a.url} for a in message.attachments - ], - "embeds": len(message.embeds) > 0, - "mentions": [ - {"id": str(m.id), "name": m.name} for m in message.mentions - ], # Keep original simple list too - "mentioned_users_details": mentioned_users_details, - "replied_to_message_id": None, - "replied_to_author_id": None, - "replied_to_author_name": None, - "replied_to_content": None, - "is_reply": False, - } - - if message.reference and message.reference.message_id: - formatted_msg["replied_to_message_id"] = str(message.reference.message_id) - formatted_msg["is_reply"] = True - # Try to get resolved details (might be None if message not cached/fetched) - ref_msg = message.reference.resolved - if isinstance(ref_msg, discord.Message): # Check if resolved is a Message - formatted_msg["replied_to_author_id"] = str(ref_msg.author.id) - formatted_msg["replied_to_author_name"] = ref_msg.author.display_name - formatted_msg["replied_to_content"] = ref_msg.content - # else: print(f"Referenced message {message.reference.message_id} not resolved.") # Optional debug - - return formatted_msg - - -def update_relationship( - cog: "WheatleyCog", user_id_1: str, user_id_2: str, change: float -): - """Updates the relationship score between two users.""" - if user_id_1 > user_id_2: - user_id_1, user_id_2 = user_id_2, user_id_1 - if user_id_1 not in cog.user_relationships: - cog.user_relationships[user_id_1] = {} - - current_score = cog.user_relationships[user_id_1].get(user_id_2, 0.0) - new_score = max(0.0, min(current_score + change, 100.0)) # Clamp 0-100 - cog.user_relationships[user_id_1][user_id_2] = new_score - # print(f"Updated relationship {user_id_1}-{user_id_2}: {current_score:.1f} -> {new_score:.1f} ({change:+.1f})") # Debug log - - -async def simulate_human_typing(cog: "WheatleyCog", channel, text: str): - """Shows typing indicator without significant delay.""" - # Minimal delay to ensure the typing indicator shows up reliably - # but doesn't add noticeable latency to the response. - # The actual sending of the message happens immediately after this. - async with channel.typing(): - await asyncio.sleep(0.1) # Very short sleep, just to ensure typing shows - - -async def log_internal_api_call( - cog: "WheatleyCog", - task_description: str, - payload: Dict[str, Any], - response_data: Optional[Dict[str, Any]], - error: Optional[Exception] = None, -): - """Helper function to log internal API calls to a file.""" - log_dir = "data" - log_file = os.path.join(log_dir, "internal_api_calls.log") - try: - os.makedirs(log_dir, exist_ok=True) - timestamp = datetime.datetime.now().isoformat() - log_entry = f"--- Log Entry: {timestamp} ---\n" - log_entry += f"Task: {task_description}\n" - log_entry += f"Model: {payload.get('model', 'N/A')}\n" - - # Sanitize payload for logging (avoid large base64 images) - payload_to_log = payload.copy() - if "messages" in payload_to_log: - sanitized_messages = [] - for msg in payload_to_log["messages"]: - if isinstance(msg.get("content"), list): # Multimodal message - new_content = [] - for part in msg["content"]: - if part.get("type") == "image_url" and part.get( - "image_url", {} - ).get("url", "").startswith("data:image"): - new_content.append( - { - "type": "image_url", - "image_url": {"url": "data:image/...[truncated]"}, - } - ) - else: - new_content.append(part) - sanitized_messages.append({**msg, "content": new_content}) - else: - sanitized_messages.append(msg) - payload_to_log["messages"] = sanitized_messages - - log_entry += f"Request Payload:\n{json.dumps(payload_to_log, indent=2)}\n" - if response_data: - log_entry += f"Response Data:\n{json.dumps(response_data, indent=2)}\n" - if error: - log_entry += f"Error: {str(error)}\n" - log_entry += "---\n\n" - - with open(log_file, "a", encoding="utf-8") as f: - f.write(log_entry) - except Exception as log_e: - print(f"!!! Failed to write to internal API log file {log_file}: {log_e}") - - -# Note: _create_human_like_mistake was removed as it wasn't used in the final on_message logic provided. -# If needed, it can be added back here, ensuring it takes 'cog' if it needs personality traits.