Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
classifier_config_dict = {
'sklearn.naive_bayes.GaussianNB': {
},

'sklearn.naive_bayes.BernoulliNB': {
'alpha': [1e-3, 1e-2, 1e-1, 1., 10., 100.],
'fit_prior': [True, False]
},

'sklearn.naive_bayes.MultinomialNB': {
'alpha': [1e-3, 1e-2, 1e-1, 1., 10., 100.],
'fit_prior': [True, False]
}
}

31 changes: 29 additions & 2 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ def test_set_params_2():
assert tpot_obj.generations == 3


def test_lite_params():
"""Assert that TPOT uses TPOT's lite dictionary of operators when config_dict is 'TPOT light' or 'TPOT MDR'."""
def test_conf_dict():
"""Assert that TPOT uses the pre-configured dictionary of operators when config_dict is 'TPOT light' or 'TPOT MDR'."""
tpot_obj = TPOTClassifier(config_dict='TPOT light')
assert tpot_obj.config_dict == classifier_config_dict_light

Expand All @@ -196,6 +196,33 @@ def test_lite_params():
assert_raises(TypeError, TPOTRegressor, config_dict='TPOT MDR')


def test_conf_dict_2():
"""Assert that TPOT uses a custom dictionary of operators when config_dict is Python dictionary."""
tpot_obj = TPOTClassifier(config_dict=tpot_mdr_classifier_config_dict)
assert tpot_obj.config_dict == tpot_mdr_classifier_config_dict


def test_conf_dict_3():
"""Assert that TPOT uses a custom dictionary of operators when config_dict is the path of Python dictionary."""
tpot_obj = TPOTRegressor(config_dict='tests.conf')
tested_config_dict = {
'sklearn.naive_bayes.GaussianNB': {
},

'sklearn.naive_bayes.BernoulliNB': {
'alpha': [1e-3, 1e-2, 1e-1, 1., 10., 100.],
'fit_prior': [True, False]
},

'sklearn.naive_bayes.MultinomialNB': {
'alpha': [1e-3, 1e-2, 1e-1, 1., 10., 100.],
'fit_prior': [True, False]
}
}
assert isinstance(tpot_obj.config_dict, dict)
assert tpot_obj.config_dict == tested_config_dict


def test_random_ind():
"""Assert that the TPOTClassifier can generate the same pipeline with same random seed."""
tpot_obj = TPOTClassifier(random_state=43)
Expand Down
15 changes: 8 additions & 7 deletions tpot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,9 +909,10 @@ def _generate(self, pset, min_, max_, condition, type_=None):
except IndexError:
_, _, traceback = sys.exc_info()
raise IndexError(
'The gp.generate function tried to add a terminal of '
'type \'%s\', but there is none available.' % (type_,)
).with_traceback(traceback)
'The gp.generate function tried to add '
'a terminal of type {}, but there is'
'none available. {}'.format(type_, traceback)
)
if inspect.isclass(term):
term = term()
expr.append(term)
Expand All @@ -921,11 +922,11 @@ def _generate(self, pset, min_, max_, condition, type_=None):
except IndexError:
_, _, traceback = sys.exc_info()
raise IndexError(
'The gp.generate function tried to add a terminal of '
'type \'%s\', but there is none available.' % (type_,)
).with_traceback(traceback)
'The gp.generate function tried to add '
'a primitive of type {}, but there is'
'none available. {}'.format(type_, traceback)
)
expr.append(prim)
for arg in reversed(prim.args):
stack.append((depth+1, arg))

return expr