feat: Refactor audio processing in VoiceGatewayCog to remove OpusDecoder and integrate FFmpeg for audio conversion
This commit is contained in:
parent
c0c65fe3d1
commit
3824ba9a6c
@ -4,8 +4,10 @@ import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
import wave # For saving audio data
|
||||
import subprocess # For audio conversion
|
||||
from discord.ext import voice_recv # For receiving voice
|
||||
|
||||
# Attempt to import STT, VAD, and Opus libraries
|
||||
# Attempt to import STT and VAD libraries
|
||||
try:
|
||||
import whisper
|
||||
except ImportError:
|
||||
@ -18,14 +20,7 @@ except ImportError:
|
||||
print("webrtcvad library not found. Please install with 'pip install webrtc-voice-activity-detector'")
|
||||
webrtcvad = None
|
||||
|
||||
try:
|
||||
from opuslib import Decoder as OpusDecoder
|
||||
except ImportError:
|
||||
import traceback
|
||||
print("opuslib library not found. Please install with 'pip install opuslib' (requires Opus C library).")
|
||||
print(f"Import error traceback: {traceback.format_exc()}")
|
||||
OpusDecoder = None
|
||||
|
||||
# OpusDecoder is no longer needed as discord-ext-voice-recv provides PCM.
|
||||
|
||||
FFMPEG_OPTIONS = {
|
||||
'before_options': '-reconnect 1 -reconnect_streamed 1 -reconnect_delay_max 5',
|
||||
@ -39,9 +34,7 @@ SAMPLE_WIDTH = 2 # 16-bit audio (2 bytes per sample)
|
||||
VAD_MODE = 3 # VAD aggressiveness (0-3, 3 is most aggressive)
|
||||
FRAME_DURATION_MS = 30 # Duration of a frame in ms for VAD (10, 20, or 30)
|
||||
BYTES_PER_FRAME = (SAMPLE_RATE // 1000) * FRAME_DURATION_MS * CHANNELS * SAMPLE_WIDTH
|
||||
OPUS_FRAME_SIZE_MS = 20 # Opus typically uses 20ms frames
|
||||
OPUS_SAMPLES_PER_FRAME = (SAMPLE_RATE // 1000) * OPUS_FRAME_SIZE_MS # e.g. 16000/1000 * 20 = 320 samples for 16kHz
|
||||
OPUS_BUFFER_SIZE = OPUS_SAMPLES_PER_FRAME * CHANNELS * SAMPLE_WIDTH # Bytes for PCM buffer for one Opus frame
|
||||
# OPUS constants removed as Opus decoding is no longer handled here.
|
||||
|
||||
# Silence detection parameters
|
||||
SILENCE_THRESHOLD_FRAMES = 25 # Number of consecutive silent VAD frames to consider end of speech (e.g., 25 * 30ms = 750ms)
|
||||
@ -49,81 +42,148 @@ MAX_SPEECH_DURATION_S = 15 # Max duration of a single speech segment to proce
|
||||
MAX_SPEECH_FRAMES = (MAX_SPEECH_DURATION_S * 1000) // FRAME_DURATION_MS
|
||||
|
||||
|
||||
class VoiceAudioSink(discord.AudioSink):
|
||||
def __init__(self, cog_instance, voice_client: discord.VoiceClient):
|
||||
# Helper function for audio conversion
|
||||
def _convert_audio_to_16khz_mono(raw_pcm_data_48k_stereo: bytes) -> bytes:
|
||||
"""
|
||||
Converts raw 48kHz stereo PCM data to 16kHz mono PCM data using FFmpeg.
|
||||
"""
|
||||
input_temp_file = None
|
||||
output_temp_file = None
|
||||
converted_audio_data = b""
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=".raw", delete=False) as tmp_in:
|
||||
input_temp_file = tmp_in.name
|
||||
tmp_in.write(raw_pcm_data_48k_stereo)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_out:
|
||||
output_temp_file = tmp_out.name
|
||||
|
||||
command = [
|
||||
'ffmpeg',
|
||||
'-f', 's16le', # Input format: signed 16-bit little-endian PCM
|
||||
'-ac', '2', # Input channels: stereo
|
||||
'-ar', '48000', # Input sample rate: 48kHz
|
||||
'-i', input_temp_file,
|
||||
'-ac', str(CHANNELS), # Output channels (e.g., 1 for mono)
|
||||
'-ar', str(SAMPLE_RATE), # Output sample rate (e.g., 16000)
|
||||
'-sample_fmt', 's16',# Output sample format
|
||||
'-y', # Overwrite output file if it exists
|
||||
output_temp_file
|
||||
]
|
||||
|
||||
process = subprocess.run(command, capture_output=True, check=False)
|
||||
|
||||
if process.returncode != 0:
|
||||
print(f"FFmpeg error during audio conversion. Return code: {process.returncode}")
|
||||
print(f"FFmpeg stdout: {process.stdout.decode(errors='ignore')}")
|
||||
print(f"FFmpeg stderr: {process.stderr.decode(errors='ignore')}")
|
||||
return b""
|
||||
|
||||
with open(output_temp_file, 'rb') as f_out:
|
||||
with wave.open(f_out, 'rb') as wf:
|
||||
if wf.getnchannels() == CHANNELS and \
|
||||
wf.getframerate() == SAMPLE_RATE and \
|
||||
wf.getsampwidth() == SAMPLE_WIDTH:
|
||||
converted_audio_data = wf.readframes(wf.getnframes())
|
||||
else:
|
||||
print(f"Warning: Converted WAV file format mismatch. Expected {CHANNELS}ch, {SAMPLE_RATE}Hz, {SAMPLE_WIDTH}bytes/sample.")
|
||||
print(f"Got: {wf.getnchannels()}ch, {wf.getframerate()}Hz, {wf.getsampwidth()}bytes/sample.")
|
||||
return b""
|
||||
except FileNotFoundError:
|
||||
print("FFmpeg command not found. Please ensure FFmpeg is installed and in your system's PATH.")
|
||||
return b""
|
||||
except Exception as e:
|
||||
print(f"Error during audio conversion: {e}")
|
||||
return b""
|
||||
finally:
|
||||
if input_temp_file and os.path.exists(input_temp_file):
|
||||
os.remove(input_temp_file)
|
||||
if output_temp_file and os.path.exists(output_temp_file):
|
||||
os.remove(output_temp_file)
|
||||
|
||||
return converted_audio_data
|
||||
|
||||
|
||||
class VoiceAudioSink(voice_recv.AudioSink): # Inherit from voice_recv.AudioSink
|
||||
def __init__(self, cog_instance, voice_client: voice_recv.VoiceRecvClient): # Updated type hint
|
||||
super().__init__()
|
||||
self.cog = cog_instance
|
||||
self.voice_client = voice_client # Store the voice_client
|
||||
self.user_audio_data = {} # {ssrc: {'buffer': bytearray, 'speaking': False, 'silent_frames': 0, 'speech_frames': 0, 'decoder': OpusDecoder, 'vad': VAD_instance}}
|
||||
# user_audio_data now keyed by user_id, 'decoder' removed
|
||||
self.user_audio_data = {} # {user_id: {'buffer': bytearray, 'speaking': False, 'silent_frames': 0, 'speech_frames': 0, 'vad': VAD_instance}}
|
||||
|
||||
if not OpusDecoder:
|
||||
print("OpusDecoder not available. AudioSink will not function correctly.")
|
||||
# OpusDecoder check removed
|
||||
if not webrtcvad:
|
||||
print("VAD library not loaded. STT might be less efficient or not work as intended.")
|
||||
|
||||
def write(self, ssrc: int, data: bytes): # data is opus encoded
|
||||
if not OpusDecoder or not webrtcvad or not self.voice_client:
|
||||
return
|
||||
|
||||
user = self.voice_client.ssrc_map.get(ssrc)
|
||||
if not user: # Unknown SSRC or user left
|
||||
# Clean up if user data exists for this SSRC
|
||||
if ssrc in self.user_audio_data:
|
||||
del self.user_audio_data[ssrc]
|
||||
# Signature changed: user object directly, data is raw 48kHz stereo PCM
|
||||
def write(self, user: discord.User, pcm_data_48k_stereo: bytes):
|
||||
if not webrtcvad or not self.voice_client or not user: # OpusDecoder check removed, user check added
|
||||
return
|
||||
|
||||
user_id = user.id
|
||||
user_id = user.id # Get user_id from the user object
|
||||
|
||||
if ssrc not in self.user_audio_data:
|
||||
self.user_audio_data[ssrc] = {
|
||||
if user_id not in self.user_audio_data:
|
||||
self.user_audio_data[user_id] = {
|
||||
'buffer': bytearray(),
|
||||
'speaking': False,
|
||||
'silent_frames': 0,
|
||||
'speech_frames': 0,
|
||||
'decoder': OpusDecoder(SAMPLE_RATE, CHANNELS), # Decode to 16kHz mono
|
||||
# 'decoder' removed
|
||||
'vad': webrtcvad.Vad(VAD_MODE) if webrtcvad else None
|
||||
}
|
||||
|
||||
entry = self.user_audio_data[ssrc]
|
||||
|
||||
try:
|
||||
# Decode Opus to PCM. Opus data is typically 20ms frames.
|
||||
# Max frame size for opuslib decoder is 2 bytes/sample * 1 channel * 120ms * 48kHz = 11520 bytes
|
||||
# We expect 20ms frames from Discord.
|
||||
# The decoder needs to know the length of the PCM buffer it can write to.
|
||||
# For 16kHz, 1 channel, 20ms: 320 samples * 2 bytes/sample = 640 bytes.
|
||||
pcm_data = entry['decoder'].decode(data, OPUS_SAMPLES_PER_FRAME, decode_fec=False)
|
||||
except Exception as e:
|
||||
print(f"Opus decoding error for SSRC {ssrc} (User {user_id}): {e}")
|
||||
entry = self.user_audio_data[user_id]
|
||||
|
||||
# Convert incoming 48kHz stereo PCM to 16kHz mono PCM
|
||||
pcm_data = _convert_audio_to_16khz_mono(pcm_data_48k_stereo)
|
||||
if not pcm_data: # Conversion failed or returned empty bytes
|
||||
# print(f"Audio conversion failed for user {user_id}. Skipping frame.")
|
||||
return
|
||||
|
||||
# VAD processing expects frames of 10, 20, or 30 ms.
|
||||
# Our pcm_data is likely 20ms if decoded correctly.
|
||||
# pcm_data is now 16kHz mono, hopefully in appropriate chunks from conversion.
|
||||
# We need to ensure it's split into VAD-compatible frame lengths if not already.
|
||||
# If pcm_data is 20ms at 16kHz, its length is 640 bytes.
|
||||
# If pcm_data (now 16kHz mono) is a 20ms chunk, its length is 640 bytes.
|
||||
# A 10ms frame at 16kHz is 320 bytes. A 30ms frame is 960 bytes.
|
||||
# Let's assume pcm_data is one 20ms frame. We can feed it directly if VAD supports 20ms.
|
||||
# Or split it into two 10ms frames. Let's use 20ms frames for VAD.
|
||||
|
||||
# Ensure frame_length for VAD is correct (e.g. 20ms at 16kHz = 320 samples = 640 bytes)
|
||||
frame_length_for_vad_20ms = (SAMPLE_RATE // 1000) * 20 * CHANNELS * SAMPLE_WIDTH # 640 bytes for 20ms @ 16kHz
|
||||
# Ensure frame_length for VAD is correct (e.g. 20ms at 16kHz mono = 640 bytes)
|
||||
# This constant could be defined at class or module level.
|
||||
# For a 20ms frame, which is typical for voice packets:
|
||||
frame_length_for_vad_20ms = (SAMPLE_RATE // 1000) * 20 * CHANNELS * SAMPLE_WIDTH
|
||||
|
||||
if len(pcm_data) != frame_length_for_vad_20ms:
|
||||
# This might happen if opus frame duration is not 20ms or sample rate mismatch
|
||||
# print(f"Warning: PCM data length {len(pcm_data)} not expected {frame_length_for_vad_20ms} for SSRC {ssrc}. Skipping VAD for this frame.")
|
||||
# For simplicity, if frame size is unexpected, we might skip or buffer differently.
|
||||
# For now, let's assume it's mostly correct.
|
||||
# A more robust solution would handle partial frames or resample/reframe.
|
||||
pass
|
||||
if len(pcm_data) % frame_length_for_vad_20ms != 0 and len(pcm_data) > 0 : # Check if it's a multiple, or handle if not.
|
||||
# This might happen if the converted chunk size isn't exactly what VAD expects per call.
|
||||
# For now, we'll try to process it. A more robust solution might buffer/segment pcm_data
|
||||
# into exact 10, 20, or 30ms chunks for VAD.
|
||||
# print(f"Warning: PCM data length {len(pcm_data)} after conversion is not an exact multiple of VAD frame size {frame_length_for_vad_20ms} for User {user_id}. Trying to process.")
|
||||
pass # Continue, VAD might handle it or error.
|
||||
|
||||
# Process VAD in chunks if pcm_data is longer than one VAD frame
|
||||
# For simplicity, let's assume pcm_data is one processable chunk for now.
|
||||
# If pcm_data can be multiple VAD frames, iterate through it.
|
||||
# Current VAD logic processes the whole pcm_data chunk at once.
|
||||
# This is okay if pcm_data is already a single VAD frame (e.g. 20ms).
|
||||
|
||||
if entry['vad']:
|
||||
try:
|
||||
is_speech = entry['vad'].is_speech(pcm_data, SAMPLE_RATE)
|
||||
# Ensure pcm_data is a valid frame for VAD (e.g. 10, 20, 30 ms)
|
||||
# If pcm_data is, for example, 640 bytes (20ms at 16kHz mono), it's fine.
|
||||
if len(pcm_data) == frame_length_for_vad_20ms: # Common case
|
||||
is_speech = entry['vad'].is_speech(pcm_data, SAMPLE_RATE)
|
||||
elif len(pcm_data) > 0 : # If not standard, but has data, try (might error)
|
||||
# print(f"VAD processing for User {user_id} with non-standard PCM length {len(pcm_data)}. May error.")
|
||||
# This path is risky if VAD is strict. For now, we assume it's handled or errors.
|
||||
# A robust way: segment pcm_data into valid VAD frames.
|
||||
# For now, let's assume the chunk from conversion is one such frame.
|
||||
is_speech = entry['vad'].is_speech(pcm_data, SAMPLE_RATE) # This might fail if len is not 10/20/30ms worth
|
||||
else: # No data
|
||||
is_speech = False
|
||||
|
||||
except Exception as e: # webrtcvad can raise errors on invalid frame length
|
||||
# print(f"VAD error for SSRC {ssrc} (User {user_id}) with PCM length {len(pcm_data)}: {e}. Defaulting to speech=True for this frame.")
|
||||
# Fallback: if VAD fails, assume it's speech to avoid losing data, or handle more gracefully.
|
||||
is_speech = True # Or False, depending on desired behavior on error
|
||||
# print(f"VAD error for User {user_id} with PCM length {len(pcm_data)}: {e}. Defaulting to speech=True for this frame.")
|
||||
is_speech = True # Fallback: if VAD fails, assume it's speech
|
||||
else: # No VAD
|
||||
is_speech = True
|
||||
|
||||
@ -133,7 +193,7 @@ class VoiceAudioSink(discord.AudioSink):
|
||||
entry['silent_frames'] = 0
|
||||
entry['speech_frames'] += 1
|
||||
if entry['speech_frames'] >= MAX_SPEECH_FRAMES:
|
||||
# print(f"Max speech frames reached for SSRC {ssrc}. Processing segment.")
|
||||
# print(f"Max speech frames reached for User {user_id}. Processing segment.")
|
||||
asyncio.create_task(self.cog.process_audio_segment(user_id, bytes(entry['buffer']), self.voice_client.guild))
|
||||
entry['buffer'].clear()
|
||||
entry['speaking'] = False
|
||||
@ -142,27 +202,27 @@ class VoiceAudioSink(discord.AudioSink):
|
||||
entry['buffer'].extend(pcm_data) # Add this last silent frame for context
|
||||
entry['silent_frames'] += 1
|
||||
if entry['silent_frames'] >= SILENCE_THRESHOLD_FRAMES:
|
||||
# print(f"Silence threshold reached for SSRC {ssrc}. Processing segment.")
|
||||
# print(f"Silence threshold reached for User {user_id}. Processing segment.")
|
||||
asyncio.create_task(self.cog.process_audio_segment(user_id, bytes(entry['buffer']), self.voice_client.guild))
|
||||
entry['buffer'].clear()
|
||||
entry['speaking'] = False
|
||||
entry['speech_frames'] = 0
|
||||
entry['silent_frames'] = 0
|
||||
# If not is_speech and not entry['speaking'], do nothing (ignore silence)
|
||||
# else:
|
||||
# If buffer has old data and user stopped talking long ago, clear it?
|
||||
# This part can be tricky to avoid cutting off speech.
|
||||
# The current logic processes on silence *after* speech.
|
||||
|
||||
def cleanup(self):
|
||||
print("VoiceAudioSink cleanup called.")
|
||||
for ssrc, data in self.user_audio_data.items():
|
||||
# If there's buffered audio when cleaning up, process it
|
||||
if data['buffer']:
|
||||
user = self.voice_client.ssrc_map.get(ssrc)
|
||||
if user:
|
||||
print(f"Processing remaining audio for SSRC {ssrc} (User {user.id}) on cleanup.")
|
||||
asyncio.create_task(self.cog.process_audio_segment(user.id, bytes(data['buffer']), self.voice_client.guild))
|
||||
# Iterate over a copy of items if modifications occur, or handle user_id directly
|
||||
for user_id, data_entry in list(self.user_audio_data.items()):
|
||||
if data_entry['buffer']:
|
||||
# user object is not directly available here, but process_audio_segment takes user_id
|
||||
# We need the guild, which should be available from self.voice_client
|
||||
if self.voice_client and self.voice_client.guild:
|
||||
guild = self.voice_client.guild
|
||||
print(f"Processing remaining audio for User ID {user_id} on cleanup.")
|
||||
asyncio.create_task(self.cog.process_audio_segment(user_id, bytes(data_entry['buffer']), guild))
|
||||
else:
|
||||
print(f"Cannot process remaining audio for User ID {user_id}: voice_client or guild not available.")
|
||||
self.user_audio_data.clear()
|
||||
|
||||
|
||||
@ -191,8 +251,13 @@ class VoiceGatewayCog(commands.Cog):
|
||||
for vc in list(self.bot.voice_clients): # Iterate over a copy
|
||||
guild_id = vc.guild.id
|
||||
if guild_id in self.active_sinks:
|
||||
if vc.is_connected():
|
||||
vc.stop_listening() # Stop listening before cleanup
|
||||
# Ensure vc is an instance of VoiceRecvClient or compatible for stop_listening
|
||||
if vc.is_connected() and hasattr(vc, 'is_listening') and vc.is_listening():
|
||||
# Check if stop_listening exists, VoiceRecvClient might have different API
|
||||
if hasattr(vc, 'stop_listening'):
|
||||
vc.stop_listening()
|
||||
else: # Or equivalent for VoiceRecvClient
|
||||
pass # May need specific cleanup for voice_recv
|
||||
self.active_sinks[guild_id].cleanup()
|
||||
del self.active_sinks[guild_id]
|
||||
if vc.is_connected():
|
||||
@ -205,24 +270,47 @@ class VoiceGatewayCog(commands.Cog):
|
||||
return None, "Channel not provided."
|
||||
|
||||
guild = channel.guild
|
||||
voice_client = guild.voice_client
|
||||
voice_client = guild.voice_client # This will be VoiceRecvClient if already connected by this cog
|
||||
|
||||
if voice_client and voice_client.is_connected():
|
||||
if voice_client.channel == channel:
|
||||
print(f"Already connected to {channel.name} in {guild.name}.")
|
||||
# Ensure listening is active if already connected
|
||||
if guild.id not in self.active_sinks or not voice_client.is_listening():
|
||||
self.start_listening_for_vc(voice_client)
|
||||
# Check if it's a VoiceRecvClient instance
|
||||
if isinstance(voice_client, voice_recv.VoiceRecvClient):
|
||||
if guild.id not in self.active_sinks or not voice_client.is_listening():
|
||||
self.start_listening_for_vc(voice_client)
|
||||
else: # If it's a regular VoiceClient, we need to reconnect with VoiceRecvClient
|
||||
print(f"Reconnecting with VoiceRecvClient to {channel.name}.")
|
||||
await voice_client.disconnect(force=True)
|
||||
try: # Reconnect with VoiceRecvClient
|
||||
voice_client = await channel.connect(cls=voice_recv.VoiceRecvClient, timeout=10.0)
|
||||
print(f"Reconnected to {channel.name} in {guild.name} with VoiceRecvClient.")
|
||||
self.start_listening_for_vc(voice_client)
|
||||
except asyncio.TimeoutError:
|
||||
return None, f"Timeout trying to reconnect to {channel.name} with VoiceRecvClient."
|
||||
except Exception as e:
|
||||
return None, f"Error reconnecting to {channel.name} with VoiceRecvClient: {str(e)}"
|
||||
|
||||
return voice_client, "Already connected to this channel."
|
||||
else:
|
||||
await voice_client.move_to(channel)
|
||||
print(f"Moved to {channel.name} in {guild.name}.")
|
||||
# Restart listening in the new channel
|
||||
self.start_listening_for_vc(voice_client)
|
||||
# Handling move_to for VoiceRecvClient might need care.
|
||||
# Simplest: disconnect and reconnect with VoiceRecvClient to the new channel.
|
||||
print(f"Moving to {channel.name} in {guild.name}. Reconnecting with VoiceRecvClient.")
|
||||
await voice_client.disconnect(force=True)
|
||||
try:
|
||||
voice_client = await channel.connect(cls=voice_recv.VoiceRecvClient, timeout=10.0)
|
||||
print(f"Moved and reconnected to {channel.name} in {guild.name} with VoiceRecvClient.")
|
||||
self.start_listening_for_vc(voice_client)
|
||||
except asyncio.TimeoutError:
|
||||
return None, f"Timeout trying to move and connect to {channel.name}."
|
||||
except Exception as e:
|
||||
return None, f"Error moving and connecting to {channel.name}: {str(e)}"
|
||||
else:
|
||||
try:
|
||||
voice_client = await channel.connect(timeout=10.0) # Added timeout
|
||||
print(f"Connected to {channel.name} in {guild.name}.")
|
||||
# Connect using VoiceRecvClient
|
||||
voice_client = await channel.connect(cls=voice_recv.VoiceRecvClient, timeout=10.0)
|
||||
print(f"Connected to {channel.name} in {guild.name} with VoiceRecvClient.")
|
||||
self.start_listening_for_vc(voice_client)
|
||||
except asyncio.TimeoutError:
|
||||
return None, f"Timeout trying to connect to {channel.name}."
|
||||
|
Loading…
x
Reference in New Issue
Block a user