Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .compat import (
CFFI_INSTALLED,
PANDAS_INSTALLED,
POLARS_INSTALLED,
PYARROW_INSTALLED,
arrow_cffi,
arrow_is_boolean,
Expand All @@ -43,6 +44,8 @@
pd_CategoricalDtype,
pd_DataFrame,
pd_Series,
pl_DataFrame,
pl_Series,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -129,6 +132,8 @@
pd_DataFrame,
pa_Array,
pa_ChunkedArray,
pl_DataFrame,
pl_Series,
]
_LGBM_PredictDataType = Union[
str,
Expand Down Expand Up @@ -872,6 +877,47 @@ def _data_from_pandas(
)


def _data_from_polars(
data: "pl_DataFrame",
feature_name: _LGBM_FeatureNameConfiguration,
categorical_feature: _LGBM_CategoricalFeatureConfiguration,
pandas_categorical: Optional[List[List]],
) -> Tuple["pa_Table", List[str], Union[List[str], List[int]], List[List]]:
if not POLARS_INSTALLED:
raise LightGBMError("Polars is not installed, cannot convert Polars DataFrame to Arrow Table.")
if not PYARROW_INSTALLED:
raise LightGBMError("Cannot use Polars input with LightGBM without 'pyarrow' installed.")

import polars as pl

if data.height < 1:
raise ValueError("Input data must be non empty.")

# determine feature names
if feature_name == "auto":
feature_name = data.columns

# determine categorical features
cat_cols = [col for col in data.columns if str(data[col].dtype) == "Categorical"]
if pandas_categorical is None: # train dataset
pandas_categorical = [list(data[col].cat.get_categories()) for col in cat_cols]
else:
if len(cat_cols) != len(pandas_categorical):
raise ValueError("train and valid dataset categorical_feature do not match.")
for col, category in zip(cat_cols, pandas_categorical):
if list(data[col].cat.get_categories()) != list(category):
mapping = {v: i for i, v in enumerate(category)}
data = data.with_columns(pl.col(col).replace_strict(mapping, default=np.nan))
if cat_cols: # cat_cols is list
data = data.with_columns(pl.col(cat).to_physical().cast(pl.Int8).fill_null(np.nan) for cat in cat_cols)

# use cat cols from DataFrame
if categorical_feature == "auto":
categorical_feature = cat_cols

return data.to_arrow(), feature_name, categorical_feature, pandas_categorical


def _dump_pandas_categorical(
pandas_categorical: Optional[List[List]],
file_name: Optional[Union[str, Path]] = None,
Expand Down Expand Up @@ -1161,6 +1207,13 @@ def predict(
categorical_feature="auto",
pandas_categorical=self.pandas_categorical,
)[0]
elif isinstance(data, pl_DataFrame):
data = _data_from_polars(
data=data,
feature_name="auto",
categorical_feature="auto",
pandas_categorical=self.pandas_categorical,
)[0]

predict_type = _C_API_PREDICT_NORMAL
if raw_score:
Expand Down Expand Up @@ -2135,6 +2188,13 @@ def _lazy_init(
categorical_feature=categorical_feature,
pandas_categorical=self.pandas_categorical,
)
elif isinstance(data, pl_DataFrame):
data, feature_name, categorical_feature, self.pandas_categorical = _data_from_polars(
data=data,
feature_name=feature_name,
categorical_feature=categorical_feature,
pandas_categorical=self.pandas_categorical,
)
elif _is_pyarrow_table(data) and feature_name == "auto":
feature_name = data.column_names

Expand Down Expand Up @@ -3088,6 +3148,12 @@ def set_label(self, label: Optional[_LGBM_LabelType]) -> "Dataset":
if len(label.columns) > 1:
raise ValueError("DataFrame for label cannot have multiple columns")
label_array = np.ravel(_pandas_to_numpy(label, target_dtype=np.float32))
if isinstance(label, pl_DataFrame):
if len(label.columns) > 1:
raise ValueError("DataFrame for label cannot have multiple columns")
label_array = label[label.columns[0]].to_arrow()
elif isinstance(label, pl_Series):
label_array = label.to_arrow()
elif _is_pyarrow_array(label):
label_array = label
else:
Expand Down
29 changes: 29 additions & 0 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,35 @@ def __init__(self, *args: Any, **kwargs: Any):
pass


"""polars"""
try:
from polars import Categorical as pl_Categorical
from polars import DataFrame as pl_DataFrame
from polars import Int32 as pl_Int32
from polars import Series as pl_Series
from polars import String as pl_String

POLARS_INSTALLED = True
except ImportError:
POLARS_INSTALLED = False

class pl_DataFrame: # type: ignore
"""Dummy class for pl.DataFrame."""

def __init__(self, *args: Any, **kwargs: Any):
pass

class pl_Series: # type: ignore
"""Dummy class for pl.Series."""

def __init__(self, *args: Any, **kwargs: Any):
pass

pl_Categorical = None
pl_Int32 = None
pl_String = None


"""cpu_count()"""
try:
from joblib import cpu_count
Expand Down
5 changes: 3 additions & 2 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
_sklearn_version,
pa_Table,
pd_DataFrame,
pl_DataFrame,
)
from .engine import train

Expand Down Expand Up @@ -1013,7 +1014,7 @@ def fit(
params["metric"] = [e for e in eval_metrics_builtin if e not in params["metric"]] + params["metric"]
params["metric"] = [metric for metric in params["metric"] if metric is not None]

if not isinstance(X, (pd_DataFrame, pa_Table)):
if not isinstance(X, (pd_DataFrame, pa_Table, pl_DataFrame)):
_X, _y = _LGBMValidateData(
self,
X,
Expand Down Expand Up @@ -1181,7 +1182,7 @@ def predict(
"""Docstring is set after definition, using a template."""
if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError("Estimator not fitted, call fit before exploiting the model.")
if not isinstance(X, (pd_DataFrame, pa_Table)):
if not isinstance(X, (pd_DataFrame, pa_Table, pl_DataFrame)):
X = _LGBMValidateData(
self,
X,
Expand Down