Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions src/haliax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,26 +193,32 @@ def __getitem__(cls, item: NamedArrayAxesSpec) -> typing.Any:
@dataclass(frozen=True)
class NamedArray(metaclass=NamedArrayMeta):
array: jnp.ndarray
axes: tuple[Axis, ...]
axis_names: tuple[str, ...]

def __init__(self, array: jnp.ndarray, axes: AxisSpec):
def __init__(self, array: jnp.ndarray, axes: AxisSelection):
object.__setattr__(self, "array", array)
if isinstance(axes, Mapping):
axes = tuple(Axis(name, size) for name, size in axes.items())
object.__setattr__(self, "axes", ensure_tuple(axes))
axes = tuple(Axis(name, size) for name, size in axes.items()) # type: ignore[arg-type]
axes = ensure_tuple(axes)

axis_names: list[str] = []
for axis in axes:
if isinstance(axis, Axis):
axis_names.append(axis.name)
elif isinstance(axis, str):
axis_names.append(axis)
else:
raise TypeError(f"Expected Axis or str, got {type(axis)}")

# ensure axes are all Axis objects
# TODO: anonymous positional axes?
for axis in self.axes:
if not isinstance(axis, Axis):
raise TypeError(f"Expected Axis, got {type(axis)}")
if len(set(axis_names)) != len(axis_names):
raise ValueError(f"Axes must be unique, but {axis_names} are not")

# ensure unique axes for now
if len(set(a.name for a in self.axes)) != len(self.axes):
raise ValueError(f"Axes must be unique, but {self.axes} are not")
object.__setattr__(self, "axis_names", tuple(axis_names))

if are_shape_checks_enabled():
self._ensure_shape_matches_axes()
@property
def axes(self) -> tuple[Axis, ...]:
shape = jnp.shape(self.array)
return tuple(Axis(name, int(sz)) for name, sz in zip(self.axis_names, shape))

def _ensure_shape_matches_axes(self):
"""This is typically called automatically, but sometimes we need to call it manually if
Expand Down Expand Up @@ -253,7 +259,7 @@ def scalar(self) -> jnp.ndarray:

@ft.cached_property
def shape(self) -> Dict[str, int]:
return {axis.name: axis.size for axis in self.axes}
return {name: size for name, size in zip(self.axis_names, jnp.shape(self.array))}

dtype = property(lambda self: self.array.dtype)
"""The dtype of the underlying array"""
Expand All @@ -265,7 +271,7 @@ def shape(self) -> Dict[str, int]:
"""The number of bytes in the underlying array"""

def tree_flatten(self) -> Any:
return ((self.array,), self.axes)
return ((self.array,), self.axis_names)

@classmethod
def tree_unflatten(cls, aux, tree: Any) -> Any:
Expand Down Expand Up @@ -1587,8 +1593,9 @@ def flatten(array: NamedArray, new_axis_name: AxisSelector) -> NamedArray:
def named(a, axis: AxisSelection) -> NamedArray:
"""Creates a NamedArray from a numpy array and a list of axes."""
a = jnp.asarray(a)
axes = check_shape(a.shape, axis)
return NamedArray(a, axes)
check_shape(a.shape, axis)
axis_names = tuple(axis_name(ax) for ax in axis_spec_to_tuple(axis))
return NamedArray(a, axis_names)


# Broadcasting Support
Expand Down
32 changes: 22 additions & 10 deletions src/haliax/debug.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import List, Tuple, Union, Sequence
from typing import List, Tuple, Union, Mapping

import equinox as eqx
import jax
Expand All @@ -8,7 +8,6 @@


from haliax.core import NamedArray
from haliax.axis import Axis
from haliax.util import is_jax_or_hax_array_like

from ._src.util import IdentityMap
Expand Down Expand Up @@ -131,13 +130,18 @@ def _pspec_parts(spec_part) -> str:
return str(spec_part)


def visualize_named_sharding(axes: Sequence[Axis], sharding: jax.sharding.Sharding) -> None:
def visualize_named_sharding(
shape: Mapping[str, int], sharding: jax.sharding.Sharding
) -> None:
"""Visualize the sharding for a set of named axes.

This extends :func:`jax.debug.visualize_sharding` to handle arrays with more
than two dimensions by falling back to a textual description when necessary.
"""

axes = list(shape.keys())
values = list(shape.values())

try:
pspec = sharding.spec # type: ignore[attr-defined]
except Exception:
Expand All @@ -148,11 +152,10 @@ def visualize_named_sharding(axes: Sequence[Axis], sharding: jax.sharding.Shardi

if num_sharded <= 2:
try:
jax.debug.visualize_sharding([ax.size for ax in axes], sharding)
jax.debug.visualize_sharding(values, sharding)
except Exception:
pass

mapping = ", ".join(f"{ax.name}->{part}" for ax, part in zip(axes, parts))
mapping = ", ".join(f"{name}->{part}" for name, part in zip(axes, parts))
print(mapping)


Expand All @@ -169,21 +172,30 @@ def visualize_shardings(tree) -> None:
def _show(x):
if isinstance(x, NamedArray):
arr = x.array
axes = x.axes
named_shape = x.shape
else:
arr = x
axes = None
named_shape = None

def cb(sh):
if axes is not None:
visualize_named_sharding(axes, sh)
if named_shape is not None:
visualize_named_sharding(named_shape, sh)
else:
try:
jax.debug.visualize_sharding(arr.shape, sh)
except Exception:
pass

jax.debug.inspect_array_sharding(arr, callback=cb)
if named_shape is not None:
try:
sh = arr.sharding
pspec = sh.spec # type: ignore[attr-defined]
except Exception:
pspec = (None,) * len(named_shape)
parts = [_pspec_parts(p) for p in pspec]
mapping = ", ".join(f"{name}->{part}" for name, part in zip(named_shape.keys(), parts))
print(mapping)
return x

htu.tree_map(_show, tree, is_leaf=is_jax_or_hax_array_like)