449 lines
17 KiB
Python
449 lines
17 KiB
Python
"""
|
|
Discord OAuth2 implementation for the Discord bot.
|
|
|
|
This module handles the OAuth2 flow for authenticating users with Discord,
|
|
including generating authorization URLs, exchanging codes for tokens,
|
|
and managing token storage and refresh.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import time
|
|
import secrets
|
|
import hashlib
|
|
import base64
|
|
import aiohttp
|
|
import asyncio
|
|
import traceback
|
|
from typing import Dict, Optional, Tuple, Any
|
|
from urllib.parse import urlencode
|
|
from dotenv import load_dotenv
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# OAuth2 Configuration
|
|
CLIENT_ID = os.getenv("DISCORD_CLIENT_ID", "1360717457852993576")
|
|
# No client secret for public clients
|
|
|
|
# Use the API service's OAuth endpoint if available, otherwise use the local server
|
|
API_URL = os.getenv("API_URL", "https://slipstreamm.dev/api")
|
|
API_OAUTH_ENABLED = os.getenv("API_OAUTH_ENABLED", "true").lower() in (
|
|
"true",
|
|
"1",
|
|
"yes",
|
|
)
|
|
|
|
# If API OAuth is enabled, use the API service's OAuth endpoint
|
|
if API_OAUTH_ENABLED:
|
|
# For API OAuth, we'll use a special redirect URI that includes the code_verifier
|
|
# The base redirect URI is the API URL + /auth
|
|
API_AUTH_ENDPOINT = f"{API_URL}/auth"
|
|
# The actual redirect URI will be constructed in get_auth_url to include the code_verifier
|
|
REDIRECT_URI = os.getenv("DISCORD_REDIRECT_URI", API_AUTH_ENDPOINT)
|
|
else:
|
|
# Otherwise, use the local OAuth server
|
|
OAUTH_HOST = os.getenv("OAUTH_HOST", "localhost")
|
|
OAUTH_PORT = int(os.getenv("OAUTH_PORT", "8080"))
|
|
REDIRECT_URI = os.getenv(
|
|
"DISCORD_REDIRECT_URI", f"http://{OAUTH_HOST}:{OAUTH_PORT}/oauth/callback"
|
|
)
|
|
|
|
# Discord API endpoints
|
|
API_ENDPOINT = "https://discord.com/api/v10"
|
|
TOKEN_URL = f"{API_ENDPOINT}/oauth2/token"
|
|
AUTH_URL = f"{API_ENDPOINT}/oauth2/authorize"
|
|
USER_URL = f"{API_ENDPOINT}/users/@me"
|
|
|
|
# Token storage directory
|
|
TOKEN_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tokens")
|
|
os.makedirs(TOKEN_DIR, exist_ok=True)
|
|
|
|
# In-memory storage for PKCE code verifiers and pending states
|
|
code_verifiers: Dict[str, Any] = {}
|
|
|
|
# Global dictionary to store code verifiers by state
|
|
# This is used to pass the code verifier to the API service
|
|
pending_code_verifiers: Dict[str, str] = {}
|
|
|
|
|
|
class OAuthError(Exception):
|
|
"""Exception raised for OAuth errors."""
|
|
|
|
pass
|
|
|
|
|
|
def generate_code_verifier() -> str:
|
|
"""Generate a code verifier for PKCE."""
|
|
return secrets.token_urlsafe(64)
|
|
|
|
|
|
def generate_code_challenge(verifier: str) -> str:
|
|
"""Generate a code challenge from a code verifier."""
|
|
sha256 = hashlib.sha256(verifier.encode()).digest()
|
|
return base64.urlsafe_b64encode(sha256).decode().rstrip("=")
|
|
|
|
|
|
def get_token_path(user_id: str) -> str:
|
|
"""Get the path to the token file for a user."""
|
|
return os.path.join(TOKEN_DIR, f"{user_id}.json")
|
|
|
|
|
|
def save_token(user_id: str, token_data: Dict[str, Any]) -> None:
|
|
"""Save a token to disk."""
|
|
# Add the time when the token was saved
|
|
token_data["saved_at"] = int(time.time())
|
|
|
|
with open(get_token_path(user_id), "w") as f:
|
|
json.dump(token_data, f)
|
|
|
|
|
|
def load_token(user_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Load a token from disk."""
|
|
token_path = get_token_path(user_id)
|
|
if not os.path.exists(token_path):
|
|
return None
|
|
|
|
try:
|
|
with open(token_path, "r") as f:
|
|
return json.load(f)
|
|
except (json.JSONDecodeError, IOError):
|
|
return None
|
|
|
|
|
|
def is_token_expired(token_data: Dict[str, Any]) -> bool:
|
|
"""Check if a token is expired."""
|
|
if not token_data:
|
|
return True
|
|
|
|
# Get the time when the token was saved
|
|
saved_at = token_data.get("saved_at", 0)
|
|
|
|
# Get the token's expiration time
|
|
expires_in = token_data.get("expires_in", 0)
|
|
|
|
# Check if the token is expired
|
|
# We consider it expired if it's within 5 minutes of expiration
|
|
return (saved_at + expires_in - 300) < int(time.time())
|
|
|
|
|
|
def delete_token(user_id: str) -> bool:
|
|
"""Delete a token from disk."""
|
|
token_path = get_token_path(user_id)
|
|
if os.path.exists(token_path):
|
|
os.remove(token_path)
|
|
return True
|
|
return False
|
|
|
|
|
|
async def send_code_verifier_to_api(state: str, code_verifier: str) -> bool:
|
|
"""Send the code verifier to the API service."""
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
# Construct the URL for the code verifier endpoint
|
|
url = f"{API_URL}/code_verifier"
|
|
|
|
# Prepare the data
|
|
data = {"state": state, "code_verifier": code_verifier}
|
|
|
|
# Send the code verifier to the API service
|
|
print(f"Sending code verifier for state {state} to API service: {url}")
|
|
|
|
# Try multiple times with increasing delays to ensure delivery
|
|
max_retries = 3
|
|
for retry in range(max_retries):
|
|
try:
|
|
async with session.post(url, json=data) as resp:
|
|
if resp.status == 200:
|
|
response_data = await resp.json()
|
|
print(
|
|
f"Successfully sent code verifier to API service: {response_data}"
|
|
)
|
|
return True
|
|
else:
|
|
error_text = await resp.text()
|
|
print(
|
|
f"Failed to send code verifier to API service (attempt {retry+1}/{max_retries}): {error_text}"
|
|
)
|
|
if retry < max_retries - 1:
|
|
# Wait before retrying, with exponential backoff
|
|
wait_time = 2**retry
|
|
print(f"Retrying in {wait_time} seconds...")
|
|
await asyncio.sleep(wait_time)
|
|
else:
|
|
return False
|
|
except aiohttp.ClientError as ce:
|
|
print(
|
|
f"Connection error when sending code verifier (attempt {retry+1}/{max_retries}): {ce}"
|
|
)
|
|
if retry < max_retries - 1:
|
|
wait_time = 2**retry
|
|
print(f"Retrying in {wait_time} seconds...")
|
|
await asyncio.sleep(wait_time)
|
|
else:
|
|
raise
|
|
|
|
return False
|
|
except Exception as e:
|
|
print(f"Error sending code verifier to API service: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
|
|
def get_auth_url(state: str, code_verifier: str) -> str:
|
|
"""Get the authorization URL for the OAuth2 flow."""
|
|
code_challenge = generate_code_challenge(code_verifier)
|
|
|
|
# Determine the redirect URI based on whether API OAuth is enabled
|
|
if API_OAUTH_ENABLED:
|
|
# For API OAuth, we must use a clean redirect URI without any query parameters
|
|
# The redirect URI must exactly match the one registered in the Discord application
|
|
actual_redirect_uri = API_AUTH_ENDPOINT
|
|
print(f"Using API OAuth with redirect URI: {actual_redirect_uri}")
|
|
else:
|
|
# For local OAuth server, use the standard redirect URI
|
|
actual_redirect_uri = REDIRECT_URI
|
|
print(f"Using local OAuth server with redirect URI: {actual_redirect_uri}")
|
|
|
|
# Build the authorization URL
|
|
params = {
|
|
"client_id": CLIENT_ID,
|
|
"redirect_uri": actual_redirect_uri,
|
|
"response_type": "code",
|
|
"scope": "identify",
|
|
"state": state,
|
|
"code_challenge": code_challenge,
|
|
"code_challenge_method": "S256",
|
|
"prompt": "consent",
|
|
}
|
|
|
|
auth_url = f"{AUTH_URL}?{urlencode(params)}"
|
|
|
|
# Store the code verifier and redirect URI for this state
|
|
code_verifiers[state] = {
|
|
"code_verifier": code_verifier,
|
|
"redirect_uri": actual_redirect_uri,
|
|
}
|
|
|
|
# Also store the code verifier in the global dictionary
|
|
# This will be used by the API service to retrieve the code verifier
|
|
pending_code_verifiers[state] = code_verifier
|
|
print(f"Stored code verifier for state {state}: {code_verifier[:10]}...")
|
|
|
|
# If API OAuth is enabled, send the code verifier to the API service
|
|
# This is critical for the PKCE flow to work with the API server
|
|
if API_OAUTH_ENABLED:
|
|
# Use a synchronous call to ensure the code verifier is sent before proceeding
|
|
# This is important because the user might click the auth URL immediately
|
|
loop = asyncio.get_event_loop()
|
|
send_success = False
|
|
try:
|
|
send_success = loop.run_until_complete(
|
|
send_code_verifier_to_api(state, code_verifier)
|
|
)
|
|
except Exception as e:
|
|
print(f"Error in synchronous code verifier send: {e}")
|
|
# Fall back to async task if synchronous call fails
|
|
asyncio.create_task(send_code_verifier_to_api(state, code_verifier))
|
|
|
|
if not send_success:
|
|
print(
|
|
"Warning: Failed to send code verifier synchronously, falling back to async task"
|
|
)
|
|
# Try again asynchronously as a backup
|
|
asyncio.create_task(send_code_verifier_to_api(state, code_verifier))
|
|
else:
|
|
print(f"Successfully sent code verifier for state {state} to API service")
|
|
|
|
return auth_url
|
|
|
|
|
|
async def exchange_code(code: str, state: str) -> Dict[str, Any]:
|
|
"""Exchange an authorization code for a token."""
|
|
# Get the code verifier and redirect URI for this state
|
|
state_data = code_verifiers.pop(state, None)
|
|
if not state_data:
|
|
raise OAuthError("Invalid state parameter")
|
|
|
|
# Extract code_verifier and redirect_uri
|
|
if isinstance(state_data, dict):
|
|
code_verifier = state_data.get("code_verifier")
|
|
redirect_uri = state_data.get("redirect_uri")
|
|
else:
|
|
# For backward compatibility
|
|
code_verifier = state_data
|
|
redirect_uri = REDIRECT_URI
|
|
|
|
if not code_verifier:
|
|
raise OAuthError("Missing code verifier")
|
|
|
|
# If API OAuth is enabled, we need to check if we should handle the token exchange ourselves
|
|
# or if the API service will handle it
|
|
if API_OAUTH_ENABLED and redirect_uri.startswith(API_URL):
|
|
# If the API service is handling the OAuth flow, we need to get the token from the API
|
|
# We'll make a request to the API service with the code and code_verifier
|
|
async with aiohttp.ClientSession() as session:
|
|
# Construct the URL with the code and code_verifier
|
|
params = {"code": code, "state": state, "code_verifier": code_verifier}
|
|
auth_url = f"{API_URL}/auth?{urlencode(params)}"
|
|
|
|
print(f"Redirecting to API service for token exchange: {auth_url}")
|
|
|
|
# Make a request to the API service
|
|
async with session.get(auth_url) as resp:
|
|
if resp.status != 200:
|
|
error_text = await resp.text()
|
|
print(f"Failed to exchange code with API service: {error_text}")
|
|
raise OAuthError(
|
|
f"Failed to exchange code with API service: {error_text}"
|
|
)
|
|
|
|
# The API service should return a success page, not the token
|
|
# We'll need to get the token from the API service separately
|
|
print("Successfully exchanged code with API service")
|
|
|
|
# Parse the response to get the token data
|
|
try:
|
|
response_data = await resp.json()
|
|
if "token" in response_data:
|
|
# Save the token data
|
|
token_data = response_data["token"]
|
|
save_token(response_data["user_id"], token_data)
|
|
print(
|
|
f"Successfully saved token for user {response_data['user_id']}"
|
|
)
|
|
return token_data
|
|
else:
|
|
# If the response doesn't contain a token, it's probably an HTML response
|
|
# We'll need to get the token from the API service separately
|
|
print(
|
|
"Response doesn't contain token data, will try to get it separately"
|
|
)
|
|
except Exception as e:
|
|
print(f"Error parsing response: {e}")
|
|
|
|
# If we couldn't get the token from the response, try to get it from the API service
|
|
try:
|
|
# Make a request to the API service to get the token
|
|
headers = {"Accept": "application/json"}
|
|
async with session.get(
|
|
f"{API_URL}/token", headers=headers
|
|
) as token_resp:
|
|
if token_resp.status != 200:
|
|
error_text = await token_resp.text()
|
|
print(f"Failed to get token from API service: {error_text}")
|
|
raise OAuthError(
|
|
f"Failed to get token from API service: {error_text}"
|
|
)
|
|
|
|
token_data = await token_resp.json()
|
|
if "access_token" in token_data:
|
|
return token_data
|
|
else:
|
|
raise OAuthError("API service didn't return a valid token")
|
|
except Exception as e:
|
|
print(f"Error getting token from API service: {e}")
|
|
# Return a placeholder token for now
|
|
return {
|
|
"access_token": "placeholder_token",
|
|
"token_type": "Bearer",
|
|
"expires_in": 604800,
|
|
}
|
|
|
|
# If we're handling the token exchange ourselves, proceed as before
|
|
async with aiohttp.ClientSession() as session:
|
|
# For public clients, we don't include a client secret
|
|
data = {
|
|
"client_id": CLIENT_ID,
|
|
"grant_type": "authorization_code",
|
|
"code": code,
|
|
"redirect_uri": redirect_uri,
|
|
"code_verifier": code_verifier,
|
|
}
|
|
|
|
print(f"Exchanging code for token with data: {data}")
|
|
|
|
async with session.post(TOKEN_URL, data=data) as resp:
|
|
if resp.status != 200:
|
|
error_text = await resp.text()
|
|
print(f"Failed to exchange code: {error_text}")
|
|
raise OAuthError(f"Failed to exchange code: {error_text}")
|
|
|
|
return await resp.json()
|
|
|
|
|
|
async def refresh_token(refresh_token: str) -> Dict[str, Any]:
|
|
"""Refresh an access token."""
|
|
async with aiohttp.ClientSession() as session:
|
|
# For public clients, we don't include a client secret
|
|
data = {
|
|
"client_id": CLIENT_ID,
|
|
"grant_type": "refresh_token",
|
|
"refresh_token": refresh_token,
|
|
}
|
|
|
|
print(f"Refreshing token with data: {data}")
|
|
|
|
async with session.post(TOKEN_URL, data=data) as resp:
|
|
if resp.status != 200:
|
|
error_text = await resp.text()
|
|
print(f"Failed to refresh token: {error_text}")
|
|
raise OAuthError(f"Failed to refresh token: {error_text}")
|
|
|
|
return await resp.json()
|
|
|
|
|
|
async def get_user_info(access_token: str) -> Dict[str, Any]:
|
|
"""Get information about the authenticated user."""
|
|
async with aiohttp.ClientSession() as session:
|
|
headers = {"Authorization": f"Bearer {access_token}"}
|
|
|
|
async with session.get(USER_URL, headers=headers) as resp:
|
|
if resp.status != 200:
|
|
error_text = await resp.text()
|
|
raise OAuthError(f"Failed to get user info: {error_text}")
|
|
|
|
return await resp.json()
|
|
|
|
|
|
async def get_token(user_id: str) -> Optional[str]:
|
|
"""Get a valid access token for a user."""
|
|
# Load the token from disk
|
|
token_data = load_token(user_id)
|
|
if not token_data:
|
|
return None
|
|
|
|
# Check if the token is expired
|
|
if is_token_expired(token_data):
|
|
# Try to refresh the token
|
|
refresh_token_str = token_data.get("refresh_token")
|
|
if not refresh_token_str:
|
|
return None
|
|
|
|
try:
|
|
# Refresh the token
|
|
new_token_data = await refresh_token(refresh_token_str)
|
|
|
|
# Save the new token
|
|
save_token(user_id, new_token_data)
|
|
|
|
# Return the new access token
|
|
return new_token_data.get("access_token")
|
|
except OAuthError:
|
|
# If refreshing fails, delete the token and return None
|
|
delete_token(user_id)
|
|
return None
|
|
|
|
# Return the access token
|
|
return token_data.get("access_token")
|
|
|
|
|
|
async def validate_token(token: str) -> Tuple[bool, Optional[str]]:
|
|
"""Validate a token and return the user ID if valid."""
|
|
try:
|
|
# Get user info to validate the token
|
|
user_info = await get_user_info(token)
|
|
return True, user_info.get("id")
|
|
except OAuthError:
|
|
return False, None
|