Skip to content

Commit a1ece5c

Browse files
committed
fix(quant): extend GPU quantization restriction to CUDA backend
1 parent 0792aac commit a1ece5c

2 files changed

Lines changed: 25 additions & 11 deletions

File tree

crates/infer-deepseek/src/model/mod.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,19 @@ impl ImageProjector {
240240
let quant = QuantizationState::global();
241241
let config = quant.config();
242242
let mut qmatmul: Option<std::sync::Arc<QMatMul>> = None;
243+
// GPU fast-fail: disallow runtime quantization on Metal/CUDA for projector as well.
244+
if (weight.device().is_metal() || weight.device().is_cuda())
245+
&& config.kind.is_enabled()
246+
&& quant.enabled_for(LinearLayerGroup::Projector)
247+
{
248+
anyhow::bail!(
249+
"GPU backend: runtime quantization is disabled on Metal/CUDA. Refusing to fallback.\n\
250+
Disable quantization (DEEPSEEK_OCR_QUANT=none) or run on CPU.\n\
251+
Context: module=projector, in_dim={}, backend={}",
252+
input_dim,
253+
crate::quantization::backend_label(&weight.device())
254+
);
255+
}
243256
if quant.enabled_for(LinearLayerGroup::Projector) {
244257
match config.kind {
245258
QuantizationKind::Q8_0 => {

crates/infer-deepseek/src/transformer/weights.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -444,18 +444,19 @@ fn maybe_quantize_linear(
444444
) -> Result<Option<Arc<QMatMul>>> {
445445
let quant = QuantizationState::global();
446446
let config = quant.config();
447-
// Disable runtime quantization entirely on Metal to avoid MPS kernel issues.
448-
if weight.device().is_metal() {
449-
tracing::trace!(
450-
tensor = tensor_name,
451-
?group,
452-
action = "fallback",
453-
reason = "metal_disabled",
454-
backend = crate::quantization::backend_label(&weight.device()),
455-
"quant-linear"
447+
// GPU fast-fail: if quant is requested for this group on Metal/CUDA, error out (awaiting upstream kernel fixes).
448+
if (weight.device().is_metal() || weight.device().is_cuda())
449+
&& config.kind.is_enabled()
450+
&& quant.enabled_for(group)
451+
{
452+
anyhow::bail!(
453+
"GPU backend: runtime quantization is disabled on Metal/CUDA. Refusing to fallback.\n\
454+
Disable quantization (DEEPSEEK_OCR_QUANT=none) or run on CPU.\n\
455+
Context: tensor={}, group={:?}, backend={}",
456+
tensor_name,
457+
group,
458+
crate::quantization::backend_label(&weight.device())
456459
);
457-
quant.record_attempt(module, QuantizationOutcome::Fallback);
458-
return Ok(None);
459460
}
460461
if !quant.enabled_for(group) {
461462
trace!(

0 commit comments

Comments
 (0)