This commit is contained in:
Slipstream 2025-04-28 23:02:45 -06:00
parent 38856d2798
commit b10f11ce51
Signed by: slipstream
GPG Key ID: 13E498CE010AC6FD
2 changed files with 66 additions and 49 deletions

View File

@ -232,46 +232,36 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name:
])
# --- 2. Prepare Tools ---
# Create wrappers for tools to include the 'cog' instance
# Collect decorated tool functions from the TOOL_MAPPING
# The @tool decorator handles schema generation and description.
# We still need to bind the 'cog' instance to each tool call.
prepared_tools = []
for tool_name, tool_func in TOOL_MAPPING.items():
# Check if the function is actually decorated (optional, but good practice)
# The @tool decorator adds attributes like .name, .description, .args_schema
if not hasattr(tool_func, 'is_lc_tool') and not hasattr(tool_func, 'name'):
# Skip functions not decorated with @tool, or internal/experimental ones
if tool_name not in ["create_new_tool", "execute_internal_command", "_check_command_safety"]:
logger.warning(f"Tool function '{tool_name}' in TOOL_MAPPING is missing the @tool decorator. Skipping.")
continue
try:
# Define a wrapper function that captures cog and calls the original tool
# This wrapper will have the correct signature for Langchain to inspect
def create_wrapper(func_to_wrap, current_cog):
# Define the actual wrapper that Langchain will call
# It accepts the arguments Langchain extracts based on the original function's signature
def tool_wrapper(*args, **kwargs):
# Call the original tool function with cog as the first argument
# Pass through the arguments received from Langchain
return func_to_wrap(current_cog, *args, **kwargs)
# Copy metadata for Langchain schema generation from the original function
functools.update_wrapper(tool_wrapper, func_to_wrap)
# Ensure the wrapper has the correct name for the tool mapping/introspection
# Note: Langchain might use the name attribute if wrapped in LangchainTool later
tool_wrapper.__name__ = func_to_wrap.__name__
return tool_wrapper
# Create a partial function that includes the 'cog' instance
# Langchain agents can often handle partial functions directly.
# The agent will call this partial, which in turn calls the original tool with 'cog'.
# The @tool decorator ensures Langchain gets the correct signature from the *original* func.
tool_with_cog = functools.partial(tool_func, cog)
# Create the specific wrapper for this tool, capturing the current cog
wrapped_tool_func = create_wrapper(tool_func, cog)
# Copy essential attributes from the original decorated tool function to the partial
# This helps ensure the agent framework recognizes it correctly.
functools.update_wrapper(tool_with_cog, tool_func)
# Pass the wrapped function directly to the agent's tool list.
# LangchainAgent should be able to handle functions directly.
# We explicitly provide the name in the LangchainTool wrapper below for clarity if needed.
# For now, let's try passing the function itself, relying on Langchain's introspection.
# If issues persist, wrap in LangchainTool:
prepared_tools.append(LangchainTool(
name=tool_name, # Use the key from TOOL_MAPPING as the definitive name
func=wrapped_tool_func, # Pass the wrapper function
description=tool_func.__doc__ or f"Executes the {tool_name} tool.", # Use original docstring
# LangChain should infer args_schema from the *original* tool_func's type hints
# because functools.update_wrapper copies the signature.
))
# Simpler alternative (try if the above fails):
# prepared_tools.append(wrapped_tool_func)
# Add the partial function (which now includes cog) to the list
prepared_tools.append(tool_with_cog)
logger.debug(f"Prepared tool '{tool_name}' with cog instance bound.")
except Exception as tool_prep_e:
logger.error(f"Error preparing tool '{tool_name}': {tool_prep_e}", exc_info=True)
logger.error(f"Error preparing tool '{tool_name}' with functools.partial: {tool_prep_e}", exc_info=True)
# Optionally skip this tool
# --- 3. Prepare Chat History Factory ---

View File

@ -17,6 +17,7 @@ from tavily import TavilyClient
import docker
import aiodocker # Use aiodocker for async operations
from asteval import Interpreter # Added for calculate tool
from langchain_core.tools import tool # Import the tool decorator
# Relative imports from within the gurt package and parent
from .memory import MemoryManager # Import from local memory.py
@ -37,6 +38,7 @@ from .config import (
# to access things like cog.bot, cog.session, cog.current_channel, cog.memory_manager etc.
# We will add 'cog' as the first parameter to each.
@tool
async def get_recent_messages(cog: commands.Cog, limit: int, channel_id: Optional[str] = None) -> Dict[str, Any]:
"""
Retrieves the most recent messages from a specified Discord channel or the current channel.
@ -71,6 +73,7 @@ async def get_recent_messages(cog: commands.Cog, limit: int, channel_id: Optiona
except Exception as e:
return {"error": f"Error retrieving messages: {str(e)}", "timestamp": datetime.datetime.now().isoformat()}
@tool
async def search_user_messages(cog: commands.Cog, user_id: str, limit: int, channel_id: Optional[str] = None) -> Dict[str, Any]:
"""
Searches recent channel history for messages sent by a specific user.
@ -115,6 +118,7 @@ async def search_user_messages(cog: commands.Cog, user_id: str, limit: int, chan
except Exception as e:
return {"error": f"Error searching user messages: {str(e)}", "timestamp": datetime.datetime.now().isoformat()}
@tool
async def search_messages_by_content(cog: commands.Cog, search_term: str, limit: int, channel_id: Optional[str] = None) -> Dict[str, Any]:
"""
Searches recent channel history for messages containing specific text content (case-insensitive).
@ -154,6 +158,7 @@ async def search_messages_by_content(cog: commands.Cog, search_term: str, limit:
except Exception as e:
return {"error": f"Error searching messages by content: {str(e)}", "timestamp": datetime.datetime.now().isoformat()}
@tool
async def get_channel_info(cog: commands.Cog, channel_id: Optional[str] = None) -> Dict[str, Any]:
"""
Retrieves detailed information about a specified Discord channel or the current channel.
@ -192,6 +197,7 @@ async def get_channel_info(cog: commands.Cog, channel_id: Optional[str] = None)
except Exception as e:
return {"error": f"Error getting channel info: {str(e)}", "timestamp": datetime.datetime.now().isoformat()}
@tool
async def get_conversation_context(cog: commands.Cog, message_count: int, channel_id: Optional[str] = None) -> Dict[str, Any]:
"""
Retrieves recent messages to provide context for the ongoing conversation in a channel.
@ -231,6 +237,7 @@ async def get_conversation_context(cog: commands.Cog, message_count: int, channe
except Exception as e:
return {"error": f"Error getting conversation context: {str(e)}"}
@tool
async def get_thread_context(cog: commands.Cog, thread_id: str, message_count: int) -> Dict[str, Any]:
"""
Retrieves recent messages from a specific Discord thread to provide conversation context.
@ -267,6 +274,7 @@ async def get_thread_context(cog: commands.Cog, thread_id: str, message_count: i
except Exception as e:
return {"error": f"Error getting thread context: {str(e)}"}
@tool
async def get_user_interaction_history(cog: commands.Cog, user_id_1: str, limit: int, user_id_2: Optional[str] = None) -> Dict[str, Any]:
"""
Retrieves the recent message history involving interactions (replies, mentions) between two users.
@ -315,6 +323,7 @@ async def get_user_interaction_history(cog: commands.Cog, user_id_1: str, limit:
except Exception as e:
return {"error": f"Error getting user interaction history: {str(e)}"}
@tool
async def get_conversation_summary(cog: commands.Cog, channel_id: Optional[str] = None, message_limit: int = 25) -> Dict[str, Any]:
"""
Generates and returns a concise summary of the recent conversation in a specified channel or the current channel.
@ -399,6 +408,7 @@ async def get_conversation_summary(cog: commands.Cog, channel_id: Optional[str]
traceback.print_exc()
return {"error": error_msg}
@tool
async def get_message_context(cog: commands.Cog, message_id: str, before_count: int = 5, after_count: int = 5) -> Dict[str, Any]:
"""
Retrieves messages immediately before and after a specific message ID within the current channel.
@ -440,6 +450,7 @@ async def get_message_context(cog: commands.Cog, message_id: str, before_count:
except Exception as e:
return {"error": f"Error getting message context: {str(e)}"}
@tool
async def web_search(cog: commands.Cog, query: str, search_depth: str = TAVILY_DEFAULT_SEARCH_DEPTH, max_results: int = TAVILY_DEFAULT_MAX_RESULTS, topic: str = "general", include_domains: Optional[List[str]] = None, exclude_domains: Optional[List[str]] = None, include_answer: bool = True, include_raw_content: bool = False, include_images: bool = False) -> Dict[str, Any]:
"""
Performs a web search using the Tavily API based on the provided query and parameters.
@ -520,6 +531,7 @@ async def web_search(cog: commands.Cog, query: str, search_depth: str = TAVILY_D
print(error_message)
return {"error": error_message, "timestamp": datetime.datetime.now().isoformat()}
@tool
async def remember_user_fact(cog: commands.Cog, user_id: str, fact: str) -> Dict[str, Any]:
"""
Stores a specific fact about a given user in the bot's long-term memory.
@ -545,6 +557,7 @@ async def remember_user_fact(cog: commands.Cog, user_id: str, fact: str) -> Dict
print(error_message); traceback.print_exc()
return {"error": error_message}
@tool
async def get_user_facts(cog: commands.Cog, user_id: str) -> Dict[str, Any]:
"""
Retrieves all stored facts associated with a specific user from the bot's long-term memory.
@ -566,6 +579,7 @@ async def get_user_facts(cog: commands.Cog, user_id: str) -> Dict[str, Any]:
print(error_message); traceback.print_exc()
return {"error": error_message}
@tool
async def remember_general_fact(cog: commands.Cog, fact: str) -> Dict[str, Any]:
"""
Stores a general fact (not specific to any user) in the bot's long-term memory.
@ -590,6 +604,7 @@ async def remember_general_fact(cog: commands.Cog, fact: str) -> Dict[str, Any]:
print(error_message); traceback.print_exc()
return {"error": error_message}
@tool
async def get_general_facts(cog: commands.Cog, query: Optional[str] = None, limit: Optional[int] = 10) -> Dict[str, Any]:
"""
Retrieves general facts from the bot's long-term memory. Can optionally filter by a query string.
@ -612,6 +627,7 @@ async def get_general_facts(cog: commands.Cog, query: Optional[str] = None, limi
print(error_message); traceback.print_exc()
return {"error": error_message}
@tool
async def timeout_user(cog: commands.Cog, user_id: str, duration_minutes: int, reason: Optional[str] = None) -> Dict[str, Any]:
"""
Applies a timeout to a specified user within the current server (guild).
@ -653,6 +669,7 @@ async def timeout_user(cog: commands.Cog, user_id: str, duration_minutes: int, r
except discord.HTTPException as e: print(f"API error timeout {user_id}: {e}"); return {"error": f"API error timeout {user_id}: {e}"}
except Exception as e: print(f"Unexpected error timeout {user_id}: {e}"); traceback.print_exc(); return {"error": f"Unexpected error timeout {user_id}: {str(e)}"}
@tool
async def remove_timeout(cog: commands.Cog, user_id: str, reason: Optional[str] = None) -> Dict[str, Any]:
"""
Removes an active timeout from a specified user within the current server (guild).
@ -689,16 +706,21 @@ async def remove_timeout(cog: commands.Cog, user_id: str, reason: Optional[str]
except discord.HTTPException as e: print(f"API error remove timeout {user_id}: {e}"); return {"error": f"API error remove timeout {user_id}: {e}"}
except Exception as e: print(f"Unexpected error remove timeout {user_id}: {e}"); traceback.print_exc(); return {"error": f"Unexpected error remove timeout {user_id}: {str(e)}"}
def calculate(cog: commands.Cog, expression: str) -> Dict[str, Any]:
@tool
def calculate(cog: commands.Cog, expression: str) -> str:
"""
Evaluates a mathematical expression using the asteval library. Supports common math functions.
Evaluates a mathematical expression using the asteval library. Supports common math functions. Returns the result as a string.
Args:
cog: The GurtCog instance (automatically passed).
expression: The mathematical expression string to evaluate (e.g., "2 * (pi + 1)").
Args:
cog: The GurtCog instance (automatically passed).
expression: The mathematical expression string to evaluate (e.g., "2 * (pi + 1)").
Returns:
A dictionary containing the original expression, the calculated result string, and status 'success', or an error dictionary.
The calculated result as a string, or an error message string if calculation fails.
"""
print(f"Calculating expression: {expression}")
aeval = Interpreter()
@ -708,20 +730,17 @@ def calculate(cog: commands.Cog, expression: str) -> Dict[str, Any]:
error_details = '; '.join(err.get_error() for err in aeval.error)
error_message = f"Calculation error: {error_details}"
print(error_message)
return {"error": error_message, "expression": expression}
if isinstance(result, (int, float, complex)): result_str = str(result)
else: result_str = repr(result) # Fallback
return f"Error: {error_message}" # Return error as string
result_str = str(result) # Convert result to string
print(f"Calculation result: {result_str}")
# Return only the result string on success, as expected by some agent frameworks
return result_str
return result_str # Return result as string
except Exception as e:
error_message = f"Unexpected error during calculation: {str(e)}"
print(error_message); traceback.print_exc()
# Return error message as string on failure
return {"error": error_message, "expression": expression}
return f"Error: {error_message}" # Return error as string
@tool
async def run_python_code(cog: commands.Cog, code: str) -> Dict[str, Any]:
"""
Executes a provided Python code snippet remotely using the Piston API.
@ -767,6 +786,7 @@ async def run_python_code(cog: commands.Cog, code: str) -> Dict[str, Any]:
except aiohttp.ClientError as e: print(f"Piston network error: {e}"); return {"error": f"Network error connecting to Piston: {str(e)}"}
except Exception as e: print(f"Unexpected Piston error: {e}"); traceback.print_exc(); return {"error": f"Unexpected error during Python execution: {str(e)}"}
@tool
async def create_poll(cog: commands.Cog, question: str, options: List[str]) -> Dict[str, Any]:
"""
Creates a simple poll message in the current channel with numbered reaction options.
@ -801,7 +821,7 @@ async def create_poll(cog: commands.Cog, question: str, options: List[str]) -> D
except discord.HTTPException as e: print(f"Poll API error: {e}"); return {"error": f"API error creating poll: {e}"}
except Exception as e: print(f"Poll unexpected error: {e}"); traceback.print_exc(); return {"error": f"Unexpected error creating poll: {str(e)}"}
# Helper function to convert memory string (e.g., "128m") to bytes
# Helper function to convert memory string (e.g., "128m") to bytes (Not a tool)
def parse_mem_limit(mem_limit_str: str) -> Optional[int]:
if not mem_limit_str: return None
mem_limit_str = mem_limit_str.lower()
@ -859,6 +879,7 @@ async def _check_command_safety(cog: commands.Cog, command: str) -> Dict[str, An
print(f"AI Safety Check Error: Response was {safety_response}")
return {"safe": False, "reason": error_msg}
@tool
async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any]:
"""
Executes a shell command within an isolated, network-disabled Docker container.
@ -979,6 +1000,8 @@ async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any
# Ensure the client connection is closed
if client:
await client.close()
@tool
async def get_user_id(cog: commands.Cog, user_name: str) -> Dict[str, Any]:
"""
Finds the Discord User ID associated with a given username or display name.
@ -1029,7 +1052,7 @@ async def get_user_id(cog: commands.Cog, user_name: str) -> Dict[str, Any]:
print(f"User '{user_name}' not found in guild '{guild.name}'.")
return {"error": f"User '{user_name}' not found in this server.", "user_name": user_name}
# NOT decorating execute_internal_command as it's marked unsafe for general agent use
async def execute_internal_command(cog: commands.Cog, command: str, timeout_seconds: int = 60) -> Dict[str, Any]:
"""
Executes a shell command directly on the host machine where the bot is running.
@ -1097,6 +1120,7 @@ async def execute_internal_command(cog: commands.Cog, command: str, timeout_seco
traceback.print_exc()
return {"error": error_message, "command": command, "status": "error"}
@tool
async def extract_web_content(cog: commands.Cog, urls: Union[str, List[str]], extract_depth: str = "basic", include_images: bool = False) -> Dict[str, Any]:
"""
Extracts the main textual content and optionally images from one or more web URLs using the Tavily API.
@ -1141,6 +1165,7 @@ async def extract_web_content(cog: commands.Cog, urls: Union[str, List[str]], ex
print(error_message)
return {"error": error_message, "timestamp": datetime.datetime.now().isoformat()}
@tool
async def read_file_content(cog: commands.Cog, file_path: str) -> Dict[str, Any]:
"""
Reads the content of a specified file located on the bot's host machine.
@ -1227,7 +1252,7 @@ async def read_file_content(cog: commands.Cog, file_path: str) -> Dict[str, Any]
traceback.print_exc()
return {"error": error_message, "file_path": file_path}
# --- Meta Tool: Create New Tool ---
# --- Meta Tool: Create New Tool --- (Not decorating as it's experimental/dangerous)
# WARNING: HIGHLY EXPERIMENTAL AND DANGEROUS. Allows AI to write and load code.
async def create_new_tool(cog: commands.Cog, tool_name: str, description: str, parameters_json: str, returns_description: str) -> Dict[str, Any]:
"""
@ -1473,6 +1498,8 @@ async def create_new_tool(cog: commands.Cog, tool_name: str, description: str, p
# --- Tool Mapping ---
# This dictionary maps tool names (used in the AI prompt) to their implementation functions.
# The agent should discover tools via the @tool decorator, but this mapping might still be used elsewhere.
# Keep it updated, but the primary mechanism for the agent is the decorator.
TOOL_MAPPING = {
"get_recent_messages": get_recent_messages,
"search_user_messages": search_user_messages,
@ -1497,7 +1524,7 @@ TOOL_MAPPING = {
"remove_timeout": remove_timeout,
"extract_web_content": extract_web_content,
"read_file_content": read_file_content,
"create_new_tool": create_new_tool, # Added the meta-tool
"execute_internal_command": execute_internal_command, # Added internal command execution
"create_new_tool": create_new_tool, # Meta-tool (not decorated)
"execute_internal_command": execute_internal_command, # Internal command execution (not decorated)
"get_user_id": get_user_id # Added user ID lookup tool
}