discordbot/discord_oauth.py
2025-06-05 21:31:06 -06:00

449 lines
17 KiB
Python

"""
Discord OAuth2 implementation for the Discord bot.
This module handles the OAuth2 flow for authenticating users with Discord,
including generating authorization URLs, exchanging codes for tokens,
and managing token storage and refresh.
"""
import os
import json
import time
import secrets
import hashlib
import base64
import aiohttp
import asyncio
import traceback
from typing import Dict, Optional, Tuple, Any
from urllib.parse import urlencode
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# OAuth2 Configuration
CLIENT_ID = os.getenv("DISCORD_CLIENT_ID", "1360717457852993576")
# No client secret for public clients
# Use the API service's OAuth endpoint if available, otherwise use the local server
API_URL = os.getenv("API_URL", "https://slipstreamm.dev/api")
API_OAUTH_ENABLED = os.getenv("API_OAUTH_ENABLED", "true").lower() in (
"true",
"1",
"yes",
)
# If API OAuth is enabled, use the API service's OAuth endpoint
if API_OAUTH_ENABLED:
# For API OAuth, we'll use a special redirect URI that includes the code_verifier
# The base redirect URI is the API URL + /auth
API_AUTH_ENDPOINT = f"{API_URL}/auth"
# The actual redirect URI will be constructed in get_auth_url to include the code_verifier
REDIRECT_URI = os.getenv("DISCORD_REDIRECT_URI", API_AUTH_ENDPOINT)
else:
# Otherwise, use the local OAuth server
OAUTH_HOST = os.getenv("OAUTH_HOST", "localhost")
OAUTH_PORT = int(os.getenv("OAUTH_PORT", "8080"))
REDIRECT_URI = os.getenv(
"DISCORD_REDIRECT_URI", f"http://{OAUTH_HOST}:{OAUTH_PORT}/oauth/callback"
)
# Discord API endpoints
API_ENDPOINT = "https://discord.com/api/v10"
TOKEN_URL = f"{API_ENDPOINT}/oauth2/token"
AUTH_URL = f"{API_ENDPOINT}/oauth2/authorize"
USER_URL = f"{API_ENDPOINT}/users/@me"
# Token storage directory
TOKEN_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tokens")
os.makedirs(TOKEN_DIR, exist_ok=True)
# In-memory storage for PKCE code verifiers and pending states
code_verifiers: Dict[str, Any] = {}
# Global dictionary to store code verifiers by state
# This is used to pass the code verifier to the API service
pending_code_verifiers: Dict[str, str] = {}
class OAuthError(Exception):
"""Exception raised for OAuth errors."""
pass
def generate_code_verifier() -> str:
"""Generate a code verifier for PKCE."""
return secrets.token_urlsafe(64)
def generate_code_challenge(verifier: str) -> str:
"""Generate a code challenge from a code verifier."""
sha256 = hashlib.sha256(verifier.encode()).digest()
return base64.urlsafe_b64encode(sha256).decode().rstrip("=")
def get_token_path(user_id: str) -> str:
"""Get the path to the token file for a user."""
return os.path.join(TOKEN_DIR, f"{user_id}.json")
def save_token(user_id: str, token_data: Dict[str, Any]) -> None:
"""Save a token to disk."""
# Add the time when the token was saved
token_data["saved_at"] = int(time.time())
with open(get_token_path(user_id), "w") as f:
json.dump(token_data, f)
def load_token(user_id: str) -> Optional[Dict[str, Any]]:
"""Load a token from disk."""
token_path = get_token_path(user_id)
if not os.path.exists(token_path):
return None
try:
with open(token_path, "r") as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
return None
def is_token_expired(token_data: Dict[str, Any]) -> bool:
"""Check if a token is expired."""
if not token_data:
return True
# Get the time when the token was saved
saved_at = token_data.get("saved_at", 0)
# Get the token's expiration time
expires_in = token_data.get("expires_in", 0)
# Check if the token is expired
# We consider it expired if it's within 5 minutes of expiration
return (saved_at + expires_in - 300) < int(time.time())
def delete_token(user_id: str) -> bool:
"""Delete a token from disk."""
token_path = get_token_path(user_id)
if os.path.exists(token_path):
os.remove(token_path)
return True
return False
async def send_code_verifier_to_api(state: str, code_verifier: str) -> bool:
"""Send the code verifier to the API service."""
try:
async with aiohttp.ClientSession() as session:
# Construct the URL for the code verifier endpoint
url = f"{API_URL}/code_verifier"
# Prepare the data
data = {"state": state, "code_verifier": code_verifier}
# Send the code verifier to the API service
print(f"Sending code verifier for state {state} to API service: {url}")
# Try multiple times with increasing delays to ensure delivery
max_retries = 3
for retry in range(max_retries):
try:
async with session.post(url, json=data) as resp:
if resp.status == 200:
response_data = await resp.json()
print(
f"Successfully sent code verifier to API service: {response_data}"
)
return True
else:
error_text = await resp.text()
print(
f"Failed to send code verifier to API service (attempt {retry+1}/{max_retries}): {error_text}"
)
if retry < max_retries - 1:
# Wait before retrying, with exponential backoff
wait_time = 2**retry
print(f"Retrying in {wait_time} seconds...")
await asyncio.sleep(wait_time)
else:
return False
except aiohttp.ClientError as ce:
print(
f"Connection error when sending code verifier (attempt {retry+1}/{max_retries}): {ce}"
)
if retry < max_retries - 1:
wait_time = 2**retry
print(f"Retrying in {wait_time} seconds...")
await asyncio.sleep(wait_time)
else:
raise
return False
except Exception as e:
print(f"Error sending code verifier to API service: {e}")
traceback.print_exc()
return False
def get_auth_url(state: str, code_verifier: str) -> str:
"""Get the authorization URL for the OAuth2 flow."""
code_challenge = generate_code_challenge(code_verifier)
# Determine the redirect URI based on whether API OAuth is enabled
if API_OAUTH_ENABLED:
# For API OAuth, we must use a clean redirect URI without any query parameters
# The redirect URI must exactly match the one registered in the Discord application
actual_redirect_uri = API_AUTH_ENDPOINT
print(f"Using API OAuth with redirect URI: {actual_redirect_uri}")
else:
# For local OAuth server, use the standard redirect URI
actual_redirect_uri = REDIRECT_URI
print(f"Using local OAuth server with redirect URI: {actual_redirect_uri}")
# Build the authorization URL
params = {
"client_id": CLIENT_ID,
"redirect_uri": actual_redirect_uri,
"response_type": "code",
"scope": "identify",
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
"prompt": "consent",
}
auth_url = f"{AUTH_URL}?{urlencode(params)}"
# Store the code verifier and redirect URI for this state
code_verifiers[state] = {
"code_verifier": code_verifier,
"redirect_uri": actual_redirect_uri,
}
# Also store the code verifier in the global dictionary
# This will be used by the API service to retrieve the code verifier
pending_code_verifiers[state] = code_verifier
print(f"Stored code verifier for state {state}: {code_verifier[:10]}...")
# If API OAuth is enabled, send the code verifier to the API service
# This is critical for the PKCE flow to work with the API server
if API_OAUTH_ENABLED:
# Use a synchronous call to ensure the code verifier is sent before proceeding
# This is important because the user might click the auth URL immediately
loop = asyncio.get_event_loop()
send_success = False
try:
send_success = loop.run_until_complete(
send_code_verifier_to_api(state, code_verifier)
)
except Exception as e:
print(f"Error in synchronous code verifier send: {e}")
# Fall back to async task if synchronous call fails
asyncio.create_task(send_code_verifier_to_api(state, code_verifier))
if not send_success:
print(
"Warning: Failed to send code verifier synchronously, falling back to async task"
)
# Try again asynchronously as a backup
asyncio.create_task(send_code_verifier_to_api(state, code_verifier))
else:
print(f"Successfully sent code verifier for state {state} to API service")
return auth_url
async def exchange_code(code: str, state: str) -> Dict[str, Any]:
"""Exchange an authorization code for a token."""
# Get the code verifier and redirect URI for this state
state_data = code_verifiers.pop(state, None)
if not state_data:
raise OAuthError("Invalid state parameter")
# Extract code_verifier and redirect_uri
if isinstance(state_data, dict):
code_verifier = state_data.get("code_verifier")
redirect_uri = state_data.get("redirect_uri")
else:
# For backward compatibility
code_verifier = state_data
redirect_uri = REDIRECT_URI
if not code_verifier:
raise OAuthError("Missing code verifier")
# If API OAuth is enabled, we need to check if we should handle the token exchange ourselves
# or if the API service will handle it
if API_OAUTH_ENABLED and redirect_uri.startswith(API_URL):
# If the API service is handling the OAuth flow, we need to get the token from the API
# We'll make a request to the API service with the code and code_verifier
async with aiohttp.ClientSession() as session:
# Construct the URL with the code and code_verifier
params = {"code": code, "state": state, "code_verifier": code_verifier}
auth_url = f"{API_URL}/auth?{urlencode(params)}"
print(f"Redirecting to API service for token exchange: {auth_url}")
# Make a request to the API service
async with session.get(auth_url) as resp:
if resp.status != 200:
error_text = await resp.text()
print(f"Failed to exchange code with API service: {error_text}")
raise OAuthError(
f"Failed to exchange code with API service: {error_text}"
)
# The API service should return a success page, not the token
# We'll need to get the token from the API service separately
print("Successfully exchanged code with API service")
# Parse the response to get the token data
try:
response_data = await resp.json()
if "token" in response_data:
# Save the token data
token_data = response_data["token"]
save_token(response_data["user_id"], token_data)
print(
f"Successfully saved token for user {response_data['user_id']}"
)
return token_data
else:
# If the response doesn't contain a token, it's probably an HTML response
# We'll need to get the token from the API service separately
print(
"Response doesn't contain token data, will try to get it separately"
)
except Exception as e:
print(f"Error parsing response: {e}")
# If we couldn't get the token from the response, try to get it from the API service
try:
# Make a request to the API service to get the token
headers = {"Accept": "application/json"}
async with session.get(
f"{API_URL}/token", headers=headers
) as token_resp:
if token_resp.status != 200:
error_text = await token_resp.text()
print(f"Failed to get token from API service: {error_text}")
raise OAuthError(
f"Failed to get token from API service: {error_text}"
)
token_data = await token_resp.json()
if "access_token" in token_data:
return token_data
else:
raise OAuthError("API service didn't return a valid token")
except Exception as e:
print(f"Error getting token from API service: {e}")
# Return a placeholder token for now
return {
"access_token": "placeholder_token",
"token_type": "Bearer",
"expires_in": 604800,
}
# If we're handling the token exchange ourselves, proceed as before
async with aiohttp.ClientSession() as session:
# For public clients, we don't include a client secret
data = {
"client_id": CLIENT_ID,
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirect_uri,
"code_verifier": code_verifier,
}
print(f"Exchanging code for token with data: {data}")
async with session.post(TOKEN_URL, data=data) as resp:
if resp.status != 200:
error_text = await resp.text()
print(f"Failed to exchange code: {error_text}")
raise OAuthError(f"Failed to exchange code: {error_text}")
return await resp.json()
async def refresh_token(refresh_token: str) -> Dict[str, Any]:
"""Refresh an access token."""
async with aiohttp.ClientSession() as session:
# For public clients, we don't include a client secret
data = {
"client_id": CLIENT_ID,
"grant_type": "refresh_token",
"refresh_token": refresh_token,
}
print(f"Refreshing token with data: {data}")
async with session.post(TOKEN_URL, data=data) as resp:
if resp.status != 200:
error_text = await resp.text()
print(f"Failed to refresh token: {error_text}")
raise OAuthError(f"Failed to refresh token: {error_text}")
return await resp.json()
async def get_user_info(access_token: str) -> Dict[str, Any]:
"""Get information about the authenticated user."""
async with aiohttp.ClientSession() as session:
headers = {"Authorization": f"Bearer {access_token}"}
async with session.get(USER_URL, headers=headers) as resp:
if resp.status != 200:
error_text = await resp.text()
raise OAuthError(f"Failed to get user info: {error_text}")
return await resp.json()
async def get_token(user_id: str) -> Optional[str]:
"""Get a valid access token for a user."""
# Load the token from disk
token_data = load_token(user_id)
if not token_data:
return None
# Check if the token is expired
if is_token_expired(token_data):
# Try to refresh the token
refresh_token_str = token_data.get("refresh_token")
if not refresh_token_str:
return None
try:
# Refresh the token
new_token_data = await refresh_token(refresh_token_str)
# Save the new token
save_token(user_id, new_token_data)
# Return the new access token
return new_token_data.get("access_token")
except OAuthError:
# If refreshing fails, delete the token and return None
delete_token(user_id)
return None
# Return the access token
return token_data.get("access_token")
async def validate_token(token: str) -> Tuple[bool, Optional[str]]:
"""Validate a token and return the user ID if valid."""
try:
# Get user info to validate the token
user_info = await get_user_info(token)
return True, user_info.get("id")
except OAuthError:
return False, None