@@ -1078,6 +1078,46 @@ impl Qwen3TTS {
10781078 StreamingSession :: new ( self , & input_ids, speaker, language, options)
10791079 }
10801080
1081+ /// Synthesize speech using a text-described voice (VoiceDesign model), streaming.
1082+ ///
1083+ /// Same as [`Self::synthesize_voice_design`] but returns a streaming session
1084+ /// that yields audio chunks as they are generated.
1085+ ///
1086+ /// The instruct text is tokenized with ChatML framing:
1087+ /// `<|im_start|>user\n{instruct}<|im_end|>\n`
1088+ ///
1089+ /// # Arguments
1090+ ///
1091+ /// * `text` - Text to synthesize
1092+ /// * `instruct` - Natural language voice description (e.g., "A cheerful young female voice")
1093+ /// * `language` - Target language
1094+ /// * `options` - Synthesis options (temperature, top_k, chunk_frames, etc.)
1095+ pub fn synthesize_voice_design_streaming (
1096+ & self ,
1097+ text : & str ,
1098+ instruct : & str ,
1099+ language : Language ,
1100+ options : SynthesisOptions ,
1101+ ) -> Result < StreamingSession < ' _ > > {
1102+ if let Some ( ref mt) = self . model_type {
1103+ if * mt != ModelType :: VoiceDesign {
1104+ tracing:: warn!(
1105+ "Using VoiceDesign synthesis on a {:?} model. This model was not trained \
1106+ for text-described voice conditioning — output may be unpredictable.",
1107+ mt
1108+ ) ;
1109+ }
1110+ }
1111+
1112+ let input_ids = self . text_tokenizer . encode ( text) ?;
1113+
1114+ // Tokenize instruct with ChatML user framing: <|im_start|>user\n{instruct}<|im_end|>\n
1115+ let instruct_text = format ! ( "<|im_start|>user\n {}<|im_end|>\n " , instruct) ;
1116+ let instruct_ids = self . text_tokenizer . encode ( & instruct_text) ?;
1117+
1118+ StreamingSession :: new_voice_design ( self , & input_ids, & instruct_ids, language, options)
1119+ }
1120+
10811121 // ── Voice cloning API ─────────────────────────────────────────────────
10821122
10831123 /// Create a voice clone prompt from reference audio.
@@ -1475,18 +1515,84 @@ impl<'a> StreamingSession<'a> {
14751515 language : Language ,
14761516 options : SynthesisOptions ,
14771517 ) -> Result < Self > {
1478- let mut sampling_ctx = generation:: SamplingContext :: new ( options. seed ) ;
1518+ let sampling_ctx = generation:: SamplingContext :: new ( options. seed ) ;
14791519 let config = options. to_gen_config ( ) ;
14801520
14811521 let ( trailing_text_hidden, trailing_text_len, tts_pad_embed) =
14821522 model. build_trailing_text ( input_ids) ?;
14831523
1484- // Prefill with CustomVoice format
14851524 let mut kv_caches = model. talker . new_kv_caches ( config. max_new_tokens + 256 ) ;
1486- let ( hidden , logits ) =
1525+ let prefill_result =
14871526 model
14881527 . talker
14891528 . prefill_custom_voice ( input_ids, speaker, language, & mut kv_caches) ?;
1529+
1530+ Self :: from_prefill (
1531+ model,
1532+ config,
1533+ sampling_ctx,
1534+ kv_caches,
1535+ prefill_result,
1536+ trailing_text_hidden,
1537+ trailing_text_len,
1538+ tts_pad_embed,
1539+ options. chunk_frames ,
1540+ )
1541+ }
1542+
1543+ /// Create a streaming session using voice design (text-described voice).
1544+ ///
1545+ /// Uses `prefill_voice_design` instead of `prefill_custom_voice` to condition
1546+ /// on a natural language voice description rather than a predefined speaker.
1547+ fn new_voice_design (
1548+ model : & ' a Qwen3TTS ,
1549+ input_ids : & [ u32 ] ,
1550+ instruct_ids : & [ u32 ] ,
1551+ language : Language ,
1552+ options : SynthesisOptions ,
1553+ ) -> Result < Self > {
1554+ let sampling_ctx = generation:: SamplingContext :: new ( options. seed ) ;
1555+ let config = options. to_gen_config ( ) ;
1556+
1557+ let ( trailing_text_hidden, trailing_text_len, tts_pad_embed) =
1558+ model. build_trailing_text ( input_ids) ?;
1559+
1560+ let mut kv_caches = model. talker . new_kv_caches ( config. max_new_tokens + 256 ) ;
1561+ let prefill_result =
1562+ model
1563+ . talker
1564+ . prefill_voice_design ( input_ids, instruct_ids, language, & mut kv_caches) ?;
1565+
1566+ Self :: from_prefill (
1567+ model,
1568+ config,
1569+ sampling_ctx,
1570+ kv_caches,
1571+ prefill_result,
1572+ trailing_text_hidden,
1573+ trailing_text_len,
1574+ tts_pad_embed,
1575+ options. chunk_frames ,
1576+ )
1577+ }
1578+
1579+ /// Shared post-prefill constructor.
1580+ ///
1581+ /// Extracts `last_hidden` from the prefill result, builds the suppression and
1582+ /// penalty masks, samples the first semantic token, and assembles the session.
1583+ #[ allow( clippy:: too_many_arguments) ]
1584+ fn from_prefill (
1585+ model : & ' a Qwen3TTS ,
1586+ config : generation:: GenerationConfig ,
1587+ mut sampling_ctx : generation:: SamplingContext ,
1588+ kv_caches : Vec < AnyKVCache > ,
1589+ prefill_result : ( Tensor , Tensor ) ,
1590+ trailing_text_hidden : Tensor ,
1591+ trailing_text_len : usize ,
1592+ tts_pad_embed : Tensor ,
1593+ chunk_frames : usize ,
1594+ ) -> Result < Self > {
1595+ let ( hidden, logits) = prefill_result;
14901596 let prefill_len = hidden. dim ( 1 ) ?;
14911597 let last_hidden = hidden. i ( ( .., prefill_len - 1 ..prefill_len, ..) ) ?;
14921598
@@ -1526,7 +1632,7 @@ impl<'a> StreamingSession<'a> {
15261632 current_token_tensor : if done { None } else { Some ( first_token) } ,
15271633 frames_generated : 0 ,
15281634 frame_buffer : Vec :: new ( ) ,
1529- chunk_frames : options . chunk_frames ,
1635+ chunk_frames,
15301636 done,
15311637 trailing_text_hidden,
15321638 trailing_text_len,
@@ -1982,4 +2088,48 @@ mod tests {
19822088 let dtype = compute_dtype_for_device ( & Device :: Cpu ) ;
19832089 assert_eq ! ( dtype, DType :: F32 ) ;
19842090 }
2091+
2092+ #[ test]
2093+ fn test_update_penalty_mask ( ) {
2094+ let device = Device :: Cpu ;
2095+ let vocab_size = 3072 ;
2096+ let mut mask = Tensor :: zeros ( ( 1 , vocab_size) , DType :: F32 , & device) . unwrap ( ) ;
2097+
2098+ Qwen3TTS :: update_penalty_mask ( & mut mask, 42 , vocab_size) . unwrap ( ) ;
2099+
2100+ let vals: Vec < f32 > = mask. flatten_all ( ) . unwrap ( ) . to_vec1 ( ) . unwrap ( ) ;
2101+ assert_eq ! ( vals[ 42 ] , 1.0 ) ;
2102+ // Neighboring positions should be untouched
2103+ assert_eq ! ( vals[ 41 ] , 0.0 ) ;
2104+ assert_eq ! ( vals[ 43 ] , 0.0 ) ;
2105+ }
2106+
2107+ #[ test]
2108+ fn test_update_penalty_mask_out_of_range ( ) {
2109+ let device = Device :: Cpu ;
2110+ let vocab_size = 3072 ;
2111+ let mut mask = Tensor :: zeros ( ( 1 , vocab_size) , DType :: F32 , & device) . unwrap ( ) ;
2112+
2113+ // Token beyond vocab_size should be a no-op (no panic)
2114+ Qwen3TTS :: update_penalty_mask ( & mut mask, 9999 , vocab_size) . unwrap ( ) ;
2115+
2116+ let sum: f32 = mask. sum_all ( ) . unwrap ( ) . to_scalar ( ) . unwrap ( ) ;
2117+ assert_eq ! ( sum, 0.0 ) ;
2118+ }
2119+
2120+ #[ test]
2121+ fn test_suppression_mask_deterministic ( ) {
2122+ let device = Device :: Cpu ;
2123+ let vocab = codec_tokens:: CODEC_VOCAB_SIZE ;
2124+ let mask1 = generation:: build_suppression_mask ( vocab, CODEC_EOS_TOKEN_ID , & device) . unwrap ( ) ;
2125+ let mask2 = generation:: build_suppression_mask ( vocab, CODEC_EOS_TOKEN_ID , & device) . unwrap ( ) ;
2126+
2127+ // Apply both masks to uniform logits and verify identical output
2128+ let logits = Tensor :: ones ( ( 1 , vocab) , DType :: F32 , & device) . unwrap ( ) ;
2129+ let out1 = generation:: apply_token_suppression_with_mask ( & logits, & mask1) . unwrap ( ) ;
2130+ let out2 = generation:: apply_token_suppression_with_mask ( & logits, & mask2) . unwrap ( ) ;
2131+ let v1: Vec < f32 > = out1. flatten_all ( ) . unwrap ( ) . to_vec1 ( ) . unwrap ( ) ;
2132+ let v2: Vec < f32 > = out2. flatten_all ( ) . unwrap ( ) . to_vec1 ( ) . unwrap ( ) ;
2133+ assert_eq ! ( v1, v2) ;
2134+ }
19852135}
0 commit comments