Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Checkpoint
  • Loading branch information
akrambek committed Oct 18, 2024
commit 5d4414d415c846d310f6271868bd12e76c693179
2 changes: 1 addition & 1 deletion incubator/binding-pgsql-kafka/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
</licenses>

<properties>
<jacoco.coverage.ratio>0.82</jacoco.coverage.ratio>
<jacoco.coverage.ratio>0.86</jacoco.coverage.ratio>
<jacoco.missed.count>0</jacoco.missed.count>
</properties>

Expand Down
2 changes: 1 addition & 1 deletion incubator/binding-pgsql/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
</licenses>

<properties>
<jacoco.coverage.ratio>0.89</jacoco.coverage.ratio>
<jacoco.coverage.ratio>0.91</jacoco.coverage.ratio>
<jacoco.missed.count>0</jacoco.missed.count>
</properties>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1939,6 +1939,19 @@ createfunctionstmt
)? createfunc_opt_list
;

opt_type_parameters
: '<' type_parameters '>'
|
;

type_parameters
: type_parameter (COMMA type_parameter)*
;

type_parameter
: colid typename
;

opt_or_replace
: OR REPLACE
|
Expand Down Expand Up @@ -1996,7 +2009,6 @@ func_return

func_type
: typename
| SETOF? (builtin_function_name | type_function_name | LEFT | RIGHT) attrs PERCENT TYPE_P
;

func_arg_with_default
Expand Down Expand Up @@ -3400,7 +3412,7 @@ consttypename
;

generictype
: (builtin_function_name | type_function_name | LEFT | RIGHT) attrs? opt_type_modifiers
: (builtin_function_name | type_function_name | LEFT | RIGHT) attrs? opt_type_modifiers opt_type_parameters
;

opt_type_modifiers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public PgsqlParser()
this.commandListener = new SqlCommandListener();
this.createTableListener = new SqlCreateTableTopicListener();
this.createStreamListener = new SqlCreateStreamListener();
this.createFunctionListener = new SqlCreateFunctionListener();
this.createFunctionListener = new SqlCreateFunctionListener(tokens);
this.createMaterializedViewListener = new SqlCreateMaterializedViewListener(tokens);
this.dropListener = new SqlDropListener();
parser.setErrorHandler(errorStrategy);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ public void enterCreatematviewstmt(
public void enterCreatefunctionstmt(
PostgreSqlParser.CreatefunctionstmtContext ctx)
{
command = "CREATE FUNCTION";
String functionBody = ctx.getText();

if (!functionBody.contains("$$"))
{
command = "CREATE FUNCTION";
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import java.util.ArrayList;
import java.util.List;

import org.antlr.v4.runtime.TokenStream;

import io.aklivity.zilla.runtime.binding.pgsql.parser.PostgreSqlParser;
import io.aklivity.zilla.runtime.binding.pgsql.parser.PostgreSqlParserBaseListener;
import io.aklivity.zilla.runtime.binding.pgsql.parser.module.FunctionArgument;
Expand All @@ -27,11 +29,19 @@ public class SqlCreateFunctionListener extends PostgreSqlParserBaseListener
private final List<FunctionArgument> arguments = new ArrayList<>();
private final List<FunctionArgument> tables = new ArrayList<>();

private final TokenStream tokens;

private String name;
private String returnType;
private String asFunction;
private String language;

public SqlCreateFunctionListener(
TokenStream tokens)
{
this.tokens = tokens;
}

public FunctionInfo functionInfo()
{
return new FunctionInfo(name, arguments, returnType, tables, asFunction, language);
Expand Down Expand Up @@ -66,7 +76,7 @@ public void enterTable_func_column(
public void enterFunc_type(
PostgreSqlParser.Func_typeContext ctx)
{
returnType = ctx.typename().getText();
returnType = tokens.getText(ctx.typename());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,98 @@ public void shouldParseCreateFunctionWithLanguage()
assertEquals("python", functionInfo.language());
}

@Test
public void shouldParseCreateFunctionWithStructReturnType()
{
String sql = "CREATE FUNCTION test_function(int) RETURNS struct<key varchar, value varchar>" +
" LANGUAGE python AS 'test_function';";
FunctionInfo functionInfo = parser.parseCreateFunction(sql);
assertNotNull(functionInfo);
assertEquals("test_function", functionInfo.name());
assertEquals("struct<key varchar, value varchar>", functionInfo.returnType());
assertEquals("python", functionInfo.language());
}

@Test(expected = ParseCancellationException.class)
public void shouldHandleInvalidCreateFunction()
{
String sql = "CREATE FUNCTION test_function()";
parser.parseCreateFunction(sql);
}

@Test
public void shouldParseCreateTableWithUniqueConstraint()
{
String sql = "CREATE TABLE test (id INT UNIQUE, name VARCHAR(100));";
TableInfo tableInfo = parser.parseCreateTable(sql);
assertNotNull(tableInfo);
assertEquals(2, tableInfo.columns().size());
assertTrue(tableInfo.columns().containsKey("id"));
assertTrue(tableInfo.columns().containsKey("name"));
}

@Test
public void shouldParseCreateTableWithForeignKey()
{
String sql = "CREATE TABLE test (id INT, name VARCHAR(100), CONSTRAINT fk_name FOREIGN KEY (name)" +
" REFERENCES other_table(name));";
TableInfo tableInfo = parser.parseCreateTable(sql);
assertNotNull(tableInfo);
assertEquals(2, tableInfo.columns().size());
assertTrue(tableInfo.columns().containsKey("id"));
assertTrue(tableInfo.columns().containsKey("name"));
}

@Test
public void shouldParseCreateTableWithCheckConstraint()
{
String sql = "CREATE TABLE test (id INT, name VARCHAR(100), CHECK (id > 0));";
TableInfo tableInfo = parser.parseCreateTable(sql);
assertNotNull(tableInfo);
assertEquals(2, tableInfo.columns().size());
assertTrue(tableInfo.columns().containsKey("id"));
assertTrue(tableInfo.columns().containsKey("name"));
}

@Test
public void shouldHandleInvalidCreateTableWithMissingColumns()
{
String sql = "CREATE TABLE test ();";
parser.parseCreateTable(sql);
}

@Test
public void shouldParseCreateTableWithDefaultValues()
{
String sql = "CREATE TABLE test (id INT DEFAULT 0, name VARCHAR(100) DEFAULT 'unknown');";
TableInfo tableInfo = parser.parseCreateTable(sql);
assertNotNull(tableInfo);
assertEquals(2, tableInfo.columns().size());
assertEquals("INT", tableInfo.columns().get("id"));
assertEquals("VARCHAR(100)", tableInfo.columns().get("name"));
}

@Test
public void shouldParseCreateTableWithNotNullConstraint()
{
String sql = "CREATE TABLE test (id INT NOT NULL, name VARCHAR(100) NOT NULL);";
TableInfo tableInfo = parser.parseCreateTable(sql);
assertNotNull(tableInfo);
assertEquals(2, tableInfo.columns().size());
assertTrue(tableInfo.columns().containsKey("id"));
assertTrue(tableInfo.columns().containsKey("name"));
}

@Test
public void shouldParseCreateTableWithMultipleConstraints()
{
String sql = "CREATE TABLE test (id INT PRIMARY KEY, name VARCHAR(100) UNIQUE, age INT CHECK (age > 0));";
TableInfo tableInfo = parser.parseCreateTable(sql);
assertNotNull(tableInfo);
assertEquals(3, tableInfo.columns().size());
assertTrue(tableInfo.primaryKeys().contains("id"));
assertTrue(tableInfo.columns().containsKey("name"));
assertTrue(tableInfo.columns().containsKey("age"));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ write zilla:data.ext ${pgsql:dataEx()
.build()
.build()}
write "CREATE FUNCTION key_value(bytea)\n"
"RETURNS struct < key varchar , value varchar >\n"
"AS key_value\n"
"RETURNS struct<key varchar, value varchar>\n"
"AS 'key_value'\n"
"LANGUAGE python\n"
"USING LINK 'http://localhost:8816';"
[0x00]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ read zilla:data.ext ${pgsql:dataEx()
.build()
.build()}
read "CREATE FUNCTION key_value(bytea)\n"
"RETURNS struct < key varchar , value varchar >\n"
"AS key_value\n"
"RETURNS struct<key varchar, value varchar>\n"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here about spaces? Checking if they are no longer permitted or not.

"AS 'key_value'\n"
"LANGUAGE python\n"
"USING LINK 'http://localhost:8816';"
[0x00]
Expand Down
2 changes: 1 addition & 1 deletion incubator/binding-risingwave/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
</licenses>

<properties>
<jacoco.coverage.ratio>0.86</jacoco.coverage.ratio>
<jacoco.coverage.ratio>0.92</jacoco.coverage.ratio>
<jacoco.missed.count>0</jacoco.missed.count>
</properties>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,22 @@ public String generate(
{
String functionName = functionInfo.name();
String asFunction = functionInfo.asFunction();
List<FunctionArgument> arguments = functionInfo.arguments();
List<FunctionArgument> tables = functionInfo.tables();

fieldBuilder.setLength(0);

functionInfo.arguments()
arguments
.forEach(arg -> fieldBuilder.append(
arg.name() != null
? "%s %s, ".formatted(arg.name(), arg.type())
: "%s, ".formatted(arg.type())));

if (!functionInfo.arguments().isEmpty())
if (!arguments.isEmpty())
{
fieldBuilder.delete(fieldBuilder.length() - 2, fieldBuilder.length());
}
String arguments = fieldBuilder.toString();

String funcArguments = fieldBuilder.toString();

String language = functionInfo.language() != null ? functionInfo.language() : "java";
String server = "python".equalsIgnoreCase(language) ? pythonServer : javaServer;
Expand All @@ -97,6 +97,6 @@ public String generate(
returnType = fieldBuilder.toString();
}

return sqlFormat.formatted(functionName, arguments, returnType, asFunction, language, server);
return sqlFormat.formatted(functionName, funcArguments, returnType, asFunction, language, server);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import java.util.LinkedHashMap;
import java.util.Map;
import java.util.stream.Collectors;

import io.aklivity.zilla.runtime.binding.pgsql.parser.module.StreamInfo;
import io.aklivity.zilla.runtime.binding.pgsql.parser.module.TableInfo;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@

public final class RisingwaveProxyFactory implements RisingwaveStreamFactory
{
private static final String SPLIT_STATEMENTS = "(?<=;)(?!\\s*\\x00)";
private static final int END_OF_FIELD = 0x00;

private static final int FLAGS_INIT = 0x02;
Expand All @@ -83,6 +82,8 @@ public final class RisingwaveProxyFactory implements RisingwaveStreamFactory
private static final Consumer<OctetsFW.Builder> EMPTY_EXTENSION = ex -> {};

private final PgsqlParser parser = new PgsqlParser();
private final List<String> statements = new ArrayList<>();
private final StringBuilder currentStatement = new StringBuilder();

private final BeginFW beginRO = new BeginFW();
private final DataFW dataRO = new DataFW();
Expand Down Expand Up @@ -725,16 +726,16 @@ private void doParseQuery(
final MutableDirectBuffer parserBuffer = bufferPool.buffer(parserSlot);

String sql = parserBuffer.getStringWithoutLengthAscii(0, parserSlotOffset);
String[] statements = sql.split(SPLIT_STATEMENTS);

int length = statements.length;
if (length > 0)
{
String statement = statements[0];
String command = parser.parseCommand(statement);
final PgsqlTransform transform = clientTransforms.get(RisingwaveCommandType.valueOf(command.getBytes()));
transform.transform(this, traceId, authorizationId, statement);
}
splitStatements(sql)
.stream()
.findFirst()
.ifPresent(s ->
{
String statement = s;
String command = parser.parseCommand(statement);
final PgsqlTransform transform = clientTransforms.get(RisingwaveCommandType.valueOf(command.getBytes()));
transform.transform(this, traceId, authorizationId, statement);
});
}
}

Expand Down Expand Up @@ -1792,6 +1793,55 @@ private void proxyDataCommand(
server.doAppData(routedId, traceId, authorization, flags, buffer, offset, limit, extension);
}

public List<String> splitStatements(
String sql)
{
statements.clear();
currentStatement.setLength(0);

boolean inDollarQuotes = false;
int length = sql.length();

for (int i = 0; i < length; i++)
{
char c = sql.charAt(i);
currentStatement.append(c);

if (c == '$' && i + 1 < length && sql.charAt(i + 1) == '$')
{
inDollarQuotes = !inDollarQuotes;
currentStatement.append(sql.charAt(++i));
}
else if (c == ';' && !inDollarQuotes)
{
int j = i + 1;
while (j < length && Character.isWhitespace(sql.charAt(j)))
{
currentStatement.append(sql.charAt(j));
j++;
}
if (j < length && sql.charAt(j) == '\0')
{
currentStatement.append(sql.charAt(j));
i = j; // Move the main loop index forward
}
else
{
statements.add(currentStatement.toString());
currentStatement.setLength(0);
}
}
}

if (!currentStatement.isEmpty())
{
statements.add(currentStatement.toString());
}

return statements;
}


@FunctionalInterface
private interface PgsqlTransform
{
Expand Down
Loading