fix(hmc): forward dtype to QuadPotentialDiagAdapt in default potential creation#8280
Open
xodn348 wants to merge 1 commit into
Open
fix(hmc): forward dtype to QuadPotentialDiagAdapt in default potential creation#8280xodn348 wants to merge 1 commit into
xodn348 wants to merge 1 commit into
Conversation
…l creation BaseHMC.__init__ accepts a dtype parameter and passes it to the parent class, but omits it when constructing the default QuadPotentialDiagAdapt. This causes the potential to silently fall back to pytensor.config.floatX even when the caller explicitly requests a different dtype. Fixes pymc-devs#8213
Member
|
run pre-commit |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
BaseHMC.__init__accepts adtypeparameter and correctly forwards it to the parent class, but silently drops it when constructing the defaultQuadPotentialDiagAdapt. The potential therefore always falls back topytensor.config.floatX, ignoring the caller's explicitdtyperequest.One-line fix: add
dtype=dtypeto theQuadPotentialDiagAdapt(...)call inbase_hmc.pyline 169.Issue
Fixes #8213
Local verification
Verified against pymc 5.28.5 with the patch applied to the installed package.
Three tests were run:
pm.NUTS(dtype="float32").potential.dtype == np.dtype("float32")— passes after fix, would fail before.pm.NUTS().potential.dtype == np.dtype(pytensor.config.floatX)— backward-compatibility preserved.pm.HamiltonianMC(dtype="float32").potential.dtype == np.dtype("float32")— same fix coversHMCsince both useBaseHMC.Risk
Low. Single-argument addition to an existing constructor call. The behaviour for
dtype=None(the default, used by virtually all existing code) is unchanged —QuadPotentialDiagAdaptalready defaults topytensor.config.floatXwhendtypeisNone.