Skip to content

Commit b8bd93a

Browse files
BlaziusMaximusOrbax Authors
authored andcommitted
Optimize safetensors loading with 1D contiguous reads and ICI resharding.
PiperOrigin-RevId: 892406249
1 parent 31f0377 commit b8bd93a

File tree

1 file changed

+191
-119
lines changed

1 file changed

+191
-119
lines changed

checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py

Lines changed: 191 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,30 @@
1414

1515
"""Defines `SafetensorsLayout`, a class to handle Safetensors checkpoint formats."""
1616

17+
import asyncio
1718
import collections
1819
import json
19-
from typing import Any, Awaitable, Sequence, cast
20+
from typing import Any, Awaitable, Sequence
2021

2122
import jax
2223
import jax.numpy as jnp
2324
import numpy as np
24-
from orbax.checkpoint._src.arrays import numpy_utils
25+
from orbax.checkpoint._src.multihost import multihost as multihost_v0
2526
from orbax.checkpoint._src.path import async_path
2627
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
2728
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
2829
from orbax.checkpoint.experimental.v1._src.path import types
30+
from orbax.checkpoint.experimental.v1._src.synchronization import multihost
2931

3032
CheckpointLayout = checkpoint_layout.CheckpointLayout
3133
InvalidLayoutError = checkpoint_layout.InvalidLayoutError
3234
Path = types.Path
3335

3436
HEADER_NUM_BYTES = 8
3537
SAFETENSORS_SUFFIX = ".safetensors"
38+
MAX_GAP_SIZE_BYTES = (
39+
32 * 1024 * 1024
40+
) # 32 MB gap allowed between tensors in a coalesced read block
3641

3742

3843
def _get_dtypes() -> dict[str, Any]:
@@ -92,75 +97,6 @@ def _get_array_properties(info: dict[str, Any]) -> tuple[tuple[int, ...], Any]:
9297
return shape, dtype
9398

9499

95-
async def _read_non_contiguous_slice(
96-
f: async_path.AsyncFile,
97-
idx: tuple[slice, ...],
98-
stored_shape: tuple[int, ...],
99-
stored_dtype: np.dtype,
100-
tensor_file_offset: int,
101-
) -> np.ndarray:
102-
"""Reads a slice of a tensor from a file.
103-
104-
This function solves the problem of reading a multi-dimensional slice from an
105-
array where the slice's data is not stored as a single, contiguous block in
106-
the file. It does so by recursively "walking" the dimensions of the slice.
107-
108-
Args:
109-
f: The asynchronous file object (binary read mode)
110-
idx: A tuple of slice objects representing the n-dimensional slice to
111-
read.
112-
stored_shape: The shape of the tensor.
113-
stored_dtype: The `dtype` of the tensor.
114-
tensor_file_offset: The starting byte offset of the tensor's data within
115-
the file.
116-
117-
Returns:
118-
The specific tensor slice.
119-
"""
120-
# Handle 0-d scalar case
121-
if not idx:
122-
await f.seek(tensor_file_offset)
123-
num_bytes = np.dtype(stored_dtype).itemsize
124-
scalar_bytes = await f.read(num_bytes)
125-
# Reshape to () to create a 0-D NumPy array.
126-
return np.frombuffer(scalar_bytes, dtype=stored_dtype).reshape(())
127-
128-
itemsize = np.dtype(stored_dtype).itemsize
129-
130-
# Calculate the byte strides for the full tensor. The stride for a
131-
# dimension is the number of bytes to "jump" to get to the next element
132-
# in that dimension while keeping all other indices the same.
133-
global_strides = [itemsize] * len(stored_shape)
134-
for i in range(len(stored_shape) - 2, -1, -1):
135-
global_strides[i] = global_strides[i + 1] * stored_shape[i + 1]
136-
137-
async def _read_slice_recursively(dim: int, base_offset: int) -> bytes:
138-
# TODO(b/438763866) - @zachmeyers to consider alternative methods.
139-
s = idx[dim] # The slice for the current dimension.
140-
141-
# If we are at the last dimension, the data is contiguous.
142-
if dim == len(stored_shape) - 1:
143-
start = base_offset + s.start * global_strides[dim]
144-
num_bytes = (s.stop - s.start) * itemsize
145-
await f.seek(tensor_file_offset + start)
146-
return cast(bytes, await f.read(num_bytes))
147-
148-
# For all other dimensions, iterate through the indices
149-
# of the slice and make a recursive call for the next dimension.
150-
chunks = []
151-
for i in range(s.start, s.stop):
152-
offset = base_offset + i * global_strides[dim]
153-
chunk = await _read_slice_recursively(dim + 1, offset)
154-
chunks.append(chunk)
155-
156-
return b"".join(chunks)
157-
158-
# Start the recursive reading process from the first dimension.
159-
slice_bytes = await _read_slice_recursively(dim=0, base_offset=0)
160-
shard_shape = numpy_utils.slice_shape(idx)
161-
return np.frombuffer(slice_bytes, dtype=stored_dtype).reshape(shard_shape)
162-
163-
164100
async def _load_safetensors_as_numpy(path: Path) -> dict[str, np.ndarray]:
165101
"""Loads tensors from a safetensors file into host NumPy arrays."""
166102
header, data_start_offset = await _read_safetensors_header(path)
@@ -179,65 +115,201 @@ async def _load_safetensors_as_numpy(path: Path) -> dict[str, np.ndarray]:
179115
return tensors
180116

181117

118+
def _create_non_sharded_array(
119+
raw_data: memoryview | bytes,
120+
abstract_leaf: Any,
121+
stored_shape: tuple[int, ...],
122+
stored_dtype: Any,
123+
) -> jax.Array:
124+
"""Creates a non-sharded JAX array from raw bytes."""
125+
np_array = np.frombuffer(raw_data, dtype=stored_dtype).reshape(stored_shape)
126+
target_dtype = abstract_leaf.dtype
127+
if np_array.dtype != target_dtype:
128+
np_array = np_array.astype(target_dtype)
129+
return jax.device_put(np_array)
130+
131+
132+
def _create_sharded_array(
133+
raw_data: memoryview | bytes,
134+
abstract_leaf: Any,
135+
stored_shape: tuple[int, ...],
136+
stored_dtype: Any,
137+
num_hosts: int,
138+
host_id: int,
139+
flat_sharding: jax.sharding.NamedSharding,
140+
) -> jax.Array:
141+
"""Creates a sharded JAX array from raw bytes."""
142+
sharding = abstract_leaf.sharding
143+
target_dtype = abstract_leaf.dtype
144+
145+
# Use 1D flat contiguous read + reshard logic for maximum IO throughput.
146+
total_elements = int(np.prod(stored_shape)) if stored_shape else 1
147+
148+
# Calculate padding
149+
elements_per_host = (total_elements + num_hosts - 1) // num_hosts
150+
padded_elements = elements_per_host * num_hosts
151+
152+
start_idx = host_id * elements_per_host
153+
end_idx = min((host_id + 1) * elements_per_host, total_elements)
154+
num_elements_to_read = max(0, end_idx - start_idx)
155+
156+
local_data = np.frombuffer(raw_data, dtype=stored_dtype)
157+
if local_data.dtype != target_dtype:
158+
local_data = local_data.astype(target_dtype)
159+
160+
if num_elements_to_read < elements_per_host:
161+
local_data = np.pad(
162+
local_data, (0, elements_per_host - num_elements_to_read)
163+
)
164+
165+
# Put local data on all addressable devices in the flat sharding
166+
local_arrays = [
167+
jax.device_put(local_data, d) for d in flat_sharding.addressable_devices
168+
]
169+
170+
# Create the 1D sharded array
171+
flat_array = jax.make_array_from_single_device_arrays(
172+
(padded_elements,), flat_sharding, local_arrays
173+
)
174+
175+
# Slice off the padding and reshape
176+
if padded_elements > total_elements:
177+
flat_array = flat_array[:total_elements]
178+
179+
reshaped_array = flat_array.reshape(stored_shape)
180+
181+
# Reshard to the target sharding
182+
target_array = jax.device_put(reshaped_array, sharding)
183+
184+
return target_array
185+
186+
187+
async def _load_non_sharded_array(
188+
path: Path,
189+
abstract_leaf: Any,
190+
header_info: dict[str, Any],
191+
data_start_offset: int,
192+
) -> jax.Array:
193+
"""Loads a single non-sharded array from a safetensors file."""
194+
stored_shape, stored_dtype = _get_array_properties(header_info)
195+
st_data_offsets = header_info["data_offsets"]
196+
197+
start_offset, end_offset = st_data_offsets
198+
num_bytes = end_offset - start_offset
199+
async with async_path.open_file(path, mode="rb") as f:
200+
await f.seek(data_start_offset + start_offset)
201+
tensor_bytes = await f.read(num_bytes)
202+
203+
return _create_non_sharded_array(
204+
tensor_bytes, abstract_leaf, stored_shape, stored_dtype
205+
)
206+
207+
208+
async def _load_sharded_array(
209+
path: Path,
210+
abstract_leaf: Any,
211+
header_info: dict[str, Any],
212+
data_start_offset: int,
213+
num_hosts: int,
214+
host_id: int,
215+
flat_sharding: jax.sharding.NamedSharding,
216+
) -> jax.Array:
217+
"""Loads a single sharded array from a safetensors file."""
218+
stored_shape, stored_dtype = _get_array_properties(header_info)
219+
st_data_offsets = header_info["data_offsets"]
220+
221+
total_elements = int(np.prod(stored_shape)) if stored_shape else 1
222+
elements_per_host = (total_elements + num_hosts - 1) // num_hosts
223+
start_idx = host_id * elements_per_host
224+
end_idx = min((host_id + 1) * elements_per_host, total_elements)
225+
num_elements_to_read = max(0, end_idx - start_idx)
226+
itemsize = np.dtype(stored_dtype).itemsize
227+
228+
start_byte = st_data_offsets[0] + data_start_offset + start_idx * itemsize
229+
num_bytes = num_elements_to_read * itemsize
230+
231+
async with async_path.open_file(path, mode="rb") as f:
232+
await f.seek(start_byte)
233+
raw_data = await f.read(num_bytes)
234+
235+
return _create_sharded_array(
236+
raw_data,
237+
abstract_leaf,
238+
stored_shape,
239+
stored_dtype,
240+
num_hosts,
241+
host_id,
242+
flat_sharding,
243+
)
244+
245+
182246
async def _load_safetensors_on_device(
183247
path: Path, abstract_pytree: dict[str, Any]
184248
) -> dict[str, jax.Array]:
185249
"""Loads tensors from a safetensors file into on-device JAX arrays."""
186250
header, data_start_offset = await _read_safetensors_header(path)
187251
restored_pytree = {}
188-
async with async_path.open_file(path, mode="rb") as f:
189-
for tensor_name, abstract_leaf in abstract_pytree.items():
190-
if tensor_name not in header:
191-
raise KeyError(
192-
f"Tensor '{tensor_name}' not found in safetensors header of {path}."
193-
)
194252

195-
stored_shape, stored_dtype = _get_array_properties(header[tensor_name])
196-
st_data_offsets = header[tensor_name]["data_offsets"]
197-
sharding = abstract_leaf.sharding
198-
target_shape = abstract_leaf.shape
199-
target_dtype = abstract_leaf.dtype
200-
201-
if sharding is None:
202-
start_offset, end_offset = st_data_offsets
203-
num_bytes = end_offset - start_offset
204-
await f.seek(data_start_offset + start_offset)
205-
tensor_bytes = await f.read(num_bytes)
206-
np_array = np.frombuffer(tensor_bytes, dtype=stored_dtype).reshape(
207-
stored_shape
208-
)
209-
if np_array.dtype != target_dtype:
210-
np_array = np_array.astype(target_dtype)
211-
restored_pytree[tensor_name] = jax.device_put(np_array)
212-
continue
213-
214-
device_indices_map = sharding.addressable_devices_indices_map(
215-
target_shape
253+
num_hosts = multihost.process_count()
254+
host_id = jax.process_index()
255+
256+
# Build an initial mesh grouping all global devices by host
257+
devices_by_host = []
258+
for i in range(num_hosts):
259+
devices_by_host.append([
260+
d
261+
for d in jax.devices()
262+
if multihost_v0.process_index_from_device(d) == i
263+
])
264+
265+
# Ensure uniform mesh shape (in case of uneven device counts, which is rare)
266+
num_devices_per_host = len(devices_by_host[0])
267+
for d in devices_by_host:
268+
if len(d) != num_devices_per_host:
269+
raise ValueError("Number of devices must be the same across all hosts.")
270+
271+
initial_mesh = jax.sharding.Mesh(
272+
np.array(devices_by_host), ("hosts", "devices")
273+
)
274+
flat_sharding = jax.sharding.NamedSharding(
275+
initial_mesh, jax.sharding.PartitionSpec("hosts")
276+
)
277+
278+
async def _load_tensor(
279+
tensor_name: str, abstract_leaf: Any
280+
) -> tuple[str, jax.Array]:
281+
if abstract_leaf.sharding is None:
282+
tensor = await _load_non_sharded_array(
283+
path,
284+
abstract_leaf,
285+
header[tensor_name],
286+
data_start_offset,
216287
)
288+
else:
289+
# We have a target sharding.
290+
tensor = await _load_sharded_array(
291+
path,
292+
abstract_leaf,
293+
header[tensor_name],
294+
data_start_offset,
295+
num_hosts,
296+
host_id,
297+
flat_sharding,
298+
)
299+
return tensor_name, tensor
217300

218-
device_map = []
219-
for device in device_indices_map:
220-
idx = device_indices_map[device]
221-
resolved_idx = numpy_utils.resolve_slice(idx, stored_shape)
222-
shard_shape = numpy_utils.slice_shape(resolved_idx)
223-
224-
shard_np = await _read_non_contiguous_slice(
225-
f,
226-
resolved_idx,
227-
stored_shape,
228-
stored_dtype,
229-
st_data_offsets[0] + data_start_offset,
230-
)
231-
shard_np = shard_np.reshape(shard_shape) # pytype: disable=attribute-error
232-
233-
if shard_np.dtype != target_dtype:
234-
shard_np = shard_np.astype(target_dtype)
301+
tasks = []
302+
for tensor_name, abstract_leaf in abstract_pytree.items():
303+
if tensor_name not in header:
304+
raise KeyError(
305+
f"Tensor '{tensor_name}' not found in safetensors header of {path}."
306+
)
307+
tasks.append(_load_tensor(tensor_name, abstract_leaf))
235308

236-
device_map.append(jax.device_put(shard_np, device))
309+
results = await asyncio.gather(*tasks)
310+
for tensor_name, tensor in results:
311+
restored_pytree[tensor_name] = tensor
237312

238-
restored_pytree[tensor_name] = jax.make_array_from_single_device_arrays(
239-
target_shape, sharding, device_map
240-
)
241313
return restored_pytree
242314

243315

0 commit comments

Comments
 (0)