Skip to content

Commit 31a96e7

Browse files
juntaoclaude
andcommitted
Cast BF16 weights to F32 on CPU for libtorch compatibility
libtorch CPU backend cannot do BF16 matmul. When running on CPU, cast all weights to F32 during loading. Signed-off-by: Michael Yuan <michael@secondstate.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 857d458 commit 31a96e7

1 file changed

Lines changed: 10 additions & 2 deletions

File tree

src/model/weights.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use std::path::Path;
1212

1313
use crate::config::VoxtralConfig;
1414
use crate::error::{Result, VoxtralError};
15-
use crate::tensor::{Device, Tensor};
15+
use crate::tensor::{DType, Device, Tensor};
1616

1717
use super::backbone::{Backbone, BackboneConfig};
1818
use super::codec::Codec;
@@ -56,13 +56,21 @@ pub fn load_model_weights(
5656
model_dir.display()
5757
);
5858

59+
// On CPU, libtorch cannot do BF16 matmul — cast all weights to F32.
60+
let need_f32 = matches!(device, Device::Cpu);
61+
5962
// Load all tensors into a single HashMap
6063
let mut all_weights: HashMap<String, Tensor> = HashMap::new();
6164
for path in &safetensors_files {
6265
tracing::debug!("Loading {}", path.display());
6366
let tensors = Tensor::load_safetensors(path)?;
6467
for (name, tensor) in tensors {
65-
all_weights.insert(name, tensor.to_device(device));
68+
let tensor = if need_f32 {
69+
tensor.to_dtype(DType::Float32).to_device(device)
70+
} else {
71+
tensor.to_device(device)
72+
};
73+
all_weights.insert(name, tensor);
6674
}
6775
}
6876

0 commit comments

Comments
 (0)