Skip to content

Commit ad16bb4

Browse files
reiinakanojnothman
authored andcommitted
[MRG+1] Fix cross_val_predict behavior for binary classification in decision_function (Fixes scikit-learn#9589) (scikit-learn#9593)
* fix cross_val_predict for binary classification in decision_function * Add unit tests * Add unit tests * Add unit tests * better fix * fix conflict * fix broken * only calculate n_classes if one of 'decision_function', 'predict_proba', 'predict_log_proba' * add test for SVC ovo in cross_val_predict * flake8 fix * fix case of ovo and imbalanced folds for binary classification * change assert_raises to assert_raise_message for ovo case * fix flake8 linetoo long * add comments and clearer tests * improve comments and error message for OvO * fix .format error with L * use assert_raises_regex for better error message * raise error in decision_function special cases. change predict_log_proba missing classes to minimum numpy value * fix broken tests due to special cases of decision_function * add modified test for decision_function behavior that does not trigger edge cases * fix typos * fix typos * escape regex . * escape regex . * address comments. one unaddressed comment * simplify code * flake * wrong classes range * address comments. adjust error message * add warning * change warning to runtimewarning * add test for the warning * Use assert_warns_message rather than assert_warns Other minor fixes * Note on class-absent replacement values * Improve error message
1 parent 2a816fb commit ad16bb4

File tree

2 files changed

+147
-15
lines changed

2 files changed

+147
-15
lines changed

sklearn/model_selection/_validation.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,15 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
644644
predictions : ndarray
645645
This is the result of calling ``method``
646646
647+
Notes
648+
-----
649+
In the case that one or more classes are absent in a training portion, a
650+
default score needs to be assigned to all instances for that class if
651+
``method`` produces columns per class, as in {'decision_function',
652+
'predict_proba', 'predict_log_proba'}. For ``predict_proba`` this value is
653+
0. In order to ensure finite output, we approximate negative infinity by
654+
the minimum finite float value for the dtype in other cases.
655+
647656
Examples
648657
--------
649658
>>> from sklearn import datasets, linear_model
@@ -746,12 +755,49 @@ def _fit_and_predict(estimator, X, y, train, test, verbose, fit_params,
746755
predictions = func(X_test)
747756
if method in ['decision_function', 'predict_proba', 'predict_log_proba']:
748757
n_classes = len(set(y))
749-
predictions_ = np.zeros((_num_samples(X_test), n_classes))
750-
if method == 'decision_function' and len(estimator.classes_) == 2:
751-
predictions_[:, estimator.classes_[-1]] = predictions
752-
else:
753-
predictions_[:, estimator.classes_] = predictions
754-
predictions = predictions_
758+
if n_classes != len(estimator.classes_):
759+
recommendation = (
760+
'To fix this, use a cross-validation '
761+
'technique resulting in properly '
762+
'stratified folds')
763+
warnings.warn('Number of classes in training fold ({}) does '
764+
'not match total number of classes ({}). '
765+
'Results may not be appropriate for your use case. '
766+
'{}'.format(len(estimator.classes_),
767+
n_classes, recommendation),
768+
RuntimeWarning)
769+
if method == 'decision_function':
770+
if (predictions.ndim == 2 and
771+
predictions.shape[1] != len(estimator.classes_)):
772+
# This handles the case when the shape of predictions
773+
# does not match the number of classes used to train
774+
# it with. This case is found when sklearn.svm.SVC is
775+
# set to `decision_function_shape='ovo'`.
776+
raise ValueError('Output shape {} of {} does not match '
777+
'number of classes ({}) in fold. '
778+
'Irregular decision_function outputs '
779+
'are not currently supported by '
780+
'cross_val_predict'.format(
781+
predictions.shape, method,
782+
len(estimator.classes_),
783+
recommendation))
784+
if len(estimator.classes_) <= 2:
785+
# In this special case, `predictions` contains a 1D array.
786+
raise ValueError('Only {} class/es in training fold, this '
787+
'is not supported for decision_function '
788+
'with imbalanced folds. {}'.format(
789+
len(estimator.classes_),
790+
recommendation))
791+
792+
float_min = np.finfo(predictions.dtype).min
793+
default_values = {'decision_function': float_min,
794+
'predict_log_proba': float_min,
795+
'predict_proba': 0}
796+
predictions_for_all_classes = np.full((_num_samples(predictions),
797+
n_classes),
798+
default_values[method])
799+
predictions_for_all_classes[:, estimator.classes_] = predictions
800+
predictions = predictions_for_all_classes
755801
return predictions, test
756802

757803

sklearn/model_selection/tests/test_validation.py

Lines changed: 95 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from sklearn.utils.testing import assert_array_almost_equal
2525
from sklearn.utils.testing import assert_array_equal
2626
from sklearn.utils.testing import assert_warns
27+
from sklearn.utils.testing import assert_warns_message
2728
from sklearn.utils.mocking import CheckingClassifier, MockDataFrame
2829

2930
from sklearn.model_selection import cross_val_score
@@ -44,6 +45,7 @@
4445
from sklearn.datasets import make_regression
4546
from sklearn.datasets import load_boston
4647
from sklearn.datasets import load_iris
48+
from sklearn.datasets import load_digits
4749
from sklearn.metrics import explained_variance_score
4850
from sklearn.metrics import make_scorer
4951
from sklearn.metrics import accuracy_score
@@ -54,7 +56,7 @@
5456
from sklearn.metrics.scorer import check_scoring
5557

5658
from sklearn.linear_model import Ridge, LogisticRegression, SGDClassifier
57-
from sklearn.linear_model import PassiveAggressiveClassifier
59+
from sklearn.linear_model import PassiveAggressiveClassifier, RidgeClassifier
5860
from sklearn.neighbors import KNeighborsClassifier
5961
from sklearn.svm import SVC
6062
from sklearn.cluster import KMeans
@@ -800,6 +802,89 @@ def split(self, X, y=None, groups=None):
800802

801803
assert_raises(ValueError, cross_val_predict, est, X, y, cv=BadCV())
802804

805+
X, y = load_iris(return_X_y=True)
806+
807+
warning_message = ('Number of classes in training fold (2) does '
808+
'not match total number of classes (3). '
809+
'Results may not be appropriate for your use case.')
810+
assert_warns_message(RuntimeWarning, warning_message,
811+
cross_val_predict, LogisticRegression(),
812+
X, y, method='predict_proba', cv=KFold(2))
813+
814+
815+
def test_cross_val_predict_decision_function_shape():
816+
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
817+
818+
preds = cross_val_predict(LogisticRegression(), X, y,
819+
method='decision_function')
820+
assert_equal(preds.shape, (50,))
821+
822+
X, y = load_iris(return_X_y=True)
823+
824+
preds = cross_val_predict(LogisticRegression(), X, y,
825+
method='decision_function')
826+
assert_equal(preds.shape, (150, 3))
827+
828+
# This specifically tests imbalanced splits for binary
829+
# classification with decision_function. This is only
830+
# applicable to classifiers that can be fit on a single
831+
# class.
832+
X = X[:100]
833+
y = y[:100]
834+
assert_raise_message(ValueError,
835+
'Only 1 class/es in training fold, this'
836+
' is not supported for decision_function'
837+
' with imbalanced folds. To fix '
838+
'this, use a cross-validation technique '
839+
'resulting in properly stratified folds',
840+
cross_val_predict, RidgeClassifier(), X, y,
841+
method='decision_function', cv=KFold(2))
842+
843+
X, y = load_digits(return_X_y=True)
844+
est = SVC(kernel='linear', decision_function_shape='ovo')
845+
846+
preds = cross_val_predict(est,
847+
X, y,
848+
method='decision_function')
849+
assert_equal(preds.shape, (1797, 45))
850+
851+
ind = np.argsort(y)
852+
X, y = X[ind], y[ind]
853+
assert_raises_regex(ValueError,
854+
'Output shape \(599L?, 21L?\) of decision_function '
855+
'does not match number of classes \(7\) in fold. '
856+
'Irregular decision_function .*',
857+
cross_val_predict, est, X, y,
858+
cv=KFold(n_splits=3), method='decision_function')
859+
860+
861+
def test_cross_val_predict_predict_proba_shape():
862+
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
863+
864+
preds = cross_val_predict(LogisticRegression(), X, y,
865+
method='predict_proba')
866+
assert_equal(preds.shape, (50, 2))
867+
868+
X, y = load_iris(return_X_y=True)
869+
870+
preds = cross_val_predict(LogisticRegression(), X, y,
871+
method='predict_proba')
872+
assert_equal(preds.shape, (150, 3))
873+
874+
875+
def test_cross_val_predict_predict_log_proba_shape():
876+
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
877+
878+
preds = cross_val_predict(LogisticRegression(), X, y,
879+
method='predict_log_proba')
880+
assert_equal(preds.shape, (50, 2))
881+
882+
X, y = load_iris(return_X_y=True)
883+
884+
preds = cross_val_predict(LogisticRegression(), X, y,
885+
method='predict_log_proba')
886+
assert_equal(preds.shape, (150, 3))
887+
803888

804889
def test_cross_val_predict_input_types():
805890
iris = load_iris()
@@ -1241,21 +1326,22 @@ def get_expected_predictions(X, y, cv, classes, est, method):
12411326
est.fit(X[train], y[train])
12421327
expected_predictions_ = func(X[test])
12431328
# To avoid 2 dimensional indexing
1244-
exp_pred_test = np.zeros((len(test), classes))
1245-
if method is 'decision_function' and len(est.classes_) == 2:
1246-
exp_pred_test[:, est.classes_[-1]] = expected_predictions_
1329+
if method is 'predict_proba':
1330+
exp_pred_test = np.zeros((len(test), classes))
12471331
else:
1248-
exp_pred_test[:, est.classes_] = expected_predictions_
1332+
exp_pred_test = np.full((len(test), classes),
1333+
np.finfo(expected_predictions.dtype).min)
1334+
exp_pred_test[:, est.classes_] = expected_predictions_
12491335
expected_predictions[test] = exp_pred_test
12501336

12511337
return expected_predictions
12521338

12531339

12541340
def test_cross_val_predict_class_subset():
12551341

1256-
X = np.arange(8).reshape(4, 2)
1257-
y = np.array([0, 0, 1, 2])
1258-
classes = 3
1342+
X = np.arange(200).reshape(100, 2)
1343+
y = np.array([x//10 for x in range(100)])
1344+
classes = 10
12591345

12601346
kfold3 = KFold(n_splits=3)
12611347
kfold4 = KFold(n_splits=4)
@@ -1283,7 +1369,7 @@ def test_cross_val_predict_class_subset():
12831369
assert_array_almost_equal(expected_predictions, predictions)
12841370

12851371
# Testing unordered labels
1286-
y = [1, 1, -4, 6]
1372+
y = shuffle(np.repeat(range(10), 10), random_state=0)
12871373
predictions = cross_val_predict(est, X, y, method=method,
12881374
cv=kfold3)
12891375
y = le.fit_transform(y)

0 commit comments

Comments
 (0)