diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 7d14f0529557..e2733cc34524 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -180,9 +180,8 @@ def readImages(self, path, recursive=False, numPartitions=-1, .. versionadded:: 2.3.0 """ - ctx = SparkContext._active_spark_context - spark = SparkSession(ctx) - image_schema = ctx._jvm.org.apache.spark.ml.image.ImageSchema + spark = SparkSession.builder.getOrCreate() + image_schema = spark._jvm.org.apache.spark.ml.image.ImageSchema jsession = spark._jsparkSession jresult = image_schema.readImages(path, jsession, recursive, numPartitions, dropImageFailures, float(sampleRatio), seed) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2258d61c9533..4f99028f5e86 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -67,7 +67,7 @@ from pyspark.ml.util import * from pyspark.ml.wrapper import JavaParams, JavaWrapper from pyspark.serializers import PickleSerializer -from pyspark.sql import DataFrame, Row, SparkSession +from pyspark.sql import DataFrame, Row, SparkSession, HiveContext from pyspark.sql.functions import rand from pyspark.sql.types import DoubleType, IntegerType from pyspark.storagelevel import * @@ -1837,6 +1837,35 @@ def test_read_images(self): self.assertEqual(ImageSchema.undefinedImageType, "Undefined") +class ImageReaderTest2(PySparkTestCase): + + @classmethod + def setUpClass(cls): + PySparkTestCase.setUpClass() + # Note that here we enable Hive's support. + try: + cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.tearDownClass() + raise unittest.SkipTest("Hive is not available") + except TypeError: + cls.tearDownClass() + raise unittest.SkipTest("Hive is not available") + cls.spark = HiveContext._createForTesting(cls.sc) + + @classmethod + def tearDownClass(cls): + PySparkTestCase.tearDownClass() + cls.spark.sparkSession.stop() + + def test_read_images_multiple_times(self): + # This test case is to check if `ImageSchema.readImages` tries to + # initiate Hive client multiple times. See SPARK-22651. + data_path = 'data/mllib/images/kittens' + ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) + ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) + + class ALSTest(SparkSessionTestCase): def test_storage_levels(self):