diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
index 28fb64f7cd0e0..184ddac9a71a6 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
@@ -32,94 +32,11 @@ public final class CalendarInterval implements Serializable {
public static final long MICROS_PER_DAY = MICROS_PER_HOUR * 24;
public static final long MICROS_PER_WEEK = MICROS_PER_DAY * 7;
- /**
- * A function to generate regex which matches interval string's unit part like "3 years".
- *
- * First, we can leave out some units in interval string, and we only care about the value of
- * unit, so here we use non-capturing group to wrap the actual regex.
- * At the beginning of the actual regex, we should match spaces before the unit part.
- * Next is the number part, starts with an optional "-" to represent negative value. We use
- * capturing group to wrap this part as we need the value later.
- * Finally is the unit name, ends with an optional "s".
- */
- private static String unitRegex(String unit) {
- return "(?:\\s+(-?\\d+)\\s+" + unit + "s?)?";
- }
-
- private static Pattern p = Pattern.compile("interval" + unitRegex("year") + unitRegex("month") +
- unitRegex("week") + unitRegex("day") + unitRegex("hour") + unitRegex("minute") +
- unitRegex("second") + unitRegex("millisecond") + unitRegex("microsecond"),
- Pattern.CASE_INSENSITIVE);
-
- private static Pattern yearMonthPattern =
- Pattern.compile("^(?:['|\"])?([+|-])?(\\d+)-(\\d+)(?:['|\"])?$");
+ private static Pattern yearMonthPattern = Pattern.compile(
+ "^([+|-])?(\\d+)-(\\d+)$");
private static Pattern dayTimePattern = Pattern.compile(
- "^(?:['|\"])?([+|-])?((\\d+) )?((\\d+):)?(\\d+):(\\d+)(\\.(\\d+))?(?:['|\"])?$");
-
- private static Pattern quoteTrimPattern = Pattern.compile("^(?:['|\"])?(.*?)(?:['|\"])?$");
-
- private static long toLong(String s) {
- if (s == null) {
- return 0;
- } else {
- return Long.parseLong(s);
- }
- }
-
- /**
- * Convert a string to CalendarInterval. Return null if the input string is not a valid interval.
- * This method is case-insensitive.
- */
- public static CalendarInterval fromString(String s) {
- try {
- return fromCaseInsensitiveString(s);
- } catch (IllegalArgumentException e) {
- return null;
- }
- }
-
- /**
- * Convert a string to CalendarInterval. This method can handle
- * strings without the `interval` prefix and throws IllegalArgumentException
- * when the input string is not a valid interval.
- *
- * @throws IllegalArgumentException if the string is not a valid internal.
- */
- public static CalendarInterval fromCaseInsensitiveString(String s) {
- if (s == null) {
- throw new IllegalArgumentException("Interval cannot be null");
- }
- String trimmed = s.trim();
- if (trimmed.isEmpty()) {
- throw new IllegalArgumentException("Interval cannot be blank");
- }
- String prefix = "interval";
- String intervalStr = trimmed;
- // Checks the given interval string does not start with the `interval` prefix
- if (!intervalStr.regionMatches(true, 0, prefix, 0, prefix.length())) {
- // Prepend `interval` if it does not present because
- // the regular expression strictly require it.
- intervalStr = prefix + " " + trimmed;
- } else if (intervalStr.length() == prefix.length()) {
- throw new IllegalArgumentException("Interval string must have time units");
- }
-
- Matcher m = p.matcher(intervalStr);
- if (!m.matches()) {
- throw new IllegalArgumentException("Invalid interval: " + s);
- }
-
- long months = toLong(m.group(1)) * 12 + toLong(m.group(2));
- long microseconds = toLong(m.group(3)) * MICROS_PER_WEEK;
- microseconds += toLong(m.group(4)) * MICROS_PER_DAY;
- microseconds += toLong(m.group(5)) * MICROS_PER_HOUR;
- microseconds += toLong(m.group(6)) * MICROS_PER_MINUTE;
- microseconds += toLong(m.group(7)) * MICROS_PER_SECOND;
- microseconds += toLong(m.group(8)) * MICROS_PER_MILLI;
- microseconds += toLong(m.group(9));
- return new CalendarInterval((int) months, microseconds);
- }
+ "^([+|-])?((\\d+) )?((\\d+):)?(\\d+):(\\d+)(\\.(\\d+))?$");
public static long toLongWithRange(String fieldName,
String s, long minValue, long maxValue) throws IllegalArgumentException {
@@ -242,72 +159,59 @@ public static CalendarInterval fromDayTimeString(String s, String from, String t
return result;
}
- public static CalendarInterval fromSingleUnitString(String unit, String s)
+ public static CalendarInterval fromUnitStrings(String[] units, String[] values)
throws IllegalArgumentException {
+ assert units.length == values.length;
+ int months = 0;
+ long microseconds = 0;
- CalendarInterval result = null;
- if (s == null) {
- throw new IllegalArgumentException(String.format("Interval %s string was null", unit));
- }
- s = s.trim();
- Matcher m = quoteTrimPattern.matcher(s);
- if (!m.matches()) {
- throw new IllegalArgumentException(
- "Interval string does not match day-time format of 'd h:m:s.n': " + s);
- } else {
+ for (int i = 0; i < units.length; i++) {
try {
- switch (unit) {
+ switch (units[i]) {
case "year":
- int year = (int) toLongWithRange("year", m.group(1),
- Integer.MIN_VALUE / 12, Integer.MAX_VALUE / 12);
- result = new CalendarInterval(year * 12, 0L);
+ months = Math.addExact(months, Math.multiplyExact(Integer.parseInt(values[i]), 12));
break;
case "month":
- int month = (int) toLongWithRange("month", m.group(1),
- Integer.MIN_VALUE, Integer.MAX_VALUE);
- result = new CalendarInterval(month, 0L);
+ months = Math.addExact(months, Integer.parseInt(values[i]));
break;
case "week":
- long week = toLongWithRange("week", m.group(1),
- Long.MIN_VALUE / MICROS_PER_WEEK, Long.MAX_VALUE / MICROS_PER_WEEK);
- result = new CalendarInterval(0, week * MICROS_PER_WEEK);
+ microseconds = Math.addExact(
+ microseconds,
+ Math.multiplyExact(Long.parseLong(values[i]), MICROS_PER_WEEK));
break;
case "day":
- long day = toLongWithRange("day", m.group(1),
- Long.MIN_VALUE / MICROS_PER_DAY, Long.MAX_VALUE / MICROS_PER_DAY);
- result = new CalendarInterval(0, day * MICROS_PER_DAY);
+ microseconds = Math.addExact(
+ microseconds,
+ Math.multiplyExact(Long.parseLong(values[i]), MICROS_PER_DAY));
break;
case "hour":
- long hour = toLongWithRange("hour", m.group(1),
- Long.MIN_VALUE / MICROS_PER_HOUR, Long.MAX_VALUE / MICROS_PER_HOUR);
- result = new CalendarInterval(0, hour * MICROS_PER_HOUR);
+ microseconds = Math.addExact(
+ microseconds,
+ Math.multiplyExact(Long.parseLong(values[i]), MICROS_PER_HOUR));
break;
case "minute":
- long minute = toLongWithRange("minute", m.group(1),
- Long.MIN_VALUE / MICROS_PER_MINUTE, Long.MAX_VALUE / MICROS_PER_MINUTE);
- result = new CalendarInterval(0, minute * MICROS_PER_MINUTE);
+ microseconds = Math.addExact(
+ microseconds,
+ Math.multiplyExact(Long.parseLong(values[i]), MICROS_PER_MINUTE));
break;
case "second": {
- long micros = parseSecondNano(m.group(1));
- result = new CalendarInterval(0, micros);
+ microseconds = Math.addExact(microseconds, parseSecondNano(values[i]));
break;
}
case "millisecond":
- long millisecond = toLongWithRange("millisecond", m.group(1),
- Long.MIN_VALUE / MICROS_PER_MILLI, Long.MAX_VALUE / MICROS_PER_MILLI);
- result = new CalendarInterval(0, millisecond * MICROS_PER_MILLI);
+ microseconds = Math.addExact(
+ microseconds,
+ Math.multiplyExact(Long.parseLong(values[i]), MICROS_PER_MILLI));
break;
- case "microsecond": {
- long micros = Long.parseLong(m.group(1));
- result = new CalendarInterval(0, micros);
+ case "microsecond":
+ microseconds = Math.addExact(microseconds, Long.parseLong(values[i]));
break;
- }
}
} catch (Exception e) {
throw new IllegalArgumentException("Error parsing interval string: " + e.getMessage(), e);
}
}
- return result;
+ return new CalendarInterval(months, microseconds);
}
/**
diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java
index 587071332ce47..9f3262bf2aaa4 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java
@@ -19,8 +19,6 @@
import org.junit.Test;
-import java.util.Arrays;
-
import static org.junit.Assert.*;
import static org.apache.spark.unsafe.types.CalendarInterval.*;
@@ -62,72 +60,6 @@ public void toStringTest() {
assertEquals("interval 2 years 10 months 3 weeks 13 hours 123 microseconds", i.toString());
}
- @Test
- public void fromStringTest() {
- testSingleUnit("year", 3, 36, 0);
- testSingleUnit("month", 3, 3, 0);
- testSingleUnit("week", 3, 0, 3 * MICROS_PER_WEEK);
- testSingleUnit("day", 3, 0, 3 * MICROS_PER_DAY);
- testSingleUnit("hour", 3, 0, 3 * MICROS_PER_HOUR);
- testSingleUnit("minute", 3, 0, 3 * MICROS_PER_MINUTE);
- testSingleUnit("second", 3, 0, 3 * MICROS_PER_SECOND);
- testSingleUnit("millisecond", 3, 0, 3 * MICROS_PER_MILLI);
- testSingleUnit("microsecond", 3, 0, 3);
-
- CalendarInterval result = new CalendarInterval(-5 * 12 + 23, 0);
- Arrays.asList(
- "interval -5 years 23 month",
- " -5 years 23 month",
- "interval -5 years 23 month ",
- " -5 years 23 month ",
- " interval -5 years 23 month ").forEach(input ->
- assertEquals(fromString(input), result)
- );
-
- // Error cases
- Arrays.asList(
- "interval 3month 1 hour",
- "3month 1 hour",
- "interval 3 moth 1 hour",
- "3 moth 1 hour",
- "interval",
- "int",
- "",
- null).forEach(input -> assertNull(fromString(input)));
- }
-
- @Test
- public void fromCaseInsensitiveStringTest() {
- for (String input : new String[]{"5 MINUTES", "5 minutes", "5 Minutes"}) {
- assertEquals(fromCaseInsensitiveString(input), new CalendarInterval(0, 5L * 60 * 1_000_000));
- }
-
- for (String input : new String[]{null, "", " "}) {
- try {
- fromCaseInsensitiveString(input);
- fail("Expected to throw an exception for the invalid input");
- } catch (IllegalArgumentException e) {
- String msg = e.getMessage();
- if (input == null) assertTrue(msg.contains("cannot be null"));
- else assertTrue(msg.contains("cannot be blank"));
- }
- }
-
- for (String input : new String[]{"interval", "interval1 day", "foo", "foo 1 day"}) {
- try {
- fromCaseInsensitiveString(input);
- fail("Expected to throw an exception for the invalid input");
- } catch (IllegalArgumentException e) {
- String msg = e.getMessage();
- if (input.trim().equalsIgnoreCase("interval")) {
- assertTrue(msg.contains("Interval string must have time units"));
- } else {
- assertTrue(msg.contains("Invalid interval:"));
- }
- }
- }
- }
-
@Test
public void fromYearMonthStringTest() {
String input;
@@ -194,107 +126,25 @@ public void fromDayTimeStringTest() {
}
}
- @Test
- public void fromSingleUnitStringTest() {
- String input;
- CalendarInterval i;
-
- input = "12";
- i = new CalendarInterval(12 * 12, 0L);
- assertEquals(fromSingleUnitString("year", input), i);
-
- input = "100";
- i = new CalendarInterval(0, 100 * MICROS_PER_DAY);
- assertEquals(fromSingleUnitString("day", input), i);
-
- input = "1999.38888";
- i = new CalendarInterval(0, 1999 * MICROS_PER_SECOND + 38);
- assertEquals(fromSingleUnitString("second", input), i);
-
- try {
- input = String.valueOf(Integer.MAX_VALUE);
- fromSingleUnitString("year", input);
- fail("Expected to throw an exception for the invalid input");
- } catch (IllegalArgumentException e) {
- assertTrue(e.getMessage().contains("outside range"));
- }
-
- try {
- input = String.valueOf(Long.MAX_VALUE / MICROS_PER_HOUR + 1);
- fromSingleUnitString("hour", input);
- fail("Expected to throw an exception for the invalid input");
- } catch (IllegalArgumentException e) {
- assertTrue(e.getMessage().contains("outside range"));
- }
- }
-
@Test
public void addTest() {
- String input = "interval 3 month 1 hour";
- String input2 = "interval 2 month 100 hour";
-
- CalendarInterval interval = fromString(input);
- CalendarInterval interval2 = fromString(input2);
-
- assertEquals(interval.add(interval2), new CalendarInterval(5, 101 * MICROS_PER_HOUR));
+ CalendarInterval input1 = new CalendarInterval(3, 1 * MICROS_PER_HOUR);
+ CalendarInterval input2 = new CalendarInterval(2, 100 * MICROS_PER_HOUR);
+ assertEquals(input1.add(input2), new CalendarInterval(5, 101 * MICROS_PER_HOUR));
- input = "interval -10 month -81 hour";
- input2 = "interval 75 month 200 hour";
-
- interval = fromString(input);
- interval2 = fromString(input2);
-
- assertEquals(interval.add(interval2), new CalendarInterval(65, 119 * MICROS_PER_HOUR));
+ input1 = new CalendarInterval(-10, -81 * MICROS_PER_HOUR);
+ input2 = new CalendarInterval(75, 200 * MICROS_PER_HOUR);
+ assertEquals(input1.add(input2), new CalendarInterval(65, 119 * MICROS_PER_HOUR));
}
@Test
public void subtractTest() {
- String input = "interval 3 month 1 hour";
- String input2 = "interval 2 month 100 hour";
-
- CalendarInterval interval = fromString(input);
- CalendarInterval interval2 = fromString(input2);
-
- assertEquals(interval.subtract(interval2), new CalendarInterval(1, -99 * MICROS_PER_HOUR));
-
- input = "interval -10 month -81 hour";
- input2 = "interval 75 month 200 hour";
-
- interval = fromString(input);
- interval2 = fromString(input2);
-
- assertEquals(interval.subtract(interval2), new CalendarInterval(-85, -281 * MICROS_PER_HOUR));
- }
-
- private static void testSingleUnit(String unit, int number, int months, long microseconds) {
- Arrays.asList("interval ", "").forEach(prefix -> {
- String input1 = prefix + number + " " + unit;
- String input2 = prefix + number + " " + unit + "s";
- CalendarInterval result = new CalendarInterval(months, microseconds);
- assertEquals(fromString(input1), result);
- assertEquals(fromString(input2), result);
- });
- }
-
- @Test
- public void fromStringCaseSensitivityTest() {
- testSingleUnit("YEAR", 3, 36, 0);
- testSingleUnit("Month", 3, 3, 0);
- testSingleUnit("Week", 3, 0, 3 * MICROS_PER_WEEK);
- testSingleUnit("DAY", 3, 0, 3 * MICROS_PER_DAY);
- testSingleUnit("HouR", 3, 0, 3 * MICROS_PER_HOUR);
- testSingleUnit("MiNuTe", 3, 0, 3 * MICROS_PER_MINUTE);
- testSingleUnit("Second", 3, 0, 3 * MICROS_PER_SECOND);
- testSingleUnit("MilliSecond", 3, 0, 3 * MICROS_PER_MILLI);
- testSingleUnit("MicroSecond", 3, 0, 3);
-
- String input;
-
- input = "INTERVAL -5 YEARS 23 MONTHS";
- CalendarInterval result = new CalendarInterval(-5 * 12 + 23, 0);
- assertEquals(fromString(input), result);
+ CalendarInterval input1 = new CalendarInterval(3, 1 * MICROS_PER_HOUR);
+ CalendarInterval input2 = new CalendarInterval(2, 100 * MICROS_PER_HOUR);
+ assertEquals(input1.subtract(input2), new CalendarInterval(1, -99 * MICROS_PER_HOUR));
- assertNull(fromString("INTERVAL"));
- assertNull(fromString(" Interval "));
+ input1 = new CalendarInterval(-10, -81 * MICROS_PER_HOUR);
+ input2 = new CalendarInterval(75, 200 * MICROS_PER_HOUR);
+ assertEquals(input1.subtract(input2), new CalendarInterval(-85, -281 * MICROS_PER_HOUR));
}
}
diff --git a/core/benchmarks/MapStatusesSerDeserBenchmark-jdk11-results.txt b/core/benchmarks/MapStatusesSerDeserBenchmark-jdk11-results.txt
index 7a6cfb7b23b94..db23cf5c12ea7 100644
--- a/core/benchmarks/MapStatusesSerDeserBenchmark-jdk11-results.txt
+++ b/core/benchmarks/MapStatusesSerDeserBenchmark-jdk11-results.txt
@@ -2,10 +2,10 @@ OpenJDK 64-Bit Server VM 11.0.4+11-post-Ubuntu-1ubuntu218.04.3 on Linux 4.15.0-1
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
200000 MapOutputs, 10 blocks w/ broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Serialization 205 213 13 1.0 1023.6 1.0X
-Deserialization 908 939 27 0.2 4540.2 0.2X
+Serialization 170 178 9 1.2 849.7 1.0X
+Deserialization 530 535 9 0.4 2651.1 0.3X
-Compressed Serialized MapStatus sizes: 400 bytes
+Compressed Serialized MapStatus sizes: 411 bytes
Compressed Serialized Broadcast MapStatus sizes: 2 MB
@@ -13,8 +13,8 @@ OpenJDK 64-Bit Server VM 11.0.4+11-post-Ubuntu-1ubuntu218.04.3 on Linux 4.15.0-1
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
200000 MapOutputs, 10 blocks w/o broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Serialization 195 204 24 1.0 976.9 1.0X
-Deserialization 913 940 33 0.2 4566.7 0.2X
+Serialization 157 165 7 1.3 785.4 1.0X
+Deserialization 495 588 79 0.4 2476.7 0.3X
Compressed Serialized MapStatus sizes: 2 MB
Compressed Serialized Broadcast MapStatus sizes: 0 bytes
@@ -24,21 +24,21 @@ OpenJDK 64-Bit Server VM 11.0.4+11-post-Ubuntu-1ubuntu218.04.3 on Linux 4.15.0-1
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
200000 MapOutputs, 100 blocks w/ broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Serialization 616 619 3 0.3 3079.1 1.0X
-Deserialization 936 954 22 0.2 4680.5 0.7X
+Serialization 344 351 4 0.6 1720.4 1.0X
+Deserialization 527 579 99 0.4 2635.9 0.7X
-Compressed Serialized MapStatus sizes: 418 bytes
-Compressed Serialized Broadcast MapStatus sizes: 14 MB
+Compressed Serialized MapStatus sizes: 427 bytes
+Compressed Serialized Broadcast MapStatus sizes: 13 MB
OpenJDK 64-Bit Server VM 11.0.4+11-post-Ubuntu-1ubuntu218.04.3 on Linux 4.15.0-1044-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
200000 MapOutputs, 100 blocks w/o broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Serialization 586 588 3 0.3 2928.8 1.0X
-Deserialization 929 933 4 0.2 4647.0 0.6X
+Serialization 317 321 4 0.6 1583.8 1.0X
+Deserialization 530 540 15 0.4 2648.3 0.6X
-Compressed Serialized MapStatus sizes: 14 MB
+Compressed Serialized MapStatus sizes: 13 MB
Compressed Serialized Broadcast MapStatus sizes: 0 bytes
@@ -46,21 +46,21 @@ OpenJDK 64-Bit Server VM 11.0.4+11-post-Ubuntu-1ubuntu218.04.3 on Linux 4.15.0-1
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
200000 MapOutputs, 1000 blocks w/ broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Serialization 4740 4916 249 0.0 23698.5 1.0X
-Deserialization 1578 1597 27 0.1 7890.6 3.0X
+Serialization 1738 1849 156 0.1 8692.0 1.0X
+Deserialization 946 977 33 0.2 4730.2 1.8X
-Compressed Serialized MapStatus sizes: 546 bytes
-Compressed Serialized Broadcast MapStatus sizes: 123 MB
+Compressed Serialized MapStatus sizes: 556 bytes
+Compressed Serialized Broadcast MapStatus sizes: 121 MB
OpenJDK 64-Bit Server VM 11.0.4+11-post-Ubuntu-1ubuntu218.04.3 on Linux 4.15.0-1044-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
200000 MapOutputs, 1000 blocks w/o broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Serialization 4492 4573 115 0.0 22458.3 1.0X
-Deserialization 1533 1547 20 0.1 7664.8 2.9X
+Serialization 1379 1432 76 0.1 6892.6 1.0X
+Deserialization 929 941 19 0.2 4645.5 1.5X
-Compressed Serialized MapStatus sizes: 123 MB
+Compressed Serialized MapStatus sizes: 121 MB
Compressed Serialized Broadcast MapStatus sizes: 0 bytes
diff --git a/core/benchmarks/MapStatusesSerDeserBenchmark-results.txt b/core/benchmarks/MapStatusesSerDeserBenchmark-results.txt
index 0c649694f6b6e..053f4bf771923 100644
--- a/core/benchmarks/MapStatusesSerDeserBenchmark-results.txt
+++ b/core/benchmarks/MapStatusesSerDeserBenchmark-results.txt
@@ -2,10 +2,10 @@ OpenJDK 64-Bit Server VM 1.8.0_222-8u222-b10-1ubuntu1~18.04.1-b10 on Linux 4.15.
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
200000 MapOutputs, 10 blocks w/ broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Serialization 236 245 18 0.8 1179.1 1.0X
-Deserialization 842 885 37 0.2 4211.4 0.3X
+Serialization 178 187 15 1.1 887.5 1.0X
+Deserialization 530 558 32 0.4 2647.5 0.3X
-Compressed Serialized MapStatus sizes: 400 bytes
+Compressed Serialized MapStatus sizes: 411 bytes
Compressed Serialized Broadcast MapStatus sizes: 2 MB
@@ -13,8 +13,8 @@ OpenJDK 64-Bit Server VM 1.8.0_222-8u222-b10-1ubuntu1~18.04.1-b10 on Linux 4.15.
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
200000 MapOutputs, 10 blocks w/o broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Serialization 213 219 8 0.9 1065.1 1.0X
-Deserialization 846 870 33 0.2 4228.6 0.3X
+Serialization 167 175 7 1.2 835.7 1.0X
+Deserialization 523 537 22 0.4 2616.2 0.3X
Compressed Serialized MapStatus sizes: 2 MB
Compressed Serialized Broadcast MapStatus sizes: 0 bytes
@@ -24,21 +24,21 @@ OpenJDK 64-Bit Server VM 1.8.0_222-8u222-b10-1ubuntu1~18.04.1-b10 on Linux 4.15.
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
200000 MapOutputs, 100 blocks w/ broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Serialization 624 709 167 0.3 3121.1 1.0X
-Deserialization 885 908 22 0.2 4427.0 0.7X
+Serialization 351 416 147 0.6 1754.4 1.0X
+Deserialization 546 551 8 0.4 2727.6 0.6X
-Compressed Serialized MapStatus sizes: 418 bytes
-Compressed Serialized Broadcast MapStatus sizes: 14 MB
+Compressed Serialized MapStatus sizes: 427 bytes
+Compressed Serialized Broadcast MapStatus sizes: 13 MB
OpenJDK 64-Bit Server VM 1.8.0_222-8u222-b10-1ubuntu1~18.04.1-b10 on Linux 4.15.0-1044-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
200000 MapOutputs, 100 blocks w/o broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Serialization 603 604 2 0.3 3014.9 1.0X
-Deserialization 892 895 5 0.2 4458.7 0.7X
+Serialization 320 321 1 0.6 1598.0 1.0X
+Deserialization 542 549 7 0.4 2709.0 0.6X
-Compressed Serialized MapStatus sizes: 14 MB
+Compressed Serialized MapStatus sizes: 13 MB
Compressed Serialized Broadcast MapStatus sizes: 0 bytes
@@ -46,21 +46,21 @@ OpenJDK 64-Bit Server VM 1.8.0_222-8u222-b10-1ubuntu1~18.04.1-b10 on Linux 4.15.
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
200000 MapOutputs, 1000 blocks w/ broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Serialization 4612 4945 471 0.0 23061.0 1.0X
-Deserialization 1493 1495 2 0.1 7466.3 3.1X
+Serialization 1671 1877 290 0.1 8357.3 1.0X
+Deserialization 943 970 32 0.2 4715.8 1.8X
-Compressed Serialized MapStatus sizes: 546 bytes
-Compressed Serialized Broadcast MapStatus sizes: 123 MB
+Compressed Serialized MapStatus sizes: 556 bytes
+Compressed Serialized Broadcast MapStatus sizes: 121 MB
OpenJDK 64-Bit Server VM 1.8.0_222-8u222-b10-1ubuntu1~18.04.1-b10 on Linux 4.15.0-1044-aws
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
200000 MapOutputs, 1000 blocks w/o broadcast: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-Serialization 4452 4595 202 0.0 22261.4 1.0X
-Deserialization 1464 1477 18 0.1 7321.4 3.0X
+Serialization 1373 1436 89 0.1 6865.0 1.0X
+Deserialization 940 970 37 0.2 4699.1 1.5X
-Compressed Serialized MapStatus sizes: 123 MB
+Compressed Serialized MapStatus sizes: 121 MB
Compressed Serialized Broadcast MapStatus sizes: 0 bytes
diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js
index 9960d5c34d1fc..ecd580e5c64aa 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js
@@ -97,9 +97,14 @@ sorttable = {
sorttable.reverse(this.sorttable_tbody);
this.className = this.className.replace('sorttable_sorted',
'sorttable_sorted_reverse');
- this.removeChild(document.getElementById('sorttable_sortfwdind'));
+ rowlists = this.parentNode.getElementsByTagName("span");
+ for (var j=0; j < rowlists.length; j++) {
+ if (rowlists[j].className.search(/\bsorttable_sortfwdind\b/)) {
+ rowlists[j].parentNode.removeChild(rowlists[j]);
+ }
+ }
sortrevind = document.createElement('span');
- sortrevind.id = "sorttable_sortrevind";
+ sortrevind.class = "sorttable_sortrevind";
sortrevind.innerHTML = stIsIE ? ' 5' : ' ▾';
this.appendChild(sortrevind);
return;
@@ -110,9 +115,14 @@ sorttable = {
sorttable.reverse(this.sorttable_tbody);
this.className = this.className.replace('sorttable_sorted_reverse',
'sorttable_sorted');
- this.removeChild(document.getElementById('sorttable_sortrevind'));
+ rowlists = this.parentNode.getElementsByTagName("span");
+ for (var j=0; j < rowlists.length; j++) {
+ if (rowlists[j].className.search(/\sorttable_sortrevind\b/)) {
+ rowlists[j].parentNode.removeChild(rowlists[j]);
+ }
+ }
sortfwdind = document.createElement('span');
- sortfwdind.id = "sorttable_sortfwdind";
+ sortfwdind.class = "sorttable_sortfwdind";
sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▴';
this.appendChild(sortfwdind);
return;
@@ -126,14 +136,17 @@ sorttable = {
cell.className = cell.className.replace('sorttable_sorted','');
}
});
- sortfwdind = document.getElementById('sorttable_sortfwdind');
- if (sortfwdind) { sortfwdind.parentNode.removeChild(sortfwdind); }
- sortrevind = document.getElementById('sorttable_sortrevind');
- if (sortrevind) { sortrevind.parentNode.removeChild(sortrevind); }
+ rowlists = this.parentNode.getElementsByTagName("span");
+ for (var j=0; j < rowlists.length; j++) {
+ if (rowlists[j].className.search(/\bsorttable_sortfwdind\b/)
+ || rowlists[j].className.search(/\sorttable_sortrevind\b/) ) {
+ rowlists[j].parentNode.removeChild(rowlists[j]);
+ }
+ }
this.className += ' sorttable_sorted';
sortfwdind = document.createElement('span');
- sortfwdind.id = "sorttable_sortfwdind";
+ sortfwdind.class = "sorttable_sortfwdind";
sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▴';
this.appendChild(sortfwdind);
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js
index 89622106ff1f0..cf04db28804c1 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js
@@ -87,4 +87,11 @@ $(function() {
collapseTablePageLoad('collapse-aggregated-runningExecutions','aggregated-runningExecutions');
collapseTablePageLoad('collapse-aggregated-completedExecutions','aggregated-completedExecutions');
collapseTablePageLoad('collapse-aggregated-failedExecutions','aggregated-failedExecutions');
-});
\ No newline at end of file
+});
+
+$(function() {
+ // Show/hide full job description on click event.
+ $(".description-input").click(function() {
+ $(this).toggleClass("description-input-full");
+ });
+});
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 6f4a6239a09ed..873efa76468ed 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -28,13 +28,12 @@ import scala.concurrent.duration.Duration
import scala.reflect.ClassTag
import scala.util.control.NonFatal
-import com.github.luben.zstd.ZstdInputStream
-import com.github.luben.zstd.ZstdOutputStream
import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOutputStream}
import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
+import org.apache.spark.io.CompressionCodec
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, MapStatus}
import org.apache.spark.shuffle.MetadataFetchFailedException
@@ -195,7 +194,8 @@ private class ShuffleStatus(numPartitions: Int) {
def serializedMapStatus(
broadcastManager: BroadcastManager,
isLocal: Boolean,
- minBroadcastSize: Int): Array[Byte] = {
+ minBroadcastSize: Int,
+ conf: SparkConf): Array[Byte] = {
var result: Array[Byte] = null
withReadLock {
@@ -207,7 +207,7 @@ private class ShuffleStatus(numPartitions: Int) {
if (result == null) withWriteLock {
if (cachedSerializedMapStatus == null) {
val serResult = MapOutputTracker.serializeMapStatuses(
- mapStatuses, broadcastManager, isLocal, minBroadcastSize)
+ mapStatuses, broadcastManager, isLocal, minBroadcastSize, conf)
cachedSerializedMapStatus = serResult._1
cachedSerializedBroadcast = serResult._2
}
@@ -450,7 +450,8 @@ private[spark] class MapOutputTrackerMaster(
" to " + hostPort)
val shuffleStatus = shuffleStatuses.get(shuffleId).head
context.reply(
- shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast))
+ shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast,
+ conf))
} catch {
case NonFatal(e) => logError(e.getMessage, e)
}
@@ -799,7 +800,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
endPartition: Int)
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
- val statuses = getStatuses(shuffleId)
+ val statuses = getStatuses(shuffleId, conf)
try {
MapOutputTracker.convertMapStatuses(
shuffleId, startPartition, endPartition, statuses)
@@ -818,7 +819,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, mapIndex $mapIndex" +
s"partitions $startPartition-$endPartition")
- val statuses = getStatuses(shuffleId)
+ val statuses = getStatuses(shuffleId, conf)
try {
MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition,
statuses, Some(mapIndex))
@@ -836,7 +837,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
*
* (It would be nice to remove this restriction in the future.)
*/
- private def getStatuses(shuffleId: Int): Array[MapStatus] = {
+ private def getStatuses(shuffleId: Int, conf: SparkConf): Array[MapStatus] = {
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
@@ -846,7 +847,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
if (fetchedStatuses == null) {
logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
- fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
+ fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes, conf)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
}
@@ -890,16 +891,20 @@ private[spark] object MapOutputTracker extends Logging {
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using Zstd. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
- def serializeMapStatuses(statuses: Array[MapStatus], broadcastManager: BroadcastManager,
- isLocal: Boolean, minBroadcastSize: Int): (Array[Byte], Broadcast[Array[Byte]]) = {
+ def serializeMapStatuses(
+ statuses: Array[MapStatus],
+ broadcastManager: BroadcastManager,
+ isLocal: Boolean,
+ minBroadcastSize: Int,
+ conf: SparkConf): (Array[Byte], Broadcast[Array[Byte]]) = {
// Using `org.apache.commons.io.output.ByteArrayOutputStream` instead of the standard one
// This implementation doesn't reallocate the whole memory block but allocates
// additional buffers. This way no buffers need to be garbage collected and
// the contents don't have to be copied to the new buffer.
val out = new ApacheByteArrayOutputStream()
- val compressedOut = new ApacheByteArrayOutputStream()
-
- val objOut = new ObjectOutputStream(out)
+ out.write(DIRECT)
+ val codec = CompressionCodec.createCodec(conf, "zstd")
+ val objOut = new ObjectOutputStream(codec.compressedOutputStream(out))
Utils.tryWithSafeFinally {
// Since statuses can be modified in parallel, sync on it
statuses.synchronized {
@@ -908,42 +913,21 @@ private[spark] object MapOutputTracker extends Logging {
} {
objOut.close()
}
-
- val arr: Array[Byte] = {
- val zos = new ZstdOutputStream(compressedOut)
- Utils.tryWithSafeFinally {
- compressedOut.write(DIRECT)
- // `out.writeTo(zos)` will write the uncompressed data from `out` to `zos`
- // without copying to avoid unnecessary allocation and copy of byte[].
- out.writeTo(zos)
- } {
- zos.close()
- }
- compressedOut.toByteArray
- }
+ val arr = out.toByteArray
if (arr.length >= minBroadcastSize) {
// Use broadcast instead.
// Important arr(0) is the tag == DIRECT, ignore that while deserializing !
val bcast = broadcastManager.newBroadcast(arr, isLocal)
// toByteArray creates copy, so we can reuse out
out.reset()
- val oos = new ObjectOutputStream(out)
+ out.write(BROADCAST)
+ val oos = new ObjectOutputStream(codec.compressedOutputStream(out))
Utils.tryWithSafeFinally {
oos.writeObject(bcast)
} {
oos.close()
}
- val outArr = {
- compressedOut.reset()
- val zos = new ZstdOutputStream(compressedOut)
- Utils.tryWithSafeFinally {
- compressedOut.write(BROADCAST)
- out.writeTo(zos)
- } {
- zos.close()
- }
- compressedOut.toByteArray
- }
+ val outArr = out.toByteArray
logInfo("Broadcast mapstatuses size = " + outArr.length + ", actual size = " + arr.length)
(outArr, bcast)
} else {
@@ -952,11 +936,15 @@ private[spark] object MapOutputTracker extends Logging {
}
// Opposite of serializeMapStatuses.
- def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = {
+ def deserializeMapStatuses(bytes: Array[Byte], conf: SparkConf): Array[MapStatus] = {
assert (bytes.length > 0)
def deserializeObject(arr: Array[Byte], off: Int, len: Int): AnyRef = {
- val objIn = new ObjectInputStream(new ZstdInputStream(
+ val codec = CompressionCodec.createCodec(conf, "zstd")
+ // The ZStd codec is wrapped in a `BufferedInputStream` which avoids overhead excessive
+ // of JNI call while trying to decompress small amount of data for each element
+ // of `MapStatuses`
+ val objIn = new ObjectInputStream(codec.compressedInputStream(
new ByteArrayInputStream(arr, off, len)))
Utils.tryWithSafeFinally {
objIn.readObject()
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala
index e59bf3f0eaf44..f60d940b8c82a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala
@@ -317,8 +317,7 @@ private class ErrorServlet extends RestServlet {
versionMismatch = true
s"Unknown protocol version '$unknownVersion'."
case _ =>
- // never reached
- s"Malformed path $path."
+ "Malformed path."
}
msg += s" Please submit requests through http://[host]:[port]/$serverVersion/submissions/..."
val error = handleError(msg)
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 36211dc2ed4f8..444a1544777a1 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -804,7 +804,7 @@ package object config {
.doc("Expire time in minutes for caching preferred locations of checkpointed RDD." +
"Caching preferred locations can relieve query loading to DFS and save the query " +
"time. The drawback is that the cached locations can be possibly outdated and " +
- "lose data locality. If this config is not specified or is 0, it will not cache.")
+ "lose data locality. If this config is not specified, it will not cache.")
.timeConf(TimeUnit.MINUTES)
.checkValue(_ > 0, "The expire time for caching preferred locations cannot be non-positive.")
.createOptional
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala
index 42802f7113a19..b70ea0073c9a0 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala
@@ -54,5 +54,27 @@ class RDDBarrier[T: ClassTag] private[spark] (rdd: RDD[T]) {
)
}
+ /**
+ * :: Experimental ::
+ * Returns a new RDD by applying a function to each partition of the wrapped RDD, while tracking
+ * the index of the original partition. And all tasks are launched together in a barrier stage.
+ * The interface is the same as [[org.apache.spark.rdd.RDD#mapPartitionsWithIndex]].
+ * Please see the API doc there.
+ * @see [[org.apache.spark.BarrierTaskContext]]
+ */
+ @Experimental
+ @Since("3.0.0")
+ def mapPartitionsWithIndex[S: ClassTag](
+ f: (Int, Iterator[T]) => Iterator[S],
+ preservesPartitioning: Boolean = false): RDD[S] = rdd.withScope {
+ val cleanedF = rdd.sparkContext.clean(f)
+ new MapPartitionsRDD(
+ rdd,
+ (_: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(index, iter),
+ preservesPartitioning,
+ isFromBarrier = true
+ )
+ }
+
// TODO: [SPARK-25247] add extra conf to RDDBarrier, e.g., timeout.
}
diff --git a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala
index 53afe141981f4..5dbef88e73a9e 100644
--- a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala
+++ b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala
@@ -67,19 +67,20 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase {
var serializedBroadcastSizes = 0
val (serializedMapStatus, serializedBroadcast) = MapOutputTracker.serializeMapStatuses(
- shuffleStatus.mapStatuses, tracker.broadcastManager, tracker.isLocal, minBroadcastSize)
+ shuffleStatus.mapStatuses, tracker.broadcastManager, tracker.isLocal, minBroadcastSize,
+ sc.getConf)
serializedMapStatusSizes = serializedMapStatus.length
if (serializedBroadcast != null) {
serializedBroadcastSizes = serializedBroadcast.value.length
}
benchmark.addCase("Serialization") { _ =>
- MapOutputTracker.serializeMapStatuses(
- shuffleStatus.mapStatuses, tracker.broadcastManager, tracker.isLocal, minBroadcastSize)
+ MapOutputTracker.serializeMapStatuses(shuffleStatus.mapStatuses, tracker.broadcastManager,
+ tracker.isLocal, minBroadcastSize, sc.getConf)
}
benchmark.addCase("Deserialization") { _ =>
- val result = MapOutputTracker.deserializeMapStatuses(serializedMapStatus)
+ val result = MapOutputTracker.deserializeMapStatuses(serializedMapStatus, sc.getConf)
assert(result.length == numMaps)
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala
index 2f6c4d6a42ea3..f048f95430138 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala
@@ -29,6 +29,15 @@ class RDDBarrierSuite extends SparkFunSuite with SharedSparkContext {
assert(rdd2.isBarrier())
}
+ test("RDDBarrier mapPartitionsWithIndex") {
+ val rdd = sc.parallelize(1 to 12, 4)
+ assert(rdd.isBarrier() === false)
+
+ val rdd2 = rdd.barrier().mapPartitionsWithIndex((index, iter) => Iterator(index))
+ assert(rdd2.isBarrier())
+ assert(rdd2.collect().toList === List(0, 1, 2, 3))
+ }
+
test("create an RDDBarrier in the middle of a chain of RDDs") {
val rdd = sc.parallelize(1 to 10, 4).map(x => x * 2)
val rdd2 = rdd.barrier().mapPartitions(iter => iter).map(x => (x, x + 1))
diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh
index 61951e73f4bab..1f6fdb2a55ff4 100755
--- a/dev/create-release/release-build.sh
+++ b/dev/create-release/release-build.sh
@@ -280,6 +280,8 @@ if [[ "$1" == "package" ]]; then
BINARY_PKGS_ARGS["without-hadoop"]="-Phadoop-provided"
if [[ $SPARK_VERSION < "3.0." ]]; then
BINARY_PKGS_ARGS["hadoop2.6"]="-Phadoop-2.6 $HIVE_PROFILES"
+ else
+ BINARY_PKGS_ARGS["hadoop3.2"]="-Phadoop-3.2 $HIVE_PROFILES"
fi
fi
diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7
index 26dc6e7bd8bf9..f21e76bf4331a 100644
--- a/dev/deps/spark-deps-hadoop-2.7
+++ b/dev/deps/spark-deps-hadoop-2.7
@@ -149,10 +149,10 @@ lz4-java-1.6.0.jar
machinist_2.12-0.6.8.jar
macro-compat_2.12-1.1.1.jar
mesos-1.4.0-shaded-protobuf.jar
-metrics-core-3.1.5.jar
-metrics-graphite-3.1.5.jar
-metrics-json-3.1.5.jar
-metrics-jvm-3.1.5.jar
+metrics-core-3.2.6.jar
+metrics-graphite-3.2.6.jar
+metrics-json-3.2.6.jar
+metrics-jvm-3.2.6.jar
minlog-1.3.0.jar
netty-all-4.1.42.Final.jar
objenesis-2.5.1.jar
diff --git a/dev/deps/spark-deps-hadoop-3.2 b/dev/deps/spark-deps-hadoop-3.2
index a92b7124cb4af..3ecc3c2b0d35a 100644
--- a/dev/deps/spark-deps-hadoop-3.2
+++ b/dev/deps/spark-deps-hadoop-3.2
@@ -179,10 +179,10 @@ lz4-java-1.6.0.jar
machinist_2.12-0.6.8.jar
macro-compat_2.12-1.1.1.jar
mesos-1.4.0-shaded-protobuf.jar
-metrics-core-3.1.5.jar
-metrics-graphite-3.1.5.jar
-metrics-json-3.1.5.jar
-metrics-jvm-3.1.5.jar
+metrics-core-3.2.6.jar
+metrics-graphite-3.2.6.jar
+metrics-json-3.2.6.jar
+metrics-jvm-3.2.6.jar
minlog-1.3.0.jar
mssql-jdbc-6.2.1.jre7.jar
netty-all-4.1.42.Final.jar
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index c7ea065b28ed8..1443584ccbcb8 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -329,6 +329,7 @@ def __hash__(self):
"pyspark.tests.test_join",
"pyspark.tests.test_profiler",
"pyspark.tests.test_rdd",
+ "pyspark.tests.test_rddbarrier",
"pyspark.tests.test_readwrite",
"pyspark.tests.test_serializers",
"pyspark.tests.test_shuffle",
diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb
index 2d1a9547e3731..f95e4e2f97792 100644
--- a/docs/_plugins/copy_api_dirs.rb
+++ b/docs/_plugins/copy_api_dirs.rb
@@ -96,9 +96,9 @@
end
# End updating JavaDoc files for badge post-processing
- puts "Copying jquery.js from Scala API to Java API for page post-processing of badges"
- jquery_src_file = "./api/scala/lib/jquery.js"
- jquery_dest_file = "./api/java/lib/jquery.js"
+ puts "Copying jquery.min.js from Scala API to Java API for page post-processing of badges"
+ jquery_src_file = "./api/scala/lib/jquery.min.js"
+ jquery_dest_file = "./api/java/lib/jquery.min.js"
mkdir_p("./api/java/lib")
cp(jquery_src_file, jquery_dest_file)
diff --git a/docs/index.md b/docs/index.md
index edb1c421fb794..9e8af0d5f8e2b 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -47,6 +47,7 @@ locally on one machine --- all you need is to have `java` installed on your syst
or the `JAVA_HOME` environment variable pointing to a Java installation.
Spark runs on Java 8/11, Scala 2.12, Python 2.7+/3.4+ and R 3.1+.
+Java 8 prior to version 8u92 support is deprecated as of Spark 3.0.0.
Python 2 support is deprecated as of Spark 3.0.0.
R prior to version 3.4 support is deprecated as of Spark 3.0.0.
For the Scala API, Spark {{site.SPARK_VERSION}}
diff --git a/docs/sql-data-sources-orc.md b/docs/sql-data-sources-orc.md
index 45bff17c6cf2b..bddffe02602e8 100644
--- a/docs/sql-data-sources-orc.md
+++ b/docs/sql-data-sources-orc.md
@@ -31,7 +31,7 @@ the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is also
spark.sql.orc.impl |
native |
- The name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4. `hive` means the ORC library in Hive 1.2.1. |
+ The name of ORC implementation. It can be one of native and hive. native means the native ORC support. hive means the ORC library in Hive. |
spark.sql.orc.enableVectorizedReader |
diff --git a/docs/sql-getting-started.md b/docs/sql-getting-started.md
index 5d18c48879f93..0ded2654719c1 100644
--- a/docs/sql-getting-started.md
+++ b/docs/sql-getting-started.md
@@ -346,6 +346,9 @@ For example:
+## Scalar Functions
+(to be filled soon)
+
## Aggregations
The [built-in DataFrames functions](api/scala/index.html#org.apache.spark.sql.functions$) provide common
diff --git a/docs/sql-keywords.md b/docs/sql-keywords.md
index 7a0e3efee8ffa..b4f8d8be11c4f 100644
--- a/docs/sql-keywords.md
+++ b/docs/sql-keywords.md
@@ -210,6 +210,7 @@ Below is a list of all the keywords in Spark SQL.
| PRECEDING | non-reserved | non-reserved | non-reserved |
| PRIMARY | reserved | non-reserved | reserved |
| PRINCIPALS | non-reserved | non-reserved | non-reserved |
+ | PROPERTIES | non-reserved | non-reserved | non-reserved |
| PURGE | non-reserved | non-reserved | non-reserved |
| QUERY | non-reserved | non-reserved | non-reserved |
| RANGE | non-reserved | non-reserved | reserved |
diff --git a/docs/sql-ref-syntax-ddl-alter-table.md b/docs/sql-ref-syntax-ddl-alter-table.md
index 7fcd397915825..e311691c6b801 100644
--- a/docs/sql-ref-syntax-ddl-alter-table.md
+++ b/docs/sql-ref-syntax-ddl-alter-table.md
@@ -19,4 +19,240 @@ license: |
limitations under the License.
---
-**This page is under construction**
+### Description
+`ALTER TABLE` statement changes the schema or properties of a table.
+
+### RENAME
+`ALTER TABLE RENAME` statement changes the table name of an existing table in the database.
+
+#### Syntax
+{% highlight sql %}
+ALTER TABLE [db_name.]old_table_name RENAME TO [db_name.]new_table_name
+
+ALTER TABLE table_name PARTITION partition_spec RENAME TO PARTITION partition_spec;
+
+{% endhighlight %}
+
+#### Parameters
+
+ old_table_name
+ - Name of an existing table.
+
+
+ db_name
+ - Name of the existing database.
+
+
+
+ new_table_name
+ - New name using which the table has to be renamed.
+
+
+
+ partition_spec
+ - Partition to be renamed.
+
+
+
+### ADD COLUMNS
+`ALTER TABLE ADD COLUMNS` statement adds mentioned columns to an existing table.
+
+#### Syntax
+{% highlight sql %}
+ALTER TABLE table_name ADD COLUMNS (col_spec[, col_spec ...])
+{% endhighlight %}
+
+#### Parameters
+
+ table_name
+ - The name of an existing table.
+
+
+
+
+ COLUMNS (col_spec)
+ - Specifies the columns to be added to be renamed.
+
+
+
+### SET AND UNSET
+
+#### SET TABLE PROPERTIES
+`ALTER TABLE SET` command is used for setting the table properties. If a particular property was already set,
+this overrides the old value with the new one.
+
+`ALTER TABLE UNSET` is used to drop the table property.
+
+##### Syntax
+{% highlight sql %}
+
+--Set Table Properties
+ALTER TABLE table_name SET TBLPROPERTIES (key1=val1, key2=val2, ...)
+
+--Unset Table Properties
+ALTER TABLE table_name UNSET TBLPROPERTIES [IF EXISTS] (key1, key2, ...)
+
+{% endhighlight %}
+
+#### SET SERDE
+`ALTER TABLE SET` command is used for setting the SERDE or SERDE properties in Hive tables. If a particular property was already set,
+this overrides the old value with the new one.
+
+##### Syntax
+{% highlight sql %}
+
+--Set SERDE Propeties
+ALTER TABLE table_name [PARTITION part_spec]
+ SET SERDEPROPERTIES (key1=val1, key2=val2, ...)
+
+ALTER TABLE table_name [PARTITION part_spec] SET SERDE serde_class_name
+ [WITH SERDEPROPERTIES (key1=val1, key2=val2, ...)]
+
+{% endhighlight %}
+
+#### SET LOCATION And SET FILE FORMAT
+`ALTER TABLE SET` command can also be used for changing the file location and file format for
+exsisting tables.
+
+##### Syntax
+{% highlight sql %}
+
+--Changing File Format
+ALTER TABLE table_name [PARTITION partition_spec] SET FILEFORMAT file_format;
+
+--Changing File Location
+ALTER TABLE table_name [PARTITION partition_spec] SET LOCATION 'new_location';
+
+{% endhighlight %}
+
+#### Parameters
+
+ table_name
+ - The name of an existing table.
+
+
+
+ PARTITION (part_spec)
+ - Specifies the partition on which the property has to be set.
+
+
+
+ SERDEPROPERTIES (key1=val1, key2=val2, ...)
+ - Specifies the SERDE properties to be set.
+
+
+
+### Examples
+{% highlight sql %}
+
+--RENAME table
+DESC student;
++--------------------------+------------+----------+--+
+| col_name | data_type | comment |
++--------------------------+------------+----------+--+
+| name | string | NULL |
+| rollno | int | NULL |
+| age | int | NULL |
+| # Partition Information | | |
+| # col_name | data_type | comment |
+| age | int | NULL |
++--------------------------+------------+----------+--+
+
+ALTER TABLE Student RENAME TO StudentInfo;
+
+--After Renaming the table
+
+DESC StudentInfo;
++--------------------------+------------+----------+--+
+| col_name | data_type | comment |
++--------------------------+------------+----------+--+
+| name | string | NULL |
+| rollno | int | NULL |
+| age | int | NULL |
+| # Partition Information | | |
+| # col_name | data_type | comment |
+| age | int | NULL |
++--------------------------+------------+----------+--+
+
+--RENAME partition
+
+SHOW PARTITIONS StudentInfo;
++------------+--+
+| partition |
++------------+--+
+| age=10 |
+| age=11 |
+| age=12 |
++------------+--+
+
+ALTER TABLE default.StudentInfo PARTITION (age='10') RENAME TO PARTITION (age='15');
+
+--After renaming Partition
+SHOW PARTITIONS StudentInfo;
++------------+--+
+| partition |
++------------+--+
+| age=11 |
+| age=12 |
+| age=15 |
++------------+--+
+
+-- Add new column to a table
+
+DESC StudentInfo;
++--------------------------+------------+----------+--+
+| col_name | data_type | comment |
++--------------------------+------------+----------+--+
+| name | string | NULL |
+| rollno | int | NULL |
+| age | int | NULL |
+| # Partition Information | | |
+| # col_name | data_type | comment |
+| age | int | NULL |
++--------------------------+------------+----------+
+
+ALTER TABLE StudentInfo ADD columns (LastName string, DOB timestamp);
+
+--After Adding New columns to the table
+DESC StudentInfo;
++--------------------------+------------+----------+--+
+| col_name | data_type | comment |
++--------------------------+------------+----------+--+
+| name | string | NULL |
+| rollno | int | NULL |
+| LastName | string | NULL |
+| DOB | timestamp | NULL |
+| age | int | NULL |
+| # Partition Information | | |
+| # col_name | data_type | comment |
+| age | int | NULL |
++--------------------------+------------+----------+--+
+
+
+--Change the fileformat
+ALTER TABLE loc_orc SET fileformat orc;
+
+ALTER TABLE p1 partition (month=2, day=2) SET fileformat parquet;
+
+--Change the file Location
+ALTER TABLE dbx.tab1 PARTITION (a='1', b='2') SET LOCATION '/path/to/part/ways'
+
+-- SET SERDE/ SERDE Properties
+ALTER TABLE test_tab SET SERDE 'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe';
+
+ALTER TABLE dbx.tab1 SET SERDE 'org.apache.madoop' WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')
+
+--SET TABLE PROPERTIES
+ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('winner' = 'loser')
+
+--DROP TABLE PROPERTIES
+ALTER TABLE dbx.tab1 UNSET TBLPROPERTIES ('winner')
+
+{% endhighlight %}
+
+
+### Related Statements
+- [CREATE TABLE](sql-ref-syntax-ddl-create-table.html)
+- [DROP TABLE](sql-ref-syntax-ddl-drop-table.html)
+
+
diff --git a/docs/sql-ref-syntax-ddl-create-function.md b/docs/sql-ref-syntax-ddl-create-function.md
index f95a9eba42c2f..4c09ebafb1f5d 100644
--- a/docs/sql-ref-syntax-ddl-create-function.md
+++ b/docs/sql-ref-syntax-ddl-create-function.md
@@ -19,4 +19,153 @@ license: |
limitations under the License.
---
-**This page is under construction**
+### Description
+The `CREATE FUNCTION` statement is used to create a temporary or permanent function
+in Spark. Temporary functions are scoped at a session level where as permanent
+functions are created in the persistent catalog and are made available to
+all sessions. The resources specified in the `USING` clause are made available
+to all executors when they are executed for the first time. In addition to the
+SQL interface, spark allows users to create custom user defined scalar and
+aggregate functions using Scala, Python and Java APIs. Please refer to
+[scalar_functions](sql-getting-started.html#scalar-functions) and
+[aggregate functions](sql-getting-started#aggregations) for more information.
+
+### Syntax
+{% highlight sql %}
+CREATE [ OR REPLACE ] [ TEMPORARY ] FUNCTION [ IF NOT EXISTS ]
+ function_name AS class_name [ resource_locations ]
+{% endhighlight %}
+
+### Parameters
+
+ OR REPLACE
+ -
+ If specified, the resources for the function are reloaded. This is mainly useful
+ to pick up any changes made to the implementation of the function. This
+ parameter is mutually exclusive to
IF NOT EXISTS and can not
+ be specified together.
+
+ TEMPORARY
+ -
+ Indicates the scope of function being created. When
TEMPORARY is specified, the
+ created function is valid and visible in the current session. No persistent
+ entry is made in the catalog for these kind of functions.
+
+ IF NOT EXISTS
+ -
+ If specified, creates the function only when it does not exist. The creation
+ of function succeeds (no error is thrown) if the specified function already
+ exists in the system. This parameter is mutually exclusive to
OR REPLACE
+ and can not be specified together.
+
+ function_name
+ -
+ Specifies a name of funnction to be created. The function name may be
+ optionally qualified with a database name.
+ Syntax:
+
+ [database_name.]function_name
+
+
+ class_name
+ -
+ Specifies the name of the class that provides the implementation for function to be created.
+ The implementing class should extend one of the base classes as follows:
+
+ - Should extend
UDF or UDAF in org.apache.hadoop.hive.ql.exec package.
+ - Should extend
AbstractGenericUDAFResolver, GenericUDF, or
+ GenericUDTF in org.apache.hadoop.hive.ql.udf.generic package.
+ - Should extend
UserDefinedAggregateFunction in org.apache.spark.sql.expressions package.
+
+
+ resource_locations
+ -
+ Specifies the list of resources that contain the implementation of the function
+ along with its dependencies.
+ Syntax:
+
+ USING { { (JAR | FILE ) resource_uri} , ...}
+
+
+
+
+### Examples
+{% highlight sql %}
+-- 1. Create a simple UDF `SimpleUdf` that increments the supplied integral value by 10.
+-- import org.apache.hadoop.hive.ql.exec.UDF;
+-- public class SimpleUdf extends UDF {
+-- public int evaluate(int value) {
+-- return value + 10;
+-- }
+-- }
+-- 2. Compile and place it in a JAR file called `SimpleUdf.jar` in /tmp.
+
+-- Create a table called `test` and insert two rows.
+CREATE TABLE test(c1 INT);
+INSERT INTO test VALUES (1), (2);
+
+-- Create a permanent function called `simple_udf`.
+CREATE FUNCTION simple_udf AS 'SimpleUdf'
+ USING JAR '/tmp/SimpleUdf.jar';
+
+-- Verify that the function is in the registry.
+SHOW USER FUNCTIONS;
+ +------------------+
+ | function|
+ +------------------+
+ |default.simple_udf|
+ +------------------+
+
+-- Invoke the function. Every selected value should be incremented by 10.
+SELECT simple_udf(c1) AS function_return_value FROM t1;
+ +---------------------+
+ |function_return_value|
+ +---------------------+
+ | 11|
+ | 12|
+ +---------------------+
+
+-- Created a temporary function.
+CREATE TEMPORARY FUNCTION simple_temp_udf AS 'SimpleUdf'
+ USING JAR '/tmp/SimpleUdf.jar';
+
+-- Verify that the newly created temporary function is in the registry.
+-- Please note that the temporary function does not have a qualified
+-- database associated with it.
+SHOW USER FUNCTIONS;
+ +------------------+
+ | function|
+ +------------------+
+ |default.simple_udf|
+ | simple_temp_udf|
+ +------------------+
+
+-- 1. Modify `SimpleUdf`'s implementation to add supplied integral value by 20.
+-- import org.apache.hadoop.hive.ql.exec.UDF;
+
+-- public class SimpleUdfR extends UDF {
+-- public int evaluate(int value) {
+-- return value + 20;
+-- }
+-- }
+-- 2. Compile and place it in a jar file called `SimpleUdfR.jar` in /tmp.
+
+-- Replace the implementation of `simple_udf`
+CREATE OR REPLACE FUNCTION simple_udf AS 'SimpleUdfR'
+ USING JAR '/tmp/SimpleUdfR.jar';
+
+-- Invoke the function. Every selected value should be incremented by 20.
+SELECT simple_udf(c1) AS function_return_value FROM t1;
++---------------------+
+|function_return_value|
++---------------------+
+| 21|
+| 22|
++---------------------+
+
+{% endhighlight %}
+
+### Related statements
+- [SHOW FUNCTIONS](sql-ref-syntax-aux-show-functions.html)
+- [DESCRIBE FUNCTION](sql-ref-syntax-aux-describe-function.html)
+- [DROP FUNCTION](sql-ref-syntax-ddl-drop-function.html)
diff --git a/docs/sql-ref-syntax-dml-load.md b/docs/sql-ref-syntax-dml-load.md
index fd25ba314e0b6..c2a6102db4aad 100644
--- a/docs/sql-ref-syntax-dml-load.md
+++ b/docs/sql-ref-syntax-dml-load.md
@@ -1,7 +1,7 @@
---
layout: global
-title: LOAD
-displayTitle: LOAD
+title: LOAD DATA
+displayTitle: LOAD DATA
license: |
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
@@ -19,4 +19,101 @@ license: |
limitations under the License.
---
-**This page is under construction**
+### Description
+`LOAD DATA` statement loads the data into a table from the user specified directory or file. If a directory is specified then all the files from the directory are loaded. If a file is specified then only the single file is loaded. Additionally the `LOAD DATA` statement takes an optional partition specification. When a partition is specified, the data files (when input source is a directory) or the single file (when input source is a file) are loaded into the partition of the target table.
+
+### Syntax
+{% highlight sql %}
+LOAD DATA [ LOCAL ] INPATH path [ OVERWRITE ] INTO TABLE table_name
+ [ PARTITION ( partition_col_name = partition_col_val [ , ... ] ) ]
+{% endhighlight %}
+
+### Parameters
+
+ path
+ - Path of the file system. It can be either an absolute or a relative path.
+
+
+
+ table_name
+ - The name of an existing table.
+
+
+
+ PARTITION ( partition_col_name = partition_col_val [ , ... ] )
+ - Specifies one or more partition column and value pairs.
+
+
+
+ LOCAL
+ - If specified, it causes the
INPATH to be resolved against the local file system, instead of the default file system, which is typically a distributed storage.
+
+
+
+ OVERWRITE
+ - By default, new data is appended to the table. If
OVERWRITE is used, the table is instead overwritten with new data.
+
+
+### Examples
+{% highlight sql %}
+ -- Example without partition specification.
+ -- Assuming the students table has already been created and populated.
+ SELECT * FROM students;
+
+ + -------------- + ------------------------------ + -------------- +
+ | name | address | student_id |
+ + -------------- + ------------------------------ + -------------- +
+ | Amy Smith | 123 Park Ave, San Jose | 111111 |
+ + -------------- + ------------------------------ + -------------- +
+
+ CREATE TABLE test_load (name VARCHAR(64), address VARCHAR(64), student_id INT);
+
+ -- Assuming the students table is in '/user/hive/warehouse/'
+ LOAD DATA LOCAL INPATH '/user/hive/warehouse/students' OVERWRITE INTO TABLE test_load;
+
+ SELECT * FROM test_load;
+
+ + -------------- + ------------------------------ + -------------- +
+ | name | address | student_id |
+ + -------------- + ------------------------------ + -------------- +
+ | Amy Smith | 123 Park Ave, San Jose | 111111 |
+ + -------------- + ------------------------------ + -------------- +
+
+ -- Example with partition specification.
+ CREATE TABLE test_partition (c1 INT, c2 INT, c3 INT) USING HIVE PARTITIONED BY (c2, c3);
+
+ INSERT INTO test_partition PARTITION (c2 = 2, c3 = 3) VALUES (1);
+
+ INSERT INTO test_partition PARTITION (c2 = 5, c3 = 6) VALUES (4);
+
+ INSERT INTO test_partition PARTITION (c2 = 8, c3 = 9) VALUES (7);
+
+ SELECT * FROM test_partition;
+
+ + ------- + ------- + ----- +
+ | c1 | c2 | c3 |
+ + ------- + --------------- +
+ | 1 | 2 | 3 |
+ + ------- + ------- + ----- +
+ | 4 | 5 | 6 |
+ + ------- + ------- + ----- +
+ | 7 | 8 | 9 |
+ + ------- + ------- + ----- +
+
+ CREATE TABLE test_load_partition (c1 INT, c2 INT, c3 INT) USING HIVE PARTITIONED BY (c2, c3);
+
+ -- Assuming the test_partition table is in '/user/hive/warehouse/'
+ LOAD DATA LOCAL INPATH '/user/hive/warehouse/test_partition/c2=2/c3=3'
+ OVERWRITE INTO TABLE test_load_partition PARTITION (c2=2, c3=3);
+
+ SELECT * FROM test_load_partition;
+
+ + ------- + ------- + ----- +
+ | c1 | c2 | c3 |
+ + ------- + --------------- +
+ | 1 | 2 | 3 |
+ + ------- + ------- + ----- +
+
+
+{% endhighlight %}
+
diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md
index 89732d309aa27..badf0429545f3 100644
--- a/docs/structured-streaming-kafka-integration.md
+++ b/docs/structured-streaming-kafka-integration.md
@@ -614,6 +614,10 @@ The Dataframe being written to Kafka should have the following columns in schema
topic (*optional) |
string |
+
+ | partition (optional) |
+ int |
+
\* The topic column is required if the "topic" configuration option is not specified.
@@ -622,6 +626,12 @@ a ```null``` valued key column will be automatically added (see Kafka semantics
how ```null``` valued key values are handled). If a topic column exists then its value
is used as the topic when writing the given row to Kafka, unless the "topic" configuration
option is set i.e., the "topic" configuration option overrides the topic column.
+If a "partition" column is not specified (or its value is ```null```)
+then the partition is calculated by the Kafka producer.
+A Kafka partitioner can be specified in Spark by setting the
+```kafka.partitioner.class``` option. If not present, Kafka default partitioner
+will be used.
+
The following options must be set for the Kafka sink
for both batch and streaming queries.
diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml
index a4956ff5ee9cc..aff79b8b8e642 100644
--- a/external/docker-integration-tests/pom.xml
+++ b/external/docker-integration-tests/pom.xml
@@ -106,6 +106,14 @@
test-jar
test
+
+
+ org.glassfish.jersey.bundles.repackaged
+ jersey-guava
+ 2.25.1
+ test
+
mysql
mysql-connector-java
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
index 9cd5c4ec41a52..bba1b5275269b 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.tags.DockerTest
@DockerTest
class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
override val db = new DatabaseOnDocker {
- override val imageName = "mysql:5.7.9"
+ override val imageName = "mysql:5.7.28"
override val env = Map(
"MYSQL_ROOT_PASSWORD" -> "rootpass"
)
@@ -39,6 +39,8 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
}
override def dataPreparation(conn: Connection): Unit = {
+ // Since MySQL 5.7.14+, we need to disable strict mode
+ conn.prepareStatement("SET GLOBAL sql_mode = ''").executeUpdate()
conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
conn.prepareStatement("CREATE TABLE tbl (x INTEGER, y TEXT(8))").executeUpdate()
conn.prepareStatement("INSERT INTO tbl VALUES (42,'fred')").executeUpdate()
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
index 89da9a1de6f74..599f00def0750 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
@@ -29,7 +29,7 @@ import org.apache.spark.tags.DockerTest
@DockerTest
class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
override val db = new DatabaseOnDocker {
- override val imageName = "postgres:11.4"
+ override val imageName = "postgres:12.0-alpine"
override val env = Map(
"POSTGRES_PASSWORD" -> "rootpass"
)
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
index b423ddc959c1b..5bdc1b5fe9f37 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
@@ -27,7 +27,7 @@ import org.apache.kafka.common.header.internals.RecordHeader
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection}
-import org.apache.spark.sql.types.{BinaryType, StringType}
+import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType}
/**
* Writes out data in a single Spark task, without any concerns about how
@@ -92,8 +92,10 @@ private[kafka010] abstract class KafkaRowWriter(
throw new NullPointerException(s"null topic present in the data. Use the " +
s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.")
}
+ val partition: Integer =
+ if (projectedRow.isNullAt(4)) null else projectedRow.getInt(4)
val record = if (projectedRow.isNullAt(3)) {
- new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, null, key, value)
+ new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, partition, key, value)
} else {
val headerArray = projectedRow.getArray(3)
val headers = (0 until headerArray.numElements()).map { i =>
@@ -101,7 +103,8 @@ private[kafka010] abstract class KafkaRowWriter(
new RecordHeader(struct.getUTF8String(0).toString, struct.getBinary(1))
.asInstanceOf[Header]
}
- new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, null, key, value, headers.asJava)
+ new ProducerRecord[Array[Byte], Array[Byte]](
+ topic.toString, partition, key, value, headers.asJava)
}
producer.send(record, callback)
}
@@ -156,12 +159,23 @@ private[kafka010] abstract class KafkaRowWriter(
throw new IllegalStateException(s"${KafkaWriter.HEADERS_ATTRIBUTE_NAME} " +
s"attribute unsupported type ${t.catalogString}")
}
+ val partitionExpression =
+ inputSchema.find(_.name == KafkaWriter.PARTITION_ATTRIBUTE_NAME)
+ .getOrElse(Literal(null, IntegerType))
+ partitionExpression.dataType match {
+ case IntegerType => // good
+ case t =>
+ throw new IllegalStateException(s"${KafkaWriter.PARTITION_ATTRIBUTE_NAME} " +
+ s"attribute unsupported type $t. ${KafkaWriter.PARTITION_ATTRIBUTE_NAME} " +
+ s"must be a ${IntegerType.catalogString}")
+ }
UnsafeProjection.create(
Seq(
topicExpression,
Cast(keyExpression, BinaryType),
Cast(valueExpression, BinaryType),
- headersExpression
+ headersExpression,
+ partitionExpression
),
inputSchema
)
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
index bbb060356f730..9b0d11f137ce2 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.QueryExecution
-import org.apache.spark.sql.types.{BinaryType, MapType, StringType}
+import org.apache.spark.sql.types.{BinaryType, IntegerType, MapType, StringType}
import org.apache.spark.util.Utils
/**
@@ -41,6 +41,7 @@ private[kafka010] object KafkaWriter extends Logging {
val KEY_ATTRIBUTE_NAME: String = "key"
val VALUE_ATTRIBUTE_NAME: String = "value"
val HEADERS_ATTRIBUTE_NAME: String = "headers"
+ val PARTITION_ATTRIBUTE_NAME: String = "partition"
override def toString: String = "KafkaWriter"
@@ -86,6 +87,14 @@ private[kafka010] object KafkaWriter extends Logging {
throw new AnalysisException(s"$HEADERS_ATTRIBUTE_NAME attribute type " +
s"must be a ${KafkaRecordToRowConverter.headersType.catalogString}")
}
+ schema.find(_.name == PARTITION_ATTRIBUTE_NAME).getOrElse(
+ Literal(null, IntegerType)
+ ).dataType match {
+ case IntegerType => // good
+ case _ =>
+ throw new AnalysisException(s"$PARTITION_ATTRIBUTE_NAME attribute type " +
+ s"must be an ${IntegerType.catalogString}")
+ }
}
def write(
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala
index 65adbd6b9887c..cbf4952406c01 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala
@@ -286,6 +286,15 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest {
}
assert(ex3.getMessage.toLowerCase(Locale.ROOT).contains(
"key attribute type must be a string or binary"))
+
+ val ex4 = intercept[AnalysisException] {
+ /* partition field wrong type */
+ createKafkaWriter(input.toDF())(
+ withSelectExpr = s"'$topic' as topic", "value as partition", "value"
+ )
+ }
+ assert(ex4.getMessage.toLowerCase(Locale.ROOT).contains(
+ "partition attribute type must be an int"))
}
test("streaming - write to non-existing topic") {
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
index d77b9a3b6a9e1..aacb10f5197b0 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
@@ -22,6 +22,8 @@ import java.util.Locale
import java.util.concurrent.atomic.AtomicInteger
import org.apache.kafka.clients.producer.ProducerConfig
+import org.apache.kafka.clients.producer.internals.DefaultPartitioner
+import org.apache.kafka.common.Cluster
import org.apache.kafka.common.serialization.ByteArraySerializer
import org.scalatest.time.SpanSugar._
@@ -33,7 +35,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.{BinaryType, DataType, StringType, StructField, StructType}
+import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, StringType, StructField, StructType}
abstract class KafkaSinkSuiteBase extends QueryTest with SharedSparkSession with KafkaTest {
protected var testUtils: KafkaTestUtils = _
@@ -293,6 +295,21 @@ class KafkaSinkStreamingSuite extends KafkaSinkSuiteBase with StreamTest {
}
assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
"key attribute type must be a string or binary"))
+
+ try {
+ ex = intercept[StreamingQueryException] {
+ /* partition field wrong type */
+ writer = createKafkaWriter(input.toDF())(
+ withSelectExpr = s"'$topic' as topic", "value", "value as partition"
+ )
+ input.addData("1", "2", "3", "4", "5")
+ writer.processAllAvailable()
+ }
+ } finally {
+ writer.stop()
+ }
+ assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
+ "partition attribute type must be an int"))
}
test("streaming - write to non-existing topic") {
@@ -418,6 +435,65 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase {
)
}
+ def writeToKafka(df: DataFrame, topic: String, options: Map[String, String] = Map.empty): Unit = {
+ df
+ .write
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("topic", topic)
+ .options(options)
+ .mode("append")
+ .save()
+ }
+
+ def partitionsInTopic(topic: String): Set[Int] = {
+ createKafkaReader(topic)
+ .select("partition")
+ .map(_.getInt(0))
+ .collect()
+ .toSet
+ }
+
+ test("batch - partition column and partitioner priorities") {
+ val nrPartitions = 4
+ val topic1 = newTopic()
+ val topic2 = newTopic()
+ val topic3 = newTopic()
+ val topic4 = newTopic()
+ testUtils.createTopic(topic1, nrPartitions)
+ testUtils.createTopic(topic2, nrPartitions)
+ testUtils.createTopic(topic3, nrPartitions)
+ testUtils.createTopic(topic4, nrPartitions)
+ val customKafkaPartitionerConf = Map(
+ "kafka.partitioner.class" -> "org.apache.spark.sql.kafka010.TestKafkaPartitioner"
+ )
+
+ val df = (0 until 5).map(n => (topic1, s"$n", s"$n")).toDF("topic", "key", "value")
+
+ // default kafka partitioner
+ writeToKafka(df, topic1)
+ val partitionsInTopic1 = partitionsInTopic(topic1)
+ assert(partitionsInTopic1.size > 1)
+
+ // custom partitioner (always returns 0) overrides default partitioner
+ writeToKafka(df, topic2, customKafkaPartitionerConf)
+ val partitionsInTopic2 = partitionsInTopic(topic2)
+ assert(partitionsInTopic2.size == 1)
+ assert(partitionsInTopic2.head == 0)
+
+ // partition column overrides custom partitioner
+ val dfWithCustomPartition = df.withColumn("partition", lit(2))
+ writeToKafka(dfWithCustomPartition, topic3, customKafkaPartitionerConf)
+ val partitionsInTopic3 = partitionsInTopic(topic3)
+ assert(partitionsInTopic3.size == 1)
+ assert(partitionsInTopic3.head == 2)
+
+ // when the partition column value is null, it is ignored
+ val dfWithNullPartitions = df.withColumn("partition", lit(null).cast(IntegerType))
+ writeToKafka(dfWithNullPartitions, topic4)
+ assert(partitionsInTopic(topic4) == partitionsInTopic1)
+ }
+
test("batch - null topic field value, and no topic option") {
val df = Seq[(String, String)](null.asInstanceOf[String] -> "1").toDF("topic", "value")
val ex = intercept[SparkException] {
@@ -515,3 +591,13 @@ class KafkaSinkBatchSuiteV2 extends KafkaSinkBatchSuiteBase {
}
}
}
+
+class TestKafkaPartitioner extends DefaultPartitioner {
+ override def partition(
+ topic: String,
+ key: Any,
+ keyBytes: Array[Byte],
+ value: Any,
+ valueBytes: Array[Byte],
+ cluster: Cluster): Int = 0
+}
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala
index bbb72bf9973e3..6c745987b4c23 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala
@@ -67,6 +67,8 @@ class KafkaTestUtils(
secure: Boolean = false) extends Logging {
private val JAVA_AUTH_CONFIG = "java.security.auth.login.config"
+ private val IBM_KRB_DEBUG_CONFIG = "com.ibm.security.krb5.Krb5Debug"
+ private val SUN_KRB_DEBUG_CONFIG = "sun.security.krb5.debug"
private val localCanonicalHostName = InetAddress.getLoopbackAddress().getCanonicalHostName()
logInfo(s"Local host name is $localCanonicalHostName")
@@ -133,6 +135,7 @@ class KafkaTestUtils(
private def setUpMiniKdc(): Unit = {
val kdcDir = Utils.createTempDir()
val kdcConf = MiniKdc.createConf()
+ kdcConf.setProperty(MiniKdc.DEBUG, "true")
kdc = new MiniKdc(kdcConf, kdcDir)
kdc.start()
kdcReady = true
@@ -238,6 +241,7 @@ class KafkaTestUtils(
}
if (secure) {
+ setupKrbDebug()
setUpMiniKdc()
val jaasConfigFile = createKeytabsAndJaasConfigFile()
System.setProperty(JAVA_AUTH_CONFIG, jaasConfigFile)
@@ -252,6 +256,14 @@ class KafkaTestUtils(
}
}
+ private def setupKrbDebug(): Unit = {
+ if (System.getProperty("java.vendor").contains("IBM")) {
+ System.setProperty(IBM_KRB_DEBUG_CONFIG, "all")
+ } else {
+ System.setProperty(SUN_KRB_DEBUG_CONFIG, "true")
+ }
+ }
+
/** Teardown the whole servers, including Kafka broker and Zookeeper */
def teardown(): Unit = {
if (leakDetector != null) {
@@ -303,6 +315,15 @@ class KafkaTestUtils(
kdc.stop()
}
UserGroupInformation.reset()
+ teardownKrbDebug()
+ }
+
+ private def teardownKrbDebug(): Unit = {
+ if (System.getProperty("java.vendor").contains("IBM")) {
+ System.clearProperty(IBM_KRB_DEBUG_CONFIG)
+ } else {
+ System.clearProperty(SUN_KRB_DEBUG_CONFIG)
+ }
}
/** Create a Kafka topic and wait until it is propagated to the whole cluster */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 9ac673078d4ad..3bff236677e6b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -53,7 +53,7 @@ private[spark] trait ClassifierParams
val validateInstance = (instance: Instance) => {
val label = instance.label
require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
- s" dataset with invalid label $label. Labels must be integers in range" +
+ s" dataset with invalid label $label. Labels must be integers in range" +
s" [0, $numClasses).")
}
extractInstances(dataset, validateInstance)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 09f81b0dcbdae..5bc45f2b02a4b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -23,7 +23,7 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
@@ -34,7 +34,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
/**
@@ -79,6 +79,10 @@ class GBTClassifier @Since("1.4.0") (
@Since("1.4.0")
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
+ /** @group setParam */
+ @Since("3.0.0")
+ def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
+
/** @group setParam */
@Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
@@ -152,36 +156,34 @@ class GBTClassifier @Since("1.4.0") (
set(validationIndicatorCol, value)
}
+ /**
+ * Sets the value of param [[weightCol]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
+ * By default the weightCol is not set, so all instances have weight 1.0.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
override protected def train(
dataset: Dataset[_]): GBTClassificationModel = instrumented { instr =>
- val categoricalFeatures: Map[Int, Int] =
- MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
-
val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty
- // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
- // 2 classes now. This lets us provide a more precise error message.
- val convert2LabeledPoint = (dataset: Dataset[_]) => {
- dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
- case Row(label: Double, features: Vector) =>
- require(label == 0 || label == 1, s"GBTClassifier was given" +
- s" dataset with invalid label $label. Labels must be in {0,1}; note that" +
- s" GBTClassifier currently only supports binary classification.")
- LabeledPoint(label, features)
- }
+ val validateInstance = (instance: Instance) => {
+ val label = instance.label
+ require(label == 0 || label == 1, s"GBTClassifier was given" +
+ s" dataset with invalid label $label. Labels must be in {0,1}; note that" +
+ s" GBTClassifier currently only supports binary classification.")
}
val (trainDataset, validationDataset) = if (withValidation) {
- (
- convert2LabeledPoint(dataset.filter(not(col($(validationIndicatorCol))))),
- convert2LabeledPoint(dataset.filter(col($(validationIndicatorCol))))
- )
+ (extractInstances(dataset.filter(not(col($(validationIndicatorCol)))), validateInstance),
+ extractInstances(dataset.filter(col($(validationIndicatorCol))), validateInstance))
} else {
- (convert2LabeledPoint(dataset), null)
+ (extractInstances(dataset, validateInstance), null)
}
- val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
-
val numClasses = 2
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
@@ -191,18 +193,21 @@ class GBTClassifier @Since("1.4.0") (
instr.logPipelineStage(this)
instr.logDataset(dataset)
- instr.logParams(this, labelCol, featuresCol, predictionCol, leafCol, impurity,
- lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
- seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy,
- validationIndicatorCol, validationTol)
+ instr.logParams(this, labelCol, weightCol, featuresCol, predictionCol, leafCol,
+ impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain,
+ minInstancesPerNode, minWeightFractionPerNode, seed, stepSize, subsamplingRate, cacheNodeIds,
+ checkpointInterval, featureSubsetStrategy, validationIndicatorCol, validationTol)
instr.logNumClasses(numClasses)
+ val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+ val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val (baseLearners, learnerWeights) = if (withValidation) {
GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy,
$(seed), $(featureSubsetStrategy))
} else {
GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy))
}
+ baseLearners.foreach(copyValues(_))
val numFeatures = baseLearners.head.numFeatures
instr.logNumFeatures(numFeatures)
@@ -373,12 +378,9 @@ class GBTClassificationModel private[ml](
*/
@Since("2.4.0")
def evaluateEachIteration(dataset: Dataset[_]): Array[Double] = {
- val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
- case Row(label: Double, features: Vector) => LabeledPoint(label, features)
- }
+ val data = extractInstances(dataset)
GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, loss,
- OldAlgo.Classification
- )
+ OldAlgo.Classification)
}
@Since("2.0.0")
@@ -422,10 +424,9 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int]
val numTrees = (metadata.metadata \ numTreesKey).extract[Int]
- val trees: Array[DecisionTreeRegressionModel] = treesData.map {
+ val trees = treesData.map {
case (treeMetadata, root) =>
- val tree =
- new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+ val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
treeMetadata.getAndSetParams(tree)
tree
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 731b43b67813f..245cda35d8ade 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -143,6 +143,7 @@ class RandomForestClassifier @Since("1.4.0") (
val trees = RandomForest
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeClassificationModel])
+ trees.foreach(copyValues(_))
val numFeatures = trees.head.numFeatures
instr.logNumClasses(numClasses)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala
index dd56fbbfa2b63..11d0c4689cbba 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala
@@ -26,7 +26,7 @@ import org.apache.spark.ml.linalg.Vector
* @param weight The weight of this instance.
* @param features The vector of features for this data point.
*/
-private[ml] case class Instance(label: Double, weight: Double, features: Vector)
+private[spark] case class Instance(label: Double, weight: Double, features: Vector)
/**
* Case class that represents an instance of data point with
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index aa4ab5903f711..eb78d8224fc3f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.types.StructType
* Params for [[QuantileDiscretizer]].
*/
private[feature] trait QuantileDiscretizerBase extends Params
- with HasHandleInvalid with HasInputCol with HasOutputCol {
+ with HasHandleInvalid with HasInputCol with HasOutputCol with HasInputCols with HasOutputCols {
/**
* Number of buckets (quantiles, or categories) into which data points are grouped. Must
@@ -129,8 +129,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
*/
@Since("1.6.0")
final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val uid: String)
- extends Estimator[Bucketizer] with QuantileDiscretizerBase with DefaultParamsWritable
- with HasInputCols with HasOutputCols {
+ extends Estimator[Bucketizer] with QuantileDiscretizerBase with DefaultParamsWritable {
@Since("1.6.0")
def this() = this(Identifiable.randomUID("quantileDiscretizer"))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 602b5fac20d3b..05851d5116751 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -23,7 +23,7 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
@@ -132,15 +132,14 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
/** (private[ml]) Train a decision tree on an RDD */
private[ml] def train(
- data: RDD[LabeledPoint],
+ data: RDD[Instance],
oldStrategy: OldStrategy,
featureSubsetStrategy: String): DecisionTreeRegressionModel = instrumented { instr =>
instr.logPipelineStage(this)
instr.logDataset(data)
instr.logParams(this, params: _*)
- val instances = data.map(_.toInstance)
- val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
+ val trees = RandomForest.run(data, oldStrategy, numTrees = 1,
featureSubsetStrategy, seed = $(seed), instr = Some(instr), parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeRegressionModel]
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 00c0bc9f5e282..9c38647642a61 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -24,7 +24,6 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
@@ -34,7 +33,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
-import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
+import org.apache.spark.sql.{Column, DataFrame, Dataset}
import org.apache.spark.sql.functions._
/**
@@ -78,6 +77,10 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.4.0")
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
+ /** @group setParam */
+ @Since("3.0.0")
+ def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
+
/** @group setParam */
@Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
@@ -151,29 +154,35 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
set(validationIndicatorCol, value)
}
- override protected def train(dataset: Dataset[_]): GBTRegressionModel = instrumented { instr =>
- val categoricalFeatures: Map[Int, Int] =
- MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+ /**
+ * Sets the value of param [[weightCol]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
+ * By default the weightCol is not set, so all instances have weight 1.0.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+ override protected def train(dataset: Dataset[_]): GBTRegressionModel = instrumented { instr =>
val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty
val (trainDataset, validationDataset) = if (withValidation) {
- (
- extractLabeledPoints(dataset.filter(not(col($(validationIndicatorCol))))),
- extractLabeledPoints(dataset.filter(col($(validationIndicatorCol))))
- )
+ (extractInstances(dataset.filter(not(col($(validationIndicatorCol))))),
+ extractInstances(dataset.filter(col($(validationIndicatorCol)))))
} else {
- (extractLabeledPoints(dataset), null)
+ (extractInstances(dataset), null)
}
- val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
instr.logPipelineStage(this)
instr.logDataset(dataset)
- instr.logParams(this, labelCol, featuresCol, predictionCol, leafCol, impurity, lossType,
- maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
- seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy,
- validationIndicatorCol, validationTol)
+ instr.logParams(this, labelCol, featuresCol, predictionCol, leafCol, weightCol, impurity,
+ lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
+ minWeightFractionPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval,
+ featureSubsetStrategy, validationIndicatorCol, validationTol)
+ val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+ val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
val (baseLearners, learnerWeights) = if (withValidation) {
GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy,
$(seed), $(featureSubsetStrategy))
@@ -181,6 +190,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
GradientBoostedTrees.run(trainDataset, boostingStrategy,
$(seed), $(featureSubsetStrategy))
}
+ baseLearners.foreach(copyValues(_))
val numFeatures = baseLearners.head.numFeatures
instr.logNumFeatures(numFeatures)
@@ -322,9 +332,7 @@ class GBTRegressionModel private[ml](
*/
@Since("2.4.0")
def evaluateEachIteration(dataset: Dataset[_], loss: String): Array[Double] = {
- val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
- case Row(label: Double, features: Vector) => LabeledPoint(label, features)
- }
+ val data = extractInstances(dataset)
GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights,
convertToOldLossType(loss), OldAlgo.Regression)
}
@@ -367,10 +375,9 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
- val trees: Array[DecisionTreeRegressionModel] = treesData.map {
+ val trees = treesData.map {
case (treeMetadata, root) =>
- val tree =
- new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+ val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
treeMetadata.getAndSetParams(tree)
tree
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 938aa5acac086..8f78fc1da18c8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -130,6 +130,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
val trees = RandomForest
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeRegressionModel])
+ trees.foreach(copyValues(_))
val numFeatures = trees.head.numFeatures
instr.logNamedValue(Instrumentation.loggerTags.numFeatures, numFeatures)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index c31334c92e1c9..744708258b0ac 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ml.tree.impl
import org.apache.spark.internal.Logging
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -34,13 +34,13 @@ private[spark] object GradientBoostedTrees extends Logging {
/**
* Method to train a gradient boosting model
- * @param input Training dataset: RDD of `LabeledPoint`.
+ * @param input Training dataset: RDD of `Instance`.
* @param seed Random seed.
* @return tuple of ensemble models and weights:
* (array of decision tree models, array of model weights)
*/
def run(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
boostingStrategy: OldBoostingStrategy,
seed: Long,
featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = {
@@ -51,7 +51,7 @@ private[spark] object GradientBoostedTrees extends Logging {
seed, featureSubsetStrategy)
case OldAlgo.Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
- val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val remappedInput = input.map(x => Instance((x.label * 2) - 1, x.weight, x.features))
GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
seed, featureSubsetStrategy)
case _ =>
@@ -61,7 +61,7 @@ private[spark] object GradientBoostedTrees extends Logging {
/**
* Method to validate a gradient boosting model
- * @param input Training dataset: RDD of `LabeledPoint`.
+ * @param input Training dataset: RDD of `Instance`.
* @param validationInput Validation dataset.
* This dataset should be different from the training dataset,
* but it should follow the same distribution.
@@ -72,8 +72,8 @@ private[spark] object GradientBoostedTrees extends Logging {
* (array of decision tree models, array of model weights)
*/
def runWithValidation(
- input: RDD[LabeledPoint],
- validationInput: RDD[LabeledPoint],
+ input: RDD[Instance],
+ validationInput: RDD[Instance],
boostingStrategy: OldBoostingStrategy,
seed: Long,
featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = {
@@ -85,9 +85,9 @@ private[spark] object GradientBoostedTrees extends Logging {
case OldAlgo.Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(
- x => new LabeledPoint((x.label * 2) - 1, x.features))
+ x => Instance((x.label * 2) - 1, x.weight, x.features))
val remappedValidationInput = validationInput.map(
- x => new LabeledPoint((x.label * 2) - 1, x.features))
+ x => Instance((x.label * 2) - 1, x.weight, x.features))
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
validate = true, seed, featureSubsetStrategy)
case _ =>
@@ -106,13 +106,13 @@ private[spark] object GradientBoostedTrees extends Logging {
* corresponding to every sample.
*/
def computeInitialPredictionAndError(
- data: RDD[LabeledPoint],
+ data: RDD[Instance],
initTreeWeight: Double,
initTree: DecisionTreeRegressionModel,
loss: OldLoss): RDD[(Double, Double)] = {
- data.map { lp =>
- val pred = updatePrediction(lp.features, 0.0, initTree, initTreeWeight)
- val error = loss.computeError(pred, lp.label)
+ data.map { case Instance(label, _, features) =>
+ val pred = updatePrediction(features, 0.0, initTree, initTreeWeight)
+ val error = loss.computeError(pred, label)
(pred, error)
}
}
@@ -129,20 +129,17 @@ private[spark] object GradientBoostedTrees extends Logging {
* corresponding to each sample.
*/
def updatePredictionError(
- data: RDD[LabeledPoint],
+ data: RDD[Instance],
predictionAndError: RDD[(Double, Double)],
treeWeight: Double,
tree: DecisionTreeRegressionModel,
loss: OldLoss): RDD[(Double, Double)] = {
-
- val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
- iter.map { case (lp, (pred, error)) =>
- val newPred = updatePrediction(lp.features, pred, tree, treeWeight)
- val newError = loss.computeError(newPred, lp.label)
+ data.zip(predictionAndError).map {
+ case (Instance(label, _, features), (pred, _)) =>
+ val newPred = updatePrediction(features, pred, tree, treeWeight)
+ val newError = loss.computeError(newPred, label)
(newPred, newError)
- }
}
- newPredError
}
/**
@@ -166,29 +163,50 @@ private[spark] object GradientBoostedTrees extends Logging {
* Method to calculate error of the base learner for the gradient boosting calculation.
* Note: This method is not used by the gradient boosting algorithm but is useful for debugging
* purposes.
- * @param data Training dataset: RDD of `LabeledPoint`.
+ * @param data Training dataset: RDD of `Instance`.
* @param trees Boosted Decision Tree models
* @param treeWeights Learning rates at each boosting iteration.
* @param loss evaluation metric.
* @return Measure of model error on data
*/
- def computeError(
- data: RDD[LabeledPoint],
+ def computeWeightedError(
+ data: RDD[Instance],
trees: Array[DecisionTreeRegressionModel],
treeWeights: Array[Double],
loss: OldLoss): Double = {
- data.map { lp =>
+ val (errSum, weightSum) = data.map { case Instance(label, weight, features) =>
val predicted = trees.zip(treeWeights).foldLeft(0.0) { case (acc, (model, weight)) =>
- updatePrediction(lp.features, acc, model, weight)
+ updatePrediction(features, acc, model, weight)
}
- loss.computeError(predicted, lp.label)
- }.mean()
+ (loss.computeError(predicted, label) * weight, weight)
+ }.treeReduce { case ((err1, weight1), (err2, weight2)) =>
+ (err1 + err2, weight1 + weight2)
+ }
+ errSum / weightSum
+ }
+
+ /**
+ * Method to calculate error of the base learner for the gradient boosting calculation.
+ * @param data Training dataset: RDD of `Instance`.
+ * @param predError Prediction and error.
+ * @return Measure of model error on data
+ */
+ def computeWeightedError(
+ data: RDD[Instance],
+ predError: RDD[(Double, Double)]): Double = {
+ val (errSum, weightSum) = data.zip(predError).map {
+ case (Instance(_, weight, _), (_, err)) =>
+ (err * weight, weight)
+ }.treeReduce { case ((err1, weight1), (err2, weight2)) =>
+ (err1 + err2, weight1 + weight2)
+ }
+ errSum / weightSum
}
/**
* Method to compute error or loss for every iteration of gradient boosting.
*
- * @param data RDD of `LabeledPoint`
+ * @param data RDD of `Instance`
* @param trees Boosted Decision Tree models
* @param treeWeights Learning rates at each boosting iteration.
* @param loss evaluation metric.
@@ -197,41 +215,34 @@ private[spark] object GradientBoostedTrees extends Logging {
* containing the first i+1 trees
*/
def evaluateEachIteration(
- data: RDD[LabeledPoint],
+ data: RDD[Instance],
trees: Array[DecisionTreeRegressionModel],
treeWeights: Array[Double],
loss: OldLoss,
algo: OldAlgo.Value): Array[Double] = {
-
- val sc = data.sparkContext
val remappedData = algo match {
- case OldAlgo.Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ case OldAlgo.Classification =>
+ data.map(x => Instance((x.label * 2) - 1, x.weight, x.features))
case _ => data
}
- val broadcastTrees = sc.broadcast(trees)
- val localTreeWeights = treeWeights
- val treesIndices = trees.indices
-
- val dataCount = remappedData.count()
- val evaluation = remappedData.map { point =>
- treesIndices.map { idx =>
- val prediction = broadcastTrees.value(idx)
- .rootNode
- .predictImpl(point.features)
- .prediction
- prediction * localTreeWeights(idx)
+ val numTrees = trees.length
+ val (errSum, weightSum) = remappedData.mapPartitions { iter =>
+ iter.map { case Instance(label, weight, features) =>
+ val pred = Array.tabulate(numTrees) { i =>
+ trees(i).rootNode.predictImpl(features)
+ .prediction * treeWeights(i)
+ }
+ val err = pred.scanLeft(0.0)(_ + _).drop(1)
+ .map(p => loss.computeError(p, label) * weight)
+ (err, weight)
}
- .scanLeft(0.0)(_ + _).drop(1)
- .map(prediction => loss.computeError(prediction, point.label))
+ }.treeReduce { case ((err1, weight1), (err2, weight2)) =>
+ (0 until numTrees).foreach(i => err1(i) += err2(i))
+ (err1, weight1 + weight2)
}
- .aggregate(treesIndices.map(_ => 0.0))(
- (aggregated, row) => treesIndices.map(idx => aggregated(idx) + row(idx)),
- (a, b) => treesIndices.map(idx => a(idx) + b(idx)))
- .map(_ / dataCount)
- broadcastTrees.destroy()
- evaluation.toArray
+ errSum.map(_ / weightSum)
}
/**
@@ -245,8 +256,8 @@ private[spark] object GradientBoostedTrees extends Logging {
* (array of decision tree models, array of model weights)
*/
def boost(
- input: RDD[LabeledPoint],
- validationInput: RDD[LabeledPoint],
+ input: RDD[Instance],
+ validationInput: RDD[Instance],
boostingStrategy: OldBoostingStrategy,
validate: Boolean,
seed: Long,
@@ -280,8 +291,10 @@ private[spark] object GradientBoostedTrees extends Logging {
}
// Prepare periodic checkpointers
+ // Note: this is checkpointing the unweighted training error
val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
treeStrategy.getCheckpointInterval, input.sparkContext)
+ // Note: this is checkpointing the unweighted validation error
val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
treeStrategy.getCheckpointInterval, input.sparkContext)
@@ -299,26 +312,29 @@ private[spark] object GradientBoostedTrees extends Logging {
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight
- var predError: RDD[(Double, Double)] =
- computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
+ var predError = computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
predErrorCheckpointer.update(predError)
- logDebug("error of gbt = " + predError.values.mean())
+ logDebug("error of gbt = " + computeWeightedError(input, predError))
// Note: A model of type regression is used since we require raw prediction
timer.stop("building tree 0")
- var validatePredError: RDD[(Double, Double)] =
+ var validatePredError =
computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
if (validate) validatePredErrorCheckpointer.update(validatePredError)
- var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
+ var bestValidateError = if (validate) {
+ computeWeightedError(validationInput, validatePredError)
+ } else {
+ 0.0
+ }
var bestM = 1
var m = 1
var doneLearning = false
while (m < numIterations && !doneLearning) {
// Update data with pseudo-residuals
- val data = predError.zip(input).map { case ((pred, _), point) =>
- LabeledPoint(-loss.gradient(pred, point.label), point.features)
+ val data = predError.zip(input).map { case ((pred, _), Instance(label, weight, features)) =>
+ Instance(-loss.gradient(pred, label), weight, features)
}
timer.start(s"building tree $m")
@@ -339,7 +355,7 @@ private[spark] object GradientBoostedTrees extends Logging {
predError = updatePredictionError(
input, predError, baseLearnerWeights(m), baseLearners(m), loss)
predErrorCheckpointer.update(predError)
- logDebug("error of gbt = " + predError.values.mean())
+ logDebug("error of gbt = " + computeWeightedError(input, predError))
if (validate) {
// Stop training early if
@@ -350,7 +366,7 @@ private[spark] object GradientBoostedTrees extends Logging {
validatePredError = updatePredictionError(
validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
validatePredErrorCheckpointer.update(validatePredError)
- val currentValidateError = validatePredError.values.mean()
+ val currentValidateError = computeWeightedError(validationInput, validatePredError)
if (bestValidateError - currentValidateError < validationTol * Math.max(
currentValidateError, 0.01)) {
doneLearning = true
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index d24d8da0dab48..d57f1b36a572c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
-import org.apache.spark.ml.feature.{LabeledPoint => NewLabeledPoint}
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.tree.impl.{GradientBoostedTrees => NewGBT}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
@@ -67,8 +67,9 @@ class GradientBoostedTrees private[spark] (
@Since("1.2.0")
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
val algo = boostingStrategy.treeStrategy.algo
- val (trees, treeWeights) = NewGBT.run(input.map { point =>
- NewLabeledPoint(point.label, point.features.asML)
+ val (trees, treeWeights) = NewGBT.run(input.map {
+ case LabeledPoint(label, features) =>
+ Instance(label, 1.0, features.asML)
}, boostingStrategy, seed.toLong, "all")
new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
}
@@ -97,10 +98,12 @@ class GradientBoostedTrees private[spark] (
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = {
val algo = boostingStrategy.treeStrategy.algo
- val (trees, treeWeights) = NewGBT.runWithValidation(input.map { point =>
- NewLabeledPoint(point.label, point.features.asML)
- }, validationInput.map { point =>
- NewLabeledPoint(point.label, point.features.asML)
+ val (trees, treeWeights) = NewGBT.runWithValidation(input.map {
+ case LabeledPoint(label, features) =>
+ Instance(label, 1.0, features.asML)
+ }, validationInput.map {
+ case LabeledPoint(label, features) =>
+ Instance(label, 1.0, features.asML)
}, boostingStrategy, seed.toLong, "all")
new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index af3dd201d3b51..fdca71f8911c6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -20,7 +20,8 @@ package org.apache.spark.ml.classification
import com.github.fommil.netlib.BLAS
import org.apache.spark.{SparkException, SparkFunSuite}
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.classification.LinearSVCSuite.generateSVMInput
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
@@ -52,8 +53,10 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
private var data: RDD[LabeledPoint] = _
private var trainData: RDD[LabeledPoint] = _
private var validationData: RDD[LabeledPoint] = _
+ private var binaryDataset: DataFrame = _
private val eps: Double = 1e-5
private val absEps: Double = 1e-8
+ private val seed = 42
override def beforeAll(): Unit = {
super.beforeAll()
@@ -65,6 +68,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
validationData =
sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2)
.map(_.asML)
+ binaryDataset = generateSVMInput(0.01, Array[Double](-1.5, 1.0), 1000, seed).toDF()
}
test("params") {
@@ -362,7 +366,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
test("Tests of feature subset strategy") {
val numClasses = 2
val gbt = new GBTClassifier()
- .setSeed(42)
+ .setSeed(seed)
.setMaxDepth(3)
.setMaxIter(5)
.setFeatureSubsetStrategy("all")
@@ -397,13 +401,15 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures, model3.numClasses)
val evalArr = model3.evaluateEachIteration(validationData.toDF)
- val remappedValidationData = validationData.map(
- x => new LabeledPoint((x.label * 2) - 1, x.features))
- val lossErr1 = GradientBoostedTrees.computeError(remappedValidationData,
+ val remappedValidationData = validationData.map {
+ case LabeledPoint(label, features) =>
+ Instance(label * 2 - 1, 1.0, features)
+ }
+ val lossErr1 = GradientBoostedTrees.computeWeightedError(remappedValidationData,
model1.trees, model1.treeWeights, model1.getOldLossType)
- val lossErr2 = GradientBoostedTrees.computeError(remappedValidationData,
+ val lossErr2 = GradientBoostedTrees.computeWeightedError(remappedValidationData,
model2.trees, model2.treeWeights, model2.getOldLossType)
- val lossErr3 = GradientBoostedTrees.computeError(remappedValidationData,
+ val lossErr3 = GradientBoostedTrees.computeWeightedError(remappedValidationData,
model3.trees, model3.treeWeights, model3.getOldLossType)
assert(evalArr(0) ~== lossErr1 relTol 1E-3)
@@ -433,16 +439,19 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
assert(modelWithValidation.numTrees < numIter)
val (errorWithoutValidation, errorWithValidation) = {
- val remappedRdd = validationData.map(x => new LabeledPoint(2 * x.label - 1, x.features))
- (GradientBoostedTrees.computeError(remappedRdd, modelWithoutValidation.trees,
+ val remappedRdd = validationData.map {
+ case LabeledPoint(label, features) =>
+ Instance(label * 2 - 1, 1.0, features)
+ }
+ (GradientBoostedTrees.computeWeightedError(remappedRdd, modelWithoutValidation.trees,
modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType),
- GradientBoostedTrees.computeError(remappedRdd, modelWithValidation.trees,
+ GradientBoostedTrees.computeWeightedError(remappedRdd, modelWithValidation.trees,
modelWithValidation.treeWeights, modelWithValidation.getOldLossType))
}
assert(errorWithValidation < errorWithoutValidation)
val evaluationArray = GradientBoostedTrees
- .evaluateEachIteration(validationData, modelWithoutValidation.trees,
+ .evaluateEachIteration(validationData.map(_.toInstance), modelWithoutValidation.trees,
modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType,
OldAlgo.Classification)
assert(evaluationArray.length === numIter)
@@ -456,6 +465,52 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
}
}
+ test("tree params") {
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
+ val gbt = new GBTClassifier()
+ .setMaxDepth(2)
+ .setCheckpointInterval(5)
+ .setSeed(123)
+ val model = gbt.fit(df)
+
+ model.trees.foreach (i => {
+ assert(i.getMaxDepth === model.getMaxDepth)
+ assert(i.getCheckpointInterval === model.getCheckpointInterval)
+ assert(i.getSeed === model.getSeed)
+ })
+ }
+
+ test("training with sample weights") {
+ val df = binaryDataset
+ val numClasses = 2
+ val predEquals = (x: Double, y: Double) => x == y
+ // (maxIter, maxDepth)
+ val testParams = Seq(
+ (5, 5),
+ (5, 10)
+ )
+
+ for ((maxIter, maxDepth) <- testParams) {
+ val estimator = new GBTClassifier()
+ .setMaxIter(maxIter)
+ .setMaxDepth(maxDepth)
+ .setSeed(seed)
+ .setMinWeightFractionPerNode(0.049)
+
+ MLTestingUtils.testArbitrarilyScaledWeights[GBTClassificationModel,
+ GBTClassifier](df.as[LabeledPoint], estimator,
+ MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7))
+ MLTestingUtils.testOutliersWithSmallWeights[GBTClassificationModel,
+ GBTClassifier](df.as[LabeledPoint], estimator,
+ numClasses, MLTestingUtils.modelPredictionEquals(df, predEquals, 0.8),
+ outlierRatio = 2)
+ MLTestingUtils.testOversamplingVsWeighting[GBTClassificationModel,
+ GBTClassifier](df.as[LabeledPoint], estimator,
+ MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7), seed)
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 2b5a9a396effd..d2b8751360e9e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -1425,8 +1425,6 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
}
test("multinomial logistic regression with zero variance (SPARK-21681)") {
- val sqlContext = multinomialDatasetWithZeroVar.sqlContext
- import sqlContext.implicits._
val mlr = new LogisticRegression().setFamily("multinomial").setFitIntercept(true)
.setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true).setWeightCol("weight")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index f03ed0b76eb80..5958bfcf5ea6d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -230,6 +230,26 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
}
}
+ test("tree params") {
+ val rdd = orderedLabeledPoints5_20
+ val rf = new RandomForestClassifier()
+ .setImpurity("entropy")
+ .setMaxDepth(3)
+ .setNumTrees(3)
+ .setSeed(123)
+ val categoricalFeatures = Map.empty[Int, Int]
+ val numClasses = 2
+
+ val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
+ val model = rf.fit(df)
+
+ model.trees.foreach (i => {
+ assert(i.getMaxDepth === model.getMaxDepth)
+ assert(i.getSeed === model.getSeed)
+ assert(i.getImpurity === model.getImpurity)
+ })
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 60007975c3b52..b772a3b7737d0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.lit
@@ -46,6 +47,8 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
private var data: RDD[LabeledPoint] = _
private var trainData: RDD[LabeledPoint] = _
private var validationData: RDD[LabeledPoint] = _
+ private var linearRegressionData: DataFrame = _
+ private val seed = 42
override def beforeAll(): Unit = {
super.beforeAll()
@@ -57,6 +60,9 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
validationData =
sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2)
.map(_.asML)
+ linearRegressionData = sc.parallelize(LinearDataGenerator.generateLinearInput(
+ intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3),
+ xVariance = Array(0.7, 1.2), nPoints = 1000, seed, eps = 0.5), 2).map(_.asML).toDF()
}
test("Regression with continuous features") {
@@ -202,7 +208,7 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
val gbt = new GBTRegressor()
.setMaxDepth(3)
.setMaxIter(5)
- .setSeed(42)
+ .setSeed(seed)
.setFeatureSubsetStrategy("all")
// In this data, feature 1 is very important.
@@ -237,11 +243,11 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
for (evalLossType <- GBTRegressor.supportedLossTypes) {
val evalArr = model3.evaluateEachIteration(validationData.toDF, evalLossType)
- val lossErr1 = GradientBoostedTrees.computeError(validationData,
+ val lossErr1 = GradientBoostedTrees.computeWeightedError(validationData.map(_.toInstance),
model1.trees, model1.treeWeights, model1.convertToOldLossType(evalLossType))
- val lossErr2 = GradientBoostedTrees.computeError(validationData,
+ val lossErr2 = GradientBoostedTrees.computeWeightedError(validationData.map(_.toInstance),
model2.trees, model2.treeWeights, model2.convertToOldLossType(evalLossType))
- val lossErr3 = GradientBoostedTrees.computeError(validationData,
+ val lossErr3 = GradientBoostedTrees.computeWeightedError(validationData.map(_.toInstance),
model3.trees, model3.treeWeights, model3.convertToOldLossType(evalLossType))
assert(evalArr(0) ~== lossErr1 relTol 1E-3)
@@ -272,17 +278,19 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
// early stop
assert(modelWithValidation.numTrees < numIter)
- val errorWithoutValidation = GradientBoostedTrees.computeError(validationData,
+ val errorWithoutValidation = GradientBoostedTrees.computeWeightedError(
+ validationData.map(_.toInstance),
modelWithoutValidation.trees, modelWithoutValidation.treeWeights,
modelWithoutValidation.getOldLossType)
- val errorWithValidation = GradientBoostedTrees.computeError(validationData,
+ val errorWithValidation = GradientBoostedTrees.computeWeightedError(
+ validationData.map(_.toInstance),
modelWithValidation.trees, modelWithValidation.treeWeights,
modelWithValidation.getOldLossType)
assert(errorWithValidation < errorWithoutValidation)
val evaluationArray = GradientBoostedTrees
- .evaluateEachIteration(validationData, modelWithoutValidation.trees,
+ .evaluateEachIteration(validationData.map(_.toInstance), modelWithoutValidation.trees,
modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType,
OldAlgo.Regression)
assert(evaluationArray.length === numIter)
@@ -296,7 +304,50 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
}
}
- /////////////////////////////////////////////////////////////////////////////
+ test("tree params") {
+ val gbt = new GBTRegressor()
+ .setMaxDepth(2)
+ .setCheckpointInterval(5)
+ .setSeed(123)
+ val model = gbt.fit(trainData.toDF)
+
+ model.trees.foreach (i => {
+ assert(i.getMaxDepth === model.getMaxDepth)
+ assert(i.getCheckpointInterval === model.getCheckpointInterval)
+ assert(i.getSeed === model.getSeed)
+ })
+ }
+
+ test("training with sample weights") {
+ val df = linearRegressionData
+ val numClasses = 0
+ // (maxIter, maxDepth)
+ val testParams = Seq(
+ (5, 5),
+ (5, 10)
+ )
+
+ for ((maxIter, maxDepth) <- testParams) {
+ val estimator = new GBTRegressor()
+ .setMaxIter(maxIter)
+ .setMaxDepth(maxDepth)
+ .setSeed(seed)
+ .setMinWeightFractionPerNode(0.1)
+
+ MLTestingUtils.testArbitrarilyScaledWeights[GBTRegressionModel,
+ GBTRegressor](df.as[LabeledPoint], estimator,
+ MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.95))
+ MLTestingUtils.testOutliersWithSmallWeights[GBTRegressionModel,
+ GBTRegressor](df.as[LabeledPoint], estimator, numClasses,
+ MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.95),
+ outlierRatio = 2)
+ MLTestingUtils.testOversamplingVsWeighting[GBTRegressionModel,
+ GBTRegressor](df.as[LabeledPoint], estimator,
+ MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.01, 0.95), seed)
+ }
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 0243e8d2335ee..f3b0f0470e579 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -139,6 +139,25 @@ class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{
}
}
+ test("tree params") {
+ val rf = new RandomForestRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setNumTrees(3)
+ .setSeed(123)
+
+ val df = orderedLabeledPoints50_1000.toDF()
+ val model = rf.fit(df)
+
+ model.trees.foreach (i => {
+ assert(i.getMaxDepth === model.getMaxDepth)
+ assert(i.getSeed === model.getSeed)
+ assert(i.getImpurity === model.getImpurity)
+ assert(i.getMaxBins === model.getMaxBins)
+ })
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
index 366d5ec3a53fb..18fc1407557f1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.tree.impl
import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.mllib.tree.{GradientBoostedTreesSuite => OldGBTSuite}
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
@@ -32,15 +32,12 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
*/
class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
- import testImplicits._
-
test("runWithValidation stops early and performs better on a validation dataset") {
// Set numIterations large enough so that it stops early.
val numIterations = 20
- val trainRdd = sc.parallelize(OldGBTSuite.trainData, 2).map(_.asML)
- val validateRdd = sc.parallelize(OldGBTSuite.validateData, 2).map(_.asML)
- val trainDF = trainRdd.toDF()
- val validateDF = validateRdd.toDF()
+ val trainRdd = sc.parallelize(OldGBTSuite.trainData, 2).map(_.asML.toInstance)
+ val validateRdd = sc.parallelize(OldGBTSuite.validateData, 2).map(_.asML.toInstance)
+ val seed = 42
val algos = Array(Regression, Regression, Classification)
val losses = Array(SquaredError, AbsoluteError, LogLoss)
@@ -50,21 +47,21 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
val boostingStrategy =
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
val (validateTrees, validateTreeWeights) = GradientBoostedTrees
- .runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L, "all")
+ .runWithValidation(trainRdd, validateRdd, boostingStrategy, seed, "all")
val numTrees = validateTrees.length
assert(numTrees !== numIterations)
// Test that it performs better on the validation dataset.
- val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L, "all")
+ val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, seed, "all")
val (errorWithoutValidation, errorWithValidation) = {
if (algo == Classification) {
- val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
- (GradientBoostedTrees.computeError(remappedRdd, trees, treeWeights, loss),
- GradientBoostedTrees.computeError(remappedRdd, validateTrees,
+ val remappedRdd = validateRdd.map(x => Instance(2 * x.label - 1, x.weight, x.features))
+ (GradientBoostedTrees.computeWeightedError(remappedRdd, trees, treeWeights, loss),
+ GradientBoostedTrees.computeWeightedError(remappedRdd, validateTrees,
validateTreeWeights, loss))
} else {
- (GradientBoostedTrees.computeError(validateRdd, trees, treeWeights, loss),
- GradientBoostedTrees.computeError(validateRdd, validateTrees,
+ (GradientBoostedTrees.computeWeightedError(validateRdd, trees, treeWeights, loss),
+ GradientBoostedTrees.computeWeightedError(validateRdd, validateTrees,
validateTreeWeights, loss))
}
}
diff --git a/pom.xml b/pom.xml
index 69b5b79b7b071..f1a7cb3d106f1 100644
--- a/pom.xml
+++ b/pom.xml
@@ -136,7 +136,7 @@
1.2.1
- 2.3.0
+ 2.3.1
10.12.1.1
1.10.1
1.5.6
@@ -148,7 +148,7 @@
0.9.3
2.4.0
2.0.8
- 3.1.5
+ 3.2.6
1.8.2
hadoop2
1.8.10
diff --git a/project/plugins.sbt b/project/plugins.sbt
index d1fe59a47217c..02525c27b6aac 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -32,6 +32,9 @@ addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.2")
addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0")
+// SPARK-29560 Only sbt-mima-plugin needs this repo
+resolvers += Resolver.url("bintray",
+ new java.net.URL("https://dl.bintray.com/typesafe/sbt-plugins"))(Resolver.defaultIvyPatterns)
addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.3.0")
// sbt 1.0.0 support: https://github.com/AlpineNow/junit_xml_listener/issues/6
diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py
index 7df5f6c748ad1..09d3a5e7cfb6f 100644
--- a/python/pyspark/cloudpickle.py
+++ b/python/pyspark/cloudpickle.py
@@ -44,7 +44,6 @@
import dis
from functools import partial
-import importlib
import io
import itertools
import logging
@@ -56,12 +55,26 @@
import traceback
import types
import weakref
+import uuid
+import threading
+
+
+try:
+ from enum import Enum
+except ImportError:
+ Enum = None
# cloudpickle is meant for inter process communication: we expect all
# communicating processes to run the same Python version hence we favor
# communication speed over compatibility:
DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL
+# Track the provenance of reconstructed dynamic classes to make it possible to
+# recontruct instances from the matching singleton class definition when
+# appropriate and preserve the usual "isinstance" semantics of Python objects.
+_DYNAMIC_CLASS_TRACKER_BY_CLASS = weakref.WeakKeyDictionary()
+_DYNAMIC_CLASS_TRACKER_BY_ID = weakref.WeakValueDictionary()
+_DYNAMIC_CLASS_TRACKER_LOCK = threading.Lock()
if sys.version_info[0] < 3: # pragma: no branch
from pickle import Pickler
@@ -71,12 +84,37 @@
from StringIO import StringIO
string_types = (basestring,) # noqa
PY3 = False
+ PY2 = True
+ PY2_WRAPPER_DESCRIPTOR_TYPE = type(object.__init__)
+ PY2_METHOD_WRAPPER_TYPE = type(object.__eq__)
+ PY2_CLASS_DICT_BLACKLIST = (PY2_METHOD_WRAPPER_TYPE,
+ PY2_WRAPPER_DESCRIPTOR_TYPE)
else:
types.ClassType = type
from pickle import _Pickler as Pickler
from io import BytesIO as StringIO
string_types = (str,)
PY3 = True
+ PY2 = False
+
+
+def _ensure_tracking(class_def):
+ with _DYNAMIC_CLASS_TRACKER_LOCK:
+ class_tracker_id = _DYNAMIC_CLASS_TRACKER_BY_CLASS.get(class_def)
+ if class_tracker_id is None:
+ class_tracker_id = uuid.uuid4().hex
+ _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
+ _DYNAMIC_CLASS_TRACKER_BY_ID[class_tracker_id] = class_def
+ return class_tracker_id
+
+
+def _lookup_class_or_track(class_tracker_id, class_def):
+ if class_tracker_id is not None:
+ with _DYNAMIC_CLASS_TRACKER_LOCK:
+ class_def = _DYNAMIC_CLASS_TRACKER_BY_ID.setdefault(
+ class_tracker_id, class_def)
+ _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
+ return class_def
def _make_cell_set_template_code():
@@ -112,7 +150,7 @@ def inner(value):
# NOTE: we are marking the cell variable as a free variable intentionally
# so that we simulate an inner function instead of the outer function. This
# is what gives us the ``nonlocal`` behavior in a Python 2 compatible way.
- if not PY3: # pragma: no branch
+ if PY2: # pragma: no branch
return types.CodeType(
co.co_argcount,
co.co_nlocals,
@@ -130,24 +168,43 @@ def inner(value):
(),
)
else:
- return types.CodeType(
- co.co_argcount,
- co.co_kwonlyargcount,
- co.co_nlocals,
- co.co_stacksize,
- co.co_flags,
- co.co_code,
- co.co_consts,
- co.co_names,
- co.co_varnames,
- co.co_filename,
- co.co_name,
- co.co_firstlineno,
- co.co_lnotab,
- co.co_cellvars, # this is the trickery
- (),
- )
-
+ if hasattr(types.CodeType, "co_posonlyargcount"): # pragma: no branch
+ return types.CodeType(
+ co.co_argcount,
+ co.co_posonlyargcount, # Python3.8 with PEP570
+ co.co_kwonlyargcount,
+ co.co_nlocals,
+ co.co_stacksize,
+ co.co_flags,
+ co.co_code,
+ co.co_consts,
+ co.co_names,
+ co.co_varnames,
+ co.co_filename,
+ co.co_name,
+ co.co_firstlineno,
+ co.co_lnotab,
+ co.co_cellvars, # this is the trickery
+ (),
+ )
+ else:
+ return types.CodeType(
+ co.co_argcount,
+ co.co_kwonlyargcount,
+ co.co_nlocals,
+ co.co_stacksize,
+ co.co_flags,
+ co.co_code,
+ co.co_consts,
+ co.co_names,
+ co.co_varnames,
+ co.co_filename,
+ co.co_name,
+ co.co_firstlineno,
+ co.co_lnotab,
+ co.co_cellvars, # this is the trickery
+ (),
+ )
_cell_set_template_code = _make_cell_set_template_code()
@@ -220,7 +277,7 @@ def _walk_global_ops(code):
global-referencing instructions in *code*.
"""
code = getattr(code, 'co_code', b'')
- if not PY3: # pragma: no branch
+ if PY2: # pragma: no branch
code = map(ord, code)
n = len(code)
@@ -250,6 +307,39 @@ def _walk_global_ops(code):
yield op, instr.arg
+def _extract_class_dict(cls):
+ """Retrieve a copy of the dict of a class without the inherited methods"""
+ clsdict = dict(cls.__dict__) # copy dict proxy to a dict
+ if len(cls.__bases__) == 1:
+ inherited_dict = cls.__bases__[0].__dict__
+ else:
+ inherited_dict = {}
+ for base in reversed(cls.__bases__):
+ inherited_dict.update(base.__dict__)
+ to_remove = []
+ for name, value in clsdict.items():
+ try:
+ base_value = inherited_dict[name]
+ if value is base_value:
+ to_remove.append(name)
+ elif PY2:
+ # backward compat for Python 2
+ if hasattr(value, "im_func"):
+ if value.im_func is getattr(base_value, "im_func", None):
+ to_remove.append(name)
+ elif isinstance(value, PY2_CLASS_DICT_BLACKLIST):
+ # On Python 2 we have no way to pickle those specific
+ # methods types nor to check that they are actually
+ # inherited. So we assume that they are always inherited
+ # from builtin types.
+ to_remove.append(name)
+ except KeyError:
+ pass
+ for name in to_remove:
+ clsdict.pop(name)
+ return clsdict
+
+
class CloudPickler(Pickler):
dispatch = Pickler.dispatch.copy()
@@ -277,7 +367,7 @@ def save_memoryview(self, obj):
dispatch[memoryview] = save_memoryview
- if not PY3: # pragma: no branch
+ if PY2: # pragma: no branch
def save_buffer(self, obj):
self.save(str(obj))
@@ -300,12 +390,23 @@ def save_codeobject(self, obj):
Save a code object
"""
if PY3: # pragma: no branch
- args = (
- obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
- obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, obj.co_varnames,
- obj.co_filename, obj.co_name, obj.co_firstlineno, obj.co_lnotab, obj.co_freevars,
- obj.co_cellvars
- )
+ if hasattr(obj, "co_posonlyargcount"): # pragma: no branch
+ args = (
+ obj.co_argcount, obj.co_posonlyargcount,
+ obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
+ obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
+ obj.co_varnames, obj.co_filename, obj.co_name,
+ obj.co_firstlineno, obj.co_lnotab, obj.co_freevars,
+ obj.co_cellvars
+ )
+ else:
+ args = (
+ obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals,
+ obj.co_stacksize, obj.co_flags, obj.co_code, obj.co_consts,
+ obj.co_names, obj.co_varnames, obj.co_filename,
+ obj.co_name, obj.co_firstlineno, obj.co_lnotab,
+ obj.co_freevars, obj.co_cellvars
+ )
else:
args = (
obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code,
@@ -460,15 +561,40 @@ def func():
# then discards the reference to it
self.write(pickle.POP)
- def save_dynamic_class(self, obj):
+ def _save_dynamic_enum(self, obj, clsdict):
+ """Special handling for dynamic Enum subclasses
+
+ Use a dedicated Enum constructor (inspired by EnumMeta.__call__) as the
+ EnumMeta metaclass has complex initialization that makes the Enum
+ subclasses hold references to their own instances.
"""
- Save a class that can't be stored as module global.
+ members = dict((e.name, e.value) for e in obj)
+
+ # Python 2.7 with enum34 can have no qualname:
+ qualname = getattr(obj, "__qualname__", None)
+
+ self.save_reduce(_make_skeleton_enum,
+ (obj.__bases__, obj.__name__, qualname, members,
+ obj.__module__, _ensure_tracking(obj), None),
+ obj=obj)
+
+ # Cleanup the clsdict that will be passed to _rehydrate_skeleton_class:
+ # Those attributes are already handled by the metaclass.
+ for attrname in ["_generate_next_value_", "_member_names_",
+ "_member_map_", "_member_type_",
+ "_value2member_map_"]:
+ clsdict.pop(attrname, None)
+ for member in members:
+ clsdict.pop(member)
+
+ def save_dynamic_class(self, obj):
+ """Save a class that can't be stored as module global.
This method is used to serialize classes that are defined inside
functions, or that otherwise can't be serialized as attribute lookups
from global modules.
"""
- clsdict = dict(obj.__dict__) # copy dict proxy to a dict
+ clsdict = _extract_class_dict(obj)
clsdict.pop('__weakref__', None)
# For ABCMeta in python3.7+, remove _abc_impl as it is not picklable.
@@ -496,8 +622,8 @@ def save_dynamic_class(self, obj):
for k in obj.__slots__:
clsdict.pop(k, None)
- # If type overrides __dict__ as a property, include it in the type kwargs.
- # In Python 2, we can't set this attribute after construction.
+ # If type overrides __dict__ as a property, include it in the type
+ # kwargs. In Python 2, we can't set this attribute after construction.
__dict__ = clsdict.pop('__dict__', None)
if isinstance(__dict__, property):
type_kwargs['__dict__'] = __dict__
@@ -524,8 +650,16 @@ def save_dynamic_class(self, obj):
write(pickle.MARK)
# Create and memoize an skeleton class with obj's name and bases.
- tp = type(obj)
- self.save_reduce(tp, (obj.__name__, obj.__bases__, type_kwargs), obj=obj)
+ if Enum is not None and issubclass(obj, Enum):
+ # Special handling of Enum subclasses
+ self._save_dynamic_enum(obj, clsdict)
+ else:
+ # "Regular" class definition:
+ tp = type(obj)
+ self.save_reduce(_make_skeleton_class,
+ (tp, obj.__name__, obj.__bases__, type_kwargs,
+ _ensure_tracking(obj), None),
+ obj=obj)
# Now save the rest of obj's __dict__. Any references to obj
# encountered while saving will point to the skeleton class.
@@ -778,7 +912,7 @@ def save_inst(self, obj):
save(stuff)
write(pickle.BUILD)
- if not PY3: # pragma: no branch
+ if PY2: # pragma: no branch
dispatch[types.InstanceType] = save_inst
def save_property(self, obj):
@@ -1119,6 +1253,22 @@ def _make_skel_func(code, cell_count, base_globals=None):
return types.FunctionType(code, base_globals, None, None, closure)
+def _make_skeleton_class(type_constructor, name, bases, type_kwargs,
+ class_tracker_id, extra):
+ """Build dynamic class with an empty __dict__ to be filled once memoized
+
+ If class_tracker_id is not None, try to lookup an existing class definition
+ matching that id. If none is found, track a newly reconstructed class
+ definition under that id so that other instances stemming from the same
+ class id will also reuse this class definition.
+
+ The "extra" variable is meant to be a dict (or None) that can be used for
+ forward compatibility shall the need arise.
+ """
+ skeleton_class = type_constructor(name, bases, type_kwargs)
+ return _lookup_class_or_track(class_tracker_id, skeleton_class)
+
+
def _rehydrate_skeleton_class(skeleton_class, class_dict):
"""Put attributes from `class_dict` back on `skeleton_class`.
@@ -1137,6 +1287,39 @@ def _rehydrate_skeleton_class(skeleton_class, class_dict):
return skeleton_class
+def _make_skeleton_enum(bases, name, qualname, members, module,
+ class_tracker_id, extra):
+ """Build dynamic enum with an empty __dict__ to be filled once memoized
+
+ The creation of the enum class is inspired by the code of
+ EnumMeta._create_.
+
+ If class_tracker_id is not None, try to lookup an existing enum definition
+ matching that id. If none is found, track a newly reconstructed enum
+ definition under that id so that other instances stemming from the same
+ class id will also reuse this enum definition.
+
+ The "extra" variable is meant to be a dict (or None) that can be used for
+ forward compatibility shall the need arise.
+ """
+ # enums always inherit from their base Enum class at the last position in
+ # the list of base classes:
+ enum_base = bases[-1]
+ metacls = enum_base.__class__
+ classdict = metacls.__prepare__(name, bases)
+
+ for member_name, member_value in members.items():
+ classdict[member_name] = member_value
+ enum_class = metacls.__new__(metacls, name, bases, classdict)
+ enum_class.__module__ = module
+
+ # Python 2.7 compat
+ if qualname is not None:
+ enum_class.__qualname__ = qualname
+
+ return _lookup_class_or_track(class_tracker_id, enum_class)
+
+
def _is_dynamic(module):
"""
Return True if the module is special module that cannot be imported by its
@@ -1176,4 +1359,4 @@ def _reduce_method_descriptor(obj):
import copy_reg as copyreg
except ImportError:
import copyreg
- copyreg.pickle(method_descriptor, _reduce_method_descriptor)
+ copyreg.pickle(method_descriptor, _reduce_method_descriptor)
\ No newline at end of file
diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py
index 82ff81c58d3c6..542cb25172ead 100644
--- a/python/pyspark/ml/base.py
+++ b/python/pyspark/ml/base.py
@@ -194,6 +194,18 @@ class UnaryTransformer(HasInputCol, HasOutputCol, Transformer):
.. versionadded:: 2.3.0
"""
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
@abstractmethod
def createTransformFunc(self):
"""
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index d0c821329471f..c5cdf35729dd8 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -177,7 +177,19 @@ class LinearSVC(JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadable
>>> df = sc.parallelize([
... Row(label=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),
... Row(label=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF()
- >>> svm = LinearSVC(maxIter=5, regParam=0.01)
+ >>> svm = LinearSVC()
+ >>> svm.getMaxIter()
+ 100
+ >>> svm.setMaxIter(5)
+ LinearSVC...
+ >>> svm.getMaxIter()
+ 5
+ >>> svm.getRegParam()
+ 0.0
+ >>> svm.setRegParam(0.01)
+ LinearSVC...
+ >>> svm.getRegParam()
+ 0.01
>>> model = svm.fit(df)
>>> model.setPredictionCol("newPrediction")
LinearSVC...
@@ -257,6 +269,62 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
def _create_model(self, java_model):
return LinearSVCModel(java_model)
+ @since("2.2.0")
+ def setMaxIter(self, value):
+ """
+ Sets the value of :py:attr:`maxIter`.
+ """
+ return self._set(maxIter=value)
+
+ @since("2.2.0")
+ def setRegParam(self, value):
+ """
+ Sets the value of :py:attr:`regParam`.
+ """
+ return self._set(regParam=value)
+
+ @since("2.2.0")
+ def setTol(self, value):
+ """
+ Sets the value of :py:attr:`tol`.
+ """
+ return self._set(tol=value)
+
+ @since("2.2.0")
+ def setFitIntercept(self, value):
+ """
+ Sets the value of :py:attr:`fitIntercept`.
+ """
+ return self._set(fitIntercept=value)
+
+ @since("2.2.0")
+ def setStandardization(self, value):
+ """
+ Sets the value of :py:attr:`standardization`.
+ """
+ return self._set(standardization=value)
+
+ @since("2.2.0")
+ def setThreshold(self, value):
+ """
+ Sets the value of :py:attr:`threshold`.
+ """
+ return self._set(threshold=value)
+
+ @since("2.2.0")
+ def setWeightCol(self, value):
+ """
+ Sets the value of :py:attr:`weightCol`.
+ """
+ return self._set(weightCol=value)
+
+ @since("2.2.0")
+ def setAggregationDepth(self, value):
+ """
+ Sets the value of :py:attr:`aggregationDepth`.
+ """
+ return self._set(aggregationDepth=value)
+
class LinearSVCModel(JavaClassificationModel, _LinearSVCParams, JavaMLWritable, JavaMLReadable):
"""
@@ -265,6 +333,13 @@ class LinearSVCModel(JavaClassificationModel, _LinearSVCParams, JavaMLWritable,
.. versionadded:: 2.2.0
"""
+ @since("3.0.0")
+ def setThreshold(self, value):
+ """
+ Sets the value of :py:attr:`threshold`.
+ """
+ return self._set(threshold=value)
+
@property
@since("2.2.0")
def coefficients(self):
@@ -454,7 +529,18 @@ class LogisticRegression(JavaProbabilisticClassifier, _LogisticRegressionParams,
... Row(label=0.0, weight=2.0, features=Vectors.dense(1.0, 2.0)),
... Row(label=1.0, weight=3.0, features=Vectors.dense(2.0, 1.0)),
... Row(label=0.0, weight=4.0, features=Vectors.dense(3.0, 3.0))]).toDF()
- >>> blor = LogisticRegression(regParam=0.01, weightCol="weight")
+ >>> blor = LogisticRegression(weightCol="weight")
+ >>> blor.getRegParam()
+ 0.0
+ >>> blor.setRegParam(0.01)
+ LogisticRegression...
+ >>> blor.getRegParam()
+ 0.01
+ >>> blor.setMaxIter(10)
+ LogisticRegression...
+ >>> blor.getMaxIter()
+ 10
+ >>> blor.clear(blor.maxIter)
>>> blorModel = blor.fit(bdf)
>>> blorModel.setFeaturesCol("features")
LogisticRegressionModel...
@@ -603,6 +689,54 @@ def setUpperBoundsOnIntercepts(self, value):
"""
return self._set(upperBoundsOnIntercepts=value)
+ def setMaxIter(self, value):
+ """
+ Sets the value of :py:attr:`maxIter`.
+ """
+ return self._set(maxIter=value)
+
+ def setRegParam(self, value):
+ """
+ Sets the value of :py:attr:`regParam`.
+ """
+ return self._set(regParam=value)
+
+ def setTol(self, value):
+ """
+ Sets the value of :py:attr:`tol`.
+ """
+ return self._set(tol=value)
+
+ def setElasticNetParam(self, value):
+ """
+ Sets the value of :py:attr:`elasticNetParam`.
+ """
+ return self._set(elasticNetParam=value)
+
+ def setFitIntercept(self, value):
+ """
+ Sets the value of :py:attr:`fitIntercept`.
+ """
+ return self._set(fitIntercept=value)
+
+ def setStandardization(self, value):
+ """
+ Sets the value of :py:attr:`standardization`.
+ """
+ return self._set(standardization=value)
+
+ def setWeightCol(self, value):
+ """
+ Sets the value of :py:attr:`weightCol`.
+ """
+ return self._set(weightCol=value)
+
+ def setAggregationDepth(self, value):
+ """
+ Sets the value of :py:attr:`aggregationDepth`.
+ """
+ return self._set(aggregationDepth=value)
+
class LogisticRegressionModel(JavaProbabilisticClassificationModel, _LogisticRegressionParams,
JavaMLWritable, JavaMLReadable, HasTrainingSummary):
@@ -1148,6 +1282,27 @@ def setImpurity(self, value):
"""
return self._set(impurity=value)
+ @since("1.4.0")
+ def setCheckpointInterval(self, value):
+ """
+ Sets the value of :py:attr:`checkpointInterval`.
+ """
+ return self._set(checkpointInterval=value)
+
+ @since("1.4.0")
+ def setSeed(self, value):
+ """
+ Sets the value of :py:attr:`seed`.
+ """
+ return self._set(seed=value)
+
+ @since("3.0.0")
+ def setWeightCol(self, value):
+ """
+ Sets the value of :py:attr:`weightCol`.
+ """
+ return self._set(weightCol=value)
+
@inherit_doc
class DecisionTreeClassificationModel(_DecisionTreeModel, JavaProbabilisticClassificationModel,
@@ -1366,6 +1521,18 @@ def setFeatureSubsetStrategy(self, value):
"""
return self._set(featureSubsetStrategy=value)
+ def setSeed(self, value):
+ """
+ Sets the value of :py:attr:`seed`.
+ """
+ return self._set(seed=value)
+
+ def setCheckpointInterval(self, value):
+ """
+ Sets the value of :py:attr:`checkpointInterval`.
+ """
+ return self._set(checkpointInterval=value)
+
class RandomForestClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel,
_RandomForestClassifierParams, JavaMLWritable,
@@ -1451,6 +1618,10 @@ class GBTClassifier(JavaProbabilisticClassifier, _GBTClassifierParams,
>>> td = si_model.transform(df)
>>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42,
... leafCol="leafId")
+ >>> gbt.setMaxIter(5)
+ GBTClassifier...
+ >>> gbt.getMaxIter()
+ 5
>>> gbt.getFeatureSubsetStrategy()
'all'
>>> model = gbt.fit(td)
@@ -1630,6 +1801,34 @@ def setValidationIndicatorCol(self, value):
"""
return self._set(validationIndicatorCol=value)
+ @since("1.4.0")
+ def setMaxIter(self, value):
+ """
+ Sets the value of :py:attr:`maxIter`.
+ """
+ return self._set(maxIter=value)
+
+ @since("1.4.0")
+ def setCheckpointInterval(self, value):
+ """
+ Sets the value of :py:attr:`checkpointInterval`.
+ """
+ return self._set(checkpointInterval=value)
+
+ @since("1.4.0")
+ def setSeed(self, value):
+ """
+ Sets the value of :py:attr:`seed`.
+ """
+ return self._set(seed=value)
+
+ @since("1.4.0")
+ def setStepSize(self, value):
+ """
+ Sets the value of :py:attr:`stepSize`.
+ """
+ return self._set(stepSize=value)
+
class GBTClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel,
_GBTClassifierParams, JavaMLWritable, JavaMLReadable):
@@ -1723,10 +1922,6 @@ class NaiveBayes(JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds,
>>> model = nb.fit(df)
>>> model.setFeaturesCol("features")
NaiveBayes_...
- >>> model.setLabelCol("newLabel")
- NaiveBayes_...
- >>> model.getLabelCol()
- 'newLabel'
>>> model.getSmoothing()
1.0
>>> model.pi
@@ -1814,6 +2009,12 @@ def setModelType(self, value):
"""
return self._set(modelType=value)
+ def setWeightCol(self, value):
+ """
+ Sets the value of :py:attr:`weightCol`.
+ """
+ return self._set(weightCol=value)
+
class NaiveBayesModel(JavaProbabilisticClassificationModel, _NaiveBayesParams, JavaMLWritable,
JavaMLReadable):
@@ -1906,7 +2107,11 @@ class MultilayerPerceptronClassifier(JavaProbabilisticClassifier, _MultilayerPer
... (1.0, Vectors.dense([0.0, 1.0])),
... (1.0, Vectors.dense([1.0, 0.0])),
... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"])
- >>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 2, 2], blockSize=1, seed=123)
+ >>> mlp = MultilayerPerceptronClassifier(layers=[2, 2, 2], blockSize=1, seed=123)
+ >>> mlp.setMaxIter(100)
+ MultilayerPerceptronClassifier...
+ >>> mlp.getMaxIter()
+ 100
>>> model = mlp.fit(df)
>>> model.setFeaturesCol("features")
MultilayerPerceptronClassifier...
@@ -2000,6 +2205,31 @@ def setBlockSize(self, value):
"""
return self._set(blockSize=value)
+ @since("2.0.0")
+ def setInitialWeights(self, value):
+ """
+ Sets the value of :py:attr:`initialWeights`.
+ """
+ return self._set(initialWeights=value)
+
+ def setMaxIter(self, value):
+ """
+ Sets the value of :py:attr:`maxIter`.
+ """
+ return self._set(maxIter=value)
+
+ def setSeed(self, value):
+ """
+ Sets the value of :py:attr:`seed`.
+ """
+ return self._set(seed=value)
+
+ def setTol(self, value):
+ """
+ Sets the value of :py:attr:`tol`.
+ """
+ return self._set(tol=value)
+
@since("2.0.0")
def setStepSize(self, value):
"""
@@ -2007,12 +2237,11 @@ def setStepSize(self, value):
"""
return self._set(stepSize=value)
- @since("2.0.0")
- def setInitialWeights(self, value):
+ def setSolver(self, value):
"""
- Sets the value of :py:attr:`initialWeights`.
+ Sets the value of :py:attr:`solver`.
"""
- return self._set(initialWeights=value)
+ return self._set(solver=value)
class MultilayerPerceptronClassificationModel(JavaProbabilisticClassificationModel, JavaMLWritable,
@@ -2134,6 +2363,42 @@ def setClassifier(self, value):
"""
return self._set(classifier=value)
+ def setLabelCol(self, value):
+ """
+ Sets the value of :py:attr:`labelCol`.
+ """
+ return self._set(labelCol=value)
+
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
+ def setRawPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`rawPredictionCol`.
+ """
+ return self._set(rawPredictionCol=value)
+
+ def setWeightCol(self, value):
+ """
+ Sets the value of :py:attr:`weightCol`.
+ """
+ return self._set(weightCol=value)
+
+ def setParallelism(self, value):
+ """
+ Sets the value of :py:attr:`parallelism`.
+ """
+ return self._set(parallelism=value)
+
def _fit(self, dataset):
labelCol = self.getLabelCol()
featuresCol = self.getFeaturesCol()
@@ -2287,6 +2552,43 @@ class OneVsRestModel(Model, _OneVsRestParams, JavaMLReadable, JavaMLWritable):
.. versionadded:: 2.0.0
"""
+ @since("2.0.0")
+ def setClassifier(self, value):
+ """
+ Sets the value of :py:attr:`classifier`.
+ """
+ return self._set(classifier=value)
+
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
+ def setRawPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`rawPredictionCol`.
+ """
+ return self._set(rawPredictionCol=value)
+
+ def setLabelCol(self, value):
+ """
+ Sets the value of :py:attr:`labelCol`.
+ """
+ return self._set(labelCol=value)
+
+ def setWeightCol(self, value):
+ """
+ Sets the value of :py:attr:`weightCol`.
+ """
+ return self._set(weightCol=value)
+
def __init__(self, models):
super(OneVsRestModel, self).__init__()
self.models = models
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index cbbbd36955dc0..bb73dc78c4ab4 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -123,6 +123,27 @@ class GaussianMixtureModel(JavaModel, _GaussianMixtureParams, JavaMLWritable, Ja
.. versionadded:: 2.0.0
"""
+ @since("3.0.0")
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ @since("3.0.0")
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
+ @since("3.0.0")
+ def setProbabilityCol(self, value):
+ """
+ Sets the value of :py:attr:`probabilityCol`.
+ """
+ return self._set(probabilityCol=value)
+
@property
@since("2.0.0")
def weights(self):
@@ -200,8 +221,13 @@ class GaussianMixture(JavaEstimator, _GaussianMixtureParams, JavaMLWritable, Jav
... (Vectors.dense([-0.83, -0.68]),),
... (Vectors.dense([-0.91, -0.76]),)]
>>> df = spark.createDataFrame(data, ["features"])
- >>> gm = GaussianMixture(k=3, tol=0.0001,
- ... maxIter=10, seed=10)
+ >>> gm = GaussianMixture(k=3, tol=0.0001, seed=10)
+ >>> gm.getMaxIter()
+ 100
+ >>> gm.setMaxIter(10)
+ GaussianMixture...
+ >>> gm.getMaxIter()
+ 10
>>> model = gm.fit(df)
>>> model.getFeaturesCol()
'features'
@@ -290,6 +316,48 @@ def setK(self, value):
"""
return self._set(k=value)
+ @since("2.0.0")
+ def setMaxIter(self, value):
+ """
+ Sets the value of :py:attr:`maxIter`.
+ """
+ return self._set(maxIter=value)
+
+ @since("2.0.0")
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ @since("2.0.0")
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
+ @since("2.0.0")
+ def setProbabilityCol(self, value):
+ """
+ Sets the value of :py:attr:`probabilityCol`.
+ """
+ return self._set(probabilityCol=value)
+
+ @since("2.0.0")
+ def setSeed(self, value):
+ """
+ Sets the value of :py:attr:`seed`.
+ """
+ return self._set(seed=value)
+
+ @since("2.0.0")
+ def setTol(self, value):
+ """
+ Sets the value of :py:attr:`tol`.
+ """
+ return self._set(tol=value)
+
class GaussianMixtureSummary(ClusteringSummary):
"""
@@ -389,6 +457,20 @@ class KMeansModel(JavaModel, _KMeansParams, GeneralJavaMLWritable, JavaMLReadabl
.. versionadded:: 1.5.0
"""
+ @since("3.0.0")
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ @since("3.0.0")
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
@since("1.5.0")
def clusterCenters(self):
"""Get the cluster centers, represented as a list of NumPy arrays."""
@@ -425,7 +507,14 @@ class KMeans(JavaEstimator, _KMeansParams, JavaMLWritable, JavaMLReadable):
>>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),),
... (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)]
>>> df = spark.createDataFrame(data, ["features"])
- >>> kmeans = KMeans(k=2, seed=1)
+ >>> kmeans = KMeans(k=2)
+ >>> kmeans.setSeed(1)
+ KMeans...
+ >>> kmeans.setMaxIter(10)
+ KMeans...
+ >>> kmeans.getMaxIter()
+ 10
+ >>> kmeans.clear(kmeans.maxIter)
>>> model = kmeans.fit(df)
>>> model.getDistanceMeasure()
'euclidean'
@@ -531,6 +620,41 @@ def setDistanceMeasure(self, value):
"""
return self._set(distanceMeasure=value)
+ @since("1.5.0")
+ def setMaxIter(self, value):
+ """
+ Sets the value of :py:attr:`maxIter`.
+ """
+ return self._set(maxIter=value)
+
+ @since("1.5.0")
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ @since("1.5.0")
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
+ @since("1.5.0")
+ def setSeed(self, value):
+ """
+ Sets the value of :py:attr:`seed`.
+ """
+ return self._set(seed=value)
+
+ @since("1.5.0")
+ def setTol(self, value):
+ """
+ Sets the value of :py:attr:`tol`.
+ """
+ return self._set(tol=value)
+
@inherit_doc
class _BisectingKMeansParams(HasMaxIter, HasFeaturesCol, HasSeed, HasPredictionCol,
@@ -571,6 +695,20 @@ class BisectingKMeansModel(JavaModel, _BisectingKMeansParams, JavaMLWritable, Ja
.. versionadded:: 2.0.0
"""
+ @since("3.0.0")
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ @since("3.0.0")
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
@since("2.0.0")
def clusterCenters(self):
"""Get the cluster centers, represented as a list of NumPy arrays."""
@@ -629,6 +767,16 @@ class BisectingKMeans(JavaEstimator, _BisectingKMeansParams, JavaMLWritable, Jav
... (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)]
>>> df = spark.createDataFrame(data, ["features"])
>>> bkm = BisectingKMeans(k=2, minDivisibleClusterSize=1.0)
+ >>> bkm.setMaxIter(10)
+ BisectingKMeans...
+ >>> bkm.getMaxIter()
+ 10
+ >>> bkm.clear(bkm.maxIter)
+ >>> bkm.setSeed(1)
+ BisectingKMeans...
+ >>> bkm.getSeed()
+ 1
+ >>> bkm.clear(bkm.seed)
>>> model = bkm.fit(df)
>>> model.getMaxIter()
20
@@ -723,6 +871,34 @@ def setDistanceMeasure(self, value):
"""
return self._set(distanceMeasure=value)
+ @since("2.0.0")
+ def setMaxIter(self, value):
+ """
+ Sets the value of :py:attr:`maxIter`.
+ """
+ return self._set(maxIter=value)
+
+ @since("2.0.0")
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ @since("2.0.0")
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
+ @since("2.0.0")
+ def setSeed(self, value):
+ """
+ Sets the value of :py:attr:`seed`.
+ """
+ return self._set(seed=value)
+
def _create_model(self, java_model):
return BisectingKMeansModel(java_model)
@@ -873,6 +1049,31 @@ class LDAModel(JavaModel, _LDAParams):
.. versionadded:: 2.0.0
"""
+ @since("3.0.0")
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ @since("3.0.0")
+ def setSeed(self, value):
+ """
+ Sets the value of :py:attr:`seed`.
+ """
+ return self._set(seed=value)
+
+ @since("3.0.0")
+ def setTopicDistributionCol(self, value):
+ """
+ Sets the value of :py:attr:`topicDistributionCol`.
+
+ >>> algo = LDA().setTopicDistributionCol("topicDistributionCol")
+ >>> algo.getTopicDistributionCol()
+ 'topicDistributionCol'
+ """
+ return self._set(topicDistributionCol=value)
+
@since("2.0.0")
def isDistributed(self):
"""
@@ -1045,6 +1246,11 @@ class LDA(JavaEstimator, _LDAParams, JavaMLReadable, JavaMLWritable):
>>> df = spark.createDataFrame([[1, Vectors.dense([0.0, 1.0])],
... [2, SparseVector(2, {0: 1.0})],], ["id", "features"])
>>> lda = LDA(k=2, seed=1, optimizer="em")
+ >>> lda.setMaxIter(10)
+ LDA...
+ >>> lda.getMaxIter()
+ 10
+ >>> lda.clear(lda.maxIter)
>>> model = lda.fit(df)
>>> model.getTopicDistributionCol()
'topicDistribution'
@@ -1125,6 +1331,20 @@ def setParams(self, featuresCol="features", maxIter=20, seed=None, checkpointInt
kwargs = self._input_kwargs
return self._set(**kwargs)
+ @since("2.0.0")
+ def setCheckpointInterval(self, value):
+ """
+ Sets the value of :py:attr:`checkpointInterval`.
+ """
+ return self._set(checkpointInterval=value)
+
+ @since("2.0.0")
+ def setSeed(self, value):
+ """
+ Sets the value of :py:attr:`seed`.
+ """
+ return self._set(seed=value)
+
@since("2.0.0")
def setK(self, value):
"""
@@ -1236,6 +1456,20 @@ def setKeepLastCheckpoint(self, value):
"""
return self._set(keepLastCheckpoint=value)
+ @since("2.0.0")
+ def setMaxIter(self, value):
+ """
+ Sets the value of :py:attr:`maxIter`.
+ """
+ return self._set(maxIter=value)
+
+ @since("2.0.0")
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
@inherit_doc
class _PowerIterationClusteringParams(HasMaxIter, HasWeightCol):
@@ -1392,6 +1626,20 @@ def setDstCol(self, value):
"""
return self._set(dstCol=value)
+ @since("2.4.0")
+ def setMaxIter(self, value):
+ """
+ Sets the value of :py:attr:`maxIter`.
+ """
+ return self._set(maxIter=value)
+
+ @since("2.4.0")
+ def setWeightCol(self, value):
+ """
+ Sets the value of :py:attr:`weightCol`.
+ """
+ return self._set(weightCol=value)
+
@since("2.4.0")
def assignClusters(self, dataset):
"""
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index cdd9be7bf11b3..6539e2abaed12 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -119,7 +119,9 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
... [(0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)])
>>> dataset = spark.createDataFrame(scoreAndLabels, ["raw", "label"])
...
- >>> evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw")
+ >>> evaluator = BinaryClassificationEvaluator()
+ >>> evaluator.setRawPredictionCol("raw")
+ BinaryClassificationEvaluator...
>>> evaluator.evaluate(dataset)
0.70...
>>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"})
@@ -196,6 +198,25 @@ def getNumBins(self):
"""
return self.getOrDefault(self.numBins)
+ def setLabelCol(self, value):
+ """
+ Sets the value of :py:attr:`labelCol`.
+ """
+ return self._set(labelCol=value)
+
+ def setRawPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`rawPredictionCol`.
+ """
+ return self._set(rawPredictionCol=value)
+
+ @since("3.0.0")
+ def setWeightCol(self, value):
+ """
+ Sets the value of :py:attr:`weightCol`.
+ """
+ return self._set(weightCol=value)
+
@keyword_only
@since("1.4.0")
def setParams(self, rawPredictionCol="rawPrediction", labelCol="label",
@@ -220,7 +241,9 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeigh
... (-25.98418959, -22.0), (30.69731842, 33.0), (74.69283752, 71.0)]
>>> dataset = spark.createDataFrame(scoreAndLabels, ["raw", "label"])
...
- >>> evaluator = RegressionEvaluator(predictionCol="raw")
+ >>> evaluator = RegressionEvaluator()
+ >>> evaluator.setPredictionCol("raw")
+ RegressionEvaluator...
>>> evaluator.evaluate(dataset)
2.842...
>>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"})
@@ -299,6 +322,25 @@ def getThroughOrigin(self):
"""
return self.getOrDefault(self.throughOrigin)
+ def setLabelCol(self, value):
+ """
+ Sets the value of :py:attr:`labelCol`.
+ """
+ return self._set(labelCol=value)
+
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
+ @since("3.0.0")
+ def setWeightCol(self, value):
+ """
+ Sets the value of :py:attr:`weightCol`.
+ """
+ return self._set(weightCol=value)
+
@keyword_only
@since("1.4.0")
def setParams(self, predictionCol="prediction", labelCol="label",
@@ -322,7 +364,9 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
>>> scoreAndLabels = [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]
>>> dataset = spark.createDataFrame(scoreAndLabels, ["prediction", "label"])
- >>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")
+ >>> evaluator = MulticlassClassificationEvaluator()
+ >>> evaluator.setPredictionCol("prediction")
+ MulticlassClassificationEvaluator...
>>> evaluator.evaluate(dataset)
0.66...
>>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})
@@ -453,6 +497,32 @@ def getEps(self):
"""
return self.getOrDefault(self.eps)
+ def setLabelCol(self, value):
+ """
+ Sets the value of :py:attr:`labelCol`.
+ """
+ return self._set(labelCol=value)
+
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
+ @since("3.0.0")
+ def setProbabilityCol(self, value):
+ """
+ Sets the value of :py:attr:`probabilityCol`.
+ """
+ return self._set(probabilityCol=value)
+
+ @since("3.0.0")
+ def setWeightCol(self, value):
+ """
+ Sets the value of :py:attr:`weightCol`.
+ """
+ return self._set(weightCol=value)
+
@keyword_only
@since("1.5.0")
def setParams(self, predictionCol="prediction", labelCol="label",
@@ -482,7 +552,9 @@ class MultilabelClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])]
>>> dataset = spark.createDataFrame(scoreAndLabels, ["prediction", "label"])
...
- >>> evaluator = MultilabelClassificationEvaluator(predictionCol="prediction")
+ >>> evaluator = MultilabelClassificationEvaluator()
+ >>> evaluator.setPredictionCol("prediction")
+ MultilabelClassificationEvaluator...
>>> evaluator.evaluate(dataset)
0.63...
>>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})
@@ -549,6 +621,20 @@ def getMetricLabel(self):
"""
return self.getOrDefault(self.metricLabel)
+ @since("3.0.0")
+ def setLabelCol(self, value):
+ """
+ Sets the value of :py:attr:`labelCol`.
+ """
+ return self._set(labelCol=value)
+
+ @since("3.0.0")
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
@keyword_only
@since("3.0.0")
def setParams(self, predictionCol="prediction", labelCol="label",
@@ -581,7 +667,9 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol,
... ([10.5, 11.5], 1.0), ([1.0, 1.0], 0.0), ([8.0, 6.0], 1.0)])
>>> dataset = spark.createDataFrame(featureAndPredictions, ["features", "prediction"])
...
- >>> evaluator = ClusteringEvaluator(predictionCol="prediction")
+ >>> evaluator = ClusteringEvaluator()
+ >>> evaluator.setPredictionCol("prediction")
+ ClusteringEvaluator...
>>> evaluator.evaluate(dataset)
0.9079...
>>> ce_path = temp_path + "/ce"
@@ -613,6 +701,18 @@ def __init__(self, predictionCol="prediction", featuresCol="features",
kwargs = self._input_kwargs
self._set(**kwargs)
+ @keyword_only
+ @since("2.3.0")
+ def setParams(self, predictionCol="prediction", featuresCol="features",
+ metricName="silhouette", distanceMeasure="squaredEuclidean"):
+ """
+ setParams(self, predictionCol="prediction", featuresCol="features", \
+ metricName="silhouette", distanceMeasure="squaredEuclidean")
+ Sets params for clustering evaluator.
+ """
+ kwargs = self._input_kwargs
+ return self._set(**kwargs)
+
@since("2.3.0")
def setMetricName(self, value):
"""
@@ -627,18 +727,6 @@ def getMetricName(self):
"""
return self.getOrDefault(self.metricName)
- @keyword_only
- @since("2.3.0")
- def setParams(self, predictionCol="prediction", featuresCol="features",
- metricName="silhouette", distanceMeasure="squaredEuclidean"):
- """
- setParams(self, predictionCol="prediction", featuresCol="features", \
- metricName="silhouette", distanceMeasure="squaredEuclidean")
- Sets params for clustering evaluator.
- """
- kwargs = self._input_kwargs
- return self._set(**kwargs)
-
@since("2.4.0")
def setDistanceMeasure(self, value):
"""
@@ -653,6 +741,18 @@ def getDistanceMeasure(self):
"""
return self.getOrDefault(self.distanceMeasure)
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
@inherit_doc
class RankingEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
@@ -669,7 +769,9 @@ class RankingEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
... ([1.0, 2.0, 3.0, 4.0, 5.0], [])]
>>> dataset = spark.createDataFrame(scoreAndLabels, ["prediction", "label"])
...
- >>> evaluator = RankingEvaluator(predictionCol="prediction")
+ >>> evaluator = RankingEvaluator()
+ >>> evaluator.setPredictionCol("prediction")
+ RankingEvaluator...
>>> evaluator.evaluate(dataset)
0.35...
>>> evaluator.evaluate(dataset, {evaluator.metricName: "precisionAtK", evaluator.k: 2})
@@ -734,6 +836,20 @@ def getK(self):
"""
return self.getOrDefault(self.k)
+ @since("3.0.0")
+ def setLabelCol(self, value):
+ """
+ Sets the value of :py:attr:`labelCol`.
+ """
+ return self._set(labelCol=value)
+
+ @since("3.0.0")
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
@keyword_only
@since("3.0.0")
def setParams(self, predictionCol="prediction", labelCol="label",
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index a0883f1d54fed..11bb7941b5d9a 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -76,6 +76,12 @@ class Binarizer(JavaTransformer, HasThreshold, HasThresholds, HasInputCol, HasOu
>>> df = spark.createDataFrame([(0.5,)], ["values"])
>>> binarizer = Binarizer(threshold=1.0, inputCol="values", outputCol="features")
+ >>> binarizer.setThreshold(1.0)
+ Binarizer...
+ >>> binarizer.setInputCol("values")
+ Binarizer...
+ >>> binarizer.setOutputCol("features")
+ Binarizer...
>>> binarizer.transform(df).head().features
0.0
>>> binarizer.setParams(outputCol="freqs").transform(df).head().freqs
@@ -154,6 +160,32 @@ def setThresholds(self, value):
"""
return self._set(thresholds=value)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ @since("3.0.0")
+ def setInputCols(self, value):
+ """
+ Sets the value of :py:attr:`inputCols`.
+ """
+ return self._set(inputCols=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
+ @since("3.0.0")
+ def setOutputCols(self, value):
+ """
+ Sets the value of :py:attr:`outputCols`.
+ """
+ return self._set(outputCols=value)
+
class _LSHParams(HasInputCol, HasOutputCol):
"""
@@ -183,12 +215,36 @@ def setNumHashTables(self, value):
"""
return self._set(numHashTables=value)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
class _LSHModel(JavaModel, _LSHParams):
"""
Mixin for Locality Sensitive Hashing (LSH) models.
"""
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
def approxNearestNeighbors(self, dataset, key, numNearestNeighbors, distCol="distCol"):
"""
Given a large dataset and an item, approximately find at most k items which have the
@@ -269,8 +325,15 @@ class BucketedRandomProjectionLSH(_LSH, _BucketedRandomProjectionLSHParams,
... (2, Vectors.dense([1.0, -1.0 ]),),
... (3, Vectors.dense([1.0, 1.0]),)]
>>> df = spark.createDataFrame(data, ["id", "features"])
- >>> brp = BucketedRandomProjectionLSH(inputCol="features", outputCol="hashes",
- ... seed=12345, bucketLength=1.0)
+ >>> brp = BucketedRandomProjectionLSH()
+ >>> brp.setInputCol("features")
+ BucketedRandomProjectionLSH...
+ >>> brp.setOutputCol("hashes")
+ BucketedRandomProjectionLSH...
+ >>> brp.setSeed(12345)
+ BucketedRandomProjectionLSH...
+ >>> brp.setBucketLength(1.0)
+ BucketedRandomProjectionLSH...
>>> model = brp.fit(df)
>>> model.getBucketLength()
1.0
@@ -350,6 +413,24 @@ def setBucketLength(self, value):
"""
return self._set(bucketLength=value)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
+ def setSeed(self, value):
+ """
+ Sets the value of :py:attr:`seed`.
+ """
+ return self._set(seed=value)
+
def _create_model(self, java_model):
return BucketedRandomProjectionLSHModel(java_model)
@@ -366,6 +447,20 @@ class BucketedRandomProjectionLSHModel(_LSHModel, _BucketedRandomProjectionLSHPa
.. versionadded:: 2.2.0
"""
+ @since("3.0.0")
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ @since("3.0.0")
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
@inherit_doc
class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols,
@@ -380,8 +475,13 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOu
>>> values = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, float("nan")),
... (float("nan"), 1.0), (float("nan"), 0.0)]
>>> df = spark.createDataFrame(values, ["values1", "values2"])
- >>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")],
- ... inputCol="values1", outputCol="buckets")
+ >>> bucketizer = Bucketizer()
+ >>> bucketizer.setSplits([-float("inf"), 0.5, 1.4, float("inf")])
+ Bucketizer...
+ >>> bucketizer.setInputCol("values1")
+ Bucketizer...
+ >>> bucketizer.setOutputCol("buckets")
+ Bucketizer...
>>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect()
>>> bucketed = bucketizer.setHandleInvalid("keep").transform(df.select("values1"))
>>> bucketed.show(truncate=False)
@@ -510,6 +610,38 @@ def getSplitsArray(self):
"""
return self.getOrDefault(self.splitsArray)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ @since("3.0.0")
+ def setInputCols(self, value):
+ """
+ Sets the value of :py:attr:`inputCols`.
+ """
+ return self._set(inputCols=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
+ @since("3.0.0")
+ def setOutputCols(self, value):
+ """
+ Sets the value of :py:attr:`outputCols`.
+ """
+ return self._set(outputCols=value)
+
+ def setHandleInvalid(self, value):
+ """
+ Sets the value of :py:attr:`handleInvalid`.
+ """
+ return self._set(handleInvalid=value)
+
class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol):
"""
@@ -595,7 +727,11 @@ class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, Jav
>>> df = spark.createDataFrame(
... [(0, ["a", "b", "c"]), (1, ["a", "b", "b", "c", "a"])],
... ["label", "raw"])
- >>> cv = CountVectorizer(inputCol="raw", outputCol="vectors")
+ >>> cv = CountVectorizer()
+ >>> cv.setInputCol("raw")
+ CountVectorizer...
+ >>> cv.setOutputCol("vectors")
+ CountVectorizer...
>>> model = cv.fit(df)
>>> model.transform(df).show(truncate=False)
+-----+---------------+-------------------------+
@@ -695,6 +831,18 @@ def setBinary(self, value):
"""
return self._set(binary=value)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
def _create_model(self, java_model):
return CountVectorizerModel(java_model)
@@ -707,6 +855,34 @@ class CountVectorizerModel(JavaModel, _CountVectorizerParams, JavaMLReadable, Ja
.. versionadded:: 1.6.0
"""
+ @since("3.0.0")
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ @since("3.0.0")
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
+ @since("3.0.0")
+ def setMinTF(self, value):
+ """
+ Sets the value of :py:attr:`minTF`.
+ """
+ return self._set(minTF=value)
+
+ @since("3.0.0")
+ def setBinary(self, value):
+ """
+ Sets the value of :py:attr:`binary`.
+ """
+ return self._set(binary=value)
+
@classmethod
@since("2.4.0")
def from_vocabulary(cls, vocabulary, inputCol, outputCol=None, minTF=None, binary=None):
@@ -766,7 +942,13 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWrit
>>> from pyspark.ml.linalg import Vectors
>>> df1 = spark.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], ["vec"])
- >>> dct = DCT(inverse=False, inputCol="vec", outputCol="resultVec")
+ >>> dct = DCT( )
+ >>> dct.setInverse(False)
+ DCT...
+ >>> dct.setInputCol("vec")
+ DCT...
+ >>> dct.setOutputCol("resultVec")
+ DCT...
>>> df2 = dct.transform(df1)
>>> df2.head().resultVec
DenseVector([10.969..., -0.707..., -2.041...])
@@ -820,6 +1002,18 @@ def getInverse(self):
"""
return self.getOrDefault(self.inverse)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
@inherit_doc
class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
@@ -831,8 +1025,13 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReada
>>> from pyspark.ml.linalg import Vectors
>>> df = spark.createDataFrame([(Vectors.dense([2.0, 1.0, 3.0]),)], ["values"])
- >>> ep = ElementwiseProduct(scalingVec=Vectors.dense([1.0, 2.0, 3.0]),
- ... inputCol="values", outputCol="eprod")
+ >>> ep = ElementwiseProduct()
+ >>> ep.setScalingVec(Vectors.dense([1.0, 2.0, 3.0]))
+ ElementwiseProduct...
+ >>> ep.setInputCol("values")
+ ElementwiseProduct...
+ >>> ep.setOutputCol("eprod")
+ ElementwiseProduct...
>>> ep.transform(df).head().eprod
DenseVector([2.0, 2.0, 9.0])
>>> ep.setParams(scalingVec=Vectors.dense([2.0, 3.0, 5.0])).transform(df).head().eprod
@@ -884,6 +1083,18 @@ def getScalingVec(self):
"""
return self.getOrDefault(self.scalingVec)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
@inherit_doc
class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, JavaMLReadable,
@@ -923,7 +1134,11 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures,
>>> data = [(2.0, True, "1", "foo"), (3.0, False, "2", "bar")]
>>> cols = ["real", "bool", "stringNum", "string"]
>>> df = spark.createDataFrame(data, cols)
- >>> hasher = FeatureHasher(inputCols=cols, outputCol="features")
+ >>> hasher = FeatureHasher()
+ >>> hasher.setInputCols(cols)
+ FeatureHasher...
+ >>> hasher.setOutputCol("features")
+ FeatureHasher...
>>> hasher.transform(df).head().features
SparseVector(262144, {174475: 2.0, 247670: 1.0, 257907: 1.0, 262126: 1.0})
>>> hasher.setCategoricalCols(["real"]).transform(df).head().features
@@ -978,6 +1193,24 @@ def getCategoricalCols(self):
"""
return self.getOrDefault(self.categoricalCols)
+ def setInputCols(self, value):
+ """
+ Sets the value of :py:attr:`inputCols`.
+ """
+ return self._set(inputCols=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
+ def setNumFeatures(self, value):
+ """
+ Sets the value of :py:attr:`numFeatures`.
+ """
+ return self._set(numFeatures=value)
+
@inherit_doc
class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable,
@@ -991,7 +1224,9 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java
otherwise the features will not be mapped evenly to the columns.
>>> df = spark.createDataFrame([(["a", "b", "c"],)], ["words"])
- >>> hashingTF = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
+ >>> hashingTF = HashingTF(inputCol="words", outputCol="features")
+ >>> hashingTF.setNumFeatures(10)
+ HashingTF...
>>> hashingTF.transform(df).head().features
SparseVector(10, {5: 1.0, 7: 1.0, 8: 1.0})
>>> hashingTF.setParams(outputCol="freqs").transform(df).head().freqs
@@ -1050,6 +1285,24 @@ def getBinary(self):
"""
return self.getOrDefault(self.binary)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
+ def setNumFeatures(self, value):
+ """
+ Sets the value of :py:attr:`numFeatures`.
+ """
+ return self._set(numFeatures=value)
+
@since("3.0.0")
def indexOf(self, term):
"""
@@ -1086,7 +1339,11 @@ class IDF(JavaEstimator, _IDFParams, JavaMLReadable, JavaMLWritable):
>>> from pyspark.ml.linalg import DenseVector
>>> df = spark.createDataFrame([(DenseVector([1.0, 2.0]),),
... (DenseVector([0.0, 1.0]),), (DenseVector([3.0, 0.2]),)], ["tf"])
- >>> idf = IDF(minDocFreq=3, inputCol="tf", outputCol="idf")
+ >>> idf = IDF(minDocFreq=3)
+ >>> idf.setInputCol("tf")
+ IDF...
+ >>> idf.setOutputCol("idf")
+ IDF...
>>> model = idf.fit(df)
>>> model.getMinDocFreq()
3
@@ -1145,6 +1402,18 @@ def setMinDocFreq(self, value):
"""
return self._set(minDocFreq=value)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
def _create_model(self, java_model):
return IDFModel(java_model)
@@ -1156,6 +1425,20 @@ class IDFModel(JavaModel, _IDFParams, JavaMLReadable, JavaMLWritable):
.. versionadded:: 1.4.0
"""
+ @since("3.0.0")
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ @since("3.0.0")
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
@property
@since("2.0.0")
def idf(self):
@@ -1228,7 +1511,11 @@ class Imputer(JavaEstimator, _ImputerParams, JavaMLReadable, JavaMLWritable):
>>> df = spark.createDataFrame([(1.0, float("nan")), (2.0, float("nan")), (float("nan"), 3.0),
... (4.0, 4.0), (5.0, 5.0)], ["a", "b"])
- >>> imputer = Imputer(inputCols=["a", "b"], outputCols=["out_a", "out_b"])
+ >>> imputer = Imputer()
+ >>> imputer.setInputCols(["a", "b"])
+ Imputer...
+ >>> imputer.setOutputCols(["out_a", "out_b"])
+ Imputer...
>>> model = imputer.fit(df)
>>> model.getStrategy()
'mean'
@@ -1308,6 +1595,20 @@ def setMissingValue(self, value):
"""
return self._set(missingValue=value)
+ @since("2.2.0")
+ def setInputCols(self, value):
+ """
+ Sets the value of :py:attr:`inputCols`.
+ """
+ return self._set(inputCols=value)
+
+ @since("2.2.0")
+ def setOutputCols(self, value):
+ """
+ Sets the value of :py:attr:`outputCols`.
+ """
+ return self._set(outputCols=value)
+
def _create_model(self, java_model):
return ImputerModel(java_model)
@@ -1319,6 +1620,20 @@ class ImputerModel(JavaModel, _ImputerParams, JavaMLReadable, JavaMLWritable):
.. versionadded:: 2.2.0
"""
+ @since("3.0.0")
+ def setInputCols(self, value):
+ """
+ Sets the value of :py:attr:`inputCols`.
+ """
+ return self._set(inputCols=value)
+
+ @since("3.0.0")
+ def setOutputCols(self, value):
+ """
+ Sets the value of :py:attr:`outputCols`.
+ """
+ return self._set(outputCols=value)
+
@property
@since("2.2.0")
def surrogateDF(self):
@@ -1342,7 +1657,11 @@ class Interaction(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, J
with four categories, the output would then be `Vector(0, 0, 0, 0, 3, 4, 0, 0)`.
>>> df = spark.createDataFrame([(0.0, 1.0), (2.0, 3.0)], ["a", "b"])
- >>> interaction = Interaction(inputCols=["a", "b"], outputCol="ab")
+ >>> interaction = Interaction()
+ >>> interaction.setInputCols(["a", "b"])
+ Interaction...
+ >>> interaction.setOutputCol("ab")
+ Interaction...
>>> interaction.transform(df).show()
+---+---+-----+
| a| b| ab|
@@ -1381,6 +1700,20 @@ def setParams(self, inputCols=None, outputCol=None):
kwargs = self._input_kwargs
return self._set(**kwargs)
+ @since("3.0.0")
+ def setInputCols(self, value):
+ """
+ Sets the value of :py:attr:`inputCols`.
+ """
+ return self._set(inputCols=value)
+
+ @since("3.0.0")
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
class _MaxAbsScalerParams(HasInputCol, HasOutputCol):
"""
@@ -1400,7 +1733,9 @@ class MaxAbsScaler(JavaEstimator, _MaxAbsScalerParams, JavaMLReadable, JavaMLWri
>>> from pyspark.ml.linalg import Vectors
>>> df = spark.createDataFrame([(Vectors.dense([1.0]),), (Vectors.dense([2.0]),)], ["a"])
- >>> maScaler = MaxAbsScaler(inputCol="a", outputCol="scaled")
+ >>> maScaler = MaxAbsScaler(outputCol="scaled")
+ >>> maScaler.setInputCol("a")
+ MaxAbsScaler...
>>> model = maScaler.fit(df)
>>> model.setOutputCol("scaledOutput")
MaxAbsScaler...
@@ -1449,6 +1784,18 @@ def setParams(self, inputCol=None, outputCol=None):
kwargs = self._input_kwargs
return self._set(**kwargs)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
def _create_model(self, java_model):
return MaxAbsScalerModel(java_model)
@@ -1460,6 +1807,20 @@ class MaxAbsScalerModel(JavaModel, _MaxAbsScalerParams, JavaMLReadable, JavaMLWr
.. versionadded:: 2.0.0
"""
+ @since("3.0.0")
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ @since("3.0.0")
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
@property
@since("2.0.0")
def maxAbs(self):
@@ -1487,7 +1848,13 @@ class MinHashLSH(_LSH, HasInputCol, HasOutputCol, HasSeed, JavaMLReadable, JavaM
... (1, Vectors.sparse(6, [2, 3, 4], [1.0, 1.0, 1.0]),),
... (2, Vectors.sparse(6, [0, 2, 4], [1.0, 1.0, 1.0]),)]
>>> df = spark.createDataFrame(data, ["id", "features"])
- >>> mh = MinHashLSH(inputCol="features", outputCol="hashes", seed=12345)
+ >>> mh = MinHashLSH()
+ >>> mh.setInputCol("features")
+ MinHashLSH...
+ >>> mh.setOutputCol("hashes")
+ MinHashLSH...
+ >>> mh.setSeed(12345)
+ MinHashLSH...
>>> model = mh.fit(df)
>>> model.transform(df).head()
Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([6179668...
@@ -1544,6 +1911,12 @@ def setParams(self, inputCol=None, outputCol=None, seed=None, numHashTables=1):
kwargs = self._input_kwargs
return self._set(**kwargs)
+ def setSeed(self, value):
+ """
+ Sets the value of :py:attr:`seed`.
+ """
+ return self._set(seed=value)
+
def _create_model(self, java_model):
return MinHashLSHModel(java_model)
@@ -1606,7 +1979,9 @@ class MinMaxScaler(JavaEstimator, _MinMaxScalerParams, JavaMLReadable, JavaMLWri
>>> from pyspark.ml.linalg import Vectors
>>> df = spark.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"])
- >>> mmScaler = MinMaxScaler(inputCol="a", outputCol="scaled")
+ >>> mmScaler = MinMaxScaler(outputCol="scaled")
+ >>> mmScaler.setInputCol("a")
+ MinMaxScaler...
>>> model = mmScaler.fit(df)
>>> model.setOutputCol("scaledOutput")
MinMaxScaler...
@@ -1675,6 +2050,18 @@ def setMax(self, value):
"""
return self._set(max=value)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
def _create_model(self, java_model):
return MinMaxScalerModel(java_model)
@@ -1686,6 +2073,34 @@ class MinMaxScalerModel(JavaModel, _MinMaxScalerParams, JavaMLReadable, JavaMLWr
.. versionadded:: 1.6.0
"""
+ @since("3.0.0")
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ @since("3.0.0")
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
+ @since("3.0.0")
+ def setMin(self, value):
+ """
+ Sets the value of :py:attr:`min`.
+ """
+ return self._set(min=value)
+
+ @since("3.0.0")
+ def setMax(self, value):
+ """
+ Sets the value of :py:attr:`max`.
+ """
+ return self._set(max=value)
+
@property
@since("2.0.0")
def originalMin(self):
@@ -1716,7 +2131,11 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWr
returned.
>>> df = spark.createDataFrame([Row(inputTokens=["a", "b", "c", "d", "e"])])
- >>> ngram = NGram(n=2, inputCol="inputTokens", outputCol="nGrams")
+ >>> ngram = NGram(n=2)
+ >>> ngram.setInputCol("inputTokens")
+ NGram...
+ >>> ngram.setOutputCol("nGrams")
+ NGram...
>>> ngram.transform(df).head()
Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b', u'b c', u'c d', u'd e'])
>>> # Change n-gram length
@@ -1779,6 +2198,18 @@ def getN(self):
"""
return self.getOrDefault(self.n)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
@inherit_doc
class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
@@ -1788,7 +2219,11 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav
>>> from pyspark.ml.linalg import Vectors
>>> svec = Vectors.sparse(4, {1: 4.0, 3: 3.0})
>>> df = spark.createDataFrame([(Vectors.dense([3.0, -4.0]), svec)], ["dense", "sparse"])
- >>> normalizer = Normalizer(p=2.0, inputCol="dense", outputCol="features")
+ >>> normalizer = Normalizer(p=2.0)
+ >>> normalizer.setInputCol("dense")
+ Normalizer...
+ >>> normalizer.setOutputCol("features")
+ Normalizer...
>>> normalizer.transform(df).head().features
DenseVector([0.6, -0.8])
>>> normalizer.setParams(inputCol="sparse", outputCol="freqs").transform(df).head().freqs
@@ -1843,6 +2278,18 @@ def getP(self):
"""
return self.getOrDefault(self.p)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
class _OneHotEncoderParams(HasInputCols, HasOutputCols, HasHandleInvalid):
"""
@@ -1895,7 +2342,11 @@ class OneHotEncoder(JavaEstimator, _OneHotEncoderParams, JavaMLReadable, JavaMLW
>>> from pyspark.ml.linalg import Vectors
>>> df = spark.createDataFrame([(0.0,), (1.0,), (2.0,)], ["input"])
- >>> ohe = OneHotEncoder(inputCols=["input"], outputCols=["output"])
+ >>> ohe = OneHotEncoder()
+ >>> ohe.setInputCols(["input"])
+ OneHotEncoder...
+ >>> ohe.setOutputCols(["output"])
+ OneHotEncoder...
>>> model = ohe.fit(df)
>>> model.getHandleInvalid()
'error'
@@ -1944,6 +2395,27 @@ def setDropLast(self, value):
"""
return self._set(dropLast=value)
+ @since("3.0.0")
+ def setInputCols(self, value):
+ """
+ Sets the value of :py:attr:`inputCols`.
+ """
+ return self._set(inputCols=value)
+
+ @since("3.0.0")
+ def setOutputCols(self, value):
+ """
+ Sets the value of :py:attr:`outputCols`.
+ """
+ return self._set(outputCols=value)
+
+ @since("3.0.0")
+ def setHandleInvalid(self, value):
+ """
+ Sets the value of :py:attr:`handleInvalid`.
+ """
+ return self._set(handleInvalid=value)
+
def _create_model(self, java_model):
return OneHotEncoderModel(java_model)
@@ -1955,6 +2427,34 @@ class OneHotEncoderModel(JavaModel, _OneHotEncoderParams, JavaMLReadable, JavaML
.. versionadded:: 2.3.0
"""
+ @since("3.0.0")
+ def setDropLast(self, value):
+ """
+ Sets the value of :py:attr:`dropLast`.
+ """
+ return self._set(dropLast=value)
+
+ @since("3.0.0")
+ def setInputCols(self, value):
+ """
+ Sets the value of :py:attr:`inputCols`.
+ """
+ return self._set(inputCols=value)
+
+ @since("3.0.0")
+ def setOutputCols(self, value):
+ """
+ Sets the value of :py:attr:`outputCols`.
+ """
+ return self._set(outputCols=value)
+
+ @since("3.0.0")
+ def setHandleInvalid(self, value):
+ """
+ Sets the value of :py:attr:`handleInvalid`.
+ """
+ return self._set(handleInvalid=value)
+
@property
@since("2.3.0")
def categorySizes(self):
@@ -1977,7 +2477,11 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLRead
>>> from pyspark.ml.linalg import Vectors
>>> df = spark.createDataFrame([(Vectors.dense([0.5, 2.0]),)], ["dense"])
- >>> px = PolynomialExpansion(degree=2, inputCol="dense", outputCol="expanded")
+ >>> px = PolynomialExpansion(degree=2)
+ >>> px.setInputCol("dense")
+ PolynomialExpansion...
+ >>> px.setOutputCol("expanded")
+ PolynomialExpansion...
>>> px.transform(df).head().expanded
DenseVector([0.5, 0.25, 2.0, 1.0, 4.0])
>>> px.setParams(outputCol="test").transform(df).head().test
@@ -2030,6 +2534,18 @@ def getDegree(self):
"""
return self.getOrDefault(self.degree)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
@inherit_doc
class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols,
@@ -2060,8 +2576,13 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasInputCols
>>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)]
>>> df1 = spark.createDataFrame(values, ["values"])
- >>> qds1 = QuantileDiscretizer(numBuckets=2,
- ... inputCol="values", outputCol="buckets", relativeError=0.01, handleInvalid="error")
+ >>> qds1 = QuantileDiscretizer(inputCol="values", outputCol="buckets")
+ >>> qds1.setNumBuckets(2)
+ QuantileDiscretizer...
+ >>> qds1.setRelativeError(0.01)
+ QuantileDiscretizer...
+ >>> qds1.setHandleInvalid("error")
+ QuantileDiscretizer...
>>> qds1.getRelativeError()
0.01
>>> bucketizer = qds1.fit(df1)
@@ -2213,6 +2734,38 @@ def getRelativeError(self):
"""
return self.getOrDefault(self.relativeError)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ @since("3.0.0")
+ def setInputCols(self, value):
+ """
+ Sets the value of :py:attr:`inputCols`.
+ """
+ return self._set(inputCols=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
+ @since("3.0.0")
+ def setOutputCols(self, value):
+ """
+ Sets the value of :py:attr:`outputCols`.
+ """
+ return self._set(outputCols=value)
+
+ def setHandleInvalid(self, value):
+ """
+ Sets the value of :py:attr:`handleInvalid`.
+ """
+ return self._set(handleInvalid=value)
+
def _create_model(self, java_model):
"""
Private method to convert the java_model to a Python model.
@@ -2292,7 +2845,11 @@ class RobustScaler(JavaEstimator, _RobustScalerParams, JavaMLReadable, JavaMLWri
... (3, Vectors.dense([3.0, -3.0]),),
... (4, Vectors.dense([4.0, -4.0]),),]
>>> df = spark.createDataFrame(data, ["id", "features"])
- >>> scaler = RobustScaler(inputCol="features", outputCol="scaled")
+ >>> scaler = RobustScaler()
+ >>> scaler.setInputCol("features")
+ RobustScaler...
+ >>> scaler.setOutputCol("scaled")
+ RobustScaler...
>>> model = scaler.fit(df)
>>> model.setOutputCol("output")
RobustScaler...
@@ -2373,6 +2930,20 @@ def setWithScaling(self, value):
"""
return self._set(withScaling=value)
+ @since("3.0.0")
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ @since("3.0.0")
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
def _create_model(self, java_model):
return RobustScalerModel(java_model)
@@ -2384,6 +2955,20 @@ class RobustScalerModel(JavaModel, _RobustScalerParams, JavaMLReadable, JavaMLWr
.. versionadded:: 3.0.0
"""
+ @since("3.0.0")
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ @since("3.0.0")
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
@property
@since("3.0.0")
def median(self):
@@ -2413,7 +2998,11 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
It returns an array of strings that can be empty.
>>> df = spark.createDataFrame([("A B c",)], ["text"])
- >>> reTokenizer = RegexTokenizer(inputCol="text", outputCol="words")
+ >>> reTokenizer = RegexTokenizer()
+ >>> reTokenizer.setInputCol("text")
+ RegexTokenizer...
+ >>> reTokenizer.setOutputCol("words")
+ RegexTokenizer...
>>> reTokenizer.transform(df).head()
Row(text=u'A B c', words=[u'a', u'b', u'c'])
>>> # Change a parameter.
@@ -2530,6 +3119,18 @@ def getToLowercase(self):
"""
return self.getOrDefault(self.toLowercase)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
@inherit_doc
class SQLTransformer(JavaTransformer, JavaMLReadable, JavaMLWritable):
@@ -2629,7 +3230,11 @@ class StandardScaler(JavaEstimator, _StandardScalerParams, JavaMLReadable, JavaM
>>> from pyspark.ml.linalg import Vectors
>>> df = spark.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"])
- >>> standardScaler = StandardScaler(inputCol="a", outputCol="scaled")
+ >>> standardScaler = StandardScaler()
+ >>> standardScaler.setInputCol("a")
+ StandardScaler...
+ >>> standardScaler.setOutputCol("scaled")
+ StandardScaler...
>>> model = standardScaler.fit(df)
>>> model.getInputCol()
'a'
@@ -2694,6 +3299,18 @@ def setWithStd(self, value):
"""
return self._set(withStd=value)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
def _create_model(self, java_model):
return StandardScalerModel(java_model)
@@ -2705,6 +3322,18 @@ class StandardScalerModel(JavaModel, _StandardScalerParams, JavaMLReadable, Java
.. versionadded:: 1.4.0
"""
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
@property
@since("2.0.0")
def std(self):
@@ -2765,8 +3394,10 @@ class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLW
so the most frequent label gets index 0. The ordering behavior is controlled by
setting :py:attr:`stringOrderType`. Its default value is 'frequencyDesc'.
- >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="error",
+ >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed",
... stringOrderType="frequencyDesc")
+ >>> stringIndexer.setHandleInvalid("error")
+ StringIndexer...
>>> model = stringIndexer.fit(stringIndDf)
>>> td = model.transform(stringIndDf)
>>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
@@ -2866,6 +3497,38 @@ def setStringOrderType(self, value):
"""
return self._set(stringOrderType=value)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ @since("3.0.0")
+ def setInputCols(self, value):
+ """
+ Sets the value of :py:attr:`inputCols`.
+ """
+ return self._set(inputCols=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
+ @since("3.0.0")
+ def setOutputCols(self, value):
+ """
+ Sets the value of :py:attr:`outputCols`.
+ """
+ return self._set(outputCols=value)
+
+ def setHandleInvalid(self, value):
+ """
+ Sets the value of :py:attr:`handleInvalid`.
+ """
+ return self._set(handleInvalid=value)
+
class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaMLWritable):
"""
@@ -2874,6 +3537,39 @@ class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaML
.. versionadded:: 1.4.0
"""
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ @since("3.0.0")
+ def setInputCols(self, value):
+ """
+ Sets the value of :py:attr:`inputCols`.
+ """
+ return self._set(inputCols=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
+ @since("3.0.0")
+ def setOutputCols(self, value):
+ """
+ Sets the value of :py:attr:`outputCols`.
+ """
+ return self._set(outputCols=value)
+
+ @since("2.4.0")
+ def setHandleInvalid(self, value):
+ """
+ Sets the value of :py:attr:`handleInvalid`.
+ """
+ return self._set(handleInvalid=value)
+
@classmethod
@since("2.4.0")
def from_labels(cls, labels, inputCol, outputCol=None, handleInvalid=None):
@@ -2921,13 +3617,6 @@ def labels(self):
"""
return self._call_java("labels")
- @since("2.4.0")
- def setHandleInvalid(self, value):
- """
- Sets the value of :py:attr:`handleInvalid`.
- """
- return self._set(handleInvalid=value)
-
@inherit_doc
class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
@@ -2981,6 +3670,18 @@ def getLabels(self):
"""
return self.getOrDefault(self.labels)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
@@ -2989,7 +3690,11 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
.. note:: null values from input array are preserved unless adding null to stopWords explicitly.
>>> df = spark.createDataFrame([(["a", "b", "c"],)], ["text"])
- >>> remover = StopWordsRemover(inputCol="text", outputCol="words", stopWords=["b"])
+ >>> remover = StopWordsRemover(stopWords=["b"])
+ >>> remover.setInputCol("text")
+ StopWordsRemover...
+ >>> remover.setOutputCol("words")
+ StopWordsRemover...
>>> remover.transform(df).head().words == ['a', 'c']
True
>>> stopWordsRemoverPath = temp_path + "/stopwords-remover"
@@ -3079,6 +3784,18 @@ def getLocale(self):
"""
return self.getOrDefault(self.locale)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
@staticmethod
@since("2.0.0")
def loadDefaultStopWords(language):
@@ -3099,7 +3816,9 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Java
splits it by white spaces.
>>> df = spark.createDataFrame([("a b c",)], ["text"])
- >>> tokenizer = Tokenizer(inputCol="text", outputCol="words")
+ >>> tokenizer = Tokenizer(outputCol="words")
+ >>> tokenizer.setInputCol("text")
+ Tokenizer...
>>> tokenizer.transform(df).head()
Row(text=u'a b c', words=[u'a', u'b', u'c'])
>>> # Change a parameter.
@@ -3144,6 +3863,18 @@ def setParams(self, inputCol=None, outputCol=None):
kwargs = self._input_kwargs
return self._set(**kwargs)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
@inherit_doc
class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, HasHandleInvalid, JavaMLReadable,
@@ -3152,7 +3883,9 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, HasHandleInva
A feature transformer that merges multiple columns into a vector column.
>>> df = spark.createDataFrame([(1, 0, 3)], ["a", "b", "c"])
- >>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features")
+ >>> vecAssembler = VectorAssembler(outputCol="features")
+ >>> vecAssembler.setInputCols(["a", "b", "c"])
+ VectorAssembler...
>>> vecAssembler.transform(df).head().features
DenseVector([1.0, 0.0, 3.0])
>>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqs
@@ -3220,6 +3953,24 @@ def setParams(self, inputCols=None, outputCol=None, handleInvalid="error"):
kwargs = self._input_kwargs
return self._set(**kwargs)
+ def setInputCols(self, value):
+ """
+ Sets the value of :py:attr:`inputCols`.
+ """
+ return self._set(inputCols=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
+ def setHandleInvalid(self, value):
+ """
+ Sets the value of :py:attr:`handleInvalid`.
+ """
+ return self._set(handleInvalid=value)
+
class _VectorIndexerParams(HasInputCol, HasOutputCol, HasHandleInvalid):
"""
@@ -3288,7 +4039,9 @@ class VectorIndexer(JavaEstimator, _VectorIndexerParams, JavaMLReadable, JavaMLW
>>> from pyspark.ml.linalg import Vectors
>>> df = spark.createDataFrame([(Vectors.dense([-1.0, 0.0]),),
... (Vectors.dense([0.0, 1.0]),), (Vectors.dense([0.0, 2.0]),)], ["a"])
- >>> indexer = VectorIndexer(maxCategories=2, inputCol="a", outputCol="indexed")
+ >>> indexer = VectorIndexer(maxCategories=2, inputCol="a")
+ >>> indexer.setOutputCol("indexed")
+ VectorIndexer...
>>> model = indexer.fit(df)
>>> indexer.getHandleInvalid()
'error'
@@ -3359,6 +4112,24 @@ def setMaxCategories(self, value):
"""
return self._set(maxCategories=value)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
+ def setHandleInvalid(self, value):
+ """
+ Sets the value of :py:attr:`handleInvalid`.
+ """
+ return self._set(handleInvalid=value)
+
def _create_model(self, java_model):
return VectorIndexerModel(java_model)
@@ -3380,6 +4151,20 @@ class VectorIndexerModel(JavaModel, _VectorIndexerParams, JavaMLReadable, JavaML
.. versionadded:: 1.4.0
"""
+ @since("3.0.0")
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ @since("3.0.0")
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
@property
@since("1.4.0")
def numFeatures(self):
@@ -3417,7 +4202,9 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, J
... (Vectors.dense([-2.0, 2.3, 0.0, 0.0, 1.0]),),
... (Vectors.dense([0.0, 0.0, 0.0, 0.0, 0.0]),),
... (Vectors.dense([0.6, -1.1, -3.0, 4.5, 3.3]),)], ["features"])
- >>> vs = VectorSlicer(inputCol="features", outputCol="sliced", indices=[1, 4])
+ >>> vs = VectorSlicer(outputCol="sliced", indices=[1, 4])
+ >>> vs.setInputCol("features")
+ VectorSlicer...
>>> vs.transform(df).head().sliced
DenseVector([2.3, 1.0])
>>> vectorSlicerPath = temp_path + "/vector-slicer"
@@ -3488,6 +4275,18 @@ def getNames(self):
"""
return self.getOrDefault(self.names)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
class _Word2VecParams(HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol):
"""
@@ -3560,6 +4359,11 @@ class Word2Vec(JavaEstimator, _Word2VecParams, JavaMLReadable, JavaMLWritable):
>>> sent = ("a b " * 100 + "a c " * 10).split(" ")
>>> doc = spark.createDataFrame([(sent,), (sent,)], ["sentence"])
>>> word2Vec = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model")
+ >>> word2Vec.setMaxIter(10)
+ Word2Vec...
+ >>> word2Vec.getMaxIter()
+ 10
+ >>> word2Vec.clear(word2Vec.maxIter)
>>> model = word2Vec.fit(doc)
>>> model.getMinCount()
5
@@ -3666,12 +4470,36 @@ def setMaxSentenceLength(self, value):
"""
return self._set(maxSentenceLength=value)
- @since("2.0.0")
- def getMaxSentenceLength(self):
+ def setMaxIter(self, value):
"""
- Gets the value of maxSentenceLength or its default value.
+ Sets the value of :py:attr:`maxIter`.
"""
- return self.getOrDefault(self.maxSentenceLength)
+ return self._set(maxIter=value)
+
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
+ def setSeed(self, value):
+ """
+ Sets the value of :py:attr:`seed`.
+ """
+ return self._set(seed=value)
+
+ @since("1.4.0")
+ def setStepSize(self, value):
+ """
+ Sets the value of :py:attr:`stepSize`.
+ """
+ return self._set(stepSize=value)
def _create_model(self, java_model):
return Word2VecModel(java_model)
@@ -3692,6 +4520,18 @@ def getVectors(self):
"""
return self._call_java("getVectors")
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
@since("1.5.0")
def findSynonyms(self, word, num):
"""
@@ -3747,7 +4587,9 @@ class PCA(JavaEstimator, _PCAParams, JavaMLReadable, JavaMLWritable):
... (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),),
... (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)]
>>> df = spark.createDataFrame(data,["features"])
- >>> pca = PCA(k=2, inputCol="features", outputCol="pca_features")
+ >>> pca = PCA(k=2, inputCol="features")
+ >>> pca.setOutputCol("pca_features")
+ PCA...
>>> model = pca.fit(df)
>>> model.getK()
2
@@ -3800,6 +4642,18 @@ def setK(self, value):
"""
return self._set(k=value)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
def _create_model(self, java_model):
return PCAModel(java_model)
@@ -3811,6 +4665,20 @@ class PCAModel(JavaModel, _PCAParams, JavaMLReadable, JavaMLWritable):
.. versionadded:: 1.5.0
"""
+ @since("3.0.0")
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ @since("3.0.0")
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
@property
@since("2.0.0")
def pc(self):
@@ -4001,6 +4869,24 @@ def setStringIndexerOrderType(self, value):
"""
return self._set(stringIndexerOrderType=value)
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ def setLabelCol(self, value):
+ """
+ Sets the value of :py:attr:`labelCol`.
+ """
+ return self._set(labelCol=value)
+
+ def setHandleInvalid(self, value):
+ """
+ Sets the value of :py:attr:`handleInvalid`.
+ """
+ return self._set(handleInvalid=value)
+
def _create_model(self, java_model):
return RFormulaModel(java_model)
@@ -4228,6 +5114,24 @@ def setFwe(self, value):
"""
return self._set(fwe=value)
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
+ def setLabelCol(self, value):
+ """
+ Sets the value of :py:attr:`labelCol`.
+ """
+ return self._set(labelCol=value)
+
def _create_model(self, java_model):
return ChiSqSelectorModel(java_model)
@@ -4239,6 +5143,20 @@ class ChiSqSelectorModel(JavaModel, _ChiSqSelectorParams, JavaMLReadable, JavaML
.. versionadded:: 2.0.0
"""
+ @since("3.0.0")
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ @since("3.0.0")
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ return self._set(outputCol=value)
+
@property
@since("2.0.0")
def selectedFeatures(self):
@@ -4323,6 +5241,18 @@ def setSize(self, value):
""" Sets size param, the size of vectors in `inputCol`."""
return self._set(size=value)
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ return self._set(inputCol=value)
+
+ def setHandleInvalid(self, value):
+ """
+ Sets the value of :py:attr:`handleInvalid`.
+ """
+ return self._set(handleInvalid=value)
+
if __name__ == "__main__":
import doctest
diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py
index 652acbb34a901..5b34d555484d1 100644
--- a/python/pyspark/ml/fpm.py
+++ b/python/pyspark/ml/fpm.py
@@ -102,6 +102,13 @@ def setMinConfidence(self, value):
"""
return self._set(minConfidence=value)
+ @since("3.0.0")
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
@property
@since("2.2.0")
def freqItemsets(self):
@@ -239,6 +246,12 @@ def setMinConfidence(self, value):
"""
return self._set(minConfidence=value)
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
def _create_model(self, java_model):
return FPGrowthModel(java_model)
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index c99ec3f467ac6..8ea94e4760007 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -81,12 +81,6 @@ def _gen_param_code(name, doc, defaultValueStr):
"""
# TODO: How to correctly inherit instance attributes?
template = '''
- def set$Name(self, value):
- """
- Sets the value of :py:attr:`$name`.
- """
- return self._set($name=value)
-
def get$Name(self):
"""
Gets the value of $name or its default value.
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index 771b4bcd9ba02..26d74fab6975a 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -30,12 +30,6 @@ class HasMaxIter(Params):
def __init__(self):
super(HasMaxIter, self).__init__()
- def setMaxIter(self, value):
- """
- Sets the value of :py:attr:`maxIter`.
- """
- return self._set(maxIter=value)
-
def getMaxIter(self):
"""
Gets the value of maxIter or its default value.
@@ -53,12 +47,6 @@ class HasRegParam(Params):
def __init__(self):
super(HasRegParam, self).__init__()
- def setRegParam(self, value):
- """
- Sets the value of :py:attr:`regParam`.
- """
- return self._set(regParam=value)
-
def getRegParam(self):
"""
Gets the value of regParam or its default value.
@@ -77,12 +65,6 @@ def __init__(self):
super(HasFeaturesCol, self).__init__()
self._setDefault(featuresCol='features')
- def setFeaturesCol(self, value):
- """
- Sets the value of :py:attr:`featuresCol`.
- """
- return self._set(featuresCol=value)
-
def getFeaturesCol(self):
"""
Gets the value of featuresCol or its default value.
@@ -101,12 +83,6 @@ def __init__(self):
super(HasLabelCol, self).__init__()
self._setDefault(labelCol='label')
- def setLabelCol(self, value):
- """
- Sets the value of :py:attr:`labelCol`.
- """
- return self._set(labelCol=value)
-
def getLabelCol(self):
"""
Gets the value of labelCol or its default value.
@@ -125,12 +101,6 @@ def __init__(self):
super(HasPredictionCol, self).__init__()
self._setDefault(predictionCol='prediction')
- def setPredictionCol(self, value):
- """
- Sets the value of :py:attr:`predictionCol`.
- """
- return self._set(predictionCol=value)
-
def getPredictionCol(self):
"""
Gets the value of predictionCol or its default value.
@@ -149,12 +119,6 @@ def __init__(self):
super(HasProbabilityCol, self).__init__()
self._setDefault(probabilityCol='probability')
- def setProbabilityCol(self, value):
- """
- Sets the value of :py:attr:`probabilityCol`.
- """
- return self._set(probabilityCol=value)
-
def getProbabilityCol(self):
"""
Gets the value of probabilityCol or its default value.
@@ -173,12 +137,6 @@ def __init__(self):
super(HasRawPredictionCol, self).__init__()
self._setDefault(rawPredictionCol='rawPrediction')
- def setRawPredictionCol(self, value):
- """
- Sets the value of :py:attr:`rawPredictionCol`.
- """
- return self._set(rawPredictionCol=value)
-
def getRawPredictionCol(self):
"""
Gets the value of rawPredictionCol or its default value.
@@ -196,12 +154,6 @@ class HasInputCol(Params):
def __init__(self):
super(HasInputCol, self).__init__()
- def setInputCol(self, value):
- """
- Sets the value of :py:attr:`inputCol`.
- """
- return self._set(inputCol=value)
-
def getInputCol(self):
"""
Gets the value of inputCol or its default value.
@@ -219,12 +171,6 @@ class HasInputCols(Params):
def __init__(self):
super(HasInputCols, self).__init__()
- def setInputCols(self, value):
- """
- Sets the value of :py:attr:`inputCols`.
- """
- return self._set(inputCols=value)
-
def getInputCols(self):
"""
Gets the value of inputCols or its default value.
@@ -243,12 +189,6 @@ def __init__(self):
super(HasOutputCol, self).__init__()
self._setDefault(outputCol=self.uid + '__output')
- def setOutputCol(self, value):
- """
- Sets the value of :py:attr:`outputCol`.
- """
- return self._set(outputCol=value)
-
def getOutputCol(self):
"""
Gets the value of outputCol or its default value.
@@ -266,12 +206,6 @@ class HasOutputCols(Params):
def __init__(self):
super(HasOutputCols, self).__init__()
- def setOutputCols(self, value):
- """
- Sets the value of :py:attr:`outputCols`.
- """
- return self._set(outputCols=value)
-
def getOutputCols(self):
"""
Gets the value of outputCols or its default value.
@@ -290,12 +224,6 @@ def __init__(self):
super(HasNumFeatures, self).__init__()
self._setDefault(numFeatures=262144)
- def setNumFeatures(self, value):
- """
- Sets the value of :py:attr:`numFeatures`.
- """
- return self._set(numFeatures=value)
-
def getNumFeatures(self):
"""
Gets the value of numFeatures or its default value.
@@ -313,12 +241,6 @@ class HasCheckpointInterval(Params):
def __init__(self):
super(HasCheckpointInterval, self).__init__()
- def setCheckpointInterval(self, value):
- """
- Sets the value of :py:attr:`checkpointInterval`.
- """
- return self._set(checkpointInterval=value)
-
def getCheckpointInterval(self):
"""
Gets the value of checkpointInterval or its default value.
@@ -337,12 +259,6 @@ def __init__(self):
super(HasSeed, self).__init__()
self._setDefault(seed=hash(type(self).__name__))
- def setSeed(self, value):
- """
- Sets the value of :py:attr:`seed`.
- """
- return self._set(seed=value)
-
def getSeed(self):
"""
Gets the value of seed or its default value.
@@ -360,12 +276,6 @@ class HasTol(Params):
def __init__(self):
super(HasTol, self).__init__()
- def setTol(self, value):
- """
- Sets the value of :py:attr:`tol`.
- """
- return self._set(tol=value)
-
def getTol(self):
"""
Gets the value of tol or its default value.
@@ -383,12 +293,6 @@ class HasStepSize(Params):
def __init__(self):
super(HasStepSize, self).__init__()
- def setStepSize(self, value):
- """
- Sets the value of :py:attr:`stepSize`.
- """
- return self._set(stepSize=value)
-
def getStepSize(self):
"""
Gets the value of stepSize or its default value.
@@ -406,12 +310,6 @@ class HasHandleInvalid(Params):
def __init__(self):
super(HasHandleInvalid, self).__init__()
- def setHandleInvalid(self, value):
- """
- Sets the value of :py:attr:`handleInvalid`.
- """
- return self._set(handleInvalid=value)
-
def getHandleInvalid(self):
"""
Gets the value of handleInvalid or its default value.
@@ -430,12 +328,6 @@ def __init__(self):
super(HasElasticNetParam, self).__init__()
self._setDefault(elasticNetParam=0.0)
- def setElasticNetParam(self, value):
- """
- Sets the value of :py:attr:`elasticNetParam`.
- """
- return self._set(elasticNetParam=value)
-
def getElasticNetParam(self):
"""
Gets the value of elasticNetParam or its default value.
@@ -454,12 +346,6 @@ def __init__(self):
super(HasFitIntercept, self).__init__()
self._setDefault(fitIntercept=True)
- def setFitIntercept(self, value):
- """
- Sets the value of :py:attr:`fitIntercept`.
- """
- return self._set(fitIntercept=value)
-
def getFitIntercept(self):
"""
Gets the value of fitIntercept or its default value.
@@ -478,12 +364,6 @@ def __init__(self):
super(HasStandardization, self).__init__()
self._setDefault(standardization=True)
- def setStandardization(self, value):
- """
- Sets the value of :py:attr:`standardization`.
- """
- return self._set(standardization=value)
-
def getStandardization(self):
"""
Gets the value of standardization or its default value.
@@ -501,12 +381,6 @@ class HasThresholds(Params):
def __init__(self):
super(HasThresholds, self).__init__()
- def setThresholds(self, value):
- """
- Sets the value of :py:attr:`thresholds`.
- """
- return self._set(thresholds=value)
-
def getThresholds(self):
"""
Gets the value of thresholds or its default value.
@@ -525,12 +399,6 @@ def __init__(self):
super(HasThreshold, self).__init__()
self._setDefault(threshold=0.5)
- def setThreshold(self, value):
- """
- Sets the value of :py:attr:`threshold`.
- """
- return self._set(threshold=value)
-
def getThreshold(self):
"""
Gets the value of threshold or its default value.
@@ -548,12 +416,6 @@ class HasWeightCol(Params):
def __init__(self):
super(HasWeightCol, self).__init__()
- def setWeightCol(self, value):
- """
- Sets the value of :py:attr:`weightCol`.
- """
- return self._set(weightCol=value)
-
def getWeightCol(self):
"""
Gets the value of weightCol or its default value.
@@ -572,12 +434,6 @@ def __init__(self):
super(HasSolver, self).__init__()
self._setDefault(solver='auto')
- def setSolver(self, value):
- """
- Sets the value of :py:attr:`solver`.
- """
- return self._set(solver=value)
-
def getSolver(self):
"""
Gets the value of solver or its default value.
@@ -595,12 +451,6 @@ class HasVarianceCol(Params):
def __init__(self):
super(HasVarianceCol, self).__init__()
- def setVarianceCol(self, value):
- """
- Sets the value of :py:attr:`varianceCol`.
- """
- return self._set(varianceCol=value)
-
def getVarianceCol(self):
"""
Gets the value of varianceCol or its default value.
@@ -619,12 +469,6 @@ def __init__(self):
super(HasAggregationDepth, self).__init__()
self._setDefault(aggregationDepth=2)
- def setAggregationDepth(self, value):
- """
- Sets the value of :py:attr:`aggregationDepth`.
- """
- return self._set(aggregationDepth=value)
-
def getAggregationDepth(self):
"""
Gets the value of aggregationDepth or its default value.
@@ -643,12 +487,6 @@ def __init__(self):
super(HasParallelism, self).__init__()
self._setDefault(parallelism=1)
- def setParallelism(self, value):
- """
- Sets the value of :py:attr:`parallelism`.
- """
- return self._set(parallelism=value)
-
def getParallelism(self):
"""
Gets the value of parallelism or its default value.
@@ -667,12 +505,6 @@ def __init__(self):
super(HasCollectSubModels, self).__init__()
self._setDefault(collectSubModels=False)
- def setCollectSubModels(self, value):
- """
- Sets the value of :py:attr:`collectSubModels`.
- """
- return self._set(collectSubModels=value)
-
def getCollectSubModels(self):
"""
Gets the value of collectSubModels or its default value.
@@ -690,12 +522,6 @@ class HasLoss(Params):
def __init__(self):
super(HasLoss, self).__init__()
- def setLoss(self, value):
- """
- Sets the value of :py:attr:`loss`.
- """
- return self._set(loss=value)
-
def getLoss(self):
"""
Gets the value of loss or its default value.
@@ -714,12 +540,6 @@ def __init__(self):
super(HasDistanceMeasure, self).__init__()
self._setDefault(distanceMeasure='euclidean')
- def setDistanceMeasure(self, value):
- """
- Sets the value of :py:attr:`distanceMeasure`.
- """
- return self._set(distanceMeasure=value)
-
def getDistanceMeasure(self):
"""
Gets the value of distanceMeasure or its default value.
@@ -737,12 +557,6 @@ class HasValidationIndicatorCol(Params):
def __init__(self):
super(HasValidationIndicatorCol, self).__init__()
- def setValidationIndicatorCol(self, value):
- """
- Sets the value of :py:attr:`validationIndicatorCol`.
- """
- return self._set(validationIndicatorCol=value)
-
def getValidationIndicatorCol(self):
"""
Gets the value of validationIndicatorCol or its default value.
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index df9c765457ec1..3ebd0ac2765f3 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -212,7 +212,16 @@ class ALS(JavaEstimator, _ALSParams, JavaMLWritable, JavaMLReadable):
>>> df = spark.createDataFrame(
... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
... ["user", "item", "rating"])
- >>> als = ALS(rank=10, maxIter=5, seed=0)
+ >>> als = ALS(rank=10, seed=0)
+ >>> als.setMaxIter(5)
+ ALS...
+ >>> als.getMaxIter()
+ 5
+ >>> als.setRegParam(0.1)
+ ALS...
+ >>> als.getRegParam()
+ 0.1
+ >>> als.clear(als.regParam)
>>> model = als.fit(df)
>>> model.getUserCol()
'user'
@@ -402,6 +411,36 @@ def setColdStartStrategy(self, value):
"""
return self._set(coldStartStrategy=value)
+ def setMaxIter(self, value):
+ """
+ Sets the value of :py:attr:`maxIter`.
+ """
+ return self._set(maxIter=value)
+
+ def setRegParam(self, value):
+ """
+ Sets the value of :py:attr:`regParam`.
+ """
+ return self._set(regParam=value)
+
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
+ def setCheckpointInterval(self, value):
+ """
+ Sets the value of :py:attr:`checkpointInterval`.
+ """
+ return self._set(checkpointInterval=value)
+
+ def setSeed(self, value):
+ """
+ Sets the value of :py:attr:`seed`.
+ """
+ return self._set(seed=value)
+
class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable, JavaMLReadable):
"""
@@ -431,6 +470,13 @@ def setColdStartStrategy(self, value):
"""
return self._set(coldStartStrategy=value)
+ @since("3.0.0")
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
@property
@since("1.4.0")
def rank(self):
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 147ebed1d633a..08e68d8bc3044 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -92,7 +92,17 @@ class LinearRegression(JavaPredictor, _LinearRegressionParams, JavaMLWritable, J
>>> df = spark.createDataFrame([
... (1.0, 2.0, Vectors.dense(1.0)),
... (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])
- >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight")
+ >>> lr = LinearRegression(regParam=0.0, solver="normal", weightCol="weight")
+ >>> lr.setMaxIter(5)
+ LinearRegression...
+ >>> lr.getMaxIter()
+ 5
+ >>> lr.setRegParam(0.1)
+ LinearRegression...
+ >>> lr.getRegParam()
+ 0.1
+ >>> lr.setRegParam(0.0)
+ LinearRegression...
>>> model = lr.fit(df)
>>> model.setFeaturesCol("features")
LinearRegression...
@@ -179,6 +189,66 @@ def setEpsilon(self, value):
"""
return self._set(epsilon=value)
+ def setMaxIter(self, value):
+ """
+ Sets the value of :py:attr:`maxIter`.
+ """
+ return self._set(maxIter=value)
+
+ def setRegParam(self, value):
+ """
+ Sets the value of :py:attr:`regParam`.
+ """
+ return self._set(regParam=value)
+
+ def setTol(self, value):
+ """
+ Sets the value of :py:attr:`tol`.
+ """
+ return self._set(tol=value)
+
+ def setElasticNetParam(self, value):
+ """
+ Sets the value of :py:attr:`elasticNetParam`.
+ """
+ return self._set(elasticNetParam=value)
+
+ def setFitIntercept(self, value):
+ """
+ Sets the value of :py:attr:`fitIntercept`.
+ """
+ return self._set(fitIntercept=value)
+
+ def setStandardization(self, value):
+ """
+ Sets the value of :py:attr:`standardization`.
+ """
+ return self._set(standardization=value)
+
+ def setWeightCol(self, value):
+ """
+ Sets the value of :py:attr:`weightCol`.
+ """
+ return self._set(weightCol=value)
+
+ def setSolver(self, value):
+ """
+ Sets the value of :py:attr:`solver`.
+ """
+ return self._set(solver=value)
+
+ def setAggregationDepth(self, value):
+ """
+ Sets the value of :py:attr:`aggregationDepth`.
+ """
+ return self._set(aggregationDepth=value)
+
+ def setLoss(self, value):
+ """
+ Sets the value of :py:attr:`loss`.
+ """
+ return self._set(lossType=value)
+
class LinearRegressionModel(JavaPredictionModel, _LinearRegressionParams, GeneralJavaMLWritable,
JavaMLReadable, HasTrainingSummary):
@@ -522,10 +592,6 @@ class IsotonicRegression(JavaEstimator, _IsotonicRegressionParams, HasWeightCol,
>>> model = ir.fit(df)
>>> model.setFeaturesCol("features")
IsotonicRegression...
- >>> model.setLabelCol("newLabel")
- IsotonicRegression...
- >>> model.getLabelCol()
- 'newLabel'
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
@@ -586,6 +652,34 @@ def setFeatureIndex(self, value):
"""
return self._set(featureIndex=value)
+ @since("1.6.0")
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ @since("1.6.0")
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
+ @since("1.6.0")
+ def setLabelCol(self, value):
+ """
+ Sets the value of :py:attr:`labelCol`.
+ """
+ return self._set(labelCol=value)
+
+ @since("1.6.0")
+ def setWeightCol(self, value):
+ """
+ Sets the value of :py:attr:`weightCol`.
+ """
+ return self._set(weightCol=value)
+
class IsotonicRegressionModel(JavaModel, _IsotonicRegressionParams, JavaMLWritable,
JavaMLReadable):
@@ -595,6 +689,26 @@ class IsotonicRegressionModel(JavaModel, _IsotonicRegressionParams, JavaMLWritab
.. versionadded:: 1.6.0
"""
+ @since("3.0.0")
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ @since("3.0.0")
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
+ def setFeatureIndex(self, value):
+ """
+ Sets the value of :py:attr:`featureIndex`.
+ """
+ return self._set(featureIndex=value)
+
@property
@since("1.6.0")
def boundaries(self):
@@ -635,7 +749,9 @@ class DecisionTreeRegressor(JavaPredictor, _DecisionTreeRegressorParams, JavaMLW
>>> df = spark.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
- >>> dt = DecisionTreeRegressor(maxDepth=2, varianceCol="variance")
+ >>> dt = DecisionTreeRegressor(maxDepth=2)
+ >>> dt.setVarianceCol("variance")
+ DecisionTreeRegressor...
>>> model = dt.fit(df)
>>> model.getVarianceCol()
'variance'
@@ -732,18 +848,21 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
def _create_model(self, java_model):
return DecisionTreeRegressionModel(java_model)
+ @since("1.4.0")
def setMaxDepth(self, value):
"""
Sets the value of :py:attr:`maxDepth`.
"""
return self._set(maxDepth=value)
+ @since("1.4.0")
def setMaxBins(self, value):
"""
Sets the value of :py:attr:`maxBins`.
"""
return self._set(maxBins=value)
+ @since("1.4.0")
def setMinInstancesPerNode(self, value):
"""
Sets the value of :py:attr:`minInstancesPerNode`.
@@ -757,18 +876,21 @@ def setMinWeightFractionPerNode(self, value):
"""
return self._set(minWeightFractionPerNode=value)
+ @since("1.4.0")
def setMinInfoGain(self, value):
"""
Sets the value of :py:attr:`minInfoGain`.
"""
return self._set(minInfoGain=value)
+ @since("1.4.0")
def setMaxMemoryInMB(self, value):
"""
Sets the value of :py:attr:`maxMemoryInMB`.
"""
return self._set(maxMemoryInMB=value)
+ @since("1.4.0")
def setCacheNodeIds(self, value):
"""
Sets the value of :py:attr:`cacheNodeIds`.
@@ -782,6 +904,34 @@ def setImpurity(self, value):
"""
return self._set(impurity=value)
+ @since("1.4.0")
+ def setCheckpointInterval(self, value):
+ """
+ Sets the value of :py:attr:`checkpointInterval`.
+ """
+ return self._set(checkpointInterval=value)
+
+ @since("1.4.0")
+ def setSeed(self, value):
+ """
+ Sets the value of :py:attr:`seed`.
+ """
+ return self._set(seed=value)
+
+ @since("3.0.0")
+ def setWeightCol(self, value):
+ """
+ Sets the value of :py:attr:`weightCol`.
+ """
+ return self._set(weightCol=value)
+
+ @since("2.0.0")
+ def setVarianceCol(self, value):
+ """
+ Sets the value of :py:attr:`varianceCol`.
+ """
+ return self._set(varianceCol=value)
+
@inherit_doc
class DecisionTreeRegressionModel(_DecisionTreeModel, _DecisionTreeRegressorParams,
@@ -792,6 +942,13 @@ class DecisionTreeRegressionModel(_DecisionTreeModel, _DecisionTreeRegressorPara
.. versionadded:: 1.4.0
"""
+ @since("3.0.0")
+ def setVarianceCol(self, value):
+ """
+ Sets the value of :py:attr:`varianceCol`.
+ """
+ return self._set(varianceCol=value)
+
@property
@since("2.0.0")
def featureImportances(self):
@@ -836,7 +993,9 @@ class RandomForestRegressor(JavaPredictor, _RandomForestRegressorParams, JavaMLW
>>> df = spark.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
- >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42)
+ >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2)
+ >>> rf.setSeed(42)
+ RandomForestRegressor...
>>> model = rf.fit(df)
>>> model.getSeed()
42
@@ -987,6 +1146,18 @@ def setFeatureSubsetStrategy(self, value):
"""
return self._set(featureSubsetStrategy=value)
+ def setCheckpointInterval(self, value):
+ """
+ Sets the value of :py:attr:`checkpointInterval`.
+ """
+ return self._set(checkpointInterval=value)
+
+ def setSeed(self, value):
+ """
+ Sets the value of :py:attr:`seed`.
+ """
+ return self._set(seed=value)
+
class RandomForestRegressionModel(_TreeEnsembleModel, _RandomForestRegressorParams,
JavaMLWritable, JavaMLReadable):
@@ -1052,7 +1223,11 @@ class GBTRegressor(JavaPredictor, _GBTRegressorParams, JavaMLWritable, JavaMLRea
>>> df = spark.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
- >>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42, leafCol="leafId")
+ >>> gbt = GBTRegressor(maxDepth=2, seed=42, leafCol="leafId")
+ >>> gbt.setMaxIter(5)
+ GBTRegressor...
+ >>> gbt.getMaxIter()
+ 5
>>> print(gbt.getImpurity())
variance
>>> print(gbt.getFeatureSubsetStrategy())
@@ -1152,36 +1327,42 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
def _create_model(self, java_model):
return GBTRegressionModel(java_model)
+ @since("1.4.0")
def setMaxDepth(self, value):
"""
Sets the value of :py:attr:`maxDepth`.
"""
return self._set(maxDepth=value)
+ @since("1.4.0")
def setMaxBins(self, value):
"""
Sets the value of :py:attr:`maxBins`.
"""
return self._set(maxBins=value)
+ @since("1.4.0")
def setMinInstancesPerNode(self, value):
"""
Sets the value of :py:attr:`minInstancesPerNode`.
"""
return self._set(minInstancesPerNode=value)
+ @since("1.4.0")
def setMinInfoGain(self, value):
"""
Sets the value of :py:attr:`minInfoGain`.
"""
return self._set(minInfoGain=value)
+ @since("1.4.0")
def setMaxMemoryInMB(self, value):
"""
Sets the value of :py:attr:`maxMemoryInMB`.
"""
return self._set(maxMemoryInMB=value)
+ @since("1.4.0")
def setCacheNodeIds(self, value):
"""
Sets the value of :py:attr:`cacheNodeIds`.
@@ -1223,6 +1404,34 @@ def setValidationIndicatorCol(self, value):
"""
return self._set(validationIndicatorCol=value)
+ @since("1.4.0")
+ def setMaxIter(self, value):
+ """
+ Sets the value of :py:attr:`maxIter`.
+ """
+ return self._set(maxIter=value)
+
+ @since("1.4.0")
+ def setCheckpointInterval(self, value):
+ """
+ Sets the value of :py:attr:`checkpointInterval`.
+ """
+ return self._set(checkpointInterval=value)
+
+ @since("1.4.0")
+ def setSeed(self, value):
+ """
+ Sets the value of :py:attr:`seed`.
+ """
+ return self._set(seed=value)
+
+ @since("1.4.0")
+ def setStepSize(self, value):
+ """
+ Sets the value of :py:attr:`stepSize`.
+ """
+ return self._set(stepSize=value)
+
class GBTRegressionModel(_TreeEnsembleModel, _GBTRegressorParams, JavaMLWritable, JavaMLReadable):
"""
@@ -1330,6 +1539,11 @@ class AFTSurvivalRegression(JavaEstimator, _AFTSurvivalRegressionParams,
... (1.0, Vectors.dense(1.0), 1.0),
... (1e-40, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"])
>>> aftsr = AFTSurvivalRegression()
+ >>> aftsr.setMaxIter(10)
+ AFTSurvivalRegression...
+ >>> aftsr.getMaxIter()
+ 10
+ >>> aftsr.clear(aftsr.maxIter)
>>> model = aftsr.fit(df)
>>> model.setFeaturesCol("features")
AFTSurvivalRegression...
@@ -1422,6 +1636,55 @@ def setQuantilesCol(self, value):
"""
return self._set(quantilesCol=value)
+ @since("1.6.0")
+ def setMaxIter(self, value):
+ """
+ Sets the value of :py:attr:`maxIter`.
+ """
+ return self._set(maxIter=value)
+
+ @since("1.6.0")
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ @since("1.6.0")
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
+ @since("1.6.0")
+ def setLabelCol(self, value):
+ """
+ Sets the value of :py:attr:`labelCol`.
+ """
+ return self._set(labelCol=value)
+
+ @since("1.6.0")
+ def setTol(self, value):
+ """
+ Sets the value of :py:attr:`tol`.
+ """
+ return self._set(tol=value)
+
+ @since("1.6.0")
+ def setFitIntercept(self, value):
+ """
+ Sets the value of :py:attr:`fitIntercept`.
+ """
+ return self._set(fitIntercept=value)
+
+ @since("2.1.0")
+ def setAggregationDepth(self, value):
+ """
+ Sets the value of :py:attr:`aggregationDepth`.
+ """
+ return self._set(aggregationDepth=value)
+
class AFTSurvivalRegressionModel(JavaModel, _AFTSurvivalRegressionParams,
JavaMLWritable, JavaMLReadable):
@@ -1431,6 +1694,34 @@ class AFTSurvivalRegressionModel(JavaModel, _AFTSurvivalRegressionParams,
.. versionadded:: 1.6.0
"""
+ @since("3.0.0")
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ return self._set(featuresCol=value)
+
+ @since("3.0.0")
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ return self._set(predictionCol=value)
+
+ @since("3.0.0")
+ def setQuantileProbabilities(self, value):
+ """
+ Sets the value of :py:attr:`quantileProbabilities`.
+ """
+ return self._set(quantileProbabilities=value)
+
+ @since("3.0.0")
+ def setQuantilesCol(self, value):
+ """
+ Sets the value of :py:attr:`quantilesCol`.
+ """
+ return self._set(quantilesCol=value)
+
@property
@since("2.0.0")
def coefficients(self):
@@ -1577,6 +1868,16 @@ class GeneralizedLinearRegression(JavaPredictor, _GeneralizedLinearRegressionPar
... (2.0, Vectors.dense(0.0, 0.0)),
... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])
>>> glr = GeneralizedLinearRegression(family="gaussian", link="identity", linkPredictionCol="p")
+ >>> glr.setRegParam(0.1)
+ GeneralizedLinearRegression...
+ >>> glr.getRegParam()
+ 0.1
+ >>> glr.clear(glr.regParam)
+ >>> glr.setMaxIter(10)
+ GeneralizedLinearRegression...
+ >>> glr.getMaxIter()
+ 10
+ >>> glr.clear(glr.maxIter)
>>> model = glr.fit(df)
>>> model.setFeaturesCol("features")
GeneralizedLinearRegression...
@@ -1690,6 +1991,48 @@ def setOffsetCol(self, value):
"""
return self._set(offsetCol=value)
+ @since("2.0.0")
+ def setMaxIter(self, value):
+ """
+ Sets the value of :py:attr:`maxIter`.
+ """
+ return self._set(maxIter=value)
+
+ @since("2.0.0")
+ def setRegParam(self, value):
+ """
+ Sets the value of :py:attr:`regParam`.
+ """
+ return self._set(regParam=value)
+
+ @since("2.0.0")
+ def setTol(self, value):
+ """
+ Sets the value of :py:attr:`tol`.
+ """
+ return self._set(tol=value)
+
+ @since("2.2.0")
+ def setFitIntercept(self, value):
+ """
+ Sets the value of :py:attr:`fitIntercept`.
+ """
+ return self._set(fitIntercept=value)
+
+ @since("2.0.0")
+ def setWeightCol(self, value):
+ """
+ Sets the value of :py:attr:`weightCol`.
+ """
+ return self._set(weightCol=value)
+
+ @since("2.0.0")
+ def setSolver(self, value):
+ """
+ Sets the value of :py:attr:`solver`.
+ """
+ return self._set(solver=value)
+
class GeneralizedLinearRegressionModel(JavaPredictionModel, _GeneralizedLinearRegressionParams,
JavaMLWritable, JavaMLReadable, HasTrainingSummary):
@@ -1699,6 +2042,13 @@ class GeneralizedLinearRegressionModel(JavaPredictionModel, _GeneralizedLinearRe
.. versionadded:: 2.0.0
"""
+ @since("3.0.0")
+ def setLinkPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`linkPredictionCol`.
+ """
+ return self._set(linkPredictionCol=value)
+
@property
@since("2.0.0")
def coefficients(self):
diff --git a/python/pyspark/ml/tests/test_param.py b/python/pyspark/ml/tests/test_param.py
index 4c7f01484dc21..75cd903b5d6d7 100644
--- a/python/pyspark/ml/tests/test_param.py
+++ b/python/pyspark/ml/tests/test_param.py
@@ -221,13 +221,6 @@ def test_params(self):
self.assertFalse(testParams.isSet(maxIter))
self.assertTrue(testParams.isDefined(maxIter))
self.assertEqual(testParams.getMaxIter(), 10)
- testParams.setMaxIter(100)
- self.assertTrue(testParams.isSet(maxIter))
- self.assertEqual(testParams.getMaxIter(), 100)
- testParams.clear(maxIter)
- self.assertFalse(testParams.isSet(maxIter))
- self.assertEqual(testParams.getMaxIter(), 10)
- testParams.setMaxIter(100)
self.assertTrue(testParams.hasParam(inputCol.name))
self.assertFalse(testParams.hasDefault(inputCol))
@@ -244,13 +237,12 @@ def test_params(self):
# Since the default is normally random, set it to a known number for debug str
testParams._setDefault(seed=41)
- testParams.setSeed(43)
self.assertEqual(
testParams.explainParams(),
"\n".join(["inputCol: input column name. (undefined)",
- "maxIter: max number of iterations (>= 0). (default: 10, current: 100)",
- "seed: random seed. (default: 41, current: 43)"]))
+ "maxIter: max number of iterations (>= 0). (default: 10)",
+ "seed: random seed. (default: 41)"]))
def test_clear_param(self):
df = self.spark.createDataFrame([(Vectors.dense([1.0]),), (Vectors.dense([2.0]),)], ["a"])
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 8052163acd00a..16c376296c20d 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -299,6 +299,24 @@ def setNumFolds(self, value):
"""
return self._set(numFolds=value)
+ def setSeed(self, value):
+ """
+ Sets the value of :py:attr:`seed`.
+ """
+ return self._set(seed=value)
+
+ def setParallelism(self, value):
+ """
+ Sets the value of :py:attr:`parallelism`.
+ """
+ return self._set(parallelism=value)
+
+ def setCollectSubModels(self, value):
+ """
+ Sets the value of :py:attr:`collectSubModels`.
+ """
+ return self._set(collectSubModels=value)
+
def _fit(self, dataset):
est = self.getOrDefault(self.estimator)
epm = self.getOrDefault(self.estimatorParamMaps)
@@ -643,6 +661,24 @@ def setTrainRatio(self, value):
"""
return self._set(trainRatio=value)
+ def setSeed(self, value):
+ """
+ Sets the value of :py:attr:`seed`.
+ """
+ return self._set(seed=value)
+
+ def setParallelism(self, value):
+ """
+ Sets the value of :py:attr:`parallelism`.
+ """
+ return self._set(parallelism=value)
+
+ def setCollectSubModels(self, value):
+ """
+ Sets the value of :py:attr:`collectSubModels`.
+ """
+ return self._set(collectSubModels=value)
+
def _fit(self, dataset):
est = self.getOrDefault(self.estimator)
epm = self.getOrDefault(self.estimatorParamMaps)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 1edffaa4ca168..52ab86c0d88ee 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2535,6 +2535,20 @@ def func(s, iterator):
return f(iterator)
return PipelinedRDD(self.rdd, func, preservesPartitioning, isFromBarrier=True)
+ def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
+ """
+ .. note:: Experimental
+
+ Returns a new RDD by applying a function to each partition of the wrapped RDD, while
+ tracking the index of the original partition. And all tasks are launched together
+ in a barrier stage.
+ The interface is the same as :func:`RDD.mapPartitionsWithIndex`.
+ Please see the API doc there.
+
+ .. versionadded:: 3.0.0
+ """
+ return PipelinedRDD(self.rdd, f, preservesPartitioning, isFromBarrier=True)
+
class PipelinedRDD(RDD):
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index f92face2d0573..18fd7de7ee547 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -788,7 +788,7 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options)
@since(1.4)
def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None,
- lineSep=None, encoding=None):
+ lineSep=None, encoding=None, ignoreNullFields=None):
"""Saves the content of the :class:`DataFrame` in JSON format
(`JSON Lines text format or newline-delimited JSON `_) at the
specified path.
@@ -817,13 +817,15 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm
the default UTF-8 charset will be used.
:param lineSep: defines the line separator that should be used for writing. If None is
set, it uses the default value, ``\\n``.
+ :param ignoreNullFields: Whether to ignore null fields when generating JSON objects.
+ If None is set, it uses the default value, ``true``.
>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self.mode(mode)
self._set_opts(
compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat,
- lineSep=lineSep, encoding=encoding)
+ lineSep=lineSep, encoding=encoding, ignoreNullFields=ignoreNullFields)
self._jwrite.json(path)
@since(1.4)
diff --git a/python/pyspark/tests/test_rddbarrier.py b/python/pyspark/tests/test_rddbarrier.py
new file mode 100644
index 0000000000000..8534fb4abb876
--- /dev/null
+++ b/python/pyspark/tests/test_rddbarrier.py
@@ -0,0 +1,50 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from pyspark.testing.utils import ReusedPySparkTestCase
+
+
+class RDDBarrierTests(ReusedPySparkTestCase):
+ def test_map_partitions(self):
+ """Test RDDBarrier.mapPartitions"""
+ rdd = self.sc.parallelize(range(12), 4)
+ self.assertFalse(rdd._is_barrier())
+
+ rdd1 = rdd.barrier().mapPartitions(lambda it: it)
+ self.assertTrue(rdd1._is_barrier())
+
+ def test_map_partitions_with_index(self):
+ """Test RDDBarrier.mapPartitionsWithIndex"""
+ rdd = self.sc.parallelize(range(12), 4)
+ self.assertFalse(rdd._is_barrier())
+
+ def f(index, iterator):
+ yield index
+ rdd1 = rdd.barrier().mapPartitionsWithIndex(f)
+ self.assertTrue(rdd1._is_barrier())
+ self.assertEqual(rdd1.collect(), [0, 1, 2, 3])
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.tests.test_rddbarrier import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/setup.py b/python/setup.py
index ee5c32683efae..ea672309703b6 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -230,6 +230,7 @@ def _supports_symlinks():
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
+ 'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: Implementation :: CPython',
'Programming Language :: Python :: Implementation :: PyPy']
)
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index 1839203e3b235..c97eb3c935be6 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -79,14 +79,18 @@ singleTableSchema
: colTypeList EOF
;
+singleInterval
+ : INTERVAL? (intervalValue intervalUnit)+ EOF
+ ;
+
statement
: query #statementDefault
| ctes? dmlStatementNoWith #dmlStatement
| USE NAMESPACE? multipartIdentifier #use
- | CREATE database (IF NOT EXISTS)? db=errorCapturingIdentifier
+ | CREATE (database | NAMESPACE) (IF NOT EXISTS)? multipartIdentifier
((COMMENT comment=STRING) |
locationSpec |
- (WITH DBPROPERTIES tablePropertyList))* #createDatabase
+ (WITH (DBPROPERTIES | PROPERTIES) tablePropertyList))* #createNamespace
| ALTER database db=errorCapturingIdentifier
SET DBPROPERTIES tablePropertyList #setDatabaseProperties
| ALTER database db=errorCapturingIdentifier
@@ -194,24 +198,24 @@ statement
('(' key=tablePropertyKey ')')? #showTblProperties
| SHOW COLUMNS (FROM | IN) tableIdentifier
((FROM | IN) db=errorCapturingIdentifier)? #showColumns
- | SHOW PARTITIONS tableIdentifier partitionSpec? #showPartitions
+ | SHOW PARTITIONS multipartIdentifier partitionSpec? #showPartitions
| SHOW identifier? FUNCTIONS
(LIKE? (qualifiedName | pattern=STRING))? #showFunctions
- | SHOW CREATE TABLE tableIdentifier #showCreateTable
+ | SHOW CREATE TABLE multipartIdentifier #showCreateTable
| (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction
| (DESC | DESCRIBE) database EXTENDED? db=errorCapturingIdentifier #describeDatabase
| (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)?
- multipartIdentifier partitionSpec? describeColName? #describeTable
+ multipartIdentifier partitionSpec? describeColName? #describeTable
| (DESC | DESCRIBE) QUERY? query #describeQuery
- | REFRESH TABLE tableIdentifier #refreshTable
+ | REFRESH TABLE multipartIdentifier #refreshTable
| REFRESH (STRING | .*?) #refreshResource
- | CACHE LAZY? TABLE tableIdentifier
+ | CACHE LAZY? TABLE multipartIdentifier
(OPTIONS options=tablePropertyList)? (AS? query)? #cacheTable
- | UNCACHE TABLE (IF EXISTS)? tableIdentifier #uncacheTable
+ | UNCACHE TABLE (IF EXISTS)? multipartIdentifier #uncacheTable
| CLEAR CACHE #clearCache
| LOAD DATA LOCAL? INPATH path=STRING OVERWRITE? INTO TABLE
tableIdentifier partitionSpec? #loadData
- | TRUNCATE TABLE tableIdentifier partitionSpec? #truncateTable
+ | TRUNCATE TABLE multipartIdentifier partitionSpec? #truncateTable
| MSCK REPAIR TABLE multipartIdentifier #repairTable
| op=(ADD | LIST) identifier .*? #manageResource
| SET ROLE .*? #failNativeCommand
@@ -1039,6 +1043,7 @@ ansiNonReserved
| POSITION
| PRECEDING
| PRINCIPALS
+ | PROPERTIES
| PURGE
| QUERY
| RANGE
@@ -1299,6 +1304,7 @@ nonReserved
| PRECEDING
| PRIMARY
| PRINCIPALS
+ | PROPERTIES
| PURGE
| QUERY
| RANGE
@@ -1564,6 +1570,7 @@ POSITION: 'POSITION';
PRECEDING: 'PRECEDING';
PRIMARY: 'PRIMARY';
PRINCIPALS: 'PRINCIPALS';
+PROPERTIES: 'PROPERTIES';
PURGE: 'PURGE';
QUERY: 'QUERY';
RANGE: 'RANGE';
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b913a9618d6eb..21bf926af50d7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1697,6 +1697,8 @@ class Analyzer(
// Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries.
case q: UnaryNode if q.childrenResolved =>
resolveSubQueries(q, q.children)
+ case j: Join if j.childrenResolved =>
+ resolveSubQueries(j, Seq(j, j.left, j.right))
case s: SupportsSubquery if s.childrenResolved =>
resolveSubQueries(s, s.children)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 6a5d938f0fdc6..d9dc9ebbcaf3c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -601,10 +601,10 @@ trait CheckAnalysis extends PredicateHelper {
case inSubqueryOrExistsSubquery =>
plan match {
- case _: Filter | _: SupportsSubquery => // Ok
+ case _: Filter | _: SupportsSubquery | _: Join => // Ok
case _ =>
failAnalysis(s"IN/EXISTS predicate sub-queries can only be used in" +
- s" Filter and a few commands: $plan")
+ s" Filter/Join and a few commands: $plan")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 04e8963944fda..019e1a08779e2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -486,6 +486,7 @@ object FunctionRegistry {
expression[CurrentDatabase]("current_database"),
expression[CallMethodViaReflection]("reflect"),
expression[CallMethodViaReflection]("java_method"),
+ expression[Version]("version"),
// grouping sets
expression[Cube]("cube"),
@@ -527,6 +528,7 @@ object FunctionRegistry {
expression[BitwiseCount]("bit_count"),
expression[BitAndAgg]("bit_and"),
expression[BitOrAgg]("bit_or"),
+ expression[BitXorAgg]("bit_xor"),
// json
expression[StructsToJson]("to_json"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala
index 13a79a82a3858..9803fda0678ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala
@@ -137,6 +137,9 @@ class ResolveCatalogs(val catalogManager: CatalogManager)
writeOptions = c.options.filterKeys(_ != "path"),
ignoreIfExists = c.ifNotExists)
+ case RefreshTableStatement(NonSessionCatalog(catalog, tableName)) =>
+ RefreshTable(catalog.asTableCatalog, tableName.asIdentifier)
+
case c @ ReplaceTableStatement(
NonSessionCatalog(catalog, tableName), _, _, _, _, _, _, _, _, _) =>
ReplaceTable(
@@ -168,6 +171,13 @@ class ResolveCatalogs(val catalogManager: CatalogManager)
s"Can not specify catalog `${catalog.name}` for view ${viewName.quoted} " +
s"because view support in catalog has not been implemented yet")
+ case c @ CreateNamespaceStatement(NonSessionCatalog(catalog, nameParts), _, _) =>
+ CreateNamespace(
+ catalog.asNamespaceCatalog,
+ nameParts,
+ c.ifNotExists,
+ c.properties)
+
case ShowNamespacesStatement(Some(CatalogAndNamespace(catalog, namespace)), pattern) =>
ShowNamespaces(catalog.asNamespaceCatalog, namespace, pattern)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index d71f300dd26dd..862b2bb515a19 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -23,7 +23,7 @@ import java.util.Locale
import java.util.concurrent.TimeUnit._
import org.apache.spark.SparkException
-import org.apache.spark.sql.catalyst.{InternalRow, WalkedTypePath}
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UTF8StringBuilder
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper}
object Cast {
@@ -466,7 +466,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
// IntervalConverter
private[this] def castToInterval(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => CalendarInterval.fromString(s.toString))
+ buildCast[UTF8String](_, s => IntervalUtils.safeFromString(s.toString))
}
// LongConverter
@@ -1213,8 +1213,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
private[this] def castToIntervalCode(from: DataType): CastFunction = from match {
case StringType =>
+ val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
(c, evPrim, evNull) =>
- code"""$evPrim = CalendarInterval.fromString($c.toString());
+ code"""$evPrim = $util.safeFromString($c.toString());
if(${evPrim} == null) {
${evNull} = true;
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index eaaf94baac216..300f075d32763 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -127,12 +127,6 @@ object UnsafeProjection
InterpretedUnsafeProjection.createProjection(in)
}
- protected def toUnsafeExprs(exprs: Seq[Expression]): Seq[Expression] = {
- exprs.map(_ transform {
- case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
- })
- }
-
/**
* Returns an UnsafeProjection for given StructType.
*
@@ -153,7 +147,7 @@ object UnsafeProjection
* Returns an UnsafeProjection for given sequence of bound Expressions.
*/
def create(exprs: Seq[Expression]): UnsafeProjection = {
- createObject(toUnsafeExprs(exprs))
+ createObject(exprs)
}
def create(expr: Expression): UnsafeProjection = create(Seq(expr))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
index 9aae678deb4bc..b9ec933f31493 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
@@ -22,8 +22,8 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.CalendarInterval
case class TimeWindow(
timeColumn: Expression,
@@ -102,7 +102,7 @@ object TimeWindow {
* precision.
*/
private def getIntervalInMicroSeconds(interval: String): Long = {
- val cal = CalendarInterval.fromCaseInsensitiveString(interval)
+ val cal = IntervalUtils.fromString(interval)
if (cal.months > 0) {
throw new IllegalArgumentException(
s"Intervals greater than a month is not supported ($interval).")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala
index c7fdb15130c4f..b69b341b0ee3e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala
@@ -98,7 +98,7 @@ abstract class MaxMinBy extends DeclarativeAggregate {
> SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y);
b
""",
- since = "3.0")
+ since = "3.0.0")
case class MaxBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy {
override protected def funcName: String = "max_by"
@@ -116,7 +116,7 @@ case class MaxBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMin
> SELECT _FUNC_(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y);
a
""",
- since = "3.0")
+ since = "3.0.0")
case class MinBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMinBy {
override protected def funcName: String = "min_by"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala
index 131fa2eb50555..b77c3bd9cbde4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala
@@ -17,20 +17,14 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BitwiseAnd, BitwiseOr, ExpectsInputTypes, Expression, ExpressionDescription, If, IsNull, Literal}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryArithmetic, BitwiseAnd, BitwiseOr, BitwiseXor, ExpectsInputTypes, Expression, ExpressionDescription, If, IsNull, Literal}
import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegralType}
-@ExpressionDescription(
- usage = "_FUNC_(expr) - Returns the bitwise AND of all non-null input values, or null if none.",
- examples = """
- Examples:
- > SELECT _FUNC_(col) FROM VALUES (3), (5) AS tab(col);
- 1
- """,
- since = "3.0.0")
-case class BitAndAgg(child: Expression) extends DeclarativeAggregate with ExpectsInputTypes {
+abstract class BitAggregate extends DeclarativeAggregate with ExpectsInputTypes {
- override def nodeName: String = "bit_and"
+ val child: Expression
+
+ def bitOperator(left: Expression, right: Expression): BinaryArithmetic
override def children: Seq[Expression] = child :: Nil
@@ -40,23 +34,40 @@ case class BitAndAgg(child: Expression) extends DeclarativeAggregate with Expect
override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)
- private lazy val bitAnd = AttributeReference("bit_and", child.dataType)()
-
- override lazy val aggBufferAttributes: Seq[AttributeReference] = bitAnd :: Nil
+ private lazy val bitAgg = AttributeReference(nodeName, child.dataType)()
override lazy val initialValues: Seq[Literal] = Literal.create(null, dataType) :: Nil
+ override lazy val aggBufferAttributes: Seq[AttributeReference] = bitAgg :: Nil
+
+ override lazy val evaluateExpression: AttributeReference = bitAgg
+
override lazy val updateExpressions: Seq[Expression] =
- If(IsNull(bitAnd),
+ If(IsNull(bitAgg),
child,
- If(IsNull(child), bitAnd, BitwiseAnd(bitAnd, child))) :: Nil
+ If(IsNull(child), bitAgg, bitOperator(bitAgg, child))) :: Nil
override lazy val mergeExpressions: Seq[Expression] =
- If(IsNull(bitAnd.left),
- bitAnd.right,
- If(IsNull(bitAnd.right), bitAnd.left, BitwiseAnd(bitAnd.left, bitAnd.right))) :: Nil
+ If(IsNull(bitAgg.left),
+ bitAgg.right,
+ If(IsNull(bitAgg.right), bitAgg.left, bitOperator(bitAgg.left, bitAgg.right))) :: Nil
+}
+
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Returns the bitwise AND of all non-null input values, or null if none.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(col) FROM VALUES (3), (5) AS tab(col);
+ 1
+ """,
+ since = "3.0.0")
+case class BitAndAgg(child: Expression) extends BitAggregate {
- override lazy val evaluateExpression: AttributeReference = bitAnd
+ override def nodeName: String = "bit_and"
+
+ override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = {
+ BitwiseAnd(left, right)
+ }
}
@ExpressionDescription(
@@ -67,33 +78,28 @@ case class BitAndAgg(child: Expression) extends DeclarativeAggregate with Expect
7
""",
since = "3.0.0")
-case class BitOrAgg(child: Expression) extends DeclarativeAggregate with ExpectsInputTypes {
+case class BitOrAgg(child: Expression) extends BitAggregate {
override def nodeName: String = "bit_or"
- override def children: Seq[Expression] = child :: Nil
-
- override def nullable: Boolean = true
-
- override def dataType: DataType = child.dataType
-
- override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)
-
- private lazy val bitOr = AttributeReference("bit_or", child.dataType)()
-
- override lazy val aggBufferAttributes: Seq[AttributeReference] = bitOr :: Nil
-
- override lazy val initialValues: Seq[Literal] = Literal.create(null, dataType) :: Nil
+ override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = {
+ BitwiseOr(left, right)
+ }
+}
- override lazy val updateExpressions: Seq[Expression] =
- If(IsNull(bitOr),
- child,
- If(IsNull(child), bitOr, BitwiseOr(bitOr, child))) :: Nil
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Returns the bitwise XOR of all non-null input values, or null if none.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(col) FROM VALUES (3), (5) AS tab(col);
+ 6
+ """,
+ since = "3.0.0")
+case class BitXorAgg(child: Expression) extends BitAggregate {
- override lazy val mergeExpressions: Seq[Expression] =
- If(IsNull(bitOr.left),
- bitOr.right,
- If(IsNull(bitOr.right), bitOr.left, BitwiseOr(bitOr.left, bitOr.right))) :: Nil
+ override def nodeName: String = "bit_xor"
- override lazy val evaluateExpression: AttributeReference = bitOr
+ override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = {
+ BitwiseXor(left, right)
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index cae3c0528e136..3f722e8537c36 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -295,9 +295,20 @@ object CreateStruct extends FunctionBuilder {
}
/**
- * Common base class for both [[CreateNamedStruct]] and [[CreateNamedStructUnsafe]].
+ * Creates a struct with the given field names and values
+ *
+ * @param children Seq(name1, val1, name2, val2, ...)
*/
-trait CreateNamedStructLike extends Expression {
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_("a", 1, "b", 2, "c", 3);
+ {"a":1,"b":2,"c":3}
+ """)
+// scalastyle:on line.size.limit
+case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
lazy val (nameExprs, valExprs) = children.grouped(2).map {
case Seq(name, value) => (name, value)
}.toList.unzip
@@ -348,23 +359,6 @@ trait CreateNamedStructLike extends Expression {
override def eval(input: InternalRow): Any = {
InternalRow(valExprs.map(_.eval(input)): _*)
}
-}
-
-/**
- * Creates a struct with the given field names and values
- *
- * @param children Seq(name1, val1, name2, val2, ...)
- */
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.",
- examples = """
- Examples:
- > SELECT _FUNC_("a", 1, "b", 2, "c", 3);
- {"a":1,"b":2,"c":3}
- """)
-// scalastyle:on line.size.limit
-case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStructLike {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rowClass = classOf[GenericInternalRow].getName
@@ -397,22 +391,6 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
override def prettyName: String = "named_struct"
}
-/**
- * Creates a struct with the given field names and values. This is a variant that returns
- * UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with
- * this expression automatically at runtime.
- *
- * @param children Seq(name1, val1, name2, val2, ...)
- */
-case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike {
- override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val eval = GenerateUnsafeProjection.createCode(ctx, valExprs)
- ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.value)
- }
-
- override def prettyName: String = "named_struct_unsafe"
-}
-
/**
* Creates a map after splitting the input text into key/value pairs using delimiters
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 2af2b13ad77f5..b8c23a1f08912 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import java.util.UUID
-
+import org.apache.spark.{SPARK_REVISION, SPARK_VERSION_SHORT}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
@@ -164,3 +163,17 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta
override def freshCopy(): Uuid = Uuid(randomSeed)
}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """_FUNC_() - Returns the Spark version. The string contains 2 fields, the first being a release version and the second being a git revision.""",
+ since = "3.0.0")
+// scalastyle:on line.size.limit
+case class Version() extends LeafExpression with CodegenFallback {
+ override def nullable: Boolean = false
+ override def foldable: Boolean = true
+ override def dataType: DataType = StringType
+ override def eval(input: InternalRow): Any = {
+ UTF8String.fromString(SPARK_VERSION_SHORT + " " + SPARK_REVISION)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index e7bfb77e46c26..4952540f1132d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -78,8 +78,8 @@ private[sql] class JSONOptions(
val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false)
// Whether to ignore null fields during json generating
- val ignoreNullFields = parameters.getOrElse("ignoreNullFields",
- SQLConf.get.jsonGeneratorIgnoreNullFields).toBoolean
+ val ignoreNullFields = parameters.get("ignoreNullFields").map(_.toBoolean)
+ .getOrElse(SQLConf.get.jsonGeneratorIgnoreNullFields)
// A language tag in IETF BCP 47 format
val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
index db7d6d3254bd2..1743565ccb6c1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
/**
- * Simplify redundant [[CreateNamedStructLike]], [[CreateArray]] and [[CreateMap]] expressions.
+ * Simplify redundant [[CreateNamedStruct]], [[CreateArray]] and [[CreateMap]] expressions.
*/
object SimplifyExtractValueOps extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
@@ -37,8 +37,8 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
case a: Aggregate => a
case p => p.transformExpressionsUp {
// Remove redundant field extraction.
- case GetStructField(createNamedStructLike: CreateNamedStructLike, ordinal, _) =>
- createNamedStructLike.valExprs(ordinal)
+ case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
+ createNamedStruct.valExprs(ordinal)
// Remove redundant array indexing.
case GetArrayStructFields(CreateArray(elems), field, ordinal, _, _) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
index b036092cf1fcc..ea01d9e63eef7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window}
@@ -114,9 +114,6 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
case CreateNamedStruct(children) =>
CreateNamedStruct(children.map(normalize))
- case CreateNamedStructUnsafe(children) =>
- CreateNamedStructUnsafe(children.map(normalize))
-
case CreateArray(children) =>
CreateArray(children.map(normalize))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 0a6737ba42118..36ad796c08a38 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -227,8 +227,8 @@ object OptimizeIn extends Rule[LogicalPlan] {
if (newList.length == 1
// TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed,
// TODO: we exclude them in this rule.
- && !v.isInstanceOf[CreateNamedStructLike]
- && !newList.head.isInstanceOf[CreateNamedStructLike]) {
+ && !v.isInstanceOf[CreateNamedStruct]
+ && !newList.head.isInstanceOf[CreateNamedStruct]) {
EqualTo(v, newList.head)
} else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) {
val hSet = newList.map(e => e.eval(EmptyRow))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 8af7cf9ad8008..4fa479f083e10 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
+import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -100,6 +101,23 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
withOrigin(ctx)(StructType(visitColTypeList(ctx.colTypeList)))
}
+ override def visitSingleInterval(ctx: SingleIntervalContext): CalendarInterval = {
+ withOrigin(ctx) {
+ val units = ctx.intervalUnit().asScala.map {
+ u => normalizeInternalUnit(u.getText.toLowerCase(Locale.ROOT))
+ }.toArray
+ val values = ctx.intervalValue().asScala.map(getIntervalValue).toArray
+ try {
+ CalendarInterval.fromUnitStrings(units, values)
+ } catch {
+ case i: IllegalArgumentException =>
+ val e = new ParseException(i.getMessage, ctx)
+ e.setStackTrace(i.getStackTrace)
+ throw e
+ }
+ }
+ }
+
/* ********************************************************************************************
* Plan parsing
* ******************************************************************************************** */
@@ -1770,7 +1788,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
toLiteral(stringToTimestamp(_, zoneId), TimestampType)
case "INTERVAL" =>
val interval = try {
- CalendarInterval.fromCaseInsensitiveString(value)
+ IntervalUtils.fromString(value)
} catch {
case e: IllegalArgumentException =>
val ex = new ParseException("Cannot parse the INTERVAL value: " + value, ctx)
@@ -1930,15 +1948,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
*/
override def visitIntervalField(ctx: IntervalFieldContext): CalendarInterval = withOrigin(ctx) {
import ctx._
- val s = value.getText
+ val s = getIntervalValue(value)
try {
val unitText = unit.getText.toLowerCase(Locale.ROOT)
val interval = (unitText, Option(to).map(_.getText.toLowerCase(Locale.ROOT))) match {
- case (u, None) if u.endsWith("s") =>
- // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/...
- CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s)
case (u, None) =>
- CalendarInterval.fromSingleUnitString(u, s)
+ CalendarInterval.fromUnitStrings(Array(normalizeInternalUnit(u)), Array(s))
case ("year", Some("month")) =>
CalendarInterval.fromYearMonthString(s)
case ("day", Some("hour")) =>
@@ -1967,6 +1982,19 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
}
}
+ private def getIntervalValue(value: IntervalValueContext): String = {
+ if (value.STRING() != null) {
+ string(value.STRING())
+ } else {
+ value.getText
+ }
+ }
+
+ // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/...
+ private def normalizeInternalUnit(s: String): String = {
+ if (s.endsWith("s")) s.substring(0, s.length - 1) else s
+ }
+
/* ********************************************************************************************
* DataType parsing
* ******************************************************************************************** */
@@ -2307,6 +2335,46 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
}
}
+ /**
+ * Create a [[CreateNamespaceStatement]] command.
+ *
+ * For example:
+ * {{{
+ * CREATE NAMESPACE [IF NOT EXISTS] ns1.ns2.ns3
+ * create_namespace_clauses;
+ *
+ * create_namespace_clauses (order insensitive):
+ * [COMMENT namespace_comment]
+ * [LOCATION path]
+ * [WITH PROPERTIES (key1=val1, key2=val2, ...)]
+ * }}}
+ */
+ override def visitCreateNamespace(ctx: CreateNamespaceContext): LogicalPlan = withOrigin(ctx) {
+ checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx)
+ checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx)
+ checkDuplicateClauses(ctx.PROPERTIES, "WITH PROPERTIES", ctx)
+ checkDuplicateClauses(ctx.DBPROPERTIES, "WITH DBPROPERTIES", ctx)
+
+ if (!ctx.PROPERTIES.isEmpty && !ctx.DBPROPERTIES.isEmpty) {
+ throw new ParseException(s"Either PROPERTIES or DBPROPERTIES is allowed.", ctx)
+ }
+
+ var properties = ctx.tablePropertyList.asScala.headOption
+ .map(visitPropertyKeyValues)
+ .getOrElse(Map.empty)
+ Option(ctx.comment).map(string).map {
+ properties += CreateNamespaceStatement.COMMENT_PROPERTY_KEY -> _
+ }
+ ctx.locationSpec.asScala.headOption.map(visitLocationSpec).map {
+ properties += CreateNamespaceStatement.LOCATION_PROPERTY_KEY -> _
+ }
+
+ CreateNamespaceStatement(
+ visitMultipartIdentifier(ctx.multipartIdentifier),
+ ctx.EXISTS != null,
+ properties)
+ }
+
/**
* Create a [[ShowNamespacesStatement]] command.
*/
@@ -2728,4 +2796,85 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
override def visitRepairTable(ctx: RepairTableContext): LogicalPlan = withOrigin(ctx) {
RepairTableStatement(visitMultipartIdentifier(ctx.multipartIdentifier()))
}
+
+ /**
+ * Creates a [[ShowCreateTableStatement]]
+ */
+ override def visitShowCreateTable(ctx: ShowCreateTableContext): LogicalPlan = withOrigin(ctx) {
+ ShowCreateTableStatement(visitMultipartIdentifier(ctx.multipartIdentifier()))
+ }
+
+ /**
+ * Create a [[CacheTableStatement]].
+ *
+ * For example:
+ * {{{
+ * CACHE [LAZY] TABLE multi_part_name
+ * [OPTIONS tablePropertyList] [[AS] query]
+ * }}}
+ */
+ override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ val query = Option(ctx.query).map(plan)
+ val tableName = visitMultipartIdentifier(ctx.multipartIdentifier)
+ if (query.isDefined && tableName.length > 1) {
+ val catalogAndNamespace = tableName.init
+ throw new ParseException("It is not allowed to add catalog/namespace " +
+ s"prefix ${catalogAndNamespace.quoted} to " +
+ "the table name in CACHE TABLE AS SELECT", ctx)
+ }
+ val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty)
+ CacheTableStatement(tableName, query, ctx.LAZY != null, options)
+ }
+
+ /**
+ * Create an [[UncacheTableStatement]] logical plan.
+ */
+ override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) {
+ UncacheTableStatement(visitMultipartIdentifier(ctx.multipartIdentifier), ctx.EXISTS != null)
+ }
+
+ /**
+ * Create a [[TruncateTableStatement]] command.
+ *
+ * For example:
+ * {{{
+ * TRUNCATE TABLE multi_part_name [PARTITION (partcol1=val1, partcol2=val2 ...)]
+ * }}}
+ */
+ override def visitTruncateTable(ctx: TruncateTableContext): LogicalPlan = withOrigin(ctx) {
+ TruncateTableStatement(
+ visitMultipartIdentifier(ctx.multipartIdentifier),
+ Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec))
+ }
+
+ /**
+ * A command for users to list the partition names of a table. If partition spec is specified,
+ * partitions that match the spec are returned. Otherwise an empty result set is returned.
+ *
+ * This function creates a [[ShowPartitionsStatement]] logical plan
+ *
+ * The syntax of using this command in SQL is:
+ * {{{
+ * SHOW PARTITIONS multi_part_name [partition_spec];
+ * }}}
+ */
+ override def visitShowPartitions(ctx: ShowPartitionsContext): LogicalPlan = withOrigin(ctx) {
+ val table = visitMultipartIdentifier(ctx.multipartIdentifier)
+ val partitionKeys = Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)
+ ShowPartitionsStatement(table, partitionKeys)
+ }
+
+ /**
+ * Create a [[RefreshTableStatement]].
+ *
+ * For example:
+ * {{{
+ * REFRESH TABLE multi_part_name
+ * }}}
+ */
+ override def visitRefreshTable(ctx: RefreshTableContext): LogicalPlan = withOrigin(ctx) {
+ RefreshTableStatement(visitMultipartIdentifier(ctx.multipartIdentifier()))
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
index 85998e33140d0..b66cae7979416 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
@@ -29,11 +29,20 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.unsafe.types.CalendarInterval
/**
* Base SQL parsing infrastructure.
*/
-abstract class AbstractSqlParser extends ParserInterface with Logging {
+abstract class AbstractSqlParser(conf: SQLConf) extends ParserInterface with Logging {
+
+ /**
+ * Creates [[CalendarInterval]] for a given SQL String. Throws [[ParseException]] if the SQL
+ * string is not a valid interval format.
+ */
+ def parseInterval(sqlText: String): CalendarInterval = parse(sqlText) { parser =>
+ astBuilder.visitSingleInterval(parser.singleInterval())
+ }
/** Creates/Resolves DataType for a given SQL string. */
override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
@@ -91,16 +100,16 @@ abstract class AbstractSqlParser extends ParserInterface with Logging {
val lexer = new SqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command)))
lexer.removeErrorListeners()
lexer.addErrorListener(ParseErrorListener)
- lexer.legacy_setops_precedence_enbled = SQLConf.get.setOpsPrecedenceEnforced
- lexer.ansi = SQLConf.get.ansiEnabled
+ lexer.legacy_setops_precedence_enbled = conf.setOpsPrecedenceEnforced
+ lexer.ansi = conf.ansiEnabled
val tokenStream = new CommonTokenStream(lexer)
val parser = new SqlBaseParser(tokenStream)
parser.addParseListener(PostProcessor)
parser.removeErrorListeners()
parser.addErrorListener(ParseErrorListener)
- parser.legacy_setops_precedence_enbled = SQLConf.get.setOpsPrecedenceEnforced
- parser.ansi = SQLConf.get.ansiEnabled
+ parser.legacy_setops_precedence_enbled = conf.setOpsPrecedenceEnforced
+ parser.ansi = conf.ansiEnabled
try {
try {
@@ -134,12 +143,12 @@ abstract class AbstractSqlParser extends ParserInterface with Logging {
/**
* Concrete SQL parser for Catalyst-only SQL statements.
*/
-class CatalystSqlParser(conf: SQLConf) extends AbstractSqlParser {
+class CatalystSqlParser(conf: SQLConf) extends AbstractSqlParser(conf) {
val astBuilder = new AstBuilder(conf)
}
/** For test-only. */
-object CatalystSqlParser extends AbstractSqlParser {
+object CatalystSqlParser extends AbstractSqlParser(SQLConf.get) {
val astBuilder = new AstBuilder(SQLConf.get)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
index 72d5cbb7d9045..655e87fce4e26 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
@@ -282,6 +282,19 @@ case class InsertIntoStatement(
case class ShowTablesStatement(namespace: Option[Seq[String]], pattern: Option[String])
extends ParsedStatement
+/**
+ * A CREATE NAMESPACE statement, as parsed from SQL.
+ */
+case class CreateNamespaceStatement(
+ namespace: Seq[String],
+ ifNotExists: Boolean,
+ properties: Map[String, String]) extends ParsedStatement
+
+object CreateNamespaceStatement {
+ val COMMENT_PROPERTY_KEY: String = "comment"
+ val LOCATION_PROPERTY_KEY: String = "location"
+}
+
/**
* A SHOW NAMESPACES statement, as parsed from SQL.
*/
@@ -316,3 +329,43 @@ case class AnalyzeColumnStatement(
* A REPAIR TABLE statement, as parsed from SQL
*/
case class RepairTableStatement(tableName: Seq[String]) extends ParsedStatement
+
+/**
+ * A SHOW CREATE TABLE statement, as parsed from SQL.
+ */
+case class ShowCreateTableStatement(tableName: Seq[String]) extends ParsedStatement
+
+/**
+ * A CACHE TABLE statement, as parsed from SQL
+ */
+case class CacheTableStatement(
+ tableName: Seq[String],
+ plan: Option[LogicalPlan],
+ isLazy: Boolean,
+ options: Map[String, String]) extends ParsedStatement
+
+/**
+ * An UNCACHE TABLE statement, as parsed from SQL
+ */
+case class UncacheTableStatement(
+ tableName: Seq[String],
+ ifExists: Boolean) extends ParsedStatement
+
+/**
+ * A TRUNCATE TABLE statement, as parsed from SQL
+ */
+case class TruncateTableStatement(
+ tableName: Seq[String],
+ partitionSpec: Option[TablePartitionSpec]) extends ParsedStatement
+
+/**
+ * A SHOW PARTITIONS statement, as parsed from SQL
+ */
+case class ShowPartitionsStatement(
+ tableName: Seq[String],
+ partitionSpec: Option[TablePartitionSpec]) extends ParsedStatement
+
+/**
+ * A REFRESH TABLE statement, as parsed from SQL
+ */
+case class RefreshTableStatement(tableName: Seq[String]) extends ParsedStatement
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index f89dfb1ec47d8..d80c1c034a867 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -237,6 +237,14 @@ case class ReplaceTableAsSelect(
}
}
+/**
+ * The logical plan of the CREATE NAMESPACE command that works for v2 catalogs.
+ */
+case class CreateNamespace(
+ catalog: SupportsNamespaces,
+ namespace: Seq[String],
+ ifNotExists: Boolean,
+ properties: Map[String, String]) extends Command
/**
* The logical plan of the SHOW NAMESPACES command that works for v2 catalogs.
@@ -340,3 +348,10 @@ case class SetCatalogAndNamespace(
catalogManager: CatalogManager,
catalogName: Option[String],
namespace: Option[Seq[String]]) extends Command
+
+/**
+ * The logical plan of the REFRESH TABLE command that works for v2 catalogs.
+ */
+case class RefreshTable(
+ catalog: TableCatalog,
+ ident: Identifier) extends Command
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
index 78d188f81f628..14fd153e15f58 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
@@ -17,21 +17,24 @@
package org.apache.spark.sql.catalyst.util
+import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.CalendarInterval
object IntervalUtils {
- val MONTHS_PER_YEAR: Int = 12
- val MONTHS_PER_QUARTER: Byte = 3
- val YEARS_PER_MILLENNIUM: Int = 1000
- val YEARS_PER_CENTURY: Int = 100
- val YEARS_PER_DECADE: Int = 10
- val MICROS_PER_HOUR: Long = DateTimeUtils.MILLIS_PER_HOUR * DateTimeUtils.MICROS_PER_MILLIS
- val MICROS_PER_MINUTE: Long = DateTimeUtils.MILLIS_PER_MINUTE * DateTimeUtils.MICROS_PER_MILLIS
- val DAYS_PER_MONTH: Byte = 30
- val MICROS_PER_MONTH: Long = DAYS_PER_MONTH * DateTimeUtils.SECONDS_PER_DAY
+ final val MONTHS_PER_YEAR: Int = 12
+ final val MONTHS_PER_QUARTER: Byte = 3
+ final val YEARS_PER_MILLENNIUM: Int = 1000
+ final val YEARS_PER_CENTURY: Int = 100
+ final val YEARS_PER_DECADE: Int = 10
+ final val MICROS_PER_HOUR: Long =
+ DateTimeUtils.MILLIS_PER_HOUR * DateTimeUtils.MICROS_PER_MILLIS
+ final val MICROS_PER_MINUTE: Long =
+ DateTimeUtils.MILLIS_PER_MINUTE * DateTimeUtils.MICROS_PER_MILLIS
+ final val DAYS_PER_MONTH: Byte = 30
+ final val MICROS_PER_MONTH: Long = DAYS_PER_MONTH * DateTimeUtils.SECONDS_PER_DAY
/* 365.25 days per year assumes leap year every four years */
- val MICROS_PER_YEAR: Long = (36525L * DateTimeUtils.MICROS_PER_DAY) / 100
+ final val MICROS_PER_YEAR: Long = (36525L * DateTimeUtils.MICROS_PER_DAY) / 100
def getYears(interval: CalendarInterval): Int = {
interval.months / MONTHS_PER_YEAR
@@ -88,4 +91,32 @@ object IntervalUtils {
result += MICROS_PER_MONTH * (interval.months % MONTHS_PER_YEAR)
Decimal(result, 18, 6)
}
+
+ /**
+ * Converts a string to [[CalendarInterval]] case-insensitively.
+ *
+ * @throws IllegalArgumentException if the input string is not in valid interval format.
+ */
+ def fromString(str: String): CalendarInterval = {
+ if (str == null) throw new IllegalArgumentException("Interval string cannot be null")
+ try {
+ CatalystSqlParser.parseInterval(str)
+ } catch {
+ case e: ParseException =>
+ val ex = new IllegalArgumentException(s"Invalid interval string: $str\n" + e.message)
+ ex.setStackTrace(e.getStackTrace)
+ throw ex
+ }
+ }
+
+ /**
+ * A safe version of `fromString`. It returns null for invalid input string.
+ */
+ def safeFromString(str: String): CalendarInterval = {
+ try {
+ fromString(str)
+ } catch {
+ case _: IllegalArgumentException => null
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala
similarity index 100%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 4944099fcc0d8..a228d9f064a1e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -626,8 +626,8 @@ object SQLConf {
.createWithDefault("snappy")
val ORC_IMPLEMENTATION = buildConf("spark.sql.orc.impl")
- .doc("When native, use the native version of ORC support instead of the ORC library in Hive " +
- "1.2.1. It is 'hive' by default prior to Spark 2.4.")
+ .doc("When native, use the native version of ORC support instead of the ORC library in Hive." +
+ "It is 'hive' by default prior to Spark 2.4.")
.internal()
.stringConf
.checkValues(Set("hive", "native"))
@@ -980,7 +980,9 @@ object SQLConf {
.createWithDefault(true)
val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes")
- .doc("The maximum number of bytes to pack into a single partition when reading files.")
+ .doc("The maximum number of bytes to pack into a single partition when reading files. " +
+ "This configuration is effective only when using file-based sources such as Parquet, JSON " +
+ "and ORC.")
.bytesConf(ByteUnit.BYTE)
.createWithDefault(128 * 1024 * 1024) // parquet.block.size
@@ -989,19 +991,24 @@ object SQLConf {
.doc("The estimated cost to open a file, measured by the number of bytes could be scanned in" +
" the same time. This is used when putting multiple files into a partition. It's better to" +
" over estimated, then the partitions with small files will be faster than partitions with" +
- " bigger files (which is scheduled first).")
+ " bigger files (which is scheduled first). This configuration is effective only when using" +
+ " file-based sources such as Parquet, JSON and ORC.")
.longConf
.createWithDefault(4 * 1024 * 1024)
val IGNORE_CORRUPT_FILES = buildConf("spark.sql.files.ignoreCorruptFiles")
.doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " +
- "encountering corrupted files and the contents that have been read will still be returned.")
+ "encountering corrupted files and the contents that have been read will still be returned. " +
+ "This configuration is effective only when using file-based sources such as Parquet, JSON " +
+ "and ORC.")
.booleanConf
.createWithDefault(false)
val IGNORE_MISSING_FILES = buildConf("spark.sql.files.ignoreMissingFiles")
.doc("Whether to ignore missing files. If true, the Spark jobs will continue to run when " +
- "encountering missing files and the contents that have been read will still be returned.")
+ "encountering missing files and the contents that have been read will still be returned. " +
+ "This configuration is effective only when using file-based sources such as Parquet, JSON " +
+ "and ORC.")
.booleanConf
.createWithDefault(false)
@@ -1189,9 +1196,11 @@ object SQLConf {
val JSON_GENERATOR_IGNORE_NULL_FIELDS =
buildConf("spark.sql.jsonGenerator.ignoreNullFields")
- .doc("If false, JacksonGenerator will generate null for null fields in Struct.")
- .stringConf
- .createWithDefault("true")
+ .doc("Whether to ignore null fields when generating JSON objects in JSON data source and " +
+ "JSON functions such as to_json. " +
+ "If false, it generates null for null fields in JSON objects.")
+ .booleanConf
+ .createWithDefault(true)
val FILE_SINK_LOG_DELETION = buildConf("spark.sql.streaming.fileSink.log.deletion")
.internal()
@@ -2385,7 +2394,7 @@ class SQLConf extends Serializable with Logging {
def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE)
- def jsonGeneratorIgnoreNullFields: String = getConf(SQLConf.JSON_GENERATOR_IGNORE_NULL_FIELDS)
+ def jsonGeneratorIgnoreNullFields: Boolean = getConf(SQLConf.JSON_GENERATOR_IGNORE_NULL_FIELDS)
def parallelFileListingInStatsComputation: Boolean =
getConf(SQLConf.PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala
index 74a8590b5eefe..5aa80e1a9bd7f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery}
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, Project}
+import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.logical._
/**
* Unit tests for [[ResolveSubquery]].
@@ -29,8 +30,10 @@ class ResolveSubquerySuite extends AnalysisTest {
val a = 'a.int
val b = 'b.int
+ val c = 'c.int
val t1 = LocalRelation(a)
val t2 = LocalRelation(b)
+ val t3 = LocalRelation(c)
test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") {
val expr = Filter(
@@ -41,4 +44,13 @@ class ResolveSubquerySuite extends AnalysisTest {
assert(m.contains(
"Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses"))
}
+
+ test("SPARK-29145 Support subquery in join condition") {
+ val expr = Join(t1,
+ t2,
+ Inner,
+ Some(InSubquery(Seq(a), ListQuery(Project(Seq(UnresolvedAttribute("c")), t3)))),
+ JoinHint.NONE)
+ assertAnalysisSuccess(expr)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 603073b40d7aa..e10aa60d52cf8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
+import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, IntervalUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
@@ -720,7 +720,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
Literal(Timestamp.valueOf("2018-01-02 00:00:00")),
- Literal(CalendarInterval.fromString("interval 12 hours"))),
+ Literal(IntervalUtils.fromString("interval 12 hours"))),
Seq(
Timestamp.valueOf("2018-01-01 00:00:00"),
Timestamp.valueOf("2018-01-01 12:00:00"),
@@ -729,7 +729,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
Literal(Timestamp.valueOf("2018-01-02 00:00:01")),
- Literal(CalendarInterval.fromString("interval 12 hours"))),
+ Literal(IntervalUtils.fromString("interval 12 hours"))),
Seq(
Timestamp.valueOf("2018-01-01 00:00:00"),
Timestamp.valueOf("2018-01-01 12:00:00"),
@@ -738,7 +738,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-01-02 00:00:00")),
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
- Literal(CalendarInterval.fromString("interval 12 hours").negate())),
+ Literal(IntervalUtils.fromString("interval 12 hours").negate())),
Seq(
Timestamp.valueOf("2018-01-02 00:00:00"),
Timestamp.valueOf("2018-01-01 12:00:00"),
@@ -747,7 +747,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-01-02 00:00:00")),
Literal(Timestamp.valueOf("2017-12-31 23:59:59")),
- Literal(CalendarInterval.fromString("interval 12 hours").negate())),
+ Literal(IntervalUtils.fromString("interval 12 hours").negate())),
Seq(
Timestamp.valueOf("2018-01-02 00:00:00"),
Timestamp.valueOf("2018-01-01 12:00:00"),
@@ -756,7 +756,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
Literal(Timestamp.valueOf("2018-03-01 00:00:00")),
- Literal(CalendarInterval.fromString("interval 1 month"))),
+ Literal(IntervalUtils.fromString("interval 1 month"))),
Seq(
Timestamp.valueOf("2018-01-01 00:00:00"),
Timestamp.valueOf("2018-02-01 00:00:00"),
@@ -765,7 +765,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-03-01 00:00:00")),
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
- Literal(CalendarInterval.fromString("interval 1 month").negate())),
+ Literal(IntervalUtils.fromString("interval 1 month").negate())),
Seq(
Timestamp.valueOf("2018-03-01 00:00:00"),
Timestamp.valueOf("2018-02-01 00:00:00"),
@@ -774,7 +774,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-03-03 00:00:00")),
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
- Literal(CalendarInterval.fromString("interval 1 month 1 day").negate())),
+ Literal(IntervalUtils.fromString("interval 1 month 1 day").negate())),
Seq(
Timestamp.valueOf("2018-03-03 00:00:00"),
Timestamp.valueOf("2018-02-02 00:00:00"),
@@ -783,7 +783,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-01-31 00:00:00")),
Literal(Timestamp.valueOf("2018-04-30 00:00:00")),
- Literal(CalendarInterval.fromString("interval 1 month"))),
+ Literal(IntervalUtils.fromString("interval 1 month"))),
Seq(
Timestamp.valueOf("2018-01-31 00:00:00"),
Timestamp.valueOf("2018-02-28 00:00:00"),
@@ -793,7 +793,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
Literal(Timestamp.valueOf("2018-03-01 00:00:00")),
- Literal(CalendarInterval.fromString("interval 1 month 1 second"))),
+ Literal(IntervalUtils.fromString("interval 1 month 1 second"))),
Seq(
Timestamp.valueOf("2018-01-01 00:00:00"),
Timestamp.valueOf("2018-02-01 00:00:01")))
@@ -801,7 +801,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
Literal(Timestamp.valueOf("2018-03-01 00:04:06")),
- Literal(CalendarInterval.fromString("interval 1 month 2 minutes 3 seconds"))),
+ Literal(IntervalUtils.fromString("interval 1 month 2 minutes 3 seconds"))),
Seq(
Timestamp.valueOf("2018-01-01 00:00:00"),
Timestamp.valueOf("2018-02-01 00:02:03"),
@@ -839,7 +839,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-03-25 01:30:00")),
Literal(Timestamp.valueOf("2018-03-25 03:30:00")),
- Literal(CalendarInterval.fromString("interval 30 minutes"))),
+ Literal(IntervalUtils.fromString("interval 30 minutes"))),
Seq(
Timestamp.valueOf("2018-03-25 01:30:00"),
Timestamp.valueOf("2018-03-25 03:00:00"),
@@ -849,7 +849,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-10-28 01:30:00")),
Literal(Timestamp.valueOf("2018-10-28 03:30:00")),
- Literal(CalendarInterval.fromString("interval 30 minutes"))),
+ Literal(IntervalUtils.fromString("interval 30 minutes"))),
Seq(
Timestamp.valueOf("2018-10-28 01:30:00"),
noDST(Timestamp.valueOf("2018-10-28 02:00:00")),
@@ -866,7 +866,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Date.valueOf("2018-01-01")),
Literal(Date.valueOf("2018-01-05")),
- Literal(CalendarInterval.fromString("interval 2 days"))),
+ Literal(IntervalUtils.fromString("interval 2 days"))),
Seq(
Date.valueOf("2018-01-01"),
Date.valueOf("2018-01-03"),
@@ -875,7 +875,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Date.valueOf("2018-01-01")),
Literal(Date.valueOf("2018-03-01")),
- Literal(CalendarInterval.fromString("interval 1 month"))),
+ Literal(IntervalUtils.fromString("interval 1 month"))),
Seq(
Date.valueOf("2018-01-01"),
Date.valueOf("2018-02-01"),
@@ -884,7 +884,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Date.valueOf("2018-01-31")),
Literal(Date.valueOf("2018-04-30")),
- Literal(CalendarInterval.fromString("interval 1 month"))),
+ Literal(IntervalUtils.fromString("interval 1 month"))),
Seq(
Date.valueOf("2018-01-31"),
Date.valueOf("2018-02-28"),
@@ -905,14 +905,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
new Sequence(
Literal(Date.valueOf("1970-01-02")),
Literal(Date.valueOf("1970-01-01")),
- Literal(CalendarInterval.fromString("interval 1 day"))),
+ Literal(IntervalUtils.fromString("interval 1 day"))),
EmptyRow, "sequence boundaries: 1 to 0 by 1")
checkExceptionInExpression[IllegalArgumentException](
new Sequence(
Literal(Date.valueOf("1970-01-01")),
Literal(Date.valueOf("1970-02-01")),
- Literal(CalendarInterval.fromString("interval 1 month").negate())),
+ Literal(IntervalUtils.fromString("interval 1 month").negate())),
EmptyRow,
s"sequence boundaries: 0 to 2678400000000 by -${28 * CalendarInterval.MICROS_PER_DAY}")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 0c4438987cd2a..9039cd6451590 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -369,7 +369,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
val b = AttributeReference("b", IntegerType)()
checkMetadata(CreateStruct(Seq(a, b)))
checkMetadata(CreateNamedStruct(Seq("a", a, "b", b)))
- checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b)))
}
test("StringToMap") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index e893e863b3675..6abadd77bd41a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
-import org.apache.spark.sql.catalyst.util.{DateTimeUtils, TimestampFormatter}
+import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT
import org.apache.spark.sql.internal.SQLConf
@@ -1075,16 +1075,16 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(SubtractTimestamps(Literal(end), Literal(end)),
new CalendarInterval(0, 0))
checkEvaluation(SubtractTimestamps(Literal(end), Literal(Instant.EPOCH)),
- CalendarInterval.fromString("interval 18173 days " +
+ IntervalUtils.fromString("interval 18173 days " +
"11 hours 4 minutes 1 seconds 123 milliseconds 456 microseconds"))
checkEvaluation(SubtractTimestamps(Literal(Instant.EPOCH), Literal(end)),
- CalendarInterval.fromString("interval -18173 days " +
+ IntervalUtils.fromString("interval -18173 days " +
"-11 hours -4 minutes -1 seconds -123 milliseconds -456 microseconds"))
checkEvaluation(
SubtractTimestamps(
Literal(Instant.parse("9999-12-31T23:59:59.999999Z")),
Literal(Instant.parse("0001-01-01T00:00:00Z"))),
- CalendarInterval.fromString("interval 521722 weeks 4 days " +
+ IntervalUtils.fromString("interval 521722 weeks 4 days " +
"23 hours 59 minutes 59 seconds 999 milliseconds 999 microseconds"))
}
@@ -1093,18 +1093,18 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(SubtractDates(Literal(end), Literal(end)),
new CalendarInterval(0, 0))
checkEvaluation(SubtractDates(Literal(end.plusDays(1)), Literal(end)),
- CalendarInterval.fromString("interval 1 days"))
+ IntervalUtils.fromString("interval 1 days"))
checkEvaluation(SubtractDates(Literal(end.minusDays(1)), Literal(end)),
- CalendarInterval.fromString("interval -1 days"))
+ IntervalUtils.fromString("interval -1 days"))
val epochDate = Literal(LocalDate.ofEpochDay(0))
checkEvaluation(SubtractDates(Literal(end), epochDate),
- CalendarInterval.fromString("interval 49 years 9 months 4 days"))
+ IntervalUtils.fromString("interval 49 years 9 months 4 days"))
checkEvaluation(SubtractDates(epochDate, Literal(end)),
- CalendarInterval.fromString("interval -49 years -9 months -4 days"))
+ IntervalUtils.fromString("interval -49 years -9 months -4 days"))
checkEvaluation(
SubtractDates(
Literal(LocalDate.of(10000, 1, 1)),
Literal(LocalDate.of(1, 1, 1))),
- CalendarInterval.fromString("interval 9999 years"))
+ IntervalUtils.fromString("interval 9999 years"))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
index f90c98be0b3fd..4b2da73abe562 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
@@ -30,9 +30,9 @@ import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData, IntervalUtils}
import org.apache.spark.sql.types.{ArrayType, StructType, _}
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.unsafe.types.UTF8String
class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val random = new scala.util.Random
@@ -252,7 +252,7 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("hive-hash for CalendarInterval type") {
def checkHiveHashForIntervalType(interval: String, expected: Long): Unit = {
- checkHiveHash(CalendarInterval.fromString(interval), CalendarIntervalType, expected)
+ checkHiveHash(IntervalUtils.fromString(interval), CalendarIntervalType, expected)
}
// ----- MICROSEC -----
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala
index 078ec88800215..818ee239dbbf8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala
@@ -20,12 +20,12 @@ package org.apache.spark.sql.catalyst.expressions
import scala.language.implicitConversions
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.types.Decimal
-import org.apache.spark.unsafe.types.CalendarInterval
class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
implicit def interval(s: String): Literal = {
- Literal(CalendarInterval.fromString("interval " + s))
+ Literal(IntervalUtils.fromString("interval " + s))
}
test("millenniums") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
index 0d594eb10962e..23ba9c6ec7388 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
@@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.CalendarInterval
class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -56,7 +56,7 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper {
testBothCodegenAndInterpreted("variable-length types") {
val proj = createMutableProjection(variableLengthTypes)
- val scalaValues = Seq("abc", BigDecimal(10), CalendarInterval.fromString("interval 1 day"),
+ val scalaValues = Seq("abc", BigDecimal(10), IntervalUtils.fromString("interval 1 day"),
Array[Byte](1, 2), Array("123", "456"), Map(1 -> "a", 2 -> "b"), Row(1, "a"),
new java.lang.Integer(5))
val inputRow = InternalRow.fromSeq(scalaValues.zip(variableLengthTypes).map {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index a171885471a36..4ccd4f7ce798d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -35,8 +35,7 @@ import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData, IntervalUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -486,7 +485,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
("abcd".getBytes, BinaryType),
("abcd", StringType),
(BigDecimal.valueOf(10), DecimalType.IntDecimal),
- (CalendarInterval.fromString("interval 3 day"), CalendarIntervalType),
+ (IntervalUtils.fromString("interval 3 day"), CalendarIntervalType),
(java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal),
(Array(3, 2, 1), ArrayType(IntegerType))
).foreach { case (input, dt) =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 323a3a901689f..20e77254ecdad 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -531,7 +531,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB
// Simple tests
val inputRow = InternalRow.fromSeq(Seq(
false, 3.toByte, 15.toShort, -83, 129L, 1.0f, 8.0, UTF8String.fromString("test"),
- Decimal(255), CalendarInterval.fromString("interval 1 day"), Array[Byte](1, 2)
+ Decimal(255), IntervalUtils.fromString("interval 1 day"), Array[Byte](1, 2)
))
val fields1 = Array(
BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
index 0eaf74f655065..da01c612b350e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
@@ -845,6 +845,90 @@ class DDLParserSuite extends AnalysisTest {
ShowTablesStatement(Some(Seq("tbl")), Some("*dog*")))
}
+ test("create namespace -- backward compatibility with DATABASE/DBPROPERTIES") {
+ val expected = CreateNamespaceStatement(
+ Seq("a", "b", "c"),
+ ifNotExists = true,
+ Map(
+ "a" -> "a",
+ "b" -> "b",
+ "c" -> "c",
+ "comment" -> "namespace_comment",
+ "location" -> "/home/user/db"))
+
+ comparePlans(
+ parsePlan(
+ """
+ |CREATE NAMESPACE IF NOT EXISTS a.b.c
+ |WITH PROPERTIES ('a'='a', 'b'='b', 'c'='c')
+ |COMMENT 'namespace_comment' LOCATION '/home/user/db'
+ """.stripMargin),
+ expected)
+
+ comparePlans(
+ parsePlan(
+ """
+ |CREATE DATABASE IF NOT EXISTS a.b.c
+ |WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')
+ |COMMENT 'namespace_comment' LOCATION '/home/user/db'
+ """.stripMargin),
+ expected)
+ }
+
+ test("create namespace -- check duplicates") {
+ def createDatabase(duplicateClause: String): String = {
+ s"""
+ |CREATE NAMESPACE IF NOT EXISTS a.b.c
+ |$duplicateClause
+ |$duplicateClause
+ """.stripMargin
+ }
+ val sql1 = createDatabase("COMMENT 'namespace_comment'")
+ val sql2 = createDatabase("LOCATION '/home/user/db'")
+ val sql3 = createDatabase("WITH PROPERTIES ('a'='a', 'b'='b', 'c'='c')")
+ val sql4 = createDatabase("WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')")
+
+ intercept(sql1, "Found duplicate clauses: COMMENT")
+ intercept(sql2, "Found duplicate clauses: LOCATION")
+ intercept(sql3, "Found duplicate clauses: WITH PROPERTIES")
+ intercept(sql4, "Found duplicate clauses: WITH DBPROPERTIES")
+ }
+
+ test("create namespace - property values must be set") {
+ assertUnsupported(
+ sql = "CREATE NAMESPACE a.b.c WITH PROPERTIES('key_without_value', 'key_with_value'='x')",
+ containsThesePhrases = Seq("key_without_value"))
+ }
+
+ test("create namespace -- either PROPERTIES or DBPROPERTIES is allowed") {
+ val sql =
+ s"""
+ |CREATE NAMESPACE IF NOT EXISTS a.b.c
+ |WITH PROPERTIES ('a'='a', 'b'='b', 'c'='c')
+ |WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')
+ """.stripMargin
+ intercept(sql, "Either PROPERTIES or DBPROPERTIES is allowed")
+ }
+
+ test("create namespace - support for other types in PROPERTIES") {
+ val sql =
+ """
+ |CREATE NAMESPACE a.b.c
+ |LOCATION '/home/user/db'
+ |WITH PROPERTIES ('a'=1, 'b'=0.1, 'c'=TRUE)
+ """.stripMargin
+ comparePlans(
+ parsePlan(sql),
+ CreateNamespaceStatement(
+ Seq("a", "b", "c"),
+ ifNotExists = false,
+ Map(
+ "a" -> "1",
+ "b" -> "0.1",
+ "c" -> "true",
+ "location" -> "/home/user/db")))
+ }
+
test("show databases: basic") {
comparePlans(
parsePlan("SHOW DATABASES"),
@@ -955,12 +1039,87 @@ class DDLParserSuite extends AnalysisTest {
"missing 'COLUMNS' at ''")
}
- test("MSCK REPAIR table") {
+ test("MSCK REPAIR TABLE") {
comparePlans(
parsePlan("MSCK REPAIR TABLE a.b.c"),
RepairTableStatement(Seq("a", "b", "c")))
}
+ test("SHOW CREATE table") {
+ comparePlans(
+ parsePlan("SHOW CREATE TABLE a.b.c"),
+ ShowCreateTableStatement(Seq("a", "b", "c")))
+ }
+
+ test("CACHE TABLE") {
+ comparePlans(
+ parsePlan("CACHE TABLE a.b.c"),
+ CacheTableStatement(Seq("a", "b", "c"), None, false, Map.empty))
+
+ comparePlans(
+ parsePlan("CACHE LAZY TABLE a.b.c"),
+ CacheTableStatement(Seq("a", "b", "c"), None, true, Map.empty))
+
+ comparePlans(
+ parsePlan("CACHE LAZY TABLE a.b.c OPTIONS('storageLevel' 'DISK_ONLY')"),
+ CacheTableStatement(Seq("a", "b", "c"), None, true, Map("storageLevel" -> "DISK_ONLY")))
+
+ intercept("CACHE TABLE a.b.c AS SELECT * FROM testData",
+ "It is not allowed to add catalog/namespace prefix a.b")
+ }
+
+ test("UNCACHE TABLE") {
+ comparePlans(
+ parsePlan("UNCACHE TABLE a.b.c"),
+ UncacheTableStatement(Seq("a", "b", "c"), ifExists = false))
+
+ comparePlans(
+ parsePlan("UNCACHE TABLE IF EXISTS a.b.c"),
+ UncacheTableStatement(Seq("a", "b", "c"), ifExists = true))
+ }
+
+ test("TRUNCATE table") {
+ comparePlans(
+ parsePlan("TRUNCATE TABLE a.b.c"),
+ TruncateTableStatement(Seq("a", "b", "c"), None))
+
+ comparePlans(
+ parsePlan("TRUNCATE TABLE a.b.c PARTITION(ds='2017-06-10')"),
+ TruncateTableStatement(Seq("a", "b", "c"), Some(Map("ds" -> "2017-06-10"))))
+ }
+
+ test("SHOW PARTITIONS") {
+ val sql1 = "SHOW PARTITIONS t1"
+ val sql2 = "SHOW PARTITIONS db1.t1"
+ val sql3 = "SHOW PARTITIONS t1 PARTITION(partcol1='partvalue', partcol2='partvalue')"
+ val sql4 = "SHOW PARTITIONS a.b.c"
+ val sql5 = "SHOW PARTITIONS a.b.c PARTITION(ds='2017-06-10')"
+
+ val parsed1 = parsePlan(sql1)
+ val expected1 = ShowPartitionsStatement(Seq("t1"), None)
+ val parsed2 = parsePlan(sql2)
+ val expected2 = ShowPartitionsStatement(Seq("db1", "t1"), None)
+ val parsed3 = parsePlan(sql3)
+ val expected3 = ShowPartitionsStatement(Seq("t1"),
+ Some(Map("partcol1" -> "partvalue", "partcol2" -> "partvalue")))
+ val parsed4 = parsePlan(sql4)
+ val expected4 = ShowPartitionsStatement(Seq("a", "b", "c"), None)
+ val parsed5 = parsePlan(sql5)
+ val expected5 = ShowPartitionsStatement(Seq("a", "b", "c"), Some(Map("ds" -> "2017-06-10")))
+
+ comparePlans(parsed1, expected1)
+ comparePlans(parsed2, expected2)
+ comparePlans(parsed3, expected3)
+ comparePlans(parsed4, expected4)
+ comparePlans(parsed5, expected5)
+ }
+
+ test("REFRESH TABLE") {
+ comparePlans(
+ parsePlan("REFRESH TABLE a.b.c"),
+ RefreshTableStatement(Seq("a", "b", "c")))
+ }
+
private case class TableSpec(
name: Seq[String],
schema: Option[StructType],
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
index e6eabcc1f3022..86b3aa8190b45 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
-import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
+import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, IntervalUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@@ -432,7 +432,7 @@ class ExpressionParserSuite extends AnalysisTest {
intercept("timestamP '2016-33-11 20:54:00.000'", "Cannot parse the TIMESTAMP value")
// Interval.
- val intervalLiteral = Literal(CalendarInterval.fromString("interval 3 month 1 hour"))
+ val intervalLiteral = Literal(IntervalUtils.fromString("interval 3 month 1 hour"))
assertEqual("InterVal 'interval 3 month 1 hour'", intervalLiteral)
assertEqual("INTERVAL '3 month 1 hour'", intervalLiteral)
intercept("Interval 'interval 3 monthsss 1 hoursss'", "Cannot parse the INTERVAL value")
@@ -597,7 +597,7 @@ class ExpressionParserSuite extends AnalysisTest {
"microsecond")
def intervalLiteral(u: String, s: String): Literal = {
- Literal(CalendarInterval.fromSingleUnitString(u, s))
+ Literal(CalendarInterval.fromUnitStrings(Array(u), Array(s)))
}
test("intervals") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
index 07f77ea889dba..c6434f2bdd3ec 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
@@ -50,7 +50,7 @@ class ParserUtilsSuite extends SparkFunSuite {
|WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')
""".stripMargin
) { parser =>
- parser.statement().asInstanceOf[CreateDatabaseContext]
+ parser.statement().asInstanceOf[CreateNamespaceContext]
}
val emptyContext = buildContext("") { parser =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
index 1abbca6c8cd29..0eaf538231284 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
@@ -555,12 +555,12 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers {
// There are some days are skipped entirely in some timezone, skip them here.
val skipped_days = Map[String, Set[Int]](
- "Kwajalein" -> Set(8632, 8633),
+ "Kwajalein" -> Set(8632, 8633, 8634),
"Pacific/Apia" -> Set(15338),
"Pacific/Enderbury" -> Set(9130, 9131),
"Pacific/Fakaofo" -> Set(15338),
"Pacific/Kiritimati" -> Set(9130, 9131),
- "Pacific/Kwajalein" -> Set(8632, 8633),
+ "Pacific/Kwajalein" -> Set(8632, 8633, 8634),
"MIT" -> Set(15338))
for (tz <- ALL_TIMEZONES) {
val skipped = skipped_days.getOrElse(tz.getID, Set.empty)
@@ -586,12 +586,15 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers {
val now = instantToMicros(LocalDateTime.now(zoneId).atZone(zoneId).toInstant)
toTimestamp("NOW", zoneId).get should be (now +- tolerance)
assert(toTimestamp("now UTC", zoneId) === None)
- val today = instantToMicros(LocalDateTime.now(zoneId)
+ val localToday = LocalDateTime.now(zoneId)
.`with`(LocalTime.MIDNIGHT)
- .atZone(zoneId).toInstant)
- toTimestamp(" Yesterday", zoneId).get should be (today - MICROS_PER_DAY +- tolerance)
+ .atZone(zoneId)
+ val yesterday = instantToMicros(localToday.minusDays(1).toInstant)
+ toTimestamp(" Yesterday", zoneId).get should be (yesterday +- tolerance)
+ val today = instantToMicros(localToday.toInstant)
toTimestamp("Today ", zoneId).get should be (today +- tolerance)
- toTimestamp(" tomorrow CET ", zoneId).get should be (today + MICROS_PER_DAY +- tolerance)
+ val tomorrow = instantToMicros(localToday.plusDays(1).toInstant)
+ toTimestamp(" tomorrow CET ", zoneId).get should be (tomorrow +- tolerance)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
new file mode 100644
index 0000000000000..e48779af3c9aa
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.util.IntervalUtils.fromString
+import org.apache.spark.unsafe.types.CalendarInterval
+import org.apache.spark.unsafe.types.CalendarInterval._
+
+class IntervalUtilsSuite extends SparkFunSuite {
+
+ test("fromString: basic") {
+ testSingleUnit("YEAR", 3, 36, 0)
+ testSingleUnit("Month", 3, 3, 0)
+ testSingleUnit("Week", 3, 0, 3 * MICROS_PER_WEEK)
+ testSingleUnit("DAY", 3, 0, 3 * MICROS_PER_DAY)
+ testSingleUnit("HouR", 3, 0, 3 * MICROS_PER_HOUR)
+ testSingleUnit("MiNuTe", 3, 0, 3 * MICROS_PER_MINUTE)
+ testSingleUnit("Second", 3, 0, 3 * MICROS_PER_SECOND)
+ testSingleUnit("MilliSecond", 3, 0, 3 * MICROS_PER_MILLI)
+ testSingleUnit("MicroSecond", 3, 0, 3)
+
+ for (input <- Seq(null, "", " ")) {
+ try {
+ fromString(input)
+ fail("Expected to throw an exception for the invalid input")
+ } catch {
+ case e: IllegalArgumentException =>
+ val msg = e.getMessage
+ if (input == null) {
+ assert(msg.contains("cannot be null"))
+ }
+ }
+ }
+
+ for (input <- Seq("interval", "interval1 day", "foo", "foo 1 day")) {
+ try {
+ fromString(input)
+ fail("Expected to throw an exception for the invalid input")
+ } catch {
+ case e: IllegalArgumentException =>
+ val msg = e.getMessage
+ assert(msg.contains("Invalid interval string"))
+ }
+ }
+ }
+
+ test("fromString: random order field") {
+ val input = "1 day 1 year"
+ val result = new CalendarInterval(12, MICROS_PER_DAY)
+ assert(fromString(input) == result)
+ }
+
+ test("fromString: duplicated fields") {
+ val input = "1 day 1 day"
+ val result = new CalendarInterval(0, 2 * MICROS_PER_DAY)
+ assert(fromString(input) == result)
+ }
+
+ test("fromString: value with +/-") {
+ val input = "+1 year -1 day"
+ val result = new CalendarInterval(12, -MICROS_PER_DAY)
+ assert(fromString(input) == result)
+ }
+
+ private def testSingleUnit(unit: String, number: Int, months: Int, microseconds: Long): Unit = {
+ for (prefix <- Seq("interval ", "")) {
+ val input1 = prefix + number + " " + unit
+ val input2 = prefix + number + " " + unit + "s"
+ val result = new CalendarInterval(months, microseconds)
+ assert(fromString(input1) == result)
+ assert(fromString(input2) == result)
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala
index 8724a38d08d1f..ece903a4c2838 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala
@@ -34,6 +34,8 @@ class BasicInMemoryTableCatalog extends TableCatalog {
protected val tables: util.Map[Identifier, InMemoryTable] =
new ConcurrentHashMap[Identifier, InMemoryTable]()
+ private val invalidatedTables: util.Set[Identifier] = ConcurrentHashMap.newKeySet()
+
private var _name: Option[String] = None
override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {
@@ -55,6 +57,10 @@ class BasicInMemoryTableCatalog extends TableCatalog {
}
}
+ override def invalidateTable(ident: Identifier): Unit = {
+ invalidatedTables.add(ident)
+ }
+
override def createTable(
ident: Identifier,
schema: StructType,
@@ -104,6 +110,10 @@ class BasicInMemoryTableCatalog extends TableCatalog {
}
}
+ def isTableInvalidated(ident: Identifier): Boolean = {
+ invalidatedTables.contains(ident)
+ }
+
def clearTables(): Unit = {
tables.clear()
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala
index 170daa6277c49..84581c0badd86 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.Matchers
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, TimestampFormatter}
-import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, instantToMicros, MICROS_PER_DAY}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, instantToMicros}
import org.apache.spark.sql.internal.SQLConf
class TimestampFormatterSuite extends SparkFunSuite with SQLHelper with Matchers {
@@ -146,12 +146,15 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper with Matchers
assert(formatter.parse("EPOCH") === 0)
val now = instantToMicros(LocalDateTime.now(zoneId).atZone(zoneId).toInstant)
formatter.parse("now") should be (now +- tolerance)
- val today = instantToMicros(LocalDateTime.now(zoneId)
+ val localToday = LocalDateTime.now(zoneId)
.`with`(LocalTime.MIDNIGHT)
- .atZone(zoneId).toInstant)
- formatter.parse("yesterday CET") should be (today - MICROS_PER_DAY +- tolerance)
+ .atZone(zoneId)
+ val yesterday = instantToMicros(localToday.minusDays(1).toInstant)
+ formatter.parse("yesterday CET") should be (yesterday +- tolerance)
+ val today = instantToMicros(localToday.toInstant)
formatter.parse(" TODAY ") should be (today +- tolerance)
- formatter.parse("Tomorrow ") should be (today + MICROS_PER_DAY +- tolerance)
+ val tomorrow = instantToMicros(localToday.plusDays(1).toInstant)
+ formatter.parse("Tomorrow ") should be (tomorrow +- tolerance)
}
}
}
diff --git a/sql/core/benchmarks/IntervalBenchmark-jdk11-results.txt b/sql/core/benchmarks/IntervalBenchmark-jdk11-results.txt
index 2a3903200a8ac..221ac42022a15 100644
--- a/sql/core/benchmarks/IntervalBenchmark-jdk11-results.txt
+++ b/sql/core/benchmarks/IntervalBenchmark-jdk11-results.txt
@@ -1,25 +1,25 @@
-OpenJDK 64-Bit Server VM 11.0.2+9 on Mac OS X 10.15
-Intel(R) Core(TM) i7-4850HQ CPU @ 2.30GHz
+OpenJDK 64-Bit Server VM 11.0.4+11-post-Ubuntu-1ubuntu218.04.3 on Linux 4.15.0-1044-aws
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
cast strings to intervals: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-string w/ interval 471 513 57 2.1 470.7 1.0X
-string w/o interval 437 444 8 2.3 436.9 1.1X
-1 units w/ interval 726 758 45 1.4 726.3 0.6X
-1 units w/o interval 712 717 5 1.4 711.7 0.7X
-2 units w/ interval 926 935 12 1.1 925.9 0.5X
-2 units w/o interval 943 947 3 1.1 943.4 0.5X
-3 units w/ interval 1089 1116 31 0.9 1089.0 0.4X
-3 units w/o interval 1105 1108 3 0.9 1105.1 0.4X
-4 units w/ interval 1260 1261 1 0.8 1260.4 0.4X
-4 units w/o interval 1276 1277 1 0.8 1275.9 0.4X
-5 units w/ interval 1436 1445 11 0.7 1435.6 0.3X
-5 units w/o interval 1455 1463 6 0.7 1455.5 0.3X
-6 units w/ interval 1634 1639 4 0.6 1634.4 0.3X
-6 units w/o interval 1642 1644 3 0.6 1641.7 0.3X
-7 units w/ interval 1829 1838 8 0.5 1828.6 0.3X
-7 units w/o interval 1850 1853 4 0.5 1849.5 0.3X
-8 units w/ interval 2065 2070 5 0.5 2065.4 0.2X
-8 units w/o interval 2070 2090 21 0.5 2070.0 0.2X
-9 units w/ interval 2279 2290 10 0.4 2278.7 0.2X
-9 units w/o interval 2276 2285 8 0.4 2275.7 0.2X
+prepare string w/ interval 672 728 64 1.5 672.1 1.0X
+prepare string w/o interval 580 602 19 1.7 580.4 1.2X
+1 units w/ interval 9450 9575 138 0.1 9449.6 0.1X
+1 units w/o interval 8948 8968 19 0.1 8948.3 0.1X
+2 units w/ interval 10947 10966 19 0.1 10947.1 0.1X
+2 units w/o interval 10470 10489 26 0.1 10469.5 0.1X
+3 units w/ interval 12265 12333 72 0.1 12264.5 0.1X
+3 units w/o interval 12001 12004 3 0.1 12000.6 0.1X
+4 units w/ interval 13749 13828 69 0.1 13748.5 0.0X
+4 units w/o interval 13467 13479 15 0.1 13467.3 0.0X
+5 units w/ interval 15392 15446 51 0.1 15392.1 0.0X
+5 units w/o interval 15090 15107 29 0.1 15089.7 0.0X
+6 units w/ interval 16696 16714 20 0.1 16695.9 0.0X
+6 units w/o interval 16361 16366 5 0.1 16361.4 0.0X
+7 units w/ interval 18190 18270 71 0.1 18190.2 0.0X
+7 units w/o interval 17757 17767 9 0.1 17756.7 0.0X
+8 units w/ interval 19821 19870 43 0.1 19820.7 0.0X
+8 units w/o interval 19479 19555 97 0.1 19479.5 0.0X
+9 units w/ interval 21417 21481 56 0.0 21417.1 0.0X
+9 units w/o interval 21058 21131 86 0.0 21058.2 0.0X
diff --git a/sql/core/benchmarks/IntervalBenchmark-results.txt b/sql/core/benchmarks/IntervalBenchmark-results.txt
index 9010b980c07b5..60e8e5198353c 100644
--- a/sql/core/benchmarks/IntervalBenchmark-results.txt
+++ b/sql/core/benchmarks/IntervalBenchmark-results.txt
@@ -1,25 +1,26 @@
-Java HotSpot(TM) 64-Bit Server VM 1.8.0_202-b08 on Mac OS X 10.15
-Intel(R) Core(TM) i7-4850HQ CPU @ 2.30GHz
+OpenJDK 64-Bit Server VM 1.8.0_222-8u222-b10-1ubuntu1~18.04.1-b10 on Linux 4.15.0-1044-aws
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
cast strings to intervals: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
-string w/ interval 420 435 18 2.4 419.8 1.0X
-string w/o interval 359 365 10 2.8 358.7 1.2X
-1 units w/ interval 752 759 8 1.3 752.0 0.6X
-1 units w/o interval 762 766 4 1.3 762.0 0.6X
-2 units w/ interval 961 970 8 1.0 960.7 0.4X
-2 units w/o interval 970 976 9 1.0 970.2 0.4X
-3 units w/ interval 1130 1136 7 0.9 1130.4 0.4X
-3 units w/o interval 1150 1158 9 0.9 1150.3 0.4X
-4 units w/ interval 1333 1336 3 0.7 1333.5 0.3X
-4 units w/o interval 1354 1359 4 0.7 1354.5 0.3X
-5 units w/ interval 1523 1525 2 0.7 1523.3 0.3X
-5 units w/o interval 1549 1551 3 0.6 1549.4 0.3X
-6 units w/ interval 1661 1663 2 0.6 1660.8 0.3X
-6 units w/o interval 1691 1704 13 0.6 1691.2 0.2X
-7 units w/ interval 1811 1817 8 0.6 1810.6 0.2X
-7 units w/o interval 1853 1854 1 0.5 1853.2 0.2X
-8 units w/ interval 2029 2037 8 0.5 2028.7 0.2X
-8 units w/o interval 2075 2075 1 0.5 2074.5 0.2X
-9 units w/ interval 2170 2175 5 0.5 2170.0 0.2X
-9 units w/o interval 2204 2212 8 0.5 2203.6 0.2X
+prepare string w/ interval 596 647 61 1.7 596.0 1.0X
+prepare string w/o interval 530 554 22 1.9 530.2 1.1X
+1 units w/ interval 9168 9243 66 0.1 9167.8 0.1X
+1 units w/o interval 8740 8744 5 0.1 8740.2 0.1X
+2 units w/ interval 10815 10874 52 0.1 10815.0 0.1X
+2 units w/o interval 10413 10419 11 0.1 10412.8 0.1X
+3 units w/ interval 12490 12530 37 0.1 12490.3 0.0X
+3 units w/o interval 12173 12180 9 0.1 12172.8 0.0X
+4 units w/ interval 13788 13834 43 0.1 13788.0 0.0X
+4 units w/o interval 13445 13456 10 0.1 13445.5 0.0X
+5 units w/ interval 15313 15330 15 0.1 15312.7 0.0X
+5 units w/o interval 14928 14942 16 0.1 14928.0 0.0X
+6 units w/ interval 16959 17003 42 0.1 16959.1 0.0X
+6 units w/o interval 16623 16627 5 0.1 16623.3 0.0X
+7 units w/ interval 18955 18972 21 0.1 18955.4 0.0X
+7 units w/o interval 18454 18462 7 0.1 18454.1 0.0X
+8 units w/ interval 20835 20843 8 0.0 20835.4 0.0X
+8 units w/o interval 20446 20463 19 0.0 20445.7 0.0X
+9 units w/ interval 22981 23031 43 0.0 22981.4 0.0X
+9 units w/o interval 22581 22603 25 0.0 22581.1 0.0X
+
diff --git a/sql/core/benchmarks/MetricsAggregationBenchmark-jdk11-results.txt b/sql/core/benchmarks/MetricsAggregationBenchmark-jdk11-results.txt
new file mode 100644
index 0000000000000..e33ed30eaa559
--- /dev/null
+++ b/sql/core/benchmarks/MetricsAggregationBenchmark-jdk11-results.txt
@@ -0,0 +1,12 @@
+OpenJDK 64-Bit Server VM 11.0.4+11 on Linux 4.15.0-66-generic
+Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz
+metrics aggregation (50 metrics, 100000 tasks per stage): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+1 stage(s) 672 841 179 0.0 671888474.0 1.0X
+2 stage(s) 1700 1842 201 0.0 1699591662.0 0.4X
+3 stage(s) 2601 2776 247 0.0 2601465786.0 0.3X
+
+Stage Count Stage Proc. Time Aggreg. Time
+ 1 436 164
+ 2 537 354
+ 3 480 602
diff --git a/sql/core/benchmarks/MetricsAggregationBenchmark-results.txt b/sql/core/benchmarks/MetricsAggregationBenchmark-results.txt
new file mode 100644
index 0000000000000..4fae928258d32
--- /dev/null
+++ b/sql/core/benchmarks/MetricsAggregationBenchmark-results.txt
@@ -0,0 +1,12 @@
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Linux 4.15.0-66-generic
+Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz
+metrics aggregation (50 metrics, 100000 tasks per stage): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+1 stage(s) 740 883 147 0.0 740089816.0 1.0X
+2 stage(s) 1661 1943 399 0.0 1660649192.0 0.4X
+3 stage(s) 2711 2967 362 0.0 2711110178.0 0.3X
+
+Stage Count Stage Proc. Time Aggreg. Time
+ 1 405 179
+ 2 375 414
+ 3 364 644
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 863d80b5cb9c5..90b55a8586de7 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -20,7 +20,6 @@
import java.io.IOException;
import java.util.function.Supplier;
-import scala.collection.AbstractIterator;
import scala.collection.Iterator;
import scala.math.Ordering;
@@ -52,6 +51,12 @@ public final class UnsafeExternalRowSorter {
private final UnsafeExternalRowSorter.PrefixComputer prefixComputer;
private final UnsafeExternalSorter sorter;
+ // This flag makes sure the cleanupResource() has been called. After the cleanup work,
+ // iterator.next should always return false. Downstream operator triggers the resource
+ // cleanup while they found there's no need to keep the iterator any more.
+ // See more details in SPARK-21492.
+ private boolean isReleased = false;
+
public abstract static class PrefixComputer {
public static class Prefix {
@@ -157,11 +162,12 @@ public long getSortTimeNanos() {
return sorter.getSortTimeNanos();
}
- private void cleanupResources() {
+ public void cleanupResources() {
+ isReleased = true;
sorter.cleanupResources();
}
- public Iterator sort() throws IOException {
+ public Iterator sort() throws IOException {
try {
final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
if (!sortedIterator.hasNext()) {
@@ -169,31 +175,32 @@ public Iterator sort() throws IOException {
// here in order to prevent memory leaks.
cleanupResources();
}
- return new AbstractIterator() {
+ return new RowIterator() {
private final int numFields = schema.length();
private UnsafeRow row = new UnsafeRow(numFields);
@Override
- public boolean hasNext() {
- return sortedIterator.hasNext();
- }
-
- @Override
- public UnsafeRow next() {
+ public boolean advanceNext() {
try {
- sortedIterator.loadNext();
- row.pointTo(
- sortedIterator.getBaseObject(),
- sortedIterator.getBaseOffset(),
- sortedIterator.getRecordLength());
- if (!hasNext()) {
- UnsafeRow copy = row.copy(); // so that we don't have dangling pointers to freed page
- row = null; // so that we don't keep references to the base object
- cleanupResources();
- return copy;
+ if (!isReleased && sortedIterator.hasNext()) {
+ sortedIterator.loadNext();
+ row.pointTo(
+ sortedIterator.getBaseObject(),
+ sortedIterator.getBaseOffset(),
+ sortedIterator.getRecordLength());
+ // Here is the initial bug fix in SPARK-9364: the bug fix of use-after-free bug
+ // when returning the last row from an iterator. For example, in
+ // [[GroupedIterator]], we still use the last row after traversing the iterator
+ // in `fetchNextGroupIterator`
+ if (!sortedIterator.hasNext()) {
+ row = row.copy(); // so that we don't have dangling pointers to freed page
+ cleanupResources();
+ }
+ return true;
} else {
- return row;
+ row = null; // so that we don't keep references to the base object
+ return false;
}
} catch (IOException e) {
cleanupResources();
@@ -203,14 +210,18 @@ public UnsafeRow next() {
}
throw new RuntimeException("Exception should have been re-thrown in next()");
}
- };
+
+ @Override
+ public UnsafeRow getRow() { return row; }
+
+ }.toScala();
} catch (IOException e) {
cleanupResources();
throw e;
}
}
- public Iterator sort(Iterator inputIterator) throws IOException {
+ public Iterator sort(Iterator inputIterator) throws IOException {
while (inputIterator.hasNext()) {
insertRow(inputIterator.next());
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 7b903a3f7f148..ed10843b08596 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -200,7 +200,7 @@ class Column(val expr: Expression) extends Logging {
UnresolvedAlias(a, Some(Column.generateAlias))
// Wait until the struct is resolved. This will generate a nicer looking alias.
- case struct: CreateNamedStructLike => UnresolvedAlias(struct)
+ case struct: CreateNamedStruct => UnresolvedAlias(struct)
case expr: Expression => Alias(expr, toPrettySQL(expr))()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 4f88cc6daa331..68127c27a8cc2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -687,6 +687,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* `encoding` (by default it is not set): specifies encoding (charset) of saved json
* files. If it is not set, the UTF-8 charset will be used.
* `lineSep` (default `\n`): defines the line separator that should be used for writing.
+ * `ignoreNullFields` (default `true`): Whether to ignore null fields
+ * when generating JSON objects.
*
*
* @since 1.4.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 076270a9f1c6b..5f6e0a82be4ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -46,6 +46,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
+import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters}
import org.apache.spark.sql.execution.command._
@@ -724,7 +725,7 @@ class Dataset[T] private[sql](
def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withTypedPlan {
val parsedDelay =
try {
- CalendarInterval.fromCaseInsensitiveString(delayThreshold)
+ IntervalUtils.fromString(delayThreshold)
} catch {
case e: IllegalArgumentException =>
throw new AnalysisException(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
index 72f539f72008d..e7e34b1ef3127 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
@@ -24,8 +24,8 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog, TableChange, V1Table}
import org.apache.spark.sql.connector.expressions.Transform
-import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableRecoverPartitionsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, AnalyzeColumnCommand, AnalyzePartitionCommand, AnalyzeTableCommand, DescribeColumnCommand, DescribeTableCommand, DropTableCommand, ShowTablesCommand}
-import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource}
+import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableRecoverPartitionsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, AnalyzeColumnCommand, AnalyzePartitionCommand, AnalyzeTableCommand, CacheTableCommand, CreateDatabaseCommand, DescribeColumnCommand, DescribeTableCommand, DropTableCommand, ShowCreateTableCommand, ShowPartitionsCommand, ShowTablesCommand, TruncateTableCommand, UncacheTableCommand}
+import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, RefreshTable}
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{HIVE_TYPE_STRING, HiveStringType, MetadataBuilder, StructField, StructType}
@@ -216,6 +216,9 @@ class ResolveSessionCatalog(
ignoreIfExists = c.ifNotExists)
}
+ case RefreshTableStatement(SessionCatalog(_, tableName)) =>
+ RefreshTable(tableName.asTableIdentifier)
+
// For REPLACE TABLE [AS SELECT], we should fail if the catalog is resolved to the
// session catalog and the table provider is not v2.
case c @ ReplaceTableStatement(
@@ -255,6 +258,19 @@ class ResolveSessionCatalog(
case DropViewStatement(SessionCatalog(catalog, viewName), ifExists) =>
DropTableCommand(viewName.asTableIdentifier, ifExists, isView = true, purge = false)
+ case c @ CreateNamespaceStatement(SessionCatalog(catalog, nameParts), _, _) =>
+ if (nameParts.length != 1) {
+ throw new AnalysisException(
+ s"The database name is not valid: ${nameParts.quoted}")
+ }
+
+ val comment = c.properties.get(CreateNamespaceStatement.COMMENT_PROPERTY_KEY)
+ val location = c.properties.get(CreateNamespaceStatement.LOCATION_PROPERTY_KEY)
+ val newProperties = c.properties -
+ CreateNamespaceStatement.COMMENT_PROPERTY_KEY -
+ CreateNamespaceStatement.LOCATION_PROPERTY_KEY
+ CreateDatabaseCommand(nameParts.head, c.ifNotExists, location, comment, newProperties)
+
case ShowTablesStatement(Some(SessionCatalog(catalog, nameParts)), pattern) =>
if (nameParts.length != 1) {
throw new AnalysisException(
@@ -282,6 +298,30 @@ class ResolveSessionCatalog(
AlterTableRecoverPartitionsCommand(
v1TableName.asTableIdentifier,
"MSCK REPAIR TABLE")
+
+ case ShowCreateTableStatement(tableName) =>
+ val v1TableName = parseV1Table(tableName, "SHOW CREATE TABLE")
+ ShowCreateTableCommand(v1TableName.asTableIdentifier)
+
+ case CacheTableStatement(tableName, plan, isLazy, options) =>
+ val v1TableName = parseV1Table(tableName, "CACHE TABLE")
+ CacheTableCommand(v1TableName.asTableIdentifier, plan, isLazy, options)
+
+ case UncacheTableStatement(tableName, ifExists) =>
+ val v1TableName = parseV1Table(tableName, "UNCACHE TABLE")
+ UncacheTableCommand(v1TableName.asTableIdentifier, ifExists)
+
+ case TruncateTableStatement(tableName, partitionSpec) =>
+ val v1TableName = parseV1Table(tableName, "TRUNCATE TABLE")
+ TruncateTableCommand(
+ v1TableName.asTableIdentifier,
+ partitionSpec)
+
+ case ShowPartitionsStatement(tableName, partitionSpec) =>
+ val v1TableName = parseV1Table(tableName, "SHOW PARTITIONS")
+ ShowPartitionsCommand(
+ v1TableName.asTableIdentifier,
+ partitionSpec)
}
private def parseV1Table(tableName: Seq[String], sql: String): Seq[String] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
index 9d1636ccf2718..b41a4ff766672 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
@@ -454,6 +454,7 @@ case class RowToColumnarExec(child: SparkPlan) extends UnaryExecNode {
override def next(): ColumnarBatch = {
cb.setNumRows(0)
+ vectors.foreach(_.reset())
var rowCount = 0
while (rowCount < numRows && rowIterator.hasNext) {
val row = rowIterator.next()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index 0a955d6a75235..32d21d05e5f73 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -62,6 +62,14 @@ case class SortExec(
"peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"),
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))
+ private[sql] var rowSorter: UnsafeExternalRowSorter = _
+
+ /**
+ * This method gets invoked only once for each SortExec instance to initialize an
+ * UnsafeExternalRowSorter, both `plan.execute` and code generation are using it.
+ * In the code generation code path, we need to call this function outside the class so we
+ * should make it public.
+ */
def createSorter(): UnsafeExternalRowSorter = {
val ordering = newOrdering(sortOrder, output)
@@ -87,13 +95,13 @@ case class SortExec(
}
val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
- val sorter = UnsafeExternalRowSorter.create(
+ rowSorter = UnsafeExternalRowSorter.create(
schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort)
if (testSpillFrequency > 0) {
- sorter.setTestSpillFrequency(testSpillFrequency)
+ rowSorter.setTestSpillFrequency(testSpillFrequency)
}
- sorter
+ rowSorter
}
protected override def doExecute(): RDD[InternalRow] = {
@@ -181,4 +189,17 @@ case class SortExec(
|$sorterVariable.insertRow((UnsafeRow)${row.value});
""".stripMargin
}
+
+ /**
+ * In SortExec, we overwrites cleanupResources to close UnsafeExternalRowSorter.
+ */
+ override protected[sql] def cleanupResources(): Unit = {
+ if (rowSorter != null) {
+ // There's possible for rowSorter is null here, for example, in the scenario of empty
+ // iterator in the current task, the downstream physical node(like SortMergeJoinExec) will
+ // trigger cleanupResources before rowSorter initialized in createSorter.
+ rowSorter.cleanupResources()
+ }
+ super.cleanupResources()
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index b4cdf9e16b7e5..125f76282e3df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -507,6 +507,15 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}
newOrdering(order, Seq.empty)
}
+
+ /**
+ * Cleans up the resources used by the physical operator (if any). In general, all the resources
+ * should be cleaned up when the task finishes but operators like SortMergeJoinExec and LimitExec
+ * may want eager cleanup to free up tight resources (e.g., memory).
+ */
+ protected[sql] def cleanupResources(): Unit = {
+ children.foreach(_.cleanupResources())
+ }
}
trait LeafExecNode extends SparkPlan {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index 3e7a54877cae8..20894b39ce5d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -39,7 +39,7 @@ import org.apache.spark.sql.types.StructType
/**
* Concrete parser for Spark SQL statements.
*/
-class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser {
+class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser(conf) {
val astBuilder = new SparkSqlAstBuilder(conf)
private val substitutor = new VariableSubstitution(conf)
@@ -135,38 +135,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
ShowColumnsCommand(Option(ctx.db).map(_.getText), visitTableIdentifier(ctx.tableIdentifier))
}
- /**
- * A command for users to list the partition names of a table. If partition spec is specified,
- * partitions that match the spec are returned. Otherwise an empty result set is returned.
- *
- * This function creates a [[ShowPartitionsCommand]] logical plan
- *
- * The syntax of using this command in SQL is:
- * {{{
- * SHOW PARTITIONS table_identifier [partition_spec];
- * }}}
- */
- override def visitShowPartitions(ctx: ShowPartitionsContext): LogicalPlan = withOrigin(ctx) {
- val table = visitTableIdentifier(ctx.tableIdentifier)
- val partitionKeys = Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)
- ShowPartitionsCommand(table, partitionKeys)
- }
-
- /**
- * Creates a [[ShowCreateTableCommand]]
- */
- override def visitShowCreateTable(ctx: ShowCreateTableContext): LogicalPlan = withOrigin(ctx) {
- val table = visitTableIdentifier(ctx.tableIdentifier())
- ShowCreateTableCommand(table)
- }
-
- /**
- * Create a [[RefreshTable]] logical plan.
- */
- override def visitRefreshTable(ctx: RefreshTableContext): LogicalPlan = withOrigin(ctx) {
- RefreshTable(visitTableIdentifier(ctx.tableIdentifier))
- }
-
/**
* Create a [[RefreshResource]] logical plan.
*/
@@ -189,28 +157,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
unquotedPath
}
- /**
- * Create a [[CacheTableCommand]] logical plan.
- */
- override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) {
- val query = Option(ctx.query).map(plan)
- val tableIdent = visitTableIdentifier(ctx.tableIdentifier)
- if (query.isDefined && tableIdent.database.isDefined) {
- val database = tableIdent.database.get
- throw new ParseException(s"It is not allowed to add database prefix `$database` to " +
- s"the table name in CACHE TABLE AS SELECT", ctx)
- }
- val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty)
- CacheTableCommand(tableIdent, query, ctx.LAZY != null, options)
- }
-
- /**
- * Create an [[UncacheTableCommand]] logical plan.
- */
- override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) {
- UncacheTableCommand(visitTableIdentifier(ctx.tableIdentifier), ctx.EXISTS != null)
- }
-
/**
* Create a [[ClearCacheCommand]] logical plan.
*/
@@ -346,47 +292,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
)
}
- /**
- * Create a [[TruncateTableCommand]] command.
- *
- * For example:
- * {{{
- * TRUNCATE TABLE tablename [PARTITION (partcol1=val1, partcol2=val2 ...)]
- * }}}
- */
- override def visitTruncateTable(ctx: TruncateTableContext): LogicalPlan = withOrigin(ctx) {
- TruncateTableCommand(
- visitTableIdentifier(ctx.tableIdentifier),
- Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec))
- }
-
- /**
- * Create a [[CreateDatabaseCommand]] command.
- *
- * For example:
- * {{{
- * CREATE DATABASE [IF NOT EXISTS] database_name
- * create_database_clauses;
- *
- * create_database_clauses (order insensitive):
- * [COMMENT database_comment]
- * [LOCATION path]
- * [WITH DBPROPERTIES (key1=val1, key2=val2, ...)]
- * }}}
- */
- override def visitCreateDatabase(ctx: CreateDatabaseContext): LogicalPlan = withOrigin(ctx) {
- checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx)
- checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx)
- checkDuplicateClauses(ctx.DBPROPERTIES, "WITH DBPROPERTIES", ctx)
-
- CreateDatabaseCommand(
- ctx.db.getText,
- ctx.EXISTS != null,
- ctx.locationSpec.asScala.headOption.map(visitLocationSpec),
- Option(ctx.comment).map(string),
- ctx.tablePropertyList.asScala.headOption.map(visitPropertyKeyValues).getOrElse(Map.empty))
- }
-
/**
* Create an [[AlterDatabasePropertiesCommand]] command.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index f45e3560b2cf1..f01947d8f5ed6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -92,6 +92,15 @@ case class AdaptiveSparkPlanExec(
// optimizations should be stage-independent.
@transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = Seq(
ReuseAdaptiveSubquery(conf, subqueryCache),
+
+ // When adding local shuffle readers in 'OptimizeLocalShuffleReader`, we revert all the local
+ // readers if additional shuffles are introduced. This may be too conservative: maybe there is
+ // only one local reader that introduces shuffle, and we can still keep other local readers.
+ // Here we re-execute this rule with the sub-plan-tree of a query stage, to make sure necessary
+ // local readers are added before executing the query stage.
+ // This rule must be executed before `ReduceNumShufflePartitions`, as local shuffle readers
+ // can't change number of partitions.
+ OptimizeLocalShuffleReader(conf),
ReduceNumShufflePartitions(conf),
ApplyColumnarRulesAndInsertTransitions(session.sessionState.conf,
session.sessionState.columnarRules),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 0f5f1591623af..e9b8fae7cd735 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -378,8 +378,6 @@ case class DataSource(
// This is a non-streaming file based datasource.
case (format: FileFormat, _) =>
- val globbedPaths =
- checkAndGlobPathIfNecessary(checkEmptyGlobPath = true, checkFilesExist = checkFilesExist)
val useCatalogFileIndex = sparkSession.sqlContext.conf.manageFilesourcePartitions &&
catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog &&
catalogTable.get.partitionColumnNames.nonEmpty
@@ -391,6 +389,8 @@ case class DataSource(
catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(defaultTableSize))
(index, catalogTable.get.dataSchema, catalogTable.get.partitionSchema)
} else {
+ val globbedPaths = checkAndGlobPathIfNecessary(
+ checkEmptyGlobPath = true, checkFilesExist = checkFilesExist)
val index = createInMemoryFileIndex(globbedPaths)
val (resultDataSchema, resultPartitionSchema) =
getOrInferFileFormatSchema(format, () => index)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index d184f3cb71b1a..5d1feaed81a9a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -147,14 +147,7 @@ class JDBCOptions(
""".stripMargin
)
- val fetchSize = {
- val size = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt
- require(size >= 0,
- s"Invalid value `${size.toString}` for parameter " +
- s"`$JDBC_BATCH_FETCH_SIZE`. The minimum value is 0. When the value is 0, " +
- "the JDBC driver ignores the value and does the estimates.")
- size
- }
+ val fetchSize = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt
// ------------------------------------------------------------
// Optional parameters only for writing
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 86a27b5afc250..55ca4e3624bdd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -605,6 +605,13 @@ object JdbcUtils extends Logging {
* implementation changes elsewhere might easily render such a closure
* non-Serializable. Instead, we explicitly close over all variables that
* are used.
+ *
+ * Note that this method records task output metrics. It assumes the method is
+ * running in a task. For now, we only records the number of rows being written
+ * because there's no good way to measure the total bytes being written. Only
+ * effective outputs are taken into account: for example, metric will not be updated
+ * if it supports transaction and transaction is rolled back, but metric will be
+ * updated even with error if it doesn't support transaction, as there're dirty outputs.
*/
def savePartition(
getConnection: () => Connection,
@@ -615,7 +622,9 @@ object JdbcUtils extends Logging {
batchSize: Int,
dialect: JdbcDialect,
isolationLevel: Int,
- options: JDBCOptions): Iterator[Byte] = {
+ options: JDBCOptions): Unit = {
+ val outMetrics = TaskContext.get().taskMetrics().outputMetrics
+
val conn = getConnection()
var committed = false
@@ -643,7 +652,7 @@ object JdbcUtils extends Logging {
}
}
val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE
-
+ var totalRowCount = 0
try {
if (supportsTransactions) {
conn.setAutoCommit(false) // Everything in the same db transaction.
@@ -672,6 +681,7 @@ object JdbcUtils extends Logging {
}
stmt.addBatch()
rowCount += 1
+ totalRowCount += 1
if (rowCount % batchSize == 0) {
stmt.executeBatch()
rowCount = 0
@@ -687,7 +697,6 @@ object JdbcUtils extends Logging {
conn.commit()
}
committed = true
- Iterator.empty
} catch {
case e: SQLException =>
val cause = e.getNextException
@@ -715,9 +724,13 @@ object JdbcUtils extends Logging {
// tell the user about another problem.
if (supportsTransactions) {
conn.rollback()
+ } else {
+ outMetrics.setRecordsWritten(totalRowCount)
}
conn.close()
} else {
+ outMetrics.setRecordsWritten(totalRowCount)
+
// The stage must succeed. We cannot propagate any exception close() might throw.
try {
conn.close()
@@ -840,10 +853,10 @@ object JdbcUtils extends Logging {
case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n)
case _ => df
}
- repartitionedDF.rdd.foreachPartition(iterator => savePartition(
+ repartitionedDF.rdd.foreachPartition { iterator => savePartition(
getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel,
options)
- )
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateNamespaceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateNamespaceExec.scala
new file mode 100644
index 0000000000000..0f69f85dd8376
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateNamespaceExec.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import scala.collection.JavaConverters.mapAsJavaMapConverter
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.connector.catalog.SupportsNamespaces
+
+/**
+ * Physical plan node for creating a namespace.
+ */
+case class CreateNamespaceExec(
+ catalog: SupportsNamespaces,
+ namespace: Seq[String],
+ ifNotExists: Boolean,
+ private var properties: Map[String, String])
+ extends V2CommandExec {
+ override protected def run(): Seq[InternalRow] = {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ val ns = namespace.toArray
+ if (!catalog.namespaceExists(ns)) {
+ try {
+ catalog.createNamespace(ns, properties.asJava)
+ } catch {
+ case _: NamespaceAlreadyExistsException if ifNotExists =>
+ logWarning(s"Namespace ${namespace.quoted} was created concurrently. Ignoring.")
+ }
+ } else if (!ifNotExists) {
+ throw new NamespaceAlreadyExistsException(ns)
+ }
+
+ Seq.empty
+ }
+
+ override def output: Seq[Attribute] = Seq.empty
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index c8d29520bcfce..4a7cb7db45ded 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable
import org.apache.spark.sql.{AnalysisException, Strategy}
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
-import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, AppendData, CreateTableAsSelect, CreateV2Table, DeleteFromTable, DescribeTable, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition, ReplaceTable, ReplaceTableAsSelect, SetCatalogAndNamespace, ShowNamespaces, ShowTables}
+import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, AppendData, CreateNamespace, CreateTableAsSelect, CreateV2Table, DeleteFromTable, DescribeTable, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, RefreshTable, Repartition, ReplaceTable, ReplaceTableAsSelect, SetCatalogAndNamespace, ShowNamespaces, ShowTables}
import org.apache.spark.sql.connector.catalog.{StagingTableCatalog, TableCapability}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream}
@@ -193,6 +193,9 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
catalog, ident, parts, query, planLater(query), props, writeOptions, ifNotExists) :: Nil
}
+ case RefreshTable(catalog, ident) =>
+ RefreshTableExec(catalog, ident) :: Nil
+
case ReplaceTable(catalog, ident, schema, parts, props, orCreate) =>
catalog match {
case staging: StagingTableCatalog =>
@@ -289,6 +292,9 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
case AlterTable(catalog, ident, _, changes) =>
AlterTableExec(catalog, ident, changes) :: Nil
+ case CreateNamespace(catalog, namespace, ifNotExists, properties) =>
+ CreateNamespaceExec(catalog, namespace, ifNotExists, properties) :: Nil
+
case r: ShowNamespaces =>
ShowNamespacesExec(r.output, r.catalog, r.namespace, r.pattern) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RefreshTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RefreshTableExec.scala
new file mode 100644
index 0000000000000..2a19ff304a9e0
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RefreshTableExec.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
+
+case class RefreshTableExec(
+ catalog: TableCatalog,
+ ident: Identifier) extends V2CommandExec {
+ override protected def run(): Seq[InternalRow] = {
+ catalog.invalidateTable(ident)
+ Seq.empty
+ }
+
+ override def output: Seq[Attribute] = Seq.empty
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 189727a9bc88d..26fb0e5ffb1af 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -191,7 +191,8 @@ case class SortMergeJoinExec(
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
inMemoryThreshold,
- spillThreshold
+ spillThreshold,
+ cleanupResources
)
private[this] val joinRow = new JoinedRow
@@ -235,7 +236,8 @@ case class SortMergeJoinExec(
streamedIter = RowIterator.fromScala(leftIter),
bufferedIter = RowIterator.fromScala(rightIter),
inMemoryThreshold,
- spillThreshold
+ spillThreshold,
+ cleanupResources
)
val rightNullRow = new GenericInternalRow(right.output.length)
new LeftOuterIterator(
@@ -249,7 +251,8 @@ case class SortMergeJoinExec(
streamedIter = RowIterator.fromScala(rightIter),
bufferedIter = RowIterator.fromScala(leftIter),
inMemoryThreshold,
- spillThreshold
+ spillThreshold,
+ cleanupResources
)
val leftNullRow = new GenericInternalRow(left.output.length)
new RightOuterIterator(
@@ -283,7 +286,8 @@ case class SortMergeJoinExec(
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
inMemoryThreshold,
- spillThreshold
+ spillThreshold,
+ cleanupResources
)
private[this] val joinRow = new JoinedRow
@@ -318,7 +322,8 @@ case class SortMergeJoinExec(
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
inMemoryThreshold,
- spillThreshold
+ spillThreshold,
+ cleanupResources
)
private[this] val joinRow = new JoinedRow
@@ -360,7 +365,8 @@ case class SortMergeJoinExec(
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
inMemoryThreshold,
- spillThreshold
+ spillThreshold,
+ cleanupResources
)
private[this] val joinRow = new JoinedRow
@@ -640,6 +646,9 @@ case class SortMergeJoinExec(
(evaluateVariables(leftVars), "")
}
+ val thisPlan = ctx.addReferenceObj("plan", this)
+ val eagerCleanup = s"$thisPlan.cleanupResources();"
+
s"""
|while (findNextInnerJoinRows($leftInput, $rightInput)) {
| ${leftVarDecl.mkString("\n")}
@@ -653,6 +662,7 @@ case class SortMergeJoinExec(
| }
| if (shouldStop()) return;
|}
+ |$eagerCleanup
""".stripMargin
}
}
@@ -678,6 +688,7 @@ case class SortMergeJoinExec(
* @param inMemoryThreshold Threshold for number of rows guaranteed to be held in memory by
* internal buffer
* @param spillThreshold Threshold for number of rows to be spilled by internal buffer
+ * @param eagerCleanupResources the eager cleanup function to be invoked when no join row found
*/
private[joins] class SortMergeJoinScanner(
streamedKeyGenerator: Projection,
@@ -686,7 +697,8 @@ private[joins] class SortMergeJoinScanner(
streamedIter: RowIterator,
bufferedIter: RowIterator,
inMemoryThreshold: Int,
- spillThreshold: Int) {
+ spillThreshold: Int,
+ eagerCleanupResources: () => Unit) {
private[this] var streamedRow: InternalRow = _
private[this] var streamedRowKey: InternalRow = _
private[this] var bufferedRow: InternalRow = _
@@ -710,7 +722,8 @@ private[joins] class SortMergeJoinScanner(
def getBufferedMatches: ExternalAppendOnlyUnsafeRowArray = bufferedMatches
/**
- * Advances both input iterators, stopping when we have found rows with matching join keys.
+ * Advances both input iterators, stopping when we have found rows with matching join keys. If no
+ * join rows found, try to do the eager resources cleanup.
* @return true if matching rows have been found and false otherwise. If this returns true, then
* [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join
* results.
@@ -720,7 +733,7 @@ private[joins] class SortMergeJoinScanner(
// Advance the streamed side of the join until we find the next row whose join key contains
// no nulls or we hit the end of the streamed iterator.
}
- if (streamedRow == null) {
+ val found = if (streamedRow == null) {
// We have consumed the entire streamed iterator, so there can be no more matches.
matchJoinKey = null
bufferedMatches.clear()
@@ -760,17 +773,19 @@ private[joins] class SortMergeJoinScanner(
true
}
}
+ if (!found) eagerCleanupResources()
+ found
}
/**
* Advances the streamed input iterator and buffers all rows from the buffered input that
- * have matching keys.
+ * have matching keys. If no join rows found, try to do the eager resources cleanup.
* @return true if the streamed iterator returned a row, false otherwise. If this returns true,
* then [[getStreamedRow]] and [[getBufferedMatches]] can be called to produce the outer
* join results.
*/
final def findNextOuterJoinRows(): Boolean = {
- if (!advancedStreamed()) {
+ val found = if (!advancedStreamed()) {
// We have consumed the entire streamed iterator, so there can be no more matches.
matchJoinKey = null
bufferedMatches.clear()
@@ -800,6 +815,8 @@ private[joins] class SortMergeJoinScanner(
// If there is a streamed input then we always return true
true
}
+ if (!found) eagerCleanupResources()
+ found
}
// --- Private methods --------------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 19809b07508d9..b7f0ab2969e45 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.metric
import java.text.NumberFormat
-import java.util.Locale
+import java.util.{Arrays, Locale}
import scala.concurrent.duration._
@@ -150,7 +150,7 @@ object SQLMetrics {
* A function that defines how we aggregate the final accumulator results among all tasks,
* and represent it in string for a SQL physical operator.
*/
- def stringValue(metricsType: String, values: Seq[Long]): String = {
+ def stringValue(metricsType: String, values: Array[Long]): String = {
if (metricsType == SUM_METRIC) {
val numberFormat = NumberFormat.getIntegerInstance(Locale.US)
numberFormat.format(values.sum)
@@ -162,8 +162,9 @@ object SQLMetrics {
val metric = if (validValues.isEmpty) {
Seq.fill(3)(0L)
} else {
- val sorted = validValues.sorted
- Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
+ Arrays.sort(validValues)
+ Seq(validValues(0), validValues(validValues.length / 2),
+ validValues(validValues.length - 1))
}
metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric))
}
@@ -184,8 +185,9 @@ object SQLMetrics {
val metric = if (validValues.isEmpty) {
Seq.fill(4)(0L)
} else {
- val sorted = validValues.sorted
- Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
+ Arrays.sort(validValues)
+ Seq(validValues.sum, validValues(0), validValues(validValues.length / 2),
+ validValues(validValues.length - 1))
}
metric.map(strFormat)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
index dda9d41f630e6..d191a79187f28 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
@@ -21,6 +21,7 @@ import java.sql.Date
import java.util.concurrent.TimeUnit
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout}
+import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.execution.streaming.GroupStateImpl._
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout}
import org.apache.spark.unsafe.types.CalendarInterval
@@ -159,7 +160,7 @@ private[sql] class GroupStateImpl[S] private(
def getTimeoutTimestamp: Long = timeoutTimestamp
private def parseDuration(duration: String): Long = {
- val cal = CalendarInterval.fromCaseInsensitiveString(duration)
+ val cal = IntervalUtils.fromString(duration)
if (cal.milliseconds < 0 || cal.months < 0) {
throw new IllegalArgumentException(s"Provided duration ($duration) is not positive")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index af52af0d1d7e6..b8e18b89b54bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -77,7 +77,8 @@ class IncrementalExecution(
*/
override
lazy val optimizedPlan: LogicalPlan = tracker.measurePhase(QueryPlanningTracker.OPTIMIZATION) {
- sparkSession.sessionState.optimizer.execute(withCachedData) transformAllExpressions {
+ sparkSession.sessionState.optimizer.executeAndTrack(withCachedData,
+ tracker) transformAllExpressions {
case ts @ CurrentBatchTimestamp(timestamp, _, _) =>
logInfo(s"Current batch timestamp = $timestamp")
ts.toLiteral
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala
index 2bdb3402c14b1..daa70a12ba0e7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala
@@ -21,8 +21,8 @@ import java.util.concurrent.TimeUnit
import scala.concurrent.duration.Duration
+import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.streaming.Trigger
-import org.apache.spark.unsafe.types.CalendarInterval
private object Triggers {
def validate(intervalMs: Long): Unit = {
@@ -30,7 +30,7 @@ private object Triggers {
}
def convert(interval: String): Long = {
- val cal = CalendarInterval.fromCaseInsensitiveString(interval)
+ val cal = IntervalUtils.fromString(interval)
if (cal.months > 0) {
throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
index 2c4a7eacdf10b..da526612e7bcf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
@@ -16,10 +16,11 @@
*/
package org.apache.spark.sql.execution.ui
-import java.util.{Date, NoSuchElementException}
+import java.util.{Arrays, Date, NoSuchElementException}
import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConverters._
+import scala.collection.mutable
import org.apache.spark.{JobExecutionStatus, SparkConf}
import org.apache.spark.internal.Logging
@@ -29,6 +30,7 @@ import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.metric._
import org.apache.spark.sql.internal.StaticSQLConf._
import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity}
+import org.apache.spark.util.collection.OpenHashMap
class SQLAppStatusListener(
conf: SparkConf,
@@ -103,8 +105,10 @@ class SQLAppStatusListener(
// Record the accumulator IDs for the stages of this job, so that the code that keeps
// track of the metrics knows which accumulators to look at.
val accumIds = exec.metrics.map(_.accumulatorId).toSet
- event.stageIds.foreach { id =>
- stageMetrics.put(id, new LiveStageMetrics(id, 0, accumIds, new ConcurrentHashMap()))
+ if (accumIds.nonEmpty) {
+ event.stageInfos.foreach { stage =>
+ stageMetrics.put(stage.stageId, new LiveStageMetrics(0, stage.numTasks, accumIds))
+ }
}
exec.jobs = exec.jobs + (jobId -> JobExecutionStatus.RUNNING)
@@ -118,9 +122,11 @@ class SQLAppStatusListener(
}
// Reset the metrics tracking object for the new attempt.
- Option(stageMetrics.get(event.stageInfo.stageId)).foreach { metrics =>
- metrics.taskMetrics.clear()
- metrics.attemptId = event.stageInfo.attemptNumber
+ Option(stageMetrics.get(event.stageInfo.stageId)).foreach { stage =>
+ if (stage.attemptId != event.stageInfo.attemptNumber) {
+ stageMetrics.put(event.stageInfo.stageId,
+ new LiveStageMetrics(event.stageInfo.attemptNumber, stage.numTasks, stage.accumulatorIds))
+ }
}
}
@@ -140,7 +146,16 @@ class SQLAppStatusListener(
override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = {
event.accumUpdates.foreach { case (taskId, stageId, attemptId, accumUpdates) =>
- updateStageMetrics(stageId, attemptId, taskId, accumUpdates, false)
+ updateStageMetrics(stageId, attemptId, taskId, SQLAppStatusListener.UNKNOWN_INDEX,
+ accumUpdates, false)
+ }
+ }
+
+ override def onTaskStart(event: SparkListenerTaskStart): Unit = {
+ Option(stageMetrics.get(event.stageId)).foreach { stage =>
+ if (stage.attemptId == event.stageAttemptId) {
+ stage.registerTask(event.taskInfo.taskId, event.taskInfo.index)
+ }
}
}
@@ -165,7 +180,7 @@ class SQLAppStatusListener(
} else {
info.accumulables
}
- updateStageMetrics(event.stageId, event.stageAttemptId, info.taskId, accums,
+ updateStageMetrics(event.stageId, event.stageAttemptId, info.taskId, info.index, accums,
info.successful)
}
@@ -181,17 +196,40 @@ class SQLAppStatusListener(
private def aggregateMetrics(exec: LiveExecutionData): Map[Long, String] = {
val metricTypes = exec.metrics.map { m => (m.accumulatorId, m.metricType) }.toMap
- val metrics = exec.stages.toSeq
+
+ val taskMetrics = exec.stages.toSeq
.flatMap { stageId => Option(stageMetrics.get(stageId)) }
- .flatMap(_.taskMetrics.values().asScala)
- .flatMap { metrics => metrics.ids.zip(metrics.values) }
-
- val aggregatedMetrics = (metrics ++ exec.driverAccumUpdates.toSeq)
- .filter { case (id, _) => metricTypes.contains(id) }
- .groupBy(_._1)
- .map { case (id, values) =>
- id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2))
+ .flatMap(_.metricValues())
+
+ val allMetrics = new mutable.HashMap[Long, Array[Long]]()
+
+ taskMetrics.foreach { case (id, values) =>
+ val prev = allMetrics.getOrElse(id, null)
+ val updated = if (prev != null) {
+ prev ++ values
+ } else {
+ values
}
+ allMetrics(id) = updated
+ }
+
+ exec.driverAccumUpdates.foreach { case (id, value) =>
+ if (metricTypes.contains(id)) {
+ val prev = allMetrics.getOrElse(id, null)
+ val updated = if (prev != null) {
+ val _copy = Arrays.copyOf(prev, prev.length + 1)
+ _copy(prev.length) = value
+ _copy
+ } else {
+ Array(value)
+ }
+ allMetrics(id) = updated
+ }
+ }
+
+ val aggregatedMetrics = allMetrics.map { case (id, values) =>
+ id -> SQLMetrics.stringValue(metricTypes(id), values)
+ }.toMap
// Check the execution again for whether the aggregated metrics data has been calculated.
// This can happen if the UI is requesting this data, and the onExecutionEnd handler is
@@ -208,43 +246,13 @@ class SQLAppStatusListener(
stageId: Int,
attemptId: Int,
taskId: Long,
+ taskIdx: Int,
accumUpdates: Seq[AccumulableInfo],
succeeded: Boolean): Unit = {
Option(stageMetrics.get(stageId)).foreach { metrics =>
- if (metrics.attemptId != attemptId || metrics.accumulatorIds.isEmpty) {
- return
- }
-
- val oldTaskMetrics = metrics.taskMetrics.get(taskId)
- if (oldTaskMetrics != null && oldTaskMetrics.succeeded) {
- return
+ if (metrics.attemptId == attemptId) {
+ metrics.updateTaskMetrics(taskId, taskIdx, succeeded, accumUpdates)
}
-
- val updates = accumUpdates
- .filter { acc => acc.update.isDefined && metrics.accumulatorIds.contains(acc.id) }
- .sortBy(_.id)
-
- if (updates.isEmpty) {
- return
- }
-
- val ids = new Array[Long](updates.size)
- val values = new Array[Long](updates.size)
- updates.zipWithIndex.foreach { case (acc, idx) =>
- ids(idx) = acc.id
- // In a live application, accumulators have Long values, but when reading from event
- // logs, they have String values. For now, assume all accumulators are Long and covert
- // accordingly.
- values(idx) = acc.update.get match {
- case s: String => s.toLong
- case l: Long => l
- case o => throw new IllegalArgumentException(s"Unexpected: $o")
- }
- }
-
- // TODO: storing metrics by task ID can cause metrics for the same task index to be
- // counted multiple times, for example due to speculation or re-attempts.
- metrics.taskMetrics.put(taskId, new LiveTaskMetrics(ids, values, succeeded))
}
}
@@ -425,12 +433,76 @@ private class LiveExecutionData(val executionId: Long) extends LiveEntity {
}
private class LiveStageMetrics(
- val stageId: Int,
- var attemptId: Int,
- val accumulatorIds: Set[Long],
- val taskMetrics: ConcurrentHashMap[Long, LiveTaskMetrics])
-
-private class LiveTaskMetrics(
- val ids: Array[Long],
- val values: Array[Long],
- val succeeded: Boolean)
+ val attemptId: Int,
+ val numTasks: Int,
+ val accumulatorIds: Set[Long]) {
+
+ /**
+ * Mapping of task IDs to their respective index. Note this may contain more elements than the
+ * stage's number of tasks, if speculative execution is on.
+ */
+ private val taskIndices = new OpenHashMap[Long, Int]()
+
+ /** Bit set tracking which indices have been successfully computed. */
+ private val completedIndices = new mutable.BitSet()
+
+ /**
+ * Task metrics values for the stage. Maps the metric ID to the metric values for each
+ * index. For each metric ID, there will be the same number of values as the number
+ * of indices. This relies on `SQLMetrics.stringValue` treating 0 as a neutral value,
+ * independent of the actual metric type.
+ */
+ private val taskMetrics = new ConcurrentHashMap[Long, Array[Long]]()
+
+ def registerTask(taskId: Long, taskIdx: Int): Unit = {
+ taskIndices.update(taskId, taskIdx)
+ }
+
+ def updateTaskMetrics(
+ taskId: Long,
+ eventIdx: Int,
+ finished: Boolean,
+ accumUpdates: Seq[AccumulableInfo]): Unit = {
+ val taskIdx = if (eventIdx == SQLAppStatusListener.UNKNOWN_INDEX) {
+ if (!taskIndices.contains(taskId)) {
+ // We probably missed the start event for the task, just ignore it.
+ return
+ }
+ taskIndices(taskId)
+ } else {
+ // Here we can recover from a missing task start event. Just register the task again.
+ registerTask(taskId, eventIdx)
+ eventIdx
+ }
+
+ if (completedIndices.contains(taskIdx)) {
+ return
+ }
+
+ accumUpdates
+ .filter { acc => acc.update.isDefined && accumulatorIds.contains(acc.id) }
+ .foreach { acc =>
+ // In a live application, accumulators have Long values, but when reading from event
+ // logs, they have String values. For now, assume all accumulators are Long and convert
+ // accordingly.
+ val value = acc.update.get match {
+ case s: String => s.toLong
+ case l: Long => l
+ case o => throw new IllegalArgumentException(s"Unexpected: $o")
+ }
+
+ val metricValues = taskMetrics.computeIfAbsent(acc.id, _ => new Array(numTasks))
+ metricValues(taskIdx) = value
+ }
+
+ if (finished) {
+ completedIndices += taskIdx
+ }
+ }
+
+ def metricValues(): Seq[(Long, Array[Long])] = taskMetrics.asScala.toSeq
+}
+
+private object SQLAppStatusListener {
+ val UNKNOWN_INDEX = -1
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
index f1a648176c3b3..d097f9f18f89b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
@@ -18,7 +18,8 @@
package org.apache.spark.sql.internal
import java.net.URL
-import java.util.Locale
+import java.util.{Locale, UUID}
+import java.util.concurrent.ConcurrentHashMap
import scala.reflect.ClassTag
import scala.util.control.NonFatal
@@ -33,6 +34,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.execution.CacheManager
import org.apache.spark.sql.execution.ui.{SQLAppStatusListener, SQLAppStatusStore, SQLTab}
import org.apache.spark.sql.internal.StaticSQLConf._
+import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.status.ElementTrackingStore
import org.apache.spark.util.Utils
@@ -110,6 +112,12 @@ private[sql] class SharedState(
*/
val cacheManager: CacheManager = new CacheManager
+ /**
+ * A map of active streaming queries to the session specific StreamingQueryManager that manages
+ * the lifecycle of that stream.
+ */
+ private[sql] val activeStreamingQueries = new ConcurrentHashMap[UUID, StreamingQueryManager]()
+
/**
* A status store to query SQL status/metrics of this Spark application, based on SQL-specific
* [[org.apache.spark.scheduler.SparkListenerEvent]]s.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
index 9abe38dfda0be..9b43a83e7b94a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
@@ -352,8 +352,10 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
}
}
- // Make sure no other query with same id is active
- if (activeQueries.values.exists(_.id == query.id)) {
+ // Make sure no other query with same id is active across all sessions
+ val activeOption =
+ Option(sparkSession.sharedState.activeStreamingQueries.putIfAbsent(query.id, this))
+ if (activeOption.isDefined || activeQueries.values.exists(_.id == query.id)) {
throw new IllegalStateException(
s"Cannot start query with id ${query.id} as another query with same id is " +
s"already active. Perhaps you are attempting to restart a query from checkpoint " +
@@ -370,9 +372,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
query.streamingQuery.start()
} catch {
case e: Throwable =>
- activeQueriesLock.synchronized {
- activeQueries -= query.id
- }
+ unregisterTerminatedStream(query.id)
throw e
}
query
@@ -380,9 +380,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
/** Notify (by the StreamingQuery) that the query has been terminated */
private[sql] def notifyQueryTermination(terminatedQuery: StreamingQuery): Unit = {
- activeQueriesLock.synchronized {
- activeQueries -= terminatedQuery.id
- }
+ unregisterTerminatedStream(terminatedQuery.id)
awaitTerminationLock.synchronized {
if (lastTerminatedQuery == null || terminatedQuery.exception.nonEmpty) {
lastTerminatedQuery = terminatedQuery
@@ -391,4 +389,12 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
}
stateStoreCoordinator.deactivateInstances(terminatedQuery.runId)
}
+
+ private def unregisterTerminatedStream(terminatedQueryId: UUID): Unit = {
+ activeQueriesLock.synchronized {
+ // remove from shared state only if the streaming query manager also matches
+ sparkSession.sharedState.activeStreamingQueries.remove(terminatedQueryId, this)
+ activeQueries -= terminatedQueryId
+ }
+ }
}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql b/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql
index 993eecf0f89b6..5e665e4c0c384 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql
@@ -37,3 +37,34 @@ select bit_count(-9223372036854775808L);
-- other illegal arguments
select bit_count("bit count");
select bit_count('a');
+
+-- test for bit_xor
+--
+CREATE OR REPLACE TEMPORARY VIEW bitwise_test AS SELECT * FROM VALUES
+ (1, 1, 1, 1L),
+ (2, 3, 4, null),
+ (7, 7, 7, 3L) AS bitwise_test(b1, b2, b3, b4);
+
+-- empty case
+SELECT BIT_XOR(b3) AS n1 FROM bitwise_test where 1 = 0;
+
+-- null case
+SELECT BIT_XOR(b4) AS n1 FROM bitwise_test where b4 is null;
+
+-- the suffix numbers show the expected answer
+SELECT
+ BIT_XOR(cast(b1 as tinyint)) AS a4,
+ BIT_XOR(cast(b2 as smallint)) AS b5,
+ BIT_XOR(b3) AS c2,
+ BIT_XOR(b4) AS d2,
+ BIT_XOR(distinct b4) AS e2
+FROM bitwise_test;
+
+-- group by
+SELECT bit_xor(b3) FROM bitwise_test GROUP BY b1 & 1;
+
+--having
+SELECT b1, bit_xor(b2) FROM bitwise_test GROUP BY b1 HAVING bit_and(b2) < 7;
+
+-- window
+SELECT b1, b2, bit_xor(b2) OVER (PARTITION BY b1 ORDER BY b2) FROM bitwise_test;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/literals.sql b/sql/core/src/test/resources/sql-tests/inputs/literals.sql
index 816386c483209..0f95f85237828 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/literals.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/literals.sql
@@ -85,6 +85,19 @@ select timestamp '2016-33-11 20:54:00.000';
-- interval
select interval 13.123456789 seconds, interval -13.123456789 second;
select interval 1 year 2 month 3 week 4 day 5 hour 6 minute 7 seconds 8 millisecond, 9 microsecond;
+select interval '30' year '25' month '-100' day '40' hour '80' minute '299.889987299' second;
+select interval '0 0:0:0.1' day to second;
+select interval '10-9' year to month;
+select interval '20 15:40:32.99899999' day to hour;
+select interval '20 15:40:32.99899999' day to minute;
+select interval '20 15:40:32.99899999' day to second;
+select interval '15:40:32.99899999' hour to minute;
+select interval '15:40.99899999' hour to second;
+select interval '15:40' hour to second;
+select interval '15:40:32.99899999' hour to second;
+select interval '20 40:32.99899999' minute to second;
+select interval '40:32.99899999' minute to second;
+select interval '40:32' minute to second;
-- ns is not supported
select interval 10 nanoseconds;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part1.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part1.sql
new file mode 100644
index 0000000000000..ae2a015ada245
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part1.sql
@@ -0,0 +1,352 @@
+-- Portions Copyright (c) 1996-2019, PostgreSQL Global Development Group
+--
+-- Window Functions Testing
+-- https://github.com/postgres/postgres/blob/REL_12_STABLE/src/test/regress/sql/window.sql#L1-L319
+
+CREATE TEMPORARY VIEW tenk2 AS SELECT * FROM tenk1;
+
+-- [SPARK-29540] Thrift in some cases can't parse string to date
+-- CREATE TABLE empsalary (
+-- depname string,
+-- empno integer,
+-- salary int,
+-- enroll_date date
+-- ) USING parquet;
+
+-- [SPARK-29540] Thrift in some cases can't parse string to date
+-- INSERT INTO empsalary VALUES ('develop', 10, 5200, '2007-08-01');
+-- INSERT INTO empsalary VALUES ('sales', 1, 5000, '2006-10-01');
+-- INSERT INTO empsalary VALUES ('personnel', 5, 3500, '2007-12-10');
+-- INSERT INTO empsalary VALUES ('sales', 4, 4800, '2007-08-08');
+-- INSERT INTO empsalary VALUES ('personnel', 2, 3900, '2006-12-23');
+-- INSERT INTO empsalary VALUES ('develop', 7, 4200, '2008-01-01');
+-- INSERT INTO empsalary VALUES ('develop', 9, 4500, '2008-01-01');
+-- INSERT INTO empsalary VALUES ('sales', 3, 4800, '2007-08-01');
+-- INSERT INTO empsalary VALUES ('develop', 8, 6000, '2006-10-01');
+-- INSERT INTO empsalary VALUES ('develop', 11, 5200, '2007-08-15');
+
+-- [SPARK-29540] Thrift in some cases can't parse string to date
+-- SELECT depname, empno, salary, sum(salary) OVER (PARTITION BY depname) FROM empsalary ORDER BY depname, salary;
+
+-- [SPARK-29540] Thrift in some cases can't parse string to date
+-- SELECT depname, empno, salary, rank() OVER (PARTITION BY depname ORDER BY salary) FROM empsalary;
+
+-- with GROUP BY
+SELECT four, ten, SUM(SUM(four)) OVER (PARTITION BY four), AVG(ten) FROM tenk1
+GROUP BY four, ten ORDER BY four, ten;
+
+-- [SPARK-29540] Thrift in some cases can't parse string to date
+-- SELECT depname, empno, salary, sum(salary) OVER w FROM empsalary WINDOW w AS (PARTITION BY depname);
+
+-- [SPARK-28064] Order by does not accept a call to rank()
+-- SELECT depname, empno, salary, rank() OVER w FROM empsalary WINDOW w AS (PARTITION BY depname ORDER BY salary) ORDER BY rank() OVER w;
+
+-- empty window specification
+SELECT COUNT(*) OVER () FROM tenk1 WHERE unique2 < 10;
+
+SELECT COUNT(*) OVER w FROM tenk1 WHERE unique2 < 10 WINDOW w AS ();
+
+-- no window operation
+SELECT four FROM tenk1 WHERE FALSE WINDOW w AS (PARTITION BY ten);
+
+-- cumulative aggregate
+SELECT sum(four) OVER (PARTITION BY ten ORDER BY unique2) AS sum_1, ten, four FROM tenk1 WHERE unique2 < 10;
+
+SELECT row_number() OVER (ORDER BY unique2) FROM tenk1 WHERE unique2 < 10;
+
+SELECT rank() OVER (PARTITION BY four ORDER BY ten) AS rank_1, ten, four FROM tenk1 WHERE unique2 < 10;
+
+SELECT dense_rank() OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10;
+
+SELECT percent_rank() OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10;
+
+SELECT cume_dist() OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10;
+
+SELECT ntile(3) OVER (ORDER BY ten, four), ten, four FROM tenk1 WHERE unique2 < 10;
+
+-- [SPARK-28065] ntile does not accept NULL as input
+-- SELECT ntile(NULL) OVER (ORDER BY ten, four), ten, four FROM tenk1 LIMIT 2;
+
+SELECT lag(ten) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10;
+
+-- [SPARK-28068] `lag` second argument must be a literal in Spark
+-- SELECT lag(ten, four) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10;
+
+-- [SPARK-28068] `lag` second argument must be a literal in Spark
+-- SELECT lag(ten, four, 0) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10;
+
+SELECT lead(ten) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10;
+
+SELECT lead(ten * 2, 1) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10;
+
+SELECT lead(ten * 2, 1, -1) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10;
+
+SELECT first(ten) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10;
+
+-- last returns the last row of the frame, which is CURRENT ROW in ORDER BY window.
+SELECT last(four) OVER (ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10;
+
+SELECT last(ten) OVER (PARTITION BY four), ten, four FROM
+(SELECT * FROM tenk1 WHERE unique2 < 10 ORDER BY four, ten)s
+ORDER BY four, ten;
+
+-- [SPARK-27951] ANSI SQL: NTH_VALUE function
+-- SELECT nth_value(ten, four + 1) OVER (PARTITION BY four), ten, four
+-- FROM (SELECT * FROM tenk1 WHERE unique2 < 10 ORDER BY four, ten)s;
+
+SELECT ten, two, sum(hundred) AS gsum, sum(sum(hundred)) OVER (PARTITION BY two ORDER BY ten) AS wsum
+FROM tenk1 GROUP BY ten, two;
+
+SELECT count(*) OVER (PARTITION BY four), four FROM (SELECT * FROM tenk1 WHERE two = 1)s WHERE unique2 < 10;
+
+SELECT (count(*) OVER (PARTITION BY four ORDER BY ten) +
+ sum(hundred) OVER (PARTITION BY four ORDER BY ten)) AS cntsum
+ FROM tenk1 WHERE unique2 < 10;
+
+-- opexpr with different windows evaluation.
+SELECT * FROM(
+ SELECT count(*) OVER (PARTITION BY four ORDER BY ten) +
+ sum(hundred) OVER (PARTITION BY two ORDER BY ten) AS total,
+ count(*) OVER (PARTITION BY four ORDER BY ten) AS fourcount,
+ sum(hundred) OVER (PARTITION BY two ORDER BY ten) AS twosum
+ FROM tenk1
+)sub WHERE total <> fourcount + twosum;
+
+SELECT avg(four) OVER (PARTITION BY four ORDER BY thousand / 100) FROM tenk1 WHERE unique2 < 10;
+
+SELECT ten, two, sum(hundred) AS gsum, sum(sum(hundred)) OVER win AS wsum
+FROM tenk1 GROUP BY ten, two WINDOW win AS (PARTITION BY two ORDER BY ten);
+
+-- [SPARK-29540] Thrift in some cases can't parse string to date
+-- more than one window with GROUP BY
+-- SELECT sum(salary),
+-- row_number() OVER (ORDER BY depname),
+-- sum(sum(salary)) OVER (ORDER BY depname DESC)
+-- FROM empsalary GROUP BY depname;
+
+-- [SPARK-29540] Thrift in some cases can't parse string to date
+-- identical windows with different names
+-- SELECT sum(salary) OVER w1, count(*) OVER w2
+-- FROM empsalary WINDOW w1 AS (ORDER BY salary), w2 AS (ORDER BY salary);
+
+-- subplan
+-- [SPARK-28379] Correlated scalar subqueries must be aggregated
+-- SELECT lead(ten, (SELECT two FROM tenk1 WHERE s.unique2 = unique2)) OVER (PARTITION BY four ORDER BY ten)
+-- FROM tenk1 s WHERE unique2 < 10;
+
+-- empty table
+SELECT count(*) OVER (PARTITION BY four) FROM (SELECT * FROM tenk1 WHERE FALSE)s;
+
+-- [SPARK-29540] Thrift in some cases can't parse string to date
+-- mixture of agg/wfunc in the same window
+-- SELECT sum(salary) OVER w, rank() OVER w FROM empsalary WINDOW w AS (PARTITION BY depname ORDER BY salary DESC);
+
+-- Cannot safely cast 'enroll_date': StringType to DateType;
+-- SELECT empno, depname, salary, bonus, depadj, MIN(bonus) OVER (ORDER BY empno), MAX(depadj) OVER () FROM(
+-- SELECT *,
+-- CASE WHEN enroll_date < '2008-01-01' THEN 2008 - extract(year FROM enroll_date) END * 500 AS bonus,
+-- CASE WHEN
+-- AVG(salary) OVER (PARTITION BY depname) < salary
+-- THEN 200 END AS depadj FROM empsalary
+-- )s;
+
+create temporary view int4_tbl as select * from values
+ (0),
+ (123456),
+ (-123456),
+ (2147483647),
+ (-2147483647)
+ as int4_tbl(f1);
+
+-- window function over ungrouped agg over empty row set (bug before 9.1)
+SELECT SUM(COUNT(f1)) OVER () FROM int4_tbl WHERE f1=42;
+
+-- window function with ORDER BY an expression involving aggregates (9.1 bug)
+select ten,
+ sum(unique1) + sum(unique2) as res,
+ rank() over (order by sum(unique1) + sum(unique2)) as rank
+from tenk1
+group by ten order by ten;
+
+-- window and aggregate with GROUP BY expression (9.2 bug)
+-- explain
+-- select first(max(x)) over (), y
+-- from (select unique1 as x, ten+four as y from tenk1) ss
+-- group by y;
+
+-- test non-default frame specifications
+SELECT four, ten,
+sum(ten) over (partition by four order by ten),
+last(ten) over (partition by four order by ten)
+FROM (select distinct ten, four from tenk1) ss;
+
+SELECT four, ten,
+sum(ten) over (partition by four order by ten range between unbounded preceding and current row),
+last(ten) over (partition by four order by ten range between unbounded preceding and current row)
+FROM (select distinct ten, four from tenk1) ss;
+
+SELECT four, ten,
+sum(ten) over (partition by four order by ten range between unbounded preceding and unbounded following),
+last(ten) over (partition by four order by ten range between unbounded preceding and unbounded following)
+FROM (select distinct ten, four from tenk1) ss;
+
+-- [SPARK-29451] Some queries with divisions in SQL windows are failling in Thrift
+-- SELECT four, ten/4 as two,
+-- sum(ten/4) over (partition by four order by ten/4 range between unbounded preceding and current row),
+-- last(ten/4) over (partition by four order by ten/4 range between unbounded preceding and current row)
+-- FROM (select distinct ten, four from tenk1) ss;
+
+-- [SPARK-29451] Some queries with divisions in SQL windows are failling in Thrift
+-- SELECT four, ten/4 as two,
+-- sum(ten/4) over (partition by four order by ten/4 rows between unbounded preceding and current row),
+-- last(ten/4) over (partition by four order by ten/4 rows between unbounded preceding and current row)
+-- FROM (select distinct ten, four from tenk1) ss;
+
+SELECT sum(unique1) over (order by four range between current row and unbounded following),
+unique1, four
+FROM tenk1 WHERE unique1 < 10;
+
+SELECT sum(unique1) over (rows between current row and unbounded following),
+unique1, four
+FROM tenk1 WHERE unique1 < 10;
+
+SELECT sum(unique1) over (rows between 2 preceding and 2 following),
+unique1, four
+FROM tenk1 WHERE unique1 < 10;
+
+-- [SPARK-28428] Spark `exclude` always expecting `()`
+-- SELECT sum(unique1) over (rows between 2 preceding and 2 following exclude no others),
+-- unique1, four
+-- FROM tenk1 WHERE unique1 < 10;
+
+-- [SPARK-28428] Spark `exclude` always expecting `()`
+-- SELECT sum(unique1) over (rows between 2 preceding and 2 following exclude current row),
+-- unique1, four
+-- FROM tenk1 WHERE unique1 < 10;
+
+-- [SPARK-28428] Spark `exclude` always expecting `()`
+-- SELECT sum(unique1) over (rows between 2 preceding and 2 following exclude group),
+-- unique1, four
+-- FROM tenk1 WHERE unique1 < 10;
+
+-- [SPARK-28428] Spark `exclude` always expecting `()`
+-- SELECT sum(unique1) over (rows between 2 preceding and 2 following exclude ties),
+-- unique1, four
+-- FROM tenk1 WHERE unique1 < 10;
+
+-- [SPARK-28428] Spark `exclude` always expecting `()`
+-- SELECT first(unique1) over (ORDER BY four rows between current row and 2 following exclude current row),
+-- unique1, four
+-- FROM tenk1 WHERE unique1 < 10;
+
+-- [SPARK-28428] Spark `exclude` always expecting `()`
+-- SELECT first(unique1) over (ORDER BY four rows between current row and 2 following exclude group),
+-- unique1, four
+-- FROM tenk1 WHERE unique1 < 10;
+
+-- [SPARK-28428] Spark `exclude` always expecting `()`
+-- SELECT first(unique1) over (ORDER BY four rows between current row and 2 following exclude ties),
+-- unique1, four
+-- FROM tenk1 WHERE unique1 < 10;
+
+-- [SPARK-28428] Spark `exclude` always expecting `()`
+-- SELECT last(unique1) over (ORDER BY four rows between current row and 2 following exclude current row),
+-- unique1, four
+-- FROM tenk1 WHERE unique1 < 10;
+
+-- [SPARK-28428] Spark `exclude` always expecting `()`
+-- SELECT last(unique1) over (ORDER BY four rows between current row and 2 following exclude group),
+-- unique1, four
+-- FROM tenk1 WHERE unique1 < 10;
+
+-- [SPARK-28428] Spark `exclude` always expecting `()`
+-- SELECT last(unique1) over (ORDER BY four rows between current row and 2 following exclude ties),
+-- unique1, four
+-- FROM tenk1 WHERE unique1 < 10;
+
+SELECT sum(unique1) over (rows between 2 preceding and 1 preceding),
+unique1, four
+FROM tenk1 WHERE unique1 < 10;
+
+SELECT sum(unique1) over (rows between 1 following and 3 following),
+unique1, four
+FROM tenk1 WHERE unique1 < 10;
+
+SELECT sum(unique1) over (rows between unbounded preceding and 1 following),
+unique1, four
+FROM tenk1 WHERE unique1 < 10;
+
+-- [SPARK-28428] Spark `exclude` always expecting `()`
+-- SELECT sum(unique1) over (w range between current row and unbounded following),
+-- unique1, four
+-- FROM tenk1 WHERE unique1 < 10 WINDOW w AS (order by four);
+
+-- [SPARK-28428] Spark `exclude` always expecting `()`
+-- SELECT sum(unique1) over (w range between unbounded preceding and current row exclude current row),
+-- unique1, four
+-- FROM tenk1 WHERE unique1 < 10 WINDOW w AS (order by four);
+
+-- [SPARK-28428] Spark `exclude` always expecting `()`
+-- SELECT sum(unique1) over (w range between unbounded preceding and current row exclude group),
+-- unique1, four
+-- FROM tenk1 WHERE unique1 < 10 WINDOW w AS (order by four);
+
+-- [SPARK-28428] Spark `exclude` always expecting `()`
+-- SELECT sum(unique1) over (w range between unbounded preceding and current row exclude ties),
+-- unique1, four
+-- FROM tenk1 WHERE unique1 < 10 WINDOW w AS (order by four);
+
+-- [SPARK-27951] ANSI SQL: NTH_VALUE function
+-- SELECT first_value(unique1) over w,
+-- nth_value(unique1, 2) over w AS nth_2,
+-- last_value(unique1) over w, unique1, four
+-- FROM tenk1 WHERE unique1 < 10
+-- WINDOW w AS (order by four range between current row and unbounded following);
+
+-- [SPARK-28501] Frame bound value must be a literal.
+-- SELECT sum(unique1) over
+-- (order by unique1
+-- rows (SELECT unique1 FROM tenk1 ORDER BY unique1 LIMIT 1) + 1 PRECEDING),
+-- unique1
+-- FROM tenk1 WHERE unique1 < 10;
+
+CREATE TEMP VIEW v_window AS
+SELECT i.id, sum(i.id) over (order by i.id rows between 1 preceding and 1 following) as sum_rows
+FROM range(1, 11) i;
+
+SELECT * FROM v_window;
+
+-- [SPARK-28428] Spark `exclude` always expecting `()`
+-- CREATE OR REPLACE TEMP VIEW v_window AS
+-- SELECT i, sum(i) over (order by i rows between 1 preceding and 1 following
+-- exclude current row) as sum_rows FROM range(1, 10) i;
+
+-- SELECT * FROM v_window;
+
+-- [SPARK-28428] Spark `exclude` always expecting `()`
+-- CREATE OR REPLACE TEMP VIEW v_window AS
+-- SELECT i, sum(i) over (order by i rows between 1 preceding and 1 following
+-- exclude group) as sum_rows FROM range(1, 10) i;
+-- SELECT * FROM v_window;
+
+-- [SPARK-28428] Spark `exclude` always expecting `()`
+-- CREATE OR REPLACE TEMP VIEW v_window AS
+-- SELECT i, sum(i) over (order by i rows between 1 preceding and 1 following
+-- exclude ties) as sum_rows FROM generate_series(1, 10) i;
+
+-- [SPARK-28428] Spark `exclude` always expecting `()`
+-- CREATE OR REPLACE TEMP VIEW v_window AS
+-- SELECT i, sum(i) over (order by i rows between 1 preceding and 1 following
+-- exclude no others) as sum_rows FROM generate_series(1, 10) i;
+-- SELECT * FROM v_window;
+
+-- [SPARK-28648] Adds support to `groups` unit type in window clauses
+-- CREATE OR REPLACE TEMP VIEW v_window AS
+-- SELECT i.id, sum(i.id) over (order by i.id groups between 1 preceding and 1 following) as sum_rows FROM range(1, 11) i;
+-- SELECT * FROM v_window;
+
+DROP VIEW v_window;
+-- [SPARK-29540] Thrift in some cases can't parse string to date
+-- DROP TABLE empsalary;
+DROP VIEW tenk2;
+DROP VIEW int4_tbl;
diff --git a/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out b/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out
index 7cbd26e87bd2b..42c22a317eb46 100644
--- a/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 20
+-- Number of queries: 27
-- !query 0
@@ -162,3 +162,72 @@ struct<>
-- !query 19 output
org.apache.spark.sql.AnalysisException
cannot resolve 'bit_count('a')' due to data type mismatch: argument 1 requires (integral or boolean) type, however, ''a'' is of string type.; line 1 pos 7
+
+
+-- !query 20
+CREATE OR REPLACE TEMPORARY VIEW bitwise_test AS SELECT * FROM VALUES
+ (1, 1, 1, 1L),
+ (2, 3, 4, null),
+ (7, 7, 7, 3L) AS bitwise_test(b1, b2, b3, b4)
+-- !query 20 schema
+struct<>
+-- !query 20 output
+
+
+
+-- !query 21
+SELECT BIT_XOR(b3) AS n1 FROM bitwise_test where 1 = 0
+-- !query 21 schema
+struct
+-- !query 21 output
+NULL
+
+
+-- !query 22
+SELECT BIT_XOR(b4) AS n1 FROM bitwise_test where b4 is null
+-- !query 22 schema
+struct
+-- !query 22 output
+NULL
+
+
+-- !query 23
+SELECT
+ BIT_XOR(cast(b1 as tinyint)) AS a4,
+ BIT_XOR(cast(b2 as smallint)) AS b5,
+ BIT_XOR(b3) AS c2,
+ BIT_XOR(b4) AS d2,
+ BIT_XOR(distinct b4) AS e2
+FROM bitwise_test
+-- !query 23 schema
+struct
+-- !query 23 output
+4 5 2 2 2
+
+
+-- !query 24
+SELECT bit_xor(b3) FROM bitwise_test GROUP BY b1 & 1
+-- !query 24 schema
+struct
+-- !query 24 output
+4
+6
+
+
+-- !query 25
+SELECT b1, bit_xor(b2) FROM bitwise_test GROUP BY b1 HAVING bit_and(b2) < 7
+-- !query 25 schema
+struct
+-- !query 25 output
+1 1
+2 3
+
+
+-- !query 26
+SELECT b1, b2, bit_xor(b2) OVER (PARTITION BY b1 ORDER BY b2) FROM bitwise_test
+-- !query 26 schema
+struct
+-- !query 26 output
+1 1 1
+2 3 3
+7 7 7
diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out
index aef23963da374..fd6e51b2385de 100644
--- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 46
+-- Number of queries: 59
-- !query 0
@@ -337,10 +337,114 @@ interval 1 years 2 months 3 weeks 4 days 5 hours 6 minutes 7 seconds 8 milliseco
-- !query 36
-select interval 10 nanoseconds
+select interval '30' year '25' month '-100' day '40' hour '80' minute '299.889987299' second
-- !query 36 schema
-struct<>
+struct
-- !query 36 output
+interval 32 years 1 months -14 weeks -6 hours -35 minutes -110 milliseconds -13 microseconds
+
+
+-- !query 37
+select interval '0 0:0:0.1' day to second
+-- !query 37 schema
+struct
+-- !query 37 output
+interval 100 milliseconds
+
+
+-- !query 38
+select interval '10-9' year to month
+-- !query 38 schema
+struct
+-- !query 38 output
+interval 10 years 9 months
+
+
+-- !query 39
+select interval '20 15:40:32.99899999' day to hour
+-- !query 39 schema
+struct
+-- !query 39 output
+interval 2 weeks 6 days 15 hours
+
+
+-- !query 40
+select interval '20 15:40:32.99899999' day to minute
+-- !query 40 schema
+struct
+-- !query 40 output
+interval 2 weeks 6 days 15 hours 40 minutes
+
+
+-- !query 41
+select interval '20 15:40:32.99899999' day to second
+-- !query 41 schema
+struct
+-- !query 41 output
+interval 2 weeks 6 days 15 hours 40 minutes 32 seconds 998 milliseconds 999 microseconds
+
+
+-- !query 42
+select interval '15:40:32.99899999' hour to minute
+-- !query 42 schema
+struct
+-- !query 42 output
+interval 15 hours 40 minutes
+
+
+-- !query 43
+select interval '15:40.99899999' hour to second
+-- !query 43 schema
+struct
+-- !query 43 output
+interval 15 minutes 40 seconds 998 milliseconds 999 microseconds
+
+
+-- !query 44
+select interval '15:40' hour to second
+-- !query 44 schema
+struct
+-- !query 44 output
+interval 15 hours 40 minutes
+
+
+-- !query 45
+select interval '15:40:32.99899999' hour to second
+-- !query 45 schema
+struct
+-- !query 45 output
+interval 15 hours 40 minutes 32 seconds 998 milliseconds 999 microseconds
+
+
+-- !query 46
+select interval '20 40:32.99899999' minute to second
+-- !query 46 schema
+struct
+-- !query 46 output
+interval 2 weeks 6 days 40 minutes 32 seconds 998 milliseconds 999 microseconds
+
+
+-- !query 47
+select interval '40:32.99899999' minute to second
+-- !query 47 schema
+struct
+-- !query 47 output
+interval 40 minutes 32 seconds 998 milliseconds 999 microseconds
+
+
+-- !query 48
+select interval '40:32' minute to second
+-- !query 48 schema
+struct
+-- !query 48 output
+interval 40 minutes 32 seconds
+
+
+-- !query 49
+select interval 10 nanoseconds
+-- !query 49 schema
+struct<>
+-- !query 49 output
org.apache.spark.sql.catalyst.parser.ParseException
no viable alternative at input 'interval 10 nanoseconds'(line 1, pos 19)
@@ -350,11 +454,11 @@ select interval 10 nanoseconds
-------------------^^^
--- !query 37
+-- !query 50
select GEO '(10,-6)'
--- !query 37 schema
+-- !query 50 schema
struct<>
--- !query 37 output
+-- !query 50 output
org.apache.spark.sql.catalyst.parser.ParseException
Literals of type 'GEO' are currently not supported.(line 1, pos 7)
@@ -364,19 +468,19 @@ select GEO '(10,-6)'
-------^^^
--- !query 38
+-- !query 51
select 90912830918230182310293801923652346786BD, 123.0E-28BD, 123.08BD
--- !query 38 schema
+-- !query 51 schema
struct<90912830918230182310293801923652346786:decimal(38,0),1.230E-26:decimal(29,29),123.08:decimal(5,2)>
--- !query 38 output
+-- !query 51 output
90912830918230182310293801923652346786 0.0000000000000000000000000123 123.08
--- !query 39
+-- !query 52
select 1.20E-38BD
--- !query 39 schema
+-- !query 52 schema
struct<>
--- !query 39 output
+-- !query 52 output
org.apache.spark.sql.catalyst.parser.ParseException
decimal can only support precision up to 38(line 1, pos 7)
@@ -386,19 +490,19 @@ select 1.20E-38BD
-------^^^
--- !query 40
+-- !query 53
select x'2379ACFe'
--- !query 40 schema
+-- !query 53 schema
struct
--- !query 40 output
+-- !query 53 output
#y��
--- !query 41
+-- !query 54
select X'XuZ'
--- !query 41 schema
+-- !query 54 schema
struct<>
--- !query 41 output
+-- !query 54 output
org.apache.spark.sql.catalyst.parser.ParseException
contains illegal character for hexBinary: 0XuZ(line 1, pos 7)
@@ -408,33 +512,33 @@ select X'XuZ'
-------^^^
--- !query 42
+-- !query 55
SELECT 3.14, -3.14, 3.14e8, 3.14e-8, -3.14e8, -3.14e-8, 3.14e+8, 3.14E8, 3.14E-8
--- !query 42 schema
+-- !query 55 schema
struct<3.14:decimal(3,2),-3.14:decimal(3,2),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10),-3.14E+8:decimal(3,-6),-3.14E-8:decimal(10,10),3.14E+8:decimal(3,-6),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10)>
--- !query 42 output
+-- !query 55 output
3.14 -3.14 314000000 0.0000000314 -314000000 -0.0000000314 314000000 314000000 0.0000000314
--- !query 43
+-- !query 56
select map(1, interval 1 day, 2, interval 3 week)
--- !query 43 schema
+-- !query 56 schema
struct