Skip to content

Port MOSAIC likelihood calculation to Python for Dask worker-side scoring #47

@gilesjohnr

Description

@gilesjohnr

Problem

On the local R parallel path (PSOCK/FORK), likelihood is computed on-worker alongside the LASER simulation — no bottleneck. But on the Dask/Coiled path, workers run LASER and return full time-series matrices (~22KB per sim) back to the R orchestrator, which then computes the likelihood serially at ~9-16 sims/sec. This post-processing takes longer than the simulations themselves — 60-90 seconds per 1,000-sim batch and ~40 minutes for a 24K-sim predictive batch — while the cluster sits idle.

Proposal

Port the full R calc_model_likelihood() to Python inside laser-cholera -- including the core NB likelihood and all shape terms (peak timing, peak magnitude, cumulative progression, WIS) -- so Dask workers compute the likelihood on-worker immediately after each LASER iteration. Workers return {sim_id, log_likelihood, param_vector} instead of full time-series matrices. The R package already has robust config-to-vector serialization (convert_config_to_matrix() / convert_matrix_to_config()) that handles the param vector round-trip, so the orchestrator integration is straightforward.

What already exists

likelihood.py has get_model_likelihood() with core NB and Poisson distribution functions. The Analyzer class calls it on the final tick. This foundation needs to be upgraded to match the R implementation.

The R side was audited, cleaned up, and normalized in MOSAIC-pkg v0.22.0 (commit 9109cd5). Guardrails, legacy max terms, boolean toggles, and several bugs were removed/fixed. Shape terms were T-normalized so that weight parameters are meaningful. The R code at that commit is the authoritative spec for this port.

Components to port

The R function has a core NB likelihood (always on) plus 4 optional shape terms. Shape terms are enabled by setting their weight > 0 (no separate boolean toggles). All weights default to 0 (OFF). Ablation tests (test_44 series) showed WIS is the most effective shape term; peaks and cumulative are available for specialized use cases.

Shape terms are T-normalized so weight = 0.25 means 25% of NB core influence:

  • NB core: sums T dnbinom() terms → O(T). No scaling.
  • Peaks: T / N_peaks scaling (raw output is O(N_peaks))
  • Cumulative: T scaling (helper divides each eval by end_idx → O(1) raw output)
  • WIS: T scaling (returns time-averaged scalar → O(1) raw output)

1. Core NB time-series likelihood (always on) — main loop

Per-location, per-outcome (cases, deaths) Negative Binomial log-likelihood with:

Dispersion estimationnb_size_from_obs_weighted(), weighted method-of-moments with Bessel correction and k_min floor:

m = sum(w * x) / sum(w)
denom = sum(w) - sum(w^2) / sum(w)       # Bessel correction
v = sum(w * (x - m)^2) / denom
k = m^2 / (v - m)
k = clamp(k, k_min=3, k_max=1e5)

Key design: k is estimated from observed data only, so all simulations are evaluated against the same noise model. Minimum 3 finite observations required per location/outcome.

NB formulacalc_log_likelihood_negbin():

LL = lgamma(y + k) - lgamma(k) - lgamma(y + 1) + k * log(k / (k + mu)) + y * log(mu / (k + mu))

Poisson fallback when k=Inf (var <= mean) — calc_log_likelihood_poisson(): LL = y * log(mu) - mu - lgamma(y + 1). Python must explicitly branch on k == Inf since scipy.stats.nbinom doesn't handle infinite size.

Zero-prediction proportional penalties: When est <= 0 and obs > 0: penalty = -obs * log(1e6) (~-13.8 per observed unit). When both zero: LL = 0. Epsilon floor for positive estimates: 1e-10.

Observed values are rounded before the integer check (cross-language float safety).

Key behaviors that apply across all components (main loop):

  • Minimum observation threshold: Requires >= 3 finite observations per location/outcome (have_cases/have_deaths). If fewer, that outcome's contribution is 0 for ALL components (core NB, peaks, cumulative, WIS), not just the core.
  • Weight masking: mask_weights() zeros out weights_time entries where obs or est is non-finite before passing to the NB function.
  • k uses weights_time: The dispersion k is estimated via nb_size_from_obs_weighted(obs, weights_time, ...), not uniform weights.
  • Separate k_min for cases vs deaths: nb_k_min_cases and nb_k_min_deaths are independent parameters (both default 3).

2. Peak timing term (weight > 0 to enable) — .calc_peak_timing_from_indices()

Uses precomputed peak indices to score how well the model places outbreaks in time:

For each known peak index in observed data:
  window = peak_idx +/- 14 timesteps
  est_peak_idx = argmax(est_vec[window])
  time_diff = (est_peak_idx - peak_idx) / timestep_to_weeks
  LL += dnorm(time_diff, 0, sigma_peak_time=1, log=TRUE)

Unmatched peaks (window < 3 points) are skipped (contribute 0).

timestep_to_weeks = 7 for daily data, 1 for weekly data (auto-detected from date sequence).

Default weight: 0 (OFF). Suggested operational value: 0.25.

3. Peak magnitude term (weight > 0 to enable) — .calc_peak_magnitude_from_indices()

Scores how well the model matches outbreak peak heights with adaptive tolerance:

For each matched peak (same +/- 14 window):
  obs_peak_val = max(obs_vec[window])
  est_peak_val = max(est_vec[window])
  adaptive_sigma = sigma_peak_log * sqrt(100 / max(obs_peak_val, 100))
  LL += dnorm(log(est_peak) - log(obs_peak), 0, adaptive_sigma, log=TRUE)

The adaptive sigma gives more tolerance to small peaks (inherently noisier). sigma_peak_log default: 0.5.

Default weight: 0 (OFF). Suggested operational value: 0.25.

4. Cumulative progression term (weight > 0 to enable) — ll_cumulative_progressive_nb()

NB likelihood evaluated at cumulative fractions of the time series, with per-observation normalization:

For each tp in [0.25, 0.5, 0.75, 1.0]:
  end_idx = round(n_timesteps * tp)
  cum_obs = sum(obs[1:end_idx])
  cum_est = sum(est[1:end_idx])
  cum_k = k_data * end_idx                # dispersion scales with n summed
  LL_tp = dnbinom(round(cum_obs), mu=cum_est, size=cum_k, log=TRUE)
  LL_tp = LL_tp / end_idx                 # per-observation normalization

Return mean(LL_tp across timepoints)

The /end_idx normalization converts the cumulative NB LL (which grows with count magnitude) to a per-observation scale, making the helper output O(1). T-scaling at assembly then brings it to O(T) like other terms.

Zero-prediction handling: same proportional penalty as core NB (-obs * log(1e6) / end_idx).

Default weight: 0 (OFF). Suggested operational value: 0.25.

5. Weighted Interval Score (weight > 0 to enable) — compute_wis_parametric_row()

Probabilistic scoring using parametric NB/Poisson quantiles per Bracher et al. (2021):

quantiles = [0.025, 0.25, 0.5, 0.75, 0.975]
For each timestep:
  qfun = qnbinom(p, mu=est, size=k) if k < Inf else qpois(p, lambda=est)

MAE_term = 0.5 * weighted_mean(|y - qfun(0.5)|)     # NOTE: 0.5 coefficient
For each quantile pair (pL, pU):
  IS = (qU - qL) + (2/alpha) * max(0, qL - y) + (2/alpha) * max(0, y - qU)
  sum_IS += (alpha/2) * weighted_mean(IS)

WIS = (MAE_term + sum_IS) / (K + 0.5)

Quantile pair-matching uses tolerance-based comparison (abs(uppers - (1-p)) < 1e-8), not exact float equality. WIS is negated to contribute as a log-likelihood term (it's a loss; lower = better). Estimated values floored at 1e-12 before computing quantiles. Returns NA (treated as 0 contribution) when all data is non-finite.

Default weight: 0 (OFF). Suggested operational value: 0.10 (provides trajectory-shape regularization; ablation tests show this is the most effective shape term).

Assembly formula — lines 245-284

T = n_time_steps
peak_scale = T / N_peaks_j      # per-location; 0 if no peaks

ll_loc = wc * NB_cases + wd * NB_deaths                                       [core, O(T)]
       + peak_scale * w_pt  * (wc * pt_c  + wd * pt_d)                        [peaks, O(T)]
       + peak_scale * w_pm  * (wc * pm_c  + wd * pm_d)                        [peaks, O(T)]
       + T          * w_cum * (wc * cum_c + wd * cum_d)                        [cumulative, O(T)]
       + T          * w_wis * (wc * wis_c + wd * wis_d)                        [WIS, O(T)]

ll_total = sum(weights_location[j] * ll_loc[j])

Shape terms with weight=0 contribute nothing (no computation is skipped, but 0 * anything = 0). weight_cases/weight_deaths apply multiplicatively to EVERY component. Non-finite per-location LL is replaced with -Inf (zero importance weight).

Suggested function signature

def get_model_likelihood(
    obs_cases,                  # np.ndarray [n_locations x n_time_steps]
    sim_cases,                  # np.ndarray [n_locations x n_time_steps]
    obs_deaths,                 # np.ndarray [n_locations x n_time_steps]
    sim_deaths,                 # np.ndarray [n_locations x n_time_steps]
    weight_cases=1.0,
    weight_deaths=1.0,
    weights_location=None,      # [n_locations] or None for uniform
    weights_time=None,          # [n_time_steps] or None for uniform
    # --- shape term weights (0 = OFF; 0.25 = 25% of NB core influence) ---
    weight_peak_timing=0.0,
    weight_peak_magnitude=0.0,
    weight_cumulative_total=0.0,
    weight_wis=0.0,
    # --- peak controls ---
    sigma_peak_time=1.0,
    sigma_peak_log=0.5,
    peak_indices_by_loc=None,   # list of int arrays, precomputed from MOSAIC::epidemic_peaks
    timestep_to_weeks=7,        # 7 for daily data, 1 for weekly
    # --- WIS ---
    wis_quantiles=(0.025, 0.25, 0.5, 0.75, 0.975),
    # --- cumulative ---
    cumulative_timepoints=(0.25, 0.5, 0.75, 1.0),
    # --- NB controls ---
    nb_k_min_cases=3,
    nb_k_min_deaths=3,
    verbose=False,
) -> float:

No separate boolean toggles — weight > 0 enables the term.

Implementation notes

Peak indices: Peak timing/magnitude terms need precomputed peak indices per location. The R orchestrator will compute these from the epidemic_peaks dataset and pass them as peak_indices_by_loc (list of int arrays) via the scattered base config. No need to ship the dataset to Python.

Likelihood settings: All likelihood parameters (weights, sigma values) are driven by the mosaic_control_defaults() control object in run_MOSAIC(). The R orchestrator will serialize the relevant likelihood settings and pass them to the Python worker alongside the base config.

Config serialization: The R package has round-trip config-to-vector conversion (convert_config_to_matrix() / convert_matrix_to_config()) that flattens location-specific vectors with ISO suffixes (e.g., beta_j0_tot_ETH). The param vector returned by workers uses this naming convention.

Dask integration: Once get_model_likelihood() is upgraded, mosaic_dask_worker.py calls it after each LASER iteration and returns the scalar LL instead of matrices. The R orchestrator changes are a separate PR in MOSAIC-pkg.

Numerical equivalence reference values

The R test suite (test-calc_model_likelihood_reference.R) pins exact numerical outputs. The Python implementation must match these within floating-point tolerance. Run this in R to reproduce:

library(MOSAIC)

# Reference data: 2 locations x 10 timesteps
obs_c <- matrix(c(10,20,30,40,50,60,70,80,90,100,
                   5,10,15,20,25,30,35,40,45, 50), nrow = 2, byrow = TRUE)
est_c <- matrix(c(12,18,35,38,55,58,72,78,88,105,
                   6, 9,14,22,23,32,33,42,43, 52), nrow = 2, byrow = TRUE)
obs_d <- round(obs_c * 0.05)
est_d <- round(est_c * 0.05)

# Core NB only (all weights default to 0)
calc_model_likelihood(obs_c, est_c, obs_d, est_d)
# [1] -100.9024

# Core NB + cumulative (weight_cumulative_total=0.25)
calc_model_likelihood(obs_c, est_c, obs_d, est_d, weight_cumulative_total = 0.25)
# [1] -107.0092

# Core NB + WIS (weight_wis=0.10)
calc_model_likelihood(obs_c, est_c, obs_d, est_d, weight_wis = 0.10)
# [1] -110.5464

# Perfect match
calc_model_likelihood(obs_c, obs_c, obs_d, obs_d)
# [1] -99.2949

# Element-level NB (k=3)
calc_log_likelihood_negbin(
  observed = c(10, 20, 30, 40, 50), estimated = c(12, 18, 35, 38, 55),
  k = 3, weights = NULL, verbose = FALSE)
# [1] -18.7032

# Element-level Poisson
calc_log_likelihood_poisson(
  observed = c(10, 20, 30), estimated = c(10, 20, 30),
  weights = NULL, verbose = FALSE)
# [1] -7.1218

All values should match to tolerance 1e-4.

Full R test suite — port these to pytest:
MOSAIC-pkg/tests/testthat/

  • test-calc_model_likelihood_reference.R — pinned numerical values (start here)
  • test-calc_model_likelihood.R — structural and integration tests
  • test-calc_model_likelihood_extreme.R — extreme inputs and edge cases
  • test-calc_log_likelihood_negbin.R — NB distribution function
  • test-nb_size_from_obs_weighted.R — dispersion estimation
  • test-compute_wis_parametric_row.R — WIS scoring
  • test-ll_cumulative_progressive_nb.R — cumulative progression

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions