Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/pyspark/mllib/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def parse(s):
s = s[start + 1: end]

try:
values = [float(val) for val in s.split(',')]
values = [float(val) for val in s.split(',') if val]
except ValueError:
raise ValueError("Unable to parse values from %s" % s)
return DenseVector(values)
Expand Down Expand Up @@ -586,7 +586,7 @@ def parse(s):
new_s = s[ind_start + 1: ind_end]
ind_list = new_s.split(',')
try:
indices = [int(ind) for ind in ind_list]
indices = [int(ind) for ind in ind_list if ind]
except ValueError:
raise ValueError("Unable to parse indices from %s." % new_s)
s = s[ind_end + 1:].strip()
Expand All @@ -599,7 +599,7 @@ def parse(s):
raise ValueError("Values array should end with ']'.")
val_list = s[val_start + 1: val_end].split(',')
try:
values = [float(val) for val in val_list]
values = [float(val) for val in val_list if val]
except ValueError:
raise ValueError("Unable to parse values from %s." % s)
return SparseVector(size, indices, values)
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,15 @@ def test_dense_matrix_is_transposed(self):
self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9]))

def test_parse_vector(self):
a = DenseVector([])
self.assertEqual(str(a), '[]')
self.assertEqual(Vectors.parse(str(a)), a)
a = DenseVector([3, 4, 6, 7])
self.assertTrue(str(a), '[3.0,4.0,6.0,7.0]')
self.assertTrue(Vectors.parse(str(a)), a)
a = SparseVector(4, [], [])
self.assertEqual(str(a), '(4,[],[])')
self.assertEqual(SparseVector.parse(str(a)), a)
a = SparseVector(4, [0, 2], [3, 4])
self.assertTrue(str(a), '(4,[0,2],[3.0,4.0])')
self.assertTrue(Vectors.parse(str(a)), a)
Expand Down