@@ -29,30 +29,19 @@ struct RNNTDecodeResult {
2929
3030/// Greedy RNNT decoder for Parakeet EOU streaming ASR.
3131///
32- /// Unlike the TDT decoder, RNNT has no duration bins — blank always advances
33- /// by one encoder frame. EOU is detected when the model emits token ID 1024.
32+ /// Matches reference implementation:
33+ /// - Encoder output: [B, D, T] (channels-first)
34+ /// - Decoder: float32 h/c, output [B, D, 1]
35+ /// - Joint: argmax baked in, outputs token_id directly
3436struct RNNTGreedyDecoder {
3537 let config : ParakeetEOUConfig
3638 let decoder : MLModel
3739 let joint : MLModel
3840
39- /// Maximum non-blank tokens per encoder frame (safety limit) .
40- private let maxSymbolsPerStep = 10
41+ /// Maximum non-blank tokens per encoder frame.
42+ private let maxSymbolsPerStep = 2 // Matches reference
4143
4244 /// Decode encoder output with persistent LSTM state.
43- ///
44- /// - Parameters:
45- /// - encoded: Encoder output MLMultiArray, shape `[1, encoderHidden, T]` (channels-first)
46- /// - encodedLength: Number of valid encoder frames
47- /// - h: LSTM hidden state (mutated in place)
48- /// - c: LSTM cell state (mutated in place)
49- /// - decoderOutput: Previous decoder output (mutated in place)
50- /// - decoderProvider: Reusable feature provider for decoder
51- /// - jointProvider: Reusable feature provider for joint
52- /// - tokenArray: Pre-allocated token MLMultiArray [1, 1]
53- /// - encSlice: Pre-allocated encoder slice [1, 1, encoderHidden]
54- /// - argmaxBuf: Pre-allocated Float buffer for vDSP argmax
55- /// - Returns: Decode result with tokens, log-probs, and EOU flag
5645 func decode(
5746 encoded: MLMultiArray ,
5847 encodedLength: Int ,
@@ -62,29 +51,25 @@ struct RNNTGreedyDecoder {
6251 decoderProvider: ReusableFeatureProvider ,
6352 jointProvider: ReusableFeatureProvider ,
6453 tokenArray: MLMultiArray ,
65- encSlice: MLMultiArray ,
66- argmaxBuf: UnsafeMutablePointer < Float >
54+ encSlice: MLMultiArray
6755 ) throws -> RNNTDecodeResult {
6856 var tokens = [ Int] ( )
6957 var tokenLogProbs = [ Float] ( )
7058 var eouDetected = false
7159
7260 let tokenPtr = tokenArray. dataPointer. assumingMemoryBound ( to: Int32 . self)
73- let totalClasses = config. vocabSize + 1 // vocab + blank
7461
7562 for t in 0 ..< encodedLength {
76- // Extract encoder frame at position t — encoder output is [1 , D, T] (channels-first)
63+ // Extract encoder frame t from [B, D, T] → [B , D, 1]
7764 copyEncoderFrame ( from: encoded, at: t, to: encSlice)
7865
7966 for _ in 0 ..< maxSymbolsPerStep {
80- // Joint network : (encoder_slice, decoder_output ) → logits
81- jointProvider. update ( " encoder_output " , encSlice)
82- jointProvider. update ( " decoder_output " , decoderOutput)
67+ // Joint: (encoder_step [1,D,1], decoder_step [1,D,1] ) → token_id
68+ jointProvider. update ( " encoder_step " , encSlice)
69+ jointProvider. update ( " decoder_step " , decoderOutput)
8370 let jointOut = try joint. prediction ( from: jointProvider)
84- let logits = jointOut. featureValue ( for: " logits " ) !. multiArrayValue!
85-
86- let tokenId = argmax ( logits, count: totalClasses, floatBuf: argmaxBuf)
87-
71+ let tokenIdArray = jointOut. featureValue ( for: " token_id " ) !. multiArrayValue!
72+ let tokenId = Int ( tokenIdArray [ 0 ] . int32Value)
8873
8974 if tokenId == config. blankTokenId {
9075 break // Advance to next encoder frame
@@ -97,15 +82,16 @@ struct RNNTGreedyDecoder {
9782
9883 // Emit token
9984 tokens. append ( tokenId)
100- let logProb = logSoftmax ( logits, tokenId: tokenId, count: totalClasses, floatBuf: argmaxBuf)
101- tokenLogProbs. append ( logProb)
85+ // Get probability for confidence
86+ let probArray = jointOut. featureValue ( for: " token_prob " ) !. multiArrayValue!
87+ tokenLogProbs. append ( Float ( truncating: probArray [ 0 ] ) )
10288
10389 // Update decoder LSTM with emitted token
10490 tokenPtr. pointee = Int32 ( tokenId)
105- decoderProvider. update ( " h " , h)
106- decoderProvider. update ( " c " , c)
91+ decoderProvider. update ( " h_in " , h)
92+ decoderProvider. update ( " c_in " , c)
10793 let decOut = try decoder. prediction ( from: decoderProvider)
108- decoderOutput = decOut. featureValue ( for: " decoder_output " ) !. multiArrayValue!
94+ decoderOutput = decOut. featureValue ( for: " decoder " ) !. multiArrayValue!
10995 h = decOut. featureValue ( for: " h_out " ) !. multiArrayValue!
11096 c = decOut. featureValue ( for: " c_out " ) !. multiArrayValue!
11197 }
@@ -118,43 +104,15 @@ struct RNNTGreedyDecoder {
118104
119105 // MARK: - Array Operations
120106
121- /// Copy encoder frame at time `t` from [B, T, D] layout.
122- /// Output slice is [1, 1, D] for joint network input.
107+ /// Copy encoder frame at time `t` from [B, D, T] layout to [B, D, 1].
123108 private func copyEncoderFrame( from encoded: MLMultiArray , at t: Int , to slice: MLMultiArray ) {
124109 let hidden = config. encoderHidden
125- // encoded is [1, T, D] — frame t is contiguous at offset t * D
126- let src = encoded. dataPointer. advanced ( by: t * hidden * MemoryLayout< Float16> . stride)
127- memcpy ( slice. dataPointer, src, hidden * MemoryLayout< Float16> . stride)
128- }
129-
130- private func logSoftmax( _ array: MLMultiArray , tokenId: Int , count: Int , floatBuf: UnsafeMutablePointer < Float > ) -> Float {
131- let ptr = array. dataPointer. assumingMemoryBound ( to: Float16 . self)
132- for i in 0 ..< count { floatBuf [ i] = Float ( ptr [ i] ) }
133-
134- var maxVal : Float = 0
135- var maxIdx : vDSP_Length = 0
136- vDSP_maxvi ( floatBuf, 1 , & maxVal, & maxIdx, vDSP_Length ( count) )
137-
138- var negMax = - maxVal
139- vDSP_vsadd ( floatBuf, 1 , & negMax, floatBuf, 1 , vDSP_Length ( count) )
140-
141- var n = Int32 ( count)
142- vvexpf ( floatBuf, floatBuf, & n)
143-
144- var sumExp : Float = 0
145- vDSP_sve ( floatBuf, 1 , & sumExp, vDSP_Length ( count) )
146-
147- let logSumExp = log ( sumExp) + maxVal
148- let logit = Float ( ptr [ tokenId] )
149- return logit - logSumExp
150- }
151-
152- private func argmax( _ array: MLMultiArray , count: Int , floatBuf: UnsafeMutablePointer < Float > ) -> Int {
153- let ptr = array. dataPointer. assumingMemoryBound ( to: Float16 . self)
154- for i in 0 ..< count { floatBuf [ i] = Float ( ptr [ i] ) }
155- var maxVal : Float = 0
156- var maxIdx : vDSP_Length = 0
157- vDSP_maxvi ( floatBuf, 1 , & maxVal, & maxIdx, vDSP_Length ( count) )
158- return Int ( maxIdx)
110+ let totalFrames = encoded. shape [ 2 ] . intValue
111+ // encoded is [1, D, T] float16 from CoreML — copy with stride
112+ let src = encoded. dataPointer. assumingMemoryBound ( to: Float16 . self)
113+ let dst = slice. dataPointer. assumingMemoryBound ( to: Float . self)
114+ for d in 0 ..< hidden {
115+ dst [ d] = Float ( src [ d * totalFrames + t] )
116+ }
159117 }
160118}
0 commit comments