diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
index 642d71b18c9e..f0b14226adaf 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
@@ -29,10 +29,12 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("")
private val plusOrMinus = 2
def render(request: HttpServletRequest): Seq[Node] = {
- val requestedPage = Option(request.getParameter("page")).getOrElse("1").toInt
+ val requestedPage = Option(UIUtils.stripXSS(request.getParameter("page"))).getOrElse("1").toInt
val requestedFirst = (requestedPage - 1) * pageSize
+
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
val requestedIncomplete =
- Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean
+ Option(UIUtils.stripXSS(request.getParameter("showIncomplete"))).getOrElse("false").toBoolean
val allApps = parent.getApplicationList()
.filter(_.attempts.head.completed != requestedIncomplete)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index f405aa2bdc8b..c89fdc84ac95 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -33,7 +33,8 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
/** Executor details for a particular application */
def render(request: HttpServletRequest): Seq[Node] = {
- val appId = request.getParameter("appId")
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val appId = UIUtils.stripXSS(request.getParameter("appId"))
val state = master.askWithRetry[MasterStateResponse](RequestMasterState)
val app = state.activeApps.find(_.id == appId).getOrElse({
state.completedApps.find(_.id == appId).getOrElse(null)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
index ee539dd1f511..c9d8754dacd4 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
@@ -57,8 +57,10 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = {
if (parent.killEnabled &&
parent.master.securityMgr.checkModifyPermissions(request.getRemoteUser)) {
- val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean
- val id = Option(request.getParameter("id"))
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val killFlag =
+ Option(UIUtils.stripXSS(request.getParameter("terminate"))).getOrElse("false").toBoolean
+ val id = Option(UIUtils.stripXSS(request.getParameter("id")))
if (id.isDefined && killFlag) {
action(id.get)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala
index bc67fd460d9a..775f27fcaa87 100644
--- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala
@@ -30,7 +30,8 @@ import org.apache.spark.ui.{UIUtils, WebUIPage}
private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") {
override def render(request: HttpServletRequest): Seq[Node] = {
- val driverId = request.getParameter("id")
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val driverId = UIUtils.stripXSS(request.getParameter("id"))
require(driverId != null && driverId.nonEmpty, "Missing id parameter")
val state = parent.scheduler.getDriverState(driverId)
@@ -96,22 +97,22 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")
Mesos Slave ID |
{state.slaveId.getValue} |
-
- | Mesos Task ID |
- {state.taskId.getValue} |
-
-
- | Launch Time |
- {state.startDate} |
-
-
- | Finish Time |
- {state.finishDate.map(_.toString).getOrElse("")} |
-
-
- | Last Task Status |
- {state.mesosTaskStatus.map(_.toString).getOrElse("")} |
-
+
+ | Mesos Task ID |
+ {state.taskId.getValue} |
+
+
+ | Launch Time |
+ {state.startDate} |
+
+
+ | Finish Time |
+ {state.finishDate.map(_.toString).getOrElse("")} |
+
+
+ | Last Task Status |
+ {state.mesosTaskStatus.map(_.toString).getOrElse("")} |
+
}.getOrElse(Seq[Node]())
}
@@ -127,39 +128,39 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")
| Main class | {command.mainClass} |
-
- | Arguments | {command.arguments.mkString(" ")} |
-
-
- | Class path entries | {command.classPathEntries.mkString(" ")} |
-
-
- | Java options | {command.javaOpts.mkString((" "))} |
-
-
- | Library path entries | {command.libraryPathEntries.mkString((" "))} |
-
+
+ | Arguments | {command.arguments.mkString(" ")} |
+
+
+ | Class path entries | {command.classPathEntries.mkString(" ")} |
+
+
+ | Java options | {command.javaOpts.mkString((" "))} |
+
+
+ | Library path entries | {command.libraryPathEntries.mkString((" "))} |
+
}
private def driverRow(driver: MesosDriverDescription): Seq[Node] = {
| Name | {driver.name} |
-
- | Id | {driver.submissionId} |
-
-
- | Cores | {driver.cores} |
-
-
- | Memory | {driver.mem} |
-
-
- | Submitted | {driver.submissionDate} |
-
-
- | Supervise | {driver.supervise} |
-
+
+ | Id | {driver.submissionId} |
+
+
+ | Cores | {driver.cores} |
+
+
+ | Memory | {driver.mem} |
+
+
+ | Submitted | {driver.submissionDate} |
+
+
+ | Supervise | {driver.supervise} |
+
}
private def retryRow(retryState: Option[MesosClusterRetryState]): Seq[Node] = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
index 15f88e79cbf1..b74095b2a2a6 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
@@ -33,15 +33,18 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
private val workDir = new File(parent.workDir.toURI.normalize().getPath)
private val supportedLogTypes = Set("stderr", "stdout")
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
def renderLog(request: HttpServletRequest): String = {
val defaultBytes = 100 * 1024
- val appId = Option(request.getParameter("appId"))
- val executorId = Option(request.getParameter("executorId"))
- val driverId = Option(request.getParameter("driverId"))
- val logType = request.getParameter("logType")
- val offset = Option(request.getParameter("offset")).map(_.toLong)
- val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
+ val appId = Option(UIUtils.stripXSS(request.getParameter("appId")))
+ val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId")))
+ val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId")))
+ val logType = UIUtils.stripXSS(request.getParameter("logType"))
+ val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong)
+ val byteLength =
+ Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt)
+ .getOrElse(defaultBytes)
val logDir = (appId, executorId, driverId) match {
case (Some(a), Some(e), None) =>
@@ -57,14 +60,17 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
pre + logText
}
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
def render(request: HttpServletRequest): Seq[Node] = {
val defaultBytes = 100 * 1024
- val appId = Option(request.getParameter("appId"))
- val executorId = Option(request.getParameter("executorId"))
- val driverId = Option(request.getParameter("driverId"))
- val logType = request.getParameter("logType")
- val offset = Option(request.getParameter("offset")).map(_.toLong)
- val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
+ val appId = Option(UIUtils.stripXSS(request.getParameter("appId")))
+ val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId")))
+ val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId")))
+ val logType = UIUtils.stripXSS(request.getParameter("logType"))
+ val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong)
+ val byteLength =
+ Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt)
+ .getOrElse(defaultBytes)
val (logDir, params, pageName) = (appId, executorId, driverId) match {
case (Some(a), Some(e), None) =>
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index 1949c4b3cbf4..f0a8d436a7c4 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -25,6 +25,9 @@ import scala.util.control.NonFatal
import scala.xml._
import scala.xml.transform.{RewriteRule, RuleTransformer}
+import org.apache.spark.Logging
+import org.apache.commons.lang3.StringEscapeUtils
+
import org.apache.spark.Logging
import org.apache.spark.ui.scope.RDDOperationGraph
@@ -34,6 +37,8 @@ private[spark] object UIUtils extends Logging {
val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped"
val TABLE_CLASS_STRIPED_SORTABLE = TABLE_CLASS_STRIPED + " sortable"
+ private val NEWLINE_AND_SINGLE_QUOTE_REGEX = raw"(?i)(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)".r
+
// SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use.
private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
@@ -467,4 +472,22 @@ private[spark] object UIUtils extends Logging {
}
param
}
+
+ /**
+ * Remove suspicious characters of user input to prevent Cross-Site scripting (XSS) attacks
+ *
+ * For more information about XSS testing:
+ * https://www.owasp.org/index.php/XSS_Filter_Evasion_Cheat_Sheet and
+ * https://www.owasp.org/index.php/Testing_for_Reflected_Cross_site_scripting_(OTG-INPVAL-001)
+ */
+ def stripXSS(requestParameter: String): String = {
+ if (requestParameter == null) {
+ null
+ } else {
+ // Remove new lines and single quotes, followed by escaping HTML version 4.0
+ StringEscapeUtils.escapeHtml4(
+ NEWLINE_AND_SINGLE_QUOTE_REGEX.replaceAllIn(requestParameter, ""))
+ }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
index 32980544347a..7604c14137aa 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
@@ -28,8 +28,10 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage
private val sc = parent.sc
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
def render(request: HttpServletRequest): Seq[Node] = {
- val executorId = Option(request.getParameter("executorId")).map { executorId =>
+ val executorId =
+ Option(UIUtils.stripXSS(request.getParameter("executorId"))).map { executorId =>
UIUtils.decodeURLParameter(executorId)
}.getOrElse {
throw new IllegalArgumentException(s"Missing executorId parameter")
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
index 8c6a6681eabb..4f41cb7d3dcc 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
@@ -188,7 +188,8 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
val listener = parent.jobProgresslistener
listener.synchronized {
- val parameterId = request.getParameter("id")
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")
val jobId = parameterId.toInt
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
index 77ca60b000a9..b159c5c96d9e 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
@@ -16,7 +16,6 @@
*/
package org.apache.spark.ui.jobs
-
import org.apache.spark.scheduler.SchedulingMode
import org.apache.spark.ui.{SparkUI, SparkUITab}
@@ -33,4 +32,5 @@ private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") {
attachPage(new AllJobsPage(this))
attachPage(new JobPage(this))
+
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
index 778272a6da1e..b293f7d7a301 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
@@ -31,7 +31,8 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") {
def render(request: HttpServletRequest): Seq[Node] = {
listener.synchronized {
- val poolName = Option(request.getParameter("poolname")).map { poolname =>
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val poolName = Option(UIUtils.stripXSS(request.getParameter("poolname"))).map { poolname =>
UIUtils.decodeURLParameter(poolname)
}.getOrElse {
throw new IllegalArgumentException(s"Missing poolname parameter")
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 5183c80ab452..8b367523e01e 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -87,16 +87,17 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
def render(request: HttpServletRequest): Seq[Node] = {
progressListener.synchronized {
- val parameterId = request.getParameter("id")
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")
- val parameterAttempt = request.getParameter("attempt")
+ val parameterAttempt = UIUtils.stripXSS(request.getParameter("attempt"))
require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter")
- val parameterTaskPage = request.getParameter("task.page")
- val parameterTaskSortColumn = request.getParameter("task.sort")
- val parameterTaskSortDesc = request.getParameter("task.desc")
- val parameterTaskPageSize = request.getParameter("task.pageSize")
+ val parameterTaskPage = UIUtils.stripXSS(request.getParameter("task.page"))
+ val parameterTaskSortColumn = UIUtils.stripXSS(request.getParameter("task.sort"))
+ val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc"))
+ val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize"))
val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1)
val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn =>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
index 5989f0035b27..c7eb34e0c7d4 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs
import javax.servlet.http.HttpServletRequest
import org.apache.spark.scheduler.SchedulingMode
-import org.apache.spark.ui.{SparkUI, SparkUITab}
+import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils}
/** Web UI showing progress status of all stages in the given SparkContext. */
private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages") {
@@ -38,8 +38,10 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages"
def handleKillRequest(request: HttpServletRequest): Unit = {
if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) {
- val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean
- val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt
+ val killFlag = Option(UIUtils.stripXSS(request.getParameter("terminate")))
+ .getOrElse("false").toBoolean
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val stageId = Option(UIUtils.stripXSS(request.getParameter("id"))).getOrElse("-1").toInt
if (stageId >= 0 && killFlag && progressListener.activeStages.contains(stageId)) {
sc.get.cancelStage(stageId)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
index fd6cc3ed759b..0e870f7a6fcc 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
@@ -31,13 +31,14 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") {
private val listener = parent.listener
def render(request: HttpServletRequest): Seq[Node] = {
- val parameterId = request.getParameter("id")
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")
- val parameterBlockPage = request.getParameter("block.page")
- val parameterBlockSortColumn = request.getParameter("block.sort")
- val parameterBlockSortDesc = request.getParameter("block.desc")
- val parameterBlockPageSize = request.getParameter("block.pageSize")
+ val parameterBlockPage = UIUtils.stripXSS(request.getParameter("block.page"))
+ val parameterBlockSortColumn = UIUtils.stripXSS(request.getParameter("block.sort"))
+ val parameterBlockSortDesc = UIUtils.stripXSS(request.getParameter("block.desc"))
+ val parameterBlockPageSize = UIUtils.stripXSS(request.getParameter("block.pageSize"))
val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1)
val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block Name")
diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
index bc8a5d494dbd..d983d0d96e95 100644
--- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
@@ -81,6 +81,45 @@ class UIUtilsSuite extends SparkFunSuite {
assert(decoded2 === decodeURLParameter(decoded2))
}
+ test("SPARK-20393: Prevent newline characters in parameters.") {
+ val encoding = "Encoding:base64%0d%0a%0d%0aPGh0bWw%2bjcmlwdD48L2h0bWw%2b"
+ val stripEncoding = "Encoding:base64PGh0bWw%2bjcmlwdD48L2h0bWw%2b"
+
+ assert(stripEncoding === stripXSS(encoding))
+ }
+
+ test("SPARK-20393: Prevent script from parameters running on page.") {
+ val scriptAlert = """>"'>