Skip to content

Commit e8217ff

Browse files
[Misc] add multi modal model fallback in generic model config parse (#589)
1 parent 4e4208e commit e8217ff

4 files changed

Lines changed: 404 additions & 22 deletions

File tree

pkg/hfutil/modelconfig/deepseek_vl.go

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -200,27 +200,6 @@ func (c *DeepSeekVLConfig) GetArchitecture() string {
200200
return "DeepseekVLForCausalLM"
201201
}
202202

203-
// Helper function to estimate MoE model parameters
204-
func estimateMoEParams(hiddenSize, numLayers, intermediateSize, moeIntermediateSize, nRoutedExperts, nSharedExperts, vocabSize int) int64 {
205-
// Embeddings
206-
params := int64(hiddenSize * vocabSize)
207-
208-
// For each layer
209-
params += int64(numLayers) * (
210-
// Self-attention
211-
int64(4*hiddenSize*hiddenSize) +
212-
// Shared experts
213-
int64(nSharedExperts*2*hiddenSize*intermediateSize) +
214-
// Routed experts
215-
int64(nRoutedExperts*2*hiddenSize*moeIntermediateSize) +
216-
// Router
217-
int64(hiddenSize*nRoutedExperts) +
218-
// Layer norms
219-
int64(2*hiddenSize))
220-
221-
return params
222-
}
223-
224203
// Register the DeepSeek VL model handlers
225204
func init() {
226205
RegisterModelLoader("deepseek_vl_v2", func(configPath string) (HuggingFaceModel, error) {

pkg/hfutil/modelconfig/interface.go

Lines changed: 162 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ var (
228228
// GenericModelConfig is a fallback configuration for unsupported model types.
229229
// It provides basic functionality by parsing common fields from the config.json
230230
// and attempting to get parameter count from safetensors files.
231+
// For multimodal models with nested configs (text_config, llm_config, language_config),
232+
// loadGenericModelConfig probes nested sub-configs to fill zero-valued fields.
231233
type GenericModelConfig struct {
232234
BaseModelConfig
233235

@@ -237,10 +239,19 @@ type GenericModelConfig struct {
237239
NumAttentionHeads int `json:"num_attention_heads"`
238240
IntermediateSize int `json:"intermediate_size"`
239241
MaxPositionEmbeddings int `json:"max_position_embeddings"`
242+
MaxSequenceLength int `json:"max_sequence_length"`
240243
VocabSize int `json:"vocab_size"`
241244

245+
// MoE fields (populated from top-level or nested config)
246+
NRoutedExperts int `json:"n_routed_experts"`
247+
NSharedExperts int `json:"n_shared_experts"`
248+
MoeIntermediateSize int `json:"moe_intermediate_size"`
249+
242250
// Quantization config (optional)
243251
QuantizationConfig *QuantizationConfig `json:"quantization_config,omitempty"`
252+
253+
// Set during loading when vision sub-config is detected
254+
hasVisionConfig bool
244255
}
245256

246257
// GetParameterCount attempts to get parameter count from safetensors, falls back to estimation
@@ -255,6 +266,10 @@ func (c *GenericModelConfig) GetParameterCount() int64 {
255266

256267
// Fallback: estimate from architecture if we have the necessary fields
257268
if c.HiddenSize > 0 && c.NumHiddenLayers > 0 {
269+
if c.NRoutedExperts > 0 {
270+
return estimateMoEParams(c.HiddenSize, c.NumHiddenLayers, c.IntermediateSize,
271+
c.MoeIntermediateSize, c.NRoutedExperts, c.NSharedExperts, c.VocabSize)
272+
}
258273
return estimateGenericParams(c.HiddenSize, c.NumHiddenLayers, c.IntermediateSize, c.VocabSize)
259274
}
260275

@@ -278,6 +293,139 @@ func estimateGenericParams(hiddenSize, numLayers, intermediateSize, vocabSize in
278293
return embeddingParams + totalLayerParams
279294
}
280295

296+
// estimateMoEParams estimates parameter count for Mixture-of-Experts models.
297+
// It accounts for per-expert FFN weights, shared experts, and the router.
298+
func estimateMoEParams(hiddenSize, numLayers, intermediateSize, moeIntermediateSize, nRoutedExperts, nSharedExperts, vocabSize int) int64 {
299+
if moeIntermediateSize == 0 {
300+
moeIntermediateSize = intermediateSize
301+
}
302+
303+
// Embeddings
304+
params := int64(hiddenSize * vocabSize)
305+
306+
// For each layer
307+
params += int64(numLayers) * (
308+
// Self-attention
309+
int64(4*hiddenSize*hiddenSize) +
310+
// Shared experts
311+
int64(nSharedExperts*2*hiddenSize*intermediateSize) +
312+
// Routed experts
313+
int64(nRoutedExperts*2*hiddenSize*moeIntermediateSize) +
314+
// Router
315+
int64(hiddenSize*nRoutedExperts) +
316+
// Layer norms
317+
int64(2*hiddenSize))
318+
319+
return params
320+
}
321+
322+
// nestedLLMConfigKeys lists the JSON keys under which multimodal models
323+
// commonly store their language/LLM sub-configuration.
324+
var nestedLLMConfigKeys = []string{"text_config", "llm_config", "language_config"}
325+
326+
// probeNestedConfig attempts to fill zero-valued fields in config from
327+
// nested sub-configurations commonly found in multimodal model configs.
328+
// It only fills fields that are zero/empty, so it is safe to call unconditionally.
329+
func probeNestedConfig(data []byte, config *GenericModelConfig) {
330+
var raw map[string]json.RawMessage
331+
if err := json.Unmarshal(data, &raw); err != nil {
332+
return
333+
}
334+
335+
// Probe for nested LLM/text config
336+
for _, key := range nestedLLMConfigKeys {
337+
sub, ok := raw[key]
338+
if !ok {
339+
continue
340+
}
341+
var nested struct {
342+
HiddenSize int `json:"hidden_size"`
343+
NumHiddenLayers int `json:"num_hidden_layers"`
344+
NumAttentionHeads int `json:"num_attention_heads"`
345+
IntermediateSize int `json:"intermediate_size"`
346+
MaxPositionEmbeddings int `json:"max_position_embeddings"`
347+
VocabSize int `json:"vocab_size"`
348+
TransformersVersion string `json:"transformers_version"`
349+
TorchDtype string `json:"torch_dtype"`
350+
// MoE fields — try all common JSON key names via multiple fields
351+
NRoutedExperts int `json:"n_routed_experts"`
352+
NumLocalExperts int `json:"num_local_experts"`
353+
NumExperts int `json:"num_experts"`
354+
NSharedExperts int `json:"n_shared_experts"`
355+
MoeIntermediateSize int `json:"moe_intermediate_size"`
356+
}
357+
if err := json.Unmarshal(sub, &nested); err != nil {
358+
continue
359+
}
360+
361+
// Fill zero-valued fields from nested config
362+
if config.HiddenSize == 0 {
363+
config.HiddenSize = nested.HiddenSize
364+
}
365+
if config.NumHiddenLayers == 0 {
366+
config.NumHiddenLayers = nested.NumHiddenLayers
367+
}
368+
if config.NumAttentionHeads == 0 {
369+
config.NumAttentionHeads = nested.NumAttentionHeads
370+
}
371+
if config.IntermediateSize == 0 {
372+
config.IntermediateSize = nested.IntermediateSize
373+
}
374+
if config.MaxPositionEmbeddings == 0 {
375+
config.MaxPositionEmbeddings = nested.MaxPositionEmbeddings
376+
}
377+
if config.VocabSize == 0 {
378+
config.VocabSize = nested.VocabSize
379+
}
380+
if config.TransformerVersion == "" {
381+
config.TransformerVersion = nested.TransformersVersion
382+
}
383+
if config.TorchDtype == "" {
384+
config.TorchDtype = nested.TorchDtype
385+
}
386+
387+
// MoE fields — resolve the different JSON key names
388+
if config.NRoutedExperts == 0 {
389+
if nested.NRoutedExperts > 0 {
390+
config.NRoutedExperts = nested.NRoutedExperts
391+
} else if nested.NumLocalExperts > 0 {
392+
config.NRoutedExperts = nested.NumLocalExperts
393+
} else if nested.NumExperts > 0 {
394+
config.NRoutedExperts = nested.NumExperts
395+
}
396+
}
397+
if config.NSharedExperts == 0 {
398+
config.NSharedExperts = nested.NSharedExperts
399+
}
400+
if config.MoeIntermediateSize == 0 {
401+
config.MoeIntermediateSize = nested.MoeIntermediateSize
402+
}
403+
404+
break // Use the first matching nested config
405+
}
406+
407+
// Resolve top-level MoE field name variants (num_local_experts, num_experts)
408+
// that don't match the GenericModelConfig JSON tags
409+
if config.NRoutedExperts == 0 {
410+
var topLevel struct {
411+
NumLocalExperts int `json:"num_local_experts"`
412+
NumExperts int `json:"num_experts"`
413+
}
414+
if err := json.Unmarshal(data, &topLevel); err == nil {
415+
if topLevel.NumLocalExperts > 0 {
416+
config.NRoutedExperts = topLevel.NumLocalExperts
417+
} else if topLevel.NumExperts > 0 {
418+
config.NRoutedExperts = topLevel.NumExperts
419+
}
420+
}
421+
}
422+
423+
// Detect vision config presence
424+
if _, ok := raw["vision_config"]; ok {
425+
config.hasVisionConfig = true
426+
}
427+
}
428+
281429
func (c *GenericModelConfig) GetQuantizationType() string {
282430
if c.QuantizationConfig != nil && c.QuantizationConfig.QuantMethod != "" {
283431
return c.QuantizationConfig.QuantMethod
@@ -286,7 +434,15 @@ func (c *GenericModelConfig) GetQuantizationType() string {
286434
}
287435

288436
func (c *GenericModelConfig) GetContextLength() int {
289-
return c.MaxPositionEmbeddings
437+
if c.MaxPositionEmbeddings > 0 {
438+
return c.MaxPositionEmbeddings
439+
}
440+
return c.MaxSequenceLength
441+
}
442+
443+
// HasVision returns true if a vision sub-config was detected during loading
444+
func (c *GenericModelConfig) HasVision() bool {
445+
return c.hasVisionConfig
290446
}
291447

292448
func (c *GenericModelConfig) GetModelSizeBytes() int64 {
@@ -462,6 +618,11 @@ func loadGenericModelConfig(configPath string) (HuggingFaceModel, error) {
462618
}
463619

464620
config.ConfigPath = configPath
621+
622+
// Probe nested sub-configs (text_config, llm_config, language_config)
623+
// to fill zero-valued fields for multimodal models
624+
probeNestedConfig(data, &config)
625+
465626
return &config, nil
466627
}
467628

0 commit comments

Comments
 (0)