Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
53713c9
Fix for issue #6352
tracer0tong Feb 17, 2016
65a2b8f
Fixed codestyle
tracer0tong Feb 18, 2016
eb242c2
ENH: FeatureHasher now accepts string values.
devashishd12 Jan 15, 2016
876f123
Do not ignore files starting with _ in nose
lesteve Feb 29, 2016
2e7d9ad
FIX: improve docs of randomized lasso
Mar 7, 2016
bf81451
Fix consistency in docs and docstring
Mar 7, 2016
a754e09
Added ref to Bach and improved docs
Mar 7, 2016
f81e5aa
Try to fix link to pdf
Mar 7, 2016
3a83071
fix x and y order
Mar 7, 2016
4eca0c9
updated info for cross_val_score
ohld Mar 14, 2016
c2eaf75
Merge pull request #6173 from dsquareindia/featurehasher_fix
MechCoder Mar 19, 2016
b64e992
Merge pull request #6542 from ohld/make_scorer-link
glouppe Mar 19, 2016
e2e6bde
Merge pull request #6498 from clamus/rand-lasso-fix-6493
glouppe Mar 19, 2016
9691824
Merge pull request #6466 from lesteve/nose-ignore-files-tweak
glouppe Mar 19, 2016
7580746
Fix broken link in ABOUT
bryandeng Mar 20, 2016
bd6b313
Merge pull request #6565 from bryandeng/doc-link
agramfort Mar 20, 2016
549474d
[gardening] Fix NameError ("estimator" not defined). Remove unused va…
practicalswift Mar 20, 2016
e228581
Merge pull request #6566 from practicalswift/fix-nameerror-and-remove…
jnothman Mar 21, 2016
e9492b7
LabelBinarizer single label case now works for sparse and dense case
devashishd12 Jan 24, 2016
528533d
MAINT: Simplify n_features_to_select in RFECV
MechCoder Mar 21, 2016
945cb7e
Merge pull request #6221 from dsquareindia/LabelBinarizer_fix
MechCoder Mar 21, 2016
54af09e
Fixing typos in logistic regression docs
hlin117 Mar 21, 2016
146f461
Merge pull request #6575 from hlin117/logregdocs
agramfort Mar 22, 2016
07a6433
Fix typo in html target
Mar 22, 2016
b3c2219
Add the possibility to add prior to Gaussian Naive Bayes
Jan 18, 2016
5a046c7
Update whatsnew
MechCoder Mar 22, 2016
5d92bd5
Merge pull request #6579 from nlathia/issue-6541
TomDLT Mar 23, 2016
65b570b
Update scorer.py
lizsz Mar 19, 2016
56d625f
Merge pull request #6569 from MechCoder/minor
TomDLT Mar 23, 2016
22d7cd5
Make dump_svmlight_file support sparse y
yenchenlin Feb 18, 2016
eed5fc5
Merge pull request #6395 from yenchenlin1994/make-dump_svmlight_file-…
TomDLT Mar 23, 2016
afc058f
Merge pull request #6376 from tracer0tong/issue_6352
TomDLT Mar 24, 2016
612cd9e
ENH: Support data centering in LogisticRegression
kernc Mar 17, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
ENH: Support data centering in LogisticRegression
  • Loading branch information
kernc committed Mar 24, 2016
commit 612cd9ef1a247821ae3e89a14389c326fd70a487
37 changes: 22 additions & 15 deletions sklearn/linear_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def center_data(X, y, fit_intercept, normalize=False, copy=True,


def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True,
sample_weight=None, return_mean=False):
sample_weight=None, return_mean=False, center_y=True):
"""
Centers data to have mean zero along axis 0. If fit_intercept=False or if
the X is a sparse matrix, no centering is done, but normalization can still
Expand Down Expand Up @@ -196,8 +196,11 @@ def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True,
return_norm=True)
else:
X_scale = np.ones(X.shape[1])
y_offset = np.average(y, axis=0, weights=sample_weight)
y = y - y_offset

y_offset = 0
if center_y:
y_offset = np.average(y, axis=0, weights=sample_weight)
y = y - y_offset
else:
X_offset = np.zeros(X.shape[1])
X_scale = np.ones(X.shape[1])
Expand All @@ -222,7 +225,22 @@ def _rescale_data(X, y, sample_weight):
return X, y


class LinearModel(six.with_metaclass(ABCMeta, BaseEstimator)):
class LinearNormalizerMixin(object):
_preprocess_data = staticmethod(_preprocess_data)

def _set_intercept(self, X_offset, y_offset, X_scale, intercept_=None):
"""Set the intercept_
"""
if self.fit_intercept:
self.coef_ = self.coef_ / X_scale
if intercept_ is None:
self.intercept_ = y_offset - np.dot(X_offset, self.coef_.T)
else:
self.intercept_ = 0 if intercept_ is None else intercept_


class LinearModel(six.with_metaclass(ABCMeta, BaseEstimator),
LinearNormalizerMixin):
"""Base class for Linear Models"""

@abstractmethod
Expand Down Expand Up @@ -267,17 +285,6 @@ def predict(self, X):
"""
return self._decision_function(X)

_preprocess_data = staticmethod(_preprocess_data)

def _set_intercept(self, X_offset, y_offset, X_scale):
"""Set the intercept_
"""
if self.fit_intercept:
self.coef_ = self.coef_ / X_scale
self.intercept_ = y_offset - np.dot(X_offset, self.coef_.T)
else:
self.intercept_ = 0.


# XXX Should this derive from LinearModel? It should be a mixin, not an ABC.
# Maybe the n_features checking can be moved to LinearModel.
Expand Down
38 changes: 32 additions & 6 deletions sklearn/linear_model/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import numpy as np
from scipy import optimize, sparse

from .base import LinearClassifierMixin, SparseCoefMixin, BaseEstimator
from .base import (LinearClassifierMixin, SparseCoefMixin, BaseEstimator,
LinearNormalizerMixin)
from .sag import sag_solver
from ..feature_selection.from_model import _LearntSelectorMixin
from ..preprocessing import LabelEncoder, LabelBinarizer
Expand Down Expand Up @@ -949,7 +950,8 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,


class LogisticRegression(BaseEstimator, LinearClassifierMixin,
_LearntSelectorMixin, SparseCoefMixin):
_LearntSelectorMixin, SparseCoefMixin,
LinearNormalizerMixin):
"""Logistic Regression (aka logit, MaxEnt) classifier.

In the multiclass case, the training algorithm uses the one-vs-rest (OvR)
Expand Down Expand Up @@ -988,7 +990,9 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin,

fit_intercept : bool, default: True
Specifies if a constant (a.k.a. bias or intercept) should be
added to the decision function.
added to the decision function. If set to false, no intercept
will be used in calculations (e.g. data is expected to be already
centered).

intercept_scaling : float, default 1.
Useful only when the solver 'liblinear' is used
Expand All @@ -1003,6 +1007,19 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin,
To lessen the effect of regularization on synthetic feature weight
(and therefore on the intercept) intercept_scaling has to be increased.

normalize : boolean, optional, default: False
If True, the regressors X will be normalized before regression.
This parameter is ignored when `fit_intercept` is set to False.
When the regressors are normalized, note that this makes the
hyperparameters learnt more robust and almost independent of the number
of samples. The same property is not valid for standardized data.
However, if you wish to standardize, please use
`preprocessing.StandardScaler` before calling `fit` on an estimator
with `normalize=False`.

copy_X : boolean, optional, default: True
If True, X will be copied; else, it may be overwritten.

class_weight : dict or 'balanced', default: None
Weights associated with classes in the form ``{class_label: weight}``.
If not given, all classes are supposed to have weight one.
Expand Down Expand Up @@ -1116,16 +1133,19 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin,
"""

def __init__(self, penalty='l2', dual=False, tol=1e-4, C=1.0,
fit_intercept=True, intercept_scaling=1, class_weight=None,
random_state=None, solver='liblinear', max_iter=100,
multi_class='ovr', verbose=0, warm_start=False, n_jobs=1):
fit_intercept=True, intercept_scaling=1, normalize=False,
copy_X=True, class_weight=None, random_state=None,
solver='liblinear', max_iter=100, multi_class='ovr',
verbose=0, warm_start=False, n_jobs=1):

self.penalty = penalty
self.dual = dual
self.tol = tol
self.C = C
self.fit_intercept = fit_intercept
self.intercept_scaling = intercept_scaling
self.normalize = normalize
self.copy_X = copy_X
self.class_weight = class_weight
self.random_state = random_state
self.solver = solver
Expand Down Expand Up @@ -1178,13 +1198,18 @@ def fit(self, X, y, sample_weight=None):
_check_solver_option(self.solver, self.multi_class, self.penalty,
self.dual)

X, _, X_offset, y_offset, X_scale = self._preprocess_data(
X, y, self.fit_intercept, self.normalize, self.copy_X,
sample_weight=sample_weight, center_y=False)

if self.solver == 'liblinear':
self.coef_, self.intercept_, n_iter_ = _fit_liblinear(
X, y, self.C, self.fit_intercept, self.intercept_scaling,
self.class_weight, self.penalty, self.dual, self.verbose,
self.max_iter, self.tol, self.random_state,
sample_weight=sample_weight)
self.n_iter_ = np.array([n_iter_])
self._set_intercept(X_offset, y_offset, X_scale, self.intercept_)
return self

if self.solver == 'sag':
Expand Down Expand Up @@ -1253,6 +1278,7 @@ def fit(self, X, y, sample_weight=None):
if self.fit_intercept:
self.intercept_ = self.coef_[:, -1]
self.coef_ = self.coef_[:, :-1]
self._set_intercept(X_offset, y_offset, X_scale, self.intercept_)

return self

Expand Down