From 94fe49a68f5222d64cbb5e55a20f7a3ae598d1e7 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Tue, 25 Nov 2025 16:46:23 +0200 Subject: [PATCH 1/4] [WIP] Working with conversion kernel --- llama-tornado | 1 - set_paths | 4 ++-- .../gpullama3/inference/InferenceCore.java | 2 +- .../gpullama3/inference/state/LlamaState.java | 3 +++ .../gpullama3/inference/state/State.java | 6 ++++++ .../model/loader/LlamaModelLoader.java | 2 +- .../model/loader/MistralModelLoader.java | 2 +- .../kernels/TransformerComputeKernels.java | 11 ++++++++++ .../tornadovm/layers/Activation.java | 21 +++++++++++++------ .../layers/type/fp16/Qwen2FP16FFNLayers.java | 7 ++++--- 10 files changed, 44 insertions(+), 15 deletions(-) diff --git a/llama-tornado b/llama-tornado index b59473f2..00a393bd 100755 --- a/llama-tornado +++ b/llama-tornado @@ -76,7 +76,6 @@ class LlamaRunner: "-Dtornado.load.tornado.implementation=uk.ac.manchester.tornado.runtime.common.Tornado", "-Dtornado.load.annotation.implementation=uk.ac.manchester.tornado.annotation.ASMClassVisitor", "-Dtornado.load.annotation.parallel=uk.ac.manchester.tornado.api.annotations.Parallel", - "-Dtornado.tvm.maxbytecodesize=65536" ] cmd.extend(tornado_config) diff --git a/set_paths b/set_paths index fd807c5e..fe79810e 100644 --- a/set_paths +++ b/set_paths @@ -6,10 +6,10 @@ # Resolve root of this project (LLaMA3) and TornadoVM export LLAMA_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -export TORNADO_ROOT="${LLAMA_ROOT}/external/tornadovm" +#export TORNADO_ROOT="${LLAMA_ROOT}/external/tornadovm" # Set the path to TornadoVM SDK binaries -export TORNADO_SDK="${TORNADO_ROOT}/bin/sdk" +#export TORNADO_SDK="${TORNADO_ROOT}/bin/sdk" # Add TornadoVM and LLaMA bin directories to PATH export PATH="${PATH}:${TORNADO_SDK}:${LLAMA_ROOT}" diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java index 8104e561..33f8c0e8 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java @@ -583,7 +583,7 @@ public static FloatArray forwardTornadoVM(Model model, State state, int token, i final Configuration configuration = model.configuration(); final TornadoWeights weights = (TornadoWeights) model.weights(); - MemorySegment.copy(weights.getTokenEmbeddingTable().asFloatArray().getSegment(), (long) token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES); + MemorySegment.copy(weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(), (long) token * configuration.dim() * Short.BYTES, state.embeddingX.getSegment(), 0, configuration.dim() * Short.BYTES); return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position); } diff --git a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java index 9f9fdcdb..38af3877 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java @@ -4,6 +4,7 @@ import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.model.Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import java.util.stream.Stream; @@ -52,6 +53,8 @@ protected StateFields createStateFields(Configuration config) { fields.wrapHb = new FloatArray(config.hiddenDim()); fields.wrapHb2 = new FloatArray(config.hiddenDim()); + fields.embeddingX = new HalfFloatArray(config.dim()); + fields.wrapLogits = new FloatArray(config.vocabularySize()); fields.wrapQ = new FloatArray(config.dim()); fields.wrapK = new FloatArray(config.dim()); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/State.java b/src/main/java/org/beehive/gpullama3/inference/state/State.java index 01d94936..5b245f15 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/State.java @@ -3,6 +3,7 @@ import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.model.Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; /** @@ -57,6 +58,8 @@ public abstract class State { public final FloatArray wrapValueCache; // FloatArray wrapper for the value cache, optimized for TornadoVM. public final IntArray positionHolder; + public final HalfFloatArray embeddingX; + // store inter public int localSize; public FloatArray temp; // Temporary buffer for intermediate calculations, size adjusted for local workgroup size. @@ -108,6 +111,8 @@ protected State(Configuration config, int batchsize) { this.temp = fields.temp; this.tempFFN = fields.tempFFN; this.tempLogits = fields.tempLogits; + + this.embeddingX = fields.embeddingX; } // Abstract method - subclasses implement their specific allocation logic and sizes @@ -121,6 +126,7 @@ protected static class StateFields { public FloatArray wrapQ, wrapK, wrapV, wrapAtt, wrapKeyCache, wrapValueCache; public IntArray positionHolder; public FloatArray temp, tempFFN, tempLogits; + public HalfFloatArray embeddingX; } @Override diff --git a/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java index aa3a3894..4605e56c 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java @@ -116,7 +116,7 @@ protected Weights createTornadoVMWeights(Map tensorEntr // Load all tensors uniformly as TornadoTensor hierarchy return new LlamaTornadoWeights( - loadTornadoTensorAsFP32(tokenEmbeddings), + loadTornadoTensor(tokenEmbeddings), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), diff --git a/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java index 0b9ba3d2..e31d2b04 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java @@ -110,7 +110,7 @@ protected Weights createTornadoVMWeights(Map tensorEntr // Load all tensors uniformly as TornadoTensor hierarchy return new LlamaTornadoWeights( - loadTornadoTensorAsFP32(tokenEmbeddings), + loadTornadoTensor(tokenEmbeddings), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java index 7f69e496..4f203a4c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java @@ -1,8 +1,11 @@ package org.beehive.gpullama3.tornadovm.kernels; import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.annotations.Parallel; import uk.ac.manchester.tornado.api.math.TornadoMath; +import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; public class TransformerComputeKernels { @@ -19,6 +22,14 @@ public static void emptyTaskToForceCopyIn(FloatArray buffer) { } } + public static void copyfp15tofp32(KernelContext context, HalfFloatArray x, FloatArray wrapX) { + int i = context.globalIdx; + if (i < wrapX.getSize()) { + wrapX.set(i, x.get(i).getFloat32()); + } + } + + /** * Performs RMS (Root Mean Square) normalization using parallel reduction. * This is a two-phase reduction: first within work groups, then across work groups. diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java index 16783829..172faec3 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java @@ -7,8 +7,10 @@ import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; public class Activation extends AbstractLayer { @@ -17,16 +19,23 @@ public class Activation extends AbstractLayer { public Activation(String taskGraphHandle, State state, Weights weights, Configuration config) { super(taskGraphHandle, state, weights, config); - // formatter:off - this.activationUpdate = new TaskGraph(taskGraphHandle).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapX) - .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.wrapX).persistOnDevice(state.wrapX); - // formatter:on + KernelContext kernelContext = new KernelContext(); + + // @formatter:off + this.activationUpdate = new TaskGraph(taskGraphHandle) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) +// .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.wrapX) + .task("updateX", TransformerComputeKernels::copyfp15tofp32, kernelContext, state.embeddingX, state.wrapX) + .persistOnDevice(state.wrapX); + // @formatter:on } @Override public GridScheduler updateGridScheduler(GridScheduler scheduler) { - WorkerGrid singleWorker = WorkerGridFactory.createSingleWorker(); - scheduler.addWorkerGrid("activationUpdate.updateX", singleWorker); +// WorkerGrid singleWorker = WorkerGridFactory.createSingleWorker(); + WorkerGrid worker = new WorkerGrid1D(config.dim()); + worker.setLocalWork(256, 1, 1); + scheduler.addWorkerGrid("activationUpdate.updateX", worker); return scheduler; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java index 858848ea..cd64b1dd 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java @@ -187,9 +187,10 @@ TaskGraph setupSingleQwen2FFNLayer(Qwen2TornadoWeights weights, int layerIndex) .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen2State.wrapXb, qwen2State.wrapK, weights.wkLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen2State.wrapXb, qwen2State.wrapV, weights.wvLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), - LOCAL_WORK_GROUP_SIZE_ALLOC).task("qbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapQ, weights.q_biasLayered[layerIndex].asFloatArray(), config.dim()) - .task("kbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapK, weights.k_biasLayered[layerIndex].asFloatArray(), config.kvDim()) - .task("vbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapV, weights.v_biasLayered[layerIndex].asFloatArray(), config.kvDim()) + LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("qbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapQ, weights.q_biasLayered[layerIndex].asFloatArray(), config.dim()) //TODO: + .task("kbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapK, weights.k_biasLayered[layerIndex].asFloatArray(), config.kvDim()) //TODO: CHECK THREADS + .task("vbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapV, weights.v_biasLayered[layerIndex].asFloatArray(), config.kvDim()) //TODO: .task("rope", Qwen3Kernels::ropeRotation, context, qwen2State.positionHolder, qwen2State.wrapQ, qwen2State.wrapK, config.numberOfKeyValueHeads(), config.headSize()) .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, qwen2State.wrapKeyCache, qwen2State.wrapK, qwen2State.wrapValueCache, qwen2State.wrapV, qwen2State.positionHolder, config.kvDim(), layerIndex, config.contextLength()) From 7a08bbf777379b7eb9031e8232564631c8b789e7 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Tue, 25 Nov 2025 17:11:26 +0200 Subject: [PATCH 2/4] [WIP] Reduced precision embeddings --- .../kernels/TransformerComputeKernels.java | 15 +++++++ .../tornadovm/layers/Activation.java | 5 ++- .../layers/type/fp16/LlamaFP16FFNLayers.java | 44 +++++++++++-------- 3 files changed, 43 insertions(+), 21 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java index 4f203a4c..461da7bf 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java @@ -29,6 +29,21 @@ public static void copyfp15tofp32(KernelContext context, HalfFloatArray x, Float } } + public static void copyfp15tofp32Vec4(KernelContext context, HalfFloatArray x, FloatArray wrapX) { + int i = context.globalIdx * 4; // Process 4 elements per thread + if (i + 3 < wrapX.getSize()) { + wrapX.set(i, x.get(i).getFloat32()); + wrapX.set(i + 1, x.get(i + 1).getFloat32()); + wrapX.set(i + 2, x.get(i + 2).getFloat32()); + wrapX.set(i + 3, x.get(i + 3).getFloat32()); + } else { + // Handle remainder + for (int j = i; j < wrapX.getSize(); j++) { + wrapX.set(j, x.get(j).getFloat32()); + } + } + } + /** * Performs RMS (Root Mean Square) normalization using parallel reduction. diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java index 172faec3..c559d9f1 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java @@ -24,7 +24,7 @@ public Activation(String taskGraphHandle, State state, Weights weights, Configur // @formatter:off this.activationUpdate = new TaskGraph(taskGraphHandle) .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) -// .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.wrapX) +// .task("updateX", TransformerComputeKernels::copyfp15tofp32, state.wrapX) .task("updateX", TransformerComputeKernels::copyfp15tofp32, kernelContext, state.embeddingX, state.wrapX) .persistOnDevice(state.wrapX); // @formatter:on @@ -34,8 +34,9 @@ public Activation(String taskGraphHandle, State state, Weights weights, Configur public GridScheduler updateGridScheduler(GridScheduler scheduler) { // WorkerGrid singleWorker = WorkerGridFactory.createSingleWorker(); WorkerGrid worker = new WorkerGrid1D(config.dim()); - worker.setLocalWork(256, 1, 1); + worker.setLocalWork(128, 1, 1); scheduler.addWorkerGrid("activationUpdate.updateX", worker); + return scheduler; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 96acd650..36c6d308 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -21,7 +21,8 @@ public class LlamaFP16FFNLayers extends AbstractFFNLayers { TaskGraph ffnTaskGraphs; GridScheduler scheduler; - List ffnLayerTaskGraphs; + List ffnLayerTaskGraphs; + public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config, SchedulerType schedulerType) { super(taskGraph, state, weights, config, schedulerType); this.ffnLayerTaskGraphs = setupFFNLayered(); @@ -29,7 +30,7 @@ public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Config @Override public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { - WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim()/2, 128); + WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 128); WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; @@ -64,12 +65,12 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) } @Override - public GridScheduler getGridScheduler() { + public GridScheduler getGridScheduler() { return scheduler; } @Override - public TaskGraph getTaskGraph() { + public TaskGraph getTaskGraph() { return ffnTaskGraphs; } @@ -87,15 +88,16 @@ List setupFFNLayered() { state.tempFFN.init(0.0f); var numLayers = config.numberOfLayers(); - return IntStream.range(0, numLayers) - .mapToObj(i -> { - var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); - if (i == numLayers - 1) setupLastID(ffnLayer.getTaskGraphName()); - return ffnLayer.snapshot(); - }) - .toList(); + return IntStream.range(0, numLayers).mapToObj(i -> { + var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); + if (i == numLayers - 1) { + setupLastID(ffnLayer.getTaskGraphName()); + } + return ffnLayer.snapshot(); + }).toList(); } + // @formatter:off TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) { var layerTaskGraphName = "layer_" + layerIndex; TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); @@ -113,10 +115,10 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); unifiedLayer .task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, - config.dim(), config.rmsNormEps()); - } +// if (shouldUseFinalNormalization()) { +// unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, +// config.dim(), config.rmsNormEps()); +// } unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) @@ -131,16 +133,18 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, config.dim(), config.rmsNormEps()); - } +// if (shouldUseFinalNormalization()) { +// unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, config.dim(), config.rmsNormEps()); +// } unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN) .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), weights.w3Layered[layerIndex].asHalfFloatArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), config.hiddenDim(), - config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX); + config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .persistOnDevice(state.wrapX); return unifiedLayer; } + // @formatter:on protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { // First layer: Transfer initial data to device (one-time transfer) @@ -164,6 +168,7 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye return unifiedLayer; } + // @formatter:off private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { if (schedulerType == SchedulerType.NVIDIA) { return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, @@ -175,4 +180,5 @@ private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), config.contextLength(), state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()); } } + // @formatter:on } From 47ae5ef4ab4e263ad8bc8ca86f71c407d4de0199 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 26 Nov 2025 11:22:55 +0200 Subject: [PATCH 3/4] [WIP] Reduced precision embeddings work-in-progress --- .../kernels/TransformerComputeKernels.java | 25 ++++++------------- .../tornadovm/layers/Activation.java | 7 +++--- .../layers/type/fp16/LlamaFP16FFNLayers.java | 1 + 3 files changed, 11 insertions(+), 22 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java index 461da7bf..14030cf9 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java @@ -1,7 +1,6 @@ package org.beehive.gpullama3.tornadovm.kernels; import uk.ac.manchester.tornado.api.KernelContext; -import uk.ac.manchester.tornado.api.annotations.Parallel; import uk.ac.manchester.tornado.api.math.TornadoMath; import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; @@ -22,26 +21,16 @@ public static void emptyTaskToForceCopyIn(FloatArray buffer) { } } - public static void copyfp15tofp32(KernelContext context, HalfFloatArray x, FloatArray wrapX) { - int i = context.globalIdx; - if (i < wrapX.getSize()) { - wrapX.set(i, x.get(i).getFloat32()); + public static void emptyTaskToForceCopyIn(HalfFloatArray buffer) { + float dummy = buffer.get(0).getFloat32(); + if (dummy > Float.MAX_VALUE) { + buffer.set(0, new HalfFloat(dummy)); } } - public static void copyfp15tofp32Vec4(KernelContext context, HalfFloatArray x, FloatArray wrapX) { - int i = context.globalIdx * 4; // Process 4 elements per thread - if (i + 3 < wrapX.getSize()) { - wrapX.set(i, x.get(i).getFloat32()); - wrapX.set(i + 1, x.get(i + 1).getFloat32()); - wrapX.set(i + 2, x.get(i + 2).getFloat32()); - wrapX.set(i + 3, x.get(i + 3).getFloat32()); - } else { - // Handle remainder - for (int j = i; j < wrapX.getSize(); j++) { - wrapX.set(j, x.get(j).getFloat32()); - } - } + public static void convertFP16toFP32(KernelContext context, HalfFloatArray x, FloatArray wrapX) { + int i = context.globalIdx; + wrapX.set(i, x.get(i).getFloat32()); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java index c559d9f1..3e815d5a 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java @@ -4,7 +4,6 @@ import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.KernelContext; @@ -24,9 +23,9 @@ public Activation(String taskGraphHandle, State state, Weights weights, Configur // @formatter:off this.activationUpdate = new TaskGraph(taskGraphHandle) .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) -// .task("updateX", TransformerComputeKernels::copyfp15tofp32, state.wrapX) - .task("updateX", TransformerComputeKernels::copyfp15tofp32, kernelContext, state.embeddingX, state.wrapX) - .persistOnDevice(state.wrapX); + .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.embeddingX) +// .task("updateX", TransformerComputeKernels::convertFP16toFP32, kernelContext, state.embeddingX, state.wrapX) + .persistOnDevice(state.embeddingX); // @formatter:on } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 36c6d308..1df45fc5 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -119,6 +119,7 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, // unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, // config.dim(), config.rmsNormEps()); // } + unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) From a172f43eb7769a6db421bccfd29a7f56f0412a2f Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 26 Nov 2025 15:01:37 +0200 Subject: [PATCH 4/4] Introduce FP32-to-FP16 conversion kernel and optimize matrix-vector operations for HalfFloat arrays. Extend state with `embeddingXlogits` support and refine logits task graph for TornadoVM layers. --- llama-tornado | 3 + .../gpullama3/inference/state/LlamaState.java | 1 + .../gpullama3/inference/state/State.java | 4 +- .../kernels/TransformerComputeKernels.java | 6 ++ .../TransformerComputeKernelsLayered.java | 62 +++++++++++++++++++ .../tornadovm/layers/Activation.java | 6 +- .../layers/type/fp16/LogitsFP16Layer.java | 7 ++- 7 files changed, 83 insertions(+), 6 deletions(-) diff --git a/llama-tornado b/llama-tornado index 00a393bd..2a88c13b 100755 --- a/llama-tornado +++ b/llama-tornado @@ -67,6 +67,9 @@ class LlamaRunner: "-Djdk.module.showModuleResolution=false", "--module-path", self.module_path_colon_sep([".", f"{self.tornado_sdk}/share/java/tornado"]), + # "-Dgraal.Dump=*:5", + # "-Dgraal.PrintGraph=Network", + # "-Dgraal.PrintBackendCFG=true", ] # TornadoVM configuration diff --git a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java index 38af3877..83d9772d 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java @@ -54,6 +54,7 @@ protected StateFields createStateFields(Configuration config) { fields.wrapHb2 = new FloatArray(config.hiddenDim()); fields.embeddingX = new HalfFloatArray(config.dim()); + fields.embeddingXlogits = new HalfFloatArray(config.dim()); fields.wrapLogits = new FloatArray(config.vocabularySize()); fields.wrapQ = new FloatArray(config.dim()); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/State.java b/src/main/java/org/beehive/gpullama3/inference/state/State.java index 5b245f15..9bb4ac48 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/State.java @@ -59,6 +59,7 @@ public abstract class State { public final IntArray positionHolder; public final HalfFloatArray embeddingX; + public final HalfFloatArray embeddingXlogits; // store inter public int localSize; @@ -113,6 +114,7 @@ protected State(Configuration config, int batchsize) { this.tempLogits = fields.tempLogits; this.embeddingX = fields.embeddingX; + this.embeddingXlogits = fields.embeddingXlogits; } // Abstract method - subclasses implement their specific allocation logic and sizes @@ -126,7 +128,7 @@ protected static class StateFields { public FloatArray wrapQ, wrapK, wrapV, wrapAtt, wrapKeyCache, wrapValueCache; public IntArray positionHolder; public FloatArray temp, tempFFN, tempLogits; - public HalfFloatArray embeddingX; + public HalfFloatArray embeddingX, embeddingXlogits; } @Override diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java index 14030cf9..709856d0 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java @@ -33,6 +33,12 @@ public static void convertFP16toFP32(KernelContext context, HalfFloatArray x, Fl wrapX.set(i, x.get(i).getFloat32()); } + public static void convertFP32toFP16(KernelContext context, FloatArray wrapX, HalfFloatArray x) { + int i = context.globalIdx; + float valInput = wrapX.get(i); + HalfFloat val = new HalfFloat(valInput); + x.set(i,val); + } /** * Performs RMS (Root Mean Square) normalization using parallel reduction. diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index dfe4ef27..4905ba6c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -3,6 +3,7 @@ import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.annotations.Parallel; import uk.ac.manchester.tornado.api.math.TornadoMath; +import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.Int8Array; @@ -690,6 +691,32 @@ public static void matrixVectorGeneric( hb.set(rowId, sum); } } + + + public static void matrixVectorGeneric( + KernelContext context, + HalfFloatArray x, + FloatArray hb, // output + HalfFloatArray w, + int dim1, // inner loop + int dim0, // outer loop + int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= dim0) { + return; + } + float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, dim1); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + hb.set(rowId, sum); + } + } // @formatter:on /** @@ -878,6 +905,41 @@ public static float matrixVectorRowMajorOptimized(KernelContext context, int loc return localSum[0]; } + public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + + // Each thread calculates partial dot product + float partialSum = 0.0f; + for (int j = localId; j < n; j += localSize) { + int matrixIdx = rowOffset + j; +// HalfFloat inter = HalfFloat.mult(w.get(matrixIdx), x.get(j)); +// partialSum = HalfFloat.add(partialSum, inter); + partialSum += w.get(matrixIdx).getFloat32() * x.get(j).getFloat32(); +// partialSum += inter; + } + + // Store partial sum in local memory + localSum[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction within workgroup + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { +// localSum[localId] = HalfFloat.add(localSum[localId], localSum[localId + stride]); + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + return localSum[0]; + } + // Second kernel - Combines partial sums and computes final normalization public static void reductionFinalNormalization(KernelContext context, FloatArray output, int size, float ermsNorm) { int gid = context.globalIdx; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java index 3e815d5a..fe20ef50 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java @@ -23,9 +23,9 @@ public Activation(String taskGraphHandle, State state, Weights weights, Configur // @formatter:off this.activationUpdate = new TaskGraph(taskGraphHandle) .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) - .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.embeddingX) -// .task("updateX", TransformerComputeKernels::convertFP16toFP32, kernelContext, state.embeddingX, state.wrapX) - .persistOnDevice(state.embeddingX); +// .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.embeddingX) + .task("updateX", TransformerComputeKernels::convertFP16toFP32, kernelContext, state.embeddingX, state.wrapX) + .persistOnDevice(state.wrapX); // @formatter:on } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index a674c1c5..3db2b074 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -29,6 +29,7 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration super(name, state, weights, config); this.lastTaskGraphID = lastTaskGraphID; state.tempLogits.init(0.0f); + state.embeddingXlogits.clear(); var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor"); this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); this.schedulerType = schedulerType; @@ -39,14 +40,15 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration */ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { TaskGraph logits = new TaskGraph("logits"); - logits.consumeFromDevice(lastTaskGraphID, state.wrapX).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits) + logits.consumeFromDevice(lastTaskGraphID, state.wrapX).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits, state.embeddingXlogits) .transferToDevice(DataTransferMode.FIRST_EXECUTION, context, state.wrapLogits, weights.wclsByteArray.asHalfFloatArray(), weights.rms_final_weight_as_floatArray.asFloatArray()) .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); if (schedulerType == SchedulerType.NON_NVIDIA) { logits.task("reductionFinalNormalizationLogits", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps()); } logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits) - .task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapX, state.wrapLogits, weights.wclsByteArray.asHalfFloatArray(), config.dim(), config.vocabularySize(), + .task("convert", TransformerComputeKernels::convertFP32toFP16, context, state.wrapX, state.embeddingXlogits) + .task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.embeddingX, state.wrapLogits, weights.wclsByteArray.asHalfFloatArray(), config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); return logits; @@ -68,6 +70,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker); tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", logitsRMS); tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.convert", logitsRMS); return tornadoForwardScheduler; }