Skip to content

Commit d62eb69

Browse files
committed
exclude default patterns
Signed-off-by: Masahiro Tanaka <[email protected]>
1 parent 1b86e1a commit d62eb69

1 file changed

Lines changed: 24 additions & 15 deletions

File tree

deepspeed/module_inject/auto_tp.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -493,20 +493,6 @@ def _slice_embedding(self, child, name, conv_linear_layer):
493493
if getattr(child, "replaced", False) == True:
494494
return
495495

496-
# When using partition_config (custom patterns), only partition embeddings if
497-
# explicitly specified in layer_specs. This is consistent with how _replace()
498-
# handles Linear layers - unmatched layers should not be automatically partitioned.
499-
if self.partition_config is not None:
500-
param_name = name + ".weight" if not name.endswith(".weight") else name
501-
model_type = self._get_model_type()
502-
spec = self.partition_config.find_matching_spec(param_name, model_type)
503-
if spec is None:
504-
# No pattern matched - skip partitioning this embedding
505-
return child
506-
if spec.partition_type == PartitionType.SKIP:
507-
return child
508-
# If explicitly specified, proceed with partitioning
509-
510496
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
511497

512498
if hasattr(child.weight, 'ds_tensor'):
@@ -570,7 +556,30 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''):
570556
continue
571557
if len(child._buffers) != 0 and self.state_dict is not None:
572558
Loading.load_buffer(child, self.state_dict, checking_key)
573-
if child.__class__ in self.linear_policies:
559+
560+
# When using partition_config (custom patterns/presets), use pattern-based routing
561+
# instead of linear_policies. This keeps all pattern logic centralized here.
562+
if self.partition_config is not None:
563+
full_name = prev_name + '.' + name if prev_name else name
564+
if isinstance(child, nn.Linear):
565+
new_child = self._replace_with_config(child, full_name)
566+
if new_child is not None:
567+
setattr(r_module, name, new_child)
568+
elif isinstance(child, nn.Embedding):
569+
# Check if embedding matches any pattern
570+
param_name = full_name + ".weight"
571+
model_type = self._get_model_type()
572+
spec = self.partition_config.find_matching_spec(param_name, model_type)
573+
if spec is not None and spec.partition_type != PartitionType.SKIP:
574+
new_child = self._slice_embedding(child, full_name, False)
575+
if new_child is not None:
576+
setattr(r_module, name, new_child)
577+
# If no pattern matched or skip, leave embedding unchanged
578+
else:
579+
self.update_mp_params(child)
580+
self._replace_module(child, name, class_name)
581+
# Traditional path: use linear_policies for type-based routing
582+
elif child.__class__ in self.linear_policies:
574583
setattr(r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
575584
self.conv_linear_layer))
576585
elif any(isinstance(child, lp) for lp in self.linear_policies):

0 commit comments

Comments
 (0)