Skip to content

Commit 0a2aa16

Browse files
committed
Use model-managed pre_cache input/output, exact 64 mel frames
- Encoder now has separate pre_cache [1,128,9] input and new_pre_cache output — model handles concatenation internally - Remove manual mel prepending and cache rotation code - Chunk samples = (melFrames-1) * hopLength = 10080 → exactly 64 mel frames - All discrepancies with reference implementation resolved: symmetric Hann window, zero padding, centered FFT, FFT/4 scaling, float32 decoder/joint, separate pre_cache, 64 mel frames
1 parent d550d8d commit 0a2aa16

1 file changed

Lines changed: 17 additions & 69 deletions

File tree

Sources/ParakeetStreamingASR/StreamingSession.swift

Lines changed: 17 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@ public class StreamingSession {
1717
private let rnntDecoder: RNNTGreedyDecoder
1818

1919
// Encoder cache state
20+
private var preCache: MLMultiArray // [1, 128, preCacheSize] — looped back from encoder
2021
private var cacheLastChannel: MLMultiArray
2122
private var cacheLastTime: MLMultiArray
2223
private var cacheLastChannelLen: MLMultiArray
23-
// Pre-encode mel cache: last preCacheSize mel frames from previous chunk
24-
private var preEncodeMelCache: [Float]
2524

2625
// Decoder LSTM state
2726
private var h: MLMultiArray
@@ -64,11 +63,12 @@ public class StreamingSession {
6463
let hidden = config.encoderHidden
6564
let attCtx = config.attentionContext
6665
let convCache = config.convCacheSize
66+
let preCacheSize = config.streaming.preCacheSize
6767

68-
// Pre-encode mel cache: zeros for first chunk
69-
preEncodeMelCache = [Float](repeating: 0,
70-
count: config.numMelBins * config.streaming.preCacheSize)
71-
68+
preCache = try MLMultiArray(
69+
shape: [1, config.numMelBins as NSNumber, preCacheSize as NSNumber], dataType: .float32)
70+
memset(preCache.dataPointer, 0,
71+
config.numMelBins * preCacheSize * MemoryLayout<Float>.stride)
7272
cacheLastChannel = try MLMultiArray(
7373
shape: [layers, 1, attCtx, hidden] as [NSNumber], dataType: .float32)
7474
cacheLastTime = try MLMultiArray(
@@ -124,7 +124,7 @@ public class StreamingSession {
124124

125125
sampleBuffer.append(contentsOf: samples)
126126

127-
let samplesPerChunk = config.streaming.melFrames * config.hopLength
127+
let samplesPerChunk = (config.streaming.melFrames - 1) * config.hopLength
128128
var results: [ParakeetStreamingASRModel.PartialTranscript] = []
129129

130130
while sampleBuffer.count >= samplesPerChunk {
@@ -147,7 +147,7 @@ public class StreamingSession {
147147
// Process remaining buffered samples
148148
if !sampleBuffer.isEmpty && !eouDetected {
149149
// Pad to full chunk size
150-
let samplesPerChunk = config.streaming.melFrames * config.hopLength
150+
let samplesPerChunk = (config.streaming.melFrames - 1) * config.hopLength
151151
let padded = sampleBuffer + [Float](repeating: 0, count: max(0, samplesPerChunk - sampleBuffer.count))
152152
sampleBuffer.removeAll()
153153
if let partial = try processChunk(Array(padded.prefix(samplesPerChunk))) {
@@ -195,30 +195,23 @@ public class StreamingSession {
195195
}
196196
guard melLength > 0 else { return nil }
197197

198-
// Truncate/pad chunk mel to exact expected frame count
198+
// Truncate/pad mel to exact expected frame count
199199
let expectedFrames = config.streaming.melFrames
200200
let actualMelFrames = rawMel.shape[2].intValue
201-
let chunkMel: MLMultiArray
201+
let mel: MLMultiArray
202202
if actualMelFrames > expectedFrames {
203-
chunkMel = try truncateMel(rawMel, to: expectedFrames)
203+
mel = try truncateMel(rawMel, to: expectedFrames)
204204
} else if actualMelFrames < expectedFrames {
205-
chunkMel = try padMel(rawMel, actualLength: actualMelFrames, targetLength: expectedFrames)
205+
mel = try padMel(rawMel, actualLength: actualMelFrames, targetLength: expectedFrames)
206206
} else {
207-
chunkMel = rawMel
207+
mel = rawMel
208208
}
209209

210-
// Prepend pre-encode mel cache to chunk mel for encoder input
211-
let preCacheSize = config.streaming.preCacheSize
212-
let totalFrames = preCacheSize + expectedFrames
213-
let mel = try prependMelCache(chunkMel, expectedFrames: expectedFrames, totalFrames: totalFrames)
214-
215-
// Save last preCacheSize frames of chunk mel for next iteration
216-
savePreEncodeMelCache(from: chunkMel, frames: expectedFrames)
217-
218-
// Run cache-aware encoder
210+
// Run encoder — pre_cache is a separate input, model concatenates internally
219211
let encoderInput = try MLDictionaryFeatureProvider(dictionary: [
220212
"audio_signal": MLFeatureValue(multiArray: mel),
221213
"audio_length": MLFeatureValue(multiArray: makeInt32Array(value: Int32(expectedFrames))),
214+
"pre_cache": MLFeatureValue(multiArray: preCache),
222215
"cache_last_channel": MLFeatureValue(multiArray: cacheLastChannel),
223216
"cache_last_time": MLFeatureValue(multiArray: cacheLastTime),
224217
"cache_last_channel_len": MLFeatureValue(multiArray: cacheLastChannelLen),
@@ -232,7 +225,8 @@ public class StreamingSession {
232225
let actualFrames = encoded.shape[1].intValue
233226
let encodedLength = min(reportedLength, actualFrames)
234227

235-
// Update encoder caches
228+
// Update encoder caches (including pre_cache loopback)
229+
preCache = encoderOutput.featureValue(for: "new_pre_cache")!.multiArrayValue!
236230
cacheLastChannel = encoderOutput.featureValue(for: "new_cache_last_channel")!.multiArrayValue!
237231
cacheLastTime = encoderOutput.featureValue(for: "new_cache_last_time")!.multiArrayValue!
238232
cacheLastChannelLen = encoderOutput.featureValue(for: "new_cache_last_channel_len")!.multiArrayValue!
@@ -318,52 +312,6 @@ public class StreamingSession {
318312

319313
// MARK: - Pre-encode Mel Cache
320314

321-
/// Prepend pre-encode mel cache to chunk mel, creating [1, 128, totalFrames].
322-
private func prependMelCache(_ chunkMel: MLMultiArray, expectedFrames: Int, totalFrames: Int) throws -> MLMultiArray {
323-
let numBins = config.numMelBins
324-
let preCacheSize = config.streaming.preCacheSize
325-
let mel = try MLMultiArray(
326-
shape: [1, numBins as NSNumber, totalFrames as NSNumber], dataType: .float32)
327-
let dst = mel.dataPointer.assumingMemoryBound(to: Float.self)
328-
let src = chunkMel.dataPointer.assumingMemoryBound(to: Float.self)
329-
330-
for bin in 0..<numBins {
331-
let dstOffset = bin * totalFrames
332-
let cacheOffset = bin * preCacheSize
333-
let srcOffset = bin * expectedFrames
334-
335-
// Copy pre-encode cache (preCacheSize frames)
336-
memcpy(dst.advanced(by: dstOffset),
337-
preEncodeMelCache.withUnsafeBufferPointer { $0.baseAddress! }.advanced(by: cacheOffset),
338-
preCacheSize * MemoryLayout<Float>.stride)
339-
340-
// Copy chunk mel (expectedFrames frames)
341-
memcpy(dst.advanced(by: dstOffset + preCacheSize),
342-
src.advanced(by: srcOffset),
343-
expectedFrames * MemoryLayout<Float>.stride)
344-
}
345-
return mel
346-
}
347-
348-
/// Save last preCacheSize frames of chunk mel for next iteration.
349-
private func savePreEncodeMelCache(from chunkMel: MLMultiArray, frames: Int) {
350-
let numBins = config.numMelBins
351-
let preCacheSize = config.streaming.preCacheSize
352-
let src = chunkMel.dataPointer.assumingMemoryBound(to: Float.self)
353-
let startFrame = max(0, frames - preCacheSize)
354-
let copyFrames = min(preCacheSize, frames)
355-
356-
for bin in 0..<numBins {
357-
let srcOffset = bin * frames + startFrame
358-
let dstOffset = bin * preCacheSize + (preCacheSize - copyFrames)
359-
preEncodeMelCache.withUnsafeMutableBufferPointer { buf in
360-
memcpy(buf.baseAddress!.advanced(by: dstOffset),
361-
src.advanced(by: srcOffset),
362-
copyFrames * MemoryLayout<Float>.stride)
363-
}
364-
}
365-
}
366-
367315
private func makeInt32Array(value: Int32) throws -> MLMultiArray {
368316
let array = try MLMultiArray(shape: [1], dataType: .int32)
369317
array[0] = NSNumber(value: value)

0 commit comments

Comments
 (0)