diff --git a/NEWS.md b/NEWS.md index 7264cfa51..eabcbdd4e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -58,6 +58,7 @@ The function interface remains unchanged. ## Model changes +- Refactored Stan code for efficiency (@bob-carpenter, #1273). - MCMC runs are now initialised with parameter values drawn from a distribution that approximates their prior distributions. - Added an option to compute growth rates using an estimator by Parag et al. (2022) based on total infectiousness rather than new infections, see `growth_method` argument in rt_opts(). - Added support for fitting the susceptible population size. diff --git a/inst/stan/functions/convolve.stan b/inst/stan/functions/convolve.stan index 2a84df929..b11a0cf6e 100644 --- a/inst/stan/functions/convolve.stan +++ b/inst/stan/functions/convolve.stan @@ -62,7 +62,6 @@ array[] int calc_conv_indices_len(int s, int xlen, int ylen) { vector convolve_with_rev_pmf(vector x, vector y, int len) { int xlen = num_elements(x); int ylen = num_elements(y); - vector[len] z; if (xlen + ylen - 1 < len) { reject("convolve_with_rev_pmf: len is longer than x and y convolved"); @@ -72,16 +71,16 @@ vector convolve_with_rev_pmf(vector x, vector y, int len) { reject("convolve_with_rev_pmf: len is shorter than x"); } + vector[len] z; + for (s in 1:xlen) { array[4] int indices = calc_conv_indices_xlen(s, xlen, ylen); z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]); } - if (len > xlen) { - for (s in (xlen + 1):len) { - array[4] int indices = calc_conv_indices_len(s, xlen, ylen); - z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]); - } + for (s in (xlen + 1):len) { // zero iterations unless len > xlen + array[4] int indices = calc_conv_indices_len(s, xlen, ylen); + z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]); } return z; diff --git a/inst/stan/functions/gaussian_process.stan b/inst/stan/functions/gaussian_process.stan index 2f067b4bd..6480dccc1 100644 --- a/inst/stan/functions/gaussian_process.stan +++ b/inst/stan/functions/gaussian_process.stan @@ -23,6 +23,21 @@ vector diagSPD_EQ(real alpha, real rho, real L, int M) { return factor * exp(exponent * square(indices)); } +/** + * Index set for M basis functions of length L for Matern kernel. + * + * The function returns pow(pi() / 2 / L * linspaced_vector(M, 1, M), 2), + * or equivalently, square(pi() / (2 * L) * linspaced_vector(M, 1, M)). + * + * @param L Length of the interval + * @param M Number of basis functions + * @return Linearly spaced M-vector + */ +vector matern_indices(int M, real L) { + real factor = pi() / (2 * L); + return square(linspaced_vector(M, factor, factor * M)); +} + /** * Spectral density for 1/2 Matern (Ornstein-Uhlenbeck) kernel * @@ -35,10 +50,8 @@ vector diagSPD_EQ(real alpha, real rho, real L, int M) { * @ingroup estimates_smoothing */ vector diagSPD_Matern12(real alpha, real rho, real L, int M) { - vector[M] indices = linspaced_vector(M, 1, M); - real factor = 2; - vector[M] denom = rho * ((1 / rho)^2 + pow(pi() / 2 / L * indices, 2)); - return alpha * sqrt(factor * inv(denom)); + vector[M] denom = 1 / rho + rho * matern_indices(M, L); + return alpha * sqrt(2 ./ denom); } /** @@ -53,10 +66,9 @@ vector diagSPD_Matern12(real alpha, real rho, real L, int M) { * @ingroup estimates_smoothing */ vector diagSPD_Matern32(real alpha, real rho, real L, int M) { - vector[M] indices = linspaced_vector(M, 1, M); - real factor = 2 * alpha * pow(sqrt(3) / rho, 1.5); - vector[M] denom = (sqrt(3) / rho)^2 + pow((pi() / 2 / L) * indices, 2); - return factor * inv(denom); + real factor = 2 * alpha * (sqrt(3) / rho)^1.5; + vector[M] denom = 3 / square(rho) + matern_indices(M, L); + return factor ./ denom; } /** @@ -71,11 +83,9 @@ vector diagSPD_Matern32(real alpha, real rho, real L, int M) { * @ingroup estimates_smoothing */ vector diagSPD_Matern52(real alpha, real rho, real L, int M) { - vector[M] indices = linspaced_vector(M, 1, M); real factor = 16 * pow(sqrt(5) / rho, 5); - vector[M] denom = - 3 * pow((sqrt(5) / rho)^2 + pow((pi() / 2 / L) * indices, 2), 3); - return alpha * sqrt(factor * inv(denom)); + vector[M] denom = 3 * pow(5 / square(rho) + matern_indices(M, L), 3); + return alpha * sqrt(factor ./ denom); } /** @@ -92,10 +102,11 @@ vector diagSPD_Periodic(real alpha, real rho, int M) { real a = inv_square(rho); vector[M] indices = linspaced_vector(M, 1, M); vector[M] q = exp( - log(alpha) + 0.5 * - (log(2) - a + to_vector(log_modified_bessel_first_kind(indices, a))) + log(alpha) + + 0.5 * (log2() - a + log_modified_bessel_first_kind(indices, a)) ); return append_row(q, q); + } /** @@ -129,11 +140,11 @@ matrix PHI(int N, int M, real L, vector x) { * * @ingroup estimates_smoothing */ + matrix PHI_periodic(int N, int M, real w0, vector x) { - matrix[N, M] mw0x = diag_post_multiply( - rep_matrix(w0 * x, M), linspaced_vector(M, 1, M) - ); - return append_col(cos(mw0x), sin(mw0x)); + row_vector[M] k = linspaced_row_vector(M, 1, M); + matrix[N, M] w0xk = (w0 * x) * k; + return append_col(cos(w0xk), sin(w0xk)); } /** @@ -153,9 +164,7 @@ matrix PHI_periodic(int N, int M, real w0, vector x) { int setup_noise(int ot_h, int t, int horizon, int estimate_r, int stationary, int future_fixed, int fixed_from) { int noise_time = estimate_r > 0 ? (stationary > 0 ? ot_h : ot_h - 1) : t; - int noise_terms = - future_fixed > 0 ? (noise_time - horizon + fixed_from) : noise_time; - return noise_terms; + return future_fixed > 0 ? (noise_time - horizon + fixed_from) : noise_time; } /** @@ -210,7 +219,7 @@ vector update_gp(matrix PHI, int M, real L, real alpha, } else if (nu == 2.5) { diagSPD = diagSPD_Matern52(alpha, rho, L, M); } else { - reject("nu must be one of 1/2, 3/2 or 5/2; found nu=", nu); + reject("nu must be one of 0.5, 1.5, or 2.5; found nu=", nu); } } return PHI * (diagSPD .* eta);