Skip to content

Commit 9ad265e

Browse files
committed
Merge branch 'releases' into dfsg
* releases: (57 commits) REL increased version number, added whatsnew DOC some copyediting COSMIT pep8 ENH set the random state to avoid heisenfailures ENH rewrite radius-NN classifier's outlier handling BUG in RadiusNeighborClassifier outlier handling BUG allow outlier_label=0 in RadiusNeighborClassifier TST stronger tests for arbitrary classes. make explicit what works and what doesn't. BF: explicitly mark train_test_split as not the one for nosetesting COSMIT update mailmap ENH fix test_estimators_overwrite_params to test regressors and transformers ENH speed up RBFSampler by ~10% TST fit_transform(X)==fit(X).transform(X) BUG in KernelPCA: wrong default value for gamma raise ValueError if division through zero in LogOddsEstimator fix: deviance computation in BinomialDeviance was wrong (ignored cases where y == 0) - thanks to ChrisBeaumont for reporting this issue fix: map labels to {0, 1} BUG fix broken grid search example DOC: add example and ref to lars_path in lasso_path BUG: highly-degenerate roc curves ... Conflicts: sklearn/externals/joblib/__init__.py sklearn/externals/joblib/hashing.py sklearn/externals/joblib/test/test_hashing.py
2 parents b4f780b + 1bee1b8 commit 9ad265e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+1378
-733
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
/sklearn/utils/arraybuilder.c -diff
1717
/sklearn/utils/arrayfuncs.c -diff
1818
/sklearn/utils/graph_shortest_path.c -diff
19+
/sklearn/utils/lgamma.c -diff
1920
/sklearn/utils/murmurhash.c -diff
2021
/sklearn/utils/seq_dataset.c -diff
2122
/sklearn/utils/sparsefuncs.c -diff

.mailmap

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,15 @@ Tim Sheerman-Chase <[email protected]> Tim Sheerman-Chase <ts00051@t
6868
Vincent Schut <[email protected]> Vincent Schut <vincent@TIMO.(none)>
6969
7070
71+
72+
Hrishikesh Huilgolkar <[email protected]> <hrishikesh@QE-IND-WKS007.(none)>
73+
74+
Brian Cheung <[email protected]> cow <briancheung>
75+
Brian Cheung <[email protected]> cow <cow@rusty.(none)>
76+
Diego Molla <[email protected]> <diego@diego-desktop.(none)>
77+
Michael EICKENBERG <[email protected]> Michael EICKENBERG <[email protected]>
78+
Michael EICKENBERG <[email protected]> Michael <[email protected]>
79+
80+
81+
82+
X006 <x006@x006-icsl.(none)> x006 <x006@x006laptop.(none)>

benchmarks/bench_covertype.py

Lines changed: 23 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -44,25 +44,29 @@
4444

4545
print __doc__
4646

47-
# Author: Peter Prettenhoer <[email protected]>
47+
# Author: Peter Prettenhofer <[email protected]>
4848
# License: BSD Style.
4949

50-
# $Id$
51-
52-
from time import time
50+
import logging
5351
import os
5452
import sys
55-
import numpy as np
53+
from time import time
5654
from optparse import OptionParser
5755

56+
import numpy as np
57+
58+
from sklearn.datasets import fetch_covtype
5859
from sklearn.svm import LinearSVC
5960
from sklearn.linear_model import SGDClassifier
6061
from sklearn.naive_bayes import GaussianNB
6162
from sklearn.tree import DecisionTreeClassifier
6263
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier
6364
from sklearn import metrics
6465
from sklearn.externals.joblib import Memory
65-
from sklearn.utils import check_random_state
66+
67+
logging.basicConfig(level=logging.INFO,
68+
format='%(asctime)s %(levelname)s %(message)s')
69+
logger = logging.getLogger(__name__)
6670

6771
op = OptionParser()
6872
op.add_option("--classifiers",
@@ -80,8 +84,7 @@
8084
# estimators.
8185
op.add_option("--random-seed",
8286
dest="random_seed", default=13, type=int,
83-
help="Common seed used by random number generator."
84-
)
87+
help="Common seed used by random number generator.")
8588

8689
op.print_help()
8790

@@ -97,57 +100,31 @@
97100
joblib_cache_folder = os.path.join(bench_folder, 'bench_covertype_data')
98101
m = Memory(joblib_cache_folder, mmap_mode='r')
99102

100-
# Set seed for rng
101-
rng = check_random_state(opts.random_seed)
102-
103103

104104
# Load the data, then cache and memmap the train/test split
105105
@m.cache
106106
def load_data(dtype=np.float32, order='F'):
107-
######################################################################
108-
## Download the data, if not already on disk
109-
if not os.path.exists(original_archive):
110-
# Download the data
111-
import urllib
112-
print "Downloading data, Please Wait (11MB)..."
113-
opener = urllib.urlopen(
114-
'http://archive.ics.uci.edu/ml/'
115-
'machine-learning-databases/covtype/covtype.data.gz')
116-
open(original_archive, 'wb').write(opener.read())
117-
118107
######################################################################
119108
## Load dataset
120109
print("Loading dataset...")
121-
import gzip
122-
f = gzip.open(original_archive)
123-
X = np.fromstring(f.read().replace(",", " "), dtype=dtype, sep=" ",
124-
count=-1)
125-
X = X.reshape((581012, 55))
110+
data = fetch_covtype(download_if_missing=True, shuffle=True,
111+
random_state=opts.random_seed)
112+
X, y = data.data, data.target
126113
if order.lower() == 'f':
127114
X = np.asfortranarray(X)
128-
f.close()
129115

130116
# class 1 vs. all others.
131-
y = np.ones(X.shape[0]) * -1
132-
y[np.where(X[:, -1] == 1)] = 1
133-
X = X[:, :-1]
117+
y[np.where(y != 1)] = -1
134118

135119
######################################################################
136120
## Create train-test split (as [Joachims, 2006])
137-
print("Creating train-test split...")
138-
idx = np.arange(X.shape[0])
139-
rng.shuffle(idx)
140-
train_idx = idx[:522911]
141-
test_idx = idx[522911:]
121+
logger.info("Creating train-test split...")
122+
n_train = 522911
142123

143-
X_train = X[train_idx]
144-
y_train = y[train_idx]
145-
X_test = X[test_idx]
146-
y_test = y[test_idx]
147-
148-
# free memory
149-
del X
150-
del y
124+
X_train = X[:n_train]
125+
y_train = y[:n_train]
126+
X_test = X[n_train:]
127+
y_test = y[n_train:]
151128

152129
######################################################################
153130
## Standardize first 10 features (the numerical ones)
@@ -204,7 +181,7 @@ def benchmark(clf):
204181
'dual': False,
205182
'tol': 1e-3,
206183
"random_state": opts.random_seed,
207-
}
184+
}
208185
classifiers['liblinear'] = LinearSVC(**liblinear_parameters)
209186

210187
######################################################################
@@ -218,7 +195,7 @@ def benchmark(clf):
218195
'n_iter': 2,
219196
'n_jobs': opts.n_jobs,
220197
"random_state": opts.random_seed,
221-
}
198+
}
222199
classifiers['SGD'] = SGDClassifier(**sgd_parameters)
223200

224201
######################################################################

doc/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
# built documents.
6868
#
6969
# The short X.Y version.
70-
version = '0.13'
70+
version = '0.13.1'
7171
# The full version, including alpha/beta/rc tags.
7272
import sklearn
7373
release = sklearn.__version__
@@ -121,7 +121,7 @@
121121
# further. For a list of options available for each theme, see the
122122
# documentation.
123123
html_theme_options = {'oldversion': False, 'collapsiblesidebar': True,
124-
'google_analytics': True, 'surveybanner': True}
124+
'google_analytics': True}
125125

126126
# Add any paths that contain custom themes here, relative to this directory.
127127
html_theme_path = ['themes']

doc/datasets/covtype.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
2+
.. _covtype:
3+
4+
Forest covertypes
5+
=================
6+
7+
The samples in this dataset correspond to 30×30m patches of forest in the US,
8+
collected for the task of predicting each patch's cover type,
9+
i.e. the dominant species of tree.
10+
There are seven covertypes, making this a multiclass classification problem.
11+
Each sample has 54 features, described on the
12+
`dataset's homepage <http://archive.ics.uci.edu/ml/datasets/Covertype>`_.
13+
Some of the features are boolean indicators,
14+
while others are discrete or continuous measurements.
15+
16+
``sklearn.datasets.fetch_covtype`` will load the covertype dataset;
17+
it returns a ``Bunch`` object with the feature matrix in the ``data`` member
18+
and the target values in ``target``.
19+
The dataset will be downloaded from the web if necessary.

doc/datasets/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,4 @@ features::
180180

181181
.. include:: labeled_faces.rst
182182

183+
.. include:: covtype.rst

doc/modules/classes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,6 +938,7 @@ Pairwise metrics
938938
:template: function.rst
939939

940940
preprocessing.add_dummy_feature
941+
preprocessing.balance_weights
941942
preprocessing.binarize
942943
preprocessing.normalize
943944
preprocessing.scale

doc/modules/feature_extraction.rst

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,9 @@ and has no ``inverse_transform`` method.
124124

125125
Since the hash function might cause collisions between (unrelated) features,
126126
a signed hash function is used and the sign of the hash value
127-
determines the sign of the value stored in the output matrix for a feature;
128-
this way, collisions are likely to cancel out rather than accumulate error,
129-
and the expected mean of any output feature's value is zero
127+
determines the sign of the value stored in the output matrix for a feature.
128+
This way, collisions are likely to cancel out rather than accumulate error,
129+
and the expected mean of any output feature's value is zero.
130130

131131
If ``non_negative=True`` is passed to the constructor,
132132
the absolute value is taken.
@@ -139,14 +139,20 @@ or ``chi2`` feature selectors that expect non-negative inputs.
139139
``(feature, value)`` pairs, or strings,
140140
depending on the constructor parameter ``input_type``.
141141
Mapping are treated as lists of ``(feature, value)`` pairs,
142-
while single strings have an implicit value of 1.
143-
If a feature occurs multiple times in a sample, the values will be summed.
142+
while single strings have an implicit value of 1,
143+
so ``['feat1', 'feat2', 'feat3']`` is interpreted as
144+
``[('feat1', 1), ('feat2', 1), ('feat3', 1)]``.
145+
If a single feature occurs multiple times in a sample,
146+
the associated values will be summed
147+
(so ``('feat', 2)`` and ``('feat', 3.5)`` become ``('feat', 5.5)``).
148+
The output from :class:`FeatureHasher` is always a ``scipy.sparse`` matrix
149+
in the CSR format.
150+
144151
Feature hashing can be employed in document classification,
145152
but unlike :class:`text.CountVectorizer`,
146153
:class:`FeatureHasher` does not do word
147-
splitting or any other preprocessing except Unicode-to-UTF-8 encoding.
148-
The output from :class:`FeatureHasher` is always a ``scipy.sparse`` matrix
149-
in the CSR format.
154+
splitting or any other preprocessing except Unicode-to-UTF-8 encoding;
155+
see :ref:`hashing_vectorizer`, below, for a combined tokenizer/hasher.
150156

151157
As an example, consider a word-level natural language processing task
152158
that needs features extracted from ``(token, part_of_speech)`` pairs.
@@ -193,6 +199,10 @@ to determine the column index and sign of a feature, respectively.
193199
The present implementation works under the assumption
194200
that the sign bit of MurmurHash3 is independent of its other bits.
195201

202+
Since a simple modulo is used to transform the hash function to a column index,
203+
it is advisable to use a power of two as the ``n_features`` parameter;
204+
otherwise the features will not be mapped evenly to the columns.
205+
196206

197207
.. topic:: References:
198208

doc/modules/linear_model.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ Passive Aggressive Algorithms
707707
=============================
708708

709709
The passive-aggressive algorithms are a family of algorithms for large-scale
710-
learning. They are similar to the Pereptron in that they do not require a
710+
learning. They are similar to the Perceptron in that they do not require a
711711
learning rate. However, contrary to the Perceptron, they include a
712712
regularization parameter ``C``.
713713

doc/modules/metrics.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
.. _metrics:
22

3-
Metrics, Affinities and Kernels
4-
===============================
3+
Pairwise metrics, Affinities and Kernels
4+
========================================
55

66
The :mod:`sklearn.metrics.pairwise` submodule implements utilities to evaluate
77
pairwise distances or affinity of sets of samples.

0 commit comments

Comments
 (0)