Skip to content

Commit 3db5815

Browse files
authored
Add log space fallback for PPF func (#2)
* Add log space fallback for PPF func * PR feedback * small fixes * add global fallback flag to avoid scipy instantiation problems * poetry lock * pull function definition out of loop * randomize test order to catch any global fallback flag issues * Bump version=0.1.7 * don't unbox to python type for a single scalar argument * formatting
1 parent 8f03cad commit 3db5815

7 files changed

Lines changed: 449 additions & 32 deletions

File tree

.github/workflows/pytest-poetry.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ jobs:
2525
cache-dependency-path: poetry.lock
2626
- name: Install dependencies
2727
run: poetry install --no-interaction
28-
- name: Run tests
29-
run: poetry run pytest
3028
- name: run black
3129
run: poetry run black . --check
3230
- name: Run ruff
33-
run: poetry run ruff check .
31+
run: poetry run ruff check .
32+
- name: Run tests
33+
run: poetry run pytest --random-order

betapert/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from betapert import funcs
1111

12+
FALLBACK = None
13+
1214

1315
class PERT(scipy.stats.rv_continuous):
1416
"""The `PERT distribution <https://en.wikipedia.org/wiki/PERT_distribution>`_ is defined by the
@@ -60,7 +62,7 @@ def _stats(self, mini, mode, maxi):
6062
return funcs.stats(mini, mode, maxi)
6163

6264
def _ppf(self, q, mini, mode, maxi):
63-
return funcs.ppf(q, mini, mode, maxi)
65+
return funcs.ppf(q, mini, mode, maxi, fallback=FALLBACK)
6466

6567
def _rvs(self, mini, mode, maxi, size=None, random_state=None):
6668
return funcs.rvs(mini, mode, maxi, size=size, random_state=random_state)
@@ -116,7 +118,7 @@ def _stats(self, mini, mode, maxi, lambd):
116118
return funcs.stats(mini, mode, maxi, lambd)
117119

118120
def _ppf(self, q, mini, mode, maxi, lambd):
119-
return funcs.ppf(q, mini, mode, maxi, lambd)
121+
return funcs.ppf(q, mini, mode, maxi, lambd, fallback=FALLBACK)
120122

121123
def _rvs(self, mini, mode, maxi, lambd, size=None, random_state=None):
122124
return funcs.rvs(mini, mode, maxi, lambd, size=size, random_state=random_state)

betapert/funcs.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,63 @@
22
33
This module contains the core mathematical functions used by the PERT and modified PERT distribution
44
classes. Each function takes the distribution parameters (minimum, mode, maximum, and optionally
5-
lambda) and implementsa specific statistical operation like pdf, cdf, etc.
5+
lambda) and implements a specific statistical operation like pdf, cdf, etc.
66
"""
77

88
import numpy as np
9+
import scipy.optimize
910
import scipy.stats
1011

12+
# Avoid log(0) or log(1) which would cause -inf or 0
13+
_CLIP_EPSILON = 1e-15
14+
_BRENTQ_BOUND = 1e-10
15+
16+
17+
def _ppf_fallback_log_space(q, mini, mode, maxi, lambd):
18+
"""Use log-space to avoid numerical issues with extreme probabilities"""
19+
alpha, beta = _calc_alpha_beta(mini, mode, maxi, lambd)
20+
21+
# Handle scalar and array inputs consistently
22+
_q = np.atleast_1d(q)
23+
results = np.zeros_like(_q, dtype=float)
24+
25+
# Define the equation to solve: log(CDF(x)) - log(q) = 0
26+
def make_log_cdf_eq(qi_val):
27+
log_qi = np.log(np.clip(qi_val, _CLIP_EPSILON, 1 - _CLIP_EPSILON))
28+
29+
def log_cdf_eq(x_normalized):
30+
# Ensure x_normalized stays in [0,1]
31+
x_clamped = np.clip(x_normalized, _CLIP_EPSILON, 1 - _CLIP_EPSILON)
32+
return scipy.stats.beta.logcdf(x_clamped, alpha, beta) - log_qi
33+
34+
return log_cdf_eq
35+
36+
for i, qi in enumerate(_q.flat):
37+
try:
38+
# Use brentq instead of fsolve, guaranteed convergence within bounds
39+
x_normalized = scipy.optimize.brentq(
40+
make_log_cdf_eq(qi),
41+
_BRENTQ_BOUND,
42+
1 - _BRENTQ_BOUND,
43+
)
44+
results.flat[i] = mini + (maxi - mini) * x_normalized
45+
46+
except (ValueError, RuntimeError):
47+
# ValueError: Invalid function values, convergence issues, or invalid bounds
48+
# RuntimeError: Maximum iterations exceeded, numerical problems
49+
# Fallback to clamped ppf if log-space fails
50+
qi_safe = np.clip(qi, _CLIP_EPSILON, 1 - _CLIP_EPSILON)
51+
x_normalized = scipy.stats.beta.ppf(qi_safe, alpha, beta)
52+
results[i] = mini + (maxi - mini) * x_normalized
53+
54+
# Returns scalar for scalar input, array for array input
55+
return results[0] if np.isscalar(q) else results
56+
57+
58+
_ppf_fallbacks = {
59+
"log": _ppf_fallback_log_space,
60+
}
61+
1162

1263
def _calc_alpha_beta(mini, mode, maxi, lambd):
1364
"""Calculate alpha and beta parameters for the underlying beta distribution.
@@ -42,9 +93,13 @@ def sf(x, mini, mode, maxi, lambd=4):
4293
return scipy.stats.beta.sf((x - mini) / (maxi - mini), alpha, beta)
4394

4495

45-
def ppf(q, mini, mode, maxi, lambd=4):
96+
def ppf(q, mini, mode, maxi, lambd=4, *, fallback=None):
4697
alpha, beta = _calc_alpha_beta(mini, mode, maxi, lambd)
47-
return mini + (maxi - mini) * scipy.stats.beta.ppf(q, alpha, beta)
98+
_beta_ppf = mini + (maxi - mini) * scipy.stats.beta.ppf(q, alpha, beta)
99+
# Use fallback if any values are NaN
100+
if fallback is not None and np.any(np.atleast_1d(np.isnan(_beta_ppf))):
101+
return _ppf_fallbacks[fallback](q, mini, mode, maxi, lambd)
102+
return _beta_ppf
48103

49104

50105
def isf(q, mini, mode, maxi, lambd=4):

poetry.lock

Lines changed: 35 additions & 20 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[tool]
22
[tool.poetry]
33
name = "beta-pert-dist-scipy"
4-
version = "0.1.6"
4+
version = "0.1.7"
55
homepage = "https://github.com/hbmartin/betapert"
66
description = "Top-level package for beta-PERT distribution."
77
authors = ["Tom Adamczewski <[email protected]>", "Harold Martin [email protected]"]
@@ -19,14 +19,15 @@ packages = [
1919

2020
[tool.poetry.dependencies]
2121
python = ">=3.11"
22-
scipy = "^1.14.0"
22+
scipy = ">=1.14.1"
2323

2424
[tool.poetry.group.dev.dependencies]
2525
coverage = "*"
2626
pytest = ">=7.2.0"
2727
black = {extras = ["d"], version = "*"}
2828
matplotlib = "^3.9.0"
2929
ruff = "^0.12.3"
30+
pytest-random-order = "^1.2.0"
3031

3132

3233

@@ -54,7 +55,7 @@ select = ["ALL"]
5455
ignore = ["ANN001", "ANN201", "ANN202", "ARG001", "D203", "D205", "D213", "D400", "D415", "PLR0913"]
5556

5657
[tool.ruff.lint.per-file-ignores]
57-
"tests/*.py" = ["ARG002", "D100", "D101", "D102", "D103", "D104", "D200", "D212", "D401", "D404", "E501", "E731", "NPY002", "PLR2004", "RET503", "S101"]
58+
"tests/*.py" = ["ANN002", "ANN003", "ARG002", "ANN206", "D100", "D101", "D102", "D103", "D104", "D200", "D212", "D401", "D404", "E501", "E731", "EM101", "NPY002", "PLR2004", "PT011", "RET503", "S101", "SLF001", "TRY003"]
5859
"betapert/funcs.py" = ["D103", "RUF002", "RUF003"]
5960

6061
[tool.ruff.format]

pytest.ini

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
[pytest]
2-
addopts = --doctest-modules
2+
addopts = --doctest-modules
3+
filterwarnings = error::pytest.PytestCollectionWarning

0 commit comments

Comments
 (0)