Skip to content

Commit 65bd653

Browse files
authored
prevent erroneous array in ppf with _scalar_if_array_all_equal (#5)
* prevent erroneous array in `ppf` with _scalar_if_array_all_equal * PR comments
1 parent ab5f580 commit 65bd653

2 files changed

Lines changed: 14 additions & 4 deletions

File tree

betapert/funcs.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,18 @@ def log_cdf_eq(x_normalized):
8383
}
8484

8585

86+
def _scalar_if_array_all_equal(array: np.ndarray | float) -> np.float64 | np.ndarray | float:
87+
if isinstance(array, np.ndarray) and array.size != 0 and np.all(array == array[0]):
88+
return array[0]
89+
return array
90+
91+
8692
def _calc_alpha_beta(
8793
mini: np.float64 | np.ndarray | float,
8894
mode: np.float64 | np.ndarray | float,
8995
maxi: np.float64 | np.ndarray | float,
9096
lambd: np.float64 | np.ndarray | float,
91-
) -> tuple[np.float64 | np.ndarray | float | int, np.float64 | np.ndarray | float | int]:
97+
) -> tuple[np.float64 | np.ndarray | float, np.float64 | np.ndarray | float]:
9298
"""Calculate alpha and beta parameters for the underlying beta distribution.
9399
94100
Args:
@@ -107,10 +113,9 @@ def _calc_alpha_beta(
107113
if DEBUG and any(isinstance(x, np.ndarray) for x in (mini, mode, maxi, lambd)):
108114
sys.stderr.write("CAB: unexpected arrays in method parameters\n")
109115
if isinstance(alpha, np.ndarray) and isinstance(beta, np.ndarray):
110-
if np.all(alpha == alpha[0]) and np.all(beta == beta[0]):
111-
return alpha[0], beta[0]
112116
if DEBUG:
113117
sys.stderr.write(f"CAB: Unexpected arrays: alpha={alpha}, beta={beta}\n")
118+
return _scalar_if_array_all_equal(alpha), _scalar_if_array_all_equal(beta)
114119
return alpha, beta
115120

116121

@@ -130,6 +135,11 @@ def sf(x, mini, mode, maxi, lambd=4):
130135

131136

132137
def ppf(q, mini, mode, maxi, lambd=4, *, fallback=None):
138+
mini = _scalar_if_array_all_equal(mini)
139+
mode = _scalar_if_array_all_equal(mode)
140+
maxi = _scalar_if_array_all_equal(maxi)
141+
lambd = _scalar_if_array_all_equal(lambd)
142+
133143
alpha, beta = _calc_alpha_beta(mini, mode, maxi, lambd)
134144
_beta_ppf = mini + (maxi - mini) * scipy.stats.beta.ppf(q, alpha, beta)
135145
# Use fallback if any values are NaN

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[tool]
22
[tool.poetry]
33
name = "beta-pert-dist-scipy"
4-
version = "0.2.0"
4+
version = "0.2.1"
55
homepage = "https://github.com/hbmartin/betapert"
66
description = "Top-level package for beta-PERT distribution."
77
authors = ["Tom Adamczewski <tadamczewskipublic@gmail.com>", "Harold Martin harold.martin@gmail.com"]

0 commit comments

Comments
 (0)