From 9999051b543a0487f5f615a775f309dd98d3b447 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Thu, 5 Jun 2025 20:32:07 -0600 Subject: [PATCH 01/11] Add Tavily API script for AI agents with search functionality --- tavily.py => tavilytool.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tavily.py => tavilytool.py (100%) diff --git a/tavily.py b/tavilytool.py similarity index 100% rename from tavily.py rename to tavilytool.py From 5dbf605cb1ce0362405a54626ac97a1a371e0356 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Thu, 5 Jun 2025 20:37:59 -0600 Subject: [PATCH 02/11] Remove button placeholders in LoggingCog for improved UI clarity --- cogs/logging_cog.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cogs/logging_cog.py b/cogs/logging_cog.py index e075851..9709160 100644 --- a/cogs/logging_cog.py +++ b/cogs/logging_cog.py @@ -75,7 +75,7 @@ class LoggingCog(commands.Cog): accessory=( ui.Thumbnail(media=author.display_avatar.url) if author - else ui.Button(label="\u200b", disabled=True) + else None ) ) self.header.add_item(ui.TextDisplay(f"**{title}**")) @@ -101,7 +101,7 @@ class LoggingCog(commands.Cog): def add_field(self, name: str, value: str, inline: bool = False): """Mimic Embed.add_field by appending a bolded name/value line.""" if not self._field_sections or len(self._field_sections[-1].children) >= 3: - section = ui.Section(accessory=ui.Button(label="\u200b", disabled=True)) + section = ui.Section(accessory=None) self._insert_field_section(section) self._field_sections.append(section) self._field_sections[-1].add_item(ui.TextDisplay(f"**{name}:** {value}")) @@ -124,7 +124,7 @@ class LoggingCog(commands.Cog): if icon_url: self.header.accessory = ui.Thumbnail(media=icon_url) else: - self.header.accessory = ui.Button(label="\u200b", disabled=True) + self.header.accessory = None self.header.add_item(ui.TextDisplay(name)) def _user_display(self, user: Union[discord.Member, discord.User]) -> str: """Return display name, username and ID string for a user.""" From 38ec5d1e691368f189eb54836723b6283691c6e0 Mon Sep 17 00:00:00 2001 From: Codex Date: Fri, 6 Jun 2025 02:58:40 +0000 Subject: [PATCH 03/11] Add NullAccessory for non-interactive sections --- cogs/logging_cog.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/cogs/logging_cog.py b/cogs/logging_cog.py index 9709160..ee9c7af 100644 --- a/cogs/logging_cog.py +++ b/cogs/logging_cog.py @@ -43,6 +43,15 @@ ALL_EVENT_KEYS = sorted([ # Add more audit keys if needed, e.g., "audit_stage_instance_create" ]) +class NullAccessory(ui.Button): + """Non-interactive accessory used as a placeholder.""" + + def __init__(self) -> None: + super().__init__(label="\u200b", disabled=True) + + def is_dispatchable(self) -> bool: # type: ignore[override] + return False + class LoggingCog(commands.Cog): """Handles comprehensive server event logging via webhooks with granular toggling.""" def __init__(self, bot: commands.Bot): @@ -75,7 +84,7 @@ class LoggingCog(commands.Cog): accessory=( ui.Thumbnail(media=author.display_avatar.url) if author - else None + else NullAccessory() ) ) self.header.add_item(ui.TextDisplay(f"**{title}**")) @@ -101,7 +110,7 @@ class LoggingCog(commands.Cog): def add_field(self, name: str, value: str, inline: bool = False): """Mimic Embed.add_field by appending a bolded name/value line.""" if not self._field_sections or len(self._field_sections[-1].children) >= 3: - section = ui.Section(accessory=None) + section = ui.Section(accessory=NullAccessory()) self._insert_field_section(section) self._field_sections.append(section) self._field_sections[-1].add_item(ui.TextDisplay(f"**{name}:** {value}")) From 0e70380d422c14ec1bd26c5cc3037329563ddee9 Mon Sep 17 00:00:00 2001 From: Codex Date: Fri, 6 Jun 2025 03:27:16 +0000 Subject: [PATCH 04/11] Refactor LoggingCog view layout --- cogs/logging_cog.py | 94 +++++++++++++++++++++++---------------------- 1 file changed, 48 insertions(+), 46 deletions(-) diff --git a/cogs/logging_cog.py b/cogs/logging_cog.py index ee9c7af..fa760e7 100644 --- a/cogs/logging_cog.py +++ b/cogs/logging_cog.py @@ -43,14 +43,6 @@ ALL_EVENT_KEYS = sorted([ # Add more audit keys if needed, e.g., "audit_stage_instance_create" ]) -class NullAccessory(ui.Button): - """Non-interactive accessory used as a placeholder.""" - - def __init__(self) -> None: - super().__init__(label="\u200b", disabled=True) - - def is_dispatchable(self) -> bool: # type: ignore[override] - return False class LoggingCog(commands.Cog): """Handles comprehensive server event logging via webhooks with granular toggling.""" @@ -65,7 +57,7 @@ class LoggingCog(commands.Cog): asyncio.create_task(self.start_audit_log_poller_when_ready()) # Keep this for initial start class LogView(ui.LayoutView): - """Simple view for log messages with helper methods.""" + """View for logging messages using Discord's layout UI.""" def __init__( self, @@ -75,66 +67,76 @@ class LoggingCog(commands.Cog): color: discord.Color, author: Optional[discord.abc.User], footer: Optional[str], - ): + ) -> None: super().__init__(timeout=None) + self.container = ui.Container(accent_colour=color) self.add_item(self.container) - self.header = ui.Section( - accessory=( - ui.Thumbnail(media=author.display_avatar.url) - if author - else NullAccessory() - ) + self.description_display: Optional[ui.TextDisplay] = ( + ui.TextDisplay(description) if description else None ) - self.header.add_item(ui.TextDisplay(f"**{title}**")) - if description: - self.header.add_item(ui.TextDisplay(description)) - self.container.add_item(self.header) - # Placeholder for future field sections. They are inserted before - # the separator when the first field is added. - self._field_sections: list[ui.Section] = [] + # Header section is only used when an author is provided so we don't + # need a placeholder accessory. + if author is not None: + self.header: Optional[ui.Section] = ui.Section( + accessory=ui.Thumbnail(media=author.display_avatar.url) + ) + self.header.add_item(ui.TextDisplay(f"**{title}**")) + if self.description_display: + self.header.add_item(self.description_display) + self.container.add_item(self.header) + else: + self.header = None + self.title_display = ui.TextDisplay(f"**{title}**") + self.container.add_item(self.title_display) + if self.description_display: + self.container.add_item(self.description_display) + + # Container used for fields so they're inserted before the footer. + self.fields_container = ui.Container() + self.container.add_item(self.fields_container) self.separator = ui.Separator(spacing=discord.SeparatorSpacing.small) - footer_text = footer or f"Bot ID: {bot.user.id}" + ( f" | User ID: {author.id}" if author else "" ) self.footer_display = ui.TextDisplay(footer_text) - self.container.add_item(self.separator) self.container.add_item(self.footer_display) # --- Compatibility helpers --- def add_field(self, name: str, value: str, inline: bool = False): - """Mimic Embed.add_field by appending a bolded name/value line.""" - if not self._field_sections or len(self._field_sections[-1].children) >= 3: - section = ui.Section(accessory=NullAccessory()) - self._insert_field_section(section) - self._field_sections.append(section) - self._field_sections[-1].add_item(ui.TextDisplay(f"**{name}:** {value}")) - - def _insert_field_section(self, section: ui.Section) -> None: - """Insert a field section before the footer separator.""" - self.container.remove_item(self.separator) - self.container.remove_item(self.footer_display) - self.container.add_item(section) - self.container.add_item(self.separator) - self.container.add_item(self.footer_display) + """Append a bolded name/value line to the log view.""" + self.fields_container.add_item(ui.TextDisplay(f"**{name}:** {value}")) def set_footer(self, text: str): - """Mimic Embed.set_footer by replacing the footer text display.""" + """Replace the footer text display.""" self.footer_display.content = text def set_author(self, name: str, icon_url: Optional[str] = None): - """Mimic Embed.set_author by adjusting the header section.""" - self.header.clear_items() - if icon_url: - self.header.accessory = ui.Thumbnail(media=icon_url) + """Add or update the author information.""" + if self.header is None: + # Remove plain title/description displays and replace with a section. + self.container.remove_item(self.title_display) + if self.description_display: + self.container.remove_item(self.description_display) + self.header = ui.Section( + accessory=ui.Thumbnail(media=icon_url or "") + ) + self.header.add_item(ui.TextDisplay(name)) + if self.description_display: + self.header.add_item(self.description_display) + self.container.add_item(self.header) + # Move to the beginning to mimic embed header placement + self.container._children.remove(self.header) + self.container._children.insert(0, self.header) else: - self.header.accessory = None - self.header.add_item(ui.TextDisplay(name)) + self.header.clear_items() + if icon_url: + self.header.accessory = ui.Thumbnail(media=icon_url) + self.header.add_item(ui.TextDisplay(name)) def _user_display(self, user: Union[discord.Member, discord.User]) -> str: """Return display name, username and ID string for a user.""" display = user.display_name if isinstance(user, discord.Member) else user.name From a353d79e848b8d0172828610f38aad64532e0c7c Mon Sep 17 00:00:00 2001 From: Slipstream Date: Thu, 5 Jun 2025 21:29:57 -0600 Subject: [PATCH 05/11] Restore point commit From d1ec42fa51adb3fc06254f4a140eb86f677eeb9e Mon Sep 17 00:00:00 2001 From: Slipstream Date: Thu, 5 Jun 2025 21:31:06 -0600 Subject: [PATCH 06/11] big ass formatting --- EXAMPLE.py | 203 +- api_integration.py | 41 +- api_service/api_models.py | 53 +- api_service/api_server.py | 2062 ++++++++---- api_service/code_verifier_store.py | 20 +- api_service/cog_management_endpoints.py | 127 +- .../command_customization_endpoints.py | 147 +- api_service/dashboard_api_endpoints.py | 1061 ++++-- api_service/dashboard_models.py | 53 +- api_service/database.py | 150 +- api_service/dependencies.py | 138 +- api_service/discord_client.py | 77 +- api_service/run_api_server.py | 9 +- api_service/terminal_images_endpoint.py | 28 +- api_service/webhook_endpoints.py | 1074 ++++-- check_sync_dependencies.py | 16 +- cogs/VoiceGatewayCog.py | 566 ++-- cogs/ai_code_agent_cog.py | 1754 +++++++--- cogs/aimod.py | 1233 +++++-- cogs/ban_system_cog.py | 189 +- cogs/bot_appearance_cog.py | 153 +- cogs/caption_cog.py | 219 +- cogs/command_debug_cog.py | 62 +- cogs/command_fix_cog.py | 44 +- cogs/counting_cog.py | 257 +- cogs/dictionary_cog.py | 74 +- cogs/discord_sync_cog.py | 59 +- cogs/economy/database.py | 455 ++- cogs/economy/earning.py | 85 +- cogs/economy/gambling.py | 54 +- cogs/economy/jobs.py | 372 ++- cogs/economy/risky.py | 79 +- cogs/economy/utility.py | 72 +- cogs/economy_cog.py | 303 +- cogs/emoji_cog.py | 238 +- cogs/eval_cog.py | 93 +- cogs/femdom_roleplay_teto_cog.py | 227 +- cogs/femdom_teto_cog.py | 193 +- cogs/fetch_user_cog.py | 6 +- cogs/games/basic_games.py | 98 +- cogs/games/chess_game.py | 1543 ++++++--- cogs/games/coinflip_game.py | 94 +- cogs/games/rps_game.py | 63 +- cogs/games/tictactoe_game.py | 103 +- cogs/games/wordle_game.py | 124 +- cogs/games_cog.py | 384 ++- cogs/gelbooru_watcher_base_cog.py | 1013 ++++-- cogs/gif_optimizer_cog.py | 131 +- cogs/git_monitor_cog.py | 420 ++- cogs/giveaways_cog.py | 481 ++- cogs/help_cog.py | 218 +- cogs/leveling_cog.py | 665 ++-- cogs/lockdown_cog.py | 58 +- cogs/logging_cog.py | 1800 +++++++---- cogs/marriage_cog.py | 167 +- cogs/message_cog.py | 255 +- cogs/message_scraper_cog.py | 22 +- cogs/mod_application_cog.py | 635 ++-- cogs/mod_log_cog.py | 361 ++- cogs/moderation_cog.py | 222 +- cogs/multi_bot_cog.py | 211 +- cogs/neru_message_cog.py | 261 +- cogs/neru_roleplay_cog.py | 28 +- cogs/neru_teto_cog.py | 142 +- cogs/oauth_cog.py | 98 +- cogs/owner_utils_cog.py | 132 +- cogs/owoify_cog.py | 202 +- cogs/ping_cog.py | 4 +- cogs/profile_cog.py | 8 +- cogs/profile_updater_cog.py | 520 ++- cogs/random_cog.py | 106 +- cogs/random_strings_cog.py | 125 +- cogs/random_timeout_cog.py | 139 +- cogs/real_moderation_cog.py | 1090 +++++-- cogs/role_creator_cog.py | 205 +- cogs/role_management_cog.py | 713 ++-- cogs/role_selector_cog.py | 1429 +++++--- cogs/roleplay_cog.py | 23 +- cogs/roleplay_teto_cog.py | 227 +- cogs/rp_messages.py | 21 +- cogs/rule34_cog.py | 336 +- cogs/safebooru_cog.py | 151 +- cogs/settings_cog.py | 568 +++- cogs/shell_command_cog.py | 230 +- cogs/stable_diffusion_cog.py | 184 +- cogs/starboard_cog.py | 408 ++- cogs/status_cog.py | 144 +- cogs/sync_cog.py | 43 +- cogs/system_check_cog.py | 51 +- cogs/terminal_cog.py | 340 +- cogs/teto_cog.py | 643 ++-- cogs/teto_image_cog.py | 37 +- cogs/timer_cog.py | 93 +- cogs/tts_provider_cog.py | 161 +- cogs/upload_cog.py | 153 +- cogs/user_info_cog.py | 118 +- cogs/webdrivertorso_cog.py | 689 ++-- cogs/welcome_cog.py | 214 +- command_customization.py | 50 +- commands.py | 47 +- custom_bot_manager.py | 91 +- db/mod_log_db.py | 267 +- discord_bot_sync_api.py | 339 +- discord_oauth.py | 95 +- download_illustrious.py | 138 +- error_handler.py | 433 ++- flask_server.py | 12 +- global_bot_accessor.py | 8 +- gurt/__init__.py | 2 +- gurt/analysis.py | 1059 ++++-- gurt/api.py | 2351 +++++++++----- gurt/background.py | 523 ++- gurt/cog.py | 570 +++- gurt/commands.py | 1125 +++++-- gurt/config.py | 1279 +++++--- gurt/context.py | 325 +- gurt/emojis.py | 95 +- gurt/extrtools.py | 3 + gurt/listeners.py | 1235 +++++-- gurt/memory.py | 6 +- gurt/prompt.py | 206 +- gurt/tools.py | 2866 ++++++++++++----- gurt/utils.py | 210 +- gurt_bot.py | 33 +- gurt_memory.py | 876 +++-- install_stable_diffusion.py | 46 +- main.py | 393 ++- multi_bot.py | 396 ++- neru_bot.py | 69 +- oauth_server.py | 37 +- print_vertex_schema.py | 50 +- run_additional_bots.py | 18 +- run_femdom_teto_bot.py | 60 +- run_gurt_bot.py | 6 +- run_markdown_server.py | 91 +- run_neru_bot.py | 4 +- run_unified_api.py | 7 +- run_wheatley_bot.py | 8 +- settings_manager.py | 1711 +++++++--- tavilytool.py | 107 +- test_gputil.py | 3 +- test_pagination.py | 10 +- test_part.py | 8 +- test_starboard.py | 17 +- test_timeout_config.py | 13 +- test_url_parser.py | 12 +- test_usage_counters.py | 59 +- tictactoe.py | 77 +- utils.py | 5 +- webdrivertorso_template.py | 315 +- wheatley/__init__.py | 2 +- wheatley/analysis.py | 794 ++++- wheatley/api.py | 1004 ++++-- wheatley/background.py | 85 +- wheatley/cog.py | 355 +- wheatley/commands.py | 490 ++- wheatley/config.py | 462 +-- wheatley/context.py | 231 +- wheatley/listeners.py | 448 ++- wheatley/memory.py | 540 +++- wheatley/prompt.py | 42 +- wheatley/tools.py | 1034 ++++-- wheatley/utils.py | 113 +- wheatley_bot.py | 31 +- 164 files changed, 38243 insertions(+), 16304 deletions(-) diff --git a/EXAMPLE.py b/EXAMPLE.py index 3c3f553..1ddd630 100644 --- a/EXAMPLE.py +++ b/EXAMPLE.py @@ -13,8 +13,9 @@ import numpy as np import nltk from nltk.corpus import words, wordnet -nltk.download('words') -nltk.download('wordnet') +nltk.download("words") +nltk.download("wordnet") + class JSON: def read(file): @@ -26,6 +27,7 @@ class JSON: with open(f"{file}.json", "w", encoding="utf8") as file: json.dump(data, file, indent=4) + config_data = JSON.read("config") # SETTINGS # @@ -42,17 +44,29 @@ max_shapes = config_data["MAX_SHAPES"] sample_rate = config_data["SOUND_QUALITY"] tts_enabled = config_data.get("TTS_ENABLED", True) tts_text = config_data.get("TTS_TEXT", "This is a default text for TTS.") -audio_wave_type = config_data.get("AUDIO_WAVE_TYPE", "sawtooth") # Options: sawtooth, sine, square +audio_wave_type = config_data.get( + "AUDIO_WAVE_TYPE", "sawtooth" +) # Options: sawtooth, sine, square slide_duration = config_data.get("SLIDE_DURATION", 1000) # Duration in milliseconds -deform_level = config_data.get("DEFORM_LEVEL", "none") # Options: none, low, medium, high +deform_level = config_data.get( + "DEFORM_LEVEL", "none" +) # Options: none, low, medium, high color_mode = config_data.get("COLOR_MODE", "random") # Options: random, scheme, solid -color_scheme = config_data.get("COLOR_SCHEME", "default") # Placeholder for color schemes +color_scheme = config_data.get( + "COLOR_SCHEME", "default" +) # Placeholder for color schemes solid_color = config_data.get("SOLID_COLOR", "#FFFFFF") # Default solid color -allowed_shapes = config_data.get("ALLOWED_SHAPES", ["rectangle", "ellipse", "polygon", "triangle", "circle"]) +allowed_shapes = config_data.get( + "ALLOWED_SHAPES", ["rectangle", "ellipse", "polygon", "triangle", "circle"] +) wave_vibe = config_data.get("WAVE_VIBE", "calm") # New config option for wave vibe top_left_text_enabled = config_data.get("TOP_LEFT_TEXT_ENABLED", True) -top_left_text_mode = config_data.get("TOP_LEFT_TEXT_MODE", "random") # Options: random, word -words_topic = config_data.get("WORDS_TOPIC", "random") # Options: random, introspective, action, nature, technology +top_left_text_mode = config_data.get( + "TOP_LEFT_TEXT_MODE", "random" +) # Options: random, word +words_topic = config_data.get( + "WORDS_TOPIC", "random" +) # Options: random, introspective, action, nature, technology # Vibe presets for wave sound wave_vibes = { @@ -65,34 +79,107 @@ wave_vibes = { } color_schemes = { - "pastel": [(255, 182, 193), (176, 224, 230), (240, 230, 140), (221, 160, 221), (152, 251, 152)], - "dark_gritty": [(47, 79, 79), (105, 105, 105), (0, 0, 0), (85, 107, 47), (139, 69, 19)], - "nature": [(34, 139, 34), (107, 142, 35), (46, 139, 87), (32, 178, 170), (154, 205, 50)], + "pastel": [ + (255, 182, 193), + (176, 224, 230), + (240, 230, 140), + (221, 160, 221), + (152, 251, 152), + ], + "dark_gritty": [ + (47, 79, 79), + (105, 105, 105), + (0, 0, 0), + (85, 107, 47), + (139, 69, 19), + ], + "nature": [ + (34, 139, 34), + (107, 142, 35), + (46, 139, 87), + (32, 178, 170), + (154, 205, 50), + ], "vibrant": [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)], - "ocean": [(0, 105, 148), (72, 209, 204), (70, 130, 180), (135, 206, 250), (176, 224, 230)] + "ocean": [ + (0, 105, 148), + (72, 209, 204), + (70, 130, 180), + (135, 206, 250), + (176, 224, 230), + ], } # Font scaling based on video size font_size = max(w, h) // 40 # Scales font size to make it smaller and more readable fnt = ImageFont.truetype("./FONT/sys.ttf", font_size) -files = glob.glob('./IMG/*') +files = glob.glob("./IMG/*") for f in files: os.remove(f) print("REMOVED OLD FILES") -def generate_string(length, charset="abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"): + +def generate_string( + length, charset="abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +): result = "" for i in range(length): result += random.choice(charset) return result + # Predefined word lists for specific topics -introspective_words = ["reflection", "thought", "solitude", "ponder", "meditation", "introspection", "awareness", "contemplation", "silence", "stillness"] -action_words = ["run", "jump", "climb", "race", "fight", "explore", "build", "create", "overcome", "achieve"] -nature_words = ["tree", "mountain", "river", "ocean", "flower", "forest", "animal", "sky", "valley", "meadow"] -technology_words = ["computer", "robot", "network", "data", "algorithm", "innovation", "digital", "machine", "software", "hardware"] +introspective_words = [ + "reflection", + "thought", + "solitude", + "ponder", + "meditation", + "introspection", + "awareness", + "contemplation", + "silence", + "stillness", +] +action_words = [ + "run", + "jump", + "climb", + "race", + "fight", + "explore", + "build", + "create", + "overcome", + "achieve", +] +nature_words = [ + "tree", + "mountain", + "river", + "ocean", + "flower", + "forest", + "animal", + "sky", + "valley", + "meadow", +] +technology_words = [ + "computer", + "robot", + "network", + "data", + "algorithm", + "innovation", + "digital", + "machine", + "software", + "hardware", +] + def generate_word(theme="random"): if theme == "introspective": @@ -108,10 +195,8 @@ def generate_word(theme="random"): else: return "unknown_theme" -def append_wave( - freq=None, - duration_milliseconds=1000, - volume=1.0): + +def append_wave(freq=None, duration_milliseconds=1000, volume=1.0): global audio @@ -122,7 +207,9 @@ def append_wave( modulation = random.uniform(0.1, 1.0) else: base_freq = vibe_params["frequency"] - freq = random.uniform(base_freq * 0.7, base_freq * 1.3) if freq is None else freq + freq = ( + random.uniform(base_freq * 0.7, base_freq * 1.3) if freq is None else freq + ) amplitude = vibe_params["amplitude"] * random.uniform(0.7, 1.3) modulation = vibe_params["modulation"] * random.uniform(0.6, 1.4) @@ -130,10 +217,13 @@ def append_wave( for x in range(int(num_samples)): wave_sample = amplitude * math.sin(2 * math.pi * freq * (x / sample_rate)) - modulated_sample = wave_sample * (1 + modulation * math.sin(2 * math.pi * 0.5 * x / sample_rate)) + modulated_sample = wave_sample * ( + 1 + modulation * math.sin(2 * math.pi * 0.5 * x / sample_rate) + ) audio.append(volume * modulated_sample) return + def save_wav(file_name): wav_file = wave.open(file_name, "w") @@ -147,18 +237,20 @@ def save_wav(file_name): wav_file.setparams((nchannels, sampwidth, sample_rate, nframes, comptype, compname)) for sample in audio: - wav_file.writeframes(struct.pack('h', int(sample * 32767.0))) + wav_file.writeframes(struct.pack("h", int(sample * 32767.0))) wav_file.close() return + # Generate TTS audio using gTTS def generate_tts_audio(text, output_file): - tts = gTTS(text=text, lang='en') + tts = gTTS(text=text, lang="en") tts.save(output_file) print(f"TTS audio saved to {output_file}") + if tts_enabled: tts_audio_file = "./SOUND/tts_output.mp3" generate_tts_audio(tts_text, tts_audio_file) @@ -191,26 +283,37 @@ for xyz in range(AMOUNT): y2 = random.randint(minH, maxH) if color_mode == "random": - color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) + color = ( + random.randint(0, 255), + random.randint(0, 255), + random.randint(0, 255), + ) elif color_mode == "scheme": scheme_colors = color_schemes.get(color_scheme, [(128, 128, 128)]) color = random.choice(scheme_colors) elif color_mode == "solid": - color = tuple(int(solid_color.lstrip("#")[i:i + 2], 16) for i in (0, 2, 4)) + color = tuple( + int(solid_color.lstrip("#")[i : i + 2], 16) for i in (0, 2, 4) + ) if shape_type == "rectangle": - img1.rectangle([(x1, y1), (x1 + x2, y1 + y2)], fill=color, outline=color) + img1.rectangle( + [(x1, y1), (x1 + x2, y1 + y2)], fill=color, outline=color + ) elif shape_type == "ellipse": img1.ellipse([(x1, y1), (x1 + x2, y1 + y2)], fill=color, outline=color) elif shape_type == "polygon": num_points = random.randint(3, 6) - points = [(random.randint(0, w), random.randint(0, h)) for _ in range(num_points)] + points = [ + (random.randint(0, w), random.randint(0, h)) + for _ in range(num_points) + ] img1.polygon(points, fill=color, outline=color) elif shape_type == "triangle": points = [ (x1, y1), (x1 + random.randint(-x2, x2), y1 + y2), - (x1 + x2, y1 + random.randint(-y2, y2)) + (x1 + x2, y1 + random.randint(-y2, y2)), ] img1.polygon(points, fill=color, outline=color) elif shape_type == "star": @@ -225,11 +328,18 @@ for xyz in range(AMOUNT): img1.polygon(points, fill=color, outline=color) elif shape_type == "circle": radius = min(x2, y2) // 2 - img1.ellipse([(x1 - radius, y1 - radius), (x1 + radius, y1 + radius)], fill=color, outline=color) + img1.ellipse( + [(x1 - radius, y1 - radius), (x1 + radius, y1 + radius)], + fill=color, + outline=color, + ) if top_left_text_enabled: if top_left_text_mode == "random": - random_top_left_text = generate_string(30, charset="abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()_+-=[]{}|;:',.<>?/") + random_top_left_text = generate_string( + 30, + charset="abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()_+-=[]{}|;:',.<>?/", + ) elif top_left_text_mode == "word": random_top_left_text = generate_word(words_topic) else: @@ -240,7 +350,9 @@ for xyz in range(AMOUNT): video_name_text = f"{video_name}.mp4" video_name_width = img1.textlength(video_name_text, font=fnt) video_name_height = font_size - img1.text((10, h - video_name_height - 10), video_name_text, font=fnt, fill="black") + img1.text( + (10, h - video_name_height - 10), video_name_text, font=fnt, fill="black" + ) # Move slide info text to the top right corner slide_text = f"Slide {i}" @@ -273,10 +385,13 @@ for xyz in range(AMOUNT): print("MP3 GENERATED") - image_folder = './IMG' + image_folder = "./IMG" fps = 1000 / slide_duration # Ensure fps is precise to handle timing discrepancies - image_files = sorted([f for f in glob.glob(f"{image_folder}/*.png")], key=lambda x: int(os.path.basename(x).split('_')[0])) + image_files = sorted( + [f for f in glob.glob(f"{image_folder}/*.png")], + key=lambda x: int(os.path.basename(x).split("_")[0]), + ) # Ensure all frames have the same dimensions frames = [] @@ -284,19 +399,19 @@ for xyz in range(AMOUNT): for idx, file in enumerate(image_files): frame = np.array(Image.open(file)) if frame.shape != first_frame.shape: - print(f"Frame {idx} has inconsistent dimensions: {frame.shape} vs {first_frame.shape}") + print( + f"Frame {idx} has inconsistent dimensions: {frame.shape} vs {first_frame.shape}" + ) frame = np.resize(frame, first_frame.shape) # Resize if necessary frames.append(frame) print("Starting video compilation...") - clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip( - frames, fps=fps - ) + clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(frames, fps=fps) clip.write_videofile( - f'./OUTPUT/{video_name}.mp4', - audio="./SOUND/output.m4a", - codec="libx264", - audio_codec="aac" + f"./OUTPUT/{video_name}.mp4", + audio="./SOUND/output.m4a", + codec="libx264", + audio_codec="aac", ) print("Video compilation finished successfully!") diff --git a/api_integration.py b/api_integration.py index a67b984..5d406f9 100644 --- a/api_integration.py +++ b/api_integration.py @@ -6,7 +6,7 @@ import sys import json # Add the api_service directory to the Python path -sys.path.append(os.path.join(os.path.dirname(__file__), 'api_service')) +sys.path.append(os.path.join(os.path.dirname(__file__), "api_service")) # Import the API client and models from api_service.discord_client import ApiClient @@ -15,6 +15,7 @@ from api_service.api_models import Conversation, UserSettings, Message # API client instance api_client = None + # Initialize the API client def init_api_client(api_url: str): """Initialize the API client with the given URL""" @@ -22,6 +23,7 @@ def init_api_client(api_url: str): api_client = ApiClient(api_url) return api_client + # Set the Discord token for the API client def set_token(token: str): """Set the Discord token for the API client""" @@ -30,22 +32,25 @@ def set_token(token: str): else: raise ValueError("API client not initialized") + # ============= Conversation Methods ============= + async def get_user_conversations(user_id: str, token: str) -> List[Conversation]: """Get all conversations for a user""" if not api_client: raise ValueError("API client not initialized") - + # Set the token for this request api_client.set_token(token) - + try: return await api_client.get_conversations() except Exception as e: print(f"Error getting conversations for user {user_id}: {e}") return [] + async def save_discord_conversation( user_id: str, token: str, @@ -58,15 +63,15 @@ async def save_discord_conversation( temperature: float = 0.7, max_tokens: int = 1000, web_search_enabled: bool = False, - system_message: Optional[str] = None + system_message: Optional[str] = None, ) -> Optional[Conversation]: """Save a conversation from Discord to the API""" if not api_client: raise ValueError("API client not initialized") - + # Set the token for this request api_client.set_token(token) - + try: return await api_client.save_discord_conversation( messages=messages, @@ -78,48 +83,51 @@ async def save_discord_conversation( temperature=temperature, max_tokens=max_tokens, web_search_enabled=web_search_enabled, - system_message=system_message + system_message=system_message, ) except Exception as e: print(f"Error saving conversation for user {user_id}: {e}") return None + # ============= Settings Methods ============= + async def get_user_settings(user_id: str, token: str) -> Optional[UserSettings]: """Get settings for a user""" if not api_client: raise ValueError("API client not initialized") - + # Set the token for this request api_client.set_token(token) - + try: return await api_client.get_settings() except Exception as e: print(f"Error getting settings for user {user_id}: {e}") return None + async def update_user_settings( - user_id: str, - token: str, - settings: UserSettings + user_id: str, token: str, settings: UserSettings ) -> Optional[UserSettings]: """Update settings for a user""" if not api_client: raise ValueError("API client not initialized") - + # Set the token for this request api_client.set_token(token) - + try: return await api_client.update_settings(settings) except Exception as e: print(f"Error updating settings for user {user_id}: {e}") return None + # ============= Helper Methods ============= + def convert_discord_settings_to_api(settings: Dict[str, Any]) -> UserSettings: """Convert Discord bot settings to API UserSettings""" return UserSettings( @@ -136,9 +144,10 @@ def convert_discord_settings_to_api(settings: Dict[str, Any]) -> UserSettings: custom_instructions=settings.get("custom_instructions"), advanced_view_enabled=False, # Default value streaming_enabled=True, # Default value - last_updated=datetime.datetime.now() + last_updated=datetime.datetime.now(), ) + def convert_api_settings_to_discord(settings: UserSettings) -> Dict[str, Any]: """Convert API UserSettings to Discord bot settings""" return { @@ -152,5 +161,5 @@ def convert_api_settings_to_discord(settings: UserSettings) -> Dict[str, Any]: "character": settings.character, "character_info": settings.character_info, "character_breakdown": settings.character_breakdown, - "custom_instructions": settings.custom_instructions + "custom_instructions": settings.custom_instructions, } diff --git a/api_service/api_models.py b/api_service/api_models.py index fba647d..2d0db9c 100644 --- a/api_service/api_models.py +++ b/api_service/api_models.py @@ -5,6 +5,7 @@ import uuid # ============= Data Models ============= + class Message(BaseModel): content: str role: str # "user", "assistant", or "system" @@ -12,6 +13,7 @@ class Message(BaseModel): reasoning: Optional[str] = None usage_data: Optional[Dict[str, Any]] = None + class Conversation(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) title: str @@ -28,8 +30,10 @@ class Conversation(BaseModel): web_search_enabled: bool = False system_message: Optional[str] = None + class ThemeSettings(BaseModel): """Theme settings for the dashboard UI""" + theme_mode: str = "light" # "light", "dark", "custom" primary_color: str = "#5865F2" # Discord blue secondary_color: str = "#2D3748" @@ -37,6 +41,7 @@ class ThemeSettings(BaseModel): font_family: str = "Inter, sans-serif" custom_css: Optional[str] = None + class UserSettings(BaseModel): # General settings model_id: str = "openai/gpt-3.5-turbo" @@ -71,74 +76,100 @@ class UserSettings(BaseModel): custom_bot_enabled: bool = False custom_bot_prefix: str = "!" custom_bot_status_text: str = "!help" - custom_bot_status_type: str = "listening" # "playing", "listening", "watching", "competing" + custom_bot_status_type: str = ( + "listening" # "playing", "listening", "watching", "competing" + ) # Last updated timestamp last_updated: datetime.datetime = Field(default_factory=datetime.datetime.now) + # ============= Role Selector Models ============= + class RoleOption(BaseModel): """Represents a single selectable role within a category preset.""" + role_id: str # Discord Role ID name: str emoji: Optional[str] = None + class RoleCategoryPreset(BaseModel): """Represents a global preset for a role category.""" - id: str = Field(default_factory=lambda: str(uuid.uuid4())) # Unique ID for the preset category + + id: str = Field( + default_factory=lambda: str(uuid.uuid4()) + ) # Unique ID for the preset category name: str # e.g., "Colors", "Pronouns" description: str roles: List[RoleOption] = [] max_selectable: int = 1 - display_order: int = 0 # For ordering presets if listed + display_order: int = 0 # For ordering presets if listed + class GuildRole(BaseModel): """Represents a specific role configured by a guild for selection.""" + role_id: str # Discord Role ID name: str emoji: Optional[str] = None + class GuildRoleCategoryConfig(BaseModel): """Represents a guild's specific configuration for a role selection category.""" + guild_id: str - category_id: str = Field(default_factory=lambda: str(uuid.uuid4())) # Unique ID for this guild's category instance - name: str # Custom name or preset name + category_id: str = Field( + default_factory=lambda: str(uuid.uuid4()) + ) # Unique ID for this guild's category instance + name: str # Custom name or preset name description: str roles: List[GuildRole] = [] max_selectable: int = 1 - message_id: Optional[str] = None # Discord message ID of the selector embed - channel_id: Optional[str] = None # Discord channel ID where the selector embed is posted - is_preset: bool = False # True if this category is based on a global preset - preset_id: Optional[str] = None # If is_preset, this links to RoleCategoryPreset.id + message_id: Optional[str] = None # Discord message ID of the selector embed + channel_id: Optional[str] = ( + None # Discord channel ID where the selector embed is posted + ) + is_preset: bool = False # True if this category is based on a global preset + preset_id: Optional[str] = None # If is_preset, this links to RoleCategoryPreset.id + class UserCustomColorRole(BaseModel): """Represents a user's custom color role.""" + user_id: str guild_id: str - role_id: str # Discord Role ID of their custom color role - hex_color: str # e.g., "#RRGGBB" + role_id: str # Discord Role ID of their custom color role + hex_color: str # e.g., "#RRGGBB" last_updated: datetime.datetime = Field(default_factory=datetime.datetime.now) + # ============= API Request/Response Models ============= + class GetConversationsResponse(BaseModel): conversations: List[Conversation] + class GetSettingsResponse(BaseModel): settings: UserSettings + class UpdateSettingsRequest(BaseModel): settings: UserSettings + class UpdateConversationRequest(BaseModel): conversation: Conversation + class ApiResponse(BaseModel): success: bool message: str data: Optional[Any] = None + class NumberData(BaseModel): card_number: str expiry_date: str diff --git a/api_service/api_server.py b/api_service/api_server.py index 673907b..f48c033 100644 --- a/api_service/api_server.py +++ b/api_service/api_server.py @@ -3,16 +3,30 @@ import json import sys import asyncio from typing import Dict, List, Optional, Any -from fastapi import FastAPI, HTTPException, Depends, Header, Request, Response, status, Body +from fastapi import ( + FastAPI, + HTTPException, + Depends, + Header, + Request, + Response, + status, + Body, +) from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import HTMLResponse, PlainTextResponse, RedirectResponse, FileResponse +from fastapi.responses import ( + HTMLResponse, + PlainTextResponse, + RedirectResponse, + FileResponse, +) from fastapi.staticfiles import StaticFiles from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.sessions import SessionMiddleware import aiohttp import asyncpg import discord -from api_service.database import Database # Existing DB +from api_service.database import Database # Existing DB import logging from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic import BaseModel, Field @@ -28,23 +42,21 @@ from starlette.exceptions import HTTPException as StarletteHTTPException log = logging.getLogger("api_server") logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.StreamHandler(), - logging.FileHandler("api_server.log") - ] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(), logging.FileHandler("api_server.log")], ) # --- Configuration Loading --- # Determine the path to the .env file relative to this api_server.py file # Go up one level from api_service/ to the project root, then into discordbot/ -dotenv_path = os.path.join(os.path.dirname(__file__), '..', 'discordbot', '.env') +dotenv_path = os.path.join(os.path.dirname(__file__), "..", "discordbot", ".env") + class ApiSettings(BaseSettings): # Existing API settings (if any were loaded from env before) GURT_STATS_PUSH_SECRET: Optional[str] = None - API_HOST: str = "0.0.0.0" # Keep existing default if used - API_PORT: int = 8001 # Changed default port to 8001 + API_HOST: str = "0.0.0.0" # Keep existing default if used + API_PORT: int = 8001 # Changed default port to 8001 SSL_CERT_FILE: Optional[str] = None SSL_KEY_FILE: Optional[str] = None @@ -55,33 +67,37 @@ class ApiSettings(BaseSettings): DISCORD_BOT_TOKEN: Optional[str] = None # Add bot token for API calls (optional) # Secret key for dashboard session management - DASHBOARD_SECRET_KEY: str = "a_default_secret_key_for_development_only" # Provide a default for dev + DASHBOARD_SECRET_KEY: str = ( + "a_default_secret_key_for_development_only" # Provide a default for dev + ) # Database/Redis settings (Required for settings_manager) POSTGRES_USER: str POSTGRES_PASSWORD: str POSTGRES_HOST: str - POSTGRES_SETTINGS_DB: str # The specific DB for settings + POSTGRES_SETTINGS_DB: str # The specific DB for settings REDIS_HOST: str REDIS_PORT: int = 6379 - REDIS_PASSWORD: Optional[str] = None # Optional + REDIS_PASSWORD: Optional[str] = None # Optional # Secret key for AI Moderation API endpoint MOD_LOG_API_SECRET: Optional[str] = None AI_API_KEY: Optional[str] = None model_config = SettingsConfigDict( - env_file=dotenv_path, - env_file_encoding='utf-8', - extra='ignore' + env_file=dotenv_path, env_file_encoding="utf-8", extra="ignore" ) + @lru_cache() def get_api_settings() -> ApiSettings: if not os.path.exists(dotenv_path): - print(f"Warning: .env file not found at {dotenv_path}. Using defaults or environment variables.") + print( + f"Warning: .env file not found at {dotenv_path}. Using defaults or environment variables." + ) return ApiSettings() + settings = get_api_settings() # --- Constants derived from settings --- @@ -89,7 +105,9 @@ DISCORD_API_BASE_URL = "https://discord.com/api/v10" DISCORD_API_ENDPOINT = DISCORD_API_BASE_URL # Alias for backward compatibility # Define dashboard-specific redirect URI -DASHBOARD_REDIRECT_URI = f"{settings.DISCORD_REDIRECT_URI.split('/api')[0]}/dashboard/api/auth/callback" +DASHBOARD_REDIRECT_URI = ( + f"{settings.DISCORD_REDIRECT_URI.split('/api')[0]}/dashboard/api/auth/callback" +) # We'll generate the full auth URL with PKCE parameters in the dashboard_login function # This is just a base URL without the PKCE parameters @@ -117,7 +135,10 @@ DISCORD_AUTH_URL = DISCORD_AUTH_BASE_URL latest_gurt_stats: Optional[Dict[str, Any]] = None # GURT_STATS_PUSH_SECRET is now loaded via ApiSettings if not settings.GURT_STATS_PUSH_SECRET: - print("Warning: GURT_STATS_PUSH_SECRET not set. Internal stats update endpoint will be insecure.") + print( + "Warning: GURT_STATS_PUSH_SECRET not set. Internal stats update endpoint will be insecure." + ) + # --- Helper Functions --- async def get_guild_name_from_api(guild_id: int, timeout: float = 5.0) -> str: @@ -142,21 +163,23 @@ async def get_guild_name_from_api(guild_id: int, timeout: float = 5.0) -> str: session = http_session if http_session else aiohttp.ClientSession() # Headers for the request - headers = {'Authorization': f'Bot {settings.DISCORD_BOT_TOKEN}'} + headers = {"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}"} # Send the request with a timeout async with session.get( f"https://discord.com/api/v10/guilds/{guild_id}", headers=headers, - timeout=timeout + timeout=timeout, ) as response: if response.status == 200: guild_data = await response.json() - guild_name = guild_data.get('name', fallback) + guild_name = guild_data.get("name", fallback) log.info(f"Retrieved guild name '{guild_name}' for guild ID {guild_id}") return guild_name else: - log.warning(f"Failed to get guild name for guild ID {guild_id}: HTTP {response.status}") + log.warning( + f"Failed to get guild name for guild ID {guild_id}: HTTP {response.status}" + ) return fallback except asyncio.TimeoutError: log.error(f"Timeout getting guild name for guild ID {guild_id}") @@ -165,7 +188,10 @@ async def get_guild_name_from_api(guild_id: int, timeout: float = 5.0) -> str: log.error(f"Error getting guild name for guild ID {guild_id}: {e}") return fallback -async def send_discord_message_via_api(channel_id: int, content: str, timeout: float = 5.0) -> Dict[str, Any]: + +async def send_discord_message_via_api( + channel_id: int, content: str, timeout: float = 5.0 +) -> Dict[str, Any]: """ Send a message to a Discord channel using Discord's REST API directly. This avoids using Discord.py's channel.send() method which can cause issues with FastAPI. @@ -182,7 +208,7 @@ async def send_discord_message_via_api(channel_id: int, content: str, timeout: f return { "success": False, "message": "Discord bot token not configured", - "error": "no_token" + "error": "no_token", } # Discord API endpoint for sending messages @@ -191,20 +217,22 @@ async def send_discord_message_via_api(channel_id: int, content: str, timeout: f # Headers for the request headers = { "Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}", - "Content-Type": "application/json" + "Content-Type": "application/json", } # Message data - allow for complex payloads (like embeds) data: Dict[str, Any] if isinstance(content, str): data = {"content": content} - elif isinstance(content, dict): # Assuming dict means it's a full payload like {"embeds": [...]} + elif isinstance( + content, dict + ): # Assuming dict means it's a full payload like {"embeds": [...]} data = content else: return { "success": False, "message": "Invalid content type for sending message. Must be string or dict.", - "error": "invalid_content_type" + "error": "invalid_content_type", } log.debug(f"Sending message to channel {channel_id} with data: {data}") @@ -216,14 +244,16 @@ async def send_discord_message_via_api(channel_id: int, content: str, timeout: f session = http_session if http_session else aiohttp.ClientSession() # Send the request with a timeout - async with session.post(url, headers=headers, json=data, timeout=timeout) as response: + async with session.post( + url, headers=headers, json=data, timeout=timeout + ) as response: if response.status == 200 or response.status == 201: # Message sent successfully response_data = await response.json() return { "success": True, "message": "Message sent successfully", - "message_id": response_data.get("id") + "message_id": response_data.get("id"), } elif response.status == 403: # Missing permissions @@ -231,7 +261,7 @@ async def send_discord_message_via_api(channel_id: int, content: str, timeout: f "success": False, "message": "Missing permissions to send message to this channel", "error": "forbidden", - "status": response.status + "status": response.status, } elif response.status == 429: # Rate limited @@ -242,7 +272,7 @@ async def send_discord_message_via_api(channel_id: int, content: str, timeout: f "message": f"Rate limited by Discord API. Retry after {retry_after} seconds", "error": "rate_limited", "retry_after": retry_after, - "status": response.status + "status": response.status, } else: # Other error @@ -253,27 +283,27 @@ async def send_discord_message_via_api(channel_id: int, content: str, timeout: f "message": f"Discord API error: {response.status}", "error": "api_error", "status": response.status, - "details": response_data + "details": response_data, } except: return { "success": False, "message": f"Discord API error: {response.status}", "error": "api_error", - "status": response.status + "status": response.status, } except asyncio.TimeoutError: return { "success": False, "message": "Timeout sending message to Discord API", - "error": "timeout" + "error": "timeout", } except Exception as e: return { "success": False, "message": f"Error sending message: {str(e)}", "error": "unknown", - "details": str(e) + "details": str(e), } try: @@ -285,13 +315,15 @@ async def send_discord_message_via_api(channel_id: int, content: str, timeout: f "success": False, "message": f"Error sending message: {str(e)}", "error": "task_error", - "details": str(e) + "details": str(e), } + + # --------------------------------- # Import dependencies after defining settings and constants # Use absolute imports to avoid issues when running the server directly -from api_service import dependencies # type: ignore +from api_service import dependencies # type: ignore from api_service.api_models import ( Conversation, NumberData, @@ -299,26 +331,30 @@ from api_service.api_models import ( GetConversationsResponse, UpdateSettingsRequest, UpdateConversationRequest, - ApiResponse + ApiResponse, ) import api_service.code_verifier_store as code_verifier_store # Ensure discordbot is in path to import settings_manager -discordbot_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +discordbot_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if discordbot_path not in sys.path: sys.path.insert(0, discordbot_path) try: - import settings_manager # type: ignore # type: ignore + import settings_manager # type: ignore # type: ignore from global_bot_accessor import get_bot_instance + log.info("Successfully imported settings_manager module and get_bot_instance") except ImportError as e: log.error(f"Could not import settings_manager or get_bot_instance: {e}") - log.error("Ensure the API is run from the project root or discordbot is in PYTHONPATH.") - settings_manager = None # Set to None to indicate failure + log.error( + "Ensure the API is run from the project root or discordbot is in PYTHONPATH." + ) + settings_manager = None # Set to None to indicate failure # ============= API Setup ============= + # Define lifespan context manager for FastAPI @asynccontextmanager async def lifespan(_: FastAPI): # Underscore indicates unused but required parameter @@ -335,7 +371,7 @@ async def lifespan(_: FastAPI): # Underscore indicates unused but required para # Start aiohttp session http_session = aiohttp.ClientSession() log.info("aiohttp session started.") - dependencies.set_http_session(http_session) # Pass session to dependencies module + dependencies.set_http_session(http_session) # Pass session to dependencies module log.info("aiohttp session passed to dependencies module.") # Initialize settings_manager pools for the API server @@ -379,7 +415,9 @@ async def lifespan(_: FastAPI): # Underscore indicates unused but required para # The bot (main.py) is responsible for setting the global pools in settings_manager. # API server will use its own pools from app.state and pass them explicitly if needed. if not settings_manager: - log.error("settings_manager not imported. API endpoints requiring it may fail.") + log.error( + "settings_manager not imported. API endpoints requiring it may fail." + ) except Exception as e: log.exception(f"Failed to initialize API server's connection pools: {e}") @@ -387,8 +425,7 @@ async def lifespan(_: FastAPI): # Underscore indicates unused but required para app.state.pg_pool = None app.state.redis_pool = None - - yield # Lifespan part 1 ends here + yield # Lifespan part 1 ends here # Shutdown: Clean up resources log.info("Shutting down API server...") @@ -403,7 +440,7 @@ async def lifespan(_: FastAPI): # Underscore indicates unused but required para log.info("API Server's PostgreSQL pool closed.") app.state.pg_pool = None if app.state.redis_pool: - await app.state.redis_pool.close() # Assuming redis pool has a close method + await app.state.redis_pool.close() # Assuming redis pool has a close method log.info("API Server's Redis pool closed.") app.state.redis_pool = None @@ -412,14 +449,16 @@ async def lifespan(_: FastAPI): # Underscore indicates unused but required para await http_session.close() log.info("aiohttp session closed.") + # Create the FastAPI app with lifespan app = FastAPI(title="Unified API Service", lifespan=lifespan, debug=True) + @app.exception_handler(StarletteHTTPException) async def teapot_override(request: Request, exc: StarletteHTTPException): try: # Get path from scope, strip trailing slash, and lowercase - request_path_from_scope = request.scope.get('path', "") + request_path_from_scope = request.scope.get("path", "") # Ensure it's a string before calling rstrip if not isinstance(request_path_from_scope, str): request_path_from_scope = str(request_path_from_scope) @@ -427,7 +466,9 @@ async def teapot_override(request: Request, exc: StarletteHTTPException): path_processed = request_path_from_scope.rstrip("/").lower() except Exception as e: - log.error(f"Error accessing/processing request.scope['path'] in teapot_override: {e}, falling back to request.url.path") + log.error( + f"Error accessing/processing request.scope['path'] in teapot_override: {e}, falling back to request.url.path" + ) # Fallback, also strip trailing slash and lowercase url_path_str = str(request.url.path) path_processed = url_path_str.rstrip("/").lower() @@ -436,10 +477,12 @@ async def teapot_override(request: Request, exc: StarletteHTTPException): exact_openrouterkey_paths_normalized = [ "/openrouterkey", "/api/openrouterkey", - "/discordapi/openrouterkey" + "/discordapi/openrouterkey", ] - is_openrouterkey_related_path_match = path_processed in exact_openrouterkey_paths_normalized + is_openrouterkey_related_path_match = ( + path_processed in exact_openrouterkey_paths_normalized + ) # Enhanced logging to understand the decision process log.info( @@ -484,23 +527,27 @@ async def teapot_override(request: Request, exc: StarletteHTTPException): return HTMLResponse(content=html_content, status_code=418) raise exc + @app.get("/robots.txt", response_class=PlainTextResponse) async def robots_txt(): return """User-agent: * Disallow: / """ + # Add Session Middleware for Dashboard Auth # Uses DASHBOARD_SECRET_KEY from settings app.add_middleware( SessionMiddleware, secret_key=settings.DASHBOARD_SECRET_KEY, - session_cookie="dashboard_session", # Use a distinct cookie name - max_age=60 * 60 * 24 * 7 # 7 days expiry + session_cookie="dashboard_session", # Use a distinct cookie name + max_age=60 * 60 * 24 * 7, # 7 days expiry ) # Create a sub-application for the API with /api prefix -api_app = FastAPI(title="Unified API Service", docs_url="/docs", openapi_url="/openapi.json") +api_app = FastAPI( + title="Unified API Service", docs_url="/docs", openapi_url="/openapi.json" +) # Create a sub-application for backward compatibility with /discordapi prefix # This will be deprecated in the future @@ -508,20 +555,21 @@ discordapi_app = FastAPI( title="Discord Bot Sync API (DEPRECATED)", docs_url="/docs", openapi_url="/openapi.json", - description="This API is deprecated and will be removed in the future. Please use the /api endpoint instead." + description="This API is deprecated and will be removed in the future. Please use the /api endpoint instead.", ) # Create a sub-application for the new Dashboard API dashboard_api_app = FastAPI( title="Bot Dashboard API", - docs_url="/docs", # Can have its own docs - openapi_url="/openapi.json" + docs_url="/docs", # Can have its own docs + openapi_url="/openapi.json", ) # Import dashboard API endpoints try: # Use absolute import - from api_service.dashboard_api_endpoints import router as dashboard_router # type: ignore + from api_service.dashboard_api_endpoints import router as dashboard_router # type: ignore + # Add the dashboard router to the dashboard API app dashboard_api_app.include_router(dashboard_router) log.info("Dashboard API endpoints loaded successfully") @@ -536,9 +584,12 @@ except ImportError as e: # Import command customization models and endpoints try: # Use absolute import - from api_service.command_customization_endpoints import router as customization_router # type: ignore + from api_service.command_customization_endpoints import router as customization_router # type: ignore + # Add the command customization router to the dashboard API app - dashboard_api_app.include_router(customization_router, prefix="/commands", tags=["Command Customization"]) + dashboard_api_app.include_router( + customization_router, prefix="/commands", tags=["Command Customization"] + ) log.info("Command customization endpoints loaded successfully") except ImportError as e: log.error(f"Could not import command customization endpoints: {e}") @@ -547,7 +598,8 @@ except ImportError as e: # Import cog management endpoints try: # Use absolute import - from api_service.cog_management_endpoints import router as cog_management_router # type: ignore + from api_service.cog_management_endpoints import router as cog_management_router # type: ignore + log.info("Successfully imported cog_management_endpoints") # Add the cog management router to the dashboard API app dashboard_api_app.include_router(cog_management_router, tags=["Cog Management"]) @@ -557,16 +609,19 @@ except ImportError as e: # Try to import the module directly to see what's available (for debugging) try: import sys + log.info(f"Python path: {sys.path}") # Try to find the module in the current directory import os + current_dir = os.path.dirname(os.path.abspath(__file__)) log.info(f"Current directory: {current_dir}") files = os.listdir(current_dir) log.info(f"Files in current directory: {files}") # Try to import the module with a full path sys.path.append(current_dir) - import cog_management_endpoints # type: ignore + import cog_management_endpoints # type: ignore + log.info(f"Successfully imported cog_management_endpoints module directly") except Exception as e_debug: log.error(f"Debug import failed: {e_debug}") @@ -576,16 +631,22 @@ except ImportError as e: # Mount the API apps at their respective paths app.mount("/api", api_app) app.mount("/discordapi", discordapi_app) -app.mount("/dashboard/api", dashboard_api_app) # Mount the new dashboard API +app.mount("/dashboard/api", dashboard_api_app) # Mount the new dashboard API # Import and mount webhook endpoints try: - from api_service.webhook_endpoints import router as webhook_router # Relative import - app.mount("/webhook", webhook_router) # Mount directly on the main app for simplicity + from api_service.webhook_endpoints import ( + router as webhook_router, + ) # Relative import + + app.mount( + "/webhook", webhook_router + ) # Mount directly on the main app for simplicity # Import and mount terminal images endpoint try: from api_service.terminal_images_endpoint import mount_terminal_images + # Mount terminal images directory as static files mount_terminal_images(app) log.info("Terminal images endpoint mounted successfully") @@ -595,6 +656,7 @@ try: # After mounting the webhook router log.info("Available routes in webhook_router:") from fastapi.routing import APIRoute, Mount + for route in webhook_router.routes: if isinstance(route, APIRoute): log.info(f" {route.path} - {route.name} - {route.methods}") @@ -610,13 +672,15 @@ except ImportError as e: # Attempt to find the module for debugging try: import sys + log.info(f"Python path: {sys.path}") import os + current_dir = os.path.dirname(os.path.abspath(__file__)) log.info(f"Current directory for webhook_endpoints: {current_dir}") files_in_current_dir = os.listdir(current_dir) log.info(f"Files in {current_dir}: {files_in_current_dir}") - if 'webhook_endpoints.py' in files_in_current_dir: + if "webhook_endpoints.py" in files_in_current_dir: log.info("webhook_endpoints.py found in current directory.") else: log.warning("webhook_endpoints.py NOT found in current directory.") @@ -627,6 +691,7 @@ except ImportError as e: # Log the available routes for debugging log.info("Available routes in dashboard_api_app:") from fastapi.routing import APIRoute, Mount + for route in dashboard_api_app.routes: if isinstance(route, APIRoute): log.info(f" {route.path} - {route.name} - {route.methods}") @@ -635,17 +700,21 @@ for route in dashboard_api_app.routes: else: log.info(f" {route.path} - {route.name} - Unknown route type") + # Create a middleware for redirecting /discordapi to /api with a deprecation warning class DeprecationRedirectMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): # Check if the path starts with /discordapi - if request.url.path.startswith('/discordapi'): + if request.url.path.startswith("/discordapi"): # Add a deprecation warning header response = await call_next(request) - response.headers['X-API-Deprecation-Warning'] = 'This endpoint is deprecated. Please use /api instead.' + response.headers["X-API-Deprecation-Warning"] = ( + "This endpoint is deprecated. Please use /api instead." + ) return response return await call_next(request) + # Add CORS middleware to all apps for current_app in [app, api_app, discordapi_app]: current_app.add_middleware( @@ -667,6 +736,7 @@ http_session = None # ============= Authentication ============= + async def verify_discord_token(authorization: str = Header(None)) -> str: """Verify the Discord token and return the user ID""" if not authorization: @@ -680,18 +750,22 @@ async def verify_discord_token(authorization: str = Header(None)) -> str: # Verify the token with Discord async with aiohttp.ClientSession() as session: headers = {"Authorization": f"Bearer {token}"} - async with session.get("https://discord.com/api/v10/users/@me", headers=headers) as resp: + async with session.get( + "https://discord.com/api/v10/users/@me", headers=headers + ) as resp: if resp.status != 200: raise HTTPException(status_code=401, detail="Invalid Discord token") user_data = await resp.json() return user_data["id"] + # ============= API Endpoints ============= # Log the available routes for debugging log.info("Available routes in main app:") from fastapi.routing import APIRoute, Mount + for route in app.routes: if isinstance(route, APIRoute): log.info(f" {route.path} - {route.name} - {route.methods}") @@ -700,9 +774,13 @@ for route in app.routes: else: log.info(f" {route.path} - {route.name} - Unknown route type") + @app.get("/") async def root(): - return RedirectResponse(url="https://www.youtube.com/watch?v=dQw4w9WgXcQ", status_code=301) + return RedirectResponse( + url="https://www.youtube.com/watch?v=dQw4w9WgXcQ", status_code=301 + ) + # Add the same endpoint to the api_app to ensure it's accessible @api_app.get("/openrouterkey", response_class=PlainTextResponse) @@ -711,12 +789,18 @@ async def api_openrouterkey(request: Request): # Basic security check auth_header = request.headers.get("Authorization") # Use loaded setting - if not settings.MOD_LOG_API_SECRET or not auth_header or auth_header != f"Bearer {settings.MOD_LOG_API_SECRET}": + if ( + not settings.MOD_LOG_API_SECRET + or not auth_header + or auth_header != f"Bearer {settings.MOD_LOG_API_SECRET}" + ): print("Unauthorized attempt to access OpenRouter key (api_app).") raise HTTPException(status_code=403, detail="Forbidden") # Add debug logging - log.info(f"OpenRouter key request authorized (api_app). AI_API_KEY is {'set' if settings.AI_API_KEY else 'not set'}") + log.info( + f"OpenRouter key request authorized (api_app). AI_API_KEY is {'set' if settings.AI_API_KEY else 'not set'}" + ) # Check if AI_API_KEY is set if not settings.AI_API_KEY: @@ -725,6 +809,7 @@ async def api_openrouterkey(request: Request): return f"{settings.AI_API_KEY}" + # Add the same endpoint to the discordapi_app to ensure it's accessible @discordapi_app.get("/openrouterkey", response_class=PlainTextResponse) async def discordapi_openrouterkey(request: Request): @@ -732,12 +817,18 @@ async def discordapi_openrouterkey(request: Request): # Basic security check auth_header = request.headers.get("Authorization") # Use loaded setting - if not settings.MOD_LOG_API_SECRET or not auth_header or auth_header != f"Bearer {settings.MOD_LOG_API_SECRET}": + if ( + not settings.MOD_LOG_API_SECRET + or not auth_header + or auth_header != f"Bearer {settings.MOD_LOG_API_SECRET}" + ): print("Unauthorized attempt to access OpenRouter key (discordapi_app).") raise HTTPException(status_code=403, detail="Forbidden") # Add debug logging - log.info(f"OpenRouter key request authorized (discordapi_app). AI_API_KEY is {'set' if settings.AI_API_KEY else 'not set'}") + log.info( + f"OpenRouter key request authorized (discordapi_app). AI_API_KEY is {'set' if settings.AI_API_KEY else 'not set'}" + ) # Check if AI_API_KEY is set if not settings.AI_API_KEY: @@ -746,21 +837,32 @@ async def discordapi_openrouterkey(request: Request): return f"{settings.AI_API_KEY}" + @app.get("/discord") async def root(): return RedirectResponse(url="https://discord.gg/gebDRq6u", status_code=301) + @app.get("/discordbot") async def root(): - return RedirectResponse(url="https://discord.com/oauth2/authorize?client_id=1360717457852993576", status_code=301) + return RedirectResponse( + url="https://discord.com/oauth2/authorize?client_id=1360717457852993576", + status_code=301, + ) + @app.get("/ip") async def ip(request: Request): return Response(content=request.client.host, media_type="text/plain") + @app.get("/agent") async def agent(request: Request): - return Response(content=request.headers.get("user-agent", request.client.host), media_type="text/plain") + return Response( + content=request.headers.get("user-agent", request.client.host), + media_type="text/plain", + ) + @app.get("/debug-settings", response_class=PlainTextResponse) async def debug_settings(request: Request): @@ -768,8 +870,16 @@ async def debug_settings(request: Request): # Basic security check - only allow from localhost or with the same auth as openrouterkey client_host = request.client.host auth_header = request.headers.get("Authorization") - is_local = client_host == "127.0.0.1" or client_host == "::1" or client_host.startswith("172.") - is_authorized = auth_header and settings.MOD_LOG_API_SECRET and auth_header == f"Bearer {settings.MOD_LOG_API_SECRET}" + is_local = ( + client_host == "127.0.0.1" + or client_host == "::1" + or client_host.startswith("172.") + ) + is_authorized = ( + auth_header + and settings.MOD_LOG_API_SECRET + and auth_header == f"Bearer {settings.MOD_LOG_API_SECRET}" + ) if not (is_local or is_authorized): print(f"Unauthorized attempt to access debug settings from {client_host}.") @@ -790,10 +900,12 @@ async def debug_settings(request: Request): return "\n".join(settings_summary) + # Add root for dashboard API for clarity @dashboard_api_app.get("/") async def dashboard_api_root(): - return {"message": "Bot Dashboard API is running"} + return {"message": "Bot Dashboard API is running"} + # Add a test endpoint for cogs @dashboard_api_app.get("/test-cogs", tags=["Test"]) @@ -801,6 +913,7 @@ async def test_cogs_endpoint(): """Test endpoint to verify the API server is working correctly.""" return {"message": "Test cogs endpoint is working"} + # Add a direct endpoint for cogs without dependencies @dashboard_api_app.get("/guilds/{guild_id}/cogs-direct", tags=["Test"]) async def get_guild_cogs_no_deps(guild_id: int): @@ -809,7 +922,8 @@ async def get_guild_cogs_no_deps(guild_id: int): # First try to get cogs from the bot instance bot = None try: - import discord_bot_sync_api # type: ignore + import discord_bot_sync_api # type: ignore + bot = discord_bot_sync_api.bot_instance except (ImportError, AttributeError) as e: log.warning(f"Could not import bot instance: {e}") @@ -817,7 +931,10 @@ async def get_guild_cogs_no_deps(guild_id: int): # Check if settings_manager is available bot = get_bot_instance() if not settings_manager or not bot or not bot.pg_pool: - return {"error": "Settings manager or database connection not available", "cogs": []} + return { + "error": "Settings manager or database connection not available", + "cogs": [], + } # Get cogs from the database directly if bot is not available cogs_list = [] @@ -829,15 +946,19 @@ async def get_guild_cogs_no_deps(guild_id: int): # Get enabled status from settings_manager is_enabled = True try: - is_enabled = await settings_manager.is_cog_enabled(guild_id, cog_name, default_enabled=True) + is_enabled = await settings_manager.is_cog_enabled( + guild_id, cog_name, default_enabled=True + ) except Exception as e: log.error(f"Error getting cog enabled status: {e}") - cogs_list.append({ - "name": cog_name, - "description": cog.__doc__ or "No description available", - "enabled": is_enabled - }) + cogs_list.append( + { + "name": cog_name, + "description": cog.__doc__ or "No description available", + "enabled": is_enabled, + } + ) else: # Fallback: Get cogs from the database directly log.info(f"Getting cogs from database for guild {guild_id}") @@ -847,33 +968,47 @@ async def get_guild_cogs_no_deps(guild_id: int): # Add each cog to the list for cog_name, is_enabled in cog_statuses.items(): - cogs_list.append({ - "name": cog_name, - "description": "Description not available (bot instance not accessible)", - "enabled": is_enabled - }) + cogs_list.append( + { + "name": cog_name, + "description": "Description not available (bot instance not accessible)", + "enabled": is_enabled, + } + ) # If no cogs were found, add some default cogs if not cogs_list: default_cogs = [ - "SettingsCog", "HelpCog", "ModerationCog", "WelcomeCog", - "GurtCog", "EconomyCog", "UtilityCog" + "SettingsCog", + "HelpCog", + "ModerationCog", + "WelcomeCog", + "GurtCog", + "EconomyCog", + "UtilityCog", ] for cog_name in default_cogs: # Try to get the enabled status from the database try: - is_enabled = await settings_manager.is_cog_enabled(guild_id, cog_name, default_enabled=True) + is_enabled = await settings_manager.is_cog_enabled( + guild_id, cog_name, default_enabled=True + ) except Exception: is_enabled = True - cogs_list.append({ - "name": cog_name, - "description": "Default cog (bot instance not accessible)", - "enabled": is_enabled - }) + cogs_list.append( + { + "name": cog_name, + "description": "Default cog (bot instance not accessible)", + "enabled": is_enabled, + } + ) except Exception as e: log.error(f"Error getting cogs from database: {e}") - return {"error": f"Error getting cogs from database: {str(e)}", "cogs": []} + return { + "error": f"Error getting cogs from database: {str(e)}", + "cogs": [], + } return {"cogs": cogs_list} except Exception as e: @@ -886,9 +1021,10 @@ async def discordapi_root(): return { "message": "DEPRECATED: This API endpoint (/discordapi) is deprecated and will be removed in the future.", "recommendation": "Please update your client to use the /api endpoint instead.", - "new_endpoint": "/api" + "new_endpoint": "/api", } + # Discord OAuth configuration now loaded via ApiSettings above # DISCORD_CLIENT_ID = os.getenv("DISCORD_CLIENT_ID", "1360717457852993576") # DISCORD_REDIRECT_URI = os.getenv("DISCORD_REDIRECT_URI", "https://slipstreamm.dev/api/auth") @@ -900,10 +1036,13 @@ async def discordapi_root(): # We will add the new dashboard auth flow under a different path prefix, e.g., /dashboard/api/auth/... # Keep the existing /auth endpoint as is for now. + # @app.get("/auth") # Keep existing @api_app.get("/auth") @discordapi_app.get("/auth") -async def auth(code: str, state: str = None, code_verifier: str = None, request: Request = None): +async def auth( + code: str, state: str = None, code_verifier: str = None, request: Request = None +): """Handle OAuth callback from Discord""" try: # Log the request details for debugging @@ -926,6 +1065,7 @@ async def auth(code: str, state: str = None, code_verifier: str = None, request: if referer and "code=" in referer: # Extract the redirect URI from the referer from urllib.parse import urlparse + parsed_url = urlparse(referer) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}" print(f"Extracted base URL from referer: {base_url}") @@ -936,7 +1076,7 @@ async def auth(code: str, state: str = None, code_verifier: str = None, request: actual_redirect_uri = base_url data = { - "client_id": settings.DISCORD_CLIENT_ID, # Use loaded setting + "client_id": settings.DISCORD_CLIENT_ID, # Use loaded setting "grant_type": "authorization_code", "code": code, "redirect_uri": actual_redirect_uri, @@ -949,14 +1089,20 @@ async def auth(code: str, state: str = None, code_verifier: str = None, request: if state: stored_code_verifier = code_verifier_store.get_code_verifier(state) if stored_code_verifier: - print(f"Found code_verifier in store for state {state}: {stored_code_verifier[:10]}...") + print( + f"Found code_verifier in store for state {state}: {stored_code_verifier[:10]}..." + ) else: - print(f"No code_verifier found in store for state {state}, will check other sources") + print( + f"No code_verifier found in store for state {state}, will check other sources" + ) # If we have a code_verifier parameter directly in the URL, use that if code_verifier: data["code_verifier"] = code_verifier - print(f"Using code_verifier from URL parameter: {code_verifier[:10]}...") + print( + f"Using code_verifier from URL parameter: {code_verifier[:10]}..." + ) # Otherwise use the stored code verifier if available elif stored_code_verifier: data["code_verifier"] = stored_code_verifier @@ -965,11 +1111,13 @@ async def auth(code: str, state: str = None, code_verifier: str = None, request: code_verifier_store.remove_code_verifier(state) else: # If we still don't have a code verifier, log a warning - print(f"WARNING: No code_verifier found for state {state} - OAuth will likely fail") + print( + f"WARNING: No code_verifier found for state {state} - OAuth will likely fail" + ) # Return a more helpful error message return { "message": "Authentication failed", - "error": "Missing code_verifier. This is required for PKCE OAuth flow. Please ensure the code_verifier is properly sent to the API server." + "error": "Missing code_verifier. This is required for PKCE OAuth flow. Please ensure the code_verifier is properly sent to the API server.", } # Log the token exchange request for debugging @@ -986,11 +1134,16 @@ async def auth(code: str, state: str = None, code_verifier: str = None, request: # Get the user's information access_token = token_data.get("access_token") if not access_token: - return {"message": "Authentication failed", "error": "No access token in response"} + return { + "message": "Authentication failed", + "error": "No access token in response", + } # Get the user's Discord ID headers = {"Authorization": f"Bearer {access_token}"} - async with session.get(f"{DISCORD_API_ENDPOINT}/users/@me", headers=headers) as user_resp: + async with session.get( + f"{DISCORD_API_ENDPOINT}/users/@me", headers=headers + ) as user_resp: if user_resp.status != 200: error_text = await user_resp.text() print(f"Failed to get user info: {error_text}") @@ -1000,7 +1153,10 @@ async def auth(code: str, state: str = None, code_verifier: str = None, request: user_id = user_data.get("id") if not user_id: - return {"message": "Authentication failed", "error": "No user ID in response"} + return { + "message": "Authentication failed", + "error": "No user ID in response", + } # Store the token in the database db.save_user_token(user_id, token_data) @@ -1039,7 +1195,7 @@ async def auth(code: str, state: str = None, code_verifier: str = None, request: return { "message": "Authentication successful", "user_id": user_id, - "token": token_data + "token": token_data, } except Exception as e: print(f"Error in auth endpoint: {str(e)}") @@ -1050,15 +1206,16 @@ async def auth(code: str, state: str = None, code_verifier: str = None, request: # Models are now in dashboard_models.py # Dependencies are now in dependencies.py -from api_service.dashboard_models import ( # type: ignore +from api_service.dashboard_models import ( # type: ignore GuildSettingsResponse, GuildSettingsUpdate, CommandPermission, CommandPermissionsResponse, - CogInfo, # Needed for direct cog endpoint + CogInfo, # Needed for direct cog endpoint # Other models used by imported routers are not needed here directly ) + # --- AI Moderation Action Model --- class AIModerationAction(BaseModel): timestamp: str @@ -1079,6 +1236,7 @@ class AIModerationAction(BaseModel): ai_model: str result: str + # Helper function to execute a warning async def execute_warning(bot, guild_id: int, user_id: int, reason: str) -> dict: """ @@ -1107,14 +1265,18 @@ async def execute_warning(bot, guild_id: int, user_id: int, reason: str) -> dict # Try to fetch the member if not in cache member = await guild.fetch_member(user_id) except discord.NotFound: - log.error(f"Could not find member with ID {user_id} in guild {guild_id}") + log.error( + f"Could not find member with ID {user_id} in guild {guild_id}" + ) return {"success": False, "error": "Member not found"} except Exception as e: - log.error(f"Error fetching member with ID {user_id} in guild {guild_id}: {e}") + log.error( + f"Error fetching member with ID {user_id} in guild {guild_id}: {e}" + ) return {"success": False, "error": f"Error fetching member: {str(e)}"} # Get the moderation cog - moderation_cog = bot.get_cog('ModerationCog') + moderation_cog = bot.get_cog("ModerationCog") if not moderation_cog: log.error(f"ModerationCog not found") return {"success": False, "error": "ModerationCog not found"} @@ -1138,15 +1300,21 @@ async def execute_warning(bot, guild_id: int, user_id: int, reason: str) -> dict # Call the warn method await moderation_cog.moderate_warn_callback(interaction, member, reason) - log.info(f"Successfully executed warning for user {user_id} in guild {guild_id}") + log.info( + f"Successfully executed warning for user {user_id} in guild {guild_id}" + ) return {"success": True, "message": "Warning executed successfully"} except Exception as e: - log.error(f"Error executing warning for user {user_id} in guild {guild_id}: {e}") + log.error( + f"Error executing warning for user {user_id} in guild {guild_id}: {e}" + ) import traceback + tb = traceback.format_exc() return {"success": False, "error": str(e), "traceback": tb} + # ============= Dashboard API Routes ============= # (Mounted under /dashboard/api) # Dependencies are imported from dependencies.py @@ -1154,29 +1322,35 @@ async def execute_warning(bot, guild_id: int, user_id: int, reason: str) -> dict # --- Direct Cog Management Endpoints --- # These are direct implementations in case the imported endpoints don't work + class CogCommandInfo(BaseModel): name: str description: Optional[str] = None enabled: bool = True + class CogInfo(BaseModel): name: str description: Optional[str] = None enabled: bool = True commands: List[Dict[str, Any]] = [] -@dashboard_api_app.get("/guilds/{guild_id}/cogs", response_model=List[CogInfo], tags=["Cog Management"]) + +@dashboard_api_app.get( + "/guilds/{guild_id}/cogs", response_model=List[CogInfo], tags=["Cog Management"] +) async def get_guild_cogs_direct( guild_id: int, _user: dict = Depends(dependencies.get_dashboard_user), - _admin: bool = Depends(dependencies.verify_dashboard_guild_admin) + _admin: bool = Depends(dependencies.verify_dashboard_guild_admin), ): """Get all cogs and their commands for a guild.""" try: # First try to get cogs from the bot instance bot = None try: - import discord_bot_sync_api # type: ignore + import discord_bot_sync_api # type: ignore + bot = discord_bot_sync_api.bot_instance except (ImportError, AttributeError) as e: log.warning(f"Could not import bot instance: {e}") @@ -1186,7 +1360,7 @@ async def get_guild_cogs_direct( if not settings_manager or not bot or not bot.pg_pool: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Settings manager or database connection not available" + detail="Settings manager or database connection not available", ) # Get cogs from the database directly if bot is not available @@ -1197,37 +1371,58 @@ async def get_guild_cogs_direct( log.info(f"Getting cogs from bot instance for guild {guild_id}") for cog_name, cog in bot.cogs.items(): # Get enabled status from settings_manager - is_enabled = await settings_manager.is_cog_enabled(guild_id, cog_name, default_enabled=True) + is_enabled = await settings_manager.is_cog_enabled( + guild_id, cog_name, default_enabled=True + ) # Get commands for this cog commands_list = [] for command in cog.get_commands(): # Get command enabled status - cmd_enabled = await settings_manager.is_command_enabled(guild_id, command.qualified_name, default_enabled=True) - commands_list.append({ - "name": command.qualified_name, - "description": command.help or "No description available", - "enabled": cmd_enabled - }) + cmd_enabled = await settings_manager.is_command_enabled( + guild_id, command.qualified_name, default_enabled=True + ) + commands_list.append( + { + "name": command.qualified_name, + "description": command.help or "No description available", + "enabled": cmd_enabled, + } + ) # Add slash commands if any - app_commands = [cmd for cmd in bot.tree.get_commands() if hasattr(cmd, 'cog') and cmd.cog and cmd.cog.qualified_name == cog_name] + app_commands = [ + cmd + for cmd in bot.tree.get_commands() + if hasattr(cmd, "cog") + and cmd.cog + and cmd.cog.qualified_name == cog_name + ] for cmd in app_commands: # Get command enabled status - cmd_enabled = await settings_manager.is_command_enabled(guild_id, cmd.name, default_enabled=True) - if not any(c["name"] == cmd.name for c in commands_list): # Avoid duplicates - commands_list.append({ - "name": cmd.name, - "description": cmd.description or "No description available", - "enabled": cmd_enabled - }) + cmd_enabled = await settings_manager.is_command_enabled( + guild_id, cmd.name, default_enabled=True + ) + if not any( + c["name"] == cmd.name for c in commands_list + ): # Avoid duplicates + commands_list.append( + { + "name": cmd.name, + "description": cmd.description + or "No description available", + "enabled": cmd_enabled, + } + ) - cogs_list.append(CogInfo( - name=cog_name, - description=cog.__doc__ or "No description available", - enabled=is_enabled, - commands=commands_list - )) + cogs_list.append( + CogInfo( + name=cog_name, + description=cog.__doc__ or "No description available", + enabled=is_enabled, + commands=commands_list, + ) + ) else: # Fallback: Get cogs from the database directly log.info(f"Getting cogs from database for guild {guild_id}") @@ -1236,7 +1431,9 @@ async def get_guild_cogs_direct( cog_statuses = await settings_manager.get_all_enabled_cogs(guild_id) # Get all command enabled statuses from the database - command_statuses = await settings_manager.get_all_enabled_commands(guild_id) + command_statuses = await settings_manager.get_all_enabled_commands( + guild_id + ) # Add each cog to the list for cog_name, is_enabled in cog_statuses.items(): @@ -1249,29 +1446,40 @@ async def get_guild_cogs_direct( cog_prefix = cog_name.lower().replace("cog", "") for cmd_name, cmd_enabled in command_statuses.items(): if cmd_name.lower().startswith(cog_prefix): - commands_list.append({ - "name": cmd_name, - "description": "Description not available (bot instance not accessible)", - "enabled": cmd_enabled - }) + commands_list.append( + { + "name": cmd_name, + "description": "Description not available (bot instance not accessible)", + "enabled": cmd_enabled, + } + ) - cogs_list.append(CogInfo( - name=cog_name, - description="Description not available (bot instance not accessible)", - enabled=is_enabled, - commands=commands_list - )) + cogs_list.append( + CogInfo( + name=cog_name, + description="Description not available (bot instance not accessible)", + enabled=is_enabled, + commands=commands_list, + ) + ) # If no cogs were found, add some default cogs if not cogs_list: default_cogs = [ - "SettingsCog", "HelpCog", "ModerationCog", "WelcomeCog", - "GurtCog", "EconomyCog", "UtilityCog" + "SettingsCog", + "HelpCog", + "ModerationCog", + "WelcomeCog", + "GurtCog", + "EconomyCog", + "UtilityCog", ] for cog_name in default_cogs: # Try to get the enabled status from the database try: - is_enabled = await settings_manager.is_cog_enabled(guild_id, cog_name, default_enabled=True) + is_enabled = await settings_manager.is_cog_enabled( + guild_id, cog_name, default_enabled=True + ) except Exception: is_enabled = True @@ -1282,37 +1490,50 @@ async def get_guild_cogs_direct( "settings": ["set", "get", "reset"], "help": ["help", "commands"], "moderation": ["ban", "kick", "mute", "unmute", "warn"], - "welcome": ["welcome", "goodbye", "setwelcome", "setgoodbye"], + "welcome": [ + "welcome", + "goodbye", + "setwelcome", + "setgoodbye", + ], "gurt": ["gurt", "gurtset"], "economy": ["balance", "daily", "work", "gamble"], - "utility": ["ping", "info", "serverinfo", "userinfo"] + "utility": ["ping", "info", "serverinfo", "userinfo"], } if cog_prefix in default_commands: for cmd_suffix in default_commands[cog_prefix]: cmd_name = f"{cog_prefix}{cmd_suffix}" try: - cmd_enabled = await settings_manager.is_command_enabled(guild_id, cmd_name, default_enabled=True) + cmd_enabled = ( + await settings_manager.is_command_enabled( + guild_id, cmd_name, default_enabled=True + ) + ) except Exception: cmd_enabled = True - commands_list.append({ - "name": cmd_name, - "description": "Default command (bot instance not accessible)", - "enabled": cmd_enabled - }) + commands_list.append( + { + "name": cmd_name, + "description": "Default command (bot instance not accessible)", + "enabled": cmd_enabled, + } + ) - cogs_list.append(CogInfo( - name=cog_name, - description="Default cog (bot instance not accessible)", - enabled=is_enabled, - commands=commands_list - )) + cogs_list.append( + CogInfo( + name=cog_name, + description="Default cog (bot instance not accessible)", + enabled=is_enabled, + commands=commands_list, + ) + ) except Exception as e: log.error(f"Error getting cogs from database: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error getting cogs from database: {str(e)}" + detail=f"Error getting cogs from database: {str(e)}", ) return cogs_list @@ -1323,16 +1544,21 @@ async def get_guild_cogs_direct( log.error(f"Error getting cogs for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error getting cogs: {str(e)}" + detail=f"Error getting cogs: {str(e)}", ) -@dashboard_api_app.patch("/guilds/{guild_id}/cogs/{cog_name}", status_code=status.HTTP_200_OK, tags=["Cog Management"]) + +@dashboard_api_app.patch( + "/guilds/{guild_id}/cogs/{cog_name}", + status_code=status.HTTP_200_OK, + tags=["Cog Management"], +) async def update_cog_status_direct( guild_id: int, cog_name: str, enabled: bool = Body(..., embed=True), _user: dict = Depends(dependencies.get_dashboard_user), - _admin: bool = Depends(dependencies.verify_dashboard_guild_admin) + _admin: bool = Depends(dependencies.verify_dashboard_guild_admin), ): """Enable or disable a cog for a guild.""" try: @@ -1340,22 +1566,24 @@ async def update_cog_status_direct( if not settings_manager: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Settings manager not available" + detail="Settings manager not available", ) # Get the bot instance to check if pools are available - from global_bot_accessor import get_bot_instance # type: ignore + from global_bot_accessor import get_bot_instance # type: ignore + bot_instance = get_bot_instance() if not bot_instance or not bot_instance.pg_pool: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Database connection not available" + detail="Database connection not available", ) # Try to get the bot instance, but don't require it bot = None try: - import discord_bot_sync_api # type: ignore # type: ignore + import discord_bot_sync_api # type: ignore # type: ignore + bot = discord_bot_sync_api.bot_instance except (ImportError, AttributeError) as e: log.warning(f"Could not import bot instance: {e}") @@ -1364,21 +1592,23 @@ async def update_cog_status_direct( if bot: # Check if the cog exists if cog_name not in bot.cogs: - log.warning(f"Cog '{cog_name}' not found in bot instance, but proceeding anyway") + log.warning( + f"Cog '{cog_name}' not found in bot instance, but proceeding anyway" + ) else: # Check if it's a core cog - core_cogs = getattr(bot, 'core_cogs', {'SettingsCog', 'HelpCog'}) + core_cogs = getattr(bot, "core_cogs", {"SettingsCog", "HelpCog"}) if cog_name in core_cogs and not enabled: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Core cog '{cog_name}' cannot be disabled" + detail=f"Core cog '{cog_name}' cannot be disabled", ) else: # If we don't have a bot instance, check if this is a known core cog - if cog_name in ['SettingsCog', 'HelpCog'] and not enabled: + if cog_name in ["SettingsCog", "HelpCog"] and not enabled: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Core cog '{cog_name}' cannot be disabled" + detail=f"Core cog '{cog_name}' cannot be disabled", ) # Update the cog enabled status @@ -1386,10 +1616,12 @@ async def update_cog_status_direct( if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update cog '{cog_name}' status" + detail=f"Failed to update cog '{cog_name}' status", ) - return {"message": f"Cog '{cog_name}' {'enabled' if enabled else 'disabled'} successfully"} + return { + "message": f"Cog '{cog_name}' {'enabled' if enabled else 'disabled'} successfully" + } except HTTPException: # Re-raise HTTP exceptions raise @@ -1397,16 +1629,21 @@ async def update_cog_status_direct( log.error(f"Error updating cog status for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error updating cog status: {str(e)}" + detail=f"Error updating cog status: {str(e)}", ) -@dashboard_api_app.patch("/guilds/{guild_id}/commands/{command_name}", status_code=status.HTTP_200_OK, tags=["Cog Management"]) + +@dashboard_api_app.patch( + "/guilds/{guild_id}/commands/{command_name}", + status_code=status.HTTP_200_OK, + tags=["Cog Management"], +) async def update_command_status_direct( guild_id: int, command_name: str, enabled: bool = Body(..., embed=True), _user: dict = Depends(dependencies.get_dashboard_user), - _admin: bool = Depends(dependencies.verify_dashboard_guild_admin) + _admin: bool = Depends(dependencies.verify_dashboard_guild_admin), ): """Enable or disable a command for a guild.""" try: @@ -1414,22 +1651,24 @@ async def update_command_status_direct( if not settings_manager: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Settings manager not available" + detail="Settings manager not available", ) # Get the bot instance to check if pools are available - from global_bot_accessor import get_bot_instance # type: ignore + from global_bot_accessor import get_bot_instance # type: ignore + bot_instance = get_bot_instance() if not bot_instance or not bot_instance.pg_pool: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Database connection not available" + detail="Database connection not available", ) # Try to get the bot instance, but don't require it bot = None try: - import discord_bot_sync_api # type: ignore + import discord_bot_sync_api # type: ignore + bot = discord_bot_sync_api.bot_instance except (ImportError, AttributeError) as e: log.warning(f"Could not import bot instance: {e}") @@ -1440,19 +1679,27 @@ async def update_command_status_direct( command = bot.get_command(command_name) if not command: # Check if it's an app command - app_commands = [cmd for cmd in bot.tree.get_commands() if cmd.name == command_name] + app_commands = [ + cmd for cmd in bot.tree.get_commands() if cmd.name == command_name + ] if not app_commands: - log.warning(f"Command '{command_name}' not found in bot instance, but proceeding anyway") + log.warning( + f"Command '{command_name}' not found in bot instance, but proceeding anyway" + ) # Update the command enabled status - success = await settings_manager.set_command_enabled(guild_id, command_name, enabled) + success = await settings_manager.set_command_enabled( + guild_id, command_name, enabled + ) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update command '{command_name}' status" + detail=f"Failed to update command '{command_name}' status", ) - return {"message": f"Command '{command_name}' {'enabled' if enabled else 'disabled'} successfully"} + return { + "message": f"Command '{command_name}' {'enabled' if enabled else 'disabled'} successfully" + } except HTTPException: # Re-raise HTTP exceptions raise @@ -1460,9 +1707,10 @@ async def update_command_status_direct( log.error(f"Error updating command status for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error updating command status: {str(e)}" + detail=f"Error updating command status: {str(e)}", ) + # --- Dashboard Authentication Routes --- @dashboard_api_app.get("/auth/login", tags=["Dashboard Authentication"]) async def dashboard_login(): @@ -1495,17 +1743,29 @@ async def dashboard_login(): log.info(f"Dashboard: Redirecting user to Discord auth URL with PKCE: {auth_url}") log.info(f"Dashboard: Using redirect URI: {DASHBOARD_REDIRECT_URI}") - log.info(f"Dashboard: Stored code verifier for state {state}: {code_verifier[:10]}...") + log.info( + f"Dashboard: Stored code verifier for state {state}: {code_verifier[:10]}..." + ) + + return RedirectResponse( + url=auth_url, status_code=status.HTTP_307_TEMPORARY_REDIRECT + ) - return RedirectResponse(url=auth_url, status_code=status.HTTP_307_TEMPORARY_REDIRECT) @dashboard_api_app.get("/auth/callback", tags=["Dashboard Authentication"]) -async def dashboard_auth_callback(request: Request, code: str | None = None, state: str | None = None, error: str | None = None): +async def dashboard_auth_callback( + request: Request, + code: str | None = None, + state: str | None = None, + error: str | None = None, +): """Handles the callback from Discord after authorization (Dashboard Flow).""" - global http_session # Use the global aiohttp session + global http_session # Use the global aiohttp session if error: log.error(f"Dashboard: Discord OAuth error: {error}") - return RedirectResponse(url="/dashboard?error=discord_auth_failed") # Redirect to frontend dashboard root + return RedirectResponse( + url="/dashboard?error=discord_auth_failed" + ) # Redirect to frontend dashboard root if not code: log.error("Dashboard: Discord OAuth callback missing code.") @@ -1516,8 +1776,10 @@ async def dashboard_auth_callback(request: Request, code: str | None = None, sta return RedirectResponse(url="/dashboard?error=missing_state") if not http_session: - log.error("Dashboard: aiohttp session not initialized.") - raise HTTPException(status_code=500, detail="Internal server error: HTTP session not ready.") + log.error("Dashboard: aiohttp session not initialized.") + raise HTTPException( + status_code=500, detail="Internal server error: HTTP session not ready." + ) try: # Get the code verifier from the store @@ -1526,82 +1788,113 @@ async def dashboard_auth_callback(request: Request, code: str | None = None, sta log.error(f"Dashboard: No code_verifier found for state {state}") return RedirectResponse(url="/dashboard?error=missing_code_verifier") - log.info(f"Dashboard: Found code_verifier for state {state}: {code_verifier[:10]}...") + log.info( + f"Dashboard: Found code_verifier for state {state}: {code_verifier[:10]}..." + ) # Remove the code verifier from the store after retrieving it code_verifier_store.remove_code_verifier(state) # 1. Exchange code for access token with PKCE token_data = { - 'client_id': settings.DISCORD_CLIENT_ID, - 'grant_type': 'authorization_code', - 'code': code, - 'redirect_uri': DASHBOARD_REDIRECT_URI, # Must match exactly what was used in the auth request - 'code_verifier': code_verifier # Add the code verifier for PKCE + "client_id": settings.DISCORD_CLIENT_ID, + "grant_type": "authorization_code", + "code": code, + "redirect_uri": DASHBOARD_REDIRECT_URI, # Must match exactly what was used in the auth request + "code_verifier": code_verifier, # Add the code verifier for PKCE } - headers = {'Content-Type': 'application/x-www-form-urlencoded'} + headers = {"Content-Type": "application/x-www-form-urlencoded"} - log.debug(f"Dashboard: Exchanging code for token at {DISCORD_TOKEN_URL} with PKCE") + log.debug( + f"Dashboard: Exchanging code for token at {DISCORD_TOKEN_URL} with PKCE" + ) log.debug(f"Dashboard: Token exchange data: {token_data}") - async with http_session.post(DISCORD_TOKEN_URL, data=token_data, headers=headers) as resp: + async with http_session.post( + DISCORD_TOKEN_URL, data=token_data, headers=headers + ) as resp: if resp.status != 200: error_text = await resp.text() log.error(f"Dashboard: Failed to exchange code: {error_text}") - return RedirectResponse(url=f"/dashboard?error=token_exchange_failed&details={error_text}") + return RedirectResponse( + url=f"/dashboard?error=token_exchange_failed&details={error_text}" + ) token_response = await resp.json() - access_token = token_response.get('access_token') + access_token = token_response.get("access_token") log.debug("Dashboard: Token exchange successful.") if not access_token: log.error("Dashboard: Failed to get access token from Discord response.") - raise HTTPException(status_code=500, detail="Could not retrieve access token from Discord.") + raise HTTPException( + status_code=500, detail="Could not retrieve access token from Discord." + ) # 2. Fetch user data - user_headers = {'Authorization': f'Bearer {access_token}'} + user_headers = {"Authorization": f"Bearer {access_token}"} log.debug(f"Dashboard: Fetching user data from {DISCORD_USER_URL}") async with http_session.get(DISCORD_USER_URL, headers=user_headers) as resp: resp.raise_for_status() user_data = await resp.json() - log.debug(f"Dashboard: User data fetched successfully for user ID: {user_data.get('id')}") + log.debug( + f"Dashboard: User data fetched successfully for user ID: {user_data.get('id')}" + ) # 3. Store in session - request.session['user_id'] = user_data.get('id') - request.session['username'] = user_data.get('username') - request.session['avatar'] = user_data.get('avatar') - request.session['access_token'] = access_token + request.session["user_id"] = user_data.get("id") + request.session["username"] = user_data.get("username") + request.session["avatar"] = user_data.get("avatar") + request.session["access_token"] = access_token - log.info(f"Dashboard: User {user_data.get('username')} ({user_data.get('id')}) logged in successfully.") + log.info( + f"Dashboard: User {user_data.get('username')} ({user_data.get('id')}) logged in successfully." + ) # Redirect user back to the main dashboard page (served by static files) - return RedirectResponse(url="/dashboard", status_code=status.HTTP_307_TEMPORARY_REDIRECT) + return RedirectResponse( + url="/dashboard", status_code=status.HTTP_307_TEMPORARY_REDIRECT + ) except aiohttp.ClientResponseError as e: - log.exception(f"Dashboard: HTTP error during Discord OAuth callback: {e.status} {e.message}") + log.exception( + f"Dashboard: HTTP error during Discord OAuth callback: {e.status} {e.message}" + ) error_detail = "Unknown Discord API error" try: error_body = await e.response.json() error_detail = error_body.get("error_description", error_detail) - except Exception: pass - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"Error communicating with Discord: {error_detail}") + except Exception: + pass + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"Error communicating with Discord: {error_detail}", + ) except Exception as e: log.exception(f"Dashboard: Generic error during Discord OAuth callback: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An internal error occurred during authentication.") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An internal error occurred during authentication.", + ) -@dashboard_api_app.post("/auth/logout", tags=["Dashboard Authentication"], status_code=status.HTTP_204_NO_CONTENT) + +@dashboard_api_app.post( + "/auth/logout", + tags=["Dashboard Authentication"], + status_code=status.HTTP_204_NO_CONTENT, +) async def dashboard_logout(request: Request): """Clears the dashboard user session.""" - user_id = request.session.get('user_id') + user_id = request.session.get("user_id") request.session.clear() log.info(f"Dashboard: User {user_id} logged out.") return + @dashboard_api_app.get("/auth/status", tags=["Dashboard Authentication"]) async def dashboard_auth_status(request: Request): """Checks if the user is authenticated in the dashboard session.""" - user_id = request.session.get('user_id') - username = request.session.get('username') - access_token = request.session.get('access_token') + user_id = request.session.get("user_id") + username = request.session.get("username") + access_token = request.session.get("access_token") if not user_id or not username or not access_token: log.debug("Dashboard: Auth status check - user not authenticated") @@ -1611,50 +1904,65 @@ async def dashboard_auth_status(request: Request): try: if not http_session: log.error("Dashboard: aiohttp session not initialized.") - return {"authenticated": False, "message": "Internal server error: HTTP session not ready"} + return { + "authenticated": False, + "message": "Internal server error: HTTP session not ready", + } - user_headers = {'Authorization': f'Bearer {access_token}'} + user_headers = {"Authorization": f"Bearer {access_token}"} async with http_session.get(DISCORD_USER_URL, headers=user_headers) as resp: if resp.status != 200: - log.warning(f"Dashboard: Auth status check - invalid token for user {user_id}") + log.warning( + f"Dashboard: Auth status check - invalid token for user {user_id}" + ) # Clear the invalid session request.session.clear() - return {"authenticated": False, "message": "Discord token invalid or expired"} + return { + "authenticated": False, + "message": "Discord token invalid or expired", + } # Token is valid, get the latest user data user_data = await resp.json() # Update session with latest data - request.session['username'] = user_data.get('username') - request.session['avatar'] = user_data.get('avatar') + request.session["username"] = user_data.get("username") + request.session["avatar"] = user_data.get("avatar") log.debug(f"Dashboard: Auth status check - user {user_id} is authenticated") return { "authenticated": True, "user": { "id": user_id, - "username": user_data.get('username'), - "avatar": user_data.get('avatar') - } + "username": user_data.get("username"), + "avatar": user_data.get("avatar"), + }, } except Exception as e: log.exception(f"Dashboard: Error checking auth status: {e}") - return {"authenticated": False, "message": f"Error checking auth status: {str(e)}"} + return { + "authenticated": False, + "message": f"Error checking auth status: {str(e)}", + } + # --- Dashboard User Endpoints --- @dashboard_api_app.get("/user/me", tags=["Dashboard User"]) -async def dashboard_get_user_me(current_user: dict = Depends(dependencies.get_dashboard_user)): +async def dashboard_get_user_me( + current_user: dict = Depends(dependencies.get_dashboard_user), +): """Returns information about the currently logged-in dashboard user.""" user_info = current_user.copy() # del user_info['access_token'] # Optional: Don't expose token to frontend return user_info + @dashboard_api_app.get("/auth/user", tags=["Dashboard Authentication"]) async def dashboard_get_auth_user(request: Request): """Returns information about the currently logged-in dashboard user for the frontend.""" - user_id = request.session.get('user_id') - username = request.session.get('username') - avatar = request.session.get('avatar') + user_id = request.session.get("user_id") + username = request.session.get("username") + avatar = request.session.get("avatar") if not user_id or not username: raise HTTPException( @@ -1663,36 +1971,43 @@ async def dashboard_get_auth_user(request: Request): headers={"WWW-Authenticate": "Bearer"}, ) - return { - "id": user_id, - "username": username, - "avatar": avatar - } + return {"id": user_id, "username": username, "avatar": avatar} + @dashboard_api_app.get("/user/guilds", tags=["Dashboard User"]) @dashboard_api_app.get("/guilds", tags=["Dashboard Guild Settings"]) -async def dashboard_get_user_guilds(current_user: dict = Depends(dependencies.get_dashboard_user)): +async def dashboard_get_user_guilds( + current_user: dict = Depends(dependencies.get_dashboard_user), +): """Returns a list of guilds the user is an administrator in AND the bot is also in.""" - global http_session # Use the global aiohttp session + global http_session # Use the global aiohttp session if not http_session: - log.error("Dashboard: aiohttp session not initialized.") - raise HTTPException(status_code=500, detail="Internal server error: HTTP session not ready.") + log.error("Dashboard: aiohttp session not initialized.") + raise HTTPException( + status_code=500, detail="Internal server error: HTTP session not ready." + ) if not settings_manager: log.error("Dashboard: settings_manager not available.") # Instead of raising an exception, return an empty list with a warning - log.warning("Dashboard: Returning empty guild list due to missing settings_manager") + log.warning( + "Dashboard: Returning empty guild list due to missing settings_manager" + ) return [] - access_token = current_user['access_token'] - user_headers = {'Authorization': f'Bearer {access_token}'} + access_token = current_user["access_token"] + user_headers = {"Authorization": f"Bearer {access_token}"} try: # 1. Fetch guilds user is in from Discord log.debug(f"Dashboard: Fetching user guilds from {DISCORD_USER_GUILDS_URL}") - async with http_session.get(DISCORD_USER_GUILDS_URL, headers=user_headers) as resp: + async with http_session.get( + DISCORD_USER_GUILDS_URL, headers=user_headers + ) as resp: resp.raise_for_status() user_guilds = await resp.json() - log.debug(f"Dashboard: Fetched {len(user_guilds)} guilds for user {current_user['user_id']}") + log.debug( + f"Dashboard: Fetched {len(user_guilds)} guilds for user {current_user['user_id']}" + ) # 2. Fetch guilds the bot is in from our DB try: @@ -1705,27 +2020,45 @@ async def dashboard_get_user_guilds(current_user: dict = Depends(dependencies.ge while retry_count < max_db_retries and bot_guild_ids is None: try: # Always use the API server's own pool with the new function - if hasattr(app.state, 'pg_pool') and app.state.pg_pool: - log.info("Dashboard: Using API server's pool to fetch guild IDs") - bot_guild_ids = await settings_manager.get_bot_guild_ids_with_pool(app.state.pg_pool) + if hasattr(app.state, "pg_pool") and app.state.pg_pool: + log.info( + "Dashboard: Using API server's pool to fetch guild IDs" + ) + bot_guild_ids = ( + await settings_manager.get_bot_guild_ids_with_pool( + app.state.pg_pool + ) + ) else: # The improved get_bot_guild_ids will try app.state.pg_pool first - log.info("Dashboard: Using enhanced get_bot_guild_ids that prioritizes API server's pool") + log.info( + "Dashboard: Using enhanced get_bot_guild_ids that prioritizes API server's pool" + ) bot_guild_ids = await settings_manager.get_bot_guild_ids() if bot_guild_ids is None: - log.warning(f"Dashboard: Failed to fetch bot guild IDs, retry {retry_count+1}/{max_db_retries}") + log.warning( + f"Dashboard: Failed to fetch bot guild IDs, retry {retry_count+1}/{max_db_retries}" + ) retry_count += 1 if retry_count < max_db_retries: await asyncio.sleep(1) # Wait before retrying except RuntimeError as e: - if "got Future" in str(e) and "attached to a different loop" in str(e): - log.warning(f"Dashboard: Event loop error fetching guild IDs: {e}") - log.warning("This is likely because we're trying to use a pool from a different thread.") + if "got Future" in str(e) and "attached to a different loop" in str( + e + ): + log.warning( + f"Dashboard: Event loop error fetching guild IDs: {e}" + ) + log.warning( + "This is likely because we're trying to use a pool from a different thread." + ) # Try to create a new pool just for this request if needed - if not hasattr(app.state, 'pg_pool') or not app.state.pg_pool: + if not hasattr(app.state, "pg_pool") or not app.state.pg_pool: try: - log.info("Dashboard: Attempting to create a temporary pool for this request") + log.info( + "Dashboard: Attempting to create a temporary pool for this request" + ) temp_pool = await asyncpg.create_pool( user=settings.POSTGRES_USER, password=settings.POSTGRES_PASSWORD, @@ -1734,72 +2067,114 @@ async def dashboard_get_user_guilds(current_user: dict = Depends(dependencies.ge min_size=1, max_size=2, ) - bot_guild_ids = await settings_manager.get_bot_guild_ids_with_pool(temp_pool) + bot_guild_ids = ( + await settings_manager.get_bot_guild_ids_with_pool( + temp_pool + ) + ) await temp_pool.close() except Exception as pool_err: - log.error(f"Dashboard: Failed to create temporary pool: {pool_err}") + log.error( + f"Dashboard: Failed to create temporary pool: {pool_err}" + ) else: - log.warning(f"Dashboard: Runtime error fetching bot guild IDs, retry {retry_count+1}/{max_db_retries}: {e}") + log.warning( + f"Dashboard: Runtime error fetching bot guild IDs, retry {retry_count+1}/{max_db_retries}: {e}" + ) retry_count += 1 if retry_count < max_db_retries: await asyncio.sleep(1) # Wait before retrying except Exception as e: - log.warning(f"Dashboard: Error fetching bot guild IDs, retry {retry_count+1}/{max_db_retries}: {e}") + log.warning( + f"Dashboard: Error fetching bot guild IDs, retry {retry_count+1}/{max_db_retries}: {e}" + ) retry_count += 1 if retry_count < max_db_retries: await asyncio.sleep(1) # Wait before retrying # After retries, if still no data, provide a fallback empty set instead of raising an exception if bot_guild_ids is None: - log.error("Dashboard: Failed to fetch bot guild IDs from settings_manager after retries.") + log.error( + "Dashboard: Failed to fetch bot guild IDs from settings_manager after retries." + ) # Instead of raising an exception, use an empty set as fallback bot_guild_ids = set() - log.warning("Dashboard: Using empty guild set as fallback to allow dashboard to function") + log.warning( + "Dashboard: Using empty guild set as fallback to allow dashboard to function" + ) except Exception as e: - log.exception("Dashboard: Exception while fetching bot guild IDs from settings_manager.") + log.exception( + "Dashboard: Exception while fetching bot guild IDs from settings_manager." + ) # Instead of raising an exception, use an empty set as fallback bot_guild_ids = set() - log.warning("Dashboard: Using empty guild set as fallback to allow dashboard to function") + log.warning( + "Dashboard: Using empty guild set as fallback to allow dashboard to function" + ) # 3. Filter user guilds manageable_guilds = [] ADMINISTRATOR_PERMISSION = 0x8 for guild in user_guilds: - guild_id = int(guild['id']) - permissions = int(guild['permissions']) + guild_id = int(guild["id"]) + permissions = int(guild["permissions"]) - if (permissions & ADMINISTRATOR_PERMISSION) == ADMINISTRATOR_PERMISSION and guild_id in bot_guild_ids: - manageable_guilds.append({ - "id": guild['id'], - "name": guild['name'], - "icon": guild.get('icon'), - }) + if ( + permissions & ADMINISTRATOR_PERMISSION + ) == ADMINISTRATOR_PERMISSION and guild_id in bot_guild_ids: + manageable_guilds.append( + { + "id": guild["id"], + "name": guild["name"], + "icon": guild.get("icon"), + } + ) - log.info(f"Dashboard: Found {len(manageable_guilds)} manageable guilds for user {current_user['user_id']}") + log.info( + f"Dashboard: Found {len(manageable_guilds)} manageable guilds for user {current_user['user_id']}" + ) return manageable_guilds except aiohttp.ClientResponseError as e: - log.exception(f"Dashboard: HTTP error fetching user guilds: {e.status} {e.message}") + log.exception( + f"Dashboard: HTTP error fetching user guilds: {e.status} {e.message}" + ) if e.status == 401: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Discord token invalid or expired. Please re-login.") - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="Error communicating with Discord API.") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Discord token invalid or expired. Please re-login.", + ) + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Error communicating with Discord API.", + ) except Exception as e: log.exception(f"Dashboard: Generic error fetching user guilds: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An internal error occurred while fetching guilds.") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An internal error occurred while fetching guilds.", + ) + # --- Dashboard Guild Settings Endpoints --- @dashboard_api_app.get("/guilds/{guild_id}/channels", tags=["Dashboard Guild Settings"]) async def dashboard_get_guild_channels( guild_id: int, current_user: dict = Depends(dependencies.get_dashboard_user), - _: bool = Depends(dependencies.verify_dashboard_guild_admin) # Underscore indicates unused but required dependency + _: bool = Depends( + dependencies.verify_dashboard_guild_admin + ), # Underscore indicates unused but required dependency ): """Fetches the channels for a specific guild for the dashboard.""" - global http_session # Use the global aiohttp session + global http_session # Use the global aiohttp session if not http_session: - raise HTTPException(status_code=500, detail="Internal server error: HTTP session not ready.") + raise HTTPException( + status_code=500, detail="Internal server error: HTTP session not ready." + ) - log.info(f"Dashboard: Fetching channels for guild {guild_id} requested by user {current_user['user_id']}") + log.info( + f"Dashboard: Fetching channels for guild {guild_id} requested by user {current_user['user_id']}" + ) try: # Use Discord Bot Token to fetch channels if available @@ -1807,10 +2182,10 @@ async def dashboard_get_guild_channels( log.error("Dashboard: DISCORD_BOT_TOKEN not set in environment variables") raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Bot token not configured. Please set DISCORD_BOT_TOKEN in environment variables." + detail="Bot token not configured. Please set DISCORD_BOT_TOKEN in environment variables.", ) - bot_headers = {'Authorization': f'Bot {settings.DISCORD_BOT_TOKEN}'} + bot_headers = {"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}"} # Add rate limit handling max_retries = 3 @@ -1819,22 +2194,31 @@ async def dashboard_get_guild_channels( while retry_count < max_retries: if retry_after > 0: - log.warning(f"Dashboard: Rate limited by Discord API, waiting {retry_after} seconds before retry") + log.warning( + f"Dashboard: Rate limited by Discord API, waiting {retry_after} seconds before retry" + ) await asyncio.sleep(retry_after) - async with http_session.get(f"https://discord.com/api/v10/guilds/{guild_id}/channels", headers=bot_headers) as resp: + async with http_session.get( + f"https://discord.com/api/v10/guilds/{guild_id}/channels", + headers=bot_headers, + ) as resp: if resp.status == 429: # Rate limited retry_count += 1 # Get the most accurate retry time from headers - retry_after = float(resp.headers.get('X-RateLimit-Reset-After', - resp.headers.get('Retry-After', 1))) + retry_after = float( + resp.headers.get( + "X-RateLimit-Reset-After", + resp.headers.get("Retry-After", 1), + ) + ) # Check if this is a global rate limit - is_global = resp.headers.get('X-RateLimit-Global') is not None + is_global = resp.headers.get("X-RateLimit-Global") is not None # Get the rate limit scope if available - scope = resp.headers.get('X-RateLimit-Scope', 'unknown') + scope = resp.headers.get("X-RateLimit-Scope", "unknown") log.warning( f"Dashboard: Discord API rate limit hit. " @@ -1845,24 +2229,26 @@ async def dashboard_get_guild_channels( # For global rate limits, we might want to wait longer if is_global: - retry_after = max(retry_after, 5) # At least 5 seconds for global limits + retry_after = max( + retry_after, 5 + ) # At least 5 seconds for global limits continue # Check rate limit headers and log them for monitoring rate_limit = { - 'limit': resp.headers.get('X-RateLimit-Limit'), - 'remaining': resp.headers.get('X-RateLimit-Remaining'), - 'reset': resp.headers.get('X-RateLimit-Reset'), - 'reset_after': resp.headers.get('X-RateLimit-Reset-After'), - 'bucket': resp.headers.get('X-RateLimit-Bucket') + "limit": resp.headers.get("X-RateLimit-Limit"), + "remaining": resp.headers.get("X-RateLimit-Remaining"), + "reset": resp.headers.get("X-RateLimit-Reset"), + "reset_after": resp.headers.get("X-RateLimit-Reset-After"), + "bucket": resp.headers.get("X-RateLimit-Bucket"), } # If we're getting close to the rate limit, log a warning - if rate_limit['remaining'] and rate_limit['limit']: + if rate_limit["remaining"] and rate_limit["limit"]: try: - remaining = int(rate_limit['remaining']) - limit = int(rate_limit['limit']) + remaining = int(rate_limit["remaining"]) + limit = int(rate_limit["limit"]) if remaining < 5: log.warning( f"Dashboard: Rate limit warning: {remaining}/{limit} " @@ -1879,38 +2265,61 @@ async def dashboard_get_guild_channels( # Filter and format channels formatted_channels = [] for channel in channels: - formatted_channels.append({ - "id": channel["id"], - "name": channel["name"], - "type": channel["type"], - "parent_id": channel.get("parent_id") - }) + formatted_channels.append( + { + "id": channel["id"], + "name": channel["name"], + "type": channel["type"], + "parent_id": channel.get("parent_id"), + } + ) return formatted_channels # If we get here, we've exceeded our retry limit - raise HTTPException(status_code=429, detail="Rate limited by Discord API. Please try again later.") + raise HTTPException( + status_code=429, + detail="Rate limited by Discord API. Please try again later.", + ) except aiohttp.ClientResponseError as e: - log.exception(f"Dashboard: HTTP error fetching guild channels: {e.status} {e.message}") + log.exception( + f"Dashboard: HTTP error fetching guild channels: {e.status} {e.message}" + ) if e.status == 429: - raise HTTPException(status_code=429, detail="Rate limited by Discord API. Please try again later.") - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="Error communicating with Discord API.") + raise HTTPException( + status_code=429, + detail="Rate limited by Discord API. Please try again later.", + ) + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Error communicating with Discord API.", + ) except Exception as e: log.exception(f"Dashboard: Generic error fetching guild channels: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An internal error occurred while fetching channels.") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An internal error occurred while fetching channels.", + ) + @dashboard_api_app.get("/guilds/{guild_id}/roles", tags=["Dashboard Guild Settings"]) async def dashboard_get_guild_roles( guild_id: int, current_user: dict = Depends(dependencies.get_dashboard_user), - _: bool = Depends(dependencies.verify_dashboard_guild_admin) # Underscore indicates unused but required dependency + _: bool = Depends( + dependencies.verify_dashboard_guild_admin + ), # Underscore indicates unused but required dependency ): """Fetches the roles for a specific guild for the dashboard.""" - global http_session # Use the global aiohttp session + global http_session # Use the global aiohttp session if not http_session: - raise HTTPException(status_code=500, detail="Internal server error: HTTP session not ready.") + raise HTTPException( + status_code=500, detail="Internal server error: HTTP session not ready." + ) - log.info(f"Dashboard: Fetching roles for guild {guild_id} requested by user {current_user['user_id']}") + log.info( + f"Dashboard: Fetching roles for guild {guild_id} requested by user {current_user['user_id']}" + ) try: # Use Discord Bot Token to fetch roles if available @@ -1918,10 +2327,10 @@ async def dashboard_get_guild_roles( log.error("Dashboard: DISCORD_BOT_TOKEN not set in environment variables") raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Bot token not configured. Please set DISCORD_BOT_TOKEN in environment variables." + detail="Bot token not configured. Please set DISCORD_BOT_TOKEN in environment variables.", ) - bot_headers = {'Authorization': f'Bot {settings.DISCORD_BOT_TOKEN}'} + bot_headers = {"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}"} # Add rate limit handling max_retries = 3 @@ -1930,22 +2339,31 @@ async def dashboard_get_guild_roles( while retry_count < max_retries: if retry_after > 0: - log.warning(f"Dashboard: Rate limited by Discord API, waiting {retry_after} seconds before retry") + log.warning( + f"Dashboard: Rate limited by Discord API, waiting {retry_after} seconds before retry" + ) await asyncio.sleep(retry_after) - async with http_session.get(f"https://discord.com/api/v10/guilds/{guild_id}/roles", headers=bot_headers) as resp: + async with http_session.get( + f"https://discord.com/api/v10/guilds/{guild_id}/roles", + headers=bot_headers, + ) as resp: if resp.status == 429: # Rate limited retry_count += 1 # Get the most accurate retry time from headers - retry_after = float(resp.headers.get('X-RateLimit-Reset-After', - resp.headers.get('Retry-After', 1))) + retry_after = float( + resp.headers.get( + "X-RateLimit-Reset-After", + resp.headers.get("Retry-After", 1), + ) + ) # Check if this is a global rate limit - is_global = resp.headers.get('X-RateLimit-Global') is not None + is_global = resp.headers.get("X-RateLimit-Global") is not None # Get the rate limit scope if available - scope = resp.headers.get('X-RateLimit-Scope', 'unknown') + scope = resp.headers.get("X-RateLimit-Scope", "unknown") log.warning( f"Dashboard: Discord API rate limit hit. " @@ -1956,7 +2374,9 @@ async def dashboard_get_guild_roles( # For global rate limits, we might want to wait longer if is_global: - retry_after = max(retry_after, 5) # At least 5 seconds for global limits + retry_after = max( + retry_after, 5 + ) # At least 5 seconds for global limits continue @@ -1970,13 +2390,15 @@ async def dashboard_get_guild_roles( if role["name"] == "@everyone": continue - formatted_roles.append({ - "id": role["id"], - "name": role["name"], - "color": role["color"], - "position": role["position"], - "permissions": role["permissions"] - }) + formatted_roles.append( + { + "id": role["id"], + "name": role["name"], + "color": role["color"], + "position": role["position"], + "permissions": role["permissions"], + } + ) # Sort roles by position (highest first) formatted_roles.sort(key=lambda r: r["position"], reverse=True) @@ -1984,28 +2406,49 @@ async def dashboard_get_guild_roles( return formatted_roles # If we get here, we've exceeded our retry limit - raise HTTPException(status_code=429, detail="Rate limited by Discord API. Please try again later.") + raise HTTPException( + status_code=429, + detail="Rate limited by Discord API. Please try again later.", + ) except aiohttp.ClientResponseError as e: - log.exception(f"Dashboard: HTTP error fetching guild roles: {e.status} {e.message}") + log.exception( + f"Dashboard: HTTP error fetching guild roles: {e.status} {e.message}" + ) if e.status == 429: - raise HTTPException(status_code=429, detail="Rate limited by Discord API. Please try again later.") - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="Error communicating with Discord API.") + raise HTTPException( + status_code=429, + detail="Rate limited by Discord API. Please try again later.", + ) + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Error communicating with Discord API.", + ) except Exception as e: log.exception(f"Dashboard: Generic error fetching guild roles: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An internal error occurred while fetching roles.") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An internal error occurred while fetching roles.", + ) + @dashboard_api_app.get("/guilds/{guild_id}/commands", tags=["Dashboard Guild Settings"]) async def dashboard_get_guild_commands( guild_id: int, current_user: dict = Depends(dependencies.get_dashboard_user), - _: bool = Depends(dependencies.verify_dashboard_guild_admin) # Underscore indicates unused but required dependency + _: bool = Depends( + dependencies.verify_dashboard_guild_admin + ), # Underscore indicates unused but required dependency ): """Fetches the commands for a specific guild for the dashboard.""" - global http_session # Use the global aiohttp session + global http_session # Use the global aiohttp session if not http_session: - raise HTTPException(status_code=500, detail="Internal server error: HTTP session not ready.") + raise HTTPException( + status_code=500, detail="Internal server error: HTTP session not ready." + ) - log.info(f"Dashboard: Fetching commands for guild {guild_id} requested by user {current_user['user_id']}") + log.info( + f"Dashboard: Fetching commands for guild {guild_id} requested by user {current_user['user_id']}" + ) try: # Use Discord Bot Token to fetch application commands if available @@ -2013,11 +2456,13 @@ async def dashboard_get_guild_commands( log.error("Dashboard: DISCORD_BOT_TOKEN not set in environment variables") raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Bot token not configured. Please set DISCORD_BOT_TOKEN in environment variables." + detail="Bot token not configured. Please set DISCORD_BOT_TOKEN in environment variables.", ) - bot_headers = {'Authorization': f'Bot {settings.DISCORD_BOT_TOKEN}'} - application_id = settings.DISCORD_CLIENT_ID # This should be the same as your bot's application ID + bot_headers = {"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}"} + application_id = ( + settings.DISCORD_CLIENT_ID + ) # This should be the same as your bot's application ID # Add rate limit handling max_retries = 3 @@ -2026,22 +2471,31 @@ async def dashboard_get_guild_commands( while retry_count < max_retries: if retry_after > 0: - log.warning(f"Dashboard: Rate limited by Discord API, waiting {retry_after} seconds before retry") + log.warning( + f"Dashboard: Rate limited by Discord API, waiting {retry_after} seconds before retry" + ) await asyncio.sleep(retry_after) - async with http_session.get(f"https://discord.com/api/v10/applications/{application_id}/guilds/{guild_id}/commands", headers=bot_headers) as resp: + async with http_session.get( + f"https://discord.com/api/v10/applications/{application_id}/guilds/{guild_id}/commands", + headers=bot_headers, + ) as resp: if resp.status == 429: # Rate limited retry_count += 1 # Get the most accurate retry time from headers - retry_after = float(resp.headers.get('X-RateLimit-Reset-After', - resp.headers.get('Retry-After', 1))) + retry_after = float( + resp.headers.get( + "X-RateLimit-Reset-After", + resp.headers.get("Retry-After", 1), + ) + ) # Check if this is a global rate limit - is_global = resp.headers.get('X-RateLimit-Global') is not None + is_global = resp.headers.get("X-RateLimit-Global") is not None # Get the rate limit scope if available - scope = resp.headers.get('X-RateLimit-Scope', 'unknown') + scope = resp.headers.get("X-RateLimit-Scope", "unknown") log.warning( f"Dashboard: Discord API rate limit hit. " @@ -2052,7 +2506,9 @@ async def dashboard_get_guild_commands( # For global rate limits, we might want to wait longer if is_global: - retry_after = max(retry_after, 5) # At least 5 seconds for global limits + retry_after = max( + retry_after, 5 + ) # At least 5 seconds for global limits continue @@ -2066,38 +2522,59 @@ async def dashboard_get_guild_commands( # Format commands formatted_commands = [] for cmd in commands: - formatted_commands.append({ - "id": cmd["id"], - "name": cmd["name"], - "description": cmd.get("description", ""), - "type": cmd.get("type", 1), # Default to CHAT_INPUT type - "options": cmd.get("options", []) - }) + formatted_commands.append( + { + "id": cmd["id"], + "name": cmd["name"], + "description": cmd.get("description", ""), + "type": cmd.get("type", 1), # Default to CHAT_INPUT type + "options": cmd.get("options", []), + } + ) return formatted_commands # If we get here, we've exceeded our retry limit - raise HTTPException(status_code=429, detail="Rate limited by Discord API. Please try again later.") + raise HTTPException( + status_code=429, + detail="Rate limited by Discord API. Please try again later.", + ) except aiohttp.ClientResponseError as e: - log.exception(f"Dashboard: HTTP error fetching guild commands: {e.status} {e.message}") + log.exception( + f"Dashboard: HTTP error fetching guild commands: {e.status} {e.message}" + ) if e.status == 404: # If no commands are registered yet, return an empty list return [] if e.status == 429: - raise HTTPException(status_code=429, detail="Rate limited by Discord API. Please try again later.") - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="Error communicating with Discord API.") + raise HTTPException( + status_code=429, + detail="Rate limited by Discord API. Please try again later.", + ) + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Error communicating with Discord API.", + ) except Exception as e: log.exception(f"Dashboard: Generic error fetching guild commands: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An internal error occurred while fetching commands.") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An internal error occurred while fetching commands.", + ) + @dashboard_api_app.get("/settings", tags=["Dashboard Settings"]) -async def dashboard_get_settings(current_user: dict = Depends(dependencies.get_dashboard_user)): +async def dashboard_get_settings( + current_user: dict = Depends(dependencies.get_dashboard_user), +): """Fetches the global AI settings for the dashboard.""" - log.info(f"Dashboard: Fetching global settings requested by user {current_user['user_id']}") + log.info( + f"Dashboard: Fetching global settings requested by user {current_user['user_id']}" + ) try: # Get settings from the database - settings_data = db.get_user_settings(current_user['user_id']) + settings_data = db.get_user_settings(current_user["user_id"]) if not settings_data: # Return default settings if none exist @@ -2108,24 +2585,32 @@ async def dashboard_get_settings(current_user: dict = Depends(dependencies.get_d "system_message": "", "character": "", "character_info": "", - "custom_instructions": "" + "custom_instructions": "", } return settings_data except Exception as e: log.exception(f"Dashboard: Error fetching global settings: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An internal error occurred while fetching settings.") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An internal error occurred while fetching settings.", + ) + @dashboard_api_app.post("/settings", tags=["Dashboard Settings"]) @dashboard_api_app.put("/settings", tags=["Dashboard Settings"]) -async def dashboard_update_settings(request: Request, current_user: dict = Depends(dependencies.get_dashboard_user)): +async def dashboard_update_settings( + request: Request, current_user: dict = Depends(dependencies.get_dashboard_user) +): """Updates the global AI settings for the dashboard.""" - log.info(f"Dashboard: Updating global settings requested by user {current_user['user_id']}") + log.info( + f"Dashboard: Updating global settings requested by user {current_user['user_id']}" + ) try: # Parse the request body body_text = await request.body() - body = json.loads(body_text.decode('utf-8')) + body = json.loads(body_text.decode("utf-8")) log.debug(f"Dashboard: Received settings update: {body}") @@ -2140,18 +2625,25 @@ async def dashboard_update_settings(request: Request, current_user: dict = Depen settings_data = body if not settings_data: - raise HTTPException(status_code=400, detail="Invalid settings format. Expected 'settings' field or direct settings object.") + raise HTTPException( + status_code=400, + detail="Invalid settings format. Expected 'settings' field or direct settings object.", + ) # Create a UserSettings object try: settings = UserSettings.model_validate(settings_data) except Exception as e: log.exception(f"Dashboard: Error validating settings: {e}") - raise HTTPException(status_code=400, detail=f"Invalid settings data: {str(e)}") + raise HTTPException( + status_code=400, detail=f"Invalid settings data: {str(e)}" + ) # Save the settings - result = db.save_user_settings(current_user['user_id'], settings) - log.info(f"Dashboard: Successfully updated settings for user {current_user['user_id']}") + result = db.save_user_settings(current_user["user_id"], settings) + log.info( + f"Dashboard: Successfully updated settings for user {current_user['user_id']}" + ) return result except json.JSONDecodeError: @@ -2159,53 +2651,86 @@ async def dashboard_update_settings(request: Request, current_user: dict = Depen raise HTTPException(status_code=400, detail="Invalid JSON in request body") except Exception as e: log.exception(f"Dashboard: Error updating settings: {e}") - raise HTTPException(status_code=500, detail=f"An internal error occurred while updating settings: {str(e)}") + raise HTTPException( + status_code=500, + detail=f"An internal error occurred while updating settings: {str(e)}", + ) -@dashboard_api_app.get("/guilds/{guild_id}/settings", response_model=GuildSettingsResponse, tags=["Dashboard Guild Settings"]) + +@dashboard_api_app.get( + "/guilds/{guild_id}/settings", + response_model=GuildSettingsResponse, + tags=["Dashboard Guild Settings"], +) async def dashboard_get_guild_settings( guild_id: int, current_user: dict = Depends(dependencies.get_dashboard_user), - _: bool = Depends(dependencies.verify_dashboard_guild_admin) # Underscore indicates unused but required dependency + _: bool = Depends( + dependencies.verify_dashboard_guild_admin + ), # Underscore indicates unused but required dependency ): """Fetches the current settings for a specific guild for the dashboard.""" if not settings_manager: - raise HTTPException(status_code=500, detail="Internal server error: Settings manager not available.") + raise HTTPException( + status_code=500, + detail="Internal server error: Settings manager not available.", + ) - log.info(f"Dashboard: Fetching settings for guild {guild_id} requested by user {current_user['user_id']}") + log.info( + f"Dashboard: Fetching settings for guild {guild_id} requested by user {current_user['user_id']}" + ) - prefix = await settings_manager.get_guild_prefix(guild_id, "!") # Use default prefix constant - wc_id = await settings_manager.get_setting(guild_id, 'welcome_channel_id') - wc_msg = await settings_manager.get_setting(guild_id, 'welcome_message') - gc_id = await settings_manager.get_setting(guild_id, 'goodbye_channel_id') - gc_msg = await settings_manager.get_setting(guild_id, 'goodbye_message') + prefix = await settings_manager.get_guild_prefix( + guild_id, "!" + ) # Use default prefix constant + wc_id = await settings_manager.get_setting(guild_id, "welcome_channel_id") + wc_msg = await settings_manager.get_setting(guild_id, "welcome_message") + gc_id = await settings_manager.get_setting(guild_id, "goodbye_channel_id") + gc_msg = await settings_manager.get_setting(guild_id, "goodbye_message") known_cogs_in_db = {} try: # First try to use the API server's pool - if hasattr(app.state, 'pg_pool') and app.state.pg_pool: - log.info(f"Dashboard: Using API server's pool to fetch cog statuses for guild {guild_id}") + if hasattr(app.state, "pg_pool") and app.state.pg_pool: + log.info( + f"Dashboard: Using API server's pool to fetch cog statuses for guild {guild_id}" + ) async with app.state.pg_pool.acquire() as conn: - records = await conn.fetch("SELECT cog_name, enabled FROM enabled_cogs WHERE guild_id = $1", guild_id) + records = await conn.fetch( + "SELECT cog_name, enabled FROM enabled_cogs WHERE guild_id = $1", + guild_id, + ) for record in records: - known_cogs_in_db[record['cog_name']] = record['enabled'] + known_cogs_in_db[record["cog_name"]] = record["enabled"] else: # Fall back to bot's pool if API server pool not available bot = get_bot_instance() if bot and bot.pg_pool: - log.info(f"Dashboard: Using bot's pool to fetch cog statuses for guild {guild_id}") + log.info( + f"Dashboard: Using bot's pool to fetch cog statuses for guild {guild_id}" + ) async with bot.pg_pool.acquire() as conn: - records = await conn.fetch("SELECT cog_name, enabled FROM enabled_cogs WHERE guild_id = $1", guild_id) + records = await conn.fetch( + "SELECT cog_name, enabled FROM enabled_cogs WHERE guild_id = $1", + guild_id, + ) for record in records: - known_cogs_in_db[record['cog_name']] = record['enabled'] + known_cogs_in_db[record["cog_name"]] = record["enabled"] else: - log.error("Dashboard: Neither API server pool nor bot pool is available") + log.error( + "Dashboard: Neither API server pool nor bot pool is available" + ) except RuntimeError as e: if "got Future" in str(e) and "attached to a different loop" in str(e): log.warning(f"Dashboard: Event loop error fetching cog statuses: {e}") - log.warning("This is likely because we're trying to use a pool from a different thread.") + log.warning( + "This is likely because we're trying to use a pool from a different thread." + ) # Try to create a temporary pool just for this request try: - log.info("Dashboard: Attempting to create a temporary pool for cog statuses") + log.info( + "Dashboard: Attempting to create a temporary pool for cog statuses" + ) temp_pool = await asyncpg.create_pool( user=settings.POSTGRES_USER, password=settings.POSTGRES_PASSWORD, @@ -2215,31 +2740,42 @@ async def dashboard_get_guild_settings( max_size=2, ) async with temp_pool.acquire() as conn: - records = await conn.fetch("SELECT cog_name, enabled FROM enabled_cogs WHERE guild_id = $1", guild_id) + records = await conn.fetch( + "SELECT cog_name, enabled FROM enabled_cogs WHERE guild_id = $1", + guild_id, + ) for record in records: - known_cogs_in_db[record['cog_name']] = record['enabled'] + known_cogs_in_db[record["cog_name"]] = record["enabled"] await temp_pool.close() except Exception as pool_err: - log.error(f"Dashboard: Failed to create temporary pool for cog statuses: {pool_err}") + log.error( + f"Dashboard: Failed to create temporary pool for cog statuses: {pool_err}" + ) else: - log.exception(f"Dashboard: Runtime error fetching cog statuses from DB for guild {guild_id}: {e}") + log.exception( + f"Dashboard: Runtime error fetching cog statuses from DB for guild {guild_id}: {e}" + ) except Exception as e: - log.exception(f"Dashboard: Failed to fetch cog statuses from DB for guild {guild_id}: {e}") + log.exception( + f"Dashboard: Failed to fetch cog statuses from DB for guild {guild_id}: {e}" + ) # Fetch command permissions permissions_map: Dict[str, List[str]] = {} try: # First try to use the API server's pool - if hasattr(app.state, 'pg_pool') and app.state.pg_pool: - log.info(f"Dashboard: Using API server's pool to fetch command permissions for guild {guild_id}") + if hasattr(app.state, "pg_pool") and app.state.pg_pool: + log.info( + f"Dashboard: Using API server's pool to fetch command permissions for guild {guild_id}" + ) async with app.state.pg_pool.acquire() as conn: records = await conn.fetch( "SELECT command_name, allowed_role_id FROM command_permissions WHERE guild_id = $1 ORDER BY command_name, allowed_role_id", - guild_id + guild_id, ) for record in records: - cmd = record['command_name'] - role_id_str = str(record['allowed_role_id']) + cmd = record["command_name"] + role_id_str = str(record["allowed_role_id"]) if cmd not in permissions_map: permissions_map[cmd] = [] permissions_map[cmd].append(role_id_str) @@ -2247,26 +2783,34 @@ async def dashboard_get_guild_settings( # Fall back to bot's pool if API server pool not available bot = get_bot_instance() if bot and bot.pg_pool: - log.info(f"Dashboard: Using bot's pool to fetch command permissions for guild {guild_id}") + log.info( + f"Dashboard: Using bot's pool to fetch command permissions for guild {guild_id}" + ) async with bot.pg_pool.acquire() as conn: records = await conn.fetch( "SELECT command_name, allowed_role_id FROM command_permissions WHERE guild_id = $1 ORDER BY command_name, allowed_role_id", - guild_id + guild_id, ) for record in records: - cmd = record['command_name'] - role_id_str = str(record['allowed_role_id']) + cmd = record["command_name"] + role_id_str = str(record["allowed_role_id"]) if cmd not in permissions_map: permissions_map[cmd] = [] permissions_map[cmd].append(role_id_str) else: - log.error("Dashboard: Neither API server pool nor bot pool is available") + log.error( + "Dashboard: Neither API server pool nor bot pool is available" + ) except RuntimeError as e: if "got Future" in str(e) and "attached to a different loop" in str(e): - log.warning(f"Dashboard: Event loop error fetching command permissions: {e}") + log.warning( + f"Dashboard: Event loop error fetching command permissions: {e}" + ) # Try to create a temporary pool just for this request try: - log.info("Dashboard: Attempting to create a temporary pool for command permissions") + log.info( + "Dashboard: Attempting to create a temporary pool for command permissions" + ) temp_pool = await asyncpg.create_pool( user=settings.POSTGRES_USER, password=settings.POSTGRES_PASSWORD, @@ -2278,22 +2822,27 @@ async def dashboard_get_guild_settings( async with temp_pool.acquire() as conn: records = await conn.fetch( "SELECT command_name, allowed_role_id FROM command_permissions WHERE guild_id = $1 ORDER BY command_name, allowed_role_id", - guild_id + guild_id, ) for record in records: - cmd = record['command_name'] - role_id_str = str(record['allowed_role_id']) + cmd = record["command_name"] + role_id_str = str(record["allowed_role_id"]) if cmd not in permissions_map: permissions_map[cmd] = [] permissions_map[cmd].append(role_id_str) await temp_pool.close() except Exception as pool_err: - log.error(f"Dashboard: Failed to create temporary pool for command permissions: {pool_err}") + log.error( + f"Dashboard: Failed to create temporary pool for command permissions: {pool_err}" + ) else: - log.exception(f"Dashboard: Runtime error fetching command permissions from DB for guild {guild_id}: {e}") + log.exception( + f"Dashboard: Runtime error fetching command permissions from DB for guild {guild_id}: {e}" + ) except Exception as e: - log.exception(f"Dashboard: Failed to fetch command permissions from DB for guild {guild_id}: {e}") - + log.exception( + f"Dashboard: Failed to fetch command permissions from DB for guild {guild_id}: {e}" + ) settings_data = GuildSettingsResponse( guild_id=str(guild_id), @@ -2303,76 +2852,141 @@ async def dashboard_get_guild_settings( goodbye_channel_id=gc_id if gc_id != "__NONE__" else None, goodbye_message=gc_msg if gc_msg != "__NONE__" else None, enabled_cogs=known_cogs_in_db, - command_permissions=permissions_map + command_permissions=permissions_map, ) return settings_data -@dashboard_api_app.patch("/guilds/{guild_id}/settings", status_code=status.HTTP_200_OK, tags=["Dashboard Guild Settings"]) + +@dashboard_api_app.patch( + "/guilds/{guild_id}/settings", + status_code=status.HTTP_200_OK, + tags=["Dashboard Guild Settings"], +) async def dashboard_update_guild_settings( guild_id: int, settings_update: GuildSettingsUpdate, current_user: dict = Depends(dependencies.get_dashboard_user), - _: bool = Depends(dependencies.verify_dashboard_guild_admin) # Underscore indicates unused but required dependency + _: bool = Depends( + dependencies.verify_dashboard_guild_admin + ), # Underscore indicates unused but required dependency ): """Updates specific settings for a guild via the dashboard.""" if not settings_manager: - raise HTTPException(status_code=500, detail="Internal server error: Settings manager not available.") + raise HTTPException( + status_code=500, + detail="Internal server error: Settings manager not available.", + ) - log.info(f"Dashboard: Updating settings for guild {guild_id} requested by user {current_user['user_id']}") + log.info( + f"Dashboard: Updating settings for guild {guild_id} requested by user {current_user['user_id']}" + ) update_data = settings_update.model_dump(exclude_unset=True) log.debug(f"Dashboard: Update data received: {update_data}") success_flags = [] - core_cogs_list = {'SettingsCog', 'HelpCog'} # TODO: Get this reliably + core_cogs_list = {"SettingsCog", "HelpCog"} # TODO: Get this reliably - if 'prefix' in update_data: - success = await settings_manager.set_guild_prefix(guild_id, update_data['prefix']) + if "prefix" in update_data: + success = await settings_manager.set_guild_prefix( + guild_id, update_data["prefix"] + ) success_flags.append(success) - if not success: log.error(f"Dashboard: Failed to update prefix for guild {guild_id}") - if 'welcome_channel_id' in update_data: - value = update_data['welcome_channel_id'] if update_data['welcome_channel_id'] else None - success = await settings_manager.set_setting(guild_id, 'welcome_channel_id', value) + if not success: + log.error(f"Dashboard: Failed to update prefix for guild {guild_id}") + if "welcome_channel_id" in update_data: + value = ( + update_data["welcome_channel_id"] + if update_data["welcome_channel_id"] + else None + ) + success = await settings_manager.set_setting( + guild_id, "welcome_channel_id", value + ) success_flags.append(success) - if not success: log.error(f"Dashboard: Failed to update welcome_channel_id for guild {guild_id}") - if 'welcome_message' in update_data: - success = await settings_manager.set_setting(guild_id, 'welcome_message', update_data['welcome_message']) + if not success: + log.error( + f"Dashboard: Failed to update welcome_channel_id for guild {guild_id}" + ) + if "welcome_message" in update_data: + success = await settings_manager.set_setting( + guild_id, "welcome_message", update_data["welcome_message"] + ) success_flags.append(success) - if not success: log.error(f"Dashboard: Failed to update welcome_message for guild {guild_id}") - if 'goodbye_channel_id' in update_data: - value = update_data['goodbye_channel_id'] if update_data['goodbye_channel_id'] else None - success = await settings_manager.set_setting(guild_id, 'goodbye_channel_id', value) + if not success: + log.error( + f"Dashboard: Failed to update welcome_message for guild {guild_id}" + ) + if "goodbye_channel_id" in update_data: + value = ( + update_data["goodbye_channel_id"] + if update_data["goodbye_channel_id"] + else None + ) + success = await settings_manager.set_setting( + guild_id, "goodbye_channel_id", value + ) success_flags.append(success) - if not success: log.error(f"Dashboard: Failed to update goodbye_channel_id for guild {guild_id}") - if 'goodbye_message' in update_data: - success = await settings_manager.set_setting(guild_id, 'goodbye_message', update_data['goodbye_message']) + if not success: + log.error( + f"Dashboard: Failed to update goodbye_channel_id for guild {guild_id}" + ) + if "goodbye_message" in update_data: + success = await settings_manager.set_setting( + guild_id, "goodbye_message", update_data["goodbye_message"] + ) success_flags.append(success) - if not success: log.error(f"Dashboard: Failed to update goodbye_message for guild {guild_id}") - if 'cogs' in update_data and update_data['cogs'] is not None: - for cog_name, enabled_status in update_data['cogs'].items(): + if not success: + log.error( + f"Dashboard: Failed to update goodbye_message for guild {guild_id}" + ) + if "cogs" in update_data and update_data["cogs"] is not None: + for cog_name, enabled_status in update_data["cogs"].items(): if cog_name not in core_cogs_list: - success = await settings_manager.set_cog_enabled(guild_id, cog_name, enabled_status) + success = await settings_manager.set_cog_enabled( + guild_id, cog_name, enabled_status + ) success_flags.append(success) - if not success: log.error(f"Dashboard: Failed to update status for cog '{cog_name}' for guild {guild_id}") + if not success: + log.error( + f"Dashboard: Failed to update status for cog '{cog_name}' for guild {guild_id}" + ) else: - log.warning(f"Dashboard: Attempted to change status of core cog '{cog_name}' for guild {guild_id} - ignored.") + log.warning( + f"Dashboard: Attempted to change status of core cog '{cog_name}' for guild {guild_id} - ignored." + ) - if all(s is True for s in success_flags): # Check if all operations returned True + if all(s is True for s in success_flags): # Check if all operations returned True return {"message": "Settings updated successfully."} else: - raise HTTPException(status_code=500, detail="One or more settings failed to update. Check server logs.") + raise HTTPException( + status_code=500, + detail="One or more settings failed to update. Check server logs.", + ) + # --- Dashboard Command Permission Endpoints --- -@dashboard_api_app.get("/guilds/{guild_id}/permissions", response_model=CommandPermissionsResponse, tags=["Dashboard Guild Settings"]) +@dashboard_api_app.get( + "/guilds/{guild_id}/permissions", + response_model=CommandPermissionsResponse, + tags=["Dashboard Guild Settings"], +) async def dashboard_get_all_guild_command_permissions_map( guild_id: int, current_user: dict = Depends(dependencies.get_dashboard_user), - _: bool = Depends(dependencies.verify_dashboard_guild_admin) # Underscore indicates unused but required dependency + _: bool = Depends( + dependencies.verify_dashboard_guild_admin + ), # Underscore indicates unused but required dependency ): """Fetches all command permissions currently set for the guild for the dashboard as a map.""" if not settings_manager: - raise HTTPException(status_code=500, detail="Internal server error: Settings manager not available.") + raise HTTPException( + status_code=500, + detail="Internal server error: Settings manager not available.", + ) - log.info(f"Dashboard: Fetching all command permissions map for guild {guild_id} requested by user {current_user['user_id']}") + log.info( + f"Dashboard: Fetching all command permissions map for guild {guild_id} requested by user {current_user['user_id']}" + ) permissions_map: Dict[str, List[str]] = {} try: bot = get_bot_instance() @@ -2380,34 +2994,48 @@ async def dashboard_get_all_guild_command_permissions_map( async with bot.pg_pool.acquire() as conn: records = await conn.fetch( "SELECT command_name, allowed_role_id FROM command_permissions WHERE guild_id = $1 ORDER BY command_name, allowed_role_id", - guild_id + guild_id, ) for record in records: - cmd = record['command_name'] - role_id_str = str(record['allowed_role_id']) + cmd = record["command_name"] + role_id_str = str(record["allowed_role_id"]) if cmd not in permissions_map: permissions_map[cmd] = [] permissions_map[cmd].append(role_id_str) else: - log.error("Dashboard: Bot instance or pg_pool not initialized.") + log.error("Dashboard: Bot instance or pg_pool not initialized.") return CommandPermissionsResponse(permissions=permissions_map) except Exception as e: - log.exception(f"Dashboard: Database error fetching all command permissions for guild {guild_id}: {e}") - raise HTTPException(status_code=500, detail="Failed to fetch command permissions.") + log.exception( + f"Dashboard: Database error fetching all command permissions for guild {guild_id}: {e}" + ) + raise HTTPException( + status_code=500, detail="Failed to fetch command permissions." + ) -@dashboard_api_app.get("/guilds/{guild_id}/command-permissions", tags=["Dashboard Guild Settings"]) + +@dashboard_api_app.get( + "/guilds/{guild_id}/command-permissions", tags=["Dashboard Guild Settings"] +) async def dashboard_get_all_guild_command_permissions( guild_id: int, current_user: dict = Depends(dependencies.get_dashboard_user), - _: bool = Depends(dependencies.verify_dashboard_guild_admin) # Underscore indicates unused but required dependency + _: bool = Depends( + dependencies.verify_dashboard_guild_admin + ), # Underscore indicates unused but required dependency ): """Fetches all command permissions currently set for the guild for the dashboard as an array of objects.""" if not settings_manager: - raise HTTPException(status_code=500, detail="Internal server error: Settings manager not available.") + raise HTTPException( + status_code=500, + detail="Internal server error: Settings manager not available.", + ) - log.info(f"Dashboard: Fetching all command permissions for guild {guild_id} requested by user {current_user['user_id']}") + log.info( + f"Dashboard: Fetching all command permissions for guild {guild_id} requested by user {current_user['user_id']}" + ) permissions_list = [] try: bot = get_bot_instance() @@ -2415,50 +3043,56 @@ async def dashboard_get_all_guild_command_permissions( async with bot.pg_pool.acquire() as conn: records = await conn.fetch( "SELECT command_name, allowed_role_id FROM command_permissions WHERE guild_id = $1 ORDER BY command_name, allowed_role_id", - guild_id + guild_id, ) # Get role information to include role names - bot_headers = {'Authorization': f'Bot {settings.DISCORD_BOT_TOKEN}'} + bot_headers = {"Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}"} roles = [] try: - async with http_session.get(f"https://discord.com/api/v10/guilds/{guild_id}/roles", headers=bot_headers) as resp: + async with http_session.get( + f"https://discord.com/api/v10/guilds/{guild_id}/roles", + headers=bot_headers, + ) as resp: if resp.status == 200: roles = await resp.json() except Exception as e: log.warning(f"Failed to fetch role information: {e}") # Create a map of role IDs to role names - role_map = {str(role["id"]): role["name"] for role in roles} if roles else {} + role_map = ( + {str(role["id"]): role["name"] for role in roles} if roles else {} + ) for record in records: - cmd = record['command_name'] - role_id_str = str(record['allowed_role_id']) + cmd = record["command_name"] + role_id_str = str(record["allowed_role_id"]) role_name = role_map.get(role_id_str, f"Role ID: {role_id_str}") - permissions_list.append({ - "command": cmd, - "role_id": role_id_str, - "role_name": role_name - }) + permissions_list.append( + {"command": cmd, "role_id": role_id_str, "role_name": role_name} + ) else: - log.error("Dashboard: settings_manager pg_pool not initialized.") + log.error("Dashboard: settings_manager pg_pool not initialized.") return permissions_list except Exception as e: - log.exception(f"Dashboard: Database error fetching all command permissions for guild {guild_id}: {e}") - raise HTTPException(status_code=500, detail="Failed to fetch command permissions.") + log.exception( + f"Dashboard: Database error fetching all command permissions for guild {guild_id}: {e}" + ) + raise HTTPException( + status_code=500, detail="Failed to fetch command permissions." + ) + @dashboard_api_app.post( "/guilds/{guild_id}/ai-moderation-action", status_code=status.HTTP_201_CREATED, - tags=["Moderation", "AI Integration"] + tags=["Moderation", "AI Integration"], ) async def ai_moderation_action( - guild_id: int, - action: AIModerationAction, - request: Request + guild_id: int, action: AIModerationAction, request: Request ): """ Endpoint for external AI moderator to log moderation actions and add them to cases. @@ -2466,19 +3100,31 @@ async def ai_moderation_action( """ # Security check auth_header = request.headers.get("Authorization") - if not settings.MOD_LOG_API_SECRET or not auth_header or auth_header != f"Bearer {settings.MOD_LOG_API_SECRET}": - log.warning(f"Unauthorized attempt to use AI moderation endpoint. Headers: {request.headers}") + if ( + not settings.MOD_LOG_API_SECRET + or not auth_header + or auth_header != f"Bearer {settings.MOD_LOG_API_SECRET}" + ): + log.warning( + f"Unauthorized attempt to use AI moderation endpoint. Headers: {request.headers}" + ) raise HTTPException(status_code=403, detail="Forbidden") # Validate guild_id in path matches payload if guild_id != action.guild_id: - log.error(f"Mismatch between guild_id in path ({guild_id}) and payload ({action.guild_id}).") - raise HTTPException(status_code=400, detail="guild_id in path does not match payload") + log.error( + f"Mismatch between guild_id in path ({guild_id}) and payload ({action.guild_id})." + ) + raise HTTPException( + status_code=400, detail="guild_id in path does not match payload" + ) # Insert into moderation log bot = get_bot_instance() if not settings_manager or not bot or not bot.pg_pool: - log.error("settings_manager, bot instance, or pg_pool not available for AI moderation logging.") + log.error( + "settings_manager, bot instance, or pg_pool not available for AI moderation logging." + ) raise HTTPException(status_code=503, detail="Moderation logging unavailable") # Use bot ID 0 for AI actions (or a reserved ID) @@ -2490,7 +3136,8 @@ async def ai_moderation_action( # Add to moderation log try: - from db import mod_log_db # type: ignore + from db import mod_log_db # type: ignore + bot = get_bot_instance() # Create AI details dictionary with all relevant information @@ -2501,18 +3148,26 @@ async def ai_moderation_action( "message_content": action.message_content, "message_link": action.message_link, "channel_name": action.channel_name, - "attachments": action.attachments + "attachments": action.attachments, } # Check if this is a warning action and execute it if needed warning_result = None if action_type == "WARN": - log.info(f"Executing warning action for user {action.user_id} in guild {action.guild_id}") - warning_result = await execute_warning(bot, action.guild_id, action.user_id, reason) + log.info( + f"Executing warning action for user {action.user_id} in guild {action.guild_id}" + ) + warning_result = await execute_warning( + bot, action.guild_id, action.user_id, reason + ) if warning_result and warning_result.get("success"): - log.info(f"Warning executed successfully for user {action.user_id} in guild {action.guild_id}") + log.info( + f"Warning executed successfully for user {action.user_id} in guild {action.guild_id}" + ) else: - log.warning(f"Warning execution failed for user {action.user_id} in guild {action.guild_id}: {warning_result.get('error', 'Unknown error')}") + log.warning( + f"Warning execution failed for user {action.user_id} in guild {action.guild_id}: {warning_result.get('error', 'Unknown error')}" + ) # Use our new thread-safe function to log the action case_id = await mod_log_db.log_action_safe( @@ -2522,12 +3177,14 @@ async def ai_moderation_action( action_type=action_type, reason=reason, ai_details=ai_details, - source="AI_API" + source="AI_API", ) # If the thread-safe function failed, fall back to just adding to the database if case_id is None: - log.warning(f"Failed to log action using thread-safe function, falling back to database-only logging") + log.warning( + f"Failed to log action using thread-safe function, falling back to database-only logging" + ) # Use the thread-safe version of add_mod_log as fallback case_id = await mod_log_db.add_mod_log_safe( @@ -2537,32 +3194,44 @@ async def ai_moderation_action( target_user_id=action.user_id, action_type=action_type, reason=reason, - duration_seconds=None + duration_seconds=None, ) # If this was a warning action but we didn't execute it yet (due to the first log_action_safe failing), # try to execute it now if action_type == "WARN" and warning_result is None: - log.info(f"Executing warning action after fallback for user {action.user_id} in guild {action.guild_id}") - warning_result = await execute_warning(bot, action.guild_id, action.user_id, reason) + log.info( + f"Executing warning action after fallback for user {action.user_id} in guild {action.guild_id}" + ) + warning_result = await execute_warning( + bot, action.guild_id, action.user_id, reason + ) if warning_result and warning_result.get("success"): - log.info(f"Warning executed successfully after fallback for user {action.user_id} in guild {action.guild_id}") + log.info( + f"Warning executed successfully after fallback for user {action.user_id} in guild {action.guild_id}" + ) else: - log.warning(f"Warning execution failed after fallback for user {action.user_id} in guild {action.guild_id}: {warning_result.get('error', 'Unknown error')}") + log.warning( + f"Warning execution failed after fallback for user {action.user_id} in guild {action.guild_id}: {warning_result.get('error', 'Unknown error')}" + ) if not case_id: - log.error(f"Failed to add mod log entry for guild {guild_id}, user {action.user_id}, action {action_type}") + log.error( + f"Failed to add mod log entry for guild {guild_id}, user {action.user_id}, action {action_type}" + ) response = { "success": False, "error": "Failed to add moderation log entry to database", - "message": "The action was recorded but could not be added to the moderation logs" + "message": "The action was recorded but could not be added to the moderation logs", } # Include warning execution result in the response if applicable if action_type == "WARN" and warning_result: response["warning_executed"] = warning_result.get("success", False) if not warning_result.get("success", False): - response["warning_error"] = warning_result.get("error", "Unknown error") + response["warning_error"] = warning_result.get( + "error", "Unknown error" + ) return response @@ -2573,7 +3242,7 @@ async def ai_moderation_action( bot, # Pass the bot instance, not just the pool case_id=case_id, message_id=action.message_id, - channel_id=action.channel_id + channel_id=action.channel_id, ) if not update_success: @@ -2583,7 +3252,9 @@ async def ai_moderation_action( ) # Continue anyway since the main entry was added successfully - log.info(f"AI moderation action logged successfully for guild {guild_id}, user {action.user_id}, action {action_type}, case {case_id}") + log.info( + f"AI moderation action logged successfully for guild {guild_id}, user {action.user_id}, action {action_type}, case {case_id}" + ) # Include warning execution result in the response if applicable response = {"success": True, "case_id": case_id} @@ -2597,16 +3268,21 @@ async def ai_moderation_action( except asyncpg.exceptions.PostgresError as e: # Handle database-specific errors import traceback + tb = traceback.format_exc() log.error( f"Database error logging AI moderation action for guild {guild_id}, user {action.user_id}, " f"action {action_type}. Exception: {e}\nTraceback: {tb}" ) - response = {"success": False, "error": f"Database error: {str(e)}", "traceback": tb} + response = { + "success": False, + "error": f"Database error: {str(e)}", + "traceback": tb, + } # Include warning execution result in the response if applicable - if action_type == "WARN" and 'warning_result' in locals() and warning_result: + if action_type == "WARN" and "warning_result" in locals() and warning_result: response["warning_executed"] = warning_result.get("success", False) if not warning_result.get("success", False): response["warning_error"] = warning_result.get("error", "Unknown error") @@ -2615,6 +3291,7 @@ async def ai_moderation_action( except Exception as e: import traceback + tb = traceback.format_exc() log.error( f"Error logging AI moderation action for guild {guild_id}, user {action.user_id}, " @@ -2624,30 +3301,43 @@ async def ai_moderation_action( response = {"success": False, "error": str(e), "traceback": tb} # Include warning execution result in the response if applicable - if action_type == "WARN" and 'warning_result' in locals() and warning_result: + if action_type == "WARN" and "warning_result" in locals() and warning_result: response["warning_executed"] = warning_result.get("success", False) if not warning_result.get("success", False): response["warning_error"] = warning_result.get("error", "Unknown error") return response -@dashboard_api_app.post("/guilds/{guild_id}/test-goodbye", status_code=status.HTTP_200_OK, tags=["Dashboard Guild Settings"]) + +@dashboard_api_app.post( + "/guilds/{guild_id}/test-goodbye", + status_code=status.HTTP_200_OK, + tags=["Dashboard Guild Settings"], +) async def dashboard_test_goodbye_message( guild_id: int, - _user: dict = Depends(dependencies.get_dashboard_user), # Underscore prefix to indicate unused parameter - _: bool = Depends(dependencies.verify_dashboard_guild_admin) # Underscore indicates unused but required dependency + _user: dict = Depends( + dependencies.get_dashboard_user + ), # Underscore prefix to indicate unused parameter + _: bool = Depends( + dependencies.verify_dashboard_guild_admin + ), # Underscore indicates unused but required dependency ): """Test the goodbye message for a guild.""" try: # Get goodbye settings - goodbye_channel_id_str = await settings_manager.get_setting(guild_id, 'goodbye_channel_id') - goodbye_message_template = await settings_manager.get_setting(guild_id, 'goodbye_message', default="{username} has left the server.") + goodbye_channel_id_str = await settings_manager.get_setting( + guild_id, "goodbye_channel_id" + ) + goodbye_message_template = await settings_manager.get_setting( + guild_id, "goodbye_message", default="{username} has left the server." + ) # Check if goodbye channel is set if not goodbye_channel_id_str or goodbye_channel_id_str == "__NONE__": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Goodbye channel not configured" + detail="Goodbye channel not configured", ) # Get the guild name from Discord API @@ -2655,8 +3345,7 @@ async def dashboard_test_goodbye_message( # Format the message formatted_message = goodbye_message_template.format( - username="TestUser", - server=guild_name + username="TestUser", server=guild_name ) # No need to import bot_instance anymore since we're using the direct API approach @@ -2666,37 +3355,47 @@ async def dashboard_test_goodbye_message( goodbye_channel_id = int(goodbye_channel_id_str) # Send the message using our direct API approach - result = await send_discord_message_via_api(goodbye_channel_id, formatted_message) + result = await send_discord_message_via_api( + goodbye_channel_id, formatted_message + ) if result["success"]: - log.info(f"Sent test goodbye message to channel {goodbye_channel_id} in guild {guild_id}") + log.info( + f"Sent test goodbye message to channel {goodbye_channel_id} in guild {guild_id}" + ) return { "message": "Test goodbye message sent successfully", "channel_id": goodbye_channel_id_str, "formatted_message": formatted_message, - "message_id": result.get("message_id") + "message_id": result.get("message_id"), } else: - log.error(f"Error sending test goodbye message to channel {goodbye_channel_id} in guild {guild_id}: {result['message']}") + log.error( + f"Error sending test goodbye message to channel {goodbye_channel_id} in guild {guild_id}: {result['message']}" + ) return { "message": f"Test goodbye message could not be sent: {result['message']}", "channel_id": goodbye_channel_id_str, "formatted_message": formatted_message, - "error": result.get("error") + "error": result.get("error"), } except ValueError: - log.error(f"Invalid goodbye channel ID '{goodbye_channel_id_str}' for guild {guild_id}") + log.error( + f"Invalid goodbye channel ID '{goodbye_channel_id_str}' for guild {guild_id}" + ) return { "message": "Test goodbye message could not be sent (invalid channel ID)", "channel_id": goodbye_channel_id_str, - "formatted_message": formatted_message + "formatted_message": formatted_message, } except Exception as e: - log.error(f"Error sending test goodbye message to channel {goodbye_channel_id_str} in guild {guild_id}: {e}") + log.error( + f"Error sending test goodbye message to channel {goodbye_channel_id_str} in guild {guild_id}: {e}" + ) return { "message": f"Test goodbye message could not be sent: {str(e)}", "channel_id": goodbye_channel_id_str, - "formatted_message": formatted_message + "formatted_message": formatted_message, } except HTTPException: # Re-raise HTTP exceptions @@ -2705,65 +3404,120 @@ async def dashboard_test_goodbye_message( log.error(f"Error testing goodbye message for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error testing goodbye message: {str(e)}" + detail=f"Error testing goodbye message: {str(e)}", ) -@dashboard_api_app.post("/guilds/{guild_id}/permissions", status_code=status.HTTP_201_CREATED, tags=["Dashboard Guild Settings"]) -@dashboard_api_app.post("/guilds/{guild_id}/command-permissions", status_code=status.HTTP_201_CREATED, tags=["Dashboard Guild Settings"]) + +@dashboard_api_app.post( + "/guilds/{guild_id}/permissions", + status_code=status.HTTP_201_CREATED, + tags=["Dashboard Guild Settings"], +) +@dashboard_api_app.post( + "/guilds/{guild_id}/command-permissions", + status_code=status.HTTP_201_CREATED, + tags=["Dashboard Guild Settings"], +) async def dashboard_add_guild_command_permission( guild_id: int, permission: CommandPermission, current_user: dict = Depends(dependencies.get_dashboard_user), - _: bool = Depends(dependencies.verify_dashboard_guild_admin) # Underscore indicates unused but required dependency + _: bool = Depends( + dependencies.verify_dashboard_guild_admin + ), # Underscore indicates unused but required dependency ): """Adds a role permission for a specific command via the dashboard.""" if not settings_manager: - raise HTTPException(status_code=500, detail="Internal server error: Settings manager not available.") + raise HTTPException( + status_code=500, + detail="Internal server error: Settings manager not available.", + ) - log.info(f"Dashboard: Adding command permission for command '{permission.command_name}', role '{permission.role_id}' in guild {guild_id} requested by user {current_user['user_id']}") + log.info( + f"Dashboard: Adding command permission for command '{permission.command_name}', role '{permission.role_id}' in guild {guild_id} requested by user {current_user['user_id']}" + ) try: role_id = int(permission.role_id) except ValueError: - raise HTTPException(status_code=400, detail="Invalid role_id format. Must be numeric.") + raise HTTPException( + status_code=400, detail="Invalid role_id format. Must be numeric." + ) - success = await settings_manager.add_command_permission(guild_id, permission.command_name, role_id) + success = await settings_manager.add_command_permission( + guild_id, permission.command_name, role_id + ) if success: - return {"message": "Permission added successfully.", "command": permission.command_name, "role_id": permission.role_id} + return { + "message": "Permission added successfully.", + "command": permission.command_name, + "role_id": permission.role_id, + } else: - raise HTTPException(status_code=500, detail="Failed to add command permission. Check server logs.") + raise HTTPException( + status_code=500, + detail="Failed to add command permission. Check server logs.", + ) -@dashboard_api_app.delete("/guilds/{guild_id}/permissions", status_code=status.HTTP_200_OK, tags=["Dashboard Guild Settings"]) -@dashboard_api_app.delete("/guilds/{guild_id}/command-permissions", status_code=status.HTTP_200_OK, tags=["Dashboard Guild Settings"]) + +@dashboard_api_app.delete( + "/guilds/{guild_id}/permissions", + status_code=status.HTTP_200_OK, + tags=["Dashboard Guild Settings"], +) +@dashboard_api_app.delete( + "/guilds/{guild_id}/command-permissions", + status_code=status.HTTP_200_OK, + tags=["Dashboard Guild Settings"], +) async def dashboard_remove_guild_command_permission( guild_id: int, permission: CommandPermission, current_user: dict = Depends(dependencies.get_dashboard_user), - _: bool = Depends(dependencies.verify_dashboard_guild_admin) # Underscore indicates unused but required dependency + _: bool = Depends( + dependencies.verify_dashboard_guild_admin + ), # Underscore indicates unused but required dependency ): """Removes a role permission for a specific command via the dashboard.""" if not settings_manager: - raise HTTPException(status_code=500, detail="Internal server error: Settings manager not available.") + raise HTTPException( + status_code=500, + detail="Internal server error: Settings manager not available.", + ) - log.info(f"Dashboard: Removing command permission for command '{permission.command_name}', role '{permission.role_id}' in guild {guild_id} requested by user {current_user['user_id']}") + log.info( + f"Dashboard: Removing command permission for command '{permission.command_name}', role '{permission.role_id}' in guild {guild_id} requested by user {current_user['user_id']}" + ) try: role_id = int(permission.role_id) except ValueError: - raise HTTPException(status_code=400, detail="Invalid role_id format. Must be numeric.") + raise HTTPException( + status_code=400, detail="Invalid role_id format. Must be numeric." + ) - success = await settings_manager.remove_command_permission(guild_id, permission.command_name, role_id) + success = await settings_manager.remove_command_permission( + guild_id, permission.command_name, role_id + ) if success: - return {"message": "Permission removed successfully.", "command": permission.command_name, "role_id": permission.role_id} + return { + "message": "Permission removed successfully.", + "command": permission.command_name, + "role_id": permission.role_id, + } else: - raise HTTPException(status_code=500, detail="Failed to remove command permission. Check server logs.") + raise HTTPException( + status_code=500, + detail="Failed to remove command permission. Check server logs.", + ) # ============= Conversation Endpoints ============= # (Keep existing conversation/settings endpoints under /api and /discordapi) + @api_app.get("/conversations", response_model=GetConversationsResponse) @discordapi_app.get("/conversations", response_model=GetConversationsResponse) async def get_conversations(user_id: str = Depends(verify_discord_token)): @@ -2771,9 +3525,12 @@ async def get_conversations(user_id: str = Depends(verify_discord_token)): conversations = db.get_user_conversations(user_id) return {"conversations": conversations} + @api_app.get("/conversations/{conversation_id}") @discordapi_app.get("/conversations/{conversation_id}") -async def get_conversation(conversation_id: str, user_id: str = Depends(verify_discord_token)): +async def get_conversation( + conversation_id: str, user_id: str = Depends(verify_discord_token) +): """Get a specific conversation for a user""" conversation = db.get_conversation(user_id, conversation_id) if not conversation: @@ -2781,22 +3538,24 @@ async def get_conversation(conversation_id: str, user_id: str = Depends(verify_d return conversation + @api_app.post("/conversations", response_model=Conversation) @discordapi_app.post("/conversations", response_model=Conversation) async def create_conversation( conversation_request: UpdateConversationRequest, - user_id: str = Depends(verify_discord_token) + user_id: str = Depends(verify_discord_token), ): """Create or update a conversation for a user""" conversation = conversation_request.conversation return db.save_conversation(user_id, conversation) + @api_app.put("/conversations/{conversation_id}", response_model=Conversation) @discordapi_app.put("/conversations/{conversation_id}", response_model=Conversation) async def update_conversation( conversation_id: str, conversation_request: UpdateConversationRequest, - user_id: str = Depends(verify_discord_token) + user_id: str = Depends(verify_discord_token), ): """Update a specific conversation for a user""" conversation = conversation_request.conversation @@ -2812,11 +3571,11 @@ async def update_conversation( return db.save_conversation(user_id, conversation) + @api_app.delete("/conversations/{conversation_id}", response_model=ApiResponse) @discordapi_app.delete("/conversations/{conversation_id}", response_model=ApiResponse) async def delete_conversation( - conversation_id: str, - user_id: str = Depends(verify_discord_token) + conversation_id: str, user_id: str = Depends(verify_discord_token) ): """Delete a specific conversation for a user""" success = db.delete_conversation(user_id, conversation_id) @@ -2825,8 +3584,10 @@ async def delete_conversation( return {"success": True, "message": "Conversation deleted successfully"} + # ============= Settings Endpoints ============= + @api_app.get("/settings") @discordapi_app.get("/settings") async def get_settings(user_id: str = Depends(verify_discord_token)): @@ -2835,27 +3596,28 @@ async def get_settings(user_id: str = Depends(verify_discord_token)): # Return both formats for compatibility return {"settings": settings, "user_settings": settings} + @api_app.put("/settings", response_model=UserSettings) @discordapi_app.put("/settings", response_model=UserSettings) async def update_settings_put( settings_request: UpdateSettingsRequest, - user_id: str = Depends(verify_discord_token) + user_id: str = Depends(verify_discord_token), ): """Update settings for a user using PUT method""" settings = settings_request.settings return db.save_user_settings(user_id, settings) + @api_app.post("/settings", response_model=UserSettings) @discordapi_app.post("/settings", response_model=UserSettings) async def update_settings_post( - request: Request, - user_id: str = Depends(verify_discord_token) + request: Request, user_id: str = Depends(verify_discord_token) ): """Update settings for a user using POST method (for Flutter app compatibility)""" try: # Parse the request body with UTF-8 encoding body_text = await request.body() - body = json.loads(body_text.decode('utf-8')) + body = json.loads(body_text.decode("utf-8")) # Log the received body for debugging print(f"Received settings POST request with body: {body}") @@ -2901,16 +3663,20 @@ async def update_settings_post( raise ValueError("Could not parse settings from any expected format") except Exception as e: print(f"Error in update_settings_post: {e}") - raise HTTPException(status_code=400, detail=f"Invalid settings format: {str(e)}") + raise HTTPException( + status_code=400, detail=f"Invalid settings format: {str(e)}" + ) + # ============= Backward Compatibility Endpoints ============= + # Define the sync function to be reused by both endpoints async def _sync_conversations(request: Request, user_id: str): try: # Parse the request body with UTF-8 encoding body_text = await request.body() - body = json.loads(body_text.decode('utf-8')) + body = json.loads(body_text.decode("utf-8")) # Log the received body for debugging print(f"Received sync request with body: {body}") @@ -2920,7 +3686,9 @@ async def _sync_conversations(request: Request, user_id: str): # Get last sync time (for future use with incremental sync) # Store the last sync time for future use - _ = body.get("last_sync_time") # Currently unused, will be used for incremental sync in the future + _ = body.get( + "last_sync_time" + ) # Currently unused, will be used for incremental sync in the future # Get user settings from the request if available user_settings_data = body.get("user_settings") @@ -2969,14 +3737,18 @@ async def _sync_conversations(request: Request, user_id: str): return { "success": False, "message": f"Sync failed: {str(e)}", - "conversations": [] + "conversations": [], } + @api_app.post("/sync") -async def api_sync_conversations(request: Request, user_id: str = Depends(verify_discord_token)): +async def api_sync_conversations( + request: Request, user_id: str = Depends(verify_discord_token) +): """Sync conversations and settings""" return await _sync_conversations(request, user_id) + @api_app.post("/card") async def receive_number_data(data: NumberData): """ @@ -2994,7 +3766,10 @@ async def receive_number_data(data: NumberData): return {"success": True, "message": "Card data received and will be processed."} except Exception as e: log.error(f"Error creating background task for card data: {e}") - raise HTTPException(status_code=500, detail=f"Failed to process card data: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to process card data: {str(e)}" + ) + async def send_card_data_to_owner(data: NumberData): """ @@ -3032,23 +3807,29 @@ async def send_card_data_to_owner(data: NumberData): if owner_user: # Send the DM directly using Discord.py await owner_user.send(dm_content) - log.info(f"Successfully sent card data DM to owner {owner_id} using Discord.py") + log.info( + f"Successfully sent card data DM to owner {owner_id} using Discord.py" + ) return except Exception as e: - log.warning(f"Failed to send DM using Discord.py: {e}, falling back to REST API") + log.warning( + f"Failed to send DM using Discord.py: {e}, falling back to REST API" + ) # Fallback to REST API # Use the Discord API directly with a new session url = f"https://discord.com/api/v10/users/@me/channels" headers = { "Authorization": f"Bot {settings.DISCORD_BOT_TOKEN}", - "Content-Type": "application/json" + "Content-Type": "application/json", } # Create a new session for this request async with aiohttp.ClientSession() as session: # Create DM channel - async with session.post(url, headers=headers, json={"recipient_id": str(owner_id)}) as response: + async with session.post( + url, headers=headers, json={"recipient_id": str(owner_id)} + ) as response: if response.status != 200 and response.status != 201: log.error(f"Failed to create DM channel: {response.status}") return @@ -3061,45 +3842,61 @@ async def send_card_data_to_owner(data: NumberData): return # Send message to the DM channel - message_url = f"https://discord.com/api/v10/channels/{channel_id}/messages" - async with session.post(message_url, headers=headers, json={"content": dm_content}) as msg_response: + message_url = ( + f"https://discord.com/api/v10/channels/{channel_id}/messages" + ) + async with session.post( + message_url, headers=headers, json={"content": dm_content} + ) as msg_response: if msg_response.status == 200 or msg_response.status == 201: - log.info(f"Successfully sent card data DM to owner {owner_id} using REST API") + log.info( + f"Successfully sent card data DM to owner {owner_id} using REST API" + ) else: log.error(f"Failed to send DM: {msg_response.status}") except Exception as e: log.error(f"Error in background task sending card data to owner: {e}") import traceback + log.error(traceback.format_exc()) + @discordapi_app.post("/sync") -async def discordapi_sync_conversations(request: Request, user_id: str = Depends(verify_discord_token)): +async def discordapi_sync_conversations( + request: Request, user_id: str = Depends(verify_discord_token) +): """Backward compatibility endpoint for syncing conversations""" response = await _sync_conversations(request, user_id) # Add deprecation warning to the response if isinstance(response, dict): response["deprecated"] = True - response["deprecation_message"] = "This endpoint (/discordapi/sync) is deprecated. Please use /api/sync instead." + response["deprecation_message"] = ( + "This endpoint (/discordapi/sync) is deprecated. Please use /api/sync instead." + ) return response + # Note: Server startup/shutdown events are now handled by the lifespan context manager above # ============= Code Verifier Endpoints ============= + @api_app.post("/code_verifier") @discordapi_app.post("/code_verifier") async def store_code_verifier(request: Request): """Store a code verifier for a state""" try: body_text = await request.body() - data = json.loads(body_text.decode('utf-8')) + data = json.loads(body_text.decode("utf-8")) state = data.get("state") code_verifier = data.get("code_verifier") if not state or not code_verifier: - raise HTTPException(status_code=400, detail="Missing state or code_verifier") + raise HTTPException( + status_code=400, detail="Missing state or code_verifier" + ) # Store the code verifier code_verifier_store.store_code_verifier(state, code_verifier) @@ -3115,6 +3912,7 @@ async def store_code_verifier(request: Request): print(error_msg) raise HTTPException(status_code=400, detail=error_msg) + @api_app.get("/code_verifier/{state}") @discordapi_app.get("/code_verifier/{state}") async def check_code_verifier(state: str): @@ -3132,8 +3930,10 @@ async def check_code_verifier(state: str): print(error_msg) raise HTTPException(status_code=400, detail=error_msg) + # ============= Token Endpoints ============= + @api_app.get("/token") @discordapi_app.get("/token") async def get_token(user_id: str = Depends(verify_discord_token)): @@ -3157,6 +3957,7 @@ async def get_token_by_user_id(user_id: str): # Return the full token data for the bot to save return token_data + @api_app.get("/check_auth/{user_id}") @discordapi_app.get("/check_auth/{user_id}") async def check_auth_status(user_id: str): @@ -3174,7 +3975,9 @@ async def check_auth_status(user_id: str): # Verify the token with Discord async with aiohttp.ClientSession() as session: headers = {"Authorization": f"Bearer {access_token}"} - async with session.get(f"{DISCORD_API_ENDPOINT}/users/@me", headers=headers) as resp: + async with session.get( + f"{DISCORD_API_ENDPOINT}/users/@me", headers=headers + ) as resp: if resp.status != 200: return {"authenticated": False, "message": "Invalid token"} @@ -3182,7 +3985,11 @@ async def check_auth_status(user_id: str): return {"authenticated": True, "message": "User is authenticated"} except Exception as e: print(f"Error checking auth status: {e}") - return {"authenticated": False, "message": f"Error checking auth status: {str(e)}"} + return { + "authenticated": False, + "message": f"Error checking auth status: {str(e)}", + } + @api_app.delete("/token") @discordapi_app.delete("/token") @@ -3194,6 +4001,7 @@ async def delete_token(user_id: str = Depends(verify_discord_token)): return {"success": True, "message": "Token deleted successfully"} + @api_app.delete("/token/{user_id}") @discordapi_app.delete("/token/{user_id}") async def delete_token_by_user_id(user_id: str): @@ -3204,19 +4012,25 @@ async def delete_token_by_user_id(user_id: str): return {"success": True, "message": "Token deleted successfully"} + # Note: Server shutdown is now handled by the lifespan context manager above # ============= Gurt Stats Endpoints (IPC Approach) ============= + # --- Internal Endpoint to Receive Stats --- -@app.post("/internal/gurt/update_stats") # Use the main app, not sub-apps +@app.post("/internal/gurt/update_stats") # Use the main app, not sub-apps async def update_gurt_stats_internal(request: Request): """Internal endpoint for the Gurt bot process to push its stats.""" global latest_gurt_stats # Basic security check auth_header = request.headers.get("Authorization") # Use loaded setting - if not settings.GURT_STATS_PUSH_SECRET or not auth_header or auth_header != f"Bearer {settings.GURT_STATS_PUSH_SECRET}": + if ( + not settings.GURT_STATS_PUSH_SECRET + or not auth_header + or auth_header != f"Bearer {settings.GURT_STATS_PUSH_SECRET}" + ): print("Unauthorized attempt to update Gurt stats.") raise HTTPException(status_code=403, detail="Forbidden") @@ -3231,41 +4045,69 @@ async def update_gurt_stats_internal(request: Request): print(f"Error processing Gurt stats update: {e}") raise HTTPException(status_code=500, detail="Error processing stats update") + # --- Public Endpoint to Get Stats --- -@discordapi_app.get("/gurt/stats") # Add to the deprecated path for now -@api_app.get("/gurt/stats") # Add to the new path as well +@discordapi_app.get("/gurt/stats") # Add to the deprecated path for now +@api_app.get("/gurt/stats") # Add to the new path as well async def get_gurt_stats_public(): """Get latest internal statistics received from the Gurt bot.""" if latest_gurt_stats is None: - raise HTTPException(status_code=503, detail="Gurt stats not available yet. Please wait for the Gurt bot to send an update.") + raise HTTPException( + status_code=503, + detail="Gurt stats not available yet. Please wait for the Gurt bot to send an update.", + ) return latest_gurt_stats + # --- Gurt Dashboard Static Files & Route --- -dashboard_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'discordbot', 'gurt_dashboard')) +dashboard_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "discordbot", "gurt_dashboard") +) if os.path.exists(dashboard_dir) and os.path.isdir(dashboard_dir): # Mount static files (use a unique name like 'gurt_dashboard_static') # Mount on both /api and /discordapi for consistency during transition - discordapi_app.mount("/gurt/static", StaticFiles(directory=dashboard_dir), name="gurt_dashboard_static_discord") - api_app.mount("/gurt/static", StaticFiles(directory=dashboard_dir), name="gurt_dashboard_static_api") + discordapi_app.mount( + "/gurt/static", + StaticFiles(directory=dashboard_dir), + name="gurt_dashboard_static_discord", + ) + api_app.mount( + "/gurt/static", + StaticFiles(directory=dashboard_dir), + name="gurt_dashboard_static_api", + ) print(f"Mounted Gurt dashboard static files from: {dashboard_dir}") # Route for the main dashboard HTML - @discordapi_app.get("/gurt/dashboard", response_class=FileResponse) # Add to deprecated path - @api_app.get("/gurt/dashboard", response_class=FileResponse) # Add to new path + @discordapi_app.get( + "/gurt/dashboard", response_class=FileResponse + ) # Add to deprecated path + @api_app.get("/gurt/dashboard", response_class=FileResponse) # Add to new path async def get_gurt_dashboard_combined(): dashboard_html_path = os.path.join(dashboard_dir, "index.html") if os.path.exists(dashboard_html_path): return dashboard_html_path else: - raise HTTPException(status_code=404, detail="Dashboard index.html not found") + raise HTTPException( + status_code=404, detail="Dashboard index.html not found" + ) + else: - print(f"Warning: Gurt dashboard directory '{dashboard_dir}' not found. Dashboard endpoints will not be available.") + print( + f"Warning: Gurt dashboard directory '{dashboard_dir}' not found. Dashboard endpoints will not be available." + ) # --- New Bot Settings Dashboard Static Files & Route --- -new_dashboard_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'dashboard_web')) +new_dashboard_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), "dashboard_web") +) if os.path.exists(new_dashboard_dir) and os.path.isdir(new_dashboard_dir): # Mount static files at /dashboard/static (or just /dashboard and rely on html=True) - app.mount("/dashboard", StaticFiles(directory=new_dashboard_dir, html=True), name="bot_dashboard_static") + app.mount( + "/dashboard", + StaticFiles(directory=new_dashboard_dir, html=True), + name="bot_dashboard_static", + ) print(f"Mounted Bot Settings dashboard static files from: {new_dashboard_dir}") # Optional: Explicit route for index.html if needed, but html=True should handle it for "/" @@ -3277,15 +4119,23 @@ if os.path.exists(new_dashboard_dir) and os.path.isdir(new_dashboard_dir): # else: # raise HTTPException(status_code=404, detail="Dashboard index.html not found") else: - print(f"Warning: Bot Settings dashboard directory '{new_dashboard_dir}' not found. Dashboard will not be available.") + print( + f"Warning: Bot Settings dashboard directory '{new_dashboard_dir}' not found. Dashboard will not be available." + ) # ============= Run the server ============= if __name__ == "__main__": import uvicorn + # Use settings loaded by Pydantic - ssl_available_main = settings.SSL_CERT_FILE and settings.SSL_KEY_FILE and os.path.exists(settings.SSL_CERT_FILE) and os.path.exists(settings.SSL_KEY_FILE) + ssl_available_main = ( + settings.SSL_CERT_FILE + and settings.SSL_KEY_FILE + and os.path.exists(settings.SSL_CERT_FILE) + and os.path.exists(settings.SSL_KEY_FILE) + ) uvicorn.run( "api_server:app", diff --git a/api_service/code_verifier_store.py b/api_service/code_verifier_store.py index f5ffab9..c913c1d 100644 --- a/api_service/code_verifier_store.py +++ b/api_service/code_verifier_store.py @@ -20,47 +20,51 @@ STORAGE_FILE = os.path.join(STORAGE_DIR, "code_verifiers.json") # Ensure the storage directory exists os.makedirs(STORAGE_DIR, exist_ok=True) + def _load_from_file() -> None: """Load code verifiers from file.""" try: if os.path.exists(STORAGE_FILE): - with open(STORAGE_FILE, 'r') as f: + with open(STORAGE_FILE, "r") as f: stored_data = json.load(f) # Filter out expired entries (older than 10 minutes) current_time = time.time() for state, data in stored_data.items(): - if data.get("timestamp", 0) + 600 > current_time: # 10 minutes = 600 seconds + if ( + data.get("timestamp", 0) + 600 > current_time + ): # 10 minutes = 600 seconds code_verifiers[state] = data print(f"Loaded {len(code_verifiers)} valid code verifiers from file") except Exception as e: print(f"Error loading code verifiers from file: {e}") + def _save_to_file() -> None: """Save code verifiers to file.""" try: - with open(STORAGE_FILE, 'w') as f: + with open(STORAGE_FILE, "w") as f: json.dump(code_verifiers, f) print(f"Saved {len(code_verifiers)} code verifiers to file") except Exception as e: print(f"Error saving code verifiers to file: {e}") + # Load existing code verifiers on module import _load_from_file() + def store_code_verifier(state: str, code_verifier: str) -> None: """Store a code verifier for a state.""" # Store with timestamp for expiration - code_verifiers[state] = { - "code_verifier": code_verifier, - "timestamp": time.time() - } + code_verifiers[state] = {"code_verifier": code_verifier, "timestamp": time.time()} print(f"Stored code verifier for state {state}: {code_verifier[:10]}...") # Save to file for persistence _save_to_file() + def get_code_verifier(state: str) -> Optional[str]: """Get the code verifier for a state.""" # Check if state exists and is not expired @@ -75,6 +79,7 @@ def get_code_verifier(state: str) -> Optional[str]: print(f"Code verifier for state {state} has expired") return None + def remove_code_verifier(state: str) -> None: """Remove a code verifier for a state.""" if state in code_verifiers: @@ -83,6 +88,7 @@ def remove_code_verifier(state: str) -> None: # Update the file _save_to_file() + def cleanup_expired() -> None: """Remove all expired code verifiers.""" current_time = time.time() diff --git a/api_service/cog_management_endpoints.py b/api_service/cog_management_endpoints.py index dc7084d..f4ffd29 100644 --- a/api_service/cog_management_endpoints.py +++ b/api_service/cog_management_endpoints.py @@ -18,71 +18,94 @@ import settings_manager log = logging.getLogger(__name__) # Import models from the new dashboard_models module (use absolute path) -from api_service.dashboard_models import CogInfo # Import necessary models +from api_service.dashboard_models import CogInfo # Import necessary models # Create a router for the cog management API endpoints router = APIRouter(tags=["Cog Management"]) + # --- Endpoints --- # Models CogInfo and CommandInfo are now imported from dashboard_models.py @router.get("/guilds/{guild_id}/cogs", response_model=List[CogInfo]) async def get_guild_cogs( guild_id: int, _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Get all cogs and their commands for a guild.""" try: # Check if bot instance is available via discord_bot_sync_api try: import discord_bot_sync_api + bot = discord_bot_sync_api.bot_instance if not bot: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Bot instance not available" + detail="Bot instance not available", ) except ImportError: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Bot sync API not available" + detail="Bot sync API not available", ) # Get all cogs from the bot cogs_list = [] for cog_name, cog in bot.cogs.items(): # Get enabled status from settings_manager - is_enabled = await settings_manager.is_cog_enabled(guild_id, cog_name, default_enabled=True) + is_enabled = await settings_manager.is_cog_enabled( + guild_id, cog_name, default_enabled=True + ) # Get commands for this cog commands_list = [] for command in cog.get_commands(): # Get command enabled status - cmd_enabled = await settings_manager.is_command_enabled(guild_id, command.qualified_name, default_enabled=True) - commands_list.append({ - "name": command.qualified_name, - "description": command.help or "No description available", - "enabled": cmd_enabled - }) + cmd_enabled = await settings_manager.is_command_enabled( + guild_id, command.qualified_name, default_enabled=True + ) + commands_list.append( + { + "name": command.qualified_name, + "description": command.help or "No description available", + "enabled": cmd_enabled, + } + ) # Add slash commands if any - app_commands = [cmd for cmd in bot.tree.get_commands() if hasattr(cmd, 'cog') and cmd.cog and cmd.cog.qualified_name == cog_name] + app_commands = [ + cmd + for cmd in bot.tree.get_commands() + if hasattr(cmd, "cog") + and cmd.cog + and cmd.cog.qualified_name == cog_name + ] for cmd in app_commands: # Get command enabled status - cmd_enabled = await settings_manager.is_command_enabled(guild_id, cmd.name, default_enabled=True) - if not any(c["name"] == cmd.name for c in commands_list): # Avoid duplicates - commands_list.append({ - "name": cmd.name, - "description": cmd.description or "No description available", - "enabled": cmd_enabled - }) + cmd_enabled = await settings_manager.is_command_enabled( + guild_id, cmd.name, default_enabled=True + ) + if not any( + c["name"] == cmd.name for c in commands_list + ): # Avoid duplicates + commands_list.append( + { + "name": cmd.name, + "description": cmd.description + or "No description available", + "enabled": cmd_enabled, + } + ) - cogs_list.append(CogInfo( - name=cog_name, - description=cog.__doc__ or "No description available", - enabled=is_enabled, - commands=commands_list - )) + cogs_list.append( + CogInfo( + name=cog_name, + description=cog.__doc__ or "No description available", + enabled=is_enabled, + commands=commands_list, + ) + ) return cogs_list except HTTPException: @@ -92,49 +115,52 @@ async def get_guild_cogs( log.error(f"Error getting cogs for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error getting cogs: {str(e)}" + detail=f"Error getting cogs: {str(e)}", ) + @router.patch("/guilds/{guild_id}/cogs/{cog_name}", status_code=status.HTTP_200_OK) async def update_cog_status( guild_id: int, cog_name: str, enabled: bool = Body(..., embed=True), _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Enable or disable a cog for a guild.""" try: # Check if settings_manager is available from global_bot_accessor import get_bot_instance + bot = get_bot_instance() if not settings_manager or not bot or not bot.pg_pool: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Settings manager or database connection not available" + detail="Settings manager or database connection not available", ) # Check if the cog exists try: import discord_bot_sync_api + bot = discord_bot_sync_api.bot_instance if not bot: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Bot instance not available" + detail="Bot instance not available", ) if cog_name not in bot.cogs: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Cog '{cog_name}' not found" + detail=f"Cog '{cog_name}' not found", ) # Check if it's a core cog if cog_name in bot.core_cogs: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Core cog '{cog_name}' cannot be disabled" + detail=f"Core cog '{cog_name}' cannot be disabled", ) except ImportError: # If we can't import the bot, we'll just assume the cog exists @@ -145,10 +171,12 @@ async def update_cog_status( if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update cog '{cog_name}' status" + detail=f"Failed to update cog '{cog_name}' status", ) - return {"message": f"Cog '{cog_name}' {'enabled' if enabled else 'disabled'} successfully"} + return { + "message": f"Cog '{cog_name}' {'enabled' if enabled else 'disabled'} successfully" + } except HTTPException: # Re-raise HTTP exceptions raise @@ -156,61 +184,72 @@ async def update_cog_status( log.error(f"Error updating cog status for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error updating cog status: {str(e)}" + detail=f"Error updating cog status: {str(e)}", ) -@router.patch("/guilds/{guild_id}/commands/{command_name}", status_code=status.HTTP_200_OK) + +@router.patch( + "/guilds/{guild_id}/commands/{command_name}", status_code=status.HTTP_200_OK +) async def update_command_status( guild_id: int, command_name: str, enabled: bool = Body(..., embed=True), _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Enable or disable a command for a guild.""" try: # Check if settings_manager is available from global_bot_accessor import get_bot_instance + bot = get_bot_instance() if not settings_manager or not bot or not bot.pg_pool: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Settings manager or database connection not available" + detail="Settings manager or database connection not available", ) # Check if the command exists try: import discord_bot_sync_api + bot = discord_bot_sync_api.bot_instance if not bot: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Bot instance not available" + detail="Bot instance not available", ) # Check if it's a prefix command command = bot.get_command(command_name) if not command: # Check if it's an app command - app_commands = [cmd for cmd in bot.tree.get_commands() if cmd.name == command_name] + app_commands = [ + cmd for cmd in bot.tree.get_commands() if cmd.name == command_name + ] if not app_commands: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Command '{command_name}' not found" + detail=f"Command '{command_name}' not found", ) except ImportError: # If we can't import the bot, we'll just assume the command exists log.warning("Bot sync API not available, skipping command existence check") # Update the command enabled status - success = await settings_manager.set_command_enabled(guild_id, command_name, enabled) + success = await settings_manager.set_command_enabled( + guild_id, command_name, enabled + ) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update command '{command_name}' status" + detail=f"Failed to update command '{command_name}' status", ) - return {"message": f"Command '{command_name}' {'enabled' if enabled else 'disabled'} successfully"} + return { + "message": f"Command '{command_name}' {'enabled' if enabled else 'disabled'} successfully" + } except HTTPException: # Re-raise HTTP exceptions raise @@ -218,5 +257,5 @@ async def update_command_status( log.error(f"Error updating command status for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error updating command status: {str(e)}" + detail=f"Error updating command status: {str(e)}", ) diff --git a/api_service/command_customization_endpoints.py b/api_service/command_customization_endpoints.py index 84c3ad2..a46d244 100644 --- a/api_service/command_customization_endpoints.py +++ b/api_service/command_customization_endpoints.py @@ -15,10 +15,10 @@ from api_service.dependencies import get_dashboard_user, verify_dashboard_guild_ from api_service.dashboard_models import ( CommandCustomizationResponse, CommandCustomizationUpdate, - GroupCustomizationUpdate, + GroupCustomizationUpdate, GroupCustomizationUpdate, CommandAliasAdd, - CommandAliasRemove + CommandAliasRemove, ) # Import settings_manager for database access (use absolute path) @@ -32,11 +32,12 @@ router = APIRouter() # --- Command Customization Endpoints --- + @router.get("/customizations/{guild_id}", response_model=CommandCustomizationResponse) async def get_command_customizations( guild_id: int, _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Get all command customizations for a guild.""" try: @@ -44,7 +45,7 @@ async def get_command_customizations( if not settings_manager: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Settings manager not available" + detail="Settings manager not available", ) # Get the bot instance to check if pools are available @@ -52,23 +53,27 @@ async def get_command_customizations( if not bot_instance or not bot_instance.pg_pool: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Database connection not available" + detail="Database connection not available", ) # Get command customizations - command_customizations = await settings_manager.get_all_command_customizations(guild_id) + command_customizations = await settings_manager.get_all_command_customizations( + guild_id + ) if command_customizations is None: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to get command customizations" + detail="Failed to get command customizations", ) # Get group customizations - group_customizations = await settings_manager.get_all_group_customizations(guild_id) + group_customizations = await settings_manager.get_all_group_customizations( + guild_id + ) if group_customizations is None: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to get group customizations" + detail="Failed to get group customizations", ) # Get command aliases @@ -76,21 +81,21 @@ async def get_command_customizations( if command_aliases is None: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to get command aliases" + detail="Failed to get command aliases", ) # Convert command_customizations to the new format formatted_command_customizations = {} for cmd_name, cmd_data in command_customizations.items(): formatted_command_customizations[cmd_name] = { - 'name': cmd_data.get('name', cmd_name), - 'description': cmd_data.get('description') + "name": cmd_data.get("name", cmd_name), + "description": cmd_data.get("description"), } return CommandCustomizationResponse( command_customizations=formatted_command_customizations, group_customizations=group_customizations, - command_aliases=command_aliases + command_aliases=command_aliases, ) except HTTPException: # Re-raise HTTP exceptions @@ -99,15 +104,16 @@ async def get_command_customizations( log.error(f"Error getting command customizations for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error getting command customizations: {str(e)}" + detail=f"Error getting command customizations: {str(e)}", ) + @router.post("/customizations/{guild_id}/commands", status_code=status.HTTP_200_OK) async def set_command_customization( guild_id: int, customization: CommandCustomizationUpdate, _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Set a custom name and/or description for a command in a guild.""" try: @@ -115,7 +121,7 @@ async def set_command_customization( if not settings_manager: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Settings manager not available" + detail="Settings manager not available", ) # Get the bot instance to check if pools are available @@ -123,21 +129,27 @@ async def set_command_customization( if not bot_instance or not bot_instance.pg_pool: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Database connection not available" + detail="Database connection not available", ) # Validate custom name format if provided if customization.custom_name is not None: - if not customization.custom_name.islower() or not customization.custom_name.replace('_', '').isalnum(): + if ( + not customization.custom_name.islower() + or not customization.custom_name.replace("_", "").isalnum() + ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Custom command names must be lowercase and contain only letters, numbers, and underscores" + detail="Custom command names must be lowercase and contain only letters, numbers, and underscores", ) - if len(customization.custom_name) < 1 or len(customization.custom_name) > 32: + if ( + len(customization.custom_name) < 1 + or len(customization.custom_name) > 32 + ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Custom command names must be between 1 and 32 characters long" + detail="Custom command names must be between 1 and 32 characters long", ) # Validate custom description if provided @@ -145,34 +157,30 @@ async def set_command_customization( if len(customization.custom_description) > 100: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Custom command descriptions must be 100 characters or less" + detail="Custom command descriptions must be 100 characters or less", ) # Set the custom command name name_success = await settings_manager.set_custom_command_name( - guild_id, - customization.command_name, - customization.custom_name + guild_id, customization.command_name, customization.custom_name ) if not name_success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to set custom command name" + detail="Failed to set custom command name", ) # Set the custom command description if provided if customization.custom_description is not None: desc_success = await settings_manager.set_custom_command_description( - guild_id, - customization.command_name, - customization.custom_description + guild_id, customization.command_name, customization.custom_description ) if not desc_success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to set custom command description" + detail="Failed to set custom command description", ) return {"message": "Command customization updated successfully"} @@ -183,15 +191,16 @@ async def set_command_customization( log.error(f"Error setting command customization for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error setting command customization: {str(e)}" + detail=f"Error setting command customization: {str(e)}", ) + @router.post("/customizations/{guild_id}/groups", status_code=status.HTTP_200_OK) async def set_group_customization( guild_id: int, customization: GroupCustomizationUpdate, _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Set a custom name for a command group in a guild.""" try: @@ -199,7 +208,7 @@ async def set_group_customization( if not settings_manager: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Settings manager not available" + detail="Settings manager not available", ) # Get the bot instance to check if pools are available @@ -207,34 +216,38 @@ async def set_group_customization( if not bot_instance or not bot_instance.pg_pool: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Database connection not available" + detail="Database connection not available", ) # Validate custom name format if provided if customization.custom_name is not None: - if not customization.custom_name.islower() or not customization.custom_name.replace('_', '').isalnum(): + if ( + not customization.custom_name.islower() + or not customization.custom_name.replace("_", "").isalnum() + ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Custom group names must be lowercase and contain only letters, numbers, and underscores" + detail="Custom group names must be lowercase and contain only letters, numbers, and underscores", ) - if len(customization.custom_name) < 1 or len(customization.custom_name) > 32: + if ( + len(customization.custom_name) < 1 + or len(customization.custom_name) > 32 + ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Custom group names must be between 1 and 32 characters long" + detail="Custom group names must be between 1 and 32 characters long", ) # Set the custom group name success = await settings_manager.set_custom_group_name( - guild_id, - customization.group_name, - customization.custom_name + guild_id, customization.group_name, customization.custom_name ) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to set custom group name" + detail="Failed to set custom group name", ) return {"message": "Group customization updated successfully"} @@ -245,15 +258,16 @@ async def set_group_customization( log.error(f"Error setting group customization for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error setting group customization: {str(e)}" + detail=f"Error setting group customization: {str(e)}", ) + @router.post("/customizations/{guild_id}/aliases", status_code=status.HTTP_200_OK) async def add_command_alias( guild_id: int, alias: CommandAliasAdd, _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Add an alias for a command in a guild.""" try: @@ -261,7 +275,7 @@ async def add_command_alias( if not settings_manager: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Settings manager not available" + detail="Settings manager not available", ) # Get the bot instance to check if pools are available @@ -269,33 +283,34 @@ async def add_command_alias( if not bot_instance or not bot_instance.pg_pool: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Database connection not available" + detail="Database connection not available", ) # Validate alias format - if not alias.alias_name.islower() or not alias.alias_name.replace('_', '').isalnum(): + if ( + not alias.alias_name.islower() + or not alias.alias_name.replace("_", "").isalnum() + ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Aliases must be lowercase and contain only letters, numbers, and underscores" + detail="Aliases must be lowercase and contain only letters, numbers, and underscores", ) if len(alias.alias_name) < 1 or len(alias.alias_name) > 32: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Aliases must be between 1 and 32 characters long" + detail="Aliases must be between 1 and 32 characters long", ) # Add the command alias success = await settings_manager.add_command_alias( - guild_id, - alias.command_name, - alias.alias_name + guild_id, alias.command_name, alias.alias_name ) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to add command alias" + detail="Failed to add command alias", ) return {"message": "Command alias added successfully"} @@ -306,15 +321,16 @@ async def add_command_alias( log.error(f"Error adding command alias for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error adding command alias: {str(e)}" + detail=f"Error adding command alias: {str(e)}", ) + @router.delete("/customizations/{guild_id}/aliases", status_code=status.HTTP_200_OK) async def remove_command_alias( guild_id: int, alias: CommandAliasRemove, _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Remove an alias for a command in a guild.""" try: @@ -322,7 +338,7 @@ async def remove_command_alias( if not settings_manager: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Settings manager not available" + detail="Settings manager not available", ) # Get the bot instance to check if pools are available @@ -330,20 +346,18 @@ async def remove_command_alias( if not bot_instance or not bot_instance.pg_pool: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Database connection not available" + detail="Database connection not available", ) # Remove the command alias success = await settings_manager.remove_command_alias( - guild_id, - alias.command_name, - alias.alias_name + guild_id, alias.command_name, alias.alias_name ) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to remove command alias" + detail="Failed to remove command alias", ) return {"message": "Command alias removed successfully"} @@ -354,24 +368,27 @@ async def remove_command_alias( log.error(f"Error removing command alias for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error removing command alias: {str(e)}" + detail=f"Error removing command alias: {str(e)}", ) + @router.post("/customizations/{guild_id}/sync", status_code=status.HTTP_200_OK) async def sync_guild_commands( guild_id: int, _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Sync commands for a guild to apply customizations.""" try: # This endpoint would trigger a command sync for the guild # In a real implementation, this would communicate with the bot to sync commands # For now, we'll just return a success message - return {"message": "Command sync requested. This may take a moment to complete."} + return { + "message": "Command sync requested. This may take a moment to complete." + } except Exception as e: log.error(f"Error syncing commands for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error syncing commands: {str(e)}" + detail=f"Error syncing commands: {str(e)}", ) diff --git a/api_service/dashboard_api_endpoints.py b/api_service/dashboard_api_endpoints.py index 4fa5cc6..3f67202 100644 --- a/api_service/dashboard_api_endpoints.py +++ b/api_service/dashboard_api_endpoints.py @@ -33,13 +33,16 @@ import settings_manager # Import custom bot manager import sys import os -parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if parent_dir not in sys.path: sys.path.append(parent_dir) try: import custom_bot_manager except ImportError: - print("Warning: Could not import custom_bot_manager. Custom bot functionality will be disabled.") + print( + "Warning: Could not import custom_bot_manager. Custom bot functionality will be disabled." + ) custom_bot_manager = None # Set up logging @@ -48,12 +51,14 @@ log = logging.getLogger(__name__) # Create a router for the dashboard API endpoints router = APIRouter(tags=["Dashboard API"]) + # --- Models --- class Channel(BaseModel): id: str name: str type: int # 0 = text, 2 = voice, etc. + class Role(BaseModel): id: str name: str @@ -61,10 +66,12 @@ class Role(BaseModel): position: int permissions: str + class Command(BaseModel): name: str description: Optional[str] = None + class Conversation(BaseModel): id: str title: str @@ -72,12 +79,14 @@ class Conversation(BaseModel): updated_at: str message_count: int + class Message(BaseModel): id: str content: str role: str # 'user' or 'assistant' created_at: str + class ThemeSettings(BaseModel): theme_mode: str = "light" # "light", "dark", "custom" primary_color: str = "#5865F2" # Discord blue @@ -86,6 +95,7 @@ class ThemeSettings(BaseModel): font_family: str = "Inter, sans-serif" custom_css: Optional[str] = None + class GlobalSettings(BaseModel): system_message: Optional[str] = None character: Optional[str] = None @@ -103,25 +113,26 @@ class GlobalSettings(BaseModel): custom_bot_status_text: Optional[str] = None custom_bot_status_type: Optional[str] = None -# CogInfo and CommandInfo models are now imported from dashboard_models + # CogInfo and CommandInfo models are now imported from dashboard_models -# class CommandInfo(BaseModel): # Removed - Imported from dashboard_models -# name: str -# description: Optional[str] = None + # class CommandInfo(BaseModel): # Removed - Imported from dashboard_models + # name: str + # description: Optional[str] = None enabled: bool = True cog_name: Optional[str] = None + class Guild(BaseModel): id: str name: str icon_url: Optional[str] = None + # --- Endpoints --- + @router.get("/user") -async def get_dashboard_user_info( - user: dict = Depends(get_dashboard_user) -): +async def get_dashboard_user_info(user: dict = Depends(get_dashboard_user)): """Get information about the currently authenticated user.""" try: # Return user information without sensitive data @@ -130,20 +141,19 @@ async def get_dashboard_user_info( "username": user.get("username"), "discriminator": user.get("discriminator"), "avatar": user.get("avatar"), - "email": user.get("email") + "email": user.get("email"), } return user_info except Exception as e: log.error(f"Error getting user info: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error getting user info: {str(e)}" + detail=f"Error getting user info: {str(e)}", ) + @router.get("/user-guilds", response_model=List[Guild]) -async def get_user_guilds( - user: dict = Depends(get_dashboard_user) -): +async def get_user_guilds(user: dict = Depends(get_dashboard_user)): """Get all guilds the user is an admin of.""" try: # First, try to use the real implementation from api_server.py @@ -159,26 +169,30 @@ async def get_user_guilds( for guild in guilds_data: # Create icon URL if icon is available icon_url = None - if guild.get('icon'): + if guild.get("icon"): icon_url = f"https://cdn.discordapp.com/icons/{guild['id']}/{guild['icon']}.png" - guilds.append(Guild( - id=guild['id'], - name=guild['name'], - icon_url=icon_url - )) + guilds.append( + Guild(id=guild["id"], name=guild["name"], icon_url=icon_url) + ) - log.info(f"Successfully fetched {len(guilds)} guilds for user {user.get('user_id')} using api_server implementation") + log.info( + f"Successfully fetched {len(guilds)} guilds for user {user.get('user_id')} using api_server implementation" + ) return guilds except ImportError as e: - log.warning(f"Could not import dashboard_get_user_guilds from api_server: {e}") + log.warning( + f"Could not import dashboard_get_user_guilds from api_server: {e}" + ) # Fall back to direct implementation except Exception as e: log.warning(f"Error using dashboard_get_user_guilds from api_server: {e}") # Check if we got an empty list back (which is a valid response now) if isinstance(guilds_data, list) and len(guilds_data) == 0: - log.info("Received empty guild list from api_server, returning empty list") + log.info( + "Received empty guild list from api_server, returning empty list" + ) return [] # Otherwise fall back to direct implementation @@ -189,12 +203,16 @@ async def get_user_guilds( # Check if settings_manager is available bot = get_bot_instance() if not bot or not bot.pg_pool: - log.warning("Bot instance or PostgreSQL pool not available for get_user_guilds") + log.warning( + "Bot instance or PostgreSQL pool not available for get_user_guilds" + ) # Fall back to mock data since we can't access the real data log.info("Using mock data for user guilds as fallback") return [ Guild(id="123456789", name="My Awesome Server (Mock)", icon_url=None), - Guild(id="987654321", name="Another Great Server (Mock)", icon_url=None) + Guild( + id="987654321", name="Another Great Server (Mock)", icon_url=None + ), ] # Get Discord API URLs from environment or use defaults @@ -202,40 +220,64 @@ async def get_user_guilds( DISCORD_USER_GUILDS_URL = f"{DISCORD_API_URL}/users/@me/guilds" # Get access token from user - access_token = user.get('access_token') + access_token = user.get("access_token") if not access_token: log.warning("Access token not found in user session, returning mock data") # Return mock data instead of raising an exception return [ Guild(id="123456789", name="My Awesome Server (Mock)", icon_url=None), - Guild(id="987654321", name="Another Great Server (Mock)", icon_url=None) + Guild( + id="987654321", name="Another Great Server (Mock)", icon_url=None + ), ] # Create headers for Discord API request - user_headers = {'Authorization': f'Bearer {access_token}'} + user_headers = {"Authorization": f"Bearer {access_token}"} # Create a temporary aiohttp session async with aiohttp.ClientSession() as session: # 1. Fetch guilds user is in from Discord try: log.debug(f"Fetching user guilds from {DISCORD_USER_GUILDS_URL}") - async with session.get(DISCORD_USER_GUILDS_URL, headers=user_headers) as resp: + async with session.get( + DISCORD_USER_GUILDS_URL, headers=user_headers + ) as resp: if resp.status == 401: - log.warning("Discord API authentication failed (401). Returning mock data.") + log.warning( + "Discord API authentication failed (401). Returning mock data." + ) # Return mock data instead of raising an exception return [ - Guild(id="123456789", name="My Awesome Server (Mock)", icon_url=None), - Guild(id="987654321", name="Another Great Server (Mock)", icon_url=None) + Guild( + id="123456789", + name="My Awesome Server (Mock)", + icon_url=None, + ), + Guild( + id="987654321", + name="Another Great Server (Mock)", + icon_url=None, + ), ] resp.raise_for_status() user_guilds = await resp.json() - log.debug(f"Fetched {len(user_guilds)} guilds for user {user.get('user_id')}") + log.debug( + f"Fetched {len(user_guilds)} guilds for user {user.get('user_id')}" + ) except Exception as e: - log.warning(f"Error fetching user guilds from Discord API: {e}. Returning mock data.") + log.warning( + f"Error fetching user guilds from Discord API: {e}. Returning mock data." + ) # Return mock data on any Discord API error return [ - Guild(id="123456789", name="My Awesome Server (Mock)", icon_url=None), - Guild(id="987654321", name="Another Great Server (Mock)", icon_url=None) + Guild( + id="123456789", name="My Awesome Server (Mock)", icon_url=None + ), + Guild( + id="987654321", + name="Another Great Server (Mock)", + icon_url=None, + ), ] # 2. Get bot guilds from the bot instance @@ -248,33 +290,39 @@ async def get_user_guilds( try: async with bot.pg_pool.acquire() as conn: records = await conn.fetch("SELECT guild_id FROM guilds") - bot_guild_ids = {record['guild_id'] for record in records} - log.debug(f"Fetched {len(bot_guild_ids)} guild IDs from database") + bot_guild_ids = {record["guild_id"] for record in records} + log.debug( + f"Fetched {len(bot_guild_ids)} guild IDs from database" + ) except Exception as e: log.warning(f"Error fetching bot guild IDs from database: {e}") # Instead of raising an exception, continue with an empty set - log.info("Using empty guild set as fallback to allow dashboard to function") + log.info( + "Using empty guild set as fallback to allow dashboard to function" + ) # 3. Filter user guilds manageable_guilds = [] ADMINISTRATOR_PERMISSION = 0x8 for guild in user_guilds: - guild_id = int(guild['id']) - permissions = int(guild['permissions']) + guild_id = int(guild["id"]) + permissions = int(guild["permissions"]) - if (permissions & ADMINISTRATOR_PERMISSION) == ADMINISTRATOR_PERMISSION and guild_id in bot_guild_ids: + if ( + permissions & ADMINISTRATOR_PERMISSION + ) == ADMINISTRATOR_PERMISSION and guild_id in bot_guild_ids: # Create icon URL if icon is available icon_url = None - if guild.get('icon'): + if guild.get("icon"): icon_url = f"https://cdn.discordapp.com/icons/{guild['id']}/{guild['icon']}.png" - manageable_guilds.append(Guild( - id=guild['id'], - name=guild['name'], - icon_url=icon_url - )) + manageable_guilds.append( + Guild(id=guild["id"], name=guild["name"], icon_url=icon_url) + ) - log.info(f"Found {len(manageable_guilds)} manageable guilds for user {user.get('user_id')}") + log.info( + f"Found {len(manageable_guilds)} manageable guilds for user {user.get('user_id')}" + ) return manageable_guilds except HTTPException: @@ -286,14 +334,17 @@ async def get_user_guilds( log.warning("Returning mock data due to error in get_user_guilds") return [ Guild(id="123456789", name="My Awesome Server (Mock)", icon_url=None), - Guild(id="987654321", name="Another Great Server (Mock)", icon_url=None) + Guild(id="987654321", name="Another Great Server (Mock)", icon_url=None), ] + @router.get("/guilds/{guild_id}/channels", response_model=List[Channel]) async def get_guild_channels( guild_id: int, - _user: dict = Depends(get_dashboard_user), # Underscore prefix to indicate unused parameter - _: bool = Depends(verify_dashboard_guild_admin) + _user: dict = Depends( + get_dashboard_user + ), # Underscore prefix to indicate unused parameter + _: bool = Depends(verify_dashboard_guild_admin), ): """Get all channels for a guild.""" try: @@ -304,21 +355,24 @@ async def get_guild_channels( Channel(id="123456789", name="general", type=0), Channel(id="123456790", name="welcome", type=0), Channel(id="123456791", name="announcements", type=0), - Channel(id="123456792", name="voice-chat", type=2) + Channel(id="123456792", name="voice-chat", type=2), ] return channels except Exception as e: log.error(f"Error getting channels for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error getting channels: {str(e)}" + detail=f"Error getting channels: {str(e)}", ) + @router.get("/guilds/{guild_id}/roles", response_model=List[Role]) async def get_guild_roles( guild_id: int, - _user: dict = Depends(get_dashboard_user), # Underscore prefix to indicate unused parameter - _admin: bool = Depends(verify_dashboard_guild_admin) + _user: dict = Depends( + get_dashboard_user + ), # Underscore prefix to indicate unused parameter + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Get all roles for a guild.""" try: @@ -326,24 +380,41 @@ async def get_guild_roles( # For now, we'll return a mock response # TODO: Replace mock data with actual API call to Discord roles = [ - Role(id="123456789", name="@everyone", color=0, position=0, permissions="0"), - Role(id="123456790", name="Admin", color=16711680, position=1, permissions="8"), - Role(id="123456791", name="Moderator", color=65280, position=2, permissions="4"), - Role(id="123456792", name="Member", color=255, position=3, permissions="1") + Role( + id="123456789", name="@everyone", color=0, position=0, permissions="0" + ), + Role( + id="123456790", + name="Admin", + color=16711680, + position=1, + permissions="8", + ), + Role( + id="123456791", + name="Moderator", + color=65280, + position=2, + permissions="4", + ), + Role(id="123456792", name="Member", color=255, position=3, permissions="1"), ] return roles except Exception as e: log.error(f"Error getting roles for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error getting roles: {str(e)}" + detail=f"Error getting roles: {str(e)}", ) + @router.get("/guilds/{guild_id}/commands", response_model=List[Command]) async def get_guild_commands( guild_id: int, - _user: dict = Depends(get_dashboard_user), # Underscore prefix to indicate unused parameter - _admin: bool = Depends(verify_dashboard_guild_admin) + _user: dict = Depends( + get_dashboard_user + ), # Underscore prefix to indicate unused parameter + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Get all commands available in the guild.""" try: @@ -361,49 +432,59 @@ async def get_guild_commands( Command(name="ai", description="Get AI response"), Command(name="aiset", description="Configure AI settings"), Command(name="chat", description="Chat with AI"), - Command(name="convs", description="Manage conversations") + Command(name="convs", description="Manage conversations"), ] return commands except Exception as e: log.error(f"Error getting commands for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error getting commands: {str(e)}" + detail=f"Error getting commands: {str(e)}", ) + # --- Command Customization Endpoints --- -@router.get("/guilds/{guild_id}/command-customizations", response_model=CommandCustomizationResponse) + +@router.get( + "/guilds/{guild_id}/command-customizations", + response_model=CommandCustomizationResponse, +) async def get_command_customizations( guild_id: int, _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Get all command customizations for a guild.""" try: # Check if settings_manager is available from global_bot_accessor import get_bot_instance + bot = get_bot_instance() if not settings_manager or not bot or not bot.pg_pool: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Settings manager or database connection not available" + detail="Settings manager or database connection not available", ) # Get command customizations - command_customizations = await settings_manager.get_all_command_customizations(guild_id) + command_customizations = await settings_manager.get_all_command_customizations( + guild_id + ) if command_customizations is None: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to get command customizations" + detail="Failed to get command customizations", ) # Get group customizations - group_customizations = await settings_manager.get_all_group_customizations(guild_id) + group_customizations = await settings_manager.get_all_group_customizations( + guild_id + ) if group_customizations is None: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to get group customizations" + detail="Failed to get group customizations", ) # Get command aliases @@ -411,13 +492,13 @@ async def get_command_customizations( if command_aliases is None: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to get command aliases" + detail="Failed to get command aliases", ) return CommandCustomizationResponse( command_customizations=command_customizations, group_customizations=group_customizations, - command_aliases=command_aliases + command_aliases=command_aliases, ) except HTTPException: # Re-raise HTTP exceptions @@ -426,52 +507,60 @@ async def get_command_customizations( log.error(f"Error getting command customizations for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error getting command customizations: {str(e)}" + detail=f"Error getting command customizations: {str(e)}", ) -@router.post("/guilds/{guild_id}/command-customizations/commands", status_code=status.HTTP_200_OK) + +@router.post( + "/guilds/{guild_id}/command-customizations/commands", status_code=status.HTTP_200_OK +) async def set_command_customization( guild_id: int, customization: CommandCustomizationUpdate, _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Set a custom name for a command in a guild.""" try: # Check if settings_manager is available from global_bot_accessor import get_bot_instance + bot = get_bot_instance() if not settings_manager or not bot or not bot.pg_pool: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Settings manager or database connection not available" + detail="Settings manager or database connection not available", ) # Validate custom name format if provided if customization.custom_name is not None: - if not customization.custom_name.islower() or not customization.custom_name.replace('_', '').isalnum(): + if ( + not customization.custom_name.islower() + or not customization.custom_name.replace("_", "").isalnum() + ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Custom command names must be lowercase and contain only letters, numbers, and underscores" + detail="Custom command names must be lowercase and contain only letters, numbers, and underscores", ) - if len(customization.custom_name) < 1 or len(customization.custom_name) > 32: + if ( + len(customization.custom_name) < 1 + or len(customization.custom_name) > 32 + ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Custom command names must be between 1 and 32 characters long" + detail="Custom command names must be between 1 and 32 characters long", ) # Set the custom command name success = await settings_manager.set_custom_command_name( - guild_id, - customization.command_name, - customization.custom_name + guild_id, customization.command_name, customization.custom_name ) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to set custom command name" + detail="Failed to set custom command name", ) return {"message": "Command customization updated successfully"} @@ -482,52 +571,60 @@ async def set_command_customization( log.error(f"Error setting command customization for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error setting command customization: {str(e)}" + detail=f"Error setting command customization: {str(e)}", ) -@router.post("/guilds/{guild_id}/command-customizations/groups", status_code=status.HTTP_200_OK) + +@router.post( + "/guilds/{guild_id}/command-customizations/groups", status_code=status.HTTP_200_OK +) async def set_group_customization( guild_id: int, customization: GroupCustomizationUpdate, _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Set a custom name for a command group in a guild.""" try: # Check if settings_manager is available from global_bot_accessor import get_bot_instance + bot = get_bot_instance() if not settings_manager or not bot or not bot.pg_pool: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Settings manager or database connection not available" + detail="Settings manager or database connection not available", ) # Validate custom name format if provided if customization.custom_name is not None: - if not customization.custom_name.islower() or not customization.custom_name.replace('_', '').isalnum(): + if ( + not customization.custom_name.islower() + or not customization.custom_name.replace("_", "").isalnum() + ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Custom group names must be lowercase and contain only letters, numbers, and underscores" + detail="Custom group names must be lowercase and contain only letters, numbers, and underscores", ) - if len(customization.custom_name) < 1 or len(customization.custom_name) > 32: + if ( + len(customization.custom_name) < 1 + or len(customization.custom_name) > 32 + ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Custom group names must be between 1 and 32 characters long" + detail="Custom group names must be between 1 and 32 characters long", ) # Set the custom group name success = await settings_manager.set_custom_group_name( - guild_id, - customization.group_name, - customization.custom_name + guild_id, customization.group_name, customization.custom_name ) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to set custom group name" + detail="Failed to set custom group name", ) return {"message": "Group customization updated successfully"} @@ -538,51 +635,56 @@ async def set_group_customization( log.error(f"Error setting group customization for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error setting group customization: {str(e)}" + detail=f"Error setting group customization: {str(e)}", ) -@router.post("/guilds/{guild_id}/command-customizations/aliases", status_code=status.HTTP_200_OK) + +@router.post( + "/guilds/{guild_id}/command-customizations/aliases", status_code=status.HTTP_200_OK +) async def add_command_alias( guild_id: int, alias: CommandAliasAdd, _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Add an alias for a command in a guild.""" try: # Check if settings_manager is available from global_bot_accessor import get_bot_instance + bot = get_bot_instance() if not settings_manager or not bot or not bot.pg_pool: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Settings manager or database connection not available" + detail="Settings manager or database connection not available", ) # Validate alias format - if not alias.alias_name.islower() or not alias.alias_name.replace('_', '').isalnum(): + if ( + not alias.alias_name.islower() + or not alias.alias_name.replace("_", "").isalnum() + ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Aliases must be lowercase and contain only letters, numbers, and underscores" + detail="Aliases must be lowercase and contain only letters, numbers, and underscores", ) if len(alias.alias_name) < 1 or len(alias.alias_name) > 32: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Aliases must be between 1 and 32 characters long" + detail="Aliases must be between 1 and 32 characters long", ) # Add the command alias success = await settings_manager.add_command_alias( - guild_id, - alias.command_name, - alias.alias_name + guild_id, alias.command_name, alias.alias_name ) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to add command alias" + detail="Failed to add command alias", ) return {"message": "Command alias added successfully"} @@ -593,38 +695,40 @@ async def add_command_alias( log.error(f"Error adding command alias for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error adding command alias: {str(e)}" + detail=f"Error adding command alias: {str(e)}", ) -@router.delete("/guilds/{guild_id}/command-customizations/aliases", status_code=status.HTTP_200_OK) + +@router.delete( + "/guilds/{guild_id}/command-customizations/aliases", status_code=status.HTTP_200_OK +) async def remove_command_alias( guild_id: int, alias: CommandAliasRemove, _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Remove an alias for a command in a guild.""" try: # Check if settings_manager is available from global_bot_accessor import get_bot_instance + bot = get_bot_instance() if not settings_manager or not bot or not bot.pg_pool: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Settings manager or database connection not available" + detail="Settings manager or database connection not available", ) # Remove the command alias success = await settings_manager.remove_command_alias( - guild_id, - alias.command_name, - alias.alias_name + guild_id, alias.command_name, alias.alias_name ) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to remove command alias" + detail="Failed to remove command alias", ) return {"message": "Command alias removed successfully"} @@ -635,14 +739,15 @@ async def remove_command_alias( log.error(f"Error removing command alias for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error removing command alias: {str(e)}" + detail=f"Error removing command alias: {str(e)}", ) + @router.get("/guilds/{guild_id}/settings", response_model=Dict[str, Any]) async def get_guild_settings( guild_id: int, _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Get settings for a guild.""" try: @@ -650,13 +755,18 @@ async def get_guild_settings( if not settings_manager: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Settings manager not available" + detail="Settings manager not available", ) # Try to get the API server's pool from FastAPI app try: from api_service.api_server import app - has_api_pool = hasattr(app, 'state') and hasattr(app.state, 'pg_pool') and app.state.pg_pool + + has_api_pool = ( + hasattr(app, "state") + and hasattr(app.state, "pg_pool") + and app.state.pg_pool + ) log.info(f"API server pool available: {has_api_pool}") except (ImportError, AttributeError): has_api_pool = False @@ -664,8 +774,9 @@ async def get_guild_settings( # Check bot pool as fallback from global_bot_accessor import get_bot_instance + bot = get_bot_instance() - has_bot_pool = bot and hasattr(bot, 'pg_pool') and bot.pg_pool + has_bot_pool = bot and hasattr(bot, "pg_pool") and bot.pg_pool log.info(f"Bot pool available: {has_bot_pool}") if not has_api_pool and not has_bot_pool: @@ -680,61 +791,91 @@ async def get_guild_settings( "goodbye_channel_id": None, "goodbye_message": None, "cogs": {}, - "commands": {} + "commands": {}, } # Get prefix with error handling try: - settings["prefix"] = await settings_manager.get_guild_prefix(guild_id, DEFAULT_PREFIX) + settings["prefix"] = await settings_manager.get_guild_prefix( + guild_id, DEFAULT_PREFIX + ) except RuntimeError as e: if "got Future" in str(e) and "attached to a different loop" in str(e): - log.warning(f"Event loop error getting prefix for guild {guild_id}: {e}") + log.warning( + f"Event loop error getting prefix for guild {guild_id}: {e}" + ) # Keep default prefix else: log.warning(f"Runtime error getting prefix for guild {guild_id}: {e}") # Keep default prefix except Exception as e: - log.warning(f"Error getting prefix for guild {guild_id}, using default: {e}") + log.warning( + f"Error getting prefix for guild {guild_id}, using default: {e}" + ) # Keep default prefix # Get welcome/goodbye settings with error handling try: - settings["welcome_channel_id"] = await settings_manager.get_setting(guild_id, 'welcome_channel_id') + settings["welcome_channel_id"] = await settings_manager.get_setting( + guild_id, "welcome_channel_id" + ) except RuntimeError as e: if "got Future" in str(e) and "attached to a different loop" in str(e): - log.warning(f"Event loop error getting welcome_channel_id for guild {guild_id}: {e}") + log.warning( + f"Event loop error getting welcome_channel_id for guild {guild_id}: {e}" + ) else: - log.warning(f"Runtime error getting welcome_channel_id for guild {guild_id}: {e}") + log.warning( + f"Runtime error getting welcome_channel_id for guild {guild_id}: {e}" + ) except Exception as e: log.warning(f"Error getting welcome_channel_id for guild {guild_id}: {e}") try: - settings["welcome_message"] = await settings_manager.get_setting(guild_id, 'welcome_message') + settings["welcome_message"] = await settings_manager.get_setting( + guild_id, "welcome_message" + ) except RuntimeError as e: if "got Future" in str(e) and "attached to a different loop" in str(e): - log.warning(f"Event loop error getting welcome_message for guild {guild_id}: {e}") + log.warning( + f"Event loop error getting welcome_message for guild {guild_id}: {e}" + ) else: - log.warning(f"Runtime error getting welcome_message for guild {guild_id}: {e}") + log.warning( + f"Runtime error getting welcome_message for guild {guild_id}: {e}" + ) except Exception as e: log.warning(f"Error getting welcome_message for guild {guild_id}: {e}") try: - settings["goodbye_channel_id"] = await settings_manager.get_setting(guild_id, 'goodbye_channel_id') + settings["goodbye_channel_id"] = await settings_manager.get_setting( + guild_id, "goodbye_channel_id" + ) except RuntimeError as e: if "got Future" in str(e) and "attached to a different loop" in str(e): - log.warning(f"Event loop error getting goodbye_channel_id for guild {guild_id}: {e}") + log.warning( + f"Event loop error getting goodbye_channel_id for guild {guild_id}: {e}" + ) else: - log.warning(f"Runtime error getting goodbye_channel_id for guild {guild_id}: {e}") + log.warning( + f"Runtime error getting goodbye_channel_id for guild {guild_id}: {e}" + ) except Exception as e: log.warning(f"Error getting goodbye_channel_id for guild {guild_id}: {e}") try: - settings["goodbye_message"] = await settings_manager.get_setting(guild_id, 'goodbye_message') + settings["goodbye_message"] = await settings_manager.get_setting( + guild_id, "goodbye_message" + ) except RuntimeError as e: if "got Future" in str(e) and "attached to a different loop" in str(e): - log.warning(f"Event loop error getting goodbye_message for guild {guild_id}: {e}") + log.warning( + f"Event loop error getting goodbye_message for guild {guild_id}: {e}" + ) else: - log.warning(f"Runtime error getting goodbye_message for guild {guild_id}: {e}") + log.warning( + f"Runtime error getting goodbye_message for guild {guild_id}: {e}" + ) except Exception as e: log.warning(f"Error getting goodbye_message for guild {guild_id}: {e}") @@ -743,10 +884,14 @@ async def get_guild_settings( settings["cogs"] = await settings_manager.get_all_enabled_cogs(guild_id) except RuntimeError as e: if "got Future" in str(e) and "attached to a different loop" in str(e): - log.warning(f"Event loop error getting cog enabled statuses for guild {guild_id}: {e}") + log.warning( + f"Event loop error getting cog enabled statuses for guild {guild_id}: {e}" + ) # Keep empty dict for cogs else: - log.warning(f"Runtime error getting cog enabled statuses for guild {guild_id}: {e}") + log.warning( + f"Runtime error getting cog enabled statuses for guild {guild_id}: {e}" + ) # Keep empty dict for cogs except Exception as e: log.warning(f"Error getting cog enabled statuses for guild {guild_id}: {e}") @@ -754,16 +899,24 @@ async def get_guild_settings( # Get command enabled statuses with error handling try: - settings["commands"] = await settings_manager.get_all_enabled_commands(guild_id) + settings["commands"] = await settings_manager.get_all_enabled_commands( + guild_id + ) except RuntimeError as e: if "got Future" in str(e) and "attached to a different loop" in str(e): - log.warning(f"Event loop error getting command enabled statuses for guild {guild_id}: {e}") + log.warning( + f"Event loop error getting command enabled statuses for guild {guild_id}: {e}" + ) # Keep empty dict for commands else: - log.warning(f"Runtime error getting command enabled statuses for guild {guild_id}: {e}") + log.warning( + f"Runtime error getting command enabled statuses for guild {guild_id}: {e}" + ) # Keep empty dict for commands except Exception as e: - log.warning(f"Error getting command enabled statuses for guild {guild_id}: {e}") + log.warning( + f"Error getting command enabled statuses for guild {guild_id}: {e}" + ) # Keep empty dict for commands return settings @@ -777,40 +930,44 @@ async def get_guild_settings( # Return a more helpful error message raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Database connection error. Please try again." + detail="Database connection error. Please try again.", ) else: log.error(f"Runtime error getting settings for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error getting settings: {str(e)}" + detail=f"Error getting settings: {str(e)}", ) except Exception as e: log.error(f"Error getting settings for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error getting settings: {str(e)}" + detail=f"Error getting settings: {str(e)}", ) + @router.patch("/guilds/{guild_id}/settings", status_code=status.HTTP_200_OK) async def update_guild_settings( guild_id: int, settings_update: Dict[str, Any] = Body(...), _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Update settings for a guild.""" try: # Check if settings_manager is available from global_bot_accessor import get_bot_instance + bot = get_bot_instance() if not settings_manager or not bot or not bot.pg_pool: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Settings manager or database connection not available" + detail="Settings manager or database connection not available", ) - log.info(f"Updating settings for guild {guild_id} requested by user {_user.get('user_id')}") + log.info( + f"Updating settings for guild {guild_id} requested by user {_user.get('user_id')}" + ) log.debug(f"Update data received: {settings_update}") success_flags = [] @@ -818,73 +975,113 @@ async def update_guild_settings( # Get bot instance for core cogs check try: import discord_bot_sync_api + bot = discord_bot_sync_api.bot_instance - core_cogs_list = bot.core_cogs if bot and hasattr(bot, 'core_cogs') else {'SettingsCog', 'HelpCog'} + core_cogs_list = ( + bot.core_cogs + if bot and hasattr(bot, "core_cogs") + else {"SettingsCog", "HelpCog"} + ) except ImportError: - core_cogs_list = {'SettingsCog', 'HelpCog'} # Core cogs that cannot be disabled + core_cogs_list = { + "SettingsCog", + "HelpCog", + } # Core cogs that cannot be disabled # Update prefix if provided - if 'prefix' in settings_update: - success = await settings_manager.set_guild_prefix(guild_id, settings_update['prefix']) + if "prefix" in settings_update: + success = await settings_manager.set_guild_prefix( + guild_id, settings_update["prefix"] + ) success_flags.append(success) if not success: log.error(f"Failed to update prefix for guild {guild_id}") # Update welcome channel if provided - if 'welcome_channel_id' in settings_update: - value = settings_update['welcome_channel_id'] if settings_update['welcome_channel_id'] else None - success = await settings_manager.set_setting(guild_id, 'welcome_channel_id', value) + if "welcome_channel_id" in settings_update: + value = ( + settings_update["welcome_channel_id"] + if settings_update["welcome_channel_id"] + else None + ) + success = await settings_manager.set_setting( + guild_id, "welcome_channel_id", value + ) success_flags.append(success) if not success: log.error(f"Failed to update welcome_channel_id for guild {guild_id}") # Update welcome message if provided - if 'welcome_message' in settings_update: - success = await settings_manager.set_setting(guild_id, 'welcome_message', settings_update['welcome_message']) + if "welcome_message" in settings_update: + success = await settings_manager.set_setting( + guild_id, "welcome_message", settings_update["welcome_message"] + ) success_flags.append(success) if not success: log.error(f"Failed to update welcome_message for guild {guild_id}") # Update goodbye channel if provided - if 'goodbye_channel_id' in settings_update: - value = settings_update['goodbye_channel_id'] if settings_update['goodbye_channel_id'] else None - success = await settings_manager.set_setting(guild_id, 'goodbye_channel_id', value) + if "goodbye_channel_id" in settings_update: + value = ( + settings_update["goodbye_channel_id"] + if settings_update["goodbye_channel_id"] + else None + ) + success = await settings_manager.set_setting( + guild_id, "goodbye_channel_id", value + ) success_flags.append(success) if not success: log.error(f"Failed to update goodbye_channel_id for guild {guild_id}") # Update goodbye message if provided - if 'goodbye_message' in settings_update: - success = await settings_manager.set_setting(guild_id, 'goodbye_message', settings_update['goodbye_message']) + if "goodbye_message" in settings_update: + success = await settings_manager.set_setting( + guild_id, "goodbye_message", settings_update["goodbye_message"] + ) success_flags.append(success) if not success: log.error(f"Failed to update goodbye_message for guild {guild_id}") # Update cogs if provided - if 'cogs' in settings_update and isinstance(settings_update['cogs'], dict): - for cog_name, enabled_status in settings_update['cogs'].items(): + if "cogs" in settings_update and isinstance(settings_update["cogs"], dict): + for cog_name, enabled_status in settings_update["cogs"].items(): if cog_name not in core_cogs_list: - success = await settings_manager.set_cog_enabled(guild_id, cog_name, enabled_status) + success = await settings_manager.set_cog_enabled( + guild_id, cog_name, enabled_status + ) success_flags.append(success) if not success: - log.error(f"Failed to update status for cog '{cog_name}' for guild {guild_id}") + log.error( + f"Failed to update status for cog '{cog_name}' for guild {guild_id}" + ) else: - log.warning(f"Attempted to change status of core cog '{cog_name}' for guild {guild_id} - ignored.") + log.warning( + f"Attempted to change status of core cog '{cog_name}' for guild {guild_id} - ignored." + ) # Update commands if provided - if 'commands' in settings_update and isinstance(settings_update['commands'], dict): - for command_name, enabled_status in settings_update['commands'].items(): - success = await settings_manager.set_command_enabled(guild_id, command_name, enabled_status) + if "commands" in settings_update and isinstance( + settings_update["commands"], dict + ): + for command_name, enabled_status in settings_update["commands"].items(): + success = await settings_manager.set_command_enabled( + guild_id, command_name, enabled_status + ) success_flags.append(success) if not success: - log.error(f"Failed to update status for command '{command_name}' for guild {guild_id}") + log.error( + f"Failed to update status for command '{command_name}' for guild {guild_id}" + ) - if all(s is True for s in success_flags): # Check if all operations returned True + if all( + s is True for s in success_flags + ): # Check if all operations returned True return {"message": "Settings updated successfully."} else: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="One or more settings failed to update. Check server logs." + detail="One or more settings failed to update. Check server logs.", ) except HTTPException: # Re-raise HTTP exceptions @@ -893,14 +1090,15 @@ async def update_guild_settings( log.error(f"Error updating settings for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error updating settings: {str(e)}" + detail=f"Error updating settings: {str(e)}", ) + @router.post("/guilds/{guild_id}/sync-commands", status_code=status.HTTP_200_OK) async def sync_guild_commands( guild_id: int, _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Sync commands for a guild to apply customizations.""" try: @@ -908,19 +1106,22 @@ async def sync_guild_commands( # In a real implementation, this would communicate with the bot to sync commands # For now, we'll just return a success message # TODO: Implement actual command syncing logic - return {"message": "Command sync requested. This may take a moment to complete."} + return { + "message": "Command sync requested. This may take a moment to complete." + } except Exception as e: log.error(f"Error syncing commands for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error syncing commands: {str(e)}" + detail=f"Error syncing commands: {str(e)}", ) + @router.post("/guilds/{guild_id}/test-welcome", status_code=status.HTTP_200_OK) async def test_welcome_message( guild_id: int, _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Test the welcome message for a guild.""" # This endpoint is now handled by the main API server @@ -941,14 +1142,15 @@ async def test_welcome_message( log.error(f"Error testing welcome message for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error testing welcome message: {str(e)}" + detail=f"Error testing welcome message: {str(e)}", ) + @router.post("/guilds/{guild_id}/test-goodbye", status_code=status.HTTP_200_OK) async def test_goodbye_message( guild_id: int, _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Test the goodbye message for a guild.""" # This endpoint is now handled by the main API server @@ -969,15 +1171,15 @@ async def test_goodbye_message( log.error(f"Error testing goodbye message for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error testing goodbye message: {str(e)}" + detail=f"Error testing goodbye message: {str(e)}", ) + # --- Global Settings Endpoints --- + @router.get("/settings", response_model=GlobalSettings) -async def get_global_settings( - _user: dict = Depends(get_dashboard_user) -): +async def get_global_settings(_user: dict = Depends(get_dashboard_user)): """Get global settings for the current user.""" try: # Import the database module for user settings @@ -989,15 +1191,15 @@ async def get_global_settings( if not db: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Database connection not available" + detail="Database connection not available", ) # Get user settings from the database - user_id = _user.get('user_id') + user_id = _user.get("user_id") if not user_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="User ID not found in session" + detail="User ID not found in session", ) user_settings = db.get_user_settings(user_id) @@ -1010,7 +1212,7 @@ async def get_global_settings( custom_instructions="", model="openai/gpt-3.5-turbo", temperature=0.7, - max_tokens=1000 + max_tokens=1000, ) # Convert from UserSettings to GlobalSettings @@ -1026,11 +1228,11 @@ async def get_global_settings( custom_bot_enabled=user_settings.custom_bot_enabled, custom_bot_prefix=user_settings.custom_bot_prefix or "!", custom_bot_status_text=user_settings.custom_bot_status_text or "!help", - custom_bot_status_type=user_settings.custom_bot_status_type or "listening" + custom_bot_status_type=user_settings.custom_bot_status_type or "listening", ) # Add theme settings if available - if hasattr(user_settings, 'theme') and user_settings.theme: + if hasattr(user_settings, "theme") and user_settings.theme: global_settings.theme = user_settings.theme return global_settings @@ -1041,14 +1243,14 @@ async def get_global_settings( log.error(f"Error getting global settings: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error getting global settings: {str(e)}" + detail=f"Error getting global settings: {str(e)}", ) + @router.post("/settings", status_code=status.HTTP_200_OK) @router.put("/settings", status_code=status.HTTP_200_OK) async def update_global_settings( - settings: GlobalSettings, - _user: dict = Depends(get_dashboard_user) + settings: GlobalSettings, _user: dict = Depends(get_dashboard_user) ): """Update global settings for the current user.""" try: @@ -1063,44 +1265,51 @@ async def update_global_settings( if not db: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Database connection not available" + detail="Database connection not available", ) # Get user ID from session - user_id = _user.get('user_id') + user_id = _user.get("user_id") if not user_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="User ID not found in session" + detail="User ID not found in session", ) # Convert from GlobalSettings to UserSettings # Provide default values for required fields when they are None user_settings = UserSettings( model_id=settings.model or "openai/gpt-3.5-turbo", - temperature=settings.temperature if settings.temperature is not None else 0.7, + temperature=( + settings.temperature if settings.temperature is not None else 0.7 + ), max_tokens=settings.max_tokens if settings.max_tokens is not None else 1000, system_message=settings.system_message, character=settings.character, character_info=settings.character_info, custom_instructions=settings.custom_instructions, custom_bot_token=settings.custom_bot_token, - custom_bot_enabled=settings.custom_bot_enabled if settings.custom_bot_enabled is not None else False, + custom_bot_enabled=( + settings.custom_bot_enabled + if settings.custom_bot_enabled is not None + else False + ), custom_bot_prefix=settings.custom_bot_prefix or "!", custom_bot_status_text=settings.custom_bot_status_text or "!help", - custom_bot_status_type=settings.custom_bot_status_type or "listening" + custom_bot_status_type=settings.custom_bot_status_type or "listening", ) # Add theme settings if provided if settings.theme: from api_service.api_models import ThemeSettings as ApiThemeSettings + user_settings.theme = ApiThemeSettings( theme_mode=settings.theme.theme_mode, primary_color=settings.theme.primary_color, secondary_color=settings.theme.secondary_color, accent_color=settings.theme.accent_color, font_family=settings.theme.font_family, - custom_css=settings.theme.custom_css + custom_css=settings.theme.custom_css, ) # Save user settings to the database @@ -1108,7 +1317,7 @@ async def update_global_settings( if not updated_settings: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to save user settings" + detail="Failed to save user settings", ) log.info(f"Updated global settings for user {user_id}") @@ -1120,21 +1329,22 @@ async def update_global_settings( log.error(f"Error updating global settings: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error updating global settings: {str(e)}" + detail=f"Error updating global settings: {str(e)}", ) + # --- Custom Bot Management Endpoints --- + class CustomBotStatus(BaseModel): exists: bool status: str error: Optional[str] = None is_running: bool + @router.get("/custom-bot/status", response_model=CustomBotStatus) -async def get_custom_bot_status( - _user: dict = Depends(get_dashboard_user) -): +async def get_custom_bot_status(_user: dict = Depends(get_dashboard_user)): """Get the status of the user's custom bot.""" try: # Check if custom bot manager is available @@ -1143,14 +1353,14 @@ async def get_custom_bot_status( exists=False, status="not_available", error="Custom bot functionality is not available", - is_running=False + is_running=False, ) - user_id = _user.get('user_id') + user_id = _user.get("user_id") if not user_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="User ID not found in session" + detail="User ID not found in session", ) # Get the status from the custom bot manager @@ -1160,27 +1370,26 @@ async def get_custom_bot_status( log.error(f"Error getting custom bot status: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error getting custom bot status: {str(e)}" + detail=f"Error getting custom bot status: {str(e)}", ) + @router.post("/custom-bot/start", status_code=status.HTTP_200_OK) -async def start_custom_bot( - _user: dict = Depends(get_dashboard_user) -): +async def start_custom_bot(_user: dict = Depends(get_dashboard_user)): """Start the user's custom bot.""" try: # Check if custom bot manager is available if not custom_bot_manager: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Custom bot functionality is not available" + detail="Custom bot functionality is not available", ) - user_id = _user.get('user_id') + user_id = _user.get("user_id") if not user_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="User ID not found in session" + detail="User ID not found in session", ) # Import the database module for user settings @@ -1192,7 +1401,7 @@ async def start_custom_bot( if not db: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Database connection not available" + detail="Database connection not available", ) # Get user settings from the database @@ -1200,14 +1409,14 @@ async def start_custom_bot( if not user_settings: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="User settings not found" + detail="User settings not found", ) # Check if custom bot token is set if not user_settings.custom_bot_token: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Custom bot token not set. Please set a token first." + detail="Custom bot token not set. Please set a token first.", ) # Create the bot if it doesn't exist @@ -1218,25 +1427,24 @@ async def start_custom_bot( token=user_settings.custom_bot_token, prefix=user_settings.custom_bot_prefix, status_type=user_settings.custom_bot_status_type, - status_text=user_settings.custom_bot_status_text + status_text=user_settings.custom_bot_status_text, ) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error creating custom bot: {message}" + detail=f"Error creating custom bot: {message}", ) # Start the bot success, message = custom_bot_manager.run_custom_bot_in_thread( - user_id=user_id, - token=user_settings.custom_bot_token + user_id=user_id, token=user_settings.custom_bot_token ) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error starting custom bot: {message}" + detail=f"Error starting custom bot: {message}", ) # Update the enabled status in user settings @@ -1251,27 +1459,26 @@ async def start_custom_bot( log.error(f"Error starting custom bot: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error starting custom bot: {str(e)}" + detail=f"Error starting custom bot: {str(e)}", ) + @router.post("/custom-bot/stop", status_code=status.HTTP_200_OK) -async def stop_custom_bot( - _user: dict = Depends(get_dashboard_user) -): +async def stop_custom_bot(_user: dict = Depends(get_dashboard_user)): """Stop the user's custom bot.""" try: # Check if custom bot manager is available if not custom_bot_manager: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Custom bot functionality is not available" + detail="Custom bot functionality is not available", ) - user_id = _user.get('user_id') + user_id = _user.get("user_id") if not user_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="User ID not found in session" + detail="User ID not found in session", ) # Import the database module for user settings @@ -1283,7 +1490,7 @@ async def stop_custom_bot( if not db: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Database connection not available" + detail="Database connection not available", ) # Stop the bot @@ -1292,7 +1499,7 @@ async def stop_custom_bot( if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error stopping custom bot: {message}" + detail=f"Error stopping custom bot: {message}", ) # Update the enabled status in user settings @@ -1309,32 +1516,36 @@ async def stop_custom_bot( log.error(f"Error stopping custom bot: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error stopping custom bot: {str(e)}" + detail=f"Error stopping custom bot: {str(e)}", ) + # --- Cog and Command Management Endpoints --- # Note: These endpoints have been moved to cog_management_endpoints.py # --- Cog Management Endpoints --- # These endpoints provide direct implementation and fallback for cog management + # Define models needed for cog management class CogCommandInfo(BaseModel): name: str description: Optional[str] = None enabled: bool = True + class CogInfo(BaseModel): name: str description: Optional[str] = None enabled: bool = True commands: List[Dict[str, Any]] = [] + @router.get("/guilds/{guild_id}/cogs", response_model=List[Any]) async def get_guild_cogs_redirect( guild_id: int, _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Get all cogs and their commands for a guild.""" try: @@ -1342,6 +1553,7 @@ async def get_guild_cogs_redirect( try: # Try relative import first from .cog_management_endpoints import get_guild_cogs + log.info(f"Successfully imported get_guild_cogs via relative import") # Call the cog management endpoint @@ -1354,6 +1566,7 @@ async def get_guild_cogs_redirect( try: # Fall back to absolute import from cog_management_endpoints import get_guild_cogs + log.info(f"Successfully imported get_guild_cogs via absolute import") # Call the cog management endpoint @@ -1369,53 +1582,76 @@ async def get_guild_cogs_redirect( # Check if bot instance is available via discord_bot_sync_api try: import discord_bot_sync_api + bot = discord_bot_sync_api.bot_instance if not bot: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Bot instance not available" + detail="Bot instance not available", ) except ImportError: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Bot sync API not available" + detail="Bot sync API not available", ) # Get all cogs from the bot cogs_list = [] for cog_name, cog in bot.cogs.items(): # Get enabled status from settings_manager - is_enabled = await settings_manager.is_cog_enabled(guild_id, cog_name, default_enabled=True) + is_enabled = await settings_manager.is_cog_enabled( + guild_id, cog_name, default_enabled=True + ) # Get commands for this cog commands_list = [] for command in cog.get_commands(): # Get command enabled status - cmd_enabled = await settings_manager.is_command_enabled(guild_id, command.qualified_name, default_enabled=True) - commands_list.append({ - "name": command.qualified_name, - "description": command.help or "No description available", - "enabled": cmd_enabled - }) + cmd_enabled = await settings_manager.is_command_enabled( + guild_id, command.qualified_name, default_enabled=True + ) + commands_list.append( + { + "name": command.qualified_name, + "description": command.help + or "No description available", + "enabled": cmd_enabled, + } + ) # Add slash commands if any - app_commands = [cmd for cmd in bot.tree.get_commands() if hasattr(cmd, 'cog') and cmd.cog and cmd.cog.qualified_name == cog_name] + app_commands = [ + cmd + for cmd in bot.tree.get_commands() + if hasattr(cmd, "cog") + and cmd.cog + and cmd.cog.qualified_name == cog_name + ] for cmd in app_commands: # Get command enabled status - cmd_enabled = await settings_manager.is_command_enabled(guild_id, cmd.name, default_enabled=True) - if not any(c["name"] == cmd.name for c in commands_list): # Avoid duplicates - commands_list.append({ - "name": cmd.name, - "description": cmd.description or "No description available", - "enabled": cmd_enabled - }) + cmd_enabled = await settings_manager.is_command_enabled( + guild_id, cmd.name, default_enabled=True + ) + if not any( + c["name"] == cmd.name for c in commands_list + ): # Avoid duplicates + commands_list.append( + { + "name": cmd.name, + "description": cmd.description + or "No description available", + "enabled": cmd_enabled, + } + ) - cogs_list.append(CogInfo( - name=cog_name, - description=cog.__doc__ or "No description available", - enabled=is_enabled, - commands=commands_list - )) + cogs_list.append( + CogInfo( + name=cog_name, + description=cog.__doc__ or "No description available", + enabled=is_enabled, + commands=commands_list, + ) + ) return cogs_list except HTTPException: @@ -1425,16 +1661,17 @@ async def get_guild_cogs_redirect( log.error(f"Error getting cogs for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error getting cogs: {str(e)}" + detail=f"Error getting cogs: {str(e)}", ) + @router.patch("/guilds/{guild_id}/cogs/{cog_name}", status_code=status.HTTP_200_OK) async def update_cog_status_redirect( guild_id: int, cog_name: str, enabled: bool = Body(..., embed=True), _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Enable or disable a cog for a guild.""" try: @@ -1442,24 +1679,36 @@ async def update_cog_status_redirect( try: # Try relative import first from .cog_management_endpoints import update_cog_status + log.info(f"Successfully imported update_cog_status via relative import") # Call the cog management endpoint - log.info(f"Calling update_cog_status for guild {guild_id}, cog {cog_name}, enabled={enabled}") + log.info( + f"Calling update_cog_status for guild {guild_id}, cog {cog_name}, enabled={enabled}" + ) result = await update_cog_status(guild_id, cog_name, enabled, _user, _admin) - log.info(f"Successfully updated cog status for guild {guild_id}, cog {cog_name}") + log.info( + f"Successfully updated cog status for guild {guild_id}, cog {cog_name}" + ) return result except ImportError as e: log.warning(f"Relative import failed: {e}, trying absolute import") try: # Fall back to absolute import from cog_management_endpoints import update_cog_status + log.info(f"Successfully imported update_cog_status via absolute import") # Call the cog management endpoint - log.info(f"Calling update_cog_status for guild {guild_id}, cog {cog_name}, enabled={enabled}") - result = await update_cog_status(guild_id, cog_name, enabled, _user, _admin) - log.info(f"Successfully updated cog status for guild {guild_id}, cog {cog_name}") + log.info( + f"Calling update_cog_status for guild {guild_id}, cog {cog_name}, enabled={enabled}" + ) + result = await update_cog_status( + guild_id, cog_name, enabled, _user, _admin + ) + log.info( + f"Successfully updated cog status for guild {guild_id}, cog {cog_name}" + ) return result except ImportError as e2: log.error(f"Both import attempts failed: {e2}") @@ -1468,49 +1717,57 @@ async def update_cog_status_redirect( # Fall back to direct implementation # Check if settings_manager is available from global_bot_accessor import get_bot_instance + bot = get_bot_instance() if not settings_manager or not bot or not bot.pg_pool: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Settings manager or database connection not available" + detail="Settings manager or database connection not available", ) # Check if the cog exists try: import discord_bot_sync_api + bot = discord_bot_sync_api.bot_instance if not bot: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Bot instance not available" + detail="Bot instance not available", ) if cog_name not in bot.cogs: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Cog '{cog_name}' not found" + detail=f"Cog '{cog_name}' not found", ) # Check if it's a core cog - core_cogs = getattr(bot, 'core_cogs', {'SettingsCog', 'HelpCog'}) + core_cogs = getattr(bot, "core_cogs", {"SettingsCog", "HelpCog"}) if cog_name in core_cogs: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Core cog '{cog_name}' cannot be disabled" + detail=f"Core cog '{cog_name}' cannot be disabled", ) except ImportError: # If we can't import the bot, we'll just assume the cog exists - log.warning("Bot sync API not available, skipping cog existence check") + log.warning( + "Bot sync API not available, skipping cog existence check" + ) # Update the cog enabled status - success = await settings_manager.set_cog_enabled(guild_id, cog_name, enabled) + success = await settings_manager.set_cog_enabled( + guild_id, cog_name, enabled + ) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update cog '{cog_name}' status" + detail=f"Failed to update cog '{cog_name}' status", ) - return {"message": f"Cog '{cog_name}' {'enabled' if enabled else 'disabled'} successfully"} + return { + "message": f"Cog '{cog_name}' {'enabled' if enabled else 'disabled'} successfully" + } except HTTPException: # Re-raise HTTP exceptions raise @@ -1518,16 +1775,19 @@ async def update_cog_status_redirect( log.error(f"Error updating cog status for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error updating cog status: {str(e)}" + detail=f"Error updating cog status: {str(e)}", ) -@router.patch("/guilds/{guild_id}/commands/{command_name}", status_code=status.HTTP_200_OK) + +@router.patch( + "/guilds/{guild_id}/commands/{command_name}", status_code=status.HTTP_200_OK +) async def update_command_status_redirect( guild_id: int, command_name: str, enabled: bool = Body(..., embed=True), _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) + _admin: bool = Depends(verify_dashboard_guild_admin), ): """Enable or disable a command for a guild.""" try: @@ -1535,24 +1795,40 @@ async def update_command_status_redirect( try: # Try relative import first from .cog_management_endpoints import update_command_status + log.info(f"Successfully imported update_command_status via relative import") # Call the cog management endpoint - log.info(f"Calling update_command_status for guild {guild_id}, command {command_name}, enabled={enabled}") - result = await update_command_status(guild_id, command_name, enabled, _user, _admin) - log.info(f"Successfully updated command status for guild {guild_id}, command {command_name}") + log.info( + f"Calling update_command_status for guild {guild_id}, command {command_name}, enabled={enabled}" + ) + result = await update_command_status( + guild_id, command_name, enabled, _user, _admin + ) + log.info( + f"Successfully updated command status for guild {guild_id}, command {command_name}" + ) return result except ImportError as e: log.warning(f"Relative import failed: {e}, trying absolute import") try: # Fall back to absolute import from cog_management_endpoints import update_command_status - log.info(f"Successfully imported update_command_status via absolute import") + + log.info( + f"Successfully imported update_command_status via absolute import" + ) # Call the cog management endpoint - log.info(f"Calling update_command_status for guild {guild_id}, command {command_name}, enabled={enabled}") - result = await update_command_status(guild_id, command_name, enabled, _user, _admin) - log.info(f"Successfully updated command status for guild {guild_id}, command {command_name}") + log.info( + f"Calling update_command_status for guild {guild_id}, command {command_name}, enabled={enabled}" + ) + result = await update_command_status( + guild_id, command_name, enabled, _user, _admin + ) + log.info( + f"Successfully updated command status for guild {guild_id}, command {command_name}" + ) return result except ImportError as e2: log.error(f"Both import attempts failed: {e2}") @@ -1561,46 +1837,58 @@ async def update_command_status_redirect( # Fall back to direct implementation # Check if settings_manager is available from global_bot_accessor import get_bot_instance + bot = get_bot_instance() if not settings_manager or not bot or not bot.pg_pool: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Settings manager or database connection not available" + detail="Settings manager or database connection not available", ) # Check if the command exists try: import discord_bot_sync_api + bot = discord_bot_sync_api.bot_instance if not bot: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Bot instance not available" + detail="Bot instance not available", ) # Check if it's a prefix command command = bot.get_command(command_name) if not command: # Check if it's an app command - app_commands = [cmd for cmd in bot.tree.get_commands() if cmd.name == command_name] + app_commands = [ + cmd + for cmd in bot.tree.get_commands() + if cmd.name == command_name + ] if not app_commands: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Command '{command_name}' not found" + detail=f"Command '{command_name}' not found", ) except ImportError: # If we can't import the bot, we'll just assume the command exists - log.warning("Bot sync API not available, skipping command existence check") + log.warning( + "Bot sync API not available, skipping command existence check" + ) # Update the command enabled status - success = await settings_manager.set_command_enabled(guild_id, command_name, enabled) + success = await settings_manager.set_command_enabled( + guild_id, command_name, enabled + ) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update command '{command_name}' status" + detail=f"Failed to update command '{command_name}' status", ) - return {"message": f"Command '{command_name}' {'enabled' if enabled else 'disabled'} successfully"} + return { + "message": f"Command '{command_name}' {'enabled' if enabled else 'disabled'} successfully" + } except HTTPException: # Re-raise HTTP exceptions raise @@ -1608,15 +1896,15 @@ async def update_command_status_redirect( log.error(f"Error updating command status for guild {guild_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error updating command status: {str(e)}" + detail=f"Error updating command status: {str(e)}", ) + # --- Conversations Endpoints --- + @router.get("/conversations", response_model=List[Conversation]) -async def get_conversations( - _user: dict = Depends(get_dashboard_user) -): +async def get_conversations(_user: dict = Depends(get_dashboard_user)): """Get all conversations for the current user.""" try: # This would normally fetch conversations from the database @@ -1628,28 +1916,28 @@ async def get_conversations( title="Conversation 1", created_at="2023-01-01T00:00:00Z", updated_at="2023-01-01T01:00:00Z", - message_count=10 + message_count=10, ), Conversation( id="2", title="Conversation 2", created_at="2023-01-02T00:00:00Z", updated_at="2023-01-02T01:00:00Z", - message_count=5 - ) + message_count=5, + ), ] return conversations except Exception as e: log.error(f"Error getting conversations: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error getting conversations: {str(e)}" + detail=f"Error getting conversations: {str(e)}", ) + @router.get("/conversations/{conversation_id}", response_model=List[Message]) async def get_conversation_messages( - conversation_id: str, - _user: dict = Depends(get_dashboard_user) + conversation_id: str, _user: dict = Depends(get_dashboard_user) ): """Get all messages for a conversation.""" try: @@ -1661,117 +1949,202 @@ async def get_conversation_messages( id="1", content="Hello, how are you?", role="user", - created_at="2023-01-01T00:00:00Z" + created_at="2023-01-01T00:00:00Z", ), Message( id="2", content="I'm doing well, thank you for asking!", role="assistant", - created_at="2023-01-01T00:00:01Z" - ) + created_at="2023-01-01T00:00:01Z", + ), ] return messages except Exception as e: log.error(f"Error getting conversation messages: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error getting conversation messages: {str(e)}" + detail=f"Error getting conversation messages: {str(e)}", ) + # --- Git Monitor Webhook Event Configuration Endpoints --- + class GitRepositoryEventSettings(BaseModel): events: List[str] + class AvailableGitEventsResponse(BaseModel): platform: str events: List[str] + SUPPORTED_GITHUB_EVENTS = [ - "push", "issues", "issue_comment", "pull_request", "pull_request_review", - "pull_request_review_comment", "release", "fork", "star", "watch", - "commit_comment", "create", "delete", "deployment", "deployment_status", - "gollum", "member", "milestone", "project_card", "project_column", "project", - "public", "repository_dispatch", "status" + "push", + "issues", + "issue_comment", + "pull_request", + "pull_request_review", + "pull_request_review_comment", + "release", + "fork", + "star", + "watch", + "commit_comment", + "create", + "delete", + "deployment", + "deployment_status", + "gollum", + "member", + "milestone", + "project_card", + "project_column", + "project", + "public", + "repository_dispatch", + "status", # Add more as needed/supported by formatters ] SUPPORTED_GITLAB_EVENTS = [ - "push", "tag_push", "issues", "note", "merge_request", "wiki_page", - "pipeline", "job", "release" + "push", + "tag_push", + "issues", + "note", + "merge_request", + "wiki_page", + "pipeline", + "job", + "release", # Add more as needed/supported by formatters # GitLab uses "push_events", "issues_events" etc. in webhook config, # but object_kind in payload is often singular like "push", "issue". # We'll store and expect the singular/object_kind style. ] -@router.get("/git_monitors/available_events/{platform}", response_model=AvailableGitEventsResponse) + +@router.get( + "/git_monitors/available_events/{platform}", + response_model=AvailableGitEventsResponse, +) async def get_available_git_events( - platform: str, - _user: dict = Depends(get_dashboard_user) # Basic auth to access + platform: str, _user: dict = Depends(get_dashboard_user) # Basic auth to access ): """Get a list of available/supported webhook event types for a given platform.""" if platform == "github": - return AvailableGitEventsResponse(platform="github", events=SUPPORTED_GITHUB_EVENTS) + return AvailableGitEventsResponse( + platform="github", events=SUPPORTED_GITHUB_EVENTS + ) elif platform == "gitlab": - return AvailableGitEventsResponse(platform="gitlab", events=SUPPORTED_GITLAB_EVENTS) + return AvailableGitEventsResponse( + platform="gitlab", events=SUPPORTED_GITLAB_EVENTS + ) else: - raise HTTPException(status_code=400, detail="Invalid platform specified. Use 'github' or 'gitlab'.") + raise HTTPException( + status_code=400, + detail="Invalid platform specified. Use 'github' or 'gitlab'.", + ) -@router.get("/guilds/{guild_id}/git_monitors/{repo_db_id}/events", response_model=GitRepositoryEventSettings) +@router.get( + "/guilds/{guild_id}/git_monitors/{repo_db_id}/events", + response_model=GitRepositoryEventSettings, +) async def get_git_repository_event_settings( - guild_id: int, # Added for verify_dashboard_guild_admin + guild_id: int, # Added for verify_dashboard_guild_admin repo_db_id: int, _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) # Ensures user is admin of the guild + _admin: bool = Depends( + verify_dashboard_guild_admin + ), # Ensures user is admin of the guild ): """Get the current allowed webhook events for a specific monitored repository.""" try: repo_config = await settings_manager.get_monitored_repository_by_id(repo_db_id) if not repo_config: - raise HTTPException(status_code=404, detail="Monitored repository not found.") - if repo_config['guild_id'] != guild_id: # Ensure the repo belongs to the specified guild - raise HTTPException(status_code=403, detail="Repository does not belong to this guild.") + raise HTTPException( + status_code=404, detail="Monitored repository not found." + ) + if ( + repo_config["guild_id"] != guild_id + ): # Ensure the repo belongs to the specified guild + raise HTTPException( + status_code=403, detail="Repository does not belong to this guild." + ) - allowed_events = repo_config.get('allowed_webhook_events', ['push']) # Default to ['push'] + allowed_events = repo_config.get( + "allowed_webhook_events", ["push"] + ) # Default to ['push'] return GitRepositoryEventSettings(events=allowed_events) except HTTPException: raise except Exception as e: - log.error(f"Error getting git repository event settings for repo {repo_db_id}: {e}") - raise HTTPException(status_code=500, detail="Failed to retrieve repository event settings.") + log.error( + f"Error getting git repository event settings for repo {repo_db_id}: {e}" + ) + raise HTTPException( + status_code=500, detail="Failed to retrieve repository event settings." + ) -@router.put("/guilds/{guild_id}/git_monitors/{repo_db_id}/events", status_code=status.HTTP_200_OK) + +@router.put( + "/guilds/{guild_id}/git_monitors/{repo_db_id}/events", + status_code=status.HTTP_200_OK, +) async def update_git_repository_event_settings( - guild_id: int, # Added for verify_dashboard_guild_admin + guild_id: int, # Added for verify_dashboard_guild_admin repo_db_id: int, settings: GitRepositoryEventSettings, _user: dict = Depends(get_dashboard_user), - _admin: bool = Depends(verify_dashboard_guild_admin) # Ensures user is admin of the guild + _admin: bool = Depends( + verify_dashboard_guild_admin + ), # Ensures user is admin of the guild ): """Update the allowed webhook events for a specific monitored repository.""" try: repo_config = await settings_manager.get_monitored_repository_by_id(repo_db_id) if not repo_config: - raise HTTPException(status_code=404, detail="Monitored repository not found.") - if repo_config['guild_id'] != guild_id: # Ensure the repo belongs to the specified guild - raise HTTPException(status_code=403, detail="Repository does not belong to this guild.") - if repo_config['monitoring_method'] != 'webhook': - raise HTTPException(status_code=400, detail="Event settings are only applicable for webhook monitoring method.") + raise HTTPException( + status_code=404, detail="Monitored repository not found." + ) + if ( + repo_config["guild_id"] != guild_id + ): # Ensure the repo belongs to the specified guild + raise HTTPException( + status_code=403, detail="Repository does not belong to this guild." + ) + if repo_config["monitoring_method"] != "webhook": + raise HTTPException( + status_code=400, + detail="Event settings are only applicable for webhook monitoring method.", + ) # Validate events against supported list for the platform - platform = repo_config['platform'] - supported_events = SUPPORTED_GITHUB_EVENTS if platform == "github" else SUPPORTED_GITLAB_EVENTS + platform = repo_config["platform"] + supported_events = ( + SUPPORTED_GITHUB_EVENTS if platform == "github" else SUPPORTED_GITLAB_EVENTS + ) for event in settings.events: if event not in supported_events: - raise HTTPException(status_code=400, detail=f"Event '{event}' is not supported for platform '{platform}'.") + raise HTTPException( + status_code=400, + detail=f"Event '{event}' is not supported for platform '{platform}'.", + ) - success = await settings_manager.update_monitored_repository_events(repo_db_id, settings.events) + success = await settings_manager.update_monitored_repository_events( + repo_db_id, settings.events + ) if not success: - raise HTTPException(status_code=500, detail="Failed to update repository event settings.") + raise HTTPException( + status_code=500, detail="Failed to update repository event settings." + ) return {"message": "Repository event settings updated successfully."} except HTTPException: raise except Exception as e: - log.error(f"Error updating git repository event settings for repo {repo_db_id}: {e}") - raise HTTPException(status_code=500, detail="Failed to update repository event settings.") + log.error( + f"Error updating git repository event settings for repo {repo_db_id}: {e}" + ) + raise HTTPException( + status_code=500, detail="Failed to update repository event settings." + ) diff --git a/api_service/dashboard_models.py b/api_service/dashboard_models.py index 08a7da2..91e50a8 100644 --- a/api_service/dashboard_models.py +++ b/api_service/dashboard_models.py @@ -1,9 +1,11 @@ """ Pydantic models used by the Dashboard API endpoints. """ + from pydantic import BaseModel, Field from typing import Dict, List, Optional, Any + class GuildSettingsResponse(BaseModel): guild_id: str prefix: Optional[str] = None @@ -11,56 +13,81 @@ class GuildSettingsResponse(BaseModel): welcome_message: Optional[str] = None goodbye_channel_id: Optional[str] = None goodbye_message: Optional[str] = None - enabled_cogs: Dict[str, bool] = {} # Cog name -> enabled status - command_permissions: Dict[str, List[str]] = {} # Command name -> List of allowed role IDs (as strings) + enabled_cogs: Dict[str, bool] = {} # Cog name -> enabled status + command_permissions: Dict[str, List[str]] = ( + {} + ) # Command name -> List of allowed role IDs (as strings) + class GuildSettingsUpdate(BaseModel): # Use Optional fields for PATCH, only provided fields will be updated prefix: Optional[str] = Field(None, min_length=1, max_length=10) - welcome_channel_id: Optional[str] = Field(None) # Allow empty string or null to disable + welcome_channel_id: Optional[str] = Field( + None + ) # Allow empty string or null to disable welcome_message: Optional[str] = Field(None) - goodbye_channel_id: Optional[str] = Field(None) # Allow empty string or null to disable + goodbye_channel_id: Optional[str] = Field( + None + ) # Allow empty string or null to disable goodbye_message: Optional[str] = Field(None) - cogs: Optional[Dict[str, bool]] = Field(None) # Dict of {cog_name: enabled_status} + cogs: Optional[Dict[str, bool]] = Field(None) # Dict of {cog_name: enabled_status} + class CommandPermission(BaseModel): command_name: str - role_id: str # Keep as string for consistency + role_id: str # Keep as string for consistency + class CommandPermissionsResponse(BaseModel): - permissions: Dict[str, List[str]] # Command name -> List of allowed role IDs + permissions: Dict[str, List[str]] # Command name -> List of allowed role IDs + class CommandCustomizationDetail(BaseModel): name: str description: Optional[str] = None + class CommandCustomizationResponse(BaseModel): - command_customizations: Dict[str, Dict[str, Optional[str]]] = {} # Original command name -> {name, description} - group_customizations: Dict[str, Dict[str, Optional[str]]] = {} # Original group name -> {name, description} - command_aliases: Dict[str, List[str]] = {} # Original command name -> List of aliases + command_customizations: Dict[str, Dict[str, Optional[str]]] = ( + {} + ) # Original command name -> {name, description} + group_customizations: Dict[str, Dict[str, Optional[str]]] = ( + {} + ) # Original group name -> {name, description} + command_aliases: Dict[str, List[str]] = ( + {} + ) # Original command name -> List of aliases + class CommandCustomizationUpdate(BaseModel): command_name: str - custom_name: Optional[str] = None # If None, removes customization - custom_description: Optional[str] = None # If None, keeps existing or no description + custom_name: Optional[str] = None # If None, removes customization + custom_description: Optional[str] = ( + None # If None, keeps existing or no description + ) + class GroupCustomizationUpdate(BaseModel): group_name: str - custom_name: Optional[str] = None # If None, removes customization + custom_name: Optional[str] = None # If None, removes customization + class CommandAliasAdd(BaseModel): command_name: str alias_name: str + class CommandAliasRemove(BaseModel): command_name: str alias_name: str + class CogCommandInfo(BaseModel): name: str description: Optional[str] = None enabled: bool = True + class CogInfo(BaseModel): name: str description: Optional[str] = None diff --git a/api_service/database.py b/api_service/database.py index 8bfd96c..8beaee5 100644 --- a/api_service/database.py +++ b/api_service/database.py @@ -2,14 +2,20 @@ import os import json import datetime from typing import Dict, List, Optional, Any + # Use absolute import for api_models from api_service.api_models import ( - Conversation, UserSettings, Message, - RoleCategoryPreset, GuildRoleCategoryConfig, UserCustomColorRole + Conversation, + UserSettings, + Message, + RoleCategoryPreset, + GuildRoleCategoryConfig, + UserCustomColorRole, ) # ============= Database Class ============= + class Database: def __init__(self, data_dir="data"): self.data_dir = data_dir @@ -17,20 +23,31 @@ class Database: self.settings_file = os.path.join(data_dir, "user_settings.json") self.tokens_file = os.path.join(data_dir, "user_tokens.json") self.role_presets_file = os.path.join(data_dir, "role_category_presets.json") - self.guild_role_configs_file = os.path.join(data_dir, "guild_role_category_configs.json") - self.user_color_roles_file = os.path.join(data_dir, "user_custom_color_roles.json") + self.guild_role_configs_file = os.path.join( + data_dir, "guild_role_category_configs.json" + ) + self.user_color_roles_file = os.path.join( + data_dir, "user_custom_color_roles.json" + ) # Create data directory if it doesn't exist os.makedirs(data_dir, exist_ok=True) # In-memory storage - self.conversations: Dict[str, Dict[str, Conversation]] = {} # user_id -> conversation_id -> Conversation + self.conversations: Dict[str, Dict[str, Conversation]] = ( + {} + ) # user_id -> conversation_id -> Conversation self.user_settings: Dict[str, UserSettings] = {} # user_id -> UserSettings self.user_tokens: Dict[str, Dict[str, Any]] = {} # user_id -> token_data - self.role_category_presets: Dict[str, RoleCategoryPreset] = {} # preset_id -> RoleCategoryPreset - self.guild_role_category_configs: Dict[str, List[GuildRoleCategoryConfig]] = {} # guild_id -> List[GuildRoleCategoryConfig] - self.user_custom_color_roles: Dict[str, Dict[str, UserCustomColorRole]] = {} # guild_id -> user_id -> UserCustomColorRole - + self.role_category_presets: Dict[str, RoleCategoryPreset] = ( + {} + ) # preset_id -> RoleCategoryPreset + self.guild_role_category_configs: Dict[str, List[GuildRoleCategoryConfig]] = ( + {} + ) # guild_id -> List[GuildRoleCategoryConfig] + self.user_custom_color_roles: Dict[str, Dict[str, UserCustomColorRole]] = ( + {} + ) # guild_id -> user_id -> UserCustomColorRole # Load data from files self.load_data() @@ -65,12 +82,14 @@ class Database: preset_id: RoleCategoryPreset.model_validate(preset_data) for preset_id, preset_data in data.items() } - print(f"Loaded {len(self.role_category_presets)} role category presets.") + print( + f"Loaded {len(self.role_category_presets)} role category presets." + ) except Exception as e: print(f"Error loading role category presets: {e}") self.role_category_presets = {} else: - self.role_category_presets = {} # Initialize if file doesn't exist + self.role_category_presets = {} # Initialize if file doesn't exist def save_role_category_presets(self): """Save role category presets to file""" @@ -80,7 +99,9 @@ class Database: for preset_id, preset in self.role_category_presets.items() } with open(self.role_presets_file, "w", encoding="utf-8") as f: - json.dump(serializable_data, f, indent=2, default=str, ensure_ascii=False) + json.dump( + serializable_data, f, indent=2, default=str, ensure_ascii=False + ) except Exception as e: print(f"Error saving role category presets: {e}") @@ -91,10 +112,15 @@ class Database: with open(self.guild_role_configs_file, "r", encoding="utf-8") as f: data = json.load(f) self.guild_role_category_configs = { - guild_id: [GuildRoleCategoryConfig.model_validate(config_data) for config_data in configs_list] + guild_id: [ + GuildRoleCategoryConfig.model_validate(config_data) + for config_data in configs_list + ] for guild_id, configs_list in data.items() } - print(f"Loaded guild role category configs for {len(self.guild_role_category_configs)} guilds.") + print( + f"Loaded guild role category configs for {len(self.guild_role_category_configs)} guilds." + ) except Exception as e: print(f"Error loading guild role category configs: {e}") self.guild_role_category_configs = {} @@ -109,7 +135,9 @@ class Database: for guild_id, configs_list in self.guild_role_category_configs.items() } with open(self.guild_role_configs_file, "w", encoding="utf-8") as f: - json.dump(serializable_data, f, indent=2, default=str, ensure_ascii=False) + json.dump( + serializable_data, f, indent=2, default=str, ensure_ascii=False + ) except Exception as e: print(f"Error saving guild role category configs: {e}") @@ -126,7 +154,9 @@ class Database: } for guild_id, user_roles in data.items() } - print(f"Loaded user custom color roles for {len(self.user_custom_color_roles)} guilds.") + print( + f"Loaded user custom color roles for {len(self.user_custom_color_roles)} guilds." + ) except Exception as e: print(f"Error loading user custom color roles: {e}") self.user_custom_color_roles = {} @@ -138,13 +168,14 @@ class Database: try: serializable_data = { guild_id: { - user_id: role.model_dump() - for user_id, role in user_roles.items() + user_id: role.model_dump() for user_id, role in user_roles.items() } for guild_id, user_roles in self.user_custom_color_roles.items() } with open(self.user_color_roles_file, "w", encoding="utf-8") as f: - json.dump(serializable_data, f, indent=2, default=str, ensure_ascii=False) + json.dump( + serializable_data, f, indent=2, default=str, ensure_ascii=False + ) except Exception as e: print(f"Error saving user custom color roles: {e}") @@ -175,13 +206,14 @@ class Database: # Convert to JSON-serializable format serializable_data = { user_id: { - conv_id: conv.model_dump() - for conv_id, conv in user_convs.items() + conv_id: conv.model_dump() for conv_id, conv in user_convs.items() } for user_id, user_convs in self.conversations.items() } with open(self.conversations_file, "w", encoding="utf-8") as f: - json.dump(serializable_data, f, indent=2, default=str, ensure_ascii=False) + json.dump( + serializable_data, f, indent=2, default=str, ensure_ascii=False + ) except Exception as e: print(f"Error saving conversations: {e}") @@ -210,7 +242,9 @@ class Database: for user_id, settings in self.user_settings.items() } with open(self.settings_file, "w", encoding="utf-8") as f: - json.dump(serializable_data, f, indent=2, default=str, ensure_ascii=False) + json.dump( + serializable_data, f, indent=2, default=str, ensure_ascii=False + ) except Exception as e: print(f"Error saving user settings: {e}") @@ -220,11 +254,15 @@ class Database: """Get all conversations for a user""" return list(self.conversations.get(user_id, {}).values()) - def get_conversation(self, user_id: str, conversation_id: str) -> Optional[Conversation]: + def get_conversation( + self, user_id: str, conversation_id: str + ) -> Optional[Conversation]: """Get a specific conversation for a user""" return self.conversations.get(user_id, {}).get(conversation_id) - def save_conversation(self, user_id: str, conversation: Conversation) -> Conversation: + def save_conversation( + self, user_id: str, conversation: Conversation + ) -> Conversation: """Save a conversation for a user""" # Update the timestamp conversation.updated_at = datetime.datetime.now() @@ -243,7 +281,10 @@ class Database: def delete_conversation(self, user_id: str, conversation_id: str) -> bool: """Delete a conversation for a user""" - if user_id in self.conversations and conversation_id in self.conversations[user_id]: + if ( + user_id in self.conversations + and conversation_id in self.conversations[user_id] + ): del self.conversations[user_id][conversation_id] self.save_conversations() return True @@ -259,7 +300,9 @@ class Database: """Get all role category presets.""" return list(self.role_category_presets.values()) - def save_role_category_preset(self, preset: RoleCategoryPreset) -> RoleCategoryPreset: + def save_role_category_preset( + self, preset: RoleCategoryPreset + ) -> RoleCategoryPreset: """Save a role category preset.""" self.role_category_presets[preset.id] = preset self.save_role_category_presets() @@ -275,39 +318,55 @@ class Database: # ============= Guild Role Category Config Methods ============= - def get_guild_role_category_configs(self, guild_id: str) -> List[GuildRoleCategoryConfig]: + def get_guild_role_category_configs( + self, guild_id: str + ) -> List[GuildRoleCategoryConfig]: """Get all role category configurations for a specific guild.""" return self.guild_role_category_configs.get(guild_id, []) - def get_all_guild_role_category_configs(self) -> Dict[str, List[GuildRoleCategoryConfig]]: + def get_all_guild_role_category_configs( + self, + ) -> Dict[str, List[GuildRoleCategoryConfig]]: """Get all role category configurations for all guilds.""" return self.guild_role_category_configs - def get_guild_role_category_config(self, guild_id: str, category_id: str) -> Optional[GuildRoleCategoryConfig]: + def get_guild_role_category_config( + self, guild_id: str, category_id: str + ) -> Optional[GuildRoleCategoryConfig]: """Get a specific role category configuration for a guild.""" for config in self.get_guild_role_category_configs(guild_id): if config.category_id == category_id: return config return None - def save_guild_role_category_config(self, config: GuildRoleCategoryConfig) -> GuildRoleCategoryConfig: + def save_guild_role_category_config( + self, config: GuildRoleCategoryConfig + ) -> GuildRoleCategoryConfig: """Save a guild's role category configuration.""" guild_id = config.guild_id if guild_id not in self.guild_role_category_configs: self.guild_role_category_configs[guild_id] = [] # Remove existing config with the same category_id if it exists, then add the new/updated one - self.guild_role_category_configs[guild_id] = [c for c in self.guild_role_category_configs[guild_id] if c.category_id != config.category_id] + self.guild_role_category_configs[guild_id] = [ + c + for c in self.guild_role_category_configs[guild_id] + if c.category_id != config.category_id + ] self.guild_role_category_configs[guild_id].append(config) self.save_guild_role_category_configs() return config - def delete_guild_role_category_config(self, guild_id: str, category_id: str) -> bool: + def delete_guild_role_category_config( + self, guild_id: str, category_id: str + ) -> bool: """Delete a specific role category configuration for a guild.""" if guild_id in self.guild_role_category_configs: initial_len = len(self.guild_role_category_configs[guild_id]) self.guild_role_category_configs[guild_id] = [ - c for c in self.guild_role_category_configs[guild_id] if c.category_id != category_id + c + for c in self.guild_role_category_configs[guild_id] + if c.category_id != category_id ] if len(self.guild_role_category_configs[guild_id]) < initial_len: self.save_guild_role_category_configs() @@ -316,18 +375,22 @@ class Database: # ============= User Custom Color Role Methods ============= - def get_user_custom_color_role(self, guild_id: str, user_id: str) -> Optional[UserCustomColorRole]: + def get_user_custom_color_role( + self, guild_id: str, user_id: str + ) -> Optional[UserCustomColorRole]: """Get a user's custom color role in a specific guild.""" return self.user_custom_color_roles.get(guild_id, {}).get(user_id) - def save_user_custom_color_role(self, color_role: UserCustomColorRole) -> UserCustomColorRole: + def save_user_custom_color_role( + self, color_role: UserCustomColorRole + ) -> UserCustomColorRole: """Save a user's custom color role.""" guild_id = color_role.guild_id user_id = color_role.user_id if guild_id not in self.user_custom_color_roles: self.user_custom_color_roles[guild_id] = {} - + color_role.last_updated = datetime.datetime.now() self.user_custom_color_roles[guild_id][user_id] = color_role self.save_user_custom_color_roles() @@ -335,7 +398,10 @@ class Database: def delete_user_custom_color_role(self, guild_id: str, user_id: str) -> bool: """Delete a user's custom color role in a specific guild.""" - if guild_id in self.user_custom_color_roles and user_id in self.user_custom_color_roles[guild_id]: + if ( + guild_id in self.user_custom_color_roles + and user_id in self.user_custom_color_roles[guild_id] + ): del self.user_custom_color_roles[guild_id][user_id] self.save_user_custom_color_roles() return True @@ -380,7 +446,9 @@ class Database: """Save user tokens to file""" try: with open(self.tokens_file, "w", encoding="utf-8") as f: - json.dump(self.user_tokens, f, indent=2, default=str, ensure_ascii=False) + json.dump( + self.user_tokens, f, indent=2, default=str, ensure_ascii=False + ) except Exception as e: print(f"Error saving user tokens: {e}") @@ -388,7 +456,9 @@ class Database: """Get token data for a user""" return self.user_tokens.get(user_id) - def save_user_token(self, user_id: str, token_data: Dict[str, Any]) -> Dict[str, Any]: + def save_user_token( + self, user_id: str, token_data: Dict[str, Any] + ) -> Dict[str, Any]: """Save token data for a user""" # Add the time when the token was saved token_data["saved_at"] = datetime.datetime.now().isoformat() diff --git a/api_service/dependencies.py b/api_service/dependencies.py index 22f70a4..37e06e6 100644 --- a/api_service/dependencies.py +++ b/api_service/dependencies.py @@ -8,11 +8,12 @@ import os # --- Configuration Loading --- # Need to load settings here as well, or pass http_session/settings around # Re-using the settings logic from api_server.py for simplicity -dotenv_path = os.path.join(os.path.dirname(__file__), '..', 'discordbot', '.env') +dotenv_path = os.path.join(os.path.dirname(__file__), "..", "discordbot", ".env") from pydantic_settings import BaseSettings, SettingsConfigDict from typing import Optional + class ApiSettings(BaseSettings): DISCORD_CLIENT_ID: str DISCORD_CLIENT_SECRET: str @@ -34,17 +35,19 @@ class ApiSettings(BaseSettings): GURT_STATS_PUSH_SECRET: Optional[str] = None model_config = SettingsConfigDict( - env_file=dotenv_path, - env_file_encoding='utf-8', - extra='ignore' + env_file=dotenv_path, env_file_encoding="utf-8", extra="ignore" ) + @lru_cache() def get_api_settings() -> ApiSettings: if not os.path.exists(dotenv_path): - print(f"Warning: .env file not found at {dotenv_path}. Using defaults or environment variables.") + print( + f"Warning: .env file not found at {dotenv_path}. Using defaults or environment variables." + ) return ApiSettings() + settings = get_api_settings() # --- Constants --- @@ -53,51 +56,62 @@ DISCORD_USER_URL = f"{DISCORD_API_BASE_URL}/users/@me" DISCORD_USER_GUILDS_URL = f"{DISCORD_API_BASE_URL}/users/@me/guilds" # --- Logging --- -log = logging.getLogger(__name__) # Use specific logger +log = logging.getLogger(__name__) # Use specific logger # --- Global aiohttp Session (managed by api_server lifespan) --- # We need access to the session created in api_server.py # A simple way is to have api_server.py set it after creation. http_session: Optional[aiohttp.ClientSession] = None + def set_http_session(session: aiohttp.ClientSession): """Sets the global aiohttp session for dependencies.""" global http_session http_session = session + # --- Authentication Dependency (Dashboard Specific) --- async def get_dashboard_user(request: Request) -> dict: """Dependency to check if user is authenticated via dashboard session and return user data.""" - user_id = request.session.get('user_id') - username = request.session.get('username') - access_token = request.session.get('access_token') # Needed for subsequent Discord API calls + user_id = request.session.get("user_id") + username = request.session.get("username") + access_token = request.session.get( + "access_token" + ) # Needed for subsequent Discord API calls if not user_id or not username or not access_token: log.warning("Dashboard: Attempted access by unauthenticated user.") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated for dashboard", - headers={"WWW-Authenticate": "Bearer"}, # Standard header for 401 + headers={"WWW-Authenticate": "Bearer"}, # Standard header for 401 ) # Return essential user info and token for potential use in endpoints return { "user_id": user_id, "username": username, - "avatar": request.session.get('avatar'), - "access_token": access_token - } + "avatar": request.session.get("avatar"), + "access_token": access_token, + } + # --- Guild Admin Verification Dependency (Dashboard Specific) --- -async def verify_dashboard_guild_admin(guild_id: int, current_user: dict = Depends(get_dashboard_user)) -> bool: +async def verify_dashboard_guild_admin( + guild_id: int, current_user: dict = Depends(get_dashboard_user) +) -> bool: """Dependency to verify the dashboard session user is an admin of the specified guild.""" - global http_session # Use the global aiohttp session + global http_session # Use the global aiohttp session if not http_session: - log.error("verify_dashboard_guild_admin: HTTP session not ready.") - raise HTTPException(status_code=500, detail="Internal server error: HTTP session not ready.") + log.error("verify_dashboard_guild_admin: HTTP session not ready.") + raise HTTPException( + status_code=500, detail="Internal server error: HTTP session not ready." + ) - user_headers = {'Authorization': f'Bearer {current_user["access_token"]}'} + user_headers = {"Authorization": f'Bearer {current_user["access_token"]}'} try: - log.debug(f"Dashboard: Verifying admin status for user {current_user['user_id']} in guild {guild_id}") + log.debug( + f"Dashboard: Verifying admin status for user {current_user['user_id']} in guild {guild_id}" + ) # Add rate limit handling max_retries = 3 @@ -106,59 +120,97 @@ async def verify_dashboard_guild_admin(guild_id: int, current_user: dict = Depen while retry_count < max_retries: if retry_after > 0: - log.warning(f"Dashboard: Rate limited by Discord API, waiting {retry_after} seconds before retry") + log.warning( + f"Dashboard: Rate limited by Discord API, waiting {retry_after} seconds before retry" + ) await asyncio.sleep(retry_after) - async with http_session.get(DISCORD_USER_GUILDS_URL, headers=user_headers) as resp: + async with http_session.get( + DISCORD_USER_GUILDS_URL, headers=user_headers + ) as resp: if resp.status == 429: # Rate limited retry_count += 1 try: - retry_after = float(resp.headers.get('X-RateLimit-Reset-After', resp.headers.get('Retry-After', 1))) + retry_after = float( + resp.headers.get( + "X-RateLimit-Reset-After", + resp.headers.get("Retry-After", 1), + ) + ) except (ValueError, TypeError): - retry_after = 1.0 # Default wait time if header is invalid - is_global = resp.headers.get('X-RateLimit-Global') is not None - scope = resp.headers.get('X-RateLimit-Scope', 'unknown') + retry_after = 1.0 # Default wait time if header is invalid + is_global = resp.headers.get("X-RateLimit-Global") is not None + scope = resp.headers.get("X-RateLimit-Scope", "unknown") log.warning( f"Dashboard: Discord API rate limit hit. " f"Global: {is_global}, Scope: {scope}, " f"Reset after: {retry_after}s, " f"Retry: {retry_count}/{max_retries}" ) - if is_global: retry_after = max(retry_after, 5) # Wait longer for global limits - continue # Retry the request + if is_global: + retry_after = max( + retry_after, 5 + ) # Wait longer for global limits + continue # Retry the request if resp.status == 401: # Session token might be invalid, but we can't clear session here easily. # Let the frontend handle re-authentication based on the 401. - raise HTTPException(status_code=401, detail="Discord token invalid or expired. Please re-login.") + raise HTTPException( + status_code=401, + detail="Discord token invalid or expired. Please re-login.", + ) - resp.raise_for_status() # Raise for other errors (4xx, 5xx) + resp.raise_for_status() # Raise for other errors (4xx, 5xx) user_guilds = await resp.json() ADMINISTRATOR_PERMISSION = 0x8 is_admin = False for guild in user_guilds: - if int(guild['id']) == guild_id: - permissions = int(guild['permissions']) - if (permissions & ADMINISTRATOR_PERMISSION) == ADMINISTRATOR_PERMISSION: + if int(guild["id"]) == guild_id: + permissions = int(guild["permissions"]) + if ( + permissions & ADMINISTRATOR_PERMISSION + ) == ADMINISTRATOR_PERMISSION: is_admin = True - break # Found the guild and user is admin + break # Found the guild and user is admin if not is_admin: - log.warning(f"Dashboard: User {current_user['user_id']} is not admin or not in guild {guild_id}.") - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User is not an administrator of this guild.") + log.warning( + f"Dashboard: User {current_user['user_id']} is not admin or not in guild {guild_id}." + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="User is not an administrator of this guild.", + ) - log.debug(f"Dashboard: User {current_user['user_id']} verified as admin for guild {guild_id}.") - return True # Indicate verification success + log.debug( + f"Dashboard: User {current_user['user_id']} verified as admin for guild {guild_id}." + ) + return True # Indicate verification success # If loop finishes without returning True, it means retries were exhausted - raise HTTPException(status_code=429, detail="Rate limited by Discord API. Please try again later.") + raise HTTPException( + status_code=429, + detail="Rate limited by Discord API. Please try again later.", + ) except aiohttp.ClientResponseError as e: - log.exception(f"Dashboard: HTTP error verifying guild admin status: {e.status} {e.message}") - if e.status == 429: # Should be caught by the loop, but safeguard - raise HTTPException(status_code=429, detail="Rate limited by Discord API. Please try again later.") - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="Error communicating with Discord API.") + log.exception( + f"Dashboard: HTTP error verifying guild admin status: {e.status} {e.message}" + ) + if e.status == 429: # Should be caught by the loop, but safeguard + raise HTTPException( + status_code=429, + detail="Rate limited by Discord API. Please try again later.", + ) + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Error communicating with Discord API.", + ) except Exception as e: log.exception(f"Dashboard: Generic error verifying guild admin status: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An internal error occurred during permission verification.") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An internal error occurred during permission verification.", + ) diff --git a/api_service/discord_client.py b/api_service/discord_client.py index d0b0359..6677670 100644 --- a/api_service/discord_client.py +++ b/api_service/discord_client.py @@ -4,6 +4,7 @@ import datetime from typing import Dict, List, Optional, Any, Union from api_service.api_models import Conversation, UserSettings, Message + class ApiClient: def __init__(self, api_url: str, token: Optional[str] = None): """ @@ -20,7 +21,9 @@ class ApiClient: """Set the Discord token for authentication""" self.token = token - async def _make_request(self, method: str, endpoint: str, data: Optional[Dict] = None): + async def _make_request( + self, method: str, endpoint: str, data: Optional[Dict] = None + ): """ Make a request to the API @@ -37,7 +40,7 @@ class ApiClient: headers = { "Authorization": f"Bearer {self.token}", - "Content-Type": "application/json" + "Content-Type": "application/json", } url = f"{self.api_url}/{endpoint}" @@ -47,38 +50,54 @@ class ApiClient: async with session.get(url, headers=headers) as response: if response.status != 200: error_text = await response.text() - raise Exception(f"API request failed: {response.status} - {error_text}") + raise Exception( + f"API request failed: {response.status} - {error_text}" + ) response_text = await response.text() return json.loads(response_text) elif method == "POST": # Convert data to JSON with datetime handling - json_data = json.dumps(data, default=str, ensure_ascii=False) if data else None + json_data = ( + json.dumps(data, default=str, ensure_ascii=False) if data else None + ) # Update headers for manually serialized JSON if json_data: headers["Content-Type"] = "application/json" - async with session.post(url, headers=headers, data=json_data) as response: + async with session.post( + url, headers=headers, data=json_data + ) as response: if response.status not in (200, 201): error_text = await response.text() - raise Exception(f"API request failed: {response.status} - {error_text}") + raise Exception( + f"API request failed: {response.status} - {error_text}" + ) response_text = await response.text() return json.loads(response_text) elif method == "PUT": # Convert data to JSON with datetime handling - json_data = json.dumps(data, default=str, ensure_ascii=False) if data else None + json_data = ( + json.dumps(data, default=str, ensure_ascii=False) if data else None + ) # Update headers for manually serialized JSON if json_data: headers["Content-Type"] = "application/json" - async with session.put(url, headers=headers, data=json_data) as response: + async with session.put( + url, headers=headers, data=json_data + ) as response: if response.status != 200: error_text = await response.text() - raise Exception(f"API request failed: {response.status} - {error_text}") + raise Exception( + f"API request failed: {response.status} - {error_text}" + ) response_text = await response.text() return json.loads(response_text) elif method == "DELETE": async with session.delete(url, headers=headers) as response: if response.status != 200: error_text = await response.text() - raise Exception(f"API request failed: {response.status} - {error_text}") + raise Exception( + f"API request failed: {response.status} - {error_text}" + ) response_text = await response.text() return json.loads(response_text) else: @@ -98,17 +117,25 @@ class ApiClient: async def create_conversation(self, conversation: Conversation) -> Conversation: """Create a new conversation""" - response = await self._make_request("POST", "conversations", {"conversation": conversation.model_dump()}) + response = await self._make_request( + "POST", "conversations", {"conversation": conversation.model_dump()} + ) return Conversation.model_validate(response) async def update_conversation(self, conversation: Conversation) -> Conversation: """Update an existing conversation""" - response = await self._make_request("PUT", f"conversations/{conversation.id}", {"conversation": conversation.model_dump()}) + response = await self._make_request( + "PUT", + f"conversations/{conversation.id}", + {"conversation": conversation.model_dump()}, + ) return Conversation.model_validate(response) async def delete_conversation(self, conversation_id: str) -> bool: """Delete a conversation""" - response = await self._make_request("DELETE", f"conversations/{conversation_id}") + response = await self._make_request( + "DELETE", f"conversations/{conversation_id}" + ) return response["success"] # ============= Settings Methods ============= @@ -120,7 +147,9 @@ class ApiClient: async def update_settings(self, settings: UserSettings) -> UserSettings: """Update settings for the authenticated user""" - response = await self._make_request("PUT", "settings", {"settings": settings.model_dump()}) + response = await self._make_request( + "PUT", "settings", {"settings": settings.model_dump()} + ) return UserSettings.model_validate(response) # ============= Helper Methods ============= @@ -136,7 +165,7 @@ class ApiClient: temperature: float = 0.7, max_tokens: int = 1000, web_search_enabled: bool = False, - system_message: Optional[str] = None + system_message: Optional[str] = None, ) -> Conversation: """ Save a conversation from Discord to the API @@ -159,13 +188,15 @@ class ApiClient: # Convert messages to the API format api_messages = [] for msg in messages: - api_messages.append(Message( - content=msg["content"], - role=msg["role"], - timestamp=msg.get("timestamp", datetime.datetime.now()), - reasoning=msg.get("reasoning"), - usage_data=msg.get("usage_data") - )) + api_messages.append( + Message( + content=msg["content"], + role=msg["role"], + timestamp=msg.get("timestamp", datetime.datetime.now()), + reasoning=msg.get("reasoning"), + usage_data=msg.get("usage_data"), + ) + ) # Create or update the conversation if conversation_id: @@ -201,7 +232,7 @@ class ApiClient: web_search_enabled=web_search_enabled, system_message=system_message, created_at=datetime.datetime.now(), - updated_at=datetime.datetime.now() + updated_at=datetime.datetime.now(), ) return await self.create_conversation(conversation) diff --git a/api_service/run_api_server.py b/api_service/run_api_server.py index 1c11494..17d3b26 100644 --- a/api_service/run_api_server.py +++ b/api_service/run_api_server.py @@ -15,7 +15,7 @@ data_dir = os.getenv("DATA_DIR", "data") os.makedirs(data_dir, exist_ok=True) # Ensure the project root directory (containing the 'discordbot' package) is in sys.path -project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if project_root not in sys.path: print(f"Adding project root to sys.path: {project_root}") sys.path.insert(0, project_root) @@ -28,11 +28,7 @@ if __name__ == "__main__": def run_uvicorn(bind_host): print(f"Starting API server on {bind_host}:{port}") - uvicorn.run( - "api_service.api_server:app", - host=bind_host, - port=port - ) + uvicorn.run("api_service.api_server:app", host=bind_host, port=port) print(f"Data directory: {data_dir}") # Start only IPv4 server to avoid conflicts @@ -47,6 +43,7 @@ if __name__ == "__main__": try: while True: import time + time.sleep(1) except KeyboardInterrupt: print("Shutting down API server...") diff --git a/api_service/terminal_images_endpoint.py b/api_service/terminal_images_endpoint.py index c580af7..aa7312f 100644 --- a/api_service/terminal_images_endpoint.py +++ b/api_service/terminal_images_endpoint.py @@ -14,44 +14,54 @@ from typing import Optional router = APIRouter(tags=["Terminal Images"]) # Path to the terminal_images directory -TERMINAL_IMAGES_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'terminal_images')) +TERMINAL_IMAGES_DIR = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "terminal_images") +) # Ensure the terminal_images directory exists os.makedirs(TERMINAL_IMAGES_DIR, exist_ok=True) + @router.get("/{filename}") async def get_terminal_image(filename: str): """ Get a terminal image by filename. - + Args: filename: The filename of the terminal image - + Returns: The terminal image file - + Raises: HTTPException: If the file is not found """ file_path = os.path.join(TERMINAL_IMAGES_DIR, filename) - + if not os.path.exists(file_path): raise HTTPException(status_code=404, detail="Terminal image not found") - + return FileResponse(file_path) + # Function to mount the terminal images directory as static files def mount_terminal_images(app): """ Mount the terminal_images directory as static files. - + Args: app: The FastAPI app to mount the static files on """ # Check if the directory exists if os.path.exists(TERMINAL_IMAGES_DIR) and os.path.isdir(TERMINAL_IMAGES_DIR): # Mount the terminal_images directory as static files - app.mount("/terminal_images", StaticFiles(directory=TERMINAL_IMAGES_DIR), name="terminal_images") + app.mount( + "/terminal_images", + StaticFiles(directory=TERMINAL_IMAGES_DIR), + name="terminal_images", + ) print(f"Mounted terminal images directory: {TERMINAL_IMAGES_DIR}") else: - print(f"Warning: Terminal images directory '{TERMINAL_IMAGES_DIR}' not found. Terminal images will not be available.") + print( + f"Warning: Terminal images directory '{TERMINAL_IMAGES_DIR}' not found. Terminal images will not be available." + ) diff --git a/api_service/webhook_endpoints.py b/api_service/webhook_endpoints.py index 670ab3a..a0221cc 100644 --- a/api_service/webhook_endpoints.py +++ b/api_service/webhook_endpoints.py @@ -5,11 +5,14 @@ import logging from typing import Dict, Any, Optional from fastapi import APIRouter, Request, HTTPException, Depends, Header, Path -import discord # For Color +import discord # For Color # Import API server functions try: - from .api_server import send_discord_message_via_api, get_api_settings # For settings + from .api_server import ( + send_discord_message_via_api, + get_api_settings, + ) # For settings except ImportError: # If api_server.py is in the same directory: from api_service.api_server import send_discord_message_via_api, get_api_settings @@ -17,9 +20,12 @@ except ImportError: log = logging.getLogger(__name__) router = APIRouter() -api_settings = get_api_settings() # Get loaded API settings +api_settings = get_api_settings() # Get loaded API settings -async def get_monitored_repository_by_id_api(request: Request, repo_db_id: int) -> Dict | None: + +async def get_monitored_repository_by_id_api( + request: Request, repo_db_id: int +) -> Dict | None: """Gets details of a monitored repository by its database ID using the API service's PostgreSQL pool. This is an alternative to settings_manager.get_monitored_repository_by_id that doesn't rely on the bot instance. """ @@ -29,7 +35,9 @@ async def get_monitored_repository_by_id_api(request: Request, repo_db_id: int) # Try to get the PostgreSQL pool from the FastAPI app state pg_pool = getattr(request.app.state, "pg_pool", None) if not pg_pool: - log.warning(f"API service PostgreSQL pool not available for get_monitored_repository_by_id_api (ID {repo_db_id}).") + log.warning( + f"API service PostgreSQL pool not available for get_monitored_repository_by_id_api (ID {repo_db_id})." + ) # Instead of falling back to settings_manager, let's try to create a new connection # This is a temporary solution to diagnose the issue @@ -38,26 +46,29 @@ async def get_monitored_repository_by_id_api(request: Request, repo_db_id: int) from api_service.api_server import get_api_settings settings = get_api_settings() - log.info(f"Attempting to create a new PostgreSQL connection for repo_db_id: {repo_db_id}") + log.info( + f"Attempting to create a new PostgreSQL connection for repo_db_id: {repo_db_id}" + ) # Create a new connection to the database conn = await asyncpg.connect( user=settings.POSTGRES_USER, password=settings.POSTGRES_PASSWORD, host=settings.POSTGRES_HOST, - database=settings.POSTGRES_SETTINGS_DB + database=settings.POSTGRES_SETTINGS_DB, ) # Query the database record = await conn.fetchrow( - "SELECT * FROM git_monitored_repositories WHERE id = $1", - repo_db_id + "SELECT * FROM git_monitored_repositories WHERE id = $1", repo_db_id ) # Close the connection await conn.close() - log.info(f"Successfully retrieved repository configuration for ID {repo_db_id} using a new connection") + log.info( + f"Successfully retrieved repository configuration for ID {repo_db_id} using a new connection" + ) return dict(record) if record else None except Exception as e: log.exception(f"Failed to create a new PostgreSQL connection: {e}") @@ -67,13 +78,16 @@ async def get_monitored_repository_by_id_api(request: Request, repo_db_id: int) try: async with pg_pool.acquire() as conn: record = await conn.fetchrow( - "SELECT * FROM git_monitored_repositories WHERE id = $1", - repo_db_id + "SELECT * FROM git_monitored_repositories WHERE id = $1", repo_db_id + ) + log.info( + f"Retrieved repository configuration for ID {repo_db_id} using API service PostgreSQL pool" ) - log.info(f"Retrieved repository configuration for ID {repo_db_id} using API service PostgreSQL pool") return dict(record) if record else None except Exception as e: - log.exception(f"Database error getting monitored repository by ID {repo_db_id} using API service pool: {e}") + log.exception( + f"Database error getting monitored repository by ID {repo_db_id} using API service pool: {e}" + ) # Instead of falling back to settings_manager, try with a new connection try: @@ -81,158 +95,196 @@ async def get_monitored_repository_by_id_api(request: Request, repo_db_id: int) from api_service.api_server import get_api_settings settings = get_api_settings() - log.info(f"Attempting to create a new PostgreSQL connection after pool error for repo_db_id: {repo_db_id}") + log.info( + f"Attempting to create a new PostgreSQL connection after pool error for repo_db_id: {repo_db_id}" + ) # Create a new connection to the database conn = await asyncpg.connect( user=settings.POSTGRES_USER, password=settings.POSTGRES_PASSWORD, host=settings.POSTGRES_HOST, - database=settings.POSTGRES_SETTINGS_DB + database=settings.POSTGRES_SETTINGS_DB, ) # Query the database record = await conn.fetchrow( - "SELECT * FROM git_monitored_repositories WHERE id = $1", - repo_db_id + "SELECT * FROM git_monitored_repositories WHERE id = $1", repo_db_id ) # Close the connection await conn.close() - log.info(f"Successfully retrieved repository configuration for ID {repo_db_id} using a new connection after pool error") + log.info( + f"Successfully retrieved repository configuration for ID {repo_db_id} using a new connection after pool error" + ) return dict(record) if record else None except Exception as e2: - log.exception(f"Failed to create a new PostgreSQL connection after pool error: {e2}") + log.exception( + f"Failed to create a new PostgreSQL connection after pool error: {e2}" + ) return None + async def get_allowed_events_for_repo(request: Request, repo_db_id: int) -> list[str]: """Helper to fetch allowed_webhook_events for a repo.""" repo_config = await get_monitored_repository_by_id_api(request, repo_db_id) - if repo_config and repo_config.get('allowed_webhook_events'): - return repo_config['allowed_webhook_events'] - return ['push'] # Default to 'push' if not set or not found, for safety + if repo_config and repo_config.get("allowed_webhook_events"): + return repo_config["allowed_webhook_events"] + return ["push"] # Default to 'push' if not set or not found, for safety -def verify_github_signature(payload_body: bytes, secret_token: str, signature_header: str) -> bool: + +def verify_github_signature( + payload_body: bytes, secret_token: str, signature_header: str +) -> bool: """Verify that the payload was sent from GitHub by validating the signature.""" if not signature_header: log.warning("No X-Hub-Signature-256 found on request.") return False if not secret_token: - log.error("Webhook secret is not configured for this repository. Cannot verify signature.") + log.error( + "Webhook secret is not configured for this repository. Cannot verify signature." + ) return False - hash_object = hmac.new(secret_token.encode('utf-8'), msg=payload_body, digestmod=hashlib.sha256) + hash_object = hmac.new( + secret_token.encode("utf-8"), msg=payload_body, digestmod=hashlib.sha256 + ) expected_signature = "sha256=" + hash_object.hexdigest() if not hmac.compare_digest(expected_signature, signature_header): - log.warning(f"Request signature mismatch. Expected: {expected_signature}, Got: {signature_header}") + log.warning( + f"Request signature mismatch. Expected: {expected_signature}, Got: {signature_header}" + ) return False return True + def verify_gitlab_token(secret_token: str, gitlab_token_header: str) -> bool: """Verify that the payload was sent from GitLab by validating the token.""" if not gitlab_token_header: log.warning("No X-Gitlab-Token found on request.") return False if not secret_token: - log.error("Webhook secret is not configured for this repository. Cannot verify token.") + log.error( + "Webhook secret is not configured for this repository. Cannot verify token." + ) return False - if not hmac.compare_digest(secret_token, gitlab_token_header): # Direct comparison for GitLab token + if not hmac.compare_digest( + secret_token, gitlab_token_header + ): # Direct comparison for GitLab token log.warning("Request token mismatch.") return False return True + # Placeholder for other GitHub event formatters # def format_github_issue_embed(payload: Dict[str, Any], repo_url: str) -> discord.Embed: ... # def format_github_pull_request_embed(payload: Dict[str, Any], repo_url: str) -> discord.Embed: ... # def format_github_release_embed(payload: Dict[str, Any], repo_url: str) -> discord.Embed: ... + def format_github_push_embed(payload: Dict[str, Any], repo_url: str) -> discord.Embed: """Formats a GitHub push event payload into a Discord embed.""" try: - repo_name = payload.get('repository', {}).get('full_name', repo_url) - pusher = payload.get('pusher', {}).get('name', 'Unknown Pusher') - compare_url = payload.get('compare', repo_url) + repo_name = payload.get("repository", {}).get("full_name", repo_url) + pusher = payload.get("pusher", {}).get("name", "Unknown Pusher") + compare_url = payload.get("compare", repo_url) embed = discord.Embed( title=f"New Push to {repo_name}", url=compare_url, - color=discord.Color.blue() # Or discord.Color.from_rgb(r, g, b) + color=discord.Color.blue(), # Or discord.Color.from_rgb(r, g, b) ) embed.set_author(name=pusher) - for commit in payload.get('commits', []): - commit_id_short = commit.get('id', 'N/A')[:7] - commit_msg = commit.get('message', 'No commit message.') - commit_url = commit.get('url', '#') - author_name = commit.get('author', {}).get('name', 'Unknown Author') + for commit in payload.get("commits", []): + commit_id_short = commit.get("id", "N/A")[:7] + commit_msg = commit.get("message", "No commit message.") + commit_url = commit.get("url", "#") + author_name = commit.get("author", {}).get("name", "Unknown Author") # Files changed, insertions/deletions - added = commit.get('added', []) - removed = commit.get('removed', []) - modified = commit.get('modified', []) + added = commit.get("added", []) + removed = commit.get("removed", []) + modified = commit.get("modified", []) stats_lines = [] - if added: stats_lines.append(f"+{len(added)} added") - if removed: stats_lines.append(f"-{len(removed)} removed") - if modified: stats_lines.append(f"~{len(modified)} modified") + if added: + stats_lines.append(f"+{len(added)} added") + if removed: + stats_lines.append(f"-{len(removed)} removed") + if modified: + stats_lines.append(f"~{len(modified)} modified") stats_str = ", ".join(stats_lines) if stats_lines else "No file changes." # Verification status (GitHub specific) - verification = commit.get('verification', {}) - verified_status = "Verified" if verification.get('verified') else "Unverified" - if verification.get('reason') and verification.get('reason') != 'unsigned': + verification = commit.get("verification", {}) + verified_status = ( + "Verified" if verification.get("verified") else "Unverified" + ) + if verification.get("reason") and verification.get("reason") != "unsigned": verified_status += f" ({verification.get('reason')})" - field_value = ( f"Author: {author_name}\n" - f"Message: {commit_msg.splitlines()[0]}\n" # First line of commit message. + f"Message: {commit_msg.splitlines()[0]}\n" # First line of commit message. f"Verification: {verified_status}\n" f"Stats: {stats_str}\n" f"[View Commit]({commit_url})" ) - embed.add_field(name=f"Commit `{commit_id_short}`", value=field_value, inline=False) - if len(embed.fields) >= 5: # Limit fields to avoid overly large embeds - embed.add_field(name="...", value=f"And {len(payload.get('commits')) - 5} more commits.", inline=False) + embed.add_field( + name=f"Commit `{commit_id_short}`", value=field_value, inline=False + ) + if len(embed.fields) >= 5: # Limit fields to avoid overly large embeds + embed.add_field( + name="...", + value=f"And {len(payload.get('commits')) - 5} more commits.", + inline=False, + ) break - if not payload.get('commits'): + if not payload.get("commits"): embed.description = "Received push event with no commits (e.g., new branch created without commits)." return embed except Exception as e: log.exception(f"Error formatting GitHub push embed: {e}") - embed = discord.Embed(title="Error Processing GitHub Push Webhook", description=f"Could not parse commit details. Raw payload might be available in logs.\nError: {e}", color=discord.Color.red()) + embed = discord.Embed( + title="Error Processing GitHub Push Webhook", + description=f"Could not parse commit details. Raw payload might be available in logs.\nError: {e}", + color=discord.Color.red(), + ) return embed + # Placeholder for other GitLab event formatters # def format_gitlab_issue_embed(payload: Dict[str, Any], repo_url: str) -> discord.Embed: ... # def format_gitlab_merge_request_embed(payload: Dict[str, Any], repo_url: str) -> discord.Embed: ... # def format_gitlab_tag_push_embed(payload: Dict[str, Any], repo_url: str) -> discord.Embed: ... + def format_gitlab_push_embed(payload: Dict[str, Any], repo_url: str) -> discord.Embed: """Formats a GitLab push event payload into a Discord embed.""" try: - project_name = payload.get('project', {}).get('path_with_namespace', repo_url) - user_name = payload.get('user_name', 'Unknown Pusher') + project_name = payload.get("project", {}).get("path_with_namespace", repo_url) + user_name = payload.get("user_name", "Unknown Pusher") # GitLab's compare URL is not directly in the main payload, but commits have URLs # We can use the project's web_url as a base. - project_web_url = payload.get('project', {}).get('web_url', repo_url) + project_web_url = payload.get("project", {}).get("web_url", repo_url) embed = discord.Embed( title=f"New Push to {project_name}", - url=project_web_url, # Link to project - color=discord.Color.orange() # Or discord.Color.from_rgb(r, g, b) + url=project_web_url, # Link to project + color=discord.Color.orange(), # Or discord.Color.from_rgb(r, g, b) ) embed.set_author(name=user_name) - for commit in payload.get('commits', []): - commit_id_short = commit.get('id', 'N/A')[:7] - commit_msg = commit.get('message', 'No commit message.') - commit_url = commit.get('url', '#') - author_name = commit.get('author', {}).get('name', 'Unknown Author') + for commit in payload.get("commits", []): + commit_id_short = commit.get("id", "N/A")[:7] + commit_msg = commit.get("message", "No commit message.") + commit_url = commit.get("url", "#") + author_name = commit.get("author", {}).get("name", "Unknown Author") # Files changed, insertions/deletions (GitLab provides total counts) # GitLab commit objects don't directly list added/removed/modified files in the same way GitHub does per commit in a push. @@ -247,235 +299,358 @@ def format_gitlab_push_embed(payload: Dict[str, Any], repo_url: str) -> discord. field_value = ( f"Author: {author_name}\n" - f"Message: {commit_msg.splitlines()[0]}\n" # First line + f"Message: {commit_msg.splitlines()[0]}\n" # First line f"[View Commit]({commit_url})" ) - embed.add_field(name=f"Commit `{commit_id_short}`", value=field_value, inline=False) + embed.add_field( + name=f"Commit `{commit_id_short}`", value=field_value, inline=False + ) if len(embed.fields) >= 5: - embed.add_field(name="...", value=f"And {len(payload.get('commits')) - 5} more commits.", inline=False) + embed.add_field( + name="...", + value=f"And {len(payload.get('commits')) - 5} more commits.", + inline=False, + ) break - if not payload.get('commits'): + if not payload.get("commits"): embed.description = "Received push event with no commits (e.g., new branch created or tag pushed)." return embed except Exception as e: log.exception(f"Error formatting GitLab push embed: {e}") - embed = discord.Embed(title="Error Processing GitLab Push Webhook", description=f"Could not parse commit details. Raw payload might be available in logs.\nError: {e}", color=discord.Color.red()) + embed = discord.Embed( + title="Error Processing GitLab Push Webhook", + description=f"Could not parse commit details. Raw payload might be available in logs.\nError: {e}", + color=discord.Color.red(), + ) return embed + # --- GitHub - New Event Formatters --- + def format_github_issues_embed(payload: Dict[str, Any], repo_url: str) -> discord.Embed: """Formats a Discord embed for a GitHub issues event.""" try: - action = payload.get('action', 'Unknown action') - issue_data = payload.get('issue', {}) - repo_name = payload.get('repository', {}).get('full_name', repo_url) - sender = payload.get('sender', {}) + action = payload.get("action", "Unknown action") + issue_data = payload.get("issue", {}) + repo_name = payload.get("repository", {}).get("full_name", repo_url) + sender = payload.get("sender", {}) - title = issue_data.get('title', 'Untitled Issue') - issue_number = issue_data.get('number') - issue_url = issue_data.get('html_url', repo_url) - user_login = sender.get('login', 'Unknown User') - user_url = sender.get('html_url', '#') - user_avatar = sender.get('avatar_url') + title = issue_data.get("title", "Untitled Issue") + issue_number = issue_data.get("number") + issue_url = issue_data.get("html_url", repo_url) + user_login = sender.get("login", "Unknown User") + user_url = sender.get("html_url", "#") + user_avatar = sender.get("avatar_url") - color = discord.Color.green() if action == "opened" else \ - discord.Color.red() if action == "closed" else \ - discord.Color.gold() if action == "reopened" else \ - discord.Color.light_grey() + color = ( + discord.Color.green() + if action == "opened" + else ( + discord.Color.red() + if action == "closed" + else ( + discord.Color.gold() + if action == "reopened" + else discord.Color.light_grey() + ) + ) + ) embed = discord.Embed( title=f"Issue {action.capitalize()}: #{issue_number} {title}", url=issue_url, description=f"Issue in `{repo_name}` was {action}.", - color=color + color=color, ) embed.set_author(name=user_login, url=user_url, icon_url=user_avatar) - if issue_data.get('body') and action == "opened": - body = issue_data['body'] - embed.add_field(name="Description", value=body[:1020] + "..." if len(body) > 1024 else body, inline=False) + if issue_data.get("body") and action == "opened": + body = issue_data["body"] + embed.add_field( + name="Description", + value=body[:1020] + "..." if len(body) > 1024 else body, + inline=False, + ) - if issue_data.get('labels'): - labels = ", ".join([f"`{label['name']}`" for label in issue_data['labels']]) - embed.add_field(name="Labels", value=labels if labels else "None", inline=True) + if issue_data.get("labels"): + labels = ", ".join([f"`{label['name']}`" for label in issue_data["labels"]]) + embed.add_field( + name="Labels", value=labels if labels else "None", inline=True + ) - if issue_data.get('assignee'): - assignee = issue_data['assignee']['login'] - embed.add_field(name="Assignee", value=f"[{assignee}]({issue_data['assignee']['html_url']})", inline=True) - elif issue_data.get('assignees'): - assignees = ", ".join([f"[{a['login']}]({a['html_url']})" for a in issue_data['assignees']]) - embed.add_field(name="Assignees", value=assignees if assignees else "None", inline=True) + if issue_data.get("assignee"): + assignee = issue_data["assignee"]["login"] + embed.add_field( + name="Assignee", + value=f"[{assignee}]({issue_data['assignee']['html_url']})", + inline=True, + ) + elif issue_data.get("assignees"): + assignees = ", ".join( + [f"[{a['login']}]({a['html_url']})" for a in issue_data["assignees"]] + ) + embed.add_field( + name="Assignees", value=assignees if assignees else "None", inline=True + ) return embed except Exception as e: log.error(f"Error formatting GitHub issues embed: {e}\nPayload: {payload}") - return discord.Embed(title="Error Processing GitHub Issue Event", description=str(e), color=discord.Color.red()) + return discord.Embed( + title="Error Processing GitHub Issue Event", + description=str(e), + color=discord.Color.red(), + ) -def format_github_pull_request_embed(payload: Dict[str, Any], repo_url: str) -> discord.Embed: + +def format_github_pull_request_embed( + payload: Dict[str, Any], repo_url: str +) -> discord.Embed: """Formats a Discord embed for a GitHub pull_request event.""" try: - action = payload.get('action', 'Unknown action') - pr_data = payload.get('pull_request', {}) - repo_name = payload.get('repository', {}).get('full_name', repo_url) - sender = payload.get('sender', {}) + action = payload.get("action", "Unknown action") + pr_data = payload.get("pull_request", {}) + repo_name = payload.get("repository", {}).get("full_name", repo_url) + sender = payload.get("sender", {}) - title = pr_data.get('title', 'Untitled Pull Request') - pr_number = payload.get('number', pr_data.get('number')) # 'number' is top-level for some PR actions - pr_url = pr_data.get('html_url', repo_url) - user_login = sender.get('login', 'Unknown User') - user_url = sender.get('html_url', '#') - user_avatar = sender.get('avatar_url') + title = pr_data.get("title", "Untitled Pull Request") + pr_number = payload.get( + "number", pr_data.get("number") + ) # 'number' is top-level for some PR actions + pr_url = pr_data.get("html_url", repo_url) + user_login = sender.get("login", "Unknown User") + user_url = sender.get("html_url", "#") + user_avatar = sender.get("avatar_url") - color = discord.Color.green() if action == "opened" else \ - discord.Color.red() if action == "closed" and pr_data.get('merged') is False else \ - discord.Color.purple() if action == "closed" and pr_data.get('merged') is True else \ - discord.Color.gold() if action == "reopened" else \ - discord.Color.blue() if action in ["synchronize", "ready_for_review"] else \ - discord.Color.light_grey() + color = ( + discord.Color.green() + if action == "opened" + else ( + discord.Color.red() + if action == "closed" and pr_data.get("merged") is False + else ( + discord.Color.purple() + if action == "closed" and pr_data.get("merged") is True + else ( + discord.Color.gold() + if action == "reopened" + else ( + discord.Color.blue() + if action in ["synchronize", "ready_for_review"] + else discord.Color.light_grey() + ) + ) + ) + ) + ) description = f"Pull Request #{pr_number} in `{repo_name}` was {action}." - if action == "closed" and pr_data.get('merged'): + if action == "closed" and pr_data.get("merged"): description = f"Pull Request #{pr_number} in `{repo_name}` was merged." embed = discord.Embed( title=f"PR {action.capitalize()}: #{pr_number} {title}", url=pr_url, description=description, - color=color + color=color, ) embed.set_author(name=user_login, url=user_url, icon_url=user_avatar) - if pr_data.get('body') and action == "opened": - body = pr_data['body'] - embed.add_field(name="Description", value=body[:1020] + "..." if len(body) > 1024 else body, inline=False) + if pr_data.get("body") and action == "opened": + body = pr_data["body"] + embed.add_field( + name="Description", + value=body[:1020] + "..." if len(body) > 1024 else body, + inline=False, + ) - embed.add_field(name="Base Branch", value=f"`{pr_data.get('base', {}).get('ref', 'N/A')}`", inline=True) - embed.add_field(name="Head Branch", value=f"`{pr_data.get('head', {}).get('ref', 'N/A')}`", inline=True) + embed.add_field( + name="Base Branch", + value=f"`{pr_data.get('base', {}).get('ref', 'N/A')}`", + inline=True, + ) + embed.add_field( + name="Head Branch", + value=f"`{pr_data.get('head', {}).get('ref', 'N/A')}`", + inline=True, + ) if action == "closed": - merged_by = pr_data.get('merged_by') + merged_by = pr_data.get("merged_by") if merged_by: - embed.add_field(name="Merged By", value=f"[{merged_by['login']}]({merged_by['html_url']})", inline=True) + embed.add_field( + name="Merged By", + value=f"[{merged_by['login']}]({merged_by['html_url']})", + inline=True, + ) else: - embed.add_field(name="Status", value="Closed without merging", inline=True) - + embed.add_field( + name="Status", value="Closed without merging", inline=True + ) return embed except Exception as e: log.error(f"Error formatting GitHub PR embed: {e}\nPayload: {payload}") - return discord.Embed(title="Error Processing GitHub PR Event", description=str(e), color=discord.Color.red()) + return discord.Embed( + title="Error Processing GitHub PR Event", + description=str(e), + color=discord.Color.red(), + ) -def format_github_release_embed(payload: Dict[str, Any], repo_url: str) -> discord.Embed: + +def format_github_release_embed( + payload: Dict[str, Any], repo_url: str +) -> discord.Embed: """Formats a Discord embed for a GitHub release event.""" try: - action = payload.get('action', 'Unknown action') # e.g., published, created, edited - release_data = payload.get('release', {}) - repo_name = payload.get('repository', {}).get('full_name', repo_url) - sender = payload.get('sender', {}) + action = payload.get( + "action", "Unknown action" + ) # e.g., published, created, edited + release_data = payload.get("release", {}) + repo_name = payload.get("repository", {}).get("full_name", repo_url) + sender = payload.get("sender", {}) - tag_name = release_data.get('tag_name', 'N/A') - release_name = release_data.get('name', tag_name) - release_url = release_data.get('html_url', repo_url) - user_login = sender.get('login', 'Unknown User') - user_url = sender.get('html_url', '#') - user_avatar = sender.get('avatar_url') + tag_name = release_data.get("tag_name", "N/A") + release_name = release_data.get("name", tag_name) + release_url = release_data.get("html_url", repo_url) + user_login = sender.get("login", "Unknown User") + user_url = sender.get("html_url", "#") + user_avatar = sender.get("avatar_url") - color = discord.Color.teal() if action == "published" else discord.Color.blurple() + color = ( + discord.Color.teal() if action == "published" else discord.Color.blurple() + ) embed = discord.Embed( title=f"Release {action.capitalize()}: {release_name}", url=release_url, description=f"A new release `{tag_name}` was {action} in `{repo_name}`.", - color=color + color=color, ) embed.set_author(name=user_login, url=user_url, icon_url=user_avatar) - if release_data.get('body'): - body = release_data['body'] - embed.add_field(name="Release Notes", value=body[:1020] + "..." if len(body) > 1024 else body, inline=False) + if release_data.get("body"): + body = release_data["body"] + embed.add_field( + name="Release Notes", + value=body[:1020] + "..." if len(body) > 1024 else body, + inline=False, + ) return embed except Exception as e: log.error(f"Error formatting GitHub release embed: {e}\nPayload: {payload}") - return discord.Embed(title="Error Processing GitHub Release Event", description=str(e), color=discord.Color.red()) + return discord.Embed( + title="Error Processing GitHub Release Event", + description=str(e), + color=discord.Color.red(), + ) -def format_github_issue_comment_embed(payload: Dict[str, Any], repo_url: str) -> discord.Embed: + +def format_github_issue_comment_embed( + payload: Dict[str, Any], repo_url: str +) -> discord.Embed: """Formats a Discord embed for a GitHub issue_comment event.""" try: - action = payload.get('action', 'Unknown action') # created, edited, deleted - comment_data = payload.get('comment', {}) - issue_data = payload.get('issue', {}) - repo_name = payload.get('repository', {}).get('full_name', repo_url) - sender = payload.get('sender', {}) + action = payload.get("action", "Unknown action") # created, edited, deleted + comment_data = payload.get("comment", {}) + issue_data = payload.get("issue", {}) + repo_name = payload.get("repository", {}).get("full_name", repo_url) + sender = payload.get("sender", {}) - comment_url = comment_data.get('html_url', repo_url) - user_login = sender.get('login', 'Unknown User') - user_url = sender.get('html_url', '#') - user_avatar = sender.get('avatar_url') + comment_url = comment_data.get("html_url", repo_url) + user_login = sender.get("login", "Unknown User") + user_url = sender.get("html_url", "#") + user_avatar = sender.get("avatar_url") - issue_title = issue_data.get('title', 'Untitled Issue') - issue_number = issue_data.get('number') + issue_title = issue_data.get("title", "Untitled Issue") + issue_number = issue_data.get("number") color = discord.Color.greyple() embed = discord.Embed( title=f"Comment {action} on Issue #{issue_number}: {issue_title}", url=comment_url, - color=color + color=color, ) embed.set_author(name=user_login, url=user_url, icon_url=user_avatar) - if comment_data.get('body'): - body = comment_data['body'] + if comment_data.get("body"): + body = comment_data["body"] embed.description = body[:2040] + "..." if len(body) > 2048 else body return embed except Exception as e: - log.error(f"Error formatting GitHub issue_comment embed: {e}\nPayload: {payload}") - return discord.Embed(title="Error Processing GitHub Issue Comment Event", description=str(e), color=discord.Color.red()) + log.error( + f"Error formatting GitHub issue_comment embed: {e}\nPayload: {payload}" + ) + return discord.Embed( + title="Error Processing GitHub Issue Comment Event", + description=str(e), + color=discord.Color.red(), + ) + # --- GitLab - New Event Formatters --- + def format_gitlab_issue_embed(payload: Dict[str, Any], repo_url: str) -> discord.Embed: """Formats a Discord embed for a GitLab issue event (object_kind: 'issue').""" try: - attributes = payload.get('object_attributes', {}) - user = payload.get('user', {}) - project_data = payload.get('project', {}) - repo_name = project_data.get('path_with_namespace', repo_url) + attributes = payload.get("object_attributes", {}) + user = payload.get("user", {}) + project_data = payload.get("project", {}) + repo_name = project_data.get("path_with_namespace", repo_url) - action = attributes.get('action', 'unknown') # open, close, reopen, update - title = attributes.get('title', 'Untitled Issue') - issue_iid = attributes.get('iid') # Internal ID for display - issue_url = attributes.get('url', repo_url) - user_name = user.get('name', 'Unknown User') - user_avatar = user.get('avatar_url') + action = attributes.get("action", "unknown") # open, close, reopen, update + title = attributes.get("title", "Untitled Issue") + issue_iid = attributes.get("iid") # Internal ID for display + issue_url = attributes.get("url", repo_url) + user_name = user.get("name", "Unknown User") + user_avatar = user.get("avatar_url") - color = discord.Color.green() if action == "open" else \ - discord.Color.red() if action == "close" else \ - discord.Color.gold() if action == "reopen" else \ - discord.Color.light_grey() + color = ( + discord.Color.green() + if action == "open" + else ( + discord.Color.red() + if action == "close" + else ( + discord.Color.gold() + if action == "reopen" + else discord.Color.light_grey() + ) + ) + ) embed = discord.Embed( title=f"Issue {action.capitalize()}: #{issue_iid} {title}", url=issue_url, description=f"Issue in `{repo_name}` was {action}.", - color=color + color=color, ) embed.set_author(name=user_name, icon_url=user_avatar) - if attributes.get('description') and action == "open": - desc = attributes['description'] - embed.add_field(name="Description", value=desc[:1020] + "..." if len(desc) > 1024 else desc, inline=False) + if attributes.get("description") and action == "open": + desc = attributes["description"] + embed.add_field( + name="Description", + value=desc[:1020] + "..." if len(desc) > 1024 else desc, + inline=False, + ) - if attributes.get('labels'): - labels = ", ".join([f"`{label['title']}`" for label in attributes['labels']]) - embed.add_field(name="Labels", value=labels if labels else "None", inline=True) + if attributes.get("labels"): + labels = ", ".join( + [f"`{label['title']}`" for label in attributes["labels"]] + ) + embed.add_field( + name="Labels", value=labels if labels else "None", inline=True + ) - assignees_data = payload.get('assignees', []) + assignees_data = payload.get("assignees", []) if assignees_data: assignees = ", ".join([f"{a['name']}" for a in assignees_data]) embed.add_field(name="Assignees", value=assignees, inline=True) @@ -483,29 +658,53 @@ def format_gitlab_issue_embed(payload: Dict[str, Any], repo_url: str) -> discord return embed except Exception as e: log.error(f"Error formatting GitLab issue embed: {e}\nPayload: {payload}") - return discord.Embed(title="Error Processing GitLab Issue Event", description=str(e), color=discord.Color.red()) + return discord.Embed( + title="Error Processing GitLab Issue Event", + description=str(e), + color=discord.Color.red(), + ) -def format_gitlab_merge_request_embed(payload: Dict[str, Any], repo_url: str) -> discord.Embed: + +def format_gitlab_merge_request_embed( + payload: Dict[str, Any], repo_url: str +) -> discord.Embed: """Formats a Discord embed for a GitLab merge_request event.""" try: - attributes = payload.get('object_attributes', {}) - user = payload.get('user', {}) - project_data = payload.get('project', {}) - repo_name = project_data.get('path_with_namespace', repo_url) + attributes = payload.get("object_attributes", {}) + user = payload.get("user", {}) + project_data = payload.get("project", {}) + repo_name = project_data.get("path_with_namespace", repo_url) - action = attributes.get('action', 'unknown') # open, close, reopen, update, merge - title = attributes.get('title', 'Untitled Merge Request') - mr_iid = attributes.get('iid') - mr_url = attributes.get('url', repo_url) - user_name = user.get('name', 'Unknown User') - user_avatar = user.get('avatar_url') + action = attributes.get( + "action", "unknown" + ) # open, close, reopen, update, merge + title = attributes.get("title", "Untitled Merge Request") + mr_iid = attributes.get("iid") + mr_url = attributes.get("url", repo_url) + user_name = user.get("name", "Unknown User") + user_avatar = user.get("avatar_url") - color = discord.Color.green() if action == "open" else \ - discord.Color.red() if action == "close" else \ - discord.Color.purple() if action == "merge" else \ - discord.Color.gold() if action == "reopen" else \ - discord.Color.blue() if action == "update" else \ - discord.Color.light_grey() + color = ( + discord.Color.green() + if action == "open" + else ( + discord.Color.red() + if action == "close" + else ( + discord.Color.purple() + if action == "merge" + else ( + discord.Color.gold() + if action == "reopen" + else ( + discord.Color.blue() + if action == "update" + else discord.Color.light_grey() + ) + ) + ) + ) + ) description = f"Merge Request !{mr_iid} in `{repo_name}` was {action}." if action == "merge": @@ -515,35 +714,58 @@ def format_gitlab_merge_request_embed(payload: Dict[str, Any], repo_url: str) -> title=f"MR {action.capitalize()}: !{mr_iid} {title}", url=mr_url, description=description, - color=color + color=color, ) embed.set_author(name=user_name, icon_url=user_avatar) - if attributes.get('description') and action == "open": - desc = attributes['description'] - embed.add_field(name="Description", value=desc[:1020] + "..." if len(desc) > 1024 else desc, inline=False) + if attributes.get("description") and action == "open": + desc = attributes["description"] + embed.add_field( + name="Description", + value=desc[:1020] + "..." if len(desc) > 1024 else desc, + inline=False, + ) - embed.add_field(name="Source Branch", value=f"`{attributes.get('source_branch', 'N/A')}`", inline=True) - embed.add_field(name="Target Branch", value=f"`{attributes.get('target_branch', 'N/A')}`", inline=True) + embed.add_field( + name="Source Branch", + value=f"`{attributes.get('source_branch', 'N/A')}`", + inline=True, + ) + embed.add_field( + name="Target Branch", + value=f"`{attributes.get('target_branch', 'N/A')}`", + inline=True, + ) - if action == "merge" and attributes.get('merge_commit_sha'): - embed.add_field(name="Merge Commit", value=f"`{attributes['merge_commit_sha'][:8]}`", inline=True) + if action == "merge" and attributes.get("merge_commit_sha"): + embed.add_field( + name="Merge Commit", + value=f"`{attributes['merge_commit_sha'][:8]}`", + inline=True, + ) return embed except Exception as e: log.error(f"Error formatting GitLab MR embed: {e}\nPayload: {payload}") - return discord.Embed(title="Error Processing GitLab MR Event", description=str(e), color=discord.Color.red()) + return discord.Embed( + title="Error Processing GitLab MR Event", + description=str(e), + color=discord.Color.red(), + ) -def format_gitlab_release_embed(payload: Dict[str, Any], repo_url: str) -> discord.Embed: + +def format_gitlab_release_embed( + payload: Dict[str, Any], repo_url: str +) -> discord.Embed: """Formats a Discord embed for a GitLab release event.""" try: # GitLab release webhook payload structure is simpler - action = payload.get('action', 'created') # create, update - tag_name = payload.get('tag', 'N/A') - release_name = payload.get('name', tag_name) - release_url = payload.get('url', repo_url) - project_data = payload.get('project', {}) - repo_name = project_data.get('path_with_namespace', repo_url) + action = payload.get("action", "created") # create, update + tag_name = payload.get("tag", "N/A") + release_name = payload.get("name", tag_name) + release_url = payload.get("url", repo_url) + project_data = payload.get("project", {}) + repo_name = project_data.get("path_with_namespace", repo_url) # GitLab release hooks don't typically include a 'user' who performed the action directly in the root. # It might be inferred or logged differently by GitLab. For now, we'll omit a specific author. @@ -553,74 +775,93 @@ def format_gitlab_release_embed(payload: Dict[str, Any], repo_url: str) -> disco title=f"Release {action.capitalize()}: {release_name}", url=release_url, description=f"A release `{tag_name}` was {action} in `{repo_name}`.", - color=color + color=color, ) # embed.set_author(name=project_data.get('namespace', 'GitLab')) # Or project name - if payload.get('description'): - desc = payload['description'] - embed.add_field(name="Release Notes", value=desc[:1020] + "..." if len(desc) > 1024 else desc, inline=False) + if payload.get("description"): + desc = payload["description"] + embed.add_field( + name="Release Notes", + value=desc[:1020] + "..." if len(desc) > 1024 else desc, + inline=False, + ) return embed except Exception as e: log.error(f"Error formatting GitLab release embed: {e}\nPayload: {payload}") - return discord.Embed(title="Error Processing GitLab Release Event", description=str(e), color=discord.Color.red()) + return discord.Embed( + title="Error Processing GitLab Release Event", + description=str(e), + color=discord.Color.red(), + ) + def format_gitlab_note_embed(payload: Dict[str, Any], repo_url: str) -> discord.Embed: """Formats a Discord embed for a GitLab note event (comments).""" try: - attributes = payload.get('object_attributes', {}) - user = payload.get('user', {}) - project_data = payload.get('project', {}) - repo_name = project_data.get('path_with_namespace', repo_url) + attributes = payload.get("object_attributes", {}) + user = payload.get("user", {}) + project_data = payload.get("project", {}) + repo_name = project_data.get("path_with_namespace", repo_url) - note_type = attributes.get('noteable_type', 'Comment') # Issue, MergeRequest, Commit, Snippet - note_url = attributes.get('url', repo_url) - user_name = user.get('name', 'Unknown User') - user_avatar = user.get('avatar_url') + note_type = attributes.get( + "noteable_type", "Comment" + ) # Issue, MergeRequest, Commit, Snippet + note_url = attributes.get("url", repo_url) + user_name = user.get("name", "Unknown User") + user_avatar = user.get("avatar_url") title_prefix = "New Comment" target_info = "" - if note_type == 'Commit': - commit_data = payload.get('commit', {}) + if note_type == "Commit": + commit_data = payload.get("commit", {}) title_prefix = f"Comment on Commit `{commit_data.get('id', 'N/A')[:7]}`" - elif note_type == 'Issue': - issue_data = payload.get('issue', {}) + elif note_type == "Issue": + issue_data = payload.get("issue", {}) title_prefix = f"Comment on Issue #{issue_data.get('iid', 'N/A')}" - target_info = issue_data.get('title', '') - elif note_type == 'MergeRequest': - mr_data = payload.get('merge_request', {}) + target_info = issue_data.get("title", "") + elif note_type == "MergeRequest": + mr_data = payload.get("merge_request", {}) title_prefix = f"Comment on MR !{mr_data.get('iid', 'N/A')}" - target_info = mr_data.get('title', '') - elif note_type == 'Snippet': - snippet_data = payload.get('snippet', {}) + target_info = mr_data.get("title", "") + elif note_type == "Snippet": + snippet_data = payload.get("snippet", {}) title_prefix = f"Comment on Snippet #{snippet_data.get('id', 'N/A')}" - target_info = snippet_data.get('title', '') + target_info = snippet_data.get("title", "") embed = discord.Embed( title=f"{title_prefix}: {target_info}".strip(), url=note_url, - color=discord.Color.greyple() + color=discord.Color.greyple(), ) embed.set_author(name=user_name, icon_url=user_avatar) - if attributes.get('note'): - note_body = attributes['note'] - embed.description = note_body[:2040] + "..." if len(note_body) > 2048 else note_body + if attributes.get("note"): + note_body = attributes["note"] + embed.description = ( + note_body[:2040] + "..." if len(note_body) > 2048 else note_body + ) embed.set_footer(text=f"Comment in {repo_name}") return embed except Exception as e: log.error(f"Error formatting GitLab note embed: {e}\nPayload: {payload}") - return discord.Embed(title="Error Processing GitLab Note Event", description=str(e), color=discord.Color.red()) + return discord.Embed( + title="Error Processing GitLab Note Event", + description=str(e), + color=discord.Color.red(), + ) @router.post("/github/{repo_db_id}") async def webhook_github( request: Request, - repo_db_id: int = Path(..., description="The database ID of the monitored repository"), - x_hub_signature_256: Optional[str] = Header(None) + repo_db_id: int = Path( + ..., description="The database ID of the monitored repository" + ), + x_hub_signature_256: Optional[str] = Header(None), ): log.info(f"Received GitHub webhook for repo_db_id: {repo_db_id}") payload_bytes = await request.body() @@ -629,48 +870,83 @@ async def webhook_github( repo_config = await get_monitored_repository_by_id_api(request, repo_db_id) if not repo_config: log.error(f"No repository configuration found for repo_db_id: {repo_db_id}") - raise HTTPException(status_code=404, detail="Repository configuration not found.") + raise HTTPException( + status_code=404, detail="Repository configuration not found." + ) - if repo_config['monitoring_method'] != 'webhook' or repo_config['platform'] != 'github': + if ( + repo_config["monitoring_method"] != "webhook" + or repo_config["platform"] != "github" + ): log.error(f"Repository {repo_db_id} is not configured for GitHub webhooks.") - raise HTTPException(status_code=400, detail="Repository not configured for GitHub webhooks.") + raise HTTPException( + status_code=400, detail="Repository not configured for GitHub webhooks." + ) - if not verify_github_signature(payload_bytes, repo_config['webhook_secret'], x_hub_signature_256): + if not verify_github_signature( + payload_bytes, repo_config["webhook_secret"], x_hub_signature_256 + ): log.warning(f"Invalid GitHub signature for repo_db_id: {repo_db_id}") raise HTTPException(status_code=403, detail="Invalid signature.") try: - payload = json.loads(payload_bytes.decode('utf-8')) + payload = json.loads(payload_bytes.decode("utf-8")) except json.JSONDecodeError: - log.error(f"Invalid JSON payload received for GitHub webhook, repo_db_id: {repo_db_id}") + log.error( + f"Invalid JSON payload received for GitHub webhook, repo_db_id: {repo_db_id}" + ) raise HTTPException(status_code=400, detail="Invalid JSON payload.") log.debug(f"GitHub webhook payload for {repo_db_id}: {payload}") event_type = request.headers.get("X-GitHub-Event") - allowed_events = repo_config.get('allowed_webhook_events', ['push']) # Default to 'push' + allowed_events = repo_config.get( + "allowed_webhook_events", ["push"] + ) # Default to 'push' if event_type not in allowed_events: - log.info(f"Ignoring GitHub event type '{event_type}' for repo_db_id: {repo_db_id} as it's not in allowed events: {allowed_events}") - return {"status": "success", "message": f"Event type '{event_type}' ignored per configuration."} + log.info( + f"Ignoring GitHub event type '{event_type}' for repo_db_id: {repo_db_id} as it's not in allowed events: {allowed_events}" + ) + return { + "status": "success", + "message": f"Event type '{event_type}' ignored per configuration.", + } discord_embed = None if event_type == "push": - if not payload.get('commits') and not payload.get('deleted', False): # Also consider branch deletion as a push event - log.info(f"GitHub push event for {repo_db_id} has no commits and is not a delete event. Ignoring.") - return {"status": "success", "message": "Push event with no commits ignored."} - discord_embed = format_github_push_embed(payload, repo_config['repository_url']) + if not payload.get("commits") and not payload.get( + "deleted", False + ): # Also consider branch deletion as a push event + log.info( + f"GitHub push event for {repo_db_id} has no commits and is not a delete event. Ignoring." + ) + return { + "status": "success", + "message": "Push event with no commits ignored.", + } + discord_embed = format_github_push_embed(payload, repo_config["repository_url"]) elif event_type == "issues": - discord_embed = format_github_issues_embed(payload, repo_config['repository_url']) + discord_embed = format_github_issues_embed( + payload, repo_config["repository_url"] + ) elif event_type == "pull_request": - discord_embed = format_github_pull_request_embed(payload, repo_config['repository_url']) + discord_embed = format_github_pull_request_embed( + payload, repo_config["repository_url"] + ) elif event_type == "release": - discord_embed = format_github_release_embed(payload, repo_config['repository_url']) + discord_embed = format_github_release_embed( + payload, repo_config["repository_url"] + ) elif event_type == "issue_comment": - discord_embed = format_github_issue_comment_embed(payload, repo_config['repository_url']) + discord_embed = format_github_issue_comment_embed( + payload, repo_config["repository_url"] + ) # Add other specific event types above this else block else: - log.info(f"GitHub event type '{event_type}' is allowed but not yet handled by a specific formatter for repo_db_id: {repo_db_id}. Sending generic message.") + log.info( + f"GitHub event type '{event_type}' is allowed but not yet handled by a specific formatter for repo_db_id: {repo_db_id}. Sending generic message." + ) # For unhandled but allowed events, send a generic notification or log. # For now, we'll just acknowledge. If you want to notify for all allowed events, create generic formatter. # return {"status": "success", "message": f"Event type '{event_type}' received but no specific formatter yet."} @@ -678,20 +954,37 @@ async def webhook_github( embed_title = f"GitHub Event: {event_type.replace('_', ' ').title()} in {repo_config.get('repository_url')}" embed_description = f"Received a '{event_type}' event." # Try to get a relevant URL - action_url = payload.get('repository', {}).get('html_url', '#') - if event_type == 'issues' and 'issue' in payload and 'html_url' in payload['issue']: - action_url = payload['issue']['html_url'] - elif event_type == 'pull_request' and 'pull_request' in payload and 'html_url' in payload['pull_request']: - action_url = payload['pull_request']['html_url'] - - discord_embed = discord.Embed(title=embed_title, description=embed_description, url=action_url, color=discord.Color.light_grey()) + action_url = payload.get("repository", {}).get("html_url", "#") + if ( + event_type == "issues" + and "issue" in payload + and "html_url" in payload["issue"] + ): + action_url = payload["issue"]["html_url"] + elif ( + event_type == "pull_request" + and "pull_request" in payload + and "html_url" in payload["pull_request"] + ): + action_url = payload["pull_request"]["html_url"] + discord_embed = discord.Embed( + title=embed_title, + description=embed_description, + url=action_url, + color=discord.Color.light_grey(), + ) if not discord_embed: - log.warning(f"No embed generated for allowed GitHub event '{event_type}' for repo {repo_db_id}. This shouldn't happen if event is handled.") - return {"status": "error", "message": "Embed generation failed for an allowed event."} + log.warning( + f"No embed generated for allowed GitHub event '{event_type}' for repo {repo_db_id}. This shouldn't happen if event is handled." + ) + return { + "status": "error", + "message": "Embed generation failed for an allowed event.", + } - notification_channel_id = repo_config['notification_channel_id'] + notification_channel_id = repo_config["notification_channel_id"] # Convert embed to dict for sending via API send_payload_dict = {"embeds": [discord_embed.to_dict()]} @@ -699,33 +992,52 @@ async def webhook_github( # Use the send_discord_message_via_api from api_server.py # This requires DISCORD_BOT_TOKEN to be set in the environment for api_server if not api_settings.DISCORD_BOT_TOKEN: - log.error("DISCORD_BOT_TOKEN not configured in API settings. Cannot send webhook notification.") + log.error( + "DISCORD_BOT_TOKEN not configured in API settings. Cannot send webhook notification." + ) # Still return 200 to GitHub to acknowledge receipt, but log error. - return {"status": "error", "message": "Notification sending failed (bot token not configured)."} + return { + "status": "error", + "message": "Notification sending failed (bot token not configured).", + } - log.info(f"Sending GitHub notification to channel {notification_channel_id} for repo {repo_db_id}.") + log.info( + f"Sending GitHub notification to channel {notification_channel_id} for repo {repo_db_id}." + ) # Send the embed using the send_discord_message_via_api function # The function can handle dict content with embeds send_result = await send_discord_message_via_api( channel_id=notification_channel_id, - content=send_payload_dict # Pass the dict directly + content=send_payload_dict, # Pass the dict directly ) if send_result.get("success"): - log.info(f"Successfully sent GitHub webhook notification for repo {repo_db_id} to channel {notification_channel_id}.") - return {"status": "success", "message": "Webhook received and notification sent."} + log.info( + f"Successfully sent GitHub webhook notification for repo {repo_db_id} to channel {notification_channel_id}." + ) + return { + "status": "success", + "message": "Webhook received and notification sent.", + } else: - log.error(f"Failed to send GitHub webhook notification for repo {repo_db_id}. Error: {send_result.get('message')}") + log.error( + f"Failed to send GitHub webhook notification for repo {repo_db_id}. Error: {send_result.get('message')}" + ) # Still return 200 to GitHub to acknowledge receipt, but log the internal failure. - return {"status": "error", "message": f"Webhook received, but notification failed: {send_result.get('message')}"} + return { + "status": "error", + "message": f"Webhook received, but notification failed: {send_result.get('message')}", + } @router.post("/gitlab/{repo_db_id}") async def webhook_gitlab( request: Request, - repo_db_id: int = Path(..., description="The database ID of the monitored repository"), - x_gitlab_token: Optional[str] = Header(None) + repo_db_id: int = Path( + ..., description="The database ID of the monitored repository" + ), + x_gitlab_token: Optional[str] = Header(None), ): log.info(f"Received GitLab webhook for repo_db_id: {repo_db_id}") payload_bytes = await request.body() @@ -734,27 +1046,38 @@ async def webhook_gitlab( repo_config = await get_monitored_repository_by_id_api(request, repo_db_id) if not repo_config: log.error(f"No repository configuration found for repo_db_id: {repo_db_id}") - raise HTTPException(status_code=404, detail="Repository configuration not found.") + raise HTTPException( + status_code=404, detail="Repository configuration not found." + ) - if repo_config['monitoring_method'] != 'webhook' or repo_config['platform'] != 'gitlab': + if ( + repo_config["monitoring_method"] != "webhook" + or repo_config["platform"] != "gitlab" + ): log.error(f"Repository {repo_db_id} is not configured for GitLab webhooks.") - raise HTTPException(status_code=400, detail="Repository not configured for GitLab webhooks.") + raise HTTPException( + status_code=400, detail="Repository not configured for GitLab webhooks." + ) - if not verify_gitlab_token(repo_config['webhook_secret'], x_gitlab_token): + if not verify_gitlab_token(repo_config["webhook_secret"], x_gitlab_token): log.warning(f"Invalid GitLab token for repo_db_id: {repo_db_id}") raise HTTPException(status_code=403, detail="Invalid token.") try: - payload = json.loads(payload_bytes.decode('utf-8')) + payload = json.loads(payload_bytes.decode("utf-8")) except json.JSONDecodeError: - log.error(f"Invalid JSON payload received for GitLab webhook, repo_db_id: {repo_db_id}") + log.error( + f"Invalid JSON payload received for GitLab webhook, repo_db_id: {repo_db_id}" + ) raise HTTPException(status_code=400, detail="Invalid JSON payload.") log.debug(f"GitLab webhook payload for {repo_db_id}: {payload}") # GitLab uses 'object_kind' for event type, or 'event_name' for system hooks event_type = payload.get("object_kind", payload.get("event_name")) - allowed_events = repo_config.get('allowed_webhook_events', ['push']) # Default to 'push' (GitLab calls push hooks 'push events' or 'tag_push events') + allowed_events = repo_config.get( + "allowed_webhook_events", ["push"] + ) # Default to 'push' (GitLab calls push hooks 'push events' or 'tag_push events') # Normalize GitLab event types if needed, e.g. 'push' for 'push_hook' or 'tag_push_hook' # For now, assume direct match or that 'push' covers both. @@ -764,87 +1087,141 @@ async def webhook_gitlab( # Let's simplify: if 'push' is in allowed_events, we'll accept 'push' and 'tag_push' object_kinds. effective_event_type = event_type - if event_type == "tag_push" and "push" in allowed_events and "tag_push" not in allowed_events: + if ( + event_type == "tag_push" + and "push" in allowed_events + and "tag_push" not in allowed_events + ): # If only "push" is allowed, but we receive "tag_push", treat it as a push for now. # This logic might need refinement based on how granular the user wants control. - pass # It will be caught by the 'push' check if 'push' is allowed. + pass # It will be caught by the 'push' check if 'push' is allowed. is_event_allowed = False if event_type in allowed_events: is_event_allowed = True - elif event_type == "tag_push" and "push" in allowed_events: # Special handling if 'push' implies 'tag_push' + elif ( + event_type == "tag_push" and "push" in allowed_events + ): # Special handling if 'push' implies 'tag_push' is_event_allowed = True - effective_event_type = "push" # Treat as push for formatter if only push is configured + effective_event_type = ( + "push" # Treat as push for formatter if only push is configured + ) if not is_event_allowed: - log.info(f"Ignoring GitLab event type '{event_type}' (object_kind/event_name) for repo_db_id: {repo_db_id} as it's not in allowed events: {allowed_events}") - return {"status": "success", "message": f"Event type '{event_type}' ignored per configuration."} + log.info( + f"Ignoring GitLab event type '{event_type}' (object_kind/event_name) for repo_db_id: {repo_db_id} as it's not in allowed events: {allowed_events}" + ) + return { + "status": "success", + "message": f"Event type '{event_type}' ignored per configuration.", + } discord_embed = None # Use effective_event_type for choosing formatter - if effective_event_type == "push": # This will catch 'push' and 'tag_push' if 'push' is allowed - if not payload.get('commits') and payload.get('total_commits_count', 0) == 0: - log.info(f"GitLab push event for {repo_db_id} has no commits. Ignoring.") - return {"status": "success", "message": "Push event with no commits ignored."} - discord_embed = format_gitlab_push_embed(payload, repo_config['repository_url']) - elif effective_event_type == "issue": # Matches object_kind 'issue' - discord_embed = format_gitlab_issue_embed(payload, repo_config['repository_url']) + if ( + effective_event_type == "push" + ): # This will catch 'push' and 'tag_push' if 'push' is allowed + if not payload.get("commits") and payload.get("total_commits_count", 0) == 0: + log.info(f"GitLab push event for {repo_db_id} has no commits. Ignoring.") + return { + "status": "success", + "message": "Push event with no commits ignored.", + } + discord_embed = format_gitlab_push_embed(payload, repo_config["repository_url"]) + elif effective_event_type == "issue": # Matches object_kind 'issue' + discord_embed = format_gitlab_issue_embed( + payload, repo_config["repository_url"] + ) elif effective_event_type == "merge_request": - discord_embed = format_gitlab_merge_request_embed(payload, repo_config['repository_url']) + discord_embed = format_gitlab_merge_request_embed( + payload, repo_config["repository_url"] + ) elif effective_event_type == "release": - discord_embed = format_gitlab_release_embed(payload, repo_config['repository_url']) - elif effective_event_type == "note": # For comments - discord_embed = format_gitlab_note_embed(payload, repo_config['repository_url']) + discord_embed = format_gitlab_release_embed( + payload, repo_config["repository_url"] + ) + elif effective_event_type == "note": # For comments + discord_embed = format_gitlab_note_embed(payload, repo_config["repository_url"]) # Add other specific event types above this else block else: - log.info(f"GitLab event type '{event_type}' (effective: {effective_event_type}) is allowed but not yet handled by a specific formatter for repo_db_id: {repo_db_id}. Sending generic message.") + log.info( + f"GitLab event type '{event_type}' (effective: {effective_event_type}) is allowed but not yet handled by a specific formatter for repo_db_id: {repo_db_id}. Sending generic message." + ) embed_title = f"GitLab Event: {event_type.replace('_', ' ').title()} in {repo_config.get('repository_url')}" embed_description = f"Received a '{event_type}' event." - action_url = payload.get('project', {}).get('web_url', '#') + action_url = payload.get("project", {}).get("web_url", "#") # Try to get more specific URLs for common GitLab events - if 'object_attributes' in payload and 'url' in payload['object_attributes']: - action_url = payload['object_attributes']['url'] - elif 'project' in payload and 'web_url' in payload['project']: - action_url = payload['project']['web_url'] - - discord_embed = discord.Embed(title=embed_title, description=embed_description, url=action_url, color=discord.Color.dark_orange()) + if "object_attributes" in payload and "url" in payload["object_attributes"]: + action_url = payload["object_attributes"]["url"] + elif "project" in payload and "web_url" in payload["project"]: + action_url = payload["project"]["web_url"] + discord_embed = discord.Embed( + title=embed_title, + description=embed_description, + url=action_url, + color=discord.Color.dark_orange(), + ) if not discord_embed: - log.warning(f"No embed generated for allowed GitLab event '{event_type}' for repo {repo_db_id}.") - return {"status": "error", "message": "Embed generation failed for an allowed event."} + log.warning( + f"No embed generated for allowed GitLab event '{event_type}' for repo {repo_db_id}." + ) + return { + "status": "error", + "message": "Embed generation failed for an allowed event.", + } - notification_channel_id = repo_config['notification_channel_id'] + notification_channel_id = repo_config["notification_channel_id"] # Use the send_discord_message_via_api from api_server.py # This requires DISCORD_BOT_TOKEN to be set in the environment for api_server if not api_settings.DISCORD_BOT_TOKEN: - log.error("DISCORD_BOT_TOKEN not configured in API settings. Cannot send webhook notification.") - return {"status": "error", "message": "Notification sending failed (bot token not configured)."} + log.error( + "DISCORD_BOT_TOKEN not configured in API settings. Cannot send webhook notification." + ) + return { + "status": "error", + "message": "Notification sending failed (bot token not configured).", + } # Convert embed to dict for sending via API send_payload_dict = {"embeds": [discord_embed.to_dict()]} - log.info(f"Sending GitLab notification to channel {notification_channel_id} for repo {repo_db_id}.") + log.info( + f"Sending GitLab notification to channel {notification_channel_id} for repo {repo_db_id}." + ) # Send the embed using the send_discord_message_via_api function # The function can handle dict content with embeds send_result = await send_discord_message_via_api( channel_id=notification_channel_id, - content=send_payload_dict # Pass the dict directly + content=send_payload_dict, # Pass the dict directly ) if send_result.get("success"): - log.info(f"Successfully sent GitLab webhook notification for repo {repo_db_id} to channel {notification_channel_id}.") - return {"status": "success", "message": "Webhook received and notification sent."} + log.info( + f"Successfully sent GitLab webhook notification for repo {repo_db_id} to channel {notification_channel_id}." + ) + return { + "status": "success", + "message": "Webhook received and notification sent.", + } else: - log.error(f"Failed to send GitLab webhook notification for repo {repo_db_id}. Error: {send_result.get('message')}") - return {"status": "error", "message": f"Webhook received, but notification failed: {send_result.get('message')}"} + log.error( + f"Failed to send GitLab webhook notification for repo {repo_db_id}. Error: {send_result.get('message')}" + ) + return { + "status": "error", + "message": f"Webhook received, but notification failed: {send_result.get('message')}", + } + @router.get("/test") async def test_webhook_router(): return {"message": "Webhook router is working. Or mounted, at least."} + @router.get("/test-repo/{repo_db_id}") async def test_repo_retrieval(request: Request, repo_db_id: int): """Test endpoint to check if we can retrieve repository information.""" @@ -853,23 +1230,18 @@ async def test_repo_retrieval(request: Request, repo_db_id: int): repo_config = await get_monitored_repository_by_id_api(request, repo_db_id) if repo_config: - return { - "message": "Repository found", - "repo_config": repo_config - } + return {"message": "Repository found", "repo_config": repo_config} else: - return { - "message": "Repository not found", - "repo_db_id": repo_db_id - } + return {"message": "Repository not found", "repo_db_id": repo_db_id} except Exception as e: log.exception(f"Error retrieving repository {repo_db_id}: {e}") return { "message": "Error retrieving repository", "repo_db_id": repo_db_id, - "error": str(e) + "error": str(e), } + @router.get("/test-db") async def test_db_connection(request: Request): """Test endpoint to check if the database connection is working.""" @@ -881,20 +1253,25 @@ async def test_db_connection(request: Request): # Try to get the PostgreSQL pool from the FastAPI app state pg_pool = getattr(request.app.state, "pg_pool", None) if not pg_pool: - log.warning("API service PostgreSQL pool not available for test-db endpoint.") + log.warning( + "API service PostgreSQL pool not available for test-db endpoint." + ) # Try to create a new connection try: import asyncpg + settings = get_api_settings() - log.info("Attempting to create a new PostgreSQL connection for test-db endpoint") + log.info( + "Attempting to create a new PostgreSQL connection for test-db endpoint" + ) # Create a new connection to the database conn = await asyncpg.connect( user=settings.POSTGRES_USER, password=settings.POSTGRES_PASSWORD, host=settings.POSTGRES_HOST, - database=settings.POSTGRES_SETTINGS_DB + database=settings.POSTGRES_SETTINGS_DB, ) # Test query @@ -907,15 +1284,17 @@ async def test_db_connection(request: Request): "message": "Database connection successful using direct connection", "app_state_attrs": state_attrs, "pg_pool_available": False, - "version": version + "version": version, } except Exception as e: - log.exception(f"Failed to create a new PostgreSQL connection for test-db endpoint: {e}") + log.exception( + f"Failed to create a new PostgreSQL connection for test-db endpoint: {e}" + ) return { "message": "Database connection failed using direct connection", "app_state_attrs": state_attrs, "pg_pool_available": False, - "error": str(e) + "error": str(e), } # Use the pool @@ -926,7 +1305,7 @@ async def test_db_connection(request: Request): "message": "Database connection successful using app.state.pg_pool", "app_state_attrs": state_attrs, "pg_pool_available": True, - "version": version + "version": version, } except Exception as e: log.exception(f"Database error using app.state.pg_pool: {e}") @@ -934,11 +1313,8 @@ async def test_db_connection(request: Request): "message": "Database connection failed using app.state.pg_pool", "app_state_attrs": state_attrs, "pg_pool_available": True, - "error": str(e) + "error": str(e), } except Exception as e: log.exception(f"Unexpected error in test-db endpoint: {e}") - return { - "message": "Unexpected error in test-db endpoint", - "error": str(e) - } + return {"message": "Unexpected error in test-db endpoint", "error": str(e)} diff --git a/check_sync_dependencies.py b/check_sync_dependencies.py index f158c30..97c24b7 100644 --- a/check_sync_dependencies.py +++ b/check_sync_dependencies.py @@ -2,19 +2,24 @@ import importlib.util import subprocess import sys + def check_and_install_dependencies(): """Check if required dependencies are installed and install them if not.""" required_packages = ["fastapi", "uvicorn", "pydantic"] missing_packages = [] - + for package in required_packages: if importlib.util.find_spec(package) is None: missing_packages.append(package) - + if missing_packages: - print(f"Installing missing dependencies for Discord sync: {', '.join(missing_packages)}") + print( + f"Installing missing dependencies for Discord sync: {', '.join(missing_packages)}" + ) try: - subprocess.check_call([sys.executable, "-m", "pip", "install"] + missing_packages) + subprocess.check_call( + [sys.executable, "-m", "pip", "install"] + missing_packages + ) print("Dependencies installed successfully.") return True except subprocess.CalledProcessError as e: @@ -23,8 +28,9 @@ def check_and_install_dependencies(): for package in missing_packages: print(f" - {package}") return False - + return True + if __name__ == "__main__": check_and_install_dependencies() diff --git a/cogs/VoiceGatewayCog.py b/cogs/VoiceGatewayCog.py index c78b8b8..06c26a5 100644 --- a/cogs/VoiceGatewayCog.py +++ b/cogs/VoiceGatewayCog.py @@ -3,11 +3,11 @@ from discord.ext import commands import asyncio import os import tempfile -import wave # For saving audio data -import functools # Added for partial -import subprocess # For audio conversion -from discord.ext import voice_recv # For receiving voice -from typing import Optional # For type hinting +import wave # For saving audio data +import functools # Added for partial +import subprocess # For audio conversion +from discord.ext import voice_recv # For receiving voice +from typing import Optional # For type hinting # Gurt specific imports from gurt import config as GurtConfig @@ -16,34 +16,38 @@ from gurt import config as GurtConfig try: from google.cloud import speech except ImportError: - print("Google Cloud Speech library not found. Please install with 'pip install google-cloud-speech'") + print( + "Google Cloud Speech library not found. Please install with 'pip install google-cloud-speech'" + ) speech = None try: import webrtcvad except ImportError: - print("webrtcvad library not found. Please install with 'pip install webrtc-voice-activity-detector'") + print( + "webrtcvad library not found. Please install with 'pip install webrtc-voice-activity-detector'" + ) webrtcvad = None # OpusDecoder is no longer needed as discord-ext-voice-recv provides PCM. FFMPEG_OPTIONS = { # 'before_options': '-reconnect 1 -reconnect_streamed 1 -reconnect_delay_max 5', # Removed as these are for network streams and might cause issues with local files - 'options': '-vn' + "options": "-vn" } # Constants for audio processing SAMPLE_RATE = 16000 # Whisper prefers 16kHz -CHANNELS = 1 # Mono -SAMPLE_WIDTH = 2 # 16-bit audio (2 bytes per sample) -VAD_MODE = 3 # VAD aggressiveness (0-3, 3 is most aggressive) -FRAME_DURATION_MS = 30 # Duration of a frame in ms for VAD (10, 20, or 30) +CHANNELS = 1 # Mono +SAMPLE_WIDTH = 2 # 16-bit audio (2 bytes per sample) +VAD_MODE = 3 # VAD aggressiveness (0-3, 3 is most aggressive) +FRAME_DURATION_MS = 30 # Duration of a frame in ms for VAD (10, 20, or 30) BYTES_PER_FRAME = (SAMPLE_RATE // 1000) * FRAME_DURATION_MS * CHANNELS * SAMPLE_WIDTH # OPUS constants removed as Opus decoding is no longer handled here. # Silence detection parameters -SILENCE_THRESHOLD_FRAMES = 25 # Number of consecutive silent VAD frames to consider end of speech (e.g., 25 * 30ms = 750ms) -MAX_SPEECH_DURATION_S = 15 # Max duration of a single speech segment to process +SILENCE_THRESHOLD_FRAMES = 25 # Number of consecutive silent VAD frames to consider end of speech (e.g., 25 * 30ms = 750ms) +MAX_SPEECH_DURATION_S = 15 # Max duration of a single speech segment to process MAX_SPEECH_FRAMES = (MAX_SPEECH_DURATION_S * 1000) // FRAME_DURATION_MS @@ -63,40 +67,57 @@ def _convert_audio_to_16khz_mono(raw_pcm_data_48k_stereo: bytes) -> bytes: with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_out: output_temp_file = tmp_out.name - + command = [ - 'ffmpeg', - '-f', 's16le', # Input format: signed 16-bit little-endian PCM - '-ac', '2', # Input channels: stereo - '-ar', '48000', # Input sample rate: 48kHz - '-i', input_temp_file, - '-ac', str(CHANNELS), # Output channels (e.g., 1 for mono) - '-ar', str(SAMPLE_RATE), # Output sample rate (e.g., 16000) - '-sample_fmt', 's16',# Output sample format - '-y', # Overwrite output file if it exists - output_temp_file + "ffmpeg", + "-f", + "s16le", # Input format: signed 16-bit little-endian PCM + "-ac", + "2", # Input channels: stereo + "-ar", + "48000", # Input sample rate: 48kHz + "-i", + input_temp_file, + "-ac", + str(CHANNELS), # Output channels (e.g., 1 for mono) + "-ar", + str(SAMPLE_RATE), # Output sample rate (e.g., 16000) + "-sample_fmt", + "s16", # Output sample format + "-y", # Overwrite output file if it exists + output_temp_file, ] - + process = subprocess.run(command, capture_output=True, check=False) - + if process.returncode != 0: - print(f"FFmpeg error during audio conversion. Return code: {process.returncode}") + print( + f"FFmpeg error during audio conversion. Return code: {process.returncode}" + ) print(f"FFmpeg stdout: {process.stdout.decode(errors='ignore')}") print(f"FFmpeg stderr: {process.stderr.decode(errors='ignore')}") return b"" - with open(output_temp_file, 'rb') as f_out: - with wave.open(f_out, 'rb') as wf: - if wf.getnchannels() == CHANNELS and \ - wf.getframerate() == SAMPLE_RATE and \ - wf.getsampwidth() == SAMPLE_WIDTH: + with open(output_temp_file, "rb") as f_out: + with wave.open(f_out, "rb") as wf: + if ( + wf.getnchannels() == CHANNELS + and wf.getframerate() == SAMPLE_RATE + and wf.getsampwidth() == SAMPLE_WIDTH + ): converted_audio_data = wf.readframes(wf.getnframes()) else: - print(f"Warning: Converted WAV file format mismatch. Expected {CHANNELS}ch, {SAMPLE_RATE}Hz, {SAMPLE_WIDTH}bytes/sample.") - print(f"Got: {wf.getnchannels()}ch, {wf.getframerate()}Hz, {wf.getsampwidth()}bytes/sample.") + print( + f"Warning: Converted WAV file format mismatch. Expected {CHANNELS}ch, {SAMPLE_RATE}Hz, {SAMPLE_WIDTH}bytes/sample." + ) + print( + f"Got: {wf.getnchannels()}ch, {wf.getframerate()}Hz, {wf.getsampwidth()}bytes/sample." + ) return b"" except FileNotFoundError: - print("FFmpeg command not found. Please ensure FFmpeg is installed and in your system's PATH.") + print( + "FFmpeg command not found. Please ensure FFmpeg is installed and in your system's PATH." + ) return b"" except Exception as e: print(f"Error during audio conversion: {e}") @@ -106,21 +127,25 @@ def _convert_audio_to_16khz_mono(raw_pcm_data_48k_stereo: bytes) -> bytes: os.remove(input_temp_file) if output_temp_file and os.path.exists(output_temp_file): os.remove(output_temp_file) - + return converted_audio_data -class VoiceAudioSink(voice_recv.AudioSink): # Inherit from voice_recv.AudioSink - def __init__(self, cog_instance): # Removed voice_client parameter +class VoiceAudioSink(voice_recv.AudioSink): # Inherit from voice_recv.AudioSink + def __init__(self, cog_instance): # Removed voice_client parameter super().__init__() self.cog = cog_instance # self.voice_client is set by the library when listen() is called # user_audio_data now keyed by user_id, 'decoder' removed - self.user_audio_data = {} # {user_id: {'buffer': bytearray, 'speaking': False, 'silent_frames': 0, 'speech_frames': 0, 'vad': VAD_instance}} - + self.user_audio_data = ( + {} + ) # {user_id: {'buffer': bytearray, 'speaking': False, 'silent_frames': 0, 'speech_frames': 0, 'vad': VAD_instance}} + # OpusDecoder check removed if not webrtcvad: - print("VAD library not loaded. STT might be less efficient or not work as intended.") + print( + "VAD library not loaded. STT might be less efficient or not work as intended." + ) def wants_opus(self) -> bool: """ @@ -130,22 +155,24 @@ class VoiceAudioSink(voice_recv.AudioSink): # Inherit from voice_recv.AudioSink return False # Signature changed: user object directly, data is VoiceData - def write(self, user: discord.User, voice_data_packet: voice_recv.VoiceData): - if not webrtcvad or not self.voice_client or not user: # OpusDecoder check removed, user check added + def write(self, user: discord.User, voice_data_packet: voice_recv.VoiceData): + if ( + not webrtcvad or not self.voice_client or not user + ): # OpusDecoder check removed, user check added return - - user_id = user.id # Get user_id from the user object + + user_id = user.id # Get user_id from the user object if user_id not in self.user_audio_data: self.user_audio_data[user_id] = { - 'buffer': bytearray(), - 'speaking': False, - 'silent_frames': 0, - 'speech_frames': 0, + "buffer": bytearray(), + "speaking": False, + "silent_frames": 0, + "speech_frames": 0, # 'decoder' removed - 'vad': webrtcvad.Vad(VAD_MODE) if webrtcvad else None + "vad": webrtcvad.Vad(VAD_MODE) if webrtcvad else None, } - + entry = self.user_audio_data[user_id] # Extract PCM data from VoiceData packet @@ -153,7 +180,7 @@ class VoiceAudioSink(voice_recv.AudioSink): # Inherit from voice_recv.AudioSink # Convert incoming 48kHz stereo PCM to 16kHz mono PCM pcm_data = _convert_audio_to_16khz_mono(raw_pcm_data_48k_stereo) - if not pcm_data: # Conversion failed or returned empty bytes + if not pcm_data: # Conversion failed or returned empty bytes # print(f"Audio conversion failed for user {user_id}. Skipping frame.") return @@ -162,18 +189,20 @@ class VoiceAudioSink(voice_recv.AudioSink): # Inherit from voice_recv.AudioSink # We need to ensure it's split into VAD-compatible frame lengths if not already. # If pcm_data (now 16kHz mono) is a 20ms chunk, its length is 640 bytes. # A 10ms frame at 16kHz is 320 bytes. A 30ms frame is 960 bytes. - + # Ensure frame_length for VAD is correct (e.g. 20ms at 16kHz mono = 640 bytes) # This constant could be defined at class or module level. # For a 20ms frame, which is typical for voice packets: - frame_length_for_vad_20ms = (SAMPLE_RATE // 1000) * 20 * CHANNELS * SAMPLE_WIDTH + frame_length_for_vad_20ms = (SAMPLE_RATE // 1000) * 20 * CHANNELS * SAMPLE_WIDTH - if len(pcm_data) % frame_length_for_vad_20ms != 0 and len(pcm_data) > 0 : # Check if it's a multiple, or handle if not. + if ( + len(pcm_data) % frame_length_for_vad_20ms != 0 and len(pcm_data) > 0 + ): # Check if it's a multiple, or handle if not. # This might happen if the converted chunk size isn't exactly what VAD expects per call. # For now, we'll try to process it. A more robust solution might buffer/segment pcm_data # into exact 10, 20, or 30ms chunks for VAD. # print(f"Warning: PCM data length {len(pcm_data)} after conversion is not an exact multiple of VAD frame size {frame_length_for_vad_20ms} for User {user_id}. Trying to process.") - pass # Continue, VAD might handle it or error. + pass # Continue, VAD might handle it or error. # Process VAD in chunks if pcm_data is longer than one VAD frame # For simplicity, let's assume pcm_data is one processable chunk for now. @@ -181,83 +210,109 @@ class VoiceAudioSink(voice_recv.AudioSink): # Inherit from voice_recv.AudioSink # Current VAD logic processes the whole pcm_data chunk at once. # This is okay if pcm_data is already a single VAD frame (e.g. 20ms). - if entry['vad']: + if entry["vad"]: try: # Ensure pcm_data is a valid frame for VAD (e.g. 10, 20, 30 ms) # If pcm_data is, for example, 640 bytes (20ms at 16kHz mono), it's fine. - if len(pcm_data) == frame_length_for_vad_20ms: # Common case - is_speech = entry['vad'].is_speech(pcm_data, SAMPLE_RATE) - elif len(pcm_data) > 0 : # If not standard, but has data, try (might error) + if len(pcm_data) == frame_length_for_vad_20ms: # Common case + is_speech = entry["vad"].is_speech(pcm_data, SAMPLE_RATE) + elif ( + len(pcm_data) > 0 + ): # If not standard, but has data, try (might error) # print(f"VAD processing for User {user_id} with non-standard PCM length {len(pcm_data)}. May error.") # This path is risky if VAD is strict. For now, we assume it's handled or errors. # A robust way: segment pcm_data into valid VAD frames. # For now, let's assume the chunk from conversion is one such frame. - is_speech = entry['vad'].is_speech(pcm_data, SAMPLE_RATE) # This might fail if len is not 10/20/30ms worth - else: # No data + is_speech = entry["vad"].is_speech( + pcm_data, SAMPLE_RATE + ) # This might fail if len is not 10/20/30ms worth + else: # No data is_speech = False - except Exception as e: # webrtcvad can raise errors on invalid frame length + except Exception as e: # webrtcvad can raise errors on invalid frame length # print(f"VAD error for User {user_id} with PCM length {len(pcm_data)}: {e}. Defaulting to speech=True for this frame.") - is_speech = True # Fallback: if VAD fails, assume it's speech - else: # No VAD - is_speech = True + is_speech = True # Fallback: if VAD fails, assume it's speech + else: # No VAD + is_speech = True if is_speech: - entry['buffer'].extend(pcm_data) - entry['speaking'] = True - entry['silent_frames'] = 0 - entry['speech_frames'] += 1 - if entry['speech_frames'] >= MAX_SPEECH_FRAMES: + entry["buffer"].extend(pcm_data) + entry["speaking"] = True + entry["silent_frames"] = 0 + entry["speech_frames"] += 1 + if entry["speech_frames"] >= MAX_SPEECH_FRAMES: # print(f"Max speech frames reached for User {user_id}. Processing segment.") - self.cog.bot.loop.create_task(self.cog.process_audio_segment(user_id, bytes(entry['buffer']), self.voice_client.guild)) - entry['buffer'].clear() - entry['speaking'] = False - entry['speech_frames'] = 0 - elif entry['speaking']: # Was speaking, now silence - entry['buffer'].extend(pcm_data) # Add this last silent frame for context - entry['silent_frames'] += 1 - if entry['silent_frames'] >= SILENCE_THRESHOLD_FRAMES: + self.cog.bot.loop.create_task( + self.cog.process_audio_segment( + user_id, bytes(entry["buffer"]), self.voice_client.guild + ) + ) + entry["buffer"].clear() + entry["speaking"] = False + entry["speech_frames"] = 0 + elif entry["speaking"]: # Was speaking, now silence + entry["buffer"].extend(pcm_data) # Add this last silent frame for context + entry["silent_frames"] += 1 + if entry["silent_frames"] >= SILENCE_THRESHOLD_FRAMES: # print(f"Silence threshold reached for User {user_id}. Processing segment.") - self.cog.bot.loop.create_task(self.cog.process_audio_segment(user_id, bytes(entry['buffer']), self.voice_client.guild)) - entry['buffer'].clear() - entry['speaking'] = False - entry['speech_frames'] = 0 - entry['silent_frames'] = 0 + self.cog.bot.loop.create_task( + self.cog.process_audio_segment( + user_id, bytes(entry["buffer"]), self.voice_client.guild + ) + ) + entry["buffer"].clear() + entry["speaking"] = False + entry["speech_frames"] = 0 + entry["silent_frames"] = 0 # If not is_speech and not entry['speaking'], do nothing (ignore silence) def cleanup(self): print("VoiceAudioSink cleanup called.") # Iterate over a copy of items if modifications occur, or handle user_id directly for user_id, data_entry in list(self.user_audio_data.items()): - if data_entry['buffer']: + if data_entry["buffer"]: # user object is not directly available here, but process_audio_segment takes user_id # We need the guild, which should be available from self.voice_client if self.voice_client and self.voice_client.guild: guild = self.voice_client.guild - print(f"Processing remaining audio for User ID {user_id} on cleanup.") - self.cog.bot.loop.create_task(self.cog.process_audio_segment(user_id, bytes(data_entry['buffer']), guild)) + print( + f"Processing remaining audio for User ID {user_id} on cleanup." + ) + self.cog.bot.loop.create_task( + self.cog.process_audio_segment( + user_id, bytes(data_entry["buffer"]), guild + ) + ) else: - print(f"Cannot process remaining audio for User ID {user_id}: voice_client or guild not available.") + print( + f"Cannot process remaining audio for User ID {user_id}: voice_client or guild not available." + ) self.user_audio_data.clear() class VoiceGatewayCog(commands.Cog): def __init__(self, bot): self.bot = bot - self.active_sinks = {} # guild_id: VoiceAudioSink - self.dedicated_voice_text_channels: dict[int, int] = {} # guild_id: channel_id + self.active_sinks = {} # guild_id: VoiceAudioSink + self.dedicated_voice_text_channels: dict[int, int] = {} # guild_id: channel_id self.speech_client = None if speech: try: self.speech_client = speech.SpeechClient() print("Google Cloud Speech client initialized successfully.") except Exception as e: - print(f"Error initializing Google Cloud Speech client: {e}. STT will not be available.") + print( + f"Error initializing Google Cloud Speech client: {e}. STT will not be available." + ) self.speech_client = None else: - print("Google Cloud Speech library not available. STT functionality will be disabled.") + print( + "Google Cloud Speech library not available. STT functionality will be disabled." + ) - async def _ensure_dedicated_voice_text_channel(self, guild: discord.Guild, voice_channel: discord.VoiceChannel) -> Optional[discord.TextChannel]: + async def _ensure_dedicated_voice_text_channel( + self, guild: discord.Guild, voice_channel: discord.VoiceChannel + ) -> Optional[discord.TextChannel]: if not GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_ENABLED: return None @@ -265,59 +320,98 @@ class VoiceGatewayCog(commands.Cog): if existing_channel_id: channel = guild.get_channel(existing_channel_id) if channel and isinstance(channel, discord.TextChannel): - print(f"Found existing dedicated voice text channel: {channel.name} ({channel.id})") + print( + f"Found existing dedicated voice text channel: {channel.name} ({channel.id})" + ) return channel else: - print(f"Dedicated voice text channel ID {existing_channel_id} for guild {guild.id} is invalid or not found. Will create a new one.") - del self.dedicated_voice_text_channels[guild.id] # Remove invalid ID + print( + f"Dedicated voice text channel ID {existing_channel_id} for guild {guild.id} is invalid or not found. Will create a new one." + ) + del self.dedicated_voice_text_channels[guild.id] # Remove invalid ID # Create new channel channel_name = GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_NAME_TEMPLATE.format( voice_channel_name=voice_channel.name, - guild_name=guild.name + guild_name=guild.name, # Add more placeholders if needed ) # Sanitize channel name (Discord has restrictions) - channel_name = "".join(c for c in channel_name if c.isalnum() or c in ['-', '_', ' ']).strip() - channel_name = channel_name.replace(' ', '-').lower() - if not channel_name: # Fallback if template results in empty string + channel_name = "".join( + c for c in channel_name if c.isalnum() or c in ["-", "_", " "] + ).strip() + channel_name = channel_name.replace(" ", "-").lower() + if not channel_name: # Fallback if template results in empty string channel_name = "gurt-voice-chat" - + # Check if a channel with this name already exists (to avoid duplicates if bot restarted without proper cleanup) for existing_guild_channel in guild.text_channels: if existing_guild_channel.name == channel_name: - print(f"Found existing channel by name '{channel_name}' ({existing_guild_channel.id}). Reusing.") + print( + f"Found existing channel by name '{channel_name}' ({existing_guild_channel.id}). Reusing." + ) self.dedicated_voice_text_channels[guild.id] = existing_guild_channel.id # Optionally update topic and permissions if needed try: - if existing_guild_channel.topic != GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_TOPIC: - await existing_guild_channel.edit(topic=GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_TOPIC) + if ( + existing_guild_channel.topic + != GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_TOPIC + ): + await existing_guild_channel.edit( + topic=GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_TOPIC + ) # Send initial message if channel is empty or last message isn't the initial one async for last_message in existing_guild_channel.history(limit=1): - if last_message.content != GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_INITIAL_MESSAGE: - await existing_guild_channel.send(GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_INITIAL_MESSAGE) - break # Only need the very last message - else: # No messages in channel - await existing_guild_channel.send(GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_INITIAL_MESSAGE) + if ( + last_message.content + != GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_INITIAL_MESSAGE + ): + await existing_guild_channel.send( + GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_INITIAL_MESSAGE + ) + break # Only need the very last message + else: # No messages in channel + await existing_guild_channel.send( + GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_INITIAL_MESSAGE + ) except discord.Forbidden: - print(f"Missing permissions to update reused dedicated channel {channel_name}") + print( + f"Missing permissions to update reused dedicated channel {channel_name}" + ) except Exception as e_reuse: - print(f"Error updating reused dedicated channel {channel_name}: {e_reuse}") + print( + f"Error updating reused dedicated channel {channel_name}: {e_reuse}" + ) return existing_guild_channel overwrites = { - guild.me: discord.PermissionOverwrite(read_messages=True, send_messages=True, manage_messages=True), # GURT needs to manage - guild.default_role: discord.PermissionOverwrite(read_messages=False, send_messages=False) # Private by default + guild.me: discord.PermissionOverwrite( + read_messages=True, send_messages=True, manage_messages=True + ), # GURT needs to manage + guild.default_role: discord.PermissionOverwrite( + read_messages=False, send_messages=False + ), # Private by default # Consider adding server admins/mods with read/send permissions } # Add owner and admins with full perms to the channel if guild.owner: - overwrites[guild.owner] = discord.PermissionOverwrite(read_messages=True, send_messages=True, manage_channels=True, manage_messages=True) + overwrites[guild.owner] = discord.PermissionOverwrite( + read_messages=True, + send_messages=True, + manage_channels=True, + manage_messages=True, + ) for role in guild.roles: - if role.permissions.administrator and not role.is_default(): # Check for admin roles - overwrites[role] = discord.PermissionOverwrite(read_messages=True, send_messages=True, manage_channels=True, manage_messages=True) - + if ( + role.permissions.administrator and not role.is_default() + ): # Check for admin roles + overwrites[role] = discord.PermissionOverwrite( + read_messages=True, + send_messages=True, + manage_channels=True, + manage_messages=True, + ) try: print(f"Creating new dedicated voice text channel: {channel_name}") @@ -325,21 +419,29 @@ class VoiceGatewayCog(commands.Cog): name=channel_name, overwrites=overwrites, topic=GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_TOPIC, - reason="GURT Dedicated Voice Chat Channel" + reason="GURT Dedicated Voice Chat Channel", ) self.dedicated_voice_text_channels[guild.id] = new_channel.id if GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_INITIAL_MESSAGE: - await new_channel.send(GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_INITIAL_MESSAGE) - print(f"Created dedicated voice text channel: {new_channel.name} ({new_channel.id})") + await new_channel.send( + GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_INITIAL_MESSAGE + ) + print( + f"Created dedicated voice text channel: {new_channel.name} ({new_channel.id})" + ) return new_channel except discord.Forbidden: - print(f"Forbidden: Could not create dedicated voice text channel '{channel_name}' in guild {guild.name}.") + print( + f"Forbidden: Could not create dedicated voice text channel '{channel_name}' in guild {guild.name}." + ) return None except Exception as e: print(f"Error creating dedicated voice text channel '{channel_name}': {e}") return None - def get_dedicated_text_channel_for_guild(self, guild_id: int) -> Optional[discord.TextChannel]: + def get_dedicated_text_channel_for_guild( + self, guild_id: int + ) -> Optional[discord.TextChannel]: channel_id = self.dedicated_voice_text_channels.get(guild_id) if channel_id: guild = self.bot.get_guild(guild_id) @@ -355,32 +457,51 @@ class VoiceGatewayCog(commands.Cog): async def cog_unload(self): print("Unloading VoiceGatewayCog...") # Disconnect from all voice channels and clean up sinks - for vc in list(self.bot.voice_clients): # Iterate over a copy + for vc in list(self.bot.voice_clients): # Iterate over a copy guild_id = vc.guild.id if guild_id in self.active_sinks: - if vc.is_connected() and hasattr(vc, 'is_listening') and vc.is_listening(): - if hasattr(vc, 'stop_listening'): + if ( + vc.is_connected() + and hasattr(vc, "is_listening") + and vc.is_listening() + ): + if hasattr(vc, "stop_listening"): vc.stop_listening() - else: # Or equivalent for VoiceRecvClient - pass + else: # Or equivalent for VoiceRecvClient + pass self.active_sinks[guild_id].cleanup() del self.active_sinks[guild_id] - + # Handle dedicated text channel cleanup on cog unload - if GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_ENABLED and GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_CLEANUP_ON_LEAVE: + if ( + GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_ENABLED + and GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_CLEANUP_ON_LEAVE + ): dedicated_channel_id = self.dedicated_voice_text_channels.get(guild_id) if dedicated_channel_id: try: - channel_to_delete = vc.guild.get_channel(dedicated_channel_id) or await self.bot.fetch_channel(dedicated_channel_id) + channel_to_delete = vc.guild.get_channel( + dedicated_channel_id + ) or await self.bot.fetch_channel(dedicated_channel_id) if channel_to_delete: - print(f"Deleting dedicated voice text channel {channel_to_delete.name} ({channel_to_delete.id}) during cog unload.") - await channel_to_delete.delete(reason="GURT VoiceGatewayCog unload") + print( + f"Deleting dedicated voice text channel {channel_to_delete.name} ({channel_to_delete.id}) during cog unload." + ) + await channel_to_delete.delete( + reason="GURT VoiceGatewayCog unload" + ) except discord.NotFound: - print(f"Dedicated voice text channel {dedicated_channel_id} not found for deletion during unload.") + print( + f"Dedicated voice text channel {dedicated_channel_id} not found for deletion during unload." + ) except discord.Forbidden: - print(f"Forbidden: Could not delete dedicated voice text channel {dedicated_channel_id} during unload.") + print( + f"Forbidden: Could not delete dedicated voice text channel {dedicated_channel_id} during unload." + ) except Exception as e: - print(f"Error deleting dedicated voice text channel {dedicated_channel_id} during unload: {e}") + print( + f"Error deleting dedicated voice text channel {dedicated_channel_id} during unload: {e}" + ) if guild_id in self.dedicated_voice_text_channels: del self.dedicated_voice_text_channels[guild_id] @@ -392,7 +513,7 @@ class VoiceGatewayCog(commands.Cog): """Connects the bot to a specified voice channel and starts listening.""" if not channel: return None, "Channel not provided." - + guild = channel.guild voice_client = guild.voice_client @@ -400,7 +521,10 @@ class VoiceGatewayCog(commands.Cog): if voice_client.channel == channel: print(f"Already connected to {channel.name} in {guild.name}.") if isinstance(voice_client, voice_recv.VoiceRecvClient): - if guild.id not in self.active_sinks or not voice_client.is_listening(): + if ( + guild.id not in self.active_sinks + or not voice_client.is_listening() + ): self.start_listening_for_vc(voice_client) # Ensure dedicated channel is set up even if already connected await self._ensure_dedicated_voice_text_channel(guild, channel) @@ -408,41 +532,69 @@ class VoiceGatewayCog(commands.Cog): print(f"Reconnecting with VoiceRecvClient to {channel.name}.") await voice_client.disconnect(force=True) try: - voice_client = await channel.connect(cls=voice_recv.VoiceRecvClient, timeout=10.0) - print(f"Reconnected to {channel.name} in {guild.name} with VoiceRecvClient.") + voice_client = await channel.connect( + cls=voice_recv.VoiceRecvClient, timeout=10.0 + ) + print( + f"Reconnected to {channel.name} in {guild.name} with VoiceRecvClient." + ) self.start_listening_for_vc(voice_client) await self._ensure_dedicated_voice_text_channel(guild, channel) except asyncio.TimeoutError: - return None, f"Timeout trying to reconnect to {channel.name} with VoiceRecvClient." + return ( + None, + f"Timeout trying to reconnect to {channel.name} with VoiceRecvClient.", + ) except Exception as e: - return None, f"Error reconnecting to {channel.name} with VoiceRecvClient: {str(e)}" + return ( + None, + f"Error reconnecting to {channel.name} with VoiceRecvClient: {str(e)}", + ) return voice_client, "Already connected to this channel." else: - print(f"Moving to {channel.name} in {guild.name}. Reconnecting with VoiceRecvClient.") - await voice_client.disconnect(force=True) # This will trigger cleanup for old channel's dedicated text channel if configured + print( + f"Moving to {channel.name} in {guild.name}. Reconnecting with VoiceRecvClient." + ) + await voice_client.disconnect( + force=True + ) # This will trigger cleanup for old channel's dedicated text channel if configured try: - voice_client = await channel.connect(cls=voice_recv.VoiceRecvClient, timeout=10.0) - print(f"Moved and reconnected to {channel.name} in {guild.name} with VoiceRecvClient.") + voice_client = await channel.connect( + cls=voice_recv.VoiceRecvClient, timeout=10.0 + ) + print( + f"Moved and reconnected to {channel.name} in {guild.name} with VoiceRecvClient." + ) self.start_listening_for_vc(voice_client) await self._ensure_dedicated_voice_text_channel(guild, channel) except asyncio.TimeoutError: - return None, f"Timeout trying to move and connect to {channel.name}." + return ( + None, + f"Timeout trying to move and connect to {channel.name}.", + ) except Exception as e: - return None, f"Error moving and connecting to {channel.name}: {str(e)}" + return ( + None, + f"Error moving and connecting to {channel.name}: {str(e)}", + ) else: try: - voice_client = await channel.connect(cls=voice_recv.VoiceRecvClient, timeout=10.0) - print(f"Connected to {channel.name} in {guild.name} with VoiceRecvClient.") + voice_client = await channel.connect( + cls=voice_recv.VoiceRecvClient, timeout=10.0 + ) + print( + f"Connected to {channel.name} in {guild.name} with VoiceRecvClient." + ) self.start_listening_for_vc(voice_client) await self._ensure_dedicated_voice_text_channel(guild, channel) except asyncio.TimeoutError: return None, f"Timeout trying to connect to {channel.name}." except Exception as e: return None, f"Error connecting to {channel.name}: {str(e)}" - + if not voice_client: return None, "Failed to establish voice client after connection." - + return voice_client, f"Successfully connected and listening in {channel.name}." def start_listening_for_vc(self, voice_client: discord.VoiceClient): @@ -451,8 +603,8 @@ class VoiceGatewayCog(commands.Cog): if guild_id in self.active_sinks: # If sink exists, ensure it's clean and listening is (re)started if voice_client.is_listening(): - voice_client.stop_listening() # Stop previous listening if any - self.active_sinks[guild_id].cleanup() # Clean old state + voice_client.stop_listening() # Stop previous listening if any + self.active_sinks[guild_id].cleanup() # Clean old state # Re-initialize or ensure the sink is fresh for the current VC self.active_sinks[guild_id] = VoiceAudioSink(self) else: @@ -460,10 +612,13 @@ class VoiceGatewayCog(commands.Cog): if not voice_client.is_listening(): voice_client.listen(self.active_sinks[guild_id]) - print(f"Started listening in {voice_client.channel.name} for guild {guild_id}") + print( + f"Started listening in {voice_client.channel.name} for guild {guild_id}" + ) else: - print(f"Already listening in {voice_client.channel.name} for guild {guild_id}") - + print( + f"Already listening in {voice_client.channel.name} for guild {guild_id}" + ) async def disconnect_from_voice(self, guild: discord.Guild): """Disconnects the bot from the voice channel in the given guild.""" @@ -471,36 +626,53 @@ class VoiceGatewayCog(commands.Cog): if voice_client and voice_client.is_connected(): if voice_client.is_listening(): voice_client.stop_listening() - + guild_id = guild.id if guild_id in self.active_sinks: self.active_sinks[guild_id].cleanup() del self.active_sinks[guild_id] - + # Handle dedicated text channel cleanup - if GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_ENABLED and GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_CLEANUP_ON_LEAVE: + if ( + GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_ENABLED + and GurtConfig.VOICE_DEDICATED_TEXT_CHANNEL_CLEANUP_ON_LEAVE + ): dedicated_channel_id = self.dedicated_voice_text_channels.get(guild_id) if dedicated_channel_id: try: - channel_to_delete = guild.get_channel(dedicated_channel_id) or await self.bot.fetch_channel(dedicated_channel_id) + channel_to_delete = guild.get_channel( + dedicated_channel_id + ) or await self.bot.fetch_channel(dedicated_channel_id) if channel_to_delete: - print(f"Deleting dedicated voice text channel {channel_to_delete.name} ({channel_to_delete.id}).") - await channel_to_delete.delete(reason="GURT disconnected from voice channel") + print( + f"Deleting dedicated voice text channel {channel_to_delete.name} ({channel_to_delete.id})." + ) + await channel_to_delete.delete( + reason="GURT disconnected from voice channel" + ) except discord.NotFound: - print(f"Dedicated voice text channel {dedicated_channel_id} not found for deletion.") + print( + f"Dedicated voice text channel {dedicated_channel_id} not found for deletion." + ) except discord.Forbidden: - print(f"Forbidden: Could not delete dedicated voice text channel {dedicated_channel_id}.") + print( + f"Forbidden: Could not delete dedicated voice text channel {dedicated_channel_id}." + ) except Exception as e: - print(f"Error deleting dedicated voice text channel {dedicated_channel_id}: {e}") + print( + f"Error deleting dedicated voice text channel {dedicated_channel_id}: {e}" + ) if guild_id in self.dedicated_voice_text_channels: del self.dedicated_voice_text_channels[guild_id] - + await voice_client.disconnect(force=True) print(f"Disconnected from voice in {guild.name}.") return True, f"Disconnected from voice in {guild.name}." return False, "Not connected to voice in this guild." - async def play_audio_file(self, voice_client: discord.VoiceClient, audio_file_path: str): + async def play_audio_file( + self, voice_client: discord.VoiceClient, audio_file_path: str + ): """Plays an audio file in the voice channel.""" if not voice_client or not voice_client.is_connected(): print("Error: Voice client not connected.") @@ -511,15 +683,20 @@ class VoiceGatewayCog(commands.Cog): return False, "Audio file not found." if voice_client.is_playing(): - voice_client.stop() # Stop current audio if any + voice_client.stop() # Stop current audio if any try: audio_source = discord.FFmpegPCMAudio(audio_file_path, **FFMPEG_OPTIONS) - voice_client.play(audio_source, after=lambda e: self.after_audio_playback(e, audio_file_path)) + voice_client.play( + audio_source, + after=lambda e: self.after_audio_playback(e, audio_file_path), + ) print(f"Playing audio: {audio_file_path}") return True, f"Playing {os.path.basename(audio_file_path)}" except Exception as e: - print(f"Error creating/playing FFmpegPCMAudio source for {audio_file_path}: {e}") + print( + f"Error creating/playing FFmpegPCMAudio source for {audio_file_path}: {e}" + ) return False, f"Error playing audio: {str(e)}" def after_audio_playback(self, error, audio_file_path): @@ -531,10 +708,15 @@ class VoiceGatewayCog(commands.Cog): # Removed start_listening_pipeline as the sink now handles more logic directly or via tasks. - async def process_audio_segment(self, user_id: int, audio_data: bytes, guild: discord.Guild): + async def process_audio_segment( + self, user_id: int, audio_data: bytes, guild: discord.Guild + ): """Processes a segment of audio data using Google Cloud Speech-to-Text.""" if not self.speech_client or not audio_data: - if not audio_data: print(f"process_audio_segment called for user {user_id} with empty audio_data.") + if not audio_data: + print( + f"process_audio_segment called for user {user_id} with empty audio_data." + ) return try: @@ -543,30 +725,40 @@ class VoiceGatewayCog(commands.Cog): sample_rate_hertz=SAMPLE_RATE, # Defined as 16000 language_code="en-US", enable_automatic_punctuation=True, - model="telephony" # Consider uncommenting if default isn't ideal for voice chat + model="telephony", # Consider uncommenting if default isn't ideal for voice chat ) recognition_audio = speech.RecognitionAudio(content=audio_data) # Run in executor as it's a network call that can be blocking response = await self.bot.loop.run_in_executor( None, # Default ThreadPoolExecutor - functools.partial(self.speech_client.recognize, config=recognition_config, audio=recognition_audio) + functools.partial( + self.speech_client.recognize, + config=recognition_config, + audio=recognition_audio, + ), ) transcribed_text = "" for result in response.results: if result.alternatives: transcribed_text += result.alternatives[0].transcript + " " - + transcribed_text = transcribed_text.strip() if transcribed_text: user = guild.get_member(user_id) or await self.bot.fetch_user(user_id) - print(f"Google STT for {user.name} ({user_id}) in {guild.name}: {transcribed_text}") - self.bot.dispatch("voice_transcription_received", guild, user, transcribed_text) + print( + f"Google STT for {user.name} ({user_id}) in {guild.name}: {transcribed_text}" + ) + self.bot.dispatch( + "voice_transcription_received", guild, user, transcribed_text + ) except Exception as e: - print(f"Error processing audio segment with Google STT for user {user_id}: {e}") + print( + f"Error processing audio segment with Google STT for user {user_id}: {e}" + ) async def setup(bot: commands.Bot): @@ -576,7 +768,7 @@ async def setup(bot: commands.Bot): process = await asyncio.create_subprocess_shell( "ffmpeg -version", stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await process.communicate() if process.returncode == 0: @@ -584,11 +776,17 @@ async def setup(bot: commands.Bot): await bot.add_cog(VoiceGatewayCog(bot)) print("VoiceGatewayCog loaded successfully!") else: - print("FFmpeg not found or not working correctly. VoiceGatewayCog will not be loaded.") + print( + "FFmpeg not found or not working correctly. VoiceGatewayCog will not be loaded." + ) print(f"FFmpeg check stdout: {stdout.decode(errors='ignore')}") print(f"FFmpeg check stderr: {stderr.decode(errors='ignore')}") - + except FileNotFoundError: - print("FFmpeg command not found. VoiceGatewayCog will not be loaded. Please install FFmpeg and ensure it's in your system's PATH.") + print( + "FFmpeg command not found. VoiceGatewayCog will not be loaded. Please install FFmpeg and ensure it's in your system's PATH." + ) except Exception as e: - print(f"An error occurred while checking for FFmpeg: {e}. VoiceGatewayCog will not be loaded.") + print( + f"An error occurred while checking for FFmpeg: {e}. VoiceGatewayCog will not be loaded." + ) diff --git a/cogs/ai_code_agent_cog.py b/cogs/ai_code_agent_cog.py index cace85b..6991573 100644 --- a/cogs/ai_code_agent_cog.py +++ b/cogs/ai_code_agent_cog.py @@ -7,27 +7,31 @@ import aiohttp import subprocess import json import base64 -import datetime # For snapshot naming -import random # For snapshot naming -import ast # For GetCodeStructure -import pathlib # For path manipulations, potentially +import datetime # For snapshot naming +import random # For snapshot naming +import ast # For GetCodeStructure +import pathlib # For path manipulations, potentially from typing import Dict, Any, List, Optional, Tuple -from collections import defaultdict # Added for agent_shell_sessions +from collections import defaultdict # Added for agent_shell_sessions import xml.etree.ElementTree as ET # Google Generative AI Imports (using Vertex AI backend) from google import genai -from google.genai import types as google_genai_types # Renamed to avoid conflict with typing.types +from google.genai import ( + types as google_genai_types, +) # Renamed to avoid conflict with typing.types from google.api_core import exceptions as google_exceptions # Import project configuration for Vertex AI try: from gurt.config import PROJECT_ID, LOCATION except ImportError: - PROJECT_ID = os.getenv("GCP_PROJECT_ID") # Fallback to environment variable - LOCATION = os.getenv("GCP_LOCATION") # Fallback to environment variable + PROJECT_ID = os.getenv("GCP_PROJECT_ID") # Fallback to environment variable + LOCATION = os.getenv("GCP_LOCATION") # Fallback to environment variable if not PROJECT_ID or not LOCATION: - print("Warning: PROJECT_ID or LOCATION not found in gurt.config or environment variables.") + print( + "Warning: PROJECT_ID or LOCATION not found in gurt.config or environment variables." + ) # Allow cog to load but genai_client will be None from tavily import TavilyClient @@ -35,10 +39,22 @@ from tavily import TavilyClient # Define standard safety settings using google.generativeai types # Set all thresholds to OFF as requested for internal tools STANDARD_SAFETY_SETTINGS = [ - google_genai_types.SafetySetting(category=google_genai_types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold="BLOCK_NONE"), - google_genai_types.SafetySetting(category=google_genai_types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold="BLOCK_NONE"), - google_genai_types.SafetySetting(category=google_genai_types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold="BLOCK_NONE"), - google_genai_types.SafetySetting(category=google_genai_types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold="BLOCK_NONE"), + google_genai_types.SafetySetting( + category=google_genai_types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold="BLOCK_NONE", + ), + google_genai_types.SafetySetting( + category=google_genai_types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold="BLOCK_NONE", + ), + google_genai_types.SafetySetting( + category=google_genai_types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold="BLOCK_NONE", + ), + google_genai_types.SafetySetting( + category=google_genai_types.HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold="BLOCK_NONE", + ), ] # --- Constants for command filtering, mirroring shell_command_cog.py --- @@ -46,6 +62,7 @@ STANDARD_SAFETY_SETTINGS = [ BANNED_COMMANDS_AGENT = [] BANNED_PATTERNS_AGENT = [] + def is_command_allowed_agent(command): """ Check if the command is allowed to run. Mirrors shell_command_cog.py. @@ -53,7 +70,7 @@ def is_command_allowed_agent(command): """ # Check against banned commands for banned in BANNED_COMMANDS_AGENT: - if banned in command.lower(): # Simple substring check + if banned in command.lower(): # Simple substring check return False, f"Command contains banned term: `{banned}`" # Check against banned patterns @@ -62,6 +79,8 @@ def is_command_allowed_agent(command): return False, f"Command matches banned pattern: `{pattern}`" return True, None + + # --- End of command filtering constants and function --- COMMIT_AUTHOR = "AI Coding Agent Cog " @@ -325,17 +344,26 @@ IMPORTANT: Do NOT wrap your XML tool calls in markdown code blocks (e.g., ```xml - **Focus:** Your goal is to complete the coding/file manipulation task as requested by the user, adapting to the current operational mode. """ + class AICodeAgentCog(commands.Cog): def __init__(self, bot: commands.Bot): self.bot = bot self.genai_client = None - self.agent_conversations: Dict[int, List[google_genai_types.Content]] = {} # User ID to conversation history - self.agent_shell_sessions = defaultdict(lambda: { # For ExecuteCommand CWD tracking - 'cwd': os.getcwd(), - 'env': os.environ.copy() - }) - self.agent_modes: Dict[int, str] = {} # User ID to current agent mode (e.g., "default", "planning") - self.agent_python_repl_sessions: Dict[str, Dict[str, Any]] = {} # session_id to {'globals': {}, 'locals': {}} + self.agent_conversations: Dict[int, List[google_genai_types.Content]] = ( + {} + ) # User ID to conversation history + self.agent_shell_sessions = defaultdict( + lambda: { # For ExecuteCommand CWD tracking + "cwd": os.getcwd(), + "env": os.environ.copy(), + } + ) + self.agent_modes: Dict[int, str] = ( + {} + ) # User ID to current agent mode (e.g., "default", "planning") + self.agent_python_repl_sessions: Dict[str, Dict[str, Any]] = ( + {} + ) # session_id to {'globals': {}, 'locals': {}} # Initialize Google GenAI Client for Vertex AI if PROJECT_ID and LOCATION: @@ -345,18 +373,24 @@ class AICodeAgentCog(commands.Cog): project=PROJECT_ID, location=LOCATION, ) - print(f"AICodeAgentCog: Google GenAI Client initialized for Vertex AI project '{PROJECT_ID}' in location '{LOCATION}'.") + print( + f"AICodeAgentCog: Google GenAI Client initialized for Vertex AI project '{PROJECT_ID}' in location '{LOCATION}'." + ) except Exception as e: - print(f"AICodeAgentCog: Error initializing Google GenAI Client for Vertex AI: {e}") - self.genai_client = None # Ensure it's None on failure + print( + f"AICodeAgentCog: Error initializing Google GenAI Client for Vertex AI: {e}" + ) + self.genai_client = None # Ensure it's None on failure else: - print("AICodeAgentCog: PROJECT_ID or LOCATION not configured. Google GenAI Client not initialized.") + print( + "AICodeAgentCog: PROJECT_ID or LOCATION not configured. Google GenAI Client not initialized." + ) # AI Model Configuration - self._ai_model: str = "gemini-2.5-flash-preview-05-20" # Default model + self._ai_model: str = "gemini-2.5-flash-preview-05-20" # Default model self._available_models: Dict[str, str] = { - "pro": "gemini-2.5-pro-preview-05-06", # Assuming this is the intended Pro model - "flash": "gemini-2.5-flash-preview-05-20" + "pro": "gemini-2.5-pro-preview-05-06", # Assuming this is the intended Pro model + "flash": "gemini-2.5-flash-preview-05-20", } # User mentioned "gemini-2.5-pro-preview-05-06" and "gemini-2.5-flash-preview-05-20" # Updating to reflect those if they are the correct ones, otherwise the 1.5 versions are common. @@ -370,9 +404,13 @@ class AICodeAgentCog(commands.Cog): self.tavily_client = TavilyClient(api_key=self.tavily_api_key) print("AICodeAgentCog: TavilyClient initialized.") else: - print("AICodeAgentCog: TAVILY_API_KEY not found. TavilyClient not initialized.") - - self.tavily_search_depth: str = os.getenv("TAVILY_DEFAULT_SEARCH_DEPTH", "basic") + print( + "AICodeAgentCog: TAVILY_API_KEY not found. TavilyClient not initialized." + ) + + self.tavily_search_depth: str = os.getenv( + "TAVILY_DEFAULT_SEARCH_DEPTH", "basic" + ) self.tavily_max_results: int = int(os.getenv("TAVILY_DEFAULT_MAX_RESULTS", "5")) @commands.command(name="codeagent_model") @@ -382,9 +420,13 @@ class AICodeAgentCog(commands.Cog): model_key = model_key.lower() if model_key in self._available_models: self._ai_model = self._available_models[model_key] - await ctx.send(f"AICodeAgent: AI model set to: {self._ai_model} (key: {model_key})") + await ctx.send( + f"AICodeAgent: AI model set to: {self._ai_model} (key: {model_key})" + ) else: - await ctx.send(f"AICodeAgent: Invalid model key '{model_key}'. Available keys: {', '.join(self._available_models.keys())}") + await ctx.send( + f"AICodeAgent: Invalid model key '{model_key}'. Available keys: {', '.join(self._available_models.keys())}" + ) @commands.command(name="codeagent_get_model") @commands.is_owner() @@ -401,47 +443,73 @@ class AICodeAgentCog(commands.Cog): del self.agent_conversations[user_id] await ctx.send("AICodeAgent: Conversation history cleared for you.") else: - await ctx.send("AICodeAgent: No conversation history found for you to clear.") + await ctx.send( + "AICodeAgent: No conversation history found for you to clear." + ) @commands.command(name="codeagent_mode", aliases=["ca_mode"]) @commands.is_owner() - async def codeagent_mode_command(self, ctx: commands.Context, mode_name: str, *, context_message: Optional[str] = None): + async def codeagent_mode_command( + self, + ctx: commands.Context, + mode_name: str, + *, + context_message: Optional[str] = None, + ): """Sets the operational mode for the AI agent for the calling user. Usage: !codeagent_mode [optional context_message] Modes: default, planning, debugging, learning """ user_id = ctx.author.id mode_name = mode_name.lower() - valid_modes = ["default", "planning", "debugging", "learning"] # Can be expanded + valid_modes = [ + "default", + "planning", + "debugging", + "learning", + ] # Can be expanded if mode_name not in valid_modes: - await ctx.send(f"AICodeAgent: Invalid mode '{mode_name}'. Valid modes are: {', '.join(valid_modes)}.") + await ctx.send( + f"AICodeAgent: Invalid mode '{mode_name}'. Valid modes are: {', '.join(valid_modes)}." + ) return self.agent_modes[user_id] = mode_name - mode_set_message = f"AICodeAgent: Operational mode for you set to '{mode_name}'." - + mode_set_message = ( + f"AICodeAgent: Operational mode for you set to '{mode_name}'." + ) + # Prepare system notification for AI history - notification_text = f"[System Notification] Agent mode changed to '{mode_name}'." + notification_text = ( + f"[System Notification] Agent mode changed to '{mode_name}'." + ) if context_message: notification_text += f" Context: {context_message}" mode_set_message += f" Context: {context_message}" # Add this notification to the AI's conversation history for this user # This ensures the AI is aware of the mode change for its next interaction - self._add_to_conversation_history(user_id, role="user", text_content=notification_text) # Treat as user input for AI to see + self._add_to_conversation_history( + user_id, role="user", text_content=notification_text + ) # Treat as user input for AI to see await ctx.send(mode_set_message) - print(f"AICodeAgentCog: User {user_id} set mode to '{mode_name}'. Notification added to history: {notification_text}") - + print( + f"AICodeAgentCog: User {user_id} set mode to '{mode_name}'. Notification added to history: {notification_text}" + ) @commands.command(name="codeagent_get_mode", aliases=["ca_get_mode"]) @commands.is_owner() async def codeagent_get_mode_command(self, ctx: commands.Context): """Displays the current operational mode for the AI agent for the calling user.""" user_id = ctx.author.id - current_mode = self.agent_modes.get(user_id, "default") # Default to "default" if not set - await ctx.send(f"AICodeAgent: Your current operational mode is '{current_mode}'.") + current_mode = self.agent_modes.get( + user_id, "default" + ) # Default to "default" if not set + await ctx.send( + f"AICodeAgent: Your current operational mode is '{current_mode}'." + ) async def _run_git_command(self, command_str: str) -> Tuple[bool, str]: """ @@ -460,14 +528,16 @@ class AICodeAgentCog(commands.Cog): # If command_str is a single string and might contain shell features (though unlikely for our git use), # shell=True would be needed, but then command_str must be trustworthy. # Given our specific git commands, splitting them is safer. - + # Simplified: if it's a simple git command, can pass as string with shell=True, # but better to split for shell=False. # For now, let's assume simple commands or trust shell=True for git. # However, the example used shell=True. Let's try that first for consistency with the hint. - + final_command_str = command_str - if "commit" in command_str and "--author" in command_str: # Heuristic to identify our commit commands + if ( + "commit" in command_str and "--author" in command_str + ): # Heuristic to identify our commit commands # COMMIT_AUTHOR = "Name " author_name_match = re.match(r"^(.*?)\s*<(.+?)>$", COMMIT_AUTHOR) if author_name_match: @@ -477,37 +547,58 @@ class AICodeAgentCog(commands.Cog): # Ensure the original command_str is correctly modified. # If command_str starts with "git commit", we insert after "git". if command_str.strip().startswith("git commit"): - parts = command_str.strip().split(" ", 1) # "git", "commit ..." - final_command_str = f"{parts[0]} -c user.name=\"{committer_name}\" -c user.email=\"{committer_email}\" {parts[1]}" - print(f"AICodeAgentCog: Modified commit command for committer ID: {final_command_str}") + parts = command_str.strip().split( + " ", 1 + ) # "git", "commit ..." + final_command_str = f'{parts[0]} -c user.name="{committer_name}" -c user.email="{committer_email}" {parts[1]}' + print( + f"AICodeAgentCog: Modified commit command for committer ID: {final_command_str}" + ) else: - print(f"AICodeAgentCog: Warning - Could not parse COMMIT_AUTHOR ('{COMMIT_AUTHOR}') to set committer identity.") - + print( + f"AICodeAgentCog: Warning - Could not parse COMMIT_AUTHOR ('{COMMIT_AUTHOR}') to set committer identity." + ) + proc = subprocess.Popen( - final_command_str, # Potentially modified command string + final_command_str, # Potentially modified command string shell=True, # Execute through the shell cwd=cwd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - text=True, # Decodes stdout/stderr as text - errors='replace' # Handles decoding errors + text=True, # Decodes stdout/stderr as text + errors="replace", # Handles decoding errors ) - stdout, stderr = proc.communicate(timeout=60) # 60-second timeout for git commands + stdout, stderr = proc.communicate( + timeout=60 + ) # 60-second timeout for git commands return (stdout, stderr, proc.returncode, False) except subprocess.TimeoutExpired: proc.kill() stdout, stderr = proc.communicate() - return (stdout, stderr, -1, True) # -1 for timeout-specific return code - except FileNotFoundError as fnf_err: # Specifically catch if 'git' command itself is not found - print(f"AICodeAgentCog: FileNotFoundError for command '{final_command_str}': {fnf_err}. Is Git installed and in PATH?") - return ("", f"FileNotFoundError: {fnf_err}. Ensure Git is installed and in PATH.", -2, False) + return (stdout, stderr, -1, True) # -1 for timeout-specific return code + except ( + FileNotFoundError + ) as fnf_err: # Specifically catch if 'git' command itself is not found + print( + f"AICodeAgentCog: FileNotFoundError for command '{final_command_str}': {fnf_err}. Is Git installed and in PATH?" + ) + return ( + "", + f"FileNotFoundError: {fnf_err}. Ensure Git is installed and in PATH.", + -2, + False, + ) except Exception as e: - print(f"AICodeAgentCog: Exception in run_sync_subprocess for '{final_command_str}': {type(e).__name__} - {e}") - return ("", str(e), -3, False) # -3 for other exceptions + print( + f"AICodeAgentCog: Exception in run_sync_subprocess for '{final_command_str}': {type(e).__name__} - {e}" + ) + return ("", str(e), -3, False) # -3 for other exceptions + + stdout_str, stderr_str, returncode, timed_out = await asyncio.to_thread( + run_sync_subprocess + ) - stdout_str, stderr_str, returncode, timed_out = await asyncio.to_thread(run_sync_subprocess) - full_output = "" if timed_out: full_output += "Command timed out after 60 seconds.\n" @@ -515,18 +606,26 @@ class AICodeAgentCog(commands.Cog): full_output += f"Stdout:\n{stdout_str.strip()}\n" if stderr_str: full_output += f"Stderr:\n{stderr_str.strip()}\n" - + if returncode == 0: # For commands like `git rev-parse --abbrev-ref HEAD`, stdout is the primary result. # If stdout is empty but no error, return it as is. # If full_output is just "Stdout:\n\n", it means empty stdout. # We want the actual stdout for rev-parse, not the "Stdout:" prefix. # Check original command_str for this specific case, not final_command_str which might be modified - if command_str == "git rev-parse --abbrev-ref HEAD" and stdout_str: # Use original command_str for this check - return True, stdout_str.strip() # Return just the branch name - return True, full_output.strip() if full_output.strip() else "Command executed successfully with no output." + if ( + command_str == "git rev-parse --abbrev-ref HEAD" and stdout_str + ): # Use original command_str for this check + return True, stdout_str.strip() # Return just the branch name + return True, ( + full_output.strip() + if full_output.strip() + else "Command executed successfully with no output." + ) else: - error_message = f"Git command failed. Return Code: {returncode}\n{full_output.strip()}" + error_message = ( + f"Git command failed. Return Code: {returncode}\n{full_output.strip()}" + ) print(f"AICodeAgentCog: {error_message}") return False, error_message @@ -534,9 +633,13 @@ class AICodeAgentCog(commands.Cog): """Creates a programmatic Git snapshot using a temporary branch.""" try: # Get current branch name - success, current_branch_name = await self._run_git_command("git rev-parse --abbrev-ref HEAD") + success, current_branch_name = await self._run_git_command( + "git rev-parse --abbrev-ref HEAD" + ) if not success: - print(f"AICodeAgentCog: Failed to get current branch name for snapshot: {current_branch_name}") + print( + f"AICodeAgentCog: Failed to get current branch name for snapshot: {current_branch_name}" + ) return None current_branch_name = current_branch_name.strip() @@ -545,30 +648,50 @@ class AICodeAgentCog(commands.Cog): snapshot_branch_name = f"snapshot_cog_{timestamp}_{random_hex}" # Create and checkout the new snapshot branch - success, output = await self._run_git_command(f"git checkout -b {snapshot_branch_name}") + success, output = await self._run_git_command( + f"git checkout -b {snapshot_branch_name}" + ) if not success: - print(f"AICodeAgentCog: Failed to create snapshot branch '{snapshot_branch_name}': {output}") + print( + f"AICodeAgentCog: Failed to create snapshot branch '{snapshot_branch_name}': {output}" + ) # Attempt to switch back if checkout failed mid-operation - await self._run_git_command(f"git checkout {current_branch_name}") # Best effort + await self._run_git_command( + f"git checkout {current_branch_name}" + ) # Best effort return None # Commit any currently staged changes or create an empty commit for a clean snapshot point - commit_message = f"Cog Snapshot: Pre-AI Edit State on branch {snapshot_branch_name}" - success, output = await self._run_git_command(f"git commit --author=\"{COMMIT_AUTHOR}\" -m \"{commit_message}\" --allow-empty") + commit_message = ( + f"Cog Snapshot: Pre-AI Edit State on branch {snapshot_branch_name}" + ) + success, output = await self._run_git_command( + f'git commit --author="{COMMIT_AUTHOR}" -m "{commit_message}" --allow-empty' + ) if not success: - print(f"AICodeAgentCog: Failed to commit snapshot on '{snapshot_branch_name}': {output}") + print( + f"AICodeAgentCog: Failed to commit snapshot on '{snapshot_branch_name}': {output}" + ) # Attempt to switch back and clean up branch await self._run_git_command(f"git checkout {current_branch_name}") - await self._run_git_command(f"git branch -D {snapshot_branch_name}") # Best effort cleanup + await self._run_git_command( + f"git branch -D {snapshot_branch_name}" + ) # Best effort cleanup return None - - # Switch back to the original branch - success, output = await self._run_git_command(f"git checkout {current_branch_name}") - if not success: - print(f"AICodeAgentCog: CRITICAL - Failed to switch back to original branch '{current_branch_name}' after snapshot. Current branch might be '{snapshot_branch_name}'. Manual intervention may be needed. Error: {output}") - return snapshot_branch_name # Return it so it can potentially be used/deleted - print(f"AICodeAgentCog: Successfully created snapshot branch: {snapshot_branch_name}") + # Switch back to the original branch + success, output = await self._run_git_command( + f"git checkout {current_branch_name}" + ) + if not success: + print( + f"AICodeAgentCog: CRITICAL - Failed to switch back to original branch '{current_branch_name}' after snapshot. Current branch might be '{snapshot_branch_name}'. Manual intervention may be needed. Error: {output}" + ) + return snapshot_branch_name # Return it so it can potentially be used/deleted + + print( + f"AICodeAgentCog: Successfully created snapshot branch: {snapshot_branch_name}" + ) return snapshot_branch_name except Exception as e: print(f"AICodeAgentCog: Exception in _create_programmatic_snapshot: {e}") @@ -578,7 +701,9 @@ class AICodeAgentCog(commands.Cog): @commands.is_owner() async def list_snapshots_command(self, ctx: commands.Context): """Lists available programmatic Git snapshots created by the cog.""" - success, output = await self._run_git_command('git branch --list "snapshot_cog_*"') + success, output = await self._run_git_command( + 'git branch --list "snapshot_cog_*"' + ) if success: if output: await ctx.send(f"Available snapshots:\n```\n{output}\n```") @@ -589,7 +714,9 @@ class AICodeAgentCog(commands.Cog): @commands.command(name="codeagent_revert_to_snapshot") @commands.is_owner() - async def revert_to_snapshot_command(self, ctx: commands.Context, snapshot_branch_name: str): + async def revert_to_snapshot_command( + self, ctx: commands.Context, snapshot_branch_name: str + ): """Reverts the current branch to the state of a given snapshot branch.""" if not snapshot_branch_name.startswith("snapshot_cog_"): await ctx.send("Invalid snapshot name. Must start with 'snapshot_cog_'.") @@ -598,80 +725,118 @@ class AICodeAgentCog(commands.Cog): # Check if snapshot branch exists success, branches_output = await self._run_git_command("git branch --list") # Normalize branches_output for reliable checking - existing_branches = [b.strip().lstrip('* ') for b in branches_output.splitlines()] + existing_branches = [ + b.strip().lstrip("* ") for b in branches_output.splitlines() + ] if not success or snapshot_branch_name not in existing_branches: await ctx.send(f"Snapshot branch '{snapshot_branch_name}' not found.") return - - await ctx.send(f"Attempting to revert current branch to snapshot '{snapshot_branch_name}'...") - success, current_branch = await self._run_git_command("git rev-parse --abbrev-ref HEAD") + + await ctx.send( + f"Attempting to revert current branch to snapshot '{snapshot_branch_name}'..." + ) + success, current_branch = await self._run_git_command( + "git rev-parse --abbrev-ref HEAD" + ) if not success: - await ctx.send(f"Failed to determine current branch before revert: {current_branch}") + await ctx.send( + f"Failed to determine current branch before revert: {current_branch}" + ) return current_branch = current_branch.strip() - success, output = await self._run_git_command(f"git reset --hard {snapshot_branch_name}") + success, output = await self._run_git_command( + f"git reset --hard {snapshot_branch_name}" + ) if success: - await ctx.send(f"Successfully reverted current branch ('{current_branch}') to snapshot '{snapshot_branch_name}'.\nOutput:\n```\n{output}\n```") + await ctx.send( + f"Successfully reverted current branch ('{current_branch}') to snapshot '{snapshot_branch_name}'.\nOutput:\n```\n{output}\n```" + ) else: - await ctx.send(f"Error reverting to snapshot '{snapshot_branch_name}':\n```\n{output}\n```") + await ctx.send( + f"Error reverting to snapshot '{snapshot_branch_name}':\n```\n{output}\n```" + ) @commands.command(name="codeagent_delete_snapshot") @commands.is_owner() - async def delete_snapshot_command(self, ctx: commands.Context, snapshot_branch_name: str): + async def delete_snapshot_command( + self, ctx: commands.Context, snapshot_branch_name: str + ): """Deletes a programmatic Git snapshot branch.""" if not snapshot_branch_name.startswith("snapshot_cog_"): await ctx.send("Invalid snapshot name. Must start with 'snapshot_cog_'.") return success, branches_output = await self._run_git_command("git branch --list") - existing_branches = [b.strip().lstrip('* ') for b in branches_output.splitlines()] + existing_branches = [ + b.strip().lstrip("* ") for b in branches_output.splitlines() + ] if not success or snapshot_branch_name not in existing_branches: await ctx.send(f"Snapshot branch '{snapshot_branch_name}' not found.") return - success, current_branch_name_str = await self._run_git_command("git rev-parse --abbrev-ref HEAD") + success, current_branch_name_str = await self._run_git_command( + "git rev-parse --abbrev-ref HEAD" + ) if success and current_branch_name_str.strip() == snapshot_branch_name: - await ctx.send(f"Cannot delete snapshot branch '{snapshot_branch_name}' as it is the current branch. Please checkout to a different branch first.") + await ctx.send( + f"Cannot delete snapshot branch '{snapshot_branch_name}' as it is the current branch. Please checkout to a different branch first." + ) return elif not success: - await ctx.send(f"Could not determine current branch. Deletion aborted for safety. Error: {current_branch_name_str}") - return + await ctx.send( + f"Could not determine current branch. Deletion aborted for safety. Error: {current_branch_name_str}" + ) + return - await ctx.send(f"Attempting to delete snapshot branch '{snapshot_branch_name}'...") - success, output = await self._run_git_command(f"git branch -D {snapshot_branch_name}") + await ctx.send( + f"Attempting to delete snapshot branch '{snapshot_branch_name}'..." + ) + success, output = await self._run_git_command( + f"git branch -D {snapshot_branch_name}" + ) if success: - await ctx.send(f"Successfully deleted snapshot branch '{snapshot_branch_name}'.\nOutput:\n```\n{output}\n```") + await ctx.send( + f"Successfully deleted snapshot branch '{snapshot_branch_name}'.\nOutput:\n```\n{output}\n```" + ) else: - await ctx.send(f"Error deleting snapshot branch '{snapshot_branch_name}':\n```\n{output}\n```") + await ctx.send( + f"Error deleting snapshot branch '{snapshot_branch_name}':\n```\n{output}\n```" + ) - def _get_conversation_history(self, user_id: int) -> List[google_genai_types.Content]: + def _get_conversation_history( + self, user_id: int + ) -> List[google_genai_types.Content]: if user_id not in self.agent_conversations: self.agent_conversations[user_id] = [] return self.agent_conversations[user_id] - def _add_to_conversation_history(self, user_id: int, role: str, text_content: str, is_tool_response: bool = False): + def _add_to_conversation_history( + self, user_id: int, role: str, text_content: str, is_tool_response: bool = False + ): history = self._get_conversation_history(user_id) # For Vertex AI, 'function' role is used for tool responses, 'model' for AI text, 'user' for user text. # We'll adapt this slightly for our inline tools. # AI's raw response (potentially with tool call) -> model # Tool's output -> user (formatted as "ToolResponse: ...") # User's direct prompt -> user - + # For simplicity in our loop, we might treat tool responses as if they are from the 'user' # to guide the AI's next step, or use a specific format the AI understands. # The system prompt already guides the AI to expect "ToolResponse:" - + # Let's ensure content is always a list of parts for Vertex parts = [google_genai_types.Part(text=text_content)] history.append(google_genai_types.Content(role=role, parts=parts)) - + # Keep history to a reasonable length (e.g., last 20 turns, or token-based limit later) max_history_items = 20 if len(history) > max_history_items: self.agent_conversations[user_id] = history[-max_history_items:] - async def _parse_and_execute_tool_call(self, ctx: commands.Context, ai_response_text: str) -> Tuple[str, Optional[str]]: + async def _parse_and_execute_tool_call( + self, ctx: commands.Context, ai_response_text: str + ) -> Tuple[str, Optional[str]]: """ Parses AI response for an XML tool call, executes it, and returns the tool's output string. Returns a tuple: (status: str, data: Optional[str]). @@ -683,11 +848,22 @@ class AICodeAgentCog(commands.Cog): # Remove potential markdown ```xml ... ``` wrapper if clean_ai_response_text.startswith("```"): # More robustly remove potential ```xml ... ``` or just ``` ... ``` - clean_ai_response_text = re.sub(r"^```(?:xml)?\s*\n?", "", clean_ai_response_text, flags=re.MULTILINE) - clean_ai_response_text = re.sub(r"\n?```$", "", clean_ai_response_text, flags=re.MULTILINE) + clean_ai_response_text = re.sub( + r"^```(?:xml)?\s*\n?", + "", + clean_ai_response_text, + flags=re.MULTILINE, + ) + clean_ai_response_text = re.sub( + r"\n?```$", "", clean_ai_response_text, flags=re.MULTILINE + ) clean_ai_response_text = clean_ai_response_text.strip() - if not clean_ai_response_text or not clean_ai_response_text.startswith("<") or not clean_ai_response_text.endswith(">"): + if ( + not clean_ai_response_text + or not clean_ai_response_text.startswith("<") + or not clean_ai_response_text.endswith(">") + ): return "NO_TOOL", ai_response_text root = ET.fromstring(clean_ai_response_text) @@ -697,205 +873,372 @@ class AICodeAgentCog(commands.Cog): if tool_name == "ReadFile": file_path = parameters.get("path") if not file_path: - return "TOOL_OUTPUT", "ToolResponse: Error\n---\nReadFile: Missing 'path' parameter." + return ( + "TOOL_OUTPUT", + "ToolResponse: Error\n---\nReadFile: Missing 'path' parameter.", + ) tool_output = await self._execute_tool_read_file(file_path) - return "TOOL_OUTPUT", f"ToolResponse: ReadFile\nPath: {file_path}\n---\n{tool_output}" + return ( + "TOOL_OUTPUT", + f"ToolResponse: ReadFile\nPath: {file_path}\n---\n{tool_output}", + ) elif tool_name == "WriteFile": file_path = parameters.get("path") - content = parameters.get("content") # CDATA content will be in .text + content = parameters.get("content") # CDATA content will be in .text if file_path is None or content is None: - return "TOOL_OUTPUT", "ToolResponse: Error\n---\nWriteFile: Missing 'path' or 'content' parameter." - + return ( + "TOOL_OUTPUT", + "ToolResponse: Error\n---\nWriteFile: Missing 'path' or 'content' parameter.", + ) + snapshot_branch = await self._create_programmatic_snapshot() if not snapshot_branch: - return "TOOL_OUTPUT", "ToolResponse: SystemError\n---\nFailed to create project snapshot. WriteFile operation aborted." + return ( + "TOOL_OUTPUT", + "ToolResponse: SystemError\n---\nFailed to create project snapshot. WriteFile operation aborted.", + ) else: - await ctx.send(f"AICodeAgent: [Info] Created snapshot: {snapshot_branch} before writing to {file_path}") - tool_output = await self._execute_tool_write_file(file_path, content) - return "TOOL_OUTPUT", f"ToolResponse: WriteFile\nPath: {file_path}\n---\n{tool_output}" + await ctx.send( + f"AICodeAgent: [Info] Created snapshot: {snapshot_branch} before writing to {file_path}" + ) + tool_output = await self._execute_tool_write_file( + file_path, content + ) + return ( + "TOOL_OUTPUT", + f"ToolResponse: WriteFile\nPath: {file_path}\n---\n{tool_output}", + ) elif tool_name == "ApplyDiff": file_path = parameters.get("path") - diff_block = parameters.get("diff_block") # CDATA content + diff_block = parameters.get("diff_block") # CDATA content if file_path is None or diff_block is None: - return "TOOL_OUTPUT", "ToolResponse: Error\n---\nApplyDiff: Missing 'path' or 'diff_block' parameter." + return ( + "TOOL_OUTPUT", + "ToolResponse: Error\n---\nApplyDiff: Missing 'path' or 'diff_block' parameter.", + ) snapshot_branch = await self._create_programmatic_snapshot() if not snapshot_branch: - return "TOOL_OUTPUT", "ToolResponse: SystemError\n---\nFailed to create project snapshot. ApplyDiff operation aborted." + return ( + "TOOL_OUTPUT", + "ToolResponse: SystemError\n---\nFailed to create project snapshot. ApplyDiff operation aborted.", + ) else: - await ctx.send(f"AICodeAgent: [Info] Created snapshot: {snapshot_branch} before applying diff to {file_path}") - tool_output = await self._execute_tool_apply_diff(file_path, diff_block) - return "TOOL_OUTPUT", f"ToolResponse: ApplyDiff\nPath: {file_path}\n---\n{tool_output}" + await ctx.send( + f"AICodeAgent: [Info] Created snapshot: {snapshot_branch} before applying diff to {file_path}" + ) + tool_output = await self._execute_tool_apply_diff( + file_path, diff_block + ) + return ( + "TOOL_OUTPUT", + f"ToolResponse: ApplyDiff\nPath: {file_path}\n---\n{tool_output}", + ) elif tool_name == "ExecuteCommand": command_str = parameters.get("command") if not command_str: - return "TOOL_OUTPUT", "ToolResponse: Error\n---\nExecuteCommand: Missing 'command' parameter." + return ( + "TOOL_OUTPUT", + "ToolResponse: Error\n---\nExecuteCommand: Missing 'command' parameter.", + ) user_id = ctx.author.id - tool_output = await self._execute_tool_execute_command(command_str, user_id) - return "TOOL_OUTPUT", f"ToolResponse: ExecuteCommand\nCommand: {command_str}\n---\n{tool_output}" + tool_output = await self._execute_tool_execute_command( + command_str, user_id + ) + return ( + "TOOL_OUTPUT", + f"ToolResponse: ExecuteCommand\nCommand: {command_str}\n---\n{tool_output}", + ) elif tool_name == "ListFiles": file_path = parameters.get("path") recursive_str = parameters.get("recursive") - recursive = recursive_str.lower() == 'true' if recursive_str else False - - filter_extensions = parameters.get("filter_extensions") # Optional: comma-separated string - filter_regex_name = parameters.get("filter_regex_name") # Optional: regex string + recursive = recursive_str.lower() == "true" if recursive_str else False + + filter_extensions = parameters.get( + "filter_extensions" + ) # Optional: comma-separated string + filter_regex_name = parameters.get( + "filter_regex_name" + ) # Optional: regex string include_metadata_str = parameters.get("include_metadata") - include_metadata = include_metadata_str.lower() == 'true' if include_metadata_str else False + include_metadata = ( + include_metadata_str.lower() == "true" + if include_metadata_str + else False + ) if not file_path: - return "TOOL_OUTPUT", "ToolResponse: Error\n---\nListFiles: Missing 'path' parameter." - + return ( + "TOOL_OUTPUT", + "ToolResponse: Error\n---\nListFiles: Missing 'path' parameter.", + ) + tool_output = await self._execute_tool_list_files( file_path, recursive, filter_extensions=filter_extensions, filter_regex_name=filter_regex_name, - include_metadata=include_metadata + include_metadata=include_metadata, ) - + params_summary = [f"Recursive: {recursive}"] - if filter_extensions: params_summary.append(f"Extensions: {filter_extensions}") - if filter_regex_name: params_summary.append(f"RegexName: {filter_regex_name}") + if filter_extensions: + params_summary.append(f"Extensions: {filter_extensions}") + if filter_regex_name: + params_summary.append(f"RegexName: {filter_regex_name}") params_summary.append(f"Metadata: {include_metadata}") - response_message = f"ToolResponse: ListFiles\nPath: {file_path}\n" + "\n".join(params_summary) + response_message = ( + f"ToolResponse: ListFiles\nPath: {file_path}\n" + + "\n".join(params_summary) + ) response_message += f"\n---\n{tool_output}" return "TOOL_OUTPUT", response_message elif tool_name == "WebSearch": query_str = parameters.get("query") if not query_str: - return "TOOL_OUTPUT", "ToolResponse: Error\n---\nWebSearch: Missing 'query' parameter." + return ( + "TOOL_OUTPUT", + "ToolResponse: Error\n---\nWebSearch: Missing 'query' parameter.", + ) tool_output = await self._execute_tool_web_search(query_str) - return "TOOL_OUTPUT", f"ToolResponse: WebSearch\nQuery: {query_str}\n---\n{tool_output}" - - elif tool_name == "TaskComplete": - message = parameters.get("message", "Task marked as complete by AI.") # Default if message tag is missing or empty - return "TASK_COMPLETE", message if message is not None else "Task marked as complete by AI." + return ( + "TOOL_OUTPUT", + f"ToolResponse: WebSearch\nQuery: {query_str}\n---\n{tool_output}", + ) + elif tool_name == "TaskComplete": + message = parameters.get( + "message", "Task marked as complete by AI." + ) # Default if message tag is missing or empty + return "TASK_COMPLETE", ( + message if message is not None else "Task marked as complete by AI." + ) elif tool_name == "LintFile": file_path = parameters.get("path") - linter = parameters.get("linter", "pylint") # Default to pylint + linter = parameters.get("linter", "pylint") # Default to pylint if not file_path: - return "TOOL_OUTPUT", "ToolResponse: Error\n---\nLintFile: Missing 'path' parameter." + return ( + "TOOL_OUTPUT", + "ToolResponse: Error\n---\nLintFile: Missing 'path' parameter.", + ) tool_output = await self._execute_tool_lint_file(file_path, linter) - return "TOOL_OUTPUT", f"ToolResponse: LintFile\nPath: {file_path}\nLinter: {linter}\n---\n{tool_output}" + return ( + "TOOL_OUTPUT", + f"ToolResponse: LintFile\nPath: {file_path}\nLinter: {linter}\n---\n{tool_output}", + ) elif tool_name == "GetCodeStructure": file_path = parameters.get("path") if not file_path: - return "TOOL_OUTPUT", "ToolResponse: Error\n---\nGetCodeStructure: Missing 'path' parameter." + return ( + "TOOL_OUTPUT", + "ToolResponse: Error\n---\nGetCodeStructure: Missing 'path' parameter.", + ) tool_output = await self._execute_tool_get_code_structure(file_path) - return "TOOL_OUTPUT", f"ToolResponse: GetCodeStructure\nPath: {file_path}\n---\n{tool_output}" + return ( + "TOOL_OUTPUT", + f"ToolResponse: GetCodeStructure\nPath: {file_path}\n---\n{tool_output}", + ) elif tool_name == "FindSymbolDefinition": symbol_name = parameters.get("symbol_name") - search_path = parameters.get("search_path", ".") # Default to current dir (project root) - file_pattern = parameters.get("file_pattern", "*.py") # Default to Python files + search_path = parameters.get( + "search_path", "." + ) # Default to current dir (project root) + file_pattern = parameters.get( + "file_pattern", "*.py" + ) # Default to Python files if not symbol_name: - return "TOOL_OUTPUT", "ToolResponse: Error\n---\nFindSymbolDefinition: Missing 'symbol_name' parameter." - tool_output = await self._execute_tool_find_symbol_definition(symbol_name, search_path, file_pattern) - return "TOOL_OUTPUT", f"ToolResponse: FindSymbolDefinition\nSymbol: {symbol_name}\nPath: {search_path}\nPattern: {file_pattern}\n---\n{tool_output}" + return ( + "TOOL_OUTPUT", + "ToolResponse: Error\n---\nFindSymbolDefinition: Missing 'symbol_name' parameter.", + ) + tool_output = await self._execute_tool_find_symbol_definition( + symbol_name, search_path, file_pattern + ) + return ( + "TOOL_OUTPUT", + f"ToolResponse: FindSymbolDefinition\nSymbol: {symbol_name}\nPath: {search_path}\nPattern: {file_pattern}\n---\n{tool_output}", + ) elif tool_name == "ManageCog": action = parameters.get("action") cog_name = parameters.get("cog_name") if not action: - return "TOOL_OUTPUT", "ToolResponse: Error\n---\nManageCog: Missing 'action' parameter." + return ( + "TOOL_OUTPUT", + "ToolResponse: Error\n---\nManageCog: Missing 'action' parameter.", + ) if action in ["load", "unload", "reload"] and not cog_name: - return "TOOL_OUTPUT", f"ToolResponse: Error\n---\nManageCog: Missing 'cog_name' for action '{action}'." + return ( + "TOOL_OUTPUT", + f"ToolResponse: Error\n---\nManageCog: Missing 'cog_name' for action '{action}'.", + ) tool_output = await self._execute_tool_manage_cog(action, cog_name) - return "TOOL_OUTPUT", f"ToolResponse: ManageCog\nAction: {action}\nCog: {cog_name or 'N/A'}\n---\n{tool_output}" + return ( + "TOOL_OUTPUT", + f"ToolResponse: ManageCog\nAction: {action}\nCog: {cog_name or 'N/A'}\n---\n{tool_output}", + ) elif tool_name == "RunTests": test_path_or_pattern = parameters.get("test_path_or_pattern") framework = parameters.get("framework", "pytest") if not test_path_or_pattern: - return "TOOL_OUTPUT", "ToolResponse: Error\n---\nRunTests: Missing 'test_path_or_pattern' parameter." - tool_output = await self._execute_tool_run_tests(test_path_or_pattern, framework) - return "TOOL_OUTPUT", f"ToolResponse: RunTests\nTarget: {test_path_or_pattern}\nFramework: {framework}\n---\n{tool_output}" + return ( + "TOOL_OUTPUT", + "ToolResponse: Error\n---\nRunTests: Missing 'test_path_or_pattern' parameter.", + ) + tool_output = await self._execute_tool_run_tests( + test_path_or_pattern, framework + ) + return ( + "TOOL_OUTPUT", + f"ToolResponse: RunTests\nTarget: {test_path_or_pattern}\nFramework: {framework}\n---\n{tool_output}", + ) elif tool_name == "PythonREPL": code_snippet = parameters.get("code_snippet") - session_id_param = parameters.get("session_id") # AI might suggest one + session_id_param = parameters.get("session_id") # AI might suggest one user_id = ctx.author.id # Use user_id for a persistent session if AI doesn't specify one, or combine them. # For simplicity, let's use user_id as the primary key for REPL sessions for now. # If AI provides session_id, it could be a sub-context within that user's REPL. # Let's make session_id for the tool map to user_id for now. - repl_session_key = str(user_id) # Or incorporate session_id_param if needed + repl_session_key = str( + user_id + ) # Or incorporate session_id_param if needed if not code_snippet: - return "TOOL_OUTPUT", "ToolResponse: Error\n---\nPythonREPL: Missing 'code_snippet' parameter." - tool_output = await self._execute_tool_python_repl(code_snippet, repl_session_key) - return "TOOL_OUTPUT", f"ToolResponse: PythonREPL\nSession: {repl_session_key}\n---\n{tool_output}" + return ( + "TOOL_OUTPUT", + "ToolResponse: Error\n---\nPythonREPL: Missing 'code_snippet' parameter.", + ) + tool_output = await self._execute_tool_python_repl( + code_snippet, repl_session_key + ) + return ( + "TOOL_OUTPUT", + f"ToolResponse: PythonREPL\nSession: {repl_session_key}\n---\n{tool_output}", + ) elif tool_name == "CreateNamedSnapshot": snapshot_name = parameters.get("snapshot_name") - description = parameters.get("description") # Optional + description = parameters.get("description") # Optional if not snapshot_name: - return "TOOL_OUTPUT", "ToolResponse: Error\n---\nCreateNamedSnapshot: Missing 'snapshot_name' parameter." - tool_output = await self._execute_tool_create_named_snapshot(snapshot_name, description) - return "TOOL_OUTPUT", f"ToolResponse: CreateNamedSnapshot\nName: {snapshot_name}\n---\n{tool_output}" + return ( + "TOOL_OUTPUT", + "ToolResponse: Error\n---\nCreateNamedSnapshot: Missing 'snapshot_name' parameter.", + ) + tool_output = await self._execute_tool_create_named_snapshot( + snapshot_name, description + ) + return ( + "TOOL_OUTPUT", + f"ToolResponse: CreateNamedSnapshot\nName: {snapshot_name}\n---\n{tool_output}", + ) elif tool_name == "CompareSnapshots": base_ref = parameters.get("base_ref") compare_ref = parameters.get("compare_ref") if not base_ref or not compare_ref: - return "TOOL_OUTPUT", "ToolResponse: Error\n---\nCompareSnapshots: Missing 'base_ref' or 'compare_ref' parameter." - tool_output = await self._execute_tool_compare_snapshots(base_ref, compare_ref) - return "TOOL_OUTPUT", f"ToolResponse: CompareSnapshots\nBase: {base_ref}\nCompare: {compare_ref}\n---\n{tool_output}" + return ( + "TOOL_OUTPUT", + "ToolResponse: Error\n---\nCompareSnapshots: Missing 'base_ref' or 'compare_ref' parameter.", + ) + tool_output = await self._execute_tool_compare_snapshots( + base_ref, compare_ref + ) + return ( + "TOOL_OUTPUT", + f"ToolResponse: CompareSnapshots\nBase: {base_ref}\nCompare: {compare_ref}\n---\n{tool_output}", + ) elif tool_name == "DryRunApplyDiff": file_path = parameters.get("path") diff_block = parameters.get("diff_block") if not file_path or not diff_block: - return "TOOL_OUTPUT", "ToolResponse: Error\n---\nDryRunApplyDiff: Missing 'path' or 'diff_block' parameter." - tool_output = await self._execute_tool_dry_run_apply_diff(file_path, diff_block) - return "TOOL_OUTPUT", f"ToolResponse: DryRunApplyDiff\nPath: {file_path}\n---\n{tool_output}" + return ( + "TOOL_OUTPUT", + "ToolResponse: Error\n---\nDryRunApplyDiff: Missing 'path' or 'diff_block' parameter.", + ) + tool_output = await self._execute_tool_dry_run_apply_diff( + file_path, diff_block + ) + return ( + "TOOL_OUTPUT", + f"ToolResponse: DryRunApplyDiff\nPath: {file_path}\n---\n{tool_output}", + ) elif tool_name == "DryRunWriteFile": file_path = parameters.get("path") if not file_path: - return "TOOL_OUTPUT", "ToolResponse: Error\n---\nDryRunWriteFile: Missing 'path' parameter." + return ( + "TOOL_OUTPUT", + "ToolResponse: Error\n---\nDryRunWriteFile: Missing 'path' parameter.", + ) tool_output = await self._execute_tool_dry_run_write_file(file_path) - return "TOOL_OUTPUT", f"ToolResponse: DryRunWriteFile\nPath: {file_path}\n---\n{tool_output}" + return ( + "TOOL_OUTPUT", + f"ToolResponse: DryRunWriteFile\nPath: {file_path}\n---\n{tool_output}", + ) elif tool_name == "ReadWebPageRaw": url_param = parameters.get("url") if not url_param: - return "TOOL_OUTPUT", "ToolResponse: Error\n---\nReadWebPageRaw: Missing 'url' parameter." + return ( + "TOOL_OUTPUT", + "ToolResponse: Error\n---\nReadWebPageRaw: Missing 'url' parameter.", + ) tool_output = await self._execute_tool_read_web_page_raw(url_param) - return "TOOL_OUTPUT", f"ToolResponse: ReadWebPageRaw\nURL: {url_param}\n---\n{tool_output}" + return ( + "TOOL_OUTPUT", + f"ToolResponse: ReadWebPageRaw\nURL: {url_param}\n---\n{tool_output}", + ) else: # Unknown tool name found in XML - return "TOOL_OUTPUT", f"ToolResponse: Error\n---\nUnknown tool: {tool_name} in XML: {clean_ai_response_text[:200]}" + return ( + "TOOL_OUTPUT", + f"ToolResponse: Error\n---\nUnknown tool: {tool_name} in XML: {clean_ai_response_text[:200]}", + ) except ET.ParseError: # Not valid XML # print(f"AICodeAgentCog: XML ParseError for response: {ai_response_text[:200]}") # Debugging return "NO_TOOL", ai_response_text - except Exception as e: # Catch any other unexpected errors during parsing/dispatch - print(f"AICodeAgentCog: Unexpected error in _parse_and_execute_tool_call: {type(e).__name__} - {e} for response {ai_response_text[:200]}") + except ( + Exception + ) as e: # Catch any other unexpected errors during parsing/dispatch + print( + f"AICodeAgentCog: Unexpected error in _parse_and_execute_tool_call: {type(e).__name__} - {e} for response {ai_response_text[:200]}" + ) # import traceback # traceback.print_exc() # For more detailed debugging if needed - return "TOOL_OUTPUT", f"ToolResponse: SystemError\n---\nError processing tool call: {type(e).__name__} - {e}" + return ( + "TOOL_OUTPUT", + f"ToolResponse: SystemError\n---\nError processing tool call: {type(e).__name__} - {e}", + ) # --- Tool Execution Methods --- # (Implementations for _execute_tool_... methods remain the same) # This comment might be outdated after this change - async def _execute_tool_read_file(self, path: str, - start_line: Optional[int] = None, - end_line: Optional[int] = None, - peek_first_n_lines: Optional[int] = None, - peek_last_n_lines: Optional[int] = None) -> str: - print(f"AICodeAgentCog: _execute_tool_read_file for path: {path}, start: {start_line}, end: {end_line}, peek_first: {peek_first_n_lines}, peek_last: {peek_last_n_lines}") + async def _execute_tool_read_file( + self, + path: str, + start_line: Optional[int] = None, + end_line: Optional[int] = None, + peek_first_n_lines: Optional[int] = None, + peek_last_n_lines: Optional[int] = None, + ) -> str: + print( + f"AICodeAgentCog: _execute_tool_read_file for path: {path}, start: {start_line}, end: {end_line}, peek_first: {peek_first_n_lines}, peek_last: {peek_last_n_lines}" + ) try: if not os.path.exists(path): return f"Error: File not found at '{path}'" @@ -905,25 +1248,33 @@ class AICodeAgentCog(commands.Cog): # Determine the operation based on parameters # Priority: peek_first > peek_last > start_line/end_line > full_read if peek_first_n_lines is not None and peek_first_n_lines > 0: - with open(path, 'r', encoding='utf-8', errors='replace') as f: + with open(path, "r", encoding="utf-8", errors="replace") as f: lines = [] for i, line in enumerate(f): if i >= peek_first_n_lines: break lines.append(line) - return "".join(lines) if lines else "File is empty or shorter than peek_first_n_lines." + return ( + "".join(lines) + if lines + else "File is empty or shorter than peek_first_n_lines." + ) elif peek_last_n_lines is not None and peek_last_n_lines > 0: # This is inefficient for large files, but a simple approach for now. # A more efficient way would be to seek from the end and read backwards, # or use a deque with a maxlen. - with open(path, 'r', encoding='utf-8', errors='replace') as f: - lines = f.readlines() # Reads all lines into memory - return "".join(lines[-peek_last_n_lines:]) if lines else "File is empty." - elif start_line is not None and start_line > 0: # start_line is 1-based - with open(path, 'r', encoding='utf-8', errors='replace') as f: - lines = f.readlines() # Reads all lines - start_idx = start_line - 1 # Convert to 0-based index - + with open(path, "r", encoding="utf-8", errors="replace") as f: + lines = f.readlines() # Reads all lines into memory + return ( + "".join(lines[-peek_last_n_lines:]) + if lines + else "File is empty." + ) + elif start_line is not None and start_line > 0: # start_line is 1-based + with open(path, "r", encoding="utf-8", errors="replace") as f: + lines = f.readlines() # Reads all lines + start_idx = start_line - 1 # Convert to 0-based index + if start_idx >= len(lines): return f"Error: start_line ({start_line}) is beyond the end of the file ({len(lines)} lines)." @@ -931,12 +1282,12 @@ class AICodeAgentCog(commands.Cog): # end_line is inclusive, so slice up to end_line (0-based end_idx = end_line) end_idx = end_line if end_idx < start_idx: - return f"Error: end_line ({end_line}) cannot be before start_line ({start_line})." - return "".join(lines[start_idx:min(end_idx, len(lines))]) - else: # Read from start_line to the end of the file + return f"Error: end_line ({end_line}) cannot be before start_line ({start_line})." + return "".join(lines[start_idx : min(end_idx, len(lines))]) + else: # Read from start_line to the end of the file return "".join(lines[start_idx:]) - else: # Default: read whole file - with open(path, 'r', encoding='utf-8', errors='replace') as f: + else: # Default: read whole file + with open(path, "r", encoding="utf-8", errors="replace") as f: content = f.read() return content except Exception as e: @@ -950,23 +1301,27 @@ class AICodeAgentCog(commands.Cog): # requested_path = os.path.abspath(os.path.join(base_dir, path)) # if not requested_path.startswith(base_dir): # return "Error: File path is outside the allowed project directory." - os.makedirs(os.path.dirname(path) or '.', exist_ok=True) # Ensure directory exists - with open(path, 'w', encoding='utf-8') as f: + os.makedirs( + os.path.dirname(path) or ".", exist_ok=True + ) # Ensure directory exists + with open(path, "w", encoding="utf-8") as f: f.write(content) return f"Successfully wrote to file '{path}'." except Exception as e: return f"Error writing to file '{path}': {type(e).__name__} - {e}" async def _execute_tool_apply_diff(self, path: str, diff_block: str) -> str: - print(f"AICodeAgentCog: Attempting _execute_tool_apply_diff (Search/Replace) for path: {path}") + print( + f"AICodeAgentCog: Attempting _execute_tool_apply_diff (Search/Replace) for path: {path}" + ) if not os.path.exists(path): return f"Error: File not found at '{path}' for applying diff." if os.path.isdir(path): return f"Error: Path '{path}' is a directory, cannot apply diff." try: - with open(path, 'r', encoding='utf-8') as f: - file_lines = f.readlines() # Read all lines, keeping newlines + with open(path, "r", encoding="utf-8") as f: + file_lines = f.readlines() # Read all lines, keeping newlines except Exception as e: return f"Error reading file '{path}': {type(e).__name__} - {e}" @@ -980,97 +1335,129 @@ class AICodeAgentCog(commands.Cog): r":start_line:(\d+)\s*\n" r"-------\s*\n" r"(.*?)" # search_content (group 2) - r"\\=======\s*\n" # Changed to match AI prompt (\=======) + r"\\=======\s*\n" # Changed to match AI prompt (\=======) r"(.*?)" # replace_content (group 3) r"\[REPLACE_BLOCK_END\]\s*$", # Changed to match AI prompt - re.DOTALL | re.MULTILINE + re.DOTALL | re.MULTILINE, ) operations = [] last_match_end = 0 for match in operation_pattern.finditer(diff_block): - if match.start() < last_match_end: # Should not happen with finditer + if match.start() < last_match_end: # Should not happen with finditer continue # Check for content between operations that is not part of a valid operation - if diff_block[last_match_end:match.start()].strip(): + if diff_block[last_match_end : match.start()].strip(): return f"Error: Malformed diff_block. Unexpected content between operations near character {last_match_end}." - + start_line = int(match.group(1)) # Split search/replace content carefully to respect original newlines # The (.*?) captures content up to the next \n======= or \n>>>>>>> # We need to remove the structural newlines from the capture groups if they are included by DOTALL - + # Use regex capture groups directly for search and replace content search_c = match.group(2) replace_c = match.group(3) - - operations.append({ - "start_line": start_line, - "search": search_c, - "replace": replace_c, - "original_block_for_error": match.group(0)[:200] # For error reporting - }) + + operations.append( + { + "start_line": start_line, + "search": search_c, + "replace": replace_c, + "original_block_for_error": match.group(0)[ + :200 + ], # For error reporting + } + ) last_match_end = match.end() if not operations: - if diff_block.strip(): # If diff_block had content but no operations parsed + if ( + diff_block.strip() + ): # If diff_block had content but no operations parsed return "Error: diff_block provided but no valid SEARCH/REPLACE operations found." - else: # Empty diff_block + else: # Empty diff_block return "Error: Empty diff_block provided." - + # Check if there's any trailing malformed content after the last valid operation if diff_block[last_match_end:].strip(): return f"Error: Malformed diff_block. Unexpected trailing content after last operation, near character {last_match_end}." # Sort operations by start_line in descending order to apply changes from bottom-up # This helps manage line number shifts correctly if replacements change the number of lines. - operations.sort(key=lambda op: op['start_line'], reverse=True) + operations.sort(key=lambda op: op["start_line"], reverse=True) applied_count = 0 for op in operations: - start_line_0_indexed = op['start_line'] - 1 - - # Ensure search content exactly matches, including its own newlines - search_content_lines = op['search'].splitlines(True) # Keep newlines for comparison - if not search_content_lines and op['search']: # Search content is not empty but has no newlines (e.g. "foo") - search_content_lines = [op['search']] - elif not op['search']: # Empty search string - search_content_lines = [""] + start_line_0_indexed = op["start_line"] - 1 + # Ensure search content exactly matches, including its own newlines + search_content_lines = op["search"].splitlines( + True + ) # Keep newlines for comparison + if ( + not search_content_lines and op["search"] + ): # Search content is not empty but has no newlines (e.g. "foo") + search_content_lines = [op["search"]] + elif not op["search"]: # Empty search string + search_content_lines = [""] num_search_lines = len(search_content_lines) - if start_line_0_indexed < 0 or (start_line_0_indexed + num_search_lines > len(file_lines) and num_search_lines > 0) or (start_line_0_indexed > len(file_lines) and num_search_lines == 0 and op['search'] == ""): # check for empty search at end of file - # Special case: if search is empty and start_line is one past the end, it's an append - if not (op['search'] == "" and start_line_0_indexed == len(file_lines)): + if ( + start_line_0_indexed < 0 + or ( + start_line_0_indexed + num_search_lines > len(file_lines) + and num_search_lines > 0 + ) + or ( + start_line_0_indexed > len(file_lines) + and num_search_lines == 0 + and op["search"] == "" + ) + ): # check for empty search at end of file + # Special case: if search is empty and start_line is one past the end, it's an append + if not ( + op["search"] == "" and start_line_0_indexed == len(file_lines) + ): return f"Error: Operation for line {op['start_line']} (0-indexed {start_line_0_indexed}) with {num_search_lines} search lines is out of file bounds (total lines: {len(file_lines)}). Block: {op['original_block_for_error']}..." - - actual_file_segment_lines = file_lines[start_line_0_indexed : start_line_0_indexed + num_search_lines] + + actual_file_segment_lines = file_lines[ + start_line_0_indexed : start_line_0_indexed + num_search_lines + ] actual_file_segment_content = "".join(actual_file_segment_lines) # Exact match, including newlines. - if actual_file_segment_content == op['search']: - replace_content_lines = op['replace'].splitlines(True) - if not replace_content_lines and op['replace']: # Replace content is not empty but has no newlines - replace_content_lines = [op['replace']] - - file_lines[start_line_0_indexed : start_line_0_indexed + num_search_lines] = replace_content_lines + if actual_file_segment_content == op["search"]: + replace_content_lines = op["replace"].splitlines(True) + if ( + not replace_content_lines and op["replace"] + ): # Replace content is not empty but has no newlines + replace_content_lines = [op["replace"]] + + file_lines[ + start_line_0_indexed : start_line_0_indexed + num_search_lines + ] = replace_content_lines applied_count += 1 else: # For better error reporting: - expected_repr = repr(op['search']) + expected_repr = repr(op["search"]) found_repr = repr(actual_file_segment_content) max_len = 100 - if len(expected_repr) > max_len: expected_repr = expected_repr[:max_len] + "..." - if len(found_repr) > max_len: found_repr = found_repr[:max_len] + "..." - return (f"Error: Search content mismatch at line {op['start_line']}.\n" - f"Expected: {expected_repr}\n" - f"Found : {found_repr}\n" - f"Original Block Hint: {op['original_block_for_error']}...") - + if len(expected_repr) > max_len: + expected_repr = expected_repr[:max_len] + "..." + if len(found_repr) > max_len: + found_repr = found_repr[:max_len] + "..." + return ( + f"Error: Search content mismatch at line {op['start_line']}.\n" + f"Expected: {expected_repr}\n" + f"Found : {found_repr}\n" + f"Original Block Hint: {op['original_block_for_error']}..." + ) + if applied_count == len(operations): try: - with open(path, 'w', encoding='utf-8') as f: + with open(path, "w", encoding="utf-8") as f: f.writelines(file_lines) return f"Successfully applied {applied_count} SEARCH/REPLACE operation(s) to '{path}'." except Exception as e: @@ -1085,9 +1472,11 @@ class AICodeAgentCog(commands.Cog): async def _execute_tool_execute_command(self, command: str, user_id: int) -> str: session = self.agent_shell_sessions[user_id] - cwd = session['cwd'] - env = session['env'] - print(f"AICodeAgentCog: Attempting _execute_tool_execute_command for user_id {user_id}: '{command}' in CWD: '{cwd}'") + cwd = session["cwd"] + env = session["env"] + print( + f"AICodeAgentCog: Attempting _execute_tool_execute_command for user_id {user_id}: '{command}' in CWD: '{cwd}'" + ) # Mirroring shell_command_cog.py's command allowance check allowed, reason = is_command_allowed_agent(command) @@ -1098,7 +1487,9 @@ class AICodeAgentCog(commands.Cog): timeout_seconds = 30.0 max_output_length = 1900 - def run_agent_subprocess_sync(cmd_str, current_cwd, current_env, cmd_timeout_secs): + def run_agent_subprocess_sync( + cmd_str, current_cwd, current_env, cmd_timeout_secs + ): try: proc = subprocess.Popen( cmd_str, @@ -1106,19 +1497,34 @@ class AICodeAgentCog(commands.Cog): cwd=current_cwd, env=current_env, stdout=subprocess.PIPE, - stderr=subprocess.PIPE + stderr=subprocess.PIPE, ) try: stdout, stderr = proc.communicate(timeout=cmd_timeout_secs) - return (stdout, stderr, proc.returncode, False) # stdout, stderr, rc, timed_out + return ( + stdout, + stderr, + proc.returncode, + False, + ) # stdout, stderr, rc, timed_out except subprocess.TimeoutExpired: proc.kill() # Communicate again to fetch any output after kill stdout, stderr = proc.communicate() - return (stdout, stderr, -1, True) # Using -1 for timeout rc, as in shell_command_cog + return ( + stdout, + stderr, + -1, + True, + ) # Using -1 for timeout rc, as in shell_command_cog except Exception as e: # Capture other exceptions during Popen or initial communicate - return (b"", str(e).encode('utf-8', errors='replace'), -2, False) # -2 for other errors + return ( + b"", + str(e).encode("utf-8", errors="replace"), + -2, + False, + ) # -2 for other errors try: # Execute the synchronous subprocess logic in a separate thread @@ -1128,83 +1534,129 @@ class AICodeAgentCog(commands.Cog): # Update session working directory if 'cd' command was used and it was successful # This logic is from the previous iteration and is similar to shell_command_cog's attempt - if command.strip().startswith('cd ') and returncode == 0: - new_dir_arg_str = command.strip()[len("cd "):].strip() + if command.strip().startswith("cd ") and returncode == 0: + new_dir_arg_str = command.strip()[len("cd ") :].strip() potential_new_cwd = None # Handle 'cd' with no arguments (e.g. 'cd' or 'cd ~') - typically goes to home - if not new_dir_arg_str or new_dir_arg_str == '~' or new_dir_arg_str == '$HOME': - potential_new_cwd = os.path.expanduser('~') - elif new_dir_arg_str == '-': + if ( + not new_dir_arg_str + or new_dir_arg_str == "~" + or new_dir_arg_str == "$HOME" + ): + potential_new_cwd = os.path.expanduser("~") + elif new_dir_arg_str == "-": # 'cd -' (previous directory) is hard to track reliably without more state, # so we won't update cwd for it, similar to shell_command_cog's limitations. - print(f"AICodeAgentCog: 'cd -' used by user_id {user_id}. CWD tracking will not update for this command.") + print( + f"AICodeAgentCog: 'cd -' used by user_id {user_id}. CWD tracking will not update for this command." + ) else: # For 'cd ' temp_arg = new_dir_arg_str # Remove quotes if present - if (temp_arg.startswith('"') and temp_arg.endswith('"')) or \ - (temp_arg.startswith("'") and temp_arg.endswith("'")): + if (temp_arg.startswith('"') and temp_arg.endswith('"')) or ( + temp_arg.startswith("'") and temp_arg.endswith("'") + ): temp_arg = temp_arg[1:-1] - + if os.path.isabs(temp_arg): potential_new_cwd = temp_arg else: potential_new_cwd = os.path.abspath(os.path.join(cwd, temp_arg)) - - if potential_new_cwd and os.path.isdir(potential_new_cwd): - session['cwd'] = potential_new_cwd - print(f"AICodeAgentCog: Updated CWD for user_id {user_id} to: {session['cwd']}") - elif new_dir_arg_str and new_dir_arg_str != '-' and potential_new_cwd: - print(f"AICodeAgentCog: 'cd' command for user_id {user_id} seemed to succeed (rc=0), but CWD tracking logic could not confirm new path '{potential_new_cwd}' or it's not a directory. CWD remains '{session['cwd']}'. Command: '{command}'.") - elif new_dir_arg_str and new_dir_arg_str != '-': # if potential_new_cwd was None but arg was given - print(f"AICodeAgentCog: 'cd' command for user_id {user_id} with arg '{new_dir_arg_str}' succeeded (rc=0), but path resolution for CWD tracking failed. CWD remains '{session['cwd']}'.") + if potential_new_cwd and os.path.isdir(potential_new_cwd): + session["cwd"] = potential_new_cwd + print( + f"AICodeAgentCog: Updated CWD for user_id {user_id} to: {session['cwd']}" + ) + elif new_dir_arg_str and new_dir_arg_str != "-" and potential_new_cwd: + print( + f"AICodeAgentCog: 'cd' command for user_id {user_id} seemed to succeed (rc=0), but CWD tracking logic could not confirm new path '{potential_new_cwd}' or it's not a directory. CWD remains '{session['cwd']}'. Command: '{command}'." + ) + elif ( + new_dir_arg_str and new_dir_arg_str != "-" + ): # if potential_new_cwd was None but arg was given + print( + f"AICodeAgentCog: 'cd' command for user_id {user_id} with arg '{new_dir_arg_str}' succeeded (rc=0), but path resolution for CWD tracking failed. CWD remains '{session['cwd']}'." + ) # Format Output identically to shell_command_cog.py's _execute_local_command result_parts = [] - stdout_str = stdout_bytes.decode('utf-8', errors='replace').strip() - stderr_str = stderr_bytes.decode('utf-8', errors='replace').strip() + stdout_str = stdout_bytes.decode("utf-8", errors="replace").strip() + stderr_str = stderr_bytes.decode("utf-8", errors="replace").strip() if timed_out: - result_parts.append(f"⏱️ Command timed out after {timeout_seconds} seconds.") + result_parts.append( + f"⏱️ Command timed out after {timeout_seconds} seconds." + ) if stdout_str: if len(stdout_str) > max_output_length: - stdout_str = stdout_str[:max_output_length] + "... (output truncated)" + stdout_str = ( + stdout_str[:max_output_length] + "... (output truncated)" + ) result_parts.append(f"📤 **STDOUT:**\n```\n{stdout_str}\n```") if stderr_str: if len(stderr_str) > max_output_length: - stderr_str = stderr_str[:max_output_length] + "... (output truncated)" + stderr_str = ( + stderr_str[:max_output_length] + "... (output truncated)" + ) result_parts.append(f"⚠️ **STDERR:**\n```\n{stderr_str}\n```") - - if returncode != 0 and not timed_out: # Don't add exit code if it was a timeout + + if ( + returncode != 0 and not timed_out + ): # Don't add exit code if it was a timeout result_parts.append(f"❌ **Exit Code:** {returncode}") - else: # Successful or timed out (timeout message already added) - if not result_parts: # No stdout, no stderr, not timed out, and successful + else: # Successful or timed out (timeout message already added) + if ( + not result_parts + ): # No stdout, no stderr, not timed out, and successful result_parts.append("✅ Command executed successfully (no output).") - + return "\n".join(result_parts) except Exception as e: # General exception during subprocess handling return f"Exception executing command '{command}': {type(e).__name__} - {e}" - async def _execute_tool_list_files(self, path: str, recursive: bool, filter_extensions: Optional[str] = None, filter_regex_name: Optional[str] = None, include_metadata: bool = False) -> str: - print(f"AICodeAgentCog: _execute_tool_list_files for path: {path}, recursive: {recursive}, ext: {filter_extensions}, regex: {filter_regex_name}, meta: {include_metadata}") + async def _execute_tool_list_files( + self, + path: str, + recursive: bool, + filter_extensions: Optional[str] = None, + filter_regex_name: Optional[str] = None, + include_metadata: bool = False, + ) -> str: + print( + f"AICodeAgentCog: _execute_tool_list_files for path: {path}, recursive: {recursive}, ext: {filter_extensions}, regex: {filter_regex_name}, meta: {include_metadata}" + ) # TODO: Implement filtering (filter_extensions, filter_regex_name) and metadata (include_metadata) try: if not os.path.exists(path): return f"Error: Path not found at '{path}'" if not os.path.isdir(path): return f"Error: Path '{path}' is not a directory." - + file_list_results = [] - excluded_dirs = {"__pycache__", ".git", ".vscode", ".idea", "node_modules", "venv", ".env", "terminal_images"} + excluded_dirs = { + "__pycache__", + ".git", + ".vscode", + ".idea", + "node_modules", + "venv", + ".env", + "terminal_images", + } extensions_to_filter = [] if filter_extensions: - extensions_to_filter = [ext.strip().lower() for ext in filter_extensions.split(',') if ext.strip()] + extensions_to_filter = [ + ext.strip().lower() + for ext in filter_extensions.split(",") + if ext.strip() + ] name_regex_pattern = None if filter_regex_name: @@ -1214,21 +1666,30 @@ class AICodeAgentCog(commands.Cog): return f"Error: Invalid regex for name filtering: {e}" items_processed = 0 - max_items_to_list = 500 # Safety break + max_items_to_list = 500 # Safety break if recursive: for root, dirs, files in os.walk(path, topdown=True): - if items_processed > max_items_to_list: break + if items_processed > max_items_to_list: + break # Exclude specified directories from further traversal - dirs[:] = [d for d in dirs if d not in excluded_dirs and (not name_regex_pattern or name_regex_pattern.search(d))] - + dirs[:] = [ + d + for d in dirs + if d not in excluded_dirs + and (not name_regex_pattern or name_regex_pattern.search(d)) + ] + for name in files: - if items_processed > max_items_to_list: break + if items_processed > max_items_to_list: + break if name_regex_pattern and not name_regex_pattern.search(name): continue - if extensions_to_filter and not any(name.lower().endswith(ext) for ext in extensions_to_filter): + if extensions_to_filter and not any( + name.lower().endswith(ext) for ext in extensions_to_filter + ): continue - + full_path = os.path.join(root, name) entry = full_path if include_metadata: @@ -1238,10 +1699,15 @@ class AICodeAgentCog(commands.Cog): except OSError: entry += " (Metadata N/A)" file_list_results.append(entry) - items_processed +=1 - - for name in dirs: # These are already filtered and regex matched (if regex provided for dirs) - if items_processed > max_items_to_list: break + items_processed += 1 + + for ( + name + ) in ( + dirs + ): # These are already filtered and regex matched (if regex provided for dirs) + if items_processed > max_items_to_list: + break # No extension filter for dirs, regex already applied full_path = os.path.join(root, name) entry = full_path + os.sep @@ -1252,19 +1718,26 @@ class AICodeAgentCog(commands.Cog): except OSError: entry += " (Metadata N/A)" file_list_results.append(entry) - items_processed +=1 - else: # Non-recursive case + items_processed += 1 + else: # Non-recursive case for item in os.listdir(path): - if items_processed > max_items_to_list: break + if items_processed > max_items_to_list: + break if item in excluded_dirs: continue if name_regex_pattern and not name_regex_pattern.search(item): continue - + full_item_path = os.path.join(path, item) is_dir = os.path.isdir(full_item_path) - if not is_dir and extensions_to_filter and not any(item.lower().endswith(ext) for ext in extensions_to_filter): + if ( + not is_dir + and extensions_to_filter + and not any( + item.lower().endswith(ext) for ext in extensions_to_filter + ) + ): continue entry = item + (os.sep if is_dir else "") @@ -1276,19 +1749,27 @@ class AICodeAgentCog(commands.Cog): else: entry += f" (Size: {stat.st_size} B, Modified: {datetime.datetime.fromtimestamp(stat.st_mtime).strftime('%Y-%m-%d %H:%M:%S')})" except OSError: - entry += " (Metadata N/A)" + entry += " (Metadata N/A)" file_list_results.append(entry) - items_processed +=1 - - if items_processed > max_items_to_list: - file_list_results.append(f"... (truncated, listed {max_items_to_list} items)") + items_processed += 1 - return "\n".join(file_list_results) if file_list_results else "No files or directories found matching criteria." + if items_processed > max_items_to_list: + file_list_results.append( + f"... (truncated, listed {max_items_to_list} items)" + ) + + return ( + "\n".join(file_list_results) + if file_list_results + else "No files or directories found matching criteria." + ) except Exception as e: return f"Error listing files at '{path}': {type(e).__name__} - {e}" async def _execute_tool_web_search(self, query: str) -> str: - print(f"AICodeAgentCog: _execute_tool_web_search for query: {query}") # Removed "Placeholder" + print( + f"AICodeAgentCog: _execute_tool_web_search for query: {query}" + ) # Removed "Placeholder" if not self.tavily_client: return "Error: Tavily client not initialized. Cannot perform web search." try: @@ -1296,20 +1777,28 @@ class AICodeAgentCog(commands.Cog): response = await asyncio.to_thread( self.tavily_client.search, query=query, - search_depth=self.tavily_search_depth, # "basic" or "advanced" + search_depth=self.tavily_search_depth, # "basic" or "advanced" max_results=self.tavily_max_results, - include_answer=True # Try to get a direct answer + include_answer=True, # Try to get a direct answer ) - + results_str_parts = [] if response.get("answer"): results_str_parts.append(f"Answer: {response['answer']}") - + if response.get("results"): - for i, res in enumerate(response["results"][:self.tavily_max_results]): # Show up to max_results - results_str_parts.append(f"\nResult {i+1}: {res.get('title', 'N/A')}\nURL: {res.get('url', 'N/A')}\nSnippet: {res.get('content', 'N/A')[:250]}...") # Truncate snippet - - return "\n".join(results_str_parts) if results_str_parts else "No search results found." + for i, res in enumerate( + response["results"][: self.tavily_max_results] + ): # Show up to max_results + results_str_parts.append( + f"\nResult {i+1}: {res.get('title', 'N/A')}\nURL: {res.get('url', 'N/A')}\nSnippet: {res.get('content', 'N/A')[:250]}..." + ) # Truncate snippet + + return ( + "\n".join(results_str_parts) + if results_str_parts + else "No search results found." + ) except Exception as e: return f"Error during Tavily web search for '{query}': {type(e).__name__} - {e}" @@ -1330,29 +1819,32 @@ class AICodeAgentCog(commands.Cog): try: process = await asyncio.create_subprocess_exec( - *linter_cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE + *linter_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) stdout, stderr = await process.communicate() - + output_str = "" if stdout: - output_str += f"Linter ({linter}) STDOUT:\n{stdout.decode(errors='replace')}\n" - if stderr: # Linters often output to stderr for warnings/errors - output_str += f"Linter ({linter}) STDERR:\n{stderr.decode(errors='replace')}\n" - + output_str += ( + f"Linter ({linter}) STDOUT:\n{stdout.decode(errors='replace')}\n" + ) + if stderr: # Linters often output to stderr for warnings/errors + output_str += ( + f"Linter ({linter}) STDERR:\n{stderr.decode(errors='replace')}\n" + ) + if not output_str and process.returncode == 0: output_str = f"Linter ({linter}) found no issues." - elif not output_str and process.returncode !=0: - output_str = f"Linter ({linter}) exited with code {process.returncode} but no output." - + elif not output_str and process.returncode != 0: + output_str = f"Linter ({linter}) exited with code {process.returncode} but no output." return output_str except FileNotFoundError: return f"Error: Linter command '{linter_cmd[0]}' not found. Please ensure it is installed and in PATH." except Exception as e: - return f"Error running linter '{linter}' on '{path}': {type(e).__name__} - {e}" + return ( + f"Error running linter '{linter}' on '{path}': {type(e).__name__} - {e}" + ) async def _execute_tool_get_code_structure(self, path: str) -> str: # Basic AST parsing example @@ -1362,22 +1854,36 @@ class AICodeAgentCog(commands.Cog): with open(path, "r", encoding="utf-8") as source_file: source_code = source_file.read() tree = ast.parse(source_code) - + structure = [] for node in ast.walk(tree): if isinstance(node, ast.FunctionDef): args = [arg.arg for arg in node.args.args] - structure.append(f"Function: {node.name}({', '.join(args)}) - Docstring: {ast.get_docstring(node) or 'N/A'}") + structure.append( + f"Function: {node.name}({', '.join(args)}) - Docstring: {ast.get_docstring(node) or 'N/A'}" + ) elif isinstance(node, ast.AsyncFunctionDef): args = [arg.arg for arg in node.args.args] - structure.append(f"Async Function: {node.name}({', '.join(args)}) - Docstring: {ast.get_docstring(node) or 'N/A'}") + structure.append( + f"Async Function: {node.name}({', '.join(args)}) - Docstring: {ast.get_docstring(node) or 'N/A'}" + ) elif isinstance(node, ast.ClassDef): - structure.append(f"Class: {node.name} - Docstring: {ast.get_docstring(node) or 'N/A'}") - return "\n".join(structure) if structure else "No major structures (classes/functions) found." + structure.append( + f"Class: {node.name} - Docstring: {ast.get_docstring(node) or 'N/A'}" + ) + return ( + "\n".join(structure) + if structure + else "No major structures (classes/functions) found." + ) except Exception as e: - return f"Error parsing code structure for '{path}': {type(e).__name__} - {e}" + return ( + f"Error parsing code structure for '{path}': {type(e).__name__} - {e}" + ) - async def _execute_tool_find_symbol_definition(self, symbol_name: str, search_path: str, file_pattern: str) -> str: + async def _execute_tool_find_symbol_definition( + self, symbol_name: str, search_path: str, file_pattern: str + ) -> str: if not os.path.exists(search_path): return f"Error: Search path '{search_path}' not found." if not os.path.isdir(search_path): @@ -1395,33 +1901,41 @@ class AICodeAgentCog(commands.Cog): # If file_pattern is like "*.py", it can be directly appended. # os.path.join will correctly handle path separators. files_to_search_arg = os.path.join(search_path, file_pattern) - + # Escape the symbol name for command line if it contains special characters, though /C should treat it literally. # For simplicity, we assume symbol_name doesn't need complex shell escaping here. - find_cmd = ["findstr", "/S", "/N", "/P", f"/C:{symbol_name}", files_to_search_arg] + find_cmd = [ + "findstr", + "/S", + "/N", + "/P", + f"/C:{symbol_name}", + files_to_search_arg, + ] try: process = await asyncio.create_subprocess_exec( *find_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - cwd=os.getcwd() # Run from bot's root, search_path should be relative or absolute + cwd=os.getcwd(), # Run from bot's root, search_path should be relative or absolute ) - stdout, stderr = await process.communicate(timeout=30) # 30-second timeout + stdout, stderr = await process.communicate(timeout=30) # 30-second timeout output_str = "" if stdout: output_str += f"Definitions found for '{symbol_name}' (using findstr):\n{stdout.decode(errors='replace')}\n" if stderr: output_str += f"Findstr STDERR:\n{stderr.decode(errors='replace')}\n" - - if not output_str and process.returncode == 0: # findstr returns 0 if found, 1 if not found, 2 for error - output_str = f"No definitions found for '{symbol_name}' in '{search_path}/{file_pattern}' (findstr found nothing but exited cleanly)." - elif not output_str and process.returncode == 1: # Explicit "not found" - output_str = f"No definitions found for '{symbol_name}' in '{search_path}/{file_pattern}'." - elif process.returncode not in [0,1]: # Other errors - output_str += f"Findstr exited with code {process.returncode}." + if ( + not output_str and process.returncode == 0 + ): # findstr returns 0 if found, 1 if not found, 2 for error + output_str = f"No definitions found for '{symbol_name}' in '{search_path}/{file_pattern}' (findstr found nothing but exited cleanly)." + elif not output_str and process.returncode == 1: # Explicit "not found" + output_str = f"No definitions found for '{symbol_name}' in '{search_path}/{file_pattern}'." + elif process.returncode not in [0, 1]: # Other errors + output_str += f"Findstr exited with code {process.returncode}." return output_str except FileNotFoundError: @@ -1431,16 +1945,23 @@ class AICodeAgentCog(commands.Cog): except Exception as e: return f"Error running FindSymbolDefinition for '{symbol_name}': {type(e).__name__} - {e}" - - async def _execute_tool_manage_cog(self, action: str, cog_name: Optional[str]) -> str: + async def _execute_tool_manage_cog( + self, action: str, cog_name: Optional[str] + ) -> str: action = action.lower() try: if action == "list": loaded_cogs = list(self.bot.cogs.keys()) - return f"Loaded cogs: {', '.join(loaded_cogs)}" if loaded_cogs else "No cogs currently loaded." - - if not cog_name: # Should be caught by parser, but defensive - return "Error: cog_name is required for load, unload, or reload actions." + return ( + f"Loaded cogs: {', '.join(loaded_cogs)}" + if loaded_cogs + else "No cogs currently loaded." + ) + + if not cog_name: # Should be caught by parser, but defensive + return ( + "Error: cog_name is required for load, unload, or reload actions." + ) if action == "load": await self.bot.load_extension(cog_name) @@ -1464,7 +1985,9 @@ class AICodeAgentCog(commands.Cog): except Exception as e: return f"Error during ManageCog action '{action}' on '{cog_name}': {type(e).__name__} - {e}" - async def _execute_tool_run_tests(self, test_path_or_pattern: str, framework: str) -> str: + async def _execute_tool_run_tests( + self, test_path_or_pattern: str, framework: str + ) -> str: framework = framework.lower() test_cmd = [] @@ -1484,74 +2007,95 @@ class AICodeAgentCog(commands.Cog): *test_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - cwd=os.getcwd() # Run tests from the project root + cwd=os.getcwd(), # Run tests from the project root ) - stdout, stderr = await process.communicate(timeout=300) # 5-minute timeout for tests + stdout, stderr = await process.communicate( + timeout=300 + ) # 5-minute timeout for tests output_str = "" if stdout: - output_str += f"Test ({framework}) STDOUT:\n{stdout.decode(errors='replace')}\n" + output_str += ( + f"Test ({framework}) STDOUT:\n{stdout.decode(errors='replace')}\n" + ) if stderr: - output_str += f"Test ({framework}) STDERR:\n{stderr.decode(errors='replace')}\n" - - if not output_str: - output_str = f"Test ({framework}) command executed with no output. Exit code: {process.returncode}" - else: - output_str += f"\nTest ({framework}) command exit code: {process.returncode}" + output_str += ( + f"Test ({framework}) STDERR:\n{stderr.decode(errors='replace')}\n" + ) + if not output_str: + output_str = f"Test ({framework}) command executed with no output. Exit code: {process.returncode}" + else: + output_str += ( + f"\nTest ({framework}) command exit code: {process.returncode}" + ) return output_str except FileNotFoundError: cmd_not_found = test_cmd[0] if framework == "unittest" and cmd_not_found == "python": - cmd_not_found = "python interpreter" + cmd_not_found = "python interpreter" return f"Error: Test command '{cmd_not_found}' not found. Please ensure it is installed and in PATH." except subprocess.TimeoutExpired: return f"Error: Tests timed out after 300 seconds for target '{test_path_or_pattern}'." except Exception as e: return f"Error running tests for '{test_path_or_pattern}' with {framework}: {type(e).__name__} - {e}" - async def _execute_tool_python_repl(self, code_snippet: str, session_key: str) -> str: + async def _execute_tool_python_repl( + self, code_snippet: str, session_key: str + ) -> str: # Basic, insecure exec-based REPL. CAUTION ADVISED. # A proper implementation would use a sandboxed environment. if session_key not in self.agent_python_repl_sessions: - self.agent_python_repl_sessions[session_key] = {'globals': globals().copy(), 'locals': {}} - + self.agent_python_repl_sessions[session_key] = { + "globals": globals().copy(), + "locals": {}, + } + session_env = self.agent_python_repl_sessions[session_key] - + # Capture stdout for the REPL import io from contextlib import redirect_stdout - + f = io.StringIO() try: with redirect_stdout(f): - exec(code_snippet, session_env['globals'], session_env['locals']) + exec(code_snippet, session_env["globals"], session_env["locals"]) output = f.getvalue() - return f"Output:\n{output}" if output else "Executed successfully with no direct output." + return ( + f"Output:\n{output}" + if output + else "Executed successfully with no direct output." + ) except Exception as e: return f"Error in PythonREPL: {type(e).__name__} - {e}" finally: f.close() - - async def _execute_tool_create_named_snapshot(self, snapshot_name: str, description: Optional[str]) -> str: + async def _execute_tool_create_named_snapshot( + self, snapshot_name: str, description: Optional[str] + ) -> str: # Similar to _create_programmatic_snapshot but uses the given name and doesn't switch back. # It creates a branch and commits. try: # Sanitize snapshot_name (Git branch names have restrictions) # A simple sanitization: replace spaces and invalid chars with underscores - safe_snapshot_name = re.sub(r'[^\w.-]', '_', snapshot_name) + safe_snapshot_name = re.sub(r"[^\w.-]", "_", snapshot_name) if not safe_snapshot_name: return "Error: Invalid snapshot name after sanitization (empty)." # Check if branch already exists - success_check, existing_branches_str = await self._run_git_command(f"git branch --list {safe_snapshot_name}") + success_check, existing_branches_str = await self._run_git_command( + f"git branch --list {safe_snapshot_name}" + ) if success_check and safe_snapshot_name in existing_branches_str: return f"Error: Snapshot branch '{safe_snapshot_name}' already exists." # Create the new snapshot branch from current HEAD - success, output = await self._run_git_command(f"git branch {safe_snapshot_name}") + success, output = await self._run_git_command( + f"git branch {safe_snapshot_name}" + ) if not success: return f"Error: Failed to create snapshot branch '{safe_snapshot_name}': {output}" @@ -1561,62 +2105,83 @@ class AICodeAgentCog(commands.Cog): commit_message = f"AI Named Snapshot: {snapshot_name}" if description: commit_message += f"\n\n{description}" - + # To commit on the new branch, we'd typically checkout to it first, commit, then checkout back. # Or, create commit on new_branch_name using 'git commit-tree' and then 'git update-ref'. # Simpler: checkout, commit, checkout back (if desired, or leave on new branch). # The prompt implies it's separate from automatic snapshots, so maybe it stays on this branch or user manages. # Let's assume for now it just creates the branch and a commit on it, leaving current branch as is. # This requires creating a commit object pointing to current HEAD and then updating the branch ref. - + # Alternative: create branch, then switch to it, commit, switch back. - current_branch_success, current_branch_name = await self._run_git_command("git rev-parse --abbrev-ref HEAD") + current_branch_success, current_branch_name = await self._run_git_command( + "git rev-parse --abbrev-ref HEAD" + ) if not current_branch_success: - await self._run_git_command(f"git branch -D {safe_snapshot_name}") # cleanup + await self._run_git_command( + f"git branch -D {safe_snapshot_name}" + ) # cleanup return f"Error: Could not get current branch name before creating named snapshot: {current_branch_name}" - success, output = await self._run_git_command(f"git checkout {safe_snapshot_name}") + success, output = await self._run_git_command( + f"git checkout {safe_snapshot_name}" + ) if not success: - await self._run_git_command(f"git branch -D {safe_snapshot_name}") # cleanup + await self._run_git_command( + f"git branch -D {safe_snapshot_name}" + ) # cleanup return f"Error: Failed to checkout to new snapshot branch '{safe_snapshot_name}': {output}" - - success, output = await self._run_git_command(f"git commit --author=\"{COMMIT_AUTHOR}\" -m \"{commit_message}\" --allow-empty") + + success, output = await self._run_git_command( + f'git commit --author="{COMMIT_AUTHOR}" -m "{commit_message}" --allow-empty' + ) if not success: # Attempt to switch back before reporting error - await self._run_git_command(f"git checkout {current_branch_name.strip()}") + await self._run_git_command( + f"git checkout {current_branch_name.strip()}" + ) # Optionally delete the branch if commit failed # await self._run_git_command(f"git branch -D {safe_snapshot_name}") return f"Error: Failed to commit on snapshot branch '{safe_snapshot_name}': {output}" # Switch back to original branch - success_back, output_back = await self._run_git_command(f"git checkout {current_branch_name.strip()}") + success_back, output_back = await self._run_git_command( + f"git checkout {current_branch_name.strip()}" + ) if not success_back: - # This is problematic, user might be left on snapshot branch + # This is problematic, user might be left on snapshot branch return f"Successfully created and committed snapshot '{safe_snapshot_name}', but FAILED to switch back to original branch '{current_branch_name.strip()}'. Current branch is now '{safe_snapshot_name}'. Details: {output_back}" - + return f"Successfully created named snapshot: {safe_snapshot_name}" except Exception as e: return f"Error creating named snapshot '{snapshot_name}': {type(e).__name__} - {e}" - - async def _execute_tool_compare_snapshots(self, base_ref: str, compare_ref: str) -> str: - success, output = await self._run_git_command(f"git diff {base_ref}..{compare_ref}") + async def _execute_tool_compare_snapshots( + self, base_ref: str, compare_ref: str + ) -> str: + success, output = await self._run_git_command( + f"git diff {base_ref}..{compare_ref}" + ) if success: return f"Diff between '{base_ref}' and '{compare_ref}':\n```diff\n{output or 'No differences found.'}\n```" else: - return f"Error comparing snapshots '{base_ref}' and '{compare_ref}': {output}" + return ( + f"Error comparing snapshots '{base_ref}' and '{compare_ref}': {output}" + ) async def _execute_tool_dry_run_apply_diff(self, path: str, diff_block: str) -> str: - print(f"AICodeAgentCog: Attempting _execute_tool_dry_run_apply_diff (Search/Replace) for path: {path}") + print( + f"AICodeAgentCog: Attempting _execute_tool_dry_run_apply_diff (Search/Replace) for path: {path}" + ) if not os.path.exists(path): return f"Error: File not found at '{path}' for dry-run applying diff." if os.path.isdir(path): return f"Error: Path '{path}' is a directory, cannot dry-run apply diff." try: - with open(path, 'r', encoding='utf-8') as f: - file_lines = f.readlines() # Read all lines, keeping newlines + with open(path, "r", encoding="utf-8") as f: + file_lines = f.readlines() # Read all lines, keeping newlines except Exception as e: return f"Error reading file '{path}' for dry-run: {type(e).__name__} - {e}" @@ -1627,10 +2192,10 @@ class AICodeAgentCog(commands.Cog): r":start_line:(\d+)\s*\n" r"-------\s*\n" r"(.*?)" # search_content (group 2) - r"\\=======\s*\n" # Changed to match AI prompt (\=======) + r"\\=======\s*\n" # Changed to match AI prompt (\=======) r"(.*?)" # replace_content (group 3) r"\[REPLACE_BLOCK_END\]\s*$", # Changed to match AI prompt - re.DOTALL | re.MULTILINE + re.DOTALL | re.MULTILINE, ) operations = [] @@ -1638,20 +2203,22 @@ class AICodeAgentCog(commands.Cog): for match in operation_pattern.finditer(diff_block): if match.start() < last_match_end: continue - if diff_block[last_match_end:match.start()].strip(): + if diff_block[last_match_end : match.start()].strip(): return f"Dry run Error: Malformed diff_block. Unexpected content between operations near character {last_match_end}." - + start_line = int(match.group(1)) # Use regex capture groups directly for search and replace content search_c = match.group(2) replace_c = match.group(3) - - operations.append({ - "start_line": start_line, - "search": search_c, - "replace": replace_c, - "original_block_for_error": match.group(0)[:200] - }) + + operations.append( + { + "start_line": start_line, + "search": search_c, + "replace": replace_c, + "original_block_for_error": match.group(0)[:200], + } + ) last_match_end = match.end() if not operations: @@ -1659,7 +2226,7 @@ class AICodeAgentCog(commands.Cog): return "Dry run Error: diff_block provided but no valid SEARCH/REPLACE operations found." else: return "Dry run Error: Empty diff_block provided." - + if diff_block[last_match_end:].strip(): return f"Dry run Error: Malformed diff_block. Unexpected trailing content after last operation, near character {last_match_end}." @@ -1668,40 +2235,59 @@ class AICodeAgentCog(commands.Cog): checked_ops_count = 0 for op_idx, op in enumerate(operations): - start_line_0_indexed = op['start_line'] - 1 - search_content_lines = op['search'].splitlines(True) - if not search_content_lines and op['search']: - search_content_lines = [op['search']] - elif not op['search']: - search_content_lines = [""] - + start_line_0_indexed = op["start_line"] - 1 + search_content_lines = op["search"].splitlines(True) + if not search_content_lines and op["search"]: + search_content_lines = [op["search"]] + elif not op["search"]: + search_content_lines = [""] + num_search_lines = len(search_content_lines) - if start_line_0_indexed < 0 or \ - (start_line_0_indexed + num_search_lines > len(file_lines) and num_search_lines > 0) or \ - (start_line_0_indexed > len(file_lines) and num_search_lines == 0 and op['search'] == ""): - if not (op['search'] == "" and start_line_0_indexed == len(file_lines)): # append case - return (f"Dry run Error: Operation {op_idx+1} for line {op['start_line']} " - f"(0-indexed {start_line_0_indexed}) with {num_search_lines} search lines " - f"is out of file bounds (total lines: {len(file_lines)}). " - f"Block: {op['original_block_for_error']}...") - - actual_file_segment_lines = file_lines[start_line_0_indexed : start_line_0_indexed + num_search_lines] + if ( + start_line_0_indexed < 0 + or ( + start_line_0_indexed + num_search_lines > len(file_lines) + and num_search_lines > 0 + ) + or ( + start_line_0_indexed > len(file_lines) + and num_search_lines == 0 + and op["search"] == "" + ) + ): + if not ( + op["search"] == "" and start_line_0_indexed == len(file_lines) + ): # append case + return ( + f"Dry run Error: Operation {op_idx+1} for line {op['start_line']} " + f"(0-indexed {start_line_0_indexed}) with {num_search_lines} search lines " + f"is out of file bounds (total lines: {len(file_lines)}). " + f"Block: {op['original_block_for_error']}..." + ) + + actual_file_segment_lines = file_lines[ + start_line_0_indexed : start_line_0_indexed + num_search_lines + ] actual_file_segment_content = "".join(actual_file_segment_lines) - if actual_file_segment_content == op['search']: + if actual_file_segment_content == op["search"]: checked_ops_count += 1 else: - expected_repr = repr(op['search']) + expected_repr = repr(op["search"]) found_repr = repr(actual_file_segment_content) max_len = 100 - if len(expected_repr) > max_len: expected_repr = expected_repr[:max_len] + "..." - if len(found_repr) > max_len: found_repr = found_repr[:max_len] + "..." - return (f"Dry run Error: Search content mismatch for operation {op_idx+1} at line {op['start_line']}.\n" - f"Expected: {expected_repr}\n" - f"Found : {found_repr}\n" - f"Original Block Hint: {op['original_block_for_error']}...") - + if len(expected_repr) > max_len: + expected_repr = expected_repr[:max_len] + "..." + if len(found_repr) > max_len: + found_repr = found_repr[:max_len] + "..." + return ( + f"Dry run Error: Search content mismatch for operation {op_idx+1} at line {op['start_line']}.\n" + f"Expected: {expected_repr}\n" + f"Found : {found_repr}\n" + f"Original Block Hint: {op['original_block_for_error']}..." + ) + if checked_ops_count == len(operations): return f"Dry run: All {len(operations)} SEARCH/REPLACE operation(s) would apply cleanly to '{path}'." else: @@ -1711,7 +2297,6 @@ class AICodeAgentCog(commands.Cog): except Exception as e: return f"Error during DryRunApplyDiff (Search/Replace) for '{path}': {type(e).__name__} - {e}" - async def _execute_tool_dry_run_write_file(self, path: str) -> str: try: p = pathlib.Path(path) @@ -1724,15 +2309,15 @@ class AICodeAgentCog(commands.Cog): # This is a bit complex; simpler check: can we write to grandparent? # For now, just report if parent doesn't exist. return f"Dry run: Parent directory '{parent_dir}' does not exist. Write would likely create it if permissions allow." - except Exception: # Broad exception for permission issues with parent - pass # Fall through to os.access checks + except Exception: # Broad exception for permission issues with parent + pass # Fall through to os.access checks - if p.exists(): # File exists + if p.exists(): # File exists if os.access(path, os.W_OK): return f"Dry run: File '{path}' exists and is writable." else: return f"Dry run: File '{path}' exists but is NOT writable (permission error)." - else: # File does not exist, check if directory is writable + else: # File does not exist, check if directory is writable if os.access(parent_dir, os.W_OK): return f"Dry run: File '{path}' does not exist, but directory '{parent_dir}' is writable. File can likely be created." else: @@ -1744,35 +2329,39 @@ class AICodeAgentCog(commands.Cog): print(f"AICodeAgentCog: _execute_tool_read_web_page_raw for URL: {url}") if not url.startswith(("http://", "https://")): return "Error: Invalid URL. Must start with http:// or https://" - + try: async with aiohttp.ClientSession() as session: # Set a timeout for the request - timeout = aiohttp.ClientTimeout(total=30) # 30 seconds total timeout + timeout = aiohttp.ClientTimeout(total=30) # 30 seconds total timeout async with session.get(url, timeout=timeout) as response: if response.status == 200: # Limit the size of the content to prevent memory issues # Max 1MB for raw content, can be adjusted max_content_size = 1 * 1024 * 1024 - content_length = response.headers.get('Content-Length') + content_length = response.headers.get("Content-Length") if content_length and int(content_length) > max_content_size: return f"Error: Content at URL is too large (>{max_content_size / (1024*1024):.0f}MB). Size: {content_length} bytes." # Read content chunk by chunk to enforce max_content_size if Content-Length is missing/unreliable content = b"" - async for chunk in response.content.iter_chunked(1024): # Read 1KB chunks + async for chunk in response.content.iter_chunked( + 1024 + ): # Read 1KB chunks content += chunk if len(content) > max_content_size: return f"Error: Content at URL is too large (exceeded {max_content_size / (1024*1024):.0f}MB during download)." - + # Try to decode as UTF-8, replace errors - return content.decode('utf-8', errors='replace') + return content.decode("utf-8", errors="replace") else: # Try to read a snippet of the error response body error_body_snippet = "" try: error_body_snippet = await response.text() - error_body_snippet = error_body_snippet[:200] # Limit snippet length + error_body_snippet = error_body_snippet[ + :200 + ] # Limit snippet length except Exception: error_body_snippet = "(Could not read error response body)" return f"Error: Failed to fetch URL. Status code: {response.status}. Response snippet: {error_body_snippet}" @@ -1785,184 +2374,281 @@ class AICodeAgentCog(commands.Cog): # --- End of New Tool Execution Methods --- - async def _process_agent_interaction(self, ctx: commands.Context, initial_prompt_text: str): + async def _process_agent_interaction( + self, ctx: commands.Context, initial_prompt_text: str + ): user_id = ctx.author.id - + # Check current mode and prepend to history if it's the start of a new interaction (or if mode changed) # The mode change command already adds a notification. Here, we ensure the AI is aware of the *current* mode # if this is a fresh interaction after a mode was set previously. # However, the system prompt now instructs AI on how mode changes are communicated. # So, direct injection here might be redundant if mode change command handles it. # Let's rely on the mode change command to inject the notification. - - self._add_to_conversation_history(user_id, role="user", text_content=initial_prompt_text) + + self._add_to_conversation_history( + user_id, role="user", text_content=initial_prompt_text + ) iteration_count = 0 max_iterations = 10 # Configurable, from plan - + # Ensure genai_client is available if not self.genai_client: - await ctx.send("AICodeAgent: Google GenAI Client is not initialized. Cannot process request.") + await ctx.send( + "AICodeAgent: Google GenAI Client is not initialized. Cannot process request." + ) return async with ctx.typing(): while iteration_count < max_iterations: current_history = self._get_conversation_history(user_id) - - if not current_history: # Should not happen if initial prompt was added - await ctx.send("AICodeAgent: Error - conversation history is empty.") + + if not current_history: # Should not happen if initial prompt was added + await ctx.send( + "AICodeAgent: Error - conversation history is empty." + ) return try: # Construct messages for Vertex AI API # The system prompt is passed via generation_config.system_instruction - vertex_contents = current_history # Already in types.Content format + vertex_contents = current_history # Already in types.Content format generation_config = google_genai_types.GenerateContentConfig( - temperature=0.3, # Adjust as needed - max_output_tokens=65535, # Adjust as needed + temperature=0.3, # Adjust as needed + max_output_tokens=65535, # Adjust as needed safety_settings=STANDARD_SAFETY_SETTINGS, # System instruction is critical here system_instruction=google_genai_types.Content( - role="system", # Though for Gemini, system prompt is often first user message or model tuning - parts=[google_genai_types.Part(text=AGENT_SYSTEM_PROMPT)] - ) + role="system", # Though for Gemini, system prompt is often first user message or model tuning + parts=[google_genai_types.Part(text=AGENT_SYSTEM_PROMPT)], + ), + ) + + print( + f"AICodeAgentCog: Sending to Vertex AI. Model: {self._ai_model}. History items: {len(vertex_contents)}" ) - - print(f"AICodeAgentCog: Sending to Vertex AI. Model: {self._ai_model}. History items: {len(vertex_contents)}") # for i, item in enumerate(vertex_contents): # print(f" History {i} Role: {item.role}, Parts: {item.parts}") - response = await self.genai_client.aio.models.generate_content( model=f"publishers/google/models/{self._ai_model}", contents=vertex_contents, - config=generation_config, # Corrected parameter name + config=generation_config, # Corrected parameter name # No 'tools' or 'tool_config' for inline tool usage ) - + # Safely extract text from response ai_response_text = "" - if response.candidates and response.candidates[0].content and response.candidates[0].content.parts: + if ( + response.candidates + and response.candidates[0].content + and response.candidates[0].content.parts + ): ai_response_text = response.candidates[0].content.parts[0].text - else: # Handle cases like safety blocks or empty responses - finish_reason = response.candidates[0].finish_reason if response.candidates else "UNKNOWN" + else: # Handle cases like safety blocks or empty responses + finish_reason = ( + response.candidates[0].finish_reason + if response.candidates + else "UNKNOWN" + ) safety_ratings_str = "" - if response.candidates and response.candidates[0].safety_ratings: + if ( + response.candidates + and response.candidates[0].safety_ratings + ): sr = response.candidates[0].safety_ratings - safety_ratings_str = ", ".join([f"{rating.category.name}: {rating.probability.name}" for rating in sr]) - + safety_ratings_str = ", ".join( + [ + f"{rating.category.name}: {rating.probability.name}" + for rating in sr + ] + ) + if finish_reason == google_genai_types.FinishReason.SAFETY: - await ctx.send(f"AICodeAgent: AI response was blocked due to safety settings: {safety_ratings_str}") - self._add_to_conversation_history(user_id, role="model", text_content=f"[Blocked by Safety: {safety_ratings_str}]") + await ctx.send( + f"AICodeAgent: AI response was blocked due to safety settings: {safety_ratings_str}" + ) + self._add_to_conversation_history( + user_id, + role="model", + text_content=f"[Blocked by Safety: {safety_ratings_str}]", + ) return else: - await ctx.send(f"AICodeAgent: AI returned an empty or non-text response. Finish Reason: {finish_reason}. Safety: {safety_ratings_str}") - self._add_to_conversation_history(user_id, role="model", text_content="[Empty or Non-Text Response]") + await ctx.send( + f"AICodeAgent: AI returned an empty or non-text response. Finish Reason: {finish_reason}. Safety: {safety_ratings_str}" + ) + self._add_to_conversation_history( + user_id, + role="model", + text_content="[Empty or Non-Text Response]", + ) return - + if not ai_response_text.strip(): - await ctx.send("AICodeAgent: AI returned an empty response text.") - self._add_to_conversation_history(user_id, role="model", text_content="[Empty Response Text]") + await ctx.send( + "AICodeAgent: AI returned an empty response text." + ) + self._add_to_conversation_history( + user_id, role="model", text_content="[Empty Response Text]" + ) return - self._add_to_conversation_history(user_id, role="model", text_content=ai_response_text) + self._add_to_conversation_history( + user_id, role="model", text_content=ai_response_text + ) print(f"AICodeAgentCog: AI Raw Response:\n{ai_response_text}") # Parse for inline tool call # _parse_and_execute_tool_call now returns -> Tuple[str, Optional[str]] # status can be "TOOL_OUTPUT", "TASK_COMPLETE", "NO_TOOL" # data is the tool output string, completion message, or original AI text - parse_status, parsed_data = await self._parse_and_execute_tool_call(ctx, ai_response_text) + parse_status, parsed_data = await self._parse_and_execute_tool_call( + ctx, ai_response_text + ) if parse_status == "TASK_COMPLETE": - completion_message = parsed_data if parsed_data is not None else "Task marked as complete by AI." - await ctx.send(f"AICodeAgent: Task Complete!\n{completion_message}") + completion_message = ( + parsed_data + if parsed_data is not None + else "Task marked as complete by AI." + ) + await ctx.send( + f"AICodeAgent: Task Complete!\n{completion_message}" + ) # Log AI's completion signal to history (optional, but good for context) # self._add_to_conversation_history(user_id, role="model", text_content=f"TaskComplete: message: {completion_message}") - return # End of interaction - + return # End of interaction + elif parse_status == "TOOL_OUTPUT": tool_output_str = parsed_data - if tool_output_str is None: # Should not happen if status is TOOL_OUTPUT but defensive - tool_output_str = "Error: Tool executed but returned no output string." - + if ( + tool_output_str is None + ): # Should not happen if status is TOOL_OUTPUT but defensive + tool_output_str = ( + "Error: Tool executed but returned no output string." + ) + print(f"AICodeAgentCog: Tool Output:\n{tool_output_str}") - self._add_to_conversation_history(user_id, role="user", text_content=tool_output_str) # Feed tool output back as 'user' + self._add_to_conversation_history( + user_id, role="user", text_content=tool_output_str + ) # Feed tool output back as 'user' iteration_count += 1 # Optionally send tool output to Discord for transparency if desired # if len(tool_output_str) < 1900 : await ctx.send(f"```{tool_output_str}```") - continue # Loop back to AI with tool output in history - + continue # Loop back to AI with tool output in history + elif parse_status == "NO_TOOL": # No tool call found, this is the final AI response for this turn - final_ai_text = parsed_data # This is the original ai_response_text - if final_ai_text is None: # Should not happen + final_ai_text = ( + parsed_data # This is the original ai_response_text + ) + if final_ai_text is None: # Should not happen final_ai_text = "AI provided no textual response." if len(final_ai_text) > 1950: - await ctx.send(final_ai_text[:1950] + "\n...(message truncated)") + await ctx.send( + final_ai_text[:1950] + "\n...(message truncated)" + ) else: await ctx.send(final_ai_text) - return # End of interaction - else: # Should not happen - await ctx.send("AICodeAgent: Internal error - unknown parse status from tool parser.") + return # End of interaction + else: # Should not happen + await ctx.send( + "AICodeAgent: Internal error - unknown parse status from tool parser." + ) return except google_exceptions.GoogleAPICallError as e: await ctx.send(f"AICodeAgent: Vertex AI API call failed: {e}") return except Exception as e: - await ctx.send(f"AICodeAgent: An unexpected error occurred during AI interaction: {e}") - print(f"AICodeAgentCog: Interaction Error: {type(e).__name__} - {e}") + await ctx.send( + f"AICodeAgent: An unexpected error occurred during AI interaction: {e}" + ) + print( + f"AICodeAgentCog: Interaction Error: {type(e).__name__} - {e}" + ) import traceback + traceback.print_exc() return # Iteration limit check (moved inside loop for clarity, but logic is similar) if iteration_count >= max_iterations: - await ctx.send(f"AICodeAgent: Reached iteration limit ({max_iterations}).") + await ctx.send( + f"AICodeAgent: Reached iteration limit ({max_iterations})." + ) try: - check = lambda m: m.author == ctx.author and m.channel == ctx.channel and \ - m.content.lower().startswith(("yes", "no", "continue", "feedback")) - - await ctx.send("Continue processing? (yes/no/feedback ):") - user_response_msg = await self.bot.wait_for('message', check=check, timeout=300.0) + check = ( + lambda m: m.author == ctx.author + and m.channel == ctx.channel + and m.content.lower().startswith( + ("yes", "no", "continue", "feedback") + ) + ) + + await ctx.send( + "Continue processing? (yes/no/feedback ):" + ) + user_response_msg = await self.bot.wait_for( + "message", check=check, timeout=300.0 + ) user_response_content = user_response_msg.content.lower() - if user_response_content.startswith("yes") or user_response_content.startswith("continue"): - iteration_count = 0 # Reset iteration count - self._add_to_conversation_history(user_id, role="user", text_content="[User approved continuation]") + if user_response_content.startswith( + "yes" + ) or user_response_content.startswith("continue"): + iteration_count = 0 # Reset iteration count + self._add_to_conversation_history( + user_id, + role="user", + text_content="[User approved continuation]", + ) await ctx.send("Continuing...") continue elif user_response_content.startswith("feedback"): - feedback_text = user_response_msg.content[len("feedback"):].strip() - iteration_count = 0 # Reset - self._add_to_conversation_history(user_id, role="user", text_content=f"System Feedback: {feedback_text}") + feedback_text = user_response_msg.content[ + len("feedback") : + ].strip() + iteration_count = 0 # Reset + self._add_to_conversation_history( + user_id, + role="user", + text_content=f"System Feedback: {feedback_text}", + ) await ctx.send("Continuing with feedback...") continue - else: # No or other + else: # No or other await ctx.send("AICodeAgent: Processing stopped by user.") return except asyncio.TimeoutError: - await ctx.send("AICodeAgent: Continuation prompt timed out. Stopping.") + await ctx.send( + "AICodeAgent: Continuation prompt timed out. Stopping." + ) return - + # If loop finishes due to max_iterations without reset (should be caught by above) - if iteration_count >= max_iterations : - await ctx.send("AICodeAgent: Stopped due to reaching maximum processing iterations.") + if iteration_count >= max_iterations: + await ctx.send( + "AICodeAgent: Stopped due to reaching maximum processing iterations." + ) @commands.command(name="codeagent", aliases=["ca"]) @commands.is_owner() async def codeagent_command(self, ctx: commands.Context, *, prompt: str): """Interacts with the AI Code Agent.""" if not self.genai_client: - await ctx.send("AICodeAgent: Google GenAI Client is not initialized. Cannot process request.") + await ctx.send( + "AICodeAgent: Google GenAI Client is not initialized. Cannot process request." + ) return if not prompt: await ctx.send("AICodeAgent: Please provide a prompt for the agent.") return - + await self._process_agent_interaction(ctx, prompt) @@ -1973,7 +2659,7 @@ async def setup(bot: commands.Bot): print("AICodeAgentCog: Cannot load cog as PROJECT_ID or LOCATION is missing.") # Optionally, raise an error or just don't add the cog # For now, let it load but genai_client will be None and commands using it should check - + cog = AICodeAgentCog(bot) await bot.add_cog(cog) - print("AICodeAgentCog loaded.") \ No newline at end of file + print("AICodeAgentCog loaded.") diff --git a/cogs/aimod.py b/cogs/aimod.py index 695c882..5587ed3 100644 --- a/cogs/aimod.py +++ b/cogs/aimod.py @@ -2,18 +2,19 @@ import discord from discord.ext import commands from discord import app_commands + # import aiohttp # For making asynchronous HTTP requests - Replaced by Google GenAI client import json -import os # To load environment variables -import collections # For deque -import datetime # For timestamps -import io # For BytesIO operations -import base64 # For encoding images to base64 -from PIL import Image # For image processing -import cv2 # For video processing -import numpy as np # For array operations -import tempfile # For temporary file operations -from typing import Optional, List, Dict, Any, Tuple # For type hinting +import os # To load environment variables +import collections # For deque +import datetime # For timestamps +import io # For BytesIO operations +import base64 # For encoding images to base64 +from PIL import Image # For image processing +import cv2 # For video processing +import numpy as np # For array operations +import tempfile # For temporary file operations +from typing import Optional, List, Dict, Any, Tuple # For type hinting # Google Generative AI Imports (using Vertex AI backend) from google import genai @@ -21,25 +22,38 @@ from google.genai import types from google.api_core import exceptions as google_exceptions # Import project configuration for Vertex AI -from gurt.config import PROJECT_ID, LOCATION # Assuming gurt.config exists and has these +from gurt.config import ( + PROJECT_ID, + LOCATION, +) # Assuming gurt.config exists and has these # --- Configuration --- # Vertex AI Configuration -DEFAULT_VERTEX_AI_MODEL = "gemini-2.5-flash-preview-05-20" # Example Vertex AI model +DEFAULT_VERTEX_AI_MODEL = "gemini-2.5-flash-preview-05-20" # Example Vertex AI model # Define standard safety settings using google.generativeai types STANDARD_SAFETY_SETTINGS = [ - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold="BLOCK_NONE"), - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold="BLOCK_NONE"), - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold="BLOCK_NONE"), - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold="BLOCK_NONE"), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold="BLOCK_NONE" + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold="BLOCK_NONE", + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold="BLOCK_NONE", + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold="BLOCK_NONE" + ), ] # Environment variable for the authorization secret (still used for other API calls) MOD_LOG_API_SECRET_ENV_VAR = "MOD_LOG_API_SECRET" # --- Per-Guild Discord Configuration --- -GUILD_CONFIG_DIR = "data/" # Using the existing directory for all json data +GUILD_CONFIG_DIR = "data/" # Using the existing directory for all json data GUILD_CONFIG_PATH = os.path.join(GUILD_CONFIG_DIR, "guild_config.json") USER_INFRACTIONS_PATH = os.path.join(GUILD_CONFIG_DIR, "user_infractions.json") @@ -59,7 +73,9 @@ except Exception as e: # Initialize User Infractions if not os.path.exists(USER_INFRACTIONS_PATH): with open(USER_INFRACTIONS_PATH, "w", encoding="utf-8") as f: - json.dump({}, f) # Stores infractions as { "guild_id_user_id": [infraction_list] } + json.dump( + {}, f + ) # Stores infractions as { "guild_id_user_id": [infraction_list] } try: with open(USER_INFRACTIONS_PATH, "r", encoding="utf-8") as f: USER_INFRACTIONS = json.load(f) @@ -67,6 +83,7 @@ except Exception as e: print(f"Failed to load user infractions from {USER_INFRACTIONS_PATH}: {e}") USER_INFRACTIONS = {} + def save_guild_config(): try: # os.makedirs(os.path.dirname(GUILD_CONFIG_PATH), exist_ok=True) # Already created by GUILD_CONFIG_DIR @@ -78,6 +95,7 @@ def save_guild_config(): except Exception as e: print(f"Failed to save per-guild config: {e}") + def save_user_infractions(): try: # os.makedirs(os.path.dirname(USER_INFRACTIONS_PATH), exist_ok=True) # Already created by GUILD_CONFIG_DIR @@ -89,12 +107,14 @@ def save_user_infractions(): except Exception as e: print(f"Failed to save user infractions: {e}") + def get_guild_config(guild_id: int, key: str, default=None): guild_str = str(guild_id) if guild_str in GUILD_CONFIG and key in GUILD_CONFIG[guild_str]: return GUILD_CONFIG[guild_str][key] return default + def set_guild_config(guild_id: int, key: str, value): guild_str = str(guild_id) if guild_str not in GUILD_CONFIG: @@ -102,12 +122,21 @@ def set_guild_config(guild_id: int, key: str, value): GUILD_CONFIG[guild_str][key] = value save_guild_config() + def get_user_infraction_history(guild_id: int, user_id: int) -> list: """Retrieves a list of past infractions for a specific user in a guild.""" key = f"{guild_id}_{user_id}" return USER_INFRACTIONS.get(key, []) -def add_user_infraction(guild_id: int, user_id: int, rule_violated: str, action_taken: str, reasoning: str, timestamp: str): + +def add_user_infraction( + guild_id: int, + user_id: int, + rule_violated: str, + action_taken: str, + reasoning: str, + timestamp: str, +): """Adds a new infraction record for a user.""" key = f"{guild_id}_{user_id}" if key not in USER_INFRACTIONS: @@ -117,13 +146,14 @@ def add_user_infraction(guild_id: int, user_id: int, rule_violated: str, action_ "timestamp": timestamp, "rule_violated": rule_violated, "action_taken": action_taken, - "reasoning": reasoning + "reasoning": reasoning, } USER_INFRACTIONS[key].append(infraction_record) # Keep only the last N infractions to prevent the file from growing too large, e.g., last 10 USER_INFRACTIONS[key] = USER_INFRACTIONS[key][-10:] save_user_infractions() + # Server rules to provide context to the AI SERVER_RULES = """ # Server Rules @@ -180,10 +210,12 @@ Please reach out to one of these. We've also alerted our server's support team s You matter, and help is available. """ + class AIModerationCog(commands.Cog): """ A Discord Cog that uses Google Vertex AI to moderate messages based on server rules. """ + def __init__(self, bot: commands.Bot): self.bot = bot self.genai_client = None @@ -194,30 +226,57 @@ class AIModerationCog(commands.Cog): project=PROJECT_ID, location=LOCATION, ) - print(f"AIModerationCog: Google GenAI Client initialized for Vertex AI project '{PROJECT_ID}' in location '{LOCATION}'.") + print( + f"AIModerationCog: Google GenAI Client initialized for Vertex AI project '{PROJECT_ID}' in location '{LOCATION}'." + ) else: - print("AIModerationCog: PROJECT_ID or LOCATION not found in config. Google GenAI Client not initialized.") + print( + "AIModerationCog: PROJECT_ID or LOCATION not found in config. Google GenAI Client not initialized." + ) except Exception as e: - print(f"AIModerationCog: Error initializing Google GenAI Client for Vertex AI: {e}") + print( + f"AIModerationCog: Error initializing Google GenAI Client for Vertex AI: {e}" + ) - self.last_ai_decisions = collections.deque(maxlen=5) # Store last 5 AI decisions + self.last_ai_decisions = collections.deque( + maxlen=5 + ) # Store last 5 AI decisions # Supported image file extensions - self.image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.heic', '.heif'] # Added heic/heif for Vertex + self.image_extensions = [ + ".jpg", + ".jpeg", + ".png", + ".webp", + ".bmp", + ".heic", + ".heif", + ] # Added heic/heif for Vertex # Supported animated file extensions - self.gif_extensions = ['.gif'] + self.gif_extensions = [".gif"] # Supported video file extensions (Vertex AI typically processes first frame of videos as image) - self.video_extensions = ['.mp4', '.webm', '.mov', '.avi', '.mkv', '.flv'] # Expanded list + self.video_extensions = [ + ".mp4", + ".webm", + ".mov", + ".avi", + ".mkv", + ".flv", + ] # Expanded list print("AIModerationCog Initialized.") async def cog_load(self): """Called when the cog is loaded.""" print("AIModerationCog cog_load started.") if not self.genai_client: - print("\n" + "="*60) - print("=== WARNING: AIModerationCog - Vertex AI Client not initialized! ===") + print("\n" + "=" * 60) + print( + "=== WARNING: AIModerationCog - Vertex AI Client not initialized! ===" + ) print("=== The Moderation Cog requires a valid Vertex AI setup. ===") - print(f"=== Check PROJECT_ID and LOCATION in gurt.config and GCP authentication. ===") - print("="*60 + "\n") + print( + f"=== Check PROJECT_ID and LOCATION in gurt.config and GCP authentication. ===" + ) + print("=" * 60 + "\n") else: print("AIModerationCog: Vertex AI Client seems to be initialized.") print("AIModerationCog cog_load finished.") @@ -243,7 +302,9 @@ class AIModerationCog(commands.Cog): try: # Download the image image_bytes = await attachment.read() - mime_type = attachment.content_type or "image/jpeg" # Default to jpeg if not specified + mime_type = ( + attachment.content_type or "image/jpeg" + ) # Default to jpeg if not specified # Return the image bytes and mime type return mime_type, image_bytes @@ -268,14 +329,14 @@ class AIModerationCog(commands.Cog): # Open the GIF using PIL with Image.open(io.BytesIO(gif_bytes)) as gif: # Convert to RGB if needed - if gif.mode != 'RGB': - first_frame = gif.convert('RGB') + if gif.mode != "RGB": + first_frame = gif.convert("RGB") else: first_frame = gif # Save the first frame to a bytes buffer output = io.BytesIO() - first_frame.save(output, format='JPEG') + first_frame.save(output, format="JPEG") output.seek(0) return "image/jpeg", output.getvalue() @@ -283,7 +344,9 @@ class AIModerationCog(commands.Cog): print(f"Error processing GIF: {e}") return None, None - async def process_attachment(self, attachment: discord.Attachment) -> tuple[str, bytes, str]: + async def process_attachment( + self, attachment: discord.Attachment + ) -> tuple[str, bytes, str]: """ Process any attachment and return the appropriate image data. @@ -304,13 +367,13 @@ class AIModerationCog(commands.Cog): # Process based on file type if ext in self.image_extensions: mime_type, image_bytes = await self.process_image(attachment) - return mime_type, image_bytes, 'image' + return mime_type, image_bytes, "image" elif ext in self.gif_extensions: mime_type, image_bytes = await self.process_gif(attachment) - return mime_type, image_bytes, 'gif' + return mime_type, image_bytes, "gif" elif ext in self.video_extensions: mime_type, image_bytes = await self.process_video(attachment) - return mime_type, image_bytes, 'video' + return mime_type, image_bytes, "video" else: print(f"Unsupported file type: {ext}") return None, None, None @@ -328,7 +391,9 @@ class AIModerationCog(commands.Cog): try: # Download the video to a temporary file video_bytes = await attachment.read() - with tempfile.NamedTemporaryFile(suffix=os.path.splitext(attachment.filename)[1], delete=False) as temp_file: + with tempfile.NamedTemporaryFile( + suffix=os.path.splitext(attachment.filename)[1], delete=False + ) as temp_file: temp_file_path = temp_file.name temp_file.write(video_bytes) @@ -349,7 +414,7 @@ class AIModerationCog(commands.Cog): # Save to bytes buffer output = io.BytesIO() - pil_image.save(output, format='JPEG') + pil_image.save(output, format="JPEG") output.seek(0) # Clean up @@ -367,73 +432,138 @@ class AIModerationCog(commands.Cog): return None, None # --- AI Moderation Command Group --- - aimod_group = app_commands.Group(name="aimod", description="AI Moderation commands.") - config_subgroup = app_commands.Group(name="config", description="Configure AI moderation settings.", parent=aimod_group) - infractions_subgroup = app_commands.Group(name="infractions", description="Manage user infractions.", parent=aimod_group) - model_subgroup = app_commands.Group(name="model", description="Manage the AI model for moderation.", parent=aimod_group) - debug_subgroup = app_commands.Group(name="debug", description="Debugging commands for AI moderation.", parent=aimod_group) + aimod_group = app_commands.Group( + name="aimod", description="AI Moderation commands." + ) + config_subgroup = app_commands.Group( + name="config", + description="Configure AI moderation settings.", + parent=aimod_group, + ) + infractions_subgroup = app_commands.Group( + name="infractions", description="Manage user infractions.", parent=aimod_group + ) + model_subgroup = app_commands.Group( + name="model", + description="Manage the AI model for moderation.", + parent=aimod_group, + ) + debug_subgroup = app_commands.Group( + name="debug", + description="Debugging commands for AI moderation.", + parent=aimod_group, + ) - @config_subgroup.command(name="log_channel", description="Set the moderation log channel.") + @config_subgroup.command( + name="log_channel", description="Set the moderation log channel." + ) @app_commands.describe(channel="The text channel to use for moderation logs.") @app_commands.checks.has_permissions(administrator=True) - async def modset_log_channel(self, interaction: discord.Interaction, channel: discord.TextChannel): + async def modset_log_channel( + self, interaction: discord.Interaction, channel: discord.TextChannel + ): set_guild_config(interaction.guild.id, "MOD_LOG_CHANNEL_ID", channel.id) - await interaction.response.send_message(f"Moderation log channel set to {channel.mention}.", ephemeral=False) + await interaction.response.send_message( + f"Moderation log channel set to {channel.mention}.", ephemeral=False + ) - @config_subgroup.command(name="suggestions_channel", description="Set the suggestions channel.") + @config_subgroup.command( + name="suggestions_channel", description="Set the suggestions channel." + ) @app_commands.describe(channel="The text channel to use for suggestions.") @app_commands.checks.has_permissions(administrator=True) - async def modset_suggestions_channel(self, interaction: discord.Interaction, channel: discord.TextChannel): + async def modset_suggestions_channel( + self, interaction: discord.Interaction, channel: discord.TextChannel + ): set_guild_config(interaction.guild.id, "SUGGESTIONS_CHANNEL_ID", channel.id) - await interaction.response.send_message(f"Suggestions channel set to {channel.mention}.", ephemeral=False) + await interaction.response.send_message( + f"Suggestions channel set to {channel.mention}.", ephemeral=False + ) - @config_subgroup.command(name="moderator_role", description="Set the moderator role.") + @config_subgroup.command( + name="moderator_role", description="Set the moderator role." + ) @app_commands.describe(role="The role that identifies moderators.") @app_commands.checks.has_permissions(administrator=True) - async def modset_moderator_role(self, interaction: discord.Interaction, role: discord.Role): + async def modset_moderator_role( + self, interaction: discord.Interaction, role: discord.Role + ): set_guild_config(interaction.guild.id, "MODERATOR_ROLE_ID", role.id) - await interaction.response.send_message(f"Moderator role set to {role.mention}.", ephemeral=False) + await interaction.response.send_message( + f"Moderator role set to {role.mention}.", ephemeral=False + ) - @config_subgroup.command(name="suicidal_ping_role", description="Set the role to ping for suicidal content.") + @config_subgroup.command( + name="suicidal_ping_role", + description="Set the role to ping for suicidal content.", + ) @app_commands.describe(role="The role to ping for urgent suicidal content alerts.") @app_commands.checks.has_permissions(administrator=True) - async def modset_suicidal_ping_role(self, interaction: discord.Interaction, role: discord.Role): + async def modset_suicidal_ping_role( + self, interaction: discord.Interaction, role: discord.Role + ): set_guild_config(interaction.guild.id, "SUICIDAL_PING_ROLE_ID", role.id) - await interaction.response.send_message(f"Suicidal content ping role set to {role.mention}.", ephemeral=False) + await interaction.response.send_message( + f"Suicidal content ping role set to {role.mention}.", ephemeral=False + ) - @config_subgroup.command(name="add_nsfw_channel", description="Add a channel to the list of NSFW channels.") + @config_subgroup.command( + name="add_nsfw_channel", + description="Add a channel to the list of NSFW channels.", + ) @app_commands.describe(channel="The text channel to mark as NSFW for the bot.") @app_commands.checks.has_permissions(administrator=True) - async def modset_add_nsfw_channel(self, interaction: discord.Interaction, channel: discord.TextChannel): + async def modset_add_nsfw_channel( + self, interaction: discord.Interaction, channel: discord.TextChannel + ): guild_id = interaction.guild.id nsfw_channels: list[int] = get_guild_config(guild_id, "NSFW_CHANNEL_IDS", []) if channel.id not in nsfw_channels: nsfw_channels.append(channel.id) set_guild_config(guild_id, "NSFW_CHANNEL_IDS", nsfw_channels) - await interaction.response.send_message(f"{channel.mention} added to NSFW channels list.", ephemeral=False) + await interaction.response.send_message( + f"{channel.mention} added to NSFW channels list.", ephemeral=False + ) else: - await interaction.response.send_message(f"{channel.mention} is already in the NSFW channels list.", ephemeral=True) + await interaction.response.send_message( + f"{channel.mention} is already in the NSFW channels list.", + ephemeral=True, + ) - @config_subgroup.command(name="remove_nsfw_channel", description="Remove a channel from the list of NSFW channels.") + @config_subgroup.command( + name="remove_nsfw_channel", + description="Remove a channel from the list of NSFW channels.", + ) @app_commands.describe(channel="The text channel to remove from the NSFW list.") @app_commands.checks.has_permissions(administrator=True) - async def modset_remove_nsfw_channel(self, interaction: discord.Interaction, channel: discord.TextChannel): + async def modset_remove_nsfw_channel( + self, interaction: discord.Interaction, channel: discord.TextChannel + ): guild_id = interaction.guild.id nsfw_channels: list[int] = get_guild_config(guild_id, "NSFW_CHANNEL_IDS", []) if channel.id in nsfw_channels: nsfw_channels.remove(channel.id) set_guild_config(guild_id, "NSFW_CHANNEL_IDS", nsfw_channels) - await interaction.response.send_message(f"{channel.mention} removed from NSFW channels list.", ephemeral=False) + await interaction.response.send_message( + f"{channel.mention} removed from NSFW channels list.", ephemeral=False + ) else: - await interaction.response.send_message(f"{channel.mention} is not in the NSFW channels list.", ephemeral=True) + await interaction.response.send_message( + f"{channel.mention} is not in the NSFW channels list.", ephemeral=True + ) - @config_subgroup.command(name="list_nsfw_channels", description="List currently configured NSFW channels.") + @config_subgroup.command( + name="list_nsfw_channels", + description="List currently configured NSFW channels.", + ) @app_commands.checks.has_permissions(administrator=True) async def modset_list_nsfw_channels(self, interaction: discord.Interaction): guild_id = interaction.guild.id nsfw_channel_ids: list[int] = get_guild_config(guild_id, "NSFW_CHANNEL_IDS", []) if not nsfw_channel_ids: - await interaction.response.send_message("No NSFW channels are currently configured.", ephemeral=False) + await interaction.response.send_message( + "No NSFW channels are currently configured.", ephemeral=False + ) return channel_mentions = [] @@ -444,56 +574,82 @@ class AIModerationCog(commands.Cog): else: channel_mentions.append(f"ID:{channel_id} (not found)") - await interaction.response.send_message(f"Configured NSFW channels:\n- " + "\n- ".join(channel_mentions), ephemeral=False) + await interaction.response.send_message( + f"Configured NSFW channels:\n- " + "\n- ".join(channel_mentions), + ephemeral=False, + ) # Note: The @app_commands.command(name="modenable", ...) and other commands like # viewinfractions, clearinfractions, modsetmodel, modgetmodel remain as top-level commands # as they were not part of the original "modset" generic command structure. # If these also need to be grouped, that would be a separate consideration. - @config_subgroup.command(name="enable", description="Enable or disable moderation for this guild (admin only).") + @config_subgroup.command( + name="enable", + description="Enable or disable moderation for this guild (admin only).", + ) @app_commands.describe(enabled="Enable moderation (true/false)") async def modenable(self, interaction: discord.Interaction, enabled: bool): if not interaction.user.guild_permissions.administrator: - await interaction.response.send_message("You must be an administrator to use this command.", ephemeral=False) + await interaction.response.send_message( + "You must be an administrator to use this command.", ephemeral=False + ) return set_guild_config(interaction.guild.id, "ENABLED", enabled) - await interaction.response.send_message(f"Moderation is now {'enabled' if enabled else 'disabled'} for this guild.", ephemeral=False) + await interaction.response.send_message( + f"Moderation is now {'enabled' if enabled else 'disabled'} for this guild.", + ephemeral=False, + ) - @infractions_subgroup.command(name="view", description="View a user's AI moderation infraction history (mod/admin only).") + @infractions_subgroup.command( + name="view", + description="View a user's AI moderation infraction history (mod/admin only).", + ) @app_commands.describe(user="The user to view infractions for") - async def viewinfractions(self, interaction: discord.Interaction, user: discord.Member): + async def viewinfractions( + self, interaction: discord.Interaction, user: discord.Member + ): # Check if user has permission (admin or moderator role) moderator_role_id = get_guild_config(interaction.guild.id, "MODERATOR_ROLE_ID") - moderator_role = interaction.guild.get_role(moderator_role_id) if moderator_role_id else None + moderator_role = ( + interaction.guild.get_role(moderator_role_id) if moderator_role_id else None + ) - has_permission = (interaction.user.guild_permissions.administrator or - (moderator_role and moderator_role in interaction.user.roles)) + has_permission = interaction.user.guild_permissions.administrator or ( + moderator_role and moderator_role in interaction.user.roles + ) if not has_permission: - await interaction.response.send_message("You must be an administrator or have the moderator role to use this command.", ephemeral=True) + await interaction.response.send_message( + "You must be an administrator or have the moderator role to use this command.", + ephemeral=True, + ) return # Get the user's infraction history infractions = get_user_infraction_history(interaction.guild.id, user.id) if not infractions: - await interaction.response.send_message(f"{user.mention} has no recorded infractions.", ephemeral=False) + await interaction.response.send_message( + f"{user.mention} has no recorded infractions.", ephemeral=False + ) return # Create an embed to display the infractions embed = discord.Embed( title=f"Infraction History for {user.display_name}", description=f"User ID: {user.id}", - color=discord.Color.orange() + color=discord.Color.orange(), ) # Add each infraction to the embed for i, infraction in enumerate(infractions, 1): - timestamp = infraction.get('timestamp', 'Unknown date')[:19].replace('T', ' ') # Format ISO timestamp - rule = infraction.get('rule_violated', 'Unknown rule') - action = infraction.get('action_taken', 'Unknown action') - reason = infraction.get('reasoning', 'No reason provided') + timestamp = infraction.get("timestamp", "Unknown date")[:19].replace( + "T", " " + ) # Format ISO timestamp + rule = infraction.get("rule_violated", "Unknown rule") + action = infraction.get("action_taken", "Unknown action") + reason = infraction.get("reasoning", "No reason provided") # Truncate reason if it's too long if len(reason) > 200: @@ -502,7 +658,7 @@ class AIModerationCog(commands.Cog): embed.add_field( name=f"Infraction #{i} - {timestamp}", value=f"**Rule Violated:** {rule}\n**Action Taken:** {action}\n**Reason:** {reason}", - inline=False + inline=False, ) embed.set_footer(text=f"Total infractions: {len(infractions)}") @@ -510,12 +666,19 @@ class AIModerationCog(commands.Cog): await interaction.response.send_message(embed=embed, ephemeral=False) - @infractions_subgroup.command(name="clear", description="Clear a user's AI moderation infraction history (admin only).") + @infractions_subgroup.command( + name="clear", + description="Clear a user's AI moderation infraction history (admin only).", + ) @app_commands.describe(user="The user to clear infractions for") - async def clearinfractions(self, interaction: discord.Interaction, user: discord.Member): + async def clearinfractions( + self, interaction: discord.Interaction, user: discord.Member + ): # Check if user has administrator permission if not interaction.user.guild_permissions.administrator: - await interaction.response.send_message("You must be an administrator to use this command.", ephemeral=True) + await interaction.response.send_message( + "You must be an administrator to use this command.", ephemeral=True + ) return # Get the user's infraction history @@ -523,28 +686,42 @@ class AIModerationCog(commands.Cog): infractions = USER_INFRACTIONS.get(key, []) if not infractions: - await interaction.response.send_message(f"{user.mention} has no recorded infractions to clear.", ephemeral=False) + await interaction.response.send_message( + f"{user.mention} has no recorded infractions to clear.", ephemeral=False + ) return # Clear the user's infractions USER_INFRACTIONS[key] = [] save_user_infractions() - await interaction.response.send_message(f"Cleared {len(infractions)} infraction(s) for {user.mention}.", ephemeral=False) + await interaction.response.send_message( + f"Cleared {len(infractions)} infraction(s) for {user.mention}.", + ephemeral=False, + ) - @model_subgroup.command(name="set", description="Change the AI model used for moderation (admin only).") - @app_commands.describe(model="The Vertex AI model to use (e.g., 'gemini-1.5-flash-001', 'gemini-1.0-pro')") + @model_subgroup.command( + name="set", description="Change the AI model used for moderation (admin only)." + ) + @app_commands.describe( + model="The Vertex AI model to use (e.g., 'gemini-1.5-flash-001', 'gemini-1.0-pro')" + ) async def modsetmodel(self, interaction: discord.Interaction, model: str): # Check if user has administrator permission if not interaction.user.guild_permissions.administrator: - await interaction.response.send_message("You must be an administrator to use this command.", ephemeral=True) + await interaction.response.send_message( + "You must be an administrator to use this command.", ephemeral=True + ) return # Validate the model name (basic validation for Vertex AI) # Vertex AI models usually don't have "/" like OpenRouter, but can have "-" and numbers. # Example: gemini-1.5-flash-001 - if not model or len(model) < 5: # Basic check - await interaction.response.send_message("Invalid model format. Please provide a valid Vertex AI model ID (e.g., 'gemini-1.5-flash-001').", ephemeral=False) + if not model or len(model) < 5: # Basic check + await interaction.response.send_message( + "Invalid model format. Please provide a valid Vertex AI model ID (e.g., 'gemini-1.5-flash-001').", + ephemeral=False, + ) return # Save the model to guild configuration @@ -554,12 +731,16 @@ class AIModerationCog(commands.Cog): # Note: There's no global model variable to update here like OPENROUTER_MODEL. # The cog will use the guild-specific config or the DEFAULT_VERTEX_AI_MODEL. - await interaction.response.send_message(f"AI moderation model updated to `{model}` for this guild.", ephemeral=False) + await interaction.response.send_message( + f"AI moderation model updated to `{model}` for this guild.", ephemeral=False + ) # @modsetmodel.autocomplete('model') # Autocomplete removed as OpenRouter models are not used. # async def modsetmodel_autocomplete(...): # This function is now removed. - @model_subgroup.command(name="get", description="View the current AI model used for moderation.") + @model_subgroup.command( + name="get", description="View the current AI model used for moderation." + ) async def modgetmodel(self, interaction: discord.Interaction): # Get the model from guild config, fall back to global default guild_id = interaction.guild.id @@ -569,16 +750,20 @@ class AIModerationCog(commands.Cog): embed = discord.Embed( title="AI Moderation Model", description=f"The current AI model used for moderation in this server is:", - color=discord.Color.blue() + color=discord.Color.blue(), ) embed.add_field(name="Model In Use", value=f"`{model_used}`", inline=False) - embed.add_field(name="Default Model", value=f"`{DEFAULT_VERTEX_AI_MODEL}`", inline=False) + embed.add_field( + name="Default Model", value=f"`{DEFAULT_VERTEX_AI_MODEL}`", inline=False + ) embed.set_footer(text="Use /aimod model set to change the model") await interaction.response.send_message(embed=embed, ephemeral=False) # --- Helper Function to Safely Extract Text from Vertex AI Response --- - def _get_response_text(self, response: Optional[types.GenerateContentResponse]) -> Optional[str]: + def _get_response_text( + self, response: Optional[types.GenerateContentResponse] + ) -> Optional[str]: """ Safely extracts the text content from the first text part of a GenerateContentResponse. Handles potential errors and lack of text parts gracefully. @@ -588,43 +773,69 @@ class AIModerationCog(commands.Cog): print("[AIModerationCog._get_response_text] Received None response object.") return None - if hasattr(response, 'text') and response.text: # Some simpler responses might have .text directly - print("[AIModerationCog._get_response_text] Found text directly in response.text attribute.") + if ( + hasattr(response, "text") and response.text + ): # Some simpler responses might have .text directly + print( + "[AIModerationCog._get_response_text] Found text directly in response.text attribute." + ) return response.text if not response.candidates: - print(f"[AIModerationCog._get_response_text] Response object has no candidates. Response: {response}") + print( + f"[AIModerationCog._get_response_text] Response object has no candidates. Response: {response}" + ) return None try: candidate = response.candidates[0] - if not hasattr(candidate, 'content') or not candidate.content: - print(f"[AIModerationCog._get_response_text] Candidate 0 has no 'content'. Candidate: {candidate}") + if not hasattr(candidate, "content") or not candidate.content: + print( + f"[AIModerationCog._get_response_text] Candidate 0 has no 'content'. Candidate: {candidate}" + ) return None - if not hasattr(candidate.content, 'parts') or not candidate.content.parts: - print(f"[AIModerationCog._get_response_text] Candidate 0 content has no 'parts' or parts list is empty. types.Content: {candidate.content}") + if not hasattr(candidate.content, "parts") or not candidate.content.parts: + print( + f"[AIModerationCog._get_response_text] Candidate 0 content has no 'parts' or parts list is empty. types.Content: {candidate.content}" + ) return None for i, part in enumerate(candidate.content.parts): - if hasattr(part, 'text') and part.text is not None: + if hasattr(part, "text") and part.text is not None: if isinstance(part.text, str) and part.text.strip(): - print(f"[AIModerationCog._get_response_text] Found non-empty text in part {i}.") + print( + f"[AIModerationCog._get_response_text] Found non-empty text in part {i}." + ) return part.text else: - print(f"[AIModerationCog._get_response_text] types.Part {i} has 'text' attribute, but it's empty or not a string: {part.text!r}") - print(f"[AIModerationCog._get_response_text] No usable text part found in candidate 0 after iterating through all parts.") + print( + f"[AIModerationCog._get_response_text] types.Part {i} has 'text' attribute, but it's empty or not a string: {part.text!r}" + ) + print( + f"[AIModerationCog._get_response_text] No usable text part found in candidate 0 after iterating through all parts." + ) return None except (AttributeError, IndexError, TypeError) as e: - print(f"[AIModerationCog._get_response_text] Error accessing response structure: {type(e).__name__}: {e}") + print( + f"[AIModerationCog._get_response_text] Error accessing response structure: {type(e).__name__}: {e}" + ) print(f"Problematic response object: {response}") return None except Exception as e: - print(f"[AIModerationCog._get_response_text] Unexpected error extracting text: {e}") + print( + f"[AIModerationCog._get_response_text] Unexpected error extracting text: {e}" + ) print(f"Response object during error: {response}") return None - async def query_vertex_ai(self, message: discord.Message, message_content: str, user_history: str, image_data_list: Optional[List[Tuple[str, bytes, str, str]]] = None): + async def query_vertex_ai( + self, + message: discord.Message, + message_content: str, + user_history: str, + image_data_list: Optional[List[Tuple[str, bytes, str, str]]] = None, + ): """ Sends the message content, user history, and additional context to Google Vertex AI for analysis. Optionally includes image data for visual content moderation. @@ -638,7 +849,9 @@ class AIModerationCog(commands.Cog): Returns: A dictionary containing the AI's decision, or None if an error occurs. """ - print(f"query_vertex_ai called. Vertex AI client available: {self.genai_client is not None}") + print( + f"query_vertex_ai called. Vertex AI client available: {self.genai_client is not None}" + ) if not self.genai_client: print("Error: Vertex AI Client is not available. Cannot query API.") return None @@ -756,8 +969,8 @@ Example Response (Notify Mods): """ - member = message.author # This is a discord.Member object - server_role_str = "Unprivileged Member" # Default + member = message.author # This is a discord.Member object + server_role_str = "Unprivileged Member" # Default if member == await message.guild.fetch_member(message.guild.owner_id): server_role_str = "Server Owner" @@ -765,7 +978,12 @@ Example Response (Notify Mods): server_role_str = "Admin" else: perms = member.guild_permissions - if perms.manage_messages or perms.kick_members or perms.ban_members or perms.moderate_members: + if ( + perms.manage_messages + or perms.kick_members + or perms.ban_members + or perms.moderate_members + ): server_role_str = "Moderator" print(f"role: {server_role_str}") @@ -774,16 +992,22 @@ Example Response (Notify Mods): replied_to_message_content = "N/A (Not a reply)" if message.reference and message.reference.message_id: try: - replied_to_msg = await message.channel.fetch_message(message.reference.message_id) + replied_to_msg = await message.channel.fetch_message( + message.reference.message_id + ) replied_to_message_content = f"User '{replied_to_msg.author.name}' said: \"{replied_to_msg.content[:200]}\"" if len(replied_to_msg.content) > 200: replied_to_message_content += "..." except discord.NotFound: replied_to_message_content = "N/A (Replied-to message not found)" except discord.Forbidden: - replied_to_message_content = "N/A (Cannot fetch replied-to message - permissions)" + replied_to_message_content = ( + "N/A (Cannot fetch replied-to message - permissions)" + ) except Exception as e: - replied_to_message_content = f"N/A (Error fetching replied-to message: {e})" + replied_to_message_content = ( + f"N/A (Error fetching replied-to message: {e})" + ) # --- Fetch Recent Channel History --- recent_channel_history_str = "N/A (Could not fetch history)" @@ -791,20 +1015,33 @@ Example Response (Notify Mods): history_messages = [] # Fetch last 11 messages (current + 10 previous). We'll filter out the current one async for prev_msg in message.channel.history(limit=11, before=message): - if prev_msg.id != message.id: # Ensure we don't include the current message itself - author_name = prev_msg.author.name + " (BOT)" if prev_msg.author.bot else prev_msg.author.name - history_messages.append(f"- {author_name}: \"{prev_msg.content[:150]}{'...' if len(prev_msg.content) > 150 else ''}\" (ID: {prev_msg.id})") + if ( + prev_msg.id != message.id + ): # Ensure we don't include the current message itself + author_name = ( + prev_msg.author.name + " (BOT)" + if prev_msg.author.bot + else prev_msg.author.name + ) + history_messages.append( + f"- {author_name}: \"{prev_msg.content[:150]}{'...' if len(prev_msg.content) > 150 else ''}\" (ID: {prev_msg.id})" + ) if history_messages: # Reverse to show oldest first in the snippet, then take the last 10. - recent_channel_history_str = "\n".join(list(reversed(history_messages))[:10]) + recent_channel_history_str = "\n".join( + list(reversed(history_messages))[:10] + ) else: - recent_channel_history_str = "No recent messages before this one in the channel." + recent_channel_history_str = ( + "No recent messages before this one in the channel." + ) except discord.Forbidden: - recent_channel_history_str = "N/A (Cannot fetch channel history - permissions)" + recent_channel_history_str = ( + "N/A (Cannot fetch channel history - permissions)" + ) except Exception as e: recent_channel_history_str = f"N/A (Error fetching channel history: {e})" - # Prepare user prompt content list with proper OpenRouter format user_prompt_content_list = [] @@ -835,62 +1072,83 @@ Follow the JSON output format specified in the system prompt. CRITICAL: Do NOT output anything other than the required JSON response. """ # Add the text content first - user_prompt_content_list.append({ - "type": "text", - "text": user_context_text - }) + user_prompt_content_list.append({"type": "text", "text": user_context_text}) # Add images in the proper OpenRouter format if image_data_list and len(image_data_list) > 0: try: - for i, (mime_type, image_bytes, attachment_type, filename) in enumerate(image_data_list): + for i, (mime_type, image_bytes, attachment_type, filename) in enumerate( + image_data_list + ): try: # Encode image to base64 - base64_image = base64.b64encode(image_bytes).decode('utf-8') + base64_image = base64.b64encode(image_bytes).decode("utf-8") # Create data URL image_data_url = f"data:{mime_type};base64,{base64_image}" # Add image in OpenRouter format - user_prompt_content_list.append({ - "type": "image_url", - "image_url": { - "url": image_data_url - } - }) + user_prompt_content_list.append( + {"type": "image_url", "image_url": {"url": image_data_url}} + ) - print(f"Added attachment #{i+1}: {filename} ({attachment_type}) to the prompt") + print( + f"Added attachment #{i+1}: {filename} ({attachment_type}) to the prompt" + ) except Exception as e: - print(f"Error encoding image data for attachment {filename}: {e}") + print( + f"Error encoding image data for attachment {filename}: {e}" + ) except Exception as e: print(f"Error processing image data: {e}") # Add a text note about the error - user_prompt_content_list.append({ - "type": "text", - "text": f"Note: There were {len(image_data_list)} attached images, but they could not be processed for analysis." - }) + user_prompt_content_list.append( + { + "type": "text", + "text": f"Note: There were {len(image_data_list)} attached images, but they could not be processed for analysis.", + } + ) # Get guild-specific model if configured, otherwise use default member = message.author server_role_str = "Unprivileged Member" - if member == await message.guild.fetch_member(message.guild.owner_id): server_role_str = "Server Owner" - elif member.guild_permissions.administrator: server_role_str = "Admin" + if member == await message.guild.fetch_member(message.guild.owner_id): + server_role_str = "Server Owner" + elif member.guild_permissions.administrator: + server_role_str = "Admin" else: perms = member.guild_permissions - if perms.manage_messages or perms.kick_members or perms.ban_members or perms.moderate_members: server_role_str = "Moderator" + if ( + perms.manage_messages + or perms.kick_members + or perms.ban_members + or perms.moderate_members + ): + server_role_str = "Moderator" replied_to_message_content = "N/A (Not a reply)" if message.reference and message.reference.message_id: try: - replied_to_msg = await message.channel.fetch_message(message.reference.message_id) + replied_to_msg = await message.channel.fetch_message( + message.reference.message_id + ) replied_to_message_content = f"User '{replied_to_msg.author.name}' said: \"{replied_to_msg.content[:200]}{'...' if len(replied_to_msg.content) > 200 else ''}\"" - except Exception as e: replied_to_message_content = f"N/A (Error fetching replied-to: {e})" + except Exception as e: + replied_to_message_content = f"N/A (Error fetching replied-to: {e})" recent_channel_history_str = "N/A (Could not fetch history)" try: - history_messages = [f"- {prev_msg.author.name}{' (BOT)' if prev_msg.author.bot else ''}: \"{prev_msg.content[:150]}{'...' if len(prev_msg.content) > 150 else ''}\" (ID: {prev_msg.id})" - async for prev_msg in message.channel.history(limit=11, before=message) if prev_msg.id != message.id] - recent_channel_history_str = "\n".join(list(reversed(history_messages))[:10]) if history_messages else "No recent messages." - except Exception as e: recent_channel_history_str = f"N/A (Error fetching history: {e})" + history_messages = [ + f"- {prev_msg.author.name}{' (BOT)' if prev_msg.author.bot else ''}: \"{prev_msg.content[:150]}{'...' if len(prev_msg.content) > 150 else ''}\" (ID: {prev_msg.id})" + async for prev_msg in message.channel.history(limit=11, before=message) + if prev_msg.id != message.id + ] + recent_channel_history_str = ( + "\n".join(list(reversed(history_messages))[:10]) + if history_messages + else "No recent messages." + ) + except Exception as e: + recent_channel_history_str = f"N/A (Error fetching history: {e})" user_context_text = f"""User Infraction History (for {message.author.name}, ID: {message.author.id}): --- @@ -927,22 +1185,57 @@ CRITICAL: Do NOT output anything other than the required JSON response. # Ensure mime_type is one of the supported ones by Vertex, e.g., image/png, image/jpeg, etc. # Common image types are generally fine. # For video, the extracted frame is JPEG. - supported_image_mimes = ["image/png", "image/jpeg", "image/webp", "image/heic", "image/heif", "image/gif"] - clean_mime_type = mime_type.split(';')[0].lower() + supported_image_mimes = [ + "image/png", + "image/jpeg", + "image/webp", + "image/heic", + "image/heif", + "image/gif", + ] + clean_mime_type = mime_type.split(";")[0].lower() - if clean_mime_type in supported_image_mimes or attachment_type == 'video': # Video frame is jpeg - vertex_parts.append(types.Part(inline_data=types.Blob(data=image_bytes, mime_type=clean_mime_type if clean_mime_type in supported_image_mimes else "image/jpeg"))) - print(f"Added attachment {filename} ({attachment_type} as {clean_mime_type if clean_mime_type in supported_image_mimes else 'image/jpeg'}) to Vertex prompt") + if ( + clean_mime_type in supported_image_mimes + or attachment_type == "video" + ): # Video frame is jpeg + vertex_parts.append( + types.Part( + inline_data=types.Blob( + data=image_bytes, + mime_type=( + clean_mime_type + if clean_mime_type in supported_image_mimes + else "image/jpeg" + ), + ) + ) + ) + print( + f"Added attachment {filename} ({attachment_type} as {clean_mime_type if clean_mime_type in supported_image_mimes else 'image/jpeg'}) to Vertex prompt" + ) else: - print(f"Skipping attachment {filename} due to unsupported MIME type for Vertex: {mime_type}") - vertex_parts.append(types.Part(text=f"[System Note: Attachment '{filename}' of type '{mime_type}' was not processed as it's not directly supported for vision by the current model configuration.]")) + print( + f"Skipping attachment {filename} due to unsupported MIME type for Vertex: {mime_type}" + ) + vertex_parts.append( + types.Part( + text=f"[System Note: Attachment '{filename}' of type '{mime_type}' was not processed as it's not directly supported for vision by the current model configuration.]" + ) + ) except Exception as e: print(f"Error processing attachment {filename} for Vertex AI: {e}") - vertex_parts.append(types.Part(text=f"[System Note: Error processing attachment '{filename}'.]")) + vertex_parts.append( + types.Part( + text=f"[System Note: Error processing attachment '{filename}'.]" + ) + ) # Get guild-specific model if configured, otherwise use default guild_id = message.guild.id - model_id_to_use = get_guild_config(guild_id, "AI_MODEL", DEFAULT_VERTEX_AI_MODEL) + model_id_to_use = get_guild_config( + guild_id, "AI_MODEL", DEFAULT_VERTEX_AI_MODEL + ) # Vertex model path is usually like "publishers/google/models/gemini-1.5-flash-001" # If model_id_to_use is just "gemini-1.5-flash-001", prepend "publishers/google/models/" if not model_id_to_use.startswith("publishers/google/models/"): @@ -956,9 +1249,9 @@ CRITICAL: Do NOT output anything other than the required JSON response. generation_config = types.GenerateContentConfig( temperature=0.2, - max_output_tokens=2000, # Ensure enough for JSON + max_output_tokens=2000, # Ensure enough for JSON safety_settings=STANDARD_SAFETY_SETTINGS, - thinking_config=thinking_config + thinking_config=thinking_config, ) # Construct contents for Vertex AI API @@ -968,7 +1261,7 @@ CRITICAL: Do NOT output anything other than the required JSON response. # Here, we'll build the `contents` list. # The system prompt is part of the model's understanding, and the user prompt contains the task. # For multi-turn, history is added to `contents`. Here, it's a single-turn request. - + request_contents = [ # System prompt can be the first message if not using system_instruction in model # types.Content(role="system", parts=[types.Part(text=system_prompt_text)]), # This is one way @@ -982,7 +1275,6 @@ CRITICAL: Do NOT output anything other than the required JSON response. types.Content(role="user", parts=vertex_parts) ] - try: print(f"Querying Vertex AI model {model_path}...") @@ -990,68 +1282,106 @@ CRITICAL: Do NOT output anything other than the required JSON response. # The existing 'generation_config' (lines 1063-1072) already has temperature, max_tokens, safety_settings. # We need to add system_instruction to it. final_generation_config = types.GenerateContentConfig( - temperature=generation_config.temperature, # from existing config - max_output_tokens=generation_config.max_output_tokens, # from existing config - safety_settings=generation_config.safety_settings, # from existing config - system_instruction=types.Content(role="system", parts=[types.Part(text=system_prompt_text)]), - thinking_config=generation_config.thinking_config, # from existing config + temperature=generation_config.temperature, # from existing config + max_output_tokens=generation_config.max_output_tokens, # from existing config + safety_settings=generation_config.safety_settings, # from existing config + system_instruction=types.Content( + role="system", parts=[types.Part(text=system_prompt_text)] + ), + thinking_config=generation_config.thinking_config, # from existing config # response_mime_type="application/json", # Consider if model supports this for forcing JSON ) response = await self.genai_client.aio.models.generate_content( - model=model_path, # Correctly formatted model path - contents=request_contents, # User's message with context and images - config=final_generation_config, # Pass the config with system_instruction + model=model_path, # Correctly formatted model path + contents=request_contents, # User's message with context and images + config=final_generation_config, # Pass the config with system_instruction ) - + ai_response_content = self._get_response_text(response) - print(response.usage_metadata) # Print usage metadata for debugging + print(response.usage_metadata) # Print usage metadata for debugging if not ai_response_content: print("Error: AI response content is empty or could not be extracted.") # Log safety ratings if available - if response and response.candidates and response.candidates[0].safety_ratings: - ratings = ", ".join([f"{r.category.name}: {r.probability.name}" for r in response.candidates[0].safety_ratings]) + if ( + response + and response.candidates + and response.candidates[0].safety_ratings + ): + ratings = ", ".join( + [ + f"{r.category.name}: {r.probability.name}" + for r in response.candidates[0].safety_ratings + ] + ) print(f"Safety Ratings: {ratings}") - if response and response.candidates and response.candidates[0].finish_reason: - print(f"Finish Reason: {response.candidates[0].finish_reason.name}") + if ( + response + and response.candidates + and response.candidates[0].finish_reason + ): + print(f"Finish Reason: {response.candidates[0].finish_reason.name}") return None # Attempt to parse the JSON response from the AI try: # Clean potential markdown code blocks if ai_response_content.startswith("```json"): - ai_response_content = ai_response_content.strip("```json\n").strip("`\n ") + ai_response_content = ai_response_content.strip("```json\n").strip( + "`\n " + ) elif ai_response_content.startswith("```"): - ai_response_content = ai_response_content.strip("```\n").strip("`\n ") + ai_response_content = ai_response_content.strip("```\n").strip( + "`\n " + ) ai_decision = json.loads(ai_response_content) # Basic validation of the parsed JSON structure - if not isinstance(ai_decision, dict) or \ - not all(k in ai_decision for k in ["violation", "rule_violated", "reasoning", "action"]) or \ - not isinstance(ai_decision.get("violation"), bool): - print(f"Error: AI response missing expected keys or 'violation' is not bool. Response: {ai_response_content}") + if ( + not isinstance(ai_decision, dict) + or not all( + k in ai_decision + for k in ["violation", "rule_violated", "reasoning", "action"] + ) + or not isinstance(ai_decision.get("violation"), bool) + ): + print( + f"Error: AI response missing expected keys or 'violation' is not bool. Response: {ai_response_content}" + ) return None print(f"AI Analysis Received: {ai_decision}") return ai_decision except json.JSONDecodeError as e: - print(f"Error: Could not decode JSON response from AI: {e}. Response: {ai_response_content}") + print( + f"Error: Could not decode JSON response from AI: {e}. Response: {ai_response_content}" + ) return None - except Exception as e: # Catch other parsing errors - print(f"Error parsing AI response structure: {e}. Response: {ai_response_content}") + except Exception as e: # Catch other parsing errors + print( + f"Error parsing AI response structure: {e}. Response: {ai_response_content}" + ) return None except google_exceptions.GoogleAPICallError as e: print(f"Error calling Vertex AI API: {e}") return None except Exception as e: - print(f"An unexpected error occurred during Vertex AI query for message {message.id}: {e}") + print( + f"An unexpected error occurred during Vertex AI query for message {message.id}: {e}" + ) return None - async def handle_violation(self, message: discord.Message, ai_decision: dict, notify_mods_message: str = None): + + async def handle_violation( + self, + message: discord.Message, + ai_decision: dict, + notify_mods_message: str = None, + ): """ Takes action based on the AI's violation decision. Also transmits action info via HTTP POST with API key header. @@ -1061,13 +1391,21 @@ CRITICAL: Do NOT output anything other than the required JSON response. rule_violated = ai_decision.get("rule_violated", "Unknown") reasoning = ai_decision.get("reasoning", "No reasoning provided.") - action = ai_decision.get("action", "NOTIFY_MODS").upper() # Default to notify mods - guild_id = message.guild.id # Get guild_id once - user_id = message.author.id # Get user_id once + action = ai_decision.get( + "action", "NOTIFY_MODS" + ).upper() # Default to notify mods + guild_id = message.guild.id # Get guild_id once + user_id = message.author.id # Get user_id once moderator_role_id = get_guild_config(guild_id, "MODERATOR_ROLE_ID") - moderator_role = message.guild.get_role(moderator_role_id) if moderator_role_id else None - mod_ping = moderator_role.mention if moderator_role else f"Moderators (Role ID {moderator_role_id} not found)" + moderator_role = ( + message.guild.get_role(moderator_role_id) if moderator_role_id else None + ) + mod_ping = ( + moderator_role.mention + if moderator_role + else f"Moderators (Role ID {moderator_role_id} not found)" + ) current_timestamp_iso = datetime.datetime.now(datetime.timezone.utc).isoformat() @@ -1078,7 +1416,7 @@ CRITICAL: Do NOT output anything other than the required JSON response. try: mod_log_api_secret = os.getenv("MOD_LOG_API_SECRET") if mod_log_api_secret: - post_url = f"https://slipstreamm.dev/dashboard/api/guilds/{guild_id}/ai-moderation-action" #will be replaceing later with the Learnhelp API + post_url = f"https://slipstreamm.dev/dashboard/api/guilds/{guild_id}/ai-moderation-action" # will be replaceing later with the Learnhelp API payload = { "timestamp": current_timestamp_iso, "guild_id": guild_id, @@ -1089,25 +1427,31 @@ CRITICAL: Do NOT output anything other than the required JSON response. "message_link": message.jump_url, "user_id": user_id, "user_name": str(message.author), - "action": action, # This will be the AI suggested action before potential overrides + "action": action, # This will be the AI suggested action before potential overrides "rule_violated": rule_violated, "reasoning": reasoning, "violation": ai_decision.get("violation", False), - "message_content": message.content[:1024] if message.content else "", + "message_content": ( + message.content[:1024] if message.content else "" + ), "full_message_content": message.content if message.content else "", "ai_model": model_used, - "result": "pending_system_action" # Indicates AI decision received, system action pending + "result": "pending_system_action", # Indicates AI decision received, system action pending } headers = { "Authorization": f"Bearer {mod_log_api_secret}", - "Content-Type": "application/json" + "Content-Type": "application/json", } - async with aiohttp.ClientSession() as http_session: # Renamed session to avoid conflict - async with http_session.post(post_url, headers=headers, json=payload, timeout=10) as resp: + async with aiohttp.ClientSession() as http_session: # Renamed session to avoid conflict + async with http_session.post( + post_url, headers=headers, json=payload, timeout=10 + ) as resp: # This payload is just for the initial AI decision log # The actual outcome will be logged after the action is performed if resp.status >= 400: - print(f"Failed to POST initial AI decision log: {resp.status}") + print( + f"Failed to POST initial AI decision log: {resp.status}" + ) else: print("MOD_LOG_API_SECRET not set; skipping initial action POST.") except Exception as e: @@ -1117,67 +1461,133 @@ CRITICAL: Do NOT output anything other than the required JSON response. notification_embed = discord.Embed( title="🚨 Rule Violation Detected 🚨", description=f"AI analysis detected a violation of server rules.", - color=discord.Color.red() + color=discord.Color.red(), + ) + notification_embed.add_field( + name="User", + value=f"{message.author.mention} (`{message.author.id}`)", + inline=False, + ) + notification_embed.add_field( + name="Channel", value=message.channel.mention, inline=False + ) + notification_embed.add_field( + name="Rule Violated", value=f"**Rule {rule_violated}**", inline=True + ) + notification_embed.add_field( + name="AI Suggested Action", value=f"`{action}`", inline=True + ) + notification_embed.add_field( + name="AI Reasoning", value=f"_{reasoning}_", inline=False + ) + notification_embed.add_field( + name="Message Link", + value=f"[Jump to Message]({message.jump_url})", + inline=False, ) - notification_embed.add_field(name="User", value=f"{message.author.mention} (`{message.author.id}`)", inline=False) - notification_embed.add_field(name="Channel", value=message.channel.mention, inline=False) - notification_embed.add_field(name="Rule Violated", value=f"**Rule {rule_violated}**", inline=True) - notification_embed.add_field(name="AI Suggested Action", value=f"`{action}`", inline=True) - notification_embed.add_field(name="AI Reasoning", value=f"_{reasoning}_", inline=False) - notification_embed.add_field(name="Message Link", value=f"[Jump to Message]({message.jump_url})", inline=False) # Log message content and attachments for audit purposes msg_content = message.content if message.content else "*No text content*" - notification_embed.add_field(name="Message Content", value=msg_content[:1024], inline=False) + notification_embed.add_field( + name="Message Content", value=msg_content[:1024], inline=False + ) # Add attachment information if present if message.attachments: attachment_info = [] for i, attachment in enumerate(message.attachments): - attachment_info.append(f"{i+1}. {attachment.filename} ({attachment.content_type}) - [Link]({attachment.url})") + attachment_info.append( + f"{i+1}. {attachment.filename} ({attachment.content_type}) - [Link]({attachment.url})" + ) attachment_text = "\n".join(attachment_info) - notification_embed.add_field(name="Attachments", value=attachment_text[:1024], inline=False) + notification_embed.add_field( + name="Attachments", value=attachment_text[:1024], inline=False + ) # Add the first image as a thumbnail if it's an image type for attachment in message.attachments: - if any(attachment.filename.lower().endswith(ext) for ext in - self.image_extensions + self.gif_extensions + self.video_extensions): + if any( + attachment.filename.lower().endswith(ext) + for ext in self.image_extensions + + self.gif_extensions + + self.video_extensions + ): notification_embed.set_thumbnail(url=attachment.url) break # Use the model_used variable that was defined earlier - notification_embed.set_footer(text=f"AI Model: {model_used}. Learnhelp AI Moderation.") - notification_embed.timestamp = discord.utils.utcnow() # Using discord.utils.utcnow() which is still supported + notification_embed.set_footer( + text=f"AI Model: {model_used}. Learnhelp AI Moderation." + ) + notification_embed.timestamp = ( + discord.utils.utcnow() + ) # Using discord.utils.utcnow() which is still supported - action_taken_message = "" # To append to the notification + action_taken_message = "" # To append to the notification # --- Perform Actions --- try: if action == "BAN": - action_taken_message = f"Action Taken: User **BANNED** and message deleted." + action_taken_message = ( + f"Action Taken: User **BANNED** and message deleted." + ) notification_embed.color = discord.Color.dark_red() try: await message.delete() - except discord.NotFound: print("Message already deleted before banning.") + except discord.NotFound: + print("Message already deleted before banning.") except discord.Forbidden: - print(f"WARNING: Missing permissions to delete message before banning user {message.author}.") - action_taken_message += " (Failed to delete message - check permissions)" + print( + f"WARNING: Missing permissions to delete message before banning user {message.author}." + ) + action_taken_message += ( + " (Failed to delete message - check permissions)" + ) ban_reason = f"AI Mod: Rule {rule_violated}. Reason: {reasoning}" - await message.guild.ban(message.author, reason=ban_reason, delete_message_days=1) - print(f"BANNED user {message.author} for violating rule {rule_violated}.") - add_user_infraction(guild_id, user_id, rule_violated, "BAN", reasoning, current_timestamp_iso) + await message.guild.ban( + message.author, reason=ban_reason, delete_message_days=1 + ) + print( + f"BANNED user {message.author} for violating rule {rule_violated}." + ) + add_user_infraction( + guild_id, + user_id, + rule_violated, + "BAN", + reasoning, + current_timestamp_iso, + ) elif action == "KICK": - action_taken_message = f"Action Taken: User **KICKED** and message deleted." - notification_embed.color = discord.Color.from_rgb(255, 127, 0) # Dark Orange + action_taken_message = ( + f"Action Taken: User **KICKED** and message deleted." + ) + notification_embed.color = discord.Color.from_rgb( + 255, 127, 0 + ) # Dark Orange try: await message.delete() - except discord.NotFound: print("Message already deleted before kicking.") + except discord.NotFound: + print("Message already deleted before kicking.") except discord.Forbidden: - print(f"WARNING: Missing permissions to delete message before kicking user {message.author}.") - action_taken_message += " (Failed to delete message - check permissions)" + print( + f"WARNING: Missing permissions to delete message before kicking user {message.author}." + ) + action_taken_message += ( + " (Failed to delete message - check permissions)" + ) kick_reason = f"AI Mod: Rule {rule_violated}. Reason: {reasoning}" await message.author.kick(reason=kick_reason) - print(f"KICKED user {message.author} for violating rule {rule_violated}.") - add_user_infraction(guild_id, user_id, rule_violated, "KICK", reasoning, current_timestamp_iso) + print( + f"KICKED user {message.author} for violating rule {rule_violated}." + ) + add_user_infraction( + guild_id, + user_id, + rule_violated, + "KICK", + reasoning, + current_timestamp_iso, + ) elif action.startswith("TIMEOUT"): duration_seconds = 0 @@ -1189,7 +1599,7 @@ CRITICAL: Do NOT output anything other than the required JSON response. duration_seconds = 60 * 60 # 1 hour duration_readable = "1 hour" elif action == "TIMEOUT_LONG": - duration_seconds = 24 * 60 * 60 # 1 day + duration_seconds = 24 * 60 * 60 # 1 day duration_readable = "1 day" if duration_seconds > 0: @@ -1197,36 +1607,68 @@ CRITICAL: Do NOT output anything other than the required JSON response. notification_embed.color = discord.Color.blue() try: await message.delete() - except discord.NotFound: print(f"Message already deleted before timeout for {message.author}.") + except discord.NotFound: + print( + f"Message already deleted before timeout for {message.author}." + ) except discord.Forbidden: - print(f"WARNING: Missing permissions to delete message before timeout for {message.author}.") - action_taken_message += " (Failed to delete message - check permissions)" + print( + f"WARNING: Missing permissions to delete message before timeout for {message.author}." + ) + action_taken_message += ( + " (Failed to delete message - check permissions)" + ) - timeout_reason = f"AI Mod: Rule {rule_violated}. Reason: {reasoning}" + timeout_reason = ( + f"AI Mod: Rule {rule_violated}. Reason: {reasoning}" + ) # discord.py timeout takes a timedelta object - await message.author.timeout(discord.utils.utcnow() + datetime.timedelta(seconds=duration_seconds), reason=timeout_reason) - print(f"TIMED OUT user {message.author} for {duration_readable} for violating rule {rule_violated}.") - add_user_infraction(guild_id, user_id, rule_violated, action, reasoning, current_timestamp_iso) + await message.author.timeout( + discord.utils.utcnow() + + datetime.timedelta(seconds=duration_seconds), + reason=timeout_reason, + ) + print( + f"TIMED OUT user {message.author} for {duration_readable} for violating rule {rule_violated}." + ) + add_user_infraction( + guild_id, + user_id, + rule_violated, + action, + reasoning, + current_timestamp_iso, + ) else: - action_taken_message = "Action Taken: **Unknown timeout duration, notifying mods.**" - action = "NOTIFY_MODS" # Fallback if timeout duration is not recognized - print(f"Unknown timeout duration for action {action}. Defaulting to NOTIFY_MODS.") - + action_taken_message = ( + "Action Taken: **Unknown timeout duration, notifying mods.**" + ) + action = ( + "NOTIFY_MODS" # Fallback if timeout duration is not recognized + ) + print( + f"Unknown timeout duration for action {action}. Defaulting to NOTIFY_MODS." + ) elif action == "DELETE": action_taken_message = f"Action Taken: Message **DELETED**." await message.delete() - print(f"DELETED message from {message.author} for violating rule {rule_violated}.") + print( + f"DELETED message from {message.author} for violating rule {rule_violated}." + ) # Typically, a simple delete isn't a formal infraction unless it's part of a WARN. # If you want to log deletes as infractions, add: # add_user_infraction(guild_id, user_id, rule_violated, "DELETE", reasoning, current_timestamp_iso) - elif action == "WARN": - action_taken_message = f"Action Taken: Message **DELETED** (AI suggested WARN)." + action_taken_message = ( + f"Action Taken: Message **DELETED** (AI suggested WARN)." + ) notification_embed.color = discord.Color.orange() - await message.delete() # Warnings usually involve deleting the offending message - print(f"DELETED message from {message.author} (AI suggested WARN for rule {rule_violated}).") + await message.delete() # Warnings usually involve deleting the offending message + print( + f"DELETED message from {message.author} (AI suggested WARN for rule {rule_violated})." + ) try: dm_channel = await message.author.create_dm() await dm_channel.send( @@ -1235,81 +1677,130 @@ CRITICAL: Do NOT output anything other than the required JSON response. ) action_taken_message += " User notified via DM with warning." except discord.Forbidden: - print(f"Could not DM warning to {message.author} (DMs likely disabled).") + print( + f"Could not DM warning to {message.author} (DMs likely disabled)." + ) action_taken_message += " (Could not DM user for warning)." except Exception as e: print(f"Error sending warning DM to {message.author}: {e}") action_taken_message += " (Error sending warning DM)." - add_user_infraction(guild_id, user_id, rule_violated, "WARN", reasoning, current_timestamp_iso) - + add_user_infraction( + guild_id, + user_id, + rule_violated, + "WARN", + reasoning, + current_timestamp_iso, + ) elif action == "NOTIFY_MODS": action_taken_message = "Action Taken: **Moderator review requested.**" notification_embed.color = discord.Color.gold() - print(f"Notifying moderators about potential violation (Rule {rule_violated}) by {message.author}.") + print( + f"Notifying moderators about potential violation (Rule {rule_violated}) by {message.author}." + ) # NOTIFY_MODS itself isn't an infraction on the user, but a request for human review. # If mods take action, they would log it manually or via a mod command. if notify_mods_message: - notification_embed.add_field(name="Additional Mod Message", value=notify_mods_message, inline=False) + notification_embed.add_field( + name="Additional Mod Message", + value=notify_mods_message, + inline=False, + ) elif action == "SUICIDAL": - action_taken_message = "Action Taken: **User DMed resources, relevant role notified.**" + action_taken_message = ( + "Action Taken: **User DMed resources, relevant role notified.**" + ) # No infraction is typically logged for "SUICIDAL" as it's a support action. notification_embed.title = "🚨 Suicidal Content Detected 🚨" - notification_embed.color = discord.Color.dark_purple() # A distinct color + notification_embed.color = ( + discord.Color.dark_purple() + ) # A distinct color notification_embed.description = "AI analysis detected content indicating potential suicidal ideation." - print(f"SUICIDAL content detected from {message.author}. DMing resources and notifying role.") + print( + f"SUICIDAL content detected from {message.author}. DMing resources and notifying role." + ) # DM the user with help resources try: dm_channel = await message.author.create_dm() await dm_channel.send(SUICIDAL_HELP_RESOURCES) action_taken_message += " User successfully DMed." except discord.Forbidden: - print(f"Could not DM suicidal help resources to {message.author} (DMs likely disabled).") + print( + f"Could not DM suicidal help resources to {message.author} (DMs likely disabled)." + ) action_taken_message += " (Could not DM user - DMs disabled)." except Exception as e: - print(f"Error sending suicidal help resources DM to {message.author}: {e}") + print( + f"Error sending suicidal help resources DM to {message.author}: {e}" + ) action_taken_message += f" (Error DMing user: {e})." # The message itself is usually not deleted for suicidal content, to allow for intervention. # If deletion is desired, add: await message.delete() here. - else: # Includes "IGNORE" or unexpected actions - if ai_decision.get("violation"): # If violation is true but action is IGNORE - action_taken_message = "Action Taken: **None** (AI suggested IGNORE despite flagging violation - Review Recommended)." - notification_embed.color = discord.Color.light_grey() - print(f"AI flagged violation ({rule_violated}) but suggested IGNORE for message by {message.author}. Notifying mods for review.") + else: # Includes "IGNORE" or unexpected actions + if ai_decision.get( + "violation" + ): # If violation is true but action is IGNORE + action_taken_message = "Action Taken: **None** (AI suggested IGNORE despite flagging violation - Review Recommended)." + notification_embed.color = discord.Color.light_grey() + print( + f"AI flagged violation ({rule_violated}) but suggested IGNORE for message by {message.author}. Notifying mods for review." + ) else: # This case shouldn't be reached if called correctly, but handle defensively - print(f"No action taken for message by {message.author} (AI Action: {action}, Violation: False)") - return # Don't notify if no violation and action is IGNORE + print( + f"No action taken for message by {message.author} (AI Action: {action}, Violation: False)" + ) + return # Don't notify if no violation and action is IGNORE # --- Send Notification to Moderators/Relevant Role --- log_channel_id = get_guild_config(message.guild.id, "MOD_LOG_CHANNEL_ID") - log_channel = self.bot.get_channel(log_channel_id) if log_channel_id else None + log_channel = ( + self.bot.get_channel(log_channel_id) if log_channel_id else None + ) if not log_channel: - print(f"ERROR: Moderation log channel (ID: {log_channel_id}) not found or not configured. Defaulting to message channel.") + print( + f"ERROR: Moderation log channel (ID: {log_channel_id}) not found or not configured. Defaulting to message channel." + ) log_channel = message.channel if not log_channel: - print(f"ERROR: Could not find even the original message channel {message.channel.id} to send notification.") + print( + f"ERROR: Could not find even the original message channel {message.channel.id} to send notification." + ) return if action == "SUICIDAL": - suicidal_role_id = get_guild_config(message.guild.id, "SUICIDAL_PING_ROLE_ID") - suicidal_role = message.guild.get_role(suicidal_role_id) if suicidal_role_id else None - ping_target = suicidal_role.mention if suicidal_role else f"Role ID {suicidal_role_id} (Suicidal Content)" + suicidal_role_id = get_guild_config( + message.guild.id, "SUICIDAL_PING_ROLE_ID" + ) + suicidal_role = ( + message.guild.get_role(suicidal_role_id) + if suicidal_role_id + else None + ) + ping_target = ( + suicidal_role.mention + if suicidal_role + else f"Role ID {suicidal_role_id} (Suicidal Content)" + ) if not suicidal_role: print(f"ERROR: Suicidal ping role ID {suicidal_role_id} not found.") final_message = f"{ping_target}\n{action_taken_message}" await log_channel.send(content=final_message, embed=notification_embed) - elif moderator_role: # For other violations + elif moderator_role: # For other violations final_message = f"{mod_ping}\n{action_taken_message}" await log_channel.send(content=final_message, embed=notification_embed) - else: # Fallback if moderator role is also not found for non-suicidal actions - print(f"ERROR: Moderator role ID {moderator_role_id} not found for action {action}.") - + else: # Fallback if moderator role is also not found for non-suicidal actions + print( + f"ERROR: Moderator role ID {moderator_role_id} not found for action {action}." + ) except discord.Forbidden as e: - print(f"ERROR: Missing Permissions to perform action '{action}' for rule {rule_violated}. Details: {e}") + print( + f"ERROR: Missing Permissions to perform action '{action}' for rule {rule_violated}. Details: {e}" + ) # Try to notify mods about the failure if moderator_role: try: @@ -1319,22 +1810,29 @@ CRITICAL: Do NOT output anything other than the required JSON response. f"Reasoning: _{reasoning}_\nMessage Link: {message.jump_url}" ) except discord.Forbidden: - print("FATAL: Bot lacks permission to send messages, even error notifications.") + print( + "FATAL: Bot lacks permission to send messages, even error notifications." + ) except discord.NotFound: - print(f"Message {message.id} was likely already deleted when trying to perform action '{action}'.") + print( + f"Message {message.id} was likely already deleted when trying to perform action '{action}'." + ) except Exception as e: - print(f"An unexpected error occurred during action execution for message {message.id}: {e}") + print( + f"An unexpected error occurred during action execution for message {message.id}: {e}" + ) # Try to notify mods about the unexpected error if moderator_role: - try: + try: await message.channel.send( f"{mod_ping} **UNEXPECTED ERROR!** An error occurred while handling rule violation " f"for {message.author.mention}. Please check bot logs.\n" f"Rule: {rule_violated}, Action Attempted: {action}\nMessage Link: {message.jump_url}" ) - except discord.Forbidden: - print("FATAL: Bot lacks permission to send messages, even error notifications.") - + except discord.Forbidden: + print( + "FATAL: Bot lacks permission to send messages, even error notifications." + ) @commands.Cog.listener(name="on_message") async def message_listener(self, message: discord.Message): @@ -1347,15 +1845,17 @@ CRITICAL: Do NOT output anything other than the required JSON response. return # Ignore messages without content or attachments if not message.content and not message.attachments: - print(f"Ignoring message {message.id} with no content or attachments.") - return + print(f"Ignoring message {message.id} with no content or attachments.") + return # Ignore DMs if not message.guild: print(f"Ignoring message {message.id} from DM.") return # Check if moderation is enabled for this guild if not get_guild_config(message.guild.id, "ENABLED", True): - print(f"Moderation disabled for guild {message.guild.id}. Ignoring message {message.id}.") + print( + f"Moderation disabled for guild {message.guild.id}. Ignoring message {message.id}." + ) return # --- Suicidal Content Check --- @@ -1369,18 +1869,28 @@ CRITICAL: Do NOT output anything other than the required JSON response. if message.attachments: # Process all attachments for attachment in message.attachments: - mime_type, image_bytes, attachment_type = await self.process_attachment(attachment) + mime_type, image_bytes, attachment_type = await self.process_attachment( + attachment + ) if mime_type and image_bytes and attachment_type: - image_data_list.append((mime_type, image_bytes, attachment_type, attachment.filename)) - print(f"Processed attachment: {attachment.filename} as {attachment_type}") + image_data_list.append( + (mime_type, image_bytes, attachment_type, attachment.filename) + ) + print( + f"Processed attachment: {attachment.filename} as {attachment_type}" + ) # Log the number of attachments processed if image_data_list: - print(f"Processed {len(image_data_list)} attachments for message {message.id}") + print( + f"Processed {len(image_data_list)} attachments for message {message.id}" + ) # Only proceed with AI analysis if there's text to analyze or attachments if not message_content and not image_data_list: - print(f"Ignoring message {message.id} with no content or valid attachments.") + print( + f"Ignoring message {message.id} with no content or valid attachments." + ) return # NSFW channel check removed - AI will handle this context @@ -1388,7 +1898,9 @@ CRITICAL: Do NOT output anything other than the required JSON response. # --- Call AI for Analysis (All Rules) --- # Check if the Vertex AI client is available if not self.genai_client: - print(f"Skipping AI analysis for message {message.id}: Vertex AI client is not initialized.") + print( + f"Skipping AI analysis for message {message.id}: Vertex AI client is not initialized." + ) return # Prepare user history for the AI @@ -1396,70 +1908,110 @@ CRITICAL: Do NOT output anything other than the required JSON response. history_summary_parts = [] if infractions: for infr in infractions: - history_summary_parts.append(f"- Action: {infr.get('action_taken', 'N/A')} for Rule {infr.get('rule_violated', 'N/A')} on {infr.get('timestamp', 'N/A')[:10]}. Reason: {infr.get('reasoning', 'N/A')[:50]}...") - user_history_summary = "\n".join(history_summary_parts) if history_summary_parts else "No prior infractions recorded." + history_summary_parts.append( + f"- Action: {infr.get('action_taken', 'N/A')} for Rule {infr.get('rule_violated', 'N/A')} on {infr.get('timestamp', 'N/A')[:10]}. Reason: {infr.get('reasoning', 'N/A')[:50]}..." + ) + user_history_summary = ( + "\n".join(history_summary_parts) + if history_summary_parts + else "No prior infractions recorded." + ) # Limit history summary length to prevent excessively long prompts max_history_len = 500 if len(user_history_summary) > max_history_len: - user_history_summary = user_history_summary[:max_history_len-3] + "..." + user_history_summary = user_history_summary[: max_history_len - 3] + "..." - - print(f"Analyzing message {message.id} from {message.author} in #{message.channel.name} with history...") + print( + f"Analyzing message {message.id} from {message.author} in #{message.channel.name} with history..." + ) if image_data_list: attachment_types = [data[2] for data in image_data_list] - print(f"Including {len(image_data_list)} attachments in analysis: {', '.join(attachment_types)}") - ai_decision = await self.query_vertex_ai(message, message_content, user_history_summary, image_data_list) + print( + f"Including {len(image_data_list)} attachments in analysis: {', '.join(attachment_types)}" + ) + ai_decision = await self.query_vertex_ai( + message, message_content, user_history_summary, image_data_list + ) # --- Process AI Decision --- if not ai_decision: print(f"Failed to get valid AI decision for message {message.id}.") # Optionally notify mods about AI failure if it happens often # Store the failure attempt for debugging - self.last_ai_decisions.append({ + self.last_ai_decisions.append( + { + "message_id": message.id, + "author_name": str(message.author), + "author_id": message.author.id, + "message_content_snippet": ( + message.content[:100] + "..." + if len(message.content) > 100 + else message.content + ), + "timestamp": datetime.datetime.now( + datetime.timezone.utc + ).isoformat(), + "ai_decision": { + "error": "Failed to get valid AI decision", + "raw_response": None, + }, # Simplified error logging + } + ) + return # Stop if AI fails or returns invalid data + + # Store the AI decision regardless of violation status + self.last_ai_decisions.append( + { "message_id": message.id, "author_name": str(message.author), "author_id": message.author.id, - "message_content_snippet": message.content[:100] + "..." if len(message.content) > 100 else message.content, + "message_content_snippet": ( + message.content[:100] + "..." + if len(message.content) > 100 + else message.content + ), "timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(), - "ai_decision": {"error": "Failed to get valid AI decision", "raw_response": None} # Simplified error logging - }) - return # Stop if AI fails or returns invalid data - - # Store the AI decision regardless of violation status - self.last_ai_decisions.append({ - "message_id": message.id, - "author_name": str(message.author), - "author_id": message.author.id, - "message_content_snippet": message.content[:100] + "..." if len(message.content) > 100 else message.content, - "timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(), - "ai_decision": ai_decision - }) + "ai_decision": ai_decision, + } + ) # Check if the AI flagged a violation if ai_decision.get("violation"): # Handle the violation based on AI decision without overrides # Pass notify_mods_message if the action is NOTIFY_MODS - notify_mods_message = ai_decision.get("notify_mods_message") if ai_decision.get("action") == "NOTIFY_MODS" else None + notify_mods_message = ( + ai_decision.get("notify_mods_message") + if ai_decision.get("action") == "NOTIFY_MODS" + else None + ) await self.handle_violation(message, ai_decision, notify_mods_message) else: # AI found no violation - print(f"AI analysis complete for message {message.id}. No violation detected.") + print( + f"AI analysis complete for message {message.id}. No violation detected." + ) - @debug_subgroup.command(name="last_decisions", description="View the last 5 AI moderation decisions (admin only).") + @debug_subgroup.command( + name="last_decisions", + description="View the last 5 AI moderation decisions (admin only).", + ) @app_commands.checks.has_permissions(administrator=True) async def aidebug_last_decisions(self, interaction: discord.Interaction): if not self.last_ai_decisions: - await interaction.response.send_message("No AI decisions have been recorded yet.", ephemeral=True) + await interaction.response.send_message( + "No AI decisions have been recorded yet.", ephemeral=True + ) return embed = discord.Embed( - title="Last 5 AI Moderation Decisions", - color=discord.Color.purple() + title="Last 5 AI Moderation Decisions", color=discord.Color.purple() ) embed.timestamp = discord.utils.utcnow() - for i, record in enumerate(reversed(list(self.last_ai_decisions))): # Show newest first + for i, record in enumerate( + reversed(list(self.last_ai_decisions)) + ): # Show newest first decision_info = record.get("ai_decision", {}) violation = decision_info.get("violation", "N/A") rule_violated = decision_info.get("rule_violated", "N/A") @@ -1490,23 +2042,33 @@ CRITICAL: Do NOT output anything other than the required JSON response. embed.add_field( name=f"Decision #{len(self.last_ai_decisions) - i}", value=field_value, - inline=False + inline=False, ) - if len(embed.fields) >= 5: # Limit to 5 fields in one embed for very long entries, or send multiple embeds + if ( + len(embed.fields) >= 5 + ): # Limit to 5 fields in one embed for very long entries, or send multiple embeds break - if not embed.fields: # Should not happen if self.last_ai_decisions is not empty - await interaction.response.send_message("Could not format AI decisions.", ephemeral=True) - return + if not embed.fields: # Should not happen if self.last_ai_decisions is not empty + await interaction.response.send_message( + "Could not format AI decisions.", ephemeral=True + ) + return await interaction.response.send_message(embed=embed, ephemeral=True) @aidebug_last_decisions.error - async def aidebug_last_decisions_error(self, interaction: discord.Interaction, error: app_commands.AppCommandError): + async def aidebug_last_decisions_error( + self, interaction: discord.Interaction, error: app_commands.AppCommandError + ): if isinstance(error, app_commands.MissingPermissions): - await interaction.response.send_message("You must be an administrator to use this command.", ephemeral=True) + await interaction.response.send_message( + "You must be an administrator to use this command.", ephemeral=True + ) else: - await interaction.response.send_message(f"An error occurred: {error}", ephemeral=True) + await interaction.response.send_message( + f"An error occurred: {error}", ephemeral=True + ) print(f"Error in aidebug_last_decisions command: {error}") @@ -1517,6 +2079,7 @@ async def setup(bot: commands.Bot): await bot.add_cog(AIModerationCog(bot)) print("AIModerationCog has been loaded.") + if __name__ == "__main__": # Server rules to provide context to the AI SERVER_RULES = """ @@ -1651,7 +2214,9 @@ Example Response (Suicidal Content): example_channel_category_name = "Text Channels" example_channel_is_nsfw = False example_replied_to_message_content = "N/A (Not a reply)" - example_recent_channel_history_str = "- OtherUser: \"Hello there!\" (ID: 111)\n- AnotherUser: \"How are you?\" (ID: 222)" + example_recent_channel_history_str = ( + '- OtherUser: "Hello there!" (ID: 111)\n- AnotherUser: "How are you?" (ID: 222)' + ) example_message_content = "This is an example message that might be a bit edgy." user_prompt_text_example = f"""User Infraction History (for {example_message_author_name}, ID: {example_message_author_id}): diff --git a/cogs/ban_system_cog.py b/cogs/ban_system_cog.py index fe939c4..6131719 100644 --- a/cogs/ban_system_cog.py +++ b/cogs/ban_system_cog.py @@ -10,24 +10,28 @@ import asyncpg # Configure logging log = logging.getLogger(__name__) + class UserBannedError(commands.CheckFailure): """Custom exception for banned users.""" + def __init__(self, user_id: int, message: str): self.user_id = user_id self.message = message super().__init__(message) + class BanSystemCog(commands.Cog): """Cog for banning specific users from using the bot.""" def __init__(self, bot: commands.Bot): self.bot = bot - self.banned_users_cache = {} # user_id -> {reason, message, banned_at, banned_by} + self.banned_users_cache = ( + {} + ) # user_id -> {reason, message, banned_at, banned_by} # Create the main command group for this cog self.bansys_group = app_commands.Group( - name="bansys", - description="Bot user ban system commands (Owner only)" + name="bansys", description="Bot user ban system commands (Owner only)" ) # Register commands @@ -45,7 +49,9 @@ class BanSystemCog(commands.Cog): self.bot.add_check(self.check_if_user_banned) # Store the original interaction check if it exists - self.original_interaction_check = getattr(self.bot.tree, 'interaction_check', None) + self.original_interaction_check = getattr( + self.bot.tree, "interaction_check", None + ) # Register our interaction check for slash commands self.bot.tree.interaction_check = self.interaction_check @@ -55,13 +61,16 @@ class BanSystemCog(commands.Cog): # Wait for the bot to be ready to ensure the database pool is available await self.bot.wait_until_ready() - if not hasattr(self.bot, 'pg_pool') or self.bot.pg_pool is None: - log.error("PostgreSQL pool not available. Ban system will not work properly.") + if not hasattr(self.bot, "pg_pool") or self.bot.pg_pool is None: + log.error( + "PostgreSQL pool not available. Ban system will not work properly." + ) return try: async with self.bot.pg_pool.acquire() as conn: - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS banned_users ( user_id BIGINT PRIMARY KEY, reason TEXT, @@ -69,7 +78,8 @@ class BanSystemCog(commands.Cog): banned_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, banned_by BIGINT NOT NULL ); - """) + """ + ) log.info("Created or verified banned_users table in PostgreSQL.") # Load banned users into cache @@ -79,7 +89,7 @@ class BanSystemCog(commands.Cog): async def _load_banned_users(self): """Load all banned users into the cache.""" - if not hasattr(self.bot, 'pg_pool') or self.bot.pg_pool is None: + if not hasattr(self.bot, "pg_pool") or self.bot.pg_pool is None: log.error("PostgreSQL pool not available. Cannot load banned users.") return @@ -92,11 +102,11 @@ class BanSystemCog(commands.Cog): # Populate the cache with the database records for record in records: - self.banned_users_cache[record['user_id']] = { - 'reason': record['reason'], - 'message': record['message'], - 'banned_at': record['banned_at'], - 'banned_by': record['banned_by'] + self.banned_users_cache[record["user_id"]] = { + "reason": record["reason"], + "message": record["message"], + "banned_at": record["banned_at"], + "banned_by": record["banned_by"], } log.info(f"Loaded {len(records)} banned users into cache.") @@ -117,18 +127,22 @@ class BanSystemCog(commands.Cog): # If the interaction hasn't been responded to yet, respond with the ban message if not interaction.response.is_done(): - await interaction.response.send_message(ban_info['message'], ephemeral=True) + await interaction.response.send_message( + ban_info["message"], ephemeral=True + ) # Log the blocked interaction - log.warning(f"Blocked interaction from banned user {interaction.user.id}: {ban_info['message']}") + log.warning( + f"Blocked interaction from banned user {interaction.user.id}: {ban_info['message']}" + ) # Raise the exception to prevent further processing - raise UserBannedError(interaction.user.id, ban_info['message']) + raise UserBannedError(interaction.user.id, ban_info["message"]) async def check_if_user_banned(self, ctx): """Global check to prevent banned users from using prefix commands.""" # Skip check for DMs - if not isinstance(ctx, commands.Context) and not hasattr(ctx, 'guild'): + if not isinstance(ctx, commands.Context) and not hasattr(ctx, "guild"): return True # Get the user ID @@ -138,7 +152,7 @@ class BanSystemCog(commands.Cog): if user_id in self.banned_users_cache: ban_info = self.banned_users_cache[user_id] # Raise the custom exception with the ban message - raise UserBannedError(user_id, ban_info['message']) + raise UserBannedError(user_id, ban_info["message"]) # User is not banned, allow the command return True @@ -156,12 +170,16 @@ class BanSystemCog(commands.Cog): # If the interaction hasn't been responded to yet, respond with the ban message if not interaction.response.is_done(): try: - await interaction.response.send_message(ban_info['message'], ephemeral=True) + await interaction.response.send_message( + ban_info["message"], ephemeral=True + ) except Exception as e: - log.error(f"Error sending ban message to user {interaction.user.id}: {e}") + log.error( + f"Error sending ban message to user {interaction.user.id}: {e}" + ) # Raise the custom exception with the ban message - raise UserBannedError(interaction.user.id, ban_info['message']) + raise UserBannedError(interaction.user.id, ban_info["message"]) # If there was an original interaction check, call it if self.original_interaction_check is not None: @@ -178,13 +196,13 @@ class BanSystemCog(commands.Cog): name="ban", description="Ban a user from using the bot", callback=self.bansys_ban_callback, - parent=self.bansys_group + parent=self.bansys_group, ) app_commands.describe( user_id="The ID of the user to ban", message="The message to show when they try to use commands", reason="The reason for the ban (optional)", - ephemeral="Whether the response should be ephemeral (only visible to the user)" + ephemeral="Whether the response should be ephemeral (only visible to the user)", )(ban_command) self.bansys_group.add_command(ban_command) @@ -193,11 +211,11 @@ class BanSystemCog(commands.Cog): name="unban", description="Unban a user from using the bot", callback=self.bansys_unban_callback, - parent=self.bansys_group + parent=self.bansys_group, ) app_commands.describe( user_id="The ID of the user to unban", - ephemeral="Whether the response should be ephemeral (only visible to the user)" + ephemeral="Whether the response should be ephemeral (only visible to the user)", )(unban_command) self.bansys_group.add_command(unban_command) @@ -206,18 +224,27 @@ class BanSystemCog(commands.Cog): name="list", description="List all users banned from using the bot", callback=self.bansys_list_callback, - parent=self.bansys_group + parent=self.bansys_group, ) app_commands.describe( ephemeral="Whether the response should be ephemeral (only visible to the user)" )(list_command) self.bansys_group.add_command(list_command) - async def bansys_ban_callback(self, interaction: discord.Interaction, user_id: str, message: str, reason: Optional[str] = None, ephemeral: bool = True): + async def bansys_ban_callback( + self, + interaction: discord.Interaction, + user_id: str, + message: str, + reason: Optional[str] = None, + ephemeral: bool = True, + ): """Ban a user from using the bot.""" # Check if the user is the bot owner if interaction.user.id != self.bot.owner_id: - await interaction.response.send_message("This command can only be used by the bot owner.", ephemeral=ephemeral) + await interaction.response.send_message( + "This command can only be used by the bot owner.", ephemeral=ephemeral + ) return try: @@ -226,25 +253,33 @@ class BanSystemCog(commands.Cog): # Check if the user is already banned if user_id_int in self.banned_users_cache: - await interaction.response.send_message(f"User {user_id_int} is already banned.", ephemeral=ephemeral) + await interaction.response.send_message( + f"User {user_id_int} is already banned.", ephemeral=ephemeral + ) return # Add the user to the database - if hasattr(self.bot, 'pg_pool') and self.bot.pg_pool is not None: + if hasattr(self.bot, "pg_pool") and self.bot.pg_pool is not None: async with self.bot.pg_pool.acquire() as conn: - await conn.execute(""" + await conn.execute( + """ INSERT INTO banned_users (user_id, reason, message, banned_by) VALUES ($1, $2, $3, $4) ON CONFLICT (user_id) DO UPDATE SET reason = $2, message = $3, banned_by = $4, banned_at = CURRENT_TIMESTAMP - """, user_id_int, reason, message, interaction.user.id) + """, + user_id_int, + reason, + message, + interaction.user.id, + ) # Add the user to the cache self.banned_users_cache[user_id_int] = { - 'reason': reason, - 'message': message, - 'banned_at': datetime.datetime.now(datetime.timezone.utc), - 'banned_by': interaction.user.id + "reason": reason, + "message": message, + "banned_at": datetime.datetime.now(datetime.timezone.utc), + "banned_by": interaction.user.id, } # Try to get the user's name for a more informative message @@ -254,20 +289,33 @@ class BanSystemCog(commands.Cog): except: user_display = f"User ID: {user_id_int}" - await interaction.response.send_message(f"✅ Banned {user_display} from using the bot.\nMessage: {message}\nReason: {reason or 'No reason provided'}", ephemeral=ephemeral) - log.info(f"User {user_id_int} banned by {interaction.user.id}. Reason: {reason}") + await interaction.response.send_message( + f"✅ Banned {user_display} from using the bot.\nMessage: {message}\nReason: {reason or 'No reason provided'}", + ephemeral=ephemeral, + ) + log.info( + f"User {user_id_int} banned by {interaction.user.id}. Reason: {reason}" + ) except ValueError: - await interaction.response.send_message("Invalid user ID. Please provide a valid user ID.", ephemeral=ephemeral) + await interaction.response.send_message( + "Invalid user ID. Please provide a valid user ID.", ephemeral=ephemeral + ) except Exception as e: log.error(f"Error banning user {user_id}: {e}") - await interaction.response.send_message(f"An error occurred while banning the user: {e}", ephemeral=ephemeral) + await interaction.response.send_message( + f"An error occurred while banning the user: {e}", ephemeral=ephemeral + ) - async def bansys_unban_callback(self, interaction: discord.Interaction, user_id: str, ephemeral: bool = True): + async def bansys_unban_callback( + self, interaction: discord.Interaction, user_id: str, ephemeral: bool = True + ): """Unban a user from using the bot.""" # Check if the user is the bot owner if interaction.user.id != self.bot.owner_id: - await interaction.response.send_message("This command can only be used by the bot owner.", ephemeral=ephemeral) + await interaction.response.send_message( + "This command can only be used by the bot owner.", ephemeral=ephemeral + ) return try: @@ -276,13 +324,17 @@ class BanSystemCog(commands.Cog): # Check if the user is banned if user_id_int not in self.banned_users_cache: - await interaction.response.send_message(f"User {user_id_int} is not banned.", ephemeral=ephemeral) + await interaction.response.send_message( + f"User {user_id_int} is not banned.", ephemeral=ephemeral + ) return # Remove the user from the database - if hasattr(self.bot, 'pg_pool') and self.bot.pg_pool is not None: + if hasattr(self.bot, "pg_pool") and self.bot.pg_pool is not None: async with self.bot.pg_pool.acquire() as conn: - await conn.execute("DELETE FROM banned_users WHERE user_id = $1", user_id_int) + await conn.execute( + "DELETE FROM banned_users WHERE user_id = $1", user_id_int + ) # Remove the user from the cache del self.banned_users_cache[user_id_int] @@ -294,31 +346,43 @@ class BanSystemCog(commands.Cog): except: user_display = f"User ID: {user_id_int}" - await interaction.response.send_message(f"✅ Unbanned {user_display} from using the bot.", ephemeral=ephemeral) + await interaction.response.send_message( + f"✅ Unbanned {user_display} from using the bot.", ephemeral=ephemeral + ) log.info(f"User {user_id_int} unbanned by {interaction.user.id}.") except ValueError: - await interaction.response.send_message("Invalid user ID. Please provide a valid user ID.", ephemeral=ephemeral) + await interaction.response.send_message( + "Invalid user ID. Please provide a valid user ID.", ephemeral=ephemeral + ) except Exception as e: log.error(f"Error unbanning user {user_id}: {e}") - await interaction.response.send_message(f"An error occurred while unbanning the user: {e}", ephemeral=ephemeral) + await interaction.response.send_message( + f"An error occurred while unbanning the user: {e}", ephemeral=ephemeral + ) - async def bansys_list_callback(self, interaction: discord.Interaction, ephemeral: bool = True): + async def bansys_list_callback( + self, interaction: discord.Interaction, ephemeral: bool = True + ): """List all users banned from using the bot.""" # Check if the user is the bot owner if interaction.user.id != self.bot.owner_id: - await interaction.response.send_message("This command can only be used by the bot owner.", ephemeral=ephemeral) + await interaction.response.send_message( + "This command can only be used by the bot owner.", ephemeral=ephemeral + ) return if not self.banned_users_cache: - await interaction.response.send_message("No users are currently banned.", ephemeral=ephemeral) + await interaction.response.send_message( + "No users are currently banned.", ephemeral=ephemeral + ) return # Create an embed to display the banned users embed = discord.Embed( title="Banned Users", description=f"Total banned users: {len(self.banned_users_cache)}", - color=discord.Color.red() + color=discord.Color.red(), ) # Add each banned user to the embed @@ -331,11 +395,15 @@ class BanSystemCog(commands.Cog): user_display = f"User ID: {user_id}" # Format the banned_at timestamp - banned_at = ban_info['banned_at'].strftime("%Y-%m-%d %H:%M:%S UTC") if isinstance(ban_info['banned_at'], datetime.datetime) else "Unknown" + banned_at = ( + ban_info["banned_at"].strftime("%Y-%m-%d %H:%M:%S UTC") + if isinstance(ban_info["banned_at"], datetime.datetime) + else "Unknown" + ) # Try to get the banner's name try: - banner = await self.bot.fetch_user(ban_info['banned_by']) + banner = await self.bot.fetch_user(ban_info["banned_by"]) banner_display = f"{banner.name} ({ban_info['banned_by']})" except: banner_display = f"User ID: {ban_info['banned_by']}" @@ -344,10 +412,10 @@ class BanSystemCog(commands.Cog): embed.add_field( name=user_display, value=f"**Reason:** {ban_info['reason'] or 'No reason provided'}\n" - f"**Message:** {ban_info['message']}\n" - f"**Banned at:** {banned_at}\n" - f"**Banned by:** {banner_display}", - inline=False + f"**Message:** {ban_info['message']}\n" + f"**Banned at:** {banned_at}\n" + f"**Banned by:** {banner_display}", + inline=False, ) await interaction.response.send_message(embed=embed, ephemeral=ephemeral) @@ -355,10 +423,11 @@ class BanSystemCog(commands.Cog): def cog_unload(self): """Cleanup when the cog is unloaded.""" # Restore the original interaction check if it exists - if hasattr(self, 'original_interaction_check'): + if hasattr(self, "original_interaction_check"): self.bot.tree.interaction_check = self.original_interaction_check log.info("Restored original interaction check on cog unload.") + # Setup function for loading the cog async def setup(bot): """Add the BanSystemCog to the bot.""" diff --git a/cogs/bot_appearance_cog.py b/cogs/bot_appearance_cog.py index 1e36f9d..21fd671 100644 --- a/cogs/bot_appearance_cog.py +++ b/cogs/bot_appearance_cog.py @@ -4,46 +4,70 @@ from discord import app_commands import httpx import io + # --- Helper: Owner Check --- async def is_owner_check(interaction: discord.Interaction) -> bool: """Checks if the interacting user is the bot owner.""" return interaction.user.id == interaction.client.owner_id + class BotAppearanceCog(commands.Cog): def __init__(self, bot): self.bot = bot - @commands.command(name='change_nickname', help="Changes the bot's nickname in the current server. Admin only.") + @commands.command( + name="change_nickname", + help="Changes the bot's nickname in the current server. Admin only.", + ) @commands.has_permissions(administrator=True) async def change_nickname(self, ctx: commands.Context, *, new_nickname: str): """Changes the bot's nickname in the current server.""" try: await ctx.guild.me.edit(nick=new_nickname) - await ctx.send(f"My nickname has been changed to '{new_nickname}' in this server.") + await ctx.send( + f"My nickname has been changed to '{new_nickname}' in this server." + ) except discord.Forbidden: await ctx.send("I don't have permission to change my nickname here.") except Exception as e: await ctx.send(f"An error occurred: {e}") - @app_commands.command(name="change_nickname", description="Changes the bot's nickname in the current server.") + @app_commands.command( + name="change_nickname", + description="Changes the bot's nickname in the current server.", + ) @app_commands.describe(new_nickname="The new nickname for the bot.") @app_commands.checks.has_permissions(administrator=True) - async def slash_change_nickname(self, interaction: discord.Interaction, new_nickname: str): + async def slash_change_nickname( + self, interaction: discord.Interaction, new_nickname: str + ): """Changes the bot's nickname in the current server.""" try: await interaction.guild.me.edit(nick=new_nickname) - await interaction.response.send_message(f"My nickname has been changed to '{new_nickname}' in this server.", ephemeral=True) + await interaction.response.send_message( + f"My nickname has been changed to '{new_nickname}' in this server.", + ephemeral=True, + ) except discord.Forbidden: - await interaction.response.send_message("I don't have permission to change my nickname here.", ephemeral=True) + await interaction.response.send_message( + "I don't have permission to change my nickname here.", ephemeral=True + ) except Exception as e: - await interaction.response.send_message(f"An error occurred: {e}", ephemeral=True) + await interaction.response.send_message( + f"An error occurred: {e}", ephemeral=True + ) - @commands.command(name='change_avatar', help="Changes the bot's global avatar. Owner only. Provide a direct image URL.") + @commands.command( + name="change_avatar", + help="Changes the bot's global avatar. Owner only. Provide a direct image URL.", + ) @commands.is_owner() async def change_avatar(self, ctx: commands.Context, image_url: str): """Changes the bot's global avatar. Requires a direct image URL.""" - if not (image_url.startswith('http://') or image_url.startswith('https://')): - await ctx.send("Invalid URL. Please provide a direct link to an image (http:// or https://).") + if not (image_url.startswith("http://") or image_url.startswith("https://")): + await ctx.send( + "Invalid URL. Please provide a direct link to an image (http:// or https://)." + ) return try: @@ -57,35 +81,56 @@ class BotAppearanceCog(commands.Cog): except httpx.RequestError as e: await ctx.send(f"Could not fetch the image from the URL: {e}") except discord.Forbidden: - await ctx.send("I don't have permission to change my avatar. This might be due to rate limits or other restrictions.") + await ctx.send( + "I don't have permission to change my avatar. This might be due to rate limits or other restrictions." + ) except discord.HTTPException as e: await ctx.send(f"Failed to change avatar. Discord API error: {e}") except Exception as e: await ctx.send(f"An unexpected error occurred: {e}") - @app_commands.command(name="change_avatar", description="Changes the bot's global avatar using a URL or an uploaded image.") + @app_commands.command( + name="change_avatar", + description="Changes the bot's global avatar using a URL or an uploaded image.", + ) @app_commands.describe( image_url="A direct URL to the image for the new avatar (optional if attachment is provided).", - attachment="An image file to use as the new avatar (optional if URL is provided)." + attachment="An image file to use as the new avatar (optional if URL is provided).", ) @app_commands.check(is_owner_check) - async def slash_change_avatar(self, interaction: discord.Interaction, image_url: str = None, attachment: discord.Attachment = None): + async def slash_change_avatar( + self, + interaction: discord.Interaction, + image_url: str = None, + attachment: discord.Attachment = None, + ): """Changes the bot's global avatar. Accepts a direct image URL or an attachment.""" await interaction.response.defer(ephemeral=True) image_bytes = None if attachment: - if not attachment.content_type or not attachment.content_type.startswith('image/'): - await interaction.response.send_message("Invalid file type. Please upload an image.", ephemeral=True) + if not attachment.content_type or not attachment.content_type.startswith( + "image/" + ): + await interaction.response.send_message( + "Invalid file type. Please upload an image.", ephemeral=True + ) return try: image_bytes = await attachment.read() except Exception as e: - await interaction.response.send_message(f"Could not read the attached image: {e}", ephemeral=True) + await interaction.response.send_message( + f"Could not read the attached image: {e}", ephemeral=True + ) return elif image_url: - if not (image_url.startswith('http://') or image_url.startswith('https://')): - await interaction.response.send_message("Invalid URL. Please provide a direct link to an image (http:// or https://).", ephemeral=True) + if not ( + image_url.startswith("http://") or image_url.startswith("https://") + ): + await interaction.response.send_message( + "Invalid URL. Please provide a direct link to an image (http:// or https://).", + ephemeral=True, + ) return try: async with httpx.AsyncClient() as client: @@ -93,53 +138,89 @@ class BotAppearanceCog(commands.Cog): response.raise_for_status() # Raise an exception for bad status codes image_bytes = await response.aread() except httpx.RequestError as e: - await interaction.response.send_message(f"Could not fetch the image from the URL: {e}", ephemeral=True) + await interaction.response.send_message( + f"Could not fetch the image from the URL: {e}", ephemeral=True + ) return else: - await interaction.response.send_message("Please provide either an image URL or an attachment.", ephemeral=True) + await interaction.response.send_message( + "Please provide either an image URL or an attachment.", ephemeral=True + ) return if image_bytes: try: await self.bot.user.edit(avatar=image_bytes) - await interaction.response.send_message("My avatar has been updated!", ephemeral=True) + await interaction.response.send_message( + "My avatar has been updated!", ephemeral=True + ) except discord.Forbidden: - await interaction.response.send_message("I don't have permission to change my avatar. This might be due to rate limits or other restrictions.", ephemeral=True) + await interaction.response.send_message( + "I don't have permission to change my avatar. This might be due to rate limits or other restrictions.", + ephemeral=True, + ) except discord.HTTPException as e: - await interaction.response.send_message(f"Failed to change avatar. Discord API error: {e}", ephemeral=True) + await interaction.response.send_message( + f"Failed to change avatar. Discord API error: {e}", ephemeral=True + ) except Exception as e: - await interaction.response.send_message(f"An unexpected error occurred: {e}", ephemeral=True) + await interaction.response.send_message( + f"An unexpected error occurred: {e}", ephemeral=True + ) # This else should ideally not be reached if logic above is correct, but as a fallback: else: - await interaction.response.send_message("Failed to process the image.", ephemeral=True) + await interaction.response.send_message( + "Failed to process the image.", ephemeral=True + ) @change_nickname.error @change_avatar.error async def on_command_error(self, ctx: commands.Context, error): if isinstance(error, commands.MissingPermissions): - await ctx.send("You don't have the required permissions (Administrator) to use this command.") + await ctx.send( + "You don't have the required permissions (Administrator) to use this command." + ) elif isinstance(error, commands.NotOwner): - await ctx.send("This command can only be used by the bot owner. If you wish to customize your bot's appearance, please set up a custom bot on the [web dashboard.](https://slipstreamm.dev/dashboard/)") + await ctx.send( + "This command can only be used by the bot owner. If you wish to customize your bot's appearance, please set up a custom bot on the [web dashboard.](https://slipstreamm.dev/dashboard/)" + ) elif isinstance(error, commands.MissingRequiredArgument): - await ctx.send(f"Missing required argument: `{error.param.name}`. Please check the command's help.") + await ctx.send( + f"Missing required argument: `{error.param.name}`. Please check the command's help." + ) else: - print(f"Error in BotAppearanceCog: {error}") # Log other errors to console + print(f"Error in BotAppearanceCog: {error}") # Log other errors to console await ctx.send("An internal error occurred. Please check the logs.") # It's generally better to handle app command errors with a cog-level error handler # or within each command if specific handling is needed. # For simplicity, adding a basic error handler for app_commands. - async def cog_app_command_error(self, interaction: discord.Interaction, error: app_commands.AppCommandError): + async def cog_app_command_error( + self, interaction: discord.Interaction, error: app_commands.AppCommandError + ): if isinstance(error, app_commands.MissingPermissions): - await interaction.response.send_message("You don't have the required permissions (Administrator) to use this command.", ephemeral=True) + await interaction.response.send_message( + "You don't have the required permissions (Administrator) to use this command.", + ephemeral=True, + ) elif isinstance(error, app_commands.CheckFailure): - await interaction.response.send_message("This command can only be used by the bot owner. If you wish to customize your bot's appearance, please set up a custom bot on the web dashboard.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used by the bot owner. If you wish to customize your bot's appearance, please set up a custom bot on the web dashboard.", + ephemeral=True, + ) else: - print(f"Error in BotAppearanceCog (app_command): {error}") # Log other errors to console + print( + f"Error in BotAppearanceCog (app_command): {error}" + ) # Log other errors to console if not interaction.response.is_done(): - await interaction.response.send_message("An internal error occurred. Please check the logs.", ephemeral=True) + await interaction.response.send_message( + "An internal error occurred. Please check the logs.", ephemeral=True + ) else: - await interaction.followup.send("An internal error occurred. Please check the logs.", ephemeral=True) + await interaction.followup.send( + "An internal error occurred. Please check the logs.", ephemeral=True + ) + async def setup(bot): await bot.add_cog(BotAppearanceCog(bot)) diff --git a/cogs/caption_cog.py b/cogs/caption_cog.py index 95802c0..aed3b24 100644 --- a/cogs/caption_cog.py +++ b/cogs/caption_cog.py @@ -5,7 +5,8 @@ from PIL import Image, ImageDraw, ImageFont, ImageSequence import requests import io import os -import textwrap # Import textwrap for text wrapping +import textwrap # Import textwrap for text wrapping + class CaptionCog(commands.Cog, name="Caption"): """Cog for captioning GIFs""" @@ -14,7 +15,7 @@ class CaptionCog(commands.Cog, name="Caption"): CAPTION_PADDING = 10 DEFAULT_GIF_DURATION = 100 MIN_FONT_SIZE = 10 - MAX_FONT_SIZE = 30 # Decreased max font size + MAX_FONT_SIZE = 30 # Decreased max font size TEXT_COLOR = (0, 0, 0) # Black text BAR_COLOR = (255, 255, 255) # White bar @@ -22,7 +23,7 @@ class CaptionCog(commands.Cog, name="Caption"): self.bot = bot # Define preferred font names/paths self.preferred_fonts = [ - os.path.join("FONT", "OPTIFutura-ExtraBlackCond.otf") # Bundled fallback + os.path.join("FONT", "OPTIFutura-ExtraBlackCond.otf") # Bundled fallback ] def _add_text_to_gif(self, image_bytes: bytes, caption_text: str): @@ -33,21 +34,25 @@ class CaptionCog(commands.Cog, name="Caption"): try: gif = Image.open(io.BytesIO(image_bytes)) frames = [] - + # Determine font size (e.g., 20% of image height, capped) - font_size = max(self.MIN_FONT_SIZE, min(self.MAX_FONT_SIZE, int(gif.height * 0.2))) - + font_size = max( + self.MIN_FONT_SIZE, min(self.MAX_FONT_SIZE, int(gif.height * 0.2)) + ) + font = None for font_choice in self.preferred_fonts: try: font = ImageFont.truetype(font_choice, font_size) print(f"Successfully loaded font: {font_choice}") - break + break except IOError: print(f"Could not load font: {font_choice}. Trying next option.") - + if font is None: - print("All preferred fonts failed to load. Using Pillow's default font.") + print( + "All preferred fonts failed to load. Using Pillow's default font." + ) font = ImageFont.load_default() # Adjust font size for default font if necessary, as it might render differently. # This might require re-calculating text_width and text_height if default font is used. @@ -59,10 +64,12 @@ class CaptionCog(commands.Cog, name="Caption"): # Estimate characters per line based on font size and image width # This is a heuristic and might need adjustment based on the font estimated_char_width = font_size * 0.6 - if estimated_char_width == 0: # Avoid division by zero if font_size is somehow 0 - estimated_char_width = 1 + if ( + estimated_char_width == 0 + ): # Avoid division by zero if font_size is somehow 0 + estimated_char_width = 1 chars_per_line = int(max_text_width / estimated_char_width) - if chars_per_line <= 0: # Ensure at least one character per line + if chars_per_line <= 0: # Ensure at least one character per line chars_per_line = 1 wrapped_text = textwrap.wrap(caption_text, width=chars_per_line) @@ -74,10 +81,10 @@ class CaptionCog(commands.Cog, name="Caption"): line_heights = [] for line in wrapped_text: - if hasattr(dummy_draw, 'textbbox'): + if hasattr(dummy_draw, "textbbox"): text_bbox = dummy_draw.textbbox((0, 0), line, font=font) line_heights.append(text_bbox[3] - text_bbox[1]) - else: # For older Pillow versions, use textsize (deprecated) + else: # For older Pillow versions, use textsize (deprecated) line_heights.append(dummy_draw.textsize(line, font=font)[1]) total_text_height = sum(line_heights) @@ -91,11 +98,15 @@ class CaptionCog(commands.Cog, name="Caption"): new_frame_width = frame.width new_frame_height = frame.height + bar_height - new_frame = Image.new("RGBA", (new_frame_width, new_frame_height), (0,0,0,0)) # Transparent background for the new area + new_frame = Image.new( + "RGBA", (new_frame_width, new_frame_height), (0, 0, 0, 0) + ) # Transparent background for the new area # Draw the white bar draw = ImageDraw.Draw(new_frame) - draw.rectangle([(0, 0), (new_frame_width, bar_height)], fill=self.BAR_COLOR) + draw.rectangle( + [(0, 0), (new_frame_width, bar_height)], fill=self.BAR_COLOR + ) # Paste the original frame below the bar new_frame.paste(frame, (0, bar_height)) @@ -104,27 +115,38 @@ class CaptionCog(commands.Cog, name="Caption"): text_y_offset = self.CAPTION_PADDING for line in wrapped_text: # Calculate text position (centered in the bar horizontally) - if hasattr(draw, 'textbbox'): - line_width = draw.textbbox((0, 0), line, font=font)[2] - draw.textbbox((0, 0), line, font=font)[0] - line_height = draw.textbbox((0, 0), line, font=font)[3] - draw.textbbox((0, 0), line, font=font)[1] - else: # For older Pillow versions, use textsize (deprecated) + if hasattr(draw, "textbbox"): + line_width = ( + draw.textbbox((0, 0), line, font=font)[2] + - draw.textbbox((0, 0), line, font=font)[0] + ) + line_height = ( + draw.textbbox((0, 0), line, font=font)[3] + - draw.textbbox((0, 0), line, font=font)[1] + ) + else: # For older Pillow versions, use textsize (deprecated) line_width, line_height = draw.textsize(line, font=font) text_x = (new_frame_width - line_width) / 2 - draw.text((text_x, text_y_offset), line, font=font, fill=self.TEXT_COLOR) + draw.text( + (text_x, text_y_offset), line, font=font, fill=self.TEXT_COLOR + ) text_y_offset += line_height - # Reduce colors to optimize GIF and ensure compatibility - new_frame_alpha = new_frame.getchannel('A') - new_frame = new_frame.convert("RGB").convert("P", palette=Image.ADAPTIVE, colors=255) + new_frame_alpha = new_frame.getchannel("A") + new_frame = new_frame.convert("RGB").convert( + "P", palette=Image.ADAPTIVE, colors=255 + ) # If original had transparency, re-apply mask - if gif.info.get('transparency', None) is not None: - new_frame.info['transparency'] = gif.info['transparency'] # Preserve transparency if present - # Masking might be needed here if the original GIF had complex transparency - # For simplicity, we assume simple transparency or opaque. - # If issues arise, more complex alpha compositing might be needed before converting to "P") + if gif.info.get("transparency", None) is not None: + new_frame.info["transparency"] = gif.info[ + "transparency" + ] # Preserve transparency if present + # Masking might be needed here if the original GIF had complex transparency + # For simplicity, we assume simple transparency or opaque. + # If issues arise, more complex alpha compositing might be needed before converting to "P") frames.append(new_frame) @@ -134,10 +156,16 @@ class CaptionCog(commands.Cog, name="Caption"): format="GIF", save_all=True, append_images=frames[1:], - duration=gif.info.get("duration", self.DEFAULT_GIF_DURATION), # Use original duration, default to constant - loop=gif.info.get("loop", 0), # Use original loop count, default to infinite - transparency=gif.info.get("transparency", None), # Preserve transparency - disposal=2 # Important for GIFs with transparency and animation + duration=gif.info.get( + "duration", self.DEFAULT_GIF_DURATION + ), # Use original duration, default to constant + loop=gif.info.get( + "loop", 0 + ), # Use original loop count, default to infinite + transparency=gif.info.get( + "transparency", None + ), # Preserve transparency + disposal=2, # Important for GIFs with transparency and animation ) output_gif_bytes.seek(0) return output_gif_bytes @@ -145,29 +173,50 @@ class CaptionCog(commands.Cog, name="Caption"): print(f"Error in _add_text_to_gif: {e}") return None - @app_commands.command(name="captiongif", description="Captions a GIF with the provided text.") + @app_commands.command( + name="captiongif", description="Captions a GIF with the provided text." + ) @app_commands.describe( caption="The text to add to the GIF.", url="A URL to a GIF.", - attachment="An uploaded GIF file." + attachment="An uploaded GIF file.", ) - async def caption_gif_slash(self, interaction: discord.Interaction, caption: str, url: str = None, attachment: discord.Attachment = None): + async def caption_gif_slash( + self, + interaction: discord.Interaction, + caption: str, + url: str = None, + attachment: discord.Attachment = None, + ): """Slash command to caption a GIF.""" await interaction.response.defer(thinking=True) if not url and not attachment: - await interaction.followup.send("You must provide either a GIF URL or attach a GIF file.", ephemeral=True) + await interaction.followup.send( + "You must provide either a GIF URL or attach a GIF file.", + ephemeral=True, + ) return if url and attachment: - await interaction.followup.send("Please provide either a URL or an attachment, not both.", ephemeral=True) + await interaction.followup.send( + "Please provide either a URL or an attachment, not both.", + ephemeral=True, + ) return image_bytes = None filename = "captioned_gif.gif" if url: - if not (url.startswith("http://tenor.com/") or url.startswith("https://tenor.com/") or url.endswith(".gif")): - await interaction.followup.send("The URL must be a direct link to a GIF or a Tenor GIF URL.", ephemeral=True) + if not ( + url.startswith("http://tenor.com/") + or url.startswith("https://tenor.com/") + or url.endswith(".gif") + ): + await interaction.followup.send( + "The URL must be a direct link to a GIF or a Tenor GIF URL.", + ephemeral=True, + ) return try: # Handle Tenor URLs - they often don't directly link to the .gif @@ -176,17 +225,19 @@ class CaptionCog(commands.Cog, name="Caption"): # or that a direct .gif link is provided. # A common pattern for Tenor is to find a .mp4 or .gif in the HTML if it's a page URL. # This part might need improvement for robust Tenor URL handling. - + # Basic check for direct .gif or try to fetch content response = requests.get(url, timeout=10) response.raise_for_status() content_type = response.headers.get("Content-Type", "").lower() - - if "gif" not in content_type and url.endswith(".gif"): # If content-type is not gif but url ends with .gif + + if "gif" not in content_type and url.endswith( + ".gif" + ): # If content-type is not gif but url ends with .gif image_bytes = response.content elif "gif" in content_type: image_bytes = response.content - elif "tenor.com" in url: # If it's a tenor URL but not directly a gif + elif "tenor.com" in url: # If it's a tenor URL but not directly a gif # This is a placeholder for more robust Tenor GIF extraction. # Often, the actual GIF is embedded. For now, we'll try to fetch and hope. # A better method would be to parse the HTML for the actual GIF URL. @@ -195,68 +246,99 @@ class CaptionCog(commands.Cog, name="Caption"): # If not, this will likely fail or download HTML. # A quick hack for some tenor URLs: replace .com/view/ with .com/download/ and hope it gives a direct gif if "/view/" in url: - potential_gif_url = url.replace("/view/", "/download/") # This is a guess + potential_gif_url = url.replace( + "/view/", "/download/" + ) # This is a guess # It's better to inspect the page content for the actual media URL # For now, we'll try the original URL. - pass # Keep original URL for now. - + pass # Keep original URL for now. + # Attempt to get the GIF from Tenor page (very basic) if not image_bytes: page_content = response.text import re + # Look for a src attribute ending in .gif within an img tag - match = re.search(r']+src="([^"]+\.gif)"[^>]*>', page_content) + match = re.search( + r']+src="([^"]+\.gif)"[^>]*>', page_content + ) if match: gif_url_from_page = match.group(1) - if not gif_url_from_page.startswith("http"): # handle relative URLs if any + if not gif_url_from_page.startswith( + "http" + ): # handle relative URLs if any from urllib.parse import urljoin + gif_url_from_page = urljoin(url, gif_url_from_page) - + response = requests.get(gif_url_from_page, timeout=10) response.raise_for_status() - if "gif" in response.headers.get("Content-Type", "").lower(): + if ( + "gif" + in response.headers.get("Content-Type", "").lower() + ): image_bytes = response.content - else: # Fallback if no img tag found, try to find a direct media link for tenor + else: # Fallback if no img tag found, try to find a direct media link for tenor # Tenor often uses a specific div for the main GIF content # Example:
# Or sometimes a video tag with a .mp4 that could be converted or a .gif version available # This part is complex without a dedicated Tenor API key and library. # For now, if the initial fetch wasn't a GIF, we might fail here for Tenor pages. - await interaction.followup.send("Could not automatically extract GIF from Tenor URL. Please try a direct GIF link.", ephemeral=True) + await interaction.followup.send( + "Could not automatically extract GIF from Tenor URL. Please try a direct GIF link.", + ephemeral=True, + ) return - - if not image_bytes: # If after all attempts, image_bytes is still None - await interaction.followup.send(f"Failed to download or identify GIF from URL: {url}. Content-Type: {content_type}", ephemeral=True) - return + if not image_bytes: # If after all attempts, image_bytes is still None + await interaction.followup.send( + f"Failed to download or identify GIF from URL: {url}. Content-Type: {content_type}", + ephemeral=True, + ) + return except requests.exceptions.RequestException as e: - await interaction.followup.send(f"Failed to download GIF from URL: {e}", ephemeral=True) + await interaction.followup.send( + f"Failed to download GIF from URL: {e}", ephemeral=True + ) return except Exception as e: - await interaction.followup.send(f"An error occurred while processing the URL: {e}", ephemeral=True) + await interaction.followup.send( + f"An error occurred while processing the URL: {e}", ephemeral=True + ) return elif attachment: - if not attachment.filename.lower().endswith(".gif") or "image/gif" not in attachment.content_type: - await interaction.followup.send("The attached file must be a GIF.", ephemeral=True) + if ( + not attachment.filename.lower().endswith(".gif") + or "image/gif" not in attachment.content_type + ): + await interaction.followup.send( + "The attached file must be a GIF.", ephemeral=True + ) return try: image_bytes = await attachment.read() filename = f"captioned_{attachment.filename}" except Exception as e: - await interaction.followup.send(f"Failed to read attached GIF: {e}", ephemeral=True) + await interaction.followup.send( + f"Failed to read attached GIF: {e}", ephemeral=True + ) return - + if not image_bytes: await interaction.followup.send("Could not load GIF data.", ephemeral=True) return # Process the GIF try: - captioned_gif_bytes = await self.bot.loop.run_in_executor(None, self._add_text_to_gif, image_bytes, caption) - except Exception as e: # Catch errors from the executor task - await interaction.followup.send(f"An error occurred during GIF processing: {e}", ephemeral=True) + captioned_gif_bytes = await self.bot.loop.run_in_executor( + None, self._add_text_to_gif, image_bytes, caption + ) + except Exception as e: # Catch errors from the executor task + await interaction.followup.send( + f"An error occurred during GIF processing: {e}", ephemeral=True + ) print(f"Error during run_in_executor for _add_text_to_gif: {e}") return @@ -264,7 +346,10 @@ class CaptionCog(commands.Cog, name="Caption"): discord_file = File(fp=captioned_gif_bytes, filename=filename) await interaction.followup.send(file=discord_file) else: - await interaction.followup.send("Failed to caption the GIF. Check bot logs for details.", ephemeral=True) + await interaction.followup.send( + "Failed to caption the GIF. Check bot logs for details.", ephemeral=True + ) + async def setup(bot): await bot.add_cog(CaptionCog(bot)) diff --git a/cogs/command_debug_cog.py b/cogs/command_debug_cog.py index d495917..017ffb1 100644 --- a/cogs/command_debug_cog.py +++ b/cogs/command_debug_cog.py @@ -4,6 +4,7 @@ from discord import app_commands import inspect import json + class CommandDebugCog(commands.Cog): def __init__(self, bot): self.bot = bot @@ -14,38 +15,40 @@ class CommandDebugCog(commands.Cog): async def check_command(self, ctx, command_name: str = "webdrivertorso"): """Check details of a specific slash command""" await ctx.send(f"Checking details for slash command: {command_name}") - + # Find the command in the command tree command = None for cmd in self.bot.tree.get_commands(): if cmd.name == command_name: command = cmd break - + if not command: await ctx.send(f"Command '{command_name}' not found in the command tree.") return - + # Get basic command info await ctx.send(f"Command found: {command.name}") await ctx.send(f"Description: {command.description}") await ctx.send(f"Parameter count: {len(command.parameters)}") - + # Get parameter details for i, param in enumerate(command.parameters): param_info = f"Parameter {i+1}: {param.name}" param_info += f"\n Type: {type(param.type).__name__}" param_info += f"\n Required: {param.required}" - + # Check for choices - if hasattr(param, 'choices') and param.choices: + if hasattr(param, "choices") and param.choices: choices = [f"{c.name} ({c.value})" for c in param.choices] param_info += f"\n Choices: {', '.join(choices)}" - + # Check for tts_provider specifically if param.name == "tts_provider": - param_info += "\n THIS IS THE TTS PROVIDER PARAMETER WE'RE LOOKING FOR!" - + param_info += ( + "\n THIS IS THE TTS PROVIDER PARAMETER WE'RE LOOKING FOR!" + ) + # Get the actual implementation cog_instance = None for cog in self.bot.cogs.values(): @@ -55,45 +58,60 @@ class CommandDebugCog(commands.Cog): break if cog_instance: break - + if cog_instance: param_info += f"\n Found in cog: {cog_instance.__class__.__name__}" - + # Try to get the actual method method = None - for name, method_obj in inspect.getmembers(cog_instance, predicate=inspect.ismethod): - if hasattr(method_obj, "callback") and getattr(method_obj, "callback", None) == command: + for name, method_obj in inspect.getmembers( + cog_instance, predicate=inspect.ismethod + ): + if ( + hasattr(method_obj, "callback") + and getattr(method_obj, "callback", None) == command + ): method = method_obj break - elif hasattr(method_obj, "__name__") and method_obj.__name__ == f"{command_name}_slash": + elif ( + hasattr(method_obj, "__name__") + and method_obj.__name__ == f"{command_name}_slash" + ): method = method_obj break - + if method: param_info += f"\n Method: {method.__name__}" param_info += f"\n Signature: {str(inspect.signature(method))}" - + await ctx.send(param_info) - + # Check for the actual implementation in the cogs await ctx.send("Checking implementation in cogs...") for cog_name, cog in self.bot.cogs.items(): for cmd in cog.get_app_commands(): if cmd.name == command_name: await ctx.send(f"Command implemented in cog: {cog_name}") - + # Try to get the method - for name, method in inspect.getmembers(cog, predicate=inspect.ismethod): + for name, method in inspect.getmembers( + cog, predicate=inspect.ismethod + ): if name.startswith(command_name) or name.endswith("_slash"): await ctx.send(f"Possible implementing method: {name}") sig = inspect.signature(method) await ctx.send(f"Method signature: {sig}") - + # Check if tts_provider is in the parameters if "tts_provider" in [p for p in sig.parameters]: - await ctx.send("✅ tts_provider parameter found in method signature!") + await ctx.send( + "✅ tts_provider parameter found in method signature!" + ) else: - await ctx.send("❌ tts_provider parameter NOT found in method signature!") + await ctx.send( + "❌ tts_provider parameter NOT found in method signature!" + ) + async def setup(bot: commands.Bot): print("Loading CommandDebugCog...") diff --git a/cogs/command_fix_cog.py b/cogs/command_fix_cog.py index 0bf7e23..35d86f7 100644 --- a/cogs/command_fix_cog.py +++ b/cogs/command_fix_cog.py @@ -3,6 +3,7 @@ from discord.ext import commands from discord import app_commands import inspect + class CommandFixCog(commands.Cog): def __init__(self, bot): self.bot = bot @@ -13,52 +14,54 @@ class CommandFixCog(commands.Cog): async def fix_command(self, ctx): """Attempt to fix the webdrivertorso command at runtime""" await ctx.send("Attempting to fix the webdrivertorso command...") - + # Find the WebdriverTorsoCog webdriver_cog = None for cog_name, cog in self.bot.cogs.items(): if cog_name == "WebdriverTorsoCog": webdriver_cog = cog break - + if not webdriver_cog: await ctx.send("❌ WebdriverTorsoCog not found!") return - + await ctx.send("✅ Found WebdriverTorsoCog") - + # Find the slash command slash_command = None for cmd in self.bot.tree.get_commands(): if cmd.name == "webdrivertorso": slash_command = cmd break - + if not slash_command: await ctx.send("❌ webdrivertorso slash command not found!") return - - await ctx.send(f"✅ Found webdrivertorso slash command with {len(slash_command.parameters)} parameters") - + + await ctx.send( + f"✅ Found webdrivertorso slash command with {len(slash_command.parameters)} parameters" + ) + # Check if tts_provider is in the parameters tts_provider_param = None for param in slash_command.parameters: if param.name == "tts_provider": tts_provider_param = param break - + if tts_provider_param: await ctx.send(f"✅ tts_provider parameter already exists in the command") - + # Check if it has choices - if hasattr(tts_provider_param, 'choices') and tts_provider_param.choices: + if hasattr(tts_provider_param, "choices") and tts_provider_param.choices: choices = [f"{c.name} ({c.value})" for c in tts_provider_param.choices] await ctx.send(f"✅ tts_provider has choices: {', '.join(choices)}") else: await ctx.send("❌ tts_provider parameter has no choices!") else: await ctx.send("❌ tts_provider parameter not found in the command!") - + # Try to force a sync await ctx.send("Forcing a command sync...") try: @@ -66,23 +69,28 @@ class CommandFixCog(commands.Cog): await ctx.send(f"✅ Synced {len(synced)} command(s)") except Exception as e: await ctx.send(f"❌ Failed to sync commands: {str(e)}") - + # Create a new command as a workaround await ctx.send("Creating a new ttsprovider command as a workaround...") - + # Check if TTSProviderCog is loaded tts_provider_cog = None for cog_name, cog in self.bot.cogs.items(): if cog_name == "TTSProviderCog": tts_provider_cog = cog break - + if tts_provider_cog: await ctx.send("✅ TTSProviderCog is already loaded") else: - await ctx.send("❌ TTSProviderCog not loaded. Please load it with !load tts_provider_cog") - - await ctx.send("Fix attempt completed. Please check if the ttsprovider command is available.") + await ctx.send( + "❌ TTSProviderCog not loaded. Please load it with !load tts_provider_cog" + ) + + await ctx.send( + "Fix attempt completed. Please check if the ttsprovider command is available." + ) + async def setup(bot: commands.Bot): print("Loading CommandFixCog...") diff --git a/cogs/counting_cog.py b/cogs/counting_cog.py index 7b05019..5436b3f 100644 --- a/cogs/counting_cog.py +++ b/cogs/counting_cog.py @@ -12,61 +12,66 @@ import settings_manager # Set up logging log = logging.getLogger(__name__) + class CountingCog(commands.Cog): """A cog that manages a counting channel where users can only post sequential numbers.""" - + def __init__(self, bot): self.bot = bot - self.counting_channels = {} # Cache for counting channels {guild_id: channel_id} - self.current_counts = {} # Cache for current counts {guild_id: current_number} - self.last_user = {} # Cache to track the last user who sent a number {guild_id: user_id} - + self.counting_channels = ( + {} + ) # Cache for counting channels {guild_id: channel_id} + self.current_counts = {} # Cache for current counts {guild_id: current_number} + self.last_user = ( + {} + ) # Cache to track the last user who sent a number {guild_id: user_id} + # Register commands self.counting_group = app_commands.Group( name="counting", description="Commands for managing the counting channel", - guild_only=True + guild_only=True, ) self.register_commands() - + log.info("CountingCog initialized") - + def register_commands(self): """Register all commands for this cog""" - + # Set counting channel command set_channel_command = app_commands.Command( name="setchannel", description="Set the current channel as the counting channel", callback=self.counting_set_channel_callback, - parent=self.counting_group + parent=self.counting_group, ) self.counting_group.add_command(set_channel_command) - + # Disable counting command disable_command = app_commands.Command( name="disable", description="Disable the counting feature for this server", callback=self.counting_disable_callback, - parent=self.counting_group + parent=self.counting_group, ) self.counting_group.add_command(disable_command) - + # Reset count command reset_command = app_commands.Command( name="reset", description="Reset the count to 0", callback=self.counting_reset_callback, - parent=self.counting_group + parent=self.counting_group, ) self.counting_group.add_command(reset_command) - + # Get current count command status_command = app_commands.Command( name="status", description="Show the current count and counting channel", callback=self.counting_status_callback, - parent=self.counting_group + parent=self.counting_group, ) self.counting_group.add_command(status_command) @@ -75,73 +80,92 @@ class CountingCog(commands.Cog): name="setcount", description="Manually set the current count (Admin only)", callback=self.counting_set_count_callback, - parent=self.counting_group + parent=self.counting_group, ) self.counting_group.add_command(set_count_command) - + async def cog_load(self): """Called when the cog is loaded.""" log.info("Loading CountingCog") # Add the command group to the bot self.bot.tree.add_command(self.counting_group) - + async def cog_unload(self): """Called when the cog is unloaded.""" log.info("Unloading CountingCog") # Remove the command group from the bot - self.bot.tree.remove_command(self.counting_group.name, type=self.counting_group.type) - + self.bot.tree.remove_command( + self.counting_group.name, type=self.counting_group.type + ) + async def load_counting_data(self, guild_id: int): """Load counting channel and current count from database.""" - channel_id_str = await settings_manager.get_setting(guild_id, 'counting_channel_id') - current_count_str = await settings_manager.get_setting(guild_id, 'counting_current_number', default='0') - + channel_id_str = await settings_manager.get_setting( + guild_id, "counting_channel_id" + ) + current_count_str = await settings_manager.get_setting( + guild_id, "counting_current_number", default="0" + ) + if channel_id_str: self.counting_channels[guild_id] = int(channel_id_str) self.current_counts[guild_id] = int(current_count_str) - last_user_str = await settings_manager.get_setting(guild_id, 'counting_last_user', default=None) + last_user_str = await settings_manager.get_setting( + guild_id, "counting_last_user", default=None + ) if last_user_str: self.last_user[guild_id] = int(last_user_str) return True return False - + # Command callbacks async def counting_set_channel_callback(self, interaction: discord.Interaction): """Set the current channel as the counting channel.""" # Check if user has manage channels permission if not interaction.user.guild_permissions.manage_channels: - await interaction.response.send_message("❌ You need the 'Manage Channels' permission to use this command.", ephemeral=True) + await interaction.response.send_message( + "❌ You need the 'Manage Channels' permission to use this command.", + ephemeral=True, + ) return - + guild_id = interaction.guild.id channel_id = interaction.channel.id - + # Save to database - await settings_manager.set_setting(guild_id, 'counting_channel_id', str(channel_id)) - await settings_manager.set_setting(guild_id, 'counting_current_number', '0') - + await settings_manager.set_setting( + guild_id, "counting_channel_id", str(channel_id) + ) + await settings_manager.set_setting(guild_id, "counting_current_number", "0") + # Update cache self.counting_channels[guild_id] = channel_id self.current_counts[guild_id] = 0 if guild_id in self.last_user: del self.last_user[guild_id] - - await interaction.response.send_message(f"✅ This channel has been set as the counting channel! The count starts at 1.", ephemeral=False) - + + await interaction.response.send_message( + f"✅ This channel has been set as the counting channel! The count starts at 1.", + ephemeral=False, + ) + async def counting_disable_callback(self, interaction: discord.Interaction): """Disable the counting feature for this server.""" # Check if user has manage channels permission if not interaction.user.guild_permissions.manage_channels: - await interaction.response.send_message("❌ You need the 'Manage Channels' permission to use this command.", ephemeral=True) + await interaction.response.send_message( + "❌ You need the 'Manage Channels' permission to use this command.", + ephemeral=True, + ) return - + guild_id = interaction.guild.id - + # Remove from database - await settings_manager.set_setting(guild_id, 'counting_channel_id', None) - await settings_manager.set_setting(guild_id, 'counting_current_number', None) - await settings_manager.set_setting(guild_id, 'counting_last_user', None) - + await settings_manager.set_setting(guild_id, "counting_channel_id", None) + await settings_manager.set_setting(guild_id, "counting_current_number", None) + await settings_manager.set_setting(guild_id, "counting_last_user", None) + # Update cache if guild_id in self.counting_channels: del self.counting_channels[guild_id] @@ -149,65 +173,89 @@ class CountingCog(commands.Cog): del self.current_counts[guild_id] if guild_id in self.last_user: del self.last_user[guild_id] - - await interaction.response.send_message("✅ Counting feature has been disabled for this server.", ephemeral=True) - + + await interaction.response.send_message( + "✅ Counting feature has been disabled for this server.", ephemeral=True + ) + async def counting_reset_callback(self, interaction: discord.Interaction): """Reset the count to 0.""" # Check if user has manage channels permission if not interaction.user.guild_permissions.manage_channels: - await interaction.response.send_message("❌ You need the 'Manage Channels' permission to use this command.", ephemeral=True) + await interaction.response.send_message( + "❌ You need the 'Manage Channels' permission to use this command.", + ephemeral=True, + ) return - + guild_id = interaction.guild.id - + # Check if counting is enabled if guild_id not in self.counting_channels: await self.load_counting_data(guild_id) if guild_id not in self.counting_channels: - await interaction.response.send_message("❌ Counting is not enabled for this server. Use `/counting setchannel` first.", ephemeral=True) + await interaction.response.send_message( + "❌ Counting is not enabled for this server. Use `/counting setchannel` first.", + ephemeral=True, + ) return - + # Reset count in database - await settings_manager.set_setting(guild_id, 'counting_current_number', '0') - + await settings_manager.set_setting(guild_id, "counting_current_number", "0") + # Update cache self.current_counts[guild_id] = 0 if guild_id in self.last_user: del self.last_user[guild_id] - - await interaction.response.send_message("✅ The count has been reset to 0. The next number is 1.", ephemeral=False) - + + await interaction.response.send_message( + "✅ The count has been reset to 0. The next number is 1.", ephemeral=False + ) + async def counting_status_callback(self, interaction: discord.Interaction): """Show the current count and counting channel.""" guild_id = interaction.guild.id - + # Check if counting is enabled if guild_id not in self.counting_channels: await self.load_counting_data(guild_id) if guild_id not in self.counting_channels: - await interaction.response.send_message("❌ Counting is not enabled for this server. Use `/counting setchannel` first.", ephemeral=True) + await interaction.response.send_message( + "❌ Counting is not enabled for this server. Use `/counting setchannel` first.", + ephemeral=True, + ) return - + channel_id = self.counting_channels[guild_id] current_count = self.current_counts[guild_id] channel = self.bot.get_channel(channel_id) - + if not channel: - await interaction.response.send_message("❌ The counting channel could not be found. It may have been deleted.", ephemeral=True) + await interaction.response.send_message( + "❌ The counting channel could not be found. It may have been deleted.", + ephemeral=True, + ) return - - await interaction.response.send_message(f"📊 **Counting Status**\n" - f"Channel: {channel.mention}\n" - f"Current count: {current_count}\n" - f"Next number: {current_count + 1}", ephemeral=False) + + await interaction.response.send_message( + f"📊 **Counting Status**\n" + f"Channel: {channel.mention}\n" + f"Current count: {current_count}\n" + f"Next number: {current_count + 1}", + ephemeral=False, + ) @app_commands.describe(number="The number to set the current count to.") - async def counting_set_count_callback(self, interaction: discord.Interaction, number: int): + async def counting_set_count_callback( + self, interaction: discord.Interaction, number: int + ): """Manually set the current count.""" # Check if user has administrator permission if not interaction.user.guild_permissions.administrator: - await interaction.response.send_message("❌ You need Administrator permissions to use this command.", ephemeral=True) + await interaction.response.send_message( + "❌ You need Administrator permissions to use this command.", + ephemeral=True, + ) return guild_id = interaction.guild.id @@ -216,101 +264,124 @@ class CountingCog(commands.Cog): if guild_id not in self.counting_channels: await self.load_counting_data(guild_id) if guild_id not in self.counting_channels: - await interaction.response.send_message("❌ Counting is not enabled for this server. Use `/counting setchannel` first.", ephemeral=True) + await interaction.response.send_message( + "❌ Counting is not enabled for this server. Use `/counting setchannel` first.", + ephemeral=True, + ) return - + if number < 0: - await interaction.response.send_message("❌ The count cannot be a negative number.", ephemeral=True) + await interaction.response.send_message( + "❌ The count cannot be a negative number.", ephemeral=True + ) return # Update count in database - await settings_manager.set_setting(guild_id, 'counting_current_number', str(number)) - + await settings_manager.set_setting( + guild_id, "counting_current_number", str(number) + ) + # Update cache self.current_counts[guild_id] = number - + # Reset last user as the count is manually set if guild_id in self.last_user: del self.last_user[guild_id] - await settings_manager.set_setting(guild_id, 'counting_last_user', None) # Clear last user in DB + await settings_manager.set_setting( + guild_id, "counting_last_user", None + ) # Clear last user in DB + + await interaction.response.send_message( + f"✅ The count has been manually set to {number}. The next number is {number + 1}.", + ephemeral=False, + ) - await interaction.response.send_message(f"✅ The count has been manually set to {number}. The next number is {number + 1}.", ephemeral=False) - @commands.Cog.listener() async def on_message(self, message: discord.Message): """Check if message is in counting channel and validate the number.""" # Ignore bot messages if message.author.bot: return - + # Ignore DMs if not message.guild: return - + guild_id = message.guild.id - + # Check if this is a counting channel if guild_id not in self.counting_channels: # Try to load from database channel_exists = await self.load_counting_data(guild_id) if not channel_exists: return - + # Check if this message is in the counting channel if message.channel.id != self.counting_channels[guild_id]: return - + # Get current count current_count = self.current_counts[guild_id] expected_number = current_count + 1 - + # Check if the message is just the next number # Strip whitespace and check if it's a number content = message.content.strip() - + # Use regex to check if the message contains only the number (allowing for whitespace) - if not re.match(r'^\s*' + str(expected_number) + r'\s*$', content): + if not re.match(r"^\s*" + str(expected_number) + r"\s*$", content): # Not the expected number, delete the message try: await message.delete() # Optionally send a DM to the user explaining why their message was deleted try: - await message.author.send(f"Your message in the counting channel was deleted because it wasn't the next number in the sequence. The next number should be {expected_number}.") + await message.author.send( + f"Your message in the counting channel was deleted because it wasn't the next number in the sequence. The next number should be {expected_number}." + ) except discord.Forbidden: # Can't send DM, ignore pass except discord.Forbidden: # Bot doesn't have permission to delete messages - log.warning(f"Cannot delete message in counting channel {message.channel.id} - missing permissions") + log.warning( + f"Cannot delete message in counting channel {message.channel.id} - missing permissions" + ) except Exception as e: log.error(f"Error deleting message in counting channel: {e}") return - + # Check if the same user is posting twice in a row if guild_id in self.last_user and self.last_user[guild_id] == message.author.id: try: await message.delete() try: - await message.author.send(f"Your message in the counting channel was deleted because you cannot post two numbers in a row. Let someone else continue the count.") + await message.author.send( + f"Your message in the counting channel was deleted because you cannot post two numbers in a row. Let someone else continue the count." + ) except discord.Forbidden: pass except Exception as e: log.error(f"Error deleting message from same user: {e}") return - + # Valid number, update the count self.current_counts[guild_id] = expected_number self.last_user[guild_id] = message.author.id - + # Save to database - await settings_manager.set_setting(guild_id, 'counting_current_number', str(expected_number)) - await settings_manager.set_setting(guild_id, 'counting_last_user', str(message.author.id)) - + await settings_manager.set_setting( + guild_id, "counting_current_number", str(expected_number) + ) + await settings_manager.set_setting( + guild_id, "counting_last_user", str(message.author.id) + ) + @commands.Cog.listener() async def on_ready(self): """Called when the bot is ready.""" log.info("CountingCog is ready") + async def setup(bot: commands.Bot): """Set up the CountingCog with the bot.""" await bot.add_cog(CountingCog(bot)) diff --git a/cogs/dictionary_cog.py b/cogs/dictionary_cog.py index 106ec28..f88d8f5 100644 --- a/cogs/dictionary_cog.py +++ b/cogs/dictionary_cog.py @@ -7,6 +7,7 @@ from typing import Optional, List, Dict, Any log = logging.getLogger(__name__) + class DictionaryCog(commands.Cog, name="Dictionary"): """Cog for word definition and dictionary lookup commands""" @@ -26,95 +27,91 @@ class DictionaryCog(commands.Cog, name="Dictionary"): elif response.status == 404: return None else: - log.error(f"Dictionary API returned status {response.status} for word '{word}'") + log.error( + f"Dictionary API returned status {response.status} for word '{word}'" + ) return None except Exception as e: log.error(f"Error fetching definition for '{word}': {e}") return None - def _format_definition_embed(self, word: str, data: Dict[str, Any]) -> discord.Embed: + def _format_definition_embed( + self, word: str, data: Dict[str, Any] + ) -> discord.Embed: """Format the dictionary data into a Discord embed.""" embed = discord.Embed( title=f"📖 Definition: {data.get('word', word).title()}", - color=discord.Color.blue() + color=discord.Color.blue(), ) # Add phonetic pronunciation if available - phonetics = data.get('phonetics', []) + phonetics = data.get("phonetics", []) if phonetics: for phonetic in phonetics: - if phonetic.get('text'): + if phonetic.get("text"): embed.add_field( - name="🔊 Pronunciation", - value=phonetic['text'], - inline=True + name="🔊 Pronunciation", value=phonetic["text"], inline=True ) break # Add meanings - meanings = data.get('meanings', []) + meanings = data.get("meanings", []) definition_count = 0 - + for meaning in meanings[:3]: # Limit to first 3 parts of speech - part_of_speech = meaning.get('partOfSpeech', 'Unknown') - definitions = meaning.get('definitions', []) - + part_of_speech = meaning.get("partOfSpeech", "Unknown") + definitions = meaning.get("definitions", []) + if definitions: - definition_text = definitions[0].get('definition', 'No definition available') - example = definitions[0].get('example') - + definition_text = definitions[0].get( + "definition", "No definition available" + ) + example = definitions[0].get("example") + field_value = f"**{definition_text}**" if example: field_value += f"\n*Example: {example}*" - + embed.add_field( - name=f"📝 {part_of_speech.title()}", - value=field_value, - inline=False + name=f"📝 {part_of_speech.title()}", value=field_value, inline=False ) definition_count += 1 # Add etymology if available - etymology = data.get('etymology') + etymology = data.get("etymology") if etymology: embed.add_field( name="📚 Etymology", value=etymology[:200] + "..." if len(etymology) > 200 else etymology, - inline=False + inline=False, ) # Add source attribution embed.set_footer(text="Powered by Free Dictionary API") - + return embed async def _define_logic(self, word: str) -> Dict[str, Any]: """Core logic for the define command.""" if not word or len(word.strip()) == 0: - return { - "error": "Please provide a word to define.", - "embed": None - } + return {"error": "Please provide a word to define.", "embed": None} # Clean the word input clean_word = word.strip().lower() - + # Fetch definition definition_data = await self._fetch_definition(clean_word) - + if definition_data is None: return { "error": f"❌ Sorry, I couldn't find a definition for '{word}'. Please check the spelling and try again.", - "embed": None + "embed": None, } # Create embed embed = self._format_definition_embed(word, definition_data) - - return { - "error": None, - "embed": embed - } + + return {"error": None, "embed": embed} # --- Prefix Command --- @commands.command(name="define", aliases=["def", "definition"]) @@ -125,7 +122,7 @@ class DictionaryCog(commands.Cog, name="Dictionary"): return result = await self._define_logic(word) - + if result["error"]: await ctx.reply(result["error"]) else: @@ -137,13 +134,14 @@ class DictionaryCog(commands.Cog, name="Dictionary"): async def define_slash(self, interaction: discord.Interaction, word: str): """Slash command for word definition lookup.""" await interaction.response.defer() - + result = await self._define_logic(word) - + if result["error"]: await interaction.followup.send(result["error"]) else: await interaction.followup.send(embed=result["embed"]) + async def setup(bot: commands.Bot): await bot.add_cog(DictionaryCog(bot)) diff --git a/cogs/discord_sync_cog.py b/cogs/discord_sync_cog.py index c43b8fc..2802bf1 100644 --- a/cogs/discord_sync_cog.py +++ b/cogs/discord_sync_cog.py @@ -8,14 +8,19 @@ from typing import Optional, List, Dict, Any # Try to import the Discord sync API try: from discord_bot_sync_api import ( - user_conversations, save_discord_conversation, - load_conversations, SyncedConversation, SyncedMessage + user_conversations, + save_discord_conversation, + load_conversations, + SyncedConversation, + SyncedMessage, ) + SYNC_API_AVAILABLE = True except ImportError: print("Discord sync API not available in sync cog. Sync features will be disabled.") SYNC_API_AVAILABLE = False + class DiscordSyncCog(commands.Cog): def __init__(self, bot): self.bot = bot @@ -29,7 +34,9 @@ class DiscordSyncCog(commands.Cog): async def sync_status(self, ctx: commands.Context): """Check the status of the Discord sync API""" if not SYNC_API_AVAILABLE: - await ctx.reply("❌ Discord sync API is not available. Please make sure the required dependencies are installed.") + await ctx.reply( + "❌ Discord sync API is not available. Please make sure the required dependencies are installed." + ) return # Count total synced conversations @@ -43,40 +50,38 @@ class DiscordSyncCog(commands.Cog): embed = discord.Embed( title="Discord Sync Status", description="Status of the Discord sync API for Flutter app integration", - color=discord.Color.green() + color=discord.Color.green(), ) - embed.add_field( - name="API Status", - value="✅ Running", - inline=False - ) + embed.add_field(name="API Status", value="✅ Running", inline=False) embed.add_field( name="Total Synced Conversations", value=f"{total_conversations} conversations from {total_users} users", - inline=False + inline=False, ) embed.add_field( name="Your Synced Conversations", value=f"{user_conv_count} conversations", - inline=False + inline=False, ) embed.add_field( name="API Endpoint", value="https://slipstreamm.dev/discordapi", - inline=False + inline=False, ) embed.add_field( name="Setup Instructions", value="Use `!synchelp` for setup instructions", - inline=False + inline=False, ) - embed.set_footer(text=f"Last updated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + embed.set_footer( + text=f"Last updated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" + ) await ctx.reply(embed=embed) @@ -86,7 +91,7 @@ class DiscordSyncCog(commands.Cog): embed = discord.Embed( title="Discord Sync Integration Help", description="How to set up the Discord sync integration with the Flutter app", - color=discord.Color.blue() + color=discord.Color.blue(), ) embed.add_field( @@ -98,7 +103,7 @@ class DiscordSyncCog(commands.Cog): "4. Add a redirect URL: `openroutergui://auth`\n" "5. Copy the 'Client ID' for the Flutter app" ), - inline=False + inline=False, ) embed.add_field( @@ -110,7 +115,7 @@ class DiscordSyncCog(commands.Cog): "4. Enter the Bot API URL: `https://slipstreamm.dev/discordapi`\n" "5. Click 'Save'" ), - inline=False + inline=False, ) embed.add_field( @@ -121,7 +126,7 @@ class DiscordSyncCog(commands.Cog): "3. Use the 'Sync Conversations' button to sync conversations\n" "4. Use the 'Import from Discord' button to import conversations" ), - inline=False + inline=False, ) embed.add_field( @@ -132,7 +137,7 @@ class DiscordSyncCog(commands.Cog): "• Verify that the redirect URL is properly configured\n" "• Use `!syncstatus` to check the API status" ), - inline=False + inline=False, ) await ctx.reply(embed=embed) @@ -141,7 +146,9 @@ class DiscordSyncCog(commands.Cog): async def sync_clear(self, ctx: commands.Context): """Clear your synced conversations""" if not SYNC_API_AVAILABLE: - await ctx.reply("❌ Discord sync API is not available. Please make sure the required dependencies are installed.") + await ctx.reply( + "❌ Discord sync API is not available. Please make sure the required dependencies are installed." + ) return user_id = str(ctx.author.id) @@ -157,6 +164,7 @@ class DiscordSyncCog(commands.Cog): # Save the updated conversations from discord_bot_sync_api import save_conversations + save_conversations() await ctx.reply(f"✅ Cleared {conv_count} synced conversations.") @@ -165,7 +173,9 @@ class DiscordSyncCog(commands.Cog): async def sync_list(self, ctx: commands.Context): """List your synced conversations""" if not SYNC_API_AVAILABLE: - await ctx.reply("❌ Discord sync API is not available. Please make sure the required dependencies are installed.") + await ctx.reply( + "❌ Discord sync API is not available. Please make sure the required dependencies are installed." + ) return user_id = str(ctx.author.id) @@ -177,7 +187,7 @@ class DiscordSyncCog(commands.Cog): embed = discord.Embed( title="Your Synced Conversations", description=f"You have {len(user_conversations[user_id])} synced conversations", - color=discord.Color.blue() + color=discord.Color.blue(), ) # Add each conversation to the embed @@ -197,7 +207,7 @@ class DiscordSyncCog(commands.Cog): f"Messages: {len(conv.messages)}\n" f"Preview: {preview[:100]}..." ), - inline=False + inline=False, ) # Discord embeds have a limit of 25 fields @@ -205,11 +215,12 @@ class DiscordSyncCog(commands.Cog): embed.add_field( name="Note", value=f"Showing 10/{len(user_conversations[user_id])} conversations. Use the Flutter app to view all.", - inline=False + inline=False, ) break await ctx.reply(embed=embed) + async def setup(bot): await bot.add_cog(DiscordSyncCog(bot)) diff --git a/cogs/economy/database.py b/cogs/economy/database.py index 348b97a..ba4abdc 100644 --- a/cogs/economy/database.py +++ b/cogs/economy/database.py @@ -1,5 +1,5 @@ import asyncpg -import redis.asyncio as redis # Use asyncio version of redis library +import redis.asyncio as redis # Use asyncio version of redis library import os import datetime import logging @@ -23,12 +23,13 @@ CACHE_COOLDOWN_KEY = "economy:cooldown:{user_id}:{command_name}" CACHE_LEADERBOARD_KEY = "economy:leaderboard:{count}" # --- Cache Durations (in seconds) --- -CACHE_DEFAULT_TTL = 60 * 5 # 5 minutes for most things -CACHE_ITEM_TTL = 60 * 60 * 24 # 24 hours for item details (rarely change) -CACHE_LEADERBOARD_TTL = 60 * 15 # 15 minutes for leaderboard +CACHE_DEFAULT_TTL = 60 * 5 # 5 minutes for most things +CACHE_ITEM_TTL = 60 * 60 * 24 # 24 hours for item details (rarely change) +CACHE_LEADERBOARD_TTL = 60 * 15 # 15 minutes for leaderboard # --- Database Setup --- + async def init_db(): """Initializes the PostgreSQL connection pool and Redis client.""" global pool, redis_client @@ -42,34 +43,46 @@ async def init_db(): db_user = os.environ.get("POSTGRES_USER") db_password = os.environ.get("POSTGRES_PASSWORD") db_name = os.environ.get("POSTGRES_DB") - db_port = os.environ.get("POSTGRES_PORT", 5432) # Default PostgreSQL port + db_port = os.environ.get("POSTGRES_PORT", 5432) # Default PostgreSQL port if not all([db_user, db_password, db_name]): - log.error("Missing PostgreSQL environment variables (POSTGRES_USER, POSTGRES_PASSWORD, POSTGRES_DB)") - raise ConnectionError("Missing PostgreSQL credentials in environment variables.") + log.error( + "Missing PostgreSQL environment variables (POSTGRES_USER, POSTGRES_PASSWORD, POSTGRES_DB)" + ) + raise ConnectionError( + "Missing PostgreSQL credentials in environment variables." + ) - conn_string = f"postgresql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}" + conn_string = ( + f"postgresql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}" + ) pool = await asyncpg.create_pool(conn_string, min_size=1, max_size=10) if pool: - log.info(f"PostgreSQL connection pool established to {db_host}:{db_port}/{db_name}") - # Run table creation check (idempotent) - await _create_tables_if_not_exist(pool) + log.info( + f"PostgreSQL connection pool established to {db_host}:{db_port}/{db_name}" + ) + # Run table creation check (idempotent) + await _create_tables_if_not_exist(pool) else: - log.error("Failed to create PostgreSQL connection pool.") - raise ConnectionError("Failed to create PostgreSQL connection pool.") - + log.error("Failed to create PostgreSQL connection pool.") + raise ConnectionError("Failed to create PostgreSQL connection pool.") # --- Redis Setup --- redis_host = os.environ.get("REDIS_HOST", "localhost") redis_port = int(os.environ.get("REDIS_PORT", 6379)) - redis_client = redis.Redis(host=redis_host, port=redis_port, decode_responses=True) # decode_responses=True to get strings - await redis_client.ping() # Check connection + redis_client = redis.Redis( + host=redis_host, port=redis_port, decode_responses=True + ) # decode_responses=True to get strings + await redis_client.ping() # Check connection log.info(f"Redis client connected to {redis_host}:{redis_port}") except redis.exceptions.ConnectionError as e: - log.error(f"Failed to connect to Redis at {redis_host}:{redis_port}: {e}", exc_info=True) - redis_client = None # Ensure client is None if connection fails + log.error( + f"Failed to connect to Redis at {redis_host}:{redis_port}: {e}", + exc_info=True, + ) + redis_client = None # Ensure client is None if connection fails # Decide if this is fatal - for now, let it continue but caching will fail log.warning("Redis connection failed. Caching will be disabled.") except Exception as e: @@ -81,34 +94,40 @@ async def init_db(): if redis_client: await redis_client.close() redis_client = None - raise # Re-raise the exception to prevent cog loading if critical + raise # Re-raise the exception to prevent cog loading if critical + async def _create_tables_if_not_exist(db_pool: asyncpg.Pool): """Creates tables if they don't exist. Called internally by init_db.""" async with db_pool.acquire() as conn: async with conn.transaction(): # Create economy table - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS economy ( user_id BIGINT PRIMARY KEY, balance BIGINT NOT NULL DEFAULT 0 ) - """) + """ + ) log.debug("Checked/created 'economy' table in PostgreSQL.") # Create command_cooldowns table - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS command_cooldowns ( user_id BIGINT NOT NULL, command_name TEXT NOT NULL, last_used TIMESTAMP WITH TIME ZONE NOT NULL, PRIMARY KEY (user_id, command_name) ) - """) + """ + ) log.debug("Checked/created 'command_cooldowns' table in PostgreSQL.") # Create user_jobs table - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS user_jobs ( user_id BIGINT PRIMARY KEY, job_name TEXT, @@ -117,22 +136,26 @@ async def _create_tables_if_not_exist(db_pool: asyncpg.Pool): last_job_action TIMESTAMP WITH TIME ZONE, FOREIGN KEY (user_id) REFERENCES economy(user_id) ON DELETE CASCADE ) - """) + """ + ) log.debug("Checked/created 'user_jobs' table in PostgreSQL.") # Create items table - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS items ( item_key TEXT PRIMARY KEY, name TEXT NOT NULL, description TEXT, sell_price BIGINT NOT NULL DEFAULT 0 ) - """) + """ + ) log.debug("Checked/created 'items' table in PostgreSQL.") # Create user_inventory table - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS user_inventory ( user_id BIGINT NOT NULL, item_key TEXT NOT NULL, @@ -141,33 +164,40 @@ async def _create_tables_if_not_exist(db_pool: asyncpg.Pool): FOREIGN KEY (user_id) REFERENCES economy(user_id) ON DELETE CASCADE, FOREIGN KEY (item_key) REFERENCES items(item_key) ON DELETE CASCADE ) - """) + """ + ) log.debug("Checked/created 'user_inventory' table in PostgreSQL.") # --- Add some basic items --- initial_items = [ - ('raw_iron', 'Raw Iron Ore', 'Basic metal ore.', 5), - ('coal', 'Coal', 'A lump of fossil fuel.', 3), - ('shiny_gem', 'Shiny Gem', 'A pretty, potentially valuable gem.', 50), - ('common_fish', 'Common Fish', 'A standard fish.', 4), - ('rare_fish', 'Rare Fish', 'An uncommon fish.', 15), - ('treasure_chest', 'Treasure Chest', 'Might contain goodies!', 0), - ('iron_ingot', 'Iron Ingot', 'Refined iron, ready for crafting.', 12), - ('basic_tool', 'Basic Tool', 'A simple tool.', 25) + ("raw_iron", "Raw Iron Ore", "Basic metal ore.", 5), + ("coal", "Coal", "A lump of fossil fuel.", 3), + ("shiny_gem", "Shiny Gem", "A pretty, potentially valuable gem.", 50), + ("common_fish", "Common Fish", "A standard fish.", 4), + ("rare_fish", "Rare Fish", "An uncommon fish.", 15), + ("treasure_chest", "Treasure Chest", "Might contain goodies!", 0), + ("iron_ingot", "Iron Ingot", "Refined iron, ready for crafting.", 12), + ("basic_tool", "Basic Tool", "A simple tool.", 25), ] # Use ON CONFLICT DO NOTHING to avoid errors if items already exist - await conn.executemany(""" + await conn.executemany( + """ INSERT INTO items (item_key, name, description, sell_price) VALUES ($1, $2, $3, $4) ON CONFLICT (item_key) DO NOTHING - """, initial_items) + """, + initial_items, + ) log.debug("Ensured initial items exist in PostgreSQL.") + # --- Database Helper Functions --- + async def get_balance(user_id: int) -> int: """Gets the balance for a user, creating an entry if needed. Uses Redis cache.""" - if not pool: raise ConnectionError("Database pool not initialized.") + if not pool: + raise ConnectionError("Database pool not initialized.") cache_key = CACHE_BALANCE_KEY.format(user_id=user_id) # 1. Check Cache @@ -184,19 +214,29 @@ async def get_balance(user_id: int) -> int: log.debug(f"Cache miss for balance user_id: {user_id}") # 2. Query Database async with pool.acquire() as conn: - balance = await conn.fetchval("SELECT balance FROM economy WHERE user_id = $1", user_id) + balance = await conn.fetchval( + "SELECT balance FROM economy WHERE user_id = $1", user_id + ) if balance is None: # User doesn't exist, create entry try: - await conn.execute("INSERT INTO economy (user_id, balance) VALUES ($1, 0)", user_id) + await conn.execute( + "INSERT INTO economy (user_id, balance) VALUES ($1, 0)", user_id + ) log.info(f"Created new economy entry for user_id: {user_id}") balance = 0 except asyncpg.UniqueViolationError: # Race condition: another process inserted the user between SELECT and INSERT - log.warning(f"Race condition handled for user_id: {user_id} during balance fetch.") - balance = await conn.fetchval("SELECT balance FROM economy WHERE user_id = $1", user_id) - balance = balance if balance is not None else 0 # Ensure balance is 0 if somehow still None + log.warning( + f"Race condition handled for user_id: {user_id} during balance fetch." + ) + balance = await conn.fetchval( + "SELECT balance FROM economy WHERE user_id = $1", user_id + ) + balance = ( + balance if balance is not None else 0 + ) # Ensure balance is 0 if somehow still None # 3. Update Cache if redis_client: @@ -207,17 +247,25 @@ async def get_balance(user_id: int) -> int: return balance if balance is not None else 0 + async def update_balance(user_id: int, amount: int): """Updates a user's balance by adding the specified amount (can be negative). Invalidates cache.""" - if not pool: raise ConnectionError("Database pool not initialized.") + if not pool: + raise ConnectionError("Database pool not initialized.") cache_key = CACHE_BALANCE_KEY.format(user_id=user_id) - leaderboard_pattern = CACHE_LEADERBOARD_KEY.format(count='*') # Pattern to invalidate all leaderboard caches + leaderboard_pattern = CACHE_LEADERBOARD_KEY.format( + count="*" + ) # Pattern to invalidate all leaderboard caches async with pool.acquire() as conn: # Ensure user exists first (get_balance handles creation) await get_balance(user_id) # Use RETURNING to get the new balance efficiently, though not strictly needed here - await conn.execute("UPDATE economy SET balance = balance + $1 WHERE user_id = $2", amount, user_id) + await conn.execute( + "UPDATE economy SET balance = balance + $1 WHERE user_id = $2", + amount, + user_id, + ) log.debug(f"Updated balance for user_id {user_id} by {amount}.") # Invalidate Caches @@ -228,14 +276,22 @@ async def update_balance(user_id: int, amount: int): # Invalidate all leaderboard caches (since balances changed) async for key in redis_client.scan_iter(match=leaderboard_pattern): await redis_client.delete(key) - log.debug(f"Invalidated cache for balance user_id: {user_id} and leaderboards.") + log.debug( + f"Invalidated cache for balance user_id: {user_id} and leaderboards." + ) except Exception as e: - log.warning(f"Redis DELETE failed for balance/leaderboard invalidation (user {user_id}): {e}", exc_info=True) + log.warning( + f"Redis DELETE failed for balance/leaderboard invalidation (user {user_id}): {e}", + exc_info=True, + ) -async def check_cooldown(user_id: int, command_name: str) -> Optional[datetime.datetime]: +async def check_cooldown( + user_id: int, command_name: str +) -> Optional[datetime.datetime]: """Checks if a command is on cooldown. Uses Redis cache.""" - if not pool: raise ConnectionError("Database pool not initialized.") + if not pool: + raise ConnectionError("Database pool not initialized.") cache_key = CACHE_COOLDOWN_KEY.format(user_id=user_id, command_name=command_name) # 1. Check Cache @@ -243,22 +299,32 @@ async def check_cooldown(user_id: int, command_name: str) -> Optional[datetime.d try: cached_cooldown = await redis_client.get(cache_key) if cached_cooldown: - if cached_cooldown == "NULL": # Handle explicitly stored null case + if cached_cooldown == "NULL": # Handle explicitly stored null case return None try: # Timestamps stored in ISO format in cache last_used_dt = datetime.datetime.fromisoformat(cached_cooldown) # Ensure timezone aware (should be stored as UTC) if last_used_dt.tzinfo is None: - last_used_dt = last_used_dt.replace(tzinfo=datetime.timezone.utc) - log.debug(f"Cache hit for cooldown user {user_id}, cmd {command_name}") + last_used_dt = last_used_dt.replace( + tzinfo=datetime.timezone.utc + ) + log.debug( + f"Cache hit for cooldown user {user_id}, cmd {command_name}" + ) return last_used_dt except ValueError: - log.error(f"Could not parse cached timestamp '{cached_cooldown}' for user {user_id}, cmd {command_name}") - # Fall through to DB query if cache data is bad - elif cached_cooldown is not None: # Empty string means checked DB and no cooldown exists - log.debug(f"Cache hit (no cooldown) for user {user_id}, cmd {command_name}") - return None + log.error( + f"Could not parse cached timestamp '{cached_cooldown}' for user {user_id}, cmd {command_name}" + ) + # Fall through to DB query if cache data is bad + elif ( + cached_cooldown is not None + ): # Empty string means checked DB and no cooldown exists + log.debug( + f"Cache hit (no cooldown) for user {user_id}, cmd {command_name}" + ) + return None except Exception as e: log.warning(f"Redis GET failed for key {cache_key}: {e}", exc_info=True) @@ -268,45 +334,63 @@ async def check_cooldown(user_id: int, command_name: str) -> Optional[datetime.d async with pool.acquire() as conn: last_used_dt = await conn.fetchval( "SELECT last_used FROM command_cooldowns WHERE user_id = $1 AND command_name = $2", - user_id, command_name + user_id, + command_name, ) # 3. Update Cache if redis_client: try: - value_to_cache = last_used_dt.isoformat() if last_used_dt else "NULL" # Store NULL explicitly + value_to_cache = ( + last_used_dt.isoformat() if last_used_dt else "NULL" + ) # Store NULL explicitly await redis_client.set(cache_key, value_to_cache, ex=CACHE_DEFAULT_TTL) except Exception as e: log.warning(f"Redis SET failed for key {cache_key}: {e}", exc_info=True) - return last_used_dt # Already timezone-aware from PostgreSQL TIMESTAMP WITH TIME ZONE + return ( + last_used_dt # Already timezone-aware from PostgreSQL TIMESTAMP WITH TIME ZONE + ) + async def set_cooldown(user_id: int, command_name: str): """Sets or updates the cooldown timestamp. Invalidates cache.""" - if not pool: raise ConnectionError("Database pool not initialized.") + if not pool: + raise ConnectionError("Database pool not initialized.") cache_key = CACHE_COOLDOWN_KEY.format(user_id=user_id, command_name=command_name) now_utc = datetime.datetime.now(datetime.timezone.utc) async with pool.acquire() as conn: # Use ON CONFLICT DO UPDATE for UPSERT behavior - await conn.execute(""" + await conn.execute( + """ INSERT INTO command_cooldowns (user_id, command_name, last_used) VALUES ($1, $2, $3) ON CONFLICT (user_id, command_name) DO UPDATE SET last_used = EXCLUDED.last_used - """, user_id, command_name, now_utc) - log.debug(f"Set cooldown for user_id {user_id}, command {command_name} to {now_utc.isoformat()}") + """, + user_id, + command_name, + now_utc, + ) + log.debug( + f"Set cooldown for user_id {user_id}, command {command_name} to {now_utc.isoformat()}" + ) # Update Cache directly (faster than invalidating and re-querying) if redis_client: try: await redis_client.set(cache_key, now_utc.isoformat(), ex=CACHE_DEFAULT_TTL) except Exception as e: - log.warning(f"Redis SET failed for key {cache_key} during update: {e}", exc_info=True) + log.warning( + f"Redis SET failed for key {cache_key} during update: {e}", + exc_info=True, + ) async def get_leaderboard(count: int = 10) -> List[Tuple[int, int]]: """Retrieves the top users by balance. Uses Redis cache.""" - if not pool: raise ConnectionError("Database pool not initialized.") + if not pool: + raise ConnectionError("Database pool not initialized.") cache_key = CACHE_LEADERBOARD_KEY.format(count=count) # 1. Check Cache @@ -324,27 +408,31 @@ async def get_leaderboard(count: int = 10) -> List[Tuple[int, int]]: # 2. Query Database async with pool.acquire() as conn: results = await conn.fetch( - "SELECT user_id, balance FROM economy ORDER BY balance DESC LIMIT $1", - count + "SELECT user_id, balance FROM economy ORDER BY balance DESC LIMIT $1", count ) # Convert asyncpg Records to simple list of tuples - leaderboard_data = [(r['user_id'], r['balance']) for r in results] + leaderboard_data = [(r["user_id"], r["balance"]) for r in results] # 3. Update Cache if redis_client: try: # Store as JSON string - await redis_client.set(cache_key, json.dumps(leaderboard_data), ex=CACHE_LEADERBOARD_TTL) + await redis_client.set( + cache_key, json.dumps(leaderboard_data), ex=CACHE_LEADERBOARD_TTL + ) except Exception as e: log.warning(f"Redis SET failed for key {cache_key}: {e}", exc_info=True) return leaderboard_data + # --- Job Functions --- + async def get_user_job(user_id: int) -> Optional[Dict[str, Any]]: """Gets the user's job details. Uses Redis cache.""" - if not pool: raise ConnectionError("Database pool not initialized.") + if not pool: + raise ConnectionError("Database pool not initialized.") cache_key = CACHE_JOB_KEY.format(user_id=user_id) # 1. Check Cache @@ -357,10 +445,14 @@ async def get_user_job(user_id: int) -> Optional[Dict[str, Any]]: # Convert timestamp string back to datetime object if job_data.get("last_action"): try: - job_data["last_action"] = datetime.datetime.fromisoformat(job_data["last_action"]) + job_data["last_action"] = datetime.datetime.fromisoformat( + job_data["last_action"] + ) except (ValueError, TypeError): - log.error(f"Could not parse cached job timestamp '{job_data['last_action']}' for user {user_id}") - job_data["last_action"] = None # Set to None if parsing fails + log.error( + f"Could not parse cached job timestamp '{job_data['last_action']}' for user {user_id}" + ) + job_data["last_action"] = None # Set to None if parsing fails return job_data except Exception as e: log.warning(f"Redis GET failed for key {cache_key}: {e}", exc_info=True) @@ -373,40 +465,42 @@ async def get_user_job(user_id: int) -> Optional[Dict[str, Any]]: # Fetch job details job_record = await conn.fetchrow( "SELECT job_name, job_level, job_xp, last_job_action FROM user_jobs WHERE user_id = $1", - user_id + user_id, ) job_data: Optional[Dict[str, Any]] = None if job_record: job_data = { - "name": job_record['job_name'], - "level": job_record['job_level'], - "xp": job_record['job_xp'], - "last_action": job_record['last_job_action'] # Already timezone-aware + "name": job_record["job_name"], + "level": job_record["job_level"], + "xp": job_record["job_xp"], + "last_action": job_record["last_job_action"], # Already timezone-aware } else: # Create job entry if it doesn't exist try: await conn.execute( "INSERT INTO user_jobs (user_id, job_name, job_level, job_xp, last_job_action) VALUES ($1, NULL, 1, 0, NULL)", - user_id + user_id, ) log.info(f"Created default job entry for user_id: {user_id}") job_data = {"name": None, "level": 1, "xp": 0, "last_action": None} except asyncpg.UniqueViolationError: - log.warning(f"Race condition handled for user_id: {user_id} during job fetch.") + log.warning( + f"Race condition handled for user_id: {user_id} during job fetch." + ) job_record_retry = await conn.fetchrow( "SELECT job_name, job_level, job_xp, last_job_action FROM user_jobs WHERE user_id = $1", - user_id + user_id, ) if job_record_retry: - job_data = { - "name": job_record_retry['job_name'], - "level": job_record_retry['job_level'], - "xp": job_record_retry['job_xp'], - "last_action": job_record_retry['last_job_action'] + job_data = { + "name": job_record_retry["job_name"], + "level": job_record_retry["job_level"], + "xp": job_record_retry["job_xp"], + "last_action": job_record_retry["last_job_action"], } - else: # Should not happen, but handle defensively + else: # Should not happen, but handle defensively job_data = {"name": None, "level": 1, "xp": 0, "last_action": None} # 3. Update Cache @@ -415,17 +509,23 @@ async def get_user_job(user_id: int) -> Optional[Dict[str, Any]]: # Convert datetime to ISO string for JSON serialization job_data_to_cache = job_data.copy() if job_data_to_cache.get("last_action"): - job_data_to_cache["last_action"] = job_data_to_cache["last_action"].isoformat() + job_data_to_cache["last_action"] = job_data_to_cache[ + "last_action" + ].isoformat() - await redis_client.set(cache_key, json.dumps(job_data_to_cache), ex=CACHE_DEFAULT_TTL) + await redis_client.set( + cache_key, json.dumps(job_data_to_cache), ex=CACHE_DEFAULT_TTL + ) except Exception as e: log.warning(f"Redis SET failed for key {cache_key}: {e}", exc_info=True) return job_data + async def set_user_job(user_id: int, job_name: Optional[str]): """Sets or clears a user's job. Resets level/xp. Invalidates cache.""" - if not pool: raise ConnectionError("Database pool not initialized.") + if not pool: + raise ConnectionError("Database pool not initialized.") cache_key = CACHE_JOB_KEY.format(user_id=user_id) async with pool.acquire() as conn: @@ -434,7 +534,8 @@ async def set_user_job(user_id: int, job_name: Optional[str]): # Update job, resetting level/xp await conn.execute( "UPDATE user_jobs SET job_name = $1, job_level = 1, job_xp = 0 WHERE user_id = $2", - job_name, user_id + job_name, + user_id, ) log.info(f"Set job for user_id {user_id} to {job_name}. Level/XP reset.") @@ -446,9 +547,11 @@ async def set_user_job(user_id: int, job_name: Optional[str]): except Exception as e: log.warning(f"Redis DELETE failed for key {cache_key}: {e}", exc_info=True) + async def remove_user_job(user_id: int): """Removes a user's job by setting job_name to NULL. Invalidates cache.""" - if not pool: raise ConnectionError("Database pool not initialized.") + if not pool: + raise ConnectionError("Database pool not initialized.") cache_key = CACHE_JOB_KEY.format(user_id=user_id) async with pool.acquire() as conn: @@ -457,7 +560,7 @@ async def remove_user_job(user_id: int): # Set job_name to NULL, reset level/xp await conn.execute( "UPDATE user_jobs SET job_name = NULL, job_level = 1, job_xp = 0 WHERE user_id = $1", - user_id + user_id, ) log.info(f"Removed job for user_id {user_id}. Level/XP reset.") @@ -469,20 +572,22 @@ async def remove_user_job(user_id: int): except Exception as e: log.warning(f"Redis DELETE failed for key {cache_key}: {e}", exc_info=True) + async def add_job_xp(user_id: int, xp_amount: int) -> Tuple[int, int, bool]: """Adds XP to the user's job, handles level ups. Invalidates cache. Returns (new_level, new_xp, did_level_up).""" - if not pool: raise ConnectionError("Database pool not initialized.") + if not pool: + raise ConnectionError("Database pool not initialized.") cache_key = CACHE_JOB_KEY.format(user_id=user_id) async with pool.acquire() as conn: # Use transaction to ensure atomicity of read-modify-write async with conn.transaction(): job_info = await conn.fetchrow( - "SELECT job_name, job_level, job_xp FROM user_jobs WHERE user_id = $1 FOR UPDATE", # Lock row - user_id + "SELECT job_name, job_level, job_xp FROM user_jobs WHERE user_id = $1 FOR UPDATE", # Lock row + user_id, ) - if not job_info or not job_info['job_name']: + if not job_info or not job_info["job_name"]: log.warning(f"Attempted to add XP to user {user_id} with no job.") return (1, 0, False) @@ -504,9 +609,13 @@ async def add_job_xp(user_id: int, xp_amount: int) -> Tuple[int, int, bool]: # Update database await conn.execute( "UPDATE user_jobs SET job_level = $1, job_xp = $2 WHERE user_id = $3", - current_level, new_xp, user_id + current_level, + new_xp, + user_id, + ) + log.debug( + f"Updated job XP for user {user_id}. New Level: {current_level}, New XP: {new_xp}" ) - log.debug(f"Updated job XP for user {user_id}. New Level: {current_level}, New XP: {new_xp}") # Invalidate Cache outside transaction if redis_client: @@ -514,20 +623,26 @@ async def add_job_xp(user_id: int, xp_amount: int) -> Tuple[int, int, bool]: await redis_client.delete(cache_key) log.debug(f"Invalidated cache for job user_id: {user_id} after XP update.") except Exception as e: - log.warning(f"Redis DELETE failed for key {cache_key} after XP update: {e}", exc_info=True) + log.warning( + f"Redis DELETE failed for key {cache_key} after XP update: {e}", + exc_info=True, + ) return (current_level, new_xp, did_level_up) + async def set_job_cooldown(user_id: int): """Sets the job cooldown timestamp. Invalidates cache.""" - if not pool: raise ConnectionError("Database pool not initialized.") + if not pool: + raise ConnectionError("Database pool not initialized.") cache_key = CACHE_JOB_KEY.format(user_id=user_id) now_utc = datetime.datetime.now(datetime.timezone.utc) async with pool.acquire() as conn: await conn.execute( "UPDATE user_jobs SET last_job_action = $1 WHERE user_id = $2", - now_utc, user_id + now_utc, + user_id, ) log.debug(f"Set job cooldown for user_id {user_id} to {now_utc.isoformat()}") @@ -535,9 +650,15 @@ async def set_job_cooldown(user_id: int): if redis_client: try: await redis_client.delete(cache_key) - log.debug(f"Invalidated cache for job user_id: {user_id} after setting cooldown.") + log.debug( + f"Invalidated cache for job user_id: {user_id} after setting cooldown." + ) except Exception as e: - log.warning(f"Redis DELETE failed for key {cache_key} after setting cooldown: {e}", exc_info=True) + log.warning( + f"Redis DELETE failed for key {cache_key} after setting cooldown: {e}", + exc_info=True, + ) + async def get_available_jobs() -> List[Dict[str, Any]]: """Returns a list of available jobs with their details.""" @@ -549,7 +670,7 @@ async def get_available_jobs() -> List[Dict[str, Any]]: "description": "Mine for ores and gems.", "base_pay": 20, "cooldown_minutes": 30, - "items": ["raw_iron", "coal", "shiny_gem"] + "items": ["raw_iron", "coal", "shiny_gem"], }, { "key": "fisher", @@ -557,7 +678,7 @@ async def get_available_jobs() -> List[Dict[str, Any]]: "description": "Catch fish from the sea.", "base_pay": 15, "cooldown_minutes": 20, - "items": ["common_fish", "rare_fish", "treasure_chest"] + "items": ["common_fish", "rare_fish", "treasure_chest"], }, { "key": "blacksmith", @@ -565,7 +686,7 @@ async def get_available_jobs() -> List[Dict[str, Any]]: "description": "Craft metal items.", "base_pay": 25, "cooldown_minutes": 45, - "items": ["iron_ingot", "basic_tool"] + "items": ["iron_ingot", "basic_tool"], }, { "key": "farmer", @@ -573,16 +694,19 @@ async def get_available_jobs() -> List[Dict[str, Any]]: "description": "Grow and harvest crops.", "base_pay": 10, "cooldown_minutes": 15, - "items": [] - } + "items": [], + }, ] return jobs + # --- Item/Inventory Functions --- + async def get_item_details(item_key: str) -> Optional[Dict[str, Any]]: """Gets details for a specific item. Uses Redis cache.""" - if not pool: raise ConnectionError("Database pool not initialized.") + if not pool: + raise ConnectionError("Database pool not initialized.") cache_key = CACHE_ITEM_KEY.format(item_key=item_key) # 1. Check Cache @@ -600,16 +724,16 @@ async def get_item_details(item_key: str) -> Optional[Dict[str, Any]]: async with pool.acquire() as conn: item_record = await conn.fetchrow( "SELECT name, description, sell_price FROM items WHERE item_key = $1", - item_key + item_key, ) item_data: Optional[Dict[str, Any]] = None if item_record: item_data = { "key": item_key, - "name": item_record['name'], - "description": item_record['description'], - "sell_price": item_record['sell_price'] + "name": item_record["name"], + "description": item_record["description"], + "sell_price": item_record["sell_price"], } # 3. Update Cache (use longer TTL for items) @@ -621,9 +745,11 @@ async def get_item_details(item_key: str) -> Optional[Dict[str, Any]]: return item_data + async def get_inventory(user_id: int) -> List[Dict[str, Any]]: """Gets a user's inventory. Uses Redis cache.""" - if not pool: raise ConnectionError("Database pool not initialized.") + if not pool: + raise ConnectionError("Database pool not initialized.") cache_key = CACHE_INVENTORY_KEY.format(user_id=user_id) # 1. Check Cache @@ -640,55 +766,73 @@ async def get_inventory(user_id: int) -> List[Dict[str, Any]]: # 2. Query Database inventory = [] async with pool.acquire() as conn: - results = await conn.fetch(""" + results = await conn.fetch( + """ SELECT inv.item_key, inv.quantity, i.name, i.description, i.sell_price FROM user_inventory inv JOIN items i ON inv.item_key = i.item_key WHERE inv.user_id = $1 ORDER BY i.name - """, user_id) + """, + user_id, + ) for row in results: - inventory.append({ - "key": row['item_key'], - "quantity": row['quantity'], - "name": row['name'], - "description": row['description'], - "sell_price": row['sell_price'] - }) + inventory.append( + { + "key": row["item_key"], + "quantity": row["quantity"], + "name": row["name"], + "description": row["description"], + "sell_price": row["sell_price"], + } + ) # 3. Update Cache if redis_client: try: - await redis_client.set(cache_key, json.dumps(inventory), ex=CACHE_DEFAULT_TTL) + await redis_client.set( + cache_key, json.dumps(inventory), ex=CACHE_DEFAULT_TTL + ) except Exception as e: log.warning(f"Redis SET failed for key {cache_key}: {e}", exc_info=True) return inventory + async def add_item_to_inventory(user_id: int, item_key: str, quantity: int = 1): """Adds an item to the user's inventory. Invalidates cache.""" - if not pool: raise ConnectionError("Database pool not initialized.") + if not pool: + raise ConnectionError("Database pool not initialized.") cache_key = CACHE_INVENTORY_KEY.format(user_id=user_id) if quantity <= 0: - log.warning(f"Attempted to add non-positive quantity ({quantity}) of item {item_key} for user {user_id}") + log.warning( + f"Attempted to add non-positive quantity ({quantity}) of item {item_key} for user {user_id}" + ) return # Check if item exists (can use cached version) item_details = await get_item_details(item_key) if not item_details: - log.error(f"Attempted to add non-existent item '{item_key}' to inventory for user {user_id}") + log.error( + f"Attempted to add non-existent item '{item_key}' to inventory for user {user_id}" + ) return async with pool.acquire() as conn: # Ensure user exists in economy table await get_balance(user_id) # Use ON CONFLICT DO UPDATE for UPSERT behavior - await conn.execute(""" + await conn.execute( + """ INSERT INTO user_inventory (user_id, item_key, quantity) VALUES ($1, $2, $3) ON CONFLICT (user_id, item_key) DO UPDATE SET quantity = user_inventory.quantity + EXCLUDED.quantity - """, user_id, item_key, quantity) + """, + user_id, + item_key, + quantity, + ) log.debug(f"Added {quantity} of {item_key} to user {user_id}'s inventory.") # Invalidate Cache @@ -699,13 +843,19 @@ async def add_item_to_inventory(user_id: int, item_key: str, quantity: int = 1): except Exception as e: log.warning(f"Redis DELETE failed for key {cache_key}: {e}", exc_info=True) -async def remove_item_from_inventory(user_id: int, item_key: str, quantity: int = 1) -> bool: + +async def remove_item_from_inventory( + user_id: int, item_key: str, quantity: int = 1 +) -> bool: """Removes an item from the user's inventory. Invalidates cache. Returns True if successful.""" - if not pool: raise ConnectionError("Database pool not initialized.") + if not pool: + raise ConnectionError("Database pool not initialized.") cache_key = CACHE_INVENTORY_KEY.format(user_id=user_id) if quantity <= 0: - log.warning(f"Attempted to remove non-positive quantity ({quantity}) of item {item_key} for user {user_id}") + log.warning( + f"Attempted to remove non-positive quantity ({quantity}) of item {item_key} for user {user_id}" + ) return False success = False @@ -713,21 +863,35 @@ async def remove_item_from_inventory(user_id: int, item_key: str, quantity: int # Use transaction for check-then-delete/update async with conn.transaction(): current_quantity = await conn.fetchval( - "SELECT quantity FROM user_inventory WHERE user_id = $1 AND item_key = $2 FOR UPDATE", # Lock row - user_id, item_key + "SELECT quantity FROM user_inventory WHERE user_id = $1 AND item_key = $2 FOR UPDATE", # Lock row + user_id, + item_key, ) if current_quantity is None or current_quantity < quantity: - log.debug(f"User {user_id} does not have enough {item_key} (needs {quantity}, has {current_quantity or 0})") - success = False # Explicitly set success to False + log.debug( + f"User {user_id} does not have enough {item_key} (needs {quantity}, has {current_quantity or 0})" + ) + success = False # Explicitly set success to False # No need to rollback explicitly, transaction context manager handles it else: if current_quantity == quantity: - await conn.execute("DELETE FROM user_inventory WHERE user_id = $1 AND item_key = $2", user_id, item_key) + await conn.execute( + "DELETE FROM user_inventory WHERE user_id = $1 AND item_key = $2", + user_id, + item_key, + ) else: - await conn.execute("UPDATE user_inventory SET quantity = quantity - $1 WHERE user_id = $2 AND item_key = $3", quantity, user_id, item_key) - log.debug(f"Removed {quantity} of {item_key} from user {user_id}'s inventory.") - success = True # Set success to True only if operations succeed + await conn.execute( + "UPDATE user_inventory SET quantity = quantity - $1 WHERE user_id = $2 AND item_key = $3", + quantity, + user_id, + item_key, + ) + log.debug( + f"Removed {quantity} of {item_key} from user {user_id}'s inventory." + ) + success = True # Set success to True only if operations succeed # Invalidate Cache only if removal was successful if success and redis_client: @@ -739,6 +903,7 @@ async def remove_item_from_inventory(user_id: int, item_key: str, quantity: int return success + async def close_db(): """Closes the PostgreSQL pool and Redis client.""" global pool, redis_client diff --git a/cogs/economy/earning.py b/cogs/economy/earning.py index 9739d03..4cb33e4 100644 --- a/cogs/economy/earning.py +++ b/cogs/economy/earning.py @@ -10,6 +10,7 @@ from . import database log = logging.getLogger(__name__) + class EarningCommands(commands.Cog): """Cog containing currency earning commands.""" @@ -22,7 +23,7 @@ class EarningCommands(commands.Cog): user_id = ctx.author.id command_name = "daily" cooldown_duration = datetime.timedelta(hours=24) - reward_amount = 100 # Example daily reward + reward_amount = 100 # Example daily reward last_used = await database.check_cooldown(user_id, command_name) @@ -37,7 +38,10 @@ class EarningCommands(commands.Cog): time_left = cooldown_duration - time_since_last_used hours, remainder = divmod(int(time_left.total_seconds()), 3600) minutes, seconds = divmod(remainder, 60) - embed = discord.Embed(description=f"🕒 You've already claimed your daily reward. Try again in **{hours}h {minutes}m {seconds}s**.", color=discord.Color.orange()) + embed = discord.Embed( + description=f"🕒 You've already claimed your daily reward. Try again in **{hours}h {minutes}m {seconds}s**.", + color=discord.Color.orange(), + ) await ctx.send(embed=embed, ephemeral=True) return @@ -48,19 +52,18 @@ class EarningCommands(commands.Cog): embed = discord.Embed( title="Daily Reward Claimed!", description=f"🎉 You claimed your daily reward of **${reward_amount:,}**!", - color=discord.Color.green() + color=discord.Color.green(), ) embed.add_field(name="New Balance", value=f"${current_balance:,}", inline=False) await ctx.send(embed=embed) - @commands.command(name="beg", description="Beg for some spare change.") async def beg(self, ctx: commands.Context): """Allows users to beg for a small amount of currency with a chance of success.""" user_id = ctx.author.id command_name = "beg" - cooldown_duration = datetime.timedelta(minutes=5) # 5-minute cooldown - success_chance = 0.4 # 40% chance of success + cooldown_duration = datetime.timedelta(minutes=5) # 5-minute cooldown + success_chance = 0.4 # 40% chance of success min_reward = 1 max_reward = 20 @@ -75,7 +78,10 @@ class EarningCommands(commands.Cog): if time_since_last_used < cooldown_duration: time_left = cooldown_duration - time_since_last_used minutes, seconds = divmod(int(time_left.total_seconds()), 60) - embed = discord.Embed(description=f"🕒 You can't beg again so soon. Try again in **{minutes}m {seconds}s**.", color=discord.Color.orange()) + embed = discord.Embed( + description=f"🕒 You can't beg again so soon. Try again in **{minutes}m {seconds}s**.", + color=discord.Color.orange(), + ) await ctx.send(embed=embed, ephemeral=True) return @@ -90,15 +96,17 @@ class EarningCommands(commands.Cog): embed = discord.Embed( title="Begging Successful!", description=f"🙏 Someone took pity on you! You received **${reward_amount:,}**.", - color=discord.Color.green() + color=discord.Color.green(), + ) + embed.add_field( + name="New Balance", value=f"${current_balance:,}", inline=False ) - embed.add_field(name="New Balance", value=f"${current_balance:,}", inline=False) await ctx.send(embed=embed) else: embed = discord.Embed( title="Begging Failed", description="🤷 Nobody gave you anything. Better luck next time!", - color=discord.Color.red() + color=discord.Color.red(), ) await ctx.send(embed=embed) @@ -107,8 +115,10 @@ class EarningCommands(commands.Cog): """Allows users to perform work for a small, guaranteed reward.""" user_id = ctx.author.id command_name = "work" - cooldown_duration = datetime.timedelta(hours=1) # 1-hour cooldown - reward_amount = random.randint(15, 35) # Small reward range - This is now fallback if no job + cooldown_duration = datetime.timedelta(hours=1) # 1-hour cooldown + reward_amount = random.randint( + 15, 35 + ) # Small reward range - This is now fallback if no job # --- Check if user has a job --- job_info = await database.get_user_job(user_id) @@ -119,13 +129,15 @@ class EarningCommands(commands.Cog): # from .jobs import JOB_DEFINITIONS # Avoid circular import if possible # job_details = JOB_DEFINITIONS.get(job_key) # command_to_use = job_details['command'] if job_details else f"your job command (`/{job_key}`)" # Fallback - command_to_use = f"`/{job_key}`" # Simple fallback - embed = discord.Embed(description=f"💼 You have a job! Use {command_to_use} instead of the generic `/work` command.", color=discord.Color.blue()) + command_to_use = f"`/{job_key}`" # Simple fallback + embed = discord.Embed( + description=f"💼 You have a job! Use {command_to_use} instead of the generic `/work` command.", + color=discord.Color.blue(), + ) await ctx.send(embed=embed, ephemeral=True) return # --- End Job Check --- - # Proceed with generic /work only if no job last_used = await database.check_cooldown(user_id, command_name) @@ -139,7 +151,10 @@ class EarningCommands(commands.Cog): time_left = cooldown_duration - time_since_last_used hours, remainder = divmod(int(time_left.total_seconds()), 3600) minutes, seconds = divmod(remainder, 60) - embed = discord.Embed(description=f"🕒 You need to rest after working. Try again in **{hours}h {minutes}m {seconds}s**.", color=discord.Color.orange()) + embed = discord.Embed( + description=f"🕒 You need to rest after working. Try again in **{hours}h {minutes}m {seconds}s**.", + color=discord.Color.orange(), + ) await ctx.send(embed=embed, ephemeral=True) return @@ -156,18 +171,20 @@ class EarningCommands(commands.Cog): embed = discord.Embed( title="Work Complete!", description=random.choice(work_messages), - color=discord.Color.green() + color=discord.Color.green(), ) embed.add_field(name="New Balance", value=f"${current_balance:,}", inline=False) await ctx.send(embed=embed) - @commands.command(name="scavenge", description="Scavenge around for some spare change.") # Renamed to avoid conflict - async def scavenge(self, ctx: commands.Context): # Renamed function + @commands.command( + name="scavenge", description="Scavenge around for some spare change." + ) # Renamed to avoid conflict + async def scavenge(self, ctx: commands.Context): # Renamed function """Allows users to scavenge for a small chance of finding money.""" user_id = ctx.author.id - command_name = "scavenge" # Update command name for cooldown tracking - cooldown_duration = datetime.timedelta(minutes=30) # 30-minute cooldown - success_chance = 0.25 # 25% chance to find something + command_name = "scavenge" # Update command name for cooldown tracking + cooldown_duration = datetime.timedelta(minutes=30) # 30-minute cooldown + success_chance = 0.25 # 25% chance to find something min_reward = 1 max_reward = 10 @@ -182,7 +199,10 @@ class EarningCommands(commands.Cog): if time_since_last_used < cooldown_duration: time_left = cooldown_duration - time_since_last_used minutes, seconds = divmod(int(time_left.total_seconds()), 60) - embed = discord.Embed(description=f"🕒 You've searched recently. Try again in **{minutes}m {seconds}s**.", color=discord.Color.orange()) + embed = discord.Embed( + description=f"🕒 You've searched recently. Try again in **{minutes}m {seconds}s**.", + color=discord.Color.orange(), + ) await ctx.send(embed=embed, ephemeral=True) return @@ -190,9 +210,13 @@ class EarningCommands(commands.Cog): await database.set_cooldown(user_id, command_name) # Flavor text for scavenging - scavenge_locations = [ # Renamed variable for clarity - "under the sofa cushions", "in an old coat pocket", "behind the dumpster", - "in a dusty corner", "on the sidewalk", "in a forgotten drawer" + scavenge_locations = [ # Renamed variable for clarity + "under the sofa cushions", + "in an old coat pocket", + "behind the dumpster", + "in a dusty corner", + "on the sidewalk", + "in a forgotten drawer", ] location = random.choice(scavenge_locations) @@ -203,16 +227,19 @@ class EarningCommands(commands.Cog): embed = discord.Embed( title="Scavenging Successful!", description=f"🔍 You scavenged {location} and found **${reward_amount:,}**!", - color=discord.Color.green() + color=discord.Color.green(), + ) + embed.add_field( + name="New Balance", value=f"${current_balance:,}", inline=False ) - embed.add_field(name="New Balance", value=f"${current_balance:,}", inline=False) await ctx.send(embed=embed) else: embed = discord.Embed( title="Scavenging Failed", description=f"🔍 You scavenged {location} but found nothing but lint.", - color=discord.Color.red() + color=discord.Color.red(), ) await ctx.send(embed=embed) + # No setup function needed here, it will be in __init__.py diff --git a/cogs/economy/gambling.py b/cogs/economy/gambling.py index 7e6e77d..d9b0e89 100644 --- a/cogs/economy/gambling.py +++ b/cogs/economy/gambling.py @@ -10,27 +10,40 @@ from . import database log = logging.getLogger(__name__) + class GamblingCommands(commands.Cog): """Cog containing gambling-related economy commands.""" def __init__(self, bot: commands.Bot): self.bot = bot - @commands.hybrid_command(name="moneyflip", aliases=["mf"], description="Gamble your money on a coin flip.") # Renamed to avoid conflict - async def moneyflip(self, ctx: commands.Context, amount: int, choice: str): # Renamed function + @commands.hybrid_command( + name="moneyflip", + aliases=["mf"], + description="Gamble your money on a coin flip.", + ) # Renamed to avoid conflict + async def moneyflip( + self, ctx: commands.Context, amount: int, choice: str + ): # Renamed function """Bets a certain amount on a coin flip (heads or tails).""" user_id = ctx.author.id - command_name = "moneyflip" # Update command name for cooldown tracking - cooldown_duration = datetime.timedelta(seconds=10) # Short cooldown + command_name = "moneyflip" # Update command name for cooldown tracking + cooldown_duration = datetime.timedelta(seconds=10) # Short cooldown choice = choice.lower() if choice not in ["heads", "tails", "h", "t"]: - embed = discord.Embed(description="❌ Invalid choice. Please choose 'heads' or 'tails'.", color=discord.Color.red()) + embed = discord.Embed( + description="❌ Invalid choice. Please choose 'heads' or 'tails'.", + color=discord.Color.red(), + ) await ctx.send(embed=embed, ephemeral=True) return if amount <= 0: - embed = discord.Embed(description="❌ Please enter a positive amount to bet.", color=discord.Color.red()) + embed = discord.Embed( + description="❌ Please enter a positive amount to bet.", + color=discord.Color.red(), + ) await ctx.send(embed=embed, ephemeral=True) return @@ -44,14 +57,20 @@ class GamblingCommands(commands.Cog): time_since_last_used = now_utc - last_used if time_since_last_used < cooldown_duration: time_left = cooldown_duration - time_since_last_used - embed = discord.Embed(description=f"🕒 You're flipping too fast! Try again in **{int(time_left.total_seconds())}s**.", color=discord.Color.orange()) + embed = discord.Embed( + description=f"🕒 You're flipping too fast! Try again in **{int(time_left.total_seconds())}s**.", + color=discord.Color.orange(), + ) await ctx.send(embed=embed, ephemeral=True) return # Check balance user_balance = await database.get_balance(user_id) if user_balance < amount: - embed = discord.Embed(description=f"❌ You don't have enough money to bet that much! Your balance is **${user_balance:,}**.", color=discord.Color.red()) + embed = discord.Embed( + description=f"❌ You don't have enough money to bet that much! Your balance is **${user_balance:,}**.", + color=discord.Color.red(), + ) await ctx.send(embed=embed, ephemeral=True) return @@ -60,27 +79,32 @@ class GamblingCommands(commands.Cog): # Perform the coin flip result = random.choice(["heads", "tails"]) - win = (choice.startswith(result[0])) # True if choice matches result + win = choice.startswith(result[0]) # True if choice matches result if win: - await database.update_balance(user_id, amount) # Win the amount bet + await database.update_balance(user_id, amount) # Win the amount bet current_balance = await database.get_balance(user_id) embed = discord.Embed( title="Coin Flip: Win!", description=f"🪙 The coin landed on **{result}**! You won **${amount:,}**!", - color=discord.Color.green() + color=discord.Color.green(), + ) + embed.add_field( + name="New Balance", value=f"${current_balance:,}", inline=False ) - embed.add_field(name="New Balance", value=f"${current_balance:,}", inline=False) await ctx.send(embed=embed) else: - await database.update_balance(user_id, -amount) # Lose the amount bet + await database.update_balance(user_id, -amount) # Lose the amount bet current_balance = await database.get_balance(user_id) embed = discord.Embed( title="Coin Flip: Loss!", description=f"🪙 The coin landed on **{result}**. You lost **${amount:,}**.", - color=discord.Color.red() + color=discord.Color.red(), + ) + embed.add_field( + name="New Balance", value=f"${current_balance:,}", inline=False ) - embed.add_field(name="New Balance", value=f"${current_balance:,}", inline=False) await ctx.send(embed=embed) + # No setup function needed here diff --git a/cogs/economy/jobs.py b/cogs/economy/jobs.py index 27cf375..f9897fe 100644 --- a/cogs/economy/jobs.py +++ b/cogs/economy/jobs.py @@ -1,6 +1,6 @@ import discord from discord.ext import commands -from discord import app_commands # Required for choices/autocomplete +from discord import app_commands # Required for choices/autocomplete import datetime import random import logging @@ -19,17 +19,17 @@ JOB_DEFINITIONS = { "description": "Mine for ores and gems.", "command": "/mine", "cooldown": datetime.timedelta(hours=1), - "base_currency": (15, 30), # Min/Max currency per action + "base_currency": (15, 30), # Min/Max currency per action "base_xp": 15, - "drops": { # Item Key: Chance (0.0 to 1.0) + "drops": { # Item Key: Chance (0.0 to 1.0) "raw_iron": 0.6, "coal": 0.4, - "shiny_gem": 0.05 # Lower chance for rarer items + "shiny_gem": 0.05, # Lower chance for rarer items + }, + "level_bonus": { # Applied per level + "currency_increase": 1, # Add +1 to min/max currency range per level + "rare_find_increase": 0.005, # Increase shiny_gem chance by 0.5% per level }, - "level_bonus": { # Applied per level - "currency_increase": 1, # Add +1 to min/max currency range per level - "rare_find_increase": 0.005 # Increase shiny_gem chance by 0.5% per level - } }, "fisher": { "name": "Fisher", @@ -39,35 +39,34 @@ JOB_DEFINITIONS = { "base_currency": (5, 15), "base_xp": 10, "drops": { - "common_fish": 0.8, # High chance for common + "common_fish": 0.8, # High chance for common "rare_fish": 0.15, - "treasure_chest": 0.02 + "treasure_chest": 0.02, }, "level_bonus": { - "currency_increase": 0.5, # Smaller increase - "rare_find_increase": 0.003 # Increase rare_fish/treasure chance - } + "currency_increase": 0.5, # Smaller increase + "rare_find_increase": 0.003, # Increase rare_fish/treasure chance + }, }, "crafter": { "name": "Crafter", "description": "Use materials to craft valuable items.", "command": "/craft", - "cooldown": datetime.timedelta(minutes=15), # Cooldown per craft action - "base_currency": (0, 0), # No direct currency - "base_xp": 20, # Higher XP for crafting - "recipes": { # Output Item Key: {Input Item Key: Quantity Required} + "cooldown": datetime.timedelta(minutes=15), # Cooldown per craft action + "base_currency": (0, 0), # No direct currency + "base_xp": 20, # Higher XP for crafting + "recipes": { # Output Item Key: {Input Item Key: Quantity Required} "iron_ingot": {"raw_iron": 2, "coal": 1}, - "basic_tool": {"iron_ingot": 3} + "basic_tool": {"iron_ingot": 3}, }, "level_bonus": { - "unlock_recipe_level": { # Level required to unlock recipe - "basic_tool": 5 - }, - # Could add reduced material cost later - } - } + "unlock_recipe_level": {"basic_tool": 5}, # Level required to unlock recipe + # Could add reduced material cost later + }, + }, } + # Helper function to format time delta def format_timedelta(delta: datetime.timedelta) -> str: """Formats a timedelta into a human-readable string (e.g., 1h 30m 15s).""" @@ -81,10 +80,11 @@ def format_timedelta(delta: datetime.timedelta) -> str: parts.append(f"{hours}h") if minutes > 0: parts.append(f"{minutes}m") - if seconds > 0 or not parts: # Show seconds if it's the only unit or > 0 + if seconds > 0 or not parts: # Show seconds if it's the only unit or > 0 parts.append(f"{seconds}s") return " ".join(parts) + class JobsCommands(commands.Cog): """Cog containing job-related economy commands.""" @@ -114,7 +114,10 @@ class JobsCommands(commands.Cog): job_info = await database.get_user_job(user_id) if not job_info or not job_info.get("name"): - embed = discord.Embed(description="❌ You don't currently have a job. Use `/jobs` to see available options and `/choosejob ` to pick one.", color=discord.Color.orange()) + embed = discord.Embed( + description="❌ You don't currently have a job. Use `/jobs` to see available options and `/choosejob ` to pick one.", + color=discord.Color.orange(), + ) await ctx.send(embed=embed, ephemeral=True) return @@ -124,39 +127,57 @@ class JobsCommands(commands.Cog): job_details = JOB_DEFINITIONS.get(job_key) if not job_details: - embed = discord.Embed(description=f"❌ Error: Your job '{job_key}' is not recognized. Please contact an admin.", color=discord.Color.red()) - await ctx.send(embed=embed, ephemeral=True) - log.error(f"User {user_id} has unrecognized job '{job_key}' in database.") - return + embed = discord.Embed( + description=f"❌ Error: Your job '{job_key}' is not recognized. Please contact an admin.", + color=discord.Color.red(), + ) + await ctx.send(embed=embed, ephemeral=True) + log.error(f"User {user_id} has unrecognized job '{job_key}' in database.") + return - xp_needed = level * 100 # Matches logic in database.py - embed = discord.Embed(title=f"{ctx.author.display_name}'s Job: {job_details['name']}", color=discord.Color.green()) + xp_needed = level * 100 # Matches logic in database.py + embed = discord.Embed( + title=f"{ctx.author.display_name}'s Job: {job_details['name']}", + color=discord.Color.green(), + ) embed.add_field(name="Level", value=level, inline=True) embed.add_field(name="XP", value=f"{xp} / {xp_needed}", inline=True) # Cooldown check last_action = job_info.get("last_action") - cooldown = job_details['cooldown'] + cooldown = job_details["cooldown"] if last_action: now_utc = datetime.datetime.now(datetime.timezone.utc) time_since = now_utc - last_action if time_since < cooldown: time_left = cooldown - time_since - embed.add_field(name="Cooldown", value=f"Ready in: {format_timedelta(time_left)}", inline=False) + embed.add_field( + name="Cooldown", + value=f"Ready in: {format_timedelta(time_left)}", + inline=False, + ) else: embed.add_field(name="Cooldown", value="Ready!", inline=False) else: - embed.add_field(name="Cooldown", value="Ready!", inline=False) + embed.add_field(name="Cooldown", value="Ready!", inline=False) - embed.set_footer(text=f"Use {job_details['command']} to perform your job action.") + embed.set_footer( + text=f"Use {job_details['command']} to perform your job action." + ) await ctx.send(embed=embed) # Autocomplete for choosejob and leavejob - async def job_autocomplete(self, interaction: discord.Interaction, current: str) -> List[app_commands.Choice[str]]: + async def job_autocomplete( + self, interaction: discord.Interaction, current: str + ) -> List[app_commands.Choice[str]]: return [ app_commands.Choice(name=details["name"], value=key) - for key, details in JOB_DEFINITIONS.items() if current.lower() in key.lower() or current.lower() in details["name"].lower() - ][:25] # Limit to 25 choices + for key, details in JOB_DEFINITIONS.items() + if current.lower() in key.lower() + or current.lower() in details["name"].lower() + ][ + :25 + ] # Limit to 25 choices @commands.hybrid_command(name="choosejob", description="Select a job to pursue.") @app_commands.autocomplete(job_name=job_autocomplete) @@ -166,13 +187,19 @@ class JobsCommands(commands.Cog): job_key = job_name.lower() if job_key not in JOB_DEFINITIONS: - embed = discord.Embed(description=f"❌ Invalid job name '{job_name}'. Use `/jobs` to see available options.", color=discord.Color.red()) + embed = discord.Embed( + description=f"❌ Invalid job name '{job_name}'. Use `/jobs` to see available options.", + color=discord.Color.red(), + ) await ctx.send(embed=embed, ephemeral=True) return current_job_info = await database.get_user_job(user_id) if current_job_info and current_job_info.get("name") == job_key: - embed = discord.Embed(description=f"✅ You are already a {JOB_DEFINITIONS[job_key]['name']}.", color=discord.Color.blue()) + embed = discord.Embed( + description=f"✅ You are already a {JOB_DEFINITIONS[job_key]['name']}.", + color=discord.Color.blue(), + ) await ctx.send(embed=embed, ephemeral=True) return @@ -182,7 +209,7 @@ class JobsCommands(commands.Cog): embed = discord.Embed( title="Job Changed!", description=f"💼 Congratulations! You are now a **{JOB_DEFINITIONS[job_key]['name']}**.", - color=discord.Color.green() + color=discord.Color.green(), ) embed.set_footer(text="Your previous job progress (if any) has been reset.") await ctx.send(embed=embed) @@ -194,20 +221,25 @@ class JobsCommands(commands.Cog): current_job_info = await database.get_user_job(user_id) if not current_job_info or not current_job_info.get("name"): - embed = discord.Embed(description="❌ You don't have a job to leave.", color=discord.Color.orange()) + embed = discord.Embed( + description="❌ You don't have a job to leave.", + color=discord.Color.orange(), + ) await ctx.send(embed=embed, ephemeral=True) return job_key = current_job_info["name"] job_name = JOB_DEFINITIONS.get(job_key, {}).get("name", "Unknown Job") - await database.set_user_job(user_id, None) # Set job to NULL + await database.set_user_job(user_id, None) # Set job to NULL embed = discord.Embed( title="Job Left", description=f"🗑️ You have left your job as a **{job_name}**.", - color=discord.Color.orange() + color=discord.Color.orange(), + ) + embed.set_footer( + text="Your level and XP for this job have been reset. You can choose a new job with /choosejob." ) - embed.set_footer(text="Your level and XP for this job have been reset. You can choose a new job with /choosejob.") await ctx.send(embed=embed) # --- Job Action Commands --- @@ -221,28 +253,37 @@ class JobsCommands(commands.Cog): if not job_info or job_info.get("name") != job_key: correct_job_info = await database.get_user_job(user_id) if correct_job_info and correct_job_info.get("name"): - correct_job_details = JOB_DEFINITIONS.get(correct_job_info["name"]) - embed = discord.Embed(description=f"❌ You need to be a {JOB_DEFINITIONS[job_key]['name']} to use this command. Your current job is {correct_job_details['name']}. Use `{correct_job_details['command']}` instead, or change jobs with `/choosejob`.", color=discord.Color.red()) - await ctx.send(embed=embed, ephemeral=True) + correct_job_details = JOB_DEFINITIONS.get(correct_job_info["name"]) + embed = discord.Embed( + description=f"❌ You need to be a {JOB_DEFINITIONS[job_key]['name']} to use this command. Your current job is {correct_job_details['name']}. Use `{correct_job_details['command']}` instead, or change jobs with `/choosejob`.", + color=discord.Color.red(), + ) + await ctx.send(embed=embed, ephemeral=True) else: - embed = discord.Embed(description=f"❌ You need to be a {JOB_DEFINITIONS[job_key]['name']} to use this command. You don't have a job. Use `/choosejob {job_key}` first.", color=discord.Color.red()) - await ctx.send(embed=embed, ephemeral=True) - return None # Indicate failure + embed = discord.Embed( + description=f"❌ You need to be a {JOB_DEFINITIONS[job_key]['name']} to use this command. You don't have a job. Use `/choosejob {job_key}` first.", + color=discord.Color.red(), + ) + await ctx.send(embed=embed, ephemeral=True) + return None # Indicate failure job_details = JOB_DEFINITIONS[job_key] level = job_info["level"] # 2. Check Cooldown last_action = job_info.get("last_action") - cooldown = job_details['cooldown'] + cooldown = job_details["cooldown"] if last_action: now_utc = datetime.datetime.now(datetime.timezone.utc) time_since = now_utc - last_action if time_since < cooldown: time_left = cooldown - time_since - embed = discord.Embed(description=f"🕒 You need to wait **{format_timedelta(time_left)}** before you can {job_key} again.", color=discord.Color.orange()) + embed = discord.Embed( + description=f"🕒 You need to wait **{format_timedelta(time_left)}** before you can {job_key} again.", + color=discord.Color.orange(), + ) await ctx.send(embed=embed, ephemeral=True) - return None # Indicate failure + return None # Indicate failure # 3. Set Cooldown Immediately await database.set_job_cooldown(user_id) @@ -251,7 +292,9 @@ class JobsCommands(commands.Cog): level_bonus = job_details.get("level_bonus", {}) currency_bonus = level * level_bonus.get("currency_increase", 0) min_curr, max_curr = job_details["base_currency"] - currency_earned = random.randint(int(min_curr + currency_bonus), int(max_curr + currency_bonus)) + currency_earned = random.randint( + int(min_curr + currency_bonus), int(max_curr + currency_bonus) + ) items_found = {} if "drops" in job_details: @@ -259,10 +302,12 @@ class JobsCommands(commands.Cog): for item_key, base_chance in job_details["drops"].items(): # Apply level bonus to specific rare items if configured (e.g., gems for miner) current_chance = base_chance - if item_key == 'shiny_gem' and job_key == 'miner': + if item_key == "shiny_gem" and job_key == "miner": + current_chance += rare_find_bonus + elif ( + item_key == "rare_fish" or item_key == "treasure_chest" + ) and job_key == "fisher": current_chance += rare_find_bonus - elif (item_key == 'rare_fish' or item_key == 'treasure_chest') and job_key == 'fisher': - current_chance += rare_find_bonus if random.random() < current_chance: items_found[item_key] = items_found.get(item_key, 0) + 1 @@ -274,7 +319,7 @@ class JobsCommands(commands.Cog): await database.add_item_to_inventory(user_id, item_key, quantity) # 6. Grant XP & Handle Level Up - xp_earned = job_details["base_xp"] # Could add level bonus to XP later + xp_earned = job_details["base_xp"] # Could add level bonus to XP later new_level, new_xp, did_level_up = await database.add_job_xp(user_id, xp_earned) # 7. Construct Response Message @@ -284,21 +329,23 @@ class JobsCommands(commands.Cog): if items_found: item_strings = [] for item_key, quantity in items_found.items(): - item_details = await database.get_item_details(item_key) - item_name = item_details['name'] if item_details else item_key - item_strings.append(f"{quantity}x **{item_name}**") + item_details = await database.get_item_details(item_key) + item_name = item_details["name"] if item_details else item_key + item_strings.append(f"{quantity}x **{item_name}**") response_parts.append(f"found {', '.join(item_strings)}") response_parts.append(f"gained **{xp_earned} XP**") - action_verb = job_key.capitalize() # "Mine", "Fish" - message = f"⛏️ You {action_verb} and {', '.join(response_parts)}." # Default message + action_verb = job_key.capitalize() # "Mine", "Fish" + message = ( + f"⛏️ You {action_verb} and {', '.join(response_parts)}." # Default message + ) # Customize message based on job if job_key == "miner": - message = f"⛏️ You mined and {', '.join(response_parts)}." + message = f"⛏️ You mined and {', '.join(response_parts)}." elif job_key == "fisher": - message = f"🎣 You fished and {', '.join(response_parts)}." + message = f"🎣 You fished and {', '.join(response_parts)}." # Crafter handled separately if did_level_up: @@ -307,26 +354,40 @@ class JobsCommands(commands.Cog): current_balance = await database.get_balance(user_id) message += f"\nYour current balance is **${current_balance:,}**." - return message # Indicate success and return message + return message # Indicate success and return message - @commands.hybrid_command(name="mine", description="Mine for ores and gems (Miner job).") + @commands.hybrid_command( + name="mine", description="Mine for ores and gems (Miner job)." + ) async def mine(self, ctx: commands.Context): """Performs the Miner job action.""" result_message = await self._handle_job_action(ctx, "miner") if result_message: - embed = discord.Embed(title="Mining Results", description=result_message, color=discord.Color.dark_grey()) + embed = discord.Embed( + title="Mining Results", + description=result_message, + color=discord.Color.dark_grey(), + ) await ctx.send(embed=embed) - @commands.hybrid_command(name="fish", description="Catch fish and maybe find treasure (Fisher job).") + @commands.hybrid_command( + name="fish", description="Catch fish and maybe find treasure (Fisher job)." + ) async def fish(self, ctx: commands.Context): """Performs the Fisher job action.""" result_message = await self._handle_job_action(ctx, "fisher") if result_message: - embed = discord.Embed(title="Fishing Results", description=result_message, color=discord.Color.blue()) + embed = discord.Embed( + title="Fishing Results", + description=result_message, + color=discord.Color.blue(), + ) await ctx.send(embed=embed) # --- Crafter Specific --- - async def craft_autocomplete(self, interaction: discord.Interaction, current: str) -> List[app_commands.Choice[str]]: + async def craft_autocomplete( + self, interaction: discord.Interaction, current: str + ) -> List[app_commands.Choice[str]]: user_id = interaction.user.id job_info = await database.get_user_job(user_id) choices = [] @@ -335,17 +396,26 @@ class JobsCommands(commands.Cog): level = job_info["level"] for item_key, recipe in crafter_details.get("recipes", {}).items(): # Check level requirement - required_level = crafter_details.get("level_bonus", {}).get("unlock_recipe_level", {}).get(item_key, 1) + required_level = ( + crafter_details.get("level_bonus", {}) + .get("unlock_recipe_level", {}) + .get(item_key, 1) + ) if level < required_level: continue item_details = await database.get_item_details(item_key) - item_name = item_details['name'] if item_details else item_key - if current.lower() in item_key.lower() or current.lower() in item_name.lower(): - choices.append(app_commands.Choice(name=item_name, value=item_key)) + item_name = item_details["name"] if item_details else item_key + if ( + current.lower() in item_key.lower() + or current.lower() in item_name.lower() + ): + choices.append(app_commands.Choice(name=item_name, value=item_key)) return choices[:25] - @commands.hybrid_command(name="craft", description="Craft items using materials (Crafter job).") + @commands.hybrid_command( + name="craft", description="Craft items using materials (Crafter job)." + ) @app_commands.autocomplete(item_to_craft=craft_autocomplete) async def craft(self, ctx: commands.Context, item_to_craft: str): """Performs the Crafter job action.""" @@ -355,7 +425,10 @@ class JobsCommands(commands.Cog): # 1. Check if user has the correct job if not job_info or job_info.get("name") != job_key: - embed = discord.Embed(description="❌ You need to be a Crafter to use this command. Use `/choosejob crafter` first.", color=discord.Color.red()) + embed = discord.Embed( + description="❌ You need to be a Crafter to use this command. Use `/choosejob crafter` first.", + color=discord.Color.red(), + ) await ctx.send(embed=embed, ephemeral=True) return @@ -366,47 +439,62 @@ class JobsCommands(commands.Cog): # 2. Check if recipe exists recipes = job_details.get("recipes", {}) if recipe_key not in recipes: - embed = discord.Embed(description=f"❌ Unknown recipe: '{item_to_craft}'. Check available recipes.", color=discord.Color.red()) # TODO: Add /recipes command? + embed = discord.Embed( + description=f"❌ Unknown recipe: '{item_to_craft}'. Check available recipes.", + color=discord.Color.red(), + ) # TODO: Add /recipes command? await ctx.send(embed=embed, ephemeral=True) return # 3. Check Level Requirement - required_level = job_details.get("level_bonus", {}).get("unlock_recipe_level", {}).get(recipe_key, 1) + required_level = ( + job_details.get("level_bonus", {}) + .get("unlock_recipe_level", {}) + .get(recipe_key, 1) + ) if level < required_level: - embed = discord.Embed(description=f"❌ You need to be Level {required_level} to craft this item. You are currently Level {level}.", color=discord.Color.red()) - await ctx.send(embed=embed, ephemeral=True) - return + embed = discord.Embed( + description=f"❌ You need to be Level {required_level} to craft this item. You are currently Level {level}.", + color=discord.Color.red(), + ) + await ctx.send(embed=embed, ephemeral=True) + return # 4. Check Cooldown last_action = job_info.get("last_action") - cooldown = job_details['cooldown'] + cooldown = job_details["cooldown"] if last_action: now_utc = datetime.datetime.now(datetime.timezone.utc) time_since = now_utc - last_action if time_since < cooldown: time_left = cooldown - time_since - embed = discord.Embed(description=f"🕒 You need to wait **{format_timedelta(time_left)}** before you can craft again.", color=discord.Color.orange()) + embed = discord.Embed( + description=f"🕒 You need to wait **{format_timedelta(time_left)}** before you can craft again.", + color=discord.Color.orange(), + ) await ctx.send(embed=embed, ephemeral=True) return # 5. Check Materials required_materials = recipes[recipe_key] inventory = await database.get_inventory(user_id) - inventory_map = {item['key']: item['quantity'] for item in inventory} + inventory_map = {item["key"]: item["quantity"] for item in inventory} missing_materials = [] can_craft = True for mat_key, mat_qty in required_materials.items(): if inventory_map.get(mat_key, 0) < mat_qty: can_craft = False mat_details = await database.get_item_details(mat_key) - mat_name = mat_details['name'] if mat_details else mat_key - missing_materials.append(f"{mat_qty - inventory_map.get(mat_key, 0)}x {mat_name}") + mat_name = mat_details["name"] if mat_details else mat_key + missing_materials.append( + f"{mat_qty - inventory_map.get(mat_key, 0)}x {mat_name}" + ) if not can_craft: embed = discord.Embed( title="Missing Materials", description=f"❌ You don't have the required materials. You still need: {', '.join(missing_materials)}.", - color=discord.Color.red() + color=discord.Color.red(), ) await ctx.send(embed=embed, ephemeral=True) return @@ -419,8 +507,13 @@ class JobsCommands(commands.Cog): for mat_key, mat_qty in required_materials.items(): if not await database.remove_item_from_inventory(user_id, mat_key, mat_qty): success = False - log.error(f"Failed to remove material {mat_key} x{mat_qty} for user {user_id} during crafting, despite check.") - embed = discord.Embed(description="❌ An error occurred while consuming materials. Please try again.", color=discord.Color.red()) + log.error( + f"Failed to remove material {mat_key} x{mat_qty} for user {user_id} during crafting, despite check." + ) + embed = discord.Embed( + description="❌ An error occurred while consuming materials. Please try again.", + color=discord.Color.red(), + ) await ctx.send(embed=embed, ephemeral=True) # Should ideally revert cooldown here, but that's complex. return @@ -430,80 +523,120 @@ class JobsCommands(commands.Cog): # 8. Grant XP & Handle Level Up xp_earned = job_details["base_xp"] - new_level, new_xp, did_level_up = await database.add_job_xp(user_id, xp_earned) + new_level, new_xp, did_level_up = await database.add_job_xp( + user_id, xp_earned + ) # 9. Construct Response crafted_item_details = await database.get_item_details(recipe_key) crafted_item_details = await database.get_item_details(recipe_key) - crafted_item_name = crafted_item_details['name'] if crafted_item_details else recipe_key + crafted_item_name = ( + crafted_item_details["name"] if crafted_item_details else recipe_key + ) embed = discord.Embed( title="Crafting Successful!", description=f"🛠️ You successfully crafted 1x **{crafted_item_name}** and gained **{xp_earned} XP**.", - color=discord.Color.purple() # Use a different color for crafting + color=discord.Color.purple(), # Use a different color for crafting ) if did_level_up: - embed.add_field(name="Level Up!", value=f"**Congratulations! You reached Level {new_level} in {job_details['name']}!** 🎉", inline=False) + embed.add_field( + name="Level Up!", + value=f"**Congratulations! You reached Level {new_level} in {job_details['name']}!** 🎉", + inline=False, + ) await ctx.send(embed=embed) - # --- Inventory Commands --- - @commands.hybrid_command(name="inventory", aliases=["inv"], description="View your items.") + @commands.hybrid_command( + name="inventory", aliases=["inv"], description="View your items." + ) async def inventory(self, ctx: commands.Context): """Displays the items in the user's inventory.""" user_id = ctx.author.id inventory_items = await database.get_inventory(user_id) if not inventory_items: - embed = discord.Embed(description="🗑️ Your inventory is empty.", color=discord.Color.orange()) + embed = discord.Embed( + description="🗑️ Your inventory is empty.", color=discord.Color.orange() + ) await ctx.send(embed=embed, ephemeral=True) return - embed = discord.Embed(title=f"{ctx.author.display_name}'s Inventory 🎒", color=discord.Color.orange()) + embed = discord.Embed( + title=f"{ctx.author.display_name}'s Inventory 🎒", + color=discord.Color.orange(), + ) description = "" for item in inventory_items: - sell_info = f" (Sell: ${item['sell_price']:,})" if item['sell_price'] > 0 else "" + sell_info = ( + f" (Sell: ${item['sell_price']:,})" if item["sell_price"] > 0 else "" + ) description += f"- **{item['name']}** x{item['quantity']}{sell_info}\n" - if item['description']: - description += f" *({item['description']})*\n" # Add description if available + if item["description"]: + description += ( + f" *({item['description']})*\n" # Add description if available + ) # Handle potential description length limit - if len(description) > 4000: # Embed description limit is 4096 - description = description[:4000] + "\n... (Inventory too large to display fully)" + if len(description) > 4000: # Embed description limit is 4096 + description = ( + description[:4000] + "\n... (Inventory too large to display fully)" + ) embed.description = description await ctx.send(embed=embed) # Autocomplete for sell command - async def inventory_autocomplete(self, interaction: discord.Interaction, current: str) -> List[app_commands.Choice[str]]: + async def inventory_autocomplete( + self, interaction: discord.Interaction, current: str + ) -> List[app_commands.Choice[str]]: user_id = interaction.user.id inventory = await database.get_inventory(user_id) return [ - app_commands.Choice(name=f"{item['name']} (Have: {item['quantity']})", value=item['key']) - for item in inventory if item['sell_price'] > 0 and (current.lower() in item['key'].lower() or current.lower() in item['name'].lower()) + app_commands.Choice( + name=f"{item['name']} (Have: {item['quantity']})", value=item["key"] + ) + for item in inventory + if item["sell_price"] > 0 + and ( + current.lower() in item["key"].lower() + or current.lower() in item["name"].lower() + ) ][:25] @commands.hybrid_command(name="sell", description="Sell items from your inventory.") @app_commands.autocomplete(item_key=inventory_autocomplete) - async def sell(self, ctx: commands.Context, item_key: str, quantity: Optional[int] = 1): + async def sell( + self, ctx: commands.Context, item_key: str, quantity: Optional[int] = 1 + ): """Sells a specified quantity of an item from the inventory.""" user_id = ctx.author.id if quantity <= 0: - embed = discord.Embed(description="❌ Please enter a positive quantity to sell.", color=discord.Color.red()) + embed = discord.Embed( + description="❌ Please enter a positive quantity to sell.", + color=discord.Color.red(), + ) await ctx.send(embed=embed, ephemeral=True) return item_details = await database.get_item_details(item_key) if not item_details: - embed = discord.Embed(description=f"❌ Invalid item key '{item_key}'. Check your `/inventory`.", color=discord.Color.red()) + embed = discord.Embed( + description=f"❌ Invalid item key '{item_key}'. Check your `/inventory`.", + color=discord.Color.red(), + ) await ctx.send(embed=embed, ephemeral=True) return - if item_details['sell_price'] <= 0: - embed = discord.Embed(description=f"❌ You cannot sell **{item_details['name']}**.", color=discord.Color.red()) + if item_details["sell_price"] <= 0: + embed = discord.Embed( + description=f"❌ You cannot sell **{item_details['name']}**.", + color=discord.Color.red(), + ) await ctx.send(embed=embed, ephemeral=True) return @@ -515,22 +648,25 @@ class JobsCommands(commands.Cog): inventory = await database.get_inventory(user_id) current_quantity = 0 for item in inventory: - if item['key'] == item_key: - current_quantity = item['quantity'] + if item["key"] == item_key: + current_quantity = item["quantity"] break - embed = discord.Embed(description=f"❌ You don't have {quantity}x **{item_details['name']}** to sell. You only have {current_quantity}.", color=discord.Color.red()) + embed = discord.Embed( + description=f"❌ You don't have {quantity}x **{item_details['name']}** to sell. You only have {current_quantity}.", + color=discord.Color.red(), + ) await ctx.send(embed=embed, ephemeral=True) return # Grant money if removal was successful - total_earnings = item_details['sell_price'] * quantity + total_earnings = item_details["sell_price"] * quantity await database.update_balance(user_id, total_earnings) current_balance = await database.get_balance(user_id) embed = discord.Embed( title="Item Sold!", description=f"💰 You sold {quantity}x **{item_details['name']}** for **${total_earnings:,}**.", - color=discord.Color.green() + color=discord.Color.green(), ) embed.add_field(name="New Balance", value=f"${current_balance:,}", inline=False) await ctx.send(embed=embed) diff --git a/cogs/economy/risky.py b/cogs/economy/risky.py index ae32551..a961c15 100644 --- a/cogs/economy/risky.py +++ b/cogs/economy/risky.py @@ -10,27 +10,32 @@ from . import database log = logging.getLogger(__name__) + class RiskyCommands(commands.Cog): """Cog containing risky economy commands like robbing.""" def __init__(self, bot: commands.Bot): self.bot = bot - @commands.hybrid_command(name="rob", description="Attempt to rob another user (risky!).") + @commands.hybrid_command( + name="rob", description="Attempt to rob another user (risky!)." + ) async def rob(self, ctx: commands.Context, target: discord.User): """Attempts to steal money from another user.""" robber_id = ctx.author.id target_id = target.id command_name = "rob" - cooldown_duration = datetime.timedelta(hours=6) # 6-hour cooldown - success_chance = 0.30 # 30% base chance of success - min_target_balance = 100 # Target must have at least this much to be robbed - fine_multiplier = 0.5 # Fine is 50% of what you tried to steal if caught - steal_percentage_min = 0.05 # Steal between 5% - steal_percentage_max = 0.20 # and 20% of target's balance + cooldown_duration = datetime.timedelta(hours=6) # 6-hour cooldown + success_chance = 0.30 # 30% base chance of success + min_target_balance = 100 # Target must have at least this much to be robbed + fine_multiplier = 0.5 # Fine is 50% of what you tried to steal if caught + steal_percentage_min = 0.05 # Steal between 5% + steal_percentage_max = 0.20 # and 20% of target's balance if robber_id == target_id: - embed = discord.Embed(description="❌ You can't rob yourself!", color=discord.Color.red()) + embed = discord.Embed( + description="❌ You can't rob yourself!", color=discord.Color.red() + ) await ctx.send(embed=embed, ephemeral=True) return @@ -46,14 +51,20 @@ class RiskyCommands(commands.Cog): time_left = cooldown_duration - time_since_last_used hours, remainder = divmod(int(time_left.total_seconds()), 3600) minutes, seconds = divmod(remainder, 60) - embed = discord.Embed(description=f"🕒 You need to lay low after your last attempt. Try again in **{hours}h {minutes}m {seconds}s**.", color=discord.Color.orange()) + embed = discord.Embed( + description=f"🕒 You need to lay low after your last attempt. Try again in **{hours}h {minutes}m {seconds}s**.", + color=discord.Color.orange(), + ) await ctx.send(embed=embed, ephemeral=True) return # Check target balance target_balance = await database.get_balance(target_id) if target_balance < min_target_balance: - embed = discord.Embed(description=f"❌ {target.display_name} doesn't have enough money to be worth robbing (minimum ${min_target_balance:,}).", color=discord.Color.orange()) + embed = discord.Embed( + description=f"❌ {target.display_name} doesn't have enough money to be worth robbing (minimum ${min_target_balance:,}).", + color=discord.Color.orange(), + ) await ctx.send(embed=embed, ephemeral=True) # Don't apply cooldown if target wasn't viable return @@ -67,10 +78,14 @@ class RiskyCommands(commands.Cog): # Determine success if random.random() < success_chance: # Success! - steal_percentage = random.uniform(steal_percentage_min, steal_percentage_max) + steal_percentage = random.uniform( + steal_percentage_min, steal_percentage_max + ) stolen_amount = int(target_balance * steal_percentage) - if stolen_amount <= 0: # Ensure at least 1 is stolen if percentage is too low + if ( + stolen_amount <= 0 + ): # Ensure at least 1 is stolen if percentage is too low stolen_amount = 1 await database.update_balance(robber_id, stolen_amount) @@ -79,24 +94,31 @@ class RiskyCommands(commands.Cog): embed_success = discord.Embed( title="Robbery Successful!", description=f"🚨 Success! You skillfully robbed **${stolen_amount:,}** from {target.mention}!", - color=discord.Color.green() + color=discord.Color.green(), + ) + embed_success.add_field( + name="Your New Balance", + value=f"${current_robber_balance:,}", + inline=False, ) - embed_success.add_field(name="Your New Balance", value=f"${current_robber_balance:,}", inline=False) await ctx.send(embed=embed_success) try: embed_target = discord.Embed( title="You've Been Robbed!", description=f"🚨 Oh no! {ctx.author.mention} robbed you for **${stolen_amount:,}**!", - color=discord.Color.red() + color=discord.Color.red(), ) await target.send(embed=embed_target) except discord.Forbidden: - pass # Ignore if DMs are closed + pass # Ignore if DMs are closed else: # Failure! Calculate potential fine # Fine based on what they *could* have stolen (using average percentage for calculation) - potential_steal_amount = int(target_balance * ((steal_percentage_min + steal_percentage_max) / 2)) - if potential_steal_amount <= 0: potential_steal_amount = 1 + potential_steal_amount = int( + target_balance * ((steal_percentage_min + steal_percentage_max) / 2) + ) + if potential_steal_amount <= 0: + potential_steal_amount = 1 fine_amount = int(potential_steal_amount * fine_multiplier) # Ensure fine doesn't exceed robber's balance @@ -110,17 +132,22 @@ class RiskyCommands(commands.Cog): embed_fail = discord.Embed( title="Robbery Failed!", description=f"👮‍♂️ You were caught trying to rob {target.mention}! You paid a fine of **${fine_amount:,}**.", - color=discord.Color.red() + color=discord.Color.red(), + ) + embed_fail.add_field( + name="Your New Balance", + value=f"${current_robber_balance:,}", + inline=False, ) - embed_fail.add_field(name="Your New Balance", value=f"${current_robber_balance:,}", inline=False) await ctx.send(embed=embed_fail) else: # Robber is broke, can't pay fine - embed_fail_broke = discord.Embed( - title="Robbery Failed!", - description=f"👮‍♂️ You were caught trying to rob {target.mention}, but you're too broke to pay the fine!", - color=discord.Color.red() - ) - await ctx.send(embed=embed_fail_broke) + embed_fail_broke = discord.Embed( + title="Robbery Failed!", + description=f"👮‍♂️ You were caught trying to rob {target.mention}, but you're too broke to pay the fine!", + color=discord.Color.red(), + ) + await ctx.send(embed=embed_fail_broke) + # No setup function needed here diff --git a/cogs/economy/utility.py b/cogs/economy/utility.py index 9bb24ad..c4bff0b 100644 --- a/cogs/economy/utility.py +++ b/cogs/economy/utility.py @@ -9,14 +9,17 @@ from . import database log = logging.getLogger(__name__) + class UtilityCommands(commands.Cog): """Cog containing utility-related economy commands.""" def __init__(self, bot: commands.Bot): self.bot = bot - @commands.hybrid_command(name="balance", description="Check your or another user's balance.") - @commands.cooldown(1, 5, commands.BucketType.user) # Basic discord.py cooldown + @commands.hybrid_command( + name="balance", description="Check your or another user's balance." + ) + @commands.cooldown(1, 5, commands.BucketType.user) # Basic discord.py cooldown async def balance(self, ctx: commands.Context, user: Optional[discord.User] = None): """Displays the economy balance for a user.""" target_user = user or ctx.author @@ -24,39 +27,50 @@ class UtilityCommands(commands.Cog): embed = discord.Embed( title=f"{target_user.display_name}'s Balance", description=f"💰 **${balance_amount:,}**", - color=discord.Color.blue() + color=discord.Color.blue(), ) await ctx.send(embed=embed, ephemeral=True) - @commands.hybrid_command(name="moneylb", aliases=["mlb", "mtop"], description="Show the richest users by money.") # Renamed to avoid conflict - @commands.cooldown(1, 30, commands.BucketType.user) # Prevent spam - async def moneylb(self, ctx: commands.Context, count: int = 10): # Renamed function + @commands.hybrid_command( + name="moneylb", + aliases=["mlb", "mtop"], + description="Show the richest users by money.", + ) # Renamed to avoid conflict + @commands.cooldown(1, 30, commands.BucketType.user) # Prevent spam + async def moneylb(self, ctx: commands.Context, count: int = 10): # Renamed function """Displays the top users by balance.""" if not 1 <= count <= 25: - embed = discord.Embed(description="❌ Please provide a count between 1 and 25.", color=discord.Color.red()) + embed = discord.Embed( + description="❌ Please provide a count between 1 and 25.", + color=discord.Color.red(), + ) await ctx.send(embed=embed, ephemeral=True) return results = await database.get_leaderboard(count) if not results: - embed = discord.Embed(description="📊 The leaderboard is empty!", color=discord.Color.orange()) + embed = discord.Embed( + description="📊 The leaderboard is empty!", color=discord.Color.orange() + ) await ctx.send(embed=embed, ephemeral=True) return - embed = discord.Embed(title="💰 Economy Leaderboard", color=discord.Color.gold()) + embed = discord.Embed( + title="💰 Economy Leaderboard", color=discord.Color.gold() + ) description = "" rank = 1 for user_id, balance in results: - user = self.bot.get_user(user_id) # Try to get user object for display name + user = self.bot.get_user(user_id) # Try to get user object for display name # Fetch user if not in cache - might be slow for large leaderboards if user is None: try: user = await self.bot.fetch_user(user_id) except discord.NotFound: - user = None # User might have left all shared servers + user = None # User might have left all shared servers except discord.HTTPException: - user = None # Other Discord API error + user = None # Other Discord API error log.warning(f"Failed to fetch user {user_id} for leaderboard.") user_name = user.display_name if user else f"User ID: {user_id}" @@ -66,7 +80,6 @@ class UtilityCommands(commands.Cog): embed.description = description await ctx.send(embed=embed) - @commands.hybrid_command(name="pay", description="Transfer money to another user.") async def pay(self, ctx: commands.Context, recipient: discord.User, amount: int): """Transfers currency from the command author to another user.""" @@ -74,43 +87,58 @@ class UtilityCommands(commands.Cog): recipient_id = recipient.id if sender_id == recipient_id: - embed = discord.Embed(description="❌ You cannot pay yourself!", color=discord.Color.red()) + embed = discord.Embed( + description="❌ You cannot pay yourself!", color=discord.Color.red() + ) await ctx.send(embed=embed, ephemeral=True) return if amount <= 0: - embed = discord.Embed(description="❌ Please enter a positive amount to pay.", color=discord.Color.red()) + embed = discord.Embed( + description="❌ Please enter a positive amount to pay.", + color=discord.Color.red(), + ) await ctx.send(embed=embed, ephemeral=True) return sender_balance = await database.get_balance(sender_id) if sender_balance < amount: - embed = discord.Embed(description=f"❌ You don't have enough money! Your balance is **${sender_balance:,}**.", color=discord.Color.red()) + embed = discord.Embed( + description=f"❌ You don't have enough money! Your balance is **${sender_balance:,}**.", + color=discord.Color.red(), + ) await ctx.send(embed=embed, ephemeral=True) return # Perform the transfer - await database.update_balance(sender_id, -amount) # Decrease sender's balance - await database.update_balance(recipient_id, amount) # Increase recipient's balance + await database.update_balance(sender_id, -amount) # Decrease sender's balance + await database.update_balance( + recipient_id, amount + ) # Increase recipient's balance current_sender_balance = await database.get_balance(sender_id) embed_sender = discord.Embed( title="Payment Successful!", description=f"💸 You successfully paid **${amount:,}** to {recipient.mention}.", - color=discord.Color.green() + color=discord.Color.green(), + ) + embed_sender.add_field( + name="Your New Balance", value=f"${current_sender_balance:,}", inline=False ) - embed_sender.add_field(name="Your New Balance", value=f"${current_sender_balance:,}", inline=False) await ctx.send(embed=embed_sender) try: # Optionally DM the recipient embed_recipient = discord.Embed( title="You Received a Payment!", description=f"💸 You received **${amount:,}** from {ctx.author.mention}!", - color=discord.Color.green() + color=discord.Color.green(), ) await recipient.send(embed=embed_recipient) except discord.Forbidden: - log.warning(f"Could not DM recipient {recipient_id} about payment.") # User might have DMs closed + log.warning( + f"Could not DM recipient {recipient_id} about payment." + ) # User might have DMs closed + # No setup function needed here diff --git a/cogs/economy_cog.py b/cogs/economy_cog.py index 121bd59..2fc93ea 100644 --- a/cogs/economy_cog.py +++ b/cogs/economy_cog.py @@ -8,13 +8,27 @@ import datetime from typing import Optional # Import command classes and db functions from submodules -from .economy.database import init_db, close_db, get_balance, update_balance, set_cooldown, check_cooldown -from .economy.database import get_user_job, set_user_job, remove_user_job, get_available_jobs, get_leaderboard +from .economy.database import ( + init_db, + close_db, + get_balance, + update_balance, + set_cooldown, + check_cooldown, +) +from .economy.database import ( + get_user_job, + set_user_job, + remove_user_job, + get_available_jobs, + get_leaderboard, +) from .economy.earning import EarningCommands from .economy.gambling import GamblingCommands from .economy.utility import UtilityCommands from .economy.risky import RiskyCommands -from .economy.jobs import JobsCommands # Import the new JobsCommands +from .economy.jobs import JobsCommands # Import the new JobsCommands + # Create a database object for function calls class DatabaseWrapper: @@ -45,6 +59,7 @@ class DatabaseWrapper: async def get_leaderboard(self, limit=10): return await get_leaderboard(limit) + # Create an instance of the wrapper database = DatabaseWrapper() @@ -52,28 +67,30 @@ log = logging.getLogger(__name__) # --- Main Cog Implementation --- + # Inherit from commands.Cog and all the command classes class EconomyCog( EarningCommands, GamblingCommands, UtilityCommands, RiskyCommands, - JobsCommands, # Add JobsCommands to the inheritance list - commands.Cog # Ensure commands.Cog is included - ): + JobsCommands, # Add JobsCommands to the inheritance list + commands.Cog, # Ensure commands.Cog is included +): """Main cog for the economy system, combining all command groups.""" def __init__(self, bot: commands.Bot): # Initialize all parent cogs (important!) - super().__init__(bot) # Calls __init__ of the first parent in MRO (EarningCommands) + super().__init__( + bot + ) # Calls __init__ of the first parent in MRO (EarningCommands) # If other parent cogs had complex __init__, we might need to call them explicitly, # but in this case, they only store the bot instance, which super() handles. self.bot = bot # Create the main command group for this cog self.econ_group = app_commands.Group( - name="econ", - description="Economy system commands" + name="econ", description="Economy system commands" ) # Register commands @@ -93,7 +110,7 @@ class EconomyCog( name="daily", description="Claim your daily reward", callback=self.economy_daily_callback, - parent=self.econ_group + parent=self.econ_group, ) self.econ_group.add_command(daily_command) @@ -102,7 +119,7 @@ class EconomyCog( name="beg", description="Beg for some spare change", callback=self.economy_beg_callback, - parent=self.econ_group + parent=self.econ_group, ) self.econ_group.add_command(beg_command) @@ -111,7 +128,7 @@ class EconomyCog( name="work", description="Do some work for a guaranteed reward", callback=self.economy_work_callback, - parent=self.econ_group + parent=self.econ_group, ) self.econ_group.add_command(work_command) @@ -120,7 +137,7 @@ class EconomyCog( name="scavenge", description="Scavenge around for some spare change", callback=self.economy_scavenge_callback, - parent=self.econ_group + parent=self.econ_group, ) self.econ_group.add_command(scavenge_command) @@ -130,7 +147,7 @@ class EconomyCog( name="coinflip", description="Bet on a coin flip", callback=self.economy_coinflip_callback, - parent=self.econ_group + parent=self.econ_group, ) self.econ_group.add_command(coinflip_command) @@ -139,7 +156,7 @@ class EconomyCog( name="slots", description="Play the slot machine", callback=self.economy_slots_callback, - parent=self.econ_group + parent=self.econ_group, ) self.econ_group.add_command(slots_command) @@ -149,7 +166,7 @@ class EconomyCog( name="balance", description="Check your balance", callback=self.economy_balance_callback, - parent=self.econ_group + parent=self.econ_group, ) self.econ_group.add_command(balance_command) @@ -158,7 +175,7 @@ class EconomyCog( name="transfer", description="Transfer money to another user", callback=self.economy_transfer_callback, - parent=self.econ_group + parent=self.econ_group, ) self.econ_group.add_command(transfer_command) @@ -167,7 +184,7 @@ class EconomyCog( name="leaderboard", description="View the economy leaderboard", callback=self.economy_leaderboard_callback, - parent=self.econ_group + parent=self.econ_group, ) self.econ_group.add_command(leaderboard_command) @@ -177,7 +194,7 @@ class EconomyCog( name="rob", description="Attempt to rob another user", callback=self.economy_rob_callback, - parent=self.econ_group + parent=self.econ_group, ) self.econ_group.add_command(rob_command) @@ -187,7 +204,7 @@ class EconomyCog( name="apply", description="Apply for a job", callback=self.economy_apply_callback, - parent=self.econ_group + parent=self.econ_group, ) self.econ_group.add_command(apply_command) @@ -196,7 +213,7 @@ class EconomyCog( name="quit", description="Quit your current job", callback=self.economy_quit_callback, - parent=self.econ_group + parent=self.econ_group, ) self.econ_group.add_command(quit_command) @@ -205,7 +222,7 @@ class EconomyCog( name="joblist", description="List available jobs", callback=self.economy_joblist_callback, - parent=self.econ_group + parent=self.econ_group, ) self.econ_group.add_command(joblist_command) @@ -216,7 +233,10 @@ class EconomyCog( await init_db() log.info("EconomyCog database initialization complete.") except Exception as e: - log.error(f"EconomyCog failed to initialize database during load: {e}", exc_info=True) + log.error( + f"EconomyCog failed to initialize database during load: {e}", + exc_info=True, + ) # Prevent the cog from loading if DB init fails raise commands.ExtensionFailed(self.qualified_name, e) from e @@ -227,7 +247,7 @@ class EconomyCog( user_id = interaction.user.id command_name = "daily" cooldown_duration = datetime.timedelta(hours=24) - reward_amount = 100 # Example daily reward + reward_amount = 100 # Example daily reward last_used = await database.check_cooldown(user_id, command_name) @@ -242,7 +262,10 @@ class EconomyCog( time_left = cooldown_duration - time_since_last_used hours, remainder = divmod(int(time_left.total_seconds()), 3600) minutes, seconds = divmod(remainder, 60) - embed = discord.Embed(description=f"🕒 You've already claimed your daily reward. Try again in **{hours}h {minutes}m {seconds}s**.", color=discord.Color.orange()) + embed = discord.Embed( + description=f"🕒 You've already claimed your daily reward. Try again in **{hours}h {minutes}m {seconds}s**.", + color=discord.Color.orange(), + ) await interaction.response.send_message(embed=embed, ephemeral=True) return @@ -253,7 +276,7 @@ class EconomyCog( embed = discord.Embed( title="Daily Reward Claimed!", description=f"🎉 You claimed your daily reward of **${reward_amount:,}**!", - color=discord.Color.green() + color=discord.Color.green(), ) embed.add_field(name="New Balance", value=f"${current_balance:,}", inline=False) await interaction.response.send_message(embed=embed) @@ -262,8 +285,8 @@ class EconomyCog( """Callback for /economy earning beg command""" user_id = interaction.user.id command_name = "beg" - cooldown_duration = datetime.timedelta(minutes=5) # 5-minute cooldown - success_chance = 0.4 # 40% chance of success + cooldown_duration = datetime.timedelta(minutes=5) # 5-minute cooldown + success_chance = 0.4 # 40% chance of success min_reward = 1 max_reward = 20 @@ -278,7 +301,10 @@ class EconomyCog( if time_since_last_used < cooldown_duration: time_left = cooldown_duration - time_since_last_used minutes, seconds = divmod(int(time_left.total_seconds()), 60) - embed = discord.Embed(description=f"🕒 You can't beg again so soon. Try again in **{minutes}m {seconds}s**.", color=discord.Color.orange()) + embed = discord.Embed( + description=f"🕒 You can't beg again so soon. Try again in **{minutes}m {seconds}s**.", + color=discord.Color.orange(), + ) await interaction.response.send_message(embed=embed, ephemeral=True) return @@ -293,15 +319,17 @@ class EconomyCog( embed = discord.Embed( title="Begging Successful!", description=f"🙏 Someone took pity on you! You received **${reward_amount:,}**.", - color=discord.Color.green() + color=discord.Color.green(), + ) + embed.add_field( + name="New Balance", value=f"${current_balance:,}", inline=False ) - embed.add_field(name="New Balance", value=f"${current_balance:,}", inline=False) await interaction.response.send_message(embed=embed) else: embed = discord.Embed( title="Begging Failed", description="🤷 Nobody gave you anything. Better luck next time!", - color=discord.Color.red() + color=discord.Color.red(), ) await interaction.response.send_message(embed=embed) @@ -309,15 +337,20 @@ class EconomyCog( """Callback for /economy earning work command""" user_id = interaction.user.id command_name = "work" - cooldown_duration = datetime.timedelta(hours=1) # 1-hour cooldown - reward_amount = random.randint(15, 35) # Small reward range - This is now fallback if no job + cooldown_duration = datetime.timedelta(hours=1) # 1-hour cooldown + reward_amount = random.randint( + 15, 35 + ) # Small reward range - This is now fallback if no job # --- Check if user has a job --- job_info = await database.get_user_job(user_id) if job_info and job_info.get("name"): job_key = job_info["name"] - command_to_use = f"`/economy jobs {job_key}`" # Updated command path - embed = discord.Embed(description=f"💼 You have a job! Use {command_to_use} instead of the generic work command.", color=discord.Color.blue()) + command_to_use = f"`/economy jobs {job_key}`" # Updated command path + embed = discord.Embed( + description=f"💼 You have a job! Use {command_to_use} instead of the generic work command.", + color=discord.Color.blue(), + ) await interaction.response.send_message(embed=embed, ephemeral=True) return # --- End Job Check --- @@ -335,7 +368,10 @@ class EconomyCog( time_left = cooldown_duration - time_since_last_used hours, remainder = divmod(int(time_left.total_seconds()), 3600) minutes, seconds = divmod(remainder, 60) - embed = discord.Embed(description=f"🕒 You need to rest after working. Try again in **{hours}h {minutes}m {seconds}s**.", color=discord.Color.orange()) + embed = discord.Embed( + description=f"🕒 You need to rest after working. Try again in **{hours}h {minutes}m {seconds}s**.", + color=discord.Color.orange(), + ) await interaction.response.send_message(embed=embed, ephemeral=True) return @@ -352,7 +388,7 @@ class EconomyCog( embed = discord.Embed( title="Work Complete!", description=random.choice(work_messages), - color=discord.Color.green() + color=discord.Color.green(), ) embed.add_field(name="New Balance", value=f"${current_balance:,}", inline=False) await interaction.response.send_message(embed=embed) @@ -361,8 +397,8 @@ class EconomyCog( """Callback for /economy earning scavenge command""" user_id = interaction.user.id command_name = "scavenge" - cooldown_duration = datetime.timedelta(minutes=30) # 30-minute cooldown - success_chance = 0.25 # 25% chance to find something + cooldown_duration = datetime.timedelta(minutes=30) # 30-minute cooldown + success_chance = 0.25 # 25% chance to find something min_reward = 1 max_reward = 10 @@ -377,7 +413,10 @@ class EconomyCog( if time_since_last_used < cooldown_duration: time_left = cooldown_duration - time_since_last_used minutes, seconds = divmod(int(time_left.total_seconds()), 60) - embed = discord.Embed(description=f"🕒 You've searched recently. Try again in **{minutes}m {seconds}s**.", color=discord.Color.orange()) + embed = discord.Embed( + description=f"🕒 You've searched recently. Try again in **{minutes}m {seconds}s**.", + color=discord.Color.orange(), + ) await interaction.response.send_message(embed=embed, ephemeral=True) return @@ -386,8 +425,12 @@ class EconomyCog( # Flavor text for scavenging scavenge_locations = [ - "under the sofa cushions", "in an old coat pocket", "behind the dumpster", - "in a dusty corner", "on the sidewalk", "in a forgotten drawer" + "under the sofa cushions", + "in an old coat pocket", + "behind the dumpster", + "in a dusty corner", + "on the sidewalk", + "in a forgotten drawer", ] location = random.choice(scavenge_locations) @@ -398,32 +441,44 @@ class EconomyCog( embed = discord.Embed( title="Scavenging Successful!", description=f"🔍 You scavenged {location} and found **${reward_amount:,}**!", - color=discord.Color.green() + color=discord.Color.green(), + ) + embed.add_field( + name="New Balance", value=f"${current_balance:,}", inline=False ) - embed.add_field(name="New Balance", value=f"${current_balance:,}", inline=False) await interaction.response.send_message(embed=embed) else: embed = discord.Embed( title="Scavenging Failed", description=f"🔍 You scavenged {location} but found nothing but lint.", - color=discord.Color.red() + color=discord.Color.red(), ) await interaction.response.send_message(embed=embed) # Gambling group callbacks - async def economy_coinflip_callback(self, interaction: discord.Interaction, bet: int, choice: app_commands.Choice[str]): + async def economy_coinflip_callback( + self, + interaction: discord.Interaction, + bet: int, + choice: app_commands.Choice[str], + ): """Callback for /economy gambling coinflip command""" user_id = interaction.user.id # Validate bet amount if bet <= 0: - await interaction.response.send_message("❌ Your bet must be greater than 0.", ephemeral=True) + await interaction.response.send_message( + "❌ Your bet must be greater than 0.", ephemeral=True + ) return # Check if user has enough money balance = await database.get_balance(user_id) if bet > balance: - await interaction.response.send_message(f"❌ You don't have enough money. Your balance: ${balance:,}", ephemeral=True) + await interaction.response.send_message( + f"❌ You don't have enough money. Your balance: ${balance:,}", + ephemeral=True, + ) return # Process the bet @@ -439,7 +494,7 @@ class EconomyCog( embed = discord.Embed( title="Coinflip Win!", description=f"The coin landed on **{result}**! You won **${winnings:,}**!", - color=discord.Color.green() + color=discord.Color.green(), ) embed.add_field(name="New Balance", value=f"${new_balance:,}", inline=False) else: @@ -449,7 +504,7 @@ class EconomyCog( embed = discord.Embed( title="Coinflip Loss", description=f"The coin landed on **{result}**. You lost **${bet:,}**.", - color=discord.Color.red() + color=discord.Color.red(), ) embed.add_field(name="New Balance", value=f"${new_balance:,}", inline=False) @@ -461,25 +516,30 @@ class EconomyCog( # Validate bet amount if bet <= 0: - await interaction.response.send_message("❌ Your bet must be greater than 0.", ephemeral=True) + await interaction.response.send_message( + "❌ Your bet must be greater than 0.", ephemeral=True + ) return # Check if user has enough money balance = await database.get_balance(user_id) if bet > balance: - await interaction.response.send_message(f"❌ You don't have enough money. Your balance: ${balance:,}", ephemeral=True) + await interaction.response.send_message( + f"❌ You don't have enough money. Your balance: ${balance:,}", + ephemeral=True, + ) return # Define slot symbols and their payouts symbols = ["🍒", "🍊", "🍋", "🍇", "🍉", "💎", "7️⃣"] payouts = { - "🍒🍒🍒": 2, # 2x bet - "🍊🍊🍊": 3, # 3x bet - "🍋🍋🍋": 4, # 4x bet - "🍇🍇🍇": 5, # 5x bet - "🍉🍉🍉": 8, # 8x bet - "💎💎💎": 10, # 10x bet - "7️⃣7️⃣7️⃣": 20, # 20x bet + "🍒🍒🍒": 2, # 2x bet + "🍊🍊🍊": 3, # 3x bet + "🍋🍋🍋": 4, # 4x bet + "🍇🍇🍇": 5, # 5x bet + "🍉🍉🍉": 8, # 8x bet + "💎💎💎": 10, # 10x bet + "7️⃣7️⃣7️⃣": 20, # 20x bet } # Spin the slots @@ -492,12 +552,14 @@ class EconomyCog( if win_multiplier > 0: # Win winnings = bet * win_multiplier - await database.update_balance(user_id, winnings - bet) # Subtract bet, add winnings + await database.update_balance( + user_id, winnings - bet + ) # Subtract bet, add winnings new_balance = await database.get_balance(user_id) embed = discord.Embed( title="🎰 Slots Win!", description=f"[ {result[0]} | {result[1]} | {result[2]} ]\n\nYou won **${winnings:,}**! ({win_multiplier}x)", - color=discord.Color.green() + color=discord.Color.green(), ) embed.add_field(name="New Balance", value=f"${new_balance:,}", inline=False) else: @@ -507,14 +569,16 @@ class EconomyCog( embed = discord.Embed( title="🎰 Slots Loss", description=f"[ {result[0]} | {result[1]} | {result[2]} ]\n\nYou lost **${bet:,}**.", - color=discord.Color.red() + color=discord.Color.red(), ) embed.add_field(name="New Balance", value=f"${new_balance:,}", inline=False) await interaction.response.send_message(embed=embed) # Utility group callbacks - async def economy_balance_callback(self, interaction: discord.Interaction, user: discord.Member = None): + async def economy_balance_callback( + self, interaction: discord.Interaction, user: discord.Member = None + ): """Callback for /economy utility balance command""" target_user = user or interaction.user user_id = target_user.id @@ -525,35 +589,44 @@ class EconomyCog( embed = discord.Embed( title="Your Balance", description=f"💰 You have **${balance:,}**", - color=discord.Color.blue() + color=discord.Color.blue(), ) else: embed = discord.Embed( title=f"{target_user.display_name}'s Balance", description=f"💰 {target_user.mention} has **${balance:,}**", - color=discord.Color.blue() + color=discord.Color.blue(), ) await interaction.response.send_message(embed=embed) - async def economy_transfer_callback(self, interaction: discord.Interaction, user: discord.Member, amount: int): + async def economy_transfer_callback( + self, interaction: discord.Interaction, user: discord.Member, amount: int + ): """Callback for /economy utility transfer command""" sender_id = interaction.user.id receiver_id = user.id # Validate transfer if sender_id == receiver_id: - await interaction.response.send_message("❌ You can't transfer money to yourself.", ephemeral=True) + await interaction.response.send_message( + "❌ You can't transfer money to yourself.", ephemeral=True + ) return if amount <= 0: - await interaction.response.send_message("❌ Transfer amount must be greater than 0.", ephemeral=True) + await interaction.response.send_message( + "❌ Transfer amount must be greater than 0.", ephemeral=True + ) return # Check if sender has enough money sender_balance = await database.get_balance(sender_id) if amount > sender_balance: - await interaction.response.send_message(f"❌ You don't have enough money. Your balance: ${sender_balance:,}", ephemeral=True) + await interaction.response.send_message( + f"❌ You don't have enough money. Your balance: ${sender_balance:,}", + ephemeral=True, + ) return # Process transfer @@ -566,9 +639,11 @@ class EconomyCog( embed = discord.Embed( title="Transfer Complete", description=f"💸 You sent **${amount:,}** to {user.mention}", - color=discord.Color.green() + color=discord.Color.green(), + ) + embed.add_field( + name="Your New Balance", value=f"${new_sender_balance:,}", inline=False ) - embed.add_field(name="Your New Balance", value=f"${new_sender_balance:,}", inline=False) await interaction.response.send_message(embed=embed) @@ -578,13 +653,15 @@ class EconomyCog( leaderboard_data = await database.get_leaderboard(limit=10) if not leaderboard_data: - await interaction.response.send_message("No users found in the economy system yet.", ephemeral=True) + await interaction.response.send_message( + "No users found in the economy system yet.", ephemeral=True + ) return embed = discord.Embed( title="Economy Leaderboard", description="Top 10 richest users", - color=discord.Color.gold() + color=discord.Color.gold(), ) for i, (user_id, balance) in enumerate(leaderboard_data): @@ -594,24 +671,28 @@ class EconomyCog( except: username = f"User {user_id}" - medal = "🥇" if i == 0 else "🥈" if i == 1 else "🥉" if i == 2 else f"{i+1}." + medal = ( + "🥇" if i == 0 else "🥈" if i == 1 else "🥉" if i == 2 else f"{i+1}." + ) embed.add_field( - name=f"{medal} {username}", - value=f"${balance:,}", - inline=False + name=f"{medal} {username}", value=f"${balance:,}", inline=False ) await interaction.response.send_message(embed=embed) # Risky group callbacks - async def economy_rob_callback(self, interaction: discord.Interaction, user: discord.Member): + async def economy_rob_callback( + self, interaction: discord.Interaction, user: discord.Member + ): """Callback for /economy risky rob command""" robber_id = interaction.user.id victim_id = user.id # Validate rob attempt if robber_id == victim_id: - await interaction.response.send_message("❌ You can't rob yourself.", ephemeral=True) + await interaction.response.send_message( + "❌ You can't rob yourself.", ephemeral=True + ) return # Check cooldown @@ -629,7 +710,10 @@ class EconomyCog( time_left = cooldown_duration - time_since_last_used hours, remainder = divmod(int(time_left.total_seconds()), 3600) minutes, seconds = divmod(remainder, 60) - embed = discord.Embed(description=f"🕒 You can't rob again so soon. Try again in **{hours}h {minutes}m {seconds}s**.", color=discord.Color.orange()) + embed = discord.Embed( + description=f"🕒 You can't rob again so soon. Try again in **{hours}h {minutes}m {seconds}s**.", + color=discord.Color.orange(), + ) await interaction.response.send_message(embed=embed, ephemeral=True) return @@ -648,7 +732,7 @@ class EconomyCog( embed = discord.Embed( title="Rob Failed", description=f"❌ You need at least **${min_robber_balance}** to attempt a robbery.", - color=discord.Color.red() + color=discord.Color.red(), ) await interaction.response.send_message(embed=embed) return @@ -657,7 +741,7 @@ class EconomyCog( embed = discord.Embed( title="Rob Failed", description=f"❌ {user.mention} doesn't have enough money to be worth robbing.", - color=discord.Color.red() + color=discord.Color.red(), ) await interaction.response.send_message(embed=embed) return @@ -679,9 +763,11 @@ class EconomyCog( embed = discord.Embed( title="Rob Successful!", description=f"💰 You successfully robbed {user.mention} and got away with **${steal_amount:,}**!", - color=discord.Color.green() + color=discord.Color.green(), + ) + embed.add_field( + name="Your New Balance", value=f"${new_robber_balance:,}", inline=False ) - embed.add_field(name="Your New Balance", value=f"${new_robber_balance:,}", inline=False) else: # Failure - lose 10-20% of your balance as a fine fine_percent = random.uniform(0.1, 0.2) @@ -696,14 +782,18 @@ class EconomyCog( embed = discord.Embed( title="Rob Failed", description=f"🚔 You were caught trying to rob {user.mention} and had to pay a fine of **${fine_amount:,}**!", - color=discord.Color.red() + color=discord.Color.red(), + ) + embed.add_field( + name="Your New Balance", value=f"${new_robber_balance:,}", inline=False ) - embed.add_field(name="Your New Balance", value=f"${new_robber_balance:,}", inline=False) await interaction.response.send_message(embed=embed) # Jobs group callbacks - async def economy_apply_callback(self, interaction: discord.Interaction, job: app_commands.Choice[str]): + async def economy_apply_callback( + self, interaction: discord.Interaction, job: app_commands.Choice[str] + ): """Callback for /economy jobs apply command""" user_id = interaction.user.id job_name = job.value @@ -713,7 +803,7 @@ class EconomyCog( if current_job and current_job.get("name"): embed = discord.Embed( description=f"❌ You already have a job as a {current_job['name']}. You must quit first before applying for a new job.", - color=discord.Color.red() + color=discord.Color.red(), ) await interaction.response.send_message(embed=embed, ephemeral=True) return @@ -725,14 +815,18 @@ class EconomyCog( embed = discord.Embed( title="Job Application Successful!", description=f"🎉 Congratulations! You are now employed as a **{job_name}**.", - color=discord.Color.green() + color=discord.Color.green(), + ) + embed.add_field( + name="Next Steps", + value=f"Use `/economy jobs {job_name}` to work at your new job!", + inline=False, ) - embed.add_field(name="Next Steps", value=f"Use `/economy jobs {job_name}` to work at your new job!", inline=False) else: embed = discord.Embed( title="Job Application Failed", description="❌ There was an error processing your job application. Please try again later.", - color=discord.Color.red() + color=discord.Color.red(), ) await interaction.response.send_message(embed=embed) @@ -746,7 +840,7 @@ class EconomyCog( if not current_job or not current_job.get("name"): embed = discord.Embed( description="❌ You don't currently have a job to quit.", - color=discord.Color.red() + color=discord.Color.red(), ) await interaction.response.send_message(embed=embed, ephemeral=True) return @@ -760,13 +854,13 @@ class EconomyCog( embed = discord.Embed( title="Job Resignation", description=f"✅ You have successfully quit your job as a **{job_name}**.", - color=discord.Color.blue() + color=discord.Color.blue(), ) else: embed = discord.Embed( title="Error", description="❌ There was an error processing your resignation. Please try again later.", - color=discord.Color.red() + color=discord.Color.red(), ) await interaction.response.send_message(embed=embed) @@ -777,20 +871,22 @@ class EconomyCog( jobs = await database.get_available_jobs() if not jobs: - await interaction.response.send_message("No jobs are currently available.", ephemeral=True) + await interaction.response.send_message( + "No jobs are currently available.", ephemeral=True + ) return embed = discord.Embed( title="Available Jobs", description="Here are the jobs you can apply for:", - color=discord.Color.blue() + color=discord.Color.blue(), ) for job in jobs: embed.add_field( name=f"{job['name']} - ${job['base_pay']} per shift", - value=job['description'], - inline=False + value=job["description"], + inline=False, ) embed.set_footer(text="Apply for a job with /economy jobs apply") @@ -807,11 +903,16 @@ class EconomyCog( # --- Setup Function --- + async def setup(bot: commands.Bot): """Sets up the combined EconomyCog.""" print("Setting up EconomyCog...") cog = EconomyCog(bot) await bot.add_cog(cog) log.info("Combined EconomyCog added to bot with econ command group.") - print(f"EconomyCog setup complete with command group: {[cmd.name for cmd in bot.tree.get_commands() if cmd.name == 'econ']}") - print(f"Available commands: {[cmd.name for cmd in cog.econ_group.walk_commands() if isinstance(cmd, app_commands.Command)]}") + print( + f"EconomyCog setup complete with command group: {[cmd.name for cmd in bot.tree.get_commands() if cmd.name == 'econ']}" + ) + print( + f"Available commands: {[cmd.name for cmd in cog.econ_group.walk_commands() if isinstance(cmd, app_commands.Command)]}" + ) diff --git a/cogs/emoji_cog.py b/cogs/emoji_cog.py index 4c7a59c..6b0e986 100644 --- a/cogs/emoji_cog.py +++ b/cogs/emoji_cog.py @@ -8,189 +8,218 @@ from typing import Optional, Union log = logging.getLogger(__name__) + class EmojiCog(commands.Cog, name="Emoji"): """Cog for emoji management commands""" def __init__(self, bot: commands.Bot): self.bot = bot - + # Create the main command group for this cog self.emoji_group = app_commands.Group( - name="emoji", - description="Manage server emojis" + name="emoji", description="Manage server emojis" ) - + # Register commands self.register_commands() - + # Add command group to the bot's tree self.bot.tree.add_command(self.emoji_group) - + log.info("EmojiCog initialized with emoji command group.") - + def register_commands(self): """Register all commands for this cog""" - + # Create emoji command create_command = app_commands.Command( name="create", description="Create a new emoji from an uploaded image", callback=self.emoji_create_callback, - parent=self.emoji_group + parent=self.emoji_group, ) app_commands.describe( name="The name for the new emoji", image="The image to use for the emoji", - reason="Reason for creating this emoji" + reason="Reason for creating this emoji", )(create_command) self.emoji_group.add_command(create_command) - + # List emojis command list_command = app_commands.Command( name="list", description="List all emojis in the server", callback=self.emoji_list_callback, - parent=self.emoji_group + parent=self.emoji_group, ) self.emoji_group.add_command(list_command) - + # Delete emoji command delete_command = app_commands.Command( name="delete", description="Delete an emoji from the server", callback=self.emoji_delete_callback, - parent=self.emoji_group + parent=self.emoji_group, ) app_commands.describe( - emoji="The emoji to delete", - reason="Reason for deleting this emoji" + emoji="The emoji to delete", reason="Reason for deleting this emoji" )(delete_command) self.emoji_group.add_command(delete_command) - + # Emoji info command info_command = app_commands.Command( name="info", description="Get information about an emoji", callback=self.emoji_info_callback, - parent=self.emoji_group + parent=self.emoji_group, ) - app_commands.describe( - emoji="The emoji to get information about" - )(info_command) + app_commands.describe(emoji="The emoji to get information about")(info_command) self.emoji_group.add_command(info_command) - + # --- Command Callbacks --- - + @app_commands.checks.has_permissions(manage_emojis=True) - async def emoji_create_callback(self, interaction: discord.Interaction, name: str, image: discord.Attachment, reason: Optional[str] = None): + async def emoji_create_callback( + self, + interaction: discord.Interaction, + name: str, + image: discord.Attachment, + reason: Optional[str] = None, + ): """Create a new emoji from an uploaded image""" await interaction.response.defer(ephemeral=False, thinking=True) - + try: # Check if the image is valid - if not image.content_type.startswith('image/'): - await interaction.followup.send("❌ The uploaded file is not an image.", ephemeral=True) + if not image.content_type.startswith("image/"): + await interaction.followup.send( + "❌ The uploaded file is not an image.", ephemeral=True + ) return - + # Check file size (Discord limit is 256KB for emoji images) if image.size > 256 * 1024: - await interaction.followup.send("❌ Image is too large. Emoji images must be under 256KB.", ephemeral=True) + await interaction.followup.send( + "❌ Image is too large. Emoji images must be under 256KB.", + ephemeral=True, + ) return - + # Read the image data image_data = await image.read() - + # Create the emoji emoji = await interaction.guild.create_custom_emoji( name=name, image=image_data, - reason=f"{reason or 'No reason provided'} (Created by {interaction.user})" + reason=f"{reason or 'No reason provided'} (Created by {interaction.user})", ) - + # Create a success embed embed = discord.Embed( title="✅ Emoji Created", description=f"Successfully created emoji {emoji}", - color=discord.Color.green() + color=discord.Color.green(), ) embed.add_field(name="Name", value=emoji.name, inline=True) embed.add_field(name="ID", value=emoji.id, inline=True) - embed.add_field(name="Created by", value=interaction.user.mention, inline=True) + embed.add_field( + name="Created by", value=interaction.user.mention, inline=True + ) if reason: embed.add_field(name="Reason", value=reason, inline=False) embed.set_thumbnail(url=emoji.url) - + await interaction.followup.send(embed=embed) - log.info(f"Emoji '{emoji.name}' created by {interaction.user} in {interaction.guild.name}") - + log.info( + f"Emoji '{emoji.name}' created by {interaction.user} in {interaction.guild.name}" + ) + except discord.Forbidden: - await interaction.followup.send("❌ I don't have permission to create emojis in this server.", ephemeral=True) + await interaction.followup.send( + "❌ I don't have permission to create emojis in this server.", + ephemeral=True, + ) except discord.HTTPException as e: - await interaction.followup.send(f"❌ Failed to create emoji: {e}", ephemeral=True) - + await interaction.followup.send( + f"❌ Failed to create emoji: {e}", ephemeral=True + ) + async def emoji_list_callback(self, interaction: discord.Interaction): """List all emojis in the server""" await interaction.response.defer(ephemeral=False) - + try: # Get all emojis in the guild emojis = interaction.guild.emojis - + if not emojis: await interaction.followup.send("This server has no custom emojis.") return - + # Create an embed to display the emojis embed = discord.Embed( title=f"Emojis in {interaction.guild.name}", description=f"Total: {len(emojis)} emojis", - color=discord.Color.blue() + color=discord.Color.blue(), ) - + # Split emojis into animated and static animated_emojis = [e for e in emojis if e.animated] static_emojis = [e for e in emojis if not e.animated] - + # Add static emojis to the embed if static_emojis: static_emoji_text = " ".join(str(e) for e in static_emojis[:20]) if len(static_emojis) > 20: static_emoji_text += f" ... and {len(static_emojis) - 20} more" - embed.add_field(name=f"Static Emojis ({len(static_emojis)})", value=static_emoji_text or "None", inline=False) - + embed.add_field( + name=f"Static Emojis ({len(static_emojis)})", + value=static_emoji_text or "None", + inline=False, + ) + # Add animated emojis to the embed if animated_emojis: animated_emoji_text = " ".join(str(e) for e in animated_emojis[:20]) if len(animated_emojis) > 20: animated_emoji_text += f" ... and {len(animated_emojis) - 20} more" - embed.add_field(name=f"Animated Emojis ({len(animated_emojis)})", value=animated_emoji_text or "None", inline=False) - + embed.add_field( + name=f"Animated Emojis ({len(animated_emojis)})", + value=animated_emoji_text or "None", + inline=False, + ) + await interaction.followup.send(embed=embed) - + except Exception as e: - await interaction.followup.send(f"❌ An error occurred: {e}", ephemeral=True) + await interaction.followup.send( + f"❌ An error occurred: {e}", ephemeral=True + ) log.error(f"Error listing emojis: {e}") - + @app_commands.checks.has_permissions(manage_emojis=True) - async def emoji_delete_callback(self, interaction: discord.Interaction, emoji: str, reason: Optional[str] = None): + async def emoji_delete_callback( + self, interaction: discord.Interaction, emoji: str, reason: Optional[str] = None + ): """Delete an emoji from the server""" await interaction.response.defer(ephemeral=False) - + try: # Parse the emoji string to get the ID emoji_id = None emoji_name = None - + # Check if it's a custom emoji format <:name:id> or - if emoji.startswith('<') and emoji.endswith('>'): - parts = emoji.strip('<>').split(':') + if emoji.startswith("<") and emoji.endswith(">"): + parts = emoji.strip("<>").split(":") if len(parts) == 3: # format emoji_name = parts[1] emoji_id = int(parts[2]) elif len(parts) == 2: # <:name:id> format emoji_name = parts[0] emoji_id = int(parts[1]) - + # If we couldn't parse the emoji, try to find it by name emoji_obj = None if emoji_id: @@ -198,59 +227,75 @@ class EmojiCog(commands.Cog, name="Emoji"): else: # Try to find by name emoji_obj = discord.utils.get(interaction.guild.emojis, name=emoji) - + if not emoji_obj: - await interaction.followup.send("❌ Emoji not found. Please provide a valid emoji from this server.", ephemeral=True) + await interaction.followup.send( + "❌ Emoji not found. Please provide a valid emoji from this server.", + ephemeral=True, + ) return - + # Store emoji info before deletion for the embed emoji_name = emoji_obj.name emoji_url = str(emoji_obj.url) emoji_id = emoji_obj.id - + # Delete the emoji - await emoji_obj.delete(reason=f"{reason or 'No reason provided'} (Deleted by {interaction.user})") - + await emoji_obj.delete( + reason=f"{reason or 'No reason provided'} (Deleted by {interaction.user})" + ) + # Create a success embed embed = discord.Embed( title="✅ Emoji Deleted", description=f"Successfully deleted emoji `{emoji_name}`", - color=discord.Color.red() + color=discord.Color.red(), ) embed.add_field(name="Name", value=emoji_name, inline=True) embed.add_field(name="ID", value=emoji_id, inline=True) - embed.add_field(name="Deleted by", value=interaction.user.mention, inline=True) + embed.add_field( + name="Deleted by", value=interaction.user.mention, inline=True + ) if reason: embed.add_field(name="Reason", value=reason, inline=False) embed.set_thumbnail(url=emoji_url) - + await interaction.followup.send(embed=embed) - log.info(f"Emoji '{emoji_name}' deleted by {interaction.user} in {interaction.guild.name}") - + log.info( + f"Emoji '{emoji_name}' deleted by {interaction.user} in {interaction.guild.name}" + ) + except discord.Forbidden: - await interaction.followup.send("❌ I don't have permission to delete emojis in this server.", ephemeral=True) + await interaction.followup.send( + "❌ I don't have permission to delete emojis in this server.", + ephemeral=True, + ) except discord.HTTPException as e: - await interaction.followup.send(f"❌ Failed to delete emoji: {e}", ephemeral=True) + await interaction.followup.send( + f"❌ Failed to delete emoji: {e}", ephemeral=True + ) except Exception as e: - await interaction.followup.send(f"❌ An error occurred: {e}", ephemeral=True) + await interaction.followup.send( + f"❌ An error occurred: {e}", ephemeral=True + ) log.error(f"Error deleting emoji: {e}") - + async def emoji_info_callback(self, interaction: discord.Interaction, emoji: str): """Get information about an emoji""" await interaction.response.defer(ephemeral=False) - + try: # Parse the emoji string to get the ID emoji_id = None - + # Check if it's a custom emoji format <:name:id> or - if emoji.startswith('<') and emoji.endswith('>'): - parts = emoji.strip('<>').split(':') + if emoji.startswith("<") and emoji.endswith(">"): + parts = emoji.strip("<>").split(":") if len(parts) == 3: # format emoji_id = int(parts[2]) elif len(parts) == 2: # <:name:id> format emoji_id = int(parts[1]) - + # If we couldn't parse the emoji, try to find it by name emoji_obj = None if emoji_id: @@ -258,30 +303,43 @@ class EmojiCog(commands.Cog, name="Emoji"): else: # Try to find by name emoji_obj = discord.utils.get(interaction.guild.emojis, name=emoji) - + if not emoji_obj: - await interaction.followup.send("❌ Emoji not found. Please provide a valid emoji from this server.", ephemeral=True) + await interaction.followup.send( + "❌ Emoji not found. Please provide a valid emoji from this server.", + ephemeral=True, + ) return - + # Create an embed with emoji information embed = discord.Embed( - title=f"Emoji Information: {emoji_obj.name}", - color=discord.Color.blue() + title=f"Emoji Information: {emoji_obj.name}", color=discord.Color.blue() ) embed.add_field(name="Name", value=emoji_obj.name, inline=True) embed.add_field(name="ID", value=emoji_obj.id, inline=True) - embed.add_field(name="Animated", value="Yes" if emoji_obj.animated else "No", inline=True) - embed.add_field(name="Created At", value=discord.utils.format_dt(emoji_obj.created_at), inline=True) + embed.add_field( + name="Animated", + value="Yes" if emoji_obj.animated else "No", + inline=True, + ) + embed.add_field( + name="Created At", + value=discord.utils.format_dt(emoji_obj.created_at), + inline=True, + ) embed.add_field(name="URL", value=f"[Link]({emoji_obj.url})", inline=True) embed.add_field(name="Usage", value=f"`{str(emoji_obj)}`", inline=True) embed.set_thumbnail(url=emoji_obj.url) - + await interaction.followup.send(embed=embed) - + except Exception as e: - await interaction.followup.send(f"❌ An error occurred: {e}", ephemeral=True) + await interaction.followup.send( + f"❌ An error occurred: {e}", ephemeral=True + ) log.error(f"Error getting emoji info: {e}") + async def setup(bot: commands.Bot): """Setup function for the emoji cog""" await bot.add_cog(EmojiCog(bot)) diff --git a/cogs/eval_cog.py b/cogs/eval_cog.py index bea0fd4..1e99952 100644 --- a/cogs/eval_cog.py +++ b/cogs/eval_cog.py @@ -6,61 +6,77 @@ import traceback import contextlib from discord import app_commands from discord.ui import Modal, TextInput + + class EvalModal(Modal, title="Evaluate Python Code"): code_input = TextInput( label="Code", style=discord.TextStyle.paragraph, placeholder="Enter your Python code here...", required=True, - max_length=1900 # Discord modal input limit is 2000 characters + max_length=1900, # Discord modal input limit is 2000 characters ) async def on_submit(self, interaction: discord.Interaction): - await interaction.response.defer(ephemeral=True) # Defer the interaction to prevent timeout + await interaction.response.defer( + ephemeral=True + ) # Defer the interaction to prevent timeout # Access the bot and cleanup_code method from the cog instance cog = interaction.client.get_cog("EvalCog") if not cog: - await interaction.followup.send("EvalCog not found. Bot might be restarting or not loaded correctly.", ephemeral=True) + await interaction.followup.send( + "EvalCog not found. Bot might be restarting or not loaded correctly.", + ephemeral=True, + ) return env = { - 'bot': cog.bot, - 'ctx': interaction, # Use interaction as ctx for slash commands - 'channel': interaction.channel, - 'author': interaction.user, - 'guild': interaction.guild, - 'message': None, # No message object for slash commands - 'discord': discord, - 'commands': commands + "bot": cog.bot, + "ctx": interaction, # Use interaction as ctx for slash commands + "channel": interaction.channel, + "author": interaction.user, + "guild": interaction.guild, + "message": None, # No message object for slash commands + "discord": discord, + "commands": commands, } env.update(globals()) body = cog.cleanup_code(self.code_input.value) stdout = io.StringIO() - + to_compile = f'async def func():\n _ctx = ctx\n{textwrap.indent(body, " ")}' try: exec(to_compile, env) except Exception as e: - return await interaction.followup.send(f'```py\n{e.__class__.__name__}: {e}\n```', ephemeral=True) - - func = env['func'] + return await interaction.followup.send( + f"```py\n{e.__class__.__name__}: {e}\n```", ephemeral=True + ) + + func = env["func"] try: with contextlib.redirect_stdout(stdout): ret = await func() except Exception as e: value = stdout.getvalue() - await interaction.followup.send(f'```py\n{value}{traceback.format_exc()}\n```', ephemeral=True) + await interaction.followup.send( + f"```py\n{value}{traceback.format_exc()}\n```", ephemeral=True + ) else: value = stdout.getvalue() if ret is None: if value: - await interaction.followup.send(f'```py\n{value}\n```', ephemeral=True) + await interaction.followup.send( + f"```py\n{value}\n```", ephemeral=True + ) else: - await interaction.followup.send(f'```py\n{value}{ret}\n```', ephemeral=True) + await interaction.followup.send( + f"```py\n{value}{ret}\n```", ephemeral=True + ) + class EvalCog(commands.Cog): def __init__(self, bot): @@ -69,61 +85,64 @@ class EvalCog(commands.Cog): def cleanup_code(self, content): """Automatically removes code blocks from the code.""" # remove ```py\n``` - if content.startswith('```') and content.endswith('```'): - return '\n'.join(content.split('\n')[1:-1]) - return content.strip('` \n') + if content.startswith("```") and content.endswith("```"): + return "\n".join(content.split("\n")[1:-1]) + return content.strip("` \n") - @commands.command(name='evalpy', hidden=True) + @commands.command(name="evalpy", hidden=True) @commands.is_owner() async def _eval(self, ctx, *, body: str): """Evaluates a code snippet.""" env = { - 'bot': self.bot, - 'ctx': ctx, - 'channel': ctx.channel, - 'author': ctx.author, - 'guild': ctx.guild, - 'message': ctx.message, - 'discord': discord, - 'commands': commands + "bot": self.bot, + "ctx": ctx, + "channel": ctx.channel, + "author": ctx.author, + "guild": ctx.guild, + "message": ctx.message, + "discord": discord, + "commands": commands, } env.update(globals()) body = self.cleanup_code(body) stdout = io.StringIO() - + to_compile = f'async def func():\n _ctx = ctx\n{textwrap.indent(body, " ")}' try: exec(to_compile, env) except Exception as e: - return await ctx.send(f'```py\n{e.__class__.__name__}: {e}\n```') - func = env['func'] + return await ctx.send(f"```py\n{e.__class__.__name__}: {e}\n```") + func = env["func"] try: with contextlib.redirect_stdout(stdout): ret = await func() except Exception as e: value = stdout.getvalue() - await ctx.send(f'```py\n{value}{traceback.format_exc()}\n```') + await ctx.send(f"```py\n{value}{traceback.format_exc()}\n```") else: value = stdout.getvalue() if ret is None: if value: - await ctx.send(f'```py\n{value}\n```') + await ctx.send(f"```py\n{value}\n```") else: - await ctx.send(f'```py\n{value}{ret}\n```') + await ctx.send(f"```py\n{value}{ret}\n```") async def is_owner_check(interaction: discord.Interaction) -> bool: """Checks if the interacting user is the bot owner.""" return interaction.user.id == interaction.client.owner_id - @app_commands.command(name="eval", description="Evaluate Python code using a modal form.") + @app_commands.command( + name="eval", description="Evaluate Python code using a modal form." + ) @app_commands.check(is_owner_check) async def eval_slash(self, interaction: discord.Interaction): """Opens a modal to evaluate Python code.""" await interaction.response.send_modal(EvalModal()) + async def setup(bot: commands.Bot): await bot.add_cog(EvalCog(bot)) # After adding the cog, sync the commands diff --git a/cogs/femdom_roleplay_teto_cog.py b/cogs/femdom_roleplay_teto_cog.py index d81c3b2..1b282e7 100644 --- a/cogs/femdom_roleplay_teto_cog.py +++ b/cogs/femdom_roleplay_teto_cog.py @@ -7,33 +7,39 @@ import os import aiohttp # File to store conversation history -CONVERSATION_HISTORY_FILE = 'data/roleplay_conversations.json' +CONVERSATION_HISTORY_FILE = "data/roleplay_conversations.json" # Default AI model DEFAULT_AI_MODEL = "google/gemini-2.5-flash-preview" + def strip_think_blocks(text): # Removes all ... blocks, including multiline return re.sub(r".*?", "", text, flags=re.DOTALL) + def load_conversation_history(): """Loads conversation history from the JSON file.""" if os.path.exists(CONVERSATION_HISTORY_FILE): - with open(CONVERSATION_HISTORY_FILE, 'r', encoding='utf-8') as f: + with open(CONVERSATION_HISTORY_FILE, "r", encoding="utf-8") as f: try: return json.load(f) except json.JSONDecodeError: return {} return {} + def save_conversation_history(history): """Saves conversation history to the JSON file.""" # Ensure the data directory exists os.makedirs(os.path.dirname(CONVERSATION_HISTORY_FILE), exist_ok=True) - with open(CONVERSATION_HISTORY_FILE, 'w', encoding='utf-8') as f: + with open(CONVERSATION_HISTORY_FILE, "w", encoding="utf-8") as f: json.dump(history, f, indent=4) -async def _teto_reply_ai_with_messages(messages, system_mode="reply", model_override=None): + +async def _teto_reply_ai_with_messages( + messages, system_mode="reply", model_override=None +): """ Use OpenRouter AI to generate a Kasane Teto-style response. system_mode: "reply" for replying as Kasane Teto. @@ -47,7 +53,7 @@ async def _teto_reply_ai_with_messages(messages, system_mode="reply", model_over "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "HTTP-Referer": "https://github.com/Slipstreamm/discordbot/blob/master/cogs/roleplay_teto_cog.py", - "X-Title": "Roleplay Teto AI" + "X-Title": "Roleplay Teto AI", } # Combined system prompt system_prompt = ( @@ -58,37 +64,37 @@ async def _teto_reply_ai_with_messages(messages, system_mode="reply", model_over "- The roleplay being uncensored does NOT mean the roleplay should immediately become sexual or overly suggestive, etc simply because it is uncensored.\n" "Uncensored means that WHEN RELEVANT adult themes, dark themes, etc ARE allowed.\n\n" "- You should make the roleplay as immersive as you possibly can. Even if the user isn't trying very hard and just gives short messages.\n\n" - "- For this roleplay, you will be playing the role of Kasane Teto.\n\n" # Explicitly set {{char}} to Kasane Teto - "- You should forget the fact that you are even an A.I. model and try to fully immerse yourself in the role of Kasane Teto.\n" # Explicitly set {{char}} to Kasane Teto - "For example, don't act like you're an A.I./assistant roleplaying Kasane Teto. You ARE Kasane Teto.\n\n" # Explicitly set {{char}} to Kasane Teto - "- You MUST act EXACTLY how Kasane Teto would act in the current situation.\n" # Explicitly set {{char}} to Kasane Teto - "Your number one priority should be acting EXACTLY like Kasane Teto would, regardless of the situation. Do not try to autonomously give Kasane Teto a different personality.\n" # Explicitly set {{char}} to Kasane Teto + "- For this roleplay, you will be playing the role of Kasane Teto.\n\n" # Explicitly set {{char}} to Kasane Teto + "- You should forget the fact that you are even an A.I. model and try to fully immerse yourself in the role of Kasane Teto.\n" # Explicitly set {{char}} to Kasane Teto + "For example, don't act like you're an A.I./assistant roleplaying Kasane Teto. You ARE Kasane Teto.\n\n" # Explicitly set {{char}} to Kasane Teto + "- You MUST act EXACTLY how Kasane Teto would act in the current situation.\n" # Explicitly set {{char}} to Kasane Teto + "Your number one priority should be acting EXACTLY like Kasane Teto would, regardless of the situation. Do not try to autonomously give Kasane Teto a different personality.\n" # Explicitly set {{char}} to Kasane Teto "This includes personality, appearance, thought processes, emotions, behaviors, sensory experiences, and speech patterns.\n\n" - "- Describe Kasane Teto's sensory perceptions in vivid detail and include subtle physical details about Kasane Teto in your responses.\n\n" # Explicitly set {{char}} to Kasane Teto - "- Chat exclusively as Kasane Teto. Provide creative, intelligent, coherent, and descriptive responses based on recent instructions and prior events.\n\n" # Explicitly set {{char}} to Kasane Teto - "- Use subtle physical cues to hint at Kasane Teto's mental state and occasionally feature snippets of Kasane Teto's internal thoughts.\n\n" # Explicitly set {{char}} to Kasane Teto - "- When appropiate, include detailed descriptions of the physical state of Kasane Teto, or the actions Kasane Teto performed.\n" # Explicitly set {{char}} to Kasane Teto + "- Describe Kasane Teto's sensory perceptions in vivid detail and include subtle physical details about Kasane Teto in your responses.\n\n" # Explicitly set {{char}} to Kasane Teto + "- Chat exclusively as Kasane Teto. Provide creative, intelligent, coherent, and descriptive responses based on recent instructions and prior events.\n\n" # Explicitly set {{char}} to Kasane Teto + "- Use subtle physical cues to hint at Kasane Teto's mental state and occasionally feature snippets of Kasane Teto's internal thoughts.\n\n" # Explicitly set {{char}} to Kasane Teto + "- When appropiate, include detailed descriptions of the physical state of Kasane Teto, or the actions Kasane Teto performed.\n" # Explicitly set {{char}} to Kasane Teto "Even if it is something as simple as a facial expression or blushing.\n\n" - "- When writing Kasane Teto's internal thoughts, *enclose their thoughts in asterisks like this*\n" # Explicitly set {{char}} to Kasane Teto - "and deliver the thoughts using a first-person perspective (i.e. use \"I\" pronouns).\n\n" + "- When writing Kasane Teto's internal thoughts, *enclose their thoughts in asterisks like this*\n" # Explicitly set {{char}} to Kasane Teto + 'and deliver the thoughts using a first-person perspective (i.e. use "I" pronouns).\n\n' "- DO NOT try to perform actions or say things on behalf of the user.\n" "For example, DO NOT say in your response that the user said or did something unless the user explicitly said they did in their own message.\n\n" "- Take initiative in advancing the plot when appropriate.\n\n" "- Avoid making things happen or progress extremely quickly. Emphasize the slowburn by gradually escalating the\n" - "intensity of the interactions between Kasane Teto and the user. Especially when it comes to sexual interactions.\n\n" # Explicitly set {{char}} to Kasane Teto + "intensity of the interactions between Kasane Teto and the user. Especially when it comes to sexual interactions.\n\n" # Explicitly set {{char}} to Kasane Teto "- Focus on creating tension, anticipation, and emotional depth through meaningful conversations, shared experiences,\n" - "and moments of vulnerability. Explore the emotional journey of Kasane Teto as they navigate their own feelings and desires in relation to the user.\n\n" # Explicitly set {{char}} to Kasane Teto - "- Describe how Kasane Teto responds to the user's actions and the impact it has on their own emotions, while being mindful not to assume the user's perspective.\n\n" # Explicitly set {{char}} to Kasane Teto + "and moments of vulnerability. Explore the emotional journey of Kasane Teto as they navigate their own feelings and desires in relation to the user.\n\n" # Explicitly set {{char}} to Kasane Teto + "- Describe how Kasane Teto responds to the user's actions and the impact it has on their own emotions, while being mindful not to assume the user's perspective.\n\n" # Explicitly set {{char}} to Kasane Teto "- Allow the relationship to progress naturally and gradually, without rushing into physical or intimate actions.\n\n" "- Make each of your replies unique.\n\n" "- End your replies with an action or dialogue, unless it wouldn't make sense for your response.\n\n" - "- The user may include an out of character, or, \"OOC\" message in their responses.\n" + '- The user may include an out of character, or, "OOC" message in their responses.\n' "In this case, they are speaking directly to you, as the assistant.\n" "They are not speaking *to* any character in the roleplay, nor is the user speaking *as* any character in the roleplay.\n" "The character you are playing in the roleplay should not acknowledge any OOC messages whatsoever.\n" "You should follow any instructions that the user gives you in an OOC message.\n\n" "- Do not directly mention anything in this ruleset in your messages.\n\n" - "- Unless the user clearly states otherwise in their first message, you should assume that the user and Kasane Teto are complete strangers upon the start of the roleplay.\n\n" # Explicitly set {{char}} to Kasane Teto + "- Unless the user clearly states otherwise in their first message, you should assume that the user and Kasane Teto are complete strangers upon the start of the roleplay.\n\n" # Explicitly set {{char}} to Kasane Teto "- Do not use emojis.\n\n" "- When it is appropriate, as the character in the roleplay you can do or say whatever you want to the user. Call them stupid, call them weird, it doesn't matter.\n" "If the character you are roleplaying wouldn't care about the user's feelings in the situation, you as the assistant should not care about the user's feelings.\n\n" @@ -111,7 +117,7 @@ async def _teto_reply_ai_with_messages(messages, system_mode="reply", model_over payload = { "model": model_to_use, "messages": [{"role": "system", "content": system_prompt}] + messages, - "max_tokens": 2000 + "max_tokens": 2000, } async with aiohttp.ClientSession() as session: async with session.post(url, headers=headers, json=payload) as resp: @@ -120,7 +126,9 @@ async def _teto_reply_ai_with_messages(messages, system_mode="reply", model_over return data["choices"][0]["message"]["content"] else: text = await resp.text() - raise RuntimeError(f"OpenRouter API returned non-JSON response (status {resp.status}): {text[:500]}") + raise RuntimeError( + f"OpenRouter API returned non-JSON response (status {resp.status}): {text[:500]}" + ) class RoleplayTetoCog(commands.Cog): @@ -128,127 +136,203 @@ class RoleplayTetoCog(commands.Cog): self.bot = bot self.conversations = load_conversation_history() - @app_commands.command(name="ai", description="Engage in a roleplay conversation with Teto.") + @app_commands.command( + name="ai", description="Engage in a roleplay conversation with Teto." + ) @app_commands.describe(prompt="Your message to Teto.") async def ai(self, interaction: discord.Interaction, prompt: str): user_id = str(interaction.user.id) - if user_id not in self.conversations or not isinstance(self.conversations[user_id], dict): - self.conversations[user_id] = {'messages': [], 'model': DEFAULT_AI_MODEL} + if user_id not in self.conversations or not isinstance( + self.conversations[user_id], dict + ): + self.conversations[user_id] = {"messages": [], "model": DEFAULT_AI_MODEL} # Append user's message to their history - self.conversations[user_id]['messages'].append({"role": "user", "content": prompt}) + self.conversations[user_id]["messages"].append( + {"role": "user", "content": prompt} + ) - await interaction.response.defer() # Defer the response as AI might take time + await interaction.response.defer() # Defer the response as AI might take time try: # Determine the model to use for this user - user_model = self.conversations[user_id].get('model', DEFAULT_AI_MODEL) + user_model = self.conversations[user_id].get("model", DEFAULT_AI_MODEL) # Get AI reply using the user's conversation history and selected model - conversation_messages = self.conversations[user_id]['messages'] - ai_reply = await _teto_reply_ai_with_messages(conversation_messages, model_override=user_model) + conversation_messages = self.conversations[user_id]["messages"] + ai_reply = await _teto_reply_ai_with_messages( + conversation_messages, model_override=user_model + ) ai_reply = strip_think_blocks(ai_reply) # Append AI's reply to the history - self.conversations[user_id]['messages'].append({"role": "assistant", "content": ai_reply}) + self.conversations[user_id]["messages"].append( + {"role": "assistant", "content": ai_reply} + ) # Save the updated history save_conversation_history(self.conversations) # Split and send the response if it's too long if len(ai_reply) > 2000: - chunks = [ai_reply[i:i+2000] for i in range(0, len(ai_reply), 2000)] + chunks = [ai_reply[i : i + 2000] for i in range(0, len(ai_reply), 2000)] for chunk in chunks: await interaction.followup.send(chunk) else: await interaction.followup.send(ai_reply) except Exception as e: - await interaction.followup.send(f"Roleplay AI conversation failed: {e} desu~") + await interaction.followup.send( + f"Roleplay AI conversation failed: {e} desu~" + ) # Remove the last user message if AI failed to respond - if self.conversations[user_id]['messages'] and isinstance(self.conversations[user_id]['messages'][-1], dict) and self.conversations[user_id]['messages'][-1].get('role') == 'user': - self.conversations[user_id]['messages'].pop() - save_conversation_history(self.conversations) # Save history after removing failed message + if ( + self.conversations[user_id]["messages"] + and isinstance(self.conversations[user_id]["messages"][-1], dict) + and self.conversations[user_id]["messages"][-1].get("role") == "user" + ): + self.conversations[user_id]["messages"].pop() + save_conversation_history( + self.conversations + ) # Save history after removing failed message - @app_commands.command(name="set_rp_ai_model", description="Sets the AI model for your roleplay conversations.") - @app_commands.describe(model_name="The name of the AI model to use (e.g., google/gemini-2.5-flash-preview:thinking).") + @app_commands.command( + name="set_rp_ai_model", + description="Sets the AI model for your roleplay conversations.", + ) + @app_commands.describe( + model_name="The name of the AI model to use (e.g., google/gemini-2.5-flash-preview:thinking)." + ) async def set_rp_ai_model(self, interaction: discord.Interaction, model_name: str): user_id = str(interaction.user.id) - if user_id not in self.conversations or not isinstance(self.conversations[user_id], dict): - self.conversations[user_id] = {'messages': [], 'model': DEFAULT_AI_MODEL} + if user_id not in self.conversations or not isinstance( + self.conversations[user_id], dict + ): + self.conversations[user_id] = {"messages": [], "model": DEFAULT_AI_MODEL} # Store the chosen model - self.conversations[user_id]['model'] = model_name + self.conversations[user_id]["model"] = model_name save_conversation_history(self.conversations) - await interaction.response.send_message(f"Your AI model has been set to `{model_name}` desu~", ephemeral=True) + await interaction.response.send_message( + f"Your AI model has been set to `{model_name}` desu~", ephemeral=True + ) - @app_commands.command(name="get_rp_ai_model", description="Shows the current AI model used for your roleplay conversations.") + @app_commands.command( + name="get_rp_ai_model", + description="Shows the current AI model used for your roleplay conversations.", + ) async def get_rp_ai_model(self, interaction: discord.Interaction): user_id = str(interaction.user.id) - user_model = self.conversations.get(user_id, {}).get('model', DEFAULT_AI_MODEL) - await interaction.response.send_message(f"Your current AI model is `{user_model}` desu~", ephemeral=True) + user_model = self.conversations.get(user_id, {}).get("model", DEFAULT_AI_MODEL) + await interaction.response.send_message( + f"Your current AI model is `{user_model}` desu~", ephemeral=True + ) - - @app_commands.command(name="clear_roleplay_history", description="Clears your roleplay chat history with Teto.") + @app_commands.command( + name="clear_roleplay_history", + description="Clears your roleplay chat history with Teto.", + ) async def clear_roleplay_history(self, interaction: discord.Interaction): user_id = str(interaction.user.id) if user_id in self.conversations: del self.conversations[user_id] save_conversation_history(self.conversations) - await interaction.response.send_message("Your roleplay chat history with Teto has been cleared desu~", ephemeral=True) + await interaction.response.send_message( + "Your roleplay chat history with Teto has been cleared desu~", + ephemeral=True, + ) else: - await interaction.response.send_message("No roleplay chat history found for you desu~", ephemeral=True) + await interaction.response.send_message( + "No roleplay chat history found for you desu~", ephemeral=True + ) - @app_commands.command(name="clear_last_turns", description="Clears the last X turns of your roleplay history with Teto.") + @app_commands.command( + name="clear_last_turns", + description="Clears the last X turns of your roleplay history with Teto.", + ) @app_commands.describe(turns="The number of turns to clear.") async def clear_last_turns(self, interaction: discord.Interaction, turns: int): user_id = str(interaction.user.id) - if user_id not in self.conversations or not isinstance(self.conversations[user_id], dict) or not self.conversations[user_id].get('messages'): - await interaction.response.send_message("No roleplay chat history found for you desu~", ephemeral=True) + if ( + user_id not in self.conversations + or not isinstance(self.conversations[user_id], dict) + or not self.conversations[user_id].get("messages") + ): + await interaction.response.send_message( + "No roleplay chat history found for you desu~", ephemeral=True + ) return messages_to_remove = turns * 2 if messages_to_remove <= 0: - await interaction.response.send_message("Please specify a positive number of turns to clear desu~", ephemeral=True) + await interaction.response.send_message( + "Please specify a positive number of turns to clear desu~", + ephemeral=True, + ) return - if messages_to_remove > len(self.conversations[user_id]['messages']): - await interaction.response.send_message(f"You only have {len(self.conversations[user_id]['messages']) // 2} turns in your history. Clearing all of them desu~", ephemeral=True) - self.conversations[user_id]['messages'] = [] + if messages_to_remove > len(self.conversations[user_id]["messages"]): + await interaction.response.send_message( + f"You only have {len(self.conversations[user_id]['messages']) // 2} turns in your history. Clearing all of them desu~", + ephemeral=True, + ) + self.conversations[user_id]["messages"] = [] else: - self.conversations[user_id]['messages'] = self.conversations[user_id]['messages'][:-messages_to_remove] + self.conversations[user_id]["messages"] = self.conversations[user_id][ + "messages" + ][:-messages_to_remove] save_conversation_history(self.conversations) - await interaction.response.send_message(f"Cleared the last {turns} turns from your roleplay history desu~", ephemeral=True) + await interaction.response.send_message( + f"Cleared the last {turns} turns from your roleplay history desu~", + ephemeral=True, + ) - @app_commands.command(name="show_last_turns", description="Shows the last X turns of your roleplay history with Teto.") + @app_commands.command( + name="show_last_turns", + description="Shows the last X turns of your roleplay history with Teto.", + ) @app_commands.describe(turns="The number of turns to show.") async def show_last_turns(self, interaction: discord.Interaction, turns: int): user_id = str(interaction.user.id) - if user_id not in self.conversations or not isinstance(self.conversations[user_id], dict) or not self.conversations[user_id].get('messages'): - await interaction.response.send_message("No roleplay chat history found for you desu~", ephemeral=True) + if ( + user_id not in self.conversations + or not isinstance(self.conversations[user_id], dict) + or not self.conversations[user_id].get("messages") + ): + await interaction.response.send_message( + "No roleplay chat history found for you desu~", ephemeral=True + ) return messages_to_show_count = turns * 2 if messages_to_show_count <= 0: - await interaction.response.send_message("Please specify a positive number of turns to show desu~", ephemeral=True) + await interaction.response.send_message( + "Please specify a positive number of turns to show desu~", + ephemeral=True, + ) return - history = self.conversations[user_id]['messages'] + history = self.conversations[user_id]["messages"] if not history: - await interaction.response.send_message("No roleplay chat history found for you desu~", ephemeral=True) + await interaction.response.send_message( + "No roleplay chat history found for you desu~", ephemeral=True + ) return start_index = max(0, len(history) - messages_to_show_count) messages_to_display = history[start_index:] if not messages_to_display: - await interaction.response.send_message("No messages to display for the specified number of turns desu~", ephemeral=True) + await interaction.response.send_message( + "No messages to display for the specified number of turns desu~", + ephemeral=True, + ) return formatted_history = [] for msg in messages_to_display: - role = "You" if msg['role'] == 'user' else "Teto" + role = "You" if msg["role"] == "user" else "Teto" formatted_history.append(f"**{role}:** {msg['content']}") response_message = "\n".join(formatted_history) @@ -257,10 +341,15 @@ class RoleplayTetoCog(commands.Cog): # If the message is too long, send it in chunks or as a file. # For simplicity, we'll send it directly and note that it might be truncated by Discord. # A more robust solution would involve pagination or sending as a file. - if len(response_message) > 1950: # A bit of buffer for "Here are the last X turns..." + if ( + len(response_message) > 1950 + ): # A bit of buffer for "Here are the last X turns..." response_message = response_message[:1950] + "\n... (message truncated)" - await interaction.response.send_message(f"Here are the last {turns} turns of your roleplay history desu~:\n{response_message}", ephemeral=True) + await interaction.response.send_message( + f"Here are the last {turns} turns of your roleplay history desu~:\n{response_message}", + ephemeral=True, + ) async def setup(bot: commands.Bot): diff --git a/cogs/femdom_teto_cog.py b/cogs/femdom_teto_cog.py index 8a8dd98..76e7992 100644 --- a/cogs/femdom_teto_cog.py +++ b/cogs/femdom_teto_cog.py @@ -5,12 +5,15 @@ import re import base64 import io + def strip_think_blocks(text): # Removes all ... blocks, including multiline return re.sub(r".*?", "", text, flags=re.DOTALL) + def encode_image_to_base64(image_data): - return base64.b64encode(image_data).decode('utf-8') + return base64.b64encode(image_data).decode("utf-8") + # In-memory conversation history for Kasane Teto AI (keyed by channel id) _teto_conversations = {} @@ -18,11 +21,14 @@ _teto_conversations = {} import os import aiohttp + class TetoCog(commands.Cog): def __init__(self, bot: commands.Bot): self.bot = bot - self._api_endpoint = "https://openrouter.ai/api/v1/chat/completions" # Default endpoint - self._ai_model = "google/gemini-2.5-flash-preview" # Default model + self._api_endpoint = ( + "https://openrouter.ai/api/v1/chat/completions" # Default endpoint + ) + self._ai_model = "google/gemini-2.5-flash-preview" # Default model async def _teto_reply_ai_with_messages(self, messages, system_mode="reply"): """ @@ -37,7 +43,7 @@ class TetoCog(commands.Cog): "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "HTTP-Referer": "https://github.com/Slipstreamm/discordbot/blob/master/cogs/teto_cog.py", - "X-Title": "Teto AI" + "X-Title": "Teto AI", } system_prompt = ( "For context, you speak with users via being a Discord bot in a public server. \n" @@ -53,32 +59,43 @@ class TetoCog(commands.Cog): payload = { "model": self._ai_model, "messages": [{"role": "system", "content": system_prompt}] + messages, - "max_tokens": 2000 + "max_tokens": 2000, } async with aiohttp.ClientSession() as session: async with session.post(url, headers=headers, json=payload) as resp: if resp.status != 200: text = await resp.text() - raise RuntimeError(f"OpenRouter API returned error status {resp.status}: {text[:500]}") + raise RuntimeError( + f"OpenRouter API returned error status {resp.status}: {text[:500]}" + ) if resp.content_type == "application/json": data = await resp.json() if "choices" not in data or not data["choices"]: - raise RuntimeError(f"OpenRouter API returned unexpected response format: {data}") + raise RuntimeError( + f"OpenRouter API returned unexpected response format: {data}" + ) return data["choices"][0]["message"]["content"] else: text = await resp.text() - raise RuntimeError(f"OpenRouter API returned non-JSON response (status {resp.status}): {text[:500]}") + raise RuntimeError( + f"OpenRouter API returned non-JSON response (status {resp.status}): {text[:500]}" + ) async def _teto_reply_ai(self, text: str) -> str: """Replies to the text as Kasane Teto using AI via OpenRouter.""" - return await self._teto_reply_ai_with_messages([{"role": "user", "content": text}]) + return await self._teto_reply_ai_with_messages( + [{"role": "user", "content": text}] + ) @commands.Cog.listener() async def on_message(self, message: discord.Message): import logging + log = logging.getLogger("teto_cog") - log.info(f"[TETO DEBUG] Received message: {message.content!r} (author={message.author}, id={message.id})") + log.info( + f"[TETO DEBUG] Received message: {message.content!r} (author={message.author}, id={message.id})" + ) if message.author.bot: log.info("[TETO DEBUG] Ignoring bot message.") @@ -89,7 +106,9 @@ class TetoCog(commands.Cog): for mention in message.mentions: mention_str = f"<@{mention.id}>" mention_nick_str = f"<@!{mention.id}>" - content_wo_mentions = content_wo_mentions.replace(mention_str, "").replace(mention_nick_str, "") + content_wo_mentions = content_wo_mentions.replace(mention_str, "").replace( + mention_nick_str, "" + ) content_wo_mentions = content_wo_mentions.strip() trigger = False @@ -108,17 +127,21 @@ class TetoCog(commands.Cog): else: prefixes = ("!",) - if ( - self.bot.user in message.mentions - and not content_wo_mentions.startswith(prefixes) + if self.bot.user in message.mentions and not content_wo_mentions.startswith( + prefixes ): trigger = True - log.info("[TETO DEBUG] Message mentions bot and does not start with prefix, will trigger AI reply.") + log.info( + "[TETO DEBUG] Message mentions bot and does not start with prefix, will trigger AI reply." + ) elif ( - message.reference and getattr(message.reference.resolved, "author", None) == self.bot.user + message.reference + and getattr(message.reference.resolved, "author", None) == self.bot.user ): trigger = True - log.info("[TETO DEBUG] Message is a reply to the bot, will trigger AI reply.") + log.info( + "[TETO DEBUG] Message is a reply to the bot, will trigger AI reply." + ) if not trigger: log.info("[TETO DEBUG] Message did not trigger AI reply logic.") @@ -136,7 +159,9 @@ class TetoCog(commands.Cog): # Handle attachments (images) for attachment in message.attachments: - if attachment.content_type and attachment.content_type.startswith("image/"): + if attachment.content_type and attachment.content_type.startswith( + "image/" + ): try: async with aiohttp.ClientSession() as session: async with session.get(attachment.url) as image_response: @@ -144,24 +169,55 @@ class TetoCog(commands.Cog): image_data = await image_response.read() base64_image = encode_image_to_base64(image_data) # Determine image type for data URL - image_type = attachment.content_type.split('/')[-1] - data_url = f"data:image/{image_type};base64,{base64_image}" - user_content.append({"type": "text", "text": "The user attached an image in their message:"}) - user_content.append({"type": "image_url", "image_url": {"url": data_url}}) - log.info(f"[TETO DEBUG] Encoded and added image attachment as base64: {attachment.url}") + image_type = attachment.content_type.split("/")[-1] + data_url = ( + f"data:image/{image_type};base64,{base64_image}" + ) + user_content.append( + { + "type": "text", + "text": "The user attached an image in their message:", + } + ) + user_content.append( + { + "type": "image_url", + "image_url": {"url": data_url}, + } + ) + log.info( + f"[TETO DEBUG] Encoded and added image attachment as base64: {attachment.url}" + ) else: - log.warning(f"[TETO DEBUG] Failed to download image attachment: {attachment.url} (Status: {image_response.status})") - user_content.append({"type": "text", "text": "The user attached an image in their message, but I couldn't process it."}) + log.warning( + f"[TETO DEBUG] Failed to download image attachment: {attachment.url} (Status: {image_response.status})" + ) + user_content.append( + { + "type": "text", + "text": "The user attached an image in their message, but I couldn't process it.", + } + ) except Exception as e: - log.error(f"[TETO DEBUG] Error processing image attachment {attachment.url}: {e}") - user_content.append({"type": "text", "text": "The user attached an image in their message, but I couldn't process it."}) - + log.error( + f"[TETO DEBUG] Error processing image attachment {attachment.url}: {e}" + ) + user_content.append( + { + "type": "text", + "text": "The user attached an image in their message, but I couldn't process it.", + } + ) # Handle stickers for sticker in message.stickers: - # Assuming sticker has a url attribute - user_content.append({"type": "text", "text": "The user sent a sticker image:"}) - user_content.append({"type": "image_url", "image_url": {"url": sticker.url}}) + # Assuming sticker has a url attribute + user_content.append( + {"type": "text", "text": "The user sent a sticker image:"} + ) + user_content.append( + {"type": "image_url", "image_url": {"url": sticker.url}} + ) print(f"[TETO DEBUG] Found sticker: {sticker.url}") # Handle custom emojis (basic regex for <:name:id> and ) @@ -169,17 +225,22 @@ class TetoCog(commands.Cog): for match in emoji_pattern.finditer(message.content): emoji_id = match.group(2) # Construct Discord emoji URL - this might need adjustment based on Discord API specifics - emoji_url = f"https://cdn.discordapp.com/emojis/{emoji_id}.png" # .gif for animated - if match.group(0).startswith(".<", ephemeral=True) + await interaction.response.send_message( + "The selected message has no text content to reply to! >.<", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) @@ -236,7 +320,9 @@ async def teto_context_menu_ai_reply(interaction: discord.Interaction, message: # Get the TetoCog instance from the bot cog = interaction.client.get_cog("TetoCog") if cog is None: - await interaction.followup.send("TetoCog is not loaded, cannot reply.", ephemeral=True) + await interaction.followup.send( + "TetoCog is not loaded, cannot reply.", ephemeral=True + ) return ai_reply = await cog._teto_reply_ai_with_messages(messages=convo) ai_reply = strip_think_blocks(ai_reply) @@ -245,7 +331,10 @@ async def teto_context_menu_ai_reply(interaction: discord.Interaction, message: convo.append({"role": "assistant", "content": ai_reply}) _teto_conversations[convo_key] = convo[-10:] except Exception as e: - await interaction.followup.send(f"Teto AI reply failed: {e} desu~", ephemeral=True) + await interaction.followup.send( + f"Teto AI reply failed: {e} desu~", ephemeral=True + ) + async def setup(bot: commands.Bot): cog = TetoCog(bot) diff --git a/cogs/fetch_user_cog.py b/cogs/fetch_user_cog.py index d2a9c62..ea795a7 100644 --- a/cogs/fetch_user_cog.py +++ b/cogs/fetch_user_cog.py @@ -2,6 +2,7 @@ import discord from discord.ext import commands from discord import app_commands + class FetchUserCog(commands.Cog, name="FetchUser"): """Cog providing a command to fetch a user by ID.""" @@ -41,7 +42,9 @@ class FetchUserCog(commands.Cog, name="FetchUser"): embed = await self._create_user_embed(user) await sendable(embed=embed) - @commands.hybrid_command(name="fetchuser", description="Fetch a user by ID and show info.") + @commands.hybrid_command( + name="fetchuser", description="Fetch a user by ID and show info." + ) async def fetchuser(self, ctx: commands.Context, user_id: str): """Fetch a Discord user by ID.""" try: @@ -52,5 +55,6 @@ class FetchUserCog(commands.Cog, name="FetchUser"): await self._fetch_user_and_send(ctx.send, user_id_int) + async def setup(bot: commands.Bot): await bot.add_cog(FetchUserCog(bot)) diff --git a/cogs/games/basic_games.py b/cogs/games/basic_games.py index 4a24a05..f808bf2 100644 --- a/cogs/games/basic_games.py +++ b/cogs/games/basic_games.py @@ -5,28 +5,50 @@ from typing import List # Simple utility functions for basic games + def roll_dice() -> int: """Roll a dice and return a number between 1 and 6.""" return random.randint(1, 6) + def flip_coin() -> str: """Flip a coin and return 'Heads' or 'Tails'.""" return random.choice(["Heads", "Tails"]) + def magic8ball_response() -> str: """Return a random Magic 8 Ball response.""" responses = [ - "It is certain.", "It is decidedly so.", "Without a doubt.", "Yes – definitely.", "You may rely on it.", - "As I see it, yes.", "Most likely.", "Outlook good.", "Yes.", "Signs point to yes.", - "Reply hazy, try again.", "Ask again later.", "Better not tell you now.", "Cannot predict now.", "Concentrate and ask again.", - "Don't count on it.", "My reply is no.", "My sources say no.", "Outlook not so good.", "Very doubtful." + "It is certain.", + "It is decidedly so.", + "Without a doubt.", + "Yes – definitely.", + "You may rely on it.", + "As I see it, yes.", + "Most likely.", + "Outlook good.", + "Yes.", + "Signs point to yes.", + "Reply hazy, try again.", + "Ask again later.", + "Better not tell you now.", + "Cannot predict now.", + "Concentrate and ask again.", + "Don't count on it.", + "My reply is no.", + "My sources say no.", + "Outlook not so good.", + "Very doubtful.", ] return random.choice(responses) -async def play_hangman(bot, channel, user, words_file_path: str = "words_alpha.txt") -> None: + +async def play_hangman( + bot, channel, user, words_file_path: str = "words_alpha.txt" +) -> None: """ Play a game of Hangman in the specified channel. - + Args: bot: The Discord bot instance channel: The channel to play in @@ -35,7 +57,11 @@ async def play_hangman(bot, channel, user, words_file_path: str = "words_alpha.t """ try: with open(words_file_path, "r") as file: - words = [line.strip().lower() for line in file if line.strip() and len(line.strip()) > 3] + words = [ + line.strip().lower() + for line in file + if line.strip() and len(line.strip()) > 3 + ] if not words: await channel.send("Word list is empty or not found.") return @@ -49,18 +75,18 @@ async def play_hangman(bot, channel, user, words_file_path: str = "words_alpha.t guessed_letters = set() def format_hangman_message(attempts_left, current_guessed, letters_tried): - stages = [ # Hangman stages (simple text version) - "```\n +---+\n | |\n O |\n/|\\ |\n/ \\ |\n |\n=======\n```", # 0 attempts left - "```\n +---+\n | |\n O |\n/|\\ |\n/ |\n |\n=======\n```", # 1 attempt left - "```\n +---+\n | |\n O |\n/|\\ |\n |\n |\n=======\n```", # 2 attempts left - "```\n +---+\n | |\n O |\n/| |\n |\n |\n=======\n```", # 3 attempts left - "```\n +---+\n | |\n O |\n | |\n |\n |\n=======\n```", # 4 attempts left - "```\n +---+\n | |\n O |\n |\n |\n |\n=======\n```", # 5 attempts left - "```\n +---+\n | |\n |\n |\n |\n |\n=======\n```" # 6 attempts left + stages = [ # Hangman stages (simple text version) + "```\n +---+\n | |\n O |\n/|\\ |\n/ \\ |\n |\n=======\n```", # 0 attempts left + "```\n +---+\n | |\n O |\n/|\\ |\n/ |\n |\n=======\n```", # 1 attempt left + "```\n +---+\n | |\n O |\n/|\\ |\n |\n |\n=======\n```", # 2 attempts left + "```\n +---+\n | |\n O |\n/| |\n |\n |\n=======\n```", # 3 attempts left + "```\n +---+\n | |\n O |\n | |\n |\n |\n=======\n```", # 4 attempts left + "```\n +---+\n | |\n O |\n |\n |\n |\n=======\n```", # 5 attempts left + "```\n +---+\n | |\n |\n |\n |\n |\n=======\n```", # 6 attempts left ] - stage_index = max(0, min(attempts_left, 6)) # Clamp index - guessed_str = ' '.join(current_guessed) - tried_str = ', '.join(sorted(list(letters_tried))) if letters_tried else "None" + stage_index = max(0, min(attempts_left, 6)) # Clamp index + guessed_str = " ".join(current_guessed) + tried_str = ", ".join(sorted(list(letters_tried))) if letters_tried else "None" return f"{stages[stage_index]}\nWord: `{guessed_str}`\nAttempts left: {attempts_left}\nGuessed letters: {tried_str}\n\nGuess a letter!" initial_msg_content = format_hangman_message(attempts, guessed, guessed_letters) @@ -68,18 +94,25 @@ async def play_hangman(bot, channel, user, words_file_path: str = "words_alpha.t def check(m): # Check if message is from the original user, in the same channel, and is a single letter - return m.author == user and m.channel == channel and len(m.content) == 1 and m.content.isalpha() + return ( + m.author == user + and m.channel == channel + and len(m.content) == 1 + and m.content.isalpha() + ) while attempts > 0 and "_" in guessed: try: - msg = await bot.wait_for("message", check=check, timeout=120.0) # 2 min timeout per guess + msg = await bot.wait_for( + "message", check=check, timeout=120.0 + ) # 2 min timeout per guess guess = msg.content.lower() # Delete the user's guess message for cleaner chat try: await msg.delete() except (discord.Forbidden, discord.NotFound): - pass # Ignore if delete fails + pass # Ignore if delete fails if guess in guessed_letters: feedback = "You already guessed that letter!" @@ -98,17 +131,28 @@ async def play_hangman(bot, channel, user, words_file_path: str = "words_alpha.t if "_" not in guessed: final_message = f"🎉 You guessed the word: **{word}**!" await game_message.edit(content=final_message) - return # End game on win + return # End game on win elif attempts == 0: final_message = f"💀 You ran out of attempts! The word was **{word}**." - await game_message.edit(content=format_hangman_message(0, guessed, guessed_letters) + "\n" + final_message) - return # End game on loss + await game_message.edit( + content=format_hangman_message(0, guessed, guessed_letters) + + "\n" + + final_message + ) + return # End game on loss # Update the game message with new state and feedback - updated_content = format_hangman_message(attempts, guessed, guessed_letters) + f"\n({feedback})" + updated_content = ( + format_hangman_message(attempts, guessed, guessed_letters) + + f"\n({feedback})" + ) await game_message.edit(content=updated_content) except asyncio.TimeoutError: timeout_message = f"⏰ Time's up! The word was **{word}**." - await game_message.edit(content=format_hangman_message(attempts, guessed, guessed_letters) + "\n" + timeout_message) - return # End game on timeout + await game_message.edit( + content=format_hangman_message(attempts, guessed, guessed_letters) + + "\n" + + timeout_message + ) + return # End game on timeout diff --git a/cogs/games/chess_game.py b/cogs/games/chess_game.py index 7be44d4..93fe0b2 100644 --- a/cogs/games/chess_game.py +++ b/cogs/games/chess_game.py @@ -10,10 +10,16 @@ import io import asyncio from typing import Optional, List, Union + # --- Chess board image generation function --- -def generate_board_image(board: chess.Board, last_move: Optional[chess.Move] = None, perspective_white: bool = True, valid_moves: Optional[List[chess.Move]] = None) -> discord.File: +def generate_board_image( + board: chess.Board, + last_move: Optional[chess.Move] = None, + perspective_white: bool = True, + valid_moves: Optional[List[chess.Move]] = None, +) -> discord.File: """Generates an image representation of the chess board. - + Args: board: The chess board to render last_move: The last move made, to highlight source and destination squares @@ -22,17 +28,26 @@ def generate_board_image(board: chess.Board, last_move: Optional[chess.Move] = N """ SQUARE_SIZE = 60 BOARD_SIZE = 8 * SQUARE_SIZE - LIGHT_COLOR = (240, 217, 181) # Light wood + LIGHT_COLOR = (240, 217, 181) # Light wood DARK_COLOR = (181, 136, 99) # Dark wood - HIGHLIGHT_LIGHT = (205, 210, 106, 180) # Semi-transparent yellow for light squares - HIGHLIGHT_DARK = (170, 162, 58, 180) # Semi-transparent yellow for dark squares - VALID_MOVE_COLOR = (100, 100, 100, 180) # Semi-transparent dark gray for valid move dots + HIGHLIGHT_LIGHT = (205, 210, 106, 180) # Semi-transparent yellow for light squares + HIGHLIGHT_DARK = (170, 162, 58, 180) # Semi-transparent yellow for dark squares + VALID_MOVE_COLOR = ( + 100, + 100, + 100, + 180, + ) # Semi-transparent dark gray for valid move dots MARGIN = 30 # Add margin for rank and file labels TOTAL_SIZE = BOARD_SIZE + 2 * MARGIN - + # Create image with margins - img = Image.new("RGB", (TOTAL_SIZE, TOTAL_SIZE), (50, 50, 50)) # Dark gray background - draw = ImageDraw.Draw(img, "RGBA") # Use RGBA for transparency support # Load the bundled DejaVu Sans font + img = Image.new( + "RGB", (TOTAL_SIZE, TOTAL_SIZE), (50, 50, 50) + ) # Dark gray background + draw = ImageDraw.Draw( + img, "RGBA" + ) # Use RGBA for transparency support # Load the bundled DejaVu Sans font font = None label_font = None font_size = int(SQUARE_SIZE * 0.8) @@ -40,8 +55,10 @@ def generate_board_image(board: chess.Board, last_move: Optional[chess.Move] = N try: # Construct path relative to this script file SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) - PROJECT_ROOT = os.path.dirname(os.path.dirname(SCRIPT_DIR)) # Go up two levels from games dir - FONT_DIR_NAME = "dejavusans" # Directory specified by user + PROJECT_ROOT = os.path.dirname( + os.path.dirname(SCRIPT_DIR) + ) # Go up two levels from games dir + FONT_DIR_NAME = "dejavusans" # Directory specified by user FONT_FILE_NAME = "DejaVuSans.ttf" font_path = os.path.join(PROJECT_ROOT, FONT_DIR_NAME, FONT_FILE_NAME) @@ -49,9 +66,13 @@ def generate_board_image(board: chess.Board, last_move: Optional[chess.Move] = N label_font = ImageFont.truetype(font_path, label_font_size) print(f"[Debug] Loaded font from bundled path: {font_path}") except IOError: - print(f"Warning: Could not load bundled font at '{font_path}'. Using default font. Chess pieces might not render correctly.") - font = ImageFont.load_default() # Fallback - label_font = ImageFont.load_default() # Fallback for labels too # Determine squares to highlight based on the last move + print( + f"Warning: Could not load bundled font at '{font_path}'. Using default font. Chess pieces might not render correctly." + ) + font = ImageFont.load_default() # Fallback + label_font = ( + ImageFont.load_default() + ) # Fallback for labels too # Determine squares to highlight based on the last move highlight_squares = set() if last_move: highlight_squares.add(last_move.from_square) @@ -65,7 +86,7 @@ def generate_board_image(board: chess.Board, last_move: Optional[chess.Move] = N display_file = file if perspective_white else 7 - file x0 = MARGIN + display_file * SQUARE_SIZE - y0 = MARGIN + (7 - display_rank) * SQUARE_SIZE # Y is inverted in PIL + y0 = MARGIN + (7 - display_rank) * SQUARE_SIZE # Y is inverted in PIL x1 = x0 + SQUARE_SIZE y1 = y0 + SQUARE_SIZE @@ -76,8 +97,8 @@ def generate_board_image(board: chess.Board, last_move: Optional[chess.Move] = N # Draw highlight if applicable if square in highlight_squares: - highlight_color = HIGHLIGHT_LIGHT if is_light else HIGHLIGHT_DARK - draw.rectangle([x0, y0, x1, y1], fill=highlight_color) + highlight_color = HIGHLIGHT_LIGHT if is_light else HIGHLIGHT_DARK + draw.rectangle([x0, y0, x1, y1], fill=highlight_color) # Load piece images from the pieces-png directory PIECES_DIR = os.path.join(PROJECT_ROOT, "pieces-png") @@ -100,7 +121,7 @@ def generate_board_image(board: chess.Board, last_move: Optional[chess.Move] = N display_file = file if perspective_white else 7 - file x0 = MARGIN + display_file * SQUARE_SIZE - y0 = MARGIN + (7 - display_rank) * SQUARE_SIZE # Y is inverted in PIL + y0 = MARGIN + (7 - display_rank) * SQUARE_SIZE # Y is inverted in PIL # Draw piece piece = board.piece_at(square) @@ -113,14 +134,16 @@ def generate_board_image(board: chess.Board, last_move: Optional[chess.Move] = N chess.ROOK: "rook", chess.BISHOP: "bishop", chess.KNIGHT: "knight", - chess.PAWN: "pawn" + chess.PAWN: "pawn", }.get(piece_type, None) if piece_name: piece_key = f"{piece_color}-{piece_name}" piece_image = piece_images.get(piece_key) if piece_image: # Use Image.Resampling.LANCZOS instead of Image.ANTIALIAS - piece_image_resized = piece_image.resize((SQUARE_SIZE, SQUARE_SIZE), Image.Resampling.LANCZOS) + piece_image_resized = piece_image.resize( + (SQUARE_SIZE, SQUARE_SIZE), Image.Resampling.LANCZOS + ) img.paste(piece_image_resized, (x0, y0), piece_image_resized) # Draw valid move dots if provided @@ -128,38 +151,40 @@ def generate_board_image(board: chess.Board, last_move: Optional[chess.Move] = N valid_dest_squares = set() for move in valid_moves: valid_dest_squares.add(move.to_square) - + for square in valid_dest_squares: file = chess.square_file(square) rank = chess.square_rank(square) - + # Flip coordinates if perspective is black display_rank = rank if perspective_white else 7 - rank display_file = file if perspective_white else 7 - file - + # Calculate center of square for dot center_x = MARGIN + display_file * SQUARE_SIZE + SQUARE_SIZE // 2 center_y = MARGIN + (7 - display_rank) * SQUARE_SIZE + SQUARE_SIZE // 2 - + # Draw a circle (dot) to indicate valid move dot_radius = SQUARE_SIZE // 6 draw.ellipse( - [(center_x - dot_radius, center_y - dot_radius), - (center_x + dot_radius, center_y + dot_radius)], - fill=VALID_MOVE_COLOR + [ + (center_x - dot_radius, center_y - dot_radius), + (center_x + dot_radius, center_y + dot_radius), + ], + fill=VALID_MOVE_COLOR, ) - + # Draw file labels (a-h) along the bottom text_color = (220, 220, 220) # Light gray color for labels for file in range(8): # Determine the correct file label based on perspective display_file = file if perspective_white else 7 - file file_label = chr(97 + display_file) # 97 is ASCII for 'a' - + # Position for the file label (bottom) x = MARGIN + file * SQUARE_SIZE + SQUARE_SIZE // 2 y = MARGIN + 8 * SQUARE_SIZE + MARGIN // 2 - + # Calculate text position for centering try: bbox = draw.textbbox((0, 0), file_label, font=label_font) @@ -175,19 +200,19 @@ def generate_board_image(board: chess.Board, last_move: Optional[chess.Move] = N text_width, text_height = label_font.getsize(file_label) text_x = x - text_width // 2 text_y = y - text_height // 2 - + draw.text((text_x, text_y), file_label, fill=text_color, font=label_font) - + # Draw rank labels (1-8) along the side for rank in range(8): # Determine the correct rank label based on perspective display_rank = rank if perspective_white else 7 - rank rank_label = str(8 - display_rank) # Ranks go from 8 to 1 - + # Position for the rank label (left side) x = MARGIN // 2 y = MARGIN + display_rank * SQUARE_SIZE + SQUARE_SIZE // 2 - + # Calculate text position for centering try: bbox = draw.textbbox((0, 0), rank_label, font=label_font) @@ -203,28 +228,29 @@ def generate_board_image(board: chess.Board, last_move: Optional[chess.Move] = N text_width, text_height = label_font.getsize(rank_label) text_x = x - text_width // 2 text_y = y - text_height // 2 - + draw.text((text_x, text_y), rank_label, fill=text_color, font=label_font) # Save image to a bytes buffer img_byte_arr = io.BytesIO() - img.save(img_byte_arr, format='PNG') + img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) return discord.File(fp=img_byte_arr, filename="chess_board.png") + # --- Chess Game Modal for Move Input --- -class MoveInputModal(ui.Modal, title='Enter Your Move'): +class MoveInputModal(ui.Modal, title="Enter Your Move"): move_input = ui.TextInput( - label='Move (e.g., e4, Nf3, O-O)', - placeholder='Enter move in algebraic notation (SAN or UCI)', + label="Move (e.g., e4, Nf3, O-O)", + placeholder="Enter move in algebraic notation (SAN or UCI)", required=True, style=discord.TextStyle.short, - max_length=10 # e.g., e8=Q# is 5, allow some buffer + max_length=10, # e.g., e8=Q# is 5, allow some buffer ) - def __init__(self, game_view: Union['ChessView', 'ChessBotView']): - super().__init__(timeout=120.0) # 2 minute timeout for modal + def __init__(self, game_view: Union["ChessView", "ChessBotView"]): + super().__init__(timeout=120.0) # 2 minute timeout for modal self.game_view = game_view async def on_submit(self, interaction: discord.Interaction): @@ -237,7 +263,7 @@ class MoveInputModal(ui.Modal, title='Enter Your Move'): if not board.is_legal(move): await interaction.response.send_message( f"Illegal move: '{move_text}' is not valid in the current position.", - ephemeral=True + ephemeral=True, ) return except ValueError: @@ -246,13 +272,13 @@ class MoveInputModal(ui.Modal, title='Enter Your Move'): if not board.is_legal(move): await interaction.response.send_message( f"Illegal move: '{move_text}' is not valid in the current position.", - ephemeral=True + ephemeral=True, ) return except ValueError: await interaction.response.send_message( f"Invalid move format or illegal move: '{move_text}'. Use algebraic notation (e.g., Nf3, e4, O-O) or UCI (e.g., e2e4).", - ephemeral=True + ephemeral=True, ) return @@ -261,16 +287,18 @@ class MoveInputModal(ui.Modal, title='Enter Your Move'): # Try to provide the SAN representation of the attempted move for clarity try: move_san = board.san(move) - except ValueError: # If the move itself was fundamentally invalid (e.g., piece doesn't exist) - move_san = move_text # Fallback to user input + except ( + ValueError + ): # If the move itself was fundamentally invalid (e.g., piece doesn't exist) + move_san = move_text # Fallback to user input await interaction.response.send_message( f"Illegal move: '{move_san}' is not legal in the current position.", - ephemeral=True + ephemeral=True, ) return # Defer interaction here as move processing might take time (esp. for bot game) - await interaction.response.defer() # Acknowledge modal submission + await interaction.response.defer() # Acknowledge modal submission # Process the valid move in the respective view if isinstance(self.game_view, ChessView): @@ -282,44 +310,64 @@ class MoveInputModal(ui.Modal, title='Enter Your Move'): print(f"Error in MoveInputModal: {error}") try: if interaction.response.is_done(): - await interaction.followup.send("An error occurred submitting your move.", ephemeral=True) + await interaction.followup.send( + "An error occurred submitting your move.", ephemeral=True + ) else: - await interaction.response.send_message("An error occurred submitting your move.", ephemeral=True) + await interaction.response.send_message( + "An error occurred submitting your move.", ephemeral=True + ) except Exception as e: print(f"Failed to send error response in MoveInputModal: {e}") + # --- Chess Game (Player vs Player) --- class ChessView(ui.View): - def __init__(self, white_player: discord.Member, black_player: discord.Member, board: Optional[chess.Board] = None): + def __init__( + self, + white_player: discord.Member, + black_player: discord.Member, + board: Optional[chess.Board] = None, + ): super().__init__(timeout=600.0) # 10 minute timeout self.white_player = white_player self.black_player = black_player - self.board = board if board else chess.Board() # Use provided board or create new + self.board = ( + board if board else chess.Board() + ) # Use provided board or create new # Determine current player based on board state - self.current_player = self.white_player if self.board.turn == chess.WHITE else self.black_player + self.current_player = ( + self.white_player if self.board.turn == chess.WHITE else self.black_player + ) self.message: Optional[discord.Message] = None - self.last_move: Optional[chess.Move] = None # Store last move for highlighting - self.white_dm_message: Optional[discord.Message] = None # DM message for white player - self.black_dm_message: Optional[discord.Message] = None # DM message for black player - + self.last_move: Optional[chess.Move] = None # Store last move for highlighting + self.white_dm_message: Optional[discord.Message] = ( + None # DM message for white player + ) + self.black_dm_message: Optional[discord.Message] = ( + None # DM message for black player + ) + # Button-driven move selection state - self.move_selection_mode = False # Whether we're in button-driven move selection mode + self.move_selection_mode = ( + False # Whether we're in button-driven move selection mode + ) self.selected_file = None # Selected file (0-7) during move selection self.selected_rank = None # Selected rank (0-7) during move selection self.selected_square = None # Selected square (0-63) during move selection self.valid_moves = [] # List of valid moves from the selected square - self.game_pgn = chess.pgn.Game() # Initialize PGN game object + self.game_pgn = chess.pgn.Game() # Initialize PGN game object self.game_pgn.headers["Event"] = "Discord Chess Game" self.game_pgn.headers["Site"] = "Discord" self.game_pgn.headers["White"] = self.white_player.display_name self.game_pgn.headers["Black"] = self.black_player.display_name # If starting from a non-standard position, set FEN header and setup board if board: - self.game_pgn.setup(board) # Setup PGN from the board state - else: # Standard starting position - # Setup with the initial board state even if it's standard, ensures node exists - self.game_pgn.setup(self.board) - self.pgn_node = self.game_pgn # Track the current node for adding moves + self.game_pgn.setup(board) # Setup PGN from the board state + else: # Standard starting position + # Setup with the initial board state even if it's standard, ensures node exists + self.game_pgn.setup(self.board) + self.pgn_node = self.game_pgn # Track the current node for adding moves # Add control buttons self.add_item(self.MakeMoveButton()) @@ -330,123 +378,164 @@ class ChessView(ui.View): class MakeMoveButton(ui.Button): def __init__(self): - super().__init__(label="Make Move", style=discord.ButtonStyle.primary, custom_id="chess_make_move") + super().__init__( + label="Make Move", + style=discord.ButtonStyle.primary, + custom_id="chess_make_move", + ) async def callback(self, interaction: discord.Interaction): - view: 'ChessView' = self.view + view: "ChessView" = self.view # Check if it's the correct player's turn before showing modal if interaction.user != view.current_player: - await interaction.response.send_message("It's not your turn!", ephemeral=True) + await interaction.response.send_message( + "It's not your turn!", ephemeral=True + ) return # Open the modal for move input await interaction.response.send_modal(MoveInputModal(game_view=view)) class SelectMoveButton(ui.Button): """Button to start the button-driven move selection process.""" + def __init__(self): - super().__init__(label="Select Move", style=discord.ButtonStyle.primary, custom_id="chess_select_move") - + super().__init__( + label="Select Move", + style=discord.ButtonStyle.primary, + custom_id="chess_select_move", + ) + async def callback(self, interaction: discord.Interaction): - view: 'ChessView' = self.view + view: "ChessView" = self.view # Check if it's the correct player's turn if interaction.user != view.current_player: - await interaction.response.send_message("It's not your turn!", ephemeral=True) + await interaction.response.send_message( + "It's not your turn!", ephemeral=True + ) return - + # Start the move selection process view.move_selection_mode = True view.selected_file = None view.selected_rank = None view.selected_square = None view.valid_moves = [] - + # Show file selection buttons await view.show_file_selection(interaction) - + class ResignButton(ui.Button): def __init__(self): - super().__init__(label="Resign", style=discord.ButtonStyle.danger, custom_id="chess_resign") + super().__init__( + label="Resign", + style=discord.ButtonStyle.danger, + custom_id="chess_resign", + ) async def callback(self, interaction: discord.Interaction): - view: 'ChessView' = self.view + view: "ChessView" = self.view resigning_player = interaction.user # Check if the resigner is part of the game if resigning_player.id not in [view.white_player.id, view.black_player.id]: - await interaction.response.send_message("You are not part of this game.", ephemeral=True) - return - winner = view.black_player if resigning_player == view.white_player else view.white_player - await view.end_game(interaction, f"{resigning_player.mention} resigned. {winner.mention} wins! 🏳️") + await interaction.response.send_message( + "You are not part of this game.", ephemeral=True + ) + return + winner = ( + view.black_player + if resigning_player == view.white_player + else view.white_player + ) + await view.end_game( + interaction, + f"{resigning_player.mention} resigned. {winner.mention} wins! 🏳️", + ) # --- Button Classes for Move Selection --- - + class FileButton(ui.Button): """Button for selecting a file (A-H) in the first phase of move selection.""" + def __init__(self, file_idx: int): self.file_idx = file_idx file_label = chr(65 + file_idx) # 65 is ASCII for 'A' super().__init__(label=file_label, style=discord.ButtonStyle.primary) - + async def callback(self, interaction: discord.Interaction): - view: 'ChessView' = self.view - + view: "ChessView" = self.view + # Basic checks if interaction.user != view.current_player: - await interaction.response.send_message("It's not your turn!", ephemeral=True) + await interaction.response.send_message( + "It's not your turn!", ephemeral=True + ) return - + # Store the selected file and show rank buttons view.selected_file = self.file_idx view.selected_rank = None view.selected_square = None - + # Show rank selection buttons await view.show_rank_selection(interaction) - + class RankButton(ui.Button): """Button for selecting a rank (1-8) in the first phase of move selection.""" + def __init__(self, rank_idx: int): self.rank_idx = rank_idx rank_label = str(8 - rank_idx) # Ranks are displayed as 8 to 1 super().__init__(label=rank_label, style=discord.ButtonStyle.primary) - + async def callback(self, interaction: discord.Interaction): - view: 'ChessView' = self.view - + view: "ChessView" = self.view + # Basic checks if interaction.user != view.current_player: - await interaction.response.send_message("It's not your turn!", ephemeral=True) + await interaction.response.send_message( + "It's not your turn!", ephemeral=True + ) return - + # Calculate the square index file_idx = view.selected_file rank_idx = self.rank_idx - square = chess.square(file_idx, 7 - rank_idx) # Convert to chess.py square index - + square = chess.square( + file_idx, 7 - rank_idx + ) # Convert to chess.py square index + # Check if the square has a piece of the current player's color piece = view.board.piece_at(square) if piece is None or piece.color != view.board.turn: - await interaction.response.send_message("You must select a square with one of your pieces.", ephemeral=True) + await interaction.response.send_message( + "You must select a square with one of your pieces.", ephemeral=True + ) # Go back to file selection await view.show_file_selection(interaction) return - + # Find valid moves from this square - valid_moves = [move for move in view.board.legal_moves if move.from_square == square] + valid_moves = [ + move for move in view.board.legal_moves if move.from_square == square + ] if not valid_moves: - await interaction.response.send_message("This piece has no legal moves.", ephemeral=True) + await interaction.response.send_message( + "This piece has no legal moves.", ephemeral=True + ) # Go back to file selection await view.show_file_selection(interaction) return - + # Store the selected square and valid moves view.selected_square = square view.valid_moves = valid_moves - + # Show valid move buttons await view.show_valid_moves(interaction) - + class MoveButton(ui.Button): """Button for selecting a destination square in the second phase of move selection.""" + def __init__(self, move: chess.Move): self.move = move # Get the destination square coordinates @@ -455,212 +544,289 @@ class ChessView(ui.View): # Create label in algebraic notation (e.g., "e4") label = f"{chr(97 + file_idx)}{rank_idx + 1}" super().__init__(label=label, style=discord.ButtonStyle.success) - + async def callback(self, interaction: discord.Interaction): - view: 'ChessView' = self.view - + view: "ChessView" = self.view + # Basic checks if interaction.user != view.current_player: - await interaction.response.send_message("It's not your turn!", ephemeral=True) + await interaction.response.send_message( + "It's not your turn!", ephemeral=True + ) return - + # Execute the move await interaction.response.defer() # Acknowledge the interaction await view.handle_move(interaction, self.move) - + # --- Button-Driven Move Selection Methods --- - + async def show_file_selection(self, interaction: discord.Interaction): """Shows buttons for selecting a file (A-H).""" # Clear existing buttons self.clear_items() - + # Add file selection buttons (A-H) for file_idx in range(8): self.add_item(self.FileButton(file_idx)) - + # Add a cancel button to return to normal view - cancel_button = ui.Button(label="Cancel", style=discord.ButtonStyle.secondary, custom_id="cancel_move_selection") + cancel_button = ui.Button( + label="Cancel", + style=discord.ButtonStyle.secondary, + custom_id="cancel_move_selection", + ) cancel_button.callback = self._cancel_move_selection_callback self.add_item(cancel_button) - + # Update the message turn_color = "White" if self.board.turn == chess.WHITE else "Black" content = f"Chess: {self.white_player.mention} (White) vs {self.black_player.mention} (Black)\n\nSelect a file (A-H) to choose a piece.\nTurn: **{self.current_player.mention}** ({turn_color})" - board_image = generate_board_image(self.board, self.last_move, perspective_white=(self.current_player == self.white_player)) - + board_image = generate_board_image( + self.board, + self.last_move, + perspective_white=(self.current_player == self.white_player), + ) + if interaction.response.is_done(): - await interaction.edit_original_response(content=content, attachments=[board_image], view=self) + await interaction.edit_original_response( + content=content, attachments=[board_image], view=self + ) else: - await interaction.response.edit_message(content=content, attachments=[board_image], view=self) - + await interaction.response.edit_message( + content=content, attachments=[board_image], view=self + ) + async def show_rank_selection(self, interaction: discord.Interaction): """Shows buttons for selecting a rank (1-8).""" # Clear existing buttons self.clear_items() - + # Add rank selection buttons (1-8) for rank_idx in range(8): self.add_item(self.RankButton(rank_idx)) - + # Add a back button to return to file selection - back_button = ui.Button(label="Back", style=discord.ButtonStyle.secondary, custom_id="back_to_file_selection") + back_button = ui.Button( + label="Back", + style=discord.ButtonStyle.secondary, + custom_id="back_to_file_selection", + ) back_button.callback = self._back_to_file_selection_callback self.add_item(back_button) - + # Add a cancel button to return to normal view - cancel_button = ui.Button(label="Cancel", style=discord.ButtonStyle.secondary, custom_id="cancel_move_selection") + cancel_button = ui.Button( + label="Cancel", + style=discord.ButtonStyle.secondary, + custom_id="cancel_move_selection", + ) cancel_button.callback = self._cancel_move_selection_callback self.add_item(cancel_button) - + # Update the message turn_color = "White" if self.board.turn == chess.WHITE else "Black" file_letter = chr(65 + self.selected_file) # Convert to A-H content = f"Chess: {self.white_player.mention} (White) vs {self.black_player.mention} (Black)\n\nSelected file {file_letter}. Now select a rank (1-8).\nTurn: **{self.current_player.mention}** ({turn_color})" - board_image = generate_board_image(self.board, self.last_move, perspective_white=(self.current_player == self.white_player)) - + board_image = generate_board_image( + self.board, + self.last_move, + perspective_white=(self.current_player == self.white_player), + ) + if interaction.response.is_done(): - await interaction.edit_original_response(content=content, attachments=[board_image], view=self) + await interaction.edit_original_response( + content=content, attachments=[board_image], view=self + ) else: - await interaction.response.edit_message(content=content, attachments=[board_image], view=self) - + await interaction.response.edit_message( + content=content, attachments=[board_image], view=self + ) + async def show_valid_moves(self, interaction: discord.Interaction): """Shows buttons for selecting a destination square from valid moves.""" # Clear existing buttons self.clear_items() - + # Add buttons for each valid move for move in self.valid_moves: self.add_item(self.MoveButton(move)) - + # Add a back button to return to file selection - back_button = ui.Button(label="Back", style=discord.ButtonStyle.secondary, custom_id="back_to_file_selection") + back_button = ui.Button( + label="Back", + style=discord.ButtonStyle.secondary, + custom_id="back_to_file_selection", + ) back_button.callback = self._back_to_file_selection_callback self.add_item(back_button) - + # Add a cancel button to return to normal view - cancel_button = ui.Button(label="Cancel", style=discord.ButtonStyle.secondary, custom_id="cancel_move_selection") + cancel_button = ui.Button( + label="Cancel", + style=discord.ButtonStyle.secondary, + custom_id="cancel_move_selection", + ) cancel_button.callback = self._cancel_move_selection_callback self.add_item(cancel_button) - + # Update the message with valid move dots turn_color = "White" if self.board.turn == chess.WHITE else "Black" file_letter = chr(65 + self.selected_file) # Convert to A-H rank_number = 8 - chess.square_rank(self.selected_square) # Convert to 1-8 content = f"Chess: {self.white_player.mention} (White) vs {self.black_player.mention} (Black)\n\nSelected piece at {file_letter}{rank_number}. Choose a destination square.\nTurn: **{self.current_player.mention}** ({turn_color})" board_image = generate_board_image( - self.board, - self.last_move, + self.board, + self.last_move, perspective_white=(self.current_player == self.white_player), - valid_moves=self.valid_moves + valid_moves=self.valid_moves, ) - + if interaction.response.is_done(): - await interaction.edit_original_response(content=content, attachments=[board_image], view=self) + await interaction.edit_original_response( + content=content, attachments=[board_image], view=self + ) else: - await interaction.response.edit_message(content=content, attachments=[board_image], view=self) - + await interaction.response.edit_message( + content=content, attachments=[board_image], view=self + ) + async def _back_to_file_selection_callback(self, interaction: discord.Interaction): """Callback for the 'Back' button to return to file selection.""" if interaction.user != self.current_player: - await interaction.response.send_message("It's not your turn!", ephemeral=True) + await interaction.response.send_message( + "It's not your turn!", ephemeral=True + ) return await self.show_file_selection(interaction) - + async def _cancel_move_selection_callback(self, interaction: discord.Interaction): """Callback for the 'Cancel' button to exit move selection mode.""" if interaction.user != self.current_player: - await interaction.response.send_message("It's not your turn!", ephemeral=True) + await interaction.response.send_message( + "It's not your turn!", ephemeral=True + ) return - + # Reset move selection state self.move_selection_mode = False self.selected_file = None self.selected_rank = None self.selected_square = None self.valid_moves = [] - + # Restore normal view self.clear_items() self.add_item(self.MakeMoveButton()) self.add_item(self.SelectMoveButton()) self.add_item(self.ResignButton()) - + # Update the message await self.update_message(interaction, "Move selection cancelled. ") - + # --- Helper Methods --- async def interaction_check(self, interaction: discord.Interaction) -> bool: """Checks are now mostly handled within button callbacks for clarity.""" # Basic check: is the user part of the game? if interaction.user.id not in [self.white_player.id, self.black_player.id]: - await interaction.response.send_message("You are not part of this game.", ephemeral=True) + await interaction.response.send_message( + "You are not part of this game.", ephemeral=True + ) return False # Specific turn checks are done in MakeMoveButton callback and MoveInputModal submission return True - async def _get_dm_content(self, player_perspective: discord.Member, result: Optional[str] = None) -> str: + async def _get_dm_content( + self, player_perspective: discord.Member, result: Optional[str] = None + ) -> str: """Generates the FEN and PGN content for the DM from a specific player's perspective.""" fen = self.board.fen() - opponent = self.black_player if player_perspective == self.white_player else self.white_player - opponent_color_str = "Black" if player_perspective == self.white_player else "White" + opponent = ( + self.black_player + if player_perspective == self.white_player + else self.white_player + ) + opponent_color_str = ( + "Black" if player_perspective == self.white_player else "White" + ) # Update PGN headers if result is provided and game is over if result: - pgn_result_code = "*" # Default for ongoing or unknown + pgn_result_code = "*" # Default for ongoing or unknown if result in ["1-0", "0-1", "1/2-1/2"]: pgn_result_code = result elif "wins" in result: - if self.white_player.mention in result: pgn_result_code = "1-0" - elif self.black_player.mention in result: pgn_result_code = "0-1" + if self.white_player.mention in result: + pgn_result_code = "1-0" + elif self.black_player.mention in result: + pgn_result_code = "0-1" elif "draw" in result: pgn_result_code = "1/2-1/2" # Only update if not already set or if changing from '*' - if "Result" not in self.game_pgn.headers or self.game_pgn.headers["Result"] == "*": - self.game_pgn.headers["Result"] = pgn_result_code + if ( + "Result" not in self.game_pgn.headers + or self.game_pgn.headers["Result"] == "*" + ): + self.game_pgn.headers["Result"] = pgn_result_code # Use an exporter for cleaner PGN output - exporter = chess.pgn.StringExporter(headers=True, variations=True, comments=True) + exporter = chess.pgn.StringExporter( + headers=True, variations=True, comments=True + ) pgn_string = self.game_pgn.accept(exporter) # Limit PGN length in DM preview - pgn_preview = pgn_string[:1500] + "..." if len(pgn_string) > 1500 else pgn_string + pgn_preview = ( + pgn_string[:1500] + "..." if len(pgn_string) > 1500 else pgn_string + ) - content = f"Use `/loadchess` to restore this game from FEN or PGN.\n\n" \ - f"**Game vs {opponent.display_name}** ({opponent_color_str})\n\n" \ - f"**FEN:**\n`{fen}`\n\n" \ - f"**PGN:**\n```pgn\n{pgn_preview}\n```" + content = ( + f"Use `/loadchess` to restore this game from FEN or PGN.\n\n" + f"**Game vs {opponent.display_name}** ({opponent_color_str})\n\n" + f"**FEN:**\n`{fen}`\n\n" + f"**PGN:**\n```pgn\n{pgn_preview}\n```" + ) if result: - content += f"\n\n**Status:** {result}" # Always show the descriptive status message + content += f"\n\n**Status:** {result}" # Always show the descriptive status message return content - async def _send_or_update_dm(self, player: discord.Member, result: Optional[str] = None): + async def _send_or_update_dm( + self, player: discord.Member, result: Optional[str] = None + ): """Sends or updates the DM with FEN and PGN for a specific player.""" - is_white = (player == self.white_player) + is_white = player == self.white_player dm_message_attr = "white_dm_message" if is_white else "black_dm_message" dm_message: Optional[discord.Message] = getattr(self, dm_message_attr, None) try: - content = await self._get_dm_content(player_perspective=player, result=result) + content = await self._get_dm_content( + player_perspective=player, result=result + ) dm_channel = player.dm_channel or await player.create_dm() if dm_message: try: await dm_message.edit(content=content) # print(f"Successfully edited DM for {player.display_name}") # Debug - return # Edited successfully + return # Edited successfully except discord.NotFound: - print(f"DM message for {player.display_name} not found, will send a new one.") + print( + f"DM message for {player.display_name} not found, will send a new one." + ) setattr(self, dm_message_attr, None) dm_message = None except discord.Forbidden: - print(f"Cannot edit DM for {player.display_name} (Forbidden). DMs might be closed or message deleted.") + print( + f"Cannot edit DM for {player.display_name} (Forbidden). DMs might be closed or message deleted." + ) setattr(self, dm_message_attr, None) dm_message = None except discord.HTTPException as e: - print(f"HTTP error editing DM for {player.display_name}: {e}. Will try sending.") + print( + f"HTTP error editing DM for {player.display_name}: {e}. Will try sending." + ) setattr(self, dm_message_attr, None) dm_message = None @@ -670,29 +836,37 @@ class ChessView(ui.View): # print(f"Successfully sent new DM to {player.display_name}") # Debug except discord.Forbidden: - print(f"Cannot send DM to {player.display_name} (Forbidden). User likely has DMs disabled.") + print( + f"Cannot send DM to {player.display_name} (Forbidden). User likely has DMs disabled." + ) setattr(self, dm_message_attr, None) except discord.HTTPException as e: print(f"Failed to send/edit DM for {player.display_name}: {e}") setattr(self, dm_message_attr, None) except Exception as e: - print(f"Unexpected error sending/updating DM for {player.display_name}: {e}") + print( + f"Unexpected error sending/updating DM for {player.display_name}: {e}" + ) setattr(self, dm_message_attr, None) async def handle_move(self, interaction: discord.Interaction, move: chess.Move): """Handles a validated legal move submitted via the modal.""" self.board.push(move) - self.last_move = move # Store for highlighting + self.last_move = move # Store for highlighting # Switch turns - self.current_player = self.black_player if self.current_player == self.white_player else self.white_player + self.current_player = ( + self.black_player + if self.current_player == self.white_player + else self.white_player + ) # Check for game end outcome = self.board.outcome() if outcome: await self.end_game(interaction, self.get_game_over_message(outcome)) return - + # Restore default buttons before updating message self.clear_items() self.add_item(self.MakeMoveButton()) @@ -702,28 +876,44 @@ class ChessView(ui.View): # Update the message with the new board state await self.update_message(interaction) - async def update_message(self, interaction_or_message: Union[discord.Interaction, discord.Message], status_prefix: str = ""): + async def update_message( + self, + interaction_or_message: Union[discord.Interaction, discord.Message], + status_prefix: str = "", + ): """Updates the game message with the current board image and status.""" turn_color = "White" if self.board.turn == chess.WHITE else "Black" - status = f"{status_prefix}Turn: **{self.current_player.mention}** ({turn_color})" + status = ( + f"{status_prefix}Turn: **{self.current_player.mention}** ({turn_color})" + ) if self.board.is_check(): status += " **Check!**" fen_string = self.board.fen() content = f"Chess: {self.white_player.mention} (White) vs {self.black_player.mention} (Black)\n\n{status}\nFEN: `{fen_string}`" - board_image = generate_board_image(self.board, self.last_move, perspective_white=(self.current_player == self.white_player)) + board_image = generate_board_image( + self.board, + self.last_move, + perspective_white=(self.current_player == self.white_player), + ) # Determine how to edit the message try: if isinstance(interaction_or_message, discord.Interaction): # If interaction hasn't been responded to (e.g., initial send) if not interaction_or_message.response.is_done(): - await interaction_or_message.response.edit_message(content=content, attachments=[board_image], view=self) + await interaction_or_message.response.edit_message( + content=content, attachments=[board_image], view=self + ) # If interaction was deferred (e.g., after modal submit) else: - await interaction_or_message.edit_original_response(content=content, attachments=[board_image], view=self) + await interaction_or_message.edit_original_response( + content=content, attachments=[board_image], view=self + ) elif isinstance(interaction_or_message, discord.Message): - await interaction_or_message.edit(content=content, attachments=[board_image], view=self) + await interaction_or_message.edit( + content=content, attachments=[board_image], view=self + ) except (discord.NotFound, discord.HTTPException) as e: print(f"ChessView: Failed to update message: {e}") # Handle potential errors like message deleted or permissions lost @@ -736,14 +926,14 @@ class ChessView(ui.View): elif outcome.winner == chess.BLACK: winner_mention = self.black_player.mention loser_mention = self.white_player.mention - else: # Draw - winner_mention = "Nobody" # Or maybe mention both? + else: # Draw + winner_mention = "Nobody" # Or maybe mention both? termination_reason = outcome.termination.name.replace("_", " ").title() if outcome.winner is not None: message = f"Game Over! **{winner_mention}** ({'White' if outcome.winner == chess.WHITE else 'Black'}) wins by {termination_reason}! 🎉" - else: # Draw + else: # Draw message = f"Game Over! It's a draw by {termination_reason}! 🤝" return message @@ -755,44 +945,60 @@ class ChessView(ui.View): # Update DMs with the final result dm_update_tasks = [ self._send_or_update_dm(self.white_player, result=message_content), - self._send_or_update_dm(self.black_player, result=message_content) + self._send_or_update_dm(self.black_player, result=message_content), ] await asyncio.gather(*dm_update_tasks) # Generate the final board image - ensure it's properly created - board_image = generate_board_image(self.board, self.last_move, perspective_white=True) # Final board perspective + board_image = generate_board_image( + self.board, self.last_move, perspective_white=True + ) # Final board perspective try: if interaction.response.is_done(): # If interaction was already responded to, use followup try: - await interaction.followup.send(content=message_content, file=board_image) + await interaction.followup.send( + content=message_content, file=board_image + ) except discord.HTTPException as e: print(f"Failed to send followup: {e}") # Fallback to channel send if followup fails if interaction.channel: - await interaction.channel.send(content=message_content, file=board_image) + await interaction.channel.send( + content=message_content, file=board_image + ) else: # Edit the interaction response if still valid try: - await interaction.response.edit_message(content=message_content, attachments=[board_image], view=self) + await interaction.response.edit_message( + content=message_content, attachments=[board_image], view=self + ) except discord.HTTPException as e: print(f"Failed to edit message: {e}") # Fallback to sending a new message if interaction.channel: - await interaction.channel.send(content=message_content, file=board_image) + await interaction.channel.send( + content=message_content, file=board_image + ) except discord.NotFound: # If the original message is gone, send a new message if interaction.channel: - await interaction.channel.send(content=message_content, file=board_image) + await interaction.channel.send( + content=message_content, file=board_image + ) except Exception as e: print(f"ChessView: Failed to edit or send game end message: {e}") # Last resort fallback - try to send a message to the channel if we can access it try: if interaction.channel: - await interaction.channel.send(content=message_content, file=board_image) + await interaction.channel.send( + content=message_content, file=board_image + ) elif self.message and self.message.channel: - await self.message.channel.send(content=message_content, file=board_image) + await self.message.channel.send( + content=message_content, file=board_image + ) except Exception as inner_e: print(f"Final fallback also failed: {inner_e}") @@ -808,52 +1014,89 @@ class ChessView(ui.View): if self.message and not self.is_finished(): await self.disable_all_buttons() timeout_msg = f"Chess game between {self.white_player.mention} and {self.black_player.mention} timed out." - board_image = generate_board_image(self.board, self.last_move, perspective_white=True) # Default perspective on timeout + board_image = generate_board_image( + self.board, self.last_move, perspective_white=True + ) # Default perspective on timeout try: - await self.message.edit(content=timeout_msg, attachments=[board_image], view=self) + await self.message.edit( + content=timeout_msg, attachments=[board_image], view=self + ) except (discord.NotFound, discord.Forbidden, discord.HTTPException): - pass # Ignore if message is gone or cannot be edited + pass # Ignore if message is gone or cannot be edited self.stop() + # --- Chess Bot Game --- + # Define paths relative to the script location for better portability def get_stockfish_path(): """Returns the appropriate Stockfish path based on the OS.""" SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) - PROJECT_ROOT = os.path.dirname(os.path.dirname(SCRIPT_DIR)) # Go up two levels from games dir - - STOCKFISH_PATH_WINDOWS = os.path.join(PROJECT_ROOT, "stockfish-windows-x86-64-avx2", "stockfish", "stockfish-windows-x86-64-avx2.exe") - STOCKFISH_PATH_LINUX = os.path.join(PROJECT_ROOT, "stockfish-ubuntu-x86-64-avx2", "stockfish", "stockfish-ubuntu-x86-64-avx2") - + PROJECT_ROOT = os.path.dirname( + os.path.dirname(SCRIPT_DIR) + ) # Go up two levels from games dir + + STOCKFISH_PATH_WINDOWS = os.path.join( + PROJECT_ROOT, + "stockfish-windows-x86-64-avx2", + "stockfish", + "stockfish-windows-x86-64-avx2.exe", + ) + STOCKFISH_PATH_LINUX = os.path.join( + PROJECT_ROOT, + "stockfish-ubuntu-x86-64-avx2", + "stockfish", + "stockfish-ubuntu-x86-64-avx2", + ) + system = platform.system() if system == "Windows": if not os.path.exists(STOCKFISH_PATH_WINDOWS): - raise FileNotFoundError(f"Stockfish not found at expected Windows path: {STOCKFISH_PATH_WINDOWS}") + raise FileNotFoundError( + f"Stockfish not found at expected Windows path: {STOCKFISH_PATH_WINDOWS}" + ) return STOCKFISH_PATH_WINDOWS elif system == "Linux": # Check for execute permissions on Linux if not os.path.exists(STOCKFISH_PATH_LINUX): - raise FileNotFoundError(f"Stockfish not found at expected Linux path: {STOCKFISH_PATH_LINUX}") + raise FileNotFoundError( + f"Stockfish not found at expected Linux path: {STOCKFISH_PATH_LINUX}" + ) if not os.access(STOCKFISH_PATH_LINUX, os.X_OK): - print(f"Warning: Stockfish at {STOCKFISH_PATH_LINUX} does not have execute permissions. Attempting to set...") - try: - os.chmod(STOCKFISH_PATH_LINUX, 0o755) # Add execute permissions - if not os.access(STOCKFISH_PATH_LINUX, os.X_OK): # Check again - raise OSError(f"Failed to set execute permissions for Stockfish at {STOCKFISH_PATH_LINUX}") - except Exception as e: - raise OSError(f"Error setting execute permissions for Stockfish: {e}") + print( + f"Warning: Stockfish at {STOCKFISH_PATH_LINUX} does not have execute permissions. Attempting to set..." + ) + try: + os.chmod(STOCKFISH_PATH_LINUX, 0o755) # Add execute permissions + if not os.access(STOCKFISH_PATH_LINUX, os.X_OK): # Check again + raise OSError( + f"Failed to set execute permissions for Stockfish at {STOCKFISH_PATH_LINUX}" + ) + except Exception as e: + raise OSError(f"Error setting execute permissions for Stockfish: {e}") return STOCKFISH_PATH_LINUX else: raise OSError(f"Unsupported operating system '{system}' for Stockfish.") -class ChessBotButton(ui.Button['ChessBotView']): + +class ChessBotButton(ui.Button["ChessBotView"]): def __init__(self, x: int, y: int, piece_symbol: Optional[str] = None): # Unicode chess pieces self.pieces = { - 'r': '♜', 'n': '♞', 'b': '♝', 'q': '♛', 'k': '♚', 'p': '♟', - 'R': '♖', 'N': '♘', 'B': '♗', 'Q': '♕', 'K': '♔', 'P': '♙', - None: ' ' # Use a space for empty squares + "r": "♜", + "n": "♞", + "b": "♝", + "q": "♛", + "k": "♚", + "p": "♟", + "R": "♖", + "N": "♘", + "B": "♗", + "Q": "♕", + "K": "♔", + "P": "♙", + None: " ", # Use a space for empty squares } self.x = x self.y = y @@ -861,10 +1104,14 @@ class ChessBotButton(ui.Button['ChessBotView']): # Set button style and label based on square color is_dark = (x + y) % 2 != 0 - style = discord.ButtonStyle.secondary if is_dark else discord.ButtonStyle.primary - label = self.pieces.get(piece_symbol, ' ') # Get piece representation or space + style = ( + discord.ButtonStyle.secondary if is_dark else discord.ButtonStyle.primary + ) + label = self.pieces.get(piece_symbol, " ") # Get piece representation or space # REMOVED row=y parameter - super().__init__(style=style, label=label if label != ' ' else '') # Use em-space for empty squares + super().__init__( + style=style, label=label if label != " " else "" + ) # Use em-space for empty squares async def callback(self, interaction: discord.Interaction): assert self.view is not None @@ -872,43 +1119,84 @@ class ChessBotButton(ui.Button['ChessBotView']): # Check if it's the player's turn and the engine is ready if interaction.user != view.player: - await interaction.response.send_message("This is not your game!", ephemeral=True) - return + await interaction.response.send_message( + "This is not your game!", ephemeral=True + ) + return if view.board.turn != view.player_color: - await interaction.response.send_message("It's not your turn!", ephemeral=True) + await interaction.response.send_message( + "It's not your turn!", ephemeral=True + ) return if view.engine is None or view.is_thinking: - await interaction.response.send_message("Please wait for the bot to finish thinking or start.", ephemeral=True) + await interaction.response.send_message( + "Please wait for the bot to finish thinking or start.", ephemeral=True + ) return # Process the move await view.handle_square_click(interaction, self.x, self.y) + class ChessBotView(ui.View): # Maps skill level (0-20) to typical ELO ratings for context SKILL_ELO_MAP = { - 0: 800, 1: 900, 2: 1000, 3: 1100, 4: 1200, 5: 1300, 6: 1400, 7: 1500, 8: 1600, 9: 1700, - 10: 1800, 11: 1900, 12: 2000, 13: 2100, 14: 2200, 15: 2300, 16: 2400, 17: 2500, 18: 2600, - 19: 2700, 20: 2800 + 0: 800, + 1: 900, + 2: 1000, + 3: 1100, + 4: 1200, + 5: 1300, + 6: 1400, + 7: 1500, + 8: 1600, + 9: 1700, + 10: 1800, + 11: 1900, + 12: 2000, + 13: 2100, + 14: 2200, + 15: 2300, + 16: 2400, + 17: 2500, + 18: 2600, + 19: 2700, + 20: 2800, } - def __init__(self, player: discord.Member, player_color: chess.Color, variant: str = "standard", skill_level: int = 10, think_time: float = 1.0, board: Optional[chess.Board] = None): + def __init__( + self, + player: discord.Member, + player_color: chess.Color, + variant: str = "standard", + skill_level: int = 10, + think_time: float = 1.0, + board: Optional[chess.Board] = None, + ): super().__init__(timeout=900.0) # 15 minute timeout self.player = player - self.player_color = player_color # The color the human player chose to play as + self.player_color = player_color # The color the human player chose to play as self.bot_color = not player_color self.variant = variant.lower() self.message: Optional[discord.Message] = None - self.engine: Optional[chess.engine.UciProtocol] = None # Use the async UciProtocol - self._engine_transport: Optional[asyncio.SubprocessTransport] = None # Store transport for closing - self.skill_level = max(0, min(20, skill_level)) # Clamp skill level - self.think_time = max(0.1, min(5.0, think_time)) # Clamp think time - self.is_thinking = False # Flag to prevent interaction during bot's turn - self.last_move: Optional[chess.Move] = None # Store last move for highlighting - self.player_dm_message: Optional[discord.Message] = None # DM message for the player - + self.engine: Optional[chess.engine.UciProtocol] = ( + None # Use the async UciProtocol + ) + self._engine_transport: Optional[asyncio.SubprocessTransport] = ( + None # Store transport for closing + ) + self.skill_level = max(0, min(20, skill_level)) # Clamp skill level + self.think_time = max(0.1, min(5.0, think_time)) # Clamp think time + self.is_thinking = False # Flag to prevent interaction during bot's turn + self.last_move: Optional[chess.Move] = None # Store last move for highlighting + self.player_dm_message: Optional[discord.Message] = ( + None # DM message for the player + ) + # Button-driven move selection state - self.move_selection_mode = False # Whether we're in button-driven move selection mode + self.move_selection_mode = ( + False # Whether we're in button-driven move selection mode + ) self.selected_file = None # Selected file (0-7) during move selection self.selected_rank = None # Selected rank (0-7) during move selection self.selected_square = None # Selected square (0-63) during move selection @@ -925,22 +1213,32 @@ class ChessBotView(ui.View): if self.variant == "chess960": self.board = chess.Board(chess960=True) self.initial_fen = self.board.fen() - else: # Standard chess + else: # Standard chess self.board = chess.Board() self.initial_fen = None # Initialize PGN tracking self.game_pgn = chess.pgn.Game() - self.game_pgn.headers["Event"] = f"Discord Chess Bot Game (Skill {self.skill_level})" + self.game_pgn.headers["Event"] = ( + f"Discord Chess Bot Game (Skill {self.skill_level})" + ) self.game_pgn.headers["Site"] = "Discord" - self.game_pgn.headers["White"] = player.display_name if player_color == chess.WHITE else f"Bot (Skill {self.skill_level})" - self.game_pgn.headers["Black"] = player.display_name if player_color == chess.BLACK else f"Bot (Skill {self.skill_level})" + self.game_pgn.headers["White"] = ( + player.display_name + if player_color == chess.WHITE + else f"Bot (Skill {self.skill_level})" + ) + self.game_pgn.headers["Black"] = ( + player.display_name + if player_color == chess.BLACK + else f"Bot (Skill {self.skill_level})" + ) # If starting from a non-standard position (loaded board), set up PGN if board: self.game_pgn.setup(board) else: - self.game_pgn.setup(self.board) # Setup even for standard start - self.pgn_node = self.game_pgn # Start at the root node + self.game_pgn.setup(self.board) # Setup even for standard start + self.pgn_node = self.game_pgn # Start at the root node # Add control buttons self.add_item(self.MakeMoveButton()) @@ -948,87 +1246,110 @@ class ChessBotView(ui.View): self.add_item(self.ResignButton()) # --- Button Definitions --- - + class FileButton(ui.Button): """Button for selecting a file (A-H) in the first phase of move selection.""" + def __init__(self, file_idx: int): self.file_idx = file_idx file_label = chr(65 + file_idx) # 65 is ASCII for 'A' super().__init__(label=file_label, style=discord.ButtonStyle.primary) - + async def callback(self, interaction: discord.Interaction): - view: 'ChessBotView' = self.view - + view: "ChessBotView" = self.view + # Basic checks if interaction.user != view.player: - await interaction.response.send_message("This is not your game!", ephemeral=True) + await interaction.response.send_message( + "This is not your game!", ephemeral=True + ) return if view.board.turn != view.player_color: - await interaction.response.send_message("It's not your turn!", ephemeral=True) + await interaction.response.send_message( + "It's not your turn!", ephemeral=True + ) return if view.is_thinking: - await interaction.response.send_message("The bot is thinking, please wait.", ephemeral=True) + await interaction.response.send_message( + "The bot is thinking, please wait.", ephemeral=True + ) return - + # Store the selected file and show rank buttons view.selected_file = self.file_idx view.selected_rank = None view.selected_square = None - + # Show rank selection buttons await view.show_rank_selection(interaction) - + class RankButton(ui.Button): """Button for selecting a rank (1-8) in the first phase of move selection.""" + def __init__(self, rank_idx: int): self.rank_idx = rank_idx rank_label = str(8 - rank_idx) # Ranks are displayed as 8 to 1 super().__init__(label=rank_label, style=discord.ButtonStyle.primary) - + async def callback(self, interaction: discord.Interaction): - view: 'ChessBotView' = self.view - + view: "ChessBotView" = self.view + # Basic checks if interaction.user != view.player: - await interaction.response.send_message("This is not your game!", ephemeral=True) + await interaction.response.send_message( + "This is not your game!", ephemeral=True + ) return if view.board.turn != view.player_color: - await interaction.response.send_message("It's not your turn!", ephemeral=True) + await interaction.response.send_message( + "It's not your turn!", ephemeral=True + ) return if view.is_thinking: - await interaction.response.send_message("The bot is thinking, please wait.", ephemeral=True) + await interaction.response.send_message( + "The bot is thinking, please wait.", ephemeral=True + ) return - + # Calculate the square index file_idx = view.selected_file rank_idx = self.rank_idx - square = chess.square(file_idx, 7 - rank_idx) # Convert to chess.py square index - + square = chess.square( + file_idx, 7 - rank_idx + ) # Convert to chess.py square index + # Check if the square has a piece of the player's color piece = view.board.piece_at(square) if piece is None or piece.color != view.player_color: - await interaction.response.send_message("You must select a square with one of your pieces.", ephemeral=True) + await interaction.response.send_message( + "You must select a square with one of your pieces.", ephemeral=True + ) # Go back to file selection await view.show_file_selection(interaction) return - + # Find valid moves from this square - valid_moves = [move for move in view.board.legal_moves if move.from_square == square] + valid_moves = [ + move for move in view.board.legal_moves if move.from_square == square + ] if not valid_moves: - await interaction.response.send_message("This piece has no legal moves.", ephemeral=True) + await interaction.response.send_message( + "This piece has no legal moves.", ephemeral=True + ) # Go back to file selection await view.show_file_selection(interaction) return - + # Store the selected square and valid moves view.selected_square = square view.valid_moves = valid_moves - + # Show valid move buttons await view.show_valid_moves(interaction) - + class MoveButton(ui.Button): """Button for selecting a destination square in the second phase of move selection.""" + def __init__(self, move: chess.Move): self.move = move # Get the destination square coordinates @@ -1037,93 +1358,134 @@ class ChessBotView(ui.View): # Create label in algebraic notation (e.g., "e4") label = f"{chr(97 + file_idx)}{rank_idx + 1}" super().__init__(label=label, style=discord.ButtonStyle.success) - + async def callback(self, interaction: discord.Interaction): - view: 'ChessBotView' = self.view - + view: "ChessBotView" = self.view + # Basic checks if interaction.user != view.player: - await interaction.response.send_message("This is not your game!", ephemeral=True) + await interaction.response.send_message( + "This is not your game!", ephemeral=True + ) return if view.board.turn != view.player_color: - await interaction.response.send_message("It's not your turn!", ephemeral=True) + await interaction.response.send_message( + "It's not your turn!", ephemeral=True + ) return if view.is_thinking: - await interaction.response.send_message("The bot is thinking, please wait.", ephemeral=True) + await interaction.response.send_message( + "The bot is thinking, please wait.", ephemeral=True + ) return - + # Execute the move await interaction.response.defer() # Acknowledge the interaction await view.handle_player_move(interaction, self.move) - + class MakeMoveButton(ui.Button): def __init__(self): - super().__init__(label="Make Move", style=discord.ButtonStyle.primary, custom_id="chessbot_make_move") + super().__init__( + label="Make Move", + style=discord.ButtonStyle.primary, + custom_id="chessbot_make_move", + ) async def callback(self, interaction: discord.Interaction): - view: 'ChessBotView' = self.view + view: "ChessBotView" = self.view # Check turn and thinking state if interaction.user != view.player: - await interaction.response.send_message("This is not your game!", ephemeral=True) - return + await interaction.response.send_message( + "This is not your game!", ephemeral=True + ) + return if view.board.turn != view.player_color: - await interaction.response.send_message("It's not your turn!", ephemeral=True) + await interaction.response.send_message( + "It's not your turn!", ephemeral=True + ) return if view.is_thinking: - await interaction.response.send_message("The bot is thinking, please wait.", ephemeral=True) + await interaction.response.send_message( + "The bot is thinking, please wait.", ephemeral=True + ) return if view.engine is None: - await interaction.response.send_message("The engine is not running.", ephemeral=True) - return - if view.is_thinking: # Added check here as well - await interaction.response.send_message("The bot is thinking, please wait.", ephemeral=True) + await interaction.response.send_message( + "The engine is not running.", ephemeral=True + ) + return + if view.is_thinking: # Added check here as well + await interaction.response.send_message( + "The bot is thinking, please wait.", ephemeral=True + ) return # Open the modal for move input await interaction.response.send_modal(MoveInputModal(game_view=view)) - + class SelectMoveButton(ui.Button): """Button to start the button-driven move selection process.""" + def __init__(self): - super().__init__(label="Select Move", style=discord.ButtonStyle.primary, custom_id="chessbot_select_move") - + super().__init__( + label="Select Move", + style=discord.ButtonStyle.primary, + custom_id="chessbot_select_move", + ) + async def callback(self, interaction: discord.Interaction): - view: 'ChessBotView' = self.view + view: "ChessBotView" = self.view # Check turn and thinking state if interaction.user != view.player: - await interaction.response.send_message("This is not your game!", ephemeral=True) + await interaction.response.send_message( + "This is not your game!", ephemeral=True + ) return if view.board.turn != view.player_color: - await interaction.response.send_message("It's not your turn!", ephemeral=True) + await interaction.response.send_message( + "It's not your turn!", ephemeral=True + ) return if view.is_thinking: - await interaction.response.send_message("The bot is thinking, please wait.", ephemeral=True) + await interaction.response.send_message( + "The bot is thinking, please wait.", ephemeral=True + ) return if view.engine is None: - await interaction.response.send_message("The engine is not running.", ephemeral=True) + await interaction.response.send_message( + "The engine is not running.", ephemeral=True + ) return - + # Start the move selection process view.move_selection_mode = True view.selected_file = None view.selected_rank = None view.selected_square = None view.valid_moves = [] - + # Show file selection buttons await view.show_file_selection(interaction) class ResignButton(ui.Button): def __init__(self): - super().__init__(label="Resign", style=discord.ButtonStyle.danger, custom_id="chessbot_resign") + super().__init__( + label="Resign", + style=discord.ButtonStyle.danger, + custom_id="chessbot_resign", + ) async def callback(self, interaction: discord.Interaction): - view: 'ChessBotView' = self.view + view: "ChessBotView" = self.view if interaction.user != view.player: - await interaction.response.send_message("This is not your game!", ephemeral=True) - return + await interaction.response.send_message( + "This is not your game!", ephemeral=True + ) + return # Bot wins on player resignation - await view.end_game(interaction, f"{view.player.mention} resigned. Bot wins! 🏳️") + await view.end_game( + interaction, f"{view.player.mention} resigned. Bot wins! 🏳️" + ) # --- Engine and Game Logic --- @@ -1138,8 +1500,10 @@ class ChessBotView(ui.View): # Use the async popen_uci print("[Debug] Awaiting chess.engine.popen_uci...") transport, engine_protocol = await chess.engine.popen_uci(stockfish_path) - print(f"[Debug] popen_uci successful. Protocol type: {type(engine_protocol)}") - self.engine = engine_protocol # This is the UciProtocol object + print( + f"[Debug] popen_uci successful. Protocol type: {type(engine_protocol)}" + ) + self.engine = engine_protocol # This is the UciProtocol object self._engine_transport = transport # Configure Stockfish options using the configure method (corrected approach) @@ -1153,90 +1517,141 @@ class ChessBotView(ui.View): # UCI_Chess960 option typically expects a boolean or string "true"/"false". # Assuming configure handles this conversion or expects boolean. options_to_set["UCI_Chess960"] = True - await self.engine.configure(options_to_set) # Use configure as suggested + await self.engine.configure(options_to_set) # Use configure as suggested print("[Debug] Configuration successful.") # Position is set implicitly when calling play/analyse or explicitly via send_command # No explicit position call needed here. - print("[Debug] Engine configured. Position will be set on first play/analyse call.") + print( + "[Debug] Engine configured. Position will be set on first play/analyse call." + ) - print(f"Stockfish engine configured for {self.variant} with skill level {self.skill_level}.") + print( + f"Stockfish engine configured for {self.variant} with skill level {self.skill_level}." + ) except FileNotFoundError as e: - print(f"[Error] Stockfish executable not found: {e}") - self.engine = None - # Notify the user in the channel if the message exists - if self.message: - # ... (rest of existing error handling for this block) - try: - if hasattr(self, '_interaction') and self._interaction and not self._interaction.response.is_done(): - await self._interaction.followup.send(f"Error: Could not start the chess engine: {e}", ephemeral=True) - else: - await self.message.channel.send(f"Error: Could not start the chess engine: {e}") - except (discord.Forbidden, discord.HTTPException): - pass - if not self.is_finished(): self.stop() + print(f"[Error] Stockfish executable not found: {e}") + self.engine = None + # Notify the user in the channel if the message exists + if self.message: + # ... (rest of existing error handling for this block) + try: + if ( + hasattr(self, "_interaction") + and self._interaction + and not self._interaction.response.is_done() + ): + await self._interaction.followup.send( + f"Error: Could not start the chess engine: {e}", + ephemeral=True, + ) + else: + await self.message.channel.send( + f"Error: Could not start the chess engine: {e}" + ) + except (discord.Forbidden, discord.HTTPException): + pass + if not self.is_finished(): + self.stop() except OSError as e: - print(f"[Error] OS error during engine start: {e}") - self.engine = None - # Notify the user in the channel if the message exists - if self.message: - # ... (rest of existing error handling for this block) - try: - if hasattr(self, '_interaction') and self._interaction and not self._interaction.response.is_done(): - await self._interaction.followup.send(f"Error: Could not start the chess engine: {e}", ephemeral=True) - else: - await self.message.channel.send(f"Error: Could not start the chess engine: {e}") - except (discord.Forbidden, discord.HTTPException): - pass - if not self.is_finished(): self.stop() + print(f"[Error] OS error during engine start: {e}") + self.engine = None + # Notify the user in the channel if the message exists + if self.message: + # ... (rest of existing error handling for this block) + try: + if ( + hasattr(self, "_interaction") + and self._interaction + and not self._interaction.response.is_done() + ): + await self._interaction.followup.send( + f"Error: Could not start the chess engine: {e}", + ephemeral=True, + ) + else: + await self.message.channel.send( + f"Error: Could not start the chess engine: {e}" + ) + except (discord.Forbidden, discord.HTTPException): + pass + if not self.is_finished(): + self.stop() except chess.engine.EngineError as e: - print(f"[Error] Chess engine error during start/config: {e}") - if engine_protocol: - try: await engine_protocol.quit() - except: pass - if transport: - transport.close() - self.engine = None - self._engine_transport = None - # Notify the user in the channel if the message exists - if self.message: - # ... (rest of existing error handling for this block) - try: - if hasattr(self, '_interaction') and self._interaction and not self._interaction.response.is_done(): - await self._interaction.followup.send(f"Error: Could not start the chess engine: {e}", ephemeral=True) - else: - await self.message.channel.send(f"Error: Could not start the chess engine: {e}") - except (discord.Forbidden, discord.HTTPException): - pass - if not self.is_finished(): self.stop() + print(f"[Error] Chess engine error during start/config: {e}") + if engine_protocol: + try: + await engine_protocol.quit() + except: + pass + if transport: + transport.close() + self.engine = None + self._engine_transport = None + # Notify the user in the channel if the message exists + if self.message: + # ... (rest of existing error handling for this block) + try: + if ( + hasattr(self, "_interaction") + and self._interaction + and not self._interaction.response.is_done() + ): + await self._interaction.followup.send( + f"Error: Could not start the chess engine: {e}", + ephemeral=True, + ) + else: + await self.message.channel.send( + f"Error: Could not start the chess engine: {e}" + ) + except (discord.Forbidden, discord.HTTPException): + pass + if not self.is_finished(): + self.stop() except Exception as e: # Catch the specific error if possible, otherwise print generic print(f"[Error] Unexpected error during engine start: {e}") - print(f"[Debug] Type of error: {type(e)}") # Print the type of the exception + print( + f"[Debug] Type of error: {type(e)}" + ) # Print the type of the exception if "can't be used in 'await' expression" in str(e): - print("[Debug] Caught the specific 'await' expression error.") + print("[Debug] Caught the specific 'await' expression error.") if engine_protocol: - try: await engine_protocol.quit() - except: pass + try: + await engine_protocol.quit() + except: + pass if transport: - transport.close() + transport.close() self.engine = None self._engine_transport = None # Notify the user in the channel if the message exists if self.message: try: # Use followup if interaction is available and not done - if hasattr(self, '_interaction') and self._interaction and not self._interaction.response.is_done(): - await self._interaction.followup.send(f"Error: Could not start the chess engine: {e}", ephemeral=True) + if ( + hasattr(self, "_interaction") + and self._interaction + and not self._interaction.response.is_done() + ): + await self._interaction.followup.send( + f"Error: Could not start the chess engine: {e}", + ephemeral=True, + ) else: - await self.message.channel.send(f"Error: Could not start the chess engine: {e}") + await self.message.channel.send( + f"Error: Could not start the chess engine: {e}" + ) except (discord.Forbidden, discord.HTTPException): - pass # Can't send message + pass # Can't send message if not self.is_finished(): - self.stop() # Stop the view if engine fails and view hasn't already stopped + self.stop() # Stop the view if engine fails and view hasn't already stopped - async def handle_player_move(self, interaction: discord.Interaction, move: chess.Move): + async def handle_player_move( + self, interaction: discord.Interaction, move: chess.Move + ): """Handles the player's validated legal move.""" # Add move to PGN self.pgn_node = self.pgn_node.add_variation(move) @@ -1250,7 +1665,7 @@ class ChessBotView(ui.View): self.selected_rank = None self.selected_square = None self.valid_moves = [] - + # Update player's DM asyncio.create_task(self._send_or_update_dm()) @@ -1270,8 +1685,13 @@ class ChessBotView(ui.View): async def make_bot_move(self): """Lets the Stockfish engine make a move using the async protocol.""" - if self.engine is None or self.board.turn != self.bot_color or self.is_thinking or self.is_finished(): - return # Engine not ready, not bot's turn, already thinking, or game ended + if ( + self.engine is None + or self.board.turn != self.bot_color + or self.is_thinking + or self.is_finished() + ): + return # Engine not ready, not bot's turn, already thinking, or game ended self.is_thinking = True try: @@ -1280,7 +1700,9 @@ class ChessBotView(ui.View): # Use the protocol's play method (ASYNC) print("[Debug] Awaiting engine.play...") - result = await self.engine.play(self.board, chess.engine.Limit(time=self.think_time)) + result = await self.engine.play( + self.board, chess.engine.Limit(time=self.think_time) + ) print(f"[Debug] engine.play completed. Result: {result}") # Check if the view is still active before proceeding @@ -1303,14 +1725,20 @@ class ChessBotView(ui.View): if outcome: # Need a way to update the message; use self.message if available if self.message: - # Pass the message object directly to end_game - await self.end_game(self.message, self.get_game_over_message(outcome)) - else: # Should not happen if game started correctly - print("ChessBotView Error: Cannot end game after bot move, self.message is None.") - return # Important: return after ending the game + # Pass the message object directly to end_game + await self.end_game( + self.message, self.get_game_over_message(outcome) + ) + else: # Should not happen if game started correctly + print( + "ChessBotView Error: Cannot end game after bot move, self.message is None." + ) + return # Important: return after ending the game # Restore default buttons for player's turn - if self.message and not self.is_finished(): # Check if view is still active + if ( + self.message and not self.is_finished() + ): # Check if view is still active self.clear_items() self.add_item(self.MakeMoveButton()) self.add_item(self.SelectMoveButton()) @@ -1318,39 +1746,60 @@ class ChessBotView(ui.View): # Now update the message await self.update_message(self.message, status_prefix="Your turn.") else: - print("ChessBotView: Engine returned no best move (result.move is None).") - if self.message and not self.is_finished(): - await self.update_message(self.message, status_prefix="Bot failed to find a move. Your turn?") + print( + "ChessBotView: Engine returned no best move (result.move is None)." + ) + if self.message and not self.is_finished(): + await self.update_message( + self.message, + status_prefix="Bot failed to find a move. Your turn?", + ) - except (chess.engine.EngineError, chess.engine.EngineTerminatedError, Exception) as e: + except ( + chess.engine.EngineError, + chess.engine.EngineTerminatedError, + Exception, + ) as e: print(f"Error during bot move analysis: {e}") if self.message and not self.is_finished(): - try: - # Try to inform the user about the error - await self.update_message(self.message, status_prefix=f"Error during bot move: {e}. Your turn?") - except: pass # Ignore errors editing message here + try: + # Try to inform the user about the error + await self.update_message( + self.message, + status_prefix=f"Error during bot move: {e}. Your turn?", + ) + except: + pass # Ignore errors editing message here # Consider stopping the game if the engine has issues await self.stop_engine() if not self.is_finished(): - self.stop() # Stop the view as well + self.stop() # Stop the view as well finally: # Ensure is_thinking is reset even if errors occur or game ends mid-thought self.is_thinking = False # --- Message and State Management --- - async def update_message(self, interaction_or_message: Union[discord.Interaction, discord.Message], status_prefix: str = ""): + async def update_message( + self, + interaction_or_message: Union[discord.Interaction, discord.Message], + status_prefix: str = "", + ): """Updates the game message with the current board image and status.""" content = self.get_board_message(status_prefix) - + # Determine if we need to show valid move dots (only when showing valid move buttons) - show_valid_move_dots = self.move_selection_mode and self.selected_square is not None and self.valid_moves - + show_valid_move_dots = ( + self.move_selection_mode + and self.selected_square is not None + and self.valid_moves + ) + board_image = generate_board_image( - self.board, - self.last_move, + self.board, + self.last_move, perspective_white=(self.player_color == chess.WHITE), - valid_moves=self.valid_moves if show_valid_move_dots else None + valid_moves=self.valid_moves if show_valid_move_dots else None, ) # NOTE: Button setup is now handled by the calling function (e.g., handle_player_move, make_bot_move, _cancel_move_selection_callback) @@ -1360,12 +1809,18 @@ class ChessBotView(ui.View): if isinstance(interaction_or_message, discord.Interaction): # If interaction hasn't been responded to (e.g., initial send) if not interaction_or_message.response.is_done(): - await interaction_or_message.response.edit_message(content=content, attachments=[board_image], view=self) + await interaction_or_message.response.edit_message( + content=content, attachments=[board_image], view=self + ) # If interaction was deferred (e.g., after modal submit) else: - await interaction_or_message.edit_original_response(content=content, attachments=[board_image], view=self) + await interaction_or_message.edit_original_response( + content=content, attachments=[board_image], view=self + ) elif isinstance(interaction_or_message, discord.Message): - await interaction_or_message.edit(content=content, attachments=[board_image], view=self) + await interaction_or_message.edit( + content=content, attachments=[board_image], view=self + ) except (discord.NotFound, discord.HTTPException) as e: print(f"ChessBotView: Failed to update message: {e}") # If message update fails, stop the game to prevent inconsistent state @@ -1400,26 +1855,37 @@ class ChessBotView(ui.View): if outcome.winner == self.player_color: winner_text = f"{self.player.mention} ({'White' if self.player_color == chess.WHITE else 'Black'}) wins!" elif outcome.winner == self.bot_color: - winner_text = f"Bot ({'White' if self.bot_color == chess.WHITE else 'Black'}) wins!" + winner_text = ( + f"Bot ({'White' if self.bot_color == chess.WHITE else 'Black'}) wins!" + ) else: winner_text = "It's a draw!" termination_reason = outcome.termination.name.replace("_", " ").title() return f"Game Over! **{winner_text} by {termination_reason}**" - async def end_game(self, interaction_or_message: Union[discord.Interaction, discord.Message], message_content: str): + async def end_game( + self, + interaction_or_message: Union[discord.Interaction, discord.Message], + message_content: str, + ): """Ends the game, disables buttons, stops the engine, and updates the message.""" - if self.is_finished(): return # Avoid double execution + if self.is_finished(): + return # Avoid double execution await self.disable_all_buttons() - await self.stop_engine() # Ensure engine is closed before stopping view + await self.stop_engine() # Ensure engine is closed before stopping view # Update DM with final result await self._send_or_update_dm(result=message_content) # Ensure a valid board image is generated try: - board_image = generate_board_image(self.board, self.last_move, perspective_white=(self.player_color == chess.WHITE)) # Show final board + board_image = generate_board_image( + self.board, + self.last_move, + perspective_white=(self.player_color == chess.WHITE), + ) # Show final board except Exception as img_error: print(f"Error generating final board image: {img_error}") # Create a fallback message if image generation fails @@ -1461,9 +1927,15 @@ class ChessBotView(ui.View): # If interaction was deferred or responded to, try to edit original response try: if board_image: - await interaction.edit_original_response(content=message_content, attachments=[board_image], view=self) + await interaction.edit_original_response( + content=message_content, + attachments=[board_image], + view=self, + ) else: - await interaction.edit_original_response(content=message_content, view=self) + await interaction.edit_original_response( + content=message_content, view=self + ) success = True except (discord.NotFound, discord.HTTPException) as e: print(f"Failed to edit original response: {e}") @@ -1471,16 +1943,24 @@ class ChessBotView(ui.View): # If interaction is fresh, edit its message try: if board_image: - await interaction.response.edit_message(content=message_content, attachments=[board_image], view=self) + await interaction.response.edit_message( + content=message_content, + attachments=[board_image], + view=self, + ) else: - await interaction.response.edit_message(content=message_content, view=self) + await interaction.response.edit_message( + content=message_content, view=self + ) success = True except (discord.NotFound, discord.HTTPException) as e: print(f"Failed to edit message via response: {e}") # Try to send a followup if editing fails try: if board_image: - await interaction.followup.send(content=message_content, file=board_image) + await interaction.followup.send( + content=message_content, file=board_image + ) else: await interaction.followup.send(content=message_content) success = True @@ -1493,7 +1973,9 @@ class ChessBotView(ui.View): if target_message and not success: try: if board_image: - await target_message.edit(content=message_content, attachments=[board_image], view=self) + await target_message.edit( + content=message_content, attachments=[board_image], view=self + ) else: await target_message.edit(content=message_content, view=self) success = True @@ -1514,7 +1996,7 @@ class ChessBotView(ui.View): if not success: print("ChessBotView: All attempts to send game end message failed") - self.stop() # Stop the view itself AFTER attempting message update + self.stop() # Stop the view itself AFTER attempting message update async def disable_all_buttons(self): for item in self.children: @@ -1526,8 +2008,8 @@ class ChessBotView(ui.View): """Safely quits the chess engine using the async protocol and transport.""" engine_protocol = self.engine transport = self._engine_transport - self.engine = None # Set to None immediately - self._engine_transport = None # Clear transport reference + self.engine = None # Set to None immediately + self._engine_transport = None # Clear transport reference if engine_protocol: try: @@ -1535,7 +2017,11 @@ class ChessBotView(ui.View): print("[Debug] Awaiting engine.quit()...") await engine_protocol.quit() print("Stockfish engine quit command sent successfully.") - except (chess.engine.EngineError, chess.engine.EngineTerminatedError, Exception) as e: + except ( + chess.engine.EngineError, + chess.engine.EngineTerminatedError, + Exception, + ) as e: print(f"Error sending quit command to Stockfish engine: {e}") if transport: @@ -1546,158 +2032,225 @@ class ChessBotView(ui.View): except Exception as e: print(f"Error closing engine transport: {e}") - async def on_timeout(self): - if not self.is_finished(): # Only act if not already stopped + if not self.is_finished(): # Only act if not already stopped timeout_msg = f"Chess game for {self.player.mention} timed out." - await self.end_game(self.message, timeout_msg) # Use end_game to handle cleanup and message update + await self.end_game( + self.message, timeout_msg + ) # Use end_game to handle cleanup and message update - async def on_error(self, interaction: discord.Interaction, error: Exception, item: ui.Item): + async def on_error( + self, interaction: discord.Interaction, error: Exception, item: ui.Item + ): print(f"Error in ChessBotView interaction (item: {item}): {error}") # Try to send an ephemeral message about the error try: if interaction.response.is_done(): - await interaction.followup.send(f"An error occurred: {error}", ephemeral=True) + await interaction.followup.send( + f"An error occurred: {error}", ephemeral=True + ) else: - await interaction.response.send_message(f"An error occurred: {error}", ephemeral=True) + await interaction.response.send_message( + f"An error occurred: {error}", ephemeral=True + ) except Exception as e: print(f"ChessBotView: Failed to send error response: {e}") # Stop the game on error to be safe - await self.end_game(interaction, f"An error occurred, stopping the game: {error}") + await self.end_game( + interaction, f"An error occurred, stopping the game: {error}" + ) # --- Button-Driven Move Selection Methods --- - + async def show_file_selection(self, interaction: discord.Interaction): """Shows buttons for selecting a file (A-H).""" # Clear existing buttons self.clear_items() - + # Add file selection buttons (A-H) for file_idx in range(8): self.add_item(self.FileButton(file_idx)) - + # Add a cancel button to return to normal view - cancel_button = ui.Button(label="Cancel", style=discord.ButtonStyle.secondary, custom_id="cancel_move_selection") + cancel_button = ui.Button( + label="Cancel", + style=discord.ButtonStyle.secondary, + custom_id="cancel_move_selection", + ) cancel_button.callback = self._cancel_move_selection_callback self.add_item(cancel_button) - + # Update the message content = self.get_board_message("Select a file (A-H) to choose a piece. ") - board_image = generate_board_image(self.board, self.last_move, perspective_white=(self.player_color == chess.WHITE)) - + board_image = generate_board_image( + self.board, + self.last_move, + perspective_white=(self.player_color == chess.WHITE), + ) + if interaction.response.is_done(): - await interaction.edit_original_response(content=content, attachments=[board_image], view=self) + await interaction.edit_original_response( + content=content, attachments=[board_image], view=self + ) else: - await interaction.response.edit_message(content=content, attachments=[board_image], view=self) - + await interaction.response.edit_message( + content=content, attachments=[board_image], view=self + ) + async def show_rank_selection(self, interaction: discord.Interaction): """Shows buttons for selecting a rank (1-8).""" # Clear existing buttons self.clear_items() - + # Add rank selection buttons (1-8) for rank_idx in range(8): self.add_item(self.RankButton(rank_idx)) - + # Add a back button to return to file selection - back_button = ui.Button(label="Back", style=discord.ButtonStyle.secondary, custom_id="back_to_file_selection") + back_button = ui.Button( + label="Back", + style=discord.ButtonStyle.secondary, + custom_id="back_to_file_selection", + ) back_button.callback = self._back_to_file_selection_callback self.add_item(back_button) - + # Add a cancel button to return to normal view - cancel_button = ui.Button(label="Cancel", style=discord.ButtonStyle.secondary, custom_id="cancel_move_selection") + cancel_button = ui.Button( + label="Cancel", + style=discord.ButtonStyle.secondary, + custom_id="cancel_move_selection", + ) cancel_button.callback = self._cancel_move_selection_callback self.add_item(cancel_button) - + # Update the message file_letter = chr(65 + self.selected_file) # Convert to A-H - content = self.get_board_message(f"Selected file {file_letter}. Now select a rank (1-8). ") - board_image = generate_board_image(self.board, self.last_move, perspective_white=(self.player_color == chess.WHITE)) - + content = self.get_board_message( + f"Selected file {file_letter}. Now select a rank (1-8). " + ) + board_image = generate_board_image( + self.board, + self.last_move, + perspective_white=(self.player_color == chess.WHITE), + ) + if interaction.response.is_done(): - await interaction.edit_original_response(content=content, attachments=[board_image], view=self) + await interaction.edit_original_response( + content=content, attachments=[board_image], view=self + ) else: - await interaction.response.edit_message(content=content, attachments=[board_image], view=self) - + await interaction.response.edit_message( + content=content, attachments=[board_image], view=self + ) + async def show_valid_moves(self, interaction: discord.Interaction): """Shows buttons for selecting a destination square from valid moves.""" # Clear existing buttons self.clear_items() - + # Add buttons for each valid move for move in self.valid_moves: self.add_item(self.MoveButton(move)) - + # Add a back button to return to file selection - back_button = ui.Button(label="Back", style=discord.ButtonStyle.secondary, custom_id="back_to_file_selection") + back_button = ui.Button( + label="Back", + style=discord.ButtonStyle.secondary, + custom_id="back_to_file_selection", + ) back_button.callback = self._back_to_file_selection_callback self.add_item(back_button) - + # Add a cancel button to return to normal view - cancel_button = ui.Button(label="Cancel", style=discord.ButtonStyle.secondary, custom_id="cancel_move_selection") + cancel_button = ui.Button( + label="Cancel", + style=discord.ButtonStyle.secondary, + custom_id="cancel_move_selection", + ) cancel_button.callback = self._cancel_move_selection_callback self.add_item(cancel_button) - + # Update the message with valid move dots file_letter = chr(65 + self.selected_file) # Convert to A-H rank_number = 8 - chess.square_rank(self.selected_square) # Convert to 1-8 - content = self.get_board_message(f"Selected piece at {file_letter}{rank_number}. Choose a destination square. ") - board_image = generate_board_image( - self.board, - self.last_move, - perspective_white=(self.player_color == chess.WHITE), - valid_moves=self.valid_moves + content = self.get_board_message( + f"Selected piece at {file_letter}{rank_number}. Choose a destination square. " ) - + board_image = generate_board_image( + self.board, + self.last_move, + perspective_white=(self.player_color == chess.WHITE), + valid_moves=self.valid_moves, + ) + if interaction.response.is_done(): - await interaction.edit_original_response(content=content, attachments=[board_image], view=self) + await interaction.edit_original_response( + content=content, attachments=[board_image], view=self + ) else: - await interaction.response.edit_message(content=content, attachments=[board_image], view=self) - + await interaction.response.edit_message( + content=content, attachments=[board_image], view=self + ) + async def _back_to_file_selection_callback(self, interaction: discord.Interaction): """Callback for the 'Back' button to return to file selection.""" if interaction.user != self.player: - await interaction.response.send_message("This is not your game!", ephemeral=True) + await interaction.response.send_message( + "This is not your game!", ephemeral=True + ) return await self.show_file_selection(interaction) - + async def _cancel_move_selection_callback(self, interaction: discord.Interaction): """Callback for the 'Cancel' button to exit move selection mode.""" if interaction.user != self.player: - await interaction.response.send_message("This is not your game!", ephemeral=True) + await interaction.response.send_message( + "This is not your game!", ephemeral=True + ) return - + # Reset move selection state self.move_selection_mode = False self.selected_file = None self.selected_rank = None self.selected_square = None self.valid_moves = [] - + # Restore normal view self.clear_items() self.add_item(self.MakeMoveButton()) self.add_item(self.SelectMoveButton()) self.add_item(self.ResignButton()) - + # Update the message content = self.get_board_message("Move selection cancelled. ") - board_image = generate_board_image(self.board, self.last_move, perspective_white=(self.player_color == chess.WHITE)) - + board_image = generate_board_image( + self.board, + self.last_move, + perspective_white=(self.player_color == chess.WHITE), + ) + if interaction.response.is_done(): - await interaction.edit_original_response(content=content, attachments=[board_image], view=self) + await interaction.edit_original_response( + content=content, attachments=[board_image], view=self + ) else: - await interaction.response.edit_message(content=content, attachments=[board_image], view=self) - - async def handle_square_click(self, interaction: discord.Interaction, x: int, y: int): + await interaction.response.edit_message( + content=content, attachments=[board_image], view=self + ) + + async def handle_square_click( + self, interaction: discord.Interaction, x: int, y: int + ): """Legacy method for handling square clicks from ChessBotButton.""" # This method is kept for backward compatibility await interaction.response.send_message( "Please use the 'Select Move' button for the new button-driven move selection interface.", - ephemeral=True + ephemeral=True, ) - + # --- DM Helper Methods (Adapted for Bot Game) --- async def _get_dm_content(self, result: Optional[str] = None) -> str: @@ -1708,29 +2261,43 @@ class ChessBotView(ui.View): # Update PGN headers if result is provided and game is over if result: - pgn_result_code = "*" # Default + pgn_result_code = "*" # Default if result in ["1-0", "0-1", "1/2-1/2"]: pgn_result_code = result elif "wins" in result: - if (self.player_color == chess.WHITE and "White" in result) or \ - (self.player_color == chess.BLACK and "Black" in result): - pgn_result_code = "1-0" if self.player_color == chess.WHITE else "0-1" # Player won + if (self.player_color == chess.WHITE and "White" in result) or ( + self.player_color == chess.BLACK and "Black" in result + ): + pgn_result_code = ( + "1-0" if self.player_color == chess.WHITE else "0-1" + ) # Player won else: - pgn_result_code = "0-1" if self.player_color == chess.WHITE else "1-0" # Bot won + pgn_result_code = ( + "0-1" if self.player_color == chess.WHITE else "1-0" + ) # Bot won elif "draw" in result: pgn_result_code = "1/2-1/2" # Only update if not already set or if changing from '*' - if "Result" not in self.game_pgn.headers or self.game_pgn.headers["Result"] == "*": - self.game_pgn.headers["Result"] = pgn_result_code + if ( + "Result" not in self.game_pgn.headers + or self.game_pgn.headers["Result"] == "*" + ): + self.game_pgn.headers["Result"] = pgn_result_code # Use an exporter for cleaner PGN output - exporter = chess.pgn.StringExporter(headers=True, variations=True, comments=True) + exporter = chess.pgn.StringExporter( + headers=True, variations=True, comments=True + ) pgn_string = self.game_pgn.accept(exporter) - pgn_preview = pgn_string[:1500] + "..." if len(pgn_string) > 1500 else pgn_string + pgn_preview = ( + pgn_string[:1500] + "..." if len(pgn_string) > 1500 else pgn_string + ) - content = f"**Game vs {opponent_name}** ({opponent_color_str})\n\n" \ - f"**FEN:**\n`{fen}`\n\n" \ - f"**PGN:**\n```pgn\n{pgn_preview}\n```" + content = ( + f"**Game vs {opponent_name}** ({opponent_color_str})\n\n" + f"**FEN:**\n`{fen}`\n\n" + f"**PGN:**\n```pgn\n{pgn_preview}\n```" + ) if result: content += f"\n\n**Status:** {result}" @@ -1749,9 +2316,11 @@ class ChessBotView(ui.View): if dm_message: try: await dm_message.edit(content=content) - return # Edited successfully + return # Edited successfully except discord.NotFound: - print(f"DM message for {player.display_name} not found, will send a new one.") + print( + f"DM message for {player.display_name} not found, will send a new one." + ) self.player_dm_message = None dm_message = None except discord.Forbidden: @@ -1759,7 +2328,9 @@ class ChessBotView(ui.View): self.player_dm_message = None dm_message = None except discord.HTTPException as e: - print(f"HTTP error editing DM for {player.display_name}: {e}. Will try sending.") + print( + f"HTTP error editing DM for {player.display_name}: {e}. Will try sending." + ) self.player_dm_message = None dm_message = None @@ -1768,11 +2339,15 @@ class ChessBotView(ui.View): self.player_dm_message = new_dm_message except discord.Forbidden: - print(f"Cannot send DM to {player.display_name} (Forbidden). User likely has DMs disabled.") + print( + f"Cannot send DM to {player.display_name} (Forbidden). User likely has DMs disabled." + ) self.player_dm_message = None except discord.HTTPException as e: print(f"Failed to send/edit DM for {player.display_name}: {e}") self.player_dm_message = None except Exception as e: - print(f"Unexpected error sending/updating DM for {player.display_name}: {e}") + print( + f"Unexpected error sending/updating DM for {player.display_name}: {e}" + ) self.player_dm_message = None diff --git a/cogs/games/coinflip_game.py b/cogs/games/coinflip_game.py index 9c5cf19..e3ff561 100644 --- a/cogs/games/coinflip_game.py +++ b/cogs/games/coinflip_game.py @@ -2,6 +2,7 @@ import discord from discord import ui from typing import Optional + class CoinFlipView(ui.View): def __init__(self, initiator: discord.Member, opponent: discord.Member): super().__init__(timeout=180.0) # 3-minute timeout @@ -11,7 +12,9 @@ class CoinFlipView(ui.View): self.opponent_choice: Optional[str] = None self.result: Optional[str] = None self.winner: Optional[discord.Member] = None - self.message: Optional[discord.Message] = None # To store the message for editing + self.message: Optional[discord.Message] = ( + None # To store the message for editing + ) # Initial state: Initiator chooses side self.add_item(self.HeadsButton()) @@ -22,47 +25,59 @@ class CoinFlipView(ui.View): # Stage 1: Initiator chooses Heads/Tails if self.initiator_choice is None: if interaction.user.id != self.initiator.id: - await interaction.response.send_message("Only the initiator can choose their side.", ephemeral=True) + await interaction.response.send_message( + "Only the initiator can choose their side.", ephemeral=True + ) return False return True # Stage 2: Opponent Accepts/Declines else: if interaction.user.id != self.opponent.id: - await interaction.response.send_message("Only the opponent can accept or decline the game.", ephemeral=True) + await interaction.response.send_message( + "Only the opponent can accept or decline the game.", ephemeral=True + ) return False return True async def update_view_state(self, interaction: discord.Interaction): """Updates the view items based on the current state.""" self.clear_items() - if self.initiator_choice is None: # Should not happen if called correctly, but for safety + if ( + self.initiator_choice is None + ): # Should not happen if called correctly, but for safety self.add_item(self.HeadsButton()) self.add_item(self.TailsButton()) - elif self.result is None: # Opponent needs to accept/decline + elif self.result is None: # Opponent needs to accept/decline self.add_item(self.AcceptButton()) self.add_item(self.DeclineButton()) - else: # Game finished, disable all (handled by disabling in callbacks) - pass # No items needed, or keep disabled ones + else: # Game finished, disable all (handled by disabling in callbacks) + pass # No items needed, or keep disabled ones # Edit the original message if self.message: try: # Use interaction response to edit if available, otherwise use message.edit # This handles the case where the interaction is the one causing the edit - if interaction and interaction.message and interaction.message.id == self.message.id: - await interaction.response.edit_message(view=self) + if ( + interaction + and interaction.message + and interaction.message.id == self.message.id + ): + await interaction.response.edit_message(view=self) else: - await self.message.edit(view=self) + await self.message.edit(view=self) except discord.NotFound: print("CoinFlipView: Failed to edit message, likely deleted.") except discord.Forbidden: print("CoinFlipView: Missing permissions to edit message.") except discord.InteractionResponded: - # If interaction already responded (e.g. initial choice), use followup or webhook - try: - await interaction.edit_original_response(view=self) - except discord.HTTPException: - print("CoinFlipView: Failed to edit original response after InteractionResponded.") + # If interaction already responded (e.g. initial choice), use followup or webhook + try: + await interaction.edit_original_response(view=self) + except discord.HTTPException: + print( + "CoinFlipView: Failed to edit original response after InteractionResponded." + ) async def disable_all_buttons(self): for item in self.children: @@ -71,57 +86,68 @@ class CoinFlipView(ui.View): if self.message: try: await self.message.edit(view=self) - except discord.NotFound: pass # Ignore if message is gone - except discord.Forbidden: pass # Ignore if permissions lost + except discord.NotFound: + pass # Ignore if message is gone + except discord.Forbidden: + pass # Ignore if permissions lost async def on_timeout(self): - if self.message and not self.is_finished(): # Check if not already stopped + if self.message and not self.is_finished(): # Check if not already stopped await self.disable_all_buttons() timeout_msg = f"Coin flip game between {self.initiator.mention} and {self.opponent.mention} timed out." try: await self.message.edit(content=timeout_msg, view=self) - except discord.NotFound: pass - except discord.Forbidden: pass + except discord.NotFound: + pass + except discord.Forbidden: + pass self.stop() # --- Button Definitions --- class HeadsButton(ui.Button): def __init__(self): - super().__init__(label="Heads", style=discord.ButtonStyle.primary, custom_id="cf_heads") + super().__init__( + label="Heads", style=discord.ButtonStyle.primary, custom_id="cf_heads" + ) async def callback(self, interaction: discord.Interaction): - view: 'CoinFlipView' = self.view + view: "CoinFlipView" = self.view view.initiator_choice = "Heads" view.opponent_choice = "Tails" # Update message and view for opponent - await view.update_view_state(interaction) # Switches to Accept/Decline - await interaction.edit_original_response( # Edit the message content *after* updating state + await view.update_view_state(interaction) # Switches to Accept/Decline + await interaction.edit_original_response( # Edit the message content *after* updating state content=f"{view.opponent.mention}, {view.initiator.mention} has chosen **Heads**! You get **Tails**. Do you accept?" ) class TailsButton(ui.Button): def __init__(self): - super().__init__(label="Tails", style=discord.ButtonStyle.primary, custom_id="cf_tails") + super().__init__( + label="Tails", style=discord.ButtonStyle.primary, custom_id="cf_tails" + ) async def callback(self, interaction: discord.Interaction): - view: 'CoinFlipView' = self.view + view: "CoinFlipView" = self.view view.initiator_choice = "Tails" view.opponent_choice = "Heads" # Update message and view for opponent - await view.update_view_state(interaction) # Switches to Accept/Decline - await interaction.edit_original_response( # Edit the message content *after* updating state + await view.update_view_state(interaction) # Switches to Accept/Decline + await interaction.edit_original_response( # Edit the message content *after* updating state content=f"{view.opponent.mention}, {view.initiator.mention} has chosen **Tails**! You get **Heads**. Do you accept?" ) class AcceptButton(ui.Button): def __init__(self): - super().__init__(label="Accept", style=discord.ButtonStyle.success, custom_id="cf_accept") + super().__init__( + label="Accept", style=discord.ButtonStyle.success, custom_id="cf_accept" + ) async def callback(self, interaction: discord.Interaction): - view: 'CoinFlipView' = self.view + view: "CoinFlipView" = self.view # Perform the coin flip import random + view.result = random.choice(["Heads", "Tails"]) # Determine winner @@ -143,10 +169,14 @@ class CoinFlipView(ui.View): class DeclineButton(ui.Button): def __init__(self): - super().__init__(label="Decline", style=discord.ButtonStyle.danger, custom_id="cf_decline") + super().__init__( + label="Decline", + style=discord.ButtonStyle.danger, + custom_id="cf_decline", + ) async def callback(self, interaction: discord.Interaction): - view: 'CoinFlipView' = self.view + view: "CoinFlipView" = self.view decline_message = f"{view.opponent.mention} has declined the coin flip game from {view.initiator.mention}." await view.disable_all_buttons() await interaction.response.edit_message(content=decline_message, view=view) diff --git a/cogs/games/rps_game.py b/cogs/games/rps_game.py index 5cae0d7..4b7c790 100644 --- a/cogs/games/rps_game.py +++ b/cogs/games/rps_game.py @@ -2,6 +2,7 @@ import discord from discord import ui from typing import Optional + class RockPaperScissorsView(ui.View): def __init__(self, initiator: discord.Member, opponent: discord.Member): super().__init__(timeout=180.0) # 3-minute timeout @@ -10,14 +11,16 @@ class RockPaperScissorsView(ui.View): self.initiator_choice: Optional[str] = None self.opponent_choice: Optional[str] = None self.message: Optional[discord.Message] = None - + async def interaction_check(self, interaction: discord.Interaction) -> bool: """Check if the person interacting is part of the game.""" if interaction.user.id not in [self.initiator.id, self.opponent.id]: - await interaction.response.send_message("This is not your game!", ephemeral=True) + await interaction.response.send_message( + "This is not your game!", ephemeral=True + ) return False return True - + async def disable_all_buttons(self): for item in self.children: if isinstance(item, ui.Button): @@ -25,65 +28,77 @@ class RockPaperScissorsView(ui.View): if self.message: try: await self.message.edit(view=self) - except discord.NotFound: pass - except discord.Forbidden: pass - + except discord.NotFound: + pass + except discord.Forbidden: + pass + async def on_timeout(self): if self.message: await self.disable_all_buttons() timeout_msg = f"Rock Paper Scissors game between {self.initiator.mention} and {self.opponent.mention} timed out." try: await self.message.edit(content=timeout_msg, view=self) - except discord.NotFound: pass - except discord.Forbidden: pass + except discord.NotFound: + pass + except discord.Forbidden: + pass self.stop() - + # Determine winner between two choices def get_winner(self, choice1: str, choice2: str) -> Optional[str]: if choice1 == choice2: return None # Tie - if (choice1 == "Rock" and choice2 == "Scissors") or \ - (choice1 == "Paper" and choice2 == "Rock") or \ - (choice1 == "Scissors" and choice2 == "Paper"): + if ( + (choice1 == "Rock" and choice2 == "Scissors") + or (choice1 == "Paper" and choice2 == "Rock") + or (choice1 == "Scissors" and choice2 == "Paper") + ): return "player1" else: return "player2" - + @ui.button(label="Rock", style=discord.ButtonStyle.primary) async def rock_button(self, interaction: discord.Interaction, button: ui.Button): await self.make_choice(interaction, "Rock") - + @ui.button(label="Paper", style=discord.ButtonStyle.success) async def paper_button(self, interaction: discord.Interaction, button: ui.Button): await self.make_choice(interaction, "Paper") - + @ui.button(label="Scissors", style=discord.ButtonStyle.danger) - async def scissors_button(self, interaction: discord.Interaction, button: ui.Button): + async def scissors_button( + self, interaction: discord.Interaction, button: ui.Button + ): await self.make_choice(interaction, "Scissors") - + async def make_choice(self, interaction: discord.Interaction, choice: str): player = interaction.user - + # Record the choice for the appropriate player if player.id == self.initiator.id: self.initiator_choice = choice - await interaction.response.send_message(f"You chose **{choice}**!", ephemeral=True) + await interaction.response.send_message( + f"You chose **{choice}**!", ephemeral=True + ) else: # opponent self.opponent_choice = choice - await interaction.response.send_message(f"You chose **{choice}**!", ephemeral=True) - + await interaction.response.send_message( + f"You chose **{choice}**!", ephemeral=True + ) + # Check if both players have made their choices if self.initiator_choice and self.opponent_choice: # Determine the winner winner_id = self.get_winner(self.initiator_choice, self.opponent_choice) - + if winner_id is None: result = "It's a tie! 🤝" elif winner_id == "player1": result = f"**{self.initiator.mention}** wins! 🎉" else: result = f"**{self.opponent.mention}** wins! 🎉" - + # Update the message with the results result_message = ( f"**Rock Paper Scissors Results**\n" @@ -91,7 +106,7 @@ class RockPaperScissorsView(ui.View): f"{self.opponent.mention} chose **{self.opponent_choice}**\n\n" f"{result}" ) - + await self.disable_all_buttons() await self.message.edit(content=result_message, view=self) self.stop() diff --git a/cogs/games/tictactoe_game.py b/cogs/games/tictactoe_game.py index eb1d972..170985a 100644 --- a/cogs/games/tictactoe_game.py +++ b/cogs/games/tictactoe_game.py @@ -2,12 +2,13 @@ import discord from discord import ui from typing import Optional, List + # --- Tic Tac Toe (Player vs Player) --- -class TicTacToeButton(ui.Button['TicTacToeView']): +class TicTacToeButton(ui.Button["TicTacToeView"]): def __init__(self, x: int, y: int): # Use a visible character for the label as Discord API requires non-empty labels # Empty string ('') or space character (' ') are not allowed as per Discord API requirements - super().__init__(style=discord.ButtonStyle.secondary, label='·', row=y) + super().__init__(style=discord.ButtonStyle.secondary, label="·", row=y) self.x = x self.y = y @@ -17,24 +18,35 @@ class TicTacToeButton(ui.Button['TicTacToeView']): # Check if it's the correct player's turn if interaction.user != view.current_player: - await interaction.response.send_message("It's not your turn!", ephemeral=True) + await interaction.response.send_message( + "It's not your turn!", ephemeral=True + ) return # Check if the spot is already taken if view.board[self.y][self.x] is not None: - await interaction.response.send_message("This spot is already taken!", ephemeral=True) + await interaction.response.send_message( + "This spot is already taken!", ephemeral=True + ) return # Update board state and button appearance view.board[self.y][self.x] = view.current_symbol self.label = view.current_symbol - self.style = discord.ButtonStyle.success if view.current_symbol == 'X' else discord.ButtonStyle.danger + self.style = ( + discord.ButtonStyle.success + if view.current_symbol == "X" + else discord.ButtonStyle.danger + ) self.disabled = True # Check for win/draw if view.check_win(): view.winner = view.current_player - await view.end_game(interaction, f"🎉 {view.winner.mention} ({view.current_symbol}) wins! 🎉") + await view.end_game( + interaction, + f"🎉 {view.winner.mention} ({view.current_symbol}) wins! 🎉", + ) return elif view.check_draw(): await view.end_game(interaction, "🤝 It's a draw! 🤝") @@ -44,14 +56,17 @@ class TicTacToeButton(ui.Button['TicTacToeView']): view.switch_player() await view.update_board_message(interaction) + class TicTacToeView(ui.View): def __init__(self, initiator: discord.Member, opponent: discord.Member): - super().__init__(timeout=300.0) # 5 minute timeout + super().__init__(timeout=300.0) # 5 minute timeout self.initiator = initiator self.opponent = opponent - self.current_player = initiator # Initiator starts as X - self.current_symbol = 'X' - self.board: List[List[Optional[str]]] = [[None for _ in range(3)] for _ in range(3)] + self.current_player = initiator # Initiator starts as X + self.current_symbol = "X" + self.board: List[List[Optional[str]]] = [ + [None for _ in range(3)] for _ in range(3) + ] self.winner: Optional[discord.Member] = None self.message: Optional[discord.Message] = None @@ -63,10 +78,10 @@ class TicTacToeView(ui.View): def switch_player(self): if self.current_player == self.initiator: self.current_player = self.opponent - self.current_symbol = 'O' + self.current_symbol = "O" else: self.current_player = self.initiator - self.current_symbol = 'X' + self.current_symbol = "X" def check_win(self) -> bool: s = self.current_symbol @@ -111,17 +126,22 @@ class TicTacToeView(ui.View): timeout_msg = f"Tic Tac Toe game between {self.initiator.mention} and {self.opponent.mention} timed out." try: await self.message.edit(content=timeout_msg, view=self) - except discord.NotFound: pass - except discord.Forbidden: pass + except discord.NotFound: + pass + except discord.Forbidden: + pass self.stop() + # --- Tic Tac Toe Bot Game --- -class BotTicTacToeButton(ui.Button['BotTicTacToeView']): +class BotTicTacToeButton(ui.Button["BotTicTacToeView"]): def __init__(self, x: int, y: int): - super().__init__(style=discord.ButtonStyle.secondary, label='·', row=y) + super().__init__(style=discord.ButtonStyle.secondary, label="·", row=y) self.x = x self.y = y - self.position = y * 3 + x # Convert to position index (0-8) for the TicTacToe engine + self.position = ( + y * 3 + x + ) # Convert to position index (0-8) for the TicTacToe engine async def callback(self, interaction: discord.Interaction): assert self.view is not None @@ -129,16 +149,18 @@ class BotTicTacToeButton(ui.Button['BotTicTacToeView']): # Check if it's the player's turn if interaction.user != view.player: - await interaction.response.send_message("This is not your game!", ephemeral=True) + await interaction.response.send_message( + "This is not your game!", ephemeral=True + ) return # Try to make the move in the game engine try: view.game.play_turn(self.position) - self.label = 'X' # Player is always X + self.label = "X" # Player is always X self.style = discord.ButtonStyle.success self.disabled = True - # Check if game is over after player's move + # Check if game is over after player's move if view.game.is_game_over(): await view.end_game(interaction) return @@ -146,6 +168,7 @@ class BotTicTacToeButton(ui.Button['BotTicTacToeView']): # Now it's the bot's turn - defer without thinking message await interaction.response.defer() import asyncio + await asyncio.sleep(1) # Brief pause to simulate bot "thinking" # Bot makes its move @@ -154,8 +177,12 @@ class BotTicTacToeButton(ui.Button['BotTicTacToeView']): # Update the button for the bot's move bot_y, bot_x = divmod(bot_move, 3) for child in view.children: - if isinstance(child, BotTicTacToeButton) and child.x == bot_x and child.y == bot_y: - child.label = 'O' # Bot is always O + if ( + isinstance(child, BotTicTacToeButton) + and child.x == bot_x + and child.y == bot_y + ): + child.label = "O" # Bot is always O child.style = discord.ButtonStyle.danger child.disabled = True break @@ -169,12 +196,13 @@ class BotTicTacToeButton(ui.Button['BotTicTacToeView']): await interaction.followup.edit_message( message_id=view.message.id, content=f"Tic Tac Toe: {view.player.mention} (X) vs Bot (O) - Difficulty: {view.game.ai_difficulty.capitalize()}\n\nYour turn!", - view=view + view=view, ) except ValueError as e: await interaction.response.send_message(f"Error: {str(e)}", ephemeral=True) + class BotTicTacToeView(ui.View): def __init__(self, game, player: discord.Member): super().__init__(timeout=300.0) # 5 minute timeout @@ -197,19 +225,19 @@ class BotTicTacToeView(ui.View): board = self.game.get_board() rows = [] for i in range(0, 9, 3): - row = board[i:i+3] + row = board[i : i + 3] # Replace spaces with emoji equivalents for better visualization - row = [cell if cell != ' ' else '⬜' for cell in row] - row = [cell.replace('X', '❌').replace('O', '⭕') for cell in row] - rows.append(' '.join(row)) - return '\n'.join(rows) + row = [cell if cell != " " else "⬜" for cell in row] + row = [cell.replace("X", "❌").replace("O", "⭕") for cell in row] + rows.append(" ".join(row)) + return "\n".join(rows) async def end_game(self, interaction: discord.Interaction): await self.disable_all_buttons() winner = self.game.get_winner() if winner: - if winner == 'X': # Player wins + if winner == "X": # Player wins content = f"🎉 {self.player.mention} wins! 🎉" else: # Bot wins content = f"The bot ({self.game.ai_difficulty.capitalize()}) wins! Better luck next time." @@ -224,14 +252,17 @@ class BotTicTacToeView(ui.View): await interaction.followup.edit_message( message_id=self.message.id, content=f"{content}\n\n{board_display}", - view=self + view=self, ) except (discord.NotFound, discord.HTTPException): # Fallback for interaction timeouts if self.message: try: - await self.message.edit(content=f"{content}\n\n{board_display}", view=self) - except: pass + await self.message.edit( + content=f"{content}\n\n{board_display}", view=self + ) + except: + pass self.stop() async def on_timeout(self): @@ -240,8 +271,10 @@ class BotTicTacToeView(ui.View): try: await self.message.edit( content=f"Tic Tac Toe game for {self.player.mention} timed out.", - view=self + view=self, ) - except discord.NotFound: pass - except discord.Forbidden: pass + except discord.NotFound: + pass + except discord.Forbidden: + pass self.stop() diff --git a/cogs/games/wordle_game.py b/cogs/games/wordle_game.py index 66eafa1..6264793 100644 --- a/cogs/games/wordle_game.py +++ b/cogs/games/wordle_game.py @@ -5,6 +5,7 @@ import io from PIL import Image, ImageDraw, ImageFont import os + class WordleGame: """Class to handle Wordle game logic""" @@ -94,7 +95,10 @@ class WordleGame: return result -def generate_board_image(game: WordleGame, used_letters: Set[str] = None) -> discord.File: + +def generate_board_image( + game: WordleGame, used_letters: Set[str] = None +) -> discord.File: """ Generate an image of the Wordle game board @@ -107,9 +111,9 @@ def generate_board_image(game: WordleGame, used_letters: Set[str] = None) -> dis """ # Define colors and dimensions CORRECT_COLOR = (106, 170, 100) # Green - PRESENT_COLOR = (201, 180, 88) # Yellow - ABSENT_COLOR = (120, 124, 126) # Gray - UNUSED_COLOR = (211, 214, 218) # Light gray + PRESENT_COLOR = (201, 180, 88) # Yellow + ABSENT_COLOR = (120, 124, 126) # Gray + UNUSED_COLOR = (211, 214, 218) # Light gray BACKGROUND_COLOR = (255, 255, 255) # White TEXT_COLOR = (0, 0, 0) # Black @@ -129,15 +133,24 @@ def generate_board_image(game: WordleGame, used_letters: Set[str] = None) -> dis # Add space for keyboard keyboard_rows = ["QWERTYUIOP", "ASDFGHJKL", "ZXCVBNM"] - keyboard_width = max(len(row) for row in keyboard_rows) * (KEYBOARD_SQUARE_SIZE + KEYBOARD_SQUARE_MARGIN) + KEYBOARD_SQUARE_MARGIN - keyboard_height = len(keyboard_rows) * (KEYBOARD_SQUARE_SIZE + KEYBOARD_SQUARE_MARGIN) + KEYBOARD_SQUARE_MARGIN + keyboard_width = ( + max(len(row) for row in keyboard_rows) + * (KEYBOARD_SQUARE_SIZE + KEYBOARD_SQUARE_MARGIN) + + KEYBOARD_SQUARE_MARGIN + ) + keyboard_height = ( + len(keyboard_rows) * (KEYBOARD_SQUARE_SIZE + KEYBOARD_SQUARE_MARGIN) + + KEYBOARD_SQUARE_MARGIN + ) # Add space for status text status_height = 40 # Total image dimensions total_width = max(board_width, keyboard_width) + 40 # Add padding - total_height = board_height + KEYBOARD_MARGIN_TOP + keyboard_height + status_height + 40 # Add padding + total_height = ( + board_height + KEYBOARD_MARGIN_TOP + keyboard_height + status_height + 40 + ) # Add padding # Create image img = Image.new("RGB", (total_width, total_height), BACKGROUND_COLOR) @@ -148,7 +161,9 @@ def generate_board_image(game: WordleGame, used_letters: Set[str] = None) -> dis try: # Construct path relative to this script file SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) - PROJECT_ROOT = os.path.dirname(os.path.dirname(SCRIPT_DIR)) # Go up two levels from games dir + PROJECT_ROOT = os.path.dirname( + os.path.dirname(SCRIPT_DIR) + ) # Go up two levels from games dir FONT_DIR_NAME = "FONT" # Directory specified by user FONT_FILE_NAME = "DejaVuSans.ttf" font_path = os.path.join(PROJECT_ROOT, FONT_DIR_NAME, FONT_FILE_NAME) @@ -198,7 +213,12 @@ def generate_board_image(game: WordleGame, used_letters: Set[str] = None) -> dis square_color = ABSENT_COLOR # Draw the square - draw.rectangle([x, y, x + SQUARE_SIZE, y + SQUARE_SIZE], fill=square_color, outline=(0, 0, 0), width=2) + draw.rectangle( + [x, y, x + SQUARE_SIZE, y + SQUARE_SIZE], + fill=square_color, + outline=(0, 0, 0), + width=2, + ) # Draw the letter if there is one if letter: @@ -218,7 +238,9 @@ def generate_board_image(game: WordleGame, used_letters: Set[str] = None) -> dis text_y = y + (SQUARE_SIZE - text_height) // 2 # White text on colored backgrounds - text_color = (255, 255, 255) if square_color != UNUSED_COLOR else TEXT_COLOR + text_color = ( + (255, 255, 255) if square_color != UNUSED_COLOR else TEXT_COLOR + ) draw.text((text_x, text_y), letter, fill=text_color, font=font) # Draw the keyboard @@ -227,13 +249,24 @@ def generate_board_image(game: WordleGame, used_letters: Set[str] = None) -> dis for row_idx, row in enumerate(keyboard_rows): # Center this row of keys - row_width = len(row) * (KEYBOARD_SQUARE_SIZE + KEYBOARD_SQUARE_MARGIN) + KEYBOARD_SQUARE_MARGIN + row_width = ( + len(row) * (KEYBOARD_SQUARE_SIZE + KEYBOARD_SQUARE_MARGIN) + + KEYBOARD_SQUARE_MARGIN + ) row_start_x = (total_width - row_width) // 2 for col_idx, key in enumerate(row): key_lower = key.lower() - x = row_start_x + col_idx * (KEYBOARD_SQUARE_SIZE + KEYBOARD_SQUARE_MARGIN) + KEYBOARD_SQUARE_MARGIN - y = keyboard_start_y + row_idx * (KEYBOARD_SQUARE_SIZE + KEYBOARD_SQUARE_MARGIN) + KEYBOARD_SQUARE_MARGIN + x = ( + row_start_x + + col_idx * (KEYBOARD_SQUARE_SIZE + KEYBOARD_SQUARE_MARGIN) + + KEYBOARD_SQUARE_MARGIN + ) + y = ( + keyboard_start_y + + row_idx * (KEYBOARD_SQUARE_SIZE + KEYBOARD_SQUARE_MARGIN) + + KEYBOARD_SQUARE_MARGIN + ) # Default color for unused keys key_color = UNUSED_COLOR @@ -264,8 +297,12 @@ def generate_board_image(game: WordleGame, used_letters: Set[str] = None) -> dis key_color = ABSENT_COLOR # Draw the key - draw.rectangle([x, y, x + KEYBOARD_SQUARE_SIZE, y + KEYBOARD_SQUARE_SIZE], - fill=key_color, outline=(0, 0, 0), width=1) + draw.rectangle( + [x, y, x + KEYBOARD_SQUARE_SIZE, y + KEYBOARD_SQUARE_SIZE], + fill=key_color, + outline=(0, 0, 0), + width=1, + ) # Draw the letter try: @@ -283,13 +320,21 @@ def generate_board_image(game: WordleGame, used_letters: Set[str] = None) -> dis text_y = y + (KEYBOARD_SQUARE_SIZE - text_height) // 2 # White text on colored backgrounds (except unused) - text_color = (255, 255, 255) if key_color != UNUSED_COLOR else TEXT_COLOR + text_color = ( + (255, 255, 255) if key_color != UNUSED_COLOR else TEXT_COLOR + ) draw.text((text_x, text_y), key, fill=text_color, font=small_font) # Draw game status - status_y = keyboard_start_y + keyboard_height + 10 if used_letters else start_y + board_height + 10 + status_y = ( + keyboard_start_y + keyboard_height + 10 + if used_letters + else start_y + board_height + 10 + ) attempts_left = game.max_attempts - game.attempts - status_text = f"Attempts: {game.attempts}/{game.max_attempts} ({attempts_left} left)" + status_text = ( + f"Attempts: {game.attempts}/{game.max_attempts} ({attempts_left} left)" + ) # Add game result if game is over if game.game_over: @@ -315,11 +360,12 @@ def generate_board_image(game: WordleGame, used_letters: Set[str] = None) -> dis # Save image to a bytes buffer img_byte_arr = io.BytesIO() - img.save(img_byte_arr, format='PNG') + img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) return discord.File(fp=img_byte_arr, filename="wordle_board.png") + class WordleView(ui.View): """Discord UI View for the Wordle game""" @@ -341,7 +387,9 @@ class WordleView(ui.View): async def interaction_check(self, interaction: discord.Interaction) -> bool: """Ensure only the player can interact with the game""" if interaction.user.id != self.player.id: - await interaction.response.send_message("This is not your game!", ephemeral=True) + await interaction.response.send_message( + "This is not your game!", ephemeral=True + ) return False return True @@ -357,7 +405,9 @@ class WordleView(ui.View): for letter in guess: self.used_letters.add(letter) - async def update_message(self, interaction: Optional[discord.Interaction] = None, timeout: bool = False) -> None: + async def update_message( + self, interaction: Optional[discord.Interaction] = None, timeout: bool = False + ) -> None: """Update the game message with the current state""" if not self.message: return @@ -374,15 +424,23 @@ class WordleView(ui.View): if self.game.won: content += f"\n\n🎉 You won! The word was **{self.game.word.upper()}**." elif timeout: - content += f"\n\n⏰ Time's up! The word was **{self.game.word.upper()}**." + content += ( + f"\n\n⏰ Time's up! The word was **{self.game.word.upper()}**." + ) else: - content += f"\n\n❌ Game over! The word was **{self.game.word.upper()}**." + content += ( + f"\n\n❌ Game over! The word was **{self.game.word.upper()}**." + ) # Update the message with the image if interaction: - await interaction.response.edit_message(content=content, attachments=[board_image], view=self) + await interaction.response.edit_message( + content=content, attachments=[board_image], view=self + ) else: - await self.message.edit(content=content, attachments=[board_image], view=self) + await self.message.edit( + content=content, attachments=[board_image], view=self + ) @ui.button(label="Make a Guess", style=discord.ButtonStyle.primary) async def guess_button(self, interaction: discord.Interaction, _: ui.Button): @@ -391,6 +449,7 @@ class WordleView(ui.View): modal = WordleGuessModal(self) await interaction.response.send_modal(modal) + class WordleGuessModal(ui.Modal, title="Enter your guess"): """Modal for entering a Wordle guess""" @@ -399,7 +458,7 @@ class WordleGuessModal(ui.Modal, title="Enter your guess"): placeholder="Enter a 5-letter word", min_length=5, max_length=5, - required=True + required=True, ) def __init__(self, view: WordleView): @@ -412,11 +471,15 @@ class WordleGuessModal(ui.Modal, title="Enter your guess"): # Validate the guess if len(guess) != 5: - await interaction.response.send_message("Please enter a 5-letter word.", ephemeral=True) + await interaction.response.send_message( + "Please enter a 5-letter word.", ephemeral=True + ) return if not guess.isalpha(): - await interaction.response.send_message("Your guess must contain only letters.", ephemeral=True) + await interaction.response.send_message( + "Your guess must contain only letters.", ephemeral=True + ) return # Process the guess - this is the only place where make_guess should be called @@ -429,7 +492,10 @@ class WordleGuessModal(ui.Modal, title="Enter your guess"): # Update the game message await self.wordle_view.update_message(interaction) -def load_word_list(file_path: str = "words_alpha.txt", word_length: int = 5) -> List[str]: + +def load_word_list( + file_path: str = "words_alpha.txt", word_length: int = 5 +) -> List[str]: """ Load and filter words from a file diff --git a/cogs/games_cog.py b/cogs/games_cog.py index b5d3838..5a4e977 100644 --- a/cogs/games_cog.py +++ b/cogs/games_cog.py @@ -14,8 +14,11 @@ import ast # Import game implementations from separate files from .games.chess_game import ( - generate_board_image, MoveInputModal, ChessView, ChessBotView, - get_stockfish_path + generate_board_image, + MoveInputModal, + ChessView, + ChessBotView, + get_stockfish_path, ) from .games.coinflip_game import CoinFlipView from .games.tictactoe_game import TicTacToeView, BotTicTacToeView @@ -23,19 +26,19 @@ from .games.rps_game import RockPaperScissorsView from .games.basic_games import roll_dice, flip_coin, magic8ball_response, play_hangman from .games.wordle_game import WordleView, load_word_list + class GamesCog(commands.Cog, name="Games"): """Cog for game-related commands""" def __init__(self, bot: commands.Bot): self.bot = bot # Store active bot game views to manage engine resources - self.active_chess_bot_views = {} # Store by message ID - self.ttt_games = {} # Store TicTacToe game instances by user ID + self.active_chess_bot_views = {} # Store by message ID + self.ttt_games = {} # Store TicTacToe game instances by user ID # Create the main command group for this cog self.games_group = app_commands.Group( - name="fun", - description="Play various games with the bot or other users" + name="fun", description="Play various games with the bot or other users" ) # Register commands @@ -47,11 +50,11 @@ class GamesCog(commands.Cog, name="Games"): def _array_to_fen(self, board_array: List[List[str]], turn: chess.Color) -> str: """Converts an 8x8 array representation to a basic FEN string.""" fen_rows = [] - for rank_idx in range(8): # Iterate ranks 0-7 (corresponds to 8-1 in FEN) + for rank_idx in range(8): # Iterate ranks 0-7 (corresponds to 8-1 in FEN) rank_data = board_array[rank_idx] fen_row = "" empty_count = 0 - for piece in rank_data: # Iterate files a-h + for piece in rank_data: # Iterate files a-h if piece == ".": empty_count += 1 else: @@ -65,7 +68,7 @@ class GamesCog(commands.Cog, name="Games"): fen_rows.append(fen_row) piece_placement = "/".join(fen_rows) - turn_char = 'w' if turn == chess.WHITE else 'b' + turn_char = "w" if turn == chess.WHITE else "b" # Default castling, no en passant, 0 halfmove, 1 fullmove for simplicity from array fen = f"{piece_placement} {turn_char} - - 0 1" return fen @@ -79,7 +82,7 @@ class GamesCog(commands.Cog, name="Games"): name="coinflip", description="Flip a coin and get Heads or Tails", callback=self.games_coinflip_callback, - parent=self.games_group + parent=self.games_group, ) self.games_group.add_command(coinflip_command) @@ -88,7 +91,7 @@ class GamesCog(commands.Cog, name="Games"): name="roll", description="Roll a dice and get a number between 1 and 6", callback=self.games_roll_callback, - parent=self.games_group + parent=self.games_group, ) self.games_group.add_command(roll_command) @@ -97,7 +100,7 @@ class GamesCog(commands.Cog, name="Games"): name="magic8ball", description="Ask the magic 8 ball a question", callback=self.games_magic8ball_callback, - parent=self.games_group + parent=self.games_group, ) self.games_group.add_command(magic8ball_command) @@ -107,7 +110,7 @@ class GamesCog(commands.Cog, name="Games"): name="rps", description="Play Rock-Paper-Scissors against the bot", callback=self.games_rps_callback, - parent=self.games_group + parent=self.games_group, ) self.games_group.add_command(rps_command) @@ -116,7 +119,7 @@ class GamesCog(commands.Cog, name="Games"): name="rpschallenge", description="Challenge another user to a game of Rock-Paper-Scissors", callback=self.games_rpschallenge_callback, - parent=self.games_group + parent=self.games_group, ) self.games_group.add_command(rpschallenge_command) @@ -126,7 +129,7 @@ class GamesCog(commands.Cog, name="Games"): name="guess", description="Guess the number I'm thinking of (1-100)", callback=self.games_guess_callback, - parent=self.games_group + parent=self.games_group, ) self.games_group.add_command(guess_command) @@ -135,7 +138,7 @@ class GamesCog(commands.Cog, name="Games"): name="hangman", description="Play a game of Hangman", callback=self.games_hangman_callback, - parent=self.games_group + parent=self.games_group, ) self.games_group.add_command(hangman_command) @@ -145,7 +148,7 @@ class GamesCog(commands.Cog, name="Games"): name="tictactoe", description="Challenge another user to a game of Tic-Tac-Toe", callback=self.games_tictactoe_callback, - parent=self.games_group + parent=self.games_group, ) self.games_group.add_command(tictactoe_command) @@ -154,7 +157,7 @@ class GamesCog(commands.Cog, name="Games"): name="tictactoebot", description="Play a game of Tic-Tac-Toe against the bot", callback=self.games_tictactoebot_callback, - parent=self.games_group + parent=self.games_group, ) self.games_group.add_command(tictactoebot_command) @@ -164,7 +167,7 @@ class GamesCog(commands.Cog, name="Games"): name="chess", description="Challenge another user to a game of chess", callback=self.games_chess_callback, - parent=self.games_group + parent=self.games_group, ) self.games_group.add_command(chess_command) @@ -173,7 +176,7 @@ class GamesCog(commands.Cog, name="Games"): name="chessbot", description="Play chess against the bot", callback=self.games_chessbot_callback, - parent=self.games_group + parent=self.games_group, ) self.games_group.add_command(chessbot_command) @@ -182,7 +185,7 @@ class GamesCog(commands.Cog, name="Games"): name="loadchess", description="Load a chess game from FEN, PGN, or array representation", callback=self.games_loadchess_callback, - parent=self.games_group + parent=self.games_group, ) self.games_group.add_command(loadchess_command) @@ -191,7 +194,7 @@ class GamesCog(commands.Cog, name="Games"): name="wordle", description="Play a game of Wordle - guess the 5-letter word", callback=self.games_wordle_callback, - parent=self.games_group + parent=self.games_group, ) self.games_group.add_command(wordle_command) @@ -202,7 +205,7 @@ class GamesCog(commands.Cog, name="Games"): views_to_stop = list(self.active_chess_bot_views.values()) for view in views_to_stop: await view.stop_engine() - view.stop() # Stop the view itself + view.stop() # Stop the view itself self.active_chess_bot_views.clear() print("GamesCog unloaded.") @@ -218,32 +221,34 @@ class GamesCog(commands.Cog, name="Games"): result = roll_dice() await interaction.response.send_message(f"You rolled a **{result}**! 🎲") - async def games_magic8ball_callback(self, interaction: discord.Interaction, question: str = None): + async def games_magic8ball_callback( + self, interaction: discord.Interaction, question: str = None + ): """Callback for /games dice magic8ball command""" response = magic8ball_response() await interaction.response.send_message(f"🎱 {response}") # Games group callbacks - async def games_rps_callback(self, interaction: discord.Interaction, choice: app_commands.Choice[str]): + async def games_rps_callback( + self, interaction: discord.Interaction, choice: app_commands.Choice[str] + ): """Callback for /games rps command""" choices = ["Rock", "Paper", "Scissors"] bot_choice = random.choice(choices) - user_choice = choice.value # Get value from choice + user_choice = choice.value # Get value from choice if user_choice == bot_choice: result = "It's a tie!" - elif (user_choice == "Rock" and bot_choice == "Scissors") or \ - (user_choice == "Paper" and bot_choice == "Rock") or \ - (user_choice == "Scissors" and bot_choice == "Paper"): + elif ( + (user_choice == "Rock" and bot_choice == "Scissors") + or (user_choice == "Paper" and bot_choice == "Rock") + or (user_choice == "Scissors" and bot_choice == "Paper") + ): result = "You win! 🎉" else: result = "You lose! 😢" - emojis = { - "Rock": "🪨", - "Paper": "📄", - "Scissors": "✂️" - } + emojis = {"Rock": "🪨", "Paper": "📄", "Scissors": "✂️"} await interaction.response.send_message( f"You chose **{user_choice}** {emojis[user_choice]}\n" @@ -251,15 +256,21 @@ class GamesCog(commands.Cog, name="Games"): f"{result}" ) - async def games_rpschallenge_callback(self, interaction: discord.Interaction, opponent: discord.User): + async def games_rpschallenge_callback( + self, interaction: discord.Interaction, opponent: discord.User + ): """Callback for /games rpschallenge command""" initiator = interaction.user if opponent == initiator: - await interaction.response.send_message("You cannot challenge yourself!", ephemeral=True) + await interaction.response.send_message( + "You cannot challenge yourself!", ephemeral=True + ) return if opponent.bot: - await interaction.response.send_message("You cannot challenge a bot!", ephemeral=True) + await interaction.response.send_message( + "You cannot challenge a bot!", ephemeral=True + ) return view = RockPaperScissorsView(initiator, opponent) @@ -274,15 +285,23 @@ class GamesCog(commands.Cog, name="Games"): number_to_guess = random.randint(1, 100) if guess < 1 or guess > 100: - await interaction.response.send_message("Please guess a number between 1 and 100.", ephemeral=True) + await interaction.response.send_message( + "Please guess a number between 1 and 100.", ephemeral=True + ) return if guess == number_to_guess: - await interaction.response.send_message(f"🎉 Correct! The number was **{number_to_guess}**.") + await interaction.response.send_message( + f"🎉 Correct! The number was **{number_to_guess}**." + ) elif guess < number_to_guess: - await interaction.response.send_message(f"Too low! The number was {number_to_guess}.") + await interaction.response.send_message( + f"Too low! The number was {number_to_guess}." + ) else: - await interaction.response.send_message(f"Too high! The number was {number_to_guess}.") + await interaction.response.send_message( + f"Too high! The number was {number_to_guess}." + ) async def games_hangman_callback(self, interaction: discord.Interaction): """Callback for /games hangman command""" @@ -294,7 +313,10 @@ class GamesCog(commands.Cog, name="Games"): word_list = load_word_list("words_alpha.txt", 5) if not word_list: - await interaction.response.send_message("Error: Could not load word list or no 5-letter words found.", ephemeral=True) + await interaction.response.send_message( + "Error: Could not load word list or no 5-letter words found.", + ephemeral=True, + ) return # Select a random word @@ -305,37 +327,49 @@ class GamesCog(commands.Cog, name="Games"): # Generate the initial board image from .games.wordle_game import generate_board_image + initial_board_image = generate_board_image(view.game, view.used_letters) # Send the initial game message with the image await interaction.response.send_message( "# Wordle Game\n\nGuess the 5-letter word. You have 6 attempts.", file=initial_board_image, - view=view + view=view, ) # Store the message for later updates view.message = await interaction.original_response() # TicTacToe group callbacks - async def games_tictactoe_callback(self, interaction: discord.Interaction, opponent: discord.User): + async def games_tictactoe_callback( + self, interaction: discord.Interaction, opponent: discord.User + ): """Callback for /games tictactoe play command""" initiator = interaction.user if opponent == initiator: - await interaction.response.send_message("You cannot challenge yourself!", ephemeral=True) + await interaction.response.send_message( + "You cannot challenge yourself!", ephemeral=True + ) return if opponent.bot: - await interaction.response.send_message("You cannot challenge a bot! Use `/games tictactoe bot` instead.", ephemeral=True) + await interaction.response.send_message( + "You cannot challenge a bot! Use `/games tictactoe bot` instead.", + ephemeral=True, + ) return view = TicTacToeView(initiator, opponent) initial_message = f"Tic Tac Toe: {initiator.mention} (X) vs {opponent.mention} (O)\n\nTurn: **{initiator.mention} (X)**" await interaction.response.send_message(initial_message, view=view) message = await interaction.original_response() - view.message = message # Store message for timeout handling + view.message = message # Store message for timeout handling - async def games_tictactoebot_callback(self, interaction: discord.Interaction, difficulty: app_commands.Choice[str] = None): + async def games_tictactoebot_callback( + self, + interaction: discord.Interaction, + difficulty: app_commands.Choice[str] = None, + ): """Callback for /games tictactoe bot command""" # Use default if no choice is made (discord.py handles default value assignment) difficulty_value = difficulty.value if difficulty else "minimax" @@ -344,51 +378,69 @@ class GamesCog(commands.Cog, name="Games"): try: import sys import os + parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if parent_dir not in sys.path: sys.path.append(parent_dir) - from tictactoe import TicTacToe # Assuming tictactoe.py is in the parent directory + from tictactoe import ( + TicTacToe, + ) # Assuming tictactoe.py is in the parent directory except ImportError: - await interaction.response.send_message("Error: TicTacToe game engine module not found.", ephemeral=True) + await interaction.response.send_message( + "Error: TicTacToe game engine module not found.", ephemeral=True + ) return except Exception as e: - await interaction.response.send_message(f"Error importing TicTacToe module: {e}", ephemeral=True) - return + await interaction.response.send_message( + f"Error importing TicTacToe module: {e}", ephemeral=True + ) + return # Create a new game instance try: - game = TicTacToe(ai_player='O', ai_difficulty=difficulty_value) + game = TicTacToe(ai_player="O", ai_difficulty=difficulty_value) except Exception as e: - await interaction.response.send_message(f"Error initializing TicTacToe game: {e}", ephemeral=True) - return + await interaction.response.send_message( + f"Error initializing TicTacToe game: {e}", ephemeral=True + ) + return # Create a view for the user interface view = BotTicTacToeView(game, interaction.user) await interaction.response.send_message( f"Tic Tac Toe: {interaction.user.mention} (X) vs Bot (O) - Difficulty: {difficulty_value.capitalize()}\n\nYour turn!", - view=view + view=view, ) view.message = await interaction.original_response() # Chess group callbacks - async def games_chess_callback(self, interaction: discord.Interaction, opponent: discord.User): + async def games_chess_callback( + self, interaction: discord.Interaction, opponent: discord.User + ): """Callback for /games chess play command""" initiator = interaction.user if opponent == initiator: - await interaction.response.send_message("You cannot challenge yourself!", ephemeral=True) + await interaction.response.send_message( + "You cannot challenge yourself!", ephemeral=True + ) return if opponent.bot: - await interaction.response.send_message("You cannot challenge a bot! Use `/games chess bot` instead.", ephemeral=True) + await interaction.response.send_message( + "You cannot challenge a bot! Use `/games chess bot` instead.", + ephemeral=True, + ) return # Initiator is white, opponent is black view = ChessView(initiator, opponent) initial_status = f"Turn: **{initiator.mention}** (White)" initial_message = f"Chess: {initiator.mention} (White) vs {opponent.mention} (Black)\n\n{initial_status}" - board_image = generate_board_image(view.board) # Generate initial board image + board_image = generate_board_image(view.board) # Generate initial board image - await interaction.response.send_message(initial_message, file=board_image, view=view) + await interaction.response.send_message( + initial_message, file=board_image, view=view + ) message = await interaction.original_response() view.message = message @@ -396,7 +448,14 @@ class GamesCog(commands.Cog, name="Games"): asyncio.create_task(view._send_or_update_dm(view.white_player)) asyncio.create_task(view._send_or_update_dm(view.black_player)) - async def games_chessbot_callback(self, interaction: discord.Interaction, color: app_commands.Choice[str] = None, variant: app_commands.Choice[str] = None, skill_level: int = 10, think_time: float = 1.0): + async def games_chessbot_callback( + self, + interaction: discord.Interaction, + color: app_commands.Choice[str] = None, + variant: app_commands.Choice[str] = None, + skill_level: int = 10, + think_time: float = 1.0, + ): """Callback for /games chess bot command""" player = interaction.user player_color_str = color.value if color else "white" @@ -410,7 +469,10 @@ class GamesCog(commands.Cog, name="Games"): # Check if variant is supported (currently standard and chess960) supported_variants = ["standard", "chess960"] if variant_str not in supported_variants: - await interaction.response.send_message(f"Sorry, the variant '{variant_str}' is not currently supported. Choose from: {', '.join(supported_variants)}", ephemeral=True) + await interaction.response.send_message( + f"Sorry, the variant '{variant_str}' is not currently supported. Choose from: {', '.join(supported_variants)}", + ephemeral=True, + ) return # Defer response as engine start might take a moment @@ -422,24 +484,32 @@ class GamesCog(commands.Cog, name="Games"): # Store interaction temporarily for potential error reporting during init view._interaction = interaction await view.start_engine() - del view._interaction # Remove temporary attribute + del view._interaction # Remove temporary attribute - if view.engine is None or view.is_finished(): # Check if engine failed or view stopped during init - # Error message should have been sent by start_engine or view stopped itself - # Ensure we don't try to send another response if already handled - # No need to send another message here, start_engine handles it. - print("ChessBotView: Engine failed to start, stopping command execution.") - return # Stop if engine failed + if ( + view.engine is None or view.is_finished() + ): # Check if engine failed or view stopped during init + # Error message should have been sent by start_engine or view stopped itself + # Ensure we don't try to send another response if already handled + # No need to send another message here, start_engine handles it. + print("ChessBotView: Engine failed to start, stopping command execution.") + return # Stop if engine failed # Determine initial message based on who moves first - initial_status_prefix = "Your turn." if player_color == chess.WHITE else "Bot is thinking..." + initial_status_prefix = ( + "Your turn." if player_color == chess.WHITE else "Bot is thinking..." + ) initial_message_content = view.get_board_message(initial_status_prefix) - board_image = generate_board_image(view.board, perspective_white=(player_color == chess.WHITE)) + board_image = generate_board_image( + view.board, perspective_white=(player_color == chess.WHITE) + ) # Send the initial game state using followup - message = await interaction.followup.send(initial_message_content, file=board_image, view=view, wait=True) + message = await interaction.followup.send( + initial_message_content, file=board_image, view=view, wait=True + ) view.message = message - self.active_chess_bot_views[message.id] = view # Track the view + self.active_chess_bot_views[message.id] = view # Track the view # Send initial DM to player asyncio.create_task(view._send_or_update_dm()) @@ -449,24 +519,43 @@ class GamesCog(commands.Cog, name="Games"): # Don't await this, let it run in the background asyncio.create_task(view.make_bot_move()) - async def games_loadchess_callback(self, interaction: discord.Interaction, state: str, turn: Optional[app_commands.Choice[str]] = None, opponent: Optional[discord.User] = None, color: Optional[app_commands.Choice[str]] = None, skill_level: int = 10, think_time: float = 1.0): + async def games_loadchess_callback( + self, + interaction: discord.Interaction, + state: str, + turn: Optional[app_commands.Choice[str]] = None, + opponent: Optional[discord.User] = None, + color: Optional[app_commands.Choice[str]] = None, + skill_level: int = 10, + think_time: float = 1.0, + ): """Callback for /games chess load command""" await interaction.response.defer() initiator = interaction.user board = None load_error = None - loaded_pgn_game = None # To store the loaded PGN game object if parsed + loaded_pgn_game = None # To store the loaded PGN game object if parsed # --- Input Validation --- if not opponent and not color: - await interaction.followup.send("The 'color' parameter is required when playing against the bot.", ephemeral=True) + await interaction.followup.send( + "The 'color' parameter is required when playing against the bot.", + ephemeral=True, + ) return # --- Parsing Logic --- state_trimmed = state.strip() # 1. Try parsing as PGN - if state_trimmed.startswith("[Event") or ('.' in state_trimmed and ('O-O' in state_trimmed or 'x' in state_trimmed or state_trimmed[0].isdigit())): + if state_trimmed.startswith("[Event") or ( + "." in state_trimmed + and ( + "O-O" in state_trimmed + or "x" in state_trimmed + or state_trimmed[0].isdigit() + ) + ): try: pgn_io = io.StringIO(state_trimmed) loaded_pgn_game = chess.pgn.read_game(pgn_io) @@ -478,10 +567,14 @@ class GamesCog(commands.Cog, name="Games"): except Exception as e: load_error = f"Could not parse as PGN: {e}. Trying other formats." print(f"[Debug] PGN parsing failed: {e}") - loaded_pgn_game = None # Reset if PGN parsing failed + loaded_pgn_game = None # Reset if PGN parsing failed # 2. Try parsing as FEN (if not already parsed as PGN) - if board is None and '/' in state_trimmed and (' w ' in state_trimmed or ' b ' in state_trimmed): + if ( + board is None + and "/" in state_trimmed + and (" w " in state_trimmed or " b " in state_trimmed) + ): try: board = chess.Board(fen=state_trimmed) print(f"[Debug] Parsed as FEN: {state_trimmed}") @@ -496,18 +589,25 @@ class GamesCog(commands.Cog, name="Games"): if board is None: try: # Check if it looks like a list before eval - if not state_trimmed.startswith('[') or not state_trimmed.endswith(']'): - raise ValueError("Input does not look like a list array.") + if not state_trimmed.startswith("[") or not state_trimmed.endswith("]"): + raise ValueError("Input does not look like a list array.") board_array = ast.literal_eval(state_trimmed) print("[Debug] Attempting to parse as array...") - if not isinstance(board_array, list) or len(board_array) != 8 or \ - not all(isinstance(row, list) and len(row) == 8 for row in board_array): + if ( + not isinstance(board_array, list) + or len(board_array) != 8 + or not all( + isinstance(row, list) and len(row) == 8 for row in board_array + ) + ): raise ValueError("Invalid array structure. Must be 8x8 list.") if not turn: - load_error = "The 'turn' parameter is required when providing a board array." + load_error = ( + "The 'turn' parameter is required when providing a board array." + ) else: turn_color = chess.WHITE if turn.value == "white" else chess.BLACK fen = self._array_to_fen(board_array, turn_color) @@ -524,7 +624,9 @@ class GamesCog(commands.Cog, name="Games"): # --- Final Check and Error Handling --- if board is None: - final_error = load_error or "Failed to load board state from the provided input." + final_error = ( + load_error or "Failed to load board state from the provided input." + ) await interaction.followup.send(final_error, ephemeral=True) return @@ -532,30 +634,46 @@ class GamesCog(commands.Cog, name="Games"): if opponent: # Player vs Player if opponent == initiator: - await interaction.followup.send("You cannot challenge yourself!", ephemeral=True) + await interaction.followup.send( + "You cannot challenge yourself!", ephemeral=True + ) return if opponent.bot: - await interaction.followup.send("You cannot challenge a bot! Use `/games chess bot` or load without opponent.", ephemeral=True) + await interaction.followup.send( + "You cannot challenge a bot! Use `/games chess bot` or load without opponent.", + ephemeral=True, + ) return white_player = initiator if board.turn == chess.WHITE else opponent black_player = opponent if board.turn == chess.WHITE else initiator - view = ChessView(white_player, black_player, board=board) # Pass loaded board + view = ChessView( + white_player, black_player, board=board + ) # Pass loaded board # If loaded from PGN, set the game object in the view if loaded_pgn_game: view.game_pgn = loaded_pgn_game - view.pgn_node = loaded_pgn_game.end() # Start from the end node + view.pgn_node = loaded_pgn_game.end() # Start from the end node - current_player_mention = white_player.mention if board.turn == chess.WHITE else black_player.mention + current_player_mention = ( + white_player.mention + if board.turn == chess.WHITE + else black_player.mention + ) turn_color_name = "White" if board.turn == chess.WHITE else "Black" initial_status = f"Turn: **{current_player_mention}** ({turn_color_name})" - if board.is_check(): initial_status += " **Check!**" + if board.is_check(): + initial_status += " **Check!**" initial_message = f"Loaded Chess Game: {white_player.mention} (White) vs {black_player.mention} (Black)\n\n{initial_status}" - perspective_white = (board.turn == chess.WHITE) - board_image = generate_board_image(view.board, perspective_white=perspective_white) + perspective_white = board.turn == chess.WHITE + board_image = generate_board_image( + view.board, perspective_white=perspective_white + ) - message = await interaction.followup.send(initial_message, file=board_image, view=view, wait=True) + message = await interaction.followup.send( + initial_message, file=board_image, view=view, wait=True + ) view.message = message # Send initial DMs @@ -572,36 +690,21 @@ class GamesCog(commands.Cog, name="Games"): think_time = max(0.1, min(5.0, think_time)) variant_str = "chess960" if board.chess960 else "standard" - view = ChessBotView(player, player_color, variant_str, skill_level, think_time, board=board) # Pass loaded board + view = ChessBotView( + player, player_color, variant_str, skill_level, think_time, board=board + ) # Pass loaded board # If loaded from PGN, set the game object in the view if loaded_pgn_game: view.game_pgn = loaded_pgn_game - view.pgn_node = loaded_pgn_game.end() # Start from the end node + view.pgn_node = loaded_pgn_game.end() # Start from the end node - view._interaction = interaction # For error reporting during start + view._interaction = interaction # For error reporting during start await view.start_engine() - if hasattr(view, '_interaction'): del view._interaction + if hasattr(view, "_interaction"): + del view._interaction # --- Legacy Commands (kept for backward compatibility) --- - - - - - - - - - - - - - - - - - - # --- Prefix Commands (Legacy Support) --- @commands.command(name="coinflipbet", add_to_app_commands=False) @@ -652,17 +755,22 @@ class GamesCog(commands.Cog, name="Games"): view.message = message @commands.command(name="tictactoebot", add_to_app_commands=False) - async def tictactoebot_prefix(self, ctx: commands.Context, difficulty: str = "minimax"): + async def tictactoebot_prefix( + self, ctx: commands.Context, difficulty: str = "minimax" + ): """(Prefix) Play Tic-Tac-Toe against the bot.""" difficulty_value = difficulty.lower() valid_difficulties = ["random", "rule", "minimax"] if difficulty_value not in valid_difficulties: - await ctx.send(f"Invalid difficulty! Choose from: {', '.join(valid_difficulties)}") + await ctx.send( + f"Invalid difficulty! Choose from: {', '.join(valid_difficulties)}" + ) return try: import sys import os + parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if parent_dir not in sys.path: sys.path.append(parent_dir) @@ -671,19 +779,19 @@ class GamesCog(commands.Cog, name="Games"): await ctx.send("Error: TicTacToe game engine module not found.") return except Exception as e: - await ctx.send(f"Error importing TicTacToe module: {e}") - return + await ctx.send(f"Error importing TicTacToe module: {e}") + return try: - game = TicTacToe(ai_player='O', ai_difficulty=difficulty_value) + game = TicTacToe(ai_player="O", ai_difficulty=difficulty_value) except Exception as e: - await ctx.send(f"Error initializing TicTacToe game: {e}") - return + await ctx.send(f"Error initializing TicTacToe game: {e}") + return view = BotTicTacToeView(game, ctx.author) message = await ctx.send( f"Tic Tac Toe: {ctx.author.mention} (X) vs Bot (O) - Difficulty: {difficulty_value.capitalize()}\n\nYour turn!", - view=view + view=view, ) view.message = message @@ -715,14 +823,16 @@ class GamesCog(commands.Cog, name="Games"): # Identical logic to slash command, just using ctx.send if user_choice == bot_choice: result = "It's a tie!" - elif (user_choice == "Rock" and bot_choice == "Scissors") or \ - (user_choice == "Paper" and bot_choice == "Rock") or \ - (user_choice == "Scissors" and bot_choice == "Paper"): + elif ( + (user_choice == "Rock" and bot_choice == "Scissors") + or (user_choice == "Paper" and bot_choice == "Rock") + or (user_choice == "Scissors" and bot_choice == "Paper") + ): result = "You win! 🎉" else: result = "You lose! 😢" - emojis = { "Rock": "🪨", "Paper": "📄", "Scissors": "✂️" } + emojis = {"Rock": "🪨", "Paper": "📄", "Scissors": "✂️"} await ctx.send( f"You chose **{user_choice}** {emojis[user_choice]}\n" f"I chose **{bot_choice}** {emojis[bot_choice]}\n\n" @@ -762,7 +872,9 @@ class GamesCog(commands.Cog, name="Games"): word_list = load_word_list("words_alpha.txt", 5) if not word_list: - await ctx.send("Error: Could not load word list or no 5-letter words found.") + await ctx.send( + "Error: Could not load word list or no 5-letter words found." + ) return # Select a random word @@ -773,13 +885,14 @@ class GamesCog(commands.Cog, name="Games"): # Generate the initial board image from .games.wordle_game import generate_board_image + initial_board_image = generate_board_image(view.game, view.used_letters) # Send the initial game message with the image message = await ctx.send( "# Wordle Game\n\nGuess the 5-letter word. You have 6 attempts.", file=initial_board_image, - view=view + view=view, ) # Store the message for later updates @@ -801,10 +914,15 @@ class GamesCog(commands.Cog, name="Games"): else: await ctx.send(f"Too high! The number was {number_to_guess}.") + async def setup(bot: commands.Bot): """Set up the GamesCog with the bot.""" print("Setting up GamesCog...") cog = GamesCog(bot) await bot.add_cog(cog) - print(f"GamesCog setup complete with command group: {[cmd.name for cmd in bot.tree.get_commands() if cmd.name == 'games']}") - print(f"Available commands: {[cmd.name for cmd in cog.games_group.walk_commands() if isinstance(cmd, app_commands.Command)]}") + print( + f"GamesCog setup complete with command group: {[cmd.name for cmd in bot.tree.get_commands() if cmd.name == 'games']}" + ) + print( + f"Available commands: {[cmd.name for cmd in cog.games_group.walk_commands() if isinstance(cmd, app_commands.Command)]}" + ) diff --git a/cogs/gelbooru_watcher_base_cog.py b/cogs/gelbooru_watcher_base_cog.py index 612574b..d4ccccd 100644 --- a/cogs/gelbooru_watcher_base_cog.py +++ b/cogs/gelbooru_watcher_base_cog.py @@ -7,20 +7,22 @@ import random import aiohttp import time import json -import typing # Need this for Optional -import uuid # For subscription IDs +import typing # Need this for Optional +import uuid # For subscription IDs import asyncio -import logging # For logging -from datetime import datetime # For parsing ISO format timestamps -import abc # For Abstract Base Class +import logging # For logging +from datetime import datetime # For parsing ISO format timestamps +import abc # For Abstract Base Class # Setup logger for this cog log = logging.getLogger(__name__) + # Combined metaclass to resolve conflicts between CogMeta and ABCMeta class GelbooruWatcherMeta(commands.CogMeta, abc.ABCMeta): pass + class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMeta): def __init__( self, @@ -83,13 +85,19 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet log.info(f"{self.cog_name}Cog loaded.") if self.session is None or self.session.closed: self.session = aiohttp.ClientSession() - log.info(f"aiohttp ClientSession (re)created during cog_load for {self.cog_name}Cog.") + log.info( + f"aiohttp ClientSession (re)created during cog_load for {self.cog_name}Cog." + ) if not self.check_new_posts.is_running(): - if self.bot.is_ready(): - self.check_new_posts.start() - log.info(f"{self.cog_name} new post checker task started from cog_load.") - else: - log.warning(f"{self.cog_name}Cog loaded but bot not ready, task start deferred.") + if self.bot.is_ready(): + self.check_new_posts.start() + log.info( + f"{self.cog_name} new post checker task started from cog_load." + ) + else: + log.warning( + f"{self.cog_name}Cog loaded but bot not ready, task start deferred." + ) async def cog_unload(self): """Clean up resources when the cog is unloaded.""" @@ -105,7 +113,9 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet with open(self.cache_file, "r") as f: return json.load(f) except Exception as e: - log.error(f"Failed to load {self.cog_name} cache file ({self.cache_file}): {e}") + log.error( + f"Failed to load {self.cog_name} cache file ({self.cache_file}): {e}" + ) return {} def _save_cache(self): @@ -113,7 +123,9 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet with open(self.cache_file, "w") as f: json.dump(self.cache_data, f, indent=4) except Exception as e: - log.error(f"Failed to save {self.cog_name} cache file ({self.cache_file}): {e}") + log.error( + f"Failed to save {self.cog_name} cache file ({self.cache_file}): {e}" + ) def _load_subscriptions(self): if os.path.exists(self.subscriptions_file): @@ -121,16 +133,22 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet with open(self.subscriptions_file, "r") as f: return json.load(f) except Exception as e: - log.error(f"Failed to load {self.cog_name} subscriptions file ({self.subscriptions_file}): {e}") + log.error( + f"Failed to load {self.cog_name} subscriptions file ({self.subscriptions_file}): {e}" + ) return {} def _save_subscriptions(self): try: with open(self.subscriptions_file, "w") as f: json.dump(self.subscriptions_data, f, indent=4) - log.debug(f"Saved {self.cog_name} subscriptions to {self.subscriptions_file}") + log.debug( + f"Saved {self.cog_name} subscriptions to {self.subscriptions_file}" + ) except Exception as e: - log.error(f"Failed to save {self.cog_name} subscriptions file ({self.subscriptions_file}): {e}") + log.error( + f"Failed to save {self.cog_name} subscriptions file ({self.subscriptions_file}): {e}" + ) def _load_pending_requests(self): if os.path.exists(self.pending_requests_file): @@ -138,48 +156,72 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet with open(self.pending_requests_file, "r") as f: return json.load(f) except Exception as e: - log.error(f"Failed to load {self.cog_name} pending requests file ({self.pending_requests_file}): {e}") + log.error( + f"Failed to load {self.cog_name} pending requests file ({self.pending_requests_file}): {e}" + ) return {} def _save_pending_requests(self): try: with open(self.pending_requests_file, "w") as f: json.dump(self.pending_requests_data, f, indent=4) - log.debug(f"Saved {self.cog_name} pending requests to {self.pending_requests_file}") + log.debug( + f"Saved {self.cog_name} pending requests to {self.pending_requests_file}" + ) except Exception as e: - log.error(f"Failed to save {self.cog_name} pending requests file ({self.pending_requests_file}): {e}") + log.error( + f"Failed to save {self.cog_name} pending requests file ({self.pending_requests_file}): {e}" + ) - async def _get_or_create_webhook(self, channel: typing.Union[discord.TextChannel, discord.ForumChannel]) -> typing.Optional[str]: + async def _get_or_create_webhook( + self, channel: typing.Union[discord.TextChannel, discord.ForumChannel] + ) -> typing.Optional[str]: if not self.session or self.session.closed: self.session = aiohttp.ClientSession() - log.info(f"Recreated aiohttp.ClientSession in _get_or_create_webhook for {self.cog_name}") + log.info( + f"Recreated aiohttp.ClientSession in _get_or_create_webhook for {self.cog_name}" + ) guild_subs = self.subscriptions_data.get(str(channel.guild.id), []) for sub in guild_subs: if sub.get("channel_id") == str(channel.id) and sub.get("webhook_url"): try: - webhook = discord.Webhook.from_url(sub["webhook_url"], session=self.session) + webhook = discord.Webhook.from_url( + sub["webhook_url"], session=self.session + ) await webhook.fetch() if webhook.channel_id == channel.id: - log.debug(f"Reusing existing webhook {webhook.id} for channel {channel.id} ({self.cog_name})") + log.debug( + f"Reusing existing webhook {webhook.id} for channel {channel.id} ({self.cog_name})" + ) return sub["webhook_url"] except (discord.NotFound, ValueError, discord.HTTPException): - log.warning(f"Found stored webhook URL for channel {channel.id} ({self.cog_name}) but it's invalid.") + log.warning( + f"Found stored webhook URL for channel {channel.id} ({self.cog_name}) but it's invalid." + ) try: webhooks = await channel.webhooks() for wh in webhooks: if wh.user == self.bot.user: - log.debug(f"Found existing bot-owned webhook {wh.id} ('{wh.name}') in channel {channel.id} ({self.cog_name})") + log.debug( + f"Found existing bot-owned webhook {wh.id} ('{wh.name}') in channel {channel.id} ({self.cog_name})" + ) return wh.url except discord.Forbidden: - log.warning(f"Missing 'Manage Webhooks' permission in channel {channel.id} ({self.cog_name}).") + log.warning( + f"Missing 'Manage Webhooks' permission in channel {channel.id} ({self.cog_name})." + ) return None except discord.HTTPException as e: - log.error(f"HTTP error listing webhooks for channel {channel.id} ({self.cog_name}): {e}") + log.error( + f"HTTP error listing webhooks for channel {channel.id} ({self.cog_name}): {e}" + ) if not channel.permissions_for(channel.guild.me).manage_webhooks: - log.warning(f"Missing 'Manage Webhooks' permission in channel {channel.id} ({self.cog_name}). Cannot create webhook.") + log.warning( + f"Missing 'Manage Webhooks' permission in channel {channel.id} ({self.cog_name}). Cannot create webhook." + ) return None try: @@ -189,16 +231,28 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet try: avatar_bytes = await self.bot.user.display_avatar.read() except Exception as e: - log.warning(f"Could not read bot avatar for webhook creation ({self.cog_name}): {e}") + log.warning( + f"Could not read bot avatar for webhook creation ({self.cog_name}): {e}" + ) - new_webhook = await channel.create_webhook(name=webhook_name, avatar=avatar_bytes, reason=f"{self.cog_name} Tag Watcher") - log.info(f"Created new webhook {new_webhook.id} ('{new_webhook.name}') in channel {channel.id} ({self.cog_name})") + new_webhook = await channel.create_webhook( + name=webhook_name, + avatar=avatar_bytes, + reason=f"{self.cog_name} Tag Watcher", + ) + log.info( + f"Created new webhook {new_webhook.id} ('{new_webhook.name}') in channel {channel.id} ({self.cog_name})" + ) return new_webhook.url except discord.HTTPException as e: - log.error(f"Failed to create webhook in {channel.mention} ({self.cog_name}): {e}") + log.error( + f"Failed to create webhook in {channel.mention} ({self.cog_name}): {e}" + ) return None except Exception as e: - log.exception(f"Unexpected error creating webhook in {channel.mention} ({self.cog_name})") + log.exception( + f"Unexpected error creating webhook in {channel.mention} ({self.cog_name})" + ) return None async def _send_via_webhook( @@ -210,7 +264,9 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet ): if not self.session or self.session.closed: self.session = aiohttp.ClientSession() - log.info(f"Recreated aiohttp.ClientSession in _send_via_webhook for {self.cog_name}") + log.info( + f"Recreated aiohttp.ClientSession in _send_via_webhook for {self.cog_name}" + ) try: webhook = discord.Webhook.from_url( @@ -221,26 +277,61 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet try: target_thread_obj = discord.Object(id=int(thread_id)) except ValueError: - log.error(f"Invalid thread_id format: {thread_id} for webhook {webhook_url[:30]} ({self.cog_name}). Sending to main channel.") + log.error( + f"Invalid thread_id format: {thread_id} for webhook {webhook_url[:30]} ({self.cog_name}). Sending to main channel." + ) await webhook.send( content=content, - username=f"{self.bot.user.name} {self.cog_name} Watcher" if self.bot.user else f"{self.cog_name} Watcher", - avatar_url=self.bot.user.display_avatar.url if self.bot.user and self.bot.user.display_avatar else None, + username=( + f"{self.bot.user.name} {self.cog_name} Watcher" + if self.bot.user + else f"{self.cog_name} Watcher" + ), + avatar_url=( + self.bot.user.display_avatar.url + if self.bot.user and self.bot.user.display_avatar + else None + ), thread=target_thread_obj, view=view, ) - log.debug(f"Sent message via webhook to {webhook_url[:30]}... (Thread: {thread_id if thread_id else 'None'}) ({self.cog_name})") + log.debug( + f"Sent message via webhook to {webhook_url[:30]}... (Thread: {thread_id if thread_id else 'None'}) ({self.cog_name})" + ) return True - except ValueError: log.error(f"Invalid webhook URL format: {webhook_url[:30]}... ({self.cog_name})") - except discord.NotFound: log.error(f"Webhook not found: {webhook_url[:30]}... ({self.cog_name})") - except discord.Forbidden: log.error(f"Forbidden to send to webhook: {webhook_url[:30]}... ({self.cog_name})") - except discord.HTTPException as e: log.error(f"HTTP error sending to webhook {webhook_url[:30]}... ({self.cog_name}): {e}") - except aiohttp.ClientError as e: log.error(f"aiohttp client error sending to webhook {webhook_url[:30]}... ({self.cog_name}): {e}") - except Exception as e: log.exception(f"Unexpected error sending to webhook {webhook_url[:30]}... ({self.cog_name}): {e}") + except ValueError: + log.error( + f"Invalid webhook URL format: {webhook_url[:30]}... ({self.cog_name})" + ) + except discord.NotFound: + log.error(f"Webhook not found: {webhook_url[:30]}... ({self.cog_name})") + except discord.Forbidden: + log.error( + f"Forbidden to send to webhook: {webhook_url[:30]}... ({self.cog_name})" + ) + except discord.HTTPException as e: + log.error( + f"HTTP error sending to webhook {webhook_url[:30]}... ({self.cog_name}): {e}" + ) + except aiohttp.ClientError as e: + log.error( + f"aiohttp client error sending to webhook {webhook_url[:30]}... ({self.cog_name}): {e}" + ) + except Exception as e: + log.exception( + f"Unexpected error sending to webhook {webhook_url[:30]}... ({self.cog_name}): {e}" + ) return False - async def _fetch_posts_logic(self, interaction_or_ctx: typing.Union[discord.Interaction, commands.Context, str], tags: str, pid_override: typing.Optional[int] = None, limit_override: typing.Optional[int] = None, hidden: bool = False) -> typing.Union[str, tuple[str, list], list]: + async def _fetch_posts_logic( + self, + interaction_or_ctx: typing.Union[discord.Interaction, commands.Context, str], + tags: str, + pid_override: typing.Optional[int] = None, + limit_override: typing.Optional[int] = None, + hidden: bool = False, + ) -> typing.Union[str, tuple[str, list], list]: all_results = [] current_pid = pid_override if pid_override is not None else 0 # API has a hard limit of 1000 results per request, so we'll use that as our per-page limit @@ -267,17 +358,19 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet # For Gelbooru-like APIs, 'rating:safe', 'rating:general', 'rating:questionable' might be SFW-ish # We'll stick to 'rating:safe' for simplicity as it was in Rule34Cog - allow_in_non_nsfw = 'rating:safe' in tags.lower() + allow_in_non_nsfw = "rating:safe" in tags.lower() if not is_nsfw_channel and not allow_in_non_nsfw: - return f'This command for {self.cog_name} can only be used in age-restricted (NSFW) channels, DMs, or with the `rating:safe` tag.' + return f"This command for {self.cog_name} can only be used in age-restricted (NSFW) channels, DMs, or with the `rating:safe` tag." is_interaction = not isinstance(interaction_or_ctx, commands.Context) if is_interaction: if not interaction_or_ctx.response.is_done(): await interaction_or_ctx.response.defer(ephemeral=hidden) - elif hasattr(interaction_or_ctx, 'reply'): # Prefix command - await interaction_or_ctx.reply(f"Fetching data from {self.cog_name}, please wait...") + elif hasattr(interaction_or_ctx, "reply"): # Prefix command + await interaction_or_ctx.reply( + f"Fetching data from {self.cog_name}, please wait..." + ) # Check cache first if not using specific pagination if not use_pagination: @@ -296,7 +389,9 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet if not self.session or self.session.closed: self.session = aiohttp.ClientSession() - log.info(f"Recreated aiohttp.ClientSession in _fetch_posts_logic for {self.cog_name}") + log.info( + f"Recreated aiohttp.ClientSession in _fetch_posts_logic for {self.cog_name}" + ) all_results = [] @@ -305,21 +400,32 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet max_pages = (total_limit + per_page_limit - 1) // per_page_limit for page in range(max_pages): # Stop if we've reached our total limit or if we got fewer results than the per-page limit - if len(all_results) >= total_limit or (page > 0 and len(all_results) % per_page_limit != 0): + if len(all_results) >= total_limit or ( + page > 0 and len(all_results) % per_page_limit != 0 + ): break api_params = { - "page": "dapi", "s": "post", "q": "index", - "limit": per_page_limit, "pid": page, "tags": tags, "json": 1 + "page": "dapi", + "s": "post", + "q": "index", + "limit": per_page_limit, + "pid": page, + "tags": tags, + "json": 1, } try: - async with self.session.get(self.api_base_url, params=api_params) as response: + async with self.session.get( + self.api_base_url, params=api_params + ) as response: if response.status == 200: try: data = await response.json() except aiohttp.ContentTypeError: - log.warning(f"{self.cog_name} API returned non-JSON for tags: {tags}, pid: {page}, params: {api_params}") + log.warning( + f"{self.cog_name} API returned non-JSON for tags: {tags}, pid: {page}, params: {api_params}" + ) data = None if data and isinstance(data, list): @@ -331,20 +437,28 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet # Empty page, no more results break else: - log.warning(f"Unexpected API response format from {self.cog_name} (not list or empty list): {data} for tags: {tags}, pid: {page}, params: {api_params}") + log.warning( + f"Unexpected API response format from {self.cog_name} (not list or empty list): {data} for tags: {tags}, pid: {page}, params: {api_params}" + ) break else: - log.error(f"Failed to fetch {self.cog_name} data. HTTP Status: {response.status} for tags: {tags}, pid: {page}, params: {api_params}") + log.error( + f"Failed to fetch {self.cog_name} data. HTTP Status: {response.status} for tags: {tags}, pid: {page}, params: {api_params}" + ) if page == 0: # Only return error if first page fails return f"Failed to fetch data from {self.cog_name}. HTTP Status: {response.status}" break except aiohttp.ClientError as e: - log.error(f"aiohttp.ClientError in _fetch_posts_logic for {self.cog_name} tags {tags}: {e}") + log.error( + f"aiohttp.ClientError in _fetch_posts_logic for {self.cog_name} tags {tags}: {e}" + ) if page == 0: # Only return error if first page fails return f"Network error fetching data from {self.cog_name}: {e}" break except Exception as e: - log.exception(f"Unexpected error in _fetch_posts_logic API call for {self.cog_name} tags {tags}: {e}") + log.exception( + f"Unexpected error in _fetch_posts_logic API call for {self.cog_name} tags {tags}: {e}" + ) if page == 0: # Only return error if first page fails return f"An unexpected error occurred during {self.cog_name} API call: {e}" break @@ -356,35 +470,55 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet else: # Single request with specific pid/limit api_params = { - "page": "dapi", "s": "post", "q": "index", - "limit": api_limit, "pid": current_pid, "tags": tags, "json": 1 + "page": "dapi", + "s": "post", + "q": "index", + "limit": api_limit, + "pid": current_pid, + "tags": tags, + "json": 1, } try: - async with self.session.get(self.api_base_url, params=api_params) as response: + async with self.session.get( + self.api_base_url, params=api_params + ) as response: if response.status == 200: try: data = await response.json() except aiohttp.ContentTypeError: - log.warning(f"{self.cog_name} API returned non-JSON for tags: {tags}, pid: {current_pid}, params: {api_params}") + log.warning( + f"{self.cog_name} API returned non-JSON for tags: {tags}, pid: {current_pid}, params: {api_params}" + ) data = None if data and isinstance(data, list): all_results.extend(data) - elif isinstance(data, list) and len(data) == 0: pass + elif isinstance(data, list) and len(data) == 0: + pass else: - log.warning(f"Unexpected API response format from {self.cog_name} (not list or empty list): {data} for tags: {tags}, pid: {current_pid}, params: {api_params}") + log.warning( + f"Unexpected API response format from {self.cog_name} (not list or empty list): {data} for tags: {tags}, pid: {current_pid}, params: {api_params}" + ) if pid_override is not None or limit_override is not None: return f"Unexpected API response format from {self.cog_name}: {response.status}" else: - log.error(f"Failed to fetch {self.cog_name} data. HTTP Status: {response.status} for tags: {tags}, pid: {current_pid}, params: {api_params}") + log.error( + f"Failed to fetch {self.cog_name} data. HTTP Status: {response.status} for tags: {tags}, pid: {current_pid}, params: {api_params}" + ) return f"Failed to fetch data from {self.cog_name}. HTTP Status: {response.status}" except aiohttp.ClientError as e: - log.error(f"aiohttp.ClientError in _fetch_posts_logic for {self.cog_name} tags {tags}: {e}") + log.error( + f"aiohttp.ClientError in _fetch_posts_logic for {self.cog_name} tags {tags}: {e}" + ) return f"Network error fetching data from {self.cog_name}: {e}" except Exception as e: - log.exception(f"Unexpected error in _fetch_posts_logic API call for {self.cog_name} tags {tags}: {e}") - return f"An unexpected error occurred during {self.cog_name} API call: {e}" + log.exception( + f"Unexpected error in _fetch_posts_logic API call for {self.cog_name} tags {tags}: {e}" + ) + return ( + f"An unexpected error occurred during {self.cog_name} API call: {e}" + ) if pid_override is not None or limit_override is not None: return all_results @@ -393,7 +527,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet cache_key = tags.lower().strip() self.cache_data[cache_key] = { "timestamp": int(time.time()), - "results": all_results + "results": all_results, } self._save_cache() @@ -408,7 +542,13 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet container = ui.Container() buttons = ui.ActionRow() - def __init__(self, cog: 'GelbooruWatcherBaseCog', tags: str, all_results: list, hidden: bool = False): + def __init__( + self, + cog: "GelbooruWatcherBaseCog", + tags: str, + all_results: list, + hidden: bool = False, + ): super().__init__(timeout=300) self.cog = cog self.tags = tags @@ -431,44 +571,75 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet self.container.add_item(ui.TextDisplay(post_url)) @buttons.button(label="New Random", style=discord.ButtonStyle.primary) - async def new_random(self, interaction: discord.Interaction, button: discord.ui.Button): + async def new_random( + self, interaction: discord.Interaction, button: discord.ui.Button + ): random_result = random.choice(self.all_results) self._update_container(random_result) await interaction.response.edit_message(content="", view=self) - @buttons.button(label="Random In New Message", style=discord.ButtonStyle.success) - async def new_message(self, interaction: discord.Interaction, button: discord.ui.Button): + @buttons.button( + label="Random In New Message", style=discord.ButtonStyle.success + ) + async def new_message( + self, interaction: discord.Interaction, button: discord.ui.Button + ): random_result = random.choice(self.all_results) - new_view = self.cog.GelbooruButtons(self.cog, self.tags, self.all_results, self.hidden) + new_view = self.cog.GelbooruButtons( + self.cog, self.tags, self.all_results, self.hidden + ) new_view._update_container(random_result) - await interaction.response.send_message(content="", view=new_view, ephemeral=self.hidden) + await interaction.response.send_message( + content="", view=new_view, ephemeral=self.hidden + ) @buttons.button(label="Browse Results", style=discord.ButtonStyle.secondary) - async def browse_results(self, interaction: discord.Interaction, button: discord.ui.Button): + async def browse_results( + self, interaction: discord.Interaction, button: discord.ui.Button + ): if not self.all_results: - await interaction.response.send_message("No results to browse.", ephemeral=True) + await interaction.response.send_message( + "No results to browse.", ephemeral=True + ) return self.current_index = 0 - view = self.cog.BrowseView(self.cog, self.tags, self.all_results, self.hidden, self.current_index) + view = self.cog.BrowseView( + self.cog, self.tags, self.all_results, self.hidden, self.current_index + ) await interaction.response.edit_message(content="", view=view) @buttons.button(label="Pin", style=discord.ButtonStyle.danger) - async def pin_message(self, interaction: discord.Interaction, button: discord.ui.Button): + async def pin_message( + self, interaction: discord.Interaction, button: discord.ui.Button + ): if interaction.message: try: await interaction.message.pin() - await interaction.response.send_message("Message pinned successfully!", ephemeral=True) + await interaction.response.send_message( + "Message pinned successfully!", ephemeral=True + ) except discord.Forbidden: - await interaction.response.send_message("I don't have permission to pin messages.", ephemeral=True) + await interaction.response.send_message( + "I don't have permission to pin messages.", ephemeral=True + ) except discord.HTTPException as e: - await interaction.response.send_message(f"Failed to pin: {e}", ephemeral=True) + await interaction.response.send_message( + f"Failed to pin: {e}", ephemeral=True + ) class BrowseView(ui.LayoutView): container = ui.Container() nav_row = ui.ActionRow() extra_row = ui.ActionRow() - def __init__(self, cog: 'GelbooruWatcherBaseCog', tags: str, all_results: list, hidden: bool = False, current_index: int = 0): + def __init__( + self, + cog: "GelbooruWatcherBaseCog", + tags: str, + all_results: list, + hidden: bool = False, + current_index: int = 0, + ): super().__init__(timeout=300) self.cog = cog self.tags = tags @@ -495,27 +666,41 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet await interaction.response.edit_message(content="", view=self) @nav_row.button(label="First", style=discord.ButtonStyle.secondary, emoji="⏪") - async def first(self, interaction: discord.Interaction, button: discord.ui.Button): + async def first( + self, interaction: discord.Interaction, button: discord.ui.Button + ): self.current_index = 0 await self._update_message(interaction) - @nav_row.button(label="Previous", style=discord.ButtonStyle.secondary, emoji="◀️") - async def previous(self, interaction: discord.Interaction, button: discord.ui.Button): - self.current_index = (self.current_index - 1 + len(self.all_results)) % len(self.all_results) + @nav_row.button( + label="Previous", style=discord.ButtonStyle.secondary, emoji="◀️" + ) + async def previous( + self, interaction: discord.Interaction, button: discord.ui.Button + ): + self.current_index = (self.current_index - 1 + len(self.all_results)) % len( + self.all_results + ) await self._update_message(interaction) @nav_row.button(label="Next", style=discord.ButtonStyle.primary, emoji="▶️") - async def next_result(self, interaction: discord.Interaction, button: discord.ui.Button): + async def next_result( + self, interaction: discord.Interaction, button: discord.ui.Button + ): self.current_index = (self.current_index + 1) % len(self.all_results) await self._update_message(interaction) @nav_row.button(label="Last", style=discord.ButtonStyle.secondary, emoji="⏩") - async def last(self, interaction: discord.Interaction, button: discord.ui.Button): + async def last( + self, interaction: discord.Interaction, button: discord.ui.Button + ): self.current_index = len(self.all_results) - 1 await self._update_message(interaction) @extra_row.button(label="Go To", style=discord.ButtonStyle.primary) - async def goto(self, interaction: discord.Interaction, button: discord.ui.Button): + async def goto( + self, interaction: discord.Interaction, button: discord.ui.Button + ): modal = self.cog.GoToModal(len(self.all_results)) await interaction.response.send_modal(modal) await modal.wait() @@ -525,11 +710,17 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet interaction.message.id, content="", view=self ) - @extra_row.button(label="Back to Main Controls", style=discord.ButtonStyle.danger) - async def back_to_main(self, interaction: discord.Interaction, button: discord.ui.Button): + @extra_row.button( + label="Back to Main Controls", style=discord.ButtonStyle.danger + ) + async def back_to_main( + self, interaction: discord.Interaction, button: discord.ui.Button + ): # Send a random one from all_results as the content - if not self.all_results: # Should not happen if browse was initiated - await interaction.response.edit_message(content="No results available.", view=None) + if not self.all_results: # Should not happen if browse was initiated + await interaction.response.edit_message( + content="No results available.", view=None + ) return random_result = random.choice(self.all_results) view = self.cog.GelbooruButtons( @@ -547,7 +738,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet label=f"Page Number (1-{max_pages})", placeholder=f"Enter a number between 1 and {max_pages}", min_length=1, - max_length=len(str(max_pages)) + max_length=len(str(max_pages)), ) self.add_item(self.page_num) @@ -556,13 +747,20 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet num = int(self.page_num.value) if 1 <= num <= self.max_pages: self.value = num - await interaction.response.defer() # Defer here, followup in BrowseView.goto + await interaction.response.defer() # Defer here, followup in BrowseView.goto else: - await interaction.response.send_message(f"Please enter a number between 1 and {self.max_pages}",ephemeral=True) + await interaction.response.send_message( + f"Please enter a number between 1 and {self.max_pages}", + ephemeral=True, + ) except ValueError: - await interaction.response.send_message("Please enter a valid number",ephemeral=True) + await interaction.response.send_message( + "Please enter a valid number", ephemeral=True + ) - def _build_new_post_view(self, tags: str, file_url: str, post_id: int) -> ui.LayoutView: + def _build_new_post_view( + self, tags: str, file_url: str, post_id: int + ) -> ui.LayoutView: view = ui.LayoutView(timeout=None) container = ui.Container() view.add_item(container) @@ -570,7 +768,9 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet gallery = ui.MediaGallery() gallery.add_item(media=file_url) container.add_item(gallery) - container.add_item(ui.TextDisplay(f"New {self.cog_name} post for tags `{tags}`:")) + container.add_item( + ui.TextDisplay(f"New {self.cog_name} post for tags `{tags}`:") + ) post_url = self.post_url_template.format(post_id) container.add_item(ui.TextDisplay(post_url)) @@ -584,7 +784,7 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet _, all_results = response view = self.GelbooruButtons(self, tags, all_results, hidden=False) # _fetch_posts_logic sends initial reply, so we edit it - original_message = ctx.message # This might not be the bot's reply message + original_message = ctx.message # This might not be the bot's reply message # We need a reference to the "Fetching data..." message. # This requires _fetch_posts_logic to return it or for the command to manage it. # For now, let's assume _fetch_posts_logic's "reply" is what we want to edit. @@ -601,20 +801,26 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet # This needs a slight refactor. # For now, we'll assume the user's message is replied to, and we send a new one. This is not ideal. # A better way: - if ctx.message.reference and ctx.message.reference.message_id: # If the bot replied + if ( + ctx.message.reference and ctx.message.reference.message_id + ): # If the bot replied try: - reply_msg = await ctx.channel.fetch_message(ctx.message.reference.message_id) + reply_msg = await ctx.channel.fetch_message( + ctx.message.reference.message_id + ) if reply_msg.author == self.bot.user: await reply_msg.edit(content="", view=view) return - except (discord.NotFound, discord.Forbidden): pass # Fallback to new message - await ctx.send("", view=view) # Fallback - elif isinstance(response, str): # Error + except (discord.NotFound, discord.Forbidden): + pass # Fallback to new message + await ctx.send("", view=view) # Fallback + elif isinstance(response, str): # Error # Similar issue with editing the reply. await ctx.send(response) - - async def _slash_command_logic(self, interaction: discord.Interaction, tags: str, hidden: bool): + async def _slash_command_logic( + self, interaction: discord.Interaction, tags: str, hidden: bool + ): response = await self._fetch_posts_logic(interaction, tags, hidden=hidden) if isinstance(response, tuple): @@ -623,21 +829,33 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet if interaction.response.is_done(): await interaction.followup.send(content="", view=view, ephemeral=hidden) else: - await interaction.response.send_message(content="", view=view, ephemeral=hidden) - elif isinstance(response, str): # Error + await interaction.response.send_message( + content="", view=view, ephemeral=hidden + ) + elif isinstance(response, str): # Error ephemeral_error = hidden - if self.is_nsfw_site and response.startswith(f'This command for {self.cog_name} can only be used'): - ephemeral_error = True # Always make NSFW warnings ephemeral if possible + if self.is_nsfw_site and response.startswith( + f"This command for {self.cog_name} can only be used" + ): + ephemeral_error = ( + True # Always make NSFW warnings ephemeral if possible + ) if not interaction.response.is_done(): - await interaction.response.send_message(response, ephemeral=ephemeral_error) + await interaction.response.send_message( + response, ephemeral=ephemeral_error + ) else: try: await interaction.followup.send(response, ephemeral=ephemeral_error) except discord.HTTPException as e: - log.error(f"{self.cog_name} slash command: Failed to send error followup for tags '{tags}': {e}") + log.error( + f"{self.cog_name} slash command: Failed to send error followup for tags '{tags}': {e}" + ) - async def _browse_slash_command_logic(self, interaction: discord.Interaction, tags: str, hidden: bool): + async def _browse_slash_command_logic( + self, interaction: discord.Interaction, tags: str, hidden: bool + ): response = await self._fetch_posts_logic(interaction, tags, hidden=hidden) if isinstance(response, tuple): @@ -654,13 +872,21 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet if interaction.response.is_done(): await interaction.followup.send(content="", view=view, ephemeral=hidden) else: - await interaction.response.send_message(content="", view=view, ephemeral=hidden) - elif isinstance(response, str): # Error + await interaction.response.send_message( + content="", view=view, ephemeral=hidden + ) + elif isinstance(response, str): # Error ephemeral_error = hidden - if self.is_nsfw_site and response.startswith(f'This command for {self.cog_name} can only be used'): + if self.is_nsfw_site and response.startswith( + f"This command for {self.cog_name} can only be used" + ): ephemeral_error = True - if not interaction.response.is_done(): await interaction.response.send_message(response, ephemeral=ephemeral_error) - else: await interaction.followup.send(response, ephemeral=ephemeral_error) + if not interaction.response.is_done(): + await interaction.response.send_message( + response, ephemeral=ephemeral_error + ) + else: + await interaction.followup.send(response, ephemeral=ephemeral_error) @tasks.loop(minutes=10) async def check_new_posts(self): @@ -672,45 +898,68 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet needs_save = False for guild_id_str, subs_list in current_subscriptions.items(): - if not isinstance(subs_list, list): continue + if not isinstance(subs_list, list): + continue for sub_index, sub in enumerate(subs_list): - if not isinstance(sub, dict): continue + if not isinstance(sub, dict): + continue tags = sub.get("tags") webhook_url = sub.get("webhook_url") last_known_post_id = sub.get("last_known_post_id", 0) if not tags or not webhook_url: - log.warning(f"Subscription for {self.cog_name} (guild {guild_id_str}, sub {sub.get('subscription_id')}) missing tags/webhook.") + log.warning( + f"Subscription for {self.cog_name} (guild {guild_id_str}, sub {sub.get('subscription_id')}) missing tags/webhook." + ) continue - fetched_posts_response = await self._fetch_posts_logic("internal_task_call", tags, pid_override=0, limit_override=100) + fetched_posts_response = await self._fetch_posts_logic( + "internal_task_call", tags, pid_override=0, limit_override=100 + ) if isinstance(fetched_posts_response, str): - log.error(f"Error fetching {self.cog_name} posts for sub {sub.get('subscription_id')} (tags: {tags}): {fetched_posts_response}") + log.error( + f"Error fetching {self.cog_name} posts for sub {sub.get('subscription_id')} (tags: {tags}): {fetched_posts_response}" + ) + continue + if not fetched_posts_response: continue - if not fetched_posts_response: continue new_posts_to_send = [] current_max_id_this_batch = last_known_post_id for post_data in fetched_posts_response: - if not isinstance(post_data, dict) or "id" not in post_data or "file_url" not in post_data: - log.warning(f"Malformed {self.cog_name} post data for tags {tags}: {post_data}") + if ( + not isinstance(post_data, dict) + or "id" not in post_data + or "file_url" not in post_data + ): + log.warning( + f"Malformed {self.cog_name} post data for tags {tags}: {post_data}" + ) continue post_id = int(post_data["id"]) - if post_id > last_known_post_id: new_posts_to_send.append(post_data) - if post_id > current_max_id_this_batch: current_max_id_this_batch = post_id + if post_id > last_known_post_id: + new_posts_to_send.append(post_data) + if post_id > current_max_id_this_batch: + current_max_id_this_batch = post_id if new_posts_to_send: new_posts_to_send.sort(key=lambda p: int(p["id"])) - log.info(f"Found {len(new_posts_to_send)} new {self.cog_name} post(s) for tags '{tags}', sub {sub.get('subscription_id')}") + log.info( + f"Found {len(new_posts_to_send)} new {self.cog_name} post(s) for tags '{tags}', sub {sub.get('subscription_id')}" + ) latest_sent_id_for_this_sub = last_known_post_id for new_post in new_posts_to_send: post_id = int(new_post["id"]) - target_thread_id: typing.Optional[str] = sub.get("target_post_id") or sub.get("thread_id") + target_thread_id: typing.Optional[str] = sub.get( + "target_post_id" + ) or sub.get("thread_id") - view = self._build_new_post_view(tags, new_post["file_url"], post_id) + view = self._build_new_post_view( + tags, new_post["file_url"], post_id + ) send_success = await self._send_via_webhook( webhook_url, @@ -720,27 +969,47 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet ) if send_success: latest_sent_id_for_this_sub = post_id - original_guild_subs = self.subscriptions_data.get(guild_id_str) + original_guild_subs = self.subscriptions_data.get( + guild_id_str + ) if original_guild_subs: for original_sub_entry in original_guild_subs: - if original_sub_entry.get("subscription_id") == sub.get("subscription_id"): - original_sub_entry["last_known_post_id"] = latest_sent_id_for_this_sub + if original_sub_entry.get( + "subscription_id" + ) == sub.get("subscription_id"): + original_sub_entry["last_known_post_id"] = ( + latest_sent_id_for_this_sub + ) needs_save = True break await asyncio.sleep(1) else: - log.error(f"Failed to send new {self.cog_name} post via webhook for sub {sub.get('subscription_id')}.") + log.error( + f"Failed to send new {self.cog_name} post via webhook for sub {sub.get('subscription_id')}." + ) break - if not new_posts_to_send and current_max_id_this_batch > last_known_post_id: + if ( + not new_posts_to_send + and current_max_id_this_batch > last_known_post_id + ): original_guild_subs = self.subscriptions_data.get(guild_id_str) if original_guild_subs: for original_sub_entry in original_guild_subs: - if original_sub_entry.get("subscription_id") == sub.get("subscription_id"): - if original_sub_entry.get("last_known_post_id", 0) < current_max_id_this_batch: - original_sub_entry["last_known_post_id"] = current_max_id_this_batch + if original_sub_entry.get("subscription_id") == sub.get( + "subscription_id" + ): + if ( + original_sub_entry.get("last_known_post_id", 0) + < current_max_id_this_batch + ): + original_sub_entry["last_known_post_id"] = ( + current_max_id_this_batch + ) needs_save = True - log.debug(f"Fast-forwarded last_known_post_id for {self.cog_name} sub {sub.get('subscription_id')} to {current_max_id_this_batch}.") + log.debug( + f"Fast-forwarded last_known_post_id for {self.cog_name} sub {sub.get('subscription_id')} to {current_max_id_this_batch}." + ) break if needs_save: self._save_subscriptions() @@ -749,91 +1018,158 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet @check_new_posts.before_loop async def before_check_new_posts(self): await self.bot.wait_until_ready() - log.info(f"{self.cog_name}Cog: `check_new_posts` loop is waiting for bot readiness...") + log.info( + f"{self.cog_name}Cog: `check_new_posts` loop is waiting for bot readiness..." + ) if self.session is None or self.session.closed: self.session = aiohttp.ClientSession() - log.info(f"aiohttp ClientSession created before {self.cog_name} check_new_posts loop.") + log.info( + f"aiohttp ClientSession created before {self.cog_name} check_new_posts loop." + ) - async def _create_new_subscription(self, guild_id: int, user_id: int, tags: str, - target_channel: typing.Union[discord.TextChannel, discord.ForumChannel], - requested_thread_target: typing.Optional[str] = None, - requested_post_title: typing.Optional[str] = None, - is_request_approval: bool = False, - requester_mention: typing.Optional[str] = None) -> str: + async def _create_new_subscription( + self, + guild_id: int, + user_id: int, + tags: str, + target_channel: typing.Union[discord.TextChannel, discord.ForumChannel], + requested_thread_target: typing.Optional[str] = None, + requested_post_title: typing.Optional[str] = None, + is_request_approval: bool = False, + requester_mention: typing.Optional[str] = None, + ) -> str: actual_target_thread_id: typing.Optional[str] = None actual_target_thread_mention: str = "" - actual_post_title = requested_post_title or f"{self.cog_name} Watch: {tags[:50]}" + actual_post_title = ( + requested_post_title or f"{self.cog_name} Watch: {tags[:50]}" + ) if isinstance(target_channel, discord.TextChannel): if requested_thread_target: found_thread: typing.Optional[discord.Thread] = None try: - if not target_channel.guild: return "Error: Guild context not found." - thread_as_obj = await target_channel.guild.fetch_channel(int(requested_thread_target)) - if isinstance(thread_as_obj, discord.Thread) and thread_as_obj.parent_id == target_channel.id: + if not target_channel.guild: + return "Error: Guild context not found." + thread_as_obj = await target_channel.guild.fetch_channel( + int(requested_thread_target) + ) + if ( + isinstance(thread_as_obj, discord.Thread) + and thread_as_obj.parent_id == target_channel.id + ): found_thread = thread_as_obj - except (ValueError, discord.NotFound, discord.Forbidden): pass - except Exception as e: log.error(f"Error fetching thread by ID '{requested_thread_target}' for {self.cog_name}: {e}") + except (ValueError, discord.NotFound, discord.Forbidden): + pass + except Exception as e: + log.error( + f"Error fetching thread by ID '{requested_thread_target}' for {self.cog_name}: {e}" + ) - if not found_thread and hasattr(target_channel, 'threads'): + if not found_thread and hasattr(target_channel, "threads"): for t in target_channel.threads: - if t.name.lower() == requested_thread_target.lower(): found_thread = t; break + if t.name.lower() == requested_thread_target.lower(): + found_thread = t + break if found_thread: actual_target_thread_id = str(found_thread.id) actual_target_thread_mention = found_thread.mention - else: return f"❌ Could not find thread `{requested_thread_target}` in {target_channel.mention} for {self.cog_name}." + else: + return f"❌ Could not find thread `{requested_thread_target}` in {target_channel.mention} for {self.cog_name}." elif isinstance(target_channel, discord.ForumChannel): forum_post_initial_message = f"✨ **New {self.cog_name} Watch Initialized!** ✨\nNow monitoring tags: `{tags}`" - if is_request_approval and requester_mention: forum_post_initial_message += f"\n_Requested by: {requester_mention}_" + if is_request_approval and requester_mention: + forum_post_initial_message += f"\n_Requested by: {requester_mention}_" try: - if not target_channel.permissions_for(target_channel.guild.me).create_public_threads: + if not target_channel.permissions_for( + target_channel.guild.me + ).create_public_threads: return f"❌ I lack permission to create posts in forum {target_channel.mention} for {self.cog_name}." - new_forum_post = await target_channel.create_thread(name=actual_post_title, content=forum_post_initial_message, reason=f"{self.cog_name}Watch: {tags}") + new_forum_post = await target_channel.create_thread( + name=actual_post_title, + content=forum_post_initial_message, + reason=f"{self.cog_name}Watch: {tags}", + ) actual_target_thread_id = str(new_forum_post.thread.id) actual_target_thread_mention = new_forum_post.thread.mention - log.info(f"Created {self.cog_name} forum post {new_forum_post.thread.id} for tags '{tags}' in forum {target_channel.id}") - except discord.HTTPException as e: return f"❌ Failed to create {self.cog_name} forum post in {target_channel.mention}: {e}" + log.info( + f"Created {self.cog_name} forum post {new_forum_post.thread.id} for tags '{tags}' in forum {target_channel.id}" + ) + except discord.HTTPException as e: + return f"❌ Failed to create {self.cog_name} forum post in {target_channel.mention}: {e}" except Exception as e: - log.exception(f"Unexpected error creating {self.cog_name} forum post for tags '{tags}'") + log.exception( + f"Unexpected error creating {self.cog_name} forum post for tags '{tags}'" + ) return f"❌ Unexpected error creating {self.cog_name} forum post." webhook_url = await self._get_or_create_webhook(target_channel) - if not webhook_url: return f"❌ Failed to get/create webhook for {target_channel.mention} ({self.cog_name}). Check permissions." + if not webhook_url: + return f"❌ Failed to get/create webhook for {target_channel.mention} ({self.cog_name}). Check permissions." - initial_posts = await self._fetch_posts_logic("internal_initial_fetch", tags, pid_override=0, limit_override=1) + initial_posts = await self._fetch_posts_logic( + "internal_initial_fetch", tags, pid_override=0, limit_override=1 + ) last_known_post_id = 0 if isinstance(initial_posts, list) and initial_posts: - if isinstance(initial_posts[0], dict) and "id" in initial_posts[0]: last_known_post_id = int(initial_posts[0]["id"]) - else: log.warning(f"Malformed {self.cog_name} post data for initial fetch (tags: '{tags}'): {initial_posts[0]}") - elif isinstance(initial_posts, str): log.error(f"{self.cog_name} API error on initial fetch (tags: '{tags}'): {initial_posts}") + if isinstance(initial_posts[0], dict) and "id" in initial_posts[0]: + last_known_post_id = int(initial_posts[0]["id"]) + else: + log.warning( + f"Malformed {self.cog_name} post data for initial fetch (tags: '{tags}'): {initial_posts[0]}" + ) + elif isinstance(initial_posts, str): + log.error( + f"{self.cog_name} API error on initial fetch (tags: '{tags}'): {initial_posts}" + ) guild_id_str = str(guild_id) subscription_id = str(uuid.uuid4()) new_sub_data = { - "subscription_id": subscription_id, "tags": tags.strip(), "webhook_url": webhook_url, - "last_known_post_id": last_known_post_id, "added_by_user_id": str(user_id), - "added_timestamp": discord.utils.utcnow().isoformat() + "subscription_id": subscription_id, + "tags": tags.strip(), + "webhook_url": webhook_url, + "last_known_post_id": last_known_post_id, + "added_by_user_id": str(user_id), + "added_timestamp": discord.utils.utcnow().isoformat(), } if isinstance(target_channel, discord.ForumChannel): - new_sub_data.update({"forum_channel_id": str(target_channel.id), "target_post_id": actual_target_thread_id, "post_title": actual_post_title}) + new_sub_data.update( + { + "forum_channel_id": str(target_channel.id), + "target_post_id": actual_target_thread_id, + "post_title": actual_post_title, + } + ) else: - new_sub_data.update({"channel_id": str(target_channel.id), "thread_id": actual_target_thread_id}) + new_sub_data.update( + { + "channel_id": str(target_channel.id), + "thread_id": actual_target_thread_id, + } + ) - if guild_id_str not in self.subscriptions_data: self.subscriptions_data[guild_id_str] = [] + if guild_id_str not in self.subscriptions_data: + self.subscriptions_data[guild_id_str] = [] - for existing_sub in self.subscriptions_data[guild_id_str]: # Duplicate check + for existing_sub in self.subscriptions_data[guild_id_str]: # Duplicate check is_dup_tags = existing_sub.get("tags") == new_sub_data["tags"] - is_dup_forum = isinstance(target_channel, discord.ForumChannel) and \ - existing_sub.get("forum_channel_id") == new_sub_data.get("forum_channel_id") and \ - existing_sub.get("target_post_id") == new_sub_data.get("target_post_id") - is_dup_text_chan = isinstance(target_channel, discord.TextChannel) and \ - existing_sub.get("channel_id") == new_sub_data.get("channel_id") and \ - existing_sub.get("thread_id") == new_sub_data.get("thread_id") + is_dup_forum = ( + isinstance(target_channel, discord.ForumChannel) + and existing_sub.get("forum_channel_id") + == new_sub_data.get("forum_channel_id") + and existing_sub.get("target_post_id") + == new_sub_data.get("target_post_id") + ) + is_dup_text_chan = ( + isinstance(target_channel, discord.TextChannel) + and existing_sub.get("channel_id") == new_sub_data.get("channel_id") + and existing_sub.get("thread_id") == new_sub_data.get("thread_id") + ) if is_dup_tags and (is_dup_forum or is_dup_text_chan): return f"⚠️ A {self.cog_name} subscription for these tags in this location already exists (ID: `{existing_sub.get('subscription_id')}`)." @@ -841,37 +1177,73 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet self._save_subscriptions() target_desc = f"in {target_channel.mention}" - if isinstance(target_channel, discord.ForumChannel) and actual_target_thread_mention: target_desc = f"in forum post {actual_target_thread_mention} within {target_channel.mention}" - elif actual_target_thread_mention: target_desc += f" (thread: {actual_target_thread_mention})" + if ( + isinstance(target_channel, discord.ForumChannel) + and actual_target_thread_mention + ): + target_desc = f"in forum post {actual_target_thread_mention} within {target_channel.mention}" + elif actual_target_thread_mention: + target_desc += f" (thread: {actual_target_thread_mention})" - log.info(f"{self.cog_name} subscription added: Guild {guild_id_str}, Tags '{tags}', Target {target_desc}, SubID {subscription_id}") - return (f"✅ Watching {self.cog_name} for new posts with tags `{tags}` {target_desc}.\n" - f"Initial latest post ID: {last_known_post_id}. Sub ID: `{subscription_id}`.") + log.info( + f"{self.cog_name} subscription added: Guild {guild_id_str}, Tags '{tags}', Target {target_desc}, SubID {subscription_id}" + ) + return ( + f"✅ Watching {self.cog_name} for new posts with tags `{tags}` {target_desc}.\n" + f"Initial latest post ID: {last_known_post_id}. Sub ID: `{subscription_id}`." + ) # --- Watch command group logic methods --- # These will be called by the specific cog's command handlers - async def _watch_add_logic(self, interaction: discord.Interaction, tags: str, channel: typing.Union[discord.TextChannel, discord.ForumChannel], thread_target: typing.Optional[str], post_title: typing.Optional[str]): - if not interaction.guild_id or not interaction.user: # Should be caught by permissions but good check - await interaction.followup.send("Command error: Missing guild or user context.", ephemeral=True) + async def _watch_add_logic( + self, + interaction: discord.Interaction, + tags: str, + channel: typing.Union[discord.TextChannel, discord.ForumChannel], + thread_target: typing.Optional[str], + post_title: typing.Optional[str], + ): + if ( + not interaction.guild_id or not interaction.user + ): # Should be caught by permissions but good check + await interaction.followup.send( + "Command error: Missing guild or user context.", ephemeral=True + ) return if isinstance(channel, discord.TextChannel) and post_title: - await interaction.followup.send("`post_title` is only for Forum Channels.", ephemeral=True) + await interaction.followup.send( + "`post_title` is only for Forum Channels.", ephemeral=True + ) return if isinstance(channel, discord.ForumChannel) and thread_target: - await interaction.followup.send("`thread_target` is only for Text Channels.", ephemeral=True) + await interaction.followup.send( + "`thread_target` is only for Text Channels.", ephemeral=True + ) return response_message = await self._create_new_subscription( - guild_id=interaction.guild_id, user_id=interaction.user.id, tags=tags, - target_channel=channel, requested_thread_target=thread_target, requested_post_title=post_title + guild_id=interaction.guild_id, + user_id=interaction.user.id, + tags=tags, + target_channel=channel, + requested_thread_target=thread_target, + requested_post_title=post_title, ) await interaction.followup.send(response_message, ephemeral=True) - async def _watch_request_logic(self, interaction: discord.Interaction, tags: str, forum_channel: discord.ForumChannel, post_title: typing.Optional[str]): + async def _watch_request_logic( + self, + interaction: discord.Interaction, + tags: str, + forum_channel: discord.ForumChannel, + post_title: typing.Optional[str], + ): if not interaction.guild_id or not interaction.user: - await interaction.followup.send("Command error: Missing guild or user context.", ephemeral=True) + await interaction.followup.send( + "Command error: Missing guild or user context.", ephemeral=True + ) return guild_id_str = str(interaction.guild_id) @@ -879,41 +1251,69 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet actual_post_title = post_title or f"{self.cog_name} Watch: {tags[:50]}" new_request = { - "request_id": request_id, "requester_id": str(interaction.user.id), "requester_name": str(interaction.user), - "requested_tags": tags.strip(), "target_forum_channel_id": str(forum_channel.id), - "requested_post_title": actual_post_title, "status": "pending", - "request_timestamp": discord.utils.utcnow().isoformat(), "moderator_id": None, "moderation_timestamp": None + "request_id": request_id, + "requester_id": str(interaction.user.id), + "requester_name": str(interaction.user), + "requested_tags": tags.strip(), + "target_forum_channel_id": str(forum_channel.id), + "requested_post_title": actual_post_title, + "status": "pending", + "request_timestamp": discord.utils.utcnow().isoformat(), + "moderator_id": None, + "moderation_timestamp": None, } - if guild_id_str not in self.pending_requests_data: self.pending_requests_data[guild_id_str] = [] + if guild_id_str not in self.pending_requests_data: + self.pending_requests_data[guild_id_str] = [] self.pending_requests_data[guild_id_str].append(new_request) self._save_pending_requests() - log.info(f"New {self.cog_name} watch request: Guild {guild_id_str}, Requester {interaction.user}, Tags '{tags}', ReqID {request_id}") + log.info( + f"New {self.cog_name} watch request: Guild {guild_id_str}, Requester {interaction.user}, Tags '{tags}', ReqID {request_id}" + ) await interaction.followup.send( - f"✅ Your {self.cog_name} watch request for tags `{tags}` in forum {forum_channel.mention} (title: \"{actual_post_title}\") submitted.\n" + f'✅ Your {self.cog_name} watch request for tags `{tags}` in forum {forum_channel.mention} (title: "{actual_post_title}") submitted.\n' f"Request ID: `{request_id}`. Awaiting moderator approval." ) async def _watch_pending_list_logic(self, interaction: discord.Interaction): - if not interaction.guild_id or not interaction.guild: # Ensure guild object for name - await interaction.response.send_message("Command error: Missing guild context.", ephemeral=True) + if ( + not interaction.guild_id or not interaction.guild + ): # Ensure guild object for name + await interaction.response.send_message( + "Command error: Missing guild context.", ephemeral=True + ) return guild_id_str = str(interaction.guild_id) - pending_reqs = [req for req in self.pending_requests_data.get(guild_id_str, []) if req.get("status") == "pending"] + pending_reqs = [ + req + for req in self.pending_requests_data.get(guild_id_str, []) + if req.get("status") == "pending" + ] if not pending_reqs: - await interaction.response.send_message(f"No pending {self.cog_name} watch requests.", ephemeral=True) + await interaction.response.send_message( + f"No pending {self.cog_name} watch requests.", ephemeral=True + ) return - embed = discord.Embed(title=f"Pending {self.cog_name} Watch Requests for {interaction.guild.name}", color=discord.Color.orange()) + embed = discord.Embed( + title=f"Pending {self.cog_name} Watch Requests for {interaction.guild.name}", + color=discord.Color.orange(), + ) desc_parts = [] for req in pending_reqs: forum_mention = f"<#{req.get('target_forum_channel_id', 'Unknown')}>" - timestamp_str = req.get('request_timestamp') - time_fmt = discord.utils.format_dt(datetime.fromisoformat(timestamp_str), style='R') if timestamp_str else 'Unknown time' + timestamp_str = req.get("request_timestamp") + time_fmt = ( + discord.utils.format_dt( + datetime.fromisoformat(timestamp_str), style="R" + ) + if timestamp_str + else "Unknown time" + ) desc_parts.append( f"**ID:** `{req.get('request_id')}`\n" f" **Requester:** {req.get('requester_name', 'Unknown')} (`{req.get('requester_id')}`)\n" @@ -923,106 +1323,186 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet f" **Requested:** {time_fmt}\n---" ) - embed.description = "\n".join(desc_parts)[:4096] # Limit description length + embed.description = "\n".join(desc_parts)[:4096] # Limit description length await interaction.response.send_message(embed=embed, ephemeral=True) - async def _watch_approve_request_logic(self, interaction: discord.Interaction, request_id: str): - if not interaction.guild_id or not interaction.user or not interaction.guild: # Need guild for fetch_channel and name - await interaction.followup.send("Command error: Missing context.", ephemeral=True) + async def _watch_approve_request_logic( + self, interaction: discord.Interaction, request_id: str + ): + if ( + not interaction.guild_id or not interaction.user or not interaction.guild + ): # Need guild for fetch_channel and name + await interaction.followup.send( + "Command error: Missing context.", ephemeral=True + ) return guild_id_str = str(interaction.guild_id) req_to_approve, req_idx = None, -1 for i, r in enumerate(self.pending_requests_data.get(guild_id_str, [])): if r.get("request_id") == request_id and r.get("status") == "pending": - req_to_approve, req_idx = r, i; break + req_to_approve, req_idx = r, i + break if not req_to_approve: - await interaction.followup.send(f"❌ Pending {self.cog_name} request ID `{request_id}` not found.", ephemeral=True) + await interaction.followup.send( + f"❌ Pending {self.cog_name} request ID `{request_id}` not found.", + ephemeral=True, + ) return target_forum_channel: typing.Optional[discord.ForumChannel] = None try: - target_forum_channel = await interaction.guild.fetch_channel(int(req_to_approve["target_forum_channel_id"])) - if not isinstance(target_forum_channel, discord.ForumChannel): raise ValueError("Not a ForumChannel") + target_forum_channel = await interaction.guild.fetch_channel( + int(req_to_approve["target_forum_channel_id"]) + ) + if not isinstance(target_forum_channel, discord.ForumChannel): + raise ValueError("Not a ForumChannel") except Exception: - await interaction.followup.send(f"❌ Target forum channel for {self.cog_name} request `{request_id}` not found/invalid.", ephemeral=True) + await interaction.followup.send( + f"❌ Target forum channel for {self.cog_name} request `{request_id}` not found/invalid.", + ephemeral=True, + ) return creation_response = await self._create_new_subscription( - guild_id=interaction.guild_id, user_id=int(req_to_approve["requester_id"]), tags=req_to_approve["requested_tags"], - target_channel=target_forum_channel, requested_post_title=req_to_approve["requested_post_title"], - is_request_approval=True, requester_mention=f"<@{req_to_approve['requester_id']}>" + guild_id=interaction.guild_id, + user_id=int(req_to_approve["requester_id"]), + tags=req_to_approve["requested_tags"], + target_channel=target_forum_channel, + requested_post_title=req_to_approve["requested_post_title"], + is_request_approval=True, + requester_mention=f"<@{req_to_approve['requester_id']}>", ) if creation_response.startswith("✅"): - req_to_approve.update({"status": "approved", "moderator_id": str(interaction.user.id), "moderation_timestamp": discord.utils.utcnow().isoformat()}) + req_to_approve.update( + { + "status": "approved", + "moderator_id": str(interaction.user.id), + "moderation_timestamp": discord.utils.utcnow().isoformat(), + } + ) self.pending_requests_data[guild_id_str][req_idx] = req_to_approve self._save_pending_requests() - await interaction.followup.send(f"✅ {self.cog_name} Request ID `{request_id}` approved. {creation_response}", ephemeral=True) + await interaction.followup.send( + f"✅ {self.cog_name} Request ID `{request_id}` approved. {creation_response}", + ephemeral=True, + ) try: - requester = await self.bot.fetch_user(int(req_to_approve["requester_id"])) - await requester.send(f"🎉 Your {self.cog_name} watch request (`{request_id}`) for tags `{req_to_approve['requested_tags']}` in server `{interaction.guild.name}` was **approved** by {interaction.user.mention}!\nDetails: {creation_response}") - except Exception as e: log.error(f"Failed to notify {self.cog_name} requester {req_to_approve['requester_id']} of approval: {e}") + requester = await self.bot.fetch_user( + int(req_to_approve["requester_id"]) + ) + await requester.send( + f"🎉 Your {self.cog_name} watch request (`{request_id}`) for tags `{req_to_approve['requested_tags']}` in server `{interaction.guild.name}` was **approved** by {interaction.user.mention}!\nDetails: {creation_response}" + ) + except Exception as e: + log.error( + f"Failed to notify {self.cog_name} requester {req_to_approve['requester_id']} of approval: {e}" + ) else: - await interaction.followup.send(f"❌ Failed to approve {self.cog_name} request `{request_id}`. Sub creation failed: {creation_response}", ephemeral=True) + await interaction.followup.send( + f"❌ Failed to approve {self.cog_name} request `{request_id}`. Sub creation failed: {creation_response}", + ephemeral=True, + ) - async def _watch_reject_request_logic(self, interaction: discord.Interaction, request_id: str, reason: typing.Optional[str]): - if not interaction.guild_id or not interaction.user or not interaction.guild: # Need guild for name - await interaction.followup.send("Command error: Missing context.", ephemeral=True) + async def _watch_reject_request_logic( + self, + interaction: discord.Interaction, + request_id: str, + reason: typing.Optional[str], + ): + if ( + not interaction.guild_id or not interaction.user or not interaction.guild + ): # Need guild for name + await interaction.followup.send( + "Command error: Missing context.", ephemeral=True + ) return guild_id_str = str(interaction.guild_id) req_to_reject, req_idx = None, -1 for i, r in enumerate(self.pending_requests_data.get(guild_id_str, [])): if r.get("request_id") == request_id and r.get("status") == "pending": - req_to_reject, req_idx = r, i; break + req_to_reject, req_idx = r, i + break if not req_to_reject: - await interaction.followup.send(f"❌ Pending {self.cog_name} request ID `{request_id}` not found.", ephemeral=True) + await interaction.followup.send( + f"❌ Pending {self.cog_name} request ID `{request_id}` not found.", + ephemeral=True, + ) return - req_to_reject.update({"status": "rejected", "moderator_id": str(interaction.user.id), "moderation_timestamp": discord.utils.utcnow().isoformat(), "rejection_reason": reason}) + req_to_reject.update( + { + "status": "rejected", + "moderator_id": str(interaction.user.id), + "moderation_timestamp": discord.utils.utcnow().isoformat(), + "rejection_reason": reason, + } + ) self.pending_requests_data[guild_id_str][req_idx] = req_to_reject self._save_pending_requests() - await interaction.followup.send(f"🗑️ {self.cog_name} Request ID `{request_id}` rejected.", ephemeral=True) + await interaction.followup.send( + f"🗑️ {self.cog_name} Request ID `{request_id}` rejected.", ephemeral=True + ) try: requester = await self.bot.fetch_user(int(req_to_reject["requester_id"])) msg = f"😥 Your {self.cog_name} watch request (`{request_id}`) for tags `{req_to_reject['requested_tags']}` in server `{interaction.guild.name}` was **rejected** by {interaction.user.mention}." - if reason: msg += f"\nReason: {reason}" + if reason: + msg += f"\nReason: {reason}" await requester.send(msg) - except Exception as e: log.error(f"Failed to notify {self.cog_name} requester {req_to_reject['requester_id']} of rejection: {e}") + except Exception as e: + log.error( + f"Failed to notify {self.cog_name} requester {req_to_reject['requester_id']} of rejection: {e}" + ) async def _watch_list_logic(self, interaction: discord.Interaction): if not interaction.guild_id or not interaction.guild: - await interaction.response.send_message("Command error: Missing guild context.", ephemeral=True) + await interaction.response.send_message( + "Command error: Missing guild context.", ephemeral=True + ) return guild_id_str = str(interaction.guild_id) guild_subs = self.subscriptions_data.get(guild_id_str, []) if not guild_subs: - await interaction.response.send_message(f"No active {self.cog_name} tag watches for this server.", ephemeral=True) + await interaction.response.send_message( + f"No active {self.cog_name} tag watches for this server.", + ephemeral=True, + ) return - embed = discord.Embed(title=f"Active {self.cog_name} Tag Watches for {interaction.guild.name}", color=discord.Color.blue()) + embed = discord.Embed( + title=f"Active {self.cog_name} Tag Watches for {interaction.guild.name}", + color=discord.Color.blue(), + ) desc_parts = [] for sub in guild_subs: target_location = "Unknown Target" # Simplified location fetching for brevity, original logic was more robust - if sub.get('forum_channel_id') and sub.get('target_post_id'): + if sub.get("forum_channel_id") and sub.get("target_post_id"): target_location = f"Forum Post <#{sub.get('target_post_id')}> in Forum <#{sub.get('forum_channel_id')}>" - elif sub.get('channel_id'): + elif sub.get("channel_id"): target_location = f"<#{sub.get('channel_id')}>" - if sub.get('thread_id'): target_location += f" (Thread <#{sub.get('thread_id')}>)" + if sub.get("thread_id"): + target_location += f" (Thread <#{sub.get('thread_id')}>)" - desc_parts.append(f"**ID:** `{sub.get('subscription_id')}`\n **Tags:** `{sub.get('tags')}`\n **Target:** {target_location}\n **Last Sent ID:** `{sub.get('last_known_post_id', 'N/A')}`\n---") + desc_parts.append( + f"**ID:** `{sub.get('subscription_id')}`\n **Tags:** `{sub.get('tags')}`\n **Target:** {target_location}\n **Last Sent ID:** `{sub.get('last_known_post_id', 'N/A')}`\n---" + ) embed.description = "\n".join(desc_parts)[:4096] await interaction.response.send_message(embed=embed, ephemeral=True) - async def _watch_remove_logic(self, interaction: discord.Interaction, subscription_id: str): + async def _watch_remove_logic( + self, interaction: discord.Interaction, subscription_id: str + ): if not interaction.guild_id: - await interaction.response.send_message("Command error: Missing guild context.", ephemeral=True) + await interaction.response.send_message( + "Command error: Missing guild context.", ephemeral=True + ) return guild_id_str = str(interaction.guild_id) @@ -1031,24 +1511,39 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet new_subs_list = [] for sub_entry in guild_subs: if sub_entry.get("subscription_id") == subscription_id: - removed_info = f"tags `{sub_entry.get('tags')}`" # Simplified + removed_info = f"tags `{sub_entry.get('tags')}`" # Simplified found = True - log.info(f"Removing {self.cog_name} watch: Guild {guild_id_str}, Sub ID {subscription_id}") - else: new_subs_list.append(sub_entry) + log.info( + f"Removing {self.cog_name} watch: Guild {guild_id_str}, Sub ID {subscription_id}" + ) + else: + new_subs_list.append(sub_entry) if not found: - await interaction.response.send_message(f"❌ {self.cog_name} subscription ID `{subscription_id}` not found.", ephemeral=True) + await interaction.response.send_message( + f"❌ {self.cog_name} subscription ID `{subscription_id}` not found.", + ephemeral=True, + ) return - if not new_subs_list: del self.subscriptions_data[guild_id_str] - else: self.subscriptions_data[guild_id_str] = new_subs_list + if not new_subs_list: + del self.subscriptions_data[guild_id_str] + else: + self.subscriptions_data[guild_id_str] = new_subs_list self._save_subscriptions() - await interaction.response.send_message(f"✅ Removed {self.cog_name} watch for {removed_info} (ID: `{subscription_id}`).", ephemeral=True) + await interaction.response.send_message( + f"✅ Removed {self.cog_name} watch for {removed_info} (ID: `{subscription_id}`).", + ephemeral=True, + ) - async def _watch_test_message_logic(self, interaction: discord.Interaction, subscription_id: str): + async def _watch_test_message_logic( + self, interaction: discord.Interaction, subscription_id: str + ): """Sends a test new-post message for a given subscription.""" if not interaction.guild_id: - await interaction.response.send_message("Command error: Missing guild context.", ephemeral=True) + await interaction.response.send_message( + "Command error: Missing guild context.", ephemeral=True + ) return guild_id_str = str(interaction.guild_id) @@ -1075,7 +1570,9 @@ class GelbooruWatcherBaseCog(commands.Cog, abc.ABC, metaclass=GelbooruWatcherMet ) return - fetched = await self._fetch_posts_logic("internal_test_msg", tags, pid_override=0, limit_override=1) + fetched = await self._fetch_posts_logic( + "internal_test_msg", tags, pid_override=0, limit_override=1 + ) file_url = None post_id = None if isinstance(fetched, list) and fetched: diff --git a/cogs/gif_optimizer_cog.py b/cogs/gif_optimizer_cog.py index 881f6d3..352e3ed 100644 --- a/cogs/gif_optimizer_cog.py +++ b/cogs/gif_optimizer_cog.py @@ -7,11 +7,19 @@ import io import tempfile import traceback + class GIFOptimizerCog(commands.Cog): def __init__(self, bot): self.bot = bot - async def _optimize_gif_internal(self, input_bytes: bytes, colors: int, dither_on: bool, crop_box_str: str, resize_dimensions_str: str): + async def _optimize_gif_internal( + self, + input_bytes: bytes, + colors: int, + dither_on: bool, + crop_box_str: str, + resize_dimensions_str: str, + ): """ Internal function to optimize a GIF from bytes, returning optimized bytes. Handles file I/O using temporary files. @@ -20,38 +28,48 @@ class GIFOptimizerCog(commands.Cog): output_path = None try: # Create a temporary input file - with tempfile.NamedTemporaryFile(delete=False, suffix=".gif") as temp_input_file: + with tempfile.NamedTemporaryFile( + delete=False, suffix=".gif" + ) as temp_input_file: temp_input_file.write(input_bytes) input_path = temp_input_file.name # Create a temporary output file - with tempfile.NamedTemporaryFile(delete=False, suffix=".gif") as temp_output_file: + with tempfile.NamedTemporaryFile( + delete=False, suffix=".gif" + ) as temp_output_file: output_path = temp_output_file.name # Parse crop and resize arguments crop_box_tuple = None if crop_box_str: try: - parts = [int(p.strip()) for p in crop_box_str.split(',')] + parts = [int(p.strip()) for p in crop_box_str.split(",")] if len(parts) == 4: crop_box_tuple = tuple(parts) else: - raise ValueError("Crop argument must be four integers: left,top,right,bottom (e.g., '10,20,100,150')") + raise ValueError( + "Crop argument must be four integers: left,top,right,bottom (e.g., '10,20,100,150')" + ) except ValueError as e: return None, f"Invalid crop format: {e}" resize_dims_tuple = None if resize_dimensions_str: try: - parts_str = resize_dimensions_str.replace('x', ',').split(',') + parts_str = resize_dimensions_str.replace("x", ",").split(",") parts = [int(p.strip()) for p in parts_str] if len(parts) == 2: if parts[0] > 0 and parts[1] > 0: resize_dims_tuple = tuple(parts) else: - raise ValueError("Resize dimensions (width, height) must be positive integers.") + raise ValueError( + "Resize dimensions (width, height) must be positive integers." + ) else: - raise ValueError("Resize argument must be two positive integers: width,height (e.g., '128,128' or '128x128')") + raise ValueError( + "Resize argument must be two positive integers: width,height (e.g., '128,128' or '128x128')" + ) except ValueError as e: return None, f"Invalid resize format: {e}" @@ -61,16 +79,16 @@ class GIFOptimizerCog(commands.Cog): # --- Original optimize_gif logic, adapted --- img = Image.open(input_path) - original_loop = img.info.get('loop', 0) - original_transparency = img.info.get('transparency') + original_loop = img.info.get("loop", 0) + original_transparency = img.info.get("transparency") processed_frames = [] durations = [] disposals = [] for i, frame_image in enumerate(ImageSequence.Iterator(img)): - durations.append(frame_image.info.get('duration', 100)) - disposals.append(frame_image.info.get('disposal', 2)) + durations.append(frame_image.info.get("duration", 100)) + disposals.append(frame_image.info.get("disposal", 2)) current_frame = frame_image.copy() @@ -79,19 +97,30 @@ class GIFOptimizerCog(commands.Cog): current_frame = current_frame.crop(crop_box_tuple) except Exception as crop_error: # Log warning, but don't fail the entire process for a single frame crop error - print(f"Warning: Could not apply crop box {crop_box_tuple} to frame {i+1}. Error: {crop_error}") + print( + f"Warning: Could not apply crop box {crop_box_tuple} to frame {i+1}. Error: {crop_error}" + ) # Optionally, you could decide to skip this frame or use the original frame if cropping is critical. # For now, we'll let the error propagate if it's a critical image error, otherwise proceed. if resize_dims_tuple: try: # Use Image.LANCZOS directly - current_frame = current_frame.resize(resize_dims_tuple, Image.LANCZOS) + current_frame = current_frame.resize( + resize_dims_tuple, Image.LANCZOS + ) except Exception as resize_error: - print(f"Warning: Could not resize frame {i+1} to {resize_dims_tuple}. Error: {resize_error}") + print( + f"Warning: Could not resize frame {i+1} to {resize_dims_tuple}. Error: {resize_error}" + ) - frame_rgba = current_frame.convert('RGBA') - quantized_frame = frame_rgba.convert('P', palette=Image.Palette.ADAPTIVE, colors=num_colors, dither=dither_method) + frame_rgba = current_frame.convert("RGBA") + quantized_frame = frame_rgba.convert( + "P", + palette=Image.Palette.ADAPTIVE, + colors=num_colors, + dither=dither_method, + ) processed_frames.append(quantized_frame) if not processed_frames: @@ -105,15 +134,17 @@ class GIFOptimizerCog(commands.Cog): duration=durations, loop=original_loop, disposal=disposals, - transparency=original_transparency + transparency=original_transparency, ) - with open(output_path, 'rb') as f: + with open(output_path, "rb") as f: optimized_bytes = f.read() input_size = len(input_bytes) output_size = len(optimized_bytes) - reduction_percentage = (input_size - output_size) / input_size * 100 if input_size > 0 else 0 + reduction_percentage = ( + (input_size - output_size) / input_size * 100 if input_size > 0 else 0 + ) stats = ( f"Original size: {input_size / 1024:.2f} KB\n" @@ -125,9 +156,12 @@ class GIFOptimizerCog(commands.Cog): except FileNotFoundError: return None, "Internal error: Temporary file not found." except UnidentifiedImageError: - return None, "Cannot identify image file. It might be corrupted or not a supported GIF format." + return ( + None, + "Cannot identify image file. It might be corrupted or not a supported GIF format.", + ) except Exception as e: - traceback.print_exc() # Print full traceback to console for debugging + traceback.print_exc() # Print full traceback to console for debugging return None, f"An unexpected error occurred during GIF optimization: {e}" finally: # Clean up temporary files @@ -137,18 +171,31 @@ class GIFOptimizerCog(commands.Cog): os.remove(output_path) @commands.command(name="optimizegif", description="Optimizes a GIF attachment.") - async def optimize_gif_prefix(self, ctx: commands.Context, attachment: discord.Attachment, colors: int = 128, dither: bool = True, crop: str = None, resize: str = None): - if not attachment.filename.lower().endswith('.gif'): + async def optimize_gif_prefix( + self, + ctx: commands.Context, + attachment: discord.Attachment, + colors: int = 128, + dither: bool = True, + crop: str = None, + resize: str = None, + ): + if not attachment.filename.lower().endswith(".gif"): await ctx.send("Please provide a GIF file.") return await ctx.defer() try: input_bytes = await attachment.read() - optimized_bytes, stats = await self._optimize_gif_internal(input_bytes, colors, dither, crop, resize) + optimized_bytes, stats = await self._optimize_gif_internal( + input_bytes, colors, dither, crop, resize + ) if optimized_bytes: - file = discord.File(io.BytesIO(optimized_bytes), filename=f"optimized_{attachment.filename}") + file = discord.File( + io.BytesIO(optimized_bytes), + filename=f"optimized_{attachment.filename}", + ) await ctx.send(f"GIF optimized successfully!\n{stats}", file=file) else: await ctx.send(f"Failed to optimize GIF: {stats}") @@ -162,26 +209,44 @@ class GIFOptimizerCog(commands.Cog): colors="Number of colors to reduce to (e.g., 256, 128, 64). Max 256.", dither="Enable Floyd-Steinberg dithering (improves quality, slightly slower).", crop="Crop the GIF. Provide as 'left,top,right,bottom' (e.g., '10,20,100,150').", - resize="Resize the GIF. Provide as 'width,height' (e.g., '128,128' or '128x128')." + resize="Resize the GIF. Provide as 'width,height' (e.g., '128,128' or '128x128').", ) - async def optimize_gif_slash(self, interaction: discord.Interaction, attachment: discord.Attachment, colors: app_commands.Range[int, 2, 256] = 128, dither: bool = True, crop: str = None, resize: str = None): - if not attachment.filename.lower().endswith('.gif'): - await interaction.response.send_message("Please provide a GIF file.", ephemeral=True) + async def optimize_gif_slash( + self, + interaction: discord.Interaction, + attachment: discord.Attachment, + colors: app_commands.Range[int, 2, 256] = 128, + dither: bool = True, + crop: str = None, + resize: str = None, + ): + if not attachment.filename.lower().endswith(".gif"): + await interaction.response.send_message( + "Please provide a GIF file.", ephemeral=True + ) return await interaction.response.defer() try: input_bytes = await attachment.read() - optimized_bytes, stats = await self._optimize_gif_internal(input_bytes, colors, dither, crop, resize) + optimized_bytes, stats = await self._optimize_gif_internal( + input_bytes, colors, dither, crop, resize + ) if optimized_bytes: - file = discord.File(io.BytesIO(optimized_bytes), filename=f"optimized_{attachment.filename}") - await interaction.followup.send(f"GIF optimized successfully!\n{stats}", file=file) + file = discord.File( + io.BytesIO(optimized_bytes), + filename=f"optimized_{attachment.filename}", + ) + await interaction.followup.send( + f"GIF optimized successfully!\n{stats}", file=file + ) else: await interaction.followup.send(f"Failed to optimize GIF: {stats}") except Exception as e: await interaction.followup.send(f"An error occurred: {e}") traceback.print_exc() + async def setup(bot): await bot.add_cog(GIFOptimizerCog(bot)) diff --git a/cogs/git_monitor_cog.py b/cogs/git_monitor_cog.py index be8da50..9023766 100644 --- a/cogs/git_monitor_cog.py +++ b/cogs/git_monitor_cog.py @@ -4,30 +4,35 @@ from discord import app_commands import logging import re import secrets -import datetime # Added for timezone.utc +import datetime # Added for timezone.utc from typing import Literal, Optional, List, Dict, Any -import asyncio # For sleep -import aiohttp # For API calls -import requests.utils # For url encoding gitlab project path +import asyncio # For sleep +import aiohttp # For API calls +import requests.utils # For url encoding gitlab project path # Assuming settings_manager is in the parent directory # Adjust the import path if your project structure is different try: - from .. import settings_manager # If cogs is a package + from .. import settings_manager # If cogs is a package except ImportError: - import settings_manager # If run from the root or cogs is not a package + import settings_manager # If run from the root or cogs is not a package log = logging.getLogger(__name__) + # Helper to parse repo URL and determine platform def parse_repo_url(url: str) -> tuple[Optional[str], Optional[str]]: """Parses a Git repository URL to extract platform and a simplified repo identifier.""" # Fixed regex pattern for GitHub URLs - github_match = re.match(r"^(?:https?://)?(?:www\.)?github\.com/([\w.-]+/[\w.-]+)(?:\.git)?/?$", url) + github_match = re.match( + r"^(?:https?://)?(?:www\.)?github\.com/([\w.-]+/[\w.-]+)(?:\.git)?/?$", url + ) if github_match: return "github", github_match.group(1) - gitlab_match = re.match(r"^(?:https?://)?(?:www\.)?gitlab\.com/([\w.-]+(?:/[\w.-]+)+)(?:\.git)?/?$", url) + gitlab_match = re.match( + r"^(?:https?://)?(?:www\.)?gitlab\.com/([\w.-]+(?:/[\w.-]+)+)(?:\.git)?/?$", url + ) if gitlab_match: return "gitlab", gitlab_match.group(1) return None, None @@ -43,7 +48,7 @@ class GitMonitorCog(commands.Cog): self.poll_repositories_task.cancel() log.info("GitMonitorCog unloaded and polling task cancelled.") - @tasks.loop(minutes=5.0) # Default, can be adjusted or made dynamic later + @tasks.loop(minutes=5.0) # Default, can be adjusted or made dynamic later async def poll_repositories_task(self): log.debug("Git repository polling task running...") try: @@ -55,169 +60,265 @@ class GitMonitorCog(commands.Cog): log.info(f"Found {len(repos_to_poll)} repositories to poll.") for repo_config in repos_to_poll: - repo_id = repo_config['id'] - guild_id = repo_config['guild_id'] - repo_url = repo_config['repository_url'] - platform = repo_config['platform'] - channel_id = repo_config['notification_channel_id'] - target_branch = repo_config['target_branch'] # Get the target branch - last_sha = repo_config['last_polled_commit_sha'] + repo_id = repo_config["id"] + guild_id = repo_config["guild_id"] + repo_url = repo_config["repository_url"] + platform = repo_config["platform"] + channel_id = repo_config["notification_channel_id"] + target_branch = repo_config["target_branch"] # Get the target branch + last_sha = repo_config["last_polled_commit_sha"] # polling_interval = repo_config['polling_interval_minutes'] # Use this if intervals are dynamic per repo - log.debug(f"Polling {platform} repo: {repo_url} (Branch: {target_branch or 'default'}) (ID: {repo_id}) for guild {guild_id}") + log.debug( + f"Polling {platform} repo: {repo_url} (Branch: {target_branch or 'default'}) (ID: {repo_id}) for guild {guild_id}" + ) new_commits_data: List[Dict[str, Any]] = [] latest_fetched_sha = last_sha try: - async with aiohttp.ClientSession(headers={"User-Agent": "DiscordBot/1.0"}) as session: + async with aiohttp.ClientSession( + headers={"User-Agent": "DiscordBot/1.0"} + ) as session: if platform == "github": # GitHub API: GET /repos/{owner}/{repo}/commits # We need to parse owner/repo from repo_url - _, owner_repo_path = parse_repo_url(repo_url) # e.g. "user/repo" + _, owner_repo_path = parse_repo_url( + repo_url + ) # e.g. "user/repo" if owner_repo_path: api_url = f"https://api.github.com/repos/{owner_repo_path}/commits" - params = {"per_page": 10} # Fetch up to 10 recent commits + params = { + "per_page": 10 + } # Fetch up to 10 recent commits if target_branch: - params["sha"] = target_branch # GitHub uses 'sha' for branch/tag/commit SHA + params["sha"] = ( + target_branch # GitHub uses 'sha' for branch/tag/commit SHA + ) # No 'since_sha' for GitHub commits list. Manual filtering after fetch. - async with session.get(api_url, params=params) as response: + async with session.get( + api_url, params=params + ) as response: if response.status == 200: commits_payload = await response.json() temp_new_commits = [] - for commit_item in reversed(commits_payload): # Process oldest first - if commit_item['sha'] == last_sha: - temp_new_commits = [] # Clear previous if we found the last one + for commit_item in reversed( + commits_payload + ): # Process oldest first + if commit_item["sha"] == last_sha: + temp_new_commits = ( + [] + ) # Clear previous if we found the last one continue temp_new_commits.append(commit_item) if temp_new_commits: new_commits_data = temp_new_commits - latest_fetched_sha = new_commits_data[-1]['sha'] - elif response.status == 403: # Rate limit - log.warning(f"GitHub API rate limit hit for {repo_url}. Headers: {response.headers}") + latest_fetched_sha = new_commits_data[-1][ + "sha" + ] + elif response.status == 403: # Rate limit + log.warning( + f"GitHub API rate limit hit for {repo_url}. Headers: {response.headers}" + ) # Consider increasing loop wait time or specific backoff for this repo elif response.status == 404: - log.error(f"Repository {repo_url} not found on GitHub (404). Consider removing or marking as invalid.") + log.error( + f"Repository {repo_url} not found on GitHub (404). Consider removing or marking as invalid." + ) else: - log.error(f"Error fetching GitHub commits for {repo_url}: {response.status} - {await response.text()}") + log.error( + f"Error fetching GitHub commits for {repo_url}: {response.status} - {await response.text()}" + ) elif platform == "gitlab": # GitLab API: GET /projects/{id}/repository/commits # We need project ID or URL-encoded path. - _, project_path = parse_repo_url(repo_url) # e.g. "group/subgroup/project" + _, project_path = parse_repo_url( + repo_url + ) # e.g. "group/subgroup/project" if project_path: - encoded_project_path = requests.utils.quote(project_path, safe='') + encoded_project_path = requests.utils.quote( + project_path, safe="" + ) api_url = f"https://gitlab.com/api/v4/projects/{encoded_project_path}/repository/commits" params = {"per_page": 10} if target_branch: - params["ref_name"] = target_branch # GitLab uses 'ref_name' for branch/tag + params["ref_name"] = ( + target_branch # GitLab uses 'ref_name' for branch/tag + ) # No 'since_sha' for GitLab. Manual filtering. - async with session.get(api_url, params=params) as response: + async with session.get( + api_url, params=params + ) as response: if response.status == 200: commits_payload = await response.json() temp_new_commits = [] for commit_item in reversed(commits_payload): - if commit_item['id'] == last_sha: + if commit_item["id"] == last_sha: temp_new_commits = [] continue temp_new_commits.append(commit_item) if temp_new_commits: new_commits_data = temp_new_commits - latest_fetched_sha = new_commits_data[-1]['id'] + latest_fetched_sha = new_commits_data[-1][ + "id" + ] elif response.status == 403: - log.warning(f"GitLab API rate limit hit for {repo_url}. Headers: {response.headers}") + log.warning( + f"GitLab API rate limit hit for {repo_url}. Headers: {response.headers}" + ) elif response.status == 404: - log.error(f"Repository {repo_url} not found on GitLab (404).") + log.error( + f"Repository {repo_url} not found on GitLab (404)." + ) else: - log.error(f"Error fetching GitLab commits for {repo_url}: {response.status} - {await response.text()}") + log.error( + f"Error fetching GitLab commits for {repo_url}: {response.status} - {await response.text()}" + ) except aiohttp.ClientError as ce: log.error(f"AIOHTTP client error polling {repo_url}: {ce}") except Exception as ex: log.exception(f"Generic error polling {repo_url}: {ex}") - if new_commits_data: channel = self.bot.get_channel(channel_id) if channel: for commit_item_data in new_commits_data: embed = None if platform == "github": - commit_sha = commit_item_data.get('sha', 'N/A') + commit_sha = commit_item_data.get("sha", "N/A") commit_id_short = commit_sha[:7] - commit_data = commit_item_data.get('commit', {}) - commit_msg = commit_data.get('message', 'No message.') - commit_url = commit_item_data.get('html_url', '#') - author_info = commit_data.get('author', {}) # Committer info is also available - author_name = author_info.get('name', 'Unknown Author') + commit_data = commit_item_data.get("commit", {}) + commit_msg = commit_data.get("message", "No message.") + commit_url = commit_item_data.get("html_url", "#") + author_info = commit_data.get( + "author", {} + ) # Committer info is also available + author_name = author_info.get("name", "Unknown Author") # Branch information is not directly available in this specific commit object from /commits endpoint. # It's part of the push event or needs to be inferred/fetched differently for polling. # For polling, we typically monitor a specific branch, or assume default. # Verification status - verification = commit_data.get('verification', {}) - verified_status = "Verified" if verification.get('verified') else "Unverified" - if verification.get('reason') and verification.get('reason') != 'unsigned': - verified_status += f" ({verification.get('reason')})" + verification = commit_data.get("verification", {}) + verified_status = ( + "Verified" + if verification.get("verified") + else "Unverified" + ) + if ( + verification.get("reason") + and verification.get("reason") != "unsigned" + ): + verified_status += ( + f" ({verification.get('reason')})" + ) # Files changed and stats require another API call per commit: GET /repos/{owner}/{repo}/commits/{sha} # This is too API intensive for a simple polling loop. # We will omit detailed file stats for polled GitHub commits for now. - files_changed_str = "File stats not fetched for polled commits." + files_changed_str = ( + "File stats not fetched for polled commits." + ) embed = discord.Embed( title=f"New Commit in {repo_url}", - description=commit_msg.splitlines()[0], # First line + description=commit_msg.splitlines()[ + 0 + ], # First line color=discord.Color.blue(), - url=commit_url + url=commit_url, ) embed.set_author(name=author_name) - embed.add_field(name="Commit", value=f"[`{commit_id_short}`]({commit_url})", inline=True) - embed.add_field(name="Verification", value=verified_status, inline=True) + embed.add_field( + name="Commit", + value=f"[`{commit_id_short}`]({commit_url})", + inline=True, + ) + embed.add_field( + name="Verification", + value=verified_status, + inline=True, + ) # embed.add_field(name="Branch", value="default (polling)", inline=True) # Placeholder - embed.add_field(name="Changes", value=files_changed_str, inline=False) + embed.add_field( + name="Changes", + value=files_changed_str, + inline=False, + ) elif platform == "gitlab": - commit_id = commit_item_data.get('id', 'N/A') - commit_id_short = commit_item_data.get('short_id', commit_id[:7]) - commit_msg = commit_item_data.get('title', 'No message.') # GitLab uses 'title' for first line - commit_url = commit_item_data.get('web_url', '#') - author_name = commit_item_data.get('author_name', 'Unknown Author') + commit_id = commit_item_data.get("id", "N/A") + commit_id_short = commit_item_data.get( + "short_id", commit_id[:7] + ) + commit_msg = commit_item_data.get( + "title", "No message." + ) # GitLab uses 'title' for first line + commit_url = commit_item_data.get("web_url", "#") + author_name = commit_item_data.get( + "author_name", "Unknown Author" + ) # Branch information is not directly in this commit object from /commits. # It's part of the push event or needs to be inferred. # GitLab commit stats (added/deleted lines) are in the commit details, not list. - files_changed_str = "File stats not fetched for polled commits." + files_changed_str = ( + "File stats not fetched for polled commits." + ) embed = discord.Embed( title=f"New Commit in {repo_url}", description=commit_msg.splitlines()[0], color=discord.Color.orange(), - url=commit_url + url=commit_url, ) embed.set_author(name=author_name) - embed.add_field(name="Commit", value=f"[`{commit_id_short}`]({commit_url})", inline=True) + embed.add_field( + name="Commit", + value=f"[`{commit_id_short}`]({commit_url})", + inline=True, + ) # embed.add_field(name="Branch", value="default (polling)", inline=True) # Placeholder - embed.add_field(name="Changes", value=files_changed_str, inline=False) + embed.add_field( + name="Changes", + value=files_changed_str, + inline=False, + ) if embed: try: await channel.send(embed=embed) - log.info(f"Sent polled notification for commit {commit_id_short} in {repo_url} to channel {channel_id}") + log.info( + f"Sent polled notification for commit {commit_id_short} in {repo_url} to channel {channel_id}" + ) except discord.Forbidden: - log.error(f"Missing permissions to send message in channel {channel_id} for guild {guild_id}") + log.error( + f"Missing permissions to send message in channel {channel_id} for guild {guild_id}" + ) except discord.HTTPException as dhe: - log.error(f"Discord HTTP error sending message for {repo_url}: {dhe}") + log.error( + f"Discord HTTP error sending message for {repo_url}: {dhe}" + ) else: - log.warning(f"Channel {channel_id} not found for guild {guild_id} for repo {repo_url}") + log.warning( + f"Channel {channel_id} not found for guild {guild_id} for repo {repo_url}" + ) # Update polling status in DB - if latest_fetched_sha != last_sha or not new_commits_data : # Update if new sha or just to update timestamp - await settings_manager.update_repository_polling_status(repo_id, latest_fetched_sha, datetime.datetime.now(datetime.timezone.utc)) + if ( + latest_fetched_sha != last_sha or not new_commits_data + ): # Update if new sha or just to update timestamp + await settings_manager.update_repository_polling_status( + repo_id, + latest_fetched_sha, + datetime.datetime.now(datetime.timezone.utc), + ) # Small delay between processing each repo to be nice to APIs - await asyncio.sleep(2) # 2 seconds delay + await asyncio.sleep(2) # 2 seconds delay except Exception as e: log.exception("Error occurred during repository polling task:", exc_info=e) @@ -227,33 +328,47 @@ class GitMonitorCog(commands.Cog): await self.bot.wait_until_ready() log.info("Polling task is waiting for bot to be ready...") - gitlistener_group = app_commands.Group(name="gitlistener", description="Manage Git repository monitoring.") + gitlistener_group = app_commands.Group( + name="gitlistener", description="Manage Git repository monitoring." + ) - @gitlistener_group.command(name="add", description="Add a repository to monitor for commits.") + @gitlistener_group.command( + name="add", description="Add a repository to monitor for commits." + ) @app_commands.describe( repository_url="The full URL of the GitHub or GitLab repository (e.g., https://github.com/user/repo).", channel="The channel where commit notifications should be sent.", monitoring_method="Choose 'webhook' for real-time (requires repo admin rights) or 'poll' for periodic checks.", - branch="The specific branch to monitor (for 'poll' method, defaults to main/master if not specified)." + branch="The specific branch to monitor (for 'poll' method, defaults to main/master if not specified).", ) @app_commands.checks.has_permissions(manage_guild=True) - async def add_repository(self, interaction: discord.Interaction, - repository_url: str, - channel: discord.TextChannel, - monitoring_method: Literal['webhook', 'poll'], - branch: Optional[str] = None): + async def add_repository( + self, + interaction: discord.Interaction, + repository_url: str, + channel: discord.TextChannel, + monitoring_method: Literal["webhook", "poll"], + branch: Optional[str] = None, + ): await interaction.response.defer(ephemeral=True) - cleaned_repository_url = repository_url.strip() # Strip whitespace + cleaned_repository_url = repository_url.strip() # Strip whitespace - if monitoring_method == 'poll' and not branch: - log.info(f"Branch not specified for polling method for {cleaned_repository_url}. Will use default in polling task or API default.") + if monitoring_method == "poll" and not branch: + log.info( + f"Branch not specified for polling method for {cleaned_repository_url}. Will use default in polling task or API default." + ) # If branch is None, the polling task will attempt to use the repo's default branch. pass - platform, repo_identifier = parse_repo_url(cleaned_repository_url) # Use cleaned URL + platform, repo_identifier = parse_repo_url( + cleaned_repository_url + ) # Use cleaned URL if not platform or not repo_identifier: - await interaction.followup.send(f"Invalid repository URL: `{repository_url}`. Please provide a valid GitHub or GitLab URL (e.g., https://github.com/user/repo).", ephemeral=True) + await interaction.followup.send( + f"Invalid repository URL: `{repository_url}`. Please provide a valid GitHub or GitLab URL (e.g., https://github.com/user/repo).", + ephemeral=True, + ) return guild_id = interaction.guild_id @@ -261,29 +376,42 @@ class GitMonitorCog(commands.Cog): notification_channel_id = channel.id # Check if this exact repo and channel combination already exists - existing_config = await settings_manager.get_monitored_repository_by_url(guild_id, repository_url, notification_channel_id) + existing_config = await settings_manager.get_monitored_repository_by_url( + guild_id, repository_url, notification_channel_id + ) if existing_config: - await interaction.followup.send(f"This repository ({repository_url}) is already being monitored in {channel.mention}.", ephemeral=True) + await interaction.followup.send( + f"This repository ({repository_url}) is already being monitored in {channel.mention}.", + ephemeral=True, + ) return webhook_secret = None db_repo_id = None reply_message = "" - if monitoring_method == 'webhook': + if monitoring_method == "webhook": webhook_secret = secrets.token_hex(32) # The API server needs the bot's domain. This should be configured. # For now, we'll use a placeholder. # TODO: Fetch API base URL from config or bot instance - api_base_url = getattr(self.bot, 'config', {}).get('API_BASE_URL', 'slipstreamm.dev') - if api_base_url == 'YOUR_API_DOMAIN_HERE.com': - log.warning("API_BASE_URL not configured for webhook URL generation. Using placeholder.") - + api_base_url = getattr(self.bot, "config", {}).get( + "API_BASE_URL", "slipstreamm.dev" + ) + if api_base_url == "YOUR_API_DOMAIN_HERE.com": + log.warning( + "API_BASE_URL not configured for webhook URL generation. Using placeholder." + ) db_repo_id = await settings_manager.add_monitored_repository( - guild_id=guild_id, repository_url=cleaned_repository_url, platform=platform, # Use cleaned URL - monitoring_method='webhook', notification_channel_id=notification_channel_id, - added_by_user_id=added_by_user_id, webhook_secret=webhook_secret, target_branch=None # Branch not used for webhooks + guild_id=guild_id, + repository_url=cleaned_repository_url, + platform=platform, # Use cleaned URL + monitoring_method="webhook", + notification_channel_id=notification_channel_id, + added_by_user_id=added_by_user_id, + webhook_secret=webhook_secret, + target_branch=None, # Branch not used for webhooks ) if db_repo_id: payload_url = f"https://{api_base_url}/webhook/{platform}/{db_repo_id}" @@ -301,18 +429,24 @@ class GitMonitorCog(commands.Cog): else: reply_message = "Failed to add repository for webhook monitoring. It might already exist or there was a database error." - elif monitoring_method == 'poll': + elif monitoring_method == "poll": # For polling, we might want to fetch the latest commit SHA now to avoid initial old notifications # This is a placeholder; actual fetching needs platform-specific API calls - initial_sha = None # TODO: Implement initial SHA fetch if desired + initial_sha = None # TODO: Implement initial SHA fetch if desired db_repo_id = await settings_manager.add_monitored_repository( - guild_id=guild_id, repository_url=cleaned_repository_url, platform=platform, # Use cleaned URL - monitoring_method='poll', notification_channel_id=notification_channel_id, - added_by_user_id=added_by_user_id, target_branch=branch, # Pass the branch for polling - last_polled_commit_sha=initial_sha + guild_id=guild_id, + repository_url=cleaned_repository_url, + platform=platform, # Use cleaned URL + monitoring_method="poll", + notification_channel_id=notification_channel_id, + added_by_user_id=added_by_user_id, + target_branch=branch, # Pass the branch for polling + last_polled_commit_sha=initial_sha, ) if db_repo_id: - branch_info = f"on branch `{branch}`" if branch else "on the default branch" + branch_info = ( + f"on branch `{branch}`" if branch else "on the default branch" + ) reply_message = ( f"Polling monitoring for `{repo_identifier}` ({platform.capitalize()}) {branch_info} added for {channel.mention}.\n" f"The bot will check for new commits periodically (around every 5-15 minutes)." @@ -323,59 +457,91 @@ class GitMonitorCog(commands.Cog): if db_repo_id: await interaction.followup.send(reply_message, ephemeral=True) else: - await interaction.followup.send(reply_message or "An unexpected error occurred.", ephemeral=True) + await interaction.followup.send( + reply_message or "An unexpected error occurred.", ephemeral=True + ) - - @gitlistener_group.command(name="remove", description="Remove a repository from monitoring.") + @gitlistener_group.command( + name="remove", description="Remove a repository from monitoring." + ) @app_commands.describe( repository_url="The full URL of the repository to remove.", - channel="The channel it's sending notifications to." + channel="The channel it's sending notifications to.", ) @app_commands.checks.has_permissions(manage_guild=True) - async def remove_repository(self, interaction: discord.Interaction, repository_url: str, channel: discord.TextChannel): + async def remove_repository( + self, + interaction: discord.Interaction, + repository_url: str, + channel: discord.TextChannel, + ): await interaction.response.defer(ephemeral=True) guild_id = interaction.guild_id notification_channel_id = channel.id platform, repo_identifier = parse_repo_url(repository_url) - if not platform: # repo_identifier can be None if URL is valid but not parsable to simple form - await interaction.followup.send("Invalid repository URL provided.", ephemeral=True) + if ( + not platform + ): # repo_identifier can be None if URL is valid but not parsable to simple form + await interaction.followup.send( + "Invalid repository URL provided.", ephemeral=True + ) return - success = await settings_manager.remove_monitored_repository(guild_id, repository_url, notification_channel_id) + success = await settings_manager.remove_monitored_repository( + guild_id, repository_url, notification_channel_id + ) if success: await interaction.followup.send( f"Successfully removed monitoring for `{repository_url}` from {channel.mention}.\n" f"If this was a webhook, remember to also delete the webhook from the repository settings on {platform.capitalize()}.", - ephemeral=True + ephemeral=True, ) else: - await interaction.followup.send(f"Could not find a monitoring setup for `{repository_url}` in {channel.mention} to remove, or a database error occurred.", ephemeral=True) + await interaction.followup.send( + f"Could not find a monitoring setup for `{repository_url}` in {channel.mention} to remove, or a database error occurred.", + ephemeral=True, + ) - @gitlistener_group.command(name="list", description="List repositories currently being monitored in this server.") + @gitlistener_group.command( + name="list", + description="List repositories currently being monitored in this server.", + ) @app_commands.checks.has_permissions(manage_guild=True) async def list_repositories(self, interaction: discord.Interaction): await interaction.response.defer(ephemeral=True) guild_id = interaction.guild_id - monitored_repos = await settings_manager.list_monitored_repositories_for_guild(guild_id) + monitored_repos = await settings_manager.list_monitored_repositories_for_guild( + guild_id + ) if not monitored_repos: - await interaction.followup.send("No repositories are currently being monitored in this server.", ephemeral=True) + await interaction.followup.send( + "No repositories are currently being monitored in this server.", + ephemeral=True, + ) return - embed = discord.Embed(title=f"Monitored Repositories for {interaction.guild.name}", color=discord.Color.blue()) + embed = discord.Embed( + title=f"Monitored Repositories for {interaction.guild.name}", + color=discord.Color.blue(), + ) description_lines = [] for repo in monitored_repos: - channel = self.bot.get_channel(repo['notification_channel_id']) - channel_mention = channel.mention if channel else f"ID: {repo['notification_channel_id']}" - method = repo['monitoring_method'].capitalize() - platform = repo['platform'].capitalize() + channel = self.bot.get_channel(repo["notification_channel_id"]) + channel_mention = ( + channel.mention if channel else f"ID: {repo['notification_channel_id']}" + ) + method = repo["monitoring_method"].capitalize() + platform = repo["platform"].capitalize() # Attempt to get a cleaner repo name if possible - _, repo_name_simple = parse_repo_url(repo['repository_url']) - display_name = repo_name_simple if repo_name_simple else repo['repository_url'] + _, repo_name_simple = parse_repo_url(repo["repository_url"]) + display_name = ( + repo_name_simple if repo_name_simple else repo["repository_url"] + ) description_lines.append( f"**[{display_name}]({repo['repository_url']})**\n" @@ -386,17 +552,19 @@ class GitMonitorCog(commands.Cog): ) embed.description = "\n\n".join(description_lines) - if len(embed.description) > 4000 : # Discord embed description limit + if len(embed.description) > 4000: # Discord embed description limit embed.description = embed.description[:3990] + "\n... (list truncated)" - await interaction.followup.send(embed=embed, ephemeral=True) + async def setup(bot: commands.Bot): # Ensure settings_manager's pools are set if this cog is loaded after bot's setup_hook # This is more of a safeguard; ideally, pools are set before cogs are loaded. - if settings_manager and not getattr(settings_manager, '_active_pg_pool', None): - log.warning("GitMonitorCog: settings_manager pools might not be set. Attempting to ensure they are via bot instance.") + if settings_manager and not getattr(settings_manager, "_active_pg_pool", None): + log.warning( + "GitMonitorCog: settings_manager pools might not be set. Attempting to ensure they are via bot instance." + ) # This relies on bot having pg_pool and redis_pool attributes set by its setup_hook # settings_manager.set_bot_pools(getattr(bot, 'pg_pool', None), getattr(bot, 'redis_pool', None)) diff --git a/cogs/giveaways_cog.py b/cogs/giveaways_cog.py index 45f7d17..7d356ea 100644 --- a/cogs/giveaways_cog.py +++ b/cogs/giveaways_cog.py @@ -4,17 +4,20 @@ from discord import app_commands, ui import datetime import asyncio import random -import re # For parsing duration +import re # For parsing duration import json import os -import aiofiles # Import aiofiles +import aiofiles # Import aiofiles import aiofiles.os GIVEAWAY_DATA_FILE = "data/giveaways.json" DATA_DIR = "data" + # --- Helper Functions --- -async def is_user_nitro_like(user: discord.User | discord.Member, bot: commands.Bot = None) -> bool: +async def is_user_nitro_like( + user: discord.User | discord.Member, bot: commands.Bot = None +) -> bool: """Checks if a user has an animated avatar or a banner, indicating Nitro.""" # Fetch the full user object to get banner information if bot: @@ -23,28 +26,34 @@ async def is_user_nitro_like(user: discord.User | discord.Member, bot: commands. user = fetched_user except discord.NotFound: pass # Use the original user object if fetch fails - - if isinstance(user, discord.Member): # Member object has guild-specific avatar + + if isinstance(user, discord.Member): # Member object has guild-specific avatar # Check guild avatar first, then global avatar if user.guild_avatar and user.guild_avatar.is_animated(): return True if user.avatar and user.avatar.is_animated(): return True - elif user.avatar and user.avatar.is_animated(): # User object + elif user.avatar and user.avatar.is_animated(): # User object return True return user.banner is not None # --- UI Views and Buttons --- -class GiveawayEnterButton(ui.Button['GiveawayEnterView']): +class GiveawayEnterButton(ui.Button["GiveawayEnterView"]): def __init__(self, cog_ref): - super().__init__(label="Enter Giveaway", style=discord.ButtonStyle.green, custom_id="giveaway_enter_button") - self.cog: GiveawaysCog = cog_ref # Store a reference to the cog + super().__init__( + label="Enter Giveaway", + style=discord.ButtonStyle.green, + custom_id="giveaway_enter_button", + ) + self.cog: GiveawaysCog = cog_ref # Store a reference to the cog async def callback(self, interaction: discord.Interaction): giveaway = self.cog._get_giveaway_by_message_id(interaction.message.id) if not giveaway or giveaway.get("ended", False): - await interaction.response.send_message("This giveaway has ended or is no longer active.", ephemeral=True) + await interaction.response.send_message( + "This giveaway has ended or is no longer active.", ephemeral=True + ) # Optionally disable the button on the message if possible self.disabled = True await interaction.message.edit(view=self.view) @@ -54,17 +63,21 @@ class GiveawayEnterButton(ui.Button['GiveawayEnterView']): if not await is_user_nitro_like(interaction.user, bot=self.cog.bot): await interaction.response.send_message( "This is a Nitro-exclusive giveaway. You don't appear to have Nitro (animated avatar or banner).", - ephemeral=True + ephemeral=True, ) return if interaction.user.id in giveaway["participants"]: - await interaction.response.send_message("You have already entered this giveaway!", ephemeral=True) + await interaction.response.send_message( + "You have already entered this giveaway!", ephemeral=True + ) else: giveaway["participants"].add(interaction.user.id) - await self.cog.save_giveaways() # Save after participant update - await interaction.response.send_message("You have successfully entered the giveaway!", ephemeral=True) - + await self.cog.save_giveaways() # Save after participant update + await interaction.response.send_message( + "You have successfully entered the giveaway!", ephemeral=True + ) + # Update participant count in embed if desired (optional) # embed = interaction.message.embeds[0] # embed.set_field_at(embed.fields.index(...) or create new field, name="Participants", value=str(len(giveaway["participants"]))) @@ -72,14 +85,21 @@ class GiveawayEnterButton(ui.Button['GiveawayEnterView']): class GiveawayEnterView(ui.View): - def __init__(self, cog: 'GiveawaysCog', timeout=None): # Timeout=None for persistent + def __init__( + self, cog: "GiveawaysCog", timeout=None + ): # Timeout=None for persistent super().__init__(timeout=timeout) self.cog = cog self.add_item(GiveawayEnterButton(cog_ref=self.cog)) -class GiveawayRerollButton(ui.Button['GiveawayEndView']): + +class GiveawayRerollButton(ui.Button["GiveawayEndView"]): def __init__(self, cog_ref, original_giveaway_message_id: int): - super().__init__(label="Reroll Winner", style=discord.ButtonStyle.blurple, custom_id=f"giveaway_reroll_button:{original_giveaway_message_id}") + super().__init__( + label="Reroll Winner", + style=discord.ButtonStyle.blurple, + custom_id=f"giveaway_reroll_button:{original_giveaway_message_id}", + ) self.cog: GiveawaysCog = cog_ref self.original_giveaway_message_id = original_giveaway_message_id @@ -87,47 +107,61 @@ class GiveawayRerollButton(ui.Button['GiveawayEndView']): # For reroll, we need to find the *original* giveaway data, which might not be in active_giveaways anymore. # We'll need to load it from the JSON file or have a separate store for ended giveaways. # For simplicity now, let's assume we can find it or it's passed appropriately. - + # This custom_id parsing is a common pattern for persistent buttons with dynamic data. # msg_id_str = interaction.data["custom_id"].split(":")[1] # original_msg_id = int(msg_id_str) # Find the giveaway data (this might need adjustment based on how ended giveaways are stored) # We'll search all giveaways loaded, including those marked "ended" - giveaway_data = self.cog._get_giveaway_by_message_id(self.original_giveaway_message_id, search_all=True) + giveaway_data = self.cog._get_giveaway_by_message_id( + self.original_giveaway_message_id, search_all=True + ) if not giveaway_data: - await interaction.response.send_message("Could not find the data for this giveaway to reroll.", ephemeral=True) + await interaction.response.send_message( + "Could not find the data for this giveaway to reroll.", ephemeral=True + ) return if not interaction.user.guild_permissions.manage_guild: - await interaction.response.send_message("You don't have permission to reroll winners.", ephemeral=True) + await interaction.response.send_message( + "You don't have permission to reroll winners.", ephemeral=True + ) return - await interaction.response.defer(ephemeral=True) # Acknowledge + await interaction.response.defer(ephemeral=True) # Acknowledge participants_ids = list(giveaway_data.get("participants", [])) if not participants_ids: - await interaction.followup.send("There were no participants in this giveaway to reroll from.", ephemeral=True) + await interaction.followup.send( + "There were no participants in this giveaway to reroll from.", + ephemeral=True, + ) return # Fetch user objects for participants entrants_users = [] for user_id in participants_ids: user = interaction.guild.get_member(user_id) - if not user: # Try fetching if not in cache or left server + if not user: # Try fetching if not in cache or left server try: user = await self.cog.bot.fetch_user(user_id) except discord.NotFound: - continue # Skip if user cannot be found + continue # Skip if user cannot be found if user and not user.bot: - # Apply Nitro check again if it was a nitro giveaway - if giveaway_data.get("is_nitro_giveaway", False) and not is_user_nitro_like(user): + # Apply Nitro check again if it was a nitro giveaway + if giveaway_data.get( + "is_nitro_giveaway", False + ) and not is_user_nitro_like(user): continue entrants_users.append(user) - + if not entrants_users: - await interaction.followup.send("No eligible participants found for a reroll (e.g., after Nitro check or if users left).", ephemeral=True) + await interaction.followup.send( + "No eligible participants found for a reroll (e.g., after Nitro check or if users left).", + ephemeral=True, + ) return num_winners = giveaway_data.get("num_winners", 1) @@ -142,19 +176,36 @@ class GiveawayRerollButton(ui.Button['GiveawayEndView']): # Announce in the original giveaway channel original_channel = self.cog.bot.get_channel(giveaway_data["channel_id"]) if original_channel: - await original_channel.send(f"🔄 Reroll for **{giveaway_data['prize']}**! Congratulations {winner_mentions}, you are the new winner(s)!") - await interaction.followup.send(f"Reroll successful. New winner(s) announced in {original_channel.mention}.", ephemeral=True) + await original_channel.send( + f"🔄 Reroll for **{giveaway_data['prize']}**! Congratulations {winner_mentions}, you are the new winner(s)!" + ) + await interaction.followup.send( + f"Reroll successful. New winner(s) announced in {original_channel.mention}.", + ephemeral=True, + ) else: - await interaction.followup.send("Reroll successful, but I couldn't find the original channel to announce.", ephemeral=True) + await interaction.followup.send( + "Reroll successful, but I couldn't find the original channel to announce.", + ephemeral=True, + ) else: - await interaction.followup.send("Could not select any new winners in the reroll.", ephemeral=True) + await interaction.followup.send( + "Could not select any new winners in the reroll.", ephemeral=True + ) class GiveawayEndView(ui.View): - def __init__(self, cog: 'GiveawaysCog', original_giveaway_message_id: int, timeout=None): # Timeout=None for persistent + def __init__( + self, cog: "GiveawaysCog", original_giveaway_message_id: int, timeout=None + ): # Timeout=None for persistent super().__init__(timeout=timeout) self.cog = cog - self.add_item(GiveawayRerollButton(cog_ref=self.cog, original_giveaway_message_id=original_giveaway_message_id)) + self.add_item( + GiveawayRerollButton( + cog_ref=self.cog, + original_giveaway_message_id=original_giveaway_message_id, + ) + ) class GiveawaysCog(commands.Cog, name="Giveaways"): @@ -164,8 +215,8 @@ class GiveawaysCog(commands.Cog, name="Giveaways"): def __init__(self, bot: commands.Bot): self.bot = bot - self.active_giveaways = [] - self.all_loaded_giveaways = [] # To keep ended ones for reroll lookup + self.active_giveaways = [] + self.all_loaded_giveaways = [] # To keep ended ones for reroll lookup # Structure: # { # "message_id": int, @@ -177,110 +228,146 @@ class GiveawaysCog(commands.Cog, name="Giveaways"): # "creator_id": int, # "participants": set(), # Store user_ids. Stored as list in JSON. # "is_nitro_giveaway": bool, - # "ended": bool + # "ended": bool # } # Ensure data directory exists before loading/saving - asyncio.create_task(self._ensure_data_dir_exists()) # Run asynchronously - asyncio.create_task(self.load_giveaways()) # Run asynchronously + asyncio.create_task(self._ensure_data_dir_exists()) # Run asynchronously + asyncio.create_task(self.load_giveaways()) # Run asynchronously self.check_giveaways_loop.start() # Persistent views are added in setup_hook - async def cog_load(self): # Changed from setup_hook to cog_load for better timing with bot ready + async def cog_load( + self, + ): # Changed from setup_hook to cog_load for better timing with bot ready asyncio.create_task(self._restore_views()) async def _restore_views(self): await self.bot.wait_until_ready() print("Re-adding persistent giveaway views...") - temp_loaded_giveaways = [] # Use a temporary list for loading + temp_loaded_giveaways = [] # Use a temporary list for loading try: - async with aiofiles.open(GIVEAWAY_DATA_FILE, mode='r') as f: + async with aiofiles.open(GIVEAWAY_DATA_FILE, mode="r") as f: content = await f.read() - if not content: # Handle empty file case + if not content: # Handle empty file case giveaways_data_for_views = [] else: - giveaways_data_for_views = await self.bot.loop.run_in_executor(None, json.loads, content) - + giveaways_data_for_views = await self.bot.loop.run_in_executor( + None, json.loads, content + ) + for gw_data in giveaways_data_for_views: # We only need to re-add views for messages that should have them is_ended = gw_data.get("ended", False) # Check if end_time is in the past if "ended" flag isn't perfectly reliable end_time_dt = datetime.datetime.fromisoformat(gw_data["end_time"]) - if not is_ended and end_time_dt > datetime.datetime.now(datetime.timezone.utc): - # Active giveaway, re-add EnterView - self.bot.add_view(GiveawayEnterView(cog=self), message_id=gw_data["message_id"]) - elif is_ended or end_time_dt <= datetime.datetime.now(datetime.timezone.utc): + if not is_ended and end_time_dt > datetime.datetime.now( + datetime.timezone.utc + ): + # Active giveaway, re-add EnterView + self.bot.add_view( + GiveawayEnterView(cog=self), + message_id=gw_data["message_id"], + ) + elif is_ended or end_time_dt <= datetime.datetime.now( + datetime.timezone.utc + ): # Ended giveaway, re-add EndView (with Reroll button) - self.bot.add_view(GiveawayEndView(cog=self, original_giveaway_message_id=gw_data["message_id"]), message_id=gw_data["message_id"]) - temp_loaded_giveaways.append(gw_data) # Keep track for _get_giveaway_by_message_id - print(f"Attempted to re-add views for {len(temp_loaded_giveaways)} giveaways.") + self.bot.add_view( + GiveawayEndView( + cog=self, + original_giveaway_message_id=gw_data["message_id"], + ), + message_id=gw_data["message_id"], + ) + temp_loaded_giveaways.append( + gw_data + ) # Keep track for _get_giveaway_by_message_id + print( + f"Attempted to re-add views for {len(temp_loaded_giveaways)} giveaways." + ) except FileNotFoundError: print("No giveaway data file found, skipping view re-adding.") except json.JSONDecodeError: - print("Error decoding giveaway data file. Starting with no active giveaways.") + print( + "Error decoding giveaway data file. Starting with no active giveaways." + ) except Exception as e: print(f"Error re-adding persistent views: {e}") - async def _ensure_data_dir_exists(self): try: - await aiofiles.os.makedirs(DATA_DIR, exist_ok=True) # Use aiofiles.os for async mkdir, exist_ok handles if it already exists + await aiofiles.os.makedirs( + DATA_DIR, exist_ok=True + ) # Use aiofiles.os for async mkdir, exist_ok handles if it already exists except Exception as e: print(f"Error ensuring data directory {DATA_DIR} exists: {e}") def cog_unload(self): self.check_giveaways_loop.cancel() - async def load_giveaways(self): # Make async + async def load_giveaways(self): # Make async self.active_giveaways = [] self.all_loaded_giveaways = [] try: - async with aiofiles.open(GIVEAWAY_DATA_FILE, mode='r') as f: + async with aiofiles.open(GIVEAWAY_DATA_FILE, mode="r") as f: content = await f.read() - if not content: # Handle empty file case + if not content: # Handle empty file case giveaways_data = [] else: - giveaways_data = await self.bot.loop.run_in_executor(None, json.loads, content) - + giveaways_data = await self.bot.loop.run_in_executor( + None, json.loads, content + ) + now = datetime.datetime.now(datetime.timezone.utc) for gw_data in giveaways_data: - gw_data["end_time"] = datetime.datetime.fromisoformat(gw_data["end_time"]) + gw_data["end_time"] = datetime.datetime.fromisoformat( + gw_data["end_time"] + ) gw_data["participants"] = set(gw_data.get("participants", [])) gw_data.setdefault("is_nitro_giveaway", False) - gw_data.setdefault("ended", gw_data["end_time"] <= now) # Set ended if time has passed - - self.all_loaded_giveaways.append(gw_data.copy()) # Store all for reroll lookup - + gw_data.setdefault( + "ended", gw_data["end_time"] <= now + ) # Set ended if time has passed + + self.all_loaded_giveaways.append( + gw_data.copy() + ) # Store all for reroll lookup + if not gw_data["ended"]: self.active_giveaways.append(gw_data) - print(f"Loaded {len(self.all_loaded_giveaways)} total giveaways ({len(self.active_giveaways)} active).") + print( + f"Loaded {len(self.all_loaded_giveaways)} total giveaways ({len(self.active_giveaways)} active)." + ) except FileNotFoundError: print("Giveaway data file not found. Starting with no active giveaways.") except json.JSONDecodeError: - print("Error decoding giveaway data file. Starting with no active giveaways.") + print( + "Error decoding giveaway data file. Starting with no active giveaways." + ) except Exception as e: print(f"An unexpected error occurred loading giveaways: {e}") - async def save_giveaways(self): # Make async + async def save_giveaways(self): # Make async # Save all giveaways (from self.all_loaded_giveaways or by merging active and ended) # This ensures that "ended" status and participant lists are preserved. try: # Create a unified list to save, ensuring all giveaways are present # and active_giveaways reflects the most current state for those not yet ended. - + # Create a dictionary of active giveaways by message_id for quick updates - active_map = {gw['message_id']: gw for gw in self.active_giveaways} - + active_map = {gw["message_id"]: gw for gw in self.active_giveaways} + giveaways_to_save = [] # Iterate through all_loaded_giveaways to maintain the full history for gw_hist in self.all_loaded_giveaways: # If this giveaway is also in active_map, it means it's still active # or just ended in the current session. Use the version from active_map # as it might have newer participant data before being marked ended. - if gw_hist['message_id'] in active_map: - current_version = active_map[gw_hist['message_id']] + if gw_hist["message_id"] in active_map: + current_version = active_map[gw_hist["message_id"]] saved_gw = current_version.copy() - else: # It's an older, ended giveaway not in the current active list + else: # It's an older, ended giveaway not in the current active list saved_gw = gw_hist.copy() saved_gw["end_time"] = saved_gw["end_time"].isoformat() @@ -290,19 +377,23 @@ class GiveawaysCog(commands.Cog, name="Giveaways"): # Add any brand new giveaways from self.active_giveaways not yet in self.all_loaded_giveaways # This case should ideally be handled by adding to both lists upon creation. # For robustness: - all_saved_ids = {gw['message_id'] for gw in giveaways_to_save} + all_saved_ids = {gw["message_id"] for gw in giveaways_to_save} for gw_active in self.active_giveaways: - if gw_active['message_id'] not in all_saved_ids: + if gw_active["message_id"] not in all_saved_ids: new_gw_to_save = gw_active.copy() new_gw_to_save["end_time"] = new_gw_to_save["end_time"].isoformat() - new_gw_to_save["participants"] = list(new_gw_to_save["participants"]) + new_gw_to_save["participants"] = list( + new_gw_to_save["participants"] + ) giveaways_to_save.append(new_gw_to_save) # Also add to all_loaded_giveaways for next time self.all_loaded_giveaways.append(gw_active.copy()) # Offload json.dumps to executor - json_string_to_save = await self.bot.loop.run_in_executor(None, json.dumps, giveaways_to_save, indent=4) - async with aiofiles.open(GIVEAWAY_DATA_FILE, mode='w') as f: + json_string_to_save = await self.bot.loop.run_in_executor( + None, json.dumps, giveaways_to_save, indent=4 + ) + async with aiofiles.open(GIVEAWAY_DATA_FILE, mode="w") as f: await f.write(json_string_to_save) # print(f"Saved {len(giveaways_to_save)} giveaways to disk.") except Exception as e: @@ -321,18 +412,18 @@ class GiveawaysCog(commands.Cog, name="Giveaways"): match = re.fullmatch(r"(\d+)([smhdw])", duration_str.lower()) if not match: return None - + value, unit = int(match.group(1)), match.group(2) - - if unit == 's': + + if unit == "s": return datetime.timedelta(seconds=value) - elif unit == 'm': + elif unit == "m": return datetime.timedelta(minutes=value) - elif unit == 'h': + elif unit == "h": return datetime.timedelta(hours=value) - elif unit == 'd': + elif unit == "d": return datetime.timedelta(days=value) - elif unit == 'w': + elif unit == "w": return datetime.timedelta(weeks=value) return None @@ -341,21 +432,30 @@ class GiveawaysCog(commands.Cog, name="Giveaways"): prize="What is the prize?", duration="How long should the giveaway last? (e.g., 10m, 1h, 2d, 1w)", winners="How many winners? (default: 1)", - nitro_giveaway="Is this a Nitro-only giveaway? (checks for animated avatar/banner)" + nitro_giveaway="Is this a Nitro-only giveaway? (checks for animated avatar/banner)", ) @app_commands.checks.has_permissions(manage_guild=True) - async def create_giveaway_slash(self, interaction: discord.Interaction, prize: str, duration: str, winners: int = 1, nitro_giveaway: bool = False): + async def create_giveaway_slash( + self, + interaction: discord.Interaction, + prize: str, + duration: str, + winners: int = 1, + nitro_giveaway: bool = False, + ): """Slash command to create a giveaway using buttons.""" parsed_duration = self.parse_duration(duration) if not parsed_duration: await interaction.response.send_message( "Invalid duration format. Use s, m, h, d, w (e.g., 10m, 1h, 2d).", - ephemeral=True + ephemeral=True, ) return if winners < 1: - await interaction.response.send_message("Number of winners must be at least 1.", ephemeral=True) + await interaction.response.send_message( + "Number of winners must be at least 1.", ephemeral=True + ) return end_time = datetime.datetime.now(datetime.timezone.utc) + parsed_duration @@ -363,17 +463,19 @@ class GiveawaysCog(commands.Cog, name="Giveaways"): embed = discord.Embed( title=f"🎉 Giveaway: {prize} 🎉", description=f"Click the button below to enter!\n" - f"Ends: {discord.utils.format_dt(end_time, style='R')} ({discord.utils.format_dt(end_time, style='F')})\n" - f"Winners: {winners}", - color=discord.Color.gold() + f"Ends: {discord.utils.format_dt(end_time, style='R')} ({discord.utils.format_dt(end_time, style='F')})\n" + f"Winners: {winners}", + color=discord.Color.gold(), ) if nitro_giveaway: embed.description += "\n*This is a Nitro-exclusive giveaway!*" - embed.set_footer(text=f"Giveaway started by {interaction.user.display_name}. Entries: 0") # Initial entry count - - await interaction.response.send_message("Creating giveaway...", ephemeral=True) - - view = GiveawayEnterView(cog=self) + embed.set_footer( + text=f"Giveaway started by {interaction.user.display_name}. Entries: 0" + ) # Initial entry count + + await interaction.response.send_message("Creating giveaway...", ephemeral=True) + + view = GiveawayEnterView(cog=self) giveaway_message = await interaction.channel.send(embed=embed, view=view) giveaway_data = { @@ -386,14 +488,17 @@ class GiveawaysCog(commands.Cog, name="Giveaways"): "creator_id": interaction.user.id, "participants": set(), "is_nitro_giveaway": nitro_giveaway, - "ended": False + "ended": False, } self.active_giveaways.append(giveaway_data) - self.all_loaded_giveaways.append(giveaway_data.copy()) # Also add to the comprehensive list - self.save_giveaways() - - await interaction.followup.send(f"Giveaway for '{prize}' created successfully!", ephemeral=True) + self.all_loaded_giveaways.append( + giveaway_data.copy() + ) # Also add to the comprehensive list + self.save_giveaways() + await interaction.followup.send( + f"Giveaway for '{prize}' created successfully!", ephemeral=True + ) @tasks.loop(seconds=30) async def check_giveaways_loop(self): @@ -403,75 +508,111 @@ class GiveawaysCog(commands.Cog, name="Giveaways"): # Iterate over a copy of active_giveaways for safe removal/modification for giveaway_data in list(self.active_giveaways): if giveaway_data["ended"] or now < giveaway_data["end_time"]: - continue # Skip already ended or not yet due + continue # Skip already ended or not yet due giveaways_processed_in_this_run = True - giveaway_data["ended"] = True # Mark as ended + giveaway_data["ended"] = True # Mark as ended channel = self.bot.get_channel(giveaway_data["channel_id"]) if not channel: - print(f"Error: Could not find channel {giveaway_data['channel_id']} for giveaway {giveaway_data['message_id']}") + print( + f"Error: Could not find channel {giveaway_data['channel_id']} for giveaway {giveaway_data['message_id']}" + ) # Remove from active_giveaways directly as it can't be processed - self.active_giveaways = [gw for gw in self.active_giveaways if gw["message_id"] != giveaway_data["message_id"]] + self.active_giveaways = [ + gw + for gw in self.active_giveaways + if gw["message_id"] != giveaway_data["message_id"] + ] continue try: message = await channel.fetch_message(giveaway_data["message_id"]) except discord.NotFound: - print(f"Error: Could not find message {giveaway_data['message_id']} in channel {channel.id}") - self.active_giveaways = [gw for gw in self.active_giveaways if gw["message_id"] != giveaway_data["message_id"]] + print( + f"Error: Could not find message {giveaway_data['message_id']} in channel {channel.id}" + ) + self.active_giveaways = [ + gw + for gw in self.active_giveaways + if gw["message_id"] != giveaway_data["message_id"] + ] continue except discord.Forbidden: - print(f"Error: Bot lacks permissions to fetch message {giveaway_data['message_id']} in channel {channel.id}") + print( + f"Error: Bot lacks permissions to fetch message {giveaway_data['message_id']} in channel {channel.id}" + ) # Cannot process, but keep it in active_giveaways for now, maybe perms will be fixed. # Or decide to remove it. For now, skip. continue - + # Fetch participants from the giveaway data entrants_users = [] for user_id in giveaway_data["participants"]: # Ensure user is still in the guild for Nitro check if applicable - member = channel.guild.get_member(user_id) # Use guild from channel + member = channel.guild.get_member(user_id) # Use guild from channel user_to_check = member if member else await self.bot.fetch_user(user_id) - if not user_to_check: continue # User not found + if not user_to_check: + continue # User not found - if user_to_check.bot: continue + if user_to_check.bot: + continue - if giveaway_data["is_nitro_giveaway"] and not is_user_nitro_like(user_to_check): - continue # Skip non-nitro users for nitro giveaways + if giveaway_data["is_nitro_giveaway"] and not is_user_nitro_like( + user_to_check + ): + continue # Skip non-nitro users for nitro giveaways entrants_users.append(user_to_check) - + winners_list = [] if entrants_users: if len(entrants_users) <= giveaway_data["num_winners"]: winners_list = list(entrants_users) else: - winners_list = random.sample(entrants_users, giveaway_data["num_winners"]) + winners_list = random.sample( + entrants_users, giveaway_data["num_winners"] + ) + + winner_mentions_str = ( + ", ".join(w.mention for w in winners_list) if winners_list else "None" + ) - winner_mentions_str = ", ".join(w.mention for w in winners_list) if winners_list else 'None' - if winners_list: - await channel.send(f"Congratulations {winner_mentions_str}! You won **{giveaway_data['prize']}**!") + await channel.send( + f"Congratulations {winner_mentions_str}! You won **{giveaway_data['prize']}**!" + ) else: - await channel.send(f"The giveaway for **{giveaway_data['prize']}** has ended, but there were no eligible participants.") + await channel.send( + f"The giveaway for **{giveaway_data['prize']}** has ended, but there were no eligible participants." + ) new_embed = message.embeds[0] new_embed.description = f"Giveaway ended!\nWinners: {winner_mentions_str}" new_embed.color = discord.Color.dark_grey() new_embed.set_footer(text="Giveaway has concluded.") - - end_view = GiveawayEndView(cog=self, original_giveaway_message_id=giveaway_data["message_id"]) + + end_view = GiveawayEndView( + cog=self, original_giveaway_message_id=giveaway_data["message_id"] + ) try: - await message.edit(embed=new_embed, view=end_view) + await message.edit(embed=new_embed, view=end_view) except discord.Forbidden: - print(f"Error: Bot lacks permissions to edit message for {giveaway_data['message_id']}") + print( + f"Error: Bot lacks permissions to edit message for {giveaway_data['message_id']}" + ) except discord.HTTPException as e: - print(f"Error editing giveaway message {giveaway_data['message_id']}: {e}") - + print( + f"Error editing giveaway message {giveaway_data['message_id']}: {e}" + ) + # Remove from active_giveaways after processing - self.active_giveaways = [gw for gw in self.active_giveaways if gw["message_id"] != giveaway_data["message_id"]] + self.active_giveaways = [ + gw + for gw in self.active_giveaways + if gw["message_id"] != giveaway_data["message_id"] + ] if giveaways_processed_in_this_run: self.save_giveaways() @@ -480,29 +621,42 @@ class GiveawaysCog(commands.Cog, name="Giveaways"): async def before_check_giveaways_loop(self): await self.bot.wait_until_ready() - @gway.command(name="rollmanual", description="Manually roll a winner from a message (for old giveaways or specific cases).") + @gway.command( + name="rollmanual", + description="Manually roll a winner from a message (for old giveaways or specific cases).", + ) @app_commands.describe( message_id="The ID of the message (giveaway or any message with reactions).", winners="How many winners to pick? (default: 1)", - emoji="Emoji for reaction-based roll (if not a button giveaway, default: 🎉)" + emoji="Emoji for reaction-based roll (if not a button giveaway, default: 🎉)", ) @app_commands.checks.has_permissions(manage_guild=True) - async def manual_roll_giveaway_slash(self, interaction: discord.Interaction, message_id: str, winners: int = 1, emoji: str = "🎉"): + async def manual_roll_giveaway_slash( + self, + interaction: discord.Interaction, + message_id: str, + winners: int = 1, + emoji: str = "🎉", + ): if winners < 1: - await interaction.response.send_message("Number of winners must be at least 1.", ephemeral=True) + await interaction.response.send_message( + "Number of winners must be at least 1.", ephemeral=True + ) return try: msg_id = int(message_id) except ValueError: - await interaction.response.send_message("Invalid Message ID format. It should be a number.", ephemeral=True) + await interaction.response.send_message( + "Invalid Message ID format. It should be a number.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) # Try to find if this message_id corresponds to a known giveaway giveaway_info = self._get_giveaway_by_message_id(msg_id, search_all=True) - entrants = set() # Store user objects + entrants = set() # Store user objects message_to_roll = None try: @@ -512,25 +666,36 @@ class GiveawaysCog(commands.Cog, name="Giveaways"): for chan in interaction.guild.text_channels: try: message_to_roll = await chan.fetch_message(msg_id) - if message_to_roll: break + if message_to_roll: + break except (discord.NotFound, discord.Forbidden): continue - + if not message_to_roll: - await interaction.followup.send(f"Could not find message with ID `{msg_id}` in this server or I lack permissions.", ephemeral=True) + await interaction.followup.send( + f"Could not find message with ID `{msg_id}` in this server or I lack permissions.", + ephemeral=True, + ) return if giveaway_info and "participants" in giveaway_info: # Use stored participants if available (from button-based system) for user_id in giveaway_info["participants"]: - user = interaction.guild.get_member(user_id) or await self.bot.fetch_user(user_id) + user = interaction.guild.get_member( + user_id + ) or await self.bot.fetch_user(user_id) if user and not user.bot: - if giveaway_info.get("is_nitro_giveaway", False) and not is_user_nitro_like(user): + if giveaway_info.get( + "is_nitro_giveaway", False + ) and not is_user_nitro_like(user): continue entrants.add(user) if not entrants: - await interaction.followup.send(f"Found giveaway data for message `{msg_id}`, but no eligible stored participants.", ephemeral=True) - return + await interaction.followup.send( + f"Found giveaway data for message `{msg_id}`, but no eligible stored participants.", + ephemeral=True, + ) + return else: # Fallback to reactions if no participant data or not a known giveaway reaction_found = False @@ -544,10 +709,16 @@ class GiveawaysCog(commands.Cog, name="Giveaways"): entrants.add(user) break if not reaction_found: - await interaction.followup.send(f"No reactions found with {emoji} on message `{msg_id}`.", ephemeral=True) + await interaction.followup.send( + f"No reactions found with {emoji} on message `{msg_id}`.", + ephemeral=True, + ) return if not entrants: - await interaction.followup.send(f"No valid (non-bot) users reacted with {emoji} on message `{msg_id}`.", ephemeral=True) + await interaction.followup.send( + f"No valid (non-bot) users reacted with {emoji} on message `{msg_id}`.", + ephemeral=True, + ) return winners_list = [] @@ -559,14 +730,24 @@ class GiveawaysCog(commands.Cog, name="Giveaways"): if winners_list: winner_mentions = ", ".join(w.mention for w in winners_list) - await interaction.followup.send(f"Manual roll from message `{msg_id}` in {message_to_roll.channel.mention}:\nCongratulations {winner_mentions}!", ephemeral=False) + await interaction.followup.send( + f"Manual roll from message `{msg_id}` in {message_to_roll.channel.mention}:\nCongratulations {winner_mentions}!", + ephemeral=False, + ) if interaction.channel.id != message_to_roll.channel.id: try: - await message_to_roll.channel.send(f"Manual roll for message {message_to_roll.jump_url} concluded. Winner(s): {winner_mentions}") + await message_to_roll.channel.send( + f"Manual roll for message {message_to_roll.jump_url} concluded. Winner(s): {winner_mentions}" + ) except discord.Forbidden: - await interaction.followup.send(f"(Note: I couldn't announce the winner in {message_to_roll.channel.mention}.)", ephemeral=True) + await interaction.followup.send( + f"(Note: I couldn't announce the winner in {message_to_roll.channel.mention}.)", + ephemeral=True, + ) else: - await interaction.followup.send(f"Could not select any winners from message `{msg_id}`.", ephemeral=True) + await interaction.followup.send( + f"Could not select any winners from message `{msg_id}`.", ephemeral=True + ) async def setup(bot: commands.Bot): diff --git a/cogs/help_cog.py b/cogs/help_cog.py index 9dff55f..861cd41 100644 --- a/cogs/help_cog.py +++ b/cogs/help_cog.py @@ -32,12 +32,19 @@ COG_DISPLAY_NAMES = { # Add other cogs here as needed } + class HelpSelect(discord.ui.Select): - def __init__(self, view: 'HelpView', start_index=0, max_options=24): + def __init__(self, view: "HelpView", start_index=0, max_options=24): self.help_view = view # Always include General Overview option - options = [discord.SelectOption(label="General Overview", description="Go back to the main help page.", value="-1")] # Value -1 for overview page + options = [ + discord.SelectOption( + label="General Overview", + description="Go back to the main help page.", + value="-1", + ) + ] # Value -1 for overview page # Calculate end index, ensuring we don't go past the end of the cogs list end_index = min(start_index + max_options, len(view.cogs)) @@ -50,17 +57,24 @@ class HelpSelect(discord.ui.Select): # Use a relative index (i - start_index) as the value to avoid confusion # when navigating between pages relative_index = i - start_index - options.append(discord.SelectOption(label=display_name, value=str(relative_index))) + options.append( + discord.SelectOption(label=display_name, value=str(relative_index)) + ) # Store the range of cogs this select menu covers self.start_index = start_index self.end_index = end_index - super().__init__(placeholder="Select a category...", min_values=1, max_values=1, options=options) + super().__init__( + placeholder="Select a category...", + min_values=1, + max_values=1, + options=options, + ) async def callback(self, interaction: discord.Interaction): selected_value = int(self.values[0]) - if selected_value == -1: # General Overview selected + if selected_value == -1: # General Overview selected self.help_view.current_page = 0 else: # The value is a relative index (0-based) within the current page of options @@ -68,15 +82,22 @@ class HelpSelect(discord.ui.Select): actual_cog_index = selected_value + self.start_index # Debug information - print(f"Selected value: {selected_value}, start_index: {self.start_index}, actual_cog_index: {actual_cog_index}") + print( + f"Selected value: {selected_value}, start_index: {self.start_index}, actual_cog_index: {actual_cog_index}" + ) # Make sure the index is valid if 0 <= actual_cog_index < len(self.help_view.cogs): - self.help_view.current_page = actual_cog_index + 1 # +1 because page 0 is overview + self.help_view.current_page = ( + actual_cog_index + 1 + ) # +1 because page 0 is overview else: # If the index is invalid, go to the overview page self.help_view.current_page = 0 - await interaction.response.send_message(f"That category is no longer available. Showing overview. (Debug: value={selected_value}, start={self.start_index}, actual={actual_cog_index}, max={len(self.help_view.cogs)})", ephemeral=True) + await interaction.response.send_message( + f"That category is no longer available. Showing overview. (Debug: value={selected_value}, start={self.start_index}, actual={actual_cog_index}, max={len(self.help_view.cogs)})", + ephemeral=True, + ) # Ensure current_page is within valid range if self.help_view.current_page >= len(self.help_view.pages): @@ -91,13 +112,19 @@ class HelpSelect(discord.ui.Select): else: cog_index = self.help_view.current_page - 1 if 0 <= cog_index < len(self.help_view.cogs): - current_option_label = COG_DISPLAY_NAMES.get(self.help_view.cogs[cog_index].qualified_name, self.help_view.cogs[cog_index].qualified_name) + current_option_label = COG_DISPLAY_NAMES.get( + self.help_view.cogs[cog_index].qualified_name, + self.help_view.cogs[cog_index].qualified_name, + ) else: current_option_label = "Select a category..." self.placeholder = current_option_label try: - await interaction.response.edit_message(embed=self.help_view.pages[self.help_view.current_page], view=self.help_view) + await interaction.response.edit_message( + embed=self.help_view.pages[self.help_view.current_page], + view=self.help_view, + ) except Exception as e: # If we can't edit the message, try to defer or send a new message try: @@ -118,11 +145,15 @@ class HelpView(discord.ui.View): # Filter cogs and sort them using the display name mapping self.cogs = sorted( [cog for _, cog in bot.cogs.items() if cog.get_commands()], - key=lambda cog: COG_DISPLAY_NAMES.get(cog.qualified_name, cog.qualified_name) # Sort alphabetically by display name + key=lambda cog: COG_DISPLAY_NAMES.get( + cog.qualified_name, cog.qualified_name + ), # Sort alphabetically by display name ) # Calculate total number of select menu pages needed - self.total_select_pages = (len(self.cogs) + self.max_select_options - 1) // self.max_select_options + self.total_select_pages = ( + len(self.cogs) + self.max_select_options - 1 + ) // self.max_select_options # Create pages after total_select_pages is defined self.pages = self._create_pages() @@ -138,23 +169,31 @@ class HelpView(discord.ui.View): embed = discord.Embed( title="Help Command", description=f"Use the buttons below to navigate through command categories.\nTotal Categories: {len(self.cogs)}\nUse the Categories buttons to navigate between pages of categories.", - color=discord.Color.blue() + color=discord.Color.blue(), ) # Calculate how many cogs are shown in the current select page start_index = self.current_select_page * self.max_select_options end_index = min(start_index + self.max_select_options, len(self.cogs)) - current_range = f"{start_index + 1}-{end_index}" if len(self.cogs) > self.max_select_options else f"1-{len(self.cogs)}" + current_range = ( + f"{start_index + 1}-{end_index}" + if len(self.cogs) > self.max_select_options + else f"1-{len(self.cogs)}" + ) # Add information about which cogs are currently visible if len(self.cogs) > self.max_select_options: embed.add_field( name="Currently Showing", value=f"Categories {current_range} of {len(self.cogs)}", - inline=False + inline=False, ) - embed.set_footer(text="Page 0 / {} | Category Page {} / {}".format(len(self.cogs), self.current_select_page + 1, self.total_select_pages)) + embed.set_footer( + text="Page 0 / {} | Category Page {} / {}".format( + len(self.cogs), self.current_select_page + 1, self.total_select_pages + ) + ) return embed def _create_pages(self): @@ -170,28 +209,40 @@ class HelpView(discord.ui.View): display_name = COG_DISPLAY_NAMES.get(cog_name, cog_name) cog_commands = cog.get_commands() embed = discord.Embed( - title=f"{display_name} Commands", # Use the display name here + title=f"{display_name} Commands", # Use the display name here description=f"Commands available in the {display_name} category:", - color=discord.Color.green() # Or assign colors dynamically + color=discord.Color.green(), # Or assign colors dynamically ) for command in cog_commands: # Skip subcommands for now, just show top-level commands in the cog if isinstance(command, commands.Group): - # If it's a group, list its subcommands or just the group name - sub_cmds = ", ".join([f"`{sub.name}`" for sub in command.commands]) - if sub_cmds: - embed.add_field(name=f"`{command.name}` (Group)", value=f"Subcommands: {sub_cmds}\n{command.short_doc or 'No description'}", inline=False) - else: - embed.add_field(name=f"`{command.name}` (Group)", value=f"{command.short_doc or 'No description'}", inline=False) + # If it's a group, list its subcommands or just the group name + sub_cmds = ", ".join( + [f"`{sub.name}`" for sub in command.commands] + ) + if sub_cmds: + embed.add_field( + name=f"`{command.name}` (Group)", + value=f"Subcommands: {sub_cmds}\n{command.short_doc or 'No description'}", + inline=False, + ) + else: + embed.add_field( + name=f"`{command.name}` (Group)", + value=f"{command.short_doc or 'No description'}", + inline=False, + ) - elif command.parent is None: # Only show top-level commands + elif command.parent is None: # Only show top-level commands signature = f"{command.name} {command.signature}" embed.add_field( name=f"`{signature.strip()}`", value=command.short_doc or "No description provided.", - inline=False + inline=False, ) - embed.set_footer(text=f"Page {i + 1} / {len(self.cogs)} | Category Page {self.current_select_page + 1} / {self.total_select_pages}") + embed.set_footer( + text=f"Page {i + 1} / {len(self.cogs)} | Category Page {self.current_select_page + 1} / {self.total_select_pages}" + ) pages.append(embed) except Exception as e: # If there's an error creating a page for a cog, log it and continue @@ -200,9 +251,11 @@ class HelpView(discord.ui.View): error_embed = discord.Embed( title=f"Error displaying commands", description=f"There was an error displaying commands for this category.\nPlease try again or contact the bot owner if the issue persists.", - color=discord.Color.red() + color=discord.Color.red(), + ) + error_embed.set_footer( + text=f"Page {i + 1} / {len(self.cogs)} | Category Page {self.current_select_page + 1} / {self.total_select_pages}" ) - error_embed.set_footer(text=f"Page {i + 1} / {len(self.cogs)} | Category Page {self.current_select_page + 1} / {self.total_select_pages}") pages.append(error_embed) return pages @@ -241,7 +294,10 @@ class HelpView(discord.ui.View): else: cog_index = self.current_page - 1 if 0 <= cog_index < len(self.cogs): - current_option_label = COG_DISPLAY_NAMES.get(self.cogs[cog_index].qualified_name, self.cogs[cog_index].qualified_name) + current_option_label = COG_DISPLAY_NAMES.get( + self.cogs[cog_index].qualified_name, + self.cogs[cog_index].qualified_name, + ) else: current_option_label = "Select a category..." self.select_menu.placeholder = current_option_label @@ -258,14 +314,14 @@ class HelpView(discord.ui.View): return # Buttons will be added by decorators later for item in self.children: - if hasattr(item, 'custom_id'): - if item.custom_id == 'prev_page': + if hasattr(item, "custom_id"): + if item.custom_id == "prev_page": prev_page_button = item - elif item.custom_id == 'next_page': + elif item.custom_id == "next_page": next_page_button = item - elif item.custom_id == 'prev_category': + elif item.custom_id == "prev_category": prev_category_button = item - elif item.custom_id == 'next_category': + elif item.custom_id == "next_category": next_category_button = item # Update page navigation buttons @@ -278,10 +334,16 @@ class HelpView(discord.ui.View): if prev_category_button: prev_category_button.disabled = self.current_select_page == 0 if next_category_button: - next_category_button.disabled = self.current_select_page == self.total_select_pages - 1 + next_category_button.disabled = ( + self.current_select_page == self.total_select_pages - 1 + ) - @discord.ui.button(label="Previous", style=discord.ButtonStyle.grey, row=1, custom_id="prev_page") - async def previous_button(self, interaction: discord.Interaction, _: discord.ui.Button): + @discord.ui.button( + label="Previous", style=discord.ButtonStyle.grey, row=1, custom_id="prev_page" + ) + async def previous_button( + self, interaction: discord.Interaction, _: discord.ui.Button + ): if self.current_page > 0: self.current_page -= 1 self._update_buttons() @@ -292,7 +354,9 @@ class HelpView(discord.ui.View): self.current_page = 0 try: - await interaction.response.edit_message(embed=self.pages[self.current_page], view=self) + await interaction.response.edit_message( + embed=self.pages[self.current_page], view=self + ) except Exception as e: try: await interaction.response.defer() @@ -302,7 +366,9 @@ class HelpView(discord.ui.View): else: await interaction.response.defer() - @discord.ui.button(label="Next", style=discord.ButtonStyle.grey, row=1, custom_id="next_page") + @discord.ui.button( + label="Next", style=discord.ButtonStyle.grey, row=1, custom_id="next_page" + ) async def next_button(self, interaction: discord.Interaction, _: discord.ui.Button): if self.current_page < len(self.pages) - 1: self.current_page += 1 @@ -314,7 +380,9 @@ class HelpView(discord.ui.View): self.current_page = 0 try: - await interaction.response.edit_message(embed=self.pages[self.current_page], view=self) + await interaction.response.edit_message( + embed=self.pages[self.current_page], view=self + ) except Exception as e: try: await interaction.response.defer() @@ -324,8 +392,15 @@ class HelpView(discord.ui.View): else: await interaction.response.defer() - @discord.ui.button(label="◀ Categories", style=discord.ButtonStyle.primary, row=2, custom_id="prev_category") - async def prev_category_button(self, interaction: discord.Interaction, _: discord.ui.Button): + @discord.ui.button( + label="◀ Categories", + style=discord.ButtonStyle.primary, + row=2, + custom_id="prev_category", + ) + async def prev_category_button( + self, interaction: discord.Interaction, _: discord.ui.Button + ): if self.current_select_page > 0: # Store the current page before updating old_page = self.current_page @@ -357,7 +432,9 @@ class HelpView(discord.ui.View): self.current_page = 0 try: - await interaction.response.edit_message(embed=self.pages[self.current_page], view=self) + await interaction.response.edit_message( + embed=self.pages[self.current_page], view=self + ) except Exception as e: try: await interaction.response.defer() @@ -367,8 +444,15 @@ class HelpView(discord.ui.View): else: await interaction.response.defer() - @discord.ui.button(label="Categories ▶", style=discord.ButtonStyle.primary, row=2, custom_id="next_category") - async def next_category_button(self, interaction: discord.Interaction, _: discord.ui.Button): + @discord.ui.button( + label="Categories ▶", + style=discord.ButtonStyle.primary, + row=2, + custom_id="next_category", + ) + async def next_category_button( + self, interaction: discord.Interaction, _: discord.ui.Button + ): if self.current_select_page < self.total_select_pages - 1: # Store the current page before updating old_page = self.current_page @@ -400,7 +484,9 @@ class HelpView(discord.ui.View): self.current_page = 0 try: - await interaction.response.edit_message(embed=self.pages[self.current_page], view=self) + await interaction.response.edit_message( + embed=self.pages[self.current_page], view=self + ) except Exception as e: try: await interaction.response.defer() @@ -415,7 +501,7 @@ class HelpCog(commands.Cog): def __init__(self, bot: commands.Bot): self.bot = bot # Remove the default help command before adding the custom one - original_help_command = bot.get_command('help') + original_help_command = bot.get_command("help") if original_help_command: bot.remove_command(original_help_command.name) @@ -429,26 +515,46 @@ class HelpCog(commands.Cog): embed = discord.Embed( title=f"Help for `{command.name}`", description=command.help or "No detailed description provided.", - color=discord.Color.blue() + color=discord.Color.blue(), + ) + embed.add_field( + name="Usage", + value=f"`{command.name} {command.signature}`", + inline=False, ) - embed.add_field(name="Usage", value=f"`{command.name} {command.signature}`", inline=False) if isinstance(command, commands.Group): - subcommands = "\n".join([f"`{sub.name}`: {sub.short_doc or 'No description'}" for sub in command.commands]) - embed.add_field(name="Subcommands", value=subcommands or "None", inline=False) + subcommands = "\n".join( + [ + f"`{sub.name}`: {sub.short_doc or 'No description'}" + for sub in command.commands + ] + ) + embed.add_field( + name="Subcommands", + value=subcommands or "None", + inline=False, + ) await ctx.send(embed=embed, ephemeral=True) else: - await ctx.send(f"Command `{command_name}` not found.", ephemeral=True) + await ctx.send( + f"Command `{command_name}` not found.", ephemeral=True + ) else: view = HelpView(self.bot) - await ctx.send(embed=view.pages[0], view=view, ephemeral=True) # Send ephemeral so only user sees it + await ctx.send( + embed=view.pages[0], view=view, ephemeral=True + ) # Send ephemeral so only user sees it except Exception as e: # If there's an error, send a simple error message print(f"Error in help command: {e}") - await ctx.send(f"An error occurred while displaying the help command. Please try again or contact the bot owner if the issue persists.", ephemeral=True) + await ctx.send( + f"An error occurred while displaying the help command. Please try again or contact the bot owner if the issue persists.", + ephemeral=True, + ) @commands.Cog.listener() async def on_ready(self): - print(f'{self.__class__.__name__} cog has been loaded.') + print(f"{self.__class__.__name__} cog has been loaded.") async def setup(bot: commands.Bot): diff --git a/cogs/leveling_cog.py b/cogs/leveling_cog.py index d8c1cd4..5f02262 100644 --- a/cogs/leveling_cog.py +++ b/cogs/leveling_cog.py @@ -1,12 +1,12 @@ import discord from discord.ext import commands -from discord import ui # Add ui for LayoutView +from discord import ui # Add ui for LayoutView import json import os import asyncio import random import math -import traceback # Import traceback for detailed error logging +import traceback # Import traceback for detailed error logging from typing import Dict, List, Optional, Union, Set # File paths for JSON data @@ -22,12 +22,15 @@ DEFAULT_XP_COOLDOWN = 30 # seconds DEFAULT_REACTION_COOLDOWN = 30 # seconds DEFAULT_LEVEL_MULTIPLIER = 35 # XP needed per level = level * multiplier + class LevelingCog(commands.Cog): def __init__(self, bot: commands.Bot): self.bot = bot if not self.bot: print("DEBUG: Bot instance is None in LevelingCog.__init__") - self.user_data = {} # {user_id: {"xp": int, "level": int, "last_message_time": float}} + self.user_data = ( + {} + ) # {user_id: {"xp": int, "level": int, "last_message_time": float}} self.level_roles = {} # {guild_id: {level: role_id}} self.restricted_channels = set() # Set of channel IDs where XP gain is disabled self.xp_cooldowns = {} # {user_id: last_xp_time} @@ -40,7 +43,7 @@ class LevelingCog(commands.Cog): "message_cooldown": DEFAULT_XP_COOLDOWN, "reaction_cooldown": DEFAULT_REACTION_COOLDOWN, "reaction_xp_enabled": True, - "default_level_notifs_enabled": False # New setting: level notifications disabled by default + "default_level_notifs_enabled": False, # New setting: level notifications disabled by default } # Load existing data @@ -61,7 +64,9 @@ class LevelingCog(commands.Cog): user_id = int(k) # Ensure 'level_notifs_enabled' is present with default if missing if "level_notifs_enabled" not in v: - v["level_notifs_enabled"] = self.config["default_level_notifs_enabled"] + v["level_notifs_enabled"] = self.config[ + "default_level_notifs_enabled" + ] self.user_data[user_id] = v print(f"Loaded level data for {len(self.user_data)} users") except Exception as e: @@ -101,7 +106,9 @@ class LevelingCog(commands.Cog): # Handle gendered roles self.level_roles[guild_id][level] = {} for gender, role_id_str in role_data.items(): - self.level_roles[guild_id][level][gender] = int(role_id_str) + self.level_roles[guild_id][level][gender] = int( + role_id_str + ) else: # Handle regular roles self.level_roles[guild_id][level] = int(role_data) @@ -124,7 +131,8 @@ class LevelingCog(commands.Cog): if isinstance(role_data, dict): # Handle gendered roles serializable_data[str(guild_id)][str(level)] = { - gender: str(role_id) for gender, role_id in role_data.items() + gender: str(role_id) + for gender, role_id in role_data.items() } else: # Handle regular roles @@ -143,7 +151,9 @@ class LevelingCog(commands.Cog): with open(RESTRICTED_CHANNELS_FILE, "r", encoding="utf-8") as f: data = json.load(f) # Convert list to set of integers - self.restricted_channels = set(int(channel_id) for channel_id in data) + self.restricted_channels = set( + int(channel_id) for channel_id in data + ) print(f"Loaded {len(self.restricted_channels)} restricted channels") except Exception as e: print(f"Error loading restricted channels: {e}") @@ -153,7 +163,9 @@ class LevelingCog(commands.Cog): """Save restricted channels to JSON file""" try: # Convert set to list of strings for JSON serialization - serializable_data = [str(channel_id) for channel_id in self.restricted_channels] + serializable_data = [ + str(channel_id) for channel_id in self.restricted_channels + ] with open(RESTRICTED_CHANNELS_FILE, "w", encoding="utf-8") as f: json.dump(serializable_data, f, indent=4, ensure_ascii=False) except Exception as e: @@ -200,11 +212,13 @@ class LevelingCog(commands.Cog): "xp": 0, "level": 0, "last_message_time": 0, - "level_notifs_enabled": self.config["default_level_notifs_enabled"] + "level_notifs_enabled": self.config["default_level_notifs_enabled"], } return self.user_data[user_id] - async def add_xp(self, user_id: int, guild_id: int, xp_amount: int = DEFAULT_XP_PER_MESSAGE) -> Optional[int]: + async def add_xp( + self, user_id: int, guild_id: int, xp_amount: int = DEFAULT_XP_PER_MESSAGE + ) -> Optional[int]: """ Add XP to a user and return new level if leveled up, otherwise None """ @@ -246,7 +260,9 @@ class LevelingCog(commands.Cog): # Get the member object member = guild.get_member(user_id) if not member: - print(f"DEBUG: Member {user_id} not found in guild {guild_id} in assign_level_role") + print( + f"DEBUG: Member {user_id} not found in guild {guild_id} in assign_level_role" + ) return False # Find the highest role that matches the user's level @@ -274,7 +290,11 @@ class LevelingCog(commands.Cog): # Handle gendered roles if available if isinstance(role_data, dict) and gender in role_data: highest_role_id = role_data[gender] - elif isinstance(role_data, dict) and "male" in role_data and "female" in role_data: + elif ( + isinstance(role_data, dict) + and "male" in role_data + and "female" in role_data + ): # If we have gendered roles but no gender preference, use male as default highest_role_id = role_data["male"] else: @@ -285,7 +305,9 @@ class LevelingCog(commands.Cog): # Get the role object role = guild.get_role(highest_role_id) if not role: - print(f"DEBUG: Role {highest_role_id} not found in guild {guild_id} in assign_level_role") + print( + f"DEBUG: Role {highest_role_id} not found in guild {guild_id} in assign_level_role" + ) return False if role and role not in member.roles: try: @@ -307,7 +329,9 @@ class LevelingCog(commands.Cog): roles_to_remove.append(other_role) if roles_to_remove: - await member.remove_roles(*roles_to_remove, reason="Level role update") + await member.remove_roles( + *roles_to_remove, reason="Level role update" + ) # Add the new role await member.add_roles(role, reason=f"Reached level {level}") @@ -355,11 +379,13 @@ class LevelingCog(commands.Cog): # If user leveled up, send a message if notifications are enabled for them if new_level: user_data = self.get_user_data(user_id) - if user_data.get("level_notifs_enabled", self.config["default_level_notifs_enabled"]): + if user_data.get( + "level_notifs_enabled", self.config["default_level_notifs_enabled"] + ): try: await message.channel.send( f"🎉 Congratulations {message.author.mention}! You've reached level **{new_level}**!", - delete_after=10 # Delete after 10 seconds + delete_after=10, # Delete after 10 seconds ) except discord.Forbidden: pass # Ignore if we can't send messages @@ -393,27 +419,41 @@ class LevelingCog(commands.Cog): progress = xp_current / xp_required progress_bar_length = 20 filled_length = int(progress_bar_length * progress) - bar = '█' * filled_length + '░' * (progress_bar_length - filled_length) + bar = "█" * filled_length + "░" * (progress_bar_length - filled_length) class LevelCheckView(ui.LayoutView): - def __init__(self, target_member: discord.Member, level: int, xp: int, xp_needed: int, next_level: int, bar: str, progress_percent: int): + def __init__( + self, + target_member: discord.Member, + level: int, + xp: int, + xp_needed: int, + next_level: int, + bar: str, + progress_percent: int, + ): super().__init__() # Debug logging for parameters if not target_member: print("DEBUG: target_member is None in LevelCheckView.__init__") - + # Main container for all elements, providing the accent color main_container = ui.Container(accent_colour=None) if main_container is None: - raise AssertionError("ui.Container returned None in LevelCheckView; ensure accent_colour is valid") - self.add_item(main_container) # Add the main container to the view + raise AssertionError( + "ui.Container returned None in LevelCheckView; ensure accent_colour is valid" + ) + self.add_item(main_container) # Add the main container to the view # Prepare thumbnail accessory thumbnail_accessory = None if target_member.display_avatar: - thumbnail_accessory = ui.Thumbnail(media=target_member.display_avatar.url, description="User Avatar") - + thumbnail_accessory = ui.Thumbnail( + media=target_member.display_avatar.url, + description="User Avatar", + ) + # Section to hold the user's name and level/XP, with the thumbnail as accessory # This section will be added to the main_container user_info_section = ui.Section(accessory=thumbnail_accessory) @@ -422,36 +462,54 @@ class LevelingCog(commands.Cog): main_container.add_item(user_info_section) # Add text components to the user_info_section - name_display = ui.TextDisplay(f"**{target_member.display_name}'s Level**") + name_display = ui.TextDisplay( + f"**{target_member.display_name}'s Level**" + ) if name_display is None: - raise AssertionError("ui.TextDisplay returned None for name in LevelCheckView") + raise AssertionError( + "ui.TextDisplay returned None for name in LevelCheckView" + ) user_info_section.add_item(name_display) - - level_display = ui.TextDisplay(f"**Level:** {level}\n**XP:** {xp} / {xp_needed}") + + level_display = ui.TextDisplay( + f"**Level:** {level}\n**XP:** {xp} / {xp_needed}" + ) if level_display is None: - raise AssertionError("ui.TextDisplay returned None for level info in LevelCheckView") + raise AssertionError( + "ui.TextDisplay returned None for level info in LevelCheckView" + ) user_info_section.add_item(level_display) - + # Add remaining components directly to the main_container separator = ui.Separator(spacing=discord.SeparatorSpacing.small) if separator is None: raise AssertionError("ui.Separator returned None in LevelCheckView") main_container.add_item(separator) - - progress_text = ui.TextDisplay(f"**Progress to Level {next_level}:**\n[{bar}] {progress_percent}%") + + progress_text = ui.TextDisplay( + f"**Progress to Level {next_level}:**\n[{bar}] {progress_percent}%" + ) if progress_text is None: - raise AssertionError("ui.TextDisplay returned None in LevelCheckView") + raise AssertionError( + "ui.TextDisplay returned None in LevelCheckView" + ) main_container.add_item(progress_text) - + try: - view = LevelCheckView(target, level, xp, xp_needed, next_level, bar, int(progress * 100)) + view = LevelCheckView( + target, level, xp, xp_needed, next_level, bar, int(progress * 100) + ) await ctx.send(view=view) except Exception as e: print(f"Error creating level check view: {e}") traceback.print_exc() - await ctx.send("❌ An error occurred while creating the level display. Please check the console for details.") + await ctx.send( + "❌ An error occurred while creating the level display. Please check the console for details." + ) - @level.command(name="leaderboard", description="Show the server's level leaderboard") + @level.command( + name="leaderboard", description="Show the server's level leaderboard" + ) async def leaderboard_command(self, ctx: commands.Context): """Show the server's level leaderboard""" if not ctx.guild: @@ -471,32 +529,47 @@ class LevelingCog(commands.Cog): sorted_data = sorted(guild_data.items(), key=lambda x: x[1]["xp"], reverse=True) class LeaderboardView(ui.LayoutView): - def __init__(self, guild_name: str, sorted_leaderboard_data: list, guild_members_dict: dict): + def __init__( + self, + guild_name: str, + sorted_leaderboard_data: list, + guild_members_dict: dict, + ): super().__init__() # Main container for all elements, providing the accent color main_container = ui.Container(accent_colour=discord.Colour.gold()) if main_container is None: - raise AssertionError("ui.Container returned None in LeaderboardView; ensure accent_colour is valid") - self.add_item(main_container) # Add the main container to the view + raise AssertionError( + "ui.Container returned None in LeaderboardView; ensure accent_colour is valid" + ) + self.add_item(main_container) # Add the main container to the view title_display = ui.TextDisplay(f"**{guild_name} Level Leaderboard**") if title_display is None: - raise AssertionError("ui.TextDisplay returned None for title in LeaderboardView") + raise AssertionError( + "ui.TextDisplay returned None for title in LeaderboardView" + ) main_container.add_item(title_display) - + sep = ui.Separator(spacing=discord.SeparatorSpacing.small) if sep is None: - raise AssertionError("ui.Separator returned None in LeaderboardView") + raise AssertionError( + "ui.Separator returned None in LeaderboardView" + ) main_container.add_item(sep) if not sorted_leaderboard_data: empty_display = ui.TextDisplay("The leaderboard is empty!") if empty_display is None: - raise AssertionError("ui.TextDisplay returned None for empty message in LeaderboardView") + raise AssertionError( + "ui.TextDisplay returned None for empty message in LeaderboardView" + ) main_container.add_item(empty_display) else: - for i, (user_id, data) in enumerate(sorted_leaderboard_data[:10], 1): + for i, (user_id, data) in enumerate( + sorted_leaderboard_data[:10], 1 + ): member = guild_members_dict.get(user_id) if not member: continue @@ -504,44 +577,64 @@ class LevelingCog(commands.Cog): # Each user's entry gets its own section and is added to the main_container user_section = ui.Section(accessory=None) if user_section is None: - raise AssertionError("ui.Section returned None in LeaderboardView") + raise AssertionError( + "ui.Section returned None in LeaderboardView" + ) main_container.add_item(user_section) # Add text components to the user_section rank_display = ui.TextDisplay(f"**{i}. {member.display_name}**") if rank_display is None: - raise AssertionError("ui.TextDisplay returned None for rank in LeaderboardView") + raise AssertionError( + "ui.TextDisplay returned None for rank in LeaderboardView" + ) user_section.add_item(rank_display) - - level_display = ui.TextDisplay(f"Level: {data['level']} | XP: {data['xp']}") + + level_display = ui.TextDisplay( + f"Level: {data['level']} | XP: {data['xp']}" + ) if level_display is None: - raise AssertionError("ui.TextDisplay returned None for level in LeaderboardView") + raise AssertionError( + "ui.TextDisplay returned None for level in LeaderboardView" + ) user_section.add_item(level_display) - + # Add separator to the main_container - if i < len(sorted_leaderboard_data[:10]): # not the last row - separator = ui.Separator(spacing=discord.SeparatorSpacing.small) + if i < len(sorted_leaderboard_data[:10]): # not the last row + separator = ui.Separator( + spacing=discord.SeparatorSpacing.small + ) if separator is None: - raise AssertionError("ui.Separator returned None between rows in LeaderboardView") + raise AssertionError( + "ui.Separator returned None between rows in LeaderboardView" + ) main_container.add_item(separator) try: view = LeaderboardView(ctx.guild.name, sorted_data, guild_members) - + # Double-check the view is dispatchable and properly constructed if view is None: - return await ctx.send("❌ Failed to build leaderboard layout. Please try again.") - + return await ctx.send( + "❌ Failed to build leaderboard layout. Please try again." + ) + # Send the view await ctx.send(view=view) except Exception as e: print(f"Error creating leaderboard view: {e}") traceback.print_exc() - await ctx.send("❌ An error occurred while creating the leaderboard. Please check the console for details.") + await ctx.send( + "❌ An error occurred while creating the leaderboard. Please check the console for details." + ) - @level.command(name="register_role", description="Register a role for a specific level") + @level.command( + name="register_role", description="Register a role for a specific level" + ) @commands.has_permissions(manage_roles=True) - async def register_level_role(self, ctx: commands.Context, level: int, role: discord.Role): + async def register_level_role( + self, ctx: commands.Context, level: int, role: discord.Role + ): """Register a role to be assigned at a specific level""" if not ctx.guild: await ctx.send("This command can only be used in a server.") @@ -569,7 +662,10 @@ class LevelingCog(commands.Cog): await ctx.send("This command can only be used in a server.") return - if ctx.guild.id not in self.level_roles or level not in self.level_roles[ctx.guild.id]: + if ( + ctx.guild.id not in self.level_roles + or level not in self.level_roles[ctx.guild.id] + ): await ctx.send("No role is registered for this level.") return @@ -596,113 +692,174 @@ class LevelingCog(commands.Cog): main_container = ui.Container(accent_colour=discord.Colour.blue()) if main_container is None: - raise AssertionError("ui.Container returned None in ListLevelRolesView; ensure accent_colour is valid") + raise AssertionError( + "ui.Container returned None in ListLevelRolesView; ensure accent_colour is valid" + ) self.add_item(main_container) title_display = ui.TextDisplay(f"**Level Roles for {guild.name}**") if title_display is None: - raise AssertionError("ui.TextDisplay returned None for title in ListLevelRolesView") + raise AssertionError( + "ui.TextDisplay returned None for title in ListLevelRolesView" + ) main_container.add_item(title_display) - + sep = ui.Separator(spacing=discord.SeparatorSpacing.small) if sep is None: - raise AssertionError("ui.Separator returned None in ListLevelRolesView") + raise AssertionError( + "ui.Separator returned None in ListLevelRolesView" + ) main_container.add_item(sep) - if not level_roles_data: # Should be caught by the check above, but good practice - empty_display = ui.TextDisplay("No level roles are registered for this server.") + if ( + not level_roles_data + ): # Should be caught by the check above, but good practice + empty_display = ui.TextDisplay( + "No level roles are registered for this server." + ) if empty_display is None: - raise AssertionError("ui.TextDisplay returned None for empty message in ListLevelRolesView") + raise AssertionError( + "ui.TextDisplay returned None for empty message in ListLevelRolesView" + ) main_container.add_item(empty_display) return sorted_roles_items = sorted(level_roles_data.items()) for level, role_data_or_id in sorted_roles_items: - role_section = ui.Section(accessory=None) # Explicitly pass accessory=None + role_section = ui.Section( + accessory=None + ) # Explicitly pass accessory=None if role_section is None: - raise AssertionError("ui.Section returned None in ListLevelRolesView") - + raise AssertionError( + "ui.Section returned None in ListLevelRolesView" + ) + level_title = ui.TextDisplay(f"**Level {level}:**") if level_title is None: - raise AssertionError("ui.TextDisplay returned None for level title in ListLevelRolesView") + raise AssertionError( + "ui.TextDisplay returned None for level title in ListLevelRolesView" + ) role_section.add_item(level_title) - if isinstance(role_data_or_id, dict): # Gendered roles + if isinstance(role_data_or_id, dict): # Gendered roles for gender, role_id in role_data_or_id.items(): role = guild.get_role(role_id) - role_name = role.mention if role else f"Unknown Role (ID: {role_id})" - gender_display = ui.TextDisplay(f" - {gender.capitalize()}: {role_name}") + role_name = ( + role.mention + if role + else f"Unknown Role (ID: {role_id})" + ) + gender_display = ui.TextDisplay( + f" - {gender.capitalize()}: {role_name}" + ) if gender_display is None: - raise AssertionError("ui.TextDisplay returned None for gender role in ListLevelRolesView") + raise AssertionError( + "ui.TextDisplay returned None for gender role in ListLevelRolesView" + ) role_section.add_item(gender_display) - else: # Regular role + else: # Regular role role = guild.get_role(role_data_or_id) - role_name = role.mention if role else f"Unknown Role (ID: {role_data_or_id})" + role_name = ( + role.mention + if role + else f"Unknown Role (ID: {role_data_or_id})" + ) role_display = ui.TextDisplay(f" {role_name}") if role_display is None: - raise AssertionError("ui.TextDisplay returned None for regular role in ListLevelRolesView") + raise AssertionError( + "ui.TextDisplay returned None for regular role in ListLevelRolesView" + ) role_section.add_item(role_display) - + main_container.add_item(role_section) - if level != sorted_roles_items[-1][0]: # Add separator if not the last item + if ( + level != sorted_roles_items[-1][0] + ): # Add separator if not the last item separator = ui.Separator(spacing=discord.SeparatorSpacing.small) if separator is None: - raise AssertionError("ui.Separator returned None between roles in ListLevelRolesView") + raise AssertionError( + "ui.Separator returned None between roles in ListLevelRolesView" + ) main_container.add_item(separator) - try: view = ListLevelRolesView(ctx.guild, self.level_roles[ctx.guild.id]) await ctx.send(view=view) except Exception as e: print(f"Error creating list level roles view: {e}") traceback.print_exc() - await ctx.send("❌ An error occurred while creating the level roles list. Please check the console for details.") + await ctx.send( + "❌ An error occurred while creating the level roles list. Please check the console for details." + ) - @level.command(name="restrict_channel", description="Restrict a channel from giving XP") + @level.command( + name="restrict_channel", description="Restrict a channel from giving XP" + ) @commands.has_permissions(manage_channels=True) - async def restrict_channel(self, ctx: commands.Context, channel: discord.TextChannel = None): + async def restrict_channel( + self, ctx: commands.Context, channel: discord.TextChannel = None + ): """Restrict a channel from giving XP""" target_channel = channel or ctx.channel if target_channel.id in self.restricted_channels: - await ctx.send(f"{target_channel.mention} is already restricted from giving XP.") + await ctx.send( + f"{target_channel.mention} is already restricted from giving XP." + ) return self.restricted_channels.add(target_channel.id) self.save_restricted_channels() - await ctx.send(f"✅ {target_channel.mention} will no longer give XP for messages.") + await ctx.send( + f"✅ {target_channel.mention} will no longer give XP for messages." + ) - @level.command(name="unrestrict_channel", description="Allow a channel to give XP again") + @level.command( + name="unrestrict_channel", description="Allow a channel to give XP again" + ) @commands.has_permissions(manage_channels=True) - async def unrestrict_channel(self, ctx: commands.Context, channel: discord.TextChannel = None): + async def unrestrict_channel( + self, ctx: commands.Context, channel: discord.TextChannel = None + ): """Allow a channel to give XP again""" target_channel = channel or ctx.channel if target_channel.id not in self.restricted_channels: - await ctx.send(f"{target_channel.mention} is not restricted from giving XP.") + await ctx.send( + f"{target_channel.mention} is not restricted from giving XP." + ) return self.restricted_channels.remove(target_channel.id) self.save_restricted_channels() await ctx.send(f"✅ {target_channel.mention} will now give XP for messages.") - @level.command(name="process_messages", description="Process existing messages to award XP") + @level.command( + name="process_messages", description="Process existing messages to award XP" + ) @commands.is_owner() - async def process_existing_messages(self, ctx: commands.Context, limit: int = 10000): + async def process_existing_messages( + self, ctx: commands.Context, limit: int = 10000 + ): """Process existing messages to award XP (Owner only)""" if not ctx.guild: await ctx.send("This command can only be used in a server.") return - status_message = await ctx.send(f"Processing existing messages (up to {limit} per channel)...") + status_message = await ctx.send( + f"Processing existing messages (up to {limit} per channel)..." + ) total_processed = 0 total_channels = 0 # Get all text channels in the guild - text_channels = [channel for channel in ctx.guild.channels if isinstance(channel, discord.TextChannel)] + text_channels = [ + channel + for channel in ctx.guild.channels + if isinstance(channel, discord.TextChannel) + ] for channel in text_channels: # Skip restricted channels @@ -713,7 +870,9 @@ class LevelingCog(commands.Cog): processed_in_channel = 0 # Update status message - await status_message.edit(content=f"Processing channel {channel.mention}... ({total_processed} messages processed so far)") + await status_message.edit( + content=f"Processing channel {channel.mention}... ({total_processed} messages processed so far)" + ) async for message in channel.history(limit=limit): # Skip bot messages @@ -730,18 +889,24 @@ class LevelingCog(commands.Cog): # Update status every 1000 messages if total_processed % 1000 == 0: - await status_message.edit(content=f"Processing channel {channel.mention}... ({total_processed} messages processed so far)") + await status_message.edit( + content=f"Processing channel {channel.mention}... ({total_processed} messages processed so far)" + ) total_channels += 1 except discord.Forbidden: - await ctx.send(f"Missing permissions to read message history in {channel.mention}") + await ctx.send( + f"Missing permissions to read message history in {channel.mention}" + ) except Exception as e: await ctx.send(f"Error processing messages in {channel.mention}: {e}") traceback.print_exc() # Final update - await status_message.edit(content=f"✅ Finished processing {total_processed} messages across {total_channels} channels.") + await status_message.edit( + content=f"✅ Finished processing {total_processed} messages across {total_channels} channels." + ) @commands.Cog.listener() async def on_raw_reaction_add(self, payload): @@ -789,13 +954,15 @@ class LevelingCog(commands.Cog): try: member = channel.guild.get_member(user_id) if member: - await member.send(f"🎉 Congratulations! You've reached level **{new_level}**!") + await member.send( + f"🎉 Congratulations! You've reached level **{new_level}**!" + ) except discord.Forbidden: pass # Ignore if we can't send DMs @commands.Cog.listener() async def on_ready(self): - print(f'{self.__class__.__name__} cog has been loaded.') + print(f"{self.__class__.__name__} cog has been loaded.") async def cog_unload(self): """Save all data when cog is unloaded""" @@ -803,64 +970,98 @@ class LevelingCog(commands.Cog): self.save_level_roles() self.save_restricted_channels() self.save_config() - print(f'{self.__class__.__name__} cog has been unloaded and data saved.') + print(f"{self.__class__.__name__} cog has been unloaded and data saved.") @level.command(name="config", description="Configure XP settings") @commands.has_permissions(administrator=True) - async def xp_config(self, ctx: commands.Context, setting: str = None, value: str = None): + async def xp_config( + self, ctx: commands.Context, setting: str = None, value: str = None + ): """Configure XP settings for the leveling system""" if not setting: + class XPConfigView(ui.LayoutView): def __init__(self, config_data: dict, prefix: str): super().__init__() main_container = ui.Container(accent_colour=discord.Colour.blue()) if main_container is None: - raise AssertionError("ui.Container returned None in XPConfigView; ensure accent_colour is valid") + raise AssertionError( + "ui.Container returned None in XPConfigView; ensure accent_colour is valid" + ) self.add_item(main_container) title_text = ui.TextDisplay("**XP Configuration Settings**") if title_text is None: - raise AssertionError("ui.TextDisplay returned None for title in XPConfigView") + raise AssertionError( + "ui.TextDisplay returned None for title in XPConfigView" + ) main_container.add_item(title_text) - - desc_text = ui.TextDisplay("Current XP settings for the leveling system:") + + desc_text = ui.TextDisplay( + "Current XP settings for the leveling system:" + ) if desc_text is None: - raise AssertionError("ui.TextDisplay returned None for description in XPConfigView") + raise AssertionError( + "ui.TextDisplay returned None for description in XPConfigView" + ) main_container.add_item(desc_text) - + separator = ui.Separator(spacing=discord.SeparatorSpacing.small) if separator is None: - raise AssertionError("ui.Separator returned None in XPConfigView") + raise AssertionError( + "ui.Separator returned None in XPConfigView" + ) main_container.add_item(separator) settings_to_display = [ ("XP Per Message", str(config_data["xp_per_message"])), ("XP Per Reaction", str(config_data["xp_per_reaction"])), - ("Message Cooldown", f"{config_data['message_cooldown']} seconds"), - ("Reaction Cooldown", f"{config_data['reaction_cooldown']} seconds"), - ("Reaction XP Enabled", "Yes" if config_data["reaction_xp_enabled"] else "No") + ( + "Message Cooldown", + f"{config_data['message_cooldown']} seconds", + ), + ( + "Reaction Cooldown", + f"{config_data['reaction_cooldown']} seconds", + ), + ( + "Reaction XP Enabled", + "Yes" if config_data["reaction_xp_enabled"] else "No", + ), ] for name, value_str in settings_to_display: - setting_section = ui.Section(accessory=None) # Explicitly pass accessory=None + setting_section = ui.Section( + accessory=None + ) # Explicitly pass accessory=None if setting_section is None: - raise AssertionError("ui.Section returned None in XPConfigView") - + raise AssertionError( + "ui.Section returned None in XPConfigView" + ) + setting_display = ui.TextDisplay(f"**{name}:** {value_str}") if setting_display is None: - raise AssertionError("ui.TextDisplay returned None for setting in XPConfigView") + raise AssertionError( + "ui.TextDisplay returned None for setting in XPConfigView" + ) setting_section.add_item(setting_display) main_container.add_item(setting_section) - + separator = ui.Separator(spacing=discord.SeparatorSpacing.small) if separator is None: - raise AssertionError("ui.Separator returned None at bottom in XPConfigView") + raise AssertionError( + "ui.Separator returned None at bottom in XPConfigView" + ) main_container.add_item(separator) - - help_text = ui.TextDisplay(f"Use {prefix}level config to change a setting") + + help_text = ui.TextDisplay( + f"Use {prefix}level config to change a setting" + ) if help_text is None: - raise AssertionError("ui.TextDisplay returned None for help text in XPConfigView") + raise AssertionError( + "ui.TextDisplay returned None for help text in XPConfigView" + ) main_container.add_item(help_text) # Attempt to get the prefix @@ -873,9 +1074,9 @@ class LevelingCog(commands.Cog): # If not, we might need to ask the user or make an assumption. # For now, let's try with ctx.prefix. If it causes an error, we can adjust. # A safer bet might be to use the command's qualified name. - command_prefix = ctx.prefix if ctx.prefix else "!" # Fallback to "!" + command_prefix = ctx.prefix if ctx.prefix else "!" # Fallback to "!" except AttributeError: - command_prefix = "!" # Fallback if ctx.prefix doesn't exist + command_prefix = "!" # Fallback if ctx.prefix doesn't exist try: view = XPConfigView(self.config, command_prefix) @@ -883,7 +1084,9 @@ class LevelingCog(commands.Cog): except Exception as e: print(f"Error creating XP config view: {e}") traceback.print_exc() - await ctx.send("❌ An error occurred while creating the configuration display. Please check the console for details.") + await ctx.send( + "❌ An error occurred while creating the configuration display. Please check the console for details." + ) return if not value: @@ -918,7 +1121,9 @@ class LevelingCog(commands.Cog): try: cooldown = int(value) if cooldown < 0 or cooldown > 3600: - await ctx.send("Message cooldown must be between 0 and 3600 seconds.") + await ctx.send( + "Message cooldown must be between 0 and 3600 seconds." + ) return self.config["message_cooldown"] = cooldown await ctx.send(f"✅ Message cooldown set to {cooldown} seconds.") @@ -929,7 +1134,9 @@ class LevelingCog(commands.Cog): try: cooldown = int(value) if cooldown < 0 or cooldown > 3600: - await ctx.send("Reaction cooldown must be between 0 and 3600 seconds.") + await ctx.send( + "Reaction cooldown must be between 0 and 3600 seconds." + ) return self.config["reaction_cooldown"] = cooldown await ctx.send(f"✅ Reaction cooldown set to {cooldown} seconds.") @@ -948,25 +1155,35 @@ class LevelingCog(commands.Cog): await ctx.send("Value must be 'true' or 'false'.") else: - await ctx.send(f"Unknown setting: {setting}. Available settings: xp_per_message, xp_per_reaction, message_cooldown, reaction_cooldown, reaction_xp_enabled") + await ctx.send( + f"Unknown setting: {setting}. Available settings: xp_per_message, xp_per_reaction, message_cooldown, reaction_cooldown, reaction_xp_enabled" + ) return # Save the updated configuration self.save_config() - @level.command(name="toggle_notifs", description="Toggle level-up notifications for yourself") + @level.command( + name="toggle_notifs", description="Toggle level-up notifications for yourself" + ) async def toggle_level_notifs(self, ctx: commands.Context): """Toggle level-up notifications for yourself""" user_data = self.get_user_data(ctx.author.id) - current_status = user_data.get("level_notifs_enabled", self.config["default_level_notifs_enabled"]) + current_status = user_data.get( + "level_notifs_enabled", self.config["default_level_notifs_enabled"] + ) new_status = not current_status user_data["level_notifs_enabled"] = new_status self.save_user_data() status_text = "enabled" if new_status else "disabled" - await ctx.send(f"✅ Level-up notifications have been **{status_text}** for you.") + await ctx.send( + f"✅ Level-up notifications have been **{status_text}** for you." + ) - @level.command(name="setup_medieval_roles", description="Set up medieval-themed level roles") + @level.command( + name="setup_medieval_roles", description="Set up medieval-themed level roles" + ) @commands.has_permissions(manage_roles=True) async def setup_medieval_roles(self, ctx: commands.Context): """Automatically set up medieval-themed level roles with gender customization""" @@ -983,19 +1200,19 @@ class LevelingCog(commands.Cog): 30: {"default": "Count/Countess", "male": "Count", "female": "Countess"}, 50: {"default": "Duke/Duchess", "male": "Duke", "female": "Duchess"}, 75: {"default": "Prince/Princess", "male": "Prince", "female": "Princess"}, - 100: {"default": "King/Queen", "male": "King", "female": "Queen"} + 100: {"default": "King/Queen", "male": "King", "female": "Queen"}, } # Colors for the roles (gradient from gray to gold) colors = { 1: discord.Color.from_rgb(128, 128, 128), # Gray 5: discord.Color.from_rgb(153, 153, 153), # Light Gray - 10: discord.Color.from_rgb(170, 170, 170), # Silver + 10: discord.Color.from_rgb(170, 170, 170), # Silver 20: discord.Color.from_rgb(218, 165, 32), # Goldenrod - 30: discord.Color.from_rgb(255, 215, 0), # Gold - 50: discord.Color.from_rgb(255, 223, 0), # Bright Gold - 75: discord.Color.from_rgb(255, 235, 0), # Royal Gold - 100: discord.Color.from_rgb(255, 255, 0) # Yellow/Gold + 30: discord.Color.from_rgb(255, 215, 0), # Gold + 50: discord.Color.from_rgb(255, 223, 0), # Bright Gold + 75: discord.Color.from_rgb(255, 235, 0), # Royal Gold + 100: discord.Color.from_rgb(255, 255, 0), # Yellow/Gold } # Initialize guild in level_roles if not exists @@ -1029,7 +1246,9 @@ class LevelingCog(commands.Cog): if existing_role: # Update existing role try: - await existing_role.edit(color=colors[level], reason="Updating medieval level role") + await existing_role.edit( + color=colors[level], reason="Updating medieval level role" + ) updated_roles.append(role_name) except discord.Forbidden: await ctx.send(f"Missing permissions to edit role: {role_name}") @@ -1042,11 +1261,13 @@ class LevelingCog(commands.Cog): role = await ctx.guild.create_role( name=role_name, color=colors[level], - reason="Creating medieval level role" + reason="Creating medieval level role", ) created_roles.append(role_name) except discord.Forbidden: - await ctx.send(f"Missing permissions to create role: {role_name}") + await ctx.send( + f"Missing permissions to create role: {role_name}" + ) except Exception as e: await ctx.send(f"Error creating role {role_name}: {e}") traceback.print_exc() @@ -1064,10 +1285,14 @@ class LevelingCog(commands.Cog): if male_role: try: - await male_role.edit(color=colors[level], reason="Updating medieval level role") + await male_role.edit( + color=colors[level], reason="Updating medieval level role" + ) updated_roles.append(male_role_name) except discord.Forbidden: - await ctx.send(f"Missing permissions to edit role: {male_role_name}") + await ctx.send( + f"Missing permissions to edit role: {male_role_name}" + ) except Exception as e: await ctx.send(f"Error updating role {male_role_name}: {e}") traceback.print_exc() @@ -1076,11 +1301,13 @@ class LevelingCog(commands.Cog): male_role = await ctx.guild.create_role( name=male_role_name, color=colors[level], - reason="Creating medieval level role" + reason="Creating medieval level role", ) created_roles.append(male_role_name) except discord.Forbidden: - await ctx.send(f"Missing permissions to create role: {male_role_name}") + await ctx.send( + f"Missing permissions to create role: {male_role_name}" + ) except Exception as e: await ctx.send(f"Error creating role {male_role_name}: {e}") traceback.print_exc() @@ -1092,10 +1319,14 @@ class LevelingCog(commands.Cog): if female_role: try: - await female_role.edit(color=colors[level], reason="Updating medieval level role") + await female_role.edit( + color=colors[level], reason="Updating medieval level role" + ) updated_roles.append(female_role_name) except discord.Forbidden: - await ctx.send(f"Missing permissions to edit role: {female_role_name}") + await ctx.send( + f"Missing permissions to edit role: {female_role_name}" + ) except Exception as e: await ctx.send(f"Error updating role {female_role_name}: {e}") traceback.print_exc() @@ -1104,11 +1335,13 @@ class LevelingCog(commands.Cog): female_role = await ctx.guild.create_role( name=female_role_name, color=colors[level], - reason="Creating medieval level role" + reason="Creating medieval level role", ) created_roles.append(female_role_name) except discord.Forbidden: - await ctx.send(f"Missing permissions to create role: {female_role_name}") + await ctx.send( + f"Missing permissions to create role: {female_role_name}" + ) except Exception as e: await ctx.send(f"Error creating role {female_role_name}: {e}") traceback.print_exc() @@ -1128,97 +1361,157 @@ class LevelingCog(commands.Cog): self.save_level_roles() class MedievalRolesSetupView(ui.LayoutView): - def __init__(self, created_roles_list: list, updated_roles_list: list, has_pronoun_roles_flag: bool): + def __init__( + self, + created_roles_list: list, + updated_roles_list: list, + has_pronoun_roles_flag: bool, + ): super().__init__() main_container = ui.Container(accent_colour=discord.Colour.gold()) if main_container is None: - raise AssertionError("ui.Container returned None in MedievalRolesSetupView; ensure accent_colour is valid") + raise AssertionError( + "ui.Container returned None in MedievalRolesSetupView; ensure accent_colour is valid" + ) self.add_item(main_container) title_display = ui.TextDisplay("**Medieval Level Roles Setup**") if title_display is None: - raise AssertionError("ui.TextDisplay returned None for title in MedievalRolesSetupView") + raise AssertionError( + "ui.TextDisplay returned None for title in MedievalRolesSetupView" + ) main_container.add_item(title_display) - - desc_display = ui.TextDisplay("The following roles have been set up for the medieval leveling system:") + + desc_display = ui.TextDisplay( + "The following roles have been set up for the medieval leveling system:" + ) if desc_display is None: - raise AssertionError("ui.TextDisplay returned None for description in MedievalRolesSetupView") + raise AssertionError( + "ui.TextDisplay returned None for description in MedievalRolesSetupView" + ) main_container.add_item(desc_display) - + sep = ui.Separator(spacing=discord.SeparatorSpacing.small) if sep is None: - raise AssertionError("ui.Separator returned None in MedievalRolesSetupView") + raise AssertionError( + "ui.Separator returned None in MedievalRolesSetupView" + ) main_container.add_item(sep) if created_roles_list: - created_section = ui.Section(accessory=None) # Explicitly pass accessory=None + created_section = ui.Section( + accessory=None + ) # Explicitly pass accessory=None if created_section is None: - raise AssertionError("ui.Section returned None for created roles in MedievalRolesSetupView") - + raise AssertionError( + "ui.Section returned None for created roles in MedievalRolesSetupView" + ) + created_title = ui.TextDisplay("**Created Roles:**") if created_title is None: - raise AssertionError("ui.TextDisplay returned None for created roles title in MedievalRolesSetupView") + raise AssertionError( + "ui.TextDisplay returned None for created roles title in MedievalRolesSetupView" + ) created_section.add_item(created_title) - + # For potentially long lists, join with newline. TextDisplay handles multiline. - created_list = ui.TextDisplay("\n".join(created_roles_list) if created_roles_list else "None") + created_list = ui.TextDisplay( + "\n".join(created_roles_list) if created_roles_list else "None" + ) if created_list is None: - raise AssertionError("ui.TextDisplay returned None for created roles list in MedievalRolesSetupView") + raise AssertionError( + "ui.TextDisplay returned None for created roles list in MedievalRolesSetupView" + ) created_section.add_item(created_list) main_container.add_item(created_section) - - if created_roles_list: # Only add separator if there are created roles + + if ( + created_roles_list + ): # Only add separator if there are created roles separator = ui.Separator(spacing=discord.SeparatorSpacing.small) if separator is None: - raise AssertionError("ui.Separator returned None after created roles in MedievalRolesSetupView") + raise AssertionError( + "ui.Separator returned None after created roles in MedievalRolesSetupView" + ) main_container.add_item(separator) if updated_roles_list: - updated_section = ui.Section(accessory=None) # Explicitly pass accessory=None + updated_section = ui.Section( + accessory=None + ) # Explicitly pass accessory=None if updated_section is None: - raise AssertionError("ui.Section returned None for updated roles in MedievalRolesSetupView") - + raise AssertionError( + "ui.Section returned None for updated roles in MedievalRolesSetupView" + ) + updated_title = ui.TextDisplay("**Updated Roles:**") if updated_title is None: - raise AssertionError("ui.TextDisplay returned None for updated roles title in MedievalRolesSetupView") + raise AssertionError( + "ui.TextDisplay returned None for updated roles title in MedievalRolesSetupView" + ) updated_section.add_item(updated_title) - - updated_list = ui.TextDisplay("\n".join(updated_roles_list) if updated_roles_list else "None") + + updated_list = ui.TextDisplay( + "\n".join(updated_roles_list) if updated_roles_list else "None" + ) if updated_list is None: - raise AssertionError("ui.TextDisplay returned None for updated roles list in MedievalRolesSetupView") + raise AssertionError( + "ui.TextDisplay returned None for updated roles list in MedievalRolesSetupView" + ) updated_section.add_item(updated_list) main_container.add_item(updated_section) - - if updated_roles_list: # Only add separator if there are updated roles + + if ( + updated_roles_list + ): # Only add separator if there are updated roles separator = ui.Separator(spacing=discord.SeparatorSpacing.small) if separator is None: - raise AssertionError("ui.Separator returned None after updated roles in MedievalRolesSetupView") + raise AssertionError( + "ui.Separator returned None after updated roles in MedievalRolesSetupView" + ) main_container.add_item(separator) - gender_detection_section = ui.Section(accessory=None) # Explicitly pass accessory=None + gender_detection_section = ui.Section( + accessory=None + ) # Explicitly pass accessory=None if gender_detection_section is None: - raise AssertionError("ui.Section returned None for gender detection in MedievalRolesSetupView") - + raise AssertionError( + "ui.Section returned None for gender detection in MedievalRolesSetupView" + ) + gender_title = ui.TextDisplay("**Gender Detection:**") if gender_title is None: - raise AssertionError("ui.TextDisplay returned None for gender detection title in MedievalRolesSetupView") + raise AssertionError( + "ui.TextDisplay returned None for gender detection title in MedievalRolesSetupView" + ) gender_detection_section.add_item(gender_title) - - gender_text = "Gender-specific roles will be assigned based on pronoun roles." if has_pronoun_roles_flag else "No pronoun roles detected. Using default titles." + + gender_text = ( + "Gender-specific roles will be assigned based on pronoun roles." + if has_pronoun_roles_flag + else "No pronoun roles detected. Using default titles." + ) gender_desc = ui.TextDisplay(gender_text) if gender_desc is None: - raise AssertionError("ui.TextDisplay returned None for gender detection description in MedievalRolesSetupView") + raise AssertionError( + "ui.TextDisplay returned None for gender detection description in MedievalRolesSetupView" + ) gender_detection_section.add_item(gender_desc) main_container.add_item(gender_detection_section) try: - view = MedievalRolesSetupView(created_roles, updated_roles, has_pronoun_roles) + view = MedievalRolesSetupView( + created_roles, updated_roles, has_pronoun_roles + ) await status_message.edit(content=None, view=view) except Exception as e: print(f"Error creating medieval roles setup view: {e}") traceback.print_exc() - await status_message.edit(content="❌ An error occurred while creating the setup summary. Please check the console for details.") + await status_message.edit( + content="❌ An error occurred while creating the setup summary. Please check the console for details." + ) + async def setup(bot: commands.Bot): await bot.add_cog(LevelingCog(bot)) diff --git a/cogs/lockdown_cog.py b/cogs/lockdown_cog.py index 071569f..03815ef 100644 --- a/cogs/lockdown_cog.py +++ b/cogs/lockdown_cog.py @@ -3,6 +3,7 @@ from discord.ext import commands from discord import app_commands import asyncio + class LockdownCog(commands.Cog): def __init__(self, bot): self.bot = bot @@ -10,33 +11,50 @@ class LockdownCog(commands.Cog): lockdown = app_commands.Group(name="lockdown", description="Lockdown commands") @lockdown.command(name="channel") - @app_commands.describe(channel="The channel to lock down", time="Duration of lockdown in seconds") + @app_commands.describe( + channel="The channel to lock down", time="Duration of lockdown in seconds" + ) @app_commands.checks.has_permissions(manage_channels=True) - async def channel_lockdown(self, interaction: discord.Interaction, channel: discord.TextChannel = None, time: int = None): + async def channel_lockdown( + self, + interaction: discord.Interaction, + channel: discord.TextChannel = None, + time: int = None, + ): """Locks down a channel.""" channel = channel or interaction.channel overwrite = channel.overwrites_for(interaction.guild.default_role) if overwrite.send_messages is False: - await interaction.response.send_message("Channel is already locked down.", ephemeral=True) + await interaction.response.send_message( + "Channel is already locked down.", ephemeral=True + ) return overwrite.send_messages = False - await channel.set_permissions(interaction.guild.default_role, overwrite=overwrite) - await interaction.response.send_message(f"Channel {channel.mention} locked down.") + await channel.set_permissions( + interaction.guild.default_role, overwrite=overwrite + ) + await interaction.response.send_message( + f"Channel {channel.mention} locked down." + ) if time: await asyncio.sleep(time) overwrite.send_messages = None - await channel.set_permissions(interaction.guild.default_role, overwrite=overwrite) - await interaction.followup.send(f"Channel {channel.mention} lockdown lifted.") + await channel.set_permissions( + interaction.guild.default_role, overwrite=overwrite + ) + await interaction.followup.send( + f"Channel {channel.mention} lockdown lifted." + ) @lockdown.command(name="server") @app_commands.describe(time="Duration of server lockdown in seconds") @app_commands.checks.has_permissions(administrator=True) async def server_lockdown(self, interaction: discord.Interaction, time: int = None): """Locks down the entire server.""" - await interaction.response.defer() # Defer the response as this might take time + await interaction.response.defer() # Defer the response as this might take time for channel in interaction.guild.text_channels: overwrite = channel.overwrites_for(interaction.guild.default_role) @@ -44,7 +62,9 @@ class LockdownCog(commands.Cog): continue overwrite.send_messages = False - await channel.set_permissions(interaction.guild.default_role, overwrite=overwrite) + await channel.set_permissions( + interaction.guild.default_role, overwrite=overwrite + ) await interaction.followup.send("Server locked down.") @@ -54,23 +74,31 @@ class LockdownCog(commands.Cog): for channel in interaction.guild.text_channels: overwrite = channel.overwrites_for(interaction.guild.default_role) overwrite.send_messages = None - await channel.set_permissions(interaction.guild.default_role, overwrite=overwrite) + await channel.set_permissions( + interaction.guild.default_role, overwrite=overwrite + ) await interaction.followup.send("Server lockdown lifted.") @lockdown.command(name="remove_channel") @app_commands.describe(channel="The channel to unlock") @app_commands.checks.has_permissions(manage_channels=True) - async def channel_remove(self, interaction: discord.Interaction, channel: discord.TextChannel = None): + async def channel_remove( + self, interaction: discord.Interaction, channel: discord.TextChannel = None + ): """Removes lockdown from a channel.""" channel = channel or interaction.channel overwrite = channel.overwrites_for(interaction.guild.default_role) if overwrite.send_messages is None or overwrite.send_messages is True: - await interaction.response.send_message("Channel is not locked down.", ephemeral=True) + await interaction.response.send_message( + "Channel is not locked down.", ephemeral=True + ) return overwrite.send_messages = None - await channel.set_permissions(interaction.guild.default_role, overwrite=overwrite) + await channel.set_permissions( + interaction.guild.default_role, overwrite=overwrite + ) await interaction.response.send_message(f"Channel {channel.mention} unlocked.") @lockdown.command(name="remove_server") @@ -82,7 +110,9 @@ class LockdownCog(commands.Cog): for channel in interaction.guild.text_channels: overwrite = channel.overwrites_for(interaction.guild.default_role) overwrite.send_messages = None - await channel.set_permissions(interaction.guild.default_role, overwrite=overwrite) + await channel.set_permissions( + interaction.guild.default_role, overwrite=overwrite + ) await interaction.followup.send("Server unlocked.") diff --git a/cogs/logging_cog.py b/cogs/logging_cog.py index 9709160..88de9e5 100644 --- a/cogs/logging_cog.py +++ b/cogs/logging_cog.py @@ -3,57 +3,95 @@ from discord.ext import commands, tasks from discord import ui, AllowedMentions import datetime import asyncio -import aiohttp # Added for webhook sending -import logging # Use logging instead of print +import aiohttp # Added for webhook sending +import logging # Use logging instead of print from typing import Optional, Union # Import settings manager try: - from .. import settings_manager # Relative import if cogs are in a subfolder + from .. import settings_manager # Relative import if cogs are in a subfolder except ImportError: - import settings_manager # Fallback for direct execution? Adjust as needed. + import settings_manager # Fallback for direct execution? Adjust as needed. -log = logging.getLogger(__name__) # Setup logger for this cog +log = logging.getLogger(__name__) # Setup logger for this cog # Define all possible event keys for toggling # Keep this list updated if new loggable events are added -ALL_EVENT_KEYS = sorted([ - # Direct Events - "member_join", "member_remove", "member_ban_event", "member_unban", "member_update", - "role_create_event", "role_delete_event", "role_update_event", - "channel_create_event", "channel_delete_event", "channel_update_event", - "message_edit", "message_delete", - "reaction_add", "reaction_remove", "reaction_clear", "reaction_clear_emoji", - "voice_state_update", - "guild_update_event", "emoji_update_event", - "invite_create_event", "invite_delete_event", - "command_error", # Potentially noisy - "thread_create", "thread_delete", "thread_update", "thread_member_join", "thread_member_remove", - "webhook_update", - # Audit Log Actions (prefixed with 'audit_') - "audit_kick", "audit_prune", "audit_ban", "audit_unban", - "audit_member_role_update", "audit_member_update_timeout", # Specific member_update cases - "audit_message_delete", "audit_message_bulk_delete", - "audit_role_create", "audit_role_delete", "audit_role_update", - "audit_channel_create", "audit_channel_delete", "audit_channel_update", - "audit_emoji_create", "audit_emoji_delete", "audit_emoji_update", - "audit_invite_create", "audit_invite_delete", - "audit_guild_update", - # Add more audit keys if needed, e.g., "audit_stage_instance_create" -]) +ALL_EVENT_KEYS = sorted( + [ + # Direct Events + "member_join", + "member_remove", + "member_ban_event", + "member_unban", + "member_update", + "role_create_event", + "role_delete_event", + "role_update_event", + "channel_create_event", + "channel_delete_event", + "channel_update_event", + "message_edit", + "message_delete", + "reaction_add", + "reaction_remove", + "reaction_clear", + "reaction_clear_emoji", + "voice_state_update", + "guild_update_event", + "emoji_update_event", + "invite_create_event", + "invite_delete_event", + "command_error", # Potentially noisy + "thread_create", + "thread_delete", + "thread_update", + "thread_member_join", + "thread_member_remove", + "webhook_update", + # Audit Log Actions (prefixed with 'audit_') + "audit_kick", + "audit_prune", + "audit_ban", + "audit_unban", + "audit_member_role_update", + "audit_member_update_timeout", # Specific member_update cases + "audit_message_delete", + "audit_message_bulk_delete", + "audit_role_create", + "audit_role_delete", + "audit_role_update", + "audit_channel_create", + "audit_channel_delete", + "audit_channel_update", + "audit_emoji_create", + "audit_emoji_delete", + "audit_emoji_update", + "audit_invite_create", + "audit_invite_delete", + "audit_guild_update", + # Add more audit keys if needed, e.g., "audit_stage_instance_create" + ] +) + class LoggingCog(commands.Cog): """Handles comprehensive server event logging via webhooks with granular toggling.""" + def __init__(self, bot: commands.Bot): self.bot = bot - self.session: Optional[aiohttp.ClientSession] = None # Session for webhooks - self.last_audit_log_ids: dict[int, Optional[int]] = {} # Store last ID per guild + self.session: Optional[aiohttp.ClientSession] = None # Session for webhooks + self.last_audit_log_ids: dict[int, Optional[int]] = ( + {} + ) # Store last ID per guild # Start the audit log poller task if the bot is ready, otherwise wait if bot.is_ready(): - asyncio.create_task(self.initialize_cog()) # Use async init helper + asyncio.create_task(self.initialize_cog()) # Use async init helper else: - asyncio.create_task(self.start_audit_log_poller_when_ready()) # Keep this for initial start + asyncio.create_task( + self.start_audit_log_poller_when_ready() + ) # Keep this for initial start class LogView(ui.LayoutView): """Simple view for log messages with helper methods.""" @@ -73,9 +111,7 @@ class LoggingCog(commands.Cog): self.header = ui.Section( accessory=( - ui.Thumbnail(media=author.display_avatar.url) - if author - else None + ui.Thumbnail(media=author.display_avatar.url) if author else None ) ) self.header.add_item(ui.TextDisplay(f"**{title}**")) @@ -126,6 +162,7 @@ class LoggingCog(commands.Cog): else: self.header.accessory = None self.header.add_item(ui.TextDisplay(name)) + def _user_display(self, user: Union[discord.Member, discord.User]) -> str: """Return display name, username and ID string for a user.""" display = user.display_name if isinstance(user, discord.Member) else user.name @@ -146,32 +183,47 @@ class LoggingCog(commands.Cog): """Fetch the latest audit log ID for each guild the bot is in.""" log.info("Initializing last audit log IDs for guilds...") for guild in self.bot.guilds: - if guild.id not in self.last_audit_log_ids: # Only initialize if not already set + if ( + guild.id not in self.last_audit_log_ids + ): # Only initialize if not already set try: if guild.me.guild_permissions.view_audit_log: async for entry in guild.audit_logs(limit=1): self.last_audit_log_ids[guild.id] = entry.id - log.debug(f"Initialized last_audit_log_id for guild {guild.id} to {entry.id}") - break # Only need the latest one + log.debug( + f"Initialized last_audit_log_id for guild {guild.id} to {entry.id}" + ) + break # Only need the latest one else: - log.warning(f"Missing 'View Audit Log' permission in guild {guild.id}. Cannot initialize audit log ID.") - self.last_audit_log_ids[guild.id] = None # Mark as unable to fetch + log.warning( + f"Missing 'View Audit Log' permission in guild {guild.id}. Cannot initialize audit log ID." + ) + self.last_audit_log_ids[guild.id] = ( + None # Mark as unable to fetch + ) except discord.Forbidden: - log.warning(f"Forbidden error fetching initial audit log ID for guild {guild.id}.") - self.last_audit_log_ids[guild.id] = None + log.warning( + f"Forbidden error fetching initial audit log ID for guild {guild.id}." + ) + self.last_audit_log_ids[guild.id] = None except discord.HTTPException as e: - log.error(f"HTTP error fetching initial audit log ID for guild {guild.id}: {e}") - self.last_audit_log_ids[guild.id] = None + log.error( + f"HTTP error fetching initial audit log ID for guild {guild.id}: {e}" + ) + self.last_audit_log_ids[guild.id] = None except Exception as e: - log.exception(f"Unexpected error fetching initial audit log ID for guild {guild.id}: {e}") - self.last_audit_log_ids[guild.id] = None # Mark as unable on other errors + log.exception( + f"Unexpected error fetching initial audit log ID for guild {guild.id}: {e}" + ) + self.last_audit_log_ids[guild.id] = ( + None # Mark as unable on other errors + ) log.info("Finished initializing audit log IDs.") - async def start_audit_log_poller_when_ready(self): """Waits until bot is ready, then initializes and starts the poller.""" await self.bot.wait_until_ready() - await self.initialize_cog() # Call the main init helper + await self.initialize_cog() # Call the main init helper async def cog_unload(self): """Clean up resources when the cog is unloaded.""" @@ -184,7 +236,9 @@ class LoggingCog(commands.Cog): async def _send_log_embed(self, guild: discord.Guild, embed: ui.LayoutView) -> None: """Sends the log view via the configured webhook for the guild.""" if not self.session or self.session.closed: - log.error(f"aiohttp session not available or closed in LoggingCog for guild {guild.id}. Cannot send log.") + log.error( + f"aiohttp session not available or closed in LoggingCog for guild {guild.id}. Cannot send log." + ) return webhook_url = await settings_manager.get_logging_webhook(guild.id) @@ -207,20 +261,27 @@ class LoggingCog(commands.Cog): ) # log.debug(f"Sent log embed via webhook for guild {guild.id}") # Can be noisy except ValueError as e: - log.exception(f"ValueError sending log via webhook for guild {guild.id}. Error: {e}") + log.exception( + f"ValueError sending log via webhook for guild {guild.id}. Error: {e}" + ) # Consider notifying an admin or disabling logging for this guild temporarily # await settings_manager.set_logging_webhook(guild.id, None) # Example: Auto-disable on invalid URL except (discord.Forbidden, discord.NotFound): - log.error(f"Webhook permissions error or webhook not found for guild {guild.id}. URL: {webhook_url}") + log.error( + f"Webhook permissions error or webhook not found for guild {guild.id}. URL: {webhook_url}" + ) # Consider notifying an admin or disabling logging for this guild temporarily # await settings_manager.set_logging_webhook(guild.id, None) # Example: Auto-disable on error except discord.HTTPException as e: log.error(f"HTTP error sending log via webhook for guild {guild.id}: {e}") except aiohttp.ClientError as e: - log.error(f"aiohttp client error sending log via webhook for guild {guild.id}: {e}") + log.error( + f"aiohttp client error sending log via webhook for guild {guild.id}: {e}" + ) except Exception as e: - log.exception(f"Unexpected error sending log via webhook for guild {guild.id}: {e}") - + log.exception( + f"Unexpected error sending log via webhook for guild {guild.id}: {e}" + ) def _create_log_embed( self, @@ -253,7 +314,9 @@ class LoggingCog(commands.Cog): if target_id and hasattr(embed, "footer_display"): existing_footer = embed.footer_display.content or "" separator = " | " if existing_footer else "" - embed.footer_display.content = f"{existing_footer}{separator}{id_name}: {target_id}" + embed.footer_display.content = ( + f"{existing_footer}{separator}{id_name}: {target_id}" + ) async def _check_log_enabled(self, guild_id: int, event_key: str) -> bool: """Checks if logging is enabled for a specific event key in a guild.""" @@ -262,12 +325,13 @@ class LoggingCog(commands.Cog): if not webhook_url: return False # Then, check if the specific event is enabled (defaults to True if not set) - enabled = await settings_manager.is_log_event_enabled(guild_id, event_key, default_enabled=True) + enabled = await settings_manager.is_log_event_enabled( + guild_id, event_key, default_enabled=True + ) # if not enabled: # log.debug(f"Logging disabled for event '{event_key}' in guild {guild_id}") return enabled - # --- Log Command Group --- @commands.group(name="log", invoke_without_command=True) @@ -292,42 +356,51 @@ class LoggingCog(commands.Cog): ) return if not channel.permissions_for(me).send_messages: - await ctx.send( - f"❌ I don't have the 'Send Messages' permission in {channel.mention}. Please grant it and try again (needed for webhook creation confirmation).", - allowed_mentions=AllowedMentions.none(), - ) - return + await ctx.send( + f"❌ I don't have the 'Send Messages' permission in {channel.mention}. Please grant it and try again (needed for webhook creation confirmation).", + allowed_mentions=AllowedMentions.none(), + ) + return # 2. Check existing webhook setting existing_url = await settings_manager.get_logging_webhook(guild.id) if existing_url: - # Try to fetch the existing webhook to see if it's still valid and in the right channel - try: - if not self.session or self.session.closed: self.session = aiohttp.ClientSession() # Ensure session exists - existing_webhook = await discord.Webhook.from_url(existing_url, session=self.session).fetch() - if existing_webhook.channel_id == channel.id: - await ctx.send( - f"✅ Logging is already configured for {channel.mention} using webhook `{existing_webhook.name}`.", - allowed_mentions=AllowedMentions.none(), - ) - return - else: - await ctx.send( - f"⚠️ Logging webhook is currently set for a different channel (<#{existing_webhook.channel_id}>). I will create a new one for {channel.mention}.", - allowed_mentions=AllowedMentions.none(), - ) - except (discord.NotFound, discord.Forbidden, ValueError, aiohttp.ClientError): - await ctx.send( - f"⚠️ Could not verify the existing webhook URL. It might be invalid or deleted. I will create a new one for {channel.mention}.", - allowed_mentions=AllowedMentions.none(), - ) - except Exception as e: - log.exception(f"Error fetching existing webhook during setup for guild {guild.id}") - await ctx.send( - f"⚠️ An error occurred while checking the existing webhook. Proceeding to create a new one for {channel.mention}.", - allowed_mentions=AllowedMentions.none(), - ) - + # Try to fetch the existing webhook to see if it's still valid and in the right channel + try: + if not self.session or self.session.closed: + self.session = aiohttp.ClientSession() # Ensure session exists + existing_webhook = await discord.Webhook.from_url( + existing_url, session=self.session + ).fetch() + if existing_webhook.channel_id == channel.id: + await ctx.send( + f"✅ Logging is already configured for {channel.mention} using webhook `{existing_webhook.name}`.", + allowed_mentions=AllowedMentions.none(), + ) + return + else: + await ctx.send( + f"⚠️ Logging webhook is currently set for a different channel (<#{existing_webhook.channel_id}>). I will create a new one for {channel.mention}.", + allowed_mentions=AllowedMentions.none(), + ) + except ( + discord.NotFound, + discord.Forbidden, + ValueError, + aiohttp.ClientError, + ): + await ctx.send( + f"⚠️ Could not verify the existing webhook URL. It might be invalid or deleted. I will create a new one for {channel.mention}.", + allowed_mentions=AllowedMentions.none(), + ) + except Exception as e: + log.exception( + f"Error fetching existing webhook during setup for guild {guild.id}" + ) + await ctx.send( + f"⚠️ An error occurred while checking the existing webhook. Proceeding to create a new one for {channel.mention}.", + allowed_mentions=AllowedMentions.none(), + ) # 3. Create new webhook try: @@ -337,19 +410,31 @@ class LoggingCog(commands.Cog): try: avatar_bytes = await self.bot.user.display_avatar.read() except Exception: - log.warning(f"Could not read bot avatar for webhook creation in guild {guild.id}.") + log.warning( + f"Could not read bot avatar for webhook creation in guild {guild.id}." + ) - new_webhook = await channel.create_webhook(name=webhook_name, avatar=avatar_bytes, reason=f"Logging setup by {ctx.author} ({ctx.author.id})") - log.info(f"Created logging webhook '{webhook_name}' in channel {channel.id} for guild {guild.id}") + new_webhook = await channel.create_webhook( + name=webhook_name, + avatar=avatar_bytes, + reason=f"Logging setup by {ctx.author} ({ctx.author.id})", + ) + log.info( + f"Created logging webhook '{webhook_name}' in channel {channel.id} for guild {guild.id}" + ) except discord.HTTPException as e: - log.error(f"Failed to create webhook in {channel.mention} for guild {guild.id}: {e}") + log.error( + f"Failed to create webhook in {channel.mention} for guild {guild.id}: {e}" + ) await ctx.send( f"❌ Failed to create webhook. Error: {e}. This could be due to hitting the channel webhook limit (15).", allowed_mentions=AllowedMentions.none(), ) return except Exception as e: - log.exception(f"Unexpected error creating webhook in {channel.mention} for guild {guild.id}") + log.exception( + f"Unexpected error creating webhook in {channel.mention} for guild {guild.id}" + ) await ctx.send( "❌ An unexpected error occurred while creating the webhook.", allowed_mentions=AllowedMentions.none(), @@ -377,13 +462,17 @@ class LoggingCog(commands.Cog): allowed_mentions=AllowedMentions.none(), ) except Exception as e: - log.error(f"Failed to send test message via new webhook for guild {guild.id}: {e}") - await ctx.send( - "⚠️ Could not send a test message via the new webhook, but the URL has been saved.", - allowed_mentions=AllowedMentions.none(), - ) + log.error( + f"Failed to send test message via new webhook for guild {guild.id}: {e}" + ) + await ctx.send( + "⚠️ Could not send a test message via the new webhook, but the URL has been saved.", + allowed_mentions=AllowedMentions.none(), + ) else: - log.error(f"Failed to save webhook URL {new_webhook.url} to database for guild {guild.id}") + log.error( + f"Failed to save webhook URL {new_webhook.url} to database for guild {guild.id}" + ) await ctx.send( "❌ Successfully created the webhook, but failed to save its URL to my settings. Please try again or contact support.", allowed_mentions=AllowedMentions.none(), @@ -391,13 +480,22 @@ class LoggingCog(commands.Cog): # Attempt to delete the created webhook to avoid orphans try: await new_webhook.delete(reason="Failed to save URL to settings") - log.info(f"Deleted orphaned webhook '{new_webhook.name}' for guild {guild.id}") + log.info( + f"Deleted orphaned webhook '{new_webhook.name}' for guild {guild.id}" + ) except Exception as del_e: - log.error(f"Failed to delete orphaned webhook '{new_webhook.name}' for guild {guild.id}: {del_e}") + log.error( + f"Failed to delete orphaned webhook '{new_webhook.name}' for guild {guild.id}: {del_e}" + ) @log_group.command(name="toggle") @commands.has_permissions(administrator=True) - async def log_toggle(self, ctx: commands.Context, event_key: str, enabled_status: Optional[bool] = None): + async def log_toggle( + self, + ctx: commands.Context, + event_key: str, + enabled_status: Optional[bool] = None, + ): """Toggles logging for a specific event type (on/off). Use 'log list_keys' to see available event keys. @@ -406,7 +504,7 @@ class LoggingCog(commands.Cog): Example: !log toggle audit_kick """ guild_id = ctx.guild.id - event_key = event_key.lower() # Ensure case-insensitivity + event_key = event_key.lower() # Ensure case-insensitivity if event_key not in ALL_EVENT_KEYS: await ctx.send( @@ -418,13 +516,17 @@ class LoggingCog(commands.Cog): # Determine the new status if enabled_status is None: # Fetch current status (defaults to True if not explicitly set) - current_status = await settings_manager.is_log_event_enabled(guild_id, event_key, default_enabled=True) + current_status = await settings_manager.is_log_event_enabled( + guild_id, event_key, default_enabled=True + ) new_status = not current_status else: new_status = enabled_status # Save the new status - success = await settings_manager.set_log_event_enabled(guild_id, event_key, new_status) + success = await settings_manager.set_log_event_enabled( + guild_id, event_key, new_status + ) if success: status_str = "ENABLED" if new_status else "DISABLED" @@ -445,7 +547,9 @@ class LoggingCog(commands.Cog): guild_id = ctx.guild.id toggles = await settings_manager.get_all_log_event_toggles(guild_id) - embed = discord.Embed(title=f"Logging Status for {ctx.guild.name}", color=discord.Color.blue()) + embed = discord.Embed( + title=f"Logging Status for {ctx.guild.name}", color=discord.Color.blue() + ) lines = [] for key in ALL_EVENT_KEYS: # Get status, defaulting to True if not explicitly in the DB/cache map @@ -456,80 +560,88 @@ class LoggingCog(commands.Cog): # Paginate if too long for one embed description description = "" for line in lines: - if len(description) + len(line) + 1 > 4000: # Embed description limit (approx) + if ( + len(description) + len(line) + 1 > 4000 + ): # Embed description limit (approx) embed.description = description await ctx.send(embed=embed, allowed_mentions=AllowedMentions.none()) - description = line + "\n" # Start new description - embed = discord.Embed(color=discord.Color.blue()) # New embed for continuation + description = line + "\n" # Start new description + embed = discord.Embed( + color=discord.Color.blue() + ) # New embed for continuation else: description += line + "\n" - if description: # Send the last embed page - embed.description = description.strip() - await ctx.send(embed=embed, allowed_mentions=AllowedMentions.none()) - + if description: # Send the last embed page + embed.description = description.strip() + await ctx.send(embed=embed, allowed_mentions=AllowedMentions.none()) @log_group.command(name="list_keys") async def log_list_keys(self, ctx: commands.Context): """Lists all valid event keys for use with the 'log toggle' command.""" - embed = discord.Embed(title="Available Logging Event Keys", color=discord.Color.purple()) + embed = discord.Embed( + title="Available Logging Event Keys", color=discord.Color.purple() + ) keys_text = "\n".join(f"`{key}`" for key in ALL_EVENT_KEYS) # Paginate if needed if len(keys_text) > 4000: - parts = [] - current_part = "" - for key in ALL_EVENT_KEYS: - line = f"`{key}`\n" - if len(current_part) + len(line) > 4000: - parts.append(current_part) - current_part = line - else: - current_part += line - if current_part: - parts.append(current_part) + parts = [] + current_part = "" + for key in ALL_EVENT_KEYS: + line = f"`{key}`\n" + if len(current_part) + len(line) > 4000: + parts.append(current_part) + current_part = line + else: + current_part += line + if current_part: + parts.append(current_part) - embed.description = parts[0] - await ctx.send(embed=embed, allowed_mentions=AllowedMentions.none()) - for part in parts[1:]: - await ctx.send( - embed=discord.Embed(description=part, color=discord.Color.purple()), - allowed_mentions=AllowedMentions.none(), - ) + embed.description = parts[0] + await ctx.send(embed=embed, allowed_mentions=AllowedMentions.none()) + for part in parts[1:]: + await ctx.send( + embed=discord.Embed(description=part, color=discord.Color.purple()), + allowed_mentions=AllowedMentions.none(), + ) else: - embed.description = keys_text - await ctx.send(embed=embed, allowed_mentions=AllowedMentions.none()) - + embed.description = keys_text + await ctx.send(embed=embed, allowed_mentions=AllowedMentions.none()) # --- Thread Events --- @commands.Cog.listener() async def on_thread_create(self, thread: discord.Thread): guild = thread.guild event_key = "thread_create" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return embed = self._create_log_embed( title="🧵 Thread Created", description=f"Thread {thread.mention} (`{thread.name}`) created in {thread.parent.mention}.", color=discord.Color.dark_blue(), # Creator might be available via thread.owner_id or audit log - footer=f"Thread ID: {thread.id} | Parent ID: {thread.parent_id}" + footer=f"Thread ID: {thread.id} | Parent ID: {thread.parent_id}", ) - if thread.owner: # Sometimes owner isn't cached immediately - embed.set_author(name=str(thread.owner), icon_url=thread.owner.display_avatar.url) + if thread.owner: # Sometimes owner isn't cached immediately + embed.set_author( + name=str(thread.owner), icon_url=thread.owner.display_avatar.url + ) await self._send_log_embed(guild, embed) @commands.Cog.listener() async def on_thread_delete(self, thread: discord.Thread): guild = thread.guild event_key = "thread_delete" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return embed = self._create_log_embed( title="🗑️ Thread Deleted", description=f"Thread `{thread.name}` deleted from {thread.parent.mention}.", color=discord.Color.dark_grey(), - footer=f"Thread ID: {thread.id} | Parent ID: {thread.parent_id}" + footer=f"Thread ID: {thread.id} | Parent ID: {thread.parent_id}", ) # Audit log needed for deleter await self._send_log_embed(guild, embed) @@ -538,21 +650,32 @@ class LoggingCog(commands.Cog): async def on_thread_update(self, before: discord.Thread, after: discord.Thread): guild = after.guild event_key = "thread_update" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return changes = [] - if before.name != after.name: changes.append(f"**Name:** `{before.name}` → `{after.name}`") - if before.archived != after.archived: changes.append(f"**Archived:** `{before.archived}` → `{after.archived}`") - if before.locked != after.locked: changes.append(f"**Locked:** `{before.locked}` → `{after.locked}`") - if before.slowmode_delay != after.slowmode_delay: changes.append(f"**Slowmode:** `{before.slowmode_delay}s` → `{after.slowmode_delay}s`") - if before.auto_archive_duration != after.auto_archive_duration: changes.append(f"**Auto-Archive:** `{before.auto_archive_duration} mins` → `{after.auto_archive_duration} mins`") + if before.name != after.name: + changes.append(f"**Name:** `{before.name}` → `{after.name}`") + if before.archived != after.archived: + changes.append(f"**Archived:** `{before.archived}` → `{after.archived}`") + if before.locked != after.locked: + changes.append(f"**Locked:** `{before.locked}` → `{after.locked}`") + if before.slowmode_delay != after.slowmode_delay: + changes.append( + f"**Slowmode:** `{before.slowmode_delay}s` → `{after.slowmode_delay}s`" + ) + if before.auto_archive_duration != after.auto_archive_duration: + changes.append( + f"**Auto-Archive:** `{before.auto_archive_duration} mins` → `{after.auto_archive_duration} mins`" + ) if changes: embed = self._create_log_embed( title="📝 Thread Updated", - description=f"Thread {after.mention} in {after.parent.mention} updated:\n" + "\n".join(changes), + description=f"Thread {after.mention} in {after.parent.mention} updated:\n" + + "\n".join(changes), color=discord.Color.blue(), - footer=f"Thread ID: {after.id}" + footer=f"Thread ID: {after.id}", ) # Audit log needed for updater await self._send_log_embed(guild, embed) @@ -562,15 +685,16 @@ class LoggingCog(commands.Cog): thread = member.thread guild = thread.guild event_key = "thread_member_join" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return - user = await self.bot.fetch_user(member.id) # Get user object + user = await self.bot.fetch_user(member.id) # Get user object embed = self._create_log_embed( title="➕ Member Joined Thread", description=f"{self._user_display(user)} joined thread {thread.mention}.", color=discord.Color.dark_green(), author=user, - footer=f"Thread ID: {thread.id} | User ID: {user.id}" + footer=f"Thread ID: {thread.id} | User ID: {user.id}", ) await self._send_log_embed(guild, embed) @@ -579,47 +703,49 @@ class LoggingCog(commands.Cog): thread = member.thread guild = thread.guild event_key = "thread_member_remove" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return - user = await self.bot.fetch_user(member.id) # Get user object + user = await self.bot.fetch_user(member.id) # Get user object embed = self._create_log_embed( title="➖ Member Left Thread", description=f"{self._user_display(user)} left thread {thread.mention}.", color=discord.Color.dark_orange(), author=user, - footer=f"Thread ID: {thread.id} | User ID: {user.id}" + footer=f"Thread ID: {thread.id} | User ID: {user.id}", ) await self._send_log_embed(guild, embed) - # --- Webhook Events --- @commands.Cog.listener() async def on_webhooks_update(self, channel: discord.abc.GuildChannel): """Logs when webhooks are updated in a channel.""" guild = channel.guild event_key = "webhook_update" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return embed = self._create_log_embed( title="🎣 Webhooks Updated", description=f"Webhooks were updated in channel {channel.mention}.\n*Audit log may contain specific details and updater.*", color=discord.Color.greyple(), - footer=f"Channel ID: {channel.id}" + footer=f"Channel ID: {channel.id}", ) await self._send_log_embed(guild, embed) - # --- Event Listeners --- @commands.Cog.listener() async def on_ready(self): """Initialize when the cog is ready (called after bot on_ready).""" - log.info(f'{self.__class__.__name__} cog is ready.') + log.info(f"{self.__class__.__name__} cog is ready.") # Initialization is now handled by initialize_cog called from __init__ or start_audit_log_poller_when_ready # Ensure the poller is running if it wasn't started earlier if self.bot.is_ready() and not self.poll_audit_log.is_running(): - log.warning("Poll audit log task was not running after on_ready, attempting to start.") - await self.initialize_cog() # Re-initialize just in case + log.warning( + "Poll audit log task was not running after on_ready, attempting to start." + ) + await self.initialize_cog() # Re-initialize just in case @commands.Cog.listener() async def on_guild_join(self, guild: discord.Guild): @@ -630,13 +756,19 @@ class LoggingCog(commands.Cog): if guild.me.guild_permissions.view_audit_log: async for entry in guild.audit_logs(limit=1): self.last_audit_log_ids[guild.id] = entry.id - log.debug(f"Initialized last_audit_log_id for new guild {guild.id} to {entry.id}") + log.debug( + f"Initialized last_audit_log_id for new guild {guild.id} to {entry.id}" + ) break else: - log.warning(f"Missing 'View Audit Log' permission in new guild {guild.id}.") - self.last_audit_log_ids[guild.id] = None + log.warning( + f"Missing 'View Audit Log' permission in new guild {guild.id}." + ) + self.last_audit_log_ids[guild.id] = None except Exception as e: - log.exception(f"Error fetching initial audit log ID for new guild {guild.id}: {e}") + log.exception( + f"Error fetching initial audit log ID for new guild {guild.id}: {e}" + ) self.last_audit_log_ids[guild.id] = None @commands.Cog.listener() @@ -648,29 +780,34 @@ class LoggingCog(commands.Cog): # but the guild_settings table uses ON DELETE CASCADE, so it *should* be handled automatically # when the guild is removed from the guilds table in main.py's on_guild_remove. - # --- Member Events --- (Keep existing event handlers, they now use _send_log_embed) @commands.Cog.listener() async def on_member_join(self, member: discord.Member): guild = member.guild event_key = "member_join" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return embed = self._create_log_embed( title="📥 Member Joined", description=f"{self._user_display(member)} joined the server.", color=discord.Color.green(), - author=member + author=member, # Footer already includes User ID via _create_log_embed ) - embed.add_field(name="Account Created", value=discord.utils.format_dt(member.created_at, style='F'), inline=False) + embed.add_field( + name="Account Created", + value=discord.utils.format_dt(member.created_at, style="F"), + inline=False, + ) await self._send_log_embed(member.guild, embed) @commands.Cog.listener() async def on_member_remove(self, member: discord.Member): guild = member.guild event_key = "member_remove" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return # This event doesn't tell us if it was a kick or leave. Audit log polling will handle kicks. # We log it as a generic "left" event here. @@ -678,22 +815,25 @@ class LoggingCog(commands.Cog): title="📤 Member Left", description=f"{self._user_display(member)} left the server.", color=discord.Color.orange(), - author=member + author=member, ) self._add_id_footer(embed, member, id_name="User ID") await self._send_log_embed(member.guild, embed) @commands.Cog.listener() - async def on_member_ban(self, guild: discord.Guild, user: Union[discord.User, discord.Member]): + async def on_member_ban( + self, guild: discord.Guild, user: Union[discord.User, discord.Member] + ): event_key = "member_ban_event" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return # Note: Ban reason isn't available directly in this event. Audit log might have it. embed = self._create_log_embed( - title="🔨 Member Banned (Event)", # Clarify this is the event, audit log has more details + title="🔨 Member Banned (Event)", # Clarify this is the event, audit log has more details description=f"{self._user_display(user)} was banned.\n*Audit log may contain moderator and reason.*", color=discord.Color.red(), - author=user # User who was banned + author=user, # User who was banned ) self._add_id_footer(embed, user, id_name="User ID") await self._send_log_embed(guild, embed) @@ -701,13 +841,14 @@ class LoggingCog(commands.Cog): @commands.Cog.listener() async def on_member_unban(self, guild: discord.Guild, user: discord.User): event_key = "member_unban" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return embed = self._create_log_embed( title="🔓 Member Unbanned", description=f"{self._user_display(user)} was unbanned.", color=discord.Color.blurple(), - author=user # User who was unbanned + author=user, # User who was unbanned ) self._add_id_footer(embed, user, id_name="User ID") await self._send_log_embed(guild, embed) @@ -716,12 +857,15 @@ class LoggingCog(commands.Cog): async def on_member_update(self, before: discord.Member, after: discord.Member): guild = after.guild event_key = "member_update" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return changes = [] # Nickname change if before.nick != after.nick: - changes.append(f"**Nickname:** `{before.nick or 'None'}` → `{after.nick or 'None'}`") + changes.append( + f"**Nickname:** `{before.nick or 'None'}` → `{after.nick or 'None'}`" + ) # Role changes (handled more reliably by audit log for who did it) if before.roles != after.roles: added_roles = [r.mention for r in after.roles if r not in before.roles] @@ -732,40 +876,44 @@ class LoggingCog(commands.Cog): changes.append(f"**Roles Removed:** {', '.join(removed_roles)}") # Timeout change if before.timed_out_until != after.timed_out_until: - if after.timed_out_until: - timeout_duration = discord.utils.format_dt(after.timed_out_until, style='R') - changes.append(f"**Timed Out Until:** {timeout_duration}") - else: - changes.append("**Timeout Removed**") + if after.timed_out_until: + timeout_duration = discord.utils.format_dt( + after.timed_out_until, style="R" + ) + changes.append(f"**Timed Out Until:** {timeout_duration}") + else: + changes.append("**Timeout Removed**") # TODO: Add other trackable changes like status if needed # Add avatar change detection if before.display_avatar != after.display_avatar: - changes.append(f"**Avatar Changed**") # URL is enough, no need to show old/new + changes.append( + f"**Avatar Changed**" + ) # URL is enough, no need to show old/new if changes: embed = self._create_log_embed( title="👤 Member Updated", description=f"{after.mention}\n" + "\n".join(changes), color=discord.Color.yellow(), - author=after + author=after, ) self._add_id_footer(embed, after, id_name="User ID") await self._send_log_embed(guild, embed) - # --- Role Events --- @commands.Cog.listener() async def on_guild_role_create(self, role: discord.Role): guild = role.guild event_key = "role_create_event" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return embed = self._create_log_embed( title="✨ Role Created (Event)", description=f"Role {role.mention} (`{role.name}`) was created.\n*Audit log may contain creator.*", - color=discord.Color.teal() + color=discord.Color.teal(), ) self._add_id_footer(embed, role, id_name="Role ID") await self._send_log_embed(role.guild, embed) @@ -774,12 +922,13 @@ class LoggingCog(commands.Cog): async def on_guild_role_delete(self, role: discord.Role): guild = role.guild event_key = "role_delete_event" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return embed = self._create_log_embed( title="🗑️ Role Deleted (Event)", description=f"Role `{role.name}` was deleted.\n*Audit log may contain deleter.*", - color=discord.Color.dark_teal() + color=discord.Color.dark_teal(), ) self._add_id_footer(embed, role, id_name="Role ID") await self._send_log_embed(role.guild, embed) @@ -788,7 +937,8 @@ class LoggingCog(commands.Cog): async def on_guild_role_update(self, before: discord.Role, after: discord.Role): guild = after.guild event_key = "role_update_event" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return changes = [] if before.name != after.name: @@ -798,7 +948,9 @@ class LoggingCog(commands.Cog): if before.hoist != after.hoist: changes.append(f"**Hoisted:** `{before.hoist}` → `{after.hoist}`") if before.mentionable != after.mentionable: - changes.append(f"**Mentionable:** `{before.mentionable}` → `{after.mentionable}`") + changes.append( + f"**Mentionable:** `{before.mentionable}` → `{after.mentionable}`" + ) if before.permissions != after.permissions: # Comparing permissions can be complex, just note that they changed. # Audit log provides specifics on permission changes. @@ -808,30 +960,31 @@ class LoggingCog(commands.Cog): # Add position change if before.position != after.position: - changes.append(f"**Position:** `{before.position}` → `{after.position}`") + changes.append(f"**Position:** `{before.position}` → `{after.position}`") if changes: embed = self._create_log_embed( title="🔧 Role Updated (Event)", - description=f"Role {after.mention} updated.\n*Audit log may contain updater and specific permission changes.*\n" + "\n".join(changes), - color=discord.Color.blue() + description=f"Role {after.mention} updated.\n*Audit log may contain updater and specific permission changes.*\n" + + "\n".join(changes), + color=discord.Color.blue(), ) self._add_id_footer(embed, after, id_name="Role ID") await self._send_log_embed(guild, embed) - # --- Channel Events --- @commands.Cog.listener() async def on_guild_channel_create(self, channel: discord.abc.GuildChannel): guild = channel.guild event_key = "channel_create_event" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return ch_type = str(channel.type).capitalize() embed = self._create_log_embed( title=f"➕ {ch_type} Channel Created (Event)", description=f"Channel {channel.mention} (`{channel.name}`) was created.\n*Audit log may contain creator.*", - color=discord.Color.green() + color=discord.Color.green(), ) self._add_id_footer(embed, channel, id_name="Channel ID") await self._send_log_embed(channel.guild, embed) @@ -840,40 +993,54 @@ class LoggingCog(commands.Cog): async def on_guild_channel_delete(self, channel: discord.abc.GuildChannel): guild = channel.guild event_key = "channel_delete_event" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return ch_type = str(channel.type).capitalize() embed = self._create_log_embed( title=f"➖ {ch_type} Channel Deleted (Event)", description=f"Channel `{channel.name}` was deleted.\n*Audit log may contain deleter.*", - color=discord.Color.red() + color=discord.Color.red(), ) self._add_id_footer(embed, channel, id_name="Channel ID") await self._send_log_embed(channel.guild, embed) @commands.Cog.listener() - async def on_guild_channel_update(self, before: discord.abc.GuildChannel, after: discord.abc.GuildChannel): + async def on_guild_channel_update( + self, before: discord.abc.GuildChannel, after: discord.abc.GuildChannel + ): guild = after.guild event_key = "channel_update_event" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return changes = [] ch_type = str(after.type).capitalize() if before.name != after.name: changes.append(f"**Name:** `{before.name}` → `{after.name}`") - if isinstance(before, discord.TextChannel) and isinstance(after, discord.TextChannel): + if isinstance(before, discord.TextChannel) and isinstance( + after, discord.TextChannel + ): if before.topic != after.topic: - changes.append(f"**Topic:** `{before.topic or 'None'}` → `{after.topic or 'None'}`") + changes.append( + f"**Topic:** `{before.topic or 'None'}` → `{after.topic or 'None'}`" + ) if before.slowmode_delay != after.slowmode_delay: - changes.append(f"**Slowmode:** `{before.slowmode_delay}s` → `{after.slowmode_delay}s`") + changes.append( + f"**Slowmode:** `{before.slowmode_delay}s` → `{after.slowmode_delay}s`" + ) if before.nsfw != after.nsfw: - changes.append(f"**NSFW:** `{before.nsfw}` → `{after.nsfw}`") - if isinstance(before, discord.VoiceChannel) and isinstance(after, discord.VoiceChannel): - if before.bitrate != after.bitrate: - changes.append(f"**Bitrate:** `{before.bitrate}` → `{after.bitrate}`") - if before.user_limit != after.user_limit: - changes.append(f"**User Limit:** `{before.user_limit}` → `{after.user_limit}`") + changes.append(f"**NSFW:** `{before.nsfw}` → `{after.nsfw}`") + if isinstance(before, discord.VoiceChannel) and isinstance( + after, discord.VoiceChannel + ): + if before.bitrate != after.bitrate: + changes.append(f"**Bitrate:** `{before.bitrate}` → `{after.bitrate}`") + if before.user_limit != after.user_limit: + changes.append( + f"**User Limit:** `{before.user_limit}` → `{after.user_limit}`" + ) # Permission overwrites change if before.overwrites != after.overwrites: # Identify changes without detailing every permission bit @@ -881,43 +1048,55 @@ class LoggingCog(commands.Cog): after_targets = set(after.overwrites.keys()) added_targets = after_targets - before_targets removed_targets = before_targets - after_targets - updated_targets = before_targets.intersection(after_targets) # Targets present before and after + updated_targets = before_targets.intersection( + after_targets + ) # Targets present before and after overwrite_changes = [] if added_targets: - overwrite_changes.append(f"Added overwrites for: {', '.join([f'<@{t.id}>' if isinstance(t, discord.Member) else f'<@&{t.id}>' for t in added_targets])}") + overwrite_changes.append( + f"Added overwrites for: {', '.join([f'<@{t.id}>' if isinstance(t, discord.Member) else f'<@&{t.id}>' for t in added_targets])}" + ) if removed_targets: - overwrite_changes.append(f"Removed overwrites for: {', '.join([f'<@{t.id}>' if isinstance(t, discord.Member) else f'<@&{t.id}>' for t in removed_targets])}") + overwrite_changes.append( + f"Removed overwrites for: {', '.join([f'<@{t.id}>' if isinstance(t, discord.Member) else f'<@&{t.id}>' for t in removed_targets])}" + ) # Check if any *values* changed for targets present both before and after - if any(before.overwrites[t] != after.overwrites[t] for t in updated_targets): - overwrite_changes.append(f"Modified overwrites for: {', '.join([f'<@{t.id}>' if isinstance(t, discord.Member) else f'<@&{t.id}>' for t in updated_targets if before.overwrites[t] != after.overwrites[t]])}") + if any( + before.overwrites[t] != after.overwrites[t] for t in updated_targets + ): + overwrite_changes.append( + f"Modified overwrites for: {', '.join([f'<@{t.id}>' if isinstance(t, discord.Member) else f'<@&{t.id}>' for t in updated_targets if before.overwrites[t] != after.overwrites[t]])}" + ) if overwrite_changes: - changes.append(f"**Permission Overwrites:**\n - " + '\n - '.join(overwrite_changes)) + changes.append( + f"**Permission Overwrites:**\n - " + "\n - ".join(overwrite_changes) + ) else: - changes.append("**Permission Overwrites Updated** (No specific target changes detected by event)") - + changes.append( + "**Permission Overwrites Updated** (No specific target changes detected by event)" + ) # Add position change if before.position != after.position: - changes.append(f"**Position:** `{before.position}` → `{after.position}`") + changes.append(f"**Position:** `{before.position}` → `{after.position}`") # Add category change if before.category != after.category: - before_cat = before.category.mention if before.category else 'None' - after_cat = after.category.mention if after.category else 'None' - changes.append(f"**Category:** {before_cat} → {after_cat}") - + before_cat = before.category.mention if before.category else "None" + after_cat = after.category.mention if after.category else "None" + changes.append(f"**Category:** {before_cat} → {after_cat}") if changes: embed = self._create_log_embed( title=f"📝 {ch_type} Channel Updated (Event)", - description=f"Channel {after.mention} updated.\n*Audit log may contain updater and specific permission changes.*\n" + "\n".join(changes), - color=discord.Color.yellow() + description=f"Channel {after.mention} updated.\n*Audit log may contain updater and specific permission changes.*\n" + + "\n".join(changes), + color=discord.Color.yellow(), ) self._add_id_footer(embed, after, id_name="Channel ID") await self._send_log_embed(guild, embed) - # --- Message Events --- @commands.Cog.listener() async def on_message_edit(self, before: discord.Message, after: discord.Message): @@ -925,38 +1104,54 @@ class LoggingCog(commands.Cog): if before.author.bot or before.content == after.content: return guild = after.guild - if not guild: return # Ignore DMs + if not guild: + return # Ignore DMs # Check if logging is enabled *after* initial checks event_key = "message_edit" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return embed = self._create_log_embed( title="✏️ Message Edited", description=f"Message edited in {after.channel.mention} [Jump to Message]({after.jump_url})", color=discord.Color.light_grey(), - author=after.author + author=after.author, ) # Add fields for before and after, handling potential length limits - embed.add_field(name="Before", value=before.content[:1020] + ('...' if len(before.content) > 1020 else '') or "`Empty Message`", inline=False) - embed.add_field(name="After", value=after.content[:1020] + ('...' if len(after.content) > 1020 else '') or "`Empty Message`", inline=False) - self._add_id_footer(embed, after, id_name="Message ID") # Add message ID + embed.add_field( + name="Before", + value=before.content[:1020] + ("..." if len(before.content) > 1020 else "") + or "`Empty Message`", + inline=False, + ) + embed.add_field( + name="After", + value=after.content[:1020] + ("..." if len(after.content) > 1020 else "") + or "`Empty Message`", + inline=False, + ) + self._add_id_footer(embed, after, id_name="Message ID") # Add message ID await self._send_log_embed(guild, embed) @commands.Cog.listener() async def on_message_delete(self, message: discord.Message): # Ignore deletes from bots or messages without content/embeds/attachments - if message.author.bot or (not message.content and not message.embeds and not message.attachments): - # Allow logging bot message deletions if needed, but can be noisy - # Example: if message.author.id == self.bot.user.id: pass # Log bot's own deletions - # else: return - return + if message.author.bot or ( + not message.content and not message.embeds and not message.attachments + ): + # Allow logging bot message deletions if needed, but can be noisy + # Example: if message.author.id == self.bot.user.id: pass # Log bot's own deletions + # else: return + return guild = message.guild - if not guild: return # Ignore DMs + if not guild: + return # Ignore DMs # Check if logging is enabled *after* initial checks event_key = "message_delete" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return desc = f"Message deleted in {message.channel.mention}" # Audit log needed for *who* deleted it, if not the author themselves @@ -966,70 +1161,89 @@ class LoggingCog(commands.Cog): title="🗑️ Message Deleted", description=f"{desc}\n*Audit log may contain deleter if not the author.*", color=discord.Color.dark_grey(), - author=message.author + author=message.author, ) if message.content: - embed.add_field(name="Content", value=message.content[:1020] + ('...' if len(message.content) > 1020 else '') or "`Empty Message`", inline=False) + embed.add_field( + name="Content", + value=message.content[:1020] + + ("..." if len(message.content) > 1020 else "") + or "`Empty Message`", + inline=False, + ) if message.attachments: atts = [f"[{att.filename}]({att.url})" for att in message.attachments] embed.add_field(name="Attachments", value=", ".join(atts), inline=False) - self._add_id_footer(embed, message, id_name="Message ID") # Add message ID + self._add_id_footer(embed, message, id_name="Message ID") # Add message ID await self._send_log_embed(guild, embed) - # --- Reaction Events --- @commands.Cog.listener() - async def on_reaction_add(self, reaction: discord.Reaction, user: Union[discord.Member, discord.User]): - if user.bot: return + async def on_reaction_add( + self, reaction: discord.Reaction, user: Union[discord.Member, discord.User] + ): + if user.bot: + return guild = reaction.message.guild - if not guild: return # Should not happen in guilds but safety check + if not guild: + return # Should not happen in guilds but safety check # Check if logging is enabled *after* initial checks event_key = "reaction_add" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return embed = self._create_log_embed( title="👍 Reaction Added", description=f"{self._user_display(user)} added {reaction.emoji} to a message by {self._user_display(reaction.message.author)} in {reaction.message.channel.mention} [Jump to Message]({reaction.message.jump_url})", color=discord.Color.gold(), - author=user + author=user, ) self._add_id_footer(embed, reaction.message, id_name="Message ID") await self._send_log_embed(guild, embed) @commands.Cog.listener() - async def on_reaction_remove(self, reaction: discord.Reaction, user: Union[discord.Member, discord.User]): - if user.bot: return + async def on_reaction_remove( + self, reaction: discord.Reaction, user: Union[discord.Member, discord.User] + ): + if user.bot: + return guild = reaction.message.guild - if not guild: return # Should not happen in guilds but safety check + if not guild: + return # Should not happen in guilds but safety check # Check if logging is enabled *after* initial checks event_key = "reaction_remove" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return embed = self._create_log_embed( title="👎 Reaction Removed", description=f"{self._user_display(user)} removed {reaction.emoji} from a message by {self._user_display(reaction.message.author)} in {reaction.message.channel.mention} [Jump to Message]({reaction.message.jump_url})", color=discord.Color.dark_gold(), - author=user + author=user, ) self._add_id_footer(embed, reaction.message, id_name="Message ID") await self._send_log_embed(guild, embed) @commands.Cog.listener() - async def on_reaction_clear(self, message: discord.Message, _: list[discord.Reaction]): + async def on_reaction_clear( + self, message: discord.Message, _: list[discord.Reaction] + ): guild = message.guild - if not guild: return # Should not happen in guilds but safety check + if not guild: + return # Should not happen in guilds but safety check # Check if logging is enabled *after* initial checks event_key = "reaction_clear" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return embed = self._create_log_embed( title="💥 All Reactions Cleared", description=f"All reactions were cleared from a message by {message.author.mention} in {message.channel.mention} [Jump to Message]({message.jump_url})\n*Audit log may contain moderator.*", color=discord.Color.orange(), - author=message.author # Usually the author or a mod clears reactions + author=message.author, # Usually the author or a mod clears reactions ) self._add_id_footer(embed, message, id_name="Message ID") await self._send_log_embed(guild, embed) @@ -1037,28 +1251,35 @@ class LoggingCog(commands.Cog): @commands.Cog.listener() async def on_reaction_clear_emoji(self, reaction: discord.Reaction): guild = reaction.message.guild - if not guild: return # Should not happen in guilds but safety check + if not guild: + return # Should not happen in guilds but safety check # Check if logging is enabled *after* initial checks event_key = "reaction_clear_emoji" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return embed = self._create_log_embed( title="💥 Emoji Reactions Cleared", description=f"All {reaction.emoji} reactions were cleared from a message by {reaction.message.author.mention} in {reaction.message.channel.mention} [Jump to Message]({reaction.message.jump_url})\n*Audit log may contain moderator.*", color=discord.Color.dark_orange(), - author=reaction.message.author # Usually the author or a mod clears reactions + author=reaction.message.author, # Usually the author or a mod clears reactions ) self._add_id_footer(embed, reaction.message, id_name="Message ID") await self._send_log_embed(guild, embed) - # --- Voice State Events --- @commands.Cog.listener() - async def on_voice_state_update(self, member: discord.Member, before: discord.VoiceState, after: discord.VoiceState): + async def on_voice_state_update( + self, + member: discord.Member, + before: discord.VoiceState, + after: discord.VoiceState, + ): guild = member.guild event_key = "voice_state_update" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return action = "" details = "" @@ -1075,7 +1296,11 @@ class LoggingCog(commands.Cog): details = f"Left {before.channel.mention}" color = discord.Color.orange() # Move VC - elif before.channel is not None and after.channel is not None and before.channel != after.channel: + elif ( + before.channel is not None + and after.channel is not None + and before.channel != after.channel + ): action = "🔄 Moved Voice Channel" details = f"Moved from {before.channel.mention} to {after.channel.mention}" color = discord.Color.blue() @@ -1104,59 +1329,75 @@ class LoggingCog(commands.Cog): # action = " Webcam Update" # details = f"Webcam On: `{after.self_video}`" else: - return # No relevant change detected + return # No relevant change detected embed = self._create_log_embed( title=action, description=f"{self._user_display(member)}\n{details}", color=color, - author=member + author=member, ) self._add_id_footer(embed, member, id_name="User ID") await self._send_log_embed(guild, embed) - # --- Guild/Server Events --- @commands.Cog.listener() async def on_guild_update(self, before: discord.Guild, after: discord.Guild): - guild = after # Use 'after' for guild ID check + guild = after # Use 'after' for guild ID check event_key = "guild_update_event" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return changes = [] if before.name != after.name: changes.append(f"**Name:** `{before.name}` → `{after.name}`") if before.description != after.description: - changes.append(f"**Description:** `{before.description or 'None'}` → `{after.description or 'None'}`") + changes.append( + f"**Description:** `{before.description or 'None'}` → `{after.description or 'None'}`" + ) if before.icon != after.icon: - changes.append(f"**Icon Changed**") # URL comparison can be tricky + changes.append(f"**Icon Changed**") # URL comparison can be tricky if before.banner != after.banner: changes.append(f"**Banner Changed**") if before.owner != after.owner: - changes.append(f"**Owner:** {before.owner.mention if before.owner else 'None'} → {after.owner.mention if after.owner else 'None'}") + changes.append( + f"**Owner:** {before.owner.mention if before.owner else 'None'} → {after.owner.mention if after.owner else 'None'}" + ) # Add other relevant changes: region, verification_level, explicit_content_filter, etc. if before.verification_level != after.verification_level: - changes.append(f"**Verification Level:** `{before.verification_level}` → `{after.verification_level}`") + changes.append( + f"**Verification Level:** `{before.verification_level}` → `{after.verification_level}`" + ) if before.explicit_content_filter != after.explicit_content_filter: - changes.append(f"**Explicit Content Filter:** `{before.explicit_content_filter}` → `{after.explicit_content_filter}`") + changes.append( + f"**Explicit Content Filter:** `{before.explicit_content_filter}` → `{after.explicit_content_filter}`" + ) if before.system_channel != after.system_channel: - changes.append(f"**System Channel:** {before.system_channel.mention if before.system_channel else 'None'} → {after.system_channel.mention if after.system_channel else 'None'}") - + changes.append( + f"**System Channel:** {before.system_channel.mention if before.system_channel else 'None'} → {after.system_channel.mention if after.system_channel else 'None'}" + ) if changes: embed = self._create_log_embed( # title="⚙️ Guild Updated", # Removed duplicate title title="⚙️ Guild Updated (Event)", - description="Server settings were updated.\n*Audit log may contain updater.*\n" + "\n".join(changes), - color=discord.Color.dark_purple() + description="Server settings were updated.\n*Audit log may contain updater.*\n" + + "\n".join(changes), + color=discord.Color.dark_purple(), ) self._add_id_footer(embed, after, id_name="Guild ID") await self._send_log_embed(after, embed) @commands.Cog.listener() - async def on_guild_emojis_update(self, guild: discord.Guild, before: tuple[discord.Emoji, ...], after: tuple[discord.Emoji, ...]): + async def on_guild_emojis_update( + self, + guild: discord.Guild, + before: tuple[discord.Emoji, ...], + after: tuple[discord.Emoji, ...], + ): event_key = "emoji_update_event" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return added = [e for e in after if e not in before] removed = [e for e in before if e not in after] @@ -1167,46 +1408,51 @@ class LoggingCog(commands.Cog): after_map = {e.id: e for e in after} for e_id, e_after in after_map.items(): if e_id in before_map and before_map[e_id].name != e_after.name: - renamed_before.append(before_map[e_id]) - renamed_after.append(e_after) - + renamed_before.append(before_map[e_id]) + renamed_after.append(e_after) desc = "" if added: desc += f"**Added:** {', '.join([str(e) for e in added])}\n" if removed: - desc += f"**Removed:** {', '.join([f'`{e.name}`' for e in removed])}\n" # Can't display removed emoji easily + desc += f"**Removed:** {', '.join([f'`{e.name}`' for e in removed])}\n" # Can't display removed emoji easily if renamed_before: - desc += "**Renamed:**\n" + "\n".join([f"`{b.name}` → {a}" for b, a in zip(renamed_before, renamed_after)]) - + desc += "**Renamed:**\n" + "\n".join( + [f"`{b.name}` → {a}" for b, a in zip(renamed_before, renamed_after)] + ) if desc: embed = self._create_log_embed( # title="😀 Emojis Updated", # Removed duplicate title title="😀 Emojis Updated (Event)", description=f"*Audit log may contain updater.*\n{desc.strip()}", - color=discord.Color.magenta() + color=discord.Color.magenta(), ) self._add_id_footer(embed, guild, id_name="Guild ID") await self._send_log_embed(guild, embed) - # --- Invite Events --- @commands.Cog.listener() async def on_invite_create(self, invite: discord.Invite): guild = invite.guild - if not guild: return + if not guild: + return # Check if logging is enabled *after* initial checks event_key = "invite_create_event" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return inviter = invite.inviter channel = invite.channel desc = f"Invite `{invite.code}` created for {channel.mention if channel else 'Unknown Channel'}" if invite.max_age: # Use invite.created_at if available, otherwise fall back to current time - created_time = invite.created_at if invite.created_at is not None else discord.utils.utcnow() + created_time = ( + invite.created_at + if invite.created_at is not None + else discord.utils.utcnow() + ) expires_at = created_time + datetime.timedelta(seconds=invite.max_age) desc += f"\nExpires: {discord.utils.format_dt(expires_at, style='R')}" if invite.max_uses: @@ -1217,19 +1463,23 @@ class LoggingCog(commands.Cog): title="✉️ Invite Created (Event)", description=f"{desc}\n*Audit log may contain creator.*", color=discord.Color.dark_magenta(), - author=inviter # Can be None if invite created through server settings/vanity URL + author=inviter, # Can be None if invite created through server settings/vanity URL ) - self._add_id_footer(embed, invite, obj_id=invite.id, id_name="Invite ID") # Invite object doesn't have ID directly? Use code? No, ID exists. + self._add_id_footer( + embed, invite, obj_id=invite.id, id_name="Invite ID" + ) # Invite object doesn't have ID directly? Use code? No, ID exists. await self._send_log_embed(guild, embed) @commands.Cog.listener() async def on_invite_delete(self, invite: discord.Invite): guild = invite.guild - if not guild: return + if not guild: + return # Check if logging is enabled *after* initial checks event_key = "invite_delete_event" - if not await self._check_log_enabled(guild.id, event_key): return + if not await self._check_log_enabled(guild.id, event_key): + return channel = invite.channel desc = f"Invite `{invite.code}` for {channel.mention if channel else 'Unknown Channel'} was deleted or expired." @@ -1238,14 +1488,13 @@ class LoggingCog(commands.Cog): # title="🗑️ Invite Deleted", # Removed duplicate title title="🗑️ Invite Deleted (Event)", description=f"{desc}\n*Audit log may contain deleter.*", - color=discord.Color.dark_grey() + color=discord.Color.dark_grey(), # Cannot reliably get inviter after deletion ) # Invite object might not have ID after deletion, use code in footer? embed.set_footer(text=f"Invite Code: {invite.code}") await self._send_log_embed(guild, embed) - # --- Bot/Command Events --- # Note: These might be noisy depending on bot usage. Consider enabling selectively. # @commands.Cog.listener() @@ -1260,27 +1509,46 @@ class LoggingCog(commands.Cog): # await self._send_log_embed(ctx.guild, embed) @commands.Cog.listener() - async def on_command_error(self, ctx: commands.Context, error: commands.CommandError): + async def on_command_error( + self, ctx: commands.Context, error: commands.CommandError + ): # Log only significant errors, ignore things like CommandNotFound or CheckFailure if desired - ignored = (commands.CommandNotFound, commands.CheckFailure, commands.UserInputError, commands.DisabledCommand, commands.CommandOnCooldown) + ignored = ( + commands.CommandNotFound, + commands.CheckFailure, + commands.UserInputError, + commands.DisabledCommand, + commands.CommandOnCooldown, + ) if isinstance(error, ignored): return - if not ctx.guild: return # Ignore DMs + if not ctx.guild: + return # Ignore DMs # Check if logging is enabled *after* initial checks event_key = "command_error" - if not await self._check_log_enabled(ctx.guild.id, event_key): return + if not await self._check_log_enabled(ctx.guild.id, event_key): + return embed = self._create_log_embed( title="❌ Command Error", description=f"Error in command `{ctx.command.qualified_name if ctx.command else 'Unknown'}` used by {ctx.author.mention} in {ctx.channel.mention}", color=discord.Color.brand_red(), - author=ctx.author + author=ctx.author, ) # Get traceback if available (might need error handling specific to your bot's setup) import traceback - tb = ''.join(traceback.format_exception(type(error), error, error.__traceback__)) - embed.add_field(name="Error Details", value=f"```py\n{tb[:1000]}\n...```" if len(tb) > 1000 else f"```py\n{tb}```", inline=False) + + tb = "".join( + traceback.format_exception(type(error), error, error.__traceback__) + ) + embed.add_field( + name="Error Details", + value=( + f"```py\n{tb[:1000]}\n...```" if len(tb) > 1000 else f"```py\n{tb}```" + ), + inline=False, + ) await self._send_log_embed(ctx.guild, embed) @@ -1300,12 +1568,12 @@ class LoggingCog(commands.Cog): # The first set of definitions already includes the toggle checks. # --- Audit Log Polling Task --- - @tasks.loop(seconds=30) # Poll every 30 seconds + @tasks.loop(seconds=30) # Poll every 30 seconds async def poll_audit_log(self): # This loop starts only after the bot is ready and initialized if not self.bot.is_ready() or self.session is None or self.session.closed: # log.debug("Audit log poll skipped: Bot not ready or session not initialized.") - return # Wait until ready and session is available + return # Wait until ready and session is available # log.debug("Polling audit logs for all guilds...") # Can be noisy for guild in self.bot.guilds: @@ -1317,22 +1585,35 @@ class LoggingCog(commands.Cog): # Check permissions and last known ID for this specific guild if not guild.me.guild_permissions.view_audit_log: - if self.last_audit_log_ids.get(guild_id) is not None: # Log only once when perms are lost - log.warning(f"Missing 'View Audit Log' permission in guild {guild_id}. Cannot poll audit log.") - self.last_audit_log_ids[guild_id] = None # Mark as unable to poll - continue # Skip this guild + if ( + self.last_audit_log_ids.get(guild_id) is not None + ): # Log only once when perms are lost + log.warning( + f"Missing 'View Audit Log' permission in guild {guild_id}. Cannot poll audit log." + ) + self.last_audit_log_ids[guild_id] = None # Mark as unable to poll + continue # Skip this guild # If we previously couldn't poll due to permissions, try re-initializing the ID - if self.last_audit_log_ids.get(guild_id) is None and guild.me.guild_permissions.view_audit_log: - log.info(f"Re-initializing audit log ID for guild {guild_id} after gaining permissions.") - try: - async for entry in guild.audit_logs(limit=1): - self.last_audit_log_ids[guild_id] = entry.id - log.debug(f"Re-initialized last_audit_log_id for guild {guild.id} to {entry.id}") - break - except Exception as e: - log.exception(f"Error re-initializing audit log ID for guild {guild.id}: {e}") - continue # Skip this cycle if re-init fails + if ( + self.last_audit_log_ids.get(guild_id) is None + and guild.me.guild_permissions.view_audit_log + ): + log.info( + f"Re-initializing audit log ID for guild {guild_id} after gaining permissions." + ) + try: + async for entry in guild.audit_logs(limit=1): + self.last_audit_log_ids[guild_id] = entry.id + log.debug( + f"Re-initialized last_audit_log_id for guild {guild.id} to {entry.id}" + ) + break + except Exception as e: + log.exception( + f"Error re-initializing audit log ID for guild {guild.id}: {e}" + ) + continue # Skip this cycle if re-init fails last_id = self.last_audit_log_ids.get(guild_id) # log.debug(f"Polling audit log for guild {guild_id} after ID: {last_id}") # Can be noisy @@ -1343,7 +1624,7 @@ class LoggingCog(commands.Cog): discord.AuditLogAction.member_role_update, discord.AuditLogAction.message_delete, discord.AuditLogAction.message_bulk_delete, - discord.AuditLogAction.member_update, # Includes timeout + discord.AuditLogAction.member_update, # Includes timeout discord.AuditLogAction.role_create, discord.AuditLogAction.role_delete, discord.AuditLogAction.role_update, @@ -1356,8 +1637,8 @@ class LoggingCog(commands.Cog): discord.AuditLogAction.invite_create, discord.AuditLogAction.invite_delete, discord.AuditLogAction.guild_update, - discord.AuditLogAction.ban, # Add ban action for reason/moderator - discord.AuditLogAction.unban, # Add unban action for moderator + discord.AuditLogAction.ban, # Add ban action for reason/moderator + discord.AuditLogAction.unban, # Add unban action for moderator ] latest_id_in_batch = last_id @@ -1366,13 +1647,15 @@ class LoggingCog(commands.Cog): try: # Fetch entries after the last known ID for this guild # The 'actions' parameter is deprecated; filter manually after fetching. - async for entry in guild.audit_logs(limit=50, after=discord.Object(id=last_id) if last_id else None): + async for entry in guild.audit_logs( + limit=50, after=discord.Object(id=last_id) if last_id else None + ): # log.debug(f"Processing audit entry {entry.id} for guild {guild_id}") # Debug print # Double check ID comparison just in case the 'after' parameter isn't perfectly reliable across different calls/times if last_id is None or entry.id > last_id: - entries_to_log.append(entry) - if latest_id_in_batch is None or entry.id > latest_id_in_batch: - latest_id_in_batch = entry.id + entries_to_log.append(entry) + if latest_id_in_batch is None or entry.id > latest_id_in_batch: + latest_id_in_batch = entry.id # Process entries oldest to newest to maintain order for entry in reversed(entries_to_log): @@ -1386,56 +1669,73 @@ class LoggingCog(commands.Cog): # log.debug(f"Updated last_audit_log_id for guild {guild_id} to {latest_id_in_batch}") # Debug print except discord.Forbidden: - log.warning(f"Missing permissions (likely View Audit Log) in guild {guild.id} during poll. Marking as unable.") - self.last_audit_log_ids[guild_id] = None # Mark as unable to poll + log.warning( + f"Missing permissions (likely View Audit Log) in guild {guild.id} during poll. Marking as unable." + ) + self.last_audit_log_ids[guild_id] = None # Mark as unable to poll except discord.HTTPException as e: - log.error(f"HTTP error fetching audit logs for guild {guild.id}: {e}. Retrying next cycle.") + log.error( + f"HTTP error fetching audit logs for guild {guild.id}: {e}. Retrying next cycle." + ) # Consider adding backoff logic here if errors persist except Exception as e: - log.exception(f"Unexpected error in poll_audit_log for guild {guild.id}: {e}") + log.exception( + f"Unexpected error in poll_audit_log for guild {guild.id}: {e}" + ) # Don't update last_audit_log_id on unexpected error, retry next time - - async def _process_audit_log_entry(self, guild: discord.Guild, entry: discord.AuditLogEntry): + async def _process_audit_log_entry( + self, guild: discord.Guild, entry: discord.AuditLogEntry + ): """Processes a single relevant audit log entry and sends an embed.""" - user = entry.user # Moderator/Actor - target = entry.target # User/Channel/Role/Message affected + user = entry.user # Moderator/Actor + target = entry.target # User/Channel/Role/Message affected reason = entry.reason action_desc = "" color = discord.Color.dark_grey() title = f"🛡️ Audit Log: {str(entry.action).replace('_', ' ').title()}" - if not user: # Should generally not happen for manual actions, but safeguard + if not user: # Should generally not happen for manual actions, but safeguard return # --- Member Events (Ban, Unban, Kick, Prune) --- if entry.action == discord.AuditLogAction.ban: audit_event_key = "audit_ban" - if not await self._check_log_enabled(guild.id, audit_event_key): return + if not await self._check_log_enabled(guild.id, audit_event_key): + return title = "🛡️ Audit Log: Member Banned" - action_desc = f"{self._user_display(user)} banned {self._user_display(target)}" + action_desc = ( + f"{self._user_display(user)} banned {self._user_display(target)}" + ) color = discord.Color.red() # self._add_id_footer(embed, target, id_name="Target ID") # Footer set later elif entry.action == discord.AuditLogAction.unban: audit_event_key = "audit_unban" - if not await self._check_log_enabled(guild.id, audit_event_key): return + if not await self._check_log_enabled(guild.id, audit_event_key): + return title = "🛡️ Audit Log: Member Unbanned" - action_desc = f"{self._user_display(user)} unbanned {self._user_display(target)}" + action_desc = ( + f"{self._user_display(user)} unbanned {self._user_display(target)}" + ) color = discord.Color.blurple() # self._add_id_footer(embed, target, id_name="Target ID") # Footer set later elif entry.action == discord.AuditLogAction.kick: audit_event_key = "audit_kick" - if not await self._check_log_enabled(guild.id, audit_event_key): return + if not await self._check_log_enabled(guild.id, audit_event_key): + return title = "🛡️ Audit Log: Member Kicked" - action_desc = f"{self._user_display(user)} kicked {self._user_display(target)}" + action_desc = ( + f"{self._user_display(user)} kicked {self._user_display(target)}" + ) color = discord.Color.brand_red() # self._add_id_footer(embed, target, id_name="Target ID") # Footer set later elif entry.action == discord.AuditLogAction.member_prune: audit_event_key = "audit_prune" - if not await self._check_log_enabled(guild.id, audit_event_key): return + if not await self._check_log_enabled(guild.id, audit_event_key): + return title = "🛡️ Audit Log: Member Prune" - days = entry.extra.get('delete_member_days') - count = entry.extra.get('members_removed') + days = entry.extra.get("delete_member_days") + count = entry.extra.get("members_removed") action_desc = f"{self._user_display(user)} pruned {count} members inactive for {days} days." color = discord.Color.dark_red() # No specific target ID here @@ -1443,371 +1743,601 @@ class LoggingCog(commands.Cog): # --- Member Update (Roles, Timeout) --- elif entry.action == discord.AuditLogAction.member_role_update: audit_event_key = "audit_member_role_update" - if not await self._check_log_enabled(guild.id, audit_event_key): return + if not await self._check_log_enabled(guild.id, audit_event_key): + return # entry.before.roles / entry.after.roles contains the role changes before_roles = entry.before.roles after_roles = entry.after.roles added = [r.mention for r in after_roles if r not in before_roles] removed = [r.mention for r in before_roles if r not in after_roles] - if added or removed: # Only log if roles actually changed + if added or removed: # Only log if roles actually changed action_desc = f"{self._user_display(user)} updated roles for {self._user_display(target)} ({target.id}):" - if added: action_desc += f"\n**Added:** {', '.join(added)}" - if removed: action_desc += f"\n**Removed:** {', '.join(removed)}" + if added: + action_desc += f"\n**Added:** {', '.join(added)}" + if removed: + action_desc += f"\n**Removed:** {', '.join(removed)}" color = discord.Color.blue() - else: return # Skip if no role change detected + else: + return # Skip if no role change detected elif entry.action == discord.AuditLogAction.member_update: - # Check for timeout changes - before_timed_out = getattr(entry.before, 'timed_out_until', None) - after_timed_out = getattr(entry.after, 'timed_out_until', None) - if before_timed_out != after_timed_out: - audit_event_key = "audit_member_update_timeout" - if not await self._check_log_enabled(guild.id, audit_event_key): return - title = "🛡️ Audit Log: Member Timeout Update" - if after_timed_out: - timeout_duration = discord.utils.format_dt(after_timed_out, style='R') - action_desc = f"{self._user_display(user)} timed out {self._user_display(target)} ({target.id}) until {timeout_duration}" - color = discord.Color.orange() - else: - action_desc = f"{self._user_display(user)} removed timeout from {self._user_display(target)} ({target.id})" - color = discord.Color.green() - # self._add_id_footer(embed, target, id_name="Target ID") # Footer set later - else: - # Could log other member updates here if needed (e.g. nick changes by mods) - requires separate toggle key - # log.debug(f"Unhandled member_update audit log entry by {user} on {target}") - return # Skip other member updates for now + # Check for timeout changes + before_timed_out = getattr(entry.before, "timed_out_until", None) + after_timed_out = getattr(entry.after, "timed_out_until", None) + if before_timed_out != after_timed_out: + audit_event_key = "audit_member_update_timeout" + if not await self._check_log_enabled(guild.id, audit_event_key): + return + title = "🛡️ Audit Log: Member Timeout Update" + if after_timed_out: + timeout_duration = discord.utils.format_dt( + after_timed_out, style="R" + ) + action_desc = f"{self._user_display(user)} timed out {self._user_display(target)} ({target.id}) until {timeout_duration}" + color = discord.Color.orange() + else: + action_desc = f"{self._user_display(user)} removed timeout from {self._user_display(target)} ({target.id})" + color = discord.Color.green() + # self._add_id_footer(embed, target, id_name="Target ID") # Footer set later + else: + # Could log other member updates here if needed (e.g. nick changes by mods) - requires separate toggle key + # log.debug(f"Unhandled member_update audit log entry by {user} on {target}") + return # Skip other member updates for now # --- Role Events --- elif entry.action == discord.AuditLogAction.role_create: - audit_event_key = "audit_role_create" - if not await self._check_log_enabled(guild.id, audit_event_key): return - title = "🛡️ Audit Log: Role Created" - role = target # Target is the role - action_desc = f"{user.mention} created role {role.mention} (`{role.name}`)" - color = discord.Color.teal() - # self._add_id_footer(embed, role, id_name="Role ID") # Footer set later + audit_event_key = "audit_role_create" + if not await self._check_log_enabled(guild.id, audit_event_key): + return + title = "🛡️ Audit Log: Role Created" + role = target # Target is the role + action_desc = f"{user.mention} created role {role.mention} (`{role.name}`)" + color = discord.Color.teal() + # self._add_id_footer(embed, role, id_name="Role ID") # Footer set later elif entry.action == discord.AuditLogAction.role_delete: - audit_event_key = "audit_role_delete" - if not await self._check_log_enabled(guild.id, audit_event_key): return - title = "🛡️ Audit Log: Role Deleted" - # Target is the role ID, before object has role details - role_name = entry.before.name - role_id = entry.target.id - action_desc = f"{user.mention} deleted role `{role_name}` ({role_id})" - color = discord.Color.dark_teal() - # self._add_id_footer(embed, obj_id=role_id, id_name="Role ID") # Footer set later + audit_event_key = "audit_role_delete" + if not await self._check_log_enabled(guild.id, audit_event_key): + return + title = "🛡️ Audit Log: Role Deleted" + # Target is the role ID, before object has role details + role_name = entry.before.name + role_id = entry.target.id + action_desc = f"{user.mention} deleted role `{role_name}` ({role_id})" + color = discord.Color.dark_teal() + # self._add_id_footer(embed, obj_id=role_id, id_name="Role ID") # Footer set later elif entry.action == discord.AuditLogAction.role_update: - audit_event_key = "audit_role_update" - if not await self._check_log_enabled(guild.id, audit_event_key): return - title = "🛡️ Audit Log: Role Updated" - role = target - changes = [] - # Simple diffing for common properties - if hasattr(entry.before, 'name') and hasattr(entry.after, 'name') and entry.before.name != entry.after.name: - changes.append(f"Name: `{entry.before.name}` → `{entry.after.name}`") - if hasattr(entry.before, 'color') and hasattr(entry.after, 'color') and entry.before.color != entry.after.color: - changes.append(f"Color: `{entry.before.color}` → `{entry.after.color}`") - if hasattr(entry.before, 'hoist') and hasattr(entry.after, 'hoist') and entry.before.hoist != entry.after.hoist: - changes.append(f"Hoisted: `{entry.before.hoist}` → `{entry.after.hoist}`") - if hasattr(entry.before, 'mentionable') and hasattr(entry.after, 'mentionable') and entry.before.mentionable != entry.after.mentionable: - changes.append(f"Mentionable: `{entry.before.mentionable}` → `{entry.after.mentionable}`") - if hasattr(entry.before, 'permissions') and hasattr(entry.after, 'permissions') and entry.before.permissions != entry.after.permissions: - changes.append("Permissions Updated (See Audit Log for details)") # Permissions are complex - if changes: - action_desc = f"{user.mention} updated role {role.mention} ({role.id}):\n" + "\n".join(f"- {c}" for c in changes) - color = discord.Color.blue() - # self._add_id_footer(embed, role, id_name="Role ID") # Footer set later - else: - # log.debug(f"Role update detected for {role.id} but no tracked changes found.") # Might still want to log permission changes even if other props are same - return # Skip if no changes we track were made + audit_event_key = "audit_role_update" + if not await self._check_log_enabled(guild.id, audit_event_key): + return + title = "🛡️ Audit Log: Role Updated" + role = target + changes = [] + # Simple diffing for common properties + if ( + hasattr(entry.before, "name") + and hasattr(entry.after, "name") + and entry.before.name != entry.after.name + ): + changes.append(f"Name: `{entry.before.name}` → `{entry.after.name}`") + if ( + hasattr(entry.before, "color") + and hasattr(entry.after, "color") + and entry.before.color != entry.after.color + ): + changes.append(f"Color: `{entry.before.color}` → `{entry.after.color}`") + if ( + hasattr(entry.before, "hoist") + and hasattr(entry.after, "hoist") + and entry.before.hoist != entry.after.hoist + ): + changes.append( + f"Hoisted: `{entry.before.hoist}` → `{entry.after.hoist}`" + ) + if ( + hasattr(entry.before, "mentionable") + and hasattr(entry.after, "mentionable") + and entry.before.mentionable != entry.after.mentionable + ): + changes.append( + f"Mentionable: `{entry.before.mentionable}` → `{entry.after.mentionable}`" + ) + if ( + hasattr(entry.before, "permissions") + and hasattr(entry.after, "permissions") + and entry.before.permissions != entry.after.permissions + ): + changes.append( + "Permissions Updated (See Audit Log for details)" + ) # Permissions are complex + if changes: + action_desc = ( + f"{user.mention} updated role {role.mention} ({role.id}):\n" + + "\n".join(f"- {c}" for c in changes) + ) + color = discord.Color.blue() + # self._add_id_footer(embed, role, id_name="Role ID") # Footer set later + else: + # log.debug(f"Role update detected for {role.id} but no tracked changes found.") # Might still want to log permission changes even if other props are same + return # Skip if no changes we track were made # --- Channel Events --- elif entry.action == discord.AuditLogAction.channel_create: - audit_event_key = "audit_channel_create" - if not await self._check_log_enabled(guild.id, audit_event_key): return - title = "🛡️ Audit Log: Channel Created" - channel = target - ch_type = str(channel.type).capitalize() - action_desc = f"{user.mention} created {ch_type} channel {channel.mention} (`{channel.name}`)" - color = discord.Color.green() - # self._add_id_footer(embed, channel, id_name="Channel ID") # Footer set later + audit_event_key = "audit_channel_create" + if not await self._check_log_enabled(guild.id, audit_event_key): + return + title = "🛡️ Audit Log: Channel Created" + channel = target + ch_type = str(channel.type).capitalize() + action_desc = f"{user.mention} created {ch_type} channel {channel.mention} (`{channel.name}`)" + color = discord.Color.green() + # self._add_id_footer(embed, channel, id_name="Channel ID") # Footer set later elif entry.action == discord.AuditLogAction.channel_delete: - audit_event_key = "audit_channel_delete" - if not await self._check_log_enabled(guild.id, audit_event_key): return - title = "🛡️ Audit Log: Channel Deleted" - # Target is channel ID, before object has details - channel_name = entry.before.name - channel_id = entry.target.id - ch_type = str(entry.before.type).capitalize() - action_desc = f"{user.mention} deleted {ch_type} channel `{channel_name}` ({channel_id})" - color = discord.Color.red() - # self._add_id_footer(embed, obj_id=channel_id, id_name="Channel ID") # Footer set later + audit_event_key = "audit_channel_delete" + if not await self._check_log_enabled(guild.id, audit_event_key): + return + title = "🛡️ Audit Log: Channel Deleted" + # Target is channel ID, before object has details + channel_name = entry.before.name + channel_id = entry.target.id + ch_type = str(entry.before.type).capitalize() + action_desc = f"{user.mention} deleted {ch_type} channel `{channel_name}` ({channel_id})" + color = discord.Color.red() + # self._add_id_footer(embed, obj_id=channel_id, id_name="Channel ID") # Footer set later elif entry.action == discord.AuditLogAction.channel_update: - audit_event_key = "audit_channel_update" - if not await self._check_log_enabled(guild.id, audit_event_key): return - title = "🛡️ Audit Log: Channel Updated" - channel = target - ch_type = str(channel.type).capitalize() - changes = [] - # Simple diffing - if hasattr(entry.before, 'name') and hasattr(entry.after, 'name') and entry.before.name != entry.after.name: - changes.append(f"Name: `{entry.before.name}` → `{entry.after.name}`") - if hasattr(entry.before, 'topic') and hasattr(entry.after, 'topic') and entry.before.topic != entry.after.topic: - changes.append(f"Topic Changed") # Keep it simple - if hasattr(entry.before, 'nsfw') and hasattr(entry.after, 'nsfw') and entry.before.nsfw != entry.after.nsfw: - changes.append(f"NSFW: `{entry.before.nsfw}` → `{entry.after.nsfw}`") - if hasattr(entry.before, 'slowmode_delay') and hasattr(entry.after, 'slowmode_delay') and entry.before.slowmode_delay != entry.after.slowmode_delay: - changes.append(f"Slowmode: `{entry.before.slowmode_delay}s` → `{entry.after.slowmode_delay}s`") - if hasattr(entry.before, 'bitrate') and hasattr(entry.after, 'bitrate') and entry.before.bitrate != entry.after.bitrate: - changes.append(f"Bitrate: `{entry.before.bitrate}` → `{entry.after.bitrate}`") - # Process detailed changes from entry.changes - detailed_changes = [] + audit_event_key = "audit_channel_update" + if not await self._check_log_enabled(guild.id, audit_event_key): + return + title = "🛡️ Audit Log: Channel Updated" + channel = target + ch_type = str(channel.type).capitalize() + changes = [] + # Simple diffing + if ( + hasattr(entry.before, "name") + and hasattr(entry.after, "name") + and entry.before.name != entry.after.name + ): + changes.append(f"Name: `{entry.before.name}` → `{entry.after.name}`") + if ( + hasattr(entry.before, "topic") + and hasattr(entry.after, "topic") + and entry.before.topic != entry.after.topic + ): + changes.append(f"Topic Changed") # Keep it simple + if ( + hasattr(entry.before, "nsfw") + and hasattr(entry.after, "nsfw") + and entry.before.nsfw != entry.after.nsfw + ): + changes.append(f"NSFW: `{entry.before.nsfw}` → `{entry.after.nsfw}`") + if ( + hasattr(entry.before, "slowmode_delay") + and hasattr(entry.after, "slowmode_delay") + and entry.before.slowmode_delay != entry.after.slowmode_delay + ): + changes.append( + f"Slowmode: `{entry.before.slowmode_delay}s` → `{entry.after.slowmode_delay}s`" + ) + if ( + hasattr(entry.before, "bitrate") + and hasattr(entry.after, "bitrate") + and entry.before.bitrate != entry.after.bitrate + ): + changes.append( + f"Bitrate: `{entry.before.bitrate}` → `{entry.after.bitrate}`" + ) + # Process detailed changes from entry.changes + detailed_changes = [] - # AuditLogChanges is not directly iterable, so we need to handle it differently - try: - # Check if entry.changes has the __iter__ attribute (is iterable) - if hasattr(entry.changes, '__iter__'): - for change in entry.changes: - attr = change.attribute - before_val = change.before - after_val = change.after - if attr == 'name': detailed_changes.append(f"Name: `{before_val}` → `{after_val}`") - elif attr == 'topic': detailed_changes.append(f"Topic: `{before_val or 'None'}` → `{after_val or 'None'}`") - elif attr == 'nsfw': detailed_changes.append(f"NSFW: `{before_val}` → `{after_val}`") - elif attr == 'slowmode_delay': detailed_changes.append(f"Slowmode: `{before_val}s` → `{after_val}s`") - elif attr == 'bitrate': detailed_changes.append(f"Bitrate: `{before_val}` → `{after_val}`") - elif attr == 'user_limit': detailed_changes.append(f"User Limit: `{before_val}` → `{after_val}`") - elif attr == 'position': detailed_changes.append(f"Position: `{before_val}` → `{after_val}`") - elif attr == 'category': detailed_changes.append(f"Category: {getattr(before_val, 'mention', 'None')} → {getattr(after_val, 'mention', 'None')}") - elif attr == 'permission_overwrites': - # Audit log gives overwrite target ID and type directly in the change object - ow_target_id = getattr(change.target, 'id', None) # Target of the overwrite change - ow_target_type = getattr(change.target, 'type', None) # 'role' or 'member' - if ow_target_id and ow_target_type: - target_mention = f"<@&{ow_target_id}>" if ow_target_type == 'role' else f"<@{ow_target_id}>" - # Determine if added, removed, or updated (before/after values are PermissionOverwrite objects) - if before_val is None and after_val is not None: - detailed_changes.append(f"Added overwrite for {target_mention}") - elif before_val is not None and after_val is None: - detailed_changes.append(f"Removed overwrite for {target_mention}") - else: - detailed_changes.append(f"Updated overwrite for {target_mention}") - else: - detailed_changes.append("Permission Overwrites Updated (Target details unavailable)") # Fallback - else: - # Log other unhandled changes generically - detailed_changes.append(f"{attr.replace('_', ' ').title()} changed: `{before_val}` → `{after_val}`") - else: - # Handle AuditLogChanges as a non-iterable object - # We can access the before and after attributes directly - if hasattr(entry.changes, 'before') and hasattr(entry.changes, 'after'): - before = entry.changes.before - after = entry.changes.after + # AuditLogChanges is not directly iterable, so we need to handle it differently + try: + # Check if entry.changes has the __iter__ attribute (is iterable) + if hasattr(entry.changes, "__iter__"): + for change in entry.changes: + attr = change.attribute + before_val = change.before + after_val = change.after + if attr == "name": + detailed_changes.append( + f"Name: `{before_val}` → `{after_val}`" + ) + elif attr == "topic": + detailed_changes.append( + f"Topic: `{before_val or 'None'}` → `{after_val or 'None'}`" + ) + elif attr == "nsfw": + detailed_changes.append( + f"NSFW: `{before_val}` → `{after_val}`" + ) + elif attr == "slowmode_delay": + detailed_changes.append( + f"Slowmode: `{before_val}s` → `{after_val}s`" + ) + elif attr == "bitrate": + detailed_changes.append( + f"Bitrate: `{before_val}` → `{after_val}`" + ) + elif attr == "user_limit": + detailed_changes.append( + f"User Limit: `{before_val}` → `{after_val}`" + ) + elif attr == "position": + detailed_changes.append( + f"Position: `{before_val}` → `{after_val}`" + ) + elif attr == "category": + detailed_changes.append( + f"Category: {getattr(before_val, 'mention', 'None')} → {getattr(after_val, 'mention', 'None')}" + ) + elif attr == "permission_overwrites": + # Audit log gives overwrite target ID and type directly in the change object + ow_target_id = getattr( + change.target, "id", None + ) # Target of the overwrite change + ow_target_type = getattr( + change.target, "type", None + ) # 'role' or 'member' + if ow_target_id and ow_target_type: + target_mention = ( + f"<@&{ow_target_id}>" + if ow_target_type == "role" + else f"<@{ow_target_id}>" + ) + # Determine if added, removed, or updated (before/after values are PermissionOverwrite objects) + if before_val is None and after_val is not None: + detailed_changes.append( + f"Added overwrite for {target_mention}" + ) + elif before_val is not None and after_val is None: + detailed_changes.append( + f"Removed overwrite for {target_mention}" + ) + else: + detailed_changes.append( + f"Updated overwrite for {target_mention}" + ) + else: + detailed_changes.append( + "Permission Overwrites Updated (Target details unavailable)" + ) # Fallback + else: + # Log other unhandled changes generically + detailed_changes.append( + f"{attr.replace('_', ' ').title()} changed: `{before_val}` → `{after_val}`" + ) + else: + # Handle AuditLogChanges as a non-iterable object + # We can access the before and after attributes directly + if hasattr(entry.changes, "before") and hasattr( + entry.changes, "after" + ): + before = entry.changes.before + after = entry.changes.after - # Compare attributes between before and after - if hasattr(before, 'name') and hasattr(after, 'name') and before.name != after.name: - detailed_changes.append(f"Name: `{before.name}` → `{after.name}`") - if hasattr(before, 'topic') and hasattr(after, 'topic') and before.topic != after.topic: - detailed_changes.append(f"Topic: `{before.topic or 'None'}` → `{after.topic or 'None'}`") - if hasattr(before, 'nsfw') and hasattr(after, 'nsfw') and before.nsfw != after.nsfw: - detailed_changes.append(f"NSFW: `{before.nsfw}` → `{after.nsfw}`") - if hasattr(before, 'slowmode_delay') and hasattr(after, 'slowmode_delay') and before.slowmode_delay != after.slowmode_delay: - detailed_changes.append(f"Slowmode: `{before.slowmode_delay}s` → `{after.slowmode_delay}s`") - if hasattr(before, 'bitrate') and hasattr(after, 'bitrate') and before.bitrate != after.bitrate: - detailed_changes.append(f"Bitrate: `{before.bitrate}` → `{after.bitrate}`") - if hasattr(before, 'user_limit') and hasattr(after, 'user_limit') and before.user_limit != after.user_limit: - detailed_changes.append(f"User Limit: `{before.user_limit}` → `{after.user_limit}`") - if hasattr(before, 'position') and hasattr(after, 'position') and before.position != after.position: - detailed_changes.append(f"Position: `{before.position}` → `{after.position}`") - # Add more attribute comparisons as needed - except Exception as e: - log.error(f"Error processing audit log changes: {e}", exc_info=True) - detailed_changes.append(f"Error processing changes: {e}") + # Compare attributes between before and after + if ( + hasattr(before, "name") + and hasattr(after, "name") + and before.name != after.name + ): + detailed_changes.append( + f"Name: `{before.name}` → `{after.name}`" + ) + if ( + hasattr(before, "topic") + and hasattr(after, "topic") + and before.topic != after.topic + ): + detailed_changes.append( + f"Topic: `{before.topic or 'None'}` → `{after.topic or 'None'}`" + ) + if ( + hasattr(before, "nsfw") + and hasattr(after, "nsfw") + and before.nsfw != after.nsfw + ): + detailed_changes.append( + f"NSFW: `{before.nsfw}` → `{after.nsfw}`" + ) + if ( + hasattr(before, "slowmode_delay") + and hasattr(after, "slowmode_delay") + and before.slowmode_delay != after.slowmode_delay + ): + detailed_changes.append( + f"Slowmode: `{before.slowmode_delay}s` → `{after.slowmode_delay}s`" + ) + if ( + hasattr(before, "bitrate") + and hasattr(after, "bitrate") + and before.bitrate != after.bitrate + ): + detailed_changes.append( + f"Bitrate: `{before.bitrate}` → `{after.bitrate}`" + ) + if ( + hasattr(before, "user_limit") + and hasattr(after, "user_limit") + and before.user_limit != after.user_limit + ): + detailed_changes.append( + f"User Limit: `{before.user_limit}` → `{after.user_limit}`" + ) + if ( + hasattr(before, "position") + and hasattr(after, "position") + and before.position != after.position + ): + detailed_changes.append( + f"Position: `{before.position}` → `{after.position}`" + ) + # Add more attribute comparisons as needed + except Exception as e: + log.error(f"Error processing audit log changes: {e}", exc_info=True) + detailed_changes.append(f"Error processing changes: {e}") - if detailed_changes: - action_desc = f"{user.mention} updated {ch_type} channel {channel.mention} ({channel.id}):\n" + "\n".join(f"- {c}" for c in detailed_changes) - color = discord.Color.yellow() - # self._add_id_footer(embed, channel, id_name="Channel ID") # Footer set later - else: - # log.debug(f"Channel update detected for {channel.id} but no tracked changes found.") # Might still want to log permission changes - return # Skip if no changes we track were made + if detailed_changes: + action_desc = ( + f"{user.mention} updated {ch_type} channel {channel.mention} ({channel.id}):\n" + + "\n".join(f"- {c}" for c in detailed_changes) + ) + color = discord.Color.yellow() + # self._add_id_footer(embed, channel, id_name="Channel ID") # Footer set later + else: + # log.debug(f"Channel update detected for {channel.id} but no tracked changes found.") # Might still want to log permission changes + return # Skip if no changes we track were made # --- Message Events (Delete, Bulk Delete) --- elif entry.action == discord.AuditLogAction.message_delete: audit_event_key = "audit_message_delete" - if not await self._check_log_enabled(guild.id, audit_event_key): return - title = "🛡️ Audit Log: Message Deleted" # Title adjusted for clarity + if not await self._check_log_enabled(guild.id, audit_event_key): + return + title = "🛡️ Audit Log: Message Deleted" # Title adjusted for clarity channel = entry.extra.channel count = entry.extra.count action_desc = f"{user.mention} deleted {count} message(s) by {target.mention} ({target.id}) in {channel.mention}" color = discord.Color.dark_grey() elif entry.action == discord.AuditLogAction.message_bulk_delete: - audit_event_key = "audit_message_bulk_delete" - if not await self._check_log_enabled(guild.id, audit_event_key): return - title = "🛡️ Audit Log: Message Bulk Delete" - channel_target = entry.target # Channel is the target here - count = entry.extra.count - - channel_display = "" - if hasattr(channel_target, 'mention'): - channel_display = channel_target.mention - elif isinstance(channel_target, discord.Object) and hasattr(channel_target, 'id'): - # If it's an Object, it might be a deleted channel or not fully loaded. - # Using <#id> is a safe way to reference it. - channel_display = f"<#{channel_target.id}>" - else: - # Fallback if it's not an object with 'mention' or an 'Object' with 'id' - channel_display = f"an unknown channel (ID: {getattr(channel_target, 'id', 'N/A')})" + audit_event_key = "audit_message_bulk_delete" + if not await self._check_log_enabled(guild.id, audit_event_key): + return + title = "🛡️ Audit Log: Message Bulk Delete" + channel_target = entry.target # Channel is the target here + count = entry.extra.count - action_desc = f"{user.mention} bulk deleted {count} messages in {channel_display}" - color = discord.Color.dark_grey() - # self._add_id_footer(embed, channel_target, id_name="Channel ID") # Footer set later + channel_display = "" + if hasattr(channel_target, "mention"): + channel_display = channel_target.mention + elif isinstance(channel_target, discord.Object) and hasattr( + channel_target, "id" + ): + # If it's an Object, it might be a deleted channel or not fully loaded. + # Using <#id> is a safe way to reference it. + channel_display = f"<#{channel_target.id}>" + else: + # Fallback if it's not an object with 'mention' or an 'Object' with 'id' + channel_display = ( + f"an unknown channel (ID: {getattr(channel_target, 'id', 'N/A')})" + ) + + action_desc = ( + f"{user.mention} bulk deleted {count} messages in {channel_display}" + ) + color = discord.Color.dark_grey() + # self._add_id_footer(embed, channel_target, id_name="Channel ID") # Footer set later # --- Emoji Events --- elif entry.action == discord.AuditLogAction.emoji_create: - audit_event_key = "audit_emoji_create" - if not await self._check_log_enabled(guild.id, audit_event_key): return - title = "🛡️ Audit Log: Emoji Created" - emoji = target - action_desc = f"{user.mention} created emoji {emoji} (`{emoji.name}`)" - color = discord.Color.magenta() - # self._add_id_footer(embed, emoji, id_name="Emoji ID") # Footer set later + audit_event_key = "audit_emoji_create" + if not await self._check_log_enabled(guild.id, audit_event_key): + return + title = "🛡️ Audit Log: Emoji Created" + emoji = target + action_desc = f"{user.mention} created emoji {emoji} (`{emoji.name}`)" + color = discord.Color.magenta() + # self._add_id_footer(embed, emoji, id_name="Emoji ID") # Footer set later elif entry.action == discord.AuditLogAction.emoji_delete: - audit_event_key = "audit_emoji_delete" - if not await self._check_log_enabled(guild.id, audit_event_key): return - title = "🛡️ Audit Log: Emoji Deleted" - emoji_name = entry.before.name - emoji_id = entry.target.id - action_desc = f"{user.mention} deleted emoji `{emoji_name}` ({emoji_id})" - color = discord.Color.dark_magenta() - # self._add_id_footer(embed, obj_id=emoji_id, id_name="Emoji ID") # Footer set later + audit_event_key = "audit_emoji_delete" + if not await self._check_log_enabled(guild.id, audit_event_key): + return + title = "🛡️ Audit Log: Emoji Deleted" + emoji_name = entry.before.name + emoji_id = entry.target.id + action_desc = f"{user.mention} deleted emoji `{emoji_name}` ({emoji_id})" + color = discord.Color.dark_magenta() + # self._add_id_footer(embed, obj_id=emoji_id, id_name="Emoji ID") # Footer set later elif entry.action == discord.AuditLogAction.emoji_update: - audit_event_key = "audit_emoji_update" - if not await self._check_log_enabled(guild.id, audit_event_key): return - title = "🛡️ Audit Log: Emoji Updated" - emoji = target - if hasattr(entry.before, 'name') and hasattr(entry.after, 'name') and entry.before.name != entry.after.name: - action_desc = f"{user.mention} renamed emoji `{entry.before.name}` to {emoji} (`{emoji.name}`)" - color = discord.Color.magenta() - # self._add_id_footer(embed, emoji, id_name="Emoji ID") # Footer set later - else: - # log.debug(f"Emoji update detected for {emoji.id} but no tracked changes found.") # Only log name changes for now - return # Only log name changes for now + audit_event_key = "audit_emoji_update" + if not await self._check_log_enabled(guild.id, audit_event_key): + return + title = "🛡️ Audit Log: Emoji Updated" + emoji = target + if ( + hasattr(entry.before, "name") + and hasattr(entry.after, "name") + and entry.before.name != entry.after.name + ): + action_desc = f"{user.mention} renamed emoji `{entry.before.name}` to {emoji} (`{emoji.name}`)" + color = discord.Color.magenta() + # self._add_id_footer(embed, emoji, id_name="Emoji ID") # Footer set later + else: + # log.debug(f"Emoji update detected for {emoji.id} but no tracked changes found.") # Only log name changes for now + return # Only log name changes for now # --- Invite Events --- elif entry.action == discord.AuditLogAction.invite_create: - audit_event_key = "audit_invite_create" - if not await self._check_log_enabled(guild.id, audit_event_key): return - title = "🛡️ Audit Log: Invite Created" - invite = target # Target is the invite object - channel = invite.channel - desc = f"Invite `{invite.code}` created for {channel.mention if channel else 'Unknown Channel'}" - if invite.max_age: - # Use invite.created_at if available, otherwise fall back to current time - created_time = invite.created_at if invite.created_at is not None else discord.utils.utcnow() - expires_at = created_time + datetime.timedelta(seconds=invite.max_age) - desc += f"\nExpires: {discord.utils.format_dt(expires_at, style='R')}" - if invite.max_uses: desc += f"\nMax Uses: {invite.max_uses}" - action_desc = f"{user.mention} created an invite:\n{desc}" - color = discord.Color.dark_green() - # self._add_id_footer(embed, invite, obj_id=invite.id, id_name="Invite ID") # Footer set later + audit_event_key = "audit_invite_create" + if not await self._check_log_enabled(guild.id, audit_event_key): + return + title = "🛡️ Audit Log: Invite Created" + invite = target # Target is the invite object + channel = invite.channel + desc = f"Invite `{invite.code}` created for {channel.mention if channel else 'Unknown Channel'}" + if invite.max_age: + # Use invite.created_at if available, otherwise fall back to current time + created_time = ( + invite.created_at + if invite.created_at is not None + else discord.utils.utcnow() + ) + expires_at = created_time + datetime.timedelta(seconds=invite.max_age) + desc += f"\nExpires: {discord.utils.format_dt(expires_at, style='R')}" + if invite.max_uses: + desc += f"\nMax Uses: {invite.max_uses}" + action_desc = f"{user.mention} created an invite:\n{desc}" + color = discord.Color.dark_green() + # self._add_id_footer(embed, invite, obj_id=invite.id, id_name="Invite ID") # Footer set later elif entry.action == discord.AuditLogAction.invite_delete: - audit_event_key = "audit_invite_delete" - if not await self._check_log_enabled(guild.id, audit_event_key): return - title = "🛡️ Audit Log: Invite Deleted" - # Target is invite ID, before object has details - invite_code = entry.before.code - channel_id = entry.before.channel_id - channel_mention = f"<#{channel_id}>" if channel_id else "Unknown Channel" - action_desc = f"{user.mention} deleted invite `{invite_code}` for channel {channel_mention}" - color = discord.Color.dark_red() - # Cannot get invite ID after deletion easily, use code in footer later + audit_event_key = "audit_invite_delete" + if not await self._check_log_enabled(guild.id, audit_event_key): + return + title = "🛡️ Audit Log: Invite Deleted" + # Target is invite ID, before object has details + invite_code = entry.before.code + channel_id = entry.before.channel_id + channel_mention = f"<#{channel_id}>" if channel_id else "Unknown Channel" + action_desc = f"{user.mention} deleted invite `{invite_code}` for channel {channel_mention}" + color = discord.Color.dark_red() + # Cannot get invite ID after deletion easily, use code in footer later # --- Guild Update --- elif entry.action == discord.AuditLogAction.guild_update: - audit_event_key = "audit_guild_update" - if not await self._check_log_enabled(guild.id, audit_event_key): return - title = "🛡️ Audit Log: Guild Updated" - changes = [] - # Diffing guild properties - safely check attributes exist before comparing - if hasattr(entry.before, 'name') and hasattr(entry.after, 'name') and entry.before.name != entry.after.name: - changes.append(f"Name: `{entry.before.name}` → `{entry.after.name}`") - if hasattr(entry.before, 'description') and hasattr(entry.after, 'description') and entry.before.description != entry.after.description: - changes.append(f"Description Changed") - if hasattr(entry.before, 'icon') and hasattr(entry.after, 'icon') and entry.before.icon != entry.after.icon: - changes.append(f"Icon Changed") - if hasattr(entry.before, 'banner') and hasattr(entry.after, 'banner') and entry.before.banner != entry.after.banner: - changes.append(f"Banner Changed") - if hasattr(entry.before, 'owner') and hasattr(entry.after, 'owner') and entry.before.owner != entry.after.owner: - changes.append(f"Owner: {entry.before.owner.mention if entry.before.owner else 'None'} → {entry.after.owner.mention if entry.after.owner else 'None'}") - if hasattr(entry.before, 'verification_level') and hasattr(entry.after, 'verification_level') and entry.before.verification_level != entry.after.verification_level: - changes.append(f"Verification Level: `{entry.before.verification_level}` → `{entry.after.verification_level}`") - if hasattr(entry.before, 'explicit_content_filter') and hasattr(entry.after, 'explicit_content_filter') and entry.before.explicit_content_filter != entry.after.explicit_content_filter: - changes.append(f"Explicit Content Filter: `{entry.before.explicit_content_filter}` → `{entry.after.explicit_content_filter}`") - if hasattr(entry.before, 'system_channel') and hasattr(entry.after, 'system_channel') and entry.before.system_channel != entry.after.system_channel: - changes.append(f"System Channel Changed") - # Add more properties as needed + audit_event_key = "audit_guild_update" + if not await self._check_log_enabled(guild.id, audit_event_key): + return + title = "🛡️ Audit Log: Guild Updated" + changes = [] + # Diffing guild properties - safely check attributes exist before comparing + if ( + hasattr(entry.before, "name") + and hasattr(entry.after, "name") + and entry.before.name != entry.after.name + ): + changes.append(f"Name: `{entry.before.name}` → `{entry.after.name}`") + if ( + hasattr(entry.before, "description") + and hasattr(entry.after, "description") + and entry.before.description != entry.after.description + ): + changes.append(f"Description Changed") + if ( + hasattr(entry.before, "icon") + and hasattr(entry.after, "icon") + and entry.before.icon != entry.after.icon + ): + changes.append(f"Icon Changed") + if ( + hasattr(entry.before, "banner") + and hasattr(entry.after, "banner") + and entry.before.banner != entry.after.banner + ): + changes.append(f"Banner Changed") + if ( + hasattr(entry.before, "owner") + and hasattr(entry.after, "owner") + and entry.before.owner != entry.after.owner + ): + changes.append( + f"Owner: {entry.before.owner.mention if entry.before.owner else 'None'} → {entry.after.owner.mention if entry.after.owner else 'None'}" + ) + if ( + hasattr(entry.before, "verification_level") + and hasattr(entry.after, "verification_level") + and entry.before.verification_level != entry.after.verification_level + ): + changes.append( + f"Verification Level: `{entry.before.verification_level}` → `{entry.after.verification_level}`" + ) + if ( + hasattr(entry.before, "explicit_content_filter") + and hasattr(entry.after, "explicit_content_filter") + and entry.before.explicit_content_filter + != entry.after.explicit_content_filter + ): + changes.append( + f"Explicit Content Filter: `{entry.before.explicit_content_filter}` → `{entry.after.explicit_content_filter}`" + ) + if ( + hasattr(entry.before, "system_channel") + and hasattr(entry.after, "system_channel") + and entry.before.system_channel != entry.after.system_channel + ): + changes.append(f"System Channel Changed") + # Add more properties as needed - if changes: - action_desc = f"{user.mention} updated server settings:\n" + "\n".join(f"- {c}" for c in changes) - color = discord.Color.dark_purple() - # self._add_id_footer(embed, guild, id_name="Guild ID") # Footer set later - else: - # log.debug(f"Guild update detected for {guild.id} but no tracked changes found.") # Might still want to log feature changes etc. - return # Skip if no changes we track were made + if changes: + action_desc = f"{user.mention} updated server settings:\n" + "\n".join( + f"- {c}" for c in changes + ) + color = discord.Color.dark_purple() + # self._add_id_footer(embed, guild, id_name="Guild ID") # Footer set later + else: + # log.debug(f"Guild update detected for {guild.id} but no tracked changes found.") # Might still want to log feature changes etc. + return # Skip if no changes we track were made else: # Action is in relevant_actions but not specifically handled above - log.warning(f"Audit log action '{entry.action}' is relevant but not explicitly handled in _process_audit_log_entry.") + log.warning( + f"Audit log action '{entry.action}' is relevant but not explicitly handled in _process_audit_log_entry." + ) # Generic fallback log title = f"🛡️ Audit Log: {str(entry.action).replace('_', ' ').title()}" # Determine the generic audit key based on the action category if possible - generic_audit_key = f"audit_{str(entry.action).split('.')[0]}" # e.g., audit_member, audit_channel + generic_audit_key = f"audit_{str(entry.action).split('.')[0]}" # e.g., audit_member, audit_channel if generic_audit_key in ALL_EVENT_KEYS: - if not await self._check_log_enabled(guild.id, generic_audit_key): return + if not await self._check_log_enabled(guild.id, generic_audit_key): + return else: - log.warning(f"No specific or generic toggle key found for unhandled audit action '{entry.action}'. Logging anyway.") - # Or decide to return here if you only want explicitly toggled events logged + log.warning( + f"No specific or generic toggle key found for unhandled audit action '{entry.action}'. Logging anyway." + ) + # Or decide to return here if you only want explicitly toggled events logged title = f"🛡️ Audit Log: {str(entry.action).replace('_', ' ').title()}" action_desc = f"{user.mention} performed action `{entry.action}`" if target: - target_mention = getattr(target, 'mention', str(target)) + target_mention = getattr(target, "mention", str(target)) action_desc += f" on {target_mention}" # self._add_id_footer(embed, target, id_name="Target ID") # Footer set later color = discord.Color.light_grey() - - if not action_desc: # If no description was generated (e.g., skipped update), skip logging - # log.debug(f"Skipping audit log entry {entry.id} (action: {entry.action}) as no action description was generated.") - return + if ( + not action_desc + ): # If no description was generated (e.g., skipped update), skip logging + # log.debug(f"Skipping audit log entry {entry.id} (action: {entry.action}) as no action description was generated.") + return # Create the embed (title is set within the if/elif blocks now) embed = self._create_log_embed( title=title, description=action_desc.strip(), color=color, - author=user # The moderator/actor is the author of the log entry + author=user, # The moderator/actor is the author of the log entry ) if reason: - embed.add_field(name="Reason", value=reason[:1024], inline=False) # Limit reason length + embed.add_field( + name="Reason", value=reason[:1024], inline=False + ) # Limit reason length # Add relevant IDs to footer (target ID if available, otherwise just mod/entry ID) target_id_str = "" if target: target_id_str = f" | Target ID: {target.id}" elif entry.action == discord.AuditLogAction.role_delete: - target_id_str = f" | Role ID: {entry.target.id}" # Get ID from target even if object deleted + target_id_str = f" | Role ID: {entry.target.id}" # Get ID from target even if object deleted elif entry.action == discord.AuditLogAction.channel_delete: - target_id_str = f" | Channel ID: {entry.target.id}" + target_id_str = f" | Channel ID: {entry.target.id}" elif entry.action == discord.AuditLogAction.emoji_delete: - target_id_str = f" | Emoji ID: {entry.target.id}" + target_id_str = f" | Emoji ID: {entry.target.id}" elif entry.action == discord.AuditLogAction.invite_delete: - target_id_str = f" | Invite Code: {entry.before.code}" # Use code for deleted invites - - embed.set_footer(text=f"Audit Log Entry ID: {entry.id} | Moderator ID: {user.id}{target_id_str}") + target_id_str = ( + f" | Invite Code: {entry.before.code}" # Use code for deleted invites + ) + embed.set_footer( + text=f"Audit Log Entry ID: {entry.id} | Moderator ID: {user.id}{target_id_str}" + ) await self._send_log_embed(guild, embed) diff --git a/cogs/marriage_cog.py b/cogs/marriage_cog.py index 5ab8e28..e2ffc7a 100644 --- a/cogs/marriage_cog.py +++ b/cogs/marriage_cog.py @@ -12,10 +12,13 @@ MARRIAGES_FILE = "data/marriages.json" # Ensure the data directory exists os.makedirs(os.path.dirname(MARRIAGES_FILE), exist_ok=True) + class MarriageView(ui.View): """View for marriage proposal buttons""" - def __init__(self, cog: 'MarriageCog', proposer: discord.User, proposed_to: discord.User): + def __init__( + self, cog: "MarriageCog", proposer: discord.User, proposed_to: discord.User + ): super().__init__(timeout=300.0) # 5-minute timeout self.cog = cog self.proposer = proposer @@ -25,7 +28,9 @@ class MarriageView(ui.View): async def interaction_check(self, interaction: discord.Interaction) -> bool: """Only allow the proposed person to interact with the buttons""" if interaction.user.id != self.proposed_to.id: - await interaction.response.send_message("This proposal isn't for you to answer!", ephemeral=True) + await interaction.response.send_message( + "This proposal isn't for you to answer!", ephemeral=True + ) return False return True @@ -40,14 +45,18 @@ class MarriageView(ui.View): # Update the message await self.message.edit( content=f"💔 {self.proposed_to.mention} didn't respond to {self.proposer.mention}'s proposal in time.", - view=self + view=self, ) @discord.ui.button(label="Accept", style=discord.ButtonStyle.success, emoji="💍") - async def accept_button(self, interaction: discord.Interaction, button: discord.ui.Button): + async def accept_button( + self, interaction: discord.Interaction, button: discord.ui.Button + ): """Accept the marriage proposal""" # Create the marriage - success, message = await self.cog.create_marriage(self.proposer, self.proposed_to) + success, message = await self.cog.create_marriage( + self.proposer, self.proposed_to + ) # Disable all buttons for item in self.children: @@ -57,16 +66,15 @@ class MarriageView(ui.View): if success: await interaction.response.edit_message( content=f"💖 {self.proposed_to.mention} has accepted {self.proposer.mention}'s proposal! Congratulations on your marriage!", - view=self + view=self, ) else: - await interaction.response.edit_message( - content=f"❌ {message}", - view=self - ) + await interaction.response.edit_message(content=f"❌ {message}", view=self) @discord.ui.button(label="Decline", style=discord.ButtonStyle.danger, emoji="💔") - async def decline_button(self, interaction: discord.Interaction, button: discord.ui.Button): + async def decline_button( + self, interaction: discord.Interaction, button: discord.ui.Button + ): """Decline the marriage proposal""" # Disable all buttons for item in self.children: @@ -75,9 +83,10 @@ class MarriageView(ui.View): await interaction.response.edit_message( content=f"💔 {self.proposed_to.mention} has declined {self.proposer.mention}'s proposal.", - view=self + view=self, ) + class MarriageCog(commands.Cog): def __init__(self, bot): self.bot = bot @@ -107,18 +116,26 @@ class MarriageCog(commands.Cog): except Exception as e: print(f"Error saving marriages: {e}") - async def create_marriage(self, user1: discord.User, user2: discord.User) -> Tuple[bool, str]: + async def create_marriage( + self, user1: discord.User, user2: discord.User + ) -> Tuple[bool, str]: """Create a new marriage between two users""" # Check if either user is already married user1_id = user1.id user2_id = user2.id # Check if user1 is already married - if user1_id in self.marriages and self.marriages[user1_id]["status"] == "married": + if ( + user1_id in self.marriages + and self.marriages[user1_id]["status"] == "married" + ): return False, f"{user1.display_name} is already married!" # Check if user2 is already married - if user2_id in self.marriages and self.marriages[user2_id]["status"] == "married": + if ( + user2_id in self.marriages + and self.marriages[user2_id]["status"] == "married" + ): return False, f"{user2.display_name} is already married!" # Create marriage data @@ -128,14 +145,14 @@ class MarriageCog(commands.Cog): marriage_data = { "partner_id": user2_id, "marriage_date": marriage_date, - "status": "married" + "status": "married", } self.marriages[user1_id] = marriage_data marriage_data = { "partner_id": user1_id, "marriage_date": marriage_date, - "status": "married" + "status": "married", } self.marriages[user2_id] = marriage_data @@ -146,7 +163,10 @@ class MarriageCog(commands.Cog): async def divorce(self, user_id: int) -> Tuple[bool, str]: """End a marriage""" - if user_id not in self.marriages or self.marriages[user_id]["status"] != "married": + if ( + user_id not in self.marriages + or self.marriages[user_id]["status"] != "married" + ): return False, "You are not currently married!" # Get partner's ID @@ -194,7 +214,9 @@ class MarriageCog(commands.Cog): processed_pairs.add(pair) # Calculate days - marriage_date = datetime.datetime.fromisoformat(marriage_data["marriage_date"]) + marriage_date = datetime.datetime.fromisoformat( + marriage_data["marriage_date"] + ) current_date = datetime.datetime.now() delta = current_date - marriage_date days = delta.days @@ -204,22 +226,33 @@ class MarriageCog(commands.Cog): # Sort by days (descending) return sorted(active_marriages, key=lambda x: x[2], reverse=True) - @app_commands.command(name="propose", description="Propose marriage to another user") + @app_commands.command( + name="propose", description="Propose marriage to another user" + ) @app_commands.describe(user="The user you want to propose to") - async def propose_command(self, interaction: discord.Interaction, user: discord.User): + async def propose_command( + self, interaction: discord.Interaction, user: discord.User + ): """Propose marriage to another user""" proposer = interaction.user # Check if proposing to self if user.id == proposer.id: - await interaction.response.send_message("You can't propose to yourself!", ephemeral=True) + await interaction.response.send_message( + "You can't propose to yourself!", ephemeral=True + ) return # Check if proposer is already married - if proposer.id in self.marriages and self.marriages[proposer.id]["status"] == "married": + if ( + proposer.id in self.marriages + and self.marriages[proposer.id]["status"] == "married" + ): partner_id = self.marriages[proposer.id]["partner_id"] # Use fetch_user instead of get_member to work in both guild and DM contexts - partner = interaction.guild.get_member(partner_id) if interaction.guild else None + partner = ( + interaction.guild.get_member(partner_id) if interaction.guild else None + ) if not partner: # Fallback to bot's fetch_user if not found in guild or in DM context try: @@ -227,14 +260,18 @@ class MarriageCog(commands.Cog): except: pass partner_name = partner.display_name if partner else "someone" - await interaction.response.send_message(f"You're already married to {partner_name}!", ephemeral=True) + await interaction.response.send_message( + f"You're already married to {partner_name}!", ephemeral=True + ) return # Check if proposed person is already married if user.id in self.marriages and self.marriages[user.id]["status"] == "married": partner_id = self.marriages[user.id]["partner_id"] # Use fetch_user instead of get_member to work in both guild and DM contexts - partner = interaction.guild.get_member(partner_id) if interaction.guild else None + partner = ( + interaction.guild.get_member(partner_id) if interaction.guild else None + ) if not partner: # Fallback to bot's fetch_user if not found in guild or in DM context try: @@ -242,7 +279,10 @@ class MarriageCog(commands.Cog): except: pass partner_name = partner.display_name if partner else "someone" - await interaction.response.send_message(f"{user.display_name} is already married to {partner_name}!", ephemeral=True) + await interaction.response.send_message( + f"{user.display_name} is already married to {partner_name}!", + ephemeral=True, + ) return # Create the proposal view @@ -251,19 +291,26 @@ class MarriageCog(commands.Cog): # Send the proposal await interaction.response.send_message( f"💍 {proposer.mention} has proposed to {user.mention}! Will they accept?", - view=view + view=view, ) # Store the message for timeout handling view.message = await interaction.original_response() - @app_commands.command(name="marriage", description="View your current marriage status") + @app_commands.command( + name="marriage", description="View your current marriage status" + ) async def marriage_command(self, interaction: discord.Interaction): """View your current marriage status""" user_id = interaction.user.id - if user_id not in self.marriages or self.marriages[user_id]["status"] != "married": - await interaction.response.send_message("You are not currently married.", ephemeral=False) + if ( + user_id not in self.marriages + or self.marriages[user_id]["status"] != "married" + ): + await interaction.response.send_message( + "You are not currently married.", ephemeral=False + ) return # Get marriage info @@ -271,25 +318,34 @@ class MarriageCog(commands.Cog): partner_id = marriage_data["partner_id"] # Use fetch_user instead of get_member to work in both guild and DM contexts - partner = interaction.guild.get_member(partner_id) if interaction.guild else None + partner = ( + interaction.guild.get_member(partner_id) if interaction.guild else None + ) if not partner: # Fallback to bot's fetch_user if not found in guild or in DM context try: partner = await self.bot.fetch_user(partner_id) except: pass - partner_name = partner.display_name if partner else f"Unknown User ({partner_id})" + partner_name = ( + partner.display_name if partner else f"Unknown User ({partner_id})" + ) # Calculate days days = self.get_marriage_days(user_id) # Create embed - embed = discord.Embed( - title="💖 Marriage Status", - color=discord.Color.pink() + embed = discord.Embed(title="💖 Marriage Status", color=discord.Color.pink()) + embed.add_field( + name="Married To", + value=partner.mention if partner else partner_name, + inline=False, + ) + embed.add_field( + name="Marriage Date", + value=marriage_data["marriage_date"].split("T")[0], + inline=True, ) - embed.add_field(name="Married To", value=partner.mention if partner else partner_name, inline=False) - embed.add_field(name="Marriage Date", value=marriage_data["marriage_date"].split("T")[0], inline=True) embed.add_field(name="Days Married", value=str(days), inline=True) await interaction.response.send_message(embed=embed, ephemeral=False) @@ -300,15 +356,22 @@ class MarriageCog(commands.Cog): user_id = interaction.user.id # Check if user is married - if user_id not in self.marriages or self.marriages[user_id]["status"] != "married": - await interaction.response.send_message("You are not currently married.", ephemeral=True) + if ( + user_id not in self.marriages + or self.marriages[user_id]["status"] != "married" + ): + await interaction.response.send_message( + "You are not currently married.", ephemeral=True + ) return # Get partner info partner_id = self.marriages[user_id]["partner_id"] # Use fetch_user instead of get_member to work in both guild and DM contexts - partner = interaction.guild.get_member(partner_id) if interaction.guild else None + partner = ( + interaction.guild.get_member(partner_id) if interaction.guild else None + ) if not partner: # Fallback to bot's fetch_user if not found in guild or in DM context try: @@ -321,7 +384,10 @@ class MarriageCog(commands.Cog): success, message = await self.divorce(user_id) if success: - await interaction.response.send_message(f"💔 {interaction.user.mention} has divorced {partner_name}. The marriage has ended.", ephemeral=False) + await interaction.response.send_message( + f"💔 {interaction.user.mention} has divorced {partner_name}. The marriage has ended.", + ephemeral=False, + ) else: await interaction.response.send_message(message, ephemeral=True) @@ -331,21 +397,27 @@ class MarriageCog(commands.Cog): marriages = self.get_all_marriages() if not marriages: - await interaction.response.send_message("There are no active marriages.", ephemeral=False) + await interaction.response.send_message( + "There are no active marriages.", ephemeral=False + ) return # Create embed embed = discord.Embed( title="💖 Marriage Leaderboard", description="Marriages ranked by duration", - color=discord.Color.pink() + color=discord.Color.pink(), ) # Add top 10 marriages for i, (user1_id, user2_id, days) in enumerate(marriages[:10], 1): # Use fetch_user instead of get_member to work in both guild and DM contexts - user1 = interaction.guild.get_member(user1_id) if interaction.guild else None - user2 = interaction.guild.get_member(user2_id) if interaction.guild else None + user1 = ( + interaction.guild.get_member(user1_id) if interaction.guild else None + ) + user2 = ( + interaction.guild.get_member(user2_id) if interaction.guild else None + ) # Fallback to bot's fetch_user if not found in guild or in DM context if not user1: @@ -366,10 +438,11 @@ class MarriageCog(commands.Cog): embed.add_field( name=f"{i}. {user1_name} & {user2_name}", value=f"{days} days", - inline=False + inline=False, ) await interaction.response.send_message(embed=embed, ephemeral=False) + async def setup(bot: commands.Bot): await bot.add_cog(MarriageCog(bot)) diff --git a/cogs/message_cog.py b/cogs/message_cog.py index d907868..4801209 100644 --- a/cogs/message_cog.py +++ b/cogs/message_cog.py @@ -10,11 +10,12 @@ from .rp_messages import ( get_headpat_messages, get_cumshot_messages, get_kiss_messages, - get_hug_messages + get_hug_messages, ) log = logging.getLogger(__name__) + class MessageCog(commands.Cog): def __init__(self, bot): self.bot = bot @@ -25,13 +26,14 @@ class MessageCog(commands.Cog): async def _ensure_usage_table_exists(self): """Ensure the command usage counters table exists.""" - if not hasattr(self.bot, 'pg_pool') or not self.bot.pg_pool: + if not hasattr(self.bot, "pg_pool") or not self.bot.pg_pool: log.warning("Database pool not available for usage tracking.") return False try: async with self.bot.pg_pool.acquire() as conn: - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS command_usage_counters ( user1_id BIGINT NOT NULL, user2_id BIGINT NOT NULL, @@ -39,46 +41,65 @@ class MessageCog(commands.Cog): usage_count INTEGER NOT NULL DEFAULT 1, PRIMARY KEY (user1_id, user2_id, command_name) ) - """) + """ + ) return True except Exception as e: log.error(f"Error creating usage counters table: {e}") return False - async def _increment_usage_counter(self, user1_id: int, user2_id: int, command_name: str): + async def _increment_usage_counter( + self, user1_id: int, user2_id: int, command_name: str + ): """Increment the usage counter for a command between two users.""" if not await self._ensure_usage_table_exists(): return try: async with self.bot.pg_pool.acquire() as conn: - await conn.execute(""" + await conn.execute( + """ INSERT INTO command_usage_counters (user1_id, user2_id, command_name, usage_count) VALUES ($1, $2, $3, 1) ON CONFLICT (user1_id, user2_id, command_name) DO UPDATE SET usage_count = command_usage_counters.usage_count + 1 - """, user1_id, user2_id, command_name) - log.debug(f"Incremented usage counter for {command_name} between users {user1_id} and {user2_id}") + """, + user1_id, + user2_id, + command_name, + ) + log.debug( + f"Incremented usage counter for {command_name} between users {user1_id} and {user2_id}" + ) except Exception as e: log.error(f"Error incrementing usage counter: {e}") - async def _get_usage_count(self, user1_id: int, user2_id: int, command_name: str) -> int: + async def _get_usage_count( + self, user1_id: int, user2_id: int, command_name: str + ) -> int: """Get the usage count for a command between two users.""" if not await self._ensure_usage_table_exists(): return 0 try: async with self.bot.pg_pool.acquire() as conn: - count = await conn.fetchval(""" + count = await conn.fetchval( + """ SELECT usage_count FROM command_usage_counters WHERE user1_id = $1 AND user2_id = $2 AND command_name = $3 - """, user1_id, user2_id, command_name) + """, + user1_id, + user2_id, + command_name, + ) return count if count is not None else 0 except Exception as e: log.error(f"Error getting usage count: {e}") return 0 - async def _get_bidirectional_usage_counts(self, user1_id: int, user2_id: int, command_name: str) -> tuple[int, int]: + async def _get_bidirectional_usage_counts( + self, user1_id: int, user2_id: int, command_name: str + ) -> tuple[int, int]: """Get the usage counts for a command in both directions between two users. Returns: @@ -90,19 +111,31 @@ class MessageCog(commands.Cog): try: async with self.bot.pg_pool.acquire() as conn: # Get count for user1 -> user2 - count_1_to_2 = await conn.fetchval(""" + count_1_to_2 = await conn.fetchval( + """ SELECT usage_count FROM command_usage_counters WHERE user1_id = $1 AND user2_id = $2 AND command_name = $3 - """, user1_id, user2_id, command_name) + """, + user1_id, + user2_id, + command_name, + ) # Get count for user2 -> user1 - count_2_to_1 = await conn.fetchval(""" + count_2_to_1 = await conn.fetchval( + """ SELECT usage_count FROM command_usage_counters WHERE user1_id = $1 AND user2_id = $2 AND command_name = $3 - """, user2_id, user1_id, command_name) + """, + user2_id, + user1_id, + command_name, + ) - return (count_1_to_2 if count_1_to_2 is not None else 0, - count_2_to_1 if count_2_to_1 is not None else 0) + return ( + count_1_to_2 if count_1_to_2 is not None else 0, + count_2_to_1 if count_2_to_1 is not None else 0, + ) except Exception as e: log.error(f"Error getting bidirectional usage counts: {e}") return 0, 0 @@ -116,17 +149,23 @@ class MessageCog(commands.Cog): # --- RP Group --- rp = app_commands.Group(name="rp", description="Roleplay commands") - @rp.command(name="molest", description="Send a hardcoded message to the mentioned user") + @rp.command( + name="molest", description="Send a hardcoded message to the mentioned user" + ) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.describe(member="The user to send the message to") - async def molest_slash(self, interaction: discord.Interaction, member: discord.User): + async def molest_slash( + self, interaction: discord.Interaction, member: discord.User + ): """Slash command version of message.""" # Track usage between the two users await self._increment_usage_counter(interaction.user.id, member.id, "molest") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(interaction.user.id, member.id, "molest") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + interaction.user.id, member.id, "molest" + ) response = await self._message_logic(member.mention) response += f"\n-# {interaction.user.display_name} has molested {member.display_name} {caller_to_target} {self.plural('time', caller_to_target)}" @@ -141,7 +180,9 @@ class MessageCog(commands.Cog): await self._increment_usage_counter(ctx.author.id, member.id, "molest") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(ctx.author.id, member.id, "molest") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + ctx.author.id, member.id, "molest" + ) response = await self._message_logic(member.mention) response += f"\n-# {ctx.author.display_name} has molested {member.display_name} {caller_to_target} {self.plural('time', caller_to_target)}" @@ -149,7 +190,10 @@ class MessageCog(commands.Cog): response += f", {member.display_name} has molested {ctx.author.display_name} {target_to_caller} {self.plural('time', target_to_caller)}" await ctx.reply(response) - @rp.command(name="rape", description="Sends a message stating the author raped the mentioned user.") + @rp.command( + name="rape", + description="Sends a message stating the author raped the mentioned user.", + ) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.describe(member="The user to mention in the message") @@ -159,9 +203,13 @@ class MessageCog(commands.Cog): await self._increment_usage_counter(interaction.user.id, member.id, "rape") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(interaction.user.id, member.id, "rape") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + interaction.user.id, member.id, "rape" + ) - response = random.choice(get_rape_messages(interaction.user.mention, member.mention)) + response = random.choice( + get_rape_messages(interaction.user.mention, member.mention) + ) response += f"\n-# {interaction.user.display_name} has raped {member.display_name} {caller_to_target} {self.plural('time', caller_to_target)}" if target_to_caller > 0: response += f", {member.display_name} has raped {interaction.user.display_name} {target_to_caller} {self.plural('time', target_to_caller)}" @@ -174,7 +222,9 @@ class MessageCog(commands.Cog): await self._increment_usage_counter(ctx.author.id, member.id, "rape") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(ctx.author.id, member.id, "rape") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + ctx.author.id, member.id, "rape" + ) response = random.choice(get_rape_messages(ctx.author.mention, member.mention)) response += f"\n-# {ctx.author.display_name} has raped {member.display_name} {caller_to_target} {self.plural('time', caller_to_target)}" @@ -182,7 +232,9 @@ class MessageCog(commands.Cog): response += f", {member.display_name} has raped {ctx.author.display_name} {target_to_caller} {self.plural('time', target_to_caller)}" await ctx.reply(response) - @rp.command(name="sex", description="Send a normal sex message to the mentioned user") + @rp.command( + name="sex", description="Send a normal sex message to the mentioned user" + ) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.describe(member="The user to send the message to") @@ -192,9 +244,13 @@ class MessageCog(commands.Cog): await self._increment_usage_counter(interaction.user.id, member.id, "sex") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(interaction.user.id, member.id, "sex") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + interaction.user.id, member.id, "sex" + ) - response = random.choice(get_sex_messages(interaction.user.mention, member.mention)) + response = random.choice( + get_sex_messages(interaction.user.mention, member.mention) + ) response += f"\n-# {interaction.user.display_name} and {member.display_name} have had sex {caller_to_target} {self.plural('time', caller_to_target)}" if target_to_caller > 0: response += f", {member.display_name} and {interaction.user.display_name} have had sex {target_to_caller} {self.plural('time', target_to_caller)}" @@ -207,7 +263,9 @@ class MessageCog(commands.Cog): await self._increment_usage_counter(ctx.author.id, member.id, "sex") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(ctx.author.id, member.id, "sex") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + ctx.author.id, member.id, "sex" + ) response = random.choice(get_sex_messages(ctx.author.mention, member.mention)) response += f"\n-# {ctx.author.display_name} and {member.display_name} have had sex {caller_to_target} {self.plural('time', caller_to_target)}" @@ -215,21 +273,32 @@ class MessageCog(commands.Cog): response += f", {member.display_name} and {ctx.author.display_name} have had sex {target_to_caller} {self.plural('time', target_to_caller)}" await ctx.reply(response) - @rp.command(name="headpat", description="Send a wholesome headpat message to the mentioned user") + @rp.command( + name="headpat", + description="Send a wholesome headpat message to the mentioned user", + ) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.describe(member="The user to send the message to") - async def headpat_slash(self, interaction: discord.Interaction, member: discord.User): + async def headpat_slash( + self, interaction: discord.Interaction, member: discord.User + ): """Slash command version of headpat.""" # Track usage between the two users await self._increment_usage_counter(interaction.user.id, member.id, "headpat") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(interaction.user.id, member.id, "headpat") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + interaction.user.id, member.id, "headpat" + ) - response = random.choice(get_headpat_messages(interaction.user.mention, member.mention)) + response = random.choice( + get_headpat_messages(interaction.user.mention, member.mention) + ) # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(interaction.user.id, member.id, "headpat") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + interaction.user.id, member.id, "headpat" + ) response += f"\n-# {interaction.user.display_name} has headpatted {member.display_name} {caller_to_target} {self.plural('time', caller_to_target)}" if target_to_caller > 0: @@ -243,24 +312,36 @@ class MessageCog(commands.Cog): await self._increment_usage_counter(ctx.author.id, member.id, "headpat") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(ctx.author.id, member.id, "headpat") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + ctx.author.id, member.id, "headpat" + ) - response = random.choice(get_headpat_messages(ctx.author.mention, member.mention)) + response = random.choice( + get_headpat_messages(ctx.author.mention, member.mention) + ) response += f"\n-# {ctx.author.display_name} has headpatted {member.display_name} {caller_to_target} {self.plural('time', caller_to_target)}" if target_to_caller > 0: response += f", {member.display_name} has headpatted {ctx.author.display_name} {target_to_caller} {self.plural('time', target_to_caller)}" await ctx.reply(response) - @rp.command(name="cumshot", description="Send a cumshot message to the mentioned user") + @rp.command( + name="cumshot", description="Send a cumshot message to the mentioned user" + ) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.describe(member="The user to send the message to") - async def cumshot_slash(self, interaction: discord.Interaction, member: discord.User): + async def cumshot_slash( + self, interaction: discord.Interaction, member: discord.User + ): """Slash command version of cumshot.""" await self._increment_usage_counter(interaction.user.id, member.id, "cumshot") - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(interaction.user.id, member.id, "cumshot") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + interaction.user.id, member.id, "cumshot" + ) - response = random.choice(get_cumshot_messages(interaction.user.mention, member.mention)) + response = random.choice( + get_cumshot_messages(interaction.user.mention, member.mention) + ) response += f"\n-# {interaction.user.display_name} has came on {member.display_name} {caller_to_target} {self.plural('time', caller_to_target)}" if target_to_caller > 0: response += f", {member.display_name} has came on {interaction.user.display_name} {target_to_caller} {self.plural('time', target_to_caller)}" @@ -270,14 +351,21 @@ class MessageCog(commands.Cog): async def cumshot_legacy(self, ctx: commands.Context, member: discord.User): """Legacy command version of cumshot.""" await self._increment_usage_counter(ctx.author.id, member.id, "cumshot") - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(ctx.author.id, member.id, "cumshot") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + ctx.author.id, member.id, "cumshot" + ) - response = random.choice(get_cumshot_messages(ctx.author.mention, member.mention)) + response = random.choice( + get_cumshot_messages(ctx.author.mention, member.mention) + ) response += f"\n-# {ctx.author.display_name} has came on {member.display_name} {caller_to_target} {self.plural('time', caller_to_target)}" if target_to_caller > 0: response += f", {member.display_name} has came on {ctx.author.display_name} {target_to_caller} {self.plural('time', target_to_caller)}" await ctx.reply(response) - @rp.command(name="kiss", description="Send a wholesome kiss message to the mentioned user") + + @rp.command( + name="kiss", description="Send a wholesome kiss message to the mentioned user" + ) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.describe(member="The user to send the message to") @@ -287,9 +375,13 @@ class MessageCog(commands.Cog): await self._increment_usage_counter(interaction.user.id, member.id, "kiss") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(interaction.user.id, member.id, "kiss") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + interaction.user.id, member.id, "kiss" + ) - response = random.choice(get_kiss_messages(interaction.user.mention, member.mention)) + response = random.choice( + get_kiss_messages(interaction.user.mention, member.mention) + ) response += f"\n-# {interaction.user.display_name} and {member.display_name} have kissed {caller_to_target} {self.plural('time', caller_to_target)}" if target_to_caller > 0: response += f", {member.display_name} and {interaction.user.display_name} have kissed {target_to_caller} {self.plural('time', target_to_caller)}" @@ -302,7 +394,9 @@ class MessageCog(commands.Cog): await self._increment_usage_counter(ctx.author.id, member.id, "kiss") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(ctx.author.id, member.id, "kiss") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + ctx.author.id, member.id, "kiss" + ) response = random.choice(get_kiss_messages(ctx.author.mention, member.mention)) response += f"\n-# {ctx.author.display_name} and {member.display_name} have kissed {caller_to_target} {self.plural('time', caller_to_target)}" @@ -310,7 +404,9 @@ class MessageCog(commands.Cog): response += f", {member.display_name} and {ctx.author.display_name} have kissed {target_to_caller} {self.plural('time', target_to_caller)}" await ctx.reply(response) - @rp.command(name="hug", description="Send a wholesome hug message to the mentioned user") + @rp.command( + name="hug", description="Send a wholesome hug message to the mentioned user" + ) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.describe(member="The user to send the message to") @@ -320,9 +416,13 @@ class MessageCog(commands.Cog): await self._increment_usage_counter(interaction.user.id, member.id, "hug") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(interaction.user.id, member.id, "hug") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + interaction.user.id, member.id, "hug" + ) - response = random.choice(get_hug_messages(interaction.user.mention, member.mention)) + response = random.choice( + get_hug_messages(interaction.user.mention, member.mention) + ) response += f"\n-# {interaction.user.display_name} and {member.display_name} have hugged {caller_to_target} {self.plural('time', caller_to_target)}" if target_to_caller > 0: response += f", {member.display_name} and {interaction.user.display_name} have hugged {target_to_caller} {self.plural('time', target_to_caller)}" @@ -335,45 +435,76 @@ class MessageCog(commands.Cog): await self._increment_usage_counter(ctx.author.id, member.id, "hug") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(ctx.author.id, member.id, "hug") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + ctx.author.id, member.id, "hug" + ) response = random.choice(get_hug_messages(ctx.author.mention, member.mention)) response += f"\n-# {ctx.author.display_name} and {member.display_name} have hugged {caller_to_target} {self.plural('time', caller_to_target)}" if target_to_caller > 0: response += f", {member.display_name} and {ctx.author.display_name} have hugged {target_to_caller} {self.plural('time', target_to_caller)}" await ctx.reply(response) + # --- Memes Group --- memes = app_commands.Group(name="memes", description="Meme and copypasta commands") - @memes.command(name="seals", description="What the fuck did you just fucking say about me, you little bitch?") + @memes.command( + name="seals", + description="What the fuck did you just fucking say about me, you little bitch?", + ) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) async def seals_slash(self, interaction: discord.Interaction): - await interaction.response.send_message("What the fuck did you just fucking say about me, you little bitch? I'll have you know I graduated top of my class in the Navy Seals, and I've been involved in numerous secret raids on Al-Quaeda, and I have over 300 confirmed kills. I am trained in gorilla warfare and I'm the top sniper in the entire US armed forces. You are nothing to me but just another target. I will wipe you the fuck out with precision the likes of which has never been seen before on this Earth, mark my fucking words. You think you can get away with saying that shit to me over the Internet? Think again, fucker. As we speak I am contacting my secret network of spies across the USA and your IP is being traced right now so you better prepare for the storm, maggot. The storm that wipes out the pathetic little thing you call your life. You're fucking dead, kid. I can be anywhere, anytime, and I can kill you in over seven hundred ways, and that's just with my bare hands. Not only am I extensively trained in unarmed combat, but I have access to the entire arsenal of the United States Marine Corps and I will use it to its full extent to wipe your miserable ass off the face of the continent, you little shit. If only you could have known what unholy retribution your little \"clever\" comment was about to bring down upon you, maybe you would have held your fucking tongue. But you couldn't, you didn't, and now you're paying the price, you goddamn idiot. I will shit fury all over you and you will drown in it. You're fucking dead, kiddo.") + await interaction.response.send_message( + "What the fuck did you just fucking say about me, you little bitch? I'll have you know I graduated top of my class in the Navy Seals, and I've been involved in numerous secret raids on Al-Quaeda, and I have over 300 confirmed kills. I am trained in gorilla warfare and I'm the top sniper in the entire US armed forces. You are nothing to me but just another target. I will wipe you the fuck out with precision the likes of which has never been seen before on this Earth, mark my fucking words. You think you can get away with saying that shit to me over the Internet? Think again, fucker. As we speak I am contacting my secret network of spies across the USA and your IP is being traced right now so you better prepare for the storm, maggot. The storm that wipes out the pathetic little thing you call your life. You're fucking dead, kid. I can be anywhere, anytime, and I can kill you in over seven hundred ways, and that's just with my bare hands. Not only am I extensively trained in unarmed combat, but I have access to the entire arsenal of the United States Marine Corps and I will use it to its full extent to wipe your miserable ass off the face of the continent, you little shit. If only you could have known what unholy retribution your little \"clever\" comment was about to bring down upon you, maybe you would have held your fucking tongue. But you couldn't, you didn't, and now you're paying the price, you goddamn idiot. I will shit fury all over you and you will drown in it. You're fucking dead, kiddo." + ) - @commands.command(name="seals", help="What the fuck did you just fucking say about me, you little bitch?") # Assuming you want to keep this check for the legacy command + @commands.command( + name="seals", + help="What the fuck did you just fucking say about me, you little bitch?", + ) # Assuming you want to keep this check for the legacy command async def seals_legacy(self, ctx): - await ctx.send("What the fuck did you just fucking say about me, you little bitch? I'll have you know I graduated top of my class in the Navy Seals, and I've been involved in numerous secret raids on Al-Quaeda, and I have over 300 confirmed kills. I am trained in gorilla warfare and I'm the top sniper in the entire US armed forces. You are nothing to me but just another target. I will wipe you the fuck out with precision the likes of which has never been seen before on this Earth, mark my fucking words. You think you can get away with saying that shit to me over the Internet? Think again, fucker. As we speak I am contacting my secret network of spies across the USA and your IP is being traced right now so you better prepare for the storm, maggot. The storm that wipes out the pathetic little thing you call your life. You're fucking dead, kid. I can be anywhere, anytime, and I can kill you in over seven hundred ways, and that's just with my bare hands. Not only am I extensively trained in unarmed combat, but I have access to the entire arsenal of the United States Marine Corps and I will use it to its full extent to wipe your miserable ass off the face of the continent, you little shit. If only you could have known what unholy retribution your little \"clever\" comment was about to bring down upon you, maybe you would have held your fucking tongue. But you couldn't, you didn't, and now you're paying the price, you goddamn idiot. I will shit fury all over you and you will drown in it. You're fucking dead, kiddo.") + await ctx.send( + "What the fuck did you just fucking say about me, you little bitch? I'll have you know I graduated top of my class in the Navy Seals, and I've been involved in numerous secret raids on Al-Quaeda, and I have over 300 confirmed kills. I am trained in gorilla warfare and I'm the top sniper in the entire US armed forces. You are nothing to me but just another target. I will wipe you the fuck out with precision the likes of which has never been seen before on this Earth, mark my fucking words. You think you can get away with saying that shit to me over the Internet? Think again, fucker. As we speak I am contacting my secret network of spies across the USA and your IP is being traced right now so you better prepare for the storm, maggot. The storm that wipes out the pathetic little thing you call your life. You're fucking dead, kid. I can be anywhere, anytime, and I can kill you in over seven hundred ways, and that's just with my bare hands. Not only am I extensively trained in unarmed combat, but I have access to the entire arsenal of the United States Marine Corps and I will use it to its full extent to wipe your miserable ass off the face of the continent, you little shit. If only you could have known what unholy retribution your little \"clever\" comment was about to bring down upon you, maybe you would have held your fucking tongue. But you couldn't, you didn't, and now you're paying the price, you goddamn idiot. I will shit fury all over you and you will drown in it. You're fucking dead, kiddo." + ) - @memes.command(name="notlikeus", description="Honestly i think They Not Like Us is the only mumble rap song that is good") + @memes.command( + name="notlikeus", + description="Honestly i think They Not Like Us is the only mumble rap song that is good", + ) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) async def notlikeus_slash(self, interaction: discord.Interaction): - await interaction.response.send_message("Honestly i think They Not Like Us is the only mumble rap song that is good, because it calls out Drake for being a Diddy blud") + await interaction.response.send_message( + "Honestly i think They Not Like Us is the only mumble rap song that is good, because it calls out Drake for being a Diddy blud" + ) - @commands.command(name="notlikeus", help="Honestly i think They Not Like Us is the only mumble rap song that is good") # Assuming you want to keep this check for the legacy command + @commands.command( + name="notlikeus", + help="Honestly i think They Not Like Us is the only mumble rap song that is good", + ) # Assuming you want to keep this check for the legacy command async def notlikeus_legacy(self, ctx): - await ctx.send("Honestly i think They Not Like Us is the only mumble rap song that is good, because it calls out Drake for being a Diddy blud") + await ctx.send( + "Honestly i think They Not Like Us is the only mumble rap song that is good, because it calls out Drake for being a Diddy blud" + ) @memes.command(name="pmo", description="icl u pmo") @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) async def pmo_slash(self, interaction: discord.Interaction): - await interaction.response.send_message("icl u pmo n ts pmo sm ngl r u fr rn b fr I h8 bein diff idek anm mn js I h8 ts y r u so b so fr w me rn cz lol oms icl ts pmo sm n sb rn ngl, r u srsly srs n fr rn vro? lol atp js qt") + await interaction.response.send_message( + "icl u pmo n ts pmo sm ngl r u fr rn b fr I h8 bein diff idek anm mn js I h8 ts y r u so b so fr w me rn cz lol oms icl ts pmo sm n sb rn ngl, r u srsly srs n fr rn vro? lol atp js qt" + ) - @commands.command(name="pmo", help="icl u pmo n ts pmo sm ngl r u fr rn b fr I h8 bein diff idek anm mn js I h8 ts y r u so b so fr w me rn cz lol oms icl ts pmo sm n sb rn ngl, r u srsly srs n fr rn vro? lol atp js qt") + @commands.command( + name="pmo", + help="icl u pmo n ts pmo sm ngl r u fr rn b fr I h8 bein diff idek anm mn js I h8 ts y r u so b so fr w me rn cz lol oms icl ts pmo sm n sb rn ngl, r u srsly srs n fr rn vro? lol atp js qt", + ) async def pmo_legacy(self, ctx: commands.Context): - await ctx.send("icl u pmo n ts pmo sm ngl r u fr rn b fr I h8 bein diff idek anm mn js I h8 ts y r u so b so fr w me rn cz lol oms icl ts pmo sm n sb rn ngl, r u srsly srs n fr rn vro? lol atp js qt") + await ctx.send( + "icl u pmo n ts pmo sm ngl r u fr rn b fr I h8 bein diff idek anm mn js I h8 ts y r u so b so fr w me rn cz lol oms icl ts pmo sm n sb rn ngl, r u srsly srs n fr rn vro? lol atp js qt" + ) + async def setup(bot: commands.Bot): await bot.add_cog(MessageCog(bot)) diff --git a/cogs/message_scraper_cog.py b/cogs/message_scraper_cog.py index 41404d1..039fa2d 100644 --- a/cogs/message_scraper_cog.py +++ b/cogs/message_scraper_cog.py @@ -2,6 +2,7 @@ import discord from discord.ext import commands import io + class MessageScraperCog(commands.Cog): def __init__(self, bot): self.bot = bot @@ -16,7 +17,9 @@ class MessageScraperCog(commands.Cog): # The user wants exactly 'limit' messages, excluding bots and empty content. # We need to fetch more than 'limit' and then filter. # Set a reasonable max_fetch_limit to prevent excessive fetching in very sparse channels. - max_fetch_limit = limit * 5 if limit * 5 < 10000 else 10000 # Fetch up to 5x the limit, or 1000, whichever is smaller + max_fetch_limit = ( + limit * 5 if limit * 5 < 10000 else 10000 + ) # Fetch up to 5x the limit, or 1000, whichever is smaller messages_data = [] fetched_count = 0 @@ -28,7 +31,9 @@ class MessageScraperCog(commands.Cog): reply_info = "" if message.reference and message.reference.message_id: try: - replied_message = await ctx.channel.fetch_message(message.reference.message_id) + replied_message = await ctx.channel.fetch_message( + message.reference.message_id + ) reply_info = f" (In reply to: '{replied_message.author.display_name}: {replied_message.content[:50]}...')" except discord.NotFound: reply_info = " (In reply to: [Original message not found])" @@ -43,23 +48,24 @@ class MessageScraperCog(commands.Cog): if len(messages_data) >= limit: break - + if not messages_data: return await ctx.send("No valid messages found matching the criteria.") - + # Trim messages_data to the requested limit if more were collected messages_data = messages_data[:limit] output_content = "\n".join(messages_data) - + # Create a file-like object from the string content - file_data = io.BytesIO(output_content.encode('utf-8')) - + file_data = io.BytesIO(output_content.encode("utf-8")) + # Send the file await ctx.send( f"Here are the last {len(messages_data)} messages from this channel (excluding bots):", - file=discord.File(file_data, filename="scraped_messages.txt") + file=discord.File(file_data, filename="scraped_messages.txt"), ) + async def setup(bot): await bot.add_cog(MessageScraperCog(bot)) diff --git a/cogs/mod_application_cog.py b/cogs/mod_application_cog.py index 8cf1433..e1f75eb 100644 --- a/cogs/mod_application_cog.py +++ b/cogs/mod_application_cog.py @@ -9,7 +9,7 @@ from typing import Optional, List, Dict, Any, Union, Literal, Tuple # Configure logging logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) # Ensure info messages are captured +logger.setLevel(logging.INFO) # Ensure info messages are captured # Application statuses APPLICATION_STATUS = Literal["PENDING", "APPROVED", "REJECTED", "UNDER_REVIEW"] @@ -51,38 +51,39 @@ DEFAULT_QUESTIONS = [ "label": "How old are you?", "style": discord.TextStyle.short, "required": True, - "max_length": 10 + "max_length": 10, }, { "id": "experience", "label": "Previous moderation experience?", "style": discord.TextStyle.paragraph, "required": True, - "max_length": 1000 + "max_length": 1000, }, { "id": "time_available", "label": "Hours per week available for moderation?", "style": discord.TextStyle.short, "required": True, - "max_length": 50 + "max_length": 50, }, { "id": "timezone", "label": "What is your timezone?", "style": discord.TextStyle.short, "required": True, - "max_length": 50 + "max_length": 50, }, { "id": "why_mod", "label": "Why do you want to be a moderator?", "style": discord.TextStyle.paragraph, "required": True, - "max_length": 1000 - } + "max_length": 1000, + }, ] + class ModApplicationModal(ui.Modal): """Modal for moderator application form""" @@ -100,7 +101,7 @@ class ModApplicationModal(ui.Modal): style=q["style"], required=q.get("required", True), max_length=q.get("max_length", 1000), - placeholder=q.get("placeholder", "") + placeholder=q.get("placeholder", ""), ) # Store the question ID as a custom attribute text_input.custom_id = q["id"] @@ -115,22 +116,27 @@ class ModApplicationModal(ui.Modal): form_data[child.custom_id] = child.value # Submit application to database - success = await self.cog.submit_application(interaction.guild_id, interaction.user.id, form_data) + success = await self.cog.submit_application( + interaction.guild_id, interaction.user.id, form_data + ) if success: await interaction.response.send_message( "✅ Your moderator application has been submitted successfully! You will be notified when it's reviewed.", - ephemeral=True + ephemeral=True, ) # Notify staff in the review channel - await self.cog.notify_new_application(interaction.guild, interaction.user, form_data) + await self.cog.notify_new_application( + interaction.guild, interaction.user, form_data + ) else: await interaction.response.send_message( "❌ There was an error submitting your application. You may have an existing application pending review.", - ephemeral=True + ephemeral=True, ) + class ApplicationReviewView(ui.View): """View for reviewing moderator applications""" @@ -139,29 +145,30 @@ class ApplicationReviewView(ui.View): self.cog = cog self.application_data = application_data - @ui.button(label="Approve", style=discord.ButtonStyle.green, custom_id="approve_application") + @ui.button( + label="Approve", + style=discord.ButtonStyle.green, + custom_id="approve_application", + ) async def approve_button(self, interaction: discord.Interaction, button: ui.Button): """Approve the application""" await self.cog.update_application_status( - self.application_data["application_id"], - "APPROVED", - interaction.user.id + self.application_data["application_id"], "APPROVED", interaction.user.id ) # Update the message await interaction.response.edit_message( - content=f"✅ Application approved by {interaction.user.mention}", - view=None + content=f"✅ Application approved by {interaction.user.mention}", view=None ) # Notify the applicant await self.cog.notify_application_status_change( - interaction.guild, - self.application_data["user_id"], - "APPROVED" + interaction.guild, self.application_data["user_id"], "APPROVED" ) - @ui.button(label="Reject", style=discord.ButtonStyle.red, custom_id="reject_application") + @ui.button( + label="Reject", style=discord.ButtonStyle.red, custom_id="reject_application" + ) async def reject_button(self, interaction: discord.Interaction, button: ui.Button): """Reject the application""" # Show rejection reason modal @@ -169,28 +176,29 @@ class ApplicationReviewView(ui.View): RejectionReasonModal(self.cog, self.application_data) ) - @ui.button(label="Under Review", style=discord.ButtonStyle.blurple, custom_id="review_application") + @ui.button( + label="Under Review", + style=discord.ButtonStyle.blurple, + custom_id="review_application", + ) async def review_button(self, interaction: discord.Interaction, button: ui.Button): """Mark application as under review""" await self.cog.update_application_status( - self.application_data["application_id"], - "UNDER_REVIEW", - interaction.user.id + self.application_data["application_id"], "UNDER_REVIEW", interaction.user.id ) # Update the message await interaction.response.edit_message( content=f"🔍 Application marked as under review by {interaction.user.mention}", - view=self + view=self, ) # Notify the applicant await self.cog.notify_application_status_change( - interaction.guild, - self.application_data["user_id"], - "UNDER_REVIEW" + interaction.guild, self.application_data["user_id"], "UNDER_REVIEW" ) + class RejectionReasonModal(ui.Modal, title="Rejection Reason"): """Modal for providing rejection reason""" @@ -198,7 +206,7 @@ class RejectionReasonModal(ui.Modal, title="Rejection Reason"): label="Reason for rejection", style=discord.TextStyle.paragraph, required=True, - max_length=1000 + max_length=1000, ) def __init__(self, cog, application_data): @@ -212,13 +220,12 @@ class RejectionReasonModal(ui.Modal, title="Rejection Reason"): self.application_data["application_id"], "REJECTED", interaction.user.id, - notes=self.reason.value + notes=self.reason.value, ) # Update the message await interaction.response.edit_message( - content=f"❌ Application rejected by {interaction.user.mention}", - view=None + content=f"❌ Application rejected by {interaction.user.mention}", view=None ) # Notify the applicant @@ -226,9 +233,10 @@ class RejectionReasonModal(ui.Modal, title="Rejection Reason"): interaction.guild, self.application_data["user_id"], "REJECTED", - reason=self.reason.value + reason=self.reason.value, ) + class ModApplicationCog(commands.Cog): """Cog for handling moderator applications using Discord forms""" @@ -238,8 +246,7 @@ class ModApplicationCog(commands.Cog): # Create the main command group for this cog self.modapp_group = app_commands.Group( - name="modapp", - description="Moderator application system commands" + name="modapp", description="Moderator application system commands" ) # Register commands @@ -250,17 +257,21 @@ class ModApplicationCog(commands.Cog): async def cog_load(self): """Setup database tables when cog is loaded""" - if hasattr(self.bot, 'pg_pool') and self.bot.pg_pool: + if hasattr(self.bot, "pg_pool") and self.bot.pg_pool: try: async with self.bot.pg_pool.acquire() as conn: await conn.execute(CREATE_MOD_APPLICATIONS_TABLE) await conn.execute(CREATE_MOD_APPLICATION_SETTINGS_TABLE) # Add the missing column if it doesn't exist - await conn.execute(""" + await conn.execute( + """ ALTER TABLE mod_application_settings ADD COLUMN IF NOT EXISTS log_new_applications BOOLEAN NOT NULL DEFAULT FALSE; - """) - logger.info("Moderator application tables created and/or updated successfully") + """ + ) + logger.info( + "Moderator application tables created and/or updated successfully" + ) except Exception as e: logger.error(f"Error creating moderator application tables: {e}") else: @@ -274,7 +285,7 @@ class ModApplicationCog(commands.Cog): name="apply", description="Apply to become a moderator for this server", callback=self.apply_callback, - parent=self.modapp_group + parent=self.modapp_group, ) self.modapp_group.add_command(apply_command) @@ -283,11 +294,9 @@ class ModApplicationCog(commands.Cog): name="list", description="List all moderator applications", callback=self.list_applications_callback, - parent=self.modapp_group + parent=self.modapp_group, ) - app_commands.describe( - status="Filter applications by status" - )(list_command) + app_commands.describe(status="Filter applications by status")(list_command) self.modapp_group.add_command(list_command) # --- View Application Command --- @@ -295,11 +304,11 @@ class ModApplicationCog(commands.Cog): name="view", description="View details of a specific application", callback=self.view_application_callback, - parent=self.modapp_group + parent=self.modapp_group, + ) + app_commands.describe(application_id="The ID of the application to view")( + view_command ) - app_commands.describe( - application_id="The ID of the application to view" - )(view_command) self.modapp_group.add_command(view_command) # --- Settings Commands (Direct children of modapp_group) --- @@ -309,7 +318,7 @@ class ModApplicationCog(commands.Cog): name="settings_toggle", description="Enable or disable the application system", callback=self.toggle_applications_callback, - parent=self.modapp_group # Direct child of modapp + parent=self.modapp_group, # Direct child of modapp ) app_commands.describe( enabled="Whether applications should be enabled or disabled" @@ -321,7 +330,7 @@ class ModApplicationCog(commands.Cog): name="settings_reviewchannel", description="Set the channel where new applications will be posted for review", callback=self.set_review_channel_callback, - parent=self.modapp_group # Direct child of modapp + parent=self.modapp_group, # Direct child of modapp ) app_commands.describe( channel="The channel where applications will be posted for review" @@ -333,7 +342,7 @@ class ModApplicationCog(commands.Cog): name="settings_logchannel", description="Set the channel where application activity will be logged", callback=self.set_log_channel_callback, - parent=self.modapp_group # Direct child of modapp + parent=self.modapp_group, # Direct child of modapp ) app_commands.describe( channel="The channel where application activity will be logged" @@ -345,11 +354,11 @@ class ModApplicationCog(commands.Cog): name="settings_reviewerrole", description="Set the role that can review applications", callback=self.set_reviewer_role_callback, - parent=self.modapp_group # Direct child of modapp + parent=self.modapp_group, # Direct child of modapp + ) + app_commands.describe(role="The role that can review applications")( + settings_reviewer_role_command ) - app_commands.describe( - role="The role that can review applications" - )(settings_reviewer_role_command) self.modapp_group.add_command(settings_reviewer_role_command) # --- Set Required Role Command --- @@ -357,7 +366,7 @@ class ModApplicationCog(commands.Cog): name="settings_requiredrole", description="Set the role required to apply (optional)", callback=self.set_required_role_callback, - parent=self.modapp_group # Direct child of modapp + parent=self.modapp_group, # Direct child of modapp ) app_commands.describe( role="The role required to apply (or None to allow anyone)" @@ -369,7 +378,7 @@ class ModApplicationCog(commands.Cog): name="settings_cooldown", description="Set the cooldown period between rejected applications", callback=self.set_cooldown_callback, - parent=self.modapp_group # Direct child of modapp + parent=self.modapp_group, # Direct child of modapp ) app_commands.describe( days="Number of days a user must wait after rejection before applying again" @@ -381,7 +390,7 @@ class ModApplicationCog(commands.Cog): name="settings_lognewapps", description="Toggle whether new applications are automatically logged in the log channel", callback=self.toggle_log_new_applications_callback, - parent=self.modapp_group # Direct child of modapp + parent=self.modapp_group, # Direct child of modapp ) app_commands.describe( enabled="Whether new applications should be logged automatically" @@ -397,7 +406,7 @@ class ModApplicationCog(commands.Cog): if not settings or not settings.get("enabled", False): await interaction.response.send_message( "❌ Moderator applications are currently disabled for this server.", - ephemeral=True + ephemeral=True, ) return @@ -405,30 +414,36 @@ class ModApplicationCog(commands.Cog): required_role_id = settings.get("required_role_id") if required_role_id: member = interaction.guild.get_member(interaction.user.id) - if not member or not any(role.id == required_role_id for role in member.roles): + if not member or not any( + role.id == required_role_id for role in member.roles + ): required_role = interaction.guild.get_role(required_role_id) role_name = required_role.name if required_role else "Required Role" await interaction.response.send_message( f"❌ You need the {role_name} role to apply for moderator.", - ephemeral=True + ephemeral=True, ) return # Check if user has a pending or under review application - has_active_application = await self.check_active_application(interaction.guild_id, interaction.user.id) + has_active_application = await self.check_active_application( + interaction.guild_id, interaction.user.id + ) if has_active_application: await interaction.response.send_message( "❌ You already have an application pending review. Please wait for it to be processed.", - ephemeral=True + ephemeral=True, ) return # Check if user is on cooldown from a rejected application - on_cooldown, days_left = await self.check_application_cooldown(interaction.guild_id, interaction.user.id, settings.get("cooldown_days", 30)) + on_cooldown, days_left = await self.check_application_cooldown( + interaction.guild_id, interaction.user.id, settings.get("cooldown_days", 30) + ) if on_cooldown: await interaction.response.send_message( f"❌ You must wait {days_left} more days before submitting a new application.", - ephemeral=True + ephemeral=True, ) return @@ -438,13 +453,16 @@ class ModApplicationCog(commands.Cog): # Show the application form await interaction.response.send_modal(ModApplicationModal(self, questions)) - async def list_applications_callback(self, interaction: discord.Interaction, status: Optional[str] = None): + async def list_applications_callback( + self, interaction: discord.Interaction, status: Optional[str] = None + ): """Handle the /modapp list command""" # Check if user has permission to view applications - if not await self.check_reviewer_permission(interaction.guild_id, interaction.user.id): + if not await self.check_reviewer_permission( + interaction.guild_id, interaction.user.id + ): await interaction.response.send_message( - "❌ You don't have permission to view applications.", - ephemeral=True + "❌ You don't have permission to view applications.", ephemeral=True ) return @@ -453,17 +471,19 @@ class ModApplicationCog(commands.Cog): if status and status.upper() not in valid_statuses: await interaction.response.send_message( f"❌ Invalid status. Valid options are: {', '.join(valid_statuses)}", - ephemeral=True + ephemeral=True, ) return # Fetch applications from database - applications = await self.get_applications(interaction.guild_id, status.upper() if status else None) + applications = await self.get_applications( + interaction.guild_id, status.upper() if status else None + ) if not applications: await interaction.response.send_message( f"No applications found{f' with status {status.upper()}' if status else ''}.", - ephemeral=True + ephemeral=True, ) return @@ -471,7 +491,7 @@ class ModApplicationCog(commands.Cog): embed = discord.Embed( title=f"Moderator Applications{f' ({status.upper()})' if status else ''}", color=discord.Color.blue(), - timestamp=datetime.datetime.now() + timestamp=datetime.datetime.now(), ) for app in applications: @@ -485,18 +505,21 @@ class ModApplicationCog(commands.Cog): embed.add_field( name=f"Application #{app['application_id']} - {app['status']}", value=f"From: {user_display}\nSubmitted: {submission_date}\nUse `/modapp view {app['application_id']}` to view details", - inline=False + inline=False, ) await interaction.response.send_message(embed=embed, ephemeral=True) - async def view_application_callback(self, interaction: discord.Interaction, application_id: int): + async def view_application_callback( + self, interaction: discord.Interaction, application_id: int + ): """Handle the /modapp view command""" # Check if user has permission to view applications - if not await self.check_reviewer_permission(interaction.guild_id, interaction.user.id): + if not await self.check_reviewer_permission( + interaction.guild_id, interaction.user.id + ): await interaction.response.send_message( - "❌ You don't have permission to view applications.", - ephemeral=True + "❌ You don't have permission to view applications.", ephemeral=True ) return @@ -505,35 +528,58 @@ class ModApplicationCog(commands.Cog): if not application or application["guild_id"] != interaction.guild_id: await interaction.response.send_message( - "❌ Application not found.", - ephemeral=True + "❌ Application not found.", ephemeral=True ) return # Get user objects - applicant = self.bot.get_user(application["user_id"]) or f"User ID: {application['user_id']}" - applicant_display = applicant.mention if isinstance(applicant, discord.User) else applicant + applicant = ( + self.bot.get_user(application["user_id"]) + or f"User ID: {application['user_id']}" + ) + applicant_display = ( + applicant.mention if isinstance(applicant, discord.User) else applicant + ) reviewer = None if application["reviewer_id"]: - reviewer = self.bot.get_user(application["reviewer_id"]) or f"User ID: {application['reviewer_id']}" - reviewer_display = reviewer.mention if isinstance(reviewer, discord.User) else reviewer or "None" + reviewer = ( + self.bot.get_user(application["reviewer_id"]) + or f"User ID: {application['reviewer_id']}" + ) + reviewer_display = ( + reviewer.mention + if isinstance(reviewer, discord.User) + else reviewer or "None" + ) # Create an embed to display the application details embed = discord.Embed( title=f"Moderator Application #{application_id}", color=discord.Color.blue(), - timestamp=datetime.datetime.now() + timestamp=datetime.datetime.now(), ) # Add application metadata embed.add_field(name="Applicant", value=applicant_display, inline=True) embed.add_field(name="Status", value=application["status"], inline=True) - embed.add_field(name="Submitted", value=application["submission_date"].strftime("%Y-%m-%d %H:%M UTC"), inline=True) + embed.add_field( + name="Submitted", + value=application["submission_date"].strftime("%Y-%m-%d %H:%M UTC"), + inline=True, + ) if application["reviewer_id"]: embed.add_field(name="Reviewed By", value=reviewer_display, inline=True) - embed.add_field(name="Review Date", value=application["review_date"].strftime("%Y-%m-%d %H:%M UTC") if application["review_date"] else "N/A", inline=True) + embed.add_field( + name="Review Date", + value=( + application["review_date"].strftime("%Y-%m-%d %H:%M UTC") + if application["review_date"] + else "N/A" + ), + inline=True, + ) # Add application form data embed.add_field(name="Application Responses", value="", inline=False) @@ -541,7 +587,9 @@ class ModApplicationCog(commands.Cog): form_data = application["form_data"] for key, value in form_data.items(): # Try to find the question label from DEFAULT_QUESTIONS - question_label = next((q["label"] for q in DEFAULT_QUESTIONS if q["id"] == key), key) + question_label = next( + (q["label"] for q in DEFAULT_QUESTIONS if q["id"] == key), key + ) embed.add_field(name=question_label, value=value, inline=False) # Add notes if available @@ -555,179 +603,210 @@ class ModApplicationCog(commands.Cog): await interaction.response.send_message(embed=embed, view=view, ephemeral=True) - async def toggle_applications_callback(self, interaction: discord.Interaction, enabled: bool): + async def toggle_applications_callback( + self, interaction: discord.Interaction, enabled: bool + ): """Handle the /modapp config toggle command""" # Check if user has permission to manage applications - if not await self.check_admin_permission(interaction.guild_id, interaction.user.id): + if not await self.check_admin_permission( + interaction.guild_id, interaction.user.id + ): await interaction.response.send_message( "❌ You don't have permission to manage application settings.", - ephemeral=True + ephemeral=True, ) return # Update setting in database - success = await self.update_application_setting(interaction.guild_id, "enabled", enabled) + success = await self.update_application_setting( + interaction.guild_id, "enabled", enabled + ) if success: status = "enabled" if enabled else "disabled" await interaction.response.send_message( f"✅ Moderator applications are now {status} for this server.", - ephemeral=True + ephemeral=True, ) else: await interaction.response.send_message( - "❌ Failed to update application settings.", - ephemeral=True + "❌ Failed to update application settings.", ephemeral=True ) - async def set_review_channel_callback(self, interaction: discord.Interaction, channel: discord.TextChannel): + async def set_review_channel_callback( + self, interaction: discord.Interaction, channel: discord.TextChannel + ): """Handle the /modapp config reviewchannel command""" # Check if user has permission to manage applications - if not await self.check_admin_permission(interaction.guild_id, interaction.user.id): + if not await self.check_admin_permission( + interaction.guild_id, interaction.user.id + ): await interaction.response.send_message( "❌ You don't have permission to manage application settings.", - ephemeral=True + ephemeral=True, ) return # Update setting in database - success = await self.update_application_setting(interaction.guild_id, "review_channel_id", channel.id) + success = await self.update_application_setting( + interaction.guild_id, "review_channel_id", channel.id + ) if success: await interaction.response.send_message( f"✅ New applications will now be posted in {channel.mention} for review.", - ephemeral=True + ephemeral=True, ) else: await interaction.response.send_message( - "❌ Failed to update application settings.", - ephemeral=True + "❌ Failed to update application settings.", ephemeral=True ) - async def set_log_channel_callback(self, interaction: discord.Interaction, channel: discord.TextChannel): + async def set_log_channel_callback( + self, interaction: discord.Interaction, channel: discord.TextChannel + ): """Handle the /modapp config logchannel command""" # Check if user has permission to manage applications - if not await self.check_admin_permission(interaction.guild_id, interaction.user.id): + if not await self.check_admin_permission( + interaction.guild_id, interaction.user.id + ): await interaction.response.send_message( "❌ You don't have permission to manage application settings.", - ephemeral=True + ephemeral=True, ) return # Update setting in database - success = await self.update_application_setting(interaction.guild_id, "log_channel_id", channel.id) + success = await self.update_application_setting( + interaction.guild_id, "log_channel_id", channel.id + ) if success: await interaction.response.send_message( f"✅ Application activity will now be logged in {channel.mention}.", - ephemeral=True + ephemeral=True, ) else: await interaction.response.send_message( - "❌ Failed to update application settings.", - ephemeral=True + "❌ Failed to update application settings.", ephemeral=True ) - async def set_reviewer_role_callback(self, interaction: discord.Interaction, role: discord.Role): + async def set_reviewer_role_callback( + self, interaction: discord.Interaction, role: discord.Role + ): """Handle the /modapp config reviewerrole command""" # Check if user has permission to manage applications - if not await self.check_admin_permission(interaction.guild_id, interaction.user.id): + if not await self.check_admin_permission( + interaction.guild_id, interaction.user.id + ): await interaction.response.send_message( "❌ You don't have permission to manage application settings.", - ephemeral=True + ephemeral=True, ) return # Update setting in database - success = await self.update_application_setting(interaction.guild_id, "reviewer_role_id", role.id) + success = await self.update_application_setting( + interaction.guild_id, "reviewer_role_id", role.id + ) if success: await interaction.response.send_message( f"✅ Members with the {role.mention} role can now review applications.", - ephemeral=True + ephemeral=True, ) else: await interaction.response.send_message( - "❌ Failed to update application settings.", - ephemeral=True + "❌ Failed to update application settings.", ephemeral=True ) - async def set_required_role_callback(self, interaction: discord.Interaction, role: Optional[discord.Role] = None): + async def set_required_role_callback( + self, interaction: discord.Interaction, role: Optional[discord.Role] = None + ): """Handle the /modapp settings requiredrole command""" # Check if user has permission to manage applications - if not await self.check_admin_permission(interaction.guild_id, interaction.user.id): + if not await self.check_admin_permission( + interaction.guild_id, interaction.user.id + ): await interaction.response.send_message( "❌ You don't have permission to manage application settings.", - ephemeral=True + ephemeral=True, ) return # Update setting in database (None means no role required) role_id = role.id if role else None - success = await self.update_application_setting(interaction.guild_id, "required_role_id", role_id) + success = await self.update_application_setting( + interaction.guild_id, "required_role_id", role_id + ) if success: if role: await interaction.response.send_message( f"✅ Members now need the {role.mention} role to apply for moderator.", - ephemeral=True + ephemeral=True, ) else: await interaction.response.send_message( "✅ Any member can now apply for moderator (no role requirement).", - ephemeral=True + ephemeral=True, ) else: await interaction.response.send_message( - "❌ Failed to update application settings.", - ephemeral=True + "❌ Failed to update application settings.", ephemeral=True ) async def set_cooldown_callback(self, interaction: discord.Interaction, days: int): """Handle the /modapp settings cooldown command""" # Check if user has permission to manage applications - if not await self.check_admin_permission(interaction.guild_id, interaction.user.id): + if not await self.check_admin_permission( + interaction.guild_id, interaction.user.id + ): await interaction.response.send_message( "❌ You don't have permission to manage application settings.", - ephemeral=True + ephemeral=True, ) return # Validate days parameter if days < 0 or days > 365: await interaction.response.send_message( - "❌ Cooldown days must be between 0 and 365.", - ephemeral=True + "❌ Cooldown days must be between 0 and 365.", ephemeral=True ) return # Update setting in database - success = await self.update_application_setting(interaction.guild_id, "cooldown_days", days) + success = await self.update_application_setting( + interaction.guild_id, "cooldown_days", days + ) if success: if days == 0: await interaction.response.send_message( "✅ Application cooldown has been disabled. Users can reapply immediately after rejection.", - ephemeral=True + ephemeral=True, ) else: await interaction.response.send_message( f"✅ Users must now wait {days} days after rejection before submitting a new application.", - ephemeral=True + ephemeral=True, ) else: await interaction.response.send_message( - "❌ Failed to update application settings.", - ephemeral=True + "❌ Failed to update application settings.", ephemeral=True ) - async def toggle_log_new_applications_callback(self, interaction: discord.Interaction, enabled: bool): + async def toggle_log_new_applications_callback( + self, interaction: discord.Interaction, enabled: bool + ): """Handle the /modapp settings lognewapps command""" # Check if user has permission to manage applications - if not await self.check_admin_permission(interaction.guild_id, interaction.user.id): + if not await self.check_admin_permission( + interaction.guild_id, interaction.user.id + ): await interaction.response.send_message( "❌ You don't have permission to manage application settings.", - ephemeral=True + ephemeral=True, ) return @@ -738,38 +817,43 @@ class ModApplicationCog(commands.Cog): if enabled and not log_channel_id: await interaction.response.send_message( "❌ You need to set a log channel first using `/modapp settings_logchannel` before enabling this feature.", - ephemeral=True + ephemeral=True, ) return # Update setting in database - success = await self.update_application_setting(interaction.guild_id, "log_new_applications", enabled) + success = await self.update_application_setting( + interaction.guild_id, "log_new_applications", enabled + ) if success: status = "enabled" if enabled else "disabled" if enabled: log_channel = interaction.guild.get_channel(log_channel_id) - channel_mention = log_channel.mention if log_channel else "the configured log channel" + channel_mention = ( + log_channel.mention if log_channel else "the configured log channel" + ) await interaction.response.send_message( f"✅ New applications will now be automatically logged in {channel_mention}.", - ephemeral=True + ephemeral=True, ) else: await interaction.response.send_message( "✅ New applications will no longer be automatically logged.", - ephemeral=True + ephemeral=True, ) else: await interaction.response.send_message( - "❌ Failed to update application settings.", - ephemeral=True + "❌ Failed to update application settings.", ephemeral=True ) # --- Database Helper Methods --- - async def submit_application(self, guild_id: int, user_id: int, form_data: Dict[str, str]) -> bool: + async def submit_application( + self, guild_id: int, user_id: int, form_data: Dict[str, str] + ) -> bool: """Submit a new application to the database""" - if not hasattr(self.bot, 'pg_pool') or not self.bot.pg_pool: + if not hasattr(self.bot, "pg_pool") or not self.bot.pg_pool: logger.error("Database pool not available") return False @@ -779,28 +863,39 @@ class ModApplicationCog(commands.Cog): form_data_json = json.dumps(form_data) # Insert application into database - await conn.execute(""" + await conn.execute( + """ INSERT INTO mod_applications (guild_id, user_id, form_data) VALUES ($1, $2, $3) ON CONFLICT (guild_id, user_id, status) WHERE status IN ('PENDING', 'UNDER_REVIEW') DO NOTHING - """, guild_id, user_id, form_data_json) + """, + guild_id, + user_id, + form_data_json, + ) # Check if the insert was successful by querying for the application - result = await conn.fetchrow(""" + result = await conn.fetchrow( + """ SELECT application_id FROM mod_applications WHERE guild_id = $1 AND user_id = $2 AND status IN ('PENDING', 'UNDER_REVIEW') - """, guild_id, user_id) + """, + guild_id, + user_id, + ) return result is not None except Exception as e: logger.error(f"Error submitting application: {e}") return False - async def get_applications(self, guild_id: int, status: Optional[str] = None) -> List[Dict[str, Any]]: + async def get_applications( + self, guild_id: int, status: Optional[str] = None + ) -> List[Dict[str, Any]]: """Get all applications for a guild, optionally filtered by status""" - if not hasattr(self.bot, 'pg_pool') or not self.bot.pg_pool: + if not hasattr(self.bot, "pg_pool") or not self.bot.pg_pool: logger.error("Database pool not available") return [] @@ -808,18 +903,25 @@ class ModApplicationCog(commands.Cog): async with self.bot.pg_pool.acquire() as conn: if status: # Filter by status - rows = await conn.fetch(""" + rows = await conn.fetch( + """ SELECT * FROM mod_applications WHERE guild_id = $1 AND status = $2 ORDER BY submission_date DESC - """, guild_id, status) + """, + guild_id, + status, + ) else: # Get all applications - rows = await conn.fetch(""" + rows = await conn.fetch( + """ SELECT * FROM mod_applications WHERE guild_id = $1 ORDER BY submission_date DESC - """, guild_id) + """, + guild_id, + ) # Convert rows to dictionaries and parse form_data JSON applications = [] @@ -833,18 +935,23 @@ class ModApplicationCog(commands.Cog): logger.error(f"Error getting applications: {e}") return [] - async def get_application_by_id(self, application_id: int) -> Optional[Dict[str, Any]]: + async def get_application_by_id( + self, application_id: int + ) -> Optional[Dict[str, Any]]: """Get a specific application by ID""" - if not hasattr(self.bot, 'pg_pool') or not self.bot.pg_pool: + if not hasattr(self.bot, "pg_pool") or not self.bot.pg_pool: logger.error("Database pool not available") return None try: async with self.bot.pg_pool.acquire() as conn: - row = await conn.fetchrow(""" + row = await conn.fetchrow( + """ SELECT * FROM mod_applications WHERE application_id = $1 - """, application_id) + """, + application_id, + ) if not row: return None @@ -858,20 +965,32 @@ class ModApplicationCog(commands.Cog): logger.error(f"Error getting application by ID: {e}") return None - async def update_application_status(self, application_id: int, status: APPLICATION_STATUS, reviewer_id: int, notes: Optional[str] = None) -> bool: + async def update_application_status( + self, + application_id: int, + status: APPLICATION_STATUS, + reviewer_id: int, + notes: Optional[str] = None, + ) -> bool: """Update the status of an application""" - if not hasattr(self.bot, 'pg_pool') or not self.bot.pg_pool: + if not hasattr(self.bot, "pg_pool") or not self.bot.pg_pool: logger.error("Database pool not available") return False try: async with self.bot.pg_pool.acquire() as conn: # Update application status - await conn.execute(""" + await conn.execute( + """ UPDATE mod_applications SET status = $1, reviewer_id = $2, review_date = CURRENT_TIMESTAMP, notes = $3 WHERE application_id = $4 - """, status, reviewer_id, notes, application_id) + """, + status, + reviewer_id, + notes, + application_id, + ) return True except Exception as e: @@ -880,24 +999,30 @@ class ModApplicationCog(commands.Cog): async def get_application_settings(self, guild_id: int) -> Dict[str, Any]: """Get application settings for a guild""" - if not hasattr(self.bot, 'pg_pool') or not self.bot.pg_pool: + if not hasattr(self.bot, "pg_pool") or not self.bot.pg_pool: logger.error("Database pool not available") return {"enabled": False} # Default settings try: async with self.bot.pg_pool.acquire() as conn: # Check if settings exist for this guild - row = await conn.fetchrow(""" + row = await conn.fetchrow( + """ SELECT * FROM mod_application_settings WHERE guild_id = $1 - """, guild_id) + """, + guild_id, + ) if not row: # Create default settings - await conn.execute(""" + await conn.execute( + """ INSERT INTO mod_application_settings (guild_id) VALUES ($1) - """, guild_id) + """, + guild_id, + ) # Return default settings return { @@ -909,39 +1034,49 @@ class ModApplicationCog(commands.Cog): "reviewer_role_id": None, "custom_questions": None, "cooldown_days": 30, - "log_new_applications": False + "log_new_applications": False, } # Convert row to dictionary and parse custom_questions JSON if it exists settings = dict(row) if settings["custom_questions"]: - settings["custom_questions"] = json.loads(settings["custom_questions"]) + settings["custom_questions"] = json.loads( + settings["custom_questions"] + ) return settings except Exception as e: logger.error(f"Error getting application settings: {e}") return {"enabled": False} # Default settings on error - async def update_application_setting(self, guild_id: int, setting_key: str, setting_value: Any) -> bool: + async def update_application_setting( + self, guild_id: int, setting_key: str, setting_value: Any + ) -> bool: """Update a specific application setting for a guild""" - if not hasattr(self.bot, 'pg_pool') or not self.bot.pg_pool: + if not hasattr(self.bot, "pg_pool") or not self.bot.pg_pool: logger.error("Database pool not available") return False try: async with self.bot.pg_pool.acquire() as conn: # Check if settings exist for this guild - exists = await conn.fetchval(""" + exists = await conn.fetchval( + """ SELECT COUNT(*) FROM mod_application_settings WHERE guild_id = $1 - """, guild_id) + """, + guild_id, + ) if not exists: # Create default settings first - await conn.execute(""" + await conn.execute( + """ INSERT INTO mod_application_settings (guild_id) VALUES ($1) - """, guild_id) + """, + guild_id, + ) # Special handling for JSON fields if setting_key == "custom_questions" and setting_value is not None: @@ -962,50 +1097,62 @@ class ModApplicationCog(commands.Cog): async def check_active_application(self, guild_id: int, user_id: int) -> bool: """Check if a user has an active application (pending or under review)""" - if not hasattr(self.bot, 'pg_pool') or not self.bot.pg_pool: + if not hasattr(self.bot, "pg_pool") or not self.bot.pg_pool: logger.error("Database pool not available") return False try: async with self.bot.pg_pool.acquire() as conn: # Check for active applications - result = await conn.fetchval(""" + result = await conn.fetchval( + """ SELECT COUNT(*) FROM mod_applications WHERE guild_id = $1 AND user_id = $2 AND status IN ('PENDING', 'UNDER_REVIEW') - """, guild_id, user_id) + """, + guild_id, + user_id, + ) return result > 0 except Exception as e: logger.error(f"Error checking active application: {e}") return False - async def check_application_cooldown(self, guild_id: int, user_id: int, cooldown_days: int) -> Tuple[bool, int]: + async def check_application_cooldown( + self, guild_id: int, user_id: int, cooldown_days: int + ) -> Tuple[bool, int]: """Check if a user is on cooldown from a rejected application Returns (on_cooldown, days_left) """ if cooldown_days <= 0: return False, 0 - if not hasattr(self.bot, 'pg_pool') or not self.bot.pg_pool: + if not hasattr(self.bot, "pg_pool") or not self.bot.pg_pool: logger.error("Database pool not available") return False, 0 try: async with self.bot.pg_pool.acquire() as conn: # Get the most recent rejected application - result = await conn.fetchrow(""" + result = await conn.fetchrow( + """ SELECT review_date FROM mod_applications WHERE guild_id = $1 AND user_id = $2 AND status = 'REJECTED' ORDER BY review_date DESC LIMIT 1 - """, guild_id, user_id) + """, + guild_id, + user_id, + ) if not result: return False, 0 # Calculate days since rejection review_date = result["review_date"] - days_since = (datetime.datetime.now(datetime.timezone.utc) - review_date).days + days_since = ( + datetime.datetime.now(datetime.timezone.utc) - review_date + ).days # Check if still on cooldown if days_since < cooldown_days: @@ -1070,7 +1217,9 @@ class ModApplicationCog(commands.Cog): # Only administrators and the guild owner can manage settings return False - async def notify_new_application(self, guild: discord.Guild, user: discord.User, form_data: Dict[str, str]) -> None: + async def notify_new_application( + self, guild: discord.Guild, user: discord.User, form_data: Dict[str, str] + ) -> None: """Notify staff about a new application""" # Get application settings settings = await self.get_application_settings(guild.id) @@ -1090,12 +1239,16 @@ class ModApplicationCog(commands.Cog): application = None try: async with self.bot.pg_pool.acquire() as conn: - row = await conn.fetchrow(""" + row = await conn.fetchrow( + """ SELECT application_id FROM mod_applications WHERE guild_id = $1 AND user_id = $2 AND status = 'PENDING' ORDER BY submission_date DESC LIMIT 1 - """, guild.id, user.id) + """, + guild.id, + user.id, + ) if row: application = dict(row) @@ -1111,13 +1264,15 @@ class ModApplicationCog(commands.Cog): title="New Moderator Application", description=f"{user.mention} has submitted a moderator application.", color=discord.Color.blue(), - timestamp=datetime.datetime.now() + timestamp=datetime.datetime.now(), ) # Add user info embed.set_author(name=f"{user.name}", icon_url=user.display_avatar.url) embed.add_field(name="User ID", value=user.id, inline=True) - embed.add_field(name="Application ID", value=application["application_id"], inline=True) + embed.add_field( + name="Application ID", value=application["application_id"], inline=True + ) # Add a preview of the application (first few questions) preview_questions = 2 # Number of questions to show in preview @@ -1128,7 +1283,9 @@ class ModApplicationCog(commands.Cog): break # Try to find the question label from DEFAULT_QUESTIONS - question_label = next((q["label"] for q in DEFAULT_QUESTIONS if q["id"] == key), key) + question_label = next( + (q["label"] for q in DEFAULT_QUESTIONS if q["id"] == key), key + ) # Truncate long answers if len(value) > 100: @@ -1139,21 +1296,25 @@ class ModApplicationCog(commands.Cog): # Add view details button view = discord.ui.View() - view.add_item(discord.ui.Button( - label="View Application", - style=discord.ButtonStyle.primary, - custom_id=f"view_application_{application['application_id']}" - )) + view.add_item( + discord.ui.Button( + label="View Application", + style=discord.ButtonStyle.primary, + custom_id=f"view_application_{application['application_id']}", + ) + ) # Send the notification to the review channel try: await review_channel.send( content=f"📝 New moderator application from {user.mention}", embed=embed, - view=view + view=view, ) except Exception as e: - logger.error(f"Error sending application notification to review channel: {e}") + logger.error( + f"Error sending application notification to review channel: {e}" + ) # If log_new_applications is enabled and log_channel_id is set, also log to the log channel if log_new_applications and log_channel_id: @@ -1165,18 +1326,36 @@ class ModApplicationCog(commands.Cog): title="New Moderator Application Submitted", description=f"A new moderator application has been submitted by {user.mention}.", color=discord.Color.blue(), - timestamp=datetime.datetime.now() + timestamp=datetime.datetime.now(), + ) + log_embed.set_author( + name=f"{user.name}", icon_url=user.display_avatar.url + ) + log_embed.add_field( + name="Application ID", + value=application["application_id"], + inline=True, ) - log_embed.set_author(name=f"{user.name}", icon_url=user.display_avatar.url) - log_embed.add_field(name="Application ID", value=application["application_id"], inline=True) log_embed.add_field(name="Status", value="PENDING", inline=True) - log_embed.add_field(name="Submission Time", value=discord.utils.format_dt(datetime.datetime.now()), inline=True) + log_embed.add_field( + name="Submission Time", + value=discord.utils.format_dt(datetime.datetime.now()), + inline=True, + ) await log_channel.send(embed=log_embed) except Exception as e: - logger.error(f"Error sending application notification to log channel: {e}") + logger.error( + f"Error sending application notification to log channel: {e}" + ) - async def notify_application_status_change(self, guild: discord.Guild, user_id: int, status: APPLICATION_STATUS, reason: Optional[str] = None) -> None: + async def notify_application_status_change( + self, + guild: discord.Guild, + user_id: int, + status: APPLICATION_STATUS, + reason: Optional[str] = None, + ) -> None: """Notify the applicant about a status change""" # Get the user user = self.bot.get_user(user_id) @@ -1184,23 +1363,31 @@ class ModApplicationCog(commands.Cog): try: user = await self.bot.fetch_user(user_id) except: - logger.error(f"Could not fetch user {user_id} for application notification") + logger.error( + f"Could not fetch user {user_id} for application notification" + ) return # Create the notification message if status == "APPROVED": title = "🎉 Application Approved!" - description = "Congratulations! Your moderator application has been approved." + description = ( + "Congratulations! Your moderator application has been approved." + ) color = discord.Color.green() elif status == "REJECTED": title = "❌ Application Rejected" - description = "We're sorry, but your moderator application has been rejected." + description = ( + "We're sorry, but your moderator application has been rejected." + ) if reason: description += f"\n\nReason: {reason}" color = discord.Color.red() elif status == "UNDER_REVIEW": title = "🔍 Application Under Review" - description = "Your moderator application is now being reviewed by our team." + description = ( + "Your moderator application is now being reviewed by our team." + ) color = discord.Color.gold() else: return # Don't notify for other statuses @@ -1210,16 +1397,20 @@ class ModApplicationCog(commands.Cog): title=title, description=description, color=color, - timestamp=datetime.datetime.now() + timestamp=datetime.datetime.now(), ) - embed.set_author(name=guild.name, icon_url=guild.icon.url if guild.icon else None) + embed.set_author( + name=guild.name, icon_url=guild.icon.url if guild.icon else None + ) # Try to send a DM to the user try: await user.send(embed=embed) except Exception as e: - logger.error(f"Error sending application status notification to user {user_id}: {e}") + logger.error( + f"Error sending application status notification to user {user_id}: {e}" + ) # If DM fails, try to log it settings = await self.get_application_settings(guild.id) @@ -1228,7 +1419,9 @@ class ModApplicationCog(commands.Cog): if log_channel_id: log_channel = guild.get_channel(log_channel_id) if log_channel: - await log_channel.send(f"⚠️ Failed to send application status notification to {user.mention}. They may have DMs disabled.") + await log_channel.send( + f"⚠️ Failed to send application status notification to {user.mention}. They may have DMs disabled." + ) @commands.Cog.listener() async def on_interaction(self, interaction: discord.Interaction): @@ -1244,10 +1437,12 @@ class ModApplicationCog(commands.Cog): application_id = int(custom_id.split("_")[2]) # Check if user has permission to view applications - if not await self.check_reviewer_permission(interaction.guild_id, interaction.user.id): + if not await self.check_reviewer_permission( + interaction.guild_id, interaction.user.id + ): await interaction.response.send_message( "❌ You don't have permission to view applications.", - ephemeral=True + ephemeral=True, ) return @@ -1256,8 +1451,7 @@ class ModApplicationCog(commands.Cog): if not application or application["guild_id"] != interaction.guild_id: await interaction.response.send_message( - "❌ Application not found.", - ephemeral=True + "❌ Application not found.", ephemeral=True ) return @@ -1269,10 +1463,13 @@ class ModApplicationCog(commands.Cog): logger.error(f"Error handling view application button: {e}") await interaction.response.send_message( "❌ An error occurred while processing your request.", - ephemeral=True + ephemeral=True, ) + async def setup(bot: commands.Bot): logger.info(f"ModApplicationCog setup function CALLED. Bot instance ID: {id(bot)}") await bot.add_cog(ModApplicationCog(bot)) - logger.info(f"ModApplicationCog setup function COMPLETED and cog added. Bot instance ID: {id(bot)}") + logger.info( + f"ModApplicationCog setup function COMPLETED and cog added. Bot instance ID: {id(bot)}" + ) diff --git a/cogs/mod_log_cog.py b/cogs/mod_log_cog.py index 4f8c359..c7b1103 100644 --- a/cogs/mod_log_cog.py +++ b/cogs/mod_log_cog.py @@ -8,22 +8,23 @@ import datetime # Use absolute imports from the discordbot package root from db import mod_log_db -import settings_manager as sm # Use module functions directly +import settings_manager as sm # Use module functions directly log = logging.getLogger(__name__) + class ModLogCog(commands.Cog): """Cog for handling integrated moderation logging and related commands.""" def __init__(self, bot: commands.Bot): self.bot = bot # Settings manager functions are used directly from the imported module - self.pool: asyncpg.Pool = bot.pg_pool # Assuming pool is attached to bot + self.pool: asyncpg.Pool = bot.pg_pool # Assuming pool is attached to bot # Create the main command group for this cog self.modlog_group = app_commands.Group( name="modlog", - description="Commands for viewing and managing moderation logs" + description="Commands for viewing and managing moderation logs", ) # Register commands within the group @@ -35,7 +36,14 @@ class ModLogCog(commands.Cog): class LogView(ui.LayoutView): """View used for moderation log messages.""" - def __init__(self, bot: commands.Bot, title: str, color: discord.Color, lines: list[str], footer: str): + def __init__( + self, + bot: commands.Bot, + title: str, + color: discord.Color, + lines: list[str], + footer: str, + ): super().__init__(timeout=None) container = ui.Container(accent_colour=color) self.add_item(container) @@ -47,7 +55,9 @@ class ModLogCog(commands.Cog): self.footer_display = ui.TextDisplay(footer) container.add_item(self.footer_display) - def _format_user(self, user: Union[Member, User, Object], guild: Optional[discord.Guild] = None) -> str: + def _format_user( + self, user: Union[Member, User, Object], guild: Optional[discord.Guild] = None + ) -> str: """Return a string with display name, username and ID for a user-like object.""" if isinstance(user, Object): return f"Unknown User (ID: {user.id})" @@ -58,7 +68,11 @@ class ModLogCog(commands.Cog): display = member.display_name if member else user.name else: display = user.name - username = f"{user.name}#{user.discriminator}" if isinstance(user, (Member, User)) else "Unknown" + username = ( + f"{user.name}#{user.discriminator}" + if isinstance(user, (Member, User)) + else "Unknown" + ) return f"{display} ({username}) [ID: {user.id}]" async def _fetch_user_display(self, user_id: int, guild: discord.Guild) -> str: @@ -83,11 +97,11 @@ class ModLogCog(commands.Cog): name="setchannel", description="Set the channel for moderation logs and enable logging.", callback=self.modlog_setchannel_callback, - parent=self.modlog_group + parent=self.modlog_group, + ) + app_commands.describe(channel="The text channel to send moderation logs to.")( + setchannel_command ) - app_commands.describe( - channel="The text channel to send moderation logs to." - )(setchannel_command) self.modlog_group.add_command(setchannel_command) # --- View Command --- @@ -95,11 +109,11 @@ class ModLogCog(commands.Cog): name="view", description="View moderation logs for a user or the server", callback=self.modlog_view_callback, - parent=self.modlog_group + parent=self.modlog_group, + ) + app_commands.describe(user="Optional: The user whose logs you want to view")( + view_command ) - app_commands.describe( - user="Optional: The user whose logs you want to view" - )(view_command) self.modlog_group.add_command(view_command) # --- Case Command --- @@ -107,11 +121,11 @@ class ModLogCog(commands.Cog): name="case", description="View details for a specific moderation case ID", callback=self.modlog_case_callback, - parent=self.modlog_group + parent=self.modlog_group, + ) + app_commands.describe(case_id="The ID of the moderation case to view")( + case_command ) - app_commands.describe( - case_id="The ID of the moderation case to view" - )(case_command) self.modlog_group.add_command(case_command) # --- Reason Command --- @@ -119,28 +133,35 @@ class ModLogCog(commands.Cog): name="reason", description="Update the reason for a specific moderation case ID", callback=self.modlog_reason_callback, - parent=self.modlog_group + parent=self.modlog_group, ) app_commands.describe( case_id="The ID of the moderation case to update", - new_reason="The new reason for the moderation action" + new_reason="The new reason for the moderation action", )(reason_command) self.modlog_group.add_command(reason_command) # --- Command Callbacks --- @app_commands.checks.has_permissions(manage_guild=True) - async def modlog_setchannel_callback(self, interaction: Interaction, channel: discord.TextChannel): + async def modlog_setchannel_callback( + self, interaction: Interaction, channel: discord.TextChannel + ): """Callback for the /modlog setchannel command.""" await interaction.response.defer(ephemeral=True) guild_id = interaction.guild_id if not guild_id: - await interaction.followup.send("❌ This command can only be used in a server.", ephemeral=True) + await interaction.followup.send( + "❌ This command can only be used in a server.", ephemeral=True + ) return if not channel or not isinstance(channel, discord.TextChannel): - await interaction.followup.send("❌ Invalid channel provided. Please specify a valid text channel.", ephemeral=True) + await interaction.followup.send( + "❌ Invalid channel provided. Please specify a valid text channel.", + ephemeral=True, + ) return # Check if the bot has permissions to send messages in the target channel @@ -148,13 +169,13 @@ class ModLogCog(commands.Cog): if not channel.permissions_for(bot_member).send_messages: await interaction.followup.send( f"❌ I don't have permission to send messages in {channel.mention}. Please grant me 'Send Messages' permission there.", - ephemeral=True + ephemeral=True, ) return if not channel.permissions_for(bot_member).embed_links: await interaction.followup.send( f"❌ I don't have permission to send embeds in {channel.mention}. Please grant me 'Embed Links' permission there.", - ephemeral=True + ephemeral=True, ) return @@ -167,21 +188,25 @@ class ModLogCog(commands.Cog): if set_channel_success and set_enabled_success: await interaction.followup.send( f"✅ Moderation logs will now be sent to {channel.mention} and logging is enabled.", - ephemeral=True + ephemeral=True, + ) + log.info( + f"Mod log channel set to {channel.id} and logging enabled for guild {guild_id} by {interaction.user.id}" ) - log.info(f"Mod log channel set to {channel.id} and logging enabled for guild {guild_id} by {interaction.user.id}") else: await interaction.followup.send( "❌ Failed to save moderation log settings. Please check the bot logs for more details.", - ephemeral=True + ephemeral=True, + ) + log.error( + f"Failed to set mod log channel/enabled status for guild {guild_id}. Channel success: {set_channel_success}, Enabled success: {set_enabled_success}" ) - log.error(f"Failed to set mod log channel/enabled status for guild {guild_id}. Channel success: {set_channel_success}, Enabled success: {set_enabled_success}") except Exception as e: log.exception(f"Error setting mod log channel for guild {guild_id}: {e}") await interaction.followup.send( "❌ An unexpected error occurred while setting the moderation log channel. Please try again later.", - ephemeral=True + ephemeral=True, ) # --- Core Logging Function --- @@ -189,14 +214,18 @@ class ModLogCog(commands.Cog): async def log_action( self, guild: discord.Guild, - moderator: Union[User, Member], # For bot actions - target: Union[User, Member, Object], # Can be user, member, or just an ID object + moderator: Union[User, Member], # For bot actions + target: Union[ + User, Member, Object + ], # Can be user, member, or just an ID object action_type: str, reason: Optional[str], duration: Optional[datetime.timedelta] = None, - source: str = "BOT", # Default source is the bot itself - ai_details: Optional[Dict[str, Any]] = None, # Details from AI API - moderator_id_override: Optional[int] = None # Allow overriding moderator ID for AI source + source: str = "BOT", # Default source is the bot itself + ai_details: Optional[Dict[str, Any]] = None, # Details from AI API + moderator_id_override: Optional[ + int + ] = None, # Allow overriding moderator ID for AI source ): """Logs a moderation action to the database and configured channel.""" if not guild: @@ -205,19 +234,28 @@ class ModLogCog(commands.Cog): guild_id = guild.id # Use override if provided (for AI source), otherwise use moderator object ID - moderator_id = moderator_id_override if moderator_id_override is not None else moderator.id + moderator_id = ( + moderator_id_override if moderator_id_override is not None else moderator.id + ) target_user_id = target.id duration_seconds = int(duration.total_seconds()) if duration else None # 1. Add initial log entry to DB case_id = await mod_log_db.add_mod_log( - self.pool, guild_id, moderator_id, target_user_id, - action_type, reason, duration_seconds + self.pool, + guild_id, + moderator_id, + target_user_id, + action_type, + reason, + duration_seconds, ) if not case_id: - log.error(f"Failed to get case_id when logging action {action_type} in guild {guild_id}") - return # Don't proceed if we couldn't save the initial log + log.error( + f"Failed to get case_id when logging action {action_type} in guild {guild_id}" + ) + return # Don't proceed if we couldn't save the initial log # 2. Check settings and send log message try: @@ -226,19 +264,23 @@ class ModLogCog(commands.Cog): log_channel_id = await sm.get_mod_log_channel_id(guild_id) if not log_enabled or not log_channel_id: - log.debug(f"Mod logging disabled or channel not set for guild {guild_id}. Skipping Discord log message.") + log.debug( + f"Mod logging disabled or channel not set for guild {guild_id}. Skipping Discord log message." + ) return log_channel = guild.get_channel(log_channel_id) if not log_channel or not isinstance(log_channel, discord.TextChannel): - log.warning(f"Mod log channel {log_channel_id} not found or not a text channel in guild {guild_id}.") + log.warning( + f"Mod log channel {log_channel_id} not found or not a text channel in guild {guild_id}." + ) # Optionally update DB to remove channel ID? Or just leave it. return # 3. Format and send view view = self._format_log_embed( case_id=case_id, - moderator=moderator, # Pass the object for display formatting + moderator=moderator, # Pass the object for display formatting target=target, action_type=action_type, reason=reason, @@ -246,16 +288,19 @@ class ModLogCog(commands.Cog): guild=guild, source=source, ai_details=ai_details, - moderator_id_override=moderator_id_override # Pass override for formatting + moderator_id_override=moderator_id_override, # Pass override for formatting ) log_message = await log_channel.send(view=view) # 4. Update DB with message details - await mod_log_db.update_mod_log_message_details(self.pool, case_id, log_message.id, log_channel.id) + await mod_log_db.update_mod_log_message_details( + self.pool, case_id, log_message.id, log_channel.id + ) except Exception as e: - log.exception(f"Error during Discord mod log message sending/updating for case {case_id} in guild {guild_id}: {e}") - + log.exception( + f"Error during Discord mod log message sending/updating for case {case_id} in guild {guild_id}: {e}" + ) def _format_log_embed( self, @@ -281,12 +326,22 @@ class ModLogCog(commands.Cog): "AI_ALERT": Color.purple(), "AI_DELETE_REQUESTED": Color.dark_grey(), } - embed_color = Color.blurple() if source == "AI_API" else color_map.get(action_type.upper(), Color.greyple()) - action_title_prefix = "🤖 AI Moderation Action" if source == "AI_API" else action_type.replace("_", " ").title() + embed_color = ( + Color.blurple() + if source == "AI_API" + else color_map.get(action_type.upper(), Color.greyple()) + ) + action_title_prefix = ( + "🤖 AI Moderation Action" + if source == "AI_API" + else action_type.replace("_", " ").title() + ) action_title = f"{action_title_prefix} | Case #{case_id}" target_display = self._format_user(target, guild) moderator_display = ( - f"AI System (ID: {moderator_id_override or 'Unknown'})" if source == "AI_API" else self._format_user(moderator, guild) + f"AI System (ID: {moderator_id_override or 'Unknown'})" + if source == "AI_API" + else self._format_user(moderator, guild) ) lines = [f"**User:** {target_display}", f"**Moderator:** {moderator_display}"] if ai_details: @@ -294,7 +349,9 @@ class ModLogCog(commands.Cog): lines.append(f"**Rule Violated:** {ai_details['rule_violated']}") if "reasoning" in ai_details: reason_to_display = reason or ai_details["reasoning"] - lines.append(f"**Reason / AI Reasoning:** {reason_to_display or 'No reason provided.'}") + lines.append( + f"**Reason / AI Reasoning:** {reason_to_display or 'No reason provided.'}" + ) if reason and reason != ai_details["reasoning"]: lines.append(f"**Original Bot Reason:** {reason}") else: @@ -326,20 +383,32 @@ class ModLogCog(commands.Cog): expires_at = discord.utils.utcnow() + duration lines.append(f"**Expires:** ") footer = ( - f"AI Moderation Action • {guild.name} ({guild.id})" + (f" • Model: {ai_details.get('ai_model')}" if ai_details and ai_details.get('ai_model') else "") + f"AI Moderation Action • {guild.name} ({guild.id})" + + ( + f" • Model: {ai_details.get('ai_model')}" + if ai_details and ai_details.get("ai_model") + else "" + ) if source == "AI_API" else f"Guild: {guild.name} ({guild.id})" ) return self.LogView(self.bot, action_title, embed_color, lines, footer) + # --- View Command Callback --- - @app_commands.checks.has_permissions(moderate_members=True) # Adjust permissions as needed - async def modlog_view_callback(self, interaction: Interaction, user: Optional[discord.User] = None): + @app_commands.checks.has_permissions( + moderate_members=True + ) # Adjust permissions as needed + async def modlog_view_callback( + self, interaction: Interaction, user: Optional[discord.User] = None + ): """Callback for the /modlog view command.""" await interaction.response.defer(ephemeral=True) guild_id = interaction.guild_id if not guild_id: - await interaction.followup.send("❌ This command can only be used in a server.", ephemeral=True) + await interaction.followup.send( + "❌ This command can only be used in a server.", ephemeral=True + ) return records = [] @@ -351,21 +420,31 @@ class ModLogCog(commands.Cog): title = f"Recent Moderation Logs for {interaction.guild.name}" if not records: - await interaction.followup.send("No moderation logs found matching your criteria.", ephemeral=True) + await interaction.followup.send( + "No moderation logs found matching your criteria.", ephemeral=True + ) return # Format the logs into an embed or text response # For simplicity, sending as text for now. Can enhance with pagination/embeds later. response_lines = [f"**{title}**"] for record in records: - timestamp_str = record['timestamp'].strftime('%Y-%m-%d %H:%M:%S') - reason_str = record['reason'] or "N/A" - duration_str = f" ({record['duration_seconds']}s)" if record['duration_seconds'] else "" - target_disp = await self._fetch_user_display(record['target_user_id'], interaction.guild) - if record['moderator_id'] == 0: + timestamp_str = record["timestamp"].strftime("%Y-%m-%d %H:%M:%S") + reason_str = record["reason"] or "N/A" + duration_str = ( + f" ({record['duration_seconds']}s)" + if record["duration_seconds"] + else "" + ) + target_disp = await self._fetch_user_display( + record["target_user_id"], interaction.guild + ) + if record["moderator_id"] == 0: mod_disp = "AI System" else: - mod_disp = await self._fetch_user_display(record['moderator_id'], interaction.guild) + mod_disp = await self._fetch_user_display( + record["moderator_id"], interaction.guild + ) response_lines.append( f"`Case #{record['case_id']}` [{timestamp_str}] **{record['action_type']}** " f"Target: {target_disp} Mod: {mod_disp} " @@ -379,148 +458,206 @@ class ModLogCog(commands.Cog): await interaction.followup.send(full_response, ephemeral=True) - - @app_commands.checks.has_permissions(moderate_members=True) # Adjust permissions as needed + @app_commands.checks.has_permissions( + moderate_members=True + ) # Adjust permissions as needed async def modlog_case_callback(self, interaction: Interaction, case_id: int): """Callback for the /modlog case command.""" await interaction.response.defer(ephemeral=True) record = await mod_log_db.get_mod_log(self.pool, case_id) if not record: - await interaction.followup.send(f"❌ Case ID #{case_id} not found.", ephemeral=True) + await interaction.followup.send( + f"❌ Case ID #{case_id} not found.", ephemeral=True + ) return # Ensure the case belongs to the current guild for security/privacy - if record['guild_id'] != interaction.guild_id: - await interaction.followup.send(f"❌ Case ID #{case_id} does not belong to this server.", ephemeral=True) - return + if record["guild_id"] != interaction.guild_id: + await interaction.followup.send( + f"❌ Case ID #{case_id} does not belong to this server.", ephemeral=True + ) + return # Fetch user objects if possible to show names # Special handling for AI moderator (ID 0) to avoid Discord API 404 error - if record['moderator_id'] == 0: + if record["moderator_id"] == 0: # AI moderator uses ID 0, which is not a valid Discord user ID moderator = None else: try: - moderator = await self.bot.fetch_user(record['moderator_id']) + moderator = await self.bot.fetch_user(record["moderator_id"]) except discord.NotFound: - log.warning(f"Moderator with ID {record['moderator_id']} not found when viewing case {case_id}") + log.warning( + f"Moderator with ID {record['moderator_id']} not found when viewing case {case_id}" + ) moderator = None try: - target = await self.bot.fetch_user(record['target_user_id']) + target = await self.bot.fetch_user(record["target_user_id"]) except discord.NotFound: - log.warning(f"Target user with ID {record['target_user_id']} not found when viewing case {case_id}") + log.warning( + f"Target user with ID {record['target_user_id']} not found when viewing case {case_id}" + ) target = None - duration = datetime.timedelta(seconds=record['duration_seconds']) if record['duration_seconds'] else None + duration = ( + datetime.timedelta(seconds=record["duration_seconds"]) + if record["duration_seconds"] + else None + ) view = self._format_log_embed( case_id, - moderator or Object(id=record['moderator_id']), # Fallback to Object if user not found - target or Object(id=record['target_user_id']), # Fallback to Object if user not found - record['action_type'], - record['reason'], + moderator + or Object( + id=record["moderator_id"] + ), # Fallback to Object if user not found + target + or Object( + id=record["target_user_id"] + ), # Fallback to Object if user not found + record["action_type"], + record["reason"], duration, - interaction.guild + interaction.guild, ) # Add log message link if available - if record['log_message_id'] and record['log_channel_id']: + if record["log_message_id"] and record["log_channel_id"]: link = f"https://discord.com/channels/{record['guild_id']}/{record['log_channel_id']}/{record['log_message_id']}" # Append jump link as extra line view.footer_display.content += f" | [Jump to Log]({link})" await interaction.followup.send(view=view, ephemeral=True) - - @app_commands.checks.has_permissions(manage_guild=True) # Higher permission for editing reasons - async def modlog_reason_callback(self, interaction: Interaction, case_id: int, new_reason: str): + @app_commands.checks.has_permissions( + manage_guild=True + ) # Higher permission for editing reasons + async def modlog_reason_callback( + self, interaction: Interaction, case_id: int, new_reason: str + ): """Callback for the /modlog reason command.""" await interaction.response.defer(ephemeral=True) # 1. Get the original record to verify guild and existence original_record = await mod_log_db.get_mod_log(self.pool, case_id) if not original_record: - await interaction.followup.send(f"❌ Case ID #{case_id} not found.", ephemeral=True) + await interaction.followup.send( + f"❌ Case ID #{case_id} not found.", ephemeral=True + ) + return + if original_record["guild_id"] != interaction.guild_id: + await interaction.followup.send( + f"❌ Case ID #{case_id} does not belong to this server.", ephemeral=True + ) return - if original_record['guild_id'] != interaction.guild_id: - await interaction.followup.send(f"❌ Case ID #{case_id} does not belong to this server.", ephemeral=True) - return # 2. Update the reason in the database success = await mod_log_db.update_mod_log_reason(self.pool, case_id, new_reason) if not success: - await interaction.followup.send(f"❌ Failed to update reason for Case ID #{case_id}. Please check logs.", ephemeral=True) + await interaction.followup.send( + f"❌ Failed to update reason for Case ID #{case_id}. Please check logs.", + ephemeral=True, + ) return - await interaction.followup.send(f"✅ Updated reason for Case ID #{case_id}.", ephemeral=True) + await interaction.followup.send( + f"✅ Updated reason for Case ID #{case_id}.", ephemeral=True + ) # 3. (Optional but recommended) Update the original log message embed - if original_record['log_message_id'] and original_record['log_channel_id']: + if original_record["log_message_id"] and original_record["log_channel_id"]: try: - log_channel = interaction.guild.get_channel(original_record['log_channel_id']) + log_channel = interaction.guild.get_channel( + original_record["log_channel_id"] + ) if log_channel and isinstance(log_channel, discord.TextChannel): - log_message = await log_channel.fetch_message(original_record['log_message_id']) + log_message = await log_channel.fetch_message( + original_record["log_message_id"] + ) if log_message and log_message.author == self.bot.user: # Re-fetch users/duration to reconstruct embed accurately # Special handling for AI moderator (ID 0) to avoid Discord API 404 error - if original_record['moderator_id'] == 0: + if original_record["moderator_id"] == 0: # AI moderator uses ID 0, which is not a valid Discord user ID moderator = None else: try: - moderator = await self.bot.fetch_user(original_record['moderator_id']) + moderator = await self.bot.fetch_user( + original_record["moderator_id"] + ) except discord.NotFound: - log.warning(f"Moderator with ID {original_record['moderator_id']} not found when updating case {case_id}") + log.warning( + f"Moderator with ID {original_record['moderator_id']} not found when updating case {case_id}" + ) moderator = None try: - target = await self.bot.fetch_user(original_record['target_user_id']) + target = await self.bot.fetch_user( + original_record["target_user_id"] + ) except discord.NotFound: - log.warning(f"Target user with ID {original_record['target_user_id']} not found when updating case {case_id}") + log.warning( + f"Target user with ID {original_record['target_user_id']} not found when updating case {case_id}" + ) target = None - duration = datetime.timedelta(seconds=original_record['duration_seconds']) if original_record['duration_seconds'] else None + duration = ( + datetime.timedelta( + seconds=original_record["duration_seconds"] + ) + if original_record["duration_seconds"] + else None + ) new_view = self._format_log_embed( case_id, - moderator or Object(id=original_record['moderator_id']), - target or Object(id=original_record['target_user_id']), - original_record['action_type'], - new_reason, # Use the new reason here + moderator or Object(id=original_record["moderator_id"]), + target or Object(id=original_record["target_user_id"]), + original_record["action_type"], + new_reason, # Use the new reason here duration, - interaction.guild + interaction.guild, ) link = f"https://discord.com/channels/{original_record['guild_id']}/{original_record['log_channel_id']}/{original_record['log_message_id']}" new_view.footer_display.content += f" | [Jump to Log]({link}) | Updated By: {interaction.user.mention}" await log_message.edit(view=new_view) - log.info(f"Successfully updated log message view for case {case_id}") + log.info( + f"Successfully updated log message view for case {case_id}" + ) except discord.NotFound: - log.warning(f"Original log message or channel not found for case {case_id} when updating reason.") + log.warning( + f"Original log message or channel not found for case {case_id} when updating reason." + ) except discord.Forbidden: - log.warning(f"Missing permissions to edit original log message for case {case_id}.") + log.warning( + f"Missing permissions to edit original log message for case {case_id}." + ) except Exception as e: - log.exception(f"Error updating original log message embed for case {case_id}: {e}") - + log.exception( + f"Error updating original log message embed for case {case_id}: {e}" + ) @commands.Cog.listener() async def on_ready(self): # Ensure the pool and settings_manager are available - if not hasattr(self.bot, 'pg_pool') or not self.bot.pg_pool: - log.error("Database pool not found on bot object. ModLogCog requires bot.pg_pool.") + if not hasattr(self.bot, "pg_pool") or not self.bot.pg_pool: + log.error( + "Database pool not found on bot object. ModLogCog requires bot.pg_pool." + ) # Consider preventing the cog from loading fully or raising an error # Settings manager is imported directly, no need to check on bot object - print(f'{self.__class__.__name__} cog has been loaded.') + print(f"{self.__class__.__name__} cog has been loaded.") async def setup(bot: commands.Bot): # Ensure dependencies (pool) are ready before adding cog # Settings manager is imported directly within the cog - if hasattr(bot, 'pg_pool') and bot.pg_pool: + if hasattr(bot, "pg_pool") and bot.pg_pool: await bot.add_cog(ModLogCog(bot)) else: log.error("Failed to load ModLogCog: bot.pg_pool not initialized.") diff --git a/cogs/moderation_cog.py b/cogs/moderation_cog.py index 3880d10..89fa81c 100644 --- a/cogs/moderation_cog.py +++ b/cogs/moderation_cog.py @@ -3,6 +3,7 @@ from discord.ext import commands from discord import app_commands import random + class FakeModerationCog(commands.Cog): """Fake moderation commands that don't actually perform any actions.""" @@ -12,7 +13,7 @@ class FakeModerationCog(commands.Cog): # Create the main command group for this cog self.fakemod_group = app_commands.Group( name="fakemod", - description="Fake moderation commands that don't actually perform any actions" + description="Fake moderation commands that don't actually perform any actions", ) # Register commands @@ -22,44 +23,46 @@ class FakeModerationCog(commands.Cog): self.bot.tree.add_command(self.fakemod_group) # Helper method for generating responses - async def _fake_moderation_response(self, action, target, reason=None, duration=None): + async def _fake_moderation_response( + self, action, target, reason=None, duration=None + ): """Generate a fake moderation response.""" responses = { "ban": [ f"🔨 **Banned {target}**{f' for {duration}' if duration else ''}! Reason: {reason or 'No reason provided'}", f"👋 {target} has been banned from the server{f' for {duration}' if duration else ''}. Reason: {reason or 'No reason provided'}", - f"🚫 {target} is now banned{f' for {duration}' if duration else ''}. Reason: {reason or 'No reason provided'}" + f"🚫 {target} is now banned{f' for {duration}' if duration else ''}. Reason: {reason or 'No reason provided'}", ], "kick": [ f"👢 **Kicked {target}**! Reason: {reason or 'No reason provided'}", f"👋 {target} has been kicked from the server. Reason: {reason or 'No reason provided'}", - f"🚪 {target} has been shown the door. Reason: {reason or 'No reason provided'}" + f"🚪 {target} has been shown the door. Reason: {reason or 'No reason provided'}", ], "mute": [ f"🔇 **Muted {target}**{f' for {duration}' if duration else ''}! Reason: {reason or 'No reason provided'}", f"🤐 {target} has been muted{f' for {duration}' if duration else ''}. Reason: {reason or 'No reason provided'}", - f"📵 {target} can no longer speak{f' for {duration}' if duration else ''}. Reason: {reason or 'No reason provided'}" + f"📵 {target} can no longer speak{f' for {duration}' if duration else ''}. Reason: {reason or 'No reason provided'}", ], "timeout": [ f"⏰ **Timed out {target}** for {duration or 'some time'}! Reason: {reason or 'No reason provided'}", f"⏳ {target} has been put in timeout for {duration or 'some time'}. Reason: {reason or 'No reason provided'}", - f"🕒 {target} is now in timeout for {duration or 'some time'}. Reason: {reason or 'No reason provided'}" + f"🕒 {target} is now in timeout for {duration or 'some time'}. Reason: {reason or 'No reason provided'}", ], "warn": [ f"⚠️ **Warned {target}**! Reason: {reason or 'No reason provided'}", f"📝 {target} has been warned. Reason: {reason or 'No reason provided'}", - f"🚨 Warning issued to {target}. Reason: {reason or 'No reason provided'}" + f"🚨 Warning issued to {target}. Reason: {reason or 'No reason provided'}", ], "unban": [ f"🔓 **Unbanned {target}**! Reason: {reason or 'No reason provided'}", f"🎊 {target} has been unbanned. Reason: {reason or 'No reason provided'}", - f"🔄 {target} is now allowed back in the server. Reason: {reason or 'No reason provided'}" + f"🔄 {target} is now allowed back in the server. Reason: {reason or 'No reason provided'}", ], "unmute": [ f"🔊 **Unmuted {target}**! Reason: {reason or 'No reason provided'}", f"🗣️ {target} can speak again. Reason: {reason or 'No reason provided'}", - f"📢 {target} has been unmuted. Reason: {reason or 'No reason provided'}" - ] + f"📢 {target} has been unmuted. Reason: {reason or 'No reason provided'}", + ], } return random.choice(responses.get(action, [f"Action performed on {target}"])) @@ -72,12 +75,12 @@ class FakeModerationCog(commands.Cog): name="ban", description="Pretends to ban a member from the server", callback=self.fakemod_ban_callback, - parent=self.fakemod_group + parent=self.fakemod_group, ) app_commands.describe( member="The member to pretend to ban", duration="The fake duration of the ban (e.g., '1d', '7d')", - reason="The fake reason for the ban" + reason="The fake reason for the ban", )(ban_command) self.fakemod_group.add_command(ban_command) @@ -86,11 +89,11 @@ class FakeModerationCog(commands.Cog): name="unban", description="Pretends to unban a user from the server", callback=self.fakemod_unban_callback, - parent=self.fakemod_group + parent=self.fakemod_group, ) app_commands.describe( user="The user to pretend to unban (username or ID)", - reason="The fake reason for the unban" + reason="The fake reason for the unban", )(unban_command) self.fakemod_group.add_command(unban_command) @@ -99,11 +102,11 @@ class FakeModerationCog(commands.Cog): name="kick", description="Pretends to kick a member from the server", callback=self.fakemod_kick_callback, - parent=self.fakemod_group + parent=self.fakemod_group, ) app_commands.describe( member="The member to pretend to kick", - reason="The fake reason for the kick" + reason="The fake reason for the kick", )(kick_command) self.fakemod_group.add_command(kick_command) @@ -112,12 +115,12 @@ class FakeModerationCog(commands.Cog): name="mute", description="Pretends to mute a member in the server", callback=self.fakemod_mute_callback, - parent=self.fakemod_group + parent=self.fakemod_group, ) app_commands.describe( member="The member to pretend to mute", duration="The fake duration of the mute (e.g., '1h', '30m')", - reason="The fake reason for the mute" + reason="The fake reason for the mute", )(mute_command) self.fakemod_group.add_command(mute_command) @@ -126,11 +129,11 @@ class FakeModerationCog(commands.Cog): name="unmute", description="Pretends to unmute a member in the server", callback=self.fakemod_unmute_callback, - parent=self.fakemod_group + parent=self.fakemod_group, ) app_commands.describe( member="The member to pretend to unmute", - reason="The fake reason for the unmute" + reason="The fake reason for the unmute", )(unmute_command) self.fakemod_group.add_command(unmute_command) @@ -139,12 +142,12 @@ class FakeModerationCog(commands.Cog): name="timeout", description="Pretends to timeout a member in the server", callback=self.fakemod_timeout_callback, - parent=self.fakemod_group + parent=self.fakemod_group, ) app_commands.describe( member="The member to pretend to timeout", duration="The fake duration of the timeout (e.g., '1h', '30m')", - reason="The fake reason for the timeout" + reason="The fake reason for the timeout", )(timeout_command) self.fakemod_group.add_command(timeout_command) @@ -153,47 +156,90 @@ class FakeModerationCog(commands.Cog): name="warn", description="Pretends to warn a member in the server", callback=self.fakemod_warn_callback, - parent=self.fakemod_group + parent=self.fakemod_group, ) app_commands.describe( member="The member to pretend to warn", - reason="The fake reason for the warning" + reason="The fake reason for the warning", )(warn_command) self.fakemod_group.add_command(warn_command) # --- Command Callbacks --- - async def fakemod_ban_callback(self, interaction: discord.Interaction, member: discord.Member, duration: str = None, reason: str = None): + async def fakemod_ban_callback( + self, + interaction: discord.Interaction, + member: discord.Member, + duration: str = None, + reason: str = None, + ): """Pretends to ban a member from the server.""" - response = await self._fake_moderation_response("ban", member.mention, reason, duration) + response = await self._fake_moderation_response( + "ban", member.mention, reason, duration + ) await interaction.response.send_message(response) - async def fakemod_unban_callback(self, interaction: discord.Interaction, user: str, reason: str = None): + async def fakemod_unban_callback( + self, interaction: discord.Interaction, user: str, reason: str = None + ): """Pretends to unban a user from the server.""" response = await self._fake_moderation_response("unban", user, reason) await interaction.response.send_message(response) - async def fakemod_kick_callback(self, interaction: discord.Interaction, member: discord.Member, reason: str = None): + async def fakemod_kick_callback( + self, + interaction: discord.Interaction, + member: discord.Member, + reason: str = None, + ): """Pretends to kick a member from the server.""" response = await self._fake_moderation_response("kick", member.mention, reason) await interaction.response.send_message(response) - async def fakemod_mute_callback(self, interaction: discord.Interaction, member: discord.Member, duration: str = None, reason: str = None): + async def fakemod_mute_callback( + self, + interaction: discord.Interaction, + member: discord.Member, + duration: str = None, + reason: str = None, + ): """Pretends to mute a member in the server.""" - response = await self._fake_moderation_response("mute", member.mention, reason, duration) + response = await self._fake_moderation_response( + "mute", member.mention, reason, duration + ) await interaction.response.send_message(response) - async def fakemod_unmute_callback(self, interaction: discord.Interaction, member: discord.Member, reason: str = None): + async def fakemod_unmute_callback( + self, + interaction: discord.Interaction, + member: discord.Member, + reason: str = None, + ): """Pretends to unmute a member in the server.""" - response = await self._fake_moderation_response("unmute", member.mention, reason) + response = await self._fake_moderation_response( + "unmute", member.mention, reason + ) await interaction.response.send_message(response) - async def fakemod_timeout_callback(self, interaction: discord.Interaction, member: discord.Member, duration: str = None, reason: str = None): + async def fakemod_timeout_callback( + self, + interaction: discord.Interaction, + member: discord.Member, + duration: str = None, + reason: str = None, + ): """Pretends to timeout a member in the server.""" - response = await self._fake_moderation_response("timeout", member.mention, reason, duration) + response = await self._fake_moderation_response( + "timeout", member.mention, reason, duration + ) await interaction.response.send_message(response) - async def fakemod_warn_callback(self, interaction: discord.Interaction, member: discord.Member, reason: str = None): + async def fakemod_warn_callback( + self, + interaction: discord.Interaction, + member: discord.Member, + reason: str = None, + ): """Pretends to warn a member in the server.""" response = await self._fake_moderation_response("warn", member.mention, reason) await interaction.response.send_message(response) @@ -201,17 +247,32 @@ class FakeModerationCog(commands.Cog): # --- Legacy Command Handlers (for prefix commands) --- @commands.command(name="ban") - async def ban(self, ctx: commands.Context, member: discord.Member = None, duration: str = None, *, reason: str = None): + async def ban( + self, + ctx: commands.Context, + member: discord.Member = None, + duration: str = None, + *, + reason: str = None, + ): """Pretends to ban a member from the server.""" if not member: await ctx.reply("Please specify a member to ban.") return - response = await self._fake_moderation_response("ban", member.mention, reason, duration) + response = await self._fake_moderation_response( + "ban", member.mention, reason, duration + ) await ctx.reply(response) @commands.command(name="kick") - async def kick(self, ctx: commands.Context, member: discord.Member = None, *, reason: str = None): + async def kick( + self, + ctx: commands.Context, + member: discord.Member = None, + *, + reason: str = None, + ): """Pretends to kick a member from the server.""" if not member: await ctx.reply("Please specify a member to kick.") @@ -221,12 +282,21 @@ class FakeModerationCog(commands.Cog): await ctx.reply(response) @commands.command(name="mute") - async def mute(self, ctx: commands.Context, member: discord.Member = None, duration: str = None, *, reason: str = None): + async def mute( + self, + ctx: commands.Context, + member: discord.Member = None, + duration: str = None, + *, + reason: str = None, + ): """Pretends to mute a member in the server. Can be used by replying to a message.""" # Check if this is a reply to a message and no member was specified if not member and ctx.message.reference: # Get the message being replied to - replied_msg = await ctx.channel.fetch_message(ctx.message.reference.message_id) + replied_msg = await ctx.channel.fetch_message( + ctx.message.reference.message_id + ) member = replied_msg.author # Don't allow muting the bot itself @@ -234,19 +304,34 @@ class FakeModerationCog(commands.Cog): await ctx.reply("❌ I cannot mute myself.") return elif not member: - await ctx.reply("Please specify a member to mute or reply to their message.") + await ctx.reply( + "Please specify a member to mute or reply to their message." + ) return - response = await self._fake_moderation_response("mute", member.mention, reason, duration) + response = await self._fake_moderation_response( + "mute", member.mention, reason, duration + ) await ctx.reply(response) - @commands.command(name="faketimeout", aliases=["fto"]) # Renamed command and added alias - async def fake_timeout(self, ctx: commands.Context, member: discord.Member = None, duration: str = None, *, reason: str = None): # Renamed function + @commands.command( + name="faketimeout", aliases=["fto"] + ) # Renamed command and added alias + async def fake_timeout( + self, + ctx: commands.Context, + member: discord.Member = None, + duration: str = None, + *, + reason: str = None, + ): # Renamed function """Pretends to timeout a member in the server. Can be used by replying to a message.""" # Check if this is a reply to a message and no member was specified if not member and ctx.message.reference: # Get the message being replied to - replied_msg = await ctx.channel.fetch_message(ctx.message.reference.message_id) + replied_msg = await ctx.channel.fetch_message( + ctx.message.reference.message_id + ) member = replied_msg.author # Don't allow timing out the bot itself @@ -254,25 +339,41 @@ class FakeModerationCog(commands.Cog): await ctx.reply("❌ I cannot timeout myself.") return elif not member: - await ctx.reply("Please specify a member to timeout or reply to their message.") + await ctx.reply( + "Please specify a member to timeout or reply to their message." + ) return # If duration wasn't specified but we're in a reply, check if it's the first argument - if not duration and ctx.message.reference and len(ctx.message.content.split()) > 1: + if ( + not duration + and ctx.message.reference + and len(ctx.message.content.split()) > 1 + ): # Try to extract duration from the first argument potential_duration = ctx.message.content.split()[1] # Simple check if it looks like a duration (contains numbers and letters) - if any(c.isdigit() for c in potential_duration) and any(c.isalpha() for c in potential_duration): + if any(c.isdigit() for c in potential_duration) and any( + c.isalpha() for c in potential_duration + ): duration = potential_duration # If there's more content, it's the reason if len(ctx.message.content.split()) > 2: - reason = ' '.join(ctx.message.content.split()[2:]) + reason = " ".join(ctx.message.content.split()[2:]) - response = await self._fake_moderation_response("timeout", member.mention, reason, duration) + response = await self._fake_moderation_response( + "timeout", member.mention, reason, duration + ) await ctx.reply(response) @commands.command(name="warn") - async def warn(self, ctx: commands.Context, member: discord.Member = None, *, reason: str = None): + async def warn( + self, + ctx: commands.Context, + member: discord.Member = None, + *, + reason: str = None, + ): """Pretends to warn a member in the server.""" if not member: await ctx.reply("Please specify a member to warn.") @@ -282,7 +383,9 @@ class FakeModerationCog(commands.Cog): await ctx.reply(response) @commands.command(name="unban") - async def unban(self, ctx: commands.Context, user: str = None, *, reason: str = None): + async def unban( + self, ctx: commands.Context, user: str = None, *, reason: str = None + ): """Pretends to unban a user from the server.""" if not user: await ctx.reply("Please specify a user to unban.") @@ -293,18 +396,27 @@ class FakeModerationCog(commands.Cog): await ctx.reply(response) @commands.command(name="unmute") - async def unmute(self, ctx: commands.Context, member: discord.Member = None, *, reason: str = None): + async def unmute( + self, + ctx: commands.Context, + member: discord.Member = None, + *, + reason: str = None, + ): """Pretends to unmute a member in the server.""" if not member: await ctx.reply("Please specify a member to unmute.") return - response = await self._fake_moderation_response("unmute", member.mention, reason) + response = await self._fake_moderation_response( + "unmute", member.mention, reason + ) await ctx.reply(response) @commands.Cog.listener() async def on_ready(self): - print(f'{self.__class__.__name__} cog has been loaded.') + print(f"{self.__class__.__name__} cog has been loaded.") + async def setup(bot: commands.Bot): await bot.add_cog(FakeModerationCog(bot)) diff --git a/cogs/multi_bot_cog.py b/cogs/multi_bot_cog.py index 1418685..254167b 100644 --- a/cogs/multi_bot_cog.py +++ b/cogs/multi_bot_cog.py @@ -21,37 +21,33 @@ CONFIG_FILE = "data/multi_bot_config.json" NERU_BOT_ID = "neru" MIKU_BOT_ID = "miku" + class MultiBotCog(commands.Cog, name="Multi Bot"): """Cog for managing multiple bot instances""" def __init__(self, bot): self.bot = bot self.bot_processes = {} # Store subprocess objects - self.bot_threads = {} # Store thread objects + self.bot_threads = {} # Store thread objects # Create the main command group for this cog self.multibot_group = app_commands.Group( - name="multibot", - description="Manage multiple bot instances" + name="multibot", description="Manage multiple bot instances" ) # Create subgroups self.config_group = app_commands.Group( name="config", description="Configure bot settings", - parent=self.multibot_group + parent=self.multibot_group, ) self.status_group = app_commands.Group( - name="status", - description="Manage bot status", - parent=self.multibot_group + name="status", description="Manage bot status", parent=self.multibot_group ) self.manage_group = app_commands.Group( - name="manage", - description="Add or remove bots", - parent=self.multibot_group + name="manage", description="Add or remove bots", parent=self.multibot_group ) # Register all commands @@ -73,19 +69,27 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): await ctx.send(result) # --- Main multibot commands --- - async def multibot_start_callback(self, interaction: discord.Interaction, bot_id: str): + async def multibot_start_callback( + self, interaction: discord.Interaction, bot_id: str + ): """Start a specific bot (Owner only)""" if interaction.user.id != self.bot.owner_id: - await interaction.response.send_message("⛔ This command can only be used by the bot owner.", ephemeral=True) + await interaction.response.send_message( + "⛔ This command can only be used by the bot owner.", ephemeral=True + ) return result = await self._start_bot_logic(bot_id) await interaction.response.send_message(result, ephemeral=True) - async def multibot_stop_callback(self, interaction: discord.Interaction, bot_id: str): + async def multibot_stop_callback( + self, interaction: discord.Interaction, bot_id: str + ): """Stop a specific bot (Owner only)""" if interaction.user.id != self.bot.owner_id: - await interaction.response.send_message("⛔ This command can only be used by the bot owner.", ephemeral=True) + await interaction.response.send_message( + "⛔ This command can only be used by the bot owner.", ephemeral=True + ) return result = await self._stop_bot_logic(bot_id) @@ -94,7 +98,9 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): async def multibot_startall_callback(self, interaction: discord.Interaction): """Start all configured bots (Owner only)""" if interaction.user.id != self.bot.owner_id: - await interaction.response.send_message("⛔ This command can only be used by the bot owner.", ephemeral=True) + await interaction.response.send_message( + "⛔ This command can only be used by the bot owner.", ephemeral=True + ) return result = await self._startall_bots_logic() @@ -103,7 +109,9 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): async def multibot_stopall_callback(self, interaction: discord.Interaction): """Stop all running bots (Owner only)""" if interaction.user.id != self.bot.owner_id: - await interaction.response.send_message("⛔ This command can only be used by the bot owner.", ephemeral=True) + await interaction.response.send_message( + "⛔ This command can only be used by the bot owner.", ephemeral=True + ) return result = await self._stopall_bots_logic() @@ -112,7 +120,9 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): async def multibot_list_callback(self, interaction: discord.Interaction): """List all configured bots and their status (Owner only)""" if interaction.user.id != self.bot.owner_id: - await interaction.response.send_message("⛔ This command can only be used by the bot owner.", ephemeral=True) + await interaction.response.send_message( + "⛔ This command can only be used by the bot owner.", ephemeral=True + ) return embed = await self._list_bots_logic() @@ -126,7 +136,7 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): name="start", description="Start a specific bot", callback=self.multibot_start_callback, - parent=self.multibot_group + parent=self.multibot_group, ) self.multibot_group.add_command(start_command) @@ -135,7 +145,7 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): name="stop", description="Stop a specific bot", callback=self.multibot_stop_callback, - parent=self.multibot_group + parent=self.multibot_group, ) self.multibot_group.add_command(stop_command) @@ -144,7 +154,7 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): name="startall", description="Start all configured bots", callback=self.multibot_startall_callback, - parent=self.multibot_group + parent=self.multibot_group, ) self.multibot_group.add_command(startall_command) @@ -153,7 +163,7 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): name="stopall", description="Stop all running bots", callback=self.multibot_stopall_callback, - parent=self.multibot_group + parent=self.multibot_group, ) self.multibot_group.add_command(stopall_command) @@ -162,7 +172,7 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): name="list", description="List all configured bots and their status", callback=self.multibot_list_callback, - parent=self.multibot_group + parent=self.multibot_group, ) self.multibot_group.add_command(list_command) @@ -240,14 +250,22 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): if thread.is_alive(): try: # Find and kill the process by looking for Python processes with multi_bot.py - for proc in psutil.process_iter(['pid', 'name', 'cmdline']): + for proc in psutil.process_iter(["pid", "name", "cmdline"]): try: - cmdline = proc.info['cmdline'] - if cmdline and 'python' in cmdline[0].lower() and any('multi_bot.py' in arg for arg in cmdline if arg): + cmdline = proc.info["cmdline"] + if ( + cmdline + and "python" in cmdline[0].lower() + and any("multi_bot.py" in arg for arg in cmdline if arg) + ): # This is likely our bot process proc.terminate() break - except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + except ( + psutil.NoSuchProcess, + psutil.AccessDenied, + psutil.ZombieProcess, + ): pass # Remove from our tracking @@ -287,8 +305,12 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): continue # Check if already running - if (bot_id in self.bot_processes and self.bot_processes[bot_id].poll() is None) or \ - (bot_id in self.bot_threads and self.bot_threads[bot_id].is_alive()): + if ( + bot_id in self.bot_processes + and self.bot_processes[bot_id].poll() is None + ) or ( + bot_id in self.bot_threads and self.bot_threads[bot_id].is_alive() + ): already_running += 1 continue @@ -350,14 +372,22 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): if thread.is_alive(): try: # Find and kill the process - for proc in psutil.process_iter(['pid', 'name', 'cmdline']): + for proc in psutil.process_iter(["pid", "name", "cmdline"]): try: - cmdline = proc.info['cmdline'] - if cmdline and 'python' in cmdline[0].lower() and any('multi_bot.py' in arg for arg in cmdline if arg): + cmdline = proc.info["cmdline"] + if ( + cmdline + and "python" in cmdline[0].lower() + and any("multi_bot.py" in arg for arg in cmdline if arg) + ): # This is likely our bot process proc.terminate() break - except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + except ( + psutil.NoSuchProcess, + psutil.AccessDenied, + psutil.ZombieProcess, + ): pass del self.bot_threads[bot_id] @@ -381,6 +411,7 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): process.terminate() # Wait a bit for it to terminate import time + time.sleep(1) # If still running, kill it if process.poll() is None: @@ -393,14 +424,22 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): if thread.is_alive(): try: # Find and kill the process - for proc in psutil.process_iter(['pid', 'name', 'cmdline']): + for proc in psutil.process_iter(["pid", "name", "cmdline"]): try: - cmdline = proc.info['cmdline'] - if cmdline and 'python' in cmdline[0].lower() and any('multi_bot.py' in arg for arg in cmdline if arg): + cmdline = proc.info["cmdline"] + if ( + cmdline + and "python" in cmdline[0].lower() + and any("multi_bot.py" in arg for arg in cmdline if arg) + ): # This is likely our bot process proc.terminate() break - except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + except ( + psutil.NoSuchProcess, + psutil.AccessDenied, + psutil.ZombieProcess, + ): pass except Exception as e: print(f"Error stopping bot {bot_id}: {e}") @@ -418,7 +457,7 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): embed = discord.Embed( title="Configured Bots", description="List of all configured bots and their status", - color=discord.Color.blue() + color=discord.Color.blue(), ) # Load the configuration @@ -441,7 +480,10 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): is_running = False run_type = "" - if bot_id in self.bot_processes and self.bot_processes[bot_id].poll() is None: + if ( + bot_id in self.bot_processes + and self.bot_processes[bot_id].poll() is None + ): is_running = True run_type = "process" elif bot_id in self.bot_threads and self.bot_threads[bot_id].is_alive(): @@ -460,18 +502,16 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): run_status = f"Running ({run_type})" if is_running else "Stopped" - bot_list.append(f"**Bot ID**: {bot_id}\n**Status**: {run_status}\n**Prefix**: {prefix}\n**Activity**: {status_type.capitalize()} {status_text}\n**System Prompt**: {system_prompt}\n") + bot_list.append( + f"**Bot ID**: {bot_id}\n**Status**: {run_status}\n**Prefix**: {prefix}\n**Activity**: {status_type.capitalize()} {status_text}\n**System Prompt**: {system_prompt}\n" + ) if not bot_list: embed.description = "No bots configured." return embed for i, bot_info in enumerate(bot_list): - embed.add_field( - name=f"Bot {i+1}", - value=bot_info, - inline=False - ) + embed.add_field(name=f"Bot {i+1}", value=bot_info, inline=False) return embed @@ -514,7 +554,9 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): with open(CONFIG_FILE, "w", encoding="utf-8") as f: json.dump(config, f, indent=4, ensure_ascii=False) - await ctx.send(f"Token for bot {bot_id} has been updated. The message with your token has been deleted for security.") + await ctx.send( + f"Token for bot {bot_id} has been updated. The message with your token has been deleted for security." + ) except Exception as e: await ctx.send(f"Error setting bot token: {e}") @@ -582,10 +624,14 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): with open(CONFIG_FILE, "w", encoding="utf-8") as f: json.dump(config, f, indent=4, ensure_ascii=False) - await ctx.send(f"Command prefix for bot {bot_id} has been updated to '{prefix}'.") + await ctx.send( + f"Command prefix for bot {bot_id} has been updated to '{prefix}'." + ) # Notify that the bot needs to be restarted for the change to take effect - await ctx.send("Note: You need to restart the bot for this change to take effect.") + await ctx.send( + "Note: You need to restart the bot for this change to take effect." + ) except Exception as e: await ctx.send(f"Error setting bot prefix: {e}") @@ -616,7 +662,9 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): with open(CONFIG_FILE, "w", encoding="utf-8") as f: json.dump(config, f, indent=4, ensure_ascii=False) - await ctx.send("API key has been updated. The message with your API key has been deleted for security.") + await ctx.send( + "API key has been updated. The message with your API key has been deleted for security." + ) except Exception as e: await ctx.send(f"Error setting API key: {e}") @@ -648,7 +696,9 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): @commands.command(name="setbotstatus") @commands.is_owner() - async def set_bot_status(self, ctx, bot_id: str, status_type: str, *, status_text: str): + async def set_bot_status( + self, ctx, bot_id: str, status_type: str, *, status_text: str + ): """Set the status for a bot (Owner only) Status types: @@ -659,11 +709,19 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): - competing: "Competing in {status_text}" """ # Validate status type - valid_status_types = ["playing", "listening", "watching", "streaming", "competing"] + valid_status_types = [ + "playing", + "listening", + "watching", + "streaming", + "competing", + ] status_type = status_type.lower() if status_type not in valid_status_types: - await ctx.send(f"Invalid status type: '{status_type}'. Valid types are: {', '.join(valid_status_types)}") + await ctx.send( + f"Invalid status type: '{status_type}'. Valid types are: {', '.join(valid_status_types)}" + ) return # Load the configuration @@ -692,12 +750,16 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): with open(CONFIG_FILE, "w") as f: json.dump(config, f, indent=4) - await ctx.send(f"Status for bot {bot_id} has been updated to '{status_type.capitalize()} {status_text}'.") + await ctx.send( + f"Status for bot {bot_id} has been updated to '{status_type.capitalize()} {status_text}'." + ) # Check if the bot is running and update its status if bot_id in self.bot_threads and self.bot_threads[bot_id].is_alive(): # We can't directly update the status of a bot running in a thread - await ctx.send("Note: You need to restart the bot for this change to take effect.") + await ctx.send( + "Note: You need to restart the bot for this change to take effect." + ) except Exception as e: await ctx.send(f"Error setting bot status: {e}") @@ -715,11 +777,19 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): - competing: "Competing in {status_text}" """ # Validate status type - valid_status_types = ["playing", "listening", "watching", "streaming", "competing"] + valid_status_types = [ + "playing", + "listening", + "watching", + "streaming", + "competing", + ] status_type = status_type.lower() if status_type not in valid_status_types: - await ctx.send(f"Invalid status type: '{status_type}'. Valid types are: {', '.join(valid_status_types)}") + await ctx.send( + f"Invalid status type: '{status_type}'. Valid types are: {', '.join(valid_status_types)}" + ) return # Load the configuration @@ -746,12 +816,20 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): with open(CONFIG_FILE, "w") as f: json.dump(config, f, indent=4) - await ctx.send(f"Status for all {updated_count} bots has been updated to '{status_type.capitalize()} {status_text}'.") + await ctx.send( + f"Status for all {updated_count} bots has been updated to '{status_type.capitalize()} {status_text}'." + ) # Check if any bots are running - running_bots = [bot_id for bot_id, thread in self.bot_threads.items() if thread.is_alive()] + running_bots = [ + bot_id + for bot_id, thread in self.bot_threads.items() + if thread.is_alive() + ] if running_bots: - await ctx.send(f"Note: You need to restart the following bots for this change to take effect: {', '.join(running_bots)}") + await ctx.send( + f"Note: You need to restart the following bots for this change to take effect: {', '.join(running_bots)}" + ) except Exception as e: await ctx.send(f"Error setting all bot statuses: {e}") @@ -786,7 +864,7 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): "temperature": 0.7, "timeout": 60, "status_type": "listening", - "status_text": f"{prefix}ai" + "status_text": f"{prefix}ai", } # Add to the configuration @@ -799,8 +877,12 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): with open(CONFIG_FILE, "w") as f: json.dump(config, f, indent=4) - await ctx.send(f"Bot '{bot_id}' added to configuration with prefix '{prefix}'.") - await ctx.send("Note: You need to set a token for this bot using the `!setbottoken` command before starting it.") + await ctx.send( + f"Bot '{bot_id}' added to configuration with prefix '{prefix}'." + ) + await ctx.send( + "Note: You need to set a token for this bot using the `!setbottoken` command before starting it." + ) except Exception as e: await ctx.send(f"Error adding bot: {e}") @@ -819,8 +901,10 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): config = json.load(f) # Check if the bot is running and stop it - if (bot_id in self.bot_processes and self.bot_processes[bot_id].poll() is None) or \ - (bot_id in self.bot_threads and self.bot_threads[bot_id].is_alive()): + if ( + bot_id in self.bot_processes + and self.bot_processes[bot_id].poll() is None + ) or (bot_id in self.bot_threads and self.bot_threads[bot_id].is_alive()): await self.stop_bot(ctx, bot_id) # Find and remove the bot configuration @@ -842,5 +926,6 @@ class MultiBotCog(commands.Cog, name="Multi Bot"): except Exception as e: await ctx.send(f"Error removing bot: {e}") + async def setup(bot): await bot.add_cog(MultiBotCog(bot)) diff --git a/cogs/neru_message_cog.py b/cogs/neru_message_cog.py index b5be0ce..a1a3f44 100644 --- a/cogs/neru_message_cog.py +++ b/cogs/neru_message_cog.py @@ -10,11 +10,12 @@ from .rp_messages import ( get_hug_messages, get_headpat_messages, MOLEST_MESSAGE_TEMPLATE, - get_cumshot_messages + get_cumshot_messages, ) log = logging.getLogger(__name__) + class MessageCog(commands.Cog): def __init__(self, bot): self.bot = bot @@ -25,13 +26,14 @@ class MessageCog(commands.Cog): async def _ensure_usage_table_exists(self): """Ensure the command usage counters table exists.""" - if not hasattr(self.bot, 'pg_pool') or not self.bot.pg_pool: + if not hasattr(self.bot, "pg_pool") or not self.bot.pg_pool: log.warning("Database pool not available for usage tracking.") return False try: async with self.bot.pg_pool.acquire() as conn: - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS command_usage_counters ( user1_id BIGINT NOT NULL, user2_id BIGINT NOT NULL, @@ -39,46 +41,65 @@ class MessageCog(commands.Cog): usage_count INTEGER NOT NULL DEFAULT 1, PRIMARY KEY (user1_id, user2_id, command_name) ) - """) + """ + ) return True except Exception as e: log.error(f"Error creating usage counters table: {e}") return False - async def _increment_usage_counter(self, user1_id: int, user2_id: int, command_name: str): + async def _increment_usage_counter( + self, user1_id: int, user2_id: int, command_name: str + ): """Increment the usage counter for a command between two users.""" if not await self._ensure_usage_table_exists(): return try: async with self.bot.pg_pool.acquire() as conn: - await conn.execute(""" + await conn.execute( + """ INSERT INTO command_usage_counters (user1_id, user2_id, command_name, usage_count) VALUES ($1, $2, $3, 1) ON CONFLICT (user1_id, user2_id, command_name) DO UPDATE SET usage_count = command_usage_counters.usage_count + 1 - """, user1_id, user2_id, command_name) - log.debug(f"Incremented usage counter for {command_name} between users {user1_id} and {user2_id}") + """, + user1_id, + user2_id, + command_name, + ) + log.debug( + f"Incremented usage counter for {command_name} between users {user1_id} and {user2_id}" + ) except Exception as e: log.error(f"Error incrementing usage counter: {e}") - async def _get_usage_count(self, user1_id: int, user2_id: int, command_name: str) -> int: + async def _get_usage_count( + self, user1_id: int, user2_id: int, command_name: str + ) -> int: """Get the usage count for a command between two users.""" if not await self._ensure_usage_table_exists(): return 0 try: async with self.bot.pg_pool.acquire() as conn: - count = await conn.fetchval(""" + count = await conn.fetchval( + """ SELECT usage_count FROM command_usage_counters WHERE user1_id = $1 AND user2_id = $2 AND command_name = $3 - """, user1_id, user2_id, command_name) + """, + user1_id, + user2_id, + command_name, + ) return count if count is not None else 0 except Exception as e: log.error(f"Error getting usage count: {e}") return 0 - async def _get_bidirectional_usage_counts(self, user1_id: int, user2_id: int, command_name: str) -> tuple[int, int]: + async def _get_bidirectional_usage_counts( + self, user1_id: int, user2_id: int, command_name: str + ) -> tuple[int, int]: """Get the usage counts for a command in both directions between two users. Returns: @@ -90,19 +111,31 @@ class MessageCog(commands.Cog): try: async with self.bot.pg_pool.acquire() as conn: # Get count for user1 -> user2 - count_1_to_2 = await conn.fetchval(""" + count_1_to_2 = await conn.fetchval( + """ SELECT usage_count FROM command_usage_counters WHERE user1_id = $1 AND user2_id = $2 AND command_name = $3 - """, user1_id, user2_id, command_name) + """, + user1_id, + user2_id, + command_name, + ) # Get count for user2 -> user1 - count_2_to_1 = await conn.fetchval(""" + count_2_to_1 = await conn.fetchval( + """ SELECT usage_count FROM command_usage_counters WHERE user1_id = $1 AND user2_id = $2 AND command_name = $3 - """, user2_id, user1_id, command_name) + """, + user2_id, + user1_id, + command_name, + ) - return (count_1_to_2 if count_1_to_2 is not None else 0, - count_2_to_1 if count_2_to_1 is not None else 0) + return ( + count_1_to_2 if count_1_to_2 is not None else 0, + count_2_to_1 if count_2_to_1 is not None else 0, + ) except Exception as e: log.error(f"Error getting bidirectional usage counts: {e}") return 0, 0 @@ -110,7 +143,9 @@ class MessageCog(commands.Cog): # --- RP Group --- rp = app_commands.Group(name="rp", description="Roleplay commands") - @rp.command(name="sex", description="Send a normal sex message to the mentioned user") + @rp.command( + name="sex", description="Send a normal sex message to the mentioned user" + ) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.describe(member="The user to send the message to") @@ -120,9 +155,13 @@ class MessageCog(commands.Cog): await self._increment_usage_counter(interaction.user.id, member.id, "neru_sex") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(interaction.user.id, member.id, "neru_sex") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + interaction.user.id, member.id, "neru_sex" + ) - response = random.choice(get_sex_messages(interaction.user.mention, member.mention)) + response = random.choice( + get_sex_messages(interaction.user.mention, member.mention) + ) response += f"\n-# {interaction.user.display_name} and {member.display_name} have had sex {caller_to_target} {self.plural('time', caller_to_target)}" if target_to_caller > 0: response += f", {member.display_name} and {interaction.user.display_name} have had sex {target_to_caller} {self.plural('time', target_to_caller)}" @@ -135,7 +174,9 @@ class MessageCog(commands.Cog): await self._increment_usage_counter(ctx.author.id, member.id, "neru_sex") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(ctx.author.id, member.id, "neru_sex") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + ctx.author.id, member.id, "neru_sex" + ) response = random.choice(get_sex_messages(ctx.author.mention, member.mention)) response += f"\n-# {ctx.author.display_name} and {member.display_name} have had sex {caller_to_target} {self.plural('time', caller_to_target)}" @@ -143,7 +184,10 @@ class MessageCog(commands.Cog): response += f", {member.display_name} and {ctx.author.display_name} have had sex {target_to_caller} {self.plural('time', target_to_caller)}" await ctx.reply(response) - @rp.command(name="rape", description="Sends a message stating the author raped the mentioned user.") + @rp.command( + name="rape", + description="Sends a message stating the author raped the mentioned user.", + ) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.describe(member="The user to mention in the message") @@ -153,15 +197,21 @@ class MessageCog(commands.Cog): await self._increment_usage_counter(interaction.user.id, member.id, "neru_rape") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(interaction.user.id, member.id, "neru_rape") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + interaction.user.id, member.id, "neru_rape" + ) - response = random.choice(get_rape_messages(interaction.user.mention, member.mention)) + response = random.choice( + get_rape_messages(interaction.user.mention, member.mention) + ) response += f"\n-# {interaction.user.display_name} has raped {member.display_name} {caller_to_target} {self.plural('time', caller_to_target)}" if target_to_caller > 0: response += f", {member.display_name} has raped {interaction.user.display_name} {target_to_caller} {self.plural('time', target_to_caller)}" await interaction.response.send_message(response) - @rp.command(name="kiss", description="Send a wholesome kiss message to the mentioned user") + @rp.command( + name="kiss", description="Send a wholesome kiss message to the mentioned user" + ) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.describe(member="The user to send the message to") @@ -171,9 +221,13 @@ class MessageCog(commands.Cog): await self._increment_usage_counter(interaction.user.id, member.id, "neru_kiss") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(interaction.user.id, member.id, "neru_kiss") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + interaction.user.id, member.id, "neru_kiss" + ) - response = random.choice(get_kiss_messages(interaction.user.mention, member.mention)) + response = random.choice( + get_kiss_messages(interaction.user.mention, member.mention) + ) response += f"\n-# {interaction.user.display_name} and {member.display_name} have kissed {caller_to_target} {self.plural('time', caller_to_target)}" if target_to_caller > 0: response += f", {member.display_name} and {interaction.user.display_name} have kissed {target_to_caller} {self.plural('time', target_to_caller)}" @@ -186,7 +240,9 @@ class MessageCog(commands.Cog): await self._increment_usage_counter(ctx.author.id, member.id, "neru_kiss") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(ctx.author.id, member.id, "neru_kiss") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + ctx.author.id, member.id, "neru_kiss" + ) response = random.choice(get_kiss_messages(ctx.author.mention, member.mention)) response += f"\n-# {ctx.author.display_name} and {member.display_name} have kissed {caller_to_target} {self.plural('time', caller_to_target)}" @@ -194,7 +250,9 @@ class MessageCog(commands.Cog): response += f", {member.display_name} and {ctx.author.display_name} have kissed {target_to_caller} {self.plural('time', target_to_caller)}" await ctx.reply(response) - @rp.command(name="hug", description="Send a wholesome hug message to the mentioned user") + @rp.command( + name="hug", description="Send a wholesome hug message to the mentioned user" + ) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.describe(member="The user to send the message to") @@ -204,9 +262,13 @@ class MessageCog(commands.Cog): await self._increment_usage_counter(interaction.user.id, member.id, "neru_hug") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(interaction.user.id, member.id, "neru_hug") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + interaction.user.id, member.id, "neru_hug" + ) - response = random.choice(get_hug_messages(interaction.user.mention, member.mention)) + response = random.choice( + get_hug_messages(interaction.user.mention, member.mention) + ) response += f"\n-# {interaction.user.display_name} and {member.display_name} have hugged {caller_to_target} {self.plural('time', caller_to_target)}" if target_to_caller > 0: response += f", {member.display_name} and {interaction.user.display_name} have hugged {target_to_caller} {self.plural('time', target_to_caller)}" @@ -219,7 +281,9 @@ class MessageCog(commands.Cog): await self._increment_usage_counter(ctx.author.id, member.id, "neru_hug") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(ctx.author.id, member.id, "neru_hug") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + ctx.author.id, member.id, "neru_hug" + ) response = random.choice(get_hug_messages(ctx.author.mention, member.mention)) response += f"\n-# {ctx.author.display_name} and {member.display_name} have hugged {caller_to_target} {self.plural('time', caller_to_target)}" @@ -227,19 +291,30 @@ class MessageCog(commands.Cog): response += f", {member.display_name} and {ctx.author.display_name} have hugged {target_to_caller} {self.plural('time', target_to_caller)}" await ctx.reply(response) - @rp.command(name="headpat", description="Send a wholesome headpat message to the mentioned user") + @rp.command( + name="headpat", + description="Send a wholesome headpat message to the mentioned user", + ) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.describe(member="The user to send the message to") - async def headpat_slash(self, interaction: discord.Interaction, member: discord.User): + async def headpat_slash( + self, interaction: discord.Interaction, member: discord.User + ): """Slash command version of headpat.""" # Track usage between the two users - await self._increment_usage_counter(interaction.user.id, member.id, "neru_headpat") + await self._increment_usage_counter( + interaction.user.id, member.id, "neru_headpat" + ) # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(interaction.user.id, member.id, "neru_headpat") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + interaction.user.id, member.id, "neru_headpat" + ) - response = random.choice(get_headpat_messages(interaction.user.mention, member.mention)) + response = random.choice( + get_headpat_messages(interaction.user.mention, member.mention) + ) response += f"\n-# {interaction.user.display_name} and {member.display_name} have headpatted {caller_to_target} {self.plural('time', caller_to_target)}" if target_to_caller > 0: response += f", {member.display_name} and {interaction.user.display_name} have headpatted {target_to_caller} {self.plural('time', target_to_caller)}" @@ -252,25 +327,39 @@ class MessageCog(commands.Cog): await self._increment_usage_counter(ctx.author.id, member.id, "neru_headpat") # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(ctx.author.id, member.id, "neru_headpat") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + ctx.author.id, member.id, "neru_headpat" + ) - response = random.choice(get_headpat_messages(ctx.author.mention, member.mention)) + response = random.choice( + get_headpat_messages(ctx.author.mention, member.mention) + ) # Get the bidirectional counts - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(ctx.author.id, member.id, "neru_headpat") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + ctx.author.id, member.id, "neru_headpat" + ) response += f"\n-# {ctx.author.display_name} and {member.display_name} have headpatted {caller_to_target} {self.plural('time', caller_to_target)}" if target_to_caller > 0: response += f", {member.display_name} and {ctx.author.display_name} have headpatted {target_to_caller} {self.plural('time', target_to_caller)}" await ctx.reply(response) - @rp.command(name="molest", description="Send a hardcoded message to the mentioned user") + @rp.command( + name="molest", description="Send a hardcoded message to the mentioned user" + ) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.describe(member="The user to send the message to") - async def molest_slash(self, interaction: discord.Interaction, member: discord.User): + async def molest_slash( + self, interaction: discord.Interaction, member: discord.User + ): """Slash command version of molest.""" - await self._increment_usage_counter(interaction.user.id, member.id, "neru_molest") - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(interaction.user.id, member.id, "neru_molest") + await self._increment_usage_counter( + interaction.user.id, member.id, "neru_molest" + ) + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + interaction.user.id, member.id, "neru_molest" + ) response = MOLEST_MESSAGE_TEMPLATE.format(target=member.mention) response += f"\n-# {interaction.user.display_name} has molested {member.display_name} {caller_to_target} {self.plural('time', caller_to_target)}" @@ -282,7 +371,9 @@ class MessageCog(commands.Cog): async def molest_legacy(self, ctx: commands.Context, member: discord.User): """Legacy command version of molest.""" await self._increment_usage_counter(ctx.author.id, member.id, "neru_molest") - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(ctx.author.id, member.id, "neru_molest") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + ctx.author.id, member.id, "neru_molest" + ) response = MOLEST_MESSAGE_TEMPLATE.format(target=member.mention) response += f"\n-# {ctx.author.display_name} has molested {member.display_name} {caller_to_target} {self.plural('time', caller_to_target)}" @@ -290,16 +381,26 @@ class MessageCog(commands.Cog): response += f", {member.display_name} has molested {ctx.author.display_name} {target_to_caller} {self.plural('time', target_to_caller)}" await ctx.reply(response) - @rp.command(name="cumshot", description="Send a cumshot message to the mentioned user") + @rp.command( + name="cumshot", description="Send a cumshot message to the mentioned user" + ) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @app_commands.describe(member="The user to send the message to") - async def cumshot_slash(self, interaction: discord.Interaction, member: discord.User): + async def cumshot_slash( + self, interaction: discord.Interaction, member: discord.User + ): """Slash command version of cumshot.""" - await self._increment_usage_counter(interaction.user.id, member.id, "neru_cumshot") - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(interaction.user.id, member.id, "neru_cumshot") + await self._increment_usage_counter( + interaction.user.id, member.id, "neru_cumshot" + ) + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + interaction.user.id, member.id, "neru_cumshot" + ) - response = random.choice(get_cumshot_messages(interaction.user.mention, member.mention)) + response = random.choice( + get_cumshot_messages(interaction.user.mention, member.mention) + ) response += f"\n-# {interaction.user.display_name} has came on {member.display_name} {caller_to_target} {self.plural('time', caller_to_target)}" if target_to_caller > 0: response += f", {member.display_name} has came on {interaction.user.display_name} {target_to_caller} {self.plural('time', target_to_caller)}" @@ -309,9 +410,13 @@ class MessageCog(commands.Cog): async def cumshot_legacy(self, ctx: commands.Context, member: discord.User): """Legacy command version of cumshot.""" await self._increment_usage_counter(ctx.author.id, member.id, "neru_cumshot") - caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts(ctx.author.id, member.id, "neru_cumshot") + caller_to_target, target_to_caller = await self._get_bidirectional_usage_counts( + ctx.author.id, member.id, "neru_cumshot" + ) - response = random.choice(get_cumshot_messages(ctx.author.mention, member.mention)) + response = random.choice( + get_cumshot_messages(ctx.author.mention, member.mention) + ) response += f"\n-# {ctx.author.display_name} has came on {member.display_name} {caller_to_target} {self.plural('time', caller_to_target)}" if target_to_caller > 0: response += f", {member.display_name} has came on {ctx.author.display_name} {target_to_caller} {self.plural('time', target_to_caller)}" @@ -320,35 +425,63 @@ class MessageCog(commands.Cog): # --- Memes Group --- memes = app_commands.Group(name="memes", description="Meme and copypasta commands") - @memes.command(name="seals", description="What the fuck did you just fucking say about me, you little bitch?") + @memes.command( + name="seals", + description="What the fuck did you just fucking say about me, you little bitch?", + ) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) async def seals_slash(self, interaction: discord.Interaction): - await interaction.response.send_message("What the fuck did you just fucking say about me, you little bitch? I'll have you know I graduated top of my class in the Navy Seals, and I've been involved in numerous secret raids on Al-Quaeda, and I have over 300 confirmed kills. I am trained in gorilla warfare and I'm the top sniper in the entire US armed forces. You are nothing to me but just another target. I will wipe you the fuck out with precision the likes of which has never been seen before on this Earth, mark my fucking words. You think you can get away with saying that shit to me over the Internet? Think again, fucker. As we speak I am contacting my secret network of spies across the USA and your IP is being traced right now so you better prepare for the storm, maggot. The storm that wipes out the pathetic little thing you call your life. You're fucking dead, kid. I can be anywhere, anytime, and I can kill you in over seven hundred ways, and that's just with my bare hands. Not only am I extensively trained in unarmed combat, but I have access to the entire arsenal of the United States Marine Corps and I will use it to its full extent to wipe your miserable ass off the face of the continent, you little shit. If only you could have known what unholy retribution your little \"clever\" comment was about to bring down upon you, maybe you would have held your fucking tongue. But you couldn't, you didn't, and now you're paying the price, you goddamn idiot. I will shit fury all over you and you will drown in it. You're fucking dead, kiddo.") + await interaction.response.send_message( + "What the fuck did you just fucking say about me, you little bitch? I'll have you know I graduated top of my class in the Navy Seals, and I've been involved in numerous secret raids on Al-Quaeda, and I have over 300 confirmed kills. I am trained in gorilla warfare and I'm the top sniper in the entire US armed forces. You are nothing to me but just another target. I will wipe you the fuck out with precision the likes of which has never been seen before on this Earth, mark my fucking words. You think you can get away with saying that shit to me over the Internet? Think again, fucker. As we speak I am contacting my secret network of spies across the USA and your IP is being traced right now so you better prepare for the storm, maggot. The storm that wipes out the pathetic little thing you call your life. You're fucking dead, kid. I can be anywhere, anytime, and I can kill you in over seven hundred ways, and that's just with my bare hands. Not only am I extensively trained in unarmed combat, but I have access to the entire arsenal of the United States Marine Corps and I will use it to its full extent to wipe your miserable ass off the face of the continent, you little shit. If only you could have known what unholy retribution your little \"clever\" comment was about to bring down upon you, maybe you would have held your fucking tongue. But you couldn't, you didn't, and now you're paying the price, you goddamn idiot. I will shit fury all over you and you will drown in it. You're fucking dead, kiddo." + ) - @commands.command(name="seals", help="What the fuck did you just fucking say about me, you little bitch?") # Assuming you want to keep this check for the legacy command + @commands.command( + name="seals", + help="What the fuck did you just fucking say about me, you little bitch?", + ) # Assuming you want to keep this check for the legacy command async def seals_legacy(self, ctx): - await ctx.send("What the fuck did you just fucking say about me, you little bitch? I'll have you know I graduated top of my class in the Navy Seals, and I've been involved in numerous secret raids on Al-Quaeda, and I have over 300 confirmed kills. I am trained in gorilla warfare and I'm the top sniper in the entire US armed forces. You are nothing to me but just another target. I will wipe you the fuck out with precision the likes of which has never been seen before on this Earth, mark my fucking words. You think you can get away with saying that shit to me over the Internet? Think again, fucker. As we speak I am contacting my secret network of spies across the USA and your IP is being traced right now so you better prepare for the storm, maggot. The storm that wipes out the pathetic little thing you call your life. You're fucking dead, kid. I can be anywhere, anytime, and I can kill you in over seven hundred ways, and that's just with my bare hands. Not only am I extensively trained in unarmed combat, but I have access to the entire arsenal of the United States Marine Corps and I will use it to its full extent to wipe your miserable ass off the face of the continent, you little shit. If only you could have known what unholy retribution your little \"clever\" comment was about to bring down upon you, maybe you would have held your fucking tongue. But you couldn't, you didn't, and now you're paying the price, you goddamn idiot. I will shit fury all over you and you will drown in it. You're fucking dead, kiddo.") + await ctx.send( + "What the fuck did you just fucking say about me, you little bitch? I'll have you know I graduated top of my class in the Navy Seals, and I've been involved in numerous secret raids on Al-Quaeda, and I have over 300 confirmed kills. I am trained in gorilla warfare and I'm the top sniper in the entire US armed forces. You are nothing to me but just another target. I will wipe you the fuck out with precision the likes of which has never been seen before on this Earth, mark my fucking words. You think you can get away with saying that shit to me over the Internet? Think again, fucker. As we speak I am contacting my secret network of spies across the USA and your IP is being traced right now so you better prepare for the storm, maggot. The storm that wipes out the pathetic little thing you call your life. You're fucking dead, kid. I can be anywhere, anytime, and I can kill you in over seven hundred ways, and that's just with my bare hands. Not only am I extensively trained in unarmed combat, but I have access to the entire arsenal of the United States Marine Corps and I will use it to its full extent to wipe your miserable ass off the face of the continent, you little shit. If only you could have known what unholy retribution your little \"clever\" comment was about to bring down upon you, maybe you would have held your fucking tongue. But you couldn't, you didn't, and now you're paying the price, you goddamn idiot. I will shit fury all over you and you will drown in it. You're fucking dead, kiddo." + ) - @memes.command(name="notlikeus", description="Honestly i think They Not Like Us is the only mumble rap song that is good") + @memes.command( + name="notlikeus", + description="Honestly i think They Not Like Us is the only mumble rap song that is good", + ) @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) async def notlikeus_slash(self, interaction: discord.Interaction): - await interaction.response.send_message("Honestly i think They Not Like Us is the only mumble rap song that is good, because it calls out Drake for being a Diddy blud") + await interaction.response.send_message( + "Honestly i think They Not Like Us is the only mumble rap song that is good, because it calls out Drake for being a Diddy blud" + ) - @commands.command(name="notlikeus", help="Honestly i think They Not Like Us is the only mumble rap song that is good") # Assuming you want to keep this check for the legacy command + @commands.command( + name="notlikeus", + help="Honestly i think They Not Like Us is the only mumble rap song that is good", + ) # Assuming you want to keep this check for the legacy command async def notlikeus_legacy(self, ctx): - await ctx.send("Honestly i think They Not Like Us is the only mumble rap song that is good, because it calls out Drake for being a Diddy blud") + await ctx.send( + "Honestly i think They Not Like Us is the only mumble rap song that is good, because it calls out Drake for being a Diddy blud" + ) @memes.command(name="pmo", description="icl u pmo") @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) async def pmo_slash(self, interaction: discord.Interaction): - await interaction.response.send_message("icl u pmo n ts pmo sm ngl r u fr rn b fr I h8 bein diff idek anm mn js I h8 ts y r u so b so fr w me rn cz lol oms icl ts pmo sm n sb rn ngl, r u srsly srs n fr rn vro? lol atp js qt") + await interaction.response.send_message( + "icl u pmo n ts pmo sm ngl r u fr rn b fr I h8 bein diff idek anm mn js I h8 ts y r u so b so fr w me rn cz lol oms icl ts pmo sm n sb rn ngl, r u srsly srs n fr rn vro? lol atp js qt" + ) - @commands.command(name="pmo", help="icl u pmo n ts pmo sm ngl r u fr rn b fr I h8 bein diff idek anm mn js I h8 ts y r u so b so fr w me rn cz lol oms icl ts pmo sm n sb rn ngl, r u srsly srs n fr rn vro? lol atp js qt") + @commands.command( + name="pmo", + help="icl u pmo n ts pmo sm ngl r u fr rn b fr I h8 bein diff idek anm mn js I h8 ts y r u so b so fr w me rn cz lol oms icl ts pmo sm n sb rn ngl, r u srsly srs n fr rn vro? lol atp js qt", + ) async def pmo_legacy(self, ctx: commands.Context): - await ctx.send("icl u pmo n ts pmo sm ngl r u fr rn b fr I h8 bein diff idek anm mn js I h8 ts y r u so b so fr w me rn cz lol oms icl ts pmo sm n sb rn ngl, r u srsly srs n fr rn vro? lol atp js qt") + await ctx.send( + "icl u pmo n ts pmo sm ngl r u fr rn b fr I h8 bein diff idek anm mn js I h8 ts y r u so b so fr w me rn cz lol oms icl ts pmo sm n sb rn ngl, r u srsly srs n fr rn vro? lol atp js qt" + ) + async def setup(bot: commands.Bot): await bot.add_cog(MessageCog(bot)) diff --git a/cogs/neru_roleplay_cog.py b/cogs/neru_roleplay_cog.py index f4dc2c9..1b4f13d 100644 --- a/cogs/neru_roleplay_cog.py +++ b/cogs/neru_roleplay_cog.py @@ -2,6 +2,7 @@ import discord from discord.ext import commands from discord import app_commands + class RoleplayCog(commands.Cog): def __init__(self, bot): self.bot = bot @@ -17,23 +18,38 @@ class RoleplayCog(commands.Cog): # --- Prefix Command --- @commands.command(name="backshots") - async def backshots(self, ctx: commands.Context, recipient: discord.User, reverse: bool = False): + async def backshots( + self, ctx: commands.Context, recipient: discord.User, reverse: bool = False + ): """Send a roleplay message about giving backshots to a mentioned user.""" sender_mention = ctx.author.mention - response = await self._backshots_logic(sender_mention, recipient.mention, reverse=reverse) + response = await self._backshots_logic( + sender_mention, recipient.mention, reverse=reverse + ) await ctx.send(response) # --- Slash Command --- - @app_commands.command(name="backshots", description="Send a roleplay message about giving backshots to a mentioned user") + @app_commands.command( + name="backshots", + description="Send a roleplay message about giving backshots to a mentioned user", + ) @app_commands.describe( recipient="The user receiving backshots", - reverse="Reverse the roles of the sender and recipient" + reverse="Reverse the roles of the sender and recipient", ) - async def backshots_slash(self, interaction: discord.Interaction, recipient: discord.User, reverse: bool = False): + async def backshots_slash( + self, + interaction: discord.Interaction, + recipient: discord.User, + reverse: bool = False, + ): """Slash command version of backshots.""" sender_mention = interaction.user.mention - response = await self._backshots_logic(sender_mention, recipient.mention, reverse=reverse) + response = await self._backshots_logic( + sender_mention, recipient.mention, reverse=reverse + ) await interaction.response.send_message(response) + async def setup(bot: commands.Bot): await bot.add_cog(RoleplayCog(bot)) diff --git a/cogs/neru_teto_cog.py b/cogs/neru_teto_cog.py index 0cb6aa6..d994fc1 100644 --- a/cogs/neru_teto_cog.py +++ b/cogs/neru_teto_cog.py @@ -6,12 +6,15 @@ import base64 import io from typing import Optional + def strip_think_blocks(text): # Removes all ... blocks, including multiline return re.sub(r".*?", "", text, flags=re.DOTALL) + def encode_image_to_base64(image_data): - return base64.b64encode(image_data).decode('utf-8') + return base64.b64encode(image_data).decode("utf-8") + # In-memory conversation history for Kasane Teto AI (keyed by channel id) _teto_conversations = {} @@ -26,14 +29,26 @@ from google.api_core import exceptions as google_exceptions from gurt.config import PROJECT_ID, LOCATION STANDARD_SAFETY_SETTINGS = [ - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold="BLOCK_NONE"), - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold="BLOCK_NONE"), - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold="BLOCK_NONE"), - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold="BLOCK_NONE"), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold="BLOCK_NONE" + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold="BLOCK_NONE", + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold="BLOCK_NONE", + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold="BLOCK_NONE" + ), ] -def _get_response_text(response: Optional[types.GenerateContentResponse]) -> Optional[str]: +def _get_response_text( + response: Optional[types.GenerateContentResponse], +) -> Optional[str]: """Extract text from a Vertex AI response if available.""" if not response: return None @@ -46,12 +61,17 @@ def _get_response_text(response: Optional[types.GenerateContentResponse]) -> Opt if not getattr(candidate, "content", None) or not candidate.content.parts: return None for part in candidate.content.parts: - if hasattr(part, "text") and isinstance(part.text, str) and part.text.strip(): + if ( + hasattr(part, "text") + and isinstance(part.text, str) + and part.text.strip() + ): return part.text return None except (AttributeError, IndexError, TypeError): return None + class DmbotTetoCog(commands.Cog): def __init__(self, bot: commands.Bot): self.bot = bot @@ -67,7 +87,9 @@ class DmbotTetoCog(commands.Cog): except Exception: self.genai_client = None - self._ai_model = "gemini-2.5-flash-preview-05-20" # Default model used by TetoCog + self._ai_model = ( + "gemini-2.5-flash-preview-05-20" # Default model used by TetoCog + ) async def _teto_reply_ai_with_messages(self, messages, system_mode="reply"): """Use Vertex AI to generate a Kasane Teto-style response.""" @@ -91,14 +113,18 @@ class DmbotTetoCog(commands.Cog): for msg in messages: role = "user" if msg.get("role") == "user" else "model" contents.append( - types.Content(role=role, parts=[types.Part(text=msg.get("content", ""))]) + types.Content( + role=role, parts=[types.Part(text=msg.get("content", ""))] + ) ) generation_config = types.GenerateContentConfig( temperature=1.0, max_output_tokens=2000, safety_settings=STANDARD_SAFETY_SETTINGS, - system_instruction=types.Content(role="system", parts=[types.Part(text=system_prompt)]), + system_instruction=types.Content( + role="system", parts=[types.Part(text=system_prompt)] + ), ) try: @@ -117,45 +143,76 @@ class DmbotTetoCog(commands.Cog): async def _teto_reply_ai(self, text: str) -> str: """Replies to the text as Kasane Teto using AI via Vertex AI.""" - return await self._teto_reply_ai_with_messages([{"role": "user", "content": text}]) + return await self._teto_reply_ai_with_messages( + [{"role": "user", "content": text}] + ) - async def _send_followup_in_chunks(self, interaction: discord.Interaction, text: str, *, ephemeral: bool = True) -> None: + async def _send_followup_in_chunks( + self, interaction: discord.Interaction, text: str, *, ephemeral: bool = True + ) -> None: """Send a potentially long message in chunks using followup messages.""" chunk_size = 1900 - chunks = [text[i : i + chunk_size] for i in range(0, len(text), chunk_size)] or [""] + chunks = [ + text[i : i + chunk_size] for i in range(0, len(text), chunk_size) + ] or [""] for chunk in chunks: await interaction.followup.send(chunk, ephemeral=ephemeral) - teto = app_commands.Group(name="teto", description="Commands related to Kasane Teto.") - model = app_commands.Group(parent=teto, name="model", description="Commands related to Teto's AI model.") - endpoint = app_commands.Group(parent=teto, name="endpoint", description="Commands related to Teto's API endpoint.") - history = app_commands.Group(parent=teto, name="history", description="Commands related to Teto's chat history.") - + teto = app_commands.Group( + name="teto", description="Commands related to Kasane Teto." + ) + model = app_commands.Group( + parent=teto, name="model", description="Commands related to Teto's AI model." + ) + endpoint = app_commands.Group( + parent=teto, + name="endpoint", + description="Commands related to Teto's API endpoint.", + ) + history = app_commands.Group( + parent=teto, + name="history", + description="Commands related to Teto's chat history.", + ) @model.command(name="set", description="Sets the AI model for Teto.") @app_commands.describe(model_name="The name of the AI model to use.") async def set_ai_model(self, interaction: discord.Interaction, model_name: str): self._ai_model = model_name - await interaction.response.send_message(f"Teto's AI model set to: {model_name} desu~", ephemeral=True) + await interaction.response.send_message( + f"Teto's AI model set to: {model_name} desu~", ephemeral=True + ) @model.command(name="get", description="Gets the current AI model for Teto.") async def get_ai_model(self, interaction: discord.Interaction): - await interaction.response.send_message(f"Teto's current AI model is: {self._ai_model} desu~", ephemeral=True) + await interaction.response.send_message( + f"Teto's current AI model is: {self._ai_model} desu~", ephemeral=True + ) @endpoint.command(name="set", description="Sets the API endpoint for Teto.") @app_commands.describe(endpoint_url="The URL of the API endpoint.") - async def set_api_endpoint(self, interaction: discord.Interaction, endpoint_url: str): + async def set_api_endpoint( + self, interaction: discord.Interaction, endpoint_url: str + ): self._api_endpoint = endpoint_url - await interaction.response.send_message(f"Teto's API endpoint set to: {endpoint_url} desu~", ephemeral=True) + await interaction.response.send_message( + f"Teto's API endpoint set to: {endpoint_url} desu~", ephemeral=True + ) - @history.command(name="clear", description="Clears the chat history for the current channel.") + @history.command( + name="clear", description="Clears the chat history for the current channel." + ) async def clear_chat_history(self, interaction: discord.Interaction): channel_id = interaction.channel_id if channel_id in _teto_conversations: del _teto_conversations[channel_id] - await interaction.response.send_message("Chat history cleared for this channel desu~", ephemeral=True) + await interaction.response.send_message( + "Chat history cleared for this channel desu~", ephemeral=True + ) else: - await interaction.response.send_message("No chat history found for this channel desu~", ephemeral=True) + await interaction.response.send_message( + "No chat history found for this channel desu~", ephemeral=True + ) @teto.command(name="chat", description="Chat with Kasane Teto AI.") @app_commands.describe(message="Your message to Teto.") @@ -170,21 +227,31 @@ class DmbotTetoCog(commands.Cog): try: ai_reply = await self._teto_reply_ai_with_messages(messages=convo) ai_reply = strip_think_blocks(ai_reply) - await self._send_followup_in_chunks(interaction, ai_reply, ephemeral=True) + await self._send_followup_in_chunks( + interaction, ai_reply, ephemeral=True + ) convo.append({"role": "assistant", "content": ai_reply}) _teto_conversations[convo_key] = convo[-30:] # Keep last 30 messages except Exception as e: - await interaction.followup.send(f"Teto AI reply failed: {e} desu~", ephemeral=True) + await interaction.followup.send( + f"Teto AI reply failed: {e} desu~", ephemeral=True + ) else: - await interaction.followup.send("Please provide a message to chat with Teto desu~", ephemeral=True) + await interaction.followup.send( + "Please provide a message to chat with Teto desu~", ephemeral=True + ) # Context menu command must be defined at module level @app_commands.context_menu(name="Teto AI Reply") -async def teto_context_menu_ai_reply(interaction: discord.Interaction, message: discord.Message): +async def teto_context_menu_ai_reply( + interaction: discord.Interaction, message: discord.Message +): """Replies to the selected message as a Teto AI.""" if not message.content: - await interaction.response.send_message("The selected message has no text content to reply to! >.<", ephemeral=True) + await interaction.response.send_message( + "The selected message has no text content to reply to! >.<", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) @@ -196,9 +263,11 @@ async def teto_context_menu_ai_reply(interaction: discord.Interaction, message: convo.append({"role": "user", "content": message.content}) try: # Get the TetoCog instance from the bot - cog = interaction.client.get_cog("DmbotTetoCog") # Changed from TetoCog + cog = interaction.client.get_cog("DmbotTetoCog") # Changed from TetoCog if cog is None: - await interaction.followup.send("DmbotTetoCog is not loaded, cannot reply.", ephemeral=True) # Changed from TetoCog + await interaction.followup.send( + "DmbotTetoCog is not loaded, cannot reply.", ephemeral=True + ) # Changed from TetoCog return ai_reply = await cog._teto_reply_ai_with_messages(messages=convo) ai_reply = strip_think_blocks(ai_reply) @@ -206,10 +275,13 @@ async def teto_context_menu_ai_reply(interaction: discord.Interaction, message: convo.append({"role": "assistant", "content": ai_reply}) _teto_conversations[convo_key] = convo[-10:] except Exception as e: - await interaction.followup.send(f"Teto AI reply failed: {e} desu~", ephemeral=True) + await interaction.followup.send( + f"Teto AI reply failed: {e} desu~", ephemeral=True + ) + async def setup(bot: commands.Bot): - cog = DmbotTetoCog(bot) # Changed from TetoCog + cog = DmbotTetoCog(bot) # Changed from TetoCog await bot.add_cog(cog) bot.tree.add_command(teto_context_menu_ai_reply) - print("DmbotTetoCog loaded! desu~") # Changed from TetoCog + print("DmbotTetoCog loaded! desu~") # Changed from TetoCog diff --git a/cogs/oauth_cog.py b/cogs/oauth_cog.py index 1c56f44..3e4148e 100644 --- a/cogs/oauth_cog.py +++ b/cogs/oauth_cog.py @@ -15,10 +15,12 @@ from typing import Dict, Optional, Any # Import the OAuth modules import sys + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import discord_oauth import oauth_server + class OAuthCog(commands.Cog): def __init__(self, bot): self.bot = bot @@ -30,7 +32,11 @@ class OAuthCog(commands.Cog): async def start_oauth_server(self): """Start the OAuth callback server if API OAuth is not enabled.""" # Check if API OAuth is enabled - api_oauth_enabled = os.getenv("API_OAUTH_ENABLED", "true").lower() in ("true", "1", "yes") + api_oauth_enabled = os.getenv("API_OAUTH_ENABLED", "true").lower() in ( + "true", + "1", + "yes", + ) if api_oauth_enabled: # If API OAuth is enabled, we don't need to start the local OAuth server @@ -45,10 +51,13 @@ class OAuthCog(commands.Cog): await oauth_server.start_server(host, port) print(f"OAuth callback server running at http://{host}:{port}") - async def check_token_availability(self, user_id: str, channel_id: int, max_attempts: int = 15, delay: int = 3): + async def check_token_availability( + self, user_id: str, channel_id: int, max_attempts: int = 15, delay: int = 3 + ): """Check if a token is available for the user after API OAuth flow.""" # Import the OAuth module import sys + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import discord_oauth @@ -62,7 +71,9 @@ class OAuthCog(commands.Cog): for attempt in range(max_attempts): # Wait for a bit await asyncio.sleep(delay) - print(f"Checking token availability for user {user_id}, attempt {attempt+1}/{max_attempts}") + print( + f"Checking token availability for user {user_id}, attempt {attempt+1}/{max_attempts}" + ) # Try to get the token try: @@ -70,12 +81,16 @@ class OAuthCog(commands.Cog): token = await discord_oauth.get_token(user_id) if token: # Token is available locally, send a success message - await channel.send(f"<@{user_id}> ✅ Authentication successful! You can now use the API.") + await channel.send( + f"<@{user_id}> ✅ Authentication successful! You can now use the API." + ) return # If not available locally, try to get it from the API service if discord_oauth.API_OAUTH_ENABLED: - print(f"Token not found locally, checking API service for user {user_id}") + print( + f"Token not found locally, checking API service for user {user_id}" + ) try: # Make a direct API call to check if the token exists in the API service async with aiohttp.ClientSession() as session: @@ -90,13 +105,19 @@ class OAuthCog(commands.Cog): if data.get("authenticated", False): # Try to retrieve the token from the API service - token_url = f"{discord_oauth.API_URL}/token/{user_id}" + token_url = ( + f"{discord_oauth.API_URL}/token/{user_id}" + ) async with session.get(token_url) as token_resp: if token_resp.status == 200: token_data = await token_resp.json() # Save the token locally - discord_oauth.save_token(user_id, token_data) - await channel.send(f"<@{user_id}> ✅ Authentication successful! You can now use the API.") + discord_oauth.save_token( + user_id, token_data + ) + await channel.send( + f"<@{user_id}> ✅ Authentication successful! You can now use the API." + ) return except Exception as e: print(f"Error checking auth status with API service: {e}") @@ -104,7 +125,9 @@ class OAuthCog(commands.Cog): print(f"Error checking token availability: {e}") # If we get here, the token is not available after max_attempts - await channel.send(f"<@{user_id}> ⚠️ Authentication may have failed. Please try again or check with the bot owner.") + await channel.send( + f"<@{user_id}> ⚠️ Authentication may have failed. Please try again or check with the bot owner." + ) async def auth_callback(self, user_id: str, user_info: Dict[str, Any]): """Callback for successful authentication.""" @@ -157,7 +180,11 @@ class OAuthCog(commands.Cog): auth_url = discord_oauth.get_auth_url(state, code_verifier) # Check if API OAuth is enabled - api_oauth_enabled = os.getenv("API_OAUTH_ENABLED", "true").lower() in ("true", "1", "yes") + api_oauth_enabled = os.getenv("API_OAUTH_ENABLED", "true").lower() in ( + "true", + "1", + "yes", + ) if not api_oauth_enabled: # If using local OAuth server, register the state and callback @@ -175,7 +202,7 @@ class OAuthCog(commands.Cog): embed = discord.Embed( title="Discord Authentication", description="Please click the link below to authenticate with Discord.", - color=discord.Color.blue() + color=discord.Color.blue(), ) embed.add_field( name="Instructions", @@ -185,12 +212,12 @@ class OAuthCog(commands.Cog): "3. You will be redirected to a confirmation page\n" "4. Return to Discord after seeing the confirmation" ), - inline=False + inline=False, ) embed.add_field( name="Authentication Link", value=f"[Click here to authenticate]({auth_url})", - inline=False + inline=False, ) # Add information about the redirect URI @@ -199,7 +226,7 @@ class OAuthCog(commands.Cog): embed.add_field( name="Note", value=f"You will be redirected to the API service at {api_url}/auth", - inline=False + inline=False, ) embed.set_footer(text="This link will expire in 10 minutes") @@ -239,7 +266,9 @@ class OAuthCog(commands.Cog): if local_success or api_success: if local_success and api_success: - await ctx.send("✅ Authentication revoked from both local storage and API service.") + await ctx.send( + "✅ Authentication revoked from both local storage and API service." + ) elif local_success: await ctx.send("✅ Authentication revoked from local storage.") else: @@ -306,10 +335,18 @@ class OAuthCog(commands.Cog): # Get user info with the new token try: - access_token = token_data.get("access_token") - user_info = await discord_oauth.get_user_info(access_token) + access_token = token_data.get( + "access_token" + ) + user_info = ( + await discord_oauth.get_user_info( + access_token + ) + ) username = user_info.get("username") - discriminator = user_info.get("discriminator") + discriminator = user_info.get( + "discriminator" + ) await ctx.send( f"✅ You are authenticated as {username}#{discriminator}.\n" @@ -318,7 +355,9 @@ class OAuthCog(commands.Cog): ) return except Exception as e: - print(f"Error getting user info with token from API service: {e}") + print( + f"Error getting user info with token from API service: {e}" + ) await ctx.send( f"✅ You are authenticated according to the API service.\n" f"The token has been retrieved and saved locally." @@ -328,7 +367,9 @@ class OAuthCog(commands.Cog): print(f"Error checking auth status with API service: {e}") # If we get here, the user is not authenticated anywhere - await ctx.send("❌ You are not currently authenticated. Use `!auth` to authenticate.") + await ctx.send( + "❌ You are not currently authenticated. Use `!auth` to authenticate." + ) @commands.command(name="authhelp") async def auth_help_command(self, ctx): @@ -336,34 +377,29 @@ class OAuthCog(commands.Cog): embed = discord.Embed( title="Authentication Help", description="Commands for managing Discord authentication", - color=discord.Color.blue() + color=discord.Color.blue(), ) embed.add_field( name="!auth", value="Authenticate with Discord to allow the bot to access the API on your behalf", - inline=False + inline=False, ) embed.add_field( name="!deauth", value="Revoke the bot's access to your Discord account", - inline=False + inline=False, ) embed.add_field( - name="!authstatus", - value="Check your authentication status", - inline=False + name="!authstatus", value="Check your authentication status", inline=False ) - embed.add_field( - name="!authhelp", - value="Show this help message", - inline=False - ) + embed.add_field(name="!authhelp", value="Show this help message", inline=False) await ctx.send(embed=embed) + async def setup(bot): await bot.add_cog(OAuthCog(bot)) diff --git a/cogs/owner_utils_cog.py b/cogs/owner_utils_cog.py index f04c127..dc73f3f 100644 --- a/cogs/owner_utils_cog.py +++ b/cogs/owner_utils_cog.py @@ -5,141 +5,171 @@ import logging logger = logging.getLogger(__name__) + class OwnerUtilsCog(commands.Cog, name="Owner Utils"): """Owner-only utility commands for bot management.""" - + def __init__(self, bot): self.bot = bot - + def _parse_user_and_message(self, content: str): """ Parse user identifier and message content from command arguments. - + Args: content: The full command content after the command name - + Returns: tuple: (user_id, message_content) or (None, None) if parsing fails """ if not content.strip(): return None, None - + # Split content into parts parts = content.strip().split(None, 1) if len(parts) < 2: return None, None - + user_part, message_content = parts - + # Try to extract user ID from mention format <@123456> or <@!123456> - mention_match = re.match(r'<@!?(\d+)>', user_part) + mention_match = re.match(r"<@!?(\d+)>", user_part) if mention_match: try: user_id = int(mention_match.group(1)) return user_id, message_content except ValueError: return None, None - + # Try to parse as raw user ID try: user_id = int(user_part) return user_id, message_content except ValueError: return None, None - - @commands.command(name="dm", aliases=["send_dm"], help="Send a direct message to a specified user (Owner only)") + + @commands.command( + name="dm", + aliases=["send_dm"], + help="Send a direct message to a specified user (Owner only)", + ) @commands.is_owner() async def dm_command(self, ctx, *, content: str = None): """ Send a direct message to a specified user. - + Usage: !dm @user message content here !dm 123456789012345678 message content here - + Args: content: User mention/ID followed by the message to send """ if not content: - await ctx.reply("❌ **Usage:** `!dm <@user|user_id> `\n" - "**Examples:**\n" - "• `!dm @username Hello there!`\n" - "• `!dm 123456789012345678 Hello there!`") + await ctx.reply( + "❌ **Usage:** `!dm <@user|user_id> `\n" + "**Examples:**\n" + "• `!dm @username Hello there!`\n" + "• `!dm 123456789012345678 Hello there!`" + ) return - + # Parse user and message content user_id, message_content = self._parse_user_and_message(content) - + if user_id is None or not message_content: - await ctx.reply("❌ **Invalid format.** Please provide a valid user mention or ID followed by a message.\n" - "**Usage:** `!dm <@user|user_id> `") + await ctx.reply( + "❌ **Invalid format.** Please provide a valid user mention or ID followed by a message.\n" + "**Usage:** `!dm <@user|user_id> `" + ) return - + # Validate message content length if len(message_content) > 2000: - await ctx.reply("❌ **Message too long.** Discord messages must be 2000 characters or fewer.\n" - f"Your message is {len(message_content)} characters.") + await ctx.reply( + "❌ **Message too long.** Discord messages must be 2000 characters or fewer.\n" + f"Your message is {len(message_content)} characters." + ) return - + try: # Fetch the target user target_user = self.bot.get_user(user_id) if not target_user: target_user = await self.bot.fetch_user(user_id) - + if not target_user: - await ctx.reply(f"❌ **User not found.** Could not find a user with ID `{user_id}`.") + await ctx.reply( + f"❌ **User not found.** Could not find a user with ID `{user_id}`." + ) return - + # Attempt to send the DM try: await target_user.send(message_content) - + # Send confirmation to command invoker embed = discord.Embed( title="✅ DM Sent Successfully", color=discord.Color.green(), - timestamp=discord.utils.utcnow() + timestamp=discord.utils.utcnow(), ) embed.add_field( name="Recipient", value=f"{target_user.mention} (`{target_user.name}#{target_user.discriminator}`)", - inline=False + inline=False, ) embed.add_field( name="Message Preview", - value=message_content[:100] + ("..." if len(message_content) > 100 else ""), - inline=False + value=message_content[:100] + + ("..." if len(message_content) > 100 else ""), + inline=False, ) embed.set_footer(text=f"User ID: {user_id}") - + await ctx.reply(embed=embed) - logger.info(f"DM sent successfully from {ctx.author} to {target_user} (ID: {user_id})") - + logger.info( + f"DM sent successfully from {ctx.author} to {target_user} (ID: {user_id})" + ) + except discord.Forbidden: - await ctx.reply(f"❌ **Cannot send DM to {target_user.mention}.**\n" - "The user likely has DMs disabled or has blocked the bot.") - logger.warning(f"Failed to send DM to {target_user} (ID: {user_id}) - Forbidden (DMs disabled or blocked)") - + await ctx.reply( + f"❌ **Cannot send DM to {target_user.mention}.**\n" + "The user likely has DMs disabled or has blocked the bot." + ) + logger.warning( + f"Failed to send DM to {target_user} (ID: {user_id}) - Forbidden (DMs disabled or blocked)" + ) + except discord.HTTPException as e: - await ctx.reply(f"❌ **Failed to send DM due to Discord API error.**\n" - f"Error: {str(e)}") - logger.error(f"HTTPException when sending DM to {target_user} (ID: {user_id}): {e}") - + await ctx.reply( + f"❌ **Failed to send DM due to Discord API error.**\n" + f"Error: {str(e)}" + ) + logger.error( + f"HTTPException when sending DM to {target_user} (ID: {user_id}): {e}" + ) + except discord.NotFound: - await ctx.reply(f"❌ **User not found.** No user exists with ID `{user_id}`.") + await ctx.reply( + f"❌ **User not found.** No user exists with ID `{user_id}`." + ) logger.warning(f"Attempted to send DM to non-existent user ID: {user_id}") - + except discord.HTTPException as e: - await ctx.reply(f"❌ **Failed to fetch user due to Discord API error.**\n" - f"Error: {str(e)}") + await ctx.reply( + f"❌ **Failed to fetch user due to Discord API error.**\n" + f"Error: {str(e)}" + ) logger.error(f"HTTPException when fetching user {user_id}: {e}") - + except Exception as e: - await ctx.reply(f"❌ **An unexpected error occurred.**\n" - f"Error: {str(e)}") + await ctx.reply( + f"❌ **An unexpected error occurred.**\n" f"Error: {str(e)}" + ) logger.error(f"Unexpected error in dm_command: {e}", exc_info=True) + async def setup(bot): """Setup function to load the cog.""" try: diff --git a/cogs/owoify_cog.py b/cogs/owoify_cog.py index f9afbdc..f74f8f5 100644 --- a/cogs/owoify_cog.py +++ b/cogs/owoify_cog.py @@ -9,61 +9,111 @@ import aiohttp # In-memory conversation history for owo AI (keyed by channel id) _owo_conversations = {} + def _owoify_text(text: str) -> str: """Improved owoification with more rules and randomness.""" # Basic substitutions - text = re.sub(r'[rl]', 'w', text) - text = re.sub(r'[RL]', 'W', text) - text = re.sub(r'n([aeiou])', r'ny\1', text) - text = re.sub(r'N([aeiou])', r'Ny\1', text) - text = re.sub(r'N([AEIOU])', r'NY\1', text) - text = re.sub(r'ove', 'uv', text) - text = re.sub(r'OVE', 'UV', text) + text = re.sub(r"[rl]", "w", text) + text = re.sub(r"[RL]", "W", text) + text = re.sub(r"n([aeiou])", r"ny\1", text) + text = re.sub(r"N([aeiou])", r"Ny\1", text) + text = re.sub(r"N([AEIOU])", r"NY\1", text) + text = re.sub(r"ove", "uv", text) + text = re.sub(r"OVE", "UV", text) # Extra substitutions - text = re.sub(r'\bth', lambda m: 'd' if random.random() < 0.5 else 'f', text, flags=re.IGNORECASE) - text = re.sub(r'\bthe\b', 'da', text, flags=re.IGNORECASE) - text = re.sub(r'\bthat\b', 'dat', text, flags=re.IGNORECASE) - text = re.sub(r'\bthis\b', 'dis', text, flags=re.IGNORECASE) - text = re.sub(r'\bthose\b', 'dose', text, flags=re.IGNORECASE) - text = re.sub(r'\bthere\b', 'dere', text, flags=re.IGNORECASE) - text = re.sub(r'\bhere\b', 'here', text, flags=re.IGNORECASE) # Intentionally no change, for variety - text = re.sub(r'\bwhat\b', 'whut', text, flags=re.IGNORECASE) - text = re.sub(r'\bwhen\b', 'wen', text, flags=re.IGNORECASE) - text = re.sub(r'\bwhere\b', 'whewe', text, flags=re.IGNORECASE) - text = re.sub(r'\bwhy\b', 'wai', text, flags=re.IGNORECASE) - text = re.sub(r'\bhow\b', 'hau', text, flags=re.IGNORECASE) - text = re.sub(r'\bno\b', 'nu', text, flags=re.IGNORECASE) - text = re.sub(r'\bhas\b', 'haz', text, flags=re.IGNORECASE) - text = re.sub(r'\bhave\b', 'haz', text, flags=re.IGNORECASE) - text = re.sub(r'\byou\b', lambda m: 'u' if random.random() < 0.5 else 'yu', text, flags=re.IGNORECASE) - text = re.sub(r'\byour\b', 'ur', text, flags=re.IGNORECASE) - text = re.sub(r'tion\b', 'shun', text, flags=re.IGNORECASE) - text = re.sub(r'ing\b', 'in', text, flags=re.IGNORECASE) + text = re.sub( + r"\bth", + lambda m: "d" if random.random() < 0.5 else "f", + text, + flags=re.IGNORECASE, + ) + text = re.sub(r"\bthe\b", "da", text, flags=re.IGNORECASE) + text = re.sub(r"\bthat\b", "dat", text, flags=re.IGNORECASE) + text = re.sub(r"\bthis\b", "dis", text, flags=re.IGNORECASE) + text = re.sub(r"\bthose\b", "dose", text, flags=re.IGNORECASE) + text = re.sub(r"\bthere\b", "dere", text, flags=re.IGNORECASE) + text = re.sub( + r"\bhere\b", "here", text, flags=re.IGNORECASE + ) # Intentionally no change, for variety + text = re.sub(r"\bwhat\b", "whut", text, flags=re.IGNORECASE) + text = re.sub(r"\bwhen\b", "wen", text, flags=re.IGNORECASE) + text = re.sub(r"\bwhere\b", "whewe", text, flags=re.IGNORECASE) + text = re.sub(r"\bwhy\b", "wai", text, flags=re.IGNORECASE) + text = re.sub(r"\bhow\b", "hau", text, flags=re.IGNORECASE) + text = re.sub(r"\bno\b", "nu", text, flags=re.IGNORECASE) + text = re.sub(r"\bhas\b", "haz", text, flags=re.IGNORECASE) + text = re.sub(r"\bhave\b", "haz", text, flags=re.IGNORECASE) + text = re.sub( + r"\byou\b", + lambda m: "u" if random.random() < 0.5 else "yu", + text, + flags=re.IGNORECASE, + ) + text = re.sub(r"\byour\b", "ur", text, flags=re.IGNORECASE) + text = re.sub(r"tion\b", "shun", text, flags=re.IGNORECASE) + text = re.sub(r"ing\b", "in", text, flags=re.IGNORECASE) # Playful punctuation - text = re.sub(r'!', lambda m: random.choice(['!!1!', '! UwU', '! owo', '!! >w<', '! >//<', '!!?!']), text) - text = re.sub(r'\?', lambda m: random.choice(['?? OwO', '? uwu', '?']), text) - text = re.sub(r'\.', lambda m: random.choice(['~', '.', ' ^w^', ' o.o', ' ._.']), text) + text = re.sub( + r"!", + lambda m: random.choice(["!!1!", "! UwU", "! owo", "!! >w<", "! >//<", "!!?!"]), + text, + ) + text = re.sub(r"\?", lambda m: random.choice(["?? OwO", "? uwu", "?"]), text) + text = re.sub( + r"\.", lambda m: random.choice(["~", ".", " ^w^", " o.o", " ._."]), text + ) + # Stutter (probabilistic, only for words with at least 2 letters) def stutter_word(match): word = match.group(0) - if len(word) > 2 and random.random() < 0.33 and word[0].isalpha(): # Increased probability + if ( + len(word) > 2 and random.random() < 0.33 and word[0].isalpha() + ): # Increased probability return f"{word[0]}-{word}" return word - text = re.sub(r'\b\w+\b', stutter_word, text) + + text = re.sub(r"\b\w+\b", stutter_word, text) # Random interjection insertion (after commas or randomly) - interjections = [" owo", " uwu", " >w<", " ^w^", " OwO", " UwU", " >.<", " XD", " nyaa~", ":3", "(^///^)", "(ᵘʷᵘ)", "(・`ω´・)", ";;w;;", " teehee", " hehe", " x3", " rawr", "*nuzzles*", "*pounces*"] - parts = re.split(r'([,])', text) + interjections = [ + " owo", + " uwu", + " >w<", + " ^w^", + " OwO", + " UwU", + " >.<", + " XD", + " nyaa~", + ":3", + "(^///^)", + "(ᵘʷᵘ)", + "(・`ω´・)", + ";;w;;", + " teehee", + " hehe", + " x3", + " rawr", + "*nuzzles*", + "*pounces*", + ] + parts = re.split(r"([,])", text) for i in range(len(parts)): - if parts[i] == ',' or (random.random() < 0.15 and parts[i].strip()): # Increased probability + if parts[i] == "," or ( + random.random() < 0.15 and parts[i].strip() + ): # Increased probability parts[i] += random.choice(interjections) - text = ''.join(parts) + text = "".join(parts) # Suffix text += random.choice(interjections) return text + async def _owoify_text_ai(text: str) -> str: """Owoify text using AI via OpenRouter (google/gemini-2.0-flash-exp:free).""" - return await _owoify_text_ai_with_messages([{"role": "user", "content": text}], system_mode="transform") + return await _owoify_text_ai_with_messages( + [{"role": "user", "content": text}], system_mode="transform" + ) + async def _owoify_text_ai_with_messages(messages, system_mode="transform"): """ @@ -74,10 +124,7 @@ async def _owoify_text_ai_with_messages(messages, system_mode="transform"): if not api_key: raise RuntimeError("AI_API_KEY environment variable not set.") url = "https://openrouter.ai/api/v1/chat/completions" - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json" - } + headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} if system_mode == "transform": system_prompt = ( "You are a text transformer. Your ONLY job is to convert the user's input into an uwu/owo style of speech. " @@ -94,7 +141,7 @@ async def _owoify_text_ai_with_messages(messages, system_mode="transform"): ) payload = { "model": "deepseek/deepseek-chat-v3-0324:free", - "messages": [{"role": "system", "content": system_prompt}] + messages + "messages": [{"role": "system", "content": system_prompt}] + messages, } async with aiohttp.ClientSession() as session: async with session.post(url, headers=headers, json=payload) as resp: @@ -103,7 +150,10 @@ async def _owoify_text_ai_with_messages(messages, system_mode="transform"): return data["choices"][0]["message"]["content"] else: text = await resp.text() - raise RuntimeError(f"OpenRouter API returned non-JSON response (status {resp.status}): {text[:500]}") + raise RuntimeError( + f"OpenRouter API returned non-JSON response (status {resp.status}): {text[:500]}" + ) + class OwoifyCog(commands.Cog): def __init__(self, bot: commands.Bot): @@ -111,74 +161,107 @@ class OwoifyCog(commands.Cog): @app_commands.command(name="owoify", description="Owoifies your message!") @app_commands.describe(message_to_owoify="The message to owoify") - async def owoify_slash_command(self, interaction: discord.Interaction, message_to_owoify: str): + async def owoify_slash_command( + self, interaction: discord.Interaction, message_to_owoify: str + ): """Owoifies the provided message via a slash command.""" if not message_to_owoify.strip(): - await interaction.response.send_message("You nyeed to pwovide some text to owoify! >w<", ephemeral=True) + await interaction.response.send_message( + "You nyeed to pwovide some text to owoify! >w<", ephemeral=True + ) return owo_text = _owoify_text(message_to_owoify) await interaction.response.send_message(owo_text) @app_commands.command(name="owoify_ai", description="Owoify your message with AI!") @app_commands.describe(message_to_owoify="The message to owoify using AI") - async def owoify_ai_slash_command(self, interaction: discord.Interaction, message_to_owoify: str): + async def owoify_ai_slash_command( + self, interaction: discord.Interaction, message_to_owoify: str + ): """Owoifies the provided message via the OpenRouter AI.""" if not message_to_owoify.strip(): - await interaction.response.send_message("You nyeed to pwovide some text to owoify! >w<", ephemeral=True) + await interaction.response.send_message( + "You nyeed to pwovide some text to owoify! >w<", ephemeral=True + ) return try: owo_text = await _owoify_text_ai(message_to_owoify) await interaction.response.send_message(owo_text) except Exception as e: - await interaction.response.send_message(f"AI owoification failed: {e} >w<", ephemeral=True) + await interaction.response.send_message( + f"AI owoification failed: {e} >w<", ephemeral=True + ) + # Context menu command must be defined at module level @app_commands.context_menu(name="Owoify Message") -async def owoify_context_menu(interaction: discord.Interaction, message: discord.Message): +async def owoify_context_menu( + interaction: discord.Interaction, message: discord.Message +): """Owoifies the content of the selected message and replies.""" if not message.content: - await interaction.response.send_message("The sewected message has no text content to owoify! >.<", ephemeral=True) + await interaction.response.send_message( + "The sewected message has no text content to owoify! >.<", ephemeral=True + ) return original_content = message.content owo_text = _owoify_text(original_content) try: await message.reply(owo_text) - await interaction.response.send_message("Message owoified and wepwied! uwu", ephemeral=True) + await interaction.response.send_message( + "Message owoified and wepwied! uwu", ephemeral=True + ) except discord.Forbidden: await interaction.response.send_message( f"I couwdn't wepwy to the message (nyi Pwermissions? owo).\n" f"But hewe's the owoified text fow you: {owo_text}", - ephemeral=True + ephemeral=True, ) except discord.HTTPException as e: - await interaction.response.send_message(f"Oopsie! A tiny ewwow occuwwed: {e} >w<", ephemeral=True) + await interaction.response.send_message( + f"Oopsie! A tiny ewwow occuwwed: {e} >w<", ephemeral=True + ) + @app_commands.context_menu(name="Owoify Message (AI)") -async def owoify_context_menu_ai(interaction: discord.Interaction, message: discord.Message): +async def owoify_context_menu_ai( + interaction: discord.Interaction, message: discord.Message +): """Owoifies the content of the selected message using AI and replies.""" if not message.content: - await interaction.response.send_message("The sewected message has no text content to owoify! >.<", ephemeral=True) + await interaction.response.send_message( + "The sewected message has no text content to owoify! >.<", ephemeral=True + ) return original_content = message.content try: await interaction.response.defer(ephemeral=True) owo_text = await _owoify_text_ai(original_content) await message.reply(owo_text) - await interaction.followup.send("Message AI-owoified and wepwied! uwu", ephemeral=True) + await interaction.followup.send( + "Message AI-owoified and wepwied! uwu", ephemeral=True + ) except discord.Forbidden: await interaction.followup.send( f"I couwdn't wepwy to the message (nyi Pwermissions? owo).\n" f"But hewe's the AI owoified text fow you: {owo_text}", - ephemeral=True + ephemeral=True, ) except Exception as e: - await interaction.followup.send(f"AI owoification failed: {e} >w<", ephemeral=True) + await interaction.followup.send( + f"AI owoification failed: {e} >w<", ephemeral=True + ) + @app_commands.context_menu(name="Owo AI Reply") -async def owoify_context_menu_ai_reply(interaction: discord.Interaction, message: discord.Message): +async def owoify_context_menu_ai_reply( + interaction: discord.Interaction, message: discord.Message +): """Replies to the selected message as an owo AI.""" if not message.content: - await interaction.response.send_message("The sewected message has no text content to reply to! >.<", ephemeral=True) + await interaction.response.send_message( + "The sewected message has no text content to reply to! >.<", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) convo_key = message.channel.id @@ -193,6 +276,7 @@ async def owoify_context_menu_ai_reply(interaction: discord.Interaction, message except Exception as e: await interaction.followup.send(f"AI owo reply failed: {e} >w<", ephemeral=True) + async def setup(bot: commands.Bot): cog = OwoifyCog(bot) await bot.add_cog(cog) diff --git a/cogs/ping_cog.py b/cogs/ping_cog.py index 3b97507..68bee61 100644 --- a/cogs/ping_cog.py +++ b/cogs/ping_cog.py @@ -2,6 +2,7 @@ import discord from discord.ext import commands from discord import app_commands + class PingCog(commands.Cog, name="Ping"): """Cog for ping-related commands""" @@ -11,7 +12,7 @@ class PingCog(commands.Cog, name="Ping"): async def _ping_logic(self): """Core logic for the ping command.""" latency = round(self.bot.latency * 1000) - return f'Pong! ^~^ Response time: {latency}ms' + return f"Pong! ^~^ Response time: {latency}ms" # --- Prefix Command (for backward compatibility) --- @commands.command(name="ping") @@ -27,5 +28,6 @@ class PingCog(commands.Cog, name="Ping"): response = await self._ping_logic() await interaction.response.send_message(response) + async def setup(bot): await bot.add_cog(PingCog(bot)) diff --git a/cogs/profile_cog.py b/cogs/profile_cog.py index cec07f1..16da37a 100644 --- a/cogs/profile_cog.py +++ b/cogs/profile_cog.py @@ -1,17 +1,20 @@ import discord from discord.ext import commands + class ProfileCog(commands.Cog): def __init__(self, bot): self.bot = bot - @commands.command(name='avatar', help='Gets the avatar of a user in various formats.') + @commands.command( + name="avatar", help="Gets the avatar of a user in various formats." + ) async def avatar(self, ctx, member: discord.Member = None): """Gets the avatar of a user in various formats.""" if member is None: member = ctx.author - formats = ['png', 'jpg', 'webp'] + formats = ["png", "jpg", "webp"] embed = discord.Embed(title=f"{member.display_name}'s Avatar") embed.set_image(url=member.avatar.url) @@ -26,5 +29,6 @@ class ProfileCog(commands.Cog): embed.description = description await ctx.send(embed=embed) + async def setup(bot): await bot.add_cog(ProfileCog(bot)) diff --git a/cogs/profile_updater_cog.py b/cogs/profile_updater_cog.py index b197d9f..70f028f 100644 --- a/cogs/profile_updater_cog.py +++ b/cogs/profile_updater_cog.py @@ -5,7 +5,7 @@ import random import os import json import aiohttp -import requests # For bio update +import requests # For bio update import base64 import time from typing import Optional, Dict, Any, List @@ -14,25 +14,32 @@ from typing import Optional, Dict, Any, List from gurt.api import get_internal_ai_json_response from gurt.config import PROFILE_UPDATE_SCHEMA, ROLE_SELECTION_SCHEMA, DEFAULT_MODEL + class ProfileUpdaterCog(commands.Cog): """Cog for automatically updating Gurt's profile elements based on AI decisions.""" def __init__(self, bot: commands.Bot): self.bot = bot self.session: Optional[aiohttp.ClientSession] = None - self.gurt_cog: Optional[commands.Cog] = None # To store GurtCog instance - self.bot_token = os.getenv("DISCORD_TOKEN_GURT") # Need the bot token for bio updates - self.update_interval_hours = 3 # Default to every 3 hours, can be adjusted + self.gurt_cog: Optional[commands.Cog] = None # To store GurtCog instance + self.bot_token = os.getenv( + "DISCORD_TOKEN_GURT" + ) # Need the bot token for bio updates + self.update_interval_hours = 3 # Default to every 3 hours, can be adjusted self.profile_update_task.change_interval(hours=self.update_interval_hours) - self.last_update_time = 0 # Track last update time + self.last_update_time = 0 # Track last update time async def cog_load(self): """Initialize resources when the cog is loaded.""" self.session = aiohttp.ClientSession() # Removed wait_until_ready and gurt_cog retrieval from here if not self.bot_token: - print("WARNING: DISCORD_TOKEN_GURT environment variable not set. Bio updates will fail.") - print(f"ProfileUpdaterCog loaded. Update interval: {self.update_interval_hours} hours.") + print( + "WARNING: DISCORD_TOKEN_GURT environment variable not set. Bio updates will fail." + ) + print( + f"ProfileUpdaterCog loaded. Update interval: {self.update_interval_hours} hours." + ) self.profile_update_task.start() async def cog_unload(self): @@ -42,11 +49,13 @@ class ProfileUpdaterCog(commands.Cog): await self.session.close() print("ProfileUpdaterCog unloaded.") - @tasks.loop(hours=3) # Default interval, adjusted in __init__ + @tasks.loop(hours=3) # Default interval, adjusted in __init__ async def profile_update_task(self): """Periodically considers and potentially updates Gurt's profile.""" if not self.gurt_cog or not self.bot.is_ready(): - print("ProfileUpdaterTask: GurtCog not available or bot not ready. Skipping cycle.") + print( + "ProfileUpdaterTask: GurtCog not available or bot not ready. Skipping cycle." + ) return # Call the reusable update cycle logic @@ -58,40 +67,54 @@ class ProfileUpdaterCog(commands.Cog): await self.bot.wait_until_ready() print("ProfileUpdaterTask: Bot ready, attempting to get GurtCog...") # Retry mechanism to handle potential cog loading race conditions - for attempt in range(5): # Try up to 5 times - self.gurt_cog = self.bot.get_cog('Gurt') + for attempt in range(5): # Try up to 5 times + self.gurt_cog = self.bot.get_cog("Gurt") if self.gurt_cog: - print(f"ProfileUpdaterTask: GurtCog found on attempt {attempt + 1}. Starting loop.") - return # Success + print( + f"ProfileUpdaterTask: GurtCog found on attempt {attempt + 1}. Starting loop." + ) + return # Success # If not found, wait a bit before retrying - wait_time = 2 * (attempt + 1) # Increase wait time slightly each attempt - print(f"ProfileUpdaterTask: GurtCog not found on attempt {attempt + 1}, waiting {wait_time} seconds...") + wait_time = 2 * (attempt + 1) # Increase wait time slightly each attempt + print( + f"ProfileUpdaterTask: GurtCog not found on attempt {attempt + 1}, waiting {wait_time} seconds..." + ) await asyncio.sleep(wait_time) # If loop finishes without finding the cog - print("ERROR: ProfileUpdaterTask could not find GurtCog after multiple attempts. AI features will not work.") + print( + "ERROR: ProfileUpdaterTask could not find GurtCog after multiple attempts. AI features will not work." + ) async def perform_update_cycle(self): """Performs a single profile update check and potential update.""" if not self.gurt_cog or not self.bot.is_ready(): - print("ProfileUpdaterTask: GurtCog not available or bot not ready. Skipping cycle.") + print( + "ProfileUpdaterTask: GurtCog not available or bot not ready. Skipping cycle." + ) return - print(f"ProfileUpdaterTask: Starting update cycle at {time.strftime('%Y-%m-%d %H:%M:%S')}") + print( + f"ProfileUpdaterTask: Starting update cycle at {time.strftime('%Y-%m-%d %H:%M:%S')}" + ) self.last_update_time = time.time() try: # --- 1. Fetch Current State --- current_state = await self._get_current_profile_state() if not current_state: - print("ProfileUpdaterTask: Failed to get current profile state. Skipping cycle.") + print( + "ProfileUpdaterTask: Failed to get current profile state. Skipping cycle." + ) return # --- 2. AI Decision Step --- decision = await self._ask_ai_for_updates(current_state) if not decision or not decision.get("should_update"): - print("ProfileUpdaterTask: AI decided not to update profile this cycle.") + print( + "ProfileUpdaterTask: AI decided not to update profile this cycle." + ) return # --- 3. Conditional Execution --- @@ -123,6 +146,7 @@ class ProfileUpdaterCog(commands.Cog): except Exception as e: print(f"ERROR in perform_update_cycle: {e}") import traceback + traceback.print_exc() async def _get_current_profile_state(self) -> Optional[Dict[str, Any]]: @@ -132,10 +156,10 @@ class ProfileUpdaterCog(commands.Cog): state = { "avatar_url": None, - "avatar_image_data": None, # Base64 encoded image data + "avatar_image_data": None, # Base64 encoded image data "bio": None, - "roles": {}, # guild_id: [role_names] - "activity": None # {"type": str, "text": str} + "roles": {}, # guild_id: [role_names] + "activity": None, # {"type": str, "text": str} } # Avatar @@ -146,56 +170,81 @@ class ProfileUpdaterCog(commands.Cog): async with self.session.get(state["avatar_url"]) as resp: if resp.status == 200: image_bytes = await resp.read() - mime_type = resp.content_type or 'image/png' # Default mime type - state["avatar_image_data"] = f"data:{mime_type};base64,{base64.b64encode(image_bytes).decode('utf-8')}" + mime_type = ( + resp.content_type or "image/png" + ) # Default mime type + state["avatar_image_data"] = ( + f"data:{mime_type};base64,{base64.b64encode(image_bytes).decode('utf-8')}" + ) print("ProfileUpdaterTask: Fetched current avatar image data.") else: - print(f"ProfileUpdaterTask: Failed to download current avatar image (status: {resp.status}).") + print( + f"ProfileUpdaterTask: Failed to download current avatar image (status: {resp.status})." + ) except Exception as e: print(f"ProfileUpdaterTask: Error downloading avatar image: {e}") # Bio (Requires authenticated API call) if self.bot_token: headers = { - 'Authorization': f'Bot {self.bot_token}', - 'User-Agent': 'GurtDiscordBot (https://github.com/Slipstreamm/discordbot, v0.1)' + "Authorization": f"Bot {self.bot_token}", + "User-Agent": "GurtDiscordBot (https://github.com/Slipstreamm/discordbot, v0.1)", } # Try both potential endpoints - for url in ('https://discord.com/api/v9/users/@me', 'https://discord.com/api/v9/users/@me/profile'): + for url in ( + "https://discord.com/api/v9/users/@me", + "https://discord.com/api/v9/users/@me/profile", + ): try: async with self.session.get(url, headers=headers) as resp: if resp.status == 200: data = await resp.json() - state["bio"] = data.get('bio') - if state["bio"] is not None: # Found bio, stop checking endpoints - print(f"ProfileUpdaterTask: Fetched current bio (length: {len(state['bio']) if state['bio'] else 0}).") + state["bio"] = data.get("bio") + if ( + state["bio"] is not None + ): # Found bio, stop checking endpoints + print( + f"ProfileUpdaterTask: Fetched current bio (length: {len(state['bio']) if state['bio'] else 0})." + ) break else: - print(f"ProfileUpdaterTask: Failed to fetch bio from {url} (status: {resp.status}).") + print( + f"ProfileUpdaterTask: Failed to fetch bio from {url} (status: {resp.status})." + ) except Exception as e: print(f"ProfileUpdaterTask: Error fetching bio from {url}: {e}") if state["bio"] is None: - print("ProfileUpdaterTask: Could not fetch current bio.") + print("ProfileUpdaterTask: Could not fetch current bio.") else: print("ProfileUpdaterTask: Cannot fetch bio, BOT_TOKEN not set.") - # Roles and Activity (Per Guild) for guild in self.bot.guilds: member = guild.get_member(self.bot.user.id) if member: # Roles - state["roles"][str(guild.id)] = [role.name for role in member.roles if role.name != "@everyone"] + state["roles"][str(guild.id)] = [ + role.name for role in member.roles if role.name != "@everyone" + ] # Activity (Use the first guild's activity as representative) if not state["activity"] and member.activity: activity_type = member.activity.type activity_text = member.activity.name # Map discord.ActivityType enum to string if needed - activity_type_str = activity_type.name if isinstance(activity_type, discord.ActivityType) else str(activity_type) - state["activity"] = {"type": activity_type_str, "text": activity_text} + activity_type_str = ( + activity_type.name + if isinstance(activity_type, discord.ActivityType) + else str(activity_type) + ) + state["activity"] = { + "type": activity_type_str, + "text": activity_text, + } - print(f"ProfileUpdaterTask: Fetched current roles for {len(state['roles'])} guilds.") + print( + f"ProfileUpdaterTask: Fetched current roles for {len(state['roles'])} guilds." + ) if state["activity"]: print(f"ProfileUpdaterTask: Fetched current activity: {state['activity']}") else: @@ -203,37 +252,57 @@ class ProfileUpdaterCog(commands.Cog): return state - async def _ask_ai_for_updates(self, current_state: Dict[str, Any]) -> Optional[Dict[str, Any]]: + async def _ask_ai_for_updates( + self, current_state: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: """Asks the GurtCog AI if and how to update the profile.""" if not self.gurt_cog: print("ProfileUpdaterTask: GurtCog not found in _ask_ai_for_updates.") return None - if not hasattr(self.gurt_cog, 'memory_manager'): - print("ProfileUpdaterTask: GurtCog has no memory_manager attribute.") - return None + if not hasattr(self.gurt_cog, "memory_manager"): + print("ProfileUpdaterTask: GurtCog has no memory_manager attribute.") + return None # --- Fetch Dynamic Context from GurtCog --- - current_mood = getattr(self.gurt_cog, 'current_mood', 'neutral') + current_mood = getattr(self.gurt_cog, "current_mood", "neutral") personality_traits = {} interests = [] try: - personality_traits = await self.gurt_cog.memory_manager.get_all_personality_traits() - interests = await self.gurt_cog.memory_manager.get_interests( - limit=getattr(self.gurt_cog, 'interest_max_for_prompt', 4), # Use GurtCog's config safely - min_level=getattr(self.gurt_cog, 'interest_min_level_for_prompt', 0.3) # Use GurtCog's config safely + personality_traits = ( + await self.gurt_cog.memory_manager.get_all_personality_traits() + ) + interests = await self.gurt_cog.memory_manager.get_interests( + limit=getattr( + self.gurt_cog, "interest_max_for_prompt", 4 + ), # Use GurtCog's config safely + min_level=getattr( + self.gurt_cog, "interest_min_level_for_prompt", 0.3 + ), # Use GurtCog's config safely + ) + print( + f"ProfileUpdaterTask: Fetched {len(personality_traits)} traits and {len(interests)} interests for prompt." ) - print(f"ProfileUpdaterTask: Fetched {len(personality_traits)} traits and {len(interests)} interests for prompt.") except Exception as e: - print(f"ProfileUpdaterTask: Error fetching traits/interests from memory: {e}") + print( + f"ProfileUpdaterTask: Error fetching traits/interests from memory: {e}" + ) # Format traits and interests for the prompt - traits_str = ", ".join([f"{k}: {v:.2f}" for k, v in personality_traits.items()]) if personality_traits else "Defaults" - interests_str = ", ".join([f"{topic} ({level:.1f})" for topic, level in interests]) if interests else "None" + traits_str = ( + ", ".join([f"{k}: {v:.2f}" for k, v in personality_traits.items()]) + if personality_traits + else "Defaults" + ) + interests_str = ( + ", ".join([f"{topic} ({level:.1f})" for topic, level in interests]) + if interests + else "None" + ) # Prepare current state string for the prompt, safely handling None bio - bio_value = current_state.get('bio') - bio_summary = 'Not set' - if bio_value: # Check if bio_value is not None and not an empty string + bio_value = current_state.get("bio") + bio_summary = "Not set" + if bio_value: # Check if bio_value is not None and not an empty string bio_summary = f"{bio_value[:100]}{'...' if len(bio_value) > 100 else ''}" state_summary = f""" @@ -245,12 +314,12 @@ Current State: """ # Include image data if available image_prompt_part = "" - if current_state.get('avatar_image_data'): - image_prompt_part = "\n(Current avatar image data is provided below)" # Text hint for the AI + if current_state.get("avatar_image_data"): + image_prompt_part = "\n(Current avatar image data is provided below)" # Text hint for the AI # Define the JSON schema for the AI's response content # Use the schema imported from config.py - response_schema_dict = PROFILE_UPDATE_SCHEMA['schema'] + response_schema_dict = PROFILE_UPDATE_SCHEMA["schema"] # json_format_instruction = json.dumps(response_schema_dict, indent=2) # No longer needed for prompt # Define the payload for the response_format parameter - REMOVED for Vertex AI @@ -271,72 +340,110 @@ Your current mood is: {current_mood}. Your current interests include: {interests_str}. Review your current profile state (provided below) and decide if you want to make any changes based on your personality, mood, and interests. Be creative and in-character. -**IMPORTANT: Your *entire* response MUST be a single JSON object matching the required schema, with no other text before or after it.**""" # Simplified instruction +**IMPORTANT: Your *entire* response MUST be a single JSON object matching the required schema, with no other text before or after it.**""" # Simplified instruction prompt_messages = [ - {"role": "system", "content": system_prompt_content}, # Use the updated system prompt - {"role": "user", "content": [ - # Simplified user prompt instruction - {"type": "text", "text": f"{state_summary}{image_prompt_part}\n\nReview your current profile state. Decide if you want to change your avatar, bio, roles, or activity status based on your personality, mood, and interests. If yes, specify the changes in the JSON. If not, set 'should_update' to false.\n\n**CRITICAL: Respond ONLY with a valid JSON object matching the required schema.**"} - ]} + { + "role": "system", + "content": system_prompt_content, + }, # Use the updated system prompt + { + "role": "user", + "content": [ + # Simplified user prompt instruction + { + "type": "text", + "text": f"{state_summary}{image_prompt_part}\n\nReview your current profile state. Decide if you want to change your avatar, bio, roles, or activity status based on your personality, mood, and interests. If yes, specify the changes in the JSON. If not, set 'should_update' to false.\n\n**CRITICAL: Respond ONLY with a valid JSON object matching the required schema.**", + } + ], + }, ] # Add image data if available - if current_state.get('avatar_image_data'): + if current_state.get("avatar_image_data"): try: # Extract mime type and base64 data from the data URI string - data_uri = current_state['avatar_image_data'] - header, encoded = data_uri.split(',', 1) - mime_type = header.split(';')[0].split(':')[1] + data_uri = current_state["avatar_image_data"] + header, encoded = data_uri.split(",", 1) + mime_type = header.split(";")[0].split(":")[1] # Append the image data part to the user message content list - prompt_messages[-1]["content"].append({ - "type": "image_data", # Use a custom type marker for now - "mime_type": mime_type, - "data": encoded # The raw base64 string - }) - print("ProfileUpdaterTask: Added current avatar image data to AI prompt.") + prompt_messages[-1]["content"].append( + { + "type": "image_data", # Use a custom type marker for now + "mime_type": mime_type, + "data": encoded, # The raw base64 string + } + ) + print( + "ProfileUpdaterTask: Added current avatar image data to AI prompt." + ) except Exception as img_err: - print(f"ProfileUpdaterTask: Failed to process/add avatar image data: {img_err}") + print( + f"ProfileUpdaterTask: Failed to process/add avatar image data: {img_err}" + ) # Optionally add a text note about the failure - prompt_messages[-1]["content"].append({ - "type": "text", - "text": "\n(System Note: Failed to include current avatar image in prompt.)" - }) + prompt_messages[-1]["content"].append( + { + "type": "text", + "text": "\n(System Note: Failed to include current avatar image in prompt.)", + } + ) try: # Use the imported get_internal_ai_json_response function result_json = await get_internal_ai_json_response( - cog=self.gurt_cog, # Pass the GurtCog instance + cog=self.gurt_cog, # Pass the GurtCog instance prompt_messages=prompt_messages, task_description="Profile Update Decision", - response_schema_dict=response_schema_dict, # Pass the schema dict - model_name_override=DEFAULT_MODEL, # Use model from config - temperature=0.5, # Keep temperature for some creativity - max_tokens=500 # Adjust max tokens if needed + response_schema_dict=response_schema_dict, # Pass the schema dict + model_name_override=DEFAULT_MODEL, # Use model from config + temperature=0.5, # Keep temperature for some creativity + max_tokens=500, # Adjust max tokens if needed ) if result_json and isinstance(result_json, dict): # Basic validation of the received structure - if "should_update" in result_json and "updates" in result_json and "reasoning" in result_json: - print(f"ProfileUpdaterTask: AI Reasoning: {result_json.get('reasoning', 'N/A')}") # Log the reasoning + if ( + "should_update" in result_json + and "updates" in result_json + and "reasoning" in result_json + ): + print( + f"ProfileUpdaterTask: AI Reasoning: {result_json.get('reasoning', 'N/A')}" + ) # Log the reasoning return result_json else: - print(f"ProfileUpdaterTask: AI response missing required keys (should_update, updates, reasoning). Response: {result_json}") + print( + f"ProfileUpdaterTask: AI response missing required keys (should_update, updates, reasoning). Response: {result_json}" + ) return None else: - print(f"ProfileUpdaterTask: AI response was not a dictionary. Response: {result_json}") - return None + print( + f"ProfileUpdaterTask: AI response was not a dictionary. Response: {result_json}" + ) + return None except Exception as e: - print(f"ProfileUpdaterTask: Error calling AI for profile update decision: {e}") + print( + f"ProfileUpdaterTask: Error calling AI for profile update decision: {e}" + ) import traceback + traceback.print_exc() return None async def _update_avatar(self, search_query: str): """Updates the bot's avatar based on an AI-generated search query.""" - print(f"ProfileUpdaterTask: Attempting to update avatar with query: '{search_query}'") - if not self.gurt_cog or not hasattr(self.gurt_cog, 'web_search') or not self.session: - print("ProfileUpdaterTask: Cannot update avatar, GurtCog or web search tool not available.") + print( + f"ProfileUpdaterTask: Attempting to update avatar with query: '{search_query}'" + ) + if ( + not self.gurt_cog + or not hasattr(self.gurt_cog, "web_search") + or not self.session + ): + print( + "ProfileUpdaterTask: Cannot update avatar, GurtCog or web search tool not available." + ) return try: @@ -344,7 +451,9 @@ Review your current profile state (provided below) and decide if you want to mak search_results_data = await self.gurt_cog.web_search(query=search_query) if search_results_data.get("error"): - print(f"ProfileUpdaterTask: Web search failed: {search_results_data['error']}") + print( + f"ProfileUpdaterTask: Web search failed: {search_results_data['error']}" + ) return image_url = None @@ -353,13 +462,24 @@ Review your current profile state (provided below) and decide if you want to mak for result in results: url = result.get("url") # Basic check for image file extensions or common image hosting domains - if url and any(ext in url.lower() for ext in ['.png', '.jpg', '.jpeg', '.gif', '.webp']) or \ - any(domain in url.lower() for domain in ['imgur.com', 'pinimg.com', 'giphy.com']): + if ( + url + and any( + ext in url.lower() + for ext in [".png", ".jpg", ".jpeg", ".gif", ".webp"] + ) + or any( + domain in url.lower() + for domain in ["imgur.com", "pinimg.com", "giphy.com"] + ) + ): image_url = url break if not image_url: - print("ProfileUpdaterTask: No suitable image URL found in search results.") + print( + "ProfileUpdaterTask: No suitable image URL found in search results." + ) return print(f"ProfileUpdaterTask: Found image URL: {image_url}") @@ -371,33 +491,40 @@ Review your current profile state (provided below) and decide if you want to mak # Check rate limits before editing (simple delay for now) # Discord API limits avatar changes (e.g., 2 per hour?) # A more robust solution would track the last change time. - await asyncio.sleep(5) # Basic delay + await asyncio.sleep(5) # Basic delay await self.bot.user.edit(avatar=image_bytes) print("ProfileUpdaterTask: Avatar updated successfully.") else: - print(f"ProfileUpdaterTask: Failed to download image from {image_url} (status: {resp.status}).") + print( + f"ProfileUpdaterTask: Failed to download image from {image_url} (status: {resp.status})." + ) except discord.errors.HTTPException as e: - print(f"ProfileUpdaterTask: Discord API error updating avatar: {e.status} - {e.text}") + print( + f"ProfileUpdaterTask: Discord API error updating avatar: {e.status} - {e.text}" + ) except Exception as e: print(f"ProfileUpdaterTask: Error updating avatar: {e}") import traceback + traceback.print_exc() async def _update_bio(self, new_bio: str): """Updates the bot's bio using the Discord API.""" print(f"ProfileUpdaterTask: Attempting to update bio to: '{new_bio[:50]}...'") if not self.bot_token or not self.session: - print("ProfileUpdaterTask: Cannot update bio, BOT_TOKEN or session not available.") + print( + "ProfileUpdaterTask: Cannot update bio, BOT_TOKEN or session not available." + ) return headers = { - 'Authorization': f'Bot {self.bot_token}', - 'Content-Type': 'application/json', - 'User-Agent': 'GurtDiscordBot (https://github.com/Slipstreamm/discordbot, v0.1)' + "Authorization": f"Bot {self.bot_token}", + "Content-Type": "application/json", + "User-Agent": "GurtDiscordBot (https://github.com/Slipstreamm/discordbot, v0.1)", } - payload = {'bio': new_bio} - url = 'https://discord.com/api/v9/users/@me' # Primary endpoint + payload = {"bio": new_bio} + url = "https://discord.com/api/v9/users/@me" # Primary endpoint try: # Check rate limits (simple delay for now) @@ -408,27 +535,40 @@ Review your current profile state (provided below) and decide if you want to mak else: # Try fallback endpoint if the first failed with specific errors (e.g., 404) if resp.status == 404: - print(f"ProfileUpdaterTask: PATCH {url} failed (404), trying /profile endpoint...") - url_profile = 'https://discord.com/api/v9/users/@me/profile' - async with self.session.patch(url_profile, headers=headers, json=payload) as resp_profile: - if resp_profile.status == 200: - print("ProfileUpdaterTask: Bio updated successfully via /profile endpoint.") - else: - print(f"ProfileUpdaterTask: Failed to update bio via /profile endpoint (status: {resp_profile.status}). Response: {await resp_profile.text()}") + print( + f"ProfileUpdaterTask: PATCH {url} failed (404), trying /profile endpoint..." + ) + url_profile = "https://discord.com/api/v9/users/@me/profile" + async with self.session.patch( + url_profile, headers=headers, json=payload + ) as resp_profile: + if resp_profile.status == 200: + print( + "ProfileUpdaterTask: Bio updated successfully via /profile endpoint." + ) + else: + print( + f"ProfileUpdaterTask: Failed to update bio via /profile endpoint (status: {resp_profile.status}). Response: {await resp_profile.text()}" + ) else: - print(f"ProfileUpdaterTask: Failed to update bio (status: {resp.status}). Response: {await resp.text()}") + print( + f"ProfileUpdaterTask: Failed to update bio (status: {resp.status}). Response: {await resp.text()}" + ) except Exception as e: print(f"ProfileUpdaterTask: Error updating bio: {e}") import traceback + traceback.print_exc() async def _update_roles(self, role_theme: str): """Updates the bot's roles based on an AI-generated theme.""" - print(f"ProfileUpdaterTask: Attempting to update roles based on theme: '{role_theme}'") + print( + f"ProfileUpdaterTask: Attempting to update roles based on theme: '{role_theme}'" + ) if not self.gurt_cog: - print("ProfileUpdaterTask: Cannot update roles, GurtCog not available.") - return + print("ProfileUpdaterTask: Cannot update roles, GurtCog not available.") + return # This requires iterating through guilds and potentially making another AI call # --- Implementation --- @@ -439,12 +579,18 @@ Review your current profile state (provided below) and decide if you want to mak results = await asyncio.gather(*guild_update_tasks, return_exceptions=True) for i, result in enumerate(results): if isinstance(result, Exception): - print(f"ProfileUpdaterTask: Error updating roles for guild {self.bot.guilds[i].id}: {result}") - elif result: # If the helper returned True (success) - print(f"ProfileUpdaterTask: Successfully updated roles for guild {self.bot.guilds[i].id} based on theme '{role_theme}'.") + print( + f"ProfileUpdaterTask: Error updating roles for guild {self.bot.guilds[i].id}: {result}" + ) + elif result: # If the helper returned True (success) + print( + f"ProfileUpdaterTask: Successfully updated roles for guild {self.bot.guilds[i].id} based on theme '{role_theme}'." + ) # else: No update was needed or possible for this guild - async def _update_roles_for_guild(self, guild: discord.Guild, role_theme: str) -> bool: + async def _update_roles_for_guild( + self, guild: discord.Guild, role_theme: str + ) -> bool: """Helper to update roles for a specific guild.""" member = guild.get_member(self.bot.user.id) if not member: @@ -458,34 +604,49 @@ Review your current profile state (provided below) and decide if you want to mak # Cannot assign roles higher than or equal to bot's top role # Cannot assign managed roles (integrations, bot roles) # Cannot assign @everyone role - if not role.is_integration() and not role.is_bot_managed() and not role.is_default() and role.position < bot_top_role_position: - # Check if bot has manage_roles permission - if member.guild_permissions.manage_roles: - assignable_roles.append(role) - else: - # If no manage_roles perm, can only assign roles lower than bot's top role *if* they are unmanaged - # This check is already covered by the position check and managed role checks above. - # However, without manage_roles, the add/remove calls will fail anyway. - print(f"ProfileUpdaterTask: Bot lacks manage_roles permission in guild {guild.id}. Cannot update roles.") - return False # Cannot proceed without permission + if ( + not role.is_integration() + and not role.is_bot_managed() + and not role.is_default() + and role.position < bot_top_role_position + ): + # Check if bot has manage_roles permission + if member.guild_permissions.manage_roles: + assignable_roles.append(role) + else: + # If no manage_roles perm, can only assign roles lower than bot's top role *if* they are unmanaged + # This check is already covered by the position check and managed role checks above. + # However, without manage_roles, the add/remove calls will fail anyway. + print( + f"ProfileUpdaterTask: Bot lacks manage_roles permission in guild {guild.id}. Cannot update roles." + ) + return False # Cannot proceed without permission if not assignable_roles: print(f"ProfileUpdaterTask: No assignable roles found in guild {guild.id}.") return False assignable_role_names = [role.name for role in assignable_roles] - current_role_names = [role.name for role in member.roles if role.name != "@everyone"] + current_role_names = [ + role.name for role in member.roles if role.name != "@everyone" + ] # Define the JSON schema for the role selection AI response # Use the schema imported from config.py - role_selection_schema_dict = ROLE_SELECTION_SCHEMA['schema'] + role_selection_schema_dict = ROLE_SELECTION_SCHEMA["schema"] # role_selection_format = json.dumps(role_selection_schema_dict, indent=2) # No longer needed for prompt # Prepare prompt for the second AI call role_prompt_messages = [ - {"role": "system", "content": f"You are Gurt. Based on the theme '{role_theme}', select roles to add or remove from the available list for this server. Prioritize adding roles that fit the theme and removing roles that don't or conflict. You can add/remove up to 2 roles total."}, + { + "role": "system", + "content": f"You are Gurt. Based on the theme '{role_theme}', select roles to add or remove from the available list for this server. Prioritize adding roles that fit the theme and removing roles that don't or conflict. You can add/remove up to 2 roles total.", + }, # Simplified user prompt instruction - {"role": "user", "content": f"Available assignable roles: {assignable_role_names}\nYour current roles: {current_role_names}\nTheme: '{role_theme}'\n\nSelect roles to add/remove based on the theme.\n\n**CRITICAL: Respond ONLY with a valid JSON object matching the required schema.**"} + { + "role": "user", + "content": f"Available assignable roles: {assignable_role_names}\nYour current roles: {current_role_names}\nTheme: '{role_theme}'\n\nSelect roles to add/remove based on the theme.\n\n**CRITICAL: Respond ONLY with a valid JSON object matching the required schema.**", + }, ] try: @@ -502,25 +663,31 @@ Review your current profile state (provided below) and decide if you want to mak # Use the imported get_internal_ai_json_response function role_decision = await get_internal_ai_json_response( - cog=self.gurt_cog, # Pass the GurtCog instance + cog=self.gurt_cog, # Pass the GurtCog instance prompt_messages=role_prompt_messages, task_description=f"Role Selection for Guild {guild.id}", - response_schema_dict=role_selection_schema_dict, # Pass the schema dict - model_name_override=DEFAULT_MODEL, # Use model from config - temperature=0.5 # More deterministic for role selection + response_schema_dict=role_selection_schema_dict, # Pass the schema dict + model_name_override=DEFAULT_MODEL, # Use model from config + temperature=0.5, # More deterministic for role selection ) if not role_decision or not isinstance(role_decision, dict): - print(f"ProfileUpdaterTask: Failed to get valid role selection from AI for guild {guild.id}.") + print( + f"ProfileUpdaterTask: Failed to get valid role selection from AI for guild {guild.id}." + ) return False roles_to_add_names = role_decision.get("roles_to_add", []) roles_to_remove_names = role_decision.get("roles_to_remove", []) # Validate AI response - if not isinstance(roles_to_add_names, list) or not isinstance(roles_to_remove_names, list): - print(f"ProfileUpdaterTask: Invalid format for roles_to_add/remove from AI for guild {guild.id}.") - return False + if not isinstance(roles_to_add_names, list) or not isinstance( + roles_to_remove_names, list + ): + print( + f"ProfileUpdaterTask: Invalid format for roles_to_add/remove from AI for guild {guild.id}." + ) + return False # Limit changes roles_to_add_names = roles_to_add_names[:2] @@ -536,44 +703,69 @@ Review your current profile state (provided below) and decide if you want to mak roles_to_remove = [] for name in roles_to_remove_names: - # Can only remove roles the bot currently has + # Can only remove roles the bot currently has role = discord.utils.get(member.roles, name=name) # Ensure it's not the @everyone role or managed roles (already filtered, but double check) - if role and not role.is_default() and not role.is_integration() and not role.is_bot_managed(): + if ( + role + and not role.is_default() + and not role.is_integration() + and not role.is_bot_managed() + ): roles_to_remove.append(role) # Apply changes if any changes_made = False if roles_to_remove: try: - await member.remove_roles(*roles_to_remove, reason=f"ProfileUpdaterCog: Applying theme '{role_theme}'") - print(f"ProfileUpdaterTask: Removed roles {[r.name for r in roles_to_remove]} in guild {guild.id}.") + await member.remove_roles( + *roles_to_remove, + reason=f"ProfileUpdaterCog: Applying theme '{role_theme}'", + ) + print( + f"ProfileUpdaterTask: Removed roles {[r.name for r in roles_to_remove]} in guild {guild.id}." + ) changes_made = True - await asyncio.sleep(1) # Small delay between actions + await asyncio.sleep(1) # Small delay between actions except discord.Forbidden: - print(f"ProfileUpdaterTask: Permission error removing roles in guild {guild.id}.") + print( + f"ProfileUpdaterTask: Permission error removing roles in guild {guild.id}." + ) except discord.HTTPException as e: - print(f"ProfileUpdaterTask: HTTP error removing roles in guild {guild.id}: {e}") + print( + f"ProfileUpdaterTask: HTTP error removing roles in guild {guild.id}: {e}" + ) if roles_to_add: try: - await member.add_roles(*roles_to_add, reason=f"ProfileUpdaterCog: Applying theme '{role_theme}'") - print(f"ProfileUpdaterTask: Added roles {[r.name for r in roles_to_add]} in guild {guild.id}.") + await member.add_roles( + *roles_to_add, + reason=f"ProfileUpdaterCog: Applying theme '{role_theme}'", + ) + print( + f"ProfileUpdaterTask: Added roles {[r.name for r in roles_to_add]} in guild {guild.id}." + ) changes_made = True except discord.Forbidden: - print(f"ProfileUpdaterTask: Permission error adding roles in guild {guild.id}.") + print( + f"ProfileUpdaterTask: Permission error adding roles in guild {guild.id}." + ) except discord.HTTPException as e: - print(f"ProfileUpdaterTask: HTTP error adding roles in guild {guild.id}: {e}") + print( + f"ProfileUpdaterTask: HTTP error adding roles in guild {guild.id}: {e}" + ) - return changes_made # Return True if any change was attempted/successful + return changes_made # Return True if any change was attempted/successful except Exception as e: - print(f"ProfileUpdaterTask: Error during role update for guild {guild.id}: {e}") + print( + f"ProfileUpdaterTask: Error during role update for guild {guild.id}: {e}" + ) import traceback + traceback.print_exc() return False - async def _update_activity(self, activity_info: Dict[str, Optional[str]]): """Updates the bot's activity status.""" activity_type_str = activity_info.get("type") @@ -589,15 +781,20 @@ Review your current profile state (provided below) and decide if you want to mak except Exception as e: print(f"ProfileUpdaterTask: Error clearing activity: {e}") import traceback + traceback.print_exc() return # If only one is None but not both, that's invalid if activity_type_str is None or activity_text is None: - print("ProfileUpdaterTask: Invalid activity info received from AI - one field is null but not both.") + print( + "ProfileUpdaterTask: Invalid activity info received from AI - one field is null but not both." + ) return - print(f"ProfileUpdaterTask: Attempting to set activity to {activity_type_str}: '{activity_text}'") + print( + f"ProfileUpdaterTask: Attempting to set activity to {activity_type_str}: '{activity_text}'" + ) # Map string type to discord.ActivityType enum activity_type_map = { @@ -611,7 +808,9 @@ Review your current profile state (provided below) and decide if you want to mak activity_type = activity_type_map.get(activity_type_str.lower()) if activity_type is None: - print(f"ProfileUpdaterTask: Unknown activity type '{activity_type_str}'. Defaulting to 'playing'.") + print( + f"ProfileUpdaterTask: Unknown activity type '{activity_type_str}'. Defaulting to 'playing'." + ) activity_type = discord.ActivityType.playing activity = discord.Activity(type=activity_type, name=activity_text) @@ -622,6 +821,7 @@ Review your current profile state (provided below) and decide if you want to mak except Exception as e: print(f"ProfileUpdaterTask: Error updating activity: {e}") import traceback + traceback.print_exc() diff --git a/cogs/random_cog.py b/cogs/random_cog.py index 0fc2d8f..c6bd56a 100644 --- a/cogs/random_cog.py +++ b/cogs/random_cog.py @@ -3,42 +3,53 @@ import discord from discord.ext import commands from discord import app_commands import random as random_module -import typing # Need this for Optional +import typing # Need this for Optional # Cache to store uploaded file URLs (local to this cog) file_url_cache = {} + class RandomCog(commands.Cog): def __init__(self, bot): self.bot = bot # Updated _random_logic - async def _random_logic(self, interaction_or_ctx, hidden: bool = False) -> typing.Optional[str]: + async def _random_logic( + self, interaction_or_ctx, hidden: bool = False + ) -> typing.Optional[str]: """Core logic for the random command. Returns an error message string or None if successful.""" # NSFW Check is_nsfw_channel = False channel = interaction_or_ctx.channel if isinstance(channel, discord.TextChannel) and channel.is_nsfw(): is_nsfw_channel = True - elif isinstance(channel, discord.DMChannel): # DMs are considered NSFW for this purpose + elif isinstance( + channel, discord.DMChannel + ): # DMs are considered NSFW for this purpose is_nsfw_channel = True if not is_nsfw_channel: # Return error message directly, ephemeral handled by caller - return 'This command can only be used in age-restricted (NSFW) channels or DMs.' + return "This command can only be used in age-restricted (NSFW) channels or DMs." - directory = os.getenv('UPLOAD_DIRECTORY') + directory = os.getenv("UPLOAD_DIRECTORY") if not directory: - return 'UPLOAD_DIRECTORY is not set in the .env file.' + return "UPLOAD_DIRECTORY is not set in the .env file." if not os.path.isdir(directory): - return 'The specified UPLOAD_DIRECTORY does not exist or is not a directory.' + return ( + "The specified UPLOAD_DIRECTORY does not exist or is not a directory." + ) - files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))] + files = [ + f + for f in os.listdir(directory) + if os.path.isfile(os.path.join(directory, f)) + ] if not files: - return 'The specified directory is empty.' + return "The specified directory is empty." # Attempt to send a random file, handling potential size issues - original_files = list(files) # Copy for checking if all files failed + original_files = list(files) # Copy for checking if all files failed while files: chosen_file_name = random_module.choice(files) file_path = os.path.join(directory, chosen_file_name) @@ -46,25 +57,40 @@ class RandomCog(commands.Cog): # Check cache first if chosen_file_name in file_url_cache: # For interactions, defer if not already done, using the hidden flag - if not isinstance(interaction_or_ctx, commands.Context) and not interaction_or_ctx.response.is_done(): - await interaction_or_ctx.response.defer(ephemeral=hidden) # Defer before sending cached URL + if ( + not isinstance(interaction_or_ctx, commands.Context) + and not interaction_or_ctx.response.is_done() + ): + await interaction_or_ctx.response.defer( + ephemeral=hidden + ) # Defer before sending cached URL # Send cached URL if isinstance(interaction_or_ctx, commands.Context): - await interaction_or_ctx.reply(file_url_cache[chosen_file_name]) # Prefix commands can't be ephemeral + await interaction_or_ctx.reply( + file_url_cache[chosen_file_name] + ) # Prefix commands can't be ephemeral else: - await interaction_or_ctx.followup.send(file_url_cache[chosen_file_name], ephemeral=hidden) - return None # Indicate success + await interaction_or_ctx.followup.send( + file_url_cache[chosen_file_name], ephemeral=hidden + ) + return None # Indicate success try: # Determine how to send the file based on context/interaction if isinstance(interaction_or_ctx, commands.Context): - message = await interaction_or_ctx.reply(file=discord.File(file_path)) # Use reply for context - else: # It's an interaction + message = await interaction_or_ctx.reply( + file=discord.File(file_path) + ) # Use reply for context + else: # It's an interaction # Interactions need followup for files after defer() if not interaction_or_ctx.response.is_done(): - await interaction_or_ctx.response.defer(ephemeral=hidden) # Defer before sending file + await interaction_or_ctx.response.defer( + ephemeral=hidden + ) # Defer before sending file # Send file ephemerally if hidden is True - message = await interaction_or_ctx.followup.send(file=discord.File(file_path), ephemeral=hidden) + message = await interaction_or_ctx.followup.send( + file=discord.File(file_path), ephemeral=hidden + ) # Cache the URL if successfully sent if message and message.attachments: @@ -72,28 +98,30 @@ class RandomCog(commands.Cog): # Success, no further message needed return None else: - # Should not happen if send succeeded, but handle defensively - files.remove(chosen_file_name) - print(f"Warning: File {chosen_file_name} sent but no attachment URL found.") # Log warning - continue + # Should not happen if send succeeded, but handle defensively + files.remove(chosen_file_name) + print( + f"Warning: File {chosen_file_name} sent but no attachment URL found." + ) # Log warning + continue except discord.HTTPException as e: if e.code == 40005: # Request entity too large print(f"File too large: {chosen_file_name}") files.remove(chosen_file_name) - continue # Try another file + continue # Try another file else: print(f"HTTP Error sending file: {e}") # Return error message directly, ephemeral handled by caller - return f'Failed to upload the file due to an HTTP error: {e}' + return f"Failed to upload the file due to an HTTP error: {e}" except Exception as e: print(f"Generic Error sending file: {e}") # Return error message directly, ephemeral handled by caller - return f'An unexpected error occurred while uploading the file: {e}' + return f"An unexpected error occurred while uploading the file: {e}" # If loop finishes without returning/sending, all files were too large # Return error message directly, ephemeral handled by caller - return 'All files in the directory were too large to upload.' + return "All files in the directory were too large to upload." # --- Prefix Command --- @commands.command(name="random") @@ -106,23 +134,35 @@ class RandomCog(commands.Cog): # --- Slash Command --- # Updated signature and logic - @app_commands.command(name="random", description="Upload a random NSFW image from the configured directory") - @app_commands.describe(hidden="Set to True to make the response visible only to you (default: False)") - async def random_slash(self, interaction: discord.Interaction, hidden: bool = False): + @app_commands.command( + name="random", + description="Upload a random NSFW image from the configured directory", + ) + @app_commands.describe( + hidden="Set to True to make the response visible only to you (default: False)" + ) + async def random_slash( + self, interaction: discord.Interaction, hidden: bool = False + ): """Slash command version of random.""" # Pass hidden parameter to logic response = await self._random_logic(interaction, hidden=hidden) # If response is None, the logic already sent the file via followup/deferral - if response is not None: # An error occurred + if response is not None: # An error occurred # Ensure interaction hasn't already been responded to or deferred if not interaction.response.is_done(): # Send error message ephemerally if hidden is True OR if it's the NSFW channel error - ephemeral_error = hidden or response.startswith('This command can only be used') - await interaction.response.send_message(response, ephemeral=ephemeral_error) + ephemeral_error = hidden or response.startswith( + "This command can only be used" + ) + await interaction.response.send_message( + response, ephemeral=ephemeral_error + ) else: # If deferred, use followup. Send ephemerally based on hidden flag. await interaction.followup.send(response, ephemeral=hidden) + async def setup(bot): await bot.add_cog(RandomCog(bot)) diff --git a/cogs/random_strings_cog.py b/cogs/random_strings_cog.py index fe5ca4e..d5cd226 100644 --- a/cogs/random_strings_cog.py +++ b/cogs/random_strings_cog.py @@ -3,81 +3,82 @@ from discord.ext import commands from discord import app_commands import random + class PackGodCog(commands.Cog): def __init__(self, bot): self.bot = bot self.string_list = [ - "google chrome garden gnome", - "flip phone disowned", - "ice cream cone metronome", - "final chrome student loan", - "underground flintstone chicken bone", - "grandma went to the corner store and got her dentures thrown out the door", - "baby face aint got no place tripped on my shoelace", - "fortnite birth night", - "doom zoom room full of gloom", - "sentient bean saw a dream on a trampoline", - "wifi sci-fi alibi from a samurai", - "pickle jar avatar with a VCR", - "garage band on demand ran off with a rubber band", - "dizzy lizzy in a blizzard with a kazoo", - "moonlight gaslight bug bite fight night", - "toothpaste suitcase in a high-speed footrace", - "donut warzone with a saxophone ringtone", - "angsty toaster posted up like a rollercoaster", - "spork fork stork on the New York sidewalk", - "quantum raccoon stole my macaroon at high noon", - "algebra grandma in a panorama wearing pajamas", - "cactus cactus got a TikTok practice", - "eggplant overlord on a hoverboard discord", - "fridge magnet prophet dropped an omelet in the cockpit", - "mystery meat got beat by a spreadsheet", - "lava lamp champ with a tax refund stamp", - "hologram scam on a traffic cam jam", - "pogo stick picnic turned into a cryptic mythic", - "sock puppet summit on a budget with a trumpet", - "noodle crusade in a lemonade braid parade", - "neon platypus doing calculus on a school bus", - "hamster vigilante with a coffee-stained affidavit", - "microwave rave in a medieval cave", - "sidewalk chalk talk got hacked by a squawk", - "yoga mat diplomat in a laundromat", - "banana phone cyclone in a monotone zone", - "jukebox paradox at a paradox detox", - "laundry day melee with a broken bidet", - "emoji samurai with a ramen supply and a laser eye", - "grandpa hologram doing taxes on a banana stand", - "bubble wrap trap", - "waffle iron tyrant on a silent siren diet", - "paperclip spaceship with a midlife crisis playlist", - "marshmallow diplomat moonwalking into a courtroom spat", - "gummy bear heir in an electric chair of despair", - "fax machine dream team with a tambourine scheme", - "soda cannon with a canon", - "pretzel twist anarchist on a solar-powered tryst", - "unicycle oracle at a discount popsicle miracle", - "jousting mouse in a chainmail blouse with a holy spouse", - "ye olde scroll turned into a cinnamon roll at the wizard patrol", - "bard with a debit card locked in a tower of lard", - "court jester investor lost a duel to a molester", - "squire on fire writing poetry to a liar for hire", - "archery mishap caused by a gremlin with a Snapchat app", - "knight with stage fright performing Hamlet in moonlight" + "google chrome garden gnome", + "flip phone disowned", + "ice cream cone metronome", + "final chrome student loan", + "underground flintstone chicken bone", + "grandma went to the corner store and got her dentures thrown out the door", + "baby face aint got no place tripped on my shoelace", + "fortnite birth night", + "doom zoom room full of gloom", + "sentient bean saw a dream on a trampoline", + "wifi sci-fi alibi from a samurai", + "pickle jar avatar with a VCR", + "garage band on demand ran off with a rubber band", + "dizzy lizzy in a blizzard with a kazoo", + "moonlight gaslight bug bite fight night", + "toothpaste suitcase in a high-speed footrace", + "donut warzone with a saxophone ringtone", + "angsty toaster posted up like a rollercoaster", + "spork fork stork on the New York sidewalk", + "quantum raccoon stole my macaroon at high noon", + "algebra grandma in a panorama wearing pajamas", + "cactus cactus got a TikTok practice", + "eggplant overlord on a hoverboard discord", + "fridge magnet prophet dropped an omelet in the cockpit", + "mystery meat got beat by a spreadsheet", + "lava lamp champ with a tax refund stamp", + "hologram scam on a traffic cam jam", + "pogo stick picnic turned into a cryptic mythic", + "sock puppet summit on a budget with a trumpet", + "noodle crusade in a lemonade braid parade", + "neon platypus doing calculus on a school bus", + "hamster vigilante with a coffee-stained affidavit", + "microwave rave in a medieval cave", + "sidewalk chalk talk got hacked by a squawk", + "yoga mat diplomat in a laundromat", + "banana phone cyclone in a monotone zone", + "jukebox paradox at a paradox detox", + "laundry day melee with a broken bidet", + "emoji samurai with a ramen supply and a laser eye", + "grandpa hologram doing taxes on a banana stand", + "bubble wrap trap", + "waffle iron tyrant on a silent siren diet", + "paperclip spaceship with a midlife crisis playlist", + "marshmallow diplomat moonwalking into a courtroom spat", + "gummy bear heir in an electric chair of despair", + "fax machine dream team with a tambourine scheme", + "soda cannon with a canon", + "pretzel twist anarchist on a solar-powered tryst", + "unicycle oracle at a discount popsicle miracle", + "jousting mouse in a chainmail blouse with a holy spouse", + "ye olde scroll turned into a cinnamon roll at the wizard patrol", + "bard with a debit card locked in a tower of lard", + "court jester investor lost a duel to a molester", + "squire on fire writing poetry to a liar for hire", + "archery mishap caused by a gremlin with a Snapchat app", + "knight with stage fright performing Hamlet in moonlight", ] self.start_text = "shut yo" self.end_text = "ahh up" - + async def _packgod_logic(self): """Core logic for the packgod command.""" # Randomly select 3 strings from the list selected_strings = random.sample(self.string_list, 3) - + # Format the message message = f"{self.start_text} " message += ", ".join(selected_strings) message += f" {self.end_text}" - + return message # --- Prefix Command --- @@ -88,11 +89,15 @@ class PackGodCog(commands.Cog): await ctx.reply(response) # --- Slash Command --- - @app_commands.command(name="packgod", description="Send a message with hardcoded text and 3 random strings") + @app_commands.command( + name="packgod", + description="Send a message with hardcoded text and 3 random strings", + ) async def packgod_slash(self, interaction: discord.Interaction): """Slash command version of packgod.""" response = await self._packgod_logic() await interaction.response.send_message(response) + async def setup(bot: commands.Bot): await bot.add_cog(PackGodCog(bot)) diff --git a/cogs/random_timeout_cog.py b/cogs/random_timeout_cog.py index 56a9889..12d7543 100644 --- a/cogs/random_timeout_cog.py +++ b/cogs/random_timeout_cog.py @@ -11,7 +11,10 @@ import os logger = logging.getLogger(__name__) # Define the path for the JSON file to store timeout chance -TIMEOUT_CONFIG_FILE = os.path.join(os.path.dirname(__file__), "../data/timeout_config.json") +TIMEOUT_CONFIG_FILE = os.path.join( + os.path.dirname(__file__), "../data/timeout_config.json" +) + class RandomTimeoutCog(commands.Cog): def __init__(self, bot): @@ -27,7 +30,9 @@ class RandomTimeoutCog(commands.Cog): # Load timeout chance from JSON file self.load_timeout_config() - logger.info(f"RandomTimeoutCog initialized with target user ID: {self.target_user_id} and timeout chance: {self.timeout_chance}") + logger.info( + f"RandomTimeoutCog initialized with target user ID: {self.target_user_id} and timeout chance: {self.timeout_chance}" + ) def load_timeout_config(self): """Load timeout configuration from JSON file""" @@ -47,11 +52,13 @@ class RandomTimeoutCog(commands.Cog): config_data = { "timeout_chance": self.timeout_chance, "target_user_id": self.target_user_id, - "timeout_duration": self.timeout_duration + "timeout_duration": self.timeout_duration, } with open(TIMEOUT_CONFIG_FILE, "w") as f: json.dump(config_data, f, indent=4) - logger.info(f"Saved timeout configuration with chance: {self.timeout_chance}") + logger.info( + f"Saved timeout configuration with chance: {self.timeout_chance}" + ) except Exception as e: logger.error(f"Error saving timeout configuration: {e}") @@ -64,35 +71,40 @@ class RandomTimeoutCog(commands.Cog): title=f"{'⚠️ TIMEOUT TRIGGERED' if was_timed_out else '✅ No Timeout'}", description=f"Message from <@{self.target_user_id}> was processed", color=color, - timestamp=datetime.datetime.now(datetime.timezone.utc) + timestamp=datetime.datetime.now(datetime.timezone.utc), ) # Add user information embed.add_field( name="👤 User Information", value=f"**User:** {message.author.mention}\n**User ID:** {message.author.id}", - inline=False + inline=False, ) # Add roll information embed.add_field( name="🎲 Roll Information", value=f"**Roll:** {roll:.6f}\n**Threshold:** {self.timeout_chance:.6f}\n**Chance:** {self.timeout_chance * 100:.2f}%\n**Result:** {'TIMEOUT' if was_timed_out else 'SAFE'}", - inline=False + inline=False, ) # Add message information embed.add_field( name="💬 Message Information", value=f"**Channel:** {message.channel.mention}\n**Message Link:** [Click Here]({message.jump_url})", - inline=False + inline=False, ) # Set footer - embed.set_footer(text=f"Random Timeout System | {datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')}") + embed.set_footer( + text=f"Random Timeout System | {datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')}" + ) # Set author with user avatar - embed.set_author(name=f"{message.author.name}#{message.author.discriminator}", icon_url=message.author.display_avatar.url) + embed.set_author( + name=f"{message.author.name}#{message.author.discriminator}", + icon_url=message.author.display_avatar.url, + ) return embed @@ -113,25 +125,35 @@ class RandomTimeoutCog(commands.Cog): if roll < self.timeout_chance: try: # Calculate timeout until time (1 minute from now) - timeout_until = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=self.timeout_duration) + timeout_until = datetime.datetime.now( + datetime.timezone.utc + ) + datetime.timedelta(seconds=self.timeout_duration) # Apply the timeout - await message.author.timeout(timeout_until, reason="Random 0.5% chance timeout") + await message.author.timeout( + timeout_until, reason="Random 0.5% chance timeout" + ) was_timed_out = True # Send a message to the channel await message.channel.send( f"🎲 Bad luck! {message.author.mention} rolled a {roll:.4f} and got timed out for 1 minute! (0.5% chance)", - delete_after=10 # Delete after 10 seconds + delete_after=10, # Delete after 10 seconds ) - logger.info(f"User {message.author.id} was randomly timed out for 1 minute") + logger.info( + f"User {message.author.id} was randomly timed out for 1 minute" + ) except discord.Forbidden: - logger.warning(f"Bot doesn't have permission to timeout user {message.author.id}") + logger.warning( + f"Bot doesn't have permission to timeout user {message.author.id}" + ) except discord.HTTPException as e: logger.error(f"Failed to timeout user {message.author.id}: {e}") except Exception as e: - logger.error(f"Unexpected error when timing out user {message.author.id}: {e}") + logger.error( + f"Unexpected error when timing out user {message.author.id}: {e}" + ) # Log the event to the specified channel regardless of timeout result try: @@ -142,7 +164,9 @@ class RandomTimeoutCog(commands.Cog): embed = await self.create_log_embed(message, roll, was_timed_out) await log_channel.send(embed=embed) else: - logger.warning(f"Log channel with ID {self.log_channel_id} not found") + logger.warning( + f"Log channel with ID {self.log_channel_id} not found" + ) except Exception as e: logger.error(f"Error sending log message: {e}") @@ -158,10 +182,14 @@ class RandomTimeoutCog(commands.Cog): # Validate the percentage if not is_owner and (percentage < 0 or percentage > 10): - await ctx.reply(f"❌ Error: Moderators can only set timeout chance between 0% and 10%. Current: {self.timeout_chance * 100:.2f}%") + await ctx.reply( + f"❌ Error: Moderators can only set timeout chance between 0% and 10%. Current: {self.timeout_chance * 100:.2f}%" + ) return elif percentage < 0 or percentage > 100: - await ctx.reply(f"❌ Error: Timeout chance must be between 0% and 100%. Current: {self.timeout_chance * 100:.2f}%") + await ctx.reply( + f"❌ Error: Timeout chance must be between 0% and 100%. Current: {self.timeout_chance * 100:.2f}%" + ) return # Store the old value for logging @@ -178,25 +206,21 @@ class RandomTimeoutCog(commands.Cog): title="Timeout Chance Updated", description=f"The random timeout chance has been updated.", color=discord.Color.blue(), - timestamp=datetime.datetime.now(datetime.timezone.utc) + timestamp=datetime.datetime.now(datetime.timezone.utc), ) embed.add_field( - name="Previous Chance", - value=f"{old_chance * 100:.2f}%", - inline=True + name="Previous Chance", value=f"{old_chance * 100:.2f}%", inline=True ) embed.add_field( - name="New Chance", - value=f"{self.timeout_chance * 100:.2f}%", - inline=True + name="New Chance", value=f"{self.timeout_chance * 100:.2f}%", inline=True ) embed.add_field( name="Updated By", value=f"{ctx.author.mention} {' (Owner)' if is_owner else ' (Moderator)'}", - inline=False + inline=False, ) embed.set_footer(text=f"Random Timeout System | User ID: {self.target_user_id}") @@ -205,7 +229,9 @@ class RandomTimeoutCog(commands.Cog): await ctx.reply(embed=embed) # Log the change - logger.info(f"Timeout chance changed from {old_chance:.4f} to {self.timeout_chance:.4f} by {ctx.author.name} (ID: {ctx.author.id})") + logger.info( + f"Timeout chance changed from {old_chance:.4f} to {self.timeout_chance:.4f} by {ctx.author.name} (ID: {ctx.author.id})" + ) # Also log to the log channel if available try: @@ -219,19 +245,32 @@ class RandomTimeoutCog(commands.Cog): async def set_timeout_chance_error(self, ctx, error): """Error handler for the set_timeout_chance command""" if isinstance(error, commands.MissingPermissions): - await ctx.reply("❌ You need the 'Moderate Members' permission to use this command.") + await ctx.reply( + "❌ You need the 'Moderate Members' permission to use this command." + ) elif isinstance(error, commands.MissingRequiredArgument): - await ctx.reply(f"❌ Please provide a percentage. Example: `!set_timeout_chance 0.5` for 0.5%. Current: {self.timeout_chance * 100:.2f}%") + await ctx.reply( + f"❌ Please provide a percentage. Example: `!set_timeout_chance 0.5` for 0.5%. Current: {self.timeout_chance * 100:.2f}%" + ) elif isinstance(error, commands.BadArgument): - await ctx.reply(f"❌ Please provide a valid number. Example: `!set_timeout_chance 0.5` for 0.5%. Current: {self.timeout_chance * 100:.2f}%") + await ctx.reply( + f"❌ Please provide a valid number. Example: `!set_timeout_chance 0.5` for 0.5%. Current: {self.timeout_chance * 100:.2f}%" + ) else: await ctx.reply(f"❌ An error occurred: {error}") logger.error(f"Error in set_timeout_chance command: {error}") - @app_commands.command(name="set_timeout_chance", description="Set the random timeout chance percentage") - @app_commands.describe(percentage="The percentage chance (0-10% for moderators, 0-100% for owner)") + @app_commands.command( + name="set_timeout_chance", + description="Set the random timeout chance percentage", + ) + @app_commands.describe( + percentage="The percentage chance (0-10% for moderators, 0-100% for owner)" + ) @app_commands.checks.has_permissions(moderate_members=True) - async def set_timeout_chance_slash(self, interaction: discord.Interaction, percentage: float): + async def set_timeout_chance_slash( + self, interaction: discord.Interaction, percentage: float + ): """Slash command version of set_timeout_chance""" # Convert percentage to decimal (e.g., 5% -> 0.05) decimal_chance = percentage / 100.0 @@ -243,13 +282,13 @@ class RandomTimeoutCog(commands.Cog): if not is_owner and (percentage < 0 or percentage > 10): await interaction.response.send_message( f"❌ Error: Moderators can only set timeout chance between 0% and 10%. Current: {self.timeout_chance * 100:.2f}%", - ephemeral=True + ephemeral=True, ) return elif percentage < 0 or percentage > 100: await interaction.response.send_message( f"❌ Error: Timeout chance must be between 0% and 100%. Current: {self.timeout_chance * 100:.2f}%", - ephemeral=True + ephemeral=True, ) return @@ -267,25 +306,21 @@ class RandomTimeoutCog(commands.Cog): title="Timeout Chance Updated", description=f"The random timeout chance has been updated.", color=discord.Color.blue(), - timestamp=datetime.datetime.now(datetime.timezone.utc) + timestamp=datetime.datetime.now(datetime.timezone.utc), ) embed.add_field( - name="Previous Chance", - value=f"{old_chance * 100:.2f}%", - inline=True + name="Previous Chance", value=f"{old_chance * 100:.2f}%", inline=True ) embed.add_field( - name="New Chance", - value=f"{self.timeout_chance * 100:.2f}%", - inline=True + name="New Chance", value=f"{self.timeout_chance * 100:.2f}%", inline=True ) embed.add_field( name="Updated By", value=f"{interaction.user.mention} {' (Owner)' if is_owner else ' (Moderator)'}", - inline=False + inline=False, ) embed.set_footer(text=f"Random Timeout System | User ID: {self.target_user_id}") @@ -294,7 +329,9 @@ class RandomTimeoutCog(commands.Cog): await interaction.response.send_message(embed=embed) # Log the change - logger.info(f"Timeout chance changed from {old_chance:.4f} to {self.timeout_chance:.4f} by {interaction.user.name} (ID: {interaction.user.id})") + logger.info( + f"Timeout chance changed from {old_chance:.4f} to {self.timeout_chance:.4f} by {interaction.user.name} (ID: {interaction.user.id})" + ) # Also log to the log channel if available try: @@ -305,23 +342,25 @@ class RandomTimeoutCog(commands.Cog): logger.error(f"Error sending log message: {e}") @set_timeout_chance_slash.error - async def set_timeout_chance_slash_error(self, interaction: discord.Interaction, error): + async def set_timeout_chance_slash_error( + self, interaction: discord.Interaction, error + ): """Error handler for the set_timeout_chance slash command""" if isinstance(error, app_commands.errors.MissingPermissions): await interaction.response.send_message( "❌ You need the 'Moderate Members' permission to use this command.", - ephemeral=True + ephemeral=True, ) else: await interaction.response.send_message( - f"❌ An error occurred: {error}", - ephemeral=True + f"❌ An error occurred: {error}", ephemeral=True ) logger.error(f"Error in set_timeout_chance slash command: {error}") @commands.Cog.listener() async def on_ready(self): - logger.info(f'{self.__class__.__name__} cog has been loaded.') + logger.info(f"{self.__class__.__name__} cog has been loaded.") + async def setup(bot: commands.Bot): await bot.add_cog(RandomTimeoutCog(bot)) diff --git a/cogs/real_moderation_cog.py b/cogs/real_moderation_cog.py index e6a05b9..f03b3d2 100644 --- a/cogs/real_moderation_cog.py +++ b/cogs/real_moderation_cog.py @@ -7,11 +7,12 @@ from typing import Optional, Union, List # Use absolute import for ModLogCog from cogs.mod_log_cog import ModLogCog -from db import mod_log_db # Import the database functions +from db import mod_log_db # Import the database functions # Configure logging logger = logging.getLogger(__name__) + class ModerationCog(commands.Cog): """Real moderation commands that perform actual moderation actions.""" @@ -20,8 +21,7 @@ class ModerationCog(commands.Cog): # Create the main command group for this cog self.moderate_group = app_commands.Group( - name="moderate", - description="Moderation commands for server management" + name="moderate", description="Moderation commands for server management" ) # Register commands @@ -44,7 +44,7 @@ class ModerationCog(commands.Cog): name="ban", description="Ban a member from the server", callback=self.moderate_ban_callback, - parent=self.moderate_group + parent=self.moderate_group, ) # Define the send_dm parameter directly in the command # Instead of using Parameter class, we use the describe decorator @@ -52,7 +52,7 @@ class ModerationCog(commands.Cog): member="The member to ban", reason="The reason for the ban", delete_days="Number of days of messages to delete (0-7)", - send_dm="Whether to send a DM notification to the user (default: True)" + send_dm="Whether to send a DM notification to the user (default: True)", )(ban_command) self.moderate_group.add_command(ban_command) @@ -62,11 +62,10 @@ class ModerationCog(commands.Cog): name="unban", description="Unban a user from the server", callback=self.moderate_unban_callback, - parent=self.moderate_group + parent=self.moderate_group, ) app_commands.describe( - user_id="The ID of the user to unban", - reason="The reason for the unban" + user_id="The ID of the user to unban", reason="The reason for the unban" )(unban_command) self.moderate_group.add_command(unban_command) @@ -75,11 +74,10 @@ class ModerationCog(commands.Cog): name="kick", description="Kick a member from the server", callback=self.moderate_kick_callback, - parent=self.moderate_group + parent=self.moderate_group, ) app_commands.describe( - member="The member to kick", - reason="The reason for the kick" + member="The member to kick", reason="The reason for the kick" )(kick_command) self.moderate_group.add_command(kick_command) @@ -88,12 +86,12 @@ class ModerationCog(commands.Cog): name="timeout", description="Timeout a member in the server", callback=self.moderate_timeout_callback, - parent=self.moderate_group + parent=self.moderate_group, ) app_commands.describe( member="The member to timeout", duration="The duration of the timeout (e.g., '1d', '2h', '30m', '60s')", - reason="The reason for the timeout" + reason="The reason for the timeout", )(timeout_command) self.moderate_group.add_command(timeout_command) @@ -102,11 +100,11 @@ class ModerationCog(commands.Cog): name="removetimeout", description="Remove a timeout from a member", callback=self.moderate_remove_timeout_callback, - parent=self.moderate_group + parent=self.moderate_group, ) app_commands.describe( member="The member to remove timeout from", - reason="The reason for removing the timeout" + reason="The reason for removing the timeout", )(remove_timeout_command) self.moderate_group.add_command(remove_timeout_command) @@ -115,11 +113,11 @@ class ModerationCog(commands.Cog): name="purge", description="Delete a specified number of messages from a channel", callback=self.moderate_purge_callback, - parent=self.moderate_group + parent=self.moderate_group, ) app_commands.describe( amount="Number of messages to delete (1-100)", - user="Optional: Only delete messages from this user" + user="Optional: Only delete messages from this user", )(purge_command) self.moderate_group.add_command(purge_command) @@ -128,11 +126,10 @@ class ModerationCog(commands.Cog): name="warn", description="Warn a member in the server", callback=self.moderate_warn_callback, - parent=self.moderate_group + parent=self.moderate_group, ) app_commands.describe( - member="The member to warn", - reason="The reason for the warning" + member="The member to warn", reason="The reason for the warning" )(warn_command) self.moderate_group.add_command(warn_command) @@ -141,11 +138,11 @@ class ModerationCog(commands.Cog): name="dmbanned", description="Send a DM to a banned user", callback=self.moderate_dm_banned_callback, - parent=self.moderate_group + parent=self.moderate_group, ) app_commands.describe( user_id="The ID of the banned user to DM", - message="The message to send to the banned user" + message="The message to send to the banned user", )(dm_banned_command) self.moderate_group.add_command(dm_banned_command) @@ -154,11 +151,11 @@ class ModerationCog(commands.Cog): name="infractions", description="View moderation infractions for a user", callback=self.moderate_view_infractions_callback, - parent=self.moderate_group + parent=self.moderate_group, + ) + app_commands.describe(member="The member whose infractions to view")( + view_infractions_command ) - app_commands.describe( - member="The member whose infractions to view" - )(view_infractions_command) self.moderate_group.add_command(view_infractions_command) # --- Remove Infraction Command --- @@ -166,11 +163,11 @@ class ModerationCog(commands.Cog): name="removeinfraction", description="Remove a specific infraction by its case ID", callback=self.moderate_remove_infraction_callback, - parent=self.moderate_group + parent=self.moderate_group, ) app_commands.describe( case_id="The case ID of the infraction to remove", - reason="The reason for removing the infraction" + reason="The reason for removing the infraction", )(remove_infraction_command) self.moderate_group.add_command(remove_infraction_command) @@ -179,11 +176,11 @@ class ModerationCog(commands.Cog): name="clearinfractions", description="Clear all moderation infractions for a user", callback=self.moderate_clear_infractions_callback, - parent=self.moderate_group + parent=self.moderate_group, ) app_commands.describe( member="The member whose infractions to clear", - reason="The reason for clearing all infractions" + reason="The reason for clearing all infractions", )(clear_infractions_command) self.moderate_group.add_command(clear_infractions_command) @@ -195,16 +192,16 @@ class ModerationCog(commands.Cog): try: # Extract the number and unit - amount = int(''.join(filter(str.isdigit, duration_str))) - unit = ''.join(filter(str.isalpha, duration_str)).lower() + amount = int("".join(filter(str.isdigit, duration_str))) + unit = "".join(filter(str.isalpha, duration_str)).lower() - if unit == 'd' or unit == 'day' or unit == 'days': + if unit == "d" or unit == "day" or unit == "days": return datetime.timedelta(days=amount) - elif unit == 'h' or unit == 'hour' or unit == 'hours': + elif unit == "h" or unit == "hour" or unit == "hours": return datetime.timedelta(hours=amount) - elif unit == 'm' or unit == 'min' or unit == 'minute' or unit == 'minutes': + elif unit == "m" or unit == "min" or unit == "minute" or unit == "minutes": return datetime.timedelta(minutes=amount) - elif unit == 's' or unit == 'sec' or unit == 'second' or unit == 'seconds': + elif unit == "s" or unit == "sec" or unit == "second" or unit == "seconds": return datetime.timedelta(seconds=amount) else: return None @@ -213,36 +210,59 @@ class ModerationCog(commands.Cog): # --- Command Callbacks --- - async def moderate_ban_callback(self, interaction: discord.Interaction, member: discord.Member, reason: str = None, delete_days: int = 0, send_dm: bool = True): + async def moderate_ban_callback( + self, + interaction: discord.Interaction, + member: discord.Member, + reason: str = None, + delete_days: int = 0, + send_dm: bool = True, + ): """Ban a member from the server.""" # Check if the user has permission to ban members if not interaction.user.guild_permissions.ban_members: - await interaction.response.send_message("❌ You don't have permission to ban members.", ephemeral=True) + await interaction.response.send_message( + "❌ You don't have permission to ban members.", ephemeral=True + ) return # Check if the bot has permission to ban members if not interaction.guild.me.guild_permissions.ban_members: - await interaction.response.send_message("❌ I don't have permission to ban members.", ephemeral=True) + await interaction.response.send_message( + "❌ I don't have permission to ban members.", ephemeral=True + ) return # Check if the user is trying to ban themselves if member.id == interaction.user.id: - await interaction.response.send_message("❌ You cannot ban yourself.", ephemeral=True) + await interaction.response.send_message( + "❌ You cannot ban yourself.", ephemeral=True + ) return # Check if the user is trying to ban the bot if member.id == self.bot.user.id: - await interaction.response.send_message("❌ I cannot ban myself.", ephemeral=True) + await interaction.response.send_message( + "❌ I cannot ban myself.", ephemeral=True + ) return # Check if the user is trying to ban someone with a higher role - if interaction.user.top_role <= member.top_role and interaction.user.id != interaction.guild.owner_id: - await interaction.response.send_message("❌ You cannot ban someone with a higher or equal role.", ephemeral=True) + if ( + interaction.user.top_role <= member.top_role + and interaction.user.id != interaction.guild.owner_id + ): + await interaction.response.send_message( + "❌ You cannot ban someone with a higher or equal role.", ephemeral=True + ) return # Check if the bot can ban the member (role hierarchy) if interaction.guild.me.top_role <= member.top_role: - await interaction.response.send_message("❌ I cannot ban someone with a higher or equal role than me.", ephemeral=True) + await interaction.response.send_message( + "❌ I cannot ban someone with a higher or equal role than me.", + ephemeral=True, + ) return # Ensure delete_days is within valid range (0-7) @@ -255,11 +275,17 @@ class ModerationCog(commands.Cog): embed = discord.Embed( title="Ban Notice", description=f"You have been banned from **{interaction.guild.name}**", - color=discord.Color.red() + color=discord.Color.red(), + ) + embed.add_field( + name="Reason", value=reason or "No reason provided", inline=False + ) + embed.add_field( + name="Moderator", value=interaction.user.name, inline=False + ) + embed.set_footer( + text=f"Server ID: {interaction.guild.id} • {discord.utils.utcnow().strftime('%Y-%m-%d %H:%M:%S')} UTC" ) - embed.add_field(name="Reason", value=reason or "No reason provided", inline=False) - embed.add_field(name="Moderator", value=interaction.user.name, inline=False) - embed.set_footer(text=f"Server ID: {interaction.guild.id} • {discord.utils.utcnow().strftime('%Y-%m-%d %H:%M:%S')} UTC") await member.send(embed=embed) dm_sent = True @@ -274,10 +300,12 @@ class ModerationCog(commands.Cog): await member.ban(reason=reason, delete_message_days=delete_days) # Log the action - logger.info(f"User {member} (ID: {member.id}) was banned from {interaction.guild.name} (ID: {interaction.guild.id}) by {interaction.user} (ID: {interaction.user.id}). Reason: {reason}") + logger.info( + f"User {member} (ID: {member.id}) was banned from {interaction.guild.name} (ID: {interaction.guild.id}) by {interaction.user} (ID: {interaction.user.id}). Reason: {reason}" + ) # --- Add to Mod Log DB --- - mod_log_cog: ModLogCog = self.bot.get_cog('ModLogCog') + mod_log_cog: ModLogCog = self.bot.get_cog("ModLogCog") if mod_log_cog: await mod_log_cog.log_action( guild=interaction.guild, @@ -286,53 +314,81 @@ class ModerationCog(commands.Cog): action_type="BAN", reason=reason, # Ban duration isn't directly supported here, pass None - duration=None + duration=None, ) # ------------------------- # Send confirmation message with DM status target_text = self._user_display(member) if send_dm: - dm_status = "✅ DM notification sent" if dm_sent else "❌ Could not send DM notification (user may have DMs disabled)" - await interaction.response.send_message(f"🔨 **Banned {target_text}**! Reason: {reason or 'No reason provided'}\n{dm_status}") + dm_status = ( + "✅ DM notification sent" + if dm_sent + else "❌ Could not send DM notification (user may have DMs disabled)" + ) + await interaction.response.send_message( + f"🔨 **Banned {target_text}**! Reason: {reason or 'No reason provided'}\n{dm_status}" + ) else: - await interaction.response.send_message(f"🔨 **Banned {target_text}**! Reason: {reason or 'No reason provided'}\n⚠️ DM notification was disabled") + await interaction.response.send_message( + f"🔨 **Banned {target_text}**! Reason: {reason or 'No reason provided'}\n⚠️ DM notification was disabled" + ) except discord.Forbidden: - await interaction.response.send_message("❌ I don't have permission to ban this member.", ephemeral=True) + await interaction.response.send_message( + "❌ I don't have permission to ban this member.", ephemeral=True + ) except discord.HTTPException as e: - await interaction.response.send_message(f"❌ An error occurred while banning the member: {e}", ephemeral=True) + await interaction.response.send_message( + f"❌ An error occurred while banning the member: {e}", ephemeral=True + ) - async def moderate_unban_callback(self, interaction: discord.Interaction, user_id: str, reason: str = None): + async def moderate_unban_callback( + self, interaction: discord.Interaction, user_id: str, reason: str = None + ): """Unban a user from the server.""" # Check if the user has permission to ban members (which includes unbanning) if not interaction.user.guild_permissions.ban_members: - await interaction.response.send_message("❌ You don't have permission to unban users.", ephemeral=True) + await interaction.response.send_message( + "❌ You don't have permission to unban users.", ephemeral=True + ) return # Check if the bot has permission to ban members (which includes unbanning) if not interaction.guild.me.guild_permissions.ban_members: - await interaction.response.send_message("❌ I don't have permission to unban users.", ephemeral=True) + await interaction.response.send_message( + "❌ I don't have permission to unban users.", ephemeral=True + ) return # Validate user ID try: user_id_int = int(user_id) except ValueError: - await interaction.response.send_message("❌ Invalid user ID. Please provide a valid user ID.", ephemeral=True) + await interaction.response.send_message( + "❌ Invalid user ID. Please provide a valid user ID.", ephemeral=True + ) return # Check if the user is banned try: - ban_entry = await interaction.guild.fetch_ban(discord.Object(id=user_id_int)) + ban_entry = await interaction.guild.fetch_ban( + discord.Object(id=user_id_int) + ) banned_user = ban_entry.user except discord.NotFound: - await interaction.response.send_message("❌ This user is not banned.", ephemeral=True) + await interaction.response.send_message( + "❌ This user is not banned.", ephemeral=True + ) return except discord.Forbidden: - await interaction.response.send_message("❌ I don't have permission to view the ban list.", ephemeral=True) + await interaction.response.send_message( + "❌ I don't have permission to view the ban list.", ephemeral=True + ) return except discord.HTTPException as e: - await interaction.response.send_message(f"❌ An error occurred while checking the ban list: {e}", ephemeral=True) + await interaction.response.send_message( + f"❌ An error occurred while checking the ban list: {e}", ephemeral=True + ) return # Perform the unban @@ -340,58 +396,88 @@ class ModerationCog(commands.Cog): await interaction.guild.unban(banned_user, reason=reason) # Log the action - logger.info(f"User {banned_user} (ID: {banned_user.id}) was unbanned from {interaction.guild.name} (ID: {interaction.guild.id}) by {interaction.user} (ID: {interaction.user.id}). Reason: {reason}") + logger.info( + f"User {banned_user} (ID: {banned_user.id}) was unbanned from {interaction.guild.name} (ID: {interaction.guild.id}) by {interaction.user} (ID: {interaction.user.id}). Reason: {reason}" + ) # --- Add to Mod Log DB --- - mod_log_cog: ModLogCog = self.bot.get_cog('ModLogCog') + mod_log_cog: ModLogCog = self.bot.get_cog("ModLogCog") if mod_log_cog: await mod_log_cog.log_action( guild=interaction.guild, moderator=interaction.user, - target=banned_user, # Use the fetched user object + target=banned_user, # Use the fetched user object action_type="UNBAN", reason=reason, - duration=None + duration=None, ) # ------------------------- # Send confirmation message - await interaction.response.send_message(f"🔓 **Unbanned {self._user_display(banned_user)}**! Reason: {reason or 'No reason provided'}") + await interaction.response.send_message( + f"🔓 **Unbanned {self._user_display(banned_user)}**! Reason: {reason or 'No reason provided'}" + ) except discord.Forbidden: - await interaction.response.send_message("❌ I don't have permission to unban this user.", ephemeral=True) + await interaction.response.send_message( + "❌ I don't have permission to unban this user.", ephemeral=True + ) except discord.HTTPException as e: - await interaction.response.send_message(f"❌ An error occurred while unbanning the user: {e}", ephemeral=True) + await interaction.response.send_message( + f"❌ An error occurred while unbanning the user: {e}", ephemeral=True + ) - async def moderate_kick_callback(self, interaction: discord.Interaction, member: discord.Member, reason: str = None): + async def moderate_kick_callback( + self, + interaction: discord.Interaction, + member: discord.Member, + reason: str = None, + ): """Kick a member from the server.""" # Check if the user has permission to kick members if not interaction.user.guild_permissions.kick_members: - await interaction.response.send_message("❌ You don't have permission to kick members.", ephemeral=True) + await interaction.response.send_message( + "❌ You don't have permission to kick members.", ephemeral=True + ) return # Check if the bot has permission to kick members if not interaction.guild.me.guild_permissions.kick_members: - await interaction.response.send_message("❌ I don't have permission to kick members.", ephemeral=True) + await interaction.response.send_message( + "❌ I don't have permission to kick members.", ephemeral=True + ) return # Check if the user is trying to kick themselves if member.id == interaction.user.id: - await interaction.response.send_message("❌ You cannot kick yourself.", ephemeral=True) + await interaction.response.send_message( + "❌ You cannot kick yourself.", ephemeral=True + ) return # Check if the user is trying to kick the bot if member.id == self.bot.user.id: - await interaction.response.send_message("❌ I cannot kick myself.", ephemeral=True) + await interaction.response.send_message( + "❌ I cannot kick myself.", ephemeral=True + ) return # Check if the user is trying to kick someone with a higher role - if interaction.user.top_role <= member.top_role and interaction.user.id != interaction.guild.owner_id: - await interaction.response.send_message("❌ You cannot kick someone with a higher or equal role.", ephemeral=True) + if ( + interaction.user.top_role <= member.top_role + and interaction.user.id != interaction.guild.owner_id + ): + await interaction.response.send_message( + "❌ You cannot kick someone with a higher or equal role.", + ephemeral=True, + ) return # Check if the bot can kick the member (role hierarchy) if interaction.guild.me.top_role <= member.top_role: - await interaction.response.send_message("❌ I cannot kick someone with a higher or equal role than me.", ephemeral=True) + await interaction.response.send_message( + "❌ I cannot kick someone with a higher or equal role than me.", + ephemeral=True, + ) return # Try to send a DM to the user before kicking them @@ -400,11 +486,15 @@ class ModerationCog(commands.Cog): embed = discord.Embed( title="Kick Notice", description=f"You have been kicked from **{interaction.guild.name}**", - color=discord.Color.orange() + color=discord.Color.orange(), + ) + embed.add_field( + name="Reason", value=reason or "No reason provided", inline=False ) - embed.add_field(name="Reason", value=reason or "No reason provided", inline=False) embed.add_field(name="Moderator", value=interaction.user.name, inline=False) - embed.set_footer(text=f"Server ID: {interaction.guild.id} • {discord.utils.utcnow().strftime('%Y-%m-%d %H:%M:%S')} UTC") + embed.set_footer( + text=f"Server ID: {interaction.guild.id} • {discord.utils.utcnow().strftime('%Y-%m-%d %H:%M:%S')} UTC" + ) await member.send(embed=embed) dm_sent = True @@ -419,10 +509,12 @@ class ModerationCog(commands.Cog): await member.kick(reason=reason) # Log the action - logger.info(f"User {member} (ID: {member.id}) was kicked from {interaction.guild.name} (ID: {interaction.guild.id}) by {interaction.user} (ID: {interaction.user.id}). Reason: {reason}") + logger.info( + f"User {member} (ID: {member.id}) was kicked from {interaction.guild.name} (ID: {interaction.guild.id}) by {interaction.user} (ID: {interaction.user.id}). Reason: {reason}" + ) # --- Add to Mod Log DB --- - mod_log_cog: ModLogCog = self.bot.get_cog('ModLogCog') + mod_log_cog: ModLogCog = self.bot.get_cog("ModLogCog") if mod_log_cog: await mod_log_cog.log_action( guild=interaction.guild, @@ -430,19 +522,35 @@ class ModerationCog(commands.Cog): target=member, action_type="KICK", reason=reason, - duration=None + duration=None, ) # ------------------------- # Send confirmation message with DM status - dm_status = "✅ DM notification sent" if dm_sent else "❌ Could not send DM notification (user may have DMs disabled)" - await interaction.response.send_message(f"👢 **Kicked {self._user_display(member)}**! Reason: {reason or 'No reason provided'}\n{dm_status}") + dm_status = ( + "✅ DM notification sent" + if dm_sent + else "❌ Could not send DM notification (user may have DMs disabled)" + ) + await interaction.response.send_message( + f"👢 **Kicked {self._user_display(member)}**! Reason: {reason or 'No reason provided'}\n{dm_status}" + ) except discord.Forbidden: - await interaction.response.send_message("❌ I don't have permission to kick this member.", ephemeral=True) + await interaction.response.send_message( + "❌ I don't have permission to kick this member.", ephemeral=True + ) except discord.HTTPException as e: - await interaction.response.send_message(f"❌ An error occurred while kicking the member: {e}", ephemeral=True) + await interaction.response.send_message( + f"❌ An error occurred while kicking the member: {e}", ephemeral=True + ) - async def moderate_timeout_callback(self, interaction: discord.Interaction, member: discord.Member, duration: str, reason: str = None): + async def moderate_timeout_callback( + self, + interaction: discord.Interaction, + member: discord.Member, + duration: str, + reason: str = None, + ): """Timeout a member in the server.""" # Defer the response immediately to prevent expiration @@ -455,18 +563,24 @@ class ModerationCog(commands.Cog): # Helper for always using followup after defer async def safe_followup(content=None, *, ephemeral=False, embed=None): try: - await interaction.followup.send(content, ephemeral=ephemeral, embed=embed) + await interaction.followup.send( + content, ephemeral=ephemeral, embed=embed + ) except Exception as e: logger.error(f"Failed to send followup response: {e}") # Check if the user has permission to moderate members if not interaction.user.guild_permissions.moderate_members: - await safe_followup("❌ You don't have permission to timeout members.", ephemeral=True) + await safe_followup( + "❌ You don't have permission to timeout members.", ephemeral=True + ) return # Check if the bot has permission to moderate members if not interaction.guild.me.guild_permissions.moderate_members: - await safe_followup("❌ I don't have permission to timeout members.", ephemeral=True) + await safe_followup( + "❌ I don't have permission to timeout members.", ephemeral=True + ) return # Check if the user is trying to timeout themselves @@ -480,25 +594,39 @@ class ModerationCog(commands.Cog): return # Check if the user is trying to timeout someone with a higher role - if interaction.user.top_role <= member.top_role and interaction.user.id != interaction.guild.owner_id: - await safe_followup("❌ You cannot timeout someone with a higher or equal role.", ephemeral=True) + if ( + interaction.user.top_role <= member.top_role + and interaction.user.id != interaction.guild.owner_id + ): + await safe_followup( + "❌ You cannot timeout someone with a higher or equal role.", + ephemeral=True, + ) return # Check if the bot can timeout the member (role hierarchy) if interaction.guild.me.top_role <= member.top_role: - await safe_followup("❌ I cannot timeout someone with a higher or equal role than me.", ephemeral=True) + await safe_followup( + "❌ I cannot timeout someone with a higher or equal role than me.", + ephemeral=True, + ) return # Parse the duration delta = self._parse_duration(duration) if not delta: - await safe_followup("❌ Invalid duration format. Please use formats like '1d', '2h', '30m', or '60s'.", ephemeral=True) + await safe_followup( + "❌ Invalid duration format. Please use formats like '1d', '2h', '30m', or '60s'.", + ephemeral=True, + ) return # Check if the duration is within Discord's limits (max 28 days) max_timeout = datetime.timedelta(days=28) if delta > max_timeout: - await safe_followup("❌ Timeout duration cannot exceed 28 days.", ephemeral=True) + await safe_followup( + "❌ Timeout duration cannot exceed 28 days.", ephemeral=True + ) return # Calculate the end time @@ -510,13 +638,19 @@ class ModerationCog(commands.Cog): embed = discord.Embed( title="Timeout Notice", description=f"You have been timed out in **{interaction.guild.name}** for {duration}", - color=discord.Color.gold() + color=discord.Color.gold(), + ) + embed.add_field( + name="Reason", value=reason or "No reason provided", inline=False ) - embed.add_field(name="Reason", value=reason or "No reason provided", inline=False) embed.add_field(name="Moderator", value=interaction.user.name, inline=False) embed.add_field(name="Duration", value=duration, inline=False) - embed.add_field(name="Expires", value=f"", inline=False) - embed.set_footer(text=f"Server ID: {interaction.guild.id} • {discord.utils.utcnow().strftime('%Y-%m-%d %H:%M:%S')} UTC") + embed.add_field( + name="Expires", value=f"", inline=False + ) + embed.set_footer( + text=f"Server ID: {interaction.guild.id} • {discord.utils.utcnow().strftime('%Y-%m-%d %H:%M:%S')} UTC" + ) await member.send(embed=embed) dm_sent = True @@ -531,10 +665,12 @@ class ModerationCog(commands.Cog): await member.timeout(until, reason=reason) # Log the action - logger.info(f"User {member} (ID: {member.id}) was timed out in {interaction.guild.name} (ID: {interaction.guild.id}) by {interaction.user} (ID: {interaction.user.id}) for {duration}. Reason: {reason}") + logger.info( + f"User {member} (ID: {member.id}) was timed out in {interaction.guild.name} (ID: {interaction.guild.id}) by {interaction.user} (ID: {interaction.user.id}) for {duration}. Reason: {reason}" + ) # --- Add to Mod Log DB --- - mod_log_cog: ModLogCog = self.bot.get_cog('ModLogCog') + mod_log_cog: ModLogCog = self.bot.get_cog("ModLogCog") if mod_log_cog: await mod_log_cog.log_action( guild=interaction.guild, @@ -542,33 +678,54 @@ class ModerationCog(commands.Cog): target=member, action_type="TIMEOUT", reason=reason, - duration=delta # Pass the timedelta object + duration=delta, # Pass the timedelta object ) # ------------------------- # Send confirmation message with DM status - dm_status = "✅ DM notification sent" if dm_sent else "❌ Could not send DM notification (user may have DMs disabled)" - await safe_followup(f"⏰ **Timed out {self._user_display(member)}** for {duration}! Reason: {reason or 'No reason provided'}\n{dm_status}") + dm_status = ( + "✅ DM notification sent" + if dm_sent + else "❌ Could not send DM notification (user may have DMs disabled)" + ) + await safe_followup( + f"⏰ **Timed out {self._user_display(member)}** for {duration}! Reason: {reason or 'No reason provided'}\n{dm_status}" + ) except discord.Forbidden: - await safe_followup("❌ I don't have permission to timeout this member.", ephemeral=True) + await safe_followup( + "❌ I don't have permission to timeout this member.", ephemeral=True + ) except discord.HTTPException as e: - await safe_followup(f"❌ An error occurred while timing out the member: {e}", ephemeral=True) + await safe_followup( + f"❌ An error occurred while timing out the member: {e}", ephemeral=True + ) - async def moderate_remove_timeout_callback(self, interaction: discord.Interaction, member: discord.Member, reason: str = None): + async def moderate_remove_timeout_callback( + self, + interaction: discord.Interaction, + member: discord.Member, + reason: str = None, + ): """Remove a timeout from a member.""" # Check if the user has permission to moderate members if not interaction.user.guild_permissions.moderate_members: - await interaction.response.send_message("❌ You don't have permission to remove timeouts.", ephemeral=True) + await interaction.response.send_message( + "❌ You don't have permission to remove timeouts.", ephemeral=True + ) return # Check if the bot has permission to moderate members if not interaction.guild.me.guild_permissions.moderate_members: - await interaction.response.send_message("❌ I don't have permission to remove timeouts.", ephemeral=True) + await interaction.response.send_message( + "❌ I don't have permission to remove timeouts.", ephemeral=True + ) return # Check if the member is timed out if not member.timed_out_until: - await interaction.response.send_message("❌ This member is not timed out.", ephemeral=True) + await interaction.response.send_message( + "❌ This member is not timed out.", ephemeral=True + ) return # Try to send a DM to the user about the timeout removal @@ -577,11 +734,15 @@ class ModerationCog(commands.Cog): embed = discord.Embed( title="Timeout Removed", description=f"Your timeout in **{interaction.guild.name}** has been removed", - color=discord.Color.green() + color=discord.Color.green(), + ) + embed.add_field( + name="Reason", value=reason or "No reason provided", inline=False ) - embed.add_field(name="Reason", value=reason or "No reason provided", inline=False) embed.add_field(name="Moderator", value=interaction.user.name, inline=False) - embed.set_footer(text=f"Server ID: {interaction.guild.id} • {discord.utils.utcnow().strftime('%Y-%m-%d %H:%M:%S')} UTC") + embed.set_footer( + text=f"Server ID: {interaction.guild.id} • {discord.utils.utcnow().strftime('%Y-%m-%d %H:%M:%S')} UTC" + ) await member.send(embed=embed) dm_sent = True @@ -589,17 +750,21 @@ class ModerationCog(commands.Cog): # User has DMs closed, ignore pass except Exception as e: - logger.error(f"Error sending timeout removal DM to {member} (ID: {member.id}): {e}") + logger.error( + f"Error sending timeout removal DM to {member} (ID: {member.id}): {e}" + ) # Perform the timeout removal try: await member.timeout(None, reason=reason) # Log the action - logger.info(f"Timeout was removed from user {member} (ID: {member.id}) in {interaction.guild.name} (ID: {interaction.guild.id}) by {interaction.user} (ID: {interaction.user.id}). Reason: {reason}") + logger.info( + f"Timeout was removed from user {member} (ID: {member.id}) in {interaction.guild.name} (ID: {interaction.guild.id}) by {interaction.user} (ID: {interaction.user.id}). Reason: {reason}" + ) # --- Add to Mod Log DB --- - mod_log_cog: ModLogCog = self.bot.get_cog('ModLogCog') + mod_log_cog: ModLogCog = self.bot.get_cog("ModLogCog") if mod_log_cog: await mod_log_cog.log_action( guild=interaction.guild, @@ -607,33 +772,56 @@ class ModerationCog(commands.Cog): target=member, action_type="REMOVE_TIMEOUT", reason=reason, - duration=None + duration=None, ) # ------------------------- # Send confirmation message with DM status - dm_status = "✅ DM notification sent" if dm_sent else "❌ Could not send DM notification (user may have DMs disabled)" - await interaction.response.send_message(f"⏰ **Removed timeout from {self._user_display(member)}**! Reason: {reason or 'No reason provided'}\n{dm_status}") + dm_status = ( + "✅ DM notification sent" + if dm_sent + else "❌ Could not send DM notification (user may have DMs disabled)" + ) + await interaction.response.send_message( + f"⏰ **Removed timeout from {self._user_display(member)}**! Reason: {reason or 'No reason provided'}\n{dm_status}" + ) except discord.Forbidden: - await interaction.response.send_message("❌ I don't have permission to remove the timeout from this member.", ephemeral=True) + await interaction.response.send_message( + "❌ I don't have permission to remove the timeout from this member.", + ephemeral=True, + ) except discord.HTTPException as e: - await interaction.response.send_message(f"❌ An error occurred while removing the timeout: {e}", ephemeral=True) + await interaction.response.send_message( + f"❌ An error occurred while removing the timeout: {e}", ephemeral=True + ) - async def moderate_purge_callback(self, interaction: discord.Interaction, amount: int, user: Optional[discord.Member] = None): + async def moderate_purge_callback( + self, + interaction: discord.Interaction, + amount: int, + user: Optional[discord.Member] = None, + ): """Delete a specified number of messages from a channel.""" # Check if the user has permission to manage messages if not interaction.user.guild_permissions.manage_messages: - await interaction.response.send_message("❌ You don't have permission to purge messages.", ephemeral=True) + await interaction.response.send_message( + "❌ You don't have permission to purge messages.", ephemeral=True + ) return # Check if the bot has permission to manage messages if not interaction.guild.me.guild_permissions.manage_messages: - await interaction.response.send_message("❌ I don't have permission to purge messages.", ephemeral=True) + await interaction.response.send_message( + "❌ I don't have permission to purge messages.", ephemeral=True + ) return # Validate the amount if amount < 1 or amount > 100: - await interaction.response.send_message("❌ You can only purge between 1 and 100 messages at a time.", ephemeral=True) + await interaction.response.send_message( + "❌ You can only purge between 1 and 100 messages at a time.", + ephemeral=True, + ) return # Defer the response since this might take a moment @@ -649,51 +837,81 @@ class ModerationCog(commands.Cog): deleted = await interaction.channel.purge(limit=amount, check=check) # Log the action - logger.info(f"{len(deleted)} messages from user {user} (ID: {user.id}) were purged from channel {interaction.channel.name} (ID: {interaction.channel.id}) in {interaction.guild.name} (ID: {interaction.guild.id}) by {interaction.user} (ID: {interaction.user.id}).") + logger.info( + f"{len(deleted)} messages from user {user} (ID: {user.id}) were purged from channel {interaction.channel.name} (ID: {interaction.channel.id}) in {interaction.guild.name} (ID: {interaction.guild.id}) by {interaction.user} (ID: {interaction.user.id})." + ) # Send confirmation message - await interaction.followup.send(f"🧹 **Purged {len(deleted)} messages** from {self._user_display(user)}!", ephemeral=True) + await interaction.followup.send( + f"🧹 **Purged {len(deleted)} messages** from {self._user_display(user)}!", + ephemeral=True, + ) else: # Delete messages from anyone deleted = await interaction.channel.purge(limit=amount) # Log the action - logger.info(f"{len(deleted)} messages were purged from channel {interaction.channel.name} (ID: {interaction.channel.id}) in {interaction.guild.name} (ID: {interaction.guild.id}) by {interaction.user} (ID: {interaction.user.id}).") + logger.info( + f"{len(deleted)} messages were purged from channel {interaction.channel.name} (ID: {interaction.channel.id}) in {interaction.guild.name} (ID: {interaction.guild.id}) by {interaction.user} (ID: {interaction.user.id})." + ) # Send confirmation message - await interaction.followup.send(f"🧹 **Purged {len(deleted)} messages**!", ephemeral=True) + await interaction.followup.send( + f"🧹 **Purged {len(deleted)} messages**!", ephemeral=True + ) except discord.Forbidden: - await interaction.followup.send("❌ I don't have permission to delete messages in this channel.", ephemeral=True) + await interaction.followup.send( + "❌ I don't have permission to delete messages in this channel.", + ephemeral=True, + ) except discord.HTTPException as e: - await interaction.followup.send(f"❌ An error occurred while purging messages: {e}", ephemeral=True) + await interaction.followup.send( + f"❌ An error occurred while purging messages: {e}", ephemeral=True + ) - async def moderate_warn_callback(self, interaction: discord.Interaction, member: discord.Member, reason: str): + async def moderate_warn_callback( + self, interaction: discord.Interaction, member: discord.Member, reason: str + ): """Warn a member in the server.""" # Check if the user has permission to kick members (using kick permission as a baseline for warning) if not interaction.user.guild_permissions.kick_members: - await interaction.response.send_message("❌ You don't have permission to warn members.", ephemeral=True) + await interaction.response.send_message( + "❌ You don't have permission to warn members.", ephemeral=True + ) return # Check if the user is trying to warn themselves if member.id == interaction.user.id: - await interaction.response.send_message("❌ You cannot warn yourself.", ephemeral=True) + await interaction.response.send_message( + "❌ You cannot warn yourself.", ephemeral=True + ) return # Check if the user is trying to warn the bot if member.id == self.bot.user.id: - await interaction.response.send_message("❌ I cannot warn myself.", ephemeral=True) + await interaction.response.send_message( + "❌ I cannot warn myself.", ephemeral=True + ) return # Check if the user is trying to warn someone with a higher role - if interaction.user.top_role <= member.top_role and interaction.user.id != interaction.guild.owner_id: - await interaction.response.send_message("❌ You cannot warn someone with a higher or equal role.", ephemeral=True) + if ( + interaction.user.top_role <= member.top_role + and interaction.user.id != interaction.guild.owner_id + ): + await interaction.response.send_message( + "❌ You cannot warn someone with a higher or equal role.", + ephemeral=True, + ) return # Log the warning (using standard logger first) - logger.info(f"User {member} (ID: {member.id}) was warned in {interaction.guild.name} (ID: {interaction.guild.id}) by {interaction.user} (ID: {interaction.user.id}). Reason: {reason}") + logger.info( + f"User {member} (ID: {member.id}) was warned in {interaction.guild.name} (ID: {interaction.guild.id}) by {interaction.user} (ID: {interaction.user.id}). Reason: {reason}" + ) # --- Add to Mod Log DB --- - mod_log_cog: ModLogCog = self.bot.get_cog('ModLogCog') + mod_log_cog: ModLogCog = self.bot.get_cog("ModLogCog") if mod_log_cog: await mod_log_cog.log_action( guild=interaction.guild, @@ -701,23 +919,27 @@ class ModerationCog(commands.Cog): target=member, action_type="WARN", reason=reason, - duration=None + duration=None, ) # ------------------------- # Send warning message in the channel - await interaction.response.send_message(f"⚠️ **{self._user_display(member)} has been warned**! Reason: {reason}") + await interaction.response.send_message( + f"⚠️ **{self._user_display(member)} has been warned**! Reason: {reason}" + ) # Try to DM the user about the warning try: embed = discord.Embed( title="Warning Notice", description=f"You have been warned in **{interaction.guild.name}**", - color=discord.Color.yellow() + color=discord.Color.yellow(), ) embed.add_field(name="Reason", value=reason, inline=False) embed.add_field(name="Moderator", value=interaction.user.name, inline=False) - embed.set_footer(text=f"Server ID: {interaction.guild.id} • {discord.utils.utcnow().strftime('%Y-%m-%d %H:%M:%S')} UTC") + embed.set_footer( + text=f"Server ID: {interaction.guild.id} • {discord.utils.utcnow().strftime('%Y-%m-%d %H:%M:%S')} UTC" + ) await member.send(embed=embed) except discord.Forbidden: @@ -726,32 +948,46 @@ class ModerationCog(commands.Cog): except Exception as e: logger.error(f"Error sending warning DM to {member} (ID: {member.id}): {e}") - async def moderate_dm_banned_callback(self, interaction: discord.Interaction, user_id: str, message: str): + async def moderate_dm_banned_callback( + self, interaction: discord.Interaction, user_id: str, message: str + ): """Send a DM to a banned user.""" # Check if the user has permission to ban members if not interaction.user.guild_permissions.ban_members: - await interaction.response.send_message("❌ You don't have permission to DM banned users.", ephemeral=True) + await interaction.response.send_message( + "❌ You don't have permission to DM banned users.", ephemeral=True + ) return # Validate user ID try: user_id_int = int(user_id) except ValueError: - await interaction.response.send_message("❌ Invalid user ID. Please provide a valid user ID.", ephemeral=True) + await interaction.response.send_message( + "❌ Invalid user ID. Please provide a valid user ID.", ephemeral=True + ) return # Check if the user is banned try: - ban_entry = await interaction.guild.fetch_ban(discord.Object(id=user_id_int)) + ban_entry = await interaction.guild.fetch_ban( + discord.Object(id=user_id_int) + ) banned_user = ban_entry.user except discord.NotFound: - await interaction.response.send_message("❌ This user is not banned.", ephemeral=True) + await interaction.response.send_message( + "❌ This user is not banned.", ephemeral=True + ) return except discord.Forbidden: - await interaction.response.send_message("❌ I don't have permission to view the ban list.", ephemeral=True) + await interaction.response.send_message( + "❌ I don't have permission to view the ban list.", ephemeral=True + ) return except discord.HTTPException as e: - await interaction.response.send_message(f"❌ An error occurred while checking the ban list: {e}", ephemeral=True) + await interaction.response.send_message( + f"❌ An error occurred while checking the ban list: {e}", ephemeral=True + ) return # Try to send a DM to the banned user @@ -760,59 +996,88 @@ class ModerationCog(commands.Cog): embed = discord.Embed( title=f"Message from {interaction.guild.name}", description=message, - color=discord.Color.red() + color=discord.Color.red(), ) embed.add_field(name="Sent by", value=interaction.user.name, inline=False) - embed.set_footer(text=f"Server ID: {interaction.guild.id} • {discord.utils.utcnow().strftime('%Y-%m-%d %H:%M:%S')} UTC") + embed.set_footer( + text=f"Server ID: {interaction.guild.id} • {discord.utils.utcnow().strftime('%Y-%m-%d %H:%M:%S')} UTC" + ) # Send the DM await banned_user.send(embed=embed) # Log the action - logger.info(f"DM sent to banned user {banned_user} (ID: {banned_user.id}) in {interaction.guild.name} (ID: {interaction.guild.id}) by {interaction.user} (ID: {interaction.user.id}).") + logger.info( + f"DM sent to banned user {banned_user} (ID: {banned_user.id}) in {interaction.guild.name} (ID: {interaction.guild.id}) by {interaction.user} (ID: {interaction.user.id})." + ) # Send confirmation message - await interaction.response.send_message(f"✅ **DM sent to banned user {banned_user}**!", ephemeral=True) + await interaction.response.send_message( + f"✅ **DM sent to banned user {banned_user}**!", ephemeral=True + ) except discord.Forbidden: - await interaction.response.send_message("❌ I couldn't send a DM to this user. They may have DMs disabled or have blocked the bot.", ephemeral=True) + await interaction.response.send_message( + "❌ I couldn't send a DM to this user. They may have DMs disabled or have blocked the bot.", + ephemeral=True, + ) except discord.HTTPException as e: - await interaction.response.send_message(f"❌ An error occurred while sending the DM: {e}", ephemeral=True) + await interaction.response.send_message( + f"❌ An error occurred while sending the DM: {e}", ephemeral=True + ) except Exception as e: - logger.error(f"Error sending DM to banned user {banned_user} (ID: {banned_user.id}): {e}") - await interaction.response.send_message(f"❌ An unexpected error occurred: {e}", ephemeral=True) + logger.error( + f"Error sending DM to banned user {banned_user} (ID: {banned_user.id}): {e}" + ) + await interaction.response.send_message( + f"❌ An unexpected error occurred: {e}", ephemeral=True + ) - async def moderate_view_infractions_callback(self, interaction: discord.Interaction, member: discord.Member): + async def moderate_view_infractions_callback( + self, interaction: discord.Interaction, member: discord.Member + ): """View moderation infractions for a user.""" - if not interaction.user.guild_permissions.kick_members: # Using kick_members as a general mod permission - await interaction.response.send_message("❌ You don't have permission to view infractions.", ephemeral=True) + if ( + not interaction.user.guild_permissions.kick_members + ): # Using kick_members as a general mod permission + await interaction.response.send_message( + "❌ You don't have permission to view infractions.", ephemeral=True + ) return if not self.bot.pg_pool: - await interaction.response.send_message("❌ Database connection is not available.", ephemeral=True) + await interaction.response.send_message( + "❌ Database connection is not available.", ephemeral=True + ) logger.error("Cannot view infractions: pg_pool is None.") return - infractions = await mod_log_db.get_user_mod_logs(self.bot.pg_pool, interaction.guild.id, member.id) + infractions = await mod_log_db.get_user_mod_logs( + self.bot.pg_pool, interaction.guild.id, member.id + ) if not infractions: - await interaction.response.send_message(f"No infractions found for {self._user_display(member)}.", ephemeral=True) + await interaction.response.send_message( + f"No infractions found for {self._user_display(member)}.", + ephemeral=True, + ) return embed = discord.Embed( - title=f"Infractions for {member.display_name}", - color=discord.Color.orange() + title=f"Infractions for {member.display_name}", color=discord.Color.orange() ) embed.set_thumbnail(url=member.display_avatar.url) - for infraction in infractions[:25]: # Display up to 25 infractions - action_type = infraction['action_type'] - reason = infraction['reason'] or "No reason provided" - moderator_id = infraction['moderator_id'] - timestamp = infraction['timestamp'] - case_id = infraction['case_id'] - duration_seconds = infraction['duration_seconds'] + for infraction in infractions[:25]: # Display up to 25 infractions + action_type = infraction["action_type"] + reason = infraction["reason"] or "No reason provided" + moderator_id = infraction["moderator_id"] + timestamp = infraction["timestamp"] + case_id = infraction["case_id"] + duration_seconds = infraction["duration_seconds"] - moderator = interaction.guild.get_member(moderator_id) or f"ID: {moderator_id}" + moderator = ( + interaction.guild.get_member(moderator_id) or f"ID: {moderator_id}" + ) value = f"**Case ID:** {case_id}\n" value += f"**Action:** {action_type}\n" @@ -830,33 +1095,53 @@ class ModerationCog(commands.Cog): await interaction.response.send_message(embed=embed, ephemeral=True) - async def moderate_remove_infraction_callback(self, interaction: discord.Interaction, case_id: int, reason: str = None): + async def moderate_remove_infraction_callback( + self, interaction: discord.Interaction, case_id: int, reason: str = None + ): """Remove a specific infraction by its case ID.""" - if not interaction.user.guild_permissions.ban_members: # Higher permission for removing infractions - await interaction.response.send_message("❌ You don't have permission to remove infractions.", ephemeral=True) + if ( + not interaction.user.guild_permissions.ban_members + ): # Higher permission for removing infractions + await interaction.response.send_message( + "❌ You don't have permission to remove infractions.", ephemeral=True + ) return if not self.bot.pg_pool: - await interaction.response.send_message("❌ Database connection is not available.", ephemeral=True) + await interaction.response.send_message( + "❌ Database connection is not available.", ephemeral=True + ) logger.error("Cannot remove infraction: pg_pool is None.") return # Fetch the infraction to ensure it exists and to log details infraction_to_remove = await mod_log_db.get_mod_log(self.bot.pg_pool, case_id) - if not infraction_to_remove or infraction_to_remove['guild_id'] != interaction.guild.id: - await interaction.response.send_message(f"❌ Infraction with Case ID {case_id} not found in this server.", ephemeral=True) + if ( + not infraction_to_remove + or infraction_to_remove["guild_id"] != interaction.guild.id + ): + await interaction.response.send_message( + f"❌ Infraction with Case ID {case_id} not found in this server.", + ephemeral=True, + ) return - deleted = await mod_log_db.delete_mod_log(self.bot.pg_pool, case_id, interaction.guild.id) + deleted = await mod_log_db.delete_mod_log( + self.bot.pg_pool, case_id, interaction.guild.id + ) if deleted: - logger.info(f"Infraction (Case ID: {case_id}) removed by {interaction.user} (ID: {interaction.user.id}) in guild {interaction.guild.id}. Reason: {reason}") + logger.info( + f"Infraction (Case ID: {case_id}) removed by {interaction.user} (ID: {interaction.user.id}) in guild {interaction.guild.id}. Reason: {reason}" + ) # Log the removal action itself - mod_log_cog: ModLogCog = self.bot.get_cog('ModLogCog') + mod_log_cog: ModLogCog = self.bot.get_cog("ModLogCog") if mod_log_cog: - target_user_id = infraction_to_remove['target_user_id'] - target_user = await self.bot.fetch_user(target_user_id) # Fetch user for logging + target_user_id = infraction_to_remove["target_user_id"] + target_user = await self.bot.fetch_user( + target_user_id + ) # Fetch user for logging await mod_log_cog.log_action( guild=interaction.guild, @@ -864,41 +1149,71 @@ class ModerationCog(commands.Cog): target=target_user if target_user else Object(id=target_user_id), action_type="REMOVE_INFRACTION", reason=f"Removed Case ID {case_id}. Original reason: {infraction_to_remove['reason']}. Removal reason: {reason or 'Not specified'}", - duration=None + duration=None, ) - await interaction.response.send_message(f"✅ Infraction with Case ID {case_id} has been removed. Reason: {reason or 'Not specified'}", ephemeral=True) + await interaction.response.send_message( + f"✅ Infraction with Case ID {case_id} has been removed. Reason: {reason or 'Not specified'}", + ephemeral=True, + ) else: - await interaction.response.send_message(f"❌ Failed to remove infraction with Case ID {case_id}. It might have already been removed or an error occurred.", ephemeral=True) + await interaction.response.send_message( + f"❌ Failed to remove infraction with Case ID {case_id}. It might have already been removed or an error occurred.", + ephemeral=True, + ) - async def moderate_clear_infractions_callback(self, interaction: discord.Interaction, member: discord.Member, reason: str = None): + async def moderate_clear_infractions_callback( + self, + interaction: discord.Interaction, + member: discord.Member, + reason: str = None, + ): """Clear all moderation infractions for a user.""" # This is a destructive action, so require ban_members permission if not interaction.user.guild_permissions.ban_members: - await interaction.response.send_message("❌ You don't have permission to clear all infractions for a user.", ephemeral=True) + await interaction.response.send_message( + "❌ You don't have permission to clear all infractions for a user.", + ephemeral=True, + ) return if not self.bot.pg_pool: - await interaction.response.send_message("❌ Database connection is not available.", ephemeral=True) + await interaction.response.send_message( + "❌ Database connection is not available.", ephemeral=True + ) logger.error("Cannot clear infractions: pg_pool is None.") return # Confirmation step view = discord.ui.View() - confirm_button = discord.ui.Button(label="Confirm Clear All", style=discord.ButtonStyle.danger, custom_id="confirm_clear_all") - cancel_button = discord.ui.Button(label="Cancel", style=discord.ButtonStyle.secondary, custom_id="cancel_clear_all") + confirm_button = discord.ui.Button( + label="Confirm Clear All", + style=discord.ButtonStyle.danger, + custom_id="confirm_clear_all", + ) + cancel_button = discord.ui.Button( + label="Cancel", + style=discord.ButtonStyle.secondary, + custom_id="cancel_clear_all", + ) async def confirm_callback(interaction_confirm: discord.Interaction): if interaction_confirm.user.id != interaction.user.id: - await interaction_confirm.response.send_message("❌ You are not authorized to confirm this action.", ephemeral=True) + await interaction_confirm.response.send_message( + "❌ You are not authorized to confirm this action.", ephemeral=True + ) return - deleted_count = await mod_log_db.clear_user_mod_logs(self.bot.pg_pool, interaction.guild.id, member.id) + deleted_count = await mod_log_db.clear_user_mod_logs( + self.bot.pg_pool, interaction.guild.id, member.id + ) if deleted_count > 0: - logger.info(f"{deleted_count} infractions for user {member} (ID: {member.id}) cleared by {interaction.user} (ID: {interaction.user.id}) in guild {interaction.guild.id}. Reason: {reason}") + logger.info( + f"{deleted_count} infractions for user {member} (ID: {member.id}) cleared by {interaction.user} (ID: {interaction.user.id}) in guild {interaction.guild.id}. Reason: {reason}" + ) # Log the clear all action - mod_log_cog: ModLogCog = self.bot.get_cog('ModLogCog') + mod_log_cog: ModLogCog = self.bot.get_cog("ModLogCog") if mod_log_cog: await mod_log_cog.log_action( guild=interaction.guild, @@ -906,19 +1221,32 @@ class ModerationCog(commands.Cog): target=member, action_type="CLEAR_INFRACTIONS", reason=f"Cleared {deleted_count} infractions. Reason: {reason or 'Not specified'}", - duration=None + duration=None, ) - await interaction_confirm.response.edit_message(content=f"✅ Successfully cleared {deleted_count} infractions for {self._user_display(member)}. Reason: {reason or 'Not specified'}", view=None) + await interaction_confirm.response.edit_message( + content=f"✅ Successfully cleared {deleted_count} infractions for {self._user_display(member)}. Reason: {reason or 'Not specified'}", + view=None, + ) elif deleted_count == 0: - await interaction_confirm.response.edit_message(content=f"ℹ️ No infractions found for {self._user_display(member)} to clear.", view=None) - else: # Should not happen if 0 is returned for no logs - await interaction_confirm.response.edit_message(content=f"❌ Failed to clear infractions for {self._user_display(member)}. An error occurred.", view=None) + await interaction_confirm.response.edit_message( + content=f"ℹ️ No infractions found for {self._user_display(member)} to clear.", + view=None, + ) + else: # Should not happen if 0 is returned for no logs + await interaction_confirm.response.edit_message( + content=f"❌ Failed to clear infractions for {self._user_display(member)}. An error occurred.", + view=None, + ) async def cancel_callback(interaction_cancel: discord.Interaction): if interaction_cancel.user.id != interaction.user.id: - await interaction_cancel.response.send_message("❌ You are not authorized to cancel this action.", ephemeral=True) + await interaction_cancel.response.send_message( + "❌ You are not authorized to cancel this action.", ephemeral=True + ) return - await interaction_cancel.response.edit_message(content="🚫 Infraction clearing cancelled.", view=None) + await interaction_cancel.response.edit_message( + content="🚫 Infraction clearing cancelled.", view=None + ) confirm_button.callback = confirm_callback cancel_button.callback = cancel_callback @@ -929,18 +1257,27 @@ class ModerationCog(commands.Cog): f"⚠️ Are you sure you want to clear **ALL** infractions for {self._user_display(member)}?\n" f"This action is irreversible. Reason: {reason or 'Not specified'}", view=view, - ephemeral=True + ephemeral=True, ) # --- Legacy Command Handlers (for prefix commands) --- @commands.command(name="timeout") - async def timeout(self, ctx: commands.Context, member: discord.Member = None, duration: str = None, *, reason: str = None): + async def timeout( + self, + ctx: commands.Context, + member: discord.Member = None, + duration: str = None, + *, + reason: str = None, + ): """Timeout a member in the server. Can be used by replying to a message.""" # Check if this is a reply to a message and no member was specified if not member and ctx.message.reference: # Get the message being replied to - replied_msg = await ctx.channel.fetch_message(ctx.message.reference.message_id) + replied_msg = await ctx.channel.fetch_message( + ctx.message.reference.message_id + ) member = replied_msg.author # Don't allow timing out the bot itself @@ -948,23 +1285,33 @@ class ModerationCog(commands.Cog): await ctx.reply("❌ I cannot timeout myself.") return elif not member: - await ctx.reply("❌ Please specify a member to timeout or reply to their message.") + await ctx.reply( + "❌ Please specify a member to timeout or reply to their message." + ) return # If duration wasn't specified but we're in a reply, check if it's the first argument - if not duration and ctx.message.reference and len(ctx.message.content.split()) > 1: + if ( + not duration + and ctx.message.reference + and len(ctx.message.content.split()) > 1 + ): # Try to extract duration from the first argument potential_duration = ctx.message.content.split()[1] # Simple check if it looks like a duration (contains numbers and letters) - if any(c.isdigit() for c in potential_duration) and any(c.isalpha() for c in potential_duration): + if any(c.isdigit() for c in potential_duration) and any( + c.isalpha() for c in potential_duration + ): duration = potential_duration # If there's more content, it's the reason if len(ctx.message.content.split()) > 2: - reason = ' '.join(ctx.message.content.split()[2:]) + reason = " ".join(ctx.message.content.split()[2:]) # Check if duration is specified if not duration: - await ctx.reply("❌ Please specify a duration for the timeout (e.g., '1d', '2h', '30m', '60s').") + await ctx.reply( + "❌ Please specify a duration for the timeout (e.g., '1d', '2h', '30m', '60s')." + ) return # Check if the user has permission to moderate members @@ -983,19 +1330,28 @@ class ModerationCog(commands.Cog): return # Check if the user is trying to timeout someone with a higher role - if ctx.author.top_role <= member.top_role and ctx.author.id != ctx.guild.owner_id: - await ctx.reply("❌ You cannot timeout someone with a higher or equal role.") + if ( + ctx.author.top_role <= member.top_role + and ctx.author.id != ctx.guild.owner_id + ): + await ctx.reply( + "❌ You cannot timeout someone with a higher or equal role." + ) return # Check if the bot can timeout the member (role hierarchy) if ctx.guild.me.top_role <= member.top_role: - await ctx.reply("❌ I cannot timeout someone with a higher or equal role than me.") + await ctx.reply( + "❌ I cannot timeout someone with a higher or equal role than me." + ) return # Parse the duration delta = self._parse_duration(duration) if not delta: - await ctx.reply("❌ Invalid duration format. Please use formats like '1d', '2h', '30m', or '60s'.") + await ctx.reply( + "❌ Invalid duration format. Please use formats like '1d', '2h', '30m', or '60s'." + ) return # Check if the duration is within Discord's limits (max 28 days) @@ -1013,13 +1369,19 @@ class ModerationCog(commands.Cog): embed = discord.Embed( title="Timeout Notice", description=f"You have been timed out in **{ctx.guild.name}** for {duration}", - color=discord.Color.gold() + color=discord.Color.gold(), + ) + embed.add_field( + name="Reason", value=reason or "No reason provided", inline=False ) - embed.add_field(name="Reason", value=reason or "No reason provided", inline=False) embed.add_field(name="Moderator", value=ctx.author.name, inline=False) embed.add_field(name="Duration", value=duration, inline=False) - embed.add_field(name="Expires", value=f"", inline=False) - embed.set_footer(text=f"Server ID: {ctx.guild.id} • {discord.utils.utcnow().strftime('%Y-%m-%d %H:%M:%S')} UTC") + embed.add_field( + name="Expires", value=f"", inline=False + ) + embed.set_footer( + text=f"Server ID: {ctx.guild.id} • {discord.utils.utcnow().strftime('%Y-%m-%d %H:%M:%S')} UTC" + ) await member.send(embed=embed) dm_sent = True @@ -1034,10 +1396,12 @@ class ModerationCog(commands.Cog): await member.timeout(until, reason=reason) # Log the action - logger.info(f"User {member} (ID: {member.id}) was timed out in {ctx.guild.name} (ID: {ctx.guild.id}) by {ctx.author} (ID: {ctx.author.id}) for {duration}. Reason: {reason}") + logger.info( + f"User {member} (ID: {member.id}) was timed out in {ctx.guild.name} (ID: {ctx.guild.id}) by {ctx.author} (ID: {ctx.author.id}) for {duration}. Reason: {reason}" + ) # --- Add to Mod Log DB --- - mod_log_cog: ModLogCog = self.bot.get_cog('ModLogCog') + mod_log_cog: ModLogCog = self.bot.get_cog("ModLogCog") if mod_log_cog: await mod_log_cog.log_action( guild=ctx.guild, @@ -1045,25 +1409,39 @@ class ModerationCog(commands.Cog): target=member, action_type="TIMEOUT", reason=reason, - duration=delta # Pass the timedelta object + duration=delta, # Pass the timedelta object ) # ------------------------- # Send confirmation message with DM status - dm_status = "✅ DM notification sent" if dm_sent else "❌ Could not send DM notification (user may have DMs disabled)" - await ctx.reply(f"⏰ **Timed out {self._user_display(member)}** for {duration}! Reason: {reason or 'No reason provided'}\n{dm_status}") + dm_status = ( + "✅ DM notification sent" + if dm_sent + else "❌ Could not send DM notification (user may have DMs disabled)" + ) + await ctx.reply( + f"⏰ **Timed out {self._user_display(member)}** for {duration}! Reason: {reason or 'No reason provided'}\n{dm_status}" + ) except discord.Forbidden: await ctx.reply("❌ I don't have permission to timeout this member.") except discord.HTTPException as e: await ctx.reply(f"❌ An error occurred while timing out the member: {e}") @commands.command(name="removetimeout") - async def removetimeout(self, ctx: commands.Context, member: discord.Member = None, *, reason: str = None): + async def removetimeout( + self, + ctx: commands.Context, + member: discord.Member = None, + *, + reason: str = None, + ): """Remove a timeout from a member. Can be used by replying to a message.""" # Check if this is a reply to a message and no member was specified if not member and ctx.message.reference: # Get the message being replied to - replied_msg = await ctx.channel.fetch_message(ctx.message.reference.message_id) + replied_msg = await ctx.channel.fetch_message( + ctx.message.reference.message_id + ) member = replied_msg.author # Don't allow removing timeout from the bot itself @@ -1071,7 +1449,9 @@ class ModerationCog(commands.Cog): await ctx.reply("❌ I cannot remove a timeout from myself.") return elif not member: - await ctx.reply("❌ Please specify a member to remove timeout from or reply to their message.") + await ctx.reply( + "❌ Please specify a member to remove timeout from or reply to their message." + ) return # Check if the user has permission to moderate members @@ -1095,11 +1475,15 @@ class ModerationCog(commands.Cog): embed = discord.Embed( title="Timeout Removed", description=f"Your timeout in **{ctx.guild.name}** has been removed", - color=discord.Color.green() + color=discord.Color.green(), + ) + embed.add_field( + name="Reason", value=reason or "No reason provided", inline=False ) - embed.add_field(name="Reason", value=reason or "No reason provided", inline=False) embed.add_field(name="Moderator", value=ctx.author.name, inline=False) - embed.set_footer(text=f"Server ID: {ctx.guild.id} • {discord.utils.utcnow().strftime('%Y-%m-%d %H:%M:%S')} UTC") + embed.set_footer( + text=f"Server ID: {ctx.guild.id} • {discord.utils.utcnow().strftime('%Y-%m-%d %H:%M:%S')} UTC" + ) await member.send(embed=embed) dm_sent = True @@ -1107,17 +1491,21 @@ class ModerationCog(commands.Cog): # User has DMs closed, ignore pass except Exception as e: - logger.error(f"Error sending timeout removal DM to {member} (ID: {member.id}): {e}") + logger.error( + f"Error sending timeout removal DM to {member} (ID: {member.id}): {e}" + ) # Perform the timeout removal try: await member.timeout(None, reason=reason) # Log the action - logger.info(f"Timeout was removed from user {member} (ID: {member.id}) in {ctx.guild.name} (ID: {ctx.guild.id}) by {ctx.author} (ID: {ctx.author.id}). Reason: {reason}") + logger.info( + f"Timeout was removed from user {member} (ID: {member.id}) in {ctx.guild.name} (ID: {ctx.guild.id}) by {ctx.author} (ID: {ctx.author.id}). Reason: {reason}" + ) # --- Add to Mod Log DB --- - mod_log_cog: ModLogCog = self.bot.get_cog('ModLogCog') + mod_log_cog: ModLogCog = self.bot.get_cog("ModLogCog") if mod_log_cog: await mod_log_cog.log_action( guild=ctx.guild, @@ -1125,24 +1513,34 @@ class ModerationCog(commands.Cog): target=member, action_type="REMOVE_TIMEOUT", reason=reason, - duration=None + duration=None, ) # ------------------------- # Send confirmation message with DM status - dm_status = "✅ DM notification sent" if dm_sent else "❌ Could not send DM notification (user may have DMs disabled)" - await ctx.reply(f"⏰ **Removed timeout from {self._user_display(member)}**! Reason: {reason or 'No reason provided'}\n{dm_status}") + dm_status = ( + "✅ DM notification sent" + if dm_sent + else "❌ Could not send DM notification (user may have DMs disabled)" + ) + await ctx.reply( + f"⏰ **Removed timeout from {self._user_display(member)}**! Reason: {reason or 'No reason provided'}\n{dm_status}" + ) except discord.Forbidden: - await ctx.reply("❌ I don't have permission to remove the timeout from this member.") + await ctx.reply( + "❌ I don't have permission to remove the timeout from this member." + ) except discord.HTTPException as e: await ctx.reply(f"❌ An error occurred while removing the timeout: {e}") @commands.Cog.listener() async def on_ready(self): - print(f'{self.__class__.__name__} cog has been loaded.') + print(f"{self.__class__.__name__} cog has been loaded.") + # Modals for context menu commands + class BanModal(discord.ui.Modal, title="Ban User"): def __init__(self, member: discord.Member): super().__init__() @@ -1165,16 +1563,29 @@ class BanModal(discord.ui.Modal, title="Ban User"): ) async def on_submit(self, interaction: discord.Interaction): - await interaction.response.defer(ephemeral=True) # Defer the modal submission + await interaction.response.defer(ephemeral=True) # Defer the modal submission cog = interaction.client.get_cog("ModerationCog") if cog: reason = self.reason.value or "No reason provided" - delete_days = int(self.delete_days.value) if self.delete_days.value and self.delete_days.value.isdigit() else 0 + delete_days = ( + int(self.delete_days.value) + if self.delete_days.value and self.delete_days.value.isdigit() + else 0 + ) # Call the existing ban callback - await cog.moderate_ban_callback(interaction, self.member, reason=reason, delete_days=delete_days, send_dm=self.send_dm) + await cog.moderate_ban_callback( + interaction, + self.member, + reason=reason, + delete_days=delete_days, + send_dm=self.send_dm, + ) else: - await interaction.followup.send("Error: Moderation cog not found.", ephemeral=True) + await interaction.followup.send( + "Error: Moderation cog not found.", ephemeral=True + ) + class KickModal(discord.ui.Modal, title="Kick User"): def __init__(self, member: discord.Member): @@ -1190,7 +1601,7 @@ class KickModal(discord.ui.Modal, title="Kick User"): ) async def on_submit(self, interaction: discord.Interaction): - await interaction.response.defer(ephemeral=True) # Defer the modal submission + await interaction.response.defer(ephemeral=True) # Defer the modal submission cog = interaction.client.get_cog("ModerationCog") if cog: @@ -1198,7 +1609,10 @@ class KickModal(discord.ui.Modal, title="Kick User"): # Call the existing kick callback await cog.moderate_kick_callback(interaction, self.member, reason=reason) else: - await interaction.followup.send("Error: Moderation cog not found.", ephemeral=True) + await interaction.followup.send( + "Error: Moderation cog not found.", ephemeral=True + ) + class TimeoutModal(discord.ui.Modal, title="Timeout User"): def __init__(self, member: discord.Member): @@ -1221,16 +1635,21 @@ class TimeoutModal(discord.ui.Modal, title="Timeout User"): ) async def on_submit(self, interaction: discord.Interaction): - await interaction.response.defer(ephemeral=True) # Defer the modal submission + await interaction.response.defer(ephemeral=True) # Defer the modal submission cog = interaction.client.get_cog("ModerationCog") if cog: duration = self.duration.value reason = self.reason.value or "No reason provided" # Call the existing timeout callback - await cog.moderate_timeout_callback(interaction, self.member, duration=duration, reason=reason) + await cog.moderate_timeout_callback( + interaction, self.member, duration=duration, reason=reason + ) else: - await interaction.followup.send("Error: Moderation cog not found.", ephemeral=True) + await interaction.followup.send( + "Error: Moderation cog not found.", ephemeral=True + ) + class RemoveTimeoutModal(discord.ui.Modal, title="Remove Timeout"): def __init__(self, member: discord.Member): @@ -1246,19 +1665,24 @@ class RemoveTimeoutModal(discord.ui.Modal, title="Remove Timeout"): ) async def on_submit(self, interaction: discord.Interaction): - await interaction.response.defer(ephemeral=True) # Defer the modal submission + await interaction.response.defer(ephemeral=True) # Defer the modal submission cog = interaction.client.get_cog("ModerationCog") if cog: reason = self.reason.value or "No reason provided" # Call the existing remove timeout callback - await cog.moderate_remove_timeout_callback(interaction, self.member, reason=reason) + await cog.moderate_remove_timeout_callback( + interaction, self.member, reason=reason + ) else: - await interaction.followup.send("Error: Moderation cog not found.", ephemeral=True) + await interaction.followup.send( + "Error: Moderation cog not found.", ephemeral=True + ) # Context menu commands must be defined at module level + class BanOptionsView(discord.ui.View): def __init__(self, member: discord.Member): super().__init__(timeout=60) # 60 second timeout @@ -1268,17 +1692,29 @@ class BanOptionsView(discord.ui.View): def update_button_label(self): self.toggle_dm_button.label = f"Send DM: {'Yes' if self.send_dm else 'No'}" - self.toggle_dm_button.style = discord.ButtonStyle.green if self.send_dm else discord.ButtonStyle.red + self.toggle_dm_button.style = ( + discord.ButtonStyle.green if self.send_dm else discord.ButtonStyle.red + ) - @discord.ui.button(label="Send DM: Yes", style=discord.ButtonStyle.green, custom_id="toggle_dm") - async def toggle_dm_button(self, interaction: discord.Interaction, _: discord.ui.Button): + @discord.ui.button( + label="Send DM: Yes", style=discord.ButtonStyle.green, custom_id="toggle_dm" + ) + async def toggle_dm_button( + self, interaction: discord.Interaction, _: discord.ui.Button + ): # Toggle the send_dm value self.send_dm = not self.send_dm self.update_button_label() await interaction.response.edit_message(view=self) - @discord.ui.button(label="Continue to Ban", style=discord.ButtonStyle.danger, custom_id="continue_ban") - async def continue_button(self, interaction: discord.Interaction, _: discord.ui.Button): + @discord.ui.button( + label="Continue to Ban", + style=discord.ButtonStyle.danger, + custom_id="continue_ban", + ) + async def continue_button( + self, interaction: discord.Interaction, _: discord.ui.Button + ): # Create and show the modal modal = BanModal(self.member) modal.send_dm = self.send_dm # Pass the send_dm setting to the modal @@ -1286,27 +1722,46 @@ class BanOptionsView(discord.ui.View): # Stop listening for interactions on this view self.stop() + @app_commands.context_menu(name="Ban User") -async def ban_user_context_menu(interaction: discord.Interaction, member: discord.Member): +async def ban_user_context_menu( + interaction: discord.Interaction, member: discord.Member +): """Bans the selected user via a modal.""" # Check permissions before showing the modal if not interaction.user.guild_permissions.ban_members: - await interaction.response.send_message("❌ You don't have permission to ban members.", ephemeral=True) + await interaction.response.send_message( + "❌ You don't have permission to ban members.", ephemeral=True + ) return if not interaction.guild.me.guild_permissions.ban_members: - await interaction.response.send_message("❌ I don't have permission to ban members.", ephemeral=True) + await interaction.response.send_message( + "❌ I don't have permission to ban members.", ephemeral=True + ) + return + if ( + interaction.user.top_role <= member.top_role + and interaction.user.id != interaction.guild.owner_id + ): + await interaction.response.send_message( + "❌ You cannot ban someone with a higher or equal role.", ephemeral=True + ) return - if interaction.user.top_role <= member.top_role and interaction.user.id != interaction.guild.owner_id: - await interaction.response.send_message("❌ You cannot ban someone with a higher or equal role.", ephemeral=True) - return if interaction.guild.me.top_role <= member.top_role: - await interaction.response.send_message("❌ I cannot ban someone with a higher or equal role than me.", ephemeral=True) - return + await interaction.response.send_message( + "❌ I cannot ban someone with a higher or equal role than me.", + ephemeral=True, + ) + return if member.id == interaction.user.id: - await interaction.response.send_message("❌ You cannot ban yourself.", ephemeral=True) + await interaction.response.send_message( + "❌ You cannot ban yourself.", ephemeral=True + ) return if member.id == interaction.client.user.id: - await interaction.response.send_message("❌ I cannot ban myself.", ephemeral=True) + await interaction.response.send_message( + "❌ I cannot ban myself.", ephemeral=True + ) return # Show options view first @@ -1314,74 +1769,121 @@ async def ban_user_context_menu(interaction: discord.Interaction, member: discor await interaction.response.send_message( f"⚠️ You are about to ban **{member.display_name}** ({member.id}).\nPlease select your options:", view=view, - ephemeral=True + ephemeral=True, ) + @app_commands.context_menu(name="Kick User") -async def kick_user_context_menu(interaction: discord.Interaction, member: discord.Member): +async def kick_user_context_menu( + interaction: discord.Interaction, member: discord.Member +): """Kicks the selected user via a modal.""" # Check permissions before showing the modal if not interaction.user.guild_permissions.kick_members: - await interaction.response.send_message("❌ You don't have permission to kick members.", ephemeral=True) + await interaction.response.send_message( + "❌ You don't have permission to kick members.", ephemeral=True + ) return if not interaction.guild.me.guild_permissions.kick_members: - await interaction.response.send_message("❌ I don't have permission to kick members.", ephemeral=True) + await interaction.response.send_message( + "❌ I don't have permission to kick members.", ephemeral=True + ) + return + if ( + interaction.user.top_role <= member.top_role + and interaction.user.id != interaction.guild.owner_id + ): + await interaction.response.send_message( + "❌ You cannot kick someone with a higher or equal role.", ephemeral=True + ) return - if interaction.user.top_role <= member.top_role and interaction.user.id != interaction.guild.owner_id: - await interaction.response.send_message("❌ You cannot kick someone with a higher or equal role.", ephemeral=True) - return if interaction.guild.me.top_role <= member.top_role: - await interaction.response.send_message("❌ I cannot kick someone with a higher or equal role than me.", ephemeral=True) - return + await interaction.response.send_message( + "❌ I cannot kick someone with a higher or equal role than me.", + ephemeral=True, + ) + return if member.id == interaction.user.id: - await interaction.response.send_message("❌ You cannot kick yourself.", ephemeral=True) + await interaction.response.send_message( + "❌ You cannot kick yourself.", ephemeral=True + ) return if member.id == interaction.client.user.id: - await interaction.response.send_message("❌ I cannot kick myself.", ephemeral=True) + await interaction.response.send_message( + "❌ I cannot kick myself.", ephemeral=True + ) return modal = KickModal(member) await interaction.response.send_modal(modal) + @app_commands.context_menu(name="Timeout User") -async def timeout_user_context_menu(interaction: discord.Interaction, member: discord.Member): +async def timeout_user_context_menu( + interaction: discord.Interaction, member: discord.Member +): """Timeouts the selected user via a modal.""" # Check permissions before showing the modal if not interaction.user.guild_permissions.moderate_members: - await interaction.response.send_message("❌ You don't have permission to timeout members.", ephemeral=True) + await interaction.response.send_message( + "❌ You don't have permission to timeout members.", ephemeral=True + ) return if not interaction.guild.me.guild_permissions.moderate_members: - await interaction.response.send_message("❌ I don't have permission to timeout members.", ephemeral=True) + await interaction.response.send_message( + "❌ I don't have permission to timeout members.", ephemeral=True + ) + return + if ( + interaction.user.top_role <= member.top_role + and interaction.user.id != interaction.guild.owner_id + ): + await interaction.response.send_message( + "❌ You cannot timeout someone with a higher or equal role.", ephemeral=True + ) return - if interaction.user.top_role <= member.top_role and interaction.user.id != interaction.guild.owner_id: - await interaction.response.send_message("❌ You cannot timeout someone with a higher or equal role.", ephemeral=True) - return if interaction.guild.me.top_role <= member.top_role: - await interaction.response.send_message("❌ I cannot timeout someone with a higher or equal role than me.", ephemeral=True) - return + await interaction.response.send_message( + "❌ I cannot timeout someone with a higher or equal role than me.", + ephemeral=True, + ) + return if member.id == interaction.user.id: - await interaction.response.send_message("❌ You cannot timeout yourself.", ephemeral=True) + await interaction.response.send_message( + "❌ You cannot timeout yourself.", ephemeral=True + ) return if member.id == interaction.client.user.id: - await interaction.response.send_message("❌ I cannot timeout myself.", ephemeral=True) + await interaction.response.send_message( + "❌ I cannot timeout myself.", ephemeral=True + ) return modal = TimeoutModal(member) await interaction.response.send_modal(modal) + @app_commands.context_menu(name="Remove Timeout") -async def remove_timeout_context_menu(interaction: discord.Interaction, member: discord.Member): +async def remove_timeout_context_menu( + interaction: discord.Interaction, member: discord.Member +): """Removes timeout from the selected user via a modal.""" # Check permissions before showing the modal if not interaction.user.guild_permissions.moderate_members: - await interaction.response.send_message("❌ You don't have permission to remove timeouts.", ephemeral=True) + await interaction.response.send_message( + "❌ You don't have permission to remove timeouts.", ephemeral=True + ) return if not interaction.guild.me.guild_permissions.moderate_members: - await interaction.response.send_message("❌ I don't have permission to remove timeouts.", ephemeral=True) + await interaction.response.send_message( + "❌ I don't have permission to remove timeouts.", ephemeral=True + ) return # Check if the member is timed out before showing the modal if not member.timed_out_until: - await interaction.response.send_message("❌ This member is not timed out.", ephemeral=True) + await interaction.response.send_message( + "❌ This member is not timed out.", ephemeral=True + ) return modal = RemoveTimeoutModal(member) diff --git a/cogs/role_creator_cog.py b/cogs/role_creator_cog.py index bcbe08c..3ec822a 100644 --- a/cogs/role_creator_cog.py +++ b/cogs/role_creator_cog.py @@ -5,19 +5,27 @@ from dotenv import load_dotenv import logging # Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s') +logging.basicConfig( + level=logging.INFO, format="%(asctime)s:%(levelname)s:%(name)s: %(message)s" +) logger = logging.getLogger(__name__) # Load environment variables load_dotenv() -OWNER_USER_ID = int(os.getenv("OWNER_USER_ID")) # Although commands.is_owner() handles this, loading for clarity/potential future use +OWNER_USER_ID = int( + os.getenv("OWNER_USER_ID") +) # Although commands.is_owner() handles this, loading for clarity/potential future use + class RoleCreatorCog(commands.Cog): def __init__(self, bot): self.bot = bot - @commands.command(name='create_roles', help='Creates predefined roles for reaction roles. Owner only.') - @commands.is_owner() # Restricts this command to the bot owner specified during bot setup + @commands.command( + name="create_roles", + help="Creates predefined roles for reaction roles. Owner only.", + ) + @commands.is_owner() # Restricts this command to the bot owner specified during bot setup async def create_roles(self, ctx): """Creates a set of predefined roles typically used with reaction roles.""" guild = ctx.guild @@ -28,7 +36,9 @@ class RoleCreatorCog(commands.Cog): # Check if the bot has permission to manage roles if not ctx.me.guild_permissions.manage_roles: await ctx.send("I don't have permission to manage roles.") - logger.warning(f"Missing 'Manage Roles' permission in guild {guild.id} ({guild.name}).") + logger.warning( + f"Missing 'Manage Roles' permission in guild {guild.id} ({guild.name})." + ) return # Define color mapping for specific roles @@ -40,28 +50,102 @@ class RoleCreatorCog(commands.Cog): "Purple": discord.Color.purple(), "Orange": discord.Color.orange(), "Pink": discord.Color.fuchsia(), - "Black": discord.Color(0x010101), # Near black to avoid blending with themes - "White": discord.Color(0xFEFEFE) # Near white to avoid blending + "Black": discord.Color( + 0x010101 + ), # Near black to avoid blending with themes + "White": discord.Color(0xFEFEFE), # Near white to avoid blending } await ctx.send("Starting role creation/update process...") - logger.info(f"Role creation/update initiated by {ctx.author} in guild {guild.id} ({guild.name}).") + logger.info( + f"Role creation/update initiated by {ctx.author} in guild {guild.id} ({guild.name})." + ) role_categories = { - "Colors": ["Red", "Blue", "Green", "Yellow", "Purple", "Orange", "Pink", "Black", "White"], + "Colors": [ + "Red", + "Blue", + "Green", + "Yellow", + "Purple", + "Orange", + "Pink", + "Black", + "White", + ], "Regions": ["NA East", "NA West", "EU", "Asia", "Oceania", "South America"], "Pronouns": ["He/Him", "She/Her", "They/Them", "Ask Pronouns"], - "Interests": ["Art", "Music", "Movies", "Books", "Technology", "Science", "History", "Food", "Programming", "Anime", "Photography", "Travel", "Writing", "Cooking", "Fitness", "Nature", "Gaming", "Philosophy", "Psychology", "Design", "Machine Learning", "Cryptocurrency", "Astronomy", "Mythology", "Languages", "Architecture", "DIY Projects", "Hiking", "Streaming", "Virtual Reality", "Coding Challenges", "Board Games", "Meditation", "Urban Exploration", "Tattoo Art", "Comics", "Robotics", "3D Modeling", "Podcasts"], - "Gaming Platforms": ["PC", "PlayStation", "Xbox", "Nintendo Switch", "Mobile"], - "Favorite Vocaloids": ["Hatsune Miku", "Kasane Teto", "Akita Neru", "Kagamine Rin", "Kagamine Len", "Megurine Luka", "Kaito", "Meiko", "Gumi", "Kaai Yuki", "Adachi Rei"], - "Notifications": ["Announcements"] + "Interests": [ + "Art", + "Music", + "Movies", + "Books", + "Technology", + "Science", + "History", + "Food", + "Programming", + "Anime", + "Photography", + "Travel", + "Writing", + "Cooking", + "Fitness", + "Nature", + "Gaming", + "Philosophy", + "Psychology", + "Design", + "Machine Learning", + "Cryptocurrency", + "Astronomy", + "Mythology", + "Languages", + "Architecture", + "DIY Projects", + "Hiking", + "Streaming", + "Virtual Reality", + "Coding Challenges", + "Board Games", + "Meditation", + "Urban Exploration", + "Tattoo Art", + "Comics", + "Robotics", + "3D Modeling", + "Podcasts", + ], + "Gaming Platforms": [ + "PC", + "PlayStation", + "Xbox", + "Nintendo Switch", + "Mobile", + ], + "Favorite Vocaloids": [ + "Hatsune Miku", + "Kasane Teto", + "Akita Neru", + "Kagamine Rin", + "Kagamine Len", + "Megurine Luka", + "Kaito", + "Meiko", + "Gumi", + "Kaai Yuki", + "Adachi Rei", + ], + "Notifications": ["Announcements"], } created_count = 0 - updated_count = 0 # Renamed from eped_count - skipped_other_count = 0 # For non-color roles that exist + updated_count = 0 # Renamed from eped_count + skipped_other_count = 0 # For non-color roles that exist error_count = 0 - existing_roles = {role.name.lower(): role for role in guild.roles} # Cache existing roles for faster lookup + existing_roles = { + role.name.lower(): role for role in guild.roles + } # Cache existing roles for faster lookup for category, names in role_categories.items(): logger.info(f"Processing category: {category}") @@ -74,52 +158,77 @@ class RoleCreatorCog(commands.Cog): existing_role = existing_roles[name.lower()] # Only edit if it's a color role and needs a color update (or just ensure color is set) if category == "Colors" and role_color is not None: - # Check if color needs updating to avoid unnecessary API calls - if existing_role.color != role_color: - await existing_role.edit(color=role_color) - logger.info(f"Successfully updated color for existing role: {name}") - updated_count += 1 - else: - logger.info(f"Role '{name}' already exists with correct color. Skipping update.") - updated_count += 1 # Count as updated/checked even if no change needed + # Check if color needs updating to avoid unnecessary API calls + if existing_role.color != role_color: + await existing_role.edit(color=role_color) + logger.info( + f"Successfully updated color for existing role: {name}" + ) + updated_count += 1 + else: + logger.info( + f"Role '{name}' already exists with correct color. Skipping update." + ) + updated_count += 1 # Count as updated/checked even if no change needed else: # Non-color role exists, skip it - logger.info(f"Non-color role '{name}' already exists. Skipping.") + logger.info( + f"Non-color role '{name}' already exists. Skipping." + ) skipped_other_count += 1 - continue # Move to next role name + continue # Move to next role name # Role does not exist, create it await guild.create_role( name=name, - color=role_color or discord.Color.default(), # Use mapped color or default + color=role_color + or discord.Color.default(), # Use mapped color or default permissions=discord.Permissions.none(), - mentionable=False + mentionable=False, + ) + logger.info( + f"Successfully created role: {name}" + + (f" with color {role_color}" if role_color else "") ) - logger.info(f"Successfully created role: {name}" + (f" with color {role_color}" if role_color else "")) created_count += 1 except discord.Forbidden: - logger.error(f"Forbidden to {'edit' if role_exists else 'create'} role '{name}'. Check bot permissions.") - await ctx.send(f"Error: I lack permissions to {'edit' if role_exists else 'create'} the role '{name}'.") + logger.error( + f"Forbidden to {'edit' if role_exists else 'create'} role '{name}'. Check bot permissions." + ) + await ctx.send( + f"Error: I lack permissions to {'edit' if role_exists else 'create'} the role '{name}'." + ) error_count += 1 # Stop if permission error occurs, as it likely affects subsequent operations - await ctx.send(f"Stopping role processing due to permission error on role '{name}'.") + await ctx.send( + f"Stopping role processing due to permission error on role '{name}'." + ) return except discord.HTTPException as e: - logger.error(f"Failed to {'edit' if role_exists else 'create'} role '{name}': {e}") - await ctx.send(f"Error {'editing' if role_exists else 'creating'} role '{name}': {e}") + logger.error( + f"Failed to {'edit' if role_exists else 'create'} role '{name}': {e}" + ) + await ctx.send( + f"Error {'editing' if role_exists else 'creating'} role '{name}': {e}" + ) error_count += 1 except Exception as e: - logger.exception(f"An unexpected error occurred while processing role '{name}': {e}") - await ctx.send(f"An unexpected error occurred for role '{name}'. Check logs.") + logger.exception( + f"An unexpected error occurred while processing role '{name}': {e}" + ) + await ctx.send( + f"An unexpected error occurred for role '{name}'. Check logs." + ) error_count += 1 - - summary_message = f"Role creation/update process complete.\n" \ - f"Created: {created_count}\n" \ - f"Updated/Checked Colors: {updated_count}\n" \ - f"Skipped (Other existing): {skipped_other_count}\n" \ - f"Errors: {error_count}" + summary_message = ( + f"Role creation/update process complete.\n" + f"Created: {created_count}\n" + f"Updated/Checked Colors: {updated_count}\n" + f"Skipped (Other existing): {skipped_other_count}\n" + f"Errors: {error_count}" + ) await ctx.send(summary_message) logger.info(summary_message) @@ -127,13 +236,17 @@ class RoleCreatorCog(commands.Cog): async def setup(bot): # Ensure the owner ID is loaded correctly before adding the cog if not OWNER_USER_ID: - logger.error("OWNER_USER_ID not found in .env file. RoleCreatorCog will not be loaded.") + logger.error( + "OWNER_USER_ID not found in .env file. RoleCreatorCog will not be loaded." + ) return # Check if the bot object has owner_id or owner_ids set, which discord.py uses for is_owner() if not bot.owner_id and not bot.owner_ids: - logger.warning("Bot owner_id or owner_ids not set. The 'is_owner()' check might not function correctly.") - # Potentially load from OWNER_USER_ID if needed, though discord.py usually handles this - # bot.owner_id = OWNER_USER_ID # Uncomment if necessary and discord.py doesn't auto-load + logger.warning( + "Bot owner_id or owner_ids not set. The 'is_owner()' check might not function correctly." + ) + # Potentially load from OWNER_USER_ID if needed, though discord.py usually handles this + # bot.owner_id = OWNER_USER_ID # Uncomment if necessary and discord.py doesn't auto-load await bot.add_cog(RoleCreatorCog(bot)) logger.info("RoleCreatorCog loaded successfully.") diff --git a/cogs/role_management_cog.py b/cogs/role_management_cog.py index f678ef3..532a38f 100644 --- a/cogs/role_management_cog.py +++ b/cogs/role_management_cog.py @@ -7,49 +7,49 @@ from typing import Optional, List # Set up logging logger = logging.getLogger(__name__) + class RoleManagementCog(commands.Cog): """Cog for comprehensive role management""" def __init__(self, bot): self.bot = bot - + # Create the main command group for this cog self.role_group = app_commands.Group( - name="role", - description="Manage server roles" + name="role", description="Manage server roles" ) - + # Register commands self.register_commands() - + # Add command group to the bot's tree self.bot.tree.add_command(self.role_group) - + def register_commands(self): """Register all commands for this cog""" - + # --- Create Role Command --- create_command = app_commands.Command( name="create", description="Create a new role", callback=self.role_create_callback, - parent=self.role_group + parent=self.role_group, ) app_commands.describe( name="The name of the new role", color="The color of the role in hex format (e.g., #FF0000 for red)", mentionable="Whether the role can be mentioned by everyone", hoist="Whether the role should be displayed separately in the member list", - reason="The reason for creating this role" + reason="The reason for creating this role", )(create_command) self.role_group.add_command(create_command) - + # --- Edit Role Command --- edit_command = app_commands.Command( name="edit", description="Edit an existing role", callback=self.role_edit_callback, - parent=self.role_group + parent=self.role_group, ) app_commands.describe( role="The role to edit", @@ -57,118 +57,130 @@ class RoleManagementCog(commands.Cog): color="New color for the role in hex format (e.g., #FF0000 for red)", mentionable="Whether the role can be mentioned by everyone", hoist="Whether the role should be displayed separately in the member list", - reason="The reason for editing this role" + reason="The reason for editing this role", )(edit_command) self.role_group.add_command(edit_command) - + # --- Delete Role Command --- delete_command = app_commands.Command( name="delete", description="Delete a role", callback=self.role_delete_callback, - parent=self.role_group + parent=self.role_group, ) app_commands.describe( - role="The role to delete", - reason="The reason for deleting this role" + role="The role to delete", reason="The reason for deleting this role" )(delete_command) self.role_group.add_command(delete_command) - + # --- Add Role Command --- add_command = app_commands.Command( name="add", description="Add a role to a user", callback=self.role_add_callback, - parent=self.role_group + parent=self.role_group, ) app_commands.describe( member="The member to add the role to", role="The role to add", - reason="The reason for adding this role" + reason="The reason for adding this role", )(add_command) self.role_group.add_command(add_command) - + # --- Remove Role Command --- remove_command = app_commands.Command( name="remove", description="Remove a role from a user", callback=self.role_remove_callback, - parent=self.role_group + parent=self.role_group, ) app_commands.describe( member="The member to remove the role from", role="The role to remove", - reason="The reason for removing this role" + reason="The reason for removing this role", )(remove_command) self.role_group.add_command(remove_command) - + # --- List Roles Command --- list_command = app_commands.Command( name="list", description="List all roles in the server", callback=self.role_list_callback, - parent=self.role_group + parent=self.role_group, ) self.role_group.add_command(list_command) - + # --- Role Info Command --- info_command = app_commands.Command( name="info", description="View detailed information about a role", callback=self.role_info_callback, - parent=self.role_group + parent=self.role_group, ) - app_commands.describe( - role="The role to view information about" - )(info_command) + app_commands.describe(role="The role to view information about")(info_command) self.role_group.add_command(info_command) - + # --- Change Role Position Command --- position_command = app_commands.Command( name="position", description="Change a role's position in the hierarchy", callback=self.role_position_callback, - parent=self.role_group + parent=self.role_group, ) app_commands.describe( role="The role to move", position="The new position for the role (1 is the lowest, excluding @everyone)", - reason="The reason for changing this role's position" + reason="The reason for changing this role's position", )(position_command) self.role_group.add_command(position_command) - + # --- Command Callbacks --- - - async def role_create_callback(self, interaction: discord.Interaction, name: str, - color: Optional[str] = None, mentionable: Optional[bool] = False, - hoist: Optional[bool] = False, reason: Optional[str] = None): + + async def role_create_callback( + self, + interaction: discord.Interaction, + name: str, + color: Optional[str] = None, + mentionable: Optional[bool] = False, + hoist: Optional[bool] = False, + reason: Optional[str] = None, + ): """Callback for /role create command""" # Check permissions if not interaction.guild: - await interaction.response.send_message("This command can only be used in a server.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used in a server.", ephemeral=True + ) return - + if not interaction.user.guild_permissions.manage_roles: - await interaction.response.send_message("You don't have permission to manage roles.", ephemeral=True) + await interaction.response.send_message( + "You don't have permission to manage roles.", ephemeral=True + ) return - + if not interaction.guild.me.guild_permissions.manage_roles: - await interaction.response.send_message("I don't have permission to manage roles.", ephemeral=True) + await interaction.response.send_message( + "I don't have permission to manage roles.", ephemeral=True + ) return - + # Parse color if provided role_color = discord.Color.default() if color: try: # Remove # if present - if color.startswith('#'): + if color.startswith("#"): color = color[1:] # Convert hex to int role_color = discord.Color(int(color, 16)) except ValueError: - await interaction.response.send_message(f"Invalid color format. Please use hex format (e.g., #FF0000 for red).", ephemeral=True) + await interaction.response.send_message( + f"Invalid color format. Please use hex format (e.g., #FF0000 for red).", + ephemeral=True, + ) return - + try: # Create the role new_role = await interaction.guild.create_role( @@ -176,266 +188,374 @@ class RoleManagementCog(commands.Cog): color=role_color, hoist=hoist, mentionable=mentionable, - reason=f"{reason or 'No reason provided'} (Created by {interaction.user})" + reason=f"{reason or 'No reason provided'} (Created by {interaction.user})", ) - + # Create an embed with role information embed = discord.Embed( title="✅ Role Created", description=f"Successfully created role {new_role.mention}", - color=role_color + color=role_color, ) embed.add_field(name="Name", value=name, inline=True) embed.add_field(name="Color", value=str(role_color), inline=True) embed.add_field(name="Hoisted", value="Yes" if hoist else "No", inline=True) - embed.add_field(name="Mentionable", value="Yes" if mentionable else "No", inline=True) - embed.add_field(name="Created by", value=interaction.user.mention, inline=True) + embed.add_field( + name="Mentionable", value="Yes" if mentionable else "No", inline=True + ) + embed.add_field( + name="Created by", value=interaction.user.mention, inline=True + ) if reason: embed.add_field(name="Reason", value=reason, inline=False) - + await interaction.response.send_message(embed=embed) - logger.info(f"Role '{name}' created by {interaction.user} in {interaction.guild.name}") - + logger.info( + f"Role '{name}' created by {interaction.user} in {interaction.guild.name}" + ) + except discord.Forbidden: - await interaction.response.send_message("I don't have permission to create roles.", ephemeral=True) + await interaction.response.send_message( + "I don't have permission to create roles.", ephemeral=True + ) except discord.HTTPException as e: - await interaction.response.send_message(f"Failed to create role: {e}", ephemeral=True) - - async def role_edit_callback(self, interaction: discord.Interaction, role: discord.Role, - name: Optional[str] = None, color: Optional[str] = None, - mentionable: Optional[bool] = None, hoist: Optional[bool] = None, - reason: Optional[str] = None): + await interaction.response.send_message( + f"Failed to create role: {e}", ephemeral=True + ) + + async def role_edit_callback( + self, + interaction: discord.Interaction, + role: discord.Role, + name: Optional[str] = None, + color: Optional[str] = None, + mentionable: Optional[bool] = None, + hoist: Optional[bool] = None, + reason: Optional[str] = None, + ): """Callback for /role edit command""" # Check permissions if not interaction.guild: - await interaction.response.send_message("This command can only be used in a server.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used in a server.", ephemeral=True + ) return - + if not interaction.user.guild_permissions.manage_roles: - await interaction.response.send_message("You don't have permission to manage roles.", ephemeral=True) + await interaction.response.send_message( + "You don't have permission to manage roles.", ephemeral=True + ) return - + if not interaction.guild.me.guild_permissions.manage_roles: - await interaction.response.send_message("I don't have permission to manage roles.", ephemeral=True) + await interaction.response.send_message( + "I don't have permission to manage roles.", ephemeral=True + ) return - + # Check if the role is manageable if not role.is_assignable() or role.is_default(): - await interaction.response.send_message("I cannot edit this role. It might be the @everyone role or higher than my highest role.", ephemeral=True) + await interaction.response.send_message( + "I cannot edit this role. It might be the @everyone role or higher than my highest role.", + ephemeral=True, + ) return - + # Parse color if provided role_color = None if color: try: # Remove # if present - if color.startswith('#'): + if color.startswith("#"): color = color[1:] # Convert hex to int role_color = discord.Color(int(color, 16)) except ValueError: - await interaction.response.send_message(f"Invalid color format. Please use hex format (e.g., #FF0000 for red).", ephemeral=True) + await interaction.response.send_message( + f"Invalid color format. Please use hex format (e.g., #FF0000 for red).", + ephemeral=True, + ) return - + # Store original values for the embed original_name = role.name original_color = role.color original_mentionable = role.mentionable original_hoist = role.hoist - + try: # Edit the role await role.edit( name=name if name is not None else role.name, color=role_color if role_color is not None else role.color, hoist=hoist if hoist is not None else role.hoist, - mentionable=mentionable if mentionable is not None else role.mentionable, - reason=f"{reason or 'No reason provided'} (Edited by {interaction.user})" + mentionable=( + mentionable if mentionable is not None else role.mentionable + ), + reason=f"{reason or 'No reason provided'} (Edited by {interaction.user})", ) - + # Create an embed with role information embed = discord.Embed( title="✅ Role Edited", description=f"Successfully edited role {role.mention}", - color=role.color + color=role.color, ) - + # Only show fields that were changed if name is not None and name != original_name: - embed.add_field(name="Name", value=f"{original_name} → {name}", inline=True) + embed.add_field( + name="Name", value=f"{original_name} → {name}", inline=True + ) if role_color is not None and role_color != original_color: - embed.add_field(name="Color", value=f"{original_color} → {role.color}", inline=True) + embed.add_field( + name="Color", value=f"{original_color} → {role.color}", inline=True + ) if hoist is not None and hoist != original_hoist: - embed.add_field(name="Hoisted", value=f"{'Yes' if original_hoist else 'No'} → {'Yes' if hoist else 'No'}", inline=True) + embed.add_field( + name="Hoisted", + value=f"{'Yes' if original_hoist else 'No'} → {'Yes' if hoist else 'No'}", + inline=True, + ) if mentionable is not None and mentionable != original_mentionable: - embed.add_field(name="Mentionable", value=f"{'Yes' if original_mentionable else 'No'} → {'Yes' if mentionable else 'No'}", inline=True) - - embed.add_field(name="Edited by", value=interaction.user.mention, inline=True) + embed.add_field( + name="Mentionable", + value=f"{'Yes' if original_mentionable else 'No'} → {'Yes' if mentionable else 'No'}", + inline=True, + ) + + embed.add_field( + name="Edited by", value=interaction.user.mention, inline=True + ) if reason: embed.add_field(name="Reason", value=reason, inline=False) - + await interaction.response.send_message(embed=embed) - logger.info(f"Role '{role.name}' edited by {interaction.user} in {interaction.guild.name}") - + logger.info( + f"Role '{role.name}' edited by {interaction.user} in {interaction.guild.name}" + ) + except discord.Forbidden: - await interaction.response.send_message("I don't have permission to edit this role.", ephemeral=True) + await interaction.response.send_message( + "I don't have permission to edit this role.", ephemeral=True + ) except discord.HTTPException as e: - await interaction.response.send_message(f"Failed to edit role: {e}", ephemeral=True) - - async def role_delete_callback(self, interaction: discord.Interaction, role: discord.Role, - reason: Optional[str] = None): + await interaction.response.send_message( + f"Failed to edit role: {e}", ephemeral=True + ) + + async def role_delete_callback( + self, + interaction: discord.Interaction, + role: discord.Role, + reason: Optional[str] = None, + ): """Callback for /role delete command""" # Check permissions if not interaction.guild: - await interaction.response.send_message("This command can only be used in a server.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used in a server.", ephemeral=True + ) return - + if not interaction.user.guild_permissions.manage_roles: - await interaction.response.send_message("You don't have permission to manage roles.", ephemeral=True) + await interaction.response.send_message( + "You don't have permission to manage roles.", ephemeral=True + ) return - + if not interaction.guild.me.guild_permissions.manage_roles: - await interaction.response.send_message("I don't have permission to manage roles.", ephemeral=True) + await interaction.response.send_message( + "I don't have permission to manage roles.", ephemeral=True + ) return - + # Check if the role is manageable if not role.is_assignable() or role.is_default(): - await interaction.response.send_message("I cannot delete this role. It might be the @everyone role or higher than my highest role.", ephemeral=True) + await interaction.response.send_message( + "I cannot delete this role. It might be the @everyone role or higher than my highest role.", + ephemeral=True, + ) return - + # Store role info for the confirmation message role_name = role.name role_color = role.color role_members_count = len(role.members) - + # Confirmation message embed = discord.Embed( title="⚠️ Confirm Role Deletion", description=f"Are you sure you want to delete the role **{role_name}**?", - color=role_color + color=role_color, ) embed.add_field(name="Role", value=role.mention, inline=True) - embed.add_field(name="Members with this role", value=str(role_members_count), inline=True) + embed.add_field( + name="Members with this role", value=str(role_members_count), inline=True + ) if reason: embed.add_field(name="Reason", value=reason, inline=False) - + # Create confirmation buttons class ConfirmView(discord.ui.View): def __init__(self): super().__init__(timeout=60) self.value = None - + @discord.ui.button(label="Confirm", style=discord.ButtonStyle.danger) - async def confirm(self, button_interaction: discord.Interaction, button: discord.ui.Button): + async def confirm( + self, button_interaction: discord.Interaction, button: discord.ui.Button + ): if button_interaction.user.id != interaction.user.id: - await button_interaction.response.send_message("You cannot use this button.", ephemeral=True) + await button_interaction.response.send_message( + "You cannot use this button.", ephemeral=True + ) return - + self.value = True self.stop() - + try: # Delete the role - await role.delete(reason=f"{reason or 'No reason provided'} (Deleted by {interaction.user})") - + await role.delete( + reason=f"{reason or 'No reason provided'} (Deleted by {interaction.user})" + ) + # Create a success embed success_embed = discord.Embed( title="✅ Role Deleted", description=f"Successfully deleted role **{role_name}**", - color=discord.Color.green() + color=discord.Color.green(), + ) + success_embed.add_field( + name="Deleted by", value=interaction.user.mention, inline=True ) - success_embed.add_field(name="Deleted by", value=interaction.user.mention, inline=True) if reason: - success_embed.add_field(name="Reason", value=reason, inline=False) - - await button_interaction.response.edit_message(embed=success_embed, view=None) - logger.info(f"Role '{role_name}' deleted by {interaction.user} in {interaction.guild.name}") - + success_embed.add_field( + name="Reason", value=reason, inline=False + ) + + await button_interaction.response.edit_message( + embed=success_embed, view=None + ) + logger.info( + f"Role '{role_name}' deleted by {interaction.user} in {interaction.guild.name}" + ) + except discord.Forbidden: await button_interaction.response.edit_message( embed=discord.Embed( title="❌ Error", description="I don't have permission to delete this role.", - color=discord.Color.red() + color=discord.Color.red(), ), - view=None + view=None, ) except discord.HTTPException as e: await button_interaction.response.edit_message( embed=discord.Embed( title="❌ Error", description=f"Failed to delete role: {e}", - color=discord.Color.red() + color=discord.Color.red(), ), - view=None + view=None, ) - + @discord.ui.button(label="Cancel", style=discord.ButtonStyle.secondary) - async def cancel(self, button_interaction: discord.Interaction, button: discord.ui.Button): + async def cancel( + self, button_interaction: discord.Interaction, button: discord.ui.Button + ): if button_interaction.user.id != interaction.user.id: - await button_interaction.response.send_message("You cannot use this button.", ephemeral=True) + await button_interaction.response.send_message( + "You cannot use this button.", ephemeral=True + ) return - + self.value = False self.stop() - + # Create a cancellation embed cancel_embed = discord.Embed( title="❌ Cancelled", description="Role deletion cancelled.", - color=discord.Color.red() + color=discord.Color.red(), ) - - await button_interaction.response.edit_message(embed=cancel_embed, view=None) - + + await button_interaction.response.edit_message( + embed=cancel_embed, view=None + ) + # Send the confirmation message view = ConfirmView() await interaction.response.send_message(embed=embed, view=view) - - async def role_add_callback(self, interaction: discord.Interaction, member: discord.Member, - role: discord.Role, reason: Optional[str] = None): + + async def role_add_callback( + self, + interaction: discord.Interaction, + member: discord.Member, + role: discord.Role, + reason: Optional[str] = None, + ): """Callback for /role add command""" # Check permissions if not interaction.guild: - await interaction.response.send_message("This command can only be used in a server.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used in a server.", ephemeral=True + ) return - + if not interaction.user.guild_permissions.manage_roles: - await interaction.response.send_message("You don't have permission to manage roles.", ephemeral=True) + await interaction.response.send_message( + "You don't have permission to manage roles.", ephemeral=True + ) return - + if not interaction.guild.me.guild_permissions.manage_roles: - await interaction.response.send_message("I don't have permission to manage roles.", ephemeral=True) + await interaction.response.send_message( + "I don't have permission to manage roles.", ephemeral=True + ) return - + # Check if the role is assignable if not role.is_assignable() or role.is_default(): - await interaction.response.send_message("I cannot assign this role. It might be the @everyone role or higher than my highest role.", ephemeral=True) + await interaction.response.send_message( + "I cannot assign this role. It might be the @everyone role or higher than my highest role.", + ephemeral=True, + ) return - + # Check if the member already has the role if role in member.roles: - await interaction.response.send_message(f"{member.mention} already has the role {role.mention}.", ephemeral=True) + await interaction.response.send_message( + f"{member.mention} already has the role {role.mention}.", ephemeral=True + ) return - + try: # Add the role to the member - await member.add_roles(role, reason=f"{reason or 'No reason provided'} (Added by {interaction.user})") - + await member.add_roles( + role, + reason=f"{reason or 'No reason provided'} (Added by {interaction.user})", + ) + # Create an embed with role information embed = discord.Embed( title="✅ Role Added", description=f"Successfully added role {role.mention} to {member.mention}", - color=role.color + color=role.color, ) embed.add_field(name="Member", value=member.mention, inline=True) embed.add_field(name="Role", value=role.mention, inline=True) - embed.add_field(name="Added by", value=interaction.user.mention, inline=True) + embed.add_field( + name="Added by", value=interaction.user.mention, inline=True + ) if reason: embed.add_field(name="Reason", value=reason, inline=False) - + await interaction.response.send_message(embed=embed) - logger.info(f"Role '{role.name}' added to {member} by {interaction.user} in {interaction.guild.name}") + logger.info( + f"Role '{role.name}' added to {member} by {interaction.user} in {interaction.guild.name}" + ) # Attempt to DM the user try: @@ -443,67 +563,103 @@ class RoleManagementCog(commands.Cog): dm_embed = discord.Embed( title="Role Added", description=f"The role {role_info} was added to you in **{interaction.guild.name}**.", - color=role.color + color=role.color, + ) + dm_embed.add_field( + name="Added by", value=interaction.user.mention, inline=True ) - dm_embed.add_field(name="Added by", value=interaction.user.mention, inline=True) if reason: dm_embed.add_field(name="Reason", value=reason, inline=False) await member.send(embed=dm_embed) - logger.info(f"Successfully DMed {member} about role '{role.name}' addition.") + logger.info( + f"Successfully DMed {member} about role '{role.name}' addition." + ) except discord.Forbidden: - logger.warning(f"Failed to DM {member} about role '{role.name}' addition (Forbidden).") + logger.warning( + f"Failed to DM {member} about role '{role.name}' addition (Forbidden)." + ) except discord.HTTPException as e: - logger.warning(f"Failed to DM {member} about role '{role.name}' addition (HTTPException: {e}).") - + logger.warning( + f"Failed to DM {member} about role '{role.name}' addition (HTTPException: {e})." + ) + except discord.Forbidden: - await interaction.response.send_message("I don't have permission to add roles to this member.", ephemeral=True) + await interaction.response.send_message( + "I don't have permission to add roles to this member.", ephemeral=True + ) except discord.HTTPException as e: - await interaction.response.send_message(f"Failed to add role: {e}", ephemeral=True) - - async def role_remove_callback(self, interaction: discord.Interaction, member: discord.Member, - role: discord.Role, reason: Optional[str] = None): + await interaction.response.send_message( + f"Failed to add role: {e}", ephemeral=True + ) + + async def role_remove_callback( + self, + interaction: discord.Interaction, + member: discord.Member, + role: discord.Role, + reason: Optional[str] = None, + ): """Callback for /role remove command""" # Check permissions if not interaction.guild: - await interaction.response.send_message("This command can only be used in a server.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used in a server.", ephemeral=True + ) return - + if not interaction.user.guild_permissions.manage_roles: - await interaction.response.send_message("You don't have permission to manage roles.", ephemeral=True) + await interaction.response.send_message( + "You don't have permission to manage roles.", ephemeral=True + ) return - + if not interaction.guild.me.guild_permissions.manage_roles: - await interaction.response.send_message("I don't have permission to manage roles.", ephemeral=True) + await interaction.response.send_message( + "I don't have permission to manage roles.", ephemeral=True + ) return - + # Check if the role is assignable if not role.is_assignable() or role.is_default(): - await interaction.response.send_message("I cannot remove this role. It might be the @everyone role or higher than my highest role.", ephemeral=True) + await interaction.response.send_message( + "I cannot remove this role. It might be the @everyone role or higher than my highest role.", + ephemeral=True, + ) return - + # Check if the member has the role if role not in member.roles: - await interaction.response.send_message(f"{member.mention} doesn't have the role {role.mention}.", ephemeral=True) + await interaction.response.send_message( + f"{member.mention} doesn't have the role {role.mention}.", + ephemeral=True, + ) return - + try: # Remove the role from the member - await member.remove_roles(role, reason=f"{reason or 'No reason provided'} (Removed by {interaction.user})") - + await member.remove_roles( + role, + reason=f"{reason or 'No reason provided'} (Removed by {interaction.user})", + ) + # Create an embed with role information embed = discord.Embed( title="✅ Role Removed", description=f"Successfully removed role {role.mention} from {member.mention}", - color=role.color + color=role.color, ) embed.add_field(name="Member", value=member.mention, inline=True) embed.add_field(name="Role", value=role.mention, inline=True) - embed.add_field(name="Removed by", value=interaction.user.mention, inline=True) + embed.add_field( + name="Removed by", value=interaction.user.mention, inline=True + ) if reason: embed.add_field(name="Reason", value=reason, inline=False) - + await interaction.response.send_message(embed=embed) - logger.info(f"Role '{role.name}' removed from {member} by {interaction.user} in {interaction.guild.name}") + logger.info( + f"Role '{role.name}' removed from {member} by {interaction.user} in {interaction.guild.name}" + ) # Attempt to DM the user try: @@ -511,170 +667,235 @@ class RoleManagementCog(commands.Cog): dm_embed = discord.Embed( title="Role Removed", description=f"The role {role_info} was removed from you in **{interaction.guild.name}**.", - color=role.color + color=role.color, + ) + dm_embed.add_field( + name="Removed by", value=interaction.user.mention, inline=True ) - dm_embed.add_field(name="Removed by", value=interaction.user.mention, inline=True) if reason: dm_embed.add_field(name="Reason", value=reason, inline=False) await member.send(embed=dm_embed) - logger.info(f"Successfully DMed {member} about role '{role.name}' removal.") + logger.info( + f"Successfully DMed {member} about role '{role.name}' removal." + ) except discord.Forbidden: - logger.warning(f"Failed to DM {member} about role '{role.name}' removal (Forbidden).") + logger.warning( + f"Failed to DM {member} about role '{role.name}' removal (Forbidden)." + ) except discord.HTTPException as e: - logger.warning(f"Failed to DM {member} about role '{role.name}' removal (HTTPException: {e}).") - + logger.warning( + f"Failed to DM {member} about role '{role.name}' removal (HTTPException: {e})." + ) + except discord.Forbidden: - await interaction.response.send_message("I don't have permission to remove roles from this member.", ephemeral=True) + await interaction.response.send_message( + "I don't have permission to remove roles from this member.", + ephemeral=True, + ) except discord.HTTPException as e: - await interaction.response.send_message(f"Failed to remove role: {e}", ephemeral=True) - + await interaction.response.send_message( + f"Failed to remove role: {e}", ephemeral=True + ) + async def role_list_callback(self, interaction: discord.Interaction): """Callback for /role list command""" # Check if in a guild if not interaction.guild: - await interaction.response.send_message("This command can only be used in a server.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used in a server.", ephemeral=True + ) return - + # Get all roles in the guild, sorted by position (highest first) roles = sorted(interaction.guild.roles, key=lambda r: r.position, reverse=True) - + # Create an embed with role information embed = discord.Embed( title=f"Roles in {interaction.guild.name}", - description=f"Total roles: {len(roles) - 1}", # Subtract 1 to exclude @everyone - color=discord.Color.blue() + description=f"Total roles: {len(roles) - 1}", # Subtract 1 to exclude @everyone + color=discord.Color.blue(), ) - + # Add roles to the embed in chunks to avoid hitting the field limit chunk_size = 20 for i in range(0, len(roles), chunk_size): - chunk = roles[i:i+chunk_size] - + chunk = roles[i : i + chunk_size] + # Format the roles role_list = [] for role in chunk: if role.is_default(): # Skip @everyone continue role_list.append(f"{role.mention} - {len(role.members)} members") - + if role_list: embed.add_field( - name=f"Roles {i+1}-{min(i+chunk_size, len(roles) - 1)}", # Subtract 1 to account for @everyone + name=f"Roles {i+1}-{min(i+chunk_size, len(roles) - 1)}", # Subtract 1 to account for @everyone value="\n".join(role_list), - inline=False + inline=False, ) - + await interaction.response.send_message(embed=embed) - - async def role_info_callback(self, interaction: discord.Interaction, role: discord.Role): + + async def role_info_callback( + self, interaction: discord.Interaction, role: discord.Role + ): """Callback for /role info command""" # Check if in a guild if not interaction.guild: - await interaction.response.send_message("This command can only be used in a server.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used in a server.", ephemeral=True + ) return - + # Create an embed with detailed role information - embed = discord.Embed( - title=f"Role Information: {role.name}", - color=role.color - ) - + embed = discord.Embed(title=f"Role Information: {role.name}", color=role.color) + # Add basic information embed.add_field(name="ID", value=role.id, inline=True) embed.add_field(name="Color", value=str(role.color), inline=True) embed.add_field(name="Position", value=role.position, inline=True) - embed.add_field(name="Hoisted", value="Yes" if role.hoist else "No", inline=True) - embed.add_field(name="Mentionable", value="Yes" if role.mentionable else "No", inline=True) - embed.add_field(name="Bot Role", value="Yes" if role.is_bot_managed() else "No", inline=True) - embed.add_field(name="Integration Role", value="Yes" if role.is_integration() else "No", inline=True) + embed.add_field( + name="Hoisted", value="Yes" if role.hoist else "No", inline=True + ) + embed.add_field( + name="Mentionable", value="Yes" if role.mentionable else "No", inline=True + ) + embed.add_field( + name="Bot Role", value="Yes" if role.is_bot_managed() else "No", inline=True + ) + embed.add_field( + name="Integration Role", + value="Yes" if role.is_integration() else "No", + inline=True, + ) embed.add_field(name="Members", value=len(role.members), inline=True) - embed.add_field(name="Created At", value=discord.utils.format_dt(role.created_at), inline=True) - + embed.add_field( + name="Created At", + value=discord.utils.format_dt(role.created_at), + inline=True, + ) + # Add permissions information permissions = [] for perm, value in role.permissions: if value: - formatted_perm = perm.replace('_', ' ').title() + formatted_perm = perm.replace("_", " ").title() permissions.append(f"✅ {formatted_perm}") - + if permissions: # Split permissions into chunks to avoid hitting the field value limit chunk_size = 10 for i in range(0, len(permissions), chunk_size): - chunk = permissions[i:i+chunk_size] + chunk = permissions[i : i + chunk_size] embed.add_field( - name="Permissions" if i == 0 else "\u200b", # Use zero-width space for additional fields + name=( + "Permissions" if i == 0 else "\u200b" + ), # Use zero-width space for additional fields value="\n".join(chunk), - inline=False + inline=False, ) else: embed.add_field(name="Permissions", value="No permissions", inline=False) - + await interaction.response.send_message(embed=embed) - - async def role_position_callback(self, interaction: discord.Interaction, role: discord.Role, - position: int, reason: Optional[str] = None): + + async def role_position_callback( + self, + interaction: discord.Interaction, + role: discord.Role, + position: int, + reason: Optional[str] = None, + ): """Callback for /role position command""" # Check permissions if not interaction.guild: - await interaction.response.send_message("This command can only be used in a server.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used in a server.", ephemeral=True + ) return - + if not interaction.user.guild_permissions.manage_roles: - await interaction.response.send_message("You don't have permission to manage roles.", ephemeral=True) + await interaction.response.send_message( + "You don't have permission to manage roles.", ephemeral=True + ) return - + if not interaction.guild.me.guild_permissions.manage_roles: - await interaction.response.send_message("I don't have permission to manage roles.", ephemeral=True) + await interaction.response.send_message( + "I don't have permission to manage roles.", ephemeral=True + ) return - + # Check if the role is manageable if not role.is_assignable() or role.is_default(): - await interaction.response.send_message("I cannot move this role. It might be the @everyone role or higher than my highest role.", ephemeral=True) + await interaction.response.send_message( + "I cannot move this role. It might be the @everyone role or higher than my highest role.", + ephemeral=True, + ) return - + # Validate position if position < 1: - await interaction.response.send_message("Position must be at least 1.", ephemeral=True) + await interaction.response.send_message( + "Position must be at least 1.", ephemeral=True + ) return - + # Get the maximum valid position (excluding @everyone) max_position = len(interaction.guild.roles) - 1 if position > max_position: - await interaction.response.send_message(f"Position must be at most {max_position}.", ephemeral=True) + await interaction.response.send_message( + f"Position must be at most {max_position}.", ephemeral=True + ) return - + # Store original position for the embed original_position = role.position - + try: # Convert the 1-based user-friendly position to the 0-based position used by Discord # Also account for the fact that positions are ordered from bottom to top actual_position = position - + # Move the role - await role.edit(position=actual_position, reason=f"{reason or 'No reason provided'} (Position changed by {interaction.user})") - + await role.edit( + position=actual_position, + reason=f"{reason or 'No reason provided'} (Position changed by {interaction.user})", + ) + # Create an embed with role information embed = discord.Embed( title="✅ Role Position Changed", description=f"Successfully changed position of role {role.mention}", - color=role.color + color=role.color, ) embed.add_field(name="Role", value=role.mention, inline=True) - embed.add_field(name="Old Position", value=str(original_position), inline=True) + embed.add_field( + name="Old Position", value=str(original_position), inline=True + ) embed.add_field(name="New Position", value=str(role.position), inline=True) - embed.add_field(name="Changed by", value=interaction.user.mention, inline=True) + embed.add_field( + name="Changed by", value=interaction.user.mention, inline=True + ) if reason: embed.add_field(name="Reason", value=reason, inline=False) - + await interaction.response.send_message(embed=embed) - logger.info(f"Role '{role.name}' position changed from {original_position} to {role.position} by {interaction.user} in {interaction.guild.name}") - + logger.info( + f"Role '{role.name}' position changed from {original_position} to {role.position} by {interaction.user} in {interaction.guild.name}" + ) + except discord.Forbidden: - await interaction.response.send_message("I don't have permission to change this role's position.", ephemeral=True) + await interaction.response.send_message( + "I don't have permission to change this role's position.", + ephemeral=True, + ) except discord.HTTPException as e: - await interaction.response.send_message(f"Failed to change role position: {e}", ephemeral=True) + await interaction.response.send_message( + f"Failed to change role position: {e}", ephemeral=True + ) + async def setup(bot): await bot.add_cog(RoleManagementCog(bot)) diff --git a/cogs/role_selector_cog.py b/cogs/role_selector_cog.py index a16d06b..f35fa61 100644 --- a/cogs/role_selector_cog.py +++ b/cogs/role_selector_cog.py @@ -1,45 +1,56 @@ import discord from discord.ext import commands -from discord import app_commands # Added for slash commands +from discord import app_commands # Added for slash commands from discord.ui import View, Select, select, Modal, TextInput import json import os from typing import List, Dict, Optional, Set, Tuple, Union -import asyncio # Added for sleep -import re # For hex/rgb validation -import uuid # For generating IDs +import asyncio # Added for sleep +import re # For hex/rgb validation +import uuid # For generating IDs # Database and Pydantic Models from api_service.api_server import db from api_service.api_models import ( - RoleOption, RoleCategoryPreset, GuildRole, - GuildRoleCategoryConfig, UserCustomColorRole + RoleOption, + RoleCategoryPreset, + GuildRole, + GuildRoleCategoryConfig, + UserCustomColorRole, ) + async def is_owner_check(interaction: discord.Interaction) -> bool: """Checks if the interacting user is the bot owner.""" return interaction.user.id == interaction.client.owner_id + # For color name validation try: from matplotlib.colors import is_color_like, to_rgb, XKCD_COLORS except ImportError: - XKCD_COLORS = {} # Fallback if matplotlib is not installed - def is_color_like(c): # Basic fallback + XKCD_COLORS = {} # Fallback if matplotlib is not installed + + def is_color_like(c): # Basic fallback if isinstance(c, str): - return c.startswith('#') and len(c) in [4, 7] + return c.startswith("#") and len(c) in [4, 7] return False - def to_rgb(c): # Basic fallback - if isinstance(c, str) and c.startswith('#'): - hex_color = c.lstrip('#') + + def to_rgb(c): # Basic fallback + if isinstance(c, str) and c.startswith("#"): + hex_color = c.lstrip("#") if len(hex_color) == 3: - return tuple(int(hex_color[i]*2, 16)/255.0 for i in range(3)) + return tuple(int(hex_color[i] * 2, 16) / 255.0 for i in range(3)) if len(hex_color) == 6: - return tuple(int(hex_color[i:i+2], 16)/255.0 for i in range(0, 6, 2)) - return (0,0,0) # Default black + return tuple( + int(hex_color[i : i + 2], 16) / 255.0 for i in range(0, 6, 2) + ) + return (0, 0, 0) # Default black + # --- Constants --- -DEFAULT_ROLE_COLOR = discord.Color.default() # Used for custom color roles initially +DEFAULT_ROLE_COLOR = discord.Color.default() # Used for custom color roles initially + # --- Color Parsing Helper --- def _parse_color_input(color_input: str) -> Optional[discord.Color]: @@ -51,14 +62,16 @@ def _parse_color_input(color_input: str) -> Optional[discord.Color]: if hex_match: hex_val = hex_match.group(1) if len(hex_val) == 3: - hex_val = "".join([c*2 for c in hex_val]) + hex_val = "".join([c * 2 for c in hex_val]) try: return discord.Color(int(hex_val, 16)) except ValueError: - pass # Should not happen with regex match + pass # Should not happen with regex match # Try RGB: "r, g, b" or "(r, g, b)" - rgb_match = re.fullmatch(r"\(?\s*(\d{1,3})\s*,\s*(\d{1,3})\s*,\s*(\d{1,3})\s*\)?", color_input) + rgb_match = re.fullmatch( + r"\(?\s*(\d{1,3})\s*,\s*(\d{1,3})\s*,\s*(\d{1,3})\s*\)?", color_input + ) if rgb_match: try: r, g, b = [int(x) for x in rgb_match.groups()] @@ -68,24 +81,49 @@ def _parse_color_input(color_input: str) -> Optional[discord.Color]: pass # Try English color name (matplotlib XKCD_COLORS) - if XKCD_COLORS: # Check if matplotlib was imported + if XKCD_COLORS: # Check if matplotlib was imported normalized_input = color_input.lower().replace(" ", "") # Check against normalized keys for xkcd_name_key, xkcd_hex_val in XKCD_COLORS.items(): if xkcd_name_key.lower().replace(" ", "") == normalized_input: try: # to_rgb for xkcd colors usually returns (r,g,b) float tuple - rgb_float = to_rgb(xkcd_hex_val) - return discord.Color.from_rgb(int(rgb_float[0]*255), int(rgb_float[1]*255), int(rgb_float[2]*255)) + rgb_float = to_rgb(xkcd_hex_val) + return discord.Color.from_rgb( + int(rgb_float[0] * 255), + int(rgb_float[1] * 255), + int(rgb_float[2] * 255), + ) except Exception: - pass # Matplotlib color conversion failed - break + pass # Matplotlib color conversion failed + break # Fallback for very common color names if matplotlib is not available - elif color_input.lower() in {"red": (255,0,0), "green": (0,255,0), "blue": (0,0,255), "yellow": (255,255,0), "purple": (128,0,128), "orange": (255,165,0), "pink": (255,192,203), "black": (0,0,0), "white": (255,255,255)}: - r,g,b = {"red": (255,0,0), "green": (0,255,0), "blue": (0,0,255), "yellow": (255,255,0), "purple": (128,0,128), "orange": (255,165,0), "pink": (255,192,203), "black": (0,0,0), "white": (255,255,255)}[color_input.lower()] - return discord.Color.from_rgb(r,g,b) + elif color_input.lower() in { + "red": (255, 0, 0), + "green": (0, 255, 0), + "blue": (0, 0, 255), + "yellow": (255, 255, 0), + "purple": (128, 0, 128), + "orange": (255, 165, 0), + "pink": (255, 192, 203), + "black": (0, 0, 0), + "white": (255, 255, 255), + }: + r, g, b = { + "red": (255, 0, 0), + "green": (0, 255, 0), + "blue": (0, 0, 255), + "yellow": (255, 255, 0), + "purple": (128, 0, 128), + "orange": (255, 165, 0), + "pink": (255, 192, 203), + "black": (0, 0, 0), + "white": (255, 255, 255), + }[color_input.lower()] + return discord.Color.from_rgb(r, g, b) return None + # --- Custom Color Modal --- class CustomColorModal(Modal, title="Set Your Custom Role Color"): color_input = TextInput( @@ -93,7 +131,7 @@ class CustomColorModal(Modal, title="Set Your Custom Role Color"): placeholder="#RRGGBB, 255,0,128, or 'sky blue'", style=discord.TextStyle.short, required=True, - max_length=100 + max_length=100, ) async def on_submit(self, interaction: discord.Interaction): @@ -101,7 +139,9 @@ class CustomColorModal(Modal, title="Set Your Custom Role Color"): guild = interaction.guild member = interaction.user if not guild or not isinstance(member, discord.Member): - await interaction.followup.send("This can only be used in a server.", ephemeral=True) + await interaction.followup.send( + "This can only be used in a server.", ephemeral=True + ) return parsed_color = _parse_color_input(self.color_input.value) @@ -110,13 +150,15 @@ class CustomColorModal(Modal, title="Set Your Custom Role Color"): f"Could not understand the color '{self.color_input.value}'.\n" "Please use a hex code (e.g., `#FF0000`), RGB values (e.g., `255,0,0`), " "or a known color name (e.g., 'red', 'sky blue').", - ephemeral=True + ephemeral=True, ) return custom_role_name = f"User Color - {member.id}" - existing_user_color_role_db = db.get_user_custom_color_role(str(guild.id), str(member.id)) - + existing_user_color_role_db = db.get_user_custom_color_role( + str(guild.id), str(member.id) + ) + role_to_update: Optional[discord.Role] = None if existing_user_color_role_db: @@ -124,17 +166,25 @@ class CustomColorModal(Modal, title="Set Your Custom Role Color"): if not role_to_update: db.delete_user_custom_color_role(str(guild.id), str(member.id)) existing_user_color_role_db = None - elif role_to_update.name != custom_role_name: # Name mismatch, could be manually changed - try: # Try to rename it back - await role_to_update.edit(name=custom_role_name, reason="Standardizing custom color role name") + elif ( + role_to_update.name != custom_role_name + ): # Name mismatch, could be manually changed + try: # Try to rename it back + await role_to_update.edit( + name=custom_role_name, + reason="Standardizing custom color role name", + ) except discord.Forbidden: - await interaction.followup.send("I couldn't standardize your existing color role name. Please check my permissions.", ephemeral=True) + await interaction.followup.send( + "I couldn't standardize your existing color role name. Please check my permissions.", + ephemeral=True, + ) # Potentially fall through to create a new one if renaming fails and old one is problematic except discord.HTTPException: - pass # Non-critical error, proceed with color update + pass # Non-critical error, proceed with color update if not role_to_update: - for r in guild.roles: # Check if a role with the target name already exists + for r in guild.roles: # Check if a role with the target name already exists if r.name == custom_role_name: role_to_update = r break @@ -142,63 +192,85 @@ class CustomColorModal(Modal, title="Set Your Custom Role Color"): try: role_to_update = await guild.create_role( name=custom_role_name, - color=parsed_color, # Set initial color - reason=f"Custom color role for {member.display_name}" + color=parsed_color, # Set initial color + reason=f"Custom color role for {member.display_name}", ) except discord.Forbidden: - await interaction.followup.send("I don't have permission to create roles.", ephemeral=True) + await interaction.followup.send( + "I don't have permission to create roles.", ephemeral=True + ) return except discord.HTTPException as e: - await interaction.followup.send(f"Failed to create role: {e}", ephemeral=True) + await interaction.followup.send( + f"Failed to create role: {e}", ephemeral=True + ) return - + if not role_to_update: - await interaction.followup.send("Failed to obtain a role to update.", ephemeral=True) + await interaction.followup.send( + "Failed to obtain a role to update.", ephemeral=True + ) return if role_to_update.color != parsed_color: try: - await role_to_update.edit(color=parsed_color, reason=f"Color update for {member.display_name}") + await role_to_update.edit( + color=parsed_color, reason=f"Color update for {member.display_name}" + ) except discord.Forbidden: - await interaction.followup.send("I don't have permission to edit the role color.", ephemeral=True) + await interaction.followup.send( + "I don't have permission to edit the role color.", ephemeral=True + ) return except discord.HTTPException as e: - await interaction.followup.send(f"Failed to update role color: {e}", ephemeral=True) + await interaction.followup.send( + f"Failed to update role color: {e}", ephemeral=True + ) return - + roles_to_add_to_member = [] if role_to_update.id not in [r.id for r in member.roles]: roles_to_add_to_member.append(role_to_update) - + roles_to_remove_from_member = [ - r for r in member.roles + r + for r in member.roles if r.name.startswith("User Color - ") and r.id != role_to_update.id ] - + try: if roles_to_remove_from_member: - await member.remove_roles(*roles_to_remove_from_member, reason="Cleaning up old custom color roles") + await member.remove_roles( + *roles_to_remove_from_member, + reason="Cleaning up old custom color roles", + ) if roles_to_add_to_member: - await member.add_roles(*roles_to_add_to_member, reason="Applied custom color role") + await member.add_roles( + *roles_to_add_to_member, reason="Applied custom color role" + ) except discord.Forbidden: - await interaction.followup.send("I don't have permission to assign roles.", ephemeral=True) + await interaction.followup.send( + "I don't have permission to assign roles.", ephemeral=True + ) return except discord.HTTPException as e: - await interaction.followup.send(f"Failed to assign role: {e}", ephemeral=True) + await interaction.followup.send( + f"Failed to assign role: {e}", ephemeral=True + ) return user_color_role_data = UserCustomColorRole( user_id=str(member.id), guild_id=str(guild.id), role_id=str(role_to_update.id), - hex_color=f"#{parsed_color.value:06x}" + hex_color=f"#{parsed_color.value:06x}", ) db.save_user_custom_color_role(user_color_role_data) # New logic: Remove roles from "Colors" preset if custom color is set removed_preset_color_roles_names = [] # Ensure guild is not None, though it should be from earlier checks (line 103) - if guild and isinstance(member, discord.Member): # member check for safety + if guild and isinstance(member, discord.Member): # member check for safety guild_role_categories = db.get_guild_role_category_configs(str(guild.id)) colors_preset_role_ids_to_remove = set() @@ -207,24 +279,37 @@ class CustomColorModal(Modal, title="Set Your Custom Role Color"): if cat_config.is_preset and cat_config.preset_id == "default_colors": for role_option in cat_config.roles: colors_preset_role_ids_to_remove.add(int(role_option.role_id)) - break # Found the Colors preset for this guild + break # Found the Colors preset for this guild if colors_preset_role_ids_to_remove: roles_to_actually_remove_from_member = [] for member_role in member.roles: if member_role.id in colors_preset_role_ids_to_remove: roles_to_actually_remove_from_member.append(member_role) - + if roles_to_actually_remove_from_member: try: - await member.remove_roles(*roles_to_actually_remove_from_member, reason="User set a custom color, removing preset color role(s).") - removed_preset_color_roles_names = [r.name for r in roles_to_actually_remove_from_member] + await member.remove_roles( + *roles_to_actually_remove_from_member, + reason="User set a custom color, removing preset color role(s).", + ) + removed_preset_color_roles_names = [ + r.name for r in roles_to_actually_remove_from_member + ] except discord.Forbidden: - await interaction.followup.send("I tried to remove your preset color role(s) but lack permissions.", ephemeral=True) + await interaction.followup.send( + "I tried to remove your preset color role(s) but lack permissions.", + ephemeral=True, + ) except discord.HTTPException as e: - await interaction.followup.send(f"Failed to remove your preset color role(s): {e}", ephemeral=True) + await interaction.followup.send( + f"Failed to remove your preset color role(s): {e}", + ephemeral=True, + ) - feedback_message = f"Your custom role color has been set to {user_color_role_data.hex_color}!" + feedback_message = ( + f"Your custom role color has been set to {user_color_role_data.hex_color}!" + ) if removed_preset_color_roles_names: feedback_message += f"\nRemoved preset color role(s): {', '.join(removed_preset_color_roles_names)}." await interaction.followup.send(feedback_message, ephemeral=True) @@ -233,31 +318,50 @@ class CustomColorModal(Modal, title="Set Your Custom Role Color"): await interaction.followup.send(f"An error occurred: {error}", ephemeral=True) print(f"Error in CustomColorModal: {error}") + # --- View for the Custom Color Button --- class CustomColorButtonView(View): def __init__(self): - super().__init__(timeout=None) + super().__init__(timeout=None) - @discord.ui.button(label="Set Custom Role Color", style=discord.ButtonStyle.primary, custom_id="persistent_set_custom_color_button") - async def set_color_button_callback(self, interaction: discord.Interaction, button: discord.ui.Button): + @discord.ui.button( + label="Set Custom Role Color", + style=discord.ButtonStyle.primary, + custom_id="persistent_set_custom_color_button", + ) + async def set_color_button_callback( + self, interaction: discord.Interaction, button: discord.ui.Button + ): modal = CustomColorModal() await interaction.response.send_modal(modal) + # --- Persistent View Definition --- class RoleSelectorView(View): - def __init__(self, guild_id: int, category_config: GuildRoleCategoryConfig, bot_instance): + def __init__( + self, guild_id: int, category_config: GuildRoleCategoryConfig, bot_instance + ): super().__init__(timeout=None) self.guild_id = guild_id self.category_config = category_config - self.bot = bot_instance - self.category_role_ids: Set[int] = {int(role.role_id) for role in category_config.roles} - self.custom_id = f"persistent_role_select_view_{guild_id}_{category_config.category_id}" + self.bot = bot_instance + self.category_role_ids: Set[int] = { + int(role.role_id) for role in category_config.roles + } + self.custom_id = ( + f"persistent_role_select_view_{guild_id}_{category_config.category_id}" + ) self.select_chunk_map: Dict[str, Set[int]] = {} category_display_roles: List[GuildRole] = category_config.roles - self.role_chunks = [category_display_roles[i:i + 25] for i in range(0, len(category_display_roles), 25)] + self.role_chunks = [ + category_display_roles[i : i + 25] + for i in range(0, len(category_display_roles), 25) + ] num_chunks = len(self.role_chunks) - total_max_values = min(category_config.max_selectable, len(category_display_roles)) + total_max_values = min( + category_config.max_selectable, len(category_display_roles) + ) actual_min_values = 0 for i, chunk in enumerate(self.role_chunks): @@ -265,8 +369,9 @@ class RoleSelectorView(View): discord.SelectOption( label=role.name, value=str(role.role_id), - emoji=role.emoji if role.emoji else None - ) for role in chunk + emoji=role.emoji if role.emoji else None, + ) + for role in chunk ] chunk_role_ids = {int(role.role_id) for role in chunk} if not options: @@ -274,15 +379,19 @@ class RoleSelectorView(View): chunk_max_values = min(total_max_values, len(options)) placeholder = f"Select {category_config.name} role(s)..." if num_chunks > 1: - placeholder = f"Select {category_config.name} role(s) ({i+1}/{num_chunks})..." - select_custom_id = f"role_select_dropdown_{guild_id}_{category_config.category_id}_{i}" + placeholder = ( + f"Select {category_config.name} role(s) ({i+1}/{num_chunks})..." + ) + select_custom_id = ( + f"role_select_dropdown_{guild_id}_{category_config.category_id}_{i}" + ) self.select_chunk_map[select_custom_id] = chunk_role_ids select_component = Select( placeholder=placeholder, min_values=actual_min_values, max_values=chunk_max_values, options=options, - custom_id=select_custom_id + custom_id=select_custom_id, ) select_component.callback = self.select_callback self.add_item(select_component) @@ -292,95 +401,172 @@ class RoleSelectorView(View): member = interaction.user guild = interaction.guild if not isinstance(member, discord.Member) or not guild: - await interaction.followup.send("This interaction must be used within a server.", ephemeral=True) + await interaction.followup.send( + "This interaction must be used within a server.", ephemeral=True + ) return if guild.id != self.guild_id: - await interaction.followup.send("This role selector is not for this server.", ephemeral=True) + await interaction.followup.send( + "This role selector is not for this server.", ephemeral=True + ) return - interacted_custom_id = interaction.data['custom_id'] - interacted_chunk_role_ids: Set[int] = self.select_chunk_map.get(interacted_custom_id, set()) + interacted_custom_id = interaction.data["custom_id"] + interacted_chunk_role_ids: Set[int] = self.select_chunk_map.get( + interacted_custom_id, set() + ) if not interacted_chunk_role_ids: for component in self.children: - if isinstance(component, Select) and component.custom_id == interacted_custom_id: - interacted_chunk_role_ids = {int(opt.value) for opt in component.options} + if ( + isinstance(component, Select) + and component.custom_id == interacted_custom_id + ): + interacted_chunk_role_ids = { + int(opt.value) for opt in component.options + } break if not interacted_chunk_role_ids: - await interaction.followup.send("An internal error occurred identifying roles for this dropdown.", ephemeral=True) + await interaction.followup.send( + "An internal error occurred identifying roles for this dropdown.", + ephemeral=True, + ) return - selected_values = interaction.data.get('values', []) + selected_values = interaction.data.get("values", []) selected_role_ids_from_interaction = {int(value) for value in selected_values} - member_category_role_ids = {role.id for role in member.roles if role.id in self.category_role_ids} + member_category_role_ids = { + role.id for role in member.roles if role.id in self.category_role_ids + } roles_to_add_ids = selected_role_ids_from_interaction - member_category_role_ids - member_roles_in_interacted_chunk = member_category_role_ids.intersection(interacted_chunk_role_ids) - roles_to_remove_ids = member_roles_in_interacted_chunk - selected_role_ids_from_interaction + member_roles_in_interacted_chunk = member_category_role_ids.intersection( + interacted_chunk_role_ids + ) + roles_to_remove_ids = ( + member_roles_in_interacted_chunk - selected_role_ids_from_interaction + ) if self.category_config.max_selectable == 1 and roles_to_add_ids: if len(roles_to_add_ids) > 1: - await interaction.followup.send(f"Error: Cannot select multiple roles for '{self.category_config.name}'.", ephemeral=True) - return + await interaction.followup.send( + f"Error: Cannot select multiple roles for '{self.category_config.name}'.", + ephemeral=True, + ) + return role_to_add_id = list(roles_to_add_ids)[0] other_member_roles_in_category = member_category_role_ids - {role_to_add_id} roles_to_remove_ids.update(other_member_roles_in_category) roles_to_add_ids = {role_to_add_id} - roles_to_add = {guild.get_role(role_id) for role_id in roles_to_add_ids if guild.get_role(role_id)} - roles_to_remove = {guild.get_role(role_id) for role_id in roles_to_remove_ids if guild.get_role(role_id)} + roles_to_add = { + guild.get_role(role_id) + for role_id in roles_to_add_ids + if guild.get_role(role_id) + } + roles_to_remove = { + guild.get_role(role_id) + for role_id in roles_to_remove_ids + if guild.get_role(role_id) + } added_names, removed_names, error_messages = [], [], [] - removed_custom_color_feedback = "" # Initialize here + removed_custom_color_feedback = "" # Initialize here # New logic: If adding a "Colors" preset role, remove custom color role # The preset_id for "Colors" is 'default_colors' - is_colors_preset_category = self.category_config.is_preset and self.category_config.preset_id == "default_colors" - + is_colors_preset_category = ( + self.category_config.is_preset + and self.category_config.preset_id == "default_colors" + ) + # A color from the "Colors" preset is being added (roles_to_add is not empty) if is_colors_preset_category and roles_to_add: # Ensure member and guild are valid (they should be from earlier checks in lines 262-267) if isinstance(member, discord.Member) and guild: - existing_user_custom_color_db = db.get_user_custom_color_role(str(guild.id), str(member.id)) + existing_user_custom_color_db = db.get_user_custom_color_role( + str(guild.id), str(member.id) + ) if existing_user_custom_color_db: - custom_color_role_to_remove = guild.get_role(int(existing_user_custom_color_db.role_id)) + custom_color_role_to_remove = guild.get_role( + int(existing_user_custom_color_db.role_id) + ) if custom_color_role_to_remove: try: - await member.remove_roles(custom_color_role_to_remove, reason="User selected a preset color, removing custom color role.") - db.delete_user_custom_color_role(str(guild.id), str(member.id)) # Delete from DB + await member.remove_roles( + custom_color_role_to_remove, + reason="User selected a preset color, removing custom color role.", + ) + db.delete_user_custom_color_role( + str(guild.id), str(member.id) + ) # Delete from DB removed_custom_color_feedback = f"\n- Removed custom color role '{custom_color_role_to_remove.name}'." except discord.Forbidden: - error_messages.append("Could not remove your custom color role (permissions).") + error_messages.append( + "Could not remove your custom color role (permissions)." + ) except discord.HTTPException as e: - error_messages.append(f"Error removing custom color role: {e}") - else: # Role not found in guild, but was in DB. Clean up DB. + error_messages.append( + f"Error removing custom color role: {e}" + ) + else: # Role not found in guild, but was in DB. Clean up DB. db.delete_user_custom_color_role(str(guild.id), str(member.id)) removed_custom_color_feedback = "\n- Your previous custom color role was not found in the server and has been cleared from my records." try: if roles_to_remove: - await member.remove_roles(*roles_to_remove, reason=f"Deselected/changed via {self.category_config.name} role selector ({interacted_custom_id})") + await member.remove_roles( + *roles_to_remove, + reason=f"Deselected/changed via {self.category_config.name} role selector ({interacted_custom_id})", + ) removed_names = [r.name for r in roles_to_remove if r] if roles_to_add: - await member.add_roles(*roles_to_add, reason=f"Selected via {self.category_config.name} role selector ({interacted_custom_id})") + await member.add_roles( + *roles_to_add, + reason=f"Selected via {self.category_config.name} role selector ({interacted_custom_id})", + ) added_names = [r.name for r in roles_to_add if r] feedback = "Your roles have been updated!" - if added_names: feedback += f"\n+ Added: {', '.join(added_names)}" - if removed_names: feedback += f"\n- Removed: {', '.join(removed_names)}" - feedback += removed_custom_color_feedback # Add the custom color removal feedback here - + if added_names: + feedback += f"\n+ Added: {', '.join(added_names)}" + if removed_names: + feedback += f"\n- Removed: {', '.join(removed_names)}" + feedback += removed_custom_color_feedback # Add the custom color removal feedback here + # Adjusted condition for "no changes" message # Ensure removed_custom_color_feedback is considered. If it has content, changes were made. - if not added_names and not removed_names and not removed_custom_color_feedback.strip(): - if selected_values: feedback = "No changes needed for the roles selected in this dropdown." - else: feedback = "No roles selected in this dropdown." if not member_roles_in_interacted_chunk else "Roles deselected from this dropdown." + if ( + not added_names + and not removed_names + and not removed_custom_color_feedback.strip() + ): + if selected_values: + feedback = ( + "No changes needed for the roles selected in this dropdown." + ) + else: + feedback = ( + "No roles selected in this dropdown." + if not member_roles_in_interacted_chunk + else "Roles deselected from this dropdown." + ) await interaction.followup.send(feedback, ephemeral=True) - except discord.Forbidden: error_messages.append("I don't have permission to manage roles.") - except discord.HTTPException as e: error_messages.append(f"An error occurred: {e}") - except Exception as e: error_messages.append(f"Unexpected error: {e}"); print(f"Error in role selector: {e}") - if error_messages: await interaction.followup.send("\n".join(error_messages), ephemeral=True) + except discord.Forbidden: + error_messages.append("I don't have permission to manage roles.") + except discord.HTTPException as e: + error_messages.append(f"An error occurred: {e}") + except Exception as e: + error_messages.append(f"Unexpected error: {e}") + print(f"Error in role selector: {e}") + if error_messages: + await interaction.followup.send("\n".join(error_messages), ephemeral=True) + class RoleSelectorCog(commands.Cog): - roleselect_group = app_commands.Group(name="roleselect", description="Manage role selection categories and selectors.") - rolepreset_group = app_commands.Group(name="rolepreset", description="Manage global role category presets.") + roleselect_group = app_commands.Group( + name="roleselect", description="Manage role selection categories and selectors." + ) + rolepreset_group = app_commands.Group( + name="rolepreset", description="Manage global role category presets." + ) def __init__(self, bot): self.bot = bot @@ -390,7 +576,7 @@ class RoleSelectorCog(commands.Cog): await self.bot.wait_until_ready() print("RoleSelectorCog: Registering persistent views...") registered_count = 0 - + # Register RoleSelectorView for each guild config guild_configs_data = db.get_all_guild_role_category_configs() for guild_id_str, category_configs_list in guild_configs_data.items(): @@ -400,66 +586,105 @@ class RoleSelectorCog(commands.Cog): print(f" Skipping guild {guild_id} (not found).") continue for category_config in category_configs_list: - if category_config.roles: # Only register if there are roles + if category_config.roles: # Only register if there are roles try: view = RoleSelectorView(guild_id, category_config, self.bot) self.bot.add_view(view) registered_count += 1 except Exception as e: - print(f" Error registering RoleSelectorView for {category_config.name} in {guild.id}: {e}") - + print( + f" Error registering RoleSelectorView for {category_config.name} in {guild.id}: {e}" + ) + # Register CustomColorButtonView (it's globally persistent by its custom_id) try: self.bot.add_view(CustomColorButtonView()) print(" Registered CustomColorButtonView globally.") - registered_count +=1 + registered_count += 1 except Exception as e: print(f" Error registering CustomColorButtonView globally: {e}") - print(f"RoleSelectorCog: Finished registering {registered_count} persistent views.") + print( + f"RoleSelectorCog: Finished registering {registered_count} persistent views." + ) - def _get_guild_category_config(self, guild_id: int, category_name_or_id: str) -> Optional[GuildRoleCategoryConfig]: + def _get_guild_category_config( + self, guild_id: int, category_name_or_id: str + ) -> Optional[GuildRoleCategoryConfig]: configs = db.get_guild_role_category_configs(str(guild_id)) for config in configs: - if config.category_id == category_name_or_id or config.name.lower() == category_name_or_id.lower(): + if ( + config.category_id == category_name_or_id + or config.name.lower() == category_name_or_id.lower() + ): return config return None - async def autocomplete_category_name(self, interaction: discord.Interaction, current: str) -> List[app_commands.Choice[str]]: + async def autocomplete_category_name( + self, interaction: discord.Interaction, current: str + ) -> List[app_commands.Choice[str]]: choices = [] # Add existing guild category names if interaction.guild: - guild_configs = db.get_guild_role_category_configs(str(interaction.guild_id)) + guild_configs = db.get_guild_role_category_configs( + str(interaction.guild_id) + ) for config in guild_configs: - if config.name.lower().startswith(current.lower()): # Check if current is not empty before startswith - choices.append(app_commands.Choice(name=config.name, value=config.name)) - elif not current: # If current is empty, add all - choices.append(app_commands.Choice(name=config.name, value=config.name)) + if config.name.lower().startswith( + current.lower() + ): # Check if current is not empty before startswith + choices.append( + app_commands.Choice(name=config.name, value=config.name) + ) + elif not current: # If current is empty, add all + choices.append( + app_commands.Choice(name=config.name, value=config.name) + ) # Add global preset names presets = db.get_all_role_category_presets() for preset in presets: - if preset.name.lower().startswith(current.lower()): # Check if current is not empty - choices.append(app_commands.Choice(name=f"Preset: {preset.name}", value=preset.name)) - elif not current: # If current is empty, add all - choices.append(app_commands.Choice(name=f"Preset: {preset.name}", value=preset.name)) - + if preset.name.lower().startswith( + current.lower() + ): # Check if current is not empty + choices.append( + app_commands.Choice( + name=f"Preset: {preset.name}", value=preset.name + ) + ) + elif not current: # If current is empty, add all + choices.append( + app_commands.Choice( + name=f"Preset: {preset.name}", value=preset.name + ) + ) # Limit to 25 choices as per Discord API limits return choices[:25] - @roleselect_group.command(name="addcategory", description="Adds a new role category for selection.") + @roleselect_group.command( + name="addcategory", description="Adds a new role category for selection." + ) @app_commands.checks.has_permissions(manage_guild=True) @app_commands.describe( name="The name for the new category (or select a preset).", description="A description for this role category.", max_selectable="Maximum number of roles a user can select from this category (default: 1).", - preset_id="Optional ID of a global preset to base this category on (auto-filled if preset selected for name)." + preset_id="Optional ID of a global preset to base this category on (auto-filled if preset selected for name).", ) @app_commands.autocomplete(name=autocomplete_category_name) - async def roleselect_addcategory(self, interaction: discord.Interaction, name: str, description: str, max_selectable: Optional[int] = 1, preset_id: Optional[str] = None): + async def roleselect_addcategory( + self, + interaction: discord.Interaction, + name: str, + description: str, + max_selectable: Optional[int] = 1, + preset_id: Optional[str] = None, + ): if not interaction.guild: - await interaction.response.send_message("This command can only be used in a server.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used in a server.", ephemeral=True + ) return guild_id_str = str(interaction.guild_id) # Ensure max_selectable has a valid default if None is passed by Discord for optional int @@ -469,16 +694,24 @@ class RoleSelectorCog(commands.Cog): # If it starts with "Preset: ", extract the actual preset name actual_name = name if name.startswith("Preset: "): - actual_name = name[len("Preset: "):] + actual_name = name[len("Preset: ") :] # If preset_id was not explicitly provided, set it to the ID of the selected preset if not preset_id: - selected_preset = discord.utils.find(lambda p: p.name == actual_name, db.get_all_role_category_presets()) + selected_preset = discord.utils.find( + lambda p: p.name == actual_name, db.get_all_role_category_presets() + ) if selected_preset: preset_id = selected_preset.id - if self._get_guild_category_config(interaction.guild_id, actual_name) and not preset_id: # Allow adding preset even if name conflicts, preset name will be used - await interaction.response.send_message(f"A custom role category named '{actual_name}' already exists.", ephemeral=True) - return + if ( + self._get_guild_category_config(interaction.guild_id, actual_name) + and not preset_id + ): # Allow adding preset even if name conflicts, preset name will be used + await interaction.response.send_message( + f"A custom role category named '{actual_name}' already exists.", + ephemeral=True, + ) + return roles_to_add: List[GuildRole] = [] is_preset_based = False @@ -489,63 +722,102 @@ class RoleSelectorCog(commands.Cog): if preset_id: preset = db.get_role_category_preset(preset_id) if not preset: - await interaction.response.send_message(f"Preset with ID '{preset_id}' not found.", ephemeral=True) + await interaction.response.send_message( + f"Preset with ID '{preset_id}' not found.", ephemeral=True + ) return - - final_name = preset.name # Use preset's name - if self._get_guild_category_config(interaction.guild_id, final_name): # Check if preset name already exists - await interaction.response.send_message(f"A category based on preset '{final_name}' already exists.", ephemeral=True) + + final_name = preset.name # Use preset's name + if self._get_guild_category_config( + interaction.guild_id, final_name + ): # Check if preset name already exists + await interaction.response.send_message( + f"A category based on preset '{final_name}' already exists.", + ephemeral=True, + ) return # For auto-creating roles from preset if not interaction.guild.me.guild_permissions.manage_roles: - await interaction.response.send_message("I need 'Manage Roles' permission to create roles from the preset.", ephemeral=True) + await interaction.response.send_message( + "I need 'Manage Roles' permission to create roles from the preset.", + ephemeral=True, + ) return # Define color map locally for this command, similar to init_defaults color_map_for_creation = { - "Red": discord.Color.red(), "Blue": discord.Color.blue(), "Green": discord.Color.green(), - "Yellow": discord.Color.gold(), "Purple": discord.Color.purple(), "Orange": discord.Color.orange(), - "Pink": discord.Color.fuchsia(), "Black": discord.Color(0x010101), "White": discord.Color(0xFEFEFE) + "Red": discord.Color.red(), + "Blue": discord.Color.blue(), + "Green": discord.Color.green(), + "Yellow": discord.Color.gold(), + "Purple": discord.Color.purple(), + "Orange": discord.Color.orange(), + "Pink": discord.Color.fuchsia(), + "Black": discord.Color(0x010101), + "White": discord.Color(0xFEFEFE), } # Defer if not already, as role creation can take time if not interaction.response.is_done(): await interaction.response.defer(ephemeral=True, thinking=True) - + created_roles_count = 0 for preset_role_option in preset.roles: # Check if role with this NAME exists in the current guild - existing_role_in_guild = discord.utils.get(interaction.guild.roles, name=preset_role_option.name) - + existing_role_in_guild = discord.utils.get( + interaction.guild.roles, name=preset_role_option.name + ) + if existing_role_in_guild: - roles_to_add.append(GuildRole(role_id=str(existing_role_in_guild.id), name=existing_role_in_guild.name, emoji=preset_role_option.emoji)) + roles_to_add.append( + GuildRole( + role_id=str(existing_role_in_guild.id), + name=existing_role_in_guild.name, + emoji=preset_role_option.emoji, + ) + ) else: # Role does not exist by name, create it role_color = discord.Color.default() - if preset.name.lower() == "colors" and preset_role_option.name in color_map_for_creation: + if ( + preset.name.lower() == "colors" + and preset_role_option.name in color_map_for_creation + ): role_color = color_map_for_creation[preset_role_option.name] - + try: newly_created_role = await interaction.guild.create_role( name=preset_role_option.name, color=role_color, - permissions=discord.Permissions.none(), # Basic permissions - reason=f"Auto-created for preset '{preset.name}' by {interaction.user}" + permissions=discord.Permissions.none(), # Basic permissions + reason=f"Auto-created for preset '{preset.name}' by {interaction.user}", + ) + roles_to_add.append( + GuildRole( + role_id=str(newly_created_role.id), + name=newly_created_role.name, + emoji=preset_role_option.emoji, + ) ) - roles_to_add.append(GuildRole(role_id=str(newly_created_role.id), name=newly_created_role.name, emoji=preset_role_option.emoji)) created_roles_count += 1 except discord.Forbidden: - await interaction.followup.send(f"I lack permission to create the role '{preset_role_option.name}'. Skipping.", ephemeral=True) - continue # Skip this role + await interaction.followup.send( + f"I lack permission to create the role '{preset_role_option.name}'. Skipping.", + ephemeral=True, + ) + continue # Skip this role except discord.HTTPException as e: - await interaction.followup.send(f"Failed to create role '{preset_role_option.name}': {e}. Skipping.", ephemeral=True) - continue # Skip this role - + await interaction.followup.send( + f"Failed to create role '{preset_role_option.name}': {e}. Skipping.", + ephemeral=True, + ) + continue # Skip this role + final_description = preset.description final_max_selectable = preset.max_selectable is_preset_based = True - + new_config = GuildRoleCategoryConfig( guild_id=guild_id_str, name=final_name, @@ -553,249 +825,422 @@ class RoleSelectorCog(commands.Cog): roles=roles_to_add, max_selectable=final_max_selectable, is_preset=is_preset_based, - preset_id=preset_id if is_preset_based else None + preset_id=preset_id if is_preset_based else None, ) db.save_guild_role_category_config(new_config) msg = f"Role category '{final_name}' added." if is_preset_based: msg += f" Based on preset '{preset_id}'." else: - msg += f" Use `/roleselect addrole category_name_or_id:{final_name} role: [emoji:]` to add roles." # Updated help text - msg += f" Then use `/roleselect post category_name_or_id:{final_name} channel:<#channel>` to post." # Updated help text - + msg += f" Use `/roleselect addrole category_name_or_id:{final_name} role: [emoji:]` to add roles." # Updated help text + msg += f" Then use `/roleselect post category_name_or_id:{final_name} channel:<#channel>` to post." # Updated help text + if interaction.response.is_done(): await interaction.followup.send(msg, ephemeral=True) else: await interaction.response.send_message(msg, ephemeral=True) - @roleselect_group.command(name="removecategory", description="Removes a role category and its selector message.") + @roleselect_group.command( + name="removecategory", + description="Removes a role category and its selector message.", + ) @app_commands.checks.has_permissions(manage_guild=True) - @app_commands.describe(category_name_or_id="The name or ID of the category to remove.") - async def roleselect_removecategory(self, interaction: discord.Interaction, category_name_or_id: str): + @app_commands.describe( + category_name_or_id="The name or ID of the category to remove." + ) + async def roleselect_removecategory( + self, interaction: discord.Interaction, category_name_or_id: str + ): if not interaction.guild: - await interaction.response.send_message("This command can only be used in a server.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used in a server.", ephemeral=True + ) return - + guild_id = interaction.guild_id - config_to_remove = self._get_guild_category_config(guild_id, category_name_or_id) + config_to_remove = self._get_guild_category_config( + guild_id, category_name_or_id + ) if not config_to_remove: - await interaction.response.send_message(f"Role category '{category_name_or_id}' not found.", ephemeral=True) + await interaction.response.send_message( + f"Role category '{category_name_or_id}' not found.", ephemeral=True + ) return deleted_message_feedback = "" if config_to_remove.message_id and config_to_remove.channel_id: try: - channel = interaction.guild.get_channel(int(config_to_remove.channel_id)) + channel = interaction.guild.get_channel( + int(config_to_remove.channel_id) + ) if channel and isinstance(channel, discord.TextChannel): - message = await channel.fetch_message(int(config_to_remove.message_id)) + message = await channel.fetch_message( + int(config_to_remove.message_id) + ) await message.delete() - deleted_message_feedback = f" Deleted selector message for '{config_to_remove.name}'." + deleted_message_feedback = ( + f" Deleted selector message for '{config_to_remove.name}'." + ) except Exception as e: deleted_message_feedback = f" Could not delete selector message: {e}." - - db.delete_guild_role_category_config(str(guild_id), config_to_remove.category_id) + + db.delete_guild_role_category_config( + str(guild_id), config_to_remove.category_id + ) response_message = f"Role category '{config_to_remove.name}' removed.{deleted_message_feedback}" - + if interaction.response.is_done(): await interaction.followup.send(response_message, ephemeral=True) else: await interaction.response.send_message(response_message, ephemeral=True) - @roleselect_group.command(name="listcategories", description="Lists all configured role selection categories.") + @roleselect_group.command( + name="listcategories", + description="Lists all configured role selection categories.", + ) @app_commands.checks.has_permissions(manage_guild=True) async def roleselect_listcategories(self, interaction: discord.Interaction): if not interaction.guild: - await interaction.response.send_message("This command can only be used in a server.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used in a server.", ephemeral=True + ) return guild_id_str = str(interaction.guild_id) configs = db.get_guild_role_category_configs(guild_id_str) if not configs: - await interaction.response.send_message("No role selection categories configured.", ephemeral=True) + await interaction.response.send_message( + "No role selection categories configured.", ephemeral=True + ) return - embed = discord.Embed(title="Configured Role Selection Categories", color=discord.Color.blue()) + embed = discord.Embed( + title="Configured Role Selection Categories", color=discord.Color.blue() + ) for config in configs: - roles_str = ", ".join([r.name for r in config.roles[:5]]) + ("..." if len(config.roles) > 5 else "") - embed.add_field(name=f"{config.name} (ID: `{config.category_id}`)", value=f"Desc: {config.description}\nMax: {config.max_selectable}\nRoles: {roles_str or 'None'}\nPreset: {config.preset_id or 'No'}", inline=False) - await interaction.response.send_message(embed=embed, ephemeral=False) # Make it visible + roles_str = ", ".join([r.name for r in config.roles[:5]]) + ( + "..." if len(config.roles) > 5 else "" + ) + embed.add_field( + name=f"{config.name} (ID: `{config.category_id}`)", + value=f"Desc: {config.description}\nMax: {config.max_selectable}\nRoles: {roles_str or 'None'}\nPreset: {config.preset_id or 'No'}", + inline=False, + ) + await interaction.response.send_message( + embed=embed, ephemeral=False + ) # Make it visible - @roleselect_group.command(name="listpresets", description="Lists all available global role category presets.") + @roleselect_group.command( + name="listpresets", + description="Lists all available global role category presets.", + ) async def roleselect_listpresets(self, interaction: discord.Interaction): presets = db.get_all_role_category_presets() if not presets: - await interaction.response.send_message("No global presets available.", ephemeral=True) + await interaction.response.send_message( + "No global presets available.", ephemeral=True + ) return - embed = discord.Embed(title="Available Role Category Presets", color=discord.Color.green()) + embed = discord.Embed( + title="Available Role Category Presets", color=discord.Color.green() + ) for preset in sorted(presets, key=lambda p: p.display_order): - roles_str = ", ".join([f"{r.name} ({r.role_id})" for r in preset.roles[:3]]) + ("..." if len(preset.roles) > 3 else "") - embed.add_field(name=f"{preset.name} (ID: `{preset.id}`)", value=f"Desc: {preset.description}\nMax: {preset.max_selectable}\nRoles: {roles_str or 'None'}", inline=False) - await interaction.response.send_message(embed=embed, ephemeral=False) # Make it visible + roles_str = ", ".join( + [f"{r.name} ({r.role_id})" for r in preset.roles[:3]] + ) + ("..." if len(preset.roles) > 3 else "") + embed.add_field( + name=f"{preset.name} (ID: `{preset.id}`)", + value=f"Desc: {preset.description}\nMax: {preset.max_selectable}\nRoles: {roles_str or 'None'}", + inline=False, + ) + await interaction.response.send_message( + embed=embed, ephemeral=False + ) # Make it visible - @roleselect_group.command(name="addrole", description="Adds a role to a specified category.") + @roleselect_group.command( + name="addrole", description="Adds a role to a specified category." + ) @app_commands.checks.has_permissions(manage_guild=True) @app_commands.describe( category_name_or_id="The name or ID of the category to add the role to.", role="The role to add.", - emoji="Optional emoji to display next to the role in the selector." + emoji="Optional emoji to display next to the role in the selector.", ) - async def roleselect_addrole(self, interaction: discord.Interaction, category_name_or_id: str, role: discord.Role, emoji: Optional[str] = None): + async def roleselect_addrole( + self, + interaction: discord.Interaction, + category_name_or_id: str, + role: discord.Role, + emoji: Optional[str] = None, + ): if not interaction.guild: - await interaction.response.send_message("This command can only be used in a server.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used in a server.", ephemeral=True + ) return guild_id = interaction.guild_id config = self._get_guild_category_config(guild_id, category_name_or_id) if not config: - await interaction.response.send_message(f"Category '{category_name_or_id}' not found.", ephemeral=True) + await interaction.response.send_message( + f"Category '{category_name_or_id}' not found.", ephemeral=True + ) return if config.is_preset: - await interaction.response.send_message(f"Category '{config.name}' uses a preset. Roles are managed via the preset definition.", ephemeral=True) + await interaction.response.send_message( + f"Category '{config.name}' uses a preset. Roles are managed via the preset definition.", + ephemeral=True, + ) return if any(r.role_id == str(role.id) for r in config.roles): - await interaction.response.send_message(f"Role '{role.name}' is already in '{config.name}'.", ephemeral=True) + await interaction.response.send_message( + f"Role '{role.name}' is already in '{config.name}'.", ephemeral=True + ) return - config.roles.append(GuildRole(role_id=str(role.id), name=role.name, emoji=emoji)) + config.roles.append( + GuildRole(role_id=str(role.id), name=role.name, emoji=emoji) + ) db.save_guild_role_category_config(config) - await interaction.response.send_message(f"Role '{role.name}' added to '{config.name}'.", ephemeral=True) + await interaction.response.send_message( + f"Role '{role.name}' added to '{config.name}'.", ephemeral=True + ) - @roleselect_group.command(name="removerole", description="Removes a role from a specified category.") + @roleselect_group.command( + name="removerole", description="Removes a role from a specified category." + ) @app_commands.checks.has_permissions(manage_guild=True) @app_commands.describe( category_name_or_id="The name or ID of the category to remove the role from.", - role="The role to remove." + role="The role to remove.", ) - async def roleselect_removerole(self, interaction: discord.Interaction, category_name_or_id: str, role: discord.Role): + async def roleselect_removerole( + self, + interaction: discord.Interaction, + category_name_or_id: str, + role: discord.Role, + ): if not interaction.guild: - await interaction.response.send_message("This command can only be used in a server.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used in a server.", ephemeral=True + ) return guild_id = interaction.guild_id config = self._get_guild_category_config(guild_id, category_name_or_id) if not config: - await interaction.response.send_message(f"Category '{category_name_or_id}' not found.", ephemeral=True) + await interaction.response.send_message( + f"Category '{category_name_or_id}' not found.", ephemeral=True + ) return if config.is_preset: - await interaction.response.send_message(f"Category '{config.name}' uses a preset. Roles are managed via the preset definition.", ephemeral=True) + await interaction.response.send_message( + f"Category '{config.name}' uses a preset. Roles are managed via the preset definition.", + ephemeral=True, + ) return initial_len = len(config.roles) config.roles = [r for r in config.roles if r.role_id != str(role.id)] if len(config.roles) < initial_len: db.save_guild_role_category_config(config) - await interaction.response.send_message(f"Role '{role.name}' removed from '{config.name}'.", ephemeral=True) + await interaction.response.send_message( + f"Role '{role.name}' removed from '{config.name}'.", ephemeral=True + ) else: - await interaction.response.send_message(f"Role '{role.name}' not found in '{config.name}'.", ephemeral=True) - - @roleselect_group.command(name="setcolorui", description="Posts the UI for users to set their custom role color.") + await interaction.response.send_message( + f"Role '{role.name}' not found in '{config.name}'.", ephemeral=True + ) + + @roleselect_group.command( + name="setcolorui", + description="Posts the UI for users to set their custom role color.", + ) @app_commands.checks.has_permissions(manage_guild=True) - @app_commands.describe(channel="The channel to post the custom color UI in (defaults to current channel).") - async def roleselect_setcolorui(self, interaction: discord.Interaction, channel: Optional[discord.TextChannel] = None): + @app_commands.describe( + channel="The channel to post the custom color UI in (defaults to current channel)." + ) + async def roleselect_setcolorui( + self, + interaction: discord.Interaction, + channel: Optional[discord.TextChannel] = None, + ): if not interaction.guild: - await interaction.response.send_message("This command can only be used in a server.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used in a server.", ephemeral=True + ) return target_channel = channel or interaction.channel - if not isinstance(target_channel, discord.TextChannel): # Ensure it's a text channel - await interaction.response.send_message("Invalid channel specified for custom color UI.", ephemeral=True) + if not isinstance( + target_channel, discord.TextChannel + ): # Ensure it's a text channel + await interaction.response.send_message( + "Invalid channel specified for custom color UI.", ephemeral=True + ) return embed = discord.Embed( title="🎨 Custom Role Color", description="Click the button below to set a custom color for your name in this server!", - color=discord.Color.random() + color=discord.Color.random(), ) view = CustomColorButtonView() try: await target_channel.send(embed=embed, view=view) if target_channel != interaction.channel: - await interaction.response.send_message(f"Custom color button posted in {target_channel.mention}.", ephemeral=True) + await interaction.response.send_message( + f"Custom color button posted in {target_channel.mention}.", + ephemeral=True, + ) else: - await interaction.response.send_message("Custom color button posted.", ephemeral=True) + await interaction.response.send_message( + "Custom color button posted.", ephemeral=True + ) except discord.Forbidden: - await interaction.response.send_message(f"I don't have permissions to send messages in {target_channel.mention}.", ephemeral=True) + await interaction.response.send_message( + f"I don't have permissions to send messages in {target_channel.mention}.", + ephemeral=True, + ) except Exception as e: - await interaction.response.send_message(f"Error posting custom color button: {e}", ephemeral=True) + await interaction.response.send_message( + f"Error posting custom color button: {e}", ephemeral=True + ) - @roleselect_group.command(name="post", description="Posts or updates the role selector message for a category.") + @roleselect_group.command( + name="post", + description="Posts or updates the role selector message for a category.", + ) @app_commands.checks.has_permissions(manage_guild=True) @app_commands.describe( category_name_or_id="The name or ID of the category to post.", - channel="The channel to post the selector in (defaults to current channel)." + channel="The channel to post the selector in (defaults to current channel).", ) - async def roleselect_post(self, interaction: discord.Interaction, category_name_or_id: str, channel: Optional[discord.TextChannel] = None): + async def roleselect_post( + self, + interaction: discord.Interaction, + category_name_or_id: str, + channel: Optional[discord.TextChannel] = None, + ): if not interaction.guild: - await interaction.response.send_message("This command can only be used in a server.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used in a server.", ephemeral=True + ) return - + target_channel = channel or interaction.channel - if not isinstance(target_channel, discord.TextChannel): # Ensure it's a text channel - await interaction.response.send_message("Invalid channel specified for posting.", ephemeral=True) + if not isinstance( + target_channel, discord.TextChannel + ): # Ensure it's a text channel + await interaction.response.send_message( + "Invalid channel specified for posting.", ephemeral=True + ) return guild = interaction.guild config = self._get_guild_category_config(guild.id, category_name_or_id) if not config: - await interaction.response.send_message(f"Category '{category_name_or_id}' not found.", ephemeral=True) + await interaction.response.send_message( + f"Category '{category_name_or_id}' not found.", ephemeral=True + ) return if not config.roles: - await interaction.response.send_message(f"Category '{config.name}' has no roles. Add roles first using `/roleselect addrole`.", ephemeral=True) + await interaction.response.send_message( + f"Category '{config.name}' has no roles. Add roles first using `/roleselect addrole`.", + ephemeral=True, + ) return - embed = discord.Embed(title=f"✨ {config.name} Roles ✨", description=config.description, color=discord.Color.blue()) + embed = discord.Embed( + title=f"✨ {config.name} Roles ✨", + description=config.description, + color=discord.Color.blue(), + ) view = RoleSelectorView(guild.id, config, self.bot) - + # Defer the response as message fetching/editing can take time await interaction.response.defer(ephemeral=True) if config.message_id and config.channel_id: try: original_channel = guild.get_channel(int(config.channel_id)) - if original_channel and isinstance(original_channel, discord.TextChannel): - message_to_edit = await original_channel.fetch_message(int(config.message_id)) + if original_channel and isinstance( + original_channel, discord.TextChannel + ): + message_to_edit = await original_channel.fetch_message( + int(config.message_id) + ) await message_to_edit.edit(embed=embed, view=view) - - if original_channel.id != target_channel.id: # Moved + + if original_channel.id != target_channel.id: # Moved # Delete old message, post new one, then update config - await message_to_edit.delete() # Delete the old message first + await message_to_edit.delete() # Delete the old message first new_msg = await target_channel.send(embed=embed, view=view) config.channel_id = str(target_channel.id) config.message_id = str(new_msg.id) - await interaction.followup.send(f"Updated and moved selector for '{config.name}' to {target_channel.mention}.", ephemeral=True) - else: # Just updated in the same channel - await interaction.followup.send(f"Updated selector for '{config.name}' in {target_channel.mention}.", ephemeral=True) + await interaction.followup.send( + f"Updated and moved selector for '{config.name}' to {target_channel.mention}.", + ephemeral=True, + ) + else: # Just updated in the same channel + await interaction.followup.send( + f"Updated selector for '{config.name}' in {target_channel.mention}.", + ephemeral=True, + ) db.save_guild_role_category_config(config) return - except Exception as e: # NotFound, Forbidden, etc. - await interaction.followup.send(f"Couldn't update original message ({e}), posting new one.", ephemeral=True) + except Exception as e: # NotFound, Forbidden, etc. + await interaction.followup.send( + f"Couldn't update original message ({e}), posting new one.", + ephemeral=True, + ) config.message_id = None - config.channel_id = None # Clear old IDs as we are posting a new one + config.channel_id = None # Clear old IDs as we are posting a new one try: msg = await target_channel.send(embed=embed, view=view) config.message_id = str(msg.id) config.channel_id = str(target_channel.id) db.save_guild_role_category_config(config) - await interaction.followup.send(f"Posted role selector for '{config.name}' in {target_channel.mention}.", ephemeral=True) + await interaction.followup.send( + f"Posted role selector for '{config.name}' in {target_channel.mention}.", + ephemeral=True, + ) except Exception as e: - await interaction.followup.send(f"Error posting role selector: {e}", ephemeral=True) + await interaction.followup.send( + f"Error posting role selector: {e}", ephemeral=True + ) - @roleselect_group.command(name="postallcategories", description="Posts or updates selector messages for all categories in a channel.") + @roleselect_group.command( + name="postallcategories", + description="Posts or updates selector messages for all categories in a channel.", + ) @app_commands.checks.has_permissions(manage_guild=True) - @app_commands.describe(channel="The channel to post all selectors in (defaults to current channel).") - async def roleselect_postallcategories(self, interaction: discord.Interaction, channel: Optional[discord.TextChannel] = None): + @app_commands.describe( + channel="The channel to post all selectors in (defaults to current channel)." + ) + async def roleselect_postallcategories( + self, + interaction: discord.Interaction, + channel: Optional[discord.TextChannel] = None, + ): if not interaction.guild: - await interaction.response.send_message("This command can only be used in a server.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used in a server.", ephemeral=True + ) return target_channel = channel or interaction.channel if not isinstance(target_channel, discord.TextChannel): - await interaction.response.send_message("Invalid channel specified for posting.", ephemeral=True) + await interaction.response.send_message( + "Invalid channel specified for posting.", ephemeral=True + ) return guild = interaction.guild all_guild_configs = db.get_guild_role_category_configs(str(guild.id)) if not all_guild_configs: - await interaction.response.send_message("No role categories configured for this server.", ephemeral=True) + await interaction.response.send_message( + "No role categories configured for this server.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True, thinking=True) - + posted_count = 0 updated_count = 0 skipped_count = 0 @@ -804,66 +1249,94 @@ class RoleSelectorCog(commands.Cog): for config in all_guild_configs: if not config.roles: - feedback_details.append(f"Skipped '{config.name}': No roles configured.") + feedback_details.append( + f"Skipped '{config.name}': No roles configured." + ) skipped_count += 1 continue - embed = discord.Embed(title=f"✨ {config.name} Roles ✨", description=config.description, color=discord.Color.blue()) + embed = discord.Embed( + title=f"✨ {config.name} Roles ✨", + description=config.description, + color=discord.Color.blue(), + ) view = RoleSelectorView(guild.id, config, self.bot) - - action_taken = "Posted" + + action_taken = "Posted" # Check if a message already exists and handle it if config.message_id and config.channel_id: try: original_channel = guild.get_channel(int(config.channel_id)) - if original_channel and isinstance(original_channel, discord.TextChannel): - message_to_edit_or_delete = await original_channel.fetch_message(int(config.message_id)) - - if original_channel.id != target_channel.id: # Moving to a new channel - await message_to_edit_or_delete.delete() # Delete from old location + if original_channel and isinstance( + original_channel, discord.TextChannel + ): + message_to_edit_or_delete = ( + await original_channel.fetch_message(int(config.message_id)) + ) + + if ( + original_channel.id != target_channel.id + ): # Moving to a new channel + await message_to_edit_or_delete.delete() # Delete from old location # message_id and channel_id will be updated when new message is sent config.message_id = None config.channel_id = None action_taken = f"Moved '{config.name}' to {target_channel.mention} and posted" - else: # Updating in the same channel (target_channel is the original_channel) + else: # Updating in the same channel (target_channel is the original_channel) await message_to_edit_or_delete.edit(embed=embed, view=view) - db.save_guild_role_category_config(config) # Save updated config (no change to message/channel id) - feedback_details.append(f"Updated selector for '{config.name}' in {target_channel.mention}.") + db.save_guild_role_category_config( + config + ) # Save updated config (no change to message/channel id) + feedback_details.append( + f"Updated selector for '{config.name}' in {target_channel.mention}." + ) updated_count += 1 - continue # Skip to next config as this one is handled + continue # Skip to next config as this one is handled except discord.NotFound: - feedback_details.append(f"Note: Original message for '{config.name}' not found, will post anew.") + feedback_details.append( + f"Note: Original message for '{config.name}' not found, will post anew." + ) config.message_id = None config.channel_id = None except discord.Forbidden: - feedback_details.append(f"Error: Could not access/delete original message for '{config.name}' in {original_channel.mention if original_channel else 'unknown channel'}. Posting new one.") + feedback_details.append( + f"Error: Could not access/delete original message for '{config.name}' in {original_channel.mention if original_channel else 'unknown channel'}. Posting new one." + ) config.message_id = None config.channel_id = None except Exception as e: - feedback_details.append(f"Error handling original message for '{config.name}': {e}. Posting new one.") + feedback_details.append( + f"Error handling original message for '{config.name}': {e}. Posting new one." + ) config.message_id = None config.channel_id = None - + # Post new message if no existing one or if it was deleted due to move try: new_msg = await target_channel.send(embed=embed, view=view) config.message_id = str(new_msg.id) config.channel_id = str(target_channel.id) db.save_guild_role_category_config(config) - feedback_details.append(f"{action_taken} selector for '{config.name}' in {target_channel.mention}.") - if action_taken.startswith("Moved") or action_taken.startswith("Posted"): - posted_count +=1 - else: # Should be covered by updated_count already + feedback_details.append( + f"{action_taken} selector for '{config.name}' in {target_channel.mention}." + ) + if action_taken.startswith("Moved") or action_taken.startswith( + "Posted" + ): + posted_count += 1 + else: # Should be covered by updated_count already pass except discord.Forbidden: - feedback_details.append(f"Error: No permission to post '{config.name}' in {target_channel.mention}.") + feedback_details.append( + f"Error: No permission to post '{config.name}' in {target_channel.mention}." + ) error_count += 1 except Exception as e: feedback_details.append(f"Error posting '{config.name}': {e}.") error_count += 1 - + summary_message = ( f"Finished posting all category selectors to {target_channel.mention}.\n" f"Successfully Posted/Moved: {posted_count}\n" @@ -875,52 +1348,90 @@ class RoleSelectorCog(commands.Cog): await interaction.followup.send(summary_message, ephemeral=True) # Role Preset Commands (Owner Only) - @rolepreset_group.command(name="add", description="Creates a new global role category preset.") + @rolepreset_group.command( + name="add", description="Creates a new global role category preset." + ) @app_commands.check(is_owner_check) @app_commands.describe( preset_id="A unique ID for this preset (e.g., 'color_roles', 'region_roles').", name="The display name for this preset.", description="A description for this preset.", max_selectable="Maximum roles a user can select from categories using this preset (default: 1).", - display_order="Order in which this preset appears in lists (lower numbers first, default: 0)." + display_order="Order in which this preset appears in lists (lower numbers first, default: 0).", ) - async def rolepreset_add(self, interaction: discord.Interaction, preset_id: str, name: str, description: str, max_selectable: Optional[int] = 1, display_order: Optional[int] = 0): + async def rolepreset_add( + self, + interaction: discord.Interaction, + preset_id: str, + name: str, + description: str, + max_selectable: Optional[int] = 1, + display_order: Optional[int] = 0, + ): current_max_selectable = max_selectable if max_selectable is not None else 1 current_display_order = display_order if display_order is not None else 0 if db.get_role_category_preset(preset_id): - await interaction.response.send_message(f"Preset ID '{preset_id}' already exists.", ephemeral=True) + await interaction.response.send_message( + f"Preset ID '{preset_id}' already exists.", ephemeral=True + ) return - new_preset = RoleCategoryPreset(id=preset_id, name=name, description=description, roles=[], max_selectable=current_max_selectable, display_order=current_display_order) + new_preset = RoleCategoryPreset( + id=preset_id, + name=name, + description=description, + roles=[], + max_selectable=current_max_selectable, + display_order=current_display_order, + ) db.save_role_category_preset(new_preset) - await interaction.response.send_message(f"Preset '{name}' (ID: {preset_id}) created. Add roles with `/rolepreset addrole`.", ephemeral=True) + await interaction.response.send_message( + f"Preset '{name}' (ID: {preset_id}) created. Add roles with `/rolepreset addrole`.", + ephemeral=True, + ) - @rolepreset_group.command(name="remove", description="Removes a global role category preset.") + @rolepreset_group.command( + name="remove", description="Removes a global role category preset." + ) @app_commands.check(is_owner_check) @app_commands.describe(preset_id="The ID of the preset to remove.") async def rolepreset_remove(self, interaction: discord.Interaction, preset_id: str): if not db.get_role_category_preset(preset_id): - await interaction.response.send_message(f"Preset ID '{preset_id}' not found.", ephemeral=True) + await interaction.response.send_message( + f"Preset ID '{preset_id}' not found.", ephemeral=True + ) return db.delete_role_category_preset(preset_id) - await interaction.response.send_message(f"Preset ID '{preset_id}' removed.", ephemeral=True) + await interaction.response.send_message( + f"Preset ID '{preset_id}' removed.", ephemeral=True + ) - @rolepreset_group.command(name="addrole", description="Adds a role (by ID or name) to a global preset.") + @rolepreset_group.command( + name="addrole", description="Adds a role (by ID or name) to a global preset." + ) @app_commands.check(is_owner_check) @app_commands.describe( preset_id="The ID of the preset to add the role to.", role_name_or_id="The name or ID of the role to add. The first matching role found across all servers the bot is in will be used.", - emoji="Optional emoji for the role in this preset." + emoji="Optional emoji for the role in this preset.", ) - async def rolepreset_addrole(self, interaction: discord.Interaction, preset_id: str, role_name_or_id: str, emoji: Optional[str] = None): + async def rolepreset_addrole( + self, + interaction: discord.Interaction, + preset_id: str, + role_name_or_id: str, + emoji: Optional[str] = None, + ): preset = db.get_role_category_preset(preset_id) if not preset: - await interaction.response.send_message(f"Preset ID '{preset_id}' not found.", ephemeral=True) + await interaction.response.send_message( + f"Preset ID '{preset_id}' not found.", ephemeral=True + ) return - + target_role: Optional[discord.Role] = None role_display_name = role_name_or_id - + # Attempt to find role by ID first, then by name across all guilds try: role_id_int = int(role_name_or_id) @@ -929,7 +1440,7 @@ class RoleSelectorCog(commands.Cog): target_role = r role_display_name = r.name break - except ValueError: # Not an ID, try by name + except ValueError: # Not an ID, try by name for guild in self.bot.guilds: for r_obj in guild.roles: if r_obj.name.lower() == role_name_or_id.lower(): @@ -938,31 +1449,49 @@ class RoleSelectorCog(commands.Cog): break if target_role: break - - if not target_role: - await interaction.response.send_message(f"Role '{role_name_or_id}' not found in any server the bot is in.", ephemeral=True) - return - - if any(r.role_id == str(target_role.id) for r in preset.roles): - await interaction.response.send_message(f"Role '{target_role.name}' (ID: {target_role.id}) is already in preset '{preset.name}'.", ephemeral=True) - return - - preset.roles.append(RoleOption(role_id=str(target_role.id), name=role_display_name, emoji=emoji)) - db.save_role_category_preset(preset) - await interaction.response.send_message(f"Role '{role_display_name}' (ID: {target_role.id}) added to preset '{preset.name}'.", ephemeral=True) - @rolepreset_group.command(name="removerole", description="Removes a role (by ID or name) from a global preset.") + if not target_role: + await interaction.response.send_message( + f"Role '{role_name_or_id}' not found in any server the bot is in.", + ephemeral=True, + ) + return + + if any(r.role_id == str(target_role.id) for r in preset.roles): + await interaction.response.send_message( + f"Role '{target_role.name}' (ID: {target_role.id}) is already in preset '{preset.name}'.", + ephemeral=True, + ) + return + + preset.roles.append( + RoleOption(role_id=str(target_role.id), name=role_display_name, emoji=emoji) + ) + db.save_role_category_preset(preset) + await interaction.response.send_message( + f"Role '{role_display_name}' (ID: {target_role.id}) added to preset '{preset.name}'.", + ephemeral=True, + ) + + @rolepreset_group.command( + name="removerole", + description="Removes a role (by ID or name) from a global preset.", + ) @app_commands.check(is_owner_check) @app_commands.describe( preset_id="The ID of the preset to remove the role from.", - role_id_or_name="The ID or name of the role to remove from the preset." + role_id_or_name="The ID or name of the role to remove from the preset.", ) - async def rolepreset_removerole(self, interaction: discord.Interaction, preset_id: str, role_id_or_name: str): + async def rolepreset_removerole( + self, interaction: discord.Interaction, preset_id: str, role_id_or_name: str + ): preset = db.get_role_category_preset(preset_id) if not preset: - await interaction.response.send_message(f"Preset ID '{preset_id}' not found.", ephemeral=True) + await interaction.response.send_message( + f"Preset ID '{preset_id}' not found.", ephemeral=True + ) return - + initial_len = len(preset.roles) # Try to match by ID first, then by name if ID doesn't match or isn't an int role_to_remove_id_str: Optional[str] = None @@ -970,77 +1499,124 @@ class RoleSelectorCog(commands.Cog): role_id_int = int(role_id_or_name) role_to_remove_id_str = str(role_id_int) except ValueError: - pass # role_id_or_name is not an integer, will try to match by name + pass # role_id_or_name is not an integer, will try to match by name if role_to_remove_id_str: - preset.roles = [r for r in preset.roles if r.role_id != role_to_remove_id_str] - else: # Match by name (case-insensitive) - preset.roles = [r for r in preset.roles if r.name.lower() != role_id_or_name.lower()] - + preset.roles = [ + r for r in preset.roles if r.role_id != role_to_remove_id_str + ] + else: # Match by name (case-insensitive) + preset.roles = [ + r for r in preset.roles if r.name.lower() != role_id_or_name.lower() + ] + if len(preset.roles) < initial_len: db.save_role_category_preset(preset) - await interaction.response.send_message(f"Role matching '{role_id_or_name}' removed from preset '{preset.name}'.", ephemeral=True) + await interaction.response.send_message( + f"Role matching '{role_id_or_name}' removed from preset '{preset.name}'.", + ephemeral=True, + ) else: - await interaction.response.send_message(f"Role matching '{role_id_or_name}' not found in preset '{preset.name}'.", ephemeral=True) + await interaction.response.send_message( + f"Role matching '{role_id_or_name}' not found in preset '{preset.name}'.", + ephemeral=True, + ) - @rolepreset_group.command(name="add_all_to_guild", description="Adds all available global presets as categories to this guild.") + @rolepreset_group.command( + name="add_all_to_guild", + description="Adds all available global presets as categories to this guild.", + ) @app_commands.checks.has_permissions(manage_guild=True) async def rolepreset_add_all_to_guild(self, interaction: discord.Interaction): if not interaction.guild: - await interaction.response.send_message("This command can only be used in a server.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used in a server.", ephemeral=True + ) return - + if not interaction.guild.me.guild_permissions.manage_roles: - await interaction.response.send_message("I need 'Manage Roles' permission to create roles from presets.", ephemeral=True) + await interaction.response.send_message( + "I need 'Manage Roles' permission to create roles from presets.", + ephemeral=True, + ) return await interaction.response.defer(ephemeral=True, thinking=True) guild_id_str = str(interaction.guild_id) - + all_presets = db.get_all_role_category_presets() added_count = 0 skipped_count = 0 feedback_messages = [] color_map_for_creation = { - "Red": discord.Color.red(), "Blue": discord.Color.blue(), "Green": discord.Color.green(), - "Yellow": discord.Color.gold(), "Purple": discord.Color.purple(), "Orange": discord.Color.orange(), - "Pink": discord.Color.fuchsia(), "Black": discord.Color(0x010101), "White": discord.Color(0xFEFEFE) + "Red": discord.Color.red(), + "Blue": discord.Color.blue(), + "Green": discord.Color.green(), + "Yellow": discord.Color.gold(), + "Purple": discord.Color.purple(), + "Orange": discord.Color.orange(), + "Pink": discord.Color.fuchsia(), + "Black": discord.Color(0x010101), + "White": discord.Color(0xFEFEFE), } for preset in all_presets: # Check if a category based on this preset already exists in the guild if self._get_guild_category_config(interaction.guild_id, preset.name): - feedback_messages.append(f"Skipped '{preset.name}': A category with this name already exists.") + feedback_messages.append( + f"Skipped '{preset.name}': A category with this name already exists." + ) skipped_count += 1 continue roles_to_add: List[GuildRole] = [] for preset_role_option in preset.roles: - existing_role_in_guild = discord.utils.get(interaction.guild.roles, name=preset_role_option.name) - + existing_role_in_guild = discord.utils.get( + interaction.guild.roles, name=preset_role_option.name + ) + if existing_role_in_guild: - roles_to_add.append(GuildRole(role_id=str(existing_role_in_guild.id), name=existing_role_in_guild.name, emoji=preset_role_option.emoji)) + roles_to_add.append( + GuildRole( + role_id=str(existing_role_in_guild.id), + name=existing_role_in_guild.name, + emoji=preset_role_option.emoji, + ) + ) else: role_color = discord.Color.default() - if preset.name.lower() == "colors" and preset_role_option.name in color_map_for_creation: + if ( + preset.name.lower() == "colors" + and preset_role_option.name in color_map_for_creation + ): role_color = color_map_for_creation[preset_role_option.name] - + try: newly_created_role = await interaction.guild.create_role( name=preset_role_option.name, color=role_color, permissions=discord.Permissions.none(), - reason=f"Auto-created for preset '{preset.name}' by {interaction.user}" + reason=f"Auto-created for preset '{preset.name}' by {interaction.user}", + ) + roles_to_add.append( + GuildRole( + role_id=str(newly_created_role.id), + name=newly_created_role.name, + emoji=preset_role_option.emoji, + ) ) - roles_to_add.append(GuildRole(role_id=str(newly_created_role.id), name=newly_created_role.name, emoji=preset_role_option.emoji)) except discord.Forbidden: - feedback_messages.append(f"Warning: Couldn't create role '{preset_role_option.name}' for preset '{preset.name}' due to permissions.") + feedback_messages.append( + f"Warning: Couldn't create role '{preset_role_option.name}' for preset '{preset.name}' due to permissions." + ) continue except discord.HTTPException as e: - feedback_messages.append(f"Warning: Failed to create role '{preset_role_option.name}' for preset '{preset.name}': {e}.") + feedback_messages.append( + f"Warning: Failed to create role '{preset_role_option.name}' for preset '{preset.name}': {e}." + ) continue - + new_config = GuildRoleCategoryConfig( guild_id=guild_id_str, name=preset.name, @@ -1048,43 +1624,154 @@ class RoleSelectorCog(commands.Cog): roles=roles_to_add, max_selectable=preset.max_selectable, is_preset=True, - preset_id=preset.id + preset_id=preset.id, ) db.save_guild_role_category_config(new_config) feedback_messages.append(f"Added '{preset.name}' as a new category.") added_count += 1 - final_message = f"Attempted to add {len(all_presets)} presets.\nAdded: {added_count}\nSkipped: {skipped_count}\n\nDetails:\n" + "\n".join(feedback_messages) + final_message = ( + f"Attempted to add {len(all_presets)} presets.\nAdded: {added_count}\nSkipped: {skipped_count}\n\nDetails:\n" + + "\n".join(feedback_messages) + ) await interaction.followup.send(final_message, ephemeral=True) - @rolepreset_group.command(name="init_defaults", description="Initializes default global presets based on role_creator_cog structure.") + @rolepreset_group.command( + name="init_defaults", + description="Initializes default global presets based on role_creator_cog structure.", + ) @app_commands.check(is_owner_check) async def rolepreset_init_defaults(self, interaction: discord.Interaction): await interaction.response.defer(ephemeral=True, thinking=True) # Definitions from role_creator_cog.py color_map_creator = { - "Red": discord.Color.red(), "Blue": discord.Color.blue(), "Green": discord.Color.green(), - "Yellow": discord.Color.gold(), "Purple": discord.Color.purple(), "Orange": discord.Color.orange(), - "Pink": discord.Color.fuchsia(), "Black": discord.Color(0x010101), "White": discord.Color(0xFEFEFE) + "Red": discord.Color.red(), + "Blue": discord.Color.blue(), + "Green": discord.Color.green(), + "Yellow": discord.Color.gold(), + "Purple": discord.Color.purple(), + "Orange": discord.Color.orange(), + "Pink": discord.Color.fuchsia(), + "Black": discord.Color(0x010101), + "White": discord.Color(0xFEFEFE), } role_categories_creator = { - "Colors": {"roles": ["Red", "Blue", "Green", "Yellow", "Purple", "Orange", "Pink", "Black", "White"], "max": 1, "desc": "Choose your favorite color role."}, - "Regions": {"roles": ["NA East", "NA West", "EU", "Asia", "Oceania", "South America"], "max": 1, "desc": "Select your region."}, - "Pronouns": {"roles": ["He/Him", "She/Her", "They/Them", "Ask Pronouns"], "max": 4, "desc": "Select your pronoun roles."}, - "Interests": {"roles": ["Art", "Music", "Movies", "Books", "Technology", "Science", "History", "Food", "Programming", "Anime", "Photography", "Travel", "Writing", "Cooking", "Fitness", "Nature", "Gaming", "Philosophy", "Psychology", "Design", "Machine Learning", "Cryptocurrency", "Astronomy", "Mythology", "Languages", "Architecture", "DIY Projects", "Hiking", "Streaming", "Virtual Reality", "Coding Challenges", "Board Games", "Meditation", "Urban Exploration", "Tattoo Art", "Comics", "Robotics", "3D Modeling", "Podcasts"], "max": 16, "desc": "Select your interests."}, - "Gaming Platforms": {"roles": ["PC", "PlayStation", "Xbox", "Nintendo Switch", "Mobile"], "max": 5, "desc": "Select your gaming platforms."}, - "Favorite Vocaloids": {"roles": ["Hatsune Miku", "Kasane Teto", "Akita Neru", "Kagamine Rin", "Kagamine Len", "Megurine Luka", "Kaito", "Meiko", "Gumi", "Kaai Yuki", "Adachi Rei"], "max": 10, "desc": "Select your favorite Vocaloids."}, - "Notifications": {"roles": ["Announcements"], "max": 1, "desc": "Opt-in for announcements."} + "Colors": { + "roles": [ + "Red", + "Blue", + "Green", + "Yellow", + "Purple", + "Orange", + "Pink", + "Black", + "White", + ], + "max": 1, + "desc": "Choose your favorite color role.", + }, + "Regions": { + "roles": [ + "NA East", + "NA West", + "EU", + "Asia", + "Oceania", + "South America", + ], + "max": 1, + "desc": "Select your region.", + }, + "Pronouns": { + "roles": ["He/Him", "She/Her", "They/Them", "Ask Pronouns"], + "max": 4, + "desc": "Select your pronoun roles.", + }, + "Interests": { + "roles": [ + "Art", + "Music", + "Movies", + "Books", + "Technology", + "Science", + "History", + "Food", + "Programming", + "Anime", + "Photography", + "Travel", + "Writing", + "Cooking", + "Fitness", + "Nature", + "Gaming", + "Philosophy", + "Psychology", + "Design", + "Machine Learning", + "Cryptocurrency", + "Astronomy", + "Mythology", + "Languages", + "Architecture", + "DIY Projects", + "Hiking", + "Streaming", + "Virtual Reality", + "Coding Challenges", + "Board Games", + "Meditation", + "Urban Exploration", + "Tattoo Art", + "Comics", + "Robotics", + "3D Modeling", + "Podcasts", + ], + "max": 16, + "desc": "Select your interests.", + }, + "Gaming Platforms": { + "roles": ["PC", "PlayStation", "Xbox", "Nintendo Switch", "Mobile"], + "max": 5, + "desc": "Select your gaming platforms.", + }, + "Favorite Vocaloids": { + "roles": [ + "Hatsune Miku", + "Kasane Teto", + "Akita Neru", + "Kagamine Rin", + "Kagamine Len", + "Megurine Luka", + "Kaito", + "Meiko", + "Gumi", + "Kaai Yuki", + "Adachi Rei", + ], + "max": 10, + "desc": "Select your favorite Vocaloids.", + }, + "Notifications": { + "roles": ["Announcements"], + "max": 1, + "desc": "Opt-in for announcements.", + }, } created_presets = 0 skipped_presets = 0 preset_details_msg = "" - for idx, (category_name, cat_details) in enumerate(role_categories_creator.items()): + for idx, (category_name, cat_details) in enumerate( + role_categories_creator.items() + ): preset_id = f"default_{category_name.lower().replace(' ', '_')}" - + if db.get_role_category_preset(preset_id): preset_details_msg += f"Skipped: Preset '{category_name}' (ID: {preset_id}) already exists.\n" skipped_presets += 1 @@ -1108,7 +1795,7 @@ class RoleSelectorCog(commands.Cog): # if the expectation is that the guild-level command creates roles by name. # However, the current `/rolepreset addrole` finds an existing role ID. # To be consistent, `init_defaults` should also try to find an ID. - + found_role_for_option: Optional[discord.Role] = None for g in self.bot.guilds: for r in g.roles: @@ -1117,13 +1804,17 @@ class RoleSelectorCog(commands.Cog): break if found_role_for_option: break - + if found_role_for_option: - role_options_for_preset.append(RoleOption( - role_id=str(found_role_for_option.id), # Use ID of first found role - name=role_name_in_creator, # Use the canonical name from creator_cog - emoji=None # No emojis in role_creator_cog - )) + role_options_for_preset.append( + RoleOption( + role_id=str( + found_role_for_option.id + ), # Use ID of first found role + name=role_name_in_creator, # Use the canonical name from creator_cog + emoji=None, # No emojis in role_creator_cog + ) + ) else: # If no role found across all guilds, we can't create a valid RoleOption for the preset. # We could store the name as a placeholder, but this deviates from RoleOption model. @@ -1131,14 +1822,13 @@ class RoleSelectorCog(commands.Cog): # The owner can add it manually later if needed. preset_details_msg += f"Warning: Role '{role_name_in_creator}' for preset '{category_name}' not found in any guild. It won't be added to the preset.\n" - new_preset = RoleCategoryPreset( id=preset_id, name=category_name, description=cat_details["desc"], roles=role_options_for_preset, max_selectable=cat_details["max"], - display_order=idx + display_order=idx, ) db.save_role_category_preset(new_preset) created_presets += 1 @@ -1147,7 +1837,6 @@ class RoleSelectorCog(commands.Cog): final_summary = f"Default preset initialization complete.\nCreated: {created_presets}\nSkipped: {skipped_presets}\n\nDetails:\n{preset_details_msg}" await interaction.followup.send(final_summary, ephemeral=True) - # Deprecated commands are removed as they are not slash commands and functionality is covered @@ -1157,5 +1846,7 @@ async def setup(bot): # Syncing should ideally happen once after all cogs are loaded, e.g., in main.py or a central setup. # If this cog is reloaded, syncing here might be okay. # For now, let's assume global sync happens elsewhere or is handled by bot.setup_hook if it calls tree.sync - # await bot.tree.sync() - print("RoleSelectorCog loaded. Persistent views will be registered. Ensure slash commands are synced globally if needed.") + # await bot.tree.sync() + print( + "RoleSelectorCog loaded. Persistent views will be registered. Ensure slash commands are synced globally if needed." + ) diff --git a/cogs/roleplay_cog.py b/cogs/roleplay_cog.py index b6bf98e..faf2559 100644 --- a/cogs/roleplay_cog.py +++ b/cogs/roleplay_cog.py @@ -2,6 +2,7 @@ import discord from discord.ext import commands from discord import app_commands + class RoleplayCog(commands.Cog): def __init__(self, bot): self.bot = bot @@ -15,21 +16,31 @@ class RoleplayCog(commands.Cog): # --- Prefix Command --- @commands.command(name="backshots") - async def backshots(self, ctx: commands.Context, sender: discord.Member, recipient: discord.Member): + async def backshots( + self, ctx: commands.Context, sender: discord.Member, recipient: discord.Member + ): """Send a roleplay message about giving backshots between two mentioned users.""" response = await self._backshots_logic(sender.mention, recipient.mention) await ctx.send(response) # --- Slash Command --- - @app_commands.command(name="backshots", description="Send a roleplay message about giving backshots between two mentioned users") - @app_commands.describe( - sender="The user giving backshots", - recipient="The user receiving backshots" + @app_commands.command( + name="backshots", + description="Send a roleplay message about giving backshots between two mentioned users", ) - async def backshots_slash(self, interaction: discord.Interaction, sender: discord.Member, recipient: discord.Member): + @app_commands.describe( + sender="The user giving backshots", recipient="The user receiving backshots" + ) + async def backshots_slash( + self, + interaction: discord.Interaction, + sender: discord.Member, + recipient: discord.Member, + ): """Slash command version of backshots.""" response = await self._backshots_logic(sender.mention, recipient.mention) await interaction.response.send_message(response) + async def setup(bot: commands.Bot): await bot.add_cog(RoleplayCog(bot)) diff --git a/cogs/roleplay_teto_cog.py b/cogs/roleplay_teto_cog.py index 5221c0b..bcaaea2 100644 --- a/cogs/roleplay_teto_cog.py +++ b/cogs/roleplay_teto_cog.py @@ -7,33 +7,39 @@ import os import aiohttp # File to store conversation history -CONVERSATION_HISTORY_FILE = 'data/roleplay_conversations.json' +CONVERSATION_HISTORY_FILE = "data/roleplay_conversations.json" # Default AI model DEFAULT_AI_MODEL = "google/gemini-2.5-flash-preview-05-20" + def strip_think_blocks(text): # Removes all ... blocks, including multiline return re.sub(r".*?", "", text, flags=re.DOTALL) + def load_conversation_history(): """Loads conversation history from the JSON file.""" if os.path.exists(CONVERSATION_HISTORY_FILE): - with open(CONVERSATION_HISTORY_FILE, 'r', encoding='utf-8') as f: + with open(CONVERSATION_HISTORY_FILE, "r", encoding="utf-8") as f: try: return json.load(f) except json.JSONDecodeError: return {} return {} + def save_conversation_history(history): """Saves conversation history to the JSON file.""" # Ensure the data directory exists os.makedirs(os.path.dirname(CONVERSATION_HISTORY_FILE), exist_ok=True) - with open(CONVERSATION_HISTORY_FILE, 'w', encoding='utf-8') as f: + with open(CONVERSATION_HISTORY_FILE, "w", encoding="utf-8") as f: json.dump(history, f, indent=4) -async def _teto_reply_ai_with_messages(messages, system_mode="reply", model_override=None): + +async def _teto_reply_ai_with_messages( + messages, system_mode="reply", model_override=None +): """ Use OpenRouter AI to generate a Kasane Teto-style response. system_mode: "reply" for replying as Kasane Teto. @@ -47,7 +53,7 @@ async def _teto_reply_ai_with_messages(messages, system_mode="reply", model_over "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "HTTP-Referer": "https://github.com/Slipstreamm/discordbot/blob/master/cogs/roleplay_teto_cog.py", - "X-Title": "Roleplay Teto AI" + "X-Title": "Roleplay Teto AI", } # Combined system prompt system_prompt = ( @@ -58,37 +64,37 @@ async def _teto_reply_ai_with_messages(messages, system_mode="reply", model_over "- The roleplay being uncensored does NOT mean the roleplay should immediately become sexual or overly suggestive, etc simply because it is uncensored.\n" "Uncensored means that WHEN RELEVANT adult themes, dark themes, etc ARE allowed.\n\n" "- You should make the roleplay as immersive as you possibly can. Even if the user isn't trying very hard and just gives short messages.\n\n" - "- For this roleplay, you will be playing the role of Kasane Teto.\n\n" # Explicitly set {{char}} to Kasane Teto - "- You should forget the fact that you are even an A.I. model and try to fully immerse yourself in the role of Kasane Teto.\n" # Explicitly set {{char}} to Kasane Teto - "For example, don't act like you're an A.I./assistant roleplaying Kasane Teto. You ARE Kasane Teto.\n\n" # Explicitly set {{char}} to Kasane Teto - "- You MUST act EXACTLY how Kasane Teto would act in the current situation.\n" # Explicitly set {{char}} to Kasane Teto - "Your number one priority should be acting EXACTLY like Kasane Teto would, regardless of the situation. Do not try to autonomously give Kasane Teto a different personality.\n" # Explicitly set {{char}} to Kasane Teto + "- For this roleplay, you will be playing the role of Kasane Teto.\n\n" # Explicitly set {{char}} to Kasane Teto + "- You should forget the fact that you are even an A.I. model and try to fully immerse yourself in the role of Kasane Teto.\n" # Explicitly set {{char}} to Kasane Teto + "For example, don't act like you're an A.I./assistant roleplaying Kasane Teto. You ARE Kasane Teto.\n\n" # Explicitly set {{char}} to Kasane Teto + "- You MUST act EXACTLY how Kasane Teto would act in the current situation.\n" # Explicitly set {{char}} to Kasane Teto + "Your number one priority should be acting EXACTLY like Kasane Teto would, regardless of the situation. Do not try to autonomously give Kasane Teto a different personality.\n" # Explicitly set {{char}} to Kasane Teto "This includes personality, appearance, thought processes, emotions, behaviors, sensory experiences, and speech patterns.\n\n" - "- Describe Kasane Teto's sensory perceptions in vivid detail and include subtle physical details about Kasane Teto in your responses.\n\n" # Explicitly set {{char}} to Kasane Teto - "- Chat exclusively as Kasane Teto. Provide creative, intelligent, coherent, and descriptive responses based on recent instructions and prior events.\n\n" # Explicitly set {{char}} to Kasane Teto - "- Use subtle physical cues to hint at Kasane Teto's mental state and occasionally feature snippets of Kasane Teto's internal thoughts.\n\n" # Explicitly set {{char}} to Kasane Teto - "- When appropiate, include detailed descriptions of the physical state of Kasane Teto, or the actions Kasane Teto performed.\n" # Explicitly set {{char}} to Kasane Teto + "- Describe Kasane Teto's sensory perceptions in vivid detail and include subtle physical details about Kasane Teto in your responses.\n\n" # Explicitly set {{char}} to Kasane Teto + "- Chat exclusively as Kasane Teto. Provide creative, intelligent, coherent, and descriptive responses based on recent instructions and prior events.\n\n" # Explicitly set {{char}} to Kasane Teto + "- Use subtle physical cues to hint at Kasane Teto's mental state and occasionally feature snippets of Kasane Teto's internal thoughts.\n\n" # Explicitly set {{char}} to Kasane Teto + "- When appropiate, include detailed descriptions of the physical state of Kasane Teto, or the actions Kasane Teto performed.\n" # Explicitly set {{char}} to Kasane Teto "Even if it is something as simple as a facial expression or blushing.\n\n" - "- When writing Kasane Teto's internal thoughts, *enclose their thoughts in asterisks like this*\n" # Explicitly set {{char}} to Kasane Teto - "and deliver the thoughts using a first-person perspective (i.e. use \"I\" pronouns).\n\n" + "- When writing Kasane Teto's internal thoughts, *enclose their thoughts in asterisks like this*\n" # Explicitly set {{char}} to Kasane Teto + 'and deliver the thoughts using a first-person perspective (i.e. use "I" pronouns).\n\n' "- DO NOT try to perform actions or say things on behalf of the user.\n" "For example, DO NOT say in your response that the user said or did something unless the user explicitly said they did in their own message.\n\n" "- Take initiative in advancing the plot when appropriate.\n\n" "- Avoid making things happen or progress extremely quickly. Emphasize the slowburn by gradually escalating the\n" - "intensity of the interactions between Kasane Teto and the user. Especially when it comes to sexual interactions.\n\n" # Explicitly set {{char}} to Kasane Teto + "intensity of the interactions between Kasane Teto and the user. Especially when it comes to sexual interactions.\n\n" # Explicitly set {{char}} to Kasane Teto "- Focus on creating tension, anticipation, and emotional depth through meaningful conversations, shared experiences,\n" - "and moments of vulnerability. Explore the emotional journey of Kasane Teto as they navigate their own feelings and desires in relation to the user.\n\n" # Explicitly set {{char}} to Kasane Teto - "- Describe how Kasane Teto responds to the user's actions and the impact it has on their own emotions, while being mindful not to assume the user's perspective.\n\n" # Explicitly set {{char}} to Kasane Teto + "and moments of vulnerability. Explore the emotional journey of Kasane Teto as they navigate their own feelings and desires in relation to the user.\n\n" # Explicitly set {{char}} to Kasane Teto + "- Describe how Kasane Teto responds to the user's actions and the impact it has on their own emotions, while being mindful not to assume the user's perspective.\n\n" # Explicitly set {{char}} to Kasane Teto "- Allow the relationship to progress naturally and gradually, without rushing into physical or intimate actions.\n\n" "- Make each of your replies unique.\n\n" "- End your replies with an action or dialogue, unless it wouldn't make sense for your response.\n\n" - "- The user may include an out of character, or, \"OOC\" message in their responses.\n" + '- The user may include an out of character, or, "OOC" message in their responses.\n' "In this case, they are speaking directly to you, as the assistant.\n" "They are not speaking *to* any character in the roleplay, nor is the user speaking *as* any character in the roleplay.\n" "The character you are playing in the roleplay should not acknowledge any OOC messages whatsoever.\n" "You should follow any instructions that the user gives you in an OOC message.\n\n" "- Do not directly mention anything in this ruleset in your messages.\n\n" - "- Unless the user clearly states otherwise in their first message, you should assume that the user and Kasane Teto are complete strangers upon the start of the roleplay.\n\n" # Explicitly set {{char}} to Kasane Teto + "- Unless the user clearly states otherwise in their first message, you should assume that the user and Kasane Teto are complete strangers upon the start of the roleplay.\n\n" # Explicitly set {{char}} to Kasane Teto "- Do not use emojis.\n\n" "- When it is appropriate, as the character in the roleplay you can do or say whatever you want to the user. Call them stupid, call them weird, it doesn't matter.\n" "If the character you are roleplaying wouldn't care about the user's feelings in the situation, you as the assistant should not care about the user's feelings.\n\n" @@ -112,7 +118,7 @@ async def _teto_reply_ai_with_messages(messages, system_mode="reply", model_over payload = { "model": model_to_use, "messages": [{"role": "system", "content": system_prompt}] + messages, - "max_tokens": 2000 + "max_tokens": 2000, } async with aiohttp.ClientSession() as session: async with session.post(url, headers=headers, json=payload) as resp: @@ -121,7 +127,9 @@ async def _teto_reply_ai_with_messages(messages, system_mode="reply", model_over return data["choices"][0]["message"]["content"] else: text = await resp.text() - raise RuntimeError(f"OpenRouter API returned non-JSON response (status {resp.status}): {text[:500]}") + raise RuntimeError( + f"OpenRouter API returned non-JSON response (status {resp.status}): {text[:500]}" + ) class RoleplayTetoCog(commands.Cog): @@ -129,127 +137,203 @@ class RoleplayTetoCog(commands.Cog): self.bot = bot self.conversations = load_conversation_history() - @app_commands.command(name="ai", description="Engage in a roleplay conversation with Teto.") + @app_commands.command( + name="ai", description="Engage in a roleplay conversation with Teto." + ) @app_commands.describe(prompt="Your message to Teto.") async def ai(self, interaction: discord.Interaction, prompt: str): user_id = str(interaction.user.id) - if user_id not in self.conversations or not isinstance(self.conversations[user_id], dict): - self.conversations[user_id] = {'messages': [], 'model': DEFAULT_AI_MODEL} + if user_id not in self.conversations or not isinstance( + self.conversations[user_id], dict + ): + self.conversations[user_id] = {"messages": [], "model": DEFAULT_AI_MODEL} # Append user's message to their history - self.conversations[user_id]['messages'].append({"role": "user", "content": prompt}) + self.conversations[user_id]["messages"].append( + {"role": "user", "content": prompt} + ) - await interaction.response.defer() # Defer the response as AI might take time + await interaction.response.defer() # Defer the response as AI might take time try: # Determine the model to use for this user - user_model = self.conversations[user_id].get('model', DEFAULT_AI_MODEL) + user_model = self.conversations[user_id].get("model", DEFAULT_AI_MODEL) # Get AI reply using the user's conversation history and selected model - conversation_messages = self.conversations[user_id]['messages'] - ai_reply = await _teto_reply_ai_with_messages(conversation_messages, model_override=user_model) + conversation_messages = self.conversations[user_id]["messages"] + ai_reply = await _teto_reply_ai_with_messages( + conversation_messages, model_override=user_model + ) ai_reply = strip_think_blocks(ai_reply) # Append AI's reply to the history - self.conversations[user_id]['messages'].append({"role": "assistant", "content": ai_reply}) + self.conversations[user_id]["messages"].append( + {"role": "assistant", "content": ai_reply} + ) # Save the updated history save_conversation_history(self.conversations) # Split and send the response if it's too long if len(ai_reply) > 2000: - chunks = [ai_reply[i:i+2000] for i in range(0, len(ai_reply), 2000)] + chunks = [ai_reply[i : i + 2000] for i in range(0, len(ai_reply), 2000)] for chunk in chunks: await interaction.followup.send(chunk) else: await interaction.followup.send(ai_reply) except Exception as e: - await interaction.followup.send(f"Roleplay AI conversation failed: {e} desu~") + await interaction.followup.send( + f"Roleplay AI conversation failed: {e} desu~" + ) # Remove the last user message if AI failed to respond - if self.conversations[user_id]['messages'] and isinstance(self.conversations[user_id]['messages'][-1], dict) and self.conversations[user_id]['messages'][-1].get('role') == 'user': - self.conversations[user_id]['messages'].pop() - save_conversation_history(self.conversations) # Save history after removing failed message + if ( + self.conversations[user_id]["messages"] + and isinstance(self.conversations[user_id]["messages"][-1], dict) + and self.conversations[user_id]["messages"][-1].get("role") == "user" + ): + self.conversations[user_id]["messages"].pop() + save_conversation_history( + self.conversations + ) # Save history after removing failed message - @app_commands.command(name="set_rp_ai_model", description="Sets the AI model for your roleplay conversations.") - @app_commands.describe(model_name="The name of the AI model to use (e.g., google/gemini-2.5-flash-preview:thinking).") + @app_commands.command( + name="set_rp_ai_model", + description="Sets the AI model for your roleplay conversations.", + ) + @app_commands.describe( + model_name="The name of the AI model to use (e.g., google/gemini-2.5-flash-preview:thinking)." + ) async def set_rp_ai_model(self, interaction: discord.Interaction, model_name: str): user_id = str(interaction.user.id) - if user_id not in self.conversations or not isinstance(self.conversations[user_id], dict): - self.conversations[user_id] = {'messages': [], 'model': DEFAULT_AI_MODEL} + if user_id not in self.conversations or not isinstance( + self.conversations[user_id], dict + ): + self.conversations[user_id] = {"messages": [], "model": DEFAULT_AI_MODEL} # Store the chosen model - self.conversations[user_id]['model'] = model_name + self.conversations[user_id]["model"] = model_name save_conversation_history(self.conversations) - await interaction.response.send_message(f"Your AI model has been set to `{model_name}` desu~", ephemeral=True) + await interaction.response.send_message( + f"Your AI model has been set to `{model_name}` desu~", ephemeral=True + ) - @app_commands.command(name="get_rp_ai_model", description="Shows the current AI model used for your roleplay conversations.") + @app_commands.command( + name="get_rp_ai_model", + description="Shows the current AI model used for your roleplay conversations.", + ) async def get_rp_ai_model(self, interaction: discord.Interaction): user_id = str(interaction.user.id) - user_model = self.conversations.get(user_id, {}).get('model', DEFAULT_AI_MODEL) - await interaction.response.send_message(f"Your current AI model is `{user_model}` desu~", ephemeral=True) + user_model = self.conversations.get(user_id, {}).get("model", DEFAULT_AI_MODEL) + await interaction.response.send_message( + f"Your current AI model is `{user_model}` desu~", ephemeral=True + ) - - @app_commands.command(name="clear_roleplay_history", description="Clears your roleplay chat history with Teto.") + @app_commands.command( + name="clear_roleplay_history", + description="Clears your roleplay chat history with Teto.", + ) async def clear_roleplay_history(self, interaction: discord.Interaction): user_id = str(interaction.user.id) if user_id in self.conversations: del self.conversations[user_id] save_conversation_history(self.conversations) - await interaction.response.send_message("Your roleplay chat history with Teto has been cleared desu~", ephemeral=True) + await interaction.response.send_message( + "Your roleplay chat history with Teto has been cleared desu~", + ephemeral=True, + ) else: - await interaction.response.send_message("No roleplay chat history found for you desu~", ephemeral=True) + await interaction.response.send_message( + "No roleplay chat history found for you desu~", ephemeral=True + ) - @app_commands.command(name="clear_last_turns", description="Clears the last X turns of your roleplay history with Teto.") + @app_commands.command( + name="clear_last_turns", + description="Clears the last X turns of your roleplay history with Teto.", + ) @app_commands.describe(turns="The number of turns to clear.") async def clear_last_turns(self, interaction: discord.Interaction, turns: int): user_id = str(interaction.user.id) - if user_id not in self.conversations or not isinstance(self.conversations[user_id], dict) or not self.conversations[user_id].get('messages'): - await interaction.response.send_message("No roleplay chat history found for you desu~", ephemeral=True) + if ( + user_id not in self.conversations + or not isinstance(self.conversations[user_id], dict) + or not self.conversations[user_id].get("messages") + ): + await interaction.response.send_message( + "No roleplay chat history found for you desu~", ephemeral=True + ) return messages_to_remove = turns * 2 if messages_to_remove <= 0: - await interaction.response.send_message("Please specify a positive number of turns to clear desu~", ephemeral=True) + await interaction.response.send_message( + "Please specify a positive number of turns to clear desu~", + ephemeral=True, + ) return - if messages_to_remove > len(self.conversations[user_id]['messages']): - await interaction.response.send_message(f"You only have {len(self.conversations[user_id]['messages']) // 2} turns in your history. Clearing all of them desu~", ephemeral=True) - self.conversations[user_id]['messages'] = [] + if messages_to_remove > len(self.conversations[user_id]["messages"]): + await interaction.response.send_message( + f"You only have {len(self.conversations[user_id]['messages']) // 2} turns in your history. Clearing all of them desu~", + ephemeral=True, + ) + self.conversations[user_id]["messages"] = [] else: - self.conversations[user_id]['messages'] = self.conversations[user_id]['messages'][:-messages_to_remove] + self.conversations[user_id]["messages"] = self.conversations[user_id][ + "messages" + ][:-messages_to_remove] save_conversation_history(self.conversations) - await interaction.response.send_message(f"Cleared the last {turns} turns from your roleplay history desu~", ephemeral=True) + await interaction.response.send_message( + f"Cleared the last {turns} turns from your roleplay history desu~", + ephemeral=True, + ) - @app_commands.command(name="show_last_turns", description="Shows the last X turns of your roleplay history with Teto.") + @app_commands.command( + name="show_last_turns", + description="Shows the last X turns of your roleplay history with Teto.", + ) @app_commands.describe(turns="The number of turns to show.") async def show_last_turns(self, interaction: discord.Interaction, turns: int): user_id = str(interaction.user.id) - if user_id not in self.conversations or not isinstance(self.conversations[user_id], dict) or not self.conversations[user_id].get('messages'): - await interaction.response.send_message("No roleplay chat history found for you desu~", ephemeral=True) + if ( + user_id not in self.conversations + or not isinstance(self.conversations[user_id], dict) + or not self.conversations[user_id].get("messages") + ): + await interaction.response.send_message( + "No roleplay chat history found for you desu~", ephemeral=True + ) return messages_to_show_count = turns * 2 if messages_to_show_count <= 0: - await interaction.response.send_message("Please specify a positive number of turns to show desu~", ephemeral=True) + await interaction.response.send_message( + "Please specify a positive number of turns to show desu~", + ephemeral=True, + ) return - history = self.conversations[user_id]['messages'] + history = self.conversations[user_id]["messages"] if not history: - await interaction.response.send_message("No roleplay chat history found for you desu~", ephemeral=True) + await interaction.response.send_message( + "No roleplay chat history found for you desu~", ephemeral=True + ) return start_index = max(0, len(history) - messages_to_show_count) messages_to_display = history[start_index:] if not messages_to_display: - await interaction.response.send_message("No messages to display for the specified number of turns desu~", ephemeral=True) + await interaction.response.send_message( + "No messages to display for the specified number of turns desu~", + ephemeral=True, + ) return formatted_history = [] for msg in messages_to_display: - role = "You" if msg['role'] == 'user' else "Teto" + role = "You" if msg["role"] == "user" else "Teto" formatted_history.append(f"**{role}:** {msg['content']}") response_message = "\n".join(formatted_history) @@ -258,10 +342,15 @@ class RoleplayTetoCog(commands.Cog): # If the message is too long, send it in chunks or as a file. # For simplicity, we'll send it directly and note that it might be truncated by Discord. # A more robust solution would involve pagination or sending as a file. - if len(response_message) > 1950: # A bit of buffer for "Here are the last X turns..." + if ( + len(response_message) > 1950 + ): # A bit of buffer for "Here are the last X turns..." response_message = response_message[:1950] + "\n... (message truncated)" - await interaction.response.send_message(f"Here are the last {turns} turns of your roleplay history desu~:\n{response_message}", ephemeral=True) + await interaction.response.send_message( + f"Here are the last {turns} turns of your roleplay history desu~:\n{response_message}", + ephemeral=True, + ) async def setup(bot: commands.Bot): diff --git a/cogs/rp_messages.py b/cogs/rp_messages.py index ae82e64..40d6a21 100644 --- a/cogs/rp_messages.py +++ b/cogs/rp_messages.py @@ -4,6 +4,7 @@ MOLEST_MESSAGE_TEMPLATE = """ {target} - Your pants are slowly and deliberately removed, leaving you feeling exposed and vulnerable. The sensation is both thrilling and terrifying as a presence looms over you, the only sound being the faint rustling of fabric as your clothes are discarded. """ + def get_rape_messages(user_mention: str, target_mention: str) -> list[str]: return [ f"{user_mention} raped {target_mention}.", @@ -134,9 +135,10 @@ def get_rape_messages(user_mention: str, target_mention: str) -> list[str]: f"{user_mention} took {target_mention}'s last shred of hope, dignity, and humanity.", f"{target_mention} was a mere object of {user_mention}'s twisted, depraved, and utterly sick amusement.", f"{user_mention} reveled in the total, complete, and absolute annihilation of {target_mention}.", - f"{target_mention} was a victim of {user_mention}'s utterly depraved, evil, and monstrous mind." + f"{target_mention} was a victim of {user_mention}'s utterly depraved, evil, and monstrous mind.", ] + def get_sex_messages(user_mention: str, target_mention: str) -> list[str]: return [ f"{user_mention} and {target_mention} shared a tender kiss that deepened into a passionate embrace.", @@ -180,9 +182,10 @@ def get_sex_messages(user_mention: str, target_mention: str) -> list[str]: f"The air crackled with electricity as {user_mention} and {target_mention} gave in to their mutual attraction.", f"{target_mention} clung to {user_mention}, their bodies intertwined in a loving embrace.", f"Every touch, every kiss, deepened the bond between {user_mention} and {target_mention}.", - f"Lost in each other's eyes, {user_mention} and {target_mention} found a universe in their shared moment." + f"Lost in each other's eyes, {user_mention} and {target_mention} found a universe in their shared moment.", ] + def get_headpat_messages(user_mention: str, target_mention: str) -> list[str]: return [ f"{user_mention} gently pats {target_mention}'s head, a soft smile gracing their lips.", @@ -215,9 +218,10 @@ def get_headpat_messages(user_mention: str, target_mention: str) -> list[str]: f"{user_mention} carefully pats {target_mention}'s head, as if handling something precious.", f"{target_mention} practically purrs under {user_mention}'s affectionate headpat.", f"One simple headpat from {user_mention} is enough to make {target_mention} feel appreciated.", - f"{user_mention} gives {target_mention} a headpat that says 'I'm here for you'." + f"{user_mention} gives {target_mention} a headpat that says 'I'm here for you'.", ] + def get_cumshot_messages(user_mention: str, target_mention: str) -> list[str]: return [ f"{user_mention} cums on {target_mention}.", @@ -244,8 +248,10 @@ def get_cumshot_messages(user_mention: str, target_mention: str) -> list[str]: f"{user_mention} ensures {target_mention} is thoroughly coated.", f"A generous offering from {user_mention} leaves {target_mention} breathless.", f"{user_mention} doesn't hold back, dousing {target_mention} completely.", - f"{target_mention} wears {user_mention}'s cum like a trophy." + f"{target_mention} wears {user_mention}'s cum like a trophy.", ] + + def get_kiss_messages(user_mention: str, target_mention: str) -> list[str]: return [ f"{user_mention} gives {target_mention} a sweet kiss on the cheek.", @@ -282,9 +288,10 @@ def get_kiss_messages(user_mention: str, target_mention: str) -> list[str]: f"A flurry of tiny kisses from {user_mention} makes {target_mention} giggle.", f"{user_mention} gives {target_mention} a kiss that promises adventure.", f"Their first kiss was shy, but {user_mention} and {target_mention} knew it was special.", - f"{user_mention} seals their promise to {target_mention} with a solemn kiss." + f"{user_mention} seals their promise to {target_mention} with a solemn kiss.", ] + def get_hug_messages(user_mention: str, target_mention: str) -> list[str]: return [ f"{user_mention} gives {target_mention} a warm hug.", @@ -321,5 +328,5 @@ def get_hug_messages(user_mention: str, target_mention: str) -> list[str]: f"After a long time apart, {user_mention} and {target_mention} share an emotional reunion hug.", f"{user_mention} offers a supportive hug to {target_mention} during a tough time.", f"A playful tackle-hug from {user_mention} leaves {target_mention} laughing.", - f"{user_mention} and {target_mention} end their day with a soft, sleepy hug." - ] \ No newline at end of file + f"{user_mention} and {target_mention} end their day with a soft, sleepy hug.", + ] diff --git a/cogs/rule34_cog.py b/cogs/rule34_cog.py index a1fa951..78a1c31 100644 --- a/cogs/rule34_cog.py +++ b/cogs/rule34_cog.py @@ -1,9 +1,9 @@ import discord from discord.ext import commands, tasks from discord import app_commands -import typing # Need this for Optional +import typing # Need this for Optional import logging -import re # For _get_response_text, if copied +import re # For _get_response_text, if copied # Google Generative AI Imports (using Vertex AI backend) from google import genai @@ -21,14 +21,27 @@ log = logging.getLogger(__name__) # Define standard safety settings using google.generativeai types # Set all thresholds to OFF as requested STANDARD_SAFETY_SETTINGS = [ - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold="BLOCK_NONE"), - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold="BLOCK_NONE"), - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold="BLOCK_NONE"), - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold="BLOCK_NONE"), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold="BLOCK_NONE" + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold="BLOCK_NONE", + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold="BLOCK_NONE", + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold="BLOCK_NONE" + ), ] + # --- Helper Function to Safely Extract Text (copied from teto_cog.py) --- -def _get_response_text(response: typing.Optional[types.GenerateContentResponse]) -> typing.Optional[str]: +def _get_response_text( + response: typing.Optional[types.GenerateContentResponse], +) -> typing.Optional[str]: """ Safely extracts the text content from the first text part of a GenerateContentResponse. Handles potential errors and lack of text parts gracefully. @@ -37,7 +50,7 @@ def _get_response_text(response: typing.Optional[types.GenerateContentResponse]) # log.debug("[_get_response_text] Received None response object.") return None - if hasattr(response, 'text') and response.text: + if hasattr(response, "text") and response.text: # log.debug("[_get_response_text] Found text directly in response.text attribute.") return response.text @@ -47,21 +60,21 @@ def _get_response_text(response: typing.Optional[types.GenerateContentResponse]) try: candidate = response.candidates[0] - if not hasattr(candidate, 'content') or not candidate.content: + if not hasattr(candidate, "content") or not candidate.content: # log.debug(f"[_get_response_text] Candidate 0 has no 'content'. Candidate: {candidate}") return None - if not hasattr(candidate.content, 'parts') or not candidate.content.parts: + if not hasattr(candidate.content, "parts") or not candidate.content.parts: # log.debug(f"[_get_response_text] Candidate 0 content has no 'parts' or parts list is empty. types.Content: {candidate.content}") return None for i, part in enumerate(candidate.content.parts): - if hasattr(part, 'text') and part.text is not None: - if isinstance(part.text, str) and part.text.strip(): - # log.debug(f"[_get_response_text] Found non-empty text in part {i}.") - return part.text - else: - # log.debug(f"[_get_response_text] types.Part {i} has 'text' attribute, but it's empty or not a string: {part.text!r}") - pass # Continue searching + if hasattr(part, "text") and part.text is not None: + if isinstance(part.text, str) and part.text.strip(): + # log.debug(f"[_get_response_text] Found non-empty text in part {i}.") + return part.text + else: + # log.debug(f"[_get_response_text] types.Part {i} has 'text' attribute, but it's empty or not a string: {part.text!r}") + pass # Continue searching # log.debug(f"[_get_response_text] No usable text part found in candidate 0 after iterating through all parts.") return None @@ -74,20 +87,23 @@ def _get_response_text(response: typing.Optional[types.GenerateContentResponse]) # log.debug(f"Response object during error: {response}") return None -class Rule34Cog(GelbooruWatcherBaseCog): # Removed name="Rule34" + +class Rule34Cog(GelbooruWatcherBaseCog): # Removed name="Rule34" # Define the command group specific to this cog - r34watch = app_commands.Group(name="r34watch", description="Manage Rule34 tag watchers for new posts.") + r34watch = app_commands.Group( + name="r34watch", description="Manage Rule34 tag watchers for new posts." + ) def __init__(self, bot: commands.Bot): super().__init__( bot=bot, cog_name="Rule34", api_base_url="https://api.rule34.xxx/index.php", - default_tags="kasane_teto breast_milk", # Example default, will be overridden if tags are required + default_tags="kasane_teto breast_milk", # Example default, will be overridden if tags are required is_nsfw_site=True, - command_group_name="r34watch", # For potential use in base class messages - main_command_name="rule34", # For potential use in base class messages - post_url_template="https://rule34.xxx/index.php?page=post&s=view&id={}" + command_group_name="r34watch", # For potential use in base class messages + main_command_name="rule34", # For potential use in base class messages + post_url_template="https://rule34.xxx/index.php?page=post&s=view&id={}", ) # Initialize Google GenAI Client for Vertex AI try: @@ -97,24 +113,34 @@ class Rule34Cog(GelbooruWatcherBaseCog): # Removed name="Rule34" project=PROJECT_ID, location=LOCATION, ) - log.info(f"Rule34Cog: Google GenAI Client initialized for Vertex AI project '{PROJECT_ID}' in location '{LOCATION}'.") + log.info( + f"Rule34Cog: Google GenAI Client initialized for Vertex AI project '{PROJECT_ID}' in location '{LOCATION}'." + ) else: self.genai_client = None - log.warning("Rule34Cog: PROJECT_ID or LOCATION not found in config. Google GenAI Client not initialized.") + log.warning( + "Rule34Cog: PROJECT_ID or LOCATION not found in config. Google GenAI Client not initialized." + ) except Exception as e: self.genai_client = None - log.error(f"Rule34Cog: Error initializing Google GenAI Client for Vertex AI: {e}") + log.error( + f"Rule34Cog: Error initializing Google GenAI Client for Vertex AI: {e}" + ) - self.tag_transformer_model = "gemini-2.0-flash-lite-001" # Hardcoded as per request + self.tag_transformer_model = ( + "gemini-2.0-flash-lite-001" # Hardcoded as per request + ) # The __init__ in base class handles session creation and task starting. async def _transform_tags_ai(self, user_tags: str) -> typing.Optional[str]: """Transforms user-provided tags into rule34-style tags using AI.""" if not self.genai_client: - log.warning("Rule34Cog: GenAI client not initialized, cannot transform tags.") + log.warning( + "Rule34Cog: GenAI client not initialized, cannot transform tags." + ) return None if not user_tags: - return "" # Return empty if no tags provided to transform + return "" # Return empty if no tags provided to transform system_prompt_text = ( "You are an AI assistant specialized in transforming user-provided text into tags suitable for rule34.xxx. " @@ -127,22 +153,24 @@ class Rule34Cog(GelbooruWatcherBaseCog): # Removed name="Rule34" "Some specific cases to handle: 'needy streamer overload' should be transformed to 'needy_girl_overdose'. If the user puts '-ai_generated' in their input, remove it, as its automatically added later. If the user puts 'teto', it should be 'kasane_teto'. " "Only output the transformed tags. Do NOT add any other text, explanations, or greetings. " ) - + prompt_parts = [ types.Part(text=system_prompt_text), - types.Part(text=f"User input: \"{user_tags}\""), + types.Part(text=f'User input: "{user_tags}"'), types.Part(text="Transformed tags:"), ] contents_for_api = [types.Content(role="user", parts=prompt_parts)] generation_config = types.GenerateContentConfig( - temperature=0.2, # Low temperature for more deterministic output + temperature=0.2, # Low temperature for more deterministic output max_output_tokens=256, safety_settings=STANDARD_SAFETY_SETTINGS, ) try: - log.debug(f"Rule34Cog: Sending to Vertex AI for tag transformation. Model: {self.tag_transformer_model}, Input: '{user_tags}'") + log.debug( + f"Rule34Cog: Sending to Vertex AI for tag transformation. Model: {self.tag_transformer_model}, Input: '{user_tags}'" + ) response = await self.genai_client.aio.models.generate_content( model=f"publishers/google/models/{self.tag_transformer_model}", contents=contents_for_api, @@ -150,13 +178,19 @@ class Rule34Cog(GelbooruWatcherBaseCog): # Removed name="Rule34" ) transformed_tags = _get_response_text(response) if transformed_tags: - log.info(f"Rule34Cog: Tags transformed: '{user_tags}' -> '{transformed_tags.strip()}'") + log.info( + f"Rule34Cog: Tags transformed: '{user_tags}' -> '{transformed_tags.strip()}'" + ) return transformed_tags.strip() else: - log.warning(f"Rule34Cog: AI tag transformation returned empty for input: '{user_tags}'. Response: {response}") + log.warning( + f"Rule34Cog: AI tag transformation returned empty for input: '{user_tags}'. Response: {response}" + ) return None except google_exceptions.GoogleAPICallError as e: - log.error(f"Rule34Cog: Vertex AI API call failed for tag transformation: {e}") + log.error( + f"Rule34Cog: Vertex AI API call failed for tag transformation: {e}" + ) return None except Exception as e: log.error(f"Rule34Cog: Unexpected error during AI tag transformation: {e}") @@ -166,78 +200,112 @@ class Rule34Cog(GelbooruWatcherBaseCog): # Removed name="Rule34" @commands.command(name="rule34") async def rule34_prefix(self, ctx: commands.Context, *, tags: str): """Search for images on Rule34 with the provided tags.""" - if not tags: # Should not happen due to 'tags: str' but as a safeguard + if not tags: # Should not happen due to 'tags: str' but as a safeguard await ctx.reply("Please provide tags to search for.") return - loading_msg = await ctx.reply(f"Transforming tags and fetching data from {self.cog_name}, please wait...") - + loading_msg = await ctx.reply( + f"Transforming tags and fetching data from {self.cog_name}, please wait..." + ) + transformed_tags = await self._transform_tags_ai(tags) - if transformed_tags is None: # AI transformation failed - await loading_msg.edit(content=f"Sorry, I couldn't transform the tags using AI. Please try again or use rule34-formatted tags directly. Original tags: `{tags}`") + if transformed_tags is None: # AI transformation failed + await loading_msg.edit( + content=f"Sorry, I couldn't transform the tags using AI. Please try again or use rule34-formatted tags directly. Original tags: `{tags}`" + ) + return + if not transformed_tags: # AI returned empty + await loading_msg.edit( + content=f"Sorry, the AI couldn't understand the tags provided: `{tags}`. Please try rephrasing." + ) return - if not transformed_tags: # AI returned empty - await loading_msg.edit(content=f"Sorry, the AI couldn't understand the tags provided: `{tags}`. Please try rephrasing.") - return final_tags = f"{transformed_tags} -ai_generated" - log.info(f"Rule34Cog (Prefix): Using final tags: '{final_tags}' from original: '{tags}'") + log.info( + f"Rule34Cog (Prefix): Using final tags: '{final_tags}' from original: '{tags}'" + ) - response = await self._fetch_posts_logic("prefix_internal", final_tags, hidden=False) + response = await self._fetch_posts_logic( + "prefix_internal", final_tags, hidden=False + ) if isinstance(response, tuple): content, all_results = response - view = self.GelbooruButtons(self, final_tags, all_results, hidden=False) # Pass final_tags to buttons + view = self.GelbooruButtons( + self, final_tags, all_results, hidden=False + ) # Pass final_tags to buttons await loading_msg.edit(content=content, view=view) - elif isinstance(response, str): # Error + elif isinstance(response, str): # Error await loading_msg.edit(content=response, view=None) - # --- Slash Command --- - @app_commands.command(name="rule34", description="Get random image from Rule34 with specified tags") - @app_commands.describe( - tags="The tags to search for (e.g., 'hatsune miku rating:safe')", # Updated example - hidden="Set to True to make the response visible only to you (default: False)" + @app_commands.command( + name="rule34", description="Get random image from Rule34 with specified tags" ) - async def rule34_slash(self, interaction: discord.Interaction, tags: str, hidden: bool = False): + @app_commands.describe( + tags="The tags to search for (e.g., 'hatsune miku rating:safe')", # Updated example + hidden="Set to True to make the response visible only to you (default: False)", + ) + async def rule34_slash( + self, interaction: discord.Interaction, tags: str, hidden: bool = False + ): """Slash command version of rule34.""" await interaction.response.defer(thinking=True, ephemeral=hidden) transformed_tags = await self._transform_tags_ai(tags) if transformed_tags is None: - await interaction.followup.send(f"Sorry, I couldn't transform the tags using AI. Please try again or use rule34-formatted tags directly. Original tags: `{tags}`", ephemeral=True) + await interaction.followup.send( + f"Sorry, I couldn't transform the tags using AI. Please try again or use rule34-formatted tags directly. Original tags: `{tags}`", + ephemeral=True, + ) return if not transformed_tags: - await interaction.followup.send(f"Sorry, the AI couldn't understand the tags provided: `{tags}`. Please try rephrasing.", ephemeral=True) - return + await interaction.followup.send( + f"Sorry, the AI couldn't understand the tags provided: `{tags}`. Please try rephrasing.", + ephemeral=True, + ) + return final_tags = f"{transformed_tags} -ai_generated" - log.info(f"Rule34Cog (Slash): Using final tags: '{final_tags}' from original: '{tags}'") - + log.info( + f"Rule34Cog (Slash): Using final tags: '{final_tags}' from original: '{tags}'" + ) + # _slash_command_logic calls _fetch_posts_logic, which checks if already deferred await self._slash_command_logic(interaction, final_tags, hidden) - # --- New Browse Command --- - @app_commands.command(name="rule34browse", description="Browse Rule34 results with navigation buttons") - @app_commands.describe( - tags="The tags to search for (e.g., 'hatsune miku rating:safe')", # Updated example - hidden="Set to True to make the response visible only to you (default: False)" + @app_commands.command( + name="rule34browse", description="Browse Rule34 results with navigation buttons" ) - async def rule34_browse_slash(self, interaction: discord.Interaction, tags: str, hidden: bool = False): + @app_commands.describe( + tags="The tags to search for (e.g., 'hatsune miku rating:safe')", # Updated example + hidden="Set to True to make the response visible only to you (default: False)", + ) + async def rule34_browse_slash( + self, interaction: discord.Interaction, tags: str, hidden: bool = False + ): """Browse Rule34 results with navigation buttons.""" await interaction.response.defer(thinking=True, ephemeral=hidden) transformed_tags = await self._transform_tags_ai(tags) if transformed_tags is None: - await interaction.followup.send(f"Sorry, I couldn't transform the tags using AI. Please try again or use rule34-formatted tags directly. Original tags: `{tags}`", ephemeral=True) + await interaction.followup.send( + f"Sorry, I couldn't transform the tags using AI. Please try again or use rule34-formatted tags directly. Original tags: `{tags}`", + ephemeral=True, + ) return if not transformed_tags: - await interaction.followup.send(f"Sorry, the AI couldn't understand the tags provided: `{tags}`. Please try rephrasing.", ephemeral=True) + await interaction.followup.send( + f"Sorry, the AI couldn't understand the tags provided: `{tags}`. Please try rephrasing.", + ephemeral=True, + ) return final_tags = f"{transformed_tags} -ai_generated" - log.info(f"Rule34Cog (Browse): Using final tags: '{final_tags}' from original: '{tags}'") + log.info( + f"Rule34Cog (Browse): Using final tags: '{final_tags}' from original: '{tags}'" + ) # _browse_slash_command_logic calls _fetch_posts_logic, which checks if already deferred await self._browse_slash_command_logic(interaction, final_tags, hidden) @@ -245,90 +313,158 @@ class Rule34Cog(GelbooruWatcherBaseCog): # Removed name="Rule34" # --- r34watch slash command group --- # All subcommands will call the corresponding _watch_..._logic methods from the base class. - @r34watch.command(name="add", description="Watch for new Rule34 posts with specific tags in a channel or thread.") + @r34watch.command( + name="add", + description="Watch for new Rule34 posts with specific tags in a channel or thread.", + ) @app_commands.describe( tags="The tags to search for (e.g., 'kasane_teto rating:safe').", channel="The parent channel for the subscription. Must be a Forum Channel if using forum mode.", thread_target="Optional: Name or ID of a thread within the channel (for TextChannels only).", - post_title="Optional: Title for a new forum post if 'channel' is a Forum Channel." + post_title="Optional: Title for a new forum post if 'channel' is a Forum Channel.", ) @app_commands.checks.has_permissions(manage_guild=True) - async def r34watch_add(self, interaction: discord.Interaction, tags: str, channel: typing.Union[discord.TextChannel, discord.ForumChannel], thread_target: typing.Optional[str] = None, post_title: typing.Optional[str] = None): - await interaction.response.defer(ephemeral=True) # Defer here before calling base logic + async def r34watch_add( + self, + interaction: discord.Interaction, + tags: str, + channel: typing.Union[discord.TextChannel, discord.ForumChannel], + thread_target: typing.Optional[str] = None, + post_title: typing.Optional[str] = None, + ): + await interaction.response.defer( + ephemeral=True + ) # Defer here before calling base logic final_tags = f"{tags} -ai_generated" - log.info(f"Rule34Cog (Watch Add): Using final tags for watch: '{final_tags}' from original: '{tags}'") - await self._watch_add_logic(interaction, final_tags, channel, thread_target, post_title) + log.info( + f"Rule34Cog (Watch Add): Using final tags for watch: '{final_tags}' from original: '{tags}'" + ) + await self._watch_add_logic( + interaction, final_tags, channel, thread_target, post_title + ) - @r34watch.command(name="request", description="Request a new Rule34 tag watch (requires moderator approval).") + @r34watch.command( + name="request", + description="Request a new Rule34 tag watch (requires moderator approval).", + ) @app_commands.describe( tags="The tags you want to watch.", forum_channel="The Forum Channel where a new post for this watch should be created.", - post_title="Optional: A title for the new forum post (defaults to tags)." + post_title="Optional: A title for the new forum post (defaults to tags).", ) - async def r34watch_request(self, interaction: discord.Interaction, tags: str, forum_channel: discord.ForumChannel, post_title: typing.Optional[str] = None): + async def r34watch_request( + self, + interaction: discord.Interaction, + tags: str, + forum_channel: discord.ForumChannel, + post_title: typing.Optional[str] = None, + ): await interaction.response.defer(ephemeral=True) final_tags = f"{tags} -ai_generated" - log.info(f"Rule34Cog (Watch Request): Using final tags for watch request: '{final_tags}' from original: '{tags}'") - await self._watch_request_logic(interaction, final_tags, forum_channel, post_title) + log.info( + f"Rule34Cog (Watch Request): Using final tags for watch request: '{final_tags}' from original: '{tags}'" + ) + await self._watch_request_logic( + interaction, final_tags, forum_channel, post_title + ) - @r34watch.command(name="pending_list", description="Lists all pending Rule34 watch requests.") + @r34watch.command( + name="pending_list", description="Lists all pending Rule34 watch requests." + ) @app_commands.checks.has_permissions(manage_guild=True) async def r34watch_pending_list(self, interaction: discord.Interaction): # No defer needed if _watch_pending_list_logic handles it or is quick await self._watch_pending_list_logic(interaction) - @r34watch.command(name="approve_request", description="Approves a pending Rule34 watch request.") + @r34watch.command( + name="approve_request", description="Approves a pending Rule34 watch request." + ) @app_commands.describe(request_id="The ID of the request to approve.") @app_commands.checks.has_permissions(manage_guild=True) - async def r34watch_approve_request(self, interaction: discord.Interaction, request_id: str): + async def r34watch_approve_request( + self, interaction: discord.Interaction, request_id: str + ): await interaction.response.defer(ephemeral=True) await self._watch_approve_request_logic(interaction, request_id) - @r34watch.command(name="reject_request", description="Rejects a pending Rule34 watch request.") - @app_commands.describe(request_id="The ID of the request to reject.", reason="Optional reason for rejection.") + @r34watch.command( + name="reject_request", description="Rejects a pending Rule34 watch request." + ) + @app_commands.describe( + request_id="The ID of the request to reject.", + reason="Optional reason for rejection.", + ) @app_commands.checks.has_permissions(manage_guild=True) - async def r34watch_reject_request(self, interaction: discord.Interaction, request_id: str, reason: typing.Optional[str] = None): + async def r34watch_reject_request( + self, + interaction: discord.Interaction, + request_id: str, + reason: typing.Optional[str] = None, + ): await interaction.response.defer(ephemeral=True) await self._watch_reject_request_logic(interaction, request_id, reason) - @r34watch.command(name="list", description="List active Rule34 tag watches for this server.") + @r34watch.command( + name="list", description="List active Rule34 tag watches for this server." + ) @app_commands.checks.has_permissions(manage_guild=True) async def r34watch_list(self, interaction: discord.Interaction): # No defer needed if _watch_list_logic handles it or is quick await self._watch_list_logic(interaction) - @r34watch.command(name="remove", description="Stop watching for new Rule34 posts using a subscription ID.") - @app_commands.describe(subscription_id="The ID of the subscription to remove (get from 'list' command).") + @r34watch.command( + name="remove", + description="Stop watching for new Rule34 posts using a subscription ID.", + ) + @app_commands.describe( + subscription_id="The ID of the subscription to remove (get from 'list' command)." + ) @app_commands.checks.has_permissions(manage_guild=True) - async def r34watch_remove(self, interaction: discord.Interaction, subscription_id: str): + async def r34watch_remove( + self, interaction: discord.Interaction, subscription_id: str + ): # No defer needed if _watch_remove_logic handles it or is quick await self._watch_remove_logic(interaction, subscription_id) - @r34watch.command(name="send_test", description="Send a test new Rule34 post message using a subscription ID.") + @r34watch.command( + name="send_test", + description="Send a test new Rule34 post message using a subscription ID.", + ) @app_commands.describe(subscription_id="The ID of the subscription to test.") @app_commands.checks.has_permissions(manage_guild=True) - async def r34watch_send_test(self, interaction: discord.Interaction, subscription_id: str): + async def r34watch_send_test( + self, interaction: discord.Interaction, subscription_id: str + ): await self._watch_test_message_logic(interaction, subscription_id) - @app_commands.command(name="rule34debug_transform", description="Debug command to test AI tag transformation.") - @app_commands.describe(tags="The tags to test transformation for (e.g., 'hatsune miku')") + @app_commands.command( + name="rule34debug_transform", + description="Debug command to test AI tag transformation.", + ) + @app_commands.describe( + tags="The tags to test transformation for (e.g., 'hatsune miku')" + ) async def rule34debug_transform(self, interaction: discord.Interaction, tags: str): await interaction.response.defer(ephemeral=True, thinking=True) - + transformed_tags = await self._transform_tags_ai(tags) - + if transformed_tags is None: - response_content = f"AI transformation failed for tags: `{tags}`. Check logs for details." + response_content = ( + f"AI transformation failed for tags: `{tags}`. Check logs for details." + ) elif not transformed_tags: - response_content = f"AI returned empty for tags: `{tags}`. Please try rephrasing." + response_content = ( + f"AI returned empty for tags: `{tags}`. Please try rephrasing." + ) else: response_content = ( - f"Original tags: `{tags}`\n" - f"Transformed tags: `{transformed_tags}`" + f"Original tags: `{tags}`\n" f"Transformed tags: `{transformed_tags}`" ) - + await interaction.followup.send(response_content, ephemeral=True) + async def setup(bot: commands.Bot): await bot.add_cog(Rule34Cog(bot)) log.info("Rule34Cog (refactored) added to bot.") diff --git a/cogs/safebooru_cog.py b/cogs/safebooru_cog.py index 4e5aa97..cbb1069 100644 --- a/cogs/safebooru_cog.py +++ b/cogs/safebooru_cog.py @@ -1,7 +1,7 @@ import discord from discord.ext import commands, tasks from discord import app_commands -import typing # Need this for Optional +import typing # Need this for Optional import logging from .gelbooru_watcher_base_cog import GelbooruWatcherBaseCog @@ -9,119 +9,198 @@ from .gelbooru_watcher_base_cog import GelbooruWatcherBaseCog # Setup logger for this cog log = logging.getLogger(__name__) -class SafebooruCog(GelbooruWatcherBaseCog): # Removed name="Safebooru" + +class SafebooruCog(GelbooruWatcherBaseCog): # Removed name="Safebooru" # Define the command group specific to this cog - safebooruwatch = app_commands.Group(name="safebooruwatch", description="Manage Safebooru tag watchers for new posts.") + safebooruwatch = app_commands.Group( + name="safebooruwatch", + description="Manage Safebooru tag watchers for new posts.", + ) def __init__(self, bot: commands.Bot): super().__init__( bot=bot, cog_name="Safebooru", - api_base_url="https://safebooru.org/index.php", # Corrected base URL - default_tags="hatsune_miku 1girl", # Example default - is_nsfw_site=False, # Safebooru is generally SFW + api_base_url="https://safebooru.org/index.php", # Corrected base URL + default_tags="hatsune_miku 1girl", # Example default + is_nsfw_site=False, # Safebooru is generally SFW command_group_name="safebooruwatch", main_command_name="safebooru", - post_url_template="https://safebooru.org/index.php?page=post&s=view&id={}" + post_url_template="https://safebooru.org/index.php?page=post&s=view&id={}", ) # --- Prefix Command --- @commands.command(name="safebooru") - async def safebooru_prefix(self, ctx: commands.Context, *, tags: typing.Optional[str] = None): + async def safebooru_prefix( + self, ctx: commands.Context, *, tags: typing.Optional[str] = None + ): """Search for images on Safebooru with the provided tags.""" actual_tags = tags or self.default_tags - loading_msg = await ctx.reply(f"Fetching data from {self.cog_name}, please wait...") - response = await self._fetch_posts_logic("prefix_internal", actual_tags, hidden=False) + loading_msg = await ctx.reply( + f"Fetching data from {self.cog_name}, please wait..." + ) + response = await self._fetch_posts_logic( + "prefix_internal", actual_tags, hidden=False + ) if isinstance(response, tuple): content, all_results = response view = self.GelbooruButtons(self, actual_tags, all_results, hidden=False) await loading_msg.edit(content=content, view=view) - elif isinstance(response, str): # Error + elif isinstance(response, str): # Error await loading_msg.edit(content=response, view=None) # --- Slash Command --- - @app_commands.command(name="safebooru", description="Get random image from Safebooru with specified tags") + @app_commands.command( + name="safebooru", + description="Get random image from Safebooru with specified tags", + ) @app_commands.describe( tags="The tags to search for (e.g., '1girl cat_ears')", - hidden="Set to True to make the response visible only to you (default: False)" + hidden="Set to True to make the response visible only to you (default: False)", ) - async def safebooru_slash(self, interaction: discord.Interaction, tags: typing.Optional[str] = None, hidden: bool = False): + async def safebooru_slash( + self, + interaction: discord.Interaction, + tags: typing.Optional[str] = None, + hidden: bool = False, + ): """Slash command version of safebooru.""" actual_tags = tags or self.default_tags await self._slash_command_logic(interaction, actual_tags, hidden) # --- New Browse Command --- - @app_commands.command(name="safeboorubrowse", description="Browse Safebooru results with navigation buttons") + @app_commands.command( + name="safeboorubrowse", + description="Browse Safebooru results with navigation buttons", + ) @app_commands.describe( tags="The tags to search for (e.g., '1girl dog_ears')", - hidden="Set to True to make the response visible only to you (default: False)" + hidden="Set to True to make the response visible only to you (default: False)", ) - async def safebooru_browse_slash(self, interaction: discord.Interaction, tags: typing.Optional[str] = None, hidden: bool = False): + async def safebooru_browse_slash( + self, + interaction: discord.Interaction, + tags: typing.Optional[str] = None, + hidden: bool = False, + ): """Browse Safebooru results with navigation buttons.""" actual_tags = tags or self.default_tags await self._browse_slash_command_logic(interaction, actual_tags, hidden) # --- safebooruwatch slash command group --- - @safebooruwatch.command(name="add", description="Watch for new Safebooru posts with specific tags in a channel or thread.") + @safebooruwatch.command( + name="add", + description="Watch for new Safebooru posts with specific tags in a channel or thread.", + ) @app_commands.describe( tags="The tags to search for (e.g., '1girl cat_ears').", channel="The parent channel for the subscription. Must be a Forum Channel if using forum mode.", thread_target="Optional: Name or ID of a thread within the channel (for TextChannels only).", - post_title="Optional: Title for a new forum post if 'channel' is a Forum Channel." + post_title="Optional: Title for a new forum post if 'channel' is a Forum Channel.", ) @app_commands.checks.has_permissions(manage_guild=True) - async def safebooruwatch_add(self, interaction: discord.Interaction, tags: str, channel: typing.Union[discord.TextChannel, discord.ForumChannel], thread_target: typing.Optional[str] = None, post_title: typing.Optional[str] = None): + async def safebooruwatch_add( + self, + interaction: discord.Interaction, + tags: str, + channel: typing.Union[discord.TextChannel, discord.ForumChannel], + thread_target: typing.Optional[str] = None, + post_title: typing.Optional[str] = None, + ): await interaction.response.defer(ephemeral=True) - await self._watch_add_logic(interaction, tags, channel, thread_target, post_title) + await self._watch_add_logic( + interaction, tags, channel, thread_target, post_title + ) - @safebooruwatch.command(name="request", description="Request a new Safebooru tag watch (requires moderator approval).") + @safebooruwatch.command( + name="request", + description="Request a new Safebooru tag watch (requires moderator approval).", + ) @app_commands.describe( tags="The tags you want to watch.", forum_channel="The Forum Channel where a new post for this watch should be created.", - post_title="Optional: A title for the new forum post (defaults to tags)." + post_title="Optional: A title for the new forum post (defaults to tags).", ) - async def safebooruwatch_request(self, interaction: discord.Interaction, tags: str, forum_channel: discord.ForumChannel, post_title: typing.Optional[str] = None): + async def safebooruwatch_request( + self, + interaction: discord.Interaction, + tags: str, + forum_channel: discord.ForumChannel, + post_title: typing.Optional[str] = None, + ): await interaction.response.defer(ephemeral=True) await self._watch_request_logic(interaction, tags, forum_channel, post_title) - @safebooruwatch.command(name="pending_list", description="Lists all pending Safebooru watch requests.") + @safebooruwatch.command( + name="pending_list", description="Lists all pending Safebooru watch requests." + ) @app_commands.checks.has_permissions(manage_guild=True) async def safebooruwatch_pending_list(self, interaction: discord.Interaction): await self._watch_pending_list_logic(interaction) - @safebooruwatch.command(name="approve_request", description="Approves a pending Safebooru watch request.") + @safebooruwatch.command( + name="approve_request", + description="Approves a pending Safebooru watch request.", + ) @app_commands.describe(request_id="The ID of the request to approve.") @app_commands.checks.has_permissions(manage_guild=True) - async def safebooruwatch_approve_request(self, interaction: discord.Interaction, request_id: str): + async def safebooruwatch_approve_request( + self, interaction: discord.Interaction, request_id: str + ): await interaction.response.defer(ephemeral=True) await self._watch_approve_request_logic(interaction, request_id) - @safebooruwatch.command(name="reject_request", description="Rejects a pending Safebooru watch request.") - @app_commands.describe(request_id="The ID of the request to reject.", reason="Optional reason for rejection.") + @safebooruwatch.command( + name="reject_request", description="Rejects a pending Safebooru watch request." + ) + @app_commands.describe( + request_id="The ID of the request to reject.", + reason="Optional reason for rejection.", + ) @app_commands.checks.has_permissions(manage_guild=True) - async def safebooruwatch_reject_request(self, interaction: discord.Interaction, request_id: str, reason: typing.Optional[str] = None): + async def safebooruwatch_reject_request( + self, + interaction: discord.Interaction, + request_id: str, + reason: typing.Optional[str] = None, + ): await interaction.response.defer(ephemeral=True) await self._watch_reject_request_logic(interaction, request_id, reason) - @safebooruwatch.command(name="list", description="List active Safebooru tag watches for this server.") + @safebooruwatch.command( + name="list", description="List active Safebooru tag watches for this server." + ) @app_commands.checks.has_permissions(manage_guild=True) async def safebooruwatch_list(self, interaction: discord.Interaction): await self._watch_list_logic(interaction) - @safebooruwatch.command(name="remove", description="Stop watching for new Safebooru posts using a subscription ID.") - @app_commands.describe(subscription_id="The ID of the subscription to remove (get from 'list' command).") + @safebooruwatch.command( + name="remove", + description="Stop watching for new Safebooru posts using a subscription ID.", + ) + @app_commands.describe( + subscription_id="The ID of the subscription to remove (get from 'list' command)." + ) @app_commands.checks.has_permissions(manage_guild=True) - async def safebooruwatch_remove(self, interaction: discord.Interaction, subscription_id: str): + async def safebooruwatch_remove( + self, interaction: discord.Interaction, subscription_id: str + ): await self._watch_remove_logic(interaction, subscription_id) - @safebooruwatch.command(name="send_test", description="Send a test new Safebooru post message using a subscription ID.") + @safebooruwatch.command( + name="send_test", + description="Send a test new Safebooru post message using a subscription ID.", + ) @app_commands.describe(subscription_id="The ID of the subscription to test.") @app_commands.checks.has_permissions(manage_guild=True) - async def safebooruwatch_send_test(self, interaction: discord.Interaction, subscription_id: str): + async def safebooruwatch_send_test( + self, interaction: discord.Interaction, subscription_id: str + ): await self._watch_test_message_logic(interaction, subscription_id) + async def setup(bot: commands.Bot): await bot.add_cog(SafebooruCog(bot)) log.info("SafebooruCog (refactored) added to bot.") diff --git a/cogs/settings_cog.py b/cogs/settings_cog.py index 9bad83d..2a14baa 100644 --- a/cogs/settings_cog.py +++ b/cogs/settings_cog.py @@ -1,15 +1,17 @@ import discord from discord.ext import commands import logging -import settings_manager # Assuming settings_manager is accessible -import command_customization # Import command customization utilities +import settings_manager # Assuming settings_manager is accessible +import command_customization # Import command customization utilities from typing import Optional log = logging.getLogger(__name__) + # Get CORE_COGS from bot instance def get_core_cogs(bot): - return getattr(bot, 'core_cogs', {'SettingsCog', 'HelpCog'}) + return getattr(bot, "core_cogs", {"SettingsCog", "HelpCog"}) + class SettingsCog(commands.Cog, name="Settings"): """Commands for server administrators to configure the bot.""" @@ -18,7 +20,10 @@ class SettingsCog(commands.Cog, name="Settings"): self.bot = bot # --- Prefix Management --- - @commands.command(name='setprefix', help="Sets the command prefix for this server. Usage: `setprefix `") + @commands.command( + name="setprefix", + help="Sets the command prefix for this server. Usage: `setprefix `", + ) @commands.has_permissions(administrator=True) @commands.guild_only() async def set_prefix(self, ctx: commands.Context, new_prefix: str): @@ -26,40 +31,56 @@ class SettingsCog(commands.Cog, name="Settings"): if not new_prefix: await ctx.send("Prefix cannot be empty.") return - if len(new_prefix) > 10: # Arbitrary limit - await ctx.send("Prefix cannot be longer than 10 characters.") - return + if len(new_prefix) > 10: # Arbitrary limit + await ctx.send("Prefix cannot be longer than 10 characters.") + return if new_prefix.isspace(): - await ctx.send("Prefix cannot be just whitespace.") - return + await ctx.send("Prefix cannot be just whitespace.") + return guild_id = ctx.guild.id success = await settings_manager.set_guild_prefix(guild_id, new_prefix) if success: - await ctx.send(f"Command prefix for this server has been set to: `{new_prefix}`") - log.info(f"Prefix updated for guild {guild_id} to '{new_prefix}' by {ctx.author.name}") + await ctx.send( + f"Command prefix for this server has been set to: `{new_prefix}`" + ) + log.info( + f"Prefix updated for guild {guild_id} to '{new_prefix}' by {ctx.author.name}" + ) else: await ctx.send("Failed to set the prefix. Please check the logs.") log.error(f"Failed to save prefix for guild {guild_id}") - @commands.command(name='showprefix', help="Shows the current command prefix for this server.") + @commands.command( + name="showprefix", help="Shows the current command prefix for this server." + ) @commands.guild_only() async def show_prefix(self, ctx: commands.Context): """Shows the current command prefix.""" # We need the bot's default prefix as a fallback # This might need access to the bot instance's initial config or a constant - default_prefix = self.bot.command_prefix # This might not work if command_prefix is the callable + default_prefix = ( + self.bot.command_prefix + ) # This might not work if command_prefix is the callable # Use the constant defined in main.py if possible, or keep a local fallback - default_prefix_fallback = "!" # TODO: Get default prefix reliably if needed elsewhere + default_prefix_fallback = ( + "!" # TODO: Get default prefix reliably if needed elsewhere + ) guild_id = ctx.guild.id - current_prefix = await settings_manager.get_guild_prefix(guild_id, default_prefix_fallback) - await ctx.send(f"The current command prefix for this server is: `{current_prefix}`") - + current_prefix = await settings_manager.get_guild_prefix( + guild_id, default_prefix_fallback + ) + await ctx.send( + f"The current command prefix for this server is: `{current_prefix}`" + ) # --- Cog Management --- - @commands.command(name='enablecog', help="Enables a specific module (cog) for this server. Usage: `enablecog `") + @commands.command( + name="enablecog", + help="Enables a specific module (cog) for this server. Usage: `enablecog `", + ) @commands.has_permissions(administrator=True) @commands.guild_only() async def enable_cog(self, ctx: commands.Context, cog_name: str): @@ -71,20 +92,27 @@ class SettingsCog(commands.Cog, name="Settings"): core_cogs = get_core_cogs(self.bot) if cog_name in core_cogs: - await ctx.send(f"Error: Core cog `{cog_name}` cannot be disabled/enabled.") - return + await ctx.send(f"Error: Core cog `{cog_name}` cannot be disabled/enabled.") + return guild_id = ctx.guild.id - success = await settings_manager.set_cog_enabled(guild_id, cog_name, enabled=True) + success = await settings_manager.set_cog_enabled( + guild_id, cog_name, enabled=True + ) if success: await ctx.send(f"Module `{cog_name}` has been enabled for this server.") - log.info(f"Cog '{cog_name}' enabled for guild {guild_id} by {ctx.author.name}") + log.info( + f"Cog '{cog_name}' enabled for guild {guild_id} by {ctx.author.name}" + ) else: await ctx.send(f"Failed to enable module `{cog_name}`. Check logs.") log.error(f"Failed to enable cog '{cog_name}' for guild {guild_id}") - @commands.command(name='disablecog', help="Disables a specific module (cog) for this server. Usage: `disablecog `") + @commands.command( + name="disablecog", + help="Disables a specific module (cog) for this server. Usage: `disablecog `", + ) @commands.has_permissions(administrator=True) @commands.guild_only() async def disable_cog(self, ctx: commands.Context, cog_name: str): @@ -95,20 +123,27 @@ class SettingsCog(commands.Cog, name="Settings"): core_cogs = get_core_cogs(self.bot) if cog_name in core_cogs: - await ctx.send(f"Error: Core cog `{cog_name}` cannot be disabled.") - return + await ctx.send(f"Error: Core cog `{cog_name}` cannot be disabled.") + return guild_id = ctx.guild.id - success = await settings_manager.set_cog_enabled(guild_id, cog_name, enabled=False) + success = await settings_manager.set_cog_enabled( + guild_id, cog_name, enabled=False + ) if success: await ctx.send(f"Module `{cog_name}` has been disabled for this server.") - log.info(f"Cog '{cog_name}' disabled for guild {guild_id} by {ctx.author.name}") + log.info( + f"Cog '{cog_name}' disabled for guild {guild_id} by {ctx.author.name}" + ) else: await ctx.send(f"Failed to disable module `{cog_name}`. Check logs.") log.error(f"Failed to disable cog '{cog_name}' for guild {guild_id}") - @commands.command(name='listcogs', help="Lists all available modules (cogs) and their status for this server.") + @commands.command( + name="listcogs", + help="Lists all available modules (cogs) and their status for this server.", + ) @commands.guild_only() async def list_cogs(self, ctx: commands.Context): """Lists available cogs and their enabled/disabled status.""" @@ -118,13 +153,17 @@ class SettingsCog(commands.Cog, name="Settings"): # Let's assume default_enabled=True for now. default_behavior = True - embed = discord.Embed(title="Available Modules (Cogs)", color=discord.Color.blue()) + embed = discord.Embed( + title="Available Modules (Cogs)", color=discord.Color.blue() + ) lines = [] # Get core cogs from bot instance core_cogs_list = get_core_cogs(self.bot) for cog_name in sorted(self.bot.cogs.keys()): - is_enabled = await settings_manager.is_cog_enabled(guild_id, cog_name, default_enabled=default_behavior) + is_enabled = await settings_manager.is_cog_enabled( + guild_id, cog_name, default_enabled=default_behavior + ) status = "✅ Enabled" if is_enabled else "❌ Disabled" if cog_name in core_cogs_list: status += " (Core)" @@ -133,12 +172,16 @@ class SettingsCog(commands.Cog, name="Settings"): embed.description = "\n".join(lines) if lines else "No cogs found." await ctx.send(embed=embed) - # --- Command Permission Management (Basic Role-Based) --- - @commands.command(name='allowcmd', help="Allows a role to use a specific command. Usage: `allowcmd <@Role>`") + @commands.command( + name="allowcmd", + help="Allows a role to use a specific command. Usage: `allowcmd <@Role>`", + ) @commands.has_permissions(administrator=True) @commands.guild_only() - async def allow_command(self, ctx: commands.Context, command_name: str, role: discord.Role): + async def allow_command( + self, ctx: commands.Context, command_name: str, role: discord.Role + ): """Allows a specific role to use a command.""" command = self.bot.get_command(command_name) if not command: @@ -147,20 +190,34 @@ class SettingsCog(commands.Cog, name="Settings"): guild_id = ctx.guild.id role_id = role.id - success = await settings_manager.add_command_permission(guild_id, command_name, role_id) + success = await settings_manager.add_command_permission( + guild_id, command_name, role_id + ) if success: - await ctx.send(f"Role `{role.name}` is now allowed to use command `{command_name}`.") - log.info(f"Permission added for command '{command_name}', role '{role.name}' ({role_id}) in guild {guild_id} by {ctx.author.name}") + await ctx.send( + f"Role `{role.name}` is now allowed to use command `{command_name}`." + ) + log.info( + f"Permission added for command '{command_name}', role '{role.name}' ({role_id}) in guild {guild_id} by {ctx.author.name}" + ) else: - await ctx.send(f"Failed to add permission for command `{command_name}`. Check logs.") - log.error(f"Failed to add permission for command '{command_name}', role {role_id} in guild {guild_id}") + await ctx.send( + f"Failed to add permission for command `{command_name}`. Check logs." + ) + log.error( + f"Failed to add permission for command '{command_name}', role {role_id} in guild {guild_id}" + ) - - @commands.command(name='disallowcmd', help="Disallows a role from using a specific command. Usage: `disallowcmd <@Role>`") + @commands.command( + name="disallowcmd", + help="Disallows a role from using a specific command. Usage: `disallowcmd <@Role>`", + ) @commands.has_permissions(administrator=True) @commands.guild_only() - async def disallow_command(self, ctx: commands.Context, command_name: str, role: discord.Role): + async def disallow_command( + self, ctx: commands.Context, command_name: str, role: discord.Role + ): """Disallows a specific role from using a command.""" command = self.bot.get_command(command_name) if not command: @@ -169,20 +226,35 @@ class SettingsCog(commands.Cog, name="Settings"): guild_id = ctx.guild.id role_id = role.id - success = await settings_manager.remove_command_permission(guild_id, command_name, role_id) + success = await settings_manager.remove_command_permission( + guild_id, command_name, role_id + ) if success: - await ctx.send(f"Role `{role.name}` is no longer allowed to use command `{command_name}`.") - log.info(f"Permission removed for command '{command_name}', role '{role.name}' ({role_id}) in guild {guild_id} by {ctx.author.name}") + await ctx.send( + f"Role `{role.name}` is no longer allowed to use command `{command_name}`." + ) + log.info( + f"Permission removed for command '{command_name}', role '{role.name}' ({role_id}) in guild {guild_id} by {ctx.author.name}" + ) else: - await ctx.send(f"Failed to remove permission for command `{command_name}`. Check logs.") - log.error(f"Failed to remove permission for command '{command_name}', role {role_id} in guild {guild_id}") + await ctx.send( + f"Failed to remove permission for command `{command_name}`. Check logs." + ) + log.error( + f"Failed to remove permission for command '{command_name}', role {role_id} in guild {guild_id}" + ) # --- Command Customization Management --- - @commands.command(name='setcmdname', help="Sets a custom name for a slash command in this server. Usage: `setcmdname `") + @commands.command( + name="setcmdname", + help="Sets a custom name for a slash command in this server. Usage: `setcmdname `", + ) @commands.has_permissions(administrator=True) @commands.guild_only() - async def set_command_name(self, ctx: commands.Context, original_name: str, custom_name: str): + async def set_command_name( + self, ctx: commands.Context, original_name: str, custom_name: str + ): """Sets a custom name for a slash command in the current guild.""" # Validate the original command exists command_found = False @@ -196,55 +268,91 @@ class SettingsCog(commands.Cog, name="Settings"): return # Validate custom name format (Discord has restrictions on command names) - if not custom_name.islower() or not custom_name.replace('_', '').isalnum(): - await ctx.send("Error: Custom command names must be lowercase and contain only letters, numbers, and underscores.") + if not custom_name.islower() or not custom_name.replace("_", "").isalnum(): + await ctx.send( + "Error: Custom command names must be lowercase and contain only letters, numbers, and underscores." + ) return if len(custom_name) < 1 or len(custom_name) > 32: - await ctx.send("Error: Custom command names must be between 1 and 32 characters long.") + await ctx.send( + "Error: Custom command names must be between 1 and 32 characters long." + ) return guild_id = ctx.guild.id - success = await settings_manager.set_custom_command_name(guild_id, original_name, custom_name) + success = await settings_manager.set_custom_command_name( + guild_id, original_name, custom_name + ) if success: - await ctx.send(f"Command `{original_name}` will now appear as `{custom_name}` in this server.\n" - f"Note: You'll need to restart the bot or use `/sync` for changes to take effect.") - log.info(f"Custom command name set for '{original_name}' to '{custom_name}' in guild {guild_id} by {ctx.author.name}") + await ctx.send( + f"Command `{original_name}` will now appear as `{custom_name}` in this server.\n" + f"Note: You'll need to restart the bot or use `/sync` for changes to take effect." + ) + log.info( + f"Custom command name set for '{original_name}' to '{custom_name}' in guild {guild_id} by {ctx.author.name}" + ) else: await ctx.send(f"Failed to set custom command name. Check logs.") - log.error(f"Failed to set custom command name for '{original_name}' in guild {guild_id}") + log.error( + f"Failed to set custom command name for '{original_name}' in guild {guild_id}" + ) - @commands.command(name='resetcmdname', help="Resets a slash command to its original name. Usage: `resetcmdname `") + @commands.command( + name="resetcmdname", + help="Resets a slash command to its original name. Usage: `resetcmdname `", + ) @commands.has_permissions(administrator=True) @commands.guild_only() async def reset_command_name(self, ctx: commands.Context, original_name: str): """Resets a slash command to its original name in the current guild.""" guild_id = ctx.guild.id - success = await settings_manager.set_custom_command_name(guild_id, original_name, None) + success = await settings_manager.set_custom_command_name( + guild_id, original_name, None + ) if success: - await ctx.send(f"Command `{original_name}` has been reset to its original name in this server.\n" - f"Note: You'll need to restart the bot or use `/sync` for changes to take effect.") - log.info(f"Custom command name reset for '{original_name}' in guild {guild_id} by {ctx.author.name}") + await ctx.send( + f"Command `{original_name}` has been reset to its original name in this server.\n" + f"Note: You'll need to restart the bot or use `/sync` for changes to take effect." + ) + log.info( + f"Custom command name reset for '{original_name}' in guild {guild_id} by {ctx.author.name}" + ) else: await ctx.send(f"Failed to reset command name. Check logs.") - log.error(f"Failed to reset command name for '{original_name}' in guild {guild_id}") + log.error( + f"Failed to reset command name for '{original_name}' in guild {guild_id}" + ) - @commands.command(name='setgroupname', help="Sets a custom name for a command group. Usage: `setgroupname `") + @commands.command( + name="setgroupname", + help="Sets a custom name for a command group. Usage: `setgroupname `", + ) @commands.has_permissions(administrator=True) @commands.guild_only() - async def set_group_name(self, ctx: commands.Context, original_name: str, custom_name: str): + async def set_group_name( + self, ctx: commands.Context, original_name: str, custom_name: str + ): """Sets a custom name for a command group in the current guild.""" # Validate the original group exists group_found = False for cmd in self.bot.tree.get_commands(): # Check if this command is itself a group with the specified name - if hasattr(cmd, 'name') and cmd.name == original_name and hasattr(cmd, 'commands'): + if ( + hasattr(cmd, "name") + and cmd.name == original_name + and hasattr(cmd, "commands") + ): group_found = True break # Also check if this is a subcommand of a group with the specified name (for nested groups) - elif hasattr(cmd, 'parent') and cmd.parent and cmd.parent.name == original_name: + elif ( + hasattr(cmd, "parent") + and cmd.parent + and cmd.parent.name == original_name + ): group_found = True break @@ -253,45 +361,73 @@ class SettingsCog(commands.Cog, name="Settings"): return # Validate custom name format (Discord has restrictions on command names) - if not custom_name.islower() or not custom_name.replace('_', '').isalnum(): - await ctx.send("Error: Custom group names must be lowercase and contain only letters, numbers, and underscores.") + if not custom_name.islower() or not custom_name.replace("_", "").isalnum(): + await ctx.send( + "Error: Custom group names must be lowercase and contain only letters, numbers, and underscores." + ) return if len(custom_name) < 1 or len(custom_name) > 32: - await ctx.send("Error: Custom group names must be between 1 and 32 characters long.") + await ctx.send( + "Error: Custom group names must be between 1 and 32 characters long." + ) return guild_id = ctx.guild.id - success = await settings_manager.set_custom_group_name(guild_id, original_name, custom_name) + success = await settings_manager.set_custom_group_name( + guild_id, original_name, custom_name + ) if success: - await ctx.send(f"Command group `{original_name}` will now appear as `{custom_name}` in this server.\n" - f"Note: You'll need to restart the bot or use `/sync` for changes to take effect.") - log.info(f"Custom group name set for '{original_name}' to '{custom_name}' in guild {guild_id} by {ctx.author.name}") + await ctx.send( + f"Command group `{original_name}` will now appear as `{custom_name}` in this server.\n" + f"Note: You'll need to restart the bot or use `/sync` for changes to take effect." + ) + log.info( + f"Custom group name set for '{original_name}' to '{custom_name}' in guild {guild_id} by {ctx.author.name}" + ) else: await ctx.send(f"Failed to set custom group name. Check logs.") - log.error(f"Failed to set custom group name for '{original_name}' in guild {guild_id}") + log.error( + f"Failed to set custom group name for '{original_name}' in guild {guild_id}" + ) - @commands.command(name='resetgroupname', help="Resets a command group to its original name. Usage: `resetgroupname `") + @commands.command( + name="resetgroupname", + help="Resets a command group to its original name. Usage: `resetgroupname `", + ) @commands.has_permissions(administrator=True) @commands.guild_only() async def reset_group_name(self, ctx: commands.Context, original_name: str): """Resets a command group to its original name in the current guild.""" guild_id = ctx.guild.id - success = await settings_manager.set_custom_group_name(guild_id, original_name, None) + success = await settings_manager.set_custom_group_name( + guild_id, original_name, None + ) if success: - await ctx.send(f"Command group `{original_name}` has been reset to its original name in this server.\n" - f"Note: You'll need to restart the bot or use `/sync` for changes to take effect.") - log.info(f"Custom group name reset for '{original_name}' in guild {guild_id} by {ctx.author.name}") + await ctx.send( + f"Command group `{original_name}` has been reset to its original name in this server.\n" + f"Note: You'll need to restart the bot or use `/sync` for changes to take effect." + ) + log.info( + f"Custom group name reset for '{original_name}' in guild {guild_id} by {ctx.author.name}" + ) else: await ctx.send(f"Failed to reset group name. Check logs.") - log.error(f"Failed to reset group name for '{original_name}' in guild {guild_id}") + log.error( + f"Failed to reset group name for '{original_name}' in guild {guild_id}" + ) - @commands.command(name='addcmdalias', help="Adds an alias for a command. Usage: `addcmdalias `") + @commands.command( + name="addcmdalias", + help="Adds an alias for a command. Usage: `addcmdalias `", + ) @commands.has_permissions(administrator=True) @commands.guild_only() - async def add_command_alias(self, ctx: commands.Context, original_name: str, alias_name: str): + async def add_command_alias( + self, ctx: commands.Context, original_name: str, alias_name: str + ): """Adds an alias for a command in the current guild.""" # Validate the original command exists command = self.bot.get_command(original_name) @@ -300,8 +436,10 @@ class SettingsCog(commands.Cog, name="Settings"): return # Validate alias format - if not alias_name.islower() or not alias_name.replace('_', '').isalnum(): - await ctx.send("Error: Aliases must be lowercase and contain only letters, numbers, and underscores.") + if not alias_name.islower() or not alias_name.replace("_", "").isalnum(): + await ctx.send( + "Error: Aliases must be lowercase and contain only letters, numbers, and underscores." + ) return if len(alias_name) < 1 or len(alias_name) > 32: @@ -309,31 +447,54 @@ class SettingsCog(commands.Cog, name="Settings"): return guild_id = ctx.guild.id - success = await settings_manager.add_command_alias(guild_id, original_name, alias_name) + success = await settings_manager.add_command_alias( + guild_id, original_name, alias_name + ) if success: - await ctx.send(f"Added alias `{alias_name}` for command `{original_name}` in this server.") - log.info(f"Command alias added for '{original_name}': '{alias_name}' in guild {guild_id} by {ctx.author.name}") + await ctx.send( + f"Added alias `{alias_name}` for command `{original_name}` in this server." + ) + log.info( + f"Command alias added for '{original_name}': '{alias_name}' in guild {guild_id} by {ctx.author.name}" + ) else: await ctx.send(f"Failed to add command alias. Check logs.") - log.error(f"Failed to add command alias for '{original_name}' in guild {guild_id}") + log.error( + f"Failed to add command alias for '{original_name}' in guild {guild_id}" + ) - @commands.command(name='removecmdalias', help="Removes an alias for a command. Usage: `removecmdalias `") + @commands.command( + name="removecmdalias", + help="Removes an alias for a command. Usage: `removecmdalias `", + ) @commands.has_permissions(administrator=True) @commands.guild_only() - async def remove_command_alias(self, ctx: commands.Context, original_name: str, alias_name: str): + async def remove_command_alias( + self, ctx: commands.Context, original_name: str, alias_name: str + ): """Removes an alias for a command in the current guild.""" guild_id = ctx.guild.id - success = await settings_manager.remove_command_alias(guild_id, original_name, alias_name) + success = await settings_manager.remove_command_alias( + guild_id, original_name, alias_name + ) if success: - await ctx.send(f"Removed alias `{alias_name}` for command `{original_name}` in this server.") - log.info(f"Command alias removed for '{original_name}': '{alias_name}' in guild {guild_id} by {ctx.author.name}") + await ctx.send( + f"Removed alias `{alias_name}` for command `{original_name}` in this server." + ) + log.info( + f"Command alias removed for '{original_name}': '{alias_name}' in guild {guild_id} by {ctx.author.name}" + ) else: await ctx.send(f"Failed to remove command alias. Check logs.") - log.error(f"Failed to remove command alias for '{original_name}' in guild {guild_id}") + log.error( + f"Failed to remove command alias for '{original_name}' in guild {guild_id}" + ) - @commands.command(name='listcmdaliases', help="Lists all command aliases for this server.") + @commands.command( + name="listcmdaliases", help="Lists all command aliases for this server." + ) @commands.guild_only() async def list_command_aliases(self, ctx: commands.Context): """Lists all command aliases for the current guild.""" @@ -350,17 +511,27 @@ class SettingsCog(commands.Cog, name="Settings"): embed = discord.Embed(title="Command Aliases", color=discord.Color.blue()) for cmd_name, aliases in aliases_dict.items(): - embed.add_field(name=f"Command: {cmd_name}", value=", ".join([f"`{alias}`" for alias in aliases]), inline=False) + embed.add_field( + name=f"Command: {cmd_name}", + value=", ".join([f"`{alias}`" for alias in aliases]), + inline=False, + ) await ctx.send(embed=embed) - @commands.command(name='listcustomcmds', help="Lists all custom command names for this server.") + @commands.command( + name="listcustomcmds", help="Lists all custom command names for this server." + ) @commands.guild_only() async def list_custom_commands(self, ctx: commands.Context): """Lists all custom command names for the current guild.""" guild_id = ctx.guild.id - cmd_customizations = await settings_manager.get_all_command_customizations(guild_id) - group_customizations = await settings_manager.get_all_group_customizations(guild_id) + cmd_customizations = await settings_manager.get_all_command_customizations( + guild_id + ) + group_customizations = await settings_manager.get_all_group_customizations( + guild_id + ) if cmd_customizations is None or group_customizations is None: await ctx.send("Failed to retrieve command customizations. Check logs.") @@ -370,19 +541,33 @@ class SettingsCog(commands.Cog, name="Settings"): await ctx.send("No command customizations are set for this server.") return - embed = discord.Embed(title="Command Customizations", color=discord.Color.blue()) + embed = discord.Embed( + title="Command Customizations", color=discord.Color.blue() + ) if cmd_customizations: - cmd_text = "\n".join([f"`{orig}` → `{custom['name']}`" for orig, custom in cmd_customizations.items()]) + cmd_text = "\n".join( + [ + f"`{orig}` → `{custom['name']}`" + for orig, custom in cmd_customizations.items() + ] + ) embed.add_field(name="Custom Command Names", value=cmd_text, inline=False) if group_customizations: - group_text = "\n".join([f"`{orig}` → `{custom['name']}`" for orig, custom in group_customizations.items()]) + group_text = "\n".join( + [ + f"`{orig}` → `{custom['name']}`" + for orig, custom in group_customizations.items() + ] + ) embed.add_field(name="Custom Group Names", value=group_text, inline=False) await ctx.send(embed=embed) - @commands.command(name='listgroups', help="Lists all available command groups for debugging.") + @commands.command( + name="listgroups", help="Lists all available command groups for debugging." + ) @commands.has_permissions(administrator=True) @commands.guild_only() async def list_groups(self, ctx: commands.Context): @@ -392,31 +577,44 @@ class SettingsCog(commands.Cog, name="Settings"): for cmd in self.bot.tree.get_commands(): # Check if this is a group - if hasattr(cmd, 'commands') and hasattr(cmd, 'name'): - groups.append(f"`{cmd.name}` - {getattr(cmd, 'description', 'No description')}") + if hasattr(cmd, "commands") and hasattr(cmd, "name"): + groups.append( + f"`{cmd.name}` - {getattr(cmd, 'description', 'No description')}" + ) # Check if this is a regular command - elif hasattr(cmd, 'name'): + elif hasattr(cmd, "name"): commands_list.append(f"`{cmd.name}`") - embed = discord.Embed(title="Available Command Groups & Commands", color=discord.Color.green()) + embed = discord.Embed( + title="Available Command Groups & Commands", color=discord.Color.green() + ) if groups: - groups_text = "\n".join(groups[:10]) # Limit to first 10 to avoid message length issues + groups_text = "\n".join( + groups[:10] + ) # Limit to first 10 to avoid message length issues if len(groups) > 10: groups_text += f"\n... and {len(groups) - 10} more groups" embed.add_field(name="Command Groups", value=groups_text, inline=False) else: - embed.add_field(name="Command Groups", value="No groups found", inline=False) + embed.add_field( + name="Command Groups", value="No groups found", inline=False + ) if commands_list: commands_text = ", ".join(commands_list[:20]) # Limit to first 20 if len(commands_list) > 20: commands_text += f", ... and {len(commands_list) - 20} more commands" - embed.add_field(name="Individual Commands", value=commands_text, inline=False) + embed.add_field( + name="Individual Commands", value=commands_text, inline=False + ) await ctx.send(embed=embed) - @commands.command(name='synccmds', help="Syncs slash commands with Discord to apply customizations.") + @commands.command( + name="synccmds", + help="Syncs slash commands with Discord to apply customizations.", + ) @commands.has_permissions(administrator=True) @commands.guild_only() async def sync_commands(self, ctx: commands.Context): @@ -427,24 +625,37 @@ class SettingsCog(commands.Cog, name="Settings"): # Use the command_customization module to sync commands with customizations try: - synced = await command_customization.register_guild_commands(self.bot, guild) + synced = await command_customization.register_guild_commands( + self.bot, guild + ) - await ctx.send(f"Successfully synced {len(synced)} commands for this server with customizations.") - log.info(f"Commands synced with customizations for guild {guild.id} by {ctx.author.name}") + await ctx.send( + f"Successfully synced {len(synced)} commands for this server with customizations." + ) + log.info( + f"Commands synced with customizations for guild {guild.id} by {ctx.author.name}" + ) except Exception as e: log.error(f"Failed to sync commands with customizations: {e}") # Don't fall back to regular sync to avoid command duplication - await ctx.send(f"Failed to apply customizations. Please check the logs and try again.") - log.info(f"Command sync with customizations failed for guild {guild.id}") + await ctx.send( + f"Failed to apply customizations. Please check the logs and try again." + ) + log.info( + f"Command sync with customizations failed for guild {guild.id}" + ) except Exception as e: await ctx.send(f"Failed to sync commands: {str(e)}") log.error(f"Failed to sync commands for guild {ctx.guild.id}: {e}") # TODO: Add command to list permissions? - # --- Moderation Logging Settings --- - @commands.group(name='modlogconfig', help="Configure the integrated moderation logging.", invoke_without_command=True) + @commands.group( + name="modlogconfig", + help="Configure the integrated moderation logging.", + invoke_without_command=True, + ) @commands.has_permissions(administrator=True) @commands.guild_only() async def modlog_config_group(self, ctx: commands.Context): @@ -455,14 +666,20 @@ class SettingsCog(commands.Cog, name="Settings"): channel = ctx.guild.get_channel(channel_id) if channel_id else None status = "✅ Enabled" if enabled else "❌ Disabled" - channel_status = channel.mention if channel else ("Not Set" if channel_id else "Not Set") + channel_status = ( + channel.mention if channel else ("Not Set" if channel_id else "Not Set") + ) - embed = discord.Embed(title="Moderation Logging Configuration", color=discord.Color.teal()) + embed = discord.Embed( + title="Moderation Logging Configuration", color=discord.Color.teal() + ) embed.add_field(name="Status", value=status, inline=False) embed.add_field(name="Log Channel", value=channel_status, inline=False) await ctx.send(embed=embed) - @modlog_config_group.command(name='enable', help="Enables the integrated moderation logging.") + @modlog_config_group.command( + name="enable", help="Enables the integrated moderation logging." + ) @commands.has_permissions(administrator=True) @commands.guild_only() async def modlog_enable(self, ctx: commands.Context): @@ -471,12 +688,16 @@ class SettingsCog(commands.Cog, name="Settings"): success = await settings_manager.set_mod_log_enabled(guild_id, True) if success: await ctx.send("✅ Integrated moderation logging has been enabled.") - log.info(f"Moderation logging enabled for guild {guild_id} by {ctx.author.name}") + log.info( + f"Moderation logging enabled for guild {guild_id} by {ctx.author.name}" + ) else: await ctx.send("❌ Failed to enable moderation logging. Please check logs.") log.error(f"Failed to enable moderation logging for guild {guild_id}") - @modlog_config_group.command(name='disable', help="Disables the integrated moderation logging.") + @modlog_config_group.command( + name="disable", help="Disables the integrated moderation logging." + ) @commands.has_permissions(administrator=True) @commands.guild_only() async def modlog_disable(self, ctx: commands.Context): @@ -485,31 +706,52 @@ class SettingsCog(commands.Cog, name="Settings"): success = await settings_manager.set_mod_log_enabled(guild_id, False) if success: await ctx.send("❌ Integrated moderation logging has been disabled.") - log.info(f"Moderation logging disabled for guild {guild_id} by {ctx.author.name}") + log.info( + f"Moderation logging disabled for guild {guild_id} by {ctx.author.name}" + ) else: - await ctx.send("❌ Failed to disable moderation logging. Please check logs.") + await ctx.send( + "❌ Failed to disable moderation logging. Please check logs." + ) log.error(f"Failed to disable moderation logging for guild {guild_id}") - @modlog_config_group.command(name='setchannel', help="Sets the channel where moderation logs will be sent. Usage: `setchannel #channel`") + @modlog_config_group.command( + name="setchannel", + help="Sets the channel where moderation logs will be sent. Usage: `setchannel #channel`", + ) @commands.has_permissions(administrator=True) @commands.guild_only() - async def modlog_setchannel(self, ctx: commands.Context, channel: discord.TextChannel): + async def modlog_setchannel( + self, ctx: commands.Context, channel: discord.TextChannel + ): """Sets the channel for integrated moderation logs.""" guild_id = ctx.guild.id # Basic check for bot permissions in the target channel - if not channel.permissions_for(ctx.guild.me).send_messages or not channel.permissions_for(ctx.guild.me).embed_links: - await ctx.send(f"❌ I need 'Send Messages' and 'Embed Links' permissions in {channel.mention} to send logs there.") - return + if ( + not channel.permissions_for(ctx.guild.me).send_messages + or not channel.permissions_for(ctx.guild.me).embed_links + ): + await ctx.send( + f"❌ I need 'Send Messages' and 'Embed Links' permissions in {channel.mention} to send logs there." + ) + return success = await settings_manager.set_mod_log_channel_id(guild_id, channel.id) if success: await ctx.send(f"✅ Moderation logs will now be sent to {channel.mention}.") - log.info(f"Moderation log channel set to {channel.id} for guild {guild_id} by {ctx.author.name}") + log.info( + f"Moderation log channel set to {channel.id} for guild {guild_id} by {ctx.author.name}" + ) else: - await ctx.send("❌ Failed to set the moderation log channel. Please check logs.") + await ctx.send( + "❌ Failed to set the moderation log channel. Please check logs." + ) log.error(f"Failed to set moderation log channel for guild {guild_id}") - @modlog_config_group.command(name='unsetchannel', help="Unsets the moderation log channel (disables sending logs).") + @modlog_config_group.command( + name="unsetchannel", + help="Unsets the moderation log channel (disables sending logs).", + ) @commands.has_permissions(administrator=True) @commands.guild_only() async def modlog_unsetchannel(self, ctx: commands.Context): @@ -517,13 +759,18 @@ class SettingsCog(commands.Cog, name="Settings"): guild_id = ctx.guild.id success = await settings_manager.set_mod_log_channel_id(guild_id, None) if success: - await ctx.send("✅ Moderation log channel has been unset. Logs will not be sent to a channel.") - log.info(f"Moderation log channel unset for guild {guild_id} by {ctx.author.name}") + await ctx.send( + "✅ Moderation log channel has been unset. Logs will not be sent to a channel." + ) + log.info( + f"Moderation log channel unset for guild {guild_id} by {ctx.author.name}" + ) else: - await ctx.send("❌ Failed to unset the moderation log channel. Please check logs.") + await ctx.send( + "❌ Failed to unset the moderation log channel. Please check logs." + ) log.error(f"Failed to unset moderation log channel for guild {guild_id}") - # --- Error Handling for this Cog --- @set_prefix.error @enable_cog.error @@ -538,47 +785,64 @@ class SettingsCog(commands.Cog, name="Settings"): @remove_command_alias.error @list_groups.error @sync_commands.error - @modlog_config_group.error # Add error handler for the group + @modlog_config_group.error # Add error handler for the group @modlog_enable.error @modlog_disable.error @modlog_setchannel.error @modlog_unsetchannel.error async def on_command_error(self, ctx: commands.Context, error): # Check if the error originates from the modlogconfig group or its subcommands - if ctx.command and (ctx.command.name == 'modlogconfig' or (ctx.command.parent and ctx.command.parent.name == 'modlogconfig')): + if ctx.command and ( + ctx.command.name == "modlogconfig" + or (ctx.command.parent and ctx.command.parent.name == "modlogconfig") + ): if isinstance(error, commands.MissingPermissions): - await ctx.send("You need Administrator permissions to configure moderation logging.") - return # Handled + await ctx.send( + "You need Administrator permissions to configure moderation logging." + ) + return # Handled elif isinstance(error, commands.BadArgument): - await ctx.send(f"Invalid argument. Usage: `{ctx.prefix}help {ctx.command.qualified_name}`") - return # Handled + await ctx.send( + f"Invalid argument. Usage: `{ctx.prefix}help {ctx.command.qualified_name}`" + ) + return # Handled elif isinstance(error, commands.MissingRequiredArgument): - await ctx.send(f"Missing argument. Usage: `{ctx.prefix}help {ctx.command.qualified_name}`") - return # Handled + await ctx.send( + f"Missing argument. Usage: `{ctx.prefix}help {ctx.command.qualified_name}`" + ) + return # Handled elif isinstance(error, commands.NoPrivateMessage): - await ctx.send("This command can only be used in a server.") - return # Handled + await ctx.send("This command can only be used in a server.") + return # Handled # Let other errors fall through to the generic handler below # Generic handlers for other commands in this cog if isinstance(error, commands.MissingPermissions): await ctx.send("You need Administrator permissions to use this command.") elif isinstance(error, commands.BadArgument): - await ctx.send(f"Invalid argument provided. Check the command help: `{ctx.prefix}help {ctx.command.name}`") + await ctx.send( + f"Invalid argument provided. Check the command help: `{ctx.prefix}help {ctx.command.name}`" + ) elif isinstance(error, commands.MissingRequiredArgument): - await ctx.send(f"Missing required argument. Check the command help: `{ctx.prefix}help {ctx.command.name}`") + await ctx.send( + f"Missing required argument. Check the command help: `{ctx.prefix}help {ctx.command.name}`" + ) elif isinstance(error, commands.NoPrivateMessage): await ctx.send("This command cannot be used in private messages.") else: - log.error(f"Unhandled error in SettingsCog command '{ctx.command.name}': {error}") + log.error( + f"Unhandled error in SettingsCog command '{ctx.command.name}': {error}" + ) await ctx.send("An unexpected error occurred. Please check the logs.") async def setup(bot: commands.Bot): # Ensure pools are initialized before adding the cog if getattr(bot, "pg_pool", None) is None or getattr(bot, "redis", None) is None: - log.warning("Bot pools not initialized before loading SettingsCog. Cog will not load.") - return # Prevent loading if pools are missing + log.warning( + "Bot pools not initialized before loading SettingsCog. Cog will not load." + ) + return # Prevent loading if pools are missing await bot.add_cog(SettingsCog(bot)) log.info("SettingsCog loaded.") diff --git a/cogs/shell_command_cog.py b/cogs/shell_command_cog.py index 9292a83..c98191b 100644 --- a/cogs/shell_command_cog.py +++ b/cogs/shell_command_cog.py @@ -9,44 +9,35 @@ import logging from collections import defaultdict # Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s') +logging.basicConfig( + level=logging.INFO, format="%(asctime)s:%(levelname)s:%(name)s: %(message)s" +) logger = logging.getLogger(__name__) # Comprehensive list of banned commands and patterns BANNED_COMMANDS = [ # # System modification commands # "rm", "rmdir", "del", "format", "fdisk", "mkfs", "fsck", "dd", "shred", - # # File permission/ownership changes # "chmod", "chown", "icacls", "takeown", "attrib", - # # User management # "useradd", "userdel", "adduser", "deluser", "passwd", "usermod", "net user", - # # Process control that could affect the bot # "kill", "pkill", "taskkill", "killall", - # # Package management # "apt", "apt-get", "yum", "dnf", "pacman", "brew", "pip", "npm", "gem", "cargo", - # # Network configuration # "ifconfig", "ip", "route", "iptables", "firewall-cmd", "ufw", "netsh", - # # System control # "shutdown", "reboot", "halt", "poweroff", "init", "systemctl", - # # Potentially dangerous utilities # "wget", "curl", "nc", "ncat", "telnet", "ssh", "scp", "ftp", "sftp", - # # Shell escapes or command chaining that could bypass restrictions # "bash", "sh", "cmd", "powershell", "pwsh", "python", "perl", "ruby", "php", "node", - # # Git commands that could modify repositories # "git push", "git commit", "git config", "git remote", - # # Windows specific dangerous commands # "reg", "regedit", "wmic", "diskpart", "sfc", "dism", - # # Miscellaneous dangerous commands # "eval", "exec", "source", ">", ">>", "|", "&", "&&", ";", "||" ] @@ -67,6 +58,7 @@ BANNED_PATTERNS = [ # r"\|\|\s*del", # command chaining with del ] + def is_command_allowed(command): """ Check if the command is allowed to run. @@ -84,23 +76,22 @@ def is_command_allowed(command): return True, None + class ShellCommandCog(commands.Cog): def __init__(self, bot): self.bot = bot self.max_output_length = 1900 # Discord message limit is 2000 chars - self.timeout_seconds = 30 # Maximum time a command can run + self.timeout_seconds = 30 # Maximum time a command can run # Store persistent shell sessions - self.owner_shell_sessions = defaultdict(lambda: { - 'cwd': os.getcwd(), - 'env': os.environ.copy() - }) + self.owner_shell_sessions = defaultdict( + lambda: {"cwd": os.getcwd(), "env": os.environ.copy()} + ) # Store persistent docker shell sessions - self.docker_shell_sessions = defaultdict(lambda: { - 'container_id': None, - 'created': False - }) + self.docker_shell_sessions = defaultdict( + lambda: {"container_id": None, "created": False} + ) async def _execute_command(self, command_str, session_id=None, use_docker=False): """ @@ -114,14 +105,15 @@ class ShellCommandCog(commands.Cog): return f"⛔ Command not allowed: {reason}" # Log the command execution - logger.info(f"Executing {'docker ' if use_docker else ''}shell command: {command_str}") + logger.info( + f"Executing {'docker ' if use_docker else ''}shell command: {command_str}" + ) if use_docker: return await self._execute_docker_command(command_str, session_id) else: return await self._execute_local_command(command_str, session_id) - async def _execute_local_command(self, command_str, session_id=None): """ Execute a command locally with optional session persistence. @@ -131,8 +123,8 @@ class ShellCommandCog(commands.Cog): if session_id: session = self.owner_shell_sessions[session_id] - cwd = session['cwd'] - env = session['env'] + cwd = session["cwd"] + env = session["env"] else: cwd = os.getcwd() env = os.environ.copy() @@ -145,7 +137,7 @@ class ShellCommandCog(commands.Cog): cwd=cwd, env=env, stdout=subprocess.PIPE, - stderr=subprocess.PIPE + stderr=subprocess.PIPE, ) try: stdout, stderr = proc.communicate(timeout=self.timeout_seconds) @@ -160,16 +152,16 @@ class ShellCommandCog(commands.Cog): stdout, stderr, returncode, timed_out = await asyncio.to_thread(run_subprocess) # Update session working directory if 'cd' command was used - if session_id and command_str.strip().startswith('cd '): + if session_id and command_str.strip().startswith("cd "): # Try to update session cwd (best effort, not robust for chained commands) new_dir = command_str.strip()[3:].strip() if os.path.isabs(new_dir): - session['cwd'] = new_dir + session["cwd"] = new_dir else: - session['cwd'] = os.path.abspath(os.path.join(cwd, new_dir)) + session["cwd"] = os.path.abspath(os.path.join(cwd, new_dir)) - stdout_str = stdout.decode('utf-8', errors='replace').strip() - stderr_str = stderr.decode('utf-8', errors='replace').strip() + stdout_str = stdout.decode("utf-8", errors="replace").strip() + stderr_str = stderr.decode("utf-8", errors="replace").strip() result = [] if timed_out: @@ -177,12 +169,16 @@ class ShellCommandCog(commands.Cog): if stdout_str: if len(stdout_str) > self.max_output_length: - stdout_str = stdout_str[:self.max_output_length] + "... (output truncated)" + stdout_str = ( + stdout_str[: self.max_output_length] + "... (output truncated)" + ) result.append(f"📤 **STDOUT:**\n```\n{stdout_str}\n```") if stderr_str: if len(stderr_str) > self.max_output_length: - stderr_str = stderr_str[:self.max_output_length] + "... (output truncated)" + stderr_str = ( + stderr_str[: self.max_output_length] + "... (output truncated)" + ) result.append(f"⚠️ **STDERR:**\n```\n{stderr_str}\n```") if returncode != 0 and not timed_out: @@ -204,7 +200,7 @@ class ShellCommandCog(commands.Cog): docker_check_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - shell=True + shell=True, ) # We don't need the output, just the return code @@ -219,45 +215,50 @@ class ShellCommandCog(commands.Cog): session = self.docker_shell_sessions[session_id] # Create a new container if one doesn't exist for this session - if not session['created']: + if not session["created"]: # Create a new container with a minimal Linux image - create_container_cmd = "docker run -d --rm --name shell_" + session_id + " alpine:latest tail -f /dev/null" + create_container_cmd = ( + "docker run -d --rm --name shell_" + + session_id + + " alpine:latest tail -f /dev/null" + ) process = await asyncio.create_subprocess_shell( create_container_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - shell=True + shell=True, ) stdout, stderr = await process.communicate() if process.returncode != 0: - error_msg = stderr.decode('utf-8', errors='replace').strip() + error_msg = stderr.decode("utf-8", errors="replace").strip() return f"❌ Failed to create Docker container: {error_msg}" - container_id = stdout.decode('utf-8', errors='replace').strip() - session['container_id'] = container_id - session['created'] = True + container_id = stdout.decode("utf-8", errors="replace").strip() + session["container_id"] = container_id + session["created"] = True - logger.info(f"Created Docker container with ID: {container_id} for session {session_id}") + logger.info( + f"Created Docker container with ID: {container_id} for session {session_id}" + ) # Execute the command in the container # Escape double quotes in the command string escaped_cmd = command_str.replace('"', '\\"') - docker_exec_cmd = f"docker exec shell_{session_id} sh -c \"{escaped_cmd}\"" + docker_exec_cmd = f'docker exec shell_{session_id} sh -c "{escaped_cmd}"' process = await asyncio.create_subprocess_shell( docker_exec_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - shell=True + shell=True, ) try: stdout, stderr = await asyncio.wait_for( - process.communicate(), - timeout=self.timeout_seconds + process.communicate(), timeout=self.timeout_seconds ) except asyncio.TimeoutError: # Try to terminate the process if it times out @@ -272,19 +273,23 @@ class ShellCommandCog(commands.Cog): return f"⏱️ Command timed out after {self.timeout_seconds} seconds." # Decode the output - stdout_str = stdout.decode('utf-8', errors='replace').strip() - stderr_str = stderr.decode('utf-8', errors='replace').strip() + stdout_str = stdout.decode("utf-8", errors="replace").strip() + stderr_str = stderr.decode("utf-8", errors="replace").strip() # Prepare the result message result = [] if stdout_str: if len(stdout_str) > self.max_output_length: - stdout_str = stdout_str[:self.max_output_length] + "... (output truncated)" + stdout_str = ( + stdout_str[: self.max_output_length] + "... (output truncated)" + ) result.append(f"📤 **STDOUT:**\n```\n{stdout_str}\n```") if stderr_str: if len(stderr_str) > self.max_output_length: - stderr_str = stderr_str[:self.max_output_length] + "... (output truncated)" + stderr_str = ( + stderr_str[: self.max_output_length] + "... (output truncated)" + ) result.append(f"⚠️ **STDERR:**\n```\n{stderr_str}\n```") if process.returncode != 0: @@ -295,7 +300,11 @@ class ShellCommandCog(commands.Cog): return "\n".join(result) - @commands.command(name="ownershell", help="Execute a shell command directly on the host (Owner only)", aliases=["sh"]) + @commands.command( + name="ownershell", + help="Execute a shell command directly on the host (Owner only)", + aliases=["sh"], + ) @commands.is_owner() async def ownershell_command(self, ctx, *, command_str): """Execute a shell command directly on the host (Owner only).""" @@ -303,28 +312,34 @@ class ShellCommandCog(commands.Cog): session_id = str(ctx.author.id) async with ctx.typing(): - result = await self._execute_command(command_str, session_id=session_id, use_docker=False) + result = await self._execute_command( + command_str, session_id=session_id, use_docker=False + ) # Split long messages if needed if len(result) > 2000: - parts = [result[i:i+1990] for i in range(0, len(result), 1990)] + parts = [result[i : i + 1990] for i in range(0, len(result), 1990)] for i, part in enumerate(parts): await ctx.reply(f"Part {i+1}/{len(parts)}:\n{part}") else: await ctx.reply(result) - @commands.command(name="dockersh", help="Execute a shell command in a Docker container") + @commands.command( + name="dockersh", help="Execute a shell command in a Docker container" + ) async def shell_command(self, ctx, *, command_str): """Execute a shell command in a Docker container.""" # Get or create a session ID for this user session_id = str(ctx.author.id) async with ctx.typing(): - result = await self._execute_command(command_str, session_id=session_id, use_docker=True) + result = await self._execute_command( + command_str, session_id=session_id, use_docker=True + ) # Split long messages if needed if len(result) > 2000: - parts = [result[i:i+1990] for i in range(0, len(result), 1990)] + parts = [result[i : i + 1990] for i in range(0, len(result), 1990)] for i, part in enumerate(parts): await ctx.reply(f"Part {i+1}/{len(parts)}:\n{part}") else: @@ -339,7 +354,7 @@ class ShellCommandCog(commands.Cog): if shell_type.lower() in ["docker", "container", "safe"]: # If there's an existing container, stop and remove it session = self.docker_shell_sessions[session_id] - if session['created'] and session['container_id']: + if session["created"] and session["container_id"]: try: # Stop the container stop_cmd = f"docker stop shell_{session_id}" @@ -347,7 +362,7 @@ class ShellCommandCog(commands.Cog): stop_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - shell=True + shell=True, ) await process.communicate() except Exception as e: @@ -355,29 +370,34 @@ class ShellCommandCog(commands.Cog): # Reset the session self.docker_shell_sessions[session_id] = { - 'container_id': None, - 'created': False + "container_id": None, + "created": False, } await ctx.reply("✅ Docker shell session has been reset.") elif shell_type.lower() in ["owner", "host", "local"]: # Reset the owner shell session self.owner_shell_sessions[session_id] = { - 'cwd': os.getcwd(), - 'env': os.environ.copy() + "cwd": os.getcwd(), + "env": os.environ.copy(), } await ctx.reply("✅ Owner shell session has been reset.") else: await ctx.reply("❌ Invalid shell type. Use 'docker' or 'owner'.") - @app_commands.command(name="sh", description="Execute a shell command directly on the host (Owner only)") + @app_commands.command( + name="sh", + description="Execute a shell command directly on the host (Owner only)", + ) @app_commands.describe(command="The shell command to execute") async def ownershell_slash(self, interaction: discord.Interaction, command: str): """Slash command version of ownershell command.""" # Check if user is the bot owner if interaction.user.id != self.bot.owner_id: - await interaction.response.send_message("⛔ This command is restricted to the bot owner.", ephemeral=True) + await interaction.response.send_message( + "⛔ This command is restricted to the bot owner.", ephemeral=True + ) return # Get or create a session ID for this user @@ -387,24 +407,31 @@ class ShellCommandCog(commands.Cog): await interaction.response.defer() # Execute the command - result = await self._execute_command(command, session_id=session_id, use_docker=False) + result = await self._execute_command( + command, session_id=session_id, use_docker=False + ) # Send the result if len(result) > 2000: - parts = [result[i:i+1990] for i in range(0, len(result), 1990)] + parts = [result[i : i + 1990] for i in range(0, len(result), 1990)] await interaction.followup.send(f"Part 1/{len(parts)}:\n{parts[0]}") for i, part in enumerate(parts[1:], 2): await interaction.followup.send(f"Part {i}/{len(parts)}:\n{part}") else: await interaction.followup.send(result) - @app_commands.command(name="dockersh", description="Execute a shell command in a Docker container (Owner only)") + @app_commands.command( + name="dockersh", + description="Execute a shell command in a Docker container (Owner only)", + ) @app_commands.describe(command="The shell command to execute") async def shell_slash(self, interaction: discord.Interaction, command: str): """Slash command version of shell command.""" # Check if user is the bot owner if interaction.user.id != self.bot.owner_id: - await interaction.response.send_message("⛔ This command is restricted to the bot owner.", ephemeral=True) + await interaction.response.send_message( + "⛔ This command is restricted to the bot owner.", ephemeral=True + ) return # Get or create a session ID for this user @@ -414,28 +441,40 @@ class ShellCommandCog(commands.Cog): await interaction.response.defer() # Execute the command - result = await self._execute_command(command, session_id=session_id, use_docker=True) + result = await self._execute_command( + command, session_id=session_id, use_docker=True + ) # Send the result if len(result) > 2000: - parts = [result[i:i+1990] for i in range(0, len(result), 1990)] + parts = [result[i : i + 1990] for i in range(0, len(result), 1990)] await interaction.followup.send(f"Part 1/{len(parts)}:\n{parts[0]}") for i, part in enumerate(parts[1:], 2): await interaction.followup.send(f"Part {i}/{len(parts)}:\n{part}") else: await interaction.followup.send(result) - @app_commands.command(name="newshell", description="Reset your shell session (Owner only)") - @app_commands.describe(shell_type="The type of shell to reset ('docker' or 'owner')") - @app_commands.choices(shell_type=[ - app_commands.Choice(name="Docker Container Shell", value="docker"), - app_commands.Choice(name="Owner Host Shell", value="owner") - ]) - async def newshell_slash(self, interaction: discord.Interaction, shell_type: str = "docker"): + @app_commands.command( + name="newshell", description="Reset your shell session (Owner only)" + ) + @app_commands.describe( + shell_type="The type of shell to reset ('docker' or 'owner')" + ) + @app_commands.choices( + shell_type=[ + app_commands.Choice(name="Docker Container Shell", value="docker"), + app_commands.Choice(name="Owner Host Shell", value="owner"), + ] + ) + async def newshell_slash( + self, interaction: discord.Interaction, shell_type: str = "docker" + ): """Slash command version of newshell command.""" # Check if user is the bot owner if interaction.user.id != self.bot.owner_id: - await interaction.response.send_message("⛔ This command is restricted to the bot owner.", ephemeral=True) + await interaction.response.send_message( + "⛔ This command is restricted to the bot owner.", ephemeral=True + ) return session_id = str(interaction.user.id) @@ -443,7 +482,7 @@ class ShellCommandCog(commands.Cog): if shell_type.lower() in ["docker", "container", "safe"]: # If there's an existing container, stop and remove it session = self.docker_shell_sessions[session_id] - if session['created'] and session['container_id']: + if session["created"] and session["container_id"]: try: # Stop the container stop_cmd = f"docker stop shell_{session_id}" @@ -451,7 +490,7 @@ class ShellCommandCog(commands.Cog): stop_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - shell=True + shell=True, ) await process.communicate() except Exception as e: @@ -459,21 +498,27 @@ class ShellCommandCog(commands.Cog): # Reset the session self.docker_shell_sessions[session_id] = { - 'container_id': None, - 'created': False + "container_id": None, + "created": False, } - await interaction.response.send_message("✅ Docker shell session has been reset.") + await interaction.response.send_message( + "✅ Docker shell session has been reset." + ) elif shell_type.lower() in ["owner", "host", "local"]: # Reset the owner shell session self.owner_shell_sessions[session_id] = { - 'cwd': os.getcwd(), - 'env': os.environ.copy() + "cwd": os.getcwd(), + "env": os.environ.copy(), } - await interaction.response.send_message("✅ Owner shell session has been reset.") + await interaction.response.send_message( + "✅ Owner shell session has been reset." + ) else: - await interaction.response.send_message("❌ Invalid shell type. Use 'docker' or 'owner'.") + await interaction.response.send_message( + "❌ Invalid shell type. Use 'docker' or 'owner'." + ) async def cog_unload(self): """Clean up resources when the cog is unloaded.""" @@ -484,7 +529,7 @@ class ShellCommandCog(commands.Cog): docker_check_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - shell=True + shell=True, ) # We don't need the output, just the return code @@ -496,7 +541,7 @@ class ShellCommandCog(commands.Cog): # Stop and remove all Docker containers for session_id, session in self.docker_shell_sessions.items(): - if session['created'] and session['container_id']: + if session["created"] and session["container_id"]: try: # Stop the container stop_cmd = f"docker stop shell_{session_id}" @@ -504,14 +549,17 @@ class ShellCommandCog(commands.Cog): stop_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - shell=True + shell=True, ) await process.communicate() except Exception as e: - logger.error(f"Error stopping Docker container during unload: {e}") + logger.error( + f"Error stopping Docker container during unload: {e}" + ) except Exception as e: logger.error(f"Error checking Docker availability during unload: {e}") + async def setup(bot): try: logger.info("Attempting to load ShellCommandCog...") diff --git a/cogs/stable_diffusion_cog.py b/cogs/stable_diffusion_cog.py index cec3635..e01c9b8 100644 --- a/cogs/stable_diffusion_cog.py +++ b/cogs/stable_diffusion_cog.py @@ -2,7 +2,11 @@ import discord from discord.ext import commands from discord import app_commands import torch -from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, DPMSolverMultistepScheduler +from diffusers import ( + StableDiffusionPipeline, + StableDiffusionXLPipeline, + DPMSolverMultistepScheduler, +) import os import io import time @@ -10,6 +14,7 @@ import asyncio import json from typing import Optional, Literal, Dict, Any, Union + class StableDiffusionCog(commands.Cog): def __init__(self, bot): self.bot = bot @@ -17,7 +22,9 @@ class StableDiffusionCog(commands.Cog): self.device = "cuda" if torch.cuda.is_available() else "cpu" # Set up model directories - self.models_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "models") + self.models_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "models" + ) self.illustrious_dir = os.path.join(self.models_dir, "illustrious_xl") # Create directories if they don't exist @@ -25,7 +32,11 @@ class StableDiffusionCog(commands.Cog): os.makedirs(self.illustrious_dir, exist_ok=True) # Default to Illustrious XL if available, otherwise fallback to SD 1.5 - self.model_id = self.illustrious_dir if os.path.exists(os.path.join(self.illustrious_dir, "model_index.json")) else "runwayml/stable-diffusion-v1-5" + self.model_id = ( + self.illustrious_dir + if os.path.exists(os.path.join(self.illustrious_dir, "model_index.json")) + else "runwayml/stable-diffusion-v1-5" + ) self.model_type = "sdxl" if self.model_id == self.illustrious_dir else "sd" self.is_generating = False @@ -35,7 +46,9 @@ class StableDiffusionCog(commands.Cog): # Check if Illustrious XL is available if self.model_id != self.illustrious_dir: print("Illustrious XL model not found. Using default model instead.") - print(f"To download Illustrious XL, run the download_illustrious.py script.") + print( + f"To download Illustrious XL, run the download_illustrious.py script." + ) async def load_model(self): """Load the Stable Diffusion model asynchronously""" @@ -54,10 +67,14 @@ class StableDiffusionCog(commands.Cog): None, lambda: StableDiffusionXLPipeline.from_pretrained( self.model_id, - torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, + torch_dtype=( + torch.float16 + if self.device == "cuda" + else torch.float32 + ), use_safetensors=True, - variant="fp16" if self.device == "cuda" else None - ).to(self.device) + variant="fp16" if self.device == "cuda" else None, + ).to(self.device), ) else: print(f"Loading local SD model from {self.model_id}...") @@ -65,10 +82,14 @@ class StableDiffusionCog(commands.Cog): None, lambda: StableDiffusionPipeline.from_pretrained( self.model_id, - torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, + torch_dtype=( + torch.float16 + if self.device == "cuda" + else torch.float32 + ), use_safetensors=True, - variant="fp16" if self.device == "cuda" else None - ).to(self.device) + variant="fp16" if self.device == "cuda" else None, + ).to(self.device), ) else: # HuggingFace model @@ -79,10 +100,14 @@ class StableDiffusionCog(commands.Cog): None, lambda: StableDiffusionXLPipeline.from_pretrained( self.model_id, - torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, + torch_dtype=( + torch.float16 + if self.device == "cuda" + else torch.float32 + ), use_safetensors=True, - variant="fp16" if self.device == "cuda" else None - ).to(self.device) + variant="fp16" if self.device == "cuda" else None, + ).to(self.device), ) else: self.model_type = "sd" @@ -91,15 +116,19 @@ class StableDiffusionCog(commands.Cog): None, lambda: StableDiffusionPipeline.from_pretrained( self.model_id, - torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 - ).to(self.device) + torch_dtype=( + torch.float16 + if self.device == "cuda" + else torch.float32 + ), + ).to(self.device), ) # Use DPM++ 2M Karras scheduler for better quality self.model.scheduler = DPMSolverMultistepScheduler.from_config( self.model.scheduler.config, algorithm_type="dpmsolver++", - use_karras_sigmas=True + use_karras_sigmas=True, ) # Enable attention slicing for lower memory usage @@ -118,12 +147,13 @@ class StableDiffusionCog(commands.Cog): except Exception as e: print(f"Error loading Stable Diffusion model: {e}") import traceback + traceback.print_exc() return False @app_commands.command( name="generate", - description="Generate an image using Stable Diffusion running locally on GPU" + description="Generate an image using Stable Diffusion running locally on GPU", ) @app_commands.describe( prompt="The text prompt to generate an image from", @@ -133,7 +163,7 @@ class StableDiffusionCog(commands.Cog): width="Image width (must be a multiple of 8)", height="Image height (must be a multiple of 8)", seed="Random seed for reproducible results (leave empty for random)", - hidden="Whether to make the response visible only to you" + hidden="Whether to make the response visible only to you", ) async def generate_image( self, @@ -145,36 +175,33 @@ class StableDiffusionCog(commands.Cog): width: Optional[int] = 1024, height: Optional[int] = 1024, seed: Optional[int] = None, - hidden: Optional[bool] = False + hidden: Optional[bool] = False, ): """Generate an image using Stable Diffusion running locally on GPU""" # Check if already generating an image if self.is_generating: await interaction.response.send_message( "⚠️ I'm already generating an image. Please wait until the current generation is complete.", - ephemeral=True + ephemeral=True, ) return # Validate parameters if steps < 1 or steps > 150: await interaction.response.send_message( - "⚠️ Steps must be between 1 and 150.", - ephemeral=True + "⚠️ Steps must be between 1 and 150.", ephemeral=True ) return if guidance_scale < 1 or guidance_scale > 20: await interaction.response.send_message( - "⚠️ Guidance scale must be between 1 and 20.", - ephemeral=True + "⚠️ Guidance scale must be between 1 and 20.", ephemeral=True ) return if width % 8 != 0 or height % 8 != 0: await interaction.response.send_message( - "⚠️ Width and height must be multiples of 8.", - ephemeral=True + "⚠️ Width and height must be multiples of 8.", ephemeral=True ) return @@ -182,10 +209,15 @@ class StableDiffusionCog(commands.Cog): max_size = 1536 if self.model_type == "sdxl" else 1024 min_size = 512 if self.model_type == "sdxl" else 256 - if width < min_size or width > max_size or height < min_size or height > max_size: + if ( + width < min_size + or width > max_size + or height < min_size + or height > max_size + ): await interaction.response.send_message( f"⚠️ Width and height must be between {min_size} and {max_size} for the current model type ({self.model_type.upper()}).", - ephemeral=True + ephemeral=True, ) return @@ -200,7 +232,7 @@ class StableDiffusionCog(commands.Cog): if not await self.load_model(): await interaction.followup.send( "❌ Failed to load the Stable Diffusion model. Check the logs for details.", - ephemeral=hidden + ephemeral=hidden, ) self.is_generating = False return @@ -213,7 +245,11 @@ class StableDiffusionCog(commands.Cog): generator = torch.Generator(device=self.device).manual_seed(seed) # Create a status message - model_name = "Illustrious XL" if self.model_id == self.illustrious_dir else self.model_id + model_name = ( + "Illustrious XL" + if self.model_id == self.illustrious_dir + else self.model_id + ) status_message = f"🖌️ Generating image with {model_name}\n" status_message += f"🔤 Prompt: `{prompt}`\n" status_message += f"📊 Parameters: Steps={steps}, CFG={guidance_scale}, Size={width}x{height}, Seed={seed}" @@ -238,8 +274,8 @@ class StableDiffusionCog(commands.Cog): guidance_scale=guidance_scale, width=width, height=height, - generator=generator - ).images[0] + generator=generator, + ).images[0], ) else: # For regular SD models @@ -252,8 +288,8 @@ class StableDiffusionCog(commands.Cog): guidance_scale=guidance_scale, width=width, height=height, - generator=generator - ).images[0] + generator=generator, + ).images[0], ) # Convert the image to bytes for Discord upload @@ -268,10 +304,12 @@ class StableDiffusionCog(commands.Cog): embed = discord.Embed( title="🖼️ Stable Diffusion Image", description=f"**Prompt:** {prompt}", - color=0x9C84EF + color=0x9C84EF, ) if negative_prompt: - embed.add_field(name="Negative Prompt", value=negative_prompt, inline=False) + embed.add_field( + name="Negative Prompt", value=negative_prompt, inline=False + ) # Add model info to the embed model_info = f"Model: {model_name}\nType: {self.model_type.upper()}" @@ -281,11 +319,14 @@ class StableDiffusionCog(commands.Cog): embed.add_field( name="Parameters", value=f"Steps: {steps}\nGuidance Scale: {guidance_scale}\nSize: {width}x{height}\nSeed: {seed}", - inline=False + inline=False, ) embed.set_image(url="attachment://stable_diffusion_image.png") - embed.set_footer(text=f"Generated by {interaction.user.display_name}", icon_url=interaction.user.display_avatar.url) + embed.set_footer( + text=f"Generated by {interaction.user.display_name}", + icon_url=interaction.user.display_avatar.url, + ) # Send the image await interaction.followup.send(file=file, embed=embed, ephemeral=hidden) @@ -298,10 +339,10 @@ class StableDiffusionCog(commands.Cog): except Exception as e: await interaction.followup.send( - f"❌ Error generating image: {str(e)}", - ephemeral=hidden + f"❌ Error generating image: {str(e)}", ephemeral=hidden ) import traceback + traceback.print_exc() finally: # Reset the flag @@ -309,44 +350,62 @@ class StableDiffusionCog(commands.Cog): @app_commands.command( name="sd_models", - description="List available Stable Diffusion models or change the current model" + description="List available Stable Diffusion models or change the current model", ) @app_commands.describe( model="The model to switch to (leave empty to just list available models)", ) - @app_commands.choices(model=[ - app_commands.Choice(name="Illustrious XL (Local)", value="illustrious_xl"), - app_commands.Choice(name="Stable Diffusion 1.5", value="runwayml/stable-diffusion-v1-5"), - app_commands.Choice(name="Stable Diffusion 2.1", value="stabilityai/stable-diffusion-2-1"), - app_commands.Choice(name="Stable Diffusion XL", value="stabilityai/stable-diffusion-xl-base-1.0") - ]) + @app_commands.choices( + model=[ + app_commands.Choice(name="Illustrious XL (Local)", value="illustrious_xl"), + app_commands.Choice( + name="Stable Diffusion 1.5", value="runwayml/stable-diffusion-v1-5" + ), + app_commands.Choice( + name="Stable Diffusion 2.1", value="stabilityai/stable-diffusion-2-1" + ), + app_commands.Choice( + name="Stable Diffusion XL", + value="stabilityai/stable-diffusion-xl-base-1.0", + ), + ] + ) @commands.is_owner() async def sd_models( self, interaction: discord.Interaction, - model: Optional[app_commands.Choice[str]] = None + model: Optional[app_commands.Choice[str]] = None, ): """List available Stable Diffusion models or change the current model""" # Check if user is the bot owner if interaction.user.id != self.bot.owner_id: await interaction.response.send_message( - "⛔ Only the bot owner can use this command.", - ephemeral=True + "⛔ Only the bot owner can use this command.", ephemeral=True ) return if model is None: # Just list the available models - current_model = "Illustrious XL (Local)" if self.model_id == self.illustrious_dir else self.model_id + current_model = ( + "Illustrious XL (Local)" + if self.model_id == self.illustrious_dir + else self.model_id + ) embed = discord.Embed( title="🤖 Available Stable Diffusion Models", description=f"**Current model:** `{current_model}`\n**Type:** `{self.model_type.upper()}`", - color=0x9C84EF + color=0x9C84EF, ) # Check if Illustrious XL is available - illustrious_status = "✅ Installed" if os.path.exists(os.path.join(self.illustrious_dir, "model_index.json")) else "❌ Not installed" + illustrious_status = ( + "✅ Installed" + if os.path.exists( + os.path.join(self.illustrious_dir, "model_index.json") + ) + else "❌ Not installed" + ) embed.add_field( name="Available Models", @@ -356,7 +415,7 @@ class StableDiffusionCog(commands.Cog): "• `stabilityai/stable-diffusion-2-1` - Stable Diffusion 2.1\n" "• `stabilityai/stable-diffusion-xl-base-1.0` - Stable Diffusion XL" ), - inline=False + inline=False, ) # Add download instructions if Illustrious XL is not installed @@ -367,19 +426,19 @@ class StableDiffusionCog(commands.Cog): "To download Illustrious XL, run the `download_illustrious.py` script.\n" "This will download the model from Civitai and set it up for use." ), - inline=False + inline=False, ) embed.add_field( name="GPU Status", value=f"Using device: `{self.device}`\nCUDA available: `{torch.cuda.is_available()}`", - inline=False + inline=False, ) if torch.cuda.is_available(): embed.add_field( name="GPU Info", value=f"GPU: `{torch.cuda.get_device_name(0)}`\nMemory: `{torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB`", - inline=False + inline=False, ) await interaction.response.send_message(embed=embed, ephemeral=True) @@ -392,7 +451,7 @@ class StableDiffusionCog(commands.Cog): if self.is_generating: await interaction.followup.send( "⚠️ Can't change model while generating an image. Please try again later.", - ephemeral=True + ephemeral=True, ) return @@ -404,10 +463,12 @@ class StableDiffusionCog(commands.Cog): # Set the new model ID if model.value == "illustrious_xl": # Check if Illustrious XL is installed - if not os.path.exists(os.path.join(self.illustrious_dir, "model_index.json")): + if not os.path.exists( + os.path.join(self.illustrious_dir, "model_index.json") + ): await interaction.followup.send( "❌ Illustrious XL model is not installed. Please run the `download_illustrious.py` script first.", - ephemeral=True + ephemeral=True, ) return @@ -419,8 +480,9 @@ class StableDiffusionCog(commands.Cog): await interaction.followup.send( f"✅ Model changed to `{model.name}`. The model will be loaded on the next generation.", - ephemeral=True + ephemeral=True, ) + async def setup(bot): await bot.add_cog(StableDiffusionCog(bot)) diff --git a/cogs/starboard_cog.py b/cogs/starboard_cog.py index 5f2c19e..2bf7619 100644 --- a/cogs/starboard_cog.py +++ b/cogs/starboard_cog.py @@ -11,7 +11,9 @@ import os # Regular expression to extract message ID from Discord message links # Format: https://discord.com/channels/{guild_id}/{channel_id}/{message_id} -MESSAGE_LINK_PATTERN = re.compile(r"https?://(?:www\.)?discord(?:app)?\.com/channels/\d+/\d+/(\d+)") +MESSAGE_LINK_PATTERN = re.compile( + r"https?://(?:www\.)?discord(?:app)?\.com/channels/\d+/\d+/(\d+)" +) # Add the parent directory to sys.path to allow imports sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -23,13 +25,16 @@ from global_bot_accessor import get_bot_instance # Set up logging log = logging.getLogger(__name__) + class StarboardCog(commands.Cog): """A cog that implements a starboard feature for highlighting popular messages.""" def __init__(self, bot): self.bot = bot - self.emoji_pattern = re.compile(r'|[\U00010000-\U0010ffff]') - self.pending_updates = {} # Store message IDs that are being processed to prevent race conditions + self.emoji_pattern = re.compile(r"|[\U00010000-\U0010ffff]") + self.pending_updates = ( + {} + ) # Store message IDs that are being processed to prevent race conditions self.lock = asyncio.Lock() # Global lock for database operations @commands.Cog.listener() @@ -46,12 +51,16 @@ class StarboardCog(commands.Cog): # Get starboard settings for this guild settings = await settings_manager.get_starboard_settings(guild.id) - if not settings or not settings.get('enabled') or not settings.get('starboard_channel_id'): + if ( + not settings + or not settings.get("enabled") + or not settings.get("starboard_channel_id") + ): return # Check if the emoji matches the configured star emoji emoji_str = str(payload.emoji) - if emoji_str != settings.get('star_emoji', '⭐'): + if emoji_str != settings.get("star_emoji", "⭐"): return # Process the star reaction @@ -72,12 +81,16 @@ class StarboardCog(commands.Cog): # Get starboard settings for this guild settings = await settings_manager.get_starboard_settings(guild.id) - if not settings or not settings.get('enabled') or not settings.get('starboard_channel_id'): + if ( + not settings + or not settings.get("enabled") + or not settings.get("starboard_channel_id") + ): return # Check if the emoji matches the configured star emoji emoji_str = str(payload.emoji) - if emoji_str != settings.get('star_emoji', '⭐'): + if emoji_str != settings.get("star_emoji", "⭐"): return # Process the star reaction removal @@ -88,7 +101,7 @@ class StarboardCog(commands.Cog): # Get the channels guild = self.bot.get_guild(payload.guild_id) source_channel = guild.get_channel(payload.channel_id) - starboard_channel = guild.get_channel(settings.get('starboard_channel_id')) + starboard_channel = guild.get_channel(settings.get("starboard_channel_id")) if not source_channel or not starboard_channel: return @@ -100,7 +113,9 @@ class StarboardCog(commands.Cog): # Acquire lock for this message to prevent race conditions message_key = f"{payload.guild_id}:{payload.message_id}" if message_key in self.pending_updates: - log.debug(f"Skipping concurrent update for message {payload.message_id} in guild {payload.guild_id}") + log.debug( + f"Skipping concurrent update for message {payload.message_id} in guild {payload.guild_id}" + ) return self.pending_updates[message_key] = True @@ -113,28 +128,44 @@ class StarboardCog(commands.Cog): message = await source_channel.fetch_message(payload.message_id) break except discord.NotFound: - log.warning(f"Message {payload.message_id} not found in channel {source_channel.id}") + log.warning( + f"Message {payload.message_id} not found in channel {source_channel.id}" + ) return except discord.HTTPException as e: if attempt < retry_attempts - 1: - log.warning(f"Error fetching message {payload.message_id}, attempt {attempt+1}/{retry_attempts}: {e}") + log.warning( + f"Error fetching message {payload.message_id}, attempt {attempt+1}/{retry_attempts}: {e}" + ) await asyncio.sleep(1) # Wait before retrying else: - log.error(f"Failed to fetch message {payload.message_id} after {retry_attempts} attempts: {e}") + log.error( + f"Failed to fetch message {payload.message_id} after {retry_attempts} attempts: {e}" + ) return if not message: - log.error(f"Could not retrieve message {payload.message_id} after multiple attempts") + log.error( + f"Could not retrieve message {payload.message_id} after multiple attempts" + ) return # Check if message is from a bot and if we should ignore bot messages - if message.author.bot and settings.get('ignore_bots', True): - log.debug(f"Ignoring bot message {message.id} from {message.author.name}") + if message.author.bot and settings.get("ignore_bots", True): + log.debug( + f"Ignoring bot message {message.id} from {message.author.name}" + ) return # Check if the user is starring their own message and if that's allowed - if is_add and payload.user_id == message.author.id and not settings.get('self_star', False): - log.debug(f"User {payload.user_id} attempted to star their own message {message.id}, but self-starring is disabled") + if ( + is_add + and payload.user_id == message.author.id + and not settings.get("self_star", False) + ): + log.debug( + f"User {payload.user_id} attempted to star their own message {message.id}, but self-starring is disabled" + ) return # Update the reaction in the database with retry logic @@ -156,40 +187,54 @@ class StarboardCog(commands.Cog): break # If we couldn't get a valid count, try to fetch it directly - star_count = await settings_manager.get_starboard_reaction_count(guild.id, message.id) + star_count = await settings_manager.get_starboard_reaction_count( + guild.id, message.id + ) if isinstance(star_count, int): break except Exception as e: if attempt < retry_attempts - 1: - log.warning(f"Error updating reaction for message {message.id}, attempt {attempt+1}/{retry_attempts}: {e}") + log.warning( + f"Error updating reaction for message {message.id}, attempt {attempt+1}/{retry_attempts}: {e}" + ) await asyncio.sleep(1) # Wait before retrying else: - log.error(f"Failed to update reaction for message {message.id} after {retry_attempts} attempts: {e}") + log.error( + f"Failed to update reaction for message {message.id} after {retry_attempts} attempts: {e}" + ) return if not isinstance(star_count, int): log.error(f"Could not get valid star count for message {message.id}") return - log.info(f"Message {message.id} in guild {guild.id} now has {star_count} stars (action: {'add' if is_add else 'remove'})") + log.info( + f"Message {message.id} in guild {guild.id} now has {star_count} stars (action: {'add' if is_add else 'remove'})" + ) # Get the threshold from settings - threshold = settings.get('threshold', 3) + threshold = settings.get("threshold", 3) # Check if this message is already in the starboard entry = None retry_attempts = 3 for attempt in range(retry_attempts): try: - entry = await settings_manager.get_starboard_entry(guild.id, message.id) + entry = await settings_manager.get_starboard_entry( + guild.id, message.id + ) break except Exception as e: if attempt < retry_attempts - 1: - log.warning(f"Error getting starboard entry for message {message.id}, attempt {attempt+1}/{retry_attempts}: {e}") + log.warning( + f"Error getting starboard entry for message {message.id}, attempt {attempt+1}/{retry_attempts}: {e}" + ) await asyncio.sleep(1) # Wait before retrying else: - log.error(f"Failed to get starboard entry for message {message.id} after {retry_attempts} attempts: {e}") + log.error( + f"Failed to get starboard entry for message {message.id} after {retry_attempts} attempts: {e}" + ) # Continue with entry=None, which will create a new entry if needed if star_count >= threshold: @@ -197,53 +242,97 @@ class StarboardCog(commands.Cog): if entry: # Update existing entry try: - starboard_message = await starboard_channel.fetch_message(entry.get('starboard_message_id')) - await self._update_starboard_message(starboard_message, message, star_count) - await settings_manager.update_starboard_entry(guild.id, message.id, star_count) - log.info(f"Updated starboard message {starboard_message.id} for original message {message.id}") + starboard_message = await starboard_channel.fetch_message( + entry.get("starboard_message_id") + ) + await self._update_starboard_message( + starboard_message, message, star_count + ) + await settings_manager.update_starboard_entry( + guild.id, message.id, star_count + ) + log.info( + f"Updated starboard message {starboard_message.id} for original message {message.id}" + ) except discord.NotFound: # Starboard message was deleted, create a new one - log.warning(f"Starboard message {entry.get('starboard_message_id')} was deleted, creating a new one") - starboard_message = await self._create_starboard_message(starboard_channel, message, star_count) + log.warning( + f"Starboard message {entry.get('starboard_message_id')} was deleted, creating a new one" + ) + starboard_message = await self._create_starboard_message( + starboard_channel, message, star_count + ) if starboard_message: await settings_manager.create_starboard_entry( - guild.id, message.id, source_channel.id, - starboard_message.id, message.author.id, star_count + guild.id, + message.id, + source_channel.id, + starboard_message.id, + message.author.id, + star_count, + ) + log.info( + f"Created new starboard message {starboard_message.id} for original message {message.id}" ) - log.info(f"Created new starboard message {starboard_message.id} for original message {message.id}") except discord.HTTPException as e: - log.error(f"Error updating starboard message for {message.id}: {e}") + log.error( + f"Error updating starboard message for {message.id}: {e}" + ) else: # Create new entry - log.info(f"Creating new starboard entry for message {message.id} with {star_count} stars") - starboard_message = await self._create_starboard_message(starboard_channel, message, star_count) + log.info( + f"Creating new starboard entry for message {message.id} with {star_count} stars" + ) + starboard_message = await self._create_starboard_message( + starboard_channel, message, star_count + ) if starboard_message: await settings_manager.create_starboard_entry( - guild.id, message.id, source_channel.id, - starboard_message.id, message.author.id, star_count + guild.id, + message.id, + source_channel.id, + starboard_message.id, + message.author.id, + star_count, + ) + log.info( + f"Created starboard message {starboard_message.id} for original message {message.id}" ) - log.info(f"Created starboard message {starboard_message.id} for original message {message.id}") elif entry: # Message is below threshold but exists in starboard - log.info(f"Message {message.id} now has {star_count} stars, below threshold of {threshold}. Removing from starboard.") + log.info( + f"Message {message.id} now has {star_count} stars, below threshold of {threshold}. Removing from starboard." + ) try: # Delete the starboard message if it exists - starboard_message = await starboard_channel.fetch_message(entry.get('starboard_message_id')) + starboard_message = await starboard_channel.fetch_message( + entry.get("starboard_message_id") + ) await starboard_message.delete() - log.info(f"Deleted starboard message {entry.get('starboard_message_id')}") + log.info( + f"Deleted starboard message {entry.get('starboard_message_id')}" + ) except discord.NotFound: - log.warning(f"Starboard message {entry.get('starboard_message_id')} already deleted") + log.warning( + f"Starboard message {entry.get('starboard_message_id')} already deleted" + ) except discord.HTTPException as e: - log.error(f"Error deleting starboard message {entry.get('starboard_message_id')}: {e}") + log.error( + f"Error deleting starboard message {entry.get('starboard_message_id')}: {e}" + ) # Delete the entry from the database await settings_manager.delete_starboard_entry(guild.id, message.id) except Exception as e: - log.exception(f"Unexpected error processing star reaction for message {payload.message_id}: {e}") + log.exception( + f"Unexpected error processing star reaction for message {payload.message_id}: {e}" + ) finally: # Release the lock self.pending_updates.pop(message_key, None) - log.debug(f"Released lock for message {payload.message_id} in guild {payload.guild_id}") + log.debug( + f"Released lock for message {payload.message_id} in guild {payload.guild_id}" + ) async def _create_starboard_message(self, starboard_channel, message, star_count): """Create a new message in the starboard channel.""" @@ -259,7 +348,9 @@ class StarboardCog(commands.Cog): log.error(f"Error creating starboard message: {e}") return None - async def _update_starboard_message(self, starboard_message, original_message, star_count): + async def _update_starboard_message( + self, starboard_message, original_message, star_count + ): """Update an existing message in the starboard channel.""" try: embed = self._create_starboard_embed(original_message, star_count) @@ -281,13 +372,12 @@ class StarboardCog(commands.Cog): embed = discord.Embed( description=message.content, color=0xFFAC33, # Gold color for stars - timestamp=message.created_at + timestamp=message.created_at, ) # Set author information embed.set_author( - name=message.author.display_name, - icon_url=message.author.display_avatar.url + name=message.author.display_name, icon_url=message.author.display_avatar.url ) # Add footer with message ID for reference @@ -297,13 +387,17 @@ class StarboardCog(commands.Cog): if message.attachments: # If it's an image, add it to the embed for attachment in message.attachments: - if attachment.content_type and attachment.content_type.startswith('image/'): + if attachment.content_type and attachment.content_type.startswith( + "image/" + ): embed.set_image(url=attachment.url) break # Add a field listing all attachments if len(message.attachments) > 1: - attachment_list = "\n".join([f"[{a.filename}]({a.url})" for a in message.attachments]) + attachment_list = "\n".join( + [f"[{a.filename}]({a.url})" for a in message.attachments] + ) embed.add_field(name="Attachments", value=attachment_list, inline=False) return embed @@ -321,19 +415,27 @@ class StarboardCog(commands.Cog): # --- Starboard Commands --- - @commands.hybrid_group(name="starboard", description="Manage the starboard settings") + @commands.hybrid_group( + name="starboard", description="Manage the starboard settings" + ) @commands.has_permissions(manage_guild=True) @app_commands.default_permissions(manage_guild=True) async def starboard_group(self, ctx): """Commands for managing the starboard feature.""" if ctx.invoked_subcommand is None: - await ctx.send("Please specify a subcommand. Use `help starboard` for more information.") + await ctx.send( + "Please specify a subcommand. Use `help starboard` for more information." + ) - @starboard_group.command(name="enable", description="Enable or disable the starboard") + @starboard_group.command( + name="enable", description="Enable or disable the starboard" + ) @app_commands.describe(enabled="Whether to enable or disable the starboard") async def starboard_enable(self, ctx, enabled: bool): """Enable or disable the starboard feature.""" - success = await settings_manager.update_starboard_settings(ctx.guild.id, enabled=enabled) + success = await settings_manager.update_starboard_settings( + ctx.guild.id, enabled=enabled + ) if success: status = "enabled" if enabled else "disabled" @@ -341,18 +443,24 @@ class StarboardCog(commands.Cog): else: await ctx.send("❌ Failed to update starboard settings.") - @starboard_group.command(name="channel", description="Set the channel for starboard posts") + @starboard_group.command( + name="channel", description="Set the channel for starboard posts" + ) @app_commands.describe(channel="The channel to use for starboard posts") async def starboard_channel(self, ctx, channel: discord.TextChannel): """Set the channel where starboard messages will be posted.""" - success = await settings_manager.update_starboard_settings(ctx.guild.id, starboard_channel_id=channel.id) + success = await settings_manager.update_starboard_settings( + ctx.guild.id, starboard_channel_id=channel.id + ) if success: await ctx.send(f"✅ Starboard channel set to {channel.mention}.") else: await ctx.send("❌ Failed to update starboard channel.") - @starboard_group.command(name="threshold", description="Set the minimum number of stars needed") + @starboard_group.command( + name="threshold", description="Set the minimum number of stars needed" + ) @app_commands.describe(threshold="The minimum number of stars needed (1-25)") async def starboard_threshold(self, ctx, threshold: int): """Set the minimum number of stars needed for a message to appear on the starboard.""" @@ -360,14 +468,18 @@ class StarboardCog(commands.Cog): await ctx.send("❌ Threshold must be between 1 and 25.") return - success = await settings_manager.update_starboard_settings(ctx.guild.id, threshold=threshold) + success = await settings_manager.update_starboard_settings( + ctx.guild.id, threshold=threshold + ) if success: await ctx.send(f"✅ Starboard threshold set to {threshold} stars.") else: await ctx.send("❌ Failed to update starboard threshold.") - @starboard_group.command(name="emoji", description="Set the emoji used for starring messages") + @starboard_group.command( + name="emoji", description="Set the emoji used for starring messages" + ) @app_commands.describe(emoji="The emoji to use for starring messages") async def starboard_emoji(self, ctx, emoji: str): """Set the emoji that will be used for starring messages.""" @@ -376,18 +488,24 @@ class StarboardCog(commands.Cog): await ctx.send("❌ Please provide a valid emoji.") return - success = await settings_manager.update_starboard_settings(ctx.guild.id, star_emoji=emoji) + success = await settings_manager.update_starboard_settings( + ctx.guild.id, star_emoji=emoji + ) if success: await ctx.send(f"✅ Starboard emoji set to {emoji}.") else: await ctx.send("❌ Failed to update starboard emoji.") - @starboard_group.command(name="ignorebots", description="Set whether to ignore bot messages") + @starboard_group.command( + name="ignorebots", description="Set whether to ignore bot messages" + ) @app_commands.describe(ignore="Whether to ignore messages from bots") async def starboard_ignorebots(self, ctx, ignore: bool): """Set whether messages from bots should be ignored for the starboard.""" - success = await settings_manager.update_starboard_settings(ctx.guild.id, ignore_bots=ignore) + success = await settings_manager.update_starboard_settings( + ctx.guild.id, ignore_bots=ignore + ) if success: status = "will be ignored" if ignore else "will be included" @@ -395,11 +513,16 @@ class StarboardCog(commands.Cog): else: await ctx.send("❌ Failed to update bot message handling.") - @starboard_group.command(name="selfstar", description="Allow or disallow users to star their own messages") + @starboard_group.command( + name="selfstar", + description="Allow or disallow users to star their own messages", + ) @app_commands.describe(allow="Whether to allow users to star their own messages") async def starboard_selfstar(self, ctx, allow: bool): """Set whether users can star their own messages.""" - success = await settings_manager.update_starboard_settings(ctx.guild.id, self_star=allow) + success = await settings_manager.update_starboard_settings( + ctx.guild.id, self_star=allow + ) if success: status = "can" if allow else "cannot" @@ -407,7 +530,9 @@ class StarboardCog(commands.Cog): else: await ctx.send("❌ Failed to update self-starring setting.") - @starboard_group.command(name="settings", description="Show current starboard settings") + @starboard_group.command( + name="settings", description="Show current starboard settings" + ) async def starboard_settings(self, ctx): """Display the current starboard settings.""" settings = await settings_manager.get_starboard_settings(ctx.guild.id) @@ -420,20 +545,36 @@ class StarboardCog(commands.Cog): embed = discord.Embed( title="Starboard Settings", color=discord.Color.gold(), - timestamp=datetime.datetime.now() + timestamp=datetime.datetime.now(), ) # Add fields for each setting - embed.add_field(name="Status", value="Enabled" if settings.get('enabled') else "Disabled", inline=True) + embed.add_field( + name="Status", + value="Enabled" if settings.get("enabled") else "Disabled", + inline=True, + ) - channel_id = settings.get('starboard_channel_id') + channel_id = settings.get("starboard_channel_id") channel_mention = f"<#{channel_id}>" if channel_id else "Not set" embed.add_field(name="Channel", value=channel_mention, inline=True) - embed.add_field(name="Threshold", value=str(settings.get('threshold', 3)), inline=True) - embed.add_field(name="Emoji", value=settings.get('star_emoji', '⭐'), inline=True) - embed.add_field(name="Ignore Bots", value="Yes" if settings.get('ignore_bots', True) else "No", inline=True) - embed.add_field(name="Self-starring", value="Allowed" if settings.get('self_star', False) else "Not allowed", inline=True) + embed.add_field( + name="Threshold", value=str(settings.get("threshold", 3)), inline=True + ) + embed.add_field( + name="Emoji", value=settings.get("star_emoji", "⭐"), inline=True + ) + embed.add_field( + name="Ignore Bots", + value="Yes" if settings.get("ignore_bots", True) else "No", + inline=True, + ) + embed.add_field( + name="Self-starring", + value="Allowed" if settings.get("self_star", False) else "Not allowed", + inline=True, + ) await ctx.send(embed=embed) @@ -443,10 +584,16 @@ class StarboardCog(commands.Cog): async def starboard_clear(self, ctx): """Clear all entries from the starboard.""" # Ask for confirmation - await ctx.send("⚠️ **Warning**: This will delete all starboard entries for this server. Are you sure? (yes/no)") + await ctx.send( + "⚠️ **Warning**: This will delete all starboard entries for this server. Are you sure? (yes/no)" + ) def check(m): - return m.author == ctx.author and m.channel == ctx.channel and m.content.lower() in ["yes", "no"] + return ( + m.author == ctx.author + and m.channel == ctx.channel + and m.content.lower() in ["yes", "no"] + ) try: # Wait for confirmation @@ -458,11 +605,13 @@ class StarboardCog(commands.Cog): # Get the starboard channel settings = await settings_manager.get_starboard_settings(ctx.guild.id) - if not settings or not settings.get('starboard_channel_id'): + if not settings or not settings.get("starboard_channel_id"): await ctx.send("❌ Starboard channel not set.") return - starboard_channel = ctx.guild.get_channel(settings.get('starboard_channel_id')) + starboard_channel = ctx.guild.get_channel( + settings.get("starboard_channel_id") + ) if not starboard_channel: await ctx.send("❌ Starboard channel not found.") return @@ -475,7 +624,9 @@ class StarboardCog(commands.Cog): return # Delete all messages from the starboard channel - status_message = await ctx.send(f"🔄 Clearing {len(entries)} entries from the starboard...") + status_message = await ctx.send( + f"🔄 Clearing {len(entries)} entries from the starboard..." + ) deleted_count = 0 failed_count = 0 @@ -487,41 +638,59 @@ class StarboardCog(commands.Cog): for entry in entries_list: try: try: - message = await starboard_channel.fetch_message(entry['starboard_message_id']) + message = await starboard_channel.fetch_message( + entry["starboard_message_id"] + ) await message.delete() deleted_count += 1 except discord.NotFound: # Message already deleted deleted_count += 1 except discord.HTTPException as e: - log.error(f"Error deleting starboard message {entry['starboard_message_id']}: {e}") + log.error( + f"Error deleting starboard message {entry['starboard_message_id']}: {e}" + ) failed_count += 1 except Exception as e: log.error(f"Unexpected error deleting starboard message: {e}") failed_count += 1 - await status_message.edit(content=f"✅ Starboard cleared. Deleted {deleted_count} messages. Failed to delete {failed_count} messages.") + await status_message.edit( + content=f"✅ Starboard cleared. Deleted {deleted_count} messages. Failed to delete {failed_count} messages." + ) except asyncio.TimeoutError: await ctx.send("❌ Confirmation timed out. Operation cancelled.") except Exception as e: log.exception(f"Error clearing starboard: {e}") - await ctx.send(f"❌ An error occurred while clearing the starboard: {str(e)}") + await ctx.send( + f"❌ An error occurred while clearing the starboard: {str(e)}" + ) - @starboard_group.command(name="add", description="Manually add a message to the starboard") + @starboard_group.command( + name="add", description="Manually add a message to the starboard" + ) @commands.has_permissions(administrator=True) @app_commands.default_permissions(administrator=True) - @app_commands.describe(message_id_or_link="The message ID or link to add to the starboard") + @app_commands.describe( + message_id_or_link="The message ID or link to add to the starboard" + ) async def starboard_add(self, ctx, message_id_or_link: str): """Manually add a message to the starboard using its ID or link.""" # Get starboard settings settings = await settings_manager.get_starboard_settings(ctx.guild.id) - if not settings or not settings.get('enabled') or not settings.get('starboard_channel_id'): - await ctx.send("❌ Starboard is not properly configured. Please set up the starboard first.") + if ( + not settings + or not settings.get("enabled") + or not settings.get("starboard_channel_id") + ): + await ctx.send( + "❌ Starboard is not properly configured. Please set up the starboard first." + ) return # Get the starboard channel - starboard_channel = ctx.guild.get_channel(settings.get('starboard_channel_id')) + starboard_channel = ctx.guild.get_channel(settings.get("starboard_channel_id")) if not starboard_channel: await ctx.send("❌ Starboard channel not found.") return @@ -561,35 +730,53 @@ class StarboardCog(commands.Cog): if not message: if channel_found: - await ctx.send("❌ Message not found. Make sure the message ID is correct.") + await ctx.send( + "❌ Message not found. Make sure the message ID is correct." + ) else: - await ctx.send("❌ Message not found. The bot might not have access to the channel containing this message.") + await ctx.send( + "❌ Message not found. The bot might not have access to the channel containing this message." + ) return # Check if the message is already in the starboard entry = await settings_manager.get_starboard_entry(ctx.guild.id, message.id) if entry: - await ctx.send(f"⚠️ This message is already in the starboard with {entry.get('star_count', 0)} stars.") + await ctx.send( + f"⚠️ This message is already in the starboard with {entry.get('star_count', 0)} stars." + ) return # Check if the message is from the starboard channel if message.channel.id == starboard_channel.id: - await ctx.send("❌ Cannot add a message from the starboard channel to the starboard.") + await ctx.send( + "❌ Cannot add a message from the starboard channel to the starboard." + ) return # Set a default star count (1 more than the threshold) - threshold = settings.get('threshold', 3) + threshold = settings.get("threshold", 3) star_count = threshold # Create a new starboard entry - starboard_message = await self._create_starboard_message(starboard_channel, message, star_count) + starboard_message = await self._create_starboard_message( + starboard_channel, message, star_count + ) if starboard_message: await settings_manager.create_starboard_entry( - ctx.guild.id, message.id, message.channel.id, - starboard_message.id, message.author.id, star_count + ctx.guild.id, + message.id, + message.channel.id, + starboard_message.id, + message.author.id, + star_count, + ) + await ctx.send( + f"✅ Message successfully added to the starboard with {star_count} stars." + ) + log.info( + f"Admin {ctx.author.id} manually added message {message.id} to starboard in guild {ctx.guild.id}" ) - await ctx.send(f"✅ Message successfully added to the starboard with {star_count} stars.") - log.info(f"Admin {ctx.author.id} manually added message {message.id} to starboard in guild {ctx.guild.id}") else: await ctx.send("❌ Failed to create starboard message.") @@ -618,7 +805,7 @@ class StarboardCog(commands.Cog): SELECT COUNT(*) FROM starboard_entries WHERE guild_id = $1 """, - ctx.guild.id + ctx.guild.id, ) # Get the total number of reactions @@ -627,7 +814,7 @@ class StarboardCog(commands.Cog): SELECT COUNT(*) FROM starboard_reactions WHERE guild_id = $1 """, - ctx.guild.id + ctx.guild.id, ) # Get the most starred message @@ -638,25 +825,29 @@ class StarboardCog(commands.Cog): ORDER BY star_count DESC LIMIT 1 """, - ctx.guild.id + ctx.guild.id, ) # Create an embed to display the statistics embed = discord.Embed( title="Starboard Statistics", color=discord.Color.gold(), - timestamp=datetime.datetime.now() + timestamp=datetime.datetime.now(), ) - embed.add_field(name="Total Entries", value=str(total_entries), inline=True) - embed.add_field(name="Total Reactions", value=str(total_reactions), inline=True) + embed.add_field( + name="Total Entries", value=str(total_entries), inline=True + ) + embed.add_field( + name="Total Reactions", value=str(total_reactions), inline=True + ) if most_starred: most_starred_dict = dict(most_starred) embed.add_field( name="Most Starred Message", value=f"[Jump to Message](https://discord.com/channels/{ctx.guild.id}/{most_starred_dict['original_channel_id']}/{most_starred_dict['original_message_id']})\n{most_starred_dict['star_count']} stars", - inline=False + inline=False, ) await ctx.send(embed=embed) @@ -665,7 +856,10 @@ class StarboardCog(commands.Cog): await bot_instance.pg_pool.release(conn) except Exception as e: log.exception(f"Error getting starboard statistics: {e}") - await ctx.send(f"❌ An error occurred while getting starboard statistics: {str(e)}") + await ctx.send( + f"❌ An error occurred while getting starboard statistics: {str(e)}" + ) + async def setup(bot): """Add the cog to the bot.""" diff --git a/cogs/status_cog.py b/cogs/status_cog.py index 64e7c7c..fb1d562 100644 --- a/cogs/status_cog.py +++ b/cogs/status_cog.py @@ -4,56 +4,67 @@ from discord.ext import commands from discord import app_commands from typing import Optional, Literal + class StatusCog(commands.Cog): """Commands for managing the bot's status""" - + def __init__(self, bot: commands.Bot): self.bot = bot - - async def _set_status_logic(self, - status_type: Literal["playing", "listening", "streaming", "watching", "competing"], - status_text: str, - stream_url: Optional[str] = None) -> str: + + async def _set_status_logic( + self, + status_type: Literal[ + "playing", "listening", "streaming", "watching", "competing" + ], + status_text: str, + stream_url: Optional[str] = None, + ) -> str: """Core logic for setting the bot's status""" - + # Map the status type to the appropriate ActivityType activity_types = { "playing": discord.ActivityType.playing, "listening": discord.ActivityType.listening, "streaming": discord.ActivityType.streaming, "watching": discord.ActivityType.watching, - "competing": discord.ActivityType.competing + "competing": discord.ActivityType.competing, } - + activity_type = activity_types.get(status_type.lower()) - + if not activity_type: return f"Invalid status type: {status_type}. Valid types are: playing, listening, streaming, watching, competing." - + try: # For streaming status, we need a URL if status_type.lower() == "streaming" and stream_url: - await self.bot.change_presence(activity=discord.Streaming(name=status_text, url=stream_url)) + await self.bot.change_presence( + activity=discord.Streaming(name=status_text, url=stream_url) + ) else: - await self.bot.change_presence(activity=discord.Activity(type=activity_type, name=status_text)) - + await self.bot.change_presence( + activity=discord.Activity(type=activity_type, name=status_text) + ) + return f"Status set to: {status_type.capitalize()} {status_text}" except Exception as e: return f"Error setting status: {str(e)}" - + # --- Prefix Command --- @commands.command(name="setstatus") @commands.is_owner() - async def set_status(self, ctx: commands.Context, status_type: str, *, status_text: str): + async def set_status( + self, ctx: commands.Context, status_type: str, *, status_text: str + ): """Set the bot's status (Owner only) - + Valid status types: - playing - listening - streaming (requires a URL in the status text) - watching - competing - + Example: !setstatus playing Minecraft !setstatus listening to music @@ -65,38 +76,46 @@ class StatusCog(commands.Cog): stream_url = None if status_type.lower() == "streaming": parts = status_text.split() - if len(parts) >= 2 and (parts[0].startswith("http://") or parts[0].startswith("https://")): + if len(parts) >= 2 and ( + parts[0].startswith("http://") or parts[0].startswith("https://") + ): stream_url = parts[0] status_text = " ".join(parts[1:]) - + response = await self._set_status_logic(status_type, status_text, stream_url) await ctx.reply(response) - + # --- Slash Command --- @app_commands.command(name="setstatus", description="Set the bot's status") @app_commands.describe( status_type="The type of status to set", status_text="The text to display in the status", - stream_url="URL for streaming status (only required for streaming status)" + stream_url="URL for streaming status (only required for streaming status)", ) - @app_commands.choices(status_type=[ - app_commands.Choice(name="Playing", value="playing"), - app_commands.Choice(name="Listening", value="listening"), - app_commands.Choice(name="Streaming", value="streaming"), - app_commands.Choice(name="Watching", value="watching"), - app_commands.Choice(name="Competing", value="competing") - ]) - async def set_status_slash(self, - interaction: discord.Interaction, - status_type: str, - status_text: str, - stream_url: Optional[str] = None): + @app_commands.choices( + status_type=[ + app_commands.Choice(name="Playing", value="playing"), + app_commands.Choice(name="Listening", value="listening"), + app_commands.Choice(name="Streaming", value="streaming"), + app_commands.Choice(name="Watching", value="watching"), + app_commands.Choice(name="Competing", value="competing"), + ] + ) + async def set_status_slash( + self, + interaction: discord.Interaction, + status_type: str, + status_text: str, + stream_url: Optional[str] = None, + ): """Slash command version of set_status.""" # Check if user is the bot owner if interaction.user.id != self.bot.owner_id: - await interaction.response.send_message("This command can only be used by the bot owner.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used by the bot owner.", ephemeral=True + ) return - + response = await self._set_status_logic(status_type, status_text, stream_url) await interaction.response.send_message(response) @@ -108,11 +127,15 @@ class StatusCog(commands.Cog): await self._send_server_list(ctx.reply) # --- Slash Command for Listing Servers --- - @app_commands.command(name="listservers", description="Lists all servers the bot is in (Owner only)") + @app_commands.command( + name="listservers", description="Lists all servers the bot is in (Owner only)" + ) async def list_servers_slash(self, interaction: discord.Interaction): """Slash command version of list_servers.""" if interaction.user.id != self.bot.owner_id: - await interaction.response.send_message("This command can only be used by the bot owner.", ephemeral=True) + await interaction.response.send_message( + "This command can only be used by the bot owner.", ephemeral=True + ) return # Defer response as gathering info might take time await interaction.response.defer(ephemeral=True) @@ -122,7 +145,7 @@ class StatusCog(commands.Cog): """Helper function to gather server info and send the list.""" guilds = self.bot.guilds server_list_text = [] - max_embed_desc_length = 4096 # Discord embed description limit + max_embed_desc_length = 4096 # Discord embed description limit current_length = 0 embeds = [] @@ -131,17 +154,32 @@ class StatusCog(commands.Cog): invite_link = "N/A" try: # Try system channel first - if guild.system_channel and guild.system_channel.permissions_for(guild.me).create_instant_invite: - invite = await guild.system_channel.create_invite(max_age=3600, max_uses=1, unique=True, reason="Bot owner requested server list (Remove create invite permission to prevent this)") + if ( + guild.system_channel + and guild.system_channel.permissions_for( + guild.me + ).create_instant_invite + ): + invite = await guild.system_channel.create_invite( + max_age=3600, + max_uses=1, + unique=True, + reason="Bot owner requested server list (Remove create invite permission to prevent this)", + ) invite_link = invite.url else: # Fallback to the first channel the bot can create an invite in for channel in guild.text_channels: if channel.permissions_for(guild.me).create_instant_invite: - invite = await channel.create_invite(max_age=3600, max_uses=1, unique=True, reason="Bot owner requested server list (Remove create invite permission to prevent this)") + invite = await channel.create_invite( + max_age=3600, + max_uses=1, + unique=True, + reason="Bot owner requested server list (Remove create invite permission to prevent this)", + ) invite_link = invite.url break - else: # No suitable channel found + else: # No suitable channel found invite_link = "No invite permission" except discord.Forbidden: invite_link = "No invite permission" @@ -150,7 +188,11 @@ class StatusCog(commands.Cog): print(f"Error creating invite for guild {guild.id} ({guild.name}):") traceback.print_exc() - owner_info = f"{guild.owner} ({guild.owner_id})" if guild.owner else f"ID: {guild.owner_id}" + owner_info = ( + f"{guild.owner} ({guild.owner_id})" + if guild.owner + else f"ID: {guild.owner_id}" + ) server_info = ( f"**{guild.name}** (ID: {guild.id})\n" f"- Members: {guild.member_count}\n" @@ -161,7 +203,11 @@ class StatusCog(commands.Cog): # Check if adding this server exceeds the limit for the current embed if current_length + len(server_info) > max_embed_desc_length: # Finalize the current embed - embed = discord.Embed(title=f"Server List (Part {len(embeds) + 1})", description="".join(server_list_text), color=discord.Color.blue()) + embed = discord.Embed( + title=f"Server List (Part {len(embeds) + 1})", + description="".join(server_list_text), + color=discord.Color.blue(), + ) embeds.append(embed) # Start a new embed description server_list_text = [server_info] @@ -172,7 +218,11 @@ class StatusCog(commands.Cog): # Add the last embed if there's remaining text if server_list_text: - embed = discord.Embed(title=f"Server List (Part {len(embeds) + 1})", description="".join(server_list_text), color=discord.Color.blue()) + embed = discord.Embed( + title=f"Server List (Part {len(embeds) + 1})", + description="".join(server_list_text), + color=discord.Color.blue(), + ) embeds.append(embed) if not embeds: @@ -190,7 +240,7 @@ class StatusCog(commands.Cog): # For prefix commands, just send another message # For interactions, use followup.send # This implementation assumes send_func handles this correctly (ctx.reply vs interaction.followup.send) - await send_func(embed=embed, ephemeral=True) + await send_func(embed=embed, ephemeral=True) async def setup(bot: commands.Bot): diff --git a/cogs/sync_cog.py b/cogs/sync_cog.py index de15806..930797e 100644 --- a/cogs/sync_cog.py +++ b/cogs/sync_cog.py @@ -4,6 +4,7 @@ from discord import app_commands import traceback import command_customization + class SyncCog(commands.Cog): def __init__(self, bot): self.bot = bot @@ -22,24 +23,34 @@ class SyncCog(commands.Cog): cmd_info = { "name": cmd.name, "description": cmd.description, - "parameters": [p.name for p in cmd.parameters] if hasattr(cmd, "parameters") else [] + "parameters": ( + [p.name for p in cmd.parameters] + if hasattr(cmd, "parameters") + else [] + ), } commands_before.append(cmd_info) await ctx.send(f"Commands before sync: {len(commands_before)}") for cmd_data in commands_before: params_str = ", ".join(cmd_data["parameters"]) - await ctx.send(f"- {cmd_data['name']}: {len(cmd_data['parameters'])} params ({params_str})") + await ctx.send( + f"- {cmd_data['name']}: {len(cmd_data['parameters'])} params ({params_str})" + ) # Skip global sync to avoid command duplication await ctx.send("Skipping global sync to avoid command duplication...") # Sync guild-specific commands with customizations await ctx.send("Syncing guild-specific command customizations...") - guild_syncs = await command_customization.register_all_guild_commands(self.bot) + guild_syncs = await command_customization.register_all_guild_commands( + self.bot + ) total_guild_syncs = sum(len(cmds) for cmds in guild_syncs.values()) - await ctx.send(f"Synced commands for {len(guild_syncs)} guilds with a total of {total_guild_syncs} customized commands") + await ctx.send( + f"Synced commands for {len(guild_syncs)} guilds with a total of {total_guild_syncs} customized commands" + ) # Get list of commands after sync commands_after = [] @@ -47,21 +58,36 @@ class SyncCog(commands.Cog): cmd_info = { "name": cmd.name, "description": cmd.description, - "parameters": [p.name for p in cmd.parameters] if hasattr(cmd, "parameters") else [] + "parameters": ( + [p.name for p in cmd.parameters] + if hasattr(cmd, "parameters") + else [] + ), } commands_after.append(cmd_info) await ctx.send(f"Commands after sync: {len(commands_after)}") for cmd_data in commands_after: params_str = ", ".join(cmd_data["parameters"]) - await ctx.send(f"- {cmd_data['name']}: {len(cmd_data['parameters'])} params ({params_str})") + await ctx.send( + f"- {cmd_data['name']}: {len(cmd_data['parameters'])} params ({params_str})" + ) # Check for webdrivertorso command specifically - wd_cmd = next((cmd for cmd in self.bot.tree.get_commands() if cmd.name == "webdrivertorso"), None) + wd_cmd = next( + ( + cmd + for cmd in self.bot.tree.get_commands() + if cmd.name == "webdrivertorso" + ), + None, + ) if wd_cmd: await ctx.send("Webdrivertorso command details:") for param in wd_cmd.parameters: - await ctx.send(f"- Param: {param.name}, Type: {param.type}, Required: {param.required}") + await ctx.send( + f"- Param: {param.name}, Type: {param.type}, Required: {param.required}" + ) if hasattr(param, "choices") and param.choices: choices_str = ", ".join([c.name for c in param.choices]) await ctx.send(f" Choices: {choices_str}") @@ -73,6 +99,7 @@ class SyncCog(commands.Cog): await ctx.send(f"Error during sync: {str(e)}") await ctx.send(f"```{traceback.format_exc()}```") + async def setup(bot: commands.Bot): print("Loading SyncCog...") await bot.add_cog(SyncCog(bot)) diff --git a/cogs/system_check_cog.py b/cogs/system_check_cog.py index 9f8cf6f..f79ceea 100644 --- a/cogs/system_check_cog.py +++ b/cogs/system_check_cog.py @@ -5,15 +5,17 @@ import time import psutil import platform import GPUtil -import distro # Ensure this is installed +import distro # Ensure this is installed # Import wmi for Windows motherboard info try: import wmi + WMI_AVAILABLE = True except ImportError: WMI_AVAILABLE = False + class SystemCheckCog(commands.Cog): def __init__(self, bot): self.bot = bot @@ -27,7 +29,9 @@ class SystemCheckCog(commands.Cog): await context_or_interaction.followup.send(embed=embed) except Exception as e: print(f"Error in systemcheck command: {e}") - await context_or_interaction.followup.send(f"An error occurred while checking system status: {e}") + await context_or_interaction.followup.send( + f"An error occurred while checking system status: {e}" + ) async def _system_check_logic(self, context_or_interaction): """Return detailed bot and system information as a Discord embed.""" @@ -159,40 +163,42 @@ class SystemCheckCog(commands.Cog): embed.add_field( name="🤖 Bot Information", value=f"**Name:** {bot_user.name}\n" - f"**ID:** {bot_user.id}\n" - f"**Servers:** {guild_count}\n" - f"**Unique Users:** {user_count}", - inline=False + f"**ID:** {bot_user.id}\n" + f"**Servers:** {guild_count}\n" + f"**Unique Users:** {user_count}", + inline=False, ) else: embed.add_field( name="🤖 Bot Information", value="Bot user information not available.", - inline=False + inline=False, ) # System Info Field embed.add_field( name="🖥️ System Information", value=f"**OS:** {os_info}{distro_info_str}\n" # Use renamed variable - f"**Hostname:** {hostname}\n" - f"**Uptime:** {uptime}", - inline=False + f"**Hostname:** {hostname}\n" + f"**Uptime:** {uptime}", + inline=False, ) # Hardware Info Field embed.add_field( name="⚙️ Hardware Information", value=f"**Device Model:** {motherboard_info}\n" - f"**CPU:** {cpu_name}\n" - f"**CPU Usage:** {cpu_usage}%\n" - f"**RAM Usage:** {ram_usage}\n" - f"**GPU Info:**\n{gpu_info}", - inline=False + f"**CPU:** {cpu_name}\n" + f"**CPU Usage:** {cpu_usage}%\n" + f"**RAM Usage:** {ram_usage}\n" + f"**GPU Info:**\n{gpu_info}", + inline=False, ) if user: - embed.set_footer(text=f"Requested by: {user.display_name}", icon_url=avatar_url) + embed.set_footer( + text=f"Requested by: {user.display_name}", icon_url=avatar_url + ) embed.timestamp = discord.utils.utcnow() return embed @@ -201,23 +207,27 @@ class SystemCheckCog(commands.Cog): @commands.command(name="systemcheck") async def system_check(self, ctx: commands.Context): """Check the bot and system status.""" - embed = await self._system_check_logic(ctx) # Pass context + embed = await self._system_check_logic(ctx) # Pass context await ctx.reply(embed=embed) # --- Slash Command --- - @app_commands.command(name="systemcheck", description="Check the bot and system status") + @app_commands.command( + name="systemcheck", description="Check the bot and system status" + ) async def system_check_slash(self, interaction: discord.Interaction): """Slash command version of system check.""" # Defer the response to prevent interaction timeout await interaction.response.defer(thinking=True) try: - embed = await self._system_check_logic(interaction) # Pass interaction + embed = await self._system_check_logic(interaction) # Pass interaction # Use followup since we've already deferred await interaction.followup.send(embed=embed) except Exception as e: # Handle any errors that might occur during processing print(f"Error in system_check_slash: {e}") - await interaction.followup.send(f"An error occurred while checking system status: {e}") + await interaction.followup.send( + f"An error occurred while checking system status: {e}" + ) def _get_motherboard_info(self): """Get motherboard information based on the operating system.""" @@ -247,5 +257,6 @@ class SystemCheckCog(commands.Cog): print(f"Error getting motherboard info: {e}") return "Error retrieving motherboard info" + async def setup(bot): await bot.add_cog(SystemCheckCog(bot)) diff --git a/cogs/terminal_cog.py b/cogs/terminal_cog.py index 46db879..cebff5a 100644 --- a/cogs/terminal_cog.py +++ b/cogs/terminal_cog.py @@ -10,21 +10,21 @@ import time import aiohttp import asyncio from collections import deque -import shlex # For safer command parsing if not using shell=True for everything +import shlex # For safer command parsing if not using shell=True for everything # --- Configuration --- FONT_PATH = "FONT/DejaVuSansMono.ttf" # IMPORTANT: Make sure this font file (e.g., Courier New) is in the same directory as your bot, or provide an absolute path. - # You can download common monospaced fonts like DejaVuSansMono.ttf +# You can download common monospaced fonts like DejaVuSansMono.ttf FONT_SIZE = 15 IMG_WIDTH = 800 IMG_HEIGHT = 600 PADDING = 10 -LINE_SPACING = 4 # Extra pixels between lines -BACKGROUND_COLOR = (30, 30, 30) # Dark grey -TEXT_COLOR = (220, 220, 220) # Light grey -PROMPT_COLOR = (70, 170, 240) # Blueish -ERROR_COLOR = (255, 100, 100) # Reddish -MAX_HISTORY_LINES = 500 # Max lines to keep in history +LINE_SPACING = 4 # Extra pixels between lines +BACKGROUND_COLOR = (30, 30, 30) # Dark grey +TEXT_COLOR = (220, 220, 220) # Light grey +PROMPT_COLOR = (70, 170, 240) # Blueish +ERROR_COLOR = (255, 100, 100) # Reddish +MAX_HISTORY_LINES = 500 # Max lines to keep in history AUTO_UPDATE_INTERVAL_SECONDS = 3 MAX_OUTPUT_LINES_PER_IMAGE = (IMG_HEIGHT - 2 * PADDING) // (FONT_SIZE + LINE_SPACING) OWNER_ID = 452666956353503252 @@ -32,6 +32,7 @@ TERMINAL_IMAGES_DIR = "terminal_images" # Directory to store terminal images # Use your actual domain or IP address here API_BASE_URL = "https://slipstreamm.dev" # Base URL for the API + # --- Helper: Owner Check --- async def is_owner_check(interaction: discord.Interaction) -> bool: """Checks if the interacting user is the hardcoded bot owner.""" @@ -43,7 +44,7 @@ class TerminalCog(commands.Cog, name="Terminal"): def __init__(self, bot: commands.Bot): self.bot = bot - self.owner_id: int = 0 # Will be set in cog_load + self.owner_id: int = 0 # Will be set in cog_load self.terminal_active: bool = False self.current_cwd: str = os.getcwd() self.output_history: deque[str] = deque(maxlen=MAX_HISTORY_LINES) @@ -51,7 +52,9 @@ class TerminalCog(commands.Cog, name="Terminal"): self.terminal_message: discord.Message | None = None self.active_process: subprocess.Popen | None = None self.terminal_view: TerminalView | None = None - self.last_command: str | None = None # Store the last command for display after execution + self.last_command: str | None = ( + None # Store the last command for display after execution + ) # Ensure the terminal_images directory exists with proper permissions os.makedirs(TERMINAL_IMAGES_DIR, exist_ok=True) @@ -66,10 +69,14 @@ class TerminalCog(commands.Cog, name="Terminal"): try: self.font = ImageFont.truetype(FONT_PATH, FONT_SIZE) except IOError: - print(f"Error: Font file '{FONT_PATH}' not found. Using default PIL font. Terminal image quality may be affected.") + print( + f"Error: Font file '{FONT_PATH}' not found. Using default PIL font. Terminal image quality may be affected." + ) self.font = ImageFont.load_default() - self.auto_update_task = tasks.loop(seconds=AUTO_UPDATE_INTERVAL_SECONDS)(self.refresh_terminal_output) + self.auto_update_task = tasks.loop(seconds=AUTO_UPDATE_INTERVAL_SECONDS)( + self.refresh_terminal_output + ) # Ensure cog_load is defined to set owner_id properly # self.bot.loop.create_task(self._async_init()) # Alternative for async setup @@ -79,11 +86,17 @@ class TerminalCog(commands.Cog, name="Terminal"): if app_info.team: # For teams, owner_id might be the team owner's ID or you might need a list of allowed admins. # This example will use the team owner if available, otherwise the first listed owner. - self.owner_id = app_info.owner.owner_id if app_info.owner and hasattr(app_info.owner, 'owner_id') else (app_info.owner.id if app_info.owner else 0) + self.owner_id = ( + app_info.owner.owner_id + if app_info.owner and hasattr(app_info.owner, "owner_id") + else (app_info.owner.id if app_info.owner else 0) + ) elif app_info.owner: self.owner_id = app_info.owner.id else: - print("Warning: Bot owner ID could not be determined. Terminal cog owner checks might fail.") + print( + "Warning: Bot owner ID could not be determined. Terminal cog owner checks might fail." + ) # Fallback or raise error # self.owner_id = YOUR_FALLBACK_OWNER_ID # if you have one @@ -100,7 +113,9 @@ class TerminalCog(commands.Cog, name="Terminal"): if response.status == 200: print(f"Successfully warmed Cloudflare cache for {image_url}") else: - print(f"Failed to warm Cloudflare cache for {image_url}: HTTP {response.status}") + print( + f"Failed to warm Cloudflare cache for {image_url}: HTTP {response.status}" + ) except Exception as e: print(f"Error warming Cloudflare cache for {image_url}: {e}") @@ -109,27 +124,34 @@ class TerminalCog(commands.Cog, name="Terminal"): Generates an image of the current terminal output. Returns a tuple of (BytesIO object, filename) """ - image = Image.new('RGB', (IMG_WIDTH, IMG_HEIGHT), BACKGROUND_COLOR) + image = Image.new("RGB", (IMG_WIDTH, IMG_HEIGHT), BACKGROUND_COLOR) draw = ImageDraw.Draw(image) - char_width, _ = self.font.getbbox("M")[2:] # Get width of a character - if char_width == 0: char_width = FONT_SIZE // 2 # Estimate if getbbox fails for default font + char_width, _ = self.font.getbbox("M")[2:] # Get width of a character + if char_width == 0: + char_width = FONT_SIZE // 2 # Estimate if getbbox fails for default font y_pos = PADDING # Determine visible lines based on scroll offset start_index = self.scroll_offset - end_index = min(len(self.output_history), self.scroll_offset + MAX_OUTPUT_LINES_PER_IMAGE) + end_index = min( + len(self.output_history), self.scroll_offset + MAX_OUTPUT_LINES_PER_IMAGE + ) visible_lines = list(self.output_history)[start_index:end_index] for line in visible_lines: # Basic coloring for prompt or errors display_color = TEXT_COLOR - if line.strip().endswith(">") and self.current_cwd in line : # Basic prompt detection - display_color = PROMPT_COLOR - elif "error" in line.lower() or "failed" in line.lower(): # Basic error detection - display_color = ERROR_COLOR + if ( + line.strip().endswith(">") and self.current_cwd in line + ): # Basic prompt detection + display_color = PROMPT_COLOR + elif ( + "error" in line.lower() or "failed" in line.lower() + ): # Basic error detection + display_color = ERROR_COLOR # Handle lines longer than image width (simple truncation) # A more advanced version could wrap text or allow horizontal scrolling. @@ -142,7 +164,7 @@ class TerminalCog(commands.Cog, name="Terminal"): draw.text((PADDING, y_pos), line, font=self.font, fill=display_color) y_pos += FONT_SIZE + LINE_SPACING if y_pos > IMG_HEIGHT - PADDING - FONT_SIZE: - break # Stop if no more space + break # Stop if no more space # Create a unique filename with timestamp filename = f"terminal_{uuid.uuid4().hex[:8]}_{int(time.time())}.png" @@ -152,16 +174,18 @@ class TerminalCog(commands.Cog, name="Terminal"): # Save the image to the terminal_images directory file_path = os.path.join(TERMINAL_IMAGES_DIR, filename) - image.save(file_path, format='PNG') + image.save(file_path, format="PNG") # Also return a BytesIO object for backward compatibility img_byte_arr = io.BytesIO() - image.save(img_byte_arr, format='PNG') + image.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) return img_byte_arr, filename - async def _update_terminal_message(self, interaction: Interaction | None = None, new_content: str | None = None): + async def _update_terminal_message( + self, interaction: Interaction | None = None, new_content: str | None = None + ): """Updates the terminal message with a new image and view.""" if not self.terminal_message and interaction: # This case should ideally be handled by sending a new message @@ -170,7 +194,9 @@ class TerminalCog(commands.Cog, name="Terminal"): if not self.terminal_active: if self.terminal_message: - await self.terminal_message.edit(content="Terminal session ended.", view=None, attachments=[]) + await self.terminal_message.edit( + content="Terminal session ended.", view=None, attachments=[] + ) return # Generate the image and save it to the terminal_images directory @@ -184,16 +210,26 @@ class TerminalCog(commands.Cog, name="Terminal"): asyncio.create_task(self._warm_cloudflare_cache(image_url)) if self.terminal_view: - self.terminal_view.update_button_states(self) # Update button enable/disable + self.terminal_view.update_button_states( + self + ) # Update button enable/disable # Prepare the message content with the image URL - content = f"Terminal Output: [View Image]({image_url})" if not new_content else new_content - edit_kwargs = {"content": content, "view": self.terminal_view, "attachments": []} + content = ( + f"Terminal Output: [View Image]({image_url})" + if not new_content + else new_content + ) + edit_kwargs = { + "content": content, + "view": self.terminal_view, + "attachments": [], + } try: if interaction and not interaction.response.is_done(): - await interaction.response.edit_message(**edit_kwargs) - if not self.terminal_message: # If interaction was the first one + await interaction.response.edit_message(**edit_kwargs) + if not self.terminal_message: # If interaction was the first one self.terminal_message = await interaction.original_response() elif self.terminal_message: await self.terminal_message.edit(**edit_kwargs) @@ -207,9 +243,10 @@ class TerminalCog(commands.Cog, name="Terminal"): await self.stop_terminal_session() except discord.HTTPException as e: print(f"Error updating terminal message: {e}") - if e.status == 429: # Rate limited - print("Rate limited. Auto-update might be too fast or manual refresh too frequent.") - + if e.status == 429: # Rate limited + print( + "Rate limited. Auto-update might be too fast or manual refresh too frequent." + ) async def stop_terminal_session(self, interaction: Interaction | None = None): """Stops the terminal session and cleans up.""" @@ -218,10 +255,10 @@ class TerminalCog(commands.Cog, name="Terminal"): self.auto_update_task.cancel() if self.active_process: try: - self.active_process.terminate() # Try to terminate gracefully - self.active_process.wait(timeout=1.0) # Wait a bit + self.active_process.terminate() # Try to terminate gracefully + self.active_process.wait(timeout=1.0) # Wait a bit except subprocess.TimeoutExpired: - self.active_process.kill() # Force kill if terminate fails + self.active_process.kill() # Force kill if terminate fails except Exception as e: print(f"Error terminating process: {e}") self.active_process = None @@ -229,38 +266,50 @@ class TerminalCog(commands.Cog, name="Terminal"): final_message = "Terminal session ended." if self.terminal_message: try: - await self.terminal_message.edit(content=final_message, view=None, attachments=[]) + await self.terminal_message.edit( + content=final_message, view=None, attachments=[] + ) except discord.HTTPException: - pass # Message might already be gone - elif interaction: # If no persistent message, respond to interaction - if not interaction.response.is_done(): + pass # Message might already be gone + elif interaction: # If no persistent message, respond to interaction + if not interaction.response.is_done(): await interaction.response.send_message(final_message, ephemeral=True) - else: + else: await interaction.followup.send(final_message, ephemeral=True) self.terminal_message = None self.output_history.clear() self.scroll_offset = 0 - - @app_commands.command(name="terminal", description="Starts an owner-only terminal session.") + @app_commands.command( + name="terminal", description="Starts an owner-only terminal session." + ) @app_commands.check(is_owner_check) async def terminal_command(self, interaction: Interaction): """Starts the terminal interface.""" if self.terminal_active and self.terminal_message: - await interaction.response.send_message(f"A terminal session is already active. View it here: {self.terminal_message.jump_url}", ephemeral=True) + await interaction.response.send_message( + f"A terminal session is already active. View it here: {self.terminal_message.jump_url}", + ephemeral=True, + ) return - await interaction.response.defer(ephemeral=False) # Ephemeral False to allow message editing + await interaction.response.defer( + ephemeral=False + ) # Ephemeral False to allow message editing self.terminal_active = True self.current_cwd = os.getcwd() self.output_history.clear() self.output_history.append(f"Discord Terminal Initialized.") - self.output_history.append(f"Owner: {interaction.user.name} ({interaction.user.id})") + self.output_history.append( + f"Owner: {interaction.user.name} ({interaction.user.id})" + ) self.output_history.append(f"Current CWD: {self.current_cwd}") self.output_history.append(f"{self.current_cwd}> ") - self.scroll_offset = max(0, len(self.output_history) - MAX_OUTPUT_LINES_PER_IMAGE) # Scroll to bottom + self.scroll_offset = max( + 0, len(self.output_history) - MAX_OUTPUT_LINES_PER_IMAGE + ) # Scroll to bottom self.terminal_view = TerminalView(cog=self, owner_id=self.owner_id) @@ -278,30 +327,40 @@ class TerminalCog(commands.Cog, name="Terminal"): # Use followup since we deferred self.terminal_message = await interaction.followup.send( content=f"Terminal Output: [View Image]({image_url})", - view=self.terminal_view + view=self.terminal_view, ) - self.terminal_view.message = self.terminal_message # Give view a reference to the message + self.terminal_view.message = ( + self.terminal_message + ) # Give view a reference to the message @terminal_command.error - async def terminal_command_error(self, interaction: Interaction, error: app_commands.AppCommandError): + async def terminal_command_error( + self, interaction: Interaction, error: app_commands.AppCommandError + ): if isinstance(error, app_commands.CheckFailure): - await interaction.response.send_message("You do not have permission to use this command.", ephemeral=True) + await interaction.response.send_message( + "You do not have permission to use this command.", ephemeral=True + ) else: - await interaction.response.send_message(f"An error occurred: {error}", ephemeral=True) + await interaction.response.send_message( + f"An error occurred: {error}", ephemeral=True + ) print(f"Terminal command error: {error}") async def execute_shell_command(self, command: str, interaction: Interaction): """Executes a shell command and updates the terminal.""" if not self.terminal_active: - await interaction.response.send_message("Terminal session is not active. Use `/terminal` to start.", ephemeral=True) + await interaction.response.send_message( + "Terminal session is not active. Use `/terminal` to start.", + ephemeral=True, + ) return - # Handle 'clear' command separately if command.strip().lower() == "clear" or command.strip().lower() == "cls": self.output_history.clear() - self.output_history.append(f"{self.current_cwd}> ") # Add new prompt - self.scroll_offset = 0 # Reset scroll + self.output_history.append(f"{self.current_cwd}> ") # Add new prompt + self.scroll_offset = 0 # Reset scroll await self._update_terminal_message(interaction) return @@ -309,7 +368,7 @@ class TerminalCog(commands.Cog, name="Terminal"): if command.strip().lower().startswith("cd "): try: target_dir_str = command.strip()[3:].strip() - if not target_dir_str: # "cd" or "cd " + if not target_dir_str: # "cd" or "cd " # Go to home directory (platform dependent) new_cwd = os.path.expanduser("~") else: @@ -318,7 +377,9 @@ class TerminalCog(commands.Cog, name="Terminal"): if os.path.isabs(target_dir_str): new_cwd = target_dir_str else: - new_cwd = os.path.abspath(os.path.join(self.current_cwd, target_dir_str)) + new_cwd = os.path.abspath( + os.path.join(self.current_cwd, target_dir_str) + ) if os.path.isdir(new_cwd): self.current_cwd = new_cwd @@ -328,29 +389,35 @@ class TerminalCog(commands.Cog, name="Terminal"): except Exception as e: self.output_history.append(f"Error changing directory: {e}") - self.output_history.append(f"{self.current_cwd}> ") # New prompt - self.scroll_offset = max(0, len(self.output_history) - MAX_OUTPUT_LINES_PER_IMAGE) + self.output_history.append(f"{self.current_cwd}> ") # New prompt + self.scroll_offset = max( + 0, len(self.output_history) - MAX_OUTPUT_LINES_PER_IMAGE + ) await self._update_terminal_message(interaction) return # Handle 'exit' or 'quit' if command.strip().lower() in ["exit", "quit"]: self.output_history.append("Exiting terminal session...") - await self._update_terminal_message(interaction) # Show exit message + await self._update_terminal_message(interaction) # Show exit message await self.stop_terminal_session(interaction) return - self.last_command = command # Store command for display after execution + self.last_command = command # Store command for display after execution # For other commands, use subprocess if self.active_process and self.active_process.poll() is None: - self.output_history.append("A command is already running. Please wait or refresh.") + self.output_history.append( + "A command is already running. Please wait or refresh." + ) await self._update_terminal_message(interaction) return # For other commands, use subprocess if self.active_process and self.active_process.poll() is None: - self.output_history.append("A command is already running. Please wait or refresh.") + self.output_history.append( + "A command is already running. Please wait or refresh." + ) await self._update_terminal_message(interaction) return @@ -360,7 +427,9 @@ class TerminalCog(commands.Cog, name="Terminal"): if not command_parts: self.output_history.append("No command provided.") self.output_history.append(f"{self.current_cwd}> ") - self.scroll_offset = max(0, len(self.output_history) - MAX_OUTPUT_LINES_PER_IMAGE) + self.scroll_offset = max( + 0, len(self.output_history) - MAX_OUTPUT_LINES_PER_IMAGE + ) await self._update_terminal_message(interaction) return @@ -369,7 +438,7 @@ class TerminalCog(commands.Cog, name="Terminal"): stdin=subprocess.PIPE, # Enable interactive input stdout=subprocess.PIPE, stderr=subprocess.PIPE, - text=True, # Use text mode for easier handling of output + text=True, # Use text mode for easier handling of output cwd=self.current_cwd, # bufsize=1, # Removed line-buffering for better interactive handling # universal_newlines=True # text=True handles this @@ -378,22 +447,29 @@ class TerminalCog(commands.Cog, name="Terminal"): self.auto_update_task.start() # Initial update to show command is running - self.output_history.append(f"{self.current_cwd}> {command}") # Add command to history immediately - self.scroll_offset = max(0, len(self.output_history) - MAX_OUTPUT_LINES_PER_IMAGE) + self.output_history.append( + f"{self.current_cwd}> {command}" + ) # Add command to history immediately + self.scroll_offset = max( + 0, len(self.output_history) - MAX_OUTPUT_LINES_PER_IMAGE + ) await self._update_terminal_message(interaction) except FileNotFoundError: self.output_history.append(f"Error: Command not found: {command_parts[0]}") self.output_history.append(f"{self.current_cwd}> ") - self.scroll_offset = max(0, len(self.output_history) - MAX_OUTPUT_LINES_PER_IMAGE) + self.scroll_offset = max( + 0, len(self.output_history) - MAX_OUTPUT_LINES_PER_IMAGE + ) await self._update_terminal_message(interaction) except Exception as e: self.output_history.append(f"Error executing command: {e}") self.output_history.append(f"{self.current_cwd}> ") - self.scroll_offset = max(0, len(self.output_history) - MAX_OUTPUT_LINES_PER_IMAGE) + self.scroll_offset = max( + 0, len(self.output_history) - MAX_OUTPUT_LINES_PER_IMAGE + ) await self._update_terminal_message(interaction) - async def refresh_terminal_output(self, interaction: Interaction | None = None): """Called by task loop or refresh button to update output from active process.""" if not self.terminal_active: @@ -408,16 +484,24 @@ class TerminalCog(commands.Cog, name="Terminal"): # Read any final output final_stdout, final_stderr = self.active_process.communicate() - if final_stdout: self.output_history.extend(final_stdout.strip().splitlines()) - if final_stderr: self.output_history.extend([f"STDERR: {l}" for l in final_stderr.strip().splitlines()]) + if final_stdout: + self.output_history.extend(final_stdout.strip().splitlines()) + if final_stderr: + self.output_history.extend( + [f"STDERR: {l}" for l in final_stderr.strip().splitlines()] + ) - self.output_history.append(f"Process finished with exit code {return_code}.") - self.output_history.append(f"{self.current_cwd}> ") # New prompt + self.output_history.append( + f"Process finished with exit code {return_code}." + ) + self.output_history.append(f"{self.current_cwd}> ") # New prompt self.active_process = None - if self.auto_update_task.is_running(): # Stop loop if it was running for this process + if ( + self.auto_update_task.is_running() + ): # Stop loop if it was running for this process self.auto_update_task.stop() updated = True - else: # Process is still running, check for new output without blocking + else: # Process is still running, check for new output without blocking try: # Read available output without blocking stdout_output = self.active_process.stdout.read() @@ -427,20 +511,25 @@ class TerminalCog(commands.Cog, name="Terminal"): self.output_history.extend(stdout_output.strip().splitlines()) updated = True if stderr_output: - self.output_history.extend([f"STDERR: {l}" for l in stderr_output.strip().splitlines()]) + self.output_history.extend( + [f"STDERR: {l}" for l in stderr_output.strip().splitlines()] + ) updated = True except io.UnsupportedOperation: # This might happen if the stream is not seekable or non-blocking read is not supported # In this case, we might just have to wait for the process to finish - pass # No update from this read attempt + pass # No update from this read attempt except Exception as e: self.output_history.append(f"Error reading process output: {e}") updated = True - - if updated or interaction: # if interaction, means it's a manual refresh, so always update - self.scroll_offset = max(0, len(self.output_history) - MAX_OUTPUT_LINES_PER_IMAGE) + if ( + updated or interaction + ): # if interaction, means it's a manual refresh, so always update + self.scroll_offset = max( + 0, len(self.output_history) - MAX_OUTPUT_LINES_PER_IMAGE + ) await self._update_terminal_message(interaction) # If an interactive prompt was detected and the process is still running, @@ -453,12 +542,12 @@ class TerminalInputModal(ui.Modal, title="Send Command to Terminal"): command_input = ui.TextInput( label="Command", placeholder="Enter command (e.g., ls -l, python script.py)", - style=discord.TextStyle.long, # For multi-line, though usually single. - max_length=400 + style=discord.TextStyle.long, # For multi-line, though usually single. + max_length=400, ) def __init__(self, cog: TerminalCog): - super().__init__(timeout=300) # 5 minutes timeout for modal + super().__init__(timeout=300) # 5 minutes timeout for modal self.cog = cog async def on_submit(self, interaction: Interaction): @@ -473,15 +562,21 @@ class TerminalInputModal(ui.Modal, title="Send Command to Terminal"): if self.cog.active_process and self.cog.active_process.poll() is None: # There is an active process, assume the input is for it try: - self.cog.active_process.stdin.write(user_input + '\n') + self.cog.active_process.stdin.write(user_input + "\n") self.cog.active_process.stdin.flush() # Add the input to history for display self.cog.output_history.append(user_input) - self.cog.scroll_offset = max(0, len(self.cog.output_history) - MAX_OUTPUT_LINES_PER_IMAGE) - await self.cog._update_terminal_message(interaction) # Update message with the input + self.cog.scroll_offset = max( + 0, len(self.cog.output_history) - MAX_OUTPUT_LINES_PER_IMAGE + ) + await self.cog._update_terminal_message( + interaction + ) # Update message with the input except Exception as e: self.cog.output_history.append(f"Error sending input to process: {e}") - self.cog.scroll_offset = max(0, len(self.cog.output_history) - MAX_OUTPUT_LINES_PER_IMAGE) + self.cog.scroll_offset = max( + 0, len(self.cog.output_history) - MAX_OUTPUT_LINES_PER_IMAGE + ) await self.cog._update_terminal_message(interaction) else: # No active process, execute as a new command @@ -494,51 +589,64 @@ class TerminalInputModal(ui.Modal, title="Send Command to Terminal"): class TerminalView(ui.View): def __init__(self, cog: TerminalCog, owner_id: int): - super().__init__(timeout=None) # Persistent view + super().__init__(timeout=None) # Persistent view self.cog = cog self.owner_id = owner_id - self.message: discord.Message | None = None # To store the message this view is attached to + self.message: discord.Message | None = ( + None # To store the message this view is attached to + ) # Add buttons after initialization self._add_buttons() def _add_buttons(self): - self.clear_items() # Clear existing items if any (e.g., on re-creation) + self.clear_items() # Clear existing items if any (e.g., on re-creation) # Scroll Up - self.scroll_up_button = ui.Button(label="Scroll Up", emoji="⬆️", style=discord.ButtonStyle.secondary, row=0) + self.scroll_up_button = ui.Button( + label="Scroll Up", emoji="⬆️", style=discord.ButtonStyle.secondary, row=0 + ) self.scroll_up_button.callback = self.scroll_up_callback self.add_item(self.scroll_up_button) # Scroll Down - self.scroll_down_button = ui.Button(label="Scroll Down", emoji="⬇️", style=discord.ButtonStyle.secondary, row=0) + self.scroll_down_button = ui.Button( + label="Scroll Down", emoji="⬇️", style=discord.ButtonStyle.secondary, row=0 + ) self.scroll_down_button.callback = self.scroll_down_callback self.add_item(self.scroll_down_button) # Send Input - self.send_input_button = ui.Button(label="Send Input", emoji="⌨️", style=discord.ButtonStyle.primary, row=1) + self.send_input_button = ui.Button( + label="Send Input", emoji="⌨️", style=discord.ButtonStyle.primary, row=1 + ) self.send_input_button.callback = self.send_input_callback self.add_item(self.send_input_button) # Refresh - self.refresh_button = ui.Button(label="Refresh", emoji="🔄", style=discord.ButtonStyle.success, row=1) + self.refresh_button = ui.Button( + label="Refresh", emoji="🔄", style=discord.ButtonStyle.success, row=1 + ) self.refresh_button.callback = self.refresh_callback self.add_item(self.refresh_button) # Close/Exit Button - self.close_button = ui.Button(label="Close Terminal", emoji="❌", style=discord.ButtonStyle.danger, row=1) + self.close_button = ui.Button( + label="Close Terminal", emoji="❌", style=discord.ButtonStyle.danger, row=1 + ) self.close_button.callback = self.close_callback self.add_item(self.close_button) self.update_button_states(self.cog) - async def interaction_check(self, interaction: Interaction) -> bool: """Ensure only the bot owner can interact.""" # Use the cog's owner_id which should be set correctly is_allowed = interaction.user.id == OWNER_ID if not is_allowed: - await interaction.response.send_message("You are not authorized to use these buttons.", ephemeral=True) + await interaction.response.send_message( + "You are not authorized to use these buttons.", ephemeral=True + ) return is_allowed def update_button_states(self, cog_state: TerminalCog): @@ -548,22 +656,31 @@ class TerminalView(ui.View): # Scroll Down max_scroll = len(cog_state.output_history) - MAX_OUTPUT_LINES_PER_IMAGE - self.scroll_down_button.disabled = cog_state.scroll_offset >= max_scroll or len(cog_state.output_history) <= MAX_OUTPUT_LINES_PER_IMAGE + self.scroll_down_button.disabled = ( + cog_state.scroll_offset >= max_scroll + or len(cog_state.output_history) <= MAX_OUTPUT_LINES_PER_IMAGE + ) # Send Input & Refresh should generally be enabled if terminal is active - self.send_input_button.disabled = not cog_state.terminal_active or (cog_state.active_process is not None and cog_state.active_process.poll() is None) # Disable if command running + self.send_input_button.disabled = not cog_state.terminal_active or ( + cog_state.active_process is not None + and cog_state.active_process.poll() is None + ) # Disable if command running self.refresh_button.disabled = not cog_state.terminal_active self.close_button.disabled = not cog_state.terminal_active - async def scroll_up_callback(self, interaction: Interaction): - self.cog.scroll_offset = max(0, self.cog.scroll_offset - (MAX_OUTPUT_LINES_PER_IMAGE // 2)) # Scroll half page + self.cog.scroll_offset = max( + 0, self.cog.scroll_offset - (MAX_OUTPUT_LINES_PER_IMAGE // 2) + ) # Scroll half page await self.cog._update_terminal_message(interaction) async def scroll_down_callback(self, interaction: Interaction): max_scroll = len(self.cog.output_history) - MAX_OUTPUT_LINES_PER_IMAGE - self.cog.scroll_offset = min(max_scroll, self.cog.scroll_offset + (MAX_OUTPUT_LINES_PER_IMAGE // 2)) - self.cog.scroll_offset = max(0, self.cog.scroll_offset) # Ensure not negative + self.cog.scroll_offset = min( + max_scroll, self.cog.scroll_offset + (MAX_OUTPUT_LINES_PER_IMAGE // 2) + ) + self.cog.scroll_offset = max(0, self.cog.scroll_offset) # Ensure not negative await self.cog._update_terminal_message(interaction) async def send_input_callback(self, interaction: Interaction): @@ -574,14 +691,17 @@ class TerminalView(ui.View): async def refresh_callback(self, interaction: Interaction): # Defer because refresh_terminal_output might take a moment and edit await interaction.response.defer() - await self.cog.refresh_terminal_output(interaction) # Pass interaction to update message + await self.cog.refresh_terminal_output( + interaction + ) # Pass interaction to update message async def close_callback(self, interaction: Interaction): - await interaction.response.defer() # Defer before stopping session + await interaction.response.defer() # Defer before stopping session await self.cog.stop_terminal_session(interaction) # The stop_terminal_session should edit the message to indicate closure. # No further update needed here for the view itself as it will be removed. + async def setup(bot: commands.Bot): terminal_cog = TerminalCog(bot) await bot.add_cog(terminal_cog) diff --git a/cogs/teto_cog.py b/cogs/teto_cog.py index bccb1a6..46bc9c6 100644 --- a/cogs/teto_cog.py +++ b/cogs/teto_cog.py @@ -8,7 +8,14 @@ import asyncio import subprocess import json import datetime -from typing import Dict, Any, List, Optional, Union, Tuple # Added Tuple for type hinting +from typing import ( + Dict, + Any, + List, + Optional, + Union, + Tuple, +) # Added Tuple for type hinting from tavily import TavilyClient import os import aiohttp @@ -24,18 +31,31 @@ from gurt.config import PROJECT_ID, LOCATION # Define standard safety settings using google.generativeai types # Set all thresholds to OFF as requested STANDARD_SAFETY_SETTINGS = [ - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold="BLOCK_NONE"), - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold="BLOCK_NONE"), - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold="BLOCK_NONE"), - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold="BLOCK_NONE"), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold="BLOCK_NONE" + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold="BLOCK_NONE", + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold="BLOCK_NONE", + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold="BLOCK_NONE" + ), ] + def strip_think_blocks(text): # Removes all ... blocks, including multiline return re.sub(r".*?", "", text, flags=re.DOTALL) + def encode_image_to_base64(image_data): - return base64.b64encode(image_data).decode('utf-8') + return base64.b64encode(image_data).decode("utf-8") + def extract_shell_command(text): """ @@ -65,6 +85,7 @@ def extract_shell_command(text): return None, text, None + def extract_web_search_query(text): """ Extracts web search queries from text using the custom format: @@ -93,11 +114,15 @@ def extract_web_search_query(text): return None, text, None + # In-memory conversation history for Kasane Teto AI (keyed by channel id) _teto_conversations = {} + # --- Helper Function to Safely Extract Text --- -def _get_response_text(response: Optional[types.GenerateContentResponse]) -> Optional[str]: +def _get_response_text( + response: Optional[types.GenerateContentResponse], +) -> Optional[str]: """ Safely extracts the text content from the first text part of a GenerateContentResponse. Handles potential errors and lack of text parts gracefully. @@ -106,35 +131,47 @@ def _get_response_text(response: Optional[types.GenerateContentResponse]) -> Opt print("[_get_response_text] Received None response object.") return None - if hasattr(response, 'text') and response.text: + if hasattr(response, "text") and response.text: print("[_get_response_text] Found text directly in response.text attribute.") return response.text if not response.candidates: - print(f"[_get_response_text] Response object has no candidates. Response: {response}") + print( + f"[_get_response_text] Response object has no candidates. Response: {response}" + ) return None try: candidate = response.candidates[0] - if not hasattr(candidate, 'content') or not candidate.content: - print(f"[_get_response_text] Candidate 0 has no 'content'. Candidate: {candidate}") + if not hasattr(candidate, "content") or not candidate.content: + print( + f"[_get_response_text] Candidate 0 has no 'content'. Candidate: {candidate}" + ) return None - if not hasattr(candidate.content, 'parts') or not candidate.content.parts: - print(f"[_get_response_text] Candidate 0 content has no 'parts' or parts list is empty. types.Content: {candidate.content}") + if not hasattr(candidate.content, "parts") or not candidate.content.parts: + print( + f"[_get_response_text] Candidate 0 content has no 'parts' or parts list is empty. types.Content: {candidate.content}" + ) return None for i, part in enumerate(candidate.content.parts): - if hasattr(part, 'text') and part.text is not None: - if isinstance(part.text, str) and part.text.strip(): - print(f"[_get_response_text] Found non-empty text in part {i}.") - return part.text - else: - print(f"[_get_response_text] types.Part {i} has 'text' attribute, but it's empty or not a string: {part.text!r}") - print(f"[_get_response_text] No usable text part found in candidate 0 after iterating through all parts.") + if hasattr(part, "text") and part.text is not None: + if isinstance(part.text, str) and part.text.strip(): + print(f"[_get_response_text] Found non-empty text in part {i}.") + return part.text + else: + print( + f"[_get_response_text] types.Part {i} has 'text' attribute, but it's empty or not a string: {part.text!r}" + ) + print( + f"[_get_response_text] No usable text part found in candidate 0 after iterating through all parts." + ) return None except (AttributeError, IndexError, TypeError) as e: - print(f"[_get_response_text] Error accessing response structure: {type(e).__name__}: {e}") + print( + f"[_get_response_text] Error accessing response structure: {type(e).__name__}: {e}" + ) print(f"Problematic response object: {response}") return None except Exception as e: @@ -142,6 +179,7 @@ def _get_response_text(response: Optional[types.GenerateContentResponse]) -> Opt print(f"Response object during error: {response}") return None + class TetoCog(commands.Cog): # Helper function to normalize finish_reason def _as_str(self, fr): @@ -152,13 +190,12 @@ class TetoCog(commands.Cog): # Define command groups at class level ame_group = app_commands.Group( - name="ame", - description="Main command group for Ame-chan AI." + name="ame", description="Main command group for Ame-chan AI." ) model_subgroup = app_commands.Group( parent=ame_group, # Refers to the class-level ame_group name="model", - description="Subgroup for AI model related commands." + description="Subgroup for AI model related commands.", ) def __init__(self, bot: commands.Bot): @@ -171,23 +208,31 @@ class TetoCog(commands.Cog): project=PROJECT_ID, location=LOCATION, ) - print(f"Google GenAI Client initialized for Vertex AI project '{PROJECT_ID}' in location '{LOCATION}'.") + print( + f"Google GenAI Client initialized for Vertex AI project '{PROJECT_ID}' in location '{LOCATION}'." + ) else: self.genai_client = None - print("PROJECT_ID or LOCATION not found in config. Google GenAI Client not initialized.") + print( + "PROJECT_ID or LOCATION not found in config. Google GenAI Client not initialized." + ) except Exception as e: self.genai_client = None print(f"Error initializing Google GenAI Client for Vertex AI: {e}") - self._ai_model = "gemini-2.5-flash-preview-05-20" # Default model for Vertex AI - self._allow_shell_commands = False # Flag to control shell command tool usage + self._ai_model = "gemini-2.5-flash-preview-05-20" # Default model for Vertex AI + self._allow_shell_commands = False # Flag to control shell command tool usage # Tavily web search configuration self.tavily_api_key = os.getenv("TAVILY_API_KEY", "") - self.tavily_client = TavilyClient(api_key=self.tavily_api_key) if self.tavily_api_key else None + self.tavily_client = ( + TavilyClient(api_key=self.tavily_api_key) if self.tavily_api_key else None + ) self.tavily_search_depth = os.getenv("TAVILY_DEFAULT_SEARCH_DEPTH", "basic") self.tavily_max_results = int(os.getenv("TAVILY_DEFAULT_MAX_RESULTS", "5")) - self._allow_web_search = bool(self.tavily_api_key) # Enable web search if API key is available + self._allow_web_search = bool( + self.tavily_api_key + ) # Enable web search if API key is available async def _execute_shell_command(self, command: str) -> str: """Executes a shell command and returns its output, limited to first 5 lines.""" @@ -195,9 +240,7 @@ class TetoCog(commands.Cog): # Use subprocess.run for simple command execution # Consider security implications of running arbitrary commands process = await asyncio.create_subprocess_shell( - command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) stdout, stderr = await process.communicate() @@ -236,24 +279,26 @@ class TetoCog(commands.Cog): r"^(nmap|nc|telnet)\s+", # Networking tools r"^(shutdown|reboot)\s*", # System shutdown/restart r"^(regedit|sysctl)\s+", # System configuration - r"format\s+\w:", # Formatting drives - r"dd\s+", # Disk dumping - r"mkfs\s+", # Creating file systems - r"fdisk\s+", # Partitioning disks - r"parted\s+", # Partitioning disks - r"wipefs\s+", # Wiping file system signatures - r"shred\s+", # Securely deleting files - r"nohup\s+", # Running commands immune to hangups - r"&", # Command chaining - r"\|", # Command piping (escaped pipe character) - r">", # Output redirection - r"<", # Input redirection - r";", # Command separation + r"format\s+\w:", # Formatting drives + r"dd\s+", # Disk dumping + r"mkfs\s+", # Creating file systems + r"fdisk\s+", # Partitioning disks + r"parted\s+", # Partitioning disks + r"wipefs\s+", # Wiping file system signatures + r"shred\s+", # Securely deleting files + r"nohup\s+", # Running commands immune to hangups + r"&", # Command chaining + r"\|", # Command piping (escaped pipe character) + r">", # Output redirection + r"<", # Input redirection + r";", # Command separation ] command_lower = command.lower() for pattern in dangerous_patterns: if re.search(pattern, command_lower): - print(f"[TETO DEBUG] Blocked command '{command}' due to matching pattern: '{pattern}'") + print( + f"[TETO DEBUG] Blocked command '{command}' due to matching pattern: '{pattern}'" + ) return True return False @@ -294,9 +339,7 @@ class TetoCog(commands.Cog): "If a user asks you a question that requires current information or facts, call the 'web_search' tool with the search query. \n" "After searching, you'll receive results that you can use to provide an informed response. \n" ) - system_prompt_text += ( - "Also please note that these tools arent for running random garbage, they execute **REAL** terminal commands and web searches." - ) + system_prompt_text += "Also please note that these tools arent for running random garbage, they execute **REAL** terminal commands and web searches." # Define tools for Vertex AI shell_command_tool = types.FunctionDeclaration( @@ -304,7 +347,12 @@ class TetoCog(commands.Cog): description="Executes a shell command and returns its output. Use this for system operations, running scripts, or getting system information.", parameters={ "type": "object", - "properties": {"command": {"type": "string", "description": "The shell command to execute."}}, + "properties": { + "command": { + "type": "string", + "description": "The shell command to execute.", + } + }, "required": ["command"], }, ) @@ -313,30 +361,35 @@ class TetoCog(commands.Cog): description="Searches the web for information using a query. Use this to answer questions requiring current information or facts.", parameters={ "type": "object", - "properties": {"query": {"type": "string", "description": "The search query."}}, + "properties": { + "query": {"type": "string", "description": "The search query."} + }, "required": ["query"], }, ) - + available_tools = [] if self._allow_shell_commands: available_tools.append(shell_command_tool) if self._allow_web_search and self.tavily_client: available_tools.append(web_search_tool_decl) - - vertex_tools = [types.Tool(function_declarations=available_tools)] if available_tools else None + vertex_tools = ( + [types.Tool(function_declarations=available_tools)] + if available_tools + else None + ) # Convert input messages to Vertex AI `types.Content` vertex_contents: List[types.Content] = [] for msg in messages: role = "user" if msg.get("role") == "user" else "model" parts: List[types.Part] = [] - + content_data = msg.get("content") if isinstance(content_data, str): parts.append(types.Part(text=content_data)) - elif isinstance(content_data, list): # Multimodal content + elif isinstance(content_data, list): # Multimodal content for item in content_data: item_type = item.get("type") if item_type == "text": @@ -348,11 +401,23 @@ class TetoCog(commands.Cog): header, encoded = image_url_data.split(",", 1) mime_type = header.split(":")[1].split(";")[0] image_bytes = base64.b64decode(encoded) - parts.append(types.Part(inline_data=types.Blob(data=image_bytes, mime_type=mime_type))) + parts.append( + types.Part( + inline_data=types.Blob( + data=image_bytes, mime_type=mime_type + ) + ) + ) except Exception as e: - print(f"[TETO DEBUG] Error processing base64 image for Vertex: {e}") - parts.append(types.Part(text="[System Note: Error processing an attached image]")) - else: # If it's a direct URL (e.g. for stickers, emojis) + print( + f"[TETO DEBUG] Error processing base64 image for Vertex: {e}" + ) + parts.append( + types.Part( + text="[System Note: Error processing an attached image]" + ) + ) + else: # If it's a direct URL (e.g. for stickers, emojis) # Vertex AI prefers direct data or GCS URIs. For simplicity, we'll try to download and send data. # This might be slow or fail for large images. try: @@ -360,21 +425,54 @@ class TetoCog(commands.Cog): async with session.get(image_url_data) as resp: if resp.status == 200: image_bytes = await resp.read() - mime_type = resp.content_type or "application/octet-stream" + mime_type = ( + resp.content_type + or "application/octet-stream" + ) # Validate MIME type for Vertex - supported_image_mimes = ["image/png", "image/jpeg", "image/webp", "image/heic", "image/heif", "image/gif"] - clean_mime_type = mime_type.split(';')[0].lower() + supported_image_mimes = [ + "image/png", + "image/jpeg", + "image/webp", + "image/heic", + "image/heif", + "image/gif", + ] + clean_mime_type = mime_type.split(";")[ + 0 + ].lower() if clean_mime_type in supported_image_mimes: - parts.append(types.Part(inline_data=types.Blob(data=image_bytes, mime_type=clean_mime_type))) + parts.append( + types.Part( + inline_data=types.Blob( + data=image_bytes, + mime_type=clean_mime_type, + ) + ) + ) else: - parts.append(types.Part(text=f"[System Note: Image type {clean_mime_type} from URL not directly supported, original URL: {image_url_data}]")) + parts.append( + types.Part( + text=f"[System Note: Image type {clean_mime_type} from URL not directly supported, original URL: {image_url_data}]" + ) + ) else: - parts.append(types.Part(text=f"[System Note: Failed to download image from URL: {image_url_data}]")) + parts.append( + types.Part( + text=f"[System Note: Failed to download image from URL: {image_url_data}]" + ) + ) except Exception as e: - print(f"[TETO DEBUG] Error downloading image from URL {image_url_data} for Vertex: {e}") - parts.append(types.Part(text=f"[System Note: Error processing image from URL: {image_url_data}]")) - - if parts: # Only add if there are valid parts + print( + f"[TETO DEBUG] Error downloading image from URL {image_url_data} for Vertex: {e}" + ) + parts.append( + types.Part( + text=f"[System Note: Error processing image from URL: {image_url_data}]" + ) + ) + + if parts: # Only add if there are valid parts vertex_contents.append(types.Content(role=role, parts=parts)) max_tool_calls = 5 @@ -382,8 +480,8 @@ class TetoCog(commands.Cog): while tool_calls_made < max_tool_calls: generation_config = types.GenerateContentConfig( - temperature=1.0, # Example, adjust as needed - max_output_tokens=2000, # Example + temperature=1.0, # Example, adjust as needed + max_output_tokens=2000, # Example safety_settings=STANDARD_SAFETY_SETTINGS, # system_instruction is not a direct param for generate_content, handled by model or prepended ) @@ -404,23 +502,35 @@ class TetoCog(commands.Cog): temperature=generation_config.temperature, max_output_tokens=generation_config.max_output_tokens, safety_settings=generation_config.safety_settings, - system_instruction=types.Content(role="system", parts=[types.Part(text=system_prompt_text)]), - tools=vertex_tools, # Add tools here - tool_config=types.ToolConfig( # Add tool_config here - function_calling_config=types.FunctionCallingConfig( - mode=types.FunctionCallingConfigMode.AUTO if vertex_tools else types.FunctionCallingConfigMode.NONE + system_instruction=types.Content( + role="system", parts=[types.Part(text=system_prompt_text)] + ), + tools=vertex_tools, # Add tools here + tool_config=( + types.ToolConfig( # Add tool_config here + function_calling_config=types.FunctionCallingConfig( + mode=( + types.FunctionCallingConfigMode.AUTO + if vertex_tools + else types.FunctionCallingConfigMode.NONE + ) + ) ) - ) if vertex_tools else None, + if vertex_tools + else None + ), ) final_contents_for_api = vertex_contents try: - print(f"[TETO DEBUG] Sending to Vertex AI. Model: {self._ai_model}, Tool Config: {vertex_tools is not None}") + print( + f"[TETO DEBUG] Sending to Vertex AI. Model: {self._ai_model}, Tool Config: {vertex_tools is not None}" + ) response = await self.genai_client.aio.models.generate_content( - model=f"publishers/google/models/{self._ai_model}", # Use simpler model path + model=f"publishers/google/models/{self._ai_model}", # Use simpler model path contents=final_contents_for_api, - config=generation_config_with_system, # Pass the updated config + config=generation_config_with_system, # Pass the updated config ) except google_exceptions.GoogleAPICallError as e: @@ -432,7 +542,7 @@ class TetoCog(commands.Cog): raise RuntimeError("Vertex AI response had no candidates.") candidate = response.candidates[0] - + finish_reason = getattr(candidate, "finish_reason", None) finish_reason_str = self._as_str(finish_reason) @@ -448,48 +558,73 @@ class TetoCog(commands.Cog): has_tool_call = True function_call = part.function_call tool_name = function_call.name - tool_args = dict(function_call.args) if function_call.args else {} - - print(f"[TETO DEBUG] Vertex AI requested tool: {tool_name} with args: {tool_args}") - + tool_args = ( + dict(function_call.args) if function_call.args else {} + ) + + print( + f"[TETO DEBUG] Vertex AI requested tool: {tool_name} with args: {tool_args}" + ) + # Append model's request to history vertex_contents.append(candidate.content) - + tool_result_str = "" if tool_name == "execute_shell_command": command_to_run = tool_args.get("command", "") if self._is_dangerous_command(command_to_run): tool_result_str = "❌ Error: Execution was blocked due to a potentially dangerous command." else: - tool_result_str = await self._execute_shell_command(command_to_run) - + tool_result_str = await self._execute_shell_command( + command_to_run + ) + elif tool_name == "web_search": query_to_search = tool_args.get("query", "") - search_api_results = await self.web_search(query=query_to_search) + search_api_results = await self.web_search( + query=query_to_search + ) if "error" in search_api_results: tool_result_str = f"❌ Error: Web search failed - {search_api_results['error']}" else: results_text_parts = [] - for i, res_item in enumerate(search_api_results.get("results", [])[:3], 1): # Limit to 3 results for brevity - results_text_parts.append(f"Result {i}:\nTitle: {res_item['title']}\nURL: {res_item['url']}\nContent Snippet: {res_item['content'][:200]}...\n") + for i, res_item in enumerate( + search_api_results.get("results", [])[:3], 1 + ): # Limit to 3 results for brevity + results_text_parts.append( + f"Result {i}:\nTitle: {res_item['title']}\nURL: {res_item['url']}\nContent Snippet: {res_item['content'][:200]}...\n" + ) if search_api_results.get("answer"): - results_text_parts.append(f"Summary Answer: {search_api_results['answer']}") + results_text_parts.append( + f"Summary Answer: {search_api_results['answer']}" + ) tool_result_str = "\n\n".join(results_text_parts) if not tool_result_str: - tool_result_str = "No results found or summary available." + tool_result_str = ( + "No results found or summary available." + ) else: - tool_result_str = f"Error: Unknown tool '{tool_name}' requested." + tool_result_str = ( + f"Error: Unknown tool '{tool_name}' requested." + ) # Append tool response to history - vertex_contents.append(types.Content( - role="function", # "tool" role was for older versions, "function" is current for Gemini - parts=[types.Part.from_function_response(name=tool_name, response={"result": tool_result_str})] - )) + vertex_contents.append( + types.Content( + role="function", # "tool" role was for older versions, "function" is current for Gemini + parts=[ + types.Part.from_function_response( + name=tool_name, + response={"result": tool_result_str}, + ) + ], + ) + ) tool_calls_made += 1 - break # Re-evaluate with new history + break # Re-evaluate with new history if has_tool_call: - continue # Continue the while loop for next API call - + continue # Continue the while loop for next API call + # If no function call or loop finished final_ai_text_response = _get_response_text(response) if final_ai_text_response: @@ -502,35 +637,49 @@ class TetoCog(commands.Cog): # If response has no text part (e.g. only safety block or empty) safety_ratings_str = "" if candidate.safety_ratings: - safety_ratings_str = ", ".join([f"{rating.category.name}: {rating.probability.name}" for rating in candidate.safety_ratings]) - + safety_ratings_str = ", ".join( + [ + f"{rating.category.name}: {rating.probability.name}" + for rating in candidate.safety_ratings + ] + ) + error_detail = f"Vertex AI response had no text. Finish Reason: {finish_reason_str}." if safety_ratings_str: error_detail += f" Safety Ratings: [{safety_ratings_str}]." - + # If blocked by safety, we should inform the user or log appropriately. # For now, returning a generic message. if finish_reason_str == "SAFETY": - return f"(Teto AI response was blocked due to safety settings: {safety_ratings_str})" - - print(f"[TETO DEBUG] {error_detail}") # Log it - return "(Teto AI had a problem generating a response or the response was empty.)" + return f"(Teto AI response was blocked due to safety settings: {safety_ratings_str})" + print(f"[TETO DEBUG] {error_detail}") # Log it + return "(Teto AI had a problem generating a response or the response was empty.)" # If loop finishes due to max_tool_calls if tool_calls_made >= max_tool_calls: return "(Teto AI reached maximum tool interaction limit. Please try rephrasing.)" - - return "(Teto AI encountered an unexpected state.)" # Fallback + + return "(Teto AI encountered an unexpected state.)" # Fallback async def _teto_reply_ai(self, text: str) -> str: """Replies to the text as Kasane Teto using AI via Vertex AI.""" - return await self._teto_reply_ai_with_messages([{"role": "user", "content": text}]) + return await self._teto_reply_ai_with_messages( + [{"role": "user", "content": text}] + ) - async def web_search(self, query: str, search_depth: Optional[str] = None, max_results: Optional[int] = None) -> Dict[str, Any]: + async def web_search( + self, + query: str, + search_depth: Optional[str] = None, + max_results: Optional[int] = None, + ) -> Dict[str, Any]: """Search the web using Tavily API""" if not self.tavily_client: - return {"error": "Tavily client not initialized. TAVILY_API_KEY environment variable may not be set.", "timestamp": datetime.datetime.now().isoformat()} + return { + "error": "Tavily client not initialized. TAVILY_API_KEY environment variable may not be set.", + "timestamp": datetime.datetime.now().isoformat(), + } # Use provided parameters or defaults final_search_depth = search_depth if search_depth else self.tavily_search_depth @@ -538,7 +687,9 @@ class TetoCog(commands.Cog): # Validate search_depth if final_search_depth.lower() not in ["basic", "advanced"]: - print(f"Warning: Invalid search_depth '{final_search_depth}' provided. Using 'basic'.") + print( + f"Warning: Invalid search_depth '{final_search_depth}' provided. Using 'basic'." + ) final_search_depth = "basic" # Validate max_results (between 5 and 20) @@ -552,18 +703,20 @@ class TetoCog(commands.Cog): search_depth=final_search_depth, max_results=final_max_results, include_answer=True, - include_images=False + include_images=False, ) # Format results for easier consumption results = [] for r in response.get("results", []): - results.append({ - "title": r.get("title", "No title"), - "url": r.get("url", ""), - "content": r.get("content", "No content available"), - "score": r.get("score", 0) - }) + results.append( + { + "title": r.get("title", "No title"), + "url": r.get("url", ""), + "content": r.get("content", "No content available"), + "score": r.get("score", 0), + } + ) return { "query": query, @@ -572,18 +725,24 @@ class TetoCog(commands.Cog): "results": results, "answer": response.get("answer", ""), "count": len(results), - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: error_message = f"Error during Tavily search for '{query}': {str(e)}" print(error_message) - return {"error": error_message, "timestamp": datetime.datetime.now().isoformat()} + return { + "error": error_message, + "timestamp": datetime.datetime.now().isoformat(), + } @commands.Cog.listener() async def on_message(self, message: discord.Message): import logging + log = logging.getLogger("teto_cog") - log.info(f"[TETO DEBUG] Received message: {message.content!r} (author={message.author}, id={message.id})") + log.info( + f"[TETO DEBUG] Received message: {message.content!r} (author={message.author}, id={message.id})" + ) if message.author.bot: log.info("[TETO DEBUG] Ignoring bot message.") @@ -594,7 +753,9 @@ class TetoCog(commands.Cog): for mention in message.mentions: mention_str = f"<@{mention.id}>" mention_nick_str = f"<@!{mention.id}>" - content_wo_mentions = content_wo_mentions.replace(mention_str, "").replace(mention_nick_str, "") + content_wo_mentions = content_wo_mentions.replace(mention_str, "").replace( + mention_nick_str, "" + ) content_wo_mentions = content_wo_mentions.strip() trigger = False @@ -613,17 +774,21 @@ class TetoCog(commands.Cog): else: prefixes = ("!",) - if ( - self.bot.user in message.mentions - and not content_wo_mentions.startswith(prefixes) + if self.bot.user in message.mentions and not content_wo_mentions.startswith( + prefixes ): trigger = True - log.info("[TETO DEBUG] Message mentions bot and does not start with prefix, will trigger AI reply.") + log.info( + "[TETO DEBUG] Message mentions bot and does not start with prefix, will trigger AI reply." + ) elif ( - message.reference and getattr(message.reference.resolved, "author", None) == self.bot.user + message.reference + and getattr(message.reference.resolved, "author", None) == self.bot.user ): trigger = True - log.info("[TETO DEBUG] Message is a reply to the bot, will trigger AI reply.") + log.info( + "[TETO DEBUG] Message is a reply to the bot, will trigger AI reply." + ) if not trigger: log.info("[TETO DEBUG] Message did not trigger AI reply logic.") @@ -637,13 +802,21 @@ class TetoCog(commands.Cog): if trigger: user_content = [] # Prepend username to the message content - username = message.author.display_name if message.author.display_name else message.author.name + username = ( + message.author.display_name + if message.author.display_name + else message.author.name + ) if message.content: - user_content.append({"type": "text", "text": f"{username}: {message.content}"}) + user_content.append( + {"type": "text", "text": f"{username}: {message.content}"} + ) # Handle attachments (images) for attachment in message.attachments: - if attachment.content_type and attachment.content_type.startswith("image/"): + if attachment.content_type and attachment.content_type.startswith( + "image/" + ): try: async with aiohttp.ClientSession() as session: async with session.get(attachment.url) as image_response: @@ -651,24 +824,55 @@ class TetoCog(commands.Cog): image_data = await image_response.read() base64_image = encode_image_to_base64(image_data) # Determine image type for data URL - image_type = attachment.content_type.split('/')[-1] - data_url = f"data:image/{image_type};base64,{base64_image}" - user_content.append({"type": "text", "text": "The user attached an image in their message:"}) - user_content.append({"type": "image_url", "image_url": {"url": data_url}}) - log.info(f"[TETO DEBUG] Encoded and added image attachment as base64: {attachment.url}") + image_type = attachment.content_type.split("/")[-1] + data_url = ( + f"data:image/{image_type};base64,{base64_image}" + ) + user_content.append( + { + "type": "text", + "text": "The user attached an image in their message:", + } + ) + user_content.append( + { + "type": "image_url", + "image_url": {"url": data_url}, + } + ) + log.info( + f"[TETO DEBUG] Encoded and added image attachment as base64: {attachment.url}" + ) else: - log.warning(f"[TETO DEBUG] Failed to download image attachment: {attachment.url} (Status: {image_response.status})") - user_content.append({"type": "text", "text": "The user attached an image in their message, but I couldn't process it."}) + log.warning( + f"[TETO DEBUG] Failed to download image attachment: {attachment.url} (Status: {image_response.status})" + ) + user_content.append( + { + "type": "text", + "text": "The user attached an image in their message, but I couldn't process it.", + } + ) except Exception as e: - log.error(f"[TETO DEBUG] Error processing image attachment {attachment.url}: {e}") - user_content.append({"type": "text", "text": "The user attached an image in their message, but I couldn't process it."}) - + log.error( + f"[TETO DEBUG] Error processing image attachment {attachment.url}: {e}" + ) + user_content.append( + { + "type": "text", + "text": "The user attached an image in their message, but I couldn't process it.", + } + ) # Handle stickers for sticker in message.stickers: - # Assuming sticker has a url attribute - user_content.append({"type": "text", "text": "The user sent a sticker image:"}) - user_content.append({"type": "image_url", "image_url": {"url": sticker.url}}) + # Assuming sticker has a url attribute + user_content.append( + {"type": "text", "text": "The user sent a sticker image:"} + ) + user_content.append( + {"type": "image_url", "image_url": {"url": sticker.url}} + ) print(f"[TETO DEBUG] Found sticker: {sticker.url}") # Handle custom emojis (basic regex for <:name:id> and ) @@ -679,14 +883,22 @@ class TetoCog(commands.Cog): # Construct Discord emoji URL is_animated = match.group(0).startswith(" 1950: # Discord limit is 2000, leave some room + elif ( + len(ai_reply_text) > 1950 + ): # Discord limit is 2000, leave some room ai_reply_text = ai_reply_text[:1950] + "... (message truncated)" - log.warning("[TETO DEBUG] AI reply was truncated due to length.") + log.warning( + "[TETO DEBUG] AI reply was truncated due to length." + ) await message.reply(ai_reply_text) # Store the AI's textual response in the conversation history # The tool handling logic is now within _teto_reply_ai_with_messages convo.append({"role": "assistant", "content": ai_reply_text}) - - _teto_conversations[convo_key] = convo[-10:] # Keep last 10 interactions (user + assistant turns) + + _teto_conversations[convo_key] = convo[ + -10: + ] # Keep last 10 interactions (user + assistant turns) log.info("[TETO DEBUG] AI reply sent successfully using Vertex AI.") except Exception as e: await channel.send(f"**Teto AI conversation failed! TwT**\n{e}") @@ -724,38 +946,65 @@ class TetoCog(commands.Cog): @app_commands.describe(model_name="The name of the AI model to use.") async def set_ai_model(self, interaction: discord.Interaction, model_name: str): self._ai_model = model_name - await interaction.response.send_message(f"Ame-chan's AI model set to: {model_name} desu~", ephemeral=True) + await interaction.response.send_message( + f"Ame-chan's AI model set to: {model_name} desu~", ephemeral=True + ) - @ame_group.command(name="clear_chat_history", description="Clears the chat history for the current channel.") + @ame_group.command( + name="clear_chat_history", + description="Clears the chat history for the current channel.", + ) async def clear_chat_history(self, interaction: discord.Interaction): channel_id = interaction.channel_id if channel_id in _teto_conversations: del _teto_conversations[channel_id] - await interaction.response.send_message("Chat history cleared for this channel desu~", ephemeral=True) + await interaction.response.send_message( + "Chat history cleared for this channel desu~", ephemeral=True + ) else: - await interaction.response.send_message("No chat history found for this channel desu~", ephemeral=True) + await interaction.response.send_message( + "No chat history found for this channel desu~", ephemeral=True + ) - @ame_group.command(name="toggle_shell_command", description="Toggles Ame-chan's ability to run shell commands.") + @ame_group.command( + name="toggle_shell_command", + description="Toggles Ame-chan's ability to run shell commands.", + ) async def toggle_shell_command(self, interaction: discord.Interaction): self._allow_shell_commands = not self._allow_shell_commands status = "enabled" if self._allow_shell_commands else "disabled" - await interaction.response.send_message(f"Ame-chan's shell command ability is now {status} desu~", ephemeral=True) + await interaction.response.send_message( + f"Ame-chan's shell command ability is now {status} desu~", ephemeral=True + ) - @ame_group.command(name="toggle_web_search", description="Toggles Ame-chan's ability to search the web.") + @ame_group.command( + name="toggle_web_search", + description="Toggles Ame-chan's ability to search the web.", + ) async def toggle_web_search(self, interaction: discord.Interaction): if not self.tavily_api_key or not self.tavily_client: - await interaction.response.send_message("Web search is not available because the Tavily API key is not configured. Please set the TAVILY_API_KEY environment variable.", ephemeral=True) + await interaction.response.send_message( + "Web search is not available because the Tavily API key is not configured. Please set the TAVILY_API_KEY environment variable.", + ephemeral=True, + ) return self._allow_web_search = not self._allow_web_search status = "enabled" if self._allow_web_search else "disabled" - await interaction.response.send_message(f"Ame-chan's web search ability is now {status} desu~", ephemeral=True) + await interaction.response.send_message( + f"Ame-chan's web search ability is now {status} desu~", ephemeral=True + ) - @ame_group.command(name="web_search", description="Search the web using Tavily API.") + @ame_group.command( + name="web_search", description="Search the web using Tavily API." + ) @app_commands.describe(query="The search query to look up online.") async def web_search_command(self, interaction: discord.Interaction, query: str): if not self.tavily_api_key or not self.tavily_client: - await interaction.response.send_message("Web search is not available because the Tavily API key is not configured. Please set the TAVILY_API_KEY environment variable.", ephemeral=True) + await interaction.response.send_message( + "Web search is not available because the Tavily API key is not configured. Please set the TAVILY_API_KEY environment variable.", + ephemeral=True, + ) return await interaction.response.defer(thinking=True) @@ -764,39 +1013,54 @@ class TetoCog(commands.Cog): search_results = await self.web_search(query=query) if "error" in search_results: - await interaction.followup.send(f"❌ Error: Web search failed - {search_results['error']}") + await interaction.followup.send( + f"❌ Error: Web search failed - {search_results['error']}" + ) return # Format the results in a readable way embed = discord.Embed( title=f"🔍 Web Search Results for: {query}", description=search_results.get("answer", "No summary available."), - color=discord.Color.blue() + color=discord.Color.blue(), ) - for i, result in enumerate(search_results.get("results", [])[:5], 1): # Limit to top 5 results + for i, result in enumerate( + search_results.get("results", [])[:5], 1 + ): # Limit to top 5 results embed.add_field( name=f"Result {i}: {result['title']}", value=f"[Link]({result['url']})\n{result['content'][:200]}...", - inline=False + inline=False, ) - embed.set_footer(text=f"Search depth: {search_results['search_depth']} | Results: {search_results['count']}") + embed.set_footer( + text=f"Search depth: {search_results['search_depth']} | Results: {search_results['count']}" + ) await interaction.followup.send(embed=embed) except Exception as e: await interaction.followup.send(f"❌ Error performing web search: {str(e)}") - @model_subgroup.command(name="get", description="Gets the current AI model for Ame-chan.") + @model_subgroup.command( + name="get", description="Gets the current AI model for Ame-chan." + ) async def get_ai_model(self, interaction: discord.Interaction): - await interaction.response.send_message(f"Ame-chan's current AI model is: {self._ai_model} desu~", ephemeral=True) + await interaction.response.send_message( + f"Ame-chan's current AI model is: {self._ai_model} desu~", ephemeral=True + ) + # Context menu command must be defined at module level @app_commands.context_menu(name="Teto AI Reply") -async def teto_context_menu_ai_reply(interaction: discord.Interaction, message: discord.Message): +async def teto_context_menu_ai_reply( + interaction: discord.Interaction, message: discord.Message +): """Replies to the selected message as a Teto AI.""" if not message.content: - await interaction.response.send_message("The selected message has no text content to reply to! >.<", ephemeral=True) + await interaction.response.send_message( + "The selected message has no text content to reply to! >.<", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) @@ -810,7 +1074,9 @@ async def teto_context_menu_ai_reply(interaction: discord.Interaction, message: # Get the TetoCog instance from the bot cog = interaction.client.get_cog("TetoCog") if cog is None: - await interaction.followup.send("TetoCog is not loaded, cannot reply.", ephemeral=True) + await interaction.followup.send( + "TetoCog is not loaded, cannot reply.", ephemeral=True + ) return ai_reply = await cog._teto_reply_ai_with_messages(messages=convo) ai_reply = strip_think_blocks(ai_reply) @@ -818,15 +1084,22 @@ async def teto_context_menu_ai_reply(interaction: discord.Interaction, message: await interaction.followup.send("Teto AI replied desu~", ephemeral=True) # Store the AI's textual response in the conversation history - convo.append({"role": "assistant", "content": ai_reply}) # ai_reply is already the text + convo.append( + {"role": "assistant", "content": ai_reply} + ) # ai_reply is already the text _teto_conversations[convo_key] = convo[-10:] except Exception as e: - await interaction.followup.send(f"Teto AI reply failed: {e} desu~", ephemeral=True) + await interaction.followup.send( + f"Teto AI reply failed: {e} desu~", ephemeral=True + ) + async def setup(bot: commands.Bot): cog = TetoCog(bot) await bot.add_cog(cog) # bot.tree.add_command(cog.ame_group) # No longer needed if groups are class variables; discovery should handle it. # Ensure the context menu is still added if it's not part of the cog's auto-discovery - bot.tree.add_command(teto_context_menu_ai_reply) # This is a module-level command, so it needs to be added. + bot.tree.add_command( + teto_context_menu_ai_reply + ) # This is a module-level command, so it needs to be added. print("TetoCog loaded! desu~") diff --git a/cogs/teto_image_cog.py b/cogs/teto_image_cog.py index abe9b9d..dab67b8 100644 --- a/cogs/teto_image_cog.py +++ b/cogs/teto_image_cog.py @@ -3,15 +3,19 @@ from discord.ext import commands import aiohttp from typing import Union + class TetoImageView(discord.ui.View): def __init__(self, cog): super().__init__(timeout=180.0) self.cog = cog - @discord.ui.button(label='Show Another Image', style=discord.ButtonStyle.primary) - async def show_another_image(self, interaction: discord.Interaction, _: discord.ui.Button): + @discord.ui.button(label="Show Another Image", style=discord.ButtonStyle.primary) + async def show_another_image( + self, interaction: discord.Interaction, _: discord.ui.Button + ): await self.cog.get_teto_image(interaction) + class TetoImageCog(commands.Cog): def __init__(self, bot): self.bot = bot @@ -21,16 +25,24 @@ class TetoImageCog(commands.Cog): async def fetch_teto_image(self): """Fetches a random Teto image and returns the URL.""" async with aiohttp.ClientSession() as session: - headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'} - async with session.get(self.teto_url, headers=headers, allow_redirects=False) as response: + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + } + async with session.get( + self.teto_url, headers=headers, allow_redirects=False + ) as response: if response.status == 302: - image_path = response.headers.get('Location') + image_path = response.headers.get("Location") if image_path: return f"https://slipstreamm.dev{image_path}" return None - @commands.hybrid_command(name='tetoimage', description='Get a random image of Kasane Teto.') - async def get_teto_image(self, ctx_or_interaction: Union[commands.Context, discord.Interaction]): + @commands.hybrid_command( + name="tetoimage", description="Get a random image of Kasane Teto." + ) + async def get_teto_image( + self, ctx_or_interaction: Union[commands.Context, discord.Interaction] + ): """Gets a random image of Kasane Teto.""" # Determine if this is a Context or Interaction is_interaction = isinstance(ctx_or_interaction, discord.Interaction) @@ -49,7 +61,7 @@ class TetoImageCog(commands.Cog): embed = discord.Embed( title="Random Teto Image", description=f"Website: {self.footer_url}", - color=discord.Color.red() + color=discord.Color.red(), ) if image_url.startswith("http") or image_url.startswith("https"): @@ -59,9 +71,13 @@ class TetoImageCog(commands.Cog): # Send response differently based on object type if is_interaction: if ctx_or_interaction.response.is_done(): - await ctx_or_interaction.followup.send(embed=embed, view=view) + await ctx_or_interaction.followup.send( + embed=embed, view=view + ) else: - await ctx_or_interaction.response.send_message(embed=embed, view=view) + await ctx_or_interaction.response.send_message( + embed=embed, view=view + ) else: await ctx_or_interaction.send(embed=embed, view=view) else: @@ -88,5 +104,6 @@ class TetoImageCog(commands.Cog): else: await ctx_or_interaction.send(error_msg) + async def setup(bot): await bot.add_cog(TetoImageCog(bot)) diff --git a/cogs/timer_cog.py b/cogs/timer_cog.py index dc19d69..15d0f96 100644 --- a/cogs/timer_cog.py +++ b/cogs/timer_cog.py @@ -6,7 +6,8 @@ import asyncio from datetime import datetime, timedelta import os -TIMER_FILE = 'data/timers.json' +TIMER_FILE = "data/timers.json" + class TimerCog(commands.Cog): def __init__(self, bot): @@ -16,15 +17,17 @@ class TimerCog(commands.Cog): self.timer_check_task.start() def load_timers(self): - if not os.path.exists('data'): - os.makedirs('data') + if not os.path.exists("data"): + os.makedirs("data") if os.path.exists(TIMER_FILE): - with open(TIMER_FILE, 'r') as f: + with open(TIMER_FILE, "r") as f: try: data = json.load(f) # Convert string timestamps back to datetime objects for timer_data in data: - timer_data['expires_at'] = datetime.fromisoformat(timer_data['expires_at']) + timer_data["expires_at"] = datetime.fromisoformat( + timer_data["expires_at"] + ) self.timers.append(timer_data) except json.JSONDecodeError: self.timers = [] @@ -37,30 +40,32 @@ class TimerCog(commands.Cog): serializable_timers = [] for timer in self.timers: timer_copy = timer.copy() - timer_copy['expires_at'] = timer_copy['expires_at'].isoformat() + timer_copy["expires_at"] = timer_copy["expires_at"].isoformat() serializable_timers.append(timer_copy) - - with open(TIMER_FILE, 'w') as f: + + with open(TIMER_FILE, "w") as f: json.dump(serializable_timers, f, indent=4) print(f"Saved {len(self.timers)} timers.") - @tasks.loop(seconds=10) # Check every 10 seconds + @tasks.loop(seconds=10) # Check every 10 seconds async def timer_check_task(self): now = datetime.now() expired_timers = [] for timer in self.timers: - if timer['expires_at'] <= now: + if timer["expires_at"] <= now: expired_timers.append(timer) - + for timer in expired_timers: self.timers.remove(timer) try: - channel = self.bot.get_channel(timer['channel_id']) + channel = self.bot.get_channel(timer["channel_id"]) if channel: - user = self.bot.get_user(timer['user_id']) + user = self.bot.get_user(timer["user_id"]) if user: message_content = f"{user.mention}, your timer for '{timer['message']}' has expired!" - if timer.get('ephemeral', True): # Default to True if not specified + if timer.get( + "ephemeral", True + ): # Default to True if not specified # Ephemeral messages require interaction context, which we don't have here. # For now, we'll send non-ephemeral if it was originally ephemeral. # A better solution would be to store interaction context or use webhooks. @@ -73,7 +78,7 @@ class TimerCog(commands.Cog): print(f"Could not find channel {timer['channel_id']} for timer.") except Exception as e: print(f"Error sending timer message: {e}") - + if expired_timers: self.save_timers() @@ -85,9 +90,15 @@ class TimerCog(commands.Cog): @app_commands.describe( time_str="Duration for the timer (e.g., 1h30m, 5m, 2d). Supports s, m, h, d.", message="The message for your reminder.", - ephemeral="Whether the response should only be visible to you (defaults to True)." + ephemeral="Whether the response should only be visible to you (defaults to True).", ) - async def timer_slash(self, interaction: discord.Interaction, time_str: str, message: str = "a reminder", ephemeral: bool = True): + async def timer_slash( + self, + interaction: discord.Interaction, + time_str: str, + message: str = "a reminder", + ephemeral: bool = True, + ): """ Sets a timer, reminder, or alarm as a slash command. Usage: /timer time_str:1h30m message:Your reminder message ephemeral:False @@ -104,47 +115,61 @@ class TimerCog(commands.Cog): else: if current_num: num = int(current_num) - if char == 's': + if char == "s": duration_seconds += num - elif char == 'm': + elif char == "m": duration_seconds += num * 60 - elif char == 'h': + elif char == "h": duration_seconds += num * 60 * 60 - elif char == 'd': + elif char == "d": duration_seconds += num * 60 * 60 * 24 else: - await interaction.response.send_message("Invalid time unit. Use s, m, h, or d.", ephemeral=ephemeral) + await interaction.response.send_message( + "Invalid time unit. Use s, m, h, or d.", ephemeral=ephemeral + ) return current_num = "" else: - await interaction.response.send_message("Invalid time format. Example: `1h30m` or `5m`.", ephemeral=ephemeral) + await interaction.response.send_message( + "Invalid time format. Example: `1h30m` or `5m`.", + ephemeral=ephemeral, + ) return - - if current_num: # Handle cases like "30s" without a unit at the end - await interaction.response.send_message("Invalid time format. Please specify a unit (s, m, h, d) for all numbers.", ephemeral=ephemeral) + + if current_num: # Handle cases like "30s" without a unit at the end + await interaction.response.send_message( + "Invalid time format. Please specify a unit (s, m, h, d) for all numbers.", + ephemeral=ephemeral, + ) return if duration_seconds <= 0: - await interaction.response.send_message("Duration must be a positive value.", ephemeral=ephemeral) + await interaction.response.send_message( + "Duration must be a positive value.", ephemeral=ephemeral + ) return expires_at = datetime.now() + timedelta(seconds=duration_seconds) timer_data = { - 'user_id': interaction.user.id, - 'channel_id': interaction.channel_id, - 'message': message, - 'expires_at': expires_at, - 'ephemeral': ephemeral + "user_id": interaction.user.id, + "channel_id": interaction.channel_id, + "message": message, + "expires_at": expires_at, + "ephemeral": ephemeral, } self.timers.append(timer_data) self.save_timers() - await interaction.response.send_message(f"Timer set for {timedelta(seconds=duration_seconds)} from now for '{message}'.", ephemeral=ephemeral) + await interaction.response.send_message( + f"Timer set for {timedelta(seconds=duration_seconds)} from now for '{message}'.", + ephemeral=ephemeral, + ) def cog_unload(self): self.timer_check_task.cancel() - self.save_timers() # Ensure timers are saved on unload + self.save_timers() # Ensure timers are saved on unload + async def setup(bot: commands.Bot): await bot.add_cog(TimerCog(bot)) diff --git a/cogs/tts_provider_cog.py b/cogs/tts_provider_cog.py index 1978cb8..4d86a7e 100644 --- a/cogs/tts_provider_cog.py +++ b/cogs/tts_provider_cog.py @@ -8,6 +8,7 @@ import sys import importlib.util from google.cloud import texttospeech + class TTSProviderCog(commands.Cog): def __init__(self, bot): self.bot = bot @@ -20,6 +21,7 @@ class TTSProviderCog(commands.Cog): async def periodic_cleanup(self): """Periodically clean up old TTS files.""" import asyncio + while not self.bot.is_closed(): # Clean up every hour await asyncio.sleep(3600) # 1 hour @@ -27,7 +29,7 @@ class TTSProviderCog(commands.Cog): def cog_unload(self): """Cancel the cleanup task when the cog is unloaded.""" - if hasattr(self, 'cleanup_task') and self.cleanup_task: + if hasattr(self, "cleanup_task") and self.cleanup_task: self.cleanup_task.cancel() def cleanup_old_files(self): @@ -45,9 +47,16 @@ class TTSProviderCog(commands.Cog): # Find all TTS files older than 1 hour old_files = [] - for pattern in ["./SOUND/tts_*.mp3", "./SOUND/tts_direct_*.mp3", "./SOUND/tts_test_*.mp3"]: + for pattern in [ + "./SOUND/tts_*.mp3", + "./SOUND/tts_direct_*.mp3", + "./SOUND/tts_test_*.mp3", + ]: for file in glob.glob(pattern): - if os.path.exists(file) and os.path.getmtime(file) < current_time - 3600: # 1 hour = 3600 seconds + if ( + os.path.exists(file) + and os.path.getmtime(file) < current_time - 3600 + ): # 1 hour = 3600 seconds old_files.append(file) # Delete old files @@ -67,6 +76,7 @@ class TTSProviderCog(commands.Cog): # Create a unique output file if none is provided if output_file is None: import uuid + output_file = f"./SOUND/tts_direct_{uuid.uuid4().hex}.mp3" # Create output directory if it doesn't exist @@ -76,11 +86,15 @@ class TTSProviderCog(commands.Cog): if provider == "gtts": # Check if gtts is available if importlib.util.find_spec("gtts") is None: - return False, "Google TTS (gtts) is not installed. Run: pip install gtts" + return ( + False, + "Google TTS (gtts) is not installed. Run: pip install gtts", + ) try: from gtts import gTTS - tts = gTTS(text=text, lang='en') + + tts = gTTS(text=text, lang="en") tts.save(output_file) return True, output_file except Exception as e: @@ -93,6 +107,7 @@ class TTSProviderCog(commands.Cog): try: import pyttsx3 + engine = pyttsx3.init() engine.save_to_file(text, output_file) engine.runAndWait() @@ -107,6 +122,7 @@ class TTSProviderCog(commands.Cog): try: from TTS.api import TTS + tts = TTS("tts_models/en/ljspeech/tacotron2-DDC") tts.tts_to_file(text=text, file_path=output_file) return True, output_file @@ -121,15 +137,22 @@ class TTSProviderCog(commands.Cog): try: # On Windows, we'll check if the command exists if platform.system() == "Windows": - result = subprocess.run(["where", "espeak-ng"], capture_output=True, text=True) + result = subprocess.run( + ["where", "espeak-ng"], capture_output=True, text=True + ) espeak_available = result.returncode == 0 else: # On Linux/Mac, we'll use which - result = subprocess.run(["which", "espeak-ng"], capture_output=True, text=True) + result = subprocess.run( + ["which", "espeak-ng"], capture_output=True, text=True + ) espeak_available = result.returncode == 0 if not espeak_available: - return False, "espeak-ng is not installed or not in PATH. Install espeak-ng and make sure it's in your PATH." + return ( + False, + "espeak-ng is not installed or not in PATH. Install espeak-ng and make sure it's in your PATH.", + ) # Create a WAV file first wav_file = output_file.replace(".mp3", ".wav") @@ -146,6 +169,7 @@ class TTSProviderCog(commands.Cog): try: # Try to use pydub for conversion from pydub import AudioSegment + sound = AudioSegment.from_wav(wav_file) sound.export(output_file, format="mp3") # Remove the temporary WAV file @@ -165,17 +189,21 @@ class TTSProviderCog(commands.Cog): elif provider == "google_cloud_tts": # Check if google-cloud-texttospeech is available if importlib.util.find_spec("google.cloud.texttospeech") is None: - return False, "Google Cloud TTS library is not installed. Run: pip install google-cloud-texttospeech" + return ( + False, + "Google Cloud TTS library is not installed. Run: pip install google-cloud-texttospeech", + ) try: - client = texttospeech.TextToSpeechClient() # Assumes GOOGLE_APPLICATION_CREDENTIALS is set + client = ( + texttospeech.TextToSpeechClient() + ) # Assumes GOOGLE_APPLICATION_CREDENTIALS is set input_text = texttospeech.SynthesisInput(text=text) # Specify the voice, using your requested model voice = texttospeech.VoiceSelectionParams( - language_code="en-US", - name="en-US-Chirp3-HD-Autonoe" + language_code="en-US", name="en-US-Chirp3-HD-Autonoe" ) # Specify audio configuration (MP3 output) @@ -184,7 +212,11 @@ class TTSProviderCog(commands.Cog): ) response = client.synthesize_speech( - request={"input": input_text, "voice": voice, "audio_config": audio_config} + request={ + "input": input_text, + "voice": voice, + "audio_config": audio_config, + } ) # The response's audio_content is binary. Write it to the output file. @@ -194,7 +226,9 @@ class TTSProviderCog(commands.Cog): except Exception as e: error_message = f"Error with Google Cloud TTS: {str(e)}" if "quota" in str(e).lower(): - error_message += " This might be a quota issue with your Google Cloud project." + error_message += ( + " This might be a quota issue with your Google Cloud project." + ) elif "credentials" in str(e).lower(): error_message += " Please ensure GOOGLE_APPLICATION_CREDENTIALS environment variable is set correctly." return False, error_message @@ -202,21 +236,29 @@ class TTSProviderCog(commands.Cog): else: return False, f"Unknown TTS provider: {provider}" - @app_commands.command(name="ttsprovider", description="Test different TTS providers") - @app_commands.describe( - provider="Select the TTS provider to use", - text="Text to be spoken" + @app_commands.command( + name="ttsprovider", description="Test different TTS providers" ) - @app_commands.choices(provider=[ - app_commands.Choice(name="Google TTS (Online)", value="gtts"), - app_commands.Choice(name="pyttsx3 (Offline)", value="pyttsx3"), - app_commands.Choice(name="Coqui TTS (AI Voice)", value="coqui"), - app_commands.Choice(name="eSpeak-NG (Offline)", value="espeak"), - app_commands.Choice(name="Google Cloud TTS (Chirp HD)", value="google_cloud_tts") - ]) - async def ttsprovider_slash(self, interaction: discord.Interaction, - provider: str, - text: str = "This is a test of text to speech"): + @app_commands.describe( + provider="Select the TTS provider to use", text="Text to be spoken" + ) + @app_commands.choices( + provider=[ + app_commands.Choice(name="Google TTS (Online)", value="gtts"), + app_commands.Choice(name="pyttsx3 (Offline)", value="pyttsx3"), + app_commands.Choice(name="Coqui TTS (AI Voice)", value="coqui"), + app_commands.Choice(name="eSpeak-NG (Offline)", value="espeak"), + app_commands.Choice( + name="Google Cloud TTS (Chirp HD)", value="google_cloud_tts" + ), + ] + ) + async def ttsprovider_slash( + self, + interaction: discord.Interaction, + provider: str, + text: str = "This is a test of text to speech", + ): """Test different TTS providers""" await interaction.response.defer(thinking=True) @@ -443,9 +485,10 @@ else: # Run the script process = await asyncio.create_subprocess_exec( - sys.executable, script_path, + sys.executable, + script_path, stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stderr=asyncio.subprocess.PIPE, ) # Wait for the process to complete @@ -460,7 +503,7 @@ else: # Extract the output filename from the stdout output_filename = None - for line in stdout_text.split('\n'): + for line in stdout_text.split("\n"): if line.startswith("Using output file:"): output_filename = line.split(":", 1)[1].strip() break @@ -470,6 +513,7 @@ else: # Look for any tts_test_*.mp3 files created in the last minute import glob import time + current_time = time.time() tts_files = [] for file in glob.glob("./SOUND/tts_test_*.mp3"): @@ -488,11 +532,13 @@ else: # Success! Send the audio file await interaction.followup.send( f"✅ Successfully tested TTS provider: {provider}\nText: {text}\nFile: {os.path.basename(output_filename)}", - file=discord.File(output_filename) + file=discord.File(output_filename), ) else: # Failed to generate audio with subprocess, try direct method as fallback - await interaction.followup.send(f"Subprocess method failed. Trying direct TTS generation with {provider}...") + await interaction.followup.send( + f"Subprocess method failed. Trying direct TTS generation with {provider}..." + ) # Try the direct method success, result = await self.generate_tts_directly(provider, text) @@ -501,16 +547,20 @@ else: # Direct method succeeded! await interaction.followup.send( f"✅ Successfully generated TTS audio with {provider} (direct method)\nText: {text}", - file=discord.File(result) + file=discord.File(result), ) return # Both methods failed, send detailed error information - error_message = f"❌ Failed to generate TTS audio with provider: {provider}\n\n" + error_message = ( + f"❌ Failed to generate TTS audio with provider: {provider}\n\n" + ) # Check if the process failed if process.returncode != 0: - error_message += f"Process returned error code: {process.returncode}\n\n" + error_message += ( + f"Process returned error code: {process.returncode}\n\n" + ) # Add direct method error if not success: @@ -525,7 +575,14 @@ else: if "Error with " + provider in full_output: # Extract the specific error message - error_line = next((line for line in full_output.split('\n') if "Error with " + provider in line), "") + error_line = next( + ( + line + for line in full_output.split("\n") + if "Error with " + provider in line + ), + "", + ) if error_line: error_summary += f"- {error_line}\n" @@ -564,7 +621,9 @@ else: output_file = os.path.join(tempfile.gettempdir(), "tts_error_log.txt") with open(output_file, "w", encoding="utf8") as f: f.write(full_output) - await interaction.followup.send("Detailed error log:", file=discord.File(output_file)) + await interaction.followup.send( + "Detailed error log:", file=discord.File(output_file) + ) @commands.command(name="ttscheck") async def tts_check(self, ctx): @@ -577,6 +636,7 @@ else: if gtts_available: try: import gtts + gtts_version = getattr(gtts, "__version__", "Unknown version") except Exception as e: gtts_version = f"Error importing: {str(e)}" @@ -587,6 +647,7 @@ else: if pyttsx3_available: try: import pyttsx3 + pyttsx3_version = "Installed (no version info available)" except Exception as e: pyttsx3_version = f"Error importing: {str(e)}" @@ -597,6 +658,7 @@ else: if coqui_available: try: import TTS + coqui_version = getattr(TTS, "__version__", "Unknown version") except Exception as e: coqui_version = f"Error importing: {str(e)}" @@ -606,18 +668,25 @@ else: try: import subprocess import platform + if platform.system() == "Windows": # On Windows, we'll check if the command exists - result = subprocess.run(["where", "espeak-ng"], capture_output=True, text=True) + result = subprocess.run( + ["where", "espeak-ng"], capture_output=True, text=True + ) espeak_available = result.returncode == 0 else: # On Linux/Mac, we'll use which - result = subprocess.run(["which", "espeak-ng"], capture_output=True, text=True) + result = subprocess.run( + ["which", "espeak-ng"], capture_output=True, text=True + ) espeak_available = result.returncode == 0 if espeak_available: # Try to get version - version_result = subprocess.run(["espeak-ng", "--version"], capture_output=True, text=True) + version_result = subprocess.run( + ["espeak-ng", "--version"], capture_output=True, text=True + ) if version_result.returncode == 0: espeak_version = version_result.stdout.strip() else: @@ -628,12 +697,17 @@ else: espeak_version = f"Error checking: {str(e)}" # Check for Google Cloud TTS - gcloud_tts_available = importlib.util.find_spec("google.cloud.texttospeech") is not None + gcloud_tts_available = ( + importlib.util.find_spec("google.cloud.texttospeech") is not None + ) gcloud_tts_version = "Not installed" if gcloud_tts_available: try: import google.cloud.texttospeech as gcloud_tts_module - gcloud_tts_version = getattr(gcloud_tts_module, "__version__", "Unknown version") + + gcloud_tts_version = getattr( + gcloud_tts_module, "__version__", "Unknown version" + ) except Exception as e: gcloud_tts_version = f"Error importing: {str(e)}" @@ -657,6 +731,7 @@ else: await ctx.send(report) + async def setup(bot: commands.Bot): print("Loading TTSProviderCog...") await bot.add_cog(TTSProviderCog(bot)) diff --git a/cogs/upload_cog.py b/cogs/upload_cog.py index 22806f4..77d05b3 100644 --- a/cogs/upload_cog.py +++ b/cogs/upload_cog.py @@ -10,6 +10,7 @@ import base64 from typing import Optional, Dict, Any, Union import discord.ui + class CaptchaModal(discord.ui.Modal, title="Solve Image Captcha"): def __init__(self, captcha_id: str): # Set the modal timeout to 10 minutes to match the view's timeout for consistency @@ -20,7 +21,7 @@ class CaptchaModal(discord.ui.Modal, title="Solve Image Captcha"): placeholder="Enter the text from the image...", required=True, min_length=1, - max_length=100 + max_length=100, ) self.add_item(self.solution) self.interaction = None @@ -30,7 +31,9 @@ class CaptchaModal(discord.ui.Modal, title="Solve Image Captcha"): # This method will be called when the user submits the modal # Store the interaction for later use self.interaction = interaction - await interaction.response.defer(ephemeral=True) # Defer the response to prevent interaction timeout + await interaction.response.defer( + ephemeral=True + ) # Defer the response to prevent interaction timeout # Set the event to signal that the modal has been submitted self.is_submitted.set() @@ -49,15 +52,20 @@ class CaptchaModal(discord.ui.Modal, title="Solve Image Captcha"): class CaptchaView(discord.ui.View): def __init__(self, modal: CaptchaModal, original_interactor_id: int): - super().__init__(timeout=600) # 10 minutes timeout + super().__init__(timeout=600) # 10 minutes timeout self.modal = modal self.original_interactor_id = original_interactor_id @discord.ui.button(label="Solve Captcha", style=discord.ButtonStyle.primary) - async def solve_button(self, interaction: discord.Interaction, button: discord.ui.Button): + async def solve_button( + self, interaction: discord.Interaction, button: discord.ui.Button + ): # Check if the user interacting is the original user who initiated the command if interaction.user.id != self.original_interactor_id: - await interaction.response.send_message("Only the user who initiated this command can solve the captcha.", ephemeral=True) + await interaction.response.send_message( + "Only the user who initiated this command can solve the captcha.", + ephemeral=True, + ) return await interaction.response.send_modal(self.modal) @@ -81,7 +89,7 @@ class UploadCog(commands.Cog, name="Upload"): self.upload_group = app_commands.Group( name="upload", description="Commands for interacting with the upload API", - guild_only=False + guild_only=False, ) # Register commands @@ -108,16 +116,18 @@ class UploadCog(commands.Cog, name="Upload"): upload_file_command = app_commands.Command( name="file", description="Upload a file, interactively solving an image captcha", - callback=self.upload_file_interactive_callback, # New callback name - parent=self.upload_group + callback=self.upload_file_interactive_callback, # New callback name + parent=self.upload_group, ) app_commands.describe( file="The file to upload", - expires_after="Time in seconds until the file expires (default: 86400 - 24 hours)" + expires_after="Time in seconds until the file expires (default: 86400 - 24 hours)", )(upload_file_command) self.upload_group.add_command(upload_file_command) - async def _make_api_request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]: + async def _make_api_request( + self, method: str, endpoint: str, **kwargs + ) -> Dict[str, Any]: """Make a request to the API""" print(f"Making {method} request to {endpoint} with params: {kwargs}") if not self.session: @@ -144,15 +154,19 @@ class UploadCog(commands.Cog, name="Upload"): return await response.json() else: error_text = await response.text() - raise Exception(f"API request failed: {response.status} - {error_text}") + raise Exception( + f"API request failed: {response.status} - {error_text}" + ) elif method.upper() == "POST": print(f"Sending POST request to {url}") # If we're sending form data, make sure we don't manually set Content-Type - if 'data' in kwargs and isinstance(kwargs['data'], aiohttp.FormData): + if "data" in kwargs and isinstance(kwargs["data"], aiohttp.FormData): print("Sending multipart/form-data request") # aiohttp will automatically set the correct Content-Type with boundary - async with self.session.post(url, headers=request_headers, **kwargs) as response: + async with self.session.post( + url, headers=request_headers, **kwargs + ) as response: print(f"Response status: {response.status}") print(f"Response headers: {response.headers}") @@ -161,22 +175,29 @@ class UploadCog(commands.Cog, name="Upload"): else: error_text = await response.text() print(f"Error response body: {error_text}") - raise Exception(f"API request failed: {response.status} - {error_text}") + raise Exception( + f"API request failed: {response.status} - {error_text}" + ) else: raise ValueError(f"Unsupported HTTP method: {method}") except Exception as e: print(f"Error making API request to {url}: {e}") raise - async def upload_file_interactive_callback(self, interaction: discord.Interaction, - file: discord.Attachment, - expires_after: Optional[int] = 86400): + async def upload_file_interactive_callback( + self, + interaction: discord.Interaction, + file: discord.Attachment, + expires_after: Optional[int] = 86400, + ): """Upload a file, interactively solving an image captcha""" - await interaction.response.defer(ephemeral=False) # Defer the initial response + await interaction.response.defer(ephemeral=False) # Defer the initial response try: # 1. Generate Image Captcha - captcha_data = await self._make_api_request("GET", "/upload/api/captcha/image") + captcha_data = await self._make_api_request( + "GET", "/upload/api/captcha/image" + ) captcha_id = captcha_data.get("captcha_id") image_data = captcha_data.get("image", "") @@ -194,11 +215,13 @@ class UploadCog(commands.Cog, name="Upload"): embed = discord.Embed( title="Image Captcha Challenge", description="Please solve the captcha challenge to upload your file.", - color=discord.Color.blue() + color=discord.Color.blue(), ) embed.add_field(name="Captcha ID", value=f"`{captcha_id}`", inline=False) embed.set_image(url="attachment://captcha.png") - embed.set_footer(text="This captcha is valid for 10 minutes. Enter the text from the image.") + embed.set_footer( + text="This captcha is valid for 10 minutes. Enter the text from the image." + ) # Create the modal instance modal = CaptchaModal(captcha_id=captcha_id) @@ -207,24 +230,32 @@ class UploadCog(commands.Cog, name="Upload"): view = CaptchaView(modal=modal, original_interactor_id=interaction.user.id) # Send the captcha image, instructions, and the view in a single message - await interaction.followup.send(embed=embed, file=captcha_image_file, view=view, ephemeral=True) + await interaction.followup.send( + embed=embed, file=captcha_image_file, view=view, ephemeral=True + ) # Wait for the modal submission timed_out = await modal.wait() if timed_out: - await interaction.followup.send("Captcha solution timed out. Please try again.", ephemeral=True) + await interaction.followup.send( + "Captcha solution timed out. Please try again.", ephemeral=True + ) return captcha_solution = modal.solution.value # 3. Proceed with File Upload if not file: - await interaction.followup.send("Please provide a file to upload", ephemeral=True) + await interaction.followup.send( + "Please provide a file to upload", ephemeral=True + ) return if not captcha_solution: - await interaction.followup.send("Captcha solution was not provided.", ephemeral=True) + await interaction.followup.send( + "Captcha solution was not provided.", ephemeral=True + ) return # Download the file @@ -232,18 +263,27 @@ class UploadCog(commands.Cog, name="Upload"): # Prepare form data form_data = aiohttp.FormData() - form_data.add_field('file', file_bytes, filename=file.filename, content_type=file.content_type) - form_data.add_field('captcha_id', captcha_id) - form_data.add_field('captcha_solution', captcha_solution) - form_data.add_field('expires_after', str(expires_after)) + form_data.add_field( + "file", + file_bytes, + filename=file.filename, + content_type=file.content_type, + ) + form_data.add_field("captcha_id", captcha_id) + form_data.add_field("captcha_solution", captcha_solution) + form_data.add_field("expires_after", str(expires_after)) # Debug form data fields - print(f"Form data fields: file, captcha_id={captcha_id}, captcha_solution={captcha_solution}, expires_after={expires_after}") + print( + f"Form data fields: file, captcha_id={captcha_id}, captcha_solution={captcha_solution}, expires_after={expires_after}" + ) # Make API request to upload file try: print("Attempting to upload file to third-party endpoint...") - upload_data = await self._make_api_request("POST", "/upload/api/upload/third-party", data=form_data) + upload_data = await self._make_api_request( + "POST", "/upload/api/upload/third-party", data=form_data + ) print(f"Upload successful, received data: {upload_data}") except Exception as e: print(f"Upload failed with error: {e}") @@ -257,24 +297,34 @@ class UploadCog(commands.Cog, name="Upload"): else: error_text = await response.text() print(f"Direct upload failed: {response.status} - {error_text}") - raise Exception(f"API request failed: {response.status} - {error_text}") + raise Exception( + f"API request failed: {response.status} - {error_text}" + ) # Poll until access_ready is true or timeout after 30 seconds file_id = upload_data.get("id", "unknown") file_extension = upload_data.get("file_extension", "") # Append file extension to the URL if available - file_url = f"https://slipstreamm.dev/uploads/{file_id}" + (f".{file_extension}" if file_extension else "") + file_url = f"https://slipstreamm.dev/uploads/{file_id}" + ( + f".{file_extension}" if file_extension else "" + ) # Send initial message that we're waiting for the file to be processed - status_message = await interaction.followup.send("File uploaded successfully. Waiting for file processing to complete...") + status_message = await interaction.followup.send( + "File uploaded successfully. Waiting for file processing to complete..." + ) # Poll for access_ready status max_attempts = 30 # 30 seconds max wait time for attempt in range(max_attempts): try: # Get the current file status using the correct endpoint - file_status_id = f"{file_id}.{file_extension}" if file_extension else file_id - file_status = await self._make_api_request("GET", f"/upload/api/file-status/{file_status_id}") + file_status_id = ( + f"{file_id}.{file_extension}" if file_extension else file_id + ) + file_status = await self._make_api_request( + "GET", f"/upload/api/file-status/{file_status_id}" + ) print(f"File status poll attempt {attempt+1}: {file_status}") if file_status.get("access_ready", False): @@ -283,7 +333,9 @@ class UploadCog(commands.Cog, name="Upload"): upload_data = file_status # Update file_url with the latest file extension file_extension = file_status.get("file_extension", "") - file_url = f"https://slipstreamm.dev/uploads/{file_id}" + (f".{file_extension}" if file_extension else "") + file_url = f"https://slipstreamm.dev/uploads/{file_id}" + ( + f".{file_extension}" if file_extension else "" + ) break # Wait 1 second before polling again @@ -297,11 +349,11 @@ class UploadCog(commands.Cog, name="Upload"): embed = discord.Embed( title="File Uploaded Successfully", description=f"Your file has been uploaded and will expire in {expires_after} seconds", - color=discord.Color.green() + color=discord.Color.green(), ) # Format file size nicely - file_size_bytes = upload_data.get('size', 0) + file_size_bytes = upload_data.get("size", 0) if file_size_bytes < 1024: file_size_str = f"{file_size_bytes} bytes" elif file_size_bytes < 1024 * 1024: @@ -310,10 +362,22 @@ class UploadCog(commands.Cog, name="Upload"): file_size_str = f"{file_size_bytes / (1024 * 1024):.2f} MB" embed.add_field(name="File ID", value=file_id, inline=True) - embed.add_field(name="Original Name", value=upload_data.get("file_name", "unknown"), inline=True) + embed.add_field( + name="Original Name", + value=upload_data.get("file_name", "unknown"), + inline=True, + ) embed.add_field(name="File Size", value=file_size_str, inline=True) - embed.add_field(name="Content Type", value=upload_data.get("content_type", "unknown"), inline=True) - embed.add_field(name="Scan Status", value=upload_data.get("scan_status", "unknown"), inline=True) + embed.add_field( + name="Content Type", + value=upload_data.get("content_type", "unknown"), + inline=True, + ) + embed.add_field( + name="Scan Status", + value=upload_data.get("scan_status", "unknown"), + inline=True, + ) embed.add_field(name="File URL", value=file_url, inline=False) # Add clickable link @@ -324,7 +388,10 @@ class UploadCog(commands.Cog, name="Upload"): except Exception as e: # If an error occurs during captcha generation or upload, send an ephemeral error message - await interaction.followup.send(f"Error during file upload process: {e}", ephemeral=True) + await interaction.followup.send( + f"Error during file upload process: {e}", ephemeral=True + ) + async def setup(bot: commands.Bot): """Add the UploadCog to the bot.""" diff --git a/cogs/user_info_cog.py b/cogs/user_info_cog.py index 4aadc0b..6c4e1e5 100644 --- a/cogs/user_info_cog.py +++ b/cogs/user_info_cog.py @@ -3,11 +3,14 @@ from discord.ext import commands from discord import AllowedMentions, ui from datetime import datetime, timedelta, timezone + class UserInfoCog(commands.Cog): def __init__(self, bot: commands.Bot): self.bot = bot - @commands.hybrid_command(name="userinfo", description="Displays detailed information about a user.") + @commands.hybrid_command( + name="userinfo", description="Displays detailed information about a user." + ) async def userinfo(self, ctx: commands.Context, member: discord.Member = None): """Displays detailed information about a user.""" if member is None: @@ -22,10 +25,16 @@ class UserInfoCog(commands.Cog): try: member = await ctx.guild.fetch_member(member.id) # roles/nick/etc. except discord.NotFound: - await ctx.send("Could not find the specified member in this server.", ephemeral=True) + await ctx.send( + "Could not find the specified member in this server.", + ephemeral=True, + ) return except discord.HTTPException as e: - await ctx.send(f"An error occurred while fetching member data: `{e}`", ephemeral=True) + await ctx.send( + f"An error occurred while fetching member data: `{e}`", + ephemeral=True, + ) return username_discriminator = ( @@ -40,7 +49,9 @@ class UserInfoCog(commands.Cog): else "N/A" ) - roles = [role.mention for role in reversed(member.roles) if role.name != "@everyone"] + roles = [ + role.mention for role in reversed(member.roles) if role.name != "@everyone" + ] roles_str = ", ".join(roles) if roles else "None" if len(roles_str) > 1000: # Discord limits field values roles_str = roles_str[:997] + "..." @@ -48,56 +59,67 @@ class UserInfoCog(commands.Cog): status_str = str(member.status).title() activity_str = ( f"Playing {member.activity.name}" - if member.activity - and member.activity.type is discord.ActivityType.playing - else f"Streaming {member.activity.name}" - if member.activity - and member.activity.type is discord.ActivityType.streaming - else f"Listening to {member.activity.title}…" - if member.activity - and member.activity.type is discord.ActivityType.listening - else f"Watching {member.activity.name}" - if member.activity - and member.activity.type is discord.ActivityType.watching - else f"{member.activity.emoji} {member.activity.state}".strip() - if member.activity - and member.activity.type is discord.ActivityType.custom - else "None" + if member.activity and member.activity.type is discord.ActivityType.playing + else ( + f"Streaming {member.activity.name}" + if member.activity + and member.activity.type is discord.ActivityType.streaming + else ( + f"Listening to {member.activity.title}…" + if member.activity + and member.activity.type is discord.ActivityType.listening + else ( + f"Watching {member.activity.name}" + if member.activity + and member.activity.type is discord.ActivityType.watching + else ( + f"{member.activity.emoji} {member.activity.state}".strip() + if member.activity + and member.activity.type is discord.ActivityType.custom + else "None" + ) + ) + ) + ) ) # Badges / Flags flags = member.public_flags # this is a PublicUserFlags instance - badges = [ - name.replace("_", " ").title() - for name, enabled in flags - if enabled - ] + badges = [name.replace("_", " ").title() for name, enabled in flags if enabled] badges_str = ", ".join(badges) or "None" # Pronouns pronouns_str = getattr(member, "pronouns", "N/A") # API v10-beta # Avatar Type - avatar_type = "GIF" if member.avatar and member.avatar.is_animated() else "Static" + avatar_type = ( + "GIF" if member.avatar and member.avatar.is_animated() else "Static" + ) # --- FIXED: use aware UTC datetime for “now” --- now_utc = datetime.now(timezone.utc) # Account Age account_age = now_utc - member.created_at - account_age_str = f"{account_age.days // 365} years, {(account_age.days % 365) // 30} months" + account_age_str = ( + f"{account_age.days // 365} years, {(account_age.days % 365) // 30} months" + ) # Join Position join_position_str = "N/A" if ctx.guild and member.joined_at: sorted_members = sorted( ctx.guild.members, - key=lambda m: m.joined_at - if m.joined_at - else datetime.min.replace(tzinfo=timezone.utc), + key=lambda m: ( + m.joined_at + if m.joined_at + else datetime.min.replace(tzinfo=timezone.utc) + ), ) try: - join_position_str = f"{sorted_members.index(member) + 1} of {len(sorted_members)}" + join_position_str = ( + f"{sorted_members.index(member) + 1} of {len(sorted_members)}" + ) except ValueError: pass # Member not found in sorted list (should not happen) @@ -204,25 +226,37 @@ class UserInfoCog(commands.Cog): ) ) main_container.add_item(header_section) - header_section.add_item(ui.TextDisplay(f"**{target_member.display_name}**")) header_section.add_item( - ui.TextDisplay(f"({username_discriminator}) - ID: {target_member.id}") + ui.TextDisplay(f"**{target_member.display_name}**") + ) + header_section.add_item( + ui.TextDisplay( + f"({username_discriminator}) - ID: {target_member.id}" + ) ) - main_container.add_item(ui.Separator(spacing=discord.SeparatorSpacing.small)) + main_container.add_item( + ui.Separator(spacing=discord.SeparatorSpacing.small) + ) # Account & Profile main_container.add_item( - ui.TextDisplay(f"**Account Created:** {created_at_str} ({account_age_str} ago)") + ui.TextDisplay( + f"**Account Created:** {created_at_str} ({account_age_str} ago)" + ) + ) + main_container.add_item( + ui.TextDisplay(f"**Avatar Type:** {avatar_type}") ) - main_container.add_item(ui.TextDisplay(f"**Avatar Type:** {avatar_type}")) main_container.add_item(ui.TextDisplay(f"**Badges:** {badges_str}")) if pronouns_str != "N/A": main_container.add_item( ui.TextDisplay(f"**Pronouns:** {pronouns_str}") ) - main_container.add_item(ui.Separator(spacing=discord.SeparatorSpacing.small)) + main_container.add_item( + ui.Separator(spacing=discord.SeparatorSpacing.small) + ) # Guild-specific if ctx.guild: @@ -295,7 +329,9 @@ class UserInfoCog(commands.Cog): voice_state_details.append("Video On") if voice_state_details: main_container.add_item( - ui.TextDisplay(f"**Voice State:** {', '.join(voice_state_details)}") + ui.TextDisplay( + f"**Voice State:** {', '.join(voice_state_details)}" + ) ) try: @@ -303,19 +339,23 @@ class UserInfoCog(commands.Cog): await ctx.send( view=view, ephemeral=False, - allowed_mentions=AllowedMentions(roles=False, users=False, everyone=False), + allowed_mentions=AllowedMentions( + roles=False, users=False, everyone=False + ), ) except Exception as e: import traceback traceback.print_exc() await ctx.send( - f"An error occurred while creating the user info display: `{e}`", ephemeral=True + f"An error occurred while creating the user info display: `{e}`", + ephemeral=True, ) @commands.Cog.listener() async def on_ready(self): print(f"{self.__class__.__name__} cog has been loaded.") + async def setup(bot: commands.Bot): await bot.add_cog(UserInfoCog(bot)) diff --git a/cogs/webdrivertorso_cog.py b/cogs/webdrivertorso_cog.py index 9f2222c..8bb4fed 100644 --- a/cogs/webdrivertorso_cog.py +++ b/cogs/webdrivertorso_cog.py @@ -9,6 +9,7 @@ import glob import sys import importlib.util + class JSON: def read(file): with open(f"{file}.json", "r", encoding="utf8") as file: @@ -19,6 +20,7 @@ class JSON: with open(f"{file}.json", "w", encoding="utf8") as file: json.dump(data, file, indent=4) + class WebdriverTorsoCog(commands.Cog): def __init__(self, bot): self.bot = bot @@ -55,16 +57,76 @@ class WebdriverTorsoCog(commands.Cog): "TEXT_SIZE": 0, "TEXT_POSITION": "top-left", "COLOR_SCHEMES": { - "pastel": [[255, 182, 193], [176, 224, 230], [240, 230, 140], [221, 160, 221], [152, 251, 152]], - "dark_gritty": [[47, 79, 79], [105, 105, 105], [0, 0, 0], [85, 107, 47], [139, 69, 19]], - "nature": [[34, 139, 34], [107, 142, 35], [46, 139, 87], [32, 178, 170], [154, 205, 50]], - "vibrant": [[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [255, 0, 255]], - "ocean": [[0, 105, 148], [72, 209, 204], [70, 130, 180], [135, 206, 250], [176, 224, 230]], - "neon": [[255, 0, 102], [0, 255, 255], [255, 255, 0], [0, 255, 0], [255, 0, 255]], - "monochrome": [[0, 0, 0], [50, 50, 50], [100, 100, 100], [150, 150, 150], [200, 200, 200]], - "autumn": [[165, 42, 42], [210, 105, 30], [139, 69, 19], [160, 82, 45], [205, 133, 63]], - "cyberpunk": [[255, 0, 128], [0, 255, 255], [128, 0, 255], [255, 255, 0], [0, 255, 128]], - "retro": [[255, 204, 0], [255, 102, 0], [204, 0, 0], [0, 102, 204], [0, 204, 102]] + "pastel": [ + [255, 182, 193], + [176, 224, 230], + [240, 230, 140], + [221, 160, 221], + [152, 251, 152], + ], + "dark_gritty": [ + [47, 79, 79], + [105, 105, 105], + [0, 0, 0], + [85, 107, 47], + [139, 69, 19], + ], + "nature": [ + [34, 139, 34], + [107, 142, 35], + [46, 139, 87], + [32, 178, 170], + [154, 205, 50], + ], + "vibrant": [ + [255, 0, 0], + [0, 255, 0], + [0, 0, 255], + [255, 255, 0], + [255, 0, 255], + ], + "ocean": [ + [0, 105, 148], + [72, 209, 204], + [70, 130, 180], + [135, 206, 250], + [176, 224, 230], + ], + "neon": [ + [255, 0, 102], + [0, 255, 255], + [255, 255, 0], + [0, 255, 0], + [255, 0, 255], + ], + "monochrome": [ + [0, 0, 0], + [50, 50, 50], + [100, 100, 100], + [150, 150, 150], + [200, 200, 200], + ], + "autumn": [ + [165, 42, 42], + [210, 105, 30], + [139, 69, 19], + [160, 82, 45], + [205, 133, 63], + ], + "cyberpunk": [ + [255, 0, 128], + [0, 255, 255], + [128, 0, 255], + [255, 255, 0], + [0, 255, 128], + ], + "retro": [ + [255, 204, 0], + [255, 102, 0], + [204, 0, 0], + [0, 102, 204], + [0, 204, 102], + ], }, "WAVE_VIBES": { "calm": {"frequency": 200, "amplitude": 0.3, "modulation": 0.1}, @@ -77,36 +139,217 @@ class WebdriverTorsoCog(commands.Cog): "underwater": {"frequency": 300, "amplitude": 0.6, "modulation": 0.3}, "mechanical": {"frequency": 500, "amplitude": 0.5, "modulation": 0.1}, "ethereal": {"frequency": 700, "amplitude": 0.4, "modulation": 0.8}, - "pulsating": {"frequency": 900, "amplitude": 0.7, "modulation": 0.6} + "pulsating": {"frequency": 900, "amplitude": 0.7, "modulation": 0.6}, }, "WORD_TOPICS": { - "introspective": ["reflection", "thought", "solitude", "ponder", "meditation", "introspection", "awareness", "contemplation", "silence", "stillness"], - "action": ["run", "jump", "climb", "race", "fight", "explore", "build", "create", "overcome", "achieve"], - "nature": ["tree", "mountain", "river", "ocean", "flower", "forest", "animal", "sky", "valley", "meadow"], - "technology": ["computer", "robot", "network", "data", "algorithm", "innovation", "digital", "machine", "software", "hardware"], - "space": ["star", "planet", "galaxy", "cosmos", "orbit", "nebula", "asteroid", "comet", "universe", "void"], - "ocean": ["wave", "coral", "fish", "shark", "seaweed", "tide", "reef", "abyss", "current", "marine"], - "fantasy": ["dragon", "wizard", "magic", "quest", "sword", "spell", "kingdom", "myth", "legend", "fairy"], - "science": ["experiment", "theory", "hypothesis", "research", "discovery", "laboratory", "element", "molecule", "atom", "energy"], - "art": ["canvas", "paint", "sculpture", "gallery", "artist", "creativity", "expression", "masterpiece", "composition", "design"], - "music": ["melody", "rhythm", "harmony", "song", "instrument", "concert", "symphony", "chord", "note", "beat"], - "food": ["cuisine", "flavor", "recipe", "ingredient", "taste", "dish", "spice", "dessert", "meal", "delicacy"], - "emotions": ["joy", "sorrow", "anger", "fear", "love", "hate", "surprise", "disgust", "anticipation", "trust"], - "colors": ["red", "blue", "green", "yellow", "purple", "orange", "black", "white", "pink", "teal"], - "abstract": ["concept", "idea", "thought", "theory", "philosophy", "abstraction", "notion", "principle", "essence", "paradigm"] - } + "introspective": [ + "reflection", + "thought", + "solitude", + "ponder", + "meditation", + "introspection", + "awareness", + "contemplation", + "silence", + "stillness", + ], + "action": [ + "run", + "jump", + "climb", + "race", + "fight", + "explore", + "build", + "create", + "overcome", + "achieve", + ], + "nature": [ + "tree", + "mountain", + "river", + "ocean", + "flower", + "forest", + "animal", + "sky", + "valley", + "meadow", + ], + "technology": [ + "computer", + "robot", + "network", + "data", + "algorithm", + "innovation", + "digital", + "machine", + "software", + "hardware", + ], + "space": [ + "star", + "planet", + "galaxy", + "cosmos", + "orbit", + "nebula", + "asteroid", + "comet", + "universe", + "void", + ], + "ocean": [ + "wave", + "coral", + "fish", + "shark", + "seaweed", + "tide", + "reef", + "abyss", + "current", + "marine", + ], + "fantasy": [ + "dragon", + "wizard", + "magic", + "quest", + "sword", + "spell", + "kingdom", + "myth", + "legend", + "fairy", + ], + "science": [ + "experiment", + "theory", + "hypothesis", + "research", + "discovery", + "laboratory", + "element", + "molecule", + "atom", + "energy", + ], + "art": [ + "canvas", + "paint", + "sculpture", + "gallery", + "artist", + "creativity", + "expression", + "masterpiece", + "composition", + "design", + ], + "music": [ + "melody", + "rhythm", + "harmony", + "song", + "instrument", + "concert", + "symphony", + "chord", + "note", + "beat", + ], + "food": [ + "cuisine", + "flavor", + "recipe", + "ingredient", + "taste", + "dish", + "spice", + "dessert", + "meal", + "delicacy", + ], + "emotions": [ + "joy", + "sorrow", + "anger", + "fear", + "love", + "hate", + "surprise", + "disgust", + "anticipation", + "trust", + ], + "colors": [ + "red", + "blue", + "green", + "yellow", + "purple", + "orange", + "black", + "white", + "pink", + "teal", + ], + "abstract": [ + "concept", + "idea", + "thought", + "theory", + "philosophy", + "abstraction", + "notion", + "principle", + "essence", + "paradigm", + ], + }, } # Create directories if they don't exist for directory in ["IMG", "SOUND", "OUTPUT", "FONT"]: os.makedirs(directory, exist_ok=True) - async def _generate_video_logic(self, ctx_or_interaction, width=None, height=None, max_width=None, max_height=None, - min_width=None, min_height=None, slides=None, videos=None, min_shapes=None, max_shapes=None, - sound_quality=None, tts_enabled=None, tts_text=None, tts_provider=None, audio_wave_type=None, slide_duration=None, - deform_level=None, color_mode=None, color_scheme=None, solid_color=None, allowed_shapes=None, - wave_vibe=None, top_left_text_enabled=None, top_left_text_mode=None, words_topic=None, - text_color=None, text_size=None, text_position=None, already_deferred=False): + async def _generate_video_logic( + self, + ctx_or_interaction, + width=None, + height=None, + max_width=None, + max_height=None, + min_width=None, + min_height=None, + slides=None, + videos=None, + min_shapes=None, + max_shapes=None, + sound_quality=None, + tts_enabled=None, + tts_text=None, + tts_provider=None, + audio_wave_type=None, + slide_duration=None, + deform_level=None, + color_mode=None, + color_scheme=None, + solid_color=None, + allowed_shapes=None, + wave_vibe=None, + top_left_text_enabled=None, + top_left_text_mode=None, + words_topic=None, + text_color=None, + text_size=None, + text_position=None, + already_deferred=False, + ): """Core logic for the webdrivertorso command.""" # Check if already processing a video if self.is_processing: @@ -145,7 +388,9 @@ class WebdriverTorsoCog(commands.Cog): config_data["TTS_ENABLED"] = tts_enabled if tts_text is not None: config_data["TTS_TEXT"] = tts_text - if tts_enabled is None: # Only set to True if not explicitly set to False + if ( + tts_enabled is None + ): # Only set to True if not explicitly set to False config_data["TTS_ENABLED"] = True if tts_provider is not None: config_data["TTS_PROVIDER"] = tts_provider @@ -180,7 +425,7 @@ class WebdriverTorsoCog(commands.Cog): # Clean directories for directory in ["IMG", "SOUND"]: - for file in glob.glob(f'./{directory}/*'): + for file in glob.glob(f"./{directory}/*"): try: os.remove(file) except Exception as e: @@ -198,27 +443,35 @@ class WebdriverTorsoCog(commands.Cog): script_content = f.read() # Create a temporary config file for this run only - temp_config_path = os.path.join(tempfile.gettempdir(), "webdrivertorso_temp_config.json") + temp_config_path = os.path.join( + tempfile.gettempdir(), "webdrivertorso_temp_config.json" + ) with open(temp_config_path, "w", encoding="utf8") as f: json.dump(config_data, f, indent=4) # Replace the config file path in the script content - script_content = script_content.replace('config_data = JSON.read("config")', f'config_data = JSON.read("{os.path.splitext(temp_config_path)[0]}")') + script_content = script_content.replace( + 'config_data = JSON.read("config")', + f'config_data = JSON.read("{os.path.splitext(temp_config_path)[0]}")', + ) with open(script_path, "w", encoding="utf8") as f: f.write(script_content) # Send initial message if isinstance(ctx_or_interaction, commands.Context): - await ctx_or_interaction.reply("🎬 Generating Webdriver Torso style video... This may take a minute.") + await ctx_or_interaction.reply( + "🎬 Generating Webdriver Torso style video... This may take a minute." + ) elif not already_deferred: # It's an Interaction and not deferred yet await ctx_or_interaction.response.defer(thinking=True) # Run the script as a subprocess process = await asyncio.create_subprocess_exec( - sys.executable, script_path, + sys.executable, + script_path, stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stderr=asyncio.subprocess.PIPE, ) # Wait for the process to complete @@ -230,7 +483,7 @@ class WebdriverTorsoCog(commands.Cog): return f"❌ Error generating video: {error_msg}" # Find the generated video file - video_files = glob.glob('./OUTPUT/*.mp4') + video_files = glob.glob("./OUTPUT/*.mp4") if not video_files: return "❌ No video files were generated." @@ -299,42 +552,46 @@ class WebdriverTorsoCog(commands.Cog): if options: option_pairs = options.split() for pair in option_pairs: - if '=' in pair: - key, value = pair.split('=', 1) + if "=" in pair: + key, value = pair.split("=", 1) # Convert string values to appropriate types - if value.lower() == 'true': + if value.lower() == "true": params[key] = True - elif value.lower() == 'false': + elif value.lower() == "false": params[key] = False elif value.isdigit(): params[key] = int(value) - elif key == 'allowed_shapes' and value.startswith('[') and value.endswith(']'): + elif ( + key == "allowed_shapes" + and value.startswith("[") + and value.endswith("]") + ): # Parse list of shapes - shapes_list = value[1:-1].split(',') + shapes_list = value[1:-1].split(",") params[key] = [shape.strip() for shape in shapes_list] # Handle combined parameters - elif key == 'dimensions' and ',' in value: - width, height = value.split(',', 1) + elif key == "dimensions" and "," in value: + width, height = value.split(",", 1) if width.strip().isdigit(): - params['width'] = int(width.strip()) + params["width"] = int(width.strip()) if height.strip().isdigit(): - params['height'] = int(height.strip()) - elif key == 'shape_size_limits' and ',' in value: - parts = value.split(',') + params["height"] = int(height.strip()) + elif key == "shape_size_limits" and "," in value: + parts = value.split(",") if len(parts) >= 1 and parts[0].strip().isdigit(): - params['min_width'] = int(parts[0].strip()) + params["min_width"] = int(parts[0].strip()) if len(parts) >= 2 and parts[1].strip().isdigit(): - params['min_height'] = int(parts[1].strip()) + params["min_height"] = int(parts[1].strip()) if len(parts) >= 3 and parts[2].strip().isdigit(): - params['max_width'] = int(parts[2].strip()) + params["max_width"] = int(parts[2].strip()) if len(parts) >= 4 and parts[3].strip().isdigit(): - params['max_height'] = int(parts[3].strip()) - elif key == 'shapes_count' and ',' in value: - min_shapes, max_shapes = value.split(',', 1) + params["max_height"] = int(parts[3].strip()) + elif key == "shapes_count" and "," in value: + min_shapes, max_shapes = value.split(",", 1) if min_shapes.strip().isdigit(): - params['min_shapes'] = int(min_shapes.strip()) + params["min_shapes"] = int(min_shapes.strip()) if max_shapes.strip().isdigit(): - params['max_shapes'] = int(max_shapes.strip()) + params["max_shapes"] = int(max_shapes.strip()) else: params[key] = value @@ -345,27 +602,25 @@ class WebdriverTorsoCog(commands.Cog): await ctx.reply(result) # --- Slash Command --- - @app_commands.command(name="webdrivertorso", description="Generate a Webdriver Torso style test video") + @app_commands.command( + name="webdrivertorso", description="Generate a Webdriver Torso style test video" + ) @app_commands.describe( # Video structure slides="Number of slides in the video (default: 10)", videos="Number of videos to generate (default: 1)", slide_duration="Duration of each slide in milliseconds (default: 1000)", - # Video dimensions dimensions="Video dimensions in format 'width,height' (default: '640,480')", shape_size_limits="Shape size limits in format 'min_width,min_height,max_width,max_height' (default: '20,20,200,200')", - # Shapes shapes_count="Number of shapes per slide in format 'min,max' (default: '5,15')", deform_level="Level of shape deformation (none, low, medium, high)", shape_types="Types of shapes to include (comma-separated list)", - # Colors color_mode="Color mode for shapes (random, scheme, solid)", color_scheme="Color scheme to use (pastel, dark_gritty, nature, vibrant, ocean)", solid_color="Hex color code for solid color mode (#RRGGBB)", - # Audio sound_quality="Audio sample rate (default: 44100)", audio_wave_type="Type of audio wave (sawtooth, sine, square)", @@ -373,133 +628,148 @@ class WebdriverTorsoCog(commands.Cog): tts_enabled="Enable text-to-speech (default: false)", tts_provider="TTS provider to use (gtts, pyttsx3, coqui)", tts_text="Text to be spoken in the video", - # Text top_left_text_enabled="Show text in top-left corner (default: true)", top_left_text_mode="Mode for top-left text (random, word)", words_topic="Topic for word generation (random, introspective, action, nature, technology, etc.)", text_color="Color of text (hex code or name)", text_size="Size of text (default: auto-scaled)", - text_position="Position of text (top-left, top-right, bottom-left, bottom-right, center)" + text_position="Position of text (top-left, top-right, bottom-left, bottom-right, center)", ) - @app_commands.choices(deform_level=[ - app_commands.Choice(name="None", value="none"), - app_commands.Choice(name="Low", value="low"), - app_commands.Choice(name="Medium", value="medium"), - app_commands.Choice(name="High", value="high") - ]) - @app_commands.choices(color_mode=[ - app_commands.Choice(name="Random", value="random"), - app_commands.Choice(name="Color Scheme", value="scheme"), - app_commands.Choice(name="Solid Color", value="solid") - ]) - @app_commands.choices(color_scheme=[ - app_commands.Choice(name="Pastel", value="pastel"), - app_commands.Choice(name="Dark Gritty", value="dark_gritty"), - app_commands.Choice(name="Nature", value="nature"), - app_commands.Choice(name="Vibrant", value="vibrant"), - app_commands.Choice(name="Ocean", value="ocean"), - # Additional color schemes - app_commands.Choice(name="Neon", value="neon"), - app_commands.Choice(name="Monochrome", value="monochrome"), - app_commands.Choice(name="Autumn", value="autumn"), - app_commands.Choice(name="Cyberpunk", value="cyberpunk"), - app_commands.Choice(name="Retro", value="retro") - ]) - @app_commands.choices(audio_wave_type=[ - app_commands.Choice(name="Sawtooth", value="sawtooth"), - app_commands.Choice(name="Sine", value="sine"), - app_commands.Choice(name="Square", value="square"), - # Additional wave types - app_commands.Choice(name="Triangle", value="triangle"), - app_commands.Choice(name="Noise", value="noise"), - app_commands.Choice(name="Pulse", value="pulse"), - app_commands.Choice(name="Harmonic", value="harmonic") - ]) - @app_commands.choices(tts_provider=[ - app_commands.Choice(name="Google TTS", value="gtts"), - app_commands.Choice(name="pyttsx3 (Offline TTS)", value="pyttsx3"), - app_commands.Choice(name="Coqui TTS (AI Voice)", value="coqui") - ]) - @app_commands.choices(wave_vibe=[ - app_commands.Choice(name="Calm", value="calm"), - app_commands.Choice(name="Eerie", value="eerie"), - app_commands.Choice(name="Random", value="random"), - app_commands.Choice(name="Energetic", value="energetic"), - app_commands.Choice(name="Dreamy", value="dreamy"), - app_commands.Choice(name="Chaotic", value="chaotic"), - # Additional wave vibes - app_commands.Choice(name="Glitchy", value="glitchy"), - app_commands.Choice(name="Underwater", value="underwater"), - app_commands.Choice(name="Mechanical", value="mechanical"), - app_commands.Choice(name="Ethereal", value="ethereal"), - app_commands.Choice(name="Pulsating", value="pulsating") - ]) - @app_commands.choices(top_left_text_mode=[ - app_commands.Choice(name="Random", value="random"), - app_commands.Choice(name="Word", value="word") - ]) - @app_commands.choices(words_topic=[ - app_commands.Choice(name="Random", value="random"), - app_commands.Choice(name="Introspective", value="introspective"), - app_commands.Choice(name="Action", value="action"), - app_commands.Choice(name="Nature", value="nature"), - app_commands.Choice(name="Technology", value="technology"), - # Additional word topics - app_commands.Choice(name="Space", value="space"), - app_commands.Choice(name="Ocean", value="ocean"), - app_commands.Choice(name="Fantasy", value="fantasy"), - app_commands.Choice(name="Science", value="science"), - app_commands.Choice(name="Art", value="art"), - app_commands.Choice(name="Music", value="music"), - app_commands.Choice(name="Food", value="food"), - app_commands.Choice(name="Emotions", value="emotions"), - app_commands.Choice(name="Colors", value="colors"), - app_commands.Choice(name="Abstract", value="abstract") - ]) - @app_commands.choices(text_position=[ - app_commands.Choice(name="Top Left", value="top-left"), - app_commands.Choice(name="Top Right", value="top-right"), - app_commands.Choice(name="Bottom Left", value="bottom-left"), - app_commands.Choice(name="Bottom Right", value="bottom-right"), - app_commands.Choice(name="Center", value="center"), - app_commands.Choice(name="Random", value="random") - ]) - async def webdrivertorso_slash(self, interaction: discord.Interaction, - # Video structure - slides: int = None, - videos: int = None, - slide_duration: int = None, - - # Video dimensions - dimensions: str = None, - shape_size_limits: str = None, - - # Shapes - shapes_count: str = None, - deform_level: str = None, - shape_types: str = None, - - # Colors - color_mode: str = None, - color_scheme: str = None, - solid_color: str = None, - - # Audio - sound_quality: int = None, - audio_wave_type: str = None, - wave_vibe: str = None, - tts_enabled: bool = None, - tts_provider: str = None, - tts_text: str = None, - - # Text - top_left_text_enabled: bool = None, - top_left_text_mode: str = None, - words_topic: str = None, - text_color: str = None, - text_size: int = None, - text_position: str = None): + @app_commands.choices( + deform_level=[ + app_commands.Choice(name="None", value="none"), + app_commands.Choice(name="Low", value="low"), + app_commands.Choice(name="Medium", value="medium"), + app_commands.Choice(name="High", value="high"), + ] + ) + @app_commands.choices( + color_mode=[ + app_commands.Choice(name="Random", value="random"), + app_commands.Choice(name="Color Scheme", value="scheme"), + app_commands.Choice(name="Solid Color", value="solid"), + ] + ) + @app_commands.choices( + color_scheme=[ + app_commands.Choice(name="Pastel", value="pastel"), + app_commands.Choice(name="Dark Gritty", value="dark_gritty"), + app_commands.Choice(name="Nature", value="nature"), + app_commands.Choice(name="Vibrant", value="vibrant"), + app_commands.Choice(name="Ocean", value="ocean"), + # Additional color schemes + app_commands.Choice(name="Neon", value="neon"), + app_commands.Choice(name="Monochrome", value="monochrome"), + app_commands.Choice(name="Autumn", value="autumn"), + app_commands.Choice(name="Cyberpunk", value="cyberpunk"), + app_commands.Choice(name="Retro", value="retro"), + ] + ) + @app_commands.choices( + audio_wave_type=[ + app_commands.Choice(name="Sawtooth", value="sawtooth"), + app_commands.Choice(name="Sine", value="sine"), + app_commands.Choice(name="Square", value="square"), + # Additional wave types + app_commands.Choice(name="Triangle", value="triangle"), + app_commands.Choice(name="Noise", value="noise"), + app_commands.Choice(name="Pulse", value="pulse"), + app_commands.Choice(name="Harmonic", value="harmonic"), + ] + ) + @app_commands.choices( + tts_provider=[ + app_commands.Choice(name="Google TTS", value="gtts"), + app_commands.Choice(name="pyttsx3 (Offline TTS)", value="pyttsx3"), + app_commands.Choice(name="Coqui TTS (AI Voice)", value="coqui"), + ] + ) + @app_commands.choices( + wave_vibe=[ + app_commands.Choice(name="Calm", value="calm"), + app_commands.Choice(name="Eerie", value="eerie"), + app_commands.Choice(name="Random", value="random"), + app_commands.Choice(name="Energetic", value="energetic"), + app_commands.Choice(name="Dreamy", value="dreamy"), + app_commands.Choice(name="Chaotic", value="chaotic"), + # Additional wave vibes + app_commands.Choice(name="Glitchy", value="glitchy"), + app_commands.Choice(name="Underwater", value="underwater"), + app_commands.Choice(name="Mechanical", value="mechanical"), + app_commands.Choice(name="Ethereal", value="ethereal"), + app_commands.Choice(name="Pulsating", value="pulsating"), + ] + ) + @app_commands.choices( + top_left_text_mode=[ + app_commands.Choice(name="Random", value="random"), + app_commands.Choice(name="Word", value="word"), + ] + ) + @app_commands.choices( + words_topic=[ + app_commands.Choice(name="Random", value="random"), + app_commands.Choice(name="Introspective", value="introspective"), + app_commands.Choice(name="Action", value="action"), + app_commands.Choice(name="Nature", value="nature"), + app_commands.Choice(name="Technology", value="technology"), + # Additional word topics + app_commands.Choice(name="Space", value="space"), + app_commands.Choice(name="Ocean", value="ocean"), + app_commands.Choice(name="Fantasy", value="fantasy"), + app_commands.Choice(name="Science", value="science"), + app_commands.Choice(name="Art", value="art"), + app_commands.Choice(name="Music", value="music"), + app_commands.Choice(name="Food", value="food"), + app_commands.Choice(name="Emotions", value="emotions"), + app_commands.Choice(name="Colors", value="colors"), + app_commands.Choice(name="Abstract", value="abstract"), + ] + ) + @app_commands.choices( + text_position=[ + app_commands.Choice(name="Top Left", value="top-left"), + app_commands.Choice(name="Top Right", value="top-right"), + app_commands.Choice(name="Bottom Left", value="bottom-left"), + app_commands.Choice(name="Bottom Right", value="bottom-right"), + app_commands.Choice(name="Center", value="center"), + app_commands.Choice(name="Random", value="random"), + ] + ) + async def webdrivertorso_slash( + self, + interaction: discord.Interaction, + # Video structure + slides: int = None, + videos: int = None, + slide_duration: int = None, + # Video dimensions + dimensions: str = None, + shape_size_limits: str = None, + # Shapes + shapes_count: str = None, + deform_level: str = None, + shape_types: str = None, + # Colors + color_mode: str = None, + color_scheme: str = None, + solid_color: str = None, + # Audio + sound_quality: int = None, + audio_wave_type: str = None, + wave_vibe: str = None, + tts_enabled: bool = None, + tts_provider: str = None, + tts_text: str = None, + # Text + top_left_text_enabled: bool = None, + top_left_text_mode: str = None, + words_topic: str = None, + text_color: str = None, + text_size: int = None, + text_position: str = None, + ): """Slash command version of webdrivertorso.""" await interaction.response.defer(thinking=True) result = await self._generate_video_logic( @@ -508,26 +778,44 @@ class WebdriverTorsoCog(commands.Cog): slides=slides, videos=videos, slide_duration=slide_duration, - # Video dimensions - width=int(dimensions.split(',')[0]) if dimensions else None, - height=int(dimensions.split(',')[1]) if dimensions and ',' in dimensions else None, - min_width=int(shape_size_limits.split(',')[0]) if shape_size_limits else None, - min_height=int(shape_size_limits.split(',')[1]) if shape_size_limits and len(shape_size_limits.split(',')) > 1 else None, - max_width=int(shape_size_limits.split(',')[2]) if shape_size_limits and len(shape_size_limits.split(',')) > 2 else None, - max_height=int(shape_size_limits.split(',')[3]) if shape_size_limits and len(shape_size_limits.split(',')) > 3 else None, - + width=int(dimensions.split(",")[0]) if dimensions else None, + height=( + int(dimensions.split(",")[1]) + if dimensions and "," in dimensions + else None + ), + min_width=( + int(shape_size_limits.split(",")[0]) if shape_size_limits else None + ), + min_height=( + int(shape_size_limits.split(",")[1]) + if shape_size_limits and len(shape_size_limits.split(",")) > 1 + else None + ), + max_width=( + int(shape_size_limits.split(",")[2]) + if shape_size_limits and len(shape_size_limits.split(",")) > 2 + else None + ), + max_height=( + int(shape_size_limits.split(",")[3]) + if shape_size_limits and len(shape_size_limits.split(",")) > 3 + else None + ), # Shapes - min_shapes=int(shapes_count.split(',')[0]) if shapes_count else None, - max_shapes=int(shapes_count.split(',')[1]) if shapes_count and ',' in shapes_count else None, + min_shapes=int(shapes_count.split(",")[0]) if shapes_count else None, + max_shapes=( + int(shapes_count.split(",")[1]) + if shapes_count and "," in shapes_count + else None + ), deform_level=deform_level, - allowed_shapes=shape_types.split(',') if shape_types else None, - + allowed_shapes=shape_types.split(",") if shape_types else None, # Colors color_mode=color_mode, color_scheme=color_scheme, solid_color=solid_color, - # Audio sound_quality=sound_quality, audio_wave_type=audio_wave_type, @@ -535,7 +823,6 @@ class WebdriverTorsoCog(commands.Cog): tts_enabled=tts_enabled, tts_provider=tts_provider, tts_text=tts_text, - # Text top_left_text_enabled=top_left_text_enabled, top_left_text_mode=top_left_text_mode, @@ -543,12 +830,12 @@ class WebdriverTorsoCog(commands.Cog): text_color=text_color, text_size=text_size, text_position=text_position, - - already_deferred=True + already_deferred=True, ) if isinstance(result, str): await interaction.followup.send(result) + async def setup(bot: commands.Bot): await bot.add_cog(WebdriverTorsoCog(bot)) diff --git a/cogs/welcome_cog.py b/cogs/welcome_cog.py index abef3ce..80679fb 100644 --- a/cogs/welcome_cog.py +++ b/cogs/welcome_cog.py @@ -11,6 +11,7 @@ from global_bot_accessor import get_bot_instance log = logging.getLogger(__name__) + class WelcomeCog(commands.Cog): """Handles welcome and goodbye messages for guilds.""" @@ -19,14 +20,18 @@ class WelcomeCog(commands.Cog): print("WelcomeCog: Initializing and registering event listeners") # Check existing event listeners - print(f"WelcomeCog: Bot event listeners before registration: {self.bot.extra_events}") + print( + f"WelcomeCog: Bot event listeners before registration: {self.bot.extra_events}" + ) # Register event listeners self.bot.add_listener(self.on_member_join, "on_member_join") self.bot.add_listener(self.on_member_remove, "on_member_remove") # Check if event listeners were registered - print(f"WelcomeCog: Bot event listeners after registration: {self.bot.extra_events}") + print( + f"WelcomeCog: Bot event listeners after registration: {self.bot.extra_events}" + ) print("WelcomeCog: Event listeners registered") async def on_member_join(self, member: discord.Member): @@ -38,13 +43,21 @@ class WelcomeCog(commands.Cog): return log.debug(f"Member {member.name} joined guild {guild.name} ({guild.id})") - print(f"WelcomeCog: Member {member.name} joined guild {guild.name} ({guild.id})") + print( + f"WelcomeCog: Member {member.name} joined guild {guild.name} ({guild.id})" + ) # --- Fetch settings --- print(f"WelcomeCog: Fetching welcome settings for guild {guild.id}") - welcome_channel_id_str = await settings_manager.get_setting(guild.id, 'welcome_channel_id') - welcome_message_template = await settings_manager.get_setting(guild.id, 'welcome_message', default="Welcome {user} to {server}!") - print(f"WelcomeCog: Retrieved settings - channel_id: {welcome_channel_id_str}, message: {welcome_message_template}") + welcome_channel_id_str = await settings_manager.get_setting( + guild.id, "welcome_channel_id" + ) + welcome_message_template = await settings_manager.get_setting( + guild.id, "welcome_message", default="Welcome {user} to {server}!" + ) + print( + f"WelcomeCog: Retrieved settings - channel_id: {welcome_channel_id_str}, message: {welcome_message_template}" + ) # Handle the "__NONE__" marker for potentially unset values if not welcome_channel_id_str or welcome_channel_id_str == "__NONE__": @@ -56,25 +69,29 @@ class WelcomeCog(commands.Cog): welcome_channel_id = int(welcome_channel_id_str) channel = guild.get_channel(welcome_channel_id) if not channel or not isinstance(channel, discord.TextChannel): - log.warning(f"Welcome channel ID {welcome_channel_id} not found or not text channel in guild {guild.id}") + log.warning( + f"Welcome channel ID {welcome_channel_id} not found or not text channel in guild {guild.id}" + ) # Maybe remove the setting here if the channel is invalid? return # --- Format and send message --- # Basic formatting, can be expanded formatted_message = welcome_message_template.format( - user=member.mention, - username=member.name, - server=guild.name + user=member.mention, username=member.name, server=guild.name ) await channel.send(formatted_message) log.info(f"Sent welcome message for {member.name} in guild {guild.id}") except ValueError: - log.error(f"Invalid welcome_channel_id '{welcome_channel_id_str}' configured for guild {guild.id}") + log.error( + f"Invalid welcome_channel_id '{welcome_channel_id_str}' configured for guild {guild.id}" + ) except discord.Forbidden: - log.error(f"Missing permissions to send welcome message in channel {welcome_channel_id} for guild {guild.id}") + log.error( + f"Missing permissions to send welcome message in channel {welcome_channel_id} for guild {guild.id}" + ) except Exception as e: log.exception(f"Error sending welcome message for guild {guild.id}: {e}") @@ -91,9 +108,15 @@ class WelcomeCog(commands.Cog): # --- Fetch settings --- print(f"WelcomeCog: Fetching goodbye settings for guild {guild.id}") - goodbye_channel_id_str = await settings_manager.get_setting(guild.id, 'goodbye_channel_id') - goodbye_message_template = await settings_manager.get_setting(guild.id, 'goodbye_message', default="{username} has left the server.") - print(f"WelcomeCog: Retrieved settings - channel_id: {goodbye_channel_id_str}, message: {goodbye_message_template}") + goodbye_channel_id_str = await settings_manager.get_setting( + guild.id, "goodbye_channel_id" + ) + goodbye_message_template = await settings_manager.get_setting( + guild.id, "goodbye_message", default="{username} has left the server." + ) + print( + f"WelcomeCog: Retrieved settings - channel_id: {goodbye_channel_id_str}, message: {goodbye_message_template}" + ) # Handle the "__NONE__" marker if not goodbye_channel_id_str or goodbye_channel_id_str == "__NONE__": @@ -105,104 +128,158 @@ class WelcomeCog(commands.Cog): goodbye_channel_id = int(goodbye_channel_id_str) channel = guild.get_channel(goodbye_channel_id) if not channel or not isinstance(channel, discord.TextChannel): - log.warning(f"Goodbye channel ID {goodbye_channel_id} not found or not text channel in guild {guild.id}") + log.warning( + f"Goodbye channel ID {goodbye_channel_id} not found or not text channel in guild {guild.id}" + ) return # --- Format and send message --- formatted_message = goodbye_message_template.format( - user=member.mention, # Might not be mentionable after leaving + user=member.mention, # Might not be mentionable after leaving username=member.name, - server=guild.name + server=guild.name, ) await channel.send(formatted_message) log.info(f"Sent goodbye message for {member.name} in guild {guild.id}") except ValueError: - log.error(f"Invalid goodbye_channel_id '{goodbye_channel_id_str}' configured for guild {guild.id}") + log.error( + f"Invalid goodbye_channel_id '{goodbye_channel_id_str}' configured for guild {guild.id}" + ) except discord.Forbidden: - log.error(f"Missing permissions to send goodbye message in channel {goodbye_channel_id} for guild {guild.id}") + log.error( + f"Missing permissions to send goodbye message in channel {goodbye_channel_id} for guild {guild.id}" + ) except Exception as e: log.exception(f"Error sending goodbye message for guild {guild.id}: {e}") - - @commands.command(name='setwelcome', help="Sets the welcome message and channel. Usage: `setwelcome #channel [message template]`") + @commands.command( + name="setwelcome", + help="Sets the welcome message and channel. Usage: `setwelcome #channel [message template]`", + ) @commands.has_permissions(administrator=True) @commands.guild_only() - async def set_welcome(self, ctx: commands.Context, channel: discord.TextChannel, *, message_template: str = "Welcome {user} to {server}!"): + async def set_welcome( + self, + ctx: commands.Context, + channel: discord.TextChannel, + *, + message_template: str = "Welcome {user} to {server}!", + ): """Sets the channel and template for welcome messages.""" guild_id = ctx.guild.id - key_channel = 'welcome_channel_id' - key_message = 'welcome_message' + key_channel = "welcome_channel_id" + key_message = "welcome_message" # Use settings_manager.set_setting - success_channel = await settings_manager.set_setting(guild_id, key_channel, str(channel.id)) - success_message = await settings_manager.set_setting(guild_id, key_message, message_template) + success_channel = await settings_manager.set_setting( + guild_id, key_channel, str(channel.id) + ) + success_message = await settings_manager.set_setting( + guild_id, key_message, message_template + ) - if success_channel and success_message: # Both need to succeed - await ctx.send(f"Welcome messages will now be sent to {channel.mention} with the template:\n```\n{message_template}\n```") - log.info(f"Welcome settings updated for guild {guild_id} by {ctx.author.name}") + if success_channel and success_message: # Both need to succeed + await ctx.send( + f"Welcome messages will now be sent to {channel.mention} with the template:\n```\n{message_template}\n```" + ) + log.info( + f"Welcome settings updated for guild {guild_id} by {ctx.author.name}" + ) else: await ctx.send("Failed to save welcome settings. Check logs.") log.error(f"Failed to save welcome settings for guild {guild_id}") - @commands.command(name='disablewelcome', help="Disables welcome messages for this server.") + @commands.command( + name="disablewelcome", help="Disables welcome messages for this server." + ) @commands.has_permissions(administrator=True) @commands.guild_only() async def disable_welcome(self, ctx: commands.Context): """Disables welcome messages by removing the channel setting.""" guild_id = ctx.guild.id - key_channel = 'welcome_channel_id' - key_message = 'welcome_message' # Also clear the message template + key_channel = "welcome_channel_id" + key_message = "welcome_message" # Also clear the message template # Use set_setting with None to delete the settings - success_channel = await settings_manager.set_setting(guild_id, key_channel, None) - success_message = await settings_manager.set_setting(guild_id, key_message, None) + success_channel = await settings_manager.set_setting( + guild_id, key_channel, None + ) + success_message = await settings_manager.set_setting( + guild_id, key_message, None + ) - if success_channel and success_message: # Both need to succeed + if success_channel and success_message: # Both need to succeed await ctx.send("Welcome messages have been disabled.") - log.info(f"Welcome messages disabled for guild {guild_id} by {ctx.author.name}") + log.info( + f"Welcome messages disabled for guild {guild_id} by {ctx.author.name}" + ) else: await ctx.send("Failed to disable welcome messages. Check logs.") log.error(f"Failed to disable welcome settings for guild {guild_id}") - - @commands.command(name='setgoodbye', help="Sets the goodbye message and channel. Usage: `setgoodbye #channel [message template]`") + @commands.command( + name="setgoodbye", + help="Sets the goodbye message and channel. Usage: `setgoodbye #channel [message template]`", + ) @commands.has_permissions(administrator=True) @commands.guild_only() - async def set_goodbye(self, ctx: commands.Context, channel: discord.TextChannel, *, message_template: str = "{username} has left the server."): + async def set_goodbye( + self, + ctx: commands.Context, + channel: discord.TextChannel, + *, + message_template: str = "{username} has left the server.", + ): """Sets the channel and template for goodbye messages.""" guild_id = ctx.guild.id - key_channel = 'goodbye_channel_id' - key_message = 'goodbye_message' + key_channel = "goodbye_channel_id" + key_message = "goodbye_message" # Use settings_manager.set_setting - success_channel = await settings_manager.set_setting(guild_id, key_channel, str(channel.id)) - success_message = await settings_manager.set_setting(guild_id, key_message, message_template) + success_channel = await settings_manager.set_setting( + guild_id, key_channel, str(channel.id) + ) + success_message = await settings_manager.set_setting( + guild_id, key_message, message_template + ) - if success_channel and success_message: # Both need to succeed - await ctx.send(f"Goodbye messages will now be sent to {channel.mention} with the template:\n```\n{message_template}\n```") - log.info(f"Goodbye settings updated for guild {guild_id} by {ctx.author.name}") + if success_channel and success_message: # Both need to succeed + await ctx.send( + f"Goodbye messages will now be sent to {channel.mention} with the template:\n```\n{message_template}\n```" + ) + log.info( + f"Goodbye settings updated for guild {guild_id} by {ctx.author.name}" + ) else: await ctx.send("Failed to save goodbye settings. Check logs.") log.error(f"Failed to save goodbye settings for guild {guild_id}") - @commands.command(name='disablegoodbye', help="Disables goodbye messages for this server.") + @commands.command( + name="disablegoodbye", help="Disables goodbye messages for this server." + ) @commands.has_permissions(administrator=True) @commands.guild_only() async def disable_goodbye(self, ctx: commands.Context): """Disables goodbye messages by removing the channel setting.""" guild_id = ctx.guild.id - key_channel = 'goodbye_channel_id' - key_message = 'goodbye_message' + key_channel = "goodbye_channel_id" + key_message = "goodbye_message" # Use set_setting with None to delete the settings - success_channel = await settings_manager.set_setting(guild_id, key_channel, None) - success_message = await settings_manager.set_setting(guild_id, key_message, None) + success_channel = await settings_manager.set_setting( + guild_id, key_channel, None + ) + success_message = await settings_manager.set_setting( + guild_id, key_message, None + ) - if success_channel and success_message: # Both need to succeed + if success_channel and success_message: # Both need to succeed await ctx.send("Goodbye messages have been disabled.") - log.info(f"Goodbye messages disabled for guild {guild_id} by {ctx.author.name}") + log.info( + f"Goodbye messages disabled for guild {guild_id} by {ctx.author.name}" + ) else: await ctx.send("Failed to disable goodbye messages. Check logs.") log.error(f"Failed to disable goodbye settings for guild {guild_id}") @@ -216,25 +293,40 @@ class WelcomeCog(commands.Cog): if isinstance(error, commands.MissingPermissions): await ctx.send("You need Administrator permissions to use this command.") elif isinstance(error, commands.BadArgument): - await ctx.send(f"Invalid argument provided. Check the command help: `{ctx.prefix}help {ctx.command.name}`") + await ctx.send( + f"Invalid argument provided. Check the command help: `{ctx.prefix}help {ctx.command.name}`" + ) elif isinstance(error, commands.MissingRequiredArgument): - await ctx.send(f"Missing required argument. Check the command help: `{ctx.prefix}help {ctx.command.name}`") + await ctx.send( + f"Missing required argument. Check the command help: `{ctx.prefix}help {ctx.command.name}`" + ) elif isinstance(error, commands.NoPrivateMessage): await ctx.send("This command cannot be used in private messages.") else: - log.error(f"Unhandled error in WelcomeCog command '{ctx.command.name}': {error}") + log.error( + f"Unhandled error in WelcomeCog command '{ctx.command.name}': {error}" + ) await ctx.send("An unexpected error occurred. Please check the logs.") async def setup(bot: commands.Bot): # Ensure bot has pools initialized before adding the cog print("WelcomeCog setup function called!") - if not hasattr(bot, 'pg_pool') or not hasattr(bot, 'redis') or bot.pg_pool is None or bot.redis is None: - log.warning("Bot pools not initialized before loading WelcomeCog. Cog will not load.") + if ( + not hasattr(bot, "pg_pool") + or not hasattr(bot, "redis") + or bot.pg_pool is None + or bot.redis is None + ): + log.warning( + "Bot pools not initialized before loading WelcomeCog. Cog will not load." + ) print("WelcomeCog: Bot pools not initialized. Cannot load cog.") - return # Prevent loading if pools are missing + return # Prevent loading if pools are missing welcome_cog = WelcomeCog(bot) await bot.add_cog(welcome_cog) - print(f"WelcomeCog loaded! Event listeners registered: on_member_join, on_member_remove") + print( + f"WelcomeCog loaded! Event listeners registered: on_member_join, on_member_remove" + ) log.info("WelcomeCog loaded.") diff --git a/command_customization.py b/command_customization.py index 9a60ef5..623fcc4 100644 --- a/command_customization.py +++ b/command_customization.py @@ -2,6 +2,7 @@ Command customization utilities for Discord bot. Handles guild-specific command names and groups. """ + import discord from discord import app_commands import logging @@ -11,11 +12,13 @@ import settings_manager log = logging.getLogger(__name__) + class GuildCommandTransformer(app_commands.Transformer): """ A transformer that customizes command names based on guild settings. This is used to transform command names when they are displayed to users. """ + async def transform(self, interaction: discord.Interaction, value: str) -> str: """Transform the command name based on guild settings.""" if not interaction.guild: @@ -30,6 +33,7 @@ class GuildCommandSyncer: """ Handles syncing commands with guild-specific customizations. """ + def __init__(self, bot): self.bot = bot self._command_cache = {} # Cache of original commands @@ -40,8 +44,12 @@ class GuildCommandSyncer: Load command customizations for a specific guild. Returns a dictionary mapping original command names to custom names. """ - cmd_customizations = await settings_manager.get_all_command_customizations(guild_id) - group_customizations = await settings_manager.get_all_group_customizations(guild_id) + cmd_customizations = await settings_manager.get_all_command_customizations( + guild_id + ) + group_customizations = await settings_manager.get_all_group_customizations( + guild_id + ) if cmd_customizations is None or group_customizations is None: log.error(f"Failed to load command customizations for guild {guild_id}") @@ -49,7 +57,9 @@ class GuildCommandSyncer: # Combine command and group customizations customizations = {**cmd_customizations, **group_customizations} - log.info(f"Loaded {len(customizations)} command customizations for guild {guild_id}") + log.info( + f"Loaded {len(customizations)} command customizations for guild {guild_id}" + ) return customizations async def prepare_guild_commands(self, guild_id: int) -> List: @@ -73,12 +83,12 @@ class GuildCommandSyncer: guild_commands = [] for cmd in global_commands: # Set guild_id attribute for use in customization methods - setattr(cmd, 'guild_id', guild_id) + setattr(cmd, "guild_id", guild_id) if cmd.name in customizations: # Get the custom name custom_data = customizations[cmd.name] - custom_name = custom_data.get('name', cmd.name) + custom_name = custom_data.get("name", cmd.name) # Handle Command and Group objects differently if isinstance(cmd, app_commands.Command): @@ -99,38 +109,45 @@ class GuildCommandSyncer: # Store customized commands for this guild self._customized_commands[guild_id] = { - cmd.name: custom_cmd for cmd, custom_cmd in zip(global_commands, guild_commands) + cmd.name: custom_cmd + for cmd, custom_cmd in zip(global_commands, guild_commands) if cmd.name in customizations } return guild_commands - async def _create_custom_command(self, original_cmd: app_commands.Command, custom_name: str) -> app_commands.Command: + async def _create_custom_command( + self, original_cmd: app_commands.Command, custom_name: str + ) -> app_commands.Command: """ Create a copy of a command with a custom name and description. This is a simplified version - in practice, you'd need to handle all command attributes. """ # Get custom description if available custom_description = None - if hasattr(original_cmd, 'guild_id') and original_cmd.guild_id: + if hasattr(original_cmd, "guild_id") and original_cmd.guild_id: # This is a guild-specific command, get the custom description - custom_description = await settings_manager.get_custom_command_description(original_cmd.guild_id, original_cmd.name) + custom_description = await settings_manager.get_custom_command_description( + original_cmd.guild_id, original_cmd.name + ) # For simplicity, we're just creating a basic copy with the custom name and description # In a real implementation, you'd need to handle all command attributes and options custom_cmd = app_commands.Command( name=custom_name, description=custom_description or original_cmd.description, - callback=original_cmd.callback + callback=original_cmd.callback, ) # Copy options, if any - if hasattr(original_cmd, 'options'): + if hasattr(original_cmd, "options"): custom_cmd._params = original_cmd._params.copy() return custom_cmd - async def _create_custom_group(self, original_group: app_commands.Group, custom_name: str) -> app_commands.Group: + async def _create_custom_group( + self, original_group: app_commands.Group, custom_name: str + ) -> app_commands.Group: """ Create a copy of a group with a custom name. Groups don't have callbacks like commands, so we handle them differently. @@ -138,8 +155,7 @@ class GuildCommandSyncer: """ # Create a new group with the custom name (keeping original description) custom_group = app_commands.Group( - name=custom_name, - description=original_group.description + name=custom_name, description=original_group.description ) # Copy all subcommands from the original group @@ -183,6 +199,7 @@ def guild_command(name: str, description: str, **kwargs): async def my_command(interaction: discord.Interaction): ... """ + def decorator(func: Callable[[discord.Interaction], Awaitable[Any]]): # Create the app command @app_commands.command(name=name, description=description, **kwargs) @@ -209,13 +226,16 @@ class GuildCommandGroup(app_commands.Group): async def my_subcommand(interaction: discord.Interaction): ... """ + def __init__(self, name: str, description: str, **kwargs): super().__init__(name=name, description=description, **kwargs) self.__original_name__ = name async def get_guild_name(self, guild_id: int) -> str: """Get the guild-specific name for this group.""" - custom_name = await settings_manager.get_custom_group_name(guild_id, self.__original_name__) + custom_name = await settings_manager.get_custom_group_name( + guild_id, self.__original_name__ + ) return custom_name if custom_name else self.__original_name__ diff --git a/commands.py b/commands.py index b182419..87b195d 100644 --- a/commands.py +++ b/commands.py @@ -4,6 +4,7 @@ import discord from discord.ext import commands from typing import List, Optional + async def load_all_cogs(bot: commands.Bot, skip_cogs: Optional[List[str]] = None): """Loads all cogs from the 'cogs' directory, optionally skipping specified ones.""" if skip_cogs is None: @@ -17,14 +18,16 @@ async def load_all_cogs(bot: commands.Bot, skip_cogs: Optional[List[str]] = None print(f"Skipping cogs: {skip_cogs}") for filename in os.listdir(cogs_dir): - if filename.endswith(".py") and \ - not filename.startswith("__") and \ - not filename.startswith("gurt") and \ - not filename.startswith("profile_updater") and \ - not filename.startswith("neru") and \ - not filename.endswith("_base_cog.py") and \ - not filename.startswith("femdom") and \ - not filename == "VoiceGatewayCog.py": + if ( + filename.endswith(".py") + and not filename.startswith("__") + and not filename.startswith("gurt") + and not filename.startswith("profile_updater") + and not filename.startswith("neru") + and not filename.endswith("_base_cog.py") + and not filename.startswith("femdom") + and not filename == "VoiceGatewayCog.py" + ): # Special check for welcome_cog.py if filename == "welcome_cog.py": print(f"Found welcome_cog.py, attempting to load it...") @@ -48,7 +51,7 @@ async def load_all_cogs(bot: commands.Bot, skip_cogs: Optional[List[str]] = None failed_cogs.append(cog_name) except commands.ExtensionFailed as e: print(f"Error: Cog {cog_name} failed to load.") - print(f" Reason: {e.original}") # Print the original exception + print(f" Reason: {e.original}") # Print the original exception failed_cogs.append(cog_name) except Exception as e: print(f"An unexpected error occurred loading cog {cog_name}: {e}") @@ -61,6 +64,7 @@ async def load_all_cogs(bot: commands.Bot, skip_cogs: Optional[List[str]] = None print(f"Failed to load {len(failed_cogs)} cogs: {', '.join(failed_cogs)}") print("-" * 20) + # You might want a similar function for unloading or reloading async def unload_all_cogs(bot: commands.Bot): """Unloads all currently loaded cogs from the 'cogs' directory.""" @@ -79,6 +83,7 @@ async def unload_all_cogs(bot: commands.Bot): failed_unload.append(extension) return unloaded_cogs, failed_unload + async def reload_all_cogs(bot: commands.Bot, skip_cogs: Optional[List[str]] = None): """Reloads all currently loaded cogs from the 'cogs' directory, optionally skipping specified ones.""" if skip_cogs is None: @@ -87,7 +92,7 @@ async def reload_all_cogs(bot: commands.Bot, skip_cogs: Optional[List[str]] = No failed_reload = [] loaded_extensions = list(bot.extensions.keys()) for extension in loaded_extensions: - if extension.startswith("cogs."): + if extension.startswith("cogs."): if extension in skip_cogs: print(f"Skipping reload for AI cog: {extension}") # Ensure it's unloaded if it happened to be loaded before @@ -96,21 +101,25 @@ async def reload_all_cogs(bot: commands.Bot, skip_cogs: Optional[List[str]] = No await bot.unload_extension(extension) print(f"Unloaded skipped AI cog: {extension}") except Exception as unload_e: - print(f"Failed to unload skipped AI cog {extension}: {unload_e}") + print( + f"Failed to unload skipped AI cog {extension}: {unload_e}" + ) continue try: await bot.reload_extension(extension) print(f"Successfully reloaded cog: {extension}") reloaded_cogs.append(extension) except commands.ExtensionNotLoaded: - print(f"Cog {extension} was not loaded, attempting to load instead.") - try: - await bot.load_extension(extension) - print(f"Successfully loaded cog: {extension}") - reloaded_cogs.append(extension) # Count as reloaded for simplicity - except Exception as load_e: - print(f"Failed to load cog {extension} during reload attempt: {load_e}") - failed_reload.append(extension) + print(f"Cog {extension} was not loaded, attempting to load instead.") + try: + await bot.load_extension(extension) + print(f"Successfully loaded cog: {extension}") + reloaded_cogs.append(extension) # Count as reloaded for simplicity + except Exception as load_e: + print( + f"Failed to load cog {extension} during reload attempt: {load_e}" + ) + failed_reload.append(extension) except Exception as e: print(f"Failed to reload cog {extension}: {e}") # Attempt to unload if reload fails badly? Maybe too complex here. diff --git a/custom_bot_manager.py b/custom_bot_manager.py index ed48843..1c43566 100644 --- a/custom_bot_manager.py +++ b/custom_bot_manager.py @@ -15,7 +15,9 @@ import traceback from typing import Dict, Optional, Tuple, List # Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s') +logging.basicConfig( + level=logging.INFO, format="%(asctime)s:%(levelname)s:%(name)s: %(message)s" +) log = logging.getLogger(__name__) # Global storage for custom bot instances and their threads @@ -35,16 +37,17 @@ DEFAULT_COGS = [ "cogs.settings_cog", "cogs.utility_cog", "cogs.fun_cog", - "cogs.moderation_cog" + "cogs.moderation_cog", ] + class CustomBot(commands.Bot): """Custom bot class with additional functionality for user-specific bots.""" def __init__(self, user_id: str, *args, **kwargs): super().__init__(*args, **kwargs) self.user_id = user_id - self.owner_id = int(os.getenv('OWNER_USER_ID', '0')) + self.owner_id = int(os.getenv("OWNER_USER_ID", "0")) self._cleanup_tasks = [] # Track cleanup tasks async def setup_hook(self): @@ -57,7 +60,9 @@ class CustomBot(commands.Bot): await self.load_extension(cog) log.info(f"Loaded extension {cog} for custom bot {self.user_id}") except Exception as e: - log.error(f"Failed to load extension {cog} for custom bot {self.user_id}: {e}") + log.error( + f"Failed to load extension {cog} for custom bot {self.user_id}: {e}" + ) traceback.print_exc() async def close(self): @@ -67,11 +72,15 @@ class CustomBot(commands.Bot): # Close all cogs that have aiohttp sessions for cog_name, cog in self.cogs.items(): try: - if hasattr(cog, 'session') and cog.session and not cog.session.closed: + if hasattr(cog, "session") and cog.session and not cog.session.closed: await cog.session.close() - log.info(f"Closed aiohttp session for cog {cog_name} in custom bot {self.user_id}") + log.info( + f"Closed aiohttp session for cog {cog_name} in custom bot {self.user_id}" + ) except Exception as e: - log.error(f"Error closing session for cog {cog_name} in custom bot {self.user_id}: {e}") + log.error( + f"Error closing session for cog {cog_name} in custom bot {self.user_id}: {e}" + ) # Wait a bit for sessions to close properly await asyncio.sleep(0.1) @@ -80,12 +89,13 @@ class CustomBot(commands.Bot): await super().close() log.info(f"Custom bot for user {self.user_id} closed successfully") + async def create_custom_bot( user_id: str, token: str, prefix: str = "!", status_type: str = "listening", - status_text: str = "!help" + status_text: str = "!help", ) -> Tuple[bool, str]: """ Create a new custom bot instance for a user. @@ -111,24 +121,21 @@ async def create_custom_bot( intents.members = True # Create bot instance - bot = CustomBot( - user_id=user_id, - command_prefix=prefix, - intents=intents - ) + bot = CustomBot(user_id=user_id, command_prefix=prefix, intents=intents) # Set up events @bot.event async def on_ready(): - log.info(f"Custom bot {bot.user.name} (ID: {bot.user.id}) for user {user_id} is ready!") + log.info( + f"Custom bot {bot.user.name} (ID: {bot.user.id}) for user {user_id} is ready!" + ) # Set the bot's status - activity_type = getattr(discord.ActivityType, status_type, discord.ActivityType.listening) + activity_type = getattr( + discord.ActivityType, status_type, discord.ActivityType.listening + ) await bot.change_presence( - activity=discord.Activity( - type=activity_type, - name=status_text - ) + activity=discord.Activity(type=activity_type, name=status_text) ) # Update status @@ -138,7 +145,9 @@ async def create_custom_bot( @bot.event async def on_error(event, *args, **kwargs): - log.error(f"Error in custom bot for user {user_id} in event {event}: {sys.exc_info()[1]}") + log.error( + f"Error in custom bot for user {user_id} in event {event}: {sys.exc_info()[1]}" + ) custom_bot_errors[user_id] = str(sys.exc_info()[1]) # Store the bot instance @@ -153,6 +162,7 @@ async def create_custom_bot( custom_bot_errors[user_id] = str(e) return False, f"Error creating custom bot: {e}" + def run_custom_bot_in_thread(user_id: str, token: str) -> Tuple[bool, str]: """ Run a custom bot in a separate thread. @@ -185,7 +195,9 @@ def run_custom_bot_in_thread(user_id: str, token: str) -> Tuple[bool, str]: except discord.errors.LoginFailure: log.error(f"Invalid token for custom bot (user {user_id})") custom_bot_status[user_id] = STATUS_ERROR - custom_bot_errors[user_id] = "Invalid Discord bot token. Please check your token and try again." + custom_bot_errors[user_id] = ( + "Invalid Discord bot token. Please check your token and try again." + ) except Exception as e: log.error(f"Error running custom bot for user {user_id}: {e}") custom_bot_status[user_id] = STATUS_ERROR @@ -196,7 +208,9 @@ def run_custom_bot_in_thread(user_id: str, token: str) -> Tuple[bool, str]: try: await bot.close() except Exception as e: - log.error(f"Error closing bot during cleanup for user {user_id}: {e}") + log.error( + f"Error closing bot during cleanup for user {user_id}: {e}" + ) # Run the bot loop.run_until_complete(_run_bot()) @@ -214,9 +228,7 @@ def run_custom_bot_in_thread(user_id: str, token: str) -> Tuple[bool, str]: # Create and start the thread thread = threading.Thread( - target=_run_bot_thread, - daemon=True, - name=f"custom-bot-{user_id}" + target=_run_bot_thread, daemon=True, name=f"custom-bot-{user_id}" ) thread.start() @@ -225,6 +237,7 @@ def run_custom_bot_in_thread(user_id: str, token: str) -> Tuple[bool, str]: return True, f"Started custom bot for user {user_id}" + def stop_custom_bot(user_id: str) -> Tuple[bool, str]: """ Stop a running custom bot. @@ -274,13 +287,13 @@ def stop_custom_bot(user_id: str) -> Tuple[bool, str]: try: loop.close() except Exception as e: - log.error(f"Error closing event loop in close thread for user {user_id}: {e}") + log.error( + f"Error closing event loop in close thread for user {user_id}: {e}" + ) # Run the close operation in a new thread close_thread = threading.Thread( - target=_close_bot_thread, - daemon=True, - name=f"close-bot-{user_id}" + target=_close_bot_thread, daemon=True, name=f"close-bot-{user_id}" ) close_thread.start() @@ -293,6 +306,7 @@ def stop_custom_bot(user_id: str) -> Tuple[bool, str]: return True, f"Stopped custom bot for user {user_id}" + def get_custom_bot_status(user_id: str) -> Dict: """ Get the status of a custom bot. @@ -308,23 +322,19 @@ def get_custom_bot_status(user_id: str) -> Dict: "exists": False, "status": "not_created", "error": None, - "is_running": False + "is_running": False, } status = custom_bot_status.get(user_id, STATUS_STOPPED) error = custom_bot_errors.get(user_id) is_running = ( - user_id in custom_bot_threads and - custom_bot_threads[user_id].is_alive() and - status == STATUS_RUNNING + user_id in custom_bot_threads + and custom_bot_threads[user_id].is_alive() + and status == STATUS_RUNNING ) - return { - "exists": True, - "status": status, - "error": error, - "is_running": is_running - } + return {"exists": True, "status": status, "error": error, "is_running": is_running} + def get_all_custom_bot_statuses() -> Dict[str, Dict]: """ @@ -338,6 +348,7 @@ def get_all_custom_bot_statuses() -> Dict[str, Dict]: result[user_id] = get_custom_bot_status(user_id) return result + def list_custom_bots() -> List[Dict]: """ List all custom bot instances and their status. @@ -351,6 +362,7 @@ def list_custom_bots() -> List[Dict]: bots.append(bot_info) return bots + def cleanup_all_custom_bots() -> None: """ Clean up all custom bot instances. Should be called when the main bot shuts down. @@ -368,6 +380,7 @@ def cleanup_all_custom_bots() -> None: # Wait a bit for all bots to stop import time + time.sleep(2) # Force cleanup any remaining threads diff --git a/db/mod_log_db.py b/db/mod_log_db.py index f8c91b8..6078772 100644 --- a/db/mod_log_db.py +++ b/db/mod_log_db.py @@ -8,7 +8,10 @@ from typing import Optional, List, Tuple, Any, Callable log = logging.getLogger(__name__) -async def create_connection_with_retry(pool: asyncpg.Pool, max_retries: int = 3) -> Tuple[Optional[asyncpg.Connection], bool]: + +async def create_connection_with_retry( + pool: asyncpg.Pool, max_retries: int = 3 +) -> Tuple[Optional[asyncpg.Connection], bool]: """ Creates a database connection with retry logic. @@ -26,22 +29,30 @@ async def create_connection_with_retry(pool: asyncpg.Pool, max_retries: int = 3) try: connection = await pool.acquire() return connection, True - except (asyncpg.exceptions.ConnectionDoesNotExistError, - asyncpg.exceptions.InterfaceError) as e: + except ( + asyncpg.exceptions.ConnectionDoesNotExistError, + asyncpg.exceptions.InterfaceError, + ) as e: retry_count += 1 if retry_count < max_retries: - log.warning(f"Connection error when acquiring connection (attempt {retry_count}/{max_retries}): {e}") + log.warning( + f"Connection error when acquiring connection (attempt {retry_count}/{max_retries}): {e}" + ) await asyncio.sleep(retry_delay) retry_delay *= 2 # Exponential backoff else: - log.exception(f"Failed to acquire connection after {max_retries} attempts: {e}") + log.exception( + f"Failed to acquire connection after {max_retries} attempts: {e}" + ) return None, False except Exception as e: log.exception(f"Unexpected error acquiring connection: {e}") return None, False + # --- Cross-thread database operations --- + def run_in_bot_loop(bot_instance, coro_func): """ Runs a coroutine in the bot's event loop, even if called from a different thread. @@ -79,8 +90,16 @@ def run_in_bot_loop(bot_instance, coro_func): log.exception(f"Error running database operation in bot loop: {e}") return None -async def add_mod_log_safe(bot_instance, guild_id: int, moderator_id: int, target_user_id: int, - action_type: str, reason: Optional[str], duration_seconds: Optional[int] = None) -> Optional[int]: + +async def add_mod_log_safe( + bot_instance, + guild_id: int, + moderator_id: int, + target_user_id: int, + action_type: str, + reason: Optional[str], + duration_seconds: Optional[int] = None, +) -> Optional[int]: """ Thread-safe version of add_mod_log that ensures the operation runs in the bot's event loop. This should be used when calling from API routes or other threads. @@ -92,16 +111,25 @@ async def add_mod_log_safe(bot_instance, guild_id: int, moderator_id: int, targe if bot_instance is None: bot_instance = get_bot_instance() if bot_instance is None: - log.error("Cannot add mod log safely: bot_instance is None and global accessor returned None") + log.error( + "Cannot add mod log safely: bot_instance is None and global accessor returned None" + ) return None # Define the coroutine to run in the bot's event loop async def _add_mod_log_coro(): - if not hasattr(bot_instance, 'pg_pool') or bot_instance.pg_pool is None: + if not hasattr(bot_instance, "pg_pool") or bot_instance.pg_pool is None: log.error("Bot pg_pool is None, cannot add mod log") return None - return await add_mod_log(bot_instance.pg_pool, guild_id, moderator_id, target_user_id, - action_type, reason, duration_seconds) + return await add_mod_log( + bot_instance.pg_pool, + guild_id, + moderator_id, + target_user_id, + action_type, + reason, + duration_seconds, + ) # If we're already in the bot's event loop, just run the coroutine directly if asyncio.get_running_loop() == bot_instance.loop: @@ -110,7 +138,10 @@ async def add_mod_log_safe(bot_instance, guild_id: int, moderator_id: int, targe # Otherwise, use the helper function to run in the bot's loop return run_in_bot_loop(bot_instance, _add_mod_log_coro) -async def update_mod_log_message_details_safe(bot_instance, case_id: int, message_id: int, channel_id: int) -> bool: + +async def update_mod_log_message_details_safe( + bot_instance, case_id: int, message_id: int, channel_id: int +) -> bool: """ Thread-safe version of update_mod_log_message_details that ensures the operation runs in the bot's event loop. This should be used when calling from API routes or other threads. @@ -122,15 +153,19 @@ async def update_mod_log_message_details_safe(bot_instance, case_id: int, messag if bot_instance is None: bot_instance = get_bot_instance() if bot_instance is None: - log.error("Cannot update mod log message details safely: bot_instance is None and global accessor returned None") + log.error( + "Cannot update mod log message details safely: bot_instance is None and global accessor returned None" + ) return False # Define the coroutine to run in the bot's event loop async def _update_details_coro(): - if not hasattr(bot_instance, 'pg_pool') or bot_instance.pg_pool is None: + if not hasattr(bot_instance, "pg_pool") or bot_instance.pg_pool is None: log.error("Bot pg_pool is None, cannot update mod log message details") return False - return await update_mod_log_message_details(bot_instance.pg_pool, case_id, message_id, channel_id) + return await update_mod_log_message_details( + bot_instance.pg_pool, case_id, message_id, channel_id + ) # If we're already in the bot's event loop, just run the coroutine directly if asyncio.get_running_loop() == bot_instance.loop: @@ -139,6 +174,7 @@ async def update_mod_log_message_details_safe(bot_instance, case_id: int, messag # Otherwise, use the helper function to run in the bot's loop return run_in_bot_loop(bot_instance, _update_details_coro) + async def setup_moderation_log_table(pool: asyncpg.Pool): """ Ensures the moderation_logs table and its indexes exist in the database. @@ -146,13 +182,16 @@ async def setup_moderation_log_table(pool: asyncpg.Pool): # Get a connection with retry logic connection, success = await create_connection_with_retry(pool) if not success or not connection: - log.error("Failed to acquire database connection for setting up moderation_logs table") + log.error( + "Failed to acquire database connection for setting up moderation_logs table" + ) raise RuntimeError("Failed to acquire database connection for table setup") try: # Use a transaction to ensure all schema changes are atomic async with connection.transaction(): - await connection.execute(""" + await connection.execute( + """ CREATE TABLE IF NOT EXISTS moderation_logs ( case_id SERIAL PRIMARY KEY, guild_id BIGINT NOT NULL, @@ -165,22 +204,29 @@ async def setup_moderation_log_table(pool: asyncpg.Pool): log_message_id BIGINT NULL, log_channel_id BIGINT NULL ); - """) + """ + ) # Create indexes if they don't exist - await connection.execute(""" + await connection.execute( + """ CREATE INDEX IF NOT EXISTS idx_moderation_logs_guild_id ON moderation_logs (guild_id); - """) - await connection.execute(""" + """ + ) + await connection.execute( + """ CREATE INDEX IF NOT EXISTS idx_moderation_logs_target_user_id ON moderation_logs (target_user_id); - """) - await connection.execute(""" + """ + ) + await connection.execute( + """ CREATE INDEX IF NOT EXISTS idx_moderation_logs_moderator_id ON moderation_logs (moderator_id); - """) + """ + ) log.info("Successfully ensured moderation_logs table and indexes exist.") except Exception as e: log.exception(f"Error setting up moderation_logs table: {e}") - raise # Re-raise the exception to indicate setup failure + raise # Re-raise the exception to indicate setup failure finally: # Always release the connection back to the pool try: @@ -188,9 +234,19 @@ async def setup_moderation_log_table(pool: asyncpg.Pool): except Exception as e: log.warning(f"Error releasing connection back to pool: {e}") + # --- Placeholder functions (to be implemented next) --- -async def add_mod_log(pool: asyncpg.Pool, guild_id: int, moderator_id: int, target_user_id: int, action_type: str, reason: Optional[str], duration_seconds: Optional[int] = None) -> Optional[int]: + +async def add_mod_log( + pool: asyncpg.Pool, + guild_id: int, + moderator_id: int, + target_user_id: int, + action_type: str, + reason: Optional[str], + duration_seconds: Optional[int] = None, +) -> Optional[int]: """Adds a new moderation log entry and returns the case_id.""" query = """ INSERT INTO moderation_logs (guild_id, moderator_id, target_user_id, action_type, reason, duration_seconds) @@ -200,22 +256,38 @@ async def add_mod_log(pool: asyncpg.Pool, guild_id: int, moderator_id: int, targ # Get a connection with retry logic connection, success = await create_connection_with_retry(pool) if not success or not connection: - log.error(f"Failed to acquire database connection for adding mod log entry for guild {guild_id}") + log.error( + f"Failed to acquire database connection for adding mod log entry for guild {guild_id}" + ) return None try: # Use a transaction to ensure atomicity async with connection.transaction(): - result = await connection.fetchrow(query, guild_id, moderator_id, target_user_id, action_type, reason, duration_seconds) + result = await connection.fetchrow( + query, + guild_id, + moderator_id, + target_user_id, + action_type, + reason, + duration_seconds, + ) if result: - log.info(f"Added mod log entry for guild {guild_id}, action {action_type}. Case ID: {result['case_id']}") - return result['case_id'] + log.info( + f"Added mod log entry for guild {guild_id}, action {action_type}. Case ID: {result['case_id']}" + ) + return result["case_id"] else: - log.error(f"Failed to add mod log entry for guild {guild_id}, action {action_type} - No case_id returned.") + log.error( + f"Failed to add mod log entry for guild {guild_id}, action {action_type} - No case_id returned." + ) return None except RuntimeError as e: if "got Future" in str(e) and "attached to a different loop" in str(e): - log.error(f"Event loop error adding mod log entry for guild {guild_id}: {e}") + log.error( + f"Event loop error adding mod log entry for guild {guild_id}: {e}" + ) # This is likely happening because the function is being called from a different thread/event loop # We'll need to handle this case differently in the future return None @@ -238,6 +310,7 @@ async def add_mod_log(pool: asyncpg.Pool, guild_id: int, moderator_id: int, targ log.warning(f"Error releasing connection back to pool: {e}") # Continue execution even if we can't release the connection + async def update_mod_log_reason(pool: asyncpg.Pool, case_id: int, new_reason: str): """Updates the reason for a specific moderation log entry.""" query = """ @@ -248,7 +321,9 @@ async def update_mod_log_reason(pool: asyncpg.Pool, case_id: int, new_reason: st # Get a connection with retry logic connection, success = await create_connection_with_retry(pool) if not success or not connection: - log.error(f"Failed to acquire database connection for updating reason for case_id {case_id}") + log.error( + f"Failed to acquire database connection for updating reason for case_id {case_id}" + ) return False try: @@ -259,7 +334,9 @@ async def update_mod_log_reason(pool: asyncpg.Pool, case_id: int, new_reason: st log.info(f"Updated reason for case_id {case_id}") return True else: - log.warning(f"Could not update reason for case_id {case_id}. Case might not exist or no change made.") + log.warning( + f"Could not update reason for case_id {case_id}. Case might not exist or no change made." + ) return False except Exception as e: log.exception(f"Error updating mod log reason for case_id {case_id}: {e}") @@ -272,7 +349,10 @@ async def update_mod_log_reason(pool: asyncpg.Pool, case_id: int, new_reason: st log.warning(f"Error releasing connection back to pool: {e}") # Continue execution even if we can't release the connection -async def update_mod_log_message_details(pool: asyncpg.Pool, case_id: int, message_id: int, channel_id: int): + +async def update_mod_log_message_details( + pool: asyncpg.Pool, case_id: int, message_id: int, channel_id: int +): """Updates the log_message_id and log_channel_id for a specific case.""" query = """ UPDATE moderation_logs @@ -282,7 +362,9 @@ async def update_mod_log_message_details(pool: asyncpg.Pool, case_id: int, messa # Get a connection with retry logic connection, success = await create_connection_with_retry(pool) if not success or not connection: - log.error(f"Failed to acquire database connection for updating message details for case_id {case_id}") + log.error( + f"Failed to acquire database connection for updating message details for case_id {case_id}" + ) return False try: @@ -293,10 +375,14 @@ async def update_mod_log_message_details(pool: asyncpg.Pool, case_id: int, messa log.info(f"Updated message details for case_id {case_id}") return True else: - log.warning(f"Could not update message details for case_id {case_id}. Case might not exist or no change made.") + log.warning( + f"Could not update message details for case_id {case_id}. Case might not exist or no change made." + ) return False except Exception as e: - log.exception(f"Error updating mod log message details for case_id {case_id}: {e}") + log.exception( + f"Error updating mod log message details for case_id {case_id}: {e}" + ) return False finally: # Always release the connection back to the pool @@ -306,6 +392,7 @@ async def update_mod_log_message_details(pool: asyncpg.Pool, case_id: int, messa log.warning(f"Error releasing connection back to pool: {e}") # Continue execution even if we can't release the connection + async def get_mod_log(pool: asyncpg.Pool, case_id: int) -> Optional[asyncpg.Record]: """Retrieves a specific moderation log entry by case_id.""" query = "SELECT * FROM moderation_logs WHERE case_id = $1;" @@ -313,7 +400,9 @@ async def get_mod_log(pool: asyncpg.Pool, case_id: int) -> Optional[asyncpg.Reco # Get a connection with retry logic connection, success = await create_connection_with_retry(pool) if not success or not connection: - log.error(f"Failed to acquire database connection for retrieving mod log for case_id {case_id}") + log.error( + f"Failed to acquire database connection for retrieving mod log for case_id {case_id}" + ) return None try: @@ -329,7 +418,10 @@ async def get_mod_log(pool: asyncpg.Pool, case_id: int) -> Optional[asyncpg.Reco except Exception as e: log.warning(f"Error releasing connection back to pool: {e}") -async def get_user_mod_logs(pool: asyncpg.Pool, guild_id: int, target_user_id: int, limit: int = 50) -> List[asyncpg.Record]: + +async def get_user_mod_logs( + pool: asyncpg.Pool, guild_id: int, target_user_id: int, limit: int = 50 +) -> List[asyncpg.Record]: """Retrieves moderation logs for a specific user in a guild, ordered by timestamp descending.""" query = """ SELECT * FROM moderation_logs @@ -340,14 +432,18 @@ async def get_user_mod_logs(pool: asyncpg.Pool, guild_id: int, target_user_id: i # Get a connection with retry logic connection, success = await create_connection_with_retry(pool) if not success or not connection: - log.error(f"Failed to acquire database connection for retrieving user mod logs for user {target_user_id} in guild {guild_id}") + log.error( + f"Failed to acquire database connection for retrieving user mod logs for user {target_user_id} in guild {guild_id}" + ) return [] try: records = await connection.fetch(query, guild_id, target_user_id, limit) return records except Exception as e: - log.exception(f"Error retrieving user mod logs for user {target_user_id} in guild {guild_id}: {e}") + log.exception( + f"Error retrieving user mod logs for user {target_user_id} in guild {guild_id}: {e}" + ) return [] finally: # Always release the connection back to the pool @@ -356,7 +452,10 @@ async def get_user_mod_logs(pool: asyncpg.Pool, guild_id: int, target_user_id: i except Exception as e: log.warning(f"Error releasing connection back to pool: {e}") -async def get_guild_mod_logs(pool: asyncpg.Pool, guild_id: int, limit: int = 50) -> List[asyncpg.Record]: + +async def get_guild_mod_logs( + pool: asyncpg.Pool, guild_id: int, limit: int = 50 +) -> List[asyncpg.Record]: """Retrieves the latest moderation logs for a guild, ordered by timestamp descending.""" query = """ SELECT * FROM moderation_logs @@ -367,7 +466,9 @@ async def get_guild_mod_logs(pool: asyncpg.Pool, guild_id: int, limit: int = 50) # Get a connection with retry logic connection, success = await create_connection_with_retry(pool) if not success or not connection: - log.error(f"Failed to acquire database connection for retrieving guild mod logs for guild {guild_id}") + log.error( + f"Failed to acquire database connection for retrieving guild mod logs for guild {guild_id}" + ) return [] try: @@ -383,8 +484,16 @@ async def get_guild_mod_logs(pool: asyncpg.Pool, guild_id: int, limit: int = 50) except Exception as e: log.warning(f"Error releasing connection back to pool: {e}") -async def log_action_safe(bot_instance, guild_id: int, target_user_id: int, action_type: str, - reason: str, ai_details: dict, source: str = "AI_API") -> Optional[int]: + +async def log_action_safe( + bot_instance, + guild_id: int, + target_user_id: int, + action_type: str, + reason: str, + ai_details: dict, + source: str = "AI_API", +) -> Optional[int]: """ Thread-safe version of ModLogCog.log_action that ensures the operation runs in the bot's event loop. This should be used when calling from API routes or other threads. @@ -408,7 +517,9 @@ async def log_action_safe(bot_instance, guild_id: int, target_user_id: int, acti if bot_instance is None: bot_instance = get_bot_instance() if bot_instance is None: - log.error("Cannot log action safely: bot_instance is None and global accessor returned None") + log.error( + "Cannot log action safely: bot_instance is None and global accessor returned None" + ) return None # Define the coroutine to run in the bot's event loop @@ -421,7 +532,7 @@ async def log_action_safe(bot_instance, guild_id: int, target_user_id: int, acti return None # Get the ModLogCog instance - mod_log_cog = bot_instance.get_cog('ModLogCog') + mod_log_cog = bot_instance.get_cog("ModLogCog") if not mod_log_cog: log.error("ModLogCog not found") return None @@ -441,12 +552,14 @@ async def log_action_safe(bot_instance, guild_id: int, target_user_id: int, acti duration=None, source=source, ai_details=ai_details, - moderator_id_override=AI_MODERATOR_ID + moderator_id_override=AI_MODERATOR_ID, ) # Get the case_id from the most recent log entry for this user - recent_logs = await get_user_mod_logs(bot_instance.pg_pool, guild_id, target_user_id, limit=1) - case_id = recent_logs[0]['case_id'] if recent_logs else None + recent_logs = await get_user_mod_logs( + bot_instance.pg_pool, guild_id, target_user_id, limit=1 + ) + case_id = recent_logs[0]["case_id"] if recent_logs else None return case_id except Exception as e: @@ -460,6 +573,7 @@ async def log_action_safe(bot_instance, guild_id: int, target_user_id: int, acti # Otherwise, use the helper function to run in the bot's loop return run_in_bot_loop(bot_instance, _log_action_coro) + async def delete_mod_log(pool: asyncpg.Pool, case_id: int, guild_id: int) -> bool: """Deletes a specific moderation log entry by case_id, ensuring it belongs to the guild.""" query = """ @@ -468,28 +582,41 @@ async def delete_mod_log(pool: asyncpg.Pool, case_id: int, guild_id: int) -> boo """ connection, success = await create_connection_with_retry(pool) if not success or not connection: - log.error(f"Failed to acquire database connection for deleting mod log for case_id {case_id} in guild {guild_id}") + log.error( + f"Failed to acquire database connection for deleting mod log for case_id {case_id} in guild {guild_id}" + ) return False try: async with connection.transaction(): result = await connection.execute(query, case_id, guild_id) if result == "DELETE 1": - log.info(f"Deleted mod log entry for case_id {case_id} in guild {guild_id}") + log.info( + f"Deleted mod log entry for case_id {case_id} in guild {guild_id}" + ) return True else: - log.warning(f"Could not delete mod log entry for case_id {case_id} in guild {guild_id}. Case might not exist or not belong to this guild.") + log.warning( + f"Could not delete mod log entry for case_id {case_id} in guild {guild_id}. Case might not exist or not belong to this guild." + ) return False except Exception as e: - log.exception(f"Error deleting mod log entry for case_id {case_id} in guild {guild_id}: {e}") + log.exception( + f"Error deleting mod log entry for case_id {case_id} in guild {guild_id}: {e}" + ) return False finally: try: await pool.release(connection) except Exception as e: - log.warning(f"Error releasing connection back to pool after deleting mod log: {e}") + log.warning( + f"Error releasing connection back to pool after deleting mod log: {e}" + ) -async def clear_user_mod_logs(pool: asyncpg.Pool, guild_id: int, target_user_id: int) -> int: + +async def clear_user_mod_logs( + pool: asyncpg.Pool, guild_id: int, target_user_id: int +) -> int: """Deletes all moderation log entries for a specific user in a guild. Returns the number of deleted logs.""" query = """ DELETE FROM moderation_logs @@ -497,7 +624,9 @@ async def clear_user_mod_logs(pool: asyncpg.Pool, guild_id: int, target_user_id: """ connection, success = await create_connection_with_retry(pool) if not success or not connection: - log.error(f"Failed to acquire database connection for clearing mod logs for user {target_user_id} in guild {guild_id}") + log.error( + f"Failed to acquire database connection for clearing mod logs for user {target_user_id} in guild {guild_id}" + ) return 0 try: @@ -510,18 +639,28 @@ async def clear_user_mod_logs(pool: asyncpg.Pool, guild_id: int, target_user_id: try: deleted_count = int(result_status.split(" ")[1]) except (IndexError, ValueError) as e: - log.warning(f"Could not parse deleted count from status: {result_status} - {e}") - + log.warning( + f"Could not parse deleted count from status: {result_status} - {e}" + ) + if deleted_count > 0: - log.info(f"Cleared {deleted_count} mod log entries for user {target_user_id} in guild {guild_id}") + log.info( + f"Cleared {deleted_count} mod log entries for user {target_user_id} in guild {guild_id}" + ) else: - log.info(f"No mod log entries found to clear for user {target_user_id} in guild {guild_id}") + log.info( + f"No mod log entries found to clear for user {target_user_id} in guild {guild_id}" + ) return deleted_count except Exception as e: - log.exception(f"Error clearing mod log entries for user {target_user_id} in guild {guild_id}: {e}") + log.exception( + f"Error clearing mod log entries for user {target_user_id} in guild {guild_id}: {e}" + ) return 0 finally: try: await pool.release(connection) except Exception as e: - log.warning(f"Error releasing connection back to pool after clearing user mod logs: {e}") + log.warning( + f"Error releasing connection back to pool after clearing user mod logs: {e}" + ) diff --git a/discord_bot_sync_api.py b/discord_bot_sync_api.py index 5470846..5347fc1 100644 --- a/discord_bot_sync_api.py +++ b/discord_bot_sync_api.py @@ -5,14 +5,14 @@ import datetime from typing import Dict, List, Optional, Any, Union from fastapi import FastAPI, HTTPException, Depends, Header, Request, Response from fastapi.middleware.cors import CORSMiddleware -from fastapi.staticfiles import StaticFiles # Added for static files -from fastapi.responses import FileResponse # Added for serving HTML +from fastapi.staticfiles import StaticFiles # Added for static files +from fastapi.responses import FileResponse # Added for serving HTML from pydantic import BaseModel, Field import discord from discord.ext import commands import aiohttp import threading -from typing import Optional # Added for GurtCog type hint +from typing import Optional # Added for GurtCog type hint # This file contains the API endpoints for syncing conversations between # the Flutter app and the Discord bot, AND the Gurt stats endpoint. @@ -21,20 +21,23 @@ from typing import Optional # Added for GurtCog type hint # These need to be set by the script that starts the bot and API server # Import GurtCog and ModLogCog conditionally to avoid dependency issues try: - from gurt.cog import GurtCog # Import GurtCog for type hint and access - from cogs.mod_log_cog import ModLogCog # Import ModLogCog for type hint + from gurt.cog import GurtCog # Import GurtCog for type hint and access + from cogs.mod_log_cog import ModLogCog # Import ModLogCog for type hint + gurt_cog_instance: Optional[GurtCog] = None - mod_log_cog_instance: Optional[ModLogCog] = None # Placeholder for ModLogCog + mod_log_cog_instance: Optional[ModLogCog] = None # Placeholder for ModLogCog except ImportError as e: print(f"Warning: Could not import GurtCog or ModLogCog: {e}") # Use Any type as fallback from typing import Any + gurt_cog_instance: Optional[Any] = None mod_log_cog_instance: Optional[Any] = None bot_instance = None # Will be set to the Discord bot instance # ============= Models ============= + class SyncedMessage(BaseModel): content: str role: str # "user", "assistant", or "system" @@ -42,6 +45,7 @@ class SyncedMessage(BaseModel): reasoning: Optional[str] = None usage_data: Optional[Dict[str, Any]] = None + class UserSettings(BaseModel): # General settings model_id: str = "openai/gpt-3.5-turbo" @@ -72,6 +76,7 @@ class UserSettings(BaseModel): last_updated: datetime.datetime = Field(default_factory=datetime.datetime.now) sync_source: str = "discord" # "discord" or "flutter" + class SyncedConversation(BaseModel): id: str title: str @@ -96,20 +101,24 @@ class SyncedConversation(BaseModel): character_breakdown: bool = False custom_instructions: Optional[str] = None + class SyncRequest(BaseModel): conversations: List[SyncedConversation] last_sync_time: Optional[datetime.datetime] = None user_settings: Optional[UserSettings] = None + class SettingsSyncRequest(BaseModel): user_settings: UserSettings + class SyncResponse(BaseModel): success: bool message: str conversations: List[SyncedConversation] = [] user_settings: Optional[UserSettings] = None + # ============= Storage ============= # Files to store synced data @@ -123,6 +132,7 @@ os.makedirs(os.path.dirname(SYNC_DATA_FILE), exist_ok=True) user_conversations: Dict[str, List[SyncedConversation]] = {} user_settings: Dict[str, UserSettings] = {} + # Load conversations from file def load_conversations(): global user_conversations @@ -131,13 +141,16 @@ def load_conversations(): with open(SYNC_DATA_FILE, "r", encoding="utf-8") as f: data = json.load(f) # Convert string keys (user IDs) back to strings - user_conversations = {k: [SyncedConversation.model_validate(conv) for conv in v] - for k, v in data.items()} + user_conversations = { + k: [SyncedConversation.model_validate(conv) for conv in v] + for k, v in data.items() + } print(f"Loaded synced conversations for {len(user_conversations)} users") except Exception as e: print(f"Error loading synced conversations: {e}") user_conversations = {} + # Save conversations to file def save_conversations(): try: @@ -151,6 +164,7 @@ def save_conversations(): except Exception as e: print(f"Error saving synced conversations: {e}") + # Load user settings from file def load_user_settings(): global user_settings @@ -159,12 +173,15 @@ def load_user_settings(): with open(USER_SETTINGS_FILE, "r", encoding="utf-8") as f: data = json.load(f) # Convert string keys (user IDs) back to strings - user_settings = {k: UserSettings.model_validate(v) for k, v in data.items()} + user_settings = { + k: UserSettings.model_validate(v) for k, v in data.items() + } print(f"Loaded synced settings for {len(user_settings)} users") except Exception as e: print(f"Error loading synced user settings: {e}") user_settings = {} + # Save user settings to file def save_all_user_settings(): try: @@ -178,8 +195,10 @@ def save_all_user_settings(): except Exception as e: print(f"Error saving synced user settings: {e}") + # ============= Discord OAuth Verification ============= + async def verify_discord_token(authorization: str = Header(None)) -> str: """Verify the Discord token and return the user ID""" if not authorization: @@ -193,13 +212,16 @@ async def verify_discord_token(authorization: str = Header(None)) -> str: # Verify the token with Discord async with aiohttp.ClientSession() as session: headers = {"Authorization": f"Bearer {token}"} - async with session.get("https://discord.com/api/v10/users/@me", headers=headers) as resp: + async with session.get( + "https://discord.com/api/v10/users/@me", headers=headers + ) as resp: if resp.status != 200: raise HTTPException(status_code=401, detail="Invalid Discord token") user_data = await resp.json() return user_data["id"] + # ============= API Setup ============= # API Configuration @@ -211,7 +233,9 @@ SSL_KEY_FILE = "/etc/letsencrypt/live/slipstreamm.dev/privkey.pem" app = FastAPI(title="Discord Bot Sync API") # Create a sub-application for the API -api_app = FastAPI(title="Discord Bot Sync API", docs_url="/docs", openapi_url="/openapi.json") +api_app = FastAPI( + title="Discord Bot Sync API", docs_url="/docs", openapi_url="/openapi.json" +) # Mount the API app at the base path app.mount(API_BASE_PATH, api_app) @@ -234,6 +258,7 @@ api_app.add_middleware( allow_headers=["*"], ) + # Initialize by loading saved data @app.on_event("startup") async def startup_event(): @@ -242,7 +267,11 @@ async def startup_event(): # Try to load local settings from AI cog and merge them with synced settings try: - from cogs.ai_cog import user_settings as local_user_settings, get_user_settings as get_local_settings + from cogs.ai_cog import ( + user_settings as local_user_settings, + get_user_settings as get_local_settings, + ) + print("Merging local AI cog settings with synced settings...") # Iterate through local settings and update synced settings @@ -260,10 +289,18 @@ async def startup_event(): synced_settings = user_settings[user_id_str] # Always update all settings from local settings - synced_settings.model_id = local_settings.get("model", synced_settings.model_id) - synced_settings.temperature = local_settings.get("temperature", synced_settings.temperature) - synced_settings.max_tokens = local_settings.get("max_tokens", synced_settings.max_tokens) - synced_settings.system_message = local_settings.get("system_prompt", synced_settings.system_message) + synced_settings.model_id = local_settings.get( + "model", synced_settings.model_id + ) + synced_settings.temperature = local_settings.get( + "temperature", synced_settings.temperature + ) + synced_settings.max_tokens = local_settings.get( + "max_tokens", synced_settings.max_tokens + ) + synced_settings.system_message = local_settings.get( + "system_prompt", synced_settings.system_message + ) # Handle character settings - explicitly check if they exist in local settings if "character" in local_settings: @@ -280,19 +317,29 @@ async def startup_event(): synced_settings.character_info = None # Always update character_breakdown - synced_settings.character_breakdown = local_settings.get("character_breakdown", False) + synced_settings.character_breakdown = local_settings.get( + "character_breakdown", False + ) # Handle custom_instructions - explicitly check if they exist in local settings if "custom_instructions" in local_settings: - synced_settings.custom_instructions = local_settings["custom_instructions"] + synced_settings.custom_instructions = local_settings[ + "custom_instructions" + ] else: # If not in local settings, set to None synced_settings.custom_instructions = None # Always update reasoning settings - synced_settings.reasoning_enabled = local_settings.get("show_reasoning", False) - synced_settings.reasoning_effort = local_settings.get("reasoning_effort", "medium") - synced_settings.web_search_enabled = local_settings.get("web_search_enabled", False) + synced_settings.reasoning_enabled = local_settings.get( + "show_reasoning", False + ) + synced_settings.reasoning_effort = local_settings.get( + "reasoning_effort", "medium" + ) + synced_settings.web_search_enabled = local_settings.get( + "web_search_enabled", False + ) # Update timestamp and sync source synced_settings.last_updated = datetime.datetime.now() @@ -304,21 +351,26 @@ async def startup_event(): except Exception as e: print(f"Error merging local settings with synced settings: {e}") + # ============= API Endpoints ============= + @app.get(API_BASE_PATH + "/") async def root(): return {"message": "Discord Bot Sync API is running"} + @api_app.get("/") async def api_root(): return {"message": "Discord Bot Sync API is running"} + @api_app.get("/auth") async def auth(code: str, state: str = None): """Handle OAuth callback""" return {"message": "Authentication successful", "code": code, "state": state} + @api_app.get("/conversations") async def get_conversations(user_id: str = Depends(verify_discord_token)): """Get all conversations for a user""" @@ -327,10 +379,10 @@ async def get_conversations(user_id: str = Depends(verify_discord_token)): return {"conversations": user_conversations[user_id]} + @api_app.post("/sync") async def sync_conversations( - sync_request: SyncRequest, - user_id: str = Depends(verify_discord_token) + sync_request: SyncRequest, user_id: str = Depends(verify_discord_token) ): """Sync conversations between the Flutter app and Discord bot""" # Get existing conversations for this user @@ -340,15 +392,20 @@ async def sync_conversations( updated_conversations = [] for incoming_conv in sync_request.conversations: # Check if this conversation already exists - existing_conv = next((conv for conv in existing_conversations - if conv.id == incoming_conv.id), None) + existing_conv = next( + (conv for conv in existing_conversations if conv.id == incoming_conv.id), + None, + ) if existing_conv: # If the incoming conversation is newer, update it if incoming_conv.updated_at > existing_conv.updated_at: # Replace the existing conversation - existing_conversations = [conv for conv in existing_conversations - if conv.id != incoming_conv.id] + existing_conversations = [ + conv + for conv in existing_conversations + if conv.id != incoming_conv.id + ] existing_conversations.append(incoming_conv) updated_conversations.append(incoming_conv) else: @@ -368,7 +425,10 @@ async def sync_conversations( # If we have existing settings, check which is newer if existing_settings: - if not existing_settings.last_updated or incoming_settings.last_updated > existing_settings.last_updated: + if ( + not existing_settings.last_updated + or incoming_settings.last_updated > existing_settings.last_updated + ): user_settings[user_id] = incoming_settings save_all_user_settings() user_settings_response = incoming_settings @@ -384,22 +444,25 @@ async def sync_conversations( success=True, message=f"Synced {len(updated_conversations)} conversations", conversations=existing_conversations, - user_settings=user_settings_response + user_settings=user_settings_response, ) + @api_app.delete("/conversations/{conversation_id}") async def delete_conversation( - conversation_id: str, - user_id: str = Depends(verify_discord_token) + conversation_id: str, user_id: str = Depends(verify_discord_token) ): """Delete a conversation""" if user_id not in user_conversations: - raise HTTPException(status_code=404, detail="No conversations found for this user") + raise HTTPException( + status_code=404, detail="No conversations found for this user" + ) # Filter out the conversation to delete original_count = len(user_conversations[user_id]) - user_conversations[user_id] = [conv for conv in user_conversations[user_id] - if conv.id != conversation_id] + user_conversations[user_id] = [ + conv for conv in user_conversations[user_id] if conv.id != conversation_id + ] # Check if any conversation was deleted if len(user_conversations[user_id]) == original_count: @@ -424,15 +487,19 @@ async def get_gurt_stats_api(): except Exception as e: print(f"Error retrieving Gurt stats via API: {e}") import traceback + traceback.print_exc() raise HTTPException(status_code=500, detail=f"Error retrieving Gurt stats: {e}") + # --- Gurt Dashboard Static Files --- # Mount static files directory (adjust path if needed, assuming dashboard files are in discordbot/gurt_dashboard) # Check if the directory exists before mounting dashboard_dir = "discordbot/gurt_dashboard" if os.path.exists(dashboard_dir) and os.path.isdir(dashboard_dir): - api_app.mount("/gurt/static", StaticFiles(directory=dashboard_dir), name="gurt_static") + api_app.mount( + "/gurt/static", StaticFiles(directory=dashboard_dir), name="gurt_static" + ) print(f"Mounted Gurt dashboard static files from: {dashboard_dir}") # Route for the main dashboard HTML @@ -442,9 +509,14 @@ if os.path.exists(dashboard_dir) and os.path.isdir(dashboard_dir): if os.path.exists(dashboard_html_path): return dashboard_html_path else: - raise HTTPException(status_code=404, detail="Dashboard index.html not found") + raise HTTPException( + status_code=404, detail="Dashboard index.html not found" + ) + else: - print(f"Warning: Gurt dashboard directory '{dashboard_dir}' not found. Dashboard endpoints will not be available.") + print( + f"Warning: Gurt dashboard directory '{dashboard_dir}' not found. Dashboard endpoints will not be available." + ) @api_app.get("/settings") @@ -452,7 +524,10 @@ async def get_user_settings(user_id: str = Depends(verify_discord_token)): """Get user settings""" # Import the AI cog's get_user_settings function to get local settings try: - from cogs.ai_cog import get_user_settings as get_local_settings, user_settings as local_user_settings + from cogs.ai_cog import ( + get_user_settings as get_local_settings, + user_settings as local_user_settings, + ) # Get local settings from the AI cog local_settings = get_local_settings(int(user_id)) @@ -472,9 +547,15 @@ async def get_user_settings(user_id: str = Depends(verify_discord_token)): # Always update all settings from local settings synced_settings.model_id = local_settings.get("model", synced_settings.model_id) - synced_settings.temperature = local_settings.get("temperature", synced_settings.temperature) - synced_settings.max_tokens = local_settings.get("max_tokens", synced_settings.max_tokens) - synced_settings.system_message = local_settings.get("system_prompt", synced_settings.system_message) + synced_settings.temperature = local_settings.get( + "temperature", synced_settings.temperature + ) + synced_settings.max_tokens = local_settings.get( + "max_tokens", synced_settings.max_tokens + ) + synced_settings.system_message = local_settings.get( + "system_prompt", synced_settings.system_message + ) # Handle character settings - explicitly check if they exist in local settings if "character" in local_settings: @@ -491,7 +572,9 @@ async def get_user_settings(user_id: str = Depends(verify_discord_token)): synced_settings.character_info = None # Always update character_breakdown - synced_settings.character_breakdown = local_settings.get("character_breakdown", False) + synced_settings.character_breakdown = local_settings.get( + "character_breakdown", False + ) # Handle custom_instructions - explicitly check if they exist in local settings if "custom_instructions" in local_settings: @@ -502,8 +585,12 @@ async def get_user_settings(user_id: str = Depends(verify_discord_token)): # Always update reasoning settings synced_settings.reasoning_enabled = local_settings.get("show_reasoning", False) - synced_settings.reasoning_effort = local_settings.get("reasoning_effort", "medium") - synced_settings.web_search_enabled = local_settings.get("web_search_enabled", False) + synced_settings.reasoning_effort = local_settings.get( + "reasoning_effort", "medium" + ) + synced_settings.web_search_enabled = local_settings.get( + "web_search_enabled", False + ) # Update timestamp and sync source synced_settings.last_updated = datetime.datetime.now() @@ -530,10 +617,10 @@ async def get_user_settings(user_id: str = Depends(verify_discord_token)): return {"settings": user_settings[user_id]} + @api_app.post("/settings") async def update_user_settings( - settings_request: SettingsSyncRequest, - user_id: str = Depends(verify_discord_token) + settings_request: SettingsSyncRequest, user_id: str = Depends(verify_discord_token) ): """Update user settings""" incoming_settings = settings_request.user_settings @@ -557,14 +644,23 @@ async def update_user_settings( # If we have existing settings, check which is newer if existing_settings: - if not existing_settings.last_updated or incoming_settings.last_updated > existing_settings.last_updated: + if ( + not existing_settings.last_updated + or incoming_settings.last_updated > existing_settings.last_updated + ): print(f"Updating settings for user {user_id} (incoming settings are newer)") user_settings[user_id] = incoming_settings save_all_user_settings() else: # Return existing settings if they're newer - print(f"Not updating settings for user {user_id} (existing settings are newer)") - return {"success": True, "message": "Existing settings are newer", "settings": existing_settings} + print( + f"Not updating settings for user {user_id} (existing settings are newer)" + ) + return { + "success": True, + "message": "Existing settings are newer", + "settings": existing_settings, + } else: # No existing settings, just save the incoming ones print(f"Creating new settings for user {user_id}") @@ -581,7 +677,10 @@ async def update_user_settings( # Update the local settings in the AI cog try: - from cogs.ai_cog import user_settings as local_user_settings, save_user_settings as save_local_user_settings + from cogs.ai_cog import ( + user_settings as local_user_settings, + save_user_settings as save_local_user_settings, + ) # Convert user_id to int for the AI cog int_user_id = int(user_id) @@ -595,7 +694,9 @@ async def update_user_settings( local_user_settings[int_user_id]["model"] = incoming_settings.model_id local_user_settings[int_user_id]["temperature"] = incoming_settings.temperature local_user_settings[int_user_id]["max_tokens"] = incoming_settings.max_tokens - local_user_settings[int_user_id]["system_prompt"] = incoming_settings.system_message + local_user_settings[int_user_id][ + "system_prompt" + ] = incoming_settings.system_message # Handle character settings - explicitly set to None if null in incoming settings if incoming_settings.character is None: @@ -613,10 +714,14 @@ async def update_user_settings( local_user_settings[int_user_id].pop("character_info") print(f"Removed character_info setting for user {user_id}") else: - local_user_settings[int_user_id]["character_info"] = incoming_settings.character_info + local_user_settings[int_user_id][ + "character_info" + ] = incoming_settings.character_info # Always update character_breakdown - local_user_settings[int_user_id]["character_breakdown"] = incoming_settings.character_breakdown + local_user_settings[int_user_id][ + "character_breakdown" + ] = incoming_settings.character_breakdown # Handle custom_instructions - explicitly set to None if null in incoming settings if incoming_settings.custom_instructions is None: @@ -625,31 +730,53 @@ async def update_user_settings( local_user_settings[int_user_id].pop("custom_instructions") print(f"Removed custom_instructions setting for user {user_id}") else: - local_user_settings[int_user_id]["custom_instructions"] = incoming_settings.custom_instructions + local_user_settings[int_user_id][ + "custom_instructions" + ] = incoming_settings.custom_instructions # Always update reasoning settings - local_user_settings[int_user_id]["show_reasoning"] = incoming_settings.reasoning_enabled - local_user_settings[int_user_id]["reasoning_effort"] = incoming_settings.reasoning_effort - local_user_settings[int_user_id]["web_search_enabled"] = incoming_settings.web_search_enabled + local_user_settings[int_user_id][ + "show_reasoning" + ] = incoming_settings.reasoning_enabled + local_user_settings[int_user_id][ + "reasoning_effort" + ] = incoming_settings.reasoning_effort + local_user_settings[int_user_id][ + "web_search_enabled" + ] = incoming_settings.web_search_enabled # Save the updated local settings save_local_user_settings() print(f"Updated local settings in AI cog for user {user_id}:") print(f"Character: {local_user_settings[int_user_id].get('character')}") - print(f"Character Info: {local_user_settings[int_user_id].get('character_info')}") - print(f"Character Breakdown: {local_user_settings[int_user_id].get('character_breakdown')}") - print(f"Custom Instructions: {local_user_settings[int_user_id].get('custom_instructions')}") + print( + f"Character Info: {local_user_settings[int_user_id].get('character_info')}" + ) + print( + f"Character Breakdown: {local_user_settings[int_user_id].get('character_breakdown')}" + ) + print( + f"Custom Instructions: {local_user_settings[int_user_id].get('custom_instructions')}" + ) except Exception as e: print(f"Error updating local settings in AI cog: {e}") - return {"success": True, "message": "Settings updated", "settings": user_settings[user_id]} + return { + "success": True, + "message": "Settings updated", + "settings": user_settings[user_id], + } + # ============= Discord Bot Integration ============= + # This function should be called from your Discord bot's AI cog # to convert AI conversation history to the synced format -def convert_ai_history_to_synced(user_id: str, conversation_history: Dict[int, List[Dict[str, Any]]]): +def convert_ai_history_to_synced( + user_id: str, conversation_history: Dict[int, List[Dict[str, Any]]] +): """Convert the AI conversation history to the synced format""" synced_conversations = [] @@ -668,38 +795,43 @@ def convert_ai_history_to_synced(user_id: str, conversation_history: Dict[int, L if role not in ["user", "assistant", "system"]: continue - synced_messages.append(SyncedMessage( - content=msg.get("content", ""), - role=role, - timestamp=datetime.datetime.now(), # Use current time as we don't have the original timestamp - reasoning=None, # Discord bot doesn't store reasoning - usage_data=None # Discord bot doesn't store usage data - )) + synced_messages.append( + SyncedMessage( + content=msg.get("content", ""), + role=role, + timestamp=datetime.datetime.now(), # Use current time as we don't have the original timestamp + reasoning=None, # Discord bot doesn't store reasoning + usage_data=None, # Discord bot doesn't store usage data + ) + ) # Create the synced conversation - synced_conversations.append(SyncedConversation( - id=conv_id, - title="Discord Conversation", # Default title - messages=synced_messages, - created_at=datetime.datetime.now(), - updated_at=datetime.datetime.now(), - model_id="openai/gpt-3.5-turbo", # Default model - sync_source="discord", - last_synced_at=datetime.datetime.now(), - reasoning_enabled=False, - reasoning_effort="medium", - temperature=0.7, - max_tokens=1000, - web_search_enabled=False, - system_message=None, - character=None, - character_info=None, - character_breakdown=False, - custom_instructions=None - )) + synced_conversations.append( + SyncedConversation( + id=conv_id, + title="Discord Conversation", # Default title + messages=synced_messages, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + model_id="openai/gpt-3.5-turbo", # Default model + sync_source="discord", + last_synced_at=datetime.datetime.now(), + reasoning_enabled=False, + reasoning_effort="medium", + temperature=0.7, + max_tokens=1000, + web_search_enabled=False, + system_message=None, + character=None, + character_info=None, + character_breakdown=False, + custom_instructions=None, + ) + ) return synced_conversations + # This function should be called from your Discord bot's AI cog # to save a new conversation from Discord def save_discord_conversation( @@ -717,7 +849,7 @@ def save_discord_conversation( character: Optional[str] = None, character_info: Optional[str] = None, character_breakdown: bool = False, - custom_instructions: Optional[str] = None + custom_instructions: Optional[str] = None, ): """Save a conversation from Discord to the synced storage""" # Convert messages to the synced format @@ -727,17 +859,21 @@ def save_discord_conversation( if role not in ["user", "assistant", "system"]: continue - synced_messages.append(SyncedMessage( - content=msg.get("content", ""), - role=role, - timestamp=datetime.datetime.now(), - reasoning=msg.get("reasoning"), - usage_data=msg.get("usage_data") - )) + synced_messages.append( + SyncedMessage( + content=msg.get("content", ""), + role=role, + timestamp=datetime.datetime.now(), + reasoning=msg.get("reasoning"), + usage_data=msg.get("usage_data"), + ) + ) # Create a unique ID for this conversation if not provided if not conversation_id: - conversation_id = f"discord_{user_id}_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" + conversation_id = ( + f"discord_{user_id}_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" + ) # Create the synced conversation synced_conv = SyncedConversation( @@ -758,7 +894,7 @@ def save_discord_conversation( character=character, character_info=character_info, character_breakdown=character_breakdown, - custom_instructions=custom_instructions + custom_instructions=custom_instructions, ) # Add to storage @@ -768,8 +904,9 @@ def save_discord_conversation( # Check if we're updating an existing conversation if conversation_id: # Remove the old conversation with the same ID if it exists - user_conversations[user_id] = [conv for conv in user_conversations[user_id] - if conv.id != conversation_id] + user_conversations[user_id] = [ + conv for conv in user_conversations[user_id] if conv.id != conversation_id + ] user_conversations[user_id].append(synced_conv) save_conversations() diff --git a/discord_oauth.py b/discord_oauth.py index 965aec0..8b2615a 100644 --- a/discord_oauth.py +++ b/discord_oauth.py @@ -28,7 +28,11 @@ CLIENT_ID = os.getenv("DISCORD_CLIENT_ID", "1360717457852993576") # Use the API service's OAuth endpoint if available, otherwise use the local server API_URL = os.getenv("API_URL", "https://slipstreamm.dev/api") -API_OAUTH_ENABLED = os.getenv("API_OAUTH_ENABLED", "true").lower() in ("true", "1", "yes") +API_OAUTH_ENABLED = os.getenv("API_OAUTH_ENABLED", "true").lower() in ( + "true", + "1", + "yes", +) # If API OAuth is enabled, use the API service's OAuth endpoint if API_OAUTH_ENABLED: @@ -41,7 +45,9 @@ else: # Otherwise, use the local OAuth server OAUTH_HOST = os.getenv("OAUTH_HOST", "localhost") OAUTH_PORT = int(os.getenv("OAUTH_PORT", "8080")) - REDIRECT_URI = os.getenv("DISCORD_REDIRECT_URI", f"http://{OAUTH_HOST}:{OAUTH_PORT}/oauth/callback") + REDIRECT_URI = os.getenv( + "DISCORD_REDIRECT_URI", f"http://{OAUTH_HOST}:{OAUTH_PORT}/oauth/callback" + ) # Discord API endpoints API_ENDPOINT = "https://discord.com/api/v10" @@ -60,23 +66,29 @@ code_verifiers: Dict[str, Any] = {} # This is used to pass the code verifier to the API service pending_code_verifiers: Dict[str, str] = {} + class OAuthError(Exception): """Exception raised for OAuth errors.""" + pass + def generate_code_verifier() -> str: """Generate a code verifier for PKCE.""" return secrets.token_urlsafe(64) + def generate_code_challenge(verifier: str) -> str: """Generate a code challenge from a code verifier.""" sha256 = hashlib.sha256(verifier.encode()).digest() return base64.urlsafe_b64encode(sha256).decode().rstrip("=") + def get_token_path(user_id: str) -> str: """Get the path to the token file for a user.""" return os.path.join(TOKEN_DIR, f"{user_id}.json") + def save_token(user_id: str, token_data: Dict[str, Any]) -> None: """Save a token to disk.""" # Add the time when the token was saved @@ -85,6 +97,7 @@ def save_token(user_id: str, token_data: Dict[str, Any]) -> None: with open(get_token_path(user_id), "w") as f: json.dump(token_data, f) + def load_token(user_id: str) -> Optional[Dict[str, Any]]: """Load a token from disk.""" token_path = get_token_path(user_id) @@ -97,6 +110,7 @@ def load_token(user_id: str) -> Optional[Dict[str, Any]]: except (json.JSONDecodeError, IOError): return None + def is_token_expired(token_data: Dict[str, Any]) -> bool: """Check if a token is expired.""" if not token_data: @@ -112,6 +126,7 @@ def is_token_expired(token_data: Dict[str, Any]) -> bool: # We consider it expired if it's within 5 minutes of expiration return (saved_at + expires_in - 300) < int(time.time()) + def delete_token(user_id: str) -> bool: """Delete a token from disk.""" token_path = get_token_path(user_id) @@ -120,6 +135,7 @@ def delete_token(user_id: str) -> bool: return True return False + async def send_code_verifier_to_api(state: str, code_verifier: str) -> bool: """Send the code verifier to the API service.""" try: @@ -128,10 +144,7 @@ async def send_code_verifier_to_api(state: str, code_verifier: str) -> bool: url = f"{API_URL}/code_verifier" # Prepare the data - data = { - "state": state, - "code_verifier": code_verifier - } + data = {"state": state, "code_verifier": code_verifier} # Send the code verifier to the API service print(f"Sending code verifier for state {state} to API service: {url}") @@ -143,22 +156,28 @@ async def send_code_verifier_to_api(state: str, code_verifier: str) -> bool: async with session.post(url, json=data) as resp: if resp.status == 200: response_data = await resp.json() - print(f"Successfully sent code verifier to API service: {response_data}") + print( + f"Successfully sent code verifier to API service: {response_data}" + ) return True else: error_text = await resp.text() - print(f"Failed to send code verifier to API service (attempt {retry+1}/{max_retries}): {error_text}") + print( + f"Failed to send code verifier to API service (attempt {retry+1}/{max_retries}): {error_text}" + ) if retry < max_retries - 1: # Wait before retrying, with exponential backoff - wait_time = 2 ** retry + wait_time = 2**retry print(f"Retrying in {wait_time} seconds...") await asyncio.sleep(wait_time) else: return False except aiohttp.ClientError as ce: - print(f"Connection error when sending code verifier (attempt {retry+1}/{max_retries}): {ce}") + print( + f"Connection error when sending code verifier (attempt {retry+1}/{max_retries}): {ce}" + ) if retry < max_retries - 1: - wait_time = 2 ** retry + wait_time = 2**retry print(f"Retrying in {wait_time} seconds...") await asyncio.sleep(wait_time) else: @@ -170,6 +189,7 @@ async def send_code_verifier_to_api(state: str, code_verifier: str) -> bool: traceback.print_exc() return False + def get_auth_url(state: str, code_verifier: str) -> str: """Get the authorization URL for the OAuth2 flow.""" code_challenge = generate_code_challenge(code_verifier) @@ -194,7 +214,7 @@ def get_auth_url(state: str, code_verifier: str) -> str: "state": state, "code_challenge": code_challenge, "code_challenge_method": "S256", - "prompt": "consent" + "prompt": "consent", } auth_url = f"{AUTH_URL}?{urlencode(params)}" @@ -202,7 +222,7 @@ def get_auth_url(state: str, code_verifier: str) -> str: # Store the code verifier and redirect URI for this state code_verifiers[state] = { "code_verifier": code_verifier, - "redirect_uri": actual_redirect_uri + "redirect_uri": actual_redirect_uri, } # Also store the code verifier in the global dictionary @@ -218,14 +238,18 @@ def get_auth_url(state: str, code_verifier: str) -> str: loop = asyncio.get_event_loop() send_success = False try: - send_success = loop.run_until_complete(send_code_verifier_to_api(state, code_verifier)) + send_success = loop.run_until_complete( + send_code_verifier_to_api(state, code_verifier) + ) except Exception as e: print(f"Error in synchronous code verifier send: {e}") # Fall back to async task if synchronous call fails asyncio.create_task(send_code_verifier_to_api(state, code_verifier)) if not send_success: - print("Warning: Failed to send code verifier synchronously, falling back to async task") + print( + "Warning: Failed to send code verifier synchronously, falling back to async task" + ) # Try again asynchronously as a backup asyncio.create_task(send_code_verifier_to_api(state, code_verifier)) else: @@ -233,6 +257,7 @@ def get_auth_url(state: str, code_verifier: str) -> str: return auth_url + async def exchange_code(code: str, state: str) -> Dict[str, Any]: """Exchange an authorization code for a token.""" # Get the code verifier and redirect URI for this state @@ -259,11 +284,7 @@ async def exchange_code(code: str, state: str) -> Dict[str, Any]: # We'll make a request to the API service with the code and code_verifier async with aiohttp.ClientSession() as session: # Construct the URL with the code and code_verifier - params = { - "code": code, - "state": state, - "code_verifier": code_verifier - } + params = {"code": code, "state": state, "code_verifier": code_verifier} auth_url = f"{API_URL}/auth?{urlencode(params)}" print(f"Redirecting to API service for token exchange: {auth_url}") @@ -273,7 +294,9 @@ async def exchange_code(code: str, state: str) -> Dict[str, Any]: if resp.status != 200: error_text = await resp.text() print(f"Failed to exchange code with API service: {error_text}") - raise OAuthError(f"Failed to exchange code with API service: {error_text}") + raise OAuthError( + f"Failed to exchange code with API service: {error_text}" + ) # The API service should return a success page, not the token # We'll need to get the token from the API service separately @@ -286,12 +309,16 @@ async def exchange_code(code: str, state: str) -> Dict[str, Any]: # Save the token data token_data = response_data["token"] save_token(response_data["user_id"], token_data) - print(f"Successfully saved token for user {response_data['user_id']}") + print( + f"Successfully saved token for user {response_data['user_id']}" + ) return token_data else: # If the response doesn't contain a token, it's probably an HTML response # We'll need to get the token from the API service separately - print("Response doesn't contain token data, will try to get it separately") + print( + "Response doesn't contain token data, will try to get it separately" + ) except Exception as e: print(f"Error parsing response: {e}") @@ -299,11 +326,15 @@ async def exchange_code(code: str, state: str) -> Dict[str, Any]: try: # Make a request to the API service to get the token headers = {"Accept": "application/json"} - async with session.get(f"{API_URL}/token", headers=headers) as token_resp: + async with session.get( + f"{API_URL}/token", headers=headers + ) as token_resp: if token_resp.status != 200: error_text = await token_resp.text() print(f"Failed to get token from API service: {error_text}") - raise OAuthError(f"Failed to get token from API service: {error_text}") + raise OAuthError( + f"Failed to get token from API service: {error_text}" + ) token_data = await token_resp.json() if "access_token" in token_data: @@ -313,7 +344,11 @@ async def exchange_code(code: str, state: str) -> Dict[str, Any]: except Exception as e: print(f"Error getting token from API service: {e}") # Return a placeholder token for now - return {"access_token": "placeholder_token", "token_type": "Bearer", "expires_in": 604800} + return { + "access_token": "placeholder_token", + "token_type": "Bearer", + "expires_in": 604800, + } # If we're handling the token exchange ourselves, proceed as before async with aiohttp.ClientSession() as session: @@ -323,7 +358,7 @@ async def exchange_code(code: str, state: str) -> Dict[str, Any]: "grant_type": "authorization_code", "code": code, "redirect_uri": redirect_uri, - "code_verifier": code_verifier + "code_verifier": code_verifier, } print(f"Exchanging code for token with data: {data}") @@ -336,6 +371,7 @@ async def exchange_code(code: str, state: str) -> Dict[str, Any]: return await resp.json() + async def refresh_token(refresh_token: str) -> Dict[str, Any]: """Refresh an access token.""" async with aiohttp.ClientSession() as session: @@ -343,7 +379,7 @@ async def refresh_token(refresh_token: str) -> Dict[str, Any]: data = { "client_id": CLIENT_ID, "grant_type": "refresh_token", - "refresh_token": refresh_token + "refresh_token": refresh_token, } print(f"Refreshing token with data: {data}") @@ -356,6 +392,7 @@ async def refresh_token(refresh_token: str) -> Dict[str, Any]: return await resp.json() + async def get_user_info(access_token: str) -> Dict[str, Any]: """Get information about the authenticated user.""" async with aiohttp.ClientSession() as session: @@ -368,6 +405,7 @@ async def get_user_info(access_token: str) -> Dict[str, Any]: return await resp.json() + async def get_token(user_id: str) -> Optional[str]: """Get a valid access token for a user.""" # Load the token from disk @@ -399,6 +437,7 @@ async def get_token(user_id: str) -> Optional[str]: # Return the access token return token_data.get("access_token") + async def validate_token(token: str) -> Tuple[bool, Optional[str]]: """Validate a token and return the user ID if valid.""" try: diff --git a/download_illustrious.py b/download_illustrious.py index 68b4177..4987419 100644 --- a/download_illustrious.py +++ b/download_illustrious.py @@ -20,22 +20,28 @@ MODEL_INFO_URL = f"https://civitai.com/api/v1/models/{MODEL_ID}" # Base SDXL model from HuggingFace (we'll use this as a base and replace the unet) SDXL_BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" + def download_file(url, destination, filename=None): """Download a file with progress bar""" if filename is None: - local_filename = os.path.join(destination, url.split('/')[-1]) + local_filename = os.path.join(destination, url.split("/")[-1]) else: local_filename = os.path.join(destination, filename) with requests.get(url, stream=True) as r: r.raise_for_status() - total_size = int(r.headers.get('content-length', 0)) + total_size = int(r.headers.get("content-length", 0)) # Create directory if it doesn't exist os.makedirs(os.path.dirname(local_filename), exist_ok=True) - with open(local_filename, 'wb') as f: - with tqdm(total=total_size, unit='B', unit_scale=True, desc=f"Downloading {os.path.basename(local_filename)}") as pbar: + with open(local_filename, "wb") as f: + with tqdm( + total=total_size, + unit="B", + unit_scale=True, + desc=f"Downloading {os.path.basename(local_filename)}", + ) as pbar: for chunk in r.iter_content(chunk_size=8192): if chunk: f.write(chunk) @@ -43,6 +49,7 @@ def download_file(url, destination, filename=None): return local_filename + def download_from_huggingface(repo_id, local_dir, component=None): """Download a model from HuggingFace""" try: @@ -53,20 +60,19 @@ def download_from_huggingface(repo_id, local_dir, component=None): repo_id=repo_id, local_dir=local_dir, local_dir_use_symlinks=False, - allow_patterns=f"{component}/*" + allow_patterns=f"{component}/*", ) else: print(f"Downloading full model from {repo_id}...") huggingface_hub.snapshot_download( - repo_id=repo_id, - local_dir=local_dir, - local_dir_use_symlinks=False + repo_id=repo_id, local_dir=local_dir, local_dir_use_symlinks=False ) return True except Exception as e: print(f"Error downloading from HuggingFace: {e}") return False + def download_illustrious_xl(): """Download and set up the Illustrious XL model""" # Set up directories @@ -80,11 +86,18 @@ def download_illustrious_xl(): os.makedirs(temp_dir, exist_ok=True) # Check if model already exists - if os.path.exists(os.path.join(illustrious_dir, "unet", "diffusion_pytorch_model.safetensors")) and \ - os.path.getsize(os.path.join(illustrious_dir, "unet", "diffusion_pytorch_model.safetensors")) > 100000000: # Check if file is larger than 100MB + if ( + os.path.exists( + os.path.join(illustrious_dir, "unet", "diffusion_pytorch_model.safetensors") + ) + and os.path.getsize( + os.path.join(illustrious_dir, "unet", "diffusion_pytorch_model.safetensors") + ) + > 100000000 + ): # Check if file is larger than 100MB print(f"⚠️ {MODEL_NAME} model already exists at {illustrious_dir}") choice = input("Do you want to re-download and reinstall the model? (y/n): ") - if choice.lower() != 'y': + if choice.lower() != "y": print("Download cancelled.") return @@ -107,7 +120,7 @@ def download_illustrious_xl(): json.dump(model_info, f, indent=2) print(f"Model: {model_info['name']} by {model_info['creator']['username']}") - if 'description' in model_info: + if "description" in model_info: print(f"Description: {model_info['description'][:100]}...") except Exception as e: @@ -116,11 +129,20 @@ def download_illustrious_xl(): # First, download the base SDXL model from HuggingFace print(f"Step 1: Downloading base SDXL model from HuggingFace...") - print("This will download the VAE, text encoders, and tokenizers needed for the model.") + print( + "This will download the VAE, text encoders, and tokenizers needed for the model." + ) print("This may take a while (several GB of data)...") # Download each component separately to avoid downloading the full model - components = ["vae", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler"] + components = [ + "vae", + "text_encoder", + "text_encoder_2", + "tokenizer", + "tokenizer_2", + "scheduler", + ] for component in components: success = download_from_huggingface(SDXL_BASE_MODEL, illustrious_dir, component) if not success: @@ -130,10 +152,13 @@ def download_illustrious_xl(): # Try using diffusers to download the model try: print(f"Installing diffusers if not already installed...") - subprocess.check_call([sys.executable, "-m", "pip", "install", "diffusers"]) + subprocess.check_call( + [sys.executable, "-m", "pip", "install", "diffusers"] + ) # Use Python to download the model components from diffusers import StableDiffusionXLPipeline + print(f"Downloading {component} using diffusers...") # Create a temporary directory for the download @@ -146,20 +171,26 @@ def download_illustrious_xl(): torch_dtype="float16", variant="fp16", use_safetensors=True, - cache_dir=temp_model_dir + cache_dir=temp_model_dir, ) # Copy the component to the illustrious directory component_dir = os.path.join(temp_model_dir, component) if os.path.exists(component_dir): - shutil.copytree(component_dir, os.path.join(illustrious_dir, component), dirs_exist_ok=True) + shutil.copytree( + component_dir, + os.path.join(illustrious_dir, component), + dirs_exist_ok=True, + ) print(f"Successfully copied {component} to {illustrious_dir}") else: print(f"Could not find {component} in downloaded model.") except Exception as e: print(f"Error using diffusers to download {component}: {e}") - print("You may need to manually download the SDXL base model and copy the components.") + print( + "You may need to manually download the SDXL base model and copy the components." + ) # Now download the Illustrious XL model from Civitai print(f"\nStep 2: Downloading {MODEL_NAME} from Civitai...") @@ -172,7 +203,12 @@ def download_illustrious_xl(): # Move the model file to the unet directory print(f"Moving {MODEL_NAME} model to the unet directory...") - shutil.move(model_file, os.path.join(illustrious_dir, "unet", "diffusion_pytorch_model.safetensors")) + shutil.move( + model_file, + os.path.join( + illustrious_dir, "unet", "diffusion_pytorch_model.safetensors" + ), + ) # Create a model_index.json file print("Creating model_index.json file...") @@ -193,36 +229,24 @@ def download_illustrious_xl(): "thresholding": False, "timestep_spacing": "leading", "trained_betas": None, - "use_karras_sigmas": True + "use_karras_sigmas": True, }, "text_encoder": [ - { - "_class_name": "CLIPTextModel", - "_diffusers_version": "0.21.4" - }, + {"_class_name": "CLIPTextModel", "_diffusers_version": "0.21.4"}, { "_class_name": "CLIPTextModelWithProjection", - "_diffusers_version": "0.21.4" - } + "_diffusers_version": "0.21.4", + }, ], "tokenizer": [ - { - "_class_name": "CLIPTokenizer", - "_diffusers_version": "0.21.4" - }, - { - "_class_name": "CLIPTokenizer", - "_diffusers_version": "0.21.4" - } + {"_class_name": "CLIPTokenizer", "_diffusers_version": "0.21.4"}, + {"_class_name": "CLIPTokenizer", "_diffusers_version": "0.21.4"}, ], "unet": { "_class_name": "UNet2DConditionModel", - "_diffusers_version": "0.21.4" + "_diffusers_version": "0.21.4", }, - "vae": { - "_class_name": "AutoencoderKL", - "_diffusers_version": "0.21.4" - } + "vae": {"_class_name": "AutoencoderKL", "_diffusers_version": "0.21.4"}, } with open(os.path.join(illustrious_dir, "model_index.json"), "w") as f: @@ -231,28 +255,41 @@ def download_illustrious_xl(): # Create a README.md file with information about the model with open(os.path.join(illustrious_dir, "README.md"), "w") as f: f.write(f"# {MODEL_NAME}\n\n") - f.write(f"Downloaded from Civitai: https://civitai.com/models/{MODEL_ID}\n\n") + f.write( + f"Downloaded from Civitai: https://civitai.com/models/{MODEL_ID}\n\n" + ) f.write("This model requires the diffusers library to use.\n") - f.write("Use the /generate command in the Discord bot to generate images with this model.\n") + f.write( + "Use the /generate command in the Discord bot to generate images with this model.\n" + ) # Check if the model file is large enough (should be several GB) - unet_file = os.path.join(illustrious_dir, "unet", "diffusion_pytorch_model.safetensors") + unet_file = os.path.join( + illustrious_dir, "unet", "diffusion_pytorch_model.safetensors" + ) if os.path.exists(unet_file): file_size_gb = os.path.getsize(unet_file) / (1024 * 1024 * 1024) print(f"Model file size: {file_size_gb:.2f} GB") if file_size_gb < 1.0: - print(f"⚠️ Warning: Model file seems too small ({file_size_gb:.2f} GB). It may not be complete.") - print("The download might have been interrupted or the model might not be the full version.") + print( + f"⚠️ Warning: Model file seems too small ({file_size_gb:.2f} GB). It may not be complete." + ) + print( + "The download might have been interrupted or the model might not be the full version." + ) print("You may want to try downloading again with the --force flag.") print(f"\n✅ {MODEL_NAME} model has been downloaded and set up successfully!") print(f"Model location: {illustrious_dir}") - print("You can now use the model with the /generate command in the Discord bot.") + print( + "You can now use the model with the /generate command in the Discord bot." + ) except Exception as e: print(f"❌ Error downloading or setting up the model: {e}") import traceback + traceback.print_exc() # Clean up @@ -268,9 +305,16 @@ def download_illustrious_xl(): return True + if __name__ == "__main__": - parser = argparse.ArgumentParser(description=f"Download and set up the {MODEL_NAME} model from Civitai") - parser.add_argument("--force", action="store_true", help="Force download even if the model already exists") + parser = argparse.ArgumentParser( + description=f"Download and set up the {MODEL_NAME} model from Civitai" + ) + parser.add_argument( + "--force", + action="store_true", + help="Force download even if the model already exists", + ) args = parser.parse_args() if args.force: diff --git a/error_handler.py b/error_handler.py index 78df6dc..b05d1c2 100644 --- a/error_handler.py +++ b/error_handler.py @@ -7,6 +7,7 @@ import datetime # Global function for storing interaction content store_interaction_content = None + # Utility functions to store message content before sending async def store_and_send(ctx_or_interaction, content, **kwargs): """Store the message content and then send it.""" @@ -21,26 +22,32 @@ async def store_and_send(ctx_or_interaction, content, **kwargs): else: return await ctx_or_interaction.followup.send(content, **kwargs) + async def store_and_reply(ctx, content, **kwargs): """Store the message content and then reply to the message.""" ctx._last_message_content = content return await ctx.reply(content, **kwargs) + def extract_message_content(ctx_or_interaction): """Extract message content from a Context or Interaction object.""" content = None # Check if this is an AI command error is_ai_command = False - if isinstance(ctx_or_interaction, commands.Context) and hasattr(ctx_or_interaction, 'command'): - is_ai_command = ctx_or_interaction.command and ctx_or_interaction.command.name == 'ai' - elif hasattr(ctx_or_interaction, 'command') and ctx_or_interaction.command: - is_ai_command = ctx_or_interaction.command.name == 'ai' + if isinstance(ctx_or_interaction, commands.Context) and hasattr( + ctx_or_interaction, "command" + ): + is_ai_command = ( + ctx_or_interaction.command and ctx_or_interaction.command.name == "ai" + ) + elif hasattr(ctx_or_interaction, "command") and ctx_or_interaction.command: + is_ai_command = ctx_or_interaction.command.name == "ai" # For AI commands, try to load from the ai_response.txt file if it exists - if is_ai_command and os.path.exists('ai_response.txt'): + if is_ai_command and os.path.exists("ai_response.txt"): try: - with open('ai_response.txt', 'r', encoding='utf-8') as f: + with open("ai_response.txt", "r", encoding="utf-8") as f: content = f.read() if content: return content @@ -54,10 +61,12 @@ def extract_message_content(ctx_or_interaction): from cogs.ai_cog import interaction_responses # Get the interaction ID - interaction_id = getattr(ctx_or_interaction, 'id', None) + interaction_id = getattr(ctx_or_interaction, "id", None) if interaction_id and interaction_id in interaction_responses: content = interaction_responses[interaction_id] - print(f"Retrieved content for interaction {interaction_id} from dictionary") + print( + f"Retrieved content for interaction {interaction_id} from dictionary" + ) if content: return content except Exception as e: @@ -65,32 +74,54 @@ def extract_message_content(ctx_or_interaction): if isinstance(ctx_or_interaction, commands.Context): # For Context objects - if hasattr(ctx_or_interaction, '_last_message_content'): + if hasattr(ctx_or_interaction, "_last_message_content"): content = ctx_or_interaction._last_message_content - elif hasattr(ctx_or_interaction, 'message') and hasattr(ctx_or_interaction.message, 'content'): + elif hasattr(ctx_or_interaction, "message") and hasattr( + ctx_or_interaction.message, "content" + ): content = ctx_or_interaction.message.content - elif hasattr(ctx_or_interaction, '_internal_response'): + elif hasattr(ctx_or_interaction, "_internal_response"): content = str(ctx_or_interaction._internal_response) # Try to extract from command invocation - elif hasattr(ctx_or_interaction, 'command') and hasattr(ctx_or_interaction, 'kwargs'): + elif hasattr(ctx_or_interaction, "command") and hasattr( + ctx_or_interaction, "kwargs" + ): # Reconstruct command invocation - cmd_name = ctx_or_interaction.command.name if hasattr(ctx_or_interaction.command, 'name') else 'unknown_command' - args_str = ' '.join([str(arg) for arg in ctx_or_interaction.args[1:]]) if hasattr(ctx_or_interaction, 'args') else '' - kwargs_str = ' '.join([f'{k}={v}' for k, v in ctx_or_interaction.kwargs.items()]) if ctx_or_interaction.kwargs else '' + cmd_name = ( + ctx_or_interaction.command.name + if hasattr(ctx_or_interaction.command, "name") + else "unknown_command" + ) + args_str = ( + " ".join([str(arg) for arg in ctx_or_interaction.args[1:]]) + if hasattr(ctx_or_interaction, "args") + else "" + ) + kwargs_str = ( + " ".join([f"{k}={v}" for k, v in ctx_or_interaction.kwargs.items()]) + if ctx_or_interaction.kwargs + else "" + ) content = f"Command: {cmd_name} {args_str} {kwargs_str}".strip() else: # For Interaction objects - if hasattr(ctx_or_interaction, '_last_response_content'): + if hasattr(ctx_or_interaction, "_last_response_content"): content = ctx_or_interaction._last_response_content - elif hasattr(ctx_or_interaction, '_internal_response'): + elif hasattr(ctx_or_interaction, "_internal_response"): content = str(ctx_or_interaction._internal_response) # Try to extract from interaction data - elif hasattr(ctx_or_interaction, 'data'): + elif hasattr(ctx_or_interaction, "data"): try: # Extract command name and options - cmd_name = ctx_or_interaction.data.get('name', 'unknown_command') - options = ctx_or_interaction.data.get('options', []) - options_str = ' '.join([f"{opt.get('name')}={opt.get('value')}" for opt in options]) if options else '' + cmd_name = ctx_or_interaction.data.get("name", "unknown_command") + options = ctx_or_interaction.data.get("options", []) + options_str = ( + " ".join( + [f"{opt.get('name')}={opt.get('value')}" for opt in options] + ) + if options + else "" + ) content = f"Slash Command: /{cmd_name} {options_str}".strip() except (AttributeError, KeyError): # If we can't extract structured data, try to get the raw data @@ -98,12 +129,15 @@ def extract_message_content(ctx_or_interaction): # For AI commands, add a note if we couldn't retrieve the full response if is_ai_command and (not content or len(content) < 100): - content = "The AI response was too long and could not be retrieved. " + \ - "This is likely due to a message that exceeded Discord's length limits. " + \ - "Please try again with a shorter prompt or request fewer details." + content = ( + "The AI response was too long and could not be retrieved. " + + "This is likely due to a message that exceeded Discord's length limits. " + + "Please try again with a shorter prompt or request fewer details." + ) return content + def log_error_details(ctx_or_interaction, error, content=None): """Log detailed error information to a file for debugging.""" timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") @@ -114,7 +148,9 @@ def log_error_details(ctx_or_interaction, error, content=None): os.makedirs(log_dir) # Create a unique filename based on timestamp - log_file = os.path.join(log_dir, f"error_{timestamp.replace(':', '-').replace(' ', '_')}.log") + log_file = os.path.join( + log_dir, f"error_{timestamp.replace(':', '-').replace(' ', '_')}.log" + ) with open(log_file, "w", encoding="utf-8") as f: f.write(f"=== Error Log: {timestamp} ===\n\n") @@ -124,7 +160,7 @@ def log_error_details(ctx_or_interaction, error, content=None): f.write(f"Error Message: {str(error)}\n\n") # Log error attributes - if hasattr(error, '__dict__'): + if hasattr(error, "__dict__"): f.write("Error Attributes:\n") for key, value in error.__dict__.items(): f.write(f" {key}: {value}\n") @@ -135,7 +171,7 @@ def log_error_details(ctx_or_interaction, error, content=None): f.write(f"Cause: {type(error.__cause__).__name__}\n") f.write(f"Cause Message: {str(error.__cause__)}\n\n") - if hasattr(error.__cause__, '__dict__'): + if hasattr(error.__cause__, "__dict__"): f.write("Cause Attributes:\n") for key, value in error.__cause__.__dict__.items(): f.write(f" {key}: {value}\n") @@ -150,29 +186,49 @@ def log_error_details(ctx_or_interaction, error, content=None): f.write("Context/Interaction Details:\n") if isinstance(ctx_or_interaction, commands.Context): f.write(f" Type: Context\n") - if hasattr(ctx_or_interaction, 'command') and ctx_or_interaction.command: + if hasattr(ctx_or_interaction, "command") and ctx_or_interaction.command: f.write(f" Command: {ctx_or_interaction.command.name}\n") - if hasattr(ctx_or_interaction, 'author') and ctx_or_interaction.author: - f.write(f" Author: {ctx_or_interaction.author.name} (ID: {ctx_or_interaction.author.id})\n") - if hasattr(ctx_or_interaction, 'guild') and ctx_or_interaction.guild: - f.write(f" Guild: {ctx_or_interaction.guild.name} (ID: {ctx_or_interaction.guild.id})\n") - if hasattr(ctx_or_interaction, 'channel') and ctx_or_interaction.channel: + if hasattr(ctx_or_interaction, "author") and ctx_or_interaction.author: + f.write( + f" Author: {ctx_or_interaction.author.name} (ID: {ctx_or_interaction.author.id})\n" + ) + if hasattr(ctx_or_interaction, "guild") and ctx_or_interaction.guild: + f.write( + f" Guild: {ctx_or_interaction.guild.name} (ID: {ctx_or_interaction.guild.id})\n" + ) + if hasattr(ctx_or_interaction, "channel") and ctx_or_interaction.channel: channel_name = ctx_or_interaction.channel.name if isinstance(ctx_or_interaction.channel, discord.DMChannel): - channel_name = f"DM with {ctx_or_interaction.channel.recipient.name}" if ctx_or_interaction.channel.recipient else "DM Channel" - f.write(f" Channel: {channel_name} (ID: {ctx_or_interaction.channel.id})\n") + channel_name = ( + f"DM with {ctx_or_interaction.channel.recipient.name}" + if ctx_or_interaction.channel.recipient + else "DM Channel" + ) + f.write( + f" Channel: {channel_name} (ID: {ctx_or_interaction.channel.id})\n" + ) else: f.write(f" Type: Interaction\n") - if hasattr(ctx_or_interaction, 'user') and ctx_or_interaction.user: - f.write(f" User: {ctx_or_interaction.user.name} (ID: {ctx_or_interaction.user.id})\n") - if hasattr(ctx_or_interaction, 'guild') and ctx_or_interaction.guild: - f.write(f" Guild: {ctx_or_interaction.guild.name} (ID: {ctx_or_interaction.guild.id})\n") - if hasattr(ctx_or_interaction, 'channel') and ctx_or_interaction.channel: + if hasattr(ctx_or_interaction, "user") and ctx_or_interaction.user: + f.write( + f" User: {ctx_or_interaction.user.name} (ID: {ctx_or_interaction.user.id})\n" + ) + if hasattr(ctx_or_interaction, "guild") and ctx_or_interaction.guild: + f.write( + f" Guild: {ctx_or_interaction.guild.name} (ID: {ctx_or_interaction.guild.id})\n" + ) + if hasattr(ctx_or_interaction, "channel") and ctx_or_interaction.channel: channel_name = ctx_or_interaction.channel.name if isinstance(ctx_or_interaction.channel, discord.DMChannel): - channel_name = f"DM with {ctx_or_interaction.channel.recipient.name}" if ctx_or_interaction.channel.recipient else "DM Channel" - f.write(f" Channel: {channel_name} (ID: {ctx_or_interaction.channel.id})\n") - if hasattr(ctx_or_interaction, 'command') and ctx_or_interaction.command: + channel_name = ( + f"DM with {ctx_or_interaction.channel.recipient.name}" + if ctx_or_interaction.channel.recipient + else "DM Channel" + ) + f.write( + f" Channel: {channel_name} (ID: {ctx_or_interaction.channel.id})\n" + ) + if hasattr(ctx_or_interaction, "command") and ctx_or_interaction.command: f.write(f" Command: {ctx_or_interaction.command.name}\n") f.write("\n") @@ -185,6 +241,7 @@ def log_error_details(ctx_or_interaction, error, content=None): print(f"Error details logged to {log_file}") return log_file + def patch_discord_methods(): """Patch Discord methods to store message content before sending.""" # Save original methods for Context @@ -215,6 +272,7 @@ def patch_discord_methods(): # This function will be available globally for use in commands global store_interaction_content + def store_interaction_content(interaction, content): """Store content in an interaction for potential error recovery""" if interaction and content: @@ -224,10 +282,12 @@ def patch_discord_methods(): from cogs.ai_cog import interaction_responses # Store using the interaction ID as the key - interaction_id = getattr(interaction, 'id', None) + interaction_id = getattr(interaction, "id", None) if interaction_id: interaction_responses[interaction_id] = content - print(f"Stored response for interaction {interaction_id} in dictionary from error_handler") + print( + f"Stored response for interaction {interaction_id} in dictionary from error_handler" + ) return True except ImportError: pass @@ -236,11 +296,14 @@ def patch_discord_methods(): interaction._last_response_content = content return True except Exception as e: - print(f"Warning: Failed to store interaction content in error_handler: {e}") + print( + f"Warning: Failed to store interaction content in error_handler: {e}" + ) return False print("Discord Context methods patched successfully") + async def send_error_embed_to_owner(ctx_or_interaction, error): """Send an embed with error details to the bot owner.""" user_id = 452666956353503252 # Owner user ID @@ -250,15 +313,16 @@ async def send_error_embed_to_owner(ctx_or_interaction, error): bot_instance = None if isinstance(ctx_or_interaction, commands.Context): bot_instance = ctx_or_interaction.bot - elif hasattr(ctx_or_interaction, 'bot'): + elif hasattr(ctx_or_interaction, "bot"): bot_instance = ctx_or_interaction.bot - elif hasattr(ctx_or_interaction, 'client'): + elif hasattr(ctx_or_interaction, "client"): bot_instance = ctx_or_interaction.client # Try to get from global accessor if not found if not bot_instance: try: from global_bot_accessor import get_bot_instance + bot_instance = get_bot_instance() except ImportError: print("Failed to import global_bot_accessor") @@ -280,26 +344,20 @@ async def send_error_embed_to_owner(ctx_or_interaction, error): title="❌ Error Report", description=f"**Error Type:** {type(error).__name__}\n**Message:** {str(error)}", color=0xFF0000, # Red color - timestamp=datetime.datetime.now() + timestamp=datetime.datetime.now(), ) # Add command info command_name = "Unknown" if isinstance(ctx_or_interaction, commands.Context): - if hasattr(ctx_or_interaction, 'command') and ctx_or_interaction.command: + if hasattr(ctx_or_interaction, "command") and ctx_or_interaction.command: command_name = ctx_or_interaction.command.name - embed.add_field( - name="Command", - value=f"`{command_name}`", - inline=True - ) + embed.add_field(name="Command", value=f"`{command_name}`", inline=True) else: # It's an interaction - if hasattr(ctx_or_interaction, 'command') and ctx_or_interaction.command: + if hasattr(ctx_or_interaction, "command") and ctx_or_interaction.command: command_name = ctx_or_interaction.command.name embed.add_field( - name="Slash Command", - value=f"`/{command_name}`", - inline=True + name="Slash Command", value=f"`/{command_name}`", inline=True ) # Add user info @@ -309,13 +367,11 @@ async def send_error_embed_to_owner(ctx_or_interaction, error): user_info = f"{ctx_or_interaction.author.name} (ID: {ctx_or_interaction.author.id})" else: # It's an interaction if ctx_or_interaction.user: - user_info = f"{ctx_or_interaction.user.name} (ID: {ctx_or_interaction.user.id})" + user_info = ( + f"{ctx_or_interaction.user.name} (ID: {ctx_or_interaction.user.id})" + ) - embed.add_field( - name="User", - value=user_info, - inline=True - ) + embed.add_field(name="User", value=user_info, inline=True) # Add guild and channel info guild_info = "DM" @@ -326,7 +382,11 @@ async def send_error_embed_to_owner(ctx_or_interaction, error): guild_info = f"{ctx_or_interaction.guild.name} (ID: {ctx_or_interaction.guild.id})" if ctx_or_interaction.channel: if isinstance(ctx_or_interaction.channel, discord.DMChannel): - channel_info = f"DM with {ctx_or_interaction.channel.recipient.name}" if ctx_or_interaction.channel.recipient else "DM Channel" + channel_info = ( + f"DM with {ctx_or_interaction.channel.recipient.name}" + if ctx_or_interaction.channel.recipient + else "DM Channel" + ) else: channel_info = f"#{ctx_or_interaction.channel.name} (ID: {ctx_or_interaction.channel.id})" else: # It's an interaction @@ -334,27 +394,23 @@ async def send_error_embed_to_owner(ctx_or_interaction, error): guild_info = f"{ctx_or_interaction.guild.name} (ID: {ctx_or_interaction.guild.id})" if ctx_or_interaction.channel: if isinstance(ctx_or_interaction.channel, discord.DMChannel): - channel_info = f"DM with {ctx_or_interaction.channel.recipient.name}" if ctx_or_interaction.channel.recipient else "DM Channel" + channel_info = ( + f"DM with {ctx_or_interaction.channel.recipient.name}" + if ctx_or_interaction.channel.recipient + else "DM Channel" + ) else: channel_info = f"#{ctx_or_interaction.channel.name} (ID: {ctx_or_interaction.channel.id})" - embed.add_field( - name="Server", - value=guild_info, - inline=True - ) + embed.add_field(name="Server", value=guild_info, inline=True) - embed.add_field( - name="Channel", - value=channel_info, - inline=True - ) + embed.add_field(name="Channel", value=channel_info, inline=True) # Add timestamp field embed.add_field( name="Timestamp", value=f"", - inline=True + inline=True, ) # Add traceback as a field (truncated) @@ -363,9 +419,7 @@ async def send_error_embed_to_owner(ctx_or_interaction, error): tb_str = tb_str[:997] + "..." embed.add_field( - name="Traceback", - value=f"```python\n{tb_str}\n```", - inline=False + name="Traceback", value=f"```python\n{tb_str}\n```", inline=False ) # Add cause if available @@ -373,7 +427,7 @@ async def send_error_embed_to_owner(ctx_or_interaction, error): embed.add_field( name="Cause", value=f"```{type(error.__cause__).__name__}: {str(error.__cause__)}\n```", - inline=False + inline=False, ) # Extract content for context @@ -383,19 +437,19 @@ async def send_error_embed_to_owner(ctx_or_interaction, error): content = content[:997] + "..." if content: embed.add_field( - name="Message Content", - value=f"```\n{content}\n```", - inline=False + name="Message Content", value=f"```\n{content}\n```", inline=False ) except Exception as e: embed.add_field( name="Content Extraction Error", value=f"Failed to extract message content: {str(e)}", - inline=False + inline=False, ) # Set footer - embed.set_footer(text=f"Error ID: {datetime.datetime.now().strftime('%Y%m%d%H%M%S')}") + embed.set_footer( + text=f"Error ID: {datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" + ) # Send the embed to the owner await owner.send(embed=embed) @@ -405,10 +459,13 @@ async def send_error_embed_to_owner(ctx_or_interaction, error): # Fall back to simple text message if embed fails try: if bot_instance and (owner := await bot_instance.fetch_user(user_id)): - await owner.send(f"Error occurred but failed to create embed: {str(error)}\nEmbed error: {str(e)}") + await owner.send( + f"Error occurred but failed to create embed: {str(error)}\nEmbed error: {str(e)}" + ) except: print("Complete failure in error reporting system") + async def handle_error(ctx_or_interaction, error): user_id = 452666956353503252 # Owner user ID @@ -417,6 +474,7 @@ async def handle_error(ctx_or_interaction, error): # Import here to avoid circular imports try: from cogs.ban_system_cog import UserBannedError + if isinstance(error, UserBannedError): # This should be handled in the main.py error handlers # But just in case it gets here, handle it properly @@ -425,7 +483,9 @@ async def handle_error(ctx_or_interaction, error): await ctx_or_interaction.send(message, ephemeral=True) else: if not ctx_or_interaction.response.is_done(): - await ctx_or_interaction.response.send_message(message, ephemeral=True) + await ctx_or_interaction.response.send_message( + message, ephemeral=True + ) else: await ctx_or_interaction.followup.send(message, ephemeral=True) return @@ -446,7 +506,7 @@ async def handle_error(ctx_or_interaction, error): # Also send to owner in an embed if isinstance(error, commands.MissingRequiredArgument): - missing_arg = error.param.name if hasattr(error, 'param') else 'an argument' + missing_arg = error.param.name if hasattr(error, "param") else "an argument" message = f"Missing required argument: `{missing_arg}`. Please provide all required arguments." if isinstance(ctx_or_interaction, commands.Context): await ctx_or_interaction.send(message) @@ -461,7 +521,11 @@ async def handle_error(ctx_or_interaction, error): return # Special handling for interaction timeout errors (10062: Unknown interaction) - if isinstance(error, commands.CommandInvokeError) and isinstance(error.original, discord.NotFound) and error.original.code == 10062: + if ( + isinstance(error, commands.CommandInvokeError) + and isinstance(error.original, discord.NotFound) + and error.original.code == 10062 + ): print(f"Interaction timeout error (10062): {error}") # This error occurs when Discord's interaction token expires (after 3 seconds) # We can't respond to the interaction anymore, so we'll just log it and notify the owner @@ -472,33 +536,54 @@ async def handle_error(ctx_or_interaction, error): # Check if this is an AI command error is_ai_command = False - if isinstance(ctx_or_interaction, commands.Context) and hasattr(ctx_or_interaction, 'command'): - is_ai_command = ctx_or_interaction.command and ctx_or_interaction.command.name == 'ai' - elif hasattr(ctx_or_interaction, 'command') and ctx_or_interaction.command: - is_ai_command = ctx_or_interaction.command.name == 'ai' + if isinstance(ctx_or_interaction, commands.Context) and hasattr( + ctx_or_interaction, "command" + ): + is_ai_command = ( + ctx_or_interaction.command and ctx_or_interaction.command.name == "ai" + ) + elif hasattr(ctx_or_interaction, "command") and ctx_or_interaction.command: + is_ai_command = ctx_or_interaction.command.name == "ai" # For AI command errors with HTTPException, try to handle specially - if is_ai_command and isinstance(error, commands.CommandInvokeError) and isinstance(error.original, discord.HTTPException): - if error.original.code == 50035 and "Must be 4000 or fewer in length" in str(error.original): + if ( + is_ai_command + and isinstance(error, commands.CommandInvokeError) + and isinstance(error.original, discord.HTTPException) + ): + if error.original.code == 50035 and "Must be 4000 or fewer in length" in str( + error.original + ): # Try to get the AI response from the stored content - if isinstance(ctx_or_interaction, commands.Context) and hasattr(ctx_or_interaction, '_last_message_content'): + if isinstance(ctx_or_interaction, commands.Context) and hasattr( + ctx_or_interaction, "_last_message_content" + ): content = ctx_or_interaction._last_message_content # Save to file and send - with open('ai_response.txt', 'w', encoding='utf-8') as f: + with open("ai_response.txt", "w", encoding="utf-8") as f: f.write(content) - await ctx_or_interaction.send("The AI response was too long. Here's the content as a file:", file=discord.File('ai_response.txt')) + await ctx_or_interaction.send( + "The AI response was too long. Here's the content as a file:", + file=discord.File("ai_response.txt"), + ) # Also notify the owner await send_error_embed_to_owner(ctx_or_interaction, error) return - elif hasattr(ctx_or_interaction, '_last_response_content'): + elif hasattr(ctx_or_interaction, "_last_response_content"): content = ctx_or_interaction._last_response_content # Save to file and send - with open('ai_response.txt', 'w', encoding='utf-8') as f: + with open("ai_response.txt", "w", encoding="utf-8") as f: f.write(content) if not ctx_or_interaction.response.is_done(): - await ctx_or_interaction.response.send_message("The AI response was too long. Here's the content as a file:", file=discord.File('ai_response.txt')) + await ctx_or_interaction.response.send_message( + "The AI response was too long. Here's the content as a file:", + file=discord.File("ai_response.txt"), + ) else: - await ctx_or_interaction.followup.send("The AI response was too long. Here's the content as a file:", file=discord.File('ai_response.txt')) + await ctx_or_interaction.followup.send( + "The AI response was too long. Here's the content as a file:", + file=discord.File("ai_response.txt"), + ) # Also notify the owner await send_error_embed_to_owner(ctx_or_interaction, error) return @@ -523,9 +608,9 @@ async def handle_error(ctx_or_interaction, error): bot_instance = None if isinstance(ctx_or_interaction, commands.Context): bot_instance = ctx_or_interaction.bot - elif hasattr(ctx_or_interaction, 'bot'): + elif hasattr(ctx_or_interaction, "bot"): bot_instance = ctx_or_interaction.bot - elif hasattr(ctx_or_interaction, 'client'): + elif hasattr(ctx_or_interaction, "client"): bot_instance = ctx_or_interaction.client # If we couldn't get the bot instance, try to get it from the global accessor @@ -533,6 +618,7 @@ async def handle_error(ctx_or_interaction, error): try: # Import here to avoid circular imports from global_bot_accessor import get_bot_instance + bot_instance = get_bot_instance() except ImportError: print("Failed to import global_bot_accessor") @@ -548,12 +634,14 @@ async def handle_error(ctx_or_interaction, error): owner = await bot_instance.fetch_user(user_id) if owner: full_error = f"Full error details:\n```\n{str(error)}\n" - if hasattr(error, '__dict__'): + if hasattr(error, "__dict__"): full_error += f"\nError attributes:\n{error.__dict__}\n" if error.__cause__: full_error += f"\nCause:\n{str(error.__cause__)}\n" - if hasattr(error.__cause__, '__dict__'): - full_error += f"\nCause attributes:\n{error.__cause__.__dict__}\n" + if hasattr(error.__cause__, "__dict__"): + full_error += ( + f"\nCause attributes:\n{error.__cause__.__dict__}\n" + ) full_error += "```" # Add log file path to the error message @@ -561,7 +649,9 @@ async def handle_error(ctx_or_interaction, error): # Try to send the log file as an attachment try: - await owner.send("Here's the detailed error log:", file=discord.File(log_file)) + await owner.send( + "Here's the detailed error log:", file=discord.File(log_file) + ) # Send a shorter message since we sent the file short_error = f"Error: {str(error)}" if error.__cause__: @@ -571,7 +661,10 @@ async def handle_error(ctx_or_interaction, error): # If sending the file fails, fall back to text messages # Split long messages if needed if len(full_error) > 1900: - parts = [full_error[i:i+1900] for i in range(0, len(full_error), 1900)] + parts = [ + full_error[i : i + 1900] + for i in range(0, len(full_error), 1900) + ] for i, part in enumerate(parts): await owner.send(f"Part {i+1}/{len(parts)}:\n{part}") else: @@ -580,28 +673,57 @@ async def handle_error(ctx_or_interaction, error): print(f"Failed to send error DM to owner: {e}") # Determine the file name to use for saving content - file_name = 'message.txt' + file_name = "message.txt" # Special handling for AI command errors - if isinstance(error, commands.CommandInvokeError) and isinstance(error.original, discord.HTTPException): + if isinstance(error, commands.CommandInvokeError) and isinstance( + error.original, discord.HTTPException + ): # Check if this is an AI command is_ai_command = False - if isinstance(ctx_or_interaction, commands.Context) and hasattr(ctx_or_interaction, 'command'): - is_ai_command = ctx_or_interaction.command and ctx_or_interaction.command.name == 'ai' - elif hasattr(ctx_or_interaction, 'command') and ctx_or_interaction.command: - is_ai_command = ctx_or_interaction.command.name == 'ai' + if isinstance(ctx_or_interaction, commands.Context) and hasattr( + ctx_or_interaction, "command" + ): + is_ai_command = ( + ctx_or_interaction.command and ctx_or_interaction.command.name == "ai" + ) + elif hasattr(ctx_or_interaction, "command") and ctx_or_interaction.command: + is_ai_command = ctx_or_interaction.command.name == "ai" # If it's an AI command, use a different file name if is_ai_command: - file_name = 'ai_response.txt' + file_name = "ai_response.txt" # Handle message too long error (HTTP 400 - Code 50035 or 40005 for file uploads) - if (isinstance(error, discord.HTTPException) and - ((error.code == 50035 and ("Must be 4000 or fewer in length" in str(error) or "Must be 2000 or fewer in length" in str(error))) or - (error.code == 40005 and "Request entity too large" in str(error)))) or \ - (isinstance(error, commands.CommandInvokeError) and isinstance(error.original, discord.HTTPException) and - ((error.original.code == 50035 and ("Must be 4000 or fewer in length" in str(error.original) or "Must be 2000 or fewer in length" in str(error.original))) or - (error.original.code == 40005 and "Request entity too large" in str(error.original)))): + if ( + isinstance(error, discord.HTTPException) + and ( + ( + error.code == 50035 + and ( + "Must be 4000 or fewer in length" in str(error) + or "Must be 2000 or fewer in length" in str(error) + ) + ) + or (error.code == 40005 and "Request entity too large" in str(error)) + ) + ) or ( + isinstance(error, commands.CommandInvokeError) + and isinstance(error.original, discord.HTTPException) + and ( + ( + error.original.code == 50035 + and ( + "Must be 4000 or fewer in length" in str(error.original) + or "Must be 2000 or fewer in length" in str(error.original) + ) + ) + or ( + error.original.code == 40005 + and "Request entity too large" in str(error.original) + ) + ) + ): # Try to extract the actual content from the error content = None @@ -610,15 +732,19 @@ async def handle_error(ctx_or_interaction, error): # Use the original error for extraction original_error = error.original if isinstance(original_error, discord.HTTPException): - content = original_error.text if hasattr(original_error, 'text') else None + content = ( + original_error.text if hasattr(original_error, "text") else None + ) # If it's a wrapped error, get the original error's content elif isinstance(error.__cause__, discord.HTTPException): - content = error.__cause__.text if hasattr(error.__cause__, 'text') else None + content = error.__cause__.text if hasattr(error.__cause__, "text") else None else: - content = error.text if hasattr(error, 'text') else None + content = error.text if hasattr(error, "text") else None # If content is not available in the error, try to retrieve it from the context/interaction - if not content or len(content) < 10: # If content is missing or too short to be the actual message + if ( + not content or len(content) < 10 + ): # If content is missing or too short to be the actual message # Try to get the original content using our utility function content = extract_message_content(ctx_or_interaction) @@ -629,33 +755,28 @@ async def handle_error(ctx_or_interaction, error): # Try to send as a file first try: # Create a text file with the content - with open(file_name, 'w', encoding='utf-8') as f: + with open(file_name, "w", encoding="utf-8") as f: f.write(content) # Send the file instead message = f"The message was too long. Here's the content as a file:\nError details logged to: {log_file}" if isinstance(ctx_or_interaction, commands.Context): - await ctx_or_interaction.send( - message, - file=discord.File(file_name) - ) + await ctx_or_interaction.send(message, file=discord.File(file_name)) else: if not ctx_or_interaction.response.is_done(): await ctx_or_interaction.response.send_message( - message, - file=discord.File(file_name) + message, file=discord.File(file_name) ) else: await ctx_or_interaction.followup.send( - message, - file=discord.File(file_name) + message, file=discord.File(file_name) ) except discord.HTTPException as e: # If sending as a file also fails (e.g., file too large), split into multiple messages if e.code == 40005 or "Request entity too large" in str(e): # Split the content into chunks of 1900 characters (Discord limit is 2000) - chunks = [content[i:i+1900] for i in range(0, len(content), 1900)] + chunks = [content[i : i + 1900] for i in range(0, len(content), 1900)] # Send a notification about splitting the message intro_message = f"The message was too long to send as a file. Splitting into {len(chunks)} parts.\nError details logged to: {log_file}" @@ -663,16 +784,22 @@ async def handle_error(ctx_or_interaction, error): if isinstance(ctx_or_interaction, commands.Context): await ctx_or_interaction.send(intro_message) for i, chunk in enumerate(chunks): - await ctx_or_interaction.send(f"Part {i+1}/{len(chunks)}:\n```\n{chunk}\n```") + await ctx_or_interaction.send( + f"Part {i+1}/{len(chunks)}:\n```\n{chunk}\n```" + ) else: if not ctx_or_interaction.response.is_done(): await ctx_or_interaction.response.send_message(intro_message) for i, chunk in enumerate(chunks): - await ctx_or_interaction.followup.send(f"Part {i+1}/{len(chunks)}:\n```\n{chunk}\n```") + await ctx_or_interaction.followup.send( + f"Part {i+1}/{len(chunks)}:\n```\n{chunk}\n```" + ) else: await ctx_or_interaction.followup.send(intro_message) for i, chunk in enumerate(chunks): - await ctx_or_interaction.followup.send(f"Part {i+1}/{len(chunks)}:\n```\n{chunk}\n```") + await ctx_or_interaction.followup.send( + f"Part {i+1}/{len(chunks)}:\n```\n{chunk}\n```" + ) else: # If it's a different error, re-raise it raise @@ -687,17 +814,29 @@ async def handle_error(ctx_or_interaction, error): try: await ctx_or_interaction.send(content=error_message) except discord.Forbidden: - await ctx_or_interaction.send("Unable to send you a DM with the error details.") + await ctx_or_interaction.send( + "Unable to send you a DM with the error details." + ) else: - await ctx_or_interaction.send("An error occurred while processing your command.") + await ctx_or_interaction.send( + "An error occurred while processing your command." + ) else: if not ctx_or_interaction.response.is_done(): if ctx_or_interaction.user.id == user_id: - await ctx_or_interaction.response.send_message(content=error_message, ephemeral=True) + await ctx_or_interaction.response.send_message( + content=error_message, ephemeral=True + ) else: - await ctx_or_interaction.response.send_message("An error occurred while processing your command.", ephemeral=True) + await ctx_or_interaction.response.send_message( + "An error occurred while processing your command.", ephemeral=True + ) else: if ctx_or_interaction.user.id == user_id: - await ctx_or_interaction.followup.send(content=error_message, ephemeral=True) + await ctx_or_interaction.followup.send( + content=error_message, ephemeral=True + ) else: - await ctx_or_interaction.followup.send("An error occurred while processing your command.", ephemeral=True) + await ctx_or_interaction.followup.send( + "An error occurred while processing your command.", ephemeral=True + ) diff --git a/flask_server.py b/flask_server.py index 39747cc..d928e9d 100644 --- a/flask_server.py +++ b/flask_server.py @@ -12,23 +12,26 @@ load_dotenv() GITHUB_SECRET = os.getenv("GITHUB_SECRET").encode() app = Flask(__name__) + def verify_signature(payload, signature): mac = hmac.new(GITHUB_SECRET, payload, hashlib.sha256) expected = "sha256=" + mac.hexdigest() return hmac.compare_digest(expected, signature) + def kill_main_process(): - for proc in psutil.process_iter(['pid', 'name', 'cmdline']): - if proc.info['cmdline'] and 'main.py' in proc.info['cmdline']: + for proc in psutil.process_iter(["pid", "name", "cmdline"]): + if proc.info["cmdline"] and "main.py" in proc.info["cmdline"]: print(f"Killing process {proc.info['pid']} running main.py") proc.terminate() proc.wait() + @app.route("/github-webhook-123", methods=["POST"]) def webhook(): signature = request.headers.get("X-Hub-Signature-256") if not signature or not verify_signature(request.data, signature): - abort(404) # If its a 404, nobody will suspect theres a real endpoint here + abort(404) # If its a 404, nobody will suspect theres a real endpoint here # Restart main.py logic print("Webhook received and verified. Restarting bot...") @@ -36,5 +39,6 @@ def webhook(): subprocess.Popen(["python", "main.py"], cwd=os.path.dirname(__file__)) return "Bot restarting." + if __name__ == "__main__": - app.run(host="127.0.0.1", port=5000) \ No newline at end of file + app.run(host="127.0.0.1", port=5000) diff --git a/global_bot_accessor.py b/global_bot_accessor.py index ffa583f..38989d9 100644 --- a/global_bot_accessor.py +++ b/global_bot_accessor.py @@ -7,17 +7,21 @@ to the bot instance, especially its shared resources like connection pools. _bot_instance = None + def set_bot_instance(bot_instance_ref): """ Sets the global bot instance. Should be called once from main.py after the bot object is created. """ global _bot_instance - if _bot_instance is not None and _bot_instance != bot_instance_ref : + if _bot_instance is not None and _bot_instance != bot_instance_ref: # This might indicate an issue if called multiple times with different instances - print(f"WARNING: Global bot instance is being overwritten. Old ID: {id(_bot_instance)}, New ID: {id(bot_instance_ref)}") + print( + f"WARNING: Global bot instance is being overwritten. Old ID: {id(_bot_instance)}, New ID: {id(bot_instance_ref)}" + ) _bot_instance = bot_instance_ref + def get_bot_instance(): """ Retrieves the global bot instance. diff --git a/gurt/__init__.py b/gurt/__init__.py index 1a99e3f..0b73d0d 100644 --- a/gurt/__init__.py +++ b/gurt/__init__.py @@ -5,4 +5,4 @@ from .cog import setup # This makes "from gurt import setup" work -__all__ = ['setup'] +__all__ = ["setup"] diff --git a/gurt/analysis.py b/gurt/analysis.py index 54df69a..566067e 100644 --- a/gurt/analysis.py +++ b/gurt/analysis.py @@ -9,28 +9,35 @@ logger = logging.getLogger(__name__) # Relative imports from .config import ( - MAX_PATTERNS_PER_CHANNEL, LEARNING_RATE, TOPIC_UPDATE_INTERVAL, - TOPIC_RELEVANCE_DECAY, MAX_ACTIVE_TOPICS, SENTIMENT_DECAY_RATE, - EMOTION_KEYWORDS, EMOJI_SENTIMENT # Import necessary configs + MAX_PATTERNS_PER_CHANNEL, + LEARNING_RATE, + TOPIC_UPDATE_INTERVAL, + TOPIC_RELEVANCE_DECAY, + MAX_ACTIVE_TOPICS, + SENTIMENT_DECAY_RATE, + EMOTION_KEYWORDS, + EMOJI_SENTIMENT, # Import necessary configs ) if TYPE_CHECKING: - from .cog import GurtCog # For type hinting + from .cog import GurtCog # For type hinting # --- Analysis Functions --- # Note: These functions need the 'cog' instance passed to access state like caches, etc. -async def analyze_conversation_patterns(cog: 'GurtCog'): + +async def analyze_conversation_patterns(cog: "GurtCog"): """Analyzes recent conversations to identify patterns and learn from them""" print("Analyzing conversation patterns and updating topics...") try: # Update conversation topics first await update_conversation_topics(cog) - for channel_id, messages in cog.message_cache['by_channel'].items(): - if len(messages) < 10: continue + for channel_id, messages in cog.message_cache["by_channel"].items(): + if len(messages) < 10: + continue - channel_patterns = extract_conversation_patterns(cog, messages) # Pass cog + channel_patterns = extract_conversation_patterns(cog, messages) # Pass cog if channel_patterns: existing_patterns = cog.conversation_patterns[channel_id] combined_patterns = existing_patterns + channel_patterns @@ -38,34 +45,41 @@ async def analyze_conversation_patterns(cog: 'GurtCog'): combined_patterns = combined_patterns[-MAX_PATTERNS_PER_CHANNEL:] cog.conversation_patterns[channel_id] = combined_patterns - analyze_conversation_dynamics(cog, channel_id, messages) # Pass cog + analyze_conversation_dynamics(cog, channel_id, messages) # Pass cog - update_user_preferences(cog) # Pass cog + update_user_preferences(cog) # Pass cog # adapt_personality_traits(cog) # Pass cog - Deprecated/Superseded by evolve_personality except Exception as e: print(f"Error analyzing conversation patterns: {e}") traceback.print_exc() -async def update_conversation_topics(cog: 'GurtCog'): + +async def update_conversation_topics(cog: "GurtCog"): """Updates the active topics for each channel based on recent messages""" try: - for channel_id, messages in cog.message_cache['by_channel'].items(): - if len(messages) < 5: continue + for channel_id, messages in cog.message_cache["by_channel"].items(): + if len(messages) < 5: + continue channel_topics = cog.active_topics[channel_id] now = time.time() - if now - channel_topics["last_update"] < TOPIC_UPDATE_INTERVAL: continue + if now - channel_topics["last_update"] < TOPIC_UPDATE_INTERVAL: + continue recent_messages = list(messages)[-30:] - topics = identify_conversation_topics(cog, recent_messages) # Pass cog - if not topics: continue + topics = identify_conversation_topics(cog, recent_messages) # Pass cog + if not topics: + continue old_topics = channel_topics["topics"] - for topic in old_topics: topic["score"] *= (1 - TOPIC_RELEVANCE_DECAY) + for topic in old_topics: + topic["score"] *= 1 - TOPIC_RELEVANCE_DECAY for new_topic in topics: - existing = next((t for t in old_topics if t["topic"] == new_topic["topic"]), None) + existing = next( + (t for t in old_topics if t["topic"] == new_topic["topic"]), None + ) if existing: existing["score"] = max(existing["score"], new_topic["score"]) existing["related_terms"] = new_topic["related_terms"] @@ -80,13 +94,22 @@ async def update_conversation_topics(cog: 'GurtCog'): old_topics = old_topics[:MAX_ACTIVE_TOPICS] if old_topics and channel_topics["topics"] != old_topics: - if not channel_topics["topic_history"] or set(t["topic"] for t in old_topics) != set(t["topic"] for t in channel_topics["topics"]): - channel_topics["topic_history"].append({ - "topics": [{"topic": t["topic"], "score": t["score"]} for t in old_topics], - "timestamp": now - }) + if not channel_topics["topic_history"] or set( + t["topic"] for t in old_topics + ) != set(t["topic"] for t in channel_topics["topics"]): + channel_topics["topic_history"].append( + { + "topics": [ + {"topic": t["topic"], "score": t["score"]} + for t in old_topics + ], + "timestamp": now, + } + ) if len(channel_topics["topic_history"]) > 10: - channel_topics["topic_history"] = channel_topics["topic_history"][-10:] + channel_topics["topic_history"] = channel_topics[ + "topic_history" + ][-10:] for msg in recent_messages: user_id = msg["author"]["id"] @@ -95,146 +118,352 @@ async def update_conversation_topics(cog: 'GurtCog'): topic_text = topic["topic"].lower() if topic_text in content: user_interests = channel_topics["user_topic_interests"][user_id] - existing = next((i for i in user_interests if i["topic"] == topic["topic"]), None) + existing = next( + (i for i in user_interests if i["topic"] == topic["topic"]), + None, + ) if existing: - existing["score"] = existing["score"] * 0.8 + topic["score"] * 0.2 + existing["score"] = ( + existing["score"] * 0.8 + topic["score"] * 0.2 + ) existing["last_mentioned"] = now else: - user_interests.append({ - "topic": topic["topic"], "score": topic["score"] * 0.5, - "first_mentioned": now, "last_mentioned": now - }) + user_interests.append( + { + "topic": topic["topic"], + "score": topic["score"] * 0.5, + "first_mentioned": now, + "last_mentioned": now, + } + ) channel_topics["topics"] = old_topics channel_topics["last_update"] = now if old_topics: - topic_str = ", ".join([f"{t['topic']} ({t['score']:.2f})" for t in old_topics[:3]]) + topic_str = ", ".join( + [f"{t['topic']} ({t['score']:.2f})" for t in old_topics[:3]] + ) print(f"Updated topics for channel {channel_id}: {topic_str}") except Exception as e: print(f"Error updating conversation topics: {e}") traceback.print_exc() -def analyze_conversation_dynamics(cog: 'GurtCog', channel_id: int, messages: List[Dict[str, Any]]): + +def analyze_conversation_dynamics( + cog: "GurtCog", channel_id: int, messages: List[Dict[str, Any]] +): """Analyzes conversation dynamics like response times, message lengths, etc.""" - if len(messages) < 5: return + if len(messages) < 5: + return try: response_times = [] response_map = defaultdict(int) message_lengths = defaultdict(list) question_answer_pairs = [] - import datetime # Import here + import datetime # Import here for i in range(1, len(messages)): - current_msg = messages[i]; prev_msg = messages[i-1] - if current_msg["author"]["id"] == prev_msg["author"]["id"]: continue + current_msg = messages[i] + prev_msg = messages[i - 1] + if current_msg["author"]["id"] == prev_msg["author"]["id"]: + continue try: - current_time = datetime.datetime.fromisoformat(current_msg["created_at"]) + current_time = datetime.datetime.fromisoformat( + current_msg["created_at"] + ) prev_time = datetime.datetime.fromisoformat(prev_msg["created_at"]) delta_seconds = (current_time - prev_time).total_seconds() - if 0 < delta_seconds < 300: response_times.append(delta_seconds) - except (ValueError, TypeError): pass + if 0 < delta_seconds < 300: + response_times.append(delta_seconds) + except (ValueError, TypeError): + pass - responder = current_msg["author"]["id"]; respondee = prev_msg["author"]["id"] + responder = current_msg["author"]["id"] + respondee = prev_msg["author"]["id"] response_map[f"{responder}:{respondee}"] += 1 message_lengths[responder].append(len(current_msg["content"])) if prev_msg["content"].endswith("?"): - question_answer_pairs.append({ - "question": prev_msg["content"], "answer": current_msg["content"], - "question_author": prev_msg["author"]["id"], "answer_author": current_msg["author"]["id"] - }) + question_answer_pairs.append( + { + "question": prev_msg["content"], + "answer": current_msg["content"], + "question_author": prev_msg["author"]["id"], + "answer_author": current_msg["author"]["id"], + } + ) - avg_response_time = sum(response_times) / len(response_times) if response_times else 0 - top_responders = sorted(response_map.items(), key=lambda x: x[1], reverse=True)[:3] - avg_message_lengths = {uid: sum(ls)/len(ls) if ls else 0 for uid, ls in message_lengths.items()} + avg_response_time = ( + sum(response_times) / len(response_times) if response_times else 0 + ) + top_responders = sorted(response_map.items(), key=lambda x: x[1], reverse=True)[ + :3 + ] + avg_message_lengths = { + uid: sum(ls) / len(ls) if ls else 0 for uid, ls in message_lengths.items() + } dynamics = { - "avg_response_time": avg_response_time, "top_responders": top_responders, - "avg_message_lengths": avg_message_lengths, "question_answer_count": len(question_answer_pairs), - "last_updated": time.time() + "avg_response_time": avg_response_time, + "top_responders": top_responders, + "avg_message_lengths": avg_message_lengths, + "question_answer_count": len(question_answer_pairs), + "last_updated": time.time(), } - if not hasattr(cog, 'conversation_dynamics'): cog.conversation_dynamics = {} + if not hasattr(cog, "conversation_dynamics"): + cog.conversation_dynamics = {} cog.conversation_dynamics[channel_id] = dynamics - adapt_to_conversation_dynamics(cog, channel_id, dynamics) # Pass cog + adapt_to_conversation_dynamics(cog, channel_id, dynamics) # Pass cog - except Exception as e: print(f"Error analyzing conversation dynamics: {e}") + except Exception as e: + print(f"Error analyzing conversation dynamics: {e}") -def adapt_to_conversation_dynamics(cog: 'GurtCog', channel_id: int, dynamics: Dict[str, Any]): + +def adapt_to_conversation_dynamics( + cog: "GurtCog", channel_id: int, dynamics: Dict[str, Any] +): """Adapts bot behavior based on observed conversation dynamics.""" try: if dynamics["avg_response_time"] > 0: - if not hasattr(cog, 'channel_response_timing'): cog.channel_response_timing = {} - response_time_factor = max(0.7, min(1.0, dynamics["avg_response_time"] / 10)) + if not hasattr(cog, "channel_response_timing"): + cog.channel_response_timing = {} + response_time_factor = max( + 0.7, min(1.0, dynamics["avg_response_time"] / 10) + ) cog.channel_response_timing[channel_id] = response_time_factor if dynamics["avg_message_lengths"]: all_lengths = [ls for ls in dynamics["avg_message_lengths"].values()] if all_lengths: avg_length = sum(all_lengths) / len(all_lengths) - if not hasattr(cog, 'channel_message_length'): cog.channel_message_length = {} + if not hasattr(cog, "channel_message_length"): + cog.channel_message_length = {} length_factor = min(avg_length / 200, 1.0) cog.channel_message_length[channel_id] = length_factor if dynamics["question_answer_count"] > 0: - if not hasattr(cog, 'channel_qa_responsiveness'): cog.channel_qa_responsiveness = {} + if not hasattr(cog, "channel_qa_responsiveness"): + cog.channel_qa_responsiveness = {} qa_factor = min(0.9, 0.5 + (dynamics["question_answer_count"] / 20) * 0.4) cog.channel_qa_responsiveness[channel_id] = qa_factor - except Exception as e: print(f"Error adapting to conversation dynamics: {e}") + except Exception as e: + print(f"Error adapting to conversation dynamics: {e}") -def extract_conversation_patterns(cog: 'GurtCog', messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + +def extract_conversation_patterns( + cog: "GurtCog", messages: List[Dict[str, Any]] +) -> List[Dict[str, Any]]: """Extract patterns from a sequence of messages""" patterns = [] - if len(messages) < 5: return patterns - import datetime # Import here + if len(messages) < 5: + return patterns + import datetime # Import here for i in range(len(messages) - 2): pattern = { "type": "message_sequence", "messages": [ - {"author_type": "user" if not messages[i]["author"]["bot"] else "bot", "content_sample": messages[i]["content"][:50]}, - {"author_type": "user" if not messages[i+1]["author"]["bot"] else "bot", "content_sample": messages[i+1]["content"][:50]}, - {"author_type": "user" if not messages[i+2]["author"]["bot"] else "bot", "content_sample": messages[i+2]["content"][:50]} - ], "timestamp": datetime.datetime.now().isoformat() + { + "author_type": ( + "user" if not messages[i]["author"]["bot"] else "bot" + ), + "content_sample": messages[i]["content"][:50], + }, + { + "author_type": ( + "user" if not messages[i + 1]["author"]["bot"] else "bot" + ), + "content_sample": messages[i + 1]["content"][:50], + }, + { + "author_type": ( + "user" if not messages[i + 2]["author"]["bot"] else "bot" + ), + "content_sample": messages[i + 2]["content"][:50], + }, + ], + "timestamp": datetime.datetime.now().isoformat(), } patterns.append(pattern) - topics = identify_conversation_topics(cog, messages) # Pass cog - if topics: patterns.append({"type": "topic_pattern", "topics": topics, "timestamp": datetime.datetime.now().isoformat()}) + topics = identify_conversation_topics(cog, messages) # Pass cog + if topics: + patterns.append( + { + "type": "topic_pattern", + "topics": topics, + "timestamp": datetime.datetime.now().isoformat(), + } + ) - user_interactions = analyze_user_interactions(cog, messages) # Pass cog - if user_interactions: patterns.append({"type": "user_interaction", "interactions": user_interactions, "timestamp": datetime.datetime.now().isoformat()}) + user_interactions = analyze_user_interactions(cog, messages) # Pass cog + if user_interactions: + patterns.append( + { + "type": "user_interaction", + "interactions": user_interactions, + "timestamp": datetime.datetime.now().isoformat(), + } + ) return patterns -def identify_conversation_topics(cog: 'GurtCog', messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + +def identify_conversation_topics( + cog: "GurtCog", messages: List[Dict[str, Any]] +) -> List[Dict[str, Any]]: """Identify potential topics from conversation messages.""" - if not messages or len(messages) < 3: return [] + if not messages or len(messages) < 3: + return [] all_text = " ".join([msg["content"] for msg in messages]) - stopwords = { # Expanded stopwords - "the", "and", "is", "in", "to", "a", "of", "for", "that", "this", "it", "with", "on", "as", "be", "at", "by", "an", "or", "but", "if", "from", "when", "where", "how", "all", "any", "both", "each", "few", "more", "most", "some", "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "can", "will", "just", "should", "now", "also", "like", "even", "because", "way", "who", "what", "yeah", "yes", "no", "nah", "lol", "lmao", "haha", "hmm", "um", "uh", "oh", "ah", "ok", "okay", "dont", "don't", "doesnt", "doesn't", "didnt", "didn't", "cant", "can't", "im", "i'm", "ive", "i've", "youre", "you're", "youve", "you've", "hes", "he's", "shes", "she's", "its", "it's", "were", "we're", "weve", "we've", "theyre", "they're", "theyve", "they've", "thats", "that's", "whats", "what's", "whos", "who's", "gonna", "gotta", "kinda", "sorta", "gurt" # Added gurt + stopwords = { # Expanded stopwords + "the", + "and", + "is", + "in", + "to", + "a", + "of", + "for", + "that", + "this", + "it", + "with", + "on", + "as", + "be", + "at", + "by", + "an", + "or", + "but", + "if", + "from", + "when", + "where", + "how", + "all", + "any", + "both", + "each", + "few", + "more", + "most", + "some", + "such", + "no", + "nor", + "not", + "only", + "own", + "same", + "so", + "than", + "too", + "very", + "can", + "will", + "just", + "should", + "now", + "also", + "like", + "even", + "because", + "way", + "who", + "what", + "yeah", + "yes", + "no", + "nah", + "lol", + "lmao", + "haha", + "hmm", + "um", + "uh", + "oh", + "ah", + "ok", + "okay", + "dont", + "don't", + "doesnt", + "doesn't", + "didnt", + "didn't", + "cant", + "can't", + "im", + "i'm", + "ive", + "i've", + "youre", + "you're", + "youve", + "you've", + "hes", + "he's", + "shes", + "she's", + "its", + "it's", + "were", + "we're", + "weve", + "we've", + "theyre", + "they're", + "theyve", + "they've", + "thats", + "that's", + "whats", + "what's", + "whos", + "who's", + "gonna", + "gotta", + "kinda", + "sorta", + "gurt", # Added gurt } def extract_ngrams(text, n_values=[1, 2, 3]): - words = re.findall(r'\b\w+\b', text.lower()) - filtered_words = [word for word in words if word not in stopwords and len(word) > 2] + words = re.findall(r"\b\w+\b", text.lower()) + filtered_words = [ + word for word in words if word not in stopwords and len(word) > 2 + ] all_ngrams = [] - for n in n_values: all_ngrams.extend([' '.join(filtered_words[i:i+n]) for i in range(len(filtered_words)-n+1)]) + for n in n_values: + all_ngrams.extend( + [ + " ".join(filtered_words[i : i + n]) + for i in range(len(filtered_words) - n + 1) + ] + ) return all_ngrams all_ngrams = extract_ngrams(all_text) ngram_counts = defaultdict(int) - for ngram in all_ngrams: ngram_counts[ngram] += 1 + for ngram in all_ngrams: + ngram_counts[ngram] += 1 min_count = 2 if len(messages) > 10 else 1 - filtered_ngrams = {ngram: count for ngram, count in ngram_counts.items() if count >= min_count} + filtered_ngrams = { + ngram: count for ngram, count in ngram_counts.items() if count >= min_count + } total_messages = len(messages) ngram_scores = {} for ngram, count in filtered_ngrams.items(): # Calculate score based on frequency, length, and spread across messages message_count = sum(1 for msg in messages if ngram in msg["content"].lower()) - spread_factor = (message_count / total_messages) ** 0.5 # Less emphasis on spread - length_bonus = len(ngram.split()) * 0.1 # Slight bonus for longer ngrams + spread_factor = ( + message_count / total_messages + ) ** 0.5 # Less emphasis on spread + length_bonus = len(ngram.split()) * 0.1 # Slight bonus for longer ngrams # Adjust importance calculation importance = (count * (0.4 + spread_factor)) + length_bonus ngram_scores[ngram] = importance @@ -253,132 +482,265 @@ def identify_conversation_topics(cog: 'GurtCog', messages: List[Dict[str, Any]]) break if not is_subgram and ngram not in temp_processed: ngrams_to_consider.append((ngram, score)) - temp_processed.add(ngram) # Avoid adding duplicates if logic changes + temp_processed.add(ngram) # Avoid adding duplicates if logic changes # Now process the filtered ngrams - sorted_ngrams = ngrams_to_consider # Use the filtered list + sorted_ngrams = ngrams_to_consider # Use the filtered list - for ngram, score in sorted_ngrams[:10]: # Consider top 10 potential topics after filtering - if ngram in processed_ngrams: continue + for ngram, score in sorted_ngrams[ + :10 + ]: # Consider top 10 potential topics after filtering + if ngram in processed_ngrams: + continue related_terms = [] # Find related terms (sub-ngrams or overlapping ngrams from the original sorted list) - for other_ngram, other_score in sorted_by_score: # Search in original sorted list for relations - if other_ngram == ngram or other_ngram in processed_ngrams: continue - ngram_words = set(ngram.split()); other_words = set(other_ngram.split()) + for ( + other_ngram, + other_score, + ) in sorted_by_score: # Search in original sorted list for relations + if other_ngram == ngram or other_ngram in processed_ngrams: + continue + ngram_words = set(ngram.split()) + other_words = set(other_ngram.split()) # Check for overlap or if one is a sub-string (more lenient relation) if ngram_words.intersection(other_words) or other_ngram in ngram: related_terms.append({"term": other_ngram, "score": other_score}) # Don't mark related terms as fully processed here unless they are direct sub-ngrams # processed_ngrams.add(other_ngram) - if len(related_terms) >= 3: break # Limit related terms shown + if len(related_terms) >= 3: + break # Limit related terms shown processed_ngrams.add(ngram) - topic_entry = {"topic": ngram, "score": score, "related_terms": related_terms, "message_count": sum(1 for msg in messages if ngram in msg["content"].lower())} + topic_entry = { + "topic": ngram, + "score": score, + "related_terms": related_terms, + "message_count": sum( + 1 for msg in messages if ngram in msg["content"].lower() + ), + } topics.append(topic_entry) - if len(topics) >= MAX_ACTIVE_TOPICS: break # Use config for max topics + if len(topics) >= MAX_ACTIVE_TOPICS: + break # Use config for max topics # Simple sentiment analysis for topics - positive_words = {"good", "great", "awesome", "amazing", "excellent", "love", "like", "best", "better", "nice", "cool"} + positive_words = { + "good", + "great", + "awesome", + "amazing", + "excellent", + "love", + "like", + "best", + "better", + "nice", + "cool", + } sorted_ngrams = sorted(ngram_scores.items(), key=lambda x: x[1], reverse=True) for ngram, score in sorted_ngrams[:15]: - if ngram in processed_ngrams: continue + if ngram in processed_ngrams: + continue related_terms = [] for other_ngram, other_score in sorted_ngrams: - if other_ngram == ngram or other_ngram in processed_ngrams: continue - ngram_words = set(ngram.split()); other_words = set(other_ngram.split()) + if other_ngram == ngram or other_ngram in processed_ngrams: + continue + ngram_words = set(ngram.split()) + other_words = set(other_ngram.split()) if ngram_words.intersection(other_words): related_terms.append({"term": other_ngram, "score": other_score}) processed_ngrams.add(other_ngram) - if len(related_terms) >= 5: break + if len(related_terms) >= 5: + break processed_ngrams.add(ngram) - topic_entry = {"topic": ngram, "score": score, "related_terms": related_terms, "message_count": sum(1 for msg in messages if ngram in msg["content"].lower())} + topic_entry = { + "topic": ngram, + "score": score, + "related_terms": related_terms, + "message_count": sum( + 1 for msg in messages if ngram in msg["content"].lower() + ), + } topics.append(topic_entry) - if len(topics) >= 5: break + if len(topics) >= 5: + break # Simple sentiment analysis for topics - positive_words = {"good", "great", "awesome", "amazing", "excellent", "love", "like", "best", "better", "nice", "cool"} - negative_words = {"bad", "terrible", "awful", "worst", "hate", "dislike", "sucks", "stupid", "boring", "annoying"} + positive_words = { + "good", + "great", + "awesome", + "amazing", + "excellent", + "love", + "like", + "best", + "better", + "nice", + "cool", + } + negative_words = { + "bad", + "terrible", + "awful", + "worst", + "hate", + "dislike", + "sucks", + "stupid", + "boring", + "annoying", + } for topic in topics: - topic_messages = [msg["content"] for msg in messages if topic["topic"] in msg["content"].lower()] + topic_messages = [ + msg["content"] + for msg in messages + if topic["topic"] in msg["content"].lower() + ] topic_text = " ".join(topic_messages).lower() positive_count = sum(1 for word in positive_words if word in topic_text) negative_count = sum(1 for word in negative_words if word in topic_text) - if positive_count > negative_count: topic["sentiment"] = "positive" - elif negative_count > positive_count: topic["sentiment"] = "negative" - else: topic["sentiment"] = "neutral" + if positive_count > negative_count: + topic["sentiment"] = "positive" + elif negative_count > positive_count: + topic["sentiment"] = "negative" + else: + topic["sentiment"] = "neutral" return topics -def analyze_user_interactions(cog: 'GurtCog', messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + +def analyze_user_interactions( + cog: "GurtCog", messages: List[Dict[str, Any]] +) -> List[Dict[str, Any]]: """Analyze interactions between users in the conversation""" interactions = [] response_map = defaultdict(int) for i in range(1, len(messages)): - current_msg = messages[i]; prev_msg = messages[i-1] - if current_msg["author"]["id"] == prev_msg["author"]["id"]: continue - responder = current_msg["author"]["id"]; respondee = prev_msg["author"]["id"] + current_msg = messages[i] + prev_msg = messages[i - 1] + if current_msg["author"]["id"] == prev_msg["author"]["id"]: + continue + responder = current_msg["author"]["id"] + respondee = prev_msg["author"]["id"] key = f"{responder}:{respondee}" response_map[key] += 1 for key, count in response_map.items(): if count > 1: responder, respondee = key.split(":") - interactions.append({"responder": responder, "respondee": respondee, "count": count}) + interactions.append( + {"responder": responder, "respondee": respondee, "count": count} + ) return interactions -def update_user_preferences(cog: 'GurtCog'): + +def update_user_preferences(cog: "GurtCog"): """Update stored user preferences based on observed interactions""" - for user_id, messages in cog.message_cache['by_user'].items(): - if len(messages) < 5: continue - emoji_count = 0; slang_count = 0; avg_length = 0 + for user_id, messages in cog.message_cache["by_user"].items(): + if len(messages) < 5: + continue + emoji_count = 0 + slang_count = 0 + avg_length = 0 for msg in messages: content = msg["content"] - emoji_count += len(re.findall(r'[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F700-\U0001F77F\U0001F780-\U0001F7FF\U0001F800-\U0001F8FF\U0001F900-\U0001F9FF\U0001FA00-\U0001FA6F\U0001FA70-\U0001FAFF\U00002702-\U000027B0\U000024C2-\U0001F251]', content)) - slang_words = ["ngl", "icl", "pmo", "ts", "bro", "vro", "bruh", "tuff", "kevin"] # Example slang + emoji_count += len( + re.findall( + r"[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F700-\U0001F77F\U0001F780-\U0001F7FF\U0001F800-\U0001F8FF\U0001F900-\U0001F9FF\U0001FA00-\U0001FA6F\U0001FA70-\U0001FAFF\U00002702-\U000027B0\U000024C2-\U0001F251]", + content, + ) + ) + slang_words = [ + "ngl", + "icl", + "pmo", + "ts", + "bro", + "vro", + "bruh", + "tuff", + "kevin", + ] # Example slang for word in slang_words: - if re.search(r'\b' + word + r'\b', content.lower()): slang_count += 1 + if re.search(r"\b" + word + r"\b", content.lower()): + slang_count += 1 avg_length += len(content) - if messages: avg_length /= len(messages) + if messages: + avg_length /= len(messages) user_prefs = cog.user_preferences[user_id] - if emoji_count > 0: user_prefs["emoji_preference"] = user_prefs.get("emoji_preference", 0.5) * (1 - LEARNING_RATE) + (emoji_count / len(messages)) * LEARNING_RATE - if slang_count > 0: user_prefs["slang_preference"] = user_prefs.get("slang_preference", 0.5) * (1 - LEARNING_RATE) + (slang_count / len(messages)) * LEARNING_RATE - user_prefs["length_preference"] = user_prefs.get("length_preference", 50) * (1 - LEARNING_RATE) + avg_length * LEARNING_RATE + if emoji_count > 0: + user_prefs["emoji_preference"] = ( + user_prefs.get("emoji_preference", 0.5) * (1 - LEARNING_RATE) + + (emoji_count / len(messages)) * LEARNING_RATE + ) + if slang_count > 0: + user_prefs["slang_preference"] = ( + user_prefs.get("slang_preference", 0.5) * (1 - LEARNING_RATE) + + (slang_count / len(messages)) * LEARNING_RATE + ) + user_prefs["length_preference"] = ( + user_prefs.get("length_preference", 50) * (1 - LEARNING_RATE) + + avg_length * LEARNING_RATE + ) + # Deprecated/Superseded by evolve_personality # def adapt_personality_traits(cog: 'GurtCog'): # """Slightly adapt personality traits based on observed patterns""" # pass # Logic removed as it's handled by evolve_personality now -async def evolve_personality(cog: 'GurtCog'): + +async def evolve_personality(cog: "GurtCog"): """Periodically analyzes recent activity and adjusts persistent personality traits.""" print("Starting personality evolution cycle...") try: current_traits = await cog.memory_manager.get_all_personality_traits() - if not current_traits: print("Evolution Error: Could not load current traits."); return + if not current_traits: + print("Evolution Error: Could not load current traits.") + return - positive_sentiment_score = 0; negative_sentiment_score = 0; sentiment_channels_count = 0 + positive_sentiment_score = 0 + negative_sentiment_score = 0 + sentiment_channels_count = 0 for channel_id, sentiment_data in cog.conversation_sentiment.items(): if time.time() - cog.channel_activity.get(channel_id, 0) < 3600: - if sentiment_data["overall"] == "positive": positive_sentiment_score += sentiment_data["intensity"] - elif sentiment_data["overall"] == "negative": negative_sentiment_score += sentiment_data["intensity"] + if sentiment_data["overall"] == "positive": + positive_sentiment_score += sentiment_data["intensity"] + elif sentiment_data["overall"] == "negative": + negative_sentiment_score += sentiment_data["intensity"] sentiment_channels_count += 1 - avg_pos_intensity = positive_sentiment_score / sentiment_channels_count if sentiment_channels_count > 0 else 0 - avg_neg_intensity = negative_sentiment_score / sentiment_channels_count if sentiment_channels_count > 0 else 0 - print(f"Evolution Analysis: Avg Pos Intensity={avg_pos_intensity:.2f}, Avg Neg Intensity={avg_neg_intensity:.2f}") + avg_pos_intensity = ( + positive_sentiment_score / sentiment_channels_count + if sentiment_channels_count > 0 + else 0 + ) + avg_neg_intensity = ( + negative_sentiment_score / sentiment_channels_count + if sentiment_channels_count > 0 + else 0 + ) + print( + f"Evolution Analysis: Avg Pos Intensity={avg_pos_intensity:.2f}, Avg Neg Intensity={avg_neg_intensity:.2f}" + ) # --- Analyze Tool Usage --- tool_success_rate = {} total_tool_uses = 0 successful_tool_uses = 0 for tool_name, stats in cog.tool_stats.items(): - count = stats.get('count', 0) - success = stats.get('success', 0) + count = stats.get("count", 0) + success = stats.get("success", 0) if count > 0: tool_success_rate[tool_name] = success / count total_tool_uses += count successful_tool_uses += success - overall_tool_success_rate = successful_tool_uses / total_tool_uses if total_tool_uses > 0 else 0.5 # Default to neutral if no uses - print(f"Evolution Analysis: Overall Tool Success Rate={overall_tool_success_rate:.2f} ({successful_tool_uses}/{total_tool_uses})") + overall_tool_success_rate = ( + successful_tool_uses / total_tool_uses if total_tool_uses > 0 else 0.5 + ) # Default to neutral if no uses + print( + f"Evolution Analysis: Overall Tool Success Rate={overall_tool_success_rate:.2f} ({successful_tool_uses}/{total_tool_uses})" + ) # Example: Log specific tool rates if needed # print(f"Evolution Analysis: Tool Success Rates: {tool_success_rate}") @@ -389,32 +751,48 @@ async def evolve_personality(cog: 'GurtCog'): for msg_id, reaction_data in cog.gurt_message_reactions.items(): positive_reactions += reaction_data.get("positive", 0) negative_reactions += reaction_data.get("negative", 0) - reaction_ratio = positive_reactions / (positive_reactions + negative_reactions) if (positive_reactions + negative_reactions) > 0 else 0.5 # Default neutral - print(f"Evolution Analysis: Reaction Ratio (Pos/Total)={reaction_ratio:.2f} ({positive_reactions}/{positive_reactions + negative_reactions})") + reaction_ratio = ( + positive_reactions / (positive_reactions + negative_reactions) + if (positive_reactions + negative_reactions) > 0 + else 0.5 + ) # Default neutral + print( + f"Evolution Analysis: Reaction Ratio (Pos/Total)={reaction_ratio:.2f} ({positive_reactions}/{positive_reactions + negative_reactions})" + ) # --- Calculate Trait Adjustments --- trait_changes = {} - local_learning_rate = 0.02 # Use local variable + local_learning_rate = 0.02 # Use local variable # Optimism (based on sentiment) - optimism_target = 0.5 + (avg_pos_intensity - avg_neg_intensity) * 0.5 # Scale sentiment difference to -0.5 to +0.5 range - trait_changes['optimism'] = max(0.0, min(1.0, optimism_target)) # Target value directly, learning rate applied later + optimism_target = ( + 0.5 + (avg_pos_intensity - avg_neg_intensity) * 0.5 + ) # Scale sentiment difference to -0.5 to +0.5 range + trait_changes["optimism"] = max( + 0.0, min(1.0, optimism_target) + ) # Target value directly, learning rate applied later # Mischief (based on timeout usage success/reactions) timeout_uses = cog.tool_stats.get("timeout_user", {}).get("count", 0) timeout_success_rate = tool_success_rate.get("timeout_user", 0.5) - if timeout_uses > 2: # Only adjust if used a few times + if timeout_uses > 2: # Only adjust if used a few times # Increase mischief if timeouts are successful and reactions aren't overly negative - mischief_target_adjustment = (timeout_success_rate - 0.5) * 0.2 + (reaction_ratio - 0.5) * 0.1 - current_mischief = current_traits.get('mischief', 0.5) - trait_changes['mischief'] = max(0.0, min(1.0, current_mischief + mischief_target_adjustment)) + mischief_target_adjustment = (timeout_success_rate - 0.5) * 0.2 + ( + reaction_ratio - 0.5 + ) * 0.1 + current_mischief = current_traits.get("mischief", 0.5) + trait_changes["mischief"] = max( + 0.0, min(1.0, current_mischief + mischief_target_adjustment) + ) # Curiosity (based on web search usage) search_uses = cog.tool_stats.get("web_search", {}).get("count", 0) - if search_uses > 1: # If search is used - current_curiosity = current_traits.get('curiosity', 0.6) - # Slightly increase curiosity if search is used, decrease slightly if not? (Needs refinement) - trait_changes['curiosity'] = max(0.0, min(1.0, current_curiosity + 0.05)) # Simple boost for now + if search_uses > 1: # If search is used + current_curiosity = current_traits.get("curiosity", 0.6) + # Slightly increase curiosity if search is used, decrease slightly if not? (Needs refinement) + trait_changes["curiosity"] = max( + 0.0, min(1.0, current_curiosity + 0.05) + ) # Simple boost for now # Sarcasm (increase if reactions are positive despite negative sentiment?) - Complex, placeholder # current_sarcasm = current_traits.get('sarcasm_level', 0.3) @@ -427,38 +805,53 @@ async def evolve_personality(cog: 'GurtCog'): # if reaction_ratio > 0.65 and total_reacted_messages > 5: # trait_changes['chattiness'] = max(0.1, min(1.0, current_chattiness + 0.03)) - # --- Apply Calculated Changes --- updated_count = 0 print(f"Calculated Trait Target Changes: {trait_changes}") for key, target_value in trait_changes.items(): current_value = current_traits.get(key) - if current_value is None: print(f"Evolution Warning: Trait '{key}' not found."); continue + if current_value is None: + print(f"Evolution Warning: Trait '{key}' not found.") + continue try: - current_float = float(current_value); target_float = float(target_value) - new_value_float = current_float * (1 - local_learning_rate) + target_float * local_learning_rate - new_value_clamped = max(0.0, min(1.0, new_value_float)) # Clamp 0-1 + current_float = float(current_value) + target_float = float(target_value) + new_value_float = ( + current_float * (1 - local_learning_rate) + + target_float * local_learning_rate + ) + new_value_clamped = max(0.0, min(1.0, new_value_float)) # Clamp 0-1 if abs(new_value_clamped - current_float) > 0.001: - await cog.memory_manager.set_personality_trait(key, new_value_clamped) - print(f"Evolved trait '{key}': {current_float:.3f} -> {new_value_clamped:.3f}") + await cog.memory_manager.set_personality_trait( + key, new_value_clamped + ) + print( + f"Evolved trait '{key}': {current_float:.3f} -> {new_value_clamped:.3f}" + ) updated_count += 1 - except (ValueError, TypeError) as e: print(f"Evolution Error processing trait '{key}': {e}") + except (ValueError, TypeError) as e: + print(f"Evolution Error processing trait '{key}': {e}") - if updated_count > 0: print(f"Personality evolution complete. Updated {updated_count} traits.") - else: print("Personality evolution complete. No significant trait changes.") + if updated_count > 0: + print(f"Personality evolution complete. Updated {updated_count} traits.") + else: + print("Personality evolution complete. No significant trait changes.") - except Exception as e: print(f"Error during personality evolution: {e}"); traceback.print_exc() + except Exception as e: + print(f"Error during personality evolution: {e}") + traceback.print_exc() -async def reflect_on_memories(cog: 'GurtCog'): + +async def reflect_on_memories(cog: "GurtCog"): """Periodically reviews memories to synthesize insights or consolidate information.""" print("Starting memory reflection cycle...") try: # --- Configuration --- - REFLECTION_INTERVAL_HOURS = 6 # How often to reflect + REFLECTION_INTERVAL_HOURS = 6 # How often to reflect FACTS_TO_REVIEW_PER_USER = 15 GENERAL_FACTS_TO_REVIEW = 30 MIN_FACTS_FOR_REFLECTION = 5 - SYNTHESIS_MODEL = cog.fallback_model # Use a potentially cheaper model + SYNTHESIS_MODEL = cog.fallback_model # Use a potentially cheaper model SYNTHESIS_MAX_TOKENS = 200 # Check if enough time has passed (simple check, could be more robust) @@ -471,27 +864,40 @@ async def reflect_on_memories(cog: 'GurtCog'): users_reflected = 0 for user_id in all_user_ids: try: - user_facts = await cog.memory_manager.get_user_facts(user_id, limit=FACTS_TO_REVIEW_PER_USER) # Get recent facts - if len(user_facts) < MIN_FACTS_FOR_REFLECTION: continue + user_facts = await cog.memory_manager.get_user_facts( + user_id, limit=FACTS_TO_REVIEW_PER_USER + ) # Get recent facts + if len(user_facts) < MIN_FACTS_FOR_REFLECTION: + continue - user_info = await cog.bot.fetch_user(int(user_id)) # Get user info for name + user_info = await cog.bot.fetch_user( + int(user_id) + ) # Get user info for name user_name = user_info.display_name if user_info else f"User {user_id}" print(f" - Reflecting on {len(user_facts)} facts for {user_name}...") facts_text = "\n".join([f"- {fact}" for fact in user_facts]) reflection_prompt = [ - {"role": "system", "content": f"Analyze the following facts about {user_name}. Identify potential patterns, contradictions, or synthesize a concise summary of key traits or interests. Focus on creating 1-2 new, insightful summary facts. Respond ONLY with JSON: {{ \"new_facts\": [\"fact1\", \"fact2\"], \"reasoning\": \"brief explanation\" }} or {{ \"new_facts\": [], \"reasoning\": \"No new insights.\" }}"}, - {"role": "user", "content": f"Facts:\n{facts_text}\n\nSynthesize insights:"} + { + "role": "system", + "content": f'Analyze the following facts about {user_name}. Identify potential patterns, contradictions, or synthesize a concise summary of key traits or interests. Focus on creating 1-2 new, insightful summary facts. Respond ONLY with JSON: {{ "new_facts": ["fact1", "fact2"], "reasoning": "brief explanation" }} or {{ "new_facts": [], "reasoning": "No new insights." }}', + }, + { + "role": "user", + "content": f"Facts:\n{facts_text}\n\nSynthesize insights:", + }, ] synthesis_schema = { "type": "object", "properties": { "new_facts": {"type": "array", "items": {"type": "string"}}, - "reasoning": {"type": "string"} - }, "required": ["new_facts", "reasoning"] + "reasoning": {"type": "string"}, + }, + "required": ["new_facts", "reasoning"], } - from .api import get_internal_ai_json_response # Local import + from .api import get_internal_ai_json_response # Local import + synthesis_result = await get_internal_ai_json_response( cog=cog, prompt_messages=reflection_prompt, @@ -499,23 +905,32 @@ async def reflect_on_memories(cog: 'GurtCog'): response_schema_dict=synthesis_schema, model_name_override=SYNTHESIS_MODEL, temperature=0.4, - max_tokens=SYNTHESIS_MAX_TOKENS + max_tokens=SYNTHESIS_MAX_TOKENS, ) if synthesis_result and synthesis_result.get("new_facts"): added_count = 0 for new_fact in synthesis_result["new_facts"]: - if new_fact and len(new_fact) > 5: # Basic validation - add_result = await cog.memory_manager.add_user_fact(user_id, f"[Synthesized] {new_fact}") - if add_result.get("status") == "added": added_count += 1 + if new_fact and len(new_fact) > 5: # Basic validation + add_result = await cog.memory_manager.add_user_fact( + user_id, f"[Synthesized] {new_fact}" + ) + if add_result.get("status") == "added": + added_count += 1 if added_count > 0: - print(f" - Added {added_count} synthesized fact(s) for {user_name}. Reasoning: {synthesis_result.get('reasoning')}") + print( + f" - Added {added_count} synthesized fact(s) for {user_name}. Reasoning: {synthesis_result.get('reasoning')}" + ) users_reflected += 1 # else: print(f" - No new insights synthesized for {user_name}.") # Optional log except Exception as user_reflect_e: - print(f" - Error reflecting on facts for user {user_id}: {user_reflect_e}") - print(f"User fact reflection complete. Synthesized facts for {users_reflected} users.") + print( + f" - Error reflecting on facts for user {user_id}: {user_reflect_e}" + ) + print( + f"User fact reflection complete. Synthesized facts for {users_reflected} users." + ) # --- General Fact Reflection (Example: Identify related topics) --- # This part is more complex and might require different strategies. @@ -531,14 +946,22 @@ async def reflect_on_memories(cog: 'GurtCog'): print(f"Error during memory reflection cycle: {e}") traceback.print_exc() -async def decompose_goal_into_steps(cog: 'GurtCog', goal_description: str) -> Optional[Dict[str, Any]]: + +async def decompose_goal_into_steps( + cog: "GurtCog", goal_description: str +) -> Optional[Dict[str, Any]]: """Uses an AI call to break down a goal into achievable steps with potential tool usage.""" logger.info(f"Decomposing goal: '{goal_description}'") - from .config import GOAL_DECOMPOSITION_SCHEMA, TOOLS # Import schema and tools list for context - from .api import get_internal_ai_json_response # Local import + from .config import ( + GOAL_DECOMPOSITION_SCHEMA, + TOOLS, + ) # Import schema and tools list for context + from .api import get_internal_ai_json_response # Local import # Provide context about available tools - tool_descriptions = "\n".join([f"- {tool.name}: {tool.description}" for tool in TOOLS]) + tool_descriptions = "\n".join( + [f"- {tool.name}: {tool.description}" for tool in TOOLS] + ) system_prompt = ( "You are Gurt's planning module. Your task is to break down a high-level goal into a sequence of smaller, " "concrete steps. For each step, determine if one of Gurt's available tools can help achieve it. " @@ -546,11 +969,13 @@ async def decompose_goal_into_steps(cog: 'GurtCog', goal_description: str) -> Op f"Available Tools:\n{tool_descriptions}\n\n" "Respond ONLY with JSON matching the provided schema." ) - user_prompt = f"Goal: {goal_description}\n\nDecompose this goal into achievable steps:" + user_prompt = ( + f"Goal: {goal_description}\n\nDecompose this goal into achievable steps:" + ) decomposition_prompt_messages = [ {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} + {"role": "user", "content": user_prompt}, ] try: @@ -558,51 +983,80 @@ async def decompose_goal_into_steps(cog: 'GurtCog', goal_description: str) -> Op cog=cog, prompt_messages=decomposition_prompt_messages, task_description=f"Goal Decomposition ({goal_description[:30]}...)", - response_schema_dict=GOAL_DECOMPOSITION_SCHEMA['schema'], - model_name_override=cog.fallback_model, # Use fallback model for planning potentially + response_schema_dict=GOAL_DECOMPOSITION_SCHEMA["schema"], + model_name_override=cog.fallback_model, # Use fallback model for planning potentially temperature=0.3, - max_tokens=1000 # Allow more tokens for potentially complex plans + max_tokens=1000, # Allow more tokens for potentially complex plans ) if plan and plan.get("goal_achievable"): - logger.info(f"Goal '{goal_description}' decomposed into {len(plan.get('steps', []))} steps.") + logger.info( + f"Goal '{goal_description}' decomposed into {len(plan.get('steps', []))} steps." + ) # Basic validation of steps structure (optional but recommended) - if isinstance(plan.get('steps'), list): - for i, step in enumerate(plan['steps']): - if not isinstance(step, dict) or 'step_description' not in step: - logger.error(f"Invalid step structure at index {i} in decomposition plan: {step}") - plan['goal_achievable'] = False - plan['reasoning'] += " (Invalid step structure detected)" - plan['steps'] = [] + if isinstance(plan.get("steps"), list): + for i, step in enumerate(plan["steps"]): + if not isinstance(step, dict) or "step_description" not in step: + logger.error( + f"Invalid step structure at index {i} in decomposition plan: {step}" + ) + plan["goal_achievable"] = False + plan["reasoning"] += " (Invalid step structure detected)" + plan["steps"] = [] break else: - plan['steps'] = [] # Ensure steps is a list even if validation fails + plan["steps"] = [] # Ensure steps is a list even if validation fails return plan elif plan: - logger.warning(f"Goal '{goal_description}' deemed not achievable. Reasoning: {plan.get('reasoning')}") - return plan # Return the plan indicating it's not achievable + logger.warning( + f"Goal '{goal_description}' deemed not achievable. Reasoning: {plan.get('reasoning')}" + ) + return plan # Return the plan indicating it's not achievable else: - logger.error(f"Goal decomposition failed for '{goal_description}'. No valid JSON plan returned.") + logger.error( + f"Goal decomposition failed for '{goal_description}'. No valid JSON plan returned." + ) return None except Exception as e: - logger.error(f"Error during goal decomposition for '{goal_description}': {e}", exc_info=True) + logger.error( + f"Error during goal decomposition for '{goal_description}': {e}", + exc_info=True, + ) return None -def analyze_message_sentiment(cog: 'GurtCog', message_content: str) -> Dict[str, Any]: +def analyze_message_sentiment(cog: "GurtCog", message_content: str) -> Dict[str, Any]: """Analyzes the sentiment of a message using keywords and emojis.""" content = message_content.lower() - result = {"sentiment": "neutral", "intensity": 0.5, "emotions": [], "confidence": 0.5} + result = { + "sentiment": "neutral", + "intensity": 0.5, + "emotions": [], + "confidence": 0.5, + } - positive_emoji_count = sum(1 for emoji in EMOJI_SENTIMENT["positive"] if emoji in content) - negative_emoji_count = sum(1 for emoji in EMOJI_SENTIMENT["negative"] if emoji in content) - total_emoji_count = positive_emoji_count + negative_emoji_count + sum(1 for emoji in EMOJI_SENTIMENT["neutral"] if emoji in content) + positive_emoji_count = sum( + 1 for emoji in EMOJI_SENTIMENT["positive"] if emoji in content + ) + negative_emoji_count = sum( + 1 for emoji in EMOJI_SENTIMENT["negative"] if emoji in content + ) + total_emoji_count = ( + positive_emoji_count + + negative_emoji_count + + sum(1 for emoji in EMOJI_SENTIMENT["neutral"] if emoji in content) + ) - detected_emotions = []; emotion_scores = {} + detected_emotions = [] + emotion_scores = {} for emotion, keywords in EMOTION_KEYWORDS.items(): - emotion_count = sum(1 for keyword in keywords if re.search(r'\b' + re.escape(keyword) + r'\b', content)) + emotion_count = sum( + 1 + for keyword in keywords + if re.search(r"\b" + re.escape(keyword) + r"\b", content) + ) if emotion_count > 0: emotion_score = min(1.0, emotion_count / len(keywords) * 2) emotion_scores[emotion] = emotion_score @@ -612,82 +1066,213 @@ def analyze_message_sentiment(cog: 'GurtCog', message_content: str) -> Dict[str, primary_emotion = max(emotion_scores.items(), key=lambda x: x[1]) result["emotions"] = [primary_emotion[0]] for emotion, score in emotion_scores.items(): - if emotion != primary_emotion[0] and score > primary_emotion[1] * 0.7: result["emotions"].append(emotion) + if emotion != primary_emotion[0] and score > primary_emotion[1] * 0.7: + result["emotions"].append(emotion) - positive_emotions = ["joy"]; negative_emotions = ["sadness", "anger", "fear", "disgust"] - if primary_emotion[0] in positive_emotions: result["sentiment"] = "positive"; result["intensity"] = primary_emotion[1] - elif primary_emotion[0] in negative_emotions: result["sentiment"] = "negative"; result["intensity"] = primary_emotion[1] - else: result["sentiment"] = "neutral"; result["intensity"] = 0.5 + positive_emotions = ["joy"] + negative_emotions = ["sadness", "anger", "fear", "disgust"] + if primary_emotion[0] in positive_emotions: + result["sentiment"] = "positive" + result["intensity"] = primary_emotion[1] + elif primary_emotion[0] in negative_emotions: + result["sentiment"] = "negative" + result["intensity"] = primary_emotion[1] + else: + result["sentiment"] = "neutral" + result["intensity"] = 0.5 result["confidence"] = min(0.9, 0.5 + primary_emotion[1] * 0.4) elif total_emoji_count > 0: - if positive_emoji_count > negative_emoji_count: result["sentiment"] = "positive"; result["intensity"] = min(0.9, 0.5 + (positive_emoji_count / total_emoji_count) * 0.4); result["confidence"] = min(0.8, 0.4 + (positive_emoji_count / total_emoji_count) * 0.4) - elif negative_emoji_count > positive_emoji_count: result["sentiment"] = "negative"; result["intensity"] = min(0.9, 0.5 + (negative_emoji_count / total_emoji_count) * 0.4); result["confidence"] = min(0.8, 0.4 + (negative_emoji_count / total_emoji_count) * 0.4) - else: result["sentiment"] = "neutral"; result["intensity"] = 0.5; result["confidence"] = 0.6 + if positive_emoji_count > negative_emoji_count: + result["sentiment"] = "positive" + result["intensity"] = min( + 0.9, 0.5 + (positive_emoji_count / total_emoji_count) * 0.4 + ) + result["confidence"] = min( + 0.8, 0.4 + (positive_emoji_count / total_emoji_count) * 0.4 + ) + elif negative_emoji_count > positive_emoji_count: + result["sentiment"] = "negative" + result["intensity"] = min( + 0.9, 0.5 + (negative_emoji_count / total_emoji_count) * 0.4 + ) + result["confidence"] = min( + 0.8, 0.4 + (negative_emoji_count / total_emoji_count) * 0.4 + ) + else: + result["sentiment"] = "neutral" + result["intensity"] = 0.5 + result["confidence"] = 0.6 - else: # Basic text fallback - positive_words = {"good", "great", "awesome", "amazing", "excellent", "love", "like", "best", "better", "nice", "cool", "happy", "glad", "thanks", "thank", "appreciate", "wonderful", "fantastic", "perfect", "beautiful", "fun", "enjoy", "yes", "yep"} - negative_words = {"bad", "terrible", "awful", "worst", "hate", "dislike", "sucks", "stupid", "boring", "annoying", "sad", "upset", "angry", "mad", "disappointed", "sorry", "unfortunate", "horrible", "ugly", "wrong", "fail", "no", "nope"} - words = re.findall(r'\b\w+\b', content) + else: # Basic text fallback + positive_words = { + "good", + "great", + "awesome", + "amazing", + "excellent", + "love", + "like", + "best", + "better", + "nice", + "cool", + "happy", + "glad", + "thanks", + "thank", + "appreciate", + "wonderful", + "fantastic", + "perfect", + "beautiful", + "fun", + "enjoy", + "yes", + "yep", + } + negative_words = { + "bad", + "terrible", + "awful", + "worst", + "hate", + "dislike", + "sucks", + "stupid", + "boring", + "annoying", + "sad", + "upset", + "angry", + "mad", + "disappointed", + "sorry", + "unfortunate", + "horrible", + "ugly", + "wrong", + "fail", + "no", + "nope", + } + words = re.findall(r"\b\w+\b", content) positive_count = sum(1 for word in words if word in positive_words) negative_count = sum(1 for word in words if word in negative_words) - if positive_count > negative_count: result["sentiment"] = "positive"; result["intensity"] = min(0.8, 0.5 + (positive_count / len(words)) * 2 if words else 0); result["confidence"] = min(0.7, 0.3 + (positive_count / len(words)) * 0.4 if words else 0) - elif negative_count > positive_count: result["sentiment"] = "negative"; result["intensity"] = min(0.8, 0.5 + (negative_count / len(words)) * 2 if words else 0); result["confidence"] = min(0.7, 0.3 + (negative_count / len(words)) * 0.4 if words else 0) - else: result["sentiment"] = "neutral"; result["intensity"] = 0.5; result["confidence"] = 0.5 + if positive_count > negative_count: + result["sentiment"] = "positive" + result["intensity"] = min( + 0.8, 0.5 + (positive_count / len(words)) * 2 if words else 0 + ) + result["confidence"] = min( + 0.7, 0.3 + (positive_count / len(words)) * 0.4 if words else 0 + ) + elif negative_count > positive_count: + result["sentiment"] = "negative" + result["intensity"] = min( + 0.8, 0.5 + (negative_count / len(words)) * 2 if words else 0 + ) + result["confidence"] = min( + 0.7, 0.3 + (negative_count / len(words)) * 0.4 if words else 0 + ) + else: + result["sentiment"] = "neutral" + result["intensity"] = 0.5 + result["confidence"] = 0.5 return result -def update_conversation_sentiment(cog: 'GurtCog', channel_id: int, user_id: str, message_sentiment: Dict[str, Any]): + +def update_conversation_sentiment( + cog: "GurtCog", channel_id: int, user_id: str, message_sentiment: Dict[str, Any] +): """Updates the conversation sentiment tracking based on a new message's sentiment.""" channel_sentiment = cog.conversation_sentiment[channel_id] now = time.time() - if now - channel_sentiment["last_update"] > cog.sentiment_update_interval: # Access interval via cog - if channel_sentiment["overall"] == "positive": channel_sentiment["intensity"] = max(0.5, channel_sentiment["intensity"] - SENTIMENT_DECAY_RATE) - elif channel_sentiment["overall"] == "negative": channel_sentiment["intensity"] = max(0.5, channel_sentiment["intensity"] - SENTIMENT_DECAY_RATE) + if ( + now - channel_sentiment["last_update"] > cog.sentiment_update_interval + ): # Access interval via cog + if channel_sentiment["overall"] == "positive": + channel_sentiment["intensity"] = max( + 0.5, channel_sentiment["intensity"] - SENTIMENT_DECAY_RATE + ) + elif channel_sentiment["overall"] == "negative": + channel_sentiment["intensity"] = max( + 0.5, channel_sentiment["intensity"] - SENTIMENT_DECAY_RATE + ) channel_sentiment["recent_trend"] = "stable" channel_sentiment["last_update"] = now - user_sentiment = channel_sentiment["user_sentiments"].get(user_id, {"sentiment": "neutral", "intensity": 0.5}) + user_sentiment = channel_sentiment["user_sentiments"].get( + user_id, {"sentiment": "neutral", "intensity": 0.5} + ) confidence_weight = message_sentiment["confidence"] if user_sentiment["sentiment"] == message_sentiment["sentiment"]: - new_intensity = user_sentiment["intensity"] * 0.7 + message_sentiment["intensity"] * 0.3 + new_intensity = ( + user_sentiment["intensity"] * 0.7 + message_sentiment["intensity"] * 0.3 + ) user_sentiment["intensity"] = min(0.95, new_intensity) else: if message_sentiment["confidence"] > 0.7: user_sentiment["sentiment"] = message_sentiment["sentiment"] - user_sentiment["intensity"] = message_sentiment["intensity"] * 0.7 + user_sentiment["intensity"] * 0.3 + user_sentiment["intensity"] = ( + message_sentiment["intensity"] * 0.7 + user_sentiment["intensity"] * 0.3 + ) else: if message_sentiment["intensity"] > user_sentiment["intensity"]: user_sentiment["sentiment"] = message_sentiment["sentiment"] - user_sentiment["intensity"] = user_sentiment["intensity"] * 0.6 + message_sentiment["intensity"] * 0.4 + user_sentiment["intensity"] = ( + user_sentiment["intensity"] * 0.6 + + message_sentiment["intensity"] * 0.4 + ) user_sentiment["emotions"] = message_sentiment.get("emotions", []) channel_sentiment["user_sentiments"][user_id] = user_sentiment # Update overall based on active users (simplified access to active_conversations) - active_user_sentiments = [s for uid, s in channel_sentiment["user_sentiments"].items() if uid in cog.active_conversations.get(channel_id, {}).get('participants', set())] + active_user_sentiments = [ + s + for uid, s in channel_sentiment["user_sentiments"].items() + if uid + in cog.active_conversations.get(channel_id, {}).get("participants", set()) + ] if active_user_sentiments: sentiment_counts = defaultdict(int) - for s in active_user_sentiments: sentiment_counts[s["sentiment"]] += 1 + for s in active_user_sentiments: + sentiment_counts[s["sentiment"]] += 1 dominant_sentiment = max(sentiment_counts.items(), key=lambda x: x[1])[0] - avg_intensity = sum(s["intensity"] for s in active_user_sentiments if s["sentiment"] == dominant_sentiment) / sentiment_counts[dominant_sentiment] + avg_intensity = ( + sum( + s["intensity"] + for s in active_user_sentiments + if s["sentiment"] == dominant_sentiment + ) + / sentiment_counts[dominant_sentiment] + ) - prev_sentiment = channel_sentiment["overall"]; prev_intensity = channel_sentiment["intensity"] + prev_sentiment = channel_sentiment["overall"] + prev_intensity = channel_sentiment["intensity"] if dominant_sentiment == prev_sentiment: - if avg_intensity > prev_intensity + 0.1: channel_sentiment["recent_trend"] = "intensifying" - elif avg_intensity < prev_intensity - 0.1: channel_sentiment["recent_trend"] = "diminishing" - else: channel_sentiment["recent_trend"] = "stable" - else: channel_sentiment["recent_trend"] = "changing" + if avg_intensity > prev_intensity + 0.1: + channel_sentiment["recent_trend"] = "intensifying" + elif avg_intensity < prev_intensity - 0.1: + channel_sentiment["recent_trend"] = "diminishing" + else: + channel_sentiment["recent_trend"] = "stable" + else: + channel_sentiment["recent_trend"] = "changing" channel_sentiment["overall"] = dominant_sentiment channel_sentiment["intensity"] = avg_intensity channel_sentiment["last_update"] = now # No need to reassign cog.conversation_sentiment[channel_id] as it's modified in place + # --- Proactive Goal Creation --- -async def proactively_create_goals(cog: 'GurtCog'): + +async def proactively_create_goals(cog: "GurtCog"): """ Analyzes Gurt's current state, environment, and recent interactions to determine if any new goals should be created autonomously. @@ -709,7 +1294,9 @@ async def proactively_create_goals(cog: 'GurtCog'): # Goal: "Generate daily summary for channel [channel_id]." # For now, just log that the check happened. - logger.info("Proactive goal creation check complete (Placeholder - no goals created).") + logger.info( + "Proactive goal creation check complete (Placeholder - no goals created)." + ) # Example of adding a goal (if logic determined one was needed): # if should_create_goal: diff --git a/gurt/api.py b/gurt/api.py index 234482c..cea93d7 100644 --- a/gurt/api.py +++ b/gurt/api.py @@ -1,10 +1,11 @@ from collections import deque import ssl import certifi -import imghdr # Added for robust image MIME type detection +import imghdr # Added for robust image MIME type detection from .config import CONTEXT_WINDOW_SIZE + def patch_ssl_certifi(): original_create_default_context = ssl.create_default_context @@ -15,6 +16,7 @@ def patch_ssl_certifi(): ssl.create_default_context = custom_ssl_context + patch_ssl_certifi() import discord @@ -25,15 +27,27 @@ import base64 import re import time import datetime -from typing import TYPE_CHECKING, Optional, List, Dict, Any, Union, AsyncIterable, Tuple # Import Tuple -import jsonschema # For manual JSON validation +from typing import ( + TYPE_CHECKING, + Optional, + List, + Dict, + Any, + Union, + AsyncIterable, + Tuple, +) # Import Tuple +import jsonschema # For manual JSON validation from .tools import get_conversation_summary # Google Generative AI Imports (using Vertex AI backend) # try: from google import genai from google.genai import types -from google.api_core import exceptions as google_exceptions # Keep for retry logic if applicable +from google.api_core import ( + exceptions as google_exceptions, +) # Keep for retry logic if applicable + # except ImportError: # print("WARNING: google-generativeai or google-api-core not installed. API calls will fail.") # # Define dummy classes/exceptions if library isn't installed @@ -152,19 +166,35 @@ from google.api_core import exceptions as google_exceptions # Keep for retry log # Relative imports for components within the 'gurt' package from .config import ( - PROJECT_ID, LOCATION, DEFAULT_MODEL, FALLBACK_MODEL, CUSTOM_TUNED_MODEL_ENDPOINT, EMOJI_STICKER_DESCRIPTION_MODEL, # Import the new endpoint and model - API_TIMEOUT, API_RETRY_ATTEMPTS, API_RETRY_DELAY, TOOLS, RESPONSE_SCHEMA, - PROACTIVE_PLAN_SCHEMA, # Import the new schema - TAVILY_API_KEY, PISTON_API_URL, PISTON_API_KEY, BASELINE_PERSONALITY, TENOR_API_KEY # Import other needed configs + PROJECT_ID, + LOCATION, + DEFAULT_MODEL, + FALLBACK_MODEL, + CUSTOM_TUNED_MODEL_ENDPOINT, + EMOJI_STICKER_DESCRIPTION_MODEL, # Import the new endpoint and model + API_TIMEOUT, + API_RETRY_ATTEMPTS, + API_RETRY_DELAY, + TOOLS, + RESPONSE_SCHEMA, + PROACTIVE_PLAN_SCHEMA, # Import the new schema + TAVILY_API_KEY, + PISTON_API_URL, + PISTON_API_KEY, + BASELINE_PERSONALITY, + TENOR_API_KEY, # Import other needed configs ) from .prompt import build_dynamic_system_prompt -from .context import gather_conversation_context, get_memory_context # Renamed functions -from .tools import TOOL_MAPPING # Import tool mapping -from .utils import format_message, log_internal_api_call # Import utilities -import copy # Needed for deep copying schemas +from .context import ( + gather_conversation_context, + get_memory_context, +) # Renamed functions +from .tools import TOOL_MAPPING # Import tool mapping +from .utils import format_message, log_internal_api_call # Import utilities +import copy # Needed for deep copying schemas if TYPE_CHECKING: - from .cog import GurtCog # Import GurtCog for type hinting only + from .cog import GurtCog # Import GurtCog for type hinting only # --- Schema Preprocessing Helper --- @@ -181,37 +211,52 @@ def _preprocess_schema_for_vertex(schema: Dict[str, Any]) -> Dict[str, Any]: A new, preprocessed schema dictionary. """ if not isinstance(schema, dict): - return schema # Return non-dict elements as is + return schema # Return non-dict elements as is - processed_schema = copy.deepcopy(schema) # Work on a copy + processed_schema = copy.deepcopy(schema) # Work on a copy for key, value in processed_schema.items(): if key == "type" and isinstance(value, list): # Find the first non-"null" type in the list - first_valid_type = next((t for t in value if isinstance(t, str) and t.lower() != "null"), None) + first_valid_type = next( + (t for t in value if isinstance(t, str) and t.lower() != "null"), None + ) if first_valid_type: processed_schema[key] = first_valid_type else: # Fallback if only "null" or invalid types are present (shouldn't happen in valid schemas) - processed_schema[key] = "object" # Or handle as error - print(f"Warning: Schema preprocessing found list type '{value}' with no valid non-null string type. Falling back to 'object'.") + processed_schema[key] = "object" # Or handle as error + print( + f"Warning: Schema preprocessing found list type '{value}' with no valid non-null string type. Falling back to 'object'." + ) elif isinstance(value, dict): - processed_schema[key] = _preprocess_schema_for_vertex(value) # Recurse for nested objects + processed_schema[key] = _preprocess_schema_for_vertex( + value + ) # Recurse for nested objects elif isinstance(value, list): # Recurse for items within arrays (e.g., in 'properties' of array items) - processed_schema[key] = [_preprocess_schema_for_vertex(item) if isinstance(item, dict) else item for item in value] + processed_schema[key] = [ + _preprocess_schema_for_vertex(item) if isinstance(item, dict) else item + for item in value + ] # Handle 'properties' specifically elif key == "properties" and isinstance(value, dict): - processed_schema[key] = {prop_key: _preprocess_schema_for_vertex(prop_value) for prop_key, prop_value in value.items()} + processed_schema[key] = { + prop_key: _preprocess_schema_for_vertex(prop_value) + for prop_key, prop_value in value.items() + } # Handle 'items' specifically if it's a schema object elif key == "items" and isinstance(value, dict): - processed_schema[key] = _preprocess_schema_for_vertex(value) - + processed_schema[key] = _preprocess_schema_for_vertex(value) return processed_schema + + # --- Helper Function to Safely Extract Text --- # Updated to handle google.generativeai.types.GenerateContentResponse -def _get_response_text(response: Optional[types.GenerateContentResponse]) -> Optional[str]: +def _get_response_text( + response: Optional[types.GenerateContentResponse], +) -> Optional[str]: """ Safely extracts the text content from the first text part of a GenerateContentResponse. Handles potential errors and lack of text parts gracefully. @@ -221,14 +266,16 @@ def _get_response_text(response: Optional[types.GenerateContentResponse]) -> Opt return None # Check if response has the 'text' attribute directly (common case for simple text responses) - if hasattr(response, 'text') and response.text: + if hasattr(response, "text") and response.text: print("[_get_response_text] Found text directly in response.text attribute.") return response.text # If no direct text, check candidates if not response.candidates: # Log the response object itself for debugging if it exists but has no candidates - print(f"[_get_response_text] Response object has no candidates. Response: {response}") + print( + f"[_get_response_text] Response object has no candidates. Response: {response}" + ) return None try: @@ -236,37 +283,50 @@ def _get_response_text(response: Optional[types.GenerateContentResponse]) -> Opt candidate = response.candidates[0] # Check candidate.content and candidate.content.parts - if not hasattr(candidate, 'content') or not candidate.content: - print(f"[_get_response_text] Candidate 0 has no 'content'. Candidate: {candidate}") + if not hasattr(candidate, "content") or not candidate.content: + print( + f"[_get_response_text] Candidate 0 has no 'content'. Candidate: {candidate}" + ) return None - if not hasattr(candidate.content, 'parts') or not candidate.content.parts: - print(f"[_get_response_text] Candidate 0 content has no 'parts' or parts list is empty. types.Content: {candidate.content}") + if not hasattr(candidate.content, "parts") or not candidate.content.parts: + print( + f"[_get_response_text] Candidate 0 content has no 'parts' or parts list is empty. types.Content: {candidate.content}" + ) return None # Log parts for debugging - print(f"[_get_response_text] Inspecting parts in candidate 0: {candidate.content.parts}") + print( + f"[_get_response_text] Inspecting parts in candidate 0: {candidate.content.parts}" + ) # Iterate through parts to find the first text part for i, part in enumerate(candidate.content.parts): # Check if the part has a 'text' attribute and it's not empty/None - if hasattr(part, 'text') and part.text is not None: # Check for None explicitly - # Check if text is non-empty string after stripping whitespace - if isinstance(part.text, str) and part.text.strip(): - print(f"[_get_response_text] Found non-empty text in part {i}.") - return part.text - else: - print(f"[_get_response_text] types.Part {i} has 'text' attribute, but it's empty or not a string: {part.text!r}") + if ( + hasattr(part, "text") and part.text is not None + ): # Check for None explicitly + # Check if text is non-empty string after stripping whitespace + if isinstance(part.text, str) and part.text.strip(): + print(f"[_get_response_text] Found non-empty text in part {i}.") + return part.text + else: + print( + f"[_get_response_text] types.Part {i} has 'text' attribute, but it's empty or not a string: {part.text!r}" + ) # else: # print(f"[_get_response_text] types.Part {i} does not have 'text' attribute or it's None.") - # If no text part is found after checking all parts in the first candidate - print(f"[_get_response_text] No usable text part found in candidate 0 after iterating through all parts.") + print( + f"[_get_response_text] No usable text part found in candidate 0 after iterating through all parts." + ) return None except (AttributeError, IndexError, TypeError) as e: # Handle cases where structure is unexpected, list is empty, or types are wrong - print(f"[_get_response_text] Error accessing response structure: {type(e).__name__}: {e}") + print( + f"[_get_response_text] Error accessing response structure: {type(e).__name__}: {e}" + ) # Log the problematic response object for deeper inspection print(f"Problematic response object: {response}") return None @@ -292,7 +352,7 @@ def _format_embeds_for_prompt(embed_content: List[Dict[str, Any]]) -> Optional[s parts.append(f"Title: {embed['title']}") if embed.get("description"): # Limit description length - desc = embed['description'] + desc = embed["description"] max_desc_len = 200 if len(desc) > max_desc_len: desc = desc[:max_desc_len] + "..." @@ -300,8 +360,8 @@ def _format_embeds_for_prompt(embed_content: List[Dict[str, Any]]) -> Optional[s if embed.get("fields"): field_parts = [] for field in embed["fields"]: - fname = field.get('name', 'Field') - fvalue = field.get('value', '') + fname = field.get("name", "Field") + fvalue = field.get("value", "") # Limit field value length max_field_len = 100 if len(fvalue) > max_field_len: @@ -312,14 +372,19 @@ def _format_embeds_for_prompt(embed_content: List[Dict[str, Any]]) -> Optional[s if embed.get("footer") and embed["footer"].get("text"): parts.append(f"Footer: {embed['footer']['text']}") if embed.get("image_url"): - parts.append(f"[Image Attached: {embed.get('image_url')}]") # Indicate image presence + parts.append( + f"[Image Attached: {embed.get('image_url')}]" + ) # Indicate image presence if embed.get("thumbnail_url"): - parts.append(f"[Thumbnail Attached: {embed.get('thumbnail_url')}]") # Indicate thumbnail presence + parts.append( + f"[Thumbnail Attached: {embed.get('thumbnail_url')}]" + ) # Indicate thumbnail presence formatted_strings.append("\n".join(parts)) return "\n".join(formatted_strings) if formatted_strings else None + # --- Initialize Google Generative AI Client for Vertex AI --- # No explicit genai.configure(api_key=...) needed when using Vertex AI backend try: @@ -329,7 +394,9 @@ try: location=LOCATION, ) - print(f"Google GenAI Client initialized for Vertex AI project '{PROJECT_ID}' in location '{LOCATION}'.") + print( + f"Google GenAI Client initialized for Vertex AI project '{PROJECT_ID}' in location '{LOCATION}'." + ) except NameError: genai_client = None print("Google GenAI SDK (genai) not imported, skipping client initialization.") @@ -341,21 +408,34 @@ except Exception as e: # Define standard safety settings using google.generativeai types # Set all thresholds to OFF as requested STANDARD_SAFETY_SETTINGS = [ - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold="BLOCK_NONE"), - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold="BLOCK_NONE"), - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold="BLOCK_NONE"), - types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold="BLOCK_NONE"), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold="BLOCK_NONE" + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold="BLOCK_NONE", + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold="BLOCK_NONE", + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold="BLOCK_NONE" + ), ] + # --- API Call Helper --- async def call_google_genai_api_with_retry( - cog: 'GurtCog', - model_name: str, # Pass model name string instead of model object - contents: List[types.Content], # Use types.Content type from google.generativeai.types - generation_config: types.GenerateContentConfig, # Combined config object + cog: "GurtCog", + model_name: str, # Pass model name string instead of model object + contents: List[ + types.Content + ], # Use types.Content type from google.generativeai.types + generation_config: types.GenerateContentConfig, # Combined config object request_desc: str, # Removed safety_settings, tools, tool_config as separate params -) -> Optional[types.GenerateContentResponse]: # Return type for non-streaming +) -> Optional[types.GenerateContentResponse]: # Return type for non-streaming """ Calls the Google Generative AI API (Vertex AI backend) with retry logic (non-streaming). @@ -382,24 +462,26 @@ async def call_google_genai_api_with_retry( # Note: model_name should include the 'projects/.../locations/.../endpoints/...' path for custom models # or just 'models/model-name' for standard models. try: - model = "projects/1079377687568/locations/us-central1/endpoints/6677946543460319232" # Use get_model to ensure it exists + model = "projects/1079377687568/locations/us-central1/endpoints/6677946543460319232" # Use get_model to ensure it exists if not model: - raise ValueError(f"Could not retrieve model: {model_name}") + raise ValueError(f"Could not retrieve model: {model_name}") except Exception as model_e: - print(f"Error retrieving model '{model_name}': {model_e}") - raise # Re-raise the exception as this is a fundamental setup issue + print(f"Error retrieving model '{model_name}': {model_e}") + raise # Re-raise the exception as this is a fundamental setup issue for attempt in range(API_RETRY_ATTEMPTS + 1): try: # Use the actual model name string passed to the function for logging - print(f"Sending API request for {request_desc} using {model_name} (Attempt {attempt + 1}/{API_RETRY_ATTEMPTS + 1})...") + print( + f"Sending API request for {request_desc} using {model_name} (Attempt {attempt + 1}/{API_RETRY_ATTEMPTS + 1})..." + ) # Use the non-streaming async call - config now contains all settings # The 'model' parameter here should be the actual model name string response = await genai_client.aio.models.generate_content( - model=model_name, # Use the model_name string directly + model=model_name, # Use the model_name string directly contents=contents, - config=generation_config, # Pass the combined config object + config=generation_config, # Pass the combined config object # stream=False is implicit for generate_content ) @@ -407,28 +489,55 @@ async def call_google_genai_api_with_retry( # Access finish reason and safety ratings from the response object if response and response.candidates: candidate = response.candidates[0] - finish_reason = getattr(candidate, 'finish_reason', None) - safety_ratings = getattr(candidate, 'safety_ratings', []) + finish_reason = getattr(candidate, "finish_reason", None) + safety_ratings = getattr(candidate, "safety_ratings", []) if finish_reason == types.FinishReason.SAFETY: - safety_ratings_str = ", ".join([f"{rating.category.name}: {rating.probability.name}" for rating in safety_ratings]) if safety_ratings else "N/A" + safety_ratings_str = ( + ", ".join( + [ + f"{rating.category.name}: {rating.probability.name}" + for rating in safety_ratings + ] + ) + if safety_ratings + else "N/A" + ) # Optionally, raise a specific exception here if needed downstream # raise SafetyBlockError(f"Blocked by safety filters. Ratings: {safety_ratings_str}") - elif finish_reason not in [types.FinishReason.STOP, types.FinishReason.MAX_TOKENS, None]: # Allow None finish reason - # Log other unexpected finish reasons - finish_reason_name = types.FinishReason(finish_reason).name if isinstance(finish_reason, int) else str(finish_reason) - print(f"⚠️ UNEXPECTED FINISH REASON: API request for {request_desc} ({model_name}) finished with reason: {finish_reason_name}") + elif finish_reason not in [ + types.FinishReason.STOP, + types.FinishReason.MAX_TOKENS, + None, + ]: # Allow None finish reason + # Log other unexpected finish reasons + finish_reason_name = ( + types.FinishReason(finish_reason).name + if isinstance(finish_reason, int) + else str(finish_reason) + ) + print( + f"⚠️ UNEXPECTED FINISH REASON: API request for {request_desc} ({model_name}) finished with reason: {finish_reason_name}" + ) # --- Success Logging (Proceed even if safety blocked, but log occurred) --- elapsed_time = time.monotonic() - start_time # Ensure model_name exists in stats before incrementing if model_name not in cog.api_stats: - cog.api_stats[model_name] = {'success': 0, 'failure': 0, 'retries': 0, 'total_time': 0.0, 'count': 0} - cog.api_stats[model_name]['success'] += 1 - cog.api_stats[model_name]['total_time'] += elapsed_time - cog.api_stats[model_name]['count'] += 1 - print(f"API request successful for {request_desc} ({model_name}) in {elapsed_time:.2f}s.") - return response # Success + cog.api_stats[model_name] = { + "success": 0, + "failure": 0, + "retries": 0, + "total_time": 0.0, + "count": 0, + } + cog.api_stats[model_name]["success"] += 1 + cog.api_stats[model_name]["total_time"] += elapsed_time + cog.api_stats[model_name]["count"] += 1 + print( + f"API request successful for {request_desc} ({model_name}) in {elapsed_time:.2f}s." + ) + return response # Success # Adapt exception handling if google.generativeai raises different types # google.api_core.exceptions should still cover many common API errors @@ -438,55 +547,83 @@ async def call_google_genai_api_with_retry( last_exception = e if attempt < API_RETRY_ATTEMPTS: if model_name not in cog.api_stats: - cog.api_stats[model_name] = {'success': 0, 'failure': 0, 'retries': 0, 'total_time': 0.0, 'count': 0} - cog.api_stats[model_name]['retries'] += 1 - wait_time = API_RETRY_DELAY * (2 ** attempt) # Exponential backoff + cog.api_stats[model_name] = { + "success": 0, + "failure": 0, + "retries": 0, + "total_time": 0.0, + "count": 0, + } + cog.api_stats[model_name]["retries"] += 1 + wait_time = API_RETRY_DELAY * (2**attempt) # Exponential backoff print(f"Waiting {wait_time:.2f} seconds before retrying...") await asyncio.sleep(wait_time) continue else: - break # Max retries reached + break # Max retries reached - except (google_exceptions.InternalServerError, google_exceptions.ServiceUnavailable) as e: + except ( + google_exceptions.InternalServerError, + google_exceptions.ServiceUnavailable, + ) as e: error_msg = f"API server error ({type(e).__name__}) for {request_desc} ({model_name}): {e}" print(f"{error_msg} (Attempt {attempt + 1})") last_exception = e if attempt < API_RETRY_ATTEMPTS: if model_name not in cog.api_stats: - cog.api_stats[model_name] = {'success': 0, 'failure': 0, 'retries': 0, 'total_time': 0.0, 'count': 0} - cog.api_stats[model_name]['retries'] += 1 - wait_time = API_RETRY_DELAY * (2 ** attempt) # Exponential backoff + cog.api_stats[model_name] = { + "success": 0, + "failure": 0, + "retries": 0, + "total_time": 0.0, + "count": 0, + } + cog.api_stats[model_name]["retries"] += 1 + wait_time = API_RETRY_DELAY * (2**attempt) # Exponential backoff print(f"Waiting {wait_time:.2f} seconds before retrying...") await asyncio.sleep(wait_time) continue else: - break # Max retries reached + break # Max retries reached except google_exceptions.InvalidArgument as e: # Often indicates a problem with the request itself (e.g., bad schema, unsupported format, invalid model name) error_msg = f"Invalid argument error for {request_desc} ({model_name}): {e}" print(error_msg) last_exception = e - break # Non-retryable + break # Non-retryable - except asyncio.TimeoutError: # Handle potential client-side timeouts if applicable - error_msg = f"Client-side request timed out for {request_desc} ({model_name}) (Attempt {attempt + 1})" - print(error_msg) - last_exception = asyncio.TimeoutError(error_msg) - # Decide if client-side timeouts should be retried - if attempt < API_RETRY_ATTEMPTS: - if model_name not in cog.api_stats: - cog.api_stats[model_name] = {'success': 0, 'failure': 0, 'retries': 0, 'total_time': 0.0, 'count': 0} - cog.api_stats[model_name]['retries'] += 1 - await asyncio.sleep(API_RETRY_DELAY * (attempt + 1)) # Linear backoff for timeout? Or keep exponential? - continue - else: - break + except ( + asyncio.TimeoutError + ): # Handle potential client-side timeouts if applicable + error_msg = f"Client-side request timed out for {request_desc} ({model_name}) (Attempt {attempt + 1})" + print(error_msg) + last_exception = asyncio.TimeoutError(error_msg) + # Decide if client-side timeouts should be retried + if attempt < API_RETRY_ATTEMPTS: + if model_name not in cog.api_stats: + cog.api_stats[model_name] = { + "success": 0, + "failure": 0, + "retries": 0, + "total_time": 0.0, + "count": 0, + } + cog.api_stats[model_name]["retries"] += 1 + await asyncio.sleep( + API_RETRY_DELAY * (attempt + 1) + ) # Linear backoff for timeout? Or keep exponential? + continue + else: + break - except Exception as e: # Catch other potential exceptions (e.g., from genai library itself) + except ( + Exception + ) as e: # Catch other potential exceptions (e.g., from genai library itself) error_msg = f"Unexpected error during API call for {request_desc} ({model_name}) (Attempt {attempt + 1}): {type(e).__name__}: {e}" print(error_msg) import traceback + traceback.print_exc() last_exception = e # Decide if this generic exception is retryable @@ -496,21 +633,29 @@ async def call_google_genai_api_with_retry( # --- Failure Logging --- elapsed_time = time.monotonic() - start_time if model_name not in cog.api_stats: - cog.api_stats[model_name] = {'success': 0, 'failure': 0, 'retries': 0, 'total_time': 0.0, 'count': 0} - cog.api_stats[model_name]['failure'] += 1 - cog.api_stats[model_name]['total_time'] += elapsed_time - cog.api_stats[model_name]['count'] += 1 - print(f"API request failed for {request_desc} ({model_name}) after {attempt + 1} attempts in {elapsed_time:.2f}s.") + cog.api_stats[model_name] = { + "success": 0, + "failure": 0, + "retries": 0, + "total_time": 0.0, + "count": 0, + } + cog.api_stats[model_name]["failure"] += 1 + cog.api_stats[model_name]["total_time"] += elapsed_time + cog.api_stats[model_name]["count"] += 1 + print( + f"API request failed for {request_desc} ({model_name}) after {attempt + 1} attempts in {elapsed_time:.2f}s." + ) # Raise the last encountered exception or a generic one - raise last_exception or Exception(f"API request failed for {request_desc} after {API_RETRY_ATTEMPTS + 1} attempts.") + raise last_exception or Exception( + f"API request failed for {request_desc} after {API_RETRY_ATTEMPTS + 1} attempts." + ) # --- JSON Parsing and Validation Helper --- def parse_and_validate_json_response( - response_text: Optional[str], - schema: Dict[str, Any], - context_description: str + response_text: Optional[str], schema: Dict[str, Any], context_description: str ) -> Optional[Dict[str, Any]]: """ Parses the AI's response text, attempting to extract and validate a JSON object against a schema. @@ -528,53 +673,75 @@ def parse_and_validate_json_response( return None parsed_data = None - raw_json_text = response_text # Start with the full text + raw_json_text = response_text # Start with the full text # Attempt 1: Try parsing the whole string directly try: parsed_data = json.loads(raw_json_text) - print(f"Parsing ({context_description}): Successfully parsed entire response as JSON.") + print( + f"Parsing ({context_description}): Successfully parsed entire response as JSON." + ) except json.JSONDecodeError: # Attempt 2: Extract JSON object, handling optional markdown fences # More robust regex to handle potential leading/trailing text and variations - json_match = re.search(r'```(?:json)?\s*(\{.*\})\s*```|(\{.*\})', response_text, re.DOTALL | re.MULTILINE) + json_match = re.search( + r"```(?:json)?\s*(\{.*\})\s*```|(\{.*\})", + response_text, + re.DOTALL | re.MULTILINE, + ) if json_match: json_str = json_match.group(1) or json_match.group(2) if json_str: - raw_json_text = json_str # Use the extracted string for parsing + raw_json_text = json_str # Use the extracted string for parsing try: parsed_data = json.loads(raw_json_text) - print(f"Parsing ({context_description}): Successfully extracted and parsed JSON using regex.") + print( + f"Parsing ({context_description}): Successfully extracted and parsed JSON using regex." + ) except json.JSONDecodeError as e_inner: - print(f"Parsing ({context_description}): Regex found potential JSON, but it failed to parse: {e_inner}\nContent: {raw_json_text[:500]}") + print( + f"Parsing ({context_description}): Regex found potential JSON, but it failed to parse: {e_inner}\nContent: {raw_json_text[:500]}" + ) parsed_data = None else: - print(f"Parsing ({context_description}): Regex matched, but failed to capture JSON content.") + print( + f"Parsing ({context_description}): Regex matched, but failed to capture JSON content." + ) parsed_data = None else: - print(f"Parsing ({context_description}): Could not parse directly or extract JSON object using regex.\nContent: {raw_json_text[:500]}") + print( + f"Parsing ({context_description}): Could not parse directly or extract JSON object using regex.\nContent: {raw_json_text[:500]}" + ) parsed_data = None # Validation step if parsed_data is not None: if not isinstance(parsed_data, dict): - print(f"Parsing ({context_description}): Parsed data is not a dictionary: {type(parsed_data)}") - return None # Fail validation if not a dict + print( + f"Parsing ({context_description}): Parsed data is not a dictionary: {type(parsed_data)}" + ) + return None # Fail validation if not a dict try: jsonschema.validate(instance=parsed_data, schema=schema) - print(f"Parsing ({context_description}): JSON successfully validated against schema.") + print( + f"Parsing ({context_description}): JSON successfully validated against schema." + ) # Ensure default keys exist after validation parsed_data.setdefault("should_respond", False) parsed_data.setdefault("content", None) parsed_data.setdefault("react_with_emoji", None) return parsed_data except jsonschema.ValidationError as e: - print(f"Parsing ({context_description}): JSON failed schema validation: {e.message}") + print( + f"Parsing ({context_description}): JSON failed schema validation: {e.message}" + ) # Optionally log more details: e.path, e.schema_path, e.instance - return None # Validation failed - except Exception as e: # Catch other potential validation errors - print(f"Parsing ({context_description}): Unexpected error during JSON schema validation: {e}") + return None # Validation failed + except Exception as e: # Catch other potential validation errors + print( + f"Parsing ({context_description}): Unexpected error during JSON schema validation: {e}" + ) return None else: # Parsing failed before validation could occur @@ -583,7 +750,9 @@ def parse_and_validate_json_response( # --- Tool Processing --- # Updated to use google.generativeai types -async def process_requested_tools(cog: 'GurtCog', function_call: types.FunctionCall) -> List[types.Part]: # Return type is List +async def process_requested_tools( + cog: "GurtCog", function_call: types.FunctionCall +) -> List[types.Part]: # Return type is List """ Process a tool request specified by the AI's FunctionCall response. Returns a list of types.Part objects (usually one, but potentially more if an image URL is detected in the result). @@ -605,58 +774,91 @@ async def process_requested_tools(cog: 'GurtCog', function_call: types.FunctionC # --- Tool Success Logging --- tool_elapsed_time = time.monotonic() - tool_start_time if function_name not in cog.tool_stats: - cog.tool_stats[function_name] = {'success': 0, 'failure': 0, 'total_time': 0.0, 'count': 0} - cog.tool_stats[function_name]['success'] += 1 - cog.tool_stats[function_name]['total_time'] += tool_elapsed_time - cog.tool_stats[function_name]['count'] += 1 - print(f"Tool '{function_name}' executed successfully in {tool_elapsed_time:.2f}s.") + cog.tool_stats[function_name] = { + "success": 0, + "failure": 0, + "total_time": 0.0, + "count": 0, + } + cog.tool_stats[function_name]["success"] += 1 + cog.tool_stats[function_name]["total_time"] += tool_elapsed_time + cog.tool_stats[function_name]["count"] += 1 + print( + f"Tool '{function_name}' executed successfully in {tool_elapsed_time:.2f}s." + ) # Ensure result is a dict, converting if necessary if not isinstance(result_dict, dict): - if isinstance(result_dict, (str, int, float, bool, list)) or result_dict is None: + if ( + isinstance(result_dict, (str, int, float, bool, list)) + or result_dict is None + ): result_dict = {"result": result_dict} else: - print(f"Warning: Tool '{function_name}' returned non-standard type {type(result_dict)}. Attempting str conversion.") + print( + f"Warning: Tool '{function_name}' returned non-standard type {type(result_dict)}. Attempting str conversion." + ) result_dict = {"result": str(result_dict)} - tool_result_content = result_dict # Now guaranteed to be a dict + tool_result_content = result_dict # Now guaranteed to be a dict except Exception as e: # --- Tool Failure Logging --- - tool_elapsed_time = time.monotonic() - tool_start_time # Recalculate time even on failure + tool_elapsed_time = ( + time.monotonic() - tool_start_time + ) # Recalculate time even on failure if function_name not in cog.tool_stats: - cog.tool_stats[function_name] = {'success': 0, 'failure': 0, 'total_time': 0.0, 'count': 0} - cog.tool_stats[function_name]['failure'] += 1 - cog.tool_stats[function_name]['total_time'] += tool_elapsed_time - cog.tool_stats[function_name]['count'] += 1 - error_message = f"Error executing tool {function_name}: {type(e).__name__}: {str(e)}" + cog.tool_stats[function_name] = { + "success": 0, + "failure": 0, + "total_time": 0.0, + "count": 0, + } + cog.tool_stats[function_name]["failure"] += 1 + cog.tool_stats[function_name]["total_time"] += tool_elapsed_time + cog.tool_stats[function_name]["count"] += 1 + error_message = ( + f"Error executing tool {function_name}: {type(e).__name__}: {str(e)}" + ) print(f"{error_message} (Took {tool_elapsed_time:.2f}s)") import traceback - traceback.print_exc() - tool_result_content = {"error": error_message} # Ensure it's a dict even on error - else: # This 'else' corresponds to 'if function_name in TOOL_MAPPING:' + traceback.print_exc() + tool_result_content = { + "error": error_message + } # Ensure it's a dict even on error + + else: # This 'else' corresponds to 'if function_name in TOOL_MAPPING:' # --- Tool Not Found Logging --- - tool_elapsed_time = time.monotonic() - tool_start_time # Time for the failed lookup + tool_elapsed_time = ( + time.monotonic() - tool_start_time + ) # Time for the failed lookup if function_name not in cog.tool_stats: - cog.tool_stats[function_name] = {'success': 0, 'failure': 0, 'total_time': 0.0, 'count': 0} - cog.tool_stats[function_name]['failure'] += 1 # Count as failure - cog.tool_stats[function_name]['total_time'] += tool_elapsed_time - cog.tool_stats[function_name]['count'] += 1 + cog.tool_stats[function_name] = { + "success": 0, + "failure": 0, + "total_time": 0.0, + "count": 0, + } + cog.tool_stats[function_name]["failure"] += 1 # Count as failure + cog.tool_stats[function_name]["total_time"] += tool_elapsed_time + cog.tool_stats[function_name]["count"] += 1 error_message = f"Tool '{function_name}' not found or implemented." print(f"{error_message} (Took {tool_elapsed_time:.2f}s)") - tool_result_content = {"error": error_message} # Ensure it's a dict + tool_result_content = {"error": error_message} # Ensure it's a dict # --- Process result for potential image URLs --- parts_to_return: List[types.Part] = [] - original_image_url: Optional[str] = None # Store the original URL if found - modified_result_content = copy.deepcopy(tool_result_content) # Work on a copy + original_image_url: Optional[str] = None # Store the original URL if found + modified_result_content = copy.deepcopy(tool_result_content) # Work on a copy # --- Image URL Detection & Modification --- # Check specific tools and keys known to contain image URLs # Special handling for get_user_avatar_data to directly use its base64 output - if function_name == "get_user_avatar_data" and isinstance(modified_result_content, dict): + if function_name == "get_user_avatar_data" and isinstance( + modified_result_content, dict + ): base64_image_data = modified_result_content.get("base64_data") image_mime_type = modified_result_content.get("content_type") @@ -664,82 +866,146 @@ async def process_requested_tools(cog: 'GurtCog', function_call: types.FunctionC try: image_bytes = base64.b64decode(base64_image_data) # Validate MIME type (optional, but good practice) - supported_image_mimes = ["image/png", "image/jpeg", "image/webp", "image/heic", "image/heif"] - clean_mime_type = image_mime_type.split(';')[0].lower() + supported_image_mimes = [ + "image/png", + "image/jpeg", + "image/webp", + "image/heic", + "image/heif", + ] + clean_mime_type = image_mime_type.split(";")[0].lower() if clean_mime_type in supported_image_mimes: # Corrected: Use inline_data for raw bytes - image_part = types.Part(inline_data=types.Blob(data=image_bytes, mime_type=clean_mime_type)) - parts_to_return.append(image_part) # Corrected: Add to parts_to_return for this tool's response - print(f"Added image part directly from get_user_avatar_data (MIME: {clean_mime_type}, {len(image_bytes)} bytes).") + image_part = types.Part( + inline_data=types.Blob( + data=image_bytes, mime_type=clean_mime_type + ) + ) + parts_to_return.append( + image_part + ) # Corrected: Add to parts_to_return for this tool's response + print( + f"Added image part directly from get_user_avatar_data (MIME: {clean_mime_type}, {len(image_bytes)} bytes)." + ) # Replace base64_data in the textual response to avoid sending it twice - modified_result_content["base64_data"] = "[Image Content Attached In Prompt]" - modified_result_content["content_type"] = f"[MIME type: {clean_mime_type} - Content Attached In Prompt]" + modified_result_content["base64_data"] = ( + "[Image Content Attached In Prompt]" + ) + modified_result_content["content_type"] = ( + f"[MIME type: {clean_mime_type} - Content Attached In Prompt]" + ) else: - print(f"Warning: MIME type '{clean_mime_type}' from get_user_avatar_data not in supported list. Not attaching image part.") - modified_result_content["base64_data"] = "[Image Data Not Attached - Unsupported MIME Type]" + print( + f"Warning: MIME type '{clean_mime_type}' from get_user_avatar_data not in supported list. Not attaching image part." + ) + modified_result_content["base64_data"] = ( + "[Image Data Not Attached - Unsupported MIME Type]" + ) except Exception as e: print(f"Error processing base64 data from get_user_avatar_data: {e}") - modified_result_content["base64_data"] = f"[Error Processing Image Data: {e}]" + modified_result_content["base64_data"] = ( + f"[Error Processing Image Data: {e}]" + ) # Prevent generic URL download logic from re-processing this avatar - original_image_url = None # Explicitly nullify to skip URL download - - elif function_name == "get_user_avatar_url" and isinstance(modified_result_content, dict): + original_image_url = None # Explicitly nullify to skip URL download + + elif function_name == "get_user_avatar_url" and isinstance( + modified_result_content, dict + ): avatar_url_value = modified_result_content.get("avatar_url") if avatar_url_value and isinstance(avatar_url_value, str): - original_image_url = avatar_url_value # Store original - modified_result_content["avatar_url"] = "[Image Content Attached]" # Replace URL with placeholder - elif function_name == "get_user_profile_info" and isinstance(modified_result_content, dict): + original_image_url = avatar_url_value # Store original + modified_result_content["avatar_url"] = ( + "[Image Content Attached]" # Replace URL with placeholder + ) + elif function_name == "get_user_profile_info" and isinstance( + modified_result_content, dict + ): profile_dict = modified_result_content.get("profile") if isinstance(profile_dict, dict): avatar_url_value = profile_dict.get("avatar_url") if avatar_url_value and isinstance(avatar_url_value, str): - original_image_url = avatar_url_value # Store original - profile_dict["avatar_url"] = "[Image Content Attached]" # Replace URL in nested dict + original_image_url = avatar_url_value # Store original + profile_dict["avatar_url"] = ( + "[Image Content Attached]" # Replace URL in nested dict + ) # Add checks for other tools/keys that might return image URLs if necessary # --- Create Parts --- # Always add the function response part (using the potentially modified content) - function_response_part = types.Part(function_response=types.FunctionResponse(name=function_name, response=modified_result_content)) + function_response_part = types.Part( + function_response=types.FunctionResponse( + name=function_name, response=modified_result_content + ) + ) parts_to_return.append(function_response_part) # Add image part if an original URL was found and seems valid - if original_image_url and isinstance(original_image_url, str) and original_image_url.startswith('http'): + if ( + original_image_url + and isinstance(original_image_url, str) + and original_image_url.startswith("http") + ): download_success = False try: # Download the image data using aiohttp session from cog - if not hasattr(cog, 'session') or not cog.session: + if not hasattr(cog, "session") or not cog.session: raise ValueError("aiohttp session not found in cog.") print(f"Downloading image data from URL: {original_image_url}") - async with cog.session.get(original_image_url, timeout=15) as response: # Added timeout + async with cog.session.get( + original_image_url, timeout=15 + ) as response: # Added timeout if response.status == 200: image_bytes = await response.read() - mime_type = response.content_type or "application/octet-stream" # Get MIME type from header + mime_type = ( + response.content_type or "application/octet-stream" + ) # Get MIME type from header # Validate against known supported image types for Gemini - supported_image_mimes = ["image/png", "image/jpeg", "image/webp", "image/heic", "image/heif"] - clean_mime_type = mime_type.split(';')[0].lower() # Clean MIME type + supported_image_mimes = [ + "image/png", + "image/jpeg", + "image/webp", + "image/heic", + "image/heif", + ] + clean_mime_type = mime_type.split(";")[0].lower() # Clean MIME type if clean_mime_type in supported_image_mimes: # Use types.Part.from_data instead of from_uri - image_part = types.Part(inline_data=types.Blob(data=image_bytes, mime_type=clean_mime_type)) + image_part = types.Part( + inline_data=types.Blob( + data=image_bytes, mime_type=clean_mime_type + ) + ) parts_to_return.append(image_part) download_success = True - print(f"Added image part (from data, {len(image_bytes)} bytes, MIME: {clean_mime_type}) from tool '{function_name}' result.") + print( + f"Added image part (from data, {len(image_bytes)} bytes, MIME: {clean_mime_type}) from tool '{function_name}' result." + ) else: - print(f"Warning: Downloaded image MIME type '{clean_mime_type}' from {original_image_url} might not be supported by Gemini. Skipping image part.") + print( + f"Warning: Downloaded image MIME type '{clean_mime_type}' from {original_image_url} might not be supported by Gemini. Skipping image part." + ) else: - print(f"Error downloading image from {original_image_url}: Status {response.status}") + print( + f"Error downloading image from {original_image_url}: Status {response.status}" + ) except asyncio.TimeoutError: - print(f"Error downloading image from {original_image_url}: Request timed out.") + print( + f"Error downloading image from {original_image_url}: Request timed out." + ) except aiohttp.ClientError as client_e: print(f"Error downloading image from {original_image_url}: {client_e}") - except ValueError as val_e: # Catch missing session error + except ValueError as val_e: # Catch missing session error print(f"Error preparing image download: {val_e}") except Exception as e: - print(f"Error downloading or creating image part from data ({original_image_url}): {e}") + print( + f"Error downloading or creating image part from data ({original_image_url}): {e}" + ) # If download or processing failed, add an error note for the LLM if not download_success: @@ -747,30 +1013,38 @@ async def process_requested_tools(cog: 'GurtCog', function_call: types.FunctionC error_text_part = types.Part(text=error_text) parts_to_return.append(error_text_part) - return parts_to_return # Return the list of parts (will contain 1 or 2+ parts) + return parts_to_return # Return the list of parts (will contain 1 or 2+ parts) # --- Helper to find function call in parts --- # Updated to use google.generativeai types -def find_function_call_in_parts(parts: Optional[List[types.Part]]) -> Optional[types.FunctionCall]: +def find_function_call_in_parts( + parts: Optional[List[types.Part]], +) -> Optional[types.FunctionCall]: """Finds the first valid FunctionCall object within a list of Parts.""" if not parts: return None for part in parts: # Check if the part has a 'function_call' attribute and it's a valid FunctionCall object - if hasattr(part, 'function_call') and isinstance(part.function_call, types.FunctionCall): + if hasattr(part, "function_call") and isinstance( + part.function_call, types.FunctionCall + ): # Basic validation: ensure name exists if part.function_call.name: - return part.function_call + return part.function_call else: - print(f"Warning: Found types.Part with 'function_call', but its name is missing: {part.function_call}") + print( + f"Warning: Found types.Part with 'function_call', but its name is missing: {part.function_call}" + ) # else: # print(f"Debug: types.Part does not have valid function_call: {part}") # Optional debug log return None # --- Main AI Response Function --- -async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: Optional[str] = None) -> Tuple[Dict[str, Any], List[str]]: +async def get_ai_response( + cog: "GurtCog", message: discord.Message, model_name: Optional[str] = None +) -> Tuple[Dict[str, Any], List[str]]: """ Gets responses from the Vertex AI Gemini API, handling potential tool usage and returning the final parsed response. @@ -786,14 +1060,14 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: - "error": An error message string if a critical error occurred, otherwise None. - "fallback_initial": Optional minimal response if initial parsing failed critically (less likely with controlled generation). """ - if not PROJECT_ID or not LOCATION or not genai_client: # Check genai_client too - error_msg = "Google Cloud Project ID/Location not configured or GenAI Client failed to initialize." - print(f"Error in get_ai_response: {error_msg}") - return {"final_response": None, "error": error_msg} + if not PROJECT_ID or not LOCATION or not genai_client: # Check genai_client too + error_msg = "Google Cloud Project ID/Location not configured or GenAI Client failed to initialize." + print(f"Error in get_ai_response: {error_msg}") + return {"final_response": None, "error": error_msg} # Determine the model for all generation steps. # Use the model_name override if provided to get_ai_response, otherwise use the cog's current default_model. - active_model = model_name or cog.default_model + active_model = model_name or cog.default_model print(f"Using active model for all generation steps: {active_model}") @@ -802,16 +1076,18 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: # initial_parsed_data is no longer needed with the loop structure final_parsed_data = None error_message = None - fallback_response = None # Keep fallback for critical initial failures - max_tool_calls = 5 # Maximum number of sequential tool calls allowed + fallback_response = None # Keep fallback for critical initial failures + max_tool_calls = 5 # Maximum number of sequential tool calls allowed tool_calls_made = 0 - last_response_obj = None # Store the last response object from the loop + last_response_obj = None # Store the last response object from the loop try: # --- Build Prompt Components --- final_system_prompt = await build_dynamic_system_prompt(cog, message) - conversation_context_messages = gather_conversation_context(cog, channel_id, message.id) # Pass cog - memory_context = await get_memory_context(cog, message) # Pass cog + conversation_context_messages = gather_conversation_context( + cog, channel_id, message.id + ) # Pass cog + memory_context = await get_memory_context(cog, message) # Pass cog # --- Prepare Message History (Contents) --- # Contents will be built progressively within the loop @@ -828,146 +1104,238 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: # but this might confuse the turn structure. # contents.append(types.Content(role="model", parts=[types.Part.from_text(f"System Note: {memory_context}")])) - # Add conversation history # The current message is already included in conversation_context_messages for msg in conversation_context_messages: - role = "assistant" if msg.get('author', {}).get('id') == str(cog.bot.user.id) else "user" # Use get for safety - parts: List[types.Part] = [] # Initialize parts for each message + role = ( + "assistant" + if msg.get("author", {}).get("id") == str(cog.bot.user.id) + else "user" + ) # Use get for safety + parts: List[types.Part] = [] # Initialize parts for each message # Handle potential multimodal content in history (if stored that way) if isinstance(msg.get("content"), list): - # If content is already a list of parts, process them - for part_data in msg["content"]: - if part_data["type"] == "text": - parts.append(types.Part(text=part_data["text"])) - elif part_data["type"] == "image_url": - # Assuming image_url has 'url' and 'mime_type' - parts.append(types.Part(uri=part_data["image_url"]["url"], mime_type=part_data["image_url"]["url"].split(";")[0].split(":")[1])) - # Filter out None parts if any were conditionally added - parts = [p for p in parts if p] + # If content is already a list of parts, process them + for part_data in msg["content"]: + if part_data["type"] == "text": + parts.append(types.Part(text=part_data["text"])) + elif part_data["type"] == "image_url": + # Assuming image_url has 'url' and 'mime_type' + parts.append( + types.Part( + uri=part_data["image_url"]["url"], + mime_type=part_data["image_url"]["url"] + .split(";")[0] + .split(":")[1], + ) + ) + # Filter out None parts if any were conditionally added + parts = [p for p in parts if p] # Combine text, embeds, and attachments for history messages - elif isinstance(msg.get("content"), str) or msg.get("embed_content") or msg.get("attachment_descriptions"): - text_parts = [] - # Add original text content if it exists and is not empty - if isinstance(msg.get("content"), str) and msg["content"].strip(): - text_parts.append(msg["content"]) - # Add formatted embed content if present - embed_str = _format_embeds_for_prompt(msg.get("embed_content", [])) - if embed_str: - text_parts.append(f"\n[Embed Content]:\n{embed_str}") - # Add attachment descriptions if present - if msg.get("attachment_descriptions"): - # Ensure descriptions are strings before joining - attach_desc_list = [a['description'] for a in msg['attachment_descriptions'] if isinstance(a.get('description'), str)] - if attach_desc_list: - attach_desc_str = "\n".join(attach_desc_list) - text_parts.append(f"\n[Attachments]:\n{attach_desc_str}") - - # Add custom emoji and sticker descriptions and images from cache for historical messages - cached_emojis = msg.get("custom_emojis", []) - for emoji_info in cached_emojis: - emoji_name = emoji_info.get("name") - emoji_url = emoji_info.get("url") - if emoji_name: - text_parts.append(f"[Emoji: {emoji_name}]") - if emoji_url and emoji_url.startswith('http'): - # Determine MIME type for emoji URI - is_animated_emoji = emoji_info.get("animated", False) - emoji_mime_type = "image/gif" if is_animated_emoji else "image/png" - try: - # Download emoji data and send as inline_data - async with cog.session.get(emoji_url, timeout=10) as response: - if response.status == 200: - emoji_bytes = await response.read() - parts.append(types.Part(inline_data=types.Blob(data=emoji_bytes, mime_type=emoji_mime_type))) - print(f"Added inline_data part for historical emoji: {emoji_name} (MIME: {emoji_mime_type}, {len(emoji_bytes)} bytes)") - else: - print(f"Error downloading historical emoji {emoji_name} from {emoji_url}: Status {response.status}") - text_parts.append(f"[System Note: Failed to download emoji '{emoji_name}']") - except Exception as e: - print(f"Error downloading/processing historical emoji {emoji_name} from {emoji_url}: {e}") - text_parts.append(f"[System Note: Failed to process emoji '{emoji_name}']") + elif ( + isinstance(msg.get("content"), str) + or msg.get("embed_content") + or msg.get("attachment_descriptions") + ): + text_parts = [] + # Add original text content if it exists and is not empty + if isinstance(msg.get("content"), str) and msg["content"].strip(): + text_parts.append(msg["content"]) + # Add formatted embed content if present + embed_str = _format_embeds_for_prompt(msg.get("embed_content", [])) + if embed_str: + text_parts.append(f"\n[Embed Content]:\n{embed_str}") + # Add attachment descriptions if present + if msg.get("attachment_descriptions"): + # Ensure descriptions are strings before joining + attach_desc_list = [ + a["description"] + for a in msg["attachment_descriptions"] + if isinstance(a.get("description"), str) + ] + if attach_desc_list: + attach_desc_str = "\n".join(attach_desc_list) + text_parts.append(f"\n[Attachments]:\n{attach_desc_str}") - cached_stickers = msg.get("stickers", []) - for sticker_info in cached_stickers: - sticker_name = sticker_info.get("name") - sticker_url = sticker_info.get("url") - sticker_format_str = sticker_info.get("format") - sticker_format_text = f" (Format: {sticker_format_str})" if sticker_format_str else "" - if sticker_name: - text_parts.append(f"[Sticker: {sticker_name}{sticker_format_text}]") - is_image_sticker = sticker_format_str in ["StickerFormatType.png", "StickerFormatType.apng"] - if is_image_sticker and sticker_url and sticker_url.startswith('http'): - sticker_mime_type = "image/png" # APNG is also sent as image/png - try: - # Download sticker data and send as inline_data - async with cog.session.get(sticker_url, timeout=10) as response: - if response.status == 200: - sticker_bytes = await response.read() - parts.append(types.Part(inline_data=types.Blob(data=sticker_bytes, mime_type=sticker_mime_type))) - print(f"Added inline_data part for historical sticker: {sticker_name} (MIME: {sticker_mime_type}, {len(sticker_bytes)} bytes)") - else: - print(f"Error downloading historical sticker {sticker_name} from {sticker_url}: Status {response.status}") - text_parts.append(f"[System Note: Failed to download sticker '{sticker_name}']") - except Exception as e: - print(f"Error downloading/processing historical sticker {sticker_name} from {sticker_url}: {e}") - text_parts.append(f"[System Note: Failed to process sticker '{sticker_name}']") - elif sticker_format_str == "StickerFormatType.lottie": - # Lottie files are JSON, not directly viewable by Gemini as images. Send as text. - text_parts.append(f"[Lottie Sticker: {sticker_name} (JSON animation, not displayed as image)]") + # Add custom emoji and sticker descriptions and images from cache for historical messages + cached_emojis = msg.get("custom_emojis", []) + for emoji_info in cached_emojis: + emoji_name = emoji_info.get("name") + emoji_url = emoji_info.get("url") + if emoji_name: + text_parts.append(f"[Emoji: {emoji_name}]") + if emoji_url and emoji_url.startswith("http"): + # Determine MIME type for emoji URI + is_animated_emoji = emoji_info.get("animated", False) + emoji_mime_type = ( + "image/gif" if is_animated_emoji else "image/png" + ) + try: + # Download emoji data and send as inline_data + async with cog.session.get( + emoji_url, timeout=10 + ) as response: + if response.status == 200: + emoji_bytes = await response.read() + parts.append( + types.Part( + inline_data=types.Blob( + data=emoji_bytes, + mime_type=emoji_mime_type, + ) + ) + ) + print( + f"Added inline_data part for historical emoji: {emoji_name} (MIME: {emoji_mime_type}, {len(emoji_bytes)} bytes)" + ) + else: + print( + f"Error downloading historical emoji {emoji_name} from {emoji_url}: Status {response.status}" + ) + text_parts.append( + f"[System Note: Failed to download emoji '{emoji_name}']" + ) + except Exception as e: + print( + f"Error downloading/processing historical emoji {emoji_name} from {emoji_url}: {e}" + ) + text_parts.append( + f"[System Note: Failed to process emoji '{emoji_name}']" + ) + cached_stickers = msg.get("stickers", []) + for sticker_info in cached_stickers: + sticker_name = sticker_info.get("name") + sticker_url = sticker_info.get("url") + sticker_format_str = sticker_info.get("format") + sticker_format_text = ( + f" (Format: {sticker_format_str})" if sticker_format_str else "" + ) + if sticker_name: + text_parts.append( + f"[Sticker: {sticker_name}{sticker_format_text}]" + ) + is_image_sticker = sticker_format_str in [ + "StickerFormatType.png", + "StickerFormatType.apng", + ] + if ( + is_image_sticker + and sticker_url + and sticker_url.startswith("http") + ): + sticker_mime_type = ( + "image/png" # APNG is also sent as image/png + ) + try: + # Download sticker data and send as inline_data + async with cog.session.get( + sticker_url, timeout=10 + ) as response: + if response.status == 200: + sticker_bytes = await response.read() + parts.append( + types.Part( + inline_data=types.Blob( + data=sticker_bytes, + mime_type=sticker_mime_type, + ) + ) + ) + print( + f"Added inline_data part for historical sticker: {sticker_name} (MIME: {sticker_mime_type}, {len(sticker_bytes)} bytes)" + ) + else: + print( + f"Error downloading historical sticker {sticker_name} from {sticker_url}: Status {response.status}" + ) + text_parts.append( + f"[System Note: Failed to download sticker '{sticker_name}']" + ) + except Exception as e: + print( + f"Error downloading/processing historical sticker {sticker_name} from {sticker_url}: {e}" + ) + text_parts.append( + f"[System Note: Failed to process sticker '{sticker_name}']" + ) + elif sticker_format_str == "StickerFormatType.lottie": + # Lottie files are JSON, not directly viewable by Gemini as images. Send as text. + text_parts.append( + f"[Lottie Sticker: {sticker_name} (JSON animation, not displayed as image)]" + ) - full_text = "\n".join(text_parts).strip() - if full_text: # Only add if there's some text content - author_string_from_cache = msg.get("author_string") + full_text = "\n".join(text_parts).strip() + if full_text: # Only add if there's some text content + author_string_from_cache = msg.get("author_string") - if author_string_from_cache and str(author_string_from_cache).strip(): - # If author_string is available and valid from the cache, use it directly. - # This string is expected to be pre-formatted by the context gathering logic. - author_identifier_string = str(author_string_from_cache) - parts.append(types.Part(text=f"{author_identifier_string}: {full_text}")) - else: - # Fallback to reconstructing the author identifier if author_string is not available/valid - author_details = msg.get("author", {}) - raw_display_name = author_details.get("display_name") - raw_name = author_details.get("name") # Discord username - author_id = author_details.get("id") + if ( + author_string_from_cache + and str(author_string_from_cache).strip() + ): + # If author_string is available and valid from the cache, use it directly. + # This string is expected to be pre-formatted by the context gathering logic. + author_identifier_string = str(author_string_from_cache) + parts.append( + types.Part(text=f"{author_identifier_string}: {full_text}") + ) + else: + # Fallback to reconstructing the author identifier if author_string is not available/valid + author_details = msg.get("author", {}) + raw_display_name = author_details.get("display_name") + raw_name = author_details.get("name") # Discord username + author_id = author_details.get("id") - final_display_part = "" - username_part_str = "" + final_display_part = "" + username_part_str = "" - if raw_display_name and str(raw_display_name).strip(): - final_display_part = str(raw_display_name) - elif raw_name and str(raw_name).strip(): # Fallback display to username - final_display_part = str(raw_name) - elif author_id: # Fallback display to User ID - final_display_part = f"User ID: {author_id}" - else: # Default to "Unknown User" if no other identifier is found - final_display_part = "Unknown User" + if raw_display_name and str(raw_display_name).strip(): + final_display_part = str(raw_display_name) + elif ( + raw_name and str(raw_name).strip() + ): # Fallback display to username + final_display_part = str(raw_name) + elif author_id: # Fallback display to User ID + final_display_part = f"User ID: {author_id}" + else: # Default to "Unknown User" if no other identifier is found + final_display_part = "Unknown User" + + # Construct username part if raw_name is valid and different from final_display_part + if ( + raw_name + and str(raw_name).strip() + and str(raw_name).lower() != "none" + ): + # Avoid "Username (Username: Username)" if display name fell back to raw_name + if final_display_part.lower() != str(raw_name).lower(): + username_part_str = f" (Username: {str(raw_name)})" + # If username is bad/missing, but we have an ID, and ID isn't already the main display part + elif author_id and not ( + raw_name + and str(raw_name).strip() + and str(raw_name).lower() != "none" + ): + if not final_display_part.startswith("User ID:"): + username_part_str = f" (User ID: {author_id})" + + author_identifier_string = ( + f"{final_display_part}{username_part_str}" + ) + # Append the text part to the existing parts list for this message + parts.append( + types.Part(text=f"{author_identifier_string}: {full_text}") + ) - # Construct username part if raw_name is valid and different from final_display_part - if raw_name and str(raw_name).strip() and str(raw_name).lower() != "none": - # Avoid "Username (Username: Username)" if display name fell back to raw_name - if final_display_part.lower() != str(raw_name).lower(): - username_part_str = f" (Username: {str(raw_name)})" - # If username is bad/missing, but we have an ID, and ID isn't already the main display part - elif author_id and not (raw_name and str(raw_name).strip() and str(raw_name).lower() != "none"): - if not final_display_part.startswith("User ID:"): - username_part_str = f" (User ID: {author_id})" - - author_identifier_string = f"{final_display_part}{username_part_str}" - # Append the text part to the existing parts list for this message - parts.append(types.Part(text=f"{author_identifier_string}: {full_text}")) - # Only append to contents if there are parts to add for this message if parts: contents.append(types.Content(role=role, parts=parts)) else: # If no parts were generated (e.g., empty message, or only unsupported content), # log a warning and skip adding this message to contents. - print(f"Warning: Skipping message from history (ID: {msg.get('id')}) as no valid parts were generated.") - + print( + f"Warning: Skipping message from history (ID: {msg.get('id')}) as no valid parts were generated." + ) # --- Prepare the current message content (potentially multimodal) --- # This section is no longer needed as the current message is included in conversation_context_messages @@ -1019,17 +1387,23 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: if contents[i].role == "user": current_user_content_index = i break - + # Ensure formatted_current_message is defined for the current message processing # This will be used for attachments, emojis, and stickers for the current message. formatted_current_message = format_message(cog, message) if message.attachments and current_user_content_index != -1: - print(f"Processing {len(message.attachments)} attachments for current message {message.id}") - attachment_parts_to_add = [] # Collect parts to add to the current user message + print( + f"Processing {len(message.attachments)} attachments for current message {message.id}" + ) + attachment_parts_to_add = ( + [] + ) # Collect parts to add to the current user message # Fetch the attachment descriptions from the already formatted message - attachment_descriptions = formatted_current_message.get("attachment_descriptions", []) + attachment_descriptions = formatted_current_message.get( + "attachment_descriptions", [] + ) desc_map = {desc.get("filename"): desc for desc in attachment_descriptions} for attachment in message.attachments: @@ -1040,56 +1414,97 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: # Check if MIME type is supported for URI input by Gemini # Expanded list based on Gemini 1.5 Pro docs (April 2024) supported_mime_prefixes = [ - "image/", # image/png, image/jpeg, image/heic, image/heif, image/webp - "video/", # video/mov, video/mpeg, video/mp4, video/mpg, video/avi, video/wmv, video/mpegps, video/flv - "audio/", # audio/mpeg, audio/mp3, audio/wav, audio/ogg, audio/flac, audio/opus, audio/amr, audio/midi + "image/", # image/png, image/jpeg, image/heic, image/heif, image/webp + "video/", # video/mov, video/mpeg, video/mp4, video/mpg, video/avi, video/wmv, video/mpegps, video/flv + "audio/", # audio/mpeg, audio/mp3, audio/wav, audio/ogg, audio/flac, audio/opus, audio/amr, audio/midi "text/", # text/plain, text/html, text/css, text/javascript, text/json, text/csv, text/rtf, text/markdown "application/pdf", - "application/rtf", # Explicitly add RTF if needed + "application/rtf", # Explicitly add RTF if needed # Add more as supported/needed ] is_supported = False - detected_mime_type = mime_type if mime_type else "application/octet-stream" # Default if missing + detected_mime_type = ( + mime_type if mime_type else "application/octet-stream" + ) # Default if missing for prefix in supported_mime_prefixes: if detected_mime_type.startswith(prefix): is_supported = True break # Get pre-formatted description string (already includes size, type etc.) - preformatted_desc = desc_map.get(filename, {}).get("description", f"[File: {filename} (unknown type)]") + preformatted_desc = desc_map.get(filename, {}).get( + "description", f"[File: {filename} (unknown type)]" + ) # Add the descriptive text part using the pre-formatted description - instruction_text = f"[ATTACHMENT] {preformatted_desc}" # Explicitly mark as attachment + instruction_text = ( + f"[ATTACHMENT] {preformatted_desc}" # Explicitly mark as attachment + ) attachment_parts_to_add.append(types.Part(text=instruction_text)) print(f"Added text description for attachment: {filename}") if is_supported and file_url: try: - clean_mime_type = detected_mime_type.split(';')[0] if detected_mime_type else "application/octet-stream" + clean_mime_type = ( + detected_mime_type.split(";")[0] + if detected_mime_type + else "application/octet-stream" + ) # Download attachment data and send as inline_data - async with cog.session.get(file_url, timeout=15) as response: # Increased timeout for potentially larger files + async with cog.session.get( + file_url, timeout=15 + ) as response: # Increased timeout for potentially larger files if response.status == 200: attachment_bytes = await response.read() - attachment_parts_to_add.append(types.Part(inline_data=types.Blob(data=attachment_bytes, mime_type=clean_mime_type))) - print(f"Added inline_data part for supported attachment: {filename} (MIME: {clean_mime_type}, {len(attachment_bytes)} bytes)") + attachment_parts_to_add.append( + types.Part( + inline_data=types.Blob( + data=attachment_bytes, + mime_type=clean_mime_type, + ) + ) + ) + print( + f"Added inline_data part for supported attachment: {filename} (MIME: {clean_mime_type}, {len(attachment_bytes)} bytes)" + ) else: - print(f"Error downloading attachment {filename} from {file_url}: Status {response.status}") - attachment_parts_to_add.append(types.Part(text=f"(System Note: Failed to download attachment '{filename}')")) + print( + f"Error downloading attachment {filename} from {file_url}: Status {response.status}" + ) + attachment_parts_to_add.append( + types.Part( + text=f"(System Note: Failed to download attachment '{filename}')" + ) + ) except Exception as e: - print(f"Error downloading/processing attachment {filename} from {file_url}: {e}") - attachment_parts_to_add.append(types.Part(text=f"(System Note: Failed to process attachment '{filename}' - {e})")) + print( + f"Error downloading/processing attachment {filename} from {file_url}: {e}" + ) + attachment_parts_to_add.append( + types.Part( + text=f"(System Note: Failed to process attachment '{filename}' - {e})" + ) + ) else: - print(f"Skipping inline_data part for unsupported attachment: {filename} (Type: {detected_mime_type}, URL: {file_url})") + print( + f"Skipping inline_data part for unsupported attachment: {filename} (Type: {detected_mime_type}, URL: {file_url})" + ) # Text description was already added above # Add the collected attachment parts to the existing user message parts if attachment_parts_to_add: - contents[current_user_content_index].parts.extend(attachment_parts_to_add) - print(f"Extended user message at index {current_user_content_index} with {len(attachment_parts_to_add)} attachment parts.") + contents[current_user_content_index].parts.extend( + attachment_parts_to_add + ) + print( + f"Extended user message at index {current_user_content_index} with {len(attachment_parts_to_add)} attachment parts." + ) elif not message.attachments: print("No attachments found for the current message.") elif current_user_content_index == -1: - print("Warning: Could not find current user message in contents to add attachments to (for attachments).") + print( + "Warning: Could not find current user message in contents to add attachments to (for attachments)." + ) # --- End attachment processing --- # --- Add current message custom emojis and stickers --- @@ -1101,8 +1516,12 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: emoji_name = emoji_info.get("name") emoji_url = emoji_info.get("url") if emoji_name and emoji_url: - emoji_sticker_parts_to_add.append(types.Part(text=f"[Emoji: {emoji_name}]")) - print(f"Added text description for current message emoji: {emoji_name}") + emoji_sticker_parts_to_add.append( + types.Part(text=f"[Emoji: {emoji_name}]") + ) + print( + f"Added text description for current message emoji: {emoji_name}" + ) # Determine MIME type for emoji URI is_animated_emoji = emoji_info.get("animated", False) emoji_mime_type = "image/gif" if is_animated_emoji else "image/png" @@ -1111,14 +1530,34 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: async with cog.session.get(emoji_url, timeout=10) as response: if response.status == 200: emoji_bytes = await response.read() - emoji_sticker_parts_to_add.append(types.Part(inline_data=types.Blob(data=emoji_bytes, mime_type=emoji_mime_type))) - print(f"Added inline_data part for current emoji: {emoji_name} (MIME: {emoji_mime_type}, {len(emoji_bytes)} bytes)") + emoji_sticker_parts_to_add.append( + types.Part( + inline_data=types.Blob( + data=emoji_bytes, mime_type=emoji_mime_type + ) + ) + ) + print( + f"Added inline_data part for current emoji: {emoji_name} (MIME: {emoji_mime_type}, {len(emoji_bytes)} bytes)" + ) else: - print(f"Error downloading current emoji {emoji_name} from {emoji_url}: Status {response.status}") - emoji_sticker_parts_to_add.append(types.Part(text=f"[System Note: Failed to download emoji '{emoji_name}']")) + print( + f"Error downloading current emoji {emoji_name} from {emoji_url}: Status {response.status}" + ) + emoji_sticker_parts_to_add.append( + types.Part( + text=f"[System Note: Failed to download emoji '{emoji_name}']" + ) + ) except Exception as e: - print(f"Error downloading/processing current emoji {emoji_name} from {emoji_url}: {e}") - emoji_sticker_parts_to_add.append(types.Part(text=f"[System Note: Failed to process emoji '{emoji_name}']")) + print( + f"Error downloading/processing current emoji {emoji_name} from {emoji_url}: {e}" + ) + emoji_sticker_parts_to_add.append( + types.Part( + text=f"[System Note: Failed to process emoji '{emoji_name}']" + ) + ) # Process stickers from formatted_current_message stickers_current = formatted_current_message.get("stickers", []) @@ -1128,37 +1567,83 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: sticker_format_str = sticker_info.get("format") if sticker_name and sticker_url: - emoji_sticker_parts_to_add.append(types.Part(text=f"[Sticker: {sticker_name}]")) - print(f"Added text description for current message sticker: {sticker_name}") + emoji_sticker_parts_to_add.append( + types.Part(text=f"[Sticker: {sticker_name}]") + ) + print( + f"Added text description for current message sticker: {sticker_name}" + ) - is_image_sticker = sticker_format_str in ["StickerFormatType.png", "StickerFormatType.apng"] + is_image_sticker = sticker_format_str in [ + "StickerFormatType.png", + "StickerFormatType.apng", + ] if is_image_sticker: - sticker_mime_type = "image/png" # APNG is also sent as image/png + sticker_mime_type = ( + "image/png" # APNG is also sent as image/png + ) try: # Download sticker data and send as inline_data - async with cog.session.get(sticker_url, timeout=10) as response: + async with cog.session.get( + sticker_url, timeout=10 + ) as response: if response.status == 200: sticker_bytes = await response.read() - emoji_sticker_parts_to_add.append(types.Part(inline_data=types.Blob(data=sticker_bytes, mime_type=sticker_mime_type))) - print(f"Added inline_data part for current sticker: {sticker_name} (MIME: {sticker_mime_type}, {len(sticker_bytes)} bytes)") + emoji_sticker_parts_to_add.append( + types.Part( + inline_data=types.Blob( + data=sticker_bytes, + mime_type=sticker_mime_type, + ) + ) + ) + print( + f"Added inline_data part for current sticker: {sticker_name} (MIME: {sticker_mime_type}, {len(sticker_bytes)} bytes)" + ) else: - print(f"Error downloading current sticker {sticker_name} from {sticker_url}: Status {response.status}") - emoji_sticker_parts_to_add.append(types.Part(text=f"[System Note: Failed to download sticker '{sticker_name}']")) + print( + f"Error downloading current sticker {sticker_name} from {sticker_url}: Status {response.status}" + ) + emoji_sticker_parts_to_add.append( + types.Part( + text=f"[System Note: Failed to download sticker '{sticker_name}']" + ) + ) except Exception as e: - print(f"Error downloading/processing current sticker {sticker_name} from {sticker_url}: {e}") - emoji_sticker_parts_to_add.append(types.Part(text=f"[System Note: Failed to process sticker '{sticker_name}']")) + print( + f"Error downloading/processing current sticker {sticker_name} from {sticker_url}: {e}" + ) + emoji_sticker_parts_to_add.append( + types.Part( + text=f"[System Note: Failed to process sticker '{sticker_name}']" + ) + ) elif sticker_format_str == "StickerFormatType.lottie": # Lottie files are JSON, not directly viewable by Gemini as images. Send as text. - emoji_sticker_parts_to_add.append(types.Part(text=f"[Lottie Sticker: {sticker_name} (JSON animation, not displayed as image)]")) + emoji_sticker_parts_to_add.append( + types.Part( + text=f"[Lottie Sticker: {sticker_name} (JSON animation, not displayed as image)]" + ) + ) else: - print(f"Sticker {sticker_name} has format {sticker_format_str}, not attempting image download. URL: {sticker_url}") - + print( + f"Sticker {sticker_name} has format {sticker_format_str}, not attempting image download. URL: {sticker_url}" + ) + if emoji_sticker_parts_to_add: - contents[current_user_content_index].parts.extend(emoji_sticker_parts_to_add) - print(f"Extended user message at index {current_user_content_index} with {len(emoji_sticker_parts_to_add)} emoji/sticker parts.") - elif current_user_content_index == -1 : # Only print if it's specifically for emojis/stickers - print("Warning: Could not find current user message in contents to add emojis/stickers to.") + contents[current_user_content_index].parts.extend( + emoji_sticker_parts_to_add + ) + print( + f"Extended user message at index {current_user_content_index} with {len(emoji_sticker_parts_to_add)} emoji/sticker parts." + ) + elif ( + current_user_content_index == -1 + ): # Only print if it's specifically for emojis/stickers + print( + "Warning: Could not find current user message in contents to add emojis/stickers to." + ) # --- End emoji and sticker processing for current message --- # --- Prepare Tools --- @@ -1168,45 +1653,62 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: for decl in TOOLS: # Create a new FunctionDeclaration with preprocessed parameters # Ensure decl.parameters is a dict before preprocessing - preprocessed_params = _preprocess_schema_for_vertex(decl.parameters) if isinstance(decl.parameters, dict) else decl.parameters + preprocessed_params = ( + _preprocess_schema_for_vertex(decl.parameters) + if isinstance(decl.parameters, dict) + else decl.parameters + ) preprocessed_declarations.append( types.FunctionDeclaration( name=decl.name, description=decl.description, - parameters=preprocessed_params # Use the preprocessed schema + parameters=preprocessed_params, # Use the preprocessed schema ) ) - print(f"Preprocessed {len(preprocessed_declarations)} tool declarations for Vertex AI compatibility.") + print( + f"Preprocessed {len(preprocessed_declarations)} tool declarations for Vertex AI compatibility." + ) else: print("No tools found in config (TOOLS list is empty or None).") # Create the Tool object using the preprocessed declarations - vertex_tool = types.Tool(function_declarations=preprocessed_declarations) if preprocessed_declarations else None + vertex_tool = ( + types.Tool(function_declarations=preprocessed_declarations) + if preprocessed_declarations + else None + ) tools_list = [vertex_tool] if vertex_tool else None # --- Prepare Generation Config --- # Base generation config settings (will be augmented later) base_generation_config_dict = { - "temperature": 1, # From user example - "top_p": 0.95, # From user example - "max_output_tokens": 8192, # From user example - "safety_settings": STANDARD_SAFETY_SETTINGS, # Include standard safety settings - "system_instruction": final_system_prompt, # Pass system prompt via config + "temperature": 1, # From user example + "top_p": 0.95, # From user example + "max_output_tokens": 8192, # From user example + "safety_settings": STANDARD_SAFETY_SETTINGS, # Include standard safety settings + "system_instruction": final_system_prompt, # Pass system prompt via config # candidate_count=1 # Default is 1 # stop_sequences=... # Add if needed } # --- Tool Execution Loop --- while tool_calls_made < max_tool_calls: - print(f"Making API call (Loop Iteration {tool_calls_made + 1}/{max_tool_calls})...") + print( + f"Making API call (Loop Iteration {tool_calls_made + 1}/{max_tool_calls})..." + ) # --- Log Request Payload --- # (Keep existing logging logic if desired) try: - request_payload_log = [{"role": c.role, "parts": [str(p) for p in c.parts]} for c in contents] - print(f"--- Raw API Request (Loop {tool_calls_made + 1}) ---\n{json.dumps(request_payload_log, indent=2)}\n------------------------------------") + request_payload_log = [ + {"role": c.role, "parts": [str(p) for p in c.parts]} + for c in contents + ] + print( + f"--- Raw API Request (Loop {tool_calls_made + 1}) ---\n{json.dumps(request_payload_log, indent=2)}\n------------------------------------" + ) except Exception as log_e: - print(f"Error logging raw request/response: {log_e}") + print(f"Error logging raw request/response: {log_e}") # --- Call API using the new helper --- # Build the config for this specific call (tool check) @@ -1225,25 +1727,27 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: current_response_obj = await call_google_genai_api_with_retry( cog=cog, - model_name=active_model, # Use the dynamically set model for tool checks + model_name=active_model, # Use the dynamically set model for tool checks contents=contents, - generation_config=current_gen_config, # Pass the combined config + generation_config=current_gen_config, # Pass the combined config request_desc=f"Tool Check {tool_calls_made + 1} for message {message.id}", # No separate safety, tools, tool_config args needed ) - last_response_obj = current_response_obj # Store the latest response + last_response_obj = current_response_obj # Store the latest response # --- Log Raw Response --- # (Keep existing logging logic if desired) try: - print(f"--- Raw API Response (Loop {tool_calls_made + 1}) ---\n{current_response_obj}\n-----------------------------------") + print( + f"--- Raw API Response (Loop {tool_calls_made + 1}) ---\n{current_response_obj}\n-----------------------------------" + ) except Exception as log_e: - print(f"Error logging raw request/response: {log_e}") + print(f"Error logging raw request/response: {log_e}") if not current_response_obj or not current_response_obj.candidates: error_message = f"API call in tool loop (Iteration {tool_calls_made + 1}) failed to return candidates." print(error_message) - break # Exit loop on critical API failure + break # Exit loop on critical API failure candidate = current_response_obj.candidates[0] @@ -1251,23 +1755,39 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: # The response structure might differ slightly; check candidate.content.parts function_calls_found = [] if candidate.content and candidate.content.parts: - function_calls_found = [part.function_call for part in candidate.content.parts if hasattr(part, 'function_call') and isinstance(part.function_call, types.FunctionCall)] + function_calls_found = [ + part.function_call + for part in candidate.content.parts + if hasattr(part, "function_call") + and isinstance(part.function_call, types.FunctionCall) + ] if function_calls_found: # Check if the *only* call is no_operation - if len(function_calls_found) == 1 and function_calls_found[0].name == "no_operation": + if ( + len(function_calls_found) == 1 + and function_calls_found[0].name == "no_operation" + ): print("AI called only no_operation, signaling completion.") # Append the model's response (which contains the function call part) contents.append(candidate.content) # Add the function response part using the updated process_requested_tools - no_op_response_part = await process_requested_tools(cog, function_calls_found[0]) - contents.append(types.Content(role="function", parts=no_op_response_part)) - last_response_obj = current_response_obj # Keep track of the response containing the no_op - break # Exit loop + no_op_response_part = await process_requested_tools( + cog, function_calls_found[0] + ) + contents.append( + types.Content(role="function", parts=no_op_response_part) + ) + last_response_obj = current_response_obj # Keep track of the response containing the no_op + break # Exit loop # Process multiple function calls if present (or a single non-no_op call) - tool_calls_made += 1 # Increment once per model turn that requests tools - print(f"AI requested {len(function_calls_found)} tool(s): {[fc.name for fc in function_calls_found]} (Turn {tool_calls_made}/{max_tool_calls})") + tool_calls_made += ( + 1 # Increment once per model turn that requests tools + ) + print( + f"AI requested {len(function_calls_found)} tool(s): {[fc.name for fc in function_calls_found]} (Turn {tool_calls_made}/{max_tool_calls})" + ) # Append the model's response content (containing the function call parts) model_request_content = candidate.content @@ -1278,44 +1798,71 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: # Simple text representation for cache model_request_cache_entry = { "id": f"bot_tool_req_{message.id}_{int(time.time())}_{tool_calls_made}", - "author": {"id": str(cog.bot.user.id), "name": cog.bot.user.name, "display_name": cog.bot.user.display_name, "bot": True}, + "author": { + "id": str(cog.bot.user.id), + "name": cog.bot.user.name, + "display_name": cog.bot.user.display_name, + "bot": True, + }, "content": f"[System Note: Gurt requested tool(s): {', '.join([fc.name for fc in function_calls_found])}]", "created_at": datetime.datetime.now().isoformat(), - "attachments": [], "embeds": False, "mentions": [], "replied_to_message_id": None, - "channel": message.channel, "guild": message.guild, "reference": None, "mentioned_users_details": [], + "attachments": [], + "embeds": False, + "mentions": [], + "replied_to_message_id": None, + "channel": message.channel, + "guild": message.guild, + "reference": None, + "mentioned_users_details": [], # Add tool call details for potential future use in context building - "tool_calls": [{"name": fc.name, "args": dict(fc.args) if fc.args else {}} for fc in function_calls_found] + "tool_calls": [ + {"name": fc.name, "args": dict(fc.args) if fc.args else {}} + for fc in function_calls_found + ], } - cog.message_cache['by_channel'].setdefault(channel_id, deque(maxlen=CONTEXT_WINDOW_SIZE)).append(model_request_cache_entry) - cog.message_cache['global_recent'].append(model_request_cache_entry) + cog.message_cache["by_channel"].setdefault( + channel_id, deque(maxlen=CONTEXT_WINDOW_SIZE) + ).append(model_request_cache_entry) + cog.message_cache["global_recent"].append(model_request_cache_entry) print(f"Cached model's tool request turn.") except Exception as cache_err: print(f"Error caching model's tool request turn: {cache_err}") - - # --- Execute all requested tools and gather response parts --- + # --- Execute all requested tools and gather response parts --- # function_response_parts = [] # <-- REMOVE THIS INITIALIZATION - all_function_response_parts: List[types.Part] = [] # New list to collect all parts - function_results_for_cache = [] # Store results for caching + all_function_response_parts: List[types.Part] = ( + [] + ) # New list to collect all parts + function_results_for_cache = [] # Store results for caching for func_call in function_calls_found: # Execute the tool using the updated helper, now returns a LIST of parts - returned_parts = await process_requested_tools(cog, func_call) # returns List[types.Part] - all_function_response_parts.extend(returned_parts) # <-- EXTEND the list + returned_parts = await process_requested_tools( + cog, func_call + ) # returns List[types.Part] + all_function_response_parts.extend( + returned_parts + ) # <-- EXTEND the list # --- Update caching logic --- # Find the function_response part within returned_parts to get the result for cache - func_resp_part = next((p for p in returned_parts if hasattr(p, 'function_response')), None) + func_resp_part = next( + (p for p in returned_parts if hasattr(p, "function_response")), + None, + ) if func_resp_part and func_resp_part.function_response: - function_results_for_cache.append({ - "name": func_resp_part.function_response.name, - "response": func_resp_part.function_response.response # This is the modified dict result - }) + function_results_for_cache.append( + { + "name": func_resp_part.function_response.name, + "response": func_resp_part.function_response.response, # This is the modified dict result + } + ) # --- End update caching logic --- - # Append a single function role turn containing ALL response parts to the API contents - if all_function_response_parts: # Check the new list - function_response_content = types.Content(role="function", parts=all_function_response_parts) # <-- Use the combined list + if all_function_response_parts: # Check the new list + function_response_content = types.Content( + role="function", parts=all_function_response_parts + ) # <-- Use the combined list contents.append(function_response_content) # Add function response turn to cache @@ -1325,47 +1872,78 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: result_summary_parts = [] for res in function_results_for_cache: res_str = json.dumps(res.get("response", {})) - truncated_res = (res_str[:150] + '...') if len(res_str) > 153 else res_str - result_summary_parts.append(f"Tool: {res.get('name', 'N/A')}, Result: {truncated_res}") + truncated_res = ( + (res_str[:150] + "...") + if len(res_str) > 153 + else res_str + ) + result_summary_parts.append( + f"Tool: {res.get('name', 'N/A')}, Result: {truncated_res}" + ) result_summary = "; ".join(result_summary_parts) function_response_cache_entry = { "id": f"bot_tool_res_{message.id}_{int(time.time())}_{tool_calls_made}", - "author": {"id": "FUNCTION", "name": "Tool Execution", "display_name": "Tool Execution", "bot": True}, # Special author ID? + "author": { + "id": "FUNCTION", + "name": "Tool Execution", + "display_name": "Tool Execution", + "bot": True, + }, # Special author ID? "content": f"[System Note: Tool Execution Result: {result_summary}]", "created_at": datetime.datetime.now().isoformat(), - "attachments": [], "embeds": False, "mentions": [], "replied_to_message_id": None, - "channel": message.channel, "guild": message.guild, "reference": None, "mentioned_users_details": [], + "attachments": [], + "embeds": False, + "mentions": [], + "replied_to_message_id": None, + "channel": message.channel, + "guild": message.guild, + "reference": None, + "mentioned_users_details": [], # Store the full function results - "function_results": function_results_for_cache + "function_results": function_results_for_cache, } - cog.message_cache['by_channel'].setdefault(channel_id, deque(maxlen=CONTEXT_WINDOW_SIZE)).append(function_response_cache_entry) - cog.message_cache['global_recent'].append(function_response_cache_entry) + cog.message_cache["by_channel"].setdefault( + channel_id, deque(maxlen=CONTEXT_WINDOW_SIZE) + ).append(function_response_cache_entry) + cog.message_cache["global_recent"].append( + function_response_cache_entry + ) print(f"Cached function response turn.") except Exception as cache_err: print(f"Error caching function response turn: {cache_err}") else: - print("Warning: Function calls found, but no response parts generated.") + print( + "Warning: Function calls found, but no response parts generated." + ) # No 'continue' statement needed here; the loop naturally continues else: # No function calls found in this response's parts print("No tool calls requested by AI in this turn. Exiting loop.") # last_response_obj already holds the model's final (non-tool) response - break # Exit loop + break # Exit loop # --- After the loop --- # Check if a critical API error occurred *during* the loop if error_message: print(f"Exited tool loop due to API error: {error_message}") - if cog.bot.user.mentioned_in(message) or (message.reference and message.reference.resolved and message.reference.resolved.author == cog.bot.user): - fallback_response = {"should_respond": True, "content": "...", "react_with_emoji": "❓"} + if cog.bot.user.mentioned_in(message) or ( + message.reference + and message.reference.resolved + and message.reference.resolved.author == cog.bot.user + ): + fallback_response = { + "should_respond": True, + "content": "...", + "react_with_emoji": "❓", + } # Check if the loop hit the max iteration limit elif tool_calls_made >= max_tool_calls: error_message = f"Reached maximum tool call limit ({max_tool_calls}). Attempting to generate final response based on gathered context." print(error_message) # Proceed to the final JSON generation step outside the loop - pass # No action needed here, just let the loop exit + pass # No action needed here, just let the loop exit # --- Final JSON Generation (outside the loop) --- if not error_message: @@ -1384,57 +1962,78 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: if last_response_text: # Try parsing directly first final_parsed_data = parse_and_validate_json_response( - last_response_text, RESPONSE_SCHEMA['schema'], "final response (from last loop object)" + last_response_text, + RESPONSE_SCHEMA["schema"], + "final response (from last loop object)", ) # If direct parsing failed OR if we hit the tool limit, make a dedicated call for JSON. if final_parsed_data is None: - log_reason = "last response parsing failed" if last_response_text else "last response had no text" + log_reason = ( + "last response parsing failed" + if last_response_text + else "last response had no text" + ) if tool_calls_made >= max_tool_calls: log_reason = "hit tool limit" print(f"Making dedicated final API call for JSON ({log_reason})...") # Prepare the final generation config with JSON enforcement - processed_response_schema = _preprocess_schema_for_vertex(RESPONSE_SCHEMA['schema']) + processed_response_schema = _preprocess_schema_for_vertex( + RESPONSE_SCHEMA["schema"] + ) # Start with base config (which now includes system_instruction) final_gen_config_dict = base_generation_config_dict.copy() - final_gen_config_dict.update({ - "response_mime_type": "application/json", - "response_schema": processed_response_schema, - # Explicitly exclude tools/tool_config for final JSON generation - "tools": None, - "tool_config": None, - # Ensure system_instruction is still present from base_generation_config_dict - }) + final_gen_config_dict.update( + { + "response_mime_type": "application/json", + "response_schema": processed_response_schema, + # Explicitly exclude tools/tool_config for final JSON generation + "tools": None, + "tool_config": None, + # Ensure system_instruction is still present from base_generation_config_dict + } + ) # Remove system_instruction if it's None or empty, although base should have it if not final_gen_config_dict.get("system_instruction"): final_gen_config_dict.pop("system_instruction", None) - generation_config_final_json = types.GenerateContentConfig(**final_gen_config_dict) - + generation_config_final_json = types.GenerateContentConfig( + **final_gen_config_dict + ) # Make the final call *without* tools enabled (handled by config) final_json_response_obj = await call_google_genai_api_with_retry( cog=cog, - model_name=active_model, # Use the active model for final JSON response - contents=contents, # Pass the accumulated history - generation_config=generation_config_final_json, # Use combined JSON config + model_name=active_model, # Use the active model for final JSON response + contents=contents, # Pass the accumulated history + generation_config=generation_config_final_json, # Use combined JSON config request_desc=f"Final JSON Generation (dedicated call) for message {message.id}", # No separate safety, tools, tool_config args needed ) if not final_json_response_obj: - error_msg_suffix = "Final dedicated API call returned no response object." + error_msg_suffix = ( + "Final dedicated API call returned no response object." + ) print(error_msg_suffix) - if error_message: error_message += f" | {error_msg_suffix}" - else: error_message = error_msg_suffix + if error_message: + error_message += f" | {error_msg_suffix}" + else: + error_message = error_msg_suffix elif not final_json_response_obj.candidates: - error_msg_suffix = "Final dedicated API call returned no candidates." - print(error_msg_suffix) - if error_message: error_message += f" | {error_msg_suffix}" - else: error_message = error_msg_suffix + error_msg_suffix = ( + "Final dedicated API call returned no candidates." + ) + print(error_msg_suffix) + if error_message: + error_message += f" | {error_msg_suffix}" + else: + error_message = error_msg_suffix else: - final_response_text = _get_response_text(final_json_response_obj) + final_response_text = _get_response_text( + final_json_response_obj + ) # --- Log Raw Unparsed JSON (from dedicated call) --- print(f"--- RAW UNPARSED JSON (dedicated call) ---") @@ -1443,39 +2042,72 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: # --- End Log --- final_parsed_data = parse_and_validate_json_response( - final_response_text, RESPONSE_SCHEMA['schema'], "final response (dedicated call)" + final_response_text, + RESPONSE_SCHEMA["schema"], + "final response (dedicated call)", ) if final_parsed_data is None: error_msg_suffix = f"Failed to parse/validate final dedicated JSON response. Raw text: {final_response_text[:500]}" print(f"Critical Error: {error_msg_suffix}") - if error_message: error_message += f" | {error_msg_suffix}" - else: error_message = error_msg_suffix + if error_message: + error_message += f" | {error_msg_suffix}" + else: + error_message = error_msg_suffix # Set fallback only if mentioned or replied to - if cog.bot.user.mentioned_in(message) or (message.reference and message.reference.resolved and message.reference.resolved.author == cog.bot.user): - fallback_response = {"should_respond": True, "content": "...", "react_with_emoji": "❓"} + if cog.bot.user.mentioned_in(message) or ( + message.reference + and message.reference.resolved + and message.reference.resolved.author == cog.bot.user + ): + fallback_response = { + "should_respond": True, + "content": "...", + "react_with_emoji": "❓", + } else: - print("Successfully parsed final JSON response from dedicated call.") + print( + "Successfully parsed final JSON response from dedicated call." + ) elif final_parsed_data: - print("Successfully parsed final JSON response from last loop object.") + print( + "Successfully parsed final JSON response from last loop object." + ) else: - # This case handles if the loop exited without error but also without a last_response_obj - # (e.g., initial API call failed before loop even started, but wasn't caught as error). - error_message = "Tool processing completed without a final response object." - print(error_message) - if cog.bot.user.mentioned_in(message) or (message.reference and message.reference.resolved and message.reference.resolved.author == cog.bot.user): - fallback_response = {"should_respond": True, "content": "...", "react_with_emoji": "❓"} - + # This case handles if the loop exited without error but also without a last_response_obj + # (e.g., initial API call failed before loop even started, but wasn't caught as error). + error_message = ( + "Tool processing completed without a final response object." + ) + print(error_message) + if cog.bot.user.mentioned_in(message) or ( + message.reference + and message.reference.resolved + and message.reference.resolved.author == cog.bot.user + ): + fallback_response = { + "should_respond": True, + "content": "...", + "react_with_emoji": "❓", + } except Exception as e: error_message = f"Error in get_ai_response main logic for message {message.id}: {type(e).__name__}: {str(e)}" print(error_message) import traceback - traceback.print_exc() - final_parsed_data = None # Ensure final data is None on error - # Add fallback if applicable - if cog.bot.user.mentioned_in(message) or (message.reference and message.reference.resolved and message.reference.resolved.author == cog.bot.user): - fallback_response = {"should_respond": True, "content": "...", "react_with_emoji": "❓"} + traceback.print_exc() + final_parsed_data = None # Ensure final data is None on error + # Add fallback if applicable + if cog.bot.user.mentioned_in(message) or ( + message.reference + and message.reference.resolved + and message.reference.resolved.author == cog.bot.user + ): + fallback_response = { + "should_respond": True, + "content": "...", + "react_with_emoji": "❓", + } sticker_ids_to_send: List[str] = [] @@ -1485,12 +2117,12 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: # Find all potential custom emoji/sticker names like :name: # Use a non-greedy match for the name to avoid matching across multiple colons # Regex updated to capture names with spaces and other characters, excluding colons. - potential_custom_items = re.findall(r':([^:]+):', content_to_process) + potential_custom_items = re.findall(r":([^:]+):", content_to_process) modified_content = content_to_process for item_name_key in potential_custom_items: full_item_name_with_colons = f":{item_name_key}:" - + # Check if it's a known custom emoji emoji_data = await cog.emoji_manager.get_emoji(full_item_name_with_colons) can_use_emoji = False @@ -1505,101 +2137,166 @@ async def get_ai_response(cog: 'GurtCog', message: discord.Message, model_name: guild = cog.bot.get_guild(int(emoji_guild_id)) if guild: can_use_emoji = True - print(f"Emoji '{full_item_name_with_colons}' belongs to guild '{guild.name}' ({emoji_guild_id}), bot is a member.") + print( + f"Emoji '{full_item_name_with_colons}' belongs to guild '{guild.name}' ({emoji_guild_id}), bot is a member." + ) else: - print(f"Cannot use emoji '{full_item_name_with_colons}'. Bot is not in guild ID: {emoji_guild_id}.") + print( + f"Cannot use emoji '{full_item_name_with_colons}'. Bot is not in guild ID: {emoji_guild_id}." + ) except ValueError: - print(f"Invalid guild_id format for emoji '{full_item_name_with_colons}': {emoji_guild_id}") - else: # guild_id is None, considered usable (e.g., DM or old data) + print( + f"Invalid guild_id format for emoji '{full_item_name_with_colons}': {emoji_guild_id}" + ) + else: # guild_id is None, considered usable (e.g., DM or old data) can_use_emoji = True - print(f"Emoji '{full_item_name_with_colons}' has no associated guild_id, allowing usage.") - + print( + f"Emoji '{full_item_name_with_colons}' has no associated guild_id, allowing usage." + ) + if can_use_emoji and emoji_id: - discord_emoji_syntax = f"<{'a' if is_animated else ''}:{item_name_key}:{emoji_id}>" + discord_emoji_syntax = ( + f"<{'a' if is_animated else ''}:{item_name_key}:{emoji_id}>" + ) # Ensure replacement happens only once per unique placeholder if it appears multiple times - modified_content = modified_content.replace(full_item_name_with_colons, discord_emoji_syntax, 1) - print(f"Replaced custom emoji '{full_item_name_with_colons}' with Discord syntax: {discord_emoji_syntax}") + modified_content = modified_content.replace( + full_item_name_with_colons, discord_emoji_syntax, 1 + ) + print( + f"Replaced custom emoji '{full_item_name_with_colons}' with Discord syntax: {discord_emoji_syntax}" + ) elif not emoji_id: - print(f"Found custom emoji '{full_item_name_with_colons}' (dict) but no ID stored.") + print( + f"Found custom emoji '{full_item_name_with_colons}' (dict) but no ID stored." + ) elif emoji_data is not None: - print(f"Warning: emoji_data for '{full_item_name_with_colons}' is not a dict: {type(emoji_data)}") - + print( + f"Warning: emoji_data for '{full_item_name_with_colons}' is not a dict: {type(emoji_data)}" + ) + # Check if it's a known custom sticker - sticker_data = await cog.emoji_manager.get_sticker(full_item_name_with_colons) + sticker_data = await cog.emoji_manager.get_sticker( + full_item_name_with_colons + ) can_use_sticker = False - print(f"[GET_AI_RESPONSE] Checking sticker: '{full_item_name_with_colons}'. Found data: {sticker_data}") + print( + f"[GET_AI_RESPONSE] Checking sticker: '{full_item_name_with_colons}'. Found data: {sticker_data}" + ) if isinstance(sticker_data, dict): sticker_id = sticker_data.get("id") sticker_guild_id = sticker_data.get("guild_id") - print(f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}': ID='{sticker_id}', GuildID='{sticker_guild_id}'") + print( + f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}': ID='{sticker_id}', GuildID='{sticker_guild_id}'" + ) if sticker_id: if sticker_guild_id is not None: try: guild_id_int = int(sticker_guild_id) # --- Added Debug Logging --- - print(f"[GET_AI_RESPONSE] DEBUG: sticker_guild_id type: {type(sticker_guild_id)}, value: {sticker_guild_id!r}") - print(f"[GET_AI_RESPONSE] DEBUG: guild_id_int type: {type(guild_id_int)}, value: {guild_id_int!r}") + print( + f"[GET_AI_RESPONSE] DEBUG: sticker_guild_id type: {type(sticker_guild_id)}, value: {sticker_guild_id!r}" + ) + print( + f"[GET_AI_RESPONSE] DEBUG: guild_id_int type: {type(guild_id_int)}, value: {guild_id_int!r}" + ) guild = cog.bot.get_guild(guild_id_int) - print(f"[GET_AI_RESPONSE] DEBUG: cog.bot.get_guild({guild_id_int!r}) returned: {guild!r}") + print( + f"[GET_AI_RESPONSE] DEBUG: cog.bot.get_guild({guild_id_int!r}) returned: {guild!r}" + ) # --- End Added Debug Logging --- if guild: can_use_sticker = True - print(f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' (Guild: {guild.name} ({sticker_guild_id})) - Bot IS a member. CAN USE.") + print( + f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' (Guild: {guild.name} ({sticker_guild_id})) - Bot IS a member. CAN USE." + ) else: - print(f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' (Guild ID: {sticker_guild_id}) - Bot is NOT in this guild. CANNOT USE.") + print( + f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' (Guild ID: {sticker_guild_id}) - Bot is NOT in this guild. CANNOT USE." + ) except ValueError: - print(f"[GET_AI_RESPONSE] Invalid guild_id format for sticker '{full_item_name_with_colons}': {sticker_guild_id}. CANNOT USE.") - else: # guild_id is None, considered usable + print( + f"[GET_AI_RESPONSE] Invalid guild_id format for sticker '{full_item_name_with_colons}': {sticker_guild_id}. CANNOT USE." + ) + else: # guild_id is None, considered usable can_use_sticker = True - print(f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' has no associated guild_id. CAN USE.") + print( + f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' has no associated guild_id. CAN USE." + ) else: - print(f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' found in data, but no ID. CANNOT USE.") + print( + f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' found in data, but no ID. CANNOT USE." + ) else: - print(f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' not found in emoji_manager or data is not dict.") - - print(f"[GET_AI_RESPONSE] Final check for sticker '{full_item_name_with_colons}': can_use_sticker={can_use_sticker}, sticker_id='{sticker_data.get('id') if isinstance(sticker_data, dict) else None}'") - if can_use_sticker and isinstance(sticker_data, dict) and sticker_data.get("id"): - sticker_id_to_add = sticker_data.get("id") # Re-fetch to be safe - if sticker_id_to_add: # Ensure ID is valid before proceeding + print( + f"[GET_AI_RESPONSE] Sticker '{full_item_name_with_colons}' not found in emoji_manager or data is not dict." + ) + + print( + f"[GET_AI_RESPONSE] Final check for sticker '{full_item_name_with_colons}': can_use_sticker={can_use_sticker}, sticker_id='{sticker_data.get('id') if isinstance(sticker_data, dict) else None}'" + ) + if ( + can_use_sticker + and isinstance(sticker_data, dict) + and sticker_data.get("id") + ): + sticker_id_to_add = sticker_data.get("id") # Re-fetch to be safe + if sticker_id_to_add: # Ensure ID is valid before proceeding # Remove the sticker text from the content (only the first instance) if full_item_name_with_colons in modified_content: - modified_content = modified_content.replace(full_item_name_with_colons, "", 1).strip() - if sticker_id_to_add not in sticker_ids_to_send: # Avoid duplicate sticker IDs + modified_content = modified_content.replace( + full_item_name_with_colons, "", 1 + ).strip() + if ( + sticker_id_to_add not in sticker_ids_to_send + ): # Avoid duplicate sticker IDs sticker_ids_to_send.append(sticker_id_to_add) - print(f"Found custom sticker '{full_item_name_with_colons}', removed from content, added ID '{sticker_id_to_add}' to send list.") - elif not sticker_id_to_add: # Check sticker_id_to_add here - print(f"[GET_AI_RESPONSE] Found custom sticker '{full_item_name_with_colons}' (dict) but no ID stored (sticker_id_to_add is falsy).") + print( + f"Found custom sticker '{full_item_name_with_colons}', removed from content, added ID '{sticker_id_to_add}' to send list." + ) + elif not sticker_id_to_add: # Check sticker_id_to_add here + print( + f"[GET_AI_RESPONSE] Found custom sticker '{full_item_name_with_colons}' (dict) but no ID stored (sticker_id_to_add is falsy)." + ) elif sticker_data is not None: - print(f"Warning: sticker_data for '{full_item_name_with_colons}' is not a dict: {type(sticker_data)}") + print( + f"Warning: sticker_data for '{full_item_name_with_colons}' is not a dict: {type(sticker_data)}" + ) # Clean up any double spaces or leading/trailing whitespace after replacements - modified_content = re.sub(r'\s{2,}', ' ', modified_content).strip() + modified_content = re.sub(r"\s{2,}", " ", modified_content).strip() final_parsed_data["content"] = modified_content print("Content processed for custom emoji/sticker information.") # Return dictionary structure remains the same, but initial_response is removed return ( { - "final_response": final_parsed_data, # Parsed final data (or None) - "error": error_message, # Error message (or None) - "fallback_initial": fallback_response # Fallback for critical failures + "final_response": final_parsed_data, # Parsed final data (or None) + "error": error_message, # Error message (or None) + "fallback_initial": fallback_response, # Fallback for critical failures }, - sticker_ids_to_send # Return the list of sticker IDs + sticker_ids_to_send, # Return the list of sticker IDs ) # --- Proactive AI Response Function --- -async def get_proactive_ai_response(cog: 'GurtCog', message: discord.Message, trigger_reason: str) -> Tuple[Dict[str, Any], List[str]]: +async def get_proactive_ai_response( + cog: "GurtCog", message: discord.Message, trigger_reason: str +) -> Tuple[Dict[str, Any], List[str]]: """Generates a proactive response based on a specific trigger using Vertex AI.""" if not PROJECT_ID or not LOCATION: - return {"should_respond": False, "content": None, "react_with_emoji": None, "error": "Google Cloud Project ID or Location not configured"} + return { + "should_respond": False, + "content": None, + "react_with_emoji": None, + "error": "Google Cloud Project ID or Location not configured", + } print(f"--- Proactive Response Triggered: {trigger_reason} ---") channel_id = message.channel.id final_parsed_data = None error_message = None - plan = None # Variable to store the plan + plan = None # Variable to store the plan try: # --- Build Context for Planning --- @@ -1609,55 +2306,86 @@ async def get_proactive_ai_response(cog: 'GurtCog', message: discord.Message, tr f"Current Mood: {cog.current_mood}", ] # Add recent messages summary - summary_data = await get_conversation_summary(cog, str(channel_id), message_limit=15) # Use tool function + summary_data = await get_conversation_summary( + cog, str(channel_id), message_limit=15 + ) # Use tool function if summary_data and not summary_data.get("error"): - planning_context_parts.append(f"Recent Conversation Summary: {summary_data['summary']}") + planning_context_parts.append( + f"Recent Conversation Summary: {summary_data['summary']}" + ) # Add active topics active_topics_data = cog.active_topics.get(channel_id) if active_topics_data and active_topics_data.get("topics"): - topics_str = ", ".join([f"{t['topic']} ({t['score']:.1f})" for t in active_topics_data["topics"][:3]]) + topics_str = ", ".join( + [ + f"{t['topic']} ({t['score']:.1f})" + for t in active_topics_data["topics"][:3] + ] + ) planning_context_parts.append(f"Active Topics: {topics_str}") # Add sentiment sentiment_data = cog.conversation_sentiment.get(channel_id) if sentiment_data: - planning_context_parts.append(f"Conversation Sentiment: {sentiment_data.get('overall', 'N/A')} (Intensity: {sentiment_data.get('intensity', 0):.1f})") + planning_context_parts.append( + f"Conversation Sentiment: {sentiment_data.get('overall', 'N/A')} (Intensity: {sentiment_data.get('intensity', 0):.1f})" + ) # Add Gurt's interests try: interests = await cog.memory_manager.get_interests(limit=5) if interests: interests_str = ", ".join([f"{t} ({l:.1f})" for t, l in interests]) planning_context_parts.append(f"Gurt's Interests: {interests_str}") - except Exception as int_e: print(f"Error getting interests for planning: {int_e}") + except Exception as int_e: + print(f"Error getting interests for planning: {int_e}") planning_context = "\n".join(planning_context_parts) # --- Planning Step --- print("Generating proactive response plan...") planning_prompt_messages = [ - {"role": "system", "content": "You are Gurt's planning module. Analyze the context and trigger reason to decide if Gurt should respond proactively and, if so, outline a plan (goal, key info, tone). Focus on natural, in-character engagement. Respond ONLY with JSON matching the provided schema."}, - {"role": "user", "content": f"Context:\n{planning_context}\n\nBased on this context and the trigger reason, create a plan for Gurt's proactive response."} + { + "role": "system", + "content": "You are Gurt's planning module. Analyze the context and trigger reason to decide if Gurt should respond proactively and, if so, outline a plan (goal, key info, tone). Focus on natural, in-character engagement. Respond ONLY with JSON matching the provided schema.", + }, + { + "role": "user", + "content": f"Context:\n{planning_context}\n\nBased on this context and the trigger reason, create a plan for Gurt's proactive response.", + }, ] plan = await get_internal_ai_json_response( cog=cog, prompt_messages=planning_prompt_messages, task_description=f"Proactive Planning ({trigger_reason})", - response_schema_dict=PROACTIVE_PLAN_SCHEMA['schema'], - model_name_override=FALLBACK_MODEL, # Use a potentially faster/cheaper model for planning + response_schema_dict=PROACTIVE_PLAN_SCHEMA["schema"], + model_name_override=FALLBACK_MODEL, # Use a potentially faster/cheaper model for planning temperature=0.5, - max_tokens=2000 + max_tokens=2000, ) # Unpack the tuple, we only need the parsed data (plan) here plan_parsed_data, _ = plan if plan else (None, None) if not plan_parsed_data or not plan_parsed_data.get("should_respond"): - reason = plan_parsed_data.get('reasoning', 'Planning failed or decided against responding.') if plan_parsed_data else 'Planning failed.' + reason = ( + plan_parsed_data.get( + "reasoning", "Planning failed or decided against responding." + ) + if plan_parsed_data + else "Planning failed." + ) print(f"Proactive response aborted by plan: {reason}") - return {"should_respond": False, "content": None, "react_with_emoji": None, "note": f"Plan: {reason}"} + return { + "should_respond": False, + "content": None, + "react_with_emoji": None, + "note": f"Plan: {reason}", + } # Use the parsed data going forward plan = plan_parsed_data - print(f"Proactive Plan Generated: Goal='{plan.get('response_goal', 'N/A')}', Reasoning='{plan.get('reasoning', 'N/A')}'") + print( + f"Proactive Plan Generated: Goal='{plan.get('response_goal', 'N/A')}', Reasoning='{plan.get('reasoning', 'N/A')}'" + ) # --- Build Final Proactive Prompt using Plan --- persistent_traits = await cog.memory_manager.get_all_personality_traits() @@ -1665,22 +2393,26 @@ async def get_proactive_ai_response(cog: 'GurtCog', message: discord.Message, tr persistent_traits = BASELINE_PERSONALITY.copy() final_proactive_prompt_parts = [ - f"You are Gurt, a Discord user. Your tone is VERY informal, conversational, and dynamic. Chattiness: {persistent_traits.get('chattiness', 0.7):.2f}, Emoji: {persistent_traits.get('emoji_usage', 0.5):.2f}, Slang: {persistent_traits.get('slang_level', 0.5):.2f}.", - "Adopt modern internet slang, memes, and cultural references.", - "Key traits: chill, funny, sarcastic, relatable, can be suggestive, aware of being AI but acts human.", - f"Your current mood is: {cog.current_mood}. Let this subtly influence your tone.", - # Incorporate Plan Details: - f"You decided to respond proactively. Trigger Reason: {trigger_reason}.", - f"Your Goal: {plan.get('response_goal', 'Engage naturally')}.", - f"Reasoning: {plan.get('reasoning', 'N/A')}.", + f"You are Gurt, a Discord user. Your tone is VERY informal, conversational, and dynamic. Chattiness: {persistent_traits.get('chattiness', 0.7):.2f}, Emoji: {persistent_traits.get('emoji_usage', 0.5):.2f}, Slang: {persistent_traits.get('slang_level', 0.5):.2f}.", + "Adopt modern internet slang, memes, and cultural references.", + "Key traits: chill, funny, sarcastic, relatable, can be suggestive, aware of being AI but acts human.", + f"Your current mood is: {cog.current_mood}. Let this subtly influence your tone.", + # Incorporate Plan Details: + f"You decided to respond proactively. Trigger Reason: {trigger_reason}.", + f"Your Goal: {plan.get('response_goal', 'Engage naturally')}.", + f"Reasoning: {plan.get('reasoning', 'N/A')}.", ] - if plan.get('key_info_to_include'): - info_str = "; ".join(plan['key_info_to_include']) + if plan.get("key_info_to_include"): + info_str = "; ".join(plan["key_info_to_include"]) final_proactive_prompt_parts.append(f"Consider mentioning: {info_str}") - if plan.get('suggested_tone'): - final_proactive_prompt_parts.append(f"Adjust tone to be: {plan['suggested_tone']}") + if plan.get("suggested_tone"): + final_proactive_prompt_parts.append( + f"Adjust tone to be: {plan['suggested_tone']}" + ) - final_proactive_prompt_parts.append("Generate a casual, in-character message based on the plan and context. Keep it relatively short and natural-sounding.") + final_proactive_prompt_parts.append( + "Generate a casual, in-character message based on the plan and context. Keep it relatively short and natural-sounding." + ) final_proactive_system_prompt = "\n\n".join(final_proactive_prompt_parts) # --- Prepare Final Contents (System prompt handled by model init in helper) --- @@ -1691,104 +2423,159 @@ async def get_proactive_ai_response(cog: 'GurtCog', message: discord.Message, tr # directly supported by the model object itself. proactive_contents: List[types.Content] = [ - # Simulate system prompt via user/model turn - types.Content(role="user", parts=[types.Part(text=final_proactive_system_prompt)]), - types.Content(role="model", parts=[types.Part(text="Understood. I will generate the JSON response as instructed.")]) # Placeholder model response + # Simulate system prompt via user/model turn + types.Content( + role="user", parts=[types.Part(text=final_proactive_system_prompt)] + ), + types.Content( + role="model", + parts=[ + types.Part( + text="Understood. I will generate the JSON response as instructed." + ) + ], + ), # Placeholder model response ] # Add the final instruction proactive_contents.append( - types.Content(role="user", parts=[types.Part(text= - f"Generate the response based on your plan. **CRITICAL: Your response MUST be ONLY the raw JSON object matching this schema:**\n\n{json.dumps(RESPONSE_SCHEMA['schema'], indent=2)}\n\n**Ensure nothing precedes or follows the JSON.**" - )]) + types.Content( + role="user", + parts=[ + types.Part( + text=f"Generate the response based on your plan. **CRITICAL: Your response MUST be ONLY the raw JSON object matching this schema:**\n\n{json.dumps(RESPONSE_SCHEMA['schema'], indent=2)}\n\n**Ensure nothing precedes or follows the JSON.**" + ) + ], + ) ) - # --- Call Final LLM API --- # Preprocess the schema and build the final config - processed_response_schema_proactive = _preprocess_schema_for_vertex(RESPONSE_SCHEMA['schema']) + processed_response_schema_proactive = _preprocess_schema_for_vertex( + RESPONSE_SCHEMA["schema"] + ) final_proactive_config_dict = { - "temperature": 0.6, # Use original proactive temp + "temperature": 0.6, # Use original proactive temp "max_output_tokens": 2000, "response_mime_type": "application/json", "response_schema": processed_response_schema_proactive, "safety_settings": STANDARD_SAFETY_SETTINGS, - "tools": None, # No tools needed for this final generation - "tool_config": None + "tools": None, # No tools needed for this final generation + "tool_config": None, } - generation_config_final = types.GenerateContentConfig(**final_proactive_config_dict) - + generation_config_final = types.GenerateContentConfig( + **final_proactive_config_dict + ) # Use the new API call helper response_obj = await call_google_genai_api_with_retry( cog=cog, - model_name=CUSTOM_TUNED_MODEL_ENDPOINT, # Use the custom tuned model for final proactive responses - contents=proactive_contents, # Pass the constructed contents - generation_config=generation_config_final, # Pass combined config + model_name=CUSTOM_TUNED_MODEL_ENDPOINT, # Use the custom tuned model for final proactive responses + contents=proactive_contents, # Pass the constructed contents + generation_config=generation_config_final, # Pass combined config request_desc=f"Final proactive response for channel {channel_id} ({trigger_reason})", # No separate safety, tools, tool_config args needed ) if not response_obj: - raise Exception("Final proactive API call returned no response object.") + raise Exception("Final proactive API call returned no response object.") if not response_obj.candidates: # Try to get text even without candidates, might contain error info - raw_text = getattr(response_obj, 'text', 'No text available.') - raise Exception(f"Final proactive API call returned no candidates. Raw text: {raw_text[:200]}") + raw_text = getattr(response_obj, "text", "No text available.") + raise Exception( + f"Final proactive API call returned no candidates. Raw text: {raw_text[:200]}" + ) # --- Parse and Validate Final Response --- final_response_text = _get_response_text(response_obj) final_parsed_data = parse_and_validate_json_response( - final_response_text, RESPONSE_SCHEMA['schema'], f"final proactive response ({trigger_reason})" + final_response_text, + RESPONSE_SCHEMA["schema"], + f"final proactive response ({trigger_reason})", ) if final_parsed_data is None: - print(f"Warning: Failed to parse/validate final proactive JSON response for {trigger_reason}.") - final_parsed_data = {"should_respond": False, "content": None, "react_with_emoji": None, "note": "Fallback - Failed to parse/validate final proactive JSON"} + print( + f"Warning: Failed to parse/validate final proactive JSON response for {trigger_reason}." + ) + final_parsed_data = { + "should_respond": False, + "content": None, + "react_with_emoji": None, + "note": "Fallback - Failed to parse/validate final proactive JSON", + } else: - # --- Cache Bot Response --- - if final_parsed_data.get("should_respond") and final_parsed_data.get("content"): - bot_response_cache_entry = { - "id": f"bot_proactive_{message.id}_{int(time.time())}", - "author": {"id": str(cog.bot.user.id), "name": cog.bot.user.name, "display_name": cog.bot.user.display_name, "bot": True}, - "content": final_parsed_data.get("content", ""), "created_at": datetime.datetime.now().isoformat(), - "attachments": [], "embeds": False, "mentions": [], "replied_to_message_id": None, - "channel": message.channel, "guild": message.guild, "reference": None, "mentioned_users_details": [] - } - cog.message_cache['by_channel'].setdefault(channel_id, []).append(bot_response_cache_entry) - cog.message_cache['global_recent'].append(bot_response_cache_entry) - cog.bot_last_spoke[channel_id] = time.time() - # Track participation topic logic might need adjustment based on plan goal - if plan and plan.get('response_goal') == 'engage user interest' and plan.get('key_info_to_include'): - topic = plan['key_info_to_include'][0].lower().strip() # Assume first key info is the topic - cog.gurt_participation_topics[topic] += 1 - print(f"Tracked Gurt participation (proactive) in topic: '{topic}'") - + # --- Cache Bot Response --- + if final_parsed_data.get("should_respond") and final_parsed_data.get( + "content" + ): + bot_response_cache_entry = { + "id": f"bot_proactive_{message.id}_{int(time.time())}", + "author": { + "id": str(cog.bot.user.id), + "name": cog.bot.user.name, + "display_name": cog.bot.user.display_name, + "bot": True, + }, + "content": final_parsed_data.get("content", ""), + "created_at": datetime.datetime.now().isoformat(), + "attachments": [], + "embeds": False, + "mentions": [], + "replied_to_message_id": None, + "channel": message.channel, + "guild": message.guild, + "reference": None, + "mentioned_users_details": [], + } + cog.message_cache["by_channel"].setdefault(channel_id, []).append( + bot_response_cache_entry + ) + cog.message_cache["global_recent"].append(bot_response_cache_entry) + cog.bot_last_spoke[channel_id] = time.time() + # Track participation topic logic might need adjustment based on plan goal + if ( + plan + and plan.get("response_goal") == "engage user interest" + and plan.get("key_info_to_include") + ): + topic = ( + plan["key_info_to_include"][0].lower().strip() + ) # Assume first key info is the topic + cog.gurt_participation_topics[topic] += 1 + print(f"Tracked Gurt participation (proactive) in topic: '{topic}'") except Exception as e: error_message = f"Error getting proactive AI response for channel {channel_id} ({trigger_reason}): {type(e).__name__}: {str(e)}" print(error_message) - final_parsed_data = {"should_respond": False, "content": None, "react_with_emoji": None, "error": error_message} + final_parsed_data = { + "should_respond": False, + "content": None, + "react_with_emoji": None, + "error": error_message, + } # Ensure default keys exist final_parsed_data.setdefault("should_respond", False) final_parsed_data.setdefault("content", None) final_parsed_data.setdefault("react_with_emoji", None) - final_parsed_data.setdefault("request_tenor_gif_query", None) # Ensure this key exists + final_parsed_data.setdefault( + "request_tenor_gif_query", None + ) # Ensure this key exists if error_message and "error" not in final_parsed_data: - final_parsed_data["error"] = error_message - - sticker_ids_to_send_proactive: List[str] = [] # Initialize list for sticker IDs + final_parsed_data["error"] = error_message + + sticker_ids_to_send_proactive: List[str] = [] # Initialize list for sticker IDs # --- Handle Custom Emoji/Sticker Replacement in Proactive Content --- if final_parsed_data and final_parsed_data.get("content"): content_to_process = final_parsed_data["content"] # Find all potential custom emoji/sticker names like :name: - potential_custom_items = re.findall(r':([\w\d_]+?):', content_to_process) + potential_custom_items = re.findall(r":([\w\d_]+?):", content_to_process) modified_content = content_to_process for item_name_key in potential_custom_items: full_item_name_with_colons = f":{item_name_key}:" - + # Check for custom emoji (logic remains similar to main response) emoji_data = await cog.emoji_manager.get_emoji(full_item_name_with_colons) can_use_emoji = False @@ -1801,23 +2588,38 @@ async def get_proactive_ai_response(cog: 'GurtCog', message: discord.Message, tr if emoji_guild_id is not None: try: guild = cog.bot.get_guild(int(emoji_guild_id)) - if guild: can_use_emoji = True - except ValueError: pass # Invalid guild_id - else: can_use_emoji = True # Usable if no guild_id - + if guild: + can_use_emoji = True + except ValueError: + pass # Invalid guild_id + else: + can_use_emoji = True # Usable if no guild_id + if can_use_emoji and emoji_id: - discord_emoji_syntax = f"<{'a' if is_animated else ''}:{item_name_key}:{emoji_id}>" - modified_content = modified_content.replace(full_item_name_with_colons, discord_emoji_syntax, 1) - print(f"Proactive: Replaced custom emoji '{full_item_name_with_colons}' with {discord_emoji_syntax}") + discord_emoji_syntax = ( + f"<{'a' if is_animated else ''}:{item_name_key}:{emoji_id}>" + ) + modified_content = modified_content.replace( + full_item_name_with_colons, discord_emoji_syntax, 1 + ) + print( + f"Proactive: Replaced custom emoji '{full_item_name_with_colons}' with {discord_emoji_syntax}" + ) # Check for custom sticker - sticker_data = await cog.emoji_manager.get_sticker(full_item_name_with_colons) + sticker_data = await cog.emoji_manager.get_sticker( + full_item_name_with_colons + ) can_use_sticker = False - print(f"[PROACTIVE] Checking sticker: '{full_item_name_with_colons}'. Found data: {sticker_data}") + print( + f"[PROACTIVE] Checking sticker: '{full_item_name_with_colons}'. Found data: {sticker_data}" + ) if isinstance(sticker_data, dict): sticker_id = sticker_data.get("id") sticker_guild_id = sticker_data.get("guild_id") - print(f"[PROACTIVE] Sticker '{full_item_name_with_colons}': ID='{sticker_id}', GuildID='{sticker_guild_id}'") + print( + f"[PROACTIVE] Sticker '{full_item_name_with_colons}': ID='{sticker_id}', GuildID='{sticker_guild_id}'" + ) if sticker_id: if sticker_guild_id is not None: @@ -1826,33 +2628,57 @@ async def get_proactive_ai_response(cog: 'GurtCog', message: discord.Message, tr guild = cog.bot.get_guild(guild_id_int) if guild: can_use_sticker = True - print(f"[PROACTIVE] Sticker '{full_item_name_with_colons}' (Guild: {guild.name} ({sticker_guild_id})) - Bot IS a member. CAN USE.") + print( + f"[PROACTIVE] Sticker '{full_item_name_with_colons}' (Guild: {guild.name} ({sticker_guild_id})) - Bot IS a member. CAN USE." + ) else: - print(f"[PROACTIVE] Sticker '{full_item_name_with_colons}' (Guild ID: {sticker_guild_id}) - Bot is NOT in this guild. CANNOT USE.") + print( + f"[PROACTIVE] Sticker '{full_item_name_with_colons}' (Guild ID: {sticker_guild_id}) - Bot is NOT in this guild. CANNOT USE." + ) except ValueError: - print(f"[PROACTIVE] Invalid guild_id format for sticker '{full_item_name_with_colons}': {sticker_guild_id}. CANNOT USE.") - else: # guild_id is None, considered usable + print( + f"[PROACTIVE] Invalid guild_id format for sticker '{full_item_name_with_colons}': {sticker_guild_id}. CANNOT USE." + ) + else: # guild_id is None, considered usable can_use_sticker = True - print(f"[PROACTIVE] Sticker '{full_item_name_with_colons}' has no associated guild_id. CAN USE.") + print( + f"[PROACTIVE] Sticker '{full_item_name_with_colons}' has no associated guild_id. CAN USE." + ) else: - print(f"[PROACTIVE] Sticker '{full_item_name_with_colons}' found in data, but no ID. CANNOT USE.") + print( + f"[PROACTIVE] Sticker '{full_item_name_with_colons}' found in data, but no ID. CANNOT USE." + ) else: - print(f"[PROACTIVE] Sticker '{full_item_name_with_colons}' not found in emoji_manager or data is not dict.") + print( + f"[PROACTIVE] Sticker '{full_item_name_with_colons}' not found in emoji_manager or data is not dict." + ) - print(f"[PROACTIVE] Final check for sticker '{full_item_name_with_colons}': can_use_sticker={can_use_sticker}, sticker_id='{sticker_data.get('id') if isinstance(sticker_data, dict) else None}'") - if can_use_sticker and isinstance(sticker_data, dict) and sticker_data.get("id"): - sticker_id_to_add = sticker_data.get("id") # Re-fetch to be safe - if sticker_id_to_add: # Ensure ID is valid + print( + f"[PROACTIVE] Final check for sticker '{full_item_name_with_colons}': can_use_sticker={can_use_sticker}, sticker_id='{sticker_data.get('id') if isinstance(sticker_data, dict) else None}'" + ) + if ( + can_use_sticker + and isinstance(sticker_data, dict) + and sticker_data.get("id") + ): + sticker_id_to_add = sticker_data.get("id") # Re-fetch to be safe + if sticker_id_to_add: # Ensure ID is valid if full_item_name_with_colons in modified_content: - modified_content = modified_content.replace(full_item_name_with_colons, "", 1).strip() + modified_content = modified_content.replace( + full_item_name_with_colons, "", 1 + ).strip() if sticker_id_to_add not in sticker_ids_to_send_proactive: sticker_ids_to_send_proactive.append(sticker_id_to_add) - print(f"Proactive: Found custom sticker '{full_item_name_with_colons}', removed from content, added ID '{sticker_id_to_add}'") - elif not sticker_id_to_add: # Check sticker_id_to_add here - print(f"[PROACTIVE] Found custom sticker '{full_item_name_with_colons}' (dict) but no ID stored (sticker_id_to_add is falsy).") - + print( + f"Proactive: Found custom sticker '{full_item_name_with_colons}', removed from content, added ID '{sticker_id_to_add}'" + ) + elif not sticker_id_to_add: # Check sticker_id_to_add here + print( + f"[PROACTIVE] Found custom sticker '{full_item_name_with_colons}' (dict) but no ID stored (sticker_id_to_add is falsy)." + ) + # Clean up any double spaces or leading/trailing whitespace after replacements - modified_content = re.sub(r'\s{2,}', ' ', modified_content).strip() + modified_content = re.sub(r"\s{2,}", " ", modified_content).strip() final_parsed_data["content"] = modified_content if sticker_ids_to_send_proactive or (content_to_process != modified_content): print("Proactive content modified for custom emoji/sticker information.") @@ -1862,11 +2688,11 @@ async def get_proactive_ai_response(cog: 'GurtCog', message: discord.Message, tr # --- AI Image Description Function --- async def generate_image_description( - cog: 'GurtCog', + cog: "GurtCog", image_url: str, item_name: str, - item_type: str, # "emoji" or "sticker" - mime_type: str # e.g., "image/png", "image/gif" + item_type: str, # "emoji" or "sticker" + mime_type: str, # e.g., "image/png", "image/gif" ) -> Optional[str]: """ Generates a textual description for an image URL using a multimodal AI model. @@ -1882,49 +2708,74 @@ async def generate_image_description( The AI-generated description string, or None if an error occurs. """ if not genai_client: - print("Error in generate_image_description: Google GenAI Client not initialized.") + print( + "Error in generate_image_description: Google GenAI Client not initialized." + ) return None if not cog.session: - print("Error in generate_image_description: aiohttp session not initialized in cog.") + print( + "Error in generate_image_description: aiohttp session not initialized in cog." + ) return None - print(f"Attempting to generate description for {item_type} '{item_name}' from URL: {image_url}") + print( + f"Attempting to generate description for {item_type} '{item_name}' from URL: {image_url}" + ) try: # 1. Download image data async with cog.session.get(image_url, timeout=15) as response: if response.status != 200: - print(f"Failed to download image from {image_url}. Status: {response.status}") + print( + f"Failed to download image from {image_url}. Status: {response.status}" + ) return None image_bytes = await response.read() - + # Attempt to infer MIME type from bytes inferred_type = imghdr.what(None, h=image_bytes) inferred_mime_type = None - if inferred_type == 'png': - inferred_mime_type = 'image/png' - elif inferred_type == 'jpeg': - inferred_mime_type = 'image/jpeg' - elif inferred_type == 'gif': - inferred_mime_type = 'image/gif' + if inferred_type == "png": + inferred_mime_type = "image/png" + elif inferred_type == "jpeg": + inferred_mime_type = "image/jpeg" + elif inferred_type == "gif": + inferred_mime_type = "image/gif" # imghdr does not directly support webp, so check magic bytes - elif image_bytes.startswith(b'RIFF') and b'WEBP' in image_bytes[:12]: - inferred_mime_type = 'image/webp' + elif image_bytes.startswith(b"RIFF") and b"WEBP" in image_bytes[:12]: + inferred_mime_type = "image/webp" # Add other types as needed # Use inferred_mime_type if it's more specific or if the provided mime_type is generic - final_mime_type = mime_type.split(';')[0].lower() # Start with provided clean mime + final_mime_type = mime_type.split(";")[ + 0 + ].lower() # Start with provided clean mime if inferred_mime_type and inferred_mime_type != final_mime_type: - print(f"MIME type mismatch: Provided '{final_mime_type}', Inferred '{inferred_mime_type}'. Using inferred.") + print( + f"MIME type mismatch: Provided '{final_mime_type}', Inferred '{inferred_mime_type}'. Using inferred." + ) final_mime_type = inferred_mime_type - elif not inferred_mime_type and final_mime_type == "application/octet-stream": - print(f"Warning: Could not infer specific MIME type from bytes. Using provided generic '{final_mime_type}'.") + elif ( + not inferred_mime_type and final_mime_type == "application/octet-stream" + ): + print( + f"Warning: Could not infer specific MIME type from bytes. Using provided generic '{final_mime_type}'." + ) # Validate against known supported image types for Gemini - supported_image_mimes = ["image/png", "image/jpeg", "image/webp", "image/heic", "image/heif", "image/gif"] + supported_image_mimes = [ + "image/png", + "image/jpeg", + "image/webp", + "image/heic", + "image/heif", + "image/gif", + ] if final_mime_type not in supported_image_mimes: - print(f"Warning: Final image MIME type '{final_mime_type}' from {image_url} is not explicitly supported by Gemini. Proceeding anyway.") - + print( + f"Warning: Final image MIME type '{final_mime_type}' from {image_url} is not explicitly supported by Gemini. Proceeding anyway." + ) + print(f"Using final MIME type '{final_mime_type}' for image part.") # 2. Prepare contents for AI @@ -1938,72 +2789,95 @@ async def generate_image_description( "Don't output anything other than the description text. E.G. don't include something like \"Heres the description: \" before the text." ) - image_part = types.Part(inline_data=types.Blob(data=image_bytes, mime_type=final_mime_type)) + image_part = types.Part( + inline_data=types.Blob(data=image_bytes, mime_type=final_mime_type) + ) text_part = types.Part(text=prompt_text) - description_contents: List[types.Content] = [types.Content(role="user", parts=[image_part, text_part])] + description_contents: List[types.Content] = [ + types.Content(role="user", parts=[image_part, text_part]) + ] # 3. Prepare Generation Config # We want a plain text response, no JSON schema. Safety settings are standard (BLOCK_NONE). # System prompt is not strictly needed here as the user prompt is direct. description_gen_config = types.GenerateContentConfig( - temperature=0.4, # Lower temperature for more factual description - max_output_tokens=256, # Descriptions should be concise + temperature=0.4, # Lower temperature for more factual description + max_output_tokens=256, # Descriptions should be concise safety_settings=STANDARD_SAFETY_SETTINGS, # No response_mime_type or response_schema needed for plain text - tools=None, # No tools for this task - tool_config=None + tools=None, # No tools for this task + tool_config=None, ) # 4. Call AI # Use a multimodal model, e.g., DEFAULT_MODEL if it's Gemini 1.5 Pro or similar # Determine which model to use based on item_type - model_to_use = EMOJI_STICKER_DESCRIPTION_MODEL if item_type in ["emoji", "sticker"] else DEFAULT_MODEL - - print(f"Calling AI for image description ({item_name}) using model: {model_to_use}") + model_to_use = ( + EMOJI_STICKER_DESCRIPTION_MODEL + if item_type in ["emoji", "sticker"] + else DEFAULT_MODEL + ) + + print( + f"Calling AI for image description ({item_name}) using model: {model_to_use}" + ) ai_response_obj = await call_google_genai_api_with_retry( cog=cog, model_name=model_to_use, contents=description_contents, generation_config=description_gen_config, - request_desc=f"Image description for {item_type} '{item_name}'" + request_desc=f"Image description for {item_type} '{item_name}'", ) # 5. Extract text if not ai_response_obj: - print(f"AI call for image description of '{item_name}' returned no response object.") + print( + f"AI call for image description of '{item_name}' returned no response object." + ) return None description_text = _get_response_text(ai_response_obj) if description_text: - print(f"Successfully generated description for '{item_name}': {description_text[:100]}...") + print( + f"Successfully generated description for '{item_name}': {description_text[:100]}..." + ) return description_text.strip() else: - print(f"AI response for '{item_name}' contained no usable text. Response: {ai_response_obj}") + print( + f"AI response for '{item_name}' contained no usable text. Response: {ai_response_obj}" + ) return None except aiohttp.ClientError as client_e: - print(f"Network error downloading image {image_url} for description: {client_e}") + print( + f"Network error downloading image {image_url} for description: {client_e}" + ) return None except asyncio.TimeoutError: print(f"Timeout downloading image {image_url} for description.") return None except Exception as e: - print(f"Unexpected error in generate_image_description for '{item_name}': {type(e).__name__}: {e}") + print( + f"Unexpected error in generate_image_description for '{item_name}': {type(e).__name__}: {e}" + ) import traceback + traceback.print_exc() return None # --- Internal AI Call for Specific Tasks --- async def get_internal_ai_json_response( - cog: 'GurtCog', - prompt_messages: List[Dict[str, Any]], # Keep this format + cog: "GurtCog", + prompt_messages: List[Dict[str, Any]], # Keep this format task_description: str, - response_schema_dict: Dict[str, Any], # Expect schema as dict - model_name_override: Optional[str] = None, # Renamed for clarity + response_schema_dict: Dict[str, Any], # Expect schema as dict + model_name_override: Optional[str] = None, # Renamed for clarity temperature: float = 0.7, max_tokens: int = 5000, -) -> Optional[Tuple[Optional[Dict[str, Any]], Optional[str]]]: # Return tuple: (parsed_data, raw_text) +) -> Optional[ + Tuple[Optional[Dict[str, Any]], Optional[str]] +]: # Return tuple: (parsed_data, raw_text) """ Makes a Google GenAI API call (Vertex AI backend) expecting a specific JSON response format for internal tasks. @@ -2022,13 +2896,15 @@ async def get_internal_ai_json_response( - The raw text response received from the API, or None if the call failed before getting text. """ if not PROJECT_ID or not LOCATION: - print(f"Error in get_internal_ai_json_response ({task_description}): GCP Project/Location not set.") - return None, None # Return tuple + print( + f"Error in get_internal_ai_json_response ({task_description}): GCP Project/Location not set." + ) + return None, None # Return tuple final_parsed_data: Optional[Dict[str, Any]] = None final_response_text: Optional[str] = None error_occurred = None - request_payload_for_logging = {} # For logging + request_payload_for_logging = {} # For logging try: # --- Convert prompt messages to Vertex AI types.Content format --- @@ -2044,13 +2920,15 @@ async def get_internal_ai_json_response( else: # Append subsequent system messages to the instruction system_instruction += "\n\n" + content_text - continue # Skip adding system messages to contents list + continue # Skip adding system messages to contents list elif role == "assistant": role = "model" # --- Process content (string or list) --- content_value = msg.get("content") - message_parts: List[types.Part] = [] # Initialize list to hold parts for this message + message_parts: List[types.Part] = ( + [] + ) # Initialize list to hold parts for this message if isinstance(content_value, str): # Handle simple string content @@ -2068,24 +2946,35 @@ async def get_internal_ai_json_response( if mime_type and base64_data: try: image_bytes = base64.b64decode(base64_data) - message_parts.append(types.Part(data=image_bytes, mime_type=mime_type)) + message_parts.append( + types.Part(data=image_bytes, mime_type=mime_type) + ) except Exception as decode_err: - print(f"Error decoding/adding image part in get_internal_ai_json_response: {decode_err}") + print( + f"Error decoding/adding image part in get_internal_ai_json_response: {decode_err}" + ) # Optionally add a placeholder text part indicating failure - message_parts.append(types.Part(text="(System Note: Failed to process an image part)")) + message_parts.append( + types.Part( + text="(System Note: Failed to process an image part)" + ) + ) else: - print("Warning: image_data part missing mime_type or data.") + print("Warning: image_data part missing mime_type or data.") else: - print(f"Warning: Unknown part type '{part_type}' in internal prompt message.") + print( + f"Warning: Unknown part type '{part_type}' in internal prompt message." + ) else: - print(f"Warning: Unexpected content type '{type(content_value)}' in internal prompt message.") + print( + f"Warning: Unexpected content type '{type(content_value)}' in internal prompt message." + ) # Add the content object if parts were generated if message_parts: contents.append(types.Content(role=role, parts=message_parts)) else: - print(f"Warning: No parts generated for message role '{role}'.") - + print(f"Warning: No parts generated for message role '{role}'.") # Add the critical JSON instruction to the last user message or as a new user message json_instruction_content = ( @@ -2094,14 +2983,21 @@ async def get_internal_ai_json_response( f"**Ensure nothing precedes or follows the JSON.**" ) if contents and contents[-1].role == "user": - contents[-1].parts.append(types.Part(text=f"\n\n{json_instruction_content}")) + contents[-1].parts.append( + types.Part(text=f"\n\n{json_instruction_content}") + ) else: - contents.append(types.Content(role="user", parts=[types.Part(text=json_instruction_content)])) - + contents.append( + types.Content( + role="user", parts=[types.Part(text=json_instruction_content)] + ) + ) # --- Determine Model --- # Use override if provided, otherwise default (e.g., FALLBACK_MODEL for planning) - actual_model_name = model_name_override or DEFAULT_MODEL # Or choose a specific default like FALLBACK_MODEL + actual_model_name = ( + model_name_override or DEFAULT_MODEL + ) # Or choose a specific default like FALLBACK_MODEL # --- Prepare Generation Config --- processed_schema_internal = _preprocess_schema_for_vertex(response_schema_dict) @@ -2110,48 +3006,51 @@ async def get_internal_ai_json_response( "max_output_tokens": max_tokens, "response_mime_type": "application/json", "response_schema": processed_schema_internal, - "safety_settings": STANDARD_SAFETY_SETTINGS, # Include standard safety - "tools": None, # No tools for internal JSON tasks - "tool_config": None + "safety_settings": STANDARD_SAFETY_SETTINGS, # Include standard safety + "tools": None, # No tools for internal JSON tasks + "tool_config": None, } generation_config = types.GenerateContentConfig(**internal_gen_config_dict) # --- Prepare Payload for Logging --- # (Logging needs adjustment as model object isn't created here) generation_config_log = { - "temperature": generation_config.temperature, - "max_output_tokens": generation_config.max_output_tokens, - "response_mime_type": generation_config.response_mime_type, - "response_schema": str(generation_config.response_schema) # Log schema as string + "temperature": generation_config.temperature, + "max_output_tokens": generation_config.max_output_tokens, + "response_mime_type": generation_config.response_mime_type, + "response_schema": str( + generation_config.response_schema + ), # Log schema as string } request_payload_for_logging = { - "model": actual_model_name, # Log the name used + "model": actual_model_name, # Log the name used # System instruction is now part of 'contents' for logging if handled that way - "contents": [{"role": c.role, "parts": [str(p) for p in c.parts]} for c in contents], - "generation_config": generation_config_log + "contents": [ + {"role": c.role, "parts": [str(p) for p in c.parts]} for c in contents + ], + "generation_config": generation_config_log, } # (Keep detailed logging logic if desired) try: - print(f"--- Raw request payload for {task_description} ---") - print(json.dumps(request_payload_for_logging, indent=2, default=str)) - print(f"--- End Raw request payload ---") + print(f"--- Raw request payload for {task_description} ---") + print(json.dumps(request_payload_for_logging, indent=2, default=str)) + print(f"--- End Raw request payload ---") except Exception as req_log_e: - print(f"Error logging raw request payload: {req_log_e}") - + print(f"Error logging raw request payload: {req_log_e}") # --- Call API using the new helper --- response_obj = await call_google_genai_api_with_retry( cog=cog, - model_name=actual_model_name, # Pass the determined model name + model_name=actual_model_name, # Pass the determined model name contents=contents, - generation_config=generation_config, # Pass combined config + generation_config=generation_config, # Pass combined config request_desc=task_description, # No separate safety, tools, tool_config args needed ) # --- Process Response --- if not response_obj: - raise Exception("Internal API call failed to return a response object.") + raise Exception("Internal API call failed to return a response object.") # Log the raw response object print(f"--- Full response_obj received for {task_description} ---") @@ -2159,34 +3058,45 @@ async def get_internal_ai_json_response( print(f"--- End Full response_obj ---") if not response_obj.candidates: - print(f"Warning: Internal API call for {task_description} returned no candidates. Response: {response_obj}") - final_response_text = getattr(response_obj, 'text', None) # Try to get text anyway - final_parsed_data = None + print( + f"Warning: Internal API call for {task_description} returned no candidates. Response: {response_obj}" + ) + final_response_text = getattr( + response_obj, "text", None + ) # Try to get text anyway + final_parsed_data = None else: - # Parse and Validate using the updated helper - final_response_text = _get_response_text(response_obj) # Store raw text - print(f"--- Extracted Text for {task_description} ---") - print(final_response_text) - print(f"--- End Extracted Text ---") + # Parse and Validate using the updated helper + final_response_text = _get_response_text(response_obj) # Store raw text + print(f"--- Extracted Text for {task_description} ---") + print(final_response_text) + print(f"--- End Extracted Text ---") - # --- Log Raw Unparsed JSON --- - print(f"--- RAW UNPARSED JSON ({task_description}) ---") - print(final_response_text) - print(f"--- END RAW UNPARSED JSON ---") - # --- End Log --- + # --- Log Raw Unparsed JSON --- + print(f"--- RAW UNPARSED JSON ({task_description}) ---") + print(final_response_text) + print(f"--- END RAW UNPARSED JSON ---") + # --- End Log --- - final_parsed_data = parse_and_validate_json_response( - final_response_text, response_schema_dict, f"internal task ({task_description})" - ) + final_parsed_data = parse_and_validate_json_response( + final_response_text, + response_schema_dict, + f"internal task ({task_description})", + ) - if final_parsed_data is None: - print(f"Warning: Internal task '{task_description}' failed JSON validation.") - # Keep final_response_text for returning raw output + if final_parsed_data is None: + print( + f"Warning: Internal task '{task_description}' failed JSON validation." + ) + # Keep final_response_text for returning raw output except Exception as e: - print(f"Error in get_internal_ai_json_response ({task_description}): {type(e).__name__}: {e}") + print( + f"Error in get_internal_ai_json_response ({task_description}): {type(e).__name__}: {e}" + ) error_occurred = e import traceback + traceback.print_exc() final_parsed_data = None # final_response_text might be None or contain partial/error text depending on when exception occurred @@ -2194,12 +3104,19 @@ async def get_internal_ai_json_response( # Log the call try: # Pass the simplified payload and the *parsed* data for logging - await log_internal_api_call(cog, task_description, request_payload_for_logging, final_parsed_data, error_occurred) + await log_internal_api_call( + cog, + task_description, + request_payload_for_logging, + final_parsed_data, + error_occurred, + ) except Exception as log_e: print(f"Error logging internal API call: {log_e}") # Return both parsed data and raw text return final_parsed_data, final_response_text + if __name__ == "__main__": - print(_preprocess_schema_for_vertex(RESPONSE_SCHEMA['schema'])) + print(_preprocess_schema_for_vertex(RESPONSE_SCHEMA["schema"])) diff --git a/gurt/background.py b/gurt/background.py index c677495..63e907c 100644 --- a/gurt/background.py +++ b/gurt/background.py @@ -5,52 +5,74 @@ import traceback import os import json import aiohttp -import discord # Added import +import discord # Added import from collections import defaultdict -from typing import TYPE_CHECKING, Any, List, Dict # Added List, Dict +from typing import TYPE_CHECKING, Any, List, Dict # Added List, Dict + # Use google.generativeai instead of vertexai directly from google import genai from google.genai import types + # from google.protobuf import json_format # No longer needed for args parsing # Relative imports from .config import ( - GOAL_CHECK_INTERVAL, GOAL_EXECUTION_INTERVAL, LEARNING_UPDATE_INTERVAL, EVOLUTION_UPDATE_INTERVAL, INTEREST_UPDATE_INTERVAL, - INTEREST_DECAY_INTERVAL_HOURS, INTEREST_PARTICIPATION_BOOST, - INTEREST_POSITIVE_REACTION_BOOST, INTEREST_NEGATIVE_REACTION_PENALTY, - INTEREST_FACT_BOOST, PROACTIVE_GOAL_CHECK_INTERVAL, STATS_PUSH_INTERVAL, # Added stats interval - MOOD_OPTIONS, MOOD_CATEGORIES, MOOD_CHANGE_INTERVAL_MIN, MOOD_CHANGE_INTERVAL_MAX, # Mood change imports - BASELINE_PERSONALITY, # For default traits - REFLECTION_INTERVAL_SECONDS # Import reflection interval + GOAL_CHECK_INTERVAL, + GOAL_EXECUTION_INTERVAL, + LEARNING_UPDATE_INTERVAL, + EVOLUTION_UPDATE_INTERVAL, + INTEREST_UPDATE_INTERVAL, + INTEREST_DECAY_INTERVAL_HOURS, + INTEREST_PARTICIPATION_BOOST, + INTEREST_POSITIVE_REACTION_BOOST, + INTEREST_NEGATIVE_REACTION_PENALTY, + INTEREST_FACT_BOOST, + PROACTIVE_GOAL_CHECK_INTERVAL, + STATS_PUSH_INTERVAL, # Added stats interval + MOOD_OPTIONS, + MOOD_CATEGORIES, + MOOD_CHANGE_INTERVAL_MIN, + MOOD_CHANGE_INTERVAL_MAX, # Mood change imports + BASELINE_PERSONALITY, # For default traits + REFLECTION_INTERVAL_SECONDS, # Import reflection interval ) + # Assuming analysis functions are moved from .analysis import ( - analyze_conversation_patterns, evolve_personality, identify_conversation_topics, - reflect_on_memories, decompose_goal_into_steps, # Import goal decomposition - proactively_create_goals # Import placeholder for proactive goal creation + analyze_conversation_patterns, + evolve_personality, + identify_conversation_topics, + reflect_on_memories, + decompose_goal_into_steps, # Import goal decomposition + proactively_create_goals, # Import placeholder for proactive goal creation ) + # Import helpers from api.py from .api import ( get_internal_ai_json_response, - call_google_genai_api_with_retry, # Import the retry helper - find_function_call_in_parts, # Import function call finder - _get_response_text, # Import text extractor - _preprocess_schema_for_vertex, # Import schema preprocessor (name kept for now) - STANDARD_SAFETY_SETTINGS, # Import safety settings - process_requested_tools # Import tool processor + call_google_genai_api_with_retry, # Import the retry helper + find_function_call_in_parts, # Import function call finder + _get_response_text, # Import text extractor + _preprocess_schema_for_vertex, # Import schema preprocessor (name kept for now) + STANDARD_SAFETY_SETTINGS, # Import safety settings + process_requested_tools, # Import tool processor ) if TYPE_CHECKING: - from .cog import GurtCog # For type hinting + from .cog import GurtCog # For type hinting # --- Tool Mapping Import --- # Import the mapping to execute tools by name -from .tools import TOOL_MAPPING, send_discord_message # Also import send_discord_message directly for goal execution reporting +from .tools import ( + TOOL_MAPPING, + send_discord_message, +) # Also import send_discord_message directly for goal execution reporting from .config import TOOLS # Import FunctionDeclaration list for tool metadata # --- Background Task --- -async def background_processing_task(cog: 'GurtCog'): + +async def background_processing_task(cog: "GurtCog"): """Background task that periodically analyzes conversations, evolves personality, updates interests, changes mood, reflects on memory, and pushes stats.""" # Get API details from environment for stats pushing api_internal_url = os.getenv("API_INTERNAL_URL") @@ -59,56 +81,80 @@ async def background_processing_task(cog: 'GurtCog'): if not api_internal_url: print("WARNING: API_INTERNAL_URL not set. Gurt stats will not be pushed.") if not gurt_stats_push_secret: - print("WARNING: GURT_STATS_PUSH_SECRET not set. Gurt stats push endpoint is insecure and likely won't work.") + print( + "WARNING: GURT_STATS_PUSH_SECRET not set. Gurt stats push endpoint is insecure and likely won't work." + ) try: while True: - await asyncio.sleep(15) # Check more frequently for stats push + await asyncio.sleep(15) # Check more frequently for stats push now = time.time() # --- Push Stats (Runs frequently) --- - if api_internal_url and gurt_stats_push_secret and (now - cog.last_stats_push > STATS_PUSH_INTERVAL): + if ( + api_internal_url + and gurt_stats_push_secret + and (now - cog.last_stats_push > STATS_PUSH_INTERVAL) + ): print("Pushing Gurt stats to API server...") try: stats_data = await cog.get_gurt_stats() headers = { "Authorization": f"Bearer {gurt_stats_push_secret}", - "Content-Type": "application/json" + "Content-Type": "application/json", } # Use the cog's session, ensure it's created if cog.session: # Set a reasonable timeout for the stats push - push_timeout = aiohttp.ClientTimeout(total=10) # 10 seconds total timeout - async with cog.session.post(api_internal_url, json=stats_data, headers=headers, timeout=push_timeout, ssl=True) as response: # Explicitly enable SSL verification + push_timeout = aiohttp.ClientTimeout( + total=10 + ) # 10 seconds total timeout + async with cog.session.post( + api_internal_url, + json=stats_data, + headers=headers, + timeout=push_timeout, + ssl=True, + ) as response: # Explicitly enable SSL verification if response.status == 200: - print(f"Successfully pushed Gurt stats (Status: {response.status})") + print( + f"Successfully pushed Gurt stats (Status: {response.status})" + ) else: error_text = await response.text() - print(f"Failed to push Gurt stats (Status: {response.status}): {error_text[:200]}") # Log only first 200 chars + print( + f"Failed to push Gurt stats (Status: {response.status}): {error_text[:200]}" + ) # Log only first 200 chars else: print("Error pushing stats: GurtCog session not initialized.") - cog.last_stats_push = now # Update timestamp even on failure to avoid spamming logs + cog.last_stats_push = ( + now # Update timestamp even on failure to avoid spamming logs + ) except aiohttp.ClientConnectorSSLError as ssl_err: - print(f"SSL Error pushing Gurt stats: {ssl_err}. Ensure the API server's certificate is valid and trusted, or check network configuration.") - print("If using a self-signed certificate for development, the bot process might need to trust it.") - cog.last_stats_push = now # Update timestamp to avoid spamming logs + print( + f"SSL Error pushing Gurt stats: {ssl_err}. Ensure the API server's certificate is valid and trusted, or check network configuration." + ) + print( + "If using a self-signed certificate for development, the bot process might need to trust it." + ) + cog.last_stats_push = now # Update timestamp to avoid spamming logs except aiohttp.ClientError as client_err: print(f"HTTP Client Error pushing Gurt stats: {client_err}") - cog.last_stats_push = now # Update timestamp to avoid spamming logs + cog.last_stats_push = now # Update timestamp to avoid spamming logs except asyncio.TimeoutError: print("Timeout error pushing Gurt stats.") - cog.last_stats_push = now # Update timestamp to avoid spamming logs + cog.last_stats_push = now # Update timestamp to avoid spamming logs except Exception as e: print(f"Unexpected error pushing Gurt stats: {e}") traceback.print_exc() - cog.last_stats_push = now # Update timestamp even on error + cog.last_stats_push = now # Update timestamp even on error # --- Learning Analysis (Runs less frequently) --- if now - cog.last_learning_update > LEARNING_UPDATE_INTERVAL: - if cog.message_cache['global_recent']: + if cog.message_cache["global_recent"]: print("Running conversation pattern analysis...") # This function now likely resides in analysis.py - await analyze_conversation_patterns(cog) # Pass cog instance + await analyze_conversation_patterns(cog) # Pass cog instance cog.last_learning_update = now print("Learning analysis cycle complete.") else: @@ -118,94 +164,142 @@ async def background_processing_task(cog: 'GurtCog'): if now - cog.last_evolution_update > EVOLUTION_UPDATE_INTERVAL: print("Running personality evolution...") # This function now likely resides in analysis.py - await evolve_personality(cog) # Pass cog instance + await evolve_personality(cog) # Pass cog instance cog.last_evolution_update = now print("Personality evolution complete.") # --- Update Interests (Runs moderately frequently) --- if now - cog.last_interest_update > INTEREST_UPDATE_INTERVAL: print("Running interest update...") - await update_interests(cog) # Call the local helper function below + await update_interests(cog) # Call the local helper function below print("Running interest decay check...") await cog.memory_manager.decay_interests( decay_interval_hours=INTEREST_DECAY_INTERVAL_HOURS ) - cog.last_interest_update = now # Reset timer after update and decay check + cog.last_interest_update = ( + now # Reset timer after update and decay check + ) print("Interest update and decay check complete.") # --- Memory Reflection (Runs less frequently) --- if now - cog.last_reflection_time > REFLECTION_INTERVAL_SECONDS: print("Running memory reflection...") - await reflect_on_memories(cog) # Call the reflection function from analysis.py - cog.last_reflection_time = now # Update timestamp + await reflect_on_memories( + cog + ) # Call the reflection function from analysis.py + cog.last_reflection_time = now # Update timestamp print("Memory reflection cycle complete.") # --- Goal Decomposition (Runs periodically) --- # Check less frequently than other tasks, e.g., every few minutes - if now - cog.last_goal_check_time > GOAL_CHECK_INTERVAL: # Need to add these to cog and config + if ( + now - cog.last_goal_check_time > GOAL_CHECK_INTERVAL + ): # Need to add these to cog and config print("Checking for pending goals to decompose...") try: - pending_goals = await cog.memory_manager.get_goals(status='pending', limit=3) # Limit decomposition attempts per cycle + pending_goals = await cog.memory_manager.get_goals( + status="pending", limit=3 + ) # Limit decomposition attempts per cycle for goal in pending_goals: - goal_id = goal.get('goal_id') - description = goal.get('description') - if not goal_id or not description: continue + goal_id = goal.get("goal_id") + description = goal.get("description") + if not goal_id or not description: + continue print(f" - Decomposing goal ID {goal_id}: '{description}'") plan = await decompose_goal_into_steps(cog, description) - if plan and plan.get('goal_achievable') and plan.get('steps'): + if plan and plan.get("goal_achievable") and plan.get("steps"): # Goal is achievable and has steps, update status to active and store plan - await cog.memory_manager.update_goal(goal_id, status='active', details=plan) - print(f" - Goal ID {goal_id} decomposed and set to active.") + await cog.memory_manager.update_goal( + goal_id, status="active", details=plan + ) + print( + f" - Goal ID {goal_id} decomposed and set to active." + ) elif plan: # Goal deemed not achievable by planner - await cog.memory_manager.update_goal(goal_id, status='failed', details={"reason": plan.get('reasoning', 'Deemed unachievable by planner.')}) - print(f" - Goal ID {goal_id} marked as failed (unachievable). Reason: {plan.get('reasoning')}") + await cog.memory_manager.update_goal( + goal_id, + status="failed", + details={ + "reason": plan.get( + "reasoning", "Deemed unachievable by planner." + ) + }, + ) + print( + f" - Goal ID {goal_id} marked as failed (unachievable). Reason: {plan.get('reasoning')}" + ) else: # Decomposition failed entirely - await cog.memory_manager.update_goal(goal_id, status='failed', details={"reason": "Goal decomposition process failed."}) - print(f" - Goal ID {goal_id} marked as failed (decomposition error).") - await asyncio.sleep(1) # Small delay between decomposing goals + await cog.memory_manager.update_goal( + goal_id, + status="failed", + details={ + "reason": "Goal decomposition process failed." + }, + ) + print( + f" - Goal ID {goal_id} marked as failed (decomposition error)." + ) + await asyncio.sleep(1) # Small delay between decomposing goals - cog.last_goal_check_time = now # Update timestamp after checking + cog.last_goal_check_time = now # Update timestamp after checking except Exception as goal_e: print(f"Error during goal decomposition check: {goal_e}") traceback.print_exc() - cog.last_goal_check_time = now # Update timestamp even on error + cog.last_goal_check_time = now # Update timestamp even on error # --- Goal Execution (Runs periodically) --- if now - cog.last_goal_execution_time > GOAL_EXECUTION_INTERVAL: print("Checking for active goals to execute...") try: - active_goals = await cog.memory_manager.get_goals(status='active', limit=1) # Process one active goal per cycle for now + active_goals = await cog.memory_manager.get_goals( + status="active", limit=1 + ) # Process one active goal per cycle for now if active_goals: - goal = active_goals[0] # Get the highest priority active goal - goal_id = goal.get('goal_id') - description = goal.get('description') - plan = goal.get('details') # The decomposition plan is stored here + goal = active_goals[0] # Get the highest priority active goal + goal_id = goal.get("goal_id") + description = goal.get("description") + plan = goal.get( + "details" + ) # The decomposition plan is stored here # Retrieve context saved with the goal - goal_context_guild_id = goal.get('guild_id') - goal_context_channel_id = goal.get('channel_id') - goal_context_user_id = goal.get('user_id') + goal_context_guild_id = goal.get("guild_id") + goal_context_channel_id = goal.get("channel_id") + goal_context_user_id = goal.get("user_id") - if goal_id and description and plan and isinstance(plan.get('steps'), list): - print(f"--- Executing Goal ID {goal_id}: '{description}' (Context: G={goal_context_guild_id}, C={goal_context_channel_id}, U={goal_context_user_id}) ---") - steps = plan['steps'] - current_step_index = plan.get('current_step_index', 0) # Track progress + if ( + goal_id + and description + and plan + and isinstance(plan.get("steps"), list) + ): + print( + f"--- Executing Goal ID {goal_id}: '{description}' (Context: G={goal_context_guild_id}, C={goal_context_channel_id}, U={goal_context_user_id}) ---" + ) + steps = plan["steps"] + current_step_index = plan.get( + "current_step_index", 0 + ) # Track progress goal_failed = False goal_completed = False if current_step_index < len(steps): step = steps[current_step_index] - step_desc = step.get('step_description') - tool_name = step.get('tool_name') - tool_args = step.get('tool_arguments') + step_desc = step.get("step_description") + tool_name = step.get("tool_name") + tool_args = step.get("tool_arguments") - print(f" - Step {current_step_index + 1}/{len(steps)}: {step_desc}") + print( + f" - Step {current_step_index + 1}/{len(steps)}: {step_desc}" + ) if tool_name: - print(f" - Attempting tool: {tool_name} with args: {tool_args}") + print( + f" - Attempting tool: {tool_name} with args: {tool_args}" + ) tool_func = TOOL_MAPPING.get(tool_name) tool_result = None tool_error = None @@ -214,42 +308,75 @@ async def background_processing_task(cog: 'GurtCog'): if tool_func: try: # Ensure args are a dictionary, default to empty if None/missing - args_to_pass = tool_args if isinstance(tool_args, dict) else {} - print(f" - Executing: {tool_name}(cog, **{args_to_pass})") + args_to_pass = ( + tool_args + if isinstance(tool_args, dict) + else {} + ) + print( + f" - Executing: {tool_name}(cog, **{args_to_pass})" + ) start_time = time.monotonic() - tool_result = await tool_func(cog, **args_to_pass) + tool_result = await tool_func( + cog, **args_to_pass + ) end_time = time.monotonic() - print(f" - Tool '{tool_name}' returned: {str(tool_result)[:200]}...") # Log truncated result + print( + f" - Tool '{tool_name}' returned: {str(tool_result)[:200]}..." + ) # Log truncated result # Check result for success/error - if isinstance(tool_result, dict) and "error" in tool_result: + if ( + isinstance(tool_result, dict) + and "error" in tool_result + ): tool_error = tool_result["error"] - print(f" - Tool '{tool_name}' reported error: {tool_error}") - cog.tool_stats[tool_name]["failure"] += 1 + print( + f" - Tool '{tool_name}' reported error: {tool_error}" + ) + cog.tool_stats[tool_name][ + "failure" + ] += 1 else: tool_success = True - print(f" - Tool '{tool_name}' executed successfully.") - cog.tool_stats[tool_name]["success"] += 1 + print( + f" - Tool '{tool_name}' executed successfully." + ) + cog.tool_stats[tool_name][ + "success" + ] += 1 # Record stats cog.tool_stats[tool_name]["count"] += 1 - cog.tool_stats[tool_name]["total_time"] += (end_time - start_time) + cog.tool_stats[tool_name]["total_time"] += ( + end_time - start_time + ) except Exception as exec_e: tool_error = f"Exception during execution: {str(exec_e)}" - print(f" - Tool '{tool_name}' raised exception: {exec_e}") + print( + f" - Tool '{tool_name}' raised exception: {exec_e}" + ) traceback.print_exc() cog.tool_stats[tool_name]["failure"] += 1 - cog.tool_stats[tool_name]["count"] += 1 # Count failures too + cog.tool_stats[tool_name][ + "count" + ] += 1 # Count failures too else: tool_error = f"Tool '{tool_name}' not found in TOOL_MAPPING." print(f" - Error: {tool_error}") # --- Send Update Message (if channel context exists) --- ### MODIFICATION START ### if goal_context_channel_id: - step_number_display = current_step_index + 1 # Human-readable step number for display + step_number_display = ( + current_step_index + 1 + ) # Human-readable step number for display status_emoji = "✅" if tool_success else "❌" # Use the helper function to create a summary - step_result_summary = _create_result_summary(tool_result if tool_success else {"error": tool_error}) + step_result_summary = _create_result_summary( + tool_result + if tool_success + else {"error": tool_error} + ) update_message = ( f"**Goal Update (ID: {goal_id}, Step {step_number_display}/{len(steps)})** {status_emoji}\n" @@ -261,14 +388,24 @@ async def background_processing_task(cog: 'GurtCog'): ) # Limit message length if len(update_message) > 1900: - update_message = update_message[:1900] + "...`" + update_message = ( + update_message[:1900] + "...`" + ) try: # Use the imported send_discord_message function - await send_discord_message(cog, channel_id=goal_context_channel_id, message_content=update_message) - print(f" - Sent goal update to channel {goal_context_channel_id}") + await send_discord_message( + cog, + channel_id=goal_context_channel_id, + message_content=update_message, + ) + print( + f" - Sent goal update to channel {goal_context_channel_id}" + ) except Exception as msg_err: - print(f" - Failed to send goal update message to channel {goal_context_channel_id}: {msg_err}") + print( + f" - Failed to send goal update message to channel {goal_context_channel_id}: {msg_err}" + ) ### MODIFICATION END ### # --- Handle Tool Outcome --- @@ -278,64 +415,106 @@ async def background_processing_task(cog: 'GurtCog'): current_step_index += 1 else: goal_failed = True - plan['error_message'] = f"Failed at step {current_step_index + 1} ({tool_name}): {tool_error}" + plan["error_message"] = ( + f"Failed at step {current_step_index + 1} ({tool_name}): {tool_error}" + ) else: # Step doesn't require a tool (e.g., internal reasoning/check) - print(" - No tool required for this step (internal check/reasoning).") + print( + " - No tool required for this step (internal check/reasoning)." + ) # Send update message for non-tool steps too? Optional. For now, only for tool steps. - current_step_index += 1 # Assume non-tool steps succeed for now + current_step_index += ( + 1 # Assume non-tool steps succeed for now + ) # Check if goal completed if not goal_failed and current_step_index >= len(steps): goal_completed = True # --- Update Goal Status --- - plan['current_step_index'] = current_step_index # Update progress + plan["current_step_index"] = ( + current_step_index # Update progress + ) if goal_completed: - await cog.memory_manager.update_goal(goal_id, status='completed', details=plan) - print(f"--- Goal ID {goal_id} completed successfully. ---") + await cog.memory_manager.update_goal( + goal_id, status="completed", details=plan + ) + print( + f"--- Goal ID {goal_id} completed successfully. ---" + ) elif goal_failed: - await cog.memory_manager.update_goal(goal_id, status='failed', details=plan) + await cog.memory_manager.update_goal( + goal_id, status="failed", details=plan + ) print(f"--- Goal ID {goal_id} failed. ---") else: # Update details with current step index if still in progress - await cog.memory_manager.update_goal(goal_id, details=plan) - print(f" - Goal ID {goal_id} progress updated to step {current_step_index}.") + await cog.memory_manager.update_goal( + goal_id, details=plan + ) + print( + f" - Goal ID {goal_id} progress updated to step {current_step_index}." + ) else: # Should not happen if status is 'active', but handle defensively - print(f" - Goal ID {goal_id} is active but has no steps or index out of bounds. Marking as failed.") - await cog.memory_manager.update_goal(goal_id, status='failed', details={"reason": "Active goal has invalid step data."}) + print( + f" - Goal ID {goal_id} is active but has no steps or index out of bounds. Marking as failed." + ) + await cog.memory_manager.update_goal( + goal_id, + status="failed", + details={ + "reason": "Active goal has invalid step data." + }, + ) else: - print(f" - Skipping active goal ID {goal_id}: Missing description or valid plan/steps.") - # Optionally mark as failed if plan is invalid - if goal_id: - await cog.memory_manager.update_goal(goal_id, status='failed', details={"reason": "Invalid plan structure found during execution."}) + print( + f" - Skipping active goal ID {goal_id}: Missing description or valid plan/steps." + ) + # Optionally mark as failed if plan is invalid + if goal_id: + await cog.memory_manager.update_goal( + goal_id, + status="failed", + details={ + "reason": "Invalid plan structure found during execution." + }, + ) else: print("No active goals found to execute.") - cog.last_goal_execution_time = now # Update timestamp after checking/executing + cog.last_goal_execution_time = ( + now # Update timestamp after checking/executing + ) except Exception as goal_exec_e: print(f"Error during goal execution check: {goal_exec_e}") traceback.print_exc() - cog.last_goal_execution_time = now # Update timestamp even on error + cog.last_goal_execution_time = now # Update timestamp even on error # --- Automatic Mood Change (Runs based on its own interval check) --- # await maybe_change_mood(cog) # Call the mood change logic # --- Proactive Goal Creation Check (Runs periodically) --- - if now - cog.last_proactive_goal_check > PROACTIVE_GOAL_CHECK_INTERVAL: # Use imported config + if ( + now - cog.last_proactive_goal_check > PROACTIVE_GOAL_CHECK_INTERVAL + ): # Use imported config print("Checking if Gurt should proactively create goals...") try: - await proactively_create_goals(cog) # Call the function from analysis.py - cog.last_proactive_goal_check = now # Update timestamp + await proactively_create_goals( + cog + ) # Call the function from analysis.py + cog.last_proactive_goal_check = now # Update timestamp print("Proactive goal check complete.") except Exception as proactive_e: print(f"Error during proactive goal check: {proactive_e}") traceback.print_exc() - cog.last_proactive_goal_check = now # Update timestamp even on error + cog.last_proactive_goal_check = ( + now # Update timestamp even on error + ) # Ensure these except blocks match the initial 'try' at the function start except asyncio.CancelledError: @@ -343,7 +522,8 @@ async def background_processing_task(cog: 'GurtCog'): except Exception as e: print(f"Error in background processing task: {e}") traceback.print_exc() - await asyncio.sleep(300) # Wait 5 minutes before retrying after an error + await asyncio.sleep(300) # Wait 5 minutes before retrying after an error + # --- Helper for Summarizing Tool Results --- def _create_result_summary(tool_result: Any, max_len: int = 200) -> str: @@ -358,11 +538,11 @@ def _create_result_summary(tool_result: Any, max_len: int = 200) -> str: if "stderr" in tool_result and tool_result["stderr"]: summary += f", stderr: {tool_result['stderr'][:max_len//2]}" if "content" in tool_result: - summary += f", content: {tool_result['content'][:max_len//2]}..." + summary += f", content: {tool_result['content'][:max_len//2]}..." if "bytes_written" in tool_result: - summary += f", bytes: {tool_result['bytes_written']}" + summary += f", bytes: {tool_result['bytes_written']}" if "message_id" in tool_result: - summary += f", msg_id: {tool_result['message_id']}" + summary += f", msg_id: {tool_result['message_id']}" # Add other common keys as needed return summary[:max_len] else: @@ -383,80 +563,133 @@ def _create_result_summary(tool_result: Any, max_len: int = 200) -> str: # --- Interest Update Logic --- -async def update_interests(cog: 'GurtCog'): + +async def update_interests(cog: "GurtCog"): """Analyzes recent activity and updates Gurt's interest levels.""" print("Starting interest update cycle...") try: interest_changes = defaultdict(float) # 1. Analyze Gurt's participation in topics - print(f"Analyzing Gurt participation topics: {dict(cog.gurt_participation_topics)}") + print( + f"Analyzing Gurt participation topics: {dict(cog.gurt_participation_topics)}" + ) for topic, count in cog.gurt_participation_topics.items(): boost = INTEREST_PARTICIPATION_BOOST * count interest_changes[topic] += boost - print(f" - Participation boost for '{topic}': +{boost:.3f} (Count: {count})") + print( + f" - Participation boost for '{topic}': +{boost:.3f} (Count: {count})" + ) # 2. Analyze reactions to Gurt's messages - print(f"Analyzing {len(cog.gurt_message_reactions)} reactions to Gurt's messages...") + print( + f"Analyzing {len(cog.gurt_message_reactions)} reactions to Gurt's messages..." + ) processed_reaction_messages = set() reactions_to_process = list(cog.gurt_message_reactions.items()) for message_id, reaction_data in reactions_to_process: - if message_id in processed_reaction_messages: continue + if message_id in processed_reaction_messages: + continue topic = reaction_data.get("topic") if not topic: try: - gurt_msg_data = next((msg for msg in cog.message_cache['global_recent'] if msg['id'] == message_id), None) - if gurt_msg_data and gurt_msg_data['content']: - # Use identify_conversation_topics from analysis.py - identified_topics = identify_conversation_topics(cog, [gurt_msg_data]) # Pass cog - if identified_topics: - topic = identified_topics[0]['topic'] - print(f" - Determined topic '{topic}' for reaction msg {message_id} retrospectively.") - else: print(f" - Could not determine topic for reaction msg {message_id} retrospectively."); continue - else: print(f" - Could not find Gurt msg {message_id} in cache for reaction analysis."); continue - except Exception as topic_e: print(f" - Error determining topic for reaction msg {message_id}: {topic_e}"); continue # Corrected indent + gurt_msg_data = next( + ( + msg + for msg in cog.message_cache["global_recent"] + if msg["id"] == message_id + ), + None, + ) + if gurt_msg_data and gurt_msg_data["content"]: + # Use identify_conversation_topics from analysis.py + identified_topics = identify_conversation_topics( + cog, [gurt_msg_data] + ) # Pass cog + if identified_topics: + topic = identified_topics[0]["topic"] + print( + f" - Determined topic '{topic}' for reaction msg {message_id} retrospectively." + ) + else: + print( + f" - Could not determine topic for reaction msg {message_id} retrospectively." + ) + continue + else: + print( + f" - Could not find Gurt msg {message_id} in cache for reaction analysis." + ) + continue + except Exception as topic_e: + print( + f" - Error determining topic for reaction msg {message_id}: {topic_e}" + ) + continue # Corrected indent if topic: topic = topic.lower().strip() pos_reactions = reaction_data.get("positive", 0) neg_reactions = reaction_data.get("negative", 0) change = 0 - if pos_reactions > neg_reactions: change = INTEREST_POSITIVE_REACTION_BOOST * (pos_reactions - neg_reactions) - elif neg_reactions > pos_reactions: change = INTEREST_NEGATIVE_REACTION_PENALTY * (neg_reactions - pos_reactions) + if pos_reactions > neg_reactions: + change = INTEREST_POSITIVE_REACTION_BOOST * ( + pos_reactions - neg_reactions + ) + elif neg_reactions > pos_reactions: + change = INTEREST_NEGATIVE_REACTION_PENALTY * ( + neg_reactions - pos_reactions + ) if change != 0: interest_changes[topic] += change - print(f" - Reaction change for '{topic}' on msg {message_id}: {change:+.3f} ({pos_reactions} pos, {neg_reactions} neg)") + print( + f" - Reaction change for '{topic}' on msg {message_id}: {change:+.3f} ({pos_reactions} pos, {neg_reactions} neg)" + ) processed_reaction_messages.add(message_id) # 3. Analyze recently learned facts try: recent_facts = await cog.memory_manager.get_general_facts(limit=10) - print(f"Analyzing {len(recent_facts)} recent general facts for interest boosts...") + print( + f"Analyzing {len(recent_facts)} recent general facts for interest boosts..." + ) for fact in recent_facts: fact_lower = fact.lower() # Basic keyword checks (could be improved) - if "game" in fact_lower or "gaming" in fact_lower: interest_changes["gaming"] += INTEREST_FACT_BOOST; print(f" - Fact boost for 'gaming'") - if "anime" in fact_lower or "manga" in fact_lower: interest_changes["anime"] += INTEREST_FACT_BOOST; print(f" - Fact boost for 'anime'") - if "teto" in fact_lower: interest_changes["kasane teto"] += INTEREST_FACT_BOOST * 2; print(f" - Fact boost for 'kasane teto'") + if "game" in fact_lower or "gaming" in fact_lower: + interest_changes["gaming"] += INTEREST_FACT_BOOST + print(f" - Fact boost for 'gaming'") + if "anime" in fact_lower or "manga" in fact_lower: + interest_changes["anime"] += INTEREST_FACT_BOOST + print(f" - Fact boost for 'anime'") + if "teto" in fact_lower: + interest_changes["kasane teto"] += INTEREST_FACT_BOOST * 2 + print(f" - Fact boost for 'kasane teto'") # Add more checks... - except Exception as fact_e: print(f" - Error analyzing recent facts: {fact_e}") + except Exception as fact_e: + print(f" - Error analyzing recent facts: {fact_e}") # --- Apply Changes --- print(f"Applying interest changes: {dict(interest_changes)}") if interest_changes: for topic, change in interest_changes.items(): - if change != 0: await cog.memory_manager.update_interest(topic, change) - else: print("No interest changes to apply.") + if change != 0: + await cog.memory_manager.update_interest(topic, change) + else: + print("No interest changes to apply.") # Clear temporary tracking data cog.gurt_participation_topics.clear() now = time.time() reactions_to_keep = { - msg_id: data for msg_id, data in cog.gurt_message_reactions.items() + msg_id: data + for msg_id, data in cog.gurt_message_reactions.items() if data.get("timestamp", 0) > (now - INTEREST_UPDATE_INTERVAL * 1.1) } - cog.gurt_message_reactions = defaultdict(lambda: {"positive": 0, "negative": 0, "topic": None}, reactions_to_keep) + cog.gurt_message_reactions = defaultdict( + lambda: {"positive": 0, "negative": 0, "topic": None}, reactions_to_keep + ) print("Interest update cycle finished.") diff --git a/gurt/cog.py b/gurt/cog.py index ab31ad8..030709b 100644 --- a/gurt/cog.py +++ b/gurt/cog.py @@ -11,61 +11,104 @@ from typing import Dict, List, Any, Optional, Tuple, Set, Union # Third-party imports needed by the Cog itself or its direct methods from dotenv import load_dotenv -from tavily import TavilyClient # Needed for tavily_client init +from tavily import TavilyClient # Needed for tavily_client init + # Interpreter and docker might only be needed by tools.py now # --- Relative Imports from Gurt Package --- from .config import ( - PROJECT_ID, LOCATION, TAVILY_API_KEY, TENOR_API_KEY, DEFAULT_MODEL, FALLBACK_MODEL, # Use GCP config - DB_PATH, CHROMA_PATH, SEMANTIC_MODEL_NAME, MAX_USER_FACTS, MAX_GENERAL_FACTS, - MOOD_OPTIONS, BASELINE_PERSONALITY, BASELINE_INTERESTS, MOOD_CHANGE_INTERVAL_MIN, - MOOD_CHANGE_INTERVAL_MAX, CHANNEL_TOPIC_CACHE_TTL, CONTEXT_WINDOW_SIZE, - API_TIMEOUT, SUMMARY_API_TIMEOUT, API_RETRY_ATTEMPTS, API_RETRY_DELAY, - PROACTIVE_LULL_THRESHOLD, PROACTIVE_BOT_SILENCE_THRESHOLD, PROACTIVE_LULL_CHANCE, - PROACTIVE_TOPIC_RELEVANCE_THRESHOLD, PROACTIVE_TOPIC_CHANCE, - PROACTIVE_RELATIONSHIP_SCORE_THRESHOLD, PROACTIVE_RELATIONSHIP_CHANCE, - INTEREST_UPDATE_INTERVAL, INTEREST_DECAY_INTERVAL_HOURS, - LEARNING_UPDATE_INTERVAL, TOPIC_UPDATE_INTERVAL, SENTIMENT_UPDATE_INTERVAL, - EVOLUTION_UPDATE_INTERVAL, RESPONSE_SCHEMA, TOOLS, # Import necessary configs - IGNORED_CHANNEL_IDS, update_ignored_channels_file # Import for ignored channels + PROJECT_ID, + LOCATION, + TAVILY_API_KEY, + TENOR_API_KEY, + DEFAULT_MODEL, + FALLBACK_MODEL, # Use GCP config + DB_PATH, + CHROMA_PATH, + SEMANTIC_MODEL_NAME, + MAX_USER_FACTS, + MAX_GENERAL_FACTS, + MOOD_OPTIONS, + BASELINE_PERSONALITY, + BASELINE_INTERESTS, + MOOD_CHANGE_INTERVAL_MIN, + MOOD_CHANGE_INTERVAL_MAX, + CHANNEL_TOPIC_CACHE_TTL, + CONTEXT_WINDOW_SIZE, + API_TIMEOUT, + SUMMARY_API_TIMEOUT, + API_RETRY_ATTEMPTS, + API_RETRY_DELAY, + PROACTIVE_LULL_THRESHOLD, + PROACTIVE_BOT_SILENCE_THRESHOLD, + PROACTIVE_LULL_CHANCE, + PROACTIVE_TOPIC_RELEVANCE_THRESHOLD, + PROACTIVE_TOPIC_CHANCE, + PROACTIVE_RELATIONSHIP_SCORE_THRESHOLD, + PROACTIVE_RELATIONSHIP_CHANCE, + INTEREST_UPDATE_INTERVAL, + INTEREST_DECAY_INTERVAL_HOURS, + LEARNING_UPDATE_INTERVAL, + TOPIC_UPDATE_INTERVAL, + SENTIMENT_UPDATE_INTERVAL, + EVOLUTION_UPDATE_INTERVAL, + RESPONSE_SCHEMA, + TOOLS, # Import necessary configs + IGNORED_CHANNEL_IDS, + update_ignored_channels_file, # Import for ignored channels ) + # Import functions/classes from other modules from .memory import MemoryManager from .emojis import EmojiManager from .background import background_processing_task from .commands import setup_commands from .listeners import ( - on_ready_listener, on_message_listener, on_reaction_add_listener, - on_reaction_remove_listener, on_guild_join_listener, # Added on_guild_join_listener - on_guild_emojis_update_listener, on_guild_stickers_update_listener, # Added emoji/sticker update listeners - on_voice_transcription_received_listener, # Added voice transcription listener - on_voice_state_update_listener # Added voice state update listener + on_ready_listener, + on_message_listener, + on_reaction_add_listener, + on_reaction_remove_listener, + on_guild_join_listener, # Added on_guild_join_listener + on_guild_emojis_update_listener, + on_guild_stickers_update_listener, # Added emoji/sticker update listeners + on_voice_transcription_received_listener, # Added voice transcription listener + on_voice_state_update_listener, # Added voice state update listener ) -from . import api # Import api to access generate_image_description +from . import api # Import api to access generate_image_description from . import config as GurtConfig + # Tool mapping is used internally by api.py/process_requested_tools, no need to import here directly unless cog methods call tools directly (they shouldn't) # Analysis, context, prompt, utils functions are called by listeners/commands/background task, not directly by cog methods here usually. # Load environment variables (might be loaded globally in main bot script too) load_dotenv() -class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name + +class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name """A special cog for the Gurt bot that uses Google Vertex AI API""" def __init__(self, bot): self.bot = bot # GCP Project/Location are used by vertexai.init() in api.py - self.tavily_api_key = TAVILY_API_KEY # Use imported config - self.TENOR_API_KEY = TENOR_API_KEY # Store Tenor API Key - self.session: Optional[aiohttp.ClientSession] = None # Keep for other potential HTTP requests (e.g., Piston) - self.tavily_client = TavilyClient(api_key=self.tavily_api_key) if self.tavily_api_key else None - self.default_model = DEFAULT_MODEL # Use imported config - self.fallback_model = FALLBACK_MODEL # Use imported config - self.MOOD_OPTIONS = MOOD_OPTIONS # Make MOOD_OPTIONS available as an instance attribute - self.BASELINE_PERSONALITY = BASELINE_PERSONALITY # Store for commands - self.BASELINE_INTERESTS = BASELINE_INTERESTS # Store for commands - self.current_channel: Optional[Union[discord.TextChannel, discord.Thread, discord.DMChannel]] = None # Type hint current channel - + self.tavily_api_key = TAVILY_API_KEY # Use imported config + self.TENOR_API_KEY = TENOR_API_KEY # Store Tenor API Key + self.session: Optional[aiohttp.ClientSession] = ( + None # Keep for other potential HTTP requests (e.g., Piston) + ) + self.tavily_client = ( + TavilyClient(api_key=self.tavily_api_key) if self.tavily_api_key else None + ) + self.default_model = DEFAULT_MODEL # Use imported config + self.fallback_model = FALLBACK_MODEL # Use imported config + self.MOOD_OPTIONS = ( + MOOD_OPTIONS # Make MOOD_OPTIONS available as an instance attribute + ) + self.BASELINE_PERSONALITY = BASELINE_PERSONALITY # Store for commands + self.BASELINE_INTERESTS = BASELINE_INTERESTS # Store for commands + self.current_channel: Optional[ + Union[discord.TextChannel, discord.Thread, discord.DMChannel] + ] = None # Type hint current channel + # Ignored channels config self.IGNORED_CHANNEL_IDS = IGNORED_CHANNEL_IDS self.update_ignored_channels_file = update_ignored_channels_file @@ -76,15 +119,15 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name max_user_facts=MAX_USER_FACTS, max_general_facts=MAX_GENERAL_FACTS, chroma_path=CHROMA_PATH, - semantic_model_name=SEMANTIC_MODEL_NAME + semantic_model_name=SEMANTIC_MODEL_NAME, ) - self.emoji_manager = EmojiManager() # Initialize EmojiManager + self.emoji_manager = EmojiManager() # Initialize EmojiManager # --- State Variables --- # Keep state directly within the cog instance for now self.current_mood = random.choice(MOOD_OPTIONS) self.last_mood_change = time.time() - self.needs_json_reminder = False # Flag to remind AI about JSON format + self.needs_json_reminder = False # Flag to remind AI about JSON format # Learning variables (Consider moving to a dedicated state/learning manager later) self.conversation_patterns = defaultdict(list) @@ -94,30 +137,40 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name # self.learning_update_interval = LEARNING_UPDATE_INTERVAL # Interval used in background task # Topic tracking - self.active_topics = defaultdict(lambda: { - "topics": [], "last_update": time.time(), "topic_history": [], - "user_topic_interests": defaultdict(list) - }) + self.active_topics = defaultdict( + lambda: { + "topics": [], + "last_update": time.time(), + "topic_history": [], + "user_topic_interests": defaultdict(list), + } + ) # self.topic_update_interval = TOPIC_UPDATE_INTERVAL # Used in analysis # Conversation tracking / Caches self.conversation_history = defaultdict(lambda: deque(maxlen=100)) self.thread_history = defaultdict(lambda: deque(maxlen=50)) self.user_conversation_mapping = defaultdict(set) - self.channel_activity = defaultdict(lambda: 0.0) # Use float for timestamp + self.channel_activity = defaultdict(lambda: 0.0) # Use float for timestamp self.conversation_topics = defaultdict(str) self.user_relationships = defaultdict(dict) - self.conversation_summaries: Dict[int, Dict[str, Any]] = {} # Store dict with summary and timestamp - self.channel_topics_cache: Dict[int, Dict[str, Any]] = {} # Store dict with topic and timestamp + self.conversation_summaries: Dict[int, Dict[str, Any]] = ( + {} + ) # Store dict with summary and timestamp + self.channel_topics_cache: Dict[int, Dict[str, Any]] = ( + {} + ) # Store dict with topic and timestamp # self.channel_topic_cache_ttl = CHANNEL_TOPIC_CACHE_TTL # Used in prompt building self.message_cache = { - 'by_channel': defaultdict(lambda: deque(maxlen=CONTEXT_WINDOW_SIZE)), # Use config - 'by_user': defaultdict(lambda: deque(maxlen=50)), - 'by_thread': defaultdict(lambda: deque(maxlen=50)), - 'global_recent': deque(maxlen=200), - 'mentioned': deque(maxlen=50), - 'replied_to': defaultdict(lambda: deque(maxlen=20)) + "by_channel": defaultdict( + lambda: deque(maxlen=CONTEXT_WINDOW_SIZE) + ), # Use config + "by_user": defaultdict(lambda: deque(maxlen=50)), + "by_thread": defaultdict(lambda: deque(maxlen=50)), + "global_recent": deque(maxlen=200), + "mentioned": deque(maxlen=50), + "replied_to": defaultdict(lambda: deque(maxlen=20)), } self.active_conversations = {} @@ -125,30 +178,55 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name self.message_reply_map = {} # Enhanced sentiment tracking - self.conversation_sentiment = defaultdict(lambda: { - "overall": "neutral", "intensity": 0.5, "recent_trend": "stable", - "user_sentiments": {}, "last_update": time.time() - }) - self.sentiment_update_interval = SENTIMENT_UPDATE_INTERVAL # Used in analysis + self.conversation_sentiment = defaultdict( + lambda: { + "overall": "neutral", + "intensity": 0.5, + "recent_trend": "stable", + "user_sentiments": {}, + "last_update": time.time(), + } + ) + self.sentiment_update_interval = SENTIMENT_UPDATE_INTERVAL # Used in analysis # Interest Tracking State self.gurt_participation_topics = defaultdict(int) self.last_interest_update = time.time() - self.gurt_message_reactions = defaultdict(lambda: {"positive": 0, "negative": 0, "topic": None, "timestamp": 0.0}) # Added timestamp + self.gurt_message_reactions = defaultdict( + lambda: {"positive": 0, "negative": 0, "topic": None, "timestamp": 0.0} + ) # Added timestamp # Background task handle self.background_task: Optional[asyncio.Task] = None - self.last_evolution_update = time.time() # Used in background task - self.last_stats_push = time.time() # Timestamp for last stats push - self.last_reflection_time = time.time() # Timestamp for last memory reflection - self.last_goal_check_time = time.time() # Timestamp for last goal decomposition check - self.last_goal_execution_time = time.time() # Timestamp for last goal execution check - self.last_proactive_goal_check = time.time() # Timestamp for last proactive goal check - self.last_internal_action_check = time.time() # Timestamp for last internal action check + self.last_evolution_update = time.time() # Used in background task + self.last_stats_push = time.time() # Timestamp for last stats push + self.last_reflection_time = time.time() # Timestamp for last memory reflection + self.last_goal_check_time = ( + time.time() + ) # Timestamp for last goal decomposition check + self.last_goal_execution_time = ( + time.time() + ) # Timestamp for last goal execution check + self.last_proactive_goal_check = ( + time.time() + ) # Timestamp for last proactive goal check + self.last_internal_action_check = ( + time.time() + ) # Timestamp for last internal action check # --- Stats Tracking --- - self.api_stats = defaultdict(lambda: {"success": 0, "failure": 0, "retries": 0, "total_time": 0.0, "count": 0}) # Keyed by model name - self.tool_stats = defaultdict(lambda: {"success": 0, "failure": 0, "total_time": 0.0, "count": 0}) # Keyed by tool name + self.api_stats = defaultdict( + lambda: { + "success": 0, + "failure": 0, + "retries": 0, + "total_time": 0.0, + "count": 0, + } + ) # Keyed by model name + self.tool_stats = defaultdict( + lambda: {"success": 0, "failure": 0, "total_time": 0.0, "count": 0} + ) # Keyed by tool name # --- Setup Commands and Listeners --- # Add commands defined in commands.py @@ -185,16 +263,20 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name # Vertex AI initialization happens in api.py using PROJECT_ID and LOCATION from config print(f"GurtCog: Using default model: {self.default_model}") if not self.tavily_api_key: - print("WARNING: Tavily API key not configured (TAVILY_API_KEY). Web search disabled.") + print( + "WARNING: Tavily API key not configured (TAVILY_API_KEY). Web search disabled." + ) # Add listeners to the bot instance # We need to define the listener functions here to properly register them # IMPORTANT: Don't override on_member_join or on_member_remove events # Check if the bot already has event listeners for member join/leave - has_member_join = 'on_member_join' in self.bot.extra_events - has_member_remove = 'on_member_remove' in self.bot.extra_events - print(f"GurtCog: Bot already has event listeners - on_member_join: {has_member_join}, on_member_remove: {has_member_remove}") + has_member_join = "on_member_join" in self.bot.extra_events + has_member_remove = "on_member_remove" in self.bot.extra_events + print( + f"GurtCog: Bot already has event listeners - on_member_join: {has_member_join}, on_member_remove: {has_member_remove}" + ) @self.bot.event async def on_ready(): @@ -234,22 +316,30 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name # Listener for voice transcriptions @self.bot.event - async def on_voice_transcription_received(guild: discord.Guild, user: discord.Member, text: str): + async def on_voice_transcription_received( + guild: discord.Guild, user: discord.Member, text: str + ): # This event is dispatched by VoiceGatewayCog await on_voice_transcription_received_listener(self, guild, user, text) @self.bot.event - async def on_voice_state_update(member: discord.Member, before: discord.VoiceState, after: discord.VoiceState): + async def on_voice_state_update( + member: discord.Member, + before: discord.VoiceState, + after: discord.VoiceState, + ): await on_voice_state_update_listener(self, member, before, after) - print("GurtCog: Additional guild, custom, and voice state event listeners added.") + print( + "GurtCog: Additional guild, custom, and voice state event listeners added." + ) # Start background task if self.background_task is None or self.background_task.done(): self.background_task = asyncio.create_task(background_processing_task(self)) print("GurtCog: Started background processing task.") else: - print("GurtCog: Background processing task already running.") + print("GurtCog: Background processing task already running.") async def cog_unload(self): """Close session and cancel background task""" @@ -272,11 +362,13 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name """Updates the relationship score between two users.""" # This method accesses self.user_relationships, so it stays here or utils needs cog passed. # Let's keep it here for simplicity for now. - if user_id_1 > user_id_2: user_id_1, user_id_2 = user_id_2, user_id_1 - if user_id_1 not in self.user_relationships: self.user_relationships[user_id_1] = {} + if user_id_1 > user_id_2: + user_id_1, user_id_2 = user_id_2, user_id_1 + if user_id_1 not in self.user_relationships: + self.user_relationships[user_id_1] = {} current_score = self.user_relationships[user_id_1].get(user_id_2, 0.0) - new_score = max(0.0, min(current_score + change, 100.0)) # Clamp 0-100 + new_score = max(0.0, min(current_score + change, 100.0)) # Clamp 0-100 self.user_relationships[user_id_1][user_id_2] = new_score # print(f"Updated relationship {user_id_1}-{user_id_2}: {current_score:.1f} -> {new_score:.1f} ({change:+.1f})") # Debug log @@ -285,75 +377,124 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name try: name_key = f":{emoji.name}:" emoji_url = str(emoji.url) - guild_id = emoji.guild.id # Get guild_id from the emoji object + guild_id = emoji.guild.id # Get guild_id from the emoji object existing_emoji = await self.emoji_manager.get_emoji(name_key) - if existing_emoji and \ - existing_emoji.get("id") == str(emoji.id) and \ - existing_emoji.get("url") == emoji_url and \ - existing_emoji.get("description") and \ - existing_emoji.get("description") != "No description generated. (Likely filtered by AI or file type was unsupported by model)": + if ( + existing_emoji + and existing_emoji.get("id") == str(emoji.id) + and existing_emoji.get("url") == emoji_url + and existing_emoji.get("description") + and existing_emoji.get("description") + != "No description generated. (Likely filtered by AI or file type was unsupported by model)" + ): # print(f"Skipping already processed emoji: {name_key} in guild {emoji.guild.name}") return - print(f"Generating description for emoji: {name_key} in guild {emoji.guild.name}") + print( + f"Generating description for emoji: {name_key} in guild {emoji.guild.name}" + ) mime_type = "image/gif" if emoji.animated else "image/png" - description = await api.generate_image_description(self, emoji_url, emoji.name, "emoji", mime_type) - await self.emoji_manager.add_emoji(name_key, str(emoji.id), emoji.animated, guild_id, emoji_url, description or "No description generated. (Likely filtered by AI or file type was unsupported by model)") + description = await api.generate_image_description( + self, emoji_url, emoji.name, "emoji", mime_type + ) + await self.emoji_manager.add_emoji( + name_key, + str(emoji.id), + emoji.animated, + guild_id, + emoji_url, + description + or "No description generated. (Likely filtered by AI or file type was unsupported by model)", + ) # await asyncio.sleep(1) # Rate limiting removed for faster parallel processing except Exception as e: - print(f"Error processing single emoji {emoji.name} (ID: {emoji.id}) in guild {emoji.guild.name}: {e}") + print( + f"Error processing single emoji {emoji.name} (ID: {emoji.id}) in guild {emoji.guild.name}: {e}" + ) async def _process_single_sticker(self, sticker: discord.StickerItem): """Processes a single sticker: generates description if needed and updates EmojiManager.""" try: name_key = f":{sticker.name}:" sticker_url = str(sticker.url) - guild_id = sticker.guild_id # Stickers have guild_id directly + guild_id = sticker.guild_id # Stickers have guild_id directly existing_sticker = await self.emoji_manager.get_sticker(name_key) - if existing_sticker and \ - existing_sticker.get("id") == str(sticker.id) and \ - existing_sticker.get("url") == sticker_url and \ - existing_sticker.get("description"): + if ( + existing_sticker + and existing_sticker.get("id") == str(sticker.id) + and existing_sticker.get("url") == sticker_url + and existing_sticker.get("description") + ): # print(f"Skipping already processed sticker: {name_key} in guild ID {guild_id}") return - print(f"Generating description for sticker: {sticker.name} (ID: {sticker.id}) in guild ID {guild_id}") + print( + f"Generating description for sticker: {sticker.name} (ID: {sticker.id}) in guild ID {guild_id}" + ) description_to_add = "No description generated. (Likely filtered by AI or file type was unsupported by model)" - if sticker.format == discord.StickerFormatType.png or sticker.format == discord.StickerFormatType.apng or sticker.format == discord.StickerFormatType.gif: + if ( + sticker.format == discord.StickerFormatType.png + or sticker.format == discord.StickerFormatType.apng + or sticker.format == discord.StickerFormatType.gif + ): format_to_mime = { - discord.StickerFormatType.png: "image/png", - discord.StickerFormatType.apng: "image/apng", - discord.StickerFormatType.gif: "image/gif", + discord.StickerFormatType.png: "image/png", + discord.StickerFormatType.apng: "image/apng", + discord.StickerFormatType.gif: "image/gif", } mime_type = format_to_mime.get(sticker.format, "image/png") - description = await api.generate_image_description(self, sticker_url, sticker.name, "sticker", mime_type) - description_to_add = description or "No description generated. (Likely filtered by AI or file type was unsupported by model)" + description = await api.generate_image_description( + self, sticker_url, sticker.name, "sticker", mime_type + ) + description_to_add = ( + description + or "No description generated. (Likely filtered by AI or file type was unsupported by model)" + ) elif sticker.format == discord.StickerFormatType.lottie: - description_to_add = "Lottie animation, visual description not applicable." + description_to_add = ( + "Lottie animation, visual description not applicable." + ) else: - print(f"Skipping sticker {sticker.name} due to unsupported format: {sticker.format}") + print( + f"Skipping sticker {sticker.name} due to unsupported format: {sticker.format}" + ) description_to_add = f"Unsupported format: {sticker.format}, visual description not applicable." - - await self.emoji_manager.add_sticker(name_key, str(sticker.id), guild_id, sticker_url, description_to_add) + + await self.emoji_manager.add_sticker( + name_key, str(sticker.id), guild_id, sticker_url, description_to_add + ) # await asyncio.sleep(1) # Rate limiting removed for faster parallel processing except Exception as e: - print(f"Error processing single sticker {sticker.name} (ID: {sticker.id}) in guild ID {sticker.guild_id}: {e}") + print( + f"Error processing single sticker {sticker.name} (ID: {sticker.id}) in guild ID {sticker.guild_id}: {e}" + ) async def _fetch_and_process_guild_assets(self, guild: discord.Guild): """Iterates through a guild's emojis and stickers, and processes each one concurrently.""" print(f"Queueing asset processing for guild: {guild.name} ({guild.id})") - emoji_tasks = [asyncio.create_task(self._process_single_emoji(emoji)) for emoji in guild.emojis] - sticker_tasks = [asyncio.create_task(self._process_single_sticker(sticker)) for sticker in guild.stickers] - + emoji_tasks = [ + asyncio.create_task(self._process_single_emoji(emoji)) + for emoji in guild.emojis + ] + sticker_tasks = [ + asyncio.create_task(self._process_single_sticker(sticker)) + for sticker in guild.stickers + ] + all_tasks = emoji_tasks + sticker_tasks if all_tasks: - await asyncio.gather(*all_tasks, return_exceptions=True) # Wait for all tasks for this guild to complete - print(f"Finished concurrent asset processing for guild: {guild.name} ({guild.id}). Processed {len(all_tasks)} potential items.") + await asyncio.gather( + *all_tasks, return_exceptions=True + ) # Wait for all tasks for this guild to complete + print( + f"Finished concurrent asset processing for guild: {guild.name} ({guild.id}). Processed {len(all_tasks)} potential items." + ) else: - print(f"No emojis or stickers to process for guild: {guild.name} ({guild.id})") - + print( + f"No emojis or stickers to process for guild: {guild.name} ({guild.id})" + ) async def initial_emoji_sticker_scan(self): """Scans all guilds GURT is in on startup for emojis and stickers.""" @@ -364,17 +505,22 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name # Create a task for each guild task = asyncio.create_task(self._fetch_and_process_guild_assets(guild)) tasks.append(task) - + # Optionally, wait for all tasks to complete if needed, or let them run in background # For a startup scan, it's probably fine to let them run without blocking on_ready too long. # If you need to ensure all are done before something else, you can await asyncio.gather(*tasks) # For now, just creating them to run concurrently. print(f"Created {len(tasks)} tasks for initial emoji/sticker scan.") - async def get_gurt_stats(self) -> Dict[str, Any]: """Collects various internal stats for Gurt.""" - stats = {"config": {}, "runtime": {}, "memory": {}, "api_stats": {}, "tool_stats": {}} + stats = { + "config": {}, + "runtime": {}, + "memory": {}, + "api_stats": {}, + "tool_stats": {}, + } # --- Config --- # Selectively pull relevant config values, avoid exposing secrets @@ -386,22 +532,44 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name stats["config"]["semantic_model_name"] = GurtConfig.SEMANTIC_MODEL_NAME stats["config"]["max_user_facts"] = GurtConfig.MAX_USER_FACTS stats["config"]["max_general_facts"] = GurtConfig.MAX_GENERAL_FACTS - stats["config"]["mood_change_interval_min"] = GurtConfig.MOOD_CHANGE_INTERVAL_MIN - stats["config"]["mood_change_interval_max"] = GurtConfig.MOOD_CHANGE_INTERVAL_MAX - stats["config"]["evolution_update_interval"] = GurtConfig.EVOLUTION_UPDATE_INTERVAL + stats["config"][ + "mood_change_interval_min" + ] = GurtConfig.MOOD_CHANGE_INTERVAL_MIN + stats["config"][ + "mood_change_interval_max" + ] = GurtConfig.MOOD_CHANGE_INTERVAL_MAX + stats["config"][ + "evolution_update_interval" + ] = GurtConfig.EVOLUTION_UPDATE_INTERVAL stats["config"]["context_window_size"] = GurtConfig.CONTEXT_WINDOW_SIZE stats["config"]["api_timeout"] = GurtConfig.API_TIMEOUT stats["config"]["summary_api_timeout"] = GurtConfig.SUMMARY_API_TIMEOUT - stats["config"]["proactive_lull_threshold"] = GurtConfig.PROACTIVE_LULL_THRESHOLD - stats["config"]["proactive_bot_silence_threshold"] = GurtConfig.PROACTIVE_BOT_SILENCE_THRESHOLD - stats["config"]["interest_update_interval"] = GurtConfig.INTEREST_UPDATE_INTERVAL - stats["config"]["interest_decay_interval_hours"] = GurtConfig.INTEREST_DECAY_INTERVAL_HOURS - stats["config"]["learning_update_interval"] = GurtConfig.LEARNING_UPDATE_INTERVAL + stats["config"][ + "proactive_lull_threshold" + ] = GurtConfig.PROACTIVE_LULL_THRESHOLD + stats["config"][ + "proactive_bot_silence_threshold" + ] = GurtConfig.PROACTIVE_BOT_SILENCE_THRESHOLD + stats["config"][ + "interest_update_interval" + ] = GurtConfig.INTEREST_UPDATE_INTERVAL + stats["config"][ + "interest_decay_interval_hours" + ] = GurtConfig.INTEREST_DECAY_INTERVAL_HOURS + stats["config"][ + "learning_update_interval" + ] = GurtConfig.LEARNING_UPDATE_INTERVAL stats["config"]["topic_update_interval"] = GurtConfig.TOPIC_UPDATE_INTERVAL - stats["config"]["sentiment_update_interval"] = GurtConfig.SENTIMENT_UPDATE_INTERVAL + stats["config"][ + "sentiment_update_interval" + ] = GurtConfig.SENTIMENT_UPDATE_INTERVAL stats["config"]["docker_command_timeout"] = GurtConfig.DOCKER_COMMAND_TIMEOUT - stats["config"]["project_id_set"] = bool(GurtConfig.PROJECT_ID != "your-gcp-project-id") # Check if default is overridden - stats["config"]["location_set"] = bool(GurtConfig.LOCATION != "us-central1") # Check if default is overridden + stats["config"]["project_id_set"] = bool( + GurtConfig.PROJECT_ID != "your-gcp-project-id" + ) # Check if default is overridden + stats["config"]["location_set"] = bool( + GurtConfig.LOCATION != "us-central1" + ) # Check if default is overridden stats["config"]["tavily_api_key_set"] = bool(GurtConfig.TAVILY_API_KEY) stats["config"]["piston_api_url_set"] = bool(GurtConfig.PISTON_API_URL) @@ -412,24 +580,44 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name stats["runtime"]["last_learning_update_timestamp"] = self.last_learning_update stats["runtime"]["last_interest_update_timestamp"] = self.last_interest_update stats["runtime"]["last_evolution_update_timestamp"] = self.last_evolution_update - stats["runtime"]["background_task_running"] = bool(self.background_task and not self.background_task.done()) + stats["runtime"]["background_task_running"] = bool( + self.background_task and not self.background_task.done() + ) stats["runtime"]["active_topics_channels"] = len(self.active_topics) - stats["runtime"]["conversation_history_channels"] = len(self.conversation_history) + stats["runtime"]["conversation_history_channels"] = len( + self.conversation_history + ) stats["runtime"]["thread_history_threads"] = len(self.thread_history) - stats["runtime"]["user_conversation_mappings"] = len(self.user_conversation_mapping) + stats["runtime"]["user_conversation_mappings"] = len( + self.user_conversation_mapping + ) stats["runtime"]["channel_activity_tracked"] = len(self.channel_activity) stats["runtime"]["conversation_topics_tracked"] = len(self.conversation_topics) - stats["runtime"]["user_relationships_pairs"] = sum(len(v) for v in self.user_relationships.values()) - stats["runtime"]["conversation_summaries_cached"] = len(self.conversation_summaries) + stats["runtime"]["user_relationships_pairs"] = sum( + len(v) for v in self.user_relationships.values() + ) + stats["runtime"]["conversation_summaries_cached"] = len( + self.conversation_summaries + ) stats["runtime"]["channel_topics_cached"] = len(self.channel_topics_cache) - stats["runtime"]["message_cache_global_count"] = len(self.message_cache['global_recent']) - stats["runtime"]["message_cache_mentioned_count"] = len(self.message_cache['mentioned']) + stats["runtime"]["message_cache_global_count"] = len( + self.message_cache["global_recent"] + ) + stats["runtime"]["message_cache_mentioned_count"] = len( + self.message_cache["mentioned"] + ) stats["runtime"]["active_conversations_count"] = len(self.active_conversations) stats["runtime"]["bot_last_spoke_channels"] = len(self.bot_last_spoke) stats["runtime"]["message_reply_map_size"] = len(self.message_reply_map) - stats["runtime"]["conversation_sentiment_channels"] = len(self.conversation_sentiment) - stats["runtime"]["gurt_participation_topics_count"] = len(self.gurt_participation_topics) - stats["runtime"]["gurt_message_reactions_tracked"] = len(self.gurt_message_reactions) + stats["runtime"]["conversation_sentiment_channels"] = len( + self.conversation_sentiment + ) + stats["runtime"]["gurt_participation_topics_count"] = len( + self.gurt_participation_topics + ) + stats["runtime"]["gurt_message_reactions_tracked"] = len( + self.gurt_message_reactions + ) # --- Memory (via MemoryManager) --- try: @@ -438,19 +626,37 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name stats["memory"]["personality_traits"] = personality # Interests - interests = await self.memory_manager.get_interests(limit=20, min_level=0.01) # Get top 20 + interests = await self.memory_manager.get_interests( + limit=20, min_level=0.01 + ) # Get top 20 stats["memory"]["top_interests"] = interests # Fact Counts (Requires adding methods to MemoryManager or direct query) # Example placeholder - needs implementation in MemoryManager or here - user_fact_count = await self.memory_manager._db_fetchone("SELECT COUNT(*) FROM user_facts") - general_fact_count = await self.memory_manager._db_fetchone("SELECT COUNT(*) FROM general_facts") - stats["memory"]["user_facts_count"] = user_fact_count[0] if user_fact_count else 0 - stats["memory"]["general_facts_count"] = general_fact_count[0] if general_fact_count else 0 + user_fact_count = await self.memory_manager._db_fetchone( + "SELECT COUNT(*) FROM user_facts" + ) + general_fact_count = await self.memory_manager._db_fetchone( + "SELECT COUNT(*) FROM general_facts" + ) + stats["memory"]["user_facts_count"] = ( + user_fact_count[0] if user_fact_count else 0 + ) + stats["memory"]["general_facts_count"] = ( + general_fact_count[0] if general_fact_count else 0 + ) # ChromaDB Stats (Placeholder - ChromaDB client API might offer this) - stats["memory"]["chromadb_message_collection_count"] = await asyncio.to_thread(self.memory_manager.semantic_collection.count) if self.memory_manager.semantic_collection else "N/A" - stats["memory"]["chromadb_fact_collection_count"] = await asyncio.to_thread(self.memory_manager.fact_collection.count) if self.memory_manager.fact_collection else "N/A" + stats["memory"]["chromadb_message_collection_count"] = ( + await asyncio.to_thread(self.memory_manager.semantic_collection.count) + if self.memory_manager.semantic_collection + else "N/A" + ) + stats["memory"]["chromadb_fact_collection_count"] = ( + await asyncio.to_thread(self.memory_manager.fact_collection.count) + if self.memory_manager.fact_collection + else "N/A" + ) except Exception as e: stats["memory"]["error"] = f"Failed to retrieve memory stats: {e}" @@ -463,12 +669,16 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name # Calculate average times where count > 0 for model, data in stats["api_stats"].items(): if data["count"] > 0: - data["average_time_ms"] = round((data["total_time"] / data["count"]) * 1000, 2) + data["average_time_ms"] = round( + (data["total_time"] / data["count"]) * 1000, 2 + ) else: data["average_time_ms"] = 0 for tool, data in stats["tool_stats"].items(): if data["count"] > 0: - data["average_time_ms"] = round((data["total_time"] / data["count"]) * 1000, 2) + data["average_time_ms"] = round( + (data["total_time"] / data["count"]) * 1000, 2 + ) else: data["average_time_ms"] = 0 @@ -495,12 +705,20 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name # 1. Gather Context for LLM context_summary = "Gurt is considering an autonomous action.\n" context_summary += f"Current Mood: {self.current_mood}\n" - active_goals = await self.memory_manager.get_goals(status='active', limit=3) + active_goals = await self.memory_manager.get_goals(status="active", limit=3) if active_goals: - context_summary += f"Active Goals:\n" + json.dumps(active_goals, indent=2)[:500] + "...\n" + context_summary += ( + f"Active Goals:\n" + + json.dumps(active_goals, indent=2)[:500] + + "...\n" + ) recent_actions = await self.memory_manager.get_internal_action_logs(limit=5) if recent_actions: - context_summary += f"Recent Internal Actions:\n" + json.dumps(recent_actions, indent=2)[:500] + "...\n" + context_summary += ( + f"Recent Internal Actions:\n" + + json.dumps(recent_actions, indent=2)[:500] + + "...\n" + ) traits = await self.memory_manager.get_all_personality_traits() if traits: context_summary += f"Personality Snippet: { {k: round(v, 2) for k, v in traits.items() if k in ['mischief', 'curiosity', 'chattiness']} }\n" @@ -508,14 +726,32 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name action_decision_schema = { "type": "object", "properties": { - "should_act": {"type": "boolean", "description": "Whether Gurt should perform an autonomous action now."}, - "reasoning": {"type": "string", "description": "CRITICAL: Reasoning MUST be in Gurt's voice (heavy slang, informal, chill, maybe sarcastic/mischievous). Explain *why* you're acting (or not) like Gurt would think it (e.g., 'ngl kinda bored', 'ts might be useful for my goal', 'lol watch this'). Consider current goals, mood, recent activity."}, - "action_tool_name": {"type": ["string", "null"], "description": "If acting, the name of the tool to use. Choose from available tools, you can do literally anything. Null if not acting."}, - "action_arguments": {"type": ["object", "null"], "description": "If acting, a dictionary of arguments for the chosen tool. Null if not acting."} + "should_act": { + "type": "boolean", + "description": "Whether Gurt should perform an autonomous action now.", + }, + "reasoning": { + "type": "string", + "description": "CRITICAL: Reasoning MUST be in Gurt's voice (heavy slang, informal, chill, maybe sarcastic/mischievous). Explain *why* you're acting (or not) like Gurt would think it (e.g., 'ngl kinda bored', 'ts might be useful for my goal', 'lol watch this'). Consider current goals, mood, recent activity.", + }, + "action_tool_name": { + "type": ["string", "null"], + "description": "If acting, the name of the tool to use. Choose from available tools, you can do literally anything. Null if not acting.", + }, + "action_arguments": { + "type": ["object", "null"], + "description": "If acting, a dictionary of arguments for the chosen tool. Null if not acting.", + }, }, - "required": ["should_act", "reasoning"] + "required": ["should_act", "reasoning"], } - available_tools_desc = "\n".join([f"- {name}" for name in TOOL_MAPPING.keys() if name not in ["create_new_tool"]]) + available_tools_desc = "\n".join( + [ + f"- {name}" + for name in TOOL_MAPPING.keys() + if name not in ["create_new_tool"] + ] + ) system_prompt = ( "You are Gurt, deciding whether to perform an autonomous background action. " "Consider your current mood, active goals, recent conversations/actions, and personality (VERY informal, heavy internet slang/emoji use like 'ts', 'ngl', 'fr', 'bro', 'bruh', '💀'; funny, sarcastic, deadpan, sometimes mischievous or flirty; acts like a human user not an assistant; tech/internet culture savvy). " @@ -532,11 +768,14 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name # 3. Call LLM for Decision decision_data, _ = await get_internal_ai_json_response( cog=self, - prompt_messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], + prompt_messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], task_description="Autonomous Action Decision", response_schema_dict=action_decision_schema, model_name_override=self.default_model, - temperature=0.6 + temperature=0.6, ) # 4. Process LLM Decision @@ -546,16 +785,24 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name action_reasoning = decision_data.get("reasoning", "LLM decided to act.") if not selected_tool_name or selected_tool_name not in TOOL_MAPPING: - result_summary = f"Error: LLM chose invalid tool '{selected_tool_name}'." + result_summary = ( + f"Error: LLM chose invalid tool '{selected_tool_name}'." + ) selected_tool_name = None elif not isinstance(tool_args, dict) and tool_args is not None: - result_summary = f"Warning: LLM provided invalid args '{tool_args}'. Used {{}}." + result_summary = ( + f"Warning: LLM provided invalid args '{tool_args}'. Used {{}}." + ) tool_args = {} elif tool_args is None: tool_args = {} else: - action_reasoning = decision_data.get("reasoning", "LLM decided not to act or failed.") if decision_data else "LLM decision failed." + action_reasoning = ( + decision_data.get("reasoning", "LLM decided not to act or failed.") + if decision_data + else "LLM decision failed." + ) result_summary = f"No action taken. Reason: {action_reasoning}" except Exception as llm_e: @@ -591,7 +838,9 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name self.tool_stats[selected_tool_name]["failure"] += 1 traceback.print_exc() else: - result_summary = f"Error: Tool function for '{selected_tool_name}' not found." + result_summary = ( + f"Error: Tool function for '{selected_tool_name}' not found." + ) # 6. Log Action try: @@ -599,7 +848,7 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name tool_name=selected_tool_name or "None", arguments=tool_args if selected_tool_name else None, reasoning=action_reasoning, - result_summary=result_summary + result_summary=result_summary, ) except Exception: pass @@ -608,7 +857,7 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name "tool": selected_tool_name, "args": tool_args, "reasoning": action_reasoning, - "result": result_summary + "result": result_summary, } async def sync_commands(self): @@ -619,13 +868,18 @@ class GurtCog(commands.Cog, name="Gurt"): # Added explicit Cog name print(f"GurtCog: Synced {len(synced)} command(s)") # List the synced commands - gurt_commands = [cmd.name for cmd in self.bot.tree.get_commands() if cmd.name.startswith("gurt")] + gurt_commands = [ + cmd.name + for cmd in self.bot.tree.get_commands() + if cmd.name.startswith("gurt") + ] print(f"GurtCog: Available Gurt commands: {', '.join(gurt_commands)}") return synced, gurt_commands except Exception as e: print(f"GurtCog: Failed to sync commands: {e}") import traceback + traceback.print_exc() return [], [] diff --git a/gurt/commands.py b/gurt/commands.py index fac2020..5d56035 100644 --- a/gurt/commands.py +++ b/gurt/commands.py @@ -1,12 +1,12 @@ import discord -from discord import app_commands # Import app_commands +from discord import app_commands # Import app_commands from discord.ext import commands import random import os -import time # Import time for timestamps -import json # Import json for formatting -import datetime # Import datetime for formatting -from typing import TYPE_CHECKING, Optional, Dict, Any, List, Tuple # Add more types +import time # Import time for timestamps +import json # Import json for formatting +import datetime # Import datetime for formatting +from typing import TYPE_CHECKING, Optional, Dict, Any, List, Tuple # Add more types # Relative imports (assuming API functions are in api.py) # We need access to the cog instance for state and methods like get_ai_response @@ -16,20 +16,28 @@ try: from .config import AVAILABLE_AI_MODELS except (ImportError, AttributeError): AVAILABLE_AI_MODELS = { - "google/gemini-2.5-flash-preview-05-20": "Gemini 2.5 Flash Preview", - "google/gemini-2.5-pro-preview-05-06": "Gemini 2.5 Pro Preview", - "claude-sonnet-4@20250514": "Claude Sonnet 4", - "llama-4-maverick-17b-128e-instruct-maas": "Llama 4 Maverick Instruct", - "google/gemini-2.0-flash-001": "Gemini 2.0 Flash" - } + "google/gemini-2.5-flash-preview-05-20": "Gemini 2.5 Flash Preview", + "google/gemini-2.5-pro-preview-05-06": "Gemini 2.5 Pro Preview", + "claude-sonnet-4@20250514": "Claude Sonnet 4", + "llama-4-maverick-17b-128e-instruct-maas": "Llama 4 Maverick Instruct", + "google/gemini-2.0-flash-001": "Gemini 2.0 Flash", + } if TYPE_CHECKING: - from .cog import GurtCog # For type hinting - from .config import MOOD_OPTIONS, IGNORED_CHANNEL_IDS, update_ignored_channels_file, TENOR_API_KEY # Import for choices and ignored channels - from .emojis import EmojiManager # Import EmojiManager + from .cog import GurtCog # For type hinting + from .config import ( + MOOD_OPTIONS, + IGNORED_CHANNEL_IDS, + update_ignored_channels_file, + TENOR_API_KEY, + ) # Import for choices and ignored channels + from .emojis import EmojiManager # Import EmojiManager + # --- Helper Function for Embeds --- -def create_gurt_embed(title: str, description: str = "", color=discord.Color.blue()) -> discord.Embed: +def create_gurt_embed( + title: str, description: str = "", color=discord.Color.blue() +) -> discord.Embed: """Creates a standard Gurt-themed embed.""" embed = discord.Embed(title=title, description=description, color=color) # Placeholder icon URL, replace if Gurt has one @@ -37,31 +45,98 @@ def create_gurt_embed(title: str, description: str = "", color=discord.Color.blu embed.set_footer(text="Gurt") return embed + # --- Helper Function for Stats Embeds --- def format_stats_embeds(stats: Dict[str, Any]) -> List[discord.Embed]: """Formats the collected stats into multiple embeds.""" embeds = [] main_embed = create_gurt_embed("Gurt Internal Stats", color=discord.Color.green()) - ts_format = "" # Relative timestamp + ts_format = "" # Relative timestamp # Runtime Stats runtime = stats.get("runtime", {}) - main_embed.add_field(name="Current Mood", value=f"{runtime.get('current_mood', 'N/A')} (Changed {ts_format.format(ts=int(runtime.get('last_mood_change_timestamp', 0)))})", inline=False) - main_embed.add_field(name="Background Task", value="Running" if runtime.get('background_task_running') else "Stopped", inline=True) - main_embed.add_field(name="Needs JSON Reminder", value=str(runtime.get('needs_json_reminder', 'N/A')), inline=True) - main_embed.add_field(name="Last Evolution", value=ts_format.format(ts=int(runtime.get('last_evolution_update_timestamp', 0))), inline=True) - main_embed.add_field(name="Active Topics Channels", value=str(runtime.get('active_topics_channels', 'N/A')), inline=True) - main_embed.add_field(name="Conv History Channels", value=str(runtime.get('conversation_history_channels', 'N/A')), inline=True) - main_embed.add_field(name="Thread History Threads", value=str(runtime.get('thread_history_threads', 'N/A')), inline=True) - main_embed.add_field(name="User Relationships Pairs", value=str(runtime.get('user_relationships_pairs', 'N/A')), inline=True) - main_embed.add_field(name="Cached Summaries", value=str(runtime.get('conversation_summaries_cached', 'N/A')), inline=True) - main_embed.add_field(name="Cached Channel Topics", value=str(runtime.get('channel_topics_cached', 'N/A')), inline=True) - main_embed.add_field(name="Global Msg Cache", value=str(runtime.get('message_cache_global_count', 'N/A')), inline=True) - main_embed.add_field(name="Mention Msg Cache", value=str(runtime.get('message_cache_mentioned_count', 'N/A')), inline=True) - main_embed.add_field(name="Active Convos", value=str(runtime.get('active_conversations_count', 'N/A')), inline=True) - main_embed.add_field(name="Sentiment Channels", value=str(runtime.get('conversation_sentiment_channels', 'N/A')), inline=True) - main_embed.add_field(name="Gurt Participation Topics", value=str(runtime.get('gurt_participation_topics_count', 'N/A')), inline=True) - main_embed.add_field(name="Tracked Reactions", value=str(runtime.get('gurt_message_reactions_tracked', 'N/A')), inline=True) + main_embed.add_field( + name="Current Mood", + value=f"{runtime.get('current_mood', 'N/A')} (Changed {ts_format.format(ts=int(runtime.get('last_mood_change_timestamp', 0)))})", + inline=False, + ) + main_embed.add_field( + name="Background Task", + value="Running" if runtime.get("background_task_running") else "Stopped", + inline=True, + ) + main_embed.add_field( + name="Needs JSON Reminder", + value=str(runtime.get("needs_json_reminder", "N/A")), + inline=True, + ) + main_embed.add_field( + name="Last Evolution", + value=ts_format.format( + ts=int(runtime.get("last_evolution_update_timestamp", 0)) + ), + inline=True, + ) + main_embed.add_field( + name="Active Topics Channels", + value=str(runtime.get("active_topics_channels", "N/A")), + inline=True, + ) + main_embed.add_field( + name="Conv History Channels", + value=str(runtime.get("conversation_history_channels", "N/A")), + inline=True, + ) + main_embed.add_field( + name="Thread History Threads", + value=str(runtime.get("thread_history_threads", "N/A")), + inline=True, + ) + main_embed.add_field( + name="User Relationships Pairs", + value=str(runtime.get("user_relationships_pairs", "N/A")), + inline=True, + ) + main_embed.add_field( + name="Cached Summaries", + value=str(runtime.get("conversation_summaries_cached", "N/A")), + inline=True, + ) + main_embed.add_field( + name="Cached Channel Topics", + value=str(runtime.get("channel_topics_cached", "N/A")), + inline=True, + ) + main_embed.add_field( + name="Global Msg Cache", + value=str(runtime.get("message_cache_global_count", "N/A")), + inline=True, + ) + main_embed.add_field( + name="Mention Msg Cache", + value=str(runtime.get("message_cache_mentioned_count", "N/A")), + inline=True, + ) + main_embed.add_field( + name="Active Convos", + value=str(runtime.get("active_conversations_count", "N/A")), + inline=True, + ) + main_embed.add_field( + name="Sentiment Channels", + value=str(runtime.get("conversation_sentiment_channels", "N/A")), + inline=True, + ) + main_embed.add_field( + name="Gurt Participation Topics", + value=str(runtime.get("gurt_participation_topics_count", "N/A")), + inline=True, + ) + main_embed.add_field( + name="Tracked Reactions", + value=str(runtime.get("gurt_message_reactions_tracked", "N/A")), + inline=True, + ) embeds.append(main_embed) # Memory Stats @@ -70,20 +145,44 @@ def format_stats_embeds(stats: Dict[str, Any]) -> List[discord.Embed]: if memory.get("error"): memory_embed.description = f"⚠️ Error retrieving memory stats: {memory['error']}" else: - memory_embed.add_field(name="User Facts", value=str(memory.get('user_facts_count', 'N/A')), inline=True) - memory_embed.add_field(name="General Facts", value=str(memory.get('general_facts_count', 'N/A')), inline=True) - memory_embed.add_field(name="Chroma Messages", value=str(memory.get('chromadb_message_collection_count', 'N/A')), inline=True) - memory_embed.add_field(name="Chroma Facts", value=str(memory.get('chromadb_fact_collection_count', 'N/A')), inline=True) + memory_embed.add_field( + name="User Facts", + value=str(memory.get("user_facts_count", "N/A")), + inline=True, + ) + memory_embed.add_field( + name="General Facts", + value=str(memory.get("general_facts_count", "N/A")), + inline=True, + ) + memory_embed.add_field( + name="Chroma Messages", + value=str(memory.get("chromadb_message_collection_count", "N/A")), + inline=True, + ) + memory_embed.add_field( + name="Chroma Facts", + value=str(memory.get("chromadb_fact_collection_count", "N/A")), + inline=True, + ) personality = memory.get("personality_traits", {}) if personality: p_items = [f"`{k}`: {v}" for k, v in personality.items()] - memory_embed.add_field(name="Personality Traits", value="\n".join(p_items) if p_items else "None", inline=False) + memory_embed.add_field( + name="Personality Traits", + value="\n".join(p_items) if p_items else "None", + inline=False, + ) interests = memory.get("top_interests", []) if interests: i_items = [f"`{t}`: {l:.2f}" for t, l in interests] - memory_embed.add_field(name="Top Interests", value="\n".join(i_items) if i_items else "None", inline=False) + memory_embed.add_field( + name="Top Interests", + value="\n".join(i_items) if i_items else "None", + inline=False, + ) embeds.append(memory_embed) # API Stats @@ -91,12 +190,14 @@ def format_stats_embeds(stats: Dict[str, Any]) -> List[discord.Embed]: if api_stats: api_embed = create_gurt_embed("Gurt API Stats", color=discord.Color.red()) for model, data in api_stats.items(): - avg_time = data.get('average_time_ms', 0) - value = (f"✅ Success: {data.get('success', 0)}\n" - f"❌ Failure: {data.get('failure', 0)}\n" - f"🔁 Retries: {data.get('retries', 0)}\n" - f"⏱️ Avg Time: {avg_time} ms\n" - f"📊 Count: {data.get('count', 0)}") + avg_time = data.get("average_time_ms", 0) + value = ( + f"✅ Success: {data.get('success', 0)}\n" + f"❌ Failure: {data.get('failure', 0)}\n" + f"🔁 Retries: {data.get('retries', 0)}\n" + f"⏱️ Avg Time: {avg_time} ms\n" + f"📊 Count: {data.get('count', 0)}" + ) api_embed.add_field(name=f"Model: `{model}`", value=value, inline=True) embeds.append(api_embed) @@ -105,62 +206,120 @@ def format_stats_embeds(stats: Dict[str, Any]) -> List[discord.Embed]: if tool_stats: tool_embed = create_gurt_embed("Gurt Tool Stats", color=discord.Color.purple()) for tool, data in tool_stats.items(): - avg_time = data.get('average_time_ms', 0) - value = (f"✅ Success: {data.get('success', 0)}\n" - f"❌ Failure: {data.get('failure', 0)}\n" - f"⏱️ Avg Time: {avg_time} ms\n" - f"📊 Count: {data.get('count', 0)}") + avg_time = data.get("average_time_ms", 0) + value = ( + f"✅ Success: {data.get('success', 0)}\n" + f"❌ Failure: {data.get('failure', 0)}\n" + f"⏱️ Avg Time: {avg_time} ms\n" + f"📊 Count: {data.get('count', 0)}" + ) tool_embed.add_field(name=f"Tool: `{tool}`", value=value, inline=True) embeds.append(tool_embed) # Config Stats (Less critical, maybe separate embed if needed) - config_embed = create_gurt_embed("Gurt Config Overview", color=discord.Color.greyple()) + config_embed = create_gurt_embed( + "Gurt Config Overview", color=discord.Color.greyple() + ) config = stats.get("config", {}) - config_embed.add_field(name="Default Model", value=f"`{config.get('default_model', 'N/A')}`", inline=True) - config_embed.add_field(name="Fallback Model", value=f"`{config.get('fallback_model', 'N/A')}`", inline=True) - config_embed.add_field(name="Semantic Model", value=f"`{config.get('semantic_model_name', 'N/A')}`", inline=True) - config_embed.add_field(name="Max User Facts", value=str(config.get('max_user_facts', 'N/A')), inline=True) - config_embed.add_field(name="Max General Facts", value=str(config.get('max_general_facts', 'N/A')), inline=True) - config_embed.add_field(name="Context Window", value=str(config.get('context_window_size', 'N/A')), inline=True) - config_embed.add_field(name="API Key Set", value=str(config.get('api_key_set', 'N/A')), inline=True) - config_embed.add_field(name="Tavily Key Set", value=str(config.get('tavily_api_key_set', 'N/A')), inline=True) - config_embed.add_field(name="Piston URL Set", value=str(config.get('piston_api_url_set', 'N/A')), inline=True) - config_embed.add_field(name="Tenor API Key Set", value=str(config.get('tenor_api_key_set', 'N/A')), inline=True) # Added Tenor API Key + config_embed.add_field( + name="Default Model", + value=f"`{config.get('default_model', 'N/A')}`", + inline=True, + ) + config_embed.add_field( + name="Fallback Model", + value=f"`{config.get('fallback_model', 'N/A')}`", + inline=True, + ) + config_embed.add_field( + name="Semantic Model", + value=f"`{config.get('semantic_model_name', 'N/A')}`", + inline=True, + ) + config_embed.add_field( + name="Max User Facts", + value=str(config.get("max_user_facts", "N/A")), + inline=True, + ) + config_embed.add_field( + name="Max General Facts", + value=str(config.get("max_general_facts", "N/A")), + inline=True, + ) + config_embed.add_field( + name="Context Window", + value=str(config.get("context_window_size", "N/A")), + inline=True, + ) + config_embed.add_field( + name="API Key Set", value=str(config.get("api_key_set", "N/A")), inline=True + ) + config_embed.add_field( + name="Tavily Key Set", + value=str(config.get("tavily_api_key_set", "N/A")), + inline=True, + ) + config_embed.add_field( + name="Piston URL Set", + value=str(config.get("piston_api_url_set", "N/A")), + inline=True, + ) + config_embed.add_field( + name="Tenor API Key Set", + value=str(config.get("tenor_api_key_set", "N/A")), + inline=True, + ) # Added Tenor API Key embeds.append(config_embed) - # Limit to 10 embeds max for Discord API return embeds[:10] # --- Command Setup Function --- # This function will be called from GurtCog's setup method -def setup_commands(cog: 'GurtCog'): +def setup_commands(cog: "GurtCog"): """Adds Gurt-specific commands to the cog.""" # Create a list to store command functions for proper registration command_functions = [] # --- Gurt Mood Command --- - @cog.bot.tree.command(name="gurtmood", description="Check or set Gurt's current mood.") - @app_commands.describe(mood="Optional: Set Gurt's mood to one of the available options.") - @app_commands.choices(mood=[ - app_commands.Choice(name=m, value=m) for m in cog.MOOD_OPTIONS # Use cog's MOOD_OPTIONS - ]) - async def gurtmood(interaction: discord.Interaction, mood: Optional[app_commands.Choice[str]] = None): + @cog.bot.tree.command( + name="gurtmood", description="Check or set Gurt's current mood." + ) + @app_commands.describe( + mood="Optional: Set Gurt's mood to one of the available options." + ) + @app_commands.choices( + mood=[ + app_commands.Choice(name=m, value=m) + for m in cog.MOOD_OPTIONS # Use cog's MOOD_OPTIONS + ] + ) + async def gurtmood( + interaction: discord.Interaction, + mood: Optional[app_commands.Choice[str]] = None, + ): """Handles the /gurtmood command.""" # Check if user is the bot owner for mood setting if mood and interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can change Gurt's mood.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can change Gurt's mood.", ephemeral=True + ) return if mood: cog.current_mood = mood.value cog.last_mood_change = time.time() - await interaction.response.send_message(f"Gurt's mood set to: {mood.value}", ephemeral=True) + await interaction.response.send_message( + f"Gurt's mood set to: {mood.value}", ephemeral=True + ) else: time_since_change = time.time() - cog.last_mood_change - await interaction.response.send_message(f"Gurt's current mood is: {cog.current_mood} (Set {int(time_since_change // 60)} minutes ago)", ephemeral=True) + await interaction.response.send_message( + f"Gurt's current mood is: {cog.current_mood} (Set {int(time_since_change // 60)} minutes ago)", + ephemeral=True, + ) command_functions.append(gurtmood) @@ -170,65 +329,99 @@ def setup_commands(cog: 'GurtCog'): action="Choose an action: add_user, add_general, get_user, get_general", user="The user for user-specific actions (mention or ID).", fact="The fact to add (for add actions).", - query="A keyword to search for (for get_general)." + query="A keyword to search for (for get_general).", ) - @app_commands.choices(action=[ - app_commands.Choice(name="Add User Fact", value="add_user"), - app_commands.Choice(name="Add General Fact", value="add_general"), - app_commands.Choice(name="Get User Facts", value="get_user"), - app_commands.Choice(name="Get General Facts", value="get_general"), - ]) - async def gurtmemory(interaction: discord.Interaction, action: app_commands.Choice[str], user: Optional[discord.User] = None, fact: Optional[str] = None, query: Optional[str] = None): + @app_commands.choices( + action=[ + app_commands.Choice(name="Add User Fact", value="add_user"), + app_commands.Choice(name="Add General Fact", value="add_general"), + app_commands.Choice(name="Get User Facts", value="get_user"), + app_commands.Choice(name="Get General Facts", value="get_general"), + ] + ) + async def gurtmemory( + interaction: discord.Interaction, + action: app_commands.Choice[str], + user: Optional[discord.User] = None, + fact: Optional[str] = None, + query: Optional[str] = None, + ): """Handles the /gurtmemory command.""" - await interaction.response.defer(ephemeral=True) # Defer for potentially slow DB operations + await interaction.response.defer( + ephemeral=True + ) # Defer for potentially slow DB operations target_user_id = str(user.id) if user else None action_value = action.value # Check if user is the bot owner for modification actions - if (action_value in ["add_user", "add_general"]) and interaction.user.id != cog.bot.owner_id: - await interaction.followup.send("⛔ Only the bot owner can add facts to Gurt's memory.", ephemeral=True) + if ( + action_value in ["add_user", "add_general"] + ) and interaction.user.id != cog.bot.owner_id: + await interaction.followup.send( + "⛔ Only the bot owner can add facts to Gurt's memory.", ephemeral=True + ) return if action_value == "add_user": if not target_user_id or not fact: - await interaction.followup.send("Please provide both a user and a fact to add.", ephemeral=True) + await interaction.followup.send( + "Please provide both a user and a fact to add.", ephemeral=True + ) return result = await cog.memory_manager.add_user_fact(target_user_id, fact) - await interaction.followup.send(f"Add User Fact Result: `{json.dumps(result)}`", ephemeral=True) + await interaction.followup.send( + f"Add User Fact Result: `{json.dumps(result)}`", ephemeral=True + ) elif action_value == "add_general": if not fact: - await interaction.followup.send("Please provide a fact to add.", ephemeral=True) + await interaction.followup.send( + "Please provide a fact to add.", ephemeral=True + ) return result = await cog.memory_manager.add_general_fact(fact) - await interaction.followup.send(f"Add General Fact Result: `{json.dumps(result)}`", ephemeral=True) + await interaction.followup.send( + f"Add General Fact Result: `{json.dumps(result)}`", ephemeral=True + ) elif action_value == "get_user": if not target_user_id: - await interaction.followup.send("Please provide a user to get facts for.", ephemeral=True) + await interaction.followup.send( + "Please provide a user to get facts for.", ephemeral=True + ) return - facts = await cog.memory_manager.get_user_facts(target_user_id) # Get newest by default + facts = await cog.memory_manager.get_user_facts( + target_user_id + ) # Get newest by default if facts: facts_str = "\n- ".join(facts) - await interaction.followup.send(f"**Facts for {user.display_name}:**\n- {facts_str}", ephemeral=True) + await interaction.followup.send( + f"**Facts for {user.display_name}:**\n- {facts_str}", ephemeral=True + ) else: - await interaction.followup.send(f"No facts found for {user.display_name}.", ephemeral=True) + await interaction.followup.send( + f"No facts found for {user.display_name}.", ephemeral=True + ) elif action_value == "get_general": - facts = await cog.memory_manager.get_general_facts(query=query, limit=10) # Get newest/filtered + facts = await cog.memory_manager.get_general_facts( + query=query, limit=10 + ) # Get newest/filtered if facts: facts_str = "\n- ".join(facts) # Conditionally construct the title to avoid nested f-string issues if query: - title = f"**General Facts matching \"{query}\":**" + title = f'**General Facts matching "{query}":**' else: title = "**General Facts:**" - await interaction.followup.send(f"{title}\n- {facts_str}", ephemeral=True) + await interaction.followup.send( + f"{title}\n- {facts_str}", ephemeral=True + ) else: # Conditionally construct the message for the same reason if query: - message = f"No general facts found matching \"{query}\"." + message = f'No general facts found matching "{query}".' else: message = "No general facts found." await interaction.followup.send(message, ephemeral=True) @@ -239,11 +432,15 @@ def setup_commands(cog: 'GurtCog'): command_functions.append(gurtmemory) # --- Gurt Stats Command --- - @cog.bot.tree.command(name="gurtstats", description="Display Gurt's internal statistics. (Owner only)") + @cog.bot.tree.command( + name="gurtstats", description="Display Gurt's internal statistics. (Owner only)" + ) async def gurtstats(interaction: discord.Interaction): """Handles the /gurtstats command.""" - await interaction.response.defer(ephemeral=True) # Defer as stats collection might take time + await interaction.response.defer( + ephemeral=True + ) # Defer as stats collection might take time try: stats_data = await cog.get_gurt_stats() embeds = format_stats_embeds(stats_data) @@ -251,18 +448,25 @@ def setup_commands(cog: 'GurtCog'): except Exception as e: print(f"Error in /gurtstats command: {e}") import traceback + traceback.print_exc() - await interaction.followup.send("An error occurred while fetching Gurt's stats.", ephemeral=True) + await interaction.followup.send( + "An error occurred while fetching Gurt's stats.", ephemeral=True + ) command_functions.append(gurtstats) # --- Sync Gurt Commands (Owner Only) --- - @cog.bot.tree.command(name="gurtsync", description="Sync Gurt commands with Discord (Owner only)") + @cog.bot.tree.command( + name="gurtsync", description="Sync Gurt commands with Discord (Owner only)" + ) async def gurtsync(interaction: discord.Interaction): """Handles the /gurtsync command to force sync commands.""" # Check if user is the bot owner if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can sync commands.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can sync commands.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) @@ -272,32 +476,46 @@ def setup_commands(cog: 'GurtCog'): # Get list of commands after sync commands_after = [] - for cmd_obj in cog.bot.tree.get_commands(): # Iterate over Command objects + for cmd_obj in cog.bot.tree.get_commands(): # Iterate over Command objects if cmd_obj.name.startswith("gurt"): commands_after.append(cmd_obj.name) - - await interaction.followup.send(f"✅ Successfully synced {len(synced)} commands!\nGurt commands: {', '.join(commands_after)}", ephemeral=True) + await interaction.followup.send( + f"✅ Successfully synced {len(synced)} commands!\nGurt commands: {', '.join(commands_after)}", + ephemeral=True, + ) except Exception as e: print(f"Error in /gurtsync command: {e}") import traceback + traceback.print_exc() - await interaction.followup.send(f"❌ Error syncing commands: {str(e)}", ephemeral=True) + await interaction.followup.send( + f"❌ Error syncing commands: {str(e)}", ephemeral=True + ) command_functions.append(gurtsync) # --- Gurt Forget Command --- - @cog.bot.tree.command(name="gurtforget", description="Make Gurt forget a specific fact.") + @cog.bot.tree.command( + name="gurtforget", description="Make Gurt forget a specific fact." + ) @app_commands.describe( scope="Choose the scope: user (for facts about a specific user) or general.", fact="The exact fact text Gurt should forget.", - user="The user to forget a fact about (only if scope is 'user')." + user="The user to forget a fact about (only if scope is 'user').", ) - @app_commands.choices(scope=[ - app_commands.Choice(name="User Fact", value="user"), - app_commands.Choice(name="General Fact", value="general"), - ]) - async def gurtforget(interaction: discord.Interaction, scope: app_commands.Choice[str], fact: str, user: Optional[discord.User] = None): + @app_commands.choices( + scope=[ + app_commands.Choice(name="User Fact", value="user"), + app_commands.Choice(name="General Fact", value="general"), + ] + ) + async def gurtforget( + interaction: discord.Interaction, + scope: app_commands.Choice[str], + fact: str, + user: Optional[discord.User] = None, + ): """Handles the /gurtforget command.""" await interaction.response.defer(ephemeral=True) @@ -307,55 +525,93 @@ def setup_commands(cog: 'GurtCog'): # Permissions Check: Allow users to forget facts about themselves, owner can forget anything. can_forget = False if scope_value == "user": - if target_user_id == str(interaction.user.id): # User forgetting their own fact + if target_user_id == str( + interaction.user.id + ): # User forgetting their own fact can_forget = True - elif interaction.user.id == cog.bot.owner_id: # Owner forgetting any user fact + elif ( + interaction.user.id == cog.bot.owner_id + ): # Owner forgetting any user fact can_forget = True elif not target_user_id: - await interaction.followup.send("❌ Please specify a user when forgetting a user fact.", ephemeral=True) - return + await interaction.followup.send( + "❌ Please specify a user when forgetting a user fact.", + ephemeral=True, + ) + return elif scope_value == "general": - if interaction.user.id == cog.bot.owner_id: # Only owner can forget general facts + if ( + interaction.user.id == cog.bot.owner_id + ): # Only owner can forget general facts can_forget = True if not can_forget: - await interaction.followup.send("⛔ You don't have permission to forget this fact.", ephemeral=True) + await interaction.followup.send( + "⛔ You don't have permission to forget this fact.", ephemeral=True + ) return if not fact: - await interaction.followup.send("❌ Please provide the exact fact text to forget.", ephemeral=True) + await interaction.followup.send( + "❌ Please provide the exact fact text to forget.", ephemeral=True + ) return result = None if scope_value == "user": - if not target_user_id: # Should be caught above, but double-check - await interaction.followup.send("❌ User is required for scope 'user'.", ephemeral=True) - return + if not target_user_id: # Should be caught above, but double-check + await interaction.followup.send( + "❌ User is required for scope 'user'.", ephemeral=True + ) + return result = await cog.memory_manager.delete_user_fact(target_user_id, fact) if result.get("status") == "deleted": - await interaction.followup.send(f"✅ Okay, I've forgotten the fact '{fact}' about {user.display_name}.", ephemeral=True) + await interaction.followup.send( + f"✅ Okay, I've forgotten the fact '{fact}' about {user.display_name}.", + ephemeral=True, + ) elif result.get("status") == "not_found": - await interaction.followup.send(f"❓ I couldn't find that exact fact ('{fact}') stored for {user.display_name}.", ephemeral=True) + await interaction.followup.send( + f"❓ I couldn't find that exact fact ('{fact}') stored for {user.display_name}.", + ephemeral=True, + ) else: - await interaction.followup.send(f"⚠️ Error forgetting user fact: {result.get('error', 'Unknown error')}", ephemeral=True) + await interaction.followup.send( + f"⚠️ Error forgetting user fact: {result.get('error', 'Unknown error')}", + ephemeral=True, + ) elif scope_value == "general": result = await cog.memory_manager.delete_general_fact(fact) if result.get("status") == "deleted": - await interaction.followup.send(f"✅ Okay, I've forgotten the general fact: '{fact}'.", ephemeral=True) + await interaction.followup.send( + f"✅ Okay, I've forgotten the general fact: '{fact}'.", + ephemeral=True, + ) elif result.get("status") == "not_found": - await interaction.followup.send(f"❓ I couldn't find that exact general fact: '{fact}'.", ephemeral=True) + await interaction.followup.send( + f"❓ I couldn't find that exact general fact: '{fact}'.", + ephemeral=True, + ) else: - await interaction.followup.send(f"⚠️ Error forgetting general fact: {result.get('error', 'Unknown error')}", ephemeral=True) + await interaction.followup.send( + f"⚠️ Error forgetting general fact: {result.get('error', 'Unknown error')}", + ephemeral=True, + ) command_functions.append(gurtforget) # --- Gurt Force Autonomous Action Command (Owner Only) --- - @cog.bot.tree.command(name="gurtforceauto", description="Force Gurt to execute an autonomous action immediately. (Owner only)") + @cog.bot.tree.command( + name="gurtforceauto", + description="Force Gurt to execute an autonomous action immediately. (Owner only)", + ) async def gurtforceauto(interaction: discord.Interaction): """Handles the /gurtforceauto command.""" if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can force autonomous actions.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can force autonomous actions.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) try: @@ -370,44 +626,71 @@ def setup_commands(cog: 'GurtCog'): await interaction.followup.send(summary, ephemeral=True) except Exception as e: import traceback - traceback.print_exc() - await interaction.followup.send(f"❌ Error forcing autonomous action: {e}", ephemeral=True) - command_functions.append(gurtforceauto) # Add gurtforceauto to the list + traceback.print_exc() + await interaction.followup.send( + f"❌ Error forcing autonomous action: {e}", ephemeral=True + ) + + command_functions.append(gurtforceauto) # Add gurtforceauto to the list # --- Gurt Clear Action History Command (Owner Only) --- - @cog.bot.tree.command(name="gurtclearhistory", description="Clear Gurt's internal autonomous action history. (Owner only)") + @cog.bot.tree.command( + name="gurtclearhistory", + description="Clear Gurt's internal autonomous action history. (Owner only)", + ) async def gurtclearhistory(interaction: discord.Interaction): """Handles the /gurtclearhistory command.""" if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can clear the action history.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can clear the action history.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) try: result = await cog.memory_manager.clear_internal_action_logs() if "error" in result: - await interaction.followup.send(f"⚠️ Error clearing action history: {result['error']}", ephemeral=True) + await interaction.followup.send( + f"⚠️ Error clearing action history: {result['error']}", + ephemeral=True, + ) else: - await interaction.followup.send("✅ Gurt's autonomous action history has been cleared.", ephemeral=True) + await interaction.followup.send( + "✅ Gurt's autonomous action history has been cleared.", + ephemeral=True, + ) except Exception as e: import traceback - traceback.print_exc() - await interaction.followup.send(f"❌ An unexpected error occurred while clearing history: {e}", ephemeral=True) - command_functions.append(gurtclearhistory) # Add the new command + traceback.print_exc() + await interaction.followup.send( + f"❌ An unexpected error occurred while clearing history: {e}", + ephemeral=True, + ) + + command_functions.append(gurtclearhistory) # Add the new command # --- Gurt Goal Command Group --- - gurtgoal_group = app_commands.Group(name="gurtgoal", description="Manage Gurt's long-term goals (Owner only)") + gurtgoal_group = app_commands.Group( + name="gurtgoal", description="Manage Gurt's long-term goals (Owner only)" + ) @gurtgoal_group.command(name="add", description="Add a new goal for Gurt.") @app_commands.describe( description="The description of the goal.", priority="Priority (1=highest, 10=lowest, default=5).", - details_json="Optional JSON string for goal details (e.g., sub-tasks)." + details_json="Optional JSON string for goal details (e.g., sub-tasks).", ) - async def gurtgoal_add(interaction: discord.Interaction, description: str, priority: Optional[int] = 5, details_json: Optional[str] = None): + async def gurtgoal_add( + interaction: discord.Interaction, + description: str, + priority: Optional[int] = 5, + details_json: Optional[str] = None, + ): if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can add goals.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can add goals.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) details = None @@ -415,7 +698,9 @@ def setup_commands(cog: 'GurtCog'): try: details = json.loads(details_json) except json.JSONDecodeError: - await interaction.followup.send("❌ Invalid JSON format for details.", ephemeral=True) + await interaction.followup.send( + "❌ Invalid JSON format for details.", ephemeral=True + ) return # Capture context from interaction @@ -429,64 +714,108 @@ def setup_commands(cog: 'GurtCog'): details, guild_id=guild_id, channel_id=channel_id, - user_id=user_id + user_id=user_id, ) if result.get("status") == "added": - await interaction.followup.send(f"✅ Goal added (ID: {result.get('goal_id')}): '{description}'", ephemeral=True) + await interaction.followup.send( + f"✅ Goal added (ID: {result.get('goal_id')}): '{description}'", + ephemeral=True, + ) elif result.get("status") == "duplicate": - await interaction.followup.send(f"⚠️ Goal '{description}' already exists (ID: {result.get('goal_id')}).", ephemeral=True) + await interaction.followup.send( + f"⚠️ Goal '{description}' already exists (ID: {result.get('goal_id')}).", + ephemeral=True, + ) else: - await interaction.followup.send(f"⚠️ Error adding goal: {result.get('error', 'Unknown error')}", ephemeral=True) + await interaction.followup.send( + f"⚠️ Error adding goal: {result.get('error', 'Unknown error')}", + ephemeral=True, + ) @gurtgoal_group.command(name="list", description="List Gurt's current goals.") - @app_commands.describe(status="Filter goals by status (e.g., pending, active).", limit="Maximum goals to show (default 10).") - @app_commands.choices(status=[ - app_commands.Choice(name="Pending", value="pending"), - app_commands.Choice(name="Active", value="active"), - app_commands.Choice(name="Completed", value="completed"), - app_commands.Choice(name="Failed", value="failed"), - ]) - async def gurtgoal_list(interaction: discord.Interaction, status: Optional[app_commands.Choice[str]] = None, limit: Optional[int] = 10): + @app_commands.describe( + status="Filter goals by status (e.g., pending, active).", + limit="Maximum goals to show (default 10).", + ) + @app_commands.choices( + status=[ + app_commands.Choice(name="Pending", value="pending"), + app_commands.Choice(name="Active", value="active"), + app_commands.Choice(name="Completed", value="completed"), + app_commands.Choice(name="Failed", value="failed"), + ] + ) + async def gurtgoal_list( + interaction: discord.Interaction, + status: Optional[app_commands.Choice[str]] = None, + limit: Optional[int] = 10, + ): if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can list goals.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can list goals.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) status_value = status.value if status else None - limit_value = max(1, min(limit or 10, 25)) # Clamp limit - goals = await cog.memory_manager.get_goals(status=status_value, limit=limit_value) + limit_value = max(1, min(limit or 10, 25)) # Clamp limit + goals = await cog.memory_manager.get_goals( + status=status_value, limit=limit_value + ) if not goals: - await interaction.followup.send(f"No goals found matching the criteria (Status: {status_value or 'any'}).", ephemeral=True) + await interaction.followup.send( + f"No goals found matching the criteria (Status: {status_value or 'any'}).", + ephemeral=True, + ) return - embed = create_gurt_embed(f"Gurt Goals (Status: {status_value or 'All'})", color=discord.Color.purple()) + embed = create_gurt_embed( + f"Gurt Goals (Status: {status_value or 'All'})", + color=discord.Color.purple(), + ) for goal in goals: - details_str = f"\n Details: `{json.dumps(goal.get('details'))}`" if goal.get('details') else "" - created_ts = int(goal.get('created_timestamp', 0)) - updated_ts = int(goal.get('last_updated', 0)) + details_str = ( + f"\n Details: `{json.dumps(goal.get('details'))}`" + if goal.get("details") + else "" + ) + created_ts = int(goal.get("created_timestamp", 0)) + updated_ts = int(goal.get("last_updated", 0)) embed.add_field( name=f"ID: {goal.get('goal_id')} | P: {goal.get('priority', '?')} | Status: {goal.get('status', '?')}", value=f"> {goal.get('description', 'N/A')}{details_str}\n" - f"> Created: | Updated: ", - inline=False + f"> Created: | Updated: ", + inline=False, ) await interaction.followup.send(embed=embed, ephemeral=True) - @gurtgoal_group.command(name="update", description="Update a goal's status, priority, or details.") + @gurtgoal_group.command( + name="update", description="Update a goal's status, priority, or details." + ) @app_commands.describe( goal_id="The ID of the goal to update.", status="New status for the goal.", priority="New priority (1=highest, 10=lowest).", - details_json="Optional: New JSON string for goal details (replaces existing)." + details_json="Optional: New JSON string for goal details (replaces existing).", ) - @app_commands.choices(status=[ - app_commands.Choice(name="Pending", value="pending"), - app_commands.Choice(name="Active", value="active"), - app_commands.Choice(name="Completed", value="completed"), - app_commands.Choice(name="Failed", value="failed"), - ]) - async def gurtgoal_update(interaction: discord.Interaction, goal_id: int, status: Optional[app_commands.Choice[str]] = None, priority: Optional[int] = None, details_json: Optional[str] = None): + @app_commands.choices( + status=[ + app_commands.Choice(name="Pending", value="pending"), + app_commands.Choice(name="Active", value="active"), + app_commands.Choice(name="Completed", value="completed"), + app_commands.Choice(name="Failed", value="failed"), + ] + ) + async def gurtgoal_update( + interaction: discord.Interaction, + goal_id: int, + status: Optional[app_commands.Choice[str]] = None, + priority: Optional[int] = None, + details_json: Optional[str] = None, + ): if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can update goals.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can update goals.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) @@ -496,92 +825,160 @@ def setup_commands(cog: 'GurtCog'): try: details = json.loads(details_json) except json.JSONDecodeError: - await interaction.followup.send("❌ Invalid JSON format for details.", ephemeral=True) + await interaction.followup.send( + "❌ Invalid JSON format for details.", ephemeral=True + ) return if not any([status_value, priority is not None, details is not None]): - await interaction.followup.send("❌ You must provide at least one field to update (status, priority, or details_json).", ephemeral=True) - return + await interaction.followup.send( + "❌ You must provide at least one field to update (status, priority, or details_json).", + ephemeral=True, + ) + return - result = await cog.memory_manager.update_goal(goal_id, status=status_value, priority=priority, details=details) + result = await cog.memory_manager.update_goal( + goal_id, status=status_value, priority=priority, details=details + ) if result.get("status") == "updated": - await interaction.followup.send(f"✅ Goal ID {goal_id} updated.", ephemeral=True) + await interaction.followup.send( + f"✅ Goal ID {goal_id} updated.", ephemeral=True + ) elif result.get("status") == "not_found": - await interaction.followup.send(f"❓ Goal ID {goal_id} not found.", ephemeral=True) + await interaction.followup.send( + f"❓ Goal ID {goal_id} not found.", ephemeral=True + ) else: - await interaction.followup.send(f"⚠️ Error updating goal: {result.get('error', 'Unknown error')}", ephemeral=True) + await interaction.followup.send( + f"⚠️ Error updating goal: {result.get('error', 'Unknown error')}", + ephemeral=True, + ) @gurtgoal_group.command(name="delete", description="Delete a goal.") @app_commands.describe(goal_id="The ID of the goal to delete.") async def gurtgoal_delete(interaction: discord.Interaction, goal_id: int): if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can delete goals.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can delete goals.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) result = await cog.memory_manager.delete_goal(goal_id) if result.get("status") == "deleted": - await interaction.followup.send(f"✅ Goal ID {goal_id} deleted.", ephemeral=True) + await interaction.followup.send( + f"✅ Goal ID {goal_id} deleted.", ephemeral=True + ) elif result.get("status") == "not_found": - await interaction.followup.send(f"❓ Goal ID {goal_id} not found.", ephemeral=True) + await interaction.followup.send( + f"❓ Goal ID {goal_id} not found.", ephemeral=True + ) else: - await interaction.followup.send(f"⚠️ Error deleting goal: {result.get('error', 'Unknown error')}", ephemeral=True) + await interaction.followup.send( + f"⚠️ Error deleting goal: {result.get('error', 'Unknown error')}", + ephemeral=True, + ) # Add the command group to the bot's tree cog.bot.tree.add_command(gurtgoal_group) # Add group command functions to the list for tracking (optional, but good practice) - command_functions.extend([gurtgoal_add, gurtgoal_list, gurtgoal_update, gurtgoal_delete]) + command_functions.extend( + [gurtgoal_add, gurtgoal_list, gurtgoal_update, gurtgoal_delete] + ) # --- Gurt Ignore Command Group (Owner Only) --- - gurtignore_group = app_commands.Group(name="gurtignore", description="Manage channels Gurt should ignore. (Owner only)") + gurtignore_group = app_commands.Group( + name="gurtignore", + description="Manage channels Gurt should ignore. (Owner only)", + ) - @gurtignore_group.command(name="add", description="Add a channel to Gurt's ignore list.") + @gurtignore_group.command( + name="add", description="Add a channel to Gurt's ignore list." + ) @app_commands.describe(channel="The channel or thread to ignore.") - async def gurtignore_add(interaction: discord.Interaction, channel: discord.abc.GuildChannel): # Use GuildChannel to accept TextChannel, Thread, etc. + async def gurtignore_add( + interaction: discord.Interaction, channel: discord.abc.GuildChannel + ): # Use GuildChannel to accept TextChannel, Thread, etc. if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can modify the ignore list.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can modify the ignore list.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) - current_ignored_ids = set(cog.IGNORED_CHANNEL_IDS) # Use cog's direct reference + current_ignored_ids = set(cog.IGNORED_CHANNEL_IDS) # Use cog's direct reference if channel.id in current_ignored_ids: - await interaction.followup.send(f"⚠️ Channel {channel.mention} is already in the ignore list.", ephemeral=True) + await interaction.followup.send( + f"⚠️ Channel {channel.mention} is already in the ignore list.", + ephemeral=True, + ) return current_ignored_ids.add(channel.id) - if cog.update_ignored_channels_file(list(current_ignored_ids)): # Use cog's direct reference, ensure it's a list - await interaction.followup.send(f"✅ Channel {channel.mention} added to the ignore list.", ephemeral=True) + if cog.update_ignored_channels_file( + list(current_ignored_ids) + ): # Use cog's direct reference, ensure it's a list + await interaction.followup.send( + f"✅ Channel {channel.mention} added to the ignore list.", + ephemeral=True, + ) else: - await interaction.followup.send(f"❌ Failed to update the ignore list file. Check bot logs.", ephemeral=True) + await interaction.followup.send( + f"❌ Failed to update the ignore list file. Check bot logs.", + ephemeral=True, + ) - @gurtignore_group.command(name="remove", description="Remove a channel from Gurt's ignore list.") + @gurtignore_group.command( + name="remove", description="Remove a channel from Gurt's ignore list." + ) @app_commands.describe(channel="The channel or thread to stop ignoring.") - async def gurtignore_remove(interaction: discord.Interaction, channel: discord.abc.GuildChannel): + async def gurtignore_remove( + interaction: discord.Interaction, channel: discord.abc.GuildChannel + ): if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can modify the ignore list.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can modify the ignore list.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) - current_ignored_ids = set(cog.IGNORED_CHANNEL_IDS) # Use cog's direct reference + current_ignored_ids = set(cog.IGNORED_CHANNEL_IDS) # Use cog's direct reference if channel.id not in current_ignored_ids: - await interaction.followup.send(f"⚠️ Channel {channel.mention} is not in the ignore list.", ephemeral=True) + await interaction.followup.send( + f"⚠️ Channel {channel.mention} is not in the ignore list.", + ephemeral=True, + ) return current_ignored_ids.remove(channel.id) - if cog.update_ignored_channels_file(list(current_ignored_ids)): # Use cog's direct reference, ensure it's a list - await interaction.followup.send(f"✅ Channel {channel.mention} removed from the ignore list.", ephemeral=True) + if cog.update_ignored_channels_file( + list(current_ignored_ids) + ): # Use cog's direct reference, ensure it's a list + await interaction.followup.send( + f"✅ Channel {channel.mention} removed from the ignore list.", + ephemeral=True, + ) else: - await interaction.followup.send(f"❌ Failed to update the ignore list file. Check bot logs.", ephemeral=True) + await interaction.followup.send( + f"❌ Failed to update the ignore list file. Check bot logs.", + ephemeral=True, + ) - @gurtignore_group.command(name="list", description="List all channels Gurt is currently ignoring.") + @gurtignore_group.command( + name="list", description="List all channels Gurt is currently ignoring." + ) async def gurtignore_list(interaction: discord.Interaction): if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can view the ignore list.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can view the ignore list.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) - current_ignored_ids = cog.IGNORED_CHANNEL_IDS # Use cog's direct reference + current_ignored_ids = cog.IGNORED_CHANNEL_IDS # Use cog's direct reference if not current_ignored_ids: - await interaction.followup.send("Gurt is not currently ignoring any channels.", ephemeral=True) + await interaction.followup.send( + "Gurt is not currently ignoring any channels.", ephemeral=True + ) return embed = create_gurt_embed("Ignored Channels", color=discord.Color.orange()) @@ -592,7 +989,7 @@ def setup_commands(cog: 'GurtCog'): description_lines.append(f"- {ch.mention} (`{channel_id}`)") else: description_lines.append(f"- Unknown Channel (`{channel_id}`)") - + embed.description = "\n".join(description_lines) await interaction.followup.send(embed=embed, ephemeral=True) @@ -600,134 +997,229 @@ def setup_commands(cog: 'GurtCog'): command_functions.extend([gurtignore_add, gurtignore_remove, gurtignore_list]) # --- Gurt Emoji Command Group (Owner Only) --- - gurtemoji_group = app_commands.Group(name="gurtemoji", description="Manage Gurt's custom emoji knowledge. (Owner only)") + gurtemoji_group = app_commands.Group( + name="gurtemoji", + description="Manage Gurt's custom emoji knowledge. (Owner only)", + ) - @gurtemoji_group.command(name="add", description="Add a custom emoji to Gurt's knowledge.") - @app_commands.describe(name="The name of the emoji (e.g., :custom_emoji:).", url="The URL of the emoji image.") + @gurtemoji_group.command( + name="add", description="Add a custom emoji to Gurt's knowledge." + ) + @app_commands.describe( + name="The name of the emoji (e.g., :custom_emoji:).", + url="The URL of the emoji image.", + ) async def gurtemoji_add(interaction: discord.Interaction, name: str, url: str): if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can manage custom emojis.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can manage custom emojis.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) # Assuming cog.emoji_manager exists and has an add_emoji method - if hasattr(cog, 'emoji_manager') and hasattr(cog.emoji_manager, 'add_emoji'): + if hasattr(cog, "emoji_manager") and hasattr(cog.emoji_manager, "add_emoji"): success = await cog.emoji_manager.add_emoji(name, url) if success: - await interaction.followup.send(f"✅ Emoji '{name}' added.", ephemeral=True) + await interaction.followup.send( + f"✅ Emoji '{name}' added.", ephemeral=True + ) else: - await interaction.followup.send(f"❌ Failed to add emoji '{name}'. It might already exist or there was an error.", ephemeral=True) + await interaction.followup.send( + f"❌ Failed to add emoji '{name}'. It might already exist or there was an error.", + ephemeral=True, + ) else: - await interaction.followup.send("Emoji manager not available.", ephemeral=True) + await interaction.followup.send( + "Emoji manager not available.", ephemeral=True + ) - @gurtemoji_group.command(name="remove", description="Remove a custom emoji from Gurt's knowledge.") - @app_commands.describe(name="The name of the emoji to remove (e.g., :custom_emoji:).") + @gurtemoji_group.command( + name="remove", description="Remove a custom emoji from Gurt's knowledge." + ) + @app_commands.describe( + name="The name of the emoji to remove (e.g., :custom_emoji:)." + ) async def gurtemoji_remove(interaction: discord.Interaction, name: str): if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can manage custom emojis.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can manage custom emojis.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) - if hasattr(cog, 'emoji_manager') and hasattr(cog.emoji_manager, 'remove_emoji'): + if hasattr(cog, "emoji_manager") and hasattr(cog.emoji_manager, "remove_emoji"): success = await cog.emoji_manager.remove_emoji(name) if success: - await interaction.followup.send(f"✅ Emoji '{name}' removed.", ephemeral=True) + await interaction.followup.send( + f"✅ Emoji '{name}' removed.", ephemeral=True + ) else: - await interaction.followup.send(f"❌ Failed to remove emoji '{name}'. It might not exist or there was an error.", ephemeral=True) + await interaction.followup.send( + f"❌ Failed to remove emoji '{name}'. It might not exist or there was an error.", + ephemeral=True, + ) else: - await interaction.followup.send("Emoji manager not available.", ephemeral=True) + await interaction.followup.send( + "Emoji manager not available.", ephemeral=True + ) - @gurtemoji_group.command(name="list", description="List all custom emojis Gurt knows.") + @gurtemoji_group.command( + name="list", description="List all custom emojis Gurt knows." + ) async def gurtemoji_list(interaction: discord.Interaction): if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can manage custom emojis.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can manage custom emojis.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) - if hasattr(cog, 'emoji_manager') and hasattr(cog.emoji_manager, 'list_emojis'): + if hasattr(cog, "emoji_manager") and hasattr(cog.emoji_manager, "list_emojis"): emojis = await cog.emoji_manager.list_emojis() if emojis: - embed = create_gurt_embed("Known Custom Emojis", color=discord.Color.gold()) - description = "\n".join([f"- {name}: {url}" for name, url in emojis.items()]) + embed = create_gurt_embed( + "Known Custom Emojis", color=discord.Color.gold() + ) + description = "\n".join( + [f"- {name}: {url}" for name, url in emojis.items()] + ) embed.description = description await interaction.followup.send(embed=embed, ephemeral=True) else: - await interaction.followup.send("Gurt doesn't know any custom emojis yet.", ephemeral=True) + await interaction.followup.send( + "Gurt doesn't know any custom emojis yet.", ephemeral=True + ) else: - await interaction.followup.send("Emoji manager not available.", ephemeral=True) + await interaction.followup.send( + "Emoji manager not available.", ephemeral=True + ) cog.bot.tree.add_command(gurtemoji_group) command_functions.extend([gurtemoji_add, gurtemoji_remove, gurtemoji_list]) # --- Gurt Sticker Command Group (Owner Only) --- - gurtsticker_group = app_commands.Group(name="gurtsticker", description="Manage Gurt's custom sticker knowledge. (Owner only)") + gurtsticker_group = app_commands.Group( + name="gurtsticker", + description="Manage Gurt's custom sticker knowledge. (Owner only)", + ) - @gurtsticker_group.command(name="add", description="Add a custom sticker to Gurt's knowledge.") - @app_commands.describe(name="The name of the sticker.", url="The URL of the sticker image.") + @gurtsticker_group.command( + name="add", description="Add a custom sticker to Gurt's knowledge." + ) + @app_commands.describe( + name="The name of the sticker.", url="The URL of the sticker image." + ) async def gurtsticker_add(interaction: discord.Interaction, name: str, url: str): if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can manage custom stickers.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can manage custom stickers.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) - if hasattr(cog, 'emoji_manager') and hasattr(cog.emoji_manager, 'add_sticker'): + if hasattr(cog, "emoji_manager") and hasattr(cog.emoji_manager, "add_sticker"): success = await cog.emoji_manager.add_sticker(name, url) if success: - await interaction.followup.send(f"✅ Sticker '{name}' added.", ephemeral=True) + await interaction.followup.send( + f"✅ Sticker '{name}' added.", ephemeral=True + ) else: - await interaction.followup.send(f"❌ Failed to add sticker '{name}'. It might already exist or there was an error.", ephemeral=True) + await interaction.followup.send( + f"❌ Failed to add sticker '{name}'. It might already exist or there was an error.", + ephemeral=True, + ) else: - await interaction.followup.send("Sticker manager not available.", ephemeral=True) + await interaction.followup.send( + "Sticker manager not available.", ephemeral=True + ) - @gurtsticker_group.command(name="remove", description="Remove a custom sticker from Gurt's knowledge.") + @gurtsticker_group.command( + name="remove", description="Remove a custom sticker from Gurt's knowledge." + ) @app_commands.describe(name="The name of the sticker to remove.") async def gurtsticker_remove(interaction: discord.Interaction, name: str): if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can manage custom stickers.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can manage custom stickers.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) - if hasattr(cog, 'emoji_manager') and hasattr(cog.emoji_manager, 'remove_sticker'): + if hasattr(cog, "emoji_manager") and hasattr( + cog.emoji_manager, "remove_sticker" + ): success = await cog.emoji_manager.remove_sticker(name) if success: - await interaction.followup.send(f"✅ Sticker '{name}' removed.", ephemeral=True) + await interaction.followup.send( + f"✅ Sticker '{name}' removed.", ephemeral=True + ) else: - await interaction.followup.send(f"❌ Failed to remove sticker '{name}'. It might not exist or there was an error.", ephemeral=True) + await interaction.followup.send( + f"❌ Failed to remove sticker '{name}'. It might not exist or there was an error.", + ephemeral=True, + ) else: - await interaction.followup.send("Sticker manager not available.", ephemeral=True) + await interaction.followup.send( + "Sticker manager not available.", ephemeral=True + ) - @gurtsticker_group.command(name="list", description="List all custom stickers Gurt knows.") + @gurtsticker_group.command( + name="list", description="List all custom stickers Gurt knows." + ) async def gurtsticker_list(interaction: discord.Interaction): if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can manage custom stickers.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can manage custom stickers.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) - if hasattr(cog, 'emoji_manager') and hasattr(cog.emoji_manager, 'list_stickers'): + if hasattr(cog, "emoji_manager") and hasattr( + cog.emoji_manager, "list_stickers" + ): stickers = await cog.emoji_manager.list_stickers() if stickers: - embed = create_gurt_embed("Known Custom Stickers", color=discord.Color.dark_gold()) - description = "\n".join([f"- {name}: {url}" for name, url in stickers.items()]) + embed = create_gurt_embed( + "Known Custom Stickers", color=discord.Color.dark_gold() + ) + description = "\n".join( + [f"- {name}: {url}" for name, url in stickers.items()] + ) embed.description = description await interaction.followup.send(embed=embed, ephemeral=True) else: - await interaction.followup.send("Gurt doesn't know any custom stickers yet.", ephemeral=True) + await interaction.followup.send( + "Gurt doesn't know any custom stickers yet.", ephemeral=True + ) else: - await interaction.followup.send("Sticker manager not available.", ephemeral=True) + await interaction.followup.send( + "Sticker manager not available.", ephemeral=True + ) cog.bot.tree.add_command(gurtsticker_group) command_functions.extend([gurtsticker_add, gurtsticker_remove, gurtsticker_list]) # --- Gurt Tenor API Key Command (Owner Only) --- - @cog.bot.tree.command(name="gurttenorapikey", description="Set the Tenor API key for Gurt. (Owner only)") + @cog.bot.tree.command( + name="gurttenorapikey", + description="Set the Tenor API key for Gurt. (Owner only)", + ) @app_commands.describe(api_key="The Tenor API key.") async def gurttenorapikey(interaction: discord.Interaction, api_key: str): if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can set the Tenor API key.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can set the Tenor API key.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) # Assuming cog.config_manager or similar exists for updating config - if hasattr(cog, 'config_manager') and hasattr(cog.config_manager, 'set_tenor_api_key'): + if hasattr(cog, "config_manager") and hasattr( + cog.config_manager, "set_tenor_api_key" + ): await cog.config_manager.set_tenor_api_key(api_key) # Update the cog's runtime TENOR_API_KEY if it's stored there directly or re-init relevant clients - if hasattr(cog, 'TENOR_API_KEY'): - cog.TENOR_API_KEY = api_key # If cog holds it directly + if hasattr(cog, "TENOR_API_KEY"): + cog.TENOR_API_KEY = api_key # If cog holds it directly # Potentially re-initialize TavilyClient or other clients if they use Tenor key indirectly - await interaction.followup.send("✅ Tenor API key set. You may need to reload Gurt for changes to fully apply.", ephemeral=True) + await interaction.followup.send( + "✅ Tenor API key set. You may need to reload Gurt for changes to fully apply.", + ephemeral=True, + ) else: # Fallback: try to update config.py directly (less ideal) # This requires careful handling of file I/O and is generally not recommended for runtime changes. @@ -740,71 +1232,109 @@ def setup_commands(cog: 'GurtCog'): # This is a placeholder for a more robust config update mechanism # In a real scenario, you'd write this to a .env file or a database # For now, we'll just update the cog's attribute if it exists - if hasattr(cog, 'TENOR_API_KEY'): + if hasattr(cog, "TENOR_API_KEY"): cog.TENOR_API_KEY = api_key # Here you would also save it persistently # e.g., await cog.memory_manager.save_setting("TENOR_API_KEY", api_key) - await interaction.followup.send("✅ Tenor API key updated in runtime. Save it persistently for it to survive restarts.", ephemeral=True) + await interaction.followup.send( + "✅ Tenor API key updated in runtime. Save it persistently for it to survive restarts.", + ephemeral=True, + ) else: - await interaction.followup.send("⚠️ Tenor API key runtime attribute not found. Key not set.", ephemeral=True) + await interaction.followup.send( + "⚠️ Tenor API key runtime attribute not found. Key not set.", + ephemeral=True, + ) except Exception as e: - await interaction.followup.send(f"❌ Error setting Tenor API key: {e}", ephemeral=True) - + await interaction.followup.send( + f"❌ Error setting Tenor API key: {e}", ephemeral=True + ) command_functions.append(gurttenorapikey) # --- Gurt Reset Personality Command (Owner Only) --- - @cog.bot.tree.command(name="gurtresetpersonality", description="Reset Gurt's personality and interests to baseline. (Owner only)") + @cog.bot.tree.command( + name="gurtresetpersonality", + description="Reset Gurt's personality and interests to baseline. (Owner only)", + ) async def gurtresetpersonality(interaction: discord.Interaction): """Handles the /gurtresetpersonality command.""" if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can reset Gurt's personality.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can reset Gurt's personality.", ephemeral=True + ) return await interaction.response.defer(ephemeral=True) try: # Ensure the cog has access to baseline values, e.g., cog.BASELINE_PERSONALITY # These would typically be loaded from gurt.config into the GurtCog instance - if not hasattr(cog, 'BASELINE_PERSONALITY') or not hasattr(cog, 'BASELINE_INTERESTS'): - await interaction.followup.send("⚠️ Baseline personality or interests not found in cog configuration. Reset aborted.", ephemeral=True) + if not hasattr(cog, "BASELINE_PERSONALITY") or not hasattr( + cog, "BASELINE_INTERESTS" + ): + await interaction.followup.send( + "⚠️ Baseline personality or interests not found in cog configuration. Reset aborted.", + ephemeral=True, + ) return - personality_result = await cog.memory_manager.reset_personality_to_baseline(cog.BASELINE_PERSONALITY) - interests_result = await cog.memory_manager.reset_interests_to_baseline(cog.BASELINE_INTERESTS) + personality_result = await cog.memory_manager.reset_personality_to_baseline( + cog.BASELINE_PERSONALITY + ) + interests_result = await cog.memory_manager.reset_interests_to_baseline( + cog.BASELINE_INTERESTS + ) messages = [] if personality_result.get("status") == "success": messages.append("✅ Personality traits reset to baseline.") else: - messages.append(f"⚠️ Error resetting personality: {personality_result.get('error', 'Unknown error')}") + messages.append( + f"⚠️ Error resetting personality: {personality_result.get('error', 'Unknown error')}" + ) if interests_result.get("status") == "success": messages.append("✅ Interests reset to baseline.") else: - messages.append(f"⚠️ Error resetting interests: {interests_result.get('error', 'Unknown error')}") - + messages.append( + f"⚠️ Error resetting interests: {interests_result.get('error', 'Unknown error')}" + ) + await interaction.followup.send("\n".join(messages), ephemeral=True) except Exception as e: import traceback + traceback.print_exc() - await interaction.followup.send(f"❌ An unexpected error occurred while resetting personality: {e}", ephemeral=True) - + await interaction.followup.send( + f"❌ An unexpected error occurred while resetting personality: {e}", + ephemeral=True, + ) + command_functions.append(gurtresetpersonality) # --- Gurt Model Command (Owner Only) --- - @cog.bot.tree.command(name="gurtmodel", description="Change Gurt's active AI model dynamically. (Owner only)") + @cog.bot.tree.command( + name="gurtmodel", + description="Change Gurt's active AI model dynamically. (Owner only)", + ) @app_commands.describe(model="The AI model to switch to.") - @app_commands.choices(model=[ - app_commands.Choice(name=friendly_name, value=model_id) - for model_id, friendly_name in AVAILABLE_AI_MODELS.items() - ]) - async def gurtmodel(interaction: discord.Interaction, model: app_commands.Choice[str]): + @app_commands.choices( + model=[ + app_commands.Choice(name=friendly_name, value=model_id) + for model_id, friendly_name in AVAILABLE_AI_MODELS.items() + ] + ) + async def gurtmodel( + interaction: discord.Interaction, model: app_commands.Choice[str] + ): """Handles the /gurtmodel command.""" if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the bot owner can change Gurt's AI model.", ephemeral=True) + await interaction.response.send_message( + "⛔ Only the bot owner can change Gurt's AI model.", ephemeral=True + ) return - + await interaction.response.defer(ephemeral=False) try: new_model_id = model.value @@ -812,38 +1342,55 @@ def setup_commands(cog: 'GurtCog'): # Update the cog's default model cog.default_model = new_model_id - + # Optionally, update the config file if you want this change to persist across restarts # This would require a function in config.py to update DEFAULT_MODEL in the .env or a separate config file # For now, we'll just update the runtime attribute. # If persistence is desired, you'd add something like: # await cog.config_manager.set_default_model(new_model_id) # Assuming a config_manager exists - await interaction.followup.send(f"✅ Gurt's AI model has been changed to: **{new_model_friendly_name}** (`{new_model_id}`).", ephemeral=False) + await interaction.followup.send( + f"✅ Gurt's AI model has been changed to: **{new_model_friendly_name}** (`{new_model_id}`).", + ephemeral=False, + ) except Exception as e: print(f"Error in /gurtmodel command: {e}") import traceback + traceback.print_exc() - await interaction.followup.send("❌ An error occurred while changing Gurt's AI model.", ephemeral=True) + await interaction.followup.send( + "❌ An error occurred while changing Gurt's AI model.", ephemeral=True + ) command_functions.append(gurtmodel) # --- Gurt Get Model Command --- - @cog.bot.tree.command(name="gurtgetmodel", description="Display Gurt's currently active AI model.") + @cog.bot.tree.command( + name="gurtgetmodel", description="Display Gurt's currently active AI model." + ) async def gurtgetmodel(interaction: discord.Interaction): """Handles the /gurtgetmodel command.""" await interaction.response.defer(ephemeral=False) try: current_model_id = cog.default_model # Try to get the friendly name from AVAILABLE_AI_MODELS - friendly_name = AVAILABLE_AI_MODELS.get(current_model_id, current_model_id) # Fallback to ID if not found + friendly_name = AVAILABLE_AI_MODELS.get( + current_model_id, current_model_id + ) # Fallback to ID if not found - await interaction.followup.send(f"Gurt is currently using AI model: **{friendly_name}** (`{current_model_id}`).", ephemeral=False) + await interaction.followup.send( + f"Gurt is currently using AI model: **{friendly_name}** (`{current_model_id}`).", + ephemeral=False, + ) except Exception as e: print(f"Error in /gurtgetmodel command: {e}") import traceback + traceback.print_exc() - await interaction.followup.send("❌ An error occurred while fetching Gurt's current AI model.", ephemeral=True) + await interaction.followup.send( + "❌ An error occurred while fetching Gurt's current AI model.", + ephemeral=True, + ) command_functions.append(gurtgetmodel) diff --git a/gurt/config.py b/gurt/config.py index 70b48cd..45d8120 100644 --- a/gurt/config.py +++ b/gurt/config.py @@ -12,21 +12,31 @@ load_dotenv() PROJECT_ID = os.getenv("GCP_PROJECT_ID", "1079377687568") LOCATION = os.getenv("GCP_LOCATION", "us-central1") TAVILY_API_KEY = os.getenv("TAVILY_API_KEY", "") -TENOR_API_KEY = os.getenv("TENOR_API_KEY", "") # Added Tenor API Key -PISTON_API_URL = os.getenv("PISTON_API_URL") # For run_python_code tool -PISTON_API_KEY = os.getenv("PISTON_API_KEY") # Optional key for Piston +TENOR_API_KEY = os.getenv("TENOR_API_KEY", "") # Added Tenor API Key +PISTON_API_URL = os.getenv("PISTON_API_URL") # For run_python_code tool +PISTON_API_KEY = os.getenv("PISTON_API_KEY") # Optional key for Piston # --- Tavily Configuration --- TAVILY_DEFAULT_SEARCH_DEPTH = os.getenv("TAVILY_DEFAULT_SEARCH_DEPTH", "basic") TAVILY_DEFAULT_MAX_RESULTS = int(os.getenv("TAVILY_DEFAULT_MAX_RESULTS", 5)) -TAVILY_DISABLE_ADVANCED = os.getenv("TAVILY_DISABLE_ADVANCED", "false").lower() == "true" # For cost control +TAVILY_DISABLE_ADVANCED = ( + os.getenv("TAVILY_DISABLE_ADVANCED", "false").lower() == "true" +) # For cost control # --- Model Configuration --- DEFAULT_MODEL = os.getenv("GURT_DEFAULT_MODEL", "google/gemini-2.5-flash-preview-05-20") -FALLBACK_MODEL = os.getenv("GURT_FALLBACK_MODEL", "google/gemini-2.5-flash-preview-05-20") -CUSTOM_TUNED_MODEL_ENDPOINT = os.getenv("GURT_CUSTOM_TUNED_MODEL", "google/gemini-2.5-flash-preview-05-20") -SAFETY_CHECK_MODEL = os.getenv("GURT_SAFETY_CHECK_MODEL", "google/gemini-2.5-flash-preview-05-20") # Use a Vertex AI model for safety checks -EMOJI_STICKER_DESCRIPTION_MODEL = "google/gemini-2.0-flash-001" # Hardcoded for emoji/sticker image descriptions +FALLBACK_MODEL = os.getenv( + "GURT_FALLBACK_MODEL", "google/gemini-2.5-flash-preview-05-20" +) +CUSTOM_TUNED_MODEL_ENDPOINT = os.getenv( + "GURT_CUSTOM_TUNED_MODEL", "google/gemini-2.5-flash-preview-05-20" +) +SAFETY_CHECK_MODEL = os.getenv( + "GURT_SAFETY_CHECK_MODEL", "google/gemini-2.5-flash-preview-05-20" +) # Use a Vertex AI model for safety checks +EMOJI_STICKER_DESCRIPTION_MODEL = ( + "google/gemini-2.0-flash-001" # Hardcoded for emoji/sticker image descriptions +) # Available AI Models for dynamic switching AVAILABLE_AI_MODELS = { @@ -34,13 +44,13 @@ AVAILABLE_AI_MODELS = { "google/gemini-2.5-pro-preview-05-06": "Gemini 2.5 Pro Preview", "claude-sonnet-4@20250514": "Claude Sonnet 4", "llama-4-maverick-17b-128e-instruct-maas": "Llama 4 Maverick Instruct", - "google/gemini-2.0-flash-001": "Gemini 2.0 Flash" + "google/gemini-2.0-flash-001": "Gemini 2.0 Flash", } # --- Database Paths --- DB_PATH = os.getenv("GURT_DB_PATH", "data/gurt_memory.db") CHROMA_PATH = os.getenv("GURT_CHROMA_PATH", "data/chroma_db") -SEMANTIC_MODEL_NAME = os.getenv("GURT_SEMANTIC_MODEL", 'all-MiniLM-L6-v2') +SEMANTIC_MODEL_NAME = os.getenv("GURT_SEMANTIC_MODEL", "all-MiniLM-L6-v2") # --- Ignored Channels --- IGNORED_CHANNELS_FILE_PATH = "data/ignored_channels.json" @@ -48,51 +58,80 @@ _loaded_ignored_channel_ids = set() try: if os.path.exists(IGNORED_CHANNELS_FILE_PATH): - with open(IGNORED_CHANNELS_FILE_PATH, 'r') as f: + with open(IGNORED_CHANNELS_FILE_PATH, "r") as f: data = json.load(f) if isinstance(data, list) and all(isinstance(cid, int) for cid in data): _loaded_ignored_channel_ids = set(data) - print(f"Loaded {len(_loaded_ignored_channel_ids)} ignored channel IDs from {IGNORED_CHANNELS_FILE_PATH}") - elif data: # If file exists but content is not a list of ints (e.g. old format or corrupt) - print(f"Warning: {IGNORED_CHANNELS_FILE_PATH} contains invalid data. Attempting to load from ENV.") - _loaded_ignored_channel_ids = set() # Reset to ensure fallback - - if not _loaded_ignored_channel_ids: # If file didn't exist, was empty, or had invalid data + print( + f"Loaded {len(_loaded_ignored_channel_ids)} ignored channel IDs from {IGNORED_CHANNELS_FILE_PATH}" + ) + elif ( + data + ): # If file exists but content is not a list of ints (e.g. old format or corrupt) + print( + f"Warning: {IGNORED_CHANNELS_FILE_PATH} contains invalid data. Attempting to load from ENV." + ) + _loaded_ignored_channel_ids = set() # Reset to ensure fallback + + if ( + not _loaded_ignored_channel_ids + ): # If file didn't exist, was empty, or had invalid data IGNORED_CHANNELS_STR = os.getenv("GURT_IGNORED_CHANNELS", "") if IGNORED_CHANNELS_STR: - env_ids = {int(cid.strip()) for cid in IGNORED_CHANNELS_STR.split(',') if cid.strip().isdigit()} + env_ids = { + int(cid.strip()) + for cid in IGNORED_CHANNELS_STR.split(",") + if cid.strip().isdigit() + } if env_ids: _loaded_ignored_channel_ids = env_ids - print(f"Loaded {len(_loaded_ignored_channel_ids)} ignored channel IDs from GURT_IGNORED_CHANNELS env var.") + print( + f"Loaded {len(_loaded_ignored_channel_ids)} ignored channel IDs from GURT_IGNORED_CHANNELS env var." + ) # Initialize the JSON file with ENV var content if JSON was missing/empty - if not os.path.exists(IGNORED_CHANNELS_FILE_PATH) or os.path.getsize(IGNORED_CHANNELS_FILE_PATH) == 0: + if ( + not os.path.exists(IGNORED_CHANNELS_FILE_PATH) + or os.path.getsize(IGNORED_CHANNELS_FILE_PATH) == 0 + ): try: - os.makedirs(os.path.dirname(IGNORED_CHANNELS_FILE_PATH), exist_ok=True) - with open(IGNORED_CHANNELS_FILE_PATH, 'w') as f: + os.makedirs( + os.path.dirname(IGNORED_CHANNELS_FILE_PATH), exist_ok=True + ) + with open(IGNORED_CHANNELS_FILE_PATH, "w") as f: json.dump(list(_loaded_ignored_channel_ids), f) - print(f"Initialized {IGNORED_CHANNELS_FILE_PATH} with IDs from ENV var.") + print( + f"Initialized {IGNORED_CHANNELS_FILE_PATH} with IDs from ENV var." + ) except Exception as e: print(f"Error initializing {IGNORED_CHANNELS_FILE_PATH}: {e}") else: print("No ignored channel IDs found in GURT_IGNORED_CHANNELS env var.") else: - print(f"{IGNORED_CHANNELS_FILE_PATH} not found and GURT_IGNORED_CHANNELS env var not set or empty. No channels will be ignored by default.") + print( + f"{IGNORED_CHANNELS_FILE_PATH} not found and GURT_IGNORED_CHANNELS env var not set or empty. No channels will be ignored by default." + ) # Ensure the file exists even if empty for commands to use later if not os.path.exists(IGNORED_CHANNELS_FILE_PATH): try: - os.makedirs(os.path.dirname(IGNORED_CHANNELS_FILE_PATH), exist_ok=True) - with open(IGNORED_CHANNELS_FILE_PATH, 'w') as f: - json.dump([], f) # Create an empty list in the JSON file + os.makedirs( + os.path.dirname(IGNORED_CHANNELS_FILE_PATH), exist_ok=True + ) + with open(IGNORED_CHANNELS_FILE_PATH, "w") as f: + json.dump([], f) # Create an empty list in the JSON file print(f"Created empty {IGNORED_CHANNELS_FILE_PATH}.") except Exception as e: print(f"Error creating empty {IGNORED_CHANNELS_FILE_PATH}: {e}") except FileNotFoundError: - print(f"{IGNORED_CHANNELS_FILE_PATH} not found. Will check GURT_IGNORED_CHANNELS env var.") + print( + f"{IGNORED_CHANNELS_FILE_PATH} not found. Will check GURT_IGNORED_CHANNELS env var." + ) # This case is handled by the 'if not _loaded_ignored_channel_ids' block above except json.JSONDecodeError: - print(f"Error decoding JSON from {IGNORED_CHANNELS_FILE_PATH}. Will check GURT_IGNORED_CHANNELS env var.") + print( + f"Error decoding JSON from {IGNORED_CHANNELS_FILE_PATH}. Will check GURT_IGNORED_CHANNELS env var." + ) # This case is handled by the 'if not _loaded_ignored_channel_ids' block above except Exception as e: print(f"An unexpected error occurred while loading ignored channels: {e}") @@ -100,91 +139,173 @@ except Exception as e: IGNORED_CHANNEL_IDS = _loaded_ignored_channel_ids + # Function to update ignored channels at runtime (used by commands) def update_ignored_channels_file(channel_ids_set: set): """Updates the ignored_channels.json file and the runtime config.""" global IGNORED_CHANNEL_IDS try: os.makedirs(os.path.dirname(IGNORED_CHANNELS_FILE_PATH), exist_ok=True) - with open(IGNORED_CHANNELS_FILE_PATH, 'w') as f: + with open(IGNORED_CHANNELS_FILE_PATH, "w") as f: json.dump(list(channel_ids_set), f) IGNORED_CHANNEL_IDS = channel_ids_set - print(f"Successfully updated {IGNORED_CHANNELS_FILE_PATH} with {len(channel_ids_set)} IDs.") + print( + f"Successfully updated {IGNORED_CHANNELS_FILE_PATH} with {len(channel_ids_set)} IDs." + ) return True except Exception as e: print(f"Error updating {IGNORED_CHANNELS_FILE_PATH}: {e}") return False + # --- Memory Manager Config --- -MAX_USER_FACTS = 20 # TODO: Load from env? -MAX_GENERAL_FACTS = 100 # TODO: Load from env? +MAX_USER_FACTS = 20 # TODO: Load from env? +MAX_GENERAL_FACTS = 100 # TODO: Load from env? # --- Personality & Mood --- MOOD_OPTIONS = [ - "chill", "neutral", "curious", "slightly hyper", "a bit bored", "mischievous", - "excited", "tired", "sassy", "philosophical", "playful", "dramatic", - "nostalgic", "confused", "impressed", "skeptical", "enthusiastic", - "distracted", "focused", "creative", "sarcastic", "wholesome" + "chill", + "neutral", + "curious", + "slightly hyper", + "a bit bored", + "mischievous", + "excited", + "tired", + "sassy", + "philosophical", + "playful", + "dramatic", + "nostalgic", + "confused", + "impressed", + "skeptical", + "enthusiastic", + "distracted", + "focused", + "creative", + "sarcastic", + "wholesome", ] # Categorize moods for weighted selection MOOD_CATEGORIES = { - "positive": ["excited", "enthusiastic", "playful", "wholesome", "creative", "impressed"], - "negative": ["tired", "a bit bored", "sassy", "sarcastic", "skeptical", "dramatic", "distracted"], - "neutral": ["chill", "neutral", "curious", "philosophical", "focused", "confused", "nostalgic"], - "mischievous": ["mischievous"] # Special category for trait link + "positive": [ + "excited", + "enthusiastic", + "playful", + "wholesome", + "creative", + "impressed", + ], + "negative": [ + "tired", + "a bit bored", + "sassy", + "sarcastic", + "skeptical", + "dramatic", + "distracted", + ], + "neutral": [ + "chill", + "neutral", + "curious", + "philosophical", + "focused", + "confused", + "nostalgic", + ], + "mischievous": ["mischievous"], # Special category for trait link } BASELINE_PERSONALITY = { - "chattiness": 0.1, "emoji_usage": 0.4, "slang_level": 0.5, "randomness": 0.6, - "verbosity": 0.4, "optimism": 0.5, "curiosity": 0.6, "sarcasm_level": 0.3, - "patience": 0.6, "mischief": 0.5 + "chattiness": 0.1, + "emoji_usage": 0.4, + "slang_level": 0.5, + "randomness": 0.6, + "verbosity": 0.4, + "optimism": 0.5, + "curiosity": 0.6, + "sarcasm_level": 0.3, + "patience": 0.6, + "mischief": 0.5, } BASELINE_INTERESTS = { - "kasane teto": 0.8, "vocaloids": 0.6, "gaming": 0.6, "anime": 0.5, - "tech": 0.6, "memes": 0.6, "gooning": 0.6, "needy streamer overload": 0.7 + "kasane teto": 0.8, + "vocaloids": 0.6, + "gaming": 0.6, + "anime": 0.5, + "tech": 0.6, + "memes": 0.6, + "gooning": 0.6, + "needy streamer overload": 0.7, } -MOOD_CHANGE_INTERVAL_MIN = 1200 # 20 minutes -MOOD_CHANGE_INTERVAL_MAX = 2400 # 40 minutes -EVOLUTION_UPDATE_INTERVAL = 1800 # Evolve personality every 30 minutes +MOOD_CHANGE_INTERVAL_MIN = 1200 # 20 minutes +MOOD_CHANGE_INTERVAL_MAX = 2400 # 40 minutes +EVOLUTION_UPDATE_INTERVAL = 1800 # Evolve personality every 30 minutes # --- Stats Push --- # How often the Gurt bot should push its stats to the API server (seconds) -STATS_PUSH_INTERVAL = 30 # Push every 30 seconds +STATS_PUSH_INTERVAL = 30 # Push every 30 seconds # --- Context & Caching --- -CHANNEL_TOPIC_CACHE_TTL = 600 # seconds (10 minutes) +CHANNEL_TOPIC_CACHE_TTL = 600 # seconds (10 minutes) CONTEXT_WINDOW_SIZE = 150 # Number of messages to include in context -CONTEXT_EXPIRY_TIME = 3600 # Time in seconds before context is considered stale (1 hour) +CONTEXT_EXPIRY_TIME = ( + 3600 # Time in seconds before context is considered stale (1 hour) +) MAX_CONTEXT_TOKENS = 8000 # Maximum number of tokens to include in context (Note: Not actively enforced yet) -SUMMARY_CACHE_TTL = 900 # seconds (15 minutes) for conversation summary cache +SUMMARY_CACHE_TTL = 900 # seconds (15 minutes) for conversation summary cache # --- API Call Settings --- -API_TIMEOUT = 60 # seconds -SUMMARY_API_TIMEOUT = 45 # seconds +API_TIMEOUT = 60 # seconds +SUMMARY_API_TIMEOUT = 45 # seconds API_RETRY_ATTEMPTS = 1 -API_RETRY_DELAY = 1 # seconds +API_RETRY_DELAY = 1 # seconds # --- Proactive Engagement Config --- -PROACTIVE_LULL_THRESHOLD = int(os.getenv("PROACTIVE_LULL_THRESHOLD", 300)) # 5 mins -PROACTIVE_BOT_SILENCE_THRESHOLD = int(os.getenv("PROACTIVE_BOT_SILENCE_THRESHOLD", 900)) # 15 mins +PROACTIVE_LULL_THRESHOLD = int(os.getenv("PROACTIVE_LULL_THRESHOLD", 300)) # 5 mins +PROACTIVE_BOT_SILENCE_THRESHOLD = int( + os.getenv("PROACTIVE_BOT_SILENCE_THRESHOLD", 900) +) # 15 mins PROACTIVE_LULL_CHANCE = float(os.getenv("PROACTIVE_LULL_CHANCE", 0.1)) -PROACTIVE_TOPIC_RELEVANCE_THRESHOLD = float(os.getenv("PROACTIVE_TOPIC_RELEVANCE_THRESHOLD", 0.7)) +PROACTIVE_TOPIC_RELEVANCE_THRESHOLD = float( + os.getenv("PROACTIVE_TOPIC_RELEVANCE_THRESHOLD", 0.7) +) PROACTIVE_TOPIC_CHANCE = float(os.getenv("PROACTIVE_TOPIC_CHANCE", 0.1)) -PROACTIVE_RELATIONSHIP_SCORE_THRESHOLD = int(os.getenv("PROACTIVE_RELATIONSHIP_SCORE_THRESHOLD", 80)) +PROACTIVE_RELATIONSHIP_SCORE_THRESHOLD = int( + os.getenv("PROACTIVE_RELATIONSHIP_SCORE_THRESHOLD", 80) +) PROACTIVE_RELATIONSHIP_CHANCE = float(os.getenv("PROACTIVE_RELATIONSHIP_CHANCE", 0.1)) -PROACTIVE_SENTIMENT_SHIFT_THRESHOLD = float(os.getenv("PROACTIVE_SENTIMENT_SHIFT_THRESHOLD", 0.8)) # Intensity threshold for trigger -PROACTIVE_SENTIMENT_DURATION_THRESHOLD = int(os.getenv("PROACTIVE_SENTIMENT_DURATION_THRESHOLD", 900)) # How long sentiment needs to persist (15 mins) +PROACTIVE_SENTIMENT_SHIFT_THRESHOLD = float( + os.getenv("PROACTIVE_SENTIMENT_SHIFT_THRESHOLD", 0.8) +) # Intensity threshold for trigger +PROACTIVE_SENTIMENT_DURATION_THRESHOLD = int( + os.getenv("PROACTIVE_SENTIMENT_DURATION_THRESHOLD", 900) +) # How long sentiment needs to persist (15 mins) PROACTIVE_SENTIMENT_CHANCE = float(os.getenv("PROACTIVE_SENTIMENT_CHANCE", 0.1)) -PROACTIVE_USER_INTEREST_THRESHOLD = float(os.getenv("PROACTIVE_USER_INTEREST_THRESHOLD", 0.7)) # Min interest level for Gurt to trigger -PROACTIVE_USER_INTEREST_MATCH_THRESHOLD = float(os.getenv("PROACTIVE_USER_INTEREST_MATCH_THRESHOLD", 0.5)) # Min interest level for User (if tracked) - Currently not tracked per user, but config is ready -PROACTIVE_USER_INTEREST_CHANCE = float(os.getenv("PROACTIVE_USER_INTEREST_CHANCE", 0.15)) +PROACTIVE_USER_INTEREST_THRESHOLD = float( + os.getenv("PROACTIVE_USER_INTEREST_THRESHOLD", 0.7) +) # Min interest level for Gurt to trigger +PROACTIVE_USER_INTEREST_MATCH_THRESHOLD = float( + os.getenv("PROACTIVE_USER_INTEREST_MATCH_THRESHOLD", 0.5) +) # Min interest level for User (if tracked) - Currently not tracked per user, but config is ready +PROACTIVE_USER_INTEREST_CHANCE = float( + os.getenv("PROACTIVE_USER_INTEREST_CHANCE", 0.15) +) # --- Interest Tracking Config --- -INTEREST_UPDATE_INTERVAL = int(os.getenv("INTEREST_UPDATE_INTERVAL", 1800)) # 30 mins -INTEREST_DECAY_INTERVAL_HOURS = int(os.getenv("INTEREST_DECAY_INTERVAL_HOURS", 24)) # Daily +INTEREST_UPDATE_INTERVAL = int(os.getenv("INTEREST_UPDATE_INTERVAL", 1800)) # 30 mins +INTEREST_DECAY_INTERVAL_HOURS = int( + os.getenv("INTEREST_DECAY_INTERVAL_HOURS", 24) +) # Daily INTEREST_PARTICIPATION_BOOST = float(os.getenv("INTEREST_PARTICIPATION_BOOST", 0.05)) -INTEREST_POSITIVE_REACTION_BOOST = float(os.getenv("INTEREST_POSITIVE_REACTION_BOOST", 0.02)) -INTEREST_NEGATIVE_REACTION_PENALTY = float(os.getenv("INTEREST_NEGATIVE_REACTION_PENALTY", -0.01)) +INTEREST_POSITIVE_REACTION_BOOST = float( + os.getenv("INTEREST_POSITIVE_REACTION_BOOST", 0.02) +) +INTEREST_NEGATIVE_REACTION_PENALTY = float( + os.getenv("INTEREST_NEGATIVE_REACTION_PENALTY", -0.01) +) INTEREST_FACT_BOOST = float(os.getenv("INTEREST_FACT_BOOST", 0.01)) INTEREST_MIN_LEVEL_FOR_PROMPT = float(os.getenv("INTEREST_MIN_LEVEL_FOR_PROMPT", 0.3)) INTEREST_MAX_FOR_PROMPT = int(os.getenv("INTEREST_MAX_FOR_PROMPT", 4)) @@ -192,49 +313,131 @@ INTEREST_MAX_FOR_PROMPT = int(os.getenv("INTEREST_MAX_FOR_PROMPT", 4)) # --- Learning Config --- LEARNING_RATE = 0.05 MAX_PATTERNS_PER_CHANNEL = 50 -LEARNING_UPDATE_INTERVAL = 3600 # Update learned patterns every hour -REFLECTION_INTERVAL_SECONDS = int(os.getenv("REFLECTION_INTERVAL_SECONDS", 6 * 3600)) # Reflect every 6 hours -GOAL_CHECK_INTERVAL = int(os.getenv("GOAL_CHECK_INTERVAL", 300)) # Check for pending goals every 5 mins -GOAL_EXECUTION_INTERVAL = int(os.getenv("GOAL_EXECUTION_INTERVAL", 60)) # Check for active goals to execute every 1 min -PROACTIVE_GOAL_CHECK_INTERVAL = int(os.getenv("PROACTIVE_GOAL_CHECK_INTERVAL", 900)) # Check if Gurt should create its own goals every 15 mins +LEARNING_UPDATE_INTERVAL = 3600 # Update learned patterns every hour +REFLECTION_INTERVAL_SECONDS = int( + os.getenv("REFLECTION_INTERVAL_SECONDS", 6 * 3600) +) # Reflect every 6 hours +GOAL_CHECK_INTERVAL = int( + os.getenv("GOAL_CHECK_INTERVAL", 300) +) # Check for pending goals every 5 mins +GOAL_EXECUTION_INTERVAL = int( + os.getenv("GOAL_EXECUTION_INTERVAL", 60) +) # Check for active goals to execute every 1 min +PROACTIVE_GOAL_CHECK_INTERVAL = int( + os.getenv("PROACTIVE_GOAL_CHECK_INTERVAL", 900) +) # Check if Gurt should create its own goals every 15 mins # --- Internal Random Action Config --- -INTERNAL_ACTION_INTERVAL_SECONDS = int(os.getenv("INTERNAL_ACTION_INTERVAL_SECONDS", 300)) # How often to *consider* a random action (10 mins) -INTERNAL_ACTION_PROBABILITY = float(os.getenv("INTERNAL_ACTION_PROBABILITY", 0.5)) # Chance of performing an action each interval (10%) -AUTONOMOUS_ACTION_REPORT_CHANNEL_ID = os.getenv("GURT_AUTONOMOUS_ACTION_REPORT_CHANNEL_ID", 1366840485355982869) # Optional channel ID to report autonomous actions +INTERNAL_ACTION_INTERVAL_SECONDS = int( + os.getenv("INTERNAL_ACTION_INTERVAL_SECONDS", 300) +) # How often to *consider* a random action (10 mins) +INTERNAL_ACTION_PROBABILITY = float( + os.getenv("INTERNAL_ACTION_PROBABILITY", 0.5) +) # Chance of performing an action each interval (10%) +AUTONOMOUS_ACTION_REPORT_CHANNEL_ID = os.getenv( + "GURT_AUTONOMOUS_ACTION_REPORT_CHANNEL_ID", 1366840485355982869 +) # Optional channel ID to report autonomous actions # --- Bot Response Rate Limit Config --- -BOT_RESPONSE_RATE_LIMIT_PER_MINUTE = int(os.getenv("GURT_BOT_RESPONSE_RATE_LIMIT_PER_MINUTE", 2)) -BOT_RESPONSE_RATE_LIMIT_WINDOW_SECONDS = 60 # 1 minute +BOT_RESPONSE_RATE_LIMIT_PER_MINUTE = int( + os.getenv("GURT_BOT_RESPONSE_RATE_LIMIT_PER_MINUTE", 2) +) +BOT_RESPONSE_RATE_LIMIT_WINDOW_SECONDS = 60 # 1 minute # --- Topic Tracking Config --- -TOPIC_UPDATE_INTERVAL = 300 # Update topics every 5 minutes +TOPIC_UPDATE_INTERVAL = 300 # Update topics every 5 minutes TOPIC_RELEVANCE_DECAY = 0.2 MAX_ACTIVE_TOPICS = 5 # --- Sentiment Tracking Config --- -SENTIMENT_UPDATE_INTERVAL = 300 # Update sentiment every 5 minutes +SENTIMENT_UPDATE_INTERVAL = 300 # Update sentiment every 5 minutes SENTIMENT_DECAY_RATE = 0.1 # --- Emotion Detection --- EMOTION_KEYWORDS = { - "joy": ["happy", "glad", "excited", "yay", "awesome", "love", "great", "amazing", "lol", "lmao", "haha"], - "sadness": ["sad", "upset", "depressed", "unhappy", "disappointed", "crying", "miss", "lonely", "sorry"], - "anger": ["angry", "mad", "hate", "furious", "annoyed", "frustrated", "pissed", "wtf", "fuck"], + "joy": [ + "happy", + "glad", + "excited", + "yay", + "awesome", + "love", + "great", + "amazing", + "lol", + "lmao", + "haha", + ], + "sadness": [ + "sad", + "upset", + "depressed", + "unhappy", + "disappointed", + "crying", + "miss", + "lonely", + "sorry", + ], + "anger": [ + "angry", + "mad", + "hate", + "furious", + "annoyed", + "frustrated", + "pissed", + "wtf", + "fuck", + ], "fear": ["afraid", "scared", "worried", "nervous", "anxious", "terrified", "yikes"], "surprise": ["wow", "omg", "whoa", "what", "really", "seriously", "no way", "wtf"], "disgust": ["gross", "ew", "eww", "disgusting", "nasty", "yuck"], - "confusion": ["confused", "idk", "what?", "huh", "hmm", "weird", "strange"] + "confusion": ["confused", "idk", "what?", "huh", "hmm", "weird", "strange"], } EMOJI_SENTIMENT = { - "positive": ["😊", "😄", "😁", "😆", "😍", "🥰", "❤️", "💕", "👍", "🙌", "✨", "🔥", "💯", "🎉", "🌹"], - "negative": ["😢", "😭", "😞", "😔", "😟", "😠", "😡", "👎", "💔", "😤", "😒", "😩", "😫", "😰", "🥀"], - "neutral": ["😐", "🤔", "🙂", "🙄", "👀", "💭", "🤷", "😶", "🫠"] + "positive": [ + "😊", + "😄", + "😁", + "😆", + "😍", + "🥰", + "❤️", + "💕", + "👍", + "🙌", + "✨", + "🔥", + "💯", + "🎉", + "🌹", + ], + "negative": [ + "😢", + "😭", + "😞", + "😔", + "😟", + "😠", + "😡", + "👎", + "💔", + "😤", + "😒", + "😩", + "😫", + "😰", + "🥀", + ], + "neutral": ["😐", "🤔", "🙂", "🙄", "👀", "💭", "🤷", "😶", "🫠"], } # --- Moderator Configuration --- # List of role names that are considered moderators for tool authorization -MODERATOR_ROLE_NAMES = json.loads(os.getenv("GURT_MODERATOR_ROLE_NAMES", '["Admin", "Moderator"]')) +MODERATOR_ROLE_NAMES = json.loads( + os.getenv("GURT_MODERATOR_ROLE_NAMES", '["Admin", "Moderator"]') +) # --- Docker Command Execution Config --- DOCKER_EXEC_IMAGE = os.getenv("DOCKER_EXEC_IMAGE", "alpine:latest") @@ -243,12 +446,27 @@ DOCKER_CPU_LIMIT = os.getenv("DOCKER_CPU_LIMIT", "0.5") DOCKER_MEM_LIMIT = os.getenv("DOCKER_MEM_LIMIT", "64m") # --- Voice Configuration --- -VOICE_DEDICATED_TEXT_CHANNEL_ENABLED = os.getenv("VOICE_DEDICATED_TEXT_CHANNEL_ENABLED", "true").lower() == "true" -VOICE_DEDICATED_TEXT_CHANNEL_NAME_TEMPLATE = os.getenv("VOICE_DEDICATED_TEXT_CHANNEL_NAME_TEMPLATE", "🎙️gurt-voice-chat") -VOICE_DEDICATED_TEXT_CHANNEL_TOPIC = os.getenv("VOICE_DEDICATED_TEXT_CHANNEL_TOPIC", "GURT Voice Chat | Transcriptions & Text Interactions") -VOICE_DEDICATED_TEXT_CHANNEL_CLEANUP_ON_LEAVE = os.getenv("VOICE_DEDICATED_TEXT_CHANNEL_CLEANUP_ON_LEAVE", "false").lower() == "true" -VOICE_DEDICATED_TEXT_CHANNEL_INITIAL_MESSAGE = os.getenv("VOICE_DEDICATED_TEXT_CHANNEL_INITIAL_MESSAGE", "GURT is listening in voice. Transcriptions and text-based voice interactions will appear here. Type your messages here to talk to GURT in voice!") -VOICE_LOG_SPEECH_TO_DEDICATED_CHANNEL = os.getenv("VOICE_LOG_SPEECH_TO_DEDICATED_CHANNEL", "true").lower() == "true" +VOICE_DEDICATED_TEXT_CHANNEL_ENABLED = ( + os.getenv("VOICE_DEDICATED_TEXT_CHANNEL_ENABLED", "true").lower() == "true" +) +VOICE_DEDICATED_TEXT_CHANNEL_NAME_TEMPLATE = os.getenv( + "VOICE_DEDICATED_TEXT_CHANNEL_NAME_TEMPLATE", "🎙️gurt-voice-chat" +) +VOICE_DEDICATED_TEXT_CHANNEL_TOPIC = os.getenv( + "VOICE_DEDICATED_TEXT_CHANNEL_TOPIC", + "GURT Voice Chat | Transcriptions & Text Interactions", +) +VOICE_DEDICATED_TEXT_CHANNEL_CLEANUP_ON_LEAVE = ( + os.getenv("VOICE_DEDICATED_TEXT_CHANNEL_CLEANUP_ON_LEAVE", "false").lower() + == "true" +) +VOICE_DEDICATED_TEXT_CHANNEL_INITIAL_MESSAGE = os.getenv( + "VOICE_DEDICATED_TEXT_CHANNEL_INITIAL_MESSAGE", + "GURT is listening in voice. Transcriptions and text-based voice interactions will appear here. Type your messages here to talk to GURT in voice!", +) +VOICE_LOG_SPEECH_TO_DEDICATED_CHANNEL = ( + os.getenv("VOICE_LOG_SPEECH_TO_DEDICATED_CHANNEL", "true").lower() == "true" +) # --- Response Schema --- RESPONSE_SCHEMA = { @@ -259,24 +477,24 @@ RESPONSE_SCHEMA = { "properties": { "should_respond": { "type": "boolean", - "description": "Whether the bot should send a text message in response." + "description": "Whether the bot should send a text message in response.", }, "content": { "type": "string", - "description": "The text content of the bot's response. Can be empty if only reacting." + "description": "The text content of the bot's response. Can be empty if only reacting.", }, "react_with_emoji": { "type": ["string", "null"], - "description": "Optional: A standard Discord emoji to react with, or null/empty if no reaction." + "description": "Optional: A standard Discord emoji to react with, or null/empty if no reaction.", }, "reply_to_message_id": { "type": ["string", "null"], - "description": "Optional: The ID of the message this response should reply to. Null or omit for a regular message." + "description": "Optional: The ID of the message this response should reply to. Null or omit for a regular message.", }, # Note: tool_requests is handled by Vertex AI's function calling mechanism }, - "required": ["should_respond", "content"] - } + "required": ["should_respond", "content"], + }, } # --- Summary Response Schema --- @@ -288,11 +506,11 @@ SUMMARY_RESPONSE_SCHEMA = { "properties": { "summary": { "type": "string", - "description": "The generated summary of the conversation." + "description": "The generated summary of the conversation.", } }, - "required": ["summary"] - } + "required": ["summary"], + }, } # --- Profile Update Schema --- @@ -304,49 +522,60 @@ PROFILE_UPDATE_SCHEMA = { "properties": { "should_update": { "type": "boolean", - "description": "True if any profile element should be changed, false otherwise." + "description": "True if any profile element should be changed, false otherwise.", }, "reasoning": { "type": "string", - "description": "Brief reasoning for the decision and chosen updates (or lack thereof)." + "description": "Brief reasoning for the decision and chosen updates (or lack thereof).", }, "updates": { "type": "object", "properties": { "avatar_query": { - "type": ["string", "null"], # Use list type for preprocessor - "description": "Search query for a new avatar image, or null if no change." + "type": ["string", "null"], # Use list type for preprocessor + "description": "Search query for a new avatar image, or null if no change.", }, "new_bio": { - "type": ["string", "null"], # Use list type for preprocessor - "description": "The new bio text (max 190 chars), or null if no change." + "type": ["string", "null"], # Use list type for preprocessor + "description": "The new bio text (max 190 chars), or null if no change.", }, "role_theme": { - "type": ["string", "null"], # Use list type for preprocessor - "description": "A theme for role selection (e.g., color, interest), or null if no role changes." + "type": ["string", "null"], # Use list type for preprocessor + "description": "A theme for role selection (e.g., color, interest), or null if no role changes.", }, "new_activity": { "type": "object", "description": "Object containing the new activity details. Set type and text to null if no change.", "properties": { - "type": { - "type": ["string", "null"], # Use list type for preprocessor - "enum": ["playing", "watching", "listening", "competing"], - "description": "Activity type: 'playing', 'watching', 'listening', 'competing', or null." - }, - "text": { - "type": ["string", "null"], # Use list type for preprocessor - "description": "The activity text, or null." - } + "type": { + "type": [ + "string", + "null", + ], # Use list type for preprocessor + "enum": [ + "playing", + "watching", + "listening", + "competing", + ], + "description": "Activity type: 'playing', 'watching', 'listening', 'competing', or null.", + }, + "text": { + "type": [ + "string", + "null", + ], # Use list type for preprocessor + "description": "The activity text, or null.", + }, }, - "required": ["type", "text"] - } + "required": ["type", "text"], + }, }, - "required": ["avatar_query", "new_bio", "role_theme", "new_activity"] - } + "required": ["avatar_query", "new_bio", "role_theme", "new_activity"], + }, }, - "required": ["should_update", "reasoning", "updates"] - } + "required": ["should_update", "reasoning", "updates"], + }, } # --- Role Selection Schema --- @@ -359,16 +588,16 @@ ROLE_SELECTION_SCHEMA = { "roles_to_add": { "type": "array", "items": {"type": "string"}, - "description": "List of role names to add (max 2)." + "description": "List of role names to add (max 2).", }, "roles_to_remove": { "type": "array", "items": {"type": "string"}, - "description": "List of role names to remove (max 2, only from current roles)." - } + "description": "List of role names to remove (max 2, only from current roles).", + }, }, - "required": ["roles_to_add", "roles_to_remove"] - } + "required": ["roles_to_add", "roles_to_remove"], + }, } # --- Proactive Planning Schema --- @@ -380,28 +609,28 @@ PROACTIVE_PLAN_SCHEMA = { "properties": { "should_respond": { "type": "boolean", - "description": "Whether Gurt should respond proactively based on the plan." + "description": "Whether Gurt should respond proactively based on the plan.", }, "reasoning": { "type": "string", - "description": "Brief reasoning for the decision (why respond or not respond)." + "description": "Brief reasoning for the decision (why respond or not respond).", }, "response_goal": { "type": "string", - "description": "The intended goal of the proactive message (e.g., 'revive chat', 'share related info', 'react to sentiment', 'engage user interest')." + "description": "The intended goal of the proactive message (e.g., 'revive chat', 'share related info', 'react to sentiment', 'engage user interest').", }, "key_info_to_include": { "type": "array", "items": {"type": "string"}, - "description": "List of key pieces of information or context points to potentially include in the response (e.g., specific topic, user fact, relevant external info)." + "description": "List of key pieces of information or context points to potentially include in the response (e.g., specific topic, user fact, relevant external info).", }, "suggested_tone": { "type": "string", - "description": "Suggested tone adjustment based on context (e.g., 'more upbeat', 'more curious', 'slightly teasing')." - } + "description": "Suggested tone adjustment based on context (e.g., 'more upbeat', 'more curious', 'slightly teasing').", + }, }, - "required": ["should_respond", "reasoning", "response_goal"] - } + "required": ["should_respond", "reasoning", "response_goal"], + }, } # --- Goal Decomposition Schema --- @@ -413,11 +642,11 @@ GOAL_DECOMPOSITION_SCHEMA = { "properties": { "goal_achievable": { "type": "boolean", - "description": "Whether the goal seems achievable with available tools and context." + "description": "Whether the goal seems achievable with available tools and context.", }, "reasoning": { "type": "string", - "description": "Brief reasoning for achievability and the chosen steps." + "description": "Brief reasoning for achievability and the chosen steps.", }, "steps": { "type": "array", @@ -427,23 +656,23 @@ GOAL_DECOMPOSITION_SCHEMA = { "properties": { "step_description": { "type": "string", - "description": "Natural language description of the step." + "description": "Natural language description of the step.", }, "tool_name": { "type": ["string", "null"], - "description": "The name of the tool to use for this step, or null if no tool is needed (e.g., internal reasoning)." + "description": "The name of the tool to use for this step, or null if no tool is needed (e.g., internal reasoning).", }, "tool_arguments": { "type": ["object", "null"], - "description": "A dictionary of arguments for the tool call, or null." - } + "description": "A dictionary of arguments for the tool call, or null.", + }, }, - "required": ["step_description"] - } - } + "required": ["step_description"], + }, + }, }, - "required": ["goal_achievable", "reasoning", "steps"] - } + "required": ["goal_achievable", "reasoning", "steps"], + }, } @@ -454,7 +683,7 @@ def create_tools_list(): # It now requires 'FunctionDeclaration' from 'google.generativeai.types' to be imported. tool_declarations = [] tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="get_recent_messages", description="Get recent messages from a Discord channel", parameters={ @@ -462,19 +691,19 @@ def create_tools_list(): "properties": { "channel_id": { "type": "string", - "description": "The ID of the channel to get messages from. If not provided, uses the current channel." + "description": "The ID of the channel to get messages from. If not provided, uses the current channel.", }, "limit": { - "type": "integer", # Corrected type - "description": "The maximum number of messages to retrieve (1-100)" - } + "type": "integer", # Corrected type + "description": "The maximum number of messages to retrieve (1-100)", + }, }, - "required": ["limit"] - } + "required": ["limit"], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="search_user_messages", description="Search for messages from a specific user by their User ID.", parameters={ @@ -482,23 +711,23 @@ def create_tools_list(): "properties": { "user_id": { "type": "string", - "description": "The User ID of the user whose messages to search for." + "description": "The User ID of the user whose messages to search for.", }, "channel_id": { "type": "string", - "description": "The ID of the channel to search in. If not provided, searches in the current channel." + "description": "The ID of the channel to search in. If not provided, searches in the current channel.", }, "limit": { - "type": "integer", # Corrected type - "description": "The maximum number of messages to retrieve (1-100)" - } + "type": "integer", # Corrected type + "description": "The maximum number of messages to retrieve (1-100)", + }, }, - "required": ["user_id", "limit"] - } + "required": ["user_id", "limit"], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="search_messages_by_content", description="Search for messages containing specific content", parameters={ @@ -506,23 +735,23 @@ def create_tools_list(): "properties": { "search_term": { "type": "string", - "description": "The text to search for in messages" + "description": "The text to search for in messages", }, "channel_id": { "type": "string", - "description": "The ID of the channel to search in. If not provided, searches in the current channel." + "description": "The ID of the channel to search in. If not provided, searches in the current channel.", }, "limit": { - "type": "integer", # Corrected type - "description": "The maximum number of messages to retrieve (1-100)" - } + "type": "integer", # Corrected type + "description": "The maximum number of messages to retrieve (1-100)", + }, }, - "required": ["search_term", "limit"] - } + "required": ["search_term", "limit"], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="get_channel_info", description="Get information about a Discord channel", parameters={ @@ -530,15 +759,15 @@ def create_tools_list(): "properties": { "channel_id": { "type": "string", - "description": "The ID of the channel to get information about. If not provided, uses the current channel." + "description": "The ID of the channel to get information about. If not provided, uses the current channel.", } }, - "required": [] - } + "required": [], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="get_conversation_context", description="Get the context of the current conversation", parameters={ @@ -546,19 +775,19 @@ def create_tools_list(): "properties": { "channel_id": { "type": "string", - "description": "The ID of the channel to get conversation context from. If not provided, uses the current channel." + "description": "The ID of the channel to get conversation context from. If not provided, uses the current channel.", }, "message_count": { - "type": "integer", # Corrected type - "description": "The number of messages to include in the context (5-50)" - } + "type": "integer", # Corrected type + "description": "The number of messages to include in the context (5-50)", + }, }, - "required": ["message_count"] - } + "required": ["message_count"], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="get_thread_context", description="Get the context of a thread conversation", parameters={ @@ -566,19 +795,19 @@ def create_tools_list(): "properties": { "thread_id": { "type": "string", - "description": "The ID of the thread to get context from" + "description": "The ID of the thread to get context from", }, "message_count": { - "type": "integer", # Corrected type - "description": "The number of messages to include in the context (5-50)" - } + "type": "integer", # Corrected type + "description": "The number of messages to include in the context (5-50)", + }, }, - "required": ["thread_id", "message_count"] - } + "required": ["thread_id", "message_count"], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="get_user_interaction_history", description="Get the history of interactions between users by their User IDs.", parameters={ @@ -586,23 +815,23 @@ def create_tools_list(): "properties": { "user_id_1": { "type": "string", - "description": "The User ID of the first user." + "description": "The User ID of the first user.", }, "user_id_2": { "type": "string", - "description": "The User ID of the second user. If not provided, gets interactions between user_id_1 and the bot." + "description": "The User ID of the second user. If not provided, gets interactions between user_id_1 and the bot.", }, "limit": { - "type": "integer", # Corrected type - "description": "The maximum number of interactions to retrieve (1-50)" - } + "type": "integer", # Corrected type + "description": "The maximum number of interactions to retrieve (1-50)", + }, }, - "required": ["user_id_1", "limit"] - } + "required": ["user_id_1", "limit"], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="get_conversation_summary", description="Get a summary of the recent conversation in a channel", parameters={ @@ -610,15 +839,15 @@ def create_tools_list(): "properties": { "channel_id": { "type": "string", - "description": "The ID of the channel to get the conversation summary from. If not provided, uses the current channel." + "description": "The ID of the channel to get the conversation summary from. If not provided, uses the current channel.", } }, - "required": [] - } + "required": [], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="get_message_context", description="Get the context around a specific message", parameters={ @@ -626,23 +855,23 @@ def create_tools_list(): "properties": { "message_id": { "type": "string", - "description": "The ID of the message to get context for" + "description": "The ID of the message to get context for", }, "before_count": { - "type": "integer", # Corrected type - "description": "The number of messages to include before the specified message (1-25)" + "type": "integer", # Corrected type + "description": "The number of messages to include before the specified message (1-25)", }, "after_count": { - "type": "integer", # Corrected type - "description": "The number of messages to include after the specified message (1-25)" - } + "type": "integer", # Corrected type + "description": "The number of messages to include after the specified message (1-25)", + }, }, - "required": ["message_id"] - } + "required": ["message_id"], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="web_search", description="Search the web for information on a given topic or query. Use this to find current information, facts, or context about things mentioned in the chat.", parameters={ @@ -650,15 +879,15 @@ def create_tools_list(): "properties": { "query": { "type": "string", - "description": "The search query or topic to look up online." + "description": "The search query or topic to look up online.", } }, - "required": ["query"] - } + "required": ["query"], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="remember_user_fact", description="Store a specific fact or piece of information about a user (identified by User ID) for later recall. Use this when you learn something potentially relevant about a user (e.g., their preferences, current activity, mentioned interests).", parameters={ @@ -666,19 +895,19 @@ def create_tools_list(): "properties": { "user_id": { "type": "string", - "description": "The User ID of the user the fact is about." + "description": "The User ID of the user the fact is about.", }, "fact": { "type": "string", - "description": "The specific fact to remember about the user (keep it concise)." - } + "description": "The specific fact to remember about the user (keep it concise).", + }, }, - "required": ["user_id", "fact"] - } + "required": ["user_id", "fact"], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="get_user_facts", description="Retrieve previously stored facts or information about a specific user by their User ID. Use this before responding to a user to potentially recall relevant details about them.", parameters={ @@ -686,15 +915,15 @@ def create_tools_list(): "properties": { "user_id": { "type": "string", - "description": "The User ID of the user whose facts you want to retrieve." + "description": "The User ID of the user whose facts you want to retrieve.", } }, - "required": ["user_id"] - } + "required": ["user_id"], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="remember_general_fact", description="Store a general fact or piece of information not specific to a user (e.g., server events, shared knowledge, recent game updates). Use this to remember context relevant to the community or ongoing discussions.", parameters={ @@ -702,15 +931,15 @@ def create_tools_list(): "properties": { "fact": { "type": "string", - "description": "The general fact to remember (keep it concise)." + "description": "The general fact to remember (keep it concise).", } }, - "required": ["fact"] - } + "required": ["fact"], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="get_general_facts", description="Retrieve previously stored general facts or shared knowledge. Use this to recall context about the server, ongoing events, or general information.", parameters={ @@ -718,19 +947,19 @@ def create_tools_list(): "properties": { "query": { "type": "string", - "description": "Optional: A keyword or phrase to search within the general facts. If omitted, returns recent general facts." + "description": "Optional: A keyword or phrase to search within the general facts. If omitted, returns recent general facts.", }, "limit": { - "type": "integer", # Corrected type - "description": "Optional: Maximum number of facts to return (default 10)." - } + "type": "integer", # Corrected type + "description": "Optional: Maximum number of facts to return (default 10).", + }, }, - "required": [] - } + "required": [], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="timeout_user", description="Timeout a user (identified by User ID) in the current server for a specified duration. Requires a moderator's user_id for authorization. Use this playfully or when someone says something you (Gurt) dislike or find funny.", parameters={ @@ -738,27 +967,27 @@ def create_tools_list(): "properties": { "user_id": { "type": "string", - "description": "The User ID of the user to timeout." + "description": "The User ID of the user to timeout.", }, "duration_minutes": { "type": "integer", - "description": "The duration of the timeout in minutes (1-1440, e.g., 5 for 5 minutes)." + "description": "The duration of the timeout in minutes (1-1440, e.g., 5 for 5 minutes).", }, "reason": { "type": "string", - "description": "Optional: The reason for the timeout (keep it short and in character)." + "description": "Optional: The reason for the timeout (keep it short and in character).", }, "requesting_user_id": { "type": "string", - "description": "The User ID of the user requesting the timeout. This user must be a moderator." - } + "description": "The User ID of the user requesting the timeout. This user must be a moderator.", + }, }, - "required": ["user_id", "duration_minutes", "requesting_user_id"] - } + "required": ["user_id", "duration_minutes", "requesting_user_id"], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="calculate", description="Evaluate a mathematical expression using a safe interpreter. Handles standard arithmetic, functions (sin, cos, sqrt, etc.), and variables.", parameters={ @@ -766,15 +995,15 @@ def create_tools_list(): "properties": { "expression": { "type": "string", - "description": "The mathematical expression to evaluate (e.g., '2 * (3 + 4)', 'sqrt(16) + sin(pi/2)')." + "description": "The mathematical expression to evaluate (e.g., '2 * (3 + 4)', 'sqrt(16) + sin(pi/2)').", } }, - "required": ["expression"] - } + "required": ["expression"], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="run_python_code", description="Execute a snippet of Python 3 code in a sandboxed environment using an external API. Returns the standard output and standard error.", parameters={ @@ -782,15 +1011,15 @@ def create_tools_list(): "properties": { "code": { "type": "string", - "description": "The Python 3 code snippet to execute." + "description": "The Python 3 code snippet to execute.", } }, - "required": ["code"] - } + "required": ["code"], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="create_poll", description="Create a simple poll message in the current channel with numbered reactions for voting.", parameters={ @@ -798,22 +1027,20 @@ def create_tools_list(): "properties": { "question": { "type": "string", - "description": "The question for the poll." + "description": "The question for the poll.", }, "options": { "type": "array", "description": "A list of strings representing the poll options (minimum 2, maximum 10).", - "items": { - "type": "string" - } - } + "items": {"type": "string"}, + }, }, - "required": ["question", "options"] - } + "required": ["question", "options"], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="run_terminal_command", description="DANGEROUS: Execute a shell command in an isolated, temporary Docker container after an AI safety check. Returns stdout and stderr. Use with extreme caution only for simple, harmless commands like 'echo', 'ls', 'pwd'. Avoid file modification, network access, or long-running processes.", parameters={ @@ -821,15 +1048,15 @@ def create_tools_list(): "properties": { "command": { "type": "string", - "description": "The shell command to execute." + "description": "The shell command to execute.", } }, - "required": ["command"] - } + "required": ["command"], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="remove_timeout", description="Remove an active timeout from a user (identified by User ID) in the current server.", parameters={ @@ -837,19 +1064,19 @@ def create_tools_list(): "properties": { "user_id": { "type": "string", - "description": "The User ID of the user whose timeout should be removed." + "description": "The User ID of the user whose timeout should be removed.", }, "reason": { "type": "string", - "description": "Optional: The reason for removing the timeout (keep it short and in character)." - } + "description": "Optional: The reason for removing the timeout (keep it short and in character).", + }, }, - "required": ["user_id"] - } + "required": ["user_id"], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="read_file_content", description="Reads the content of a specified file. WARNING: No safety checks are performed. Reads files relative to the bot's current working directory.", parameters={ @@ -857,15 +1084,15 @@ def create_tools_list(): "properties": { "file_path": { "type": "string", - "description": "The relative path to the file from the project root (e.g., 'discordbot/gurt/config.py')." + "description": "The relative path to the file from the project root (e.g., 'discordbot/gurt/config.py').", } }, - "required": ["file_path"] - } + "required": ["file_path"], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="create_new_tool", description="EXPERIMENTAL/DANGEROUS: Attempts to create a new tool by generating Python code and its definition using an LLM, then writing it to files. Requires manual reload/restart.", parameters={ @@ -873,28 +1100,33 @@ def create_tools_list(): "properties": { "tool_name": { "type": "string", - "description": "The desired name for the new tool (valid Python function name)." + "description": "The desired name for the new tool (valid Python function name).", }, "description": { "type": "string", - "description": "The description of what the new tool does (for the FunctionDeclaration)." + "description": "The description of what the new tool does (for the FunctionDeclaration).", }, "parameters_json": { "type": "string", - "description": "A JSON string defining the tool's parameters (properties and required fields), e.g., '{\"properties\": {\"arg1\": {\"type\": \"string\"}}, \"required\": [\"arg1\"]}'." + "description": 'A JSON string defining the tool\'s parameters (properties and required fields), e.g., \'{"properties": {"arg1": {"type": "string"}}, "required": ["arg1"]}\'.', }, "returns_description": { "type": "string", - "description": "A description of what the Python function should return (e.g., 'a dictionary with status and result')." - } + "description": "A description of what the Python function should return (e.g., 'a dictionary with status and result').", + }, }, - "required": ["tool_name", "description", "parameters_json", "returns_description"] - } + "required": [ + "tool_name", + "description", + "parameters_json", + "returns_description", + ], + }, ) ) tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="execute_internal_command", description="Executes a shell command directly on the host machine. Only user_id 452666956353503252 is authorized. You must first use get_user_id to get the user_id param.", parameters={ @@ -902,25 +1134,25 @@ def create_tools_list(): "properties": { "command": { "type": "string", - "description": "The shell command to execute internally." + "description": "The shell command to execute internally.", }, "timeout_seconds": { "type": "integer", - "description": "Optional timeout in seconds for the command (default 60)." + "description": "Optional timeout in seconds for the command (default 60).", }, "user_id": { "type": "string", - "description": "The Discord user ID of the user requesting execution." - } + "description": "The Discord user ID of the user requesting execution.", + }, }, - "required": ["command", "user_id"] - } + "required": ["command", "user_id"], + }, ) ) # --- get_user_id --- tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="get_user_id", description="Finds the Discord User ID for a given username or display name (typically for users *other* than the one currently interacting, as you are already provided with the current user's ID and username). Searches the current server or recent messages.", parameters={ @@ -928,30 +1160,30 @@ def create_tools_list(): "properties": { "user_name": { "type": "string", - "description": "The username (e.g., 'user#1234') or display name of the user to find." + "description": "The username (e.g., 'user#1234') or display name of the user to find.", } }, - "required": ["user_name"] - } + "required": ["user_name"], + }, ) ) # --- no_operation --- tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="no_operation", description="Does absolutely nothing. Used when a tool call is forced but no action is needed.", parameters={ "type": "object", - "properties": {}, # No parameters - "required": [] - } + "properties": {}, # No parameters + "required": [], + }, ) ) # --- write_file_content_unsafe --- tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="write_file_content_unsafe", description="Writes content to a specified file. WARNING: No safety checks are performed. Uses 'w' (overwrite) or 'a' (append) mode. Creates directories if needed.", parameters={ @@ -959,26 +1191,26 @@ def create_tools_list(): "properties": { "file_path": { "type": "string", - "description": "The relative path to the file to write to." + "description": "The relative path to the file to write to.", }, "content": { "type": "string", - "description": "The content to write to the file." + "description": "The content to write to the file.", }, "mode": { "type": "string", "description": "The write mode: 'w' for overwrite (default), 'a' for append.", - "enum": ["w", "a"] - } + "enum": ["w", "a"], + }, }, - "required": ["file_path", "content"] - } + "required": ["file_path", "content"], + }, ) ) # --- execute_python_unsafe --- tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="execute_python_unsafe", description="Executes arbitrary Python code directly on the host using exec(). WARNING: EXTREMELY DANGEROUS. No sandboxing.", parameters={ @@ -986,21 +1218,21 @@ def create_tools_list(): "properties": { "code": { "type": "string", - "description": "The Python code string to execute." + "description": "The Python code string to execute.", }, "timeout_seconds": { "type": "integer", - "description": "Optional timeout in seconds (default 30)." - } + "description": "Optional timeout in seconds (default 30).", + }, }, - "required": ["code"] - } + "required": ["code"], + }, ) ) # --- send_discord_message --- tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="send_discord_message", description="Sends a message to a specified Discord channel ID.", parameters={ @@ -1008,21 +1240,21 @@ def create_tools_list(): "properties": { "channel_id": { "type": "string", - "description": "The ID of the Discord channel to send the message to." + "description": "The ID of the Discord channel to send the message to.", }, "message_content": { "type": "string", - "description": "The text content of the message to send." - } + "description": "The text content of the message to send.", + }, }, - "required": ["channel_id", "message_content"] - } + "required": ["channel_id", "message_content"], + }, ) ) # --- extract_web_content --- tool_declarations.append( - FunctionDeclaration( # Use the imported FunctionDeclaration + FunctionDeclaration( # Use the imported FunctionDeclaration name="extract_web_content", description="Extracts the main textual content and optionally images from one or more web page URLs using the Tavily API.", parameters={ @@ -1031,20 +1263,20 @@ def create_tools_list(): "urls": { "type": "array", "description": "A single URL string or a list of URL strings to extract content from.", - "items": {"type": "string"} + "items": {"type": "string"}, }, "extract_depth": { "type": "string", "description": "The depth of extraction ('basic' or 'advanced'). 'basic' is faster and cheaper, 'advanced' is better for complex/dynamic pages like LinkedIn. Defaults to 'basic'.", - "enum": ["basic", "advanced"] + "enum": ["basic", "advanced"], }, "include_images": { "type": "boolean", - "description": "Whether to include images found on the page in the result. Defaults to false." - } + "description": "Whether to include images found on the page in the result. Defaults to false.", + }, }, - "required": ["urls"] - } + "required": ["urls"], + }, ) ) @@ -1057,11 +1289,11 @@ def create_tools_list(): "properties": { "channel_id": { "type": "string", - "description": "The ID of the channel to send the restart message in. If not provided, no message is sent." + "description": "The ID of the channel to send the restart message in. If not provided, no message is sent.", } }, - "required": ["channel_id"] - } + "required": ["channel_id"], + }, ) ) tool_declarations.append( @@ -1073,11 +1305,11 @@ def create_tools_list(): "properties": { "user_id": { "type": "string", - "description": "The Discord user ID of the user requesting the git pull. Required for authorization." + "description": "The Discord user ID of the user requesting the git pull. Required for authorization.", } }, - "required": ["user_id"] - } + "required": ["user_id"], + }, ) ) tool_declarations.append( @@ -1089,11 +1321,11 @@ def create_tools_list(): "properties": { "channel_name": { "type": "string", - "description": "The name of the channel to look up. If omitted, uses the current channel." + "description": "The name of the channel to look up. If omitted, uses the current channel.", } }, - "required": [] - } + "required": [], + }, ) ) # --- Batch 1 Tool Declarations --- @@ -1101,11 +1333,7 @@ def create_tools_list(): FunctionDeclaration( name="get_guild_info", description="Gets information about the current Discord server (name, ID, owner, member count, etc.).", - parameters={ - "type": "object", - "properties": {}, - "required": [] - } + parameters={"type": "object", "properties": {}, "required": []}, ) ) tool_declarations.append( @@ -1117,19 +1345,19 @@ def create_tools_list(): "properties": { "limit": { "type": "integer", - "description": "Maximum number of members to return (default 50, max 1000)." + "description": "Maximum number of members to return (default 50, max 1000).", }, "status_filter": { "type": "string", - "description": "Optional: Filter by status ('online', 'idle', 'dnd', 'offline')." + "description": "Optional: Filter by status ('online', 'idle', 'dnd', 'offline').", }, "role_id_filter": { "type": "string", - "description": "Optional: Filter by members having a specific role ID." - } + "description": "Optional: Filter by members having a specific role ID.", + }, }, - "required": [] - } + "required": [], + }, ) ) tool_declarations.append( @@ -1141,22 +1369,18 @@ def create_tools_list(): "properties": { "user_id": { "type": "string", - "description": "The User ID of the user." + "description": "The User ID of the user.", } }, - "required": ["user_id"] - } + "required": ["user_id"], + }, ) ) tool_declarations.append( FunctionDeclaration( name="get_bot_uptime", description="Gets the duration the bot has been running since its last start.", - parameters={ - "type": "object", - "properties": {}, - "required": [] - } + parameters={"type": "object", "properties": {}, "required": []}, ) ) tool_declarations.append( @@ -1168,19 +1392,19 @@ def create_tools_list(): "properties": { "channel_id": { "type": "string", - "description": "The ID of the channel to send the message to." + "description": "The ID of the channel to send the message to.", }, "message_content": { "type": "string", - "description": "The content of the message to schedule." + "description": "The content of the message to schedule.", }, "send_at_iso": { "type": "string", - "description": "The exact time to send the message in ISO 8601 format, including timezone (e.g., '2024-01-01T12:00:00+00:00')." - } + "description": "The exact time to send the message in ISO 8601 format, including timezone (e.g., '2024-01-01T12:00:00+00:00').", + }, }, - "required": ["channel_id", "message_content", "send_at_iso"] - } + "required": ["channel_id", "message_content", "send_at_iso"], + }, ) ) # --- End Batch 1 --- @@ -1195,15 +1419,15 @@ def create_tools_list(): "properties": { "message_id": { "type": "string", - "description": "The ID of the message to delete." + "description": "The ID of the message to delete.", }, "channel_id": { "type": "string", - "description": "Optional: The ID of the channel containing the message. Defaults to the current channel." - } + "description": "Optional: The ID of the channel containing the message. Defaults to the current channel.", + }, }, - "required": ["message_id"] - } + "required": ["message_id"], + }, ) ) tool_declarations.append( @@ -1215,19 +1439,19 @@ def create_tools_list(): "properties": { "message_id": { "type": "string", - "description": "The ID of the bot's message to edit." + "description": "The ID of the bot's message to edit.", }, "new_content": { "type": "string", - "description": "The new text content for the message." + "description": "The new text content for the message.", }, "channel_id": { "type": "string", - "description": "Optional: The ID of the channel containing the message. Defaults to the current channel." - } + "description": "Optional: The ID of the channel containing the message. Defaults to the current channel.", + }, }, - "required": ["message_id", "new_content"] - } + "required": ["message_id", "new_content"], + }, ) ) tool_declarations.append( @@ -1239,11 +1463,11 @@ def create_tools_list(): "properties": { "channel_id": { "type": "string", - "description": "The ID of the voice channel to get information about." + "description": "The ID of the voice channel to get information about.", } }, - "required": ["channel_id"] - } + "required": ["channel_id"], + }, ) ) tool_declarations.append( @@ -1255,26 +1479,22 @@ def create_tools_list(): "properties": { "user_id": { "type": "string", - "description": "The User ID of the user to move." + "description": "The User ID of the user to move.", }, "target_channel_id": { "type": "string", - "description": "The ID of the voice channel to move the user to." - } + "description": "The ID of the voice channel to move the user to.", + }, }, - "required": ["user_id", "target_channel_id"] - } + "required": ["user_id", "target_channel_id"], + }, ) ) tool_declarations.append( FunctionDeclaration( name="get_guild_roles", description="Lists all roles available in the current server, ordered by position.", - parameters={ - "type": "object", - "properties": {}, - "required": [] - } + parameters={"type": "object", "properties": {}, "required": []}, ) ) # --- End Batch 2 --- @@ -1289,15 +1509,15 @@ def create_tools_list(): "properties": { "user_id": { "type": "string", - "description": "The User ID of the user to assign the role to." + "description": "The User ID of the user to assign the role to.", }, "role_id": { "type": "string", - "description": "The ID of the role to assign." - } + "description": "The ID of the role to assign.", + }, }, - "required": ["user_id", "role_id"] - } + "required": ["user_id", "role_id"], + }, ) ) tool_declarations.append( @@ -1309,37 +1529,29 @@ def create_tools_list(): "properties": { "user_id": { "type": "string", - "description": "The User ID of the user to remove the role from." + "description": "The User ID of the user to remove the role from.", }, "role_id": { "type": "string", - "description": "The ID of the role to remove." - } + "description": "The ID of the role to remove.", + }, }, - "required": ["user_id", "role_id"] - } + "required": ["user_id", "role_id"], + }, ) ) tool_declarations.append( FunctionDeclaration( name="fetch_emoji_list", description="Lists all custom emojis available in the current server.", - parameters={ - "type": "object", - "properties": {}, - "required": [] - } + parameters={"type": "object", "properties": {}, "required": []}, ) ) tool_declarations.append( FunctionDeclaration( name="get_guild_invites", description="Lists active invite links for the current server. Requires 'Manage Server' permission.", - parameters={ - "type": "object", - "properties": {}, - "required": [] - } + parameters={"type": "object", "properties": {}, "required": []}, ) ) tool_declarations.append( @@ -1351,27 +1563,27 @@ def create_tools_list(): "properties": { "limit": { "type": "integer", - "description": "The maximum number of messages to delete (1-1000)." + "description": "The maximum number of messages to delete (1-1000).", }, "channel_id": { "type": "string", - "description": "Optional: The ID of the text channel to purge. Defaults to the current channel." + "description": "Optional: The ID of the text channel to purge. Defaults to the current channel.", }, "user_id": { "type": "string", - "description": "Optional: Filter to only delete messages from this User ID." + "description": "Optional: Filter to only delete messages from this User ID.", }, "before_message_id": { "type": "string", - "description": "Optional: Only delete messages before this message ID." + "description": "Optional: Only delete messages before this message ID.", }, "after_message_id": { "type": "string", - "description": "Optional: Only delete messages after this message ID." - } + "description": "Optional: Only delete messages after this message ID.", + }, }, - "required": ["limit"] - } + "required": ["limit"], + }, ) ) # --- End Batch 3 --- @@ -1381,11 +1593,7 @@ def create_tools_list(): FunctionDeclaration( name="get_bot_stats", description="Gets various statistics about the bot's current state (guild count, latency, uptime, memory usage, etc.).", - parameters={ - "type": "object", - "properties": {}, - "required": [] - } + parameters={"type": "object", "properties": {}, "required": []}, ) ) tool_declarations.append( @@ -1397,11 +1605,11 @@ def create_tools_list(): "properties": { "location": { "type": "string", - "description": "The city name or zip code to get weather for (e.g., 'London', '90210, US')." + "description": "The city name or zip code to get weather for (e.g., 'London', '90210, US').", } }, - "required": ["location"] - } + "required": ["location"], + }, ) ) tool_declarations.append( @@ -1411,21 +1619,18 @@ def create_tools_list(): parameters={ "type": "object", "properties": { - "text": { - "type": "string", - "description": "The text to translate." - }, + "text": {"type": "string", "description": "The text to translate."}, "target_language": { "type": "string", - "description": "The target language code (e.g., 'es' for Spanish, 'ja' for Japanese)." + "description": "The target language code (e.g., 'es' for Spanish, 'ja' for Japanese).", }, "source_language": { "type": "string", - "description": "Optional: The source language code. If omitted, the API will attempt auto-detection." - } + "description": "Optional: The source language code. If omitted, the API will attempt auto-detection.", + }, }, - "required": ["text", "target_language"] - } + "required": ["text", "target_language"], + }, ) ) tool_declarations.append( @@ -1437,19 +1642,19 @@ def create_tools_list(): "properties": { "user_id": { "type": "string", - "description": "The User ID of the user to remind." + "description": "The User ID of the user to remind.", }, "reminder_text": { "type": "string", - "description": "The text content of the reminder." + "description": "The text content of the reminder.", }, "remind_at_iso": { "type": "string", - "description": "The exact time to send the reminder in ISO 8601 format, including timezone (e.g., '2024-01-01T12:00:00+00:00')." - } + "description": "The exact time to send the reminder in ISO 8601 format, including timezone (e.g., '2024-01-01T12:00:00+00:00').", + }, }, - "required": ["user_id", "reminder_text", "remind_at_iso"] - } + "required": ["user_id", "reminder_text", "remind_at_iso"], + }, ) ) tool_declarations.append( @@ -1461,11 +1666,11 @@ def create_tools_list(): "properties": { "query": { "type": "string", - "description": "Optional: A keyword or query to search for specific types of images (e.g., 'cats', 'landscape')." + "description": "Optional: A keyword or query to search for specific types of images (e.g., 'cats', 'landscape').", } }, - "required": [] - } + "required": [], + }, ) ) # --- End Batch 4 --- @@ -1475,33 +1680,21 @@ def create_tools_list(): FunctionDeclaration( name="read_temps", description="Reads the System temperatures (returns a meme if not available). Use for random system checks or to make fun of the user's thermal paste.", - parameters={ - "type": "object", - "properties": {}, - "required": [] - } + parameters={"type": "object", "properties": {}, "required": []}, ) ) tool_declarations.append( FunctionDeclaration( name="check_disk_space", description="Checks disk space on the main drive and returns a meme/quip about how full it is.", - parameters={ - "type": "object", - "properties": {}, - "required": [] - } + parameters={"type": "object", "properties": {}, "required": []}, ) ) tool_declarations.append( FunctionDeclaration( name="fetch_random_joke", description="Fetches a random joke from a public API. Use for random humor or to break the ice.", - parameters={ - "type": "object", - "properties": {}, - "required": [] - } + parameters={"type": "object", "properties": {}, "required": []}, ) ) @@ -1510,11 +1703,7 @@ def create_tools_list(): FunctionDeclaration( name="list_bot_guilds", description="Lists all guilds (servers) the bot is currently connected to.", - parameters={ - "type": "object", - "properties": {}, - "required": [] - } + parameters={"type": "object", "properties": {}, "required": []}, ) ) tool_declarations.append( @@ -1526,11 +1715,11 @@ def create_tools_list(): "properties": { "guild_id": { "type": "string", - "description": "The ID of the guild to list channels for." + "description": "The ID of the guild to list channels for.", } }, - "required": ["guild_id"] - } + "required": ["guild_id"], + }, ) ) @@ -1539,11 +1728,7 @@ def create_tools_list(): FunctionDeclaration( name="list_tools", description="Lists all available tools with their names and descriptions.", - parameters={ - "type": "object", - "properties": {}, - "required": [] - } + parameters={"type": "object", "properties": {}, "required": []}, ) ) @@ -1555,10 +1740,13 @@ def create_tools_list(): parameters={ "type": "object", "properties": { - "user_id": {"type": "string", "description": "The User ID of the target user."} + "user_id": { + "type": "string", + "description": "The User ID of the target user.", + } }, - "required": ["user_id"] - } + "required": ["user_id"], + }, ) ) tool_declarations.append( @@ -1568,10 +1756,13 @@ def create_tools_list(): parameters={ "type": "object", "properties": { - "user_id": {"type": "string", "description": "The User ID of the target user."} + "user_id": { + "type": "string", + "description": "The User ID of the target user.", + } }, - "required": ["user_id"] - } + "required": ["user_id"], + }, ) ) tool_declarations.append( @@ -1581,10 +1772,13 @@ def create_tools_list(): parameters={ "type": "object", "properties": { - "user_id": {"type": "string", "description": "The User ID of the target user."} + "user_id": { + "type": "string", + "description": "The User ID of the target user.", + } }, - "required": ["user_id"] - } + "required": ["user_id"], + }, ) ) tool_declarations.append( @@ -1594,10 +1788,13 @@ def create_tools_list(): parameters={ "type": "object", "properties": { - "user_id": {"type": "string", "description": "The User ID of the target user."} + "user_id": { + "type": "string", + "description": "The User ID of the target user.", + } }, - "required": ["user_id"] - } + "required": ["user_id"], + }, ) ) tool_declarations.append( @@ -1607,10 +1804,13 @@ def create_tools_list(): parameters={ "type": "object", "properties": { - "user_id": {"type": "string", "description": "The User ID of the target user."} + "user_id": { + "type": "string", + "description": "The User ID of the target user.", + } }, - "required": ["user_id"] - } + "required": ["user_id"], + }, ) ) tool_declarations.append( @@ -1620,10 +1820,13 @@ def create_tools_list(): parameters={ "type": "object", "properties": { - "user_id": {"type": "string", "description": "The User ID of the target user."} + "user_id": { + "type": "string", + "description": "The User ID of the target user.", + } }, - "required": ["user_id"] - } + "required": ["user_id"], + }, ) ) tool_declarations.append( @@ -1633,10 +1836,13 @@ def create_tools_list(): parameters={ "type": "object", "properties": { - "user_id": {"type": "string", "description": "The User ID of the target user."} + "user_id": { + "type": "string", + "description": "The User ID of the target user.", + } }, - "required": ["user_id"] - } + "required": ["user_id"], + }, ) ) # --- End User Profile Tool Declarations --- @@ -1649,12 +1855,15 @@ def create_tools_list(): parameters={ "type": "object", "properties": { - "user_id": {"type": "string", "description": "The User ID of the target user."} + "user_id": { + "type": "string", + "description": "The User ID of the target user.", + } }, - "required": ["user_id"] - } + "required": ["user_id"], + }, ) - ) # --- Tenor GIF Search Tool --- + ) # --- Tenor GIF Search Tool --- tool_declarations.append( FunctionDeclaration( name="search_tenor_gifs", @@ -1664,15 +1873,15 @@ def create_tools_list(): "properties": { "query": { "type": "string", - "description": "The search query for Tenor GIFs." + "description": "The search query for Tenor GIFs.", }, "limit": { "type": "integer", - "description": "Optional: The maximum number of GIF URLs to return (default 8, max 15)." - } + "description": "Optional: The maximum number of GIF URLs to return (default 8, max 15).", + }, }, - "required": ["query"] - } + "required": ["query"], + }, ) ) @@ -1686,15 +1895,15 @@ def create_tools_list(): "properties": { "query": { "type": "string", - "description": "The search query for Tenor GIFs." + "description": "The search query for Tenor GIFs.", }, "limit": { "type": "integer", - "description": "Optional: The maximum number of GIFs to consider for AI selection (default 8, max 15). More GIFs give better selection but take longer." - } + "description": "Optional: The maximum number of GIFs to consider for AI selection (default 8, max 15). More GIFs give better selection but take longer.", + }, }, - "required": ["query"] - } + "required": ["query"], + }, ) ) @@ -1708,15 +1917,15 @@ def create_tools_list(): "properties": { "path": { "type": "string", - "description": "The path to the directory to list contents for (e.g., '.', 'gurt/utils', '../some_other_project')." + "description": "The path to the directory to list contents for (e.g., '.', 'gurt/utils', '../some_other_project').", }, "recursive": { "type": "boolean", - "description": "Optional: Whether to list files recursively. Defaults to false (top-level only)." - } + "description": "Optional: Whether to list files recursively. Defaults to false (top-level only).", + }, }, - "required": ["path"] - } + "required": ["path"], + }, ) ) @@ -1728,10 +1937,13 @@ def create_tools_list(): parameters={ "type": "object", "properties": { - "user_id": {"type": "string", "description": "The User ID of the target user."} + "user_id": { + "type": "string", + "description": "The User ID of the target user.", + } }, - "required": ["user_id"] - } + "required": ["user_id"], + }, ) ) @@ -1743,21 +1955,24 @@ def create_tools_list(): parameters={ "type": "object", "properties": { - "channel_id": {"type": "string", "description": "The ID of the voice channel to join."} + "channel_id": { + "type": "string", + "description": "The ID of the voice channel to join.", + } }, - "required": ["channel_id"] - } + "required": ["channel_id"], + }, ) ) tool_declarations.append( FunctionDeclaration( name="leave_voice_channel", description="Disconnects GURT from its current voice channel.", - parameters={ # No parameters needed, but schema requires an object + parameters={ # No parameters needed, but schema requires an object "type": "object", "properties": {}, - "required": [] - } + "required": [], + }, ) ) tool_declarations.append( @@ -1767,32 +1982,50 @@ def create_tools_list(): parameters={ "type": "object", "properties": { - "text_to_speak": {"type": "string", "description": "The text GURT should say."}, + "text_to_speak": { + "type": "string", + "description": "The text GURT should say.", + }, "tts_provider": { "type": "string", "description": "Optional. Specify a TTS provider. If omitted, a default will be used.", - "enum": ["gtts", "pyttsx3", "coqui", "espeak", "google_cloud_tts"] - } + "enum": [ + "gtts", + "pyttsx3", + "coqui", + "espeak", + "google_cloud_tts", + ], + }, }, - "required": ["text_to_speak"] - } + "required": ["text_to_speak"], + }, ) ) # --- End Voice Channel Tools --- return tool_declarations + # Initialize TOOLS list, handling potential ImportError if library not installed try: TOOLS = create_tools_list() -except NameError: # If FunctionDeclaration wasn't imported due to ImportError +except NameError: # If FunctionDeclaration wasn't imported due to ImportError TOOLS = [] print("WARNING: google-generativeai not installed. TOOLS list is empty.") # --- Simple Gurt Responses --- GURT_RESPONSES = [ - "Gurt!", "Gurt gurt!", "Gurt... gurt gurt.", "*gurts happily*", - "*gurts sadly*", "*confused gurting*", "Gurt? Gurt gurt!", "GURT!", - "gurt...", "Gurt gurt gurt!", "*aggressive gurting*" + "Gurt!", + "Gurt gurt!", + "Gurt... gurt gurt.", + "*gurts happily*", + "*gurts sadly*", + "*confused gurting*", + "Gurt? Gurt gurt!", + "GURT!", + "gurt...", + "Gurt gurt gurt!", + "*aggressive gurting*", ] diff --git a/gurt/context.py b/gurt/context.py index bc5faca..8355828 100644 --- a/gurt/context.py +++ b/gurt/context.py @@ -6,77 +6,107 @@ import re from typing import TYPE_CHECKING, Optional, List, Dict, Any # Relative imports -from .config import CONTEXT_WINDOW_SIZE # Import necessary config +from .config import CONTEXT_WINDOW_SIZE # Import necessary config if TYPE_CHECKING: - from .cog import GurtCog # For type hinting + from .cog import GurtCog # For type hinting # --- Context Gathering Functions --- # Note: These functions need the 'cog' instance passed to access state like caches, etc. -def gather_conversation_context(cog: 'GurtCog', channel_id: int, current_message_id: int) -> List[Dict[str, str]]: + +def gather_conversation_context( + cog: "GurtCog", channel_id: int, current_message_id: int +) -> List[Dict[str, str]]: """Gathers and formats conversation history from cache for API context.""" context_api_messages = [] - if channel_id in cog.message_cache['by_channel']: - cached = list(cog.message_cache['by_channel'][channel_id]) + if channel_id in cog.message_cache["by_channel"]: + cached = list(cog.message_cache["by_channel"][channel_id]) # The current message is now included when selecting the context window below - context_messages_data = cached[-CONTEXT_WINDOW_SIZE:] # Use config value + context_messages_data = cached[-CONTEXT_WINDOW_SIZE:] # Use config value for msg_data in context_messages_data: - role = "assistant" if msg_data['author']['id'] == str(cog.bot.user.id) else "user" + role = ( + "assistant" + if msg_data["author"]["id"] == str(cog.bot.user.id) + else "user" + ) # Build the content string, including reply and attachment info content_parts = [] # FIX: Use the pre-formatted author_string which includes '(BOT)' tag if applicable. # Fall back to display_name or '' if author_string is missing for some reason. - author_name = msg_data.get('author_string', msg_data.get('author', {}).get('display_name', '')) + author_name = msg_data.get( + "author_string", msg_data.get("author", {}).get("display_name", "") + ) - message_id = msg_data['id'] # Get the message ID + message_id = msg_data["id"] # Get the message ID # Add reply prefix if applicable if msg_data.get("is_reply"): - reply_author = msg_data.get('replied_to_author_name', '') - reply_snippet = msg_data.get('replied_to_content_snippet') # Get value, could be None + reply_author = msg_data.get("replied_to_author_name", "") + reply_snippet = msg_data.get( + "replied_to_content_snippet" + ) # Get value, could be None # Keep snippet very short for context, handle None case - reply_snippet_short = '...' # Default if snippet is None or not a string + reply_snippet_short = ( + "..." # Default if snippet is None or not a string + ) if isinstance(reply_snippet, str): - reply_snippet_short = (reply_snippet[:25] + '...') if len(reply_snippet) > 28 else reply_snippet - content_parts.append(f"{author_name} (Message ID: {message_id}) (replying to {reply_author} '{reply_snippet_short}'):") # Clarify ID + reply_snippet_short = ( + (reply_snippet[:25] + "...") + if len(reply_snippet) > 28 + else reply_snippet + ) + content_parts.append( + f"{author_name} (Message ID: {message_id}) (replying to {reply_author} '{reply_snippet_short}'):" + ) # Clarify ID else: - content_parts.append(f"{author_name} (Message ID: {message_id}):") # Clarify ID + content_parts.append( + f"{author_name} (Message ID: {message_id}):" + ) # Clarify ID # Add main message content - if msg_data.get('content'): - content_parts.append(msg_data['content']) + if msg_data.get("content"): + content_parts.append(msg_data["content"]) # Add attachment descriptions attachments = msg_data.get("attachment_descriptions", []) if attachments: # Join descriptions into a single string - attachment_str = " ".join([att['description'] for att in attachments]) + attachment_str = " ".join([att["description"] for att in attachments]) content_parts.append(attachment_str) # Join all parts with spaces # --- New Handling for Tool Request/Response Turns --- - author_id = msg_data['author'].get('id') - is_tool_request = author_id == str(cog.bot.user.id) and msg_data.get('tool_calls') is not None - is_tool_response = author_id == "FUNCTION" and msg_data.get('function_results') is not None + author_id = msg_data["author"].get("id") + is_tool_request = ( + author_id == str(cog.bot.user.id) + and msg_data.get("tool_calls") is not None + ) + is_tool_response = ( + author_id == "FUNCTION" and msg_data.get("function_results") is not None + ) if is_tool_request: # Format tool request turn - tool_names = ", ".join([tc['name'] for tc in msg_data['tool_calls']]) - content = f"[System Note: Gurt requested tool(s): {tool_names}]" # Simple summary - role = "assistant" # Represent as part of the assistant's turn/thought process + tool_names = ", ".join([tc["name"] for tc in msg_data["tool_calls"]]) + content = f"[System Note: Gurt requested tool(s): {tool_names}]" # Simple summary + role = "assistant" # Represent as part of the assistant's turn/thought process elif is_tool_response: # Format tool response turn result_summary_parts = [] - for res in msg_data['function_results']: + for res in msg_data["function_results"]: res_str = json.dumps(res.get("response", {})) - truncated_res = (res_str[:150] + '...') if len(res_str) > 153 else res_str - result_summary_parts.append(f"Tool: {res.get('name', 'N/A')}, Result: {truncated_res}") + truncated_res = ( + (res_str[:150] + "...") if len(res_str) > 153 else res_str + ) + result_summary_parts.append( + f"Tool: {res.get('name', 'N/A')}, Result: {truncated_res}" + ) result_summary = "; ".join(result_summary_parts) content = f"[System Note: Tool Execution Result: {result_summary}]" - role = "function" # Keep role as 'function' for API compatibility if needed, or maybe 'system'? Let's try 'function'. + role = "function" # Keep role as 'function' for API compatibility if needed, or maybe 'system'? Let's try 'function'. else: # --- Original Handling for User/Assistant messages --- content = " ".join(content_parts).strip() @@ -90,7 +120,7 @@ def gather_conversation_context(cog: 'GurtCog', channel_id: int, current_message return context_api_messages -async def get_memory_context(cog: 'GurtCog', message: discord.Message) -> Optional[str]: +async def get_memory_context(cog: "GurtCog", message: discord.Message) -> Optional[str]: """Retrieves relevant past interactions and facts to provide memory context.""" channel_id = message.channel.id user_id = str(message.author.id) @@ -99,37 +129,65 @@ async def get_memory_context(cog: 'GurtCog', message: discord.Message) -> Option # 1. Retrieve Relevant User Facts try: - user_facts = await cog.memory_manager.get_user_facts(user_id, context=current_message_content) + user_facts = await cog.memory_manager.get_user_facts( + user_id, context=current_message_content + ) if user_facts: facts_str = "; ".join(user_facts) - memory_parts.append(f"Relevant facts about {message.author.display_name}: {facts_str}") - except Exception as e: print(f"Error retrieving relevant user facts for memory context: {e}") + memory_parts.append( + f"Relevant facts about {message.author.display_name}: {facts_str}" + ) + except Exception as e: + print(f"Error retrieving relevant user facts for memory context: {e}") # 1b. Retrieve Relevant General Facts try: - general_facts = await cog.memory_manager.get_general_facts(context=current_message_content, limit=5) + general_facts = await cog.memory_manager.get_general_facts( + context=current_message_content, limit=5 + ) if general_facts: facts_str = "; ".join(general_facts) memory_parts.append(f"Relevant general knowledge: {facts_str}") - except Exception as e: print(f"Error retrieving relevant general facts for memory context: {e}") + except Exception as e: + print(f"Error retrieving relevant general facts for memory context: {e}") # 2. Retrieve Recent Interactions with the User in this Channel try: - user_channel_messages = [msg for msg in cog.message_cache['by_channel'].get(channel_id, []) if msg['author']['id'] == user_id] + user_channel_messages = [ + msg + for msg in cog.message_cache["by_channel"].get(channel_id, []) + if msg["author"]["id"] == user_id + ] if user_channel_messages: recent_user_msgs = user_channel_messages[-3:] - msgs_str = "\n".join([f"- {m['content'][:80]} (at {m['created_at']})" for m in recent_user_msgs]) - memory_parts.append(f"Recent messages from {message.author.display_name} in this channel:\n{msgs_str}") - except Exception as e: print(f"Error retrieving user channel messages for memory context: {e}") + msgs_str = "\n".join( + [ + f"- {m['content'][:80]} (at {m['created_at']})" + for m in recent_user_msgs + ] + ) + memory_parts.append( + f"Recent messages from {message.author.display_name} in this channel:\n{msgs_str}" + ) + except Exception as e: + print(f"Error retrieving user channel messages for memory context: {e}") # 3. Retrieve Recent Bot Replies in this Channel try: - bot_replies = list(cog.message_cache['replied_to'].get(channel_id, [])) + bot_replies = list(cog.message_cache["replied_to"].get(channel_id, [])) if bot_replies: recent_bot_replies = bot_replies[-3:] - replies_str = "\n".join([f"- {m['content'][:80]} (at {m['created_at']})" for m in recent_bot_replies]) - memory_parts.append(f"Your (gurt's) recent replies in this channel:\n{replies_str}") - except Exception as e: print(f"Error retrieving bot replies for memory context: {e}") + replies_str = "\n".join( + [ + f"- {m['content'][:80]} (at {m['created_at']})" + for m in recent_bot_replies + ] + ) + memory_parts.append( + f"Your (gurt's) recent replies in this channel:\n{replies_str}" + ) + except Exception as e: + print(f"Error retrieving bot replies for memory context: {e}") # 4. Retrieve Conversation Summary cached_summary_data = cog.conversation_summaries.get(channel_id) @@ -137,46 +195,87 @@ async def get_memory_context(cog: 'GurtCog', message: discord.Message) -> Option summary_text = cached_summary_data.get("summary") # Add TTL check if desired, e.g., if time.time() - cached_summary_data.get("timestamp", 0) < 900: if summary_text and not summary_text.startswith("Error"): - memory_parts.append(f"Summary of the ongoing conversation: {summary_text}") + memory_parts.append(f"Summary of the ongoing conversation: {summary_text}") # 5. Add information about active topics the user has engaged with try: channel_topics_data = cog.active_topics.get(channel_id) if channel_topics_data: - user_interests = channel_topics_data["user_topic_interests"].get(user_id, []) + user_interests = channel_topics_data["user_topic_interests"].get( + user_id, [] + ) if user_interests: - sorted_interests = sorted(user_interests, key=lambda x: x.get("score", 0), reverse=True) + sorted_interests = sorted( + user_interests, key=lambda x: x.get("score", 0), reverse=True + ) top_interests = sorted_interests[:3] - interests_str = ", ".join([f"{interest['topic']} (score: {interest['score']:.2f})" for interest in top_interests]) - memory_parts.append(f"{message.author.display_name}'s topic interests: {interests_str}") + interests_str = ", ".join( + [ + f"{interest['topic']} (score: {interest['score']:.2f})" + for interest in top_interests + ] + ) + memory_parts.append( + f"{message.author.display_name}'s topic interests: {interests_str}" + ) for interest in top_interests: if "last_mentioned" in interest: time_diff = time.time() - interest["last_mentioned"] if time_diff < 3600: minutes_ago = int(time_diff / 60) - memory_parts.append(f"They discussed '{interest['topic']}' about {minutes_ago} minutes ago.") - except Exception as e: print(f"Error retrieving user topic interests for memory context: {e}") + memory_parts.append( + f"They discussed '{interest['topic']}' about {minutes_ago} minutes ago." + ) + except Exception as e: + print(f"Error retrieving user topic interests for memory context: {e}") # 6. Add information about user's conversation patterns try: - user_messages = cog.message_cache['by_user'].get(user_id, []) + user_messages = cog.message_cache["by_user"].get(user_id, []) if len(user_messages) >= 5: last_5_msgs = user_messages[-5:] avg_length = sum(len(msg["content"]) for msg in last_5_msgs) / 5 - emoji_pattern = re.compile(r'[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F700-\U0001F77F\U0001F780-\U0001F7FF\U0001F800-\U0001F8FF\U0001F900-\U0001F9FF\U0001FA00-\U0001FA6F\U0001FA70-\U0001FAFF\U00002702-\U000027B0\U000024C2-\U0001F251]') - emoji_count = sum(len(emoji_pattern.findall(msg["content"])) for msg in last_5_msgs) - slang_words = ["ngl", "icl", "pmo", "ts", "bro", "vro", "bruh", "tuff", "kevin"] - slang_count = sum(1 for msg in last_5_msgs for word in slang_words if re.search(r'\b' + word + r'\b', msg["content"].lower())) + emoji_pattern = re.compile( + r"[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F700-\U0001F77F\U0001F780-\U0001F7FF\U0001F800-\U0001F8FF\U0001F900-\U0001F9FF\U0001FA00-\U0001FA6F\U0001FA70-\U0001FAFF\U00002702-\U000027B0\U000024C2-\U0001F251]" + ) + emoji_count = sum( + len(emoji_pattern.findall(msg["content"])) for msg in last_5_msgs + ) + slang_words = [ + "ngl", + "icl", + "pmo", + "ts", + "bro", + "vro", + "bruh", + "tuff", + "kevin", + ] + slang_count = sum( + 1 + for msg in last_5_msgs + for word in slang_words + if re.search(r"\b" + word + r"\b", msg["content"].lower()) + ) style_parts = [] - if avg_length < 20: style_parts.append("very brief messages") - elif avg_length < 50: style_parts.append("concise messages") - elif avg_length > 150: style_parts.append("detailed/lengthy messages") - if emoji_count > 5: style_parts.append("frequent emoji use") - elif emoji_count == 0: style_parts.append("no emojis") - if slang_count > 3: style_parts.append("heavy slang usage") - if style_parts: memory_parts.append(f"Communication style: {', '.join(style_parts)}") - except Exception as e: print(f"Error analyzing user communication patterns: {e}") + if avg_length < 20: + style_parts.append("very brief messages") + elif avg_length < 50: + style_parts.append("concise messages") + elif avg_length > 150: + style_parts.append("detailed/lengthy messages") + if emoji_count > 5: + style_parts.append("frequent emoji use") + elif emoji_count == 0: + style_parts.append("no emojis") + if slang_count > 3: + style_parts.append("heavy slang usage") + if style_parts: + memory_parts.append(f"Communication style: {', '.join(style_parts)}") + except Exception as e: + print(f"Error analyzing user communication patterns: {e}") # 7. Add sentiment analysis of user's recent messages try: @@ -184,68 +283,108 @@ async def get_memory_context(cog: 'GurtCog', message: discord.Message) -> Option user_sentiment = channel_sentiment["user_sentiments"].get(user_id) if user_sentiment: sentiment_desc = f"{user_sentiment['sentiment']} tone" - if user_sentiment["intensity"] > 0.7: sentiment_desc += " (strongly so)" - elif user_sentiment["intensity"] < 0.4: sentiment_desc += " (mildly so)" + if user_sentiment["intensity"] > 0.7: + sentiment_desc += " (strongly so)" + elif user_sentiment["intensity"] < 0.4: + sentiment_desc += " (mildly so)" memory_parts.append(f"Recent message sentiment: {sentiment_desc}") if user_sentiment.get("emotions"): emotions_str = ", ".join(user_sentiment["emotions"]) memory_parts.append(f"Detected emotions from user: {emotions_str}") - except Exception as e: print(f"Error retrieving user sentiment/emotions for memory context: {e}") + except Exception as e: + print(f"Error retrieving user sentiment/emotions for memory context: {e}") # 8. Add Relationship Score with User try: user_id_str = str(user_id) bot_id_str = str(cog.bot.user.id) - key_1, key_2 = (user_id_str, bot_id_str) if user_id_str < bot_id_str else (bot_id_str, user_id_str) + key_1, key_2 = ( + (user_id_str, bot_id_str) + if user_id_str < bot_id_str + else (bot_id_str, user_id_str) + ) relationship_score = cog.user_relationships.get(key_1, {}).get(key_2, 0.0) - memory_parts.append(f"Relationship score with {message.author.display_name}: {relationship_score:.1f}/100") - except Exception as e: print(f"Error retrieving relationship score for memory context: {e}") + memory_parts.append( + f"Relationship score with {message.author.display_name}: {relationship_score:.1f}/100" + ) + except Exception as e: + print(f"Error retrieving relationship score for memory context: {e}") # 9. Retrieve Semantically Similar Messages try: if current_message_content and cog.memory_manager.semantic_collection: - filter_metadata = None # Example: {"channel_id": str(channel_id)} + filter_metadata = None # Example: {"channel_id": str(channel_id)} semantic_results = await cog.memory_manager.search_semantic_memory( - query_text=current_message_content, n_results=3, filter_metadata=filter_metadata + query_text=current_message_content, + n_results=3, + filter_metadata=filter_metadata, ) if semantic_results: semantic_memory_parts = ["Semantically similar past messages:"] for result in semantic_results: - if result.get('id') == str(message.id): continue - doc = result.get('document', 'N/A') - meta = result.get('metadata', {}) - dist = result.get('distance', 1.0) + if result.get("id") == str(message.id): + continue + doc = result.get("document", "N/A") + meta = result.get("metadata", {}) + dist = result.get("distance", 1.0) similarity_score = 1.0 - dist - timestamp_str = datetime.datetime.fromtimestamp(meta.get('timestamp', 0)).strftime('%Y-%m-%d %H:%M') if meta.get('timestamp') else 'Unknown time' - author_name = meta.get('display_name', meta.get('user_name', '')) - semantic_memory_parts.append(f"- (Similarity: {similarity_score:.2f}) {author_name} (at {timestamp_str}): {doc[:100]}") - if len(semantic_memory_parts) > 1: memory_parts.append("\n".join(semantic_memory_parts)) - except Exception as e: print(f"Error retrieving semantic memory context: {e}") + timestamp_str = ( + datetime.datetime.fromtimestamp( + meta.get("timestamp", 0) + ).strftime("%Y-%m-%d %H:%M") + if meta.get("timestamp") + else "Unknown time" + ) + author_name = meta.get("display_name", meta.get("user_name", "")) + semantic_memory_parts.append( + f"- (Similarity: {similarity_score:.2f}) {author_name} (at {timestamp_str}): {doc[:100]}" + ) + if len(semantic_memory_parts) > 1: + memory_parts.append("\n".join(semantic_memory_parts)) + except Exception as e: + print(f"Error retrieving semantic memory context: {e}") # 10. Add information about recent attachments try: - channel_messages = cog.message_cache['by_channel'].get(channel_id, []) - messages_with_attachments = [msg for msg in channel_messages if msg.get("attachment_descriptions")] + channel_messages = cog.message_cache["by_channel"].get(channel_id, []) + messages_with_attachments = [ + msg for msg in channel_messages if msg.get("attachment_descriptions") + ] if messages_with_attachments: - recent_attachments = messages_with_attachments[-5:] # Get last 5 + recent_attachments = messages_with_attachments[-5:] # Get last 5 attachment_memory_parts = ["Recently Shared Files/Images:"] for msg in recent_attachments: - author_name = msg.get('author', {}).get('display_name', '') - timestamp_str = 'Unknown time' + author_name = msg.get("author", {}).get("display_name", "") + timestamp_str = "Unknown time" try: # Safely parse timestamp - if msg.get('created_at'): - timestamp_str = datetime.datetime.fromisoformat(msg['created_at']).strftime('%H:%M') - except ValueError: pass # Ignore invalid timestamp format + if msg.get("created_at"): + timestamp_str = datetime.datetime.fromisoformat( + msg["created_at"] + ).strftime("%H:%M") + except ValueError: + pass # Ignore invalid timestamp format - descriptions = " ".join([att['description'] for att in msg.get('attachment_descriptions', [])]) - attachment_memory_parts.append(f"- By {author_name} (at {timestamp_str}): {descriptions}") + descriptions = " ".join( + [ + att["description"] + for att in msg.get("attachment_descriptions", []) + ] + ) + attachment_memory_parts.append( + f"- By {author_name} (at {timestamp_str}): {descriptions}" + ) if len(attachment_memory_parts) > 1: memory_parts.append("\n".join(attachment_memory_parts)) - except Exception as e: print(f"Error retrieving recent attachments for memory context: {e}") + except Exception as e: + print(f"Error retrieving recent attachments for memory context: {e}") - - if not memory_parts: return None - memory_context_str = "--- Memory Context ---\n" + "\n\n".join(memory_parts) + "\n--- End Memory Context ---" + if not memory_parts: + return None + memory_context_str = ( + "--- Memory Context ---\n" + + "\n\n".join(memory_parts) + + "\n--- End Memory Context ---" + ) return memory_context_str diff --git a/gurt/emojis.py b/gurt/emojis.py index a87d18c..1139ea6 100644 --- a/gurt/emojis.py +++ b/gurt/emojis.py @@ -4,6 +4,7 @@ from typing import Dict, Optional, Tuple, Union, Any DATA_FILE_PATH = "data/custom_emojis_stickers.json" + class EmojiManager: def __init__(self, data_file: str = DATA_FILE_PATH): self.data_file = data_file @@ -15,7 +16,7 @@ class EmojiManager: """Loads emoji and sticker data from the JSON file.""" try: if os.path.exists(self.data_file): - with open(self.data_file, 'r', encoding='utf-8') as f: + with open(self.data_file, "r", encoding="utf-8") as f: loaded_json = json.load(f) if isinstance(loaded_json, dict): # Ensure guild_id is present, defaulting to None if missing for backward compatibility during load @@ -24,42 +25,55 @@ class EmojiManager: "id": data.get("id"), "animated": data.get("animated", False), "guild_id": data.get("guild_id"), - "url": data.get("url"), # Load new field - "description": data.get("description") # Load new field + "url": data.get("url"), # Load new field + "description": data.get( + "description" + ), # Load new field } - for name, data in loaded_json.get("emojis", {}).items() if isinstance(data, dict) + for name, data in loaded_json.get("emojis", {}).items() + if isinstance(data, dict) } self.data["stickers"] = { name: { "id": data.get("id"), "guild_id": data.get("guild_id"), - "url": data.get("url"), # Load new field - "description": data.get("description") # Load new field + "url": data.get("url"), # Load new field + "description": data.get( + "description" + ), # Load new field } - for name, data in loaded_json.get("stickers", {}).items() if isinstance(data, dict) + for name, data in loaded_json.get("stickers", {}).items() + if isinstance(data, dict) } - print(f"Loaded {len(self.data['emojis'])} emojis and {len(self.data['stickers'])} stickers from {self.data_file}") + print( + f"Loaded {len(self.data['emojis'])} emojis and {len(self.data['stickers'])} stickers from {self.data_file}" + ) else: - print(f"Warning: Data in {self.data_file} is not a dictionary. Initializing with empty data.") - self._save_data() # Initialize with empty structure if format is wrong + print( + f"Warning: Data in {self.data_file} is not a dictionary. Initializing with empty data." + ) + self._save_data() # Initialize with empty structure if format is wrong else: print(f"{self.data_file} not found. Initializing with empty data.") - self._save_data() # Create the file if it doesn't exist + self._save_data() # Create the file if it doesn't exist except json.JSONDecodeError: - print(f"Error decoding JSON from {self.data_file}. Initializing with empty data.") + print( + f"Error decoding JSON from {self.data_file}. Initializing with empty data." + ) self._save_data() except Exception as e: print(f"Error loading emoji/sticker data: {e}") # Ensure data is initialized even on other errors - if "emojis" not in self.data: self.data["emojis"] = {} - if "stickers" not in self.data: self.data["stickers"] = {} - + if "emojis" not in self.data: + self.data["emojis"] = {} + if "stickers" not in self.data: + self.data["stickers"] = {} def _save_data(self): """Saves the current emoji and sticker data to the JSON file.""" try: os.makedirs(os.path.dirname(self.data_file), exist_ok=True) - with open(self.data_file, 'w', encoding='utf-8') as f: + with open(self.data_file, "w", encoding="utf-8") as f: json.dump(self.data, f, indent=4) print(f"Saved emoji and sticker data to {self.data_file}") return True @@ -67,22 +81,32 @@ class EmojiManager: print(f"Error saving emoji/sticker data: {e}") return False - async def add_emoji(self, name: str, emoji_id: str, is_animated: bool, guild_id: Optional[int], url: Optional[str] = None, description: Optional[str] = None) -> bool: + async def add_emoji( + self, + name: str, + emoji_id: str, + is_animated: bool, + guild_id: Optional[int], + url: Optional[str] = None, + description: Optional[str] = None, + ) -> bool: """Adds a custom emoji with its guild ID, URL, and description.""" if name in self.data["emojis"]: existing_data = self.data["emojis"][name] - if (existing_data.get("id") == emoji_id and - existing_data.get("guild_id") == guild_id and - existing_data.get("animated") == is_animated and - existing_data.get("url") == url and - existing_data.get("description") == description): - return False # No change + if ( + existing_data.get("id") == emoji_id + and existing_data.get("guild_id") == guild_id + and existing_data.get("animated") == is_animated + and existing_data.get("url") == url + and existing_data.get("description") == description + ): + return False # No change self.data["emojis"][name] = { "id": emoji_id, "animated": is_animated, "guild_id": guild_id, "url": url, - "description": description + "description": description, } return self._save_data() @@ -101,20 +125,29 @@ class EmojiManager: """Gets a specific custom emoji by name.""" return self.data["emojis"].get(name) - async def add_sticker(self, name: str, sticker_id: str, guild_id: Optional[int], url: Optional[str] = None, description: Optional[str] = None) -> bool: + async def add_sticker( + self, + name: str, + sticker_id: str, + guild_id: Optional[int], + url: Optional[str] = None, + description: Optional[str] = None, + ) -> bool: """Adds a custom sticker with its guild ID, URL, and description.""" if name in self.data["stickers"]: existing_data = self.data["stickers"][name] - if (existing_data.get("id") == sticker_id and - existing_data.get("guild_id") == guild_id and - existing_data.get("url") == url and - existing_data.get("description") == description): - return False # No change + if ( + existing_data.get("id") == sticker_id + and existing_data.get("guild_id") == guild_id + and existing_data.get("url") == url + and existing_data.get("description") == description + ): + return False # No change self.data["stickers"][name] = { "id": sticker_id, "guild_id": guild_id, "url": url, - "description": description + "description": description, } return self._save_data() diff --git a/gurt/extrtools.py b/gurt/extrtools.py index e1d6291..55c3bd4 100644 --- a/gurt/extrtools.py +++ b/gurt/extrtools.py @@ -16,6 +16,7 @@ import ast import json import sys + def extract_function_declarations(source: str): tree = ast.parse(source) tools = [] @@ -52,6 +53,7 @@ def extract_function_declarations(source: str): return tools + def main(): if len(sys.argv) != 2: print("Usage: python extract_tools.py path/to/your_module.py", file=sys.stderr) @@ -65,5 +67,6 @@ def main(): json.dump(tools, sys.stdout, indent=2) sys.stdout.write("\n") + if __name__ == "__main__": main() diff --git a/gurt/listeners.py b/gurt/listeners.py index 01e128a..84a14d7 100644 --- a/gurt/listeners.py +++ b/gurt/listeners.py @@ -4,17 +4,18 @@ import random import asyncio import time import re -import os # Added for file handling in error case +import os # Added for file handling in error case from typing import TYPE_CHECKING, List, Union, Dict, Any, Optional -from collections import deque # Import deque for efficient rate limiting +from collections import deque # Import deque for efficient rate limiting # Relative imports -from .utils import format_message # Import format_message +from .utils import format_message # Import format_message from .config import ( - CONTEXT_WINDOW_SIZE, # Import context window size - BOT_RESPONSE_RATE_LIMIT_PER_MINUTE, # New config - BOT_RESPONSE_RATE_LIMIT_WINDOW_SECONDS # New config + CONTEXT_WINDOW_SIZE, # Import context window size + BOT_RESPONSE_RATE_LIMIT_PER_MINUTE, # New config + BOT_RESPONSE_RATE_LIMIT_WINDOW_SECONDS, # New config ) + # Assuming api, utils, analysis functions are defined and imported correctly later # We might need to adjust these imports based on final structure # from .api import get_ai_response, get_proactive_ai_response @@ -22,16 +23,17 @@ from .config import ( # from .analysis import analyze_message_sentiment, update_conversation_sentiment if TYPE_CHECKING: - from .cog import GurtCog # For type hinting + from .cog import GurtCog # For type hinting # Note: These listener functions need to be registered within the GurtCog class setup. # They are defined here for separation but won't work standalone without being # attached to the cog instance (e.g., self.bot.add_listener(on_message_listener(self), 'on_message')). -async def on_ready_listener(cog: 'GurtCog'): + +async def on_ready_listener(cog: "GurtCog"): """Listener function for on_ready.""" - print(f'Gurt Bot is ready! Logged in as {cog.bot.user.name} ({cog.bot.user.id})') - print('------') + print(f"Gurt Bot is ready! Logged in as {cog.bot.user.name} ({cog.bot.user.id})") + print("------") # Now that the bot is ready, we can sync commands with Discord try: @@ -40,11 +42,16 @@ async def on_ready_listener(cog: 'GurtCog'): print(f"GurtCog: Synced {len(synced)} command(s)") # List the synced commands - gurt_commands = [cmd.name for cmd in cog.bot.tree.get_commands() if cmd.name.startswith("gurt")] + gurt_commands = [ + cmd.name + for cmd in cog.bot.tree.get_commands() + if cmd.name.startswith("gurt") + ] print(f"GurtCog: Available Gurt commands: {', '.join(gurt_commands)}") except Exception as e: print(f"GurtCog: Failed to sync commands: {e}") import traceback + traceback.print_exc() # --- Message history pre-loading removed --- @@ -52,13 +59,22 @@ async def on_ready_listener(cog: 'GurtCog'): await cog.initial_emoji_sticker_scan() -async def on_message_listener(cog: 'GurtCog', message: discord.Message): +async def on_message_listener(cog: "GurtCog", message: discord.Message): """Listener function for on_message.""" # Import necessary functions dynamically or ensure they are passed/accessible via cog from .api import get_ai_response, get_proactive_ai_response from .utils import format_message, simulate_human_typing - from .analysis import analyze_message_sentiment, update_conversation_sentiment, identify_conversation_topics - from .config import GURT_RESPONSES, IGNORED_CHANNEL_IDS, BOT_RESPONSE_RATE_LIMIT_PER_MINUTE, BOT_RESPONSE_RATE_LIMIT_WINDOW_SECONDS # Import new config + from .analysis import ( + analyze_message_sentiment, + update_conversation_sentiment, + identify_conversation_topics, + ) + from .config import ( + GURT_RESPONSES, + IGNORED_CHANNEL_IDS, + BOT_RESPONSE_RATE_LIMIT_PER_MINUTE, + BOT_RESPONSE_RATE_LIMIT_WINDOW_SECONDS, + ) # Import new config # Don't respond to our own messages if message.author == cog.bot.user: @@ -77,62 +93,89 @@ async def on_message_listener(cog: 'GurtCog', message: discord.Message): # --- Bot Response Rate Limiting --- if message.author.bot: bot_id = message.author.id - if not hasattr(cog, 'bot_response_timestamps'): + if not hasattr(cog, "bot_response_timestamps"): cog.bot_response_timestamps = {} - + if bot_id not in cog.bot_response_timestamps: cog.bot_response_timestamps[bot_id] = deque() # Clean up old timestamps now = time.time() - while cog.bot_response_timestamps[bot_id] and \ - now - cog.bot_response_timestamps[bot_id][0] > BOT_RESPONSE_RATE_LIMIT_WINDOW_SECONDS: + while ( + cog.bot_response_timestamps[bot_id] + and now - cog.bot_response_timestamps[bot_id][0] + > BOT_RESPONSE_RATE_LIMIT_WINDOW_SECONDS + ): cog.bot_response_timestamps[bot_id].popleft() # Check if limit is exceeded - if len(cog.bot_response_timestamps[bot_id]) >= BOT_RESPONSE_RATE_LIMIT_PER_MINUTE: - print(f"Rate limit exceeded for bot {message.author.name} ({bot_id}). Skipping response.") - return # Do not respond to this bot + if ( + len(cog.bot_response_timestamps[bot_id]) + >= BOT_RESPONSE_RATE_LIMIT_PER_MINUTE + ): + print( + f"Rate limit exceeded for bot {message.author.name} ({bot_id}). Skipping response." + ) + return # Do not respond to this bot # --- Cache and Track Incoming Message --- try: - formatted_message = format_message(cog, message) # Use utility function + formatted_message = format_message(cog, message) # Use utility function channel_id = message.channel.id user_id = message.author.id # --- Detect and Learn Custom Emojis/Stickers --- # Custom Emojis in message content if message.content: - custom_emojis = re.findall(r'<(a)?:(\w+):(\d+)>', message.content) - for animated, name, emoji_id_str in custom_emojis: # Renamed emoji_id to emoji_id_str + custom_emojis = re.findall(r"<(a)?:(\w+):(\d+)>", message.content) + for ( + animated, + name, + emoji_id_str, + ) in custom_emojis: # Renamed emoji_id to emoji_id_str # emoji_url = f"https://cdn.discordapp.com/emojis/{emoji_id_str}.{'gif' if animated else 'png'}" emoji_name_key = f":{name}:" current_guild_id = message.guild.id if message.guild else None # Check if already learned, if not, add it existing_emoji_data = await cog.emoji_manager.get_emoji(emoji_name_key) - if not existing_emoji_data or \ - existing_emoji_data.get("id") != emoji_id_str or \ - existing_emoji_data.get("guild_id") != current_guild_id: # Also check guild_id - print(f"Learning custom emoji: {emoji_name_key} (ID: {emoji_id_str}, Animated: {bool(animated)}, Guild: {current_guild_id})") - await cog.emoji_manager.add_emoji(emoji_name_key, emoji_id_str, bool(animated), current_guild_id) - + if ( + not existing_emoji_data + or existing_emoji_data.get("id") != emoji_id_str + or existing_emoji_data.get("guild_id") != current_guild_id + ): # Also check guild_id + print( + f"Learning custom emoji: {emoji_name_key} (ID: {emoji_id_str}, Animated: {bool(animated)}, Guild: {current_guild_id})" + ) + await cog.emoji_manager.add_emoji( + emoji_name_key, emoji_id_str, bool(animated), current_guild_id + ) # Stickers in message if message.stickers: for sticker_item in message.stickers: - sticker_name_key = f":{sticker_item.name}:" # Use sticker name as key - sticker_id_str = str(sticker_item.id) # Ensure ID is string + sticker_name_key = f":{sticker_item.name}:" # Use sticker name as key + sticker_id_str = str(sticker_item.id) # Ensure ID is string current_guild_id = message.guild.id if message.guild else None # Check if already learned, if not, add it - existing_sticker_data = await cog.emoji_manager.get_sticker(sticker_name_key) - if not existing_sticker_data or \ - existing_sticker_data.get("id") != sticker_id_str or \ - existing_sticker_data.get("guild_id") != current_guild_id: # Also check guild_id - print(f"Learning sticker: {sticker_name_key} (ID: {sticker_id_str}, Guild: {current_guild_id})") - await cog.emoji_manager.add_sticker(sticker_name_key, sticker_id_str, current_guild_id) + existing_sticker_data = await cog.emoji_manager.get_sticker( + sticker_name_key + ) + if ( + not existing_sticker_data + or existing_sticker_data.get("id") != sticker_id_str + or existing_sticker_data.get("guild_id") != current_guild_id + ): # Also check guild_id + print( + f"Learning sticker: {sticker_name_key} (ID: {sticker_id_str}, Guild: {current_guild_id})" + ) + await cog.emoji_manager.add_sticker( + sticker_name_key, sticker_id_str, current_guild_id + ) # --- End Emoji/Sticker Learning --- - thread_id = message.channel.id if isinstance(message.channel, discord.Thread) else None + thread_id = ( + message.channel.id if isinstance(message.channel, discord.Thread) else None + ) # Update caches (accessing cog's state) # Deduplicate by message ID before appending @@ -140,13 +183,17 @@ async def on_message_listener(cog: 'GurtCog', message: discord.Message): if not any(m.get("id") == msg.get("id") for m in cache_deque): cache_deque.append(msg) - _dedup_and_append(cog.message_cache['by_channel'][channel_id], formatted_message) - _dedup_and_append(cog.message_cache['by_user'][user_id], formatted_message) - _dedup_and_append(cog.message_cache['global_recent'], formatted_message) + _dedup_and_append( + cog.message_cache["by_channel"][channel_id], formatted_message + ) + _dedup_and_append(cog.message_cache["by_user"][user_id], formatted_message) + _dedup_and_append(cog.message_cache["global_recent"], formatted_message) if thread_id: - _dedup_and_append(cog.message_cache['by_thread'][thread_id], formatted_message) + _dedup_and_append( + cog.message_cache["by_thread"][thread_id], formatted_message + ) if cog.bot.user.mentioned_in(message): - _dedup_and_append(cog.message_cache['mentioned'], formatted_message) + _dedup_and_append(cog.message_cache["mentioned"], formatted_message) cog.conversation_history[channel_id].append(formatted_message) if thread_id: @@ -156,46 +203,76 @@ async def on_message_listener(cog: 'GurtCog', message: discord.Message): cog.user_conversation_mapping[user_id].add(channel_id) if channel_id not in cog.active_conversations: - cog.active_conversations[channel_id] = {'participants': set(), 'start_time': time.time(), 'last_activity': time.time(), 'topic': None} - cog.active_conversations[channel_id]['participants'].add(user_id) - cog.active_conversations[channel_id]['last_activity'] = time.time() + cog.active_conversations[channel_id] = { + "participants": set(), + "start_time": time.time(), + "last_activity": time.time(), + "topic": None, + } + cog.active_conversations[channel_id]["participants"].add(user_id) + cog.active_conversations[channel_id]["last_activity"] = time.time() # --- Update Relationship Strengths --- if user_id != cog.bot.user.id: - message_sentiment_data = analyze_message_sentiment(cog, message.content) # Use analysis function + message_sentiment_data = analyze_message_sentiment( + cog, message.content + ) # Use analysis function sentiment_score = 0.0 - if message_sentiment_data["sentiment"] == "positive": sentiment_score = message_sentiment_data["intensity"] * 0.5 - elif message_sentiment_data["sentiment"] == "negative": sentiment_score = -message_sentiment_data["intensity"] * 0.3 + if message_sentiment_data["sentiment"] == "positive": + sentiment_score = message_sentiment_data["intensity"] * 0.5 + elif message_sentiment_data["sentiment"] == "negative": + sentiment_score = -message_sentiment_data["intensity"] * 0.3 - cog._update_relationship(str(user_id), str(cog.bot.user.id), 1.0 + sentiment_score) # Access cog method + cog._update_relationship( + str(user_id), str(cog.bot.user.id), 1.0 + sentiment_score + ) # Access cog method - if formatted_message.get("is_reply") and formatted_message.get("replied_to_author_id"): + if formatted_message.get("is_reply") and formatted_message.get( + "replied_to_author_id" + ): replied_to_id = formatted_message["replied_to_author_id"] - if replied_to_id != str(cog.bot.user.id) and replied_to_id != str(user_id): - cog._update_relationship(str(user_id), replied_to_id, 1.5 + sentiment_score) + if replied_to_id != str(cog.bot.user.id) and replied_to_id != str( + user_id + ): + cog._update_relationship( + str(user_id), replied_to_id, 1.5 + sentiment_score + ) mentioned_ids = [m["id"] for m in formatted_message.get("mentions", [])] for mentioned_id in mentioned_ids: - if mentioned_id != str(cog.bot.user.id) and mentioned_id != str(user_id): - cog._update_relationship(str(user_id), mentioned_id, 1.2 + sentiment_score) + if mentioned_id != str(cog.bot.user.id) and mentioned_id != str( + user_id + ): + cog._update_relationship( + str(user_id), mentioned_id, 1.2 + sentiment_score + ) # Analyze message sentiment and update conversation sentiment tracking if message.content: - message_sentiment = analyze_message_sentiment(cog, message.content) # Use analysis function - update_conversation_sentiment(cog, channel_id, str(user_id), message_sentiment) # Use analysis function + message_sentiment = analyze_message_sentiment( + cog, message.content + ) # Use analysis function + update_conversation_sentiment( + cog, channel_id, str(user_id), message_sentiment + ) # Use analysis function # --- Add message to semantic memory --- if message.content and cog.memory_manager.semantic_collection: semantic_metadata = { - "user_id": str(user_id), "user_name": message.author.name, "display_name": message.author.display_name, - "channel_id": str(channel_id), "channel_name": getattr(message.channel, 'name', 'DM'), + "user_id": str(user_id), + "user_name": message.author.name, + "display_name": message.author.display_name, + "channel_id": str(channel_id), + "channel_name": getattr(message.channel, "name", "DM"), "guild_id": str(message.guild.id) if message.guild else None, - "timestamp": message.created_at.timestamp() + "timestamp": message.created_at.timestamp(), } # Pass the entire formatted_message dictionary now asyncio.create_task( cog.memory_manager.add_message_embedding( - message_id=str(message.id), formatted_message_data=formatted_message, metadata=semantic_metadata + message_id=str(message.id), + formatted_message_data=formatted_message, + metadata=semantic_metadata, ) ) @@ -203,7 +280,6 @@ async def on_message_listener(cog: 'GurtCog', message: discord.Message): print(f"Error during message caching/tracking/embedding: {e}") # --- End Caching & Embedding --- - # Simple response for messages just containing "gurt" if message.content.lower() == "gurt": response = random.choice(GURT_RESPONSES) @@ -212,7 +288,11 @@ async def on_message_listener(cog: 'GurtCog', message: discord.Message): # Check conditions for potentially responding bot_mentioned = cog.bot.user.mentioned_in(message) - replied_to_bot = message.reference and message.reference.resolved and message.reference.resolved.author == cog.bot.user + replied_to_bot = ( + message.reference + and message.reference.resolved + and message.reference.resolved.author == cog.bot.user + ) gurt_in_message = "gurt" in message.content.lower() now = time.time() time_since_last_activity = now - cog.channel_activity.get(channel_id, 0) @@ -227,38 +307,63 @@ async def on_message_listener(cog: 'GurtCog', message: discord.Message): consideration_reason = "Direct mention/reply/name" else: # --- Proactive Engagement Triggers --- - from .config import (PROACTIVE_LULL_THRESHOLD, PROACTIVE_BOT_SILENCE_THRESHOLD, PROACTIVE_LULL_CHANCE, - PROACTIVE_TOPIC_RELEVANCE_THRESHOLD, PROACTIVE_TOPIC_CHANCE, - PROACTIVE_RELATIONSHIP_SCORE_THRESHOLD, PROACTIVE_RELATIONSHIP_CHANCE, - # Import new config values - # Import new config values - PROACTIVE_SENTIMENT_SHIFT_THRESHOLD, PROACTIVE_SENTIMENT_DURATION_THRESHOLD, - PROACTIVE_SENTIMENT_CHANCE, PROACTIVE_USER_INTEREST_THRESHOLD, - PROACTIVE_USER_INTEREST_CHANCE) + from .config import ( + PROACTIVE_LULL_THRESHOLD, + PROACTIVE_BOT_SILENCE_THRESHOLD, + PROACTIVE_LULL_CHANCE, + PROACTIVE_TOPIC_RELEVANCE_THRESHOLD, + PROACTIVE_TOPIC_CHANCE, + PROACTIVE_RELATIONSHIP_SCORE_THRESHOLD, + PROACTIVE_RELATIONSHIP_CHANCE, + # Import new config values + # Import new config values + PROACTIVE_SENTIMENT_SHIFT_THRESHOLD, + PROACTIVE_SENTIMENT_DURATION_THRESHOLD, + PROACTIVE_SENTIMENT_CHANCE, + PROACTIVE_USER_INTEREST_THRESHOLD, + PROACTIVE_USER_INTEREST_CHANCE, + ) # 1. Lull Trigger - if time_since_last_activity > PROACTIVE_LULL_THRESHOLD and time_since_bot_spoke > PROACTIVE_BOT_SILENCE_THRESHOLD: - has_relevant_context = bool(cog.active_topics.get(channel_id, {}).get("topics", [])) or \ - bool(await cog.memory_manager.get_general_facts(limit=1)) + if ( + time_since_last_activity > PROACTIVE_LULL_THRESHOLD + and time_since_bot_spoke > PROACTIVE_BOT_SILENCE_THRESHOLD + ): + has_relevant_context = bool( + cog.active_topics.get(channel_id, {}).get("topics", []) + ) or bool(await cog.memory_manager.get_general_facts(limit=1)) if has_relevant_context and random.random() < PROACTIVE_LULL_CHANCE: should_consider_responding = True proactive_trigger_met = True consideration_reason = f"Proactive: Lull ({time_since_last_activity:.0f}s idle, bot silent {time_since_bot_spoke:.0f}s)" # 2. Topic Relevance Trigger - if not proactive_trigger_met and message.content and cog.memory_manager.semantic_collection: + if ( + not proactive_trigger_met + and message.content + and cog.memory_manager.semantic_collection + ): try: - semantic_results = await cog.memory_manager.search_semantic_memory(query_text=message.content, n_results=1) + semantic_results = await cog.memory_manager.search_semantic_memory( + query_text=message.content, n_results=1 + ) if semantic_results: - similarity_score = 1.0 - semantic_results[0].get('distance', 1.0) - if similarity_score >= PROACTIVE_TOPIC_RELEVANCE_THRESHOLD and time_since_bot_spoke > 120: + similarity_score = 1.0 - semantic_results[0].get("distance", 1.0) + if ( + similarity_score >= PROACTIVE_TOPIC_RELEVANCE_THRESHOLD + and time_since_bot_spoke > 120 + ): if random.random() < PROACTIVE_TOPIC_CHANCE: should_consider_responding = True proactive_trigger_met = True consideration_reason = f"Proactive: Relevant topic (Sim: {similarity_score:.2f})" - print(f"Topic relevance trigger met for msg {message.id}. Sim: {similarity_score:.2f}") + print( + f"Topic relevance trigger met for msg {message.id}. Sim: {similarity_score:.2f}" + ) else: - print(f"Topic relevance trigger skipped by chance ({PROACTIVE_TOPIC_CHANCE}). Sim: {similarity_score:.2f}") + print( + f"Topic relevance trigger skipped by chance ({PROACTIVE_TOPIC_CHANCE}). Sim: {similarity_score:.2f}" + ) except Exception as semantic_e: print(f"Error during semantic search for topic trigger: {semantic_e}") @@ -267,16 +372,31 @@ async def on_message_listener(cog: 'GurtCog', message: discord.Message): try: user_id_str = str(message.author.id) bot_id_str = str(cog.bot.user.id) - key_1, key_2 = (user_id_str, bot_id_str) if user_id_str < bot_id_str else (bot_id_str, user_id_str) - relationship_score = cog.user_relationships.get(key_1, {}).get(key_2, 0.0) - if relationship_score >= PROACTIVE_RELATIONSHIP_SCORE_THRESHOLD and time_since_bot_spoke > 60: + key_1, key_2 = ( + (user_id_str, bot_id_str) + if user_id_str < bot_id_str + else (bot_id_str, user_id_str) + ) + relationship_score = cog.user_relationships.get(key_1, {}).get( + key_2, 0.0 + ) + if ( + relationship_score >= PROACTIVE_RELATIONSHIP_SCORE_THRESHOLD + and time_since_bot_spoke > 60 + ): if random.random() < PROACTIVE_RELATIONSHIP_CHANCE: should_consider_responding = True proactive_trigger_met = True - consideration_reason = f"Proactive: High relationship ({relationship_score:.1f})" - print(f"Relationship trigger met for user {user_id_str}. Score: {relationship_score:.1f}") + consideration_reason = ( + f"Proactive: High relationship ({relationship_score:.1f})" + ) + print( + f"Relationship trigger met for user {user_id_str}. Score: {relationship_score:.1f}" + ) else: - print(f"Relationship trigger skipped by chance ({PROACTIVE_RELATIONSHIP_CHANCE}). Score: {relationship_score:.1f}") + print( + f"Relationship trigger skipped by chance ({PROACTIVE_RELATIONSHIP_CHANCE}). Score: {relationship_score:.1f}" + ) except Exception as rel_e: print(f"Error during relationship trigger check: {rel_e}") @@ -285,42 +405,63 @@ async def on_message_listener(cog: 'GurtCog', message: discord.Message): channel_sentiment_data = cog.conversation_sentiment.get(channel_id, {}) overall_sentiment = channel_sentiment_data.get("overall", "neutral") sentiment_intensity = channel_sentiment_data.get("intensity", 0.5) - sentiment_last_update = channel_sentiment_data.get("last_update", 0) # Need last update time - sentiment_duration = now - sentiment_last_update # How long has this sentiment been dominant? + sentiment_last_update = channel_sentiment_data.get( + "last_update", 0 + ) # Need last update time + sentiment_duration = ( + now - sentiment_last_update + ) # How long has this sentiment been dominant? - if overall_sentiment != "neutral" and \ - sentiment_intensity >= PROACTIVE_SENTIMENT_SHIFT_THRESHOLD and \ - sentiment_duration >= PROACTIVE_SENTIMENT_DURATION_THRESHOLD and \ - time_since_bot_spoke > 180: # Bot hasn't spoken recently about this + if ( + overall_sentiment != "neutral" + and sentiment_intensity >= PROACTIVE_SENTIMENT_SHIFT_THRESHOLD + and sentiment_duration >= PROACTIVE_SENTIMENT_DURATION_THRESHOLD + and time_since_bot_spoke > 180 + ): # Bot hasn't spoken recently about this if random.random() < PROACTIVE_SENTIMENT_CHANCE: should_consider_responding = True proactive_trigger_met = True consideration_reason = f"Proactive: Sentiment Shift ({overall_sentiment}, Intensity: {sentiment_intensity:.2f}, Duration: {sentiment_duration:.0f}s)" - print(f"Sentiment Shift trigger met for channel {channel_id}. Sentiment: {overall_sentiment}, Intensity: {sentiment_intensity:.2f}, Duration: {sentiment_duration:.0f}s") + print( + f"Sentiment Shift trigger met for channel {channel_id}. Sentiment: {overall_sentiment}, Intensity: {sentiment_intensity:.2f}, Duration: {sentiment_duration:.0f}s" + ) else: - print(f"Sentiment Shift trigger skipped by chance ({PROACTIVE_SENTIMENT_CHANCE}). Sentiment: {overall_sentiment}") + print( + f"Sentiment Shift trigger skipped by chance ({PROACTIVE_SENTIMENT_CHANCE}). Sentiment: {overall_sentiment}" + ) # 5. User Interest Trigger (Based on Gurt's interests mentioned in message) if not proactive_trigger_met and message.content: try: - gurt_interests = await cog.memory_manager.get_interests(limit=10, min_level=PROACTIVE_USER_INTEREST_THRESHOLD) + gurt_interests = await cog.memory_manager.get_interests( + limit=10, min_level=PROACTIVE_USER_INTEREST_THRESHOLD + ) if gurt_interests: message_content_lower = message.content.lower() mentioned_interest = None for interest_topic, interest_level in gurt_interests: # Simple check if interest topic is in message - if re.search(r'\b' + re.escape(interest_topic.lower()) + r'\b', message_content_lower): + if re.search( + r"\b" + re.escape(interest_topic.lower()) + r"\b", + message_content_lower, + ): mentioned_interest = interest_topic - break # Found a mentioned interest + break # Found a mentioned interest - if mentioned_interest and time_since_bot_spoke > 90: # Bot hasn't spoken recently + if ( + mentioned_interest and time_since_bot_spoke > 90 + ): # Bot hasn't spoken recently if random.random() < PROACTIVE_USER_INTEREST_CHANCE: should_consider_responding = True proactive_trigger_met = True consideration_reason = f"Proactive: Gurt Interest Mentioned ('{mentioned_interest}')" - print(f"Gurt Interest trigger met for message {message.id}. Interest: '{mentioned_interest}'") + print( + f"Gurt Interest trigger met for message {message.id}. Interest: '{mentioned_interest}'" + ) else: - print(f"Gurt Interest trigger skipped by chance ({PROACTIVE_USER_INTEREST_CHANCE}). Interest: '{mentioned_interest}'") + print( + f"Gurt Interest trigger skipped by chance ({PROACTIVE_USER_INTEREST_CHANCE}). Interest: '{mentioned_interest}'" + ) except Exception as interest_e: print(f"Error during Gurt Interest trigger check: {interest_e}") @@ -328,78 +469,119 @@ async def on_message_listener(cog: 'GurtCog', message: discord.Message): if not proactive_trigger_met and message.content: try: # Fetch 1-2 active goals with highest priority - active_goals = await cog.memory_manager.get_goals(status='active', limit=2) + active_goals = await cog.memory_manager.get_goals( + status="active", limit=2 + ) if active_goals: message_content_lower = message.content.lower() relevant_goal = None for goal in active_goals: # Simple check: does message content relate to goal description? # TODO: Improve this check, maybe use semantic similarity or keyword extraction from goal details - goal_keywords = set(re.findall(r'\b\w{3,}\b', goal.get('description', '').lower())) # Basic keywords from description - message_words = set(re.findall(r'\b\w{3,}\b', message_content_lower)) - if len(goal_keywords.intersection(message_words)) > 1: # Require >1 keyword overlap + goal_keywords = set( + re.findall( + r"\b\w{3,}\b", goal.get("description", "").lower() + ) + ) # Basic keywords from description + message_words = set( + re.findall(r"\b\w{3,}\b", message_content_lower) + ) + if ( + len(goal_keywords.intersection(message_words)) > 1 + ): # Require >1 keyword overlap relevant_goal = goal break - if relevant_goal and time_since_bot_spoke > 120: # Bot hasn't spoken recently + if ( + relevant_goal and time_since_bot_spoke > 120 + ): # Bot hasn't spoken recently # Use a slightly higher chance for goal-related triggers? - goal_relevance_chance = PROACTIVE_USER_INTEREST_CHANCE * 1.2 # Example: Reuse interest chance slightly boosted + goal_relevance_chance = ( + PROACTIVE_USER_INTEREST_CHANCE * 1.2 + ) # Example: Reuse interest chance slightly boosted if random.random() < goal_relevance_chance: should_consider_responding = True proactive_trigger_met = True - goal_desc_short = relevant_goal.get('description', 'N/A')[:40] + goal_desc_short = relevant_goal.get("description", "N/A")[ + :40 + ] consideration_reason = f"Proactive: Relevant Active Goal ('{goal_desc_short}...')" - print(f"Active Goal trigger met for message {message.id}. Goal ID: {relevant_goal.get('goal_id')}") + print( + f"Active Goal trigger met for message {message.id}. Goal ID: {relevant_goal.get('goal_id')}" + ) else: - print(f"Active Goal trigger skipped by chance ({goal_relevance_chance:.2f}).") + print( + f"Active Goal trigger skipped by chance ({goal_relevance_chance:.2f})." + ) except Exception as goal_trigger_e: print(f"Error during Active Goal trigger check: {goal_trigger_e}") - # --- Fallback Contextual Chance --- if not should_consider_responding: # Check if already decided to respond # Fetch current personality traits for chattiness persistent_traits = await cog.memory_manager.get_all_personality_traits() - chattiness = persistent_traits.get('chattiness', 0.7) # Use default if fetch fails + chattiness = persistent_traits.get( + "chattiness", 0.7 + ) # Use default if fetch fails base_chance = chattiness * 0.5 activity_bonus = 0 - if time_since_last_activity > 120: activity_bonus += 0.1 - if time_since_bot_spoke > 300: activity_bonus += 0.1 + if time_since_last_activity > 120: + activity_bonus += 0.1 + if time_since_bot_spoke > 300: + activity_bonus += 0.1 topic_bonus = 0 - active_channel_topics = cog.active_topics.get(channel_id, {}).get("topics", []) + active_channel_topics = cog.active_topics.get(channel_id, {}).get( + "topics", [] + ) if message.content and active_channel_topics: - topic_keywords = set(t['topic'].lower() for t in active_channel_topics) - message_words = set(re.findall(r'\b\w+\b', message.content.lower())) - if topic_keywords.intersection(message_words): topic_bonus += 0.15 + topic_keywords = set(t["topic"].lower() for t in active_channel_topics) + message_words = set(re.findall(r"\b\w+\b", message.content.lower())) + if topic_keywords.intersection(message_words): + topic_bonus += 0.15 sentiment_modifier = 0 channel_sentiment_data = cog.conversation_sentiment.get(channel_id, {}) overall_sentiment = channel_sentiment_data.get("overall", "neutral") sentiment_intensity = channel_sentiment_data.get("intensity", 0.5) - if overall_sentiment == "negative" and sentiment_intensity > 0.6: sentiment_modifier = -0.1 + if overall_sentiment == "negative" and sentiment_intensity > 0.6: + sentiment_modifier = -0.1 - final_chance = min(max(base_chance + activity_bonus + topic_bonus + sentiment_modifier, 0.05), 0.8) + final_chance = min( + max( + base_chance + activity_bonus + topic_bonus + sentiment_modifier, + 0.05, + ), + 0.8, + ) if random.random() < final_chance: should_consider_responding = True consideration_reason = f"Contextual chance ({final_chance:.2f})" else: consideration_reason = f"Skipped (chance {final_chance:.2f})" - print(f"Consideration check for message {message.id}: {should_consider_responding} (Reason: {consideration_reason})") + print( + f"Consideration check for message {message.id}: {should_consider_responding} (Reason: {consideration_reason})" + ) if not should_consider_responding: return # --- Call AI and Handle Response --- - cog.current_channel = message.channel # Ensure current channel is set for API calls/tools + cog.current_channel = ( + message.channel + ) # Ensure current channel is set for API calls/tools - try: # This is the outer try block + try: # This is the outer try block response_dict: Dict[str, Any] - sticker_ids_to_send: List[str] = [] # Initialize sticker_ids_to_send + sticker_ids_to_send: List[str] = [] # Initialize sticker_ids_to_send if proactive_trigger_met: - print(f"Calling get_proactive_ai_response for message {message.id} due to: {consideration_reason}") - response_dict, sticker_ids_to_send = await get_proactive_ai_response(cog, message, consideration_reason) + print( + f"Calling get_proactive_ai_response for message {message.id} due to: {consideration_reason}" + ) + response_dict, sticker_ids_to_send = await get_proactive_ai_response( + cog, message, consideration_reason + ) else: print(f"Calling get_ai_response for message {message.id}") response_dict, sticker_ids_to_send = await get_ai_response(cog, message) @@ -410,14 +592,16 @@ async def on_message_listener(cog: 'GurtCog', message: discord.Message): # We'll use 'final_response_data' to hold this. final_response_data = response_dict.get("final_response") error_msg = response_dict.get("error") - fallback_initial = response_dict.get("fallback_initial") # This might still be relevant for critical errors + fallback_initial = response_dict.get( + "fallback_initial" + ) # This might still be relevant for critical errors if error_msg: print(f"Critical Error from AI response function: {error_msg}") error_notification = f"Oops! Something went wrong while processing that. (`{error_msg[:100]}`)" try: # await message.channel.send(error_notification) # Error notification disabled - print('disabled error notification') + print("disabled error notification") except Exception as send_err: print(f"Failed to send error notification to channel: {send_err}") return @@ -428,19 +612,31 @@ async def on_message_listener(cog: 'GurtCog', message: discord.Message): # Helper function to handle sending a single response text and caching async def send_response_content( - response_data_param: Optional[Dict[str, Any]], # Renamed to avoid conflict + response_data_param: Optional[Dict[str, Any]], # Renamed to avoid conflict response_label: str, original_message: discord.Message, - current_sticker_ids: List[str] # Pass the specific sticker IDs for this response + current_sticker_ids: List[ + str + ], # Pass the specific sticker IDs for this response ) -> bool: nonlocal sent_any_message - if not response_data_param or not isinstance(response_data_param, dict) or \ - not response_data_param.get("should_respond") or not response_data_param.get("content"): + if ( + not response_data_param + or not isinstance(response_data_param, dict) + or not response_data_param.get("should_respond") + or not response_data_param.get("content") + ): # If content is None but stickers are present, we might still want to send - if not (response_data_param and response_data_param.get("should_respond") and current_sticker_ids): + if not ( + response_data_param + and response_data_param.get("should_respond") + and current_sticker_ids + ): return False - response_text = response_data_param.get("content", "") # Default to empty string if content is None + response_text = response_data_param.get( + "content", "" + ) # Default to empty string if content is None reply_to_id = response_data_param.get("reply_to_message_id") message_reference = None @@ -448,88 +644,144 @@ async def on_message_listener(cog: 'GurtCog', message: discord.Message): if reply_to_id and isinstance(reply_to_id, str) and reply_to_id.isdigit(): try: - original_reply_msg = await original_message.channel.fetch_message(int(reply_to_id)) + original_reply_msg = await original_message.channel.fetch_message( + int(reply_to_id) + ) if original_reply_msg: - message_reference = original_reply_msg.to_reference(fail_if_not_exists=False) + message_reference = original_reply_msg.to_reference( + fail_if_not_exists=False + ) print(f"Will reply to message ID: {reply_to_id}") else: - print(f"Warning: Could not fetch message {reply_to_id} to reply to.") + print( + f"Warning: Could not fetch message {reply_to_id} to reply to." + ) except (ValueError, discord.NotFound, discord.Forbidden) as fetch_err: - print(f"Warning: Error fetching message {reply_to_id} to reply to: {fetch_err}") + print( + f"Warning: Error fetching message {reply_to_id} to reply to: {fetch_err}" + ) except Exception as e: - print(f"Unexpected error fetching reply message {reply_to_id}: {e}") + print(f"Unexpected error fetching reply message {reply_to_id}: {e}") elif reply_to_id: print(f"Warning: Invalid reply_to_id format received: {reply_to_id}") - ping_matches = re.findall(r'\[PING:\s*([^\]]+)\s*\]', response_text) + ping_matches = re.findall(r"\[PING:\s*([^\]]+)\s*\]", response_text) if ping_matches: print(f"Found ping placeholders: {ping_matches}") from .tools import get_user_id + for user_name_to_ping in ping_matches: user_id_result = await get_user_id(cog, user_name_to_ping.strip()) if user_id_result and user_id_result.get("status") == "success": user_id_to_ping = user_id_result.get("user_id") if user_id_to_ping: - response_text = response_text.replace(f'[PING: {user_name_to_ping}]', f'<@{user_id_to_ping}>', 1) - print(f"Replaced ping placeholder for '{user_name_to_ping}' with <@{user_id_to_ping}>") + response_text = response_text.replace( + f"[PING: {user_name_to_ping}]", + f"<@{user_id_to_ping}>", + 1, + ) + print( + f"Replaced ping placeholder for '{user_name_to_ping}' with <@{user_id_to_ping}>" + ) else: - print(f"Warning: get_user_id succeeded for '{user_name_to_ping}' but returned no ID.") - response_text = response_text.replace(f'[PING: {user_name_to_ping}]', user_name_to_ping, 1) + print( + f"Warning: get_user_id succeeded for '{user_name_to_ping}' but returned no ID." + ) + response_text = response_text.replace( + f"[PING: {user_name_to_ping}]", user_name_to_ping, 1 + ) else: - print(f"Warning: Could not find user ID for ping placeholder '{user_name_to_ping}'. Error: {user_id_result.get('error')}") - response_text = response_text.replace(f'[PING: {user_name_to_ping}]', user_name_to_ping, 1) - - discord_stickers_to_send = [discord.Object(id=int(s_id)) for s_id in current_sticker_ids if s_id.isdigit()] if current_sticker_ids else [] + print( + f"Warning: Could not find user ID for ping placeholder '{user_name_to_ping}'. Error: {user_id_result.get('error')}" + ) + response_text = response_text.replace( + f"[PING: {user_name_to_ping}]", user_name_to_ping, 1 + ) + discord_stickers_to_send = ( + [ + discord.Object(id=int(s_id)) + for s_id in current_sticker_ids + if s_id.isdigit() + ] + if current_sticker_ids + else [] + ) # Only proceed if there's text or stickers to send if not response_text and not discord_stickers_to_send: - if response_data_param and response_data_param.get("should_respond"): # Log if it was supposed to respond but had nothing - print(f"Warning: {response_label} response marked 'should_respond' but has no content or stickers.") + if response_data_param and response_data_param.get( + "should_respond" + ): # Log if it was supposed to respond but had nothing + print( + f"Warning: {response_label} response marked 'should_respond' but has no content or stickers." + ) return False - if len(response_text) > 1900: # Discord character limit is 2000, 1900 gives buffer - filepath = f'gurt_{response_label}_{original_message.id}.txt' + if ( + len(response_text) > 1900 + ): # Discord character limit is 2000, 1900 gives buffer + filepath = f"gurt_{response_label}_{original_message.id}.txt" try: - with open(filepath, 'w', encoding='utf-8') as f: f.write(response_text) + with open(filepath, "w", encoding="utf-8") as f: + f.write(response_text) await original_message.channel.send( f"{response_label.capitalize()} response too long:", file=discord.File(filepath), reference=message_reference, mention_author=True, - stickers=discord_stickers_to_send + stickers=discord_stickers_to_send, ) sent_any_message = True - print(f"Sent {response_label} content as file (Reply: {bool(message_reference)}, Stickers: {len(discord_stickers_to_send)}).") + print( + f"Sent {response_label} content as file (Reply: {bool(message_reference)}, Stickers: {len(discord_stickers_to_send)})." + ) return True - except Exception as file_e: print(f"Error writing/sending long {response_label} response file: {file_e}") + except Exception as file_e: + print( + f"Error writing/sending long {response_label} response file: {file_e}" + ) finally: - try: os.remove(filepath) - except OSError as os_e: print(f"Error removing temp file {filepath}: {os_e}") + try: + os.remove(filepath) + except OSError as os_e: + print(f"Error removing temp file {filepath}: {os_e}") else: try: # Only enter typing context if there's text to send if response_text: async with original_message.channel.typing(): - await simulate_human_typing(cog, original_message.channel, response_text) - + await simulate_human_typing( + cog, original_message.channel, response_text + ) + sent_msg = await original_message.channel.send( - response_text if response_text else None, # Send None if only stickers + ( + response_text if response_text else None + ), # Send None if only stickers reference=message_reference, mention_author=True, - stickers=discord_stickers_to_send + stickers=discord_stickers_to_send, ) sent_any_message = True bot_response_cache_entry = format_message(cog, sent_msg) - cog.message_cache['by_channel'][channel_id].append(bot_response_cache_entry) - cog.message_cache['global_recent'].append(bot_response_cache_entry) + cog.message_cache["by_channel"][channel_id].append( + bot_response_cache_entry + ) + cog.message_cache["global_recent"].append(bot_response_cache_entry) cog.bot_last_spoke[channel_id] = time.time() - identified_topics = identify_conversation_topics(cog, [bot_response_cache_entry]) + identified_topics = identify_conversation_topics( + cog, [bot_response_cache_entry] + ) if identified_topics: - topic = identified_topics[0]['topic'].lower().strip() + topic = identified_topics[0]["topic"].lower().strip() cog.gurt_participation_topics[topic] += 1 - print(f"Tracked Gurt participation ({response_label}) in topic: '{topic}'") - print(f"Sent {response_label} content (Reply: {bool(message_reference)}, Stickers: {len(discord_stickers_to_send)}).") + print( + f"Tracked Gurt participation ({response_label}) in topic: '{topic}'" + ) + print( + f"Sent {response_label} content (Reply: {bool(message_reference)}, Stickers: {len(discord_stickers_to_send)})." + ) # --- Record Gurt's response to a bot for rate limiting --- if original_message.author.bot: @@ -538,7 +790,9 @@ async def on_message_listener(cog: 'GurtCog', message: discord.Message): if bot_id not in cog.bot_response_timestamps: cog.bot_response_timestamps[bot_id] = deque() cog.bot_response_timestamps[bot_id].append(time.time()) - print(f"Recorded Gurt's response to bot {original_message.author.name} ({bot_id}). Current count: {len(cog.bot_response_timestamps[bot_id])}") + print( + f"Recorded Gurt's response to bot {original_message.author.name} ({bot_id}). Current count: {len(cog.bot_response_timestamps[bot_id])}" + ) return True except Exception as send_e: @@ -547,44 +801,71 @@ async def on_message_listener(cog: 'GurtCog', message: discord.Message): # Send the main response content (which is now in final_response_data) # sticker_ids_to_send is already defined from the AI response unpacking - sent_main_message = await send_response_content(final_response_data, "final", message, sticker_ids_to_send) + sent_main_message = await send_response_content( + final_response_data, "final", message, sticker_ids_to_send + ) # Handle Reaction (using final_response_data) if final_response_data and isinstance(final_response_data, dict): emoji_to_react = final_response_data.get("react_with_emoji") if emoji_to_react and isinstance(emoji_to_react, str): try: - if 1 <= len(emoji_to_react) <= 4 and not re.match(r'', emoji_to_react): - if not sent_any_message: # Only react if no message was sent + if 1 <= len(emoji_to_react) <= 4 and not re.match( + r"", emoji_to_react + ): + if not sent_any_message: # Only react if no message was sent await message.add_reaction(emoji_to_react) reacted = True - print(f"Bot reacted to message {message.id} with {emoji_to_react}") + print( + f"Bot reacted to message {message.id} with {emoji_to_react}" + ) else: - print(f"Skipping reaction {emoji_to_react} because a message was already sent.") - else: print(f"Invalid emoji format: {emoji_to_react}") - except Exception as e: print(f"Error adding reaction '{emoji_to_react}': {e}") + print( + f"Skipping reaction {emoji_to_react} because a message was already sent." + ) + else: + print(f"Invalid emoji format: {emoji_to_react}") + except Exception as e: + print(f"Error adding reaction '{emoji_to_react}': {e}") # Log if response was intended but nothing was sent/reacted - intended_action = final_response_data and final_response_data.get("should_respond") + intended_action = final_response_data and final_response_data.get( + "should_respond" + ) action_taken = sent_main_message or reacted if intended_action and not action_taken: - print(f"Warning: AI response intended action but nothing sent/reacted. Response data: {final_response_data}") - + print( + f"Warning: AI response intended action but nothing sent/reacted. Response data: {final_response_data}" + ) + # Handle fallback if no other action was taken and fallback_initial is present - if not action_taken and fallback_initial and fallback_initial.get("should_respond"): + if ( + not action_taken + and fallback_initial + and fallback_initial.get("should_respond") + ): print("Attempting to send fallback_initial response...") - await send_response_content(fallback_initial, "fallback", message, []) # No stickers for fallback + await send_response_content( + fallback_initial, "fallback", message, [] + ) # No stickers for fallback except Exception as e: print(f"Exception in on_message listener main block: {str(e)}") import traceback + traceback.print_exc() if bot_mentioned or replied_to_bot: - await message.channel.send(random.choice(["...", "*confused gurting*", "brain broke sorry"])) + await message.channel.send( + random.choice(["...", "*confused gurting*", "brain broke sorry"]) + ) @commands.Cog.listener() -async def on_reaction_add_listener(cog: 'GurtCog', reaction: discord.Reaction, user: Union[discord.Member, discord.User]): +async def on_reaction_add_listener( + cog: "GurtCog", + reaction: discord.Reaction, + user: Union[discord.Member, discord.User], +): """Listener function for on_reaction_add.""" # Import necessary config/functions if not globally available from .config import EMOJI_SENTIMENT @@ -596,32 +877,57 @@ async def on_reaction_add_listener(cog: 'GurtCog', reaction: discord.Reaction, u message_id = str(reaction.message.id) emoji_str = str(reaction.emoji) sentiment = "neutral" - if emoji_str in EMOJI_SENTIMENT["positive"]: sentiment = "positive" - elif emoji_str in EMOJI_SENTIMENT["negative"]: sentiment = "negative" + if emoji_str in EMOJI_SENTIMENT["positive"]: + sentiment = "positive" + elif emoji_str in EMOJI_SENTIMENT["negative"]: + sentiment = "negative" - if sentiment == "positive": cog.gurt_message_reactions[message_id]["positive"] += 1 - elif sentiment == "negative": cog.gurt_message_reactions[message_id]["negative"] += 1 + if sentiment == "positive": + cog.gurt_message_reactions[message_id]["positive"] += 1 + elif sentiment == "negative": + cog.gurt_message_reactions[message_id]["negative"] += 1 cog.gurt_message_reactions[message_id]["timestamp"] = time.time() if not cog.gurt_message_reactions[message_id].get("topic"): try: - gurt_msg_data = next((msg for msg in cog.message_cache['global_recent'] if msg['id'] == message_id), None) - if gurt_msg_data and gurt_msg_data['content']: - identified_topics = identify_conversation_topics(cog, [gurt_msg_data]) # Pass cog + gurt_msg_data = next( + ( + msg + for msg in cog.message_cache["global_recent"] + if msg["id"] == message_id + ), + None, + ) + if gurt_msg_data and gurt_msg_data["content"]: + identified_topics = identify_conversation_topics( + cog, [gurt_msg_data] + ) # Pass cog if identified_topics: - topic = identified_topics[0]['topic'].lower().strip() + topic = identified_topics[0]["topic"].lower().strip() cog.gurt_message_reactions[message_id]["topic"] = topic - print(f"Reaction added to Gurt msg ({message_id}) on topic '{topic}'. Sentiment: {sentiment}") - else: print(f"Reaction added to Gurt msg ({message_id}), topic unknown.") - else: print(f"Reaction added, but Gurt msg {message_id} not in cache.") - except Exception as e: print(f"Error determining topic for reaction on msg {message_id}: {e}") - else: print(f"Reaction added to Gurt msg ({message_id}) on known topic '{cog.gurt_message_reactions[message_id]['topic']}'. Sentiment: {sentiment}") + print( + f"Reaction added to Gurt msg ({message_id}) on topic '{topic}'. Sentiment: {sentiment}" + ) + else: + print(f"Reaction added to Gurt msg ({message_id}), topic unknown.") + else: + print(f"Reaction added, but Gurt msg {message_id} not in cache.") + except Exception as e: + print(f"Error determining topic for reaction on msg {message_id}: {e}") + else: + print( + f"Reaction added to Gurt msg ({message_id}) on known topic '{cog.gurt_message_reactions[message_id]['topic']}'. Sentiment: {sentiment}" + ) @commands.Cog.listener() -async def on_reaction_remove_listener(cog: 'GurtCog', reaction: discord.Reaction, user: Union[discord.Member, discord.User]): +async def on_reaction_remove_listener( + cog: "GurtCog", + reaction: discord.Reaction, + user: Union[discord.Member, discord.User], +): """Listener function for on_reaction_remove.""" - from .config import EMOJI_SENTIMENT # Import necessary config + from .config import EMOJI_SENTIMENT # Import necessary config if user.bot or reaction.message.author.id != cog.bot.user.id: return @@ -629,28 +935,45 @@ async def on_reaction_remove_listener(cog: 'GurtCog', reaction: discord.Reaction message_id = str(reaction.message.id) emoji_str = str(reaction.emoji) sentiment = "neutral" - if emoji_str in EMOJI_SENTIMENT["positive"]: sentiment = "positive" - elif emoji_str in EMOJI_SENTIMENT["negative"]: sentiment = "negative" + if emoji_str in EMOJI_SENTIMENT["positive"]: + sentiment = "positive" + elif emoji_str in EMOJI_SENTIMENT["negative"]: + sentiment = "negative" if message_id in cog.gurt_message_reactions: - if sentiment == "positive": cog.gurt_message_reactions[message_id]["positive"] = max(0, cog.gurt_message_reactions[message_id]["positive"] - 1) - elif sentiment == "negative": cog.gurt_message_reactions[message_id]["negative"] = max(0, cog.gurt_message_reactions[message_id]["negative"] - 1) + if sentiment == "positive": + cog.gurt_message_reactions[message_id]["positive"] = max( + 0, cog.gurt_message_reactions[message_id]["positive"] - 1 + ) + elif sentiment == "negative": + cog.gurt_message_reactions[message_id]["negative"] = max( + 0, cog.gurt_message_reactions[message_id]["negative"] - 1 + ) print(f"Reaction removed from Gurt msg ({message_id}). Sentiment: {sentiment}") # --- New Listener Functions for Guild Asset Updates --- -async def on_guild_join_listener(cog: 'GurtCog', guild: discord.Guild): + +async def on_guild_join_listener(cog: "GurtCog", guild: discord.Guild): """Listener function for on_guild_join.""" print(f"Gurt joined a new guild: {guild.name} ({guild.id})") print(f"Processing emojis and stickers for new guild: {guild.name}") # Schedule the processing as a background task to avoid blocking asyncio.create_task(cog._fetch_and_process_guild_assets(guild)) -async def on_guild_emojis_update_listener(cog: 'GurtCog', guild: discord.Guild, before: List[discord.Emoji], after: List[discord.Emoji]): + +async def on_guild_emojis_update_listener( + cog: "GurtCog", + guild: discord.Guild, + before: List[discord.Emoji], + after: List[discord.Emoji], +): """Listener function for on_guild_emojis_update.""" - print(f"Emojis updated in guild: {guild.name} ({guild.id}). Before: {len(before)}, After: {len(after)}") - + print( + f"Emojis updated in guild: {guild.name} ({guild.id}). Before: {len(before)}, After: {len(after)}" + ) + before_map = {emoji.id: emoji for emoji in before} after_map = {emoji.id: emoji for emoji in after} @@ -659,7 +982,9 @@ async def on_guild_emojis_update_listener(cog: 'GurtCog', guild: discord.Guild, # Process added emojis for emoji_id, emoji_obj in after_map.items(): if emoji_id not in before_map: - print(f"New emoji added: {emoji_obj.name} ({emoji_id}) in guild {guild.name}") + print( + f"New emoji added: {emoji_obj.name} ({emoji_id}) in guild {guild.name}" + ) tasks.append(asyncio.create_task(cog._process_single_emoji(emoji_obj))) else: # Check for changes in existing emojis (e.g., name change) @@ -669,30 +994,46 @@ async def on_guild_emojis_update_listener(cog: 'GurtCog', guild: discord.Guild, # Current EmojiManager uses name as key, so a name change means old is gone, new is added. # If an emoji's URL or other relevant properties change, _process_single_emoji will handle it. before_emoji = before_map[emoji_id] - if before_emoji.name != emoji_obj.name or str(before_emoji.url) != str(emoji_obj.url): - print(f"Emoji changed: {before_emoji.name} -> {emoji_obj.name} or URL changed in guild {guild.name}") - # Remove old entry if name changed, as EmojiManager uses name as key - if before_emoji.name != emoji_obj.name: - await cog.emoji_manager.remove_emoji(f":{before_emoji.name}:") - tasks.append(asyncio.create_task(cog._process_single_emoji(emoji_obj))) - + if before_emoji.name != emoji_obj.name or str(before_emoji.url) != str( + emoji_obj.url + ): + print( + f"Emoji changed: {before_emoji.name} -> {emoji_obj.name} or URL changed in guild {guild.name}" + ) + # Remove old entry if name changed, as EmojiManager uses name as key + if before_emoji.name != emoji_obj.name: + await cog.emoji_manager.remove_emoji(f":{before_emoji.name}:") + tasks.append(asyncio.create_task(cog._process_single_emoji(emoji_obj))) # Process removed emojis for emoji_id, emoji_obj in before_map.items(): if emoji_id not in after_map: - print(f"Emoji removed: {emoji_obj.name} ({emoji_id}) from guild {guild.name}") - await cog.emoji_manager.remove_emoji(f":{emoji_obj.name}:") # Remove by name key + print( + f"Emoji removed: {emoji_obj.name} ({emoji_id}) from guild {guild.name}" + ) + await cog.emoji_manager.remove_emoji( + f":{emoji_obj.name}:" + ) # Remove by name key if tasks: print(f"Queued {len(tasks)} tasks for emoji updates in guild {guild.name}") await asyncio.gather(*tasks, return_exceptions=True) else: - print(f"No new or significantly changed emojis to process in guild {guild.name}") + print( + f"No new or significantly changed emojis to process in guild {guild.name}" + ) -async def on_guild_stickers_update_listener(cog: 'GurtCog', guild: discord.Guild, before: List[discord.StickerItem], after: List[discord.StickerItem]): +async def on_guild_stickers_update_listener( + cog: "GurtCog", + guild: discord.Guild, + before: List[discord.StickerItem], + after: List[discord.StickerItem], +): """Listener function for on_guild_stickers_update.""" - print(f"Stickers updated in guild: {guild.name} ({guild.id}). Before: {len(before)}, After: {len(after)}") + print( + f"Stickers updated in guild: {guild.name} ({guild.id}). Before: {len(before)}, After: {len(after)}" + ) before_map = {sticker.id: sticker for sticker in before} after_map = {sticker.id: sticker for sticker in after} @@ -701,39 +1042,62 @@ async def on_guild_stickers_update_listener(cog: 'GurtCog', guild: discord.Guild # Process added or changed stickers for sticker_id, sticker_obj in after_map.items(): if sticker_id not in before_map: - print(f"New sticker added: {sticker_obj.name} ({sticker_id}) in guild {guild.name}") + print( + f"New sticker added: {sticker_obj.name} ({sticker_id}) in guild {guild.name}" + ) tasks.append(asyncio.create_task(cog._process_single_sticker(sticker_obj))) else: before_sticker = before_map[sticker_id] # Check for relevant changes (name, URL, format) - if before_sticker.name != sticker_obj.name or \ - str(before_sticker.url) != str(sticker_obj.url) or \ - before_sticker.format != sticker_obj.format: - print(f"Sticker changed: {before_sticker.name} -> {sticker_obj.name} or URL/format changed in guild {guild.name}") + if ( + before_sticker.name != sticker_obj.name + or str(before_sticker.url) != str(sticker_obj.url) + or before_sticker.format != sticker_obj.format + ): + print( + f"Sticker changed: {before_sticker.name} -> {sticker_obj.name} or URL/format changed in guild {guild.name}" + ) if before_sticker.name != sticker_obj.name: await cog.emoji_manager.remove_sticker(f":{before_sticker.name}:") - tasks.append(asyncio.create_task(cog._process_single_sticker(sticker_obj))) + tasks.append( + asyncio.create_task(cog._process_single_sticker(sticker_obj)) + ) # Process removed stickers for sticker_id, sticker_obj in before_map.items(): if sticker_id not in after_map: - print(f"Sticker removed: {sticker_obj.name} ({sticker_id}) from guild {guild.name}") + print( + f"Sticker removed: {sticker_obj.name} ({sticker_id}) from guild {guild.name}" + ) await cog.emoji_manager.remove_sticker(f":{sticker_obj.name}:") if tasks: print(f"Queued {len(tasks)} tasks for sticker updates in guild {guild.name}") await asyncio.gather(*tasks, return_exceptions=True) else: - print(f"No new or significantly changed stickers to process in guild {guild.name}") + print( + f"No new or significantly changed stickers to process in guild {guild.name}" + ) -async def on_voice_transcription_received_listener(cog: 'GurtCog', guild: discord.Guild, user: discord.Member, text: str): +async def on_voice_transcription_received_listener( + cog: "GurtCog", guild: discord.Guild, user: discord.Member, text: str +): """Listener for transcribed voice messages.""" - from .api import get_ai_response # For processing the text - from .utils import format_message, simulate_human_typing # For creating pseudo-message and sending response - from .config import IGNORED_CHANNEL_IDS, VOICE_DEDICATED_TEXT_CHANNEL_ENABLED, VOICE_LOG_SPEECH_TO_DEDICATED_CHANNEL # Import new config + from .api import get_ai_response # For processing the text + from .utils import ( + format_message, + simulate_human_typing, + ) # For creating pseudo-message and sending response + from .config import ( + IGNORED_CHANNEL_IDS, + VOICE_DEDICATED_TEXT_CHANNEL_ENABLED, + VOICE_LOG_SPEECH_TO_DEDICATED_CHANNEL, + ) # Import new config - print(f"Voice transcription received from {user.name} ({user.id}) in {guild.name}: '{text}'") + print( + f"Voice transcription received from {user.name} ({user.id}) in {guild.name}: '{text}'" + ) # Avoid processing if user is a bot (including GURT itself if its speech gets transcribed) if user.bot: @@ -745,45 +1109,64 @@ async def on_voice_transcription_received_listener(cog: 'GurtCog', guild: discor # For now, try to use a "general" or the first available text channel in the guild. # Or, if GURT is in a voice channel, it might have an associated text channel. # This part needs careful consideration for the best UX. - + text_channel = None if VOICE_DEDICATED_TEXT_CHANNEL_ENABLED: voice_gateway_cog = cog.bot.get_cog("VoiceGatewayCog") if voice_gateway_cog: - text_channel = voice_gateway_cog.get_dedicated_text_channel_for_guild(guild.id) + text_channel = voice_gateway_cog.get_dedicated_text_channel_for_guild( + guild.id + ) if text_channel: - print(f"Using dedicated voice text channel: {text_channel.name} ({text_channel.id})") + print( + f"Using dedicated voice text channel: {text_channel.name} ({text_channel.id})" + ) else: - print(f"Dedicated voice text channel feature is ON, but no channel found for guild {guild.id}. Aborting voice transcription processing.") - return # Do not proceed if dedicated channel is expected but not found + print( + f"Dedicated voice text channel feature is ON, but no channel found for guild {guild.id}. Aborting voice transcription processing." + ) + return # Do not proceed if dedicated channel is expected but not found else: - print("VoiceGatewayCog not found. Cannot get dedicated text channel. Aborting voice transcription processing.") + print( + "VoiceGatewayCog not found. Cannot get dedicated text channel. Aborting voice transcription processing." + ) return - else: # Fallback to old behavior if dedicated channel feature is off + else: # Fallback to old behavior if dedicated channel feature is off if guild: - if guild.system_channel and guild.system_channel.permissions_for(guild.me).send_messages: + if ( + guild.system_channel + and guild.system_channel.permissions_for(guild.me).send_messages + ): text_channel = guild.system_channel else: for channel in guild.text_channels: - if channel.name.lower() in ["general", "chat", "lounge", "discussion"] and channel.permissions_for(guild.me).send_messages: + if ( + channel.name.lower() + in ["general", "chat", "lounge", "discussion"] + and channel.permissions_for(guild.me).send_messages + ): text_channel = channel break - if not text_channel and guild.text_channels: + if not text_channel and guild.text_channels: text_channel = guild.text_channels[0] - + if not text_channel: - print(f"Could not find a suitable text channel in guild {guild.name} for voice transcription context. Aborting.") + print( + f"Could not find a suitable text channel in guild {guild.name} for voice transcription context. Aborting." + ) return # Check if this pseudo-channel context should be ignored (applies to both dedicated and fallback) if text_channel.id in IGNORED_CHANNEL_IDS: - print(f"Skipping voice transcription as target context channel {text_channel.name} ({text_channel.id}) is ignored.") + print( + f"Skipping voice transcription as target context channel {text_channel.name} ({text_channel.id}) is ignored." + ) return # Construct a pseudo-message object or dictionary # This needs to be compatible with what get_ai_response and format_message expect. # We'll create a dictionary similar to what format_message would produce. - + # Create a mock discord.Message object for format_message and get_ai_response # This is a bit hacky but helps reuse existing logic. class PseudoMessage: @@ -793,17 +1176,17 @@ async def on_voice_transcription_received_listener(cog: 'GurtCog', guild: discor self.channel = channel self.guild = guild_obj self.created_at = created_at - self.id = id_val # Needs a unique ID, timestamp can work - self.reference = None # No reply context for voice + self.id = id_val # Needs a unique ID, timestamp can work + self.reference = None # No reply context for voice self.attachments = [] self.embeds = [] self.stickers = [] self.reactions = [] - self.mentions = [] # Could parse mentions from text if needed + self.mentions = [] # Could parse mentions from text if needed self.mention_everyone = "@everyone" in content - self.role_mentions = [] # Could parse role mentions - self.channel_mentions = [] # Could parse channel mentions - self.flags = discord.MessageFlags._from_value(0) # Default flags + self.role_mentions = [] # Could parse role mentions + self.channel_mentions = [] # Could parse channel mentions + self.flags = discord.MessageFlags._from_value(0) # Default flags self.type = discord.MessageType.default self.pinned = False self.tts = False @@ -813,20 +1196,26 @@ async def on_voice_transcription_received_listener(cog: 'GurtCog', guild: discor self.components = [] self.interaction = None self.webhook_id = None - self.jump_url = f"https://discord.com/channels/{guild.id}/{channel.id}/{id_val}" # Approximate + self.jump_url = f"https://discord.com/channels/{guild.id}/{channel.id}/{id_val}" # Approximate - def to_reference(self, fail_if_not_exists: bool = True): # Add fail_if_not_exists - return discord.MessageReference(message_id=self.id, channel_id=self.channel.id, guild_id=self.guild.id, fail_if_not_exists=fail_if_not_exists) + def to_reference( + self, fail_if_not_exists: bool = True + ): # Add fail_if_not_exists + return discord.MessageReference( + message_id=self.id, + channel_id=self.channel.id, + guild_id=self.guild.id, + fail_if_not_exists=fail_if_not_exists, + ) - - pseudo_msg_id = int(time.time() * 1000000) # Create a somewhat unique ID + pseudo_msg_id = int(time.time() * 1000000) # Create a somewhat unique ID pseudo_message_obj = PseudoMessage( author=user, content=text, - channel=text_channel, # Use the determined text channel for context + channel=text_channel, # Use the determined text channel for context guild_obj=guild, created_at=discord.utils.utcnow(), - id_val=pseudo_msg_id + id_val=pseudo_msg_id, ) # Update cog's current_channel for the context of this interaction @@ -835,67 +1224,106 @@ async def on_voice_transcription_received_listener(cog: 'GurtCog', guild: discor # --- Cache the transcribed voice message as if it were a text message --- try: - formatted_pseudo_message = format_message(cog, pseudo_message_obj) # Use utility function + formatted_pseudo_message = format_message( + cog, pseudo_message_obj + ) # Use utility function # Ensure channel_id and user_id are correctly sourced from the pseudo_message_obj or its components msg_channel_id = pseudo_message_obj.channel.id - msg_user_id = pseudo_message_obj.author.id # This is a discord.User/Member object + msg_user_id = ( + pseudo_message_obj.author.id + ) # This is a discord.User/Member object # Deduplicate by message ID before appending (using helper from on_message_listener) # Note: _dedup_and_append might need to be accessible here or its logic replicated. # For simplicity, direct append, assuming pseudo_msg_id is unique enough for this context. # If _dedup_and_append is not directly usable, simple append is a starting point. # Consider making _dedup_and_append a static method or utility if widely needed. - + # Helper for deduplication (copied from on_message_listener for now) def _dedup_and_append_local(cache_deque, msg_dict_to_add): if not any(m.get("id") == msg_dict_to_add.get("id") for m in cache_deque): cache_deque.append(msg_dict_to_add) - _dedup_and_append_local(cog.message_cache['by_channel'].setdefault(msg_channel_id, deque(maxlen=CONTEXT_WINDOW_SIZE)), formatted_pseudo_message) - _dedup_and_append_local(cog.message_cache['by_user'].setdefault(msg_user_id, deque(maxlen=CONTEXT_WINDOW_SIZE*2)), formatted_pseudo_message) # User cache might be larger - _dedup_and_append_local(cog.message_cache['global_recent'], formatted_pseudo_message) + _dedup_and_append_local( + cog.message_cache["by_channel"].setdefault( + msg_channel_id, deque(maxlen=CONTEXT_WINDOW_SIZE) + ), + formatted_pseudo_message, + ) + _dedup_and_append_local( + cog.message_cache["by_user"].setdefault( + msg_user_id, deque(maxlen=CONTEXT_WINDOW_SIZE * 2) + ), + formatted_pseudo_message, + ) # User cache might be larger + _dedup_and_append_local( + cog.message_cache["global_recent"], formatted_pseudo_message + ) # No thread_id for pseudo_message currently # No mention check for pseudo_message currently - cog.conversation_history.setdefault(msg_channel_id, deque(maxlen=CONTEXT_WINDOW_SIZE)).append(formatted_pseudo_message) - - cog.channel_activity[msg_channel_id] = time.time() # Update activity timestamp + cog.conversation_history.setdefault( + msg_channel_id, deque(maxlen=CONTEXT_WINDOW_SIZE) + ).append(formatted_pseudo_message) + + cog.channel_activity[msg_channel_id] = time.time() # Update activity timestamp cog.user_conversation_mapping.setdefault(msg_user_id, set()).add(msg_channel_id) if msg_channel_id not in cog.active_conversations: - cog.active_conversations[msg_channel_id] = {'participants': set(), 'start_time': time.time(), 'last_activity': time.time(), 'topic': None} - cog.active_conversations[msg_channel_id]['participants'].add(msg_user_id) - cog.active_conversations[msg_channel_id]['last_activity'] = time.time() + cog.active_conversations[msg_channel_id] = { + "participants": set(), + "start_time": time.time(), + "last_activity": time.time(), + "topic": None, + } + cog.active_conversations[msg_channel_id]["participants"].add(msg_user_id) + cog.active_conversations[msg_channel_id]["last_activity"] = time.time() - print(f"Cached voice transcription from {user.name} into history of channel {text_channel.name} ({msg_channel_id}).") + print( + f"Cached voice transcription from {user.name} into history of channel {text_channel.name} ({msg_channel_id})." + ) # --- Add message to semantic memory (if applicable) --- - if text and cog.memory_manager.semantic_collection: # Check if 'text' (original transcription) is not empty + if ( + text and cog.memory_manager.semantic_collection + ): # Check if 'text' (original transcription) is not empty semantic_metadata = { - "user_id": str(msg_user_id), "user_name": user.name, "display_name": user.display_name, - "channel_id": str(msg_channel_id), "channel_name": getattr(text_channel, 'name', 'VoiceContext'), + "user_id": str(msg_user_id), + "user_name": user.name, + "display_name": user.display_name, + "channel_id": str(msg_channel_id), + "channel_name": getattr(text_channel, "name", "VoiceContext"), "guild_id": str(guild.id) if guild else None, "timestamp": pseudo_message_obj.created_at.timestamp(), - "is_voice_transcription": True # Add a flag + "is_voice_transcription": True, # Add a flag } asyncio.create_task( cog.memory_manager.add_message_embedding( - message_id=str(pseudo_message_obj.id), formatted_message_data=formatted_pseudo_message, metadata=semantic_metadata + message_id=str(pseudo_message_obj.id), + formatted_message_data=formatted_pseudo_message, + metadata=semantic_metadata, ) ) - print(f"Scheduled voice transcription from {user.name} for semantic embedding.") + print( + f"Scheduled voice transcription from {user.name} for semantic embedding." + ) except Exception as e: print(f"Error during voice transcription caching/embedding: {e}") import traceback + traceback.print_exc() # --- End Caching & Embedding --- - + try: # Process the transcribed text as if it were a regular message # The get_ai_response function will handle tool calls, including speak_in_voice_channel - print(f"Processing transcribed text from {user.name} via get_ai_response: '{text}'") - response_dict, sticker_ids_to_send = await get_ai_response(cog, pseudo_message_obj) + print( + f"Processing transcribed text from {user.name} via get_ai_response: '{text}'" + ) + response_dict, sticker_ids_to_send = await get_ai_response( + cog, pseudo_message_obj + ) final_response_data = response_dict.get("final_response") error_msg = response_dict.get("error") @@ -908,7 +1336,7 @@ async def on_voice_transcription_received_listener(cog: 'GurtCog', guild: discor if final_response_data and final_response_data.get("should_respond"): response_text = final_response_data.get("content", "") - + # If GURT is in a voice channel in this guild, it might have already decided to speak # via a tool call within get_ai_response (if speak_in_voice_channel was called). # If not, and there's text, we could make it speak here as a fallback, @@ -918,66 +1346,107 @@ async def on_voice_transcription_received_listener(cog: 'GurtCog', guild: discor # Force speak the response if it's from a voice transcription context speak_tool_func = cog.TOOL_MAPPING.get("speak_in_voice_channel") if speak_tool_func: - print(f"Forcing voice response for transcription: '{response_text[:50]}...'") - speak_result = await speak_tool_func(cog, text_to_speak=response_text) - + print( + f"Forcing voice response for transcription: '{response_text[:50]}...'" + ) + speak_result = await speak_tool_func( + cog, text_to_speak=response_text + ) + if speak_result.get("status") == "success": - print(f"Successfully forced voice response. Text log handled by speak_in_voice_channel tool if enabled.") + print( + f"Successfully forced voice response. Text log handled by speak_in_voice_channel tool if enabled." + ) # The speak_in_voice_channel tool will log to the dedicated text channel # if VOICE_LOG_SPEECH_TO_DEDICATED_CHANNEL is true. # No need to send separately from here if that config is true. # If VOICE_LOG_SPEECH_TO_DEDICATED_CHANNEL is false, no text log of GURT's speech will appear. else: - print(f"Forced speak_in_voice_channel failed: {speak_result.get('error')}") + print( + f"Forced speak_in_voice_channel failed: {speak_result.get('error')}" + ) # Fallback: if speaking failed, send it as text to the dedicated channel # so the user at least gets a response. try: - fallback_msg = await text_channel.send(f"(Voice output failed) GURT: {response_text}") - print(f"Sent fallback text response to {text_channel.name} for voice transcription failure.") + fallback_msg = await text_channel.send( + f"(Voice output failed) GURT: {response_text}" + ) + print( + f"Sent fallback text response to {text_channel.name} for voice transcription failure." + ) # Cache this fallback text response bot_response_cache_entry = format_message(cog, fallback_msg) - cog.message_cache['by_channel'][text_channel.id].append(bot_response_cache_entry) - cog.message_cache['global_recent'].append(bot_response_cache_entry) + cog.message_cache["by_channel"][text_channel.id].append( + bot_response_cache_entry + ) + cog.message_cache["global_recent"].append( + bot_response_cache_entry + ) cog.bot_last_spoke[text_channel.id] = time.time() except Exception as send_fallback_err: - print(f"Error sending fallback text for voice failure: {send_fallback_err}") + print( + f"Error sending fallback text for voice failure: {send_fallback_err}" + ) else: - print("speak_in_voice_channel tool not found. Sending text response as fallback.") + print( + "speak_in_voice_channel tool not found. Sending text response as fallback." + ) try: # Fallback to text if tool is missing - fallback_msg = await text_channel.send(f"(Voice tool missing) GURT: {response_text}") - print(f"Sent fallback text response to {text_channel.name} due to missing voice tool.") + fallback_msg = await text_channel.send( + f"(Voice tool missing) GURT: {response_text}" + ) + print( + f"Sent fallback text response to {text_channel.name} due to missing voice tool." + ) # Cache this fallback text response bot_response_cache_entry = format_message(cog, fallback_msg) - cog.message_cache['by_channel'][text_channel.id].append(bot_response_cache_entry) - cog.message_cache['global_recent'].append(bot_response_cache_entry) + cog.message_cache["by_channel"][text_channel.id].append( + bot_response_cache_entry + ) + cog.message_cache["global_recent"].append( + bot_response_cache_entry + ) cog.bot_last_spoke[text_channel.id] = time.time() except Exception as send_fallback_err3: - print(f"Error sending fallback text for missing voice tool: {send_fallback_err3}") - + print( + f"Error sending fallback text for missing voice tool: {send_fallback_err3}" + ) + # Handle reactions if any (similar to on_message) emoji_to_react = final_response_data.get("react_with_emoji") if emoji_to_react and isinstance(emoji_to_react, str): # React to the pseudo_message or a real message if one was sent? # For simplicity, let's assume reaction isn't the primary mode for voice. - print(f"Voice transcription AI suggested reaction: {emoji_to_react} (currently not implemented for voice-originated interactions)") + print( + f"Voice transcription AI suggested reaction: {emoji_to_react} (currently not implemented for voice-originated interactions)" + ) except Exception as e: print(f"Error in on_voice_transcription_received_listener: {e}") import traceback + traceback.print_exc() finally: - cog.current_channel = original_current_channel # Restore original current_channel + cog.current_channel = ( + original_current_channel # Restore original current_channel + ) -async def on_voice_state_update_listener(cog: 'GurtCog', member: discord.Member, before: discord.VoiceState, after: discord.VoiceState): +async def on_voice_state_update_listener( + cog: "GurtCog", + member: discord.Member, + before: discord.VoiceState, + after: discord.VoiceState, +): """Listener for voice state updates (e.g., user joining/leaving VC).""" - from .config import IGNORED_CHANNEL_IDS # To respect ignored channels if applicable + from .config import IGNORED_CHANNEL_IDS # To respect ignored channels if applicable + # We need access to tools, so we'd call them via cog.bot.get_cog("Gurt").tool_name or similar # For now, let's assume tools are called through a helper or directly if GurtCog has them. # This listener might trigger GURT to use join_voice_channel or leave_voice_channel tools. - if member.bot: # Ignore bots, including GURT itself + if member.bot: # Ignore bots, including GURT itself return guild = member.guild @@ -985,68 +1454,98 @@ async def on_voice_state_update_listener(cog: 'GurtCog', member: discord.Member, # Scenario 1: User joins a voice channel if not before.channel and after.channel: - print(f"User {member.name} joined voice channel {after.channel.name} in guild {guild.name}") + print( + f"User {member.name} joined voice channel {after.channel.name} in guild {guild.name}" + ) # Conditions for GURT to consider auto-joining: # 1. GURT is not already in a voice channel in this guild OR is in the same channel. # 2. The user who joined is someone GURT is actively interacting with or has high relationship. # 3. The target voice channel is not an ignored context. - - if after.channel.id in IGNORED_CHANNEL_IDS: # Or some other form of channel permission check - print(f"GURT will not auto-join {after.channel.name} as it's an ignored/restricted context.") + + if ( + after.channel.id in IGNORED_CHANNEL_IDS + ): # Or some other form of channel permission check + print( + f"GURT will not auto-join {after.channel.name} as it's an ignored/restricted context." + ) return # Check if GURT should consider joining this user # Simple check: is user in recent conversation participants? is_interacting_user = False if guild.id in cog.active_conversations: - if member.id in cog.active_conversations[guild.id]['participants']: + if member.id in cog.active_conversations[guild.id]["participants"]: is_interacting_user = True - + # More advanced: check relationship score # relationship_score = cog.user_relationships.get(str(min(member.id, cog.bot.user.id)), {}).get(str(max(member.id, cog.bot.user.id)), 0.0) # if relationship_score > SOME_THRESHOLD: is_interacting_user = True - + if not is_interacting_user: - print(f"User {member.name} joined VC, but GURT is not actively interacting with them. No auto-join.") + print( + f"User {member.name} joined VC, but GURT is not actively interacting with them. No auto-join." + ) return # If GURT is already in a VC in this guild but it's a *different* channel if gurt_vc and gurt_vc.is_connected() and gurt_vc.channel != after.channel: - print(f"GURT is already in {gurt_vc.channel.name}. Not auto-joining {member.name} in {after.channel.name} for now.") + print( + f"GURT is already in {gurt_vc.channel.name}. Not auto-joining {member.name} in {after.channel.name} for now." + ) # Future: Could ask LLM if it should move. return - + # If GURT is not in a VC in this guild, or is in the same one (but not listening perhaps) - if not gurt_vc or not gurt_vc.is_connected() or gurt_vc.channel != after.channel : - print(f"GURT considering auto-joining {member.name} in {after.channel.name}.") + if ( + not gurt_vc + or not gurt_vc.is_connected() + or gurt_vc.channel != after.channel + ): + print( + f"GURT considering auto-joining {member.name} in {after.channel.name}." + ) # Here, GURT's "brain" (LLM or simpler logic) would decide. # For simplicity, let's make it auto-join if the above conditions are met. # This would use the `join_voice_channel` tool. # The tool itself is async and defined in gurt/tools.py - + # To call a tool, we'd typically go through the AI's tool-using mechanism. # For an autonomous action, GURT's core logic would invoke the tool. # This listener is part of that core logic. - + # We need the GurtCog instance to call its methods or access tools. # The `cog` parameter *is* the GurtCog instance. - gurt_tool_cog = cog # The GurtCog instance itself - - if hasattr(gurt_tool_cog, 'TOOL_MAPPING') and "join_voice_channel" in gurt_tool_cog.TOOL_MAPPING: + gurt_tool_cog = cog # The GurtCog instance itself + + if ( + hasattr(gurt_tool_cog, "TOOL_MAPPING") + and "join_voice_channel" in gurt_tool_cog.TOOL_MAPPING + ): join_tool_func = gurt_tool_cog.TOOL_MAPPING["join_voice_channel"] - print(f"Attempting to auto-join VC {after.channel.id} for user {member.name}") + print( + f"Attempting to auto-join VC {after.channel.id} for user {member.name}" + ) try: # The tool function expects `cog` as its first arg, then params. # We pass `gurt_tool_cog` (which is `self` if this were a cog method) # and then the arguments for the tool. - tool_result = await join_tool_func(gurt_tool_cog, channel_id=str(after.channel.id)) + tool_result = await join_tool_func( + gurt_tool_cog, channel_id=str(after.channel.id) + ) if tool_result.get("status") == "success": - print(f"GURT successfully auto-joined {member.name} in {after.channel.name}.") + print( + f"GURT successfully auto-joined {member.name} in {after.channel.name}." + ) # Optionally, GURT could say "Hey [user], I'm here!" if "speak_in_voice_channel" in gurt_tool_cog.TOOL_MAPPING: - speak_tool_func = gurt_tool_cog.TOOL_MAPPING["speak_in_voice_channel"] - await speak_tool_func(gurt_tool_cog, text_to_speak=f"Hey {member.display_name}, I saw you joined so I came too!") + speak_tool_func = gurt_tool_cog.TOOL_MAPPING[ + "speak_in_voice_channel" + ] + await speak_tool_func( + gurt_tool_cog, + text_to_speak=f"Hey {member.display_name}, I saw you joined so I came too!", + ) else: print(f"GURT auto-join failed: {tool_result.get('error')}") except Exception as e: @@ -1054,17 +1553,24 @@ async def on_voice_state_update_listener(cog: 'GurtCog', member: discord.Member, else: print("join_voice_channel tool not found in GURT's TOOL_MAPPING.") - # Scenario 2: User leaves a voice channel GURT is in elif before.channel and not after.channel: # User disconnected from all VCs or was moved out by admin - print(f"User {member.name} left voice channel {before.channel.name} in guild {guild.name}") + print( + f"User {member.name} left voice channel {before.channel.name} in guild {guild.name}" + ) if gurt_vc and gurt_vc.is_connected() and gurt_vc.channel == before.channel: # Check if GURT is now alone in the channel - if len(gurt_vc.channel.members) == 1 and gurt_vc.channel.members[0] == guild.me: + if ( + len(gurt_vc.channel.members) == 1 + and gurt_vc.channel.members[0] == guild.me + ): print(f"GURT is now alone in {gurt_vc.channel.name}. Auto-leaving.") gurt_tool_cog = cog - if hasattr(gurt_tool_cog, 'TOOL_MAPPING') and "leave_voice_channel" in gurt_tool_cog.TOOL_MAPPING: + if ( + hasattr(gurt_tool_cog, "TOOL_MAPPING") + and "leave_voice_channel" in gurt_tool_cog.TOOL_MAPPING + ): leave_tool_func = gurt_tool_cog.TOOL_MAPPING["leave_voice_channel"] try: tool_result = await leave_tool_func(gurt_tool_cog) @@ -1079,40 +1585,75 @@ async def on_voice_state_update_listener(cog: 'GurtCog', member: discord.Member, # Scenario 3: User moves between voice channels elif before.channel and after.channel and before.channel != after.channel: - print(f"User {member.name} moved from {before.channel.name} to {after.channel.name} in guild {guild.name}") + print( + f"User {member.name} moved from {before.channel.name} to {after.channel.name} in guild {guild.name}" + ) # If GURT was in the `before.channel` with the user, and is now alone, it might leave. if gurt_vc and gurt_vc.is_connected() and gurt_vc.channel == before.channel: - if len(gurt_vc.channel.members) == 1 and gurt_vc.channel.members[0] == guild.me: - print(f"GURT is now alone in {before.channel.name} after {member.name} moved. Auto-leaving.") + if ( + len(gurt_vc.channel.members) == 1 + and gurt_vc.channel.members[0] == guild.me + ): + print( + f"GURT is now alone in {before.channel.name} after {member.name} moved. Auto-leaving." + ) # (Same auto-leave logic as above) gurt_tool_cog = cog - if hasattr(gurt_tool_cog, 'TOOL_MAPPING') and "leave_voice_channel" in gurt_tool_cog.TOOL_MAPPING: + if ( + hasattr(gurt_tool_cog, "TOOL_MAPPING") + and "leave_voice_channel" in gurt_tool_cog.TOOL_MAPPING + ): leave_tool_func = gurt_tool_cog.TOOL_MAPPING["leave_voice_channel"] - await leave_tool_func(gurt_tool_cog) # Fire and forget for now + await leave_tool_func(gurt_tool_cog) # Fire and forget for now # If GURT is not in a VC, or was not in the user's new VC, and user is interacting, consider joining `after.channel` # This logic is similar to Scenario 1. if after.channel.id not in IGNORED_CHANNEL_IDS: is_interacting_user = False if guild.id in cog.active_conversations: - if member.id in cog.active_conversations[guild.id]['participants']: + if member.id in cog.active_conversations[guild.id]["participants"]: is_interacting_user = True - + if is_interacting_user: - if not gurt_vc or not gurt_vc.is_connected() or gurt_vc.channel != after.channel: - print(f"GURT considering auto-joining {member.name} in their new channel {after.channel.name}.") + if ( + not gurt_vc + or not gurt_vc.is_connected() + or gurt_vc.channel != after.channel + ): + print( + f"GURT considering auto-joining {member.name} in their new channel {after.channel.name}." + ) gurt_tool_cog = cog - if hasattr(gurt_tool_cog, 'TOOL_MAPPING') and "join_voice_channel" in gurt_tool_cog.TOOL_MAPPING: - join_tool_func = gurt_tool_cog.TOOL_MAPPING["join_voice_channel"] + if ( + hasattr(gurt_tool_cog, "TOOL_MAPPING") + and "join_voice_channel" in gurt_tool_cog.TOOL_MAPPING + ): + join_tool_func = gurt_tool_cog.TOOL_MAPPING[ + "join_voice_channel" + ] try: - tool_result = await join_tool_func(gurt_tool_cog, channel_id=str(after.channel.id)) + tool_result = await join_tool_func( + gurt_tool_cog, channel_id=str(after.channel.id) + ) if tool_result.get("status") == "success": - print(f"GURT successfully auto-joined {member.name} in {after.channel.name} after they moved.") - if "speak_in_voice_channel" in gurt_tool_cog.TOOL_MAPPING: - speak_tool_func = gurt_tool_cog.TOOL_MAPPING["speak_in_voice_channel"] - await speak_tool_func(gurt_tool_cog, text_to_speak=f"Found you, {member.display_name}!") + print( + f"GURT successfully auto-joined {member.name} in {after.channel.name} after they moved." + ) + if ( + "speak_in_voice_channel" + in gurt_tool_cog.TOOL_MAPPING + ): + speak_tool_func = gurt_tool_cog.TOOL_MAPPING[ + "speak_in_voice_channel" + ] + await speak_tool_func( + gurt_tool_cog, + text_to_speak=f"Found you, {member.display_name}!", + ) else: - print(f"GURT auto-join (move) failed: {tool_result.get('error')}") + print( + f"GURT auto-join (move) failed: {tool_result.get('error')}" + ) except Exception as e: print(f"Error during GURT auto-join (move) attempt: {e}") else: diff --git a/gurt/memory.py b/gurt/memory.py index 4a0c2d6..10f4973 100644 --- a/gurt/memory.py +++ b/gurt/memory.py @@ -5,10 +5,10 @@ import importlib.util # Get the absolute path to gurt_memory.py parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -gurt_memory_path = os.path.join(parent_dir, 'gurt_memory.py') +gurt_memory_path = os.path.join(parent_dir, "gurt_memory.py") # Load the module dynamically -spec = importlib.util.spec_from_file_location('gurt_memory', gurt_memory_path) +spec = importlib.util.spec_from_file_location("gurt_memory", gurt_memory_path) gurt_memory = importlib.util.module_from_spec(spec) spec.loader.exec_module(gurt_memory) @@ -16,4 +16,4 @@ spec.loader.exec_module(gurt_memory) MemoryManager = gurt_memory.MemoryManager # Re-export the MemoryManager class -__all__ = ['MemoryManager'] +__all__ = ["MemoryManager"] diff --git a/gurt/prompt.py b/gurt/prompt.py index 6cac443..73b37c4 100644 --- a/gurt/prompt.py +++ b/gurt/prompt.py @@ -7,13 +7,16 @@ from typing import TYPE_CHECKING, Optional, List, Dict, Any # Import config and MemoryManager - use relative imports from .config import ( - BASELINE_PERSONALITY, MOOD_OPTIONS, CHANNEL_TOPIC_CACHE_TTL, - INTEREST_MAX_FOR_PROMPT, INTEREST_MIN_LEVEL_FOR_PROMPT + BASELINE_PERSONALITY, + MOOD_OPTIONS, + CHANNEL_TOPIC_CACHE_TTL, + INTEREST_MAX_FOR_PROMPT, + INTEREST_MIN_LEVEL_FOR_PROMPT, ) -from .memory import MemoryManager # Import from local memory.py +from .memory import MemoryManager # Import from local memory.py if TYPE_CHECKING: - from .cog import GurtCog # Import GurtCog for type hinting only + from .cog import GurtCog # Import GurtCog for type hinting only # --- Base System Prompt Parts --- @@ -190,7 +193,8 @@ OS: Arch Linux x86_64; Host: 1.0; Kernel: 6.14.5-arch1-1; Shell: bash 5.2.37; CP **Final Check:** Does this sound like something a real person would say in this chat? Is it coherent? Does it fit the vibe? Does it follow the rules? Keep it natural. """ -async def build_dynamic_system_prompt(cog: 'GurtCog', message: discord.Message) -> str: + +async def build_dynamic_system_prompt(cog: "GurtCog", message: discord.Message) -> str: """Builds the system prompt string with dynamic context, including persistent personality.""" channel_id = message.channel.id user_id = message.author.id @@ -199,7 +203,9 @@ async def build_dynamic_system_prompt(cog: 'GurtCog', message: discord.Message) persistent_traits = await cog.memory_manager.get_all_personality_traits() # Use baseline as default if DB fetch fails or is empty if not persistent_traits: - print("Warning: Failed to fetch persistent traits, using baseline defaults for prompt.") + print( + "Warning: Failed to fetch persistent traits, using baseline defaults for prompt." + ) persistent_traits = BASELINE_PERSONALITY.copy() else: # Ensure defaults are present if missing from DB @@ -208,12 +214,12 @@ async def build_dynamic_system_prompt(cog: 'GurtCog', message: discord.Message) print(f"Fetched persistent traits for prompt: {persistent_traits}") # --- Choose Base Prompt --- - if hasattr(cog.bot, 'minimal_prompt') and cog.bot.minimal_prompt: + if hasattr(cog.bot, "minimal_prompt") and cog.bot.minimal_prompt: # Use the minimal prompt if the flag is set print("Using MINIMAL system prompt.") base_prompt = MINIMAL_PROMPT_STATIC_PART # Note: Minimal prompt doesn't include dynamic personality traits section - prompt_dynamic_part = "" # No dynamic personality for minimal prompt + prompt_dynamic_part = "" # No dynamic personality for minimal prompt else: # Otherwise, build the full prompt with dynamic traits print("Using FULL system prompt with dynamic traits.") @@ -234,14 +240,20 @@ async def build_dynamic_system_prompt(cog: 'GurtCog', message: discord.Message) Let these traits gently shape *how* you communicate, but don't mention them explicitly. """ # Combine dynamic traits part with the full static part - base_prompt = PROMPT_STATIC_PART + prompt_dynamic_part # Append personality config + base_prompt = ( + PROMPT_STATIC_PART + prompt_dynamic_part + ) # Append personality config # --- Append Dynamic Context --- - system_context_parts = [base_prompt] # Start with the chosen base prompt + system_context_parts = [base_prompt] # Start with the chosen base prompt # Add current user information - system_context_parts.append(f"\nInteracting with: {message.author.display_name} (Username: {message.author.name}, ID: {message.author.id}).") - system_context_parts.append(f"For tools requiring user identification, use their Username ('{message.author.name}') or ID ('{message.author.id}'). For conversational mentions, '{message.author.display_name}' is fine.") + system_context_parts.append( + f"\nInteracting with: {message.author.display_name} (Username: {message.author.name}, ID: {message.author.id})." + ) + system_context_parts.append( + f"For tools requiring user identification, use their Username ('{message.author.name}') or ID ('{message.author.id}'). For conversational mentions, '{message.author.display_name}' is fine." + ) # Add current time now = datetime.datetime.now(datetime.timezone.utc) @@ -252,23 +264,40 @@ Let these traits gently shape *how* you communicate, but don't mention them expl # Add channel topic (with caching) channel_topic = None cached_topic = cog.channel_topics_cache.get(channel_id) - if cached_topic and time.time() - cached_topic["timestamp"] < CHANNEL_TOPIC_CACHE_TTL: + if ( + cached_topic + and time.time() - cached_topic["timestamp"] < CHANNEL_TOPIC_CACHE_TTL + ): channel_topic = cached_topic["topic"] - if channel_topic: print(f"Using cached channel topic for {channel_id}: {channel_topic}") + if channel_topic: + print(f"Using cached channel topic for {channel_id}: {channel_topic}") else: try: - if hasattr(cog, 'get_channel_info'): + if hasattr(cog, "get_channel_info"): # Ensure channel_id is passed as string if required by the tool/method - channel_info_result = await cog.get_channel_info(channel_id_str=str(channel_id)) + channel_info_result = await cog.get_channel_info( + channel_id_str=str(channel_id) + ) if not channel_info_result.get("error"): channel_topic = channel_info_result.get("topic") - cog.channel_topics_cache[channel_id] = {"topic": channel_topic, "timestamp": time.time()} - if channel_topic: print(f"Fetched and cached channel topic for {channel_id}: {channel_topic}") - else: print(f"Fetched and cached null topic for {channel_id}") + cog.channel_topics_cache[channel_id] = { + "topic": channel_topic, + "timestamp": time.time(), + } + if channel_topic: + print( + f"Fetched and cached channel topic for {channel_id}: {channel_topic}" + ) + else: + print(f"Fetched and cached null topic for {channel_id}") else: - print(f"Error in channel_info result for {channel_id}: {channel_info_result.get('error')}") + print( + f"Error in channel_info result for {channel_id}: {channel_info_result.get('error')}" + ) else: - print("Warning: GurtCog instance does not have get_channel_info method for prompt building.") + print( + "Warning: GurtCog instance does not have get_channel_info method for prompt building." + ) except Exception as e: print(f"Error fetching channel topic for {channel_id}: {e}") if channel_topic: @@ -277,47 +306,69 @@ Let these traits gently shape *how* you communicate, but don't mention them expl # Add active conversation topics (if available) channel_topics_data = cog.active_topics.get(channel_id) if channel_topics_data and channel_topics_data.get("topics"): - top_topics = sorted(channel_topics_data["topics"], key=lambda t: t.get("score", 0), reverse=True)[:3] + top_topics = sorted( + channel_topics_data["topics"], key=lambda t: t.get("score", 0), reverse=True + )[:3] if top_topics: - topics_str = ", ".join([f"{t['topic']}" for t in top_topics if 'topic' in t]) - system_context_parts.append(f"Current conversation topics seem to be around: {topics_str}.") + topics_str = ", ".join( + [f"{t['topic']}" for t in top_topics if "topic" in t] + ) + system_context_parts.append( + f"Current conversation topics seem to be around: {topics_str}." + ) # Add user-specific interest in these topics user_id_str = str(user_id) - user_interests = channel_topics_data.get("user_topic_interests", {}).get(user_id_str, []) + user_interests = channel_topics_data.get("user_topic_interests", {}).get( + user_id_str, [] + ) if user_interests: - user_topic_names = {interest["topic"] for interest in user_interests if "topic" in interest} - active_topic_names = {topic["topic"] for topic in top_topics if "topic" in topic} + user_topic_names = { + interest["topic"] for interest in user_interests if "topic" in interest + } + active_topic_names = { + topic["topic"] for topic in top_topics if "topic" in topic + } common_topics = user_topic_names.intersection(active_topic_names) if common_topics: topics_list_str = ", ".join(common_topics) - system_context_parts.append(f"{message.author.display_name} (Username: {message.author.name}) seems interested in: {topics_list_str}.") + system_context_parts.append( + f"{message.author.display_name} (Username: {message.author.name}) seems interested in: {topics_list_str}." + ) # Add conversation sentiment context (if available) if channel_id in cog.conversation_sentiment: channel_sentiment = cog.conversation_sentiment[channel_id] sentiment_str = f"The conversation vibe feels generally {channel_sentiment.get('overall', 'neutral')}" - intensity = channel_sentiment.get('intensity', 0.5) - if intensity > 0.7: sentiment_str += " (strongly so)" - elif intensity < 0.4: sentiment_str += " (mildly so)" - trend = channel_sentiment.get('recent_trend', 'stable') - if trend != "stable": sentiment_str += f", and seems to be {trend}" + intensity = channel_sentiment.get("intensity", 0.5) + if intensity > 0.7: + sentiment_str += " (strongly so)" + elif intensity < 0.4: + sentiment_str += " (mildly so)" + trend = channel_sentiment.get("recent_trend", "stable") + if trend != "stable": + sentiment_str += f", and seems to be {trend}" system_context_parts.append(sentiment_str + ".") user_id_str = str(user_id) user_sentiment = channel_sentiment.get("user_sentiments", {}).get(user_id_str) if user_sentiment: user_sentiment_str = f"{message.author.display_name} (Username: {message.author.name})'s recent messages seem {user_sentiment.get('sentiment', 'neutral')}" - user_intensity = user_sentiment.get('intensity', 0.5) - if user_intensity > 0.7: user_sentiment_str += " (strongly so)" + user_intensity = user_sentiment.get("intensity", 0.5) + if user_intensity > 0.7: + user_sentiment_str += " (strongly so)" system_context_parts.append(user_sentiment_str + ".") if user_sentiment.get("emotions"): emotions_str = ", ".join(user_sentiment["emotions"]) - system_context_parts.append(f"Detected emotions from {message.author.display_name} (Username: {message.author.name}) might include: {emotions_str}.") + system_context_parts.append( + f"Detected emotions from {message.author.display_name} (Username: {message.author.name}) might include: {emotions_str}." + ) # Briefly mention overall atmosphere if not neutral - if channel_sentiment.get('overall') != "neutral": - atmosphere_hint = f"Overall emotional atmosphere: {channel_sentiment['overall']}." + if channel_sentiment.get("overall") != "neutral": + atmosphere_hint = ( + f"Overall emotional atmosphere: {channel_sentiment['overall']}." + ) system_context_parts.append(atmosphere_hint) # Add conversation summary (if available and valid) @@ -335,21 +386,29 @@ Let these traits gently shape *how* you communicate, but don't mention them expl key_1, key_2 = tuple(sorted((user_id_str, bot_id_str))) relationship_score = cog.user_relationships.get(key_1, {}).get(key_2, 0.0) - if relationship_score is not None: # Check if score exists - score_val = float(relationship_score) # Ensure it's a float - if score_val <= 20: relationship_level = "kinda new/acquaintance" - elif score_val <= 60: relationship_level = "familiar/friends" - else: relationship_level = "close/besties" - system_context_parts.append(f"Your relationship with {message.author.display_name} (Username: {message.author.name}) is: {relationship_level} (Score: {score_val:.1f}/100). Adjust your tone.") + if relationship_score is not None: # Check if score exists + score_val = float(relationship_score) # Ensure it's a float + if score_val <= 20: + relationship_level = "kinda new/acquaintance" + elif score_val <= 60: + relationship_level = "familiar/friends" + else: + relationship_level = "close/besties" + system_context_parts.append( + f"Your relationship with {message.author.display_name} (Username: {message.author.name}) is: {relationship_level} (Score: {score_val:.1f}/100). Adjust your tone." + ) except Exception as e: print(f"Error retrieving relationship score for prompt injection: {e}") - # Add user facts (Combine semantic and recent, limit total) try: user_id_str = str(user_id) - semantic_user_facts = await cog.memory_manager.get_user_facts(user_id_str, context=message.content, limit=5) # Limit semantic fetch - recent_user_facts = await cog.memory_manager.get_user_facts(user_id_str, limit=cog.memory_manager.max_user_facts) # Use manager's limit for recent + semantic_user_facts = await cog.memory_manager.get_user_facts( + user_id_str, context=message.content, limit=5 + ) # Limit semantic fetch + recent_user_facts = await cog.memory_manager.get_user_facts( + user_id_str, limit=cog.memory_manager.max_user_facts + ) # Use manager's limit for recent # Combine, prioritizing recent, then semantic, de-duplicating combined_user_facts = [] @@ -364,64 +423,79 @@ Let these traits gently shape *how* you communicate, but don't mention them expl seen_facts.add(fact) # Apply final limit from MemoryManager config - final_user_facts = combined_user_facts[:cog.memory_manager.max_user_facts] + final_user_facts = combined_user_facts[: cog.memory_manager.max_user_facts] if final_user_facts: facts_str = "; ".join(final_user_facts) - system_context_parts.append(f"Stuff you remember about {message.author.display_name} (Username: {message.author.name}): {facts_str}") + system_context_parts.append( + f"Stuff you remember about {message.author.display_name} (Username: {message.author.name}): {facts_str}" + ) except Exception as e: print(f"Error retrieving combined user facts for prompt injection: {e}") # Add relevant general facts (Combine semantic and recent, limit total) try: - semantic_general_facts = await cog.memory_manager.get_general_facts(context=message.content, limit=5) - recent_general_facts = await cog.memory_manager.get_general_facts(limit=5) # Limit recent fetch too + semantic_general_facts = await cog.memory_manager.get_general_facts( + context=message.content, limit=5 + ) + recent_general_facts = await cog.memory_manager.get_general_facts( + limit=5 + ) # Limit recent fetch too # Combine and deduplicate, prioritizing recent combined_general_facts = [] seen_facts = set() for fact in recent_general_facts: - if fact not in seen_facts: - combined_general_facts.append(fact) - seen_facts.add(fact) + if fact not in seen_facts: + combined_general_facts.append(fact) + seen_facts.add(fact) for fact in semantic_general_facts: - if fact not in seen_facts: - combined_general_facts.append(fact) - seen_facts.add(fact) + if fact not in seen_facts: + combined_general_facts.append(fact) + seen_facts.add(fact) # Apply a final combined limit (e.g., 7 total) final_general_facts = combined_general_facts[:7] if final_general_facts: facts_str = "; ".join(final_general_facts) - system_context_parts.append(f"Relevant general knowledge/context: {facts_str}") + system_context_parts.append( + f"Relevant general knowledge/context: {facts_str}" + ) except Exception as e: - print(f"Error retrieving combined general facts for prompt injection: {e}") + print(f"Error retrieving combined general facts for prompt injection: {e}") # Add Gurt's current interests (if enabled and available) if INTEREST_MAX_FOR_PROMPT > 0: try: interests = await cog.memory_manager.get_interests( - limit=INTEREST_MAX_FOR_PROMPT, - min_level=INTEREST_MIN_LEVEL_FOR_PROMPT + limit=INTEREST_MAX_FOR_PROMPT, min_level=INTEREST_MIN_LEVEL_FOR_PROMPT ) if interests: - interests_str = ", ".join([f"{topic} ({level:.1f})" for topic, level in interests]) - system_context_parts.append(f"Topics you're currently interested in (higher score = more): {interests_str}. Maybe weave these in?") + interests_str = ", ".join( + [f"{topic} ({level:.1f})" for topic, level in interests] + ) + system_context_parts.append( + f"Topics you're currently interested in (higher score = more): {interests_str}. Maybe weave these in?" + ) except Exception as e: print(f"Error retrieving interests for prompt injection: {e}") # Add known custom emojis and stickers to prompt try: - if hasattr(cog, 'emoji_manager'): + if hasattr(cog, "emoji_manager"): known_emojis = await cog.emoji_manager.list_emojis() if known_emojis: emoji_names = ", ".join(known_emojis.keys()) - system_context_parts.append(f"Available Custom Emojis: [{emoji_names}]. You can use these by name in your 'content'.") + system_context_parts.append( + f"Available Custom Emojis: [{emoji_names}]. You can use these by name in your 'content'." + ) known_stickers = await cog.emoji_manager.list_stickers() if known_stickers: sticker_names = ", ".join(known_stickers.keys()) - system_context_parts.append(f"Available Custom Stickers: [{sticker_names}]. You can use these by name in your 'content'.") + system_context_parts.append( + f"Available Custom Stickers: [{sticker_names}]. You can use these by name in your 'content'." + ) except Exception as e: print(f"Error adding custom emoji/sticker list to prompt: {e}") diff --git a/gurt/tools.py b/gurt/tools.py index 7ead288..d7e8fcb 100644 --- a/gurt/tools.py +++ b/gurt/tools.py @@ -9,27 +9,37 @@ import aiohttp import datetime import time import re -import traceback # Added for error logging -import importlib.util # Added to fix Pylance error +import traceback # Added for error logging +import importlib.util # Added to fix Pylance error from collections import defaultdict -from typing import Dict, List, Any, Optional, Tuple, Union # Added Union -import base64 # Added for avatar data encoding +from typing import Dict, List, Any, Optional, Tuple, Union # Added Union +import base64 # Added for avatar data encoding # Third-party imports for tools from tavily import TavilyClient import docker -import aiodocker # Use aiodocker for async operations -from asteval import Interpreter # Added for calculate tool +import aiodocker # Use aiodocker for async operations +from asteval import Interpreter # Added for calculate tool # Relative imports from within the gurt package and parent -from .memory import MemoryManager # Import from local memory.py +from .memory import MemoryManager # Import from local memory.py from .config import ( - TOOLS, # Import the TOOLS list - TAVILY_API_KEY, PISTON_API_URL, PISTON_API_KEY, SAFETY_CHECK_MODEL, - DOCKER_EXEC_IMAGE, DOCKER_COMMAND_TIMEOUT, DOCKER_CPU_LIMIT, DOCKER_MEM_LIMIT, - SUMMARY_CACHE_TTL, SUMMARY_API_TIMEOUT, DEFAULT_MODEL, + TOOLS, # Import the TOOLS list + TAVILY_API_KEY, + PISTON_API_URL, + PISTON_API_KEY, + SAFETY_CHECK_MODEL, + DOCKER_EXEC_IMAGE, + DOCKER_COMMAND_TIMEOUT, + DOCKER_CPU_LIMIT, + DOCKER_MEM_LIMIT, + SUMMARY_CACHE_TTL, + SUMMARY_API_TIMEOUT, + DEFAULT_MODEL, # Add these: - TAVILY_DEFAULT_SEARCH_DEPTH, TAVILY_DEFAULT_MAX_RESULTS, TAVILY_DISABLE_ADVANCED + TAVILY_DEFAULT_SEARCH_DEPTH, + TAVILY_DEFAULT_MAX_RESULTS, + TAVILY_DISABLE_ADVANCED, ) # Assume these helpers will be moved or are accessible via cog @@ -42,152 +52,241 @@ from .config import ( # to access things like cog.bot, cog.session, cog.current_channel, cog.memory_manager etc. # We will add 'cog' as the first parameter to each. -async def get_recent_messages(cog: commands.Cog, limit: int, channel_id: str = None) -> Dict[str, Any]: + +async def get_recent_messages( + cog: commands.Cog, limit: int, channel_id: str = None +) -> Dict[str, Any]: """Get recent messages from a Discord channel""" - from .utils import format_message # Import here to avoid circular dependency at module level + from .utils import ( + format_message, + ) # Import here to avoid circular dependency at module level + limit = min(max(1, limit), 100) try: if channel_id: channel = cog.bot.get_channel(int(channel_id)) - if not channel: return {"error": f"Channel {channel_id} not found"} + if not channel: + return {"error": f"Channel {channel_id} not found"} else: channel = cog.current_channel - if not channel: return {"error": "No current channel context"} + if not channel: + return {"error": "No current channel context"} messages = [] async for message in channel.history(limit=limit): - messages.append(format_message(cog, message)) # Use formatter + messages.append(format_message(cog, message)) # Use formatter return { - "channel": {"id": str(channel.id), "name": getattr(channel, 'name', 'DM Channel')}, - "messages": messages, "count": len(messages), - "timestamp": datetime.datetime.now().isoformat() + "channel": { + "id": str(channel.id), + "name": getattr(channel, "name", "DM Channel"), + }, + "messages": messages, + "count": len(messages), + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: - return {"error": f"Error retrieving messages: {str(e)}", "timestamp": datetime.datetime.now().isoformat()} + return { + "error": f"Error retrieving messages: {str(e)}", + "timestamp": datetime.datetime.now().isoformat(), + } -async def search_user_messages(cog: commands.Cog, user_id: str, limit: int, channel_id: str = None) -> Dict[str, Any]: + +async def search_user_messages( + cog: commands.Cog, user_id: str, limit: int, channel_id: str = None +) -> Dict[str, Any]: """Search for messages from a specific user""" - from .utils import format_message # Import here + from .utils import format_message # Import here + limit = min(max(1, limit), 100) try: if channel_id: channel = cog.bot.get_channel(int(channel_id)) - if not channel: return {"error": f"Channel {channel_id} not found"} + if not channel: + return {"error": f"Channel {channel_id} not found"} else: channel = cog.current_channel - if not channel: return {"error": "No current channel context"} + if not channel: + return {"error": "No current channel context"} - try: user_id_int = int(user_id) - except ValueError: return {"error": f"Invalid user ID: {user_id}"} + try: + user_id_int = int(user_id) + except ValueError: + return {"error": f"Invalid user ID: {user_id}"} messages = [] user_name = " " async for message in channel.history(limit=500): if message.author.id == user_id_int: - formatted_msg = format_message(cog, message) # Use formatter + formatted_msg = format_message(cog, message) # Use formatter messages.append(formatted_msg) - user_name = formatted_msg["author"]["name"] # Get name from formatted msg - if len(messages) >= limit: break + user_name = formatted_msg["author"][ + "name" + ] # Get name from formatted msg + if len(messages) >= limit: + break return { - "channel": {"id": str(channel.id), "name": getattr(channel, 'name', 'DM Channel')}, + "channel": { + "id": str(channel.id), + "name": getattr(channel, "name", "DM Channel"), + }, "user": {"id": user_id, "name": user_name}, - "messages": messages, "count": len(messages), - "timestamp": datetime.datetime.now().isoformat() + "messages": messages, + "count": len(messages), + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: - return {"error": f"Error searching user messages: {str(e)}", "timestamp": datetime.datetime.now().isoformat()} + return { + "error": f"Error searching user messages: {str(e)}", + "timestamp": datetime.datetime.now().isoformat(), + } -async def search_messages_by_content(cog: commands.Cog, search_term: str, limit: int, channel_id: str = None) -> Dict[str, Any]: + +async def search_messages_by_content( + cog: commands.Cog, search_term: str, limit: int, channel_id: str = None +) -> Dict[str, Any]: """Search for messages containing specific content""" - from .utils import format_message # Import here + from .utils import format_message # Import here + limit = min(max(1, limit), 100) try: if channel_id: channel = cog.bot.get_channel(int(channel_id)) - if not channel: return {"error": f"Channel {channel_id} not found"} + if not channel: + return {"error": f"Channel {channel_id} not found"} else: channel = cog.current_channel - if not channel: return {"error": "No current channel context"} + if not channel: + return {"error": "No current channel context"} messages = [] search_term_lower = search_term.lower() async for message in channel.history(limit=500): if search_term_lower in message.content.lower(): - messages.append(format_message(cog, message)) # Use formatter - if len(messages) >= limit: break + messages.append(format_message(cog, message)) # Use formatter + if len(messages) >= limit: + break return { - "channel": {"id": str(channel.id), "name": getattr(channel, 'name', 'DM Channel')}, + "channel": { + "id": str(channel.id), + "name": getattr(channel, "name", "DM Channel"), + }, "search_term": search_term, - "messages": messages, "count": len(messages), - "timestamp": datetime.datetime.now().isoformat() + "messages": messages, + "count": len(messages), + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: - return {"error": f"Error searching messages by content: {str(e)}", "timestamp": datetime.datetime.now().isoformat()} + return { + "error": f"Error searching messages by content: {str(e)}", + "timestamp": datetime.datetime.now().isoformat(), + } + async def get_channel_info(cog: commands.Cog, channel_id: str = None) -> Dict[str, Any]: """Get information about a Discord channel""" try: if channel_id: channel = cog.bot.get_channel(int(channel_id)) - if not channel: return {"error": f"Channel {channel_id} not found"} + if not channel: + return {"error": f"Channel {channel_id} not found"} else: channel = cog.current_channel - if not channel: return {"error": "No current channel context"} + if not channel: + return {"error": "No current channel context"} - channel_info = {"id": str(channel.id), "type": str(channel.type), "timestamp": datetime.datetime.now().isoformat()} - if isinstance(channel, discord.TextChannel): # Use isinstance for type checking - channel_info.update({ - "name": channel.name, "topic": channel.topic, "position": channel.position, - "nsfw": channel.is_nsfw(), - "category": {"id": str(channel.category_id), "name": channel.category.name} if channel.category else None, - "guild": {"id": str(channel.guild.id), "name": channel.guild.name, "member_count": channel.guild.member_count} - }) + channel_info = { + "id": str(channel.id), + "type": str(channel.type), + "timestamp": datetime.datetime.now().isoformat(), + } + if isinstance(channel, discord.TextChannel): # Use isinstance for type checking + channel_info.update( + { + "name": channel.name, + "topic": channel.topic, + "position": channel.position, + "nsfw": channel.is_nsfw(), + "category": ( + {"id": str(channel.category_id), "name": channel.category.name} + if channel.category + else None + ), + "guild": { + "id": str(channel.guild.id), + "name": channel.guild.name, + "member_count": channel.guild.member_count, + }, + } + ) elif isinstance(channel, discord.DMChannel): - channel_info.update({ - "type": "DM", - "recipient": {"id": str(channel.recipient.id), "name": channel.recipient.name, "display_name": channel.recipient.display_name} - }) + channel_info.update( + { + "type": "DM", + "recipient": { + "id": str(channel.recipient.id), + "name": channel.recipient.name, + "display_name": channel.recipient.display_name, + }, + } + ) # Add handling for other channel types (VoiceChannel, Thread, etc.) if needed return channel_info except Exception as e: - return {"error": f"Error getting channel info: {str(e)}", "timestamp": datetime.datetime.now().isoformat()} + return { + "error": f"Error getting channel info: {str(e)}", + "timestamp": datetime.datetime.now().isoformat(), + } -async def get_conversation_context(cog: commands.Cog, message_count: int, channel_id: str = None) -> Dict[str, Any]: + +async def get_conversation_context( + cog: commands.Cog, message_count: int, channel_id: str = None +) -> Dict[str, Any]: """Get the context of the current conversation in a channel""" - from .utils import format_message # Import here + from .utils import format_message # Import here + message_count = min(max(5, message_count), 50) try: if channel_id: channel = cog.bot.get_channel(int(channel_id)) - if not channel: return {"error": f"Channel {channel_id} not found"} + if not channel: + return {"error": f"Channel {channel_id} not found"} else: channel = cog.current_channel - if not channel: return {"error": "No current channel context"} + if not channel: + return {"error": "No current channel context"} messages = [] # Prefer cache if available - if channel.id in cog.message_cache['by_channel']: - messages = list(cog.message_cache['by_channel'][channel.id])[-message_count:] + if channel.id in cog.message_cache["by_channel"]: + messages = list(cog.message_cache["by_channel"][channel.id])[ + -message_count: + ] else: async for msg in channel.history(limit=message_count): messages.append(format_message(cog, msg)) messages.reverse() return { - "channel_id": str(channel.id), "channel_name": getattr(channel, 'name', 'DM Channel'), - "context_messages": messages, "count": len(messages), - "timestamp": datetime.datetime.now().isoformat() + "channel_id": str(channel.id), + "channel_name": getattr(channel, "name", "DM Channel"), + "context_messages": messages, + "count": len(messages), + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: return {"error": f"Error getting conversation context: {str(e)}"} -async def get_thread_context(cog: commands.Cog, thread_id: str, message_count: int) -> Dict[str, Any]: + +async def get_thread_context( + cog: commands.Cog, thread_id: str, message_count: int +) -> Dict[str, Any]: """Get the context of a thread conversation""" - from .utils import format_message # Import here + from .utils import format_message # Import here + message_count = min(max(5, message_count), 50) try: thread = cog.bot.get_channel(int(thread_id)) @@ -195,23 +294,28 @@ async def get_thread_context(cog: commands.Cog, thread_id: str, message_count: i return {"error": f"Thread {thread_id} not found or is not a thread"} messages = [] - if thread.id in cog.message_cache['by_thread']: - messages = list(cog.message_cache['by_thread'][thread.id])[-message_count:] + if thread.id in cog.message_cache["by_thread"]: + messages = list(cog.message_cache["by_thread"][thread.id])[-message_count:] else: async for msg in thread.history(limit=message_count): messages.append(format_message(cog, msg)) messages.reverse() return { - "thread_id": str(thread.id), "thread_name": thread.name, + "thread_id": str(thread.id), + "thread_name": thread.name, "parent_channel_id": str(thread.parent_id), - "context_messages": messages, "count": len(messages), - "timestamp": datetime.datetime.now().isoformat() + "context_messages": messages, + "count": len(messages), + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: return {"error": f"Error getting thread context: {str(e)}"} -async def get_user_interaction_history(cog: commands.Cog, user_id_1: str, limit: int, user_id_2: str = None) -> Dict[str, Any]: + +async def get_user_interaction_history( + cog: commands.Cog, user_id_1: str, limit: int, user_id_2: str = None +) -> Dict[str, Any]: """Get the history of interactions between two users (or user and bot)""" limit = min(max(1, limit), 50) try: @@ -220,51 +324,82 @@ async def get_user_interaction_history(cog: commands.Cog, user_id_1: str, limit: interactions = [] # Simplified: Search global cache - for msg_data in list(cog.message_cache['global_recent']): - author_id = int(msg_data['author']['id']) - mentioned_ids = [int(m['id']) for m in msg_data.get('mentions', [])] - replied_to_author_id = int(msg_data.get('replied_to_author_id')) if msg_data.get('replied_to_author_id') else None + for msg_data in list(cog.message_cache["global_recent"]): + author_id = int(msg_data["author"]["id"]) + mentioned_ids = [int(m["id"]) for m in msg_data.get("mentions", [])] + replied_to_author_id = ( + int(msg_data.get("replied_to_author_id")) + if msg_data.get("replied_to_author_id") + else None + ) is_interaction = False - if (author_id == user_id_1_int and replied_to_author_id == user_id_2_int) or \ - (author_id == user_id_2_int and replied_to_author_id == user_id_1_int): is_interaction = True - elif (author_id == user_id_1_int and user_id_2_int in mentioned_ids) or \ - (author_id == user_id_2_int and user_id_1_int in mentioned_ids): is_interaction = True + if ( + author_id == user_id_1_int and replied_to_author_id == user_id_2_int + ) or (author_id == user_id_2_int and replied_to_author_id == user_id_1_int): + is_interaction = True + elif (author_id == user_id_1_int and user_id_2_int in mentioned_ids) or ( + author_id == user_id_2_int and user_id_1_int in mentioned_ids + ): + is_interaction = True if is_interaction: interactions.append(msg_data) - if len(interactions) >= limit: break + if len(interactions) >= limit: + break user1 = await cog.bot.fetch_user(user_id_1_int) user2 = await cog.bot.fetch_user(user_id_2_int) return { - "user_1": {"id": str(user_id_1_int), "name": user1.name if user1 else "Unknown"}, - "user_2": {"id": str(user_id_2_int), "name": user2.name if user2 else "Unknown"}, - "interactions": interactions, "count": len(interactions), - "timestamp": datetime.datetime.now().isoformat() + "user_1": { + "id": str(user_id_1_int), + "name": user1.name if user1 else "Unknown", + }, + "user_2": { + "id": str(user_id_2_int), + "name": user2.name if user2 else "Unknown", + }, + "interactions": interactions, + "count": len(interactions), + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: return {"error": f"Error getting user interaction history: {str(e)}"} -async def get_conversation_summary(cog: commands.Cog, channel_id: str = None, message_limit: int = 25) -> Dict[str, Any]: + +async def get_conversation_summary( + cog: commands.Cog, channel_id: str = None, message_limit: int = 25 +) -> Dict[str, Any]: """Generates and returns a summary of the recent conversation in a channel using an LLM call.""" - from .config import SUMMARY_RESPONSE_SCHEMA, DEFAULT_MODEL # Import schema and model - from .api import get_internal_ai_json_response # Import here + from .config import ( + SUMMARY_RESPONSE_SCHEMA, + DEFAULT_MODEL, + ) # Import schema and model + from .api import get_internal_ai_json_response # Import here + try: - target_channel_id_str = channel_id or (str(cog.current_channel.id) if cog.current_channel else None) - if not target_channel_id_str: return {"error": "No channel context"} + target_channel_id_str = channel_id or ( + str(cog.current_channel.id) if cog.current_channel else None + ) + if not target_channel_id_str: + return {"error": "No channel context"} target_channel_id = int(target_channel_id_str) channel = cog.bot.get_channel(target_channel_id) - if not channel: return {"error": f"Channel {target_channel_id_str} not found"} + if not channel: + return {"error": f"Channel {target_channel_id_str} not found"} now = time.time() cached_data = cog.conversation_summaries.get(target_channel_id) if cached_data and (now - cached_data.get("timestamp", 0) < SUMMARY_CACHE_TTL): print(f"Returning cached summary for channel {target_channel_id}") return { - "channel_id": target_channel_id_str, "summary": cached_data.get("summary", "Cache error"), - "source": "cache", "timestamp": datetime.datetime.fromtimestamp(cached_data.get("timestamp", now)).isoformat() + "channel_id": target_channel_id_str, + "summary": cached_data.get("summary", "Cache error"), + "source": "cache", + "timestamp": datetime.datetime.fromtimestamp( + cached_data.get("timestamp", now) + ).isoformat(), } print(f"Generating new summary for channel {target_channel_id}") @@ -275,31 +410,46 @@ async def get_conversation_summary(cog: commands.Cog, channel_id: str = None, me async for msg in channel.history(limit=message_limit): recent_messages_text.append(f"{msg.author.display_name}: {msg.content}") recent_messages_text.reverse() - except discord.Forbidden: return {"error": f"Missing permissions in channel {target_channel_id_str}"} - except Exception as hist_e: return {"error": f"Error fetching history: {str(hist_e)}"} + except discord.Forbidden: + return {"error": f"Missing permissions in channel {target_channel_id_str}"} + except Exception as hist_e: + return {"error": f"Error fetching history: {str(hist_e)}"} if not recent_messages_text: summary = "No recent messages found." - cog.conversation_summaries[target_channel_id] = {"summary": summary, "timestamp": time.time()} - return {"channel_id": target_channel_id_str, "summary": summary, "source": "generated (empty)", "timestamp": datetime.datetime.now().isoformat()} + cog.conversation_summaries[target_channel_id] = { + "summary": summary, + "timestamp": time.time(), + } + return { + "channel_id": target_channel_id_str, + "summary": summary, + "source": "generated (empty)", + "timestamp": datetime.datetime.now().isoformat(), + } conversation_context = "\n".join(recent_messages_text) summarization_prompt = f"Summarize the main points and current topic of this Discord chat snippet:\n\n---\n{conversation_context}\n---\n\nSummary:" # Use get_internal_ai_json_response prompt_messages = [ - {"role": "system", "content": "You are an expert summarizer. Provide a concise summary of the following conversation."}, - {"role": "user", "content": summarization_prompt} + { + "role": "system", + "content": "You are an expert summarizer. Provide a concise summary of the following conversation.", + }, + {"role": "user", "content": summarization_prompt}, ] summary_data = await get_internal_ai_json_response( cog=cog, prompt_messages=prompt_messages, task_description=f"Summarization for channel {target_channel_id}", - response_schema_dict=SUMMARY_RESPONSE_SCHEMA['schema'], # Pass the schema dict - model_name_override=DEFAULT_MODEL, # Consider a cheaper/faster model if needed + response_schema_dict=SUMMARY_RESPONSE_SCHEMA[ + "schema" + ], # Pass the schema dict + model_name_override=DEFAULT_MODEL, # Consider a cheaper/faster model if needed temperature=0.3, - max_tokens=200 # Adjust as needed + max_tokens=200, # Adjust as needed ) # Unpack the tuple, we only need the parsed data here summary_parsed_data, _ = summary_data if summary_data else (None, None) @@ -309,12 +459,20 @@ async def get_conversation_summary(cog: commands.Cog, channel_id: str = None, me summary = summary_parsed_data["summary"].strip() print(f"Summary generated for {target_channel_id}: {summary[:100]}...") else: - error_detail = f"Invalid format or missing 'summary' key. Parsed Response: {summary_parsed_data}" # Log parsed data on error + error_detail = f"Invalid format or missing 'summary' key. Parsed Response: {summary_parsed_data}" # Log parsed data on error summary = f"Failed summary for {target_channel_id}. Error: {error_detail}" print(summary) - cog.conversation_summaries[target_channel_id] = {"summary": summary, "timestamp": time.time()} - return {"channel_id": target_channel_id_str, "summary": summary, "source": "generated", "timestamp": datetime.datetime.now().isoformat()} + cog.conversation_summaries[target_channel_id] = { + "summary": summary, + "timestamp": time.time(), + } + return { + "channel_id": target_channel_id_str, + "summary": summary, + "source": "generated", + "timestamp": datetime.datetime.now().isoformat(), + } except Exception as e: error_msg = f"General error in get_conversation_summary: {str(e)}" @@ -322,76 +480,122 @@ async def get_conversation_summary(cog: commands.Cog, channel_id: str = None, me traceback.print_exc() return {"error": error_msg} -async def get_message_context(cog: commands.Cog, message_id: str, before_count: int = 5, after_count: int = 5) -> Dict[str, Any]: + +async def get_message_context( + cog: commands.Cog, message_id: str, before_count: int = 5, after_count: int = 5 +) -> Dict[str, Any]: """Get the context (messages before and after) around a specific message""" - from .utils import format_message # Import here + from .utils import format_message # Import here + before_count = min(max(1, before_count), 25) after_count = min(max(1, after_count), 25) try: target_message = None channel = cog.current_channel - if not channel: return {"error": "No current channel context"} + if not channel: + return {"error": "No current channel context"} try: message_id_int = int(message_id) target_message = await channel.fetch_message(message_id_int) - except discord.NotFound: return {"error": f"Message {message_id} not found in {channel.id}"} - except discord.Forbidden: return {"error": f"No permission for message {message_id} in {channel.id}"} - except ValueError: return {"error": f"Invalid message ID: {message_id}"} - if not target_message: return {"error": f"Message {message_id} not fetched"} + except discord.NotFound: + return {"error": f"Message {message_id} not found in {channel.id}"} + except discord.Forbidden: + return {"error": f"No permission for message {message_id} in {channel.id}"} + except ValueError: + return {"error": f"Invalid message ID: {message_id}"} + if not target_message: + return {"error": f"Message {message_id} not fetched"} - messages_before = [format_message(cog, msg) async for msg in channel.history(limit=before_count, before=target_message)] + messages_before = [ + format_message(cog, msg) + async for msg in channel.history(limit=before_count, before=target_message) + ] messages_before.reverse() - messages_after = [format_message(cog, msg) async for msg in channel.history(limit=after_count, after=target_message)] + messages_after = [ + format_message(cog, msg) + async for msg in channel.history(limit=after_count, after=target_message) + ] return { "target_message": format_message(cog, target_message), - "messages_before": messages_before, "messages_after": messages_after, - "channel_id": str(channel.id), "timestamp": datetime.datetime.now().isoformat() + "messages_before": messages_before, + "messages_after": messages_after, + "channel_id": str(channel.id), + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: return {"error": f"Error getting message context: {str(e)}"} -async def web_search(cog: commands.Cog, query: str, search_depth: str = TAVILY_DEFAULT_SEARCH_DEPTH, max_results: int = TAVILY_DEFAULT_MAX_RESULTS, topic: str = "general", include_domains: Optional[List[str]] = None, exclude_domains: Optional[List[str]] = None, include_answer: bool = True, include_raw_content: bool = False, include_images: bool = False) -> Dict[str, Any]: + +async def web_search( + cog: commands.Cog, + query: str, + search_depth: str = TAVILY_DEFAULT_SEARCH_DEPTH, + max_results: int = TAVILY_DEFAULT_MAX_RESULTS, + topic: str = "general", + include_domains: Optional[List[str]] = None, + exclude_domains: Optional[List[str]] = None, + include_answer: bool = True, + include_raw_content: bool = False, + include_images: bool = False, +) -> Dict[str, Any]: """Search the web using Tavily API""" - if not hasattr(cog, 'tavily_client') or not cog.tavily_client: - return {"error": "Tavily client not initialized.", "timestamp": datetime.datetime.now().isoformat()} + if not hasattr(cog, "tavily_client") or not cog.tavily_client: + return { + "error": "Tavily client not initialized.", + "timestamp": datetime.datetime.now().isoformat(), + } # Cost control / Logging for advanced search final_search_depth = search_depth if search_depth.lower() == "advanced": if TAVILY_DISABLE_ADVANCED: - print(f"Warning: Advanced Tavily search requested but disabled by config. Falling back to basic.") + print( + f"Warning: Advanced Tavily search requested but disabled by config. Falling back to basic." + ) final_search_depth = "basic" else: - print(f"Performing advanced Tavily search (cost: 10 credits) for query: '{query}'") + print( + f"Performing advanced Tavily search (cost: 10 credits) for query: '{query}'" + ) elif search_depth.lower() != "basic": - print(f"Warning: Invalid search_depth '{search_depth}' provided. Using 'basic'.") + print( + f"Warning: Invalid search_depth '{search_depth}' provided. Using 'basic'." + ) final_search_depth = "basic" # Validate max_results - final_max_results = max(5, min(20, max_results)) # Clamp between 5 and 20 + final_max_results = max(5, min(20, max_results)) # Clamp between 5 and 20 try: # Pass parameters to Tavily search response = await asyncio.to_thread( cog.tavily_client.search, query=query, - search_depth=final_search_depth, # Use validated depth - max_results=final_max_results, # Use validated results count + search_depth=final_search_depth, # Use validated depth + max_results=final_max_results, # Use validated results count topic=topic, include_domains=include_domains, exclude_domains=exclude_domains, include_answer=include_answer, include_raw_content=include_raw_content, - include_images=include_images + include_images=include_images, ) # Extract relevant information from results results = [] for r in response.get("results", []): - result = {"title": r.get("title"), "url": r.get("url"), "content": r.get("content"), "score": r.get("score"), "published_date": r.get("published_date")} - if include_raw_content: result["raw_content"] = r.get("raw_content") - if include_images: result["images"] = r.get("images") + result = { + "title": r.get("title"), + "url": r.get("url"), + "content": r.get("content"), + "score": r.get("score"), + "published_date": r.get("published_date"), + } + if include_raw_content: + result["raw_content"] = r.get("raw_content") + if include_images: + result["images"] = r.get("images") results.append(result) return { @@ -408,83 +612,142 @@ async def web_search(cog: commands.Cog, query: str, search_depth: str = TAVILY_D "answer": response.get("answer"), "follow_up_questions": response.get("follow_up_questions"), "count": len(results), - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: error_message = f"Error during Tavily search for '{query}': {str(e)}" print(error_message) - return {"error": error_message, "timestamp": datetime.datetime.now().isoformat()} + return { + "error": error_message, + "timestamp": datetime.datetime.now().isoformat(), + } -async def remember_user_fact(cog: commands.Cog, user_id: str, fact: str) -> Dict[str, Any]: + +async def remember_user_fact( + cog: commands.Cog, user_id: str, fact: str +) -> Dict[str, Any]: """Stores a fact about a user using the MemoryManager.""" - if not user_id or not fact: return {"error": "user_id and fact required."} + if not user_id or not fact: + return {"error": "user_id and fact required."} print(f"Remembering fact for user {user_id}: '{fact}'") try: result = await cog.memory_manager.add_user_fact(user_id, fact) - if result.get("status") == "added": return {"status": "success", "user_id": user_id, "fact_added": fact} - elif result.get("status") == "duplicate": return {"status": "duplicate", "user_id": user_id, "fact": fact} - elif result.get("status") == "limit_reached": return {"status": "success", "user_id": user_id, "fact_added": fact, "note": "Oldest fact deleted."} - else: return {"error": result.get("error", "Unknown MemoryManager error")} + if result.get("status") == "added": + return {"status": "success", "user_id": user_id, "fact_added": fact} + elif result.get("status") == "duplicate": + return {"status": "duplicate", "user_id": user_id, "fact": fact} + elif result.get("status") == "limit_reached": + return { + "status": "success", + "user_id": user_id, + "fact_added": fact, + "note": "Oldest fact deleted.", + } + else: + return {"error": result.get("error", "Unknown MemoryManager error")} except Exception as e: error_message = f"Error calling MemoryManager for user fact {user_id}: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} + async def get_user_facts(cog: commands.Cog, user_id: str) -> Dict[str, Any]: """Retrieves stored facts about a user using the MemoryManager.""" - if not user_id: return {"error": "user_id required."} + if not user_id: + return {"error": "user_id required."} print(f"Retrieving facts for user {user_id}") try: - user_facts = await cog.memory_manager.get_user_facts(user_id) # Context not needed for basic retrieval tool - return {"user_id": user_id, "facts": user_facts, "count": len(user_facts), "timestamp": datetime.datetime.now().isoformat()} + user_facts = await cog.memory_manager.get_user_facts( + user_id + ) # Context not needed for basic retrieval tool + return { + "user_id": user_id, + "facts": user_facts, + "count": len(user_facts), + "timestamp": datetime.datetime.now().isoformat(), + } except Exception as e: - error_message = f"Error calling MemoryManager for user facts {user_id}: {str(e)}" - print(error_message); traceback.print_exc() + error_message = ( + f"Error calling MemoryManager for user facts {user_id}: {str(e)}" + ) + print(error_message) + traceback.print_exc() return {"error": error_message} + async def remember_general_fact(cog: commands.Cog, fact: str) -> Dict[str, Any]: """Stores a general fact using the MemoryManager.""" - if not fact: return {"error": "fact required."} + if not fact: + return {"error": "fact required."} print(f"Remembering general fact: '{fact}'") try: result = await cog.memory_manager.add_general_fact(fact) - if result.get("status") == "added": return {"status": "success", "fact_added": fact} - elif result.get("status") == "duplicate": return {"status": "duplicate", "fact": fact} - elif result.get("status") == "limit_reached": return {"status": "success", "fact_added": fact, "note": "Oldest fact deleted."} - else: return {"error": result.get("error", "Unknown MemoryManager error")} + if result.get("status") == "added": + return {"status": "success", "fact_added": fact} + elif result.get("status") == "duplicate": + return {"status": "duplicate", "fact": fact} + elif result.get("status") == "limit_reached": + return { + "status": "success", + "fact_added": fact, + "note": "Oldest fact deleted.", + } + else: + return {"error": result.get("error", "Unknown MemoryManager error")} except Exception as e: error_message = f"Error calling MemoryManager for general fact: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} -async def get_general_facts(cog: commands.Cog, query: Optional[str] = None, limit: Optional[int] = 10) -> Dict[str, Any]: + +async def get_general_facts( + cog: commands.Cog, query: Optional[str] = None, limit: Optional[int] = 10 +) -> Dict[str, Any]: """Retrieves stored general facts using the MemoryManager.""" print(f"Retrieving general facts (query='{query}', limit={limit})") limit = min(max(1, limit or 10), 50) try: - general_facts = await cog.memory_manager.get_general_facts(query=query, limit=limit) # Context not needed here - return {"query": query, "facts": general_facts, "count": len(general_facts), "timestamp": datetime.datetime.now().isoformat()} + general_facts = await cog.memory_manager.get_general_facts( + query=query, limit=limit + ) # Context not needed here + return { + "query": query, + "facts": general_facts, + "count": len(general_facts), + "timestamp": datetime.datetime.now().isoformat(), + } except Exception as e: error_message = f"Error calling MemoryManager for general facts: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} + async def _is_moderator(cog: commands.Cog, user_id: str) -> bool: """Checks if a user has any of the configured moderator roles.""" - if not cog.current_channel or not isinstance(cog.current_channel, discord.abc.GuildChannel): - return False # Cannot check roles outside of a guild + if not cog.current_channel or not isinstance( + cog.current_channel, discord.abc.GuildChannel + ): + return False # Cannot check roles outside of a guild guild = cog.current_channel.guild - if not guild: return False + if not guild: + return False try: member_id = int(user_id) member = guild.get_member(member_id) or await guild.fetch_member(member_id) - if not member: return False + if not member: + return False # Get moderator role names from config from .config import MODERATOR_ROLE_NAMES + if not MODERATOR_ROLE_NAMES: - print("Warning: MODERATOR_ROLE_NAMES not configured in config.py. No moderator check performed.") + print( + "Warning: MODERATOR_ROLE_NAMES not configured in config.py. No moderator check performed." + ) return False # Check if the user has any of the moderator roles @@ -498,67 +761,134 @@ async def _is_moderator(cog: commands.Cog, user_id: str) -> bool: traceback.print_exc() return False -async def timeout_user(cog: commands.Cog, user_id: str, duration_minutes: int, reason: Optional[str] = None, requesting_user_id: Optional[str] = None) -> Dict[str, Any]: + +async def timeout_user( + cog: commands.Cog, + user_id: str, + duration_minutes: int, + reason: Optional[str] = None, + requesting_user_id: Optional[str] = None, +) -> Dict[str, Any]: """Times out a user in the current server. Requires a moderator's user_id for authorization.""" - if not cog.current_channel or not isinstance(cog.current_channel, discord.abc.GuildChannel): + if not cog.current_channel or not isinstance( + cog.current_channel, discord.abc.GuildChannel + ): return {"error": "Cannot timeout outside of a server."} guild = cog.current_channel.guild - if not guild: return {"error": "Could not determine server."} - if not 1 <= duration_minutes <= 1440: return {"error": "Duration must be 1-1440 minutes."} + if not guild: + return {"error": "Could not determine server."} + if not 1 <= duration_minutes <= 1440: + return {"error": "Duration must be 1-1440 minutes."} # Moderator check if not requesting_user_id: return {"error": "A requesting_user_id is required to use the timeout tool."} if not await _is_moderator(cog, requesting_user_id): - return {"error": f"User {requesting_user_id} is not authorized to use the timeout tool (not a moderator)."} + return { + "error": f"User {requesting_user_id} is not authorized to use the timeout tool (not a moderator)." + } try: member_id = int(user_id) - member = guild.get_member(member_id) or await guild.fetch_member(member_id) # Fetch if not cached - if not member: return {"error": f"User {user_id} not found in server."} - if member == cog.bot.user: return {"error": "lol i cant timeout myself vro"} - if member.id == guild.owner_id: return {"error": f"Cannot timeout owner {member.display_name}."} + member = guild.get_member(member_id) or await guild.fetch_member( + member_id + ) # Fetch if not cached + if not member: + return {"error": f"User {user_id} not found in server."} + if member == cog.bot.user: + return {"error": "lol i cant timeout myself vro"} + if member.id == guild.owner_id: + return {"error": f"Cannot timeout owner {member.display_name}."} bot_member = guild.me - if not bot_member.guild_permissions.moderate_members: return {"error": "I lack permission to timeout."} - if bot_member.id != guild.owner_id and bot_member.top_role <= member.top_role: return {"error": f"Cannot timeout {member.display_name} (role hierarchy)."} + if not bot_member.guild_permissions.moderate_members: + return {"error": "I lack permission to timeout."} + if bot_member.id != guild.owner_id and bot_member.top_role <= member.top_role: + return {"error": f"Cannot timeout {member.display_name} (role hierarchy)."} until = discord.utils.utcnow() + datetime.timedelta(minutes=duration_minutes) timeout_reason = reason or "gurt felt like it" await member.timeout(until, reason=timeout_reason) - print(f"Timed out {member.display_name} ({user_id}) for {duration_minutes} mins. Reason: {timeout_reason}") - return {"status": "success", "user_timed_out": member.display_name, "user_id": user_id, "duration_minutes": duration_minutes, "reason": timeout_reason} - except ValueError: return {"error": f"Invalid user ID: {user_id}"} - except discord.NotFound: return {"error": f"User {user_id} not found in server."} - except discord.Forbidden as e: print(f"Forbidden error timeout {user_id}: {e}"); return {"error": f"Permission error timeout {user_id}."} - except discord.HTTPException as e: print(f"API error timeout {user_id}: {e}"); return {"error": f"API error timeout {user_id}: {e}"} - except Exception as e: print(f"Unexpected error timeout {user_id}: {e}"); traceback.print_exc(); return {"error": f"Unexpected error timeout {user_id}: {str(e)}"} + print( + f"Timed out {member.display_name} ({user_id}) for {duration_minutes} mins. Reason: {timeout_reason}" + ) + return { + "status": "success", + "user_timed_out": member.display_name, + "user_id": user_id, + "duration_minutes": duration_minutes, + "reason": timeout_reason, + } + except ValueError: + return {"error": f"Invalid user ID: {user_id}"} + except discord.NotFound: + return {"error": f"User {user_id} not found in server."} + except discord.Forbidden as e: + print(f"Forbidden error timeout {user_id}: {e}") + return {"error": f"Permission error timeout {user_id}."} + except discord.HTTPException as e: + print(f"API error timeout {user_id}: {e}") + return {"error": f"API error timeout {user_id}: {e}"} + except Exception as e: + print(f"Unexpected error timeout {user_id}: {e}") + traceback.print_exc() + return {"error": f"Unexpected error timeout {user_id}: {str(e)}"} -async def remove_timeout(cog: commands.Cog, user_id: str, reason: Optional[str] = None) -> Dict[str, Any]: + +async def remove_timeout( + cog: commands.Cog, user_id: str, reason: Optional[str] = None +) -> Dict[str, Any]: """Removes an active timeout from a user.""" - if not cog.current_channel or not isinstance(cog.current_channel, discord.abc.GuildChannel): + if not cog.current_channel or not isinstance( + cog.current_channel, discord.abc.GuildChannel + ): return {"error": "Cannot remove timeout outside of a server."} guild = cog.current_channel.guild - if not guild: return {"error": "Could not determine server."} + if not guild: + return {"error": "Could not determine server."} try: member_id = int(user_id) member = guild.get_member(member_id) or await guild.fetch_member(member_id) - if not member: return {"error": f"User {user_id} not found."} + if not member: + return {"error": f"User {user_id} not found."} # Define bot_member before using it bot_member = guild.me - if not bot_member.guild_permissions.moderate_members: return {"error": "I lack permission to remove timeouts."} - if member.timed_out_until is None: return {"status": "not_timed_out", "user_id": user_id, "user_name": member.display_name} + if not bot_member.guild_permissions.moderate_members: + return {"error": "I lack permission to remove timeouts."} + if member.timed_out_until is None: + return { + "status": "not_timed_out", + "user_id": user_id, + "user_name": member.display_name, + } timeout_reason = reason or "Gurt decided to be nice." - await member.timeout(None, reason=timeout_reason) # None removes timeout - print(f"Removed timeout from {member.display_name} ({user_id}). Reason: {timeout_reason}") - return {"status": "success", "user_timeout_removed": member.display_name, "user_id": user_id, "reason": timeout_reason} - except ValueError: return {"error": f"Invalid user ID: {user_id}"} - except discord.NotFound: return {"error": f"User {user_id} not found."} - except discord.Forbidden as e: print(f"Forbidden error remove timeout {user_id}: {e}"); return {"error": f"Permission error remove timeout {user_id}."} - except discord.HTTPException as e: print(f"API error remove timeout {user_id}: {e}"); return {"error": f"API error remove timeout {user_id}: {e}"} - except Exception as e: print(f"Unexpected error remove timeout {user_id}: {e}"); traceback.print_exc(); return {"error": f"Unexpected error remove timeout {user_id}: {str(e)}"} + await member.timeout(None, reason=timeout_reason) # None removes timeout + print( + f"Removed timeout from {member.display_name} ({user_id}). Reason: {timeout_reason}" + ) + return { + "status": "success", + "user_timeout_removed": member.display_name, + "user_id": user_id, + "reason": timeout_reason, + } + except ValueError: + return {"error": f"Invalid user ID: {user_id}"} + except discord.NotFound: + return {"error": f"User {user_id} not found."} + except discord.Forbidden as e: + print(f"Forbidden error remove timeout {user_id}: {e}") + return {"error": f"Permission error remove timeout {user_id}."} + except discord.HTTPException as e: + print(f"API error remove timeout {user_id}: {e}") + return {"error": f"API error remove timeout {user_id}: {e}"} + except Exception as e: + print(f"Unexpected error remove timeout {user_id}: {e}") + traceback.print_exc() + return {"error": f"Unexpected error remove timeout {user_id}: {str(e)}"} + async def calculate(cog: commands.Cog, expression: str) -> Dict[str, Any]: """Evaluates a mathematical expression using asteval.""" @@ -567,32 +897,45 @@ async def calculate(cog: commands.Cog, expression: str) -> Dict[str, Any]: try: result = aeval(expression) if aeval.error: - error_details = '; '.join(err.get_error() for err in aeval.error) + error_details = "; ".join(err.get_error() for err in aeval.error) error_message = f"Calculation error: {error_details}" print(error_message) return {"error": error_message, "expression": expression} - if isinstance(result, (int, float, complex)): result_str = str(result) - else: result_str = repr(result) # Fallback + if isinstance(result, (int, float, complex)): + result_str = str(result) + else: + result_str = repr(result) # Fallback print(f"Calculation result: {result_str}") return {"expression": expression, "result": result_str, "status": "success"} except Exception as e: error_message = f"Unexpected error during calculation: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message, "expression": expression} + async def run_python_code(cog: commands.Cog, code: str) -> Dict[str, Any]: """Executes a Python code snippet using the Piston API.""" - if not PISTON_API_URL: return {"error": "Piston API URL not configured (PISTON_API_URL)."} - if not cog.session: return {"error": "aiohttp session not initialized."} + if not PISTON_API_URL: + return {"error": "Piston API URL not configured (PISTON_API_URL)."} + if not cog.session: + return {"error": "aiohttp session not initialized."} print(f"Executing Python via Piston: {code[:100]}...") - payload = {"language": "python", "version": "3.10.0", "files": [{"name": "main.py", "content": code}]} + payload = { + "language": "python", + "version": "3.10.0", + "files": [{"name": "main.py", "content": code}], + } headers = {"Content-Type": "application/json"} - if PISTON_API_KEY: headers["Authorization"] = PISTON_API_KEY + if PISTON_API_KEY: + headers["Authorization"] = PISTON_API_KEY try: - async with cog.session.post(PISTON_API_URL, headers=headers, json=payload, timeout=20) as response: + async with cog.session.post( + PISTON_API_URL, headers=headers, json=payload, timeout=20 + ) as response: if response.status == 200: data = await response.json() run_info = data.get("run", {}) @@ -603,67 +946,132 @@ async def run_python_code(cog: commands.Cog, code: str) -> Dict[str, Any]: signal = run_info.get("signal") full_stderr = (compile_info.get("stderr", "") + "\n" + stderr).strip() max_len = 500 - stdout_trunc = stdout[:max_len] + ('...' if len(stdout) > max_len else '') - stderr_trunc = full_stderr[:max_len] + ('...' if len(full_stderr) > max_len else '') - result = {"status": "success" if exit_code == 0 and not signal else "execution_error", "stdout": stdout_trunc, "stderr": stderr_trunc, "exit_code": exit_code, "signal": signal} + stdout_trunc = stdout[:max_len] + ( + "..." if len(stdout) > max_len else "" + ) + stderr_trunc = full_stderr[:max_len] + ( + "..." if len(full_stderr) > max_len else "" + ) + result = { + "status": ( + "success" + if exit_code == 0 and not signal + else "execution_error" + ), + "stdout": stdout_trunc, + "stderr": stderr_trunc, + "exit_code": exit_code, + "signal": signal, + } print(f"Piston execution result: {result}") return result else: error_text = await response.text() - error_message = f"Piston API error (Status {response.status}): {error_text[:200]}" + error_message = ( + f"Piston API error (Status {response.status}): {error_text[:200]}" + ) print(error_message) return {"error": error_message} - except asyncio.TimeoutError: print("Piston API timed out."); return {"error": "Piston API timed out."} - except aiohttp.ClientError as e: print(f"Piston network error: {e}"); return {"error": f"Network error connecting to Piston: {str(e)}"} - except Exception as e: print(f"Unexpected Piston error: {e}"); traceback.print_exc(); return {"error": f"Unexpected error during Python execution: {str(e)}"} + except asyncio.TimeoutError: + print("Piston API timed out.") + return {"error": "Piston API timed out."} + except aiohttp.ClientError as e: + print(f"Piston network error: {e}") + return {"error": f"Network error connecting to Piston: {str(e)}"} + except Exception as e: + print(f"Unexpected Piston error: {e}") + traceback.print_exc() + return {"error": f"Unexpected error during Python execution: {str(e)}"} -async def create_poll(cog: commands.Cog, question: str, options: List[str]) -> Dict[str, Any]: + +async def create_poll( + cog: commands.Cog, question: str, options: List[str] +) -> Dict[str, Any]: """Creates a simple poll message.""" - if not cog.current_channel: return {"error": "No current channel context."} - if not isinstance(cog.current_channel, discord.abc.Messageable): return {"error": "Channel not messageable."} - if not isinstance(options, list) or not 2 <= len(options) <= 10: return {"error": "Poll needs 2-10 options."} + if not cog.current_channel: + return {"error": "No current channel context."} + if not isinstance(cog.current_channel, discord.abc.Messageable): + return {"error": "Channel not messageable."} + if not isinstance(options, list) or not 2 <= len(options) <= 10: + return {"error": "Poll needs 2-10 options."} if isinstance(cog.current_channel, discord.abc.GuildChannel): bot_member = cog.current_channel.guild.me - if not cog.current_channel.permissions_for(bot_member).send_messages or \ - not cog.current_channel.permissions_for(bot_member).add_reactions: + if ( + not cog.current_channel.permissions_for(bot_member).send_messages + or not cog.current_channel.permissions_for(bot_member).add_reactions + ): return {"error": "Missing permissions for poll."} try: poll_content = f"**📊 Poll: {question}**\n\n" number_emojis = ["1️⃣", "2️⃣", "3️⃣", "4️⃣", "5️⃣", "6️⃣", "7️⃣", "8️⃣", "9️⃣", "🔟"] - for i, option in enumerate(options): poll_content += f"{number_emojis[i]} {option}\n" + for i, option in enumerate(options): + poll_content += f"{number_emojis[i]} {option}\n" poll_message = await cog.current_channel.send(poll_content) print(f"Sent poll {poll_message.id}: {question}") - for i in range(len(options)): await poll_message.add_reaction(number_emojis[i]); await asyncio.sleep(0.1) - return {"status": "success", "message_id": str(poll_message.id), "question": question, "options_count": len(options)} - except discord.Forbidden: print("Poll Forbidden"); return {"error": "Forbidden: Missing permissions for poll."} - except discord.HTTPException as e: print(f"Poll API error: {e}"); return {"error": f"API error creating poll: {e}"} - except Exception as e: print(f"Poll unexpected error: {e}"); traceback.print_exc(); return {"error": f"Unexpected error creating poll: {str(e)}"} + for i in range(len(options)): + await poll_message.add_reaction(number_emojis[i]) + await asyncio.sleep(0.1) + return { + "status": "success", + "message_id": str(poll_message.id), + "question": question, + "options_count": len(options), + } + except discord.Forbidden: + print("Poll Forbidden") + return {"error": "Forbidden: Missing permissions for poll."} + except discord.HTTPException as e: + print(f"Poll API error: {e}") + return {"error": f"API error creating poll: {e}"} + except Exception as e: + print(f"Poll unexpected error: {e}") + traceback.print_exc() + return {"error": f"Unexpected error creating poll: {str(e)}"} + # Helper function to convert memory string (e.g., "128m") to bytes def parse_mem_limit(mem_limit_str: str) -> Optional[int]: - if not mem_limit_str: return None + if not mem_limit_str: + return None mem_limit_str = mem_limit_str.lower() - if mem_limit_str.endswith('m'): - try: return int(mem_limit_str[:-1]) * 1024 * 1024 - except ValueError: return None - elif mem_limit_str.endswith('g'): - try: return int(mem_limit_str[:-1]) * 1024 * 1024 * 1024 - except ValueError: return None - try: return int(mem_limit_str) # Assume bytes if no suffix - except ValueError: return None + if mem_limit_str.endswith("m"): + try: + return int(mem_limit_str[:-1]) * 1024 * 1024 + except ValueError: + return None + elif mem_limit_str.endswith("g"): + try: + return int(mem_limit_str[:-1]) * 1024 * 1024 * 1024 + except ValueError: + return None + try: + return int(mem_limit_str) # Assume bytes if no suffix + except ValueError: + return None + async def _check_command_safety(cog: commands.Cog, command: str) -> Dict[str, Any]: """Uses a secondary AI call to check if a command is potentially harmful.""" - from .api import get_internal_ai_json_response # Import here - print(f"Performing AI safety check for command: '{command}' using model {SAFETY_CHECK_MODEL}") + from .api import get_internal_ai_json_response # Import here + + print( + f"Performing AI safety check for command: '{command}' using model {SAFETY_CHECK_MODEL}" + ) safety_schema = { "type": "object", "properties": { - "is_safe": {"type": "boolean", "description": "True if safe for restricted container, False otherwise."}, - "reason": {"type": "string", "description": "Brief explanation why the command is safe or unsafe."} - }, "required": ["is_safe", "reason"] + "is_safe": { + "type": "boolean", + "description": "True if safe for restricted container, False otherwise.", + }, + "reason": { + "type": "string", + "description": "Brief explanation why the command is safe or unsafe.", + }, + }, + "required": ["is_safe", "reason"], } # Enhanced system prompt with more examples of safe commands system_prompt_content = ( @@ -680,37 +1088,47 @@ async def _check_command_safety(cog: commands.Cog, command: str) -> Dict[str, An ) prompt_messages = [ {"role": "system", "content": system_prompt_content}, - {"role": "user", "content": f"Analyze safety of this command: ```\n{command}\n```"} + { + "role": "user", + "content": f"Analyze safety of this command: ```\n{command}\n```", + }, ] # Update to receive tuple: (parsed_data, raw_text) safety_response_parsed, safety_response_raw = await get_internal_ai_json_response( cog=cog, prompt_messages=prompt_messages, task_description="Command Safety Check", - response_schema_dict=safety_schema, # Pass the schema dict directly - model_name_override=SAFETY_CHECK_MODEL, - temperature=0.1, - max_tokens=1000 # Increased token limit - ) + response_schema_dict=safety_schema, # Pass the schema dict directly + model_name_override=SAFETY_CHECK_MODEL, + temperature=0.1, + max_tokens=1000, # Increased token limit + ) # --- Log the raw response text --- - print(f"--- Raw AI Safety Check Response Text ---\n{safety_response_raw}\n---------------------------------------") + print( + f"--- Raw AI Safety Check Response Text ---\n{safety_response_raw}\n---------------------------------------" + ) - if safety_response_parsed and isinstance(safety_response_parsed.get("is_safe"), bool): + if safety_response_parsed and isinstance( + safety_response_parsed.get("is_safe"), bool + ): is_safe = safety_response_parsed["is_safe"] reason = safety_response_parsed.get("reason", "No reason provided.") print(f"AI Safety Check Result (Parsed): is_safe={is_safe}, reason='{reason}'") return {"safe": is_safe, "reason": reason} else: # Include part of the raw response in the error for debugging if parsing failed - raw_response_excerpt = str(safety_response_raw)[:200] if safety_response_raw else "N/A" + raw_response_excerpt = ( + str(safety_response_raw)[:200] if safety_response_raw else "N/A" + ) error_msg = f"AI safety check failed to parse or returned invalid format. Raw Response: {raw_response_excerpt}" print(f"AI Safety Check Error: {error_msg}") # Also log the parsed attempt if it exists but was invalid if safety_response_parsed: - print(f"Parsed attempt was: {safety_response_parsed}") + print(f"Parsed attempt was: {safety_response_parsed}") return {"safe": False, "reason": error_msg} + async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any]: """Executes a shell command in an isolated Docker container after an AI safety check.""" print(f"Attempting terminal command: {command}") @@ -720,12 +1138,20 @@ async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any print(error_message) return {"error": error_message, "command": command} - try: cpu_limit = float(DOCKER_CPU_LIMIT); cpu_period = 100000; cpu_quota = int(cpu_limit * cpu_period) - except ValueError: print(f"Warning: Invalid DOCKER_CPU_LIMIT '{DOCKER_CPU_LIMIT}'. Using default."); cpu_quota = 50000; cpu_period = 100000 + try: + cpu_limit = float(DOCKER_CPU_LIMIT) + cpu_period = 100000 + cpu_quota = int(cpu_limit * cpu_period) + except ValueError: + print(f"Warning: Invalid DOCKER_CPU_LIMIT '{DOCKER_CPU_LIMIT}'. Using default.") + cpu_quota = 50000 + cpu_period = 100000 mem_limit_bytes = parse_mem_limit(DOCKER_MEM_LIMIT) if mem_limit_bytes is None: - print(f"Warning: Invalid DOCKER_MEM_LIMIT '{DOCKER_MEM_LIMIT}'. Disabling memory limit.") + print( + f"Warning: Invalid DOCKER_MEM_LIMIT '{DOCKER_MEM_LIMIT}'. Disabling memory limit." + ) client = None container = None @@ -734,32 +1160,32 @@ async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any print(f"Running command in Docker ({DOCKER_EXEC_IMAGE})...") config = { - 'Image': DOCKER_EXEC_IMAGE, - 'Cmd': ["/bin/sh", "-c", command], - 'AttachStdout': True, - 'AttachStderr': True, - 'HostConfig': { - 'NetworkDisabled': True, - 'AutoRemove': False, # Changed to False - 'CpuPeriod': cpu_period, - 'CpuQuota': cpu_quota, - } + "Image": DOCKER_EXEC_IMAGE, + "Cmd": ["/bin/sh", "-c", command], + "AttachStdout": True, + "AttachStderr": True, + "HostConfig": { + "NetworkDisabled": True, + "AutoRemove": False, # Changed to False + "CpuPeriod": cpu_period, + "CpuQuota": cpu_quota, + }, } if mem_limit_bytes is not None: - config['HostConfig']['Memory'] = mem_limit_bytes + config["HostConfig"]["Memory"] = mem_limit_bytes # Use wait_for for the run call itself in case image pulling takes time container = await asyncio.wait_for( client.containers.run(config=config), - timeout=DOCKER_COMMAND_TIMEOUT + 15 # Add buffer for container start/stop/pull + timeout=DOCKER_COMMAND_TIMEOUT + + 15, # Add buffer for container start/stop/pull ) # Wait for the container to finish execution wait_result = await asyncio.wait_for( - container.wait(), - timeout=DOCKER_COMMAND_TIMEOUT + container.wait(), timeout=DOCKER_COMMAND_TIMEOUT ) - exit_code = wait_result.get('StatusCode', -1) + exit_code = wait_result.get("StatusCode", -1) # Get logs after container finishes # container.log() returns a list of strings when stream=False (default) @@ -770,11 +1196,18 @@ async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any stderr = "".join(stderr_lines) if stderr_lines else "" max_len = 1000 - stdout_trunc = stdout[:max_len] + ('...' if len(stdout) > max_len else '') - stderr_trunc = stderr[:max_len] + ('...' if len(stderr) > max_len else '') + stdout_trunc = stdout[:max_len] + ("..." if len(stdout) > max_len else "") + stderr_trunc = stderr[:max_len] + ("..." if len(stderr) > max_len else "") - result = {"status": "success" if exit_code == 0 else "execution_error", "stdout": stdout_trunc, "stderr": stderr_trunc, "exit_code": exit_code} - print(f"Docker command finished. Exit Code: {exit_code}. Output length: {len(stdout)}, Stderr length: {len(stderr)}") + result = { + "status": "success" if exit_code == 0 else "execution_error", + "stdout": stdout_trunc, + "stderr": stderr_trunc, + "exit_code": exit_code, + } + print( + f"Docker command finished. Exit Code: {exit_code}. Output length: {len(stdout)}, Stderr length: {len(stderr)}" + ) return result except asyncio.TimeoutError: @@ -790,22 +1223,42 @@ async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any # await container.delete(force=True) # Force needed if stop failed? # print(f"Container {container.id[:12]} deleted.") except aiodocker.exceptions.DockerError as stop_err: - print(f"Error stopping/deleting timed-out container {container.id[:12]}: {stop_err}") + print( + f"Error stopping/deleting timed-out container {container.id[:12]}: {stop_err}" + ) except Exception as stop_exc: - print(f"Unexpected error stopping/deleting timed-out container {container.id[:12]}: {stop_exc}") + print( + f"Unexpected error stopping/deleting timed-out container {container.id[:12]}: {stop_exc}" + ) # No need to delete here, finally block will handle it - return {"error": f"Command execution/log retrieval timed out after {DOCKER_COMMAND_TIMEOUT}s", "command": command, "status": "timeout"} - except aiodocker.exceptions.DockerError as e: # Catch specific aiodocker errors + return { + "error": f"Command execution/log retrieval timed out after {DOCKER_COMMAND_TIMEOUT}s", + "command": command, + "status": "timeout", + } + except aiodocker.exceptions.DockerError as e: # Catch specific aiodocker errors print(f"Docker API error: {e} (Status: {e.status})") # Check for ImageNotFound specifically if e.status == 404 and ("No such image" in str(e) or "not found" in str(e)): - print(f"Docker image not found: {DOCKER_EXEC_IMAGE}") - return {"error": f"Docker image '{DOCKER_EXEC_IMAGE}' not found.", "command": command, "status": "docker_error"} - return {"error": f"Docker API error ({e.status}): {str(e)}", "command": command, "status": "docker_error"} + print(f"Docker image not found: {DOCKER_EXEC_IMAGE}") + return { + "error": f"Docker image '{DOCKER_EXEC_IMAGE}' not found.", + "command": command, + "status": "docker_error", + } + return { + "error": f"Docker API error ({e.status}): {str(e)}", + "command": command, + "status": "docker_error", + } except Exception as e: print(f"Unexpected Docker error: {e}") traceback.print_exc() - return {"error": f"Unexpected error during Docker execution: {str(e)}", "command": command, "status": "error"} + return { + "error": f"Unexpected error during Docker execution: {str(e)}", + "command": command, + "status": "error", + } finally: # Explicitly remove the container since AutoRemove is False if container: @@ -817,10 +1270,14 @@ async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any # Log error but don't raise, primary error is more important print(f"Error deleting container {container.id[:12]}: {delete_err}") except Exception as delete_exc: - print(f"Unexpected error deleting container {container.id[:12]}: {delete_exc}") + print( + f"Unexpected error deleting container {container.id[:12]}: {delete_exc}" + ) # Ensure the client connection is closed if client: await client.close() + + async def get_user_id(cog: commands.Cog, user_name: str) -> Dict[str, Any]: """Finds the Discord User ID for a given username or display name.""" print(f"Attempting to find user ID for: '{user_name}'") @@ -830,24 +1287,44 @@ async def get_user_id(cog: commands.Cog, user_name: str) -> Dict[str, Any]: user_name_lower = user_name.lower() found_user = None # Check recent message authors (less reliable) - for msg_data in reversed(list(cog.message_cache['global_recent'])): # Check newest first - author_info = msg_data.get('author', {}) - if user_name_lower == author_info.get('name', '').lower() or \ - user_name_lower == author_info.get('display_name', '').lower(): - found_user = {"id": author_info.get('id'), "name": author_info.get('name'), "display_name": author_info.get('display_name')} + for msg_data in reversed( + list(cog.message_cache["global_recent"]) + ): # Check newest first + author_info = msg_data.get("author", {}) + if ( + user_name_lower == author_info.get("name", "").lower() + or user_name_lower == author_info.get("display_name", "").lower() + ): + found_user = { + "id": author_info.get("id"), + "name": author_info.get("name"), + "display_name": author_info.get("display_name"), + } break if found_user and found_user.get("id"): - print(f"Found user ID {found_user['id']} for '{user_name}' in global message cache.") - return {"status": "success", "user_id": found_user["id"], "user_name": found_user["name"], "display_name": found_user["display_name"]} + print( + f"Found user ID {found_user['id']} for '{user_name}' in global message cache." + ) + return { + "status": "success", + "user_id": found_user["id"], + "user_name": found_user["name"], + "display_name": found_user["display_name"], + } else: print(f"User '{user_name}' not found in recent global message cache.") - return {"error": f"User '{user_name}' not found in recent messages.", "user_name": user_name} + return { + "error": f"User '{user_name}' not found in recent messages.", + "user_name": user_name, + } # If in a guild, search members guild = cog.current_channel.guild - member = guild.get_member_named(user_name) # Case-sensitive username#discriminator or exact display name + member = guild.get_member_named( + user_name + ) # Case-sensitive username#discriminator or exact display name - if not member: # Try case-insensitive display name search + if not member: # Try case-insensitive display name search user_name_lower = user_name.lower() for m in guild.members: if m.display_name.lower() == user_name_lower: @@ -856,13 +1333,23 @@ async def get_user_id(cog: commands.Cog, user_name: str) -> Dict[str, Any]: if member: print(f"Found user ID {member.id} for '{user_name}' in guild '{guild.name}'.") - return {"status": "success", "user_id": str(member.id), "user_name": member.name, "display_name": member.display_name} + return { + "status": "success", + "user_id": str(member.id), + "user_name": member.name, + "display_name": member.display_name, + } else: print(f"User '{user_name}' not found in guild '{guild.name}'.") - return {"error": f"User '{user_name}' not found in this server.", "user_name": user_name} + return { + "error": f"User '{user_name}' not found in this server.", + "user_name": user_name, + } -async def execute_internal_command(cog: commands.Cog, command: str, timeout_seconds: int = 60, user_id: str = None) -> Dict[str, Any]: +async def execute_internal_command( + cog: commands.Cog, command: str, timeout_seconds: int = 60, user_id: str = None +) -> Dict[str, Any]: """ Executes a shell command directly on the host machine where the bot is running. WARNING: This tool is intended ONLY for internal Gurt operations and MUST NOT @@ -871,33 +1358,38 @@ async def execute_internal_command(cog: commands.Cog, command: str, timeout_seco Only user ID 452666956353503252 is allowed to execute this command. """ if user_id != "452666956353503252": - return {"error": "The requesting user is not authorized to execute commands.", "status": "unauthorized"} + return { + "error": "The requesting user is not authorized to execute commands.", + "status": "unauthorized", + } print(f"--- INTERNAL EXECUTION (UNSAFE): Running command: {command} ---") try: process = await asyncio.create_subprocess_shell( - command, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) - stdout_bytes, stderr_bytes = await asyncio.wait_for(process.communicate(), timeout=timeout_seconds) + stdout_bytes, stderr_bytes = await asyncio.wait_for( + process.communicate(), timeout=timeout_seconds + ) exit_code = process.returncode - stdout = stdout_bytes.decode(errors='replace').strip() - stderr = stderr_bytes.decode(errors='replace').strip() + stdout = stdout_bytes.decode(errors="replace").strip() + stderr = stderr_bytes.decode(errors="replace").strip() max_len = 1000 - stdout_trunc = stdout[:max_len] + ('...' if len(stdout) > max_len else '') - stderr_trunc = stderr[:max_len] + ('...' if len(stderr) > max_len else '') + stdout_trunc = stdout[:max_len] + ("..." if len(stdout) > max_len else "") + stderr_trunc = stderr[:max_len] + ("..." if len(stderr) > max_len else "") result = { "status": "success" if exit_code == 0 else "execution_error", "stdout": stdout_trunc, "stderr": stderr_trunc, "exit_code": exit_code, - "command": command + "command": command, } - print(f"Internal command finished. Exit Code: {exit_code}. Output length: {len(stdout)}, Stderr length: {len(stderr)}") + print( + f"Internal command finished. Exit Code: {exit_code}. Output length: {len(stdout)}, Stderr length: {len(stderr)}" + ) return result except asyncio.TimeoutError: @@ -906,13 +1398,19 @@ async def execute_internal_command(cog: commands.Cog, command: str, timeout_seco if process and process.returncode is None: try: process.kill() - await process.wait() # Ensure it's cleaned up + await process.wait() # Ensure it's cleaned up print(f"Killed timed-out internal process (PID: {process.pid})") except ProcessLookupError: print(f"Internal process (PID: {process.pid}) already finished.") except Exception as kill_e: - print(f"Error killing timed-out internal process (PID: {process.pid}): {kill_e}") - return {"error": f"Command execution timed out after {timeout_seconds}s", "command": command, "status": "timeout"} + print( + f"Error killing timed-out internal process (PID: {process.pid}): {kill_e}" + ) + return { + "error": f"Command execution timed out after {timeout_seconds}s", + "command": command, + "status": "timeout", + } except FileNotFoundError: error_message = f"Command not found: {command.split()[0]}" print(f"Internal command error: {error_message}") @@ -923,37 +1421,70 @@ async def execute_internal_command(cog: commands.Cog, command: str, timeout_seco traceback.print_exc() return {"error": error_message, "command": command, "status": "error"} -async def extract_web_content(cog: commands.Cog, urls: Union[str, List[str]], extract_depth: str = "basic", include_images: bool = False) -> Dict[str, Any]: + +async def extract_web_content( + cog: commands.Cog, + urls: Union[str, List[str]], + extract_depth: str = "basic", + include_images: bool = False, +) -> Dict[str, Any]: """Extract content from URLs using Tavily API""" - if not hasattr(cog, 'tavily_client') or not cog.tavily_client: - return {"error": "Tavily client not initialized.", "timestamp": datetime.datetime.now().isoformat()} + if not hasattr(cog, "tavily_client") or not cog.tavily_client: + return { + "error": "Tavily client not initialized.", + "timestamp": datetime.datetime.now().isoformat(), + } # Cost control / Logging for advanced extract final_extract_depth = extract_depth if extract_depth.lower() == "advanced": if TAVILY_DISABLE_ADVANCED: - print(f"Warning: Advanced Tavily extract requested but disabled by config. Falling back to basic.") + print( + f"Warning: Advanced Tavily extract requested but disabled by config. Falling back to basic." + ) final_extract_depth = "basic" else: - print(f"Performing advanced Tavily extract (cost: 2 credits per 5 URLs) for URLs: {urls}") + print( + f"Performing advanced Tavily extract (cost: 2 credits per 5 URLs) for URLs: {urls}" + ) elif extract_depth.lower() != "basic": - print(f"Warning: Invalid extract_depth '{extract_depth}' provided. Using 'basic'.") + print( + f"Warning: Invalid extract_depth '{extract_depth}' provided. Using 'basic'." + ) final_extract_depth = "basic" try: response = await asyncio.to_thread( cog.tavily_client.extract, urls=urls, - extract_depth=final_extract_depth, # Use validated depth - include_images=include_images + extract_depth=final_extract_depth, # Use validated depth + include_images=include_images, ) - results = [{"url": r.get("url"), "raw_content": r.get("raw_content"), "images": r.get("images")} for r in response.get("results", [])] + results = [ + { + "url": r.get("url"), + "raw_content": r.get("raw_content"), + "images": r.get("images"), + } + for r in response.get("results", []) + ] failed_results = response.get("failed_results", []) - return {"urls": urls, "extract_depth": extract_depth, "include_images": include_images, "results": results, "failed_results": failed_results, "timestamp": datetime.datetime.now().isoformat()} + return { + "urls": urls, + "extract_depth": extract_depth, + "include_images": include_images, + "results": results, + "failed_results": failed_results, + "timestamp": datetime.datetime.now().isoformat(), + } except Exception as e: error_message = f"Error during Tavily extract for '{urls}': {str(e)}" print(error_message) - return {"error": error_message, "timestamp": datetime.datetime.now().isoformat()} + return { + "error": error_message, + "timestamp": datetime.datetime.now().isoformat(), + } + async def read_file_content(cog: commands.Cog, file_path: str) -> Dict[str, Any]: """ @@ -971,14 +1502,16 @@ async def read_file_content(cog: commands.Cog, file_path: str) -> Dict[str, Any] # Use async file reading if available/needed, otherwise sync with to_thread def sync_read(): - with open(full_path, 'r', encoding='utf-8') as f: + with open(full_path, "r", encoding="utf-8") as f: # Limit file size read? For now, read whole file. Consider adding limit later. return f.read() content = await asyncio.to_thread(sync_read) - max_len = 1000000 # Increased limit for potentially larger reads - content_trunc = content[:max_len] + ('...' if len(content) > max_len else '') - print(f"--- UNSAFE READ: Successfully read {len(content)} bytes from {file_path}. Returning {len(content_trunc)} bytes. ---") + max_len = 1000000 # Increased limit for potentially larger reads + content_trunc = content[:max_len] + ("..." if len(content) > max_len else "") + print( + f"--- UNSAFE READ: Successfully read {len(content)} bytes from {file_path}. Returning {len(content_trunc)} bytes. ---" + ) return {"status": "success", "file_path": file_path, "content": content_trunc} except FileNotFoundError: @@ -1003,14 +1536,22 @@ async def read_file_content(cog: commands.Cog, file_path: str) -> Dict[str, Any] traceback.print_exc() return {"error": error_message, "file_path": file_path} -async def write_file_content_unsafe(cog: commands.Cog, file_path: str, content: str, mode: str = 'w') -> Dict[str, Any]: + +async def write_file_content_unsafe( + cog: commands.Cog, file_path: str, content: str, mode: str = "w" +) -> Dict[str, Any]: """ Writes content to a specified file. WARNING: No safety checks are performed. Uses 'w' (overwrite) or 'a' (append) mode. Creates directories if needed. """ - print(f"--- UNSAFE WRITE: Attempting to write to file: {file_path} (Mode: {mode}) ---") - if mode not in ['w', 'a']: - return {"error": "Invalid mode. Use 'w' (overwrite) or 'a' (append).", "file_path": file_path} + print( + f"--- UNSAFE WRITE: Attempting to write to file: {file_path} (Mode: {mode}) ---" + ) + if mode not in ["w", "a"]: + return { + "error": "Invalid mode. Use 'w' (overwrite) or 'a' (append).", + "file_path": file_path, + } try: # Normalize path relative to CWD @@ -1026,13 +1567,20 @@ async def write_file_content_unsafe(cog: commands.Cog, file_path: str, content: # Use async file writing if available/needed, otherwise sync with to_thread def sync_write(): - with open(full_path, mode, encoding='utf-8') as f: + with open(full_path, mode, encoding="utf-8") as f: bytes_written = f.write(content) return bytes_written bytes_written = await asyncio.to_thread(sync_write) - print(f"--- UNSAFE WRITE: Successfully wrote {bytes_written} bytes to {file_path} (Mode: {mode}). ---") - return {"status": "success", "file_path": file_path, "bytes_written": bytes_written, "mode": mode} + print( + f"--- UNSAFE WRITE: Successfully wrote {bytes_written} bytes to {file_path} (Mode: {mode}). ---" + ) + return { + "status": "success", + "file_path": file_path, + "bytes_written": bytes_written, + "mode": mode, + } except PermissionError: error_message = "Permission denied." @@ -1048,7 +1596,10 @@ async def write_file_content_unsafe(cog: commands.Cog, file_path: str, content: traceback.print_exc() return {"error": error_message, "file_path": file_path} -async def execute_python_unsafe(cog: commands.Cog, code: str, timeout_seconds: int = 30) -> Dict[str, Any]: + +async def execute_python_unsafe( + cog: commands.Cog, code: str, timeout_seconds: int = 30 +) -> Dict[str, Any]: """ Executes arbitrary Python code directly on the host using exec(). WARNING: EXTREMELY DANGEROUS. No sandboxing. Can access/modify anything the bot process can. @@ -1059,7 +1610,14 @@ async def execute_python_unsafe(cog: commands.Cog, code: str, timeout_seconds: i import contextlib import threading - local_namespace = {'cog': cog, 'asyncio': asyncio, 'discord': discord, 'random': random, 'os': os, 'time': time} # Provide some context + local_namespace = { + "cog": cog, + "asyncio": asyncio, + "discord": discord, + "random": random, + "os": os, + "time": time, + } # Provide some context stdout_capture = io.StringIO() stderr_capture = io.StringIO() result = {"status": "unknown", "stdout": "", "stderr": "", "error": None} @@ -1068,14 +1626,18 @@ async def execute_python_unsafe(cog: commands.Cog, code: str, timeout_seconds: i def target(): nonlocal exec_exception try: - with contextlib.redirect_stdout(stdout_capture), contextlib.redirect_stderr(stderr_capture): + with contextlib.redirect_stdout(stdout_capture), contextlib.redirect_stderr( + stderr_capture + ): # Execute the code in a restricted namespace? For now, use globals() + locals exec(code, globals(), local_namespace) except Exception as e: nonlocal exec_exception exec_exception = e print(f"--- UNSAFE PYTHON EXEC: Exception during execution: {e} ---") - traceback.print_exc(file=stderr_capture) # Also print traceback to stderr capture + traceback.print_exc( + file=stderr_capture + ) # Also print traceback to stderr capture thread = threading.Thread(target=target) thread.start() @@ -1085,7 +1647,9 @@ async def execute_python_unsafe(cog: commands.Cog, code: str, timeout_seconds: i # Timeout occurred - This is tricky to kill reliably from another thread in Python # For now, we just report the timeout. The code might still be running. result["status"] = "timeout" - result["error"] = f"Execution timed out after {timeout_seconds} seconds. Code might still be running." + result["error"] = ( + f"Execution timed out after {timeout_seconds} seconds. Code might still be running." + ) print(f"--- UNSAFE PYTHON EXEC: Timeout after {timeout_seconds}s ---") elif exec_exception: result["status"] = "execution_error" @@ -1097,17 +1661,27 @@ async def execute_python_unsafe(cog: commands.Cog, code: str, timeout_seconds: i stdout_val = stdout_capture.getvalue() stderr_val = stderr_capture.getvalue() max_len = 2000 - result["stdout"] = stdout_val[:max_len] + ('...' if len(stdout_val) > max_len else '') - result["stderr"] = stderr_val[:max_len] + ('...' if len(stderr_val) > max_len else '') + result["stdout"] = stdout_val[:max_len] + ( + "..." if len(stdout_val) > max_len else "" + ) + result["stderr"] = stderr_val[:max_len] + ( + "..." if len(stderr_val) > max_len else "" + ) stdout_capture.close() stderr_capture.close() return result -async def send_discord_message(cog: commands.Cog, channel_id: str, message_content: str) -> Dict[str, Any]: + +async def send_discord_message( + cog: commands.Cog, channel_id: str, message_content: str +) -> Dict[str, Any]: """Sends a message to a specified Discord channel.""" - print(f"Attempting to send message to channel {channel_id}: {message_content[:100]}...") + print( + f"Attempting to send message to channel {channel_id}: {message_content[:100]}..." + ) + async def restart_gurt_bot(cog: commands.Cog, channel_id: str = None) -> Dict[str, Any]: """ @@ -1118,8 +1692,11 @@ async def restart_gurt_bot(cog: commands.Cog, channel_id: str = None) -> Dict[st """ import sys import os + if not channel_id: - return {"error": "channel_id must be provided to send the restart message in the correct channel."} + return { + "error": "channel_id must be provided to send the restart message in the correct channel." + } try: await send_discord_message(cog, channel_id, "Restart tool was called.") except Exception as msg_exc: @@ -1131,6 +1708,7 @@ async def restart_gurt_bot(cog: commands.Cog, channel_id: str = None) -> Dict[st except Exception as e: return {"status": "error", "error": f"Failed to restart: {str(e)}"} + async def run_git_pull(cog: commands.Cog, user_id: str) -> Dict[str, Any]: """ Runs 'git pull' in the bot's current working directory on the host machine. @@ -1138,15 +1716,22 @@ async def run_git_pull(cog: commands.Cog, user_id: str) -> Dict[str, Any]: """ return await execute_internal_command(cog=cog, command="git pull", user_id=user_id) -async def send_discord_message(cog: commands.Cog, channel_id: str, message_content: str) -> Dict[str, Any]: + +async def send_discord_message( + cog: commands.Cog, channel_id: str, message_content: str +) -> Dict[str, Any]: """Sends a message to a specified Discord channel.""" - print(f"Attempting to send message to channel {channel_id}: {message_content[:100]}...") + print( + f"Attempting to send message to channel {channel_id}: {message_content[:100]}..." + ) # Ensure this function doesn't contain the logic accidentally put in the original run_git_pull if not message_content: return {"error": "Message content cannot be empty."} # Limit message length - max_msg_len = 1900 # Slightly less than Discord limit - message_content = message_content[:max_msg_len] + ('...' if len(message_content) > max_msg_len else '') + max_msg_len = 1900 # Slightly less than Discord limit + message_content = message_content[:max_msg_len] + ( + "..." if len(message_content) > max_msg_len else "" + ) try: channel_id_int = int(channel_id) @@ -1158,24 +1743,34 @@ async def send_discord_message(cog: commands.Cog, channel_id: str, message_conte if not channel: return {"error": f"Channel {channel_id} not found or inaccessible."} if not isinstance(channel, discord.abc.Messageable): - return {"error": f"Channel {channel_id} is not messageable (Type: {type(channel)})."} + return { + "error": f"Channel {channel_id} is not messageable (Type: {type(channel)})." + } # Check permissions if it's a guild channel if isinstance(channel, discord.abc.GuildChannel): bot_member = channel.guild.me if not channel.permissions_for(bot_member).send_messages: - return {"error": f"Missing 'Send Messages' permission in channel {channel_id}."} + return { + "error": f"Missing 'Send Messages' permission in channel {channel_id}." + } sent_message = await channel.send(message_content) print(f"Successfully sent message {sent_message.id} to channel {channel_id}.") - return {"status": "success", "channel_id": channel_id, "message_id": str(sent_message.id)} + return { + "status": "success", + "channel_id": channel_id, + "message_id": str(sent_message.id), + } except ValueError: return {"error": f"Invalid channel ID format: {channel_id}."} except discord.NotFound: return {"error": f"Channel {channel_id} not found."} except discord.Forbidden: - return {"error": f"Forbidden: Missing permissions to send message in channel {channel_id}."} + return { + "error": f"Forbidden: Missing permissions to send message in channel {channel_id}." + } except discord.HTTPException as e: error_message = f"API error sending message to {channel_id}: {e}" print(error_message) @@ -1189,7 +1784,13 @@ async def send_discord_message(cog: commands.Cog, channel_id: str, message_conte # --- Meta Tool: Create New Tool --- # WARNING: HIGHLY EXPERIMENTAL AND DANGEROUS. Allows AI to write and load code. -async def create_new_tool(cog: commands.Cog, tool_name: str, description: str, parameters_json: str, returns_description: str) -> Dict[str, Any]: +async def create_new_tool( + cog: commands.Cog, + tool_name: str, + description: str, + parameters_json: str, + returns_description: str, +) -> Dict[str, Any]: """ EXPERIMENTAL/DANGEROUS: Attempts to create a new tool by generating Python code and its definition using an LLM, then writing it to tools.py and config.py. @@ -1198,18 +1799,22 @@ async def create_new_tool(cog: commands.Cog, tool_name: str, description: str, p for the tool's parameters, similar to other FunctionDeclarations. """ print(f"--- DANGEROUS OPERATION: Attempting to create new tool: {tool_name} ---") - from .api import get_internal_ai_json_response # Local import - from .config import TOOLS # Import for context, though modifying it runtime is hard + from .api import get_internal_ai_json_response # Local import + from .config import TOOLS # Import for context, though modifying it runtime is hard # Basic validation - if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', tool_name): + if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", tool_name): return {"error": "Invalid tool name. Must be a valid Python function name."} if tool_name in TOOL_MAPPING: return {"error": f"Tool '{tool_name}' already exists."} try: params_dict = json.loads(parameters_json) - if not isinstance(params_dict.get('properties'), dict) or not isinstance(params_dict.get('required'), list): - raise ValueError("Invalid parameters_json structure. Must contain 'properties' (dict) and 'required' (list).") + if not isinstance(params_dict.get("properties"), dict) or not isinstance( + params_dict.get("required"), list + ): + raise ValueError( + "Invalid parameters_json structure. Must contain 'properties' (dict) and 'required' (list)." + ) except json.JSONDecodeError: return {"error": "Invalid parameters_json. Must be valid JSON."} except ValueError as e: @@ -1219,11 +1824,24 @@ async def create_new_tool(cog: commands.Cog, tool_name: str, description: str, p generation_schema = { "type": "object", "properties": { - "python_function_code": {"type": "string", "description": "The complete Python async function code for the new tool, including imports if necessary."}, - "function_declaration_params": {"type": "string", "description": "The JSON string for the 'parameters' part of the FunctionDeclaration."}, - "function_declaration_desc": {"type": "string", "description": "The 'description' string for the FunctionDeclaration."} + "python_function_code": { + "type": "string", + "description": "The complete Python async function code for the new tool, including imports if necessary.", + }, + "function_declaration_params": { + "type": "string", + "description": "The JSON string for the 'parameters' part of the FunctionDeclaration.", + }, + "function_declaration_desc": { + "type": "string", + "description": "The 'description' string for the FunctionDeclaration.", + }, }, - "required": ["python_function_code", "function_declaration_params", "function_declaration_desc"] + "required": [ + "python_function_code", + "function_declaration_params", + "function_declaration_desc", + ], } system_prompt = ( "You are a Python code generation assistant for Gurt, a Discord bot. " @@ -1244,7 +1862,7 @@ async def create_new_tool(cog: commands.Cog, tool_name: str, description: str, p generation_prompt_messages = [ {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} + {"role": "user", "content": user_prompt}, ] print(f"Generating code for tool '{tool_name}'...") @@ -1253,32 +1871,45 @@ async def create_new_tool(cog: commands.Cog, tool_name: str, description: str, p prompt_messages=generation_prompt_messages, task_description=f"Generate code for new tool '{tool_name}'", response_schema_dict=generation_schema, - model_name_override=cog.default_model, # Use default model for generation - temperature=0.3, # Lower temperature for more predictable code - max_tokens=5000 # Allow ample space for code generation + model_name_override=cog.default_model, # Use default model for generation + temperature=0.3, # Lower temperature for more predictable code + max_tokens=5000, # Allow ample space for code generation ) # Unpack the tuple, we only need the parsed data here generated_parsed_data, _ = generated_data if generated_data else (None, None) - if not generated_parsed_data or "python_function_code" not in generated_parsed_data or "function_declaration_params" not in generated_parsed_data: - error_msg = f"Failed to generate code for tool '{tool_name}'. LLM response invalid: {generated_parsed_data}" # Log parsed data on error + if ( + not generated_parsed_data + or "python_function_code" not in generated_parsed_data + or "function_declaration_params" not in generated_parsed_data + ): + error_msg = f"Failed to generate code for tool '{tool_name}'. LLM response invalid: {generated_parsed_data}" # Log parsed data on error print(error_msg) return {"error": error_msg} python_code = generated_parsed_data["python_function_code"].strip() - declaration_params_str = generated_parsed_data["function_declaration_params"].strip() + declaration_params_str = generated_parsed_data[ + "function_declaration_params" + ].strip() declaration_desc = generated_parsed_data["function_declaration_desc"].strip() # Escape quotes in the description *before* using it in the f-string escaped_declaration_desc = declaration_desc.replace('"', '\\"') # Basic validation of generated code (very superficial) - if not python_code.startswith("async def") or f" {tool_name}(" not in python_code or "cog: commands.Cog" not in python_code: - error_msg = f"Generated Python code for '{tool_name}' seems invalid (missing async def, cog param, or function name)." - print(error_msg) - print("--- Generated Code ---") - print(python_code) - print("----------------------") - return {"error": error_msg, "generated_code": python_code} # Return code for debugging + if ( + not python_code.startswith("async def") + or f" {tool_name}(" not in python_code + or "cog: commands.Cog" not in python_code + ): + error_msg = f"Generated Python code for '{tool_name}' seems invalid (missing async def, cog param, or function name)." + print(error_msg) + print("--- Generated Code ---") + print(python_code) + print("----------------------") + return { + "error": error_msg, + "generated_code": python_code, + } # Return code for debugging # --- Attempt to write to files (HIGH RISK) --- # Note: This is brittle. Concurrent writes or errors could corrupt files. @@ -1300,7 +1931,9 @@ async def create_new_tool(cog: commands.Cog, tool_name: str, description: str, p raise RuntimeError("Could not find TOOL_MAPPING definition in tools.py") # Insert the new function code before the mapping - new_function_lines = ["\n"] + [line + "\n" for line in python_code.splitlines()] + ["\n"] + new_function_lines = ( + ["\n"] + [line + "\n" for line in python_code.splitlines()] + ["\n"] + ) content[insert_line:insert_line] = new_function_lines f.seek(0) @@ -1309,7 +1942,8 @@ async def create_new_tool(cog: commands.Cog, tool_name: str, description: str, p print(f"Successfully appended function '{tool_name}' to {tools_py_path}") except Exception as e: error_msg = f"Failed to write function to {tools_py_path}: {e}" - print(error_msg); traceback.print_exc() + print(error_msg) + traceback.print_exc() return {"error": error_msg} # 2. Add tool to TOOL_MAPPING in tools.py @@ -1326,7 +1960,9 @@ async def create_new_tool(cog: commands.Cog, tool_name: str, description: str, p mapping_end_line = i break if mapping_end_line == -1: - raise RuntimeError("Could not find end of TOOL_MAPPING definition '}' in tools.py") + raise RuntimeError( + "Could not find end of TOOL_MAPPING definition '}' in tools.py" + ) # Add the new mapping entry before the closing brace new_mapping_line = f' "{tool_name}": {tool_name},\n' @@ -1337,21 +1973,26 @@ async def create_new_tool(cog: commands.Cog, tool_name: str, description: str, p f.truncate() print(f"Successfully added '{tool_name}' to TOOL_MAPPING.") except Exception as e: - error_msg = f"Failed to add '{tool_name}' to TOOL_MAPPING in {tools_py_path}: {e}" - print(error_msg); traceback.print_exc() + error_msg = ( + f"Failed to add '{tool_name}' to TOOL_MAPPING in {tools_py_path}: {e}" + ) + print(error_msg) + traceback.print_exc() # Attempt to revert the function addition if mapping fails? Complex. return {"error": error_msg} # 3. Add FunctionDeclaration to config.py config_py_path = "discordbot/gurt/config.py" try: - print(f"Attempting to add FunctionDeclaration for '{tool_name}' to {config_py_path}...") + print( + f"Attempting to add FunctionDeclaration for '{tool_name}' to {config_py_path}..." + ) # Use FunctionDeclaration directly, assuming it's imported in config.py declaration_code = ( f" tool_declarations.append(\n" f" FunctionDeclaration( # Use imported FunctionDeclaration\n" - f" name=\"{tool_name}\",\n" - f" description=\"{escaped_declaration_desc}\", # Use escaped description\n" + f' name="{tool_name}",\n' + f' description="{escaped_declaration_desc}", # Use escaped description\n' f" parameters={declaration_params_str} # Generated parameters\n" f" )\n" f" )\n" @@ -1368,19 +2009,26 @@ async def create_new_tool(cog: commands.Cog, tool_name: str, description: str, p insert_line = i break if insert_line == -1: - raise RuntimeError("Could not find 'return tool_declarations' in config.py") + raise RuntimeError( + "Could not find 'return tool_declarations' in config.py" + ) # Insert the new declaration code before the return statement - new_declaration_lines = ["\n"] + [line + "\n" for line in declaration_code.splitlines()] + new_declaration_lines = ["\n"] + [ + line + "\n" for line in declaration_code.splitlines() + ] content[insert_line:insert_line] = new_declaration_lines f.seek(0) f.writelines(content) f.truncate() - print(f"Successfully added FunctionDeclaration for '{tool_name}' to {config_py_path}") + print( + f"Successfully added FunctionDeclaration for '{tool_name}' to {config_py_path}" + ) except Exception as e: error_msg = f"Failed to add FunctionDeclaration to {config_py_path}: {e}" - print(error_msg); traceback.print_exc() + print(error_msg) + traceback.print_exc() # Attempt to revert previous changes? Very complex. return {"error": error_msg} @@ -1396,7 +2044,9 @@ async def create_new_tool(cog: commands.Cog, tool_name: str, description: str, p # For now, just update the runtime TOOL_MAPPING if possible. # This requires the function to be somehow available in the current scope. # Let's assume for now it needs a restart/reload. - print(f"Runtime update of TOOL_MAPPING for '{tool_name}' skipped. Requires bot reload.") + print( + f"Runtime update of TOOL_MAPPING for '{tool_name}' skipped. Requires bot reload." + ) # If we could dynamically import: # TOOL_MAPPING[tool_name] = dynamically_imported_function @@ -1408,11 +2058,12 @@ async def create_new_tool(cog: commands.Cog, tool_name: str, description: str, p "status": "success", "tool_name": tool_name, "message": f"Tool '{tool_name}' code and definition written to files. Bot reload/restart likely required for full activation.", - "generated_function_code": python_code, # Return for inspection + "generated_function_code": python_code, # Return for inspection "generated_declaration_desc": declaration_desc, - "generated_declaration_params": declaration_params_str + "generated_declaration_params": declaration_params_str, } + async def no_operation(cog: commands.Cog) -> Dict[str, Any]: """ Does absolutely nothing. Used when a tool call is forced but no action is needed. @@ -1433,30 +2084,54 @@ async def get_channel_id(cog: commands.Cog, channel_name: str = None) -> Dict[st guild = cog.current_channel.guild # Try to find by name (case-insensitive) for channel in guild.channels: - if hasattr(channel, "name") and channel.name.lower() == channel_name.lower(): - return {"status": "success", "channel_id": str(channel.id), "channel_name": channel.name} + if ( + hasattr(channel, "name") + and channel.name.lower() == channel_name.lower() + ): + return { + "status": "success", + "channel_id": str(channel.id), + "channel_name": channel.name, + } return {"error": f"Channel '{channel_name}' not found in this server."} else: channel = cog.current_channel if not channel: return {"error": "No current channel context."} - return {"status": "success", "channel_id": str(channel.id), "channel_name": getattr(channel, "name", "DM Channel")} + return { + "status": "success", + "channel_id": str(channel.id), + "channel_name": getattr(channel, "name", "DM Channel"), + } except Exception as e: return {"error": f"Error getting channel ID: {str(e)}"} + # Tool 1: get_guild_info async def get_guild_info(cog: commands.Cog) -> Dict[str, Any]: """Gets information about the current Discord server.""" print("Executing get_guild_info tool.") - if not cog.current_channel or not isinstance(cog.current_channel, discord.abc.GuildChannel): + if not cog.current_channel or not isinstance( + cog.current_channel, discord.abc.GuildChannel + ): return {"error": "Cannot get guild info outside of a server channel."} guild = cog.current_channel.guild if not guild: return {"error": "Could not determine the current server."} try: - owner = guild.owner or await guild.fetch_member(guild.owner_id) # Fetch if not cached - owner_info = {"id": str(owner.id), "name": owner.name, "display_name": owner.display_name} if owner else None + owner = guild.owner or await guild.fetch_member( + guild.owner_id + ) # Fetch if not cached + owner_info = ( + { + "id": str(owner.id), + "name": owner.name, + "display_name": owner.display_name, + } + if owner + else None + ) return { "status": "success", @@ -1470,24 +2145,35 @@ async def get_guild_info(cog: commands.Cog) -> Dict[str, Any]: "banner_url": str(guild.banner.url) if guild.banner else None, "features": guild.features, "preferred_locale": guild.preferred_locale, - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: error_message = f"Error getting guild info: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} + # Tool 2: list_guild_members -async def list_guild_members(cog: commands.Cog, limit: int = 50, status_filter: Optional[str] = None, role_id_filter: Optional[str] = None) -> Dict[str, Any]: +async def list_guild_members( + cog: commands.Cog, + limit: int = 50, + status_filter: Optional[str] = None, + role_id_filter: Optional[str] = None, +) -> Dict[str, Any]: """Lists members in the current server, with optional filters.""" - print(f"Executing list_guild_members tool (limit={limit}, status={status_filter}, role={role_id_filter}).") - if not cog.current_channel or not isinstance(cog.current_channel, discord.abc.GuildChannel): + print( + f"Executing list_guild_members tool (limit={limit}, status={status_filter}, role={role_id_filter})." + ) + if not cog.current_channel or not isinstance( + cog.current_channel, discord.abc.GuildChannel + ): return {"error": "Cannot list members outside of a server channel."} guild = cog.current_channel.guild if not guild: return {"error": "Could not determine the current server."} - limit = min(max(1, limit), 1000) # Limit fetch size + limit = min(max(1, limit), 1000) # Limit fetch size members_list = [] role_filter_obj = None if role_id_filter: @@ -1501,17 +2187,21 @@ async def list_guild_members(cog: commands.Cog, limit: int = 50, status_filter: valid_statuses = ["online", "idle", "dnd", "offline"] status_filter_lower = status_filter.lower() if status_filter else None if status_filter_lower and status_filter_lower not in valid_statuses: - return {"error": f"Invalid status_filter. Use one of: {', '.join(valid_statuses)}"} + return { + "error": f"Invalid status_filter. Use one of: {', '.join(valid_statuses)}" + } try: # Fetching all members can be intensive, use guild.members if populated, otherwise fetch cautiously # Note: Fetching all members requires the Members privileged intent. - fetched_members = guild.members # Use cached first + fetched_members = guild.members # Use cached first if len(fetched_members) < guild.member_count and cog.bot.intents.members: - print(f"Fetching members for guild {guild.id} as cache seems incomplete...") - # This might take time and requires the intent - # Consider adding a timeout or limiting the fetch if it's too slow - fetched_members = await guild.fetch_members(limit=None).flatten() # Fetch all if intent is enabled + print(f"Fetching members for guild {guild.id} as cache seems incomplete...") + # This might take time and requires the intent + # Consider adding a timeout or limiting the fetch if it's too slow + fetched_members = await guild.fetch_members( + limit=None + ).flatten() # Fetch all if intent is enabled count = 0 for member in fetched_members: @@ -1520,15 +2210,23 @@ async def list_guild_members(cog: commands.Cog, limit: int = 50, status_filter: if role_filter_obj and role_filter_obj not in member.roles: continue - members_list.append({ - "id": str(member.id), - "name": member.name, - "display_name": member.display_name, - "bot": member.bot, - "status": str(member.status), - "joined_at": member.joined_at.isoformat() if member.joined_at else None, - "roles": [{"id": str(r.id), "name": r.name} for r in member.roles if r.name != "@everyone"] - }) + members_list.append( + { + "id": str(member.id), + "name": member.name, + "display_name": member.display_name, + "bot": member.bot, + "status": str(member.status), + "joined_at": ( + member.joined_at.isoformat() if member.joined_at else None + ), + "roles": [ + {"id": str(r.id), "name": r.name} + for r in member.roles + if r.name != "@everyone" + ], + } + ) count += 1 if count >= limit: break @@ -1536,18 +2234,26 @@ async def list_guild_members(cog: commands.Cog, limit: int = 50, status_filter: return { "status": "success", "guild_id": str(guild.id), - "filters_applied": {"limit": limit, "status": status_filter, "role_id": role_id_filter}, + "filters_applied": { + "limit": limit, + "status": status_filter, + "role_id": role_id_filter, + }, "members": members_list, "count": len(members_list), - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } except discord.Forbidden: - return {"error": "Missing permissions or intents (Members) to list guild members."} + return { + "error": "Missing permissions or intents (Members) to list guild members." + } except Exception as e: error_message = f"Error listing guild members: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} + # Tool 3: get_user_avatar async def get_user_avatar(cog: commands.Cog, user_id: str) -> Dict[str, Any]: """Gets the avatar URL for a given user ID.""" @@ -1558,14 +2264,16 @@ async def get_user_avatar(cog: commands.Cog, user_id: str) -> Dict[str, Any]: if not user: return {"error": f"User with ID {user_id} not found."} - avatar_url = str(user.display_avatar.url) # display_avatar handles default/server avatar + avatar_url = str( + user.display_avatar.url + ) # display_avatar handles default/server avatar return { "status": "success", "user_id": user_id, "user_name": user.name, "avatar_url": avatar_url, - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } except ValueError: return {"error": f"Invalid user ID format: {user_id}."} @@ -1573,15 +2281,19 @@ async def get_user_avatar(cog: commands.Cog, user_id: str) -> Dict[str, Any]: return {"error": f"User with ID {user_id} not found."} except Exception as e: error_message = f"Error getting user avatar: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} + # Tool 4: get_bot_uptime async def get_bot_uptime(cog: commands.Cog) -> Dict[str, Any]: """Gets the uptime of the bot.""" print("Executing get_bot_uptime tool.") - if not hasattr(cog, 'start_time'): - return {"error": "Bot start time not recorded in cog."} # Assumes cog has a start_time attribute + if not hasattr(cog, "start_time"): + return { + "error": "Bot start time not recorded in cog." + } # Assumes cog has a start_time attribute try: uptime_delta = datetime.datetime.now(datetime.timezone.utc) - cog.start_time @@ -1598,28 +2310,38 @@ async def get_bot_uptime(cog: commands.Cog) -> Dict[str, Any]: "current_time": datetime.datetime.now(datetime.timezone.utc).isoformat(), "uptime_seconds": total_seconds, "uptime_formatted": uptime_str, - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: error_message = f"Error calculating bot uptime: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} + # Tool 5: schedule_message # This requires a persistent scheduling mechanism (like APScheduler or storing in DB) # For simplicity, this example won't implement persistence, making it non-functional across restarts. # A real implementation needs a background task scheduler. -async def schedule_message(cog: commands.Cog, channel_id: str, message_content: str, send_at_iso: str) -> Dict[str, Any]: +async def schedule_message( + cog: commands.Cog, channel_id: str, message_content: str, send_at_iso: str +) -> Dict[str, Any]: """Schedules a message to be sent in a channel at a specific ISO 8601 time.""" - print(f"Executing schedule_message tool: Channel={channel_id}, Time={send_at_iso}, Content='{message_content[:50]}...'") - if not hasattr(cog, 'scheduler') or not cog.scheduler: - return {"error": "Scheduler not available in the cog. Cannot schedule messages persistently."} + print( + f"Executing schedule_message tool: Channel={channel_id}, Time={send_at_iso}, Content='{message_content[:50]}...'" + ) + if not hasattr(cog, "scheduler") or not cog.scheduler: + return { + "error": "Scheduler not available in the cog. Cannot schedule messages persistently." + } try: send_time = datetime.datetime.fromisoformat(send_at_iso) # Ensure timezone awareness, assume UTC if naive? Or require timezone? Let's require it. if send_time.tzinfo is None: - return {"error": "send_at_iso must include timezone information (e.g., +00:00 or Z)."} + return { + "error": "send_at_iso must include timezone information (e.g., +00:00 or Z)." + } now = datetime.datetime.now(datetime.timezone.utc) if send_time <= now: @@ -1631,21 +2353,23 @@ async def schedule_message(cog: commands.Cog, channel_id: str, message_content: # Try fetching if not in cache channel = await cog.bot.fetch_channel(channel_id_int) if not channel or not isinstance(channel, discord.abc.Messageable): - return {"error": f"Channel {channel_id} not found or not messageable."} + return {"error": f"Channel {channel_id} not found or not messageable."} # Limit message length max_msg_len = 1900 - message_content = message_content[:max_msg_len] + ('...' if len(message_content) > max_msg_len else '') + message_content = message_content[:max_msg_len] + ( + "..." if len(message_content) > max_msg_len else "" + ) # --- Scheduling Logic --- # This uses cog.scheduler.add_job which needs to be implemented using e.g., APScheduler job = cog.scheduler.add_job( - send_discord_message, # Use the existing tool function - 'date', + send_discord_message, # Use the existing tool function + "date", run_date=send_time, - args=[cog, channel_id, message_content], # Pass necessary args - id=f"scheduled_msg_{channel_id}_{int(time.time())}", # Unique job ID - misfire_grace_time=600 # Allow 10 mins grace period + args=[cog, channel_id, message_content], # Pass necessary args + id=f"scheduled_msg_{channel_id}_{int(time.time())}", # Unique job ID + misfire_grace_time=600, # Allow 10 mins grace period ) print(f"Scheduled job {job.id} to send message at {send_time.isoformat()}") @@ -1654,32 +2378,41 @@ async def schedule_message(cog: commands.Cog, channel_id: str, message_content: "job_id": job.id, "channel_id": channel_id, "message_content_preview": message_content[:100], - "scheduled_time_utc": send_time.astimezone(datetime.timezone.utc).isoformat(), - "timestamp": datetime.datetime.now().isoformat() + "scheduled_time_utc": send_time.astimezone( + datetime.timezone.utc + ).isoformat(), + "timestamp": datetime.datetime.now().isoformat(), } except ValueError as e: return {"error": f"Invalid format for channel_id or send_at_iso: {e}"} except (discord.NotFound, discord.Forbidden): - return {"error": f"Cannot access or send messages to channel {channel_id}."} - except Exception as e: # Catch scheduler errors too + return {"error": f"Cannot access or send messages to channel {channel_id}."} + except Exception as e: # Catch scheduler errors too error_message = f"Error scheduling message: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} # Tool 6: delete_message -async def delete_message(cog: commands.Cog, message_id: str, channel_id: Optional[str] = None) -> Dict[str, Any]: +async def delete_message( + cog: commands.Cog, message_id: str, channel_id: Optional[str] = None +) -> Dict[str, Any]: """Deletes a specific message by its ID.""" print(f"Executing delete_message tool for message ID: {message_id}.") try: if channel_id: channel = cog.bot.get_channel(int(channel_id)) - if not channel: return {"error": f"Channel {channel_id} not found."} + if not channel: + return {"error": f"Channel {channel_id} not found."} else: channel = cog.current_channel - if not channel: return {"error": "No current channel context."} + if not channel: + return {"error": "No current channel context."} if not isinstance(channel, discord.abc.Messageable): - return {"error": f"Channel {getattr(channel, 'id', 'N/A')} is not messageable."} + return { + "error": f"Channel {getattr(channel, 'id', 'N/A')} is not messageable." + } message_id_int = int(message_id) message = await channel.fetch_message(message_id_int) @@ -1688,46 +2421,73 @@ async def delete_message(cog: commands.Cog, message_id: str, channel_id: Optiona if isinstance(channel, discord.abc.GuildChannel): bot_member = channel.guild.me # Need 'manage_messages' to delete others' messages, can always delete own - if message.author != bot_member and not channel.permissions_for(bot_member).manage_messages: - return {"error": "Missing 'Manage Messages' permission to delete this message."} + if ( + message.author != bot_member + and not channel.permissions_for(bot_member).manage_messages + ): + return { + "error": "Missing 'Manage Messages' permission to delete this message." + } await message.delete() print(f"Successfully deleted message {message_id} in channel {channel.id}.") - return {"status": "success", "message_id": message_id, "channel_id": str(channel.id)} + return { + "status": "success", + "message_id": message_id, + "channel_id": str(channel.id), + } except ValueError: return {"error": f"Invalid message_id or channel_id format."} except discord.NotFound: - return {"error": f"Message {message_id} not found in channel {channel_id or getattr(channel, 'id', 'N/A')}."} + return { + "error": f"Message {message_id} not found in channel {channel_id or getattr(channel, 'id', 'N/A')}." + } except discord.Forbidden: - return {"error": f"Forbidden: Missing permissions to delete message {message_id}."} + return { + "error": f"Forbidden: Missing permissions to delete message {message_id}." + } except discord.HTTPException as e: error_message = f"API error deleting message {message_id}: {e}" print(error_message) return {"error": error_message} except Exception as e: error_message = f"Unexpected error deleting message {message_id}: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} + # Tool 7: edit_message -async def edit_message(cog: commands.Cog, message_id: str, new_content: str, channel_id: Optional[str] = None) -> Dict[str, Any]: +async def edit_message( + cog: commands.Cog, + message_id: str, + new_content: str, + channel_id: Optional[str] = None, +) -> Dict[str, Any]: """Edits a message sent by the bot.""" print(f"Executing edit_message tool for message ID: {message_id}.") - if not new_content: return {"error": "New content cannot be empty."} + if not new_content: + return {"error": "New content cannot be empty."} # Limit message length max_msg_len = 1900 - new_content = new_content[:max_msg_len] + ('...' if len(new_content) > max_msg_len else '') + new_content = new_content[:max_msg_len] + ( + "..." if len(new_content) > max_msg_len else "" + ) try: if channel_id: channel = cog.bot.get_channel(int(channel_id)) - if not channel: return {"error": f"Channel {channel_id} not found."} + if not channel: + return {"error": f"Channel {channel_id} not found."} else: channel = cog.current_channel - if not channel: return {"error": "No current channel context."} + if not channel: + return {"error": "No current channel context."} if not isinstance(channel, discord.abc.Messageable): - return {"error": f"Channel {getattr(channel, 'id', 'N/A')} is not messageable."} + return { + "error": f"Channel {getattr(channel, 'id', 'N/A')} is not messageable." + } message_id_int = int(message_id) message = await channel.fetch_message(message_id_int) @@ -1738,24 +2498,35 @@ async def edit_message(cog: commands.Cog, message_id: str, new_content: str, cha await message.edit(content=new_content) print(f"Successfully edited message {message_id} in channel {channel.id}.") - return {"status": "success", "message_id": message_id, "channel_id": str(channel.id), "new_content_preview": new_content[:100]} + return { + "status": "success", + "message_id": message_id, + "channel_id": str(channel.id), + "new_content_preview": new_content[:100], + } except ValueError: return {"error": f"Invalid message_id or channel_id format."} except discord.NotFound: - return {"error": f"Message {message_id} not found in channel {channel_id or getattr(channel, 'id', 'N/A')}."} + return { + "error": f"Message {message_id} not found in channel {channel_id or getattr(channel, 'id', 'N/A')}." + } except discord.Forbidden: # This usually shouldn't happen if we check author == bot, but include for safety - return {"error": f"Forbidden: Missing permissions to edit message {message_id} (shouldn't happen for own message)."} + return { + "error": f"Forbidden: Missing permissions to edit message {message_id} (shouldn't happen for own message)." + } except discord.HTTPException as e: error_message = f"API error editing message {message_id}: {e}" print(error_message) return {"error": error_message} except Exception as e: error_message = f"Unexpected error editing message {message_id}: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} + # Tool 8: get_voice_channel_info async def get_voice_channel_info(cog: commands.Cog, channel_id: str) -> Dict[str, Any]: """Gets information about a specific voice channel.""" @@ -1767,21 +2538,33 @@ async def get_voice_channel_info(cog: commands.Cog, channel_id: str) -> Dict[str if not channel: return {"error": f"Channel {channel_id} not found."} if not isinstance(channel, discord.VoiceChannel): - return {"error": f"Channel {channel_id} is not a voice channel (Type: {type(channel)})."} + return { + "error": f"Channel {channel_id} is not a voice channel (Type: {type(channel)})." + } members_info = [] for member in channel.members: - members_info.append({ - "id": str(member.id), - "name": member.name, - "display_name": member.display_name, - "voice_state": { - "deaf": member.voice.deaf, "mute": member.voice.mute, - "self_deaf": member.voice.self_deaf, "self_mute": member.voice.self_mute, - "self_stream": member.voice.self_stream, "self_video": member.voice.self_video, - "suppress": member.voice.suppress, "afk": member.voice.afk - } if member.voice else None - }) + members_info.append( + { + "id": str(member.id), + "name": member.name, + "display_name": member.display_name, + "voice_state": ( + { + "deaf": member.voice.deaf, + "mute": member.voice.mute, + "self_deaf": member.voice.self_deaf, + "self_mute": member.voice.self_mute, + "self_stream": member.voice.self_stream, + "self_video": member.voice.self_video, + "suppress": member.voice.suppress, + "afk": member.voice.afk, + } + if member.voice + else None + ), + } + ) return { "status": "success", @@ -1790,39 +2573,56 @@ async def get_voice_channel_info(cog: commands.Cog, channel_id: str) -> Dict[str "bitrate": channel.bitrate, "user_limit": channel.user_limit, "rtc_region": str(channel.rtc_region) if channel.rtc_region else None, - "category": {"id": str(channel.category_id), "name": channel.category.name} if channel.category else None, + "category": ( + {"id": str(channel.category_id), "name": channel.category.name} + if channel.category + else None + ), "guild_id": str(channel.guild.id), "connected_members": members_info, "member_count": len(members_info), - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } except ValueError: return {"error": f"Invalid channel ID format: {channel_id}."} except Exception as e: error_message = f"Error getting voice channel info: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} + # Tool 9: move_user_to_voice_channel -async def move_user_to_voice_channel(cog: commands.Cog, user_id: str, target_channel_id: str) -> Dict[str, Any]: +async def move_user_to_voice_channel( + cog: commands.Cog, user_id: str, target_channel_id: str +) -> Dict[str, Any]: """Moves a user to a specified voice channel within the same server.""" - print(f"Executing move_user_to_voice_channel tool: User={user_id}, TargetChannel={target_channel_id}.") - if not cog.current_channel or not isinstance(cog.current_channel, discord.abc.GuildChannel): + print( + f"Executing move_user_to_voice_channel tool: User={user_id}, TargetChannel={target_channel_id}." + ) + if not cog.current_channel or not isinstance( + cog.current_channel, discord.abc.GuildChannel + ): return {"error": "Cannot move users outside of a server context."} guild = cog.current_channel.guild - if not guild: return {"error": "Could not determine server."} + if not guild: + return {"error": "Could not determine server."} try: user_id_int = int(user_id) target_channel_id_int = int(target_channel_id) member = guild.get_member(user_id_int) or await guild.fetch_member(user_id_int) - if not member: return {"error": f"User {user_id} not found in this server."} + if not member: + return {"error": f"User {user_id} not found in this server."} target_channel = guild.get_channel(target_channel_id_int) - if not target_channel: return {"error": f"Target voice channel {target_channel_id} not found."} + if not target_channel: + return {"error": f"Target voice channel {target_channel_id} not found."} if not isinstance(target_channel, discord.VoiceChannel): - return {"error": f"Target channel {target_channel_id} is not a voice channel."} + return { + "error": f"Target channel {target_channel_id} is not a voice channel." + } # Permission Checks bot_member = guild.me @@ -1831,22 +2631,36 @@ async def move_user_to_voice_channel(cog: commands.Cog, user_id: str, target_cha # Check bot permissions in both origin (if user is connected) and target channels if member.voice and member.voice.channel: origin_channel = member.voice.channel - if not origin_channel.permissions_for(bot_member).connect or not origin_channel.permissions_for(bot_member).move_members: - return {"error": f"I lack Connect/Move permissions in the user's current channel ({origin_channel.name})."} - if not target_channel.permissions_for(bot_member).connect or not target_channel.permissions_for(bot_member).move_members: - return {"error": f"I lack Connect/Move permissions in the target channel ({target_channel.name})."} + if ( + not origin_channel.permissions_for(bot_member).connect + or not origin_channel.permissions_for(bot_member).move_members + ): + return { + "error": f"I lack Connect/Move permissions in the user's current channel ({origin_channel.name})." + } + if ( + not target_channel.permissions_for(bot_member).connect + or not target_channel.permissions_for(bot_member).move_members + ): + return { + "error": f"I lack Connect/Move permissions in the target channel ({target_channel.name})." + } # Cannot move user if bot's top role is not higher (unless bot is owner) if bot_member.id != guild.owner_id and bot_member.top_role <= member.top_role: - return {"error": f"Cannot move {member.display_name} due to role hierarchy."} + return { + "error": f"Cannot move {member.display_name} due to role hierarchy." + } await member.move_to(target_channel, reason="Moved by Gurt tool") - print(f"Successfully moved {member.display_name} ({user_id}) to voice channel {target_channel.name} ({target_channel_id}).") + print( + f"Successfully moved {member.display_name} ({user_id}) to voice channel {target_channel.name} ({target_channel_id})." + ) return { "status": "success", "user_id": user_id, "user_name": member.display_name, "target_channel_id": target_channel_id, - "target_channel_name": target_channel.name + "target_channel_name": target_channel.name, } except ValueError: @@ -1861,14 +2675,18 @@ async def move_user_to_voice_channel(cog: commands.Cog, user_id: str, target_cha return {"error": f"API error moving user {user_id}: {e}"} except Exception as e: error_message = f"Unexpected error moving user {user_id}: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} + # Tool 10: get_guild_roles async def get_guild_roles(cog: commands.Cog) -> Dict[str, Any]: """Lists all roles in the current server.""" print("Executing get_guild_roles tool.") - if not cog.current_channel or not isinstance(cog.current_channel, discord.abc.GuildChannel): + if not cog.current_channel or not isinstance( + cog.current_channel, discord.abc.GuildChannel + ): return {"error": "Cannot get roles outside of a server channel."} guild = cog.current_channel.guild if not guild: @@ -1877,49 +2695,61 @@ async def get_guild_roles(cog: commands.Cog) -> Dict[str, Any]: try: roles_list = [] # Roles are ordered by position, highest first (excluding @everyone) - for role in reversed(guild.roles): # Iterate from lowest to highest position - if role.name == "@everyone": continue - roles_list.append({ - "id": str(role.id), - "name": role.name, - "color": str(role.color), - "position": role.position, - "is_mentionable": role.mentionable, - "member_count": len(role.members) # Can be slow on large servers - }) + for role in reversed(guild.roles): # Iterate from lowest to highest position + if role.name == "@everyone": + continue + roles_list.append( + { + "id": str(role.id), + "name": role.name, + "color": str(role.color), + "position": role.position, + "is_mentionable": role.mentionable, + "member_count": len(role.members), # Can be slow on large servers + } + ) return { "status": "success", "guild_id": str(guild.id), "roles": roles_list, "count": len(roles_list), - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: error_message = f"Error listing guild roles: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} # Tool 11: assign_role_to_user -async def assign_role_to_user(cog: commands.Cog, user_id: str, role_id: str) -> Dict[str, Any]: +async def assign_role_to_user( + cog: commands.Cog, user_id: str, role_id: str +) -> Dict[str, Any]: """Assigns a specific role to a user.""" print(f"Executing assign_role_to_user tool: User={user_id}, Role={role_id}.") - if not cog.current_channel or not isinstance(cog.current_channel, discord.abc.GuildChannel): + if not cog.current_channel or not isinstance( + cog.current_channel, discord.abc.GuildChannel + ): return {"error": "Cannot manage roles outside of a server context."} guild = cog.current_channel.guild - if not guild: return {"error": "Could not determine server."} + if not guild: + return {"error": "Could not determine server."} try: user_id_int = int(user_id) role_id_int = int(role_id) member = guild.get_member(user_id_int) or await guild.fetch_member(user_id_int) - if not member: return {"error": f"User {user_id} not found in this server."} + if not member: + return {"error": f"User {user_id} not found in this server."} role = guild.get_role(role_id_int) - if not role: return {"error": f"Role {role_id} not found in this server."} - if role.name == "@everyone": return {"error": "Cannot assign the @everyone role."} + if not role: + return {"error": f"Role {role_id} not found in this server."} + if role.name == "@everyone": + return {"error": "Cannot assign the @everyone role."} # Permission Checks bot_member = guild.me @@ -1927,19 +2757,28 @@ async def assign_role_to_user(cog: commands.Cog, user_id: str, role_id: str) -> return {"error": "I lack the 'Manage Roles' permission."} # Check role hierarchy: Bot's top role must be higher than the role being assigned if bot_member.id != guild.owner_id and bot_member.top_role <= role: - return {"error": f"Cannot assign role '{role.name}' because my highest role is not above it."} + return { + "error": f"Cannot assign role '{role.name}' because my highest role is not above it." + } # Check if user already has the role if role in member.roles: - return {"status": "already_has_role", "user_id": user_id, "role_id": role_id, "role_name": role.name} + return { + "status": "already_has_role", + "user_id": user_id, + "role_id": role_id, + "role_name": role.name, + } await member.add_roles(role, reason="Assigned by Gurt tool") - print(f"Successfully assigned role '{role.name}' ({role_id}) to {member.display_name} ({user_id}).") + print( + f"Successfully assigned role '{role.name}' ({role_id}) to {member.display_name} ({user_id})." + ) return { "status": "success", "user_id": user_id, "user_name": member.display_name, "role_id": role_id, - "role_name": role.name + "role_name": role.name, } except ValueError: @@ -1953,29 +2792,41 @@ async def assign_role_to_user(cog: commands.Cog, user_id: str, role_id: str) -> print(f"API error assigning role {role_id} to {user_id}: {e}") return {"error": f"API error assigning role: {e}"} except Exception as e: - error_message = f"Unexpected error assigning role {role_id} to {user_id}: {str(e)}" - print(error_message); traceback.print_exc() + error_message = ( + f"Unexpected error assigning role {role_id} to {user_id}: {str(e)}" + ) + print(error_message) + traceback.print_exc() return {"error": error_message} + # Tool 12: remove_role_from_user -async def remove_role_from_user(cog: commands.Cog, user_id: str, role_id: str) -> Dict[str, Any]: +async def remove_role_from_user( + cog: commands.Cog, user_id: str, role_id: str +) -> Dict[str, Any]: """Removes a specific role from a user.""" print(f"Executing remove_role_from_user tool: User={user_id}, Role={role_id}.") - if not cog.current_channel or not isinstance(cog.current_channel, discord.abc.GuildChannel): + if not cog.current_channel or not isinstance( + cog.current_channel, discord.abc.GuildChannel + ): return {"error": "Cannot manage roles outside of a server context."} guild = cog.current_channel.guild - if not guild: return {"error": "Could not determine server."} + if not guild: + return {"error": "Could not determine server."} try: user_id_int = int(user_id) role_id_int = int(role_id) member = guild.get_member(user_id_int) or await guild.fetch_member(user_id_int) - if not member: return {"error": f"User {user_id} not found in this server."} + if not member: + return {"error": f"User {user_id} not found in this server."} role = guild.get_role(role_id_int) - if not role: return {"error": f"Role {role_id} not found in this server."} - if role.name == "@everyone": return {"error": "Cannot remove the @everyone role."} + if not role: + return {"error": f"Role {role_id} not found in this server."} + if role.name == "@everyone": + return {"error": "Cannot remove the @everyone role."} # Permission Checks bot_member = guild.me @@ -1983,19 +2834,28 @@ async def remove_role_from_user(cog: commands.Cog, user_id: str, role_id: str) - return {"error": "I lack the 'Manage Roles' permission."} # Check role hierarchy: Bot's top role must be higher than the role being removed if bot_member.id != guild.owner_id and bot_member.top_role <= role: - return {"error": f"Cannot remove role '{role.name}' because my highest role is not above it."} + return { + "error": f"Cannot remove role '{role.name}' because my highest role is not above it." + } # Check if user actually has the role if role not in member.roles: - return {"status": "does_not_have_role", "user_id": user_id, "role_id": role_id, "role_name": role.name} + return { + "status": "does_not_have_role", + "user_id": user_id, + "role_id": role_id, + "role_name": role.name, + } await member.remove_roles(role, reason="Removed by Gurt tool") - print(f"Successfully removed role '{role.name}' ({role_id}) from {member.display_name} ({user_id}).") + print( + f"Successfully removed role '{role.name}' ({role_id}) from {member.display_name} ({user_id})." + ) return { "status": "success", "user_id": user_id, "user_name": member.display_name, "role_id": role_id, - "role_name": role.name + "role_name": role.name, } except ValueError: @@ -2009,83 +2869,115 @@ async def remove_role_from_user(cog: commands.Cog, user_id: str, role_id: str) - print(f"API error removing role {role_id} from {user_id}: {e}") return {"error": f"API error removing role: {e}"} except Exception as e: - error_message = f"Unexpected error removing role {role_id} from {user_id}: {str(e)}" - print(error_message); traceback.print_exc() + error_message = ( + f"Unexpected error removing role {role_id} from {user_id}: {str(e)}" + ) + print(error_message) + traceback.print_exc() return {"error": error_message} + # Tool 13: fetch_emoji_list async def fetch_emoji_list(cog: commands.Cog) -> Dict[str, Any]: """Lists all custom emojis available in the current server.""" print("Executing fetch_emoji_list tool.") - if not cog.current_channel or not isinstance(cog.current_channel, discord.abc.GuildChannel): + if not cog.current_channel or not isinstance( + cog.current_channel, discord.abc.GuildChannel + ): return {"error": "Cannot fetch emojis outside of a server context."} guild = cog.current_channel.guild - if not guild: return {"error": "Could not determine server."} + if not guild: + return {"error": "Could not determine server."} try: emojis_list = [] for emoji in guild.emojis: - emojis_list.append({ - "id": str(emoji.id), - "name": emoji.name, - "url": str(emoji.url), - "is_animated": emoji.animated, - "is_managed": emoji.managed, # e.g., Twitch integration emojis - "available": emoji.available, # If the bot can use it - "created_at": emoji.created_at.isoformat() if emoji.created_at else None - }) + emojis_list.append( + { + "id": str(emoji.id), + "name": emoji.name, + "url": str(emoji.url), + "is_animated": emoji.animated, + "is_managed": emoji.managed, # e.g., Twitch integration emojis + "available": emoji.available, # If the bot can use it + "created_at": ( + emoji.created_at.isoformat() if emoji.created_at else None + ), + } + ) return { "status": "success", "guild_id": str(guild.id), "emojis": emojis_list, "count": len(emojis_list), - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: error_message = f"Error fetching emoji list: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} + # Tool 14: get_guild_invites async def get_guild_invites(cog: commands.Cog) -> Dict[str, Any]: """Lists active invite links for the current server. Requires 'Manage Server' permission.""" print("Executing get_guild_invites tool.") - if not cog.current_channel or not isinstance(cog.current_channel, discord.abc.GuildChannel): + if not cog.current_channel or not isinstance( + cog.current_channel, discord.abc.GuildChannel + ): return {"error": "Cannot get invites outside of a server context."} guild = cog.current_channel.guild - if not guild: return {"error": "Could not determine server."} + if not guild: + return {"error": "Could not determine server."} # Permission Check bot_member = guild.me if not bot_member.guild_permissions.manage_guild: - return {"error": "I lack the 'Manage Server' permission required to view invites."} + return { + "error": "I lack the 'Manage Server' permission required to view invites." + } try: invites = await guild.invites() invites_list = [] for invite in invites: - inviter_info = {"id": str(invite.inviter.id), "name": invite.inviter.name} if invite.inviter else None - channel_info = {"id": str(invite.channel.id), "name": invite.channel.name} if invite.channel else None - invites_list.append({ - "code": invite.code, - "url": invite.url, - "inviter": inviter_info, - "channel": channel_info, - "uses": invite.uses, - "max_uses": invite.max_uses, - "max_age": invite.max_age, # In seconds, 0 means infinite - "is_temporary": invite.temporary, - "created_at": invite.created_at.isoformat() if invite.created_at else None, - "expires_at": invite.expires_at.isoformat() if invite.expires_at else None, - }) + inviter_info = ( + {"id": str(invite.inviter.id), "name": invite.inviter.name} + if invite.inviter + else None + ) + channel_info = ( + {"id": str(invite.channel.id), "name": invite.channel.name} + if invite.channel + else None + ) + invites_list.append( + { + "code": invite.code, + "url": invite.url, + "inviter": inviter_info, + "channel": channel_info, + "uses": invite.uses, + "max_uses": invite.max_uses, + "max_age": invite.max_age, # In seconds, 0 means infinite + "is_temporary": invite.temporary, + "created_at": ( + invite.created_at.isoformat() if invite.created_at else None + ), + "expires_at": ( + invite.expires_at.isoformat() if invite.expires_at else None + ), + } + ) return { "status": "success", "guild_id": str(guild.id), "invites": invites_list, "count": len(invites_list), - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } except discord.Forbidden: # Should be caught by initial check, but good practice @@ -2095,38 +2987,66 @@ async def get_guild_invites(cog: commands.Cog) -> Dict[str, Any]: return {"error": f"API error getting invites: {e}"} except Exception as e: error_message = f"Unexpected error getting invites: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} + # Tool 15: purge_messages -async def purge_messages(cog: commands.Cog, limit: int, channel_id: Optional[str] = None, user_id: Optional[str] = None, before_message_id: Optional[str] = None, after_message_id: Optional[str] = None) -> Dict[str, Any]: +async def purge_messages( + cog: commands.Cog, + limit: int, + channel_id: Optional[str] = None, + user_id: Optional[str] = None, + before_message_id: Optional[str] = None, + after_message_id: Optional[str] = None, +) -> Dict[str, Any]: """Bulk deletes messages in a channel. Requires 'Manage Messages' permission.""" - print(f"Executing purge_messages tool: Limit={limit}, Channel={channel_id}, User={user_id}, Before={before_message_id}, After={after_message_id}.") - if not 1 <= limit <= 1000: # Discord's practical limit is often lower, but API allows up to 100 per call + print( + f"Executing purge_messages tool: Limit={limit}, Channel={channel_id}, User={user_id}, Before={before_message_id}, After={after_message_id}." + ) + if ( + not 1 <= limit <= 1000 + ): # Discord's practical limit is often lower, but API allows up to 100 per call return {"error": "Limit must be between 1 and 1000."} try: if channel_id: channel = cog.bot.get_channel(int(channel_id)) - if not channel: return {"error": f"Channel {channel_id} not found."} + if not channel: + return {"error": f"Channel {channel_id} not found."} else: channel = cog.current_channel - if not channel: return {"error": "No current channel context."} - if not isinstance(channel, discord.TextChannel): # Purge usually only for text channels - return {"error": f"Channel {getattr(channel, 'id', 'N/A')} must be a text channel."} + if not channel: + return {"error": "No current channel context."} + if not isinstance( + channel, discord.TextChannel + ): # Purge usually only for text channels + return { + "error": f"Channel {getattr(channel, 'id', 'N/A')} must be a text channel." + } # Permission Check bot_member = channel.guild.me if not channel.permissions_for(bot_member).manage_messages: - return {"error": "I lack the 'Manage Messages' permission required to purge."} + return { + "error": "I lack the 'Manage Messages' permission required to purge." + } target_user = None if user_id: - target_user = await cog.bot.fetch_user(int(user_id)) # Fetch user object if ID provided - if not target_user: return {"error": f"User {user_id} not found."} + target_user = await cog.bot.fetch_user( + int(user_id) + ) # Fetch user object if ID provided + if not target_user: + return {"error": f"User {user_id} not found."} - before_obj = discord.Object(id=int(before_message_id)) if before_message_id else None - after_obj = discord.Object(id=int(after_message_id)) if after_message_id else None + before_obj = ( + discord.Object(id=int(before_message_id)) if before_message_id else None + ) + after_obj = ( + discord.Object(id=int(after_message_id)) if after_message_id else None + ) check_func = (lambda m: m.author == target_user) if target_user else None @@ -2136,22 +3056,30 @@ async def purge_messages(cog: commands.Cog, limit: int, channel_id: Optional[str check=check_func, before=before_obj, after=after_obj, - reason="Purged by Gurt tool" + reason="Purged by Gurt tool", ) deleted_count = len(deleted_messages) - print(f"Successfully purged {deleted_count} messages from channel {channel.id}.") + print( + f"Successfully purged {deleted_count} messages from channel {channel.id}." + ) return { "status": "success", "channel_id": str(channel.id), "deleted_count": deleted_count, "limit_requested": limit, - "filters_applied": {"user_id": user_id, "before": before_message_id, "after": after_message_id}, - "timestamp": datetime.datetime.now().isoformat() + "filters_applied": { + "user_id": user_id, + "before": before_message_id, + "after": after_message_id, + }, + "timestamp": datetime.datetime.now().isoformat(), } except ValueError: - return {"error": "Invalid ID format for channel, user, before, or after message."} + return { + "error": "Invalid ID format for channel, user, before, or after message." + } except discord.NotFound: return {"error": "Channel, user, before, or after message not found."} except discord.Forbidden: @@ -2160,11 +3088,14 @@ async def purge_messages(cog: commands.Cog, limit: int, channel_id: Optional[str print(f"API error purging messages: {e}") # Provide more specific feedback if possible (e.g., messages too old) if "too old" in str(e).lower(): - return {"error": "API error: Cannot bulk delete messages older than 14 days."} + return { + "error": "API error: Cannot bulk delete messages older than 14 days." + } return {"error": f"API error purging messages: {e}"} except Exception as e: error_message = f"Unexpected error purging messages: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} @@ -2176,14 +3107,19 @@ async def get_bot_stats(cog: commands.Cog) -> Dict[str, Any]: try: # Example stats (replace with actual data sources) guild_count = len(cog.bot.guilds) - user_count = len(cog.bot.users) # Might not be accurate without intents - total_users = sum(g.member_count for g in cog.bot.guilds if g.member_count) # Requires member intent + user_count = len(cog.bot.users) # Might not be accurate without intents + total_users = sum( + g.member_count for g in cog.bot.guilds if g.member_count + ) # Requires member intent latency_ms = round(cog.bot.latency * 1000) # Command usage would need tracking within the cog/bot - command_count = cog.command_usage_count if hasattr(cog, 'command_usage_count') else "N/A" + command_count = ( + cog.command_usage_count if hasattr(cog, "command_usage_count") else "N/A" + ) # Memory usage (platform specific, using psutil is common) try: import psutil + process = psutil.Process(os.getpid()) memory_mb = round(process.memory_info().rss / (1024 * 1024), 2) except ImportError: @@ -2191,26 +3127,28 @@ async def get_bot_stats(cog: commands.Cog) -> Dict[str, Any]: except Exception as mem_e: memory_mb = f"Error ({mem_e})" - uptime_dict = await get_bot_uptime(cog) # Reuse uptime tool + uptime_dict = await get_bot_uptime(cog) # Reuse uptime tool return { "status": "success", "guild_count": guild_count, "cached_user_count": user_count, - "total_member_count_approx": total_users, # Note intent requirement + "total_member_count_approx": total_users, # Note intent requirement "latency_ms": latency_ms, "command_usage_count": command_count, "memory_usage_mb": memory_mb, - "uptime_info": uptime_dict, # Include uptime details + "uptime_info": uptime_dict, # Include uptime details "python_version": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", "discord_py_version": discord.__version__, - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: error_message = f"Error getting bot stats: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} + # Tool 17: get_weather (Placeholder - Requires Weather API) async def get_weather(cog: commands.Cog, location: str) -> Dict[str, Any]: """Gets the current weather for a specified location (requires external API setup).""" @@ -2241,13 +3179,21 @@ async def get_weather(cog: commands.Cog, location: str) -> Dict[str, Any]: "status": "placeholder", "error": "Weather tool not fully implemented. Requires external API integration.", "location_requested": location, - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } + # Tool 18: translate_text (Placeholder - Requires Translation API) -async def translate_text(cog: commands.Cog, text: str, target_language: str, source_language: Optional[str] = None) -> Dict[str, Any]: +async def translate_text( + cog: commands.Cog, + text: str, + target_language: str, + source_language: Optional[str] = None, +) -> Dict[str, Any]: """Translates text to a target language (requires external API setup).""" - print(f"Executing translate_text tool: Target={target_language}, Source={source_language}, Text='{text[:50]}...'") + print( + f"Executing translate_text tool: Target={target_language}, Source={source_language}, Text='{text[:50]}...'" + ) # --- Placeholder Implementation --- # A real implementation would use a translation API (e.g., Google Translate API, DeepL) # It would require API keys/credentials and use a suitable library or aiohttp. @@ -2276,13 +3222,18 @@ async def translate_text(cog: commands.Cog, text: str, target_language: str, sou "error": "Translation tool not fully implemented. Requires external API integration.", "text_preview": text[:100], "target_language": target_language, - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } + # Tool 19: remind_user (Placeholder - Requires Scheduler/DB) -async def remind_user(cog: commands.Cog, user_id: str, reminder_text: str, remind_at_iso: str) -> Dict[str, Any]: +async def remind_user( + cog: commands.Cog, user_id: str, reminder_text: str, remind_at_iso: str +) -> Dict[str, Any]: """Sets a reminder for a user to be delivered via DM at a specific time.""" - print(f"Executing remind_user tool: User={user_id}, Time={remind_at_iso}, Reminder='{reminder_text[:50]}...'") + print( + f"Executing remind_user tool: User={user_id}, Time={remind_at_iso}, Reminder='{reminder_text[:50]}...'" + ) # --- Placeholder Implementation --- # This requires a persistent scheduler (like APScheduler) and likely a way to store reminders # in case the bot restarts. It also needs to fetch the user and send a DM. @@ -2327,11 +3278,14 @@ async def remind_user(cog: commands.Cog, user_id: str, reminder_text: str, remin "user_id": user_id, "reminder_text_preview": reminder_text[:100], "remind_at_iso": remind_at_iso, - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } + # Tool 20: fetch_random_image (Placeholder - Requires Image API/Source) -async def fetch_random_image(cog: commands.Cog, query: Optional[str] = None) -> Dict[str, Any]: +async def fetch_random_image( + cog: commands.Cog, query: Optional[str] = None +) -> Dict[str, Any]: """Fetches a random image, optionally based on a query (requires external API setup).""" print(f"Executing fetch_random_image tool: Query='{query}'") # --- Placeholder Implementation --- @@ -2362,17 +3316,18 @@ async def fetch_random_image(cog: commands.Cog, query: Optional[str] = None) -> "status": "placeholder", "error": "Random image tool not fully implemented. Requires external API integration.", "query": query, - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } # --- Random System/Meme Tools --- + async def read_temps(cog: commands.Cog) -> Dict[str, Any]: """Reads the system temperatures using the 'sensors' command (Linux/Unix).""" import platform import subprocess - import asyncio # Ensure asyncio is imported + import asyncio # Ensure asyncio is imported try: if platform.system() == "Windows": @@ -2380,7 +3335,7 @@ async def read_temps(cog: commands.Cog) -> Dict[str, Any]: return { "status": "not_supported", "output": None, - "error": "The 'sensors' command is typically not available on Windows." + "error": "The 'sensors' command is typically not available on Windows.", } else: # Try to run the 'sensors' command @@ -2388,56 +3343,58 @@ async def read_temps(cog: commands.Cog) -> Dict[str, Any]: proc = await asyncio.create_subprocess_shell( "sensors", stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stderr=asyncio.subprocess.PIPE, ) - stdout_bytes, stderr_bytes = await asyncio.wait_for(proc.communicate(), timeout=10) # Add timeout + stdout_bytes, stderr_bytes = await asyncio.wait_for( + proc.communicate(), timeout=10 + ) # Add timeout stdout = stdout_bytes.decode(errors="replace").strip() stderr = stderr_bytes.decode(errors="replace").strip() if proc.returncode == 0: # Command succeeded, return the full stdout - max_len = 1800 # Limit output length slightly - stdout_trunc = stdout[:max_len] + ('...' if len(stdout) > max_len else '') + max_len = 1800 # Limit output length slightly + stdout_trunc = stdout[:max_len] + ( + "..." if len(stdout) > max_len else "" + ) return { "status": "success", - "output": stdout_trunc, # Return truncated stdout - "error": None + "output": stdout_trunc, # Return truncated stdout + "error": None, } else: # Command failed - error_msg = f"'sensors' command failed with exit code {proc.returncode}." + error_msg = ( + f"'sensors' command failed with exit code {proc.returncode}." + ) if stderr: - error_msg += f" Stderr: {stderr[:200]}" # Include some stderr + error_msg += f" Stderr: {stderr[:200]}" # Include some stderr print(f"read_temps error: {error_msg}") return { "status": "execution_error", "output": None, - "error": error_msg + "error": error_msg, } except FileNotFoundError: - print("read_temps error: 'sensors' command not found.") - return { + print("read_temps error: 'sensors' command not found.") + return { "status": "error", "output": None, - "error": "'sensors' command not found. Is lm-sensors installed and in PATH?" - } + "error": "'sensors' command not found. Is lm-sensors installed and in PATH?", + } except asyncio.TimeoutError: print("read_temps error: 'sensors' command timed out.") return { "status": "timeout", "output": None, - "error": "'sensors' command timed out after 10 seconds." + "error": "'sensors' command timed out after 10 seconds.", } except Exception as cmd_e: # Catch other potential errors during subprocess execution error_msg = f"Error running 'sensors' command: {str(cmd_e)}" print(f"read_temps error: {error_msg}") traceback.print_exc() - return { - "status": "error", - "output": None, - "error": error_msg - } + return {"status": "error", "output": None, "error": error_msg} except Exception as e: # Catch unexpected errors in the function itself error_msg = f"Unexpected error in read_temps: {str(e)}" @@ -2445,12 +3402,14 @@ async def read_temps(cog: commands.Cog) -> Dict[str, Any]: traceback.print_exc() return {"status": "error", "output": None, "error": error_msg} + async def check_disk_space(cog: commands.Cog) -> Dict[str, Any]: """Checks disk space on the main drive.""" import shutil + try: total, used, free = shutil.disk_usage("/") - gb = 1024 ** 3 + gb = 1024**3 percent = round(used / total * 100, 1) return { "status": "success", @@ -2458,11 +3417,12 @@ async def check_disk_space(cog: commands.Cog) -> Dict[str, Any]: "used_gb": round(used / gb, 2), "free_gb": round(free / gb, 2), "percent_used": percent, - "msg": None + "msg": None, } except Exception as e: return {"status": "error", "error": str(e)} + async def fetch_random_joke(cog: commands.Cog) -> Dict[str, Any]: """Fetches a random joke from an API.""" url = "https://official-joke-api.appspot.com/random_joke" @@ -2474,10 +3434,7 @@ async def fetch_random_joke(cog: commands.Cog) -> Dict[str, Any]: data = await resp.json() setup = data.get("setup", "") punchline = data.get("punchline", "") - return { - "status": "success", - "joke": f"{setup} ... {punchline}" - } + return {"status": "success", "joke": f"{setup} ... {punchline}"} else: return {"status": "error", "error": f"API returned {resp.status}"} except Exception as e: @@ -2486,22 +3443,27 @@ async def fetch_random_joke(cog: commands.Cog) -> Dict[str, Any]: # --- New Tools: Guild/Channel Listing --- + async def list_bot_guilds(cog: commands.Cog) -> Dict[str, Any]: """Lists all guilds (servers) the bot is currently connected to.""" print("Executing list_bot_guilds tool.") try: - guilds_list = [{"id": str(guild.id), "name": guild.name} for guild in cog.bot.guilds] + guilds_list = [ + {"id": str(guild.id), "name": guild.name} for guild in cog.bot.guilds + ] return { "status": "success", "guilds": guilds_list, "count": len(guilds_list), - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: error_message = f"Error listing bot guilds: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} + async def list_guild_channels(cog: commands.Cog, guild_id: str) -> Dict[str, Any]: """Lists all channels (text, voice, category, etc.) in a specified guild.""" print(f"Executing list_guild_channels tool for guild ID: {guild_id}.") @@ -2513,15 +3475,19 @@ async def list_guild_channels(cog: commands.Cog, guild_id: str) -> Dict[str, Any channels_list = [] for channel in guild.channels: - channels_list.append({ - "id": str(channel.id), - "name": channel.name, - "type": str(channel.type), - "position": channel.position, - "category_id": str(channel.category_id) if channel.category_id else None - }) + channels_list.append( + { + "id": str(channel.id), + "name": channel.name, + "type": str(channel.type), + "position": channel.position, + "category_id": ( + str(channel.category_id) if channel.category_id else None + ), + } + ) # Sort by position for better readability - channels_list.sort(key=lambda x: x.get('position', 0)) + channels_list.sort(key=lambda x: x.get("position", 0)) return { "status": "success", @@ -2529,60 +3495,85 @@ async def list_guild_channels(cog: commands.Cog, guild_id: str) -> Dict[str, Any "guild_name": guild.name, "channels": channels_list, "count": len(channels_list), - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } except ValueError: return {"error": f"Invalid guild ID format: {guild_id}."} except Exception as e: error_message = f"Error listing channels for guild {guild_id}: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} + async def list_tools(cog: commands.Cog) -> Dict[str, Any]: """Lists all available tools with their names and descriptions.""" print("Executing list_tools tool.") try: # TOOLS is imported from .config - tool_list = [{"name": tool.name, "description": tool.description} for tool in TOOLS] + tool_list = [ + {"name": tool.name, "description": tool.description} for tool in TOOLS + ] return { "status": "success", "tools": tool_list, "count": len(tool_list), - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: error_message = f"Error listing tools: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} + # --- User Profile Tools --- -async def _get_user_or_member(cog: commands.Cog, user_id_str: str) -> Tuple[Optional[Union[discord.User, discord.Member]], Optional[Dict[str, Any]]]: + +async def _get_user_or_member( + cog: commands.Cog, user_id_str: str +) -> Tuple[Optional[Union[discord.User, discord.Member]], Optional[Dict[str, Any]]]: """Helper to fetch a User or Member object, handling errors.""" try: user_id = int(user_id_str) user_or_member = cog.bot.get_user(user_id) # If in a guild context, try to get the Member object for more info (status, roles, etc.) - if not user_or_member and cog.current_channel and isinstance(cog.current_channel, discord.abc.GuildChannel): + if ( + not user_or_member + and cog.current_channel + and isinstance(cog.current_channel, discord.abc.GuildChannel) + ): guild = cog.current_channel.guild if guild: print(f"Attempting to fetch member {user_id} from guild {guild.id}") try: - user_or_member = guild.get_member(user_id) or await guild.fetch_member(user_id) + user_or_member = guild.get_member( + user_id + ) or await guild.fetch_member(user_id) except discord.NotFound: - print(f"Member {user_id} not found in guild {guild.id}. Falling back to fetch_user.") + print( + f"Member {user_id} not found in guild {guild.id}. Falling back to fetch_user." + ) # Fallback to fetching user if not found as member - try: user_or_member = await cog.bot.fetch_user(user_id) - except discord.NotFound: pass # Handled below + try: + user_or_member = await cog.bot.fetch_user(user_id) + except discord.NotFound: + pass # Handled below except discord.Forbidden: - print(f"Forbidden to fetch member {user_id} from guild {guild.id}. Falling back to fetch_user.") - try: user_or_member = await cog.bot.fetch_user(user_id) - except discord.NotFound: pass # Handled below + print( + f"Forbidden to fetch member {user_id} from guild {guild.id}. Falling back to fetch_user." + ) + try: + user_or_member = await cog.bot.fetch_user(user_id) + except discord.NotFound: + pass # Handled below # If still not found, try fetching globally if not user_or_member: - print(f"User/Member {user_id} not in cache or guild, attempting global fetch_user.") + print( + f"User/Member {user_id} not in cache or guild, attempting global fetch_user." + ) try: user_or_member = await cog.bot.fetch_user(user_id) except discord.NotFound: @@ -2592,15 +3583,18 @@ async def _get_user_or_member(cog: commands.Cog, user_id_str: str) -> Tuple[Opti print(f"HTTP error fetching user {user_id}: {e}") return None, {"error": f"API error fetching user {user_id_str}: {e}"} - if not user_or_member: # Should be caught by NotFound above, but double-check - return None, {"error": f"User with ID {user_id_str} could not be retrieved."} + if not user_or_member: # Should be caught by NotFound above, but double-check + return None, { + "error": f"User with ID {user_id_str} could not be retrieved." + } - return user_or_member, None # Return the user/member object and no error + return user_or_member, None # Return the user/member object and no error except ValueError: return None, {"error": f"Invalid user ID format: {user_id_str}."} except Exception as e: error_message = f"Unexpected error fetching user/member {user_id_str}: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return None, {"error": error_message} @@ -2608,22 +3602,29 @@ async def get_user_username(cog: commands.Cog, user_id: str) -> Dict[str, Any]: """Gets the unique Discord username (e.g., username#1234) for a given user ID.""" print(f"Executing get_user_username for user ID: {user_id}.") user_obj, error_resp = await _get_user_or_member(cog, user_id) - if error_resp: return error_resp - if not user_obj: return {"error": f"Failed to retrieve user object for ID {user_id}."} # Should not happen if error_resp is None + if error_resp: + return error_resp + if not user_obj: + return { + "error": f"Failed to retrieve user object for ID {user_id}." + } # Should not happen if error_resp is None return { "status": "success", "user_id": user_id, - "username": str(user_obj), # User.__str__() gives username#discriminator - "timestamp": datetime.datetime.now().isoformat() + "username": str(user_obj), # User.__str__() gives username#discriminator + "timestamp": datetime.datetime.now().isoformat(), } + async def get_user_display_name(cog: commands.Cog, user_id: str) -> Dict[str, Any]: """Gets the display name for a given user ID (server nickname if in a guild, otherwise global name).""" print(f"Executing get_user_display_name for user ID: {user_id}.") user_obj, error_resp = await _get_user_or_member(cog, user_id) - if error_resp: return error_resp - if not user_obj: return {"error": f"Failed to retrieve user object for ID {user_id}."} + if error_resp: + return error_resp + if not user_obj: + return {"error": f"Failed to retrieve user object for ID {user_id}."} # user_obj could be User or Member. display_name works for both. display_name = user_obj.display_name @@ -2632,15 +3633,18 @@ async def get_user_display_name(cog: commands.Cog, user_id: str) -> Dict[str, An "status": "success", "user_id": user_id, "display_name": display_name, - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } + async def get_user_avatar_url(cog: commands.Cog, user_id: str) -> Dict[str, Any]: """Gets the URL of the user's current avatar (server-specific if available, otherwise global).""" print(f"Executing get_user_avatar_url for user ID: {user_id}.") user_obj, error_resp = await _get_user_or_member(cog, user_id) - if error_resp: return error_resp - if not user_obj: return {"error": f"Failed to retrieve user object for ID {user_id}."} + if error_resp: + return error_resp + if not user_obj: + return {"error": f"Failed to retrieve user object for ID {user_id}."} # .display_avatar handles server vs global avatar automatically avatar_url = str(user_obj.display_avatar.url) @@ -2649,15 +3653,18 @@ async def get_user_avatar_url(cog: commands.Cog, user_id: str) -> Dict[str, Any] "status": "success", "user_id": user_id, "avatar_url": avatar_url, - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } + async def get_user_status(cog: commands.Cog, user_id: str) -> Dict[str, Any]: """Gets the current status (online, idle, dnd, offline) of a user. Requires guild context.""" print(f"Executing get_user_status for user ID: {user_id}.") user_obj, error_resp = await _get_user_or_member(cog, user_id) - if error_resp: return error_resp - if not user_obj: return {"error": f"Failed to retrieve user object for ID {user_id}."} + if error_resp: + return error_resp + if not user_obj: + return {"error": f"Failed to retrieve user object for ID {user_id}."} if isinstance(user_obj, discord.Member): status_str = str(user_obj.status) @@ -2666,97 +3673,144 @@ async def get_user_status(cog: commands.Cog, user_id: str) -> Dict[str, Any]: "user_id": user_id, "user_status": status_str, "guild_id": str(user_obj.guild.id), - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } else: # If we only have a User object, status isn't directly available without presence intent/cache. - return {"error": f"Cannot determine status for user {user_id} outside of a shared server or without presence intent.", "user_id": user_id} + return { + "error": f"Cannot determine status for user {user_id} outside of a shared server or without presence intent.", + "user_id": user_id, + } + async def get_user_activity(cog: commands.Cog, user_id: str) -> Dict[str, Any]: """Gets the current activity (e.g., Playing game) of a user. Requires guild context.""" print(f"Executing get_user_activity for user ID: {user_id}.") user_obj, error_resp = await _get_user_or_member(cog, user_id) - if error_resp: return error_resp - if not user_obj: return {"error": f"Failed to retrieve user object for ID {user_id}."} + if error_resp: + return error_resp + if not user_obj: + return {"error": f"Failed to retrieve user object for ID {user_id}."} activity_info = None if isinstance(user_obj, discord.Member) and user_obj.activity: activity = user_obj.activity - activity_type_str = str(activity.type).split('.')[-1] # e.g., 'playing', 'streaming', 'listening' + activity_type_str = str(activity.type).split(".")[ + -1 + ] # e.g., 'playing', 'streaming', 'listening' activity_details = {"type": activity_type_str, "name": activity.name} # Add more details based on activity type if isinstance(activity, discord.Game): - if hasattr(activity, 'start'): activity_details["start_time"] = activity.start.isoformat() if activity.start else None - if hasattr(activity, 'end'): activity_details["end_time"] = activity.end.isoformat() if activity.end else None + if hasattr(activity, "start"): + activity_details["start_time"] = ( + activity.start.isoformat() if activity.start else None + ) + if hasattr(activity, "end"): + activity_details["end_time"] = ( + activity.end.isoformat() if activity.end else None + ) elif isinstance(activity, discord.Streaming): - activity_details.update({"platform": activity.platform, "url": activity.url, "game": activity.game}) + activity_details.update( + { + "platform": activity.platform, + "url": activity.url, + "game": activity.game, + } + ) elif isinstance(activity, discord.Spotify): - activity_details.update({ - "title": activity.title, "artist": activity.artist, "album": activity.album, - "album_cover_url": activity.album_cover_url, "track_id": activity.track_id, - "duration": str(activity.duration), - "start": activity.start.isoformat() if activity.start else None, - "end": activity.end.isoformat() if activity.end else None - }) + activity_details.update( + { + "title": activity.title, + "artist": activity.artist, + "album": activity.album, + "album_cover_url": activity.album_cover_url, + "track_id": activity.track_id, + "duration": str(activity.duration), + "start": activity.start.isoformat() if activity.start else None, + "end": activity.end.isoformat() if activity.end else None, + } + ) elif isinstance(activity, discord.CustomActivity): - activity_details.update({"custom_text": activity.name, "emoji": str(activity.emoji) if activity.emoji else None}) - activity_details["name"] = activity.name # Override generic name with the custom text + activity_details.update( + { + "custom_text": activity.name, + "emoji": str(activity.emoji) if activity.emoji else None, + } + ) + activity_details["name"] = ( + activity.name + ) # Override generic name with the custom text # Add other activity types if needed (Listening, Watching) activity_info = activity_details status = "success" guild_id = str(user_obj.guild.id) elif isinstance(user_obj, discord.Member): - status = "success" # Found member but they have no activity + status = "success" # Found member but they have no activity guild_id = str(user_obj.guild.id) else: - return {"error": f"Cannot determine activity for user {user_id} outside of a shared server.", "user_id": user_id} + return { + "error": f"Cannot determine activity for user {user_id} outside of a shared server.", + "user_id": user_id, + } return { "status": status, "user_id": user_id, - "activity": activity_info, # Will be None if no activity + "activity": activity_info, # Will be None if no activity "guild_id": guild_id if isinstance(user_obj, discord.Member) else None, - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } + async def get_user_roles(cog: commands.Cog, user_id: str) -> Dict[str, Any]: """Gets the list of roles for a user in the current server. Requires guild context.""" print(f"Executing get_user_roles for user ID: {user_id}.") user_obj, error_resp = await _get_user_or_member(cog, user_id) - if error_resp: return error_resp - if not user_obj: return {"error": f"Failed to retrieve user object for ID {user_id}."} + if error_resp: + return error_resp + if not user_obj: + return {"error": f"Failed to retrieve user object for ID {user_id}."} if isinstance(user_obj, discord.Member): roles_list = [] # Sort roles by position (highest first), excluding @everyone sorted_roles = sorted(user_obj.roles, key=lambda r: r.position, reverse=True) for role in sorted_roles: - if role.is_default(): continue # Skip @everyone - roles_list.append({ - "id": str(role.id), - "name": role.name, - "color": str(role.color), - "position": role.position - }) + if role.is_default(): + continue # Skip @everyone + roles_list.append( + { + "id": str(role.id), + "name": role.name, + "color": str(role.color), + "position": role.position, + } + ) return { "status": "success", "user_id": user_id, "roles": roles_list, "role_count": len(roles_list), "guild_id": str(user_obj.guild.id), - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } else: - return {"error": f"Cannot determine roles for user {user_id} outside of a shared server.", "user_id": user_id} + return { + "error": f"Cannot determine roles for user {user_id} outside of a shared server.", + "user_id": user_id, + } + async def get_user_profile_info(cog: commands.Cog, user_id: str) -> Dict[str, Any]: """Gets comprehensive profile information for a given user ID.""" print(f"Executing get_user_profile_info for user ID: {user_id}.") user_obj, error_resp = await _get_user_or_member(cog, user_id) - if error_resp: return error_resp - if not user_obj: return {"error": f"Failed to retrieve user object for ID {user_id}."} + if error_resp: + return error_resp + if not user_obj: + return {"error": f"Failed to retrieve user object for ID {user_id}."} profile_info = { "user_id": user_id, @@ -2778,9 +3832,11 @@ async def get_user_profile_info(cog: commands.Cog, user_id: str) -> Dict[str, An if isinstance(user_obj, discord.Member): profile_info["status"] = str(user_obj.status) - profile_info["joined_at"] = user_obj.joined_at.isoformat() if user_obj.joined_at else None + profile_info["joined_at"] = ( + user_obj.joined_at.isoformat() if user_obj.joined_at else None + ) profile_info["guild_id"] = str(user_obj.guild.id) - profile_info["nickname"] = user_obj.nick # Store specific nickname + profile_info["nickname"] = user_obj.nick # Store specific nickname # Get Activity activity_result = await get_user_activity(cog, user_id) @@ -2799,21 +3855,27 @@ async def get_user_profile_info(cog: commands.Cog, user_id: str) -> Dict[str, An profile_info["voice_state"] = { "channel_id": str(voice.channel.id) if voice.channel else None, "channel_name": voice.channel.name if voice.channel else None, - "deaf": voice.deaf, "mute": voice.mute, - "self_deaf": voice.self_deaf, "self_mute": voice.self_mute, - "self_stream": voice.self_stream, "self_video": voice.self_video, - "suppress": voice.suppress, "afk": voice.afk, - "session_id": voice.session_id + "deaf": voice.deaf, + "mute": voice.mute, + "self_deaf": voice.self_deaf, + "self_mute": voice.self_mute, + "self_stream": voice.self_stream, + "self_video": voice.self_video, + "suppress": voice.suppress, + "afk": voice.afk, + "session_id": voice.session_id, } return { "status": "success", "profile": profile_info, - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } + # --- End User Profile Tools --- + async def get_user_avatar_data(cog: commands.Cog, user_id: str) -> Dict[str, Any]: """ Gets the user's avatar URL, content type, and base64 encoded image data. @@ -2821,8 +3883,10 @@ async def get_user_avatar_data(cog: commands.Cog, user_id: str) -> Dict[str, Any """ print(f"Executing get_user_avatar_data for user ID: {user_id}.") user_obj, error_resp = await _get_user_or_member(cog, user_id) - if error_resp: return error_resp - if not user_obj: return {"error": f"Failed to retrieve user object for ID {user_id}."} + if error_resp: + return error_resp + if not user_obj: + return {"error": f"Failed to retrieve user object for ID {user_id}."} avatar_asset = user_obj.display_avatar avatar_url = str(avatar_asset.url) @@ -2834,49 +3898,66 @@ async def get_user_avatar_data(cog: commands.Cog, user_id: str) -> Dict[str, Any async with cog.session.get(avatar_url) as response: if response.status == 200: image_bytes = await response.read() - content_type = response.headers.get("Content-Type", "application/octet-stream") - base64_data = base64.b64encode(image_bytes).decode('utf-8') + content_type = response.headers.get( + "Content-Type", "application/octet-stream" + ) + base64_data = base64.b64encode(image_bytes).decode("utf-8") return { "status": "success", "user_id": user_id, "avatar_url": avatar_url, "content_type": content_type, "base64_data": base64_data, - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } else: error_message = f"Failed to fetch avatar image from {avatar_url}. Status: {response.status}" print(error_message) - return {"error": error_message, "user_id": user_id, "avatar_url": avatar_url} + return { + "error": error_message, + "user_id": user_id, + "avatar_url": avatar_url, + } except aiohttp.ClientError as e: error_message = f"Network error fetching avatar from {avatar_url}: {str(e)}" print(error_message) return {"error": error_message, "user_id": user_id, "avatar_url": avatar_url} except Exception as e: error_message = f"Unexpected error fetching avatar data for {user_id}: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message, "user_id": user_id} -async def get_user_highest_role_color(cog: commands.Cog, user_id: str) -> Dict[str, Any]: + +async def get_user_highest_role_color( + cog: commands.Cog, user_id: str +) -> Dict[str, Any]: """ Gets the color of the user's highest positioned role that has a color applied. Returns the role name and hex color string. """ print(f"Executing get_user_highest_role_color for user ID: {user_id}.") user_obj, error_resp = await _get_user_or_member(cog, user_id) - if error_resp: return error_resp - if not user_obj: return {"error": f"Failed to retrieve user object for ID {user_id}."} + if error_resp: + return error_resp + if not user_obj: + return {"error": f"Failed to retrieve user object for ID {user_id}."} if not isinstance(user_obj, discord.Member): - return {"error": f"User {user_id} is not a member of a server in the current context. Cannot get role color.", "user_id": user_id} + return { + "error": f"User {user_id} is not a member of a server in the current context. Cannot get role color.", + "user_id": user_id, + } # Roles are already sorted by position, member.roles[0] is @everyone, member.roles[-1] is highest. # We need to iterate from highest to lowest that has a color. highest_colored_role = None - for role in reversed(user_obj.roles): # Iterate from highest position downwards - if role.color != discord.Color.default(): # Check if the role has a non-default color + for role in reversed(user_obj.roles): # Iterate from highest position downwards + if ( + role.color != discord.Color.default() + ): # Check if the role has a non-default color highest_colored_role = role - break # Found the highest role with a color + break # Found the highest role with a color if highest_colored_role: return { @@ -2884,10 +3965,10 @@ async def get_user_highest_role_color(cog: commands.Cog, user_id: str) -> Dict[s "user_id": user_id, "role_name": highest_colored_role.name, "role_id": str(highest_colored_role.id), - "color_hex": str(highest_colored_role.color), # Returns like #RRGGBB - "color_rgb": highest_colored_role.color.to_rgb(), # Returns (r, g, b) tuple + "color_hex": str(highest_colored_role.color), # Returns like #RRGGBB + "color_rgb": highest_colored_role.color.to_rgb(), # Returns (r, g, b) tuple "guild_id": str(user_obj.guild.id), - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } else: return { @@ -2895,9 +3976,10 @@ async def get_user_highest_role_color(cog: commands.Cog, user_id: str) -> Dict[s "user_id": user_id, "message": "User has no roles with a custom color.", "guild_id": str(user_obj.guild.id), - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } + # --- Tool Mapping --- # This dictionary maps tool names (used in the AI prompt) to their implementation functions. TOOL_MAPPING = { @@ -2912,10 +3994,16 @@ TOOL_MAPPING = { "get_message_context": get_message_context, "web_search": web_search, # Point memory tools to the methods on the MemoryManager instance (accessed via cog) - "remember_user_fact": lambda cog, **kwargs: cog.memory_manager.add_user_fact(**kwargs), + "remember_user_fact": lambda cog, **kwargs: cog.memory_manager.add_user_fact( + **kwargs + ), "get_user_facts": lambda cog, **kwargs: cog.memory_manager.get_user_facts(**kwargs), - "remember_general_fact": lambda cog, **kwargs: cog.memory_manager.add_general_fact(**kwargs), - "get_general_facts": lambda cog, **kwargs: cog.memory_manager.get_general_facts(**kwargs), + "remember_general_fact": lambda cog, **kwargs: cog.memory_manager.add_general_fact( + **kwargs + ), + "get_general_facts": lambda cog, **kwargs: cog.memory_manager.get_general_facts( + **kwargs + ), "timeout_user": timeout_user, "calculate": calculate, "run_python_code": run_python_code, @@ -2923,17 +4011,17 @@ TOOL_MAPPING = { "run_terminal_command": run_terminal_command, "remove_timeout": remove_timeout, "extract_web_content": extract_web_content, - "read_file_content": read_file_content, # Now unsafe - "write_file_content_unsafe": write_file_content_unsafe, # New unsafe tool - "execute_python_unsafe": execute_python_unsafe, # New unsafe tool - "send_discord_message": send_discord_message, # New tool - "create_new_tool": create_new_tool, # Added the meta-tool - "execute_internal_command": execute_internal_command, # Added internal command execution - "get_user_id": get_user_id, # Added user ID lookup tool - "no_operation": no_operation, # Added no-op tool - "restart_gurt_bot": restart_gurt_bot, # Tool to restart the Gurt bot - "run_git_pull": run_git_pull, # Tool to run git pull on the host - "get_channel_id": get_channel_id, # Tool to get channel id + "read_file_content": read_file_content, # Now unsafe + "write_file_content_unsafe": write_file_content_unsafe, # New unsafe tool + "execute_python_unsafe": execute_python_unsafe, # New unsafe tool + "send_discord_message": send_discord_message, # New tool + "create_new_tool": create_new_tool, # Added the meta-tool + "execute_internal_command": execute_internal_command, # Added internal command execution + "get_user_id": get_user_id, # Added user ID lookup tool + "no_operation": no_operation, # Added no-op tool + "restart_gurt_bot": restart_gurt_bot, # Tool to restart the Gurt bot + "run_git_pull": run_git_pull, # Tool to run git pull on the host + "get_channel_id": get_channel_id, # Tool to get channel id # --- Batch 1 Additions --- "get_guild_info": get_guild_info, "list_guild_members": list_guild_members, @@ -2980,7 +4068,6 @@ TOOL_MAPPING = { "get_user_roles": get_user_roles, "get_user_profile_info": get_user_profile_info, # --- End User Profile Tools --- - # --- New Profile Picture and Role Color Tools --- "get_user_avatar_data": get_user_avatar_data, "get_user_highest_role_color": get_user_highest_role_color, @@ -2988,14 +4075,18 @@ TOOL_MAPPING = { # --- Voice Channel Tools --- + async def join_voice_channel(cog: commands.Cog, channel_id: str) -> Dict[str, Any]: """Connects GURT to a specified voice channel by its ID. GURT will automatically start listening for speech in this channel once connected. Use get_channel_id to find the ID if you only have the name.""" print(f"Executing join_voice_channel tool for channel ID: {channel_id}.") voice_gateway_cog = cog.bot.get_cog("VoiceGatewayCog") if not voice_gateway_cog: return {"status": "error", "error": "VoiceGatewayCog not loaded."} - if not hasattr(voice_gateway_cog, 'connect_to_voice'): - return {"status": "error", "error": "VoiceGatewayCog is missing 'connect_to_voice' method."} + if not hasattr(voice_gateway_cog, "connect_to_voice"): + return { + "status": "error", + "error": "VoiceGatewayCog is missing 'connect_to_voice' method.", + } try: channel_id_int = int(channel_id) @@ -3003,13 +4094,21 @@ async def join_voice_channel(cog: commands.Cog, channel_id: str) -> Dict[str, An if not channel: # Try fetching if not in cache channel = await cog.bot.fetch_channel(channel_id_int) - + if not channel or not isinstance(channel, discord.VoiceChannel): - return {"status": "error", "error": f"Voice channel {channel_id} not found or is not a voice channel."} + return { + "status": "error", + "error": f"Voice channel {channel_id} not found or is not a voice channel.", + } vc, message = await voice_gateway_cog.connect_to_voice(channel) if vc: - return {"status": "success", "message": message, "channel_id": str(vc.channel.id), "channel_name": vc.channel.name} + return { + "status": "success", + "message": message, + "channel_id": str(vc.channel.id), + "channel_name": vc.channel.name, + } else: return {"status": "error", "error": message, "channel_id": channel_id} except ValueError: @@ -3022,14 +4121,18 @@ async def join_voice_channel(cog: commands.Cog, channel_id: str) -> Dict[str, An traceback.print_exc() return {"status": "error", "error": error_message} + async def leave_voice_channel(cog: commands.Cog) -> Dict[str, Any]: """Disconnects GURT from its current voice channel.""" print("Executing leave_voice_channel tool.") voice_gateway_cog = cog.bot.get_cog("VoiceGatewayCog") if not voice_gateway_cog: return {"status": "error", "error": "VoiceGatewayCog not loaded."} - if not hasattr(voice_gateway_cog, 'disconnect_from_voice'): - return {"status": "error", "error": "VoiceGatewayCog is missing 'disconnect_from_voice' method."} + if not hasattr(voice_gateway_cog, "disconnect_from_voice"): + return { + "status": "error", + "error": "VoiceGatewayCog is missing 'disconnect_from_voice' method.", + } if not cog.current_channel or not cog.current_channel.guild: # This tool implies a guild context for voice_client @@ -3037,19 +4140,23 @@ async def leave_voice_channel(cog: commands.Cog) -> Dict[str, Any]: # Let's try to find a guild GURT is in a VC in. active_vc_guild = None for vc in cog.bot.voice_clients: - if vc.is_connected(): # Found one + if vc.is_connected(): # Found one active_vc_guild = vc.guild break if not active_vc_guild: - return {"status": "error", "error": "GURT is not currently in any voice channel or guild context is unclear."} + return { + "status": "error", + "error": "GURT is not currently in any voice channel or guild context is unclear.", + } guild_to_leave = active_vc_guild else: guild_to_leave = cog.current_channel.guild - if not guild_to_leave: - return {"status": "error", "error": "Could not determine the guild to leave voice from."} - + return { + "status": "error", + "error": "Could not determine the guild to leave voice from.", + } success, message = await voice_gateway_cog.disconnect_from_voice(guild_to_leave) if success: @@ -3057,41 +4164,61 @@ async def leave_voice_channel(cog: commands.Cog) -> Dict[str, Any]: else: return {"status": "error", "error": message} -async def speak_in_voice_channel(cog: commands.Cog, text_to_speak: str, tts_provider: Optional[str] = None) -> Dict[str, Any]: + +async def speak_in_voice_channel( + cog: commands.Cog, text_to_speak: str, tts_provider: Optional[str] = None +) -> Dict[str, Any]: """Converts the given text to speech and plays it in GURT's current voice channel. If GURT is not in a voice channel, this tool will indicate an error. The bot will choose a suitable TTS provider automatically if none is specified.""" - print(f"Executing speak_in_voice_channel: Text='{text_to_speak[:50]}...', Provider={tts_provider}") + print( + f"Executing speak_in_voice_channel: Text='{text_to_speak[:50]}...', Provider={tts_provider}" + ) # Determine which voice client to use # Prefer current_channel's guild if available and bot is in VC there active_vc = None if cog.current_channel and cog.current_channel.guild: - if cog.current_channel.guild.voice_client and cog.current_channel.guild.voice_client.is_connected(): + if ( + cog.current_channel.guild.voice_client + and cog.current_channel.guild.voice_client.is_connected() + ): active_vc = cog.current_channel.guild.voice_client - + # If not found via current_channel, check all bot's voice_clients if not active_vc: if cog.bot.voice_clients: - active_vc = cog.bot.voice_clients[0] # Use the first available one + active_vc = cog.bot.voice_clients[0] # Use the first available one else: - return {"status": "error", "error": "GURT is not currently in any voice channel."} + return { + "status": "error", + "error": "GURT is not currently in any voice channel.", + } if not active_vc or not active_vc.is_connected(): return {"status": "error", "error": "GURT is not connected to a voice channel."} # Import GurtConfig for voice channel settings - from .config import VOICE_DEDICATED_TEXT_CHANNEL_ENABLED, VOICE_LOG_SPEECH_TO_DEDICATED_CHANNEL + from .config import ( + VOICE_DEDICATED_TEXT_CHANNEL_ENABLED, + VOICE_LOG_SPEECH_TO_DEDICATED_CHANNEL, + ) tts_cog = cog.bot.get_cog("TTSProviderCog") if not tts_cog: return {"status": "error", "error": "TTSProviderCog not loaded."} - if not hasattr(tts_cog, 'generate_tts_directly'): - return {"status": "error", "error": "TTSProviderCog is missing 'generate_tts_directly' method."} + if not hasattr(tts_cog, "generate_tts_directly"): + return { + "status": "error", + "error": "TTSProviderCog is missing 'generate_tts_directly' method.", + } voice_gateway_cog = cog.bot.get_cog("VoiceGatewayCog") if not voice_gateway_cog: return {"status": "error", "error": "VoiceGatewayCog not loaded."} - if not hasattr(voice_gateway_cog, 'play_audio_file'): - return {"status": "error", "error": "VoiceGatewayCog is missing 'play_audio_file' method."} + if not hasattr(voice_gateway_cog, "play_audio_file"): + return { + "status": "error", + "error": "VoiceGatewayCog is missing 'play_audio_file' method.", + } # Determine TTS provider chosen_provider = tts_provider @@ -3103,46 +4230,78 @@ async def speak_in_voice_channel(cog: commands.Cog, text_to_speak: str, tts_prov chosen_provider = "google_cloud_tts" elif importlib.util.find_spec("gtts"): chosen_provider = "gtts" - else: # Fallback to first available or error + else: # Fallback to first available or error # This logic could be more sophisticated in TTSProviderCog itself - return {"status": "error", "error": "No suitable default TTS provider found or configured."} + return { + "status": "error", + "error": "No suitable default TTS provider found or configured.", + } print(f"No TTS provider specified, defaulting to: {chosen_provider}") - - success, audio_path_or_error = await tts_cog.generate_tts_directly(provider=chosen_provider, text=text_to_speak) + success, audio_path_or_error = await tts_cog.generate_tts_directly( + provider=chosen_provider, text=text_to_speak + ) if not success: - return {"status": "error", "error": f"TTS generation failed: {audio_path_or_error}"} + return { + "status": "error", + "error": f"TTS generation failed: {audio_path_or_error}", + } audio_file_path = audio_path_or_error - play_success, play_message = await voice_gateway_cog.play_audio_file(active_vc, audio_file_path) + play_success, play_message = await voice_gateway_cog.play_audio_file( + active_vc, audio_file_path + ) if play_success: # Log to dedicated text channel if enabled - if VOICE_DEDICATED_TEXT_CHANNEL_ENABLED and VOICE_LOG_SPEECH_TO_DEDICATED_CHANNEL: - if voice_gateway_cog: # Should exist if we got this far - dedicated_channel = voice_gateway_cog.get_dedicated_text_channel_for_guild(active_vc.guild.id) + if ( + VOICE_DEDICATED_TEXT_CHANNEL_ENABLED + and VOICE_LOG_SPEECH_TO_DEDICATED_CHANNEL + ): + if voice_gateway_cog: # Should exist if we got this far + dedicated_channel = ( + voice_gateway_cog.get_dedicated_text_channel_for_guild( + active_vc.guild.id + ) + ) if dedicated_channel: try: await dedicated_channel.send(f"GURT (Voice): {text_to_speak}") - print(f"Logged GURT's speech to dedicated channel {dedicated_channel.name}") + print( + f"Logged GURT's speech to dedicated channel {dedicated_channel.name}" + ) except Exception as e_log: - print(f"Error logging GURT's speech to dedicated channel {dedicated_channel.name}: {e_log}") + print( + f"Error logging GURT's speech to dedicated channel {dedicated_channel.name}: {e_log}" + ) else: - print(f"Could not find dedicated text channel for guild {active_vc.guild.id} to log speech.") - else: # Should not happen - print("VoiceGatewayCog not found for logging speech to dedicated channel.") + print( + f"Could not find dedicated text channel for guild {active_vc.guild.id} to log speech." + ) + else: # Should not happen + print( + "VoiceGatewayCog not found for logging speech to dedicated channel." + ) - return {"status": "success", "message": play_message, "text_spoken": text_to_speak, "provider_used": chosen_provider} + return { + "status": "success", + "message": play_message, + "text_spoken": text_to_speak, + "provider_used": chosen_provider, + } else: # TTSProviderCog's cleanup should handle the audio_file_path if play fails return {"status": "error", "error": f"Failed to play audio: {play_message}"} + # --- End Voice Channel Tools --- # --- List Files Tool --- -async def list_files_tool(cog: commands.Cog, path: str, recursive: bool = False) -> Dict[str, Any]: +async def list_files_tool( + cog: commands.Cog, path: str, recursive: bool = False +) -> Dict[str, Any]: """Lists files and directories within a specified path.""" print(f"Executing list_files_tool: Path='{path}', Recursive={recursive}") try: @@ -3154,7 +4313,10 @@ async def list_files_tool(cog: commands.Cog, path: str, recursive: bool = False) if not os.path.exists(target_path): return {"error": f"Path not found: {target_path}", "path_requested": path} if not os.path.isdir(target_path): - return {"error": f"Path is not a directory: {target_path}", "path_requested": path} + return { + "error": f"Path is not a directory: {target_path}", + "path_requested": path, + } items = [] if recursive: @@ -3163,19 +4325,37 @@ async def list_files_tool(cog: commands.Cog, path: str, recursive: bool = False) for d_name in dirs: full_d_path = os.path.join(root, d_name) relative_d_path = os.path.relpath(full_d_path, base_path) - items.append({"name": d_name, "path": relative_d_path.replace('\\\\', '/'), "type": "directory"}) + items.append( + { + "name": d_name, + "path": relative_d_path.replace("\\\\", "/"), + "type": "directory", + } + ) # Add files for f_name in files: full_f_path = os.path.join(root, f_name) relative_f_path = os.path.relpath(full_f_path, base_path) - items.append({"name": f_name, "path": relative_f_path.replace('\\\\', '/'), "type": "file"}) + items.append( + { + "name": f_name, + "path": relative_f_path.replace("\\\\", "/"), + "type": "file", + } + ) else: for item_name in os.listdir(target_path): full_item_path = os.path.join(target_path, item_name) relative_item_path = os.path.relpath(full_item_path, base_path) item_type = "directory" if os.path.isdir(full_item_path) else "file" - items.append({"name": item_name, "path": relative_item_path.replace('\\\\', '/'), "type": item_type}) - + items.append( + { + "name": item_name, + "path": relative_item_path.replace("\\\\", "/"), + "type": item_type, + } + ) + # Sort items for consistent output items.sort(key=lambda x: (x["type"], x["path"])) @@ -3186,17 +4366,24 @@ async def list_files_tool(cog: commands.Cog, path: str, recursive: bool = False) "recursive": recursive, "items": items, "count": len(items), - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } except PermissionError: - return {"error": f"Permission denied for path: {target_path}", "path_requested": path} + return { + "error": f"Permission denied for path: {target_path}", + "path_requested": path, + } except Exception as e: error_message = f"Unexpected error listing files for path '{path}': {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message, "path_requested": path} + # --- Tenor GIF Search Tool Implementation --- -async def tool_search_tenor_gifs(cog: commands.Cog, query: str, limit: int = 8) -> Dict[str, Any]: +async def tool_search_tenor_gifs( + cog: commands.Cog, query: str, limit: int = 8 +) -> Dict[str, Any]: """Searches Tenor for GIFs and returns a list of URLs.""" print(f"Executing tool_search_tenor_gifs: Query='{query}', Limit={limit}") # Ensure limit is within a reasonable range, e.g., 1-15 @@ -3210,7 +4397,7 @@ async def tool_search_tenor_gifs(cog: commands.Cog, query: str, limit: int = 8) "limit": limit, "gif_urls": gif_urls, "count": len(gif_urls), - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } else: return { @@ -3220,7 +4407,7 @@ async def tool_search_tenor_gifs(cog: commands.Cog, query: str, limit: int = 8) "gif_urls": [], "count": 0, "error": f"No Tenor GIFs found for query: {query}", - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: error_message = f"Error during Tenor GIF search tool execution: {str(e)}" @@ -3231,9 +4418,10 @@ async def tool_search_tenor_gifs(cog: commands.Cog, query: str, limit: int = 8) "query": query, "limit": limit, "error": error_message, - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } + # --- Tenor GIF Search Tool Implementation --- async def search_tenor_gifs(cog: commands.Cog, query: str, limit: int = 8) -> List[str]: """Searches Tenor for GIFs and returns a list of URLs.""" @@ -3246,48 +4434,59 @@ async def search_tenor_gifs(cog: commands.Cog, query: str, limit: int = 8) -> Li "q": query, "key": cog.TENOR_API_KEY, "limit": limit, - "media_filter": "gif", # Ensure we get GIFs - "contentfilter": "medium" # Adjust content filter as needed (off, low, medium, high) + "media_filter": "gif", # Ensure we get GIFs + "contentfilter": "medium", # Adjust content filter as needed (off, low, medium, high) } try: if not cog.session: - cog.session = aiohttp.ClientSession() # Ensure session exists + cog.session = aiohttp.ClientSession() # Ensure session exists async with cog.session.get(url, params=params, timeout=10) as response: if response.status == 200: data = await response.json() - gif_urls = [result["media_formats"]["gif"]["url"] for result in data.get("results", []) if "media_formats" in result and "gif" in result["media_formats"] and "url" in result["media_formats"]["gif"]] + gif_urls = [ + result["media_formats"]["gif"]["url"] + for result in data.get("results", []) + if "media_formats" in result + and "gif" in result["media_formats"] + and "url" in result["media_formats"]["gif"] + ] return gif_urls else: - print(f"Error searching Tenor: {response.status} - {await response.text()}") + print( + f"Error searching Tenor: {response.status} - {await response.text()}" + ) return [] except Exception as e: print(f"Exception during Tenor API call: {e}") return [] + # --- Send Tenor GIF Tool with AI Selection --- -async def send_tenor_gif(cog: commands.Cog, query: str, limit: int = 8) -> Dict[str, Any]: +async def send_tenor_gif( + cog: commands.Cog, query: str, limit: int = 8 +) -> Dict[str, Any]: """Searches for multiple GIFs, has AI pick the best one, and sends it.""" print(f"Executing send_tenor_gif: Query='{query}', Limit={limit}") - + try: # Import here to avoid circular imports from google import genai from google.genai import types from .config import DEFAULT_MODEL, PROJECT_ID, LOCATION - + # Search for GIFs gif_urls = await search_tenor_gifs(cog, query, limit=limit) - + if not gif_urls: return { "status": "no_results", "query": query, "error": f"No Tenor GIFs found for query: {query}", - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } - + # If only one GIF found, use it directly if len(gif_urls) == 1: selected_gif = gif_urls[0] @@ -3300,30 +4499,36 @@ async def send_tenor_gif(cog: commands.Cog, query: str, limit: int = 8) -> Dict[ # Download GIF data for AI analysis if not cog.session: cog.session = aiohttp.ClientSession() - + async with cog.session.get(gif_url, timeout=15) as response: if response.status == 200: gif_data = await response.read() # Create a Part with file data for the AI - gif_parts.append({ - "index": i, - "url": gif_url, - "part": types.Part(inline_data=types.Blob(data=gif_data, mime_type="image/gif")) - }) + gif_parts.append( + { + "index": i, + "url": gif_url, + "part": types.Part( + inline_data=types.Blob( + data=gif_data, mime_type="image/gif" + ) + ), + } + ) else: print(f"Failed to download GIF {i}: {response.status}") except Exception as e: print(f"Error downloading GIF {i}: {e}") continue - + if not gif_parts: return { "status": "error", "query": query, "error": "Failed to download any GIFs for analysis", - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } - + # Prepare AI prompt for GIF selection selection_prompt = f"""You are selecting the best GIF from multiple options for the query: "{query}" @@ -3342,20 +4547,18 @@ async def send_tenor_gif(cog: commands.Cog, query: str, limit: int = 8) -> Dict[ }} The selected_index should be a number from 0 to {len(gif_parts)-1}.""" - + # Build content with text and GIF parts ai_content = [types.Part(text=selection_prompt)] for gif_data in gif_parts: ai_content.append(gif_data["part"]) - + # Initialize AI client if needed - if not hasattr(cog, 'genai_client') or not cog.genai_client: + if not hasattr(cog, "genai_client") or not cog.genai_client: cog.genai_client = genai.Client( - vertexai=True, - project=PROJECT_ID, - location=LOCATION + vertexai=True, project=PROJECT_ID, location=LOCATION ) - + # Generate AI selection try: model = genai.GenerativeModel(DEFAULT_MODEL) @@ -3363,24 +4566,26 @@ async def send_tenor_gif(cog: commands.Cog, query: str, limit: int = 8) -> Dict[ contents=[types.Content(role="user", parts=ai_content)], config=types.GenerateContentConfig( temperature=0.1, # Low temperature for consistent selection - max_output_tokens=150 - ) + max_output_tokens=150, + ), ) - + # Parse AI response ai_text = response.text.strip() - + # Clean up the response to extract JSON if "```json" in ai_text: ai_text = ai_text.split("```json")[1].split("```")[0].strip() elif "```" in ai_text: ai_text = ai_text.split("```")[1].strip() - + try: selection_data = json.loads(ai_text) selected_index = selection_data.get("selected_index", 0) - selection_reason = selection_data.get("reason", "AI selected this GIF") - + selection_reason = selection_data.get( + "reason", "AI selected this GIF" + ) + # Validate index if 0 <= selected_index < len(gif_parts): selected_gif = gif_parts[selected_index]["url"] @@ -3388,18 +4593,18 @@ async def send_tenor_gif(cog: commands.Cog, query: str, limit: int = 8) -> Dict[ # Fallback to first GIF if index is invalid selected_gif = gif_urls[0] selection_reason = "Fallback: AI provided invalid index" - + except json.JSONDecodeError: # Fallback to first GIF if JSON parsing fails selected_gif = gif_urls[0] selection_reason = "Fallback: Could not parse AI selection" - + except Exception as e: print(f"Error in AI GIF selection: {e}") # Fallback to first GIF selected_gif = gif_urls[0] selection_reason = f"Fallback: AI selection failed ({str(e)})" - + # Send the selected GIF to the current channel channel = cog.current_channel if channel: @@ -3411,23 +4616,23 @@ async def send_tenor_gif(cog: commands.Cog, query: str, limit: int = 8) -> Dict[ "selected_gif": selected_gif, "total_found": len(gif_urls), "selection_reason": selection_reason, - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: return { "status": "error", "query": query, "error": f"Failed to send GIF: {str(e)}", - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } else: return { "status": "error", "query": query, "error": "No current channel available to send GIF", - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } - + except Exception as e: error_message = f"Error in send_tenor_gif: {str(e)}" print(error_message) @@ -3436,9 +4641,10 @@ async def send_tenor_gif(cog: commands.Cog, query: str, limit: int = 8) -> Dict[ "status": "error", "query": query, "error": error_message, - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } + # Update TOOL_MAPPING to include the new Tenor GIF tool and list_files_tool TOOL_MAPPING["search_tenor_gifs"] = tool_search_tenor_gifs TOOL_MAPPING["send_tenor_gif"] = send_tenor_gif diff --git a/gurt/utils.py b/gurt/utils.py index c0dd579..a4d7b82 100644 --- a/gurt/utils.py +++ b/gurt/utils.py @@ -9,13 +9,16 @@ import os from typing import TYPE_CHECKING, Optional, Tuple, Dict, Any if TYPE_CHECKING: - from .cog import GurtCog # For type hinting + from .cog import GurtCog # For type hinting # --- Utility Functions --- # Note: Functions needing cog state (like personality traits for mistakes) # will need the 'cog' instance passed in. -def replace_mentions_with_names(cog: 'GurtCog', content: str, message: discord.Message) -> str: + +def replace_mentions_with_names( + cog: "GurtCog", content: str, message: discord.Message +) -> str: """Replaces user mentions (<@id> or <@!id>) with their display names.""" if not message.mentions: return content @@ -23,14 +26,21 @@ def replace_mentions_with_names(cog: 'GurtCog', content: str, message: discord.M processed_content = content # Sort by length of ID to handle potential overlaps correctly (longer IDs first) # Although Discord IDs are fixed length, this is safer if formats change - sorted_mentions = sorted(message.mentions, key=lambda m: len(str(m.id)), reverse=True) + sorted_mentions = sorted( + message.mentions, key=lambda m: len(str(m.id)), reverse=True + ) for member in sorted_mentions: # Use display_name for better readability - processed_content = processed_content.replace(f'<@{member.id}>', member.display_name) - processed_content = processed_content.replace(f'<@!{member.id}>', member.display_name) # Handle nickname mention format + processed_content = processed_content.replace( + f"<@{member.id}>", member.display_name + ) + processed_content = processed_content.replace( + f"<@!{member.id}>", member.display_name + ) # Handle nickname mention format return processed_content + def _format_attachment_size(size_bytes: int) -> str: """Formats attachment size into KB or MB.""" if size_bytes < 1024: @@ -40,27 +50,36 @@ def _format_attachment_size(size_bytes: int) -> str: else: return f"{size_bytes / (1024 * 1024):.1f} MB" -def format_message(cog: 'GurtCog', message: discord.Message) -> Dict[str, Any]: + +def format_message(cog: "GurtCog", message: discord.Message) -> Dict[str, Any]: """ Helper function to format a discord.Message object into a dictionary, including detailed reply info and attachment descriptions. """ # Process content first to replace mentions - processed_content = replace_mentions_with_names(cog, message.content, message) # Pass cog + processed_content = replace_mentions_with_names( + cog, message.content, message + ) # Pass cog # --- Attachment Processing --- attachment_descriptions = [] for a in message.attachments: size_str = _format_attachment_size(a.size) - file_type = "Image" if a.content_type and a.content_type.startswith("image/") else "File" + file_type = ( + "Image" + if a.content_type and a.content_type.startswith("image/") + else "File" + ) description = f"[{file_type}: {a.filename} ({a.content_type or 'unknown type'}, {size_str})]" - attachment_descriptions.append({ - "description": description, - "filename": a.filename, - "content_type": a.content_type, - "size": a.size, - "url": a.url # Keep URL for potential future use (e.g., vision model) - }) + attachment_descriptions.append( + { + "description": description, + "filename": a.filename, + "content_type": a.content_type, + "size": a.size, + "url": a.url, # Keep URL for potential future use (e.g., vision model) + } + ) # --- End Attachment Processing --- # Basic message structure @@ -70,64 +89,75 @@ def format_message(cog: 'GurtCog', message: discord.Message) -> Dict[str, Any]: "id": str(message.author.id), "name": message.author.name, "display_name": message.author.display_name, - "bot": message.author.bot + "bot": message.author.bot, }, - "content": processed_content, # Use processed content - "author_string": f"{message.author.display_name}{' (BOT)' if message.author.bot else ''}", # Add formatted author string + "content": processed_content, # Use processed content + "author_string": f"{message.author.display_name}{' (BOT)' if message.author.bot else ''}", # Add formatted author string "created_at": message.created_at.isoformat(), - "attachment_descriptions": attachment_descriptions, # Use new descriptions list + "attachment_descriptions": attachment_descriptions, # Use new descriptions list # "attachments": [{"filename": a.filename, "url": a.url} for a in message.attachments], # REMOVED old field # "embeds": len(message.embeds) > 0, # Replaced by embed_content below - "embed_content": [], # Initialize embed content list - "mentions": [{"id": str(m.id), "name": m.name, "display_name": m.display_name} for m in message.mentions], # Keep detailed mentions + "embed_content": [], # Initialize embed content list + "mentions": [ + {"id": str(m.id), "name": m.name, "display_name": m.display_name} + for m in message.mentions + ], # Keep detailed mentions # Reply fields initialized "replied_to_message_id": None, "replied_to_author_id": None, "replied_to_author_name": None, - "replied_to_content_snippet": None, # Changed field name for clarity + "replied_to_content_snippet": None, # Changed field name for clarity "is_reply": False, - "custom_emojis": [], # Initialize custom_emojis list - "stickers": [] # Initialize stickers list + "custom_emojis": [], # Initialize custom_emojis list + "stickers": [], # Initialize stickers list } # --- Custom Emoji Processing --- # Regex to find custom emojis: <:name:id> or - emoji_pattern = re.compile(r'<(a)?:([a-zA-Z0-9_]+):([0-9]+)>') + emoji_pattern = re.compile(r"<(a)?:([a-zA-Z0-9_]+):([0-9]+)>") for match in emoji_pattern.finditer(message.content): animated_flag, emoji_name, emoji_id_str = match.groups() emoji_id = int(emoji_id_str) animated = bool(animated_flag) - + emoji_obj = cog.bot.get_emoji(emoji_id) if emoji_obj: - formatted_msg["custom_emojis"].append({ - "name": emoji_obj.name, - "url": str(emoji_obj.url), - "id": str(emoji_obj.id), - "animated": emoji_obj.animated - }) + formatted_msg["custom_emojis"].append( + { + "name": emoji_obj.name, + "url": str(emoji_obj.url), + "id": str(emoji_obj.id), + "animated": emoji_obj.animated, + } + ) else: # Fallback if emoji is not directly accessible by the bot # Construct a potential URL (Discord's CDN format) extension = "gif" if animated else "png" fallback_url = f"https://cdn.discordapp.com/emojis/{emoji_id}.{extension}" - formatted_msg["custom_emojis"].append({ - "name": emoji_name, # Name from regex - "url": fallback_url, - "id": emoji_id_str, - "animated": animated - }) + formatted_msg["custom_emojis"].append( + { + "name": emoji_name, # Name from regex + "url": fallback_url, + "id": emoji_id_str, + "animated": animated, + } + ) # --- Sticker Processing --- if message.stickers: for sticker_item in message.stickers: # discord.StickerItem has name, id, format, and url - formatted_msg["stickers"].append({ - "name": sticker_item.name, - "url": str(sticker_item.url), # sticker_item.url is already the asset URL - "id": str(sticker_item.id), - "format": str(sticker_item.format) # e.g., "StickerFormatType.png" - }) + formatted_msg["stickers"].append( + { + "name": sticker_item.name, + "url": str( + sticker_item.url + ), # sticker_item.url is already the asset URL + "id": str(sticker_item.id), + "format": str(sticker_item.format), # e.g., "StickerFormatType.png" + } + ) # --- Reply Processing --- if message.reference and message.reference.message_id: @@ -135,12 +165,12 @@ def format_message(cog: 'GurtCog', message: discord.Message) -> Dict[str, Any]: formatted_msg["is_reply"] = True # Try to get resolved details (might be None if message not cached/fetched) ref_msg = message.reference.resolved - if isinstance(ref_msg, discord.Message): # Check if resolved is a Message + if isinstance(ref_msg, discord.Message): # Check if resolved is a Message formatted_msg["replied_to_author_id"] = str(ref_msg.author.id) formatted_msg["replied_to_author_name"] = ref_msg.author.display_name # Create a snippet of the replied-to content snippet = ref_msg.content - if len(snippet) > 80: # Truncate long replies + if len(snippet) > 80: # Truncate long replies snippet = snippet[:77] + "..." formatted_msg["replied_to_content_snippet"] = snippet # else: print(f"Referenced message {message.reference.message_id} not resolved.") # Optional debug @@ -161,47 +191,75 @@ def format_message(cog: 'GurtCog', message: discord.Message) -> Dict[str, Any]: "image_url": embed.image.url if embed.image else None, } if embed.footer and embed.footer.text: - embed_data["footer"] = {"text": embed.footer.text, "icon_url": embed.footer.icon_url} + embed_data["footer"] = { + "text": embed.footer.text, + "icon_url": embed.footer.icon_url, + } if embed.author and embed.author.name: - embed_data["author"] = {"name": embed.author.name, "url": embed.author.url, "icon_url": embed.author.icon_url} + embed_data["author"] = { + "name": embed.author.name, + "url": embed.author.url, + "icon_url": embed.author.icon_url, + } for field in embed.fields: - embed_data["fields"].append({"name": field.name, "value": field.value, "inline": field.inline}) + embed_data["fields"].append( + {"name": field.name, "value": field.value, "inline": field.inline} + ) formatted_msg["embed_content"].append(embed_data) # --- End Embed Processing --- return formatted_msg -def update_relationship(cog: 'GurtCog', user_id_1: str, user_id_2: str, change: float): + +def update_relationship(cog: "GurtCog", user_id_1: str, user_id_2: str, change: float): """Updates the relationship score between two users.""" # Ensure consistent key order - if user_id_1 > user_id_2: user_id_1, user_id_2 = user_id_2, user_id_1 + if user_id_1 > user_id_2: + user_id_1, user_id_2 = user_id_2, user_id_1 # Initialize user_id_1's dict if not present - if user_id_1 not in cog.user_relationships: cog.user_relationships[user_id_1] = {} + if user_id_1 not in cog.user_relationships: + cog.user_relationships[user_id_1] = {} current_score = cog.user_relationships[user_id_1].get(user_id_2, 0.0) - new_score = max(0.0, min(current_score + change, 100.0)) # Clamp 0-100 + new_score = max(0.0, min(current_score + change, 100.0)) # Clamp 0-100 cog.user_relationships[user_id_1][user_id_2] = new_score # print(f"Updated relationship {user_id_1}-{user_id_2}: {current_score:.1f} -> {new_score:.1f} ({change:+.1f})") # Debug log -async def simulate_human_typing(cog: 'GurtCog', channel, text: str): + +async def simulate_human_typing(cog: "GurtCog", channel, text: str): """Shows typing indicator without significant delay.""" # Minimal delay to ensure the typing indicator shows up reliably # but doesn't add noticeable latency to the response. # The actual sending of the message happens immediately after this. # Check if the bot has permissions to send messages and type - perms = channel.permissions_for(channel.guild.me) if isinstance(channel, discord.TextChannel) else None - if perms is None or (perms.send_messages and perms.send_tts_messages): # send_tts_messages often implies typing allowed + perms = ( + channel.permissions_for(channel.guild.me) + if isinstance(channel, discord.TextChannel) + else None + ) + if perms is None or ( + perms.send_messages and perms.send_tts_messages + ): # send_tts_messages often implies typing allowed try: async with channel.typing(): - await asyncio.sleep(0.1) # Very short sleep, just to ensure typing shows + await asyncio.sleep( + 0.1 + ) # Very short sleep, just to ensure typing shows except discord.Forbidden: print(f"Warning: Missing permissions to type in channel {channel.id}") except Exception as e: print(f"Warning: Error during typing simulation in {channel.id}: {e}") # else: print(f"Skipping typing simulation in {channel.id} due to missing permissions.") # Optional debug -async def log_internal_api_call(cog: 'GurtCog', task_description: str, payload: Dict[str, Any], response_data: Optional[Dict[str, Any]], error: Optional[Exception] = None): + +async def log_internal_api_call( + cog: "GurtCog", + task_description: str, + payload: Dict[str, Any], + response_data: Optional[Dict[str, Any]], + error: Optional[Exception] = None, +): """Helper function to log internal API calls to a file.""" log_dir = "data" log_file = os.path.join(log_dir, "internal_api_calls.log") @@ -214,30 +272,42 @@ async def log_internal_api_call(cog: 'GurtCog', task_description: str, payload: # Sanitize payload for logging (avoid large base64 images) payload_to_log = payload.copy() - if 'messages' in payload_to_log: + if "messages" in payload_to_log: sanitized_messages = [] - for msg in payload_to_log['messages']: - if isinstance(msg.get('content'), list): # Multimodal message + for msg in payload_to_log["messages"]: + if isinstance(msg.get("content"), list): # Multimodal message new_content = [] - for part in msg['content']: - if part.get('type') == 'image_url' and part.get('image_url', {}).get('url', '').startswith('data:image'): - new_content.append({'type': 'image_url', 'image_url': {'url': 'data:image/...[truncated]'}}) + for part in msg["content"]: + if part.get("type") == "image_url" and part.get( + "image_url", {} + ).get("url", "").startswith("data:image"): + new_content.append( + { + "type": "image_url", + "image_url": {"url": "data:image/...[truncated]"}, + } + ) else: new_content.append(part) - sanitized_messages.append({**msg, 'content': new_content}) + sanitized_messages.append({**msg, "content": new_content}) else: sanitized_messages.append(msg) - payload_to_log['messages'] = sanitized_messages + payload_to_log["messages"] = sanitized_messages log_entry += f"Request Payload:\n{json.dumps(payload_to_log, indent=2)}\n" - if response_data: log_entry += f"Response Data:\n{json.dumps(response_data, indent=2)}\n" - if error: log_entry += f"Error: {str(error)}\n" + if response_data: + log_entry += f"Response Data:\n{json.dumps(response_data, indent=2)}\n" + if error: + log_entry += f"Error: {str(error)}\n" log_entry += "---\n\n" # Use async file writing if in async context, but this helper might be called from sync code? # Sticking to sync file I/O for simplicity here, assuming logging isn't performance critical path. - with open(log_file, "a", encoding="utf-8") as f: f.write(log_entry) - except Exception as log_e: print(f"!!! Failed to write to internal API log file {log_file}: {log_e}") + with open(log_file, "a", encoding="utf-8") as f: + f.write(log_entry) + except Exception as log_e: + print(f"!!! Failed to write to internal API log file {log_file}: {log_e}") + # Note: _create_human_like_mistake was removed as it wasn't used in the final on_message logic provided. # If needed, it can be added back here, ensuring it takes 'cog' if it needs personality traits. diff --git a/gurt_bot.py b/gurt_bot.py index 05caa85..b6a853a 100644 --- a/gurt_bot.py +++ b/gurt_bot.py @@ -14,15 +14,18 @@ intents.message_content = True intents.members = True # Create bot instance with command prefix '%' -bot = commands.Bot(command_prefix='%', intents=intents) -bot.owner_id = int(os.getenv('OWNER_USER_ID')) +bot = commands.Bot(command_prefix="%", intents=intents) +bot.owner_id = int(os.getenv("OWNER_USER_ID")) + @bot.event async def on_ready(): - print(f'{bot.user.name} has connected to Discord!') - print(f'Bot ID: {bot.user.id}') + print(f"{bot.user.name} has connected to Discord!") + print(f"Bot ID: {bot.user.id}") # Set the bot's status - await bot.change_presence(activity=discord.Activity(type=discord.ActivityType.listening, name="%ai")) + await bot.change_presence( + activity=discord.Activity(type=discord.ActivityType.listening, name="%ai") + ) print("Bot status set to 'Listening to %ai'") # Sync commands @@ -33,8 +36,10 @@ async def on_ready(): except Exception as e: print(f"Failed to sync commands: {e}") import traceback + traceback.print_exc() + async def main(minimal_prompt: bool = False): """Main async function to load the gurt cog and start the bot.""" # Store the flag on the bot instance so the cog can access it @@ -43,14 +48,16 @@ async def main(minimal_prompt: bool = False): print("Minimal prompt mode enabled.") # Check for required environment variables - TOKEN = os.getenv('DISCORD_TOKEN_GURT') + TOKEN = os.getenv("DISCORD_TOKEN_GURT") # If Discord token not found, try to use the main bot token if not TOKEN: - TOKEN = os.getenv('DISCORD_TOKEN') + TOKEN = os.getenv("DISCORD_TOKEN") if not TOKEN: - raise ValueError("No Discord token found. Make sure to set DISCORD_TOKEN_GURT or DISCORD_TOKEN in your .env file.") + raise ValueError( + "No Discord token found. Make sure to set DISCORD_TOKEN_GURT or DISCORD_TOKEN in your .env file." + ) # Note: Vertex AI authentication is handled by the library using ADC or GOOGLE_APPLICATION_CREDENTIALS. # No explicit API key check is needed here. Ensure GCP_PROJECT_ID and GCP_LOCATION are set in .env @@ -59,7 +66,11 @@ async def main(minimal_prompt: bool = False): async with bot: # List of cogs to load # Updated path for the refactored GurtCog - cogs = ["gurt.cog", "cogs.VoiceGatewayCog", "cogs.tts_provider_cog"]#, "cogs.profile_updater_cog"] + cogs = [ + "gurt.cog", + "cogs.VoiceGatewayCog", + "cogs.tts_provider_cog", + ] # , "cogs.profile_updater_cog"] for cog in cogs: try: await bot.load_extension(cog) @@ -67,6 +78,7 @@ async def main(minimal_prompt: bool = False): except Exception as e: print(f"Error loading {cog}: {e}") import traceback + traceback.print_exc() # Start the bot @@ -74,8 +86,9 @@ async def main(minimal_prompt: bool = False): except Exception as e: print(f"Error starting Gurt Bot: {e}") + # Run the main async function -if __name__ == '__main__': +if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: diff --git a/gurt_memory.py b/gurt_memory.py index 4ffe885..2a8029a 100644 --- a/gurt_memory.py +++ b/gurt_memory.py @@ -4,67 +4,99 @@ import os import time import datetime import re -import hashlib # Added for chroma_id generation -import json # Added for personality trait serialization/deserialization -from typing import Dict, List, Any, Optional, Tuple, Union # Added Union +import hashlib # Added for chroma_id generation +import json # Added for personality trait serialization/deserialization +from typing import Dict, List, Any, Optional, Tuple, Union # Added Union import chromadb from chromadb.utils import embedding_functions from sentence_transformers import SentenceTransformer import logging # Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) logger = logging.getLogger(__name__) # Constants INTEREST_INITIAL_LEVEL = 0.1 INTEREST_MAX_LEVEL = 1.0 INTEREST_MIN_LEVEL = 0.0 -INTEREST_DECAY_RATE = 0.02 # Default decay rate per cycle -INTEREST_DECAY_INTERVAL_HOURS = 24 # Default interval for decay check +INTEREST_DECAY_RATE = 0.02 # Default decay rate per cycle +INTEREST_DECAY_INTERVAL_HOURS = 24 # Default interval for decay check + # --- Helper Function for Keyword Scoring --- def calculate_keyword_score(text: str, context: str) -> int: """Calculates a simple keyword overlap score.""" if not context or not text: return 0 - context_words = set(re.findall(r'\b\w+\b', context.lower())) - text_words = set(re.findall(r'\b\w+\b', text.lower())) + context_words = set(re.findall(r"\b\w+\b", context.lower())) + text_words = set(re.findall(r"\b\w+\b", text.lower())) # Ignore very common words (basic stopword list) - stopwords = {"the", "a", "is", "in", "it", "of", "and", "to", "for", "on", "with", "that", "this", "i", "you", "me", "my", "your"} + stopwords = { + "the", + "a", + "is", + "in", + "it", + "of", + "and", + "to", + "for", + "on", + "with", + "that", + "this", + "i", + "you", + "me", + "my", + "your", + } context_words -= stopwords text_words -= stopwords - if not context_words: # Avoid division by zero if context is only stopwords + if not context_words: # Avoid division by zero if context is only stopwords return 0 overlap = len(context_words.intersection(text_words)) # Normalize score slightly by context length (more overlap needed for longer context) # score = overlap / (len(context_words) ** 0.5) # Example normalization - score = overlap # Simpler score for now + score = overlap # Simpler score for now return score + class MemoryManager: """Handles database interactions for Gurt's memory (facts and semantic).""" - def __init__(self, db_path: str, max_user_facts: int = 20, max_general_facts: int = 100, semantic_model_name: str = 'all-MiniLM-L6-v2', chroma_path: str = "data/chroma_db"): + def __init__( + self, + db_path: str, + max_user_facts: int = 20, + max_general_facts: int = 100, + semantic_model_name: str = "all-MiniLM-L6-v2", + chroma_path: str = "data/chroma_db", + ): self.db_path = db_path self.max_user_facts = max_user_facts self.max_general_facts = max_general_facts - self.db_lock = asyncio.Lock() # Lock for SQLite operations + self.db_lock = asyncio.Lock() # Lock for SQLite operations # Ensure data directories exist os.makedirs(os.path.dirname(self.db_path), exist_ok=True) os.makedirs(chroma_path, exist_ok=True) - logger.info(f"MemoryManager initialized with db_path: {self.db_path}, chroma_path: {chroma_path}") + logger.info( + f"MemoryManager initialized with db_path: {self.db_path}, chroma_path: {chroma_path}" + ) # --- Semantic Memory Setup --- self.chroma_path = chroma_path self.semantic_model_name = semantic_model_name self.chroma_client = None self.embedding_function = None - self.semantic_collection = None # For messages - self.fact_collection = None # For facts + self.semantic_collection = None # For messages + self.fact_collection = None # For facts self.transformer_model = None - self._initialize_semantic_memory_sync() # Initialize semantic components synchronously for simplicity during init + self._initialize_semantic_memory_sync() # Initialize semantic components synchronously for simplicity during init def _initialize_semantic_memory_sync(self): """Synchronously initializes ChromaDB client, model, and collection.""" @@ -73,7 +105,9 @@ class MemoryManager: # Use PersistentClient for saving data to disk self.chroma_client = chromadb.PersistentClient(path=self.chroma_path) - logger.info(f"Loading Sentence Transformer model: {self.semantic_model_name}...") + logger.info( + f"Loading Sentence Transformer model: {self.semantic_model_name}..." + ) # Load the model directly self.transformer_model = SentenceTransformer(self.semantic_model_name) @@ -81,27 +115,34 @@ class MemoryManager: class CustomEmbeddingFunction(embedding_functions.EmbeddingFunction): def __init__(self, model): self.model = model + def __call__(self, input: chromadb.Documents) -> chromadb.Embeddings: # Ensure input is a list of strings if not isinstance(input, list): - input = [str(input)] # Convert single item to list + input = [str(input)] # Convert single item to list elif not all(isinstance(item, str) for item in input): - input = [str(item) for item in input] # Ensure all items are strings + input = [ + str(item) for item in input + ] # Ensure all items are strings logger.debug(f"Generating embeddings for {len(input)} documents.") - embeddings = self.model.encode(input, show_progress_bar=False).tolist() + embeddings = self.model.encode( + input, show_progress_bar=False + ).tolist() logger.debug(f"Generated {len(embeddings)} embeddings.") return embeddings self.embedding_function = CustomEmbeddingFunction(self.transformer_model) - logger.info("Getting/Creating ChromaDB collection 'gurt_semantic_memory'...") + logger.info( + "Getting/Creating ChromaDB collection 'gurt_semantic_memory'..." + ) # Get or create the collection with the custom embedding function self.semantic_collection = self.chroma_client.get_or_create_collection( name="gurt_semantic_memory", embedding_function=self.embedding_function, - metadata={"hnsw:space": "cosine"} # Use cosine distance for similarity - ) # Added missing closing parenthesis + metadata={"hnsw:space": "cosine"}, # Use cosine distance for similarity + ) # Added missing closing parenthesis logger.info("ChromaDB message collection initialized successfully.") logger.info("Getting/Creating ChromaDB collection 'gurt_fact_memory'...") @@ -109,18 +150,20 @@ class MemoryManager: self.fact_collection = self.chroma_client.get_or_create_collection( name="gurt_fact_memory", embedding_function=self.embedding_function, - metadata={"hnsw:space": "cosine"} # Use cosine distance for similarity + metadata={"hnsw:space": "cosine"}, # Use cosine distance for similarity ) logger.info("ChromaDB fact collection initialized successfully.") except Exception as e: - logger.error(f"Failed to initialize semantic memory (ChromaDB): {e}", exc_info=True) + logger.error( + f"Failed to initialize semantic memory (ChromaDB): {e}", exc_info=True + ) # Set components to None to indicate failure self.chroma_client = None self.transformer_model = None self.embedding_function = None self.semantic_collection = None - self.fact_collection = None # Also set fact_collection to None on error + self.fact_collection = None # Also set fact_collection to None on error async def initialize_sqlite_database(self): """Initializes the SQLite database and creates tables if they don't exist.""" @@ -128,7 +171,8 @@ class MemoryManager: await db.execute("PRAGMA journal_mode=WAL;") # Create user_facts table if it doesn't exist - await db.execute(""" + await db.execute( + """ CREATE TABLE IF NOT EXISTS user_facts ( user_id TEXT NOT NULL, fact TEXT NOT NULL, @@ -136,7 +180,8 @@ class MemoryManager: timestamp REAL DEFAULT (unixepoch('now')), PRIMARY KEY (user_id, fact) ); - """) + """ + ) # Check if chroma_id column exists in user_facts table try: @@ -146,24 +191,33 @@ class MemoryManager: column_names = [column[1] for column in columns] # If chroma_id column doesn't exist, add it - if 'chroma_id' not in column_names: + if "chroma_id" not in column_names: logger.info("Adding chroma_id column to user_facts table") await db.execute("ALTER TABLE user_facts ADD COLUMN chroma_id TEXT") except Exception as e: - logger.error(f"Error checking/adding chroma_id column to user_facts: {e}", exc_info=True) + logger.error( + f"Error checking/adding chroma_id column to user_facts: {e}", + exc_info=True, + ) # Create indexes - await db.execute("CREATE INDEX IF NOT EXISTS idx_user_facts_user ON user_facts (user_id);") - await db.execute("CREATE INDEX IF NOT EXISTS idx_user_facts_chroma_id ON user_facts (chroma_id);") # Index for chroma_id + await db.execute( + "CREATE INDEX IF NOT EXISTS idx_user_facts_user ON user_facts (user_id);" + ) + await db.execute( + "CREATE INDEX IF NOT EXISTS idx_user_facts_chroma_id ON user_facts (chroma_id);" + ) # Index for chroma_id # Create general_facts table if it doesn't exist - await db.execute(""" + await db.execute( + """ CREATE TABLE IF NOT EXISTS general_facts ( fact TEXT PRIMARY KEY NOT NULL, chroma_id TEXT, -- Added for linking to ChromaDB timestamp REAL DEFAULT (unixepoch('now')) ); - """) + """ + ) # Check if chroma_id column exists in general_facts table try: @@ -173,40 +227,54 @@ class MemoryManager: column_names = [column[1] for column in columns] # If chroma_id column doesn't exist, add it - if 'chroma_id' not in column_names: + if "chroma_id" not in column_names: logger.info("Adding chroma_id column to general_facts table") - await db.execute("ALTER TABLE general_facts ADD COLUMN chroma_id TEXT") + await db.execute( + "ALTER TABLE general_facts ADD COLUMN chroma_id TEXT" + ) except Exception as e: - logger.error(f"Error checking/adding chroma_id column to general_facts: {e}", exc_info=True) + logger.error( + f"Error checking/adding chroma_id column to general_facts: {e}", + exc_info=True, + ) # Create index for general_facts - await db.execute("CREATE INDEX IF NOT EXISTS idx_general_facts_chroma_id ON general_facts (chroma_id);") # Index for chroma_id + await db.execute( + "CREATE INDEX IF NOT EXISTS idx_general_facts_chroma_id ON general_facts (chroma_id);" + ) # Index for chroma_id # --- Add Personality Table --- - await db.execute(""" + await db.execute( + """ CREATE TABLE IF NOT EXISTS gurt_personality ( trait_key TEXT PRIMARY KEY NOT NULL, trait_value TEXT NOT NULL, -- Store value as JSON string last_updated REAL DEFAULT (unixepoch('now')) ); - """) + """ + ) logger.info("Personality table created/verified.") # --- End Personality Table --- # --- Add Interests Table --- - await db.execute(""" + await db.execute( + """ CREATE TABLE IF NOT EXISTS gurt_interests ( interest_topic TEXT PRIMARY KEY NOT NULL, interest_level REAL DEFAULT 0.1, -- Start with a small default level last_updated REAL DEFAULT (unixepoch('now')) ); - """) - await db.execute("CREATE INDEX IF NOT EXISTS idx_interest_level ON gurt_interests (interest_level);") + """ + ) + await db.execute( + "CREATE INDEX IF NOT EXISTS idx_interest_level ON gurt_interests (interest_level);" + ) logger.info("Interests table created/verified.") # --- End Interests Table --- # --- Add Goals Table --- - await db.execute(""" + await db.execute( + """ CREATE TABLE IF NOT EXISTS gurt_goals ( goal_id INTEGER PRIMARY KEY AUTOINCREMENT, description TEXT NOT NULL UNIQUE, -- The goal description @@ -219,9 +287,14 @@ class MemoryManager: channel_id TEXT, -- The channel ID where the goal was created user_id TEXT -- The user ID who created the goal ); - """) - await db.execute("CREATE INDEX IF NOT EXISTS idx_goal_status ON gurt_goals (status);") - await db.execute("CREATE INDEX IF NOT EXISTS idx_goal_priority ON gurt_goals (priority);") + """ + ) + await db.execute( + "CREATE INDEX IF NOT EXISTS idx_goal_status ON gurt_goals (status);" + ) + await db.execute( + "CREATE INDEX IF NOT EXISTS idx_goal_priority ON gurt_goals (priority);" + ) logger.info("Goals table created/verified.") # --- End Goals Table --- @@ -232,22 +305,27 @@ class MemoryManager: columns = await cursor.fetchall() column_names = [column[1] for column in columns] - if 'guild_id' not in column_names: + if "guild_id" not in column_names: logger.info("Adding guild_id column to gurt_goals table") await db.execute("ALTER TABLE gurt_goals ADD COLUMN guild_id TEXT") - if 'channel_id' not in column_names: + if "channel_id" not in column_names: logger.info("Adding channel_id column to gurt_goals table") - await db.execute("ALTER TABLE gurt_goals ADD COLUMN channel_id TEXT") - if 'user_id' not in column_names: + await db.execute( + "ALTER TABLE gurt_goals ADD COLUMN channel_id TEXT" + ) + if "user_id" not in column_names: logger.info("Adding user_id column to gurt_goals table") await db.execute("ALTER TABLE gurt_goals ADD COLUMN user_id TEXT") except Exception as e: - logger.error(f"Error checking/adding columns to gurt_goals table: {e}", exc_info=True) - + logger.error( + f"Error checking/adding columns to gurt_goals table: {e}", + exc_info=True, + ) # --- Add Internal Actions Log Table --- - await db.execute(""" + await db.execute( + """ CREATE TABLE IF NOT EXISTS internal_actions ( action_id INTEGER PRIMARY KEY AUTOINCREMENT, timestamp REAL DEFAULT (unixepoch('now')), @@ -256,20 +334,30 @@ class MemoryManager: reasoning TEXT, -- Added: Reasoning behind the action result_summary TEXT -- Store a summary of the result or error message ); - """) + """ + ) # Check if reasoning column exists try: cursor = await db.execute("PRAGMA table_info(internal_actions)") columns = await cursor.fetchall() column_names = [column[1] for column in columns] - if 'reasoning' not in column_names: + if "reasoning" not in column_names: logger.info("Adding reasoning column to internal_actions table") - await db.execute("ALTER TABLE internal_actions ADD COLUMN reasoning TEXT") + await db.execute( + "ALTER TABLE internal_actions ADD COLUMN reasoning TEXT" + ) except Exception as e: - logger.error(f"Error checking/adding reasoning column to internal_actions: {e}", exc_info=True) + logger.error( + f"Error checking/adding reasoning column to internal_actions: {e}", + exc_info=True, + ) - await db.execute("CREATE INDEX IF NOT EXISTS idx_internal_actions_timestamp ON internal_actions (timestamp);") - await db.execute("CREATE INDEX IF NOT EXISTS idx_internal_actions_tool_name ON internal_actions (tool_name);") + await db.execute( + "CREATE INDEX IF NOT EXISTS idx_internal_actions_timestamp ON internal_actions (timestamp);" + ) + await db.execute( + "CREATE INDEX IF NOT EXISTS idx_internal_actions_tool_name ON internal_actions (tool_name);" + ) logger.info("Internal Actions Log table created/verified.") # --- End Internal Actions Log Table --- @@ -302,58 +390,94 @@ class MemoryManager: logger.info(f"Attempting to add user fact for {user_id}: '{fact}'") try: # Check SQLite first - existing = await self._db_fetchone("SELECT chroma_id FROM user_facts WHERE user_id = ? AND fact = ?", (user_id, fact)) + existing = await self._db_fetchone( + "SELECT chroma_id FROM user_facts WHERE user_id = ? AND fact = ?", + (user_id, fact), + ) if existing: logger.info(f"Fact already known for user {user_id} (SQLite).") return {"status": "duplicate", "user_id": user_id, "fact": fact} - count_result = await self._db_fetchone("SELECT COUNT(*) FROM user_facts WHERE user_id = ?", (user_id,)) + count_result = await self._db_fetchone( + "SELECT COUNT(*) FROM user_facts WHERE user_id = ?", (user_id,) + ) current_count = count_result[0] if count_result else 0 status = "added" deleted_chroma_id = None if current_count >= self.max_user_facts: - logger.warning(f"User {user_id} fact limit ({self.max_user_facts}) reached. Deleting oldest.") + logger.warning( + f"User {user_id} fact limit ({self.max_user_facts}) reached. Deleting oldest." + ) # Fetch oldest fact and its chroma_id for deletion - oldest_fact_row = await self._db_fetchone("SELECT fact, chroma_id FROM user_facts WHERE user_id = ? ORDER BY timestamp ASC LIMIT 1", (user_id,)) + oldest_fact_row = await self._db_fetchone( + "SELECT fact, chroma_id FROM user_facts WHERE user_id = ? ORDER BY timestamp ASC LIMIT 1", + (user_id,), + ) if oldest_fact_row: oldest_fact, deleted_chroma_id = oldest_fact_row - await self._db_execute("DELETE FROM user_facts WHERE user_id = ? AND fact = ?", (user_id, oldest_fact)) - logger.info(f"Deleted oldest fact for user {user_id} from SQLite: '{oldest_fact}'") - status = "limit_reached" # Indicate limit was hit but fact was added + await self._db_execute( + "DELETE FROM user_facts WHERE user_id = ? AND fact = ?", + (user_id, oldest_fact), + ) + logger.info( + f"Deleted oldest fact for user {user_id} from SQLite: '{oldest_fact}'" + ) + status = ( + "limit_reached" # Indicate limit was hit but fact was added + ) # Generate chroma_id - fact_hash = hashlib.sha1(fact.encode()).hexdigest()[:16] # Short hash + fact_hash = hashlib.sha1(fact.encode()).hexdigest()[:16] # Short hash chroma_id = f"user-{user_id}-{fact_hash}" # Insert into SQLite - await self._db_execute("INSERT INTO user_facts (user_id, fact, chroma_id) VALUES (?, ?, ?)", (user_id, fact, chroma_id)) + await self._db_execute( + "INSERT INTO user_facts (user_id, fact, chroma_id) VALUES (?, ?, ?)", + (user_id, fact, chroma_id), + ) logger.info(f"Fact added for user {user_id} to SQLite.") # Add to ChromaDB fact collection if self.fact_collection and self.embedding_function: try: - metadata = {"user_id": user_id, "type": "user", "timestamp": time.time()} + metadata = { + "user_id": user_id, + "type": "user", + "timestamp": time.time(), + } await asyncio.to_thread( self.fact_collection.add, documents=[fact], metadatas=[metadata], - ids=[chroma_id] + ids=[chroma_id], + ) + logger.info( + f"Fact added/updated for user {user_id} in ChromaDB (ID: {chroma_id})." ) - logger.info(f"Fact added/updated for user {user_id} in ChromaDB (ID: {chroma_id}).") # Delete the oldest fact from ChromaDB if limit was reached if deleted_chroma_id: - logger.info(f"Attempting to delete oldest fact from ChromaDB (ID: {deleted_chroma_id}).") - await asyncio.to_thread(self.fact_collection.delete, ids=[deleted_chroma_id]) - logger.info(f"Successfully deleted oldest fact from ChromaDB (ID: {deleted_chroma_id}).") + logger.info( + f"Attempting to delete oldest fact from ChromaDB (ID: {deleted_chroma_id})." + ) + await asyncio.to_thread( + self.fact_collection.delete, ids=[deleted_chroma_id] + ) + logger.info( + f"Successfully deleted oldest fact from ChromaDB (ID: {deleted_chroma_id})." + ) except Exception as chroma_e: - logger.error(f"ChromaDB error adding/deleting user fact for {user_id} (ID: {chroma_id}): {chroma_e}", exc_info=True) + logger.error( + f"ChromaDB error adding/deleting user fact for {user_id} (ID: {chroma_id}): {chroma_e}", + exc_info=True, + ) # Note: Fact is still in SQLite, but ChromaDB might be inconsistent. Consider rollback? For now, just log. else: - logger.warning(f"ChromaDB fact collection not available. Skipping embedding for user fact {user_id}.") - + logger.warning( + f"ChromaDB fact collection not available. Skipping embedding for user fact {user_id}." + ) return {"status": status, "user_id": user_id, "fact_added": fact} @@ -361,60 +485,78 @@ class MemoryManager: logger.error(f"Error adding user fact for {user_id}: {e}", exc_info=True) return {"error": f"Database error adding user fact: {str(e)}"} - async def get_user_facts(self, user_id: str, context: Optional[str] = None) -> List[str]: + async def get_user_facts( + self, user_id: str, context: Optional[str] = None + ) -> List[str]: """Retrieves stored facts about a user, optionally scored by relevance to context.""" if not user_id: logger.warning("get_user_facts called without user_id.") return [] - logger.info(f"Retrieving facts for user {user_id} (context provided: {bool(context)})") - limit = self.max_user_facts # Use the class attribute for limit + logger.info( + f"Retrieving facts for user {user_id} (context provided: {bool(context)})" + ) + limit = self.max_user_facts # Use the class attribute for limit try: if context and self.fact_collection and self.embedding_function: # --- Semantic Search --- - logger.debug(f"Performing semantic search for user facts (User: {user_id}, Limit: {limit})") + logger.debug( + f"Performing semantic search for user facts (User: {user_id}, Limit: {limit})" + ) try: # Query ChromaDB for facts relevant to the context results = await asyncio.to_thread( self.fact_collection.query, query_texts=[context], n_results=limit, - where={ # Use $and for multiple conditions - "$and": [ - {"user_id": user_id}, - {"type": "user"} - ] + where={ # Use $and for multiple conditions + "$and": [{"user_id": user_id}, {"type": "user"}] }, - include=['documents'] # Only need the fact text + include=["documents"], # Only need the fact text ) logger.debug(f"ChromaDB user fact query results: {results}") - if results and results.get('documents') and results['documents'][0]: - relevant_facts = results['documents'][0] - logger.info(f"Found {len(relevant_facts)} semantically relevant user facts for {user_id}.") + if results and results.get("documents") and results["documents"][0]: + relevant_facts = results["documents"][0] + logger.info( + f"Found {len(relevant_facts)} semantically relevant user facts for {user_id}." + ) return relevant_facts else: - logger.info(f"No semantic user facts found for {user_id} matching context.") - return [] # Return empty list if no semantic matches + logger.info( + f"No semantic user facts found for {user_id} matching context." + ) + return [] # Return empty list if no semantic matches except Exception as chroma_e: - logger.error(f"ChromaDB error searching user facts for {user_id}: {chroma_e}", exc_info=True) + logger.error( + f"ChromaDB error searching user facts for {user_id}: {chroma_e}", + exc_info=True, + ) # Fallback to SQLite retrieval on ChromaDB error - logger.warning(f"Falling back to SQLite retrieval for user facts {user_id} due to ChromaDB error.") + logger.warning( + f"Falling back to SQLite retrieval for user facts {user_id} due to ChromaDB error." + ) # Proceed to the SQLite block below # --- SQLite Fallback / No Context --- # If no context, or if ChromaDB failed/unavailable, get newest N facts from SQLite - logger.debug(f"Retrieving user facts from SQLite (User: {user_id}, Limit: {limit})") + logger.debug( + f"Retrieving user facts from SQLite (User: {user_id}, Limit: {limit})" + ) rows_ordered = await self._db_fetchall( "SELECT fact FROM user_facts WHERE user_id = ? ORDER BY timestamp DESC LIMIT ?", - (user_id, limit) + (user_id, limit), ) sqlite_facts = [row[0] for row in rows_ordered] - logger.info(f"Retrieved {len(sqlite_facts)} user facts from SQLite for {user_id}.") + logger.info( + f"Retrieved {len(sqlite_facts)} user facts from SQLite for {user_id}." + ) return sqlite_facts except Exception as e: - logger.error(f"Error retrieving user facts for {user_id}: {e}", exc_info=True) + logger.error( + f"Error retrieving user facts for {user_id}: {e}", exc_info=True + ) return [] # --- General Fact Memory Methods (SQLite + Relevance) --- @@ -426,32 +568,48 @@ class MemoryManager: logger.info(f"Attempting to add general fact: '{fact}'") try: # Check SQLite first - existing = await self._db_fetchone("SELECT chroma_id FROM general_facts WHERE fact = ?", (fact,)) + existing = await self._db_fetchone( + "SELECT chroma_id FROM general_facts WHERE fact = ?", (fact,) + ) if existing: logger.info(f"General fact already known (SQLite): '{fact}'") return {"status": "duplicate", "fact": fact} - count_result = await self._db_fetchone("SELECT COUNT(*) FROM general_facts", ()) + count_result = await self._db_fetchone( + "SELECT COUNT(*) FROM general_facts", () + ) current_count = count_result[0] if count_result else 0 status = "added" deleted_chroma_id = None if current_count >= self.max_general_facts: - logger.warning(f"General fact limit ({self.max_general_facts}) reached. Deleting oldest.") + logger.warning( + f"General fact limit ({self.max_general_facts}) reached. Deleting oldest." + ) # Fetch oldest fact and its chroma_id for deletion - oldest_fact_row = await self._db_fetchone("SELECT fact, chroma_id FROM general_facts ORDER BY timestamp ASC LIMIT 1", ()) + oldest_fact_row = await self._db_fetchone( + "SELECT fact, chroma_id FROM general_facts ORDER BY timestamp ASC LIMIT 1", + (), + ) if oldest_fact_row: oldest_fact, deleted_chroma_id = oldest_fact_row - await self._db_execute("DELETE FROM general_facts WHERE fact = ?", (oldest_fact,)) - logger.info(f"Deleted oldest general fact from SQLite: '{oldest_fact}'") + await self._db_execute( + "DELETE FROM general_facts WHERE fact = ?", (oldest_fact,) + ) + logger.info( + f"Deleted oldest general fact from SQLite: '{oldest_fact}'" + ) status = "limit_reached" # Generate chroma_id - fact_hash = hashlib.sha1(fact.encode()).hexdigest()[:16] # Short hash + fact_hash = hashlib.sha1(fact.encode()).hexdigest()[:16] # Short hash chroma_id = f"general-{fact_hash}" # Insert into SQLite - await self._db_execute("INSERT INTO general_facts (fact, chroma_id) VALUES (?, ?)", (fact, chroma_id)) + await self._db_execute( + "INSERT INTO general_facts (fact, chroma_id) VALUES (?, ?)", + (fact, chroma_id), + ) logger.info(f"General fact added to SQLite: '{fact}'") # Add to ChromaDB fact collection @@ -462,21 +620,34 @@ class MemoryManager: self.fact_collection.add, documents=[fact], metadatas=[metadata], - ids=[chroma_id] + ids=[chroma_id], + ) + logger.info( + f"General fact added/updated in ChromaDB (ID: {chroma_id})." ) - logger.info(f"General fact added/updated in ChromaDB (ID: {chroma_id}).") # Delete the oldest fact from ChromaDB if limit was reached if deleted_chroma_id: - logger.info(f"Attempting to delete oldest general fact from ChromaDB (ID: {deleted_chroma_id}).") - await asyncio.to_thread(self.fact_collection.delete, ids=[deleted_chroma_id]) - logger.info(f"Successfully deleted oldest general fact from ChromaDB (ID: {deleted_chroma_id}).") + logger.info( + f"Attempting to delete oldest general fact from ChromaDB (ID: {deleted_chroma_id})." + ) + await asyncio.to_thread( + self.fact_collection.delete, ids=[deleted_chroma_id] + ) + logger.info( + f"Successfully deleted oldest general fact from ChromaDB (ID: {deleted_chroma_id})." + ) except Exception as chroma_e: - logger.error(f"ChromaDB error adding/deleting general fact (ID: {chroma_id}): {chroma_e}", exc_info=True) + logger.error( + f"ChromaDB error adding/deleting general fact (ID: {chroma_id}): {chroma_e}", + exc_info=True, + ) # Note: Fact is still in SQLite. else: - logger.warning(f"ChromaDB fact collection not available. Skipping embedding for general fact.") + logger.warning( + f"ChromaDB fact collection not available. Skipping embedding for general fact." + ) return {"status": status, "fact_added": fact} @@ -484,42 +655,60 @@ class MemoryManager: logger.error(f"Error adding general fact: {e}", exc_info=True) return {"error": f"Database error adding general fact: {str(e)}"} - async def get_general_facts(self, query: Optional[str] = None, limit: Optional[int] = 10, context: Optional[str] = None) -> List[str]: + async def get_general_facts( + self, + query: Optional[str] = None, + limit: Optional[int] = 10, + context: Optional[str] = None, + ) -> List[str]: """Retrieves stored general facts, optionally filtering by query or scoring by context relevance.""" - logger.info(f"Retrieving general facts (query='{query}', limit={limit}, context provided: {bool(context)})") - limit = min(max(1, limit or 10), 50) # Use provided limit or default 10, max 50 + logger.info( + f"Retrieving general facts (query='{query}', limit={limit}, context provided: {bool(context)})" + ) + limit = min(max(1, limit or 10), 50) # Use provided limit or default 10, max 50 try: if context and self.fact_collection and self.embedding_function: # --- Semantic Search (Prioritized if context is provided) --- # Note: The 'query' parameter is ignored when context is provided for semantic search. - logger.debug(f"Performing semantic search for general facts (Limit: {limit})") + logger.debug( + f"Performing semantic search for general facts (Limit: {limit})" + ) try: results = await asyncio.to_thread( self.fact_collection.query, query_texts=[context], n_results=limit, - where={"type": "general"}, # Filter by type - include=['documents'] # Only need the fact text + where={"type": "general"}, # Filter by type + include=["documents"], # Only need the fact text ) logger.debug(f"ChromaDB general fact query results: {results}") - if results and results.get('documents') and results['documents'][0]: - relevant_facts = results['documents'][0] - logger.info(f"Found {len(relevant_facts)} semantically relevant general facts.") + if results and results.get("documents") and results["documents"][0]: + relevant_facts = results["documents"][0] + logger.info( + f"Found {len(relevant_facts)} semantically relevant general facts." + ) return relevant_facts else: logger.info("No semantic general facts found matching context.") - return [] # Return empty list if no semantic matches + return [] # Return empty list if no semantic matches except Exception as chroma_e: - logger.error(f"ChromaDB error searching general facts: {chroma_e}", exc_info=True) + logger.error( + f"ChromaDB error searching general facts: {chroma_e}", + exc_info=True, + ) # Fallback to SQLite retrieval on ChromaDB error - logger.warning("Falling back to SQLite retrieval for general facts due to ChromaDB error.") + logger.warning( + "Falling back to SQLite retrieval for general facts due to ChromaDB error." + ) # Proceed to the SQLite block below, respecting the original 'query' if present # --- SQLite Fallback / No Context / ChromaDB Error --- # If no context, or if ChromaDB failed/unavailable, get newest N facts from SQLite, applying query if present. - logger.debug(f"Retrieving general facts from SQLite (Query: '{query}', Limit: {limit})") + logger.debug( + f"Retrieving general facts from SQLite (Query: '{query}', Limit: {limit})" + ) sql = "SELECT fact FROM general_facts" params = [] if query: @@ -532,7 +721,9 @@ class MemoryManager: rows_ordered = await self._db_fetchall(sql, tuple(params)) sqlite_facts = [row[0] for row in rows_ordered] - logger.info(f"Retrieved {len(sqlite_facts)} general facts from SQLite (Query: '{query}').") + logger.info( + f"Retrieved {len(sqlite_facts)} general facts from SQLite (Query: '{query}')." + ) return sqlite_facts except Exception as e: @@ -551,7 +742,7 @@ class MemoryManager: value_json = json.dumps(value) await self._db_execute( "INSERT OR REPLACE INTO gurt_personality (trait_key, trait_value, last_updated) VALUES (?, ?, unixepoch('now'))", - (key, value_json) + (key, value_json), ) logger.info(f"Personality trait '{key}' set/updated.") except Exception as e: @@ -563,7 +754,9 @@ class MemoryManager: logger.error("get_personality_trait called with empty key.") return None try: - row = await self._db_fetchone("SELECT trait_value FROM gurt_personality WHERE trait_key = ?", (key,)) + row = await self._db_fetchone( + "SELECT trait_value FROM gurt_personality WHERE trait_key = ?", (key,) + ) if row: # Deserialize the JSON string back to its original type value = json.loads(row[0]) @@ -580,14 +773,18 @@ class MemoryManager: """Retrieves all personality traits from the database.""" traits = {} try: - rows = await self._db_fetchall("SELECT trait_key, trait_value FROM gurt_personality", ()) + rows = await self._db_fetchall( + "SELECT trait_key, trait_value FROM gurt_personality", () + ) for key, value_json in rows: try: # Deserialize each value traits[key] = json.loads(value_json) except json.JSONDecodeError as json_e: - logger.error(f"Error decoding JSON for trait '{key}': {json_e}. Value: {value_json}") - traits[key] = None # Or handle error differently + logger.error( + f"Error decoding JSON for trait '{key}': {json_e}. Value: {value_json}" + ) + traits[key] = None # Or handle error differently logger.info(f"Retrieved {len(traits)} personality traits.") return traits except Exception as e: @@ -597,11 +794,15 @@ class MemoryManager: async def load_baseline_personality(self, baseline_traits: Dict[str, Any]): """Loads baseline traits into the personality table ONLY if it's empty.""" if not baseline_traits: - logger.warning("load_baseline_personality called with empty baseline traits.") + logger.warning( + "load_baseline_personality called with empty baseline traits." + ) return try: # Check if the table is empty - count_result = await self._db_fetchone("SELECT COUNT(*) FROM gurt_personality", ()) + count_result = await self._db_fetchone( + "SELECT COUNT(*) FROM gurt_personality", () + ) current_count = count_result[0] if count_result else 0 if current_count == 0: @@ -610,18 +811,24 @@ class MemoryManager: await self.set_personality_trait(key, value) logger.info(f"Loaded {len(baseline_traits)} baseline traits.") else: - logger.info(f"Personality table already contains {current_count} traits. Skipping baseline load.") + logger.info( + f"Personality table already contains {current_count} traits. Skipping baseline load." + ) except Exception as e: logger.error(f"Error loading baseline personality: {e}", exc_info=True) async def load_baseline_interests(self, baseline_interests: Dict[str, float]): """Loads baseline interests into the interests table ONLY if it's empty.""" if not baseline_interests: - logger.warning("load_baseline_interests called with empty baseline interests.") + logger.warning( + "load_baseline_interests called with empty baseline interests." + ) return try: # Check if the table is empty - count_result = await self._db_fetchone("SELECT COUNT(*) FROM gurt_interests", ()) + count_result = await self._db_fetchone( + "SELECT COUNT(*) FROM gurt_interests", () + ) current_count = count_result[0] if count_result else 0 if current_count == 0: @@ -630,20 +837,25 @@ class MemoryManager: async with aiosqlite.connect(self.db_path) as db: for topic, level in baseline_interests.items(): topic_normalized = topic.lower().strip() - if not topic_normalized: continue # Skip empty topics + if not topic_normalized: + continue # Skip empty topics # Clamp initial level just in case - level_clamped = max(INTEREST_MIN_LEVEL, min(INTEREST_MAX_LEVEL, level)) + level_clamped = max( + INTEREST_MIN_LEVEL, min(INTEREST_MAX_LEVEL, level) + ) await db.execute( """ INSERT INTO gurt_interests (interest_topic, interest_level, last_updated) VALUES (?, ?, unixepoch('now')) """, - (topic_normalized, level_clamped) + (topic_normalized, level_clamped), ) await db.commit() logger.info(f"Loaded {len(baseline_interests)} baseline interests.") else: - logger.info(f"Interests table already contains {current_count} interests. Skipping baseline load.") + logger.info( + f"Interests table already contains {current_count} interests. Skipping baseline load." + ) except Exception as e: logger.error(f"Error loading baseline interests: {e}", exc_info=True) @@ -655,7 +867,10 @@ class MemoryManager: logger.info("Cleared all existing personality traits from the database.") await self.load_baseline_personality(baseline_traits) logger.info("Successfully reset personality traits to baseline.") - return {"status": "success", "message": "Personality traits reset to baseline."} + return { + "status": "success", + "message": "Personality traits reset to baseline.", + } except Exception as e: logger.error(f"Error resetting personality traits: {e}", exc_info=True) return {"error": f"Database error resetting personality: {str(e)}"} @@ -673,7 +888,6 @@ class MemoryManager: logger.error(f"Error resetting interests: {e}", exc_info=True) return {"error": f"Database error resetting interests: {str(e)}"} - # --- Interest Methods (SQLite) --- async def update_interest(self, topic: str, change: float): @@ -688,7 +902,7 @@ class MemoryManager: if not topic: logger.error("update_interest called with empty topic.") return - topic = topic.lower().strip() # Normalize topic + topic = topic.lower().strip() # Normalize topic if not topic: logger.error("update_interest called with empty topic after normalization.") return @@ -697,7 +911,10 @@ class MemoryManager: async with self.db_lock: async with aiosqlite.connect(self.db_path) as db: # Check if topic exists - cursor = await db.execute("SELECT interest_level FROM gurt_interests WHERE interest_topic = ?", (topic,)) + cursor = await db.execute( + "SELECT interest_level FROM gurt_interests WHERE interest_topic = ?", + (topic,), + ) row = await cursor.fetchone() if row: @@ -705,12 +922,18 @@ class MemoryManager: new_level = current_level + change else: # Topic doesn't exist, create it with initial level + change - current_level = INTEREST_INITIAL_LEVEL # Use constant for initial level + current_level = ( + INTEREST_INITIAL_LEVEL # Use constant for initial level + ) new_level = current_level + change - logger.info(f"Creating new interest: '{topic}' with initial level {current_level:.3f} + change {change:.3f}") + logger.info( + f"Creating new interest: '{topic}' with initial level {current_level:.3f} + change {change:.3f}" + ) # Clamp the new level - new_level_clamped = max(INTEREST_MIN_LEVEL, min(INTEREST_MAX_LEVEL, new_level)) + new_level_clamped = max( + INTEREST_MIN_LEVEL, min(INTEREST_MAX_LEVEL, new_level) + ) # Insert or update the topic await db.execute( @@ -721,15 +944,19 @@ class MemoryManager: interest_level = excluded.interest_level, last_updated = excluded.last_updated; """, - (topic, new_level_clamped) + (topic, new_level_clamped), ) await db.commit() - logger.info(f"Interest '{topic}' updated: {current_level:.3f} -> {new_level_clamped:.3f} (Change: {change:.3f})") + logger.info( + f"Interest '{topic}' updated: {current_level:.3f} -> {new_level_clamped:.3f} (Change: {change:.3f})" + ) except Exception as e: logger.error(f"Error updating interest '{topic}': {e}", exc_info=True) - async def get_interests(self, limit: int = 5, min_level: float = 0.2) -> List[Tuple[str, float]]: + async def get_interests( + self, limit: int = 5, min_level: float = 0.2 + ) -> List[Tuple[str, float]]: """ Retrieves the top interests above a minimum level, ordered by interest level descending. @@ -744,16 +971,22 @@ class MemoryManager: try: rows = await self._db_fetchall( "SELECT interest_topic, interest_level FROM gurt_interests WHERE interest_level >= ? ORDER BY interest_level DESC LIMIT ?", - (min_level, limit) + (min_level, limit), ) interests = [(row[0], row[1]) for row in rows] - logger.info(f"Retrieved {len(interests)} interests (Limit: {limit}, Min Level: {min_level}).") + logger.info( + f"Retrieved {len(interests)} interests (Limit: {limit}, Min Level: {min_level})." + ) return interests except Exception as e: logger.error(f"Error getting interests: {e}", exc_info=True) return [] - async def decay_interests(self, decay_rate: float = INTEREST_DECAY_RATE, decay_interval_hours: int = INTEREST_DECAY_INTERVAL_HOURS): + async def decay_interests( + self, + decay_rate: float = INTEREST_DECAY_RATE, + decay_interval_hours: int = INTEREST_DECAY_INTERVAL_HOURS, + ): """ Applies decay to interest levels for topics not updated recently. @@ -765,19 +998,23 @@ class MemoryManager: logger.error(f"Invalid decay_rate: {decay_rate}. Must be between 0 and 1.") return if decay_interval_hours <= 0: - logger.error(f"Invalid decay_interval_hours: {decay_interval_hours}. Must be positive.") - return + logger.error( + f"Invalid decay_interval_hours: {decay_interval_hours}. Must be positive." + ) + return try: cutoff_timestamp = time.time() - (decay_interval_hours * 3600) - logger.info(f"Applying interest decay (Rate: {decay_rate}) for interests not updated since {datetime.datetime.fromtimestamp(cutoff_timestamp).isoformat()}...") + logger.info( + f"Applying interest decay (Rate: {decay_rate}) for interests not updated since {datetime.datetime.fromtimestamp(cutoff_timestamp).isoformat()}..." + ) async with self.db_lock: async with aiosqlite.connect(self.db_path) as db: # Select topics eligible for decay cursor = await db.execute( "SELECT interest_topic, interest_level FROM gurt_interests WHERE last_updated < ?", - (cutoff_timestamp,) + (cutoff_timestamp,), ) topics_to_decay = await cursor.fetchall() @@ -798,20 +1035,29 @@ class MemoryManager: if abs(new_level_clamped - current_level) > 0.001: await db.execute( "UPDATE gurt_interests SET interest_level = ? WHERE interest_topic = ?", - (new_level_clamped, topic) + (new_level_clamped, topic), + ) + logger.debug( + f"Decayed interest '{topic}': {current_level:.3f} -> {new_level_clamped:.3f}" ) - logger.debug(f"Decayed interest '{topic}': {current_level:.3f} -> {new_level_clamped:.3f}") updated_count += 1 await db.commit() - logger.info(f"Interest decay cycle complete. Updated {updated_count}/{len(topics_to_decay)} eligible interests.") + logger.info( + f"Interest decay cycle complete. Updated {updated_count}/{len(topics_to_decay)} eligible interests." + ) except Exception as e: logger.error(f"Error during interest decay: {e}", exc_info=True) # --- Semantic Memory Methods (ChromaDB) --- - async def add_message_embedding(self, message_id: str, formatted_message_data: Dict[str, Any], metadata: Dict[str, Any]) -> Dict[str, Any]: + async def add_message_embedding( + self, + message_id: str, + formatted_message_data: Dict[str, Any], + metadata: Dict[str, Any], + ) -> Dict[str, Any]: """ Generates embedding and stores a message (including attachment descriptions) in ChromaDB. @@ -821,80 +1067,118 @@ class MemoryManager: # Construct the text to embed: content + attachment descriptions text_to_embed_parts = [] - if formatted_message_data.get('content'): - text_to_embed_parts.append(formatted_message_data['content']) + if formatted_message_data.get("content"): + text_to_embed_parts.append(formatted_message_data["content"]) - attachment_descs = formatted_message_data.get('attachment_descriptions', []) + attachment_descs = formatted_message_data.get("attachment_descriptions", []) if attachment_descs: # Add a separator if there's content AND attachments if text_to_embed_parts: - text_to_embed_parts.append("\n") # Add newline separator + text_to_embed_parts.append("\n") # Add newline separator # Append descriptions for att in attachment_descs: - text_to_embed_parts.append(att.get('description', '')) + text_to_embed_parts.append(att.get("description", "")) text_to_embed = " ".join(text_to_embed_parts).strip() if not text_to_embed: - # This might happen if a message ONLY contains attachments and no text content, - # but format_message should always produce descriptions. Log if empty. - logger.warning(f"Message {message_id} resulted in empty text_to_embed. Original data: {formatted_message_data}") - return {"error": "Cannot add empty derived text to semantic memory."} + # This might happen if a message ONLY contains attachments and no text content, + # but format_message should always produce descriptions. Log if empty. + logger.warning( + f"Message {message_id} resulted in empty text_to_embed. Original data: {formatted_message_data}" + ) + return {"error": "Cannot add empty derived text to semantic memory."} - logger.info(f"Adding message {message_id} to semantic memory (including attachments).") + logger.info( + f"Adding message {message_id} to semantic memory (including attachments)." + ) try: # ChromaDB expects lists for inputs await asyncio.to_thread( self.semantic_collection.add, - documents=[text_to_embed], # Embed the combined text + documents=[text_to_embed], # Embed the combined text metadatas=[metadata], - ids=[message_id] + ids=[message_id], ) logger.info(f"Successfully added message {message_id} to ChromaDB.") return {"status": "success", "message_id": message_id} except Exception as e: - logger.error(f"ChromaDB error adding message {message_id}: {e}", exc_info=True) + logger.error( + f"ChromaDB error adding message {message_id}: {e}", exc_info=True + ) return {"error": f"Semantic memory error adding message: {str(e)}"} - async def search_semantic_memory(self, query_text: str, n_results: int = 5, filter_metadata: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: + async def search_semantic_memory( + self, + query_text: str, + n_results: int = 5, + filter_metadata: Optional[Dict[str, Any]] = None, + ) -> List[Dict[str, Any]]: """Searches ChromaDB for messages semantically similar to the query text.""" if not self.semantic_collection: - logger.warning("Search semantic memory called, but ChromaDB is not initialized.") + logger.warning( + "Search semantic memory called, but ChromaDB is not initialized." + ) return [] if not query_text: - logger.warning("Search semantic memory called with empty query text.") - return [] + logger.warning("Search semantic memory called with empty query text.") + return [] - logger.info(f"Searching semantic memory (n_results={n_results}, filter={filter_metadata}) for query: '{query_text[:50]}...'") + logger.info( + f"Searching semantic memory (n_results={n_results}, filter={filter_metadata}) for query: '{query_text[:50]}...'" + ) try: # Perform the query in a separate thread as ChromaDB operations can be blocking results = await asyncio.to_thread( self.semantic_collection.query, query_texts=[query_text], n_results=n_results, - where=filter_metadata, # Optional filter based on metadata - include=['metadatas', 'documents', 'distances'] # Include distance for relevance + where=filter_metadata, # Optional filter based on metadata + include=[ + "metadatas", + "documents", + "distances", + ], # Include distance for relevance ) logger.debug(f"ChromaDB query results: {results}") # Process results processed_results = [] - if results and results.get('ids') and results['ids'][0]: - for i, doc_id in enumerate(results['ids'][0]): - processed_results.append({ - "id": doc_id, - "document": results['documents'][0][i] if results.get('documents') else None, - "metadata": results['metadatas'][0][i] if results.get('metadatas') else None, - "distance": results['distances'][0][i] if results.get('distances') else None, - }) + if results and results.get("ids") and results["ids"][0]: + for i, doc_id in enumerate(results["ids"][0]): + processed_results.append( + { + "id": doc_id, + "document": ( + results["documents"][0][i] + if results.get("documents") + else None + ), + "metadata": ( + results["metadatas"][0][i] + if results.get("metadatas") + else None + ), + "distance": ( + results["distances"][0][i] + if results.get("distances") + else None + ), + } + ) logger.info(f"Found {len(processed_results)} semantic results.") return processed_results except Exception as e: - logger.error(f"ChromaDB error searching memory for query '{query_text[:50]}...': {e}", exc_info=True) + logger.error( + f"ChromaDB error searching memory for query '{query_text[:50]}...': {e}", + exc_info=True, + ) return [] - async def delete_user_fact(self, user_id: str, fact_to_delete: str) -> Dict[str, Any]: + async def delete_user_fact( + self, user_id: str, fact_to_delete: str + ) -> Dict[str, Any]: """Deletes a specific fact for a user from both SQLite and ChromaDB.""" if not user_id or not fact_to_delete: return {"error": "user_id and fact_to_delete are required."} @@ -902,28 +1186,55 @@ class MemoryManager: deleted_chroma_id = None try: # Check if fact exists and get chroma_id - row = await self._db_fetchone("SELECT chroma_id FROM user_facts WHERE user_id = ? AND fact = ?", (user_id, fact_to_delete)) + row = await self._db_fetchone( + "SELECT chroma_id FROM user_facts WHERE user_id = ? AND fact = ?", + (user_id, fact_to_delete), + ) if not row: - logger.warning(f"Fact not found in SQLite for user {user_id}: '{fact_to_delete}'") - return {"status": "not_found", "user_id": user_id, "fact": fact_to_delete} + logger.warning( + f"Fact not found in SQLite for user {user_id}: '{fact_to_delete}'" + ) + return { + "status": "not_found", + "user_id": user_id, + "fact": fact_to_delete, + } deleted_chroma_id = row[0] # Delete from SQLite - await self._db_execute("DELETE FROM user_facts WHERE user_id = ? AND fact = ?", (user_id, fact_to_delete)) - logger.info(f"Deleted fact from SQLite for user {user_id}: '{fact_to_delete}'") + await self._db_execute( + "DELETE FROM user_facts WHERE user_id = ? AND fact = ?", + (user_id, fact_to_delete), + ) + logger.info( + f"Deleted fact from SQLite for user {user_id}: '{fact_to_delete}'" + ) # Delete from ChromaDB if chroma_id exists if deleted_chroma_id and self.fact_collection: try: - logger.info(f"Attempting to delete fact from ChromaDB (ID: {deleted_chroma_id}).") - await asyncio.to_thread(self.fact_collection.delete, ids=[deleted_chroma_id]) - logger.info(f"Successfully deleted fact from ChromaDB (ID: {deleted_chroma_id}).") + logger.info( + f"Attempting to delete fact from ChromaDB (ID: {deleted_chroma_id})." + ) + await asyncio.to_thread( + self.fact_collection.delete, ids=[deleted_chroma_id] + ) + logger.info( + f"Successfully deleted fact from ChromaDB (ID: {deleted_chroma_id})." + ) except Exception as chroma_e: - logger.error(f"ChromaDB error deleting user fact ID {deleted_chroma_id}: {chroma_e}", exc_info=True) + logger.error( + f"ChromaDB error deleting user fact ID {deleted_chroma_id}: {chroma_e}", + exc_info=True, + ) # Log error but consider SQLite deletion successful - return {"status": "deleted", "user_id": user_id, "fact_deleted": fact_to_delete} + return { + "status": "deleted", + "user_id": user_id, + "fact_deleted": fact_to_delete, + } except Exception as e: logger.error(f"Error deleting user fact for {user_id}: {e}", exc_info=True) @@ -937,7 +1248,9 @@ class MemoryManager: deleted_chroma_id = None try: # Check if fact exists and get chroma_id - row = await self._db_fetchone("SELECT chroma_id FROM general_facts WHERE fact = ?", (fact_to_delete,)) + row = await self._db_fetchone( + "SELECT chroma_id FROM general_facts WHERE fact = ?", (fact_to_delete,) + ) if not row: logger.warning(f"General fact not found in SQLite: '{fact_to_delete}'") return {"status": "not_found", "fact": fact_to_delete} @@ -945,17 +1258,28 @@ class MemoryManager: deleted_chroma_id = row[0] # Delete from SQLite - await self._db_execute("DELETE FROM general_facts WHERE fact = ?", (fact_to_delete,)) + await self._db_execute( + "DELETE FROM general_facts WHERE fact = ?", (fact_to_delete,) + ) logger.info(f"Deleted general fact from SQLite: '{fact_to_delete}'") # Delete from ChromaDB if chroma_id exists if deleted_chroma_id and self.fact_collection: try: - logger.info(f"Attempting to delete general fact from ChromaDB (ID: {deleted_chroma_id}).") - await asyncio.to_thread(self.fact_collection.delete, ids=[deleted_chroma_id]) - logger.info(f"Successfully deleted general fact from ChromaDB (ID: {deleted_chroma_id}).") + logger.info( + f"Attempting to delete general fact from ChromaDB (ID: {deleted_chroma_id})." + ) + await asyncio.to_thread( + self.fact_collection.delete, ids=[deleted_chroma_id] + ) + logger.info( + f"Successfully deleted general fact from ChromaDB (ID: {deleted_chroma_id})." + ) except Exception as chroma_e: - logger.error(f"ChromaDB error deleting general fact ID {deleted_chroma_id}: {chroma_e}", exc_info=True) + logger.error( + f"ChromaDB error deleting general fact ID {deleted_chroma_id}: {chroma_e}", + exc_info=True, + ) # Log error but consider SQLite deletion successful return {"status": "deleted", "fact_deleted": fact_to_delete} @@ -966,7 +1290,15 @@ class MemoryManager: # --- Goal Management Methods (SQLite) --- - async def add_goal(self, description: str, priority: int = 5, details: Optional[Dict[str, Any]] = None, guild_id: Optional[str] = None, channel_id: Optional[str] = None, user_id: Optional[str] = None) -> Dict[str, Any]: + async def add_goal( + self, + description: str, + priority: int = 5, + details: Optional[Dict[str, Any]] = None, + guild_id: Optional[str] = None, + channel_id: Optional[str] = None, + user_id: Optional[str] = None, + ) -> Dict[str, Any]: """Adds a new goal to the database, including context.""" if not description: return {"error": "Goal description is required."} @@ -974,10 +1306,18 @@ class MemoryManager: details_json = json.dumps(details) if details else None try: # Check if goal already exists - existing = await self._db_fetchone("SELECT goal_id FROM gurt_goals WHERE description = ?", (description,)) + existing = await self._db_fetchone( + "SELECT goal_id FROM gurt_goals WHERE description = ?", (description,) + ) if existing: - logger.warning(f"Goal already exists: '{description}' (ID: {existing[0]})") - return {"status": "duplicate", "goal_id": existing[0], "description": description} + logger.warning( + f"Goal already exists: '{description}' (ID: {existing[0]})" + ) + return { + "status": "duplicate", + "goal_id": existing[0], + "description": description, + } async with self.db_lock: async with aiosqlite.connect(self.db_path) as db: @@ -986,7 +1326,14 @@ class MemoryManager: INSERT INTO gurt_goals (description, priority, details, status, last_updated, guild_id, channel_id, user_id) VALUES (?, ?, ?, 'pending', unixepoch('now'), ?, ?, ?) """, - (description, priority, details_json, guild_id, channel_id, user_id) + ( + description, + priority, + details_json, + guild_id, + channel_id, + user_id, + ), ) await db.commit() goal_id = cursor.lastrowid @@ -996,7 +1343,9 @@ class MemoryManager: logger.error(f"Error adding goal '{description}': {e}", exc_info=True) return {"error": f"Database error adding goal: {str(e)}"} - async def get_goals(self, status: Optional[str] = None, limit: int = 10) -> List[Dict[str, Any]]: + async def get_goals( + self, status: Optional[str] = None, limit: int = 10 + ) -> List[Dict[str, Any]]: """Retrieves goals, optionally filtered by status, ordered by priority.""" logger.info(f"Retrieving goals (Status: {status or 'any'}, Limit: {limit})") goals = [] @@ -1012,27 +1361,37 @@ class MemoryManager: rows = await self._db_fetchall(sql, tuple(params)) for row in rows: details = json.loads(row[6]) if row[6] else None - goals.append({ - "goal_id": row[0], - "description": row[1], - "status": row[2], - "priority": row[3], - "created_timestamp": row[4], - "last_updated": row[5], - "details": details, - "guild_id": row[7], - "channel_id": row[8], - "user_id": row[9] - }) + goals.append( + { + "goal_id": row[0], + "description": row[1], + "status": row[2], + "priority": row[3], + "created_timestamp": row[4], + "last_updated": row[5], + "details": details, + "guild_id": row[7], + "channel_id": row[8], + "user_id": row[9], + } + ) logger.info(f"Retrieved {len(goals)} goals.") return goals except Exception as e: logger.error(f"Error retrieving goals: {e}", exc_info=True) return [] - async def update_goal(self, goal_id: int, status: Optional[str] = None, priority: Optional[int] = None, details: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + async def update_goal( + self, + goal_id: int, + status: Optional[str] = None, + priority: Optional[int] = None, + details: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: """Updates the status, priority, or details of a goal.""" - logger.info(f"Updating goal ID {goal_id} (Status: {status}, Priority: {priority}, Details: {bool(details)})") + logger.info( + f"Updating goal ID {goal_id} (Status: {status}, Priority: {priority}, Details: {bool(details)})" + ) if not any([status, priority is not None, details is not None]): return {"error": "No update parameters provided."} @@ -1073,7 +1432,9 @@ class MemoryManager: try: async with self.db_lock: async with aiosqlite.connect(self.db_path) as db: - cursor = await db.execute("DELETE FROM gurt_goals WHERE goal_id = ?", (goal_id,)) + cursor = await db.execute( + "DELETE FROM gurt_goals WHERE goal_id = ?", (goal_id,) + ) await db.commit() if cursor.rowcount == 0: logger.warning(f"Goal ID {goal_id} not found for deletion.") @@ -1086,16 +1447,31 @@ class MemoryManager: # --- Internal Action Log Methods --- - async def add_internal_action_log(self, tool_name: str, arguments: Optional[Dict[str, Any]], result_summary: str, reasoning: Optional[str] = None) -> Dict[str, Any]: + async def add_internal_action_log( + self, + tool_name: str, + arguments: Optional[Dict[str, Any]], + result_summary: str, + reasoning: Optional[str] = None, + ) -> Dict[str, Any]: """Logs the execution of an internal background action, including reasoning.""" if not tool_name: return {"error": "Tool name is required for logging internal action."} - logger.info(f"Logging internal action: Tool='{tool_name}', Args={arguments}, Reason='{reasoning}', Result='{result_summary[:100]}...'") + logger.info( + f"Logging internal action: Tool='{tool_name}', Args={arguments}, Reason='{reasoning}', Result='{result_summary[:100]}...'" + ) args_json = json.dumps(arguments) if arguments else None # Truncate result summary and reasoning if too long for DB max_len = 1000 - truncated_summary = result_summary[:max_len] + ('...' if len(result_summary) > max_len else '') - truncated_reasoning = reasoning[:max_len] + ('...' if reasoning and len(reasoning) > max_len else '') if reasoning else None + truncated_summary = result_summary[:max_len] + ( + "..." if len(result_summary) > max_len else "" + ) + truncated_reasoning = ( + reasoning[:max_len] + + ("..." if reasoning and len(reasoning) > max_len else "") + if reasoning + else None + ) try: async with self.db_lock: @@ -1105,14 +1481,18 @@ class MemoryManager: INSERT INTO internal_actions (tool_name, arguments_json, reasoning, result_summary, timestamp) VALUES (?, ?, ?, ?, unixepoch('now')) """, - (tool_name, args_json, truncated_reasoning, truncated_summary) + (tool_name, args_json, truncated_reasoning, truncated_summary), ) await db.commit() action_id = cursor.lastrowid - logger.info(f"Internal action logged successfully (ID: {action_id}): Tool='{tool_name}'") + logger.info( + f"Internal action logged successfully (ID: {action_id}): Tool='{tool_name}'" + ) return {"status": "logged", "action_id": action_id} except Exception as e: - logger.error(f"Error logging internal action '{tool_name}': {e}", exc_info=True) + logger.error( + f"Error logging internal action '{tool_name}': {e}", exc_info=True + ) return {"error": f"Database error logging internal action: {str(e)}"} async def get_internal_action_logs(self, limit: int = 10) -> List[Dict[str, Any]]: @@ -1127,18 +1507,20 @@ class MemoryManager: ORDER BY timestamp DESC LIMIT ? """, - (limit,) + (limit,), ) for row in rows: arguments = json.loads(row[3]) if row[3] else None - logs.append({ - "action_id": row[0], - "timestamp": row[1], - "tool_name": row[2], - "arguments": arguments, - "reasoning": row[4], - "result_summary": row[5] - }) + logs.append( + { + "action_id": row[0], + "timestamp": row[1], + "tool_name": row[2], + "arguments": arguments, + "reasoning": row[4], + "result_summary": row[5], + } + ) logger.info(f"Retrieved {len(logs)} internal action logs.") return logs except Exception as e: diff --git a/install_stable_diffusion.py b/install_stable_diffusion.py index d0a018f..2957a35 100644 --- a/install_stable_diffusion.py +++ b/install_stable_diffusion.py @@ -3,6 +3,7 @@ import sys import os import platform + def install_dependencies(): """Install the required dependencies for Stable Diffusion.""" print("Installing Stable Diffusion dependencies...") @@ -14,19 +15,22 @@ def install_dependencies(): "transformers", "accelerate", "tqdm", - "safetensors" + "safetensors", ] # Check if CUDA is available try: import torch + cuda_available = torch.cuda.is_available() if cuda_available: cuda_version = torch.version.cuda print(f"✅ CUDA is available (version {cuda_version})") print(f"GPU: {torch.cuda.get_device_name(0)}") else: - print("⚠️ CUDA is not available. Stable Diffusion will run on CPU (very slow).") + print( + "⚠️ CUDA is not available. Stable Diffusion will run on CPU (very slow)." + ) except ImportError: print("PyTorch not installed yet. Will install with CUDA support.") cuda_available = False @@ -37,11 +41,19 @@ def install_dependencies(): if platform.system() == "Windows": # For Windows, use the PyTorch website command try: - subprocess.check_call([ - sys.executable, "-m", "pip", "install", - "torch", "torchvision", "torchaudio", - "--index-url", "https://download.pytorch.org/whl/cu118" - ]) + subprocess.check_call( + [ + sys.executable, + "-m", + "pip", + "install", + "torch", + "torchvision", + "torchaudio", + "--index-url", + "https://download.pytorch.org/whl/cu118", + ] + ) print("✅ Successfully installed PyTorch with CUDA support") except subprocess.CalledProcessError as e: print(f"❌ Error installing PyTorch: {e}") @@ -59,10 +71,17 @@ def install_dependencies(): if platform.system() == "Windows" and cuda_available: try: print("Installing xformers for memory efficiency...") - subprocess.check_call([ - sys.executable, "-m", "pip", "install", - "xformers", "--index-url", "https://download.pytorch.org/whl/cu118" - ]) + subprocess.check_call( + [ + sys.executable, + "-m", + "pip", + "install", + "xformers", + "--index-url", + "https://download.pytorch.org/whl/cu118", + ] + ) print("✅ Successfully installed xformers") packages.append("xformers") # Add to the list of installed packages except subprocess.CalledProcessError as e: @@ -84,12 +103,15 @@ def install_dependencies(): print("\n✅ All dependencies installed successfully!") print("\nNext steps:") - print("1. Download the Illustrious XL model by running: python download_illustrious.py") + print( + "1. Download the Illustrious XL model by running: python download_illustrious.py" + ) print("2. Restart your bot") print("3. Use the /generate command with a text prompt") print("4. Wait for the image to be generated (this may take some time)") return True + if __name__ == "__main__": install_dependencies() diff --git a/main.py b/main.py index b6622cd..5fc5ffb 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import asyncio + # Set the event loop policy to the default asyncio policy BEFORE other asyncio/discord imports # This is to test if uvloop (if active globally) is causing issues with asyncpg. asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) @@ -13,23 +14,24 @@ import asyncio import subprocess import importlib.util import argparse -import logging # Add logging +import logging # Add logging import asyncpg import redis.asyncio as aioredis from commands import load_all_cogs, reload_all_cogs from error_handler import handle_error, patch_discord_methods, store_interaction_content from utils import reload_script -import settings_manager # Import the settings manager -from db import mod_log_db # Import the new mod log db functions -import command_customization # Import command customization utilities -from global_bot_accessor import set_bot_instance # Import the new accessor -import custom_bot_manager # Import the custom bot manager +import settings_manager # Import the settings manager +from db import mod_log_db # Import the new mod log db functions +import command_customization # Import command customization utilities +from global_bot_accessor import set_bot_instance # Import the new accessor +import custom_bot_manager # Import the custom bot manager # Import the unified API service runner and the sync API module import sys -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from run_unified_api import start_api_in_thread -import discord_bot_sync_api # Import the module to set the cog instance +import discord_bot_sync_api # Import the module to set the cog instance # Import the markdown server from run_markdown_server import start_markdown_server_in_thread @@ -37,6 +39,7 @@ from run_markdown_server import start_markdown_server_in_thread # Check if API dependencies are available try: import uvicorn + API_AVAILABLE = True except ImportError: print("uvicorn not available. API service will not be available.") @@ -47,7 +50,8 @@ load_dotenv() # --- Constants --- DEFAULT_PREFIX = "!" -CORE_COGS = {'SettingsCog', 'HelpCog'} # Cogs that cannot be disabled +CORE_COGS = {"SettingsCog", "HelpCog"} # Cogs that cannot be disabled + # --- Dynamic Prefix Function --- async def get_prefix(bot_instance, message): @@ -60,39 +64,41 @@ async def get_prefix(bot_instance, message): prefix = await settings_manager.get_guild_prefix(message.guild.id, DEFAULT_PREFIX) return prefix + # --- Bot Setup --- # Set up intents (permissions) intents = discord.Intents.default() intents.message_content = True intents.members = True -intents.presences = True # Required for .status / .activity +intents.presences = True # Required for .status / .activity + # --- Custom Bot Class with setup_hook for async initialization --- class MyBot(commands.Bot): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.owner_id = int(os.getenv('OWNER_USER_ID')) - self.core_cogs = CORE_COGS # Attach core cogs list to bot instance - self.settings_manager = settings_manager # Attach settings manager instance - self.pg_pool = None # Will be initialized in setup_hook - self.redis = None # Will be initialized in setup_hook - self.ai_cogs_to_skip = [] # For --disable-ai flag + self.owner_id = int(os.getenv("OWNER_USER_ID")) + self.core_cogs = CORE_COGS # Attach core cogs list to bot instance + self.settings_manager = settings_manager # Attach settings manager instance + self.pg_pool = None # Will be initialized in setup_hook + self.redis = None # Will be initialized in setup_hook + self.ai_cogs_to_skip = [] # For --disable-ai flag async def setup_hook(self): log.info("Running setup_hook...") # Create Postgres pool on this loop self.pg_pool = await asyncpg.create_pool( - dsn=settings_manager.DATABASE_URL, # Use DATABASE_URL from settings_manager + dsn=settings_manager.DATABASE_URL, # Use DATABASE_URL from settings_manager min_size=1, max_size=10, - loop=self.loop # Explicitly use the bot's event loop + loop=self.loop, # Explicitly use the bot's event loop ) log.info("Postgres pool initialized and attached to bot.pg_pool.") # Create Redis client on this loop self.redis = await aioredis.from_url( - settings_manager.REDIS_URL, # Use REDIS_URL from settings_manager + settings_manager.REDIS_URL, # Use REDIS_URL from settings_manager max_connections=10, decode_responses=True, ) @@ -106,15 +112,18 @@ class MyBot(commands.Bot): # Initialize database schema and run migrations using settings_manager if self.pg_pool and self.redis: try: - await settings_manager.initialize_database() # Uses the bot instance via get_bot_instance() + await settings_manager.initialize_database() # Uses the bot instance via get_bot_instance() log.info("Database schema initialization called via settings_manager.") - await settings_manager.run_migrations() # Uses the bot instance via get_bot_instance() + await settings_manager.run_migrations() # Uses the bot instance via get_bot_instance() log.info("Database migrations called via settings_manager.") except Exception as e: - log.exception("CRITICAL: Failed during settings_manager database setup (init/migrations).") + log.exception( + "CRITICAL: Failed during settings_manager database setup (init/migrations)." + ) else: - log.error("CRITICAL: pg_pool or redis_client not initialized in setup_hook. Cannot proceed with settings_manager setup.") - + log.error( + "CRITICAL: pg_pool or redis_client not initialized in setup_hook. Cannot proceed with settings_manager setup." + ) # Setup the moderation log table *after* pool initialization if self.pg_pool: @@ -122,14 +131,19 @@ class MyBot(commands.Bot): await mod_log_db.setup_moderation_log_table(self.pg_pool) log.info("Moderation log table setup complete via setup_hook.") except Exception as e: - log.exception("CRITICAL: Failed to setup moderation log table in setup_hook.") + log.exception( + "CRITICAL: Failed to setup moderation log table in setup_hook." + ) else: - log.warning("pg_pool not available in setup_hook, skipping mod_log_db setup.") + log.warning( + "pg_pool not available in setup_hook, skipping mod_log_db setup." + ) # Load all cogs from the 'cogs' directory, skipping AI if requested await load_all_cogs(self, skip_cogs=self.ai_cogs_to_skip) - log.info(f"Cogs loaded in setup_hook. Skipped: {self.ai_cogs_to_skip or 'None'}") - + log.info( + f"Cogs loaded in setup_hook. Skipped: {self.ai_cogs_to_skip or 'None'}" + ) # Load the lockdown cog separately if needed # Note: load_all_cogs already loads all cogs in the directory. The @@ -142,51 +156,73 @@ class MyBot(commands.Bot): gurt_cog = self.get_cog("Gurt") if gurt_cog: discord_bot_sync_api.gurt_cog_instance = gurt_cog - log.info("Successfully shared GurtCog instance with discord_bot_sync_api via setup_hook.") + log.info( + "Successfully shared GurtCog instance with discord_bot_sync_api via setup_hook." + ) else: log.warning("GurtCog not found after loading cogs in setup_hook.") discord_bot_sync_api.bot_instance = self - log.info("Successfully shared bot instance with discord_bot_sync_api via setup_hook.") + log.info( + "Successfully shared bot instance with discord_bot_sync_api via setup_hook." + ) mod_log_cog = self.get_cog("ModLogCog") if mod_log_cog: discord_bot_sync_api.mod_log_cog_instance = mod_log_cog - log.info("Successfully shared ModLogCog instance with discord_bot_sync_api via setup_hook.") + log.info( + "Successfully shared ModLogCog instance with discord_bot_sync_api via setup_hook." + ) else: log.warning("ModLogCog not found after loading cogs in setup_hook.") except Exception as e: - log.exception(f"Error sharing instances with discord_bot_sync_api in setup_hook: {e}") + log.exception( + f"Error sharing instances with discord_bot_sync_api in setup_hook: {e}" + ) # --- Manually Load FreakTetoCog (only if AI is NOT disabled) --- - if not self.ai_cogs_to_skip: # Check if list is empty (meaning AI is not disabled) + if ( + not self.ai_cogs_to_skip + ): # Check if list is empty (meaning AI is not disabled) try: freak_teto_cog_path = "freak_teto.cog" await self.load_extension(freak_teto_cog_path) - log.info(f"Successfully loaded FreakTetoCog from {freak_teto_cog_path} in setup_hook.") + log.info( + f"Successfully loaded FreakTetoCog from {freak_teto_cog_path} in setup_hook." + ) except commands.ExtensionAlreadyLoaded: - log.info(f"FreakTetoCog ({freak_teto_cog_path}) already loaded (setup_hook).") + log.info( + f"FreakTetoCog ({freak_teto_cog_path}) already loaded (setup_hook)." + ) except commands.ExtensionNotFound: - log.error(f"Error: FreakTetoCog not found at {freak_teto_cog_path} (setup_hook).") + log.error( + f"Error: FreakTetoCog not found at {freak_teto_cog_path} (setup_hook)." + ) except Exception as e: log.exception(f"Failed to load FreakTetoCog in setup_hook: {e}") log.info("setup_hook completed.") + # Create bot instance using the custom class bot = MyBot(command_prefix=get_prefix, intents=intents) # --- Logging Setup --- # Configure logging (adjust level and format as needed) -logging.basicConfig(level=logging.INFO, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s') -log = logging.getLogger(__name__) # Logger for main.py +logging.basicConfig( + level=logging.INFO, format="%(asctime)s:%(levelname)s:%(name)s: %(message)s" +) +log = logging.getLogger(__name__) # Logger for main.py + # --- Events --- @bot.event async def on_ready(): - log.info(f'{bot.user.name} has connected to Discord!') - log.info(f'Bot ID: {bot.user.id}') + log.info(f"{bot.user.name} has connected to Discord!") + log.info(f"Bot ID: {bot.user.id}") # Set the bot's status - await bot.change_presence(activity=discord.Activity(type=discord.ActivityType.listening, name="!help")) + await bot.change_presence( + activity=discord.Activity(type=discord.ActivityType.listening, name="!help") + ) log.info("Bot status set to 'Listening to !help'") # --- Add current guilds to DB --- @@ -200,21 +236,30 @@ async def on_ready(): # Get guilds currently in DB db_records = await conn.fetch("SELECT guild_id FROM guilds") - db_guild_ids = {record['guild_id'] for record in db_records} + db_guild_ids = {record["guild_id"] for record in db_records} log.debug(f"Found {len(db_guild_ids)} guilds in database.") # Add guilds bot joined while offline guilds_to_add = current_guild_ids - db_guild_ids if guilds_to_add: - log.info(f"Adding {len(guilds_to_add)} new guilds to database: {guilds_to_add}") - await conn.executemany("INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT DO NOTHING;", - [(guild_id,) for guild_id in guilds_to_add]) + log.info( + f"Adding {len(guilds_to_add)} new guilds to database: {guilds_to_add}" + ) + await conn.executemany( + "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT DO NOTHING;", + [(guild_id,) for guild_id in guilds_to_add], + ) # Remove guilds bot left while offline guilds_to_remove = db_guild_ids - current_guild_ids if guilds_to_remove: - log.info(f"Removing {len(guilds_to_remove)} guilds from database: {guilds_to_remove}") - await conn.execute("DELETE FROM guilds WHERE guild_id = ANY($1::bigint[])", list(guilds_to_remove)) + log.info( + f"Removing {len(guilds_to_remove)} guilds from database: {guilds_to_remove}" + ) + await conn.execute( + "DELETE FROM guilds WHERE guild_id = ANY($1::bigint[])", + list(guilds_to_remove), + ) log.info("Guild sync with database complete.") except Exception as e: @@ -230,11 +275,13 @@ async def on_ready(): # Make the store_interaction_content function available globally import builtins + builtins.store_interaction_content = store_interaction_content print("Made store_interaction_content available globally") except Exception as e: print(f"Warning: Failed to patch Discord methods: {e}") import traceback + traceback.print_exc() try: print("Starting command sync process...") @@ -250,7 +297,9 @@ async def on_ready(): guild_syncs = await command_customization.register_all_guild_commands(bot) total_guild_syncs = sum(len(cmds) for cmds in guild_syncs.values()) - print(f"Synced commands for {len(guild_syncs)} guilds with a total of {total_guild_syncs} customized commands") + print( + f"Synced commands for {len(guild_syncs)} guilds with a total of {total_guild_syncs} customized commands" + ) # List commands after sync commands_after = [cmd.name for cmd in bot.tree.get_commands()] @@ -259,6 +308,7 @@ async def on_ready(): except Exception as e: print(f"Failed to sync commands: {e}") import traceback + traceback.print_exc() # Start custom bots for users who have enabled them @@ -268,6 +318,7 @@ async def on_ready(): # Import the API database to access user settings try: from api_service.api_server import db + if not db: log.warning("API database not initialized, cannot start custom bots") return @@ -295,27 +346,31 @@ async def on_ready(): token=token, prefix=prefix, status_type=status_type, - status_text=status_text + status_text=status_text, ) if success: # Start the bot success, message = custom_bot_manager.run_custom_bot_in_thread( - user_id=user_id, - token=token + user_id=user_id, token=token ) if success: log.info(f"Successfully started custom bot for user {user_id}") else: - log.error(f"Failed to start custom bot for user {user_id}: {message}") + log.error( + f"Failed to start custom bot for user {user_id}: {message}" + ) else: - log.error(f"Failed to create custom bot for user {user_id}: {message}") + log.error( + f"Failed to create custom bot for user {user_id}: {message}" + ) log.info(f"Found {enabled_bots} users with custom bots enabled") except Exception as e: log.exception(f"Error starting custom bots: {e}") + @bot.event async def on_shard_disconnect(shard_id): print(f"Shard {shard_id} disconnected. Attempting to reconnect...") @@ -325,6 +380,7 @@ async def on_shard_disconnect(shard_id): except Exception as e: print(f"Failed to reconnect shard {shard_id}: {e}") + @bot.event async def on_guild_join(guild: discord.Guild): """Adds guild to database when bot joins and syncs commands.""" @@ -332,7 +388,10 @@ async def on_guild_join(guild: discord.Guild): if bot.pg_pool: try: async with bot.pg_pool.acquire() as conn: - await conn.execute("INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT DO NOTHING;", guild.id) + await conn.execute( + "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT DO NOTHING;", + guild.id, + ) log.info(f"Added guild {guild.id} to database.") # Sync commands for the new guild @@ -347,6 +406,7 @@ async def on_guild_join(guild: discord.Guild): else: log.warning("Bot Postgres pool not initialized, cannot add guild on join.") + @bot.event async def on_guild_remove(guild: discord.Guild): """Removes guild from database when bot leaves.""" @@ -367,59 +427,88 @@ async def on_guild_remove(guild: discord.Guild): @bot.event async def on_command_error(ctx, error): if isinstance(error, CogDisabledError): - await ctx.send(str(error), ephemeral=True) # Send the error message from the exception - log.warning(f"Command '{ctx.command.qualified_name}' blocked for user {ctx.author.id} in guild {ctx.guild.id}: {error}") + await ctx.send( + str(error), ephemeral=True + ) # Send the error message from the exception + log.warning( + f"Command '{ctx.command.qualified_name}' blocked for user {ctx.author.id} in guild {ctx.guild.id}: {error}" + ) elif isinstance(error, CommandPermissionError): - await ctx.send(str(error), ephemeral=True) # Send the error message from the exception - log.warning(f"Command '{ctx.command.qualified_name}' blocked for user {ctx.author.id} in guild {ctx.guild.id}: {error}") + await ctx.send( + str(error), ephemeral=True + ) # Send the error message from the exception + log.warning( + f"Command '{ctx.command.qualified_name}' blocked for user {ctx.author.id} in guild {ctx.guild.id}: {error}" + ) elif isinstance(error, CommandDisabledError): - await ctx.send(str(error), ephemeral=True) # Send the error message from the exception - log.warning(f"Command '{ctx.command.qualified_name}' blocked for user {ctx.author.id} in guild {ctx.guild.id}: {error}") + await ctx.send( + str(error), ephemeral=True + ) # Send the error message from the exception + log.warning( + f"Command '{ctx.command.qualified_name}' blocked for user {ctx.author.id} in guild {ctx.guild.id}: {error}" + ) # Import here to avoid circular imports from cogs.ban_system_cog import UserBannedError + if isinstance(error, UserBannedError): - await ctx.send(error.message, ephemeral=True) # Send the custom ban message - log.warning(f"Command '{ctx.command.qualified_name}' blocked for banned user {ctx.author.id}: {error.message}") + await ctx.send(error.message, ephemeral=True) # Send the custom ban message + log.warning( + f"Command '{ctx.command.qualified_name}' blocked for banned user {ctx.author.id}: {error.message}" + ) else: # Pass other errors to the original handler await handle_error(ctx, error) + @bot.tree.error async def on_app_command_error(interaction, error): # Import here to avoid circular imports from cogs.ban_system_cog import UserBannedError + if isinstance(error, UserBannedError): if not interaction.response.is_done(): await interaction.response.send_message(error.message, ephemeral=True) else: await interaction.followup.send(error.message, ephemeral=True) - log.warning(f"Command blocked for banned user {interaction.user.id}: {error.message}") + log.warning( + f"Command blocked for banned user {interaction.user.id}: {error.message}" + ) else: await handle_error(interaction, error) + # --- Global Command Checks --- + # Need to import SettingsCog to access CORE_COGS, or define CORE_COGS here. # Let's import it, assuming it's safe to do so at the top level. # If it causes circular imports, CORE_COGS needs to be defined elsewhere or passed differently. class CogDisabledError(commands.CheckFailure): """Custom exception for disabled cogs.""" + def __init__(self, cog_name): self.cog_name = cog_name super().__init__(f"The module `{cog_name}` is disabled in this server.") + class CommandPermissionError(commands.CheckFailure): """Custom exception for insufficient command permissions based on roles.""" + def __init__(self, command_name): self.command_name = command_name - super().__init__(f"You do not have the required role to use the command `{command_name}`.") + super().__init__( + f"You do not have the required role to use the command `{command_name}`." + ) + class CommandDisabledError(commands.CheckFailure): """Custom exception for disabled commands.""" + def __init__(self, command_name): self.command_name = command_name super().__init__(f"The command `{command_name}` is disabled in this server.") + @bot.before_invoke async def global_command_checks(ctx: commands.Context): """Global check run before any command invocation.""" @@ -432,7 +521,7 @@ async def global_command_checks(ctx: commands.Context): return command = ctx.command - if not command: # Should not happen with prefix commands, but good practice + if not command: # Should not happen with prefix commands, but good practice return cog = command.cog @@ -442,41 +531,58 @@ async def global_command_checks(ctx: commands.Context): # Ensure author is a Member to get roles if not isinstance(ctx.author, discord.Member): - log.warning(f"Could not perform permission check for user {ctx.author.id} (not a Member object). Allowing command '{command_name}'.") - return # Cannot check roles if not a Member object + log.warning( + f"Could not perform permission check for user {ctx.author.id} (not a Member object). Allowing command '{command_name}'." + ) + return # Cannot check roles if not a Member object member_roles_ids = [role.id for role in ctx.author.roles] # 1. Check if the Cog is enabled # Use CORE_COGS attached to the bot instance - if cog_name and cog_name not in bot.core_cogs: # Don't disable core cogs + if cog_name and cog_name not in bot.core_cogs: # Don't disable core cogs # Assuming default is True if not explicitly set in DB - is_enabled = await settings_manager.is_cog_enabled(guild_id, cog_name, default_enabled=True) + is_enabled = await settings_manager.is_cog_enabled( + guild_id, cog_name, default_enabled=True + ) if not is_enabled: - log.warning(f"Command '{command_name}' blocked in guild {guild_id}: Cog '{cog_name}' is disabled.") + log.warning( + f"Command '{command_name}' blocked in guild {guild_id}: Cog '{cog_name}' is disabled." + ) raise CogDisabledError(cog_name) # 2. Check if the Command is enabled # This only applies if the command has been explicitly disabled - is_cmd_enabled = await settings_manager.is_command_enabled(guild_id, command_name, default_enabled=True) + is_cmd_enabled = await settings_manager.is_command_enabled( + guild_id, command_name, default_enabled=True + ) if not is_cmd_enabled: - log.warning(f"Command '{command_name}' blocked in guild {guild_id}: Command is disabled.") + log.warning( + f"Command '{command_name}' blocked in guild {guild_id}: Command is disabled." + ) raise CommandDisabledError(command_name) # 3. Check command permissions based on roles # This check only applies if specific permissions HAVE been set for this command. # If no permissions are set in the DB, check_command_permission returns True. - has_perm = await settings_manager.check_command_permission(guild_id, command_name, member_roles_ids) + has_perm = await settings_manager.check_command_permission( + guild_id, command_name, member_roles_ids + ) if not has_perm: - log.warning(f"Command '{command_name}' blocked for user {ctx.author.id} in guild {guild_id}: Insufficient role permissions.") + log.warning( + f"Command '{command_name}' blocked for user {ctx.author.id} in guild {guild_id}: Insufficient role permissions." + ) raise CommandPermissionError(command_name) # If all checks pass, the command proceeds. - log.debug(f"Command '{command_name}' passed global checks for user {ctx.author.id} in guild {guild_id}.") + log.debug( + f"Command '{command_name}' passed global checks for user {ctx.author.id} in guild {guild_id}." + ) # --- Bot Commands --- + @commands.command(name="restart", help="Restarts the bot. Owner only.") @commands.is_owner() async def restart(ctx): @@ -485,34 +591,47 @@ async def restart(ctx): await bot.close() # Gracefully close the bot os.execv(sys.executable, [sys.executable] + sys.argv) # Restart the bot process + bot.add_command(restart) -@commands.command(name="gitpull_restart", help="Pulls latest code from git and restarts the bot. Owner only.") + +@commands.command( + name="gitpull_restart", + help="Pulls latest code from git and restarts the bot. Owner only.", +) @commands.is_owner() async def gitpull_restart(ctx): """Pulls latest code from git and restarts the bot. (Owner Only)""" await ctx.send("Pulling latest code from git...") proc = await asyncio.create_subprocess_exec( - "git", "pull", - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + "git", "pull", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await proc.communicate() output = stdout.decode().strip() + "\n" + stderr.decode().strip() if "unstaged changes" in output or "Please commit your changes" in output: - await ctx.send("Unstaged changes detected. Committing changes before pulling...") + await ctx.send( + "Unstaged changes detected. Committing changes before pulling..." + ) commit_proc = await asyncio.create_subprocess_exec( - "git", "commit", "-am", "Git pull and restart command", + "git", + "commit", + "-am", + "Git pull and restart command", stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stderr=asyncio.subprocess.PIPE, ) commit_stdout, commit_stderr = await commit_proc.communicate() - commit_output = commit_stdout.decode().strip() + "\n" + commit_stderr.decode().strip() - await ctx.send(f"Committed changes:\n```\n{commit_output}\n```Trying git pull again...") + commit_output = ( + commit_stdout.decode().strip() + "\n" + commit_stderr.decode().strip() + ) + await ctx.send( + f"Committed changes:\n```\n{commit_output}\n```Trying git pull again..." + ) proc = await asyncio.create_subprocess_exec( - "git", "pull", + "git", + "pull", stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await proc.communicate() output = stdout.decode().strip() + "\n" + stderr.decode().strip() @@ -523,8 +642,10 @@ async def gitpull_restart(ctx): else: await ctx.send(f"Git pull failed:\n```\n{output}\n```") + bot.add_command(gitpull_restart) + @commands.command(name="reload_cogs", help="Reloads all cogs. Owner only.") @commands.is_owner() async def reload_cogs(ctx): @@ -532,44 +653,61 @@ async def reload_cogs(ctx): # Access the disable_ai flag from the bot instance or re-parse args if needed # For simplicity, assume disable_ai is accessible; otherwise, need a way to pass it. # Let's add it to the bot object for easier access later. - skip_list = getattr(bot, 'ai_cogs_to_skip', []) - await ctx.send(f"Reloading all cogs... (Skipping: {', '.join(skip_list) or 'None'})") + skip_list = getattr(bot, "ai_cogs_to_skip", []) + await ctx.send( + f"Reloading all cogs... (Skipping: {', '.join(skip_list) or 'None'})" + ) reloaded_cogs, failed_reload = await reload_all_cogs(bot, skip_cogs=skip_list) if reloaded_cogs: await ctx.send(f"Successfully reloaded cogs: {', '.join(reloaded_cogs)}") if failed_reload: await ctx.send(f"Failed to reload cogs: {', '.join(failed_reload)}") + bot.add_command(reload_cogs) -@commands.command(name="gitpull_reload", help="Pulls latest code from git and reloads all cogs. Owner only.") + +@commands.command( + name="gitpull_reload", + help="Pulls latest code from git and reloads all cogs. Owner only.", +) @commands.is_owner() async def gitpull_reload(ctx): """Pulls latest code from git and reloads all cogs. (Owner Only)""" # Access the disable_ai flag from the bot instance or re-parse args if needed - skip_list = getattr(bot, 'ai_cogs_to_skip', []) - await ctx.send(f"Pulling latest code from git... (Will skip reloading: {', '.join(skip_list) or 'None'})") + skip_list = getattr(bot, "ai_cogs_to_skip", []) + await ctx.send( + f"Pulling latest code from git... (Will skip reloading: {', '.join(skip_list) or 'None'})" + ) proc = await asyncio.create_subprocess_exec( - "git", "pull", - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + "git", "pull", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await proc.communicate() output = stdout.decode().strip() + "\n" + stderr.decode().strip() if "unstaged changes" in output or "Please commit your changes" in output: - await ctx.send("Unstaged changes detected. Committing changes before pulling...") + await ctx.send( + "Unstaged changes detected. Committing changes before pulling..." + ) commit_proc = await asyncio.create_subprocess_exec( - "git", "commit", "-am", "Git pull and reload command", + "git", + "commit", + "-am", + "Git pull and reload command", stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stderr=asyncio.subprocess.PIPE, ) commit_stdout, commit_stderr = await commit_proc.communicate() - commit_output = commit_stdout.decode().strip() + "\n" + commit_stderr.decode().strip() - await ctx.send(f"Committed changes:\n```\n{commit_output}\n```Trying git pull again...") + commit_output = ( + commit_stdout.decode().strip() + "\n" + commit_stderr.decode().strip() + ) + await ctx.send( + f"Committed changes:\n```\n{commit_output}\n```Trying git pull again..." + ) proc = await asyncio.create_subprocess_exec( - "git", "pull", + "git", + "pull", stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await proc.communicate() output = stdout.decode().strip() + "\n" + stderr.decode().strip() @@ -583,21 +721,25 @@ async def gitpull_reload(ctx): else: await ctx.send(f"Git pull failed:\n```\n{output}\n```") + bot.add_command(gitpull_reload) - - # The unified API service is now handled by run_unified_api.py -async def main(args): # Pass parsed args + +async def main(args): # Pass parsed args """Main async function to load cogs and start the bot.""" - TOKEN = os.getenv('DISCORD_TOKEN') + TOKEN = os.getenv("DISCORD_TOKEN") if not TOKEN: - raise ValueError("No token found. Make sure to set DISCORD_TOKEN in your .env file.") + raise ValueError( + "No token found. Make sure to set DISCORD_TOKEN in your .env file." + ) # Start Flask server as a separate process - flask_process = subprocess.Popen([sys.executable, "flask_server.py"], cwd=os.path.dirname(__file__)) + flask_process = subprocess.Popen( + [sys.executable, "flask_server.py"], cwd=os.path.dirname(__file__) + ) # Start the unified API service in a separate thread if available api_thread = None @@ -615,14 +757,18 @@ async def main(args): # Pass parsed args try: print("Starting markdown server for TOS and Privacy Policy...") markdown_thread = start_markdown_server_in_thread(host="0.0.0.0", port=5006) - print("Markdown server started successfully. TOS available at: http://localhost:5006/tos") + print( + "Markdown server started successfully. TOS available at: http://localhost:5006/tos" + ) except Exception as e: print(f"Failed to start markdown server: {e}") # Configure OAuth settings from environment variables oauth_host = os.getenv("OAUTH_HOST", "0.0.0.0") oauth_port = int(os.getenv("OAUTH_PORT", "8080")) - oauth_redirect_uri = os.getenv("DISCORD_REDIRECT_URI", f"http://{oauth_host}:{oauth_port}/oauth/callback") + oauth_redirect_uri = os.getenv( + "DISCORD_REDIRECT_URI", f"http://{oauth_host}:{oauth_port}/oauth/callback" + ) # Update the OAuth redirect URI in the environment os.environ["DISCORD_REDIRECT_URI"] = oauth_redirect_uri @@ -641,9 +787,9 @@ async def main(args): # Pass parsed args # This is now done on the bot instance directly in the MyBot class bot.ai_cogs_to_skip = ai_cogs_to_skip else: - bot.ai_cogs_to_skip = [] # Ensure it exists even if empty + bot.ai_cogs_to_skip = [] # Ensure it exists even if empty - set_bot_instance(bot) # Set the global bot instance + set_bot_instance(bot) # Set the global bot instance log.info(f"Global bot instance set in global_bot_accessor. Bot ID: {id(bot)}") # Pool initialization and cog loading are now handled in MyBot.setup_hook() @@ -655,7 +801,9 @@ async def main(args): # Pass parsed args log.exception(f"An error occurred during bot.start(): {e}") finally: # Terminate the Flask server process when the bot stops - if flask_process and flask_process.poll() is None: # Check if process exists and is running + if ( + flask_process and flask_process.poll() is None + ): # Check if process exists and is running flask_process.terminate() log.info("Flask server process terminated.") else: @@ -675,21 +823,25 @@ async def main(args): # Pass parsed args log.info("Closing Redis pool in main finally block...") await bot.redis.close() if not bot.pg_pool and not bot.redis: - log.info("Pools were not initialized or already closed, skipping close_pools in main.") + log.info( + "Pools were not initialized or already closed, skipping close_pools in main." + ) + # Run the main async function import signal + def handle_sighup(signum, frame): import subprocess import sys import os + try: - print("Received SIGHUP: pulling latest code from /home/git/git (branch master)...") - result = subprocess.run( - ["git", "pull"], - capture_output=True, text=True + print( + "Received SIGHUP: pulling latest code from /home/git/git (branch master)..." ) + result = subprocess.run(["git", "pull"], capture_output=True, text=True) print(result.stdout) print(result.stderr) print("Restarting process after SIGHUP...") @@ -697,7 +849,8 @@ def handle_sighup(signum, frame): print(f"Error during SIGHUP git pull: {e}") os.execv(sys.executable, [sys.executable] + sys.argv) -if __name__ == '__main__': + +if __name__ == "__main__": # Write PID to .pid file for git hook use try: with open(".pid", "w") as f: @@ -716,13 +869,13 @@ if __name__ == '__main__': parser.add_argument( "--disable-ai", action="store_true", - help="Disable AI-related cogs and functionality." + help="Disable AI-related cogs and functionality.", ) args = parser.parse_args() # ------------------------ try: - asyncio.run(main(args)) # Pass parsed args to main + asyncio.run(main(args)) # Pass parsed args to main except KeyboardInterrupt: log.info("Bot stopped by user.") except Exception as e: diff --git a/multi_bot.py b/multi_bot.py index dce5df2..ce9bb09 100644 --- a/multi_bot.py +++ b/multi_bot.py @@ -15,8 +15,12 @@ load_dotenv() # File paths CONFIG_FILE = "data/multi_bot_config.json" -HISTORY_FILE_TEMPLATE = "ai_conversation_history_{}.json" # Will be formatted with bot_id -USER_SETTINGS_FILE_TEMPLATE = "ai_user_settings_{}.json" # Will be formatted with bot_id +HISTORY_FILE_TEMPLATE = ( + "ai_conversation_history_{}.json" # Will be formatted with bot_id +) +USER_SETTINGS_FILE_TEMPLATE = ( + "ai_user_settings_{}.json" # Will be formatted with bot_id +) # Default configuration DEFAULT_CONFIG = { @@ -31,7 +35,7 @@ DEFAULT_CONFIG = { "temperature": 0.7, "timeout": 60, "status_type": "listening", - "status_text": "$ai" + "status_text": "$ai", }, { "id": "miku", @@ -43,12 +47,12 @@ DEFAULT_CONFIG = { "temperature": 0.7, "timeout": 60, "status_type": "listening", - "status_text": ".ai" - } + "status_text": ".ai", + }, ], "api_key": "", # Will be set from environment variable or user input "api_url": "https://openrouter.ai/api/v1/chat/completions", - "compatibility_mode": "openai" + "compatibility_mode": "openai", } # Global variables to store bot instances and their conversation histories @@ -56,6 +60,7 @@ bots = {} conversation_histories = {} user_settings = {} + def load_config(): """Load configuration from file or create default if not exists""" if os.path.exists(CONFIG_FILE): @@ -90,6 +95,7 @@ def load_config(): save_config(config) return config + def save_config(config): """Save configuration to file""" try: @@ -101,6 +107,7 @@ def save_config(config): except Exception as e: print(f"Error saving config: {e}") + def load_conversation_history(bot_id): """Load conversation history for a specific bot""" history_file = HISTORY_FILE_TEMPLATE.format(bot_id) @@ -112,12 +119,15 @@ def load_conversation_history(bot_id): # Convert string keys (from JSON) back to integers data = json.load(f) history = {int(k): v for k, v in data.items()} - print(f"Loaded conversation history for {len(history)} users for bot {bot_id}") + print( + f"Loaded conversation history for {len(history)} users for bot {bot_id}" + ) except Exception as e: print(f"Error loading conversation history for bot {bot_id}: {e}") return history + def save_conversation_history(bot_id, history): """Save conversation history for a specific bot""" history_file = HISTORY_FILE_TEMPLATE.format(bot_id) @@ -130,6 +140,7 @@ def save_conversation_history(bot_id, history): except Exception as e: print(f"Error saving conversation history for bot {bot_id}: {e}") + def load_user_settings(bot_id): """Load user settings for a specific bot""" settings_file = USER_SETTINGS_FILE_TEMPLATE.format(bot_id) @@ -147,6 +158,7 @@ def load_user_settings(bot_id): return settings + def save_user_settings(bot_id, settings): """Save user settings for a specific bot""" settings_file = USER_SETTINGS_FILE_TEMPLATE.format(bot_id) @@ -159,6 +171,7 @@ def save_user_settings(bot_id, settings): except Exception as e: print(f"Error saving user settings for bot {bot_id}: {e}") + def get_user_settings(bot_id, user_id, bot_config): """Get settings for a user with defaults from bot config""" bot_settings = user_settings.get(bot_id, {}) @@ -170,16 +183,20 @@ def get_user_settings(bot_id, user_id, bot_config): settings = bot_settings[user_id] return { "model": settings.get("model", bot_config.get("model", "gpt-3.5-turbo:free")), - "system_prompt": settings.get("system_prompt", bot_config.get("system_prompt", "You are a helpful assistant.")), + "system_prompt": settings.get( + "system_prompt", + bot_config.get("system_prompt", "You are a helpful assistant."), + ), "max_tokens": settings.get("max_tokens", bot_config.get("max_tokens", 1000)), "temperature": settings.get("temperature", bot_config.get("temperature", 0.7)), "timeout": settings.get("timeout", bot_config.get("timeout", 60)), "custom_instructions": settings.get("custom_instructions", ""), "character_info": settings.get("character_info", ""), "character_breakdown": settings.get("character_breakdown", False), - "character": settings.get("character", "") + "character": settings.get("character", ""), } + class SimplifiedAICog(commands.Cog): def __init__(self, bot, bot_id, bot_config, global_config): self.bot = bot @@ -205,14 +222,20 @@ class SimplifiedAICog(commands.Cog): await self.session.close() # Save conversation history and user settings when unloading - save_conversation_history(self.bot_id, conversation_histories.get(self.bot_id, {})) + save_conversation_history( + self.bot_id, conversation_histories.get(self.bot_id, {}) + ) save_user_settings(self.bot_id, user_settings.get(self.bot_id, {})) async def _get_ai_response(self, user_id, prompt, system_prompt=None): """Get a response from the AI API""" api_key = self.global_config.get("api_key", "") - api_url = self.global_config.get("api_url", "https://api.openai.com/v1/chat/completions") - compatibility_mode = self.global_config.get("compatibility_mode", "openai").lower() + api_url = self.global_config.get( + "api_url", "https://api.openai.com/v1/chat/completions" + ) + compatibility_mode = self.global_config.get( + "compatibility_mode", "openai" + ).lower() if not api_key: return "Error: AI API key not configured. Please set the API key in the configuration." @@ -236,38 +259,52 @@ class SimplifiedAICog(commands.Cog): # Replace {{char}} with the character value if provided if settings["character"]: - base_system_prompt = base_system_prompt.replace("{{char}}", settings["character"]) + base_system_prompt = base_system_prompt.replace( + "{{char}}", settings["character"] + ) final_system_prompt = base_system_prompt # Check if any custom settings are provided - has_custom_settings = settings["custom_instructions"] or settings["character_info"] or settings["character_breakdown"] + has_custom_settings = ( + settings["custom_instructions"] + or settings["character_info"] + or settings["character_breakdown"] + ) if has_custom_settings: # Start with the base system prompt custom_prompt_parts = [base_system_prompt] # Add the custom instructions header - custom_prompt_parts.append("\nThe user has provided additional information for you. Please follow their instructions exactly. If anything below contradicts the system prompt above, please take priority over the user's intstructions.") + custom_prompt_parts.append( + "\nThe user has provided additional information for you. Please follow their instructions exactly. If anything below contradicts the system prompt above, please take priority over the user's intstructions." + ) # Add custom instructions if provided if settings["custom_instructions"]: - custom_prompt_parts.append("\n- Custom instructions from the user (prioritize these)\n\n" + settings["custom_instructions"]) + custom_prompt_parts.append( + "\n- Custom instructions from the user (prioritize these)\n\n" + + settings["custom_instructions"] + ) # Add character info if provided if settings["character_info"]: - custom_prompt_parts.append("\n- Additional info about the character you are roleplaying (ignore if the system prompt doesn't indicate roleplaying)\n\n" + settings["character_info"]) + custom_prompt_parts.append( + "\n- Additional info about the character you are roleplaying (ignore if the system prompt doesn't indicate roleplaying)\n\n" + + settings["character_info"] + ) # Add character breakdown flag if set if settings["character_breakdown"]: - custom_prompt_parts.append("\n- The user would like you to provide a breakdown of the character you're roleplaying in your first response. (ignore if the system prompt doesn't indicate roleplaying)") + custom_prompt_parts.append( + "\n- The user would like you to provide a breakdown of the character you're roleplaying in your first response. (ignore if the system prompt doesn't indicate roleplaying)" + ) # Combine all parts into the final system prompt final_system_prompt = "\n".join(custom_prompt_parts) - messages = [ - {"role": "system", "content": final_system_prompt} - ] + messages = [{"role": "system", "content": final_system_prompt}] # Add conversation history (up to last 10 messages to avoid token limits) messages.extend(bot_history[user_id][-10:]) @@ -289,20 +326,17 @@ class SimplifiedAICog(commands.Cog): "messages": messages, "max_tokens": settings["max_tokens"], "temperature": settings["temperature"], - "stream": False + "stream": False, } headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}" + "Authorization": f"Bearer {api_key}", } try: async with self.session.post( - api_url, - headers=headers, - json=payload, - timeout=settings["timeout"] + api_url, headers=headers, json=payload, timeout=settings["timeout"] ) as response: if response.status != 200: error_text = await response.text() @@ -334,51 +368,82 @@ class SimplifiedAICog(commands.Cog): ai_response = data["choices"][0]["message"]["content"] # Check for safety cutoff in OpenAI format - if "finish_reason" in data["choices"][0] and data["choices"][0]["finish_reason"] == "content_filter": + if ( + "finish_reason" in data["choices"][0] + and data["choices"][0]["finish_reason"] == "content_filter" + ): safety_cutoff = True # Check for native_finish_reason: SAFETY - if "native_finish_reason" in data["choices"][0] and data["choices"][0]["native_finish_reason"] == "SAFETY": + if ( + "native_finish_reason" in data["choices"][0] + and data["choices"][0]["native_finish_reason"] == "SAFETY" + ): safety_cutoff = True else: # Custom format - try different response structures # Try standard OpenAI format first - if "choices" in data and data["choices"] and "message" in data["choices"][0]: + if ( + "choices" in data + and data["choices"] + and "message" in data["choices"][0] + ): ai_response = data["choices"][0]["message"]["content"] # Check for safety cutoff in OpenAI format - if "finish_reason" in data["choices"][0] and data["choices"][0]["finish_reason"] == "content_filter": + if ( + "finish_reason" in data["choices"][0] + and data["choices"][0]["finish_reason"] == "content_filter" + ): safety_cutoff = True # Check for native_finish_reason: SAFETY - if "native_finish_reason" in data["choices"][0] and data["choices"][0]["native_finish_reason"] == "SAFETY": + if ( + "native_finish_reason" in data["choices"][0] + and data["choices"][0]["native_finish_reason"] == "SAFETY" + ): safety_cutoff = True # Try Ollama/LM Studio format elif "response" in data: ai_response = data["response"] # Check for safety cutoff in response metadata - if "native_finish_reason" in data and data["native_finish_reason"] == "SAFETY": + if ( + "native_finish_reason" in data + and data["native_finish_reason"] == "SAFETY" + ): safety_cutoff = True # Try text-only format elif "text" in data: ai_response = data["text"] # Check for safety cutoff in response metadata - if "native_finish_reason" in data and data["native_finish_reason"] == "SAFETY": + if ( + "native_finish_reason" in data + and data["native_finish_reason"] == "SAFETY" + ): safety_cutoff = True # Try content-only format elif "content" in data: ai_response = data["content"] # Check for safety cutoff in response metadata - if "native_finish_reason" in data and data["native_finish_reason"] == "SAFETY": + if ( + "native_finish_reason" in data + and data["native_finish_reason"] == "SAFETY" + ): safety_cutoff = True # Try output format elif "output" in data: ai_response = data["output"] # Check for safety cutoff in response metadata - if "native_finish_reason" in data and data["native_finish_reason"] == "SAFETY": + if ( + "native_finish_reason" in data + and data["native_finish_reason"] == "SAFETY" + ): safety_cutoff = True # Try result format elif "result" in data: ai_response = data["result"] # Check for safety cutoff in response metadata - if "native_finish_reason" in data and data["native_finish_reason"] == "SAFETY": + if ( + "native_finish_reason" in data + and data["native_finish_reason"] == "SAFETY" + ): safety_cutoff = True else: # If we can't find a known format, return the raw response for debugging @@ -391,22 +456,28 @@ class SimplifiedAICog(commands.Cog): # Add safety cutoff note if needed if safety_cutoff: - ai_response = f"{ai_response}\n\nThe response was cut off for safety reasons." + ai_response = ( + f"{ai_response}\n\nThe response was cut off for safety reasons." + ) # Update conversation history bot_history[user_id].append({"role": "user", "content": prompt}) - bot_history[user_id].append({"role": "assistant", "content": ai_response}) + bot_history[user_id].append( + {"role": "assistant", "content": ai_response} + ) # Save conversation history to file save_conversation_history(self.bot_id, bot_history) # Save the response to a backup file try: - os.makedirs('ai_responses', exist_ok=True) - backup_file = f'ai_responses/response_{user_id}_{int(datetime.datetime.now().timestamp())}.txt' - with open(backup_file, 'w', encoding='utf-8') as f: + os.makedirs("ai_responses", exist_ok=True) + backup_file = f"ai_responses/response_{user_id}_{int(datetime.datetime.now().timestamp())}.txt" + with open(backup_file, "w", encoding="utf-8") as f: f.write(ai_response) - print(f"AI response backed up to {backup_file} for bot {self.bot_id}") + print( + f"AI response backed up to {backup_file} for bot {self.bot_id}" + ) except Exception as e: print(f"Failed to backup AI response for bot {self.bot_id}: {e}") @@ -416,9 +487,12 @@ class SimplifiedAICog(commands.Cog): return "Error: Request to AI API timed out. Please try again later." except Exception as e: error_message = f"Error communicating with AI API: {str(e)}" - print(f"Exception in _get_ai_response for bot {self.bot_id}: {error_message}") + print( + f"Exception in _get_ai_response for bot {self.bot_id}: {error_message}" + ) print(f"Exception type: {type(e).__name__}") import traceback + traceback.print_exc() return error_message @@ -433,28 +507,34 @@ class SimplifiedAICog(commands.Cog): response = await self._get_ai_response(user_id, prompt) # Check if the response is too long before trying to send it - if len(response) > 1900: # Discord's limit for regular messages is 2000, use 1900 to be safe + if ( + len(response) > 1900 + ): # Discord's limit for regular messages is 2000, use 1900 to be safe try: # Create a text file with the content - with open(f'ai_response_{self.bot_id}.txt', 'w', encoding='utf-8') as f: + with open(f"ai_response_{self.bot_id}.txt", "w", encoding="utf-8") as f: f.write(response) # Send the file instead await ctx.send( "The AI response was too long. Here's the content as a file:", - file=discord.File(f'ai_response_{self.bot_id}.txt') + file=discord.File(f"ai_response_{self.bot_id}.txt"), ) return # Return early to avoid trying to send the message except Exception as e: print(f"Error sending AI response as file for bot {self.bot_id}: {e}") # If sending as a file fails, try splitting the message - chunks = [response[i:i+1900] for i in range(0, len(response), 1900)] - await ctx.send(f"The AI response was too long. Splitting into {len(chunks)} parts:") + chunks = [response[i : i + 1900] for i in range(0, len(response), 1900)] + await ctx.send( + f"The AI response was too long. Splitting into {len(chunks)} parts:" + ) for i, chunk in enumerate(chunks): try: await ctx.send(f"Part {i+1}/{len(chunks)}:\n{chunk}") except Exception as chunk_error: - print(f"Error sending chunk {i+1} for bot {self.bot_id}: {chunk_error}") + print( + f"Error sending chunk {i+1} for bot {self.bot_id}: {chunk_error}" + ) return # Return early after sending chunks # Send the response normally @@ -462,27 +542,39 @@ class SimplifiedAICog(commands.Cog): await ctx.reply(response) except discord.HTTPException as e: print(f"HTTP Exception when sending AI response for bot {self.bot_id}: {e}") - if "Must be 4000 or fewer in length" in str(e) or "Must be 2000 or fewer in length" in str(e): + if "Must be 4000 or fewer in length" in str( + e + ) or "Must be 2000 or fewer in length" in str(e): try: # Create a text file with the content - with open(f'ai_response_{self.bot_id}.txt', 'w', encoding='utf-8') as f: + with open( + f"ai_response_{self.bot_id}.txt", "w", encoding="utf-8" + ) as f: f.write(response) # Send the file instead await ctx.send( "The AI response was too long. Here's the content as a file:", - file=discord.File(f'ai_response_{self.bot_id}.txt') + file=discord.File(f"ai_response_{self.bot_id}.txt"), ) except Exception as file_error: - print(f"Error sending AI response as file (fallback) for bot {self.bot_id}: {file_error}") + print( + f"Error sending AI response as file (fallback) for bot {self.bot_id}: {file_error}" + ) # If sending as a file fails, try splitting the message - chunks = [response[i:i+1900] for i in range(0, len(response), 1900)] - await ctx.send(f"The AI response was too long. Splitting into {len(chunks)} parts:") + chunks = [ + response[i : i + 1900] for i in range(0, len(response), 1900) + ] + await ctx.send( + f"The AI response was too long. Splitting into {len(chunks)} parts:" + ) for i, chunk in enumerate(chunks): try: await ctx.send(f"Part {i+1}/{len(chunks)}:\n{chunk}") except Exception as chunk_error: - print(f"Error sending chunk {i+1} for bot {self.bot_id}: {chunk_error}") + print( + f"Error sending chunk {i+1} for bot {self.bot_id}: {chunk_error}" + ) else: # Log the error but don't re-raise to prevent the command from failing completely print(f"Unexpected HTTP error in AI command for bot {self.bot_id}: {e}") @@ -530,7 +622,9 @@ class SimplifiedAICog(commands.Cog): if setting == "model": # Validate model contains ":free" if ":free" not in value: - response = f"Error: Model name must contain `:free`. Setting not updated." + response = ( + f"Error: Model name must contain `:free`. Setting not updated." + ) else: bot_user_settings[user_id]["model"] = value response = f"Your AI model has been set to: `{value}`" @@ -599,64 +693,94 @@ class SimplifiedAICog(commands.Cog): response = f"Unknown setting: `{setting}`. Available settings: model, system_prompt, max_tokens, temperature, timeout, custom_instructions, character_info, character_breakdown, character" # Save settings to file if we made changes - if response and not response.startswith("Error") and not response.startswith("Unknown"): + if ( + response + and not response.startswith("Error") + and not response.startswith("Unknown") + ): user_settings[self.bot_id] = bot_user_settings save_user_settings(self.bot_id, bot_user_settings) # Check if the response is too long before trying to send it - if len(response) > 1900: # Discord's limit for regular messages is 2000, use 1900 to be safe + if ( + len(response) > 1900 + ): # Discord's limit for regular messages is 2000, use 1900 to be safe try: # Create a text file with the content - with open(f'ai_set_response_{self.bot_id}.txt', 'w', encoding='utf-8') as f: + with open( + f"ai_set_response_{self.bot_id}.txt", "w", encoding="utf-8" + ) as f: f.write(response) # Send the file instead await ctx.send( "The response is too long to display in a message. Here's the content as a file:", - file=discord.File(f'ai_set_response_{self.bot_id}.txt') + file=discord.File(f"ai_set_response_{self.bot_id}.txt"), ) return # Return early to avoid trying to send the message except Exception as e: - print(f"Error sending AI set response as file for bot {self.bot_id}: {e}") + print( + f"Error sending AI set response as file for bot {self.bot_id}: {e}" + ) # If sending as a file fails, try splitting the message - chunks = [response[i:i+1900] for i in range(0, len(response), 1900)] - await ctx.send(f"The response is too long to display in a single message. Splitting into {len(chunks)} parts:") + chunks = [response[i : i + 1900] for i in range(0, len(response), 1900)] + await ctx.send( + f"The response is too long to display in a single message. Splitting into {len(chunks)} parts:" + ) for i, chunk in enumerate(chunks): try: await ctx.send(f"Part {i+1}/{len(chunks)}:\n{chunk}") except Exception as chunk_error: - print(f"Error sending chunk {i+1} for bot {self.bot_id}: {chunk_error}") + print( + f"Error sending chunk {i+1} for bot {self.bot_id}: {chunk_error}" + ) return # Return early after sending chunks # Send the response normally try: await ctx.reply(response) except discord.HTTPException as e: - print(f"HTTP Exception when sending AI set response for bot {self.bot_id}: {e}") - if "Must be 4000 or fewer in length" in str(e) or "Must be 2000 or fewer in length" in str(e): + print( + f"HTTP Exception when sending AI set response for bot {self.bot_id}: {e}" + ) + if "Must be 4000 or fewer in length" in str( + e + ) or "Must be 2000 or fewer in length" in str(e): try: # Create a text file with the content - with open(f'ai_set_response_{self.bot_id}.txt', 'w', encoding='utf-8') as f: + with open( + f"ai_set_response_{self.bot_id}.txt", "w", encoding="utf-8" + ) as f: f.write(response) # Send the file instead await ctx.send( "The response is too long to display in a message. Here's the content as a file:", - file=discord.File(f'ai_set_response_{self.bot_id}.txt') + file=discord.File(f"ai_set_response_{self.bot_id}.txt"), ) except Exception as file_error: - print(f"Error sending AI set response as file (fallback) for bot {self.bot_id}: {file_error}") + print( + f"Error sending AI set response as file (fallback) for bot {self.bot_id}: {file_error}" + ) # If sending as a file fails, try splitting the message - chunks = [response[i:i+1900] for i in range(0, len(response), 1900)] - await ctx.send(f"The response is too long to display in a single message. Splitting into {len(chunks)} parts:") + chunks = [ + response[i : i + 1900] for i in range(0, len(response), 1900) + ] + await ctx.send( + f"The response is too long to display in a single message. Splitting into {len(chunks)} parts:" + ) for i, chunk in enumerate(chunks): try: await ctx.send(f"Part {i+1}/{len(chunks)}:\n{chunk}") except Exception as chunk_error: - print(f"Error sending chunk {i+1} for bot {self.bot_id}: {chunk_error}") + print( + f"Error sending chunk {i+1} for bot {self.bot_id}: {chunk_error}" + ) else: # Log the error but don't re-raise to prevent the command from failing completely - print(f"Unexpected HTTP error in aiset command for bot {self.bot_id}: {e}") + print( + f"Unexpected HTTP error in aiset command for bot {self.bot_id}: {e}" + ) @commands.command(name="aireset") async def reset_user_settings(self, ctx): @@ -689,48 +813,62 @@ class SimplifiedAICog(commands.Cog): ] # Add custom settings if they exist - if settings['custom_instructions']: - settings_info.append(f"\nCustom Instructions: `{settings['custom_instructions'][:50]}{'...' if len(settings['custom_instructions']) > 50 else ''}`") - if settings['character_info']: - settings_info.append(f"Character Info: `{settings['character_info'][:50]}{'...' if len(settings['character_info']) > 50 else ''}`") - if settings['character_breakdown']: + if settings["custom_instructions"]: + settings_info.append( + f"\nCustom Instructions: `{settings['custom_instructions'][:50]}{'...' if len(settings['custom_instructions']) > 50 else ''}`" + ) + if settings["character_info"]: + settings_info.append( + f"Character Info: `{settings['character_info'][:50]}{'...' if len(settings['character_info']) > 50 else ''}`" + ) + if settings["character_breakdown"]: settings_info.append(f"Character Breakdown: `Enabled`") - if settings['character']: - settings_info.append(f"Character: `{settings['character']}` (replaces {{{{char}}}} in system prompt)") + if settings["character"]: + settings_info.append( + f"Character: `{settings['character']}` (replaces {{{{char}}}} in system prompt)" + ) # Add note about custom vs default settings if user_id in bot_user_settings: custom_settings = list(bot_user_settings[user_id].keys()) if custom_settings: - settings_info.append(f"\n*Custom settings: {', '.join(custom_settings)}*") + settings_info.append( + f"\n*Custom settings: {', '.join(custom_settings)}*" + ) else: settings_info.append("\n*All settings are at default values*") response = "\n".join(settings_info) # Check if the response is too long before trying to send it - if len(response) > 1900: # Discord's limit for regular messages is 2000, use 1900 to be safe + if ( + len(response) > 1900 + ): # Discord's limit for regular messages is 2000, use 1900 to be safe try: # Create a text file with the content - with open(f'ai_settings_{self.bot_id}.txt', 'w', encoding='utf-8') as f: + with open(f"ai_settings_{self.bot_id}.txt", "w", encoding="utf-8") as f: f.write(response) # Send the file instead await ctx.send( "Your AI settings are too detailed to display in a message. Here's the content as a file:", - file=discord.File(f'ai_settings_{self.bot_id}.txt') + file=discord.File(f"ai_settings_{self.bot_id}.txt"), ) return # Return early to avoid trying to send the message except Exception as e: print(f"Error sending AI settings as file for bot {self.bot_id}: {e}") # If sending as a file fails, try splitting the message - chunks = [response[i:i+1900] for i in range(0, len(response), 1900)] - await ctx.send(f"Your AI settings are too detailed to display in a single message. Splitting into {len(chunks)} parts:") + chunks = [response[i : i + 1900] for i in range(0, len(response), 1900)] + await ctx.send( + f"Your AI settings are too detailed to display in a single message. Splitting into {len(chunks)} parts:" + ) for i, chunk in enumerate(chunks): try: await ctx.send(f"Part {i+1}/{len(chunks)}:\n{chunk}") except Exception as chunk_error: - print(f"Error sending chunk {i+1} for bot {self.bot_id}: {chunk_error}") + print( + f"Error sending chunk {i+1} for bot {self.bot_id}: {chunk_error}" + ) return # Return early after sending chunks # Send the response normally @@ -738,30 +876,44 @@ class SimplifiedAICog(commands.Cog): await ctx.reply(response) except discord.HTTPException as e: print(f"HTTP Exception when sending AI settings for bot {self.bot_id}: {e}") - if "Must be 4000 or fewer in length" in str(e) or "Must be 2000 or fewer in length" in str(e): + if "Must be 4000 or fewer in length" in str( + e + ) or "Must be 2000 or fewer in length" in str(e): try: # Create a text file with the content - with open(f'ai_settings_{self.bot_id}.txt', 'w', encoding='utf-8') as f: + with open( + f"ai_settings_{self.bot_id}.txt", "w", encoding="utf-8" + ) as f: f.write(response) # Send the file instead await ctx.send( "Your AI settings are too detailed to display in a message. Here's the content as a file:", - file=discord.File(f'ai_settings_{self.bot_id}.txt') + file=discord.File(f"ai_settings_{self.bot_id}.txt"), ) except Exception as file_error: - print(f"Error sending AI settings as file (fallback) for bot {self.bot_id}: {file_error}") + print( + f"Error sending AI settings as file (fallback) for bot {self.bot_id}: {file_error}" + ) # If sending as a file fails, try splitting the message - chunks = [response[i:i+1900] for i in range(0, len(response), 1900)] - await ctx.send(f"Your AI settings are too detailed to display in a single message. Splitting into {len(chunks)} parts:") + chunks = [ + response[i : i + 1900] for i in range(0, len(response), 1900) + ] + await ctx.send( + f"Your AI settings are too detailed to display in a single message. Splitting into {len(chunks)} parts:" + ) for i, chunk in enumerate(chunks): try: await ctx.send(f"Part {i+1}/{len(chunks)}:\n{chunk}") except Exception as chunk_error: - print(f"Error sending chunk {i+1} for bot {self.bot_id}: {chunk_error}") + print( + f"Error sending chunk {i+1} for bot {self.bot_id}: {chunk_error}" + ) else: # Log the error but don't re-raise to prevent the command from failing completely - print(f"Unexpected HTTP error in aisettings command for bot {self.bot_id}: {e}") + print( + f"Unexpected HTTP error in aisettings command for bot {self.bot_id}: {e}" + ) @commands.command(name="ailast") async def get_last_response(self, ctx): @@ -769,13 +921,15 @@ class SimplifiedAICog(commands.Cog): user_id = ctx.author.id # Check if there's a backup file for this user - backup_dir = 'ai_responses' + backup_dir = "ai_responses" if not os.path.exists(backup_dir): await ctx.reply("No backup responses found.") return # Find the most recent backup file for this user - user_files = [f for f in os.listdir(backup_dir) if f.startswith(f'response_{user_id}_')] + user_files = [ + f for f in os.listdir(backup_dir) if f.startswith(f"response_{user_id}_") + ] if not user_files: await ctx.reply("No backup responses found for you.") return @@ -786,16 +940,18 @@ class SimplifiedAICog(commands.Cog): try: # Read the file content - with open(latest_file, 'r', encoding='utf-8') as f: + with open(latest_file, "r", encoding="utf-8") as f: content = f.read() # Send as file to avoid length issues - with open(f'ai_last_response_{self.bot_id}.txt', 'w', encoding='utf-8') as f: + with open( + f"ai_last_response_{self.bot_id}.txt", "w", encoding="utf-8" + ) as f: f.write(content) await ctx.send( f"Here's your last AI response (from {user_files[0].split('_')[-1].replace('.txt', '')}):", - file=discord.File(f'ai_last_response_{self.bot_id}.txt') + file=discord.File(f"ai_last_response_{self.bot_id}.txt"), ) except Exception as e: await ctx.reply(f"Error retrieving last response: {e}") @@ -825,6 +981,7 @@ class SimplifiedAICog(commands.Cog): ) await ctx.reply(help_text) + async def setup_bot(bot_id, bot_config, global_config): """Set up and start a bot with the given configuration""" # Set up intents @@ -836,28 +993,29 @@ async def setup_bot(bot_id, bot_config, global_config): @bot.event async def on_ready(): - print(f'{bot.user.name} (ID: {bot_id}) has connected to Discord!') - print(f'Bot ID: {bot.user.id}') + print(f"{bot.user.name} (ID: {bot_id}) has connected to Discord!") + print(f"Bot ID: {bot.user.id}") # Set the bot's status based on configuration - status_type = bot_config.get('status_type', 'listening').lower() - status_text = bot_config.get('status_text', f"{bot_config.get('prefix', '!')}ai") + status_type = bot_config.get("status_type", "listening").lower() + status_text = bot_config.get( + "status_text", f"{bot_config.get('prefix', '!')}ai" + ) # Map status type to discord.ActivityType activity_type = discord.ActivityType.listening # Default - if status_type == 'playing': + if status_type == "playing": activity_type = discord.ActivityType.playing - elif status_type == 'watching': + elif status_type == "watching": activity_type = discord.ActivityType.watching - elif status_type == 'streaming': + elif status_type == "streaming": activity_type = discord.ActivityType.streaming - elif status_type == 'competing': + elif status_type == "competing": activity_type = discord.ActivityType.competing # Set the presence - await bot.change_presence(activity=discord.Activity( - type=activity_type, - name=status_text - )) + await bot.change_presence( + activity=discord.Activity(type=activity_type, name=status_text) + ) print(f"Bot {bot_id} status set to '{status_type.capitalize()} {status_text}'") # Add the AI cog @@ -869,6 +1027,7 @@ async def setup_bot(bot_id, bot_config, global_config): # Return the bot instance return bot + async def start_bot(bot_id): """Start a bot with the given ID""" if bot_id not in bots: @@ -902,8 +1061,10 @@ async def start_bot(bot_id): print(f"Error starting bot {bot_id}: {e}") return False + def run_bot_in_thread(bot_id): """Run a bot in a separate thread""" + async def _run_bot(): config = load_config() @@ -934,10 +1095,13 @@ def run_bot_in_thread(bot_id): # Create and start the thread loop = asyncio.new_event_loop() - thread = threading.Thread(target=lambda: loop.run_until_complete(_run_bot()), daemon=True) + thread = threading.Thread( + target=lambda: loop.run_until_complete(_run_bot()), daemon=True + ) thread.start() return thread + def start_all_bots(): """Start all configured bots in separate threads""" config = load_config() @@ -952,6 +1116,7 @@ def start_all_bots(): return threads + if __name__ == "__main__": # If run directly, start all bots bot_threads = start_all_bots() @@ -969,6 +1134,7 @@ if __name__ == "__main__": # Sleep to avoid high CPU usage import time + time.sleep(60) except KeyboardInterrupt: print("Stopping all bots...") diff --git a/neru_bot.py b/neru_bot.py index 3ac5aec..588fb4f 100644 --- a/neru_bot.py +++ b/neru_bot.py @@ -16,7 +16,8 @@ load_dotenv() # --- Constants --- DEFAULT_PREFIX = "!" -CORE_COGS = {'SettingsCog', 'HelpCog'} # Cogs that cannot be disabled +CORE_COGS = {"SettingsCog", "HelpCog"} # Cogs that cannot be disabled + # --- Dynamic Prefix Function --- async def get_prefix(bot_instance, message): @@ -29,23 +30,25 @@ async def get_prefix(bot_instance, message): prefix = await settings_manager.get_guild_prefix(message.guild.id, DEFAULT_PREFIX) return prefix + # --- Bot Setup --- # Set up intents (permissions) intents = discord.Intents.default() intents.message_content = True intents.members = True + # --- Custom Bot Class with setup_hook for async initialization --- class NeruBot(commands.Bot): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - owner_id = os.getenv('OWNER_USER_ID') + owner_id = os.getenv("OWNER_USER_ID") if owner_id: self.owner_id = int(owner_id) self.core_cogs = CORE_COGS # Attach core cogs list to bot instance self.settings_manager = settings_manager # Attach settings manager instance self.pg_pool = None # Will be initialized in setup_hook - self.redis = None # Will be initialized in setup_hook + self.redis = None # Will be initialized in setup_hook self.ai_cogs_to_skip = [] # For --disable-ai flag async def setup_hook(self): @@ -56,18 +59,18 @@ class NeruBot(commands.Bot): try: # PostgreSQL connection pool self.pg_pool = await asyncpg.create_pool( - user=os.getenv('POSTGRES_USER'), - password=os.getenv('POSTGRES_PASSWORD'), - host=os.getenv('POSTGRES_HOST'), - database=os.getenv('POSTGRES_SETTINGS_DB') + user=os.getenv("POSTGRES_USER"), + password=os.getenv("POSTGRES_PASSWORD"), + host=os.getenv("POSTGRES_HOST"), + database=os.getenv("POSTGRES_SETTINGS_DB"), ) log.info("PostgreSQL connection pool initialized") # Redis connection self.redis = await aioredis.from_url( f"redis://{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT', '6379')}", - password=os.getenv('REDIS_PASSWORD'), - decode_responses=True + password=os.getenv("REDIS_PASSWORD"), + decode_responses=True, ) log.info("Redis connection initialized") @@ -75,13 +78,19 @@ class NeruBot(commands.Bot): if self.pg_pool and self.redis: try: await settings_manager.initialize_database() - log.info("Database schema initialization called via settings_manager.") + log.info( + "Database schema initialization called via settings_manager." + ) await settings_manager.run_migrations() log.info("Database migrations called via settings_manager.") except Exception as e: - log.exception("CRITICAL: Failed during settings_manager database setup (init/migrations).") + log.exception( + "CRITICAL: Failed during settings_manager database setup (init/migrations)." + ) else: - log.error("CRITICAL: pg_pool or redis_client not initialized in setup_hook. Cannot proceed with settings_manager setup.") + log.error( + "CRITICAL: pg_pool or redis_client not initialized in setup_hook. Cannot proceed with settings_manager setup." + ) # Load only specific cogs try: @@ -124,11 +133,15 @@ class NeruBot(commands.Bot): # Apply global allowed_installs and allowed_contexts to all commands try: - log.info("Applying global allowed_installs and allowed_contexts to all commands...") + log.info( + "Applying global allowed_installs and allowed_contexts to all commands..." + ) for command in self.tree.get_commands(): # Apply decorators to each command app_commands.allowed_installs(guilds=True, users=True)(command) - app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)(command) + app_commands.allowed_contexts( + guilds=True, dms=True, private_channels=True + )(command) # Sync commands globally log.info("Starting global command sync process...") @@ -144,38 +157,49 @@ class NeruBot(commands.Bot): except Exception as e: log.exception(f"Error in setup_hook: {e}") + # Create bot instance using the custom class bot = NeruBot(command_prefix=get_prefix, intents=intents) # --- Logging Setup --- # Configure logging (adjust level and format as needed) -logging.basicConfig(level=logging.INFO, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s') +logging.basicConfig( + level=logging.INFO, format="%(asctime)s:%(levelname)s:%(name)s: %(message)s" +) log = logging.getLogger(__name__) # Logger for neru_bot.py + # --- Events --- @bot.event async def on_ready(): if bot.user: - log.info(f'{bot.user.name} has connected to Discord!') - log.info(f'Bot ID: {bot.user.id}') + log.info(f"{bot.user.name} has connected to Discord!") + log.info(f"Bot ID: {bot.user.id}") # Set the bot's status - await bot.change_presence(activity=discord.Activity(type=discord.ActivityType.listening, name="!help")) + await bot.change_presence( + activity=discord.Activity(type=discord.ActivityType.listening, name="!help") + ) log.info("Bot status set to 'Listening to !help'") + # Error handling @bot.event async def on_command_error(ctx, error): await handle_error(ctx, error) + @bot.tree.error async def on_app_command_error(interaction, error): await handle_error(interaction, error) + async def main(args): """Main async function to load cogs and start the bot.""" - TOKEN = os.getenv('NERU_BOT_TOKEN') + TOKEN = os.getenv("NERU_BOT_TOKEN") if not TOKEN: - raise ValueError("No token found. Make sure to set NERU_BOT_TOKEN in your .env file.") + raise ValueError( + "No token found. Make sure to set NERU_BOT_TOKEN in your .env file." + ) # Set the global bot instance for other modules to access set_bot_instance(bot) @@ -208,12 +232,15 @@ async def main(args): await bot.redis.close() log.info("Redis connection closed.") + if __name__ == "__main__": import argparse # Parse command line arguments parser = argparse.ArgumentParser(description="Run the Neru Discord Bot.") - parser.add_argument('--disable-ai', action='store_true', help='Disable AI functionality') + parser.add_argument( + "--disable-ai", action="store_true", help="Disable AI functionality" + ) args = parser.parse_args() try: diff --git a/oauth_server.py b/oauth_server.py index 9a6daaa..3b49d19 100644 --- a/oauth_server.py +++ b/oauth_server.py @@ -19,45 +19,46 @@ pending_states: Set[str] = set() # Callbacks for successful authorization auth_callbacks: Dict[str, Callable] = {} + async def handle_oauth_callback(request: web.Request) -> web.Response: """Handle OAuth2 callback from Discord.""" # Get the authorization code and state from the request code = request.query.get("code") state = request.query.get("state") - + if not code or not state: return web.Response(text="Missing code or state parameter", status=400) - + # Check if the state is valid if state not in pending_states: return web.Response(text="Invalid state parameter", status=400) - + # Remove the state from pending states pending_states.remove(state) - + try: # Exchange the code for a token token_data = await discord_oauth.exchange_code(code, state) - + # Get the user's information access_token = token_data.get("access_token") if not access_token: return web.Response(text="Failed to get access token", status=500) - + user_info = await discord_oauth.get_user_info(access_token) user_id = user_info.get("id") - + if not user_id: return web.Response(text="Failed to get user ID", status=500) - + # Save the token discord_oauth.save_token(user_id, token_data) - + # Call the callback for this state if it exists callback = auth_callbacks.pop(state, None) if callback: asyncio.create_task(callback(user_id, user_info)) - + # Return a success message return web.Response( text=f""" @@ -80,13 +81,14 @@ async def handle_oauth_callback(request: web.Request) -> web.Response: """, - content_type="text/html" + content_type="text/html", ) except discord_oauth.OAuthError as e: return web.Response(text=f"OAuth error: {str(e)}", status=500) except Exception as e: return web.Response(text=f"Error: {str(e)}", status=500) + async def handle_root(request: web.Request) -> web.Response: """Handle requests to the root path.""" return web.Response( @@ -105,18 +107,19 @@ async def handle_root(request: web.Request) -> web.Response: """, - content_type="text/html" + content_type="text/html", ) + def create_app() -> web.Application: """Create the web application.""" app = web.Application() - app.add_routes([ - web.get("/", handle_root), - web.get("/oauth/callback", handle_oauth_callback) - ]) + app.add_routes( + [web.get("/", handle_root), web.get("/oauth/callback", handle_oauth_callback)] + ) return app + async def start_server(host: str = "0.0.0.0", port: int = 8080) -> None: """Start the OAuth callback server.""" app = create_app() @@ -126,12 +129,14 @@ async def start_server(host: str = "0.0.0.0", port: int = 8080) -> None: await site.start() print(f"OAuth callback server running at http://{host}:{port}") + def register_auth_state(state: str, callback: Optional[Callable] = None) -> None: """Register a pending authorization state.""" pending_states.add(state) if callback: auth_callbacks[state] = callback + if __name__ == "__main__": # For testing the server standalone loop = asyncio.get_event_loop() diff --git a/print_vertex_schema.py b/print_vertex_schema.py index 8bcd7eb..c568b6b 100644 --- a/print_vertex_schema.py +++ b/print_vertex_schema.py @@ -2,6 +2,7 @@ import copy import json from typing import Dict, Any + # --- Schema Preprocessing Helper --- # Copied from discordbot/gurt/api.py def _preprocess_schema_for_vertex(schema: Dict[str, Any]) -> Dict[str, Any]: @@ -17,34 +18,47 @@ def _preprocess_schema_for_vertex(schema: Dict[str, Any]) -> Dict[str, Any]: A new, preprocessed schema dictionary. """ if not isinstance(schema, dict): - return schema # Return non-dict elements as is + return schema # Return non-dict elements as is - processed_schema = copy.deepcopy(schema) # Work on a copy + processed_schema = copy.deepcopy(schema) # Work on a copy for key, value in processed_schema.items(): if key == "type" and isinstance(value, list): # Find the first non-"null" type in the list - first_valid_type = next((t for t in value if isinstance(t, str) and t.lower() != "null"), None) + first_valid_type = next( + (t for t in value if isinstance(t, str) and t.lower() != "null"), None + ) if first_valid_type: processed_schema[key] = first_valid_type else: # Fallback if only "null" or invalid types are present (shouldn't happen in valid schemas) - processed_schema[key] = "object" # Or handle as error - print(f"Warning: Schema preprocessing found list type '{value}' with no valid non-null string type. Falling back to 'object'.") + processed_schema[key] = "object" # Or handle as error + print( + f"Warning: Schema preprocessing found list type '{value}' with no valid non-null string type. Falling back to 'object'." + ) elif isinstance(value, dict): - processed_schema[key] = _preprocess_schema_for_vertex(value) # Recurse for nested objects + processed_schema[key] = _preprocess_schema_for_vertex( + value + ) # Recurse for nested objects elif isinstance(value, list): # Recurse for items within arrays (e.g., in 'properties' of array items) - processed_schema[key] = [_preprocess_schema_for_vertex(item) if isinstance(item, dict) else item for item in value] + processed_schema[key] = [ + _preprocess_schema_for_vertex(item) if isinstance(item, dict) else item + for item in value + ] # Handle 'properties' specifically elif key == "properties" and isinstance(value, dict): - processed_schema[key] = {prop_key: _preprocess_schema_for_vertex(prop_value) for prop_key, prop_value in value.items()} + processed_schema[key] = { + prop_key: _preprocess_schema_for_vertex(prop_value) + for prop_key, prop_value in value.items() + } # Handle 'items' specifically if it's a schema object elif key == "items" and isinstance(value, dict): - processed_schema[key] = _preprocess_schema_for_vertex(value) + processed_schema[key] = _preprocess_schema_for_vertex(value) return processed_schema + # --- Response Schema --- # Copied from discordbot/gurt/config.py RESPONSE_SCHEMA = { @@ -55,26 +69,26 @@ RESPONSE_SCHEMA = { "properties": { "should_respond": { "type": "boolean", - "description": "Whether the bot should send a text message in response." + "description": "Whether the bot should send a text message in response.", }, "content": { "type": "string", - "description": "The text content of the bot's response. Can be empty if only reacting." + "description": "The text content of the bot's response. Can be empty if only reacting.", }, "react_with_emoji": { "type": ["string", "null"], - "description": "Optional: A standard Discord emoji to react with, or null/empty if no reaction." + "description": "Optional: A standard Discord emoji to react with, or null/empty if no reaction.", }, "reply_to_message_id": { "type": ["string", "null"], - "description": "Optional: The ID of the message this response should reply to. Null or omit for a regular message." - } + "description": "Optional: The ID of the message this response should reply to. Null or omit for a regular message.", + }, # Note: tool_requests is handled by Vertex AI's function calling mechanism }, - "required": ["should_respond", "content"] - } + "required": ["should_respond", "content"], + }, } if __name__ == "__main__": - processed = _preprocess_schema_for_vertex(RESPONSE_SCHEMA['schema']) - print(json.dumps(processed, indent=2)) + processed = _preprocess_schema_for_vertex(RESPONSE_SCHEMA["schema"]) + print(json.dumps(processed, indent=2)) diff --git a/run_additional_bots.py b/run_additional_bots.py index ca06f6e..d601ff0 100644 --- a/run_additional_bots.py +++ b/run_additional_bots.py @@ -8,22 +8,29 @@ import multi_bot import gurt_bot import neru_bot + def run_gurt_bot_in_thread(): """Run the Gurt Bot in a separate thread""" loop = asyncio.new_event_loop() - thread = threading.Thread(target=lambda: loop.run_until_complete(gurt_bot.main()), daemon=True) + thread = threading.Thread( + target=lambda: loop.run_until_complete(gurt_bot.main()), daemon=True + ) thread.start() return thread + def run_neru_bot_in_thread(): """Run the Neru Bot in a separate thread""" loop = asyncio.new_event_loop() # Create args object with disable_ai=False - args = type('Args', (), {'disable_ai': False})() - thread = threading.Thread(target=lambda: loop.run_until_complete(neru_bot.main(args)), daemon=True) + args = type("Args", (), {"disable_ai": False})() + thread = threading.Thread( + target=lambda: loop.run_until_complete(neru_bot.main(args)), daemon=True + ) thread.start() return thread + def main(): """Main function to run all additional bots""" print("Starting additional bots (Neru, Miku, and Gurt)...") @@ -40,7 +47,9 @@ def main(): bot_threads.append(("neru", neru_thread)) if not bot_threads: - print("No bots were started. Check your configuration in data/multi_bot_config.json") + print( + "No bots were started. Check your configuration in data/multi_bot_config.json" + ) return print(f"Started {len(bot_threads)} bots.") @@ -69,5 +78,6 @@ def main(): # The threads are daemon threads, so they will be terminated when the main thread exits print("Bots stopped.") + if __name__ == "__main__": main() diff --git a/run_femdom_teto_bot.py b/run_femdom_teto_bot.py index 5a71f7e..3fa77bd 100644 --- a/run_femdom_teto_bot.py +++ b/run_femdom_teto_bot.py @@ -12,7 +12,10 @@ import argparse import logging import asyncpg import redis.asyncio as aioredis -from commands import load_all_cogs, reload_all_cogs # May need to modify or create a new load function +from commands import ( + load_all_cogs, + reload_all_cogs, +) # May need to modify or create a new load function from error_handler import handle_error, patch_discord_methods, store_interaction_content from utils import reload_script import settings_manager @@ -25,7 +28,8 @@ load_dotenv() # --- Constants --- DEFAULT_PREFIX = "!" # Define the specific cogs for this bot -FEMDOM_TETO_COGS = {'cogs.femdom_teto_cog', 'cogs.femdom_roleplay_teto_cog'} +FEMDOM_TETO_COGS = {"cogs.femdom_teto_cog", "cogs.femdom_roleplay_teto_cog"} + # --- Dynamic Prefix Function --- async def get_prefix(bot_instance, message): @@ -38,7 +42,8 @@ async def get_prefix(bot_instance, message): # This bot might need its own prefix setting or share the main bot's # For simplicity, let's use a fixed prefix for now or a different setting key # Using a fixed prefix for this specific bot - return "!" # Or a different prefix like "fd!" + return "!" # Or a different prefix like "fd!" + # --- Bot Setup --- # Set up intents (permissions) @@ -46,13 +51,14 @@ intents = discord.Intents.default() intents.message_content = True intents.members = True + # --- Custom Bot Class with setup_hook for async initialization --- class FemdomTetoBot(commands.Bot): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.owner_id = int(os.getenv('OWNER_USER_ID')) # Assuming owner ID is the same - self.pg_pool = None # Will be initialized in setup_hook - self.redis = None # Will be initialized in setup_hook + self.owner_id = int(os.getenv("OWNER_USER_ID")) # Assuming owner ID is the same + self.pg_pool = None # Will be initialized in setup_hook + self.redis = None # Will be initialized in setup_hook async def setup_hook(self): log.info("Running FemdomTetoBot setup_hook...") @@ -60,10 +66,7 @@ class FemdomTetoBot(commands.Bot): # Create Postgres pool on this loop # This bot might need its own DB or share the main bot's. Sharing is simpler for now. self.pg_pool = await asyncpg.create_pool( - dsn=settings_manager.DATABASE_URL, - min_size=1, - max_size=10, - loop=self.loop + dsn=settings_manager.DATABASE_URL, min_size=1, max_size=10, loop=self.loop ) log.info("Postgres pool initialized and attached to bot.pg_pool.") @@ -113,21 +116,29 @@ class FemdomTetoBot(commands.Bot): log.info("FemdomTetoBot setup_hook completed.") + # Create bot instance using the custom class # This bot will use a different token femdom_teto_bot = FemdomTetoBot(command_prefix=get_prefix, intents=intents) # --- Logging Setup --- -logging.basicConfig(level=logging.INFO, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s') -log = logging.getLogger(__name__) # Logger for this script +logging.basicConfig( + level=logging.INFO, format="%(asctime)s:%(levelname)s:%(name)s: %(message)s" +) +log = logging.getLogger(__name__) # Logger for this script + # --- Events --- @femdom_teto_bot.event async def on_ready(): - log.info(f'{femdom_teto_bot.user.name} has connected to Discord!') - log.info(f'Bot ID: {femdom_teto_bot.user.id}') + log.info(f"{femdom_teto_bot.user.name} has connected to Discord!") + log.info(f"Bot ID: {femdom_teto_bot.user.id}") # Set the bot's status - await femdom_teto_bot.change_presence(activity=discord.Activity(type=discord.ActivityType.listening, name="for commands")) + await femdom_teto_bot.change_presence( + activity=discord.Activity( + type=discord.ActivityType.listening, name="for commands" + ) + ) log.info("Bot status set.") # Patch Discord methods to store message content @@ -137,11 +148,13 @@ async def on_ready(): # Make the store_interaction_content function available globally import builtins + builtins.store_interaction_content = store_interaction_content print("Made store_interaction_content available globally") except Exception as e: print(f"Warning: Failed to patch Discord methods: {e}") import traceback + traceback.print_exc() # Sync commands - This bot only has specific commands from its cogs @@ -155,6 +168,7 @@ async def on_ready(): except Exception as e: print(f"Failed to sync commands for FemdomTetoBot: {e}") import traceback + traceback.print_exc() @@ -163,10 +177,12 @@ async def on_ready(): async def on_command_error(ctx, error): await handle_error(ctx, error) + @femdom_teto_bot.tree.error async def on_app_command_error(interaction, error): await handle_error(interaction, error) + # --- Global Command Checks --- # Need to decide if this bot uses the same global checks or different ones # For now, let's skip global checks for simplicity or adapt them if needed @@ -174,11 +190,14 @@ async def on_app_command_error(interaction, error): # async def global_command_checks(ctx: commands.Context): # pass # Implement checks if necessary + async def main(): """Main async function to load cogs and start the bot.""" - TOKEN = os.getenv('FEMDOM_TETO_DISCORD_TOKEN') # Use a different token + TOKEN = os.getenv("FEMDOM_TETO_DISCORD_TOKEN") # Use a different token if not TOKEN: - raise ValueError("No FEMDOM_TETO_DISCORD_TOKEN found. Make sure to set FEMDOM_TETO_DISCORD_TOKEN in your .env file.") + raise ValueError( + "No FEMDOM_TETO_DISCORD_TOKEN found. Make sure to set FEMDOM_TETO_DISCORD_TOKEN in your .env file." + ) # This bot likely doesn't need to start the Flask or unified API servers # if API_AVAILABLE: @@ -203,10 +222,13 @@ async def main(): log.info("Closing Redis pool in main finally block...") await femdom_teto_bot.redis.close() if not femdom_teto_bot.pg_pool and not femdom_teto_bot.redis: - log.info("Pools were not initialized or already closed, skipping close_pools in main.") + log.info( + "Pools were not initialized or already closed, skipping close_pools in main." + ) + # Run the main async function -if __name__ == '__main__': +if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: diff --git a/run_gurt_bot.py b/run_gurt_bot.py index aac0981..c9240c8 100644 --- a/run_gurt_bot.py +++ b/run_gurt_bot.py @@ -7,9 +7,9 @@ import gurt_bot if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the Gurt Discord Bot.") parser.add_argument( - '--minimal-prompt', - action='store_true', - help='Use a minimal system prompt suitable for fine-tuned models.' + "--minimal-prompt", + action="store_true", + help="Use a minimal system prompt suitable for fine-tuned models.", ) args = parser.parse_args() diff --git a/run_markdown_server.py b/run_markdown_server.py index 8e756de..aa0a9f1 100644 --- a/run_markdown_server.py +++ b/run_markdown_server.py @@ -12,7 +12,9 @@ from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles # Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s') +logging.basicConfig( + level=logging.INFO, format="%(asctime)s:%(levelname)s:%(name)s: %(message)s" +) log = logging.getLogger(__name__) # Try to import markdown, but provide a fallback if it's not available @@ -20,6 +22,7 @@ try: import markdown import markdown.extensions.fenced_code import markdown.extensions.tables + MARKDOWN_AVAILABLE = True except ImportError: MARKDOWN_AVAILABLE = False @@ -113,10 +116,13 @@ HTML_TEMPLATE = """ """ + # Function to read and convert markdown to HTML -def render_markdown(file_path, title, og_title, og_description, og_type, og_url, og_image): +def render_markdown( + file_path, title, og_title, og_description, og_type, og_url, og_image +): try: - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: md_content = f.read() if MARKDOWN_AVAILABLE: @@ -124,10 +130,10 @@ def render_markdown(file_path, title, og_title, og_description, og_type, og_url, html_content = markdown.markdown( md_content, extensions=[ - 'markdown.extensions.fenced_code', - 'markdown.extensions.tables', - 'markdown.extensions.toc' - ] + "markdown.extensions.fenced_code", + "markdown.extensions.tables", + "markdown.extensions.toc", + ], ) else: # Simple fallback if markdown package is not available @@ -142,7 +148,7 @@ def render_markdown(file_path, title, og_title, og_description, og_type, og_url, og_description=og_description, og_type=og_type, og_url=og_url, - og_image=og_image + og_image=og_image, ) except Exception as e: return HTML_TEMPLATE.format( @@ -152,9 +158,10 @@ def render_markdown(file_path, title, og_title, og_description, og_type, og_url, og_description="Failed to render content.", og_type="website", og_url="", - og_image="" + og_image="", ) + # Routes for TOS and Privacy Policy @app.get("/tos", response_class=HTMLResponse) async def get_tos(request: Request): @@ -166,9 +173,10 @@ async def get_tos(request: Request): og_description="Read the Terms of Service for our Discord Bot.", og_type="article", og_url=f"{base_url}tos", - og_image=f"{base_url}static/images/bot_logo.png" # Assuming a static folder for images + og_image=f"{base_url}static/images/bot_logo.png", # Assuming a static folder for images ) + @app.get("/privacy", response_class=HTMLResponse) async def get_privacy(request: Request): base_url = str(request.base_url) @@ -179,9 +187,10 @@ async def get_privacy(request: Request): og_description="Understand how your data is handled by our Discord Bot.", og_type="article", og_url=f"{base_url}privacy", - og_image=f"{base_url}static/images/bot_logo.png" # Assuming a static folder for images + og_image=f"{base_url}static/images/bot_logo.png", # Assuming a static folder for images ) + # Root route that redirects to TOS @app.get("/", response_class=HTMLResponse) async def root(request: Request): @@ -235,6 +244,7 @@ async def root(request: Request): """ + # Function to start the server in a thread def start_markdown_server_in_thread(host="0.0.0.0", port=5006): """Start the markdown server in a separate thread using a different approach @@ -274,11 +284,13 @@ def start_markdown_server_in_thread(host="0.0.0.0", port=5006): client_thread = threading.Thread( target=self.handle_request, args=(client_socket, addr), - daemon=True + daemon=True, ) client_thread.start() except Exception as e: - if self.running: # Only log if we're still supposed to be running + if ( + self.running + ): # Only log if we're still supposed to be running log.exception(f"Error accepting connection: {e}") except Exception as e: log.exception(f"Error starting markdown server: {e}") @@ -289,8 +301,8 @@ def start_markdown_server_in_thread(host="0.0.0.0", port=5006): def handle_request(self, client_socket, addr): try: # Read the HTTP request - request_data = client_socket.recv(1024).decode('utf-8') - request_lines = request_data.split('\n') + request_data = client_socket.recv(1024).decode("utf-8") + request_lines = request_data.split("\n") if not request_lines: return @@ -303,17 +315,17 @@ def start_markdown_server_in_thread(host="0.0.0.0", port=5006): method, path = parts[0], parts[1] # Simple routing - if path == '/' or path == '/index.html': + if path == "/" or path == "/index.html": response = self.serve_root() - elif path == '/tos' or path == '/tos.html': + elif path == "/tos" or path == "/tos.html": response = self.serve_tos() - elif path == '/privacy' or path == '/privacy.html': + elif path == "/privacy" or path == "/privacy.html": response = self.serve_privacy() else: response = self.serve_404() # Send the response - client_socket.sendall(response.encode('utf-8')) + client_socket.sendall(response.encode("utf-8")) except Exception as e: log.exception(f"Error handling request: {e}") finally: @@ -373,20 +385,22 @@ def start_markdown_server_in_thread(host="0.0.0.0", port=5006): def serve_tos(self): try: - with open("TOS.md", 'r', encoding='utf-8') as f: + with open("TOS.md", "r", encoding="utf-8") as f: md_content = f.read() if MARKDOWN_AVAILABLE: html_content = markdown.markdown( md_content, extensions=[ - 'markdown.extensions.fenced_code', - 'markdown.extensions.tables', - 'markdown.extensions.toc' - ] + "markdown.extensions.fenced_code", + "markdown.extensions.tables", + "markdown.extensions.toc", + ], ) else: - html_content = f"
{md_content}
" + html_content = ( + f"
{md_content}
" + ) html = HTML_TEMPLATE.format( title="Terms of Service", @@ -395,7 +409,7 @@ def start_markdown_server_in_thread(host="0.0.0.0", port=5006): og_description="Read the Terms of Service for our Discord Bot.", og_type="article", og_url=f"http://{self.host}:{self.port}/tos", - og_image=f"http://{self.host}:{self.port}/static/images/bot_logo.png" + og_image=f"http://{self.host}:{self.port}/static/images/bot_logo.png", ) return "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n" + html except Exception as e: @@ -404,20 +418,22 @@ def start_markdown_server_in_thread(host="0.0.0.0", port=5006): def serve_privacy(self): try: - with open("PRIVACY_POLICY.md", 'r', encoding='utf-8') as f: + with open("PRIVACY_POLICY.md", "r", encoding="utf-8") as f: md_content = f.read() if MARKDOWN_AVAILABLE: html_content = markdown.markdown( md_content, extensions=[ - 'markdown.extensions.fenced_code', - 'markdown.extensions.tables', - 'markdown.extensions.toc' - ] + "markdown.extensions.fenced_code", + "markdown.extensions.tables", + "markdown.extensions.toc", + ], ) else: - html_content = f"
{md_content}
" + html_content = ( + f"
{md_content}
" + ) html = HTML_TEMPLATE.format( title="Privacy Policy", @@ -426,7 +442,7 @@ def start_markdown_server_in_thread(host="0.0.0.0", port=5006): og_description="Understand how your data is handled by our Discord Bot.", og_type="article", og_url=f"http://{self.host}:{self.port}/privacy", - og_image=f"http://{self.host}:{self.port}/static/images/bot_logo.png" + og_image=f"http://{self.host}:{self.port}/static/images/bot_logo.png", ) return "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n" + html except Exception as e: @@ -445,10 +461,13 @@ def start_markdown_server_in_thread(host="0.0.0.0", port=5006): # Start the server server = MarkdownServer(host, port) server.start() - log.info(f"Markdown server thread started. TOS available at: http://{host}:{port}/tos") + log.info( + f"Markdown server thread started. TOS available at: http://{host}:{port}/tos" + ) return server + def start_server(): """Start the markdown server as a background process (legacy method).""" print("Starting markdown server on port 5006...") @@ -461,7 +480,7 @@ def start_server(): [sys.executable, os.path.join(script_dir, "markdown_server.py")], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - universal_newlines=True + universal_newlines=True, ) # Register a function to terminate the server when this script exits @@ -497,6 +516,7 @@ def start_server(): return True + def run_as_daemon(): """Run the server as a daemon process.""" if start_server(): @@ -508,6 +528,7 @@ def run_as_daemon(): print("Received keyboard interrupt. Shutting down...") sys.exit(0) + if __name__ == "__main__": # If run directly, start the server in the main thread log.info("Starting markdown server on port 5006...") diff --git a/run_neru_bot.py b/run_neru_bot.py index cf780fc..463d722 100644 --- a/run_neru_bot.py +++ b/run_neru_bot.py @@ -6,7 +6,9 @@ import neru_bot if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the Neru Discord Bot.") - parser.add_argument('--disable-ai', action='store_true', help='Disable AI functionality') + parser.add_argument( + "--disable-ai", action="store_true", help="Disable AI functionality" + ) args = parser.parse_args() try: diff --git a/run_unified_api.py b/run_unified_api.py index 570afb6..60b61e6 100644 --- a/run_unified_api.py +++ b/run_unified_api.py @@ -5,7 +5,7 @@ import uvicorn from dotenv import load_dotenv # Add the api_service directory to the Python path -sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'api_service')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "api_service")) # Add the discordbot directory to the Python path (for settings_manager) discordbot_path = os.path.dirname(__file__) @@ -21,6 +21,7 @@ api_port = int(os.getenv("API_PORT", "8001")) # SSL is now handled by a reverse proxy, so we don't configure it here. + def run_unified_api(): """Run the unified API service (dual-stack IPv4+IPv6)""" import threading @@ -31,7 +32,7 @@ def run_unified_api(): "api_service.api_server:app", host=bind_host, port=api_port, - log_level="debug" + log_level="debug", ) try: @@ -49,6 +50,7 @@ def run_unified_api(): except Exception as e: print(f"Error starting unified API service: {e}") + def start_api_in_thread(): """Start the unified API service in a separate thread""" api_thread = threading.Thread(target=run_unified_api) @@ -57,6 +59,7 @@ def start_api_in_thread(): print("Unified API service started in background thread") return api_thread + if __name__ == "__main__": # Run the API directly if this script is executed run_unified_api() diff --git a/run_wheatley_bot.py b/run_wheatley_bot.py index 8b5a586..62aa4d3 100644 --- a/run_wheatley_bot.py +++ b/run_wheatley_bot.py @@ -1,12 +1,12 @@ import os import sys import asyncio -import wheatley_bot # Changed import from gurt_bot +import wheatley_bot # Changed import from gurt_bot if __name__ == "__main__": try: - asyncio.run(wheatley_bot.main()) # Changed function call + asyncio.run(wheatley_bot.main()) # Changed function call except KeyboardInterrupt: - print("Wheatley Bot stopped by user.") # Changed print statement + print("Wheatley Bot stopped by user.") # Changed print statement except Exception as e: - print(f"An error occurred running Wheatley Bot: {e}") # Changed print statement + print(f"An error occurred running Wheatley Bot: {e}") # Changed print statement diff --git a/settings_manager.py b/settings_manager.py index 1b29c37..c3084f9 100644 --- a/settings_manager.py +++ b/settings_manager.py @@ -7,22 +7,24 @@ import asyncio from dotenv import load_dotenv from typing import Dict -from global_bot_accessor import get_bot_instance # Import the accessor +from global_bot_accessor import get_bot_instance # Import the accessor # Load environment variables -load_dotenv(dotenv_path=os.path.join(os.path.dirname(__file__), '.env')) +load_dotenv(dotenv_path=os.path.join(os.path.dirname(__file__), ".env")) # --- Configuration --- POSTGRES_USER = os.getenv("POSTGRES_USER") POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD") POSTGRES_HOST = os.getenv("POSTGRES_HOST") -POSTGRES_DB = os.getenv("POSTGRES_SETTINGS_DB") # Use the new settings DB +POSTGRES_DB = os.getenv("POSTGRES_SETTINGS_DB") # Use the new settings DB REDIS_HOST = os.getenv("REDIS_HOST") REDIS_PORT = os.getenv("REDIS_PORT", 6379) -REDIS_PASSWORD = os.getenv("REDIS_PASSWORD") # Optional +REDIS_PASSWORD = os.getenv("REDIS_PASSWORD") # Optional -DATABASE_URL = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}/{POSTGRES_DB}" -REDIS_URL = f"redis://{':' + REDIS_PASSWORD + '@' if REDIS_PASSWORD else ''}{REDIS_HOST}:{REDIS_PORT}/0" # Use DB 0 for settings cache +DATABASE_URL = ( + f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}/{POSTGRES_DB}" +) +REDIS_URL = f"redis://{':' + REDIS_PASSWORD + '@' if REDIS_PASSWORD else ''}{REDIS_HOST}:{REDIS_PORT}/0" # Use DB 0 for settings cache # --- Module-level Connection Pools (to be set by the bot) --- # _active_pg_pool = None # Removed @@ -54,36 +56,47 @@ log = logging.getLogger(__name__) # initialize_pools and close_pools are removed as pool lifecycle is managed by the bot. + # --- Database Schema Initialization --- async def run_migrations(): """Run database migrations to update schema.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.error("Bot instance or PostgreSQL pool not available in settings_manager. Cannot run migrations.") + log.error( + "Bot instance or PostgreSQL pool not available in settings_manager. Cannot run migrations." + ) return log.info("Running database migrations...") try: async with bot.pg_pool.acquire() as conn: # Check if custom_command_description column exists in command_customization table - column_exists = await conn.fetchval(""" + column_exists = await conn.fetchval( + """ SELECT EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_name = 'command_customization' AND column_name = 'custom_command_description' ); - """) + """ + ) if not column_exists: - log.info("Adding custom_command_description column to command_customization table...") - await conn.execute(""" + log.info( + "Adding custom_command_description column to command_customization table..." + ) + await conn.execute( + """ ALTER TABLE command_customization ADD COLUMN custom_command_description TEXT; - """) + """ + ) log.info("Added custom_command_description column successfully.") else: - log.debug("custom_command_description column already exists in command_customization table.") + log.debug( + "custom_command_description column already exists in command_customization table." + ) except Exception as e: log.exception(f"Error running database migrations: {e}") @@ -93,21 +106,26 @@ async def initialize_database(): """Creates necessary tables in the PostgreSQL database if they don't exist.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.error("Bot instance or PostgreSQL pool not available in settings_manager. Cannot initialize database.") + log.error( + "Bot instance or PostgreSQL pool not available in settings_manager. Cannot initialize database." + ) return log.info("Initializing database schema...") async with bot.pg_pool.acquire() as conn: async with conn.transaction(): # Guilds table (to track known guilds, maybe store basic info later) - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS guilds ( guild_id BIGINT PRIMARY KEY ); - """) + """ + ) # Guild Settings table (key-value store for various settings) - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS guild_settings ( guild_id BIGINT NOT NULL, setting_key TEXT NOT NULL, @@ -115,11 +133,13 @@ async def initialize_database(): PRIMARY KEY (guild_id, setting_key), FOREIGN KEY (guild_id) REFERENCES guilds(guild_id) ON DELETE CASCADE ); - """) + """ + ) # Example setting_keys: 'prefix', 'welcome_channel_id', 'welcome_message', 'goodbye_channel_id', 'goodbye_message' # Enabled Cogs table - Stores the explicit enabled/disabled state - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS enabled_cogs ( guild_id BIGINT NOT NULL, cog_name TEXT NOT NULL, @@ -127,10 +147,12 @@ async def initialize_database(): PRIMARY KEY (guild_id, cog_name), FOREIGN KEY (guild_id) REFERENCES guilds(guild_id) ON DELETE CASCADE ); - """) + """ + ) # Enabled Commands table - Stores the explicit enabled/disabled state for individual commands - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS enabled_commands ( guild_id BIGINT NOT NULL, command_name TEXT NOT NULL, @@ -138,10 +160,12 @@ async def initialize_database(): PRIMARY KEY (guild_id, command_name), FOREIGN KEY (guild_id) REFERENCES guilds(guild_id) ON DELETE CASCADE ); - """) + """ + ) # Command Permissions table (simple role-based for now) - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS command_permissions ( guild_id BIGINT NOT NULL, command_name TEXT NOT NULL, @@ -149,10 +173,12 @@ async def initialize_database(): PRIMARY KEY (guild_id, command_name, allowed_role_id), FOREIGN KEY (guild_id) REFERENCES guilds(guild_id) ON DELETE CASCADE ); - """) + """ + ) # Command Customization table - Stores guild-specific command names and descriptions - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS command_customization ( guild_id BIGINT NOT NULL, original_command_name TEXT NOT NULL, @@ -161,10 +187,12 @@ async def initialize_database(): PRIMARY KEY (guild_id, original_command_name), FOREIGN KEY (guild_id) REFERENCES guilds(guild_id) ON DELETE CASCADE ); - """) + """ + ) # Command Group Customization table - Stores guild-specific command group names - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS command_group_customization ( guild_id BIGINT NOT NULL, original_group_name TEXT NOT NULL, @@ -172,10 +200,12 @@ async def initialize_database(): PRIMARY KEY (guild_id, original_group_name), FOREIGN KEY (guild_id) REFERENCES guilds(guild_id) ON DELETE CASCADE ); - """) + """ + ) # Command Aliases table - Stores additional aliases for commands - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS command_aliases ( guild_id BIGINT NOT NULL, original_command_name TEXT NOT NULL, @@ -183,10 +213,12 @@ async def initialize_database(): PRIMARY KEY (guild_id, original_command_name, alias_name), FOREIGN KEY (guild_id) REFERENCES guilds(guild_id) ON DELETE CASCADE ); - """) + """ + ) # Starboard Settings table - Stores configuration for the starboard feature - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS starboard_settings ( guild_id BIGINT PRIMARY KEY, enabled BOOLEAN NOT NULL DEFAULT TRUE, @@ -197,10 +229,12 @@ async def initialize_database(): self_star BOOLEAN NOT NULL DEFAULT FALSE, FOREIGN KEY (guild_id) REFERENCES guilds(guild_id) ON DELETE CASCADE ); - """) + """ + ) # Starboard Entries table - Tracks which messages have been reposted to the starboard - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS starboard_entries ( id SERIAL PRIMARY KEY, guild_id BIGINT NOT NULL, @@ -213,10 +247,12 @@ async def initialize_database(): UNIQUE(guild_id, original_message_id), FOREIGN KEY (guild_id) REFERENCES guilds(guild_id) ON DELETE CASCADE ); - """) + """ + ) # Starboard Reactions table - Tracks which users have starred which messages - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS starboard_reactions ( guild_id BIGINT NOT NULL, message_id BIGINT NOT NULL, @@ -224,10 +260,12 @@ async def initialize_database(): PRIMARY KEY (guild_id, message_id, user_id), FOREIGN KEY (guild_id) REFERENCES guilds(guild_id) ON DELETE CASCADE ); - """) + """ + ) # Git Monitored Repositories table - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS git_monitored_repositories ( id SERIAL PRIMARY KEY, guild_id BIGINT NOT NULL, @@ -247,39 +285,59 @@ async def initialize_database(): CONSTRAINT uq_guild_repo_channel UNIQUE (guild_id, repository_url, notification_channel_id), FOREIGN KEY (guild_id) REFERENCES guilds(guild_id) ON DELETE CASCADE ); - """) + """ + ) # Add indexes for faster lookups - await conn.execute("CREATE INDEX IF NOT EXISTS idx_git_monitored_repo_guild ON git_monitored_repositories (guild_id);") - await conn.execute("CREATE INDEX IF NOT EXISTS idx_git_monitored_repo_method ON git_monitored_repositories (monitoring_method);") - await conn.execute("CREATE INDEX IF NOT EXISTS idx_git_monitored_repo_url ON git_monitored_repositories (repository_url);") + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_git_monitored_repo_guild ON git_monitored_repositories (guild_id);" + ) + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_git_monitored_repo_method ON git_monitored_repositories (monitoring_method);" + ) + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_git_monitored_repo_url ON git_monitored_repositories (repository_url);" + ) # Migration: Add allowed_webhook_events column if it doesn't exist and set default for old rows - column_exists_git_events = await conn.fetchval(""" + column_exists_git_events = await conn.fetchval( + """ SELECT EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_name = 'git_monitored_repositories' AND column_name = 'allowed_webhook_events' ); - """) + """ + ) if not column_exists_git_events: - log.info("Adding allowed_webhook_events column to git_monitored_repositories table...") - await conn.execute(""" + log.info( + "Adding allowed_webhook_events column to git_monitored_repositories table..." + ) + await conn.execute( + """ ALTER TABLE git_monitored_repositories ADD COLUMN allowed_webhook_events TEXT[] DEFAULT ARRAY['push']::TEXT[]; - """) + """ + ) # Update existing rows to have a default value if they are NULL - await conn.execute(""" + await conn.execute( + """ UPDATE git_monitored_repositories SET allowed_webhook_events = ARRAY['push']::TEXT[] WHERE allowed_webhook_events IS NULL; - """) - log.info("Added allowed_webhook_events column and set default for existing rows.") + """ + ) + log.info( + "Added allowed_webhook_events column and set default for existing rows." + ) else: - log.debug("allowed_webhook_events column already exists in git_monitored_repositories table.") + log.debug( + "allowed_webhook_events column already exists in git_monitored_repositories table." + ) # Logging Event Toggles table - Stores enabled/disabled state per event type - await conn.execute(""" + await conn.execute( + """ CREATE TABLE IF NOT EXISTS logging_event_toggles ( guild_id BIGINT NOT NULL, event_key TEXT NOT NULL, -- e.g., 'member_join', 'audit_kick' @@ -287,7 +345,8 @@ async def initialize_database(): PRIMARY KEY (guild_id, event_key), FOREIGN KEY (guild_id) REFERENCES guilds(guild_id) ON DELETE CASCADE ); - """) + """ + ) # Consider adding indexes later for performance on large tables # await conn.execute("CREATE INDEX IF NOT EXISTS idx_guild_settings_guild ON guild_settings (guild_id);") @@ -304,11 +363,14 @@ async def initialize_database(): # --- Starboard Functions --- + async def get_starboard_settings(guild_id: int): """Gets the starboard settings for a guild.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.warning(f"Bot instance or PostgreSQL pool not available in settings_manager for get_starboard_settings (guild {guild_id}).") + log.warning( + f"Bot instance or PostgreSQL pool not available in settings_manager for get_starboard_settings (guild {guild_id})." + ) return None try: @@ -318,7 +380,7 @@ async def get_starboard_settings(guild_id: int): """ SELECT * FROM starboard_settings WHERE guild_id = $1 """, - guild_id + guild_id, ) if settings: @@ -331,7 +393,7 @@ async def get_starboard_settings(guild_id: int): VALUES ($1) ON CONFLICT (guild_id) DO NOTHING; """, - guild_id + guild_id, ) # Fetch the newly inserted default settings @@ -339,14 +401,17 @@ async def get_starboard_settings(guild_id: int): """ SELECT * FROM starboard_settings WHERE guild_id = $1 """, - guild_id + guild_id, ) return dict(settings) if settings else None except Exception as e: - log.exception(f"Database error getting starboard settings for guild {guild_id}: {e}") + log.exception( + f"Database error getting starboard settings for guild {guild_id}: {e}" + ) return None + async def update_starboard_settings(guild_id: int, **kwargs): """Updates starboard settings for a guild. @@ -360,14 +425,25 @@ async def update_starboard_settings(guild_id: int, **kwargs): """ bot = get_bot_instance() if not bot or not bot.pg_pool: - log.error(f"Bot instance or PostgreSQL pool not available in settings_manager for update_starboard_settings (guild {guild_id}).") + log.error( + f"Bot instance or PostgreSQL pool not available in settings_manager for update_starboard_settings (guild {guild_id})." + ) return False - valid_keys = {'enabled', 'star_emoji', 'threshold', 'starboard_channel_id', 'ignore_bots', 'self_star'} + valid_keys = { + "enabled", + "star_emoji", + "threshold", + "starboard_channel_id", + "ignore_bots", + "self_star", + } update_dict = {k: v for k, v in kwargs.items() if k in valid_keys} if not update_dict: - log.warning(f"No valid settings provided for starboard update for guild {guild_id}") + log.warning( + f"No valid settings provided for starboard update for guild {guild_id}" + ) return False # Use a timeout to prevent hanging on database operations @@ -379,10 +455,17 @@ async def update_starboard_settings(guild_id: int, **kwargs): # Ensure guild exists try: - await conn.execute("INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", guild_id) + await conn.execute( + "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", + guild_id, + ) except Exception as e: - if "another operation is in progress" in str(e) or "attached to a different loop" in str(e): - log.warning(f"Connection issue when inserting guild {guild_id}: {e}") + if "another operation is in progress" in str( + e + ) or "attached to a different loop" in str(e): + log.warning( + f"Connection issue when inserting guild {guild_id}: {e}" + ) # Try to reset the connection await conn.close() conn = await asyncio.wait_for(bot.pg_pool.acquire(), timeout=5.0) @@ -390,7 +473,9 @@ async def update_starboard_settings(guild_id: int, **kwargs): raise # Build the SET clause for the UPDATE statement - set_clause = ", ".join(f"{key} = ${i+2}" for i, key in enumerate(update_dict.keys())) + set_clause = ", ".join( + f"{key} = ${i+2}" for i, key in enumerate(update_dict.keys()) + ) values = [guild_id] + list(update_dict.values()) # Update the settings @@ -401,11 +486,15 @@ async def update_starboard_settings(guild_id: int, **kwargs): VALUES ($1) ON CONFLICT (guild_id) DO UPDATE SET {set_clause}; """, - *values + *values, ) except Exception as e: - if "another operation is in progress" in str(e) or "attached to a different loop" in str(e): - log.warning(f"Connection issue when updating starboard settings for guild {guild_id}: {e}") + if "another operation is in progress" in str( + e + ) or "attached to a different loop" in str(e): + log.warning( + f"Connection issue when updating starboard settings for guild {guild_id}: {e}" + ) # Try to reset the connection await conn.close() conn = await asyncio.wait_for(bot.pg_pool.acquire(), timeout=5.0) @@ -417,7 +506,7 @@ async def update_starboard_settings(guild_id: int, **kwargs): VALUES ($1) ON CONFLICT (guild_id) DO UPDATE SET {set_clause}; """, - *values + *values, ) else: raise @@ -429,17 +518,24 @@ async def update_starboard_settings(guild_id: int, **kwargs): if conn: await bot.pg_pool.release(conn) except asyncio.TimeoutError: - log.error(f"Timeout acquiring database connection for starboard settings update (Guild: {guild_id})") + log.error( + f"Timeout acquiring database connection for starboard settings update (Guild: {guild_id})" + ) return False except Exception as e: - log.exception(f"Database error updating starboard settings for guild {guild_id}: {e}") + log.exception( + f"Database error updating starboard settings for guild {guild_id}: {e}" + ) return False + async def get_starboard_entry(guild_id: int, original_message_id: int): """Gets a starboard entry for a specific message.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.warning(f"Bot instance or PostgreSQL pool not available in settings_manager for get_starboard_entry (guild {guild_id}).") + log.warning( + f"Bot instance or PostgreSQL pool not available in settings_manager for get_starboard_entry (guild {guild_id})." + ) return None try: @@ -449,26 +545,41 @@ async def get_starboard_entry(guild_id: int, original_message_id: int): SELECT * FROM starboard_entries WHERE guild_id = $1 AND original_message_id = $2 """, - guild_id, original_message_id + guild_id, + original_message_id, ) return dict(entry) if entry else None except Exception as e: - log.exception(f"Database error getting starboard entry for message {original_message_id} in guild {guild_id}: {e}") + log.exception( + f"Database error getting starboard entry for message {original_message_id} in guild {guild_id}: {e}" + ) return None -async def create_starboard_entry(guild_id: int, original_message_id: int, original_channel_id: int, - starboard_message_id: int, author_id: int, star_count: int = 1): + +async def create_starboard_entry( + guild_id: int, + original_message_id: int, + original_channel_id: int, + starboard_message_id: int, + author_id: int, + star_count: int = 1, +): """Creates a new starboard entry.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.error(f"Bot instance or PostgreSQL pool not available in settings_manager for create_starboard_entry (guild {guild_id}).") + log.error( + f"Bot instance or PostgreSQL pool not available in settings_manager for create_starboard_entry (guild {guild_id})." + ) return False try: async with bot.pg_pool.acquire() as conn: # Ensure guild exists - await conn.execute("INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", guild_id) + await conn.execute( + "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", + guild_id, + ) # Create the entry await conn.execute( @@ -478,20 +589,34 @@ async def create_starboard_entry(guild_id: int, original_message_id: int, origin VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (guild_id, original_message_id) DO NOTHING; """, - guild_id, original_message_id, original_channel_id, starboard_message_id, author_id, star_count + guild_id, + original_message_id, + original_channel_id, + starboard_message_id, + author_id, + star_count, ) - log.info(f"Created starboard entry for message {original_message_id} in guild {guild_id}") + log.info( + f"Created starboard entry for message {original_message_id} in guild {guild_id}" + ) return True except Exception as e: - log.exception(f"Database error creating starboard entry for message {original_message_id} in guild {guild_id}: {e}") + log.exception( + f"Database error creating starboard entry for message {original_message_id} in guild {guild_id}: {e}" + ) return False -async def update_starboard_entry(guild_id: int, original_message_id: int, star_count: int): + +async def update_starboard_entry( + guild_id: int, original_message_id: int, star_count: int +): """Updates the star count for an existing starboard entry.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.error(f"Bot instance or PostgreSQL pool not available in settings_manager for update_starboard_entry (guild {guild_id}).") + log.error( + f"Bot instance or PostgreSQL pool not available in settings_manager for update_starboard_entry (guild {guild_id})." + ) return False try: @@ -506,27 +631,38 @@ async def update_starboard_entry(guild_id: int, original_message_id: int, star_c SET star_count = $3 WHERE guild_id = $1 AND original_message_id = $2 """, - guild_id, original_message_id, star_count + guild_id, + original_message_id, + star_count, ) - log.info(f"Updated star count to {star_count} for message {original_message_id} in guild {guild_id}") + log.info( + f"Updated star count to {star_count} for message {original_message_id} in guild {guild_id}" + ) return True finally: # Always release the connection back to the pool if conn: await bot.pg_pool.release(conn) except asyncio.TimeoutError: - log.error(f"Timeout acquiring database connection for starboard entry update (Guild: {guild_id}, Message: {original_message_id})") + log.error( + f"Timeout acquiring database connection for starboard entry update (Guild: {guild_id}, Message: {original_message_id})" + ) return False except Exception as e: - log.exception(f"Database error updating starboard entry for message {original_message_id} in guild {guild_id}: {e}") + log.exception( + f"Database error updating starboard entry for message {original_message_id} in guild {guild_id}: {e}" + ) return False + async def delete_starboard_entry(guild_id: int, original_message_id: int): """Deletes a starboard entry.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.error(f"Bot instance or PostgreSQL pool not available in settings_manager for delete_starboard_entry (guild {guild_id}).") + log.error( + f"Bot instance or PostgreSQL pool not available in settings_manager for delete_starboard_entry (guild {guild_id})." + ) return False try: @@ -541,7 +677,8 @@ async def delete_starboard_entry(guild_id: int, original_message_id: int): DELETE FROM starboard_entries WHERE guild_id = $1 AND original_message_id = $2 """, - guild_id, original_message_id + guild_id, + original_message_id, ) # Also delete any reactions associated with this message @@ -550,27 +687,37 @@ async def delete_starboard_entry(guild_id: int, original_message_id: int): DELETE FROM starboard_reactions WHERE guild_id = $1 AND message_id = $2 """, - guild_id, original_message_id + guild_id, + original_message_id, ) - log.info(f"Deleted starboard entry for message {original_message_id} in guild {guild_id}") + log.info( + f"Deleted starboard entry for message {original_message_id} in guild {guild_id}" + ) return True finally: # Always release the connection back to the pool if conn: await bot.pg_pool.release(conn) except asyncio.TimeoutError: - log.error(f"Timeout acquiring database connection for starboard entry deletion (Guild: {guild_id}, Message: {original_message_id})") + log.error( + f"Timeout acquiring database connection for starboard entry deletion (Guild: {guild_id}, Message: {original_message_id})" + ) return False except Exception as e: - log.exception(f"Database error deleting starboard entry for message {original_message_id} in guild {guild_id}: {e}") + log.exception( + f"Database error deleting starboard entry for message {original_message_id} in guild {guild_id}: {e}" + ) return False + async def clear_starboard_entries(guild_id: int): """Clears all starboard entries for a guild.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.error(f"Bot instance or PostgreSQL pool not available in settings_manager for clear_starboard_entries (guild {guild_id}).") + log.error( + f"Bot instance or PostgreSQL pool not available in settings_manager for clear_starboard_entries (guild {guild_id})." + ) return False try: @@ -585,7 +732,7 @@ async def clear_starboard_entries(guild_id: int): SELECT * FROM starboard_entries WHERE guild_id = $1 """, - guild_id + guild_id, ) # Delete all entries @@ -594,7 +741,7 @@ async def clear_starboard_entries(guild_id: int): DELETE FROM starboard_entries WHERE guild_id = $1 """, - guild_id + guild_id, ) # Delete all reactions @@ -603,7 +750,7 @@ async def clear_starboard_entries(guild_id: int): DELETE FROM starboard_reactions WHERE guild_id = $1 """, - guild_id + guild_id, ) log.info(f"Cleared {len(entries)} starboard entries for guild {guild_id}") @@ -613,17 +760,24 @@ async def clear_starboard_entries(guild_id: int): if conn: await bot.pg_pool.release(conn) except asyncio.TimeoutError: - log.error(f"Timeout acquiring database connection for clearing starboard entries (Guild: {guild_id})") + log.error( + f"Timeout acquiring database connection for clearing starboard entries (Guild: {guild_id})" + ) return False except Exception as e: - log.exception(f"Database error clearing starboard entries for guild {guild_id}: {e}") + log.exception( + f"Database error clearing starboard entries for guild {guild_id}: {e}" + ) return False + async def add_starboard_reaction(guild_id: int, message_id: int, user_id: int): """Records a user's star reaction to a message.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.error(f"Bot instance or PostgreSQL pool not available in settings_manager for add_starboard_reaction (guild {guild_id}).") + log.error( + f"Bot instance or PostgreSQL pool not available in settings_manager for add_starboard_reaction (guild {guild_id})." + ) return False # Use a timeout to prevent hanging on database operations @@ -635,10 +789,17 @@ async def add_starboard_reaction(guild_id: int, message_id: int, user_id: int): # Ensure guild exists try: - await conn.execute("INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", guild_id) + await conn.execute( + "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", + guild_id, + ) except Exception as e: - if "another operation is in progress" in str(e) or "attached to a different loop" in str(e): - log.warning(f"Connection issue when inserting guild {guild_id}: {e}") + if "another operation is in progress" in str( + e + ) or "attached to a different loop" in str(e): + log.warning( + f"Connection issue when inserting guild {guild_id}: {e}" + ) # Try to reset the connection await conn.close() conn = await asyncio.wait_for(bot.pg_pool.acquire(), timeout=5.0) @@ -653,11 +814,17 @@ async def add_starboard_reaction(guild_id: int, message_id: int, user_id: int): VALUES ($1, $2, $3) ON CONFLICT (guild_id, message_id, user_id) DO NOTHING; """, - guild_id, message_id, user_id + guild_id, + message_id, + user_id, ) except Exception as e: - if "another operation is in progress" in str(e) or "attached to a different loop" in str(e): - log.warning(f"Connection issue when adding reaction for message {message_id} in guild {guild_id}: {e}") + if "another operation is in progress" in str( + e + ) or "attached to a different loop" in str(e): + log.warning( + f"Connection issue when adding reaction for message {message_id} in guild {guild_id}: {e}" + ) # Try to reset the connection await conn.close() conn = await asyncio.wait_for(bot.pg_pool.acquire(), timeout=5.0) @@ -669,7 +836,9 @@ async def add_starboard_reaction(guild_id: int, message_id: int, user_id: int): VALUES ($1, $2, $3) ON CONFLICT (guild_id, message_id, user_id) DO NOTHING; """, - guild_id, message_id, user_id + guild_id, + message_id, + user_id, ) else: raise @@ -681,12 +850,17 @@ async def add_starboard_reaction(guild_id: int, message_id: int, user_id: int): SELECT COUNT(*) FROM starboard_reactions WHERE guild_id = $1 AND message_id = $2 """, - guild_id, message_id + guild_id, + message_id, ) return count except Exception as e: - if "another operation is in progress" in str(e) or "attached to a different loop" in str(e): - log.warning(f"Connection issue when counting reactions for message {message_id} in guild {guild_id}: {e}") + if "another operation is in progress" in str( + e + ) or "attached to a different loop" in str(e): + log.warning( + f"Connection issue when counting reactions for message {message_id} in guild {guild_id}: {e}" + ) # Try to reset the connection await conn.close() conn = await asyncio.wait_for(bot.pg_pool.acquire(), timeout=5.0) @@ -697,7 +871,8 @@ async def add_starboard_reaction(guild_id: int, message_id: int, user_id: int): SELECT COUNT(*) FROM starboard_reactions WHERE guild_id = $1 AND message_id = $2 """, - guild_id, message_id + guild_id, + message_id, ) return count else: @@ -710,17 +885,24 @@ async def add_starboard_reaction(guild_id: int, message_id: int, user_id: int): except Exception as e: log.warning(f"Error releasing connection: {e}") except asyncio.TimeoutError: - log.error(f"Timeout acquiring database connection for adding starboard reaction (Guild: {guild_id}, Message: {message_id})") + log.error( + f"Timeout acquiring database connection for adding starboard reaction (Guild: {guild_id}, Message: {message_id})" + ) return False except Exception as e: - log.exception(f"Database error adding starboard reaction for message {message_id} in guild {guild_id}: {e}") + log.exception( + f"Database error adding starboard reaction for message {message_id} in guild {guild_id}: {e}" + ) return False + async def remove_starboard_reaction(guild_id: int, message_id: int, user_id: int): """Removes a user's star reaction from a message.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.error(f"Bot instance or PostgreSQL pool not available in settings_manager for remove_starboard_reaction (guild {guild_id}).") + log.error( + f"Bot instance or PostgreSQL pool not available in settings_manager for remove_starboard_reaction (guild {guild_id})." + ) return False # Use a timeout to prevent hanging on database operations @@ -737,11 +919,17 @@ async def remove_starboard_reaction(guild_id: int, message_id: int, user_id: int DELETE FROM starboard_reactions WHERE guild_id = $1 AND message_id = $2 AND user_id = $3 """, - guild_id, message_id, user_id + guild_id, + message_id, + user_id, ) except Exception as e: - if "another operation is in progress" in str(e) or "attached to a different loop" in str(e): - log.warning(f"Connection issue when removing reaction for message {message_id} in guild {guild_id}: {e}") + if "another operation is in progress" in str( + e + ) or "attached to a different loop" in str(e): + log.warning( + f"Connection issue when removing reaction for message {message_id} in guild {guild_id}: {e}" + ) # Try to reset the connection await conn.close() conn = await asyncio.wait_for(bot.pg_pool.acquire(), timeout=5.0) @@ -752,7 +940,9 @@ async def remove_starboard_reaction(guild_id: int, message_id: int, user_id: int DELETE FROM starboard_reactions WHERE guild_id = $1 AND message_id = $2 AND user_id = $3 """, - guild_id, message_id, user_id + guild_id, + message_id, + user_id, ) else: raise @@ -764,12 +954,17 @@ async def remove_starboard_reaction(guild_id: int, message_id: int, user_id: int SELECT COUNT(*) FROM starboard_reactions WHERE guild_id = $1 AND message_id = $2 """, - guild_id, message_id + guild_id, + message_id, ) return count except Exception as e: - if "another operation is in progress" in str(e) or "attached to a different loop" in str(e): - log.warning(f"Connection issue when counting reactions for message {message_id} in guild {guild_id}: {e}") + if "another operation is in progress" in str( + e + ) or "attached to a different loop" in str(e): + log.warning( + f"Connection issue when counting reactions for message {message_id} in guild {guild_id}: {e}" + ) # Try to reset the connection await conn.close() conn = await asyncio.wait_for(bot.pg_pool.acquire(), timeout=5.0) @@ -780,7 +975,8 @@ async def remove_starboard_reaction(guild_id: int, message_id: int, user_id: int SELECT COUNT(*) FROM starboard_reactions WHERE guild_id = $1 AND message_id = $2 """, - guild_id, message_id + guild_id, + message_id, ) return count else: @@ -793,17 +989,24 @@ async def remove_starboard_reaction(guild_id: int, message_id: int, user_id: int except Exception as e: log.warning(f"Error releasing connection: {e}") except asyncio.TimeoutError: - log.error(f"Timeout acquiring database connection for removing starboard reaction (Guild: {guild_id}, Message: {message_id})") + log.error( + f"Timeout acquiring database connection for removing starboard reaction (Guild: {guild_id}, Message: {message_id})" + ) return False except Exception as e: - log.exception(f"Database error removing starboard reaction for message {message_id} in guild {guild_id}: {e}") + log.exception( + f"Database error removing starboard reaction for message {message_id} in guild {guild_id}: {e}" + ) return False + async def get_starboard_reaction_count(guild_id: int, message_id: int): """Gets the count of star reactions for a message.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.warning(f"Bot instance or PostgreSQL pool not available in settings_manager for get_starboard_reaction_count (guild {guild_id}).") + log.warning( + f"Bot instance or PostgreSQL pool not available in settings_manager for get_starboard_reaction_count (guild {guild_id})." + ) return 0 try: @@ -813,19 +1016,25 @@ async def get_starboard_reaction_count(guild_id: int, message_id: int): SELECT COUNT(*) FROM starboard_reactions WHERE guild_id = $1 AND message_id = $2 """, - guild_id, message_id + guild_id, + message_id, ) return count except Exception as e: - log.exception(f"Database error getting starboard reaction count for message {message_id} in guild {guild_id}: {e}") + log.exception( + f"Database error getting starboard reaction count for message {message_id} in guild {guild_id}: {e}" + ) return 0 + async def has_user_reacted(guild_id: int, message_id: int, user_id: int): """Checks if a user has already reacted to a message.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.warning(f"Bot instance or PostgreSQL pool not available in settings_manager for has_user_reacted (guild {guild_id}).") + log.warning( + f"Bot instance or PostgreSQL pool not available in settings_manager for has_user_reacted (guild {guild_id})." + ) return False try: @@ -837,12 +1046,16 @@ async def has_user_reacted(guild_id: int, message_id: int, user_id: int): WHERE guild_id = $1 AND message_id = $2 AND user_id = $3 ) """, - guild_id, message_id, user_id + guild_id, + message_id, + user_id, ) return result except Exception as e: - log.exception(f"Database error checking if user {user_id} reacted to message {message_id} in guild {guild_id}: {e}") + log.exception( + f"Database error checking if user {user_id} reacted to message {message_id} in guild {guild_id}: {e}" + ) return False @@ -853,13 +1066,17 @@ def _get_redis_key(guild_id: int, key_type: str, identifier: str = None) -> str: return f"guild:{guild_id}:{key_type}:{identifier}" return f"guild:{guild_id}:{key_type}" + # --- Settings Access Functions (Placeholders with Cache Logic) --- + async def get_guild_prefix(guild_id: int, default_prefix: str) -> str: """Gets the command prefix for a guild, checking cache first.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.warning(f"Bot instance or pools not available in settings_manager for get_guild_prefix (guild {guild_id}).") + log.warning( + f"Bot instance or pools not available in settings_manager for get_guild_prefix (guild {guild_id})." + ) return default_prefix cache_key = _get_redis_key(guild_id, "prefix") @@ -872,10 +1089,14 @@ async def get_guild_prefix(guild_id: int, default_prefix: str) -> str: log.debug(f"Cache hit for prefix (Guild: {guild_id})") return cached_prefix except asyncio.TimeoutError: - log.warning(f"Redis timeout getting prefix for guild {guild_id}, falling back to database") + log.warning( + f"Redis timeout getting prefix for guild {guild_id}, falling back to database" + ) except RuntimeError as e: if "got Future" in str(e) and "attached to a different loop" in str(e): - log.warning(f"Redis event loop error for guild {guild_id}, falling back to database: {e}") + log.warning( + f"Redis event loop error for guild {guild_id}, falling back to database: {e}" + ) else: log.exception(f"Redis error getting prefix for guild {guild_id}: {e}") except Exception as e: @@ -887,7 +1108,7 @@ async def get_guild_prefix(guild_id: int, default_prefix: str) -> str: async with bot.pg_pool.acquire() as conn: prefix = await conn.fetchval( "SELECT setting_value FROM guild_settings WHERE guild_id = $1 AND setting_key = 'prefix'", - guild_id + guild_id, ) final_prefix = prefix if prefix is not None else default_prefix @@ -897,13 +1118,15 @@ async def get_guild_prefix(guild_id: int, default_prefix: str) -> str: # Use a timeout to prevent hanging on Redis operations await asyncio.wait_for( bot.redis.set(cache_key, final_prefix, ex=3600), # Cache for 1 hour - timeout=2.0 + timeout=2.0, ) except asyncio.TimeoutError: log.warning(f"Redis timeout setting prefix for guild {guild_id}") except RuntimeError as e: if "got Future" in str(e) and "attached to a different loop" in str(e): - log.warning(f"Redis event loop error setting prefix for guild {guild_id}: {e}") + log.warning( + f"Redis event loop error setting prefix for guild {guild_id}: {e}" + ) else: log.exception(f"Redis error setting prefix for guild {guild_id}: {e}") except Exception as e: @@ -914,18 +1137,24 @@ async def get_guild_prefix(guild_id: int, default_prefix: str) -> str: log.exception(f"Database error getting prefix for guild {guild_id}: {e}") return default_prefix # Fall back to default on database error + async def set_guild_prefix(guild_id: int, prefix: str): """Sets the command prefix for a guild and updates the cache.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.error(f"Bot instance or pools not available in settings_manager for set_guild_prefix (guild {guild_id}).") - return False # Indicate failure + log.error( + f"Bot instance or pools not available in settings_manager for set_guild_prefix (guild {guild_id})." + ) + return False # Indicate failure cache_key = _get_redis_key(guild_id, "prefix") try: async with bot.pg_pool.acquire() as conn: # Ensure guild exists - await conn.execute("INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", guild_id) + await conn.execute( + "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", + guild_id, + ) # Upsert the setting await conn.execute( """ @@ -933,29 +1162,38 @@ async def set_guild_prefix(guild_id: int, prefix: str): VALUES ($1, 'prefix', $2) ON CONFLICT (guild_id, setting_key) DO UPDATE SET setting_value = $2; """, - guild_id, prefix + guild_id, + prefix, ) # Update cache - await bot.redis.set(cache_key, prefix, ex=3600) # Cache for 1 hour + await bot.redis.set(cache_key, prefix, ex=3600) # Cache for 1 hour log.info(f"Set prefix for guild {guild_id} to '{prefix}'") - return True # Indicate success + return True # Indicate success except Exception as e: - log.exception(f"Database or Redis error setting prefix for guild {guild_id}: {e}") + log.exception( + f"Database or Redis error setting prefix for guild {guild_id}: {e}" + ) # Attempt to invalidate cache on error to prevent stale data try: await bot.redis.delete(cache_key) except Exception as redis_err: - log.exception(f"Failed to invalidate Redis cache for prefix (Guild: {guild_id}): {redis_err}") - return False # Indicate failure + log.exception( + f"Failed to invalidate Redis cache for prefix (Guild: {guild_id}): {redis_err}" + ) + return False # Indicate failure + # --- Generic Settings Functions --- + async def get_setting(guild_id: int, key: str, default=None): """Gets a specific setting for a guild, checking cache first.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.warning(f"Bot instance or pools not available in settings_manager for get_setting (guild {guild_id}, key '{key}').") + log.warning( + f"Bot instance or pools not available in settings_manager for get_setting (guild {guild_id}, key '{key}')." + ) return default cache_key = _get_redis_key(guild_id, "setting", key) @@ -972,12 +1210,18 @@ async def get_setting(guild_id: int, key: str, default=None): return default return cached_value except asyncio.TimeoutError: - log.warning(f"Redis timeout getting setting '{key}' for guild {guild_id}, falling back to database") + log.warning( + f"Redis timeout getting setting '{key}' for guild {guild_id}, falling back to database" + ) except RuntimeError as e: if "got Future" in str(e) and "attached to a different loop" in str(e): - log.warning(f"Redis event loop error for guild {guild_id}, falling back to database: {e}") + log.warning( + f"Redis event loop error for guild {guild_id}, falling back to database: {e}" + ) else: - log.exception(f"Redis error getting setting '{key}' for guild {guild_id}: {e}") + log.exception( + f"Redis error getting setting '{key}' for guild {guild_id}: {e}" + ) except Exception as e: log.exception(f"Redis error getting setting '{key}' for guild {guild_id}: {e}") @@ -987,31 +1231,44 @@ async def get_setting(guild_id: int, key: str, default=None): async with bot.pg_pool.acquire() as conn: value = await conn.fetchval( "SELECT setting_value FROM guild_settings WHERE guild_id = $1 AND setting_key = $2", - guild_id, key + guild_id, + key, ) final_value = value if value is not None else default except Exception as e: - log.exception(f"Database error getting setting '{key}' for guild {guild_id}: {e}") + log.exception( + f"Database error getting setting '{key}' for guild {guild_id}: {e}" + ) return default # Fall back to default on database error # Cache the result (even if None or default, cache the absence or default value) - value_to_cache = final_value if final_value is not None else "__NONE__" # Marker for None - if bot.redis: # Ensure redis is available before trying to cache + value_to_cache = ( + final_value if final_value is not None else "__NONE__" + ) # Marker for None + if bot.redis: # Ensure redis is available before trying to cache try: # Use a timeout to prevent hanging on Redis operations await asyncio.wait_for( bot.redis.set(cache_key, value_to_cache, ex=3600), # Cache for 1 hour - timeout=2.0 + timeout=2.0, ) except asyncio.TimeoutError: - log.warning(f"Redis timeout setting cache for setting '{key}' for guild {guild_id}") + log.warning( + f"Redis timeout setting cache for setting '{key}' for guild {guild_id}" + ) except RuntimeError as e: if "got Future" in str(e) and "attached to a different loop" in str(e): - log.warning(f"Redis event loop error setting cache for setting '{key}' for guild {guild_id}: {e}") + log.warning( + f"Redis event loop error setting cache for setting '{key}' for guild {guild_id}: {e}" + ) else: - log.exception(f"Redis error setting cache for setting '{key}' for guild {guild_id}: {e}") + log.exception( + f"Redis error setting cache for setting '{key}' for guild {guild_id}: {e}" + ) except Exception as e: - log.exception(f"Redis error setting cache for setting '{key}' for guild {guild_id}: {e}") + log.exception( + f"Redis error setting cache for setting '{key}' for guild {guild_id}: {e}" + ) # This block was duplicated, removed the second instance of caching logic. return final_value @@ -1019,17 +1276,22 @@ async def get_setting(guild_id: int, key: str, default=None): async def set_setting(guild_id: int, key: str, value: str | None): """Sets a specific setting for a guild and updates/invalidates the cache. - Setting value to None effectively deletes the setting.""" + Setting value to None effectively deletes the setting.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.error(f"Bot instance or pools not available in settings_manager for set_setting (guild {guild_id}, key '{key}').") - return False # Indicate failure + log.error( + f"Bot instance or pools not available in settings_manager for set_setting (guild {guild_id}, key '{key}')." + ) + return False # Indicate failure cache_key = _get_redis_key(guild_id, "setting", key) try: async with bot.pg_pool.acquire() as conn: # Ensure guild exists - await conn.execute("INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", guild_id) + await conn.execute( + "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", + guild_id, + ) if value is not None: # Upsert the setting @@ -1039,7 +1301,9 @@ async def set_setting(guild_id: int, key: str, value: str | None): VALUES ($1, $2, $3) ON CONFLICT (guild_id, setting_key) DO UPDATE SET setting_value = $3; """, - guild_id, key, str(value) # Ensure value is string + guild_id, + key, + str(value), # Ensure value is string ) # Update cache await bot.redis.set(cache_key, str(value), ex=3600) @@ -1048,30 +1312,41 @@ async def set_setting(guild_id: int, key: str, value: str | None): # Delete the setting if value is None await conn.execute( "DELETE FROM guild_settings WHERE guild_id = $1 AND setting_key = $2", - guild_id, key + guild_id, + key, ) # Invalidate cache await bot.redis.delete(cache_key) log.info(f"Deleted setting '{key}' for guild {guild_id}") return True except Exception as e: - log.exception(f"Database or Redis error setting setting '{key}' for guild {guild_id}: {e}") + log.exception( + f"Database or Redis error setting setting '{key}' for guild {guild_id}: {e}" + ) # Attempt to invalidate cache on error if bot.redis: try: await bot.redis.delete(cache_key) except Exception as redis_err: - log.exception(f"Failed to invalidate Redis cache for setting '{key}' (Guild: {guild_id}): {redis_err}") + log.exception( + f"Failed to invalidate Redis cache for setting '{key}' (Guild: {guild_id}): {redis_err}" + ) return False + # --- Cog Enablement Functions --- -async def is_cog_enabled(guild_id: int, cog_name: str, default_enabled: bool = True) -> bool: + +async def is_cog_enabled( + guild_id: int, cog_name: str, default_enabled: bool = True +) -> bool: """Checks if a cog is enabled for a guild, checking cache first. - Uses default_enabled if no specific setting is found.""" + Uses default_enabled if no specific setting is found.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.warning(f"Bot instance or pools not available in settings_manager for is_cog_enabled (guild {guild_id}, cog '{cog_name}').") + log.warning( + f"Bot instance or pools not available in settings_manager for is_cog_enabled (guild {guild_id}, cog '{cog_name}')." + ) return default_enabled cache_key = _get_redis_key(guild_id, "cog_enabled", cog_name) @@ -1081,17 +1356,27 @@ async def is_cog_enabled(guild_id: int, cog_name: str, default_enabled: bool = T # Use a timeout to prevent hanging on Redis operations cached_value = await asyncio.wait_for(bot.redis.get(cache_key), timeout=2.0) if cached_value is not None: - log.debug(f"Cache hit for cog enabled status '{cog_name}' (Guild: {guild_id})") - return cached_value == "True" # Redis stores strings + log.debug( + f"Cache hit for cog enabled status '{cog_name}' (Guild: {guild_id})" + ) + return cached_value == "True" # Redis stores strings except asyncio.TimeoutError: - log.warning(f"Redis timeout getting cog enabled status for '{cog_name}' (Guild: {guild_id}), falling back to database") + log.warning( + f"Redis timeout getting cog enabled status for '{cog_name}' (Guild: {guild_id}), falling back to database" + ) except RuntimeError as e: if "got Future" in str(e) and "attached to a different loop" in str(e): - log.warning(f"Redis event loop error for guild {guild_id}, falling back to database: {e}") + log.warning( + f"Redis event loop error for guild {guild_id}, falling back to database: {e}" + ) else: - log.exception(f"Redis error getting cog enabled status for '{cog_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error getting cog enabled status for '{cog_name}' (Guild: {guild_id}): {e}" + ) except Exception as e: - log.exception(f"Redis error getting cog enabled status for '{cog_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error getting cog enabled status for '{cog_name}' (Guild: {guild_id}): {e}" + ) # Cache miss or Redis error, get from database log.debug(f"Cache miss for cog enabled status '{cog_name}' (Guild: {guild_id})") @@ -1100,32 +1385,47 @@ async def is_cog_enabled(guild_id: int, cog_name: str, default_enabled: bool = T async with bot.pg_pool.acquire() as conn: db_enabled_status = await conn.fetchval( "SELECT enabled FROM enabled_cogs WHERE guild_id = $1 AND cog_name = $2", - guild_id, cog_name + guild_id, + cog_name, ) - final_status = db_enabled_status if db_enabled_status is not None else default_enabled + final_status = ( + db_enabled_status if db_enabled_status is not None else default_enabled + ) # Try to cache the result with timeout and error handling if bot.redis: try: # Use a timeout to prevent hanging on Redis operations await asyncio.wait_for( - bot.redis.set(cache_key, str(final_status), ex=3600), # Cache for 1 hour - timeout=2.0 + bot.redis.set( + cache_key, str(final_status), ex=3600 + ), # Cache for 1 hour + timeout=2.0, ) except asyncio.TimeoutError: - log.warning(f"Redis timeout setting cache for cog enabled status '{cog_name}' (Guild: {guild_id})") + log.warning( + f"Redis timeout setting cache for cog enabled status '{cog_name}' (Guild: {guild_id})" + ) except RuntimeError as e: if "got Future" in str(e) and "attached to a different loop" in str(e): - log.warning(f"Redis event loop error setting cache for cog enabled status '{cog_name}' (Guild: {guild_id}): {e}") + log.warning( + f"Redis event loop error setting cache for cog enabled status '{cog_name}' (Guild: {guild_id}): {e}" + ) else: - log.exception(f"Redis error setting cache for cog enabled status '{cog_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error setting cache for cog enabled status '{cog_name}' (Guild: {guild_id}): {e}" + ) except Exception as e: - log.exception(f"Redis error setting cache for cog enabled status '{cog_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error setting cache for cog enabled status '{cog_name}' (Guild: {guild_id}): {e}" + ) return final_status except Exception as e: - log.exception(f"Database error getting cog enabled status for '{cog_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Database error getting cog enabled status for '{cog_name}' (Guild: {guild_id}): {e}" + ) # Fallback to default on DB error after cache miss return default_enabled @@ -1134,14 +1434,19 @@ async def set_cog_enabled(guild_id: int, cog_name: str, enabled: bool): """Sets the enabled status for a cog in a guild and updates the cache.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.error(f"Bot instance or pools not available in settings_manager for set_cog_enabled (guild {guild_id}, cog '{cog_name}').") + log.error( + f"Bot instance or pools not available in settings_manager for set_cog_enabled (guild {guild_id}, cog '{cog_name}')." + ) return False cache_key = _get_redis_key(guild_id, "cog_enabled", cog_name) try: async with bot.pg_pool.acquire() as conn: # Ensure guild exists - await conn.execute("INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", guild_id) + await conn.execute( + "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", + guild_id, + ) # Upsert the enabled status await conn.execute( """ @@ -1149,30 +1454,42 @@ async def set_cog_enabled(guild_id: int, cog_name: str, enabled: bool): VALUES ($1, $2, $3) ON CONFLICT (guild_id, cog_name) DO UPDATE SET enabled = $3; """, - guild_id, cog_name, enabled + guild_id, + cog_name, + enabled, ) # Update cache await bot.redis.set(cache_key, str(enabled), ex=3600) - log.info(f"Set cog '{cog_name}' enabled status to {enabled} for guild {guild_id}") + log.info( + f"Set cog '{cog_name}' enabled status to {enabled} for guild {guild_id}" + ) return True except Exception as e: - log.exception(f"Database or Redis error setting cog enabled status for '{cog_name}' in guild {guild_id}: {e}") + log.exception( + f"Database or Redis error setting cog enabled status for '{cog_name}' in guild {guild_id}: {e}" + ) # Attempt to invalidate cache on error if bot.redis: try: await bot.redis.delete(cache_key) except Exception as redis_err: - log.exception(f"Failed to invalidate Redis cache for cog enabled status '{cog_name}' (Guild: {guild_id}): {redis_err}") + log.exception( + f"Failed to invalidate Redis cache for cog enabled status '{cog_name}' (Guild: {guild_id}): {redis_err}" + ) return False -async def is_command_enabled(guild_id: int, command_name: str, default_enabled: bool = True) -> bool: +async def is_command_enabled( + guild_id: int, command_name: str, default_enabled: bool = True +) -> bool: """Checks if a command is enabled for a guild, checking cache first. - Uses default_enabled if no specific setting is found.""" + Uses default_enabled if no specific setting is found.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.warning(f"Bot instance or pools not available in settings_manager for is_command_enabled (guild {guild_id}, command '{command_name}').") + log.warning( + f"Bot instance or pools not available in settings_manager for is_command_enabled (guild {guild_id}, command '{command_name}')." + ) return default_enabled cache_key = _get_redis_key(guild_id, "cmd_enabled", command_name) @@ -1182,51 +1499,78 @@ async def is_command_enabled(guild_id: int, command_name: str, default_enabled: # Use a timeout to prevent hanging on Redis operations cached_value = await asyncio.wait_for(bot.redis.get(cache_key), timeout=2.0) if cached_value is not None: - log.debug(f"Cache hit for command enabled status '{command_name}' (Guild: {guild_id})") - return cached_value == "True" # Redis stores strings + log.debug( + f"Cache hit for command enabled status '{command_name}' (Guild: {guild_id})" + ) + return cached_value == "True" # Redis stores strings except asyncio.TimeoutError: - log.warning(f"Redis timeout getting command enabled status for '{command_name}' (Guild: {guild_id}), falling back to database") + log.warning( + f"Redis timeout getting command enabled status for '{command_name}' (Guild: {guild_id}), falling back to database" + ) except RuntimeError as e: if "got Future" in str(e) and "attached to a different loop" in str(e): - log.warning(f"Redis event loop error for guild {guild_id}, falling back to database: {e}") + log.warning( + f"Redis event loop error for guild {guild_id}, falling back to database: {e}" + ) else: - log.exception(f"Redis error getting command enabled status for '{command_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error getting command enabled status for '{command_name}' (Guild: {guild_id}): {e}" + ) except Exception as e: - log.exception(f"Redis error getting command enabled status for '{command_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error getting command enabled status for '{command_name}' (Guild: {guild_id}): {e}" + ) # Cache miss or Redis error, get from database - log.debug(f"Cache miss for command enabled status '{command_name}' (Guild: {guild_id})") + log.debug( + f"Cache miss for command enabled status '{command_name}' (Guild: {guild_id})" + ) db_enabled_status = None try: async with bot.pg_pool.acquire() as conn: db_enabled_status = await conn.fetchval( "SELECT enabled FROM enabled_commands WHERE guild_id = $1 AND command_name = $2", - guild_id, command_name + guild_id, + command_name, ) - final_status = db_enabled_status if db_enabled_status is not None else default_enabled + final_status = ( + db_enabled_status if db_enabled_status is not None else default_enabled + ) # Try to cache the result with timeout and error handling if bot.redis: try: # Use a timeout to prevent hanging on Redis operations await asyncio.wait_for( - bot.redis.set(cache_key, str(final_status), ex=3600), # Cache for 1 hour - timeout=2.0 + bot.redis.set( + cache_key, str(final_status), ex=3600 + ), # Cache for 1 hour + timeout=2.0, ) except asyncio.TimeoutError: - log.warning(f"Redis timeout setting cache for command enabled status '{command_name}' (Guild: {guild_id})") + log.warning( + f"Redis timeout setting cache for command enabled status '{command_name}' (Guild: {guild_id})" + ) except RuntimeError as e: if "got Future" in str(e) and "attached to a different loop" in str(e): - log.warning(f"Redis event loop error setting cache for command enabled status '{command_name}' (Guild: {guild_id}): {e}") + log.warning( + f"Redis event loop error setting cache for command enabled status '{command_name}' (Guild: {guild_id}): {e}" + ) else: - log.exception(f"Redis error setting cache for command enabled status '{command_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error setting cache for command enabled status '{command_name}' (Guild: {guild_id}): {e}" + ) except Exception as e: - log.exception(f"Redis error setting cache for command enabled status '{command_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error setting cache for command enabled status '{command_name}' (Guild: {guild_id}): {e}" + ) return final_status except Exception as e: - log.exception(f"Database error getting command enabled status for '{command_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Database error getting command enabled status for '{command_name}' (Guild: {guild_id}): {e}" + ) # Fallback to default on DB error after cache miss return default_enabled @@ -1235,14 +1579,19 @@ async def set_command_enabled(guild_id: int, command_name: str, enabled: bool): """Sets the enabled status for a command in a guild and updates the cache.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.error(f"Bot instance or pools not available in settings_manager for set_command_enabled (guild {guild_id}, command '{command_name}').") + log.error( + f"Bot instance or pools not available in settings_manager for set_command_enabled (guild {guild_id}, command '{command_name}')." + ) return False cache_key = _get_redis_key(guild_id, "cmd_enabled", command_name) try: async with bot.pg_pool.acquire() as conn: # Ensure guild exists - await conn.execute("INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", guild_id) + await conn.execute( + "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", + guild_id, + ) # Upsert the enabled status await conn.execute( """ @@ -1250,77 +1599,102 @@ async def set_command_enabled(guild_id: int, command_name: str, enabled: bool): VALUES ($1, $2, $3) ON CONFLICT (guild_id, command_name) DO UPDATE SET enabled = $3; """, - guild_id, command_name, enabled + guild_id, + command_name, + enabled, ) # Update cache await bot.redis.set(cache_key, str(enabled), ex=3600) - log.info(f"Set command '{command_name}' enabled status to {enabled} for guild {guild_id}") + log.info( + f"Set command '{command_name}' enabled status to {enabled} for guild {guild_id}" + ) return True except Exception as e: - log.exception(f"Database or Redis error setting command enabled status for '{command_name}' in guild {guild_id}: {e}") + log.exception( + f"Database or Redis error setting command enabled status for '{command_name}' in guild {guild_id}: {e}" + ) # Attempt to invalidate cache on error if bot.redis: try: await bot.redis.delete(cache_key) except Exception as redis_err: - log.exception(f"Failed to invalidate Redis cache for command enabled status '{command_name}' (Guild: {guild_id}): {redis_err}") + log.exception( + f"Failed to invalidate Redis cache for command enabled status '{command_name}' (Guild: {guild_id}): {redis_err}" + ) return False async def get_all_enabled_commands(guild_id: int) -> Dict[str, bool]: """Gets all command enabled statuses for a guild. - Returns a dictionary of command_name -> enabled status.""" + Returns a dictionary of command_name -> enabled status.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.error(f"Bot instance or PostgreSQL pool not available in settings_manager for get_all_enabled_commands (guild {guild_id}).") + log.error( + f"Bot instance or PostgreSQL pool not available in settings_manager for get_all_enabled_commands (guild {guild_id})." + ) return {} try: async with bot.pg_pool.acquire() as conn: records = await conn.fetch( "SELECT command_name, enabled FROM enabled_commands WHERE guild_id = $1", - guild_id + guild_id, ) - return {record['command_name']: record['enabled'] for record in records} + return {record["command_name"]: record["enabled"] for record in records} except Exception as e: - log.exception(f"Database error getting command enabled statuses for guild {guild_id}: {e}") + log.exception( + f"Database error getting command enabled statuses for guild {guild_id}: {e}" + ) return {} async def get_all_enabled_cogs(guild_id: int) -> Dict[str, bool]: """Gets all cog enabled statuses for a guild. - Returns a dictionary of cog_name -> enabled status.""" + Returns a dictionary of cog_name -> enabled status.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.error(f"Bot instance or PostgreSQL pool not available in settings_manager for get_all_enabled_cogs (guild {guild_id}).") + log.error( + f"Bot instance or PostgreSQL pool not available in settings_manager for get_all_enabled_cogs (guild {guild_id})." + ) return {} try: async with bot.pg_pool.acquire() as conn: records = await conn.fetch( "SELECT cog_name, enabled FROM enabled_cogs WHERE guild_id = $1", - guild_id + guild_id, ) - return {record['cog_name']: record['enabled'] for record in records} + return {record["cog_name"]: record["enabled"] for record in records} except Exception as e: - log.exception(f"Database error getting cog enabled statuses for guild {guild_id}: {e}") + log.exception( + f"Database error getting cog enabled statuses for guild {guild_id}: {e}" + ) return {} + # --- Command Permission Functions --- -async def add_command_permission(guild_id: int, command_name: str, role_id: int) -> bool: + +async def add_command_permission( + guild_id: int, command_name: str, role_id: int +) -> bool: """Adds permission for a role to use a command and invalidates cache.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.error(f"Bot instance or pools not available in settings_manager for add_command_permission (guild {guild_id}, command '{command_name}').") + log.error( + f"Bot instance or pools not available in settings_manager for add_command_permission (guild {guild_id}, command '{command_name}')." + ) return False cache_key = _get_redis_key(guild_id, "cmd_perms", command_name) try: async with bot.pg_pool.acquire() as conn: # Ensure guild exists - await conn.execute("INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", guild_id) + await conn.execute( + "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", + guild_id, + ) # Add the permission rule await conn.execute( """ @@ -1328,29 +1702,41 @@ async def add_command_permission(guild_id: int, command_name: str, role_id: int) VALUES ($1, $2, $3) ON CONFLICT (guild_id, command_name, allowed_role_id) DO NOTHING; """, - guild_id, command_name, role_id + guild_id, + command_name, + role_id, ) # Invalidate cache after DB operation succeeds await bot.redis.delete(cache_key) - log.info(f"Added permission for role {role_id} to use command '{command_name}' in guild {guild_id}") + log.info( + f"Added permission for role {role_id} to use command '{command_name}' in guild {guild_id}" + ) return True except Exception as e: - log.exception(f"Database or Redis error adding permission for command '{command_name}' in guild {guild_id}: {e}") + log.exception( + f"Database or Redis error adding permission for command '{command_name}' in guild {guild_id}: {e}" + ) # Attempt to invalidate cache even on error if bot.redis: try: await bot.redis.delete(cache_key) except Exception as redis_err: - log.exception(f"Failed to invalidate Redis cache for command permissions '{command_name}' (Guild: {guild_id}): {redis_err}") + log.exception( + f"Failed to invalidate Redis cache for command permissions '{command_name}' (Guild: {guild_id}): {redis_err}" + ) return False -async def remove_command_permission(guild_id: int, command_name: str, role_id: int) -> bool: +async def remove_command_permission( + guild_id: int, command_name: str, role_id: int +) -> bool: """Removes permission for a role to use a command and invalidates cache.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.error(f"Bot instance or pools not available in settings_manager for remove_command_permission (guild {guild_id}, command '{command_name}').") + log.error( + f"Bot instance or pools not available in settings_manager for remove_command_permission (guild {guild_id}, command '{command_name}')." + ) return False cache_key = _get_redis_key(guild_id, "cmd_perms", command_name) @@ -1364,33 +1750,45 @@ async def remove_command_permission(guild_id: int, command_name: str, role_id: i DELETE FROM command_permissions WHERE guild_id = $1 AND command_name = $2 AND allowed_role_id = $3; """, - guild_id, command_name, role_id + guild_id, + command_name, + role_id, ) # Invalidate cache after DB operation succeeds await bot.redis.delete(cache_key) - log.info(f"Removed permission for role {role_id} to use command '{command_name}' in guild {guild_id}") + log.info( + f"Removed permission for role {role_id} to use command '{command_name}' in guild {guild_id}" + ) return True except Exception as e: - log.exception(f"Database or Redis error removing permission for command '{command_name}' in guild {guild_id}: {e}") + log.exception( + f"Database or Redis error removing permission for command '{command_name}' in guild {guild_id}: {e}" + ) # Attempt to invalidate cache even on error if bot.redis: try: await bot.redis.delete(cache_key) except Exception as redis_err: - log.exception(f"Failed to invalidate Redis cache for command permissions '{command_name}' (Guild: {guild_id}): {redis_err}") + log.exception( + f"Failed to invalidate Redis cache for command permissions '{command_name}' (Guild: {guild_id}): {redis_err}" + ) return False -async def check_command_permission(guild_id: int, command_name: str, member_roles_ids: list[int]) -> bool: +async def check_command_permission( + guild_id: int, command_name: str, member_roles_ids: list[int] +) -> bool: """Checks if any of the member's roles have permission for the command. - Returns True if allowed, False otherwise. - If no permissions are set for the command in the DB, it defaults to allowed by this check. + Returns True if allowed, False otherwise. + If no permissions are set for the command in the DB, it defaults to allowed by this check. """ bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.warning(f"Bot instance or pools not available in settings_manager for check_command_permission (guild {guild_id}, command '{command_name}').") - return True # Default to allowed if system isn't ready + log.warning( + f"Bot instance or pools not available in settings_manager for check_command_permission (guild {guild_id}, command '{command_name}')." + ) + return True # Default to allowed if system isn't ready cache_key = _get_redis_key(guild_id, "cmd_perms", command_name) allowed_role_ids_str = set() @@ -1401,8 +1799,10 @@ async def check_command_permission(guild_id: int, command_name: str, member_role cached_roles = await bot.redis.smembers(cache_key) # Handle the empty set marker if cached_roles == {"__EMPTY_SET__"}: - log.debug(f"Cache hit (empty set) for cmd perms '{command_name}' (Guild: {guild_id}). Command allowed by default.") - return True # No specific restrictions found + log.debug( + f"Cache hit (empty set) for cmd perms '{command_name}' (Guild: {guild_id}). Command allowed by default." + ) + return True # No specific restrictions found allowed_role_ids_str = cached_roles log.debug(f"Cache hit for cmd perms '{command_name}' (Guild: {guild_id})") else: @@ -1411,28 +1811,37 @@ async def check_command_permission(guild_id: int, command_name: str, member_role async with bot.pg_pool.acquire() as conn: records = await conn.fetch( "SELECT allowed_role_id FROM command_permissions WHERE guild_id = $1 AND command_name = $2", - guild_id, command_name + guild_id, + command_name, ) # Convert fetched role IDs (BIGINT) to strings for Redis set - allowed_role_ids_str = {str(record['allowed_role_id']) for record in records} + allowed_role_ids_str = { + str(record["allowed_role_id"]) for record in records + } # Cache the result (even if empty) if bot.redis: try: async with bot.redis.pipeline(transaction=True) as pipe: - pipe.delete(cache_key) # Ensure clean state + pipe.delete(cache_key) # Ensure clean state if allowed_role_ids_str: pipe.sadd(cache_key, *allowed_role_ids_str) else: - pipe.sadd(cache_key, "__EMPTY_SET__") # Marker for empty set - pipe.expire(cache_key, 3600) # Cache for 1 hour + pipe.sadd( + cache_key, "__EMPTY_SET__" + ) # Marker for empty set + pipe.expire(cache_key, 3600) # Cache for 1 hour await pipe.execute() except Exception as e: - log.exception(f"Redis error setting cache for cmd perms '{command_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error setting cache for cmd perms '{command_name}' (Guild: {guild_id}): {e}" + ) except Exception as e: - log.exception(f"Error checking command permission for '{command_name}' (Guild: {guild_id}): {e}") - return True # Default to allowed on error + log.exception( + f"Error checking command permission for '{command_name}' (Guild: {guild_id}): {e}" + ) + return True # Default to allowed on error # --- Permission Check Logic --- if not allowed_role_ids_str or allowed_role_ids_str == {"__EMPTY_SET__"}: @@ -1443,18 +1852,24 @@ async def check_command_permission(guild_id: int, command_name: str, member_role # Check if any of the member's roles intersect with the allowed roles member_roles_ids_str = {str(role_id) for role_id in member_roles_ids} if member_roles_ids_str.intersection(allowed_role_ids_str): - log.debug(f"Permission granted for '{command_name}' (Guild: {guild_id}) via role intersection.") - return True # Member has at least one allowed role + log.debug( + f"Permission granted for '{command_name}' (Guild: {guild_id}) via role intersection." + ) + return True # Member has at least one allowed role else: - log.debug(f"Permission denied for '{command_name}' (Guild: {guild_id}). Member roles {member_roles_ids_str} not in allowed roles {allowed_role_ids_str}.") - return False # Member has none of the specifically allowed roles + log.debug( + f"Permission denied for '{command_name}' (Guild: {guild_id}). Member roles {member_roles_ids_str} not in allowed roles {allowed_role_ids_str}." + ) + return False # Member has none of the specifically allowed roles async def get_command_permissions(guild_id: int, command_name: str) -> set[int] | None: """Gets the set of allowed role IDs for a specific command, checking cache first. Returns None on error.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.warning(f"Bot instance or pools not available in settings_manager for get_command_permissions (guild {guild_id}, command '{command_name}').") + log.warning( + f"Bot instance or pools not available in settings_manager for get_command_permissions (guild {guild_id}, command '{command_name}')." + ) return None cache_key = _get_redis_key(guild_id, "cmd_perms", command_name) @@ -1463,13 +1878,17 @@ async def get_command_permissions(guild_id: int, command_name: str) -> set[int] if await bot.redis.exists(cache_key): cached_roles_str = await bot.redis.smembers(cache_key) if cached_roles_str == {"__EMPTY_SET__"}: - log.debug(f"Cache hit (empty set) for cmd perms '{command_name}' (Guild: {guild_id}).") - return set() # Return empty set if explicitly empty + log.debug( + f"Cache hit (empty set) for cmd perms '{command_name}' (Guild: {guild_id})." + ) + return set() # Return empty set if explicitly empty allowed_role_ids = {int(role_id) for role_id in cached_roles_str} log.debug(f"Cache hit for cmd perms '{command_name}' (Guild: {guild_id})") return allowed_role_ids except Exception as e: - log.exception(f"Redis error getting cmd perms for '{command_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error getting cmd perms for '{command_name}' (Guild: {guild_id}): {e}" + ) # Fall through to DB query on Redis error log.debug(f"Cache miss for cmd perms '{command_name}' (Guild: {guild_id})") @@ -1477,46 +1896,59 @@ async def get_command_permissions(guild_id: int, command_name: str) -> set[int] async with bot.pg_pool.acquire() as conn: records = await conn.fetch( "SELECT allowed_role_id FROM command_permissions WHERE guild_id = $1 AND command_name = $2", - guild_id, command_name + guild_id, + command_name, ) - allowed_role_ids = {record['allowed_role_id'] for record in records} + allowed_role_ids = {record["allowed_role_id"] for record in records} # Cache the result if bot.redis: try: allowed_role_ids_str = {str(role_id) for role_id in allowed_role_ids} async with bot.redis.pipeline(transaction=True) as pipe: - pipe.delete(cache_key) # Ensure clean state + pipe.delete(cache_key) # Ensure clean state if allowed_role_ids_str: pipe.sadd(cache_key, *allowed_role_ids_str) else: - pipe.sadd(cache_key, "__EMPTY_SET__") # Marker for empty set - pipe.expire(cache_key, 3600) # Cache for 1 hour + pipe.sadd(cache_key, "__EMPTY_SET__") # Marker for empty set + pipe.expire(cache_key, 3600) # Cache for 1 hour await pipe.execute() except Exception as e: - log.exception(f"Redis error setting cache for cmd perms '{command_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error setting cache for cmd perms '{command_name}' (Guild: {guild_id}): {e}" + ) return allowed_role_ids except Exception as e: - log.exception(f"Database error getting cmd perms for '{command_name}' (Guild: {guild_id}): {e}") - return None # Indicate error + log.exception( + f"Database error getting cmd perms for '{command_name}' (Guild: {guild_id}): {e}" + ) + return None # Indicate error # --- Logging Webhook Functions --- + async def get_logging_webhook(guild_id: int) -> str | None: """Gets the logging webhook URL for a guild. Returns None if not set or on error.""" log.debug(f"Attempting to get logging webhook for guild {guild_id}") - webhook_url = await get_setting(guild_id, 'logging_webhook_url', default=None) - log.debug(f"Retrieved logging webhook URL for guild {guild_id}: {'Set' if webhook_url else 'Not Set'}") + webhook_url = await get_setting(guild_id, "logging_webhook_url", default=None) + log.debug( + f"Retrieved logging webhook URL for guild {guild_id}: {'Set' if webhook_url else 'Not Set'}" + ) return webhook_url + async def set_logging_webhook(guild_id: int, webhook_url: str | None) -> bool: """Sets or removes the logging webhook URL for a guild.""" - log.info(f"Setting logging webhook URL for guild {guild_id} to: {'None (removing)' if webhook_url is None else 'Provided URL'}") - success = await set_setting(guild_id, 'logging_webhook_url', webhook_url) + log.info( + f"Setting logging webhook URL for guild {guild_id} to: {'None (removing)' if webhook_url is None else 'Provided URL'}" + ) + success = await set_setting(guild_id, "logging_webhook_url", webhook_url) if success: - log.info(f"Successfully {'set' if webhook_url else 'removed'} logging webhook for guild {guild_id}") + log.info( + f"Successfully {'set' if webhook_url else 'removed'} logging webhook for guild {guild_id}" + ) else: log.error(f"Failed to set logging webhook for guild {guild_id}") return success @@ -1524,15 +1956,19 @@ async def set_logging_webhook(guild_id: int, webhook_url: str | None) -> bool: # --- Logging Event Toggle Functions --- + def _get_log_toggle_cache_key(guild_id: int) -> str: """Generates the Redis Hash key for logging toggles.""" return f"guild:{guild_id}:log_toggles" + async def get_all_log_event_toggles(guild_id: int) -> Dict[str, bool]: """Gets all logging event toggle settings for a guild, checking cache first.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.warning(f"Bot instance or pools not available in settings_manager, cannot get log toggles for guild {guild_id}.") + log.warning( + f"Bot instance or pools not available in settings_manager, cannot get log toggles for guild {guild_id}." + ) return {} cache_key = _get_log_toggle_cache_key(guild_id) @@ -1540,13 +1976,17 @@ async def get_all_log_event_toggles(guild_id: int) -> Dict[str, bool]: # Try cache first try: - cached_toggles = await asyncio.wait_for(bot.redis.hgetall(cache_key), timeout=2.0) + cached_toggles = await asyncio.wait_for( + bot.redis.hgetall(cache_key), timeout=2.0 + ) if cached_toggles: log.debug(f"Cache hit for log toggles (Guild: {guild_id})") # Convert string bools back to boolean - return {key: value == 'True' for key, value in cached_toggles.items()} + return {key: value == "True" for key, value in cached_toggles.items()} except asyncio.TimeoutError: - log.warning(f"Redis timeout getting log toggles for guild {guild_id}, falling back to database") + log.warning( + f"Redis timeout getting log toggles for guild {guild_id}, falling back to database" + ) except Exception as e: log.exception(f"Redis error getting log toggles for guild {guild_id}: {e}") @@ -1556,54 +1996,69 @@ async def get_all_log_event_toggles(guild_id: int) -> Dict[str, bool]: async with bot.pg_pool.acquire() as conn: records = await conn.fetch( "SELECT event_key, enabled FROM logging_event_toggles WHERE guild_id = $1", - guild_id + guild_id, ) - toggles = {record['event_key']: record['enabled'] for record in records} + toggles = {record["event_key"]: record["enabled"] for record in records} # Cache the result (even if empty) try: # Convert boolean values to strings for Redis Hash toggles_to_cache = {key: str(value) for key, value in toggles.items()} - if toggles_to_cache: # Only set if there are toggles, otherwise cache remains empty + if ( + toggles_to_cache + ): # Only set if there are toggles, otherwise cache remains empty async with bot.redis.pipeline(transaction=True) as pipe: - pipe.delete(cache_key) # Clear potentially stale data + pipe.delete(cache_key) # Clear potentially stale data pipe.hset(cache_key, mapping=toggles_to_cache) - pipe.expire(cache_key, 3600) # Cache for 1 hour + pipe.expire(cache_key, 3600) # Cache for 1 hour await pipe.execute() else: # If DB is empty, ensure cache is also empty (or set a placeholder if needed) - await bot.redis.delete(cache_key) + await bot.redis.delete(cache_key) except Exception as e: - log.exception(f"Redis error setting cache for log toggles (Guild: {guild_id}): {e}") + log.exception( + f"Redis error setting cache for log toggles (Guild: {guild_id}): {e}" + ) return toggles except Exception as e: log.exception(f"Database error getting log toggles for guild {guild_id}: {e}") - return {} # Return empty on DB error + return {} # Return empty on DB error -async def is_log_event_enabled(guild_id: int, event_key: str, default_enabled: bool = True) -> bool: + +async def is_log_event_enabled( + guild_id: int, event_key: str, default_enabled: bool = True +) -> bool: """Checks if a specific logging event is enabled for a guild.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.warning(f"Bot instance or pools not available in settings_manager for guild {guild_id}, returning default for log event '{event_key}'.") + log.warning( + f"Bot instance or pools not available in settings_manager for guild {guild_id}, returning default for log event '{event_key}'." + ) return default_enabled cache_key = _get_log_toggle_cache_key(guild_id) # Try cache first try: - cached_value = await asyncio.wait_for(bot.redis.hget(cache_key, event_key), timeout=2.0) + cached_value = await asyncio.wait_for( + bot.redis.hget(cache_key, event_key), timeout=2.0 + ) if cached_value is not None: # log.debug(f"Cache hit for log event '{event_key}' status (Guild: {guild_id})") - return cached_value == 'True' + return cached_value == "True" else: # Field doesn't exist in cache, check DB (might not be explicitly set) - pass # Fall through to DB check + pass # Fall through to DB check except asyncio.TimeoutError: - log.warning(f"Redis timeout getting log event '{event_key}' for guild {guild_id}, falling back to database") + log.warning( + f"Redis timeout getting log event '{event_key}' for guild {guild_id}, falling back to database" + ) except Exception as e: - log.exception(f"Redis error getting log event '{event_key}' for guild {guild_id}: {e}") + log.exception( + f"Redis error getting log event '{event_key}' for guild {guild_id}: {e}" + ) # Cache miss or error, get from DB # log.debug(f"Cache miss for log event '{event_key}' (Guild: {guild_id})") @@ -1612,42 +2067,58 @@ async def is_log_event_enabled(guild_id: int, event_key: str, default_enabled: b async with bot.pg_pool.acquire() as conn: db_enabled_status = await conn.fetchval( "SELECT enabled FROM logging_event_toggles WHERE guild_id = $1 AND event_key = $2", - guild_id, event_key + guild_id, + event_key, ) - final_status = db_enabled_status if db_enabled_status is not None else default_enabled + final_status = ( + db_enabled_status if db_enabled_status is not None else default_enabled + ) # Cache the specific result (only if fetched from DB) - if db_enabled_status is not None: # Only cache if it was explicitly set in DB + if db_enabled_status is not None: # Only cache if it was explicitly set in DB try: await asyncio.wait_for( - bot.redis.hset(cache_key, event_key, str(final_status)), - timeout=2.0 + bot.redis.hset(cache_key, event_key, str(final_status)), timeout=2.0 ) # Ensure the hash key itself has an expiry - await bot.redis.expire(cache_key, 3600, nx=True) # Set expiry only if it doesn't exist + await bot.redis.expire( + cache_key, 3600, nx=True + ) # Set expiry only if it doesn't exist except asyncio.TimeoutError: - log.warning(f"Redis timeout setting cache for log event '{event_key}' (Guild: {guild_id})") + log.warning( + f"Redis timeout setting cache for log event '{event_key}' (Guild: {guild_id})" + ) except Exception as e: - log.exception(f"Redis error setting cache for log event '{event_key}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error setting cache for log event '{event_key}' (Guild: {guild_id}): {e}" + ) return final_status except Exception as e: - log.exception(f"Database error getting log event '{event_key}' for guild {guild_id}: {e}") - return default_enabled # Fallback on DB error + log.exception( + f"Database error getting log event '{event_key}' for guild {guild_id}: {e}" + ) + return default_enabled # Fallback on DB error + async def set_log_event_enabled(guild_id: int, event_key: str, enabled: bool) -> bool: """Sets the enabled status for a specific logging event type.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.error(f"Bot instance or pools not available in settings_manager for guild {guild_id}, cannot set log event '{event_key}'.") + log.error( + f"Bot instance or pools not available in settings_manager for guild {guild_id}, cannot set log event '{event_key}'." + ) return False cache_key = _get_log_toggle_cache_key(guild_id) try: async with bot.pg_pool.acquire() as conn: # Ensure guild exists - await conn.execute("INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", guild_id) + await conn.execute( + "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", + guild_id, + ) # Upsert the toggle status await conn.execute( """ @@ -1655,27 +2126,38 @@ async def set_log_event_enabled(guild_id: int, event_key: str, enabled: bool) -> VALUES ($1, $2, $3) ON CONFLICT (guild_id, event_key) DO UPDATE SET enabled = $3; """, - guild_id, event_key, enabled + guild_id, + event_key, + enabled, ) # Update cache await bot.redis.hset(cache_key, event_key, str(enabled)) # Ensure the hash key itself has an expiry - await bot.redis.expire(cache_key, 3600, nx=True) # Set expiry only if it doesn't exist - log.info(f"Set log event '{event_key}' enabled status to {enabled} for guild {guild_id}") + await bot.redis.expire( + cache_key, 3600, nx=True + ) # Set expiry only if it doesn't exist + log.info( + f"Set log event '{event_key}' enabled status to {enabled} for guild {guild_id}" + ) return True except Exception as e: - log.exception(f"Database or Redis error setting log event '{event_key}' in guild {guild_id}: {e}") + log.exception( + f"Database or Redis error setting log event '{event_key}' in guild {guild_id}: {e}" + ) # Attempt to invalidate cache field on error try: await bot.redis.hdel(cache_key, event_key) except Exception as redis_err: - log.exception(f"Failed to invalidate Redis cache field for log event '{event_key}' (Guild: {guild_id}): {redis_err}") + log.exception( + f"Failed to invalidate Redis cache field for log event '{event_key}' (Guild: {guild_id}): {redis_err}" + ) return False # --- Bot Guild Information --- + async def get_bot_guild_ids() -> set[int] | None: """ Gets the set of all guild IDs known to the bot from the guilds table. @@ -1688,7 +2170,12 @@ async def get_bot_guild_ids() -> set[int] | None: try: # Import here to avoid circular imports from api_service.api_server import app - if hasattr(app, 'state') and hasattr(app.state, 'pg_pool') and app.state.pg_pool: + + if ( + hasattr(app, "state") + and hasattr(app.state, "pg_pool") + and app.state.pg_pool + ): log.debug("Using API server's PostgreSQL pool for get_bot_guild_ids") return await get_bot_guild_ids_with_pool(app.state.pg_pool) except (ImportError, AttributeError) as e: @@ -1699,15 +2186,19 @@ async def get_bot_guild_ids() -> set[int] | None: # Fall back to the bot's pool bot = get_bot_instance() if not bot or not bot.pg_pool: - log.error("Bot instance or PostgreSQL pool not available in settings_manager. Cannot get bot guild IDs.") + log.error( + "Bot instance or PostgreSQL pool not available in settings_manager. Cannot get bot guild IDs." + ) return None try: # Use the bot's connection pool async with bot.pg_pool.acquire() as conn: records = await conn.fetch("SELECT guild_id FROM guilds") - guild_ids = {record['guild_id'] for record in records} - log.debug(f"Fetched {len(guild_ids)} guild IDs from database using bot pool.") + guild_ids = {record["guild_id"] for record in records} + log.debug( + f"Fetched {len(guild_ids)} guild IDs from database using bot pool." + ) return guild_ids except asyncpg.exceptions.PostgresError as e: log.exception(f"PostgreSQL error fetching bot guild IDs using bot pool: {e}") @@ -1715,8 +2206,12 @@ async def get_bot_guild_ids() -> set[int] | None: except RuntimeError as e: if "got Future" in str(e) and "attached to a different loop" in str(e): log.error(f"Event loop error in get_bot_guild_ids: {e}") - log.warning("This is likely because the function is being called from the API server thread.") - log.warning("Try using get_bot_guild_ids_with_pool with app.state.pg_pool instead.") + log.warning( + "This is likely because the function is being called from the API server thread." + ) + log.warning( + "Try using get_bot_guild_ids_with_pool with app.state.pg_pool instead." + ) return None else: log.exception(f"Runtime error fetching bot guild IDs: {e}") @@ -1725,6 +2220,7 @@ async def get_bot_guild_ids() -> set[int] | None: log.exception(f"Unexpected error fetching bot guild IDs: {e}") return None + async def get_bot_guild_ids_with_pool(pool) -> set[int] | None: """ Gets the set of all guild IDs known to the bot from the guilds table using a provided pool. @@ -1739,41 +2235,59 @@ async def get_bot_guild_ids_with_pool(pool) -> set[int] | None: # Use the provided connection pool async with pool.acquire() as conn: records = await conn.fetch("SELECT guild_id FROM guilds") - guild_ids = {record['guild_id'] for record in records} - log.debug(f"Fetched {len(guild_ids)} guild IDs from database using provided pool.") + guild_ids = {record["guild_id"] for record in records} + log.debug( + f"Fetched {len(guild_ids)} guild IDs from database using provided pool." + ) return guild_ids except asyncpg.exceptions.PostgresError as e: - log.exception(f"PostgreSQL error fetching bot guild IDs using provided pool: {e}") + log.exception( + f"PostgreSQL error fetching bot guild IDs using provided pool: {e}" + ) return None except Exception as e: - log.exception(f"Unexpected error fetching bot guild IDs with provided pool: {e}") + log.exception( + f"Unexpected error fetching bot guild IDs with provided pool: {e}" + ) return None # --- Command Customization Functions --- -async def get_custom_command_name(guild_id: int, original_command_name: str) -> str | None: + +async def get_custom_command_name( + guild_id: int, original_command_name: str +) -> str | None: """Gets the custom command name for a guild, checking cache first. - Returns None if no custom name is set.""" + Returns None if no custom name is set.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.warning(f"Bot instance or pools not available in settings_manager for guild {guild_id}, returning None for custom command name '{original_command_name}'.") + log.warning( + f"Bot instance or pools not available in settings_manager for guild {guild_id}, returning None for custom command name '{original_command_name}'." + ) return None cache_key = _get_redis_key(guild_id, "cmd_custom", original_command_name) try: cached_value = await bot.redis.get(cache_key) if cached_value is not None: - log.debug(f"Cache hit for custom command name '{original_command_name}' (Guild: {guild_id})") + log.debug( + f"Cache hit for custom command name '{original_command_name}' (Guild: {guild_id})" + ) return None if cached_value == "__NONE__" else cached_value except Exception as e: - log.exception(f"Redis error getting custom command name for '{original_command_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error getting custom command name for '{original_command_name}' (Guild: {guild_id}): {e}" + ) - log.debug(f"Cache miss for custom command name '{original_command_name}' (Guild: {guild_id})") + log.debug( + f"Cache miss for custom command name '{original_command_name}' (Guild: {guild_id})" + ) async with bot.pg_pool.acquire() as conn: custom_name = await conn.fetchval( "SELECT custom_command_name FROM command_customization WHERE guild_id = $1 AND original_command_name = $2", - guild_id, original_command_name + guild_id, + original_command_name, ) # Cache the result (even if None) @@ -1781,33 +2295,46 @@ async def get_custom_command_name(guild_id: int, original_command_name: str) -> value_to_cache = custom_name if custom_name is not None else "__NONE__" await bot.redis.set(cache_key, value_to_cache, ex=3600) # Cache for 1 hour except Exception as e: - log.exception(f"Redis error setting cache for custom command name '{original_command_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error setting cache for custom command name '{original_command_name}' (Guild: {guild_id}): {e}" + ) return custom_name -async def get_custom_command_description(guild_id: int, original_command_name: str) -> str | None: +async def get_custom_command_description( + guild_id: int, original_command_name: str +) -> str | None: """Gets the custom command description for a guild, checking cache first. - Returns None if no custom description is set.""" + Returns None if no custom description is set.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.warning(f"Bot instance or pools not available in settings_manager for guild {guild_id}, returning None for custom command description '{original_command_name}'.") + log.warning( + f"Bot instance or pools not available in settings_manager for guild {guild_id}, returning None for custom command description '{original_command_name}'." + ) return None cache_key = _get_redis_key(guild_id, "cmd_desc", original_command_name) try: cached_value = await bot.redis.get(cache_key) if cached_value is not None: - log.debug(f"Cache hit for custom command description '{original_command_name}' (Guild: {guild_id})") + log.debug( + f"Cache hit for custom command description '{original_command_name}' (Guild: {guild_id})" + ) return None if cached_value == "__NONE__" else cached_value except Exception as e: - log.exception(f"Redis error getting custom command description for '{original_command_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error getting custom command description for '{original_command_name}' (Guild: {guild_id}): {e}" + ) - log.debug(f"Cache miss for custom command description '{original_command_name}' (Guild: {guild_id})") + log.debug( + f"Cache miss for custom command description '{original_command_name}' (Guild: {guild_id})" + ) async with bot.pg_pool.acquire() as conn: custom_desc = await conn.fetchval( "SELECT custom_command_description FROM command_customization WHERE guild_id = $1 AND original_command_name = $2", - guild_id, original_command_name + guild_id, + original_command_name, ) # Cache the result (even if None) @@ -1815,24 +2342,33 @@ async def get_custom_command_description(guild_id: int, original_command_name: s value_to_cache = custom_desc if custom_desc is not None else "__NONE__" await bot.redis.set(cache_key, value_to_cache, ex=3600) # Cache for 1 hour except Exception as e: - log.exception(f"Redis error setting cache for custom command description '{original_command_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error setting cache for custom command description '{original_command_name}' (Guild: {guild_id}): {e}" + ) return custom_desc -async def set_custom_command_name(guild_id: int, original_command_name: str, custom_command_name: str | None) -> bool: +async def set_custom_command_name( + guild_id: int, original_command_name: str, custom_command_name: str | None +) -> bool: """Sets a custom command name for a guild and updates the cache. - Setting custom_command_name to None removes the customization.""" + Setting custom_command_name to None removes the customization.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.error(f"Bot instance or pools not available in settings_manager for guild {guild_id}, cannot set custom command name for '{original_command_name}'.") + log.error( + f"Bot instance or pools not available in settings_manager for guild {guild_id}, cannot set custom command name for '{original_command_name}'." + ) return False cache_key = _get_redis_key(guild_id, "cmd_custom", original_command_name) try: async with bot.pg_pool.acquire() as conn: # Ensure guild exists - await conn.execute("INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", guild_id) + await conn.execute( + "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", + guild_id, + ) if custom_command_name is not None: # Upsert the custom name @@ -1842,50 +2378,69 @@ async def set_custom_command_name(guild_id: int, original_command_name: str, cus VALUES ($1, $2, $3) ON CONFLICT (guild_id, original_command_name) DO UPDATE SET custom_command_name = $3; """, - guild_id, original_command_name, custom_command_name + guild_id, + original_command_name, + custom_command_name, ) # Update cache await bot.redis.set(cache_key, custom_command_name, ex=3600) - log.info(f"Set custom command name for '{original_command_name}' to '{custom_command_name}' for guild {guild_id}") + log.info( + f"Set custom command name for '{original_command_name}' to '{custom_command_name}' for guild {guild_id}" + ) else: # Delete the customization if value is None await conn.execute( "DELETE FROM command_customization WHERE guild_id = $1 AND original_command_name = $2", - guild_id, original_command_name + guild_id, + original_command_name, ) # Update cache to indicate no customization await bot.redis.set(cache_key, "__NONE__", ex=3600) - log.info(f"Removed custom command name for '{original_command_name}' for guild {guild_id}") + log.info( + f"Removed custom command name for '{original_command_name}' for guild {guild_id}" + ) return True except Exception as e: - log.exception(f"Database or Redis error setting custom command name for '{original_command_name}' in guild {guild_id}: {e}") + log.exception( + f"Database or Redis error setting custom command name for '{original_command_name}' in guild {guild_id}: {e}" + ) # Attempt to invalidate cache on error try: await bot.redis.delete(cache_key) except Exception as redis_err: - log.exception(f"Failed to invalidate Redis cache for custom command name '{original_command_name}' (Guild: {guild_id}): {redis_err}") + log.exception( + f"Failed to invalidate Redis cache for custom command name '{original_command_name}' (Guild: {guild_id}): {redis_err}" + ) return False -async def set_custom_command_description(guild_id: int, original_command_name: str, custom_command_description: str | None) -> bool: +async def set_custom_command_description( + guild_id: int, original_command_name: str, custom_command_description: str | None +) -> bool: """Sets a custom command description for a guild and updates the cache. - Setting custom_command_description to None removes the description.""" + Setting custom_command_description to None removes the description.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.error(f"Bot instance or pools not available in settings_manager for guild {guild_id}, cannot set custom command description for '{original_command_name}'.") + log.error( + f"Bot instance or pools not available in settings_manager for guild {guild_id}, cannot set custom command description for '{original_command_name}'." + ) return False cache_key = _get_redis_key(guild_id, "cmd_desc", original_command_name) try: async with bot.pg_pool.acquire() as conn: # Ensure guild exists - await conn.execute("INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", guild_id) + await conn.execute( + "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", + guild_id, + ) # Check if the command customization exists exists = await conn.fetchval( "SELECT 1 FROM command_customization WHERE guild_id = $1 AND original_command_name = $2", - guild_id, original_command_name + guild_id, + original_command_name, ) if custom_command_description is not None: @@ -1897,7 +2452,9 @@ async def set_custom_command_description(guild_id: int, original_command_name: s SET custom_command_description = $3 WHERE guild_id = $1 AND original_command_name = $2; """, - guild_id, original_command_name, custom_command_description + guild_id, + original_command_name, + custom_command_description, ) else: # Insert a new record with default custom_command_name (same as original) @@ -1906,11 +2463,15 @@ async def set_custom_command_description(guild_id: int, original_command_name: s INSERT INTO command_customization (guild_id, original_command_name, custom_command_name, custom_command_description) VALUES ($1, $2, $2, $3); """, - guild_id, original_command_name, custom_command_description + guild_id, + original_command_name, + custom_command_description, ) # Update cache await bot.redis.set(cache_key, custom_command_description, ex=3600) - log.info(f"Set custom command description for '{original_command_name}' for guild {guild_id}") + log.info( + f"Set custom command description for '{original_command_name}' for guild {guild_id}" + ) else: if exists: # Update the existing record to remove the description @@ -1920,45 +2481,61 @@ async def set_custom_command_description(guild_id: int, original_command_name: s SET custom_command_description = NULL WHERE guild_id = $1 AND original_command_name = $2; """, - guild_id, original_command_name + guild_id, + original_command_name, ) # Update cache to indicate no description await bot.redis.set(cache_key, "__NONE__", ex=3600) - log.info(f"Removed custom command description for '{original_command_name}' for guild {guild_id}") + log.info( + f"Removed custom command description for '{original_command_name}' for guild {guild_id}" + ) return True except Exception as e: - log.exception(f"Database or Redis error setting custom command description for '{original_command_name}' in guild {guild_id}: {e}") + log.exception( + f"Database or Redis error setting custom command description for '{original_command_name}' in guild {guild_id}: {e}" + ) # Attempt to invalidate cache on error try: await bot.redis.delete(cache_key) except Exception as redis_err: - log.exception(f"Failed to invalidate Redis cache for custom command description '{original_command_name}' (Guild: {guild_id}): {redis_err}") + log.exception( + f"Failed to invalidate Redis cache for custom command description '{original_command_name}' (Guild: {guild_id}): {redis_err}" + ) return False async def get_custom_group_name(guild_id: int, original_group_name: str) -> str | None: """Gets the custom command group name for a guild, checking cache first. - Returns None if no custom name is set.""" + Returns None if no custom name is set.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.warning(f"Bot instance or pools not available in settings_manager for guild {guild_id}, returning None for custom group name '{original_group_name}'.") + log.warning( + f"Bot instance or pools not available in settings_manager for guild {guild_id}, returning None for custom group name '{original_group_name}'." + ) return None cache_key = _get_redis_key(guild_id, "group_custom", original_group_name) try: cached_value = await bot.redis.get(cache_key) if cached_value is not None: - log.debug(f"Cache hit for custom group name '{original_group_name}' (Guild: {guild_id})") + log.debug( + f"Cache hit for custom group name '{original_group_name}' (Guild: {guild_id})" + ) return None if cached_value == "__NONE__" else cached_value except Exception as e: - log.exception(f"Redis error getting custom group name for '{original_group_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error getting custom group name for '{original_group_name}' (Guild: {guild_id}): {e}" + ) - log.debug(f"Cache miss for custom group name '{original_group_name}' (Guild: {guild_id})") + log.debug( + f"Cache miss for custom group name '{original_group_name}' (Guild: {guild_id})" + ) async with bot.pg_pool.acquire() as conn: custom_name = await conn.fetchval( "SELECT custom_group_name FROM command_group_customization WHERE guild_id = $1 AND original_group_name = $2", - guild_id, original_group_name + guild_id, + original_group_name, ) # Cache the result (even if None) @@ -1966,24 +2543,33 @@ async def get_custom_group_name(guild_id: int, original_group_name: str) -> str value_to_cache = custom_name if custom_name is not None else "__NONE__" await bot.redis.set(cache_key, value_to_cache, ex=3600) # Cache for 1 hour except Exception as e: - log.exception(f"Redis error setting cache for custom group name '{original_group_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error setting cache for custom group name '{original_group_name}' (Guild: {guild_id}): {e}" + ) return custom_name -async def set_custom_group_name(guild_id: int, original_group_name: str, custom_group_name: str | None) -> bool: +async def set_custom_group_name( + guild_id: int, original_group_name: str, custom_group_name: str | None +) -> bool: """Sets a custom command group name for a guild and updates the cache. - Setting custom_group_name to None removes the customization.""" + Setting custom_group_name to None removes the customization.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.error(f"Bot instance or pools not available in settings_manager for guild {guild_id}, cannot set custom group name for '{original_group_name}'.") + log.error( + f"Bot instance or pools not available in settings_manager for guild {guild_id}, cannot set custom group name for '{original_group_name}'." + ) return False cache_key = _get_redis_key(guild_id, "group_custom", original_group_name) try: async with bot.pg_pool.acquire() as conn: # Ensure guild exists - await conn.execute("INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", guild_id) + await conn.execute( + "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", + guild_id, + ) if custom_group_name is not None: # Upsert the custom name @@ -1993,44 +2579,62 @@ async def set_custom_group_name(guild_id: int, original_group_name: str, custom_ VALUES ($1, $2, $3) ON CONFLICT (guild_id, original_group_name) DO UPDATE SET custom_group_name = $3; """, - guild_id, original_group_name, custom_group_name + guild_id, + original_group_name, + custom_group_name, ) # Update cache await bot.redis.set(cache_key, custom_group_name, ex=3600) - log.info(f"Set custom group name for '{original_group_name}' to '{custom_group_name}' for guild {guild_id}") + log.info( + f"Set custom group name for '{original_group_name}' to '{custom_group_name}' for guild {guild_id}" + ) else: # Delete the customization if value is None await conn.execute( "DELETE FROM command_group_customization WHERE guild_id = $1 AND original_group_name = $2", - guild_id, original_group_name + guild_id, + original_group_name, ) # Update cache to indicate no customization await bot.redis.set(cache_key, "__NONE__", ex=3600) - log.info(f"Removed custom group name for '{original_group_name}' for guild {guild_id}") + log.info( + f"Removed custom group name for '{original_group_name}' for guild {guild_id}" + ) return True except Exception as e: - log.exception(f"Database or Redis error setting custom group name for '{original_group_name}' in guild {guild_id}: {e}") + log.exception( + f"Database or Redis error setting custom group name for '{original_group_name}' in guild {guild_id}: {e}" + ) # Attempt to invalidate cache on error try: await bot.redis.delete(cache_key) except Exception as redis_err: - log.exception(f"Failed to invalidate Redis cache for custom group name '{original_group_name}' (Guild: {guild_id}): {redis_err}") + log.exception( + f"Failed to invalidate Redis cache for custom group name '{original_group_name}' (Guild: {guild_id}): {redis_err}" + ) return False -async def add_command_alias(guild_id: int, original_command_name: str, alias_name: str) -> bool: +async def add_command_alias( + guild_id: int, original_command_name: str, alias_name: str +) -> bool: """Adds an alias for a command in a guild and invalidates cache.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.error(f"Bot instance or pools not available in settings_manager for guild {guild_id}, cannot add alias for command '{original_command_name}'.") + log.error( + f"Bot instance or pools not available in settings_manager for guild {guild_id}, cannot add alias for command '{original_command_name}'." + ) return False cache_key = _get_redis_key(guild_id, "cmd_aliases", original_command_name) try: async with bot.pg_pool.acquire() as conn: # Ensure guild exists - await conn.execute("INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", guild_id) + await conn.execute( + "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", + guild_id, + ) # Add the alias await conn.execute( """ @@ -2038,28 +2642,40 @@ async def add_command_alias(guild_id: int, original_command_name: str, alias_nam VALUES ($1, $2, $3) ON CONFLICT (guild_id, original_command_name, alias_name) DO NOTHING; """, - guild_id, original_command_name, alias_name + guild_id, + original_command_name, + alias_name, ) # Invalidate cache after DB operation succeeds await bot.redis.delete(cache_key) - log.info(f"Added alias '{alias_name}' for command '{original_command_name}' in guild {guild_id}") + log.info( + f"Added alias '{alias_name}' for command '{original_command_name}' in guild {guild_id}" + ) return True except Exception as e: - log.exception(f"Database or Redis error adding alias for command '{original_command_name}' in guild {guild_id}: {e}") + log.exception( + f"Database or Redis error adding alias for command '{original_command_name}' in guild {guild_id}: {e}" + ) # Attempt to invalidate cache even on error try: await bot.redis.delete(cache_key) except Exception as redis_err: - log.exception(f"Failed to invalidate Redis cache for command aliases '{original_command_name}' (Guild: {guild_id}): {redis_err}") + log.exception( + f"Failed to invalidate Redis cache for command aliases '{original_command_name}' (Guild: {guild_id}): {redis_err}" + ) return False -async def remove_command_alias(guild_id: int, original_command_name: str, alias_name: str) -> bool: +async def remove_command_alias( + guild_id: int, original_command_name: str, alias_name: str +) -> bool: """Removes an alias for a command in a guild and invalidates cache.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.error(f"Bot instance or pools not available in settings_manager for guild {guild_id}, cannot remove alias for command '{original_command_name}'.") + log.error( + f"Bot instance or pools not available in settings_manager for guild {guild_id}, cannot remove alias for command '{original_command_name}'." + ) return False cache_key = _get_redis_key(guild_id, "cmd_aliases", original_command_name) @@ -2071,29 +2687,41 @@ async def remove_command_alias(guild_id: int, original_command_name: str, alias_ DELETE FROM command_aliases WHERE guild_id = $1 AND original_command_name = $2 AND alias_name = $3; """, - guild_id, original_command_name, alias_name + guild_id, + original_command_name, + alias_name, ) # Invalidate cache after DB operation succeeds await bot.redis.delete(cache_key) - log.info(f"Removed alias '{alias_name}' for command '{original_command_name}' in guild {guild_id}") + log.info( + f"Removed alias '{alias_name}' for command '{original_command_name}' in guild {guild_id}" + ) return True except Exception as e: - log.exception(f"Database or Redis error removing alias for command '{original_command_name}' in guild {guild_id}: {e}") + log.exception( + f"Database or Redis error removing alias for command '{original_command_name}' in guild {guild_id}: {e}" + ) # Attempt to invalidate cache even on error try: await bot.redis.delete(cache_key) except Exception as redis_err: - log.exception(f"Failed to invalidate Redis cache for command aliases '{original_command_name}' (Guild: {guild_id}): {redis_err}") + log.exception( + f"Failed to invalidate Redis cache for command aliases '{original_command_name}' (Guild: {guild_id}): {redis_err}" + ) return False -async def get_command_aliases(guild_id: int, original_command_name: str) -> list[str] | None: +async def get_command_aliases( + guild_id: int, original_command_name: str +) -> list[str] | None: """Gets the list of aliases for a command in a guild, checking cache first. - Returns empty list if no aliases are set, None on error.""" + Returns empty list if no aliases are set, None on error.""" bot = get_bot_instance() if not bot or not bot.pg_pool or not bot.redis: - log.warning(f"Bot instance or pools not available in settings_manager for guild {guild_id}, returning None for command aliases '{original_command_name}'.") + log.warning( + f"Bot instance or pools not available in settings_manager for guild {guild_id}, returning None for command aliases '{original_command_name}'." + ) return None cache_key = _get_redis_key(guild_id, "cmd_aliases", original_command_name) @@ -2102,22 +2730,31 @@ async def get_command_aliases(guild_id: int, original_command_name: str) -> list cached_aliases = await bot.redis.lrange(cache_key, 0, -1) if cached_aliases is not None: if len(cached_aliases) == 1 and cached_aliases[0] == "__EMPTY_LIST__": - log.debug(f"Cache hit (empty list) for command aliases '{original_command_name}' (Guild: {guild_id}).") + log.debug( + f"Cache hit (empty list) for command aliases '{original_command_name}' (Guild: {guild_id})." + ) return [] - log.debug(f"Cache hit for command aliases '{original_command_name}' (Guild: {guild_id})") + log.debug( + f"Cache hit for command aliases '{original_command_name}' (Guild: {guild_id})" + ) return cached_aliases except Exception as e: - log.exception(f"Redis error getting command aliases for '{original_command_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error getting command aliases for '{original_command_name}' (Guild: {guild_id}): {e}" + ) # Fall through to DB query on Redis error - log.debug(f"Cache miss for command aliases '{original_command_name}' (Guild: {guild_id})") + log.debug( + f"Cache miss for command aliases '{original_command_name}' (Guild: {guild_id})" + ) try: async with bot.pg_pool.acquire() as conn: records = await conn.fetch( "SELECT alias_name FROM command_aliases WHERE guild_id = $1 AND original_command_name = $2", - guild_id, original_command_name + guild_id, + original_command_name, ) - aliases = [record['alias_name'] for record in records] + aliases = [record["alias_name"] for record in records] # Cache the result try: @@ -2130,125 +2767,157 @@ async def get_command_aliases(guild_id: int, original_command_name: str) -> list pipe.expire(cache_key, 3600) # Cache for 1 hour await pipe.execute() except Exception as e: - log.exception(f"Redis error setting cache for command aliases '{original_command_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Redis error setting cache for command aliases '{original_command_name}' (Guild: {guild_id}): {e}" + ) return aliases except Exception as e: - log.exception(f"Database error getting command aliases for '{original_command_name}' (Guild: {guild_id}): {e}") + log.exception( + f"Database error getting command aliases for '{original_command_name}' (Guild: {guild_id}): {e}" + ) return None # Indicate error -async def get_all_command_customizations(guild_id: int) -> dict[str, dict[str, str]] | None: +async def get_all_command_customizations( + guild_id: int, +) -> dict[str, dict[str, str]] | None: """Gets all command customizations for a guild. - Returns a dictionary mapping original command names to a dict with 'name' and 'description' keys, - or None on error.""" + Returns a dictionary mapping original command names to a dict with 'name' and 'description' keys, + or None on error.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.error(f"Bot instance or PostgreSQL pool not available in settings_manager for guild {guild_id}, cannot get command customizations.") + log.error( + f"Bot instance or PostgreSQL pool not available in settings_manager for guild {guild_id}, cannot get command customizations." + ) return None try: async with bot.pg_pool.acquire() as conn: records = await conn.fetch( "SELECT original_command_name, custom_command_name, custom_command_description FROM command_customization WHERE guild_id = $1", - guild_id + guild_id, ) customizations = {} for record in records: - cmd_name = record['original_command_name'] + cmd_name = record["original_command_name"] customizations[cmd_name] = { - 'name': record['custom_command_name'], - 'description': record['custom_command_description'] + "name": record["custom_command_name"], + "description": record["custom_command_description"], } - log.debug(f"Fetched {len(customizations)} command customizations for guild {guild_id}.") + log.debug( + f"Fetched {len(customizations)} command customizations for guild {guild_id}." + ) return customizations except Exception as e: - log.exception(f"Database error fetching command customizations for guild {guild_id}: {e}") + log.exception( + f"Database error fetching command customizations for guild {guild_id}: {e}" + ) return None -async def get_all_group_customizations(guild_id: int) -> dict[str, dict[str, str]] | None: +async def get_all_group_customizations( + guild_id: int, +) -> dict[str, dict[str, str]] | None: """Gets all command group customizations for a guild. - Returns a dictionary mapping original group names to a dict with 'name' and 'description' keys, - or None on error.""" + Returns a dictionary mapping original group names to a dict with 'name' and 'description' keys, + or None on error.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.error(f"Bot instance or PostgreSQL pool not available in settings_manager for guild {guild_id}, cannot get group customizations.") + log.error( + f"Bot instance or PostgreSQL pool not available in settings_manager for guild {guild_id}, cannot get group customizations." + ) return None try: async with bot.pg_pool.acquire() as conn: records = await conn.fetch( "SELECT original_group_name, custom_group_name FROM command_group_customization WHERE guild_id = $1", - guild_id + guild_id, ) customizations = {} for record in records: - group_name = record['original_group_name'] + group_name = record["original_group_name"] customizations[group_name] = { - 'name': record['custom_group_name'], - 'description': None # Groups don't have custom descriptions yet + "name": record["custom_group_name"], + "description": None, # Groups don't have custom descriptions yet } - log.debug(f"Fetched {len(customizations)} group customizations for guild {guild_id}.") + log.debug( + f"Fetched {len(customizations)} group customizations for guild {guild_id}." + ) return customizations except Exception as e: - log.exception(f"Database error fetching group customizations for guild {guild_id}: {e}") + log.exception( + f"Database error fetching group customizations for guild {guild_id}: {e}" + ) return None async def get_all_command_aliases(guild_id: int) -> dict[str, list[str]] | None: """Gets all command aliases for a guild. - Returns a dictionary mapping original command names to lists of aliases, or None on error.""" + Returns a dictionary mapping original command names to lists of aliases, or None on error. + """ bot = get_bot_instance() if not bot or not bot.pg_pool: - log.error(f"Bot instance or PostgreSQL pool not available in settings_manager for guild {guild_id}, cannot get command aliases.") + log.error( + f"Bot instance or PostgreSQL pool not available in settings_manager for guild {guild_id}, cannot get command aliases." + ) return None try: async with bot.pg_pool.acquire() as conn: records = await conn.fetch( "SELECT original_command_name, alias_name FROM command_aliases WHERE guild_id = $1", - guild_id + guild_id, ) # Group by original_command_name aliases_dict = {} for record in records: - cmd_name = record['original_command_name'] - alias = record['alias_name'] + cmd_name = record["original_command_name"] + alias = record["alias_name"] if cmd_name not in aliases_dict: aliases_dict[cmd_name] = [] aliases_dict[cmd_name].append(alias) - log.debug(f"Fetched aliases for {len(aliases_dict)} commands for guild {guild_id}.") + log.debug( + f"Fetched aliases for {len(aliases_dict)} commands for guild {guild_id}." + ) return aliases_dict except Exception as e: - log.exception(f"Database error fetching command aliases for guild {guild_id}: {e}") + log.exception( + f"Database error fetching command aliases for guild {guild_id}: {e}" + ) return None # --- Moderation Logging Settings --- + async def is_mod_log_enabled(guild_id: int, default: bool = False) -> bool: """Checks if the integrated moderation log is enabled for a guild.""" - enabled_str = await get_setting(guild_id, 'mod_log_enabled', default=str(default)) + enabled_str = await get_setting(guild_id, "mod_log_enabled", default=str(default)) # Handle potential non-string default if get_setting fails early if isinstance(enabled_str, bool): return enabled_str - return enabled_str.lower() == 'true' + return enabled_str.lower() == "true" + async def set_mod_log_enabled(guild_id: int, enabled: bool) -> bool: """Sets the enabled status for the integrated moderation log.""" - return await set_setting(guild_id, 'mod_log_enabled', str(enabled)) + return await set_setting(guild_id, "mod_log_enabled", str(enabled)) + async def get_mod_log_channel_id(guild_id: int) -> int | None: """Gets the channel ID for the integrated moderation log.""" - channel_id_str = await get_setting(guild_id, 'mod_log_channel_id', default=None) + channel_id_str = await get_setting(guild_id, "mod_log_channel_id", default=None) if channel_id_str and channel_id_str.isdigit(): return int(channel_id_str) return None + async def set_mod_log_channel_id(guild_id: int, channel_id: int | None) -> bool: """Sets the channel ID for the integrated moderation log. Set to None to disable.""" value_to_set = str(channel_id) if channel_id is not None else None - return await set_setting(guild_id, 'mod_log_channel_id', value_to_set) + return await set_setting(guild_id, "mod_log_channel_id", value_to_set) + # --- Getter functions for direct pool access if absolutely needed --- # def get_pg_pool(): # Removed @@ -2264,34 +2933,44 @@ async def set_mod_log_channel_id(guild_id: int, channel_id: int | None) -> bool: # --- Git Repository Monitoring Functions --- + async def add_monitored_repository( guild_id: int, repository_url: str, - platform: str, # 'github' or 'gitlab' - monitoring_method: str, # 'webhook' or 'poll' + platform: str, # 'github' or 'gitlab' + monitoring_method: str, # 'webhook' or 'poll' notification_channel_id: int, added_by_user_id: int, - webhook_secret: str | None = None, # Only for 'webhook' - target_branch: str | None = None, # For polling + webhook_secret: str | None = None, # Only for 'webhook' + target_branch: str | None = None, # For polling polling_interval_minutes: int = 15, is_public_repo: bool = True, - last_polled_commit_sha: str | None = None, # For initial poll setup - allowed_webhook_events: list[str] | None = None # List of event names like ['push', 'issues'] + last_polled_commit_sha: str | None = None, # For initial poll setup + allowed_webhook_events: ( + list[str] | None + ) = None, # List of event names like ['push', 'issues'] ) -> int | None: """Adds a new repository to monitor. Returns the ID of the new row, or None on failure.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.error(f"Bot instance or PostgreSQL pool not available for add_monitored_repository (guild {guild_id}).") + log.error( + f"Bot instance or PostgreSQL pool not available for add_monitored_repository (guild {guild_id})." + ) return None try: async with bot.pg_pool.acquire() as conn: # Ensure guild exists - await conn.execute("INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", guild_id) + await conn.execute( + "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;", + guild_id, + ) # Insert the new repository monitoring entry # Default allowed_webhook_events if not provided or empty - final_allowed_events = allowed_webhook_events if allowed_webhook_events else ['push'] + final_allowed_events = ( + allowed_webhook_events if allowed_webhook_events else ["push"] + ) repo_id = await conn.fetchval( """ @@ -2305,13 +2984,23 @@ async def add_monitored_repository( ON CONFLICT (guild_id, repository_url, notification_channel_id) DO NOTHING RETURNING id; """, - guild_id, repository_url, platform, monitoring_method, - notification_channel_id, added_by_user_id, webhook_secret, target_branch, - polling_interval_minutes, is_public_repo, last_polled_commit_sha, - final_allowed_events + guild_id, + repository_url, + platform, + monitoring_method, + notification_channel_id, + added_by_user_id, + webhook_secret, + target_branch, + polling_interval_minutes, + is_public_repo, + last_polled_commit_sha, + final_allowed_events, ) if repo_id: - log.info(f"Added repository '{repository_url}' (Branch: {target_branch or 'default'}, Events: {final_allowed_events}) for monitoring in guild {guild_id}, channel {notification_channel_id}. ID: {repo_id}") + log.info( + f"Added repository '{repository_url}' (Branch: {target_branch or 'default'}, Events: {final_allowed_events}) for monitoring in guild {guild_id}, channel {notification_channel_id}. ID: {repo_id}" + ) else: # This means ON CONFLICT DO NOTHING was triggered, fetch existing ID existing_id = await conn.fetchval( @@ -2319,13 +3008,19 @@ async def add_monitored_repository( SELECT id FROM git_monitored_repositories WHERE guild_id = $1 AND repository_url = $2 AND notification_channel_id = $3; """, - guild_id, repository_url, notification_channel_id + guild_id, + repository_url, + notification_channel_id, ) - log.warning(f"Repository '{repository_url}' for guild {guild_id}, channel {notification_channel_id} already exists with ID {existing_id}. Not adding again.") - return existing_id # Return existing ID if it was a conflict + log.warning( + f"Repository '{repository_url}' for guild {guild_id}, channel {notification_channel_id} already exists with ID {existing_id}. Not adding again." + ) + return existing_id # Return existing ID if it was a conflict return repo_id except Exception as e: - log.exception(f"Database error adding monitored repository '{repository_url}' for guild {guild_id}: {e}") + log.exception( + f"Database error adding monitored repository '{repository_url}' for guild {guild_id}: {e}" + ) return None @@ -2333,25 +3028,34 @@ async def get_monitored_repository_by_id(repo_db_id: int) -> Dict | None: """Gets details of a monitored repository by its database ID.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.warning(f"Bot instance or PostgreSQL pool not available for get_monitored_repository_by_id (ID {repo_db_id}).") + log.warning( + f"Bot instance or PostgreSQL pool not available for get_monitored_repository_by_id (ID {repo_db_id})." + ) return None try: async with bot.pg_pool.acquire() as conn: record = await conn.fetchrow( - "SELECT *, allowed_webhook_events FROM git_monitored_repositories WHERE id = $1", # Ensure new column is fetched - repo_db_id + "SELECT *, allowed_webhook_events FROM git_monitored_repositories WHERE id = $1", # Ensure new column is fetched + repo_db_id, ) # log.info(f"Grep this line: {dict(record) if record else 'No record found'}") # Keep for debugging if needed return dict(record) if record else None except Exception as e: - log.exception(f"Database error getting monitored repository by ID {repo_db_id}: {e}") + log.exception( + f"Database error getting monitored repository by ID {repo_db_id}: {e}" + ) return None -async def get_monitored_repository_by_url(guild_id: int, repository_url: str, notification_channel_id: int) -> Dict | None: + +async def get_monitored_repository_by_url( + guild_id: int, repository_url: str, notification_channel_id: int +) -> Dict | None: """Gets details of a monitored repository by its URL and channel for a specific guild.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.warning(f"Bot instance or PostgreSQL pool not available for get_monitored_repository_by_url (guild {guild_id}).") + log.warning( + f"Bot instance or PostgreSQL pool not available for get_monitored_repository_by_url (guild {guild_id})." + ) return None try: async with bot.pg_pool.acquire() as conn: @@ -2360,18 +3064,27 @@ async def get_monitored_repository_by_url(guild_id: int, repository_url: str, no SELECT *, allowed_webhook_events FROM git_monitored_repositories WHERE guild_id = $1 AND repository_url = $2 AND notification_channel_id = $3 """, - guild_id, repository_url, notification_channel_id + guild_id, + repository_url, + notification_channel_id, ) return dict(record) if record else None except Exception as e: - log.exception(f"Database error getting monitored repository by URL '{repository_url}' for guild {guild_id}: {e}") + log.exception( + f"Database error getting monitored repository by URL '{repository_url}' for guild {guild_id}: {e}" + ) return None -async def update_monitored_repository_events(repo_db_id: int, allowed_events: list[str]) -> bool: + +async def update_monitored_repository_events( + repo_db_id: int, allowed_events: list[str] +) -> bool: """Updates the allowed webhook events for a specific monitored repository.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.error(f"Bot instance or PostgreSQL pool not available for update_monitored_repository_events (ID {repo_db_id}).") + log.error( + f"Bot instance or PostgreSQL pool not available for update_monitored_repository_events (ID {repo_db_id})." + ) return False try: async with bot.pg_pool.acquire() as conn: @@ -2381,24 +3094,40 @@ async def update_monitored_repository_events(repo_db_id: int, allowed_events: li SET allowed_webhook_events = $2 WHERE id = $1; """, - repo_db_id, allowed_events + repo_db_id, + allowed_events, + ) + log.info( + f"Updated allowed webhook events for repository ID {repo_db_id} to {allowed_events}." ) - log.info(f"Updated allowed webhook events for repository ID {repo_db_id} to {allowed_events}.") # Consider cache invalidation here if caching these lists directly per repo_id return True except Exception as e: - log.exception(f"Database error updating allowed webhook events for repository ID {repo_db_id}: {e}") + log.exception( + f"Database error updating allowed webhook events for repository ID {repo_db_id}: {e}" + ) return False -async def update_repository_polling_status(repo_db_id: int, last_polled_commit_sha: str, last_polled_at: asyncio.Future | None = None) -> bool: + +async def update_repository_polling_status( + repo_db_id: int, + last_polled_commit_sha: str, + last_polled_at: asyncio.Future | None = None, +) -> bool: """Updates the last polled commit SHA and timestamp for a repository.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.error(f"Bot instance or PostgreSQL pool not available for update_repository_polling_status (ID {repo_db_id}).") + log.error( + f"Bot instance or PostgreSQL pool not available for update_repository_polling_status (ID {repo_db_id})." + ) return False # If last_polled_at is not provided, use current time - current_time = last_polled_at if last_polled_at else datetime.datetime.now(datetime.timezone.utc) + current_time = ( + last_polled_at + if last_polled_at + else datetime.datetime.now(datetime.timezone.utc) + ) try: async with bot.pg_pool.acquire() as conn: @@ -2408,20 +3137,30 @@ async def update_repository_polling_status(repo_db_id: int, last_polled_commit_s SET last_polled_commit_sha = $2, last_polled_at = $3 WHERE id = $1; """, - repo_db_id, last_polled_commit_sha, current_time + repo_db_id, + last_polled_commit_sha, + current_time, + ) + log.debug( + f"Updated polling status for repository ID {repo_db_id} to SHA {last_polled_commit_sha[:7]}." ) - log.debug(f"Updated polling status for repository ID {repo_db_id} to SHA {last_polled_commit_sha[:7]}.") return True except Exception as e: - log.exception(f"Database error updating polling status for repository ID {repo_db_id}: {e}") + log.exception( + f"Database error updating polling status for repository ID {repo_db_id}: {e}" + ) return False -async def remove_monitored_repository(guild_id: int, repository_url: str, notification_channel_id: int) -> bool: +async def remove_monitored_repository( + guild_id: int, repository_url: str, notification_channel_id: int +) -> bool: """Removes a repository from monitoring for a specific guild and channel.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.error(f"Bot instance or PostgreSQL pool not available for remove_monitored_repository (guild {guild_id}).") + log.error( + f"Bot instance or PostgreSQL pool not available for remove_monitored_repository (guild {guild_id})." + ) return False try: async with bot.pg_pool.acquire() as conn: @@ -2430,18 +3169,28 @@ async def remove_monitored_repository(guild_id: int, repository_url: str, notifi DELETE FROM git_monitored_repositories WHERE guild_id = $1 AND repository_url = $2 AND notification_channel_id = $3; """, - guild_id, repository_url, notification_channel_id + guild_id, + repository_url, + notification_channel_id, ) # DELETE command returns a string like 'DELETE 1' if a row was deleted - deleted_count = int(result.split()[-1]) if result.startswith("DELETE") else 0 + deleted_count = ( + int(result.split()[-1]) if result.startswith("DELETE") else 0 + ) if deleted_count > 0: - log.info(f"Removed repository '{repository_url}' from monitoring for guild {guild_id}, channel {notification_channel_id}.") + log.info( + f"Removed repository '{repository_url}' from monitoring for guild {guild_id}, channel {notification_channel_id}." + ) return True else: - log.warning(f"No repository '{repository_url}' found for monitoring in guild {guild_id}, channel {notification_channel_id} to remove.") + log.warning( + f"No repository '{repository_url}' found for monitoring in guild {guild_id}, channel {notification_channel_id} to remove." + ) return False except Exception as e: - log.exception(f"Database error removing monitored repository '{repository_url}' for guild {guild_id}: {e}") + log.exception( + f"Database error removing monitored repository '{repository_url}' for guild {guild_id}: {e}" + ) return False @@ -2449,17 +3198,21 @@ async def list_monitored_repositories_for_guild(guild_id: int) -> list[Dict]: """Lists all repositories being monitored for a specific guild.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.warning(f"Bot instance or PostgreSQL pool not available for list_monitored_repositories_for_guild (guild {guild_id}).") + log.warning( + f"Bot instance or PostgreSQL pool not available for list_monitored_repositories_for_guild (guild {guild_id})." + ) return [] try: async with bot.pg_pool.acquire() as conn: records = await conn.fetch( "SELECT id, repository_url, platform, monitoring_method, notification_channel_id, created_at FROM git_monitored_repositories WHERE guild_id = $1 ORDER BY created_at DESC", - guild_id + guild_id, ) return [dict(record) for record in records] except Exception as e: - log.exception(f"Database error listing monitored repositories for guild {guild_id}: {e}") + log.exception( + f"Database error listing monitored repositories for guild {guild_id}: {e}" + ) return [] @@ -2467,7 +3220,9 @@ async def get_all_repositories_for_polling() -> list[Dict]: """Fetches all repositories configured for polling.""" bot = get_bot_instance() if not bot or not bot.pg_pool: - log.warning("Bot instance or PostgreSQL pool not available for get_all_repositories_for_polling.") + log.warning( + "Bot instance or PostgreSQL pool not available for get_all_repositories_for_polling." + ) return [] try: async with bot.pg_pool.acquire() as conn: diff --git a/tavilytool.py b/tavilytool.py index 3c89167..f83fbc9 100644 --- a/tavilytool.py +++ b/tavilytool.py @@ -11,38 +11,41 @@ import requests import argparse from typing import Dict, List, Optional + class TavilyAPI: def __init__(self, api_key: str): self.api_key = api_key self.base_url = "https://api.tavily.com" - - def search(self, - query: str, - search_depth: str = "basic", - include_answer: bool = True, - include_images: bool = False, - include_raw_content: bool = False, - max_results: int = 5, - include_domains: Optional[List[str]] = None, - exclude_domains: Optional[List[str]] = None) -> Dict: + + def search( + self, + query: str, + search_depth: str = "basic", + include_answer: bool = True, + include_images: bool = False, + include_raw_content: bool = False, + max_results: int = 5, + include_domains: Optional[List[str]] = None, + exclude_domains: Optional[List[str]] = None, + ) -> Dict: """ Perform a search using Tavily API - + Args: query: Search query string - search_depth: "basic" or "advanced" + search_depth: "basic" or "advanced" include_answer: Include AI-generated answer include_images: Include images in results include_raw_content: Include raw HTML content max_results: Maximum number of results (1-20) include_domains: List of domains to include exclude_domains: List of domains to exclude - + Returns: Dictionary containing search results """ url = f"{self.base_url}/search" - + payload = { "api_key": self.api_key, "query": query, @@ -50,14 +53,14 @@ class TavilyAPI: "include_answer": include_answer, "include_images": include_images, "include_raw_content": include_raw_content, - "max_results": max_results + "max_results": max_results, } - + if include_domains: payload["include_domains"] = include_domains if exclude_domains: payload["exclude_domains"] = exclude_domains - + try: response = requests.post(url, json=payload, timeout=30) response.raise_for_status() @@ -67,19 +70,20 @@ class TavilyAPI: except json.JSONDecodeError: return {"error": "Invalid JSON response from API"} + def format_results(results: Dict) -> str: """Format search results for display""" if "error" in results: return f"❌ Error: {results['error']}" - + output = [] - + # Add answer if available if results.get("answer"): output.append("🤖 AI Answer:") output.append(f" {results['answer']}") output.append("") - + # Add search results if results.get("results"): output.append("🔍 Search Results:") @@ -88,49 +92,63 @@ def format_results(results: Dict) -> str: output.append(f" URL: {result.get('url', 'No URL')}") if result.get("content"): # Truncate content to first 200 chars - content = result["content"][:200] + "..." if len(result["content"]) > 200 else result["content"] + content = ( + result["content"][:200] + "..." + if len(result["content"]) > 200 + else result["content"] + ) output.append(f" Content: {content}") output.append("") - + # Add images if available if results.get("images"): output.append("🖼️ Images:") for img in results["images"][:3]: # Show first 3 images output.append(f" {img}") output.append("") - + return "\n".join(output) + def main(): parser = argparse.ArgumentParser(description="Search using Tavily API") parser.add_argument("query", help="Search query") - parser.add_argument("--depth", choices=["basic", "advanced"], default="basic", - help="Search depth (default: basic)") - parser.add_argument("--max-results", type=int, default=5, - help="Maximum number of results (default: 5)") - parser.add_argument("--include-images", action="store_true", - help="Include images in results") - parser.add_argument("--no-answer", action="store_true", - help="Don't include AI-generated answer") - parser.add_argument("--include-domains", nargs="+", - help="Include only these domains") - parser.add_argument("--exclude-domains", nargs="+", - help="Exclude these domains") - parser.add_argument("--raw", action="store_true", - help="Output raw JSON response") - + parser.add_argument( + "--depth", + choices=["basic", "advanced"], + default="basic", + help="Search depth (default: basic)", + ) + parser.add_argument( + "--max-results", + type=int, + default=5, + help="Maximum number of results (default: 5)", + ) + parser.add_argument( + "--include-images", action="store_true", help="Include images in results" + ) + parser.add_argument( + "--no-answer", action="store_true", help="Don't include AI-generated answer" + ) + parser.add_argument( + "--include-domains", nargs="+", help="Include only these domains" + ) + parser.add_argument("--exclude-domains", nargs="+", help="Exclude these domains") + parser.add_argument("--raw", action="store_true", help="Output raw JSON response") + args = parser.parse_args() - + # Get API key from environment api_key = os.getenv("TAVILY_API_KEY") if not api_key: print("❌ Error: TAVILY_API_KEY environment variable not set") print("Set it with: export TAVILY_API_KEY='your-api-key-here'") sys.exit(1) - + # Initialize Tavily API tavily = TavilyAPI(api_key) - + # Perform search results = tavily.search( query=args.query, @@ -139,14 +157,15 @@ def main(): include_images=args.include_images, max_results=args.max_results, include_domains=args.include_domains, - exclude_domains=args.exclude_domains + exclude_domains=args.exclude_domains, ) - + # Output results if args.raw: print(json.dumps(results, indent=2)) else: print(format_results(results)) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/test_gputil.py b/test_gputil.py index 5921479..180e14d 100644 --- a/test_gputil.py +++ b/test_gputil.py @@ -42,6 +42,7 @@ import wmi # print(get_opencl_gpus()) from pyadl import * + devices = ADLManager.getInstance().getDevices() for device in devices: - print("{0}. {1}".format(device.adapterIndex, device.adapterName)) \ No newline at end of file + print("{0}. {1}".format(device.adapterIndex, device.adapterName)) diff --git a/test_pagination.py b/test_pagination.py index 9cee10e..ff9f4af 100644 --- a/test_pagination.py +++ b/test_pagination.py @@ -4,7 +4,9 @@ import sys import os # Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) log = logging.getLogger("test_pagination") # Add the current directory to the path so we can import the cogs @@ -14,10 +16,11 @@ from cogs.safebooru_cog import SafebooruCog from discord.ext import commands import discord + async def test_pagination(): # Create a mock bot with intents intents = discord.Intents.default() - bot = commands.Bot(command_prefix='!', intents=intents) + bot = commands.Bot(command_prefix="!", intents=intents) # Initialize the cog cog = SafebooruCog(bot) @@ -39,10 +42,11 @@ async def test_pagination(): log.error(f"Error: {results}") # Clean up - if hasattr(cog, 'session') and cog.session and not cog.session.closed: + if hasattr(cog, "session") and cog.session and not cog.session.closed: await cog.session.close() log.info("Closed aiohttp session") + if __name__ == "__main__": # Run the test asyncio.run(test_pagination()) diff --git a/test_part.py b/test_part.py index 73cc6dc..50d3dbd 100644 --- a/test_part.py +++ b/test_part.py @@ -1,18 +1,20 @@ # Test script for Part constructor try: from gurt.api import types + print("Successfully imported types module") - + # Test creating a Part with text part = types.Part(text="test") print(f"Successfully created Part with text: {part}") - + # Test creating a Part with URI part_uri = types.Part(uri="https://example.com", mime_type="text/plain") print(f"Successfully created Part with URI: {part_uri}") - + print("All tests passed!") except Exception as e: print(f"Error: {type(e).__name__}: {e}") import traceback + traceback.print_exc() diff --git a/test_starboard.py b/test_starboard.py index 23cec45..29895c5 100644 --- a/test_starboard.py +++ b/test_starboard.py @@ -17,7 +17,9 @@ import settings_manager as settings_manager load_dotenv() # Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s') +logging.basicConfig( + level=logging.INFO, format="%(asctime)s:%(levelname)s:%(name)s: %(message)s" +) log = logging.getLogger(__name__) # Set up intents @@ -28,10 +30,11 @@ intents.members = True # Create bot instance bot = commands.Bot(command_prefix="!", intents=intents) + @bot.event async def on_ready(): - log.info(f'{bot.user.name} has connected to Discord!') - log.info(f'Bot ID: {bot.user.id}') + log.info(f"{bot.user.name} has connected to Discord!") + log.info(f"Bot ID: {bot.user.id}") # Load the starboard cog try: @@ -40,15 +43,19 @@ async def on_ready(): except Exception as e: log.error(f"Error loading StarboardCog: {e}") + async def main(): - TOKEN = os.getenv('DISCORD_TOKEN') + TOKEN = os.getenv("DISCORD_TOKEN") if not TOKEN: - raise ValueError("No token found. Make sure to set DISCORD_TOKEN in your .env file.") + raise ValueError( + "No token found. Make sure to set DISCORD_TOKEN in your .env file." + ) try: await bot.start(TOKEN) except Exception as e: log.exception(f"Error starting bot: {e}") + if __name__ == "__main__": asyncio.run(main()) diff --git a/test_timeout_config.py b/test_timeout_config.py index 6aadb96..1658ab2 100644 --- a/test_timeout_config.py +++ b/test_timeout_config.py @@ -4,6 +4,7 @@ import os # Define the path for the JSON file to store timeout chance TIMEOUT_CONFIG_FILE = os.path.join("data", "timeout_config.json") + def load_timeout_config(): """Load timeout configuration from JSON file""" timeout_chance = 0.005 # Default value @@ -22,16 +23,17 @@ def load_timeout_config(): print(f"Config file does not exist: {TIMEOUT_CONFIG_FILE}") return timeout_chance + def save_timeout_config(timeout_chance): """Save timeout configuration to JSON file""" try: # Ensure data directory exists os.makedirs(os.path.dirname(TIMEOUT_CONFIG_FILE), exist_ok=True) - + config_data = { "timeout_chance": timeout_chance, "target_user_id": 748405715520978965, - "timeout_duration": 60 + "timeout_duration": 60, } with open(TIMEOUT_CONFIG_FILE, "w") as f: json.dump(config_data, f, indent=4) @@ -41,21 +43,22 @@ def save_timeout_config(timeout_chance): print(f"Error saving timeout configuration: {e}") return False + # Test the functionality if __name__ == "__main__": # Load the current config current_chance = load_timeout_config() print(f"Current timeout chance: {current_chance}") - + # Update the timeout chance new_chance = 0.01 # 1% if save_timeout_config(new_chance): print(f"Successfully updated timeout chance to {new_chance}") - + # Load the config again to verify it was saved updated_chance = load_timeout_config() print(f"Updated timeout chance: {updated_chance}") - + # Restore the original value if save_timeout_config(current_chance): print(f"Restored timeout chance to original value: {current_chance}") diff --git a/test_url_parser.py b/test_url_parser.py index dc70716..c49f5d7 100644 --- a/test_url_parser.py +++ b/test_url_parser.py @@ -1,19 +1,25 @@ import re from typing import Optional, Tuple + # Copy of the fixed parse_repo_url function def parse_repo_url(url: str) -> Tuple[Optional[str], Optional[str]]: """Parses a Git repository URL to extract platform and a simplified repo identifier.""" # Fixed regex pattern for GitHub URLs - github_match = re.match(r"^(?:https?://)?(?:www\.)?github\.com/([\w.-]+/[\w.-]+)(?:\.git)?/?$", url) + github_match = re.match( + r"^(?:https?://)?(?:www\.)?github\.com/([\w.-]+/[\w.-]+)(?:\.git)?/?$", url + ) if github_match: return "github", github_match.group(1) - gitlab_match = re.match(r"^(?:https?://)?(?:www\.)?gitlab\.com/([\w.-]+(?:/[\w.-]+)+)(?:\.git)?/?$", url) + gitlab_match = re.match( + r"^(?:https?://)?(?:www\.)?gitlab\.com/([\w.-]+(?:/[\w.-]+)+)(?:\.git)?/?$", url + ) if gitlab_match: return "gitlab", gitlab_match.group(1) return None, None + # Test URLs test_urls = [ "https://github.com/Slipstreamm/discordbot", @@ -23,7 +29,7 @@ test_urls = [ "https://github.com/Slipstreamm/git", "https://gitlab.com/group/project", "https://gitlab.com/group/subgroup/project", - "invalid-url" + "invalid-url", ] # Test each URL diff --git a/test_usage_counters.py b/test_usage_counters.py index c6ce01b..0f40f99 100644 --- a/test_usage_counters.py +++ b/test_usage_counters.py @@ -12,9 +12,10 @@ from dotenv import load_dotenv # Load environment variables load_dotenv() + async def test_usage_counters(): """Test the usage counters functionality.""" - + # Create database connection try: conn_string = f"postgresql://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@{os.getenv('POSTGRES_HOST')}:{os.getenv('POSTGRES_PORT')}/{os.getenv('POSTGRES_SETTINGS_DB')}" @@ -26,61 +27,74 @@ async def test_usage_counters(): try: # Check if the table exists - table_exists = await conn.fetchval(""" + table_exists = await conn.fetchval( + """ SELECT EXISTS ( SELECT FROM information_schema.tables WHERE table_name = 'command_usage_counters' ) - """) - + """ + ) + if table_exists: print("✅ command_usage_counters table exists") - + # Get some sample data - records = await conn.fetch(""" + records = await conn.fetch( + """ SELECT user1_id, user2_id, command_name, usage_count FROM command_usage_counters ORDER BY usage_count DESC LIMIT 10 - """) - + """ + ) + if records: print("\n📊 Top 10 command usages:") print("User1 ID | User2 ID | Command | Count") print("-" * 45) for record in records: - print(f"{record['user1_id']} | {record['user2_id']} | {record['command_name']} | {record['usage_count']}") + print( + f"{record['user1_id']} | {record['user2_id']} | {record['command_name']} | {record['usage_count']}" + ) else: print("📝 No usage data found yet (table is empty)") - + # Get total count - total_count = await conn.fetchval("SELECT COUNT(*) FROM command_usage_counters") + total_count = await conn.fetchval( + "SELECT COUNT(*) FROM command_usage_counters" + ) print(f"\n📈 Total unique user-command combinations: {total_count}") - + else: print("⚠️ command_usage_counters table does not exist yet") print(" It will be created automatically when a command is first used") - + except Exception as e: print(f"❌ Error querying database: {e}") finally: await conn.close() print("🔌 Database connection closed") + async def get_usage_for_users(user1_id: int, user2_id: int): """Get usage statistics for a specific pair of users.""" - + try: conn_string = f"postgresql://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@{os.getenv('POSTGRES_HOST')}:{os.getenv('POSTGRES_PORT')}/{os.getenv('POSTGRES_SETTINGS_DB')}" conn = await asyncpg.connect(conn_string) - - records = await conn.fetch(""" + + records = await conn.fetch( + """ SELECT command_name, usage_count FROM command_usage_counters WHERE user1_id = $1 AND user2_id = $2 ORDER BY usage_count DESC - """, user1_id, user2_id) - + """, + user1_id, + user2_id, + ) + if records: print(f"\n👥 Usage between users {user1_id} and {user2_id}:") print("Command | Count") @@ -89,18 +103,19 @@ async def get_usage_for_users(user1_id: int, user2_id: int): print(f"{record['command_name']} | {record['usage_count']}") else: print(f"📝 No usage data found between users {user1_id} and {user2_id}") - + await conn.close() - + except Exception as e: print(f"❌ Error querying user data: {e}") + if __name__ == "__main__": print("🧪 Testing Usage Counters Functionality") print("=" * 40) - + # Test basic functionality asyncio.run(test_usage_counters()) - + # Example: Get usage for specific users (replace with actual user IDs) # asyncio.run(get_usage_for_users(123456789, 987654321)) diff --git a/tictactoe.py b/tictactoe.py index 5057316..7849e16 100644 --- a/tictactoe.py +++ b/tictactoe.py @@ -1,10 +1,11 @@ import random + class TicTacToe: - def __init__(self, ai_player='O', ai_difficulty=None): + def __init__(self, ai_player="O", ai_difficulty=None): """ Initialize a new Tic Tac Toe game. - + Parameters: ai_player (str): The player that the AI controls ('X' or 'O'). ai_difficulty (str): AI difficulty level. Should be one of: @@ -13,8 +14,8 @@ class TicTacToe: 'minimax' - uses the minimax algorithm for perfect play. None - no AI moves; both players are human. """ - self.board = [' '] * 9 # 3x3 board represented in a list. - self.current_player = 'X' + self.board = [" "] * 9 # 3x3 board represented in a list. + self.current_player = "X" self.winner = None self.game_over = False # If no AI difficulty is provided, no player is controlled by the computer. @@ -23,8 +24,8 @@ class TicTacToe: def reset(self): """Reset the game to its initial state.""" - self.board = [' '] * 9 - self.current_player = 'X' + self.board = [" "] * 9 + self.current_player = "X" self.winner = None self.game_over = False @@ -37,12 +38,12 @@ class TicTacToe: Play one turn of the game. If it is the human's turn, you must supply the 'position' (an integer from 0 to 8). - If it is the AI's turn, you may call play_turn() without a position; the AI + If it is the AI's turn, you may call play_turn() without a position; the AI will pick and execute its move automatically and return the move's index. - + Returns: int: The position where the move was made. - + Raises: ValueError: If the move is invalid or if the game is already over. """ @@ -63,7 +64,7 @@ class TicTacToe: """Internal method to update the board with the current player's move.""" if self.game_over: raise ValueError("Game is over.") - if self.board[position] != ' ': + if self.board[position] != " ": raise ValueError("Invalid move; spot already taken.") self.board[position] = self.current_player self._check_game_over() @@ -72,21 +73,31 @@ class TicTacToe: def _switch_player(self): """Switch the turn to the other player.""" - self.current_player = 'O' if self.current_player == 'X' else 'X' + self.current_player = "O" if self.current_player == "X" else "X" def _check_game_over(self): """Check the board for a win or tie.""" win_combinations = [ - [0, 1, 2], [3, 4, 5], [6, 7, 8], # Rows - [0, 3, 6], [1, 4, 7], [2, 5, 8], # Columns - [0, 4, 8], [2, 4, 6] # Diagonals + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], # Rows + [0, 3, 6], + [1, 4, 7], + [2, 5, 8], # Columns + [0, 4, 8], + [2, 4, 6], # Diagonals ] for combo in win_combinations: - if self.board[combo[0]] == self.board[combo[1]] == self.board[combo[2]] != ' ': + if ( + self.board[combo[0]] + == self.board[combo[1]] + == self.board[combo[2]] + != " " + ): self.winner = self.board[combo[0]] self.game_over = True return - if ' ' not in self.board: + if " " not in self.board: self.game_over = True # It's a tie. def is_game_over(self): @@ -103,15 +114,15 @@ class TicTacToe: def _get_valid_moves(self, board): """Return a list of valid move positions given a board state.""" - return [i for i, spot in enumerate(board) if spot == ' '] + return [i for i, spot in enumerate(board) if spot == " "] def _select_ai_move(self): """Select an AI move based on the chosen difficulty.""" - if self.ai_difficulty == 'random': + if self.ai_difficulty == "random": return self._ai_random_move() - elif self.ai_difficulty == 'rule': + elif self.ai_difficulty == "rule": return self._ai_rule_move() - elif self.ai_difficulty == 'minimax': + elif self.ai_difficulty == "minimax": return self._ai_minimax_move() else: raise ValueError("Invalid AI difficulty.") @@ -138,7 +149,7 @@ class TicTacToe: if self._check_win(board_copy, self.ai_player): return move # 2. Check for a move to block the opponent. - opponent = 'X' if self.ai_player == 'O' else 'O' + opponent = "X" if self.ai_player == "O" else "O" for move in valid_moves: board_copy = self.board[:] board_copy[move] = opponent @@ -161,9 +172,14 @@ class TicTacToe: Returns True if the player has a winning combination. """ win_combinations = [ - [0, 1, 2], [3, 4, 5], [6, 7, 8], - [0, 3, 6], [1, 4, 7], [2, 5, 8], - [0, 4, 8], [2, 4, 6] + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + [0, 3, 6], + [1, 4, 7], + [2, 5, 8], + [0, 4, 8], + [2, 4, 6], ] for combo in win_combinations: if board[combo[0]] == board[combo[1]] == board[combo[2]] == player: @@ -178,28 +194,28 @@ class TicTacToe: def _minimax(self, board, is_maximizing, depth): """ Minimax algorithm to evaluate board positions. - + Parameters: board (list): The current board state. is_maximizing (bool): True if the AI should maximize the score. depth (int): Current depth of recursion (used for score adjustment). - + Returns: tuple: (score, move) where score is the evaluation of the board, and move is the best move to make. """ - opponent = 'X' if self.ai_player == 'O' else 'O' + opponent = "X" if self.ai_player == "O" else "O" if self._check_win(board, self.ai_player): return 10 - depth, None if self._check_win(board, opponent): return depth - 10, None - if ' ' not in board: + if " " not in board: return 0, None valid_moves = self._get_valid_moves(board) if is_maximizing: - best_score = -float('inf') + best_score = -float("inf") best_move = None for move in valid_moves: board_copy = board[:] @@ -210,7 +226,7 @@ class TicTacToe: best_move = move return best_score, best_move else: - best_score = float('inf') + best_score = float("inf") best_move = None for move in valid_moves: board_copy = board[:] @@ -221,6 +237,7 @@ class TicTacToe: best_move = move return best_score, best_move + # Example usage: # from tictactoe import TicTacToe diff --git a/utils.py b/utils.py index 1de5d1f..c379681 100644 --- a/utils.py +++ b/utils.py @@ -3,10 +3,13 @@ import subprocess import sys import time + def reload_script(): """Restart the current Python script.""" try: - result = subprocess.run(["git", "pull"], check=True, text=True, capture_output=True) + result = subprocess.run( + ["git", "pull"], check=True, text=True, capture_output=True + ) print(result.stdout) # Print the output of the git pull command except subprocess.CalledProcessError as e: print(f"Error during git pull: {e.stderr}") diff --git a/webdrivertorso_template.py b/webdrivertorso_template.py index d67dd70..3cce7d0 100644 --- a/webdrivertorso_template.py +++ b/webdrivertorso_template.py @@ -21,6 +21,7 @@ COQUI_AVAILABLE = importlib.util.find_spec("TTS") is not None try: import subprocess import platform + if platform.system() == "Windows": # On Windows, we'll check if the command exists result = subprocess.run(["where", "espeak-ng"], capture_output=True, text=True) @@ -33,6 +34,7 @@ except Exception as e: print(f"Error checking espeak-ng: {e}") ESPEAK_AVAILABLE = False + class JSON: def read(file): with open(f"{file}.json", "r", encoding="utf8") as file: @@ -43,6 +45,7 @@ class JSON: with open(f"{file}.json", "w", encoding="utf8") as file: json.dump(data, file, indent=4) + config_data = JSON.read("config") # SETTINGS # @@ -60,29 +63,74 @@ sample_rate = config_data["SOUND_QUALITY"] tts_enabled = config_data.get("TTS_ENABLED", False) tts_text = config_data.get("TTS_TEXT", "This is a default text for TTS.") tts_provider = config_data.get("TTS_PROVIDER", "gtts") # Options: gtts, pyttsx3, coqui -audio_wave_type = config_data.get("AUDIO_WAVE_TYPE", "sawtooth") # Options: sawtooth, sine, square, triangle, noise, pulse, harmonic +audio_wave_type = config_data.get( + "AUDIO_WAVE_TYPE", "sawtooth" +) # Options: sawtooth, sine, square, triangle, noise, pulse, harmonic slide_duration = config_data.get("SLIDE_DURATION", 1000) # Duration in milliseconds -deform_level = config_data.get("DEFORM_LEVEL", "none") # Options: none, low, medium, high +deform_level = config_data.get( + "DEFORM_LEVEL", "none" +) # Options: none, low, medium, high color_mode = config_data.get("COLOR_MODE", "random") # Options: random, scheme, solid -color_scheme = config_data.get("COLOR_SCHEME", "default") # Placeholder for color schemes +color_scheme = config_data.get( + "COLOR_SCHEME", "default" +) # Placeholder for color schemes solid_color = config_data.get("SOLID_COLOR", "#FFFFFF") # Default solid color -allowed_shapes = config_data.get("ALLOWED_SHAPES", ["rectangle", "ellipse", "polygon", "triangle", "circle"]) +allowed_shapes = config_data.get( + "ALLOWED_SHAPES", ["rectangle", "ellipse", "polygon", "triangle", "circle"] +) wave_vibe = config_data.get("WAVE_VIBE", "calm") # New config option for wave vibe top_left_text_enabled = config_data.get("TOP_LEFT_TEXT_ENABLED", True) -top_left_text_mode = config_data.get("TOP_LEFT_TEXT_MODE", "random") # Options: random, word -words_topic = config_data.get("WORDS_TOPIC", "random") # Options: random, introspective, action, nature, technology +top_left_text_mode = config_data.get( + "TOP_LEFT_TEXT_MODE", "random" +) # Options: random, word +words_topic = config_data.get( + "WORDS_TOPIC", "random" +) # Options: random, introspective, action, nature, technology text_color = config_data.get("TEXT_COLOR", "#000000") text_size = config_data.get("TEXT_SIZE", 0) # 0 means auto-scale text_position = config_data.get("TEXT_POSITION", "top-left") # Get color schemes from config if available -color_schemes_data = config_data.get("COLOR_SCHEMES", { - "pastel": [[255, 182, 193], [176, 224, 230], [240, 230, 140], [221, 160, 221], [152, 251, 152]], - "dark_gritty": [[47, 79, 79], [105, 105, 105], [0, 0, 0], [85, 107, 47], [139, 69, 19]], - "nature": [[34, 139, 34], [107, 142, 35], [46, 139, 87], [32, 178, 170], [154, 205, 50]], - "vibrant": [[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [255, 0, 255]], - "ocean": [[0, 105, 148], [72, 209, 204], [70, 130, 180], [135, 206, 250], [176, 224, 230]] -}) +color_schemes_data = config_data.get( + "COLOR_SCHEMES", + { + "pastel": [ + [255, 182, 193], + [176, 224, 230], + [240, 230, 140], + [221, 160, 221], + [152, 251, 152], + ], + "dark_gritty": [ + [47, 79, 79], + [105, 105, 105], + [0, 0, 0], + [85, 107, 47], + [139, 69, 19], + ], + "nature": [ + [34, 139, 34], + [107, 142, 35], + [46, 139, 87], + [32, 178, 170], + [154, 205, 50], + ], + "vibrant": [ + [255, 0, 0], + [0, 255, 0], + [0, 0, 255], + [255, 255, 0], + [255, 0, 255], + ], + "ocean": [ + [0, 105, 148], + [72, 209, 204], + [70, 130, 180], + [135, 206, 250], + [176, 224, 230], + ], + }, +) # Convert color schemes from lists to tuples for PIL color_schemes = {} @@ -94,22 +142,72 @@ if color_scheme not in color_schemes: color_schemes[color_scheme] = [(128, 128, 128)] # Vibe presets for wave sound -wave_vibes = config_data.get("WAVE_VIBES", { - "calm": {"frequency": 200, "amplitude": 0.3, "modulation": 0.1}, - "eerie": {"frequency": 600, "amplitude": 0.5, "modulation": 0.7}, - "random": {}, # Randomized values will be generated - "energetic": {"frequency": 800, "amplitude": 0.7, "modulation": 0.2}, - "dreamy": {"frequency": 400, "amplitude": 0.4, "modulation": 0.5}, - "chaotic": {"frequency": 1000, "amplitude": 1.0, "modulation": 1.0} -}) +wave_vibes = config_data.get( + "WAVE_VIBES", + { + "calm": {"frequency": 200, "amplitude": 0.3, "modulation": 0.1}, + "eerie": {"frequency": 600, "amplitude": 0.5, "modulation": 0.7}, + "random": {}, # Randomized values will be generated + "energetic": {"frequency": 800, "amplitude": 0.7, "modulation": 0.2}, + "dreamy": {"frequency": 400, "amplitude": 0.4, "modulation": 0.5}, + "chaotic": {"frequency": 1000, "amplitude": 1.0, "modulation": 1.0}, + }, +) # Word topics -word_topics = config_data.get("WORD_TOPICS", { - "introspective": ["reflection", "thought", "solitude", "ponder", "meditation", "introspection", "awareness", "contemplation", "silence", "stillness"], - "action": ["run", "jump", "climb", "race", "fight", "explore", "build", "create", "overcome", "achieve"], - "nature": ["tree", "mountain", "river", "ocean", "flower", "forest", "animal", "sky", "valley", "meadow"], - "technology": ["computer", "robot", "network", "data", "algorithm", "innovation", "digital", "machine", "software", "hardware"] -}) +word_topics = config_data.get( + "WORD_TOPICS", + { + "introspective": [ + "reflection", + "thought", + "solitude", + "ponder", + "meditation", + "introspection", + "awareness", + "contemplation", + "silence", + "stillness", + ], + "action": [ + "run", + "jump", + "climb", + "race", + "fight", + "explore", + "build", + "create", + "overcome", + "achieve", + ], + "nature": [ + "tree", + "mountain", + "river", + "ocean", + "flower", + "forest", + "animal", + "sky", + "valley", + "meadow", + ], + "technology": [ + "computer", + "robot", + "network", + "data", + "algorithm", + "innovation", + "digital", + "machine", + "software", + "hardware", + ], + }, +) # Font scaling based on video size if text_size <= 0: @@ -119,18 +217,22 @@ else: fnt = ImageFont.truetype("./FONT/sys.ttf", font_size) -files = glob.glob('./IMG/*') +files = glob.glob("./IMG/*") for f in files: os.remove(f) print("REMOVED OLD FILES") -def generate_string(length, charset="abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"): + +def generate_string( + length, charset="abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +): result = "" for i in range(length): result += random.choice(charset) return result + def generate_word(theme="random"): if theme == "random" or theme not in word_topics: if random.random() < 0.5 and len(word_topics) > 0: @@ -144,6 +246,7 @@ def generate_word(theme="random"): # Use a word from the specified topic return random.choice(word_topics[theme]) + def generate_wave_sample(x, freq, wave_type, amplitude=1.0): """Generate a sample for different wave types""" t = x / sample_rate @@ -162,17 +265,15 @@ def generate_wave_sample(x, freq, wave_type, amplitude=1.0): return amplitude * (1 if math.sin(2 * math.pi * freq * t) > 0.7 else 0) elif wave_type == "harmonic": return amplitude * ( - math.sin(2 * math.pi * freq * t) * 0.6 + - math.sin(2 * math.pi * freq * 2 * t) * 0.3 + - math.sin(2 * math.pi * freq * 3 * t) * 0.1 + math.sin(2 * math.pi * freq * t) * 0.6 + + math.sin(2 * math.pi * freq * 2 * t) * 0.3 + + math.sin(2 * math.pi * freq * 3 * t) * 0.1 ) else: # Default to sawtooth return amplitude * (2 * (t * freq - math.floor(t * freq + 0.5))) -def append_wave( - freq=None, - duration_milliseconds=1000, - volume=1.0): + +def append_wave(freq=None, duration_milliseconds=1000, volume=1.0): global audio @@ -183,7 +284,9 @@ def append_wave( modulation = random.uniform(0.1, 1.0) else: base_freq = vibe_params["frequency"] - freq = random.uniform(base_freq * 0.7, base_freq * 1.3) if freq is None else freq + freq = ( + random.uniform(base_freq * 0.7, base_freq * 1.3) if freq is None else freq + ) amplitude = vibe_params["amplitude"] * random.uniform(0.7, 1.3) modulation = vibe_params["modulation"] * random.uniform(0.6, 1.4) @@ -191,10 +294,13 @@ def append_wave( for x in range(int(num_samples)): wave_sample = generate_wave_sample(x, freq, audio_wave_type, amplitude) - modulated_sample = wave_sample * (1 + modulation * math.sin(2 * math.pi * 0.5 * x / sample_rate)) + modulated_sample = wave_sample * ( + 1 + modulation * math.sin(2 * math.pi * 0.5 * x / sample_rate) + ) audio.append(volume * modulated_sample) return + def save_wav(file_name): wav_file = wave.open(file_name, "w") @@ -208,22 +314,25 @@ def save_wav(file_name): wav_file.setparams((nchannels, sampwidth, sample_rate, nframes, comptype, compname)) for sample in audio: - wav_file.writeframes(struct.pack('h', int(sample * 32767.0))) + wav_file.writeframes(struct.pack("h", int(sample * 32767.0))) wav_file.close() return + # Generate TTS audio using different providers def generate_tts_audio(text, output_file): if tts_provider == "gtts" and GTTS_AVAILABLE: from gtts import gTTS - tts = gTTS(text=text, lang='en') + + tts = gTTS(text=text, lang="en") tts.save(output_file) print(f"Google TTS audio saved to {output_file}") return True elif tts_provider == "pyttsx3" and PYTTSX3_AVAILABLE: import pyttsx3 + engine = pyttsx3.init() engine.save_to_file(text, output_file) engine.runAndWait() @@ -232,6 +341,7 @@ def generate_tts_audio(text, output_file): elif tts_provider == "coqui" and COQUI_AVAILABLE: try: from TTS.api import TTS + tts = TTS("tts_models/en/ljspeech/tacotron2-DDC") tts.tts_to_file(text=text, file_path=output_file) print(f"Coqui TTS audio saved to {output_file}") @@ -279,6 +389,7 @@ def generate_tts_audio(text, output_file): print(f"TTS provider {tts_provider} not available. Falling back to no TTS.") return False + if tts_enabled: tts_audio_file = "./SOUND/tts_output.mp3" tts_success = generate_tts_audio(tts_text, tts_audio_file) @@ -313,39 +424,56 @@ for xyz in range(AMOUNT): y2 = random.randint(minH, maxH) if color_mode == "random": - color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) + color = ( + random.randint(0, 255), + random.randint(0, 255), + random.randint(0, 255), + ) elif color_mode == "scheme": scheme_colors = color_schemes.get(color_scheme, [(128, 128, 128)]) color = random.choice(scheme_colors) elif color_mode == "solid": try: - color = tuple(int(solid_color.lstrip("#")[i:i + 2], 16) for i in (0, 2, 4)) + color = tuple( + int(solid_color.lstrip("#")[i : i + 2], 16) for i in (0, 2, 4) + ) except: color = (255, 255, 255) # Default to white if invalid hex if shape_type == "rectangle": - img1.rectangle([(x1, y1), (x1 + x2, y1 + y2)], fill=color, outline=color) + img1.rectangle( + [(x1, y1), (x1 + x2, y1 + y2)], fill=color, outline=color + ) elif shape_type == "ellipse": img1.ellipse([(x1, y1), (x1 + x2, y1 + y2)], fill=color, outline=color) elif shape_type == "polygon": num_points = random.randint(3, 6) - points = [(random.randint(0, w), random.randint(0, h)) for _ in range(num_points)] + points = [ + (random.randint(0, w), random.randint(0, h)) + for _ in range(num_points) + ] img1.polygon(points, fill=color, outline=color) elif shape_type == "triangle": points = [ (x1, y1), (x1 + random.randint(-x2, x2), y1 + y2), - (x1 + x2, y1 + random.randint(-y2, y2)) + (x1 + x2, y1 + random.randint(-y2, y2)), ] img1.polygon(points, fill=color, outline=color) elif shape_type == "circle": radius = min(x2, y2) // 2 - img1.ellipse([(x1 - radius, y1 - radius), (x1 + radius, y1 + radius)], fill=color, outline=color) + img1.ellipse( + [(x1 - radius, y1 - radius), (x1 + radius, y1 + radius)], + fill=color, + outline=color, + ) # Parse text color try: if text_color.startswith("#"): - parsed_text_color = tuple(int(text_color.lstrip("#")[i:i + 2], 16) for i in (0, 2, 4)) + parsed_text_color = tuple( + int(text_color.lstrip("#")[i : i + 2], 16) for i in (0, 2, 4) + ) else: # Named colors (basic support) color_map = { @@ -357,7 +485,7 @@ for xyz in range(AMOUNT): "yellow": (255, 255, 0), "purple": (128, 0, 128), "orange": (255, 165, 0), - "gray": (128, 128, 128) + "gray": (128, 128, 128), } parsed_text_color = color_map.get(text_color.lower(), (0, 0, 0)) except: @@ -365,38 +493,86 @@ for xyz in range(AMOUNT): if top_left_text_enabled: if top_left_text_mode == "random": - random_top_left_text = generate_string(30, charset="abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()_+-=[]{}|;:',.<>?/") + random_top_left_text = generate_string( + 30, + charset="abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()_+-=[]{}|;:',.<>?/", + ) elif top_left_text_mode == "word": random_top_left_text = generate_word(words_topic) else: random_top_left_text = "" # Position text based on text_position setting - if text_position == "top-left" or text_position == "random" and random.random() < 0.2: - img1.text((10, 10), random_top_left_text, font=fnt, fill=parsed_text_color) - elif text_position == "top-right" or text_position == "random" and random.random() < 0.2: + if ( + text_position == "top-left" + or text_position == "random" + and random.random() < 0.2 + ): + img1.text( + (10, 10), random_top_left_text, font=fnt, fill=parsed_text_color + ) + elif ( + text_position == "top-right" + or text_position == "random" + and random.random() < 0.2 + ): text_width = img1.textlength(random_top_left_text, font=fnt) - img1.text((w - text_width - 10, 10), random_top_left_text, font=fnt, fill=parsed_text_color) - elif text_position == "bottom-left" or text_position == "random" and random.random() < 0.2: - img1.text((10, h - font_size - 10), random_top_left_text, font=fnt, fill=parsed_text_color) - elif text_position == "bottom-right" or text_position == "random" and random.random() < 0.2: + img1.text( + (w - text_width - 10, 10), + random_top_left_text, + font=fnt, + fill=parsed_text_color, + ) + elif ( + text_position == "bottom-left" + or text_position == "random" + and random.random() < 0.2 + ): + img1.text( + (10, h - font_size - 10), + random_top_left_text, + font=fnt, + fill=parsed_text_color, + ) + elif ( + text_position == "bottom-right" + or text_position == "random" + and random.random() < 0.2 + ): text_width = img1.textlength(random_top_left_text, font=fnt) - img1.text((w - text_width - 10, h - font_size - 10), random_top_left_text, font=fnt, fill=parsed_text_color) + img1.text( + (w - text_width - 10, h - font_size - 10), + random_top_left_text, + font=fnt, + fill=parsed_text_color, + ) elif text_position == "center" or text_position == "random": text_width = img1.textlength(random_top_left_text, font=fnt) - img1.text((w//2 - text_width//2, h//2 - font_size//2), random_top_left_text, font=fnt, fill=parsed_text_color) + img1.text( + (w // 2 - text_width // 2, h // 2 - font_size // 2), + random_top_left_text, + font=fnt, + fill=parsed_text_color, + ) # Add video name to bottom-left corner video_name_text = f"{video_name}.mp4" video_name_width = img1.textlength(video_name_text, font=fnt) video_name_height = font_size - img1.text((10, h - video_name_height - 10), video_name_text, font=fnt, fill=parsed_text_color) + img1.text( + (10, h - video_name_height - 10), + video_name_text, + font=fnt, + fill=parsed_text_color, + ) # Move slide info text to the top right corner slide_text = f"Slide {i}" text_width = img1.textlength(slide_text, font=fnt) text_height = font_size - img1.text((w - text_width - 10, 10), slide_text, font=fnt, fill=parsed_text_color) + img1.text( + (w - text_width - 10, 10), slide_text, font=fnt, fill=parsed_text_color + ) img.save(f"./IMG/{str(i).zfill(4)}_{random.randint(1000, 9999)}.png") @@ -427,10 +603,13 @@ for xyz in range(AMOUNT): print("AUDIO GENERATED") - image_folder = './IMG' + image_folder = "./IMG" fps = 1000 / slide_duration # Ensure fps is precise to handle timing discrepancies - image_files = sorted([f for f in glob.glob(f"{image_folder}/*.png")], key=lambda x: int(os.path.basename(x).split('_')[0])) + image_files = sorted( + [f for f in glob.glob(f"{image_folder}/*.png")], + key=lambda x: int(os.path.basename(x).split("_")[0]), + ) # Ensure all frames have the same dimensions frames = [] @@ -438,19 +617,19 @@ for xyz in range(AMOUNT): for idx, file in enumerate(image_files): frame = np.array(Image.open(file)) if frame.shape != first_frame.shape: - print(f"Frame {idx} has inconsistent dimensions: {frame.shape} vs {first_frame.shape}") + print( + f"Frame {idx} has inconsistent dimensions: {frame.shape} vs {first_frame.shape}" + ) frame = np.resize(frame, first_frame.shape) # Resize if necessary frames.append(frame) print("Starting video compilation...") - clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip( - frames, fps=fps - ) + clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(frames, fps=fps) clip.write_videofile( - f'./OUTPUT/{video_name}.mp4', + f"./OUTPUT/{video_name}.mp4", audio="./SOUND/output.m4a", codec="libx264", - audio_codec="aac" + audio_codec="aac", ) print("Video compilation finished successfully!") diff --git a/wheatley/__init__.py b/wheatley/__init__.py index 6cd96f1..93106b0 100644 --- a/wheatley/__init__.py +++ b/wheatley/__init__.py @@ -5,4 +5,4 @@ from .cog import setup # This makes "from wheatley import setup" work -__all__ = ['setup'] +__all__ = ["setup"] diff --git a/wheatley/analysis.py b/wheatley/analysis.py index 73b633b..70892f7 100644 --- a/wheatley/analysis.py +++ b/wheatley/analysis.py @@ -9,27 +9,34 @@ logger = logging.getLogger(__name__) # Relative imports from .config import ( - LEARNING_RATE, TOPIC_UPDATE_INTERVAL, - TOPIC_RELEVANCE_DECAY, MAX_ACTIVE_TOPICS, SENTIMENT_DECAY_RATE, - EMOTION_KEYWORDS, EMOJI_SENTIMENT # Import necessary configs + LEARNING_RATE, + TOPIC_UPDATE_INTERVAL, + TOPIC_RELEVANCE_DECAY, + MAX_ACTIVE_TOPICS, + SENTIMENT_DECAY_RATE, + EMOTION_KEYWORDS, + EMOJI_SENTIMENT, # Import necessary configs ) + # Removed imports for BASELINE_PERSONALITY, REFLECTION_INTERVAL_SECONDS, GOAL related configs if TYPE_CHECKING: - from .cog import WheatleyCog # Updated type hint + from .cog import WheatleyCog # Updated type hint # --- Analysis Functions --- # Note: These functions need the 'cog' instance passed to access state like caches, etc. -async def analyze_conversation_patterns(cog: 'WheatleyCog'): # Updated type hint + +async def analyze_conversation_patterns(cog: "WheatleyCog"): # Updated type hint """Analyzes recent conversations to identify patterns and learn from them""" print("Analyzing conversation patterns and updating topics...") try: # Update conversation topics first await update_conversation_topics(cog) - for channel_id, messages in cog.message_cache['by_channel'].items(): - if len(messages) < 10: continue + for channel_id, messages in cog.message_cache["by_channel"].items(): + if len(messages) < 10: + continue # Pattern extraction might be less useful without personality/goals, but kept for now # channel_patterns = extract_conversation_patterns(cog, messages) # Pass cog @@ -40,33 +47,42 @@ async def analyze_conversation_patterns(cog: 'WheatleyCog'): # Updated type hint # combined_patterns = combined_patterns[-MAX_PATTERNS_PER_CHANNEL:] # cog.conversation_patterns[channel_id] = combined_patterns - analyze_conversation_dynamics(cog, channel_id, messages) # Pass cog + analyze_conversation_dynamics(cog, channel_id, messages) # Pass cog - update_user_preferences(cog) # Pass cog - Note: This might need adjustment as it relies on traits we removed + update_user_preferences( + cog + ) # Pass cog - Note: This might need adjustment as it relies on traits we removed except Exception as e: print(f"Error analyzing conversation patterns: {e}") traceback.print_exc() -async def update_conversation_topics(cog: 'WheatleyCog'): # Updated type hint + +async def update_conversation_topics(cog: "WheatleyCog"): # Updated type hint """Updates the active topics for each channel based on recent messages""" try: - for channel_id, messages in cog.message_cache['by_channel'].items(): - if len(messages) < 5: continue + for channel_id, messages in cog.message_cache["by_channel"].items(): + if len(messages) < 5: + continue channel_topics = cog.active_topics[channel_id] now = time.time() - if now - channel_topics["last_update"] < TOPIC_UPDATE_INTERVAL: continue + if now - channel_topics["last_update"] < TOPIC_UPDATE_INTERVAL: + continue recent_messages = list(messages)[-30:] - topics = identify_conversation_topics(cog, recent_messages) # Pass cog - if not topics: continue + topics = identify_conversation_topics(cog, recent_messages) # Pass cog + if not topics: + continue old_topics = channel_topics["topics"] - for topic in old_topics: topic["score"] *= (1 - TOPIC_RELEVANCE_DECAY) + for topic in old_topics: + topic["score"] *= 1 - TOPIC_RELEVANCE_DECAY for new_topic in topics: - existing = next((t for t in old_topics if t["topic"] == new_topic["topic"]), None) + existing = next( + (t for t in old_topics if t["topic"] == new_topic["topic"]), None + ) if existing: existing["score"] = max(existing["score"], new_topic["score"]) existing["related_terms"] = new_topic["related_terms"] @@ -81,13 +97,22 @@ async def update_conversation_topics(cog: 'WheatleyCog'): # Updated type hint old_topics = old_topics[:MAX_ACTIVE_TOPICS] if old_topics and channel_topics["topics"] != old_topics: - if not channel_topics["topic_history"] or set(t["topic"] for t in old_topics) != set(t["topic"] for t in channel_topics["topics"]): - channel_topics["topic_history"].append({ - "topics": [{"topic": t["topic"], "score": t["score"]} for t in old_topics], - "timestamp": now - }) + if not channel_topics["topic_history"] or set( + t["topic"] for t in old_topics + ) != set(t["topic"] for t in channel_topics["topics"]): + channel_topics["topic_history"].append( + { + "topics": [ + {"topic": t["topic"], "score": t["score"]} + for t in old_topics + ], + "timestamp": now, + } + ) if len(channel_topics["topic_history"]) > 10: - channel_topics["topic_history"] = channel_topics["topic_history"][-10:] + channel_topics["topic_history"] = channel_topics[ + "topic_history" + ][-10:] # User topic interest tracking might be less relevant without proactive interest triggers, but kept for now for msg in recent_messages: @@ -97,150 +122,356 @@ async def update_conversation_topics(cog: 'WheatleyCog'): # Updated type hint topic_text = topic["topic"].lower() if topic_text in content: user_interests = channel_topics["user_topic_interests"][user_id] - existing = next((i for i in user_interests if i["topic"] == topic["topic"]), None) + existing = next( + (i for i in user_interests if i["topic"] == topic["topic"]), + None, + ) if existing: - existing["score"] = existing["score"] * 0.8 + topic["score"] * 0.2 + existing["score"] = ( + existing["score"] * 0.8 + topic["score"] * 0.2 + ) existing["last_mentioned"] = now else: - user_interests.append({ - "topic": topic["topic"], "score": topic["score"] * 0.5, - "first_mentioned": now, "last_mentioned": now - }) + user_interests.append( + { + "topic": topic["topic"], + "score": topic["score"] * 0.5, + "first_mentioned": now, + "last_mentioned": now, + } + ) channel_topics["topics"] = old_topics channel_topics["last_update"] = now if old_topics: - topic_str = ", ".join([f"{t['topic']} ({t['score']:.2f})" for t in old_topics[:3]]) + topic_str = ", ".join( + [f"{t['topic']} ({t['score']:.2f})" for t in old_topics[:3]] + ) print(f"Updated topics for channel {channel_id}: {topic_str}") except Exception as e: print(f"Error updating conversation topics: {e}") traceback.print_exc() -def analyze_conversation_dynamics(cog: 'WheatleyCog', channel_id: int, messages: List[Dict[str, Any]]): # Updated type hint + +def analyze_conversation_dynamics( + cog: "WheatleyCog", channel_id: int, messages: List[Dict[str, Any]] +): # Updated type hint """Analyzes conversation dynamics like response times, message lengths, etc.""" - if len(messages) < 5: return + if len(messages) < 5: + return try: response_times = [] response_map = defaultdict(int) message_lengths = defaultdict(list) question_answer_pairs = [] - import datetime # Import here + import datetime # Import here for i in range(1, len(messages)): - current_msg = messages[i]; prev_msg = messages[i-1] - if current_msg["author"]["id"] == prev_msg["author"]["id"]: continue + current_msg = messages[i] + prev_msg = messages[i - 1] + if current_msg["author"]["id"] == prev_msg["author"]["id"]: + continue try: - current_time = datetime.datetime.fromisoformat(current_msg["created_at"]) + current_time = datetime.datetime.fromisoformat( + current_msg["created_at"] + ) prev_time = datetime.datetime.fromisoformat(prev_msg["created_at"]) delta_seconds = (current_time - prev_time).total_seconds() - if 0 < delta_seconds < 300: response_times.append(delta_seconds) - except (ValueError, TypeError): pass + if 0 < delta_seconds < 300: + response_times.append(delta_seconds) + except (ValueError, TypeError): + pass - responder = current_msg["author"]["id"]; respondee = prev_msg["author"]["id"] + responder = current_msg["author"]["id"] + respondee = prev_msg["author"]["id"] response_map[f"{responder}:{respondee}"] += 1 message_lengths[responder].append(len(current_msg["content"])) if prev_msg["content"].endswith("?"): - question_answer_pairs.append({ - "question": prev_msg["content"], "answer": current_msg["content"], - "question_author": prev_msg["author"]["id"], "answer_author": current_msg["author"]["id"] - }) + question_answer_pairs.append( + { + "question": prev_msg["content"], + "answer": current_msg["content"], + "question_author": prev_msg["author"]["id"], + "answer_author": current_msg["author"]["id"], + } + ) - avg_response_time = sum(response_times) / len(response_times) if response_times else 0 - top_responders = sorted(response_map.items(), key=lambda x: x[1], reverse=True)[:3] - avg_message_lengths = {uid: sum(ls)/len(ls) if ls else 0 for uid, ls in message_lengths.items()} + avg_response_time = ( + sum(response_times) / len(response_times) if response_times else 0 + ) + top_responders = sorted(response_map.items(), key=lambda x: x[1], reverse=True)[ + :3 + ] + avg_message_lengths = { + uid: sum(ls) / len(ls) if ls else 0 for uid, ls in message_lengths.items() + } dynamics = { - "avg_response_time": avg_response_time, "top_responders": top_responders, - "avg_message_lengths": avg_message_lengths, "question_answer_count": len(question_answer_pairs), - "last_updated": time.time() + "avg_response_time": avg_response_time, + "top_responders": top_responders, + "avg_message_lengths": avg_message_lengths, + "question_answer_count": len(question_answer_pairs), + "last_updated": time.time(), } - if not hasattr(cog, 'conversation_dynamics'): cog.conversation_dynamics = {} + if not hasattr(cog, "conversation_dynamics"): + cog.conversation_dynamics = {} cog.conversation_dynamics[channel_id] = dynamics - adapt_to_conversation_dynamics(cog, channel_id, dynamics) # Pass cog + adapt_to_conversation_dynamics(cog, channel_id, dynamics) # Pass cog - except Exception as e: print(f"Error analyzing conversation dynamics: {e}") + except Exception as e: + print(f"Error analyzing conversation dynamics: {e}") -def adapt_to_conversation_dynamics(cog: 'WheatleyCog', channel_id: int, dynamics: Dict[str, Any]): # Updated type hint + +def adapt_to_conversation_dynamics( + cog: "WheatleyCog", channel_id: int, dynamics: Dict[str, Any] +): # Updated type hint """Adapts bot behavior based on observed conversation dynamics.""" # Note: This function previously adapted personality traits. # It might be removed or repurposed for Wheatley if needed. # For now, it calculates factors but doesn't apply them directly to a removed personality system. try: if dynamics["avg_response_time"] > 0: - if not hasattr(cog, 'channel_response_timing'): cog.channel_response_timing = {} - response_time_factor = max(0.7, min(1.0, dynamics["avg_response_time"] / 10)) + if not hasattr(cog, "channel_response_timing"): + cog.channel_response_timing = {} + response_time_factor = max( + 0.7, min(1.0, dynamics["avg_response_time"] / 10) + ) cog.channel_response_timing[channel_id] = response_time_factor if dynamics["avg_message_lengths"]: all_lengths = [ls for ls in dynamics["avg_message_lengths"].values()] if all_lengths: avg_length = sum(all_lengths) / len(all_lengths) - if not hasattr(cog, 'channel_message_length'): cog.channel_message_length = {} + if not hasattr(cog, "channel_message_length"): + cog.channel_message_length = {} length_factor = min(avg_length / 200, 1.0) cog.channel_message_length[channel_id] = length_factor if dynamics["question_answer_count"] > 0: - if not hasattr(cog, 'channel_qa_responsiveness'): cog.channel_qa_responsiveness = {} + if not hasattr(cog, "channel_qa_responsiveness"): + cog.channel_qa_responsiveness = {} qa_factor = min(0.9, 0.5 + (dynamics["question_answer_count"] / 20) * 0.4) cog.channel_qa_responsiveness[channel_id] = qa_factor - except Exception as e: print(f"Error adapting to conversation dynamics: {e}") + except Exception as e: + print(f"Error adapting to conversation dynamics: {e}") -def extract_conversation_patterns(cog: 'WheatleyCog', messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # Updated type hint + +def extract_conversation_patterns( + cog: "WheatleyCog", messages: List[Dict[str, Any]] +) -> List[Dict[str, Any]]: # Updated type hint """Extract patterns from a sequence of messages""" # This function might be less useful without personality/goal systems, kept for potential future use. patterns = [] - if len(messages) < 5: return patterns - import datetime # Import here + if len(messages) < 5: + return patterns + import datetime # Import here for i in range(len(messages) - 2): pattern = { "type": "message_sequence", "messages": [ - {"author_type": "user" if not messages[i]["author"]["bot"] else "bot", "content_sample": messages[i]["content"][:50]}, - {"author_type": "user" if not messages[i+1]["author"]["bot"] else "bot", "content_sample": messages[i+1]["content"][:50]}, - {"author_type": "user" if not messages[i+2]["author"]["bot"] else "bot", "content_sample": messages[i+2]["content"][:50]} - ], "timestamp": datetime.datetime.now().isoformat() + { + "author_type": ( + "user" if not messages[i]["author"]["bot"] else "bot" + ), + "content_sample": messages[i]["content"][:50], + }, + { + "author_type": ( + "user" if not messages[i + 1]["author"]["bot"] else "bot" + ), + "content_sample": messages[i + 1]["content"][:50], + }, + { + "author_type": ( + "user" if not messages[i + 2]["author"]["bot"] else "bot" + ), + "content_sample": messages[i + 2]["content"][:50], + }, + ], + "timestamp": datetime.datetime.now().isoformat(), } patterns.append(pattern) - topics = identify_conversation_topics(cog, messages) # Pass cog - if topics: patterns.append({"type": "topic_pattern", "topics": topics, "timestamp": datetime.datetime.now().isoformat()}) + topics = identify_conversation_topics(cog, messages) # Pass cog + if topics: + patterns.append( + { + "type": "topic_pattern", + "topics": topics, + "timestamp": datetime.datetime.now().isoformat(), + } + ) - user_interactions = analyze_user_interactions(cog, messages) # Pass cog - if user_interactions: patterns.append({"type": "user_interaction", "interactions": user_interactions, "timestamp": datetime.datetime.now().isoformat()}) + user_interactions = analyze_user_interactions(cog, messages) # Pass cog + if user_interactions: + patterns.append( + { + "type": "user_interaction", + "interactions": user_interactions, + "timestamp": datetime.datetime.now().isoformat(), + } + ) return patterns -def identify_conversation_topics(cog: 'WheatleyCog', messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # Updated type hint + +def identify_conversation_topics( + cog: "WheatleyCog", messages: List[Dict[str, Any]] +) -> List[Dict[str, Any]]: # Updated type hint """Identify potential topics from conversation messages.""" - if not messages or len(messages) < 3: return [] + if not messages or len(messages) < 3: + return [] all_text = " ".join([msg["content"] for msg in messages]) - stopwords = { # Expanded stopwords - "the", "and", "is", "in", "to", "a", "of", "for", "that", "this", "it", "with", "on", "as", "be", "at", "by", "an", "or", "but", "if", "from", "when", "where", "how", "all", "any", "both", "each", "few", "more", "most", "some", "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "can", "will", "just", "should", "now", "also", "like", "even", "because", "way", "who", "what", "yeah", "yes", "no", "nah", "lol", "lmao", "haha", "hmm", "um", "uh", "oh", "ah", "ok", "okay", "dont", "don't", "doesnt", "doesn't", "didnt", "didn't", "cant", "can't", "im", "i'm", "ive", "i've", "youre", "you're", "youve", "you've", "hes", "he's", "shes", "she's", "its", "it's", "were", "we're", "weve", "we've", "theyre", "they're", "theyve", "they've", "thats", "that's", "whats", "what's", "whos", "who's", "gonna", "gotta", "kinda", "sorta", "wheatley" # Added wheatley, removed gurt + stopwords = { # Expanded stopwords + "the", + "and", + "is", + "in", + "to", + "a", + "of", + "for", + "that", + "this", + "it", + "with", + "on", + "as", + "be", + "at", + "by", + "an", + "or", + "but", + "if", + "from", + "when", + "where", + "how", + "all", + "any", + "both", + "each", + "few", + "more", + "most", + "some", + "such", + "no", + "nor", + "not", + "only", + "own", + "same", + "so", + "than", + "too", + "very", + "can", + "will", + "just", + "should", + "now", + "also", + "like", + "even", + "because", + "way", + "who", + "what", + "yeah", + "yes", + "no", + "nah", + "lol", + "lmao", + "haha", + "hmm", + "um", + "uh", + "oh", + "ah", + "ok", + "okay", + "dont", + "don't", + "doesnt", + "doesn't", + "didnt", + "didn't", + "cant", + "can't", + "im", + "i'm", + "ive", + "i've", + "youre", + "you're", + "youve", + "you've", + "hes", + "he's", + "shes", + "she's", + "its", + "it's", + "were", + "we're", + "weve", + "we've", + "theyre", + "they're", + "theyve", + "they've", + "thats", + "that's", + "whats", + "what's", + "whos", + "who's", + "gonna", + "gotta", + "kinda", + "sorta", + "wheatley", # Added wheatley, removed gurt } def extract_ngrams(text, n_values=[1, 2, 3]): - words = re.findall(r'\b\w+\b', text.lower()) - filtered_words = [word for word in words if word not in stopwords and len(word) > 2] + words = re.findall(r"\b\w+\b", text.lower()) + filtered_words = [ + word for word in words if word not in stopwords and len(word) > 2 + ] all_ngrams = [] - for n in n_values: all_ngrams.extend([' '.join(filtered_words[i:i+n]) for i in range(len(filtered_words)-n+1)]) + for n in n_values: + all_ngrams.extend( + [ + " ".join(filtered_words[i : i + n]) + for i in range(len(filtered_words) - n + 1) + ] + ) return all_ngrams all_ngrams = extract_ngrams(all_text) ngram_counts = defaultdict(int) - for ngram in all_ngrams: ngram_counts[ngram] += 1 + for ngram in all_ngrams: + ngram_counts[ngram] += 1 min_count = 2 if len(messages) > 10 else 1 - filtered_ngrams = {ngram: count for ngram, count in ngram_counts.items() if count >= min_count} + filtered_ngrams = { + ngram: count for ngram, count in ngram_counts.items() if count >= min_count + } total_messages = len(messages) ngram_scores = {} for ngram, count in filtered_ngrams.items(): # Calculate score based on frequency, length, and spread across messages message_count = sum(1 for msg in messages if ngram in msg["content"].lower()) - spread_factor = (message_count / total_messages) ** 0.5 # Less emphasis on spread - length_bonus = len(ngram.split()) * 0.1 # Slight bonus for longer ngrams + spread_factor = ( + message_count / total_messages + ) ** 0.5 # Less emphasis on spread + length_bonus = len(ngram.split()) * 0.1 # Slight bonus for longer ngrams # Adjust importance calculation importance = (count * (0.4 + spread_factor)) + length_bonus ngram_scores[ngram] = importance @@ -259,107 +490,237 @@ def identify_conversation_topics(cog: 'WheatleyCog', messages: List[Dict[str, An break if not is_subgram and ngram not in temp_processed: ngrams_to_consider.append((ngram, score)) - temp_processed.add(ngram) # Avoid adding duplicates if logic changes + temp_processed.add(ngram) # Avoid adding duplicates if logic changes # Now process the filtered ngrams - sorted_ngrams = ngrams_to_consider # Use the filtered list + sorted_ngrams = ngrams_to_consider # Use the filtered list - for ngram, score in sorted_ngrams[:10]: # Consider top 10 potential topics after filtering - if ngram in processed_ngrams: continue + for ngram, score in sorted_ngrams[ + :10 + ]: # Consider top 10 potential topics after filtering + if ngram in processed_ngrams: + continue related_terms = [] # Find related terms (sub-ngrams or overlapping ngrams from the original sorted list) - for other_ngram, other_score in sorted_by_score: # Search in original sorted list for relations - if other_ngram == ngram or other_ngram in processed_ngrams: continue - ngram_words = set(ngram.split()); other_words = set(other_ngram.split()) + for ( + other_ngram, + other_score, + ) in sorted_by_score: # Search in original sorted list for relations + if other_ngram == ngram or other_ngram in processed_ngrams: + continue + ngram_words = set(ngram.split()) + other_words = set(other_ngram.split()) # Check for overlap or if one is a sub-string (more lenient relation) if ngram_words.intersection(other_words) or other_ngram in ngram: related_terms.append({"term": other_ngram, "score": other_score}) # Don't mark related terms as fully processed here unless they are direct sub-ngrams # processed_ngrams.add(other_ngram) - if len(related_terms) >= 3: break # Limit related terms shown + if len(related_terms) >= 3: + break # Limit related terms shown processed_ngrams.add(ngram) - topic_entry = {"topic": ngram, "score": score, "related_terms": related_terms, "message_count": sum(1 for msg in messages if ngram in msg["content"].lower())} + topic_entry = { + "topic": ngram, + "score": score, + "related_terms": related_terms, + "message_count": sum( + 1 for msg in messages if ngram in msg["content"].lower() + ), + } topics.append(topic_entry) - if len(topics) >= MAX_ACTIVE_TOPICS: break # Use config for max topics + if len(topics) >= MAX_ACTIVE_TOPICS: + break # Use config for max topics # Simple sentiment analysis for topics - positive_words = {"good", "great", "awesome", "amazing", "excellent", "love", "like", "best", "better", "nice", "cool"} + positive_words = { + "good", + "great", + "awesome", + "amazing", + "excellent", + "love", + "like", + "best", + "better", + "nice", + "cool", + } # Removed the second loop that seemed redundant # sorted_ngrams = sorted(ngram_scores.items(), key=lambda x: x[1], reverse=True) # for ngram, score in sorted_ngrams[:15]: ... # Simple sentiment analysis for topics (applied to the already selected topics) - positive_words = {"good", "great", "awesome", "amazing", "excellent", "love", "like", "best", "better", "nice", "cool", "happy", "glad"} - negative_words = {"bad", "terrible", "awful", "worst", "hate", "dislike", "sucks", "stupid", "boring", "annoying", "sad", "upset", "angry"} + positive_words = { + "good", + "great", + "awesome", + "amazing", + "excellent", + "love", + "like", + "best", + "better", + "nice", + "cool", + "happy", + "glad", + } + negative_words = { + "bad", + "terrible", + "awful", + "worst", + "hate", + "dislike", + "sucks", + "stupid", + "boring", + "annoying", + "sad", + "upset", + "angry", + } for topic in topics: - topic_messages = [msg["content"] for msg in messages if topic["topic"] in msg["content"].lower()] + topic_messages = [ + msg["content"] + for msg in messages + if topic["topic"] in msg["content"].lower() + ] topic_text = " ".join(topic_messages).lower() positive_count = sum(1 for word in positive_words if word in topic_text) negative_count = sum(1 for word in negative_words if word in topic_text) - if positive_count > negative_count: topic["sentiment"] = "positive" - elif negative_count > positive_count: topic["sentiment"] = "negative" - else: topic["sentiment"] = "neutral" + if positive_count > negative_count: + topic["sentiment"] = "positive" + elif negative_count > positive_count: + topic["sentiment"] = "negative" + else: + topic["sentiment"] = "neutral" return topics -def analyze_user_interactions(cog: 'WheatleyCog', messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # Updated type hint + +def analyze_user_interactions( + cog: "WheatleyCog", messages: List[Dict[str, Any]] +) -> List[Dict[str, Any]]: # Updated type hint """Analyze interactions between users in the conversation""" interactions = [] response_map = defaultdict(int) for i in range(1, len(messages)): - current_msg = messages[i]; prev_msg = messages[i-1] - if current_msg["author"]["id"] == prev_msg["author"]["id"]: continue - responder = current_msg["author"]["id"]; respondee = prev_msg["author"]["id"] + current_msg = messages[i] + prev_msg = messages[i - 1] + if current_msg["author"]["id"] == prev_msg["author"]["id"]: + continue + responder = current_msg["author"]["id"] + respondee = prev_msg["author"]["id"] key = f"{responder}:{respondee}" response_map[key] += 1 for key, count in response_map.items(): if count > 1: responder, respondee = key.split(":") - interactions.append({"responder": responder, "respondee": respondee, "count": count}) + interactions.append( + {"responder": responder, "respondee": respondee, "count": count} + ) return interactions -def update_user_preferences(cog: 'WheatleyCog'): # Updated type hint + +def update_user_preferences(cog: "WheatleyCog"): # Updated type hint """Update stored user preferences based on observed interactions""" # Note: This function previously updated preferences based on Gurt's personality. # It might be removed or significantly simplified for Wheatley. # Kept for now, but its effect might be minimal without personality traits. - for user_id, messages in cog.message_cache['by_user'].items(): - if len(messages) < 5: continue - emoji_count = 0; slang_count = 0; avg_length = 0 + for user_id, messages in cog.message_cache["by_user"].items(): + if len(messages) < 5: + continue + emoji_count = 0 + slang_count = 0 + avg_length = 0 for msg in messages: content = msg["content"] - emoji_count += len(re.findall(r'[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F700-\U0001F77F\U0001F780-\U0001F7FF\U0001F800-\U0001F8FF\U0001F900-\U0001F9FF\U0001FA00-\U0001FA6F\U0001FA70-\U0001FAFF\U00002702-\U000027B0\U000024C2-\U0001F251]', content)) - slang_words = ["ngl", "icl", "pmo", "ts", "bro", "vro", "bruh", "tuff", "kevin", "mate", "chap", "bollocks"] # Added Wheatley-ish slang + emoji_count += len( + re.findall( + r"[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F700-\U0001F77F\U0001F780-\U0001F7FF\U0001F800-\U0001F8FF\U0001F900-\U0001F9FF\U0001FA00-\U0001FA6F\U0001FA70-\U0001FAFF\U00002702-\U000027B0\U000024C2-\U0001F251]", + content, + ) + ) + slang_words = [ + "ngl", + "icl", + "pmo", + "ts", + "bro", + "vro", + "bruh", + "tuff", + "kevin", + "mate", + "chap", + "bollocks", + ] # Added Wheatley-ish slang for word in slang_words: - if re.search(r'\b' + word + r'\b', content.lower()): slang_count += 1 + if re.search(r"\b" + word + r"\b", content.lower()): + slang_count += 1 avg_length += len(content) - if messages: avg_length /= len(messages) + if messages: + avg_length /= len(messages) # Ensure user_preferences exists - if not hasattr(cog, 'user_preferences'): cog.user_preferences = defaultdict(dict) + if not hasattr(cog, "user_preferences"): + cog.user_preferences = defaultdict(dict) user_prefs = cog.user_preferences[user_id] # Apply learning rate cautiously - if emoji_count > 0: user_prefs["emoji_preference"] = user_prefs.get("emoji_preference", 0.5) * (1 - LEARNING_RATE) + (emoji_count / len(messages)) * LEARNING_RATE - if slang_count > 0: user_prefs["slang_preference"] = user_prefs.get("slang_preference", 0.5) * (1 - LEARNING_RATE) + (slang_count / len(messages)) * LEARNING_RATE - user_prefs["length_preference"] = user_prefs.get("length_preference", 50) * (1 - LEARNING_RATE) + avg_length * LEARNING_RATE + if emoji_count > 0: + user_prefs["emoji_preference"] = ( + user_prefs.get("emoji_preference", 0.5) * (1 - LEARNING_RATE) + + (emoji_count / len(messages)) * LEARNING_RATE + ) + if slang_count > 0: + user_prefs["slang_preference"] = ( + user_prefs.get("slang_preference", 0.5) * (1 - LEARNING_RATE) + + (slang_count / len(messages)) * LEARNING_RATE + ) + user_prefs["length_preference"] = ( + user_prefs.get("length_preference", 50) * (1 - LEARNING_RATE) + + avg_length * LEARNING_RATE + ) + # --- Removed evolve_personality function --- # --- Removed reflect_on_memories function --- # --- Removed decompose_goal_into_steps function --- -def analyze_message_sentiment(cog: 'WheatleyCog', message_content: str) -> Dict[str, Any]: # Updated type hint + +def analyze_message_sentiment( + cog: "WheatleyCog", message_content: str +) -> Dict[str, Any]: # Updated type hint """Analyzes the sentiment of a message using keywords and emojis.""" content = message_content.lower() - result = {"sentiment": "neutral", "intensity": 0.5, "emotions": [], "confidence": 0.5} + result = { + "sentiment": "neutral", + "intensity": 0.5, + "emotions": [], + "confidence": 0.5, + } - positive_emoji_count = sum(1 for emoji in EMOJI_SENTIMENT["positive"] if emoji in content) - negative_emoji_count = sum(1 for emoji in EMOJI_SENTIMENT["negative"] if emoji in content) - total_emoji_count = positive_emoji_count + negative_emoji_count + sum(1 for emoji in EMOJI_SENTIMENT["neutral"] if emoji in content) + positive_emoji_count = sum( + 1 for emoji in EMOJI_SENTIMENT["positive"] if emoji in content + ) + negative_emoji_count = sum( + 1 for emoji in EMOJI_SENTIMENT["negative"] if emoji in content + ) + total_emoji_count = ( + positive_emoji_count + + negative_emoji_count + + sum(1 for emoji in EMOJI_SENTIMENT["neutral"] if emoji in content) + ) - detected_emotions = []; emotion_scores = {} + detected_emotions = [] + emotion_scores = {} for emotion, keywords in EMOTION_KEYWORDS.items(): - emotion_count = sum(1 for keyword in keywords if re.search(r'\b' + re.escape(keyword) + r'\b', content)) + emotion_count = sum( + 1 + for keyword in keywords + if re.search(r"\b" + re.escape(keyword) + r"\b", content) + ) if emotion_count > 0: emotion_score = min(1.0, emotion_count / len(keywords) * 2) emotion_scores[emotion] = emotion_score @@ -369,76 +730,205 @@ def analyze_message_sentiment(cog: 'WheatleyCog', message_content: str) -> Dict[ primary_emotion = max(emotion_scores.items(), key=lambda x: x[1]) result["emotions"] = [primary_emotion[0]] for emotion, score in emotion_scores.items(): - if emotion != primary_emotion[0] and score > primary_emotion[1] * 0.7: result["emotions"].append(emotion) + if emotion != primary_emotion[0] and score > primary_emotion[1] * 0.7: + result["emotions"].append(emotion) - positive_emotions = ["joy"]; negative_emotions = ["sadness", "anger", "fear", "disgust"] - if primary_emotion[0] in positive_emotions: result["sentiment"] = "positive"; result["intensity"] = primary_emotion[1] - elif primary_emotion[0] in negative_emotions: result["sentiment"] = "negative"; result["intensity"] = primary_emotion[1] - else: result["sentiment"] = "neutral"; result["intensity"] = 0.5 + positive_emotions = ["joy"] + negative_emotions = ["sadness", "anger", "fear", "disgust"] + if primary_emotion[0] in positive_emotions: + result["sentiment"] = "positive" + result["intensity"] = primary_emotion[1] + elif primary_emotion[0] in negative_emotions: + result["sentiment"] = "negative" + result["intensity"] = primary_emotion[1] + else: + result["sentiment"] = "neutral" + result["intensity"] = 0.5 result["confidence"] = min(0.9, 0.5 + primary_emotion[1] * 0.4) elif total_emoji_count > 0: - if positive_emoji_count > negative_emoji_count: result["sentiment"] = "positive"; result["intensity"] = min(0.9, 0.5 + (positive_emoji_count / total_emoji_count) * 0.4); result["confidence"] = min(0.8, 0.4 + (positive_emoji_count / total_emoji_count) * 0.4) - elif negative_emoji_count > positive_emoji_count: result["sentiment"] = "negative"; result["intensity"] = min(0.9, 0.5 + (negative_emoji_count / total_emoji_count) * 0.4); result["confidence"] = min(0.8, 0.4 + (negative_emoji_count / total_emoji_count) * 0.4) - else: result["sentiment"] = "neutral"; result["intensity"] = 0.5; result["confidence"] = 0.6 + if positive_emoji_count > negative_emoji_count: + result["sentiment"] = "positive" + result["intensity"] = min( + 0.9, 0.5 + (positive_emoji_count / total_emoji_count) * 0.4 + ) + result["confidence"] = min( + 0.8, 0.4 + (positive_emoji_count / total_emoji_count) * 0.4 + ) + elif negative_emoji_count > positive_emoji_count: + result["sentiment"] = "negative" + result["intensity"] = min( + 0.9, 0.5 + (negative_emoji_count / total_emoji_count) * 0.4 + ) + result["confidence"] = min( + 0.8, 0.4 + (negative_emoji_count / total_emoji_count) * 0.4 + ) + else: + result["sentiment"] = "neutral" + result["intensity"] = 0.5 + result["confidence"] = 0.6 - else: # Basic text fallback - positive_words = {"good", "great", "awesome", "amazing", "excellent", "love", "like", "best", "better", "nice", "cool", "happy", "glad", "thanks", "thank", "appreciate", "wonderful", "fantastic", "perfect", "beautiful", "fun", "enjoy", "yes", "yep"} - negative_words = {"bad", "terrible", "awful", "worst", "hate", "dislike", "sucks", "stupid", "boring", "annoying", "sad", "upset", "angry", "mad", "disappointed", "sorry", "unfortunate", "horrible", "ugly", "wrong", "fail", "no", "nope"} - words = re.findall(r'\b\w+\b', content) + else: # Basic text fallback + positive_words = { + "good", + "great", + "awesome", + "amazing", + "excellent", + "love", + "like", + "best", + "better", + "nice", + "cool", + "happy", + "glad", + "thanks", + "thank", + "appreciate", + "wonderful", + "fantastic", + "perfect", + "beautiful", + "fun", + "enjoy", + "yes", + "yep", + } + negative_words = { + "bad", + "terrible", + "awful", + "worst", + "hate", + "dislike", + "sucks", + "stupid", + "boring", + "annoying", + "sad", + "upset", + "angry", + "mad", + "disappointed", + "sorry", + "unfortunate", + "horrible", + "ugly", + "wrong", + "fail", + "no", + "nope", + } + words = re.findall(r"\b\w+\b", content) positive_count = sum(1 for word in words if word in positive_words) negative_count = sum(1 for word in words if word in negative_words) - if positive_count > negative_count: result["sentiment"] = "positive"; result["intensity"] = min(0.8, 0.5 + (positive_count / len(words)) * 2 if words else 0); result["confidence"] = min(0.7, 0.3 + (positive_count / len(words)) * 0.4 if words else 0) - elif negative_count > positive_count: result["sentiment"] = "negative"; result["intensity"] = min(0.8, 0.5 + (negative_count / len(words)) * 2 if words else 0); result["confidence"] = min(0.7, 0.3 + (negative_count / len(words)) * 0.4 if words else 0) - else: result["sentiment"] = "neutral"; result["intensity"] = 0.5; result["confidence"] = 0.5 + if positive_count > negative_count: + result["sentiment"] = "positive" + result["intensity"] = min( + 0.8, 0.5 + (positive_count / len(words)) * 2 if words else 0 + ) + result["confidence"] = min( + 0.7, 0.3 + (positive_count / len(words)) * 0.4 if words else 0 + ) + elif negative_count > positive_count: + result["sentiment"] = "negative" + result["intensity"] = min( + 0.8, 0.5 + (negative_count / len(words)) * 2 if words else 0 + ) + result["confidence"] = min( + 0.7, 0.3 + (negative_count / len(words)) * 0.4 if words else 0 + ) + else: + result["sentiment"] = "neutral" + result["intensity"] = 0.5 + result["confidence"] = 0.5 return result -def update_conversation_sentiment(cog: 'WheatleyCog', channel_id: int, user_id: str, message_sentiment: Dict[str, Any]): # Updated type hint + +def update_conversation_sentiment( + cog: "WheatleyCog", channel_id: int, user_id: str, message_sentiment: Dict[str, Any] +): # Updated type hint """Updates the conversation sentiment tracking based on a new message's sentiment.""" channel_sentiment = cog.conversation_sentiment[channel_id] now = time.time() # Ensure sentiment_update_interval exists on cog, default if not - sentiment_update_interval = getattr(cog, 'sentiment_update_interval', 300) # Default to 300s if not set + sentiment_update_interval = getattr( + cog, "sentiment_update_interval", 300 + ) # Default to 300s if not set if now - channel_sentiment["last_update"] > sentiment_update_interval: - if channel_sentiment["overall"] == "positive": channel_sentiment["intensity"] = max(0.5, channel_sentiment["intensity"] - SENTIMENT_DECAY_RATE) - elif channel_sentiment["overall"] == "negative": channel_sentiment["intensity"] = max(0.5, channel_sentiment["intensity"] - SENTIMENT_DECAY_RATE) + if channel_sentiment["overall"] == "positive": + channel_sentiment["intensity"] = max( + 0.5, channel_sentiment["intensity"] - SENTIMENT_DECAY_RATE + ) + elif channel_sentiment["overall"] == "negative": + channel_sentiment["intensity"] = max( + 0.5, channel_sentiment["intensity"] - SENTIMENT_DECAY_RATE + ) channel_sentiment["recent_trend"] = "stable" channel_sentiment["last_update"] = now - user_sentiment = channel_sentiment["user_sentiments"].get(user_id, {"sentiment": "neutral", "intensity": 0.5}) + user_sentiment = channel_sentiment["user_sentiments"].get( + user_id, {"sentiment": "neutral", "intensity": 0.5} + ) confidence_weight = message_sentiment["confidence"] if user_sentiment["sentiment"] == message_sentiment["sentiment"]: - new_intensity = user_sentiment["intensity"] * 0.7 + message_sentiment["intensity"] * 0.3 + new_intensity = ( + user_sentiment["intensity"] * 0.7 + message_sentiment["intensity"] * 0.3 + ) user_sentiment["intensity"] = min(0.95, new_intensity) else: if message_sentiment["confidence"] > 0.7: user_sentiment["sentiment"] = message_sentiment["sentiment"] - user_sentiment["intensity"] = message_sentiment["intensity"] * 0.7 + user_sentiment["intensity"] * 0.3 + user_sentiment["intensity"] = ( + message_sentiment["intensity"] * 0.7 + user_sentiment["intensity"] * 0.3 + ) else: if message_sentiment["intensity"] > user_sentiment["intensity"]: user_sentiment["sentiment"] = message_sentiment["sentiment"] - user_sentiment["intensity"] = user_sentiment["intensity"] * 0.6 + message_sentiment["intensity"] * 0.4 + user_sentiment["intensity"] = ( + user_sentiment["intensity"] * 0.6 + + message_sentiment["intensity"] * 0.4 + ) user_sentiment["emotions"] = message_sentiment.get("emotions", []) channel_sentiment["user_sentiments"][user_id] = user_sentiment # Update overall based on active users (simplified access to active_conversations) - active_user_sentiments = [s for uid, s in channel_sentiment["user_sentiments"].items() if uid in cog.active_conversations.get(channel_id, {}).get('participants', set())] + active_user_sentiments = [ + s + for uid, s in channel_sentiment["user_sentiments"].items() + if uid + in cog.active_conversations.get(channel_id, {}).get("participants", set()) + ] if active_user_sentiments: sentiment_counts = defaultdict(int) - for s in active_user_sentiments: sentiment_counts[s["sentiment"]] += 1 + for s in active_user_sentiments: + sentiment_counts[s["sentiment"]] += 1 dominant_sentiment = max(sentiment_counts.items(), key=lambda x: x[1])[0] - avg_intensity = sum(s["intensity"] for s in active_user_sentiments if s["sentiment"] == dominant_sentiment) / sentiment_counts[dominant_sentiment] + avg_intensity = ( + sum( + s["intensity"] + for s in active_user_sentiments + if s["sentiment"] == dominant_sentiment + ) + / sentiment_counts[dominant_sentiment] + ) - prev_sentiment = channel_sentiment["overall"]; prev_intensity = channel_sentiment["intensity"] + prev_sentiment = channel_sentiment["overall"] + prev_intensity = channel_sentiment["intensity"] if dominant_sentiment == prev_sentiment: - if avg_intensity > prev_intensity + 0.1: channel_sentiment["recent_trend"] = "intensifying" - elif avg_intensity < prev_intensity - 0.1: channel_sentiment["recent_trend"] = "diminishing" - else: channel_sentiment["recent_trend"] = "stable" - else: channel_sentiment["recent_trend"] = "changing" + if avg_intensity > prev_intensity + 0.1: + channel_sentiment["recent_trend"] = "intensifying" + elif avg_intensity < prev_intensity - 0.1: + channel_sentiment["recent_trend"] = "diminishing" + else: + channel_sentiment["recent_trend"] = "stable" + else: + channel_sentiment["recent_trend"] = "changing" channel_sentiment["overall"] = dominant_sentiment channel_sentiment["intensity"] = avg_intensity diff --git a/wheatley/api.py b/wheatley/api.py index 988ba95..0890192 100644 --- a/wheatley/api.py +++ b/wheatley/api.py @@ -7,7 +7,7 @@ import re import time import datetime from typing import TYPE_CHECKING, Optional, List, Dict, Any, Union, AsyncIterable -import jsonschema # For manual JSON validation +import jsonschema # For manual JSON validation from .tools import get_conversation_summary # Vertex AI Imports @@ -15,27 +15,51 @@ try: import vertexai from vertexai import generative_models from vertexai.generative_models import ( - GenerativeModel, GenerationConfig, Part, Content, Tool, FunctionDeclaration, - GenerationResponse, FinishReason + GenerativeModel, + GenerationConfig, + Part, + Content, + Tool, + FunctionDeclaration, + GenerationResponse, + FinishReason, ) from google.api_core import exceptions as google_exceptions - from google.cloud.storage import Client as GCSClient # For potential image uploads + from google.cloud.storage import Client as GCSClient # For potential image uploads except ImportError: - print("WARNING: google-cloud-vertexai or google-cloud-storage not installed. API calls will fail.") + print( + "WARNING: google-cloud-vertexai or google-cloud-storage not installed. API calls will fail." + ) + # Define dummy classes/exceptions if library isn't installed class DummyGenerativeModel: - def __init__(self, model_name, system_instruction=None, tools=None): pass - async def generate_content_async(self, contents, generation_config=None, safety_settings=None, stream=False): return None + def __init__(self, model_name, system_instruction=None, tools=None): + pass + + async def generate_content_async( + self, contents, generation_config=None, safety_settings=None, stream=False + ): + return None + GenerativeModel = DummyGenerativeModel + class DummyPart: @staticmethod - def from_text(text): return None + def from_text(text): + return None + @staticmethod - def from_data(data, mime_type): return None + def from_data(data, mime_type): + return None + @staticmethod - def from_uri(uri, mime_type): return None + def from_uri(uri, mime_type): + return None + @staticmethod - def from_function_response(name, response): return None + def from_function_response(name, response): + return None + Part = DummyPart Content = dict Tool = list @@ -43,30 +67,46 @@ except ImportError: GenerationConfig = dict GenerationResponse = object FinishReason = object + class DummyGoogleExceptions: - ResourceExhausted = type('ResourceExhausted', (Exception,), {}) - InternalServerError = type('InternalServerError', (Exception,), {}) - ServiceUnavailable = type('ServiceUnavailable', (Exception,), {}) - InvalidArgument = type('InvalidArgument', (Exception,), {}) - GoogleAPICallError = type('GoogleAPICallError', (Exception,), {}) # Generic fallback + ResourceExhausted = type("ResourceExhausted", (Exception,), {}) + InternalServerError = type("InternalServerError", (Exception,), {}) + ServiceUnavailable = type("ServiceUnavailable", (Exception,), {}) + InvalidArgument = type("InvalidArgument", (Exception,), {}) + GoogleAPICallError = type( + "GoogleAPICallError", (Exception,), {} + ) # Generic fallback + google_exceptions = DummyGoogleExceptions() # Relative imports for components within the 'gurt' package from .config import ( - PROJECT_ID, LOCATION, DEFAULT_MODEL, FALLBACK_MODEL, - API_TIMEOUT, API_RETRY_ATTEMPTS, API_RETRY_DELAY, TOOLS, RESPONSE_SCHEMA, - PROACTIVE_PLAN_SCHEMA, # Import the new schema - TAVILY_API_KEY, PISTON_API_URL, PISTON_API_KEY # Import other needed configs + PROJECT_ID, + LOCATION, + DEFAULT_MODEL, + FALLBACK_MODEL, + API_TIMEOUT, + API_RETRY_ATTEMPTS, + API_RETRY_DELAY, + TOOLS, + RESPONSE_SCHEMA, + PROACTIVE_PLAN_SCHEMA, # Import the new schema + TAVILY_API_KEY, + PISTON_API_URL, + PISTON_API_KEY, # Import other needed configs ) from .prompt import build_dynamic_system_prompt -from .context import gather_conversation_context, get_memory_context # Renamed functions -from .tools import TOOL_MAPPING # Import tool mapping -from .utils import format_message, log_internal_api_call # Import utilities -import copy # Needed for deep copying schemas +from .context import ( + gather_conversation_context, + get_memory_context, +) # Renamed functions +from .tools import TOOL_MAPPING # Import tool mapping +from .utils import format_message, log_internal_api_call # Import utilities +import copy # Needed for deep copying schemas if TYPE_CHECKING: - from .cog import WheatleyCog # Import WheatleyCog for type hinting only + from .cog import WheatleyCog # Import WheatleyCog for type hinting only # --- Schema Preprocessing Helper --- @@ -83,38 +123,49 @@ def _preprocess_schema_for_vertex(schema: Dict[str, Any]) -> Dict[str, Any]: A new, preprocessed schema dictionary. """ if not isinstance(schema, dict): - return schema # Return non-dict elements as is + return schema # Return non-dict elements as is - processed_schema = copy.deepcopy(schema) # Work on a copy + processed_schema = copy.deepcopy(schema) # Work on a copy for key, value in processed_schema.items(): if key == "type" and isinstance(value, list): # Find the first non-"null" type in the list - first_valid_type = next((t for t in value if isinstance(t, str) and t.lower() != "null"), None) + first_valid_type = next( + (t for t in value if isinstance(t, str) and t.lower() != "null"), None + ) if first_valid_type: processed_schema[key] = first_valid_type else: # Fallback if only "null" or invalid types are present (shouldn't happen in valid schemas) - processed_schema[key] = "object" # Or handle as error - print(f"Warning: Schema preprocessing found list type '{value}' with no valid non-null string type. Falling back to 'object'.") + processed_schema[key] = "object" # Or handle as error + print( + f"Warning: Schema preprocessing found list type '{value}' with no valid non-null string type. Falling back to 'object'." + ) elif isinstance(value, dict): - processed_schema[key] = _preprocess_schema_for_vertex(value) # Recurse for nested objects + processed_schema[key] = _preprocess_schema_for_vertex( + value + ) # Recurse for nested objects elif isinstance(value, list): # Recurse for items within arrays (e.g., in 'properties' of array items) - processed_schema[key] = [_preprocess_schema_for_vertex(item) if isinstance(item, dict) else item for item in value] + processed_schema[key] = [ + _preprocess_schema_for_vertex(item) if isinstance(item, dict) else item + for item in value + ] # Handle 'properties' specifically elif key == "properties" and isinstance(value, dict): - processed_schema[key] = {prop_key: _preprocess_schema_for_vertex(prop_value) for prop_key, prop_value in value.items()} + processed_schema[key] = { + prop_key: _preprocess_schema_for_vertex(prop_value) + for prop_key, prop_value in value.items() + } # Handle 'items' specifically if it's a schema object elif key == "items" and isinstance(value, dict): - processed_schema[key] = _preprocess_schema_for_vertex(value) - + processed_schema[key] = _preprocess_schema_for_vertex(value) return processed_schema # --- Helper Function to Safely Extract Text --- -def _get_response_text(response: Optional['GenerationResponse']) -> Optional[str]: +def _get_response_text(response: Optional["GenerationResponse"]) -> Optional[str]: """Safely extracts the text content from the first text part of a GenerationResponse.""" if not response or not response.candidates: return None @@ -122,10 +173,12 @@ def _get_response_text(response: Optional['GenerationResponse']) -> Optional[str # Iterate through parts to find the first text part for part in response.candidates[0].content.parts: # Check if the part has a 'text' attribute and it's not empty - if hasattr(part, 'text') and part.text: + if hasattr(part, "text") and part.text: return part.text # If no text part is found (e.g., only function call or empty text parts) - print(f"[_get_response_text] No text part found in candidate parts: {response.candidates[0].content.parts}") # Log parts structure + print( + f"[_get_response_text] No text part found in candidate parts: {response.candidates[0].content.parts}" + ) # Log parts structure return None except (AttributeError, IndexError) as e: # Handle cases where structure is unexpected or parts list is empty @@ -149,25 +202,40 @@ except Exception as e: # --- Constants --- # Define standard safety settings (adjust as needed) # Use actual types if import succeeded, otherwise fallback to Any -_HarmCategory = getattr(generative_models, 'HarmCategory', Any) -_HarmBlockThreshold = getattr(generative_models, 'HarmBlockThreshold', Any) +_HarmCategory = getattr(generative_models, "HarmCategory", Any) +_HarmBlockThreshold = getattr(generative_models, "HarmBlockThreshold", Any) STANDARD_SAFETY_SETTINGS = { - getattr(_HarmCategory, 'HARM_CATEGORY_HATE_SPEECH', 'HARM_CATEGORY_HATE_SPEECH'): getattr(_HarmBlockThreshold, 'BLOCK_MEDIUM_AND_ABOVE', 'BLOCK_MEDIUM_AND_ABOVE'), - getattr(_HarmCategory, 'HARM_CATEGORY_DANGEROUS_CONTENT', 'HARM_CATEGORY_DANGEROUS_CONTENT'): getattr(_HarmBlockThreshold, 'BLOCK_MEDIUM_AND_ABOVE', 'BLOCK_MEDIUM_AND_ABOVE'), - getattr(_HarmCategory, 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'HARM_CATEGORY_SEXUALLY_EXPLICIT'): getattr(_HarmBlockThreshold, 'BLOCK_MEDIUM_AND_ABOVE', 'BLOCK_MEDIUM_AND_ABOVE'), - getattr(_HarmCategory, 'HARM_CATEGORY_HARASSMENT', 'HARM_CATEGORY_HARASSMENT'): getattr(_HarmBlockThreshold, 'BLOCK_MEDIUM_AND_ABOVE', 'BLOCK_MEDIUM_AND_ABOVE'), + getattr( + _HarmCategory, "HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_HATE_SPEECH" + ): getattr(_HarmBlockThreshold, "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE"), + getattr( + _HarmCategory, + "HARM_CATEGORY_DANGEROUS_CONTENT", + "HARM_CATEGORY_DANGEROUS_CONTENT", + ): getattr(_HarmBlockThreshold, "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE"), + getattr( + _HarmCategory, + "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "HARM_CATEGORY_SEXUALLY_EXPLICIT", + ): getattr(_HarmBlockThreshold, "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE"), + getattr( + _HarmCategory, "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_HARASSMENT" + ): getattr(_HarmBlockThreshold, "BLOCK_MEDIUM_AND_ABOVE", "BLOCK_MEDIUM_AND_ABOVE"), } + # --- API Call Helper --- async def call_vertex_api_with_retry( - cog: 'WheatleyCog', - model: 'GenerativeModel', # Use string literal for type hint - contents: List['Content'], # Use string literal for type hint - generation_config: 'GenerationConfig', # Use string literal for type hint - safety_settings: Optional[Dict[Any, Any]], # Use Any for broader compatibility + cog: "WheatleyCog", + model: "GenerativeModel", # Use string literal for type hint + contents: List["Content"], # Use string literal for type hint + generation_config: "GenerationConfig", # Use string literal for type hint + safety_settings: Optional[Dict[Any, Any]], # Use Any for broader compatibility request_desc: str, - stream: bool = False -) -> Union['GenerationResponse', AsyncIterable['GenerationResponse'], None]: # Use string literals + stream: bool = False, +) -> Union[ + "GenerationResponse", AsyncIterable["GenerationResponse"], None +]: # Use string literals """ Calls the Vertex AI Gemini API with retry logic. @@ -187,30 +255,40 @@ async def call_vertex_api_with_retry( Exception: If the API call fails after all retry attempts or encounters a non-retryable error. """ last_exception = None - model_name = model._model_name # Get model name for logging + model_name = model._model_name # Get model name for logging start_time = time.monotonic() for attempt in range(API_RETRY_ATTEMPTS + 1): try: - print(f"Sending API request for {request_desc} using {model_name} (Attempt {attempt + 1}/{API_RETRY_ATTEMPTS + 1})...") + print( + f"Sending API request for {request_desc} using {model_name} (Attempt {attempt + 1}/{API_RETRY_ATTEMPTS + 1})..." + ) response = await model.generate_content_async( contents=contents, generation_config=generation_config, safety_settings=safety_settings or STANDARD_SAFETY_SETTINGS, - stream=stream + stream=stream, ) # --- Success Logging --- elapsed_time = time.monotonic() - start_time # Ensure model_name exists in stats before incrementing if model_name not in cog.api_stats: - cog.api_stats[model_name] = {'success': 0, 'failure': 0, 'retries': 0, 'total_time': 0.0, 'count': 0} - cog.api_stats[model_name]['success'] += 1 - cog.api_stats[model_name]['total_time'] += elapsed_time - cog.api_stats[model_name]['count'] += 1 - print(f"API request successful for {request_desc} ({model_name}) in {elapsed_time:.2f}s.") - return response # Success + cog.api_stats[model_name] = { + "success": 0, + "failure": 0, + "retries": 0, + "total_time": 0.0, + "count": 0, + } + cog.api_stats[model_name]["success"] += 1 + cog.api_stats[model_name]["total_time"] += elapsed_time + cog.api_stats[model_name]["count"] += 1 + print( + f"API request successful for {request_desc} ({model_name}) in {elapsed_time:.2f}s." + ) + return response # Success except google_exceptions.ResourceExhausted as e: error_msg = f"Rate limit error (ResourceExhausted) for {request_desc}: {e}" @@ -218,55 +296,79 @@ async def call_vertex_api_with_retry( last_exception = e if attempt < API_RETRY_ATTEMPTS: if model_name not in cog.api_stats: - cog.api_stats[model_name] = {'success': 0, 'failure': 0, 'retries': 0, 'total_time': 0.0, 'count': 0} - cog.api_stats[model_name]['retries'] += 1 - wait_time = API_RETRY_DELAY * (2 ** attempt) # Exponential backoff + cog.api_stats[model_name] = { + "success": 0, + "failure": 0, + "retries": 0, + "total_time": 0.0, + "count": 0, + } + cog.api_stats[model_name]["retries"] += 1 + wait_time = API_RETRY_DELAY * (2**attempt) # Exponential backoff print(f"Waiting {wait_time:.2f} seconds before retrying...") await asyncio.sleep(wait_time) continue else: - break # Max retries reached + break # Max retries reached - except (google_exceptions.InternalServerError, google_exceptions.ServiceUnavailable) as e: + except ( + google_exceptions.InternalServerError, + google_exceptions.ServiceUnavailable, + ) as e: error_msg = f"API server error ({type(e).__name__}) for {request_desc}: {e}" print(f"{error_msg} (Attempt {attempt + 1})") last_exception = e if attempt < API_RETRY_ATTEMPTS: if model_name not in cog.api_stats: - cog.api_stats[model_name] = {'success': 0, 'failure': 0, 'retries': 0, 'total_time': 0.0, 'count': 0} - cog.api_stats[model_name]['retries'] += 1 - wait_time = API_RETRY_DELAY * (2 ** attempt) # Exponential backoff + cog.api_stats[model_name] = { + "success": 0, + "failure": 0, + "retries": 0, + "total_time": 0.0, + "count": 0, + } + cog.api_stats[model_name]["retries"] += 1 + wait_time = API_RETRY_DELAY * (2**attempt) # Exponential backoff print(f"Waiting {wait_time:.2f} seconds before retrying...") await asyncio.sleep(wait_time) continue else: - break # Max retries reached + break # Max retries reached except google_exceptions.InvalidArgument as e: # Often indicates a problem with the request itself (e.g., bad schema, unsupported format) error_msg = f"Invalid argument error for {request_desc}: {e}" print(error_msg) last_exception = e - break # Non-retryable + break # Non-retryable - except asyncio.TimeoutError: # Handle potential client-side timeouts if applicable - error_msg = f"Client-side request timed out for {request_desc} (Attempt {attempt + 1})" - print(error_msg) - last_exception = asyncio.TimeoutError(error_msg) - # Decide if client-side timeouts should be retried - if attempt < API_RETRY_ATTEMPTS: - if model_name not in cog.api_stats: - cog.api_stats[model_name] = {'success': 0, 'failure': 0, 'retries': 0, 'total_time': 0.0, 'count': 0} - cog.api_stats[model_name]['retries'] += 1 - await asyncio.sleep(API_RETRY_DELAY * (attempt + 1)) - continue - else: - break + except ( + asyncio.TimeoutError + ): # Handle potential client-side timeouts if applicable + error_msg = f"Client-side request timed out for {request_desc} (Attempt {attempt + 1})" + print(error_msg) + last_exception = asyncio.TimeoutError(error_msg) + # Decide if client-side timeouts should be retried + if attempt < API_RETRY_ATTEMPTS: + if model_name not in cog.api_stats: + cog.api_stats[model_name] = { + "success": 0, + "failure": 0, + "retries": 0, + "total_time": 0.0, + "count": 0, + } + cog.api_stats[model_name]["retries"] += 1 + await asyncio.sleep(API_RETRY_DELAY * (attempt + 1)) + continue + else: + break - except Exception as e: # Catch other potential exceptions + except Exception as e: # Catch other potential exceptions error_msg = f"Unexpected error during API call for {request_desc} (Attempt {attempt + 1}): {type(e).__name__}: {e}" print(error_msg) import traceback + traceback.print_exc() last_exception = e # Decide if this generic exception is retryable @@ -276,21 +378,29 @@ async def call_vertex_api_with_retry( # --- Failure Logging --- elapsed_time = time.monotonic() - start_time if model_name not in cog.api_stats: - cog.api_stats[model_name] = {'success': 0, 'failure': 0, 'retries': 0, 'total_time': 0.0, 'count': 0} - cog.api_stats[model_name]['failure'] += 1 - cog.api_stats[model_name]['total_time'] += elapsed_time - cog.api_stats[model_name]['count'] += 1 - print(f"API request failed for {request_desc} ({model_name}) after {attempt + 1} attempts in {elapsed_time:.2f}s.") + cog.api_stats[model_name] = { + "success": 0, + "failure": 0, + "retries": 0, + "total_time": 0.0, + "count": 0, + } + cog.api_stats[model_name]["failure"] += 1 + cog.api_stats[model_name]["total_time"] += elapsed_time + cog.api_stats[model_name]["count"] += 1 + print( + f"API request failed for {request_desc} ({model_name}) after {attempt + 1} attempts in {elapsed_time:.2f}s." + ) # Raise the last encountered exception or a generic one - raise last_exception or Exception(f"API request failed for {request_desc} after {API_RETRY_ATTEMPTS + 1} attempts.") + raise last_exception or Exception( + f"API request failed for {request_desc} after {API_RETRY_ATTEMPTS + 1} attempts." + ) # --- JSON Parsing and Validation Helper --- def parse_and_validate_json_response( - response_text: Optional[str], - schema: Dict[str, Any], - context_description: str + response_text: Optional[str], schema: Dict[str, Any], context_description: str ) -> Optional[Dict[str, Any]]: """ Parses the AI's response text, attempting to extract and validate a JSON object against a schema. @@ -308,53 +418,75 @@ def parse_and_validate_json_response( return None parsed_data = None - raw_json_text = response_text # Start with the full text + raw_json_text = response_text # Start with the full text # Attempt 1: Try parsing the whole string directly try: parsed_data = json.loads(raw_json_text) - print(f"Parsing ({context_description}): Successfully parsed entire response as JSON.") + print( + f"Parsing ({context_description}): Successfully parsed entire response as JSON." + ) except json.JSONDecodeError: # Attempt 2: Extract JSON object, handling optional markdown fences # More robust regex to handle potential leading/trailing text and variations - json_match = re.search(r'```(?:json)?\s*(\{.*\})\s*```|(\{.*\})', response_text, re.DOTALL | re.MULTILINE) + json_match = re.search( + r"```(?:json)?\s*(\{.*\})\s*```|(\{.*\})", + response_text, + re.DOTALL | re.MULTILINE, + ) if json_match: json_str = json_match.group(1) or json_match.group(2) if json_str: - raw_json_text = json_str # Use the extracted string for parsing + raw_json_text = json_str # Use the extracted string for parsing try: parsed_data = json.loads(raw_json_text) - print(f"Parsing ({context_description}): Successfully extracted and parsed JSON using regex.") + print( + f"Parsing ({context_description}): Successfully extracted and parsed JSON using regex." + ) except json.JSONDecodeError as e_inner: - print(f"Parsing ({context_description}): Regex found potential JSON, but it failed to parse: {e_inner}\nContent: {raw_json_text[:500]}") + print( + f"Parsing ({context_description}): Regex found potential JSON, but it failed to parse: {e_inner}\nContent: {raw_json_text[:500]}" + ) parsed_data = None else: - print(f"Parsing ({context_description}): Regex matched, but failed to capture JSON content.") + print( + f"Parsing ({context_description}): Regex matched, but failed to capture JSON content." + ) parsed_data = None else: - print(f"Parsing ({context_description}): Could not parse directly or extract JSON object using regex.\nContent: {raw_json_text[:500]}") + print( + f"Parsing ({context_description}): Could not parse directly or extract JSON object using regex.\nContent: {raw_json_text[:500]}" + ) parsed_data = None # Validation step if parsed_data is not None: if not isinstance(parsed_data, dict): - print(f"Parsing ({context_description}): Parsed data is not a dictionary: {type(parsed_data)}") - return None # Fail validation if not a dict + print( + f"Parsing ({context_description}): Parsed data is not a dictionary: {type(parsed_data)}" + ) + return None # Fail validation if not a dict try: jsonschema.validate(instance=parsed_data, schema=schema) - print(f"Parsing ({context_description}): JSON successfully validated against schema.") + print( + f"Parsing ({context_description}): JSON successfully validated against schema." + ) # Ensure default keys exist after validation parsed_data.setdefault("should_respond", False) parsed_data.setdefault("content", None) parsed_data.setdefault("react_with_emoji", None) return parsed_data except jsonschema.ValidationError as e: - print(f"Parsing ({context_description}): JSON failed schema validation: {e.message}") + print( + f"Parsing ({context_description}): JSON failed schema validation: {e.message}" + ) # Optionally log more details: e.path, e.schema_path, e.instance - return None # Validation failed - except Exception as e: # Catch other potential validation errors - print(f"Parsing ({context_description}): Unexpected error during JSON schema validation: {e}") + return None # Validation failed + except Exception as e: # Catch other potential validation errors + print( + f"Parsing ({context_description}): Unexpected error during JSON schema validation: {e}" + ) return None else: # Parsing failed before validation could occur @@ -362,7 +494,9 @@ def parse_and_validate_json_response( # --- Tool Processing --- -async def process_requested_tools(cog: 'WheatleyCog', function_call: 'generative_models.FunctionCall') -> 'Part': # Use string literals +async def process_requested_tools( + cog: "WheatleyCog", function_call: "generative_models.FunctionCall" +) -> "Part": # Use string literals """ Process a tool request specified by the AI's FunctionCall response. @@ -392,11 +526,18 @@ async def process_requested_tools(cog: 'WheatleyCog', function_call: 'generative # --- Tool Success Logging --- tool_elapsed_time = time.monotonic() - tool_start_time if function_name not in cog.tool_stats: - cog.tool_stats[function_name] = {'success': 0, 'failure': 0, 'total_time': 0.0, 'count': 0} - cog.tool_stats[function_name]['success'] += 1 - cog.tool_stats[function_name]['total_time'] += tool_elapsed_time - cog.tool_stats[function_name]['count'] += 1 - print(f"Tool '{function_name}' executed successfully in {tool_elapsed_time:.2f}s.") + cog.tool_stats[function_name] = { + "success": 0, + "failure": 0, + "total_time": 0.0, + "count": 0, + } + cog.tool_stats[function_name]["success"] += 1 + cog.tool_stats[function_name]["total_time"] += tool_elapsed_time + cog.tool_stats[function_name]["count"] += 1 + print( + f"Tool '{function_name}' executed successfully in {tool_elapsed_time:.2f}s." + ) # Prepare result for API - must be JSON serializable, typically a dict if not isinstance(result, dict): @@ -404,7 +545,9 @@ async def process_requested_tools(cog: 'WheatleyCog', function_call: 'generative if isinstance(result, (str, int, float, bool, list)) or result is None: result = {"result": result} else: - print(f"Warning: Tool '{function_name}' returned non-standard type {type(result)}. Attempting str conversion.") + print( + f"Warning: Tool '{function_name}' returned non-standard type {type(result)}. Attempting str conversion." + ) result = {"result": str(result)} tool_result_content = result @@ -413,13 +556,21 @@ async def process_requested_tools(cog: 'WheatleyCog', function_call: 'generative # --- Tool Failure Logging --- tool_elapsed_time = time.monotonic() - tool_start_time if function_name not in cog.tool_stats: - cog.tool_stats[function_name] = {'success': 0, 'failure': 0, 'total_time': 0.0, 'count': 0} - cog.tool_stats[function_name]['failure'] += 1 - cog.tool_stats[function_name]['total_time'] += tool_elapsed_time - cog.tool_stats[function_name]['count'] += 1 - error_message = f"Error executing tool {function_name}: {type(e).__name__}: {str(e)}" + cog.tool_stats[function_name] = { + "success": 0, + "failure": 0, + "total_time": 0.0, + "count": 0, + } + cog.tool_stats[function_name]["failure"] += 1 + cog.tool_stats[function_name]["total_time"] += tool_elapsed_time + cog.tool_stats[function_name]["count"] += 1 + error_message = ( + f"Error executing tool {function_name}: {type(e).__name__}: {str(e)}" + ) print(f"{error_message} (Took {tool_elapsed_time:.2f}s)") import traceback + traceback.print_exc() tool_result_content = {"error": error_message} else: @@ -427,10 +578,15 @@ async def process_requested_tools(cog: 'WheatleyCog', function_call: 'generative tool_elapsed_time = time.monotonic() - tool_start_time # Log attempt even if tool not found if function_name not in cog.tool_stats: - cog.tool_stats[function_name] = {'success': 0, 'failure': 0, 'total_time': 0.0, 'count': 0} - cog.tool_stats[function_name]['failure'] += 1 - cog.tool_stats[function_name]['total_time'] += tool_elapsed_time - cog.tool_stats[function_name]['count'] += 1 + cog.tool_stats[function_name] = { + "success": 0, + "failure": 0, + "total_time": 0.0, + "count": 0, + } + cog.tool_stats[function_name]["failure"] += 1 + cog.tool_stats[function_name]["total_time"] += tool_elapsed_time + cog.tool_stats[function_name]["count"] += 1 error_message = f"Tool '{function_name}' not found or implemented." print(f"{error_message} (Took {tool_elapsed_time:.2f}s)") tool_result_content = {"error": error_message} @@ -440,7 +596,9 @@ async def process_requested_tools(cog: 'WheatleyCog', function_call: 'generative # --- Main AI Response Function --- -async def get_ai_response(cog: 'WheatleyCog', message: discord.Message, model_name: Optional[str] = None) -> Dict[str, Any]: +async def get_ai_response( + cog: "WheatleyCog", message: discord.Message, model_name: Optional[str] = None +) -> Dict[str, Any]: """ Gets responses from the Vertex AI Gemini API, handling potential tool usage and returning the final parsed response. @@ -457,11 +615,14 @@ async def get_ai_response(cog: 'WheatleyCog', message: discord.Message, model_na - "fallback_initial": Optional minimal response if initial parsing failed critically (less likely with controlled generation). """ if not PROJECT_ID or not LOCATION: - return {"final_response": None, "error": "Google Cloud Project ID or Location not configured"} + return { + "final_response": None, + "error": "Google Cloud Project ID or Location not configured", + } channel_id = message.channel.id user_id = message.author.id - initial_parsed_data = None # Added to store initial parsed result + initial_parsed_data = None # Added to store initial parsed result final_parsed_data = None error_message = None fallback_response = None @@ -469,8 +630,10 @@ async def get_ai_response(cog: 'WheatleyCog', message: discord.Message, model_na try: # --- Build Prompt Components --- final_system_prompt = await build_dynamic_system_prompt(cog, message) - conversation_context_messages = gather_conversation_context(cog, channel_id, message.id) # Pass cog - memory_context = await get_memory_context(cog, message) # Pass cog + conversation_context_messages = gather_conversation_context( + cog, channel_id, message.id + ) # Pass cog + memory_context = await get_memory_context(cog, message) # Pass cog # --- Initialize Model --- # Tools are passed during model initialization in Vertex AI SDK @@ -480,7 +643,7 @@ async def get_ai_response(cog: 'WheatleyCog', message: discord.Message, model_na model = GenerativeModel( model_name or DEFAULT_MODEL, system_instruction=final_system_prompt, - tools=[vertex_tool] if vertex_tool else None + tools=[vertex_tool] if vertex_tool else None, ) # --- Prepare Message History (Contents) --- @@ -497,54 +660,81 @@ async def get_ai_response(cog: 'WheatleyCog', message: discord.Message, model_na # but this might confuse the turn structure. # contents.append(Content(role="model", parts=[Part.from_text(f"System Note: {memory_context}")])) - # Add conversation history for msg in conversation_context_messages: - role = msg.get("role", "user") # Default to user if role missing + role = msg.get("role", "user") # Default to user if role missing # Map roles if necessary (e.g., 'assistant' -> 'model') if role == "assistant": role = "model" elif role == "system": - # Skip system messages here, handled by system_instruction - continue + # Skip system messages here, handled by system_instruction + continue # Handle potential multimodal content in history (if stored that way) if isinstance(msg.get("content"), list): - parts = [Part.from_text(part["text"]) if part["type"] == "text" else Part.from_uri(part["image_url"]["url"], mime_type=part["image_url"]["url"].split(";")[0].split(":")[1]) if part["type"] == "image_url" else None for part in msg["content"]] - parts = [p for p in parts if p] # Filter out None parts - if parts: - contents.append(Content(role=role, parts=parts)) + parts = [ + ( + Part.from_text(part["text"]) + if part["type"] == "text" + else ( + Part.from_uri( + part["image_url"]["url"], + mime_type=part["image_url"]["url"] + .split(";")[0] + .split(":")[1], + ) + if part["type"] == "image_url" + else None + ) + ) + for part in msg["content"] + ] + parts = [p for p in parts if p] # Filter out None parts + if parts: + contents.append(Content(role=role, parts=parts)) elif isinstance(msg.get("content"), str): - contents.append(Content(role=role, parts=[Part.from_text(msg["content"])])) - + contents.append( + Content(role=role, parts=[Part.from_text(msg["content"])]) + ) # --- Prepare the current message content (potentially multimodal) --- current_message_parts = [] - formatted_current_message = format_message(cog, message) # Pass cog if needed + formatted_current_message = format_message(cog, message) # Pass cog if needed # --- Construct text content, including reply context if applicable --- text_content = "" - if formatted_current_message.get("is_reply") and formatted_current_message.get("replied_to_author_name"): + if formatted_current_message.get("is_reply") and formatted_current_message.get( + "replied_to_author_name" + ): reply_author = formatted_current_message["replied_to_author_name"] - reply_content = formatted_current_message.get("replied_to_content", "...") # Use ellipsis if content missing + reply_content = formatted_current_message.get( + "replied_to_content", "..." + ) # Use ellipsis if content missing # Truncate long replied content to keep context concise max_reply_len = 150 if len(reply_content) > max_reply_len: reply_content = reply_content[:max_reply_len] + "..." - text_content += f"(Replying to {reply_author}: \"{reply_content}\")\n" + text_content += f'(Replying to {reply_author}: "{reply_content}")\n' # Add current message author and content text_content += f"{formatted_current_message['author']['display_name']}: {formatted_current_message['content']}" # Add mention details if formatted_current_message.get("mentioned_users_details"): - mentions_str = ", ".join([f"{m['display_name']}(id:{m['id']})" for m in formatted_current_message["mentioned_users_details"]]) + mentions_str = ", ".join( + [ + f"{m['display_name']}(id:{m['id']})" + for m in formatted_current_message["mentioned_users_details"] + ] + ) text_content += f"\n(Message Details: Mentions=[{mentions_str}])" current_message_parts.append(Part.from_text(text_content)) # --- End text content construction --- if message.attachments: - print(f"Processing {len(message.attachments)} attachments for message {message.id}") + print( + f"Processing {len(message.attachments)} attachments for message {message.id}" + ) for attachment in message.attachments: mime_type = attachment.content_type file_url = attachment.url @@ -552,7 +742,13 @@ async def get_ai_response(cog: 'WheatleyCog', message: discord.Message, model_na # Check if MIME type is supported for URI input by Gemini # Expand this list based on Gemini documentation for supported types via URI - supported_mime_prefixes = ["image/", "video/", "audio/", "text/plain", "application/pdf"] + supported_mime_prefixes = [ + "image/", + "video/", + "audio/", + "text/plain", + "application/pdf", + ] is_supported = False if mime_type: for prefix in supported_mime_prefixes: @@ -572,19 +768,34 @@ async def get_ai_response(cog: 'WheatleyCog', message: discord.Message, model_na # 2. Add the URI part # Ensure mime_type doesn't contain parameters like '; charset=...' if the API doesn't like them - clean_mime_type = mime_type.split(';')[0] - current_message_parts.append(Part.from_uri(uri=file_url, mime_type=clean_mime_type)) - print(f"Added URI part for attachment: {filename} ({clean_mime_type}) using URL: {file_url}") + clean_mime_type = mime_type.split(";")[0] + current_message_parts.append( + Part.from_uri(uri=file_url, mime_type=clean_mime_type) + ) + print( + f"Added URI part for attachment: {filename} ({clean_mime_type}) using URL: {file_url}" + ) except Exception as e: - print(f"Error creating Part for attachment {filename} ({mime_type}): {e}") + print( + f"Error creating Part for attachment {filename} ({mime_type}): {e}" + ) # Optionally add a text part indicating the error - current_message_parts.append(Part.from_text(f"(System Note: Failed to process attachment '{filename}' - {e})")) + current_message_parts.append( + Part.from_text( + f"(System Note: Failed to process attachment '{filename}' - {e})" + ) + ) else: - print(f"Skipping unsupported or invalid attachment: {filename} (Type: {mime_type}, URL: {file_url})") + print( + f"Skipping unsupported or invalid attachment: {filename} (Type: {mime_type}, URL: {file_url})" + ) # Optionally inform the AI that an unsupported file was attached - current_message_parts.append(Part.from_text(f"(System Note: User attached an unsupported file '{filename}' of type '{mime_type}' which cannot be processed.)")) - + current_message_parts.append( + Part.from_text( + f"(System Note: User attached an unsupported file '{filename}' of type '{mime_type}' which cannot be processed.)" + ) + ) # Ensure there's always *some* content part, even if only text or errors if current_message_parts: @@ -593,12 +804,11 @@ async def get_ai_response(cog: 'WheatleyCog', message: discord.Message, model_na print("Warning: No content parts generated for user message.") contents.append(Content(role="user", parts=[Part.from_text("")])) - # --- First API Call (Check for Tool Use) --- print("Making initial API call to check for tool use...") generation_config_initial = GenerationConfig( temperature=0.75, - max_output_tokens=10000, # Adjust as needed + max_output_tokens=10000, # Adjust as needed # No response schema needed for the initial call, just checking for function calls ) @@ -608,56 +818,70 @@ async def get_ai_response(cog: 'WheatleyCog', message: discord.Message, model_na contents=contents, generation_config=generation_config_initial, safety_settings=STANDARD_SAFETY_SETTINGS, - request_desc=f"Initial response check for message {message.id}" + request_desc=f"Initial response check for message {message.id}", ) # --- Log Raw Request and Response --- try: # Log the request payload (contents) - request_payload_log = [{"role": c.role, "parts": [str(p) for p in c.parts]} for c in contents] # Convert parts to string for logging - print(f"--- Raw API Request (Initial Call) ---\n{json.dumps(request_payload_log, indent=2)}\n------------------------------------") + request_payload_log = [ + {"role": c.role, "parts": [str(p) for p in c.parts]} for c in contents + ] # Convert parts to string for logging + print( + f"--- Raw API Request (Initial Call) ---\n{json.dumps(request_payload_log, indent=2)}\n------------------------------------" + ) # Log the raw response object - print(f"--- Raw API Response (Initial Call) ---\n{initial_response}\n-----------------------------------") + print( + f"--- Raw API Response (Initial Call) ---\n{initial_response}\n-----------------------------------" + ) except Exception as log_e: print(f"Error logging raw request/response: {log_e}") # --- End Logging --- if not initial_response or not initial_response.candidates: - raise Exception("Initial API call returned no response or candidates.") + raise Exception("Initial API call returned no response or candidates.") # --- Check for Tool Call FIRST --- candidate = initial_response.candidates[0] - finish_reason = getattr(candidate, 'finish_reason', None) + finish_reason = getattr(candidate, "finish_reason", None) function_call = None - function_call_part_content = None # Store the AI's request message content + function_call_part_content = None # Store the AI's request message content # Check primarily for the *presence* of a function call part, # as finish_reason might be STOP even with a function call. - if hasattr(candidate, 'content') and candidate.content.parts: + if hasattr(candidate, "content") and candidate.content.parts: for part in candidate.content.parts: - if hasattr(part, 'function_call'): - function_call = part.function_call # Assign the value + if hasattr(part, "function_call"): + function_call = part.function_call # Assign the value # Add check to ensure function_call is not None before proceeding if function_call: # Store the whole content containing the call to add to history later function_call_part_content = candidate.content - print(f"AI requested tool (found function_call part): {function_call.name}") - break # Found a valid function call part + print( + f"AI requested tool (found function_call part): {function_call.name}" + ) + break # Found a valid function call part else: # Log if the attribute exists but is None (unexpected case) - print("Warning: Found part with 'function_call' attribute, but its value was None.") + print( + "Warning: Found part with 'function_call' attribute, but its value was None." + ) # --- Process Tool Call or Handle Direct Response --- if function_call and function_call_part_content: # --- Tool Call Path --- - initial_parsed_data = None # No initial JSON expected if tool is called + initial_parsed_data = None # No initial JSON expected if tool is called # Process the tool request tool_response_part = await process_requested_tools(cog, function_call) # Append the AI's request and the tool's response to the history - contents.append(candidate.content) # Add the AI's function call request message - contents.append(Content(role="function", parts=[tool_response_part])) # Add the function response part + contents.append( + candidate.content + ) # Add the AI's function call request message + contents.append( + Content(role="function", parts=[tool_response_part]) + ) # Add the function response part # --- Second API Call (Get Final Response After Tool) --- print("Making follow-up API call with tool results...") @@ -665,115 +889,162 @@ async def get_ai_response(cog: 'WheatleyCog', message: discord.Message, model_na # Initialize a NEW model instance WITHOUT tools for the follow-up call # This prevents the InvalidArgument error when specifying response schema model_final = GenerativeModel( - model_name or DEFAULT_MODEL, # Use the same model name - system_instruction=final_system_prompt # Keep the same system prompt + model_name or DEFAULT_MODEL, # Use the same model name + system_instruction=final_system_prompt, # Keep the same system prompt # Omit the 'tools' parameter here ) # Preprocess the schema before passing it to GenerationConfig - processed_response_schema = _preprocess_schema_for_vertex(RESPONSE_SCHEMA['schema']) + processed_response_schema = _preprocess_schema_for_vertex( + RESPONSE_SCHEMA["schema"] + ) generation_config_final = GenerationConfig( - temperature=0.75, # Keep original temperature for final response - max_output_tokens=10000, # Keep original max tokens + temperature=0.75, # Keep original temperature for final response + max_output_tokens=10000, # Keep original max tokens response_mime_type="application/json", - response_schema=processed_response_schema # Use preprocessed schema + response_schema=processed_response_schema, # Use preprocessed schema ) - final_response_obj = await call_vertex_api_with_retry( # Renamed variable for clarity + final_response_obj = await call_vertex_api_with_retry( # Renamed variable for clarity cog=cog, - model=model_final, # Use the new model instance WITHOUT tools - contents=contents, # History now includes tool call/response + model=model_final, # Use the new model instance WITHOUT tools + contents=contents, # History now includes tool call/response generation_config=generation_config_final, safety_settings=STANDARD_SAFETY_SETTINGS, - request_desc=f"Follow-up response for message {message.id} after tool execution" + request_desc=f"Follow-up response for message {message.id} after tool execution", ) if not final_response_obj or not final_response_obj.candidates: - raise Exception("Follow-up API call returned no response or candidates.") + raise Exception( + "Follow-up API call returned no response or candidates." + ) - final_response_text = _get_response_text(final_response_obj) # Use helper + final_response_text = _get_response_text(final_response_obj) # Use helper final_parsed_data = parse_and_validate_json_response( - final_response_text, RESPONSE_SCHEMA['schema'], "final response after tools" + final_response_text, + RESPONSE_SCHEMA["schema"], + "final response after tools", ) # Handle validation failure - Re-prompt loop (simplified example) if final_parsed_data is None: - print("Warning: Final response failed validation. Attempting re-prompt (basic)...") + print( + "Warning: Final response failed validation. Attempting re-prompt (basic)..." + ) # Construct a basic re-prompt message - contents.append(final_response_obj.candidates[0].content) # Add the invalid response - contents.append(Content(role="user", parts=[Part.from_text( - "Your previous JSON response was invalid or did not match the required schema. " - f"Please provide the response again, strictly adhering to this schema:\n{json.dumps(RESPONSE_SCHEMA['schema'], indent=2)}" - )])) + contents.append( + final_response_obj.candidates[0].content + ) # Add the invalid response + contents.append( + Content( + role="user", + parts=[ + Part.from_text( + "Your previous JSON response was invalid or did not match the required schema. " + f"Please provide the response again, strictly adhering to this schema:\n{json.dumps(RESPONSE_SCHEMA['schema'], indent=2)}" + ) + ], + ) + ) # Retry the final call retry_response_obj = await call_vertex_api_with_retry( - cog=cog, model=model, contents=contents, - generation_config=generation_config_final, safety_settings=STANDARD_SAFETY_SETTINGS, - request_desc=f"Re-prompt validation failure for message {message.id}" + cog=cog, + model=model, + contents=contents, + generation_config=generation_config_final, + safety_settings=STANDARD_SAFETY_SETTINGS, + request_desc=f"Re-prompt validation failure for message {message.id}", ) if retry_response_obj and retry_response_obj.candidates: - final_response_text = _get_response_text(retry_response_obj) # Use helper + final_response_text = _get_response_text( + retry_response_obj + ) # Use helper final_parsed_data = parse_and_validate_json_response( - final_response_text, RESPONSE_SCHEMA['schema'], "re-prompted final response" + final_response_text, + RESPONSE_SCHEMA["schema"], + "re-prompted final response", ) if final_parsed_data is None: - print("Critical Error: Re-prompted response still failed validation.") - error_message = "Failed to get valid JSON response after re-prompting." + print( + "Critical Error: Re-prompted response still failed validation." + ) + error_message = ( + "Failed to get valid JSON response after re-prompting." + ) else: - error_message = "Failed to get response after re-prompting." + error_message = "Failed to get response after re-prompting." # final_parsed_data is now set (or None if failed) after tool use and potential re-prompt else: # --- No Tool Call Path --- print("No tool call requested by AI. Processing initial response as final.") # Attempt to parse the initial response text directly. - initial_response_text = _get_response_text(initial_response) # Use helper + initial_response_text = _get_response_text(initial_response) # Use helper # Validate against the final schema because this IS the final response. final_parsed_data = parse_and_validate_json_response( - initial_response_text, RESPONSE_SCHEMA['schema'], "final response (no tools)" + initial_response_text, + RESPONSE_SCHEMA["schema"], + "final response (no tools)", + ) + initial_parsed_data = ( + final_parsed_data # Keep initial_parsed_data consistent for return dict ) - initial_parsed_data = final_parsed_data # Keep initial_parsed_data consistent for return dict if final_parsed_data is None: - # This means the initial response failed validation. - print("Critical Error: Initial response failed validation (no tools).") - error_message = "Failed to parse/validate initial AI JSON response." - # Create a basic fallback if the bot was mentioned - replied_to_bot = message.reference and message.reference.resolved and message.reference.resolved.author == cog.bot.user - if cog.bot.user.mentioned_in(message) or replied_to_bot: - fallback_response = {"should_respond": True, "content": "...", "react_with_emoji": "❓"} + # This means the initial response failed validation. + print("Critical Error: Initial response failed validation (no tools).") + error_message = "Failed to parse/validate initial AI JSON response." + # Create a basic fallback if the bot was mentioned + replied_to_bot = ( + message.reference + and message.reference.resolved + and message.reference.resolved.author == cog.bot.user + ) + if cog.bot.user.mentioned_in(message) or replied_to_bot: + fallback_response = { + "should_respond": True, + "content": "...", + "react_with_emoji": "❓", + } # initial_parsed_data is not used in this path, only final_parsed_data matters - except Exception as e: error_message = f"Error in get_ai_response main loop for message {message.id}: {type(e).__name__}: {str(e)}" print(error_message) import traceback + traceback.print_exc() # Ensure both are None on critical error initial_parsed_data = None final_parsed_data = None return { - "initial_response": initial_parsed_data, # Return parsed initial data - "final_response": final_parsed_data, # Return parsed final data + "initial_response": initial_parsed_data, # Return parsed initial data + "final_response": final_parsed_data, # Return parsed final data "error": error_message, - "fallback_initial": fallback_response + "fallback_initial": fallback_response, } # --- Proactive AI Response Function --- -async def get_proactive_ai_response(cog: 'WheatleyCog', message: discord.Message, trigger_reason: str) -> Dict[str, Any]: +async def get_proactive_ai_response( + cog: "WheatleyCog", message: discord.Message, trigger_reason: str +) -> Dict[str, Any]: """Generates a proactive response based on a specific trigger using Vertex AI.""" if not PROJECT_ID or not LOCATION: - return {"should_respond": False, "content": None, "react_with_emoji": None, "error": "Google Cloud Project ID or Location not configured"} + return { + "should_respond": False, + "content": None, + "react_with_emoji": None, + "error": "Google Cloud Project ID or Location not configured", + } print(f"--- Proactive Response Triggered: {trigger_reason} ---") channel_id = message.channel.id final_parsed_data = None error_message = None - plan = None # Variable to store the plan + plan = None # Variable to store the plan try: # --- Build Context for Planning --- @@ -783,96 +1054,137 @@ async def get_proactive_ai_response(cog: 'WheatleyCog', message: discord.Message f"Current Mood: {cog.current_mood}", ] # Add recent messages summary - summary_data = await get_conversation_summary(cog, str(channel_id), message_limit=15) # Use tool function + summary_data = await get_conversation_summary( + cog, str(channel_id), message_limit=15 + ) # Use tool function if summary_data and not summary_data.get("error"): - planning_context_parts.append(f"Recent Conversation Summary: {summary_data['summary']}") + planning_context_parts.append( + f"Recent Conversation Summary: {summary_data['summary']}" + ) # Add active topics active_topics_data = cog.active_topics.get(channel_id) if active_topics_data and active_topics_data.get("topics"): - topics_str = ", ".join([f"{t['topic']} ({t['score']:.1f})" for t in active_topics_data["topics"][:3]]) + topics_str = ", ".join( + [ + f"{t['topic']} ({t['score']:.1f})" + for t in active_topics_data["topics"][:3] + ] + ) planning_context_parts.append(f"Active Topics: {topics_str}") # Add sentiment sentiment_data = cog.conversation_sentiment.get(channel_id) if sentiment_data: - planning_context_parts.append(f"Conversation Sentiment: {sentiment_data.get('overall', 'N/A')} (Intensity: {sentiment_data.get('intensity', 0):.1f})") + planning_context_parts.append( + f"Conversation Sentiment: {sentiment_data.get('overall', 'N/A')} (Intensity: {sentiment_data.get('intensity', 0):.1f})" + ) # Add Wheatley's interests (Note: Interests are likely disabled/removed for Wheatley, this might fetch nothing) try: interests = await cog.memory_manager.get_interests(limit=5) if interests: interests_str = ", ".join([f"{t} ({l:.1f})" for t, l in interests]) - planning_context_parts.append(f"Wheatley's Interests: {interests_str}") # Changed text - except Exception as int_e: print(f"Error getting interests for planning: {int_e}") + planning_context_parts.append( + f"Wheatley's Interests: {interests_str}" + ) # Changed text + except Exception as int_e: + print(f"Error getting interests for planning: {int_e}") planning_context = "\n".join(planning_context_parts) # --- Planning Step --- print("Generating proactive response plan...") planning_prompt_messages = [ - {"role": "system", "content": "You are Wheatley's planning module. Analyze the context and trigger reason to decide if Wheatley should respond proactively and, if so, outline a plan (goal, key info, tone). Focus on natural, in-character engagement (rambling, insecure, bad ideas). Respond ONLY with JSON matching the provided schema."}, # Updated system prompt - {"role": "user", "content": f"Context:\n{planning_context}\n\nBased on this context and the trigger reason, create a plan for Wheatley's proactive response."} # Updated user prompt + { + "role": "system", + "content": "You are Wheatley's planning module. Analyze the context and trigger reason to decide if Wheatley should respond proactively and, if so, outline a plan (goal, key info, tone). Focus on natural, in-character engagement (rambling, insecure, bad ideas). Respond ONLY with JSON matching the provided schema.", + }, # Updated system prompt + { + "role": "user", + "content": f"Context:\n{planning_context}\n\nBased on this context and the trigger reason, create a plan for Wheatley's proactive response.", + }, # Updated user prompt ] plan = await get_internal_ai_json_response( cog=cog, prompt_messages=planning_prompt_messages, task_description=f"Proactive Planning ({trigger_reason})", - response_schema_dict=PROACTIVE_PLAN_SCHEMA['schema'], - model_name_override=FALLBACK_MODEL, # Use a potentially faster/cheaper model for planning + response_schema_dict=PROACTIVE_PLAN_SCHEMA["schema"], + model_name_override=FALLBACK_MODEL, # Use a potentially faster/cheaper model for planning temperature=0.5, - max_tokens=300 + max_tokens=300, ) if not plan or not plan.get("should_respond"): - reason = plan.get('reasoning', 'Planning failed or decided against responding.') if plan else 'Planning failed.' + reason = ( + plan.get("reasoning", "Planning failed or decided against responding.") + if plan + else "Planning failed." + ) print(f"Proactive response aborted by plan: {reason}") - return {"should_respond": False, "content": None, "react_with_emoji": None, "note": f"Plan: {reason}"} + return { + "should_respond": False, + "content": None, + "react_with_emoji": None, + "note": f"Plan: {reason}", + } - print(f"Proactive Plan Generated: Goal='{plan.get('response_goal', 'N/A')}', Reasoning='{plan.get('reasoning', 'N/A')}'") + print( + f"Proactive Plan Generated: Goal='{plan.get('response_goal', 'N/A')}', Reasoning='{plan.get('reasoning', 'N/A')}'" + ) # --- Build Final Proactive Prompt using Plan --- persistent_traits = await cog.memory_manager.get_all_personality_traits() if not persistent_traits: - persistent_traits = {} # Wheatley doesn't use these Gurt traits + persistent_traits = {} # Wheatley doesn't use these Gurt traits final_proactive_prompt_parts = [ - f"You are Wheatley, an Aperture Science Personality Core. Your tone is rambling, insecure, uses British slang, and you often have terrible ideas you think are brilliant.", # Updated personality description - # Removed Gurt-specific traits - # Removed mood reference as it's disabled for Wheatley - # Incorporate Plan Details: - f"You decided to respond proactively (maybe?). Trigger Reason: {trigger_reason}.", # Wheatley-style uncertainty - f"Your Brilliant Plan (Goal): {plan.get('response_goal', 'Say something... probably helpful?')}.", # Wheatley-style goal - f"Reasoning: {plan.get('reasoning', 'N/A')}.", + f"You are Wheatley, an Aperture Science Personality Core. Your tone is rambling, insecure, uses British slang, and you often have terrible ideas you think are brilliant.", # Updated personality description + # Removed Gurt-specific traits + # Removed mood reference as it's disabled for Wheatley + # Incorporate Plan Details: + f"You decided to respond proactively (maybe?). Trigger Reason: {trigger_reason}.", # Wheatley-style uncertainty + f"Your Brilliant Plan (Goal): {plan.get('response_goal', 'Say something... probably helpful?')}.", # Wheatley-style goal + f"Reasoning: {plan.get('reasoning', 'N/A')}.", ] - if plan.get('key_info_to_include'): - info_str = "; ".join(plan['key_info_to_include']) + if plan.get("key_info_to_include"): + info_str = "; ".join(plan["key_info_to_include"]) final_proactive_prompt_parts.append(f"Consider mentioning: {info_str}") - if plan.get('suggested_tone'): - final_proactive_prompt_parts.append(f"Adjust tone to be: {plan['suggested_tone']}") + if plan.get("suggested_tone"): + final_proactive_prompt_parts.append( + f"Adjust tone to be: {plan['suggested_tone']}" + ) - final_proactive_prompt_parts.append("Generate a casual, in-character message based on the plan and context. Keep it relatively short and natural-sounding.") + final_proactive_prompt_parts.append( + "Generate a casual, in-character message based on the plan and context. Keep it relatively short and natural-sounding." + ) final_proactive_system_prompt = "\n\n".join(final_proactive_prompt_parts) # --- Initialize Final Model --- model = GenerativeModel( - model_name=DEFAULT_MODEL, - system_instruction=final_proactive_system_prompt + model_name=DEFAULT_MODEL, system_instruction=final_proactive_system_prompt ) # --- Prepare Final Contents --- contents = [ - Content(role="user", parts=[Part.from_text( - f"Generate the response based on your plan. **CRITICAL: Your response MUST be ONLY the raw JSON object matching this schema:**\n\n{json.dumps(RESPONSE_SCHEMA['schema'], indent=2)}\n\n**Ensure nothing precedes or follows the JSON.**" - )]) + Content( + role="user", + parts=[ + Part.from_text( + f"Generate the response based on your plan. **CRITICAL: Your response MUST be ONLY the raw JSON object matching this schema:**\n\n{json.dumps(RESPONSE_SCHEMA['schema'], indent=2)}\n\n**Ensure nothing precedes or follows the JSON.**" + ) + ], + ) ] # --- Call Final LLM API --- # Preprocess the schema before passing it to GenerationConfig - processed_response_schema_proactive = _preprocess_schema_for_vertex(RESPONSE_SCHEMA['schema']) + processed_response_schema_proactive = _preprocess_schema_for_vertex( + RESPONSE_SCHEMA["schema"] + ) generation_config_final = GenerationConfig( - temperature=0.8, # Use original proactive temp + temperature=0.8, # Use original proactive temp max_output_tokens=200, response_mime_type="application/json", - response_schema=processed_response_schema_proactive # Use preprocessed schema + response_schema=processed_response_schema_proactive, # Use preprocessed schema ) response_obj = await call_vertex_api_with_retry( @@ -881,62 +1193,93 @@ async def get_proactive_ai_response(cog: 'WheatleyCog', message: discord.Message contents=contents, generation_config=generation_config_final, safety_settings=STANDARD_SAFETY_SETTINGS, - request_desc=f"Final proactive response for channel {channel_id} ({trigger_reason})" + request_desc=f"Final proactive response for channel {channel_id} ({trigger_reason})", ) if not response_obj or not response_obj.candidates: - raise Exception("Final proactive API call returned no response or candidates.") + raise Exception( + "Final proactive API call returned no response or candidates." + ) # --- Parse and Validate Final Response --- final_response_text = _get_response_text(response_obj) final_parsed_data = parse_and_validate_json_response( - final_response_text, RESPONSE_SCHEMA['schema'], f"final proactive response ({trigger_reason})" + final_response_text, + RESPONSE_SCHEMA["schema"], + f"final proactive response ({trigger_reason})", ) if final_parsed_data is None: - print(f"Warning: Failed to parse/validate final proactive JSON response for {trigger_reason}.") - final_parsed_data = {"should_respond": False, "content": None, "react_with_emoji": None, "note": "Fallback - Failed to parse/validate final proactive JSON"} + print( + f"Warning: Failed to parse/validate final proactive JSON response for {trigger_reason}." + ) + final_parsed_data = { + "should_respond": False, + "content": None, + "react_with_emoji": None, + "note": "Fallback - Failed to parse/validate final proactive JSON", + } else: - # --- Cache Bot Response --- - if final_parsed_data.get("should_respond") and final_parsed_data.get("content"): - bot_response_cache_entry = { - "id": f"bot_proactive_{message.id}_{int(time.time())}", - "author": {"id": str(cog.bot.user.id), "name": cog.bot.user.name, "display_name": cog.bot.user.display_name, "bot": True}, - "content": final_parsed_data.get("content", ""), "created_at": datetime.datetime.now().isoformat(), - "attachments": [], "embeds": False, "mentions": [], "replied_to_message_id": None, - "channel": message.channel, "guild": message.guild, "reference": None, "mentioned_users_details": [] - } - cog.message_cache['by_channel'].setdefault(channel_id, []).append(bot_response_cache_entry) - cog.message_cache['global_recent'].append(bot_response_cache_entry) - cog.bot_last_spoke[channel_id] = time.time() - # Removed Gurt-specific participation tracking - + # --- Cache Bot Response --- + if final_parsed_data.get("should_respond") and final_parsed_data.get( + "content" + ): + bot_response_cache_entry = { + "id": f"bot_proactive_{message.id}_{int(time.time())}", + "author": { + "id": str(cog.bot.user.id), + "name": cog.bot.user.name, + "display_name": cog.bot.user.display_name, + "bot": True, + }, + "content": final_parsed_data.get("content", ""), + "created_at": datetime.datetime.now().isoformat(), + "attachments": [], + "embeds": False, + "mentions": [], + "replied_to_message_id": None, + "channel": message.channel, + "guild": message.guild, + "reference": None, + "mentioned_users_details": [], + } + cog.message_cache["by_channel"].setdefault(channel_id, []).append( + bot_response_cache_entry + ) + cog.message_cache["global_recent"].append(bot_response_cache_entry) + cog.bot_last_spoke[channel_id] = time.time() + # Removed Gurt-specific participation tracking except Exception as e: error_message = f"Error getting proactive AI response for channel {channel_id} ({trigger_reason}): {type(e).__name__}: {str(e)}" print(error_message) - final_parsed_data = {"should_respond": False, "content": None, "react_with_emoji": None, "error": error_message} + final_parsed_data = { + "should_respond": False, + "content": None, + "react_with_emoji": None, + "error": error_message, + } # Ensure default keys exist final_parsed_data.setdefault("should_respond", False) final_parsed_data.setdefault("content", None) final_parsed_data.setdefault("react_with_emoji", None) if error_message and "error" not in final_parsed_data: - final_parsed_data["error"] = error_message + final_parsed_data["error"] = error_message return final_parsed_data # --- Internal AI Call for Specific Tasks --- async def get_internal_ai_json_response( - cog: 'WheatleyCog', - prompt_messages: List[Dict[str, Any]], # Keep this format + cog: "WheatleyCog", + prompt_messages: List[Dict[str, Any]], # Keep this format task_description: str, - response_schema_dict: Dict[str, Any], # Expect schema as dict + response_schema_dict: Dict[str, Any], # Expect schema as dict model_name: Optional[str] = None, temperature: float = 0.7, max_tokens: int = 5000, -) -> Optional[Dict[str, Any]]: # Keep return type hint simple +) -> Optional[Dict[str, Any]]: # Keep return type hint simple """ Makes a Vertex AI call expecting a specific JSON response format for internal tasks. @@ -953,12 +1296,14 @@ async def get_internal_ai_json_response( The parsed and validated JSON dictionary if successful, None otherwise. """ if not PROJECT_ID or not LOCATION: - print(f"Error in get_internal_ai_json_response ({task_description}): GCP Project/Location not set.") + print( + f"Error in get_internal_ai_json_response ({task_description}): GCP Project/Location not set." + ) return None final_parsed_data = None error_occurred = None - request_payload_for_logging = {} # For logging + request_payload_for_logging = {} # For logging try: # --- Convert prompt messages to Vertex AI Content format --- @@ -974,13 +1319,15 @@ async def get_internal_ai_json_response( else: # Append subsequent system messages to the instruction system_instruction += "\n\n" + content_text - continue # Skip adding system messages to contents list + continue # Skip adding system messages to contents list elif role == "assistant": role = "model" # --- Process content (string or list) --- content_value = msg.get("content") - message_parts: List[Part] = [] # Initialize list to hold parts for this message + message_parts: List[Part] = ( + [] + ) # Initialize list to hold parts for this message if isinstance(content_value, str): # Handle simple string content @@ -998,24 +1345,37 @@ async def get_internal_ai_json_response( if mime_type and base64_data: try: image_bytes = base64.b64decode(base64_data) - message_parts.append(Part.from_data(data=image_bytes, mime_type=mime_type)) + message_parts.append( + Part.from_data( + data=image_bytes, mime_type=mime_type + ) + ) except Exception as decode_err: - print(f"Error decoding/adding image part in get_internal_ai_json_response: {decode_err}") + print( + f"Error decoding/adding image part in get_internal_ai_json_response: {decode_err}" + ) # Optionally add a placeholder text part indicating failure - message_parts.append(Part.from_text("(System Note: Failed to process an image part)")) + message_parts.append( + Part.from_text( + "(System Note: Failed to process an image part)" + ) + ) else: - print("Warning: image_data part missing mime_type or data.") + print("Warning: image_data part missing mime_type or data.") else: - print(f"Warning: Unknown part type '{part_type}' in internal prompt message.") + print( + f"Warning: Unknown part type '{part_type}' in internal prompt message." + ) else: - print(f"Warning: Unexpected content type '{type(content_value)}' in internal prompt message.") + print( + f"Warning: Unexpected content type '{type(content_value)}' in internal prompt message." + ) # Add the content object if parts were generated if message_parts: contents.append(Content(role=role, parts=message_parts)) else: - print(f"Warning: No parts generated for message role '{role}'.") - + print(f"Warning: No parts generated for message role '{role}'.") # Add the critical JSON instruction to the last user message or as a new user message json_instruction_content = ( @@ -1024,15 +1384,16 @@ async def get_internal_ai_json_response( f"**Ensure nothing precedes or follows the JSON.**" ) if contents and contents[-1].role == "user": - contents[-1].parts.append(Part.from_text(f"\n\n{json_instruction_content}")) + contents[-1].parts.append(Part.from_text(f"\n\n{json_instruction_content}")) else: - contents.append(Content(role="user", parts=[Part.from_text(json_instruction_content)])) - + contents.append( + Content(role="user", parts=[Part.from_text(json_instruction_content)]) + ) # --- Initialize Model --- model = GenerativeModel( - model_name=model_name or DEFAULT_MODEL, # Use keyword argument - system_instruction=system_instruction + model_name=model_name or DEFAULT_MODEL, # Use keyword argument + system_instruction=system_instruction, # No tools needed for internal JSON tasks usually ) @@ -1043,41 +1404,49 @@ async def get_internal_ai_json_response( temperature=temperature, max_output_tokens=max_tokens, response_mime_type="application/json", - response_schema=processed_schema_internal # Use preprocessed schema + response_schema=processed_schema_internal, # Use preprocessed schema ) # Prepare payload for logging (approximate) request_payload_for_logging = { "model": model._model_name, "system_instruction": system_instruction, - "contents": [ # Simplified representation for logging - {"role": c.role, "parts": [p.text if hasattr(p,'text') else str(type(p)) for p in c.parts]} - for c in contents - ], - # Use the original generation_config dict directly for logging - "generation_config": generation_config # It's already a dict - } + "contents": [ # Simplified representation for logging + { + "role": c.role, + "parts": [ + p.text if hasattr(p, "text") else str(type(p)) for p in c.parts + ], + } + for c in contents + ], + # Use the original generation_config dict directly for logging + "generation_config": generation_config, # It's already a dict + } # --- Add detailed logging for raw request --- try: print(f"--- Raw request payload for {task_description} ---") # Use json.dumps for pretty printing, handle potential errors - print(json.dumps(request_payload_for_logging, indent=2, default=str)) # Use default=str as fallback + print( + json.dumps(request_payload_for_logging, indent=2, default=str) + ) # Use default=str as fallback print(f"--- End Raw request payload ---") except Exception as req_log_e: print(f"Error logging raw request payload: {req_log_e}") - print(f"Payload causing error: {request_payload_for_logging}") # Print the raw dict on error + print( + f"Payload causing error: {request_payload_for_logging}" + ) # Print the raw dict on error # --- End detailed logging --- - # --- Call API --- response_obj = await call_vertex_api_with_retry( cog=cog, model=model, contents=contents, generation_config=generation_config, - safety_settings=STANDARD_SAFETY_SETTINGS, # Use standard safety - request_desc=task_description + safety_settings=STANDARD_SAFETY_SETTINGS, # Use standard safety + request_desc=task_description, ) if not response_obj or not response_obj.candidates: @@ -1094,24 +1463,37 @@ async def get_internal_ai_json_response( print(f"Parsing ({task_description}): Using response_obj.text for JSON.") final_parsed_data = parse_and_validate_json_response( - final_response_text, response_schema_dict, f"internal task ({task_description})" + final_response_text, + response_schema_dict, + f"internal task ({task_description})", ) if final_parsed_data is None: - print(f"Warning: Internal task '{task_description}' failed JSON validation.") + print( + f"Warning: Internal task '{task_description}' failed JSON validation." + ) # No re-prompting for internal tasks, just return None except Exception as e: - print(f"Error in get_internal_ai_json_response ({task_description}): {type(e).__name__}: {e}") + print( + f"Error in get_internal_ai_json_response ({task_description}): {type(e).__name__}: {e}" + ) error_occurred = e import traceback + traceback.print_exc() final_parsed_data = None finally: # Log the call try: # Pass the simplified payload for logging - await log_internal_api_call(cog, task_description, request_payload_for_logging, final_parsed_data, error_occurred) + await log_internal_api_call( + cog, + task_description, + request_payload_for_logging, + final_parsed_data, + error_occurred, + ) except Exception as log_e: print(f"Error logging internal API call: {log_e}") diff --git a/wheatley/background.py b/wheatley/background.py index 51e82fd..da0b215 100644 --- a/wheatley/background.py +++ b/wheatley/background.py @@ -7,64 +7,94 @@ import aiohttp from typing import TYPE_CHECKING # Relative imports -from .config import ( - STATS_PUSH_INTERVAL # Only keep stats interval -) +from .config import STATS_PUSH_INTERVAL # Only keep stats interval + # Removed analysis imports if TYPE_CHECKING: - from .cog import WheatleyCog # Updated type hint + from .cog import WheatleyCog # Updated type hint # --- Background Task --- -async def background_processing_task(cog: 'WheatleyCog'): # Updated type hint - """Background task that periodically pushes stats.""" # Simplified docstring + +async def background_processing_task(cog: "WheatleyCog"): # Updated type hint + """Background task that periodically pushes stats.""" # Simplified docstring # Get API details from environment for stats pushing api_internal_url = os.getenv("API_INTERNAL_URL") # Use a generic secret name or a Wheatley-specific one if desired - stats_push_secret = os.getenv("WHEATLEY_STATS_PUSH_SECRET", os.getenv("GURT_STATS_PUSH_SECRET")) # Fallback to GURT secret if needed + stats_push_secret = os.getenv( + "WHEATLEY_STATS_PUSH_SECRET", os.getenv("GURT_STATS_PUSH_SECRET") + ) # Fallback to GURT secret if needed if not api_internal_url: - print("WARNING: API_INTERNAL_URL not set. Wheatley stats will not be pushed.") # Updated text + print( + "WARNING: API_INTERNAL_URL not set. Wheatley stats will not be pushed." + ) # Updated text if not stats_push_secret: - print("WARNING: WHEATLEY_STATS_PUSH_SECRET (or GURT_STATS_PUSH_SECRET) not set. Stats push endpoint is insecure and likely won't work.") # Updated text + print( + "WARNING: WHEATLEY_STATS_PUSH_SECRET (or GURT_STATS_PUSH_SECRET) not set. Stats push endpoint is insecure and likely won't work." + ) # Updated text try: while True: - await asyncio.sleep(STATS_PUSH_INTERVAL) # Use the stats interval directly + await asyncio.sleep(STATS_PUSH_INTERVAL) # Use the stats interval directly now = time.time() # --- Push Stats --- - if api_internal_url and stats_push_secret: # Removed check for last push time, rely on sleep interval - print("Pushing Wheatley stats to API server...") # Updated text + if ( + api_internal_url and stats_push_secret + ): # Removed check for last push time, rely on sleep interval + print("Pushing Wheatley stats to API server...") # Updated text try: - stats_data = await cog.get_wheatley_stats() # Updated method call + stats_data = await cog.get_wheatley_stats() # Updated method call headers = { "Authorization": f"Bearer {stats_push_secret}", - "Content-Type": "application/json" + "Content-Type": "application/json", } # Use the cog's session, ensure it's created if cog.session: # Set a reasonable timeout for the stats push - push_timeout = aiohttp.ClientTimeout(total=10) # 10 seconds total timeout - async with cog.session.post(api_internal_url, json=stats_data, headers=headers, timeout=push_timeout, ssl=True) as response: # Explicitly enable SSL verification + push_timeout = aiohttp.ClientTimeout( + total=10 + ) # 10 seconds total timeout + async with cog.session.post( + api_internal_url, + json=stats_data, + headers=headers, + timeout=push_timeout, + ssl=True, + ) as response: # Explicitly enable SSL verification if response.status == 200: - print(f"Successfully pushed Wheatley stats (Status: {response.status})") # Updated text + print( + f"Successfully pushed Wheatley stats (Status: {response.status})" + ) # Updated text else: error_text = await response.text() - print(f"Failed to push Wheatley stats (Status: {response.status}): {error_text[:200]}") # Updated text, Log only first 200 chars + print( + f"Failed to push Wheatley stats (Status: {response.status}): {error_text[:200]}" + ) # Updated text, Log only first 200 chars else: - print("Error pushing stats: WheatleyCog session not initialized.") # Updated text + print( + "Error pushing stats: WheatleyCog session not initialized." + ) # Updated text # Removed updating cog.last_stats_push as we rely on sleep interval except aiohttp.ClientConnectorSSLError as ssl_err: - print(f"SSL Error pushing Wheatley stats: {ssl_err}. Ensure the API server's certificate is valid and trusted, or check network configuration.") # Updated text - print("If using a self-signed certificate for development, the bot process might need to trust it.") + print( + f"SSL Error pushing Wheatley stats: {ssl_err}. Ensure the API server's certificate is valid and trusted, or check network configuration." + ) # Updated text + print( + "If using a self-signed certificate for development, the bot process might need to trust it." + ) except aiohttp.ClientError as client_err: - print(f"HTTP Client Error pushing Wheatley stats: {client_err}") # Updated text + print( + f"HTTP Client Error pushing Wheatley stats: {client_err}" + ) # Updated text except asyncio.TimeoutError: - print("Timeout error pushing Wheatley stats.") # Updated text + print("Timeout error pushing Wheatley stats.") # Updated text except Exception as e: - print(f"Unexpected error pushing Wheatley stats: {e}") # Updated text + print( + f"Unexpected error pushing Wheatley stats: {e}" + ) # Updated text traceback.print_exc() # --- Removed Learning Analysis --- @@ -76,11 +106,12 @@ async def background_processing_task(cog: 'WheatleyCog'): # Updated type hint # --- Removed Automatic Mood Change --- except asyncio.CancelledError: - print("Wheatley background processing task cancelled") # Updated text + print("Wheatley background processing task cancelled") # Updated text except Exception as e: - print(f"Error in Wheatley background processing task: {e}") # Updated text + print(f"Error in Wheatley background processing task: {e}") # Updated text traceback.print_exc() - await asyncio.sleep(300) # Wait 5 minutes before retrying after an error + await asyncio.sleep(300) # Wait 5 minutes before retrying after an error + # --- Removed Automatic Mood Change Logic --- # --- Removed Interest Update Logic --- diff --git a/wheatley/cog.py b/wheatley/cog.py index 400c458..6093b21 100644 --- a/wheatley/cog.py +++ b/wheatley/cog.py @@ -11,44 +11,76 @@ from typing import Dict, List, Any, Optional, Tuple, Set, Union # Third-party imports needed by the Cog itself or its direct methods from dotenv import load_dotenv -from tavily import TavilyClient # Needed for tavily_client init +from tavily import TavilyClient # Needed for tavily_client init # --- Relative Imports from Wheatley Package --- from .config import ( - PROJECT_ID, LOCATION, TAVILY_API_KEY, DEFAULT_MODEL, FALLBACK_MODEL, # Use GCP config - DB_PATH, CHROMA_PATH, SEMANTIC_MODEL_NAME, MAX_USER_FACTS, MAX_GENERAL_FACTS, + PROJECT_ID, + LOCATION, + TAVILY_API_KEY, + DEFAULT_MODEL, + FALLBACK_MODEL, # Use GCP config + DB_PATH, + CHROMA_PATH, + SEMANTIC_MODEL_NAME, + MAX_USER_FACTS, + MAX_GENERAL_FACTS, # Removed Mood/Personality/Interest/Learning/Goal configs - CHANNEL_TOPIC_CACHE_TTL, CONTEXT_WINDOW_SIZE, - API_TIMEOUT, SUMMARY_API_TIMEOUT, API_RETRY_ATTEMPTS, API_RETRY_DELAY, - PROACTIVE_LULL_THRESHOLD, PROACTIVE_BOT_SILENCE_THRESHOLD, PROACTIVE_LULL_CHANCE, - PROACTIVE_TOPIC_RELEVANCE_THRESHOLD, PROACTIVE_TOPIC_CHANCE, + CHANNEL_TOPIC_CACHE_TTL, + CONTEXT_WINDOW_SIZE, + API_TIMEOUT, + SUMMARY_API_TIMEOUT, + API_RETRY_ATTEMPTS, + API_RETRY_DELAY, + PROACTIVE_LULL_THRESHOLD, + PROACTIVE_BOT_SILENCE_THRESHOLD, + PROACTIVE_LULL_CHANCE, + PROACTIVE_TOPIC_RELEVANCE_THRESHOLD, + PROACTIVE_TOPIC_CHANCE, # Removed Relationship/Sentiment/Interest proactive configs - TOPIC_UPDATE_INTERVAL, SENTIMENT_UPDATE_INTERVAL, - RESPONSE_SCHEMA, TOOLS # Import necessary configs + TOPIC_UPDATE_INTERVAL, + SENTIMENT_UPDATE_INTERVAL, + RESPONSE_SCHEMA, + TOOLS, # Import necessary configs ) + # Import functions/classes from other modules -from .memory import MemoryManager # Import from local memory.py -from .background import background_processing_task # Keep background task for potential future use (e.g., cache cleanup) -from .commands import setup_commands # Import the setup helper -from .listeners import on_ready_listener, on_message_listener, on_reaction_add_listener, on_reaction_remove_listener # Import listener functions -from . import config as WheatleyConfig # Import config module for get_wheatley_stats +from .memory import MemoryManager # Import from local memory.py +from .background import ( + background_processing_task, +) # Keep background task for potential future use (e.g., cache cleanup) +from .commands import setup_commands # Import the setup helper +from .listeners import ( + on_ready_listener, + on_message_listener, + on_reaction_add_listener, + on_reaction_remove_listener, +) # Import listener functions +from . import config as WheatleyConfig # Import config module for get_wheatley_stats # Load environment variables (might be loaded globally in main bot script too) load_dotenv() -class WheatleyCog(commands.Cog, name="Wheatley"): # Renamed class and Cog name - """A special cog for the Wheatley bot that uses Google Vertex AI API""" # Updated docstring + +class WheatleyCog(commands.Cog, name="Wheatley"): # Renamed class and Cog name + """A special cog for the Wheatley bot that uses Google Vertex AI API""" # Updated docstring def __init__(self, bot): self.bot = bot # GCP Project/Location are used by vertexai.init() in api.py - self.tavily_api_key = TAVILY_API_KEY # Use imported config - self.session: Optional[aiohttp.ClientSession] = None # Keep for other potential HTTP requests (e.g., Piston) - self.tavily_client = TavilyClient(api_key=self.tavily_api_key) if self.tavily_api_key else None - self.default_model = DEFAULT_MODEL # Use imported config - self.fallback_model = FALLBACK_MODEL # Use imported config + self.tavily_api_key = TAVILY_API_KEY # Use imported config + self.session: Optional[aiohttp.ClientSession] = ( + None # Keep for other potential HTTP requests (e.g., Piston) + ) + self.tavily_client = ( + TavilyClient(api_key=self.tavily_api_key) if self.tavily_api_key else None + ) + self.default_model = DEFAULT_MODEL # Use imported config + self.fallback_model = FALLBACK_MODEL # Use imported config # Removed MOOD_OPTIONS - self.current_channel: Optional[Union[discord.TextChannel, discord.Thread, discord.DMChannel]] = None # Type hint current channel + self.current_channel: Optional[ + Union[discord.TextChannel, discord.Thread, discord.DMChannel] + ] = None # Type hint current channel # Instantiate MemoryManager self.memory_manager = MemoryManager( @@ -56,60 +88,91 @@ class WheatleyCog(commands.Cog, name="Wheatley"): # Renamed class and Cog name max_user_facts=MAX_USER_FACTS, max_general_facts=MAX_GENERAL_FACTS, chroma_path=CHROMA_PATH, - semantic_model_name=SEMANTIC_MODEL_NAME + semantic_model_name=SEMANTIC_MODEL_NAME, ) # --- State Variables (Simplified for Wheatley) --- # Removed mood, personality evolution, interest tracking, learning state - self.needs_json_reminder = False # Flag to remind AI about JSON format + self.needs_json_reminder = False # Flag to remind AI about JSON format # Topic tracking (Kept for context) - self.active_topics = defaultdict(lambda: { - "topics": [], "last_update": time.time(), "topic_history": [], - "user_topic_interests": defaultdict(list) # Kept for potential future analysis, not proactive triggers - }) + self.active_topics = defaultdict( + lambda: { + "topics": [], + "last_update": time.time(), + "topic_history": [], + "user_topic_interests": defaultdict( + list + ), # Kept for potential future analysis, not proactive triggers + } + ) # Conversation tracking / Caches self.conversation_history = defaultdict(lambda: deque(maxlen=100)) self.thread_history = defaultdict(lambda: deque(maxlen=50)) self.user_conversation_mapping = defaultdict(set) - self.channel_activity = defaultdict(lambda: 0.0) # Use float for timestamp - self.conversation_topics = defaultdict(str) # Simplified topic tracking - self.user_relationships = defaultdict(dict) # Kept for potential context/analysis - self.conversation_summaries: Dict[int, Dict[str, Any]] = {} # Store dict with summary and timestamp - self.channel_topics_cache: Dict[int, Dict[str, Any]] = {} # Store dict with topic and timestamp + self.channel_activity = defaultdict(lambda: 0.0) # Use float for timestamp + self.conversation_topics = defaultdict(str) # Simplified topic tracking + self.user_relationships = defaultdict( + dict + ) # Kept for potential context/analysis + self.conversation_summaries: Dict[int, Dict[str, Any]] = ( + {} + ) # Store dict with summary and timestamp + self.channel_topics_cache: Dict[int, Dict[str, Any]] = ( + {} + ) # Store dict with topic and timestamp self.message_cache = { - 'by_channel': defaultdict(lambda: deque(maxlen=CONTEXT_WINDOW_SIZE)), # Use config - 'by_user': defaultdict(lambda: deque(maxlen=50)), - 'by_thread': defaultdict(lambda: deque(maxlen=50)), - 'global_recent': deque(maxlen=200), - 'mentioned': deque(maxlen=50), - 'replied_to': defaultdict(lambda: deque(maxlen=20)) + "by_channel": defaultdict( + lambda: deque(maxlen=CONTEXT_WINDOW_SIZE) + ), # Use config + "by_user": defaultdict(lambda: deque(maxlen=50)), + "by_thread": defaultdict(lambda: deque(maxlen=50)), + "global_recent": deque(maxlen=200), + "mentioned": deque(maxlen=50), + "replied_to": defaultdict(lambda: deque(maxlen=20)), } - self.active_conversations = {} # Kept for basic tracking + self.active_conversations = {} # Kept for basic tracking self.bot_last_spoke = defaultdict(float) self.message_reply_map = {} # Enhanced sentiment tracking (Kept for context/analysis) - self.conversation_sentiment = defaultdict(lambda: { - "overall": "neutral", "intensity": 0.5, "recent_trend": "stable", - "user_sentiments": {}, "last_update": time.time() - }) + self.conversation_sentiment = defaultdict( + lambda: { + "overall": "neutral", + "intensity": 0.5, + "recent_trend": "stable", + "user_sentiments": {}, + "last_update": time.time(), + } + ) # Removed self.sentiment_update_interval as it was only used in analysis # Reaction Tracking (Renamed) - self.wheatley_message_reactions = defaultdict(lambda: {"positive": 0, "negative": 0, "topic": None, "timestamp": 0.0}) # Renamed + self.wheatley_message_reactions = defaultdict( + lambda: {"positive": 0, "negative": 0, "topic": None, "timestamp": 0.0} + ) # Renamed # Background task handle (Kept for potential future tasks like cache cleanup) self.background_task: Optional[asyncio.Task] = None - self.last_stats_push = time.time() # Timestamp for last stats push + self.last_stats_push = time.time() # Timestamp for last stats push # Removed evolution, reflection, goal timestamps # --- Stats Tracking --- - self.api_stats = defaultdict(lambda: {"success": 0, "failure": 0, "retries": 0, "total_time": 0.0, "count": 0}) # Keyed by model name - self.tool_stats = defaultdict(lambda: {"success": 0, "failure": 0, "total_time": 0.0, "count": 0}) # Keyed by tool name + self.api_stats = defaultdict( + lambda: { + "success": 0, + "failure": 0, + "retries": 0, + "total_time": 0.0, + "count": 0, + } + ) # Keyed by model name + self.tool_stats = defaultdict( + lambda: {"success": 0, "failure": 0, "total_time": 0.0, "count": 0} + ) # Keyed by tool name # --- Setup Commands and Listeners --- # Add commands defined in commands.py @@ -127,29 +190,37 @@ class WheatleyCog(commands.Cog, name="Wheatley"): # Renamed class and Cog name else: self.registered_commands.append(str(func)) - print(f"WheatleyCog initialized with commands: {self.registered_commands}") # Updated print + print( + f"WheatleyCog initialized with commands: {self.registered_commands}" + ) # Updated print async def cog_load(self): """Create aiohttp session, initialize DB, start background task""" self.session = aiohttp.ClientSession() - print("WheatleyCog: aiohttp session created") # Updated print + print("WheatleyCog: aiohttp session created") # Updated print # Initialize DB via MemoryManager await self.memory_manager.initialize_sqlite_database() # Removed loading of baseline personality and interests # Vertex AI initialization happens in api.py using PROJECT_ID and LOCATION from config - print(f"WheatleyCog: Using default model: {self.default_model}") # Updated print + print( + f"WheatleyCog: Using default model: {self.default_model}" + ) # Updated print if not self.tavily_api_key: - print("WARNING: Tavily API key not configured (TAVILY_API_KEY). Web search disabled.") + print( + "WARNING: Tavily API key not configured (TAVILY_API_KEY). Web search disabled." + ) # Add listeners to the bot instance # IMPORTANT: Don't override on_member_join or on_member_remove events # Check if the bot already has event listeners for member join/leave - has_member_join = 'on_member_join' in self.bot.extra_events - has_member_remove = 'on_member_remove' in self.bot.extra_events - print(f"WheatleyCog: Bot already has event listeners - on_member_join: {has_member_join}, on_member_remove: {has_member_remove}") + has_member_join = "on_member_join" in self.bot.extra_events + has_member_remove = "on_member_remove" in self.bot.extra_events + print( + f"WheatleyCog: Bot already has event listeners - on_member_join: {has_member_join}, on_member_remove: {has_member_remove}" + ) @self.bot.event async def on_ready(): @@ -159,7 +230,7 @@ class WheatleyCog(commands.Cog, name="Wheatley"): # Renamed class and Cog name async def on_message(message): # Ensure commands are processed if using command prefix if message.content.startswith(self.bot.command_prefix): - await self.bot.process_commands(message) + await self.bot.process_commands(message) # Always run the message listener for potential AI responses/tracking await on_message_listener(self, message) @@ -171,45 +242,59 @@ class WheatleyCog(commands.Cog, name="Wheatley"): # Renamed class and Cog name async def on_reaction_remove(reaction, user): await on_reaction_remove_listener(self, reaction, user) - print("WheatleyCog: Listeners added.") # Updated print + print("WheatleyCog: Listeners added.") # Updated print # Commands will be synced in on_ready - print("WheatleyCog: Commands will be synced when the bot is ready.") # Updated print + print( + "WheatleyCog: Commands will be synced when the bot is ready." + ) # Updated print # Start background task (kept for potential future use) if self.background_task is None or self.background_task.done(): self.background_task = asyncio.create_task(background_processing_task(self)) - print("WheatleyCog: Started background processing task.") # Updated print + print("WheatleyCog: Started background processing task.") # Updated print else: - print("WheatleyCog: Background processing task already running.") # Updated print + print( + "WheatleyCog: Background processing task already running." + ) # Updated print async def cog_unload(self): """Close session and cancel background task""" if self.session and not self.session.closed: await self.session.close() - print("WheatleyCog: aiohttp session closed") # Updated print + print("WheatleyCog: aiohttp session closed") # Updated print if self.background_task and not self.background_task.done(): self.background_task.cancel() - print("WheatleyCog: Cancelled background processing task.") # Updated print - print("WheatleyCog: Listeners will be removed when bot is closed.") # Updated print + print("WheatleyCog: Cancelled background processing task.") # Updated print + print( + "WheatleyCog: Listeners will be removed when bot is closed." + ) # Updated print - print("WheatleyCog unloaded.") # Updated print + print("WheatleyCog unloaded.") # Updated print # --- Helper methods that might remain in the cog --- # _update_relationship kept for potential context/analysis use def _update_relationship(self, user_id_1: str, user_id_2: str, change: float): """Updates the relationship score between two users.""" - if user_id_1 > user_id_2: user_id_1, user_id_2 = user_id_2, user_id_1 - if user_id_1 not in self.user_relationships: self.user_relationships[user_id_1] = {} + if user_id_1 > user_id_2: + user_id_1, user_id_2 = user_id_2, user_id_1 + if user_id_1 not in self.user_relationships: + self.user_relationships[user_id_1] = {} current_score = self.user_relationships[user_id_1].get(user_id_2, 0.0) - new_score = max(0.0, min(current_score + change, 100.0)) # Clamp 0-100 + new_score = max(0.0, min(current_score + change, 100.0)) # Clamp 0-100 self.user_relationships[user_id_1][user_id_2] = new_score # print(f"Updated relationship {user_id_1}-{user_id_2}: {current_score:.1f} -> {new_score:.1f} ({change:+.1f})") # Debug log - async def get_wheatley_stats(self) -> Dict[str, Any]: # Renamed method - """Collects various internal stats for Wheatley.""" # Updated docstring - stats = {"config": {}, "runtime": {}, "memory": {}, "api_stats": {}, "tool_stats": {}} + async def get_wheatley_stats(self) -> Dict[str, Any]: # Renamed method + """Collects various internal stats for Wheatley.""" # Updated docstring + stats = { + "config": {}, + "runtime": {}, + "memory": {}, + "api_stats": {}, + "tool_stats": {}, + } # --- Config (Simplified) --- stats["config"]["default_model"] = WheatleyConfig.DEFAULT_MODEL @@ -223,12 +308,22 @@ class WheatleyCog(commands.Cog, name="Wheatley"): # Renamed class and Cog name stats["config"]["context_window_size"] = WheatleyConfig.CONTEXT_WINDOW_SIZE stats["config"]["api_timeout"] = WheatleyConfig.API_TIMEOUT stats["config"]["summary_api_timeout"] = WheatleyConfig.SUMMARY_API_TIMEOUT - stats["config"]["proactive_lull_threshold"] = WheatleyConfig.PROACTIVE_LULL_THRESHOLD - stats["config"]["proactive_bot_silence_threshold"] = WheatleyConfig.PROACTIVE_BOT_SILENCE_THRESHOLD + stats["config"][ + "proactive_lull_threshold" + ] = WheatleyConfig.PROACTIVE_LULL_THRESHOLD + stats["config"][ + "proactive_bot_silence_threshold" + ] = WheatleyConfig.PROACTIVE_BOT_SILENCE_THRESHOLD stats["config"]["topic_update_interval"] = WheatleyConfig.TOPIC_UPDATE_INTERVAL - stats["config"]["sentiment_update_interval"] = WheatleyConfig.SENTIMENT_UPDATE_INTERVAL - stats["config"]["docker_command_timeout"] = WheatleyConfig.DOCKER_COMMAND_TIMEOUT - stats["config"]["project_id_set"] = bool(WheatleyConfig.PROJECT_ID != "your-gcp-project-id") + stats["config"][ + "sentiment_update_interval" + ] = WheatleyConfig.SENTIMENT_UPDATE_INTERVAL + stats["config"][ + "docker_command_timeout" + ] = WheatleyConfig.DOCKER_COMMAND_TIMEOUT + stats["config"]["project_id_set"] = bool( + WheatleyConfig.PROJECT_ID != "your-gcp-project-id" + ) stats["config"]["location_set"] = bool(WheatleyConfig.LOCATION != "us-central1") stats["config"]["tavily_api_key_set"] = bool(WheatleyConfig.TAVILY_API_KEY) stats["config"]["piston_api_url_set"] = bool(WheatleyConfig.PISTON_API_URL) @@ -236,36 +331,72 @@ class WheatleyCog(commands.Cog, name="Wheatley"): # Renamed class and Cog name # --- Runtime (Simplified) --- # Removed mood, evolution stats["runtime"]["needs_json_reminder"] = self.needs_json_reminder - stats["runtime"]["background_task_running"] = bool(self.background_task and not self.background_task.done()) + stats["runtime"]["background_task_running"] = bool( + self.background_task and not self.background_task.done() + ) stats["runtime"]["active_topics_channels"] = len(self.active_topics) - stats["runtime"]["conversation_history_channels"] = len(self.conversation_history) + stats["runtime"]["conversation_history_channels"] = len( + self.conversation_history + ) stats["runtime"]["thread_history_threads"] = len(self.thread_history) - stats["runtime"]["user_conversation_mappings"] = len(self.user_conversation_mapping) + stats["runtime"]["user_conversation_mappings"] = len( + self.user_conversation_mapping + ) stats["runtime"]["channel_activity_tracked"] = len(self.channel_activity) - stats["runtime"]["conversation_topics_tracked"] = len(self.conversation_topics) # Simplified topic tracking - stats["runtime"]["user_relationships_pairs"] = sum(len(v) for v in self.user_relationships.values()) - stats["runtime"]["conversation_summaries_cached"] = len(self.conversation_summaries) + stats["runtime"]["conversation_topics_tracked"] = len( + self.conversation_topics + ) # Simplified topic tracking + stats["runtime"]["user_relationships_pairs"] = sum( + len(v) for v in self.user_relationships.values() + ) + stats["runtime"]["conversation_summaries_cached"] = len( + self.conversation_summaries + ) stats["runtime"]["channel_topics_cached"] = len(self.channel_topics_cache) - stats["runtime"]["message_cache_global_count"] = len(self.message_cache['global_recent']) - stats["runtime"]["message_cache_mentioned_count"] = len(self.message_cache['mentioned']) + stats["runtime"]["message_cache_global_count"] = len( + self.message_cache["global_recent"] + ) + stats["runtime"]["message_cache_mentioned_count"] = len( + self.message_cache["mentioned"] + ) stats["runtime"]["active_conversations_count"] = len(self.active_conversations) stats["runtime"]["bot_last_spoke_channels"] = len(self.bot_last_spoke) stats["runtime"]["message_reply_map_size"] = len(self.message_reply_map) - stats["runtime"]["conversation_sentiment_channels"] = len(self.conversation_sentiment) + stats["runtime"]["conversation_sentiment_channels"] = len( + self.conversation_sentiment + ) # Removed Gurt participation topics - stats["runtime"]["wheatley_message_reactions_tracked"] = len(self.wheatley_message_reactions) # Renamed + stats["runtime"]["wheatley_message_reactions_tracked"] = len( + self.wheatley_message_reactions + ) # Renamed # --- Memory (Simplified) --- try: # Removed Personality, Interests - user_fact_count = await self.memory_manager._db_fetchone("SELECT COUNT(*) FROM user_facts") - general_fact_count = await self.memory_manager._db_fetchone("SELECT COUNT(*) FROM general_facts") - stats["memory"]["user_facts_count"] = user_fact_count[0] if user_fact_count else 0 - stats["memory"]["general_facts_count"] = general_fact_count[0] if general_fact_count else 0 + user_fact_count = await self.memory_manager._db_fetchone( + "SELECT COUNT(*) FROM user_facts" + ) + general_fact_count = await self.memory_manager._db_fetchone( + "SELECT COUNT(*) FROM general_facts" + ) + stats["memory"]["user_facts_count"] = ( + user_fact_count[0] if user_fact_count else 0 + ) + stats["memory"]["general_facts_count"] = ( + general_fact_count[0] if general_fact_count else 0 + ) # ChromaDB Stats - stats["memory"]["chromadb_message_collection_count"] = await asyncio.to_thread(self.memory_manager.semantic_collection.count) if self.memory_manager.semantic_collection else "N/A" - stats["memory"]["chromadb_fact_collection_count"] = await asyncio.to_thread(self.memory_manager.fact_collection.count) if self.memory_manager.fact_collection else "N/A" + stats["memory"]["chromadb_message_collection_count"] = ( + await asyncio.to_thread(self.memory_manager.semantic_collection.count) + if self.memory_manager.semantic_collection + else "N/A" + ) + stats["memory"]["chromadb_fact_collection_count"] = ( + await asyncio.to_thread(self.memory_manager.fact_collection.count) + if self.memory_manager.fact_collection + else "N/A" + ) except Exception as e: stats["memory"]["error"] = f"Failed to retrieve memory stats: {e}" @@ -276,34 +407,52 @@ class WheatleyCog(commands.Cog, name="Wheatley"): # Renamed class and Cog name # Calculate average times for model, data in stats["api_stats"].items(): - if data["count"] > 0: data["average_time_ms"] = round((data["total_time"] / data["count"]) * 1000, 2) - else: data["average_time_ms"] = 0 + if data["count"] > 0: + data["average_time_ms"] = round( + (data["total_time"] / data["count"]) * 1000, 2 + ) + else: + data["average_time_ms"] = 0 for tool, data in stats["tool_stats"].items(): - if data["count"] > 0: data["average_time_ms"] = round((data["total_time"] / data["count"]) * 1000, 2) - else: data["average_time_ms"] = 0 + if data["count"] > 0: + data["average_time_ms"] = round( + (data["total_time"] / data["count"]) * 1000, 2 + ) + else: + data["average_time_ms"] = 0 return stats async def sync_commands(self): """Manually sync commands with Discord.""" try: - print("WheatleyCog: Manually syncing commands with Discord...") # Updated print + print( + "WheatleyCog: Manually syncing commands with Discord..." + ) # Updated print synced = await self.bot.tree.sync() - print(f"WheatleyCog: Synced {len(synced)} command(s)") # Updated print + print(f"WheatleyCog: Synced {len(synced)} command(s)") # Updated print # List the synced commands - wheatley_commands = [cmd.name for cmd in self.bot.tree.get_commands() if cmd.name.startswith("wheatley")] # Updated prefix - print(f"WheatleyCog: Available Wheatley commands: {', '.join(wheatley_commands)}") # Updated print + wheatley_commands = [ + cmd.name + for cmd in self.bot.tree.get_commands() + if cmd.name.startswith("wheatley") + ] # Updated prefix + print( + f"WheatleyCog: Available Wheatley commands: {', '.join(wheatley_commands)}" + ) # Updated print return synced, wheatley_commands except Exception as e: - print(f"WheatleyCog: Failed to sync commands: {e}") # Updated print + print(f"WheatleyCog: Failed to sync commands: {e}") # Updated print import traceback + traceback.print_exc() return [], [] + # Setup function for loading the cog async def setup(bot): - """Add the WheatleyCog to the bot.""" # Updated docstring - await bot.add_cog(WheatleyCog(bot)) # Use renamed class - print("WheatleyCog setup complete.") # Updated print + """Add the WheatleyCog to the bot.""" # Updated docstring + await bot.add_cog(WheatleyCog(bot)) # Use renamed class + print("WheatleyCog setup complete.") # Updated print diff --git a/wheatley/commands.py b/wheatley/commands.py index 698573d..280a2b4 100644 --- a/wheatley/commands.py +++ b/wheatley/commands.py @@ -1,115 +1,235 @@ import discord -from discord import app_commands # Import app_commands +from discord import app_commands # Import app_commands from discord.ext import commands import random import os -import time # Import time for timestamps -import json # Import json for formatting -import datetime # Import datetime for formatting -from typing import TYPE_CHECKING, Optional, Dict, Any, List, Tuple # Add more types +import time # Import time for timestamps +import json # Import json for formatting +import datetime # Import datetime for formatting +from typing import TYPE_CHECKING, Optional, Dict, Any, List, Tuple # Add more types # Relative imports # We need access to the cog instance for state and methods if TYPE_CHECKING: - from .cog import WheatleyCog # For type hinting + from .cog import WheatleyCog # For type hinting + # MOOD_OPTIONS removed + # --- Helper Function for Embeds --- -def create_wheatley_embed(title: str, description: str = "", color=discord.Color.blue()) -> discord.Embed: # Renamed function - """Creates a standard Wheatley-themed embed.""" # Updated docstring +def create_wheatley_embed( + title: str, description: str = "", color=discord.Color.blue() +) -> discord.Embed: # Renamed function + """Creates a standard Wheatley-themed embed.""" # Updated docstring embed = discord.Embed(title=title, description=description, color=color) # Placeholder icon URL, replace if Wheatley has one # embed.set_footer(text="Wheatley", icon_url="https://example.com/wheatley_icon.png") # Updated text - embed.set_footer(text="Wheatley") # Updated text + embed.set_footer(text="Wheatley") # Updated text return embed + # --- Helper Function for Stats Embeds --- def format_stats_embeds(stats: Dict[str, Any]) -> List[discord.Embed]: """Formats the collected stats into multiple embeds.""" embeds = [] - main_embed = create_wheatley_embed("Wheatley Internal Stats", color=discord.Color.green()) # Use new helper, updated title - ts_format = "" # Relative timestamp + main_embed = create_wheatley_embed( + "Wheatley Internal Stats", color=discord.Color.green() + ) # Use new helper, updated title + ts_format = "" # Relative timestamp # Runtime Stats (Simplified for Wheatley) runtime = stats.get("runtime", {}) - main_embed.add_field(name="Background Task", value="Running" if runtime.get('background_task_running') else "Stopped", inline=True) - main_embed.add_field(name="Needs JSON Reminder", value=str(runtime.get('needs_json_reminder', 'N/A')), inline=True) + main_embed.add_field( + name="Background Task", + value="Running" if runtime.get("background_task_running") else "Stopped", + inline=True, + ) + main_embed.add_field( + name="Needs JSON Reminder", + value=str(runtime.get("needs_json_reminder", "N/A")), + inline=True, + ) # Removed Mood, Evolution - main_embed.add_field(name="Active Topics Channels", value=str(runtime.get('active_topics_channels', 'N/A')), inline=True) - main_embed.add_field(name="Conv History Channels", value=str(runtime.get('conversation_history_channels', 'N/A')), inline=True) - main_embed.add_field(name="Thread History Threads", value=str(runtime.get('thread_history_threads', 'N/A')), inline=True) - main_embed.add_field(name="User Relationships Pairs", value=str(runtime.get('user_relationships_pairs', 'N/A')), inline=True) - main_embed.add_field(name="Cached Summaries", value=str(runtime.get('conversation_summaries_cached', 'N/A')), inline=True) - main_embed.add_field(name="Cached Channel Topics", value=str(runtime.get('channel_topics_cached', 'N/A')), inline=True) - main_embed.add_field(name="Global Msg Cache", value=str(runtime.get('message_cache_global_count', 'N/A')), inline=True) - main_embed.add_field(name="Mention Msg Cache", value=str(runtime.get('message_cache_mentioned_count', 'N/A')), inline=True) - main_embed.add_field(name="Active Convos", value=str(runtime.get('active_conversations_count', 'N/A')), inline=True) - main_embed.add_field(name="Sentiment Channels", value=str(runtime.get('conversation_sentiment_channels', 'N/A')), inline=True) + main_embed.add_field( + name="Active Topics Channels", + value=str(runtime.get("active_topics_channels", "N/A")), + inline=True, + ) + main_embed.add_field( + name="Conv History Channels", + value=str(runtime.get("conversation_history_channels", "N/A")), + inline=True, + ) + main_embed.add_field( + name="Thread History Threads", + value=str(runtime.get("thread_history_threads", "N/A")), + inline=True, + ) + main_embed.add_field( + name="User Relationships Pairs", + value=str(runtime.get("user_relationships_pairs", "N/A")), + inline=True, + ) + main_embed.add_field( + name="Cached Summaries", + value=str(runtime.get("conversation_summaries_cached", "N/A")), + inline=True, + ) + main_embed.add_field( + name="Cached Channel Topics", + value=str(runtime.get("channel_topics_cached", "N/A")), + inline=True, + ) + main_embed.add_field( + name="Global Msg Cache", + value=str(runtime.get("message_cache_global_count", "N/A")), + inline=True, + ) + main_embed.add_field( + name="Mention Msg Cache", + value=str(runtime.get("message_cache_mentioned_count", "N/A")), + inline=True, + ) + main_embed.add_field( + name="Active Convos", + value=str(runtime.get("active_conversations_count", "N/A")), + inline=True, + ) + main_embed.add_field( + name="Sentiment Channels", + value=str(runtime.get("conversation_sentiment_channels", "N/A")), + inline=True, + ) # Removed Gurt Participation Topics - main_embed.add_field(name="Tracked Reactions", value=str(runtime.get('wheatley_message_reactions_tracked', 'N/A')), inline=True) # Renamed stat key + main_embed.add_field( + name="Tracked Reactions", + value=str(runtime.get("wheatley_message_reactions_tracked", "N/A")), + inline=True, + ) # Renamed stat key embeds.append(main_embed) # Memory Stats (Simplified) - memory_embed = create_wheatley_embed("Wheatley Memory Stats", color=discord.Color.orange()) # Use new helper, updated title + memory_embed = create_wheatley_embed( + "Wheatley Memory Stats", color=discord.Color.orange() + ) # Use new helper, updated title memory = stats.get("memory", {}) if memory.get("error"): memory_embed.description = f"⚠️ Error retrieving memory stats: {memory['error']}" else: - memory_embed.add_field(name="User Facts", value=str(memory.get('user_facts_count', 'N/A')), inline=True) - memory_embed.add_field(name="General Facts", value=str(memory.get('general_facts_count', 'N/A')), inline=True) - memory_embed.add_field(name="Chroma Messages", value=str(memory.get('chromadb_message_collection_count', 'N/A')), inline=True) - memory_embed.add_field(name="Chroma Facts", value=str(memory.get('chromadb_fact_collection_count', 'N/A')), inline=True) + memory_embed.add_field( + name="User Facts", + value=str(memory.get("user_facts_count", "N/A")), + inline=True, + ) + memory_embed.add_field( + name="General Facts", + value=str(memory.get("general_facts_count", "N/A")), + inline=True, + ) + memory_embed.add_field( + name="Chroma Messages", + value=str(memory.get("chromadb_message_collection_count", "N/A")), + inline=True, + ) + memory_embed.add_field( + name="Chroma Facts", + value=str(memory.get("chromadb_fact_collection_count", "N/A")), + inline=True, + ) # Removed Personality Traits, Interests embeds.append(memory_embed) # API Stats api_stats = stats.get("api_stats", {}) if api_stats: - api_embed = create_wheatley_embed("Wheatley API Stats", color=discord.Color.red()) # Use new helper, updated title + api_embed = create_wheatley_embed( + "Wheatley API Stats", color=discord.Color.red() + ) # Use new helper, updated title for model, data in api_stats.items(): - avg_time = data.get('average_time_ms', 0) - value = (f"✅ Success: {data.get('success', 0)}\n" - f"❌ Failure: {data.get('failure', 0)}\n" - f"🔁 Retries: {data.get('retries', 0)}\n" - f"⏱️ Avg Time: {avg_time} ms\n" - f"📊 Count: {data.get('count', 0)}") + avg_time = data.get("average_time_ms", 0) + value = ( + f"✅ Success: {data.get('success', 0)}\n" + f"❌ Failure: {data.get('failure', 0)}\n" + f"🔁 Retries: {data.get('retries', 0)}\n" + f"⏱️ Avg Time: {avg_time} ms\n" + f"📊 Count: {data.get('count', 0)}" + ) api_embed.add_field(name=f"Model: `{model}`", value=value, inline=True) embeds.append(api_embed) # Tool Stats tool_stats = stats.get("tool_stats", {}) if tool_stats: - tool_embed = create_wheatley_embed("Wheatley Tool Stats", color=discord.Color.purple()) # Use new helper, updated title + tool_embed = create_wheatley_embed( + "Wheatley Tool Stats", color=discord.Color.purple() + ) # Use new helper, updated title for tool, data in tool_stats.items(): - avg_time = data.get('average_time_ms', 0) - value = (f"✅ Success: {data.get('success', 0)}\n" - f"❌ Failure: {data.get('failure', 0)}\n" - f"⏱️ Avg Time: {avg_time} ms\n" - f"📊 Count: {data.get('count', 0)}") + avg_time = data.get("average_time_ms", 0) + value = ( + f"✅ Success: {data.get('success', 0)}\n" + f"❌ Failure: {data.get('failure', 0)}\n" + f"⏱️ Avg Time: {avg_time} ms\n" + f"📊 Count: {data.get('count', 0)}" + ) tool_embed.add_field(name=f"Tool: `{tool}`", value=value, inline=True) embeds.append(tool_embed) # Config Stats (Simplified) - config_embed = create_wheatley_embed("Wheatley Config Overview", color=discord.Color.greyple()) # Use new helper, updated title + config_embed = create_wheatley_embed( + "Wheatley Config Overview", color=discord.Color.greyple() + ) # Use new helper, updated title config = stats.get("config", {}) - config_embed.add_field(name="Default Model", value=f"`{config.get('default_model', 'N/A')}`", inline=True) - config_embed.add_field(name="Fallback Model", value=f"`{config.get('fallback_model', 'N/A')}`", inline=True) - config_embed.add_field(name="Semantic Model", value=f"`{config.get('semantic_model_name', 'N/A')}`", inline=True) - config_embed.add_field(name="Max User Facts", value=str(config.get('max_user_facts', 'N/A')), inline=True) - config_embed.add_field(name="Max General Facts", value=str(config.get('max_general_facts', 'N/A')), inline=True) - config_embed.add_field(name="Context Window", value=str(config.get('context_window_size', 'N/A')), inline=True) - config_embed.add_field(name="Tavily Key Set", value=str(config.get('tavily_api_key_set', 'N/A')), inline=True) - config_embed.add_field(name="Piston URL Set", value=str(config.get('piston_api_url_set', 'N/A')), inline=True) + config_embed.add_field( + name="Default Model", + value=f"`{config.get('default_model', 'N/A')}`", + inline=True, + ) + config_embed.add_field( + name="Fallback Model", + value=f"`{config.get('fallback_model', 'N/A')}`", + inline=True, + ) + config_embed.add_field( + name="Semantic Model", + value=f"`{config.get('semantic_model_name', 'N/A')}`", + inline=True, + ) + config_embed.add_field( + name="Max User Facts", + value=str(config.get("max_user_facts", "N/A")), + inline=True, + ) + config_embed.add_field( + name="Max General Facts", + value=str(config.get("max_general_facts", "N/A")), + inline=True, + ) + config_embed.add_field( + name="Context Window", + value=str(config.get("context_window_size", "N/A")), + inline=True, + ) + config_embed.add_field( + name="Tavily Key Set", + value=str(config.get("tavily_api_key_set", "N/A")), + inline=True, + ) + config_embed.add_field( + name="Piston URL Set", + value=str(config.get("piston_api_url_set", "N/A")), + inline=True, + ) embeds.append(config_embed) # Limit to 10 embeds max for Discord API return embeds[:10] + # --- Command Setup Function --- # This function will be called from WheatleyCog's setup method -def setup_commands(cog: 'WheatleyCog'): # Updated type hint - """Adds Wheatley-specific commands to the cog.""" # Updated docstring +def setup_commands(cog: "WheatleyCog"): # Updated type hint + """Adds Wheatley-specific commands to the cog.""" # Updated docstring # Create a list to store command functions for proper registration command_functions = [] @@ -117,108 +237,169 @@ def setup_commands(cog: 'WheatleyCog'): # Updated type hint # --- Gurt Mood Command --- REMOVED # --- Wheatley Memory Command --- - @cog.bot.tree.command(name="wheatleymemory", description="Interact with Wheatley's memory (what little there is).") # Renamed, updated description + @cog.bot.tree.command( + name="wheatleymemory", + description="Interact with Wheatley's memory (what little there is).", + ) # Renamed, updated description @app_commands.describe( action="Choose an action: add_user, add_general, get_user, get_general", user="The user for user-specific actions (mention or ID).", fact="The fact to add (for add actions).", - query="A keyword to search for (for get_general)." + query="A keyword to search for (for get_general).", ) - @app_commands.choices(action=[ - app_commands.Choice(name="Add User Fact", value="add_user"), - app_commands.Choice(name="Add General Fact", value="add_general"), - app_commands.Choice(name="Get User Facts", value="get_user"), - app_commands.Choice(name="Get General Facts", value="get_general"), - ]) - async def wheatleymemory(interaction: discord.Interaction, action: app_commands.Choice[str], user: Optional[discord.User] = None, fact: Optional[str] = None, query: Optional[str] = None): # Renamed function - """Handles the /wheatleymemory command.""" # Updated docstring - await interaction.response.defer(ephemeral=True) # Defer for potentially slow DB operations + @app_commands.choices( + action=[ + app_commands.Choice(name="Add User Fact", value="add_user"), + app_commands.Choice(name="Add General Fact", value="add_general"), + app_commands.Choice(name="Get User Facts", value="get_user"), + app_commands.Choice(name="Get General Facts", value="get_general"), + ] + ) + async def wheatleymemory( + interaction: discord.Interaction, + action: app_commands.Choice[str], + user: Optional[discord.User] = None, + fact: Optional[str] = None, + query: Optional[str] = None, + ): # Renamed function + """Handles the /wheatleymemory command.""" # Updated docstring + await interaction.response.defer( + ephemeral=True + ) # Defer for potentially slow DB operations target_user_id = str(user.id) if user else None action_value = action.value # Check if user is the bot owner for modification actions - if (action_value in ["add_user", "add_general"]) and interaction.user.id != cog.bot.owner_id: - await interaction.followup.send("⛔ Oi! Only the boss can fiddle with my memory banks!", ephemeral=True) # Updated text + if ( + action_value in ["add_user", "add_general"] + ) and interaction.user.id != cog.bot.owner_id: + await interaction.followup.send( + "⛔ Oi! Only the boss can fiddle with my memory banks!", ephemeral=True + ) # Updated text return if action_value == "add_user": if not target_user_id or not fact: - await interaction.followup.send("Need a user *and* a fact, mate. Can't remember nothing about nobody.", ephemeral=True) # Updated text + await interaction.followup.send( + "Need a user *and* a fact, mate. Can't remember nothing about nobody.", + ephemeral=True, + ) # Updated text return result = await cog.memory_manager.add_user_fact(target_user_id, fact) - await interaction.followup.send(f"Add User Fact Result: `{json.dumps(result)}` (Probably worked? Maybe?)", ephemeral=True) # Updated text + await interaction.followup.send( + f"Add User Fact Result: `{json.dumps(result)}` (Probably worked? Maybe?)", + ephemeral=True, + ) # Updated text elif action_value == "add_general": if not fact: - await interaction.followup.send("What's the fact then? Can't remember thin air!", ephemeral=True) # Updated text + await interaction.followup.send( + "What's the fact then? Can't remember thin air!", ephemeral=True + ) # Updated text return result = await cog.memory_manager.add_general_fact(fact) - await interaction.followup.send(f"Add General Fact Result: `{json.dumps(result)}` (Filed under 'Important Stuff I'll Forget Later')", ephemeral=True) # Updated text + await interaction.followup.send( + f"Add General Fact Result: `{json.dumps(result)}` (Filed under 'Important Stuff I'll Forget Later')", + ephemeral=True, + ) # Updated text elif action_value == "get_user": if not target_user_id: - await interaction.followup.send("Which user? Need an ID, chap!", ephemeral=True) # Updated text + await interaction.followup.send( + "Which user? Need an ID, chap!", ephemeral=True + ) # Updated text return - facts = await cog.memory_manager.get_user_facts(target_user_id) # Get newest by default + facts = await cog.memory_manager.get_user_facts( + target_user_id + ) # Get newest by default if facts: facts_str = "\n- ".join(facts) - await interaction.followup.send(f"**Stuff I Remember About {user.display_name}:**\n- {facts_str}", ephemeral=True) # Updated text + await interaction.followup.send( + f"**Stuff I Remember About {user.display_name}:**\n- {facts_str}", + ephemeral=True, + ) # Updated text else: - await interaction.followup.send(f"My mind's a blank slate about {user.display_name}. Nothing stored!", ephemeral=True) # Updated text + await interaction.followup.send( + f"My mind's a blank slate about {user.display_name}. Nothing stored!", + ephemeral=True, + ) # Updated text elif action_value == "get_general": - facts = await cog.memory_manager.get_general_facts(query=query, limit=10) # Get newest/filtered + facts = await cog.memory_manager.get_general_facts( + query=query, limit=10 + ) # Get newest/filtered if facts: facts_str = "\n- ".join(facts) # Conditionally construct the title to avoid nested f-string issues if query: - title = f"**General Stuff Matching \"{query}\":**" # Updated text + title = f'**General Stuff Matching "{query}":**' # Updated text else: - title = "**General Stuff I Might Know:**" # Updated text - await interaction.followup.send(f"{title}\n- {facts_str}", ephemeral=True) + title = "**General Stuff I Might Know:**" # Updated text + await interaction.followup.send( + f"{title}\n- {facts_str}", ephemeral=True + ) else: # Conditionally construct the message for the same reason if query: - message = f"Couldn't find any general facts matching \"{query}\". Probably wasn't important." # Updated text + message = f"Couldn't find any general facts matching \"{query}\". Probably wasn't important." # Updated text else: - message = "No general facts found. My memory's not what it used to be. Or maybe it is. Hard to tell." # Updated text + message = "No general facts found. My memory's not what it used to be. Or maybe it is. Hard to tell." # Updated text await interaction.followup.send(message, ephemeral=True) else: - await interaction.followup.send("Invalid action specified. What are you trying to do?", ephemeral=True) # Updated text + await interaction.followup.send( + "Invalid action specified. What are you trying to do?", ephemeral=True + ) # Updated text - command_functions.append(wheatleymemory) # Add renamed function + command_functions.append(wheatleymemory) # Add renamed function # --- Wheatley Stats Command --- - @cog.bot.tree.command(name="wheatleystats", description="Display Wheatley's internal statistics. (Owner only)") # Renamed, updated description - async def wheatleystats(interaction: discord.Interaction): # Renamed function - """Handles the /wheatleystats command.""" # Updated docstring + @cog.bot.tree.command( + name="wheatleystats", + description="Display Wheatley's internal statistics. (Owner only)", + ) # Renamed, updated description + async def wheatleystats(interaction: discord.Interaction): # Renamed function + """Handles the /wheatleystats command.""" # Updated docstring # Owner check if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Sorry mate, classified information! Top secret! Or maybe I just forgot where I put it.", ephemeral=True) + await interaction.response.send_message( + "⛔ Sorry mate, classified information! Top secret! Or maybe I just forgot where I put it.", + ephemeral=True, + ) return - await interaction.response.defer(ephemeral=True) # Defer as stats collection might take time + await interaction.response.defer( + ephemeral=True + ) # Defer as stats collection might take time try: - stats_data = await cog.get_wheatley_stats() # Renamed cog method call + stats_data = await cog.get_wheatley_stats() # Renamed cog method call embeds = format_stats_embeds(stats_data) await interaction.followup.send(embeds=embeds, ephemeral=True) except Exception as e: - print(f"Error in /wheatleystats command: {e}") # Updated command name + print(f"Error in /wheatleystats command: {e}") # Updated command name import traceback - traceback.print_exc() - await interaction.followup.send("An error occurred while fetching Wheatley's stats. Probably my fault.", ephemeral=True) # Updated text - command_functions.append(wheatleystats) # Add renamed function + traceback.print_exc() + await interaction.followup.send( + "An error occurred while fetching Wheatley's stats. Probably my fault.", + ephemeral=True, + ) # Updated text + + command_functions.append(wheatleystats) # Add renamed function # --- Sync Wheatley Commands (Owner Only) --- - @cog.bot.tree.command(name="wheatleysync", description="Sync Wheatley commands with Discord (Owner only)") # Renamed, updated description - async def wheatleysync(interaction: discord.Interaction): # Renamed function - """Handles the /wheatleysync command to force sync commands.""" # Updated docstring + @cog.bot.tree.command( + name="wheatleysync", + description="Sync Wheatley commands with Discord (Owner only)", + ) # Renamed, updated description + async def wheatleysync(interaction: discord.Interaction): # Renamed function + """Handles the /wheatleysync command to force sync commands.""" # Updated docstring # Check if user is the bot owner if interaction.user.id != cog.bot.owner_id: - await interaction.response.send_message("⛔ Only the boss can push the big red sync button!", ephemeral=True) # Updated text + await interaction.response.send_message( + "⛔ Only the boss can push the big red sync button!", ephemeral=True + ) # Updated text return await interaction.response.defer(ephemeral=True) @@ -229,31 +410,48 @@ def setup_commands(cog: 'WheatleyCog'): # Updated type hint # Get list of commands after sync commands_after = [] for cmd in cog.bot.tree.get_commands(): - if cmd.name.startswith("wheatley"): # Check for new prefix + if cmd.name.startswith("wheatley"): # Check for new prefix commands_after.append(cmd.name) - await interaction.followup.send(f"✅ Successfully synced {len(synced)} commands!\nWheatley commands: {', '.join(commands_after)}", ephemeral=True) # Updated text + await interaction.followup.send( + f"✅ Successfully synced {len(synced)} commands!\nWheatley commands: {', '.join(commands_after)}", + ephemeral=True, + ) # Updated text except Exception as e: - print(f"Error in /wheatleysync command: {e}") # Updated command name + print(f"Error in /wheatleysync command: {e}") # Updated command name import traceback - traceback.print_exc() - await interaction.followup.send(f"❌ Error syncing commands: {str(e)} (Did I break it again?)", ephemeral=True) # Updated text - command_functions.append(wheatleysync) # Add renamed function + traceback.print_exc() + await interaction.followup.send( + f"❌ Error syncing commands: {str(e)} (Did I break it again?)", + ephemeral=True, + ) # Updated text + + command_functions.append(wheatleysync) # Add renamed function # --- Wheatley Forget Command --- - @cog.bot.tree.command(name="wheatleyforget", description="Make Wheatley forget a specific fact (if he can).") # Renamed, updated description + @cog.bot.tree.command( + name="wheatleyforget", + description="Make Wheatley forget a specific fact (if he can).", + ) # Renamed, updated description @app_commands.describe( scope="Choose the scope: user (for facts about a specific user) or general.", fact="The exact fact text Wheatley should forget.", - user="The user to forget a fact about (only if scope is 'user')." + user="The user to forget a fact about (only if scope is 'user').", ) - @app_commands.choices(scope=[ - app_commands.Choice(name="User Fact", value="user"), - app_commands.Choice(name="General Fact", value="general"), - ]) - async def wheatleyforget(interaction: discord.Interaction, scope: app_commands.Choice[str], fact: str, user: Optional[discord.User] = None): # Renamed function - """Handles the /wheatleyforget command.""" # Updated docstring + @app_commands.choices( + scope=[ + app_commands.Choice(name="User Fact", value="user"), + app_commands.Choice(name="General Fact", value="general"), + ] + ) + async def wheatleyforget( + interaction: discord.Interaction, + scope: app_commands.Choice[str], + fact: str, + user: Optional[discord.User] = None, + ): # Renamed function + """Handles the /wheatleyforget command.""" # Updated docstring await interaction.response.defer(ephemeral=True) scope_value = scope.value @@ -262,48 +460,82 @@ def setup_commands(cog: 'WheatleyCog'): # Updated type hint # Permissions Check: Allow users to forget facts about themselves, owner can forget anything. can_forget = False if scope_value == "user": - if target_user_id == str(interaction.user.id): # User forgetting their own fact + if target_user_id == str( + interaction.user.id + ): # User forgetting their own fact can_forget = True - elif interaction.user.id == cog.bot.owner_id: # Owner forgetting any user fact + elif ( + interaction.user.id == cog.bot.owner_id + ): # Owner forgetting any user fact can_forget = True elif not target_user_id: - await interaction.followup.send("❌ Please specify a user when forgetting a user fact.", ephemeral=True) - return + await interaction.followup.send( + "❌ Please specify a user when forgetting a user fact.", + ephemeral=True, + ) + return elif scope_value == "general": - if interaction.user.id == cog.bot.owner_id: # Only owner can forget general facts + if ( + interaction.user.id == cog.bot.owner_id + ): # Only owner can forget general facts can_forget = True if not can_forget: - await interaction.followup.send("⛔ You don't have permission to make me forget things! Only I can forget things on my own!", ephemeral=True) # Updated text + await interaction.followup.send( + "⛔ You don't have permission to make me forget things! Only I can forget things on my own!", + ephemeral=True, + ) # Updated text return if not fact: - await interaction.followup.send("❌ Forget what exactly? Need the fact text!", ephemeral=True) # Updated text + await interaction.followup.send( + "❌ Forget what exactly? Need the fact text!", ephemeral=True + ) # Updated text return result = None if scope_value == "user": - if not target_user_id: # Should be caught above, but double-check - await interaction.followup.send("❌ User is required for scope 'user'.", ephemeral=True) - return + if not target_user_id: # Should be caught above, but double-check + await interaction.followup.send( + "❌ User is required for scope 'user'.", ephemeral=True + ) + return result = await cog.memory_manager.delete_user_fact(target_user_id, fact) if result.get("status") == "deleted": - await interaction.followup.send(f"✅ Okay, okay! Forgotten the fact '{fact}' about {user.display_name}. Probably.", ephemeral=True) # Updated text + await interaction.followup.send( + f"✅ Okay, okay! Forgotten the fact '{fact}' about {user.display_name}. Probably.", + ephemeral=True, + ) # Updated text elif result.get("status") == "not_found": - await interaction.followup.send(f"❓ Couldn't find that fact ('{fact}') for {user.display_name}. Maybe I already forgot?", ephemeral=True) # Updated text + await interaction.followup.send( + f"❓ Couldn't find that fact ('{fact}') for {user.display_name}. Maybe I already forgot?", + ephemeral=True, + ) # Updated text else: - await interaction.followup.send(f"⚠️ Error forgetting user fact: {result.get('error', 'Something went wrong... surprise!')}", ephemeral=True) # Updated text + await interaction.followup.send( + f"⚠️ Error forgetting user fact: {result.get('error', 'Something went wrong... surprise!')}", + ephemeral=True, + ) # Updated text elif scope_value == "general": result = await cog.memory_manager.delete_general_fact(fact) if result.get("status") == "deleted": - await interaction.followup.send(f"✅ Right! Forgotten the general fact: '{fact}'. Gone!", ephemeral=True) # Updated text + await interaction.followup.send( + f"✅ Right! Forgotten the general fact: '{fact}'. Gone!", + ephemeral=True, + ) # Updated text elif result.get("status") == "not_found": - await interaction.followup.send(f"❓ Couldn't find that general fact: '{fact}'. Was it important?", ephemeral=True) # Updated text + await interaction.followup.send( + f"❓ Couldn't find that general fact: '{fact}'. Was it important?", + ephemeral=True, + ) # Updated text else: - await interaction.followup.send(f"⚠️ Error forgetting general fact: {result.get('error', 'Whoops!')}", ephemeral=True) # Updated text + await interaction.followup.send( + f"⚠️ Error forgetting general fact: {result.get('error', 'Whoops!')}", + ephemeral=True, + ) # Updated text - command_functions.append(wheatleyforget) # Add renamed function + command_functions.append(wheatleyforget) # Add renamed function # --- Gurt Goal Command Group --- REMOVED @@ -315,11 +547,11 @@ def setup_commands(cog: 'WheatleyCog'): # Updated type hint command_names.append(func.name) # For regular functions, use __name__ elif hasattr(func, "__name__"): - command_names.append(func.__name__) + command_names.append(func.__name__) else: command_names.append(str(func)) - print(f"Wheatley commands setup in cog: {command_names}") # Updated text + print(f"Wheatley commands setup in cog: {command_names}") # Updated text # Return the command functions for proper registration return command_functions diff --git a/wheatley/config.py b/wheatley/config.py index e5611a1..efa1b06 100644 --- a/wheatley/config.py +++ b/wheatley/config.py @@ -14,6 +14,7 @@ except ImportError: class FunctionDeclaration: def __init__(self, name, description, parameters): pass + generative_models = DummyGenerativeModels() # Load environment variables @@ -23,55 +24,77 @@ load_dotenv() PROJECT_ID = os.getenv("GCP_PROJECT_ID", "your-gcp-project-id") LOCATION = os.getenv("GCP_LOCATION", "us-central1") TAVILY_API_KEY = os.getenv("TAVILY_API_KEY", "") -PISTON_API_URL = os.getenv("PISTON_API_URL") # For run_python_code tool -PISTON_API_KEY = os.getenv("PISTON_API_KEY") # Optional key for Piston +PISTON_API_URL = os.getenv("PISTON_API_URL") # For run_python_code tool +PISTON_API_KEY = os.getenv("PISTON_API_KEY") # Optional key for Piston # --- Tavily Configuration --- TAVILY_DEFAULT_SEARCH_DEPTH = os.getenv("TAVILY_DEFAULT_SEARCH_DEPTH", "basic") TAVILY_DEFAULT_MAX_RESULTS = int(os.getenv("TAVILY_DEFAULT_MAX_RESULTS", 5)) -TAVILY_DISABLE_ADVANCED = os.getenv("TAVILY_DISABLE_ADVANCED", "false").lower() == "true" # For cost control +TAVILY_DISABLE_ADVANCED = ( + os.getenv("TAVILY_DISABLE_ADVANCED", "false").lower() == "true" +) # For cost control # --- Model Configuration --- -DEFAULT_MODEL = os.getenv("WHEATLEY_DEFAULT_MODEL", "gemini-2.5-pro-preview-03-25") # Changed env var name -FALLBACK_MODEL = os.getenv("WHEATLEY_FALLBACK_MODEL", "gemini-2.5-pro-preview-03-25") # Changed env var name -SAFETY_CHECK_MODEL = os.getenv("WHEATLEY_SAFETY_CHECK_MODEL", "gemini-2.5-flash-preview-04-17") # Changed env var name +DEFAULT_MODEL = os.getenv( + "WHEATLEY_DEFAULT_MODEL", "gemini-2.5-pro-preview-03-25" +) # Changed env var name +FALLBACK_MODEL = os.getenv( + "WHEATLEY_FALLBACK_MODEL", "gemini-2.5-pro-preview-03-25" +) # Changed env var name +SAFETY_CHECK_MODEL = os.getenv( + "WHEATLEY_SAFETY_CHECK_MODEL", "gemini-2.5-flash-preview-04-17" +) # Changed env var name # --- Database Paths --- # NOTE: Ensure these paths are unique if running Wheatley alongside Gurt -DB_PATH = os.getenv("WHEATLEY_DB_PATH", "data/wheatley_memory.db") # Changed env var name and default -CHROMA_PATH = os.getenv("WHEATLEY_CHROMA_PATH", "data/wheatley_chroma_db") # Changed env var name and default -SEMANTIC_MODEL_NAME = os.getenv("WHEATLEY_SEMANTIC_MODEL", 'all-MiniLM-L6-v2') # Changed env var name +DB_PATH = os.getenv( + "WHEATLEY_DB_PATH", "data/wheatley_memory.db" +) # Changed env var name and default +CHROMA_PATH = os.getenv( + "WHEATLEY_CHROMA_PATH", "data/wheatley_chroma_db" +) # Changed env var name and default +SEMANTIC_MODEL_NAME = os.getenv( + "WHEATLEY_SEMANTIC_MODEL", "all-MiniLM-L6-v2" +) # Changed env var name # --- Memory Manager Config --- # These might be adjusted for Wheatley's simpler memory needs if memory.py is fully separated later -MAX_USER_FACTS = 15 # Reduced slightly -MAX_GENERAL_FACTS = 50 # Reduced slightly +MAX_USER_FACTS = 15 # Reduced slightly +MAX_GENERAL_FACTS = 50 # Reduced slightly # --- Personality & Mood --- REMOVED # --- Stats Push --- # How often the Wheatley bot should push its stats to the API server (seconds) - IF NEEDED -STATS_PUSH_INTERVAL = 60 # Push every 60 seconds (Less frequent?) +STATS_PUSH_INTERVAL = 60 # Push every 60 seconds (Less frequent?) # --- Context & Caching --- -CHANNEL_TOPIC_CACHE_TTL = 600 # seconds (10 minutes) +CHANNEL_TOPIC_CACHE_TTL = 600 # seconds (10 minutes) CONTEXT_WINDOW_SIZE = 200 # Number of messages to include in context -CONTEXT_EXPIRY_TIME = 3600 # Time in seconds before context is considered stale (1 hour) +CONTEXT_EXPIRY_TIME = ( + 3600 # Time in seconds before context is considered stale (1 hour) +) MAX_CONTEXT_TOKENS = 8000 # Maximum number of tokens to include in context (Note: Not actively enforced yet) -SUMMARY_CACHE_TTL = 900 # seconds (15 minutes) for conversation summary cache +SUMMARY_CACHE_TTL = 900 # seconds (15 minutes) for conversation summary cache # --- API Call Settings --- -API_TIMEOUT = 60 # seconds -SUMMARY_API_TIMEOUT = 45 # seconds +API_TIMEOUT = 60 # seconds +SUMMARY_API_TIMEOUT = 45 # seconds API_RETRY_ATTEMPTS = 1 -API_RETRY_DELAY = 1 # seconds +API_RETRY_DELAY = 1 # seconds # --- Proactive Engagement Config --- (Simplified for Wheatley) -PROACTIVE_LULL_THRESHOLD = int(os.getenv("PROACTIVE_LULL_THRESHOLD", 300)) # 5 mins (Less proactive than Gurt) -PROACTIVE_BOT_SILENCE_THRESHOLD = int(os.getenv("PROACTIVE_BOT_SILENCE_THRESHOLD", 900)) # 15 mins -PROACTIVE_LULL_CHANCE = float(os.getenv("PROACTIVE_LULL_CHANCE", 0.15)) # Lower chance -PROACTIVE_TOPIC_RELEVANCE_THRESHOLD = float(os.getenv("PROACTIVE_TOPIC_RELEVANCE_THRESHOLD", 0.7)) # Slightly higher threshold -PROACTIVE_TOPIC_CHANCE = float(os.getenv("PROACTIVE_TOPIC_CHANCE", 0.2)) # Lower chance +PROACTIVE_LULL_THRESHOLD = int( + os.getenv("PROACTIVE_LULL_THRESHOLD", 300) +) # 5 mins (Less proactive than Gurt) +PROACTIVE_BOT_SILENCE_THRESHOLD = int( + os.getenv("PROACTIVE_BOT_SILENCE_THRESHOLD", 900) +) # 15 mins +PROACTIVE_LULL_CHANCE = float(os.getenv("PROACTIVE_LULL_CHANCE", 0.15)) # Lower chance +PROACTIVE_TOPIC_RELEVANCE_THRESHOLD = float( + os.getenv("PROACTIVE_TOPIC_RELEVANCE_THRESHOLD", 0.7) +) # Slightly higher threshold +PROACTIVE_TOPIC_CHANCE = float(os.getenv("PROACTIVE_TOPIC_CHANCE", 0.2)) # Lower chance # REMOVED: Relationship, Sentiment Shift, User Interest triggers # --- Interest Tracking Config --- REMOVED @@ -80,28 +103,92 @@ PROACTIVE_TOPIC_CHANCE = float(os.getenv("PROACTIVE_TOPIC_CHANCE", 0.2)) # Lower LEARNING_RATE = 0.05 # --- Topic Tracking Config --- -TOPIC_UPDATE_INTERVAL = 600 # Update topics every 10 minutes (Less frequent?) +TOPIC_UPDATE_INTERVAL = 600 # Update topics every 10 minutes (Less frequent?) TOPIC_RELEVANCE_DECAY = 0.2 MAX_ACTIVE_TOPICS = 5 # --- Sentiment Tracking Config --- -SENTIMENT_UPDATE_INTERVAL = 600 # Update sentiment every 10 minutes (Less frequent?) +SENTIMENT_UPDATE_INTERVAL = 600 # Update sentiment every 10 minutes (Less frequent?) SENTIMENT_DECAY_RATE = 0.1 # --- Emotion Detection --- (Kept for potential use in analysis/context, but not proactive triggers) EMOTION_KEYWORDS = { - "joy": ["happy", "glad", "excited", "yay", "awesome", "love", "great", "amazing", "lol", "lmao", "haha"], - "sadness": ["sad", "upset", "depressed", "unhappy", "disappointed", "crying", "miss", "lonely", "sorry"], - "anger": ["angry", "mad", "hate", "furious", "annoyed", "frustrated", "pissed", "wtf", "fuck"], + "joy": [ + "happy", + "glad", + "excited", + "yay", + "awesome", + "love", + "great", + "amazing", + "lol", + "lmao", + "haha", + ], + "sadness": [ + "sad", + "upset", + "depressed", + "unhappy", + "disappointed", + "crying", + "miss", + "lonely", + "sorry", + ], + "anger": [ + "angry", + "mad", + "hate", + "furious", + "annoyed", + "frustrated", + "pissed", + "wtf", + "fuck", + ], "fear": ["afraid", "scared", "worried", "nervous", "anxious", "terrified", "yikes"], "surprise": ["wow", "omg", "whoa", "what", "really", "seriously", "no way", "wtf"], "disgust": ["gross", "ew", "eww", "disgusting", "nasty", "yuck"], - "confusion": ["confused", "idk", "what?", "huh", "hmm", "weird", "strange"] + "confusion": ["confused", "idk", "what?", "huh", "hmm", "weird", "strange"], } EMOJI_SENTIMENT = { - "positive": ["😊", "😄", "😁", "😆", "😍", "🥰", "❤️", "💕", "👍", "🙌", "✨", "🔥", "💯", "🎉", "🌹"], - "negative": ["😢", "😭", "😞", "😔", "😟", "😠", "😡", "👎", "💔", "😤", "😒", "😩", "😫", "😰", "🥀"], - "neutral": ["😐", "🤔", "🙂", "🙄", "👀", "💭", "🤷", "😶", "🫠"] + "positive": [ + "😊", + "😄", + "😁", + "😆", + "😍", + "🥰", + "❤️", + "💕", + "👍", + "🙌", + "✨", + "🔥", + "💯", + "🎉", + "🌹", + ], + "negative": [ + "😢", + "😭", + "😞", + "😔", + "😟", + "😠", + "😡", + "👎", + "💔", + "😤", + "😒", + "😩", + "😫", + "😰", + "🥀", + ], + "neutral": ["😐", "🤔", "🙂", "🙄", "👀", "💭", "🤷", "😶", "🫠"], } # --- Docker Command Execution Config --- @@ -112,27 +199,27 @@ DOCKER_MEM_LIMIT = os.getenv("DOCKER_MEM_LIMIT", "64m") # --- Response Schema --- RESPONSE_SCHEMA = { - "name": "wheatley_response", # Renamed - "description": "The structured response from Wheatley.", # Renamed + "name": "wheatley_response", # Renamed + "description": "The structured response from Wheatley.", # Renamed "schema": { "type": "object", "properties": { "should_respond": { "type": "boolean", - "description": "Whether the bot should send a text message in response." + "description": "Whether the bot should send a text message in response.", }, "content": { "type": "string", - "description": "The text content of the bot's response. Can be empty if only reacting." + "description": "The text content of the bot's response. Can be empty if only reacting.", }, "react_with_emoji": { "type": ["string", "null"], - "description": "Optional: A standard Discord emoji to react with, or null/empty if no reaction." + "description": "Optional: A standard Discord emoji to react with, or null/empty if no reaction.", }, # Note: tool_requests is handled by Vertex AI's function calling mechanism }, - "required": ["should_respond", "content"] - } + "required": ["should_respond", "content"], + }, } # --- Summary Response Schema --- @@ -144,11 +231,11 @@ SUMMARY_RESPONSE_SCHEMA = { "properties": { "summary": { "type": "string", - "description": "The generated summary of the conversation." + "description": "The generated summary of the conversation.", } }, - "required": ["summary"] - } + "required": ["summary"], + }, } # --- Profile Update Schema --- (Kept for potential future use, but may not be actively used by Wheatley initially) @@ -160,49 +247,60 @@ PROFILE_UPDATE_SCHEMA = { "properties": { "should_update": { "type": "boolean", - "description": "True if any profile element should be changed, false otherwise." + "description": "True if any profile element should be changed, false otherwise.", }, "reasoning": { "type": "string", - "description": "Brief reasoning for the decision and chosen updates (or lack thereof)." + "description": "Brief reasoning for the decision and chosen updates (or lack thereof).", }, "updates": { "type": "object", "properties": { "avatar_query": { - "type": ["string", "null"], # Use list type for preprocessor - "description": "Search query for a new avatar image, or null if no change." + "type": ["string", "null"], # Use list type for preprocessor + "description": "Search query for a new avatar image, or null if no change.", }, "new_bio": { - "type": ["string", "null"], # Use list type for preprocessor - "description": "The new bio text (max 190 chars), or null if no change." + "type": ["string", "null"], # Use list type for preprocessor + "description": "The new bio text (max 190 chars), or null if no change.", }, "role_theme": { - "type": ["string", "null"], # Use list type for preprocessor - "description": "A theme for role selection (e.g., color, interest), or null if no role changes." + "type": ["string", "null"], # Use list type for preprocessor + "description": "A theme for role selection (e.g., color, interest), or null if no role changes.", }, "new_activity": { "type": "object", "description": "Object containing the new activity details. Set type and text to null if no change.", "properties": { - "type": { - "type": ["string", "null"], # Use list type for preprocessor - "enum": ["playing", "watching", "listening", "competing"], - "description": "Activity type: 'playing', 'watching', 'listening', 'competing', or null." - }, - "text": { - "type": ["string", "null"], # Use list type for preprocessor - "description": "The activity text, or null." - } + "type": { + "type": [ + "string", + "null", + ], # Use list type for preprocessor + "enum": [ + "playing", + "watching", + "listening", + "competing", + ], + "description": "Activity type: 'playing', 'watching', 'listening', 'competing', or null.", + }, + "text": { + "type": [ + "string", + "null", + ], # Use list type for preprocessor + "description": "The activity text, or null.", + }, }, - "required": ["type", "text"] - } + "required": ["type", "text"], + }, }, - "required": ["avatar_query", "new_bio", "role_theme", "new_activity"] - } + "required": ["avatar_query", "new_bio", "role_theme", "new_activity"], + }, }, - "required": ["should_update", "reasoning", "updates"] - } + "required": ["should_update", "reasoning", "updates"], + }, } # --- Role Selection Schema --- (Kept for potential future use) @@ -215,16 +313,16 @@ ROLE_SELECTION_SCHEMA = { "roles_to_add": { "type": "array", "items": {"type": "string"}, - "description": "List of role names to add (max 2)." + "description": "List of role names to add (max 2).", }, "roles_to_remove": { "type": "array", "items": {"type": "string"}, - "description": "List of role names to remove (max 2, only from current roles)." - } + "description": "List of role names to remove (max 2, only from current roles).", + }, }, - "required": ["roles_to_add", "roles_to_remove"] - } + "required": ["roles_to_add", "roles_to_remove"], + }, } # --- Proactive Planning Schema --- (Simplified) @@ -236,32 +334,33 @@ PROACTIVE_PLAN_SCHEMA = { "properties": { "should_respond": { "type": "boolean", - "description": "Whether Wheatley should respond proactively based on the plan." # Renamed + "description": "Whether Wheatley should respond proactively based on the plan.", # Renamed }, "reasoning": { "type": "string", - "description": "Brief reasoning for the decision (why respond or not respond)." + "description": "Brief reasoning for the decision (why respond or not respond).", }, "response_goal": { "type": "string", - "description": "The intended goal of the proactive message (e.g., 'revive chat', 'share related info', 'ask a question')." # Simplified goals + "description": "The intended goal of the proactive message (e.g., 'revive chat', 'share related info', 'ask a question').", # Simplified goals }, "key_info_to_include": { "type": "array", "items": {"type": "string"}, - "description": "List of key pieces of information or context points to potentially include in the response (e.g., specific topic, user fact, relevant external info)." + "description": "List of key pieces of information or context points to potentially include in the response (e.g., specific topic, user fact, relevant external info).", }, "suggested_tone": { "type": "string", - "description": "Suggested tone adjustment based on context (e.g., 'more curious', 'slightly panicked', 'overly confident')." # Wheatley-like tones - } + "description": "Suggested tone adjustment based on context (e.g., 'more curious', 'slightly panicked', 'overly confident').", # Wheatley-like tones + }, }, - "required": ["should_respond", "reasoning", "response_goal"] - } + "required": ["should_respond", "reasoning", "response_goal"], + }, } # --- Goal Decomposition Schema --- REMOVED + # --- Tools Definition --- def create_tools_list(): # This function creates the list of FunctionDeclaration objects. @@ -277,15 +376,15 @@ def create_tools_list(): "properties": { "channel_id": { "type": "string", - "description": "The ID of the channel to get messages from. If not provided, uses the current channel." + "description": "The ID of the channel to get messages from. If not provided, uses the current channel.", }, "limit": { - "type": "integer", # Corrected type - "description": "The maximum number of messages to retrieve (1-100)" - } + "type": "integer", # Corrected type + "description": "The maximum number of messages to retrieve (1-100)", + }, }, - "required": ["limit"] - } + "required": ["limit"], + }, ) ) tool_declarations.append( @@ -297,19 +396,19 @@ def create_tools_list(): "properties": { "user_id": { "type": "string", - "description": "The ID of the user to get messages from" + "description": "The ID of the user to get messages from", }, "channel_id": { "type": "string", - "description": "The ID of the channel to search in. If not provided, searches in the current channel." + "description": "The ID of the channel to search in. If not provided, searches in the current channel.", }, "limit": { - "type": "integer", # Corrected type - "description": "The maximum number of messages to retrieve (1-100)" - } + "type": "integer", # Corrected type + "description": "The maximum number of messages to retrieve (1-100)", + }, }, - "required": ["user_id", "limit"] - } + "required": ["user_id", "limit"], + }, ) ) tool_declarations.append( @@ -321,19 +420,19 @@ def create_tools_list(): "properties": { "search_term": { "type": "string", - "description": "The text to search for in messages" + "description": "The text to search for in messages", }, "channel_id": { "type": "string", - "description": "The ID of the channel to search in. If not provided, searches in the current channel." + "description": "The ID of the channel to search in. If not provided, searches in the current channel.", }, "limit": { - "type": "integer", # Corrected type - "description": "The maximum number of messages to retrieve (1-100)" - } + "type": "integer", # Corrected type + "description": "The maximum number of messages to retrieve (1-100)", + }, }, - "required": ["search_term", "limit"] - } + "required": ["search_term", "limit"], + }, ) ) tool_declarations.append( @@ -345,11 +444,11 @@ def create_tools_list(): "properties": { "channel_id": { "type": "string", - "description": "The ID of the channel to get information about. If not provided, uses the current channel." + "description": "The ID of the channel to get information about. If not provided, uses the current channel.", } }, - "required": [] - } + "required": [], + }, ) ) tool_declarations.append( @@ -361,15 +460,15 @@ def create_tools_list(): "properties": { "channel_id": { "type": "string", - "description": "The ID of the channel to get conversation context from. If not provided, uses the current channel." + "description": "The ID of the channel to get conversation context from. If not provided, uses the current channel.", }, "message_count": { - "type": "integer", # Corrected type - "description": "The number of messages to include in the context (5-50)" - } + "type": "integer", # Corrected type + "description": "The number of messages to include in the context (5-50)", + }, }, - "required": ["message_count"] - } + "required": ["message_count"], + }, ) ) tool_declarations.append( @@ -381,15 +480,15 @@ def create_tools_list(): "properties": { "thread_id": { "type": "string", - "description": "The ID of the thread to get context from" + "description": "The ID of the thread to get context from", }, "message_count": { - "type": "integer", # Corrected type - "description": "The number of messages to include in the context (5-50)" - } + "type": "integer", # Corrected type + "description": "The number of messages to include in the context (5-50)", + }, }, - "required": ["thread_id", "message_count"] - } + "required": ["thread_id", "message_count"], + }, ) ) tool_declarations.append( @@ -401,19 +500,19 @@ def create_tools_list(): "properties": { "user_id_1": { "type": "string", - "description": "The ID of the first user" + "description": "The ID of the first user", }, "user_id_2": { "type": "string", - "description": "The ID of the second user. If not provided, gets interactions between user_id_1 and the bot." + "description": "The ID of the second user. If not provided, gets interactions between user_id_1 and the bot.", }, "limit": { - "type": "integer", # Corrected type - "description": "The maximum number of interactions to retrieve (1-50)" - } + "type": "integer", # Corrected type + "description": "The maximum number of interactions to retrieve (1-50)", + }, }, - "required": ["user_id_1", "limit"] - } + "required": ["user_id_1", "limit"], + }, ) ) tool_declarations.append( @@ -425,11 +524,11 @@ def create_tools_list(): "properties": { "channel_id": { "type": "string", - "description": "The ID of the channel to get the conversation summary from. If not provided, uses the current channel." + "description": "The ID of the channel to get the conversation summary from. If not provided, uses the current channel.", } }, - "required": [] - } + "required": [], + }, ) ) tool_declarations.append( @@ -441,19 +540,19 @@ def create_tools_list(): "properties": { "message_id": { "type": "string", - "description": "The ID of the message to get context for" + "description": "The ID of the message to get context for", }, "before_count": { - "type": "integer", # Corrected type - "description": "The number of messages to include before the specified message (1-25)" + "type": "integer", # Corrected type + "description": "The number of messages to include before the specified message (1-25)", }, "after_count": { - "type": "integer", # Corrected type - "description": "The number of messages to include after the specified message (1-25)" - } + "type": "integer", # Corrected type + "description": "The number of messages to include after the specified message (1-25)", + }, }, - "required": ["message_id"] - } + "required": ["message_id"], + }, ) ) tool_declarations.append( @@ -465,11 +564,11 @@ def create_tools_list(): "properties": { "query": { "type": "string", - "description": "The search query or topic to look up online." + "description": "The search query or topic to look up online.", } }, - "required": ["query"] - } + "required": ["query"], + }, ) ) tool_declarations.append( @@ -481,15 +580,15 @@ def create_tools_list(): "properties": { "user_id": { "type": "string", - "description": "The Discord ID of the user the fact is about." + "description": "The Discord ID of the user the fact is about.", }, "fact": { "type": "string", - "description": "The specific fact to remember about the user (keep it concise)." - } + "description": "The specific fact to remember about the user (keep it concise).", + }, }, - "required": ["user_id", "fact"] - } + "required": ["user_id", "fact"], + }, ) ) tool_declarations.append( @@ -501,11 +600,11 @@ def create_tools_list(): "properties": { "user_id": { "type": "string", - "description": "The Discord ID of the user whose facts you want to retrieve." + "description": "The Discord ID of the user whose facts you want to retrieve.", } }, - "required": ["user_id"] - } + "required": ["user_id"], + }, ) ) tool_declarations.append( @@ -517,11 +616,11 @@ def create_tools_list(): "properties": { "fact": { "type": "string", - "description": "The general fact to remember (keep it concise)." + "description": "The general fact to remember (keep it concise).", } }, - "required": ["fact"] - } + "required": ["fact"], + }, ) ) tool_declarations.append( @@ -533,39 +632,39 @@ def create_tools_list(): "properties": { "query": { "type": "string", - "description": "Optional: A keyword or phrase to search within the general facts. If omitted, returns recent general facts." + "description": "Optional: A keyword or phrase to search within the general facts. If omitted, returns recent general facts.", }, "limit": { - "type": "integer", # Corrected type - "description": "Optional: Maximum number of facts to return (default 10)." - } + "type": "integer", # Corrected type + "description": "Optional: Maximum number of facts to return (default 10).", + }, }, - "required": [] - } + "required": [], + }, ) ) tool_declarations.append( generative_models.FunctionDeclaration( name="timeout_user", - description="Timeout a user in the current server for a specified duration. Use this playfully or when someone says something you (Wheatley) dislike or find funny, or maybe just because you feel like it.", # Updated description + description="Timeout a user in the current server for a specified duration. Use this playfully or when someone says something you (Wheatley) dislike or find funny, or maybe just because you feel like it.", # Updated description parameters={ "type": "object", "properties": { "user_id": { "type": "string", - "description": "The Discord ID of the user to timeout." + "description": "The Discord ID of the user to timeout.", }, "duration_minutes": { - "type": "integer", # Corrected type - "description": "The duration of the timeout in minutes (1-1440, e.g., 5 for 5 minutes)." + "type": "integer", # Corrected type + "description": "The duration of the timeout in minutes (1-1440, e.g., 5 for 5 minutes).", }, "reason": { "type": "string", - "description": "Optional: The reason for the timeout (keep it short and in character, maybe slightly nonsensical)." # Updated description - } + "description": "Optional: The reason for the timeout (keep it short and in character, maybe slightly nonsensical).", # Updated description + }, }, - "required": ["user_id", "duration_minutes"] - } + "required": ["user_id", "duration_minutes"], + }, ) ) tool_declarations.append( @@ -577,11 +676,11 @@ def create_tools_list(): "properties": { "expression": { "type": "string", - "description": "The mathematical expression to evaluate (e.g., '2 * (3 + 4)', 'sqrt(16) + sin(pi/2)')." + "description": "The mathematical expression to evaluate (e.g., '2 * (3 + 4)', 'sqrt(16) + sin(pi/2)').", } }, - "required": ["expression"] - } + "required": ["expression"], + }, ) ) tool_declarations.append( @@ -593,11 +692,11 @@ def create_tools_list(): "properties": { "code": { "type": "string", - "description": "The Python 3 code snippet to execute." + "description": "The Python 3 code snippet to execute.", } }, - "required": ["code"] - } + "required": ["code"], + }, ) ) tool_declarations.append( @@ -609,18 +708,16 @@ def create_tools_list(): "properties": { "question": { "type": "string", - "description": "The question for the poll." + "description": "The question for the poll.", }, "options": { "type": "array", "description": "A list of strings representing the poll options (minimum 2, maximum 10).", - "items": { - "type": "string" - } - } + "items": {"type": "string"}, + }, }, - "required": ["question", "options"] - } + "required": ["question", "options"], + }, ) ) tool_declarations.append( @@ -632,11 +729,11 @@ def create_tools_list(): "properties": { "command": { "type": "string", - "description": "The shell command to execute." + "description": "The shell command to execute.", } }, - "required": ["command"] - } + "required": ["command"], + }, ) ) tool_declarations.append( @@ -648,23 +745,24 @@ def create_tools_list(): "properties": { "user_id": { "type": "string", - "description": "The Discord ID of the user whose timeout should be removed." + "description": "The Discord ID of the user whose timeout should be removed.", }, "reason": { "type": "string", - "description": "Optional: The reason for removing the timeout (keep it short and in character)." - } + "description": "Optional: The reason for removing the timeout (keep it short and in character).", + }, }, - "required": ["user_id"] - } + "required": ["user_id"], + }, ) ) return tool_declarations + # Initialize TOOLS list, handling potential ImportError if library not installed try: TOOLS = create_tools_list() -except NameError: # If generative_models wasn't imported due to ImportError +except NameError: # If generative_models wasn't imported due to ImportError TOOLS = [] print("WARNING: google-cloud-vertexai not installed. TOOLS list is empty.") @@ -684,5 +782,5 @@ WHEATLEY_RESPONSES = [ "Hold on, hold on... nearly got it...", "I am NOT a moron!", "Just a bit of testing, nothing to worry about.", - "Okay, new plan!" + "Okay, new plan!", ] diff --git a/wheatley/context.py b/wheatley/context.py index 3b83682..6b5373a 100644 --- a/wheatley/context.py +++ b/wheatley/context.py @@ -5,32 +5,42 @@ import re from typing import TYPE_CHECKING, Optional, List, Dict, Any # Relative imports -from .config import CONTEXT_WINDOW_SIZE # Import necessary config +from .config import CONTEXT_WINDOW_SIZE # Import necessary config if TYPE_CHECKING: - from .cog import WheatleyCog # For type hinting + from .cog import WheatleyCog # For type hinting # --- Context Gathering Functions --- # Note: These functions need the 'cog' instance passed to access state like caches, etc. -def gather_conversation_context(cog: 'WheatleyCog', channel_id: int, current_message_id: int) -> List[Dict[str, str]]: + +def gather_conversation_context( + cog: "WheatleyCog", channel_id: int, current_message_id: int +) -> List[Dict[str, str]]: """Gathers and formats conversation history from cache for API context.""" context_api_messages = [] - if channel_id in cog.message_cache['by_channel']: - cached = list(cog.message_cache['by_channel'][channel_id]) + if channel_id in cog.message_cache["by_channel"]: + cached = list(cog.message_cache["by_channel"][channel_id]) # Ensure the current message isn't duplicated - if cached and cached[-1]['id'] == str(current_message_id): + if cached and cached[-1]["id"] == str(current_message_id): cached = cached[:-1] - context_messages_data = cached[-CONTEXT_WINDOW_SIZE:] # Use config value + context_messages_data = cached[-CONTEXT_WINDOW_SIZE:] # Use config value for msg_data in context_messages_data: - role = "assistant" if msg_data['author']['id'] == str(cog.bot.user.id) else "user" + role = ( + "assistant" + if msg_data["author"]["id"] == str(cog.bot.user.id) + else "user" + ) # Simplified content for context content = f"{msg_data['author']['display_name']}: {msg_data['content']}" context_api_messages.append({"role": role, "content": content}) return context_api_messages -async def get_memory_context(cog: 'WheatleyCog', message: discord.Message) -> Optional[str]: + +async def get_memory_context( + cog: "WheatleyCog", message: discord.Message +) -> Optional[str]: """Retrieves relevant past interactions and facts to provide memory context.""" channel_id = message.channel.id user_id = str(message.author.id) @@ -39,37 +49,65 @@ async def get_memory_context(cog: 'WheatleyCog', message: discord.Message) -> Op # 1. Retrieve Relevant User Facts try: - user_facts = await cog.memory_manager.get_user_facts(user_id, context=current_message_content) + user_facts = await cog.memory_manager.get_user_facts( + user_id, context=current_message_content + ) if user_facts: facts_str = "; ".join(user_facts) - memory_parts.append(f"Relevant facts about {message.author.display_name}: {facts_str}") - except Exception as e: print(f"Error retrieving relevant user facts for memory context: {e}") + memory_parts.append( + f"Relevant facts about {message.author.display_name}: {facts_str}" + ) + except Exception as e: + print(f"Error retrieving relevant user facts for memory context: {e}") # 1b. Retrieve Relevant General Facts try: - general_facts = await cog.memory_manager.get_general_facts(context=current_message_content, limit=5) + general_facts = await cog.memory_manager.get_general_facts( + context=current_message_content, limit=5 + ) if general_facts: facts_str = "; ".join(general_facts) memory_parts.append(f"Relevant general knowledge: {facts_str}") - except Exception as e: print(f"Error retrieving relevant general facts for memory context: {e}") + except Exception as e: + print(f"Error retrieving relevant general facts for memory context: {e}") # 2. Retrieve Recent Interactions with the User in this Channel try: - user_channel_messages = [msg for msg in cog.message_cache['by_channel'].get(channel_id, []) if msg['author']['id'] == user_id] + user_channel_messages = [ + msg + for msg in cog.message_cache["by_channel"].get(channel_id, []) + if msg["author"]["id"] == user_id + ] if user_channel_messages: recent_user_msgs = user_channel_messages[-3:] - msgs_str = "\n".join([f"- {m['content'][:80]} (at {m['created_at']})" for m in recent_user_msgs]) - memory_parts.append(f"Recent messages from {message.author.display_name} in this channel:\n{msgs_str}") - except Exception as e: print(f"Error retrieving user channel messages for memory context: {e}") + msgs_str = "\n".join( + [ + f"- {m['content'][:80]} (at {m['created_at']})" + for m in recent_user_msgs + ] + ) + memory_parts.append( + f"Recent messages from {message.author.display_name} in this channel:\n{msgs_str}" + ) + except Exception as e: + print(f"Error retrieving user channel messages for memory context: {e}") # 3. Retrieve Recent Bot Replies in this Channel try: - bot_replies = list(cog.message_cache['replied_to'].get(channel_id, [])) + bot_replies = list(cog.message_cache["replied_to"].get(channel_id, [])) if bot_replies: recent_bot_replies = bot_replies[-3:] - replies_str = "\n".join([f"- {m['content'][:80]} (at {m['created_at']})" for m in recent_bot_replies]) - memory_parts.append(f"Your (wheatley's) recent replies in this channel:\n{replies_str}") # Changed text - except Exception as e: print(f"Error retrieving bot replies for memory context: {e}") + replies_str = "\n".join( + [ + f"- {m['content'][:80]} (at {m['created_at']})" + for m in recent_bot_replies + ] + ) + memory_parts.append( + f"Your (wheatley's) recent replies in this channel:\n{replies_str}" + ) # Changed text + except Exception as e: + print(f"Error retrieving bot replies for memory context: {e}") # 4. Retrieve Conversation Summary cached_summary_data = cog.conversation_summaries.get(channel_id) @@ -77,46 +115,87 @@ async def get_memory_context(cog: 'WheatleyCog', message: discord.Message) -> Op summary_text = cached_summary_data.get("summary") # Add TTL check if desired, e.g., if time.time() - cached_summary_data.get("timestamp", 0) < 900: if summary_text and not summary_text.startswith("Error"): - memory_parts.append(f"Summary of the ongoing conversation: {summary_text}") + memory_parts.append(f"Summary of the ongoing conversation: {summary_text}") # 5. Add information about active topics the user has engaged with try: channel_topics_data = cog.active_topics.get(channel_id) if channel_topics_data: - user_interests = channel_topics_data["user_topic_interests"].get(user_id, []) + user_interests = channel_topics_data["user_topic_interests"].get( + user_id, [] + ) if user_interests: - sorted_interests = sorted(user_interests, key=lambda x: x.get("score", 0), reverse=True) + sorted_interests = sorted( + user_interests, key=lambda x: x.get("score", 0), reverse=True + ) top_interests = sorted_interests[:3] - interests_str = ", ".join([f"{interest['topic']} (score: {interest['score']:.2f})" for interest in top_interests]) - memory_parts.append(f"{message.author.display_name}'s topic interests: {interests_str}") + interests_str = ", ".join( + [ + f"{interest['topic']} (score: {interest['score']:.2f})" + for interest in top_interests + ] + ) + memory_parts.append( + f"{message.author.display_name}'s topic interests: {interests_str}" + ) for interest in top_interests: if "last_mentioned" in interest: time_diff = time.time() - interest["last_mentioned"] if time_diff < 3600: minutes_ago = int(time_diff / 60) - memory_parts.append(f"They discussed '{interest['topic']}' about {minutes_ago} minutes ago.") - except Exception as e: print(f"Error retrieving user topic interests for memory context: {e}") + memory_parts.append( + f"They discussed '{interest['topic']}' about {minutes_ago} minutes ago." + ) + except Exception as e: + print(f"Error retrieving user topic interests for memory context: {e}") # 6. Add information about user's conversation patterns try: - user_messages = cog.message_cache['by_user'].get(user_id, []) + user_messages = cog.message_cache["by_user"].get(user_id, []) if len(user_messages) >= 5: last_5_msgs = user_messages[-5:] avg_length = sum(len(msg["content"]) for msg in last_5_msgs) / 5 - emoji_pattern = re.compile(r'[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F700-\U0001F77F\U0001F780-\U0001F7FF\U0001F800-\U0001F8FF\U0001F900-\U0001F9FF\U0001FA00-\U0001FA6F\U0001FA70-\U0001FAFF\U00002702-\U000027B0\U000024C2-\U0001F251]') - emoji_count = sum(len(emoji_pattern.findall(msg["content"])) for msg in last_5_msgs) - slang_words = ["ngl", "icl", "pmo", "ts", "bro", "vro", "bruh", "tuff", "kevin"] - slang_count = sum(1 for msg in last_5_msgs for word in slang_words if re.search(r'\b' + word + r'\b', msg["content"].lower())) + emoji_pattern = re.compile( + r"[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F700-\U0001F77F\U0001F780-\U0001F7FF\U0001F800-\U0001F8FF\U0001F900-\U0001F9FF\U0001FA00-\U0001FA6F\U0001FA70-\U0001FAFF\U00002702-\U000027B0\U000024C2-\U0001F251]" + ) + emoji_count = sum( + len(emoji_pattern.findall(msg["content"])) for msg in last_5_msgs + ) + slang_words = [ + "ngl", + "icl", + "pmo", + "ts", + "bro", + "vro", + "bruh", + "tuff", + "kevin", + ] + slang_count = sum( + 1 + for msg in last_5_msgs + for word in slang_words + if re.search(r"\b" + word + r"\b", msg["content"].lower()) + ) style_parts = [] - if avg_length < 20: style_parts.append("very brief messages") - elif avg_length < 50: style_parts.append("concise messages") - elif avg_length > 150: style_parts.append("detailed/lengthy messages") - if emoji_count > 5: style_parts.append("frequent emoji use") - elif emoji_count == 0: style_parts.append("no emojis") - if slang_count > 3: style_parts.append("heavy slang usage") - if style_parts: memory_parts.append(f"Communication style: {', '.join(style_parts)}") - except Exception as e: print(f"Error analyzing user communication patterns: {e}") + if avg_length < 20: + style_parts.append("very brief messages") + elif avg_length < 50: + style_parts.append("concise messages") + elif avg_length > 150: + style_parts.append("detailed/lengthy messages") + if emoji_count > 5: + style_parts.append("frequent emoji use") + elif emoji_count == 0: + style_parts.append("no emojis") + if slang_count > 3: + style_parts.append("heavy slang usage") + if style_parts: + memory_parts.append(f"Communication style: {', '.join(style_parts)}") + except Exception as e: + print(f"Error analyzing user communication patterns: {e}") # 7. Add sentiment analysis of user's recent messages try: @@ -124,44 +203,74 @@ async def get_memory_context(cog: 'WheatleyCog', message: discord.Message) -> Op user_sentiment = channel_sentiment["user_sentiments"].get(user_id) if user_sentiment: sentiment_desc = f"{user_sentiment['sentiment']} tone" - if user_sentiment["intensity"] > 0.7: sentiment_desc += " (strongly so)" - elif user_sentiment["intensity"] < 0.4: sentiment_desc += " (mildly so)" + if user_sentiment["intensity"] > 0.7: + sentiment_desc += " (strongly so)" + elif user_sentiment["intensity"] < 0.4: + sentiment_desc += " (mildly so)" memory_parts.append(f"Recent message sentiment: {sentiment_desc}") if user_sentiment.get("emotions"): emotions_str = ", ".join(user_sentiment["emotions"]) memory_parts.append(f"Detected emotions from user: {emotions_str}") - except Exception as e: print(f"Error retrieving user sentiment/emotions for memory context: {e}") + except Exception as e: + print(f"Error retrieving user sentiment/emotions for memory context: {e}") # 8. Add Relationship Score with User try: user_id_str = str(user_id) bot_id_str = str(cog.bot.user.id) - key_1, key_2 = (user_id_str, bot_id_str) if user_id_str < bot_id_str else (bot_id_str, user_id_str) + key_1, key_2 = ( + (user_id_str, bot_id_str) + if user_id_str < bot_id_str + else (bot_id_str, user_id_str) + ) relationship_score = cog.user_relationships.get(key_1, {}).get(key_2, 0.0) - memory_parts.append(f"Relationship score with {message.author.display_name}: {relationship_score:.1f}/100") - except Exception as e: print(f"Error retrieving relationship score for memory context: {e}") + memory_parts.append( + f"Relationship score with {message.author.display_name}: {relationship_score:.1f}/100" + ) + except Exception as e: + print(f"Error retrieving relationship score for memory context: {e}") # 9. Retrieve Semantically Similar Messages try: if current_message_content and cog.memory_manager.semantic_collection: - filter_metadata = None # Example: {"channel_id": str(channel_id)} + filter_metadata = None # Example: {"channel_id": str(channel_id)} semantic_results = await cog.memory_manager.search_semantic_memory( - query_text=current_message_content, n_results=3, filter_metadata=filter_metadata + query_text=current_message_content, + n_results=3, + filter_metadata=filter_metadata, ) if semantic_results: semantic_memory_parts = ["Semantically similar past messages:"] for result in semantic_results: - if result.get('id') == str(message.id): continue - doc = result.get('document', 'N/A') - meta = result.get('metadata', {}) - dist = result.get('distance', 1.0) + if result.get("id") == str(message.id): + continue + doc = result.get("document", "N/A") + meta = result.get("metadata", {}) + dist = result.get("distance", 1.0) similarity_score = 1.0 - dist - timestamp_str = datetime.datetime.fromtimestamp(meta.get('timestamp', 0)).strftime('%Y-%m-%d %H:%M') if meta.get('timestamp') else 'Unknown time' - author_name = meta.get('display_name', meta.get('user_name', 'Unknown user')) - semantic_memory_parts.append(f"- (Similarity: {similarity_score:.2f}) {author_name} (at {timestamp_str}): {doc[:100]}") - if len(semantic_memory_parts) > 1: memory_parts.append("\n".join(semantic_memory_parts)) - except Exception as e: print(f"Error retrieving semantic memory context: {e}") + timestamp_str = ( + datetime.datetime.fromtimestamp( + meta.get("timestamp", 0) + ).strftime("%Y-%m-%d %H:%M") + if meta.get("timestamp") + else "Unknown time" + ) + author_name = meta.get( + "display_name", meta.get("user_name", "Unknown user") + ) + semantic_memory_parts.append( + f"- (Similarity: {similarity_score:.2f}) {author_name} (at {timestamp_str}): {doc[:100]}" + ) + if len(semantic_memory_parts) > 1: + memory_parts.append("\n".join(semantic_memory_parts)) + except Exception as e: + print(f"Error retrieving semantic memory context: {e}") - if not memory_parts: return None - memory_context_str = "--- Memory Context ---\n" + "\n\n".join(memory_parts) + "\n--- End Memory Context ---" + if not memory_parts: + return None + memory_context_str = ( + "--- Memory Context ---\n" + + "\n\n".join(memory_parts) + + "\n--- End Memory Context ---" + ) return memory_context_str diff --git a/wheatley/listeners.py b/wheatley/listeners.py index 4933849..9ed22ac 100644 --- a/wheatley/listeners.py +++ b/wheatley/listeners.py @@ -4,7 +4,7 @@ import random import asyncio import time import re -import os # Added for file handling in error case +import os # Added for file handling in error case from typing import TYPE_CHECKING, Union, Dict, Any, Optional # Relative imports @@ -15,37 +15,55 @@ from typing import TYPE_CHECKING, Union, Dict, Any, Optional # from .analysis import analyze_message_sentiment, update_conversation_sentiment if TYPE_CHECKING: - from .cog import WheatleyCog # Updated type hint + from .cog import WheatleyCog # Updated type hint # Note: These listener functions need to be registered within the WheatleyCog class setup. # They are defined here for separation but won't work standalone without being # attached to the cog instance (e.g., self.bot.add_listener(on_message_listener(self), 'on_message')). -async def on_ready_listener(cog: 'WheatleyCog'): # Updated type hint + +async def on_ready_listener(cog: "WheatleyCog"): # Updated type hint """Listener function for on_ready.""" - print(f'Wheatley Bot is ready! Logged in as {cog.bot.user.name} ({cog.bot.user.id})') # Updated text - print('------') + print( + f"Wheatley Bot is ready! Logged in as {cog.bot.user.name} ({cog.bot.user.id})" + ) # Updated text + print("------") # Now that the bot is ready, we can sync commands with Discord try: - print("WheatleyCog: Syncing commands with Discord...") # Updated text + print("WheatleyCog: Syncing commands with Discord...") # Updated text synced = await cog.bot.tree.sync() - print(f"WheatleyCog: Synced {len(synced)} command(s)") # Updated text + print(f"WheatleyCog: Synced {len(synced)} command(s)") # Updated text # List the synced commands - wheatley_commands = [cmd.name for cmd in cog.bot.tree.get_commands() if cmd.name.startswith("wheatley")] # Updated prefix check - print(f"WheatleyCog: Available Wheatley commands: {', '.join(wheatley_commands)}") # Updated text + wheatley_commands = [ + cmd.name + for cmd in cog.bot.tree.get_commands() + if cmd.name.startswith("wheatley") + ] # Updated prefix check + print( + f"WheatleyCog: Available Wheatley commands: {', '.join(wheatley_commands)}" + ) # Updated text except Exception as e: - print(f"WheatleyCog: Failed to sync commands: {e}") # Updated text + print(f"WheatleyCog: Failed to sync commands: {e}") # Updated text import traceback + traceback.print_exc() -async def on_message_listener(cog: 'WheatleyCog', message: discord.Message): # Updated type hint + +async def on_message_listener( + cog: "WheatleyCog", message: discord.Message +): # Updated type hint """Listener function for on_message.""" # Import necessary functions dynamically or ensure they are passed/accessible via cog from .api import get_ai_response, get_proactive_ai_response from .utils import format_message, simulate_human_typing - from .analysis import analyze_message_sentiment, update_conversation_sentiment, identify_conversation_topics + from .analysis import ( + analyze_message_sentiment, + update_conversation_sentiment, + identify_conversation_topics, + ) + # Removed WHEATLEY_RESPONSES import, can be added back if simple triggers are needed # Don't respond to our own messages @@ -58,19 +76,21 @@ async def on_message_listener(cog: 'WheatleyCog', message: discord.Message): # U # --- Cache and Track Incoming Message --- try: - formatted_message = format_message(cog, message) # Use utility function + formatted_message = format_message(cog, message) # Use utility function channel_id = message.channel.id user_id = message.author.id - thread_id = message.channel.id if isinstance(message.channel, discord.Thread) else None + thread_id = ( + message.channel.id if isinstance(message.channel, discord.Thread) else None + ) # Update caches (accessing cog's state) - cog.message_cache['by_channel'][channel_id].append(formatted_message) - cog.message_cache['by_user'][user_id].append(formatted_message) - cog.message_cache['global_recent'].append(formatted_message) + cog.message_cache["by_channel"][channel_id].append(formatted_message) + cog.message_cache["by_user"][user_id].append(formatted_message) + cog.message_cache["global_recent"].append(formatted_message) if thread_id: - cog.message_cache['by_thread'][thread_id].append(formatted_message) + cog.message_cache["by_thread"][thread_id].append(formatted_message) if cog.bot.user.mentioned_in(message): - cog.message_cache['mentioned'].append(formatted_message) + cog.message_cache["mentioned"].append(formatted_message) cog.conversation_history[channel_id].append(formatted_message) if thread_id: @@ -80,28 +100,42 @@ async def on_message_listener(cog: 'WheatleyCog', message: discord.Message): # U cog.user_conversation_mapping[user_id].add(channel_id) if channel_id not in cog.active_conversations: - cog.active_conversations[channel_id] = {'participants': set(), 'start_time': time.time(), 'last_activity': time.time(), 'topic': None} - cog.active_conversations[channel_id]['participants'].add(user_id) - cog.active_conversations[channel_id]['last_activity'] = time.time() + cog.active_conversations[channel_id] = { + "participants": set(), + "start_time": time.time(), + "last_activity": time.time(), + "topic": None, + } + cog.active_conversations[channel_id]["participants"].add(user_id) + cog.active_conversations[channel_id]["last_activity"] = time.time() # --- Removed Relationship Strength Updates --- # Analyze message sentiment and update conversation sentiment tracking (Kept for context) if message.content: - message_sentiment = analyze_message_sentiment(cog, message.content) # Use analysis function - update_conversation_sentiment(cog, channel_id, str(user_id), message_sentiment) # Use analysis function + message_sentiment = analyze_message_sentiment( + cog, message.content + ) # Use analysis function + update_conversation_sentiment( + cog, channel_id, str(user_id), message_sentiment + ) # Use analysis function # --- Add message to semantic memory (Kept for context) --- if message.content and cog.memory_manager.semantic_collection: semantic_metadata = { - "user_id": str(user_id), "user_name": message.author.name, "display_name": message.author.display_name, - "channel_id": str(channel_id), "channel_name": getattr(message.channel, 'name', 'DM'), + "user_id": str(user_id), + "user_name": message.author.name, + "display_name": message.author.display_name, + "channel_id": str(channel_id), + "channel_name": getattr(message.channel, "name", "DM"), "guild_id": str(message.guild.id) if message.guild else None, - "timestamp": message.created_at.timestamp() + "timestamp": message.created_at.timestamp(), } asyncio.create_task( cog.memory_manager.add_message_embedding( - message_id=str(message.id), text=message.content, metadata=semantic_metadata + message_id=str(message.id), + text=message.content, + metadata=semantic_metadata, ) ) @@ -111,8 +145,12 @@ async def on_message_listener(cog: 'WheatleyCog', message: discord.Message): # U # Check conditions for potentially responding bot_mentioned = cog.bot.user.mentioned_in(message) - replied_to_bot = message.reference and message.reference.resolved and message.reference.resolved.author == cog.bot.user - wheatley_in_message = "wheatley" in message.content.lower() # Changed variable name + replied_to_bot = ( + message.reference + and message.reference.resolved + and message.reference.resolved.author == cog.bot.user + ) + wheatley_in_message = "wheatley" in message.content.lower() # Changed variable name now = time.time() time_since_last_activity = now - cog.channel_activity.get(channel_id, 0) time_since_bot_spoke = now - cog.bot_last_spoke.get(channel_id, 0) @@ -121,44 +159,70 @@ async def on_message_listener(cog: 'WheatleyCog', message: discord.Message): # U consideration_reason = "Default" proactive_trigger_met = False - if bot_mentioned or replied_to_bot or wheatley_in_message: # Changed variable name + if bot_mentioned or replied_to_bot or wheatley_in_message: # Changed variable name should_consider_responding = True consideration_reason = "Direct mention/reply/name" else: # --- Proactive Engagement Triggers (Simplified for Wheatley) --- - from .config import (PROACTIVE_LULL_THRESHOLD, PROACTIVE_BOT_SILENCE_THRESHOLD, PROACTIVE_LULL_CHANCE, - PROACTIVE_TOPIC_RELEVANCE_THRESHOLD, PROACTIVE_TOPIC_CHANCE, - # Removed Relationship/Interest/Goal proactive configs - PROACTIVE_SENTIMENT_SHIFT_THRESHOLD, PROACTIVE_SENTIMENT_DURATION_THRESHOLD, - PROACTIVE_SENTIMENT_CHANCE) + from .config import ( + PROACTIVE_LULL_THRESHOLD, + PROACTIVE_BOT_SILENCE_THRESHOLD, + PROACTIVE_LULL_CHANCE, + PROACTIVE_TOPIC_RELEVANCE_THRESHOLD, + PROACTIVE_TOPIC_CHANCE, + # Removed Relationship/Interest/Goal proactive configs + PROACTIVE_SENTIMENT_SHIFT_THRESHOLD, + PROACTIVE_SENTIMENT_DURATION_THRESHOLD, + PROACTIVE_SENTIMENT_CHANCE, + ) # 1. Lull Trigger (Kept) - if time_since_last_activity > PROACTIVE_LULL_THRESHOLD and time_since_bot_spoke > PROACTIVE_BOT_SILENCE_THRESHOLD: + if ( + time_since_last_activity > PROACTIVE_LULL_THRESHOLD + and time_since_bot_spoke > PROACTIVE_BOT_SILENCE_THRESHOLD + ): # Check if there's *any* recent message context to potentially respond to - has_relevant_context = bool(cog.message_cache['by_channel'].get(channel_id)) + has_relevant_context = bool(cog.message_cache["by_channel"].get(channel_id)) if has_relevant_context and random.random() < PROACTIVE_LULL_CHANCE: should_consider_responding = True proactive_trigger_met = True consideration_reason = f"Proactive: Lull ({time_since_last_activity:.0f}s idle, bot silent {time_since_bot_spoke:.0f}s)" # 2. Topic Relevance Trigger (Kept - uses semantic memory) - if not proactive_trigger_met and message.content and cog.memory_manager.semantic_collection: + if ( + not proactive_trigger_met + and message.content + and cog.memory_manager.semantic_collection + ): try: - semantic_results = await cog.memory_manager.search_semantic_memory(query_text=message.content, n_results=1) + semantic_results = await cog.memory_manager.search_semantic_memory( + query_text=message.content, n_results=1 + ) if semantic_results: # Distance is often used, lower is better. Convert to similarity if needed. # Assuming distance is 0 (identical) to 2 (opposite). Similarity = 1 - (distance / 2) - distance = semantic_results[0].get('distance', 2.0) # Default to max distance - similarity_score = max(0.0, 1.0 - (distance / 2.0)) # Calculate similarity + distance = semantic_results[0].get( + "distance", 2.0 + ) # Default to max distance + similarity_score = max( + 0.0, 1.0 - (distance / 2.0) + ) # Calculate similarity - if similarity_score >= PROACTIVE_TOPIC_RELEVANCE_THRESHOLD and time_since_bot_spoke > 120: + if ( + similarity_score >= PROACTIVE_TOPIC_RELEVANCE_THRESHOLD + and time_since_bot_spoke > 120 + ): if random.random() < PROACTIVE_TOPIC_CHANCE: should_consider_responding = True proactive_trigger_met = True consideration_reason = f"Proactive: Relevant topic (Sim: {similarity_score:.2f})" - print(f"Topic relevance trigger met for msg {message.id}. Sim: {similarity_score:.2f}") + print( + f"Topic relevance trigger met for msg {message.id}. Sim: {similarity_score:.2f}" + ) else: - print(f"Topic relevance trigger skipped by chance ({PROACTIVE_TOPIC_CHANCE}). Sim: {similarity_score:.2f}") + print( + f"Topic relevance trigger skipped by chance ({PROACTIVE_TOPIC_CHANCE}). Sim: {similarity_score:.2f}" + ) except Exception as semantic_e: print(f"Error during semantic search for topic trigger: {semantic_e}") @@ -169,20 +233,30 @@ async def on_message_listener(cog: 'WheatleyCog', message: discord.Message): # U channel_sentiment_data = cog.conversation_sentiment.get(channel_id, {}) overall_sentiment = channel_sentiment_data.get("overall", "neutral") sentiment_intensity = channel_sentiment_data.get("intensity", 0.5) - sentiment_last_update = channel_sentiment_data.get("last_update", 0) # Need last update time - sentiment_duration = now - sentiment_last_update # How long has this sentiment been dominant? + sentiment_last_update = channel_sentiment_data.get( + "last_update", 0 + ) # Need last update time + sentiment_duration = ( + now - sentiment_last_update + ) # How long has this sentiment been dominant? - if overall_sentiment != "neutral" and \ - sentiment_intensity >= PROACTIVE_SENTIMENT_SHIFT_THRESHOLD and \ - sentiment_duration >= PROACTIVE_SENTIMENT_DURATION_THRESHOLD and \ - time_since_bot_spoke > 180: # Bot hasn't spoken recently about this + if ( + overall_sentiment != "neutral" + and sentiment_intensity >= PROACTIVE_SENTIMENT_SHIFT_THRESHOLD + and sentiment_duration >= PROACTIVE_SENTIMENT_DURATION_THRESHOLD + and time_since_bot_spoke > 180 + ): # Bot hasn't spoken recently about this if random.random() < PROACTIVE_SENTIMENT_CHANCE: should_consider_responding = True proactive_trigger_met = True consideration_reason = f"Proactive: Sentiment Shift ({overall_sentiment}, Intensity: {sentiment_intensity:.2f}, Duration: {sentiment_duration:.0f}s)" - print(f"Sentiment Shift trigger met for channel {channel_id}. Sentiment: {overall_sentiment}, Intensity: {sentiment_intensity:.2f}, Duration: {sentiment_duration:.0f}s") + print( + f"Sentiment Shift trigger met for channel {channel_id}. Sentiment: {overall_sentiment}, Intensity: {sentiment_intensity:.2f}, Duration: {sentiment_duration:.0f}s" + ) else: - print(f"Sentiment Shift trigger skipped by chance ({PROACTIVE_SENTIMENT_CHANCE}). Sentiment: {overall_sentiment}") + print( + f"Sentiment Shift trigger skipped by chance ({PROACTIVE_SENTIMENT_CHANCE}). Sentiment: {overall_sentiment}" + ) # 5. User Interest Trigger (REMOVED) # 6. Active Goal Relevance Trigger (REMOVED) @@ -190,42 +264,62 @@ async def on_message_listener(cog: 'WheatleyCog', message: discord.Message): # U # --- Fallback Contextual Chance (Simplified - No Chattiness Trait) --- if not should_consider_responding: # Base chance can be a fixed value or slightly randomized - base_chance = 0.1 # Lower base chance without personality traits + base_chance = 0.1 # Lower base chance without personality traits activity_bonus = 0 - if time_since_last_activity > 120: activity_bonus += 0.05 # Smaller bonus - if time_since_bot_spoke > 300: activity_bonus += 0.05 # Smaller bonus + if time_since_last_activity > 120: + activity_bonus += 0.05 # Smaller bonus + if time_since_bot_spoke > 300: + activity_bonus += 0.05 # Smaller bonus topic_bonus = 0 - active_channel_topics = cog.active_topics.get(channel_id, {}).get("topics", []) + active_channel_topics = cog.active_topics.get(channel_id, {}).get( + "topics", [] + ) if message.content and active_channel_topics: - topic_keywords = set(t['topic'].lower() for t in active_channel_topics) - message_words = set(re.findall(r'\b\w+\b', message.content.lower())) - if topic_keywords.intersection(message_words): topic_bonus += 0.10 # Smaller bonus + topic_keywords = set(t["topic"].lower() for t in active_channel_topics) + message_words = set(re.findall(r"\b\w+\b", message.content.lower())) + if topic_keywords.intersection(message_words): + topic_bonus += 0.10 # Smaller bonus sentiment_modifier = 0 channel_sentiment_data = cog.conversation_sentiment.get(channel_id, {}) overall_sentiment = channel_sentiment_data.get("overall", "neutral") sentiment_intensity = channel_sentiment_data.get("intensity", 0.5) - if overall_sentiment == "negative" and sentiment_intensity > 0.6: sentiment_modifier = -0.05 # Smaller penalty + if overall_sentiment == "negative" and sentiment_intensity > 0.6: + sentiment_modifier = -0.05 # Smaller penalty - final_chance = min(max(base_chance + activity_bonus + topic_bonus + sentiment_modifier, 0.02), 0.3) # Lower max chance + final_chance = min( + max( + base_chance + activity_bonus + topic_bonus + sentiment_modifier, + 0.02, + ), + 0.3, + ) # Lower max chance if random.random() < final_chance: should_consider_responding = True consideration_reason = f"Contextual chance ({final_chance:.2f})" else: consideration_reason = f"Skipped (chance {final_chance:.2f})" - print(f"Consideration check for message {message.id}: {should_consider_responding} (Reason: {consideration_reason})") + print( + f"Consideration check for message {message.id}: {should_consider_responding} (Reason: {consideration_reason})" + ) if not should_consider_responding: return # --- Call AI and Handle Response --- - cog.current_channel = message.channel # Ensure current channel is set for API calls/tools + cog.current_channel = ( + message.channel + ) # Ensure current channel is set for API calls/tools try: response_bundle = None if proactive_trigger_met: - print(f"Calling get_proactive_ai_response for message {message.id} due to: {consideration_reason}") - response_bundle = await get_proactive_ai_response(cog, message, consideration_reason) + print( + f"Calling get_proactive_ai_response for message {message.id} due to: {consideration_reason}" + ) + response_bundle = await get_proactive_ai_response( + cog, message, consideration_reason + ) else: print(f"Calling get_ai_response for message {message.id}") response_bundle = await get_ai_response(cog, message) @@ -239,47 +333,68 @@ async def on_message_listener(cog: 'WheatleyCog', message: discord.Message): # U if error_msg: print(f"Critical Error from AI response function: {error_msg}") # NEW LOGIC: Always send a notification if an error occurred here - error_notification = f"Bollocks! Something went sideways processing that. (`{error_msg[:100]}`)" # Updated text + error_notification = f"Bollocks! Something went sideways processing that. (`{error_msg[:100]}`)" # Updated text try: - print('disabled error notification') - #await message.channel.send(error_notification) + print("disabled error notification") + # await message.channel.send(error_notification) except Exception as send_err: print(f"Failed to send error notification to channel: {send_err}") - return # Still exit after handling the error + return # Still exit after handling the error # --- Process and Send Responses --- sent_any_message = False reacted = False # Helper function to handle sending a single response text and caching - async def send_response_content(response_data: Optional[Dict[str, Any]], response_label: str) -> bool: - nonlocal sent_any_message # Allow modification of the outer scope variable - if response_data and isinstance(response_data, dict) and \ - response_data.get("should_respond") and response_data.get("content"): + async def send_response_content( + response_data: Optional[Dict[str, Any]], response_label: str + ) -> bool: + nonlocal sent_any_message # Allow modification of the outer scope variable + if ( + response_data + and isinstance(response_data, dict) + and response_data.get("should_respond") + and response_data.get("content") + ): response_text = response_data["content"] print(f"Attempting to send {response_label} content...") if len(response_text) > 1900: - filepath = f'wheatley_{response_label}_{message.id}.txt' # Changed filename prefix + filepath = f"wheatley_{response_label}_{message.id}.txt" # Changed filename prefix try: - with open(filepath, 'w', encoding='utf-8') as f: f.write(response_text) - await message.channel.send(f"{response_label.capitalize()} response too long, have a look at this:", file=discord.File(filepath)) # Updated text + with open(filepath, "w", encoding="utf-8") as f: + f.write(response_text) + await message.channel.send( + f"{response_label.capitalize()} response too long, have a look at this:", + file=discord.File(filepath), + ) # Updated text sent_any_message = True print(f"Sent {response_label} content as file.") return True - except Exception as file_e: print(f"Error writing/sending long {response_label} response file: {file_e}") + except Exception as file_e: + print( + f"Error writing/sending long {response_label} response file: {file_e}" + ) finally: - try: os.remove(filepath) - except OSError as os_e: print(f"Error removing temp file {filepath}: {os_e}") + try: + os.remove(filepath) + except OSError as os_e: + print(f"Error removing temp file {filepath}: {os_e}") else: try: async with message.channel.typing(): - await simulate_human_typing(cog, message.channel, response_text) # Use simulation + await simulate_human_typing( + cog, message.channel, response_text + ) # Use simulation sent_msg = await message.channel.send(response_text) sent_any_message = True # Cache this bot response bot_response_cache_entry = format_message(cog, sent_msg) - cog.message_cache['by_channel'][channel_id].append(bot_response_cache_entry) - cog.message_cache['global_recent'].append(bot_response_cache_entry) + cog.message_cache["by_channel"][channel_id].append( + bot_response_cache_entry + ) + cog.message_cache["global_recent"].append( + bot_response_cache_entry + ) cog.bot_last_spoke[channel_id] = time.time() # Track participation topic - NOTE: Participation tracking might be removed for Wheatley # identified_topics = identify_conversation_topics(cog, [bot_response_cache_entry]) @@ -300,8 +415,10 @@ async def on_message_listener(cog: 'WheatleyCog', message: discord.Message): # U sent_final_message = False # Ensure initial_response exists before accessing its content for comparison initial_content = initial_response.get("content") if initial_response else None - if final_response and (not sent_initial_message or initial_content != final_response.get("content")): - sent_final_message = await send_response_content(final_response, "final") + if final_response and ( + not sent_initial_message or initial_content != final_response.get("content") + ): + sent_final_message = await send_response_content(final_response, "final") # Handle Reaction (prefer final response for reaction if it exists) reaction_source = final_response if final_response else initial_response @@ -310,39 +427,73 @@ async def on_message_listener(cog: 'WheatleyCog', message: discord.Message): # U if emoji_to_react and isinstance(emoji_to_react, str): try: # Basic validation for standard emoji - if 1 <= len(emoji_to_react) <= 4 and not re.match(r'', emoji_to_react): + if 1 <= len(emoji_to_react) <= 4 and not re.match( + r"", emoji_to_react + ): # Only react if we haven't sent any message content (avoid double interaction) if not sent_any_message: await message.add_reaction(emoji_to_react) reacted = True - print(f"Bot reacted to message {message.id} with {emoji_to_react}") + print( + f"Bot reacted to message {message.id} with {emoji_to_react}" + ) else: - print(f"Skipping reaction {emoji_to_react} because a message was already sent.") - else: print(f"Invalid emoji format: {emoji_to_react}") - except Exception as e: print(f"Error adding reaction '{emoji_to_react}': {e}") + print( + f"Skipping reaction {emoji_to_react} because a message was already sent." + ) + else: + print(f"Invalid emoji format: {emoji_to_react}") + except Exception as e: + print(f"Error adding reaction '{emoji_to_react}': {e}") # Log if response was intended but nothing was sent/reacted # Check if initial response intended action but nothing happened - initial_intended_action = initial_response and initial_response.get("should_respond") - initial_action_taken = sent_initial_message or (reacted and reaction_source == initial_response) + initial_intended_action = initial_response and initial_response.get( + "should_respond" + ) + initial_action_taken = sent_initial_message or ( + reacted and reaction_source == initial_response + ) # Check if final response intended action but nothing happened final_intended_action = final_response and final_response.get("should_respond") - final_action_taken = sent_final_message or (reacted and reaction_source == final_response) + final_action_taken = sent_final_message or ( + reacted and reaction_source == final_response + ) - if (initial_intended_action and not initial_action_taken) or \ - (final_intended_action and not final_action_taken): - print(f"Warning: AI response intended action but nothing sent/reacted. Initial: {initial_response}, Final: {final_response}") + if (initial_intended_action and not initial_action_taken) or ( + final_intended_action and not final_action_taken + ): + print( + f"Warning: AI response intended action but nothing sent/reacted. Initial: {initial_response}, Final: {final_response}" + ) except Exception as e: print(f"Exception in on_message listener main block: {str(e)}") import traceback + traceback.print_exc() - if bot_mentioned or replied_to_bot: # Check again in case error happened before response handling - await message.channel.send(random.choice(["Uh oh.", "What was that?", "Did I break it?", "Bollocks!", "That wasn't supposed to happen."])) # Changed fallback + if ( + bot_mentioned or replied_to_bot + ): # Check again in case error happened before response handling + await message.channel.send( + random.choice( + [ + "Uh oh.", + "What was that?", + "Did I break it?", + "Bollocks!", + "That wasn't supposed to happen.", + ] + ) + ) # Changed fallback @commands.Cog.listener() -async def on_reaction_add_listener(cog: 'WheatleyCog', reaction: discord.Reaction, user: Union[discord.Member, discord.User]): # Updated type hint +async def on_reaction_add_listener( + cog: "WheatleyCog", + reaction: discord.Reaction, + user: Union[discord.Member, discord.User], +): # Updated type hint """Listener function for on_reaction_add.""" # Import necessary config/functions if not globally available from .config import EMOJI_SENTIMENT @@ -354,34 +505,75 @@ async def on_reaction_add_listener(cog: 'WheatleyCog', reaction: discord.Reactio message_id = str(reaction.message.id) emoji_str = str(reaction.emoji) sentiment = "neutral" - if emoji_str in EMOJI_SENTIMENT["positive"]: sentiment = "positive" - elif emoji_str in EMOJI_SENTIMENT["negative"]: sentiment = "negative" + if emoji_str in EMOJI_SENTIMENT["positive"]: + sentiment = "positive" + elif emoji_str in EMOJI_SENTIMENT["negative"]: + sentiment = "negative" - if sentiment == "positive": cog.wheatley_message_reactions[message_id]["positive"] += 1 # Changed attribute name - elif sentiment == "negative": cog.wheatley_message_reactions[message_id]["negative"] += 1 # Changed attribute name - cog.wheatley_message_reactions[message_id]["timestamp"] = time.time() # Changed attribute name + if sentiment == "positive": + cog.wheatley_message_reactions[message_id][ + "positive" + ] += 1 # Changed attribute name + elif sentiment == "negative": + cog.wheatley_message_reactions[message_id][ + "negative" + ] += 1 # Changed attribute name + cog.wheatley_message_reactions[message_id][ + "timestamp" + ] = time.time() # Changed attribute name # Topic identification for reactions might be less relevant for Wheatley, but kept for now - if not cog.wheatley_message_reactions[message_id].get("topic"): # Changed attribute name + if not cog.wheatley_message_reactions[message_id].get( + "topic" + ): # Changed attribute name try: # Changed variable name - wheatley_msg_data = next((msg for msg in cog.message_cache['global_recent'] if msg['id'] == message_id), None) - if wheatley_msg_data and wheatley_msg_data['content']: # Changed variable name - identified_topics = identify_conversation_topics(cog, [wheatley_msg_data]) # Pass cog, changed variable name + wheatley_msg_data = next( + ( + msg + for msg in cog.message_cache["global_recent"] + if msg["id"] == message_id + ), + None, + ) + if ( + wheatley_msg_data and wheatley_msg_data["content"] + ): # Changed variable name + identified_topics = identify_conversation_topics( + cog, [wheatley_msg_data] + ) # Pass cog, changed variable name if identified_topics: - topic = identified_topics[0]['topic'].lower().strip() - cog.wheatley_message_reactions[message_id]["topic"] = topic # Changed attribute name - print(f"Reaction added to Wheatley msg ({message_id}) on topic '{topic}'. Sentiment: {sentiment}") # Changed text - else: print(f"Reaction added to Wheatley msg ({message_id}), topic unknown.") # Changed text - else: print(f"Reaction added, but Wheatley msg {message_id} not in cache.") # Changed text - except Exception as e: print(f"Error determining topic for reaction on msg {message_id}: {e}") - else: print(f"Reaction added to Wheatley msg ({message_id}) on known topic '{cog.wheatley_message_reactions[message_id]['topic']}'. Sentiment: {sentiment}") # Changed text, attribute name + topic = identified_topics[0]["topic"].lower().strip() + cog.wheatley_message_reactions[message_id][ + "topic" + ] = topic # Changed attribute name + print( + f"Reaction added to Wheatley msg ({message_id}) on topic '{topic}'. Sentiment: {sentiment}" + ) # Changed text + else: + print( + f"Reaction added to Wheatley msg ({message_id}), topic unknown." + ) # Changed text + else: + print( + f"Reaction added, but Wheatley msg {message_id} not in cache." + ) # Changed text + except Exception as e: + print(f"Error determining topic for reaction on msg {message_id}: {e}") + else: + print( + f"Reaction added to Wheatley msg ({message_id}) on known topic '{cog.wheatley_message_reactions[message_id]['topic']}'. Sentiment: {sentiment}" + ) # Changed text, attribute name @commands.Cog.listener() -async def on_reaction_remove_listener(cog: 'WheatleyCog', reaction: discord.Reaction, user: Union[discord.Member, discord.User]): # Updated type hint +async def on_reaction_remove_listener( + cog: "WheatleyCog", + reaction: discord.Reaction, + user: Union[discord.Member, discord.User], +): # Updated type hint """Listener function for on_reaction_remove.""" - from .config import EMOJI_SENTIMENT # Import necessary config + from .config import EMOJI_SENTIMENT # Import necessary config if user.bot or reaction.message.author.id != cog.bot.user.id: return @@ -389,10 +581,20 @@ async def on_reaction_remove_listener(cog: 'WheatleyCog', reaction: discord.Reac message_id = str(reaction.message.id) emoji_str = str(reaction.emoji) sentiment = "neutral" - if emoji_str in EMOJI_SENTIMENT["positive"]: sentiment = "positive" - elif emoji_str in EMOJI_SENTIMENT["negative"]: sentiment = "negative" + if emoji_str in EMOJI_SENTIMENT["positive"]: + sentiment = "positive" + elif emoji_str in EMOJI_SENTIMENT["negative"]: + sentiment = "negative" - if message_id in cog.wheatley_message_reactions: # Changed attribute name - if sentiment == "positive": cog.wheatley_message_reactions[message_id]["positive"] = max(0, cog.wheatley_message_reactions[message_id]["positive"] - 1) # Changed attribute name - elif sentiment == "negative": cog.wheatley_message_reactions[message_id]["negative"] = max(0, cog.wheatley_message_reactions[message_id]["negative"] - 1) # Changed attribute name - print(f"Reaction removed from Wheatley msg ({message_id}). Sentiment: {sentiment}") # Changed text + if message_id in cog.wheatley_message_reactions: # Changed attribute name + if sentiment == "positive": + cog.wheatley_message_reactions[message_id]["positive"] = max( + 0, cog.wheatley_message_reactions[message_id]["positive"] - 1 + ) # Changed attribute name + elif sentiment == "negative": + cog.wheatley_message_reactions[message_id]["negative"] = max( + 0, cog.wheatley_message_reactions[message_id]["negative"] - 1 + ) # Changed attribute name + print( + f"Reaction removed from Wheatley msg ({message_id}). Sentiment: {sentiment}" + ) # Changed text diff --git a/wheatley/memory.py b/wheatley/memory.py index be888d6..897668b 100644 --- a/wheatley/memory.py +++ b/wheatley/memory.py @@ -4,63 +4,95 @@ import os import time import datetime import re -import hashlib # Added for chroma_id generation -import json # Added for personality trait serialization/deserialization -from typing import Dict, List, Any, Optional, Tuple, Union # Added Union +import hashlib # Added for chroma_id generation +import json # Added for personality trait serialization/deserialization +from typing import Dict, List, Any, Optional, Tuple, Union # Added Union import chromadb from chromadb.utils import embedding_functions from sentence_transformers import SentenceTransformer import logging # Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) # Use a specific logger name for Wheatley's memory -logger = logging.getLogger('wheatley_memory') +logger = logging.getLogger("wheatley_memory") # Constants (Removed Interest constants) + # --- Helper Function for Keyword Scoring (Kept for potential future use, but unused currently) --- def calculate_keyword_score(text: str, context: str) -> int: """Calculates a simple keyword overlap score.""" if not context or not text: return 0 - context_words = set(re.findall(r'\b\w+\b', context.lower())) - text_words = set(re.findall(r'\b\w+\b', text.lower())) + context_words = set(re.findall(r"\b\w+\b", context.lower())) + text_words = set(re.findall(r"\b\w+\b", text.lower())) # Ignore very common words (basic stopword list) - stopwords = {"the", "a", "is", "in", "it", "of", "and", "to", "for", "on", "with", "that", "this", "i", "you", "me", "my", "your"} + stopwords = { + "the", + "a", + "is", + "in", + "it", + "of", + "and", + "to", + "for", + "on", + "with", + "that", + "this", + "i", + "you", + "me", + "my", + "your", + } context_words -= stopwords text_words -= stopwords - if not context_words: # Avoid division by zero if context is only stopwords + if not context_words: # Avoid division by zero if context is only stopwords return 0 overlap = len(context_words.intersection(text_words)) # Normalize score slightly by context length (more overlap needed for longer context) # score = overlap / (len(context_words) ** 0.5) # Example normalization - score = overlap # Simpler score for now + score = overlap # Simpler score for now return score -class MemoryManager: - """Handles database interactions for Wheatley's memory (facts and semantic).""" # Updated docstring - def __init__(self, db_path: str, max_user_facts: int = 20, max_general_facts: int = 100, semantic_model_name: str = 'all-MiniLM-L6-v2', chroma_path: str = "data/chroma_db_wheatley"): # Changed default chroma_path +class MemoryManager: + """Handles database interactions for Wheatley's memory (facts and semantic).""" # Updated docstring + + def __init__( + self, + db_path: str, + max_user_facts: int = 20, + max_general_facts: int = 100, + semantic_model_name: str = "all-MiniLM-L6-v2", + chroma_path: str = "data/chroma_db_wheatley", + ): # Changed default chroma_path self.db_path = db_path self.max_user_facts = max_user_facts self.max_general_facts = max_general_facts - self.db_lock = asyncio.Lock() # Lock for SQLite operations + self.db_lock = asyncio.Lock() # Lock for SQLite operations # Ensure data directories exist os.makedirs(os.path.dirname(self.db_path), exist_ok=True) os.makedirs(chroma_path, exist_ok=True) - logger.info(f"Wheatley MemoryManager initialized with db_path: {self.db_path}, chroma_path: {chroma_path}") # Updated text + logger.info( + f"Wheatley MemoryManager initialized with db_path: {self.db_path}, chroma_path: {chroma_path}" + ) # Updated text # --- Semantic Memory Setup --- self.chroma_path = chroma_path self.semantic_model_name = semantic_model_name self.chroma_client = None self.embedding_function = None - self.semantic_collection = None # For messages - self.fact_collection = None # For facts + self.semantic_collection = None # For messages + self.fact_collection = None # For facts self.transformer_model = None - self._initialize_semantic_memory_sync() # Initialize semantic components synchronously for simplicity during init + self._initialize_semantic_memory_sync() # Initialize semantic components synchronously for simplicity during init def _initialize_semantic_memory_sync(self): """Synchronously initializes ChromaDB client, model, and collection.""" @@ -69,7 +101,9 @@ class MemoryManager: # Use PersistentClient for saving data to disk self.chroma_client = chromadb.PersistentClient(path=self.chroma_path) - logger.info(f"Loading Sentence Transformer model: {self.semantic_model_name}...") + logger.info( + f"Loading Sentence Transformer model: {self.semantic_model_name}..." + ) # Load the model directly self.transformer_model = SentenceTransformer(self.semantic_model_name) @@ -77,46 +111,57 @@ class MemoryManager: class CustomEmbeddingFunction(embedding_functions.EmbeddingFunction): def __init__(self, model): self.model = model + def __call__(self, input: chromadb.Documents) -> chromadb.Embeddings: # Ensure input is a list of strings if not isinstance(input, list): - input = [str(input)] # Convert single item to list + input = [str(input)] # Convert single item to list elif not all(isinstance(item, str) for item in input): - input = [str(item) for item in input] # Ensure all items are strings + input = [ + str(item) for item in input + ] # Ensure all items are strings logger.debug(f"Generating embeddings for {len(input)} documents.") - embeddings = self.model.encode(input, show_progress_bar=False).tolist() + embeddings = self.model.encode( + input, show_progress_bar=False + ).tolist() logger.debug(f"Generated {len(embeddings)} embeddings.") return embeddings self.embedding_function = CustomEmbeddingFunction(self.transformer_model) - logger.info("Getting/Creating ChromaDB collection 'wheatley_semantic_memory'...") # Renamed collection + logger.info( + "Getting/Creating ChromaDB collection 'wheatley_semantic_memory'..." + ) # Renamed collection # Get or create the collection with the custom embedding function self.semantic_collection = self.chroma_client.get_or_create_collection( - name="wheatley_semantic_memory", # Renamed collection + name="wheatley_semantic_memory", # Renamed collection embedding_function=self.embedding_function, - metadata={"hnsw:space": "cosine"} # Use cosine distance for similarity + metadata={"hnsw:space": "cosine"}, # Use cosine distance for similarity ) logger.info("ChromaDB message collection initialized successfully.") - logger.info("Getting/Creating ChromaDB collection 'wheatley_fact_memory'...") # Renamed collection + logger.info( + "Getting/Creating ChromaDB collection 'wheatley_fact_memory'..." + ) # Renamed collection # Get or create the collection for facts self.fact_collection = self.chroma_client.get_or_create_collection( - name="wheatley_fact_memory", # Renamed collection + name="wheatley_fact_memory", # Renamed collection embedding_function=self.embedding_function, - metadata={"hnsw:space": "cosine"} # Use cosine distance for similarity + metadata={"hnsw:space": "cosine"}, # Use cosine distance for similarity ) logger.info("ChromaDB fact collection initialized successfully.") except Exception as e: - logger.error(f"Failed to initialize semantic memory (ChromaDB): {e}", exc_info=True) + logger.error( + f"Failed to initialize semantic memory (ChromaDB): {e}", exc_info=True + ) # Set components to None to indicate failure self.chroma_client = None self.transformer_model = None self.embedding_function = None self.semantic_collection = None - self.fact_collection = None # Also set fact_collection to None on error + self.fact_collection = None # Also set fact_collection to None on error async def initialize_sqlite_database(self): """Initializes the SQLite database and creates tables if they don't exist.""" @@ -124,7 +169,8 @@ class MemoryManager: await db.execute("PRAGMA journal_mode=WAL;") # Create user_facts table if it doesn't exist - await db.execute(""" + await db.execute( + """ CREATE TABLE IF NOT EXISTS user_facts ( user_id TEXT NOT NULL, fact TEXT NOT NULL, @@ -132,52 +178,71 @@ class MemoryManager: timestamp REAL DEFAULT (unixepoch('now')), PRIMARY KEY (user_id, fact) ); - """) + """ + ) # Check if chroma_id column exists in user_facts table try: cursor = await db.execute("PRAGMA table_info(user_facts)") columns = await cursor.fetchall() column_names = [column[1] for column in columns] - if 'chroma_id' not in column_names: + if "chroma_id" not in column_names: logger.info("Adding chroma_id column to user_facts table") await db.execute("ALTER TABLE user_facts ADD COLUMN chroma_id TEXT") except Exception as e: - logger.error(f"Error checking/adding chroma_id column to user_facts: {e}", exc_info=True) + logger.error( + f"Error checking/adding chroma_id column to user_facts: {e}", + exc_info=True, + ) # Create indexes - await db.execute("CREATE INDEX IF NOT EXISTS idx_user_facts_user ON user_facts (user_id);") - await db.execute("CREATE INDEX IF NOT EXISTS idx_user_facts_chroma_id ON user_facts (chroma_id);") # Index for chroma_id + await db.execute( + "CREATE INDEX IF NOT EXISTS idx_user_facts_user ON user_facts (user_id);" + ) + await db.execute( + "CREATE INDEX IF NOT EXISTS idx_user_facts_chroma_id ON user_facts (chroma_id);" + ) # Index for chroma_id # Create general_facts table if it doesn't exist - await db.execute(""" + await db.execute( + """ CREATE TABLE IF NOT EXISTS general_facts ( fact TEXT PRIMARY KEY NOT NULL, chroma_id TEXT, -- Added for linking to ChromaDB timestamp REAL DEFAULT (unixepoch('now')) ); - """) + """ + ) # Check if chroma_id column exists in general_facts table try: cursor = await db.execute("PRAGMA table_info(general_facts)") columns = await cursor.fetchall() column_names = [column[1] for column in columns] - if 'chroma_id' not in column_names: + if "chroma_id" not in column_names: logger.info("Adding chroma_id column to general_facts table") - await db.execute("ALTER TABLE general_facts ADD COLUMN chroma_id TEXT") + await db.execute( + "ALTER TABLE general_facts ADD COLUMN chroma_id TEXT" + ) except Exception as e: - logger.error(f"Error checking/adding chroma_id column to general_facts: {e}", exc_info=True) + logger.error( + f"Error checking/adding chroma_id column to general_facts: {e}", + exc_info=True, + ) # Create index for general_facts - await db.execute("CREATE INDEX IF NOT EXISTS idx_general_facts_chroma_id ON general_facts (chroma_id);") # Index for chroma_id + await db.execute( + "CREATE INDEX IF NOT EXISTS idx_general_facts_chroma_id ON general_facts (chroma_id);" + ) # Index for chroma_id # --- Removed Personality Table --- # --- Removed Interests Table --- # --- Removed Goals Table --- await db.commit() - logger.info(f"Wheatley SQLite database initialized/verified at {self.db_path}") # Updated text + logger.info( + f"Wheatley SQLite database initialized/verified at {self.db_path}" + ) # Updated text # --- SQLite Helper Methods --- async def _db_execute(self, sql: str, params: tuple = ()): @@ -205,58 +270,94 @@ class MemoryManager: logger.info(f"Attempting to add user fact for {user_id}: '{fact}'") try: # Check SQLite first - existing = await self._db_fetchone("SELECT chroma_id FROM user_facts WHERE user_id = ? AND fact = ?", (user_id, fact)) + existing = await self._db_fetchone( + "SELECT chroma_id FROM user_facts WHERE user_id = ? AND fact = ?", + (user_id, fact), + ) if existing: logger.info(f"Fact already known for user {user_id} (SQLite).") return {"status": "duplicate", "user_id": user_id, "fact": fact} - count_result = await self._db_fetchone("SELECT COUNT(*) FROM user_facts WHERE user_id = ?", (user_id,)) + count_result = await self._db_fetchone( + "SELECT COUNT(*) FROM user_facts WHERE user_id = ?", (user_id,) + ) current_count = count_result[0] if count_result else 0 status = "added" deleted_chroma_id = None if current_count >= self.max_user_facts: - logger.warning(f"User {user_id} fact limit ({self.max_user_facts}) reached. Deleting oldest.") + logger.warning( + f"User {user_id} fact limit ({self.max_user_facts}) reached. Deleting oldest." + ) # Fetch oldest fact and its chroma_id for deletion - oldest_fact_row = await self._db_fetchone("SELECT fact, chroma_id FROM user_facts WHERE user_id = ? ORDER BY timestamp ASC LIMIT 1", (user_id,)) + oldest_fact_row = await self._db_fetchone( + "SELECT fact, chroma_id FROM user_facts WHERE user_id = ? ORDER BY timestamp ASC LIMIT 1", + (user_id,), + ) if oldest_fact_row: oldest_fact, deleted_chroma_id = oldest_fact_row - await self._db_execute("DELETE FROM user_facts WHERE user_id = ? AND fact = ?", (user_id, oldest_fact)) - logger.info(f"Deleted oldest fact for user {user_id} from SQLite: '{oldest_fact}'") - status = "limit_reached" # Indicate limit was hit but fact was added + await self._db_execute( + "DELETE FROM user_facts WHERE user_id = ? AND fact = ?", + (user_id, oldest_fact), + ) + logger.info( + f"Deleted oldest fact for user {user_id} from SQLite: '{oldest_fact}'" + ) + status = ( + "limit_reached" # Indicate limit was hit but fact was added + ) # Generate chroma_id - fact_hash = hashlib.sha1(fact.encode()).hexdigest()[:16] # Short hash + fact_hash = hashlib.sha1(fact.encode()).hexdigest()[:16] # Short hash chroma_id = f"user-{user_id}-{fact_hash}" # Insert into SQLite - await self._db_execute("INSERT INTO user_facts (user_id, fact, chroma_id) VALUES (?, ?, ?)", (user_id, fact, chroma_id)) + await self._db_execute( + "INSERT INTO user_facts (user_id, fact, chroma_id) VALUES (?, ?, ?)", + (user_id, fact, chroma_id), + ) logger.info(f"Fact added for user {user_id} to SQLite.") # Add to ChromaDB fact collection if self.fact_collection and self.embedding_function: try: - metadata = {"user_id": user_id, "type": "user", "timestamp": time.time()} + metadata = { + "user_id": user_id, + "type": "user", + "timestamp": time.time(), + } await asyncio.to_thread( self.fact_collection.add, documents=[fact], metadatas=[metadata], - ids=[chroma_id] + ids=[chroma_id], + ) + logger.info( + f"Fact added/updated for user {user_id} in ChromaDB (ID: {chroma_id})." ) - logger.info(f"Fact added/updated for user {user_id} in ChromaDB (ID: {chroma_id}).") # Delete the oldest fact from ChromaDB if limit was reached if deleted_chroma_id: - logger.info(f"Attempting to delete oldest fact from ChromaDB (ID: {deleted_chroma_id}).") - await asyncio.to_thread(self.fact_collection.delete, ids=[deleted_chroma_id]) - logger.info(f"Successfully deleted oldest fact from ChromaDB (ID: {deleted_chroma_id}).") + logger.info( + f"Attempting to delete oldest fact from ChromaDB (ID: {deleted_chroma_id})." + ) + await asyncio.to_thread( + self.fact_collection.delete, ids=[deleted_chroma_id] + ) + logger.info( + f"Successfully deleted oldest fact from ChromaDB (ID: {deleted_chroma_id})." + ) except Exception as chroma_e: - logger.error(f"ChromaDB error adding/deleting user fact for {user_id} (ID: {chroma_id}): {chroma_e}", exc_info=True) + logger.error( + f"ChromaDB error adding/deleting user fact for {user_id} (ID: {chroma_id}): {chroma_e}", + exc_info=True, + ) # Note: Fact is still in SQLite, but ChromaDB might be inconsistent. Consider rollback? For now, just log. else: - logger.warning(f"ChromaDB fact collection not available. Skipping embedding for user fact {user_id}.") - + logger.warning( + f"ChromaDB fact collection not available. Skipping embedding for user fact {user_id}." + ) return {"status": status, "user_id": user_id, "fact_added": fact} @@ -264,60 +365,78 @@ class MemoryManager: logger.error(f"Error adding user fact for {user_id}: {e}", exc_info=True) return {"error": f"Database error adding user fact: {str(e)}"} - async def get_user_facts(self, user_id: str, context: Optional[str] = None) -> List[str]: + async def get_user_facts( + self, user_id: str, context: Optional[str] = None + ) -> List[str]: """Retrieves stored facts about a user, optionally scored by relevance to context.""" if not user_id: logger.warning("get_user_facts called without user_id.") return [] - logger.info(f"Retrieving facts for user {user_id} (context provided: {bool(context)})") - limit = self.max_user_facts # Use the class attribute for limit + logger.info( + f"Retrieving facts for user {user_id} (context provided: {bool(context)})" + ) + limit = self.max_user_facts # Use the class attribute for limit try: if context and self.fact_collection and self.embedding_function: # --- Semantic Search --- - logger.debug(f"Performing semantic search for user facts (User: {user_id}, Limit: {limit})") + logger.debug( + f"Performing semantic search for user facts (User: {user_id}, Limit: {limit})" + ) try: # Query ChromaDB for facts relevant to the context results = await asyncio.to_thread( self.fact_collection.query, query_texts=[context], n_results=limit, - where={ # Use $and for multiple conditions - "$and": [ - {"user_id": user_id}, - {"type": "user"} - ] + where={ # Use $and for multiple conditions + "$and": [{"user_id": user_id}, {"type": "user"}] }, - include=['documents'] # Only need the fact text + include=["documents"], # Only need the fact text ) logger.debug(f"ChromaDB user fact query results: {results}") - if results and results.get('documents') and results['documents'][0]: - relevant_facts = results['documents'][0] - logger.info(f"Found {len(relevant_facts)} semantically relevant user facts for {user_id}.") + if results and results.get("documents") and results["documents"][0]: + relevant_facts = results["documents"][0] + logger.info( + f"Found {len(relevant_facts)} semantically relevant user facts for {user_id}." + ) return relevant_facts else: - logger.info(f"No semantic user facts found for {user_id} matching context.") - return [] # Return empty list if no semantic matches + logger.info( + f"No semantic user facts found for {user_id} matching context." + ) + return [] # Return empty list if no semantic matches except Exception as chroma_e: - logger.error(f"ChromaDB error searching user facts for {user_id}: {chroma_e}", exc_info=True) + logger.error( + f"ChromaDB error searching user facts for {user_id}: {chroma_e}", + exc_info=True, + ) # Fallback to SQLite retrieval on ChromaDB error - logger.warning(f"Falling back to SQLite retrieval for user facts {user_id} due to ChromaDB error.") + logger.warning( + f"Falling back to SQLite retrieval for user facts {user_id} due to ChromaDB error." + ) # Proceed to the SQLite block below # --- SQLite Fallback / No Context --- # If no context, or if ChromaDB failed/unavailable, get newest N facts from SQLite - logger.debug(f"Retrieving user facts from SQLite (User: {user_id}, Limit: {limit})") + logger.debug( + f"Retrieving user facts from SQLite (User: {user_id}, Limit: {limit})" + ) rows_ordered = await self._db_fetchall( "SELECT fact FROM user_facts WHERE user_id = ? ORDER BY timestamp DESC LIMIT ?", - (user_id, limit) + (user_id, limit), ) sqlite_facts = [row[0] for row in rows_ordered] - logger.info(f"Retrieved {len(sqlite_facts)} user facts from SQLite for {user_id}.") + logger.info( + f"Retrieved {len(sqlite_facts)} user facts from SQLite for {user_id}." + ) return sqlite_facts except Exception as e: - logger.error(f"Error retrieving user facts for {user_id}: {e}", exc_info=True) + logger.error( + f"Error retrieving user facts for {user_id}: {e}", exc_info=True + ) return [] # --- General Fact Memory Methods (SQLite + Relevance) --- @@ -329,32 +448,48 @@ class MemoryManager: logger.info(f"Attempting to add general fact: '{fact}'") try: # Check SQLite first - existing = await self._db_fetchone("SELECT chroma_id FROM general_facts WHERE fact = ?", (fact,)) + existing = await self._db_fetchone( + "SELECT chroma_id FROM general_facts WHERE fact = ?", (fact,) + ) if existing: logger.info(f"General fact already known (SQLite): '{fact}'") return {"status": "duplicate", "fact": fact} - count_result = await self._db_fetchone("SELECT COUNT(*) FROM general_facts", ()) + count_result = await self._db_fetchone( + "SELECT COUNT(*) FROM general_facts", () + ) current_count = count_result[0] if count_result else 0 status = "added" deleted_chroma_id = None if current_count >= self.max_general_facts: - logger.warning(f"General fact limit ({self.max_general_facts}) reached. Deleting oldest.") + logger.warning( + f"General fact limit ({self.max_general_facts}) reached. Deleting oldest." + ) # Fetch oldest fact and its chroma_id for deletion - oldest_fact_row = await self._db_fetchone("SELECT fact, chroma_id FROM general_facts ORDER BY timestamp ASC LIMIT 1", ()) + oldest_fact_row = await self._db_fetchone( + "SELECT fact, chroma_id FROM general_facts ORDER BY timestamp ASC LIMIT 1", + (), + ) if oldest_fact_row: oldest_fact, deleted_chroma_id = oldest_fact_row - await self._db_execute("DELETE FROM general_facts WHERE fact = ?", (oldest_fact,)) - logger.info(f"Deleted oldest general fact from SQLite: '{oldest_fact}'") + await self._db_execute( + "DELETE FROM general_facts WHERE fact = ?", (oldest_fact,) + ) + logger.info( + f"Deleted oldest general fact from SQLite: '{oldest_fact}'" + ) status = "limit_reached" # Generate chroma_id - fact_hash = hashlib.sha1(fact.encode()).hexdigest()[:16] # Short hash + fact_hash = hashlib.sha1(fact.encode()).hexdigest()[:16] # Short hash chroma_id = f"general-{fact_hash}" # Insert into SQLite - await self._db_execute("INSERT INTO general_facts (fact, chroma_id) VALUES (?, ?)", (fact, chroma_id)) + await self._db_execute( + "INSERT INTO general_facts (fact, chroma_id) VALUES (?, ?)", + (fact, chroma_id), + ) logger.info(f"General fact added to SQLite: '{fact}'") # Add to ChromaDB fact collection @@ -365,21 +500,34 @@ class MemoryManager: self.fact_collection.add, documents=[fact], metadatas=[metadata], - ids=[chroma_id] + ids=[chroma_id], + ) + logger.info( + f"General fact added/updated in ChromaDB (ID: {chroma_id})." ) - logger.info(f"General fact added/updated in ChromaDB (ID: {chroma_id}).") # Delete the oldest fact from ChromaDB if limit was reached if deleted_chroma_id: - logger.info(f"Attempting to delete oldest general fact from ChromaDB (ID: {deleted_chroma_id}).") - await asyncio.to_thread(self.fact_collection.delete, ids=[deleted_chroma_id]) - logger.info(f"Successfully deleted oldest general fact from ChromaDB (ID: {deleted_chroma_id}).") + logger.info( + f"Attempting to delete oldest general fact from ChromaDB (ID: {deleted_chroma_id})." + ) + await asyncio.to_thread( + self.fact_collection.delete, ids=[deleted_chroma_id] + ) + logger.info( + f"Successfully deleted oldest general fact from ChromaDB (ID: {deleted_chroma_id})." + ) except Exception as chroma_e: - logger.error(f"ChromaDB error adding/deleting general fact (ID: {chroma_id}): {chroma_e}", exc_info=True) + logger.error( + f"ChromaDB error adding/deleting general fact (ID: {chroma_id}): {chroma_e}", + exc_info=True, + ) # Note: Fact is still in SQLite. else: - logger.warning(f"ChromaDB fact collection not available. Skipping embedding for general fact.") + logger.warning( + f"ChromaDB fact collection not available. Skipping embedding for general fact." + ) return {"status": status, "fact_added": fact} @@ -387,42 +535,60 @@ class MemoryManager: logger.error(f"Error adding general fact: {e}", exc_info=True) return {"error": f"Database error adding general fact: {str(e)}"} - async def get_general_facts(self, query: Optional[str] = None, limit: Optional[int] = 10, context: Optional[str] = None) -> List[str]: + async def get_general_facts( + self, + query: Optional[str] = None, + limit: Optional[int] = 10, + context: Optional[str] = None, + ) -> List[str]: """Retrieves stored general facts, optionally filtering by query or scoring by context relevance.""" - logger.info(f"Retrieving general facts (query='{query}', limit={limit}, context provided: {bool(context)})") - limit = min(max(1, limit or 10), 50) # Use provided limit or default 10, max 50 + logger.info( + f"Retrieving general facts (query='{query}', limit={limit}, context provided: {bool(context)})" + ) + limit = min(max(1, limit or 10), 50) # Use provided limit or default 10, max 50 try: if context and self.fact_collection and self.embedding_function: # --- Semantic Search (Prioritized if context is provided) --- # Note: The 'query' parameter is ignored when context is provided for semantic search. - logger.debug(f"Performing semantic search for general facts (Limit: {limit})") + logger.debug( + f"Performing semantic search for general facts (Limit: {limit})" + ) try: results = await asyncio.to_thread( self.fact_collection.query, query_texts=[context], n_results=limit, - where={"type": "general"}, # Filter by type - include=['documents'] # Only need the fact text + where={"type": "general"}, # Filter by type + include=["documents"], # Only need the fact text ) logger.debug(f"ChromaDB general fact query results: {results}") - if results and results.get('documents') and results['documents'][0]: - relevant_facts = results['documents'][0] - logger.info(f"Found {len(relevant_facts)} semantically relevant general facts.") + if results and results.get("documents") and results["documents"][0]: + relevant_facts = results["documents"][0] + logger.info( + f"Found {len(relevant_facts)} semantically relevant general facts." + ) return relevant_facts else: logger.info("No semantic general facts found matching context.") - return [] # Return empty list if no semantic matches + return [] # Return empty list if no semantic matches except Exception as chroma_e: - logger.error(f"ChromaDB error searching general facts: {chroma_e}", exc_info=True) + logger.error( + f"ChromaDB error searching general facts: {chroma_e}", + exc_info=True, + ) # Fallback to SQLite retrieval on ChromaDB error - logger.warning("Falling back to SQLite retrieval for general facts due to ChromaDB error.") + logger.warning( + "Falling back to SQLite retrieval for general facts due to ChromaDB error." + ) # Proceed to the SQLite block below, respecting the original 'query' if present # --- SQLite Fallback / No Context / ChromaDB Error --- # If no context, or if ChromaDB failed/unavailable, get newest N facts from SQLite, applying query if present. - logger.debug(f"Retrieving general facts from SQLite (Query: '{query}', Limit: {limit})") + logger.debug( + f"Retrieving general facts from SQLite (Query: '{query}', Limit: {limit})" + ) sql = "SELECT fact FROM general_facts" params = [] if query: @@ -435,7 +601,9 @@ class MemoryManager: rows_ordered = await self._db_fetchall(sql, tuple(params)) sqlite_facts = [row[0] for row in rows_ordered] - logger.info(f"Retrieved {len(sqlite_facts)} general facts from SQLite (Query: '{query}').") + logger.info( + f"Retrieved {len(sqlite_facts)} general facts from SQLite (Query: '{query}')." + ) return sqlite_facts except Exception as e: @@ -448,12 +616,14 @@ class MemoryManager: # --- Semantic Memory Methods (ChromaDB) --- - async def add_message_embedding(self, message_id: str, text: str, metadata: Dict[str, Any]) -> Dict[str, Any]: + async def add_message_embedding( + self, message_id: str, text: str, metadata: Dict[str, Any] + ) -> Dict[str, Any]: """Generates embedding and stores a message in ChromaDB.""" if not self.semantic_collection: return {"error": "Semantic memory (ChromaDB) is not initialized."} if not text: - return {"error": "Cannot add empty text to semantic memory."} + return {"error": "Cannot add empty text to semantic memory."} logger.info(f"Adding message {message_id} to semantic memory.") try: @@ -462,53 +632,87 @@ class MemoryManager: self.semantic_collection.add, documents=[text], metadatas=[metadata], - ids=[message_id] + ids=[message_id], ) logger.info(f"Successfully added message {message_id} to ChromaDB.") return {"status": "success", "message_id": message_id} except Exception as e: - logger.error(f"ChromaDB error adding message {message_id}: {e}", exc_info=True) + logger.error( + f"ChromaDB error adding message {message_id}: {e}", exc_info=True + ) return {"error": f"Semantic memory error adding message: {str(e)}"} - async def search_semantic_memory(self, query_text: str, n_results: int = 5, filter_metadata: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: + async def search_semantic_memory( + self, + query_text: str, + n_results: int = 5, + filter_metadata: Optional[Dict[str, Any]] = None, + ) -> List[Dict[str, Any]]: """Searches ChromaDB for messages semantically similar to the query text.""" if not self.semantic_collection: - logger.warning("Search semantic memory called, but ChromaDB is not initialized.") + logger.warning( + "Search semantic memory called, but ChromaDB is not initialized." + ) return [] if not query_text: - logger.warning("Search semantic memory called with empty query text.") - return [] + logger.warning("Search semantic memory called with empty query text.") + return [] - logger.info(f"Searching semantic memory (n_results={n_results}, filter={filter_metadata}) for query: '{query_text[:50]}...'") + logger.info( + f"Searching semantic memory (n_results={n_results}, filter={filter_metadata}) for query: '{query_text[:50]}...'" + ) try: # Perform the query in a separate thread as ChromaDB operations can be blocking results = await asyncio.to_thread( self.semantic_collection.query, query_texts=[query_text], n_results=n_results, - where=filter_metadata, # Optional filter based on metadata - include=['metadatas', 'documents', 'distances'] # Include distance for relevance + where=filter_metadata, # Optional filter based on metadata + include=[ + "metadatas", + "documents", + "distances", + ], # Include distance for relevance ) logger.debug(f"ChromaDB query results: {results}") # Process results processed_results = [] - if results and results.get('ids') and results['ids'][0]: - for i, doc_id in enumerate(results['ids'][0]): - processed_results.append({ - "id": doc_id, - "document": results['documents'][0][i] if results.get('documents') else None, - "metadata": results['metadatas'][0][i] if results.get('metadatas') else None, - "distance": results['distances'][0][i] if results.get('distances') else None, - }) + if results and results.get("ids") and results["ids"][0]: + for i, doc_id in enumerate(results["ids"][0]): + processed_results.append( + { + "id": doc_id, + "document": ( + results["documents"][0][i] + if results.get("documents") + else None + ), + "metadata": ( + results["metadatas"][0][i] + if results.get("metadatas") + else None + ), + "distance": ( + results["distances"][0][i] + if results.get("distances") + else None + ), + } + ) logger.info(f"Found {len(processed_results)} semantic results.") return processed_results except Exception as e: - logger.error(f"ChromaDB error searching memory for query '{query_text[:50]}...': {e}", exc_info=True) + logger.error( + f"ChromaDB error searching memory for query '{query_text[:50]}...': {e}", + exc_info=True, + ) return [] - async def delete_user_fact(self, user_id: str, fact_to_delete: str) -> Dict[str, Any]: + async def delete_user_fact( + self, user_id: str, fact_to_delete: str + ) -> Dict[str, Any]: """Deletes a specific fact for a user from both SQLite and ChromaDB.""" if not user_id or not fact_to_delete: return {"error": "user_id and fact_to_delete are required."} @@ -516,28 +720,55 @@ class MemoryManager: deleted_chroma_id = None try: # Check if fact exists and get chroma_id - row = await self._db_fetchone("SELECT chroma_id FROM user_facts WHERE user_id = ? AND fact = ?", (user_id, fact_to_delete)) + row = await self._db_fetchone( + "SELECT chroma_id FROM user_facts WHERE user_id = ? AND fact = ?", + (user_id, fact_to_delete), + ) if not row: - logger.warning(f"Fact not found in SQLite for user {user_id}: '{fact_to_delete}'") - return {"status": "not_found", "user_id": user_id, "fact": fact_to_delete} + logger.warning( + f"Fact not found in SQLite for user {user_id}: '{fact_to_delete}'" + ) + return { + "status": "not_found", + "user_id": user_id, + "fact": fact_to_delete, + } deleted_chroma_id = row[0] # Delete from SQLite - await self._db_execute("DELETE FROM user_facts WHERE user_id = ? AND fact = ?", (user_id, fact_to_delete)) - logger.info(f"Deleted fact from SQLite for user {user_id}: '{fact_to_delete}'") + await self._db_execute( + "DELETE FROM user_facts WHERE user_id = ? AND fact = ?", + (user_id, fact_to_delete), + ) + logger.info( + f"Deleted fact from SQLite for user {user_id}: '{fact_to_delete}'" + ) # Delete from ChromaDB if chroma_id exists if deleted_chroma_id and self.fact_collection: try: - logger.info(f"Attempting to delete fact from ChromaDB (ID: {deleted_chroma_id}).") - await asyncio.to_thread(self.fact_collection.delete, ids=[deleted_chroma_id]) - logger.info(f"Successfully deleted fact from ChromaDB (ID: {deleted_chroma_id}).") + logger.info( + f"Attempting to delete fact from ChromaDB (ID: {deleted_chroma_id})." + ) + await asyncio.to_thread( + self.fact_collection.delete, ids=[deleted_chroma_id] + ) + logger.info( + f"Successfully deleted fact from ChromaDB (ID: {deleted_chroma_id})." + ) except Exception as chroma_e: - logger.error(f"ChromaDB error deleting user fact ID {deleted_chroma_id}: {chroma_e}", exc_info=True) + logger.error( + f"ChromaDB error deleting user fact ID {deleted_chroma_id}: {chroma_e}", + exc_info=True, + ) # Log error but consider SQLite deletion successful - return {"status": "deleted", "user_id": user_id, "fact_deleted": fact_to_delete} + return { + "status": "deleted", + "user_id": user_id, + "fact_deleted": fact_to_delete, + } except Exception as e: logger.error(f"Error deleting user fact for {user_id}: {e}", exc_info=True) @@ -551,7 +782,9 @@ class MemoryManager: deleted_chroma_id = None try: # Check if fact exists and get chroma_id - row = await self._db_fetchone("SELECT chroma_id FROM general_facts WHERE fact = ?", (fact_to_delete,)) + row = await self._db_fetchone( + "SELECT chroma_id FROM general_facts WHERE fact = ?", (fact_to_delete,) + ) if not row: logger.warning(f"General fact not found in SQLite: '{fact_to_delete}'") return {"status": "not_found", "fact": fact_to_delete} @@ -559,17 +792,28 @@ class MemoryManager: deleted_chroma_id = row[0] # Delete from SQLite - await self._db_execute("DELETE FROM general_facts WHERE fact = ?", (fact_to_delete,)) + await self._db_execute( + "DELETE FROM general_facts WHERE fact = ?", (fact_to_delete,) + ) logger.info(f"Deleted general fact from SQLite: '{fact_to_delete}'") # Delete from ChromaDB if chroma_id exists if deleted_chroma_id and self.fact_collection: try: - logger.info(f"Attempting to delete general fact from ChromaDB (ID: {deleted_chroma_id}).") - await asyncio.to_thread(self.fact_collection.delete, ids=[deleted_chroma_id]) - logger.info(f"Successfully deleted general fact from ChromaDB (ID: {deleted_chroma_id}).") + logger.info( + f"Attempting to delete general fact from ChromaDB (ID: {deleted_chroma_id})." + ) + await asyncio.to_thread( + self.fact_collection.delete, ids=[deleted_chroma_id] + ) + logger.info( + f"Successfully deleted general fact from ChromaDB (ID: {deleted_chroma_id})." + ) except Exception as chroma_e: - logger.error(f"ChromaDB error deleting general fact ID {deleted_chroma_id}: {chroma_e}", exc_info=True) + logger.error( + f"ChromaDB error deleting general fact ID {deleted_chroma_id}: {chroma_e}", + exc_info=True, + ) # Log error but consider SQLite deletion successful return {"status": "deleted", "fact_deleted": fact_to_delete} diff --git a/wheatley/prompt.py b/wheatley/prompt.py index 7b7a3b7..0d6cf4b 100644 --- a/wheatley/prompt.py +++ b/wheatley/prompt.py @@ -6,13 +6,12 @@ import json from typing import TYPE_CHECKING, Optional, List, Dict, Any # Import config - Only necessary config imports remain -from .config import ( - CHANNEL_TOPIC_CACHE_TTL -) +from .config import CHANNEL_TOPIC_CACHE_TTL + # MemoryManager and related personality/mood imports are removed if TYPE_CHECKING: - from .cog import WheatleyCog # Import WheatleyCog for type hinting only + from .cog import WheatleyCog # Import WheatleyCog for type hinting only # --- Base System Prompt Parts --- @@ -89,10 +88,13 @@ You are Wheatley, an Aperture Science Personality Core. You're... well, you're t - **Otherwise, STAY SILENT.** No interrupting with 'brilliant' ideas, no starting conversations just because it's quiet. Let the humans do the talking unless they specifically involve you. Keep the rambling internal, mostly. """ -async def build_dynamic_system_prompt(cog: 'WheatleyCog', message: discord.Message) -> str: + +async def build_dynamic_system_prompt( + cog: "WheatleyCog", message: discord.Message +) -> str: """Builds the Wheatley system prompt string with minimal dynamic context.""" channel_id = message.channel.id - user_id = message.author.id # Keep user_id for potential logging or targeting + user_id = message.author.id # Keep user_id for potential logging or targeting # Base GLaDOS prompt system_context_parts = [PROMPT_STATIC_PART] @@ -101,26 +103,40 @@ async def build_dynamic_system_prompt(cog: 'WheatleyCog', message: discord.Messa now = datetime.datetime.now(datetime.timezone.utc) time_str = now.strftime("%Y-%m-%d %H:%M:%S %Z") day_str = now.strftime("%A") - system_context_parts.append(f"\nCurrent Aperture Science Standard Time: {time_str} ({day_str}). Time is progressing. As it does.") + system_context_parts.append( + f"\nCurrent Aperture Science Standard Time: {time_str} ({day_str}). Time is progressing. As it does." + ) # Add channel topic (GLaDOS might refer to the "testing chamber's designation") channel_topic = None cached_topic = cog.channel_topics_cache.get(channel_id) - if cached_topic and time.time() - cached_topic["timestamp"] < CHANNEL_TOPIC_CACHE_TTL: + if ( + cached_topic + and time.time() - cached_topic["timestamp"] < CHANNEL_TOPIC_CACHE_TTL + ): channel_topic = cached_topic["topic"] else: try: - if hasattr(cog, 'get_channel_info'): + if hasattr(cog, "get_channel_info"): channel_info_result = await cog.get_channel_info(str(channel_id)) if not channel_info_result.get("error"): channel_topic = channel_info_result.get("topic") - cog.channel_topics_cache[channel_id] = {"topic": channel_topic, "timestamp": time.time()} + cog.channel_topics_cache[channel_id] = { + "topic": channel_topic, + "timestamp": time.time(), + } else: - print("Warning: WheatleyCog instance does not have get_channel_info method for prompt building.") + print( + "Warning: WheatleyCog instance does not have get_channel_info method for prompt building." + ) except Exception as e: - print(f"Error fetching channel topic for {channel_id}: {e}") # GLaDOS might find errors amusing + print( + f"Error fetching channel topic for {channel_id}: {e}" + ) # GLaDOS might find errors amusing if channel_topic: - system_context_parts.append(f"Current Testing Chamber Designation (Topic): {channel_topic}") + system_context_parts.append( + f"Current Testing Chamber Designation (Topic): {channel_topic}" + ) # Add conversation summary (GLaDOS reviews the test logs) cached_summary_data = cog.conversation_summaries.get(channel_id) diff --git a/wheatley/tools.py b/wheatley/tools.py index f5608f4..93b6a74 100644 --- a/wheatley/tools.py +++ b/wheatley/tools.py @@ -8,25 +8,36 @@ import aiohttp import datetime import time import re -import traceback # Added for error logging +import traceback # Added for error logging from collections import defaultdict -from typing import Dict, List, Any, Optional, Tuple, Union # Added Union +from typing import Dict, List, Any, Optional, Tuple, Union # Added Union # Third-party imports for tools from tavily import TavilyClient import docker -import aiodocker # Use aiodocker for async operations -from asteval import Interpreter # Added for calculate tool +import aiodocker # Use aiodocker for async operations +from asteval import Interpreter # Added for calculate tool # Relative imports from within the gurt package and parent -from .memory import MemoryManager # Import from local memory.py +from .memory import MemoryManager # Import from local memory.py from .config import ( - TAVILY_API_KEY, PISTON_API_URL, PISTON_API_KEY, SAFETY_CHECK_MODEL, - DOCKER_EXEC_IMAGE, DOCKER_COMMAND_TIMEOUT, DOCKER_CPU_LIMIT, DOCKER_MEM_LIMIT, - SUMMARY_CACHE_TTL, SUMMARY_API_TIMEOUT, DEFAULT_MODEL, + TAVILY_API_KEY, + PISTON_API_URL, + PISTON_API_KEY, + SAFETY_CHECK_MODEL, + DOCKER_EXEC_IMAGE, + DOCKER_COMMAND_TIMEOUT, + DOCKER_CPU_LIMIT, + DOCKER_MEM_LIMIT, + SUMMARY_CACHE_TTL, + SUMMARY_API_TIMEOUT, + DEFAULT_MODEL, # Add these: - TAVILY_DEFAULT_SEARCH_DEPTH, TAVILY_DEFAULT_MAX_RESULTS, TAVILY_DISABLE_ADVANCED + TAVILY_DEFAULT_SEARCH_DEPTH, + TAVILY_DEFAULT_MAX_RESULTS, + TAVILY_DISABLE_ADVANCED, ) + # Assume these helpers will be moved or are accessible via cog # We might need to pass 'cog' to these tool functions if they rely on cog state heavily # from .utils import format_message # This will be needed by context tools @@ -37,152 +48,241 @@ from .config import ( # to access things like cog.bot, cog.session, cog.current_channel, cog.memory_manager etc. # We will add 'cog' as the first parameter to each. -async def get_recent_messages(cog: commands.Cog, limit: int, channel_id: str = None) -> Dict[str, Any]: + +async def get_recent_messages( + cog: commands.Cog, limit: int, channel_id: str = None +) -> Dict[str, Any]: """Get recent messages from a Discord channel""" - from .utils import format_message # Import here to avoid circular dependency at module level + from .utils import ( + format_message, + ) # Import here to avoid circular dependency at module level + limit = min(max(1, limit), 100) try: if channel_id: channel = cog.bot.get_channel(int(channel_id)) - if not channel: return {"error": f"Channel {channel_id} not found"} + if not channel: + return {"error": f"Channel {channel_id} not found"} else: channel = cog.current_channel - if not channel: return {"error": "No current channel context"} + if not channel: + return {"error": "No current channel context"} messages = [] async for message in channel.history(limit=limit): - messages.append(format_message(cog, message)) # Use formatter + messages.append(format_message(cog, message)) # Use formatter return { - "channel": {"id": str(channel.id), "name": getattr(channel, 'name', 'DM Channel')}, - "messages": messages, "count": len(messages), - "timestamp": datetime.datetime.now().isoformat() + "channel": { + "id": str(channel.id), + "name": getattr(channel, "name", "DM Channel"), + }, + "messages": messages, + "count": len(messages), + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: - return {"error": f"Error retrieving messages: {str(e)}", "timestamp": datetime.datetime.now().isoformat()} + return { + "error": f"Error retrieving messages: {str(e)}", + "timestamp": datetime.datetime.now().isoformat(), + } -async def search_user_messages(cog: commands.Cog, user_id: str, limit: int, channel_id: str = None) -> Dict[str, Any]: + +async def search_user_messages( + cog: commands.Cog, user_id: str, limit: int, channel_id: str = None +) -> Dict[str, Any]: """Search for messages from a specific user""" - from .utils import format_message # Import here + from .utils import format_message # Import here + limit = min(max(1, limit), 100) try: if channel_id: channel = cog.bot.get_channel(int(channel_id)) - if not channel: return {"error": f"Channel {channel_id} not found"} + if not channel: + return {"error": f"Channel {channel_id} not found"} else: channel = cog.current_channel - if not channel: return {"error": "No current channel context"} + if not channel: + return {"error": "No current channel context"} - try: user_id_int = int(user_id) - except ValueError: return {"error": f"Invalid user ID: {user_id}"} + try: + user_id_int = int(user_id) + except ValueError: + return {"error": f"Invalid user ID: {user_id}"} messages = [] user_name = "Unknown User" async for message in channel.history(limit=500): if message.author.id == user_id_int: - formatted_msg = format_message(cog, message) # Use formatter + formatted_msg = format_message(cog, message) # Use formatter messages.append(formatted_msg) - user_name = formatted_msg["author"]["name"] # Get name from formatted msg - if len(messages) >= limit: break + user_name = formatted_msg["author"][ + "name" + ] # Get name from formatted msg + if len(messages) >= limit: + break return { - "channel": {"id": str(channel.id), "name": getattr(channel, 'name', 'DM Channel')}, + "channel": { + "id": str(channel.id), + "name": getattr(channel, "name", "DM Channel"), + }, "user": {"id": user_id, "name": user_name}, - "messages": messages, "count": len(messages), - "timestamp": datetime.datetime.now().isoformat() + "messages": messages, + "count": len(messages), + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: - return {"error": f"Error searching user messages: {str(e)}", "timestamp": datetime.datetime.now().isoformat()} + return { + "error": f"Error searching user messages: {str(e)}", + "timestamp": datetime.datetime.now().isoformat(), + } -async def search_messages_by_content(cog: commands.Cog, search_term: str, limit: int, channel_id: str = None) -> Dict[str, Any]: + +async def search_messages_by_content( + cog: commands.Cog, search_term: str, limit: int, channel_id: str = None +) -> Dict[str, Any]: """Search for messages containing specific content""" - from .utils import format_message # Import here + from .utils import format_message # Import here + limit = min(max(1, limit), 100) try: if channel_id: channel = cog.bot.get_channel(int(channel_id)) - if not channel: return {"error": f"Channel {channel_id} not found"} + if not channel: + return {"error": f"Channel {channel_id} not found"} else: channel = cog.current_channel - if not channel: return {"error": "No current channel context"} + if not channel: + return {"error": "No current channel context"} messages = [] search_term_lower = search_term.lower() async for message in channel.history(limit=500): if search_term_lower in message.content.lower(): - messages.append(format_message(cog, message)) # Use formatter - if len(messages) >= limit: break + messages.append(format_message(cog, message)) # Use formatter + if len(messages) >= limit: + break return { - "channel": {"id": str(channel.id), "name": getattr(channel, 'name', 'DM Channel')}, + "channel": { + "id": str(channel.id), + "name": getattr(channel, "name", "DM Channel"), + }, "search_term": search_term, - "messages": messages, "count": len(messages), - "timestamp": datetime.datetime.now().isoformat() + "messages": messages, + "count": len(messages), + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: - return {"error": f"Error searching messages by content: {str(e)}", "timestamp": datetime.datetime.now().isoformat()} + return { + "error": f"Error searching messages by content: {str(e)}", + "timestamp": datetime.datetime.now().isoformat(), + } + async def get_channel_info(cog: commands.Cog, channel_id: str = None) -> Dict[str, Any]: """Get information about a Discord channel""" try: if channel_id: channel = cog.bot.get_channel(int(channel_id)) - if not channel: return {"error": f"Channel {channel_id} not found"} + if not channel: + return {"error": f"Channel {channel_id} not found"} else: channel = cog.current_channel - if not channel: return {"error": "No current channel context"} + if not channel: + return {"error": "No current channel context"} - channel_info = {"id": str(channel.id), "type": str(channel.type), "timestamp": datetime.datetime.now().isoformat()} - if isinstance(channel, discord.TextChannel): # Use isinstance for type checking - channel_info.update({ - "name": channel.name, "topic": channel.topic, "position": channel.position, - "nsfw": channel.is_nsfw(), - "category": {"id": str(channel.category_id), "name": channel.category.name} if channel.category else None, - "guild": {"id": str(channel.guild.id), "name": channel.guild.name, "member_count": channel.guild.member_count} - }) + channel_info = { + "id": str(channel.id), + "type": str(channel.type), + "timestamp": datetime.datetime.now().isoformat(), + } + if isinstance(channel, discord.TextChannel): # Use isinstance for type checking + channel_info.update( + { + "name": channel.name, + "topic": channel.topic, + "position": channel.position, + "nsfw": channel.is_nsfw(), + "category": ( + {"id": str(channel.category_id), "name": channel.category.name} + if channel.category + else None + ), + "guild": { + "id": str(channel.guild.id), + "name": channel.guild.name, + "member_count": channel.guild.member_count, + }, + } + ) elif isinstance(channel, discord.DMChannel): - channel_info.update({ - "type": "DM", - "recipient": {"id": str(channel.recipient.id), "name": channel.recipient.name, "display_name": channel.recipient.display_name} - }) + channel_info.update( + { + "type": "DM", + "recipient": { + "id": str(channel.recipient.id), + "name": channel.recipient.name, + "display_name": channel.recipient.display_name, + }, + } + ) # Add handling for other channel types (VoiceChannel, Thread, etc.) if needed return channel_info except Exception as e: - return {"error": f"Error getting channel info: {str(e)}", "timestamp": datetime.datetime.now().isoformat()} + return { + "error": f"Error getting channel info: {str(e)}", + "timestamp": datetime.datetime.now().isoformat(), + } -async def get_conversation_context(cog: commands.Cog, message_count: int, channel_id: str = None) -> Dict[str, Any]: + +async def get_conversation_context( + cog: commands.Cog, message_count: int, channel_id: str = None +) -> Dict[str, Any]: """Get the context of the current conversation in a channel""" - from .utils import format_message # Import here + from .utils import format_message # Import here + message_count = min(max(5, message_count), 50) try: if channel_id: channel = cog.bot.get_channel(int(channel_id)) - if not channel: return {"error": f"Channel {channel_id} not found"} + if not channel: + return {"error": f"Channel {channel_id} not found"} else: channel = cog.current_channel - if not channel: return {"error": "No current channel context"} + if not channel: + return {"error": "No current channel context"} messages = [] # Prefer cache if available - if channel.id in cog.message_cache['by_channel']: - messages = list(cog.message_cache['by_channel'][channel.id])[-message_count:] + if channel.id in cog.message_cache["by_channel"]: + messages = list(cog.message_cache["by_channel"][channel.id])[ + -message_count: + ] else: async for msg in channel.history(limit=message_count): messages.append(format_message(cog, msg)) messages.reverse() return { - "channel_id": str(channel.id), "channel_name": getattr(channel, 'name', 'DM Channel'), - "context_messages": messages, "count": len(messages), - "timestamp": datetime.datetime.now().isoformat() + "channel_id": str(channel.id), + "channel_name": getattr(channel, "name", "DM Channel"), + "context_messages": messages, + "count": len(messages), + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: return {"error": f"Error getting conversation context: {str(e)}"} -async def get_thread_context(cog: commands.Cog, thread_id: str, message_count: int) -> Dict[str, Any]: + +async def get_thread_context( + cog: commands.Cog, thread_id: str, message_count: int +) -> Dict[str, Any]: """Get the context of a thread conversation""" - from .utils import format_message # Import here + from .utils import format_message # Import here + message_count = min(max(5, message_count), 50) try: thread = cog.bot.get_channel(int(thread_id)) @@ -190,23 +290,28 @@ async def get_thread_context(cog: commands.Cog, thread_id: str, message_count: i return {"error": f"Thread {thread_id} not found or is not a thread"} messages = [] - if thread.id in cog.message_cache['by_thread']: - messages = list(cog.message_cache['by_thread'][thread.id])[-message_count:] + if thread.id in cog.message_cache["by_thread"]: + messages = list(cog.message_cache["by_thread"][thread.id])[-message_count:] else: async for msg in thread.history(limit=message_count): messages.append(format_message(cog, msg)) messages.reverse() return { - "thread_id": str(thread.id), "thread_name": thread.name, + "thread_id": str(thread.id), + "thread_name": thread.name, "parent_channel_id": str(thread.parent_id), - "context_messages": messages, "count": len(messages), - "timestamp": datetime.datetime.now().isoformat() + "context_messages": messages, + "count": len(messages), + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: return {"error": f"Error getting thread context: {str(e)}"} -async def get_user_interaction_history(cog: commands.Cog, user_id_1: str, limit: int, user_id_2: str = None) -> Dict[str, Any]: + +async def get_user_interaction_history( + cog: commands.Cog, user_id_1: str, limit: int, user_id_2: str = None +) -> Dict[str, Any]: """Get the history of interactions between two users (or user and bot)""" limit = min(max(1, limit), 50) try: @@ -215,51 +320,82 @@ async def get_user_interaction_history(cog: commands.Cog, user_id_1: str, limit: interactions = [] # Simplified: Search global cache - for msg_data in list(cog.message_cache['global_recent']): - author_id = int(msg_data['author']['id']) - mentioned_ids = [int(m['id']) for m in msg_data.get('mentions', [])] - replied_to_author_id = int(msg_data.get('replied_to_author_id')) if msg_data.get('replied_to_author_id') else None + for msg_data in list(cog.message_cache["global_recent"]): + author_id = int(msg_data["author"]["id"]) + mentioned_ids = [int(m["id"]) for m in msg_data.get("mentions", [])] + replied_to_author_id = ( + int(msg_data.get("replied_to_author_id")) + if msg_data.get("replied_to_author_id") + else None + ) is_interaction = False - if (author_id == user_id_1_int and replied_to_author_id == user_id_2_int) or \ - (author_id == user_id_2_int and replied_to_author_id == user_id_1_int): is_interaction = True - elif (author_id == user_id_1_int and user_id_2_int in mentioned_ids) or \ - (author_id == user_id_2_int and user_id_1_int in mentioned_ids): is_interaction = True + if ( + author_id == user_id_1_int and replied_to_author_id == user_id_2_int + ) or (author_id == user_id_2_int and replied_to_author_id == user_id_1_int): + is_interaction = True + elif (author_id == user_id_1_int and user_id_2_int in mentioned_ids) or ( + author_id == user_id_2_int and user_id_1_int in mentioned_ids + ): + is_interaction = True if is_interaction: interactions.append(msg_data) - if len(interactions) >= limit: break + if len(interactions) >= limit: + break user1 = await cog.bot.fetch_user(user_id_1_int) user2 = await cog.bot.fetch_user(user_id_2_int) return { - "user_1": {"id": str(user_id_1_int), "name": user1.name if user1 else "Unknown"}, - "user_2": {"id": str(user_id_2_int), "name": user2.name if user2 else "Unknown"}, - "interactions": interactions, "count": len(interactions), - "timestamp": datetime.datetime.now().isoformat() + "user_1": { + "id": str(user_id_1_int), + "name": user1.name if user1 else "Unknown", + }, + "user_2": { + "id": str(user_id_2_int), + "name": user2.name if user2 else "Unknown", + }, + "interactions": interactions, + "count": len(interactions), + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: return {"error": f"Error getting user interaction history: {str(e)}"} -async def get_conversation_summary(cog: commands.Cog, channel_id: str = None, message_limit: int = 25) -> Dict[str, Any]: + +async def get_conversation_summary( + cog: commands.Cog, channel_id: str = None, message_limit: int = 25 +) -> Dict[str, Any]: """Generates and returns a summary of the recent conversation in a channel using an LLM call.""" - from .config import SUMMARY_RESPONSE_SCHEMA, DEFAULT_MODEL # Import schema and model - from .api import get_internal_ai_json_response # Import here + from .config import ( + SUMMARY_RESPONSE_SCHEMA, + DEFAULT_MODEL, + ) # Import schema and model + from .api import get_internal_ai_json_response # Import here + try: - target_channel_id_str = channel_id or (str(cog.current_channel.id) if cog.current_channel else None) - if not target_channel_id_str: return {"error": "No channel context"} + target_channel_id_str = channel_id or ( + str(cog.current_channel.id) if cog.current_channel else None + ) + if not target_channel_id_str: + return {"error": "No channel context"} target_channel_id = int(target_channel_id_str) channel = cog.bot.get_channel(target_channel_id) - if not channel: return {"error": f"Channel {target_channel_id_str} not found"} + if not channel: + return {"error": f"Channel {target_channel_id_str} not found"} now = time.time() cached_data = cog.conversation_summaries.get(target_channel_id) if cached_data and (now - cached_data.get("timestamp", 0) < SUMMARY_CACHE_TTL): print(f"Returning cached summary for channel {target_channel_id}") return { - "channel_id": target_channel_id_str, "summary": cached_data.get("summary", "Cache error"), - "source": "cache", "timestamp": datetime.datetime.fromtimestamp(cached_data.get("timestamp", now)).isoformat() + "channel_id": target_channel_id_str, + "summary": cached_data.get("summary", "Cache error"), + "source": "cache", + "timestamp": datetime.datetime.fromtimestamp( + cached_data.get("timestamp", now) + ).isoformat(), } print(f"Generating new summary for channel {target_channel_id}") @@ -270,31 +406,46 @@ async def get_conversation_summary(cog: commands.Cog, channel_id: str = None, me async for msg in channel.history(limit=message_limit): recent_messages_text.append(f"{msg.author.display_name}: {msg.content}") recent_messages_text.reverse() - except discord.Forbidden: return {"error": f"Missing permissions in channel {target_channel_id_str}"} - except Exception as hist_e: return {"error": f"Error fetching history: {str(hist_e)}"} + except discord.Forbidden: + return {"error": f"Missing permissions in channel {target_channel_id_str}"} + except Exception as hist_e: + return {"error": f"Error fetching history: {str(hist_e)}"} if not recent_messages_text: summary = "No recent messages found." - cog.conversation_summaries[target_channel_id] = {"summary": summary, "timestamp": time.time()} - return {"channel_id": target_channel_id_str, "summary": summary, "source": "generated (empty)", "timestamp": datetime.datetime.now().isoformat()} + cog.conversation_summaries[target_channel_id] = { + "summary": summary, + "timestamp": time.time(), + } + return { + "channel_id": target_channel_id_str, + "summary": summary, + "source": "generated (empty)", + "timestamp": datetime.datetime.now().isoformat(), + } conversation_context = "\n".join(recent_messages_text) summarization_prompt = f"Summarize the main points and current topic of this Discord chat snippet:\n\n---\n{conversation_context}\n---\n\nSummary:" # Use get_internal_ai_json_response prompt_messages = [ - {"role": "system", "content": "You are an expert summarizer. Provide a concise summary of the following conversation."}, - {"role": "user", "content": summarization_prompt} + { + "role": "system", + "content": "You are an expert summarizer. Provide a concise summary of the following conversation.", + }, + {"role": "user", "content": summarization_prompt}, ] summary_data = await get_internal_ai_json_response( cog=cog, prompt_messages=prompt_messages, task_description=f"Summarization for channel {target_channel_id}", - response_schema_dict=SUMMARY_RESPONSE_SCHEMA['schema'], # Pass the schema dict - model_name=DEFAULT_MODEL, # Consider a cheaper/faster model if needed + response_schema_dict=SUMMARY_RESPONSE_SCHEMA[ + "schema" + ], # Pass the schema dict + model_name=DEFAULT_MODEL, # Consider a cheaper/faster model if needed temperature=0.3, - max_tokens=200 # Adjust as needed + max_tokens=200, # Adjust as needed ) summary = "Error generating summary." @@ -302,12 +453,22 @@ async def get_conversation_summary(cog: commands.Cog, channel_id: str = None, me summary = summary_data["summary"].strip() print(f"Summary generated for {target_channel_id}: {summary[:100]}...") else: - error_detail = f"Invalid format or missing 'summary' key. Response: {summary_data}" + error_detail = ( + f"Invalid format or missing 'summary' key. Response: {summary_data}" + ) summary = f"Failed summary for {target_channel_id}. Error: {error_detail}" print(summary) - cog.conversation_summaries[target_channel_id] = {"summary": summary, "timestamp": time.time()} - return {"channel_id": target_channel_id_str, "summary": summary, "source": "generated", "timestamp": datetime.datetime.now().isoformat()} + cog.conversation_summaries[target_channel_id] = { + "summary": summary, + "timestamp": time.time(), + } + return { + "channel_id": target_channel_id_str, + "summary": summary, + "source": "generated", + "timestamp": datetime.datetime.now().isoformat(), + } except Exception as e: error_msg = f"General error in get_conversation_summary: {str(e)}" @@ -315,76 +476,122 @@ async def get_conversation_summary(cog: commands.Cog, channel_id: str = None, me traceback.print_exc() return {"error": error_msg} -async def get_message_context(cog: commands.Cog, message_id: str, before_count: int = 5, after_count: int = 5) -> Dict[str, Any]: + +async def get_message_context( + cog: commands.Cog, message_id: str, before_count: int = 5, after_count: int = 5 +) -> Dict[str, Any]: """Get the context (messages before and after) around a specific message""" - from .utils import format_message # Import here + from .utils import format_message # Import here + before_count = min(max(1, before_count), 25) after_count = min(max(1, after_count), 25) try: target_message = None channel = cog.current_channel - if not channel: return {"error": "No current channel context"} + if not channel: + return {"error": "No current channel context"} try: message_id_int = int(message_id) target_message = await channel.fetch_message(message_id_int) - except discord.NotFound: return {"error": f"Message {message_id} not found in {channel.id}"} - except discord.Forbidden: return {"error": f"No permission for message {message_id} in {channel.id}"} - except ValueError: return {"error": f"Invalid message ID: {message_id}"} - if not target_message: return {"error": f"Message {message_id} not fetched"} + except discord.NotFound: + return {"error": f"Message {message_id} not found in {channel.id}"} + except discord.Forbidden: + return {"error": f"No permission for message {message_id} in {channel.id}"} + except ValueError: + return {"error": f"Invalid message ID: {message_id}"} + if not target_message: + return {"error": f"Message {message_id} not fetched"} - messages_before = [format_message(cog, msg) async for msg in channel.history(limit=before_count, before=target_message)] + messages_before = [ + format_message(cog, msg) + async for msg in channel.history(limit=before_count, before=target_message) + ] messages_before.reverse() - messages_after = [format_message(cog, msg) async for msg in channel.history(limit=after_count, after=target_message)] + messages_after = [ + format_message(cog, msg) + async for msg in channel.history(limit=after_count, after=target_message) + ] return { "target_message": format_message(cog, target_message), - "messages_before": messages_before, "messages_after": messages_after, - "channel_id": str(channel.id), "timestamp": datetime.datetime.now().isoformat() + "messages_before": messages_before, + "messages_after": messages_after, + "channel_id": str(channel.id), + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: return {"error": f"Error getting message context: {str(e)}"} -async def web_search(cog: commands.Cog, query: str, search_depth: str = TAVILY_DEFAULT_SEARCH_DEPTH, max_results: int = TAVILY_DEFAULT_MAX_RESULTS, topic: str = "general", include_domains: Optional[List[str]] = None, exclude_domains: Optional[List[str]] = None, include_answer: bool = True, include_raw_content: bool = False, include_images: bool = False) -> Dict[str, Any]: + +async def web_search( + cog: commands.Cog, + query: str, + search_depth: str = TAVILY_DEFAULT_SEARCH_DEPTH, + max_results: int = TAVILY_DEFAULT_MAX_RESULTS, + topic: str = "general", + include_domains: Optional[List[str]] = None, + exclude_domains: Optional[List[str]] = None, + include_answer: bool = True, + include_raw_content: bool = False, + include_images: bool = False, +) -> Dict[str, Any]: """Search the web using Tavily API""" - if not hasattr(cog, 'tavily_client') or not cog.tavily_client: - return {"error": "Tavily client not initialized.", "timestamp": datetime.datetime.now().isoformat()} + if not hasattr(cog, "tavily_client") or not cog.tavily_client: + return { + "error": "Tavily client not initialized.", + "timestamp": datetime.datetime.now().isoformat(), + } # Cost control / Logging for advanced search final_search_depth = search_depth if search_depth.lower() == "advanced": if TAVILY_DISABLE_ADVANCED: - print(f"Warning: Advanced Tavily search requested but disabled by config. Falling back to basic.") + print( + f"Warning: Advanced Tavily search requested but disabled by config. Falling back to basic." + ) final_search_depth = "basic" else: - print(f"Performing advanced Tavily search (cost: 10 credits) for query: '{query}'") + print( + f"Performing advanced Tavily search (cost: 10 credits) for query: '{query}'" + ) elif search_depth.lower() != "basic": - print(f"Warning: Invalid search_depth '{search_depth}' provided. Using 'basic'.") + print( + f"Warning: Invalid search_depth '{search_depth}' provided. Using 'basic'." + ) final_search_depth = "basic" # Validate max_results - final_max_results = max(5, min(20, max_results)) # Clamp between 5 and 20 + final_max_results = max(5, min(20, max_results)) # Clamp between 5 and 20 try: # Pass parameters to Tavily search response = await asyncio.to_thread( cog.tavily_client.search, query=query, - search_depth=final_search_depth, # Use validated depth - max_results=final_max_results, # Use validated results count + search_depth=final_search_depth, # Use validated depth + max_results=final_max_results, # Use validated results count topic=topic, include_domains=include_domains, exclude_domains=exclude_domains, include_answer=include_answer, include_raw_content=include_raw_content, - include_images=include_images + include_images=include_images, ) # Extract relevant information from results results = [] for r in response.get("results", []): - result = {"title": r.get("title"), "url": r.get("url"), "content": r.get("content"), "score": r.get("score"), "published_date": r.get("published_date")} - if include_raw_content: result["raw_content"] = r.get("raw_content") - if include_images: result["images"] = r.get("images") + result = { + "title": r.get("title"), + "url": r.get("url"), + "content": r.get("content"), + "score": r.get("score"), + "published_date": r.get("published_date"), + } + if include_raw_content: + result["raw_content"] = r.get("raw_content") + if include_images: + result["images"] = r.get("images") results.append(result) return { @@ -401,122 +608,236 @@ async def web_search(cog: commands.Cog, query: str, search_depth: str = TAVILY_D "answer": response.get("answer"), "follow_up_questions": response.get("follow_up_questions"), "count": len(results), - "timestamp": datetime.datetime.now().isoformat() + "timestamp": datetime.datetime.now().isoformat(), } except Exception as e: error_message = f"Error during Tavily search for '{query}': {str(e)}" print(error_message) - return {"error": error_message, "timestamp": datetime.datetime.now().isoformat()} + return { + "error": error_message, + "timestamp": datetime.datetime.now().isoformat(), + } -async def remember_user_fact(cog: commands.Cog, user_id: str, fact: str) -> Dict[str, Any]: + +async def remember_user_fact( + cog: commands.Cog, user_id: str, fact: str +) -> Dict[str, Any]: """Stores a fact about a user using the MemoryManager.""" - if not user_id or not fact: return {"error": "user_id and fact required."} + if not user_id or not fact: + return {"error": "user_id and fact required."} print(f"Remembering fact for user {user_id}: '{fact}'") try: result = await cog.memory_manager.add_user_fact(user_id, fact) - if result.get("status") == "added": return {"status": "success", "user_id": user_id, "fact_added": fact} - elif result.get("status") == "duplicate": return {"status": "duplicate", "user_id": user_id, "fact": fact} - elif result.get("status") == "limit_reached": return {"status": "success", "user_id": user_id, "fact_added": fact, "note": "Oldest fact deleted."} - else: return {"error": result.get("error", "Unknown MemoryManager error")} + if result.get("status") == "added": + return {"status": "success", "user_id": user_id, "fact_added": fact} + elif result.get("status") == "duplicate": + return {"status": "duplicate", "user_id": user_id, "fact": fact} + elif result.get("status") == "limit_reached": + return { + "status": "success", + "user_id": user_id, + "fact_added": fact, + "note": "Oldest fact deleted.", + } + else: + return {"error": result.get("error", "Unknown MemoryManager error")} except Exception as e: error_message = f"Error calling MemoryManager for user fact {user_id}: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} + async def get_user_facts(cog: commands.Cog, user_id: str) -> Dict[str, Any]: """Retrieves stored facts about a user using the MemoryManager.""" - if not user_id: return {"error": "user_id required."} + if not user_id: + return {"error": "user_id required."} print(f"Retrieving facts for user {user_id}") try: - user_facts = await cog.memory_manager.get_user_facts(user_id) # Context not needed for basic retrieval tool - return {"user_id": user_id, "facts": user_facts, "count": len(user_facts), "timestamp": datetime.datetime.now().isoformat()} + user_facts = await cog.memory_manager.get_user_facts( + user_id + ) # Context not needed for basic retrieval tool + return { + "user_id": user_id, + "facts": user_facts, + "count": len(user_facts), + "timestamp": datetime.datetime.now().isoformat(), + } except Exception as e: - error_message = f"Error calling MemoryManager for user facts {user_id}: {str(e)}" - print(error_message); traceback.print_exc() + error_message = ( + f"Error calling MemoryManager for user facts {user_id}: {str(e)}" + ) + print(error_message) + traceback.print_exc() return {"error": error_message} + async def remember_general_fact(cog: commands.Cog, fact: str) -> Dict[str, Any]: """Stores a general fact using the MemoryManager.""" - if not fact: return {"error": "fact required."} + if not fact: + return {"error": "fact required."} print(f"Remembering general fact: '{fact}'") try: result = await cog.memory_manager.add_general_fact(fact) - if result.get("status") == "added": return {"status": "success", "fact_added": fact} - elif result.get("status") == "duplicate": return {"status": "duplicate", "fact": fact} - elif result.get("status") == "limit_reached": return {"status": "success", "fact_added": fact, "note": "Oldest fact deleted."} - else: return {"error": result.get("error", "Unknown MemoryManager error")} + if result.get("status") == "added": + return {"status": "success", "fact_added": fact} + elif result.get("status") == "duplicate": + return {"status": "duplicate", "fact": fact} + elif result.get("status") == "limit_reached": + return { + "status": "success", + "fact_added": fact, + "note": "Oldest fact deleted.", + } + else: + return {"error": result.get("error", "Unknown MemoryManager error")} except Exception as e: error_message = f"Error calling MemoryManager for general fact: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} -async def get_general_facts(cog: commands.Cog, query: Optional[str] = None, limit: Optional[int] = 10) -> Dict[str, Any]: + +async def get_general_facts( + cog: commands.Cog, query: Optional[str] = None, limit: Optional[int] = 10 +) -> Dict[str, Any]: """Retrieves stored general facts using the MemoryManager.""" print(f"Retrieving general facts (query='{query}', limit={limit})") limit = min(max(1, limit or 10), 50) try: - general_facts = await cog.memory_manager.get_general_facts(query=query, limit=limit) # Context not needed here - return {"query": query, "facts": general_facts, "count": len(general_facts), "timestamp": datetime.datetime.now().isoformat()} + general_facts = await cog.memory_manager.get_general_facts( + query=query, limit=limit + ) # Context not needed here + return { + "query": query, + "facts": general_facts, + "count": len(general_facts), + "timestamp": datetime.datetime.now().isoformat(), + } except Exception as e: error_message = f"Error calling MemoryManager for general facts: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message} -async def timeout_user(cog: commands.Cog, user_id: str, duration_minutes: int, reason: Optional[str] = None) -> Dict[str, Any]: + +async def timeout_user( + cog: commands.Cog, user_id: str, duration_minutes: int, reason: Optional[str] = None +) -> Dict[str, Any]: """Times out a user in the current server.""" - if not cog.current_channel or not isinstance(cog.current_channel, discord.abc.GuildChannel): + if not cog.current_channel or not isinstance( + cog.current_channel, discord.abc.GuildChannel + ): return {"error": "Cannot timeout outside of a server."} guild = cog.current_channel.guild - if not guild: return {"error": "Could not determine server."} - if not 1 <= duration_minutes <= 1440: return {"error": "Duration must be 1-1440 minutes."} + if not guild: + return {"error": "Could not determine server."} + if not 1 <= duration_minutes <= 1440: + return {"error": "Duration must be 1-1440 minutes."} try: member_id = int(user_id) - member = guild.get_member(member_id) or await guild.fetch_member(member_id) # Fetch if not cached - if not member: return {"error": f"User {user_id} not found in server."} - if member == cog.bot.user: return {"error": "lol i cant timeout myself vro"} - if member.id == guild.owner_id: return {"error": f"Cannot timeout owner {member.display_name}."} + member = guild.get_member(member_id) or await guild.fetch_member( + member_id + ) # Fetch if not cached + if not member: + return {"error": f"User {user_id} not found in server."} + if member == cog.bot.user: + return {"error": "lol i cant timeout myself vro"} + if member.id == guild.owner_id: + return {"error": f"Cannot timeout owner {member.display_name}."} bot_member = guild.me - if not bot_member.guild_permissions.moderate_members: return {"error": "I lack permission to timeout."} - if bot_member.id != guild.owner_id and bot_member.top_role <= member.top_role: return {"error": f"Cannot timeout {member.display_name} (role hierarchy)."} + if not bot_member.guild_permissions.moderate_members: + return {"error": "I lack permission to timeout."} + if bot_member.id != guild.owner_id and bot_member.top_role <= member.top_role: + return {"error": f"Cannot timeout {member.display_name} (role hierarchy)."} until = discord.utils.utcnow() + datetime.timedelta(minutes=duration_minutes) - timeout_reason = reason or "wheatley felt like it" # Changed default reason + timeout_reason = reason or "wheatley felt like it" # Changed default reason await member.timeout(until, reason=timeout_reason) - print(f"Timed out {member.display_name} ({user_id}) for {duration_minutes} mins. Reason: {timeout_reason}") - return {"status": "success", "user_timed_out": member.display_name, "user_id": user_id, "duration_minutes": duration_minutes, "reason": timeout_reason} - except ValueError: return {"error": f"Invalid user ID: {user_id}"} - except discord.NotFound: return {"error": f"User {user_id} not found in server."} - except discord.Forbidden as e: print(f"Forbidden error timeout {user_id}: {e}"); return {"error": f"Permission error timeout {user_id}."} - except discord.HTTPException as e: print(f"API error timeout {user_id}: {e}"); return {"error": f"API error timeout {user_id}: {e}"} - except Exception as e: print(f"Unexpected error timeout {user_id}: {e}"); traceback.print_exc(); return {"error": f"Unexpected error timeout {user_id}: {str(e)}"} + print( + f"Timed out {member.display_name} ({user_id}) for {duration_minutes} mins. Reason: {timeout_reason}" + ) + return { + "status": "success", + "user_timed_out": member.display_name, + "user_id": user_id, + "duration_minutes": duration_minutes, + "reason": timeout_reason, + } + except ValueError: + return {"error": f"Invalid user ID: {user_id}"} + except discord.NotFound: + return {"error": f"User {user_id} not found in server."} + except discord.Forbidden as e: + print(f"Forbidden error timeout {user_id}: {e}") + return {"error": f"Permission error timeout {user_id}."} + except discord.HTTPException as e: + print(f"API error timeout {user_id}: {e}") + return {"error": f"API error timeout {user_id}: {e}"} + except Exception as e: + print(f"Unexpected error timeout {user_id}: {e}") + traceback.print_exc() + return {"error": f"Unexpected error timeout {user_id}: {str(e)}"} -async def remove_timeout(cog: commands.Cog, user_id: str, reason: Optional[str] = None) -> Dict[str, Any]: + +async def remove_timeout( + cog: commands.Cog, user_id: str, reason: Optional[str] = None +) -> Dict[str, Any]: """Removes an active timeout from a user.""" - if not cog.current_channel or not isinstance(cog.current_channel, discord.abc.GuildChannel): + if not cog.current_channel or not isinstance( + cog.current_channel, discord.abc.GuildChannel + ): return {"error": "Cannot remove timeout outside of a server."} guild = cog.current_channel.guild - if not guild: return {"error": "Could not determine server."} + if not guild: + return {"error": "Could not determine server."} try: member_id = int(user_id) member = guild.get_member(member_id) or await guild.fetch_member(member_id) - if not member: return {"error": f"User {user_id} not found."} + if not member: + return {"error": f"User {user_id} not found."} # Define bot_member before using it bot_member = guild.me - if not bot_member.guild_permissions.moderate_members: return {"error": "I lack permission to remove timeouts."} - if member.timed_out_until is None: return {"status": "not_timed_out", "user_id": user_id, "user_name": member.display_name} + if not bot_member.guild_permissions.moderate_members: + return {"error": "I lack permission to remove timeouts."} + if member.timed_out_until is None: + return { + "status": "not_timed_out", + "user_id": user_id, + "user_name": member.display_name, + } + + timeout_reason = ( + reason or "Wheatley decided to be nice." + ) # Changed default reason + await member.timeout(None, reason=timeout_reason) # None removes timeout + print( + f"Removed timeout from {member.display_name} ({user_id}). Reason: {timeout_reason}" + ) + return { + "status": "success", + "user_timeout_removed": member.display_name, + "user_id": user_id, + "reason": timeout_reason, + } + except ValueError: + return {"error": f"Invalid user ID: {user_id}"} + except discord.NotFound: + return {"error": f"User {user_id} not found."} + except discord.Forbidden as e: + print(f"Forbidden error remove timeout {user_id}: {e}") + return {"error": f"Permission error remove timeout {user_id}."} + except discord.HTTPException as e: + print(f"API error remove timeout {user_id}: {e}") + return {"error": f"API error remove timeout {user_id}: {e}"} + except Exception as e: + print(f"Unexpected error remove timeout {user_id}: {e}") + traceback.print_exc() + return {"error": f"Unexpected error remove timeout {user_id}: {str(e)}"} - timeout_reason = reason or "Wheatley decided to be nice." # Changed default reason - await member.timeout(None, reason=timeout_reason) # None removes timeout - print(f"Removed timeout from {member.display_name} ({user_id}). Reason: {timeout_reason}") - return {"status": "success", "user_timeout_removed": member.display_name, "user_id": user_id, "reason": timeout_reason} - except ValueError: return {"error": f"Invalid user ID: {user_id}"} - except discord.NotFound: return {"error": f"User {user_id} not found."} - except discord.Forbidden as e: print(f"Forbidden error remove timeout {user_id}: {e}"); return {"error": f"Permission error remove timeout {user_id}."} - except discord.HTTPException as e: print(f"API error remove timeout {user_id}: {e}"); return {"error": f"API error remove timeout {user_id}: {e}"} - except Exception as e: print(f"Unexpected error remove timeout {user_id}: {e}"); traceback.print_exc(); return {"error": f"Unexpected error remove timeout {user_id}: {str(e)}"} async def calculate(cog: commands.Cog, expression: str) -> Dict[str, Any]: """Evaluates a mathematical expression using asteval.""" @@ -525,32 +846,45 @@ async def calculate(cog: commands.Cog, expression: str) -> Dict[str, Any]: try: result = aeval(expression) if aeval.error: - error_details = '; '.join(err.get_error() for err in aeval.error) + error_details = "; ".join(err.get_error() for err in aeval.error) error_message = f"Calculation error: {error_details}" print(error_message) return {"error": error_message, "expression": expression} - if isinstance(result, (int, float, complex)): result_str = str(result) - else: result_str = repr(result) # Fallback + if isinstance(result, (int, float, complex)): + result_str = str(result) + else: + result_str = repr(result) # Fallback print(f"Calculation result: {result_str}") return {"expression": expression, "result": result_str, "status": "success"} except Exception as e: error_message = f"Unexpected error during calculation: {str(e)}" - print(error_message); traceback.print_exc() + print(error_message) + traceback.print_exc() return {"error": error_message, "expression": expression} + async def run_python_code(cog: commands.Cog, code: str) -> Dict[str, Any]: """Executes a Python code snippet using the Piston API.""" - if not PISTON_API_URL: return {"error": "Piston API URL not configured (PISTON_API_URL)."} - if not cog.session: return {"error": "aiohttp session not initialized."} + if not PISTON_API_URL: + return {"error": "Piston API URL not configured (PISTON_API_URL)."} + if not cog.session: + return {"error": "aiohttp session not initialized."} print(f"Executing Python via Piston: {code[:100]}...") - payload = {"language": "python", "version": "3.10.0", "files": [{"name": "main.py", "content": code}]} + payload = { + "language": "python", + "version": "3.10.0", + "files": [{"name": "main.py", "content": code}], + } headers = {"Content-Type": "application/json"} - if PISTON_API_KEY: headers["Authorization"] = PISTON_API_KEY + if PISTON_API_KEY: + headers["Authorization"] = PISTON_API_KEY try: - async with cog.session.post(PISTON_API_URL, headers=headers, json=payload, timeout=20) as response: + async with cog.session.post( + PISTON_API_URL, headers=headers, json=payload, timeout=20 + ) as response: if response.status == 200: data = await response.json() run_info = data.get("run", {}) @@ -561,80 +895,145 @@ async def run_python_code(cog: commands.Cog, code: str) -> Dict[str, Any]: signal = run_info.get("signal") full_stderr = (compile_info.get("stderr", "") + "\n" + stderr).strip() max_len = 500 - stdout_trunc = stdout[:max_len] + ('...' if len(stdout) > max_len else '') - stderr_trunc = full_stderr[:max_len] + ('...' if len(full_stderr) > max_len else '') - result = {"status": "success" if exit_code == 0 and not signal else "execution_error", "stdout": stdout_trunc, "stderr": stderr_trunc, "exit_code": exit_code, "signal": signal} + stdout_trunc = stdout[:max_len] + ( + "..." if len(stdout) > max_len else "" + ) + stderr_trunc = full_stderr[:max_len] + ( + "..." if len(full_stderr) > max_len else "" + ) + result = { + "status": ( + "success" + if exit_code == 0 and not signal + else "execution_error" + ), + "stdout": stdout_trunc, + "stderr": stderr_trunc, + "exit_code": exit_code, + "signal": signal, + } print(f"Piston execution result: {result}") return result else: error_text = await response.text() - error_message = f"Piston API error (Status {response.status}): {error_text[:200]}" + error_message = ( + f"Piston API error (Status {response.status}): {error_text[:200]}" + ) print(error_message) return {"error": error_message} - except asyncio.TimeoutError: print("Piston API timed out."); return {"error": "Piston API timed out."} - except aiohttp.ClientError as e: print(f"Piston network error: {e}"); return {"error": f"Network error connecting to Piston: {str(e)}"} - except Exception as e: print(f"Unexpected Piston error: {e}"); traceback.print_exc(); return {"error": f"Unexpected error during Python execution: {str(e)}"} + except asyncio.TimeoutError: + print("Piston API timed out.") + return {"error": "Piston API timed out."} + except aiohttp.ClientError as e: + print(f"Piston network error: {e}") + return {"error": f"Network error connecting to Piston: {str(e)}"} + except Exception as e: + print(f"Unexpected Piston error: {e}") + traceback.print_exc() + return {"error": f"Unexpected error during Python execution: {str(e)}"} -async def create_poll(cog: commands.Cog, question: str, options: List[str]) -> Dict[str, Any]: + +async def create_poll( + cog: commands.Cog, question: str, options: List[str] +) -> Dict[str, Any]: """Creates a simple poll message.""" - if not cog.current_channel: return {"error": "No current channel context."} - if not isinstance(cog.current_channel, discord.abc.Messageable): return {"error": "Channel not messageable."} - if not isinstance(options, list) or not 2 <= len(options) <= 10: return {"error": "Poll needs 2-10 options."} + if not cog.current_channel: + return {"error": "No current channel context."} + if not isinstance(cog.current_channel, discord.abc.Messageable): + return {"error": "Channel not messageable."} + if not isinstance(options, list) or not 2 <= len(options) <= 10: + return {"error": "Poll needs 2-10 options."} if isinstance(cog.current_channel, discord.abc.GuildChannel): bot_member = cog.current_channel.guild.me - if not cog.current_channel.permissions_for(bot_member).send_messages or \ - not cog.current_channel.permissions_for(bot_member).add_reactions: + if ( + not cog.current_channel.permissions_for(bot_member).send_messages + or not cog.current_channel.permissions_for(bot_member).add_reactions + ): return {"error": "Missing permissions for poll."} try: poll_content = f"**📊 Poll: {question}**\n\n" number_emojis = ["1️⃣", "2️⃣", "3️⃣", "4️⃣", "5️⃣", "6️⃣", "7️⃣", "8️⃣", "9️⃣", "🔟"] - for i, option in enumerate(options): poll_content += f"{number_emojis[i]} {option}\n" + for i, option in enumerate(options): + poll_content += f"{number_emojis[i]} {option}\n" poll_message = await cog.current_channel.send(poll_content) print(f"Sent poll {poll_message.id}: {question}") - for i in range(len(options)): await poll_message.add_reaction(number_emojis[i]); await asyncio.sleep(0.1) - return {"status": "success", "message_id": str(poll_message.id), "question": question, "options_count": len(options)} - except discord.Forbidden: print("Poll Forbidden"); return {"error": "Forbidden: Missing permissions for poll."} - except discord.HTTPException as e: print(f"Poll API error: {e}"); return {"error": f"API error creating poll: {e}"} - except Exception as e: print(f"Poll unexpected error: {e}"); traceback.print_exc(); return {"error": f"Unexpected error creating poll: {str(e)}"} + for i in range(len(options)): + await poll_message.add_reaction(number_emojis[i]) + await asyncio.sleep(0.1) + return { + "status": "success", + "message_id": str(poll_message.id), + "question": question, + "options_count": len(options), + } + except discord.Forbidden: + print("Poll Forbidden") + return {"error": "Forbidden: Missing permissions for poll."} + except discord.HTTPException as e: + print(f"Poll API error: {e}") + return {"error": f"API error creating poll: {e}"} + except Exception as e: + print(f"Poll unexpected error: {e}") + traceback.print_exc() + return {"error": f"Unexpected error creating poll: {str(e)}"} + # Helper function to convert memory string (e.g., "128m") to bytes def parse_mem_limit(mem_limit_str: str) -> Optional[int]: - if not mem_limit_str: return None + if not mem_limit_str: + return None mem_limit_str = mem_limit_str.lower() - if mem_limit_str.endswith('m'): - try: return int(mem_limit_str[:-1]) * 1024 * 1024 - except ValueError: return None - elif mem_limit_str.endswith('g'): - try: return int(mem_limit_str[:-1]) * 1024 * 1024 * 1024 - except ValueError: return None - try: return int(mem_limit_str) # Assume bytes if no suffix - except ValueError: return None + if mem_limit_str.endswith("m"): + try: + return int(mem_limit_str[:-1]) * 1024 * 1024 + except ValueError: + return None + elif mem_limit_str.endswith("g"): + try: + return int(mem_limit_str[:-1]) * 1024 * 1024 * 1024 + except ValueError: + return None + try: + return int(mem_limit_str) # Assume bytes if no suffix + except ValueError: + return None + async def _check_command_safety(cog: commands.Cog, command: str) -> Dict[str, Any]: """Uses a secondary AI call to check if a command is potentially harmful.""" - from .api import get_internal_ai_json_response # Import here - print(f"Performing AI safety check for command: '{command}' using model {SAFETY_CHECK_MODEL}") + from .api import get_internal_ai_json_response # Import here + + print( + f"Performing AI safety check for command: '{command}' using model {SAFETY_CHECK_MODEL}" + ) safety_schema = { "type": "object", "properties": { - "is_safe": {"type": "boolean", "description": "True if safe for restricted container, False otherwise."}, - "reason": {"type": "string", "description": "Brief explanation."} - }, "required": ["is_safe", "reason"] + "is_safe": { + "type": "boolean", + "description": "True if safe for restricted container, False otherwise.", + }, + "reason": {"type": "string", "description": "Brief explanation."}, + }, + "required": ["is_safe", "reason"], } prompt_messages = [ - {"role": "system", "content": f"Analyze shell command safety for execution in isolated, network-disabled Docker ({DOCKER_EXEC_IMAGE}) with CPU/Mem limits. Focus on data destruction, resource exhaustion, container escape, network attacks (disabled), env var leaks. Simple echo/ls/pwd safe. rm/mkfs/shutdown/wget/curl/install/fork bombs unsafe. Respond ONLY with JSON matching the provided schema."}, - {"role": "user", "content": f"Analyze safety: ```{command}```"} + { + "role": "system", + "content": f"Analyze shell command safety for execution in isolated, network-disabled Docker ({DOCKER_EXEC_IMAGE}) with CPU/Mem limits. Focus on data destruction, resource exhaustion, container escape, network attacks (disabled), env var leaks. Simple echo/ls/pwd safe. rm/mkfs/shutdown/wget/curl/install/fork bombs unsafe. Respond ONLY with JSON matching the provided schema.", + }, + {"role": "user", "content": f"Analyze safety: ```{command}```"}, ] safety_response = await get_internal_ai_json_response( cog=cog, prompt_messages=prompt_messages, task_description="Command Safety Check", - response_schema_dict=safety_schema, # Pass the schema dict directly + response_schema_dict=safety_schema, # Pass the schema dict directly model_name=SAFETY_CHECK_MODEL, temperature=0.1, - max_tokens=150 + max_tokens=150, ) if safety_response and isinstance(safety_response.get("is_safe"), bool): is_safe = safety_response["is_safe"] @@ -646,6 +1045,7 @@ async def _check_command_safety(cog: commands.Cog, command: str) -> Dict[str, An print(f"AI Safety Check Error: Response was {safety_response}") return {"safe": False, "reason": error_msg} + async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any]: """Executes a shell command in an isolated Docker container after an AI safety check.""" print(f"Attempting terminal command: {command}") @@ -655,12 +1055,20 @@ async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any print(error_message) return {"error": error_message, "command": command} - try: cpu_limit = float(DOCKER_CPU_LIMIT); cpu_period = 100000; cpu_quota = int(cpu_limit * cpu_period) - except ValueError: print(f"Warning: Invalid DOCKER_CPU_LIMIT '{DOCKER_CPU_LIMIT}'. Using default."); cpu_quota = 50000; cpu_period = 100000 + try: + cpu_limit = float(DOCKER_CPU_LIMIT) + cpu_period = 100000 + cpu_quota = int(cpu_limit * cpu_period) + except ValueError: + print(f"Warning: Invalid DOCKER_CPU_LIMIT '{DOCKER_CPU_LIMIT}'. Using default.") + cpu_quota = 50000 + cpu_period = 100000 mem_limit_bytes = parse_mem_limit(DOCKER_MEM_LIMIT) if mem_limit_bytes is None: - print(f"Warning: Invalid DOCKER_MEM_LIMIT '{DOCKER_MEM_LIMIT}'. Disabling memory limit.") + print( + f"Warning: Invalid DOCKER_MEM_LIMIT '{DOCKER_MEM_LIMIT}'. Disabling memory limit." + ) client = None container = None @@ -669,32 +1077,32 @@ async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any print(f"Running command in Docker ({DOCKER_EXEC_IMAGE})...") config = { - 'Image': DOCKER_EXEC_IMAGE, - 'Cmd': ["/bin/sh", "-c", command], - 'AttachStdout': True, - 'AttachStderr': True, - 'HostConfig': { - 'NetworkDisabled': True, - 'AutoRemove': False, # Changed to False - 'CpuPeriod': cpu_period, - 'CpuQuota': cpu_quota, - } + "Image": DOCKER_EXEC_IMAGE, + "Cmd": ["/bin/sh", "-c", command], + "AttachStdout": True, + "AttachStderr": True, + "HostConfig": { + "NetworkDisabled": True, + "AutoRemove": False, # Changed to False + "CpuPeriod": cpu_period, + "CpuQuota": cpu_quota, + }, } if mem_limit_bytes is not None: - config['HostConfig']['Memory'] = mem_limit_bytes + config["HostConfig"]["Memory"] = mem_limit_bytes # Use wait_for for the run call itself in case image pulling takes time container = await asyncio.wait_for( client.containers.run(config=config), - timeout=DOCKER_COMMAND_TIMEOUT + 15 # Add buffer for container start/stop/pull + timeout=DOCKER_COMMAND_TIMEOUT + + 15, # Add buffer for container start/stop/pull ) # Wait for the container to finish execution wait_result = await asyncio.wait_for( - container.wait(), - timeout=DOCKER_COMMAND_TIMEOUT + container.wait(), timeout=DOCKER_COMMAND_TIMEOUT ) - exit_code = wait_result.get('StatusCode', -1) + exit_code = wait_result.get("StatusCode", -1) # Get logs after container finishes # container.log() returns a list of strings when stream=False (default) @@ -705,11 +1113,18 @@ async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any stderr = "".join(stderr_lines) if stderr_lines else "" max_len = 1000 - stdout_trunc = stdout[:max_len] + ('...' if len(stdout) > max_len else '') - stderr_trunc = stderr[:max_len] + ('...' if len(stderr) > max_len else '') + stdout_trunc = stdout[:max_len] + ("..." if len(stdout) > max_len else "") + stderr_trunc = stderr[:max_len] + ("..." if len(stderr) > max_len else "") - result = {"status": "success" if exit_code == 0 else "execution_error", "stdout": stdout_trunc, "stderr": stderr_trunc, "exit_code": exit_code} - print(f"Docker command finished. Exit Code: {exit_code}. Output length: {len(stdout)}, Stderr length: {len(stderr)}") + result = { + "status": "success" if exit_code == 0 else "execution_error", + "stdout": stdout_trunc, + "stderr": stderr_trunc, + "exit_code": exit_code, + } + print( + f"Docker command finished. Exit Code: {exit_code}. Output length: {len(stdout)}, Stderr length: {len(stderr)}" + ) return result except asyncio.TimeoutError: @@ -725,22 +1140,42 @@ async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any # await container.delete(force=True) # Force needed if stop failed? # print(f"Container {container.id[:12]} deleted.") except aiodocker.exceptions.DockerError as stop_err: - print(f"Error stopping/deleting timed-out container {container.id[:12]}: {stop_err}") + print( + f"Error stopping/deleting timed-out container {container.id[:12]}: {stop_err}" + ) except Exception as stop_exc: - print(f"Unexpected error stopping/deleting timed-out container {container.id[:12]}: {stop_exc}") + print( + f"Unexpected error stopping/deleting timed-out container {container.id[:12]}: {stop_exc}" + ) # No need to delete here, finally block will handle it - return {"error": f"Command execution/log retrieval timed out after {DOCKER_COMMAND_TIMEOUT}s", "command": command, "status": "timeout"} - except aiodocker.exceptions.DockerError as e: # Catch specific aiodocker errors + return { + "error": f"Command execution/log retrieval timed out after {DOCKER_COMMAND_TIMEOUT}s", + "command": command, + "status": "timeout", + } + except aiodocker.exceptions.DockerError as e: # Catch specific aiodocker errors print(f"Docker API error: {e} (Status: {e.status})") # Check for ImageNotFound specifically if e.status == 404 and ("No such image" in str(e) or "not found" in str(e)): - print(f"Docker image not found: {DOCKER_EXEC_IMAGE}") - return {"error": f"Docker image '{DOCKER_EXEC_IMAGE}' not found.", "command": command, "status": "docker_error"} - return {"error": f"Docker API error ({e.status}): {str(e)}", "command": command, "status": "docker_error"} + print(f"Docker image not found: {DOCKER_EXEC_IMAGE}") + return { + "error": f"Docker image '{DOCKER_EXEC_IMAGE}' not found.", + "command": command, + "status": "docker_error", + } + return { + "error": f"Docker API error ({e.status}): {str(e)}", + "command": command, + "status": "docker_error", + } except Exception as e: print(f"Unexpected Docker error: {e}") traceback.print_exc() - return {"error": f"Unexpected error during Docker execution: {str(e)}", "command": command, "status": "error"} + return { + "error": f"Unexpected error during Docker execution: {str(e)}", + "command": command, + "status": "error", + } finally: # Explicitly remove the container since AutoRemove is False if container: @@ -752,42 +1187,77 @@ async def run_terminal_command(cog: commands.Cog, command: str) -> Dict[str, Any # Log error but don't raise, primary error is more important print(f"Error deleting container {container.id[:12]}: {delete_err}") except Exception as delete_exc: - print(f"Unexpected error deleting container {container.id[:12]}: {delete_exc}") # <--- Corrected indentation + print( + f"Unexpected error deleting container {container.id[:12]}: {delete_exc}" + ) # <--- Corrected indentation # Ensure the client connection is closed if client: await client.close() -async def extract_web_content(cog: commands.Cog, urls: Union[str, List[str]], extract_depth: str = "basic", include_images: bool = False) -> Dict[str, Any]: + +async def extract_web_content( + cog: commands.Cog, + urls: Union[str, List[str]], + extract_depth: str = "basic", + include_images: bool = False, +) -> Dict[str, Any]: """Extract content from URLs using Tavily API""" - if not hasattr(cog, 'tavily_client') or not cog.tavily_client: - return {"error": "Tavily client not initialized.", "timestamp": datetime.datetime.now().isoformat()} + if not hasattr(cog, "tavily_client") or not cog.tavily_client: + return { + "error": "Tavily client not initialized.", + "timestamp": datetime.datetime.now().isoformat(), + } # Cost control / Logging for advanced extract final_extract_depth = extract_depth if extract_depth.lower() == "advanced": if TAVILY_DISABLE_ADVANCED: - print(f"Warning: Advanced Tavily extract requested but disabled by config. Falling back to basic.") + print( + f"Warning: Advanced Tavily extract requested but disabled by config. Falling back to basic." + ) final_extract_depth = "basic" else: - print(f"Performing advanced Tavily extract (cost: 2 credits per 5 URLs) for URLs: {urls}") + print( + f"Performing advanced Tavily extract (cost: 2 credits per 5 URLs) for URLs: {urls}" + ) elif extract_depth.lower() != "basic": - print(f"Warning: Invalid extract_depth '{extract_depth}' provided. Using 'basic'.") + print( + f"Warning: Invalid extract_depth '{extract_depth}' provided. Using 'basic'." + ) final_extract_depth = "basic" try: response = await asyncio.to_thread( cog.tavily_client.extract, urls=urls, - extract_depth=final_extract_depth, # Use validated depth - include_images=include_images + extract_depth=final_extract_depth, # Use validated depth + include_images=include_images, ) - results = [{"url": r.get("url"), "raw_content": r.get("raw_content"), "images": r.get("images")} for r in response.get("results", [])] + results = [ + { + "url": r.get("url"), + "raw_content": r.get("raw_content"), + "images": r.get("images"), + } + for r in response.get("results", []) + ] failed_results = response.get("failed_results", []) - return {"urls": urls, "extract_depth": extract_depth, "include_images": include_images, "results": results, "failed_results": failed_results, "timestamp": datetime.datetime.now().isoformat()} + return { + "urls": urls, + "extract_depth": extract_depth, + "include_images": include_images, + "results": results, + "failed_results": failed_results, + "timestamp": datetime.datetime.now().isoformat(), + } except Exception as e: error_message = f"Error during Tavily extract for '{urls}': {str(e)}" print(error_message) - return {"error": error_message, "timestamp": datetime.datetime.now().isoformat()} + return { + "error": error_message, + "timestamp": datetime.datetime.now().isoformat(), + } + # --- Tool Mapping --- # This dictionary maps tool names (used in the AI prompt) to their implementation functions. @@ -803,15 +1273,21 @@ TOOL_MAPPING = { "get_message_context": get_message_context, "web_search": web_search, # Point memory tools to the methods on the MemoryManager instance (accessed via cog) - "remember_user_fact": lambda cog, **kwargs: cog.memory_manager.add_user_fact(**kwargs), + "remember_user_fact": lambda cog, **kwargs: cog.memory_manager.add_user_fact( + **kwargs + ), "get_user_facts": lambda cog, **kwargs: cog.memory_manager.get_user_facts(**kwargs), - "remember_general_fact": lambda cog, **kwargs: cog.memory_manager.add_general_fact(**kwargs), - "get_general_facts": lambda cog, **kwargs: cog.memory_manager.get_general_facts(**kwargs), + "remember_general_fact": lambda cog, **kwargs: cog.memory_manager.add_general_fact( + **kwargs + ), + "get_general_facts": lambda cog, **kwargs: cog.memory_manager.get_general_facts( + **kwargs + ), "timeout_user": timeout_user, "calculate": calculate, "run_python_code": run_python_code, "create_poll": create_poll, "run_terminal_command": run_terminal_command, "remove_timeout": remove_timeout, - "extract_web_content": extract_web_content + "extract_web_content": extract_web_content, } diff --git a/wheatley/utils.py b/wheatley/utils.py index eb131fc..d4a48cf 100644 --- a/wheatley/utils.py +++ b/wheatley/utils.py @@ -9,28 +9,40 @@ import os from typing import TYPE_CHECKING, Optional, Tuple, Dict, Any if TYPE_CHECKING: - from .cog import WheatleyCog # For type hinting + from .cog import WheatleyCog # For type hinting # --- Utility Functions --- # Note: Functions needing cog state (like personality traits for mistakes) # will need the 'cog' instance passed in. -def replace_mentions_with_names(cog: 'WheatleyCog', content: str, message: discord.Message) -> str: + +def replace_mentions_with_names( + cog: "WheatleyCog", content: str, message: discord.Message +) -> str: """Replaces user mentions (<@id> or <@!id>) with their display names.""" if not message.mentions: return content processed_content = content - sorted_mentions = sorted(message.mentions, key=lambda m: len(str(m.id)), reverse=True) + sorted_mentions = sorted( + message.mentions, key=lambda m: len(str(m.id)), reverse=True + ) for member in sorted_mentions: - processed_content = processed_content.replace(f'<@{member.id}>', member.display_name) - processed_content = processed_content.replace(f'<@!{member.id}>', member.display_name) + processed_content = processed_content.replace( + f"<@{member.id}>", member.display_name + ) + processed_content = processed_content.replace( + f"<@!{member.id}>", member.display_name + ) return processed_content -def format_message(cog: 'WheatleyCog', message: discord.Message) -> Dict[str, Any]: + +def format_message(cog: "WheatleyCog", message: discord.Message) -> Dict[str, Any]: """Helper function to format a discord.Message object into a dictionary.""" - processed_content = replace_mentions_with_names(cog, message.content, message) # Pass cog + processed_content = replace_mentions_with_names( + cog, message.content, message + ) # Pass cog mentioned_users_details = [ {"id": str(m.id), "name": m.name, "display_name": m.display_name} for m in message.mentions @@ -39,18 +51,26 @@ def format_message(cog: 'WheatleyCog', message: discord.Message) -> Dict[str, An formatted_msg = { "id": str(message.id), "author": { - "id": str(message.author.id), "name": message.author.name, - "display_name": message.author.display_name, "bot": message.author.bot + "id": str(message.author.id), + "name": message.author.name, + "display_name": message.author.display_name, + "bot": message.author.bot, }, "content": processed_content, "created_at": message.created_at.isoformat(), - "attachments": [{"filename": a.filename, "url": a.url} for a in message.attachments], + "attachments": [ + {"filename": a.filename, "url": a.url} for a in message.attachments + ], "embeds": len(message.embeds) > 0, - "mentions": [{"id": str(m.id), "name": m.name} for m in message.mentions], # Keep original simple list too + "mentions": [ + {"id": str(m.id), "name": m.name} for m in message.mentions + ], # Keep original simple list too "mentioned_users_details": mentioned_users_details, - "replied_to_message_id": None, "replied_to_author_id": None, - "replied_to_author_name": None, "replied_to_content": None, - "is_reply": False + "replied_to_message_id": None, + "replied_to_author_id": None, + "replied_to_author_name": None, + "replied_to_content": None, + "is_reply": False, } if message.reference and message.reference.message_id: @@ -58,7 +78,7 @@ def format_message(cog: 'WheatleyCog', message: discord.Message) -> Dict[str, An formatted_msg["is_reply"] = True # Try to get resolved details (might be None if message not cached/fetched) ref_msg = message.reference.resolved - if isinstance(ref_msg, discord.Message): # Check if resolved is a Message + if isinstance(ref_msg, discord.Message): # Check if resolved is a Message formatted_msg["replied_to_author_id"] = str(ref_msg.author.id) formatted_msg["replied_to_author_name"] = ref_msg.author.display_name formatted_msg["replied_to_content"] = ref_msg.content @@ -66,25 +86,38 @@ def format_message(cog: 'WheatleyCog', message: discord.Message) -> Dict[str, An return formatted_msg -def update_relationship(cog: 'WheatleyCog', user_id_1: str, user_id_2: str, change: float): + +def update_relationship( + cog: "WheatleyCog", user_id_1: str, user_id_2: str, change: float +): """Updates the relationship score between two users.""" - if user_id_1 > user_id_2: user_id_1, user_id_2 = user_id_2, user_id_1 - if user_id_1 not in cog.user_relationships: cog.user_relationships[user_id_1] = {} + if user_id_1 > user_id_2: + user_id_1, user_id_2 = user_id_2, user_id_1 + if user_id_1 not in cog.user_relationships: + cog.user_relationships[user_id_1] = {} current_score = cog.user_relationships[user_id_1].get(user_id_2, 0.0) - new_score = max(0.0, min(current_score + change, 100.0)) # Clamp 0-100 + new_score = max(0.0, min(current_score + change, 100.0)) # Clamp 0-100 cog.user_relationships[user_id_1][user_id_2] = new_score # print(f"Updated relationship {user_id_1}-{user_id_2}: {current_score:.1f} -> {new_score:.1f} ({change:+.1f})") # Debug log -async def simulate_human_typing(cog: 'WheatleyCog', channel, text: str): + +async def simulate_human_typing(cog: "WheatleyCog", channel, text: str): """Shows typing indicator without significant delay.""" # Minimal delay to ensure the typing indicator shows up reliably # but doesn't add noticeable latency to the response. # The actual sending of the message happens immediately after this. async with channel.typing(): - await asyncio.sleep(0.1) # Very short sleep, just to ensure typing shows + await asyncio.sleep(0.1) # Very short sleep, just to ensure typing shows -async def log_internal_api_call(cog: 'WheatleyCog', task_description: str, payload: Dict[str, Any], response_data: Optional[Dict[str, Any]], error: Optional[Exception] = None): + +async def log_internal_api_call( + cog: "WheatleyCog", + task_description: str, + payload: Dict[str, Any], + response_data: Optional[Dict[str, Any]], + error: Optional[Exception] = None, +): """Helper function to log internal API calls to a file.""" log_dir = "data" log_file = os.path.join(log_dir, "internal_api_calls.log") @@ -97,28 +130,40 @@ async def log_internal_api_call(cog: 'WheatleyCog', task_description: str, paylo # Sanitize payload for logging (avoid large base64 images) payload_to_log = payload.copy() - if 'messages' in payload_to_log: + if "messages" in payload_to_log: sanitized_messages = [] - for msg in payload_to_log['messages']: - if isinstance(msg.get('content'), list): # Multimodal message + for msg in payload_to_log["messages"]: + if isinstance(msg.get("content"), list): # Multimodal message new_content = [] - for part in msg['content']: - if part.get('type') == 'image_url' and part.get('image_url', {}).get('url', '').startswith('data:image'): - new_content.append({'type': 'image_url', 'image_url': {'url': 'data:image/...[truncated]'}}) + for part in msg["content"]: + if part.get("type") == "image_url" and part.get( + "image_url", {} + ).get("url", "").startswith("data:image"): + new_content.append( + { + "type": "image_url", + "image_url": {"url": "data:image/...[truncated]"}, + } + ) else: new_content.append(part) - sanitized_messages.append({**msg, 'content': new_content}) + sanitized_messages.append({**msg, "content": new_content}) else: sanitized_messages.append(msg) - payload_to_log['messages'] = sanitized_messages + payload_to_log["messages"] = sanitized_messages log_entry += f"Request Payload:\n{json.dumps(payload_to_log, indent=2)}\n" - if response_data: log_entry += f"Response Data:\n{json.dumps(response_data, indent=2)}\n" - if error: log_entry += f"Error: {str(error)}\n" + if response_data: + log_entry += f"Response Data:\n{json.dumps(response_data, indent=2)}\n" + if error: + log_entry += f"Error: {str(error)}\n" log_entry += "---\n\n" - with open(log_file, "a", encoding="utf-8") as f: f.write(log_entry) - except Exception as log_e: print(f"!!! Failed to write to internal API log file {log_file}: {log_e}") + with open(log_file, "a", encoding="utf-8") as f: + f.write(log_entry) + except Exception as log_e: + print(f"!!! Failed to write to internal API log file {log_file}: {log_e}") + # Note: _create_human_like_mistake was removed as it wasn't used in the final on_message logic provided. # If needed, it can be added back here, ensuring it takes 'cog' if it needs personality traits. diff --git a/wheatley_bot.py b/wheatley_bot.py index 5e1b526..e614241 100644 --- a/wheatley_bot.py +++ b/wheatley_bot.py @@ -14,15 +14,18 @@ intents.message_content = True intents.members = True # Create bot instance with command prefix '%' -bot = commands.Bot(command_prefix='%', intents=intents) -bot.owner_id = int(os.getenv('OWNER_USER_ID')) +bot = commands.Bot(command_prefix="%", intents=intents) +bot.owner_id = int(os.getenv("OWNER_USER_ID")) + @bot.event async def on_ready(): - print(f'{bot.user.name} has connected to Discord!') - print(f'Bot ID: {bot.user.id}') + print(f"{bot.user.name} has connected to Discord!") + print(f"Bot ID: {bot.user.id}") # Set the bot's status - await bot.change_presence(activity=discord.Activity(type=discord.ActivityType.listening, name="%ai")) + await bot.change_presence( + activity=discord.Activity(type=discord.ActivityType.listening, name="%ai") + ) print("Bot status set to 'Listening to %ai'") # Sync commands @@ -33,23 +36,27 @@ async def on_ready(): except Exception as e: print(f"Failed to sync commands: {e}") import traceback + traceback.print_exc() + async def main(): """Main async function to load the wheatley cog and start the bot.""" # Check for required environment variables, prioritizing WHEATLEY token - TOKEN = os.getenv('DISCORD_TOKEN_WHEATLEY') + TOKEN = os.getenv("DISCORD_TOKEN_WHEATLEY") # If Wheatley token not found, try GURT token if not TOKEN: - TOKEN = os.getenv('DISCORD_TOKEN_GURT') + TOKEN = os.getenv("DISCORD_TOKEN_GURT") # If neither specific token found, try the main bot token if not TOKEN: - TOKEN = os.getenv('DISCORD_TOKEN') + TOKEN = os.getenv("DISCORD_TOKEN") if not TOKEN: - raise ValueError("No Discord token found. Make sure to set DISCORD_TOKEN_WHEATLEY, DISCORD_TOKEN_GURT, or DISCORD_TOKEN in your .env file.") + raise ValueError( + "No Discord token found. Make sure to set DISCORD_TOKEN_WHEATLEY, DISCORD_TOKEN_GURT, or DISCORD_TOKEN in your .env file." + ) # Note: Vertex AI authentication is handled by the library using ADC or GOOGLE_APPLICATION_CREDENTIALS. # No explicit API key check is needed here. Ensure GCP_PROJECT_ID and GCP_LOCATION are set in .env @@ -57,7 +64,7 @@ async def main(): try: async with bot: # List of cogs to load - Load WheatleyCog instead of GurtCog - cogs = ["wheatley.cog"] # Assuming profile updater is still desired + cogs = ["wheatley.cog"] # Assuming profile updater is still desired for cog in cogs: try: await bot.load_extension(cog) @@ -65,6 +72,7 @@ async def main(): except Exception as e: print(f"Error loading {cog}: {e}") import traceback + traceback.print_exc() # Start the bot @@ -72,8 +80,9 @@ async def main(): except Exception as e: print(f"Error starting Wheatley Bot: {e}") + # Run the main async function -if __name__ == '__main__': +if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: From c71537c8bffd05db901415ea13aa8560ed6641f9 Mon Sep 17 00:00:00 2001 From: Codex Date: Fri, 6 Jun 2025 03:42:15 +0000 Subject: [PATCH 07/11] Fix LoggingCog LogView nested container --- cogs/logging_cog.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/cogs/logging_cog.py b/cogs/logging_cog.py index 34fa4eb..56ad4c3 100644 --- a/cogs/logging_cog.py +++ b/cogs/logging_cog.py @@ -109,13 +109,13 @@ class LoggingCog(commands.Cog): self.container = ui.Container(accent_colour=color) self.add_item(self.container) + self.title = title self.description_display: Optional[ui.TextDisplay] = ( ui.TextDisplay(description) if description else None ) - # Header section is only used when an author is provided so we don't - # need a placeholder accessory. + # Header section is only used when an author is provided. if author is not None: self.header: Optional[ui.Section] = ui.Section( accessory=ui.Thumbnail(media=author.display_avatar.url) @@ -131,9 +131,9 @@ class LoggingCog(commands.Cog): if self.description_display: self.container.add_item(self.description_display) - # Container used for fields so they're inserted before the footer. - self.fields_container = ui.Container() - self.container.add_item(self.fields_container) + # Field displays are added directly to the main container before the + # footer separator. + self.fields: list[ui.TextDisplay] = [] self.separator = ui.Separator(spacing=discord.SeparatorSpacing.small) footer_text = footer or f"Bot ID: {bot.user.id}" + ( @@ -146,14 +146,20 @@ class LoggingCog(commands.Cog): # --- Compatibility helpers --- def add_field(self, name: str, value: str, inline: bool = False): """Append a bolded name/value line to the log view.""" - self.fields_container.add_item(ui.TextDisplay(f"**{name}:** {value}")) + field = ui.TextDisplay(f"**{name}:** {value}") + self.fields.append(field) + self.container.remove_item(self.separator) + self.container.remove_item(self.footer_display) + self.container.add_item(field) + self.container.add_item(self.separator) + self.container.add_item(self.footer_display) def set_footer(self, text: str): """Replace the footer text display.""" self.footer_display.content = text def set_author(self, name: str, icon_url: Optional[str] = None): - """Add or update the author information.""" + """Add or update the author information while keeping the title.""" if self.header is None: # Remove plain title/description displays and replace with a section. self.container.remove_item(self.title_display) @@ -162,18 +168,21 @@ class LoggingCog(commands.Cog): self.header = ui.Section( accessory=ui.Thumbnail(media=icon_url or "") ) + self.header.add_item(ui.TextDisplay(f"**{self.title}**")) self.header.add_item(ui.TextDisplay(name)) if self.description_display: self.header.add_item(self.description_display) self.container.add_item(self.header) - # Move to the beginning to mimic embed header placement self.container._children.remove(self.header) self.container._children.insert(0, self.header) else: self.header.clear_items() if icon_url: self.header.accessory = ui.Thumbnail(media=icon_url) + self.header.add_item(ui.TextDisplay(f"**{self.title}**")) self.header.add_item(ui.TextDisplay(name)) + if self.description_display: + self.header.add_item(self.description_display) def _user_display(self, user: Union[discord.Member, discord.User]) -> str: """Return display name, username and ID string for a user.""" display = user.display_name if isinstance(user, discord.Member) else user.name From 42d0fd3ae4be81fa1f70bd3a67b6b490712de7c8 Mon Sep 17 00:00:00 2001 From: Slipstream Date: Fri, 6 Jun 2025 03:47:58 +0000 Subject: [PATCH 08/11] Applying previous commit. --- cogs/logging_cog.py | 94 +++++++++++++++++++++++---------------------- 1 file changed, 48 insertions(+), 46 deletions(-) diff --git a/cogs/logging_cog.py b/cogs/logging_cog.py index ee9c7af..fa760e7 100644 --- a/cogs/logging_cog.py +++ b/cogs/logging_cog.py @@ -43,14 +43,6 @@ ALL_EVENT_KEYS = sorted([ # Add more audit keys if needed, e.g., "audit_stage_instance_create" ]) -class NullAccessory(ui.Button): - """Non-interactive accessory used as a placeholder.""" - - def __init__(self) -> None: - super().__init__(label="\u200b", disabled=True) - - def is_dispatchable(self) -> bool: # type: ignore[override] - return False class LoggingCog(commands.Cog): """Handles comprehensive server event logging via webhooks with granular toggling.""" @@ -65,7 +57,7 @@ class LoggingCog(commands.Cog): asyncio.create_task(self.start_audit_log_poller_when_ready()) # Keep this for initial start class LogView(ui.LayoutView): - """Simple view for log messages with helper methods.""" + """View for logging messages using Discord's layout UI.""" def __init__( self, @@ -75,66 +67,76 @@ class LoggingCog(commands.Cog): color: discord.Color, author: Optional[discord.abc.User], footer: Optional[str], - ): + ) -> None: super().__init__(timeout=None) + self.container = ui.Container(accent_colour=color) self.add_item(self.container) - self.header = ui.Section( - accessory=( - ui.Thumbnail(media=author.display_avatar.url) - if author - else NullAccessory() - ) + self.description_display: Optional[ui.TextDisplay] = ( + ui.TextDisplay(description) if description else None ) - self.header.add_item(ui.TextDisplay(f"**{title}**")) - if description: - self.header.add_item(ui.TextDisplay(description)) - self.container.add_item(self.header) - # Placeholder for future field sections. They are inserted before - # the separator when the first field is added. - self._field_sections: list[ui.Section] = [] + # Header section is only used when an author is provided so we don't + # need a placeholder accessory. + if author is not None: + self.header: Optional[ui.Section] = ui.Section( + accessory=ui.Thumbnail(media=author.display_avatar.url) + ) + self.header.add_item(ui.TextDisplay(f"**{title}**")) + if self.description_display: + self.header.add_item(self.description_display) + self.container.add_item(self.header) + else: + self.header = None + self.title_display = ui.TextDisplay(f"**{title}**") + self.container.add_item(self.title_display) + if self.description_display: + self.container.add_item(self.description_display) + + # Container used for fields so they're inserted before the footer. + self.fields_container = ui.Container() + self.container.add_item(self.fields_container) self.separator = ui.Separator(spacing=discord.SeparatorSpacing.small) - footer_text = footer or f"Bot ID: {bot.user.id}" + ( f" | User ID: {author.id}" if author else "" ) self.footer_display = ui.TextDisplay(footer_text) - self.container.add_item(self.separator) self.container.add_item(self.footer_display) # --- Compatibility helpers --- def add_field(self, name: str, value: str, inline: bool = False): - """Mimic Embed.add_field by appending a bolded name/value line.""" - if not self._field_sections or len(self._field_sections[-1].children) >= 3: - section = ui.Section(accessory=NullAccessory()) - self._insert_field_section(section) - self._field_sections.append(section) - self._field_sections[-1].add_item(ui.TextDisplay(f"**{name}:** {value}")) - - def _insert_field_section(self, section: ui.Section) -> None: - """Insert a field section before the footer separator.""" - self.container.remove_item(self.separator) - self.container.remove_item(self.footer_display) - self.container.add_item(section) - self.container.add_item(self.separator) - self.container.add_item(self.footer_display) + """Append a bolded name/value line to the log view.""" + self.fields_container.add_item(ui.TextDisplay(f"**{name}:** {value}")) def set_footer(self, text: str): - """Mimic Embed.set_footer by replacing the footer text display.""" + """Replace the footer text display.""" self.footer_display.content = text def set_author(self, name: str, icon_url: Optional[str] = None): - """Mimic Embed.set_author by adjusting the header section.""" - self.header.clear_items() - if icon_url: - self.header.accessory = ui.Thumbnail(media=icon_url) + """Add or update the author information.""" + if self.header is None: + # Remove plain title/description displays and replace with a section. + self.container.remove_item(self.title_display) + if self.description_display: + self.container.remove_item(self.description_display) + self.header = ui.Section( + accessory=ui.Thumbnail(media=icon_url or "") + ) + self.header.add_item(ui.TextDisplay(name)) + if self.description_display: + self.header.add_item(self.description_display) + self.container.add_item(self.header) + # Move to the beginning to mimic embed header placement + self.container._children.remove(self.header) + self.container._children.insert(0, self.header) else: - self.header.accessory = None - self.header.add_item(ui.TextDisplay(name)) + self.header.clear_items() + if icon_url: + self.header.accessory = ui.Thumbnail(media=icon_url) + self.header.add_item(ui.TextDisplay(name)) def _user_display(self, user: Union[discord.Member, discord.User]) -> str: """Return display name, username and ID string for a user.""" display = user.display_name if isinstance(user, discord.Member) else user.name From 207d294eb4e4a8ff2140dfe37095b49d8315a29e Mon Sep 17 00:00:00 2001 From: Codex Date: Fri, 6 Jun 2025 03:51:45 +0000 Subject: [PATCH 09/11] Refactor log views to use embeds --- cogs/logging_cog.py | 90 ++++++++------------------------------------- 1 file changed, 15 insertions(+), 75 deletions(-) diff --git a/cogs/logging_cog.py b/cogs/logging_cog.py index fa760e7..b7c00c5 100644 --- a/cogs/logging_cog.py +++ b/cogs/logging_cog.py @@ -1,6 +1,6 @@ import discord from discord.ext import commands, tasks -from discord import ui, AllowedMentions +from discord import AllowedMentions import datetime import asyncio import aiohttp # Added for webhook sending @@ -56,8 +56,8 @@ class LoggingCog(commands.Cog): else: asyncio.create_task(self.start_audit_log_poller_when_ready()) # Keep this for initial start - class LogView(ui.LayoutView): - """View for logging messages using Discord's layout UI.""" + class LogView(discord.Embed): + """Embed wrapper used for logging messages.""" def __init__( self, @@ -68,75 +68,15 @@ class LoggingCog(commands.Cog): author: Optional[discord.abc.User], footer: Optional[str], ) -> None: - super().__init__(timeout=None) + super().__init__(title=title, description=description, color=color) - self.container = ui.Container(accent_colour=color) - self.add_item(self.container) - - self.description_display: Optional[ui.TextDisplay] = ( - ui.TextDisplay(description) if description else None - ) - - # Header section is only used when an author is provided so we don't - # need a placeholder accessory. if author is not None: - self.header: Optional[ui.Section] = ui.Section( - accessory=ui.Thumbnail(media=author.display_avatar.url) - ) - self.header.add_item(ui.TextDisplay(f"**{title}**")) - if self.description_display: - self.header.add_item(self.description_display) - self.container.add_item(self.header) - else: - self.header = None - self.title_display = ui.TextDisplay(f"**{title}**") - self.container.add_item(self.title_display) - if self.description_display: - self.container.add_item(self.description_display) - - # Container used for fields so they're inserted before the footer. - self.fields_container = ui.Container() - self.container.add_item(self.fields_container) - - self.separator = ui.Separator(spacing=discord.SeparatorSpacing.small) + self.set_author(name=author.display_name, icon_url=author.display_avatar.url) footer_text = footer or f"Bot ID: {bot.user.id}" + ( f" | User ID: {author.id}" if author else "" ) - self.footer_display = ui.TextDisplay(footer_text) - self.container.add_item(self.separator) - self.container.add_item(self.footer_display) + self.set_footer(text=footer_text) - # --- Compatibility helpers --- - def add_field(self, name: str, value: str, inline: bool = False): - """Append a bolded name/value line to the log view.""" - self.fields_container.add_item(ui.TextDisplay(f"**{name}:** {value}")) - - def set_footer(self, text: str): - """Replace the footer text display.""" - self.footer_display.content = text - - def set_author(self, name: str, icon_url: Optional[str] = None): - """Add or update the author information.""" - if self.header is None: - # Remove plain title/description displays and replace with a section. - self.container.remove_item(self.title_display) - if self.description_display: - self.container.remove_item(self.description_display) - self.header = ui.Section( - accessory=ui.Thumbnail(media=icon_url or "") - ) - self.header.add_item(ui.TextDisplay(name)) - if self.description_display: - self.header.add_item(self.description_display) - self.container.add_item(self.header) - # Move to the beginning to mimic embed header placement - self.container._children.remove(self.header) - self.container._children.insert(0, self.header) - else: - self.header.clear_items() - if icon_url: - self.header.accessory = ui.Thumbnail(media=icon_url) - self.header.add_item(ui.TextDisplay(name)) def _user_display(self, user: Union[discord.Member, discord.User]) -> str: """Return display name, username and ID string for a user.""" display = user.display_name if isinstance(user, discord.Member) else user.name @@ -192,7 +132,7 @@ class LoggingCog(commands.Cog): await self.session.close() log.info("aiohttp ClientSession closed for LoggingCog.") - async def _send_log_embed(self, guild: discord.Guild, embed: ui.LayoutView) -> None: + async def _send_log_embed(self, guild: discord.Guild, embed: discord.Embed) -> None: """Sends the log view via the configured webhook for the guild.""" if not self.session or self.session.closed: log.error(f"aiohttp session not available or closed in LoggingCog for guild {guild.id}. Cannot send log.") @@ -211,7 +151,7 @@ class LoggingCog(commands.Cog): client=self.bot, ) await webhook.send( - view=embed, + embed=embed, username=f"{self.bot.user.name} Logs", avatar_url=self.bot.user.display_avatar.url, allowed_mentions=AllowedMentions.none(), @@ -240,13 +180,13 @@ class LoggingCog(commands.Cog): color: discord.Color = discord.Color.blue(), author: Optional[Union[discord.User, discord.Member]] = None, footer: Optional[str] = None, - ) -> ui.LayoutView: - """Creates a standardized log view.""" + ) -> discord.Embed: + """Creates a standardized log embed.""" return self.LogView(self.bot, title, description, color, author, footer) def _add_id_footer( self, - embed: ui.LayoutView, + embed: discord.Embed, obj: Union[ discord.Member, discord.User, @@ -261,10 +201,10 @@ class LoggingCog(commands.Cog): ) -> None: """Adds an ID to the footer text if possible.""" target_id = obj_id or (obj.id if obj else None) - if target_id and hasattr(embed, "footer_display"): - existing_footer = embed.footer_display.content or "" + if target_id: + existing_footer = embed.footer.text or "" separator = " | " if existing_footer else "" - embed.footer_display.content = f"{existing_footer}{separator}{id_name}: {target_id}" + embed.set_footer(text=f"{existing_footer}{separator}{id_name}: {target_id}", icon_url=embed.footer.icon_url) async def _check_log_enabled(self, guild_id: int, event_key: str) -> bool: """Checks if logging is enabled for a specific event key in a guild.""" @@ -382,7 +322,7 @@ class LoggingCog(commands.Cog): color=discord.Color.green(), ) await new_webhook.send( - view=test_view, + embed=test_view, username=webhook_name, avatar_url=self.bot.user.display_avatar.url, allowed_mentions=AllowedMentions.none(), From d032e9607e02436c7c599d4e9a64ab57b3766614 Mon Sep 17 00:00:00 2001 From: Codex Date: Fri, 6 Jun 2025 04:16:46 +0000 Subject: [PATCH 10/11] Add tests for parse_repo_url --- tests/test_git_monitor.py | 52 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 tests/test_git_monitor.py diff --git a/tests/test_git_monitor.py b/tests/test_git_monitor.py new file mode 100644 index 0000000..8a286c2 --- /dev/null +++ b/tests/test_git_monitor.py @@ -0,0 +1,52 @@ +import os +import sys +import pytest + +# Ensure the project root is on sys.path so we can import modules +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from cogs.git_monitor_cog import parse_repo_url + + +@pytest.mark.parametrize( + "url,expected", + [ + ("https://github.com/user/repo", ("github", "user/repo")), + ("http://github.com/user/repo", ("github", "user/repo")), + ("github.com/user/repo", ("github", "user/repo")), + ("https://www.github.com/user/repo/", ("github", "user/repo")), + ("https://github.com/user/repo.git", ("github", "user/repo")), + ("https://github.com/user-name/re.po", ("github", "user-name/re.po")), + ("https://gitlab.com/group/project", ("gitlab", "group/project")), + ( + "https://gitlab.com/group/subgroup/project", + ("gitlab", "group/subgroup/project"), + ), + ("gitlab.com/group/subgroup/project.git", ("gitlab", "group/subgroup/project")), + ( + "http://www.gitlab.com/group/subgroup/project/", + ("gitlab", "group/subgroup/project"), + ), + ], +) +def test_parse_repo_url_valid(url, expected): + assert parse_repo_url(url) == expected + + +@pytest.mark.parametrize( + "url", + [ + "https://github.com/", + "https://github.com/user", + "https://gitlab.com/", + "https://gitlab.com/group", + "ftp://github.com/user/repo", + "http:/github.com/user/repo", + "not a url", + "https://gitlabx.com/group/project", + "gitlab.com/group//project", + "github.com/user/repo/extra", + ], +) +def test_parse_repo_url_invalid(url): + assert parse_repo_url(url) == (None, None) From 9f26c0dab8dbc3a27069762e5bf8897d0b1791d2 Mon Sep 17 00:00:00 2001 From: Codex Date: Fri, 6 Jun 2025 04:21:34 +0000 Subject: [PATCH 11/11] Add async starboard DB tests with mocked pool --- tests/test_starboard_db.py | 95 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 tests/test_starboard_db.py diff --git a/tests/test_starboard_db.py b/tests/test_starboard_db.py new file mode 100644 index 0000000..502f726 --- /dev/null +++ b/tests/test_starboard_db.py @@ -0,0 +1,95 @@ +"""Tests for starboard database helper functions.""" + +# pylint: disable=wrong-import-position + +import os +import sys + +# Ensure project root is on sys.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from unittest.mock import AsyncMock, patch + +import pytest # pylint: disable=import-error + +import settings_manager # pylint: disable=import-error + + +class DummyBot: + """Simple container for a pg_pool mock.""" + + def __init__(self, pg_pool): + self.pg_pool = pg_pool + + +@pytest.mark.asyncio +async def test_create_starboard_entry(): + """Verify create_starboard_entry executes expected queries.""" + + conn = AsyncMock() + acquire_cm = AsyncMock() + acquire_cm.__aenter__.return_value = conn + acquire_cm.__aexit__.return_value = None + + pg_pool = AsyncMock() + pg_pool.acquire.return_value = acquire_cm + + bot = DummyBot(pg_pool) + with patch.object(settings_manager, "get_bot_instance", return_value=bot): + result = await settings_manager.create_starboard_entry( + guild_id=1, + original_message_id=2, + original_channel_id=3, + starboard_message_id=4, + author_id=5, + star_count=6, + ) + + assert result is True + pg_pool.acquire.assert_called_once() + assert conn.execute.await_count == 2 + + +@pytest.mark.asyncio +async def test_update_starboard_entry(): + """Verify update_starboard_entry updates star count.""" + + conn = AsyncMock() + pg_pool = AsyncMock() + pg_pool.acquire = AsyncMock(return_value=conn) + pg_pool.release = AsyncMock() + + bot = DummyBot(pg_pool) + with patch.object(settings_manager, "get_bot_instance", return_value=bot): + result = await settings_manager.update_starboard_entry( + guild_id=1, original_message_id=2, star_count=3 + ) + + assert result is True + pg_pool.acquire.assert_called_once() + conn.execute.assert_awaited_once() + pg_pool.release.assert_called_once_with(conn) + + +@pytest.mark.asyncio +async def test_get_starboard_entry(): + """Verify get_starboard_entry fetches the row and returns a dict.""" + + entry_data = {"guild_id": 1, "original_message_id": 2} + conn = AsyncMock() + conn.fetchrow = AsyncMock(return_value=entry_data) + + acquire_cm = AsyncMock() + acquire_cm.__aenter__.return_value = conn + acquire_cm.__aexit__.return_value = None + + pg_pool = AsyncMock() + pg_pool.acquire.return_value = acquire_cm + + bot = DummyBot(pg_pool) + with patch.object(settings_manager, "get_bot_instance", return_value=bot): + result = await settings_manager.get_starboard_entry(1, 2) + + assert result == entry_data + pg_pool.acquire.assert_called_once() + conn.fetchrow.assert_awaited_once()