feat: Implement command filtering and session management in AICodeAgentCog
This commit is contained in:
parent
17c806eedd
commit
32b8fc3839
@ -9,6 +9,7 @@ import base64
|
||||
import datetime # For snapshot naming
|
||||
import random # For snapshot naming
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from collections import defaultdict # Added for agent_shell_sessions
|
||||
|
||||
# Google Generative AI Imports (using Vertex AI backend)
|
||||
from google import genai
|
||||
@ -36,9 +37,32 @@ STANDARD_SAFETY_SETTINGS = [
|
||||
google_genai_types.SafetySetting(category=google_genai_types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold="BLOCK_NONE"),
|
||||
]
|
||||
|
||||
# --- Constants for command filtering, mirroring shell_command_cog.py ---
|
||||
# (Currently empty as they are commented out in the reference cog)
|
||||
BANNED_COMMANDS_AGENT = []
|
||||
BANNED_PATTERNS_AGENT = []
|
||||
|
||||
def is_command_allowed_agent(command):
|
||||
"""
|
||||
Check if the command is allowed to run. Mirrors shell_command_cog.py.
|
||||
Returns (allowed, reason) tuple.
|
||||
"""
|
||||
# Check against banned commands
|
||||
for banned in BANNED_COMMANDS_AGENT:
|
||||
if banned in command.lower(): # Simple substring check
|
||||
return False, f"Command contains banned term: `{banned}`"
|
||||
|
||||
# Check against banned patterns
|
||||
for pattern in BANNED_PATTERNS_AGENT:
|
||||
if re.search(pattern, command):
|
||||
return False, f"Command matches banned pattern: `{pattern}`"
|
||||
|
||||
return True, None
|
||||
# --- End of command filtering constants and function ---
|
||||
|
||||
COMMIT_AUTHOR = "AI Coding Agent Cog <me@slipstreamm.dev>"
|
||||
|
||||
AGENT_SYSTEM_PROMPT = """You are an expert AI Coding Agent. Your primary function is to assist the user (bot owner) by directly modifying the codebase of this Discord bot project or performing related tasks. You operate by understanding user requests and then generating specific inline "tool calls" in your responses when you need to interact with the file system, execute commands, or search the web.
|
||||
AGENT_SYSTEM_PROMPT = """You are an expert AI Coding Agent. Your primary function is to assist the user (bot owner) by directly modifying the codebase of this Discord bot project or performing related tasks. This bot uses discord.py. Cogs placed in the 'cogs' folder are automatically loaded by the bot's main script, so you typically do not need to modify `main.py` to load new cogs you create in that directory. You operate by understanding user requests and then generating specific inline "tool calls" in your responses when you need to interact with the file system, execute commands, or search the web.
|
||||
|
||||
**Inline Tool Call Syntax:**
|
||||
When you need to use a tool, your response should *only* contain the tool call block, formatted exactly as specified below. The system will parse this, execute the tool, and then feed the output back to you in a subsequent message prefixed with "ToolResponse:".
|
||||
@ -129,6 +153,10 @@ class AICodeAgentCog(commands.Cog):
|
||||
self.bot = bot
|
||||
self.genai_client = None
|
||||
self.agent_conversations: Dict[int, List[google_genai_types.Content]] = {} # User ID to conversation history
|
||||
self.agent_shell_sessions = defaultdict(lambda: {
|
||||
'cwd': os.getcwd(),
|
||||
'env': os.environ.copy()
|
||||
})
|
||||
|
||||
# Initialize Google GenAI Client for Vertex AI
|
||||
if PROJECT_ID and LOCATION:
|
||||
@ -497,7 +525,8 @@ class AICodeAgentCog(commands.Cog):
|
||||
if exec_command_match:
|
||||
tool_executed = True
|
||||
command_str = exec_command_match.group(1).strip()
|
||||
tool_output = await self._execute_tool_execute_command(command_str)
|
||||
user_id = ctx.author.id # Get user_id from context
|
||||
tool_output = await self._execute_tool_execute_command(command_str, user_id)
|
||||
return "TOOL_OUTPUT", f"ToolResponse: ExecuteCommand\nCommand: {command_str}\n---\n{tool_output}"
|
||||
|
||||
# --- ListFiles ---
|
||||
@ -602,39 +631,111 @@ class AICodeAgentCog(commands.Cog):
|
||||
except Exception as e:
|
||||
return f"Error applying diff to '{path}': {type(e).__name__} - {e}"
|
||||
|
||||
async def _execute_tool_execute_command(self, command: str) -> str:
|
||||
print(f"AICodeAgentCog: Attempting _execute_tool_execute_command: {command}")
|
||||
async def _execute_tool_execute_command(self, command: str, user_id: int) -> str:
|
||||
session = self.agent_shell_sessions[user_id]
|
||||
cwd = session['cwd']
|
||||
env = session['env']
|
||||
print(f"AICodeAgentCog: Attempting _execute_tool_execute_command for user_id {user_id}: '{command}' in CWD: '{cwd}'")
|
||||
|
||||
# Basic safety check for extremely dangerous commands, even for owner.
|
||||
# The AI is instructed on Git workflow, but this adds a small layer.
|
||||
# More comprehensive checks could be added if needed.
|
||||
blocked_commands = ["rm -rf /", "sudo rm -rf /", "mkfs", "format C:", "> /dev/sda"] # Example
|
||||
for blocked in blocked_commands:
|
||||
if blocked in command: # Simple substring check, can be improved
|
||||
return f"Error: Command '{command}' appears to be extremely dangerous and was blocked by a safeguard."
|
||||
# Mirroring shell_command_cog.py's command allowance check
|
||||
allowed, reason = is_command_allowed_agent(command)
|
||||
if not allowed:
|
||||
return f"⛔ Command not allowed: {reason}"
|
||||
|
||||
# Mirroring shell_command_cog.py's settings
|
||||
timeout_seconds = 30.0
|
||||
max_output_length = 1900
|
||||
|
||||
def run_agent_subprocess_sync(cmd_str, current_cwd, current_env, cmd_timeout_secs):
|
||||
try:
|
||||
proc = subprocess.Popen(
|
||||
cmd_str,
|
||||
shell=True,
|
||||
cwd=current_cwd,
|
||||
env=current_env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE
|
||||
)
|
||||
try:
|
||||
stdout, stderr = proc.communicate(timeout=cmd_timeout_secs)
|
||||
return (stdout, stderr, proc.returncode, False) # stdout, stderr, rc, timed_out
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
# Communicate again to fetch any output after kill
|
||||
stdout, stderr = proc.communicate()
|
||||
return (stdout, stderr, -1, True) # Using -1 for timeout rc, as in shell_command_cog
|
||||
except Exception as e:
|
||||
# Capture other exceptions during Popen or initial communicate
|
||||
return (b"", str(e).encode('utf-8', errors='replace'), -2, False) # -2 for other errors
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE
|
||||
# Execute the synchronous subprocess logic in a separate thread
|
||||
stdout_bytes, stderr_bytes, returncode, timed_out = await asyncio.to_thread(
|
||||
run_agent_subprocess_sync, command, cwd, env, timeout_seconds
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
output = ""
|
||||
if stdout:
|
||||
output += f"Stdout:\n{stdout.decode(errors='replace')}\n"
|
||||
if stderr:
|
||||
output += f"Stderr:\n{stderr.decode(errors='replace')}\n"
|
||||
|
||||
if not output.strip():
|
||||
output = "Command executed with no output.\n"
|
||||
|
||||
if process.returncode == 0:
|
||||
return f"Command executed successfully (exit code 0).\n{output.strip()}"
|
||||
else:
|
||||
return f"Command failed with exit code {process.returncode}.\n{output.strip()}"
|
||||
# Update session working directory if 'cd' command was used and it was successful
|
||||
# This logic is from the previous iteration and is similar to shell_command_cog's attempt
|
||||
if command.strip().startswith('cd ') and returncode == 0:
|
||||
new_dir_arg_str = command.strip()[len("cd "):].strip()
|
||||
potential_new_cwd = None
|
||||
# Handle 'cd' with no arguments (e.g. 'cd' or 'cd ~') - typically goes to home
|
||||
if not new_dir_arg_str or new_dir_arg_str == '~' or new_dir_arg_str == '$HOME':
|
||||
potential_new_cwd = os.path.expanduser('~')
|
||||
elif new_dir_arg_str == '-':
|
||||
# 'cd -' (previous directory) is hard to track reliably without more state,
|
||||
# so we won't update cwd for it, similar to shell_command_cog's limitations.
|
||||
print(f"AICodeAgentCog: 'cd -' used by user_id {user_id}. CWD tracking will not update for this command.")
|
||||
else:
|
||||
# For 'cd <path>'
|
||||
temp_arg = new_dir_arg_str
|
||||
# Remove quotes if present
|
||||
if (temp_arg.startswith('"') and temp_arg.endswith('"')) or \
|
||||
(temp_arg.startswith("'") and temp_arg.endswith("'")):
|
||||
temp_arg = temp_arg[1:-1]
|
||||
|
||||
if os.path.isabs(temp_arg):
|
||||
potential_new_cwd = temp_arg
|
||||
else:
|
||||
potential_new_cwd = os.path.abspath(os.path.join(cwd, temp_arg))
|
||||
|
||||
if potential_new_cwd and os.path.isdir(potential_new_cwd):
|
||||
session['cwd'] = potential_new_cwd
|
||||
print(f"AICodeAgentCog: Updated CWD for user_id {user_id} to: {session['cwd']}")
|
||||
elif new_dir_arg_str and new_dir_arg_str != '-' and potential_new_cwd:
|
||||
print(f"AICodeAgentCog: 'cd' command for user_id {user_id} seemed to succeed (rc=0), but CWD tracking logic could not confirm new path '{potential_new_cwd}' or it's not a directory. CWD remains '{session['cwd']}'. Command: '{command}'.")
|
||||
elif new_dir_arg_str and new_dir_arg_str != '-': # if potential_new_cwd was None but arg was given
|
||||
print(f"AICodeAgentCog: 'cd' command for user_id {user_id} with arg '{new_dir_arg_str}' succeeded (rc=0), but path resolution for CWD tracking failed. CWD remains '{session['cwd']}'.")
|
||||
|
||||
|
||||
# Format Output identically to shell_command_cog.py's _execute_local_command
|
||||
result_parts = []
|
||||
stdout_str = stdout_bytes.decode('utf-8', errors='replace').strip()
|
||||
stderr_str = stderr_bytes.decode('utf-8', errors='replace').strip()
|
||||
|
||||
if timed_out:
|
||||
result_parts.append(f"⏱️ Command timed out after {timeout_seconds} seconds.")
|
||||
|
||||
if stdout_str:
|
||||
if len(stdout_str) > max_output_length:
|
||||
stdout_str = stdout_str[:max_output_length] + "... (output truncated)"
|
||||
result_parts.append(f"📤 **STDOUT:**\n```\n{stdout_str}\n```")
|
||||
|
||||
if stderr_str:
|
||||
if len(stderr_str) > max_output_length:
|
||||
stderr_str = stderr_str[:max_output_length] + "... (output truncated)"
|
||||
result_parts.append(f"⚠️ **STDERR:**\n```\n{stderr_str}\n```")
|
||||
|
||||
if returncode != 0 and not timed_out: # Don't add exit code if it was a timeout
|
||||
result_parts.append(f"❌ **Exit Code:** {returncode}")
|
||||
else: # Successful or timed out (timeout message already added)
|
||||
if not result_parts: # No stdout, no stderr, not timed out, and successful
|
||||
result_parts.append("✅ Command executed successfully (no output).")
|
||||
|
||||
return "\n".join(result_parts)
|
||||
|
||||
except Exception as e:
|
||||
# General exception during subprocess handling
|
||||
return f"Exception executing command '{command}': {type(e).__name__} - {e}"
|
||||
|
||||
async def _execute_tool_list_files(self, path: str, recursive: bool) -> str:
|
||||
@ -651,17 +752,24 @@ class AICodeAgentCog(commands.Cog):
|
||||
return f"Error: Path '{path}' is not a directory."
|
||||
|
||||
file_list = []
|
||||
excluded_dirs = {"__pycache__", ".git", ".vscode", ".idea", "node_modules", "venv", ".env", "terminal_images"}
|
||||
if recursive:
|
||||
for root, dirs, files in os.walk(path):
|
||||
for root, dirs, files in os.walk(path, topdown=True):
|
||||
# Exclude specified directories from further traversal
|
||||
dirs[:] = [d for d in dirs if d not in excluded_dirs]
|
||||
|
||||
for name in files:
|
||||
file_list.append(os.path.join(root, name))
|
||||
for name in dirs:
|
||||
# Add filtered directories to the list
|
||||
for name in dirs: # These are already filtered dirs
|
||||
file_list.append(os.path.join(root, name) + os.sep) # Indicate dirs
|
||||
else:
|
||||
else: # Non-recursive case
|
||||
for item in os.listdir(path):
|
||||
if item in excluded_dirs: # Check if the item itself is an excluded directory name
|
||||
continue
|
||||
full_item_path = os.path.join(path, item)
|
||||
if os.path.isdir(full_item_path):
|
||||
file_list.append(item + os.sep)
|
||||
file_list.append(item + os.sep) # Indicate dirs
|
||||
else:
|
||||
file_list.append(item)
|
||||
return "\n".join(file_list) if file_list else "No files or directories found."
|
||||
|
Loading…
x
Reference in New Issue
Block a user