Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Use *AggMsgSerializer for all RDD[(VertexId, VD)] shuffles
  • Loading branch information
ankurdave committed Apr 21, 2014
commit cccff338630c814364c5cc2ec55c032f448d2ee4
11 changes: 6 additions & 5 deletions graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.util.collection.PrimitiveVector
import org.apache.spark.graphx.impl.MsgRDDFunctions
import org.apache.spark.graphx.impl.RoutingTableMessage
import org.apache.spark.graphx.impl.RoutingTableMessageRDDFunctions._
import org.apache.spark.graphx.impl.VertexRDDFunctions._
import org.apache.spark.graphx.impl.RoutingTablePartition
import org.apache.spark.graphx.impl.ShippableVertexPartition
import org.apache.spark.graphx.impl.VertexAttributeBlock
Expand Down Expand Up @@ -227,7 +228,7 @@ class VertexRDD[@specialized VD: ClassTag](
case _ =>
new VertexRDD[VD3](
partitionsRDD.zipPartitions(
other.partitionBy(this.partitioner.get), preservesPartitioning = true) {
other.copartitionWithVertices(this.partitioner.get), preservesPartitioning = true) {
(partIter, msgs) => partIter.map(_.leftJoin(msgs)(f))
}
)
Expand Down Expand Up @@ -271,7 +272,7 @@ class VertexRDD[@specialized VD: ClassTag](
case _ =>
new VertexRDD(
partitionsRDD.zipPartitions(
other.partitionBy(this.partitioner.get), preservesPartitioning = true) {
other.copartitionWithVertices(this.partitioner.get), preservesPartitioning = true) {
(partIter, msgs) => partIter.map(_.innerJoin(msgs)(f))
}
)
Expand All @@ -291,7 +292,7 @@ class VertexRDD[@specialized VD: ClassTag](
*/
def aggregateUsingIndex[VD2: ClassTag](
messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] = {
val shuffled = MsgRDDFunctions.partitionForAggregation(messages, this.partitioner.get)
val shuffled = messages.copartitionWithVertices(this.partitioner.get)
val parts = partitionsRDD.zipPartitions(shuffled, true) { (thisIter, msgIter) =>
thisIter.map(_.aggregateUsingIndex(msgIter, reduceFunc))
}
Expand Down Expand Up @@ -329,7 +330,7 @@ object VertexRDD {
def apply[VD: ClassTag](vertices: RDD[(VertexId, VD)]): VertexRDD[VD] = {
val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match {
case Some(p) => vertices
case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.size))
case None => vertices.copartitionWithVertices(new HashPartitioner(vertices.partitions.size))
}
val vertexPartitions = vPartitioned.mapPartitions(
iter => Iterator(ShippableVertexPartition(iter)),
Expand All @@ -356,7 +357,7 @@ object VertexRDD {
): VertexRDD[VD] = {
val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match {
case Some(p) => vertices
case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.size))
case None => vertices.copartitionWithVertices(new HashPartitioner(vertices.partitions.size))
}
val routingTables = createRoutingTables(edges, vPartitioned.partitioner.get)
val vertexPartitions = vPartitioned.zipPartitions(routingTables, preservesPartitioning = true) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ class MsgRDDFunctions[T: ClassTag](self: RDD[MessageToPartition[T]]) {

}


private[graphx]
object MsgRDDFunctions {
implicit def rdd2PartitionRDDFunctions[T: ClassTag](rdd: RDD[MessageToPartition[T]]) = {
Expand All @@ -98,18 +97,28 @@ object MsgRDDFunctions {
implicit def rdd2vertexMessageRDDFunctions[T: ClassTag](rdd: RDD[VertexBroadcastMsg[T]]) = {
new VertexBroadcastMsgRDDFunctions(rdd)
}
}

def partitionForAggregation[T: ClassTag](msgs: RDD[(VertexId, T)], partitioner: Partitioner) = {
val rdd = new ShuffledRDD[VertexId, T, (VertexId, T)](msgs, partitioner)
private[graphx]
class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) {
def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = {
val rdd = new ShuffledRDD[VertexId, VD, (VertexId, VD)](self, partitioner)

// Set a custom serializer if the data is of int or double type.
if (classTag[T] == ClassTag.Int) {
if (classTag[VD] == ClassTag.Int) {
rdd.setSerializer(new IntAggMsgSerializer)
} else if (classTag[T] == ClassTag.Long) {
} else if (classTag[VD] == ClassTag.Long) {
rdd.setSerializer(new LongAggMsgSerializer)
} else if (classTag[T] == ClassTag.Double) {
} else if (classTag[VD] == ClassTag.Double) {
rdd.setSerializer(new DoubleAggMsgSerializer)
}
rdd
}
}

private[graphx]
object VertexRDDFunctions {
implicit def rdd2VertexRDDFunctions[VD: ClassTag](rdd: RDD[(VertexId, VD)]) = {
new VertexRDDFunctions(rdd)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable {
val a = readVarLong(optimizePositive = false)
val b = readUnsignedVarInt()
val c = s.read()
new RoutingTableMessage(a, b, c).asInstanceOf[T]
if (c == -1) throw new EOFException
new RoutingTableMessage(a, b, c.toByte).asInstanceOf[T]
}
}
}
Expand Down