Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
cleanup
  • Loading branch information
gaborbarna committed Feb 13, 2018
commit c230e9746647d48e34ffcafb3515a9def3835718
30 changes: 18 additions & 12 deletions core/src/main/scala/ste/selector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ sealed trait DataTypeSelector[A] {
}

object DataTypeSelector {
type Prefixes = Seq[Prefix]
type Prefixes = List[Prefix]
type Select = (DataFrame, Option[Prefixes]) => DataFrame

def pure[A](s: Select): DataTypeSelector[A] =
Expand Down Expand Up @@ -120,11 +120,11 @@ trait SelectorImplicits {
tSelector: AnnotatedStructTypeSelector[T]
): AnnotatedStructTypeSelector[FieldType[K, H] :: T] = AnnotatedStructTypeSelector.pure { (df, parentPrefixes, flatten) =>
val fieldName = witness.value.name
val prefixes = parentPrefixes.map(_.map(_.addSuffix(fieldName))).getOrElse(Seq(Prefix(fieldName)))
val prefixes = parentPrefixes.map(_.map(_.addSuffix(fieldName))).getOrElse(List(Prefix(fieldName)))
val childPrefixes = getChildPrefixes(prefixes, flatten.head)
val dfHead = hSelector.value.select(df, Some(childPrefixes))
val dfNested = flatten.head.map { fl =>
val fields = dfHead.schema.fields.map(f => Prefix(f.name)).toSeq
val fields = dfHead.schema.fields.map(f => Prefix(f.name)).toList
val restCols = fields.filter(f => !childPrefixes.exists(_.isParentOf(f))).map(f => dfHead(f.quotedString))
val structs = childPrefixes.map { p =>
val cols = fields.filter(_.isChildrenOf(p)).map(f => dfHead(f.quotedString).as(f.getSuffix))
Expand All @@ -137,14 +137,14 @@ trait SelectorImplicits {
tSelector.select(dfNested, parentPrefixes, flatten.tail)
}

private def getChildPrefixes(prefixes: Seq[Prefix], flatten: Option[Flatten]) =
flatten.map(_ match {
case Flatten(times, _) if times > 1 => (0 until times).flatMap(i => prefixes.map(_.addSuffix(i)))
case Flatten(_, keys) if keys.nonEmpty => keys.flatMap(k => prefixes.map(_.addSuffix(k)))
private def getChildPrefixes(prefixes: List[Prefix], flatten: Option[Flatten]): List[Prefix] =
flatten.map {
case Flatten(times, _) if times > 1 => (0 until times).flatMap(i => prefixes.map(_.addSuffix(i))).toList
case Flatten(_, keys) if keys.nonEmpty => keys.flatMap(k => prefixes.map(_.addSuffix(k))).toList
case Flatten(_, _) => prefixes
}).getOrElse(prefixes)
}.getOrElse(prefixes)

private def getNestedColumns(prefixes: Seq[Prefix], df: DataFrame, flatten: Flatten): Map[Prefix, Column] =
private def getNestedColumns(prefixes: List[Prefix], df: DataFrame, flatten: Flatten): Map[Prefix, Column] =
prefixes.groupBy(_.getParent).map { case (prefix, groupedPrefixes) =>
val colName = prefix.toString
val cols = groupedPrefixes.map(p => df(p.quotedString))
Expand All @@ -155,16 +155,16 @@ trait SelectorImplicits {
}
}(breakOut)

private def orderedSelect(df: DataFrame, nestedCols: Map[Prefix, Column], fields: Seq[Prefix]) = {
private def orderedSelect(df: DataFrame, nestedCols: Map[Prefix, Column], fields: List[Prefix]): DataFrame = {
@tailrec
def loop(nestedCols: Map[Prefix, Column], fields: Seq[Prefix], cols: Seq[Column]): Seq[Column] = fields match {
def loop(nestedCols: Map[Prefix, Column], fields: List[Prefix], cols: List[Column]): List[Column] = fields match {
case Nil => cols.reverse
case hd +: tail => nestedCols.find { case (p, _) => p.isParentOf(hd) } match {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since you have lists you can do hd :: tail, same thing below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea, unfortunately shapeless overrides the :: definition

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, too bad :(

case Some((p, c)) => loop(nestedCols - p, fields.dropWhile(_.isChildrenOf(p)), c +: cols)
case None => loop(nestedCols, tail, df(hd.quotedString) +: cols)
}
}
val cols = loop(nestedCols, fields, Seq[Column]())
val cols = loop(nestedCols, fields, List[Column]())
df.select(cols :_*)
}

Expand Down Expand Up @@ -209,6 +209,12 @@ trait SelectorImplicits {
): DataTypeSelector[collection.Map[K, V]] = DataTypeSelector.pure { (df, prefixes) =>
s.select(df, prefixes)
}

implicit def immutableMapSelector[K, V](
implicit s: DataTypeSelector[V]
): DataTypeSelector[Map[K, V]] = DataTypeSelector.pure { (df, prefixes) =>
s.select(df, prefixes)
}
}

object DFUtils {
Expand Down
1 change: 0 additions & 1 deletion core/src/test/scala/ste/StructTypeSelectorSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ package ste
import org.apache.spark.sql.types._
import org.apache.spark.sql.SparkSession
import org.scalatest.{ FlatSpec, Matchers }
import scala.collection
import ste._
import StructTypeEncoder._
import StructTypeSelector._
Expand Down