Skip to content

docs(jax): correct postprocessing_vectorize default in sample_jax_nuts docstring#8283

Open
xodn348 wants to merge 1 commit into
pymc-devs:mainfrom
xodn348:fix/jax-postprocessing-vectorize-docstring
Open

docs(jax): correct postprocessing_vectorize default in sample_jax_nuts docstring#8283
xodn348 wants to merge 1 commit into
pymc-devs:mainfrom
xodn348:fix/jax-postprocessing-vectorize-docstring

Conversation

@xodn348
Copy link
Copy Markdown

@xodn348 xodn348 commented May 8, 2026

Description

Fixes a docstring inaccuracy in sample_numpyro_nuts / sample_blackjax_nuts (the shared parameter section in pymc/sampling/jax.py).

The parameter description for postprocessing_vectorize states default "scan", but the actual runtime default is "vmap". When postprocessing_vectorize=None (i.e., the caller uses the default), the function body sets it to "vmap":

else:
    postprocessing_vectorize = "vmap"

This mismatch led a user (#8238) to remove their explicit postprocessing_vectorize="scan" argument (after reading that "scan" was the default) and accidentally switch to "vmap", causing an out-of-memory error.

The one-character change corrects default "scan"default "vmap" so the docstring matches the actual behaviour.

Related Issue

Checklist

  • Checked that the pre-commit linting/style checks pass (pre-commit run --files pymc/sampling/jax.py — all hooks passed)
  • Included tests that prove the fix is effective or that the new feature works — N/A: this is a docstring-only correction with no behaviour change; the existing code path is fully tested by the existing JAX test suite
  • Added necessary documentation (docstrings and/or example notebooks) — the fix IS the docstring correction
  • If you are a pro: each commit corresponds to a relevant logical change

Type of change

  • Documentation

Local verification

$ pre-commit run --files pymc/sampling/jax.py
check for merge conflicts............................................................Passed
debug statements (python)............................................................Passed
fix end of files.....................................................................Passed
don't commit to branch...............................................................Passed
trim trailing whitespace.............................................................Passed
check blanket noqa...................................................................Passed
check blanket type ignore............................................................Passed
check for not-real mock methods......................................................Passed
use logger.warning(..................................................................Passed
type annotations not comments........................................................Passed
no unicode replacement chars.........................................................Passed
Apply Apache 2.0 License.............................................................Passed
ruff check...........................................................................Passed
ruff format..........................................................................Passed

$ ruff check pymc/sampling/jax.py
All checks passed!

$ ruff format --check pymc/sampling/jax.py
1 file already formatted

=== LOCAL_TEST_PASSED ===

…s docstring

The docstring claimed the default was 'scan', but the function body
sets it to 'vmap' when None is supplied (the actual default). Fixes
the misleading documentation that caused users to accidentally rely
on the wrong default after removing an explicit argument.

Fixes pymc-devs#8238
@read-the-docs-community
Copy link
Copy Markdown

Documentation build overview

📚 pymc | 🛠️ Build #32600982 | 📁 Comparing 62d8d38 against latest (151672a)

  🔍 Preview build  

3 files changed
± glossary.html
± api/generated/pymc.sampling.jax.sample_blackjax_nuts.html
± api/generated/pymc.sampling.jax.sample_numpyro_nuts.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

postprocessing_vectorize default does not match docs/warnings

1 participant