Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions examples/qwen2/pretrain_qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

from megatron_patch.data.utils import get_batch_on_this_tp_rank_original, get_batch_on_this_tp_rank_idxmap_sft
from megatron_patch.model.qwen2.layer_specs import (
get_gpt_decoder_block_spec,
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
Expand All @@ -62,17 +63,20 @@ def model_provider(

config = core_transformer_config_from_args(args, Qwen2TransformerConfig)
use_te = args.transformer_impl == "transformer_engine"

if use_te:
print_rank_0("building qwen2 model in TE...")
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
args.num_experts, args.moe_grouped_gemm, args.qk_layernorm
)
if args.num_experts:
# Define the decoder block spec
transformer_layer_spec = get_gpt_decoder_block_spec(config, use_transformer_engine=use_te)
else:
print_rank_0("building qwen2 model in Mcore...")
transformer_layer_spec = get_gpt_layer_local_spec(
args.num_experts, args.moe_grouped_gemm, args.qk_layernorm
)
if use_te:
print_rank_0("building qwen2 model in TE...")
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
args.num_experts, args.moe_grouped_gemm, args.qk_layernorm
)
else:
print_rank_0("building qwen2 model in Mcore...")
transformer_layer_spec = get_gpt_layer_local_spec(
args.num_experts, args.moe_grouped_gemm, args.qk_layernorm
)

model = GPTModel(
config=config,
Expand Down
37 changes: 35 additions & 2 deletions megatron_patch/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,33 @@
import torch.nn.functional as F
from megatron.core.transformer import TransformerConfig

def moe_freq_type(x):
"""Frequency between MoE layers and Dense layers.

Accepts either:
- An integer N: Represents a 1:N ratio, meaning one expert layer for every N-1 dense layers
- A string "N": Same as above, but provided as a string
- A string containing a Python list expression that defines a custom pattern, e.g.:
"([1]*3+[0]*1)*3" evaluates to [1,1,1,0,1,1,1,0,1,1,1,0]
where 1 indicates an expert layer and 0 indicates a dense layer.
This allows defining arbitrary patterns of expert and dense layers.
The pattern length must match the total number of transformer layers.
Examples:
"([0]+[1]*23)": 1 dense layer followed by 23 experts layers
"([1]*3+[0]*2)*2": Three expert layers followed by two dense layers, repeated twice.
"""
if isinstance(x, int):
return x
assert isinstance(x, str)
if '[' in x:
# it's a custom pattern
pattern = eval(x)
return pattern
else:
# it's a single int but in str
return int(x)


def core_transformer_config_from_args(args, config_class=None):
# Config class.
config_class = config_class or TransformerConfig
Expand Down Expand Up @@ -493,8 +520,14 @@ def get_patch_args(parser):
group.add_argument("--qk-nope-head-dim", type=int, default=None)
group.add_argument("--qk-rope-head-dim", type=int, default=None)
group.add_argument("--num-shared-experts", type=int, default=None)
group.add_argument("--moe-layer-freq", type=int, default=1)

group.add_argument('--moe-layer-freq', type=moe_freq_type, default=1,
help='Frequency between MoE layers and Dense layers. Accepts either: '
'- An integer N: Represents a 1:N ratio, meaning one expert layer for every N-1 dense layers '
'- A string containing a Python list expression that defines a custom pattern, e.g.: '
'"([1]*3+[0]*1)*3" evaluates to [1,1,1,0,1,1,1,0,1,1,1,0] '
'where 1 indicates an expert layer and 0 indicates a dense layer. '
'Examples: "([0]+[1]*23)": 1 dense layer followed by 23 experts layers, '
'"([1]*3+[0]*2)*2": Three expert layers followed by two dense layers, repeated twice.')
patch_if_not_exist(
group,
"--rotary-scaling-factor", type=int, default=1
Expand Down
83 changes: 83 additions & 0 deletions megatron_patch/model/qwen2/layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
from megatron.core.transformer.identity_op import IdentityOp

from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig


from .transformer_block import get_num_layers_to_build, TransformerBlockSubmodules
from .transformer.mlp import MLP, MLPSubmodules
from .transformer.attention import SelfAttention, SelfAttentionSubmodules
from .moe.moe_layer import MoELayer
Expand Down Expand Up @@ -138,3 +140,84 @@ def _get_mlp_module_spec(
else None
),
)


def get_gpt_decoder_block_spec(
config: TransformerConfig, use_transformer_engine: bool
) -> TransformerBlockSubmodules:
"""GPT block spec."""
if use_transformer_engine:
layer_norm_impl = TENorm
else:
layer_norm_impl = LNImpl

# Layer specs.
dense_layer_spec = (
get_gpt_layer_with_transformer_engine_spec(
num_experts=None,
moe_grouped_gemm=False,
qk_layernorm=config.qk_layernorm,
)
if use_transformer_engine
else get_gpt_layer_local_spec(
num_experts=None,
moe_grouped_gemm=False,
qk_layernorm=config.qk_layernorm,
)
)
moe_layer_spec = (
get_gpt_layer_with_transformer_engine_spec(
num_experts=config.num_moe_experts,
moe_grouped_gemm=config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
)
if use_transformer_engine
else get_gpt_layer_local_spec(
num_experts=config.num_moe_experts,
moe_grouped_gemm=config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
)
)

# Parse config.moe_layer_freq to determine the pattern of expert/dense layers.
# 0 stands for dense layers, 1 stands for expert layers.
# For integer N: Creates a pattern with one expert layer every N layers.
# For string pattern: Evaluates the str directly (e.g. "[1,0,1]" for alternating expert/dense).
if isinstance(config.moe_layer_freq, int):
moe_layer_pattern = [
1 if i>=config.moe_layer_freq else 0 for i in range(config.num_layers)
]
elif isinstance(config.moe_layer_freq, list):
moe_layer_pattern = config.moe_layer_freq
assert len(moe_layer_pattern) == config.num_layers, (
f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, "
f"expected {config.num_layers}, "
f"current moe layer pattern: {config.moe_layer_freq}"
)
else:
raise ValueError(
f"Invalid moe_layer_freq: {type(config.moe_layer_freq)}, {config.moe_layer_freq}"
)

# Create the layer specs for the model.
layer_specs = []
for layer_number in range(config.num_layers):
if moe_layer_pattern[layer_number] == 1:
layer_specs.append(moe_layer_spec)
elif moe_layer_pattern[layer_number] == 0:
layer_specs.append(dense_layer_spec)
else:
raise ValueError(f"Invalid layer pattern: {moe_layer_pattern}")

# Slice the layer specs to only include the layers that are built in this pipeline stage.
# Note: MCore layer_number starts at 1
# offset = TransformerLayer._get_layer_offset(config)
offset = TransformerLayer.get_layer_offset(config)
num_layers_to_build = get_num_layers_to_build(config)
layer_specs = layer_specs[offset : offset + num_layers_to_build]

# Block spec.
# block_spec = TransformerBlockSubmodules(layer_specs=layer_specs, layer_norm=layer_norm_impl)
block_spec = TransformerBlockSubmodules(layer_specs=layer_specs)

return block_spec
8 changes: 7 additions & 1 deletion megatron_patch/model/qwen2/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,12 +446,18 @@ def sharded_state_dict(
non_homogeneous_layers = metadata is not None and metadata.get(
'non_homogeneous_layers', False
)
if isinstance(self.config.moe_layer_freq, int):
if self.config.moe_layer_freq > 1:
non_homogeneous_layers = True
elif isinstance(self.config.moe_layer_freq, list):
non_homogeneous_layers = True

sharded_state_dict = {}

layer_prefix = f'{prefix}layers.'
num_layers = self.config.num_layers
for layer in self.layers:
offset = layer._get_layer_offset()
offset = TransformerLayer.get_layer_offset(self.config)

global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1
state_dict_prefix = f'{layer_prefix}{global_layer_offset - offset}.' # module list index in TransformerBlock
Expand Down
29 changes: 29 additions & 0 deletions megatron_patch/model/qwen2/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,35 @@ def _get_layer_offset(self):
offset = 0

return offset


@staticmethod
def get_layer_offset(config: TransformerConfig):

pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()

num_layers_per_pipeline_rank = (
config.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
)

if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()

total_num_layers = config.num_layers
num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
total_virtual_chunks = total_num_layers // vp_size
offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank)

else:
# Each stage gets a contiguous set of layers.
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
offset = pipeline_rank * num_layers_per_pipeline_rank
else:
offset = 0

return offset


def forward(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ SLIDING_WINDOW=131072
EXTRA_VOCAB_SIZE=293

moe_options=" \
--moe-layer-freq 1 \
--moe-router-topk ${NUM_EXPERTS_PER_TOPK} \
--num-experts ${NUM_EXPERTS} \
--target-expert-model-parallel-size ${EP}\
Expand Down