-
Notifications
You must be signed in to change notification settings - Fork 85
Expand file tree
/
Copy pathdispatch.py
More file actions
313 lines (255 loc) · 10.5 KB
/
dispatch.py
File metadata and controls
313 lines (255 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Container
from copy import deepcopy
from functools import partial
from typing import Any, Optional, TypeVar
import torch
import torch.distributed as dist
from compressed_tensors.offload.cache import OffloadCache
from compressed_tensors.offload.module import offload_module, remove_module_offload
from compressed_tensors.offload.utils import (
get_module_device,
get_module_sizes,
module_size,
)
from compressed_tensors.utils import getattr_chain
from compressed_tensors.utils.binary_search import SearchFailureError, max_binary_search
from compressed_tensors.utils.helpers import deprecated
from loguru import logger
from torch._prims_common import DeviceLikeType
from tqdm import tqdm
from transformers import PreTrainedModel
__all__ = [
"set_onload_device",
"offload_model",
"dispatch_with_map",
"get_device_map",
"dispatch_model",
"remove_dispatch",
"get_device_memory",
"DeviceMap",
]
ModelType = TypeVar("ModelType", bound=torch.nn.Module)
DeviceMap = dict[str, tuple[torch.device | None, torch.device | str | None]]
def set_onload_device(
model: ModelType,
onload_device: torch.device | str,
) -> ModelType:
"""
Modify the dispatch of a model to onload to the provided `onload_device`. Existing
offloaded tensors will not be modified. If a module is not already offloaded, it
will be offloaded to its current device.
:param model: model to dispatch
:param onload_device: device to move weights to during forward pass
:return: dispatched model
"""
for module in model.modules():
if isinstance(module._parameters, OffloadCache):
module._parameters.onload_device = onload_device
module._buffers.onload_device = onload_device
else:
offload_device = get_module_device(module, torch.device("cpu"))
offload_module(module, onload_device, offload_device)
return model
@deprecated("set_onload_device")
def offload_model(
model: ModelType,
onload_device: torch.device | str,
offload_device: Any = None,
) -> ModelType:
"""
.. deprecated::
Use :func:`set_onload_device` instead.
"""
return set_onload_device(model, onload_device)
def dispatch_with_map(
model: torch.nn.Module,
device_map: DeviceMap,
offload_dir: Optional[str] = None,
show_progress: bool = True,
):
"""
Dispatch a model according to the provided device map
:param model: model to dispatch
:param device_map: device map specifying the onload and offload of each module
:param offload_dir: optional directory for disk offloading
:param show_progress: show tqdm progress
"""
for name, (onload_device, offload_device) in tqdm(
list(device_map.items()), desc="Dispatching model", disable=(not show_progress)
):
module = model.get_submodule(name)
if offload_device == "disk":
offload_module(
module, onload_device, offload_device, offload_dir=offload_dir
)
elif offload_device is not None:
offload_module(module, onload_device, offload_device)
def get_device_map(
model: torch.nn.Module, default_device: DeviceLikeType = torch.device("cpu")
) -> DeviceMap:
"""
Get the device map of a CT-offloaded model
:param: model: model to get device map of
:param default_device: the default onload/offload device
when module has no parameters
:return: device map specifying the onload and offload device of all modules
"""
from compressed_tensors.offload import get_execution_device, get_offloaded_device
return {
name: (
get_execution_device(module, default_device),
get_offloaded_device(module, default_device),
)
for name, module in model.named_modules(remove_duplicate=False)
}
def dispatch_model(
model: ModelType,
device_memory: dict[torch.device, int] | None = None,
extra_memory: int | None = None,
no_split_modules: Container[str] | None = None,
) -> ModelType:
"""
Dispatch a model for autoregressive generation. This means that modules are
dispatched evenly across available devices and kept onloaded if possible. If
onloading the entire model is not possible, some modules may be offloaded. Any
existing offloads will be removed.
Disclaimers:
* Optimal runtime assumes that modules are called in order of `model.modules()`
:param model: model to dispatch
:param device_memory: optional dictionary mapping torch device to available memory.
If none is provided, all available devices will be used
:param extra_memory: the amount of memory to be reserved for activations
:param no_split_modules: names of module classes which should not be split
across multiple devices
:return: dispatched model
"""
# infer no_split_modules
if no_split_modules is None:
no_split_modules = getattr(model, "_no_split_modules", tuple())
# collect devices
if device_memory is None:
device_memory: dict[torch.device, int] = get_device_memory()
if len(device_memory) <= 0:
raise MemoryError("Did not find any devices to dispatch model to")
# collect module sizes
sizes = get_module_sizes(model, no_split_modules)
if len(sizes) <= 0:
raise ValueError("Model does not have any modules")
# estimate memory requirement
if extra_memory is None:
# fragmentation, kv cache, embeddings, ect.
extra_memory = max(module_size(model) * 0.05, 1e9)
# activations
if isinstance(model, PreTrainedModel):
extra_memory += (
1 # batch_size
* 2048 # seq_len
* getattr_chain(model, "config.intermediate_size", 256)
* getattr(model, "dtype", torch.bfloat16).itemsize
)
# search for the best dispatch which maximizes extra memory across devices
try:
max_extra_memory = min(device_memory.values())
extra_memory, (dispatch, _) = max_binary_search(
fn=partial(_get_greedy_dispatch, sizes, device_memory),
cond=(lambda result: len(result[0]) == len(sizes)),
start=extra_memory,
end=max_extra_memory,
)
# fallback: create a cpu dispatch
except SearchFailureError:
dispatch, device_memory = _get_greedy_dispatch(
sizes, device_memory, extra_memory
)
assert len(dispatch) < len(sizes)
last_device = dispatch[-1][1] if len(dispatch) else list(device_memory)[0]
sizes_dict = {module: size for module, size in sizes}
largest_offloaded_module = max(size for _, size in sizes[len(dispatch) :])
# pop off modules until all offloaded modules can fit in last device
while largest_offloaded_module + extra_memory > device_memory[last_device]:
if len(dispatch) <= 0:
raise ValueError(
f"Cannot fit no_split module of size {largest_offloaded_module} "
f"bytes into any device: {device_memory}"
)
module, last_device, _ = dispatch.pop(-1)
device_memory[last_device] += sizes_dict[module]
largest_offloaded_module = max(largest_offloaded_module, sizes_dict[module])
# fill dispatch back with cpu offloading
for module, _ in list(sizes[len(dispatch) :]):
dispatch.append((module, last_device, "cpu"))
logger.warning("Forced to offload modules due to insufficient gpu resources")
# dispatch
assert len(dispatch) == len(sizes)
dispatch_dict = {
submodule: (onload, offload)
for module, onload, offload in dispatch
for submodule in module.modules()
}
for module in model.modules():
remove_module_offload(module, onload_tensors=True)
if module in dispatch_dict:
onload, offload = dispatch_dict[module]
offload_module(module, onload, offload)
logger.debug(f"Dispatched model with {extra_memory} bytes of extra memory")
return model
def get_device_memory() -> dict[torch.device, int]:
"""
Get the total memory of all available accelerator devices. Returns accelerator
device memory when available, otherwise falls back to CPU with system RAM.
:return: mapping from torch device to total memory
"""
if not torch.accelerator.is_available():
import os
total_ram = os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES")
return {torch.device("cpu"): total_ram}
accel_type = torch.accelerator.current_accelerator().type
if dist.is_available() and dist.is_initialized():
logger.info("Detected distributed context. Dispatching to local rank gpu")
device_memory = torch.accelerator.get_memory_info(
torch.accelerator.current_device_index()
)[1]
return {torch.device(accel_type): device_memory}
return {
torch.device(accel_type, idx): torch.accelerator.get_memory_info(idx)[1]
for idx in range(torch.accelerator.device_count())
}
def remove_dispatch(
module: torch.nn.Module, onload_tensors: bool = False
) -> torch.nn.Module:
"""
Remove any existing dispatches from module
:param onload_tensors: Whether to move tensors to the onloaded device, or keep them
on the offload device. Defaults to False.
:return: module with offloading functionality removed
"""
for submodule in module.modules():
remove_module_offload(submodule, onload_tensors)
return module
def _get_greedy_dispatch(
sizes: list[tuple[torch.nn.Module, int]],
device_memory: dict[torch.device, int],
extra_memory: int = 0,
) -> tuple[
list[tuple[torch.nn.Module, torch.device, torch.device]], dict[torch.device, int]
]:
dispatch = list()
memory_remaining = deepcopy(device_memory)
device_index = 0
devices = list(memory_remaining.keys())
if len(devices) <= 0:
raise ValueError()
for module, size in sizes:
while True:
if device_index >= len(devices):
return dispatch, memory_remaining
device = devices[device_index]
if size > memory_remaining[device] - extra_memory:
device_index += 1
continue
dispatch.append((module, device, device))
memory_remaining[device] -= size
break
return dispatch, memory_remaining