Skip to content

Commit fdde7d0

Browse files
committed
[SPARK-16348][ML][MLLIB][PYTHON] Use full classpaths for pyspark ML JVM calls
## What changes were proposed in this pull request? Issue: Omitting the full classpath can cause problems when calling JVM methods or classes from pyspark. This PR: Changed all uses of jvm.X in pyspark.ml and pyspark.mllib to use full classpath for X ## How was this patch tested? Existing unit tests. Manual testing in an environment where this was an issue. Author: Joseph K. Bradley <joseph@databricks.com> Closes apache#14023 from jkbradley/SPARK-16348.
1 parent 59f9c1b commit fdde7d0

File tree

8 files changed

+28
-26
lines changed

8 files changed

+28
-26
lines changed

python/pyspark/ml/common.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _to_java_object_rdd(rdd):
6363
RDD is serialized in batch or not.
6464
"""
6565
rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
66-
return rdd.ctx._jvm.MLSerDe.pythonToJava(rdd._jrdd, True)
66+
return rdd.ctx._jvm.org.apache.spark.ml.python.MLSerDe.pythonToJava(rdd._jrdd, True)
6767

6868

6969
def _py2java(sc, obj):
@@ -82,7 +82,7 @@ def _py2java(sc, obj):
8282
pass
8383
else:
8484
data = bytearray(PickleSerializer().dumps(obj))
85-
obj = sc._jvm.MLSerDe.loads(data)
85+
obj = sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(data)
8686
return obj
8787

8888

@@ -95,17 +95,17 @@ def _java2py(sc, r, encoding="bytes"):
9595
clsName = 'JavaRDD'
9696

9797
if clsName == 'JavaRDD':
98-
jrdd = sc._jvm.MLSerDe.javaToPython(r)
98+
jrdd = sc._jvm.org.apache.spark.ml.python.MLSerDe.javaToPython(r)
9999
return RDD(jrdd, sc)
100100

101101
if clsName == 'Dataset':
102102
return DataFrame(r, SQLContext.getOrCreate(sc))
103103

104104
if clsName in _picklable_classes:
105-
r = sc._jvm.MLSerDe.dumps(r)
105+
r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r)
106106
elif isinstance(r, (JavaArray, JavaList)):
107107
try:
108-
r = sc._jvm.MLSerDe.dumps(r)
108+
r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r)
109109
except Py4JJavaError:
110110
pass # not pickable
111111

python/pyspark/ml/tests.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,12 +1195,12 @@ class VectorTests(MLlibTestCase):
11951195

11961196
def _test_serialize(self, v):
11971197
self.assertEqual(v, ser.loads(ser.dumps(v)))
1198-
jvec = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(v)))
1199-
nv = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvec)))
1198+
jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v)))
1199+
nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec)))
12001200
self.assertEqual(v, nv)
12011201
vs = [v] * 100
1202-
jvecs = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(vs)))
1203-
nvs = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvecs)))
1202+
jvecs = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(vs)))
1203+
nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvecs)))
12041204
self.assertEqual(vs, nvs)
12051205

12061206
def test_serialize(self):

python/pyspark/mllib/clustering.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ def load(cls, sc, path):
507507
Path to where the model is stored.
508508
"""
509509
model = cls._load_java(sc, path)
510-
wrapper = sc._jvm.GaussianMixtureModelWrapper(model)
510+
wrapper = sc._jvm.org.apache.spark.mllib.api.python.GaussianMixtureModelWrapper(model)
511511
return cls(wrapper)
512512

513513

@@ -638,7 +638,8 @@ def load(cls, sc, path):
638638
Load a model from the given path.
639639
"""
640640
model = cls._load_java(sc, path)
641-
wrapper = sc._jvm.PowerIterationClusteringModelWrapper(model)
641+
wrapper =\
642+
sc._jvm.org.apache.spark.mllib.api.python.PowerIterationClusteringModelWrapper(model)
642643
return PowerIterationClusteringModel(wrapper)
643644

644645

python/pyspark/mllib/common.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _to_java_object_rdd(rdd):
6666
RDD is serialized in batch or not.
6767
"""
6868
rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
69-
return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True)
69+
return rdd.ctx._jvm.org.apache.spark.mllib.api.python.SerDe.pythonToJava(rdd._jrdd, True)
7070

7171

7272
def _py2java(sc, obj):
@@ -85,7 +85,7 @@ def _py2java(sc, obj):
8585
pass
8686
else:
8787
data = bytearray(PickleSerializer().dumps(obj))
88-
obj = sc._jvm.SerDe.loads(data)
88+
obj = sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(data)
8989
return obj
9090

9191

@@ -98,17 +98,17 @@ def _java2py(sc, r, encoding="bytes"):
9898
clsName = 'JavaRDD'
9999

100100
if clsName == 'JavaRDD':
101-
jrdd = sc._jvm.SerDe.javaToPython(r)
101+
jrdd = sc._jvm.org.apache.spark.mllib.api.python.SerDe.javaToPython(r)
102102
return RDD(jrdd, sc)
103103

104104
if clsName == 'Dataset':
105105
return DataFrame(r, SQLContext.getOrCreate(sc))
106106

107107
if clsName in _picklable_classes:
108-
r = sc._jvm.SerDe.dumps(r)
108+
r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r)
109109
elif isinstance(r, (JavaArray, JavaList)):
110110
try:
111-
r = sc._jvm.SerDe.dumps(r)
111+
r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r)
112112
except Py4JJavaError:
113113
pass # not pickable
114114

python/pyspark/mllib/feature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def load(cls, sc, path):
553553
"""
554554
jmodel = sc._jvm.org.apache.spark.mllib.feature \
555555
.Word2VecModel.load(sc._jsc.sc(), path)
556-
model = sc._jvm.Word2VecModelWrapper(jmodel)
556+
model = sc._jvm.org.apache.spark.mllib.api.python.Word2VecModelWrapper(jmodel)
557557
return Word2VecModel(model)
558558

559559

python/pyspark/mllib/fpm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def load(cls, sc, path):
6464
Load a model from the given path.
6565
"""
6666
model = cls._load_java(sc, path)
67-
wrapper = sc._jvm.FPGrowthModelWrapper(model)
67+
wrapper = sc._jvm.org.apache.spark.mllib.api.python.FPGrowthModelWrapper(model)
6868
return FPGrowthModel(wrapper)
6969

7070

python/pyspark/mllib/recommendation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def rank(self):
207207
def load(cls, sc, path):
208208
"""Load a model from the given path"""
209209
model = cls._load_java(sc, path)
210-
wrapper = sc._jvm.MatrixFactorizationModelWrapper(model)
210+
wrapper = sc._jvm.org.apache.spark.mllib.api.python.MatrixFactorizationModelWrapper(model)
211211
return MatrixFactorizationModel(wrapper)
212212

213213

python/pyspark/mllib/tests.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,12 @@ class VectorTests(MLlibTestCase):
150150

151151
def _test_serialize(self, v):
152152
self.assertEqual(v, ser.loads(ser.dumps(v)))
153-
jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v)))
154-
nv = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvec)))
153+
jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v)))
154+
nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec)))
155155
self.assertEqual(v, nv)
156156
vs = [v] * 100
157-
jvecs = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(vs)))
158-
nvs = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvecs)))
157+
jvecs = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(vs)))
158+
nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvecs)))
159159
self.assertEqual(vs, nvs)
160160

161161
def test_serialize(self):
@@ -1650,16 +1650,17 @@ class ALSTests(MLlibTestCase):
16501650

16511651
def test_als_ratings_serialize(self):
16521652
r = Rating(7, 1123, 3.14)
1653-
jr = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(r)))
1654-
nr = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jr)))
1653+
jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r)))
1654+
nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr)))
16551655
self.assertEqual(r.user, nr.user)
16561656
self.assertEqual(r.product, nr.product)
16571657
self.assertAlmostEqual(r.rating, nr.rating, 2)
16581658

16591659
def test_als_ratings_id_long_error(self):
16601660
r = Rating(1205640308657491975, 50233468418, 1.0)
16611661
# rating user id exceeds max int value, should fail when pickled
1662-
self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r)))
1662+
self.assertRaises(Py4JJavaError, self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads,
1663+
bytearray(ser.dumps(r)))
16631664

16641665

16651666
class HashingTFTest(MLlibTestCase):

0 commit comments

Comments
 (0)