Skip to content

Commit e45ffcc

Browse files
Arash Parsazzcclp
authored andcommitted
[SPARK-14739][PYSPARK] Fix Vectors parser bugs
## What changes were proposed in this pull request? The PySpark deserialization has a bug that shows while deserializing all zero sparse vectors. This fix filters out empty string tokens before casting, hence properly stringified SparseVectors successfully get parsed. ## How was this patch tested? Standard unit-tests similar to other methods. Author: Arash Parsa <arash@ip-192-168-50-106.ec2.internal> Author: Arash Parsa <arashpa@gmail.com> Author: Vishnu Prasad <vishnu667@gmail.com> Author: Vishnu Prasad S <vishnu667@gmail.com> Closes apache#12516 from arashpa/SPARK-14739. (cherry picked from commit 2b8906c) Signed-off-by: Sean Owen <sowen@cloudera.com> (cherry picked from commit 1cda10b)
1 parent 4d90ecf commit e45ffcc

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

python/pyspark/mllib/linalg/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def parse(s):
293293
s = s[start + 1: end]
294294

295295
try:
296-
values = [float(val) for val in s.split(',')]
296+
values = [float(val) for val in s.split(',') if val]
297297
except ValueError:
298298
raise ValueError("Unable to parse values from %s" % s)
299299
return DenseVector(values)
@@ -584,7 +584,7 @@ def parse(s):
584584
new_s = s[ind_start + 1: ind_end]
585585
ind_list = new_s.split(',')
586586
try:
587-
indices = [int(ind) for ind in ind_list]
587+
indices = [int(ind) for ind in ind_list if ind]
588588
except ValueError:
589589
raise ValueError("Unable to parse indices from %s." % new_s)
590590
s = s[ind_end + 1:].strip()
@@ -597,7 +597,7 @@ def parse(s):
597597
raise ValueError("Values array should end with ']'.")
598598
val_list = s[val_start + 1: val_end].split(',')
599599
try:
600-
values = [float(val) for val in val_list]
600+
values = [float(val) for val in val_list if val]
601601
except ValueError:
602602
raise ValueError("Unable to parse values from %s." % s)
603603
return SparseVector(size, indices, values)

python/pyspark/mllib/tests.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -388,14 +388,20 @@ def test_dense_matrix_is_transposed(self):
388388
self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9]))
389389

390390
def test_parse_vector(self):
391+
a = DenseVector([])
392+
self.assertEqual(str(a), '[]')
393+
self.assertEqual(Vectors.parse(str(a)), a)
391394
a = DenseVector([3, 4, 6, 7])
392-
self.assertTrue(str(a), '[3.0,4.0,6.0,7.0]')
393-
self.assertTrue(Vectors.parse(str(a)), a)
395+
self.assertEqual(str(a), '[3.0,4.0,6.0,7.0]')
396+
self.assertEqual(Vectors.parse(str(a)), a)
397+
a = SparseVector(4, [], [])
398+
self.assertEqual(str(a), '(4,[],[])')
399+
self.assertEqual(SparseVector.parse(str(a)), a)
394400
a = SparseVector(4, [0, 2], [3, 4])
395-
self.assertTrue(str(a), '(4,[0,2],[3.0,4.0])')
396-
self.assertTrue(Vectors.parse(str(a)), a)
401+
self.assertEqual(str(a), '(4,[0,2],[3.0,4.0])')
402+
self.assertEqual(Vectors.parse(str(a)), a)
397403
a = SparseVector(10, [0, 1], [4, 5])
398-
self.assertTrue(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a)
404+
self.assertEqual(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a)
399405

400406
def test_norms(self):
401407
a = DenseVector([0, 2, 3, -1])

0 commit comments

Comments
 (0)