Skip to content

Commit a7ed56b

Browse files
committed
Apply fit kwargs in place; remove snapshot/restore.
1 parent 122a36a commit a7ed56b

2 files changed

Lines changed: 83 additions & 290 deletions

File tree

python-package/lightgbm/sklearn.py

Lines changed: 73 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -152,64 +152,6 @@ def _get_weight_from_constructed_dataset(dataset: Dataset) -> Optional[np.ndarra
152152
return weight
153153

154154

155-
def _snapshot_mutable_fields(dataset: Dataset) -> Dict[str, Optional[np.ndarray]]:
156-
# snapshot fields fit() may overwrite, so they can be restored after train()
157-
def _copy_or_none(value: Any) -> Optional[np.ndarray]:
158-
return np.asarray(value).copy() if value is not None else None
159-
160-
return {
161-
"label": _copy_or_none(dataset.get_label()),
162-
"weight": _copy_or_none(dataset.get_weight()),
163-
"group": _copy_or_none(dataset.get_group()),
164-
"init_score": _copy_or_none(dataset.get_init_score()),
165-
}
166-
167-
168-
def _restore_mutable_fields(dataset: Dataset, snapshot: Dict[str, Optional[np.ndarray]]) -> None:
169-
# reset the Python attr too so the lazy-cached get_*() reflects the C++ clear
170-
if snapshot["label"] is not None:
171-
dataset.set_label(snapshot["label"])
172-
if snapshot["weight"] is not None:
173-
dataset.set_weight(snapshot["weight"])
174-
else:
175-
dataset.weight = None
176-
dataset.set_field("weight", None)
177-
if snapshot["group"] is not None:
178-
dataset.set_group(snapshot["group"])
179-
else:
180-
dataset.group = None
181-
dataset.set_field("group", None)
182-
if snapshot["init_score"] is not None:
183-
dataset.set_init_score(snapshot["init_score"])
184-
else:
185-
dataset.init_score = None
186-
dataset.set_field("init_score", None)
187-
188-
189-
def _best_effort_restore(dataset: Dataset, snapshot: Dict[str, Optional[np.ndarray]], context: str) -> None:
190-
# restore that swallows any error rather than mask the caller's primary exception;
191-
# failures are surfaced via _log_warning so the user sees them in the log
192-
try:
193-
_restore_mutable_fields(dataset, snapshot)
194-
except Exception as restore_err: # noqa: BLE001
195-
_log_warning(f"Failed to restore a {context} Dataset field after fit: {restore_err}")
196-
197-
198-
def _set_eval_label(
199-
dataset: Dataset,
200-
label: Any,
201-
snapshots: List[Tuple[Dataset, Dict[str, Optional[np.ndarray]]]],
202-
) -> None:
203-
# roll back on failure
204-
dataset.construct()
205-
snapshots.append((dataset, _snapshot_mutable_fields(dataset)))
206-
try:
207-
dataset.set_label(label)
208-
except BaseException:
209-
_best_effort_restore(dataset, snapshots.pop()[1], "eval")
210-
raise
211-
212-
213155
class _ObjectiveFunctionWrapper:
214156
"""Proxy class for objective function."""
215157

@@ -451,13 +393,14 @@ def __call__(
451393
452394
Notes
453395
-----
454-
When ``X`` is a pre-built ``lightgbm.Dataset``, ``y`` may be ``None``;
396+
When ``X`` is a pre-built ``lightgbm.Dataset``, ``y`` may be ``None``.
455397
``y`` / ``sample_weight`` / ``group`` / ``init_score`` passed to ``fit()``
456-
are applied for the fit and rolled back on return. Binning parameters on the
457-
estimator (``max_bin``, ``min_data_in_bin``, etc.) are ignored because the
458-
Dataset's binning is frozen at construction time; build validation
459-
Datasets with ``reference=<training Dataset>`` to share it. The
460-
sklearn-level validation that runs on the array path
398+
are applied to the Dataset in place via the ``set_*`` API; an omitted
399+
kwarg leaves the corresponding field on the Dataset unchanged. Binning
400+
parameters on the estimator (``max_bin``, ``min_data_in_bin``, etc.)
401+
are ignored because the Dataset's binning is frozen at construction
402+
time; build validation Datasets with ``reference=<training Dataset>``
403+
to share it. The sklearn-level validation that runs on the array path
461404
(``ensure_min_samples``, ``_LGBMCheckSampleWeight``, etc.) is not
462405
re-applied, matching ``lightgbm.train()``.
463406
"""
@@ -1128,8 +1071,6 @@ def fit(
11281071
"""Docstring is set after definition, using a template."""
11291072
params = self._process_params(stage="fit")
11301073

1131-
dataset_snapshots: List[Tuple[Dataset, Dict[str, Optional[np.ndarray]]]] = []
1132-
11331074
# Do not modify original args in fit function
11341075
# Refer to https://github.com/lightgbm-org/LightGBM/pull/2619
11351076
eval_metric_list: List[Union[str, _LGBM_ScikitCustomEvalFunction]]
@@ -1171,32 +1112,25 @@ def fit(
11711112
# construct now so n_features_in_ is known and label can be read back for class weight
11721113
train_set.construct()
11731114
self.n_features_in_ = train_set.num_feature()
1174-
# snapshot + try/except: mutations below must roll back on failure so the user's
1175-
# Dataset is not left half-written if e.g. set_label raises on a length mismatch
1176-
dataset_snapshots.append((train_set, _snapshot_mutable_fields(train_set)))
1177-
try:
1178-
if y is not None:
1179-
train_set.set_label(y)
1180-
if group is not None:
1181-
train_set.set_group(group)
1182-
if init_score is not None:
1183-
train_set.set_init_score(init_score)
1184-
1185-
if self._class_weight is None:
1186-
self._class_weight = self.class_weight
1187-
if self._class_weight is not None:
1188-
y_for_class_weight = _get_label_from_constructed_dataset(train_set) if y is None else y
1189-
class_sample_weight = _LGBMComputeSampleWeight(self._class_weight, y_for_class_weight)
1190-
if sample_weight is None or len(sample_weight) == 0:
1191-
sample_weight = class_sample_weight
1192-
else:
1193-
sample_weight = np.multiply(sample_weight, class_sample_weight)
1115+
if y is not None:
1116+
train_set.set_label(y)
1117+
1118+
if self._class_weight is None:
1119+
self._class_weight = self.class_weight
1120+
if self._class_weight is not None:
1121+
y_for_class_weight = _get_label_from_constructed_dataset(train_set) if y is None else y
1122+
class_sample_weight = _LGBMComputeSampleWeight(self._class_weight, y_for_class_weight)
1123+
if sample_weight is None or len(sample_weight) == 0:
1124+
sample_weight = class_sample_weight
1125+
else:
1126+
sample_weight = np.multiply(sample_weight, class_sample_weight)
11941127

1195-
if sample_weight is not None:
1196-
train_set.set_weight(sample_weight)
1197-
except BaseException:
1198-
_best_effort_restore(train_set, dataset_snapshots.pop()[1], "training")
1199-
raise
1128+
if sample_weight is not None:
1129+
train_set.set_weight(sample_weight)
1130+
if group is not None:
1131+
train_set.set_group(group)
1132+
if init_score is not None:
1133+
train_set.set_init_score(init_score)
12001134
else:
12011135
train_set = self._build_train_set_from_array(
12021136
X, y, sample_weight, group, init_score, categorical_feature, feature_name, params
@@ -1233,7 +1167,8 @@ def fit(
12331167
f"pass reference=train_set when constructing it to use the same binning"
12341168
)
12351169
if valid_y is not None:
1236-
_set_eval_label(valid_x, valid_y, dataset_snapshots)
1170+
valid_x.construct()
1171+
valid_x.set_label(valid_y)
12371172
valid_set = valid_x
12381173
elif valid_x is X and valid_y is y:
12391174
# reduce cost for prediction training data
@@ -1290,38 +1225,33 @@ def fit(
12901225
evals_result: _EvalResultDict = {}
12911226
callbacks.append(record_evaluation(evals_result))
12921227

1293-
try:
1294-
self._Booster = train(
1295-
params=params,
1296-
train_set=train_set,
1297-
num_boost_round=self.n_estimators,
1298-
valid_sets=valid_sets,
1299-
valid_names=eval_names,
1300-
feval=eval_metrics_callable, # type: ignore[arg-type]
1301-
init_model=init_model,
1302-
callbacks=callbacks,
1303-
)
1228+
self._Booster = train(
1229+
params=params,
1230+
train_set=train_set,
1231+
num_boost_round=self.n_estimators,
1232+
valid_sets=valid_sets,
1233+
valid_names=eval_names,
1234+
feval=eval_metrics_callable, # type: ignore[arg-type]
1235+
init_model=init_model,
1236+
callbacks=callbacks,
1237+
)
1238+
1239+
# This populates the property self.n_features_, the number of features in the fitted model,
1240+
# and so should only be set after fitting.
1241+
#
1242+
# The related property self._n_features_in, which populates self.n_features_in_,
1243+
# is set BEFORE fitting.
1244+
self._n_features = self._Booster.num_feature()
1245+
1246+
self._evals_result = evals_result
1247+
self._best_iteration = self._Booster.best_iteration
1248+
self._best_score = self._Booster.best_score
1249+
1250+
self.fitted_ = True
13041251

1305-
# This populates the property self.n_features_, the number of features in the fitted model,
1306-
# and so should only be set after fitting.
1307-
#
1308-
# The related property self._n_features_in, which populates self.n_features_in_,
1309-
# is set BEFORE fitting.
1310-
self._n_features = self._Booster.num_feature()
1311-
1312-
self._evals_result = evals_result
1313-
self._best_iteration = self._Booster.best_iteration
1314-
self._best_score = self._Booster.best_score
1315-
1316-
self.fitted_ = True
1317-
1318-
# free dataset
1319-
self._Booster.free_dataset()
1320-
del train_set, valid_sets
1321-
finally:
1322-
# restore any user-passed Dataset fields we mutated via set_*, so fit is non-mutating
1323-
for snap_ds, snapshot in dataset_snapshots:
1324-
_best_effort_restore(snap_ds, snapshot, "training")
1252+
# free dataset
1253+
self._Booster.free_dataset()
1254+
del train_set, valid_sets
13251255
return self
13261256

13271257
fit.__doc__ = (
@@ -1808,7 +1738,6 @@ def fit( # type: ignore[override]
18081738

18091739
# do not modify args, as it causes errors in model selection tools
18101740
valid_sets: Optional[List[_LGBM_ScikitValidSet]] = None
1811-
eval_dataset_snapshots: List[Tuple[Dataset, Dict[str, Optional[np.ndarray]]]] = []
18121741
if eval_set is not None:
18131742
if isinstance(eval_set, tuple):
18141743
eval_set = [eval_set]
@@ -1823,35 +1752,32 @@ def fit( # type: ignore[override]
18231752
valid_x, valid_y = valid_data[0], valid_data[1]
18241753
if isinstance(valid_x, Dataset):
18251754
if valid_y is not None:
1826-
_set_eval_label(valid_x, self._le.transform(valid_y), eval_dataset_snapshots)
1755+
valid_x.construct()
1756+
valid_x.set_label(self._le.transform(valid_y))
18271757
valid_sets.append(valid_x) # type: ignore[arg-type]
18281758
elif valid_x is X and valid_y is y:
18291759
valid_sets.append((valid_x, _y))
18301760
else:
18311761
valid_sets.append((valid_x, self._le.transform(valid_y)))
18321762

1833-
try:
1834-
super().fit(
1835-
X,
1836-
_y,
1837-
sample_weight=sample_weight,
1838-
init_score=init_score,
1839-
eval_set=valid_sets,
1840-
eval_names=eval_names,
1841-
eval_X=eval_X,
1842-
eval_y=eval_y,
1843-
eval_sample_weight=eval_sample_weight,
1844-
eval_class_weight=eval_class_weight,
1845-
eval_init_score=eval_init_score,
1846-
eval_metric=eval_metric,
1847-
feature_name=feature_name,
1848-
categorical_feature=categorical_feature,
1849-
callbacks=callbacks,
1850-
init_model=init_model,
1851-
)
1852-
finally:
1853-
for snap_ds, snapshot in eval_dataset_snapshots:
1854-
_best_effort_restore(snap_ds, snapshot, "eval")
1763+
super().fit(
1764+
X,
1765+
_y,
1766+
sample_weight=sample_weight,
1767+
init_score=init_score,
1768+
eval_set=valid_sets,
1769+
eval_names=eval_names,
1770+
eval_X=eval_X,
1771+
eval_y=eval_y,
1772+
eval_sample_weight=eval_sample_weight,
1773+
eval_class_weight=eval_class_weight,
1774+
eval_init_score=eval_init_score,
1775+
eval_metric=eval_metric,
1776+
feature_name=feature_name,
1777+
categorical_feature=categorical_feature,
1778+
callbacks=callbacks,
1779+
init_model=init_model,
1780+
)
18551781
return self
18561782

18571783
_base_doc = LGBMModel.fit.__doc__.replace("self : LGBMModel", "self : LGBMClassifier") # type: ignore

0 commit comments

Comments
 (0)