Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
53713c9
Fix for issue #6352
tracer0tong Feb 17, 2016
65a2b8f
Fixed codestyle
tracer0tong Feb 18, 2016
eb242c2
ENH: FeatureHasher now accepts string values.
devashishd12 Jan 15, 2016
876f123
Do not ignore files starting with _ in nose
lesteve Feb 29, 2016
2e7d9ad
FIX: improve docs of randomized lasso
Mar 7, 2016
bf81451
Fix consistency in docs and docstring
Mar 7, 2016
a754e09
Added ref to Bach and improved docs
Mar 7, 2016
f81e5aa
Try to fix link to pdf
Mar 7, 2016
3a83071
fix x and y order
Mar 7, 2016
4eca0c9
updated info for cross_val_score
ohld Mar 14, 2016
c2eaf75
Merge pull request #6173 from dsquareindia/featurehasher_fix
MechCoder Mar 19, 2016
b64e992
Merge pull request #6542 from ohld/make_scorer-link
glouppe Mar 19, 2016
e2e6bde
Merge pull request #6498 from clamus/rand-lasso-fix-6493
glouppe Mar 19, 2016
9691824
Merge pull request #6466 from lesteve/nose-ignore-files-tweak
glouppe Mar 19, 2016
7580746
Fix broken link in ABOUT
bryandeng Mar 20, 2016
bd6b313
Merge pull request #6565 from bryandeng/doc-link
agramfort Mar 20, 2016
549474d
[gardening] Fix NameError ("estimator" not defined). Remove unused va…
practicalswift Mar 20, 2016
e228581
Merge pull request #6566 from practicalswift/fix-nameerror-and-remove…
jnothman Mar 21, 2016
e9492b7
LabelBinarizer single label case now works for sparse and dense case
devashishd12 Jan 24, 2016
528533d
MAINT: Simplify n_features_to_select in RFECV
MechCoder Mar 21, 2016
945cb7e
Merge pull request #6221 from dsquareindia/LabelBinarizer_fix
MechCoder Mar 21, 2016
54af09e
Fixing typos in logistic regression docs
hlin117 Mar 21, 2016
146f461
Merge pull request #6575 from hlin117/logregdocs
agramfort Mar 22, 2016
07a6433
Fix typo in html target
Mar 22, 2016
b3c2219
Add the possibility to add prior to Gaussian Naive Bayes
Jan 18, 2016
5a046c7
Update whatsnew
MechCoder Mar 22, 2016
5d92bd5
Merge pull request #6579 from nlathia/issue-6541
TomDLT Mar 23, 2016
65b570b
Update scorer.py
lizsz Mar 19, 2016
56d625f
Merge pull request #6569 from MechCoder/minor
TomDLT Mar 23, 2016
22d7cd5
Make dump_svmlight_file support sparse y
yenchenlin Feb 18, 2016
eed5fc5
Merge pull request #6395 from yenchenlin1994/make-dump_svmlight_file-…
TomDLT Mar 23, 2016
afc058f
Merge pull request #6376 from tracer0tong/issue_6352
TomDLT Mar 24, 2016
612cd9e
ENH: Support data centering in LogisticRegression
kernc Mar 17, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Make dump_svmlight_file support sparse y
  • Loading branch information
yenchenlin committed Mar 23, 2016
commit 22d7cd5cd07e301461d635aa4235c7d10135a4a9
46 changes: 33 additions & 13 deletions sklearn/datasets/svmlight_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ def load_svmlight_files(files, n_features=None, dtype=np.float64,


def _dump_svmlight(X, y, f, multilabel, one_based, comment, query_id):
is_sp = int(hasattr(X, "tocsr"))
X_is_sp = int(hasattr(X, "tocsr"))
y_is_sp = int(hasattr(y, "tocsr"))
if X.dtype.kind == 'i':
value_pattern = u("%d:%d")
else:
Expand All @@ -302,7 +303,7 @@ def _dump_svmlight(X, y, f, multilabel, one_based, comment, query_id):
f.writelines(b("# %s\n" % line) for line in comment.splitlines())

for i in range(X.shape[0]):
if is_sp:
if X_is_sp:
span = slice(X.indptr[i], X.indptr[i + 1])
row = zip(X.indices[span], X.data[span])
else:
Expand All @@ -312,10 +313,16 @@ def _dump_svmlight(X, y, f, multilabel, one_based, comment, query_id):
s = " ".join(value_pattern % (j + one_based, x) for j, x in row)

if multilabel:
nz_labels = np.where(y[i] != 0)[0]
if y_is_sp:
nz_labels = y[i].nonzero()[1]
else:
nz_labels = np.where(y[i] != 0)[0]
labels_str = ",".join(label_pattern % j for j in nz_labels)
else:
labels_str = label_pattern % y[i]
if y_is_sp:
labels_str = label_pattern % y.data[i]
else:
labels_str = label_pattern % y[i]

if query_id is not None:
feat = (labels_str, query_id[i], s)
Expand All @@ -341,9 +348,10 @@ def dump_svmlight_file(X, y, f, zero_based=True, comment=None, query_id=None,
Training vectors, where n_samples is the number of samples and
n_features is the number of features.

y : array-like, shape = [n_samples] or [n_samples, n_labels]
Target values. Class labels must be an integer or float, or array-like
objects of integer or float for multilabel classifications.
y : {array-like, sparse matrix}, shape = [n_samples (, n_labels)]
Target values. Class labels must be an
integer or float, or array-like objects of integer or float for
multilabel classifications.

f : string or file-like in binary mode
If string, specifies the path that will contain the data.
Expand Down Expand Up @@ -385,19 +393,31 @@ def dump_svmlight_file(X, y, f, zero_based=True, comment=None, query_id=None,
if six.b("\0") in comment:
raise ValueError("comment string contains NUL byte")

y = np.asarray(y)
if y.ndim != 1 and not multilabel:
raise ValueError("expected y of shape (n_samples,), got %r"
% (y.shape,))
yval = check_array(y, accept_sparse='csr', ensure_2d=False)
if sp.issparse(yval):
if yval.shape[1] != 1 and not multilabel:
raise ValueError("expected y of shape (n_samples, 1),"
" got %r" % (yval.shape,))
else:
if yval.ndim != 1 and not multilabel:
raise ValueError("expected y of shape (n_samples,), got %r"
% (yval.shape,))

Xval = check_array(X, accept_sparse='csr')
if Xval.shape[0] != y.shape[0]:
if Xval.shape[0] != yval.shape[0]:
raise ValueError("X.shape[0] and y.shape[0] should be the same, got"
" %r and %r instead." % (Xval.shape[0], y.shape[0]))
" %r and %r instead." % (Xval.shape[0], yval.shape[0]))

# We had some issues with CSR matrices with unsorted indices (e.g. #1501),
# so sort them here, but first make sure we don't modify the user's X.
# TODO We can do this cheaper; sorted_indices copies the whole matrix.
if yval is y and hasattr(yval, "sorted_indices"):
y = yval.sorted_indices()
else:
y = yval
if hasattr(y, "sort_indices"):
y.sort_indices()

if Xval is X and hasattr(Xval, "sorted_indices"):
X = Xval.sorted_indices()
else:
Expand Down
118 changes: 68 additions & 50 deletions sklearn/datasets/tests/test_svmlight_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gzip
from io import BytesIO
import numpy as np
import scipy.sparse as sp
import os
import shutil
from tempfile import NamedTemporaryFile
Expand Down Expand Up @@ -200,67 +201,84 @@ def test_invalid_filename():


def test_dump():
Xs, y = load_svmlight_file(datafile)
Xd = Xs.toarray()
X_sparse, y_dense = load_svmlight_file(datafile)
X_dense = X_sparse.toarray()
y_sparse = sp.csr_matrix(y_dense)

# slicing a csr_matrix can unsort its .indices, so test that we sort
# those correctly
Xsliced = Xs[np.arange(Xs.shape[0])]

for X in (Xs, Xd, Xsliced):
for zero_based in (True, False):
for dtype in [np.float32, np.float64, np.int32]:
f = BytesIO()
# we need to pass a comment to get the version info in;
# LibSVM doesn't grok comments so they're not put in by
# default anymore.
dump_svmlight_file(X.astype(dtype), y, f, comment="test",
zero_based=zero_based)
f.seek(0)

comment = f.readline()
try:
comment = str(comment, "utf-8")
except TypeError: # fails in Python 2.x
pass

assert_in("scikit-learn %s" % sklearn.__version__, comment)

comment = f.readline()
try:
comment = str(comment, "utf-8")
except TypeError: # fails in Python 2.x
pass

assert_in(["one", "zero"][zero_based] + "-based", comment)

X2, y2 = load_svmlight_file(f, dtype=dtype,
zero_based=zero_based)
assert_equal(X2.dtype, dtype)
assert_array_equal(X2.sorted_indices().indices, X2.indices)
if dtype == np.float32:
assert_array_almost_equal(
X_sliced = X_sparse[np.arange(X_sparse.shape[0])]
y_sliced = y_sparse[np.arange(y_sparse.shape[0])]

for X in (X_sparse, X_dense, X_sliced):
for y in (y_sparse, y_dense, y_sliced):
for zero_based in (True, False):
for dtype in [np.float32, np.float64, np.int32]:
f = BytesIO()
# we need to pass a comment to get the version info in;
# LibSVM doesn't grok comments so they're not put in by
# default anymore.

if (sp.issparse(y) and y.shape[0] == 1):
# make sure y's shape is: (n_samples, n_labels)
# when it is sparse
y = y.T

dump_svmlight_file(X.astype(dtype), y, f, comment="test",
zero_based=zero_based)
f.seek(0)

comment = f.readline()
try:
comment = str(comment, "utf-8")
except TypeError: # fails in Python 2.x
pass

assert_in("scikit-learn %s" % sklearn.__version__, comment)

comment = f.readline()
try:
comment = str(comment, "utf-8")
except TypeError: # fails in Python 2.x
pass

assert_in(["one", "zero"][zero_based] + "-based", comment)

X2, y2 = load_svmlight_file(f, dtype=dtype,
zero_based=zero_based)
assert_equal(X2.dtype, dtype)
assert_array_equal(X2.sorted_indices().indices, X2.indices)

X2_dense = X2.toarray()

if dtype == np.float32:
# allow a rounding error at the last decimal place
Xd.astype(dtype), X2.toarray(), 4)
else:
assert_array_almost_equal(
assert_array_almost_equal(
X_dense.astype(dtype), X2_dense, 4)
assert_array_almost_equal(
y_dense.astype(dtype), y2, 4)
else:
# allow a rounding error at the last decimal place
Xd.astype(dtype), X2.toarray(), 15)
assert_array_equal(y, y2)
assert_array_almost_equal(
X_dense.astype(dtype), X2_dense, 15)
assert_array_almost_equal(
y_dense.astype(dtype), y2, 15)


def test_dump_multilabel():
X = [[1, 0, 3, 0, 5],
[0, 0, 0, 0, 0],
[0, 5, 0, 1, 0]]
y = [[0, 1, 0], [1, 0, 1], [1, 1, 0]]
f = BytesIO()
dump_svmlight_file(X, y, f, multilabel=True)
f.seek(0)
# make sure it dumps multilabel correctly
assert_equal(f.readline(), b("1 0:1 2:3 4:5\n"))
assert_equal(f.readline(), b("0,2 \n"))
assert_equal(f.readline(), b("0,1 1:5 3:1\n"))
y_dense = [[0, 1, 0], [1, 0, 1], [1, 1, 0]]
y_sparse = sp.csr_matrix(y_dense)
for y in [y_dense, y_sparse]:
f = BytesIO()
dump_svmlight_file(X, y, f, multilabel=True)
f.seek(0)
# make sure it dumps multilabel correctly
assert_equal(f.readline(), b("1 0:1 2:3 4:5\n"))
assert_equal(f.readline(), b("0,2 \n"))
assert_equal(f.readline(), b("0,1 1:5 3:1\n"))


def test_dump_concise():
Expand Down