Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
53713c9
Fix for issue #6352
tracer0tong Feb 17, 2016
65a2b8f
Fixed codestyle
tracer0tong Feb 18, 2016
eb242c2
ENH: FeatureHasher now accepts string values.
devashishd12 Jan 15, 2016
876f123
Do not ignore files starting with _ in nose
lesteve Feb 29, 2016
2e7d9ad
FIX: improve docs of randomized lasso
Mar 7, 2016
bf81451
Fix consistency in docs and docstring
Mar 7, 2016
a754e09
Added ref to Bach and improved docs
Mar 7, 2016
f81e5aa
Try to fix link to pdf
Mar 7, 2016
3a83071
fix x and y order
Mar 7, 2016
4eca0c9
updated info for cross_val_score
ohld Mar 14, 2016
c2eaf75
Merge pull request #6173 from dsquareindia/featurehasher_fix
MechCoder Mar 19, 2016
b64e992
Merge pull request #6542 from ohld/make_scorer-link
glouppe Mar 19, 2016
e2e6bde
Merge pull request #6498 from clamus/rand-lasso-fix-6493
glouppe Mar 19, 2016
9691824
Merge pull request #6466 from lesteve/nose-ignore-files-tweak
glouppe Mar 19, 2016
7580746
Fix broken link in ABOUT
bryandeng Mar 20, 2016
bd6b313
Merge pull request #6565 from bryandeng/doc-link
agramfort Mar 20, 2016
549474d
[gardening] Fix NameError ("estimator" not defined). Remove unused va…
practicalswift Mar 20, 2016
e228581
Merge pull request #6566 from practicalswift/fix-nameerror-and-remove…
jnothman Mar 21, 2016
e9492b7
LabelBinarizer single label case now works for sparse and dense case
devashishd12 Jan 24, 2016
528533d
MAINT: Simplify n_features_to_select in RFECV
MechCoder Mar 21, 2016
945cb7e
Merge pull request #6221 from dsquareindia/LabelBinarizer_fix
MechCoder Mar 21, 2016
54af09e
Fixing typos in logistic regression docs
hlin117 Mar 21, 2016
146f461
Merge pull request #6575 from hlin117/logregdocs
agramfort Mar 22, 2016
07a6433
Fix typo in html target
Mar 22, 2016
b3c2219
Add the possibility to add prior to Gaussian Naive Bayes
Jan 18, 2016
5a046c7
Update whatsnew
MechCoder Mar 22, 2016
5d92bd5
Merge pull request #6579 from nlathia/issue-6541
TomDLT Mar 23, 2016
65b570b
Update scorer.py
lizsz Mar 19, 2016
56d625f
Merge pull request #6569 from MechCoder/minor
TomDLT Mar 23, 2016
22d7cd5
Make dump_svmlight_file support sparse y
yenchenlin Feb 18, 2016
eed5fc5
Merge pull request #6395 from yenchenlin1994/make-dump_svmlight_file-…
TomDLT Mar 23, 2016
afc058f
Merge pull request #6376 from tracer0tong/issue_6352
TomDLT Mar 24, 2016
612cd9e
ENH: Support data centering in LogisticRegression
kernc Mar 17, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add the possibility to add prior to Gaussian Naive Bayes
Added the following things:
* Test when the sum is not equal to one
* Test the prediction in case of a large bias for one class
* Test explicitely class_prior_
* Move the function to update the class prior in the GaussianNB class
* Remove the updating of the class prior before to actually compute the mean and variance

Address comments for PR scikit-learn#6180 - Correct the documentation

Address comment PR#6180 - Improve the class prior initialisation and updating
We modify the code to:
* Initialisat self.class_prior_ with the different possibilities (class_prior given or not, fit_prior True or False)
* Update self.class_prior_ only when no class_prior is given and than fit_prior is True

Address comments PR scikit-learn#6180 - Remove useless line

Fix the file according to PEP8 regulations

Update the API to have only class_prior in GaussianNB

Correct prior fitting using samples and not class number

Change name of priors and correct the warning with division by zero

Update the API

Remove functions which were called only once

Add additional test for part of the code which was not covered in GaussianNB

Correct doc formatting

Correct spelling
  • Loading branch information
Guillaume Lemaitre authored and MechCoder committed Mar 22, 2016
commit b3c22194b1db90fb6c0effe697fe2ae4d65f7318
61 changes: 48 additions & 13 deletions sklearn/naive_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ class GaussianNB(BaseNB):

Read more in the :ref:`User Guide <gaussian_naive_bayes>`.

Parameters
----------
priors : array-like, shape (n_classes,)
Prior probabilities of the classes. If specified the priors are not
adjusted according to the data.

Attributes
----------
class_prior_ : array, shape (n_classes,)
Expand All @@ -137,16 +143,19 @@ class GaussianNB(BaseNB):
>>> from sklearn.naive_bayes import GaussianNB
>>> clf = GaussianNB()
>>> clf.fit(X, Y)
GaussianNB()
GaussianNB(priors=None)
>>> print(clf.predict([[-0.8, -1]]))
[1]
>>> clf_pf = GaussianNB()
>>> clf_pf.partial_fit(X, Y, np.unique(Y))
GaussianNB()
GaussianNB(priors=None)
>>> print(clf_pf.predict([[-0.8, -1]]))
[1]
"""

def __init__(self, priors=None):
self.priors = priors

def fit(self, X, y, sample_weight=None):
"""Fit Gaussian Naive Bayes according to X, y

Expand Down Expand Up @@ -341,8 +350,29 @@ def _partial_fit(self, X, y, classes=None, _refit=False,
n_classes = len(self.classes_)
self.theta_ = np.zeros((n_classes, n_features))
self.sigma_ = np.zeros((n_classes, n_features))
self.class_prior_ = np.zeros(n_classes)
self.class_count_ = np.zeros(n_classes)

self.class_count_ = np.zeros(n_classes, dtype=np.float64)

# Initialise the class prior
n_classes = len(self.classes_)
# Take into account the priors
if self.priors is not None:
priors = np.asarray(self.priors)
# Check that the provide prior match the number of classes
if len(priors) != n_classes:
raise ValueError('Number of priors must match number of'
' classes.')
# Check that the sum is 1
if priors.sum() != 1.0:
raise ValueError('The sum of the priors should be 1.')
# Check that the prior are non-negative
if (priors < 0).any():
raise ValueError('Priors must be non-negative.')
self.class_prior_ = priors
else:
# Initialize the priors to zeros for each class
self.class_prior_ = np.zeros(len(self.classes_),
dtype=np.float64)
else:
if X.shape[1] != self.theta_.shape[1]:
msg = "Number of features %d does not match previous data %d."
Expand Down Expand Up @@ -380,7 +410,12 @@ def _partial_fit(self, X, y, classes=None, _refit=False,
self.class_count_[i] += N_i

self.sigma_[:, :] += epsilon
self.class_prior_[:] = self.class_count_ / np.sum(self.class_count_)

# Update if only no priors is provided
if self.priors is None:
# Empirical prior, with sample_weight taken into account
self.class_prior_ = self.class_count_ / self.class_count_.sum()

return self

def _joint_log_likelihood(self, X):
Expand Down Expand Up @@ -417,8 +452,8 @@ def _update_class_log_prior(self, class_prior=None):
self.class_log_prior_ = np.log(class_prior)
elif self.fit_prior:
# empirical prior, with sample_weight taken into account
self.class_log_prior_ = (np.log(self.class_count_)
- np.log(self.class_count_.sum()))
self.class_log_prior_ = (np.log(self.class_count_) -
np.log(self.class_count_.sum()))
else:
self.class_log_prior_ = np.zeros(n_classes) - np.log(n_classes)

Expand Down Expand Up @@ -661,16 +696,16 @@ def _update_feature_log_prob(self):
smoothed_fc = self.feature_count_ + self.alpha
smoothed_cc = smoothed_fc.sum(axis=1)

self.feature_log_prob_ = (np.log(smoothed_fc)
- np.log(smoothed_cc.reshape(-1, 1)))
self.feature_log_prob_ = (np.log(smoothed_fc) -
np.log(smoothed_cc.reshape(-1, 1)))

def _joint_log_likelihood(self, X):
"""Calculate the posterior log probability of the samples X"""
check_is_fitted(self, "classes_")

X = check_array(X, accept_sparse='csr')
return (safe_sparse_dot(X, self.feature_log_prob_.T)
+ self.class_log_prior_)
return (safe_sparse_dot(X, self.feature_log_prob_.T) +
self.class_log_prior_)


class BernoulliNB(BaseDiscreteNB):
Expand Down Expand Up @@ -763,8 +798,8 @@ def _update_feature_log_prob(self):
smoothed_fc = self.feature_count_ + self.alpha
smoothed_cc = self.class_count_ + self.alpha * 2

self.feature_log_prob_ = (np.log(smoothed_fc)
- np.log(smoothed_cc.reshape(-1, 1)))
self.feature_log_prob_ = (np.log(smoothed_fc) -
np.log(smoothed_cc.reshape(-1, 1)))

def _joint_log_likelihood(self, X):
"""Calculate the posterior log probability of the samples X"""
Expand Down
4 changes: 2 additions & 2 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,10 @@ def make_pipeline(*steps):
--------
>>> from sklearn.naive_bayes import GaussianNB
>>> from sklearn.preprocessing import StandardScaler
>>> make_pipeline(StandardScaler(), GaussianNB()) # doctest: +NORMALIZE_WHITESPACE
>>> make_pipeline(StandardScaler(), GaussianNB(priors=None)) # doctest: +NORMALIZE_WHITESPACE
Pipeline(steps=[('standardscaler',
StandardScaler(copy=True, with_mean=True, with_std=True)),
('gaussiannb', GaussianNB())])
('gaussiannb', GaussianNB(priors=None))])

Returns
-------
Expand Down
58 changes: 58 additions & 0 deletions sklearn/tests/test_naive_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,64 @@ def test_gnb_sample_weight():
assert_array_almost_equal(clf_dupl.sigma_, clf_sw.sigma_)


def test_gnb_neg_priors():
"""Test whether an error is raised in case of negative priors"""
clf = GaussianNB(priors=np.array([-1., 2.]))
assert_raises(ValueError, clf.fit, X, y)


def test_gnb_priors():
"""Test whether the class prior override is properly used"""
clf = GaussianNB(priors=np.array([0.3, 0.7])).fit(X, y)
assert_array_almost_equal(clf.predict_proba([[-0.1, -0.1]]),
np.array([[0.825303662161683,
0.174696337838317]]), 8)
assert_array_equal(clf.class_prior_, np.array([0.3, 0.7]))


def test_gnb_wrong_nb_priors():
""" Test whether an error is raised if the number of prior is different
from the number of class"""
clf = GaussianNB(priors=np.array([.25, .25, .25, .25]))
assert_raises(ValueError, clf.fit, X, y)


def test_gnb_prior_greater_one():
"""Test if an error is raised if the sum of prior greater than one"""
clf = GaussianNB(priors=np.array([2., 1.]))
assert_raises(ValueError, clf.fit, X, y)


def test_gnb_prior_large_bias():
"""Test if good prediction when class prior favor largely one class"""
clf = GaussianNB(priors=np.array([0.01, 0.99]))
clf.fit(X, y)
assert_equal(clf.predict([[-0.1, -0.1]]), np.array([2]))


def test_check_update_with_no_data():
""" Test when the partial fit is called without any data"""
# Create an empty array
prev_points = 100
mean = 0.
var = 1.
x_empty = np.empty((0, X.shape[1]))
tmean, tvar = GaussianNB._update_mean_variance(prev_points, mean,
var, x_empty)
assert_equal(tmean, mean)
assert_equal(tvar, var)


def test_gnb_pfit_wrong_nb_features():
"""Test whether an error is raised when the number of feature changes
between two partial fit"""
clf = GaussianNB()
# Fit for the first time the GNB
clf.fit(X, y)
# Partial fit a second time with an incoherent X
assert_raises(ValueError, clf.partial_fit, np.hstack((X, X)), y)


def test_discrete_prior():
# Test whether class priors are properly set.
for cls in [BernoulliNB, MultinomialNB]:
Expand Down