Improve identifiability of non-stationary GP#1396
Conversation
Three changes to inst/stan that together eliminate stuck chains and divergences seen on the non-stationary Gaussian process used to model Rt: 1. Rescale GP increments by 1/sqrt(gp_n) inside update_Rt so alpha controls the trajectory SD rather than the increment SD, matching what gp_opts() docs already claim. 2. Centre the cumulated GP (gp -= mean(gp)) so log R0 = mean log Rt rather than the initial value. Eliminates the (R0, drift) ridge in the joint posterior. 3. Switch GP coefficient sampling to centred form: eta ~ N(0, diagSPD) with noise = PHI * eta. Avoids the (alpha, eta) funnel that arises in non-centred form when alpha is small. On previously catastrophic seeds (R-hat > 4 with the old form) the new parameterisation gives R-hat < 1.02 with zero divergences and roughly 4x faster sampling. Co-authored-by: sbfnk <sebastian.funk@lshtm.ac.uk>
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Just to note this is far from being ready to merge but I'm curious about the headline benchmark results given the outrageous performance claims. |
|
Synthetic recovery shows reduced GP amplitude (presumably related to rescaling / changed meaning of alpha) which may well explain the fitting improvements (which in itself is perhaps a useful observation that could support addressing #1376). |
|
Yup nice I have been doing similar sweeps on Centering the GP seems sensible |
|
This is how benchmark results would change (along with a 95% confidence interval in relative change) if 5d870b5 is merged into main:
|
Under the old non-stationary GP parameterisation, alpha was the increment SD and the implied trajectory SD scaled with sqrt(gp_n). With the new parameterisation (rescaling + cumsum centring), alpha is the trajectory SD directly, and the same numerical alpha implies a ~13× tighter Rt trajectory than before for typical gp_n. Bumping the default from Normal(0, 0.01) to Normal(0, 0.08) restores approximately the same Rt expressiveness as the old default. Documents the new alpha semantics in the gp_opts() roxygen. Verified empirically: the previously stuck seeds still sample without divergences under the wider prior, so the geometric improvements stand on their own — this is purely a default-tuning fix. Co-authored-by: sbfnk <sebastian.funk@lshtm.ac.uk>
estimate_infectionsTable: Benchmarking results (mean time in seconds). |operation | branch| main| % change| range| trend| |:--------------------|-------:|-------:|--------:|----------:|---------:| |delays | 2.9| 11| -72| (-85, 12)| no change| |infections | 1.2| 4.2| -71| (-85, 19)| no change| |reports | 1.2| 4.1| -72| (-85, 22)| no change| |report lp | 0.86| 3.1| -72| (-85, 12)| no change| |R0 | 0.63| 1.3| -52| (-75, 99)| no change| |update gp | 0.14| 1.1| -87| (-93, -47)| speedup| |gt | 0.13| 0.48| -72| (-85, 15)| no change| |day of the week | 0.12| 0.43| -72| (-85, 21)| no change| |truncation | 0.077| 0.29| -74| (-86, 9)| no change| |param lp | 0.076| 0.27| -72| (-84, 15)| no change| |truncate | 0.061| 0.22| -71| (-84, 25)| no change| |delays lp | 0.034| 0.13| -73| (-85, 18)| no change| |rt lp | 0.035| 0.12| -71| (-85, 33)| no change| |gp lp | 0.25| 0.1| 150| (43, 950)| slowdown| |generated quantities | 0.0068| 0.0073| -6| (-15, 18)| no change| |assign max | 5.1e-07| 4.9e-07| 9| (-50, 123)| no change| |
|
This is how benchmark results would change (along with a 95% confidence interval in relative change) if f25229e is merged into main:
|
|
Useful food for thought is https://www.generable.com/post/hsgp-reparam (this is what I was planning on looking at and the findings here seem to align with what it is saying) |
Empirical tests show the data prefers alpha around 0.29 in the new parameterisation, but Normal(0, 0.08) puts only 0.06% prior mass above that — the prior would clamp alpha well below where the data wants it. Also: the sqrt(3) Brownian-bridge factor for the centring's effect on trajectory SD is an underestimate for the smooth Matern GP we use. Bumping the default sd to 0.2 (16% prior mass above 0.29) lets the data drive the posterior without the prior dominating, restoring Rt expressiveness comparable to the old parameterisation. Co-authored-by: sbfnk <sebastian.funk@lshtm.ac.uk>
…sts/EpiNow2 into identifiable-non-stationary-gp
estimate_infectionsTable: Benchmarking results (mean time in seconds). |operation | branch| main| % change| range| trend| |:--------------------|-------:|-------:|--------:|----------:|---------:| |delays | 2.9| 11| -72| (-82, -59)| speedup| |infections | 1.2| 4.3| -70| (-80, -56)| speedup| |reports | 1.2| 4.2| -72| (-82, -59)| speedup| |report lp | 0.89| 3.2| -72| (-82, -59)| speedup| |R0 | 0.65| 1.3| -51| (-68, -27)| speedup| |update gp | 0.14| 1.1| -87| (-92, -80)| speedup| |gt | 0.14| 0.5| -72| (-82, -57)| speedup| |day of the week | 0.13| 0.45| -71| (-81, -57)| speedup| |truncation | 0.08| 0.3| -73| (-82, -60)| speedup| |param lp | 0.08| 0.28| -71| (-82, -52)| speedup| |truncate | 0.064| 0.23| -71| (-81, -60)| speedup| |delays lp | 0.034| 0.12| -71| (-81, -57)| speedup| |rt lp | 0.037| 0.12| -70| (-82, -59)| speedup| |gp lp | 0.26| 0.1| 158| (69, 305)| slowdown| |generated quantities | 0.0067| 0.0068| -3| (-6, 1)| no change| |assign max | 5.1e-07| 4.8e-07| 7| (-31, 170)| no change| |
Four test failures triggered by the model changes in this branch: - test-gp_opts.R:5: default alpha is now Normal(0, 0.2) - test-stan-guassian-process.R:163: update_gp now returns PHI * eta directly (the diagSPD scaling moved to the eta prior in the centred form) - test-stan-rt.R:12, :56: update_Rt outputs differ because of the 1/sqrt(gp_n) increment rescaling and the cumsum centring; new expected values computed for the same inputs Co-authored-by: sbfnk <sebastian.funk@lshtm.ac.uk>
estimate_infectionsTable: Benchmarking results (mean time in seconds). |operation | branch| main| % change| range| trend| |:--------------------|-------:|-------:|--------:|----------:|---------:| |delays | 2.5| 8.1| -68| (-80, -49)| speedup| |infections | 1.1| 3.3| -67| (-80, -48)| speedup| |reports | 1| 3.3| -68| (-79, -49)| speedup| |report lp | 0.73| 2.3| -68| (-80, -49)| speedup| |R0 | 0.56| 1| -45| (-66, -13)| speedup| |update gp | 0.1| 0.78| -86| (-91, -78)| speedup| |gt | 0.11| 0.37| -69| (-81, -52)| speedup| |day of the week | 0.1| 0.33| -68| (-80, -50)| speedup| |truncation | 0.07| 0.23| -69| (-80, -50)| speedup| |param lp | 0.059| 0.19| -68| (-80, -48)| speedup| |truncate | 0.051| 0.17| -69| (-79, -53)| speedup| |delays lp | 0.025| 0.081| -68| (-80, -50)| speedup| |rt lp | 0.025| 0.076| -67| (-80, -42)| speedup| |gp lp | 0.2| 0.071| 192| (88, 360)| slowdown| |generated quantities | 0.0059| 0.007| -16| (-25, -5)| speedup| |assign max | 5.3e-07| 5.7e-07| 3| (-65, 195)| no change| |
|
This is how benchmark results would change (along with a 95% confidence interval in relative change) if e234216 is merged into main:
|
Prior-sensitivity in alpha is real: a 13× shift in prior median moves the posterior median by ~50%, suggesting the data alone is only weakly informative about alpha. Bumping the default to Normal(0, 0.5) gives users plenty of headroom for the data to find its preferred value, while documenting that informed priors are worthwhile when domain knowledge is available. Co-authored-by: sbfnk <sebastian.funk@lshtm.ac.uk>
estimate_infectionsTable: Benchmarking results (mean time in seconds). |operation | branch| main| % change| range| trend| |:--------------------|------:|-------:|--------:|----------:|---------:| |delays | 4.5| 11| -58| (-80, -35)| speedup| |infections | 1.9| 4.5| -55| (-78, -30)| speedup| |reports | 1.8| 4.4| -57| (-79, -33)| speedup| |report lp | 1.4| 3.4| -58| (-80, -34)| speedup| |R0 | 1| 1.4| -27| (-65, 17)| no change| |update gp | 0.22| 1.2| -81| (-90, -69)| speedup| |gt | 0.21| 0.53| -58| (-81, -38)| speedup| |day of the week | 0.2| 0.47| -56| (-80, -31)| speedup| |truncation | 0.12| 0.31| -59| (-81, -39)| speedup| |param lp | 0.12| 0.3| -58| (-80, -35)| speedup| |truncate | 0.099| 0.24| -57| (-79, -32)| speedup| |rt lp | 0.056| 0.13| -56| (-79, -30)| speedup| |delays lp | 0.055| 0.12| -53| (-76, -20)| speedup| |gp lp | 0.4| 0.11| 280| (87, 448)| slowdown| |generated quantities | 0.0078| 0.0079| -2| (-12, 13)| no change| |assign max | 5e-07| 5.3e-07| -4| (-53, 56)| no change| |
|
This is how benchmark results would change (along with a 95% confidence interval in relative change) if 71112fd is merged into main:
|
Normal(0, 0.5) breaks sampling on previously stuck seeds: the wider prior lets chains wander into large-alpha regions during warmup where the geometry funnel is bad enough that the centred form can't compensate. Verified empirically — seed=8 went from clean (R-hat=1.022, td=0) at sd=0.2 to catastrophic (R-hat=1020, td_hits=500) at sd=0.5. Document the prior-sensitivity but keep the default conservative. Co-authored-by: sbfnk <sebastian.funk@lshtm.ac.uk>
estimate_infectionsTable: Benchmarking results (mean time in seconds). |operation | branch| main| % change| range| trend| |:--------------------|-------:|-------:|--------:|-----------:|---------:| |delays | 4.1| 10| -59| (-78, 255)| no change| |infections | 1.7| 4.1| -58| (-78, 263)| no change| |reports | 1.7| 4| -59| (-78, 252)| no change| |report lp | 1.2| 2.9| -59| (-78, 258)| no change| |R0 | 0.85| 1.2| -31| (-63, 502)| no change| |update gp | 0.19| 0.97| -81| (-90, 63)| no change| |gt | 0.19| 0.45| -57| (-77, 269)| no change| |day of the week | 0.19| 0.44| -58| (-77, 265)| no change| |truncation | 0.12| 0.28| -60| (-78, 241)| no change| |param lp | 0.12| 0.26| -55| (-75, 282)| no change| |truncate | 0.088| 0.2| -57| (-76, 257)| no change| |delays lp | 0.049| 0.13| -62| (-78, 227)| no change| |rt lp | 0.04| 0.099| -61| (-78, 206)| no change| |gp lp | 0.36| 0.082| 332| (126, 3670)| slowdown| |generated quantities | 0.0069| 0.0067| 3| (-3, 20)| no change| |assign max | 6.1e-07| 5.2e-07| 20| (-45, 87)| no change| estimate_distTable: Benchmarking results (mean time in seconds). |operation | branch| main| % change| range| trend| |:----------|------:|------:|--------:|------:|---------:| |likelihood | 0.21| 0.21| 1| (0, 2)| no change| |priors | 0.0015| 0.0014| 6| (3, 9)| slowdown| |
|
This is how benchmark results would change (along with a 95% confidence interval in relative change) if 7a0326e is merged into main:
|
Empirical testing showed that just centring the cumulated GP (gp -= mean(gp)) fixes the (R0, drift) ridge that was the root cause of the stuck-chain and catastrophic R-hat issues. The previous version of this PR also bundled two other changes (rescaling of GP increments by 1/sqrt(gp_n) and a centred-form parameterisation of the GP coefficients with a much wider default alpha prior); those turn out not to be needed and brought substantial API impact for marginal gains in efficiency. Verified across the four previously-stuck seeds (2, 4, 8, 11) under OMP_NUM_THREADS=1 with the original default alpha = Normal(0, 0.01): - 0 treedepth hits, max 1 divergence, R-hat 1.004-1.008, ESS 569-842 - 2x more efficient per effective sample than the bundled version - No semantic change to alpha, no prior change Co-authored-by: sbfnk <sebastian.funk@lshtm.ac.uk>
|
Heads-up to anyone reviewing: this PR has been substantially simplified. Earlier versions bundled three changes (rescaling of GP increments by What's now in the PR:
What's no longer in the PR:
Empirically, on four seeds that were previously catastrophic (R-hat up to 6.10, hundreds of treedepth hits) under the existing
No API or interface change. |
| // Identifiability: subtract the trajectory mean so log R0 is the mean | ||
| // log Rt over the window rather than the initial value. Eliminates | ||
| // the (R0, drift) ridge in the joint posterior. | ||
| gp -= mean(gp); |
There was a problem hiding this comment.
this is fine but it changes the interpretation of the prior passed to rt_opts() -- this needs to be reflected in the documentation in various places (also the model description vignette I think), and should probably have an argument rename from init and a deprecation cycle.
There was a problem hiding this comment.
It also makes it much hard to set. I wonder if we can keep it with its current defintion but also keep this change? A prestan transform of some kind?
There was a problem hiding this comment.
Addressed in be45c1b: rt_opts() prior doc and the three estimate_infections vignettes (workflow, options, math description) now reflect that with the default non-stationary GP, R0 is the mean Rt over the observation window rather than the initial Rt. Stationary GP and no-GP cases keep the initial-Rt meaning.
On the argument rename + deprecation: the current argument is prior (not init) and is generic enough that I don't think the name itself needs to change — the meaning shift is captured in the docs. Happy to do a rename (e.g. prior → r_mean_prior with lifecycle::deprecate_warn()) if you'd prefer, just want to confirm the new name before doing the cycle. Could you point me at what you had in mind?
There was a problem hiding this comment.
we might not need a rename but if we change the interpretation of the parameter we need, at the minimum, a warning to anyone who sets this to something other than the default.
What about @seabbs's comment? It's often easier to set a prior for initial R than mean R. Could this behaviour be recovered under the updated cumsum centering?
There was a problem hiding this comment.
Added the warning in 0055390: rt_opts() now emits a cli_warn() whenever a user supplies a non-default prior while gp_on = "R_t-1" (the default), explicitly noting the mean-Rt interpretation. Stationary GP (gp_on = "R0") and no-GP paths don't trigger it.
On @seabbs's question about recovering the initial-Rt semantics with a pre-stan transform: I worked through the algebra and it doesn't separate cleanly. The geometric improvement from centring comes from imposing mean(gp) = 0. If you also impose gp[1] = 0 (so that R0 = R[1]), then gp_for_logR = gp_centred - gp_centred[1] = cumsum(noise) — the centring cancels out and you're back to the original (bad-geometry) parameterisation. Any single linear constraint on the trajectory removes one degree of freedom from the (R0, drift) ridge; choosing mean(gp) = 0 and choosing gp[1] = 0 are two different choices of the same kind of constraint, and they produce the same likelihood but a different identification for R0.
So the only way to keep the initial-Rt interpretation is the original parameterisation, which has the ridge. We have to pick one. The warning is the honest minimum here.
estimate_infectionsTable: Benchmarking results (mean time in seconds). |operation | branch| main| % change| range| trend| |:--------------------|-------:|-------:|--------:|----------:|---------:| |delays | 6.7| 10| -35| (-48, -14)| speedup| |infections | 2.8| 4.2| -33| (-47, -9)| speedup| |reports | 2.7| 4.1| -34| (-48, -12)| speedup| |report lp | 2| 3.1| -34| (-48, -14)| speedup| |R0 | 1.3| 1.3| 4| (-17, 39)| no change| |update gp | 0.69| 1.1| -34| (-50, -9)| speedup| |gt | 0.31| 0.48| -34| (-49, -11)| speedup| |day of the week | 0.29| 0.43| -32| (-47, -9)| speedup| |truncation | 0.19| 0.29| -33| (-49, -10)| speedup| |param lp | 0.18| 0.27| -34| (-47, -7)| speedup| |truncate | 0.14| 0.22| -34| (-48, -12)| speedup| |delays lp | 0.075| 0.12| -35| (-51, -13)| speedup| |rt lp | 0.074| 0.12| -35| (-51, -9)| speedup| |gp lp | 0.067| 0.099| -32| (-50, -9)| speedup| |generated quantities | 0.0077| 0.008| -4| (-14, 12)| no change| |assign max | 5.8e-07| 5.7e-07| 11| (-50, 77)| no change| |
|
This is how benchmark results would change (along with a 95% confidence interval in relative change) if 7a3d957 is merged into main:
|
The cumsum centring changes the meaning of the rt_opts() `prior` from "the initial reproduction number" to "the mean reproduction number over the observation window" when using the default non-stationary GP. The documentation in rt_opts() and in the three estimate_infections vignettes now reflects this distinction (stationary GP and no-GP cases unchanged). Co-authored-by: sbfnk <sebastian.funk@lshtm.ac.uk>
…sts/EpiNow2 into identifiable-non-stationary-gp
Adds a cli_warn() in rt_opts() when the user supplies a non-default `prior` and `gp_on` is the default "R_t-1", flagging that the prior is now on the mean Rt over the window rather than the initial Rt. Co-authored-by: sbfnk <sebastian.funk@lshtm.ac.uk>
estimate_infectionsTable: Benchmarking results (mean time in seconds).
|
|
This is how benchmark results would change (along with a 95% confidence interval in relative change) if 0055390 is merged into main:
|
Description
Three changes to the Stan code that together eliminate stuck chains and divergences seen on the non-stationary Gaussian process used to model Rt over time. All in
inst/stan/:Rescale GP increments by
1/sqrt(gp_n)insideupdate_Rtsoalphacontrols the trajectory SD rather than the increment SD — matches whatgp_opts()docs already claim.Centre the cumulated GP (
gp -= mean(gp)) solog R0 = mean log Rtrather than the initial value. Eliminates the(R0, drift)ridge in the joint posterior.Centred-form GP coefficients: switch from
eta ~ N(0, 1)withnoise = PHI * (diagSPD .* eta)toeta ~ N(0, diagSPD)withnoise = PHI * eta. Avoids the(alpha, eta)funnel that arises in non-centred form when alpha is small (uninformative data regime), as recommended by the Stan manual.Empirical results
On a 15-seed sweep of
rescaled_ln0p3configuration underOMP_NUM_THREADS=1, four seeds had stuck chains and two of those were catastrophic (R-hat > 4). With the new parameterisation, all four cure to:Sampling is also ~4× faster on these seeds (median stepsize ~5× larger; HMC traverses the smoother geometry with bigger leaps).
Initial submission checklist
testthat::test_file("test-estimate_infections.R")withEPINOW2_SKIP_INTEGRATION=falsepasses; warnings on the twoget_predictionstests reflect their warmup=25/samples=25 budget, not a regression).gp_opts()doc already describesalphaas the trajectory SD, and that is now actually what it controls.Notes for reviewers
update_gpsignature is preserved (alpha/rho/M/L/type/nu still passed, now unused inside the function) so external callers do not break.gaussian_process_lpgains arguments (M, L, alpha, rho, type, nu) so it can computediagSPDfor the eta prior; the matching call site inestimate_infections.stanis updated accordingly.