diff --git a/ci/.travis_install.sh b/ci/.travis_install.sh index 267d8e601..b3adb8237 100755 --- a/ci/.travis_install.sh +++ b/ci/.travis_install.sh @@ -53,6 +53,8 @@ fi pip install update_checker pip install tqdm +pip install scikit-mdr +pip install skrebate if [[ "$COVERAGE" == "true" ]]; then pip install coverage coveralls diff --git a/tests.py b/tests.py index 21163100d..788cb6654 100644 --- a/tests.py +++ b/tests.py @@ -40,7 +40,7 @@ from datetime import datetime import subprocess -from sklearn.datasets import load_digits, load_boston +from sklearn.datasets import load_digits, load_boston, make_classification from sklearn.model_selection import train_test_split, cross_val_score from deap import creator from tqdm import tqdm @@ -175,7 +175,7 @@ def test_set_params_2(): assert tpot_obj.generations == 3 -def test_lite_params(): +def test_config_dict_params(): """Assert that TPOT uses TPOT's lite dictionary of operators when config_dict is \'TPOT light\' or \'TPOT MDR\'""" tpot_obj = TPOTClassifier(config_dict='TPOT light') assert tpot_obj.config_dict == classifier_config_dict_light @@ -426,6 +426,7 @@ def test_warm_start(): assert tpot_obj._pop == first_pop + def test_fit(): """Assert that the TPOT fit function provides an optimized pipeline""" tpot_obj = TPOTClassifier(random_state=42, population_size=1, offspring_size=2, generations=1, verbosity=0) @@ -434,6 +435,7 @@ def test_fit(): assert isinstance(tpot_obj._optimized_pipeline, creator.Individual) assert not (tpot_obj._start_datetime is None) + def test_fit2(): """Assert that the TPOT fit function provides an optimized pipeline when config_dict is \'TPOT light\'""" tpot_obj = TPOTClassifier(random_state=42, population_size=1, offspring_size=2, generations=1, verbosity=0, config_dict='TPOT light') @@ -442,6 +444,17 @@ def test_fit2(): assert isinstance(tpot_obj._optimized_pipeline, creator.Individual) assert not (tpot_obj._start_datetime is None) + +def test_fit3(): + """Assert that the TPOT fit function provides an optimized pipeline when config_dict is \'TPOT MDR\'""" + X, y = make_classification(n_samples=50, n_features=10, random_state=42, n_classes=2) # binary classification problem + tpot_obj = TPOTClassifier(random_state=42, population_size=1, offspring_size=2, generations=1, verbosity=0, config_dict='TPOT MDR') + tpot_obj.fit(X, y) + + assert isinstance(tpot_obj._optimized_pipeline, creator.Individual) + assert not (tpot_obj._start_datetime is None) + + def testTPOTOperatorClassFactory(): """Assert that the TPOT operators class factory""" test_config_dict = { diff --git a/tpot/config_classifier_mdr.py b/tpot/config_classifier_mdr.py index a90a9a702..402a066d9 100644 --- a/tpot/config_classifier_mdr.py +++ b/tpot/config_classifier_mdr.py @@ -30,7 +30,7 @@ # Classifiers - 'mdr.MDR': { + 'mdr.MDRClassifier': { 'tie_break': [0, 1], 'default_label': [0, 1] },