Skip to content

Commit 2f69574

Browse files
add sample_weight to RidgeClassifier
1 parent 8453daa commit 2f69574

File tree

3 files changed

+41
-6
lines changed

3 files changed

+41
-6
lines changed

doc/whats_new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ Enhancements
7171
It is now possible to ignore one or more labels, such as where
7272
a multiclass problem has a majority class to ignore. By `Joel Nothman`_.
7373

74+
- Add ``sample_weight`` support to :class:`linear_model.RidgeClassifier`.
75+
By `Trevor Stephens`_.
7476

7577
Bug fixes
7678
.........

sklearn/linear_model/ridge.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ def __init__(self, alpha=1.0, fit_intercept=True, normalize=False,
572572
copy_X=copy_X, max_iter=max_iter, tol=tol, solver=solver)
573573
self.class_weight = class_weight
574574

575-
def fit(self, X, y):
575+
def fit(self, X, y, sample_weight=None):
576576
"""Fit Ridge regression model.
577577
578578
Parameters
@@ -583,20 +583,24 @@ def fit(self, X, y):
583583
y : array-like, shape = [n_samples]
584584
Target values
585585
586+
sample_weight : float or numpy array of shape (n_samples,)
587+
Sample weight.
588+
586589
Returns
587590
-------
588591
self : returns an instance of self.
589592
"""
593+
if sample_weight is None:
594+
sample_weight = 1.
595+
590596
self._label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1)
591597
Y = self._label_binarizer.fit_transform(y)
592598
if not self._label_binarizer.y_type_.startswith('multilabel'):
593599
y = column_or_1d(y, warn=True)
594600

595-
if self.class_weight:
596-
# get the class weight corresponding to each sample
597-
sample_weight = compute_sample_weight(self.class_weight, y)
598-
else:
599-
sample_weight = None
601+
# modify the sample weights with the corresponding class weight
602+
sample_weight = (sample_weight *
603+
compute_sample_weight(self.class_weight, y))
600604

601605
super(RidgeClassifier, self).fit(X, Y, sample_weight=sample_weight)
602606
return self

sklearn/linear_model/tests/test_ridge.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,35 @@ def test_class_weights():
487487
assert_array_almost_equal(clf.intercept_, clfa.intercept_)
488488

489489

490+
def test_class_weight_vs_sample_weight():
491+
"""Check class_weights resemble sample_weights behavior."""
492+
for clf in (RidgeClassifier, RidgeClassifierCV):
493+
494+
# Iris is balanced, so no effect expected for using 'balanced' weights
495+
clf1 = clf()
496+
clf1.fit(iris.data, iris.target)
497+
clf2 = clf(class_weight='balanced')
498+
clf2.fit(iris.data, iris.target)
499+
assert_almost_equal(clf1.coef_, clf2.coef_)
500+
501+
# Inflate importance of class 1, check against user-defined weights
502+
sample_weight = np.ones(iris.target.shape)
503+
sample_weight[iris.target == 1] *= 100
504+
class_weight = {0: 1., 1: 100., 2: 1.}
505+
clf1 = clf()
506+
clf1.fit(iris.data, iris.target, sample_weight)
507+
clf2 = clf(class_weight=class_weight)
508+
clf2.fit(iris.data, iris.target)
509+
assert_almost_equal(clf1.coef_, clf2.coef_)
510+
511+
# Check that sample_weight and class_weight are multiplicative
512+
clf1 = clf()
513+
clf1.fit(iris.data, iris.target, sample_weight ** 2)
514+
clf2 = clf(class_weight=class_weight)
515+
clf2.fit(iris.data, iris.target, sample_weight)
516+
assert_almost_equal(clf1.coef_, clf2.coef_)
517+
518+
490519
def test_class_weights_cv():
491520
# Test class weights for cross validated ridge classifier.
492521
X = np.array([[-1.0, -1.0], [-1.0, 0], [-.8, -1.0],

0 commit comments

Comments
 (0)