2025-05-04 21:20:13 -06:00

310 lines
17 KiB
Python

import discord
from discord.ext import commands
import asyncio
import os
import json
import aiohttp
import random
import time
from collections import defaultdict, deque
from typing import Dict, List, Any, Optional, Tuple, Set, Union
# Third-party imports needed by the Cog itself or its direct methods
from dotenv import load_dotenv
from tavily import TavilyClient # Needed for tavily_client init
# --- Relative Imports from Wheatley Package ---
from .config import (
PROJECT_ID, LOCATION, TAVILY_API_KEY, DEFAULT_MODEL, FALLBACK_MODEL, # Use GCP config
DB_PATH, CHROMA_PATH, SEMANTIC_MODEL_NAME, MAX_USER_FACTS, MAX_GENERAL_FACTS,
# Removed Mood/Personality/Interest/Learning/Goal configs
CHANNEL_TOPIC_CACHE_TTL, CONTEXT_WINDOW_SIZE,
API_TIMEOUT, SUMMARY_API_TIMEOUT, API_RETRY_ATTEMPTS, API_RETRY_DELAY,
PROACTIVE_LULL_THRESHOLD, PROACTIVE_BOT_SILENCE_THRESHOLD, PROACTIVE_LULL_CHANCE,
PROACTIVE_TOPIC_RELEVANCE_THRESHOLD, PROACTIVE_TOPIC_CHANCE,
# Removed Relationship/Sentiment/Interest proactive configs
TOPIC_UPDATE_INTERVAL, SENTIMENT_UPDATE_INTERVAL,
RESPONSE_SCHEMA, TOOLS # Import necessary configs
)
# Import functions/classes from other modules
from .memory import MemoryManager # Import from local memory.py
from .background import background_processing_task # Keep background task for potential future use (e.g., cache cleanup)
from .commands import setup_commands # Import the setup helper
from .listeners import on_ready_listener, on_message_listener, on_reaction_add_listener, on_reaction_remove_listener # Import listener functions
from . import config as WheatleyConfig # Import config module for get_wheatley_stats
# Load environment variables (might be loaded globally in main bot script too)
load_dotenv()
class WheatleyCog(commands.Cog, name="Wheatley"): # Renamed class and Cog name
"""A special cog for the Wheatley bot that uses Google Vertex AI API""" # Updated docstring
def __init__(self, bot):
self.bot = bot
# GCP Project/Location are used by vertexai.init() in api.py
self.tavily_api_key = TAVILY_API_KEY # Use imported config
self.session: Optional[aiohttp.ClientSession] = None # Keep for other potential HTTP requests (e.g., Piston)
self.tavily_client = TavilyClient(api_key=self.tavily_api_key) if self.tavily_api_key else None
self.default_model = DEFAULT_MODEL # Use imported config
self.fallback_model = FALLBACK_MODEL # Use imported config
# Removed MOOD_OPTIONS
self.current_channel: Optional[Union[discord.TextChannel, discord.Thread, discord.DMChannel]] = None # Type hint current channel
# Instantiate MemoryManager
self.memory_manager = MemoryManager(
db_path=DB_PATH,
max_user_facts=MAX_USER_FACTS,
max_general_facts=MAX_GENERAL_FACTS,
chroma_path=CHROMA_PATH,
semantic_model_name=SEMANTIC_MODEL_NAME
)
# --- State Variables (Simplified for Wheatley) ---
# Removed mood, personality evolution, interest tracking, learning state
self.needs_json_reminder = False # Flag to remind AI about JSON format
# Topic tracking (Kept for context)
self.active_topics = defaultdict(lambda: {
"topics": [], "last_update": time.time(), "topic_history": [],
"user_topic_interests": defaultdict(list) # Kept for potential future analysis, not proactive triggers
})
# Conversation tracking / Caches
self.conversation_history = defaultdict(lambda: deque(maxlen=100))
self.thread_history = defaultdict(lambda: deque(maxlen=50))
self.user_conversation_mapping = defaultdict(set)
self.channel_activity = defaultdict(lambda: 0.0) # Use float for timestamp
self.conversation_topics = defaultdict(str) # Simplified topic tracking
self.user_relationships = defaultdict(dict) # Kept for potential context/analysis
self.conversation_summaries: Dict[int, Dict[str, Any]] = {} # Store dict with summary and timestamp
self.channel_topics_cache: Dict[int, Dict[str, Any]] = {} # Store dict with topic and timestamp
self.message_cache = {
'by_channel': defaultdict(lambda: deque(maxlen=CONTEXT_WINDOW_SIZE)), # Use config
'by_user': defaultdict(lambda: deque(maxlen=50)),
'by_thread': defaultdict(lambda: deque(maxlen=50)),
'global_recent': deque(maxlen=200),
'mentioned': deque(maxlen=50),
'replied_to': defaultdict(lambda: deque(maxlen=20))
}
self.active_conversations = {} # Kept for basic tracking
self.bot_last_spoke = defaultdict(float)
self.message_reply_map = {}
# Enhanced sentiment tracking (Kept for context/analysis)
self.conversation_sentiment = defaultdict(lambda: {
"overall": "neutral", "intensity": 0.5, "recent_trend": "stable",
"user_sentiments": {}, "last_update": time.time()
})
# Removed self.sentiment_update_interval as it was only used in analysis
# Reaction Tracking (Renamed)
self.wheatley_message_reactions = defaultdict(lambda: {"positive": 0, "negative": 0, "topic": None, "timestamp": 0.0}) # Renamed
# Background task handle (Kept for potential future tasks like cache cleanup)
self.background_task: Optional[asyncio.Task] = None
self.last_stats_push = time.time() # Timestamp for last stats push
# Removed evolution, reflection, goal timestamps
# --- Stats Tracking ---
self.api_stats = defaultdict(lambda: {"success": 0, "failure": 0, "retries": 0, "total_time": 0.0, "count": 0}) # Keyed by model name
self.tool_stats = defaultdict(lambda: {"success": 0, "failure": 0, "total_time": 0.0, "count": 0}) # Keyed by tool name
# --- Setup Commands and Listeners ---
# Add commands defined in commands.py
self.command_functions = setup_commands(self)
# Store command names for reference - safely handle Command objects
self.registered_commands = []
for func in self.command_functions:
# For app commands, use the name attribute directly
if hasattr(func, "name"):
self.registered_commands.append(func.name)
# For regular functions, use __name__
elif hasattr(func, "__name__"):
self.registered_commands.append(func.__name__)
else:
self.registered_commands.append(str(func))
print(f"WheatleyCog initialized with commands: {self.registered_commands}") # Updated print
async def cog_load(self):
"""Create aiohttp session, initialize DB, start background task"""
self.session = aiohttp.ClientSession()
print("WheatleyCog: aiohttp session created") # Updated print
# Initialize DB via MemoryManager
await self.memory_manager.initialize_sqlite_database()
# Removed loading of baseline personality and interests
# Vertex AI initialization happens in api.py using PROJECT_ID and LOCATION from config
print(f"WheatleyCog: Using default model: {self.default_model}") # Updated print
if not self.tavily_api_key:
print("WARNING: Tavily API key not configured (TAVILY_API_KEY). Web search disabled.")
# Add listeners to the bot instance
# IMPORTANT: Don't override on_member_join or on_member_remove events
# Check if the bot already has event listeners for member join/leave
has_member_join = 'on_member_join' in self.bot.extra_events
has_member_remove = 'on_member_remove' in self.bot.extra_events
print(f"WheatleyCog: Bot already has event listeners - on_member_join: {has_member_join}, on_member_remove: {has_member_remove}")
@self.bot.event
async def on_ready():
await on_ready_listener(self)
@self.bot.event
async def on_message(message):
# Ensure commands are processed if using command prefix
if message.content.startswith(self.bot.command_prefix):
await self.bot.process_commands(message)
# Always run the message listener for potential AI responses/tracking
await on_message_listener(self, message)
@self.bot.event
async def on_reaction_add(reaction, user):
await on_reaction_add_listener(self, reaction, user)
@self.bot.event
async def on_reaction_remove(reaction, user):
await on_reaction_remove_listener(self, reaction, user)
print("WheatleyCog: Listeners added.") # Updated print
# Commands will be synced in on_ready
print("WheatleyCog: Commands will be synced when the bot is ready.") # Updated print
# Start background task (kept for potential future use)
if self.background_task is None or self.background_task.done():
self.background_task = asyncio.create_task(background_processing_task(self))
print("WheatleyCog: Started background processing task.") # Updated print
else:
print("WheatleyCog: Background processing task already running.") # Updated print
async def cog_unload(self):
"""Close session and cancel background task"""
if self.session and not self.session.closed:
await self.session.close()
print("WheatleyCog: aiohttp session closed") # Updated print
if self.background_task and not self.background_task.done():
self.background_task.cancel()
print("WheatleyCog: Cancelled background processing task.") # Updated print
print("WheatleyCog: Listeners will be removed when bot is closed.") # Updated print
print("WheatleyCog unloaded.") # Updated print
# --- Helper methods that might remain in the cog ---
# _update_relationship kept for potential context/analysis use
def _update_relationship(self, user_id_1: str, user_id_2: str, change: float):
"""Updates the relationship score between two users."""
if user_id_1 > user_id_2: user_id_1, user_id_2 = user_id_2, user_id_1
if user_id_1 not in self.user_relationships: self.user_relationships[user_id_1] = {}
current_score = self.user_relationships[user_id_1].get(user_id_2, 0.0)
new_score = max(0.0, min(current_score + change, 100.0)) # Clamp 0-100
self.user_relationships[user_id_1][user_id_2] = new_score
# print(f"Updated relationship {user_id_1}-{user_id_2}: {current_score:.1f} -> {new_score:.1f} ({change:+.1f})") # Debug log
async def get_wheatley_stats(self) -> Dict[str, Any]: # Renamed method
"""Collects various internal stats for Wheatley.""" # Updated docstring
stats = {"config": {}, "runtime": {}, "memory": {}, "api_stats": {}, "tool_stats": {}}
# --- Config (Simplified) ---
stats["config"]["default_model"] = WheatleyConfig.DEFAULT_MODEL
stats["config"]["fallback_model"] = WheatleyConfig.FALLBACK_MODEL
stats["config"]["safety_check_model"] = WheatleyConfig.SAFETY_CHECK_MODEL
stats["config"]["db_path"] = WheatleyConfig.DB_PATH
stats["config"]["chroma_path"] = WheatleyConfig.CHROMA_PATH
stats["config"]["semantic_model_name"] = WheatleyConfig.SEMANTIC_MODEL_NAME
stats["config"]["max_user_facts"] = WheatleyConfig.MAX_USER_FACTS
stats["config"]["max_general_facts"] = WheatleyConfig.MAX_GENERAL_FACTS
stats["config"]["context_window_size"] = WheatleyConfig.CONTEXT_WINDOW_SIZE
stats["config"]["api_timeout"] = WheatleyConfig.API_TIMEOUT
stats["config"]["summary_api_timeout"] = WheatleyConfig.SUMMARY_API_TIMEOUT
stats["config"]["proactive_lull_threshold"] = WheatleyConfig.PROACTIVE_LULL_THRESHOLD
stats["config"]["proactive_bot_silence_threshold"] = WheatleyConfig.PROACTIVE_BOT_SILENCE_THRESHOLD
stats["config"]["topic_update_interval"] = WheatleyConfig.TOPIC_UPDATE_INTERVAL
stats["config"]["sentiment_update_interval"] = WheatleyConfig.SENTIMENT_UPDATE_INTERVAL
stats["config"]["docker_command_timeout"] = WheatleyConfig.DOCKER_COMMAND_TIMEOUT
stats["config"]["project_id_set"] = bool(WheatleyConfig.PROJECT_ID != "your-gcp-project-id")
stats["config"]["location_set"] = bool(WheatleyConfig.LOCATION != "us-central1")
stats["config"]["tavily_api_key_set"] = bool(WheatleyConfig.TAVILY_API_KEY)
stats["config"]["piston_api_url_set"] = bool(WheatleyConfig.PISTON_API_URL)
# --- Runtime (Simplified) ---
# Removed mood, evolution
stats["runtime"]["needs_json_reminder"] = self.needs_json_reminder
stats["runtime"]["background_task_running"] = bool(self.background_task and not self.background_task.done())
stats["runtime"]["active_topics_channels"] = len(self.active_topics)
stats["runtime"]["conversation_history_channels"] = len(self.conversation_history)
stats["runtime"]["thread_history_threads"] = len(self.thread_history)
stats["runtime"]["user_conversation_mappings"] = len(self.user_conversation_mapping)
stats["runtime"]["channel_activity_tracked"] = len(self.channel_activity)
stats["runtime"]["conversation_topics_tracked"] = len(self.conversation_topics) # Simplified topic tracking
stats["runtime"]["user_relationships_pairs"] = sum(len(v) for v in self.user_relationships.values())
stats["runtime"]["conversation_summaries_cached"] = len(self.conversation_summaries)
stats["runtime"]["channel_topics_cached"] = len(self.channel_topics_cache)
stats["runtime"]["message_cache_global_count"] = len(self.message_cache['global_recent'])
stats["runtime"]["message_cache_mentioned_count"] = len(self.message_cache['mentioned'])
stats["runtime"]["active_conversations_count"] = len(self.active_conversations)
stats["runtime"]["bot_last_spoke_channels"] = len(self.bot_last_spoke)
stats["runtime"]["message_reply_map_size"] = len(self.message_reply_map)
stats["runtime"]["conversation_sentiment_channels"] = len(self.conversation_sentiment)
# Removed Gurt participation topics
stats["runtime"]["wheatley_message_reactions_tracked"] = len(self.wheatley_message_reactions) # Renamed
# --- Memory (Simplified) ---
try:
# Removed Personality, Interests
user_fact_count = await self.memory_manager._db_fetchone("SELECT COUNT(*) FROM user_facts")
general_fact_count = await self.memory_manager._db_fetchone("SELECT COUNT(*) FROM general_facts")
stats["memory"]["user_facts_count"] = user_fact_count[0] if user_fact_count else 0
stats["memory"]["general_facts_count"] = general_fact_count[0] if general_fact_count else 0
# ChromaDB Stats
stats["memory"]["chromadb_message_collection_count"] = await asyncio.to_thread(self.memory_manager.semantic_collection.count) if self.memory_manager.semantic_collection else "N/A"
stats["memory"]["chromadb_fact_collection_count"] = await asyncio.to_thread(self.memory_manager.fact_collection.count) if self.memory_manager.fact_collection else "N/A"
except Exception as e:
stats["memory"]["error"] = f"Failed to retrieve memory stats: {e}"
# --- API & Tool Stats ---
stats["api_stats"] = dict(self.api_stats)
stats["tool_stats"] = dict(self.tool_stats)
# Calculate average times
for model, data in stats["api_stats"].items():
if data["count"] > 0: data["average_time_ms"] = round((data["total_time"] / data["count"]) * 1000, 2)
else: data["average_time_ms"] = 0
for tool, data in stats["tool_stats"].items():
if data["count"] > 0: data["average_time_ms"] = round((data["total_time"] / data["count"]) * 1000, 2)
else: data["average_time_ms"] = 0
return stats
async def sync_commands(self):
"""Manually sync commands with Discord."""
try:
print("WheatleyCog: Manually syncing commands with Discord...") # Updated print
synced = await self.bot.tree.sync()
print(f"WheatleyCog: Synced {len(synced)} command(s)") # Updated print
# List the synced commands
wheatley_commands = [cmd.name for cmd in self.bot.tree.get_commands() if cmd.name.startswith("wheatley")] # Updated prefix
print(f"WheatleyCog: Available Wheatley commands: {', '.join(wheatley_commands)}") # Updated print
return synced, wheatley_commands
except Exception as e:
print(f"WheatleyCog: Failed to sync commands: {e}") # Updated print
import traceback
traceback.print_exc()
return [], []
# Setup function for loading the cog
async def setup(bot):
"""Add the WheatleyCog to the bot.""" # Updated docstring
await bot.add_cog(WheatleyCog(bot)) # Use renamed class
print("WheatleyCog setup complete.") # Updated print