1919import numpy as np
2020import scipy .sparse as ss
2121
22- from .basic import LightGBMError , _choose_param_value , _ConfigAliases , _log_info , _log_warning
22+ from .basic import Dataset , LightGBMError , _choose_param_value , _ConfigAliases , _log_info , _log_warning
2323from .compat import (
2424 DASK_INSTALLED ,
2525 PANDAS_INSTALLED ,
@@ -1080,7 +1080,7 @@ def _lgb_dask_fit(
10801080 * ,
10811081 model_factory : Type [LGBMModel ],
10821082 X : _DaskMatrixLike ,
1083- y : _DaskCollection ,
1083+ y : Optional [ _DaskCollection ] = None ,
10841084 sample_weight : Optional [_DaskVectorLike ] = None ,
10851085 init_score : Optional [_DaskCollection ] = None ,
10861086 group : Optional [_DaskVectorLike ] = None ,
@@ -1096,6 +1096,8 @@ def _lgb_dask_fit(
10961096 eval_at : Optional [Union [List [int ], Tuple [int , ...]]] = None ,
10971097 ** kwargs : Any ,
10981098 ) -> "_DaskLGBMModel" :
1099+ if isinstance (X , Dataset ):
1100+ raise LightGBMError ("Passing a pre-built lightgbm.Dataset as X is not supported by DaskLGBM estimators" )
10991101 if not DASK_INSTALLED :
11001102 raise LightGBMError ("dask is required for lightgbm.dask" )
11011103 if not all ((DASK_INSTALLED , PANDAS_INSTALLED , SKLEARN_INSTALLED )):
@@ -1104,6 +1106,7 @@ def _lgb_dask_fit(
11041106 params = self .get_params (True ) # type: ignore[attr-defined]
11051107 params .pop ("client" , None )
11061108
1109+ assert y is not None , "DaskLGBM requires y (the Dataset-only path is rejected above)"
11071110 model = _train (
11081111 client = _get_dask_client (self .client ),
11091112 data = X ,
@@ -1219,7 +1222,7 @@ def __getstate__(self) -> Dict[Any, Any]:
12191222 def fit ( # type: ignore[override]
12201223 self ,
12211224 X : _DaskMatrixLike ,
1222- y : _DaskCollection ,
1225+ y : Optional [ _DaskCollection ] = None ,
12231226 sample_weight : Optional [_DaskVectorLike ] = None ,
12241227 init_score : Optional [_DaskCollection ] = None ,
12251228 eval_set : Optional [List [Tuple [_DaskMatrixLike , _DaskCollection ]]] = None ,
@@ -1427,7 +1430,7 @@ def __getstate__(self) -> Dict[Any, Any]:
14271430 def fit ( # type: ignore[override]
14281431 self ,
14291432 X : _DaskMatrixLike ,
1430- y : _DaskCollection ,
1433+ y : Optional [ _DaskCollection ] = None ,
14311434 sample_weight : Optional [_DaskVectorLike ] = None ,
14321435 init_score : Optional [_DaskVectorLike ] = None ,
14331436 eval_set : Optional [List [Tuple [_DaskMatrixLike , _DaskCollection ]]] = None ,
@@ -1600,7 +1603,7 @@ def __getstate__(self) -> Dict[Any, Any]:
16001603 def fit ( # type: ignore[override]
16011604 self ,
16021605 X : _DaskMatrixLike ,
1603- y : _DaskCollection ,
1606+ y : Optional [ _DaskCollection ] = None ,
16041607 sample_weight : Optional [_DaskVectorLike ] = None ,
16051608 init_score : Optional [_DaskVectorLike ] = None ,
16061609 group : Optional [_DaskVectorLike ] = None ,
0 commit comments