diff --git a/disagreement/audio.py b/disagreement/audio.py index 9e58530..cd1eadf 100644 --- a/disagreement/audio.py +++ b/disagreement/audio.py @@ -5,6 +5,7 @@ from __future__ import annotations import asyncio import contextlib import io +import shlex from typing import Optional, Union @@ -35,15 +36,27 @@ class FFmpegAudioSource(AudioSource): A filename, URL, or file-like object to read from. """ - def __init__(self, source: Union[str, io.BufferedIOBase]): + def __init__( + self, + source: Union[str, io.BufferedIOBase], + *, + before_options: Optional[str] = None, + options: Optional[str] = None, + volume: float = 1.0, + ): self.source = source + self.before_options = before_options + self.options = options + self.volume = volume self.process: Optional[asyncio.subprocess.Process] = None self._feeder: Optional[asyncio.Task] = None async def _spawn(self) -> None: if isinstance(self.source, str): - args = [ - "ffmpeg", + args = ["ffmpeg"] + if self.before_options: + args += shlex.split(self.before_options) + args += [ "-i", self.source, "-f", @@ -54,14 +67,18 @@ class FFmpegAudioSource(AudioSource): "2", "pipe:1", ] + if self.options: + args += shlex.split(self.options) self.process = await asyncio.create_subprocess_exec( *args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.DEVNULL, ) else: - args = [ - "ffmpeg", + args = ["ffmpeg"] + if self.before_options: + args += shlex.split(self.before_options) + args += [ "-i", "pipe:0", "-f", @@ -72,6 +89,8 @@ class FFmpegAudioSource(AudioSource): "2", "pipe:1", ] + if self.options: + args += shlex.split(self.options) self.process = await asyncio.create_subprocess_exec( *args, stdin=asyncio.subprocess.PIPE, @@ -115,6 +134,7 @@ class FFmpegAudioSource(AudioSource): with contextlib.suppress(Exception): self.source.close() + class AudioSink: """Abstract base class for audio sinks.""" diff --git a/disagreement/voice_client.py b/disagreement/voice_client.py index c771869..8696665 100644 --- a/disagreement/voice_client.py +++ b/disagreement/voice_client.py @@ -7,9 +7,26 @@ import asyncio import contextlib import socket import threading +from array import array + + +def _apply_volume(data: bytes, volume: float) -> bytes: + samples = array("h") + samples.frombytes(data) + for i, sample in enumerate(samples): + scaled = int(sample * volume) + if scaled > 32767: + scaled = 32767 + elif scaled < -32768: + scaled = -32768 + samples[i] = scaled + return samples.tobytes() + + from typing import TYPE_CHECKING, Optional, Sequence import aiohttp + # The following import is correct, but may be flagged by Pylance if the virtual # environment is not configured correctly. from nacl.secret import SecretBox @@ -180,6 +197,9 @@ class VoiceClient: data = await self._current_source.read() if not data: break + volume = getattr(self._current_source, "volume", 1.0) + if volume != 1.0: + data = _apply_volume(data, volume) await self.send_audio_frame(data) finally: await self._current_source.close() diff --git a/tests/test_voice_client.py b/tests/test_voice_client.py index da052bc..1c6cb98 100644 --- a/tests/test_voice_client.py +++ b/tests/test_voice_client.py @@ -1,8 +1,11 @@ import asyncio +import io +from array import array import pytest +from disagreement.audio import AudioSource, FFmpegAudioSource + from disagreement.voice_client import VoiceClient -from disagreement.audio import AudioSource from disagreement.client import Client @@ -137,3 +140,68 @@ async def test_play_and_switch_sources(): await vc.play(DummySource([b"c"])) assert udp.sent == [b"a", b"b", b"c"] + + +@pytest.mark.asyncio +async def test_ffmpeg_source_custom_options(monkeypatch): + captured = {} + + class DummyProcess: + def __init__(self): + self.stdout = io.BytesIO(b"") + + async def wait(self): + return 0 + + async def fake_exec(*args, **kwargs): + captured["args"] = args + return DummyProcess() + + monkeypatch.setattr(asyncio, "create_subprocess_exec", fake_exec) + src = FFmpegAudioSource( + "file.mp3", before_options="-reconnect 1", options="-vn", volume=0.5 + ) + + await src._spawn() + + cmd = captured["args"] + assert "-reconnect" in cmd + assert "-vn" in cmd + assert src.volume == 0.5 + + +@pytest.mark.asyncio +async def test_voice_client_volume_scaling(monkeypatch): + ws = DummyWebSocket( + [ + {"d": {"heartbeat_interval": 50}}, + {"d": {"ssrc": 1, "ip": "127.0.0.1", "port": 4000}}, + {"d": {"secret_key": []}}, + ] + ) + udp = DummyUDP() + vc = VoiceClient( + client=DummyVoiceClient(), + endpoint="ws://localhost", + session_id="sess", + token="tok", + guild_id=1, + user_id=2, + ws=ws, + udp=udp, + ) + await vc.connect() + vc._heartbeat_task.cancel() + + chunk = b"\x10\x00\x10\x00" + src = DummySource([chunk]) + src.volume = 0.5 + + await vc.play(src) + + samples = array("h") + samples.frombytes(chunk) + samples[0] = int(samples[0] * 0.5) + samples[1] = int(samples[1] * 0.5) + expected = samples.tobytes() + assert udp.sent == [expected]