@@ -493,20 +493,6 @@ def _slice_embedding(self, child, name, conv_linear_layer):
493493 if getattr (child , "replaced" , False ) == True :
494494 return
495495
496- # When using partition_config (custom patterns), only partition embeddings if
497- # explicitly specified in layer_specs. This is consistent with how _replace()
498- # handles Linear layers - unmatched layers should not be automatically partitioned.
499- if self .partition_config is not None :
500- param_name = name + ".weight" if not name .endswith (".weight" ) else name
501- model_type = self ._get_model_type ()
502- spec = self .partition_config .find_matching_spec (param_name , model_type )
503- if spec is None :
504- # No pattern matched - skip partitioning this embedding
505- return child
506- if spec .partition_type == PartitionType .SKIP :
507- return child
508- # If explicitly specified, proceed with partitioning
509-
510496 mp_replace = ReplaceWithTensorSlicing (mp_group = self .mp_group )
511497
512498 if hasattr (child .weight , 'ds_tensor' ):
@@ -570,7 +556,30 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''):
570556 continue
571557 if len (child ._buffers ) != 0 and self .state_dict is not None :
572558 Loading .load_buffer (child , self .state_dict , checking_key )
573- if child .__class__ in self .linear_policies :
559+
560+ # When using partition_config (custom patterns/presets), use pattern-based routing
561+ # instead of linear_policies. This keeps all pattern logic centralized here.
562+ if self .partition_config is not None :
563+ full_name = prev_name + '.' + name if prev_name else name
564+ if isinstance (child , nn .Linear ):
565+ new_child = self ._replace_with_config (child , full_name )
566+ if new_child is not None :
567+ setattr (r_module , name , new_child )
568+ elif isinstance (child , nn .Embedding ):
569+ # Check if embedding matches any pattern
570+ param_name = full_name + ".weight"
571+ model_type = self ._get_model_type ()
572+ spec = self .partition_config .find_matching_spec (param_name , model_type )
573+ if spec is not None and spec .partition_type != PartitionType .SKIP :
574+ new_child = self ._slice_embedding (child , full_name , False )
575+ if new_child is not None :
576+ setattr (r_module , name , new_child )
577+ # If no pattern matched or skip, leave embedding unchanged
578+ else :
579+ self .update_mp_params (child )
580+ self ._replace_module (child , name , class_name )
581+ # Traditional path: use linear_policies for type-based routing
582+ elif child .__class__ in self .linear_policies :
574583 setattr (r_module , name , self .linear_policies [child .__class__ ](child , prev_name + '.' + name ,
575584 self .conv_linear_layer ))
576585 elif any (isinstance (child , lp ) for lp in self .linear_policies ):
0 commit comments