Commit 1c95ade
committed
Fix bf16 dtype mismatch in ZeRO-3 with zero_quantized_weights
When using ZeRO-3 with zero_quantized_weights=True and bf16 enabled,
the dequantized weights were incorrectly cast to fp16 instead of
preserving the original bf16 dtype. This caused RuntimeError during
training with BERT and similar models.
The fix adds original_dtype tracking to AllGatherCoalescedHandle,
mirroring the existing pattern in AllGatherHandle, to ensure weights
are converted back to their original dtype after dequantization.
Fixes #7775
Signed-off-by: juyterman1000 <[email protected]>1 parent 491a38c commit 1c95ade
1 file changed
Lines changed: 12 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
713 | 713 | | |
714 | 714 | | |
715 | 715 | | |
| 716 | + | |
716 | 717 | | |
717 | 718 | | |
718 | 719 | | |
| |||
721 | 722 | | |
722 | 723 | | |
723 | 724 | | |
| 725 | + | |
724 | 726 | | |
725 | 727 | | |
726 | 728 | | |
| |||
735 | 737 | | |
736 | 738 | | |
737 | 739 | | |
738 | | - | |
739 | | - | |
| 740 | + | |
| 741 | + | |
| 742 | + | |
| 743 | + | |
| 744 | + | |
| 745 | + | |
| 746 | + | |
740 | 747 | | |
741 | 748 | | |
742 | 749 | | |
| |||
1469 | 1476 | | |
1470 | 1477 | | |
1471 | 1478 | | |
| 1479 | + | |
| 1480 | + | |
1472 | 1481 | | |
1473 | 1482 | | |
1474 | 1483 | | |
1475 | 1484 | | |
1476 | 1485 | | |
1477 | 1486 | | |
1478 | 1487 | | |
| 1488 | + | |
1479 | 1489 | | |
1480 | 1490 | | |
1481 | 1491 | | |
| |||
0 commit comments