1414
1515"""Defines `SafetensorsLayout`, a class to handle Safetensors checkpoint formats."""
1616
17+ import asyncio
1718import collections
1819import json
19- from typing import Any , Awaitable , Sequence , cast
20+ from typing import Any , Awaitable , Sequence
2021
2122import jax
2223import jax .numpy as jnp
2324import numpy as np
24- from orbax .checkpoint ._src .arrays import numpy_utils
25+ from orbax .checkpoint ._src .multihost import multihost as multihost_v0
2526from orbax .checkpoint ._src .path import async_path
2627from orbax .checkpoint .experimental .v1 ._src .layout import checkpoint_layout
2728from orbax .checkpoint .experimental .v1 ._src .metadata import types as metadata_types
2829from orbax .checkpoint .experimental .v1 ._src .path import types
30+ from orbax .checkpoint .experimental .v1 ._src .synchronization import multihost
2931
3032CheckpointLayout = checkpoint_layout .CheckpointLayout
3133InvalidLayoutError = checkpoint_layout .InvalidLayoutError
3234Path = types .Path
3335
3436HEADER_NUM_BYTES = 8
3537SAFETENSORS_SUFFIX = ".safetensors"
38+ MAX_GAP_SIZE_BYTES = (
39+ 32 * 1024 * 1024
40+ ) # 32 MB gap allowed between tensors in a coalesced read block
3641
3742
3843def _get_dtypes () -> dict [str , Any ]:
@@ -92,75 +97,6 @@ def _get_array_properties(info: dict[str, Any]) -> tuple[tuple[int, ...], Any]:
9297 return shape , dtype
9398
9499
95- async def _read_non_contiguous_slice (
96- f : async_path .AsyncFile ,
97- idx : tuple [slice , ...],
98- stored_shape : tuple [int , ...],
99- stored_dtype : np .dtype ,
100- tensor_file_offset : int ,
101- ) -> np .ndarray :
102- """Reads a slice of a tensor from a file.
103-
104- This function solves the problem of reading a multi-dimensional slice from an
105- array where the slice's data is not stored as a single, contiguous block in
106- the file. It does so by recursively "walking" the dimensions of the slice.
107-
108- Args:
109- f: The asynchronous file object (binary read mode)
110- idx: A tuple of slice objects representing the n-dimensional slice to
111- read.
112- stored_shape: The shape of the tensor.
113- stored_dtype: The `dtype` of the tensor.
114- tensor_file_offset: The starting byte offset of the tensor's data within
115- the file.
116-
117- Returns:
118- The specific tensor slice.
119- """
120- # Handle 0-d scalar case
121- if not idx :
122- await f .seek (tensor_file_offset )
123- num_bytes = np .dtype (stored_dtype ).itemsize
124- scalar_bytes = await f .read (num_bytes )
125- # Reshape to () to create a 0-D NumPy array.
126- return np .frombuffer (scalar_bytes , dtype = stored_dtype ).reshape (())
127-
128- itemsize = np .dtype (stored_dtype ).itemsize
129-
130- # Calculate the byte strides for the full tensor. The stride for a
131- # dimension is the number of bytes to "jump" to get to the next element
132- # in that dimension while keeping all other indices the same.
133- global_strides = [itemsize ] * len (stored_shape )
134- for i in range (len (stored_shape ) - 2 , - 1 , - 1 ):
135- global_strides [i ] = global_strides [i + 1 ] * stored_shape [i + 1 ]
136-
137- async def _read_slice_recursively (dim : int , base_offset : int ) -> bytes :
138- # TODO(b/438763866) - @zachmeyers to consider alternative methods.
139- s = idx [dim ] # The slice for the current dimension.
140-
141- # If we are at the last dimension, the data is contiguous.
142- if dim == len (stored_shape ) - 1 :
143- start = base_offset + s .start * global_strides [dim ]
144- num_bytes = (s .stop - s .start ) * itemsize
145- await f .seek (tensor_file_offset + start )
146- return cast (bytes , await f .read (num_bytes ))
147-
148- # For all other dimensions, iterate through the indices
149- # of the slice and make a recursive call for the next dimension.
150- chunks = []
151- for i in range (s .start , s .stop ):
152- offset = base_offset + i * global_strides [dim ]
153- chunk = await _read_slice_recursively (dim + 1 , offset )
154- chunks .append (chunk )
155-
156- return b"" .join (chunks )
157-
158- # Start the recursive reading process from the first dimension.
159- slice_bytes = await _read_slice_recursively (dim = 0 , base_offset = 0 )
160- shard_shape = numpy_utils .slice_shape (idx )
161- return np .frombuffer (slice_bytes , dtype = stored_dtype ).reshape (shard_shape )
162-
163-
164100async def _load_safetensors_as_numpy (path : Path ) -> dict [str , np .ndarray ]:
165101 """Loads tensors from a safetensors file into host NumPy arrays."""
166102 header , data_start_offset = await _read_safetensors_header (path )
@@ -179,65 +115,201 @@ async def _load_safetensors_as_numpy(path: Path) -> dict[str, np.ndarray]:
179115 return tensors
180116
181117
118+ def _create_non_sharded_array (
119+ raw_data : memoryview | bytes ,
120+ abstract_leaf : Any ,
121+ stored_shape : tuple [int , ...],
122+ stored_dtype : Any ,
123+ ) -> jax .Array :
124+ """Creates a non-sharded JAX array from raw bytes."""
125+ np_array = np .frombuffer (raw_data , dtype = stored_dtype ).reshape (stored_shape )
126+ target_dtype = abstract_leaf .dtype
127+ if np_array .dtype != target_dtype :
128+ np_array = np_array .astype (target_dtype )
129+ return jax .device_put (np_array )
130+
131+
132+ def _create_sharded_array (
133+ raw_data : memoryview | bytes ,
134+ abstract_leaf : Any ,
135+ stored_shape : tuple [int , ...],
136+ stored_dtype : Any ,
137+ num_hosts : int ,
138+ host_id : int ,
139+ flat_sharding : jax .sharding .NamedSharding ,
140+ ) -> jax .Array :
141+ """Creates a sharded JAX array from raw bytes."""
142+ sharding = abstract_leaf .sharding
143+ target_dtype = abstract_leaf .dtype
144+
145+ # Use 1D flat contiguous read + reshard logic for maximum IO throughput.
146+ total_elements = int (np .prod (stored_shape )) if stored_shape else 1
147+
148+ # Calculate padding
149+ elements_per_host = (total_elements + num_hosts - 1 ) // num_hosts
150+ padded_elements = elements_per_host * num_hosts
151+
152+ start_idx = host_id * elements_per_host
153+ end_idx = min ((host_id + 1 ) * elements_per_host , total_elements )
154+ num_elements_to_read = max (0 , end_idx - start_idx )
155+
156+ local_data = np .frombuffer (raw_data , dtype = stored_dtype )
157+ if local_data .dtype != target_dtype :
158+ local_data = local_data .astype (target_dtype )
159+
160+ if num_elements_to_read < elements_per_host :
161+ local_data = np .pad (
162+ local_data , (0 , elements_per_host - num_elements_to_read )
163+ )
164+
165+ # Put local data on all addressable devices in the flat sharding
166+ local_arrays = [
167+ jax .device_put (local_data , d ) for d in flat_sharding .addressable_devices
168+ ]
169+
170+ # Create the 1D sharded array
171+ flat_array = jax .make_array_from_single_device_arrays (
172+ (padded_elements ,), flat_sharding , local_arrays
173+ )
174+
175+ # Slice off the padding and reshape
176+ if padded_elements > total_elements :
177+ flat_array = flat_array [:total_elements ]
178+
179+ reshaped_array = flat_array .reshape (stored_shape )
180+
181+ # Reshard to the target sharding
182+ target_array = jax .device_put (reshaped_array , sharding )
183+
184+ return target_array
185+
186+
187+ async def _load_non_sharded_array (
188+ path : Path ,
189+ abstract_leaf : Any ,
190+ header_info : dict [str , Any ],
191+ data_start_offset : int ,
192+ ) -> jax .Array :
193+ """Loads a single non-sharded array from a safetensors file."""
194+ stored_shape , stored_dtype = _get_array_properties (header_info )
195+ st_data_offsets = header_info ["data_offsets" ]
196+
197+ start_offset , end_offset = st_data_offsets
198+ num_bytes = end_offset - start_offset
199+ async with async_path .open_file (path , mode = "rb" ) as f :
200+ await f .seek (data_start_offset + start_offset )
201+ tensor_bytes = await f .read (num_bytes )
202+
203+ return _create_non_sharded_array (
204+ tensor_bytes , abstract_leaf , stored_shape , stored_dtype
205+ )
206+
207+
208+ async def _load_sharded_array (
209+ path : Path ,
210+ abstract_leaf : Any ,
211+ header_info : dict [str , Any ],
212+ data_start_offset : int ,
213+ num_hosts : int ,
214+ host_id : int ,
215+ flat_sharding : jax .sharding .NamedSharding ,
216+ ) -> jax .Array :
217+ """Loads a single sharded array from a safetensors file."""
218+ stored_shape , stored_dtype = _get_array_properties (header_info )
219+ st_data_offsets = header_info ["data_offsets" ]
220+
221+ total_elements = int (np .prod (stored_shape )) if stored_shape else 1
222+ elements_per_host = (total_elements + num_hosts - 1 ) // num_hosts
223+ start_idx = host_id * elements_per_host
224+ end_idx = min ((host_id + 1 ) * elements_per_host , total_elements )
225+ num_elements_to_read = max (0 , end_idx - start_idx )
226+ itemsize = np .dtype (stored_dtype ).itemsize
227+
228+ start_byte = st_data_offsets [0 ] + data_start_offset + start_idx * itemsize
229+ num_bytes = num_elements_to_read * itemsize
230+
231+ async with async_path .open_file (path , mode = "rb" ) as f :
232+ await f .seek (start_byte )
233+ raw_data = await f .read (num_bytes )
234+
235+ return _create_sharded_array (
236+ raw_data ,
237+ abstract_leaf ,
238+ stored_shape ,
239+ stored_dtype ,
240+ num_hosts ,
241+ host_id ,
242+ flat_sharding ,
243+ )
244+
245+
182246async def _load_safetensors_on_device (
183247 path : Path , abstract_pytree : dict [str , Any ]
184248) -> dict [str , jax .Array ]:
185249 """Loads tensors from a safetensors file into on-device JAX arrays."""
186250 header , data_start_offset = await _read_safetensors_header (path )
187251 restored_pytree = {}
188- async with async_path .open_file (path , mode = "rb" ) as f :
189- for tensor_name , abstract_leaf in abstract_pytree .items ():
190- if tensor_name not in header :
191- raise KeyError (
192- f"Tensor '{ tensor_name } ' not found in safetensors header of { path } ."
193- )
194252
195- stored_shape , stored_dtype = _get_array_properties (header [tensor_name ])
196- st_data_offsets = header [tensor_name ]["data_offsets" ]
197- sharding = abstract_leaf .sharding
198- target_shape = abstract_leaf .shape
199- target_dtype = abstract_leaf .dtype
200-
201- if sharding is None :
202- start_offset , end_offset = st_data_offsets
203- num_bytes = end_offset - start_offset
204- await f .seek (data_start_offset + start_offset )
205- tensor_bytes = await f .read (num_bytes )
206- np_array = np .frombuffer (tensor_bytes , dtype = stored_dtype ).reshape (
207- stored_shape
208- )
209- if np_array .dtype != target_dtype :
210- np_array = np_array .astype (target_dtype )
211- restored_pytree [tensor_name ] = jax .device_put (np_array )
212- continue
213-
214- device_indices_map = sharding .addressable_devices_indices_map (
215- target_shape
253+ num_hosts = multihost .process_count ()
254+ host_id = jax .process_index ()
255+
256+ # Build an initial mesh grouping all global devices by host
257+ devices_by_host = []
258+ for i in range (num_hosts ):
259+ devices_by_host .append ([
260+ d
261+ for d in jax .devices ()
262+ if multihost_v0 .process_index_from_device (d ) == i
263+ ])
264+
265+ # Ensure uniform mesh shape (in case of uneven device counts, which is rare)
266+ num_devices_per_host = len (devices_by_host [0 ])
267+ for d in devices_by_host :
268+ if len (d ) != num_devices_per_host :
269+ raise ValueError ("Number of devices must be the same across all hosts." )
270+
271+ initial_mesh = jax .sharding .Mesh (
272+ np .array (devices_by_host ), ("hosts" , "devices" )
273+ )
274+ flat_sharding = jax .sharding .NamedSharding (
275+ initial_mesh , jax .sharding .PartitionSpec ("hosts" )
276+ )
277+
278+ async def _load_tensor (
279+ tensor_name : str , abstract_leaf : Any
280+ ) -> tuple [str , jax .Array ]:
281+ if abstract_leaf .sharding is None :
282+ tensor = await _load_non_sharded_array (
283+ path ,
284+ abstract_leaf ,
285+ header [tensor_name ],
286+ data_start_offset ,
216287 )
288+ else :
289+ # We have a target sharding.
290+ tensor = await _load_sharded_array (
291+ path ,
292+ abstract_leaf ,
293+ header [tensor_name ],
294+ data_start_offset ,
295+ num_hosts ,
296+ host_id ,
297+ flat_sharding ,
298+ )
299+ return tensor_name , tensor
217300
218- device_map = []
219- for device in device_indices_map :
220- idx = device_indices_map [device ]
221- resolved_idx = numpy_utils .resolve_slice (idx , stored_shape )
222- shard_shape = numpy_utils .slice_shape (resolved_idx )
223-
224- shard_np = await _read_non_contiguous_slice (
225- f ,
226- resolved_idx ,
227- stored_shape ,
228- stored_dtype ,
229- st_data_offsets [0 ] + data_start_offset ,
230- )
231- shard_np = shard_np .reshape (shard_shape ) # pytype: disable=attribute-error
232-
233- if shard_np .dtype != target_dtype :
234- shard_np = shard_np .astype (target_dtype )
301+ tasks = []
302+ for tensor_name , abstract_leaf in abstract_pytree .items ():
303+ if tensor_name not in header :
304+ raise KeyError (
305+ f"Tensor '{ tensor_name } ' not found in safetensors header of { path } ."
306+ )
307+ tasks .append (_load_tensor (tensor_name , abstract_leaf ))
235308
236- device_map .append (jax .device_put (shard_np , device ))
309+ results = await asyncio .gather (* tasks )
310+ for tensor_name , tensor in results :
311+ restored_pytree [tensor_name ] = tensor
237312
238- restored_pytree [tensor_name ] = jax .make_array_from_single_device_arrays (
239- target_shape , sharding , device_map
240- )
241313 return restored_pytree
242314
243315
0 commit comments