Skip to content

Commit fe83fa0

Browse files
committed
Rewrite session and decoder to match reference I/O spec
- Encoder: separate pre_cache, [B,D,T] output, new_pre_cache loopback - Decoder: float32 h/c, names targets/h_in/c_in, output [1,640,1] - Joint: argmax baked in, token_id output, [B,D,1] inputs - RNNT decoder: strided [B,D,T] frame copy, maxSymbolsPerStep=2 WIP: batch transcription produces empty text — need to debug encoder output handling (8 frames vs 4, encoded_length mismatch)
1 parent b01d58b commit fe83fa0

2 files changed

Lines changed: 105 additions & 272 deletions

File tree

Sources/ParakeetStreamingASR/RNNTGreedyDecoder.swift

Lines changed: 27 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -29,30 +29,19 @@ struct RNNTDecodeResult {
2929

3030
/// Greedy RNNT decoder for Parakeet EOU streaming ASR.
3131
///
32-
/// Unlike the TDT decoder, RNNT has no duration bins — blank always advances
33-
/// by one encoder frame. EOU is detected when the model emits token ID 1024.
32+
/// Matches reference implementation:
33+
/// - Encoder output: [B, D, T] (channels-first)
34+
/// - Decoder: float32 h/c, output [B, D, 1]
35+
/// - Joint: argmax baked in, outputs token_id directly
3436
struct RNNTGreedyDecoder {
3537
let config: ParakeetEOUConfig
3638
let decoder: MLModel
3739
let joint: MLModel
3840

39-
/// Maximum non-blank tokens per encoder frame (safety limit).
40-
private let maxSymbolsPerStep = 10
41+
/// Maximum non-blank tokens per encoder frame.
42+
private let maxSymbolsPerStep = 2 // Matches reference
4143

4244
/// Decode encoder output with persistent LSTM state.
43-
///
44-
/// - Parameters:
45-
/// - encoded: Encoder output MLMultiArray, shape `[1, encoderHidden, T]` (channels-first)
46-
/// - encodedLength: Number of valid encoder frames
47-
/// - h: LSTM hidden state (mutated in place)
48-
/// - c: LSTM cell state (mutated in place)
49-
/// - decoderOutput: Previous decoder output (mutated in place)
50-
/// - decoderProvider: Reusable feature provider for decoder
51-
/// - jointProvider: Reusable feature provider for joint
52-
/// - tokenArray: Pre-allocated token MLMultiArray [1, 1]
53-
/// - encSlice: Pre-allocated encoder slice [1, 1, encoderHidden]
54-
/// - argmaxBuf: Pre-allocated Float buffer for vDSP argmax
55-
/// - Returns: Decode result with tokens, log-probs, and EOU flag
5645
func decode(
5746
encoded: MLMultiArray,
5847
encodedLength: Int,
@@ -62,29 +51,25 @@ struct RNNTGreedyDecoder {
6251
decoderProvider: ReusableFeatureProvider,
6352
jointProvider: ReusableFeatureProvider,
6453
tokenArray: MLMultiArray,
65-
encSlice: MLMultiArray,
66-
argmaxBuf: UnsafeMutablePointer<Float>
54+
encSlice: MLMultiArray
6755
) throws -> RNNTDecodeResult {
6856
var tokens = [Int]()
6957
var tokenLogProbs = [Float]()
7058
var eouDetected = false
7159

7260
let tokenPtr = tokenArray.dataPointer.assumingMemoryBound(to: Int32.self)
73-
let totalClasses = config.vocabSize + 1 // vocab + blank
7461

7562
for t in 0..<encodedLength {
76-
// Extract encoder frame at position t — encoder output is [1, D, T] (channels-first)
63+
// Extract encoder frame t from [B, D, T] → [B, D, 1]
7764
copyEncoderFrame(from: encoded, at: t, to: encSlice)
7865

7966
for _ in 0..<maxSymbolsPerStep {
80-
// Joint network: (encoder_slice, decoder_output) → logits
81-
jointProvider.update("encoder_output", encSlice)
82-
jointProvider.update("decoder_output", decoderOutput)
67+
// Joint: (encoder_step [1,D,1], decoder_step [1,D,1]) → token_id
68+
jointProvider.update("encoder_step", encSlice)
69+
jointProvider.update("decoder_step", decoderOutput)
8370
let jointOut = try joint.prediction(from: jointProvider)
84-
let logits = jointOut.featureValue(for: "logits")!.multiArrayValue!
85-
86-
let tokenId = argmax(logits, count: totalClasses, floatBuf: argmaxBuf)
87-
71+
let tokenIdArray = jointOut.featureValue(for: "token_id")!.multiArrayValue!
72+
let tokenId = Int(tokenIdArray[0].int32Value)
8873

8974
if tokenId == config.blankTokenId {
9075
break // Advance to next encoder frame
@@ -97,15 +82,16 @@ struct RNNTGreedyDecoder {
9782

9883
// Emit token
9984
tokens.append(tokenId)
100-
let logProb = logSoftmax(logits, tokenId: tokenId, count: totalClasses, floatBuf: argmaxBuf)
101-
tokenLogProbs.append(logProb)
85+
// Get probability for confidence
86+
let probArray = jointOut.featureValue(for: "token_prob")!.multiArrayValue!
87+
tokenLogProbs.append(Float(truncating: probArray[0]))
10288

10389
// Update decoder LSTM with emitted token
10490
tokenPtr.pointee = Int32(tokenId)
105-
decoderProvider.update("h", h)
106-
decoderProvider.update("c", c)
91+
decoderProvider.update("h_in", h)
92+
decoderProvider.update("c_in", c)
10793
let decOut = try decoder.prediction(from: decoderProvider)
108-
decoderOutput = decOut.featureValue(for: "decoder_output")!.multiArrayValue!
94+
decoderOutput = decOut.featureValue(for: "decoder")!.multiArrayValue!
10995
h = decOut.featureValue(for: "h_out")!.multiArrayValue!
11096
c = decOut.featureValue(for: "c_out")!.multiArrayValue!
11197
}
@@ -118,43 +104,15 @@ struct RNNTGreedyDecoder {
118104

119105
// MARK: - Array Operations
120106

121-
/// Copy encoder frame at time `t` from [B, T, D] layout.
122-
/// Output slice is [1, 1, D] for joint network input.
107+
/// Copy encoder frame at time `t` from [B, D, T] layout to [B, D, 1].
123108
private func copyEncoderFrame(from encoded: MLMultiArray, at t: Int, to slice: MLMultiArray) {
124109
let hidden = config.encoderHidden
125-
// encoded is [1, T, D] — frame t is contiguous at offset t * D
126-
let src = encoded.dataPointer.advanced(by: t * hidden * MemoryLayout<Float16>.stride)
127-
memcpy(slice.dataPointer, src, hidden * MemoryLayout<Float16>.stride)
128-
}
129-
130-
private func logSoftmax(_ array: MLMultiArray, tokenId: Int, count: Int, floatBuf: UnsafeMutablePointer<Float>) -> Float {
131-
let ptr = array.dataPointer.assumingMemoryBound(to: Float16.self)
132-
for i in 0..<count { floatBuf[i] = Float(ptr[i]) }
133-
134-
var maxVal: Float = 0
135-
var maxIdx: vDSP_Length = 0
136-
vDSP_maxvi(floatBuf, 1, &maxVal, &maxIdx, vDSP_Length(count))
137-
138-
var negMax = -maxVal
139-
vDSP_vsadd(floatBuf, 1, &negMax, floatBuf, 1, vDSP_Length(count))
140-
141-
var n = Int32(count)
142-
vvexpf(floatBuf, floatBuf, &n)
143-
144-
var sumExp: Float = 0
145-
vDSP_sve(floatBuf, 1, &sumExp, vDSP_Length(count))
146-
147-
let logSumExp = log(sumExp) + maxVal
148-
let logit = Float(ptr[tokenId])
149-
return logit - logSumExp
150-
}
151-
152-
private func argmax(_ array: MLMultiArray, count: Int, floatBuf: UnsafeMutablePointer<Float>) -> Int {
153-
let ptr = array.dataPointer.assumingMemoryBound(to: Float16.self)
154-
for i in 0..<count { floatBuf[i] = Float(ptr[i]) }
155-
var maxVal: Float = 0
156-
var maxIdx: vDSP_Length = 0
157-
vDSP_maxvi(floatBuf, 1, &maxVal, &maxIdx, vDSP_Length(count))
158-
return Int(maxIdx)
110+
let totalFrames = encoded.shape[2].intValue
111+
// encoded is [1, D, T] float16 from CoreML — copy with stride
112+
let src = encoded.dataPointer.assumingMemoryBound(to: Float16.self)
113+
let dst = slice.dataPointer.assumingMemoryBound(to: Float.self)
114+
for d in 0..<hidden {
115+
dst[d] = Float(src[d * totalFrames + t])
116+
}
159117
}
160118
}

0 commit comments

Comments
 (0)