|
40 | 40 | "get_xgboost_path_df", |
41 | 41 | "get_xgboost_path_summary_df", |
42 | 42 | "get_xgboost_preds_df", |
| 43 | + "get_lgbm_preds_df", |
43 | 44 | "get_multiclass_logodds_scores", |
44 | 45 | "get_xgboost_output_label", |
45 | 46 | "_ensure_numeric_predictions", # Internal helper for XGBoost 3.0+ compatibility |
@@ -2165,7 +2166,14 @@ def node_pred_proba(node): |
2165 | 2166 | else: |
2166 | 2167 |
|
2167 | 2168 | def node_mean(node): |
2168 | | - return decision_tree.tree_model.tree_.value[node.id].item() |
| 2169 | + try: |
| 2170 | + return decision_tree.tree_model.tree_.value[node.id].item() |
| 2171 | + except Exception: |
| 2172 | + node_samples = decision_tree.get_node_samples() |
| 2173 | + sample_idxs = node_samples.get(node.id, []) |
| 2174 | + if len(sample_idxs) == 0: |
| 2175 | + return np.nan |
| 2176 | + return float(np.asarray(decision_tree.y_train)[sample_idxs].mean()) |
2169 | 2177 |
|
2170 | 2178 | for node in nodes: |
2171 | 2179 | if not node.isleaf(): |
@@ -2549,3 +2557,93 @@ def get_xgboost_preds_df(xgbmodel, X_row, pos_label=1): |
2549 | 2557 | 0, "pred_proba" |
2550 | 2558 | ] |
2551 | 2559 | return xgboost_preds_df |
| 2560 | + |
| 2561 | + |
| 2562 | +def get_lgbm_preds_df(lgbmodel, X_row, pos_label=1): |
| 2563 | + """Returns cumulative per-tree predictions for a LightGBM model. |
| 2564 | +
|
| 2565 | + Args: |
| 2566 | + lgbmodel: fitted LightGBM sklearn-compatible model |
| 2567 | + (i.e. LGBMClassifier or LGBMRegressor) |
| 2568 | + X_row: a single row of data, e.g X_train.iloc[0] |
| 2569 | + pos_label: for classifier the label to be used as positive label |
| 2570 | + Defaults to 1. |
| 2571 | +
|
| 2572 | + Returns: |
| 2573 | + pd.DataFrame |
| 2574 | + """ |
| 2575 | + if safe_isinstance(lgbmodel, "lightgbm.sklearn.LGBMClassifier"): |
| 2576 | + is_classifier = True |
| 2577 | + n_classes = len(lgbmodel.classes_) |
| 2578 | + n_trees = lgbmodel.booster_.num_trees() |
| 2579 | + if n_classes > 2: |
| 2580 | + n_trees = int(n_trees / n_classes) |
| 2581 | + elif safe_isinstance(lgbmodel, "lightgbm.sklearn.LGBMRegressor"): |
| 2582 | + is_classifier = False |
| 2583 | + n_trees = lgbmodel.booster_.num_trees() |
| 2584 | + else: |
| 2585 | + raise ValueError("Pass either an LGBMClassifier or LGBMRegressor!") |
| 2586 | + |
| 2587 | + if is_classifier: |
| 2588 | + if n_classes == 2: |
| 2589 | + if pos_label not in (0, 1): |
| 2590 | + raise ValueError("pos_label should be either 0 or 1!") |
| 2591 | + |
| 2592 | + margins = [] |
| 2593 | + for i in range(1, n_trees + 1): |
| 2594 | + margin_raw = lgbmodel.predict(X_row, raw_score=True, num_iteration=i)[0] |
| 2595 | + margin_raw = _ensure_numeric_predictions(margin_raw) |
| 2596 | + if isinstance(margin_raw, np.ndarray): |
| 2597 | + margin_raw = ( |
| 2598 | + margin_raw.item() |
| 2599 | + if margin_raw.ndim == 0 |
| 2600 | + else float(margin_raw[0]) |
| 2601 | + ) |
| 2602 | + margin = float(margin_raw) |
| 2603 | + margins.append(margin if pos_label == 1 else -margin) |
| 2604 | + |
| 2605 | + pred_probas = (np.exp(margins) / (1 + np.exp(margins))).tolist() |
| 2606 | + base_score = 0.0 |
| 2607 | + base_proba = 0.5 |
| 2608 | + preds = margins |
| 2609 | + else: |
| 2610 | + if pos_label < 0 or pos_label >= n_classes: |
| 2611 | + raise ValueError( |
| 2612 | + f"pos_label={pos_label}, but should be >= 0 and <= {n_classes - 1}!" |
| 2613 | + ) |
| 2614 | + margins = [] |
| 2615 | + for i in range(1, n_trees + 1): |
| 2616 | + margin_raw = lgbmodel.predict(X_row, raw_score=True, num_iteration=i)[0] |
| 2617 | + margin_raw = _ensure_numeric_predictions(margin_raw) |
| 2618 | + margin = np.asarray(margin_raw, dtype=float) |
| 2619 | + margins.append(margin) |
| 2620 | + |
| 2621 | + preds = [float(margin[pos_label]) for margin in margins] |
| 2622 | + pred_probas = [ |
| 2623 | + float((np.exp(margin) / np.exp(margin).sum())[pos_label]) |
| 2624 | + for margin in margins |
| 2625 | + ] |
| 2626 | + base_score = 0.0 |
| 2627 | + base_proba = 1.0 / n_classes |
| 2628 | + else: |
| 2629 | + preds = [] |
| 2630 | + for i in range(1, n_trees + 1): |
| 2631 | + pred_raw = lgbmodel.predict(X_row, raw_score=True, num_iteration=i)[0] |
| 2632 | + pred_raw = _ensure_numeric_predictions(pred_raw) |
| 2633 | + if isinstance(pred_raw, np.ndarray): |
| 2634 | + pred_raw = pred_raw.item() if pred_raw.ndim == 0 else float(pred_raw[0]) |
| 2635 | + preds.append(float(pred_raw)) |
| 2636 | + base_score = 0.0 |
| 2637 | + |
| 2638 | + lgbm_preds_df = pd.DataFrame( |
| 2639 | + dict(tree=range(-1, n_trees), pred=[base_score] + preds) |
| 2640 | + ) |
| 2641 | + lgbm_preds_df["pred_diff"] = lgbm_preds_df.pred.diff() |
| 2642 | + lgbm_preds_df.loc[0, "pred_diff"] = lgbm_preds_df.loc[0, "pred"] |
| 2643 | + |
| 2644 | + if is_classifier: |
| 2645 | + lgbm_preds_df["pred_proba"] = [base_proba] + pred_probas |
| 2646 | + lgbm_preds_df["pred_proba_diff"] = lgbm_preds_df.pred_proba.diff() |
| 2647 | + lgbm_preds_df.loc[0, "pred_proba_diff"] = lgbm_preds_df.loc[0, "pred_proba"] |
| 2648 | + |
| 2649 | + return lgbm_preds_df |
0 commit comments