Skip to content

Commit 711ceee

Browse files
authored
Merge pull request #9 from TrevorS/streaming-voice-design
feat: streaming VoiceDesign support + constructor dedup
2 parents ce382fb + f903bcb commit 711ceee

2 files changed

Lines changed: 312 additions & 4 deletions

File tree

src/lib.rs

Lines changed: 154 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,46 @@ impl Qwen3TTS {
10781078
StreamingSession::new(self, &input_ids, speaker, language, options)
10791079
}
10801080

1081+
/// Synthesize speech using a text-described voice (VoiceDesign model), streaming.
1082+
///
1083+
/// Same as [`Self::synthesize_voice_design`] but returns a streaming session
1084+
/// that yields audio chunks as they are generated.
1085+
///
1086+
/// The instruct text is tokenized with ChatML framing:
1087+
/// `<|im_start|>user\n{instruct}<|im_end|>\n`
1088+
///
1089+
/// # Arguments
1090+
///
1091+
/// * `text` - Text to synthesize
1092+
/// * `instruct` - Natural language voice description (e.g., "A cheerful young female voice")
1093+
/// * `language` - Target language
1094+
/// * `options` - Synthesis options (temperature, top_k, chunk_frames, etc.)
1095+
pub fn synthesize_voice_design_streaming(
1096+
&self,
1097+
text: &str,
1098+
instruct: &str,
1099+
language: Language,
1100+
options: SynthesisOptions,
1101+
) -> Result<StreamingSession<'_>> {
1102+
if let Some(ref mt) = self.model_type {
1103+
if *mt != ModelType::VoiceDesign {
1104+
tracing::warn!(
1105+
"Using VoiceDesign synthesis on a {:?} model. This model was not trained \
1106+
for text-described voice conditioning — output may be unpredictable.",
1107+
mt
1108+
);
1109+
}
1110+
}
1111+
1112+
let input_ids = self.text_tokenizer.encode(text)?;
1113+
1114+
// Tokenize instruct with ChatML user framing: <|im_start|>user\n{instruct}<|im_end|>\n
1115+
let instruct_text = format!("<|im_start|>user\n{}<|im_end|>\n", instruct);
1116+
let instruct_ids = self.text_tokenizer.encode(&instruct_text)?;
1117+
1118+
StreamingSession::new_voice_design(self, &input_ids, &instruct_ids, language, options)
1119+
}
1120+
10811121
// ── Voice cloning API ─────────────────────────────────────────────────
10821122

10831123
/// Create a voice clone prompt from reference audio.
@@ -1475,18 +1515,84 @@ impl<'a> StreamingSession<'a> {
14751515
language: Language,
14761516
options: SynthesisOptions,
14771517
) -> Result<Self> {
1478-
let mut sampling_ctx = generation::SamplingContext::new(options.seed);
1518+
let sampling_ctx = generation::SamplingContext::new(options.seed);
14791519
let config = options.to_gen_config();
14801520

14811521
let (trailing_text_hidden, trailing_text_len, tts_pad_embed) =
14821522
model.build_trailing_text(input_ids)?;
14831523

1484-
// Prefill with CustomVoice format
14851524
let mut kv_caches = model.talker.new_kv_caches(config.max_new_tokens + 256);
1486-
let (hidden, logits) =
1525+
let prefill_result =
14871526
model
14881527
.talker
14891528
.prefill_custom_voice(input_ids, speaker, language, &mut kv_caches)?;
1529+
1530+
Self::from_prefill(
1531+
model,
1532+
config,
1533+
sampling_ctx,
1534+
kv_caches,
1535+
prefill_result,
1536+
trailing_text_hidden,
1537+
trailing_text_len,
1538+
tts_pad_embed,
1539+
options.chunk_frames,
1540+
)
1541+
}
1542+
1543+
/// Create a streaming session using voice design (text-described voice).
1544+
///
1545+
/// Uses `prefill_voice_design` instead of `prefill_custom_voice` to condition
1546+
/// on a natural language voice description rather than a predefined speaker.
1547+
fn new_voice_design(
1548+
model: &'a Qwen3TTS,
1549+
input_ids: &[u32],
1550+
instruct_ids: &[u32],
1551+
language: Language,
1552+
options: SynthesisOptions,
1553+
) -> Result<Self> {
1554+
let sampling_ctx = generation::SamplingContext::new(options.seed);
1555+
let config = options.to_gen_config();
1556+
1557+
let (trailing_text_hidden, trailing_text_len, tts_pad_embed) =
1558+
model.build_trailing_text(input_ids)?;
1559+
1560+
let mut kv_caches = model.talker.new_kv_caches(config.max_new_tokens + 256);
1561+
let prefill_result =
1562+
model
1563+
.talker
1564+
.prefill_voice_design(input_ids, instruct_ids, language, &mut kv_caches)?;
1565+
1566+
Self::from_prefill(
1567+
model,
1568+
config,
1569+
sampling_ctx,
1570+
kv_caches,
1571+
prefill_result,
1572+
trailing_text_hidden,
1573+
trailing_text_len,
1574+
tts_pad_embed,
1575+
options.chunk_frames,
1576+
)
1577+
}
1578+
1579+
/// Shared post-prefill constructor.
1580+
///
1581+
/// Extracts `last_hidden` from the prefill result, builds the suppression and
1582+
/// penalty masks, samples the first semantic token, and assembles the session.
1583+
#[allow(clippy::too_many_arguments)]
1584+
fn from_prefill(
1585+
model: &'a Qwen3TTS,
1586+
config: generation::GenerationConfig,
1587+
mut sampling_ctx: generation::SamplingContext,
1588+
kv_caches: Vec<AnyKVCache>,
1589+
prefill_result: (Tensor, Tensor),
1590+
trailing_text_hidden: Tensor,
1591+
trailing_text_len: usize,
1592+
tts_pad_embed: Tensor,
1593+
chunk_frames: usize,
1594+
) -> Result<Self> {
1595+
let (hidden, logits) = prefill_result;
14901596
let prefill_len = hidden.dim(1)?;
14911597
let last_hidden = hidden.i((.., prefill_len - 1..prefill_len, ..))?;
14921598

@@ -1526,7 +1632,7 @@ impl<'a> StreamingSession<'a> {
15261632
current_token_tensor: if done { None } else { Some(first_token) },
15271633
frames_generated: 0,
15281634
frame_buffer: Vec::new(),
1529-
chunk_frames: options.chunk_frames,
1635+
chunk_frames,
15301636
done,
15311637
trailing_text_hidden,
15321638
trailing_text_len,
@@ -1982,4 +2088,48 @@ mod tests {
19822088
let dtype = compute_dtype_for_device(&Device::Cpu);
19832089
assert_eq!(dtype, DType::F32);
19842090
}
2091+
2092+
#[test]
2093+
fn test_update_penalty_mask() {
2094+
let device = Device::Cpu;
2095+
let vocab_size = 3072;
2096+
let mut mask = Tensor::zeros((1, vocab_size), DType::F32, &device).unwrap();
2097+
2098+
Qwen3TTS::update_penalty_mask(&mut mask, 42, vocab_size).unwrap();
2099+
2100+
let vals: Vec<f32> = mask.flatten_all().unwrap().to_vec1().unwrap();
2101+
assert_eq!(vals[42], 1.0);
2102+
// Neighboring positions should be untouched
2103+
assert_eq!(vals[41], 0.0);
2104+
assert_eq!(vals[43], 0.0);
2105+
}
2106+
2107+
#[test]
2108+
fn test_update_penalty_mask_out_of_range() {
2109+
let device = Device::Cpu;
2110+
let vocab_size = 3072;
2111+
let mut mask = Tensor::zeros((1, vocab_size), DType::F32, &device).unwrap();
2112+
2113+
// Token beyond vocab_size should be a no-op (no panic)
2114+
Qwen3TTS::update_penalty_mask(&mut mask, 9999, vocab_size).unwrap();
2115+
2116+
let sum: f32 = mask.sum_all().unwrap().to_scalar().unwrap();
2117+
assert_eq!(sum, 0.0);
2118+
}
2119+
2120+
#[test]
2121+
fn test_suppression_mask_deterministic() {
2122+
let device = Device::Cpu;
2123+
let vocab = codec_tokens::CODEC_VOCAB_SIZE;
2124+
let mask1 = generation::build_suppression_mask(vocab, CODEC_EOS_TOKEN_ID, &device).unwrap();
2125+
let mask2 = generation::build_suppression_mask(vocab, CODEC_EOS_TOKEN_ID, &device).unwrap();
2126+
2127+
// Apply both masks to uniform logits and verify identical output
2128+
let logits = Tensor::ones((1, vocab), DType::F32, &device).unwrap();
2129+
let out1 = generation::apply_token_suppression_with_mask(&logits, &mask1).unwrap();
2130+
let out2 = generation::apply_token_suppression_with_mask(&logits, &mask2).unwrap();
2131+
let v1: Vec<f32> = out1.flatten_all().unwrap().to_vec1().unwrap();
2132+
let v2: Vec<f32> = out2.flatten_all().unwrap().to_vec1().unwrap();
2133+
assert_eq!(v1, v2);
2134+
}
19852135
}

tests/streaming_e2e.rs

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
//! End-to-end streaming tests.
2+
//!
3+
//! These tests require real model weights and a CUDA GPU.
4+
//! Run with:
5+
//! cargo test --release --features cuda --test streaming_e2e -- --ignored --nocapture
6+
7+
use qwen3_tts::{Language, Qwen3TTS, Speaker, SynthesisOptions};
8+
9+
fn load_model(model_dir: &str) -> Qwen3TTS {
10+
let device = qwen3_tts::auto_device().expect("auto_device failed");
11+
Qwen3TTS::from_pretrained(model_dir, device).expect("model load failed")
12+
}
13+
14+
#[test]
15+
#[ignore = "requires model weights + GPU"]
16+
fn test_streaming_custom_voice() {
17+
let model = load_model("test_data/models/1.7B-CustomVoice");
18+
19+
let options = SynthesisOptions {
20+
seed: Some(42),
21+
chunk_frames: 10,
22+
..Default::default()
23+
};
24+
25+
let session = model
26+
.synthesize_streaming(
27+
"Hello, this is a streaming test.",
28+
Speaker::Ryan,
29+
Language::English,
30+
options,
31+
)
32+
.expect("streaming session creation failed");
33+
34+
let mut total_samples = 0usize;
35+
let mut chunk_count = 0usize;
36+
for chunk_result in session {
37+
let audio = chunk_result.expect("chunk generation failed");
38+
assert!(audio.len() > 0, "chunk {chunk_count} was empty");
39+
assert_eq!(audio.sample_rate, 24000);
40+
total_samples += audio.len();
41+
chunk_count += 1;
42+
println!(
43+
" CustomVoice streaming chunk {}: {} samples ({:.2}s)",
44+
chunk_count,
45+
audio.len(),
46+
audio.duration()
47+
);
48+
}
49+
50+
println!(
51+
"CustomVoice streaming: {} chunks, {:.2}s total",
52+
chunk_count,
53+
total_samples as f32 / 24000.0
54+
);
55+
assert!(chunk_count > 0, "no chunks generated");
56+
assert!(total_samples > 0, "no audio samples generated");
57+
}
58+
59+
#[test]
60+
#[ignore = "requires model weights + GPU"]
61+
fn test_streaming_voice_design() {
62+
let model = load_model("test_data/models/1.7B-VoiceDesign");
63+
64+
let options = SynthesisOptions {
65+
seed: Some(42),
66+
chunk_frames: 10,
67+
..Default::default()
68+
};
69+
70+
let session = model
71+
.synthesize_voice_design_streaming(
72+
"Hello, this is a streaming voice design test.",
73+
"A deep male voice with a calm and steady tone",
74+
Language::English,
75+
options,
76+
)
77+
.expect("streaming session creation failed");
78+
79+
let mut total_samples = 0usize;
80+
let mut chunk_count = 0usize;
81+
for chunk_result in session {
82+
let audio = chunk_result.expect("chunk generation failed");
83+
assert!(audio.len() > 0, "chunk {chunk_count} was empty");
84+
assert_eq!(audio.sample_rate, 24000);
85+
total_samples += audio.len();
86+
chunk_count += 1;
87+
println!(
88+
" VoiceDesign streaming chunk {}: {} samples ({:.2}s)",
89+
chunk_count,
90+
audio.len(),
91+
audio.duration()
92+
);
93+
}
94+
95+
println!(
96+
"VoiceDesign streaming: {} chunks, {:.2}s total",
97+
chunk_count,
98+
total_samples as f32 / 24000.0
99+
);
100+
assert!(chunk_count > 0, "no chunks generated");
101+
assert!(total_samples > 0, "no audio samples generated");
102+
}
103+
104+
#[test]
105+
#[ignore = "requires model weights + GPU"]
106+
fn test_streaming_matches_non_streaming() {
107+
// Verify that streaming and non-streaming produce the same number of
108+
// samples for the same seed (deterministic generation).
109+
let model = load_model("test_data/models/1.7B-CustomVoice");
110+
111+
let make_options = || SynthesisOptions {
112+
seed: Some(123),
113+
chunk_frames: 10,
114+
..Default::default()
115+
};
116+
117+
// Non-streaming
118+
let audio_non_streaming = model
119+
.synthesize_with_voice(
120+
"Determinism test.",
121+
Speaker::Ryan,
122+
Language::English,
123+
Some(make_options()),
124+
)
125+
.expect("non-streaming synthesis failed");
126+
127+
// Streaming — collect all chunks
128+
let session = model
129+
.synthesize_streaming(
130+
"Determinism test.",
131+
Speaker::Ryan,
132+
Language::English,
133+
make_options(),
134+
)
135+
.expect("streaming session creation failed");
136+
137+
let mut streaming_samples: Vec<f32> = Vec::new();
138+
for chunk_result in session {
139+
let audio = chunk_result.expect("chunk failed");
140+
streaming_samples.extend_from_slice(&audio.samples);
141+
}
142+
143+
// The total sample count should match (same frames decoded, same model)
144+
println!(
145+
"Non-streaming: {} samples, Streaming: {} samples",
146+
audio_non_streaming.len(),
147+
streaming_samples.len()
148+
);
149+
150+
// They won't be sample-identical because streaming decodes in chunks
151+
// (decoder sees fewer frames of context per chunk), but frame count
152+
// and total sample count should match.
153+
assert_eq!(
154+
audio_non_streaming.len(),
155+
streaming_samples.len(),
156+
"streaming and non-streaming produced different sample counts"
157+
);
158+
}

0 commit comments

Comments
 (0)