@@ -21,15 +21,16 @@ public class LlamaFP16FFNLayers extends AbstractFFNLayers {
2121
2222 TaskGraph ffnTaskGraphs ;
2323 GridScheduler scheduler ;
24- List <ImmutableTaskGraph > ffnLayerTaskGraphs ;
24+ List <ImmutableTaskGraph > ffnLayerTaskGraphs ;
25+
2526 public LlamaFP16FFNLayers (String taskGraph , State state , Weights weights , Configuration config , SchedulerType schedulerType ) {
2627 super (taskGraph , state , weights , config , schedulerType );
2728 this .ffnLayerTaskGraphs = setupFFNLayered ();
2829 }
2930
3031 @ Override
3132 public GridScheduler updateGridScheduler (GridScheduler tornadoForwardScheduler ) {
32- WorkerGrid ropeWorker = WorkerGridFactory .genericWorker (config .dim ()/ 2 , 128 );
33+ WorkerGrid ropeWorker = WorkerGridFactory .genericWorker (config .dim () / 2 , 128 );
3334 WorkerGrid rmsNormWorker = WorkerGridFactory .createRmsNormWorker (config .dim (), 256 );
3435
3536 int configDimRowMajorGlobal = config .dim () * LOCAL_WORK_GROUP_SIZE_ALLOC ;
@@ -64,12 +65,12 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
6465 }
6566
6667 @ Override
67- public GridScheduler getGridScheduler () {
68+ public GridScheduler getGridScheduler () {
6869 return scheduler ;
6970 }
7071
7172 @ Override
72- public TaskGraph getTaskGraph () {
73+ public TaskGraph getTaskGraph () {
7374 return ffnTaskGraphs ;
7475 }
7576
@@ -87,15 +88,16 @@ List<ImmutableTaskGraph> setupFFNLayered() {
8788 state .tempFFN .init (0.0f );
8889 var numLayers = config .numberOfLayers ();
8990
90- return IntStream .range (0 , numLayers )
91- . mapToObj ( i -> {
92- var ffnLayer = setupSingleFFNLayer (( LlamaTornadoWeights ) weights , config , i );
93- if ( i == numLayers - 1 ) setupLastID (ffnLayer .getTaskGraphName ());
94- return ffnLayer . snapshot ();
95- })
96- .toList ();
91+ return IntStream .range (0 , numLayers ). mapToObj ( i -> {
92+ var ffnLayer = setupSingleFFNLayer (( LlamaTornadoWeights ) weights , config , i );
93+ if ( i == numLayers - 1 ) {
94+ setupLastID (ffnLayer .getTaskGraphName ());
95+ }
96+ return ffnLayer . snapshot ();
97+ }) .toList ();
9798 }
9899
100+ // @formatter:off
99101 TaskGraph setupSingleFFNLayer (LlamaTornadoWeights weights , Configuration config , int layerIndex ) {
100102 var layerTaskGraphName = "layer_" + layerIndex ;
101103 TaskGraph unifiedLayer = new TaskGraph (layerTaskGraphName );
@@ -113,10 +115,10 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
113115 unifiedLayer = configureLayerDataTransfers (unifiedLayer , layerIndex );
114116 unifiedLayer
115117 .task ("reductionsOneBlock" , TransformerComputeKernelsLayered ::reductionOneBlockWithLayer , context , state .temp , state .wrapX , config .dim (), config .rmsNormEps (), state .localSize );
116- if (shouldUseFinalNormalization ()) {
117- unifiedLayer .task ("reductionFinalNormalization" , TransformerComputeKernelsLayered ::reductionFinalNormalization , context , state .temp ,
118- config .dim (), config .rmsNormEps ());
119- }
118+ // if (shouldUseFinalNormalization()) {
119+ // unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp,
120+ // config.dim(), config.rmsNormEps());
121+ // }
120122 unifiedLayer .task ("mapContext" , TransformerComputeKernelsLayered ::reductionOneBlock2WithLayer , context , state .wrapXb , state .wrapX , weights .rms_att_weightLayered [layerIndex ].asFloatArray (), state .temp )
121123 .task ("qmatmul" , TransformerComputeKernelsLayered ::matrixVectorGeneric , context , state .wrapXb , state .wrapQ , weights .wqLayered [layerIndex ].asHalfFloatArray (), config .dim (), config .dim (),
122124 LOCAL_WORK_GROUP_SIZE_ALLOC )
@@ -131,16 +133,18 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
131133 unifiedLayer .task ("matmul1" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidual , context , state .wrapXb , state .wrapX , weights .woLayered [layerIndex ].asHalfFloatArray (), config .dim (), config .dim (),
132134 LOCAL_WORK_GROUP_SIZE_ALLOC )
133135 .task ("reductionsOneBlockFFN" , TransformerComputeKernelsLayered ::reductionOneBlockWithLayer , context , state .tempFFN , state .wrapX , config .dim (), config .rmsNormEps (), state .localSize );
134- if (shouldUseFinalNormalization ()) {
135- unifiedLayer .task ("reductionFinalNormalizationFFN" , TransformerComputeKernelsLayered ::reductionFinalNormalization , context , state .tempFFN , config .dim (), config .rmsNormEps ());
136- }
136+ // if (shouldUseFinalNormalization()) {
137+ // unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, config.dim(), config.rmsNormEps());
138+ // }
137139 unifiedLayer .task ("mapContextFFN" , TransformerComputeKernelsLayered ::reductionOneBlock2WithLayer , context , state .wrapXb , state .wrapX , weights .rms_ffn_weightLayered [layerIndex ].asFloatArray (), state .tempFFN )
138140 .task ("fused_ffn_w1_w3" , TransformerComputeKernelsLayered ::fusedFeedForwardWithSiLUAndGLUActivation , context , state .wrapXb , state .wrapHb , weights .w1Layered [layerIndex ].asHalfFloatArray (),
139141 weights .w3Layered [layerIndex ].asHalfFloatArray (), config .dim (), config .hiddenDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
140142 .task ("projectionTwo" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidual , context , state .wrapHb , state .wrapX , weights .w2Layered [layerIndex ].asHalfFloatArray (), config .hiddenDim (),
141- config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC ).persistOnDevice (state .wrapX );
143+ config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
144+ .persistOnDevice (state .wrapX );
142145 return unifiedLayer ;
143146 }
147+ // @formatter:on
144148
145149 protected TaskGraph configureLayerDataTransfers (TaskGraph unifiedLayer , int layerIndex ) {
146150 // First layer: Transfer initial data to device (one-time transfer)
@@ -164,6 +168,7 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye
164168 return unifiedLayer ;
165169 }
166170
171+ // @formatter:off
167172 private TaskGraph configureAttention (TaskGraph unifiedLayer , int layerIndex ) {
168173 if (schedulerType == SchedulerType .NVIDIA ) {
169174 return unifiedLayer .task ("parallel-attention" , TransformerComputeKernelsLayered ::processHeadsFlashAttention ,
@@ -175,4 +180,5 @@ private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) {
175180 config .numberOfHeads (), config .headSize (), config .kvDim (), config .kvMul (), config .contextLength (), state .positionHolder , state .wrapAtt , layerIndex , config .contextLength ());
176181 }
177182 }
183+ // @formatter:on
178184}
0 commit comments