|
26 | 26 |
|
27 | 27 | from ..utils.validation import check_array, check_consistent_length |
28 | 28 | from ..utils.validation import column_or_1d |
| 29 | +from ..externals.six import string_types |
29 | 30 |
|
30 | 31 | import warnings |
31 | 32 |
|
@@ -162,11 +163,12 @@ def mean_absolute_error(y_true, y_pred, |
162 | 163 | y_true, y_pred, multioutput) |
163 | 164 | output_errors = np.average(np.abs(y_pred - y_true), |
164 | 165 | weights=sample_weight, axis=0) |
165 | | - if multioutput == 'raw_values': |
166 | | - return output_errors |
167 | | - elif multioutput == 'uniform_average': |
168 | | - # pass None as weights to np.average: uniform mean |
169 | | - multioutput = None |
| 166 | + if isinstance(multioutput, string_types): |
| 167 | + if multioutput == 'raw_values': |
| 168 | + return output_errors |
| 169 | + elif multioutput == 'uniform_average': |
| 170 | + # pass None as weights to np.average: uniform mean |
| 171 | + multioutput = None |
170 | 172 |
|
171 | 173 | return np.average(output_errors, weights=multioutput) |
172 | 174 |
|
@@ -229,11 +231,12 @@ def mean_squared_error(y_true, y_pred, |
229 | 231 | y_true, y_pred, multioutput) |
230 | 232 | output_errors = np.average((y_true - y_pred) ** 2, axis=0, |
231 | 233 | weights=sample_weight) |
232 | | - if multioutput == 'raw_values': |
233 | | - return output_errors |
234 | | - elif multioutput == 'uniform_average': |
235 | | - # pass None as weights to np.average: uniform mean |
236 | | - multioutput = None |
| 234 | + if isinstance(multioutput, string_types): |
| 235 | + if multioutput == 'raw_values': |
| 236 | + return output_errors |
| 237 | + elif multioutput == 'uniform_average': |
| 238 | + # pass None as weights to np.average: uniform mean |
| 239 | + multioutput = None |
237 | 240 |
|
238 | 241 | return np.average(output_errors, weights=multioutput) |
239 | 242 |
|
@@ -464,20 +467,21 @@ def r2_score(y_true, y_pred, |
464 | 467 | "to 'uniform_average' in 0.18.", |
465 | 468 | DeprecationWarning) |
466 | 469 | multioutput = 'variance_weighted' |
467 | | - if multioutput == 'raw_values': |
468 | | - # return scores individually |
469 | | - return output_scores |
470 | | - elif multioutput == 'uniform_average': |
471 | | - # passing None as weights results is uniform mean |
472 | | - avg_weights = None |
473 | | - elif multioutput == 'variance_weighted': |
474 | | - avg_weights = denominator |
475 | | - # avoid fail on constant y or one-element arrays |
476 | | - if not np.any(nonzero_denominator): |
477 | | - if not np.any(nonzero_numerator): |
478 | | - return 1.0 |
479 | | - else: |
480 | | - return 0.0 |
| 470 | + if isinstance(multioutput, string_types): |
| 471 | + if multioutput == 'raw_values': |
| 472 | + # return scores individually |
| 473 | + return output_scores |
| 474 | + elif multioutput == 'uniform_average': |
| 475 | + # passing None as weights results is uniform mean |
| 476 | + avg_weights = None |
| 477 | + elif multioutput == 'variance_weighted': |
| 478 | + avg_weights = denominator |
| 479 | + # avoid fail on constant y or one-element arrays |
| 480 | + if not np.any(nonzero_denominator): |
| 481 | + if not np.any(nonzero_numerator): |
| 482 | + return 1.0 |
| 483 | + else: |
| 484 | + return 0.0 |
481 | 485 | else: |
482 | 486 | avg_weights = multioutput |
483 | 487 |
|
|
0 commit comments