Skip to content

Commit 06e94ad

Browse files
SamsagaxjameslambStrikerRUS
authored
[python-package] Load parameters from model string (#6852)
* [python-package] Test serialization and deserialization from in-memory string Test case for #6851 * [python-package] Fill in `params` when loading from in-memory string This fixes #6851 by using the same workaround as when loading the model from a file. * test_basic: use rng instead of legacy numpy RandomState * test_basic: remove debug prints leftovers Co-authored-by: James Lamb <[email protected]> * test_basic: add boolean, array of float and array of integers to testcase Co-authored-by: James Lamb <[email protected]> * test_basic: make a cheaper model (2 rounds with 7 leaves each) Co-authored-by: James Lamb <[email protected]> * test_basic: bugfix typos * python_package_test: move string load test from basic to engine * test_engine: catch params ignored warnings * test_engine: be explicit about parameters assertion * test_engine: shush linter complaint * test_basic: delete empty line Co-authored-by: James Lamb <[email protected]> * test_engine: even cheaper model with less features Co-authored-by: James Lamb <[email protected]> * test_engine: delete redundant assert Co-authored-by: James Lamb <[email protected]> * test_engine: run pre-commit and take it's word for it * test_engine: be explicit in an E712 compliant way * test_engine.py: pass different value as argument to make sure it is ignored --------- Co-authored-by: James Lamb <[email protected]> Co-authored-by: Nikita Titov <[email protected]>
1 parent 0bbb02f commit 06e94ad

2 files changed

Lines changed: 48 additions & 1 deletion

File tree

python-package/lightgbm/basic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3715,6 +3715,9 @@ def __init__(
37153715
params = self._get_loaded_param()
37163716
elif model_str is not None:
37173717
self.model_from_string(model_str)
3718+
if params:
3719+
_log_warning("Ignoring params argument, using parameters from model string.")
3720+
params = self._get_loaded_param()
37183721
else:
37193722
raise TypeError(
37203723
"Need at least one training dataset or model file or model string to create Booster instance"

tests/python_package_test/test_engine.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1498,7 +1498,7 @@ def test_parameters_are_loaded_from_model_file(tmp_path, capsys, rng):
14981498
assert bst.params["categorical_feature"] == [1, 2]
14991499

15001500
# check that passing parameters to the constructor raises warning and ignores them
1501-
with pytest.warns(UserWarning, match="Ignoring params argument"):
1501+
with pytest.warns(UserWarning, match="Ignoring params argument, using parameters from model file."):
15021502
bst2 = lgb.Booster(params={"num_leaves": 7}, model_file=model_file)
15031503
assert bst.params == bst2.params
15041504

@@ -1508,6 +1508,50 @@ def test_parameters_are_loaded_from_model_file(tmp_path, capsys, rng):
15081508
np.testing.assert_allclose(preds, orig_preds)
15091509

15101510

1511+
def test_string_serialized_params_retrieval(rng):
1512+
# Random train data
1513+
train_x = rng.random((500, 3))
1514+
train_y = rng.integers(0, 1, 500)
1515+
train_data = lgb.Dataset(train_x, train_y)
1516+
1517+
# Parameters
1518+
params = {
1519+
"boosting": "gbdt",
1520+
"deterministic": True,
1521+
"feature_contri": [0.5] * train_x.shape[1],
1522+
"interaction_constraints": [[0, 1], [0]],
1523+
"objective": "binary",
1524+
"metric": ["auc"],
1525+
"num_leaves": 7,
1526+
"learning_rate": 0.05,
1527+
"feature_fraction": 0.9,
1528+
"bagging_fraction": 0.8,
1529+
"bagging_freq": 5,
1530+
"verbosity": -100,
1531+
}
1532+
1533+
# train a model and serialize it to a string in memory
1534+
model = lgb.train(params, train_data, num_boost_round=2)
1535+
model_serialized = model.model_to_string()
1536+
1537+
# load a new model with the string
1538+
with pytest.warns(UserWarning, match="Ignoring params argument, using parameters from model string."):
1539+
new_model = lgb.Booster(params={"num_leaves": 32}, model_str=model_serialized)
1540+
1541+
assert new_model.params["boosting"] == "gbdt"
1542+
assert new_model.params["deterministic"] is True
1543+
assert new_model.params["feature_contri"] == [0.5] * train_x.shape[1]
1544+
assert new_model.params["interaction_constraints"] == [[0, 1], [0]]
1545+
assert new_model.params["objective"] == "binary"
1546+
assert new_model.params["metric"] == ["auc"]
1547+
assert new_model.params["num_leaves"] == 7
1548+
assert new_model.params["learning_rate"] == 0.05
1549+
assert new_model.params["feature_fraction"] == 0.9
1550+
assert new_model.params["bagging_fraction"] == 0.8
1551+
assert new_model.params["bagging_freq"] == 5
1552+
assert new_model.params["verbosity"] == -100
1553+
1554+
15111555
def test_save_load_copy_pickle(tmp_path):
15121556
def train_and_predict(init_model=None, return_model=False):
15131557
X, y = make_synthetic_regression()

0 commit comments

Comments
 (0)