diff --git a/disagreement/client.py b/disagreement/client.py index 5fb4c7f..e5604a3 100644 --- a/disagreement/client.py +++ b/disagreement/client.py @@ -1112,6 +1112,7 @@ class Client: session_id = state["session_id"] voice = VoiceClient( + self, endpoint, session_id, token, diff --git a/disagreement/voice_client.py b/disagreement/voice_client.py index db3444b..38b8d09 100644 --- a/disagreement/voice_client.py +++ b/disagreement/voice_client.py @@ -6,11 +6,19 @@ from __future__ import annotations import asyncio import contextlib import socket -from typing import Optional, Sequence +import threading +from typing import TYPE_CHECKING, Optional, Sequence import aiohttp +# The following import is correct, but may be flagged by Pylance if the virtual +# environment is not configured correctly. +from nacl.secret import SecretBox -from .audio import AudioSource, FFmpegAudioSource +from .audio import AudioSink, AudioSource, FFmpegAudioSource +from .models import User + +if TYPE_CHECKING: + from .client import Client class VoiceClient: @@ -18,6 +26,7 @@ class VoiceClient: def __init__( self, + client: Client, endpoint: str, session_id: str, token: str, @@ -29,6 +38,7 @@ class VoiceClient: loop: Optional[asyncio.AbstractEventLoop] = None, verbose: bool = False, ) -> None: + self.client = client self.endpoint = endpoint self.session_id = session_id self.token = token @@ -38,8 +48,14 @@ class VoiceClient: self._udp = udp self._session: Optional[aiohttp.ClientSession] = None self._heartbeat_task: Optional[asyncio.Task] = None + self._receive_task: Optional[asyncio.Task] = None + self._udp_receive_thread: Optional[threading.Thread] = None self._heartbeat_interval: Optional[float] = None - self._loop = loop or asyncio.get_event_loop() + try: + self._loop = loop or asyncio.get_running_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) self.verbose = verbose self.ssrc: Optional[int] = None self.secret_key: Optional[Sequence[int]] = None @@ -47,6 +63,9 @@ class VoiceClient: self._server_port: Optional[int] = None self._current_source: Optional[AudioSource] = None self._play_task: Optional[asyncio.Task] = None + self._sink: Optional[AudioSink] = None + self._ssrc_map: dict[int, int] = {} + self._ssrc_lock = threading.Lock() async def connect(self) -> None: if self._ws is None: @@ -106,6 +125,49 @@ class VoiceClient: except asyncio.CancelledError: pass + async def _receive_loop(self) -> None: + assert self._ws is not None + while True: + try: + msg = await self._ws.receive_json() + op = msg.get("op") + data = msg.get("d") + if op == 5: # Speaking + user_id = int(data["user_id"]) + ssrc = data["ssrc"] + with self._ssrc_lock: + self._ssrc_map[ssrc] = user_id + except (asyncio.CancelledError, aiohttp.ClientError): + break + + def _udp_receive_loop(self) -> None: + assert self._udp is not None + assert self.secret_key is not None + box = SecretBox(bytes(self.secret_key)) + while True: + try: + packet = self._udp.recv(4096) + if len(packet) < 12: + continue + + ssrc = int.from_bytes(packet[8:12], "big") + with self._ssrc_lock: + if ssrc not in self._ssrc_map: + continue + user_id = self._ssrc_map[ssrc] + user = self.client._users.get(str(user_id)) + if not user: + continue + + decrypted = box.decrypt(packet[12:]) + if self._sink: + self._sink.write(user, decrypted) + except (socket.error, asyncio.CancelledError): + break + except Exception as e: + if self.verbose: + print(f"Error in UDP receive loop: {e}") + async def send_audio_frame(self, frame: bytes) -> None: if not self._udp: raise RuntimeError("UDP socket not initialised") @@ -148,15 +210,35 @@ class VoiceClient: await self.play(FFmpegAudioSource(filename), wait=wait) + def listen(self, sink: AudioSink) -> None: + """Start listening to voice and routing to a sink.""" + if not isinstance(sink, AudioSink): + raise TypeError("sink must be an AudioSink instance") + + self._sink = sink + if not self._udp_receive_thread: + self._udp_receive_thread = threading.Thread( + target=self._udp_receive_loop, daemon=True + ) + self._udp_receive_thread.start() + async def close(self) -> None: await self.stop() if self._heartbeat_task: self._heartbeat_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._heartbeat_task + if self._receive_task: + self._receive_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._receive_task if self._ws: await self._ws.close() if self._session: await self._session.close() if self._udp: self._udp.close() + if self._udp_receive_thread: + self._udp_receive_thread.join(timeout=1) + if self._sink: + self._sink.close() diff --git a/pyproject.toml b/pyproject.toml index 1c006e8..aefbedb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ classifiers = [ dependencies = [ "aiohttp>=3.9.0,<4.0.0", + "PyNaCl>=1.5.0,<2.0.0", ] [project.optional-dependencies]