Skip to content

Commit 7a08bbf

Browse files
committed
[WIP] Reduced precision embeddings
1 parent 94fe49a commit 7a08bbf

File tree

3 files changed

+43
-21
lines changed

3 files changed

+43
-21
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,21 @@ public static void copyfp15tofp32(KernelContext context, HalfFloatArray x, Float
2929
}
3030
}
3131

32+
public static void copyfp15tofp32Vec4(KernelContext context, HalfFloatArray x, FloatArray wrapX) {
33+
int i = context.globalIdx * 4; // Process 4 elements per thread
34+
if (i + 3 < wrapX.getSize()) {
35+
wrapX.set(i, x.get(i).getFloat32());
36+
wrapX.set(i + 1, x.get(i + 1).getFloat32());
37+
wrapX.set(i + 2, x.get(i + 2).getFloat32());
38+
wrapX.set(i + 3, x.get(i + 3).getFloat32());
39+
} else {
40+
// Handle remainder
41+
for (int j = i; j < wrapX.getSize(); j++) {
42+
wrapX.set(j, x.get(j).getFloat32());
43+
}
44+
}
45+
}
46+
3247

3348
/**
3449
* Performs RMS (Root Mean Square) normalization using parallel reduction.

src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public Activation(String taskGraphHandle, State state, Weights weights, Configur
2424
// @formatter:off
2525
this.activationUpdate = new TaskGraph(taskGraphHandle)
2626
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
27-
// .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.wrapX)
27+
// .task("updateX", TransformerComputeKernels::copyfp15tofp32, state.wrapX)
2828
.task("updateX", TransformerComputeKernels::copyfp15tofp32, kernelContext, state.embeddingX, state.wrapX)
2929
.persistOnDevice(state.wrapX);
3030
// @formatter:on
@@ -34,8 +34,9 @@ public Activation(String taskGraphHandle, State state, Weights weights, Configur
3434
public GridScheduler updateGridScheduler(GridScheduler scheduler) {
3535
// WorkerGrid singleWorker = WorkerGridFactory.createSingleWorker();
3636
WorkerGrid worker = new WorkerGrid1D(config.dim());
37-
worker.setLocalWork(256, 1, 1);
37+
worker.setLocalWork(128, 1, 1);
3838
scheduler.addWorkerGrid("activationUpdate.updateX", worker);
39+
3940
return scheduler;
4041
}
4142

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,16 @@ public class LlamaFP16FFNLayers extends AbstractFFNLayers {
2121

2222
TaskGraph ffnTaskGraphs;
2323
GridScheduler scheduler;
24-
List<ImmutableTaskGraph> ffnLayerTaskGraphs;
24+
List<ImmutableTaskGraph> ffnLayerTaskGraphs;
25+
2526
public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config, SchedulerType schedulerType) {
2627
super(taskGraph, state, weights, config, schedulerType);
2728
this.ffnLayerTaskGraphs = setupFFNLayered();
2829
}
2930

3031
@Override
3132
public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
32-
WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim()/2, 128);
33+
WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 128);
3334
WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256);
3435

3536
int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
@@ -64,12 +65,12 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
6465
}
6566

6667
@Override
67-
public GridScheduler getGridScheduler() {
68+
public GridScheduler getGridScheduler() {
6869
return scheduler;
6970
}
7071

7172
@Override
72-
public TaskGraph getTaskGraph() {
73+
public TaskGraph getTaskGraph() {
7374
return ffnTaskGraphs;
7475
}
7576

@@ -87,15 +88,16 @@ List<ImmutableTaskGraph> setupFFNLayered() {
8788
state.tempFFN.init(0.0f);
8889
var numLayers = config.numberOfLayers();
8990

90-
return IntStream.range(0, numLayers)
91-
.mapToObj(i -> {
92-
var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i);
93-
if (i == numLayers - 1) setupLastID(ffnLayer.getTaskGraphName());
94-
return ffnLayer.snapshot();
95-
})
96-
.toList();
91+
return IntStream.range(0, numLayers).mapToObj(i -> {
92+
var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i);
93+
if (i == numLayers - 1) {
94+
setupLastID(ffnLayer.getTaskGraphName());
95+
}
96+
return ffnLayer.snapshot();
97+
}).toList();
9798
}
9899

100+
// @formatter:off
99101
TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) {
100102
var layerTaskGraphName = "layer_" + layerIndex;
101103
TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName);
@@ -113,10 +115,10 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
113115
unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
114116
unifiedLayer
115117
.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize);
116-
if (shouldUseFinalNormalization()) {
117-
unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp,
118-
config.dim(), config.rmsNormEps());
119-
}
118+
// if (shouldUseFinalNormalization()) {
119+
// unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp,
120+
// config.dim(), config.rmsNormEps());
121+
// }
120122
unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp)
121123
.task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(),
122124
LOCAL_WORK_GROUP_SIZE_ALLOC)
@@ -131,16 +133,18 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
131133
unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(),
132134
LOCAL_WORK_GROUP_SIZE_ALLOC)
133135
.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize);
134-
if (shouldUseFinalNormalization()) {
135-
unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, config.dim(), config.rmsNormEps());
136-
}
136+
// if (shouldUseFinalNormalization()) {
137+
// unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, config.dim(), config.rmsNormEps());
138+
// }
137139
unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN)
138140
.task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(),
139141
weights.w3Layered[layerIndex].asHalfFloatArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
140142
.task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), config.hiddenDim(),
141-
config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX);
143+
config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
144+
.persistOnDevice(state.wrapX);
142145
return unifiedLayer;
143146
}
147+
// @formatter:on
144148

145149
protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) {
146150
// First layer: Transfer initial data to device (one-time transfer)
@@ -164,6 +168,7 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye
164168
return unifiedLayer;
165169
}
166170

171+
// @formatter:off
167172
private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) {
168173
if (schedulerType == SchedulerType.NVIDIA) {
169174
return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention,
@@ -175,4 +180,5 @@ private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) {
175180
config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), config.contextLength(), state.positionHolder, state.wrapAtt, layerIndex, config.contextLength());
176181
}
177182
}
183+
// @formatter:on
178184
}

0 commit comments

Comments
 (0)