diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index f8289c1fdcda..41c511a11048 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -993,7 +993,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { return null; } - int numInputBytes = 0; // total number of bytes from the inputs + long numInputBytes = 0L; // total number of bytes from the inputs int numInputs = 0; // number of non-null inputs for (int i = 0; i < inputs.length; i++) { if (inputs[i] != null) { @@ -1009,7 +1009,8 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { // Allocate a new byte array, and copy the inputs one by one into it. // The size of the new array is the size of all inputs, plus the separators. - final byte[] result = new byte[numInputBytes + (numInputs - 1) * separator.numBytes]; + int resultSize = Math.toIntExact(numInputBytes + (numInputs - 1) * (long)separator.numBytes); + final byte[] result = new byte[resultSize]; int offset = 0; for (int i = 0, j = 0; i < inputs.length; i++) {