discordbot/oauth_server.py
Slipstreamm d64da1aa9a refactor(oauth): Refine authentication success messages
Update various success messages and help text related to OAuth. Instead
of broadly stating the bot gains general API access, the language now
specifies that it gains access to "scopes" or "additional OAuth scopes"
as allowed during the authentication flow, providing a more precise
understanding of the permissions granted.
2025-06-14 12:37:55 -06:00

144 lines
4.6 KiB
Python

"""
OAuth2 callback server for the Discord bot.
This module provides a simple web server to handle OAuth2 callbacks
from Discord. It uses aiohttp to create an asynchronous web server
that can run alongside the Discord bot.
"""
import os
import asyncio
import aiohttp
from aiohttp import web
import discord_oauth
from typing import Dict, Optional, Set, Callable
# Set of pending authorization states
pending_states: Set[str] = set()
# Callbacks for successful authorization
auth_callbacks: Dict[str, Callable] = {}
async def handle_oauth_callback(request: web.Request) -> web.Response:
"""Handle OAuth2 callback from Discord."""
# Get the authorization code and state from the request
code = request.query.get("code")
state = request.query.get("state")
if not code or not state:
return web.Response(text="Missing code or state parameter", status=400)
# Check if the state is valid
if state not in pending_states:
return web.Response(text="Invalid state parameter", status=400)
# Remove the state from pending states
pending_states.remove(state)
try:
# Exchange the code for a token
token_data = await discord_oauth.exchange_code(code, state)
# Get the user's information
access_token = token_data.get("access_token")
if not access_token:
return web.Response(text="Failed to get access token", status=500)
user_info = await discord_oauth.get_user_info(access_token)
user_id = user_info.get("id")
if not user_id:
return web.Response(text="Failed to get user ID", status=500)
# Save the token
discord_oauth.save_token(user_id, token_data)
# Call the callback for this state if it exists
callback = auth_callbacks.pop(state, None)
if callback:
asyncio.create_task(callback(user_id, user_info))
# Return a success message
return web.Response(
text=f"""
<html>
<head>
<title>Authentication Successful</title>
<style>
body {{ font-family: Arial, sans-serif; text-align: center; padding: 50px; }}
.success {{ color: green; }}
.info {{ margin-top: 20px; }}
</style>
</head>
<body>
<h1 class="success">Authentication Successful!</h1>
<p>You have successfully authenticated with Discord.</p>
<div class="info">
<p>You can now close this window and return to Discord.</p>
</div>
</body>
</html>
""",
content_type="text/html",
)
except discord_oauth.OAuthError as e:
return web.Response(text=f"OAuth error: {str(e)}", status=500)
except Exception as e:
return web.Response(text=f"Error: {str(e)}", status=500)
async def handle_root(request: web.Request) -> web.Response:
"""Handle requests to the root path."""
return web.Response(
text="""
<html>
<head>
<title>Discord Bot OAuth Server</title>
<style>
body { font-family: Arial, sans-serif; text-align: center; padding: 50px; }
</style>
</head>
<body>
<h1>Discord Bot OAuth Server</h1>
<p>This server handles OAuth callbacks for the Discord bot.</p>
<p>You should not access this page directly.</p>
</body>
</html>
""",
content_type="text/html",
)
def create_app() -> web.Application:
"""Create the web application."""
app = web.Application()
app.add_routes(
[web.get("/", handle_root), web.get("/oauth/callback", handle_oauth_callback)]
)
return app
async def start_server(host: str = "0.0.0.0", port: int = 8080) -> None:
"""Start the OAuth callback server."""
app = create_app()
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, host, port)
await site.start()
print(f"OAuth callback server running at http://{host}:{port}")
def register_auth_state(state: str, callback: Optional[Callable] = None) -> None:
"""Register a pending authorization state."""
pending_states.add(state)
if callback:
auth_callbacks[state] = callback
if __name__ == "__main__":
# For testing the server standalone
loop = asyncio.get_event_loop()
loop.run_until_complete(start_server())
loop.run_forever()