Skip to content

Commit 4efff76

Browse files
authored
Add LightGBM dtreeviz tree visualization support (#346)
* Add LightGBM dtreeviz dashboard support (closes #118) * Document LightGBM decision-tree visualization support
1 parent 451357f commit 4efff76

File tree

12 files changed

+451
-25
lines changed

12 files changed

+451
-25
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ The library includes:
104104
- *Permutation importances* (how much does the model metric deteriorate when you shuffle a feature?)
105105
- *Partial dependence plots* (how does the model prediction change when you vary a single feature?
106106
- *Shap interaction values* (decompose the shap value into a direct effect an interaction effects)
107-
- For Random Forests and xgboost models: visualisation of individual decision trees
107+
- For Random Forest, XGBoost, and LightGBM models: visualisation of individual decision trees
108108
- Plus for classifiers: precision plots, confusion matrix, ROC AUC plot, PR AUC plot, etc
109109
- For regression models: goodness-of-fit plots, residual plots, etc.
110110

RELEASE_NOTES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
- Fix XGBoost multiclass decision-path summary wording to display `prediction (logodds)` when explainer `model_output='logodds'`.
1313
- Fix issue #256: add robust multiclass probability fallback for classifiers that expose `decision_function` but not `predict_proba` (e.g. `LinearSVC`), and use it consistently across kernel SHAP, prediction helpers, PDP, and permutation scorer paths.
1414
- Prevent multiclass class-count mismatches when user-provided/broken `predict_proba` outputs do not match model class count by falling back to `decision_function`-based probabilities.
15+
- Fix issue #118: add LightGBM decision-tree visualization support (dtreeviz) across explainer auto-detection, tree plotting, and decision-path rendering in dashboard tree tabs.
16+
- Fix dtreeviz callback rendering on macOS by switching matplotlib to a non-interactive backend for off-main-thread tree rendering to prevent dashboard 500 errors.
1517

1618
### Tests
1719
- Add regression tests for LightGBM with string categorical features covering dashboard initialization, `get_shap_row(...)`, unseen categorical values in `X_row`, and regression dashboard initialization.
@@ -22,6 +24,7 @@
2224
- Add explainer-method unit tests for binary-like onehot detection, transformed feature-name deduping, inferred pipeline cats, and pipeline extraction warning text.
2325
- Add regression tests for issue #256 covering multiclass `LinearSVC` with kernel SHAP, PDP, and permutation-importances flows using `decision_function` fallback.
2426
- Add guard tests to confirm multiclass `predict_proba` models (logistic regression) keep working for PDP and permutation-importances paths.
27+
- Add LightGBM tree-visualization regression tests (shadow trees, decision paths, plot_trees, and dtreeviz render contracts) in the boosting-model test suite.
2528

2629
### Improvements
2730
- Add pipeline feature-name cleanup options: `strip_pipeline_prefix=True` and `feature_name_fn=...` for sklearn/imblearn pipeline transformed output columns.

TODO.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,23 @@
1212
- [S][Hub][#146/#342] hub.to_yaml integrate_dashboard_yamls honors pickle_type and dumps integrated explainer artifacts.
1313
- [M][Explainers][#294] align/explain multiclass logodds between Contributions Plot and Prediction Box (+ PDP highlight and XGBoost decision path wording alignment).
1414
- [M][Explainers/Methods/Docs][#213] improve sklearn/imblearn pipeline support: feature-name cleanup (`strip_pipeline_prefix`, `feature_name_fn`), auto-detect onehot groups (`auto_detect_pipeline_cats`), accept binary-like scaled onehot columns in `cats`, preserve transformed index, add warnings/docs/tests.
15+
- [M][Explainers/Methods/Tests/Docs][#256] improve multiclass LinearSVC support/docs with decision_function probability fallback and regression coverage for SHAP/PDP/permutation flows.
16+
- [M][Explainers/Methods/Components/Tests][#118] add LightGBM tree visualization support (dtreeviz), including tree explainer wiring, dashboard tree tabs, and regression coverage.
1517

1618
**Now**
17-
- [M][Explainers][#118] add LightGBM tree visualization support (dtreeviz).
19+
- [M][Dashboard][#161] more flexible instantiate_component (no explainer needed for non-ExplainerComponents).
1820

1921
**Next**
20-
- [M][Dashboard][#263/#161] more flexible instantiate_component (no explainer needed for non-ExplainerComponents).
22+
- [M] add ExtraTrees and GradientBoostingClassifier to tree visualizers.
2123

2224
**Backlog: Explainers**
2325
- [M] add plain language explanations for plots (in_words + UI toggle).
2426
- [S] pass n_jobs to pdp_isolate.
2527
- [M] add ExtraTrees and GradientBoostingClassifier to tree visualizers.
26-
- [M][#118] add LightGBM tree visualization support (dtreeviz).
2728

2829
**Backlog: Dashboard**
2930
- [S] make poweredby right-aligned.
30-
- [M][#263/#161] more flexible instantiate_component (no explainer needed for non-ExplainerComponents).
31+
- [M][#161] more flexible instantiate_component (no explainer needed for non-ExplainerComponents).
3132
- [M] add TablePopout.
3233
- [M][#247] add EDA-style feature histograms/bar charts/correlation graphs.
3334
- [M/L] add cost calculator/optimizer for classifier models (confusion matrix weights, Youden J).
@@ -54,7 +55,6 @@
5455
- [M] support SamplingExplainer, PartitionExplainer, PermutationExplainer, AdditiveExplainer.
5556
- [M] support LimeTabularExplainer.
5657
- [M] investigate method from https://arxiv.org/abs/2006.04750.
57-
- [M][#256] improve multiclass LinearSVC support/docs (class-count mismatch with SHAP output).
5858
- [M][#229] clarify/add support path for Poisson and Gamma regression explainers.
5959

6060
**Backlog: Plots**

docs/source/deployment.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ And you need to tell heroku how to start your server in ``Procfile``::
126126
Graphviz buildpack
127127
------------------
128128

129-
If you want to visualize individual trees inside your ``RandomForest`` or ``xgboost``
129+
If you want to visualize individual trees inside your ``RandomForest``, ``xgboost`` or ``lightgbm``
130130
model using the ``dtreeviz`` package you will
131131
need to make sure that ``graphviz`` is installed on your ``heroku`` dyno by
132132
adding the following buildstack (as well as the ``python`` buildpack):

docs/source/explainers.rst

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -456,10 +456,10 @@ plot_residuals_vs_feature
456456
DecisionTree Plots
457457
------------------
458458

459-
There are additional mixin classes specifically for ``sklearn`` ``RandomForests``
460-
and for xgboost models that define additional methods and plots to investigate and visualize
461-
individual decision trees within the ensemblke. These
462-
uses the ``dtreeviz`` library to visualize individual decision trees.
459+
There are additional mixin classes specifically for ``sklearn`` ``RandomForests``,
460+
``xgboost``, and ``lightgbm`` models that define additional methods and plots to
461+
investigate and visualize individual decision trees within the ensemble. These
462+
use the ``dtreeviz`` library to visualize individual decision trees.
463463

464464
You can get a pd.DataFrame summary of the path that a specific index row took
465465
through a specific decision tree.
@@ -476,9 +476,9 @@ And for dtreeviz visualization of individual decision trees (svg format)::
476476
explainer.decisiontree_file(tree_idx, index)
477477
explainer.decisiontree_encoded(tree_idx, index)
478478

479-
These methods are part of the ``RandomForestExplainer`` and XGBExplainer`` mixin
480-
classes that get automatically loaded when you pass either a RandomForest
481-
or XGBoost model.
479+
These methods are part of the ``RandomForestExplainer``, ``XGBExplainer``, and
480+
``LGBMExplainer`` mixin classes that get automatically loaded when you pass a
481+
RandomForest, XGBoost, or LightGBM model.
482482

483483

484484
plot_trees
@@ -661,12 +661,12 @@ restrict candidate rows by feature values before selecting a random index::
661661
.. automethod:: explainerdashboard.explainers.RegressionExplainer.random_index
662662

663663

664-
RandomForest and XGBoost outputs
665-
--------------------------------
664+
RandomForest, XGBoost, and LightGBM outputs
665+
-------------------------------------------
666666

667-
For RandomForest and XGBoost models mixin classes that visualize individual
668-
decision trees will be loaded: ``RandomForestExplainer`` and ``XGBExplainer``
669-
with the following additional methods::
667+
For RandomForest, XGBoost, and LightGBM models mixin classes that visualize
668+
individual decision trees will be loaded: ``RandomForestExplainer``,
669+
``XGBExplainer``, and ``LGBMExplainer`` with the following additional methods::
670670

671671
decisiontree_df(tree_idx, index, pos_label=None)
672672
decisiontree_summary_df(tree_idx, index, round=2, pos_label=None)

docs/source/index.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ with just two lines of code.
1212

1313
It allows you to investigate SHAP values, permutation importances,
1414
interaction effects, partial dependence plots, all kinds of performance plots,
15-
and even individual decision trees inside a random forest. With ``explainerdashboard`` any data
15+
and even individual decision trees inside random forest, XGBoost, and LightGBM models.
16+
With ``explainerdashboard`` any data
1617
scientist can create an interactive explainable AI web app in minutes,
1718
without having to know anything about web development or deployment.
1819

explainerdashboard/dashboard_components/decisiontree_components.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from dash.exceptions import PreventUpdate
99
import dash_bootstrap_components as dbc
1010

11-
from ..explainers import RandomForestExplainer, XGBExplainer
11+
from ..explainers import RandomForestExplainer, XGBExplainer, LGBMExplainer
1212
from ..dashboard_methods import *
1313
from .. import to_html
1414

@@ -94,12 +94,20 @@ def __init__(
9494
elif isinstance(self.explainer, XGBExplainer):
9595
if self.description is None:
9696
self.description = """
97-
Shows the marginal contributions of each decision tree in an
97+
Shows the marginal contributions of each decision tree in an
9898
xgboost ensemble to the final prediction. This demonstrates that
9999
an xgboost model is simply a sum of individual decision trees.
100100
"""
101101
if self.subtitle == "Displaying individual decision trees":
102102
self.subtitle += " inside xgboost model"
103+
elif isinstance(self.explainer, LGBMExplainer):
104+
if self.description is None:
105+
self.description = """
106+
Shows the marginal contributions of each decision tree in a
107+
LightGBM ensemble to the final prediction.
108+
"""
109+
if self.subtitle == "Displaying individual decision trees":
110+
self.subtitle += " inside LightGBM model"
103111
else:
104112
if self.description is None:
105113
self.description = ""

explainerdashboard/explainer_methods.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"get_xgboost_path_df",
4141
"get_xgboost_path_summary_df",
4242
"get_xgboost_preds_df",
43+
"get_lgbm_preds_df",
4344
"get_multiclass_logodds_scores",
4445
"get_xgboost_output_label",
4546
"_ensure_numeric_predictions", # Internal helper for XGBoost 3.0+ compatibility
@@ -2165,7 +2166,14 @@ def node_pred_proba(node):
21652166
else:
21662167

21672168
def node_mean(node):
2168-
return decision_tree.tree_model.tree_.value[node.id].item()
2169+
try:
2170+
return decision_tree.tree_model.tree_.value[node.id].item()
2171+
except Exception:
2172+
node_samples = decision_tree.get_node_samples()
2173+
sample_idxs = node_samples.get(node.id, [])
2174+
if len(sample_idxs) == 0:
2175+
return np.nan
2176+
return float(np.asarray(decision_tree.y_train)[sample_idxs].mean())
21692177

21702178
for node in nodes:
21712179
if not node.isleaf():
@@ -2549,3 +2557,93 @@ def get_xgboost_preds_df(xgbmodel, X_row, pos_label=1):
25492557
0, "pred_proba"
25502558
]
25512559
return xgboost_preds_df
2560+
2561+
2562+
def get_lgbm_preds_df(lgbmodel, X_row, pos_label=1):
2563+
"""Returns cumulative per-tree predictions for a LightGBM model.
2564+
2565+
Args:
2566+
lgbmodel: fitted LightGBM sklearn-compatible model
2567+
(i.e. LGBMClassifier or LGBMRegressor)
2568+
X_row: a single row of data, e.g X_train.iloc[0]
2569+
pos_label: for classifier the label to be used as positive label
2570+
Defaults to 1.
2571+
2572+
Returns:
2573+
pd.DataFrame
2574+
"""
2575+
if safe_isinstance(lgbmodel, "lightgbm.sklearn.LGBMClassifier"):
2576+
is_classifier = True
2577+
n_classes = len(lgbmodel.classes_)
2578+
n_trees = lgbmodel.booster_.num_trees()
2579+
if n_classes > 2:
2580+
n_trees = int(n_trees / n_classes)
2581+
elif safe_isinstance(lgbmodel, "lightgbm.sklearn.LGBMRegressor"):
2582+
is_classifier = False
2583+
n_trees = lgbmodel.booster_.num_trees()
2584+
else:
2585+
raise ValueError("Pass either an LGBMClassifier or LGBMRegressor!")
2586+
2587+
if is_classifier:
2588+
if n_classes == 2:
2589+
if pos_label not in (0, 1):
2590+
raise ValueError("pos_label should be either 0 or 1!")
2591+
2592+
margins = []
2593+
for i in range(1, n_trees + 1):
2594+
margin_raw = lgbmodel.predict(X_row, raw_score=True, num_iteration=i)[0]
2595+
margin_raw = _ensure_numeric_predictions(margin_raw)
2596+
if isinstance(margin_raw, np.ndarray):
2597+
margin_raw = (
2598+
margin_raw.item()
2599+
if margin_raw.ndim == 0
2600+
else float(margin_raw[0])
2601+
)
2602+
margin = float(margin_raw)
2603+
margins.append(margin if pos_label == 1 else -margin)
2604+
2605+
pred_probas = (np.exp(margins) / (1 + np.exp(margins))).tolist()
2606+
base_score = 0.0
2607+
base_proba = 0.5
2608+
preds = margins
2609+
else:
2610+
if pos_label < 0 or pos_label >= n_classes:
2611+
raise ValueError(
2612+
f"pos_label={pos_label}, but should be >= 0 and <= {n_classes - 1}!"
2613+
)
2614+
margins = []
2615+
for i in range(1, n_trees + 1):
2616+
margin_raw = lgbmodel.predict(X_row, raw_score=True, num_iteration=i)[0]
2617+
margin_raw = _ensure_numeric_predictions(margin_raw)
2618+
margin = np.asarray(margin_raw, dtype=float)
2619+
margins.append(margin)
2620+
2621+
preds = [float(margin[pos_label]) for margin in margins]
2622+
pred_probas = [
2623+
float((np.exp(margin) / np.exp(margin).sum())[pos_label])
2624+
for margin in margins
2625+
]
2626+
base_score = 0.0
2627+
base_proba = 1.0 / n_classes
2628+
else:
2629+
preds = []
2630+
for i in range(1, n_trees + 1):
2631+
pred_raw = lgbmodel.predict(X_row, raw_score=True, num_iteration=i)[0]
2632+
pred_raw = _ensure_numeric_predictions(pred_raw)
2633+
if isinstance(pred_raw, np.ndarray):
2634+
pred_raw = pred_raw.item() if pred_raw.ndim == 0 else float(pred_raw[0])
2635+
preds.append(float(pred_raw))
2636+
base_score = 0.0
2637+
2638+
lgbm_preds_df = pd.DataFrame(
2639+
dict(tree=range(-1, n_trees), pred=[base_score] + preds)
2640+
)
2641+
lgbm_preds_df["pred_diff"] = lgbm_preds_df.pred.diff()
2642+
lgbm_preds_df.loc[0, "pred_diff"] = lgbm_preds_df.loc[0, "pred"]
2643+
2644+
if is_classifier:
2645+
lgbm_preds_df["pred_proba"] = [base_proba] + pred_probas
2646+
lgbm_preds_df["pred_proba_diff"] = lgbm_preds_df.pred_proba.diff()
2647+
lgbm_preds_df.loc[0, "pred_proba_diff"] = lgbm_preds_df.loc[0, "pred_proba"]
2648+
2649+
return lgbm_preds_df

explainerdashboard/explainer_plots.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2930,6 +2930,7 @@ def plotly_xgboost_trees(
29302930
target="",
29312931
units="",
29322932
higher_is_better=True,
2933+
model_name="xgboost",
29332934
):
29342935
"""Generate a plot showing the prediction of every single tree inside an XGBoost model
29352936
@@ -2944,6 +2945,8 @@ def plotly_xgboost_trees(
29442945
units (str, optional): Units of target variable. Defaults to "".
29452946
higher_is_better (bool, optional): up is green, down is red. If False then
29462947
flip the colors.
2948+
model_name (str, optional): model family label used in chart titles.
2949+
Defaults to "xgboost".
29472950
29482951
Returns:
29492952
Plotly fig
@@ -3041,10 +3044,10 @@ def plotly_xgboost_trees(
30413044
)
30423045

30433046
if target:
3044-
title = f"Individual xgboost decision trees predicting {target}"
3047+
title = f"Individual {model_name} decision trees predicting {target}"
30453048
yaxis_title = f"Predicted {target} {f'({units})' if units else ''}"
30463049
else:
3047-
title = "Individual xgboost decision trees"
3050+
title = f"Individual {model_name} decision trees"
30483051
yaxis_title = f"Predicted outcome ({units})" if units else "Predicted outcome"
30493052

30503053
layout = go.Layout(

0 commit comments

Comments
 (0)