bjuh
This commit is contained in:
parent
026ac3f20a
commit
625d2e4faf
1158
gurt/api.py
1158
gurt/api.py
File diff suppressed because it is too large
Load Diff
119
gurt/memory.py
119
gurt/memory.py
@ -15,5 +15,120 @@ spec.loader.exec_module(gurt_memory)
|
||||
# Import the MemoryManager class from the loaded module
|
||||
MemoryManager = gurt_memory.MemoryManager
|
||||
|
||||
# Re-export the MemoryManager class
|
||||
__all__ = ['MemoryManager']
|
||||
import logging
|
||||
from typing import List, Sequence
|
||||
|
||||
# LangChain imports for Chat History
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
|
||||
|
||||
# Relative imports
|
||||
from .config import CONTEXT_WINDOW_SIZE # Import context window size
|
||||
|
||||
# Configure logging if not already done elsewhere
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- LangChain Chat History Implementation ---
|
||||
|
||||
class GurtMessageCacheHistory(BaseChatMessageHistory):
|
||||
"""Chat message history that reads from and potentially writes to GurtCog's message cache."""
|
||||
|
||||
def __init__(self, cog: 'GurtCog', channel_id: int):
|
||||
from .cog import GurtCog # Local import for type check
|
||||
if not isinstance(cog, GurtCog):
|
||||
raise TypeError("GurtMessageCacheHistory requires a GurtCog instance.")
|
||||
self.cog = cog
|
||||
self.channel_id = channel_id
|
||||
self.key = f"channel:{channel_id}" # Example key structure
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Retrieve messages from the cache and format them."""
|
||||
# Access the cache via the cog instance
|
||||
cached_messages_data = list(self.cog.message_cache['by_channel'].get(self.channel_id, []))
|
||||
|
||||
# Apply context window limit
|
||||
items: List[BaseMessage] = []
|
||||
# Take the last N messages based on CONTEXT_WINDOW_SIZE
|
||||
relevant_messages_data = cached_messages_data[-CONTEXT_WINDOW_SIZE:]
|
||||
|
||||
for msg_data in relevant_messages_data:
|
||||
role = "ai" if msg_data['author']['id'] == str(self.cog.bot.user.id) else "human"
|
||||
# Reconstruct content similar to gather_conversation_context
|
||||
content_parts = []
|
||||
author_name = msg_data['author']['display_name']
|
||||
|
||||
if msg_data.get("is_reply"):
|
||||
reply_author = msg_data.get('replied_to_author_name', 'Unknown User')
|
||||
reply_snippet = msg_data.get('replied_to_content_snippet')
|
||||
reply_snippet_short = '...'
|
||||
if isinstance(reply_snippet, str):
|
||||
reply_snippet_short = (reply_snippet[:25] + '...') if len(reply_snippet) > 28 else reply_snippet
|
||||
content_parts.append(f"{author_name} (replying to {reply_author} '{reply_snippet_short}'):")
|
||||
else:
|
||||
content_parts.append(f"{author_name}:")
|
||||
|
||||
if msg_data.get('content'):
|
||||
content_parts.append(msg_data['content'])
|
||||
|
||||
attachments = msg_data.get("attachment_descriptions", [])
|
||||
if attachments:
|
||||
attachment_str = " ".join([att['description'] for att in attachments])
|
||||
content_parts.append(f"[Attachments: {attachment_str}]") # Clearly label attachments
|
||||
|
||||
content = " ".join(content_parts).strip()
|
||||
|
||||
if role == "human":
|
||||
items.append(HumanMessage(content=content))
|
||||
elif role == "ai":
|
||||
items.append(AIMessage(content=content))
|
||||
else:
|
||||
# Handle other roles if necessary, or raise an error
|
||||
logger.warning(f"Unhandled message role '{role}' in GurtMessageCacheHistory for channel {self.channel_id}")
|
||||
|
||||
return items
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""
|
||||
Add a message to the history.
|
||||
|
||||
Note: This implementation assumes the GurtCog's message listeners
|
||||
are already populating the cache. This method might just log
|
||||
or could potentially duplicate additions if not careful.
|
||||
For now, we make it a no-op and rely on the cog's caching.
|
||||
"""
|
||||
logger.debug(f"GurtMessageCacheHistory.add_message called for channel {self.channel_id}, but is currently a no-op. Cache is populated by GurtCog listeners.")
|
||||
# If we needed to write back:
|
||||
# self._add_message_to_cache(message)
|
||||
pass
|
||||
|
||||
# Optional: Implement add_user_message, add_ai_message if needed
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear history from the cache for this channel."""
|
||||
logger.warning(f"GurtMessageCacheHistory.clear() called for channel {self.channel_id}. Clearing cache entry.")
|
||||
if self.channel_id in self.cog.message_cache['by_channel']:
|
||||
del self.cog.message_cache['by_channel'][self.channel_id]
|
||||
# Potentially clear other related caches if necessary
|
||||
|
||||
# Factory function for LangchainAgent
|
||||
def get_gurt_session_history(session_id: str, cog: 'GurtCog') -> BaseChatMessageHistory:
|
||||
"""
|
||||
Factory function to get a chat history instance for a given session ID.
|
||||
The session_id is expected to be the Discord channel ID.
|
||||
"""
|
||||
try:
|
||||
channel_id = int(session_id)
|
||||
return GurtMessageCacheHistory(cog=cog, channel_id=channel_id)
|
||||
except ValueError:
|
||||
logger.error(f"Invalid session_id for Gurt chat history: '{session_id}'. Expected integer channel ID.")
|
||||
# Return an in-memory history as a fallback? Or raise error?
|
||||
# from langchain_community.chat_message_histories import ChatMessageHistory
|
||||
# return ChatMessageHistory() # Fallback to basic in-memory
|
||||
raise ValueError(f"Invalid session_id: {session_id}")
|
||||
except TypeError as e:
|
||||
logger.error(f"TypeError creating GurtMessageCacheHistory: {e}. Ensure 'cog' is passed correctly.")
|
||||
raise
|
||||
|
||||
# Re-export the MemoryManager class AND the history components
|
||||
__all__ = ['MemoryManager', 'GurtMessageCacheHistory', 'get_gurt_session_history']
|
||||
|
Loading…
x
Reference in New Issue
Block a user