diff --git a/crates/argmin/src/solver/neldermead/mod.rs b/crates/argmin/src/solver/neldermead/mod.rs index 3a1d04971..a6b9068f1 100644 --- a/crates/argmin/src/solver/neldermead/mod.rs +++ b/crates/argmin/src/solver/neldermead/mod.rs @@ -401,6 +401,12 @@ where } } } else { + if xr_cost.is_nan() { + return Err(argmin_error!( + ConditionViolated, + format!("`NelderMead`: Cost function returned NaN") + )); + } return Err(argmin_error!( PotentialBug, "`NelderMead`: Reached unreachable point." @@ -868,4 +874,20 @@ mod tests { assert_relative_eq!(nm.params[2].0[1], 0.0f64, epsilon = f64::EPSILON); assert_relative_eq!(nm.params[2].1, 1.00f64, epsilon = f64::EPSILON); } + + #[test] + fn test_nan_error() { + let params: Vec> = vec![vec![-1.0, 0.0], vec![0.0, 1.0], vec![f64::NAN, f64::NAN]]; + let mut nm: NelderMead<_, f64> = NelderMead::new(params); + let state: IterState, (), (), (), (), f64> = IterState::new(); + let mut problem = Problem::new(MwProblem {}); + let (state, _) = nm.init(&mut problem, state).unwrap(); + + let res = nm.next_iter(&mut problem, state); + assert_error!( + res, + ArgminError, + "Condition violated: \"`NelderMead`: Cost function returned NaN\"" + ); + } }