Skip to content

Commit db4230a

Browse files
arjolylarsmans
authored andcommitted
TST improve parameter checking + use nose assert_ functions
1 parent defa6c4 commit db4230a

File tree

1 file changed

+52
-34
lines changed

1 file changed

+52
-34
lines changed

sklearn/feature_selection/tests/test_feature_select.py

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,17 @@
66
import numpy as np
77
from scipy import stats, sparse
88

9-
from nose.tools import (assert_equal, assert_almost_equal, assert_raises,
10-
assert_true)
11-
from numpy.testing import assert_array_equal, assert_array_almost_equal
12-
from sklearn.utils.testing import assert_not_in, ignore_warnings
13-
9+
from sklearn.utils.testing import assert_equal
10+
from sklearn.utils.testing import assert_almost_equal
11+
from sklearn.utils.testing import assert_raises
12+
from sklearn.utils.testing import assert_true
13+
from sklearn.utils.testing import assert_array_equal
14+
from sklearn.utils.testing import assert_array_almost_equal
15+
from sklearn.utils.testing import assert_not_in
16+
from sklearn.utils.testing import assert_less
17+
from sklearn.utils.testing import ignore_warnings
1418
from sklearn.utils import safe_mask
19+
1520
from sklearn.datasets.samples_generator import (make_classification,
1621
make_regression)
1722
from sklearn.feature_selection import (chi2, f_classif, f_oneway, f_regression,
@@ -60,11 +65,11 @@ def test_f_classif():
6065

6166
F, pv = f_classif(X, y)
6267
F_sparse, pv_sparse = f_classif(sparse.csr_matrix(X), y)
63-
assert(F > 0).all()
64-
assert(pv > 0).all()
65-
assert(pv < 1).all()
66-
assert(pv[:5] < 0.05).all()
67-
assert(pv[5:] > 1.e-4).all()
68+
assert_true((F > 0).all())
69+
assert_true((pv > 0).all())
70+
assert_true((pv < 1).all())
71+
assert_true((pv[:5] < 0.05).all())
72+
assert_true((pv[5:] > 1.e-4).all())
6873
assert_array_almost_equal(F_sparse, F)
6974
assert_array_almost_equal(pv_sparse, pv)
7075

@@ -78,11 +83,11 @@ def test_f_regression():
7883
shuffle=False, random_state=0)
7984

8085
F, pv = f_regression(X, y)
81-
assert(F > 0).all()
82-
assert(pv > 0).all()
83-
assert(pv < 1).all()
84-
assert(pv[:5] < 0.05).all()
85-
assert(pv[5:] > 1.e-4).all()
86+
assert_true((F > 0).all())
87+
assert_true((pv > 0).all())
88+
assert_true((pv < 1).all())
89+
assert_true((pv[:5] < 0.05).all())
90+
assert_true((pv[5:] > 1.e-4).all())
8691

8792
# again without centering, compare with sparse
8893
F, pv = f_regression(X, y, center=False)
@@ -137,11 +142,11 @@ def test_f_classif_multi_class():
137142
class_sep=10, shuffle=False, random_state=0)
138143

139144
F, pv = f_classif(X, y)
140-
assert(F > 0).all()
141-
assert(pv > 0).all()
142-
assert(pv < 1).all()
143-
assert(pv[:5] < 0.05).all()
144-
assert(pv[5:] > 1.e-5).all()
145+
assert_true((F > 0).all())
146+
assert_true((pv > 0).all())
147+
assert_true((pv < 1).all())
148+
assert_true((pv[:5] < 0.05).all())
149+
assert_true((pv[5:] > 1.e-4).all())
145150

146151

147152
def test_select_percentile_classif():
@@ -316,7 +321,7 @@ def test_select_fwe_classif():
316321
support = univariate_filter.get_support()
317322
gtruth = np.zeros(20)
318323
gtruth[:5] = 1
319-
assert(np.sum(np.abs(support - gtruth)) < 2)
324+
assert_array_almost_equal(support, gtruth)
320325

321326

322327
##############################################################################
@@ -380,8 +385,12 @@ def test_invalid_percentile():
380385
X, y = make_regression(n_samples=10, n_features=20,
381386
n_informative=2, shuffle=False, random_state=0)
382387

383-
assert_raises(ValueError, SelectPercentile(percentile=101).fit, X, y)
384388
assert_raises(ValueError, SelectPercentile(percentile=-1).fit, X, y)
389+
assert_raises(ValueError, SelectPercentile(percentile=101).fit, X, y)
390+
assert_raises(ValueError, GenericUnivariateSelect(mode='percentile',
391+
param=-1).fit, X, y)
392+
assert_raises(ValueError, GenericUnivariateSelect(mode='percentile',
393+
param=101).fit, X, y)
385394

386395

387396
def test_select_kbest_regression():
@@ -422,8 +431,8 @@ def test_select_fpr_regression():
422431
support = univariate_filter.get_support()
423432
gtruth = np.zeros(20)
424433
gtruth[:5] = 1
425-
assert(support[:5] == 1).all()
426-
assert(np.sum(support[5:] == 1) < 3)
434+
assert_array_equal(support[:5], np.ones((5, ), dtype=np.bool))
435+
assert_less(np.sum(support[5:] == 1), 3)
427436

428437

429438
def test_select_fdr_regression():
@@ -463,8 +472,8 @@ def test_select_fwe_regression():
463472
support = univariate_filter.get_support()
464473
gtruth = np.zeros(20)
465474
gtruth[:5] = 1
466-
assert(support[:5] == 1).all()
467-
assert(np.sum(support[5:] == 1) < 2)
475+
assert_array_equal(support[:5], np.ones((5, ), dtype=np.bool))
476+
assert_less(np.sum(support[5:] == 1), 2)
468477

469478

470479
def test_selectkbest_tiebreaking():
@@ -538,9 +547,7 @@ def test_nans():
538547
"""Assert that SelectKBest and SelectPercentile can handle NaNs."""
539548
# First feature has zero variance to confuse f_classif (ANOVA) and
540549
# make it return a NaN.
541-
X = [[0, 1, 0],
542-
[0, -1, -1],
543-
[0, .5, .5]]
550+
X = [[0, 1, 0], [0, -1, -1], [0, .5, .5]]
544551
y = [1, 0, 1]
545552

546553
for select in (SelectKBest(f_classif, 2),
@@ -550,10 +557,21 @@ def test_nans():
550557

551558

552559
def test_score_func_error():
553-
X = [[0, 1, 0],
554-
[0, -1, -1],
555-
[0, .5, .5]]
560+
X = [[0, 1, 0], [0, -1, -1], [0, .5, .5]]
561+
y = [1, 0, 1]
562+
563+
for SelectFeatures in [SelectKBest, SelectPercentile, SelectFwe,
564+
SelectFdr, SelectFpr, GenericUnivariateSelect]:
565+
assert_raises(TypeError, SelectFeatures(score_func=10).fit, X, y)
566+
567+
568+
def test_invalid_k():
569+
X = [[0, 1, 0], [0, -1, -1], [0, .5, .5]]
556570
y = [1, 0, 1]
557571

558-
# test that score-func needs to be a callable
559-
assert_raises(TypeError, SelectKBest(score_func=10).fit, X, y)
572+
assert_raises(ValueError, SelectKBest(k=-1).fit, X, y)
573+
assert_raises(ValueError, SelectKBest(k=4).fit, X, y)
574+
assert_raises(ValueError,
575+
GenericUnivariateSelect(mode='k_best', param=-1).fit, X, y)
576+
assert_raises(ValueError,
577+
GenericUnivariateSelect(mode='k_best', param=4).fit, X, y)

0 commit comments

Comments
 (0)