370 lines
18 KiB
Python
370 lines
18 KiB
Python
import discord
|
|
from discord.ext import commands
|
|
import asyncio
|
|
import os
|
|
import tempfile
|
|
import wave # For saving audio data
|
|
|
|
# Attempt to import STT, VAD, and Opus libraries
|
|
try:
|
|
import whisper
|
|
except ImportError:
|
|
print("Whisper library not found. Please install with 'pip install openai-whisper'")
|
|
whisper = None
|
|
|
|
try:
|
|
import webrtcvad
|
|
except ImportError:
|
|
print("webrtcvad library not found. Please install with 'pip install webrtc-voice-activity-detector'")
|
|
webrtcvad = None
|
|
|
|
try:
|
|
from opuslib import Decoder as OpusDecoder
|
|
from opuslib import OPUS_APPLICATION_VOIP, OPUS_SIGNAL_VOICE
|
|
except ImportError:
|
|
import traceback
|
|
print("opuslib library not found. Please install with 'pip install opuslib' (requires Opus C library).")
|
|
print(f"Import error traceback: {traceback.format_exc()}")
|
|
OpusDecoder = None
|
|
|
|
|
|
FFMPEG_OPTIONS = {
|
|
'before_options': '-reconnect 1 -reconnect_streamed 1 -reconnect_delay_max 5',
|
|
'options': '-vn'
|
|
}
|
|
|
|
# Constants for audio processing
|
|
SAMPLE_RATE = 16000 # Whisper prefers 16kHz
|
|
CHANNELS = 1 # Mono
|
|
SAMPLE_WIDTH = 2 # 16-bit audio (2 bytes per sample)
|
|
VAD_MODE = 3 # VAD aggressiveness (0-3, 3 is most aggressive)
|
|
FRAME_DURATION_MS = 30 # Duration of a frame in ms for VAD (10, 20, or 30)
|
|
BYTES_PER_FRAME = (SAMPLE_RATE // 1000) * FRAME_DURATION_MS * CHANNELS * SAMPLE_WIDTH
|
|
OPUS_FRAME_SIZE_MS = 20 # Opus typically uses 20ms frames
|
|
OPUS_SAMPLES_PER_FRAME = (SAMPLE_RATE // 1000) * OPUS_FRAME_SIZE_MS # e.g. 16000/1000 * 20 = 320 samples for 16kHz
|
|
OPUS_BUFFER_SIZE = OPUS_SAMPLES_PER_FRAME * CHANNELS * SAMPLE_WIDTH # Bytes for PCM buffer for one Opus frame
|
|
|
|
# Silence detection parameters
|
|
SILENCE_THRESHOLD_FRAMES = 25 # Number of consecutive silent VAD frames to consider end of speech (e.g., 25 * 30ms = 750ms)
|
|
MAX_SPEECH_DURATION_S = 15 # Max duration of a single speech segment to process
|
|
MAX_SPEECH_FRAMES = (MAX_SPEECH_DURATION_S * 1000) // FRAME_DURATION_MS
|
|
|
|
|
|
class VoiceAudioSink(discord.AudioSink):
|
|
def __init__(self, cog_instance, voice_client: discord.VoiceClient):
|
|
super().__init__()
|
|
self.cog = cog_instance
|
|
self.voice_client = voice_client # Store the voice_client
|
|
self.user_audio_data = {} # {ssrc: {'buffer': bytearray, 'speaking': False, 'silent_frames': 0, 'speech_frames': 0, 'decoder': OpusDecoder, 'vad': VAD_instance}}
|
|
|
|
if not OpusDecoder:
|
|
print("OpusDecoder not available. AudioSink will not function correctly.")
|
|
if not webrtcvad:
|
|
print("VAD library not loaded. STT might be less efficient or not work as intended.")
|
|
|
|
def write(self, ssrc: int, data: bytes): # data is opus encoded
|
|
if not OpusDecoder or not webrtcvad or not self.voice_client:
|
|
return
|
|
|
|
user = self.voice_client.ssrc_map.get(ssrc)
|
|
if not user: # Unknown SSRC or user left
|
|
# Clean up if user data exists for this SSRC
|
|
if ssrc in self.user_audio_data:
|
|
del self.user_audio_data[ssrc]
|
|
return
|
|
|
|
user_id = user.id
|
|
|
|
if ssrc not in self.user_audio_data:
|
|
self.user_audio_data[ssrc] = {
|
|
'buffer': bytearray(),
|
|
'speaking': False,
|
|
'silent_frames': 0,
|
|
'speech_frames': 0,
|
|
'decoder': OpusDecoder(SAMPLE_RATE, CHANNELS), # Decode to 16kHz mono
|
|
'vad': webrtcvad.Vad(VAD_MODE) if webrtcvad else None
|
|
}
|
|
|
|
entry = self.user_audio_data[ssrc]
|
|
|
|
try:
|
|
# Decode Opus to PCM. Opus data is typically 20ms frames.
|
|
# Max frame size for opuslib decoder is 2 bytes/sample * 1 channel * 120ms * 48kHz = 11520 bytes
|
|
# We expect 20ms frames from Discord.
|
|
# The decoder needs to know the length of the PCM buffer it can write to.
|
|
# For 16kHz, 1 channel, 20ms: 320 samples * 2 bytes/sample = 640 bytes.
|
|
pcm_data = entry['decoder'].decode(data, OPUS_SAMPLES_PER_FRAME, decode_fec=False)
|
|
except Exception as e:
|
|
print(f"Opus decoding error for SSRC {ssrc} (User {user_id}): {e}")
|
|
return
|
|
|
|
# VAD processing expects frames of 10, 20, or 30 ms.
|
|
# Our pcm_data is likely 20ms if decoded correctly.
|
|
# We need to ensure it's split into VAD-compatible frame lengths if not already.
|
|
# If pcm_data is 20ms at 16kHz, its length is 640 bytes.
|
|
# A 10ms frame at 16kHz is 320 bytes. A 30ms frame is 960 bytes.
|
|
# Let's assume pcm_data is one 20ms frame. We can feed it directly if VAD supports 20ms.
|
|
# Or split it into two 10ms frames. Let's use 20ms frames for VAD.
|
|
|
|
# Ensure frame_length for VAD is correct (e.g. 20ms at 16kHz = 320 samples = 640 bytes)
|
|
frame_length_for_vad_20ms = (SAMPLE_RATE // 1000) * 20 * CHANNELS * SAMPLE_WIDTH # 640 bytes for 20ms @ 16kHz
|
|
|
|
if len(pcm_data) != frame_length_for_vad_20ms:
|
|
# This might happen if opus frame duration is not 20ms or sample rate mismatch
|
|
# print(f"Warning: PCM data length {len(pcm_data)} not expected {frame_length_for_vad_20ms} for SSRC {ssrc}. Skipping VAD for this frame.")
|
|
# For simplicity, if frame size is unexpected, we might skip or buffer differently.
|
|
# For now, let's assume it's mostly correct.
|
|
# A more robust solution would handle partial frames or resample/reframe.
|
|
pass
|
|
|
|
|
|
if entry['vad']:
|
|
try:
|
|
is_speech = entry['vad'].is_speech(pcm_data, SAMPLE_RATE)
|
|
except Exception as e: # webrtcvad can raise errors on invalid frame length
|
|
# print(f"VAD error for SSRC {ssrc} (User {user_id}) with PCM length {len(pcm_data)}: {e}. Defaulting to speech=True for this frame.")
|
|
# Fallback: if VAD fails, assume it's speech to avoid losing data, or handle more gracefully.
|
|
is_speech = True # Or False, depending on desired behavior on error
|
|
else: # No VAD
|
|
is_speech = True
|
|
|
|
if is_speech:
|
|
entry['buffer'].extend(pcm_data)
|
|
entry['speaking'] = True
|
|
entry['silent_frames'] = 0
|
|
entry['speech_frames'] += 1
|
|
if entry['speech_frames'] >= MAX_SPEECH_FRAMES:
|
|
# print(f"Max speech frames reached for SSRC {ssrc}. Processing segment.")
|
|
asyncio.create_task(self.cog.process_audio_segment(user_id, bytes(entry['buffer']), self.voice_client.guild))
|
|
entry['buffer'].clear()
|
|
entry['speaking'] = False
|
|
entry['speech_frames'] = 0
|
|
elif entry['speaking']: # Was speaking, now silence
|
|
entry['buffer'].extend(pcm_data) # Add this last silent frame for context
|
|
entry['silent_frames'] += 1
|
|
if entry['silent_frames'] >= SILENCE_THRESHOLD_FRAMES:
|
|
# print(f"Silence threshold reached for SSRC {ssrc}. Processing segment.")
|
|
asyncio.create_task(self.cog.process_audio_segment(user_id, bytes(entry['buffer']), self.voice_client.guild))
|
|
entry['buffer'].clear()
|
|
entry['speaking'] = False
|
|
entry['speech_frames'] = 0
|
|
entry['silent_frames'] = 0
|
|
# If not is_speech and not entry['speaking'], do nothing (ignore silence)
|
|
# else:
|
|
# If buffer has old data and user stopped talking long ago, clear it?
|
|
# This part can be tricky to avoid cutting off speech.
|
|
# The current logic processes on silence *after* speech.
|
|
|
|
def cleanup(self):
|
|
print("VoiceAudioSink cleanup called.")
|
|
for ssrc, data in self.user_audio_data.items():
|
|
# If there's buffered audio when cleaning up, process it
|
|
if data['buffer']:
|
|
user = self.voice_client.ssrc_map.get(ssrc)
|
|
if user:
|
|
print(f"Processing remaining audio for SSRC {ssrc} (User {user.id}) on cleanup.")
|
|
asyncio.create_task(self.cog.process_audio_segment(user.id, bytes(data['buffer']), self.voice_client.guild))
|
|
self.user_audio_data.clear()
|
|
|
|
|
|
class VoiceGatewayCog(commands.Cog):
|
|
def __init__(self, bot):
|
|
self.bot = bot
|
|
self.active_sinks = {} # guild_id: VoiceAudioSink
|
|
self.whisper_model = None
|
|
if whisper:
|
|
try:
|
|
# Load a smaller model initially, can be made configurable
|
|
self.whisper_model = whisper.load_model("base")
|
|
print("Whisper model 'base' loaded successfully.")
|
|
except Exception as e:
|
|
print(f"Error loading Whisper model: {e}. STT will not be available.")
|
|
self.whisper_model = None
|
|
else:
|
|
print("Whisper library not available. STT functionality will be disabled.")
|
|
|
|
async def cog_load(self):
|
|
print("VoiceGatewayCog loaded!")
|
|
|
|
async def cog_unload(self):
|
|
print("Unloading VoiceGatewayCog...")
|
|
# Disconnect from all voice channels and clean up sinks
|
|
for vc in list(self.bot.voice_clients): # Iterate over a copy
|
|
guild_id = vc.guild.id
|
|
if guild_id in self.active_sinks:
|
|
if vc.is_connected():
|
|
vc.stop_listening() # Stop listening before cleanup
|
|
self.active_sinks[guild_id].cleanup()
|
|
del self.active_sinks[guild_id]
|
|
if vc.is_connected():
|
|
await vc.disconnect(force=True)
|
|
print("VoiceGatewayCog unloaded and disconnected from voice channels.")
|
|
|
|
async def connect_to_voice(self, channel: discord.VoiceChannel):
|
|
"""Connects the bot to a specified voice channel and starts listening."""
|
|
if not channel:
|
|
return None, "Channel not provided."
|
|
|
|
guild = channel.guild
|
|
voice_client = guild.voice_client
|
|
|
|
if voice_client and voice_client.is_connected():
|
|
if voice_client.channel == channel:
|
|
print(f"Already connected to {channel.name} in {guild.name}.")
|
|
# Ensure listening is active if already connected
|
|
if guild.id not in self.active_sinks or not voice_client.is_listening():
|
|
self.start_listening_for_vc(voice_client)
|
|
return voice_client, "Already connected to this channel."
|
|
else:
|
|
await voice_client.move_to(channel)
|
|
print(f"Moved to {channel.name} in {guild.name}.")
|
|
# Restart listening in the new channel
|
|
self.start_listening_for_vc(voice_client)
|
|
else:
|
|
try:
|
|
voice_client = await channel.connect(timeout=10.0) # Added timeout
|
|
print(f"Connected to {channel.name} in {guild.name}.")
|
|
self.start_listening_for_vc(voice_client)
|
|
except asyncio.TimeoutError:
|
|
return None, f"Timeout trying to connect to {channel.name}."
|
|
except Exception as e:
|
|
return None, f"Error connecting to {channel.name}: {str(e)}"
|
|
|
|
if not voice_client: # Should not happen if connect succeeded
|
|
return None, "Failed to establish voice client after connection."
|
|
|
|
return voice_client, f"Successfully connected and listening in {channel.name}."
|
|
|
|
def start_listening_for_vc(self, voice_client: discord.VoiceClient):
|
|
"""Starts or restarts listening for a given voice client."""
|
|
guild_id = voice_client.guild.id
|
|
if guild_id in self.active_sinks:
|
|
# If sink exists, ensure it's clean and listening is (re)started
|
|
if voice_client.is_listening():
|
|
voice_client.stop_listening() # Stop previous listening if any
|
|
self.active_sinks[guild_id].cleanup() # Clean old state
|
|
# Re-initialize or ensure the sink is fresh for the current VC
|
|
self.active_sinks[guild_id] = VoiceAudioSink(self, voice_client)
|
|
else:
|
|
self.active_sinks[guild_id] = VoiceAudioSink(self, voice_client)
|
|
|
|
if not voice_client.is_listening():
|
|
voice_client.listen(self.active_sinks[guild_id])
|
|
print(f"Started listening in {voice_client.channel.name} for guild {guild_id}")
|
|
else:
|
|
print(f"Already listening in {voice_client.channel.name} for guild {guild_id}")
|
|
|
|
|
|
async def disconnect_from_voice(self, guild: discord.Guild):
|
|
"""Disconnects the bot from the voice channel in the given guild."""
|
|
voice_client = guild.voice_client
|
|
if voice_client and voice_client.is_connected():
|
|
if voice_client.is_listening():
|
|
voice_client.stop_listening()
|
|
|
|
guild_id = guild.id
|
|
if guild_id in self.active_sinks:
|
|
self.active_sinks[guild_id].cleanup()
|
|
del self.active_sinks[guild_id]
|
|
|
|
await voice_client.disconnect(force=True)
|
|
print(f"Disconnected from voice in {guild.name}.")
|
|
return True, f"Disconnected from voice in {guild.name}."
|
|
return False, "Not connected to voice in this guild."
|
|
|
|
async def play_audio_file(self, voice_client: discord.VoiceClient, audio_file_path: str):
|
|
"""Plays an audio file in the voice channel."""
|
|
if not voice_client or not voice_client.is_connected():
|
|
print("Error: Voice client not connected.")
|
|
return False, "Voice client not connected."
|
|
|
|
if not os.path.exists(audio_file_path):
|
|
print(f"Error: Audio file not found at {audio_file_path}")
|
|
return False, "Audio file not found."
|
|
|
|
if voice_client.is_playing():
|
|
voice_client.stop() # Stop current audio if any
|
|
|
|
try:
|
|
audio_source = discord.FFmpegPCMAudio(audio_file_path, **FFMPEG_OPTIONS)
|
|
voice_client.play(audio_source, after=lambda e: self.after_audio_playback(e, audio_file_path))
|
|
print(f"Playing audio: {audio_file_path}")
|
|
return True, f"Playing {os.path.basename(audio_file_path)}"
|
|
except Exception as e:
|
|
print(f"Error creating/playing FFmpegPCMAudio source for {audio_file_path}: {e}")
|
|
return False, f"Error playing audio: {str(e)}"
|
|
|
|
def after_audio_playback(self, error, audio_file_path):
|
|
if error:
|
|
print(f"Error during audio playback for {audio_file_path}: {error}")
|
|
else:
|
|
print(f"Finished playing {audio_file_path}")
|
|
# TTSProviderCog's cleanup will handle deleting the file.
|
|
|
|
# Removed start_listening_pipeline as the sink now handles more logic directly or via tasks.
|
|
|
|
async def process_audio_segment(self, user_id: int, audio_data: bytes, guild: discord.Guild):
|
|
"""Processes a segment of audio data using Whisper."""
|
|
if not self.whisper_model or not audio_data: # also check if audio_data is empty
|
|
if not audio_data: print(f"process_audio_segment called for user {user_id} with empty audio_data.")
|
|
return
|
|
|
|
# Save audio_data (PCM) to a temporary WAV file
|
|
# Whisper expects a file path or a NumPy array.
|
|
# Using a temporary file is straightforward.
|
|
try:
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav:
|
|
wav_file_path = tmp_wav.name
|
|
wf = wave.open(tmp_wav, 'wb')
|
|
wf.setnchannels(CHANNELS)
|
|
wf.setsampwidth(SAMPLE_WIDTH)
|
|
wf.setframerate(SAMPLE_RATE)
|
|
wf.writeframes(audio_data)
|
|
wf.close()
|
|
|
|
# Transcribe using Whisper (this can be blocking, run in executor)
|
|
loop = asyncio.get_event_loop()
|
|
result = await loop.run_in_executor(
|
|
None, # Default ThreadPoolExecutor
|
|
self.whisper_model.transcribe,
|
|
wav_file_path,
|
|
fp16=False # Set to True if GPU supports it and you want faster inference
|
|
)
|
|
transcribed_text = result["text"].strip()
|
|
|
|
if transcribed_text: # Only dispatch if there's actual text
|
|
user = guild.get_member(user_id) or await self.bot.fetch_user(user_id)
|
|
print(f"Transcription for {user.name} ({user_id}) in {guild.name}: {transcribed_text}")
|
|
self.bot.dispatch("voice_transcription_received", guild, user, transcribed_text)
|
|
|
|
except Exception as e:
|
|
print(f"Error processing audio segment for user {user_id}: {e}")
|
|
finally:
|
|
if 'wav_file_path' in locals() and os.path.exists(wav_file_path):
|
|
os.remove(wav_file_path)
|
|
|
|
|
|
async def setup(bot: commands.Bot):
|
|
# Check for FFmpeg before adding cog
|
|
try:
|
|
# Try running ffmpeg -version to check if it's installed and in PATH
|
|
process = await asyncio.create_subprocess_shell(
|
|
"ffmpeg -version",
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE
|
|
)
|
|
stdout, stderr = await process.communicate()
|
|
if process.returncode == 0:
|
|
print("FFmpeg found. VoiceGatewayCog can be loaded.")
|
|
await bot.add_cog(VoiceGatewayCog(bot))
|
|
print("VoiceGatewayCog loaded successfully!")
|
|
else:
|
|
print("FFmpeg not found or not working correctly. VoiceGatewayCog will not be loaded.")
|
|
print(f"FFmpeg check stdout: {stdout.decode(errors='ignore')}")
|
|
print(f"FFmpeg check stderr: {stderr.decode(errors='ignore')}")
|
|
|
|
except FileNotFoundError:
|
|
print("FFmpeg command not found. VoiceGatewayCog will not be loaded. Please install FFmpeg and ensure it's in your system's PATH.")
|
|
except Exception as e:
|
|
print(f"An error occurred while checking for FFmpeg: {e}. VoiceGatewayCog will not be loaded.")
|