aa
This commit is contained in:
parent
38856d2798
commit
b10f11ce51
56
gurt/api.py
56
gurt/api.py
@ -232,46 +232,36 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name:
|
||||
])
|
||||
|
||||
# --- 2. Prepare Tools ---
|
||||
# Create wrappers for tools to include the 'cog' instance
|
||||
# Collect decorated tool functions from the TOOL_MAPPING
|
||||
# The @tool decorator handles schema generation and description.
|
||||
# We still need to bind the 'cog' instance to each tool call.
|
||||
prepared_tools = []
|
||||
for tool_name, tool_func in TOOL_MAPPING.items():
|
||||
# Check if the function is actually decorated (optional, but good practice)
|
||||
# The @tool decorator adds attributes like .name, .description, .args_schema
|
||||
if not hasattr(tool_func, 'is_lc_tool') and not hasattr(tool_func, 'name'):
|
||||
# Skip functions not decorated with @tool, or internal/experimental ones
|
||||
if tool_name not in ["create_new_tool", "execute_internal_command", "_check_command_safety"]:
|
||||
logger.warning(f"Tool function '{tool_name}' in TOOL_MAPPING is missing the @tool decorator. Skipping.")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Define a wrapper function that captures cog and calls the original tool
|
||||
# This wrapper will have the correct signature for Langchain to inspect
|
||||
def create_wrapper(func_to_wrap, current_cog):
|
||||
# Define the actual wrapper that Langchain will call
|
||||
# It accepts the arguments Langchain extracts based on the original function's signature
|
||||
def tool_wrapper(*args, **kwargs):
|
||||
# Call the original tool function with cog as the first argument
|
||||
# Pass through the arguments received from Langchain
|
||||
return func_to_wrap(current_cog, *args, **kwargs)
|
||||
# Copy metadata for Langchain schema generation from the original function
|
||||
functools.update_wrapper(tool_wrapper, func_to_wrap)
|
||||
# Ensure the wrapper has the correct name for the tool mapping/introspection
|
||||
# Note: Langchain might use the name attribute if wrapped in LangchainTool later
|
||||
tool_wrapper.__name__ = func_to_wrap.__name__
|
||||
return tool_wrapper
|
||||
# Create a partial function that includes the 'cog' instance
|
||||
# Langchain agents can often handle partial functions directly.
|
||||
# The agent will call this partial, which in turn calls the original tool with 'cog'.
|
||||
# The @tool decorator ensures Langchain gets the correct signature from the *original* func.
|
||||
tool_with_cog = functools.partial(tool_func, cog)
|
||||
|
||||
# Create the specific wrapper for this tool, capturing the current cog
|
||||
wrapped_tool_func = create_wrapper(tool_func, cog)
|
||||
# Copy essential attributes from the original decorated tool function to the partial
|
||||
# This helps ensure the agent framework recognizes it correctly.
|
||||
functools.update_wrapper(tool_with_cog, tool_func)
|
||||
|
||||
# Pass the wrapped function directly to the agent's tool list.
|
||||
# LangchainAgent should be able to handle functions directly.
|
||||
# We explicitly provide the name in the LangchainTool wrapper below for clarity if needed.
|
||||
# For now, let's try passing the function itself, relying on Langchain's introspection.
|
||||
# If issues persist, wrap in LangchainTool:
|
||||
prepared_tools.append(LangchainTool(
|
||||
name=tool_name, # Use the key from TOOL_MAPPING as the definitive name
|
||||
func=wrapped_tool_func, # Pass the wrapper function
|
||||
description=tool_func.__doc__ or f"Executes the {tool_name} tool.", # Use original docstring
|
||||
# LangChain should infer args_schema from the *original* tool_func's type hints
|
||||
# because functools.update_wrapper copies the signature.
|
||||
))
|
||||
# Simpler alternative (try if the above fails):
|
||||
# prepared_tools.append(wrapped_tool_func)
|
||||
# Add the partial function (which now includes cog) to the list
|
||||
prepared_tools.append(tool_with_cog)
|
||||
logger.debug(f"Prepared tool '{tool_name}' with cog instance bound.")
|
||||
|
||||
except Exception as tool_prep_e:
|
||||
logger.error(f"Error preparing tool '{tool_name}': {tool_prep_e}", exc_info=True)
|
||||
logger.error(f"Error preparing tool '{tool_name}' with functools.partial: {tool_prep_e}", exc_info=True)
|
||||
# Optionally skip this tool
|
||||
|
||||
# --- 3. Prepare Chat History Factory ---
|
||||
|
@ -17,6 +17,7 @@ from tavily import TavilyClient
|
||||
import docker
|
||||
import aiodocker # Use aiodocker for async operations
|
||||
from asteval import Interpreter # Added for calculate tool
|
||||
from langchain_core.tools import tool # Import the tool decorator
|
||||
|
||||
# Relative imports from within the gurt package and parent
|
||||
from .memory import MemoryManager # Import from local memory.py
|
||||
@ -37,6 +38,7 @@ from .config import (
|
||||
# to access things like cog.bot, cog.session, cog.current_channel, cog.memory_manager etc.
|
||||
# We will add 'cog' as the first parameter to each.
|
||||
|
||||
@tool
|
||||
async def get_recent_messages(cog: commands.Cog, limit: int, channel_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieves the most recent messages from a specified Discord channel or the current channel.
|
||||
@ -71,6 +73,7 @@ async def get_recent_messages(cog: commands.Cog, limit: int, channel_id: Optiona
|
||||
except Exception as e:
|
||||
return {"error": f"Error retrieving messages: {str(e)}", "timestamp": datetime.datetime.now().isoformat()}
|
||||
|
||||
@tool
|
||||
async def search_user_messages(cog: commands.Cog, user_id: str, limit: int, channel_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Searches recent channel history for messages sent by a specific user.
|
||||
@ -115,6 +118,7 @@ async def search_user_messages(cog: commands.Cog, user_id: str, limit: int, chan
|
||||
except Exception as e:
|
||||
return {"error": f"Error searching user messages: {str(e)}", "timestamp": datetime.datetime.now().isoformat()}
|
||||
|
||||
@tool
|
||||
async def search_messages_by_content(cog: commands.Cog, search_term: str, limit: int, channel_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Searches recent channel history for messages containing specific text content (case-insensitive).
|
||||
@ -154,6 +158,7 @@ async def search_messages_by_content(cog: commands.Cog, search_term: str, limit:
|
||||
except Exception as e:
|
||||
return {"error": f"Error searching messages by content: {str(e)}", "timestamp": datetime.datetime.now().isoformat()}
|
||||
|
||||
@tool
|
||||
async def get_channel_info(cog: commands.Cog, channel_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieves detailed information about a specified Discord channel or the current channel.
|
||||
@ -192,6 +197,7 @@ async def get_channel_info(cog: commands.Cog, channel_id: Optional[str] = None)
|
||||
except Exception as e:
|
||||
return {"error": f"Error getting channel info: {str(e)}", "timestamp": datetime.datetime.now().isoformat()}
|
||||
|
||||
@tool
|
||||
async def get_conversation_context(cog: commands.Cog, message_count: int, channel_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieves recent messages to provide context for the ongoing conversation in a channel.
|
||||
@ -231,6 +237,7 @@ async def get_conversation_context(cog: commands.Cog, message_count: int, channe
|
||||
except Exception as e:
|
||||
return {"error": f"Error getting conversation context: {str(e)}"}
|
||||
|
||||
@tool
|
||||
async def get_thread_context(cog: commands.Cog, thread_id: str, message_count: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieves recent messages from a specific Discord thread to provide conversation context.
|
||||
@ -267,6 +274,7 @@ async def get_thread_context(cog: commands.Cog, thread_id: str, message_count: i
|
||||
except Exception as e:
|
||||
return {"error": f"Error getting thread context: {str(e)}"}
|
||||
|
||||
@tool
|
||||
async def get_user_interaction_history(cog: commands.Cog, user_id_1: str, limit: int, user_id_2: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieves the recent message history involving interactions (replies, mentions) between two users.
|
||||
@ -315,6 +323,7 @@ async def get_user_interaction_history(cog: commands.Cog, user_id_1: str, limit:
|
||||
except Exception as e:
|
||||
return {"error": f"Error getting user interaction history: {str(e)}"}
|
||||
|
||||
@tool
|
||||
async def get_conversation_summary(cog: commands.Cog, channel_id: Optional[str] = None, message_limit: int = 25) -> Dict[str, Any]:
|
||||
"""
|
||||
Generates and returns a concise summary of the recent conversation in a specified channel or the current channel.
|
||||
@ -399,6 +408,7 @@ async def get_conversation_summary(cog: commands.Cog, channel_id: Optional[str]
|
||||
traceback.print_exc()
|
||||
return {"error": error_msg}
|
||||
|
||||
@tool
|
||||
async def get_message_context(cog: commands.Cog, message_id: str, before_count: int = 5, after_count: int = 5) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieves messages immediately before and after a specific message ID within the current channel.
|
||||
@ -440,6 +450,7 @@ async def get_message_context(cog: commands.Cog, message_id: str, before_count:
|
||||
except Exception as e:
|
||||
return {"error": f"Error getting message context: {str(e)}"}
|
||||
|
||||
@tool
|
||||
async def web_search(cog: commands.Cog, query: str, search_depth: str = TAVILY_DEFAULT_SEARCH_DEPTH, max_results: int = TAVILY_DEFAULT_MAX_RESULTS, topic: str = "general", include_domains: Optional[List[str]] = None, exclude_domains: Optional[List[str]] = None, include_answer: bool = True, include_raw_content: bool = False, include_images: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Performs a web search using the Tavily API based on the provided query and parameters.
|
||||
@ -520,6 +531,7 @@ async def web_search(cog: commands.Cog, query: str, search_depth: str = TAVILY_D
|
||||
print(error_message)
|
||||
return {"error": error_message, "timestamp": datetime.datetime.now().isoformat()}
|
||||
|
||||
@tool
|
||||
async def remember_user_fact(cog: commands.Cog, user_id: str, fact: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Stores a specific fact about a given user in the bot's long-term memory.
|
||||
@ -545,6 +557,7 @@ async def remember_user_fact(cog: commands.Cog, user_id: str, fact: str) -> Dict
|
||||
print(error_message); traceback.print_exc()
|
||||
return {"error": error_message}
|
||||
|
||||
@tool
|
||||
async def get_user_facts(cog: commands.Cog, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieves all stored facts associated with a specific user from the bot's long-term memory.
|
||||
@ -566,6 +579,7 @@ async def get_user_facts(cog: commands.Cog, user_id: str) -> Dict[str, Any]:
|
||||
print(error_message); traceback.print_exc()
|
||||
return {"error": error_message}
|
||||
|
||||
@tool
|
||||
async def remember_general_fact(cog: commands.Cog, fact: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Stores a general fact (not specific to any user) in the bot's long-term memory.
|
||||
@ -590,6 +604,7 @@ async def remember_general_fact(cog: commands.Cog, fact: str) -> Dict[str, Any]:
|
||||
print(error_message); traceback.print_exc()
|
||||
return {"error": error_message}
|
||||
|
||||
@tool
|
||||
async def get_general_facts(cog: commands.Cog, query: Optional[str] = None, limit: Optional[int] = 10) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieves general facts from the bot's long-term memory. Can optionally filter by a query string.
|
||||
@ -612,6 +627,7 @@ async def get_general_facts(cog: commands.Cog, query: Optional[str] = None, limi
|
||||
print(error_message); traceback.print_exc()
|
||||
return {"error": error_message}
|
||||
|
||||
@tool
|
||||
async def timeout_user(cog: commands.Cog, user_id: str, duration_minutes: int, reason: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Applies a timeout to a specified user within the current server (guild).
|
||||
@ -653,6 +669,7 @@ async def timeout_user(cog: commands.Cog, user_id: str, duration_minutes: int, r
|
||||
except discord.HTTPException as e: print(f"API error timeout {user_id}: {e}"); return {"error": f"API error timeout {user_id}: {e}"}
|
||||
except Exception as e: print(f"Unexpected error timeout {user_id}: {e}"); traceback.print_exc(); return {"error": f"Unexpected error timeout {user_id}: {str(e)}"}
|
||||
|
||||
@tool
|
||||
async def remove_timeout(cog: commands.Cog, user_id: str, reason: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Removes an active timeout from a specified user within the current server (guild).
|
||||
@ -689,16 +706,21 @@ async def remove_timeout(cog: commands.Cog, user_id: str, reason: Optional[str]
|
||||
except discord.HTTPException as e: print(f"API error remove timeout {user_id}: {e}"); return {"error": f"API error remove timeout {user_id}: {e}"}
|
||||
except Exception as e: print(f"Unexpected error remove timeout {user_id}: {e}"); traceback.print_exc(); return {"error": f"Unexpected error remove timeout {user_id}: {str(e)}"}
|
||||
|
||||
def calculate(cog: commands.Cog, expression: str) -> Dict[str, Any]:
|
||||
@tool
|
||||
def calculate(cog: commands.Cog, expression: str) -> str:
|
||||
"""
|
||||
Evaluates a mathematical expression using the asteval library. Supports common math functions.
|
||||
Evaluates a mathematical expression using the asteval library. Supports common math functions. Returns the result as a string.
|
||||
|
||||
Args:
|
||||
cog: The GurtCog instance (automatically passed).
|
||||
expression: The mathematical expression string to evaluate (e.g., "2 * (pi + 1)").
|
||||
|
||||
Args:
|
||||
cog: The GurtCog instance (automatically passed).
|
||||
expression: The mathematical expression string to evaluate (e.g., "2 * (pi + 1)").
|
||||
|
||||
Returns:
|
||||
A dictionary containing the original expression, the calculated result string, and status 'success', or an error dictionary.
|
||||
The calculated result as a string, or an error message string if calculation fails.
|
||||
"""
|
||||
print(f"Calculating expression: {expression}")
|
||||
aeval = Interpreter()
|
||||
@ -708,20 +730,17 @@ def calculate(cog: commands.Cog, expression: str) -> Dict[str, Any]:
|
||||
error_details = '; '.join(err.get_error() for err in aeval.error)
|
||||
error_message = f"Calculation error: {error_details}"
|
||||
print(error_message)
|
||||
return {"error": error_message, "expression": expression}
|
||||
|
||||
if isinstance(result, (int, float, complex)): result_str = str(result)
|
||||
else: result_str = repr(result) # Fallback
|
||||
return f"Error: {error_message}" # Return error as string
|
||||
|
||||
result_str = str(result) # Convert result to string
|
||||
print(f"Calculation result: {result_str}")
|
||||
# Return only the result string on success, as expected by some agent frameworks
|
||||
return result_str
|
||||
return result_str # Return result as string
|
||||
except Exception as e:
|
||||
error_message = f"Unexpected error during calculation: {str(e)}"
|
||||
print(error_message); traceback.print_exc()
|
||||
# Return error message as string on failure
|
||||
return {"error": error_message, "expression": expression}
|
||||
return f"Error: {error_message}" # Return error as string
|
||||
|
||||
@tool
|
||||
async def run_python_code(cog: commands.Cog, code: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Executes a provided Python code snippet remotely using the Piston API.
|
||||
@ -767,6 +786,7 @@ async def run_python_code(cog: commands.Cog, code: str) -> Dict[str, Any]:
|
||||
except aiohttp.ClientError as e: print(f"Piston network error: {e}"); return {"error": f"Network error connecting to Piston: {str(e)}"}
|
||||
except Exception as e: print(f"Unexpected Piston error: {e}"); traceback.print_exc(); return {"error": f"Unexpected error during Python execution: {str(e)}"}
|
||||
|
||||
@tool
|
||||
async def create_poll(cog: commands.Cog, question: str, options: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Creates a simple poll message in the current channel with numbered reaction options.
|
||||
@ -801,7 +821,7 @@ async def create_poll(cog: commands.Cog, question: str, options: List[str]) -> D
|
||||
except discord.HTTPException as e: print(f"Poll API error: {e}"); return {"error": f"API error creating poll: {e}"}
|
||||
except Exception as e: print(f"Poll unexpected error: {e}"); traceback.print_exc(); return {"error": f"Unexpected error creating poll: {str(e)}"}
|
||||
|
||||
# Helper function to convert memory string (e.g., "128m") to bytes
|
||||
# Helper function to convert memory string (e.g., "128m") to bytes (Not a tool)
|
||||
def parse_mem_limit(mem_limit_str: str) -> Optional[int]:
|
||||
if not mem_limit_str: return None
|
||||
mem_limit_str = mem_limit_str.lower()
|
||||
@ -859,6 +879,7 @@ async def _check_command_safety(cog: commands.Cog, command: str) -> Dict[str, An
|
||||
print(f"AI Safety Check Error: Response was {safety_response}")
|
||||
return {"safe": False, "reason": error_msg}
|
||||
|
||||
@tool
|
||||
async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Executes a shell command within an isolated, network-disabled Docker container.
|
||||
@ -979,6 +1000,8 @@ async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any
|
||||
# Ensure the client connection is closed
|
||||
if client:
|
||||
await client.close()
|
||||
|
||||
@tool
|
||||
async def get_user_id(cog: commands.Cog, user_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Finds the Discord User ID associated with a given username or display name.
|
||||
@ -1029,7 +1052,7 @@ async def get_user_id(cog: commands.Cog, user_name: str) -> Dict[str, Any]:
|
||||
print(f"User '{user_name}' not found in guild '{guild.name}'.")
|
||||
return {"error": f"User '{user_name}' not found in this server.", "user_name": user_name}
|
||||
|
||||
|
||||
# NOT decorating execute_internal_command as it's marked unsafe for general agent use
|
||||
async def execute_internal_command(cog: commands.Cog, command: str, timeout_seconds: int = 60) -> Dict[str, Any]:
|
||||
"""
|
||||
Executes a shell command directly on the host machine where the bot is running.
|
||||
@ -1097,6 +1120,7 @@ async def execute_internal_command(cog: commands.Cog, command: str, timeout_seco
|
||||
traceback.print_exc()
|
||||
return {"error": error_message, "command": command, "status": "error"}
|
||||
|
||||
@tool
|
||||
async def extract_web_content(cog: commands.Cog, urls: Union[str, List[str]], extract_depth: str = "basic", include_images: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Extracts the main textual content and optionally images from one or more web URLs using the Tavily API.
|
||||
@ -1141,6 +1165,7 @@ async def extract_web_content(cog: commands.Cog, urls: Union[str, List[str]], ex
|
||||
print(error_message)
|
||||
return {"error": error_message, "timestamp": datetime.datetime.now().isoformat()}
|
||||
|
||||
@tool
|
||||
async def read_file_content(cog: commands.Cog, file_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Reads the content of a specified file located on the bot's host machine.
|
||||
@ -1227,7 +1252,7 @@ async def read_file_content(cog: commands.Cog, file_path: str) -> Dict[str, Any]
|
||||
traceback.print_exc()
|
||||
return {"error": error_message, "file_path": file_path}
|
||||
|
||||
# --- Meta Tool: Create New Tool ---
|
||||
# --- Meta Tool: Create New Tool --- (Not decorating as it's experimental/dangerous)
|
||||
# WARNING: HIGHLY EXPERIMENTAL AND DANGEROUS. Allows AI to write and load code.
|
||||
async def create_new_tool(cog: commands.Cog, tool_name: str, description: str, parameters_json: str, returns_description: str) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -1473,6 +1498,8 @@ async def create_new_tool(cog: commands.Cog, tool_name: str, description: str, p
|
||||
|
||||
# --- Tool Mapping ---
|
||||
# This dictionary maps tool names (used in the AI prompt) to their implementation functions.
|
||||
# The agent should discover tools via the @tool decorator, but this mapping might still be used elsewhere.
|
||||
# Keep it updated, but the primary mechanism for the agent is the decorator.
|
||||
TOOL_MAPPING = {
|
||||
"get_recent_messages": get_recent_messages,
|
||||
"search_user_messages": search_user_messages,
|
||||
@ -1497,7 +1524,7 @@ TOOL_MAPPING = {
|
||||
"remove_timeout": remove_timeout,
|
||||
"extract_web_content": extract_web_content,
|
||||
"read_file_content": read_file_content,
|
||||
"create_new_tool": create_new_tool, # Added the meta-tool
|
||||
"execute_internal_command": execute_internal_command, # Added internal command execution
|
||||
"create_new_tool": create_new_tool, # Meta-tool (not decorated)
|
||||
"execute_internal_command": execute_internal_command, # Internal command execution (not decorated)
|
||||
"get_user_id": get_user_id # Added user ID lookup tool
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user