Skip to content

Commit 11cb3ec

Browse files
authored
Merge branch 'master' into tohtana/aws-test-torch-version-input
2 parents e6ac6cd + ecb26a5 commit 11cb3ec

20 files changed

Lines changed: 557 additions & 74 deletions

csrc/compile/deepcompile.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,12 @@ at::Tensor reduce_grad(at::Tensor grad_tensor, long graph_id, long ds_id)
9999

100100
if (sync_after_reduce) { c10::cuda::device_synchronize(); }
101101

102-
return at::Tensor();
102+
return torch::empty({0}, grad_tensor.options());
103103
}
104104

105105
at::Tensor reduce_grad_meta(at::Tensor grad_tensor, long graph_id, long ds_id)
106106
{
107-
return at::Tensor();
107+
return torch::empty({0}, grad_tensor.options());
108108
}
109109

110110
void free_tensors(std::vector<at::Tensor> tensors)
@@ -179,10 +179,12 @@ void start_backward(bool update)
179179
for (auto& it : executors) { it.second->startBackward(update); }
180180
}
181181

182-
void end_backward(long graph_id)
182+
void end_backward(const c10::IValue& deps, long graph_id)
183183
{
184184
auto executor = getExecutor<CustomOpExecutor>(graph_id, executors);
185185
executor->endBackward();
186186
}
187187

188+
void end_backward_meta(const c10::IValue& deps, long graph_id) {}
189+
188190
} // namespace dc

csrc/compile/init.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ TORCH_LIBRARY(dc, m)
2424
m.def("wait_reload(Tensor a, int id, int id) -> Tensor");
2525
m.def("offload_parameter(Tensor a, int id, int id) -> ()");
2626
m.def("reload_parameter(Tensor a, int id, int id) -> ()");
27-
m.def("end_backward(int graph_id) -> ()");
27+
m.def("end_backward(Any deps, int graph_id) -> ()");
2828

2929
m.def("test_call(Tensor a) -> Tensor");
3030
}
@@ -43,6 +43,7 @@ TORCH_LIBRARY_IMPL(dc, CPU, m)
4343
m.impl("wait_reload", &dc::wait_reload);
4444
m.impl("offload_parameter", &dc::offload_parameter);
4545
m.impl("reload_parameter", &dc::reload_parameter);
46+
m.impl("end_backward", &dc::end_backward);
4647

4748
m.impl("test_call", &dc::test_call);
4849
}
@@ -61,6 +62,7 @@ TORCH_LIBRARY_IMPL(dc, CUDA, m)
6162
m.impl("wait_reload", &dc::wait_reload);
6263
m.impl("offload_parameter", &dc::offload_parameter);
6364
m.impl("reload_parameter", &dc::reload_parameter);
65+
m.impl("end_backward", &dc::end_backward);
6466

6567
m.impl("test_call", &dc::test_call);
6668
}
@@ -75,10 +77,11 @@ TORCH_LIBRARY_IMPL(dc, Meta, m)
7577
m.impl("free_tensors", &dc::free_tensors_meta);
7678
m.impl("reload_parameter", &dc::reload_parameter_meta);
7779
m.impl("offload_parameter", &dc::offload_parameter_meta);
80+
m.impl("end_backward", &dc::end_backward_meta);
7881
}
7982

80-
// The "Undefined" dispatch key is for operations whose arguments do not contain
81-
// a tensor.
83+
// end_backward may be invoked with dependency placeholders that have already
84+
// become None, in which case the dispatcher sees no tensor arguments.
8285
TORCH_LIBRARY_IMPL(dc, Undefined, m) { m.impl("end_backward", &dc::end_backward); }
8386

8487
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)

csrc/compile/z3.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,6 @@ void reload_parameter(at::Tensor tensor, long graph_id, long id);
5353
void offload_parameter(at::Tensor tensor, long graph_id, long id);
5454
void reload_parameter_meta(at::Tensor tensor, long graph_id, long id);
5555
void offload_parameter_meta(at::Tensor tensor, long graph_id, long id);
56-
void end_backward(long graph_id);
56+
void end_backward(const c10::IValue& deps, long graph_id);
57+
void end_backward_meta(const c10::IValue& deps, long graph_id);
5758
} // namespace dc

deepspeed/compile/fx.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010
from torch.fx import Node, Graph, GraphModule
11+
from torch.fx.node import map_aggregate
1112

1213
from .util import get_last_uses
1314

@@ -19,6 +20,23 @@ def get_output_node(graph: Graph):
1920
raise ValueError("No output node found")
2021

2122

23+
def add_end_backward(graph: Graph, graph_id: int):
24+
reduce_nodes = [n for n in graph.nodes if n.target == torch.ops.dc.reduce_grad.default]
25+
if len(reduce_nodes) == 0:
26+
return
27+
28+
with graph.inserting_before(get_output_node(graph)):
29+
graph.create_node("call_function", torch.ops.dc.end_backward.default, (reduce_nodes, graph_id))
30+
31+
32+
def replace_reduce_outputs_with_none(graph: Graph):
33+
output_node = get_output_node(graph)
34+
new_outputs = map_aggregate(
35+
output_node.args[0], lambda n: None
36+
if isinstance(n, Node) and n.target == torch.ops.dc.reduce_grad.default else n)
37+
output_node.args = (new_outputs, )
38+
39+
2240
def move_primals_to_head(graph: Graph):
2341

2442
# Move primals to the head of the graph

deepspeed/compile/inductor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def register_fallback_no_reuse(op_overload,
212212
never_reuse_output=True,
213213
force_free_input=True)
214214
register_fallback_no_reuse(torch.ops.dc.free_tensors.default, never_reuse_input=True, never_reuse_output=True)
215+
register_fallback_no_reuse(torch.ops.dc.end_backward.default, never_reuse_input=True, never_reuse_output=False)
215216

216217
if not hasattr(Scheduler, "is_dc_patched") or not Scheduler.is_dc_patched:
217218
Scheduler.is_dc_patched = True

deepspeed/compile/passes/selective_gather.py

Lines changed: 92 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
# DeepSpeed Team
55

66
from collections import defaultdict
7-
from typing import List, Tuple
7+
from typing import Dict, List, Tuple
88

99
import torch
1010
from torch.fx import GraphModule
1111

1212
import deepspeed.comm as dist
1313
from deepspeed.accelerator import get_accelerator
14+
from deepspeed.utils import log_dist
1415

1516
from ..util import get_deepcompile_handle
1617
from ..graph_param import DSGraphParamManager
@@ -19,11 +20,44 @@
1920

2021
max_alloc_mem = 0
2122
last_optimize_step = 0
23+
MEM_MARGIN = 0.1
24+
25+
26+
def print_rank_0(message):
27+
log_dist(message, ranks=[0])
28+
29+
30+
def _compute_persistence_budget(all_graph_mem_records: List[List[Tuple[str, int, int, int]]], total_mem: int,
31+
mem_margin: float) -> Dict[str, int]:
32+
usable_mem = int(total_mem * (1 - mem_margin))
33+
non_empty_records = [mem_records for mem_records in all_graph_mem_records if mem_records]
34+
35+
if not non_empty_records:
36+
return {
37+
"usable_mem": usable_mem,
38+
"peak_resident_alloc": 0,
39+
"transient_peak": 0,
40+
"available_mem": 0,
41+
"profiled_list_count": 0,
42+
}
43+
44+
# Persistent parameters add to live allocations that remain resident past an op boundary.
45+
peak_resident_alloc = max(record[1] for mem_records in non_empty_records for record in mem_records)
46+
transient_peak = max(record[3] for mem_records in non_empty_records for record in mem_records)
47+
48+
return {
49+
"usable_mem": usable_mem,
50+
"peak_resident_alloc": peak_resident_alloc,
51+
"transient_peak": transient_peak,
52+
"available_mem": max(0, usable_mem - peak_resident_alloc),
53+
"profiled_list_count": len(non_empty_records),
54+
}
2255

2356

2457
def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
2558
create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager,
2659
bwd: bool) -> GraphModule:
60+
target_graph_id = graph_id
2761

2862
if not bwd:
2963
return gm
@@ -38,19 +72,21 @@ def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int
3872
if last_backward_graph_id is None or graph_id != last_backward_graph_id:
3973
return gm
4074

41-
peak_mem = 0
42-
for graph_id, prof in profiling_results.items():
43-
# Use peak memory
44-
fwd_max_mem = max(m[3] for m in prof.fwd_mem)
45-
bwd_max_mem = max(m[3] for m in prof.bwd_mem) if len(prof.bwd_mem) > 0 else 0
46-
peak_mem = max(peak_mem, fwd_max_mem, bwd_max_mem)
47-
if dist.get_rank() == 0:
48-
print(
49-
f"selective_gather graph_id={graph_id} max_mem={peak_mem} fwd_max_mem={fwd_max_mem} bwd_max_mem={bwd_max_mem}"
50-
)
75+
all_graph_mem_records = []
76+
for profile_graph_id, prof in profiling_results.items():
77+
all_graph_mem_records.extend([prof.fwd_mem, prof.bwd_mem])
78+
79+
fwd_peak_resident = max((m[1] for m in prof.fwd_mem), default=0)
80+
fwd_transient_peak = max((m[3] for m in prof.fwd_mem), default=0)
81+
bwd_peak_resident = max((m[1] for m in prof.bwd_mem), default=0)
82+
bwd_transient_peak = max((m[3] for m in prof.bwd_mem), default=0)
83+
84+
print_rank_0(f"selective_gather graph_id={profile_graph_id} "
85+
f"fwd_peak_resident={fwd_peak_resident} fwd_transient_peak={fwd_transient_peak} "
86+
f"bwd_peak_resident={bwd_peak_resident} bwd_transient_peak={bwd_transient_peak}")
5187

5288
persistent_ds_ids = set()
53-
for graph_id, pm in param_manager.items():
89+
for param_graph_id, pm in param_manager.items():
5490
for name, ds_param in pm.params.items():
5591
if ds_param.param.ds_persist:
5692
persistent_ds_ids.add(pm.ds_ids[name])
@@ -60,13 +96,13 @@ def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int
6096
ds_id_to_prof_dtime = defaultdict(float)
6197
ds_id_to_prof_wtime = defaultdict(float)
6298

63-
for graph_id, pm in param_manager.items():
99+
for param_graph_id, pm in param_manager.items():
64100
params = pm.params
65101
for param_name, param in params.items():
66102
ds_id = pm.ds_ids[param_name]
67103
ds_id_to_size[ds_id] = param.numel * param.dtype.itemsize
68104

69-
profile = profiling_results[graph_id]
105+
profile = profiling_results[param_graph_id]
70106
for n in profile.fwd_graph.nodes:
71107
if n.target == torch.ops.dc.allgather_param.default:
72108
assert "tensor_size" in n.meta
@@ -100,39 +136,68 @@ def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int
100136
# f"ds_id={ds_id} time_per_size={ds_id_to_time[ds_id] / ds_id_to_size[ds_id]:.5f} dtime={dtime_in_sec:.3f} wtime={wtime_in_sec:.3f} size={size_in_mb:.2f}MB bw={size_in_mb/dtime_in_sec:.2f}MB/s"
101137
# )
102138

103-
sorted_ds_ids = {ds_id: ds_id_to_size[ds_id] for ds_id in ds_ids}
104-
105139
accelerator = get_accelerator()
106140
total_mem = accelerator.total_memory()
107-
vals_to_bcast = torch.tensor([total_mem], device=torch.device(get_accelerator().current_device()))
141+
current_available_mem = accelerator.available_memory()
142+
vals_to_bcast = torch.tensor([total_mem, current_available_mem],
143+
device=torch.device(get_accelerator().current_device()))
108144
dist.all_reduce(vals_to_bcast, dist.ReduceOp.MIN)
109145
total_mem = vals_to_bcast[0].item()
146+
current_available_mem = vals_to_bcast[1].item()
110147

111-
MEM_MARGIN = 0.1
112-
available_mem = total_mem * (1 - MEM_MARGIN) - peak_mem
113-
114-
if dist.get_rank() == 0:
115-
print(
116-
f"selective_gather max_mem={peak_mem} total_mem={total_mem} MEM_MARGIN={MEM_MARGIN} available_mem={available_mem}"
117-
)
148+
budget = _compute_persistence_budget(all_graph_mem_records, total_mem, MEM_MARGIN)
149+
available_mem = int(current_available_mem * (1 - MEM_MARGIN))
118150

119151
ds_id_to_param = {}
120152
for g_id, g_pm in param_manager.items():
121153
for name, ds_param in g_pm.params.items():
122154
ds_id_to_param[g_pm.ds_ids[name]] = ds_param.param
123155

156+
candidate_bytes = sum(ds_id_to_size[ds_id] for ds_id in ds_ids)
157+
persistent_bytes = sum(ds_id_to_size.get(ds_id, 0) for ds_id in persistent_ds_ids)
158+
159+
print_rank_0(
160+
f"selective_gather target_graph_id={target_graph_id} profiled_mem_lists={budget['profiled_list_count']} "
161+
f"total_mem={total_mem} usable_mem={budget['usable_mem']} peak_resident_alloc={budget['peak_resident_alloc']} "
162+
f"transient_peak={budget['transient_peak']} current_available_mem={current_available_mem} "
163+
f"usable_available_mem={available_mem} "
164+
f"persistent_count={len(persistent_ds_ids)} persistent_bytes={persistent_bytes} "
165+
f"candidate_count={len(ds_ids)} candidate_bytes={candidate_bytes}")
166+
167+
if budget["profiled_list_count"] == 0:
168+
print_rank_0("selective_gather no profiling data; skipping persistence update")
169+
return gm
170+
171+
if len(ds_ids) == 0:
172+
print_rank_0("selective_gather no candidates to persist")
173+
return gm
174+
175+
if available_mem == 0:
176+
print_rank_0("selective_gather no currently available memory for new persistent params")
177+
return gm
178+
124179
persistent_mem = 0
180+
selected_count = 0
125181
nz3 = get_deepcompile_handle()
126-
for ds_id, size in sorted_ds_ids.items():
182+
for ds_id in ds_ids:
183+
size = ds_id_to_size[ds_id]
127184
if persistent_mem + size > available_mem:
128185
break
129186
persistent_mem += size
187+
selected_count += 1
130188

131189
param_obj = ds_id_to_param[ds_id]
132190

133191
nz3.set_persistent(ds_id)
134-
if dist.get_rank() == 0:
135-
print(f"Set persistent: {ds_id} size: {size} persistent_mem: {persistent_mem} shape: {param_obj.ds_shape}")
192+
print_rank_0(
193+
f"Set persistent: {ds_id} size: {size} persistent_mem: {persistent_mem} shape: {param_obj.ds_shape}")
194+
195+
if selected_count == 0:
196+
smallest_candidate = min(ds_id_to_size[ds_id] for ds_id in ds_ids)
197+
print_rank_0(f"selective_gather selected no new params: available_mem={available_mem} "
198+
f"smallest_candidate={smallest_candidate}")
199+
else:
200+
print_rank_0(f"selective_gather selected_count={selected_count} selected_bytes={persistent_mem}")
136201

137202
return gm
138203

deepspeed/compile/passes/zero1_compile.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch.fx import GraphModule
1010

1111
from ..util import get_deepcompile_handle
12-
from ..fx import add_postprocess, move_primals_to_head, _make_node_meta, get_output_node
12+
from ..fx import add_postprocess, move_primals_to_head, _make_node_meta, add_end_backward, replace_reduce_outputs_with_none
1313

1414
NAME = "zero1_compile"
1515

@@ -50,8 +50,8 @@ def add_z1_reduce_bw(gm: GraphModule, graph_id: int, param_manager) -> GraphModu
5050

5151
gm.graph = move_primals_to_head(graph)
5252

53-
with gm.graph.inserting_before(get_output_node(gm.graph)):
54-
gm.graph.create_node("call_function", torch.ops.dc.end_backward.default, (graph_id, ))
53+
add_end_backward(gm.graph, graph_id)
54+
replace_reduce_outputs_with_none(gm.graph)
5555

5656
return gm
5757

deepspeed/compile/passes/zero3_compile.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torch.fx import Graph, Node, GraphModule
1212

1313
from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses, is_cast_op
14-
from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head
14+
from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head, add_end_backward, replace_reduce_outputs_with_none
1515
from ..profilers.graph_profile import ProfilingInterpreter
1616
from ..list_schedule import fast_free_schedule
1717

@@ -209,8 +209,8 @@ def add_z3_gather_release_bw(gm: GraphModule,
209209
0, # unused
210210
debug_log=debug_log)
211211

212-
with gm.graph.inserting_before(get_output_node(gm.graph)):
213-
gm.graph.create_node("call_function", torch.ops.dc.end_backward.default, (graph_id, ))
212+
add_end_backward(gm.graph, graph_id)
213+
replace_reduce_outputs_with_none(gm.graph)
214214

215215
return gm
216216

deepspeed/compile/patch_fake_tensor.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@ def wrap_if_ds_param(t):
2929
return t
3030

3131

32+
def _get_guard_sizes_strides(t):
33+
if hasattr(t, "ds_id"):
34+
# ZeRO-3 may temporarily all-gather a parameter during tracing, but the
35+
# stable module state used by TorchDynamo guards is the released
36+
# partitioned form, where DeepSpeed resets param.data to empty(0).
37+
released = torch.empty(0, dtype=t.dtype, device=t.device)
38+
return released.size(), released.stride()
39+
40+
return t.size(), t.stride()
41+
42+
3243
def patch_fake_tensor():
3344
# dynamo tracer uses wrap_to_fake_tensor_and_record
3445
# Wrapping FakeTensorMode.from_tensor is not sufficient as dynamo generates SymbolicContext before calling from_tensor
@@ -37,8 +48,20 @@ def patch_fake_tensor():
3748
def wrap_to_fake_tensor_and_record_wrapper(t, *args, **kwargs):
3849
dummy_tensor = wrap_if_ds_param(t)
3950
ret = original_wrap_to_fake_tensor_and_record(dummy_tensor, *args, **kwargs)
51+
tx = kwargs.get("tx") if "tx" in kwargs else args[0]
52+
source = kwargs.get("source")
4053
if tracing_context := torch._guards.TracingContext.try_get():
4154
tracing_context.tensor_to_context[t] = tracing_context.tensor_to_context.pop(dummy_tensor)
55+
if source is not None:
56+
# Keep the full ds_shape symbolic context from the dummy tensor, but
57+
# use the stable released ZeRO-3 parameter representation for
58+
# TorchDynamo's tensor-match guards. PyTorch 2.9 started enforcing
59+
# those guards for parameters during build_guards().
60+
size, stride = _get_guard_sizes_strides(t)
61+
tx.output.input_source_to_sizes_strides[source] = {
62+
"size": size,
63+
"stride": stride,
64+
}
4265
return ret
4366

4467
torch._dynamo.variables.builder.wrap_to_fake_tensor_and_record = wrap_to_fake_tensor_and_record_wrapper

0 commit comments

Comments
 (0)