diff --git a/cogs/aimod.py b/cogs/aimod.py index aa1d706..f0bd7ef 100644 --- a/cogs/aimod.py +++ b/cogs/aimod.py @@ -15,6 +15,8 @@ import cv2 # For video processing import numpy as np # For array operations import tempfile # For temporary file operations from typing import Optional, List, Dict, Any, Tuple # For type hinting +import asyncio +import aiofiles # Google Generative AI Imports (using Vertex AI backend) from google import genai @@ -83,29 +85,25 @@ except Exception as e: print(f"Failed to load user infractions from {USER_INFRACTIONS_PATH}: {e}") USER_INFRACTIONS = {} - -def save_guild_config(): - try: - # os.makedirs(os.path.dirname(GUILD_CONFIG_PATH), exist_ok=True) # Already created by GUILD_CONFIG_DIR - # if not os.path.exists(GUILD_CONFIG_PATH): # Redundant check, file is created if not exists - # with open(GUILD_CONFIG_PATH, "w", encoding="utf-8") as f: - # json.dump({}, f) - with open(GUILD_CONFIG_PATH, "w", encoding="utf-8") as f: - json.dump(GUILD_CONFIG, f, indent=2) - except Exception as e: - print(f"Failed to save per-guild config: {e}") +CONFIG_LOCK = asyncio.Lock() -def save_user_infractions(): - try: - # os.makedirs(os.path.dirname(USER_INFRACTIONS_PATH), exist_ok=True) # Already created by GUILD_CONFIG_DIR - # if not os.path.exists(USER_INFRACTIONS_PATH): # Redundant check - # with open(USER_INFRACTIONS_PATH, "w", encoding="utf-8") as f: - # json.dump({}, f) - with open(USER_INFRACTIONS_PATH, "w", encoding="utf-8") as f: - json.dump(USER_INFRACTIONS, f, indent=2) - except Exception as e: - print(f"Failed to save user infractions: {e}") +async def save_guild_config(): + async with CONFIG_LOCK: + try: + async with aiofiles.open(GUILD_CONFIG_PATH, "w", encoding="utf-8") as f: + await f.write(json.dumps(GUILD_CONFIG, indent=2)) + except Exception as e: + print(f"Failed to save per-guild config: {e}") + + +async def save_user_infractions(): + async with CONFIG_LOCK: + try: + async with aiofiles.open(USER_INFRACTIONS_PATH, "w", encoding="utf-8") as f: + await f.write(json.dumps(USER_INFRACTIONS, indent=2)) + except Exception as e: + print(f"Failed to save user infractions: {e}") def get_guild_config(guild_id: int, key: str, default=None): @@ -115,12 +113,12 @@ def get_guild_config(guild_id: int, key: str, default=None): return default -def set_guild_config(guild_id: int, key: str, value): +async def set_guild_config(guild_id: int, key: str, value): guild_str = str(guild_id) if guild_str not in GUILD_CONFIG: GUILD_CONFIG[guild_str] = {} GUILD_CONFIG[guild_str][key] = value - save_guild_config() + await save_guild_config() def get_user_infraction_history(guild_id: int, user_id: int) -> list: @@ -129,7 +127,7 @@ def get_user_infraction_history(guild_id: int, user_id: int) -> list: return USER_INFRACTIONS.get(key, []) -def add_user_infraction( +async def add_user_infraction( guild_id: int, user_id: int, rule_violated: str, @@ -151,7 +149,7 @@ def add_user_infraction( USER_INFRACTIONS[key].append(infraction_record) # Keep only the last N infractions to prevent the file from growing too large, e.g., last 10 USER_INFRACTIONS[key] = USER_INFRACTIONS[key][-10:] - save_user_infractions() + await save_user_infractions() # Server rules to provide context to the AI @@ -241,6 +239,7 @@ class AIModerationCog(commands.Cog): self.last_ai_decisions = collections.deque( maxlen=5 ) # Store last 5 AI decisions + self.config_lock = CONFIG_LOCK # Supported image file extensions self.image_extensions = [ ".jpg", @@ -462,7 +461,7 @@ class AIModerationCog(commands.Cog): async def modset_log_channel( self, interaction: discord.Interaction, channel: discord.TextChannel ): - set_guild_config(interaction.guild.id, "MOD_LOG_CHANNEL_ID", channel.id) + await set_guild_config(interaction.guild.id, "MOD_LOG_CHANNEL_ID", channel.id) await interaction.response.send_message( f"Moderation log channel set to {channel.mention}.", ephemeral=False ) @@ -475,7 +474,9 @@ class AIModerationCog(commands.Cog): async def modset_suggestions_channel( self, interaction: discord.Interaction, channel: discord.TextChannel ): - set_guild_config(interaction.guild.id, "SUGGESTIONS_CHANNEL_ID", channel.id) + await set_guild_config( + interaction.guild.id, "SUGGESTIONS_CHANNEL_ID", channel.id + ) await interaction.response.send_message( f"Suggestions channel set to {channel.mention}.", ephemeral=False ) @@ -488,7 +489,7 @@ class AIModerationCog(commands.Cog): async def modset_moderator_role( self, interaction: discord.Interaction, role: discord.Role ): - set_guild_config(interaction.guild.id, "MODERATOR_ROLE_ID", role.id) + await set_guild_config(interaction.guild.id, "MODERATOR_ROLE_ID", role.id) await interaction.response.send_message( f"Moderator role set to {role.mention}.", ephemeral=False ) @@ -502,7 +503,7 @@ class AIModerationCog(commands.Cog): async def modset_suicidal_ping_role( self, interaction: discord.Interaction, role: discord.Role ): - set_guild_config(interaction.guild.id, "SUICIDAL_PING_ROLE_ID", role.id) + await set_guild_config(interaction.guild.id, "SUICIDAL_PING_ROLE_ID", role.id) await interaction.response.send_message( f"Suicidal content ping role set to {role.mention}.", ephemeral=False ) @@ -520,7 +521,7 @@ class AIModerationCog(commands.Cog): nsfw_channels: list[int] = get_guild_config(guild_id, "NSFW_CHANNEL_IDS", []) if channel.id not in nsfw_channels: nsfw_channels.append(channel.id) - set_guild_config(guild_id, "NSFW_CHANNEL_IDS", nsfw_channels) + await set_guild_config(guild_id, "NSFW_CHANNEL_IDS", nsfw_channels) await interaction.response.send_message( f"{channel.mention} added to NSFW channels list.", ephemeral=False ) @@ -543,7 +544,7 @@ class AIModerationCog(commands.Cog): nsfw_channels: list[int] = get_guild_config(guild_id, "NSFW_CHANNEL_IDS", []) if channel.id in nsfw_channels: nsfw_channels.remove(channel.id) - set_guild_config(guild_id, "NSFW_CHANNEL_IDS", nsfw_channels) + await set_guild_config(guild_id, "NSFW_CHANNEL_IDS", nsfw_channels) await interaction.response.send_message( f"{channel.mention} removed from NSFW channels list.", ephemeral=False ) @@ -595,7 +596,7 @@ class AIModerationCog(commands.Cog): "You must be an administrator to use this command.", ephemeral=False ) return - set_guild_config(interaction.guild.id, "ENABLED", enabled) + await set_guild_config(interaction.guild.id, "ENABLED", enabled) await interaction.response.send_message( f"Moderation is now {'enabled' if enabled else 'disabled'} for this guild.", ephemeral=False, @@ -693,7 +694,7 @@ class AIModerationCog(commands.Cog): # Clear the user's infractions USER_INFRACTIONS[key] = [] - save_user_infractions() + await save_user_infractions() await interaction.response.send_message( f"Cleared {len(infractions)} infraction(s) for {user.mention}.", @@ -726,7 +727,7 @@ class AIModerationCog(commands.Cog): # Save the model to guild configuration guild_id = interaction.guild.id - set_guild_config(guild_id, "AI_MODEL", model) + await set_guild_config(guild_id, "AI_MODEL", model) # Note: There's no global model variable to update here like OPENROUTER_MODEL. # The cog will use the guild-specific config or the DEFAULT_VERTEX_AI_MODEL. @@ -1548,7 +1549,7 @@ CRITICAL: Do NOT output anything other than the required JSON response. print( f"BANNED user {message.author} for violating rule {rule_violated}." ) - add_user_infraction( + await add_user_infraction( guild_id, user_id, rule_violated, @@ -1580,7 +1581,7 @@ CRITICAL: Do NOT output anything other than the required JSON response. print( f"KICKED user {message.author} for violating rule {rule_violated}." ) - add_user_infraction( + await add_user_infraction( guild_id, user_id, rule_violated, @@ -1631,7 +1632,7 @@ CRITICAL: Do NOT output anything other than the required JSON response. print( f"TIMED OUT user {message.author} for {duration_readable} for violating rule {rule_violated}." ) - add_user_infraction( + await add_user_infraction( guild_id, user_id, rule_violated, @@ -1684,7 +1685,7 @@ CRITICAL: Do NOT output anything other than the required JSON response. except Exception as e: print(f"Error sending warning DM to {message.author}: {e}") action_taken_message += " (Error sending warning DM)." - add_user_infraction( + await add_user_infraction( guild_id, user_id, rule_violated,