This commit is contained in:
Slipstream 2025-04-28 23:09:32 -06:00
parent d66d935d4d
commit d26c0527d3
Signed by: slipstream
GPG Key ID: 13E498CE010AC6FD

View File

@ -17,7 +17,7 @@ from tavily import TavilyClient
import docker
import aiodocker # Use aiodocker for async operations
from asteval import Interpreter # Added for calculate tool
from langchain_core.tools import tool # Import the tool decorator
# Removed: from langchain_core.tools import tool
# Relative imports from within the gurt package and parent
from .memory import MemoryManager # Import from local memory.py
@ -36,10 +36,10 @@ from .config import (
# --- Tool Implementations ---
# Note: Most of these functions will need the 'cog' instance passed to them
# to access things like cog.bot, cog.session, cog.current_channel, cog.memory_manager etc.
# We will add 'cog' as the first parameter to each. (Wrapper in api.py handles this)
# We will add 'cog' as the first parameter to each.
@tool
async def get_recent_messages(limit: int, channel_id: Optional[str] = None, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def get_recent_messages(cog: commands.Cog, limit: int, channel_id: Optional[str] = None) -> Dict[str, Any]:
"""
Retrieves the most recent messages from a specified Discord channel or the current channel.
@ -73,8 +73,8 @@ async def get_recent_messages(limit: int, channel_id: Optional[str] = None, *, c
except Exception as e:
return {"error": f"Error retrieving messages: {str(e)}", "timestamp": datetime.datetime.now().isoformat()}
@tool
async def search_user_messages(user_id: str, limit: int, channel_id: Optional[str] = None, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def search_user_messages(cog: commands.Cog, user_id: str, limit: int, channel_id: Optional[str] = None) -> Dict[str, Any]:
"""
Searches recent channel history for messages sent by a specific user.
@ -118,8 +118,8 @@ async def search_user_messages(user_id: str, limit: int, channel_id: Optional[st
except Exception as e:
return {"error": f"Error searching user messages: {str(e)}", "timestamp": datetime.datetime.now().isoformat()}
@tool
async def search_messages_by_content(search_term: str, limit: int, channel_id: Optional[str] = None, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def search_messages_by_content(cog: commands.Cog, search_term: str, limit: int, channel_id: Optional[str] = None) -> Dict[str, Any]:
"""
Searches recent channel history for messages containing specific text content (case-insensitive).
@ -158,8 +158,8 @@ async def search_messages_by_content(search_term: str, limit: int, channel_id: O
except Exception as e:
return {"error": f"Error searching messages by content: {str(e)}", "timestamp": datetime.datetime.now().isoformat()}
@tool
async def get_channel_info(channel_id: Optional[str] = None, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def get_channel_info(cog: commands.Cog, channel_id: Optional[str] = None) -> Dict[str, Any]:
"""
Retrieves detailed information about a specified Discord channel or the current channel.
@ -197,8 +197,8 @@ async def get_channel_info(channel_id: Optional[str] = None, *, cog: commands.Co
except Exception as e:
return {"error": f"Error getting channel info: {str(e)}", "timestamp": datetime.datetime.now().isoformat()}
@tool
async def get_conversation_context(message_count: int, channel_id: Optional[str] = None, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def get_conversation_context(cog: commands.Cog, message_count: int, channel_id: Optional[str] = None) -> Dict[str, Any]:
"""
Retrieves recent messages to provide context for the ongoing conversation in a channel.
@ -237,8 +237,8 @@ async def get_conversation_context(message_count: int, channel_id: Optional[str]
except Exception as e:
return {"error": f"Error getting conversation context: {str(e)}"}
@tool
async def get_thread_context(thread_id: str, message_count: int, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def get_thread_context(cog: commands.Cog, thread_id: str, message_count: int) -> Dict[str, Any]:
"""
Retrieves recent messages from a specific Discord thread to provide conversation context.
@ -274,8 +274,8 @@ async def get_thread_context(thread_id: str, message_count: int, *, cog: command
except Exception as e:
return {"error": f"Error getting thread context: {str(e)}"}
@tool
async def get_user_interaction_history(user_id_1: str, limit: int, user_id_2: Optional[str] = None, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def get_user_interaction_history(cog: commands.Cog, user_id_1: str, limit: int, user_id_2: Optional[str] = None) -> Dict[str, Any]:
"""
Retrieves the recent message history involving interactions (replies, mentions) between two users.
If user_id_2 is not provided, it defaults to interactions between user_id_1 and the bot (Gurt).
@ -323,8 +323,8 @@ async def get_user_interaction_history(user_id_1: str, limit: int, user_id_2: Op
except Exception as e:
return {"error": f"Error getting user interaction history: {str(e)}"}
@tool
async def get_conversation_summary(channel_id: Optional[str] = None, message_limit: int = 25, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def get_conversation_summary(cog: commands.Cog, channel_id: Optional[str] = None, message_limit: int = 25) -> Dict[str, Any]:
"""
Generates and returns a concise summary of the recent conversation in a specified channel or the current channel.
Uses an internal LLM call for summarization and caches the result.
@ -408,8 +408,8 @@ async def get_conversation_summary(channel_id: Optional[str] = None, message_lim
traceback.print_exc()
return {"error": error_msg}
@tool
async def get_message_context(message_id: str, before_count: int = 5, after_count: int = 5, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def get_message_context(cog: commands.Cog, message_id: str, before_count: int = 5, after_count: int = 5) -> Dict[str, Any]:
"""
Retrieves messages immediately before and after a specific message ID within the current channel.
@ -450,8 +450,8 @@ async def get_message_context(message_id: str, before_count: int = 5, after_coun
except Exception as e:
return {"error": f"Error getting message context: {str(e)}"}
@tool
async def web_search(query: str, search_depth: str = TAVILY_DEFAULT_SEARCH_DEPTH, max_results: int = TAVILY_DEFAULT_MAX_RESULTS, topic: str = "general", include_domains: Optional[List[str]] = None, exclude_domains: Optional[List[str]] = None, include_answer: bool = True, include_raw_content: bool = False, include_images: bool = False, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def web_search(cog: commands.Cog, query: str, search_depth: str = TAVILY_DEFAULT_SEARCH_DEPTH, max_results: int = TAVILY_DEFAULT_MAX_RESULTS, topic: str = "general", include_domains: Optional[List[str]] = None, exclude_domains: Optional[List[str]] = None, include_answer: bool = True, include_raw_content: bool = False, include_images: bool = False) -> Dict[str, Any]:
"""
Performs a web search using the Tavily API based on the provided query and parameters.
@ -531,8 +531,8 @@ async def web_search(query: str, search_depth: str = TAVILY_DEFAULT_SEARCH_DEPTH
print(error_message)
return {"error": error_message, "timestamp": datetime.datetime.now().isoformat()}
@tool
async def remember_user_fact(user_id: str, fact: str, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def remember_user_fact(cog: commands.Cog, user_id: str, fact: str) -> Dict[str, Any]:
"""
Stores a specific fact about a given user in the bot's long-term memory.
@ -557,8 +557,8 @@ async def remember_user_fact(user_id: str, fact: str, *, cog: commands.Cog) -> D
print(error_message); traceback.print_exc()
return {"error": error_message}
@tool
async def get_user_facts(user_id: str, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def get_user_facts(cog: commands.Cog, user_id: str) -> Dict[str, Any]:
"""
Retrieves all stored facts associated with a specific user from the bot's long-term memory.
@ -579,8 +579,8 @@ async def get_user_facts(user_id: str, *, cog: commands.Cog) -> Dict[str, Any]:
print(error_message); traceback.print_exc()
return {"error": error_message}
@tool
async def remember_general_fact(fact: str, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def remember_general_fact(cog: commands.Cog, fact: str) -> Dict[str, Any]:
"""
Stores a general fact (not specific to any user) in the bot's long-term memory.
@ -604,8 +604,8 @@ async def remember_general_fact(fact: str, *, cog: commands.Cog) -> Dict[str, An
print(error_message); traceback.print_exc()
return {"error": error_message}
@tool
async def get_general_facts(query: Optional[str] = None, limit: Optional[int] = 10, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def get_general_facts(cog: commands.Cog, query: Optional[str] = None, limit: Optional[int] = 10) -> Dict[str, Any]:
"""
Retrieves general facts from the bot's long-term memory. Can optionally filter by a query string.
@ -627,8 +627,8 @@ async def get_general_facts(query: Optional[str] = None, limit: Optional[int] =
print(error_message); traceback.print_exc()
return {"error": error_message}
@tool
async def timeout_user(user_id: str, duration_minutes: int, reason: Optional[str] = None, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def timeout_user(cog: commands.Cog, user_id: str, duration_minutes: int, reason: Optional[str] = None) -> Dict[str, Any]:
"""
Applies a timeout to a specified user within the current server (guild).
@ -669,8 +669,8 @@ async def timeout_user(user_id: str, duration_minutes: int, reason: Optional[str
except discord.HTTPException as e: print(f"API error timeout {user_id}: {e}"); return {"error": f"API error timeout {user_id}: {e}"}
except Exception as e: print(f"Unexpected error timeout {user_id}: {e}"); traceback.print_exc(); return {"error": f"Unexpected error timeout {user_id}: {str(e)}"}
@tool
async def remove_timeout(user_id: str, reason: Optional[str] = None, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def remove_timeout(cog: commands.Cog, user_id: str, reason: Optional[str] = None) -> Dict[str, Any]:
"""
Removes an active timeout from a specified user within the current server (guild).
@ -706,8 +706,8 @@ async def remove_timeout(user_id: str, reason: Optional[str] = None, *, cog: com
except discord.HTTPException as e: print(f"API error remove timeout {user_id}: {e}"); return {"error": f"API error remove timeout {user_id}: {e}"}
except Exception as e: print(f"Unexpected error remove timeout {user_id}: {e}"); traceback.print_exc(); return {"error": f"Unexpected error remove timeout {user_id}: {str(e)}"}
@tool
def calculate(expression: str, *, cog: commands.Cog) -> str:
# @tool removed
def calculate(cog: commands.Cog, expression: str) -> str:
"""
Evaluates a mathematical expression using the asteval library. Supports common math functions. Returns the result as a string.
@ -740,8 +740,8 @@ def calculate(expression: str, *, cog: commands.Cog) -> str:
print(error_message); traceback.print_exc()
return f"Error: {error_message}" # Return error as string
@tool
async def run_python_code(code: str, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def run_python_code(cog: commands.Cog, code: str) -> Dict[str, Any]:
"""
Executes a provided Python code snippet remotely using the Piston API.
The execution environment is sandboxed and has limitations.
@ -786,8 +786,8 @@ async def run_python_code(code: str, *, cog: commands.Cog) -> Dict[str, Any]:
except aiohttp.ClientError as e: print(f"Piston network error: {e}"); return {"error": f"Network error connecting to Piston: {str(e)}"}
except Exception as e: print(f"Unexpected Piston error: {e}"); traceback.print_exc(); return {"error": f"Unexpected error during Python execution: {str(e)}"}
@tool
async def create_poll(question: str, options: List[str], *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def create_poll(cog: commands.Cog, question: str, options: List[str]) -> Dict[str, Any]:
"""
Creates a simple poll message in the current channel with numbered reaction options.
@ -879,8 +879,8 @@ async def _check_command_safety(cog: commands.Cog, command: str) -> Dict[str, An
print(f"AI Safety Check Error: Response was {safety_response}")
return {"safe": False, "reason": error_msg}
@tool
async def run_terminal_command(command: str, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any]:
"""
Executes a shell command within an isolated, network-disabled Docker container.
Performs an AI safety check before execution. Resource limits (CPU, memory) are applied.
@ -1001,8 +1001,8 @@ async def run_terminal_command(command: str, *, cog: commands.Cog) -> Dict[str,
if client:
await client.close()
@tool
async def get_user_id(user_name: str, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def get_user_id(cog: commands.Cog, user_name: str) -> Dict[str, Any]:
"""
Finds the Discord User ID associated with a given username or display name.
Searches the current server's members first, then falls back to recent message authors if not in a server context.
@ -1053,7 +1053,7 @@ async def get_user_id(user_name: str, *, cog: commands.Cog) -> Dict[str, Any]:
return {"error": f"User '{user_name}' not found in this server.", "user_name": user_name}
# NOT decorating execute_internal_command as it's marked unsafe for general agent use
async def execute_internal_command(command: str, timeout_seconds: int = 60, *, cog: commands.Cog) -> Dict[str, Any]:
async def execute_internal_command(cog: commands.Cog, command: str, timeout_seconds: int = 60) -> Dict[str, Any]:
"""
Executes a shell command directly on the host machine where the bot is running.
**WARNING:** This tool is intended ONLY for internal Gurt operations (e.g., git pull, service restart)
@ -1120,8 +1120,8 @@ async def execute_internal_command(command: str, timeout_seconds: int = 60, *, c
traceback.print_exc()
return {"error": error_message, "command": command, "status": "error"}
@tool
async def extract_web_content(urls: Union[str, List[str]], extract_depth: str = "basic", include_images: bool = False, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def extract_web_content(cog: commands.Cog, urls: Union[str, List[str]], extract_depth: str = "basic", include_images: bool = False) -> Dict[str, Any]:
"""
Extracts the main textual content and optionally images from one or more web URLs using the Tavily API.
This is useful for getting the content of a webpage without performing a full search.
@ -1165,8 +1165,8 @@ async def extract_web_content(urls: Union[str, List[str]], extract_depth: str =
print(error_message)
return {"error": error_message, "timestamp": datetime.datetime.now().isoformat()}
@tool
async def read_file_content(file_path: str, *, cog: commands.Cog) -> Dict[str, Any]:
# @tool removed
async def read_file_content(cog: commands.Cog, file_path: str) -> Dict[str, Any]:
"""
Reads the content of a specified file located on the bot's host machine.
Access is restricted to specific allowed directories and file extensions within the project
@ -1254,7 +1254,7 @@ async def read_file_content(file_path: str, *, cog: commands.Cog) -> Dict[str, A
# --- Meta Tool: Create New Tool --- (Not decorating as it's experimental/dangerous)
# WARNING: HIGHLY EXPERIMENTAL AND DANGEROUS. Allows AI to write and load code.
async def create_new_tool(tool_name: str, description: str, parameters_json: str, returns_description: str, *, cog: commands.Cog) -> Dict[str, Any]:
async def create_new_tool(cog: commands.Cog, tool_name: str, description: str, parameters_json: str, returns_description: str) -> Dict[str, Any]:
"""
**EXPERIMENTAL & DANGEROUS:** Attempts to dynamically create a new tool for Gurt.
This involves using an LLM to generate Python code for the tool's function and its