Skip to content
Open
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
ca2b28a
Implement FP16 support in TornadoVM by introducing HalfFloat arrays, …
mikepapadim Dec 3, 2025
f0411ae
Introduce matrix-vector kernel with residual addition and enhance FP1…
mikepapadim Dec 3, 2025
6334ac3
Fused Q/K/V matrix-vector multiplication into a single kernel to redu…
mikepapadim Dec 3, 2025
46218a7
Fuse RoPE rotation and KV cache copy into a single kernel, update tas…
mikepapadim Dec 3, 2025
b48ec62
Add `mapContextWithQuantize` kernel, integrate into task graph, and d…
mikepapadim Dec 3, 2025
943da78
Refactor logits task graph to optimize kernel setup, update worker gr…
mikepapadim Dec 4, 2025
386dddc
Refactor FP16 FFN layers to streamline task graph setup, update worke…
mikepapadim Dec 4, 2025
b202bb4
Refactor FP16 FFN layers to streamline task graph setup, update worke…
mikepapadim Dec 4, 2025
3eba3b3
Refactor `LogitsFP16Layer` task graph to improve readability, optimiz…
mikepapadim Dec 4, 2025
2e010b1
Add `fusedFeedForwardWithSiLUAndGLUActivation` kernel for HalfFloat a…
mikepapadim Dec 4, 2025
4aef300
Document Transformer Layer Task Flow for `LlamaFP16FFNLayers` with de…
mikepapadim Dec 4, 2025
177ec9d
Set default profiler dump directory relative to `LLAMA_ROOT` when not…
mikepapadim Dec 4, 2025
a1c94fb
Add `fusedRmsNormFFNGateUp` kernel and update FP16 FFN task graph to …
mikepapadim Dec 4, 2025
577b6b1
Increase `BLOCK_SIZE_C` to 16 for Transformer kernel and update FP16 …
mikepapadim Dec 4, 2025
d5c1206
Increase `ropeWithCacheWorker` local work group size to 512 in FP16 F…
mikepapadim Dec 4, 2025
f91108c
Add fused kernels for Qwen3: `ropeRotationWithCacheCopy`, `fusedQKVMa…
mikepapadim Dec 4, 2025
67050bb
Merge branch 'feat/deq-n-compute' of github.com:beehive-lab/GPULlama3…
mikepapadim Dec 4, 2025
cfa3ba0
Add fused Q and K RMSNorm kernel and refactor task graph to consolida…
mikepapadim Dec 4, 2025
abf12d4
Refactor Qwen3 FP16 FFN layers to streamline worker grid setup, updat…
mikepapadim Dec 4, 2025
042b0b5
Add `processHeadsFlashAttentionOptV2` kernel with static memory size …
mikepapadim Dec 4, 2025
1cbe03a
Refactor Qwen3 FP16 FFN layers: remove unused imports, replace explic…
mikepapadim Dec 4, 2025
a4bc159
Refactor Qwen2 FP16 task graph: consolidate attention and FFN tasks w…
mikepapadim Dec 4, 2025
e15c229
Add `fusedQKvBiasAddition` kernel, refactor Qwen2 FP16 task graph to …
mikepapadim Dec 4, 2025
e7d79c9
Add support for HalfFloatArray in Phi3State and initialize FP16 wrapp…
mikepapadim Dec 4, 2025
02b1a2c
Add `splitQKV` and `splitGateUpSiLU` worker grids to Phi3 FP16 FFN la…
mikepapadim Dec 4, 2025
428e5cc
Refactor Phi3 FP16 FFN layers: replace `createRoPEWorker` with generi…
mikepapadim Dec 4, 2025
6c1ac6f
Add Phi3-specific fused kernels for RMSNorm+QKV and RMSNorm+Gate/Up, …
mikepapadim Dec 4, 2025
ed74652
Replace `splitQKV` kernel with `fusedRmsNormQKVMatmulDirect`, refacto…
mikepapadim Dec 4, 2025
8b52fbe
Remove unused `splitQKV` and RMS Apply+QKV Projection kernels, update…
mikepapadim Dec 4, 2025
977f0ba
Add `fusedRmsNormFFNGateUpSiLU` kernel to optimize Phi3 FFN flow, rep…
mikepapadim Dec 4, 2025
7e19032
Remove unused `splitQKV` and `splitGateUpSiLU` workers, clean up comm…
mikepapadim Dec 4, 2025
1e46405
Refactor Phi3 FP16 FFN layer task graph: improve readability by adjus…
mikepapadim Dec 4, 2025
7c63dc4
Refactor LogitsFP16Layer: streamline task graph setup, consolidate gr…
mikepapadim Dec 4, 2025
1a98725
Refactor `TransformerComputeKernelsLayered`: replace `matrixVectorRow…
mikepapadim Dec 7, 2025
d1ec408
Refactor `TransformerComputeKernelsLayered`: rename `matrixVectorRowM…
mikepapadim Dec 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor logits task graph to optimize kernel setup, update worker gr…
…ids, and deprecate redundant tasks in FP16 layer.
  • Loading branch information
mikepapadim committed Dec 4, 2025
commit 943da78ff7718e98da40299f823f0ca57990d676
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
Expand All @@ -28,7 +28,7 @@ public class LogitsFP16Layer extends AbstractLayer {
public LogitsFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) {
super(name, state, weights, config);
this.lastTaskGraphID = lastTaskGraphID;
state.tempLogits.init(0.0f);
state.tempLogits.clear();

var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor");
this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config);
Expand All @@ -40,18 +40,20 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration
*/
private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) {
TaskGraph logits = new TaskGraph("logits");
logits.consumeFromDevice(lastTaskGraphID, state.wrapX).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits, state.wrapXFP16)
.transferToDevice(DataTransferMode.FIRST_EXECUTION, context, state.wrapLogits, weights.wclsByteArray.asHalfFloatArray(), weights.rms_final_weight_as_floatArray.asFloatArray())
.task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize);
if (schedulerType == SchedulerType.NON_NVIDIA) {
logits.task("reductionFinalNormalizationLogits", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps());
}
logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits)
.task("dequantizeX", TransformerComputeKernels::convertFP32toFP16v2, context, state.wrapX, state.wrapXFP16)
.task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, //
context, state.wrapXFP16, state.wrapLogits, //
weights.wclsByteArray.asHalfFloatArray(), config.dim(), config.vocabularySize(), //
LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); //
logits.consumeFromDevice(lastTaskGraphID, state.wrapX) //
.transferToDevice(DataTransferMode.FIRST_EXECUTION, context, //
state.wrapLogits, state.wrapXbFP16, //
weights.wclsByteArray.asHalfFloatArray(), //
weights.rms_final_weight_as_floatArray.asFloatArray()) //
.task("rms_reduce", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize);
if (schedulerType == SchedulerType.NON_NVIDIA) {
logits.task("rms_finalize", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps());
}
logits.task("rms_apply_fp16", TransformerComputeKernels::mapContextWithQuantizeLogits, context, state.wrapXbFP16, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits)
.task("vocab_proj", TransformerComputeKernelsLayered::matrixVectorGeneric, //
context, state.wrapXbFP16, state.wrapLogits, //
weights.wclsByteArray.asHalfFloatArray(), config.dim(), config.vocabularySize(), //
LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); //
logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
return logits;
}
Expand All @@ -69,10 +71,9 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);

tornadoForwardScheduler.addWorkerGrid("logits.dequantizeX", logitsRMS);
tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker);
tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", logitsRMS);
tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS);
tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker);
tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS);
tornadoForwardScheduler.addWorkerGrid("logits.rms_apply_fp16", logitsRMS);
return tornadoForwardScheduler;
}

Expand Down