Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
flattened annotation support
  • Loading branch information
gaborbarna committed Feb 3, 2018
commit 36d5eb5cf47a1a1a610c7562f348a32fbb43495a
52 changes: 36 additions & 16 deletions core/src/main/scala/ste/encoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,25 @@ import scala.annotation.StaticAnnotation
import scala.collection.generic.IsTraversableOnce

final class Meta(val metadata: Metadata) extends StaticAnnotation
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have made this a case class?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think we should do that, and remove the val keyword.

final case class Flattened(times: Int = 1, keys: Seq[String] = Seq()) extends StaticAnnotation

@annotation.implicitNotFound("""
Type ${A} does not have a DataTypeEncoder defined in the library.
You need to define one yourself.
""")
sealed trait DataTypeEncoder[A] {
def encode: DataType
def fields: Option[Seq[StructField]]
def nullable: Boolean
}

object DataTypeEncoder {
def apply[A](implicit enc: DataTypeEncoder[A]): DataTypeEncoder[A] = enc

def pure[A](dt: DataType, isNullable: Boolean = false): DataTypeEncoder[A] =
def pure[A](dt: DataType, f: Option[Seq[StructField]] = None, isNullable: Boolean = false): DataTypeEncoder[A] =
new DataTypeEncoder[A] {
def encode: DataType = dt
def fields: Option[Seq[StructField]] = f
def nullable: Boolean = isNullable
}
}
Expand All @@ -56,6 +59,7 @@ object DataTypeEncoder {
""")
sealed trait StructTypeEncoder[A] extends DataTypeEncoder[A] {
def encode: StructType
def fields: Option[Seq[StructField]]
def nullable: Boolean
}

Expand All @@ -65,6 +69,7 @@ object StructTypeEncoder extends MediumPriorityImplicits {
def pure[A](st: StructType, isNullable: Boolean = false): StructTypeEncoder[A] =
new StructTypeEncoder[A] {
def encode: StructType = st
def fields: Option[Seq[StructField]] = Some(st.fields.toSeq)
def nullable: Boolean = isNullable
}
}
Expand All @@ -80,7 +85,7 @@ sealed trait AnnotatedStructTypeEncoder[A] {
}

object AnnotatedStructTypeEncoder extends MediumPriorityImplicits {
type Encode = Seq[Metadata] => StructType
type Encode = (Seq[Metadata], Seq[Option[Flattened]]) => StructType

def pure[A](enc: Encode): AnnotatedStructTypeEncoder[A] =
new AnnotatedStructTypeEncoder[A] {
Expand All @@ -89,29 +94,44 @@ object AnnotatedStructTypeEncoder extends MediumPriorityImplicits {
}

trait LowPriorityImplicits {
implicit val hnilEncoder: AnnotatedStructTypeEncoder[HNil] = AnnotatedStructTypeEncoder.pure(_ => StructType(Nil))
implicit val hnilEncoder: AnnotatedStructTypeEncoder[HNil] =
AnnotatedStructTypeEncoder.pure((_, _) => StructType(Nil))
implicit def hconsEncoder[K <: Symbol, H, T <: HList](
implicit
witness: Witness.Aux[K],
hEncoder: Lazy[DataTypeEncoder[H]],
tEncoder: AnnotatedStructTypeEncoder[T]
): AnnotatedStructTypeEncoder[FieldType[K, H] :: T] = AnnotatedStructTypeEncoder.pure { metadata =>
): AnnotatedStructTypeEncoder[FieldType[K, H] :: T] = AnnotatedStructTypeEncoder.pure { (metadata, flattened) =>
val fieldName = witness.value.name
val head = hEncoder.value.encode
val nullable = hEncoder.value.nullable
val tail = tEncoder.encode(metadata.tail)
StructType(StructField(fieldName, head, nullable, metadata.head) +: tail.fields)
val fields = flattened.head.flatMap(f => hEncoder.value.fields.map(flattenFields(_, fieldName, f))).getOrElse(
Seq(StructField(fieldName, hEncoder.value.encode, hEncoder.value.nullable, metadata.head)))
val tail = tEncoder.encode(metadata.tail, flattened.tail)
StructType(fields ++ tail.fields)
}

implicit def recordEncoder[A, H <: HList, HA <: HList](
private def flattenFields(fields: Seq[StructField], prefix: String, flattened: Flattened) = flattened match {
case Flattened(times, _) if times > 1 =>
(0 until times).flatMap(i => fields.map(prefixStructField(_, s"$prefix$i")))
case Flattened(_, keys) if keys.nonEmpty =>
keys.flatMap(k => fields.map(prefixStructField(_, s"$k.$prefix")))
case Flattened(_, _) => fields.map(prefixStructField(_, prefix))
}

private def prefixStructField(f: StructField, prefix: String) =
f.copy(name = s"$prefix.${f.name}")

implicit def recordEncoder[A, H <: HList, HA <: HList, HF <: HList](
implicit
generic: LabelledGeneric.Aux[A, H],
annotations: Annotations.Aux[Meta, A, HA],
metaAnnotations: Annotations.Aux[Meta, A, HA],
flattenedAnnotations: Annotations.Aux[Flattened, A, HF],
hEncoder: Lazy[AnnotatedStructTypeEncoder[H]],
toList: ToList[HA, Option[Meta]]
metaToList: ToList[HA, Option[Meta]],
flattenedToList: ToList[HF, Option[Flattened]]
): StructTypeEncoder[A] = {
val metadata = annotations().toList[Option[Meta]].map(extractMetadata)
StructTypeEncoder.pure(hEncoder.value.encode(metadata))
val metadata = metaAnnotations().toList[Option[Meta]].map(extractMetadata)
val flattened = flattenedAnnotations().toList[Option[Flattened]]
StructTypeEncoder.pure(hEncoder.value.encode(metadata, flattened))
}

private val extractMetadata: Option[Meta] => Metadata =
Expand Down Expand Up @@ -153,16 +173,16 @@ trait MediumPriorityImplicits extends LowPriorityImplicits {
enc: DataTypeEncoder[A0],
is: IsTraversableOnce[C[A0]] { type A = A0 }
): DataTypeEncoder[C[A0]] =
DataTypeEncoder.pure(ArrayType(enc.encode))
DataTypeEncoder.pure(ArrayType(enc.encode), enc.fields)
implicit def mapEncoder[K, V](
implicit
kEnc: DataTypeEncoder[K],
vEnc: DataTypeEncoder[V]
): DataTypeEncoder[Map[K, V]] =
DataTypeEncoder.pure(MapType(kEnc.encode, vEnc.encode))
DataTypeEncoder.pure(MapType(kEnc.encode, vEnc.encode), vEnc.fields)
implicit def optionEncoder[V](
implicit
enc: DataTypeEncoder[V]
): DataTypeEncoder[Option[V]] =
DataTypeEncoder.pure(enc.encode, true)
DataTypeEncoder.pure(enc.encode, isNullable = true)
}
15 changes: 12 additions & 3 deletions core/src/test/scala/ste/StructTypeEncoderSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,18 @@ class StructTypeEncoderSpec extends FlatSpec with Matchers {
.build

case class Foo(a: String, @Meta(metadata) b: Int)
StructTypeEncoder[Foo].encode shouldBe StructType(
StructField("a", StringType, false) ::
StructField("b", IntegerType, false, metadata) :: Nil
case class Bar(@Flattened(2) a: Seq[Foo], @Flattened(1, Seq("x", "y")) b: Map[Symbol, Foo], @Flattened c: Foo)
StructTypeEncoder[Bar].encode shouldBe StructType(
StructField("a0.a", StringType, false) ::
StructField("a0.b", IntegerType, false, metadata) ::
StructField("a1.a", StringType, false) ::
StructField("a1.b", IntegerType, false, metadata) ::
StructField("x.b.a", StringType, false) ::
StructField("x.b.b", IntegerType, false, metadata) ::
StructField("y.b.a", StringType, false) ::
StructField("y.b.b", IntegerType, false, metadata) ::
StructField("c.a", StringType, false) ::
StructField("c.b", IntegerType, false, metadata) :: Nil
)
}
}