Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
7fd1242
Enhance has_jax() to validate version and platform support
thegialeo Feb 26, 2026
659c01d
Allow _parse_version to compare versions with differing lengths by ze…
thegialeo Feb 26, 2026
49a4445
Refactor has_jax() for improved readability and error handling
thegialeo Feb 26, 2026
ed96ecd
Add sync reminder comment for JAX version constraints
thegialeo Feb 26, 2026
d340c71
add unit test for helper function _parse_version in util.py
thegialeo Feb 26, 2026
7a197ba
add unit tests for helper function _is_version_in_range in util.py
thegialeo Feb 26, 2026
3914941
add unit tests for has_jax() function in util.py
thegialeo Feb 26, 2026
e8d3208
chore: auto-update via `nox -s pre-commit`
thegialeo Feb 26, 2026
8ed78ff
style: pre-commit fixes
pre-commit-ci[bot] Feb 26, 2026
72dd0f0
add PR changes to CHANGELOG.md
thegialeo Feb 26, 2026
dba8080
update change wording to be more user impact focused
thegialeo Feb 26, 2026
7edf037
Merge branch 'main' into issue-5381-jax-import-crash
martinjrobins Feb 26, 2026
1d97f52
isolate test_pybamm_import into a forked subprocess run
thegialeo Feb 27, 2026
23716ed
Temporarily skip test_pybamm_import to investigate CI flakiness
thegialeo Feb 28, 2026
f2843f3
update skip reason comment for test_pybamm_import
thegialeo Mar 3, 2026
26a54a4
Merge branch 'main' into issue-5381-jax-import-crash
thegialeo Mar 3, 2026
4f39634
fix: normalize 4-part jax/jaxlib versions (e.g. 0.9.0.1) before range…
thegialeo Mar 4, 2026
2c75f3d
Merge branch 'main' into issue-5381-jax-import-crash
thegialeo Mar 6, 2026
fa31c7d
Merge branch 'main' into issue-5381-jax-import-crash
thegialeo Mar 11, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ as initial conditions. ([#5311](https://github.com/pybamm-team/PyBaMM/pull/5311)
- Optimize state mapper for multi-step experiments by pre-calculating mapper during setup. ([#5380](https://github.com/pybamm-team/PyBaMM/pull/5380))

## Bug fixes

- PyBaMM no longer crashes on import if an incompatible version of JAX is installed (e.g., installed for other packages or numeric computations). This ensures JAX is truly an optional dependency: users who do not intend to use JAX can safely import PyBaMM, with a warning shown if JAX-dependent features are disabled. ([#5398](https://github.com/pybamm-team/PyBaMM/pull/5398))
- Fixed a bug in the exchange current density calculation for MSMR models. ([#5404](https://github.com/pybamm-team/PyBaMM/pull/5404))
- Fixed a bug where when converting `ExpressionFunctionParameter` to source code, `Interpolant` objects were being reduced to just their input variable names (e.g., sto) instead of preserving the full constructor call with data arrays. ([#5393](https://github.com/pybamm-team/PyBaMM/pull/5393))

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ cite = ["pybtex>=0.25.0"]
bpx = ["bpx>=0.5.0,<0.6.0"]
# Low-overhead progress bars
tqdm = ["tqdm"]
# When updating the jax version constraints below, also update MIN_VERSION and MAX_VERSION in src/pybamm/util.py has_jax()
jax = ["jax>=0.7.0, <0.9.0; python_version >= '3.11' and (sys_platform != 'darwin' or platform_machine != 'x86_64')"]
# Contains all optional dependencies, except for jax, and dev dependencies
all = [
Expand Down
116 changes: 110 additions & 6 deletions src/pybamm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import pathlib
import pickle
import re
import timeit
from warnings import warn

Expand Down Expand Up @@ -347,19 +348,122 @@ def get_parameters_filepath(path):
return os.path.join(pybamm.__path__[0], path)


def has_jax():
def _parse_version(version_str, length=3):
"""Parse version string into tuple of integers for comparison,
ignoring suffix (e.g., "0.7.0rc1" -> (0, 7, 0)).

Parameters
----------
version_str : str
The version string to parse.
length : int, optional
The desired length of the output tuple, defaults to 3.

Returns
-------
tuple of int
The parsed version as a tuple of integers normalized to the specified length.
Missing components are padded with zeros and extra components are truncated.
"""
Check if jax and jaxlib are installed with the correct versions
parsed = tuple(
int(re.match(r"\d+", part).group())
for part in version_str.split(".")
if re.match(r"\d+", part)
)

return (parsed + (0,) * length)[:length]


def _is_version_in_range(version_tuple, min_version, max_version):
"""Check if version is within the supported range [min_version, max_version).

Parameters
----------
version_tuple : tuple of int
The version to check, parsed as a tuple of integers.
min_version : tuple of int
The minimum supported version, inclusive, parsed as a tuple of integers.
max_version : tuple of int
The maximum supported version, exclusive, parsed as a tuple of integers.

Returns
-------
bool
True if jax and jaxlib are installed with the correct versions, False if otherwise
True if version_tuple is >= min_version and < max_version, False otherwise.
"""
return version_tuple >= min_version and version_tuple < max_version


def has_jax():
"""
return (importlib.util.find_spec("jax") is not None) and (
importlib.util.find_spec("jaxlib") is not None
)
Check if jax and jaxlib are installed with correct versions on a supported platform.

Returns
-------
bool
True if jax and jaxlib are installed with supported versions on a supported platform
(Linux, Windows, or macOS with Apple Silicon), False otherwise.

Notes
-----
This function checks that jax and jaxlib are installed with versions >= 0.7.0 and < 0.9.0,
and the platform is not macOS intel x86_64. These constraints should be kept in sync with
the jax optional dependency constraint in pyproject.toml. If versions or platform are
unsupported, a warning is emitted and False is returned to treat JAX as unavailable,
rather than raising a hard error.
"""
# Check if modules are available
if any(importlib.util.find_spec(pkg) is None for pkg in ("jax", "jaxlib")):
return False

# Check platform: JAX is not supported on macOS x86_64 (Intel)
# see https://docs.jax.dev/en/latest/changelog.html#jax-0-5-0-jan-17-2025
if is_macos_intel():
warn(
"JAX is not supported on macOS with Intel (x86_64) processors. "
"JAX dropped macOS x86_64 support in version 0.5.0. "
"To use JAX with PyBaMM, you need macOS with Apple Silicon (M-series), Linux, or Windows.",
UserWarning,
stacklevel=2,
)
return False

try:
jax_version = importlib.metadata.version("jax")
jaxlib_version = importlib.metadata.version("jaxlib")

# When updating these version constraints, also update the jax dependency in pyproject.toml
MIN_VERSION = (0, 7, 0)
MAX_VERSION = (0, 9, 0) # exclusive

jax_parsed = _parse_version(jax_version)
jaxlib_parsed = _parse_version(jaxlib_version)

# Check if both jax and jaxlib are within supported version range
if _is_version_in_range(
jax_parsed, MIN_VERSION, MAX_VERSION
) and _is_version_in_range(jaxlib_parsed, MIN_VERSION, MAX_VERSION):
return True

warn(
f"JAX version {jax_version} and/or jaxlib version {jaxlib_version} are not supported. "
f"Supported versions are >= {'.'.join(map(str, MIN_VERSION))}, < {'.'.join(map(str, MAX_VERSION))}. "
f"JAX features will be unavailable. To use JAX with PyBaMM, install compatible versions "
f"via: pip install 'pybamm[jax]'",
UserWarning,
stacklevel=2,
)
return False

except importlib.metadata.PackageNotFoundError:
# If package metadata cannot be found, treat JAX as unavailable
warn(
"JAX version information could not be retrieved. "
"JAX features will be unavailable.",
UserWarning,
stacklevel=2,
)
return False


def is_macos_intel():
Expand Down
99 changes: 99 additions & 0 deletions tests/unit/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def test_import_optional_dependency(self):
for import_pkg in present_optional_import_deps:
sys.modules[import_pkg] = modules[import_pkg]

@pytest.mark.skip(
reason="Skipped until refactor to prevent CI flakiness (see issue #5402)"
)
def test_pybamm_import(self):
optional_distribution_deps = get_optional_distribution_deps("pybamm")
present_optional_import_deps = get_present_optional_import_deps(
Expand Down Expand Up @@ -166,6 +169,102 @@ def test_optional_dependencies(self):
"or list them as required."
)

@pytest.mark.parametrize(
"input_str,expected",
[
("0.7.0", (0, 7, 0)),
("0.7", (0, 7, 0)),
("1.2.3rc1", (1, 2, 3)),
("2.0", (2, 0, 0)),
("10.4.5.dev0", (10, 4, 5)),
("0.8.0-beta", (0, 8, 0)),
("0.9.0.1", (0, 9, 0)),
("0.8.1.7", (0, 8, 1)),
],
)
def test_parse_version(self, input_str, expected):
assert pybamm.util._parse_version(input_str) == expected

@pytest.mark.parametrize(
"version,min_ver,max_ver,expected",
[
# Version within range
((0, 7, 5), (0, 7, 0), (0, 9, 0), True),
((0, 8, 0), (0, 7, 0), (0, 9, 0), True),
# Version equal to min (inclusive)
((0, 7, 0), (0, 7, 0), (0, 9, 0), True),
# Version equal to max (exclusive)
((0, 9, 0), (0, 7, 0), (0, 9, 0), False),
# Version below min
((0, 6, 9), (0, 7, 0), (0, 9, 0), False),
# Version above max
((0, 9, 1), (0, 7, 0), (0, 9, 0), False),
((1, 0, 0), (0, 7, 0), (0, 9, 0), False),
# Edge cases
((0, 7, 0), (0, 7, 0), (0, 7, 1), True),
((0, 7, 1), (0, 7, 0), (0, 7, 1), False),
],
)
def test_is_version_in_range(self, version, min_ver, max_ver, expected):
assert pybamm.util._is_version_in_range(version, min_ver, max_ver) == expected

def test_has_jax_not_installed(self, monkeypatch):
# Simulate jax not installed
monkeypatch.setattr("importlib.util.find_spec", lambda name: None)
assert pybamm.util.has_jax() is False

def test_has_jax_macos_intel(self, monkeypatch):
# Simulate jax installed but macOS Intel
monkeypatch.setattr("importlib.util.find_spec", lambda name: True)
monkeypatch.setattr("pybamm.util.is_macos_intel", lambda: True)
with pytest.warns(UserWarning):
assert pybamm.util.has_jax() is False

def test_has_jax_version_supported(self, monkeypatch):
# Simulate jax installed on supported platform with correct versions
monkeypatch.setattr("importlib.util.find_spec", lambda name: True)
monkeypatch.setattr("pybamm.util.is_macos_intel", lambda: False)
monkeypatch.setattr(
"importlib.metadata.version",
lambda name: "0.8.0", # valid version
)
assert pybamm.util.has_jax() is True

def test_has_jax_version_unsupported(self, monkeypatch):
# Simulate jax installed with unsupported version
monkeypatch.setattr("importlib.util.find_spec", lambda name: True)
monkeypatch.setattr("pybamm.util.is_macos_intel", lambda: False)
monkeypatch.setattr(
"importlib.metadata.version",
lambda name: "0.9.0", # too high
)
with pytest.warns(UserWarning):
assert pybamm.util.has_jax() is False

def test_has_jax_version_unsupported_four_part(self, monkeypatch):
# Simulate jax installed with unsupported four-part version
monkeypatch.setattr("importlib.util.find_spec", lambda name: True)
monkeypatch.setattr("pybamm.util.is_macos_intel", lambda: False)
monkeypatch.setattr(
"importlib.metadata.version",
lambda name: "0.9.0.1", # normalizes to 0.9.0 (too high)
)
with pytest.warns(UserWarning):
assert pybamm.util.has_jax() is False

def test_has_jax_version_error(self, monkeypatch):
# Simulate error reading version (PackageNotFoundError)
import importlib.metadata

def mock_version(name):
raise importlib.metadata.PackageNotFoundError(name)

monkeypatch.setattr("importlib.util.find_spec", lambda name: True)
monkeypatch.setattr("pybamm.util.is_macos_intel", lambda: False)
monkeypatch.setattr("importlib.metadata.version", mock_version)
with pytest.warns(UserWarning):
assert pybamm.util.has_jax() is False


class TestSearch:
def test_url_gets_to_stdout(self, mocker):
Expand Down