Skip to content

Commit 9a4b75a

Browse files
joshloyaljnothman
authored andcommitted
ENH Document how to get valid parameters in set_params
Issue scikit-learn#4690
1 parent 64d3f45 commit 9a4b75a

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

sklearn/base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,15 +254,19 @@ def set_params(self, **params):
254254
# nested objects case
255255
name, sub_name = split
256256
if name not in valid_params:
257-
raise ValueError('Invalid parameter %s for estimator %s' %
257+
raise ValueError('Invalid parameter %s for estimator %s. '
258+
'Check the list of available parameters '
259+
'with `estimator.get_params().keys()`.' %
258260
(name, self))
259261
sub_object = valid_params[name]
260262
sub_object.set_params(**{sub_name: value})
261263
else:
262264
# simple objects case
263265
if key not in valid_params:
264-
raise ValueError('Invalid parameter %s ' 'for estimator %s'
265-
% (key, self.__class__.__name__))
266+
raise ValueError('Invalid parameter %s for estimator %s. '
267+
'Check the list of available parameters '
268+
'with `estimator.get_params().keys()`.' %
269+
(key, self.__class__.__name__))
266270
setattr(self, key, value)
267271
return self
268272

sklearn/tests/test_pipeline.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from scipy import sparse
66

77
from sklearn.externals.six.moves import zip
8-
from sklearn.utils.testing import assert_raises, assert_raises_regex
8+
from sklearn.utils.testing import assert_raises, assert_raises_regex, assert_raise_message
99
from sklearn.utils.testing import assert_equal
1010
from sklearn.utils.testing import assert_false
1111
from sklearn.utils.testing import assert_true
@@ -165,6 +165,27 @@ def test_pipeline_fit_params():
165165
assert_true(pipe.named_steps['transf'].b is None)
166166

167167

168+
def test_pipeline_raise_set_params_error():
169+
# Test pipeline raises set params error message for nested models.
170+
pipe = Pipeline([('cls', LinearRegression())])
171+
172+
# expected error message
173+
error_msg = ('Invalid parameter %s for estimator %s. '
174+
'Check the list of available parameters '
175+
'with `estimator.get_params().keys()`.' )
176+
177+
assert_raise_message(ValueError,
178+
error_msg % ('fake', 'Pipeline'),
179+
pipe.set_params,
180+
fake='nope')
181+
182+
# nested model check
183+
assert_raise_message(ValueError,
184+
error_msg % ("fake", pipe),
185+
pipe.set_params,
186+
fake__estimator='nope')
187+
188+
168189
def test_pipeline_methods_pca_svm():
169190
# Test the various methods of the pipeline (pca + svm).
170191
iris = load_iris()

0 commit comments

Comments
 (0)