Skip to content

ML-Benchmark prediction submission: Prithvi-UNet-orog #46

@midatm1234

Description

@midatm1234

Institution ID

JPL

Institution Name

Jet Propulsion Laboratory

Emulator identifier (emulator_id)

Prithvi-UNet-orog

Emulator Description

Our emulator, Prithvi-UNet, is a Prithvi-WxC-based statistical downscaling model for CORDEX-ML experiments. It maps coarse atmospheric predictors (u, v, q, t, z) at 850, 700, and 500 hPa, together with static orography to high-resolution precipitation (pr) and daily maximum temperature (tasmax). In the current configuration, the model uses a Prithvi-WxC encoder backbone (best_rmse_UNET_small.pt from IBM's granite-wxc repo) with a convolutional encoder-decoder downscaling head and pixel-shuffle upsampling.

Workflow
Compute training-set scalars for predictors and targets → preprocess/regrid data onto the training target grid → train/fine-tune the downscaling model → generate predictions. The saved scalar files are then reused during training and inference.

Preprocessing and normalization
Two preprocessing components are central to this emulator setup. First, the model uses saved normalization arrays for inputs_mean, inputs_std, targets_mean, and targets_std, and the model configuration points input_mu, input_sigma, target_mu, and target_sigma to those same files. The scalar-generation script computes these quantities channel-wise from the training data.

Second, the coarse predictors are regridded before training and inference so they are spatially aligned with the high-resolution target grid. Rather than describing this as part of the emulator architecture itself, it is more accurate to describe regridding as a preprocessing step in the CORDEX_ML workflow.

Target scaling details

In the current v6 configuration, dynamic predictors use training-set channel normalization:

x_norm = (x - input_mu) / (input_sigma + eps)

For target variables, tasmax uses gridpoint z-score normalization:

tasmax_norm = (tasmax - tasmax_mu) / tasmax_sigma

Precipitation (pr) does not use z-score or log1p normalization in v6. It uses divide-only scaling by the training-set global 95th percentile, pr_95th, with target_mu(pr)=0 and target_sigma(pr)=pr_95th:

pr_norm = pr / pr_95th

The precipitation output is modeled with a hurdle formulation: an occurrence branch predicts wet probability and an amount branch predicts positive precipitation amount. The raw amount output, z_amount, is passed through softplus:

amount_norm = log(1 + exp(z_amount))

The wet probability is:

p_wet = sigmoid(z_wet)

The denormalized precipitation prediction is:

if p_wet >= 0.5:
pr = pr_95th * amount_norm
else:
pr = 0

During training, the precipitation-specific loss is the sum of an occurrence loss and an amount loss. The occurrence target is wet = 1 when pr > 0.01, otherwise wet = 0.

occurrence_loss = BCEWithLogits(z_wet, wet)

amount_target_norm = pr / pr_95th

amount_loss = SmoothL1(amount_norm, amount_target_norm), computed on wet pixels only

L_pr = occurrence_loss + amount_loss

The occurrence and amount loss weights are both 1.0.

Hardware and Training Details

Fine-tuning with ten epochs required about 30 hours, and inference
over the evaluation period of 20 years required approximately two hours on a system with four Nvidia A10
GPUs.

Stochastic/Probabilistic Output

no

Reference URL

https://arxiv.org/abs/2409.13598

Repository for Reproducibility

https://github.com/midatm1234/granite-wxc/tree/CORDEX_ML

Additional Notes

Submission without orography: #14

v2 submission addressed the following issues.

  • NaN predictor values in the South Africa domain.
  • Mishandling of the orography input in part of the workflow.

v3 submission includes the following updates.

  • Improved smoothness in the predicted tasmax distributions
  • Implementation of softplus normalization for pr

v4 Submission

  • Extended Fine-tuning: Increased training duration from 5 epochs (v3) to 10 epochs.

v5 Submission

  • log1p precipitation normalization
  • Batch size is 32
  • Loss function includes distribution differences

v6 Submission (current)

  • Predictor normalization: x_norm = (x - input_mu) / (input_sigma + eps)
  • tasmax uses gridpoint z-score normalization.
  • pr uses divide-only scaling by the training-set global 95th percentile.
  • pr is modeled with a hurdle head: wet occurrence probability plus positive amount prediction.
  • Precipitation loss combines BCEWithLogits occurrence loss and SmoothL1 amount loss on wet pixels only.
  • Wet threshold changed from 0.5 to 0.01; inference wet-probability threshold is 0.5.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions