@@ -152,64 +152,6 @@ def _get_weight_from_constructed_dataset(dataset: Dataset) -> Optional[np.ndarra
152152 return weight
153153
154154
155- def _snapshot_mutable_fields (dataset : Dataset ) -> Dict [str , Optional [np .ndarray ]]:
156- # snapshot fields fit() may overwrite, so they can be restored after train()
157- def _copy_or_none (value : Any ) -> Optional [np .ndarray ]:
158- return np .asarray (value ).copy () if value is not None else None
159-
160- return {
161- "label" : _copy_or_none (dataset .get_label ()),
162- "weight" : _copy_or_none (dataset .get_weight ()),
163- "group" : _copy_or_none (dataset .get_group ()),
164- "init_score" : _copy_or_none (dataset .get_init_score ()),
165- }
166-
167-
168- def _restore_mutable_fields (dataset : Dataset , snapshot : Dict [str , Optional [np .ndarray ]]) -> None :
169- # reset the Python attr too so the lazy-cached get_*() reflects the C++ clear
170- if snapshot ["label" ] is not None :
171- dataset .set_label (snapshot ["label" ])
172- if snapshot ["weight" ] is not None :
173- dataset .set_weight (snapshot ["weight" ])
174- else :
175- dataset .weight = None
176- dataset .set_field ("weight" , None )
177- if snapshot ["group" ] is not None :
178- dataset .set_group (snapshot ["group" ])
179- else :
180- dataset .group = None
181- dataset .set_field ("group" , None )
182- if snapshot ["init_score" ] is not None :
183- dataset .set_init_score (snapshot ["init_score" ])
184- else :
185- dataset .init_score = None
186- dataset .set_field ("init_score" , None )
187-
188-
189- def _best_effort_restore (dataset : Dataset , snapshot : Dict [str , Optional [np .ndarray ]], context : str ) -> None :
190- # restore that swallows any error rather than mask the caller's primary exception;
191- # failures are surfaced via _log_warning so the user sees them in the log
192- try :
193- _restore_mutable_fields (dataset , snapshot )
194- except Exception as restore_err : # noqa: BLE001
195- _log_warning (f"Failed to restore a { context } Dataset field after fit: { restore_err } " )
196-
197-
198- def _set_eval_label (
199- dataset : Dataset ,
200- label : Any ,
201- snapshots : List [Tuple [Dataset , Dict [str , Optional [np .ndarray ]]]],
202- ) -> None :
203- # roll back on failure
204- dataset .construct ()
205- snapshots .append ((dataset , _snapshot_mutable_fields (dataset )))
206- try :
207- dataset .set_label (label )
208- except BaseException :
209- _best_effort_restore (dataset , snapshots .pop ()[1 ], "eval" )
210- raise
211-
212-
213155class _ObjectiveFunctionWrapper :
214156 """Proxy class for objective function."""
215157
@@ -451,13 +393,14 @@ def __call__(
451393
452394 Notes
453395 -----
454- When ``X`` is a pre-built ``lightgbm.Dataset``, ``y`` may be ``None``;
396+ When ``X`` is a pre-built ``lightgbm.Dataset``, ``y`` may be ``None``.
455397 ``y`` / ``sample_weight`` / ``group`` / ``init_score`` passed to ``fit()``
456- are applied for the fit and rolled back on return. Binning parameters on the
457- estimator (``max_bin``, ``min_data_in_bin``, etc.) are ignored because the
458- Dataset's binning is frozen at construction time; build validation
459- Datasets with ``reference=<training Dataset>`` to share it. The
460- sklearn-level validation that runs on the array path
398+ are applied to the Dataset in place via the ``set_*`` API; an omitted
399+ kwarg leaves the corresponding field on the Dataset unchanged. Binning
400+ parameters on the estimator (``max_bin``, ``min_data_in_bin``, etc.)
401+ are ignored because the Dataset's binning is frozen at construction
402+ time; build validation Datasets with ``reference=<training Dataset>``
403+ to share it. The sklearn-level validation that runs on the array path
461404 (``ensure_min_samples``, ``_LGBMCheckSampleWeight``, etc.) is not
462405 re-applied, matching ``lightgbm.train()``.
463406 """
@@ -1128,8 +1071,6 @@ def fit(
11281071 """Docstring is set after definition, using a template."""
11291072 params = self ._process_params (stage = "fit" )
11301073
1131- dataset_snapshots : List [Tuple [Dataset , Dict [str , Optional [np .ndarray ]]]] = []
1132-
11331074 # Do not modify original args in fit function
11341075 # Refer to https://github.com/lightgbm-org/LightGBM/pull/2619
11351076 eval_metric_list : List [Union [str , _LGBM_ScikitCustomEvalFunction ]]
@@ -1171,32 +1112,25 @@ def fit(
11711112 # construct now so n_features_in_ is known and label can be read back for class weight
11721113 train_set .construct ()
11731114 self .n_features_in_ = train_set .num_feature ()
1174- # snapshot + try/except: mutations below must roll back on failure so the user's
1175- # Dataset is not left half-written if e.g. set_label raises on a length mismatch
1176- dataset_snapshots .append ((train_set , _snapshot_mutable_fields (train_set )))
1177- try :
1178- if y is not None :
1179- train_set .set_label (y )
1180- if group is not None :
1181- train_set .set_group (group )
1182- if init_score is not None :
1183- train_set .set_init_score (init_score )
1184-
1185- if self ._class_weight is None :
1186- self ._class_weight = self .class_weight
1187- if self ._class_weight is not None :
1188- y_for_class_weight = _get_label_from_constructed_dataset (train_set ) if y is None else y
1189- class_sample_weight = _LGBMComputeSampleWeight (self ._class_weight , y_for_class_weight )
1190- if sample_weight is None or len (sample_weight ) == 0 :
1191- sample_weight = class_sample_weight
1192- else :
1193- sample_weight = np .multiply (sample_weight , class_sample_weight )
1115+ if y is not None :
1116+ train_set .set_label (y )
1117+
1118+ if self ._class_weight is None :
1119+ self ._class_weight = self .class_weight
1120+ if self ._class_weight is not None :
1121+ y_for_class_weight = _get_label_from_constructed_dataset (train_set ) if y is None else y
1122+ class_sample_weight = _LGBMComputeSampleWeight (self ._class_weight , y_for_class_weight )
1123+ if sample_weight is None or len (sample_weight ) == 0 :
1124+ sample_weight = class_sample_weight
1125+ else :
1126+ sample_weight = np .multiply (sample_weight , class_sample_weight )
11941127
1195- if sample_weight is not None :
1196- train_set .set_weight (sample_weight )
1197- except BaseException :
1198- _best_effort_restore (train_set , dataset_snapshots .pop ()[1 ], "training" )
1199- raise
1128+ if sample_weight is not None :
1129+ train_set .set_weight (sample_weight )
1130+ if group is not None :
1131+ train_set .set_group (group )
1132+ if init_score is not None :
1133+ train_set .set_init_score (init_score )
12001134 else :
12011135 train_set = self ._build_train_set_from_array (
12021136 X , y , sample_weight , group , init_score , categorical_feature , feature_name , params
@@ -1233,7 +1167,8 @@ def fit(
12331167 f"pass reference=train_set when constructing it to use the same binning"
12341168 )
12351169 if valid_y is not None :
1236- _set_eval_label (valid_x , valid_y , dataset_snapshots )
1170+ valid_x .construct ()
1171+ valid_x .set_label (valid_y )
12371172 valid_set = valid_x
12381173 elif valid_x is X and valid_y is y :
12391174 # reduce cost for prediction training data
@@ -1290,38 +1225,33 @@ def fit(
12901225 evals_result : _EvalResultDict = {}
12911226 callbacks .append (record_evaluation (evals_result ))
12921227
1293- try :
1294- self ._Booster = train (
1295- params = params ,
1296- train_set = train_set ,
1297- num_boost_round = self .n_estimators ,
1298- valid_sets = valid_sets ,
1299- valid_names = eval_names ,
1300- feval = eval_metrics_callable , # type: ignore[arg-type]
1301- init_model = init_model ,
1302- callbacks = callbacks ,
1303- )
1228+ self ._Booster = train (
1229+ params = params ,
1230+ train_set = train_set ,
1231+ num_boost_round = self .n_estimators ,
1232+ valid_sets = valid_sets ,
1233+ valid_names = eval_names ,
1234+ feval = eval_metrics_callable , # type: ignore[arg-type]
1235+ init_model = init_model ,
1236+ callbacks = callbacks ,
1237+ )
1238+
1239+ # This populates the property self.n_features_, the number of features in the fitted model,
1240+ # and so should only be set after fitting.
1241+ #
1242+ # The related property self._n_features_in, which populates self.n_features_in_,
1243+ # is set BEFORE fitting.
1244+ self ._n_features = self ._Booster .num_feature ()
1245+
1246+ self ._evals_result = evals_result
1247+ self ._best_iteration = self ._Booster .best_iteration
1248+ self ._best_score = self ._Booster .best_score
1249+
1250+ self .fitted_ = True
13041251
1305- # This populates the property self.n_features_, the number of features in the fitted model,
1306- # and so should only be set after fitting.
1307- #
1308- # The related property self._n_features_in, which populates self.n_features_in_,
1309- # is set BEFORE fitting.
1310- self ._n_features = self ._Booster .num_feature ()
1311-
1312- self ._evals_result = evals_result
1313- self ._best_iteration = self ._Booster .best_iteration
1314- self ._best_score = self ._Booster .best_score
1315-
1316- self .fitted_ = True
1317-
1318- # free dataset
1319- self ._Booster .free_dataset ()
1320- del train_set , valid_sets
1321- finally :
1322- # restore any user-passed Dataset fields we mutated via set_*, so fit is non-mutating
1323- for snap_ds , snapshot in dataset_snapshots :
1324- _best_effort_restore (snap_ds , snapshot , "training" )
1252+ # free dataset
1253+ self ._Booster .free_dataset ()
1254+ del train_set , valid_sets
13251255 return self
13261256
13271257 fit .__doc__ = (
@@ -1808,7 +1738,6 @@ def fit( # type: ignore[override]
18081738
18091739 # do not modify args, as it causes errors in model selection tools
18101740 valid_sets : Optional [List [_LGBM_ScikitValidSet ]] = None
1811- eval_dataset_snapshots : List [Tuple [Dataset , Dict [str , Optional [np .ndarray ]]]] = []
18121741 if eval_set is not None :
18131742 if isinstance (eval_set , tuple ):
18141743 eval_set = [eval_set ]
@@ -1823,35 +1752,32 @@ def fit( # type: ignore[override]
18231752 valid_x , valid_y = valid_data [0 ], valid_data [1 ]
18241753 if isinstance (valid_x , Dataset ):
18251754 if valid_y is not None :
1826- _set_eval_label (valid_x , self ._le .transform (valid_y ), eval_dataset_snapshots )
1755+ valid_x .construct ()
1756+ valid_x .set_label (self ._le .transform (valid_y ))
18271757 valid_sets .append (valid_x ) # type: ignore[arg-type]
18281758 elif valid_x is X and valid_y is y :
18291759 valid_sets .append ((valid_x , _y ))
18301760 else :
18311761 valid_sets .append ((valid_x , self ._le .transform (valid_y )))
18321762
1833- try :
1834- super ().fit (
1835- X ,
1836- _y ,
1837- sample_weight = sample_weight ,
1838- init_score = init_score ,
1839- eval_set = valid_sets ,
1840- eval_names = eval_names ,
1841- eval_X = eval_X ,
1842- eval_y = eval_y ,
1843- eval_sample_weight = eval_sample_weight ,
1844- eval_class_weight = eval_class_weight ,
1845- eval_init_score = eval_init_score ,
1846- eval_metric = eval_metric ,
1847- feature_name = feature_name ,
1848- categorical_feature = categorical_feature ,
1849- callbacks = callbacks ,
1850- init_model = init_model ,
1851- )
1852- finally :
1853- for snap_ds , snapshot in eval_dataset_snapshots :
1854- _best_effort_restore (snap_ds , snapshot , "eval" )
1763+ super ().fit (
1764+ X ,
1765+ _y ,
1766+ sample_weight = sample_weight ,
1767+ init_score = init_score ,
1768+ eval_set = valid_sets ,
1769+ eval_names = eval_names ,
1770+ eval_X = eval_X ,
1771+ eval_y = eval_y ,
1772+ eval_sample_weight = eval_sample_weight ,
1773+ eval_class_weight = eval_class_weight ,
1774+ eval_init_score = eval_init_score ,
1775+ eval_metric = eval_metric ,
1776+ feature_name = feature_name ,
1777+ categorical_feature = categorical_feature ,
1778+ callbacks = callbacks ,
1779+ init_model = init_model ,
1780+ )
18551781 return self
18561782
18571783 _base_doc = LGBMModel .fit .__doc__ .replace ("self : LGBMModel" , "self : LGBMClassifier" ) # type: ignore
0 commit comments