From 0eed122f02586caa29a99fe1f3c1bfa59b0edc3f Mon Sep 17 00:00:00 2001 From: Slipstream Date: Tue, 10 Jun 2025 15:44:54 -0600 Subject: [PATCH] Refactor voice client with audio sources (#19) --- disagreement/__init__.py | 1 + disagreement/audio.py | 116 +++++++++++++++++++++++++++++++++++ disagreement/voice_client.py | 55 +++++++++++------ docs/voice_client.md | 13 +++- docs/voice_features.md | 3 +- tests/test_voice_client.py | 31 ++++++++++ 6 files changed, 196 insertions(+), 23 deletions(-) create mode 100644 disagreement/audio.py diff --git a/disagreement/__init__.py b/disagreement/__init__.py index 327cb86..f0eca2b 100644 --- a/disagreement/__init__.py +++ b/disagreement/__init__.py @@ -19,6 +19,7 @@ __version__ = "0.0.2" from .client import Client from .models import Message, User, Reaction from .voice_client import VoiceClient +from .audio import AudioSource, FFmpegAudioSource from .typing import Typing from .errors import ( DisagreementException, diff --git a/disagreement/audio.py b/disagreement/audio.py new file mode 100644 index 0000000..f70369c --- /dev/null +++ b/disagreement/audio.py @@ -0,0 +1,116 @@ +"""Audio source abstractions for the voice client.""" + +from __future__ import annotations + +import asyncio +import contextlib +import io +from typing import Optional, Union + + +class AudioSource: + """Abstract base class for audio sources.""" + + async def read(self) -> bytes: + """Read the next chunk of PCM audio. + + Subclasses must implement this and return raw PCM data + at 48kHz stereo (3840 byte chunks). + """ + + raise NotImplementedError + + async def close(self) -> None: + """Cleanup the source when playback ends.""" + + return None + + +class FFmpegAudioSource(AudioSource): + """Decode audio using FFmpeg. + + Parameters + ---------- + source: + A filename, URL, or file-like object to read from. + """ + + def __init__(self, source: Union[str, io.BufferedIOBase]): + self.source = source + self.process: Optional[asyncio.subprocess.Process] = None + self._feeder: Optional[asyncio.Task] = None + + async def _spawn(self) -> None: + if isinstance(self.source, str): + args = [ + "ffmpeg", + "-i", + self.source, + "-f", + "s16le", + "-ar", + "48000", + "-ac", + "2", + "pipe:1", + ] + self.process = await asyncio.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.DEVNULL, + ) + else: + args = [ + "ffmpeg", + "-i", + "pipe:0", + "-f", + "s16le", + "-ar", + "48000", + "-ac", + "2", + "pipe:1", + ] + self.process = await asyncio.create_subprocess_exec( + *args, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.DEVNULL, + ) + assert self.process.stdin is not None + self._feeder = asyncio.create_task(self._feed()) + + async def _feed(self) -> None: + assert isinstance(self.source, io.BufferedIOBase) + assert self.process is not None + assert self.process.stdin is not None + while True: + data = await asyncio.to_thread(self.source.read, 4096) + if not data: + break + self.process.stdin.write(data) + await self.process.stdin.drain() + self.process.stdin.close() + + async def read(self) -> bytes: + if self.process is None: + await self._spawn() + assert self.process is not None + assert self.process.stdout is not None + data = await self.process.stdout.read(3840) + if not data: + await self.close() + return data + + async def close(self) -> None: + if self._feeder: + self._feeder.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._feeder + if self.process: + await self.process.wait() + self.process = None + if isinstance(self.source, io.IOBase): + with contextlib.suppress(Exception): + self.source.close() diff --git a/disagreement/voice_client.py b/disagreement/voice_client.py index 6e7008f..db3444b 100644 --- a/disagreement/voice_client.py +++ b/disagreement/voice_client.py @@ -10,6 +10,8 @@ from typing import Optional, Sequence import aiohttp +from .audio import AudioSource, FFmpegAudioSource + class VoiceClient: """Handles the Discord voice WebSocket connection and UDP streaming.""" @@ -43,6 +45,8 @@ class VoiceClient: self.secret_key: Optional[Sequence[int]] = None self._server_ip: Optional[str] = None self._server_port: Optional[int] = None + self._current_source: Optional[AudioSource] = None + self._play_task: Optional[asyncio.Task] = None async def connect(self) -> None: if self._ws is None: @@ -107,34 +111,45 @@ class VoiceClient: raise RuntimeError("UDP socket not initialised") self._udp.send(frame) - async def play_file(self, filename: str) -> None: - """|coro| Stream an audio file to the voice connection using FFmpeg.""" - - process = await asyncio.create_subprocess_exec( - "ffmpeg", - "-i", - filename, - "-f", - "s16le", - "-ar", - "48000", - "-ac", - "2", - "pipe:1", - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.DEVNULL, - ) - assert process.stdout is not None + async def _play_loop(self) -> None: + assert self._current_source is not None try: while True: - data = await process.stdout.read(3840) + data = await self._current_source.read() if not data: break await self.send_audio_frame(data) finally: - await process.wait() + await self._current_source.close() + self._current_source = None + self._play_task = None + + async def stop(self) -> None: + if self._play_task: + self._play_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._play_task + self._play_task = None + if self._current_source: + await self._current_source.close() + self._current_source = None + + async def play(self, source: AudioSource, *, wait: bool = True) -> None: + """|coro| Play an :class:`AudioSource` on the voice connection.""" + + await self.stop() + self._current_source = source + self._play_task = self._loop.create_task(self._play_loop()) + if wait: + await self._play_task + + async def play_file(self, filename: str, *, wait: bool = True) -> None: + """|coro| Stream an audio file or URL using FFmpeg.""" + + await self.play(FFmpegAudioSource(filename), wait=wait) async def close(self) -> None: + await self.stop() if self._heartbeat_task: self._heartbeat_task.cancel() with contextlib.suppress(asyncio.CancelledError): diff --git a/docs/voice_client.md b/docs/voice_client.md index d7bf2a2..f8a1060 100644 --- a/docs/voice_client.md +++ b/docs/voice_client.md @@ -26,10 +26,19 @@ After connecting you can send raw Opus frames: await vc.send_audio_frame(opus_bytes) ``` -Or stream a file using FFmpeg: +Or stream audio using an :class:`AudioSource`: ```python -await vc.play_file("welcome.mp3") +from disagreement import FFmpegAudioSource + +source = FFmpegAudioSource("welcome.mp3") +await vc.play(source) +``` + +You can switch sources while connected: + +```python +await vc.play(FFmpegAudioSource("other.mp3")) ``` Call `await vc.close()` when finished. diff --git a/docs/voice_features.md b/docs/voice_features.md index ce261a3..82265f6 100644 --- a/docs/voice_features.md +++ b/docs/voice_features.md @@ -1,10 +1,11 @@ # Voice Features -Disagreement includes experimental support for connecting to voice channels. You can join a voice channel and play audio using an FFmpeg subprocess. +Disagreement includes experimental support for connecting to voice channels. You can join a voice channel and play audio using an :class:`AudioSource`. ```python voice = await client.join_voice(guild_id, channel_id) await voice.play_file("welcome.mp3") +await voice.play_file("another.mp3") # switch sources while connected await voice.close() ``` diff --git a/tests/test_voice_client.py b/tests/test_voice_client.py index 4f10e50..fae0289 100644 --- a/tests/test_voice_client.py +++ b/tests/test_voice_client.py @@ -2,6 +2,7 @@ import asyncio import pytest from disagreement.voice_client import VoiceClient +from disagreement.audio import AudioSource class DummyWebSocket: @@ -39,6 +40,16 @@ class DummyUDP: pass +class DummySource(AudioSource): + def __init__(self, chunks): + self.chunks = list(chunks) + + async def read(self) -> bytes: + if self.chunks: + return self.chunks.pop(0) + return b"" + + @pytest.mark.asyncio async def test_voice_client_handshake(): hello = {"d": {"heartbeat_interval": 50}} @@ -73,3 +84,23 @@ async def test_send_audio_frame(): await vc.send_audio_frame(b"abc") assert udp.sent[-1] == b"abc" + + +@pytest.mark.asyncio +async def test_play_and_switch_sources(): + ws = DummyWebSocket( + [ + {"d": {"heartbeat_interval": 50}}, + {"d": {"ssrc": 1, "ip": "127.0.0.1", "port": 4000}}, + {"d": {"secret_key": []}}, + ] + ) + udp = DummyUDP() + vc = VoiceClient("ws://localhost", "sess", "tok", 1, 2, ws=ws, udp=udp) + await vc.connect() + vc._heartbeat_task.cancel() + + await vc.play(DummySource([b"a", b"b"])) + await vc.play(DummySource([b"c"])) + + assert udp.sent == [b"a", b"b", b"c"]