44# DeepSpeed Team
55
66from collections import defaultdict
7- from typing import List , Tuple
7+ from typing import Dict , List , Tuple
88
99import torch
1010from torch .fx import GraphModule
1111
1212import deepspeed .comm as dist
1313from deepspeed .accelerator import get_accelerator
14+ from deepspeed .utils import log_dist
1415
1516from ..util import get_deepcompile_handle
1617from ..graph_param import DSGraphParamManager
1920
2021max_alloc_mem = 0
2122last_optimize_step = 0
23+ MEM_MARGIN = 0.1
24+
25+
26+ def print_rank_0 (message ):
27+ log_dist (message , ranks = [0 ])
28+
29+
30+ def _compute_persistence_budget (all_graph_mem_records : List [List [Tuple [str , int , int , int ]]], total_mem : int ,
31+ mem_margin : float ) -> Dict [str , int ]:
32+ usable_mem = int (total_mem * (1 - mem_margin ))
33+ non_empty_records = [mem_records for mem_records in all_graph_mem_records if mem_records ]
34+
35+ if not non_empty_records :
36+ return {
37+ "usable_mem" : usable_mem ,
38+ "peak_resident_alloc" : 0 ,
39+ "transient_peak" : 0 ,
40+ "available_mem" : 0 ,
41+ "profiled_list_count" : 0 ,
42+ }
43+
44+ # Persistent parameters add to live allocations that remain resident past an op boundary.
45+ peak_resident_alloc = max (record [1 ] for mem_records in non_empty_records for record in mem_records )
46+ transient_peak = max (record [3 ] for mem_records in non_empty_records for record in mem_records )
47+
48+ return {
49+ "usable_mem" : usable_mem ,
50+ "peak_resident_alloc" : peak_resident_alloc ,
51+ "transient_peak" : transient_peak ,
52+ "available_mem" : max (0 , usable_mem - peak_resident_alloc ),
53+ "profiled_list_count" : len (non_empty_records ),
54+ }
2255
2356
2457def selective_gather (gm : GraphModule , graph_id : int , graph_order : List [Tuple [int , bool ]], profiling_results ,
2558 create_inputs_fn , mem_budget : float , param_manager : DSGraphParamManager ,
2659 bwd : bool ) -> GraphModule :
60+ target_graph_id = graph_id
2761
2862 if not bwd :
2963 return gm
@@ -38,19 +72,21 @@ def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int
3872 if last_backward_graph_id is None or graph_id != last_backward_graph_id :
3973 return gm
4074
41- peak_mem = 0
42- for graph_id , prof in profiling_results .items ():
43- # Use peak memory
44- fwd_max_mem = max (m [3 ] for m in prof .fwd_mem )
45- bwd_max_mem = max (m [3 ] for m in prof .bwd_mem ) if len (prof .bwd_mem ) > 0 else 0
46- peak_mem = max (peak_mem , fwd_max_mem , bwd_max_mem )
47- if dist .get_rank () == 0 :
48- print (
49- f"selective_gather graph_id={ graph_id } max_mem={ peak_mem } fwd_max_mem={ fwd_max_mem } bwd_max_mem={ bwd_max_mem } "
50- )
75+ all_graph_mem_records = []
76+ for profile_graph_id , prof in profiling_results .items ():
77+ all_graph_mem_records .extend ([prof .fwd_mem , prof .bwd_mem ])
78+
79+ fwd_peak_resident = max ((m [1 ] for m in prof .fwd_mem ), default = 0 )
80+ fwd_transient_peak = max ((m [3 ] for m in prof .fwd_mem ), default = 0 )
81+ bwd_peak_resident = max ((m [1 ] for m in prof .bwd_mem ), default = 0 )
82+ bwd_transient_peak = max ((m [3 ] for m in prof .bwd_mem ), default = 0 )
83+
84+ print_rank_0 (f"selective_gather graph_id={ profile_graph_id } "
85+ f"fwd_peak_resident={ fwd_peak_resident } fwd_transient_peak={ fwd_transient_peak } "
86+ f"bwd_peak_resident={ bwd_peak_resident } bwd_transient_peak={ bwd_transient_peak } " )
5187
5288 persistent_ds_ids = set ()
53- for graph_id , pm in param_manager .items ():
89+ for param_graph_id , pm in param_manager .items ():
5490 for name , ds_param in pm .params .items ():
5591 if ds_param .param .ds_persist :
5692 persistent_ds_ids .add (pm .ds_ids [name ])
@@ -60,13 +96,13 @@ def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int
6096 ds_id_to_prof_dtime = defaultdict (float )
6197 ds_id_to_prof_wtime = defaultdict (float )
6298
63- for graph_id , pm in param_manager .items ():
99+ for param_graph_id , pm in param_manager .items ():
64100 params = pm .params
65101 for param_name , param in params .items ():
66102 ds_id = pm .ds_ids [param_name ]
67103 ds_id_to_size [ds_id ] = param .numel * param .dtype .itemsize
68104
69- profile = profiling_results [graph_id ]
105+ profile = profiling_results [param_graph_id ]
70106 for n in profile .fwd_graph .nodes :
71107 if n .target == torch .ops .dc .allgather_param .default :
72108 assert "tensor_size" in n .meta
@@ -100,39 +136,68 @@ def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int
100136 # f"ds_id={ds_id} time_per_size={ds_id_to_time[ds_id] / ds_id_to_size[ds_id]:.5f} dtime={dtime_in_sec:.3f} wtime={wtime_in_sec:.3f} size={size_in_mb:.2f}MB bw={size_in_mb/dtime_in_sec:.2f}MB/s"
101137 # )
102138
103- sorted_ds_ids = {ds_id : ds_id_to_size [ds_id ] for ds_id in ds_ids }
104-
105139 accelerator = get_accelerator ()
106140 total_mem = accelerator .total_memory ()
107- vals_to_bcast = torch .tensor ([total_mem ], device = torch .device (get_accelerator ().current_device ()))
141+ current_available_mem = accelerator .available_memory ()
142+ vals_to_bcast = torch .tensor ([total_mem , current_available_mem ],
143+ device = torch .device (get_accelerator ().current_device ()))
108144 dist .all_reduce (vals_to_bcast , dist .ReduceOp .MIN )
109145 total_mem = vals_to_bcast [0 ].item ()
146+ current_available_mem = vals_to_bcast [1 ].item ()
110147
111- MEM_MARGIN = 0.1
112- available_mem = total_mem * (1 - MEM_MARGIN ) - peak_mem
113-
114- if dist .get_rank () == 0 :
115- print (
116- f"selective_gather max_mem={ peak_mem } total_mem={ total_mem } MEM_MARGIN={ MEM_MARGIN } available_mem={ available_mem } "
117- )
148+ budget = _compute_persistence_budget (all_graph_mem_records , total_mem , MEM_MARGIN )
149+ available_mem = int (current_available_mem * (1 - MEM_MARGIN ))
118150
119151 ds_id_to_param = {}
120152 for g_id , g_pm in param_manager .items ():
121153 for name , ds_param in g_pm .params .items ():
122154 ds_id_to_param [g_pm .ds_ids [name ]] = ds_param .param
123155
156+ candidate_bytes = sum (ds_id_to_size [ds_id ] for ds_id in ds_ids )
157+ persistent_bytes = sum (ds_id_to_size .get (ds_id , 0 ) for ds_id in persistent_ds_ids )
158+
159+ print_rank_0 (
160+ f"selective_gather target_graph_id={ target_graph_id } profiled_mem_lists={ budget ['profiled_list_count' ]} "
161+ f"total_mem={ total_mem } usable_mem={ budget ['usable_mem' ]} peak_resident_alloc={ budget ['peak_resident_alloc' ]} "
162+ f"transient_peak={ budget ['transient_peak' ]} current_available_mem={ current_available_mem } "
163+ f"usable_available_mem={ available_mem } "
164+ f"persistent_count={ len (persistent_ds_ids )} persistent_bytes={ persistent_bytes } "
165+ f"candidate_count={ len (ds_ids )} candidate_bytes={ candidate_bytes } " )
166+
167+ if budget ["profiled_list_count" ] == 0 :
168+ print_rank_0 ("selective_gather no profiling data; skipping persistence update" )
169+ return gm
170+
171+ if len (ds_ids ) == 0 :
172+ print_rank_0 ("selective_gather no candidates to persist" )
173+ return gm
174+
175+ if available_mem == 0 :
176+ print_rank_0 ("selective_gather no currently available memory for new persistent params" )
177+ return gm
178+
124179 persistent_mem = 0
180+ selected_count = 0
125181 nz3 = get_deepcompile_handle ()
126- for ds_id , size in sorted_ds_ids .items ():
182+ for ds_id in ds_ids :
183+ size = ds_id_to_size [ds_id ]
127184 if persistent_mem + size > available_mem :
128185 break
129186 persistent_mem += size
187+ selected_count += 1
130188
131189 param_obj = ds_id_to_param [ds_id ]
132190
133191 nz3 .set_persistent (ds_id )
134- if dist .get_rank () == 0 :
135- print (f"Set persistent: { ds_id } size: { size } persistent_mem: { persistent_mem } shape: { param_obj .ds_shape } " )
192+ print_rank_0 (
193+ f"Set persistent: { ds_id } size: { size } persistent_mem: { persistent_mem } shape: { param_obj .ds_shape } " )
194+
195+ if selected_count == 0 :
196+ smallest_candidate = min (ds_id_to_size [ds_id ] for ds_id in ds_ids )
197+ print_rank_0 (f"selective_gather selected no new params: available_mem={ available_mem } "
198+ f"smallest_candidate={ smallest_candidate } " )
199+ else :
200+ print_rank_0 (f"selective_gather selected_count={ selected_count } selected_bytes={ persistent_mem } " )
136201
137202 return gm
138203
0 commit comments