@@ -17,11 +17,10 @@ public class StreamingSession {
1717 private let rnntDecoder : RNNTGreedyDecoder
1818
1919 // Encoder cache state
20+ private var preCache : MLMultiArray // [1, 128, preCacheSize] — looped back from encoder
2021 private var cacheLastChannel : MLMultiArray
2122 private var cacheLastTime : MLMultiArray
2223 private var cacheLastChannelLen : MLMultiArray
23- // Pre-encode mel cache: last preCacheSize mel frames from previous chunk
24- private var preEncodeMelCache : [ Float ]
2524
2625 // Decoder LSTM state
2726 private var h : MLMultiArray
@@ -64,11 +63,12 @@ public class StreamingSession {
6463 let hidden = config. encoderHidden
6564 let attCtx = config. attentionContext
6665 let convCache = config. convCacheSize
66+ let preCacheSize = config. streaming. preCacheSize
6767
68- // Pre-encode mel cache: zeros for first chunk
69- preEncodeMelCache = [ Float ] ( repeating : 0 ,
70- count : config . numMelBins * config . streaming . preCacheSize )
71-
68+ preCache = try MLMultiArray (
69+ shape : [ 1 , config . numMelBins as NSNumber , preCacheSize as NSNumber ] , dataType : . float32 )
70+ memset ( preCache . dataPointer , 0 ,
71+ config . numMelBins * preCacheSize * MemoryLayout < Float > . stride )
7272 cacheLastChannel = try MLMultiArray (
7373 shape: [ layers, 1 , attCtx, hidden] as [ NSNumber ] , dataType: . float32)
7474 cacheLastTime = try MLMultiArray (
@@ -124,7 +124,7 @@ public class StreamingSession {
124124
125125 sampleBuffer. append ( contentsOf: samples)
126126
127- let samplesPerChunk = config. streaming. melFrames * config. hopLength
127+ let samplesPerChunk = ( config. streaming. melFrames - 1 ) * config. hopLength
128128 var results : [ ParakeetStreamingASRModel . PartialTranscript ] = [ ]
129129
130130 while sampleBuffer. count >= samplesPerChunk {
@@ -147,7 +147,7 @@ public class StreamingSession {
147147 // Process remaining buffered samples
148148 if !sampleBuffer. isEmpty && !eouDetected {
149149 // Pad to full chunk size
150- let samplesPerChunk = config. streaming. melFrames * config. hopLength
150+ let samplesPerChunk = ( config. streaming. melFrames - 1 ) * config. hopLength
151151 let padded = sampleBuffer + [ Float] ( repeating: 0 , count: max ( 0 , samplesPerChunk - sampleBuffer. count) )
152152 sampleBuffer. removeAll ( )
153153 if let partial = try processChunk ( Array ( padded. prefix ( samplesPerChunk) ) ) {
@@ -195,30 +195,23 @@ public class StreamingSession {
195195 }
196196 guard melLength > 0 else { return nil }
197197
198- // Truncate/pad chunk mel to exact expected frame count
198+ // Truncate/pad mel to exact expected frame count
199199 let expectedFrames = config. streaming. melFrames
200200 let actualMelFrames = rawMel. shape [ 2 ] . intValue
201- let chunkMel : MLMultiArray
201+ let mel : MLMultiArray
202202 if actualMelFrames > expectedFrames {
203- chunkMel = try truncateMel ( rawMel, to: expectedFrames)
203+ mel = try truncateMel ( rawMel, to: expectedFrames)
204204 } else if actualMelFrames < expectedFrames {
205- chunkMel = try padMel ( rawMel, actualLength: actualMelFrames, targetLength: expectedFrames)
205+ mel = try padMel ( rawMel, actualLength: actualMelFrames, targetLength: expectedFrames)
206206 } else {
207- chunkMel = rawMel
207+ mel = rawMel
208208 }
209209
210- // Prepend pre-encode mel cache to chunk mel for encoder input
211- let preCacheSize = config. streaming. preCacheSize
212- let totalFrames = preCacheSize + expectedFrames
213- let mel = try prependMelCache ( chunkMel, expectedFrames: expectedFrames, totalFrames: totalFrames)
214-
215- // Save last preCacheSize frames of chunk mel for next iteration
216- savePreEncodeMelCache ( from: chunkMel, frames: expectedFrames)
217-
218- // Run cache-aware encoder
210+ // Run encoder — pre_cache is a separate input, model concatenates internally
219211 let encoderInput = try MLDictionaryFeatureProvider ( dictionary: [
220212 " audio_signal " : MLFeatureValue ( multiArray: mel) ,
221213 " audio_length " : MLFeatureValue ( multiArray: makeInt32Array ( value: Int32 ( expectedFrames) ) ) ,
214+ " pre_cache " : MLFeatureValue ( multiArray: preCache) ,
222215 " cache_last_channel " : MLFeatureValue ( multiArray: cacheLastChannel) ,
223216 " cache_last_time " : MLFeatureValue ( multiArray: cacheLastTime) ,
224217 " cache_last_channel_len " : MLFeatureValue ( multiArray: cacheLastChannelLen) ,
@@ -232,7 +225,8 @@ public class StreamingSession {
232225 let actualFrames = encoded. shape [ 1 ] . intValue
233226 let encodedLength = min ( reportedLength, actualFrames)
234227
235- // Update encoder caches
228+ // Update encoder caches (including pre_cache loopback)
229+ preCache = encoderOutput. featureValue ( for: " new_pre_cache " ) !. multiArrayValue!
236230 cacheLastChannel = encoderOutput. featureValue ( for: " new_cache_last_channel " ) !. multiArrayValue!
237231 cacheLastTime = encoderOutput. featureValue ( for: " new_cache_last_time " ) !. multiArrayValue!
238232 cacheLastChannelLen = encoderOutput. featureValue ( for: " new_cache_last_channel_len " ) !. multiArrayValue!
@@ -318,52 +312,6 @@ public class StreamingSession {
318312
319313 // MARK: - Pre-encode Mel Cache
320314
321- /// Prepend pre-encode mel cache to chunk mel, creating [1, 128, totalFrames].
322- private func prependMelCache( _ chunkMel: MLMultiArray , expectedFrames: Int , totalFrames: Int ) throws -> MLMultiArray {
323- let numBins = config. numMelBins
324- let preCacheSize = config. streaming. preCacheSize
325- let mel = try MLMultiArray (
326- shape: [ 1 , numBins as NSNumber , totalFrames as NSNumber ] , dataType: . float32)
327- let dst = mel. dataPointer. assumingMemoryBound ( to: Float . self)
328- let src = chunkMel. dataPointer. assumingMemoryBound ( to: Float . self)
329-
330- for bin in 0 ..< numBins {
331- let dstOffset = bin * totalFrames
332- let cacheOffset = bin * preCacheSize
333- let srcOffset = bin * expectedFrames
334-
335- // Copy pre-encode cache (preCacheSize frames)
336- memcpy ( dst. advanced ( by: dstOffset) ,
337- preEncodeMelCache. withUnsafeBufferPointer { $0. baseAddress! } . advanced ( by: cacheOffset) ,
338- preCacheSize * MemoryLayout< Float> . stride)
339-
340- // Copy chunk mel (expectedFrames frames)
341- memcpy ( dst. advanced ( by: dstOffset + preCacheSize) ,
342- src. advanced ( by: srcOffset) ,
343- expectedFrames * MemoryLayout< Float> . stride)
344- }
345- return mel
346- }
347-
348- /// Save last preCacheSize frames of chunk mel for next iteration.
349- private func savePreEncodeMelCache( from chunkMel: MLMultiArray , frames: Int ) {
350- let numBins = config. numMelBins
351- let preCacheSize = config. streaming. preCacheSize
352- let src = chunkMel. dataPointer. assumingMemoryBound ( to: Float . self)
353- let startFrame = max ( 0 , frames - preCacheSize)
354- let copyFrames = min ( preCacheSize, frames)
355-
356- for bin in 0 ..< numBins {
357- let srcOffset = bin * frames + startFrame
358- let dstOffset = bin * preCacheSize + ( preCacheSize - copyFrames)
359- preEncodeMelCache. withUnsafeMutableBufferPointer { buf in
360- memcpy ( buf. baseAddress!. advanced ( by: dstOffset) ,
361- src. advanced ( by: srcOffset) ,
362- copyFrames * MemoryLayout< Float> . stride)
363- }
364- }
365- }
366-
367315 private func makeInt32Array( value: Int32 ) throws -> MLMultiArray {
368316 let array = try MLMultiArray ( shape: [ 1 ] , dataType: . int32)
369317 array [ 0 ] = NSNumber ( value: value)
0 commit comments