Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Use assert_raises in tests
  • Loading branch information
teaearlgraycold committed Apr 22, 2017
commit 2b8636d6b3d2c8afbc2960b24f15759e0b0d04ea
97 changes: 27 additions & 70 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from sklearn.model_selection import train_test_split, cross_val_score
from deap import creator
from tqdm import tqdm
from nose.tools import assert_raises

# Set up the MNIST data set for testing
mnist_data = load_digits()
Expand Down Expand Up @@ -124,27 +125,24 @@ def test_init_default_scoring():

def test_invaild_score_warning():
"""Assert that the TPOT fit function raises a ValueError when the scoring metrics is not available in SCORERS."""
try:
TPOTClassifier(scoring='balanced_accuray') # typo for balanced_accuracy
assert False
except ValueError:
pass
try:
TPOTClassifier(scoring='balanced_accuracy') # correct one
assert True
except Exception:
assert False
# Mis-spelled scorer
assert_raises(ValueError, TPOTClassifier, scoring='balanced_accuray')
# Correctly spelled
TPOTClassifier(scoring='balanced_accuracy')


def test_invaild_dataset_warning():
"""Assert that the TPOT fit function raises a ValueError when dataset is not in right format."""
tpot_obj = TPOTClassifier(random_state=42, population_size=1, offspring_size=2, generations=1, verbosity=0)
bad_training_classes = training_classes.reshape((1, len(training_classes))) # common mistake in classes
try:
tpot_obj.fit(training_features, bad_training_classes) # typo for balanced_accuracy
assert False
except ValueError:
pass
tpot_obj = TPOTClassifier(
random_state=42,
population_size=1,
offspring_size=2,
generations=1,
verbosity=0
)
# common mistake in classes
bad_training_classes = training_classes.reshape((1, len(training_classes)))
assert_raises(ValueError, tpot_obj.fit, training_features, bad_training_classes)


def test_init_max_time_mins():
Expand Down Expand Up @@ -201,11 +199,7 @@ def test_lite_params():
tpot_obj = TPOTRegressor(config_dict='TPOT light')
assert tpot_obj.config_dict == regressor_config_dict_light

try:
tpot_obj = TPOTRegressor(config_dict='TPOT MDR')
assert False
except TypeError:
assert True
assert_raises(TypeError, TPOTRegressor, config_dict='TPOT MDR')


def test_random_ind():
Expand Down Expand Up @@ -250,12 +244,7 @@ def test_random_ind_2():
def test_score():
"""Assert that the TPOT score function raises a RuntimeError when no optimized pipeline exists."""
tpot_obj = TPOTClassifier()

try:
tpot_obj.score(testing_features, testing_classes)
assert False # Should be unreachable
except RuntimeError:
pass
assert_raises(RuntimeError, tpot_obj.score, testing_features, testing_classes)


def test_score_2():
Expand Down Expand Up @@ -358,12 +347,7 @@ def test_sample_weight_func():
def test_predict():
"""Assert that the TPOT predict function raises a RuntimeError when no optimized pipeline exists."""
tpot_obj = TPOTClassifier()

try:
tpot_obj.predict(testing_features)
assert False # Should be unreachable
except RuntimeError:
pass
assert_raises(RuntimeError, tpot_obj.predict, testing_features)


def test_predict_2():
Expand All @@ -381,7 +365,6 @@ def test_predict_2():
tpot_obj._optimized_pipeline = creator.Individual.from_string(pipeline_string, tpot_obj._pset)
tpot_obj._fitted_pipeline = tpot_obj._toolbox.compile(expr=tpot_obj._optimized_pipeline)
tpot_obj._fitted_pipeline.fit(training_features, training_classes)

result = tpot_obj.predict(testing_features)

assert result.shape == (testing_features.shape[0],)
Expand Down Expand Up @@ -424,16 +407,11 @@ def test_predict_proba2():
tpot_obj._fitted_pipeline.fit(training_features, training_classes)

result = tpot_obj.predict_proba(testing_features)
rows = result.shape[0]
columns = result.shape[1]
rows, columns = result.shape

try:
for i in range(rows):
for j in range(columns):
float_range(result[i][j])
assert True
except Exception:
assert False
for i in range(rows):
for j in range(columns):
float_range(result[i][j])


def test_warm_start():
Expand Down Expand Up @@ -545,12 +523,7 @@ def test_operators():
def test_export():
"""Assert that TPOT's export function throws a RuntimeError when no optimized pipeline exists."""
tpot_obj = TPOTClassifier()

try:
tpot_obj.export("test_export.py")
assert False # Should be unreachable
except RuntimeError:
pass
assert_raises(RuntimeError, tpot_obj.export, "test_export.py")


def test_generate_pipeline_code():
Expand Down Expand Up @@ -813,11 +786,7 @@ def test_gen():

def test_positive_integer():
"""Assert that the TPOT CLI interface's integer parsing throws an exception when n < 0."""
try:
positive_integer('-1')
assert False # Should be unreachable
except Exception:
pass
assert_raises(Exception, positive_integer, '-1')


def test_positive_integer_2():
Expand All @@ -827,11 +796,7 @@ def test_positive_integer_2():

def test_positive_integer_3():
"""Assert that the TPOT CLI interface's integer parsing throws an exception when n is not an integer."""
try:
positive_integer('foobar')
assert False # Should be unreachable
except Exception:
pass
assert_raises(Exception, positive_integer, 'foobar')


def test_float_range():
Expand All @@ -841,17 +806,9 @@ def test_float_range():

def test_float_range_2():
"""Assert that the TPOT CLI interface's float range throws an exception when input it out of range."""
try:
float_range('2.0')
assert False # Should be unreachable
except Exception:
pass
assert_raises(Exception, float_range, '2.0')


def test_float_range_3():
"""Assert that the TPOT CLI interface's float range throws an exception when input is not a float."""
try:
float_range('foobar')
assert False # Should be unreachable
except Exception:
pass
assert_raises(Exception, float_range, 'foobar')