Skip to content

Commit 2063c43

Browse files
committed
[KYUUBI #3222][FOLLOWUP] Introdude JdbcUtils to simplify code
1 parent 97b14f8 commit 2063c43

File tree

4 files changed

+251
-199
lines changed

4 files changed

+251
-199
lines changed

kyuubi-common/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@
119119
<artifactId>HikariCP</artifactId>
120120
</dependency>
121121

122+
<dependency>
123+
<groupId>com.jakewharton.fliptables</groupId>
124+
<artifactId>fliptables</artifactId>
125+
</dependency>
126+
122127
<dependency>
123128
<groupId>org.apache.hadoop</groupId>
124129
<artifactId>hadoop-minikdc</artifactId>

kyuubi-common/src/main/scala/org/apache/kyuubi/service/authentication/JdbcAuthenticationProviderImpl.scala

Lines changed: 61 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,48 @@
1717

1818
package org.apache.kyuubi.service.authentication
1919

20-
import java.sql.{Connection, PreparedStatement, Statement}
2120
import java.util.Properties
2221
import javax.security.sasl.AuthenticationException
22+
import javax.sql.DataSource
2323

2424
import com.zaxxer.hikari.{HikariConfig, HikariDataSource}
2525
import org.apache.commons.lang3.StringUtils
2626

2727
import org.apache.kyuubi.Logging
2828
import org.apache.kyuubi.config.KyuubiConf
2929
import org.apache.kyuubi.config.KyuubiConf._
30+
import org.apache.kyuubi.util.JdbcUtils
3031

3132
class JdbcAuthenticationProviderImpl(conf: KyuubiConf) extends PasswdAuthenticationProvider
3233
with Logging {
3334

34-
private val driverClass = conf.get(AUTHENTICATION_JDBC_DRIVER)
35-
private val jdbcUrl = conf.get(AUTHENTICATION_JDBC_URL)
36-
private val jdbcUsername = conf.get(AUTHENTICATION_JDBC_USERNAME)
37-
private val jdbcUserPassword = conf.get(AUTHENTICATION_JDBC_PASSWORD)
38-
private val authQuerySql = conf.get(AUTHENTICATION_JDBC_QUERY)
39-
4035
private val SQL_PLACEHOLDER_REGEX = """\$\{.+?}""".r
4136
private val USERNAME_SQL_PLACEHOLDER = "${username}"
4237
private val PASSWORD_SQL_PLACEHOLDER = "${password}"
4338

39+
private val driverClass = conf.get(AUTHENTICATION_JDBC_DRIVER)
40+
private val jdbcUrl = conf.get(AUTHENTICATION_JDBC_URL)
41+
private val username = conf.get(AUTHENTICATION_JDBC_USERNAME)
42+
private val password = conf.get(AUTHENTICATION_JDBC_PASSWORD)
43+
private val authQuery = conf.get(AUTHENTICATION_JDBC_QUERY)
44+
45+
private val redactedPasswd = password match {
46+
case Some(value) => s"${"*" * value.length}(length: ${value.length})"
47+
case None => "(empty)"
48+
}
49+
4450
checkJdbcConfigs()
4551

46-
private[kyuubi] val hikariDataSource = getHikariDataSource
52+
implicit private[kyuubi] val ds: DataSource = {
53+
val datasourceProperties = new Properties()
54+
val hikariConfig = new HikariConfig(datasourceProperties)
55+
hikariConfig.setDriverClassName(driverClass.orNull)
56+
hikariConfig.setJdbcUrl(jdbcUrl.orNull)
57+
hikariConfig.setUsername(username.orNull)
58+
hikariConfig.setPassword(password.orNull)
59+
hikariConfig.setPoolName("jdbc-auth-pool")
60+
new HikariDataSource(hikariConfig)
61+
}
4762

4863
/**
4964
* The authenticate method is called by the Kyuubi Server authentication layer
@@ -62,37 +77,27 @@ class JdbcAuthenticationProviderImpl(conf: KyuubiConf) extends PasswdAuthenticat
6277
s" or contains blank space")
6378
}
6479

65-
if (StringUtils.isBlank(password)) {
66-
throw new AuthenticationException(s"Error validating, password is null" +
67-
s" or contains blank space")
68-
}
69-
70-
var connection: Connection = null
71-
var queryStatement: PreparedStatement = null
72-
7380
try {
74-
connection = hikariDataSource.getConnection
75-
76-
queryStatement = getAndPrepareQueryStatement(connection, user, password)
77-
78-
val resultSet = queryStatement.executeQuery()
79-
80-
if (resultSet == null || !resultSet.next()) {
81-
// auth failed
82-
throw new AuthenticationException(s"Password does not match or no such user. user:" +
83-
s" $user , password length: ${password.length}")
81+
debug(s"prepared auth query: $preparedQuery")
82+
JdbcUtils.executeQuery(preparedQuery) { stmt =>
83+
stmt.setMaxRows(1) // minimum result size required for authentication
84+
queryPlaceholders.zipWithIndex.foreach {
85+
case (USERNAME_SQL_PLACEHOLDER, i) => stmt.setString(i + 1, user)
86+
case (PASSWORD_SQL_PLACEHOLDER, i) => stmt.setString(i + 1, password)
87+
case (p, _) => throw new IllegalArgumentException(
88+
s"Unrecognized placeholder in Query SQL: $p")
89+
}
90+
} { resultSet =>
91+
if (resultSet == null || !resultSet.next()) {
92+
throw new AuthenticationException("Password does not match or no such user. " +
93+
s"user: $user, password: $redactedPasswd")
94+
}
8495
}
85-
86-
// auth passed
87-
8896
} catch {
89-
case e: AuthenticationException =>
90-
throw e
91-
case e: Exception =>
92-
error("Cannot get user info", e);
93-
throw e
94-
} finally {
95-
closeDbConnection(connection, queryStatement)
97+
case rethrow: AuthenticationException =>
98+
throw rethrow
99+
case rethrow: Exception =>
100+
throw new AuthenticationException("Cannot get user info", rethrow)
96101
}
97102
}
98103

@@ -101,104 +106,31 @@ class JdbcAuthenticationProviderImpl(conf: KyuubiConf) extends PasswdAuthenticat
101106

102107
debug(configLog("Driver Class", driverClass.orNull))
103108
debug(configLog("JDBC URL", jdbcUrl.orNull))
104-
debug(configLog("Database username", jdbcUsername.orNull))
105-
debug(configLog("Database password length", jdbcUserPassword.getOrElse("").length.toString))
106-
debug(configLog("Query SQL", authQuerySql.orNull))
109+
debug(configLog("Database username", username.orNull))
110+
debug(configLog("Database password", redactedPasswd))
111+
debug(configLog("Query SQL", authQuery.orNull))
107112

108113
// Check if JDBC parameters valid
109-
if (driverClass.isEmpty) {
110-
throw new IllegalArgumentException("JDBC driver class is not configured.")
111-
}
112-
113-
if (jdbcUrl.isEmpty) {
114-
throw new IllegalArgumentException("JDBC url is not configured")
115-
}
116-
117-
if (jdbcUsername.isEmpty || jdbcUserPassword.isEmpty) {
118-
throw new IllegalArgumentException("JDBC username or password is not configured")
114+
require(driverClass.nonEmpty, "JDBC driver class is not configured.")
115+
require(jdbcUrl.nonEmpty, "JDBC url is not configured.")
116+
require(username.nonEmpty, "JDBC username is not configured")
117+
// allow empty password
118+
require(authQuery.nonEmpty, "Query SQL is not configured")
119+
120+
val query = authQuery.get.trim.toLowerCase
121+
// allow simple select query sql only, complex query like CTE is not allowed
122+
require(query.startsWith("select"), "Query SQL must start with 'SELECT'")
123+
if (!query.contains("where")) {
124+
warn("Query SQL does not contains 'WHERE' keyword")
119125
}
120-
121-
// Check Query SQL
122-
if (authQuerySql.isEmpty) {
123-
throw new IllegalArgumentException("Query SQL is not configured")
124-
}
125-
val querySqlInLowerCase = authQuerySql.get.trim.toLowerCase
126-
if (!querySqlInLowerCase.startsWith("select")) { // allow select query sql only
127-
throw new IllegalArgumentException("Query SQL must start with \"SELECT\"");
128-
}
129-
if (!querySqlInLowerCase.contains("where")) {
130-
warn("Query SQL does not contains \"WHERE\" keyword");
131-
}
132-
if (!querySqlInLowerCase.contains("${username}")) {
133-
warn("Query SQL does not contains \"${username}\" placeholder");
134-
}
135-
}
136-
137-
private def getPlaceholderList(sql: String): List[String] = {
138-
SQL_PLACEHOLDER_REGEX.findAllMatchIn(sql)
139-
.map(m => m.matched)
140-
.toList
141-
}
142-
143-
private def getAndPrepareQueryStatement(
144-
connection: Connection,
145-
user: String,
146-
password: String): PreparedStatement = {
147-
148-
val preparedSql: String = {
149-
SQL_PLACEHOLDER_REGEX.replaceAllIn(authQuerySql.get, "?")
150-
}
151-
debug(s"prepared auth query sql: $preparedSql")
152-
153-
val stmt = connection.prepareStatement(preparedSql)
154-
stmt.setMaxRows(1) // minimum result size required for authentication
155-
156-
// Extract placeholder list and fill parameters to placeholders
157-
val placeholderList: List[String] = getPlaceholderList(authQuerySql.get)
158-
for (i <- placeholderList.indices) {
159-
val param = placeholderList(i) match {
160-
case USERNAME_SQL_PLACEHOLDER => user
161-
case PASSWORD_SQL_PLACEHOLDER => password
162-
case otherPlaceholder =>
163-
throw new IllegalArgumentException(
164-
s"Unrecognized Placeholder In Query SQL: $otherPlaceholder")
165-
}
166-
167-
stmt.setString(i + 1, param)
168-
}
169-
170-
stmt
171-
}
172-
173-
private def closeDbConnection(connection: Connection, statement: Statement): Unit = {
174-
if (statement != null && !statement.isClosed) {
175-
try {
176-
statement.close()
177-
} catch {
178-
case e: Exception =>
179-
error("Cannot close PreparedStatement to auth database ", e)
180-
}
181-
}
182-
183-
if (connection != null && !connection.isClosed) {
184-
try {
185-
connection.close()
186-
} catch {
187-
case e: Exception =>
188-
error("Cannot close connection to auth database ", e)
189-
}
126+
if (!query.contains("${username}")) {
127+
warn("Query SQL does not contains '${username}' placeholder")
190128
}
191129
}
192130

193-
private def getHikariDataSource: HikariDataSource = {
194-
val datasourceProperties = new Properties()
195-
val hikariConfig = new HikariConfig(datasourceProperties)
196-
hikariConfig.setDriverClassName(driverClass.orNull)
197-
hikariConfig.setJdbcUrl(jdbcUrl.orNull)
198-
hikariConfig.setUsername(jdbcUsername.orNull)
199-
hikariConfig.setPassword(jdbcUserPassword.orNull)
200-
hikariConfig.setPoolName("jdbc-auth-pool")
131+
private def preparedQuery: String =
132+
SQL_PLACEHOLDER_REGEX.replaceAllIn(authQuery.get, "?")
201133

202-
new HikariDataSource(hikariConfig)
203-
}
134+
private def queryPlaceholders: Iterator[String] =
135+
SQL_PLACEHOLDER_REGEX.findAllMatchIn(authQuery.get).map(_.matched)
204136
}
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.kyuubi.util
19+
20+
import java.sql.{Connection, PreparedStatement, ResultSet, ResultSetMetaData}
21+
import java.util
22+
import javax.sql.DataSource
23+
24+
import scala.util.control.NonFatal
25+
26+
import com.jakewharton.fliptables.FlipTable
27+
28+
import org.apache.kyuubi.Logging
29+
30+
object JdbcUtils extends Logging {
31+
32+
def close(c: AutoCloseable): Unit = {
33+
if (c != null) {
34+
try {
35+
c.close()
36+
} catch {
37+
case NonFatal(t) => warn(s"Error on closing", t)
38+
}
39+
}
40+
}
41+
42+
def withCloseable[R, C <: AutoCloseable](c: C)(block: C => R): R = {
43+
try {
44+
block(c)
45+
} finally {
46+
close(c)
47+
}
48+
}
49+
50+
def withConnection[R](block: Connection => R)(implicit ds: DataSource): R = {
51+
withCloseable(ds.getConnection)(block)
52+
}
53+
54+
def execute(
55+
sqlTemplate: String)(
56+
setParameters: PreparedStatement => Unit = _ => {})(
57+
implicit ds: DataSource): Boolean = withConnection { conn =>
58+
withCloseable(conn.prepareStatement(sqlTemplate)) { pStmt =>
59+
setParameters(pStmt)
60+
pStmt.execute()
61+
}
62+
}
63+
64+
def executeUpdate(
65+
sqlTemplate: String)(
66+
setParameters: PreparedStatement => Unit = _ => {})(
67+
implicit ds: DataSource): Int = withConnection { conn =>
68+
withCloseable(conn.prepareStatement(sqlTemplate)) { pStmt =>
69+
setParameters(pStmt)
70+
pStmt.executeUpdate()
71+
}
72+
}
73+
74+
def executeQuery[R](
75+
sqlTemplate: String)(
76+
setParameters: PreparedStatement => Unit = _ => {})(
77+
processResultSet: ResultSet => R)(
78+
implicit ds: DataSource): R = withConnection { conn =>
79+
withCloseable(conn.prepareStatement(sqlTemplate)) { pStmt =>
80+
setParameters(pStmt)
81+
withCloseable(pStmt.executeQuery()) { rs =>
82+
processResultSet(rs)
83+
}
84+
}
85+
}
86+
87+
def executeQueryWithRowMapper[R](
88+
sqlTemplate: String)(
89+
setParameters: PreparedStatement => Unit = _ => {})(
90+
rowMapper: ResultSet => R)(
91+
implicit ds: DataSource): Seq[R] = withConnection { conn =>
92+
withCloseable(conn.prepareStatement(sqlTemplate)) { pStmt =>
93+
setParameters(pStmt)
94+
withCloseable(pStmt.executeQuery()) { rs =>
95+
val builder = Seq.newBuilder[R]
96+
while (rs.next()) builder += rowMapper(rs)
97+
builder.result
98+
}
99+
}
100+
}
101+
102+
def queryAndRenderResultSet(sql: String)(implicit ds: DataSource): String =
103+
withConnection { conn =>
104+
withCloseable(conn.prepareStatement(sql).executeQuery()) { rs =>
105+
renderResultSet(rs)
106+
}
107+
}
108+
109+
private def renderResultSet(resultSet: ResultSet): String = {
110+
if (resultSet == null) throw new NullPointerException("resultSet == null")
111+
val headers: util.List[String] = new util.ArrayList[String]
112+
val resultSetMetaData: ResultSetMetaData = resultSet.getMetaData
113+
val columnCount: Int = resultSetMetaData.getColumnCount
114+
for (column <- 0 until columnCount) {
115+
headers.add(resultSetMetaData.getColumnName(column + 1))
116+
}
117+
val data: util.List[Array[String]] = new util.ArrayList[Array[String]]
118+
while ({
119+
resultSet.next
120+
}) {
121+
val rowData: Array[String] = new Array[String](columnCount)
122+
for (column <- 0 until columnCount) {
123+
rowData(column) = resultSet.getString(column + 1)
124+
}
125+
data.add(rowData)
126+
}
127+
val headerArray: Array[String] = headers.toArray(new Array[String](headers.size))
128+
val dataArray: Array[Array[String]] = data.toArray(new Array[Array[String]](data.size))
129+
FlipTable.of(headerArray, dataArray)
130+
}
131+
}

0 commit comments

Comments
 (0)