Skip to content

Commit 06f4033

Browse files
committed
make test stable
1 parent ee56c7b commit 06f4033

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

python/pyspark/mllib/clustering.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,20 +202,25 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
202202
203203
>>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1,
204204
... 0.9,0.8,0.75,0.935,
205-
... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2))
205+
... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2), 2)
206206
>>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001,
207207
... maxIterations=50, seed=10)
208208
>>> labels = model.predict(clusterdata_1).collect()
209209
>>> labels[0]==labels[1]
210210
False
211211
>>> labels[1]==labels[2]
212-
True
212+
False
213213
>>> labels[4]==labels[5]
214214
True
215215
>>> model.predict([-0.1,-0.05])
216216
0
217-
>>> model.predictSoft([-0.1,-0.05])
218-
array([ 0.985..., 0.005..., 0.009...])
217+
>>> softPredicted = model.predictSoft([-0.1,-0.05])
218+
>>> abs(softPredicted[0] - 1.0) < 0.001
219+
True
220+
>>> abs(softPredicted[1] - 0.0) < 0.001
221+
True
222+
>>> abs(softPredicted[2] - 0.0) < 0.001
223+
True
219224
220225
>>> path = tempfile.mkdtemp()
221226
>>> model.save(sc, path)

0 commit comments

Comments
 (0)