diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 28fb64f7cd0e..d4a6d7f95763 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -32,25 +32,6 @@ public final class CalendarInterval implements Serializable { public static final long MICROS_PER_DAY = MICROS_PER_HOUR * 24; public static final long MICROS_PER_WEEK = MICROS_PER_DAY * 7; - /** - * A function to generate regex which matches interval string's unit part like "3 years". - * - * First, we can leave out some units in interval string, and we only care about the value of - * unit, so here we use non-capturing group to wrap the actual regex. - * At the beginning of the actual regex, we should match spaces before the unit part. - * Next is the number part, starts with an optional "-" to represent negative value. We use - * capturing group to wrap this part as we need the value later. - * Finally is the unit name, ends with an optional "s". - */ - private static String unitRegex(String unit) { - return "(?:\\s+(-?\\d+)\\s+" + unit + "s?)?"; - } - - private static Pattern p = Pattern.compile("interval" + unitRegex("year") + unitRegex("month") + - unitRegex("week") + unitRegex("day") + unitRegex("hour") + unitRegex("minute") + - unitRegex("second") + unitRegex("millisecond") + unitRegex("microsecond"), - Pattern.CASE_INSENSITIVE); - private static Pattern yearMonthPattern = Pattern.compile("^(?:['|\"])?([+|-])?(\\d+)-(\\d+)(?:['|\"])?$"); @@ -59,14 +40,6 @@ private static String unitRegex(String unit) { private static Pattern quoteTrimPattern = Pattern.compile("^(?:['|\"])?(.*?)(?:['|\"])?$"); - private static long toLong(String s) { - if (s == null) { - return 0; - } else { - return Long.parseLong(s); - } - } - /** * Convert a string to CalendarInterval. Return null if the input string is not a valid interval. * This method is case-insensitive. @@ -79,6 +52,9 @@ public static CalendarInterval fromString(String s) { } } + private enum ParsingState {START, UNIT_VALUE, UNIT_NAME} + private enum Unit {YEAR, MONTH, WEEK, DAY, HOUR, MINUTE, SECOND, MILLISECOND, MICROSECOND} + /** * Convert a string to CalendarInterval. This method can handle * strings without the `interval` prefix and throws IllegalArgumentException @@ -94,31 +70,88 @@ public static CalendarInterval fromCaseInsensitiveString(String s) { if (trimmed.isEmpty()) { throw new IllegalArgumentException("Interval cannot be blank"); } - String prefix = "interval"; - String intervalStr = trimmed; - // Checks the given interval string does not start with the `interval` prefix - if (!intervalStr.regionMatches(true, 0, prefix, 0, prefix.length())) { - // Prepend `interval` if it does not present because - // the regular expression strictly require it. - intervalStr = prefix + " " + trimmed; - } else if (intervalStr.length() == prefix.length()) { - throw new IllegalArgumentException("Interval string must have time units"); + String[] splits = trimmed.split("\\s+"); + ParsingState state = ParsingState.START; + long currentValue = 0; + boolean[] parsedUnits = new boolean[Unit.values().length]; + long months = 0; + long microseconds = 0; + + for (String split: splits) { + switch (state) { + case START: + if (split.equalsIgnoreCase("interval")) { + if (splits.length == 1) { + throw new IllegalArgumentException("Interval string must have time units"); + } + state = ParsingState.UNIT_VALUE; + break; + } + // falls through + case UNIT_VALUE: + try { + currentValue = Long.parseLong(split); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid interval unit value: " + split); + } + state = ParsingState.UNIT_NAME; + break; + case UNIT_NAME: + String upsplit = split.toUpperCase(); + if (upsplit.endsWith("S")) { + upsplit = upsplit.substring(0, upsplit.length() - 1); + } + Unit currentUnit; + try { + currentUnit = Unit.valueOf(upsplit); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Invalid interval unit name: " + split); + } + int index = currentUnit.ordinal(); + if (parsedUnits[index]) { + throw new IllegalArgumentException("Interval units must be unique"); + } else { + parsedUnits[index] = true; + } + switch (currentUnit) { + case YEAR: + months += 12 * currentValue; + break; + case MONTH: + months += currentValue; + break; + case WEEK: + microseconds += currentValue * MICROS_PER_WEEK; + break; + case DAY: + microseconds += currentValue * MICROS_PER_DAY; + break; + case HOUR: + microseconds += currentValue * MICROS_PER_HOUR; + break; + case MINUTE: + microseconds += currentValue * MICROS_PER_MINUTE; + break; + case SECOND: + microseconds += currentValue * MICROS_PER_SECOND; + break; + case MILLISECOND: + microseconds += currentValue * MICROS_PER_MILLI; + break; + case MICROSECOND: + microseconds += currentValue; + break; + } + state = ParsingState.UNIT_VALUE; + break; + } } - Matcher m = p.matcher(intervalStr); - if (!m.matches()) { + if (state != ParsingState.UNIT_VALUE) { throw new IllegalArgumentException("Invalid interval: " + s); } - long months = toLong(m.group(1)) * 12 + toLong(m.group(2)); - long microseconds = toLong(m.group(3)) * MICROS_PER_WEEK; - microseconds += toLong(m.group(4)) * MICROS_PER_DAY; - microseconds += toLong(m.group(5)) * MICROS_PER_HOUR; - microseconds += toLong(m.group(6)) * MICROS_PER_MINUTE; - microseconds += toLong(m.group(7)) * MICROS_PER_SECOND; - microseconds += toLong(m.group(8)) * MICROS_PER_MILLI; - microseconds += toLong(m.group(9)); - return new CalendarInterval((int) months, microseconds); + return new CalendarInterval(Math.toIntExact(months), microseconds); } public static long toLongWithRange(String fieldName, diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java index 587071332ce4..8e163d713e74 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java @@ -113,7 +113,7 @@ public void fromCaseInsensitiveStringTest() { } } - for (String input : new String[]{"interval", "interval1 day", "foo", "foo 1 day"}) { + for (String input : new String[]{"interval", "interval1 day", "foo", "foo 1 day", "1 dday"}) { try { fromCaseInsensitiveString(input); fail("Expected to throw an exception for the invalid input"); @@ -122,7 +122,7 @@ public void fromCaseInsensitiveStringTest() { if (input.trim().equalsIgnoreCase("interval")) { assertTrue(msg.contains("Interval string must have time units")); } else { - assertTrue(msg.contains("Invalid interval:")); + assertTrue(msg.contains("Invalid interval")); } } } @@ -297,4 +297,43 @@ public void fromStringCaseSensitivityTest() { assertNull(fromString("INTERVAL")); assertNull(fromString(" Interval ")); } + + @Test + public void uniqueUnitTest() { + String[] inputs = new String[]{ + "1 year 2 years", + "2 months 1 month", + "interval 1 month 2 weeks 1 day 1 week", + "interval 1 day 2 weeks 3 days", + " 1 hour 1 hour", + "6 minutes 1 Minute ", + "7 SECONDS 1 Second", + "3 MilliSECONDS 1 MilliSecond", + "8 microseconds 10 MICROSECONDS" + }; + for (String input : inputs) { + try { + fromCaseInsensitiveString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("Interval units must be unique")); + } + } + } + + @Test + public void unorderedUnitsTest() { + Arrays.asList( + "interval 23 month -5 years ", + "23 month -5 years").forEach(input -> + assertEquals(fromString(input), new CalendarInterval(-5 * 12 + 23, 0)) + ); + Arrays.asList( + "interval 1 microsecond 2 milliseconds 3 seconds 4 minutes 5 hours", + "1 microsecond 2 milliseconds 3 seconds 4 minutes 5 hours", + "1 microsecond 5 hours 2 milliseconds 4 minutes 3 seconds ").forEach(input -> + assertEquals(fromString(input), + fromString("5 hours 4 minutes 3 seconds 2 milliseconds 1 microseconds")) + ); + } }