discordbot/api_service/database.py
2025-06-05 21:31:06 -06:00

481 lines
18 KiB
Python

import os
import json
import datetime
from typing import Dict, List, Optional, Any
# Use absolute import for api_models
from api_service.api_models import (
Conversation,
UserSettings,
Message,
RoleCategoryPreset,
GuildRoleCategoryConfig,
UserCustomColorRole,
)
# ============= Database Class =============
class Database:
def __init__(self, data_dir="data"):
self.data_dir = data_dir
self.conversations_file = os.path.join(data_dir, "conversations.json")
self.settings_file = os.path.join(data_dir, "user_settings.json")
self.tokens_file = os.path.join(data_dir, "user_tokens.json")
self.role_presets_file = os.path.join(data_dir, "role_category_presets.json")
self.guild_role_configs_file = os.path.join(
data_dir, "guild_role_category_configs.json"
)
self.user_color_roles_file = os.path.join(
data_dir, "user_custom_color_roles.json"
)
# Create data directory if it doesn't exist
os.makedirs(data_dir, exist_ok=True)
# In-memory storage
self.conversations: Dict[str, Dict[str, Conversation]] = (
{}
) # user_id -> conversation_id -> Conversation
self.user_settings: Dict[str, UserSettings] = {} # user_id -> UserSettings
self.user_tokens: Dict[str, Dict[str, Any]] = {} # user_id -> token_data
self.role_category_presets: Dict[str, RoleCategoryPreset] = (
{}
) # preset_id -> RoleCategoryPreset
self.guild_role_category_configs: Dict[str, List[GuildRoleCategoryConfig]] = (
{}
) # guild_id -> List[GuildRoleCategoryConfig]
self.user_custom_color_roles: Dict[str, Dict[str, UserCustomColorRole]] = (
{}
) # guild_id -> user_id -> UserCustomColorRole
# Load data from files
self.load_data()
def load_data(self):
"""Load all data from files"""
self.load_conversations()
self.load_user_settings()
self.load_user_tokens()
self.load_role_category_presets()
self.load_guild_role_category_configs()
self.load_user_custom_color_roles()
def save_data(self):
"""Save all data to files"""
self.save_conversations()
self.save_all_user_settings()
self.save_user_tokens()
self.save_role_category_presets()
self.save_guild_role_category_configs()
self.save_user_custom_color_roles()
# ============= Role Selector Data Load/Save Methods =============
def load_role_category_presets(self):
"""Load role category presets from file"""
if os.path.exists(self.role_presets_file):
try:
with open(self.role_presets_file, "r", encoding="utf-8") as f:
data = json.load(f)
self.role_category_presets = {
preset_id: RoleCategoryPreset.model_validate(preset_data)
for preset_id, preset_data in data.items()
}
print(
f"Loaded {len(self.role_category_presets)} role category presets."
)
except Exception as e:
print(f"Error loading role category presets: {e}")
self.role_category_presets = {}
else:
self.role_category_presets = {} # Initialize if file doesn't exist
def save_role_category_presets(self):
"""Save role category presets to file"""
try:
serializable_data = {
preset_id: preset.model_dump()
for preset_id, preset in self.role_category_presets.items()
}
with open(self.role_presets_file, "w", encoding="utf-8") as f:
json.dump(
serializable_data, f, indent=2, default=str, ensure_ascii=False
)
except Exception as e:
print(f"Error saving role category presets: {e}")
def load_guild_role_category_configs(self):
"""Load guild role category configs from file"""
if os.path.exists(self.guild_role_configs_file):
try:
with open(self.guild_role_configs_file, "r", encoding="utf-8") as f:
data = json.load(f)
self.guild_role_category_configs = {
guild_id: [
GuildRoleCategoryConfig.model_validate(config_data)
for config_data in configs_list
]
for guild_id, configs_list in data.items()
}
print(
f"Loaded guild role category configs for {len(self.guild_role_category_configs)} guilds."
)
except Exception as e:
print(f"Error loading guild role category configs: {e}")
self.guild_role_category_configs = {}
else:
self.guild_role_category_configs = {}
def save_guild_role_category_configs(self):
"""Save guild role category configs to file"""
try:
serializable_data = {
guild_id: [config.model_dump() for config in configs_list]
for guild_id, configs_list in self.guild_role_category_configs.items()
}
with open(self.guild_role_configs_file, "w", encoding="utf-8") as f:
json.dump(
serializable_data, f, indent=2, default=str, ensure_ascii=False
)
except Exception as e:
print(f"Error saving guild role category configs: {e}")
def load_user_custom_color_roles(self):
"""Load user custom color roles from file"""
if os.path.exists(self.user_color_roles_file):
try:
with open(self.user_color_roles_file, "r", encoding="utf-8") as f:
data = json.load(f)
self.user_custom_color_roles = {
guild_id: {
user_id: UserCustomColorRole.model_validate(role_data)
for user_id, role_data in user_roles.items()
}
for guild_id, user_roles in data.items()
}
print(
f"Loaded user custom color roles for {len(self.user_custom_color_roles)} guilds."
)
except Exception as e:
print(f"Error loading user custom color roles: {e}")
self.user_custom_color_roles = {}
else:
self.user_custom_color_roles = {}
def save_user_custom_color_roles(self):
"""Save user custom color roles to file"""
try:
serializable_data = {
guild_id: {
user_id: role.model_dump() for user_id, role in user_roles.items()
}
for guild_id, user_roles in self.user_custom_color_roles.items()
}
with open(self.user_color_roles_file, "w", encoding="utf-8") as f:
json.dump(
serializable_data, f, indent=2, default=str, ensure_ascii=False
)
except Exception as e:
print(f"Error saving user custom color roles: {e}")
# ============= Existing Data Load/Save Methods =============
def load_conversations(self):
"""Load conversations from file"""
if os.path.exists(self.conversations_file):
try:
with open(self.conversations_file, "r", encoding="utf-8") as f:
data = json.load(f)
# Convert to Conversation objects
self.conversations = {
user_id: {
conv_id: Conversation.model_validate(conv_data)
for conv_id, conv_data in user_convs.items()
}
for user_id, user_convs in data.items()
}
print(f"Loaded conversations for {len(self.conversations)} users")
except Exception as e:
print(f"Error loading conversations: {e}")
self.conversations = {}
def save_conversations(self):
"""Save conversations to file"""
try:
# Convert to JSON-serializable format
serializable_data = {
user_id: {
conv_id: conv.model_dump() for conv_id, conv in user_convs.items()
}
for user_id, user_convs in self.conversations.items()
}
with open(self.conversations_file, "w", encoding="utf-8") as f:
json.dump(
serializable_data, f, indent=2, default=str, ensure_ascii=False
)
except Exception as e:
print(f"Error saving conversations: {e}")
def load_user_settings(self):
"""Load user settings from file"""
if os.path.exists(self.settings_file):
try:
with open(self.settings_file, "r", encoding="utf-8") as f:
data = json.load(f)
# Convert to UserSettings objects
self.user_settings = {
user_id: UserSettings.model_validate(settings_data)
for user_id, settings_data in data.items()
}
print(f"Loaded settings for {len(self.user_settings)} users")
except Exception as e:
print(f"Error loading user settings: {e}")
self.user_settings = {}
def save_all_user_settings(self):
"""Save all user settings to file"""
try:
# Convert to JSON-serializable format
serializable_data = {
user_id: settings.model_dump()
for user_id, settings in self.user_settings.items()
}
with open(self.settings_file, "w", encoding="utf-8") as f:
json.dump(
serializable_data, f, indent=2, default=str, ensure_ascii=False
)
except Exception as e:
print(f"Error saving user settings: {e}")
# ============= Conversation Methods =============
def get_user_conversations(self, user_id: str) -> List[Conversation]:
"""Get all conversations for a user"""
return list(self.conversations.get(user_id, {}).values())
def get_conversation(
self, user_id: str, conversation_id: str
) -> Optional[Conversation]:
"""Get a specific conversation for a user"""
return self.conversations.get(user_id, {}).get(conversation_id)
def save_conversation(
self, user_id: str, conversation: Conversation
) -> Conversation:
"""Save a conversation for a user"""
# Update the timestamp
conversation.updated_at = datetime.datetime.now()
# Initialize user's conversations dict if it doesn't exist
if user_id not in self.conversations:
self.conversations[user_id] = {}
# Save the conversation
self.conversations[user_id][conversation.id] = conversation
# Save to disk
self.save_conversations()
return conversation
def delete_conversation(self, user_id: str, conversation_id: str) -> bool:
"""Delete a conversation for a user"""
if (
user_id in self.conversations
and conversation_id in self.conversations[user_id]
):
del self.conversations[user_id][conversation_id]
self.save_conversations()
return True
return False
# ============= Role Category Preset Methods =============
def get_role_category_preset(self, preset_id: str) -> Optional[RoleCategoryPreset]:
"""Get a specific role category preset by ID."""
return self.role_category_presets.get(preset_id)
def get_all_role_category_presets(self) -> List[RoleCategoryPreset]:
"""Get all role category presets."""
return list(self.role_category_presets.values())
def save_role_category_preset(
self, preset: RoleCategoryPreset
) -> RoleCategoryPreset:
"""Save a role category preset."""
self.role_category_presets[preset.id] = preset
self.save_role_category_presets()
return preset
def delete_role_category_preset(self, preset_id: str) -> bool:
"""Delete a role category preset."""
if preset_id in self.role_category_presets:
del self.role_category_presets[preset_id]
self.save_role_category_presets()
return True
return False
# ============= Guild Role Category Config Methods =============
def get_guild_role_category_configs(
self, guild_id: str
) -> List[GuildRoleCategoryConfig]:
"""Get all role category configurations for a specific guild."""
return self.guild_role_category_configs.get(guild_id, [])
def get_all_guild_role_category_configs(
self,
) -> Dict[str, List[GuildRoleCategoryConfig]]:
"""Get all role category configurations for all guilds."""
return self.guild_role_category_configs
def get_guild_role_category_config(
self, guild_id: str, category_id: str
) -> Optional[GuildRoleCategoryConfig]:
"""Get a specific role category configuration for a guild."""
for config in self.get_guild_role_category_configs(guild_id):
if config.category_id == category_id:
return config
return None
def save_guild_role_category_config(
self, config: GuildRoleCategoryConfig
) -> GuildRoleCategoryConfig:
"""Save a guild's role category configuration."""
guild_id = config.guild_id
if guild_id not in self.guild_role_category_configs:
self.guild_role_category_configs[guild_id] = []
# Remove existing config with the same category_id if it exists, then add the new/updated one
self.guild_role_category_configs[guild_id] = [
c
for c in self.guild_role_category_configs[guild_id]
if c.category_id != config.category_id
]
self.guild_role_category_configs[guild_id].append(config)
self.save_guild_role_category_configs()
return config
def delete_guild_role_category_config(
self, guild_id: str, category_id: str
) -> bool:
"""Delete a specific role category configuration for a guild."""
if guild_id in self.guild_role_category_configs:
initial_len = len(self.guild_role_category_configs[guild_id])
self.guild_role_category_configs[guild_id] = [
c
for c in self.guild_role_category_configs[guild_id]
if c.category_id != category_id
]
if len(self.guild_role_category_configs[guild_id]) < initial_len:
self.save_guild_role_category_configs()
return True
return False
# ============= User Custom Color Role Methods =============
def get_user_custom_color_role(
self, guild_id: str, user_id: str
) -> Optional[UserCustomColorRole]:
"""Get a user's custom color role in a specific guild."""
return self.user_custom_color_roles.get(guild_id, {}).get(user_id)
def save_user_custom_color_role(
self, color_role: UserCustomColorRole
) -> UserCustomColorRole:
"""Save a user's custom color role."""
guild_id = color_role.guild_id
user_id = color_role.user_id
if guild_id not in self.user_custom_color_roles:
self.user_custom_color_roles[guild_id] = {}
color_role.last_updated = datetime.datetime.now()
self.user_custom_color_roles[guild_id][user_id] = color_role
self.save_user_custom_color_roles()
return color_role
def delete_user_custom_color_role(self, guild_id: str, user_id: str) -> bool:
"""Delete a user's custom color role in a specific guild."""
if (
guild_id in self.user_custom_color_roles
and user_id in self.user_custom_color_roles[guild_id]
):
del self.user_custom_color_roles[guild_id][user_id]
self.save_user_custom_color_roles()
return True
return False
# ============= User Settings Methods =============
def get_user_settings(self, user_id: str) -> UserSettings:
"""Get settings for a user, creating default settings if they don't exist"""
if user_id not in self.user_settings:
self.user_settings[user_id] = UserSettings()
return self.user_settings[user_id]
def save_user_settings(self, user_id: str, settings: UserSettings) -> UserSettings:
"""Save settings for a user"""
# Update the timestamp
settings.last_updated = datetime.datetime.now()
# Save the settings
self.user_settings[user_id] = settings
# Save to disk
self.save_all_user_settings()
return settings
# ============= User Tokens Methods =============
def load_user_tokens(self):
"""Load user tokens from file"""
if os.path.exists(self.tokens_file):
try:
with open(self.tokens_file, "r", encoding="utf-8") as f:
self.user_tokens = json.load(f)
print(f"Loaded tokens for {len(self.user_tokens)} users")
except Exception as e:
print(f"Error loading user tokens: {e}")
self.user_tokens = {}
def save_user_tokens(self):
"""Save user tokens to file"""
try:
with open(self.tokens_file, "w", encoding="utf-8") as f:
json.dump(
self.user_tokens, f, indent=2, default=str, ensure_ascii=False
)
except Exception as e:
print(f"Error saving user tokens: {e}")
def get_user_token(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Get token data for a user"""
return self.user_tokens.get(user_id)
def save_user_token(
self, user_id: str, token_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Save token data for a user"""
# Add the time when the token was saved
token_data["saved_at"] = datetime.datetime.now().isoformat()
# Save the token data
self.user_tokens[user_id] = token_data
# Save to disk
self.save_user_tokens()
return token_data
def delete_user_token(self, user_id: str) -> bool:
"""Delete token data for a user"""
if user_id in self.user_tokens:
del self.user_tokens[user_id]
self.save_user_tokens()
return True
return False