diff --git a/gurt/api.py b/gurt/api.py index 5b533c1..d5048ee 100644 --- a/gurt/api.py +++ b/gurt/api.py @@ -232,46 +232,36 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: ]) # --- 2. Prepare Tools --- - # Create wrappers for tools to include the 'cog' instance + # Collect decorated tool functions from the TOOL_MAPPING + # The @tool decorator handles schema generation and description. + # We still need to bind the 'cog' instance to each tool call. prepared_tools = [] for tool_name, tool_func in TOOL_MAPPING.items(): + # Check if the function is actually decorated (optional, but good practice) + # The @tool decorator adds attributes like .name, .description, .args_schema + if not hasattr(tool_func, 'is_lc_tool') and not hasattr(tool_func, 'name'): + # Skip functions not decorated with @tool, or internal/experimental ones + if tool_name not in ["create_new_tool", "execute_internal_command", "_check_command_safety"]: + logger.warning(f"Tool function '{tool_name}' in TOOL_MAPPING is missing the @tool decorator. Skipping.") + continue + try: - # Define a wrapper function that captures cog and calls the original tool - # This wrapper will have the correct signature for Langchain to inspect - def create_wrapper(func_to_wrap, current_cog): - # Define the actual wrapper that Langchain will call - # It accepts the arguments Langchain extracts based on the original function's signature - def tool_wrapper(*args, **kwargs): - # Call the original tool function with cog as the first argument - # Pass through the arguments received from Langchain - return func_to_wrap(current_cog, *args, **kwargs) - # Copy metadata for Langchain schema generation from the original function - functools.update_wrapper(tool_wrapper, func_to_wrap) - # Ensure the wrapper has the correct name for the tool mapping/introspection - # Note: Langchain might use the name attribute if wrapped in LangchainTool later - tool_wrapper.__name__ = func_to_wrap.__name__ - return tool_wrapper + # Create a partial function that includes the 'cog' instance + # Langchain agents can often handle partial functions directly. + # The agent will call this partial, which in turn calls the original tool with 'cog'. + # The @tool decorator ensures Langchain gets the correct signature from the *original* func. + tool_with_cog = functools.partial(tool_func, cog) - # Create the specific wrapper for this tool, capturing the current cog - wrapped_tool_func = create_wrapper(tool_func, cog) + # Copy essential attributes from the original decorated tool function to the partial + # This helps ensure the agent framework recognizes it correctly. + functools.update_wrapper(tool_with_cog, tool_func) - # Pass the wrapped function directly to the agent's tool list. - # LangchainAgent should be able to handle functions directly. - # We explicitly provide the name in the LangchainTool wrapper below for clarity if needed. - # For now, let's try passing the function itself, relying on Langchain's introspection. - # If issues persist, wrap in LangchainTool: - prepared_tools.append(LangchainTool( - name=tool_name, # Use the key from TOOL_MAPPING as the definitive name - func=wrapped_tool_func, # Pass the wrapper function - description=tool_func.__doc__ or f"Executes the {tool_name} tool.", # Use original docstring - # LangChain should infer args_schema from the *original* tool_func's type hints - # because functools.update_wrapper copies the signature. - )) - # Simpler alternative (try if the above fails): - # prepared_tools.append(wrapped_tool_func) + # Add the partial function (which now includes cog) to the list + prepared_tools.append(tool_with_cog) + logger.debug(f"Prepared tool '{tool_name}' with cog instance bound.") except Exception as tool_prep_e: - logger.error(f"Error preparing tool '{tool_name}': {tool_prep_e}", exc_info=True) + logger.error(f"Error preparing tool '{tool_name}' with functools.partial: {tool_prep_e}", exc_info=True) # Optionally skip this tool # --- 3. Prepare Chat History Factory --- diff --git a/gurt/tools.py b/gurt/tools.py index 53ee594..ee86409 100644 --- a/gurt/tools.py +++ b/gurt/tools.py @@ -17,6 +17,7 @@ from tavily import TavilyClient import docker import aiodocker # Use aiodocker for async operations from asteval import Interpreter # Added for calculate tool +from langchain_core.tools import tool # Import the tool decorator # Relative imports from within the gurt package and parent from .memory import MemoryManager # Import from local memory.py @@ -37,6 +38,7 @@ from .config import ( # to access things like cog.bot, cog.session, cog.current_channel, cog.memory_manager etc. # We will add 'cog' as the first parameter to each. +@tool async def get_recent_messages(cog: commands.Cog, limit: int, channel_id: Optional[str] = None) -> Dict[str, Any]: """ Retrieves the most recent messages from a specified Discord channel or the current channel. @@ -71,6 +73,7 @@ async def get_recent_messages(cog: commands.Cog, limit: int, channel_id: Optiona except Exception as e: return {"error": f"Error retrieving messages: {str(e)}", "timestamp": datetime.datetime.now().isoformat()} +@tool async def search_user_messages(cog: commands.Cog, user_id: str, limit: int, channel_id: Optional[str] = None) -> Dict[str, Any]: """ Searches recent channel history for messages sent by a specific user. @@ -115,6 +118,7 @@ async def search_user_messages(cog: commands.Cog, user_id: str, limit: int, chan except Exception as e: return {"error": f"Error searching user messages: {str(e)}", "timestamp": datetime.datetime.now().isoformat()} +@tool async def search_messages_by_content(cog: commands.Cog, search_term: str, limit: int, channel_id: Optional[str] = None) -> Dict[str, Any]: """ Searches recent channel history for messages containing specific text content (case-insensitive). @@ -154,6 +158,7 @@ async def search_messages_by_content(cog: commands.Cog, search_term: str, limit: except Exception as e: return {"error": f"Error searching messages by content: {str(e)}", "timestamp": datetime.datetime.now().isoformat()} +@tool async def get_channel_info(cog: commands.Cog, channel_id: Optional[str] = None) -> Dict[str, Any]: """ Retrieves detailed information about a specified Discord channel or the current channel. @@ -192,6 +197,7 @@ async def get_channel_info(cog: commands.Cog, channel_id: Optional[str] = None) except Exception as e: return {"error": f"Error getting channel info: {str(e)}", "timestamp": datetime.datetime.now().isoformat()} +@tool async def get_conversation_context(cog: commands.Cog, message_count: int, channel_id: Optional[str] = None) -> Dict[str, Any]: """ Retrieves recent messages to provide context for the ongoing conversation in a channel. @@ -231,6 +237,7 @@ async def get_conversation_context(cog: commands.Cog, message_count: int, channe except Exception as e: return {"error": f"Error getting conversation context: {str(e)}"} +@tool async def get_thread_context(cog: commands.Cog, thread_id: str, message_count: int) -> Dict[str, Any]: """ Retrieves recent messages from a specific Discord thread to provide conversation context. @@ -267,6 +274,7 @@ async def get_thread_context(cog: commands.Cog, thread_id: str, message_count: i except Exception as e: return {"error": f"Error getting thread context: {str(e)}"} +@tool async def get_user_interaction_history(cog: commands.Cog, user_id_1: str, limit: int, user_id_2: Optional[str] = None) -> Dict[str, Any]: """ Retrieves the recent message history involving interactions (replies, mentions) between two users. @@ -315,6 +323,7 @@ async def get_user_interaction_history(cog: commands.Cog, user_id_1: str, limit: except Exception as e: return {"error": f"Error getting user interaction history: {str(e)}"} +@tool async def get_conversation_summary(cog: commands.Cog, channel_id: Optional[str] = None, message_limit: int = 25) -> Dict[str, Any]: """ Generates and returns a concise summary of the recent conversation in a specified channel or the current channel. @@ -399,6 +408,7 @@ async def get_conversation_summary(cog: commands.Cog, channel_id: Optional[str] traceback.print_exc() return {"error": error_msg} +@tool async def get_message_context(cog: commands.Cog, message_id: str, before_count: int = 5, after_count: int = 5) -> Dict[str, Any]: """ Retrieves messages immediately before and after a specific message ID within the current channel. @@ -440,6 +450,7 @@ async def get_message_context(cog: commands.Cog, message_id: str, before_count: except Exception as e: return {"error": f"Error getting message context: {str(e)}"} +@tool async def web_search(cog: commands.Cog, query: str, search_depth: str = TAVILY_DEFAULT_SEARCH_DEPTH, max_results: int = TAVILY_DEFAULT_MAX_RESULTS, topic: str = "general", include_domains: Optional[List[str]] = None, exclude_domains: Optional[List[str]] = None, include_answer: bool = True, include_raw_content: bool = False, include_images: bool = False) -> Dict[str, Any]: """ Performs a web search using the Tavily API based on the provided query and parameters. @@ -520,6 +531,7 @@ async def web_search(cog: commands.Cog, query: str, search_depth: str = TAVILY_D print(error_message) return {"error": error_message, "timestamp": datetime.datetime.now().isoformat()} +@tool async def remember_user_fact(cog: commands.Cog, user_id: str, fact: str) -> Dict[str, Any]: """ Stores a specific fact about a given user in the bot's long-term memory. @@ -545,6 +557,7 @@ async def remember_user_fact(cog: commands.Cog, user_id: str, fact: str) -> Dict print(error_message); traceback.print_exc() return {"error": error_message} +@tool async def get_user_facts(cog: commands.Cog, user_id: str) -> Dict[str, Any]: """ Retrieves all stored facts associated with a specific user from the bot's long-term memory. @@ -566,6 +579,7 @@ async def get_user_facts(cog: commands.Cog, user_id: str) -> Dict[str, Any]: print(error_message); traceback.print_exc() return {"error": error_message} +@tool async def remember_general_fact(cog: commands.Cog, fact: str) -> Dict[str, Any]: """ Stores a general fact (not specific to any user) in the bot's long-term memory. @@ -590,6 +604,7 @@ async def remember_general_fact(cog: commands.Cog, fact: str) -> Dict[str, Any]: print(error_message); traceback.print_exc() return {"error": error_message} +@tool async def get_general_facts(cog: commands.Cog, query: Optional[str] = None, limit: Optional[int] = 10) -> Dict[str, Any]: """ Retrieves general facts from the bot's long-term memory. Can optionally filter by a query string. @@ -612,6 +627,7 @@ async def get_general_facts(cog: commands.Cog, query: Optional[str] = None, limi print(error_message); traceback.print_exc() return {"error": error_message} +@tool async def timeout_user(cog: commands.Cog, user_id: str, duration_minutes: int, reason: Optional[str] = None) -> Dict[str, Any]: """ Applies a timeout to a specified user within the current server (guild). @@ -653,6 +669,7 @@ async def timeout_user(cog: commands.Cog, user_id: str, duration_minutes: int, r except discord.HTTPException as e: print(f"API error timeout {user_id}: {e}"); return {"error": f"API error timeout {user_id}: {e}"} except Exception as e: print(f"Unexpected error timeout {user_id}: {e}"); traceback.print_exc(); return {"error": f"Unexpected error timeout {user_id}: {str(e)}"} +@tool async def remove_timeout(cog: commands.Cog, user_id: str, reason: Optional[str] = None) -> Dict[str, Any]: """ Removes an active timeout from a specified user within the current server (guild). @@ -689,16 +706,21 @@ async def remove_timeout(cog: commands.Cog, user_id: str, reason: Optional[str] except discord.HTTPException as e: print(f"API error remove timeout {user_id}: {e}"); return {"error": f"API error remove timeout {user_id}: {e}"} except Exception as e: print(f"Unexpected error remove timeout {user_id}: {e}"); traceback.print_exc(); return {"error": f"Unexpected error remove timeout {user_id}: {str(e)}"} -def calculate(cog: commands.Cog, expression: str) -> Dict[str, Any]: +@tool +def calculate(cog: commands.Cog, expression: str) -> str: """ - Evaluates a mathematical expression using the asteval library. Supports common math functions. + Evaluates a mathematical expression using the asteval library. Supports common math functions. Returns the result as a string. + + Args: + cog: The GurtCog instance (automatically passed). + expression: The mathematical expression string to evaluate (e.g., "2 * (pi + 1)"). Args: cog: The GurtCog instance (automatically passed). expression: The mathematical expression string to evaluate (e.g., "2 * (pi + 1)"). Returns: - A dictionary containing the original expression, the calculated result string, and status 'success', or an error dictionary. + The calculated result as a string, or an error message string if calculation fails. """ print(f"Calculating expression: {expression}") aeval = Interpreter() @@ -708,20 +730,17 @@ def calculate(cog: commands.Cog, expression: str) -> Dict[str, Any]: error_details = '; '.join(err.get_error() for err in aeval.error) error_message = f"Calculation error: {error_details}" print(error_message) - return {"error": error_message, "expression": expression} - - if isinstance(result, (int, float, complex)): result_str = str(result) - else: result_str = repr(result) # Fallback + return f"Error: {error_message}" # Return error as string + result_str = str(result) # Convert result to string print(f"Calculation result: {result_str}") - # Return only the result string on success, as expected by some agent frameworks - return result_str + return result_str # Return result as string except Exception as e: error_message = f"Unexpected error during calculation: {str(e)}" print(error_message); traceback.print_exc() - # Return error message as string on failure - return {"error": error_message, "expression": expression} + return f"Error: {error_message}" # Return error as string +@tool async def run_python_code(cog: commands.Cog, code: str) -> Dict[str, Any]: """ Executes a provided Python code snippet remotely using the Piston API. @@ -767,6 +786,7 @@ async def run_python_code(cog: commands.Cog, code: str) -> Dict[str, Any]: except aiohttp.ClientError as e: print(f"Piston network error: {e}"); return {"error": f"Network error connecting to Piston: {str(e)}"} except Exception as e: print(f"Unexpected Piston error: {e}"); traceback.print_exc(); return {"error": f"Unexpected error during Python execution: {str(e)}"} +@tool async def create_poll(cog: commands.Cog, question: str, options: List[str]) -> Dict[str, Any]: """ Creates a simple poll message in the current channel with numbered reaction options. @@ -801,7 +821,7 @@ async def create_poll(cog: commands.Cog, question: str, options: List[str]) -> D except discord.HTTPException as e: print(f"Poll API error: {e}"); return {"error": f"API error creating poll: {e}"} except Exception as e: print(f"Poll unexpected error: {e}"); traceback.print_exc(); return {"error": f"Unexpected error creating poll: {str(e)}"} -# Helper function to convert memory string (e.g., "128m") to bytes +# Helper function to convert memory string (e.g., "128m") to bytes (Not a tool) def parse_mem_limit(mem_limit_str: str) -> Optional[int]: if not mem_limit_str: return None mem_limit_str = mem_limit_str.lower() @@ -859,6 +879,7 @@ async def _check_command_safety(cog: commands.Cog, command: str) -> Dict[str, An print(f"AI Safety Check Error: Response was {safety_response}") return {"safe": False, "reason": error_msg} +@tool async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any]: """ Executes a shell command within an isolated, network-disabled Docker container. @@ -979,6 +1000,8 @@ async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any # Ensure the client connection is closed if client: await client.close() + +@tool async def get_user_id(cog: commands.Cog, user_name: str) -> Dict[str, Any]: """ Finds the Discord User ID associated with a given username or display name. @@ -1029,7 +1052,7 @@ async def get_user_id(cog: commands.Cog, user_name: str) -> Dict[str, Any]: print(f"User '{user_name}' not found in guild '{guild.name}'.") return {"error": f"User '{user_name}' not found in this server.", "user_name": user_name} - +# NOT decorating execute_internal_command as it's marked unsafe for general agent use async def execute_internal_command(cog: commands.Cog, command: str, timeout_seconds: int = 60) -> Dict[str, Any]: """ Executes a shell command directly on the host machine where the bot is running. @@ -1097,6 +1120,7 @@ async def execute_internal_command(cog: commands.Cog, command: str, timeout_seco traceback.print_exc() return {"error": error_message, "command": command, "status": "error"} +@tool async def extract_web_content(cog: commands.Cog, urls: Union[str, List[str]], extract_depth: str = "basic", include_images: bool = False) -> Dict[str, Any]: """ Extracts the main textual content and optionally images from one or more web URLs using the Tavily API. @@ -1141,6 +1165,7 @@ async def extract_web_content(cog: commands.Cog, urls: Union[str, List[str]], ex print(error_message) return {"error": error_message, "timestamp": datetime.datetime.now().isoformat()} +@tool async def read_file_content(cog: commands.Cog, file_path: str) -> Dict[str, Any]: """ Reads the content of a specified file located on the bot's host machine. @@ -1227,7 +1252,7 @@ async def read_file_content(cog: commands.Cog, file_path: str) -> Dict[str, Any] traceback.print_exc() return {"error": error_message, "file_path": file_path} -# --- Meta Tool: Create New Tool --- +# --- Meta Tool: Create New Tool --- (Not decorating as it's experimental/dangerous) # WARNING: HIGHLY EXPERIMENTAL AND DANGEROUS. Allows AI to write and load code. async def create_new_tool(cog: commands.Cog, tool_name: str, description: str, parameters_json: str, returns_description: str) -> Dict[str, Any]: """ @@ -1473,6 +1498,8 @@ async def create_new_tool(cog: commands.Cog, tool_name: str, description: str, p # --- Tool Mapping --- # This dictionary maps tool names (used in the AI prompt) to their implementation functions. +# The agent should discover tools via the @tool decorator, but this mapping might still be used elsewhere. +# Keep it updated, but the primary mechanism for the agent is the decorator. TOOL_MAPPING = { "get_recent_messages": get_recent_messages, "search_user_messages": search_user_messages, @@ -1497,7 +1524,7 @@ TOOL_MAPPING = { "remove_timeout": remove_timeout, "extract_web_content": extract_web_content, "read_file_content": read_file_content, - "create_new_tool": create_new_tool, # Added the meta-tool - "execute_internal_command": execute_internal_command, # Added internal command execution + "create_new_tool": create_new_tool, # Meta-tool (not decorated) + "execute_internal_command": execute_internal_command, # Internal command execution (not decorated) "get_user_id": get_user_id # Added user ID lookup tool }