aa
This commit is contained in:
parent
e2ea584e06
commit
38856d2798
@ -296,7 +296,7 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name:
|
||||
chat_history=chat_history_factory, # Pass the factory function
|
||||
prompt=prompt_template, # Pass the constructed prompt template
|
||||
model_kwargs=model_kwargs,
|
||||
verbose=True, # Enable for debugging agent steps
|
||||
# verbose=True, # Enable for debugging agent steps
|
||||
# handle_parsing_errors=True, # Let the agent try to recover from parsing errors
|
||||
# max_iterations=10, # Limit tool execution loops
|
||||
)
|
||||
|
@ -469,13 +469,16 @@ async def on_message_listener(cog: 'GurtCog', message: discord.Message):
|
||||
# Send message with reference if applicable
|
||||
sent_msg = await original_message.channel.send(response_text, reference=message_reference, mention_author=False) # mention_author=False is usually preferred for bots
|
||||
sent_any_message = True
|
||||
# Cache this bot response
|
||||
bot_response_cache_entry = format_message(cog, sent_msg) # Pass cog
|
||||
cog.message_cache['by_channel'][channel_id].append(bot_response_cache_entry)
|
||||
cog.message_cache['global_recent'].append(bot_response_cache_entry)
|
||||
cog.bot_last_spoke[channel_id] = time.time()
|
||||
# Track participation topic
|
||||
identified_topics = identify_conversation_topics(cog, [bot_response_cache_entry]) # Pass cog
|
||||
# Cache this bot response - NOTE: Commented out as LangchainAgent should handle history via add_message
|
||||
# bot_response_cache_entry = format_message(cog, sent_msg) # Pass cog
|
||||
# cog.message_cache['by_channel'][channel_id].append(bot_response_cache_entry)
|
||||
# cog.message_cache['global_recent'].append(bot_response_cache_entry)
|
||||
cog.bot_last_spoke[channel_id] = time.time() # Keep track of when bot last spoke
|
||||
# Track participation topic - Requires the sent message content. Let's get it directly.
|
||||
# We need the content to identify topics. Since we don't cache the formatted message anymore,
|
||||
# let's create a minimal dict for topic identification.
|
||||
bot_response_for_topic = {"content": sent_msg.content, "author": {"id": str(cog.bot.user.id)}}
|
||||
identified_topics = identify_conversation_topics(cog, [bot_response_for_topic]) # Pass cog
|
||||
if identified_topics:
|
||||
topic = identified_topics[0]['topic'].lower().strip()
|
||||
cog.gurt_participation_topics[topic] += 1
|
||||
|
158
gurt/memory.py
158
gurt/memory.py
@ -2,7 +2,8 @@
|
||||
# Use a direct import path that doesn't rely on package structure
|
||||
import os
|
||||
import importlib.util
|
||||
from typing import TYPE_CHECKING # Import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, List, Sequence, Dict, Any # Import TYPE_CHECKING and other types
|
||||
import collections # Import collections for deque
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .cog import GurtCog # Use relative import for type hinting
|
||||
@ -24,7 +25,10 @@ from typing import List, Sequence
|
||||
|
||||
# LangChain imports for Chat History
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
|
||||
# Import specific message types needed
|
||||
from langchain_core.messages import (
|
||||
BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
)
|
||||
|
||||
# Relative imports
|
||||
from .config import CONTEXT_WINDOW_SIZE # Import context window size
|
||||
@ -50,72 +54,124 @@ class GurtMessageCacheHistory(BaseChatMessageHistory):
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Retrieve messages from the cache and format them."""
|
||||
"""Retrieve messages from the cache and reconstruct LangChain messages."""
|
||||
# Access the cache via the cog instance
|
||||
cached_messages_data = list(self.cog.message_cache['by_channel'].get(self.channel_id, []))
|
||||
# Ensure the cache is initialized as a deque
|
||||
channel_cache = self.cog.message_cache['by_channel'].setdefault(
|
||||
self.channel_id, collections.deque(maxlen=CONTEXT_WINDOW_SIZE * 2) # Use a larger maxlen for safety?
|
||||
)
|
||||
cached_messages_data = list(channel_cache) # Get a list copy
|
||||
|
||||
# Apply context window limit
|
||||
items: List[BaseMessage] = []
|
||||
# Take the last N messages based on CONTEXT_WINDOW_SIZE
|
||||
relevant_messages_data = cached_messages_data[-CONTEXT_WINDOW_SIZE:]
|
||||
# Apply context window limit (consider if the limit should apply differently to LC messages vs formatted)
|
||||
# For now, apply simple limit to the combined list
|
||||
relevant_messages_data = cached_messages_data[-(CONTEXT_WINDOW_SIZE * 2):] # Use the potentially larger limit
|
||||
|
||||
for msg_data in relevant_messages_data:
|
||||
role = "ai" if msg_data['author']['id'] == str(self.cog.bot.user.id) else "human"
|
||||
# Reconstruct content similar to gather_conversation_context
|
||||
content_parts = []
|
||||
author_name = msg_data['author']['display_name']
|
||||
if isinstance(msg_data, dict) and msg_data.get('_is_lc_message_'):
|
||||
# Reconstruct LangChain message from serialized dict
|
||||
lc_type = msg_data.get('lc_type')
|
||||
content = msg_data.get('content', '')
|
||||
additional_kwargs = msg_data.get('additional_kwargs', {})
|
||||
tool_calls = msg_data.get('tool_calls') # For AIMessage
|
||||
tool_call_id = msg_data.get('tool_call_id') # For ToolMessage
|
||||
|
||||
if msg_data.get("is_reply"):
|
||||
reply_author = msg_data.get('replied_to_author_name', 'Unknown User')
|
||||
reply_snippet = msg_data.get('replied_to_content_snippet')
|
||||
reply_snippet_short = '...'
|
||||
if isinstance(reply_snippet, str):
|
||||
reply_snippet_short = (reply_snippet[:25] + '...') if len(reply_snippet) > 28 else reply_snippet
|
||||
content_parts.append(f"{author_name} (replying to {reply_author} '{reply_snippet_short}'):")
|
||||
try:
|
||||
if lc_type == 'HumanMessage':
|
||||
items.append(HumanMessage(content=content, additional_kwargs=additional_kwargs))
|
||||
elif lc_type == 'AIMessage':
|
||||
# Reconstruct AIMessage, potentially with tool_calls
|
||||
ai_msg = AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
if tool_calls:
|
||||
# Ensure tool_calls are in the correct format if needed (e.g., list of dicts)
|
||||
# Assuming they were stored correctly from message.dict()
|
||||
ai_msg.tool_calls = tool_calls
|
||||
items.append(ai_msg)
|
||||
elif lc_type == 'ToolMessage':
|
||||
# ToolMessage needs content and tool_call_id
|
||||
if tool_call_id:
|
||||
items.append(ToolMessage(content=content, tool_call_id=tool_call_id, additional_kwargs=additional_kwargs))
|
||||
else:
|
||||
logger.warning(f"Skipping ToolMessage reconstruction, missing tool_call_id: {msg_data}")
|
||||
elif lc_type == 'SystemMessage': # Should not happen via add_message, but handle defensively
|
||||
items.append(SystemMessage(content=content, additional_kwargs=additional_kwargs))
|
||||
# Add other types if needed (FunctionMessage?)
|
||||
else:
|
||||
logger.warning(f"Unhandled LangChain message type '{lc_type}' during reconstruction.")
|
||||
|
||||
except Exception as recon_e:
|
||||
logger.error(f"Error reconstructing LangChain message type '{lc_type}': {recon_e}\nData: {msg_data}", exc_info=True)
|
||||
|
||||
elif isinstance(msg_data, dict) and not msg_data.get('_is_lc_message_'):
|
||||
# Existing logic for reconstructing from formatted user/bot messages
|
||||
# This assumes the agent doesn't add Human/AI messages that overlap with these
|
||||
role = "ai" if msg_data.get('author', {}).get('id') == str(self.cog.bot.user.id) else "human"
|
||||
# Reconstruct content similar to original logic (simplified)
|
||||
content_parts = []
|
||||
author_name = msg_data.get('author', {}).get('display_name', 'Unknown')
|
||||
|
||||
# Basic content reconstruction
|
||||
content = msg_data.get('content', '')
|
||||
attachments = msg_data.get("attachment_descriptions", [])
|
||||
if attachments:
|
||||
attachment_str = " ".join([att['description'] for att in attachments])
|
||||
content += f" [Attachments: {attachment_str}]" # Append attachment info
|
||||
|
||||
# Combine author and content for the LangChain message
|
||||
# NOTE: This might differ from how the agent expects input if it relies on raw content.
|
||||
# Consider if just the content string is better here.
|
||||
# Let's stick to the previous format for now.
|
||||
full_content = f"{author_name}: {content}"
|
||||
|
||||
if role == "human":
|
||||
items.append(HumanMessage(content=full_content))
|
||||
elif role == "ai":
|
||||
# This should only be the *final* AI response text, without tool calls
|
||||
items.append(AIMessage(content=full_content))
|
||||
else:
|
||||
logger.warning(f"Unhandled message role '{role}' in GurtMessageCacheHistory (formatted msg) for channel {self.channel_id}")
|
||||
else:
|
||||
content_parts.append(f"{author_name}:")
|
||||
|
||||
if msg_data.get('content'):
|
||||
content_parts.append(msg_data['content'])
|
||||
|
||||
attachments = msg_data.get("attachment_descriptions", [])
|
||||
if attachments:
|
||||
attachment_str = " ".join([att['description'] for att in attachments])
|
||||
content_parts.append(f"[Attachments: {attachment_str}]") # Clearly label attachments
|
||||
|
||||
content = " ".join(content_parts).strip()
|
||||
|
||||
if role == "human":
|
||||
items.append(HumanMessage(content=content))
|
||||
elif role == "ai":
|
||||
items.append(AIMessage(content=content))
|
||||
else:
|
||||
# Handle other roles if necessary, or raise an error
|
||||
logger.warning(f"Unhandled message role '{role}' in GurtMessageCacheHistory for channel {self.channel_id}")
|
||||
logger.warning(f"Skipping unrecognized item in message cache: {type(msg_data)}")
|
||||
|
||||
return items
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""
|
||||
Add a message to the history.
|
||||
"""Add a LangChain BaseMessage to the history cache."""
|
||||
try:
|
||||
# Serialize the message object to a dictionary using pydantic's dict()
|
||||
message_dict = message.dict()
|
||||
# Explicitly store the LangChain class name for reconstruction
|
||||
message_dict['lc_type'] = message.__class__.__name__
|
||||
# Add our flag to distinguish it during retrieval
|
||||
message_dict['_is_lc_message_'] = True
|
||||
|
||||
Note: This implementation assumes the GurtCog's message listeners
|
||||
are already populating the cache. This method might just log
|
||||
or could potentially duplicate additions if not careful.
|
||||
For now, we make it a no-op and rely on the cog's caching.
|
||||
"""
|
||||
logger.debug(f"GurtMessageCacheHistory.add_message called for channel {self.channel_id}, but is currently a no-op. Cache is populated by GurtCog listeners.")
|
||||
# If we needed to write back:
|
||||
# self._add_message_to_cache(message)
|
||||
pass
|
||||
# Ensure tool_calls and tool_call_id are preserved if they exist
|
||||
# (message.dict() should handle this, but double-check if issues arise)
|
||||
# Example explicit checks (might be redundant):
|
||||
# if isinstance(message, AIMessage) and hasattr(message, 'tool_calls') and message.tool_calls:
|
||||
# message_dict['tool_calls'] = message.tool_calls
|
||||
# elif isinstance(message, ToolMessage) and hasattr(message, 'tool_call_id'):
|
||||
# message_dict['tool_call_id'] = message.tool_call_id
|
||||
|
||||
# Optional: Implement add_user_message, add_ai_message if needed
|
||||
# Access the cache via the cog instance, ensuring it's a deque
|
||||
channel_cache = self.cog.message_cache['by_channel'].setdefault(
|
||||
self.channel_id, collections.deque(maxlen=CONTEXT_WINDOW_SIZE * 2) # Use consistent maxlen
|
||||
)
|
||||
channel_cache.append(message_dict)
|
||||
logger.debug(f"Added LangChain message ({message.__class__.__name__}) to cache for channel {self.channel_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding LangChain message to cache for channel {self.channel_id}: {e}", exc_info=True)
|
||||
|
||||
|
||||
# Optional: Implement add_user_message, add_ai_message if needed (BaseChatMessageHistory provides defaults)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear history from the cache for this channel."""
|
||||
logger.warning(f"GurtMessageCacheHistory.clear() called for channel {self.channel_id}. Clearing cache entry.")
|
||||
logger.warning(f"GurtMessageCacheHistory.clear() called for channel {self.channel_id}. Clearing cache deque.")
|
||||
if self.channel_id in self.cog.message_cache['by_channel']:
|
||||
del self.cog.message_cache['by_channel'][self.channel_id]
|
||||
# Clear the deque instead of deleting the key, to keep the deque object
|
||||
self.cog.message_cache['by_channel'][self.channel_id].clear()
|
||||
# Potentially clear other related caches if necessary
|
||||
|
||||
# Factory function for LangchainAgent
|
||||
|
Loading…
x
Reference in New Issue
Block a user