import aiosqlite import asyncio import os import time import datetime import re import hashlib # Added for chroma_id generation import json # Added for personality trait serialization/deserialization from typing import Dict, List, Any, Optional, Tuple, Union # Added Union import chromadb from chromadb.utils import embedding_functions from sentence_transformers import SentenceTransformer import logging # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) # Use a specific logger name for Wheatley's memory logger = logging.getLogger("wheatley_memory") # Constants (Removed Interest constants) # --- Helper Function for Keyword Scoring (Kept for potential future use, but unused currently) --- def calculate_keyword_score(text: str, context: str) -> int: """Calculates a simple keyword overlap score.""" if not context or not text: return 0 context_words = set(re.findall(r"\b\w+\b", context.lower())) text_words = set(re.findall(r"\b\w+\b", text.lower())) # Ignore very common words (basic stopword list) stopwords = { "the", "a", "is", "in", "it", "of", "and", "to", "for", "on", "with", "that", "this", "i", "you", "me", "my", "your", } context_words -= stopwords text_words -= stopwords if not context_words: # Avoid division by zero if context is only stopwords return 0 overlap = len(context_words.intersection(text_words)) # Normalize score slightly by context length (more overlap needed for longer context) # score = overlap / (len(context_words) ** 0.5) # Example normalization score = overlap # Simpler score for now return score class MemoryManager: """Handles database interactions for Wheatley's memory (facts and semantic).""" # Updated docstring def __init__( self, db_path: str, max_user_facts: int = 20, max_general_facts: int = 100, semantic_model_name: str = "all-MiniLM-L6-v2", chroma_path: str = "data/chroma_db_wheatley", ): # Changed default chroma_path self.db_path = db_path self.max_user_facts = max_user_facts self.max_general_facts = max_general_facts self.db_lock = asyncio.Lock() # Lock for SQLite operations # Ensure data directories exist os.makedirs(os.path.dirname(self.db_path), exist_ok=True) os.makedirs(chroma_path, exist_ok=True) logger.info( f"Wheatley MemoryManager initialized with db_path: {self.db_path}, chroma_path: {chroma_path}" ) # Updated text # --- Semantic Memory Setup --- self.chroma_path = chroma_path self.semantic_model_name = semantic_model_name self.chroma_client = None self.embedding_function = None self.semantic_collection = None # For messages self.fact_collection = None # For facts self.transformer_model = None self._initialize_semantic_memory_sync() # Initialize semantic components synchronously for simplicity during init def _initialize_semantic_memory_sync(self): """Synchronously initializes ChromaDB client, model, and collection.""" try: logger.info("Initializing ChromaDB client...") # Use PersistentClient for saving data to disk self.chroma_client = chromadb.PersistentClient(path=self.chroma_path) logger.info( f"Loading Sentence Transformer model: {self.semantic_model_name}..." ) # Load the model directly self.transformer_model = SentenceTransformer(self.semantic_model_name) # Create a custom embedding function using the loaded model class CustomEmbeddingFunction(embedding_functions.EmbeddingFunction): def __init__(self, model): self.model = model def __call__(self, input: chromadb.Documents) -> chromadb.Embeddings: # Ensure input is a list of strings if not isinstance(input, list): input = [str(input)] # Convert single item to list elif not all(isinstance(item, str) for item in input): input = [ str(item) for item in input ] # Ensure all items are strings logger.debug(f"Generating embeddings for {len(input)} documents.") embeddings = self.model.encode( input, show_progress_bar=False ).tolist() logger.debug(f"Generated {len(embeddings)} embeddings.") return embeddings self.embedding_function = CustomEmbeddingFunction(self.transformer_model) logger.info( "Getting/Creating ChromaDB collection 'wheatley_semantic_memory'..." ) # Renamed collection # Get or create the collection with the custom embedding function self.semantic_collection = self.chroma_client.get_or_create_collection( name="wheatley_semantic_memory", # Renamed collection embedding_function=self.embedding_function, metadata={"hnsw:space": "cosine"}, # Use cosine distance for similarity ) logger.info("ChromaDB message collection initialized successfully.") logger.info( "Getting/Creating ChromaDB collection 'wheatley_fact_memory'..." ) # Renamed collection # Get or create the collection for facts self.fact_collection = self.chroma_client.get_or_create_collection( name="wheatley_fact_memory", # Renamed collection embedding_function=self.embedding_function, metadata={"hnsw:space": "cosine"}, # Use cosine distance for similarity ) logger.info("ChromaDB fact collection initialized successfully.") except Exception as e: logger.error( f"Failed to initialize semantic memory (ChromaDB): {e}", exc_info=True ) # Set components to None to indicate failure self.chroma_client = None self.transformer_model = None self.embedding_function = None self.semantic_collection = None self.fact_collection = None # Also set fact_collection to None on error async def initialize_sqlite_database(self): """Initializes the SQLite database and creates tables if they don't exist.""" async with aiosqlite.connect(self.db_path) as db: await db.execute("PRAGMA journal_mode=WAL;") # Create user_facts table if it doesn't exist await db.execute( """ CREATE TABLE IF NOT EXISTS user_facts ( user_id TEXT NOT NULL, fact TEXT NOT NULL, chroma_id TEXT, -- Added for linking to ChromaDB timestamp REAL DEFAULT (unixepoch('now')), PRIMARY KEY (user_id, fact) ); """ ) # Check if chroma_id column exists in user_facts table try: cursor = await db.execute("PRAGMA table_info(user_facts)") columns = await cursor.fetchall() column_names = [column[1] for column in columns] if "chroma_id" not in column_names: logger.info("Adding chroma_id column to user_facts table") await db.execute("ALTER TABLE user_facts ADD COLUMN chroma_id TEXT") except Exception as e: logger.error( f"Error checking/adding chroma_id column to user_facts: {e}", exc_info=True, ) # Create indexes await db.execute( "CREATE INDEX IF NOT EXISTS idx_user_facts_user ON user_facts (user_id);" ) await db.execute( "CREATE INDEX IF NOT EXISTS idx_user_facts_chroma_id ON user_facts (chroma_id);" ) # Index for chroma_id # Create general_facts table if it doesn't exist await db.execute( """ CREATE TABLE IF NOT EXISTS general_facts ( fact TEXT PRIMARY KEY NOT NULL, chroma_id TEXT, -- Added for linking to ChromaDB timestamp REAL DEFAULT (unixepoch('now')) ); """ ) # Check if chroma_id column exists in general_facts table try: cursor = await db.execute("PRAGMA table_info(general_facts)") columns = await cursor.fetchall() column_names = [column[1] for column in columns] if "chroma_id" not in column_names: logger.info("Adding chroma_id column to general_facts table") await db.execute( "ALTER TABLE general_facts ADD COLUMN chroma_id TEXT" ) except Exception as e: logger.error( f"Error checking/adding chroma_id column to general_facts: {e}", exc_info=True, ) # Create index for general_facts await db.execute( "CREATE INDEX IF NOT EXISTS idx_general_facts_chroma_id ON general_facts (chroma_id);" ) # Index for chroma_id # --- Removed Personality Table --- # --- Removed Interests Table --- # --- Removed Goals Table --- await db.commit() logger.info( f"Wheatley SQLite database initialized/verified at {self.db_path}" ) # Updated text # --- SQLite Helper Methods --- async def _db_execute(self, sql: str, params: tuple = ()): async with self.db_lock: async with aiosqlite.connect(self.db_path) as db: await db.execute(sql, params) await db.commit() async def _db_fetchone(self, sql: str, params: tuple = ()) -> Optional[tuple]: async with aiosqlite.connect(self.db_path) as db: async with db.execute(sql, params) as cursor: return await cursor.fetchone() async def _db_fetchall(self, sql: str, params: tuple = ()) -> List[tuple]: async with aiosqlite.connect(self.db_path) as db: async with db.execute(sql, params) as cursor: return await cursor.fetchall() # --- User Fact Memory Methods (SQLite + Relevance) --- async def add_user_fact(self, user_id: str, fact: str) -> Dict[str, Any]: """Stores a fact about a user in the SQLite database, enforcing limits.""" if not user_id or not fact: return {"error": "user_id and fact are required."} logger.info(f"Attempting to add user fact for {user_id}: '{fact}'") try: # Check SQLite first existing = await self._db_fetchone( "SELECT chroma_id FROM user_facts WHERE user_id = ? AND fact = ?", (user_id, fact), ) if existing: logger.info(f"Fact already known for user {user_id} (SQLite).") return {"status": "duplicate", "user_id": user_id, "fact": fact} count_result = await self._db_fetchone( "SELECT COUNT(*) FROM user_facts WHERE user_id = ?", (user_id,) ) current_count = count_result[0] if count_result else 0 status = "added" deleted_chroma_id = None if current_count >= self.max_user_facts: logger.warning( f"User {user_id} fact limit ({self.max_user_facts}) reached. Deleting oldest." ) # Fetch oldest fact and its chroma_id for deletion oldest_fact_row = await self._db_fetchone( "SELECT fact, chroma_id FROM user_facts WHERE user_id = ? ORDER BY timestamp ASC LIMIT 1", (user_id,), ) if oldest_fact_row: oldest_fact, deleted_chroma_id = oldest_fact_row await self._db_execute( "DELETE FROM user_facts WHERE user_id = ? AND fact = ?", (user_id, oldest_fact), ) logger.info( f"Deleted oldest fact for user {user_id} from SQLite: '{oldest_fact}'" ) status = ( "limit_reached" # Indicate limit was hit but fact was added ) # Generate chroma_id fact_hash = hashlib.sha1(fact.encode()).hexdigest()[:16] # Short hash chroma_id = f"user-{user_id}-{fact_hash}" # Insert into SQLite await self._db_execute( "INSERT INTO user_facts (user_id, fact, chroma_id) VALUES (?, ?, ?)", (user_id, fact, chroma_id), ) logger.info(f"Fact added for user {user_id} to SQLite.") # Add to ChromaDB fact collection if self.fact_collection and self.embedding_function: try: metadata = { "user_id": user_id, "type": "user", "timestamp": time.time(), } await asyncio.to_thread( self.fact_collection.add, documents=[fact], metadatas=[metadata], ids=[chroma_id], ) logger.info( f"Fact added/updated for user {user_id} in ChromaDB (ID: {chroma_id})." ) # Delete the oldest fact from ChromaDB if limit was reached if deleted_chroma_id: logger.info( f"Attempting to delete oldest fact from ChromaDB (ID: {deleted_chroma_id})." ) await asyncio.to_thread( self.fact_collection.delete, ids=[deleted_chroma_id] ) logger.info( f"Successfully deleted oldest fact from ChromaDB (ID: {deleted_chroma_id})." ) except Exception as chroma_e: logger.error( f"ChromaDB error adding/deleting user fact for {user_id} (ID: {chroma_id}): {chroma_e}", exc_info=True, ) # Note: Fact is still in SQLite, but ChromaDB might be inconsistent. Consider rollback? For now, just log. else: logger.warning( f"ChromaDB fact collection not available. Skipping embedding for user fact {user_id}." ) return {"status": status, "user_id": user_id, "fact_added": fact} except Exception as e: logger.error(f"Error adding user fact for {user_id}: {e}", exc_info=True) return {"error": f"Database error adding user fact: {str(e)}"} async def get_user_facts( self, user_id: str, context: Optional[str] = None ) -> List[str]: """Retrieves stored facts about a user, optionally scored by relevance to context.""" if not user_id: logger.warning("get_user_facts called without user_id.") return [] logger.info( f"Retrieving facts for user {user_id} (context provided: {bool(context)})" ) limit = self.max_user_facts # Use the class attribute for limit try: if context and self.fact_collection and self.embedding_function: # --- Semantic Search --- logger.debug( f"Performing semantic search for user facts (User: {user_id}, Limit: {limit})" ) try: # Query ChromaDB for facts relevant to the context results = await asyncio.to_thread( self.fact_collection.query, query_texts=[context], n_results=limit, where={ # Use $and for multiple conditions "$and": [{"user_id": user_id}, {"type": "user"}] }, include=["documents"], # Only need the fact text ) logger.debug(f"ChromaDB user fact query results: {results}") if results and results.get("documents") and results["documents"][0]: relevant_facts = results["documents"][0] logger.info( f"Found {len(relevant_facts)} semantically relevant user facts for {user_id}." ) return relevant_facts else: logger.info( f"No semantic user facts found for {user_id} matching context." ) return [] # Return empty list if no semantic matches except Exception as chroma_e: logger.error( f"ChromaDB error searching user facts for {user_id}: {chroma_e}", exc_info=True, ) # Fallback to SQLite retrieval on ChromaDB error logger.warning( f"Falling back to SQLite retrieval for user facts {user_id} due to ChromaDB error." ) # Proceed to the SQLite block below # --- SQLite Fallback / No Context --- # If no context, or if ChromaDB failed/unavailable, get newest N facts from SQLite logger.debug( f"Retrieving user facts from SQLite (User: {user_id}, Limit: {limit})" ) rows_ordered = await self._db_fetchall( "SELECT fact FROM user_facts WHERE user_id = ? ORDER BY timestamp DESC LIMIT ?", (user_id, limit), ) sqlite_facts = [row[0] for row in rows_ordered] logger.info( f"Retrieved {len(sqlite_facts)} user facts from SQLite for {user_id}." ) return sqlite_facts except Exception as e: logger.error( f"Error retrieving user facts for {user_id}: {e}", exc_info=True ) return [] # --- General Fact Memory Methods (SQLite + Relevance) --- async def add_general_fact(self, fact: str) -> Dict[str, Any]: """Stores a general fact in the SQLite database, enforcing limits.""" if not fact: return {"error": "fact is required."} logger.info(f"Attempting to add general fact: '{fact}'") try: # Check SQLite first existing = await self._db_fetchone( "SELECT chroma_id FROM general_facts WHERE fact = ?", (fact,) ) if existing: logger.info(f"General fact already known (SQLite): '{fact}'") return {"status": "duplicate", "fact": fact} count_result = await self._db_fetchone( "SELECT COUNT(*) FROM general_facts", () ) current_count = count_result[0] if count_result else 0 status = "added" deleted_chroma_id = None if current_count >= self.max_general_facts: logger.warning( f"General fact limit ({self.max_general_facts}) reached. Deleting oldest." ) # Fetch oldest fact and its chroma_id for deletion oldest_fact_row = await self._db_fetchone( "SELECT fact, chroma_id FROM general_facts ORDER BY timestamp ASC LIMIT 1", (), ) if oldest_fact_row: oldest_fact, deleted_chroma_id = oldest_fact_row await self._db_execute( "DELETE FROM general_facts WHERE fact = ?", (oldest_fact,) ) logger.info( f"Deleted oldest general fact from SQLite: '{oldest_fact}'" ) status = "limit_reached" # Generate chroma_id fact_hash = hashlib.sha1(fact.encode()).hexdigest()[:16] # Short hash chroma_id = f"general-{fact_hash}" # Insert into SQLite await self._db_execute( "INSERT INTO general_facts (fact, chroma_id) VALUES (?, ?)", (fact, chroma_id), ) logger.info(f"General fact added to SQLite: '{fact}'") # Add to ChromaDB fact collection if self.fact_collection and self.embedding_function: try: metadata = {"type": "general", "timestamp": time.time()} await asyncio.to_thread( self.fact_collection.add, documents=[fact], metadatas=[metadata], ids=[chroma_id], ) logger.info( f"General fact added/updated in ChromaDB (ID: {chroma_id})." ) # Delete the oldest fact from ChromaDB if limit was reached if deleted_chroma_id: logger.info( f"Attempting to delete oldest general fact from ChromaDB (ID: {deleted_chroma_id})." ) await asyncio.to_thread( self.fact_collection.delete, ids=[deleted_chroma_id] ) logger.info( f"Successfully deleted oldest general fact from ChromaDB (ID: {deleted_chroma_id})." ) except Exception as chroma_e: logger.error( f"ChromaDB error adding/deleting general fact (ID: {chroma_id}): {chroma_e}", exc_info=True, ) # Note: Fact is still in SQLite. else: logger.warning( f"ChromaDB fact collection not available. Skipping embedding for general fact." ) return {"status": status, "fact_added": fact} except Exception as e: logger.error(f"Error adding general fact: {e}", exc_info=True) return {"error": f"Database error adding general fact: {str(e)}"} async def get_general_facts( self, query: Optional[str] = None, limit: Optional[int] = 10, context: Optional[str] = None, ) -> List[str]: """Retrieves stored general facts, optionally filtering by query or scoring by context relevance.""" logger.info( f"Retrieving general facts (query='{query}', limit={limit}, context provided: {bool(context)})" ) limit = min(max(1, limit or 10), 50) # Use provided limit or default 10, max 50 try: if context and self.fact_collection and self.embedding_function: # --- Semantic Search (Prioritized if context is provided) --- # Note: The 'query' parameter is ignored when context is provided for semantic search. logger.debug( f"Performing semantic search for general facts (Limit: {limit})" ) try: results = await asyncio.to_thread( self.fact_collection.query, query_texts=[context], n_results=limit, where={"type": "general"}, # Filter by type include=["documents"], # Only need the fact text ) logger.debug(f"ChromaDB general fact query results: {results}") if results and results.get("documents") and results["documents"][0]: relevant_facts = results["documents"][0] logger.info( f"Found {len(relevant_facts)} semantically relevant general facts." ) return relevant_facts else: logger.info("No semantic general facts found matching context.") return [] # Return empty list if no semantic matches except Exception as chroma_e: logger.error( f"ChromaDB error searching general facts: {chroma_e}", exc_info=True, ) # Fallback to SQLite retrieval on ChromaDB error logger.warning( "Falling back to SQLite retrieval for general facts due to ChromaDB error." ) # Proceed to the SQLite block below, respecting the original 'query' if present # --- SQLite Fallback / No Context / ChromaDB Error --- # If no context, or if ChromaDB failed/unavailable, get newest N facts from SQLite, applying query if present. logger.debug( f"Retrieving general facts from SQLite (Query: '{query}', Limit: {limit})" ) sql = "SELECT fact FROM general_facts" params = [] if query: # Apply the LIKE query only in the SQLite fallback scenario sql += " WHERE fact LIKE ?" params.append(f"%{query}%") sql += " ORDER BY timestamp DESC LIMIT ?" params.append(limit) rows_ordered = await self._db_fetchall(sql, tuple(params)) sqlite_facts = [row[0] for row in rows_ordered] logger.info( f"Retrieved {len(sqlite_facts)} general facts from SQLite (Query: '{query}')." ) return sqlite_facts except Exception as e: logger.error(f"Error retrieving general facts: {e}", exc_info=True) return [] # --- Personality Trait Methods (REMOVED) --- # --- Interest Methods (REMOVED) --- # --- Goal Management Methods (REMOVED) --- # --- Semantic Memory Methods (ChromaDB) --- async def add_message_embedding( self, message_id: str, text: str, metadata: Dict[str, Any] ) -> Dict[str, Any]: """Generates embedding and stores a message in ChromaDB.""" if not self.semantic_collection: return {"error": "Semantic memory (ChromaDB) is not initialized."} if not text: return {"error": "Cannot add empty text to semantic memory."} logger.info(f"Adding message {message_id} to semantic memory.") try: # ChromaDB expects lists for inputs await asyncio.to_thread( self.semantic_collection.add, documents=[text], metadatas=[metadata], ids=[message_id], ) logger.info(f"Successfully added message {message_id} to ChromaDB.") return {"status": "success", "message_id": message_id} except Exception as e: logger.error( f"ChromaDB error adding message {message_id}: {e}", exc_info=True ) return {"error": f"Semantic memory error adding message: {str(e)}"} async def search_semantic_memory( self, query_text: str, n_results: int = 5, filter_metadata: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """Searches ChromaDB for messages semantically similar to the query text.""" if not self.semantic_collection: logger.warning( "Search semantic memory called, but ChromaDB is not initialized." ) return [] if not query_text: logger.warning("Search semantic memory called with empty query text.") return [] logger.info( f"Searching semantic memory (n_results={n_results}, filter={filter_metadata}) for query: '{query_text[:50]}...'" ) try: # Perform the query in a separate thread as ChromaDB operations can be blocking results = await asyncio.to_thread( self.semantic_collection.query, query_texts=[query_text], n_results=n_results, where=filter_metadata, # Optional filter based on metadata include=[ "metadatas", "documents", "distances", ], # Include distance for relevance ) logger.debug(f"ChromaDB query results: {results}") # Process results processed_results = [] if results and results.get("ids") and results["ids"][0]: for i, doc_id in enumerate(results["ids"][0]): processed_results.append( { "id": doc_id, "document": ( results["documents"][0][i] if results.get("documents") else None ), "metadata": ( results["metadatas"][0][i] if results.get("metadatas") else None ), "distance": ( results["distances"][0][i] if results.get("distances") else None ), } ) logger.info(f"Found {len(processed_results)} semantic results.") return processed_results except Exception as e: logger.error( f"ChromaDB error searching memory for query '{query_text[:50]}...': {e}", exc_info=True, ) return [] async def delete_user_fact( self, user_id: str, fact_to_delete: str ) -> Dict[str, Any]: """Deletes a specific fact for a user from both SQLite and ChromaDB.""" if not user_id or not fact_to_delete: return {"error": "user_id and fact_to_delete are required."} logger.info(f"Attempting to delete user fact for {user_id}: '{fact_to_delete}'") deleted_chroma_id = None try: # Check if fact exists and get chroma_id row = await self._db_fetchone( "SELECT chroma_id FROM user_facts WHERE user_id = ? AND fact = ?", (user_id, fact_to_delete), ) if not row: logger.warning( f"Fact not found in SQLite for user {user_id}: '{fact_to_delete}'" ) return { "status": "not_found", "user_id": user_id, "fact": fact_to_delete, } deleted_chroma_id = row[0] # Delete from SQLite await self._db_execute( "DELETE FROM user_facts WHERE user_id = ? AND fact = ?", (user_id, fact_to_delete), ) logger.info( f"Deleted fact from SQLite for user {user_id}: '{fact_to_delete}'" ) # Delete from ChromaDB if chroma_id exists if deleted_chroma_id and self.fact_collection: try: logger.info( f"Attempting to delete fact from ChromaDB (ID: {deleted_chroma_id})." ) await asyncio.to_thread( self.fact_collection.delete, ids=[deleted_chroma_id] ) logger.info( f"Successfully deleted fact from ChromaDB (ID: {deleted_chroma_id})." ) except Exception as chroma_e: logger.error( f"ChromaDB error deleting user fact ID {deleted_chroma_id}: {chroma_e}", exc_info=True, ) # Log error but consider SQLite deletion successful return { "status": "deleted", "user_id": user_id, "fact_deleted": fact_to_delete, } except Exception as e: logger.error(f"Error deleting user fact for {user_id}: {e}", exc_info=True) return {"error": f"Database error deleting user fact: {str(e)}"} async def delete_general_fact(self, fact_to_delete: str) -> Dict[str, Any]: """Deletes a specific general fact from both SQLite and ChromaDB.""" if not fact_to_delete: return {"error": "fact_to_delete is required."} logger.info(f"Attempting to delete general fact: '{fact_to_delete}'") deleted_chroma_id = None try: # Check if fact exists and get chroma_id row = await self._db_fetchone( "SELECT chroma_id FROM general_facts WHERE fact = ?", (fact_to_delete,) ) if not row: logger.warning(f"General fact not found in SQLite: '{fact_to_delete}'") return {"status": "not_found", "fact": fact_to_delete} deleted_chroma_id = row[0] # Delete from SQLite await self._db_execute( "DELETE FROM general_facts WHERE fact = ?", (fact_to_delete,) ) logger.info(f"Deleted general fact from SQLite: '{fact_to_delete}'") # Delete from ChromaDB if chroma_id exists if deleted_chroma_id and self.fact_collection: try: logger.info( f"Attempting to delete general fact from ChromaDB (ID: {deleted_chroma_id})." ) await asyncio.to_thread( self.fact_collection.delete, ids=[deleted_chroma_id] ) logger.info( f"Successfully deleted general fact from ChromaDB (ID: {deleted_chroma_id})." ) except Exception as chroma_e: logger.error( f"ChromaDB error deleting general fact ID {deleted_chroma_id}: {chroma_e}", exc_info=True, ) # Log error but consider SQLite deletion successful return {"status": "deleted", "fact_deleted": fact_to_delete} except Exception as e: logger.error(f"Error deleting general fact: {e}", exc_info=True) return {"error": f"Database error deleting general fact: {str(e)}"}