Skip to content

Commit 122a36a

Browse files
committed
[python-package] accept a pre-built Dataset as X in LGBMModel.fit
1 parent 6d7d06e commit 122a36a

4 files changed

Lines changed: 783 additions & 107 deletions

File tree

python-package/lightgbm/dask.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import numpy as np
2020
import scipy.sparse as ss
2121

22-
from .basic import LightGBMError, _choose_param_value, _ConfigAliases, _log_info, _log_warning
22+
from .basic import Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _log_info, _log_warning
2323
from .compat import (
2424
DASK_INSTALLED,
2525
PANDAS_INSTALLED,
@@ -1080,7 +1080,7 @@ def _lgb_dask_fit(
10801080
*,
10811081
model_factory: Type[LGBMModel],
10821082
X: _DaskMatrixLike,
1083-
y: _DaskCollection,
1083+
y: Optional[_DaskCollection] = None,
10841084
sample_weight: Optional[_DaskVectorLike] = None,
10851085
init_score: Optional[_DaskCollection] = None,
10861086
group: Optional[_DaskVectorLike] = None,
@@ -1096,6 +1096,8 @@ def _lgb_dask_fit(
10961096
eval_at: Optional[Union[List[int], Tuple[int, ...]]] = None,
10971097
**kwargs: Any,
10981098
) -> "_DaskLGBMModel":
1099+
if isinstance(X, Dataset):
1100+
raise LightGBMError("Passing a pre-built lightgbm.Dataset as X is not supported by DaskLGBM estimators")
10991101
if not DASK_INSTALLED:
11001102
raise LightGBMError("dask is required for lightgbm.dask")
11011103
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
@@ -1104,6 +1106,7 @@ def _lgb_dask_fit(
11041106
params = self.get_params(True) # type: ignore[attr-defined]
11051107
params.pop("client", None)
11061108

1109+
assert y is not None, "DaskLGBM requires y (the Dataset-only path is rejected above)"
11071110
model = _train(
11081111
client=_get_dask_client(self.client),
11091112
data=X,
@@ -1219,7 +1222,7 @@ def __getstate__(self) -> Dict[Any, Any]:
12191222
def fit( # type: ignore[override]
12201223
self,
12211224
X: _DaskMatrixLike,
1222-
y: _DaskCollection,
1225+
y: Optional[_DaskCollection] = None,
12231226
sample_weight: Optional[_DaskVectorLike] = None,
12241227
init_score: Optional[_DaskCollection] = None,
12251228
eval_set: Optional[List[Tuple[_DaskMatrixLike, _DaskCollection]]] = None,
@@ -1427,7 +1430,7 @@ def __getstate__(self) -> Dict[Any, Any]:
14271430
def fit( # type: ignore[override]
14281431
self,
14291432
X: _DaskMatrixLike,
1430-
y: _DaskCollection,
1433+
y: Optional[_DaskCollection] = None,
14311434
sample_weight: Optional[_DaskVectorLike] = None,
14321435
init_score: Optional[_DaskVectorLike] = None,
14331436
eval_set: Optional[List[Tuple[_DaskMatrixLike, _DaskCollection]]] = None,
@@ -1600,7 +1603,7 @@ def __getstate__(self) -> Dict[Any, Any]:
16001603
def fit( # type: ignore[override]
16011604
self,
16021605
X: _DaskMatrixLike,
1603-
y: _DaskCollection,
1606+
y: Optional[_DaskCollection] = None,
16041607
sample_weight: Optional[_DaskVectorLike] = None,
16051608
init_score: Optional[_DaskVectorLike] = None,
16061609
group: Optional[_DaskVectorLike] = None,

0 commit comments

Comments
 (0)