Skip to content

Commit

Permalink
improve Gradio streaming audio playback
Browse files Browse the repository at this point in the history
- Optimize audio chunk sizes with progressive growth strategy
- Adjust audio processing logic for smoother real-time playback
  • Loading branch information
Mereithhh committed Nov 4, 2024
1 parent 29afed6 commit 2b03c2b
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 3 deletions.
93 changes: 93 additions & 0 deletions audio_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os
import librosa
import soundfile as sf
import numpy as np
from pathlib import Path
import io

# Split audio stream at silence points to prevent playback stuttering issues
# caused by AAC encoder frame padding when streaming audio through Gradio audio components.
class AudioStreamProcessor:
def __init__(self, sr=22050, min_silence_duration=0.1, threshold_db=-40):
self.sr = sr
self.min_silence_duration = min_silence_duration
self.threshold_db = threshold_db
self.buffer = np.array([])


def process(self, audio_data, last=False):
"""
Add audio data and process it
params:
audio_data: audio data in numpy array
last: whether this is the last chunk of data
returns:
Processed audio data, returns None if no split point is found
"""

# Add new data to buffer
self.buffer = np.concatenate([self.buffer, audio_data]) if len(self.buffer) > 0 else audio_data

if last:
result = self.buffer
self.buffer = np.array([])
return self._to_wav_bytes(result)

# Find silence boundary
split_point = self._find_silence_boundary(self.buffer)

if split_point is not None:
# Modified: Extend split point to the end of silence
silence_end = self._find_silence_end(split_point)
result = self.buffer[:silence_end]
self.buffer = self.buffer[silence_end:]
return self._to_wav_bytes(result)

return None

def _find_silence_boundary(self, audio):
"""
Find the starting point of silence boundary in audio
"""
# Convert audio to decibels
db = librosa.amplitude_to_db(np.abs(audio), ref=np.max)

# Find points below threshold
silence_points = np.where(db < self.threshold_db)[0]

if len(silence_points) == 0:
return None

# Calculate minimum silence samples
min_silence_samples = int(self.min_silence_duration * self.sr)

# Search backwards for continuous silence segment starting point
for i in range(len(silence_points) - min_silence_samples, -1, -1):
if i < 0:
break
if np.all(np.diff(silence_points[i:i+min_silence_samples]) == 1):
return silence_points[i]

return None

def _find_silence_end(self, start_point):
"""
Find the end point of silence segment
"""
db = librosa.amplitude_to_db(np.abs(self.buffer[start_point:]), ref=np.max)
silence_points = np.where(db >= self.threshold_db)[0]

if len(silence_points) == 0:
return len(self.buffer)

return start_point + silence_points[0]

def _to_wav_bytes(self, audio_data):
"""
trans_to_wav_bytes
"""
wav_buffer = io.BytesIO()
sf.write(wav_buffer, audio_data, self.sr, format='WAV')
return wav_buffer.getvalue()


15 changes: 12 additions & 3 deletions web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
audio_token_pattern = re.compile(r"<\|audio_(\d+)\|>")

from flow_inference import AudioDecoder
from audio_process import AudioStreamProcessor

if __name__ == "__main__":
parser = ArgumentParser()
Expand Down Expand Up @@ -125,13 +126,18 @@ def inference_fn(
tts_mels = []
prev_mel = None
is_finalize = False
block_size = 10
block_size_list = [25,50,100,150,200]
block_size_idx = 0
block_size = block_size_list[block_size_idx]
audio_processor = AudioStreamProcessor()
for chunk in response.iter_lines():
token_id = json.loads(chunk)["token_id"]
if token_id == end_token_id:
is_finalize = True
if len(audio_tokens) >= block_size or (is_finalize and audio_tokens):
block_size = 20
if block_size_idx < len(block_size_list) - 1:
block_size_idx += 1
block_size = block_size_list[block_size_idx]
tts_token = torch.tensor(audio_tokens, device=device).unsqueeze(0)

if prev_mel is not None:
Expand All @@ -143,9 +149,12 @@ def inference_fn(
finalize=is_finalize)
prev_mel = tts_mel

audio_bytes = audio_processor.process(tts_speech.clone().cpu().numpy()[0], last=is_finalize)

tts_speechs.append(tts_speech.squeeze())
tts_mels.append(tts_mel)
yield history, inputs, '', '', (22050, tts_speech.squeeze().cpu().numpy()), None
if audio_bytes:
yield history, inputs, '', '', audio_bytes, None
flow_prompt_speech_token = torch.cat((flow_prompt_speech_token, tts_token), dim=-1)
audio_tokens = []
if not is_finalize:
Expand Down

0 comments on commit 2b03c2b

Please sign in to comment.