Skip to content

fix(hmc): forward dtype to QuadPotentialDiagAdapt in default potential creation#8280

Open
xodn348 wants to merge 1 commit into
pymc-devs:mainfrom
xodn348:fix/base-hmc-dtype-not-passed-to-quad-potential
Open

fix(hmc): forward dtype to QuadPotentialDiagAdapt in default potential creation#8280
xodn348 wants to merge 1 commit into
pymc-devs:mainfrom
xodn348:fix/base-hmc-dtype-not-passed-to-quad-potential

Conversation

@xodn348
Copy link
Copy Markdown

@xodn348 xodn348 commented May 6, 2026

Summary

BaseHMC.__init__ accepts a dtype parameter and correctly forwards it to the parent class, but silently drops it when constructing the default QuadPotentialDiagAdapt. The potential therefore always falls back to pytensor.config.floatX, ignoring the caller's explicit dtype request.

One-line fix: add dtype=dtype to the QuadPotentialDiagAdapt(...) call in base_hmc.py line 169.

Issue

Fixes #8213

Local verification

Verified against pymc 5.28.5 with the patch applied to the installed package.

potential.dtype = float32
PASS: potential.dtype = float32 as expected
potential.dtype = float64, expected = float64
PASS: potential.dtype = float64 (floatX) when dtype=None

All tests passed.
=== LOCAL_TEST_PASSED ===

Three tests were run:

  1. pm.NUTS(dtype="float32").potential.dtype == np.dtype("float32") — passes after fix, would fail before.
  2. pm.NUTS().potential.dtype == np.dtype(pytensor.config.floatX) — backward-compatibility preserved.
  3. pm.HamiltonianMC(dtype="float32").potential.dtype == np.dtype("float32") — same fix covers HMC since both use BaseHMC.

Risk

Low. Single-argument addition to an existing constructor call. The behaviour for dtype=None (the default, used by virtually all existing code) is unchanged — QuadPotentialDiagAdapt already defaults to pytensor.config.floatX when dtype is None.

…l creation

BaseHMC.__init__ accepts a dtype parameter and passes it to the parent
class, but omits it when constructing the default QuadPotentialDiagAdapt.
This causes the potential to silently fall back to pytensor.config.floatX
even when the caller explicitly requests a different dtype.

Fixes pymc-devs#8213
@github-actions github-actions Bot added the bug label May 6, 2026
@read-the-docs-community
Copy link
Copy Markdown

Documentation build overview

📚 pymc | 🛠️ Build #32561089 | 📁 Comparing 488d3bd against latest (151672a)

  🔍 Preview build  

1 file changed
± glossary.html

@ricardoV94
Copy link
Copy Markdown
Member

run pre-commit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG: BaseHMC doesn't pass dtype to ‎QuadPotentialDiagAdapt

2 participants