Skip to content

Commit d5dc307

Browse files
committed
[SPARK-34432][SQL][TEST][FOLLOWUP] Add a java implementation of simple writable data source in DataSourceV2Suite
### What changes were proposed in this pull request? This is a followup of #19269 In #19269 , there is only a scala implementation of simple writable data source in `DataSourceV2Suite`. This PR adds a java implementation of it. ### Why are the changes needed? To improve test coverage. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing testsuites
1 parent 7ea3a33 commit d5dc307

File tree

2 files changed

+375
-2
lines changed

2 files changed

+375
-2
lines changed
Lines changed: 374 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,374 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package test.org.apache.spark.sql.connector;
19+
20+
import org.apache.hadoop.conf.Configuration;
21+
import org.apache.hadoop.fs.*;
22+
import org.apache.spark.deploy.SparkHadoopUtil;
23+
import org.apache.spark.sql.catalyst.InternalRow;
24+
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
25+
import org.apache.spark.sql.connector.TestingV2Source;
26+
import org.apache.spark.sql.connector.catalog.SessionConfigSupport;
27+
import org.apache.spark.sql.connector.catalog.SupportsWrite;
28+
import org.apache.spark.sql.connector.catalog.Table;
29+
import org.apache.spark.sql.connector.read.InputPartition;
30+
import org.apache.spark.sql.connector.read.PartitionReader;
31+
import org.apache.spark.sql.connector.read.PartitionReaderFactory;
32+
import org.apache.spark.sql.connector.read.ScanBuilder;
33+
import org.apache.spark.sql.connector.write.*;
34+
import org.apache.spark.sql.types.StructType;
35+
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
36+
import org.apache.spark.util.SerializableConfiguration;
37+
38+
import java.io.BufferedReader;
39+
import java.io.IOException;
40+
import java.io.InputStreamReader;
41+
import java.util.Arrays;
42+
import java.util.Iterator;
43+
44+
public class JavaSimpleWritableDataSource implements TestingV2Source, SessionConfigSupport {
45+
46+
private final StructType tableSchema = new StructType().add("i", "long").add("j", "long");
47+
48+
@Override
49+
public String keyPrefix() {
50+
return "javaSimpleWritableDataSource";
51+
}
52+
53+
@Override
54+
public Table getTable(CaseInsensitiveStringMap options) {
55+
return new MyTable(options);
56+
}
57+
58+
static class JavaCSVInputPartitionReader implements InputPartition {
59+
private String path;
60+
61+
JavaCSVInputPartitionReader(String path) {
62+
this.path = path;
63+
}
64+
65+
public String getPath() {
66+
return path;
67+
}
68+
69+
public void setPath(String path) {
70+
this.path = path;
71+
}
72+
}
73+
74+
static class JavaCSVReaderFactory implements PartitionReaderFactory {
75+
76+
private final SerializableConfiguration conf;
77+
78+
JavaCSVReaderFactory(SerializableConfiguration conf) {
79+
this.conf = conf;
80+
}
81+
82+
@Override
83+
public PartitionReader<InternalRow> createReader(InputPartition partition) {
84+
String path = ((JavaCSVInputPartitionReader) partition).getPath();
85+
Path filePath = new Path(path);
86+
try {
87+
FileSystem fs = filePath.getFileSystem(conf.value());
88+
return new PartitionReader<InternalRow>() {
89+
private final FSDataInputStream inputStream = fs.open(filePath);
90+
private final Iterator<String> lines =
91+
new BufferedReader(new InputStreamReader(inputStream)).lines().iterator();
92+
private String currentLine = "";
93+
94+
@Override
95+
public boolean next() {
96+
if (lines.hasNext()) {
97+
currentLine = lines.next();
98+
return true;
99+
} else {
100+
return false;
101+
}
102+
}
103+
104+
@Override
105+
public InternalRow get() {
106+
Object[] objects =
107+
Arrays.stream(currentLine.split(","))
108+
.map(String::trim)
109+
.map(Long::parseLong)
110+
.toArray();
111+
return new GenericInternalRow(objects);
112+
}
113+
114+
@Override
115+
public void close() throws IOException {
116+
inputStream.close();
117+
}
118+
};
119+
} catch (IOException e) {
120+
throw new RuntimeException(e);
121+
}
122+
}
123+
}
124+
125+
static class JavaSimpleCounter {
126+
private static Integer count = 0;
127+
128+
public static void increaseCounter() {
129+
count += 1;
130+
}
131+
132+
public static int getCounter() {
133+
return count;
134+
}
135+
136+
public static void resetCounter() {
137+
count = 0;
138+
}
139+
}
140+
141+
static class JavaCSVDataWriterFactory implements DataWriterFactory {
142+
private final String path;
143+
private final String jobId;
144+
private final SerializableConfiguration conf;
145+
146+
JavaCSVDataWriterFactory(String path, String jobId, SerializableConfiguration conf) {
147+
this.path = path;
148+
this.jobId = jobId;
149+
this.conf = conf;
150+
}
151+
152+
@Override
153+
public DataWriter<InternalRow> createWriter(int partitionId, long taskId) {
154+
try {
155+
Path jobPath = new Path(new Path(path, "_temporary"), jobId);
156+
Path filePath = new Path(jobPath, String.format("%s-%d-%d", jobId, partitionId, taskId));
157+
FileSystem fs = filePath.getFileSystem(conf.value());
158+
return new JavaCSVDataWriter(fs, filePath);
159+
} catch (IOException e) {
160+
throw new RuntimeException(e);
161+
}
162+
}
163+
}
164+
165+
static class JavaCSVDataWriter implements DataWriter<InternalRow> {
166+
private final FileSystem fs;
167+
private final Path file;
168+
private final FSDataOutputStream out;
169+
170+
JavaCSVDataWriter(FileSystem fs, Path file) throws IOException {
171+
this.fs = fs;
172+
this.file = file;
173+
out = fs.create(file);
174+
}
175+
176+
@Override
177+
public void write(InternalRow record) throws IOException {
178+
out.writeBytes(String.format("%d,%d\n", record.getLong(0), record.getLong(1)));
179+
}
180+
181+
@Override
182+
public WriterCommitMessage commit() throws IOException {
183+
out.close();
184+
return null;
185+
}
186+
187+
@Override
188+
public void abort() throws IOException {
189+
try {
190+
out.close();
191+
} finally {
192+
fs.delete(file, false);
193+
}
194+
}
195+
196+
@Override
197+
public void close() {}
198+
}
199+
200+
class MyScanBuilder extends JavaSimpleScanBuilder {
201+
private final String path;
202+
private final Configuration conf;
203+
204+
MyScanBuilder(String path, Configuration conf) {
205+
this.path = path;
206+
this.conf = conf;
207+
}
208+
209+
@Override
210+
public InputPartition[] planInputPartitions() {
211+
Path dataPath = new Path(this.path);
212+
try {
213+
FileSystem fs = dataPath.getFileSystem(conf);
214+
if (fs.exists(dataPath)) {
215+
return Arrays.stream(fs.listStatus(dataPath))
216+
.filter(
217+
status -> {
218+
String name = status.getPath().getName();
219+
return !name.startsWith("_") && !name.startsWith(".");
220+
})
221+
.map(f -> new JavaCSVInputPartitionReader(f.getPath().toUri().toString()))
222+
.toArray(InputPartition[]::new);
223+
} else {
224+
return new InputPartition[0];
225+
}
226+
} catch (Exception e) {
227+
throw new RuntimeException(e);
228+
}
229+
}
230+
231+
@Override
232+
public StructType readSchema() {
233+
return tableSchema;
234+
}
235+
236+
@Override
237+
public PartitionReaderFactory createReaderFactory() {
238+
SerializableConfiguration serializableConf = new SerializableConfiguration(conf);
239+
return new JavaCSVReaderFactory(serializableConf);
240+
}
241+
}
242+
243+
static class MyWriteBuilder implements WriteBuilder, SupportsTruncate {
244+
private final String path;
245+
private final String queryId;
246+
private boolean needTruncate = false;
247+
248+
MyWriteBuilder(String path, LogicalWriteInfo info) {
249+
this.path = path;
250+
this.queryId = info.queryId();
251+
}
252+
253+
@Override
254+
public WriteBuilder truncate() {
255+
this.needTruncate = true;
256+
return this;
257+
}
258+
259+
@Override
260+
public Write build() {
261+
return new MyWrite(path, queryId, needTruncate);
262+
}
263+
}
264+
265+
static class MyWrite implements Write {
266+
private final String path;
267+
private final String queryId;
268+
private final boolean needTruncate;
269+
270+
MyWrite(String path, String queryId, boolean needTruncate) {
271+
this.path = path;
272+
this.queryId = queryId;
273+
this.needTruncate = needTruncate;
274+
}
275+
276+
@Override
277+
public BatchWrite toBatch() {
278+
Path hadoopPath = new Path(path);
279+
Configuration hadoopConf = SparkHadoopUtil.get().conf();
280+
try {
281+
FileSystem fs = hadoopPath.getFileSystem(hadoopConf);
282+
if (needTruncate) {
283+
fs.delete(hadoopPath, true);
284+
}
285+
} catch (IOException e) {
286+
throw new RuntimeException(e);
287+
}
288+
String pathStr = hadoopPath.toUri().toString();
289+
return new MyBatchWrite(queryId, pathStr, hadoopConf);
290+
}
291+
}
292+
293+
static class MyBatchWrite implements BatchWrite {
294+
295+
private final String queryId;
296+
private final String path;
297+
private final Configuration conf;
298+
299+
MyBatchWrite(String queryId, String path, Configuration conf) {
300+
this.queryId = queryId;
301+
this.path = path;
302+
this.conf = conf;
303+
}
304+
305+
@Override
306+
public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) {
307+
JavaSimpleCounter.resetCounter();
308+
return new JavaCSVDataWriterFactory(path, queryId, new SerializableConfiguration(conf));
309+
}
310+
311+
@Override
312+
public void onDataWriterCommit(WriterCommitMessage message) {
313+
JavaSimpleCounter.increaseCounter();
314+
}
315+
316+
@Override
317+
public void commit(WriterCommitMessage[] messages) {
318+
Path finalPath = new Path(this.path);
319+
Path jobPath = new Path(new Path(finalPath, "_temporary"), queryId);
320+
try {
321+
FileSystem fs = jobPath.getFileSystem(conf);
322+
FileStatus[] fileStatuses = fs.listStatus(jobPath);
323+
try {
324+
for (FileStatus status : fileStatuses) {
325+
Path file = status.getPath();
326+
Path dest = new Path(finalPath, file.getName());
327+
if (!fs.rename(file, dest)) {
328+
throw new IOException(String.format("failed to rename(%s, %s)", file, dest));
329+
}
330+
}
331+
} finally {
332+
fs.delete(jobPath, true);
333+
}
334+
} catch (IOException e) {
335+
throw new RuntimeException(e);
336+
}
337+
}
338+
339+
@Override
340+
public void abort(WriterCommitMessage[] messages) {
341+
try {
342+
Path jobPath = new Path(new Path(this.path, "_temporary"), queryId);
343+
FileSystem fs = jobPath.getFileSystem(conf);
344+
fs.delete(jobPath, true);
345+
} catch (IOException e) {
346+
throw new RuntimeException(e);
347+
}
348+
}
349+
}
350+
351+
class MyTable extends JavaSimpleBatchTable implements SupportsWrite {
352+
private final String path;
353+
private final Configuration conf = SparkHadoopUtil.get().conf();
354+
355+
MyTable(CaseInsensitiveStringMap options) {
356+
this.path = options.get("path");
357+
}
358+
359+
@Override
360+
public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
361+
return new MyScanBuilder(new Path(path).toUri().toString(), conf);
362+
}
363+
364+
@Override
365+
public WriteBuilder newWriteBuilder(LogicalWriteInfo info) {
366+
return new MyWriteBuilder(path, info);
367+
}
368+
369+
@Override
370+
public StructType schema() {
371+
return tableSchema;
372+
}
373+
}
374+
}

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
228228
}
229229

230230
test("simple writable data source") {
231-
// TODO: java implementation.
232-
Seq(classOf[SimpleWritableDataSource]).foreach { cls =>
231+
Seq(classOf[SimpleWritableDataSource], classOf[JavaSimpleWritableDataSource]).foreach { cls =>
233232
withTempPath { file =>
234233
val path = file.getCanonicalPath
235234
assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty)

0 commit comments

Comments
 (0)