Skip to content
27 changes: 21 additions & 6 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@ def _ssc_wait(start_time, end_time, sleep_time):
while time() - start_time < end_time:
sleep(0.01)

@staticmethod
def _ssc_wait_checked(start_time, end_time, term_check):
"""
:param term_check: Function which checks a termination condition.
If true, this method returns early.
"""
while time() - start_time < end_time:
if term_check():
return
sleep(0.01)


def _squared_distance(a, b):
if isinstance(a, Vector):
Expand Down Expand Up @@ -1001,8 +1012,10 @@ def test_accuracy_for_single_center(self):

t = time()
self.ssc.start()
self._ssc_wait(t, 10.0, 0.01)
self.assertEquals(stkm.latestModel().clusterWeights, [25.0])
def termCheck():
return stkm.latestModel().clusterWeights == [25.0]
self._ssc_wait_checked(t, 20.0, termCheck)
self.assertTrue(termCheck())
realCenters = array_sum(array(centers), axis=0)
for i in range(5):
modelCenters = stkm.latestModel().centers[0][i]
Expand Down Expand Up @@ -1041,10 +1054,12 @@ def test_trainOn_model(self):
self.ssc.start()

# Give enough time to train the model.
self._ssc_wait(t, 6.0, 0.01)
finalModel = stkm.latestModel()
self.assertTrue(all(finalModel.centers == array(initCenters)))
self.assertEquals(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0])
def termCheck():
finalModel = stkm.latestModel()
all(finalModel.centers == array(initCenters)) and \
finalModel.clusterWeight == [5.0, 5.0, 5.0, 5.0]
self._ssc_wait_checked(t, 20.0, termCheck)
self.assertTrue(termCheck())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is still a slight possibility that between the last time term_check() is called in the _ssc_wait_checked, and next time its called in this method, another batch may have been processed, which which fail the test unnecessarily. So a better approach would be for the _ssc_wait_checked method to return True if the term_check() has succeeded within the timeout, otherwise return false. Then there is not need to check term_check() once again.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For these tests, they should pass whenever all batches have been processed, so the current setup should be safe. I'm actually thinking of copying the checks so that assertions print out more useful error messages. (I don't see a great way to avoid copying the checks if I want them for both early stopping & useful error messages.)


def test_predictOn_model(self):
"""Test that the model predicts correctly on toy data."""
Expand Down