metadata) {
/**
* Loads the language model based on the given options.
- *
- * If Ahead-of-Time (AOT) mode is enabled, attempts to use a pre-loaded compiled model. Otherwise, loads the model from the specified path using the model loader.
- *
*
- * @param options
- * the parsed CLI options containing model path and max token limit
+ * If Ahead-of-Time (AOT) mode is enabled, attempts to use a pre-loaded compiled model.
+ * Otherwise, loads the model from the specified path using the model loader.
+ *
+ * @param options the parsed CLI options containing model path and max token limit
* @return the loaded {@link Model} instance
- * @throws IOException
- * if the model fails to load
- * @throws IllegalStateException
- * if AOT loading is enabled but the preloaded model is unavailable
+ * @throws IOException if the model fails to load
+ * @throws IllegalStateException if AOT loading is enabled but the preloaded model is unavailable
*/
public static Model loadModel(Options options) throws IOException {
- if (USE_AOT) {
- Model model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens());
- if (model == null) {
- throw new IllegalStateException("Failed to load precompiled AOT model.");
- }
- return model;
- }
- return ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true, options.useTornadovm());
- }
+ Path ggufPath = options.modelPath();
+ int contextLength = options.maxTokens();
+ boolean useTornadovm = options.useTornadovm();
- public static Model loadModel(Path ggufPath, int contextLength, boolean loadWeights, boolean useTornadovm) throws IOException {
// initial load of metadata from gguf file
- GGUF gguf = GGUF.loadModel(ggufPath);
- FileChannel fileChannel = FileChannel.open(ggufPath, StandardOpenOption.READ);
+ GGUF gguf = GGUF.loadGGUFMetadata(ggufPath);
// detect model type
ModelType modelType = detectModelType(gguf.getMetadata());
// model type-specific load
- return modelType.loadModel(fileChannel, gguf, contextLength, loadWeights, useTornadovm);
+ return modelType.loadModel(gguf.getFileChannel(), gguf, contextLength, useTornadovm);
}
- public static FloatTensor loadQuantized(GGMLTensorEntry entry) {
+ /**
+ * Dispatcher method for loading a standard (non-tornado) tensor based on GGML type.
+ * Used in CPU-path.
+ */
+ public static FloatTensor loadTensor(GGMLTensorEntry entry) {
GGMLType ggmlType = entry.ggmlType();
return switch (ggmlType) {
- case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
+ case F32 -> new FP32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
case Q8_0 -> new Q8_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
case Q4_0 -> new Q4_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
- case F16 -> new F16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
+ case F16 -> new FP16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
default -> throw new UnsupportedOperationException("Quantization format " + ggmlType);
};
}
+ /**
+ * Dispatcher method for loading a standard tensor array based on type.
+ * Used in CPU-path.
+ */
+ public static FloatTensor[] loadArrayOfTensors(int size, IntFunction getTensorEntry) {
+ FloatTensor[] array = new FloatTensor[size];
+ for (int i = 0; i < size; i++) {
+ array[i] = loadTensor(getTensorEntry.apply(i));
+ }
+ return array;
+ }
+
+ /**
+ * Dispatcher method for loading a TornadoVM-compatible tensor based on GGML type.
+ * Used in GPU-path.
+ */
+ public static TornadoTensor loadTornadoTensor(GGMLTensorEntry entry) {
+ GGMLType ggmlType = entry.ggmlType();
+ int size = FloatTensor.numberOfElements(entry.shape());
+ return switch (ggmlType) {
+ case F32 -> FP32TornadoTensor.fromTornadoMemorySegment(entry.memorySegment());
+ case F16 -> FP16TornadoTensor.fromTornadoMemorySegment(entry.memorySegment());
+ case Q8_0 -> Q8_0TornadoTensor.createAsQ8_0(entry);
+ case Q4_0 -> throw new UnsupportedOperationException("Q4 format not supported yet");
+ default -> throw new UnsupportedOperationException("Quantization format " + ggmlType);
+ };
+ }
+
+ /**
+ * Dispatcher method for loading a TornadoVM tensor array based on type.
+ * Used in GPU-path.
+ */
+ public static TornadoTensor[] loadArrayOfTornadoTensors(int size, IntFunction getTensorEntry) {
+ TornadoTensor[] array = new TornadoTensor[size];
+ for (int i = 0; i < size; i++) {
+ array[i] = loadTornadoTensor(getTensorEntry.apply(i));
+ }
+ return array;
+ }
+
+ /**
+ * Load a tensor and manually convert to FP32 (FloatArray).
+ * Used for embeddings that currently are treated as FP32.
+ * TODO: it is ultra-slow and should be removed
+ */
+ public static TornadoTensor loadTornadoTensorAsFP32(GGMLTensorEntry entry) {
+ TornadoTensor tensor = loadTornadoTensor(entry);
+ return switch (tensor.type()) {
+ case F32 -> tensor;
+ case F16 -> {
+ HalfFloatArray tensorHFA = tensor.asHalfFloatArray();
+ int numOfElements = tensorHFA.getSize();
+ FloatArray tensorFA = new FloatArray(numOfElements);
+ for (int i = 0; i < numOfElements; i++) {
+ tensorFA.set(i, tensorHFA.get(i).getFloat32());
+ }
+ yield new FP32TornadoTensor(tensorFA);
+ }
+ case Q8_0 -> Q8_0TornadoTensor.createAsFP32(entry);
+ default -> {
+ throw new UnsupportedOperationException("Unsupported tensor type: " + tensor.type());
+ }
+ };
+ }
+
+ // Helper methods
+
public static FloatArray[] loadArrayAsFloatArray(int size, IntFunction getTensorEntry) {
FloatArray[] array = new FloatArray[size];
for (int i = 0; i < size; i++) {
@@ -132,7 +179,6 @@ public static FloatArray[] loadArrayAsFloatArray(int size, IntFunction getTensorEntry) {
HalfFloatArray[] array = new HalfFloatArray[size];
@@ -142,7 +188,13 @@ public static HalfFloatArray[] loadArrayAsHalfFloatArray(int size, IntFunction getTensorEntry) {
+ Q8_0TornadoTensor[] array = new Q8_0TornadoTensor[size];
+ for (int i = 0; i < size; i++) {
+ array[i] = Q8_0TornadoTensor.createAsQ8_0(getTensorEntry.apply(i));
+ }
+ return array;
+ }
public static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) {
if (tensorEntry.ggmlType() == GGMLType.F32) {
@@ -152,7 +204,6 @@ public static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) {
throw new UnsupportedOperationException("Conversion to FloatArray from " + tensorEntry.ggmlType());
}
}
- //@formatter:on
public static FloatArray[] loadArrayAsFloatArrayFromBuffer(int size, IntFunction getTensorEntry) {
FloatArray[] array = new FloatArray[size];
@@ -163,7 +214,7 @@ public static FloatArray[] loadArrayAsFloatArrayFromBuffer(int size, IntFunction
}
public static ByteArray createByteArrayFromTensor(GGMLTensorEntry entry) {
- FloatTensor tensor = loadQuantized(entry);
+ FloatTensor tensor = loadTensor(entry);
return ByteArray.fromSegment(tensor.asMemorySegment());
}
@@ -178,7 +229,7 @@ public static FloatArray loadTensorAsFloatArray(GGMLTensorEntry entry) {
return array;
} else {
// For quantized formats, we need to load through FloatTensor
- FloatTensor tensor = loadQuantized(entry);
+ FloatTensor tensor = loadTensor(entry);
FloatArray array = new FloatArray(tensor.size());
for (int i = 0; i < tensor.size(); i++) {
array.set(i, tensor.getFloat(i));
@@ -193,7 +244,7 @@ public static HalfFloatArray loadTensorAsHalfFloatArray(GGMLTensorEntry entry) {
return null;
} else {
// For quantized formats, we need to load through FloatTensor
- FloatTensor tensor = loadQuantized(entry);
+ FloatTensor tensor = loadTensor(entry);
HalfFloatArray array = new HalfFloatArray(tensor.size());
for (int i = 0; i < tensor.size(); i++) {
HalfFloat x = new HalfFloat(tensor.getFloat(i));
@@ -203,14 +254,6 @@ public static HalfFloatArray loadTensorAsHalfFloatArray(GGMLTensorEntry entry) {
}
}
- public static FloatTensor[] loadArrayOfQuantized(int size, IntFunction getTensorEntry) {
- FloatTensor[] array = new FloatTensor[size];
- for (int i = 0; i < size; i++) {
- array[i] = loadQuantized(getTensorEntry.apply(i));
- }
- return array;
- }
-
public static FloatBuffer[] loadArrayOfFloatBuffer(int size, IntFunction getTensorEntry) {
FloatBuffer[] array = new FloatBuffer[size];
for (int i = 0; i < size; i++) {
@@ -226,95 +269,4 @@ public static FloatBuffer toFloatBuffer(GGMLTensorEntry tensorEntry) {
default -> throw new UnsupportedOperationException("Conversion to " + ggmlType);
};
}
-
- public abstract Model loadModel();
-
- //@formatter:off
- public Weights loadWeights(Map tensorEntries, Configuration config) {
- boolean ropeScaling = tensorEntries.containsKey("rope_freqs");
- RopeConfig ropeConfig = new RopeConfig(8.0f, // scaleFactor
- 1.0f, // loFreqFactor
- 3.0f, // hiFreqFactor
- 8192 // oldContextLength
- );
-
- Pair ropeFreqs = RoPE.precomputeFreqsCis(
- config.contextLength(), // Maximum sequence length the model can process
- config.headSize(), // Dimension of each attention head
- config.ropeTheta(), // Base frequency parameter (typically 10000.0)
- ropeScaling, // Whether to apply frequency scaling (determined by model type)
- ropeConfig.scaleFactor, // Scale factor for extending context length (NTK-aware scaling)
- ropeConfig.loFreqFactor, // Low frequency scaling factor for better long-range dependencies
- ropeConfig.hiFreqFactor, // High frequency scaling factor for preserving local precision
- ropeConfig.oldContextLength // Original context length the model was trained with
- );
-
- GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
- GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings);
-
- if (useTornadovm) {
- if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
- System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
- }
- return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- } else {
- return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- }
- }
-
- public Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
- GGMLTensorEntry outputWeight) {
- return new LlamaTornadoWeights(
- // Load directly to TornadoVM format
- loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
- loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
- FloatArray.fromArray(ropeFreqs.first()), FloatArray.fromArray(ropeFreqs.second()), loadTensorAsHalfFloatArray(outputWeight), outputWeight.ggmlType()) {
- };
- }
-
- /**
- * Creates weights in standard format only
- */
- public Weights createStandardWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
- GGMLTensorEntry outputWeight) {
- return new LlamaStandardWeights(
- loadQuantized(tokenEmbeddings),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
- loadQuantized(tensorEntries.get("output_norm.weight")),
- new ArrayFloatTensor(ropeFreqs.first()),
- new ArrayFloatTensor(ropeFreqs.second()),
- loadQuantized(outputWeight),
- outputWeight.ggmlType());
- }
-
- // Helper class to encapsulate RoPE configuration parameters
- private static class RopeConfig {
- final float scaleFactor;
- final float loFreqFactor;
- final float hiFreqFactor;
- final int oldContextLength;
-
- RopeConfig(float scaleFactor, float loFreqFactor, float hiFreqFactor, int oldContextLength) {
- this.scaleFactor = scaleFactor;
- this.loFreqFactor = loFreqFactor;
- this.hiFreqFactor = hiFreqFactor;
- this.oldContextLength = oldContextLength;
- }
- }
-
}
diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java
index d6b431c5..f32249ed 100644
--- a/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java
+++ b/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java
@@ -1,161 +1,157 @@
package org.beehive.gpullama3.model.loader;
-import org.beehive.gpullama3.LlamaApp;
-import org.beehive.gpullama3.Options;
-import org.beehive.gpullama3.auxiliary.Timer;
-import org.beehive.gpullama3.core.model.GGMLType;
-import org.beehive.gpullama3.core.model.GGUF;
-import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
-import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry;
-import org.beehive.gpullama3.core.types.Pair;
+import org.beehive.gpullama3.tensor.GGMLType;
+import org.beehive.gpullama3.tensor.GGUF;
+import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
+import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor;
+import org.beehive.gpullama3.tensor.GGMLTensorEntry;
+import org.beehive.gpullama3.auxiliary.Pair;
import org.beehive.gpullama3.inference.operation.RoPE;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.inference.weights.standard.Phi3StandardWeights;
import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights;
-import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.format.ChatFormat;
import org.beehive.gpullama3.model.phi3.Phi3;
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
-import org.beehive.gpullama3.tokenizer.impl.Phi3Tokenizer;
-import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
-import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
+import org.beehive.gpullama3.tokenizer.Phi3Tokenizer;
+import org.beehive.gpullama3.tokenizer.Tokenizer;
+import org.beehive.gpullama3.tokenizer.Vocabulary;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
-import java.io.IOException;
import java.nio.channels.FileChannel;
import java.util.Map;
-public class Phi3ModelLoader extends ModelLoader {
- public Phi3ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
- super(fileChannel, gguf, contextLength, loadWeights, useTornadovm);
+import static org.beehive.gpullama3.model.loader.ModelLoader.*;
+
+public class Phi3ModelLoader extends AbstractModelLoader {
+ private int modelContextLength;
+
+ public Phi3ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
+ super(fileChannel, gguf, contextLength, useTornadovm);
}
- // @formatter:off
@Override
- public Phi3 loadModel() {
- try {
- Map metadata = gguf.getMetadata();
- final String modelPrefix = "phi3.";
+ protected Vocabulary loadVocabulary(Map metadata) {
+ return Vocabulary.loadPhi3Vocabulary(metadata);
+ }
- Vocabulary vocabulary = Vocabulary.loadPhi3Vocabulary(metadata);
+ @Override
+ protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) {
+ if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
Tokenizer tokenizer = new Phi3Tokenizer(metadata, vocabulary);
-
- if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
- System.out.println("Tokenizer: " + tokenizer.getClass().getSimpleName());
- }
-
- int modelContextLength = (int) metadata.get(modelPrefix + "context_length");
- if (contextLength < 0 || modelContextLength < contextLength) {
- contextLength = modelContextLength;
- }
-
- Phi3Configuration config = new Phi3Configuration(
- (int) metadata.get(modelPrefix + "embedding_length"), // dim
- (int) metadata.get(modelPrefix + "feed_forward_length"), // hidden_dim
- (int) metadata.get(modelPrefix + "block_count"), // n_layers
- (int) metadata.get(modelPrefix + "attention.head_count"), // n_heads
-
- metadata.containsKey(modelPrefix + "attention.head_count_kv")
- ? (int) metadata.get(modelPrefix + "attention.head_count_kv")
- : (int) metadata.get(modelPrefix + "attention.head_count"), // n_kv_heads
-
- vocabulary.size(), // vocab_size
- contextLength, // context_length (user-specified, not model)
- (float) metadata.getOrDefault(modelPrefix + "attention.layer_norm_rms_epsilon", 1e-5f), // rms_norm_eps
- (float) metadata.getOrDefault(modelPrefix + "rope.freq_base", 10000f) // rope_theta
- );
-
- Weights weights = null;
- if (loadWeights) {
- Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
- weights = loadWeights(tensorEntries, config, modelContextLength);
- }
-
- // Phi3 chat tokens
- ChatFormat.ChatTokens chatTokens = new ChatFormat.ChatTokens(
- "<|system|>", "<|end|>", "<|user|>", "<|end|>", "<|assistant|>"
- );
-
- return new Phi3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
- } catch (IOException e) {
- throw new RuntimeException(e);
+ System.out.println("Tokenizer: " + tokenizer.getClass().getSimpleName());
+ return tokenizer;
}
+ return new Phi3Tokenizer(metadata, vocabulary);
}
- // @formatter:on
// @formatter:off
- private Weights loadWeights(Map tensorEntries, Configuration config, int modelContextLength) {
+ @Override
+ protected Phi3Configuration createConfiguration(Map metadata) {
+ final String modelPrefix = "phi3.";
+
+ var config = new Phi3Configuration(
+ (int) metadata.get(modelPrefix + "embedding_length"), // dim
+ (int) metadata.get(modelPrefix + "feed_forward_length"), // hidden_dim
+ (int) metadata.get(modelPrefix + "block_count"), // n_layers
+ (int) metadata.get(modelPrefix + "attention.head_count"), // n_heads
+
+ metadata.containsKey(modelPrefix + "attention.head_count_kv")
+ ? (int) metadata.get(modelPrefix + "attention.head_count_kv")
+ : (int) metadata.get(modelPrefix + "attention.head_count"), // n_kv_heads
+
+ vocabulary.size(), // vocab_size
+ contextLength, // context_length (user-specified, not model)
+ (float) metadata.getOrDefault(modelPrefix + "attention.layer_norm_rms_epsilon", 1e-5f), // rms_norm_eps
+ (float) metadata.getOrDefault(modelPrefix + "rope.freq_base", 10000f) // rope_theta
+ );
+ return config;
+ }
+ // @formatter:off
+
+ // @formatter:off
+ @Override
+ protected Pair precomputeRopeFrequencies(Phi3Configuration config) {
// Calculate head size from dim and numberOfHeads
int headSize = config.dim() / config.numberOfHeads();
- Pair ropeFreqs = RoPE.precomputeFreqsCis(
- modelContextLength, // Use model context length for RoPE precomputation
- headSize, // Calculated head size
+ return RoPE.precomputeFreqsCis(
+ modelContextLength, // Use model context length for RoPE precomputation
+ headSize, // Calculated head size
config.ropeTheta(),
- false, // Phi3 uses standard RoPE, not neox-style based on reference
- 8, 1, 3, 8192 // Additional RoPE parameters from reference
+ false, // Phi3 uses standard RoPE, not neox-style based on reference
+ 8,
+ 1,
+ 3,
+ 8192 // Additional RoPE parameters from reference
);
+ }
+ // @formatter:off
- GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
- GGMLTensorEntry outputWeight = tensorEntries.get("output.weight"); // Phi3 always has separate output weight
+ @Override
+ protected Phi3 createModel(Phi3Configuration config, Tokenizer tokenizer, Weights weights) {
+ // Phi3 chat tokens
+ ChatFormat.ChatTokens chatTokens = new ChatFormat.ChatTokens("<|system|>", "<|end|>", "<|user|>", "<|end|>", "<|assistant|>");
- if (useTornadovm) {
- if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
- System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
- }
- return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- } else {
- return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- }
+ return new Phi3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
}
- // @formatter:on
// @formatter:off
@Override
- public Weights createTornadoVMWeights(Map tensorEntries, Configuration config,
- Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
- GGMLTensorEntry outputWeight) {
- return new Phi3TornadoWeights(
- loadTensorAsFloatArray(tokenEmbeddings),
- loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // Combined QKV
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo
- loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (not combined in reference)
- floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
- FloatArray.fromArray(ropeFreqs.first()),
- FloatArray.fromArray(ropeFreqs.second()),
- loadTensorAsHalfFloatArray(outputWeight),
- outputWeight.ggmlType()
+ protected Weights createStandardWeights(Map tensorEntries, Phi3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) {
+ float[] ropeFreqsReal = ropeFreqs.first();
+ float[] ropeFreqsImag = ropeFreqs.second();
+
+ final int nl = config.numberOfLayers();
+
+ return new Phi3StandardWeights(
+ loadTensor(tokenEmbeddings), // token_embedding_table
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // rms_att_weight (as FloatTensor[])
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // wqkv (combined)
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), // rms_ffn_weight (as FloatTensor[])
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (separate, not combined)
+ loadTensor(tensorEntries.get("output_norm.weight")), // rms_final_weight (as FloatTensor)
+ new ArrayFloatTensor(ropeFreqsReal), // freq_cis_real
+ new ArrayFloatTensor(ropeFreqsImag), // freq_cis_imag
+ loadTensor(outputWeight), // wcls
+ outputWeight.ggmlType() // weightType
);
}
// @formatter:on
// @formatter:off
@Override
- public Weights createStandardWeights(Map tensorEntries,
- Configuration config,
- Pair ropeFreqs,
- GGMLTensorEntry tokenEmbeddings,
- GGMLTensorEntry outputWeight) {
- float[] ropeFreqsReal = ropeFreqs.first();
- float[] ropeFreqsImag = ropeFreqs.second();
+ protected Weights createTornadoVMWeights(Map tensorEntries, Phi3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) {
+ GGMLType ggmlType = outputWeight.ggmlType();
- return new Phi3StandardWeights(
- loadQuantized(tokenEmbeddings), // token_embedding_table
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // rms_att_weight (as FloatTensor[])
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // wqkv (combined)
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), // rms_ffn_weight (as FloatTensor[])
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (separate, not combined)
- loadQuantized(tensorEntries.get("output_norm.weight")), // rms_final_weight (as FloatTensor)
- new ArrayFloatTensor(ropeFreqsReal), // freq_cis_real
- new ArrayFloatTensor(ropeFreqsImag), // freq_cis_imag
- loadQuantized(outputWeight), // wcls
- outputWeight.ggmlType() // weightType
+ if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
+ System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")");
+ }
+
+ // Validate supported types
+ if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) {
+ throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights.");
+ }
+
+ final int nl = config.numberOfLayers();
+
+ // Load all tensors uniformly as TornadoTensor hierarchy
+ return new Phi3TornadoWeights(
+ loadTornadoTensorAsFP32(tokenEmbeddings),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), // fp32
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
+ loadTornadoTensor(tensorEntries.get("output_norm.weight")), // fp32
+ new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.first())),
+ new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.second())),
+ loadTornadoTensor(outputWeight),
+ ggmlType
);
}
// @formatter:on
diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java
index 0fdcce3c..c957c029 100644
--- a/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java
+++ b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java
@@ -1,163 +1,163 @@
package org.beehive.gpullama3.model.loader;
-import org.beehive.gpullama3.Options;
-import org.beehive.gpullama3.core.model.GGMLType;
-import org.beehive.gpullama3.core.model.GGUF;
-import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
-import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry;
-import org.beehive.gpullama3.core.types.Pair;
+import org.beehive.gpullama3.tensor.GGMLType;
+import org.beehive.gpullama3.tensor.GGUF;
+import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
+import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor;
+import org.beehive.gpullama3.tensor.GGMLTensorEntry;
+import org.beehive.gpullama3.auxiliary.Pair;
import org.beehive.gpullama3.inference.operation.RoPE;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.inference.weights.standard.Qwen2StandardWeights;
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
-import org.beehive.gpullama3.model.Configuration;
-import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.format.ChatFormat;
import org.beehive.gpullama3.model.format.ChatFormat.ChatTokens;
import org.beehive.gpullama3.model.qwen2.Qwen2;
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
-import org.beehive.gpullama3.tokenizer.impl.Qwen3Tokenizer;
-import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
-import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
+import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer;
+import org.beehive.gpullama3.tokenizer.Tokenizer;
+import org.beehive.gpullama3.tokenizer.Vocabulary;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
-import java.io.IOException;
import java.nio.channels.FileChannel;
import java.util.Map;
-import static org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary.loadQwen3Vocabulary;
+import static org.beehive.gpullama3.model.loader.ModelLoader.*;
-public class Qwen2ModelLoader extends ModelLoader {
+public class Qwen2ModelLoader extends AbstractModelLoader {
- public Qwen2ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
- super(fileChannel, gguf, contextLength, loadWeights, useTornadovm);
+ public Qwen2ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
+ super(fileChannel, gguf, contextLength, useTornadovm);
}
@Override
- public Model loadModel() {
- Map metadata = gguf.getMetadata();
- String basename = (String) metadata.get("general.basename");
-
- String modelName = "DeepSeek-R1-Distill-Qwen".equals(basename) ? "DeepSeek-R1-Distill-Qwen" : "Qwen2.5";
-
- try {
- // reuse method of Qwen3
- Vocabulary vocabulary = loadQwen3Vocabulary(metadata);
- boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename"));
- Tokenizer tokenizer = new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen);
-
- int modelContextLength = (int) metadata.get("qwen2.context_length");
- if (contextLength < 0 || modelContextLength < contextLength) {
- contextLength = modelContextLength;
- }
-
- int numberOfKeyValueHeads = metadata.containsKey("qwen2.attention.head_count_kv") ? (int) metadata.get("qwen2.attention.head_count_kv") : (int) metadata.get("qwen2.attention.head_count");
- Qwen2Configuration config = new Qwen2Configuration((int) metadata.get("qwen2.embedding_length"), // dim
- (int) metadata.get("qwen2.feed_forward_length"), // hiddendim
- (int) metadata.get("qwen2.block_count"), // numberOfLayers
- (int) metadata.get("qwen2.attention.head_count"), // numberOfHeads
-
- numberOfKeyValueHeads, // numberOfKeyValueHeads
- numberOfKeyValueHeads, // numberOfHeadsKey
- numberOfKeyValueHeads, // numberOfHeadsValue
-
- vocabulary.size(), modelContextLength, contextLength, false, (float) metadata.get("qwen2.attention.layer_norm_rms_epsilon"), (float) metadata.get("qwen2.rope.freq_base"));
-
- Weights weights = null;
- if (loadWeights) {
- Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
- weights = loadWeights(tensorEntries, config);
- }
- // Qwen2.5-Coder uses <|endoftext|> as stop-token.
- ChatTokens chatTokens = isDeepSeekR1DistillQwen
- ? new ChatTokens("<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "")
- : new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>");
- return new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
+ protected Vocabulary loadVocabulary(Map metadata) {
+ return Vocabulary.loadQwen3Vocabulary(metadata);
+ }
+
+ @Override
+ protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) {
+ boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename"));
+ return new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen);
}
// @formatter:off
@Override
- public Weights loadWeights(Map tensorEntries, Configuration config) {
- Pair ropeFreqs = RoPE.precomputeFreqsCis(
- config.contextLengthModel(),
- config.headSize(),
- config.ropeTheta(),
+ protected Qwen2Configuration createConfiguration(Map metadata) {
+ int modelContextLength = (int) metadata.get("qwen2.context_length");
+ int finalContextLength = (contextLength < 0 || modelContextLength < contextLength) ? modelContextLength : contextLength;
+
+ int numberOfKeyValueHeads = metadata.containsKey("qwen2.attention.head_count_kv") ? (int) metadata.get("qwen2.attention.head_count_kv") : (int) metadata.get("qwen2.attention.head_count");
+ int vocabSize = vocabulary.size();
+
+ return new Qwen2Configuration(
+ (int) metadata.get("qwen2.embedding_length"), // dim
+ (int) metadata.get("qwen2.feed_forward_length"), // hiddendim
+ (int) metadata.get("qwen2.block_count"), // numberOfLayers
+ (int) metadata.get("qwen2.attention.head_count"), // numberOfHeads
+
+ numberOfKeyValueHeads, // numberOfKeyValueHeads
+ numberOfKeyValueHeads, // numberOfHeadsKey
+ numberOfKeyValueHeads, // numberOfHeadsValue
+
+ vocabSize,
+ modelContextLength,
+ finalContextLength,
false,
- 8,
- 1,
- 3,
- 8192
+ (float) metadata.get("qwen2.attention.layer_norm_rms_epsilon"),
+ (float) metadata.get("qwen2.rope.freq_base")
);
+ }
+ // @formatter:on
- GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
- GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings);
+ @Override
+ protected Pair precomputeRopeFrequencies(Qwen2Configuration config) {
+ return RoPE.precomputeFreqsCis(config.contextLengthModel(), config.headSize(), config.ropeTheta(), false, 8, 1, 3, 8192);
+ }
- if (useTornadovm) {
- if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
- System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
- }
- return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- } else {
- return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- }
+ // @formatter:off
+ @Override
+ protected Qwen2 createModel(Qwen2Configuration config, Tokenizer tokenizer, Weights weights) {
+ Map metadata = gguf.getMetadata();
+ boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename"));
+ // Qwen2.5-Coder uses <|endoftext|> as stop-token.
+ ChatTokens chatTokens = isDeepSeekR1DistillQwen ? new ChatTokens("<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "")
+ : new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>");
+ return new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
}
+ // @formatter:on
+ // @formatter:off
@Override
- public Weights createStandardWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
- GGMLTensorEntry outputWeight) {
+ protected Weights createStandardWeights(Map tensorEntries, Qwen2Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
+
+ final int nl = config.numberOfLayers();
+
return new Qwen2StandardWeights(
- loadQuantized(tokenEmbeddings),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
-
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.bias")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.bias")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.bias")),
-
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
- loadQuantized(tensorEntries.get("output_norm.weight")),
+ loadTensor(tokenEmbeddings),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.bias")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.bias")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.bias")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
+ loadTensor(tensorEntries.get("output_norm.weight")),
new ArrayFloatTensor(ropeFreqs.first()),
new ArrayFloatTensor(ropeFreqs.second()),
- loadQuantized(outputWeight),
- outputWeight.ggmlType());
+ loadTensor(outputWeight),
+ outputWeight.ggmlType()
+ );
}
+ // @formatter:on
+ // @formatter:off
@Override
- public Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
- GGMLTensorEntry outputWeight) {
+ protected Weights createTornadoVMWeights(Map tensorEntries, Qwen2Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
+ GGMLType ggmlType = outputWeight.ggmlType();
+
+ if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
+ System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")");
+ }
+
+ // Validate supported types
+ if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) {
+ throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights.");
+ }
+
+ final int nl = config.numberOfLayers();
+
+ // Load all tensors uniformly as TornadoTensor hierarchy
return new Qwen2TornadoWeights(
- loadTensorAsFloatArray(tokenEmbeddings),
- loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
- // Qwen2-specific: qkv bias
- loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.bias")),
- loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.bias")),
- loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.bias")),
-
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
- loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3
- floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
- FloatArray.fromArray(ropeFreqs.first()),
- FloatArray.fromArray(ropeFreqs.second()),
- loadTensorAsHalfFloatArray(outputWeight),
- outputWeight.ggmlType()
+ loadTornadoTensorAsFP32(tokenEmbeddings),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
+ // Qwen2-specific: qkv bias (always F32)
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.bias")), // fp32
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.bias")), // fp32
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.bias")), // fp32
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), // fp32
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
+ loadTornadoTensor(tensorEntries.get("output_norm.weight")), // fp32
+ new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.first())),
+ new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.second())),
+ loadTornadoTensor(outputWeight),
+ ggmlType
);
- }
- // @formatter:on
+ }
+ // @formatter:off
}
diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java
index 8671b8ef..008af2b3 100644
--- a/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java
+++ b/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java
@@ -1,175 +1,162 @@
package org.beehive.gpullama3.model.loader;
-import org.beehive.gpullama3.Options;
-import org.beehive.gpullama3.auxiliary.Timer;
-import org.beehive.gpullama3.core.model.GGMLType;
-import org.beehive.gpullama3.core.model.GGUF;
-import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
-import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry;
-import org.beehive.gpullama3.core.types.Pair;
+import org.beehive.gpullama3.tensor.GGMLType;
+import org.beehive.gpullama3.tensor.GGUF;
+import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
+import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor;
+import org.beehive.gpullama3.tensor.GGMLTensorEntry;
+import org.beehive.gpullama3.auxiliary.Pair;
import org.beehive.gpullama3.inference.operation.RoPE;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.inference.weights.standard.Qwen3StandardWeights;
import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights;
-import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.format.ChatFormat;
import org.beehive.gpullama3.model.format.ChatFormat.ChatTokens;
import org.beehive.gpullama3.model.qwen3.Qwen3;
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
-import org.beehive.gpullama3.tokenizer.impl.Qwen3Tokenizer;
-import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
-import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
+import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer;
+import org.beehive.gpullama3.tokenizer.Tokenizer;
+import org.beehive.gpullama3.tokenizer.Vocabulary;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
-import java.io.IOException;
import java.nio.channels.FileChannel;
import java.util.Map;
-import static org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary.loadQwen3Vocabulary;
+import static org.beehive.gpullama3.model.loader.ModelLoader.*;
+import static org.beehive.gpullama3.tokenizer.Vocabulary.loadQwen3Vocabulary;
-public class Qwen3ModelLoader extends ModelLoader {
+public class Qwen3ModelLoader extends AbstractModelLoader {
- public Qwen3ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
- super(fileChannel, gguf, contextLength, loadWeights, useTornadovm);
+ public Qwen3ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
+ super(fileChannel, gguf, contextLength, useTornadovm);
}
- // @formatter:off
@Override
- public Qwen3 loadModel() {
- try {
- Map metadata = gguf.getMetadata();
-
- Vocabulary vocabulary = loadQwen3Vocabulary(metadata);
- boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename"));
- Tokenizer tokenizer = new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen);
-
- int modelContextLength = (int) metadata.get("qwen3.context_length");
- if (contextLength < 0 || modelContextLength < contextLength) {
- contextLength = modelContextLength;
- }
-
- Qwen3Configuration config = new Qwen3Configuration(
- (int) metadata.get("qwen3.embedding_length"),
- (int) metadata.get("qwen3.feed_forward_length"),
- (int) metadata.get("qwen3.block_count"),
- (int) metadata.get("qwen3.attention.head_count"),
-
- metadata.containsKey("qwen3.attention.head_count_kv")
- ? (int) metadata.get("qwen3.attention.head_count_kv")
- : (int) metadata.get("qwen3.attention.head_count"),
- (int) metadata.get("qwen3.attention.key_length"),
- (int) metadata.get("qwen3.attention.value_length"),
-
- vocabulary.size(),
- modelContextLength, contextLength,
- false,
- (float) metadata.get("qwen3.attention.layer_norm_rms_epsilon"),
- (float) metadata.get("qwen3.rope.freq_base")
- );
-
- Weights weights = null;
- if (loadWeights) {
- Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
- weights = loadWeights(tensorEntries, config);
- }
- // Qwen2.5-coder uses <|endoftext|> as stop-token.
- ChatTokens chatTokens = isDeepSeekR1DistillQwen ?
- new ChatTokens( "<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "") :
- new ChatTokens( "<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>");
- return new Qwen3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
+ protected Vocabulary loadVocabulary(Map metadata) {
+ return loadQwen3Vocabulary(metadata);
+ }
+
+ @Override
+ protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) {
+ boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename"));
+ return new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen);
}
- // @formatter:on
// @formatter:off
@Override
- public Weights loadWeights(Map tensorEntries, Configuration config) {
- Pair ropeFreqs = RoPE.precomputeFreqsCis(
- config.contextLengthModel(),
- config.numberOfHeadsKey(),
- config.ropeTheta(),
+ protected Qwen3Configuration createConfiguration(Map metadata) {
+ int modelContextLength = (int) metadata.get("qwen3.context_length");
+ int finalContextLength = (contextLength < 0 || modelContextLength < contextLength) ? modelContextLength : contextLength;
+
+ int vocabSize = vocabulary.size();
+
+ return new Qwen3Configuration(
+ (int) metadata.get("qwen3.embedding_length"),
+ (int) metadata.get("qwen3.feed_forward_length"),
+ (int) metadata.get("qwen3.block_count"),
+ (int) metadata.get("qwen3.attention.head_count"),
+
+ metadata.containsKey("qwen3.attention.head_count_kv") ?
+ (int) metadata.get("qwen3.attention.head_count_kv") :
+ (int) metadata.get("qwen3.attention.head_count"),
+ (int) metadata.get("qwen3.attention.key_length"),
+ (int) metadata.get("qwen3.attention.value_length"),
+
+ vocabSize,
+ modelContextLength,
+ finalContextLength,
false,
- 0,
- 0,
- 0,
- 0
+ (float) metadata.get("qwen3.attention.layer_norm_rms_epsilon"),
+ (float) metadata.get("qwen3.rope.freq_base")
);
-
- GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
- GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings);
-
- if (useTornadovm) {
- if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
- System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
- }
- return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- } else {
- return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- }
}
// @formatter:on
+ @Override
+ protected Pair precomputeRopeFrequencies(Qwen3Configuration config) {
+ return RoPE.precomputeFreqsCis(config.contextLengthModel(), config.numberOfHeadsKey(), config.ropeTheta(), false, 0, 0, 0, 0);
+ }
+
// @formatter:off
@Override
- public Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
- GGMLTensorEntry outputWeight) {
- return new Qwen3TornadoWeights(
- loadTensorAsFloatArray(tokenEmbeddings),
- loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
- loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k_norm.weight")), // attnKNorm
- loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q_norm.weight")), // attnQNorm
- loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3
- floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
- FloatArray.fromArray(ropeFreqs.first()),
- FloatArray.fromArray(ropeFreqs.second()),
- loadTensorAsHalfFloatArray(outputWeight),
- outputWeight.ggmlType()
- );
+ protected Qwen3 createModel(Qwen3Configuration config, Tokenizer tokenizer, Weights weights) {
+ Map metadata = gguf.getMetadata();
+ boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename"));
+ // Qwen2.5-coder uses <|endoftext|> as stop-token.
+ ChatTokens chatTokens = isDeepSeekR1DistillQwen ? new ChatTokens("<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "")
+ : new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>");
+ return new Qwen3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
}
- // @formatter:on
+ // @formatter:off
// @formatter:off
@Override
- public Weights createStandardWeights(Map tensorEntries,
- Configuration config,
- Pair ropeFreqs,
- GGMLTensorEntry tokenEmbeddings,
- GGMLTensorEntry outputWeight) {
+ protected Weights createStandardWeights(Map tensorEntries, Qwen3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
float[] ropeFreqsReal = ropeFreqs.first();
float[] ropeFreqsImag = ropeFreqs.second();
+
+ final int nl = config.numberOfLayers();
+
return new Qwen3StandardWeights(
- loadQuantized(tokenEmbeddings),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // rms_att_weight
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), // wq
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), // wk
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), // wv
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo
-
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k_norm.weight")), // attnKNorm
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q_norm.weight")), // attnQNorm
-
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), //rms_ffn_weight
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3
- loadQuantized(tensorEntries.get("output_norm.weight")), // rms_final_weight
+ loadTensor(tokenEmbeddings),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // rms_att_weight
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), // wq
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), // wk
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), // wv
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k_norm.weight")), // attnKNorm
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q_norm.weight")), // attnQNorm
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), //rms_ffn_weight
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3
+ loadTensor(tensorEntries.get("output_norm.weight")), // rms_final_weight
new ArrayFloatTensor(ropeFreqsReal),
new ArrayFloatTensor(ropeFreqsImag),
tensorEntries.containsKey("output.weight")
- ? ModelLoader.loadQuantized(tensorEntries.get("output.weight"))
- : loadQuantized(tokenEmbeddings), // weights are shared
+ ? ModelLoader.loadTensor(tensorEntries.get("output.weight"))
+ : loadTensor(tokenEmbeddings), // weights are shared
null
);
}
// @formatter:on
+
+ // @formatter:off
+ @Override
+ protected Weights createTornadoVMWeights(Map tensorEntries, Qwen3Configuration config,
+ Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
+ if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
+ System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
+ }
+
+ GGMLType ggmlType = outputWeight.ggmlType();
+
+ final int nl = config.numberOfLayers();
+
+ return new Qwen3TornadoWeights(
+ loadTornadoTensorAsFP32(tokenEmbeddings),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
+ // Qwen3-specific: attnKNorm and attnQNorm (always F32)
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k_norm.weight")), // fp32
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q_norm.weight")), // fp32
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), // fp32
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
+ loadTornadoTensor(tensorEntries.get("output_norm.weight")), // fp32
+ new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.first())),
+ new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.second())),
+ loadTornadoTensor(outputWeight),
+ ggmlType
+ );
+
+ }
+ // @formatter:on
}
diff --git a/src/main/java/org/beehive/gpullama3/model/mistral/Mistral.java b/src/main/java/org/beehive/gpullama3/model/mistral/Mistral.java
index 8176b85b..931f4317 100644
--- a/src/main/java/org/beehive/gpullama3/model/mistral/Mistral.java
+++ b/src/main/java/org/beehive/gpullama3/model/mistral/Mistral.java
@@ -9,8 +9,8 @@
import org.beehive.gpullama3.model.AbstractModel;
import org.beehive.gpullama3.model.ModelType;
import org.beehive.gpullama3.model.format.ChatFormat;
-import org.beehive.gpullama3.tokenizer.impl.MistralTokenizer;
-import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
+import org.beehive.gpullama3.tokenizer.MistralTokenizer;
+import org.beehive.gpullama3.tokenizer.Tokenizer;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
import java.util.List;
diff --git a/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java b/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java
index 1ee4ce46..3328a55f 100644
--- a/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java
+++ b/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java
@@ -9,8 +9,8 @@
import org.beehive.gpullama3.model.AbstractModel;
import org.beehive.gpullama3.model.ModelType;
import org.beehive.gpullama3.model.format.ChatFormat;
-import org.beehive.gpullama3.tokenizer.impl.Phi3Tokenizer;
-import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
+import org.beehive.gpullama3.tokenizer.Phi3Tokenizer;
+import org.beehive.gpullama3.tokenizer.Tokenizer;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
import java.util.List;
diff --git a/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java b/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java
index e8fcb581..92fdf564 100644
--- a/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java
+++ b/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java
@@ -9,8 +9,8 @@
import org.beehive.gpullama3.model.AbstractModel;
import org.beehive.gpullama3.model.ModelType;
import org.beehive.gpullama3.model.format.ChatFormat;
-import org.beehive.gpullama3.tokenizer.impl.Qwen3Tokenizer;
-import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
+import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer;
+import org.beehive.gpullama3.tokenizer.Tokenizer;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
import java.util.List;
diff --git a/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java b/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java
index bf90c13d..cf16b3cc 100644
--- a/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java
+++ b/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java
@@ -9,8 +9,8 @@
import org.beehive.gpullama3.model.AbstractModel;
import org.beehive.gpullama3.model.ModelType;
import org.beehive.gpullama3.model.format.ChatFormat;
-import org.beehive.gpullama3.tokenizer.impl.Qwen3Tokenizer;
-import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
+import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer;
+import org.beehive.gpullama3.tokenizer.Tokenizer;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
import java.util.List;
diff --git a/src/main/java/org/beehive/gpullama3/core/types/Float16.java b/src/main/java/org/beehive/gpullama3/tensor/Float16.java
similarity index 62%
rename from src/main/java/org/beehive/gpullama3/core/types/Float16.java
rename to src/main/java/org/beehive/gpullama3/tensor/Float16.java
index 6639a41b..fb171317 100644
--- a/src/main/java/org/beehive/gpullama3/core/types/Float16.java
+++ b/src/main/java/org/beehive/gpullama3/tensor/Float16.java
@@ -1,4 +1,4 @@
-package org.beehive.gpullama3.core.types;
+package org.beehive.gpullama3.tensor;
public final class Float16 {
public static final int BYTES = 2;
diff --git a/src/main/java/org/beehive/gpullama3/core/model/tensor/GGMLTensorEntry.java b/src/main/java/org/beehive/gpullama3/tensor/GGMLTensorEntry.java
similarity index 67%
rename from src/main/java/org/beehive/gpullama3/core/model/tensor/GGMLTensorEntry.java
rename to src/main/java/org/beehive/gpullama3/tensor/GGMLTensorEntry.java
index 8098aa11..9af9b10f 100644
--- a/src/main/java/org/beehive/gpullama3/core/model/tensor/GGMLTensorEntry.java
+++ b/src/main/java/org/beehive/gpullama3/tensor/GGMLTensorEntry.java
@@ -1,6 +1,4 @@
-package org.beehive.gpullama3.core.model.tensor;
-
-import org.beehive.gpullama3.core.model.GGMLType;
+package org.beehive.gpullama3.tensor;
import java.lang.foreign.MemorySegment;
diff --git a/src/main/java/org/beehive/gpullama3/core/model/GGMLType.java b/src/main/java/org/beehive/gpullama3/tensor/GGMLType.java
similarity index 98%
rename from src/main/java/org/beehive/gpullama3/core/model/GGMLType.java
rename to src/main/java/org/beehive/gpullama3/tensor/GGMLType.java
index 972a4f52..f1888bb2 100644
--- a/src/main/java/org/beehive/gpullama3/core/model/GGMLType.java
+++ b/src/main/java/org/beehive/gpullama3/tensor/GGMLType.java
@@ -1,4 +1,4 @@
-package org.beehive.gpullama3.core.model;
+package org.beehive.gpullama3.tensor;
public enum GGMLType {
// Floating point types
diff --git a/src/main/java/org/beehive/gpullama3/core/model/GGUF.java b/src/main/java/org/beehive/gpullama3/tensor/GGUF.java
similarity index 64%
rename from src/main/java/org/beehive/gpullama3/core/model/GGUF.java
rename to src/main/java/org/beehive/gpullama3/tensor/GGUF.java
index c32cdc1d..9cdc5b7d 100644
--- a/src/main/java/org/beehive/gpullama3/core/model/GGUF.java
+++ b/src/main/java/org/beehive/gpullama3/tensor/GGUF.java
@@ -1,15 +1,14 @@
-package org.beehive.gpullama3.core.model;
+package org.beehive.gpullama3.tensor;
-import org.beehive.gpullama3.auxiliary.Timer;
-import org.beehive.gpullama3.core.model.tensor.FloatTensor;
-import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry;
-import org.beehive.gpullama3.core.types.MetadataValueType;
-import org.beehive.gpullama3.core.types.Pair;
+import org.beehive.gpullama3.tensor.standard.FloatTensor;
+import org.beehive.gpullama3.auxiliary.Pair;
+import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
+import java.lang.foreign.ValueLayout;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
@@ -20,7 +19,11 @@
import java.util.List;
import java.util.Map;
+import static java.nio.file.StandardOpenOption.READ;
+import static java.nio.file.StandardOpenOption.WRITE;
+
public final class GGUF {
+ private static FileChannel fileChannel;
private static final int GGUF_MAGIC = 0x46554747;
private static final int DEFAULT_ALIGNMENT = 32; // must be a power of 2
private static final List SUPPORTED_GGUF_VERSIONS = List.of(2, 3);
@@ -37,38 +40,159 @@ public final class GGUF {
private Map tensorInfos;
private long tensorDataOffset;
- public static GGUF loadModel(Path modelPath) throws IOException {
+ public static GGUF loadGGUFMetadata(Path modelPath) throws IOException {
// file existence check
if (!Files.exists(modelPath)) {
throw new FileNotFoundException("Model file not found: " + modelPath);
}
- // second check to make sure that nothing goes wrong during model loading
- try (FileChannel fileChannel = FileChannel.open(modelPath);
- ) {
+ // Open file
+ try {
+ fileChannel = FileChannel.open(modelPath, READ, WRITE);
+ // Ensure we start reading from the beginning of the file
+ fileChannel.position(0);
+ } catch (Exception e) {
+ throw new RuntimeException("Failed to open file channel for " + modelPath, e);
+ }
+
+ // Read and store the gguf metadata
+ try {
GGUF gguf = new GGUF();
- gguf.loadModelImpl(fileChannel);
+ // The header of the file.
+ gguf.readHeader(fileChannel); // gguf_header_t header;
+ // Tensor infos, which can be used to locate the tensor data.
+ // gguf_tensor_info_t tensor_infos[header.tensor_count];
+ gguf.tensorInfos = HashMap.newHashMap(gguf.tensorCount);
+ for (int i = 0; i < gguf.tensorCount; ++i) {
+ GGUF.GGUFTensorInfo ti = gguf.readTensorInfo(fileChannel);
+ assert !gguf.tensorInfos.containsKey(ti.name);
+ gguf.tensorInfos.put(ti.name, ti);
+ }
+ // Padding to the nearest multiple of `ALIGNMENT`.
+ // uint8_t _padding[ALIGNMENT - (sizeof(header + tensor_infos) % ALIGNMENT)];
+ long _padding = (gguf.getAlignment() - (fileChannel.position() % gguf.getAlignment())) % gguf.getAlignment();
+ fileChannel.position(fileChannel.position() + _padding);
+ // Tensor data.
+ //
+ // This is arbitrary binary data corresponding to the weights of the model. This data should be close
+ // or identical to the data in the original model file, but may be different due to quantization or
+ // other optimizations for inference. Any such deviations should be recorded in the metadata or as
+ // part of the architecture definition.
+ //
+ // Each tensor's data must be stored within this array, and located through its `tensor_infos` entry.
+ // The offset of each tensor's data must be a multiple of `ALIGNMENT`, and the space between tensors
+ // should be padded to `ALIGNMENT` bytes.
+ // uint8_t tensor_data[];
+ gguf.tensorDataOffset = fileChannel.position();
return gguf;
} catch (Exception e) {
throw new RuntimeException("Unexpected error while loading GGUF model from " + modelPath, e);
}
}
- public static Map loadTensors(FileChannel fileChannel, long tensorDataOffset, Map tensorInfos) throws IOException {
+ /**
+ * Loads tensor data from a given file channel based on the tensor metadata information.
+ * The mapping is read-only and creates standard memory segments for each tensor.
+ *
+ * @param fileChannel the channel from which tensor storage is read
+ * @param tensorDataOffset the absolute byte offset of the GGUF tensor-data section
+ * @param tensorInfos metadata describing all GGUF tensors
+ * @return a map from tensor name to {@link GGMLTensorEntry} containing
+ * standard memory segments for each tensor
+ * @throws IOException if memory mapping fails or the channel cannot be read
+ */
+ public static Map loadTensorsStandard(FileChannel fileChannel, long tensorDataOffset, Map tensorInfos) throws IOException {
Arena arena = Arena.ofAuto();
- MemorySegment tensorData = fileChannel.map(FileChannel.MapMode.READ_ONLY, tensorDataOffset, fileChannel.size() - tensorDataOffset, arena);
+
+ // absolute file offset where the tensor-data section begins
+ long mappingOffset = tensorDataOffset;
+ // size of the entire tensor-data section
+ long mappingSize = fileChannel.size() - tensorDataOffset;
+
+ MemorySegment tensorData = fileChannel.map(FileChannel.MapMode.READ_ONLY, mappingOffset, mappingSize, arena);
+
Map tensorEntries = HashMap.newHashMap(tensorInfos.size());
+
for (Map.Entry entry : tensorInfos.entrySet()) {
GGUFTensorInfo ti = entry.getValue();
+
+ // skip rope_freqs.weight (not needed for inference)
+ if (ti.name().equals("rope_freqs.weight")) {
+ continue;
+ }
+
int numberOfElements = FloatTensor.numberOfElements(ti.dimensions());
int sizeInBytes = Math.toIntExact(ti.ggmlType().byteSizeFor(numberOfElements));
- MemorySegment memorySegment = tensorData.asSlice(ti.offset(), sizeInBytes);
+
+ // per-tensor slice offset; ti.offset() is relative to tensor-data start
+ long offset = ti.offset();
+
+ // per-tensor slice segment
+ MemorySegment memorySegment = tensorData.asSlice(offset, sizeInBytes);
+
tensorEntries.put(ti.name(), new GGMLTensorEntry(tensorData, ti.name(), ti.ggmlType(), ti.dimensions(), memorySegment));
}
return tensorEntries;
}
+ /**
+ * Loads GGUF tensor data using a TornadoVM-compatible memory layout.
+ *
+ * This method parses the GGUF tensor list and memory-maps each tensor
+ * in {@link TornadoNativeArray} layout directly from the underlying {@link FileChannel}.
+ * For compatibility with {@link TornadoNativeArray} layout, an additional header is required at
+ * the start of each tensor region. To satisfy this requirement, each tensor
+ * is mapped using {@link FileChannel.MapMode#PRIVATE} starting 16 bytes
+ * before the actual tensor position, providing a writable header region
+ * without modifying the underlying GGUF file.
+ *
+ * @param fileChannel the channel from which tensor storage is read
+ * @param tensorDataOffset the absolute byte offset of the GGUF tensor-data section
+ * @param tensorInfos metadata describing all GGUF tensors
+ * @return a map from tensor name to {@link GGMLTensorEntry} containing
+ * TornadoVM-compatible memory segments for each tensor
+ * @throws IOException if memory mapping fails or the channel cannot be read
+ */
+ public static Map loadTensorsTornado(FileChannel fileChannel, long tensorDataOffset, Map tensorInfos) throws IOException {
+
+ Arena arena = Arena.ofAuto();
+ Map tensorEntries = HashMap.newHashMap(tensorInfos.size());
+
+ for (Map.Entry entry : tensorInfos.entrySet()) {
+ GGUFTensorInfo ti = entry.getValue();
+
+ // skip rope_freqs.weight (not required for inference)
+ if (ti.name().equals("rope_freqs.weight")) {
+ continue;
+ }
+
+ int numberOfElements = FloatTensor.numberOfElements(ti.dimensions());
+ int sizeInBytes = Math.toIntExact(ti.ggmlType().byteSizeFor(numberOfElements));
+
+ // absolute tensor offset - relative to start of the file
+ long mappingOffset = tensorDataOffset + ti.offset();
+
+ // create memory segment in TornadoVM NativeArray layout:
+ // TornadoNativeArray.ARRAY_HEADER (16-byte) + tensor data
+ long headerBytes = TornadoNativeArray.ARRAY_HEADER;
+
+ // start 16 bytes before the tensor position to include header space
+ long offset = mappingOffset - headerBytes;
+ long size = sizeInBytes + headerBytes;
+ MemorySegment memorySegment = fileChannel.map(FileChannel.MapMode.PRIVATE, offset, size, arena);
+
+ // zero out the 16-byte header
+ for (int i = 0; i < headerBytes; i++) {
+ memorySegment.set(ValueLayout.JAVA_BYTE, i, (byte) 0);
+ }
+
+ // store tornado-compatible segment
+ tensorEntries.put(ti.name(), new GGMLTensorEntry(memorySegment, ti.name(), ti.ggmlType(), ti.dimensions(), memorySegment));
+ }
+ return tensorEntries;
+ }
+
public Map getTensorInfos() {
return tensorInfos;
}
@@ -81,34 +205,8 @@ public Map getMetadata() {
return metadata;
}
- private void loadModelImpl(FileChannel fileChannel) throws IOException {
- // The header of the file.
- readHeader(fileChannel); // gguf_header_t header;
- // Tensor infos, which can be used to locate the tensor data.
- // gguf_tensor_info_t tensor_infos[header.tensor_count];
- this.tensorInfos = HashMap.newHashMap(tensorCount);
- for (int i = 0; i < tensorCount; ++i) {
- GGUF.GGUFTensorInfo ti = readTensorInfo(fileChannel);
- assert !tensorInfos.containsKey(ti.name);
- tensorInfos.put(ti.name, ti);
- }
- // Padding to the nearest multiple of `ALIGNMENT`.
- // uint8_t _padding[ALIGNMENT - (sizeof(header + tensor_infos) % ALIGNMENT)];
- //long _padding = -fileChannel.position() & (ALIGNMENT - 1);
- long _padding = getAlignment() - (fileChannel.position() % getAlignment());
- fileChannel.position(fileChannel.position() + _padding);
- // Tensor data.
- //
- // This is arbitrary binary data corresponding to the weights of the model. This data should be close
- // or identical to the data in the original model file, but may be different due to quantization or
- // other optimizations for inference. Any such deviations should be recorded in the metadata or as
- // part of the architecture definition.
- //
- // Each tensor's data must be stored within this array, and located through its `tensor_infos` entry.
- // The offset of each tensor's data must be a multiple of `ALIGNMENT`, and the space between tensors
- // should be padded to `ALIGNMENT` bytes.
- // uint8_t tensor_data[];
- this.tensorDataOffset = fileChannel.position();
+ public FileChannel getFileChannel() {
+ return fileChannel;
}
private GGMLType readGGMLType(FileChannel fileChannel) throws IOException {
diff --git a/src/main/java/org/beehive/gpullama3/core/types/MetadataValueType.java b/src/main/java/org/beehive/gpullama3/tensor/MetadataValueType.java
similarity index 97%
rename from src/main/java/org/beehive/gpullama3/core/types/MetadataValueType.java
rename to src/main/java/org/beehive/gpullama3/tensor/MetadataValueType.java
index 911f364d..f7e08346 100644
--- a/src/main/java/org/beehive/gpullama3/core/types/MetadataValueType.java
+++ b/src/main/java/org/beehive/gpullama3/tensor/MetadataValueType.java
@@ -1,4 +1,4 @@
-package org.beehive.gpullama3.core.types;
+package org.beehive.gpullama3.tensor;
public enum MetadataValueType {
// The value is a 8-bit unsigned integer.
diff --git a/src/main/java/org/beehive/gpullama3/core/model/tensor/ArrayFloatTensor.java b/src/main/java/org/beehive/gpullama3/tensor/standard/ArrayFloatTensor.java
similarity index 93%
rename from src/main/java/org/beehive/gpullama3/core/model/tensor/ArrayFloatTensor.java
rename to src/main/java/org/beehive/gpullama3/tensor/standard/ArrayFloatTensor.java
index 1214967f..d25623cc 100644
--- a/src/main/java/org/beehive/gpullama3/core/model/tensor/ArrayFloatTensor.java
+++ b/src/main/java/org/beehive/gpullama3/tensor/standard/ArrayFloatTensor.java
@@ -1,6 +1,6 @@
-package org.beehive.gpullama3.core.model.tensor;
+package org.beehive.gpullama3.tensor.standard;
-import org.beehive.gpullama3.core.model.GGMLType;
+import org.beehive.gpullama3.tensor.GGMLType;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.VectorSpecies;
diff --git a/src/main/java/org/beehive/gpullama3/core/model/tensor/F16FloatTensor.java b/src/main/java/org/beehive/gpullama3/tensor/standard/FP16FloatTensor.java
similarity index 92%
rename from src/main/java/org/beehive/gpullama3/core/model/tensor/F16FloatTensor.java
rename to src/main/java/org/beehive/gpullama3/tensor/standard/FP16FloatTensor.java
index 9e7ec8bf..88587072 100644
--- a/src/main/java/org/beehive/gpullama3/core/model/tensor/F16FloatTensor.java
+++ b/src/main/java/org/beehive/gpullama3/tensor/standard/FP16FloatTensor.java
@@ -1,6 +1,6 @@
-package org.beehive.gpullama3.core.model.tensor;
+package org.beehive.gpullama3.tensor.standard;
-import org.beehive.gpullama3.core.model.GGMLType;
+import org.beehive.gpullama3.tensor.GGMLType;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorOperators;
@@ -9,12 +9,12 @@
import java.lang.foreign.MemorySegment;
import java.nio.ByteOrder;
-public final class F16FloatTensor extends FloatTensor {
+public final class FP16FloatTensor extends FloatTensor {
final int size;
final MemorySegment memorySegment;
- public F16FloatTensor(int size, MemorySegment memorySegment) {
+ public FP16FloatTensor(int size, MemorySegment memorySegment) {
this.size = size;
this.memorySegment = memorySegment;
}
@@ -59,7 +59,7 @@ public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) {
}
}
- private static float vectorDot(F16FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) {
+ private static float vectorDot(FP16FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) {
assert S_SPECIES_HALF.length() == F_SPECIES.length();
FloatVector val = FloatVector.zero(F_SPECIES);
int upperBound = F_SPECIES.loopBound(size);
diff --git a/src/main/java/org/beehive/gpullama3/core/model/tensor/F32FloatTensor.java b/src/main/java/org/beehive/gpullama3/tensor/standard/FP32FloatTensor.java
similarity index 82%
rename from src/main/java/org/beehive/gpullama3/core/model/tensor/F32FloatTensor.java
rename to src/main/java/org/beehive/gpullama3/tensor/standard/FP32FloatTensor.java
index f188e9f5..2deff33e 100644
--- a/src/main/java/org/beehive/gpullama3/core/model/tensor/F32FloatTensor.java
+++ b/src/main/java/org/beehive/gpullama3/tensor/standard/FP32FloatTensor.java
@@ -1,17 +1,17 @@
-package org.beehive.gpullama3.core.model.tensor;
+package org.beehive.gpullama3.tensor.standard;
-import org.beehive.gpullama3.core.model.GGMLType;
+import org.beehive.gpullama3.tensor.GGMLType;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.VectorSpecies;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
-public final class F32FloatTensor extends FloatTensor {
+public final class FP32FloatTensor extends FloatTensor {
final int size;
final MemorySegment segment;
- public F32FloatTensor(int size, MemorySegment segment) {
+ public FP32FloatTensor(int size, MemorySegment segment) {
this.size = size;
this.segment = segment;
}
diff --git a/src/main/java/org/beehive/gpullama3/core/model/tensor/FloatTensor.java b/src/main/java/org/beehive/gpullama3/tensor/standard/FloatTensor.java
similarity index 98%
rename from src/main/java/org/beehive/gpullama3/core/model/tensor/FloatTensor.java
rename to src/main/java/org/beehive/gpullama3/tensor/standard/FloatTensor.java
index f0c7f2cf..d91ab964 100644
--- a/src/main/java/org/beehive/gpullama3/core/model/tensor/FloatTensor.java
+++ b/src/main/java/org/beehive/gpullama3/tensor/standard/FloatTensor.java
@@ -1,7 +1,7 @@
-package org.beehive.gpullama3.core.model.tensor;
+package org.beehive.gpullama3.tensor.standard;
import org.beehive.gpullama3.auxiliary.Parallel;
-import org.beehive.gpullama3.core.model.GGMLType;
+import org.beehive.gpullama3.tensor.GGMLType;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.VectorShape;
import jdk.incubator.vector.VectorSpecies;
diff --git a/src/main/java/org/beehive/gpullama3/core/model/tensor/Q4_0FloatTensor.java b/src/main/java/org/beehive/gpullama3/tensor/standard/Q4_0FloatTensor.java
similarity index 97%
rename from src/main/java/org/beehive/gpullama3/core/model/tensor/Q4_0FloatTensor.java
rename to src/main/java/org/beehive/gpullama3/tensor/standard/Q4_0FloatTensor.java
index 8396e611..eadfc1ab 100644
--- a/src/main/java/org/beehive/gpullama3/core/model/tensor/Q4_0FloatTensor.java
+++ b/src/main/java/org/beehive/gpullama3/tensor/standard/Q4_0FloatTensor.java
@@ -1,8 +1,8 @@
-package org.beehive.gpullama3.core.model.tensor;
+package org.beehive.gpullama3.tensor.standard;
import org.beehive.gpullama3.LlamaApp;
-import org.beehive.gpullama3.core.model.GGMLType;
-import org.beehive.gpullama3.core.types.Float16;
+import org.beehive.gpullama3.tensor.GGMLType;
+import org.beehive.gpullama3.tensor.Float16;
import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.VectorOperators;
diff --git a/src/main/java/org/beehive/gpullama3/core/model/tensor/Q8_0FloatTensor.java b/src/main/java/org/beehive/gpullama3/tensor/standard/Q8_0FloatTensor.java
similarity index 97%
rename from src/main/java/org/beehive/gpullama3/core/model/tensor/Q8_0FloatTensor.java
rename to src/main/java/org/beehive/gpullama3/tensor/standard/Q8_0FloatTensor.java
index 63a214af..9067bde0 100644
--- a/src/main/java/org/beehive/gpullama3/core/model/tensor/Q8_0FloatTensor.java
+++ b/src/main/java/org/beehive/gpullama3/tensor/standard/Q8_0FloatTensor.java
@@ -1,8 +1,8 @@
-package org.beehive.gpullama3.core.model.tensor;
+package org.beehive.gpullama3.tensor.standard;
-import org.beehive.gpullama3.core.model.GGMLType;
-import org.beehive.gpullama3.core.types.Float16;
+import org.beehive.gpullama3.tensor.GGMLType;
+import org.beehive.gpullama3.tensor.Float16;
import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.VectorOperators;
diff --git a/src/main/java/org/beehive/gpullama3/tensor/tornado/FP16TornadoTensor.java b/src/main/java/org/beehive/gpullama3/tensor/tornado/FP16TornadoTensor.java
new file mode 100644
index 00000000..bcf1e3df
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tensor/tornado/FP16TornadoTensor.java
@@ -0,0 +1,30 @@
+package org.beehive.gpullama3.tensor.tornado;
+
+import org.beehive.gpullama3.tensor.GGMLType;
+import uk.ac.manchester.tornado.api.types.HalfFloat;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+
+import java.lang.foreign.MemorySegment;
+
+public class FP16TornadoTensor extends TornadoTensor {
+ private final HalfFloatArray tornadoNativeArray;
+
+ public FP16TornadoTensor(HalfFloatArray halfFloatArray) {
+ this.tornadoNativeArray = halfFloatArray;
+ }
+
+ public static FP16TornadoTensor fromTornadoMemorySegment(MemorySegment segment) {
+ return new FP16TornadoTensor(HalfFloatArray.fromSegmentShallow(segment));
+ }
+
+ @Override
+ public HalfFloatArray asHalfFloatArray() {
+ return tornadoNativeArray;
+ }
+
+ @Override
+ public GGMLType type() {
+ return GGMLType.F16;
+ }
+}
+
diff --git a/src/main/java/org/beehive/gpullama3/tensor/tornado/FP32TornadoTensor.java b/src/main/java/org/beehive/gpullama3/tensor/tornado/FP32TornadoTensor.java
new file mode 100644
index 00000000..a1520c36
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tensor/tornado/FP32TornadoTensor.java
@@ -0,0 +1,29 @@
+package org.beehive.gpullama3.tensor.tornado;
+
+import org.beehive.gpullama3.tensor.GGMLType;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+
+import java.lang.foreign.MemorySegment;
+
+public class FP32TornadoTensor extends TornadoTensor {
+ private final FloatArray tornadoNativeArray;
+
+ public FP32TornadoTensor(FloatArray floatArray) {
+ this.tornadoNativeArray = floatArray;
+ }
+
+ public static FP32TornadoTensor fromTornadoMemorySegment(MemorySegment segment) {
+ return new FP32TornadoTensor(FloatArray.fromSegmentShallow(segment));
+ }
+
+ @Override
+ public FloatArray asFloatArray() {
+ return tornadoNativeArray;
+ }
+
+ @Override
+ public GGMLType type() {
+ return GGMLType.F32;
+ }
+
+}
diff --git a/src/main/java/org/beehive/gpullama3/tensor/tornado/Q8_0TornadoTensor.java b/src/main/java/org/beehive/gpullama3/tensor/tornado/Q8_0TornadoTensor.java
new file mode 100644
index 00000000..296e7bfa
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tensor/tornado/Q8_0TornadoTensor.java
@@ -0,0 +1,195 @@
+package org.beehive.gpullama3.tensor.tornado;
+
+import org.beehive.gpullama3.tensor.GGMLTensorEntry;
+import org.beehive.gpullama3.tensor.GGMLType;
+import org.beehive.gpullama3.tensor.standard.FloatTensor;
+import uk.ac.manchester.tornado.api.types.HalfFloat;
+import uk.ac.manchester.tornado.api.types.arrays.*;
+
+import java.lang.foreign.MemorySegment;
+import java.lang.foreign.ValueLayout;
+import java.nio.ByteOrder;
+import java.util.concurrent.*;
+import java.util.stream.IntStream;
+
+public class Q8_0TornadoTensor extends TornadoTensor {
+
+ private final int size;
+ private final HalfFloatArray scales; // One per 32-element block
+ private final Int8Array quants; // Quantized int8 values
+ private MemorySegment segment;
+
+ public Q8_0TornadoTensor(int size, HalfFloatArray scales, Int8Array quants, MemorySegment segment) {
+ this.size = size;
+ this.scales = scales;
+ this.quants = quants;
+ this.segment = segment;
+ }
+
+ public int getSize() {
+ return size;
+ }
+
+ /**
+ * Returns the scale factors for GPU kernels.
+ *
+ * @return HalfFloatArray containing fp16 scale factors
+ */
+ public HalfFloatArray getScales() {
+ return scales;
+ }
+
+ /**
+ * Returns the quantized values for GPU kernels.
+ *
+ * @return Int8Array containing quantized int8 values
+ */
+ public Int8Array getQuants() {
+ return quants;
+ }
+
+ @Override
+ public GGMLType type() {
+ return GGMLType.Q8_0;
+ }
+
+ public MemorySegment asMemorySegment() {
+ return segment;
+ }
+
+ /**
+ * Dequantizes and returns a single float value.
+ *
+ * @param index Element index
+ * @return Dequantized float value
+ */
+ public float getFloat(int index) {
+ assert 0 <= index;
+ int blockIdx = index / GGMLType.Q8_0.getBlockSize();
+ float scale = scales.get(blockIdx).getFloat32();
+ byte quant = quants.get(index);
+ return quant * scale;
+ }
+
+ /**
+ * Creates a Q8_0TornadoTensor from a GGMLTensorEntry (original implementation).
+ */
+ public static Q8_0TornadoTensor createAsQ8_0(GGMLTensorEntry entry) {
+ if (entry.ggmlType() != GGMLType.Q8_0) {
+ throw new IllegalArgumentException("Expected Q8_0 tensor, got: " + entry.ggmlType() + " for tensor: " + entry.name());
+ }
+
+ int[] shape = entry.shape();
+ int size = FloatTensor.numberOfElements(shape);
+ int numBlocks = size / GGMLType.Q8_0.getBlockSize();
+
+ if (size % GGMLType.Q8_0.getBlockSize() != 0) {
+ throw new IllegalArgumentException("Q8_0 tensor size must be multiple of " + GGMLType.Q8_0.getBlockSize() + ", got: " + size + " for tensor: " + entry.name());
+ }
+
+ // TODO: fix Q8_0 loading in tornado layoyt
+ // currently we end up to hack it by removing
+ // tornado header from memory segment
+ MemorySegment q8Segment = entry.memorySegment().asSlice(TornadoNativeArray.ARRAY_HEADER);
+
+ // allocate the arrays for quantized data (int8) and scales (fp16)
+ HalfFloatArray scales = new HalfFloatArray(numBlocks);
+ Int8Array quants = new Int8Array(size);
+
+ // unpack Q8_0 blocks: [2 bytes fp16 scale][32 bytes int8 quants]
+ ValueLayout.OfShort shortLayout = ValueLayout.JAVA_SHORT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
+ ValueLayout.OfByte byteLayout = ValueLayout.JAVA_BYTE;
+
+ // element-wise copy and unpack from MemorySegment to HalfFloatArray scales and Int8Array quants
+ // use parallel streams and unroll inner loop for better performance
+ IntStream.range(0, numBlocks)
+ .parallel()
+ .forEach(block -> {
+ // TODO: use GGML type method for the 34L size
+ long blockOffset = block * 34L; // 34 bytes per block
+
+ // read fp16 scale (first 2 bytes of block)
+ short scaleRaw = q8Segment.get(shortLayout, blockOffset);
+ scales.set(block, new HalfFloat(scaleRaw));
+ int blockStart = block * 32;
+
+ // read 32 int8 quantized values (remaining bytes of block)
+ // TODO: use GGML type method for the 32 size
+ for (int i = 0; i < 32; i += 4) {
+ // unroll inner loop for better performance
+ byte q0 = q8Segment.get(byteLayout, blockOffset + 2 + i);
+ byte q1 = q8Segment.get(byteLayout, blockOffset + 2 + i + 1);
+ byte q2 = q8Segment.get(byteLayout, blockOffset + 2 + i + 2);
+ byte q3 = q8Segment.get(byteLayout, blockOffset + 2 + i + 3);
+
+ quants.set(blockStart + i, q0);
+ quants.set(blockStart + i + 1, q1);
+ quants.set(blockStart + i + 2, q2);
+ quants.set(blockStart + i + 3, q3);
+ }
+ });
+
+ return new Q8_0TornadoTensor(size, scales, quants, q8Segment);
+ }
+
+ /**
+ * Creates a Q8_0TornadoTensor formulated as FP32TornadoTensor object from a GGMLTensorEntry.
+ * NOTE: Hack implementation to comply with FP32 inference.
+ */
+ public static FP32TornadoTensor createAsFP32(GGMLTensorEntry entry) {
+ if (entry.ggmlType() != GGMLType.Q8_0) {
+ throw new IllegalArgumentException("Expected Q8_0 tensor, got: " + entry.ggmlType() + " for tensor: " + entry.name());
+ }
+
+ int[] shape = entry.shape();
+ int size = FloatTensor.numberOfElements(shape);
+ int numBlocks = size / GGMLType.Q8_0.getBlockSize();
+
+ if (size % GGMLType.Q8_0.getBlockSize() != 0) {
+ throw new IllegalArgumentException("Q8_0 tensor size must be multiple of " + GGMLType.Q8_0.getBlockSize() + ", got: " + size + " for tensor: " + entry.name());
+ }
+
+ // TODO: fix Q8_0 loading in tornado layoyt
+ // currently we end up to hack it by removing
+ // tornado header from memory segment
+ MemorySegment q8Segment = entry.memorySegment().asSlice(TornadoNativeArray.ARRAY_HEADER);
+
+ // allocate the FloatArray to store the result
+ FloatArray floatArray = new FloatArray(size);
+
+ // unpack Q8_0 blocks: [2 bytes fp16 scale][32 bytes int8 quants]
+ ValueLayout.OfShort shortLayout = ValueLayout.JAVA_SHORT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
+ ValueLayout.OfByte byteLayout = ValueLayout.JAVA_BYTE;
+
+ // element-wise dequantization and copy from MemorySegment to FloatArray
+ // use parallel streams and unroll inner loop for better performance
+ IntStream.range(0, numBlocks)
+ .parallel()
+ .forEach(block -> {
+ // TODO: use GGML type method for the 34L size
+ long blockOffset = block * 34L; // 34 bytes per block
+
+ // read fp16 scale (first 2 bytes of block) and convert to float
+ short scaleRaw = q8Segment.get(shortLayout, blockOffset);
+ float scale = Float.float16ToFloat(scaleRaw);
+ int blockStart = block * 32;
+
+ // read 32 int8 quantized values (remaining bytes of block)
+ // TODO: use GGML type method for the 32 size
+ for (int i = 0; i < 32; i += 4) {
+ // unroll inner loop for better performance
+ byte q0 = q8Segment.get(byteLayout, blockOffset + 2 + i);
+ byte q1 = q8Segment.get(byteLayout, blockOffset + 2 + i + 1);
+ byte q2 = q8Segment.get(byteLayout, blockOffset + 2 + i + 2);
+ byte q3 = q8Segment.get(byteLayout, blockOffset + 2 + i + 3);
+
+ floatArray.set(blockStart + i, q0 * scale);
+ floatArray.set(blockStart + i + 1, q1 * scale);
+ floatArray.set(blockStart + i + 2, q2 * scale);
+ floatArray.set(blockStart + i + 3, q3 * scale);
+ }
+ });
+
+ return new FP32TornadoTensor(floatArray);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tensor/tornado/TornadoTensor.java b/src/main/java/org/beehive/gpullama3/tensor/tornado/TornadoTensor.java
new file mode 100644
index 00000000..30ae9d15
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tensor/tornado/TornadoTensor.java
@@ -0,0 +1,51 @@
+package org.beehive.gpullama3.tensor.tornado;
+
+import org.beehive.gpullama3.tensor.GGMLType;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.Int8Array;
+
+/**
+ * Base class for TornadoVM-compatible tensor types.
+ * These tensors wrap TornadoVM native arrays for GPU execution.
+ */
+public abstract class TornadoTensor {
+
+ public abstract GGMLType type();
+
+ /**
+ * Get as FloatArray (for F32 tensors).
+ *
+ * @throws UnsupportedOperationException if not F32
+ */
+ public FloatArray asFloatArray() {
+ throw new UnsupportedOperationException("Not a FloatArray tensor: " + this.getClass().getSimpleName());
+ }
+
+ /**
+ * Get as HalfFloatArray (for F16 tensors).
+ *
+ * @throws UnsupportedOperationException if not F16
+ */
+ public HalfFloatArray asHalfFloatArray() {
+ throw new UnsupportedOperationException("Not a HalfFloatArray tensor: " + this.getClass().getSimpleName());
+ }
+
+ /**
+ * Get quantized scales (for Q8_0 tensors).
+ *
+ * @throws UnsupportedOperationException if not quantized
+ */
+ public HalfFloatArray getScales() {
+ throw new UnsupportedOperationException("Not a quantized tensor: " + this.getClass().getSimpleName());
+ }
+
+ /**
+ * Get quantized values (for Q8_0 tensors).
+ *
+ * @throws UnsupportedOperationException if not quantized
+ */
+ public Int8Array getQuants() {
+ throw new UnsupportedOperationException("Not a quantized tensor: " + this.getClass().getSimpleName());
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/impl/LlamaTokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/LlamaTokenizer.java
similarity index 86%
rename from src/main/java/org/beehive/gpullama3/tokenizer/impl/LlamaTokenizer.java
rename to src/main/java/org/beehive/gpullama3/tokenizer/LlamaTokenizer.java
index 9575ff76..36a78f1e 100644
--- a/src/main/java/org/beehive/gpullama3/tokenizer/impl/LlamaTokenizer.java
+++ b/src/main/java/org/beehive/gpullama3/tokenizer/LlamaTokenizer.java
@@ -1,10 +1,15 @@
-package org.beehive.gpullama3.tokenizer.impl;
+package org.beehive.gpullama3.tokenizer;
-import org.beehive.gpullama3.core.types.Pair;
-import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
+import org.beehive.gpullama3.auxiliary.Pair;
import java.nio.charset.StandardCharsets;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
@@ -13,19 +18,18 @@
/**
* GPT-2-style BPE tokenizer (even though it's called "llama") with an explicit merges list.
*
- * BPE (Byte Pair Encoding):
- * A sub-word tokenization algorithm that iteratively merges the most frequent pairs of symbols in a corpus to build a vocabulary of common character sequences.
+ * BPE (Byte Pair Encoding): A sub-word tokenization algorithm that iteratively merges the most frequent pairs of symbols in a corpus to build a vocabulary of common character sequences.
*
- * GPT-2-style tokenization:
- * Applies BPE at the byte level, ensuring all UTF-8 inputs are representable and using tokens that preserve leading spaces (e.g., 'Ġthe').
+ * GPT-2-style tokenization: Applies BPE at the byte level, ensuring all UTF-8 inputs are representable and using tokens that preserve leading spaces (e.g., 'Ġthe').
*
- * Explicit merges list:
- * A fixed sequence of learned merge rules that deterministically reconstructs the tokenizer’s vocabulary during inference without retraining.
+ * Explicit merges list: A fixed sequence of learned merge rules that deterministically reconstructs the tokenizer’s vocabulary during inference without retraining.
*
* Based on minbpe, algorithmically follows along the
* GPT 2 tokenizer
*/
public class LlamaTokenizer implements Tokenizer {
+ static final Map BYTE_ENCODER = bytesToUnicode();
+ static final Map BYTE_DECODER = BYTE_ENCODER.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey));
private static final String LLAMA_3_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
// general fields
private final Pattern compiledPattern;
@@ -34,28 +38,6 @@ public class LlamaTokenizer implements Tokenizer {
private final Map, Integer> merges;
private final Map specialTokens;
- public String regexPattern() {
- if (compiledPattern == null) {
- return null;
- }
- return compiledPattern.pattern();
- }
-
- @Override
- public Map getSpecialTokens() {
- return specialTokens;
- }
-
- @Override
- public boolean isSpecialToken(int tokenIndex) {
- return specialTokens.containsValue(tokenIndex);
- }
-
- @Override
- public boolean shouldDisplayToken(int token) {
- return !isSpecialToken(token);
- }
-
public LlamaTokenizer(Map metadata, Vocabulary vocabulary) {
// load from metadata
String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges");
@@ -83,16 +65,85 @@ public LlamaTokenizer(Map metadata, Vocabulary vocabulary) {
}
}
+ private static List findAll(Pattern pattern, String text) {
+ List allMatches = new ArrayList<>();
+ Matcher matcher = pattern.matcher(text);
+ while (matcher.find()) {
+ allMatches.add(matcher.group());
+ }
+ return allMatches;
+ }
+
+ private static List merge(List ids, Pair pair, int idx) {
+ List newids = new ArrayList<>();
+ int i = 0;
+ while (i < ids.size()) {
+ // if not at the very last position AND the pair matches, replace it
+ if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) {
+ newids.add(idx);
+ i += 2;
+ } else {
+ newids.add(ids.get(i));
+ i += 1;
+ }
+ }
+ return newids;
+ }
+
+ /**
+ * Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if
+ * you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab.
+ * To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on.
+ */
+ private static Map bytesToUnicode() {
+ List bs = new ArrayList<>();
+ IntStream.rangeClosed('!', '~').forEach(bs::add);
+ IntStream.rangeClosed('¡', '¬').forEach(bs::add);
+ IntStream.rangeClosed('®', 'ÿ').forEach(bs::add);
+
+ List cs = new ArrayList<>(bs);
+ int n = 0;
+ for (int b = 0; b < 256; ++b) {
+ if (!bs.contains(b)) {
+ bs.add(b);
+ cs.add(256 + n);
+ n += 1;
+ }
+ }
+
+ // return dict(zip(bs, cs))
+ return IntStream.range(0, bs.size()).boxed().collect(Collectors.toMap(bs::get, cs::get));
+ }
+
+ public String regexPattern() {
+ if (compiledPattern == null) {
+ return null;
+ }
+ return compiledPattern.pattern();
+ }
+
+ @Override
+ public Map getSpecialTokens() {
+ return specialTokens;
+ }
+
+ @Override
+ public boolean isSpecialToken(int tokenIndex) {
+ return specialTokens.containsValue(tokenIndex);
+ }
+
+ @Override
+ public boolean shouldDisplayToken(int token) {
+ return !isSpecialToken(token);
+ }
+
private int[] encodeImpl(String text) {
return encode(text, Set.of()).stream().mapToInt(i -> i).toArray();
}
/**
- * Unlike {@link #encodeOrdinary(String)}, this function handles special tokens.
- * allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
- * if none_raise, then an error is raised if any special token is encountered in text
- * this is the default tiktoken behavior right now as well
- * any other behavior is either annoying, or a major footgun.
+ * Unlike {@link #encodeOrdinary(String)}, this function handles special tokens. allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens if none_raise, then an error is
+ * raised if any special token is encountered in text this is the default tiktoken behavior right now as well any other behavior is either annoying, or a major footgun.
*/
public List encode(String text, Set allowedSpecial) {
// decode the user desire w.r.t. handling of special tokens
@@ -108,10 +159,7 @@ public List encode(String text, Set allowedSpecial) {
// based on the occurrence of any exact match with any of the special tokens
// we can use re.split for this. note that surrounding the pattern with ()
// makes it into a capturing group, so the special tokens will be included
- String specialPattern = special
- .stream()
- .map(Pattern::quote)
- .collect(Collectors.joining("|", "(", ")"));
+ String specialPattern = special.stream().map(Pattern::quote).collect(Collectors.joining("|", "(", ")"));
String[] specialChunks = text.split(specialPattern);
// now all the special characters are separated from the rest of the text
@@ -129,15 +177,6 @@ public List encode(String text, Set allowedSpecial) {
return ids;
}
- private static List findAll(Pattern pattern, String text) {
- List allMatches = new ArrayList<>();
- Matcher matcher = pattern.matcher(text);
- while (matcher.find()) {
- allMatches.add(matcher.group());
- }
- return allMatches;
- }
-
/**
* Encoding that ignores any special tokens.
*/
@@ -189,22 +228,6 @@ private List encodeChunk(String chunk) {
return ids;
}
- private static List merge(List ids, Pair pair, int idx) {
- List newids = new ArrayList<>();
- int i = 0;
- while (i < ids.size()) {
- // if not at the very last position AND the pair matches, replace it
- if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) {
- newids.add(idx);
- i += 2;
- } else {
- newids.add(ids.get(i));
- i += 1;
- }
- }
- return newids;
- }
-
public String decodeImpl(List tokens) {
StringBuilder sb = new StringBuilder();
for (int token : tokens) {
@@ -214,38 +237,6 @@ public String decodeImpl(List tokens) {
return sb.toString();
}
- /**
- * Returns list of utf-8 byte and a corresponding list of unicode strings.
- * The reversible bpe codes work on unicode strings.
- * This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
- * When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
- * This is a significant percentage of your normal, say, 32K bpe vocab.
- * To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
- * And avoids mapping to whitespace/control characters the bpe code barfs on.
- */
- private static Map bytesToUnicode() {
- List bs = new ArrayList<>();
- IntStream.rangeClosed('!', '~').forEach(bs::add);
- IntStream.rangeClosed('¡', '¬').forEach(bs::add);
- IntStream.rangeClosed('®', 'ÿ').forEach(bs::add);
-
- List cs = new ArrayList<>(bs);
- int n = 0;
- for (int b = 0; b < 256; ++b) {
- if (!bs.contains(b)) {
- bs.add(b);
- cs.add(256 + n);
- n += 1;
- }
- }
-
- // return dict(zip(bs, cs))
- return IntStream.range(0, bs.size()).boxed().collect(Collectors.toMap(bs::get, cs::get));
- }
-
- static final Map BYTE_ENCODER = bytesToUnicode();
- static final Map BYTE_DECODER = BYTE_ENCODER.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey));
-
public int[] encode(String text) {
StringBuilder sb = new StringBuilder();
byte[] bytes = text.getBytes(StandardCharsets.UTF_8);
diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/impl/MistralTokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/MistralTokenizer.java
similarity index 85%
rename from src/main/java/org/beehive/gpullama3/tokenizer/impl/MistralTokenizer.java
rename to src/main/java/org/beehive/gpullama3/tokenizer/MistralTokenizer.java
index c4264a1b..940318f9 100644
--- a/src/main/java/org/beehive/gpullama3/tokenizer/impl/MistralTokenizer.java
+++ b/src/main/java/org/beehive/gpullama3/tokenizer/MistralTokenizer.java
@@ -1,9 +1,12 @@
-package org.beehive.gpullama3.tokenizer.impl;
-
-import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
+package org.beehive.gpullama3.tokenizer;
import java.nio.charset.StandardCharsets;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@@ -11,18 +14,12 @@
/**
* TikToken-style BPE tokenizer with byte fallback.
*
- * TikToken-style:
- * A Byte Pair Encoding (BPE) strategy that converts text to UTF-8 bytes.
- * Frequent pairs of bytes (or tokens) are merged according to a learned vocabulary.
- * This reduces long words into common subwords or whole-word tokens.
- * If a word or character isn't found, it falls back to byte-level tokens.
+ * TikToken-style: A Byte Pair Encoding (BPE) strategy that converts text to UTF-8 bytes. Frequent pairs of bytes (or tokens) are merged according to a learned vocabulary. This reduces long words into
+ * common subwords or whole-word tokens. If a word or character isn't found, it falls back to byte-level tokens.
*
- * Byte fallback:
- * A fail-safe mechanism.
- * It ensures every byte has a token, so any input (even unknown words, misspellings, foreign languages, emojis, or binary) can be tokenized.
- * If a token is not found in the merges or vocabulary, it will fall back to the individual byte.
- * Each byte is wrapped as a special token like <0xF0> — these are part of the tokenizer’s extended vocabulary.
- * This guarantees reversibility: every string can be tokenized and decoded back exactly.
+ * Byte fallback: A fail-safe mechanism. It ensures every byte has a token, so any input (even unknown words, misspellings, foreign languages, emojis, or binary) can be tokenized. If a token is not
+ * found in the merges or vocabulary, it will fall back to the individual byte. Each byte is wrapped as a special token like <0xF0> — these are part of the tokenizer’s extended vocabulary. This
+ * guarantees reversibility: every string can be tokenized and decoded back exactly.
*/
public class MistralTokenizer implements Tokenizer {
private static final String MISTRAL_PATTERN = "\\S+|\\s+";
@@ -34,6 +31,26 @@ public class MistralTokenizer implements Tokenizer {
private final int[] tokenType;
private final int byte0;
+ // @formatter:off
+ public MistralTokenizer(Map metadata, Vocabulary vocabulary) {
+ // load from metadata
+ int[] tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type");
+ List specialTokensList = IntStream.range(0, vocabulary.size()).filter(t -> tokenTypes[t] != 1 && tokenTypes[t] != 6).boxed().toList();
+ Map specialTokens =
+ IntStream.range(0, specialTokensList.size())
+ .boxed()
+ .collect(Collectors.toMap(
+ t -> vocabulary.get(t),
+ t -> t)
+ );
+ // init tokenizer object fields
+ this.vocabulary = vocabulary;
+ this.compiledPattern = null;
+ this.specialTokens = new HashMap<>(specialTokens);
+ this.tokenType = tokenTypes;
+ this.byte0 = vocabulary.getIndex("<0x00>").orElseThrow();
+ }
+
public String regexPattern() {
if (compiledPattern == null) {
return null;
@@ -60,26 +77,6 @@ public boolean shouldDisplayToken(int token) {
public int getTokenType(int tokenIndex) {
return tokenType[tokenIndex];
}
-
- // @formatter:off
- public MistralTokenizer(Map metadata, Vocabulary vocabulary) {
- // load from metadata
- int[] tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type");
- List specialTokensList = IntStream.range(0, vocabulary.size()).filter(t -> tokenTypes[t] != 1 && tokenTypes[t] != 6).boxed().toList();
- Map specialTokens =
- IntStream.range(0, specialTokensList.size())
- .boxed()
- .collect(Collectors.toMap(
- t -> vocabulary.get(t),
- t -> t)
- );
- // init tokenizer object fields
- this.vocabulary = vocabulary;
- this.compiledPattern = null;
- this.specialTokens = new HashMap<>(specialTokens);
- this.tokenType = tokenTypes;
- this.byte0 = vocabulary.getIndex("<0x00>").orElseThrow();
- }
// @formatter:on
private List encodeImpl(String text) {
diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/impl/Phi3Tokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/Phi3Tokenizer.java
similarity index 97%
rename from src/main/java/org/beehive/gpullama3/tokenizer/impl/Phi3Tokenizer.java
rename to src/main/java/org/beehive/gpullama3/tokenizer/Phi3Tokenizer.java
index e8e12d92..4b5167c0 100644
--- a/src/main/java/org/beehive/gpullama3/tokenizer/impl/Phi3Tokenizer.java
+++ b/src/main/java/org/beehive/gpullama3/tokenizer/Phi3Tokenizer.java
@@ -1,7 +1,6 @@
-package org.beehive.gpullama3.tokenizer.impl;
+package org.beehive.gpullama3.tokenizer;
-import org.beehive.gpullama3.core.types.Pair;
-import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
+import org.beehive.gpullama3.auxiliary.Pair;
import java.io.ByteArrayOutputStream;
import java.nio.charset.StandardCharsets;
diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/impl/Qwen3Tokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/Qwen3Tokenizer.java
similarity index 98%
rename from src/main/java/org/beehive/gpullama3/tokenizer/impl/Qwen3Tokenizer.java
rename to src/main/java/org/beehive/gpullama3/tokenizer/Qwen3Tokenizer.java
index 0f8751fb..077dd536 100644
--- a/src/main/java/org/beehive/gpullama3/tokenizer/impl/Qwen3Tokenizer.java
+++ b/src/main/java/org/beehive/gpullama3/tokenizer/Qwen3Tokenizer.java
@@ -1,8 +1,7 @@
-package org.beehive.gpullama3.tokenizer.impl;
+package org.beehive.gpullama3.tokenizer;
import org.beehive.gpullama3.auxiliary.Utf8Mask;
-import org.beehive.gpullama3.core.types.Pair;
-import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
+import org.beehive.gpullama3.auxiliary.Pair;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
@@ -18,13 +17,14 @@
import java.util.stream.IntStream;
public class Qwen3Tokenizer implements Tokenizer {
+ static final Map BYTE_ENCODER = bytesToUnicode();
+ static final Map BYTE_DECODER = BYTE_ENCODER.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey));
private final static String QWEN3_PATTERN = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
private final Pattern compiledPattern;
private final Vocabulary vocabulary;
private final Map, Integer> merges;
private final Map specialTokens;
private final int[] tokenTypes;
-
/** buffer to store incomplete UTF-8 sequence */
private final byte[] bufUtf8 = new byte[4];
/** index in UTF-8 buffer */
@@ -32,38 +32,6 @@ public class Qwen3Tokenizer implements Tokenizer {
/** current UTF-8 mask */
private Utf8Mask currUtf8Mask;
- @Override
- public String regexPattern() {
- if (compiledPattern == null) {
- return null;
- }
- return compiledPattern.pattern();
- }
-
- @Override
- public Map getSpecialTokens() {
- return specialTokens;
- }
-
- @Override
- public boolean isSpecialToken(int tokenIndex) {
- return specialTokens.containsValue(tokenIndex);
- }
-
- @Override
- public boolean shouldDisplayToken(int token) {
- int tokenType = getTokenType(token);
- // tokenType 4 allows the display of reasoning ( ... <\think> )
- return tokenType == 1 || tokenType == 4 || tokenType == 6;
- }
-
- public int getTokenType(int tokenIndex) {
- if (tokenTypes == null) {
- throw new IllegalStateException("Qwen3Tokenizer hasn't been constructed using tokenTypes");
- }
- return tokenTypes[tokenIndex];
- }
-
// @formatter:off
public Qwen3Tokenizer(Map metadata, Vocabulary vocabulary, boolean isDeepSeekR1DistillQwen) {
int[] tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type");
@@ -106,11 +74,6 @@ public Qwen3Tokenizer(Map metadata, Vocabulary vocabulary, boole
this.merges.put(pair, mergeIndex);
}
}
- // @formatter:on
-
- private int[] encodeImpl(String text) {
- return encode(text, Set.of()).stream().mapToInt(i -> i).toArray();
- }
static List findAll(Pattern pattern, String text) {
List allMatches = new ArrayList<>();
@@ -121,6 +84,92 @@ static List findAll(Pattern pattern, String text) {
return allMatches;
}
+ static List merge(List ids, Pair pair, int idx) {
+ List newids = new ArrayList<>();
+ int i = 0;
+ while (i < ids.size()) {
+ // if not at the very last position AND the pair matches, replace it
+ if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) {
+ newids.add(idx);
+ i += 2;
+ } else {
+ newids.add(ids.get(i));
+ i += 1;
+ }
+ }
+ return newids;
+ }
+
+ /**
+ * Returns list of utf-8 byte and a corresponding list of unicode strings.
+ * The reversible bpe codes work on unicode strings.
+ * This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ * When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ * This is a significant percentage of your normal, say, 32K bpe vocab.
+ * To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ * And avoids mapping to whitespace/control characters the bpe code barfs on.
+ */
+ static Map bytesToUnicode() {
+ List bs = new ArrayList<>();
+ IntStream.rangeClosed('!', '~').forEach(bs::add);
+ IntStream.rangeClosed('¡', '¬').forEach(bs::add);
+ IntStream.rangeClosed('®', 'ÿ').forEach(bs::add);
+
+ List cs = new ArrayList<>(bs);
+ int n = 0;
+ for (int b = 0; b < 256; ++b) {
+ if (!bs.contains(b)) {
+ bs.add(b);
+ cs.add(256 + n);
+ n += 1;
+ }
+ }
+
+ // return dict(zip(bs, cs))
+ return IntStream.range(0, bs.size())
+ .boxed()
+ .collect(Collectors.toMap(bs::get, cs::get));
+ }
+ // @formatter:on
+
+ @Override
+ public String regexPattern() {
+ if (compiledPattern == null) {
+ return null;
+ }
+ return compiledPattern.pattern();
+ }
+
+ @Override
+ public Map getSpecialTokens() {
+ return specialTokens;
+ }
+
+ @Override
+ public boolean isSpecialToken(int tokenIndex) {
+ return specialTokens.containsValue(tokenIndex);
+ }
+
+ @Override
+ public boolean shouldDisplayToken(int token) {
+ int tokenType = getTokenType(token);
+ // tokenType 4 allows the display of reasoning ( ... <\think> )
+ return tokenType == 1 || tokenType == 4 || tokenType == 6;
+ }
+
+ public int getTokenType(int tokenIndex) {
+ if (tokenTypes == null) {
+ throw new IllegalStateException("Qwen3Tokenizer hasn't been constructed using tokenTypes");
+ }
+ return tokenTypes[tokenIndex];
+ }
+
+ private int[] encodeImpl(String text) {
+ return encode(text, Set.of()).stream().mapToInt(i -> i).toArray();
+ }
+
+ // @formatter:off
+
/**
* Encoding that ignores any special tokens.
*/
@@ -135,6 +184,7 @@ public List encodeOrdinary(String text) {
}
return ids;
}
+ // @formatter:on
private Map, Integer> getStats(List ids) {
Map, Integer> map = new HashMap<>();
@@ -172,58 +222,6 @@ private List encodeChunk(String chunk) {
return ids;
}
- static List merge(List ids, Pair pair, int idx) {
- List newids = new ArrayList<>();
- int i = 0;
- while (i < ids.size()) {
- // if not at the very last position AND the pair matches, replace it
- if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) {
- newids.add(idx);
- i += 2;
- } else {
- newids.add(ids.get(i));
- i += 1;
- }
- }
- return newids;
- }
-
- // @formatter:off
- /**
- * Returns list of utf-8 byte and a corresponding list of unicode strings.
- * The reversible bpe codes work on unicode strings.
- * This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
- * When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
- * This is a significant percentage of your normal, say, 32K bpe vocab.
- * To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
- * And avoids mapping to whitespace/control characters the bpe code barfs on.
- */
- static Map bytesToUnicode() {
- List bs = new ArrayList<>();
- IntStream.rangeClosed('!', '~').forEach(bs::add);
- IntStream.rangeClosed('¡', '¬').forEach(bs::add);
- IntStream.rangeClosed('®', 'ÿ').forEach(bs::add);
-
- List cs = new ArrayList<>(bs);
- int n = 0;
- for (int b = 0; b < 256; ++b) {
- if (!bs.contains(b)) {
- bs.add(b);
- cs.add(256 + n);
- n += 1;
- }
- }
-
- // return dict(zip(bs, cs))
- return IntStream.range(0, bs.size())
- .boxed()
- .collect(Collectors.toMap(bs::get, cs::get));
- }
- // @formatter:on
-
- static final Map BYTE_ENCODER = bytesToUnicode();
- static final Map BYTE_DECODER = BYTE_ENCODER.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey));
-
public int[] encode(String text) {
StringBuilder sb = new StringBuilder();
byte[] bytes = text.getBytes(StandardCharsets.UTF_8);
@@ -290,8 +288,6 @@ public List encodeAsList(String text) {
return Arrays.stream(encode(text)).boxed().toList();
}
-
-
public String decodeImpl(List tokens) {
StringBuilder sb = new StringBuilder();
for (int token : tokens) {
diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/impl/Tokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/Tokenizer.java
similarity index 89%
rename from src/main/java/org/beehive/gpullama3/tokenizer/impl/Tokenizer.java
rename to src/main/java/org/beehive/gpullama3/tokenizer/Tokenizer.java
index 8419019d..ec67c5f5 100644
--- a/src/main/java/org/beehive/gpullama3/tokenizer/impl/Tokenizer.java
+++ b/src/main/java/org/beehive/gpullama3/tokenizer/Tokenizer.java
@@ -1,4 +1,4 @@
-package org.beehive.gpullama3.tokenizer.impl;
+package org.beehive.gpullama3.tokenizer;
import java.util.HexFormat;
import java.util.List;
@@ -6,27 +6,6 @@
import java.util.Set;
public interface Tokenizer {
- String regexPattern();
-
- Map getSpecialTokens();
-
- boolean isSpecialToken(int tokenIndex);
-
- /**
- * Determines if a token should be displayed during streaming output.
- * This filters out special tokens, control characters, or other non-displayable content.
- *
- * @param token the token to check
- * @return true if the token should be displayed to the user, false otherwise
- */
- boolean shouldDisplayToken(int token);
-
- List encode(String text, Set allowedSpecial);
-
- List encodeAsList(String text);
-
- String decode(List tokens);
-
// Utility method for all tokenizers, implemented as static.
static String replaceControlCharacters(int[] codePoints) {
// we don't want to print control characters
@@ -49,5 +28,26 @@ static String replaceControlCharacters(String str) {
return replaceControlCharacters(str.codePoints().toArray());
}
+ String regexPattern();
+
+ Map getSpecialTokens();
+
+ boolean isSpecialToken(int tokenIndex);
+
+ /**
+ * Determines if a token should be displayed during streaming output. This filters out special tokens, control characters, or other non-displayable content.
+ *
+ * @param token
+ * the token to check
+ * @return true if the token should be displayed to the user, false otherwise
+ */
+ boolean shouldDisplayToken(int token);
+
+ List encode(String text, Set allowedSpecial);
+
+ List encodeAsList(String text);
+
+ String decode(List tokens);
+
}
diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/vocabulary/Vocabulary.java b/src/main/java/org/beehive/gpullama3/tokenizer/Vocabulary.java
similarity index 98%
rename from src/main/java/org/beehive/gpullama3/tokenizer/vocabulary/Vocabulary.java
rename to src/main/java/org/beehive/gpullama3/tokenizer/Vocabulary.java
index 474b4b77..1a867569 100644
--- a/src/main/java/org/beehive/gpullama3/tokenizer/vocabulary/Vocabulary.java
+++ b/src/main/java/org/beehive/gpullama3/tokenizer/Vocabulary.java
@@ -1,4 +1,4 @@
-package org.beehive.gpullama3.tokenizer.vocabulary;
+package org.beehive.gpullama3.tokenizer;
import java.util.Arrays;
import java.util.Map;
@@ -18,15 +18,6 @@ public Vocabulary(String[] vocabulary, float[] scores) {
}
// @formatter:on
- public String get(int tokenIndex) {
- return tokens[tokenIndex];
- }
-
- public OptionalInt getIndex(String token) {
- Integer value = tokenToIndex.get(token);
- return value != null ? OptionalInt.of(value) : OptionalInt.empty();
- }
-
public static Vocabulary loadLlamaVocabulary(Map metadata) {
String[] tokens = (String[]) metadata.get("tokenizer.ggml.tokens");
return new Vocabulary(tokens, null);
@@ -51,6 +42,15 @@ public static Vocabulary loadPhi3Vocabulary(Map metadata) {
return new Vocabulary(tokens, scores);
}
+ public String get(int tokenIndex) {
+ return tokens[tokenIndex];
+ }
+
+ public OptionalInt getIndex(String token) {
+ Integer value = tokenToIndex.get(token);
+ return value != null ? OptionalInt.of(value) : OptionalInt.empty();
+ }
+
public int size() {
return tokens.length;
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/GPULLlama3TypeException.java b/src/main/java/org/beehive/gpullama3/tornadovm/GPULLlama3TypeException.java
new file mode 100644
index 00000000..78962a2c
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/GPULLlama3TypeException.java
@@ -0,0 +1,7 @@
+package org.beehive.gpullama3.tornadovm;
+
+public class GPULLlama3TypeException extends IllegalArgumentException {
+ public GPULLlama3TypeException(String message) {
+ super(message);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/GenericLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/GenericLayerPlanner.java
new file mode 100644
index 00000000..5a151212
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/GenericLayerPlanner.java
@@ -0,0 +1,14 @@
+package org.beehive.gpullama3.tornadovm;
+
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+
+import java.util.List;
+
+public interface GenericLayerPlanner {
+
+ List getImmutableTaskGraphs();
+
+ GridScheduler getGridScheduler();
+
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlanner.java
deleted file mode 100644
index 6cfdb821..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlanner.java
+++ /dev/null
@@ -1,355 +0,0 @@
-package org.beehive.gpullama3.tornadovm;
-
-import org.beehive.gpullama3.auxiliary.Tuple2;
-import org.beehive.gpullama3.inference.state.Phi3State;
-import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.phi3.Phi3Configuration;
-import uk.ac.manchester.tornado.api.GridScheduler;
-import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
-import uk.ac.manchester.tornado.api.TaskGraph;
-import uk.ac.manchester.tornado.api.WorkerGrid;
-import uk.ac.manchester.tornado.api.WorkerGrid1D;
-import uk.ac.manchester.tornado.api.enums.DataTransferMode;
-
-import java.util.ArrayList;
-import java.util.List;
-
-public class Phi3TornadoVMLayerPlanner extends TornadoVMLayerPlanner {
-
- /**
- * Constructs a TornadoVMLayerPlanner for the given Llama model.
- *
- * @param state
- * The state object containing model tensors and buffers
- * @param model
- * The Llama model instance containing configuration and weights
- */
- public Phi3TornadoVMLayerPlanner(Phi3State state, Model model) {
- super(state, model);
- }
-
- public Tuple2, GridScheduler> setupTornadoForwardPlanLayered() {
- List taskGraphs = new ArrayList<>();
-
- state.temp.init(0.0f);
- state.tempFFN.init(0.0f);
- state.tempLogits.init(0.0f);
- final int opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize());
-
- // @formatter:off
- TaskGraph activationUpdate = new TaskGraph("activationUpdate")
- .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapX)
- .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.wrapX)
- .persistOnDevice(state.wrapX);
- taskGraphs.add(activationUpdate.snapshot());
-
- TaskGraph unifiedLayer = null;
- for (int layerIndex = 0; layerIndex < config.numberOfLayers(); layerIndex++) {
- unifiedLayer = new TaskGraph("layer_" + layerIndex);
- unifiedLayer.consumeFromDevice(state.wrapX);
- unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
- weights.rms_att_weightLayered[layerIndex],
- weights.wqkvLayered[layerIndex],
- weights.woLayered[layerIndex],
- weights.rms_ffn_weightLayered[layerIndex],
- weights.wDownLayered[layerIndex],
- weights.wUpLayered[layerIndex]
- );
- unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
- unifiedLayer.task("reductionsOneBlock" , TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp,
- state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
- .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb,
- state.wrapX, weights.rms_att_weightLayered[layerIndex], state.temp)
- .task("qkvmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapQkv,
- weights.wqkvLayered[layerIndex], config.dim(), opSize, LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("splitQKV", TransformerComputeKernelsLayered::splitQKV,
- state.wrapQkv, state.wrapQ, state.wrapK, state.wrapV,
- config.dim(), config.headSize() * config.numberOfKeyValueHeads())
- .task("rope", TransformerComputeKernelsLayered::ropeRotationPhi3,context,
- state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(),
- config.headSize())
- .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache,
- state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), layerIndex, config.contextLength())
- .task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, context,
- state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb,
- config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(),
- state.positionHolder, layerIndex, config.contextLength())
- .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
- state.wrapXb, state.wrapX, weights.woLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN,
- state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
- .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb,
- state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN)
- .task("wGateUp", TransformerComputeKernelsLayered::matrixVectorGeneric, context,
- state.wrapXb, state.wrapHb, weights.wUpLayered[layerIndex], config.dim(), 2 * config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("gateUpSiLU", TransformerComputeKernelsLayered::splitGateUpAndSiLU,
- state.wrapHb, state.wrapHbG, state.wrapHbU, config.hiddenDim())
- .task("wDown", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
- state.wrapHbU, state.wrapX, weights.wDownLayered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .persistOnDevice(
- state.wrapX
- );
- taskGraphs.add(unifiedLayer.snapshot());
- }
-
- TaskGraph lastUnifiedLayer = unifiedLayer;
- TaskGraph logits = new TaskGraph("logits")
- .consumeFromDevice(lastUnifiedLayer.getTaskGraphName(),
- state.wrapX
- )
- .transferToDevice(DataTransferMode.EVERY_EXECUTION,
- state.tempLogits
- )
- .transferToDevice(DataTransferMode.FIRST_EXECUTION,
- context,
- state.wrapLogits,
- weights.wclsHalfFloat,
- weights.rms_final_weight_as_floatArray
- )
- .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits,
- state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
- .task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX,
- weights.rms_final_weight_as_floatArray, state.tempLogits);
- logits = configureQuantizedMatrixVectorFinalWeight(logits);
- logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
- taskGraphs.add(logits.snapshot());
- // @formatter:on
-
- return new Tuple2<>(taskGraphs, setupGridSchedulersLayered());
- }
-
- // @formatter:off
- /**
- * Configures the final projection layer in the task graph based on weight quantization type.
- *
- * This method adds a "projection" task to compute the final logits by performing a
- * matrix-vector multiplication between the model's output embeddings and the classifier
- * weights (wcls). The computation kernel used depends on the quantization format.
- *
- * Supported quantization types:
- * - Q8_0: 8-bit quantization with uniform scaling per 32-element block
- * - Q4_0: 4-bit quantization with uniform scaling per 32-element block
- *
- * The task multiplies:
- * - weights.wclsByteArray: Quantized classifier weights (vocab_size x dim)
- * - state.wrapX: Current layer output (dim)
- * - Result: state.wrapLogits: Raw logits (vocab_size)
- *
- * @param logits The existing task graph to extend with the projection operation
- * @return The modified task graph with the projection task added
- * @throws UnsupportedOperationException If weights.weightType is not Q8_0 or Q4_0
- */
- // @formatter:on
- protected TaskGraph configureQuantizedMatrixVectorFinalWeight(TaskGraph logits) {
- switch (weights.getWeightType()) {
- case F16:
- case Q8_0:
- case Q4_0:
- logits.task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, //
- context, state.wrapX, state.wrapLogits, weights.wclsHalfFloat, //
- config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); //
- break;
- default:
- throw new UnsupportedOperationException("Unsupported weight quantization type: " + weights.getWeightType() + ". Only Q8_0 and Q4_0 are supported.");
- }
- return logits;
- }
-
- /**
- * Configures data transfer operations for a specific layer in the neural network task graph.
- *
- * This method manages GPU memory transfers with optimized data movement strategies: This optimization pattern minimizes data movement by: 1. Using one-time transfers for static data 2. Reusing
- * intermediate results already on GPU from previous layers 3. Only transferring // dynamic data that changes per execution
- *
- * @param unifiedLayer
- * The task graph representing this layer's operations
- * @param layerIndex
- * Index of the current layer (0-based)
- * @return The configured task graph with appropriate data transfer operations
- */
- protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) {
- // First layer: Transfer initial data to device (one-time transfer)
- if (layerIndex == 0) {
- // Transfer all attention-related data: query, key, value matrices and their caches
- unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); //
- unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, //
- context, state.wrapXb, state.wrapXb2, //
- state.wrapQ, state.wrapK, state.wrapV, //
- state.wrapKeyCache, state.wrapValueCache, //
- state.wrapAtt, state.wrapHb, //
- state.wrapHbG, state.wrapHbU, state.wrapQkv); //
- } else {
- // Subsequent layers: Consume data already on device from previous layer
- unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, //
- state.wrapQ, state.wrapK, state.wrapV, //
- state.wrapKeyCache, state.wrapValueCache, //
- state.wrapAtt, state.wrapHb, //
- state.positionHolder, // /
- state.wrapHbG, state.wrapHbU, state.wrapQkv);
- }
- return unifiedLayer;
- }
-
- // @formatter:off
- /**
- * Sets up the grid scheduler configuration for a layered neural network forward pass.
- *
- * This method creates and configures worker grids for different types of GPU operations
- * in the transformer/ML model pipeline. Each worker grid defines how work should be
- * distributed across GPU threads (OpenCL work-items or CUDA threads).
- *
- * The method creates several worker profiles:
- * - Single thread operations (activation updates)
- * - RoPE (Rotary Position Embedding) operations
- * - Matrix multiplications with different dimensions
- * - RMS normalization operations
- * - Parallel attention computations
- * - Cache copying operations
- * - Vocabulary projections
- *
- * Each worker grid maps to equivalent OpenCL NDRange or CUDA grid/block configurations:
- * - setGlobalWork() ≈ OpenCL global_work_size ≈ CUDA grid dimensions × block dimensions
- * - setLocalWork() ≈ OpenCL local_work_size ≈ CUDA block dimensions
- *
- * @return GridScheduler configured with all necessary worker grids for the model layers
- */
- // @formatter:on
- private GridScheduler setupGridSchedulersLayered() {
- GridScheduler tornadoForwardScheduler = new GridScheduler();
-
- // Single worker for tasks running with a single thread
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[1,1,1], localWorkSize=[1,1,1])
- // CUDA equivalent: kernel<<>>
- WorkerGrid singleWorker = new WorkerGrid1D(1);
- singleWorker.setGlobalWork(1, 1, 1);
- singleWorker.setLocalWork(1, 1, 1);
-
- // config.dim / 2 Worker for RoPE
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim/2,1,1], localWorkSize=[128,1,1])
- // CUDA equivalent: kernel<<>>
- WorkerGrid ropeWorker = new WorkerGrid1D(config.dim() / 2);
- ropeWorker.setGlobalWork(config.dim() / 2, 1, 1);
- ropeWorker.setLocalWork(128, 1, 1);
-
- // config.dim Worker for Row major access
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1])
- // CUDA equivalent: kernel<<>>
- int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
- WorkerGrid configDimRowMajorGlobalWorker = new WorkerGrid1D(configDimRowMajorGlobal);
- configDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
-
- final int opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize());
-
- int qkvmatmulDimRowMajorGlobal = opSize * LOCAL_WORK_GROUP_SIZE_ALLOC;
- WorkerGrid qkvDimRowMajorGlobalWorker = new WorkerGrid1D(qkvmatmulDimRowMajorGlobal);
- qkvDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
-
- // config.kvDim Worker for Row major access
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.kvDim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1])
- // CUDA equivalent: kernel<<>>
- int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
- WorkerGrid configKvDimRowMajorGlobalWorker = new WorkerGrid1D(configKvDimRowMajorGlobal);
- configKvDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
-
- // config.hiddenDim * 32 Worker for Row major access
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.hiddenDim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1])
- // CUDA equivalent: kernel<<>>
- int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
- WorkerGrid configHiddenDimRowMajorWorker = new WorkerGrid1D(configHiddenDimRowMajor);
- configHiddenDimRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
-
- int wgetUPDimRowMajor = 2 * config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
- WorkerGrid wgetHiddenDimRowMajorWorker = new WorkerGrid1D(wgetUPDimRowMajor);
- wgetHiddenDimRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
-
- // RMSNorm worker configuration
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[256,1,1])
- // CUDA equivalent: kernel<<>>
- WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim());
- rmsNormWorker.setGlobalWork(config.dim(), 1, 1); // Set global work size to total dimension
- rmsNormWorker.setLocalWork(256, 1, 1); // Set local work size to 256 (standard efficient size)
-
- // Parallel attention worker configuration
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.numberOfHeads,1,1], localWorkSize=[4,1,1])
- // CUDA equivalent: kernel<<>>
- WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads());
- // the global group work size is numberOfHeads * localWorkGroupSize, where the localWorkGroupSize is currently 4
- parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * 8, 1, 1);
- parallelAttentionWorker.setLocalWork(8, 1, 1); // Set local work size to 4 (for parallel attention)
-
- // Copy to caches worker configuration
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[128,1,1])
- // CUDA equivalent: kernel<<>>
- WorkerGrid copyToCachesWorker = new WorkerGrid1D(config.kvDim());
- copyToCachesWorker.setGlobalWork(config.dim(), 1, 1);
- copyToCachesWorker.setLocalWork(128, 1, 1); // Set local work size to 32 (for copying to caches)
-
- // Q copy worker configuration
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[128,1,1])
- // CUDA equivalent: kernel<<>>
- WorkerGrid copyQWorker = new WorkerGrid1D(config.dim());
- copyQWorker.setGlobalWork(config.dim(), 1, 1);
- copyQWorker.setLocalWork(128, 1, 1);
-
- // K copy worker configuration
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[kvSize,1,1], localWorkSize=[128,1,1])
- // CUDA equivalent: kernel<<>>
- int kvSize = config.headSize() * config.numberOfKeyValueHeads();
- WorkerGrid copyKWorker = new WorkerGrid1D(kvSize);
- copyKWorker.setGlobalWork(kvSize, 1, 1);
- copyKWorker.setLocalWork(128, 1, 1);
-
- // V copy worker configuration
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[kvSize,1,1], localWorkSize=[128,1,1])
- // CUDA equivalent: kernel<<>>
- WorkerGrid copyVWorker = new WorkerGrid1D(kvSize);
- copyVWorker.setGlobalWork(kvSize, 1, 1);
- copyVWorker.setLocalWork(128, 1, 1);
-
- WorkerGrid hiddenDimWorker = new WorkerGrid1D(config.hiddenDim());
- hiddenDimWorker.setGlobalWork(config.hiddenDim(), 1, 1);
- hiddenDimWorker.setLocalWork(128, 1, 1);
-
- WorkerGrid splitGateUpSiLUWorker = new WorkerGrid1D(config.hiddenDim());
- splitGateUpSiLUWorker.setGlobalWork(config.hiddenDim(), 1, 1);
- splitGateUpSiLUWorker.setLocalWork(128, 1, 1);
-
- // Total work size is dimQ + 2*dimKV (same as opSize)
- WorkerGrid splitQKVWorker = new WorkerGrid1D(opSize);
- splitQKVWorker.setGlobalWork(opSize, 1, 1);
- splitQKVWorker.setLocalWork(128, 1, 1);
-
- // Map workers to tasks
- tornadoForwardScheduler.addWorkerGrid("activationUpdate.updateX", singleWorker);
- for (int i = 0; i < config.numberOfLayers(); i++) {
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkvmatmul", qkvDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".wDown", configDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".wGateUp", wgetHiddenDimRowMajorWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker);
- // New FFN tasks
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".gateUpSiLU", splitGateUpSiLUWorker);
- }
-
- // Vocabulary worker configuration
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.vocabularySize,1,1], localWorkSize=[16,1,1])
- // CUDA equivalent: kernel<<>>
- int vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
- WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
- vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
-
- tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker);
- tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", rmsNormWorker);
-
- return tornadoForwardScheduler;
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2TornadoVMLayerPlanner.java
deleted file mode 100644
index 1f9d547b..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2TornadoVMLayerPlanner.java
+++ /dev/null
@@ -1,250 +0,0 @@
-package org.beehive.gpullama3.tornadovm;
-
-import org.beehive.gpullama3.auxiliary.Tuple2;
-import org.beehive.gpullama3.inference.state.Qwen2State;
-import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
-import uk.ac.manchester.tornado.api.GridScheduler;
-import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
-import uk.ac.manchester.tornado.api.TaskGraph;
-import uk.ac.manchester.tornado.api.WorkerGrid;
-import uk.ac.manchester.tornado.api.WorkerGrid1D;
-import uk.ac.manchester.tornado.api.WorkerGrid2D;
-import uk.ac.manchester.tornado.api.enums.DataTransferMode;
-
-import java.util.ArrayList;
-import java.util.List;
-
-public class Qwen2TornadoVMLayerPlanner extends TornadoVMLayerPlanner {
-
- /**
- * Constructs a TornadoVMLayerPlanner for the given Qwen2 model.
- *
- * @param state
- * The state object containing model tensors and buffers
- * @param model
- * The Qwen2 model instance containing configuration and weights
- */
- public Qwen2TornadoVMLayerPlanner(Qwen2State state, Model model) {
- super(state, model);
- }
-
- @Override
- public Tuple2, GridScheduler> setupTornadoForwardPlanLayered() {
- List taskGraphs = new ArrayList<>();
-
- state.temp.init(0.0f);
- state.tempFFN.init(0.0f);
- state.tempLogits.init(0.0f);
- state.wrapLogits.init(0.0f);
-
- TaskGraph activationUpdate = new TaskGraph("activationUpdate")
- .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapX)
- .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.wrapX)
- .persistOnDevice(state.wrapX);
- taskGraphs.add(activationUpdate.snapshot());
-
- TaskGraph unifiedLayer = null;
- for (int layerIndex =0; layerIndex < config.numberOfLayers(); layerIndex++) {
- unifiedLayer = new TaskGraph("layer_" + layerIndex);
- unifiedLayer.consumeFromDevice(state.wrapX);
- unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
- //Copy-in weights per layer for batched-layered layout
- weights.rms_att_weightLayered[layerIndex],
- weights.wqLayered[layerIndex],
- weights.wkLayered[layerIndex],
- weights.wvLayered[layerIndex],
- weights.woLayered[layerIndex],
- weights.q_biasLayered[layerIndex],
- weights.k_biasLayered[layerIndex],
- weights.v_biasLayered[layerIndex],
- weights.rms_ffn_weightLayered[layerIndex],
- weights.w1Layered[layerIndex],
- weights.w2Layered[layerIndex],
- weights.w3Layered[layerIndex]
- );
- unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
-
- unifiedLayer.task("reductionsOneBlock" , TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp,
- state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
- .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb,
- state.wrapX, weights.rms_att_weightLayered[layerIndex], state.temp)
- .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context,
- state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context,
- state.wrapXb, state.wrapK, weights.wkLayered[layerIndex], config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context,
- state.wrapXb, state.wrapV, weights.wvLayered[layerIndex], config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("qbias", TransformerComputeKernelsLayered::addInPlace, state.wrapQ, weights.q_biasLayered[layerIndex], config.dim())
- .task("kbias", TransformerComputeKernelsLayered::addInPlace, state.wrapK, weights.k_biasLayered[layerIndex], config.kvDim())
- .task("vbias", TransformerComputeKernelsLayered::addInPlace, state.wrapV, weights.v_biasLayered[layerIndex], config.kvDim())
- .task("rope", Qwen3Kernels::ropeRotation,context, state.positionHolder, state.wrapQ, state.wrapK, config.numberOfKeyValueHeads(),
- config.headSize())
- .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache,
- state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), layerIndex, config.contextLength())
- .task("parallel-attention", Qwen2Kernels::processHeadsFlashAttention, context,
- state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb,
- config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(),
- state.positionHolder, layerIndex, config.contextLength())
- .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
- state.wrapXb, state.wrapX, weights.woLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN,
- state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
- .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb,
- state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN)
- .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context,
- state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex], weights.w3Layered[layerIndex], config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
- state.wrapHb, state.wrapX, weights.w2Layered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .persistOnDevice(
- state.wrapX
- );
- taskGraphs.add(unifiedLayer.snapshot());
- }
-
- TaskGraph lastUnifiedLayer = unifiedLayer;
- TaskGraph logits = new TaskGraph("logits")
- .consumeFromDevice(lastUnifiedLayer.getTaskGraphName(),
- state.wrapX
- )
- .transferToDevice(DataTransferMode.EVERY_EXECUTION,
- state.tempLogits
- )
- .transferToDevice(DataTransferMode.FIRST_EXECUTION,
- context,
- state.wrapLogits,
- weights.wclsHalfFloat,
- weights.rms_final_weight_as_floatArray
- )
- .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits,
- state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
- .task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX,
- weights.rms_final_weight_as_floatArray, state.tempLogits);
- logits = configureQuantizedMatrixVectorFinalWeight(logits);
- logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
- taskGraphs.add(logits.snapshot());
- // @formatter:on
-
- return new Tuple2<>(taskGraphs, setupQwen2GridSchedulersLayeredNonNvidia());
- }
-
- @Override
- public Tuple2, GridScheduler> setupTornadoForwardPlanLayeredNonNvidia() {
- return setupTornadoForwardPlanLayered();
- }
-
- private GridScheduler setupQwen2GridSchedulersLayeredNonNvidia() {
- //throw new UnsupportedOperationException("setupQwen2GridSchedulersLayeredNonNvidia Not supported yet.");
- GridScheduler tornadoForwardScheduler = new GridScheduler();
-
- // Single worker for tasks running with a single thread
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[1,1,1], localWorkSize=[1,1,1])
- // CUDA equivalent: kernel<<>>
- WorkerGrid singleWorker = new WorkerGrid1D(1);
- singleWorker.setGlobalWork(1, 1, 1);
- singleWorker.setLocalWork(1, 1, 1);
-
- // config.dim / 2 Worker for RoPE
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim/2,1,1], localWorkSize=[128,1,1])
- // CUDA equivalent: kernel<<>>
- int h = config.numberOfHeads();
- int ic = config.headSize() / 2;
- WorkerGrid ropeWorker = new WorkerGrid2D(h, ic);
- ropeWorker.setGlobalWork(h, ic, 1);
- ropeWorker.setLocalWork(1, 1, 1);
-
- // config.dim Worker for Row major access
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1])
- // CUDA equivalent: kernel<<>>
- int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
- WorkerGrid configDimRowMajorGlobalWorker = new WorkerGrid1D(configDimRowMajorGlobal);
- configDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
-
- // config.kvDim Worker for Row major access
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.kvDim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1])
- // CUDA equivalent: kernel<<>>
- int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
- WorkerGrid configKvDimRowMajorGlobalWorker = new WorkerGrid1D(configKvDimRowMajorGlobal);
- configKvDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
-
- WorkerGrid qBiasWorker = new WorkerGrid1D(config.dim());
- qBiasWorker.setGlobalWork(config.dim(), 1, 1);
- qBiasWorker.setLocalWork(config.dim() / 8, 1, 1);
- WorkerGrid kvBiasWorker = new WorkerGrid1D(config.kvDim());
- kvBiasWorker.setGlobalWork(config.kvDim(), 1, 1);
- kvBiasWorker.setLocalWork(32, 1, 1);
-
- // config.hiddenDim * 32 Worker for Row major access
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.hiddenDim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1])
- // CUDA equivalent: kernel<<>>
- int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
- WorkerGrid configHiddenDimRowMajorWorker = new WorkerGrid1D(configHiddenDimRowMajor);
- configHiddenDimRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
-
- // RMSNorm worker configuration
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[256,1,1])
- // CUDA equivalent: kernel<<>>
- WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim());
- rmsNormWorker.setGlobalWork(config.dim(), 1, 1); // Set global work size to total dimension
- rmsNormWorker.setLocalWork(32, 1, 1); // Set local work size to 256 (standard efficient size)
-
- // Parallel attention worker configuration
- // Calculate optimal local work size based on head dimension
- int optimalLocalSize = Math.min(config.headSize(), 64); // Start with 64 threads per head
- if (config.headSize() % optimalLocalSize != 0) {
- // Find largest divisor of headSize <= 64
- for (int size = 64; size >= 1; size--) {
- if (config.headSize() % size == 0) {
- optimalLocalSize = size;
- break;
- }
- }
- }
-
- WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads());
- parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * optimalLocalSize, 1, 1);
- parallelAttentionWorker.setLocalWork(optimalLocalSize, 1, 1);
-
- // Copy to caches worker configuration
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[128,1,1])
- // CUDA equivalent: kernel<<>>
- WorkerGrid copyToCachesWorker = new WorkerGrid1D(config.kvDim());
- copyToCachesWorker.setGlobalWork(config.kvDim(), 1, 1);
- copyToCachesWorker.setLocalWork(32, 1, 1); // Set local work size to 32 (for copying to caches)
-
- // Map workers to tasks
- tornadoForwardScheduler.addWorkerGrid("activationUpdate.updateX", singleWorker);
- for (int i = 0; i < config.numberOfLayers(); i++) {
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qbias", qBiasWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kbias", kvBiasWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vbias", kvBiasWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker);
- tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker);
- }
-
- // Vocabulary worker configuration
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.vocabularySize,1,1], localWorkSize=[16,1,1])
- // CUDA equivalent: kernel<<>>
- int vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
- WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
- vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
-
- tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker);
- tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker);
- tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", rmsNormWorker);
-
- return tornadoForwardScheduler;
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3TornadoVMLayerPlanner.java
deleted file mode 100644
index 57d08a90..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3TornadoVMLayerPlanner.java
+++ /dev/null
@@ -1,383 +0,0 @@
-package org.beehive.gpullama3.tornadovm;
-
-import org.beehive.gpullama3.auxiliary.Tuple2;
-import org.beehive.gpullama3.inference.state.Qwen3State;
-import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
-import uk.ac.manchester.tornado.api.GridScheduler;
-import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
-import uk.ac.manchester.tornado.api.TaskGraph;
-import uk.ac.manchester.tornado.api.WorkerGrid;
-import uk.ac.manchester.tornado.api.WorkerGrid1D;
-import uk.ac.manchester.tornado.api.WorkerGrid2D;
-import uk.ac.manchester.tornado.api.enums.DataTransferMode;
-
-import java.util.ArrayList;
-import java.util.List;
-
-public class Qwen3TornadoVMLayerPlanner extends TornadoVMLayerPlanner {
-
- private final int nHeadKv;
- private final int nEmbdHeadK;
- private final int nEmbdHeadV;
- private final int nEmbdVGqa;
- private final int nEmbdHead;
- private final int nEmbdGqa;
- private final int gqa;
-
- public Qwen3TornadoVMLayerPlanner(Qwen3State state, Model model) {
- super(state, model);
-
- this.nHeadKv = config.numberOfKeyValueHeads();
- this.nEmbdHeadK = config.numberOfHeadsKey();
- this.nEmbdHeadV = config.numberOfHeadsValue(); // n_embd_head_v = n_embd / n_head; %s.attention.value_length
- this.nEmbdVGqa = nEmbdHeadV * nHeadKv; // n_embd_v_gqa = n_embd_head_v * n_head_kv
- this.nEmbdHead = nEmbdHeadV;
- this.nEmbdGqa = nEmbdVGqa;
- this.gqa = config.numberOfHeads() / config.numberOfKeyValueHeads(); // integer multiplier of the kv sharing in multiquery
- }
-
- // @formatter:off
- @Override
- protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) {
- if (layerIndex == 0) {
- unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION,
- state.positionHolder, state.temp, state.tempFFN,
- state.tempQcur, state.tempKcur); //
- unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, //
- context, state.wrapXb, state.wrapXb2, //
- state.wrapQ, state.wrapK, state.wrapV, //
- state.wrapKeyCache, state.wrapValueCache, //
- state.wrapAtt, state.wrapHb);//
- } else {
- // Subsequent layers: Consume data already on device from previous layer
- unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, //
- state.wrapQ, state.wrapK, state.wrapV, //
- state.wrapKeyCache, state.wrapValueCache, //
- state.wrapAtt, state.wrapHb, //
- state.positionHolder //
- );
- }
- return unifiedLayer;
- }
- // @formatter:on
-
- // @formatter:off
- @Override
- public Tuple2, GridScheduler> setupTornadoForwardPlanLayered() {
- List taskGraphs = new ArrayList<>();
-
- state.temp.init(0.0f);
- state.tempFFN.init(0.0f);
- state.tempLogits.init(0.0f);
- state.wrapLogits.init(0.0f);
-
- TaskGraph activationUpdate = new TaskGraph("activationUpdate")
- .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapX)
- .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.wrapX)
- .persistOnDevice(state.wrapX);
- taskGraphs.add(activationUpdate.snapshot());
-
- TaskGraph unifiedLayer = null;
- for (int layerIndex =0; layerIndex < config.numberOfLayers(); layerIndex++) {
- unifiedLayer = new TaskGraph("layer_" + layerIndex);
- unifiedLayer.consumeFromDevice(state.wrapX);
- unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
- //Copy-in weights per layer for batched-layered layout
- weights.rms_att_weightLayered[layerIndex],
- weights.wqLayered[layerIndex],
- weights.wkLayered[layerIndex],
- weights.wvLayered[layerIndex],
- weights.woLayered[layerIndex],
- //rms_att_KNormLayered
- weights.rms_att_KNormLayered[layerIndex],
- //rms_att_QNormLayered
- weights.rms_att_QNormLayered[layerIndex],
- weights.rms_ffn_weightLayered[layerIndex],
- weights.w1Layered[layerIndex],
- weights.w2Layered[layerIndex],
- weights.w3Layered[layerIndex]
- );
- unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
- unifiedLayer.task("reductionsOneBlock",
- TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
- context,
- state.temp,
- state.wrapX, // in
- config.dim(),
- config.rmsNormEps(),
- state.localSize)
- .task("mapContext",
- TransformerComputeKernelsLayered::reductionOneBlock2WithLayer,
- context,
- state.wrapXb, // out
- state.wrapX,
- weights.rms_att_weightLayered[layerIndex],
- state.temp);
-
- int qDim0 = nEmbdHeadK * config.numberOfHeads();
- int kvDim0 = nEmbdGqa;
- int qkvDim1 = config.dim();
- unifiedLayer.task("qmatmul",
- TransformerComputeKernelsLayered::matrixVectorGeneric,
- context,
- state.wrapXb,
- state.wrapQ, // output
- weights.wqLayered[layerIndex],
- qkvDim1,
- qDim0,
- LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("kmatmul",
- TransformerComputeKernelsLayered::matrixVectorGeneric,
- context,
- state.wrapXb,
- state.wrapK, // output
- weights.wkLayered[layerIndex],
- qkvDim1,
- kvDim0,
- LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("vmatmul",
- TransformerComputeKernelsLayered::matrixVectorGeneric,
- context,
- state.wrapXb,
- state.wrapV, // output
- weights.wvLayered[layerIndex],
- qkvDim1,
- kvDim0,
- LOCAL_WORK_GROUP_SIZE_ALLOC);
-
- // Qcur rmsnorm
- unifiedLayer
- .task("rmsnormReduction_Qcur",
- Qwen3Kernels::rmsnormWithParallelOffset,
- context,
- state.tempQcur, // output
- state.wrapQ, // input
- state.localSize, // currently 128, should be variable of global nEmbHead
- nEmbdHead, // for normalization
- config.rmsNormEps()) // for normalization
- .task("rmsnormMapIndexInPlace_Qcur",
- Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset,
- context,
- state.wrapQ, // output
- weights.rms_att_QNormLayered[layerIndex],
- nEmbdHead,
- state.tempQcur);
-
- // Kcur rmsnorm
- unifiedLayer
- .task("rmsnormReduction_Kcur",
- Qwen3Kernels::rmsnormWithParallelOffset,
- context,
- state.tempKcur, // output
- state.wrapK, // input
- state.localSize, // currently 128, should be variable of global nEmbHead
- nEmbdHead, // for normalization
- config.rmsNormEps()) // for normalization
- .task("rmsnormMapIndexInPlace_Kcur",
- Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset,
- context,
- state.wrapK, // output
- weights.rms_att_KNormLayered[layerIndex],
- nEmbdHead,
- state.tempKcur);
-
- // rope rotation task graph
- unifiedLayer.task("ropeRotation",
- Qwen3Kernels::ropeRotation,
- context,
- state.positionHolder,
- state.wrapQ, // out
- state.wrapK, // out
- config.numberOfKeyValueHeads(),
- nEmbdHead);
-
- unifiedLayer.task("copyToCaches",
- TransformerComputeKernelsLayered::copyToCache,
- state.wrapKeyCache, // out
- state.wrapK, // in
- state.wrapValueCache, // out
- state.wrapV, // in
- state.positionHolder,
- nEmbdGqa,
- layerIndex,
- config.contextLength());
-
- unifiedLayer.task("parallel-attention",
- TransformerComputeKernelsLayered::processHeadsFlashAttentionOpt,
- context,
- state.wrapQ,
- state.wrapKeyCache,
- state.wrapValueCache,
- state.wrapXb, // out
- config.numberOfHeads(),
- nEmbdHead,
- nEmbdGqa,
- gqa,
- state.positionHolder,
- layerIndex,
- config.contextLength());
-
- unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual,
- context,
- state.wrapXb, // vector
- state.wrapX, // out, should be [1024]
- weights.woLayered[layerIndex], // matrix
- nEmbdHeadK * config.numberOfHeads(), // dim1 = 2048
- config.dim(), // dim0 = 1024
- LOCAL_WORK_GROUP_SIZE_ALLOC);
-
- unifiedLayer.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
- context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
- .task("reductionFinalNormalizationFFN" , TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN,
- config.dim(), config.rmsNormEps())
- .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb,
- state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN);
-
- unifiedLayer.task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context,
- state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex], weights.w3Layered[layerIndex], config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
- state.wrapHb, state.wrapX, weights.w2Layered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .persistOnDevice(
- state.wrapX
- );
- taskGraphs.add(unifiedLayer.snapshot());
- }
-
- TaskGraph lastUnifiedLayer = unifiedLayer;
- TaskGraph logits = new TaskGraph("logits")
- .consumeFromDevice(lastUnifiedLayer.getTaskGraphName(),
- state.wrapX
- )
- .transferToDevice(DataTransferMode.EVERY_EXECUTION,
- state.tempLogits,
- state.wrapLogits
- )
- .transferToDevice(DataTransferMode.FIRST_EXECUTION,
- context,
- weights.wclsHalfFloat,
- weights.rms_final_weight_as_floatArray
- )
- .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer,
- context,
- state.tempLogits,
- state.wrapX,
- config.dim(),
- config.rmsNormEps(),
- state.localSize)
- .task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX,
- weights.rms_final_weight_as_floatArray, state.tempLogits);
- logits = configureQuantizedMatrixVectorFinalWeight(logits);
- logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
- taskGraphs.add(logits.snapshot());
-
- return new Tuple2<>(taskGraphs, setupQwen3GridSchedulersLayeredNonNvidia());
-
- }
- // @formatter:on
-
- @Override
- public Tuple2, GridScheduler> setupTornadoForwardPlanLayeredNonNvidia() {
- return setupTornadoForwardPlanLayered();
- }
-
- private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
- GridScheduler gridScheduler = new GridScheduler();
-
- WorkerGrid singleWorker = new WorkerGrid1D(1);
- singleWorker.setGlobalWork(1, 1, 1);
- singleWorker.setLocalWork(1, 1, 1);
-
- WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim());
- rmsNormWorker.setGlobalWork(config.dim(), 1, 1); // Set global work size to total dimension
- rmsNormWorker.setLocalWork(state.localSize, 1, 1); // Set local work size to 256 (standard efficient size)
-
- int matmulQGlobal = nEmbdHeadK * config.numberOfHeads() * LOCAL_WORK_GROUP_SIZE_ALLOC;
- WorkerGrid matmulQRowMajorWorker = new WorkerGrid1D(matmulQGlobal);
- matmulQRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
-
- int matmulKVGlobal = nEmbdGqa * LOCAL_WORK_GROUP_SIZE_ALLOC;
- WorkerGrid matmulKVRowMajorWorker = new WorkerGrid1D(matmulKVGlobal);
- matmulKVRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
-
- WorkerGrid curWorker = new WorkerGrid1D(nEmbdHead); // mEmbdHead = 128
- curWorker.setGlobalWork(nEmbdHead, 1, 1); // Set global work size to total dimension
- curWorker.setLocalWork(128, 1, 1); // Set local work size to 256 (standard efficient size)
-
- // Qcur
- WorkerGrid qCurWorker = new WorkerGrid1D(config.numberOfHeads() * nEmbdHead);
- qCurWorker.setLocalWork(nEmbdHead, 1, 1);
-
- // Kcur
- WorkerGrid kCurWorker = new WorkerGrid1D(config.numberOfKeyValueHeads() * nEmbdHead);
- kCurWorker.setLocalWork(nEmbdHead, 1, 1);
-
- int h = config.numberOfHeads();
- int ic = nEmbdHead / 2;
- WorkerGrid ropeWorker = new WorkerGrid2D(h, ic);
- ropeWorker.setGlobalWork(h, ic, 1);
- ropeWorker.setLocalWork(8, 1, 1);
-
- WorkerGrid copyToCachesWorker = new WorkerGrid1D(nEmbdGqa);
- copyToCachesWorker.setGlobalWork(nEmbdGqa, 1, 1);
- copyToCachesWorker.setLocalWork(128, 1, 1);
-
- // Parallel attention worker configuration
- WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads());
- parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * 32, 1, 1);
- parallelAttentionWorker.setLocalWork(32, 1, 1);
-
- int matmul1Global = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
- WorkerGrid matmul1Worker = new WorkerGrid1D(matmul1Global);
- matmul1Worker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
-
- int fusedFFNW1W3Global = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
- WorkerGrid fusedFFNW1W3Worker = new WorkerGrid1D(fusedFFNW1W3Global);
- fusedFFNW1W3Worker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
-
- int projectionTwoGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
- WorkerGrid projectionTwoWorker = new WorkerGrid1D(projectionTwoGlobal);
- projectionTwoWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
-
- // Map workers to tasks
- gridScheduler.addWorkerGrid("activationUpdate.updateX", singleWorker);
- for (int i = 0; i < config.numberOfLayers(); i++) {
- gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker);
-
- gridScheduler.addWorkerGrid("layer_" + i + ".qmatmul", matmulQRowMajorWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".kmatmul", matmulKVRowMajorWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".vmatmul", matmulKVRowMajorWorker);
-
- // Qcur
- gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Qcur", qCurWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Qcur", qCurWorker);
-
- // Kcur
- gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Kcur", kCurWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Kcur", kCurWorker);
-
- gridScheduler.addWorkerGrid("layer_" + i + ".ropeRotation", ropeWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".matmul1", matmul1Worker);
- gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker);
- gridScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", fusedFFNW1W3Worker);
- gridScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", projectionTwoWorker);
- }
-
- int vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
- WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
- vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
-
- gridScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker);
- gridScheduler.addWorkerGrid("logits.mapContextLogits", rmsNormWorker);
-
- gridScheduler.addWorkerGrid("logits.projection", vocabWorker);
-
- return gridScheduler;
- }
-
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMLayerPlanner.java
deleted file mode 100644
index 54764389..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMLayerPlanner.java
+++ /dev/null
@@ -1,543 +0,0 @@
-package org.beehive.gpullama3.tornadovm;
-
-import org.beehive.gpullama3.auxiliary.Tuple2;
-import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
-import org.beehive.gpullama3.model.Configuration;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.inference.state.State;
-import uk.ac.manchester.tornado.api.GridScheduler;
-import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
-import uk.ac.manchester.tornado.api.KernelContext;
-import uk.ac.manchester.tornado.api.TaskGraph;
-import uk.ac.manchester.tornado.api.WorkerGrid;
-import uk.ac.manchester.tornado.api.WorkerGrid1D;
-import uk.ac.manchester.tornado.api.enums.DataTransferMode;
-
-import java.util.ArrayList;
-import java.util.List;
-
-// @formatter:off
- /**
- * TornadoVMLayerPlanner orchestrates the execution planning for transformer model inference
- * on GPU using the TornadoVM framework.
- *
- * This class is responsible for:
- * - Creating task graphs for each layer of the neural network
- * - Managing GPU memory transfers between layers
- * - Configuring worker grids for optimal GPU utilization
- * - Setting up the execution schedule for the entire forward pass
- *
- * The planner implements a layered approach where:
- * - Each layer is represented as a separate TaskGraph
- * - Data transfers are optimized to minimize host-device communication
- * - Worker grids are configured for different types of operations (attention, FFN, etc.)
- * - The entire pipeline is scheduled to run efficiently on GPU
- *
- * Key optimizations include:
- * - One-time transfer of static data (weights, caches)
- * - Per-execution transfer of dynamic data (position, activations)
- * - Device-to-device data consumption between layers
- * - Parallelized attention computation across heads
- *
- * @see TaskGraph
- * @see GridScheduler
- */
- // @formatter:on
- public class TornadoVMLayerPlanner {
- protected static final int LOCAL_WORK_GROUP_SIZE_ALLOC = 32;
- protected static final int THREAD_SCALE_FOR_LOGITS = 8;
-
- protected final S state;
- protected final C config;
- protected final W weights;
- protected final KernelContext context;
-
- /**
- * Constructs a TornadoVMLayerPlanner for the given Llama model.
- *
- * @param state
- * The state object containing model tensors and buffers
- * @param model
- * The Llama model instance containing configuration and weights
- */
- public TornadoVMLayerPlanner(S state, Model model) {
- this.state = state;
- this.config = (C) model.configuration();
- this.weights = (W) model.weights();
- this.context = new KernelContext();
- }
-
- public Tuple2, GridScheduler> setupTornadoForwardPlanLayered() {
- List taskGraphs = new ArrayList<>();
-
- state.temp.init(0.0f);
- state.tempFFN.init(0.0f);
- state.tempLogits.init(0.0f);
-
- // @formatter:off
- TaskGraph activationUpdate = new TaskGraph("activationUpdate")
- .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapX)
- .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.wrapX)
- .persistOnDevice(state.wrapX);
- taskGraphs.add(activationUpdate.snapshot());
-
- TaskGraph unifiedLayer = null;
- for (int layerIndex =0; layerIndex < config.numberOfLayers(); layerIndex++) {
- unifiedLayer = new TaskGraph("layer_" + layerIndex);
- unifiedLayer.consumeFromDevice(state.wrapX);
- unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
- //Copy-in weights per layer for batched-layered layout
- weights.rms_att_weightLayered[layerIndex],
- weights.wqLayered[layerIndex],
- weights.wkLayered[layerIndex],
- weights.wvLayered[layerIndex],
- weights.woLayered[layerIndex],
- weights.rms_ffn_weightLayered[layerIndex],
- weights.w1Layered[layerIndex],
- weights.w2Layered[layerIndex],
- weights.w3Layered[layerIndex]
- );
- unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
- unifiedLayer.task("reductionsOneBlock" , TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp,
- state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
- .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb,
- state.wrapX, weights.rms_att_weightLayered[layerIndex], state.temp)
- .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context,
- state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context,
- state.wrapXb, state.wrapK, weights.wkLayered[layerIndex], config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context,
- state.wrapXb, state.wrapV, weights.wvLayered[layerIndex], config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("rope", TransformerComputeKernelsLayered::ropeRotation,context,
- state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(),
- config.headSize())
- .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache,
- state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), layerIndex, config.contextLength())
- .task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, context,
- state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb,
- config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(),
- state.positionHolder, layerIndex, config.contextLength())
- .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
- state.wrapXb, state.wrapX, weights.woLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN,
- state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
- .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb,
- state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN)
- .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context,
- state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex], weights.w3Layered[layerIndex], config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
- state.wrapHb, state.wrapX, weights.w2Layered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
- .persistOnDevice(
- state.wrapX
- );
- taskGraphs.add(unifiedLayer.snapshot());
- }
-
- TaskGraph lastUnifiedLayer = unifiedLayer;
- TaskGraph logits = new TaskGraph("logits")
- .consumeFromDevice(lastUnifiedLayer.getTaskGraphName(),
- state.wrapX
- )
- .transferToDevice(DataTransferMode.EVERY_EXECUTION,
- state.tempLogits
- )
- .transferToDevice(DataTransferMode.FIRST_EXECUTION,
- context,
- state.wrapLogits,
- weights.wclsHalfFloat,
- weights.rms_final_weight_as_floatArray
- )
- .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits,
- state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
- .task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX,
- weights.rms_final_weight_as_floatArray, state.tempLogits);
- logits = configureQuantizedMatrixVectorFinalWeight(logits);
- logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
- taskGraphs.add(logits.snapshot());
- // @formatter:on
-
- return new Tuple2<>(taskGraphs, setupGridSchedulersLayered());
- }
-
- // @formatter:off
- /**
- * Configures the final projection layer in the task graph based on weight quantization type.
- *
- * This method adds a "projection" task to compute the final logits by performing a
- * matrix-vector multiplication between the model's output embeddings and the classifier
- * weights (wcls). The computation kernel used depends on the quantization format.
- *
- * Supported quantization types:
- * - Q8_0: 8-bit quantization with uniform scaling per 32-element block
- * - Q4_0: 4-bit quantization with uniform scaling per 32-element block
- *
- * The task multiplies:
- * - weights.wclsByteArray: Quantized classifier weights (vocab_size x dim)
- * - state.wrapX: Current layer output (dim)
- * - Result: state.wrapLogits: Raw logits (vocab_size)
- *
- * @param logits The existing task graph to extend with the projection operation
- * @return The modified task graph with the projection task added
- * @throws UnsupportedOperationException If weights.weightType is not Q8_0 or Q4_0
- */
- // @formatter:on
- protected TaskGraph configureQuantizedMatrixVectorFinalWeight(TaskGraph logits) {
- switch (weights.getWeightType()) {
- case F16:
- case Q8_0:
- case Q4_0:
- logits.task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, //
- context, state.wrapX, state.wrapLogits, weights.wclsHalfFloat, //
- config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); //
- break;
- default:
- throw new UnsupportedOperationException("Unsupported weight quantization type: " + weights.getWeightType() + ". Only Q8_0 and Q4_0 are supported.");
- }
- return logits;
- }
-
- /**
- * Configures data transfer operations for a specific layer in the neural network task graph.
- *
- * This method manages GPU memory transfers with optimized data movement strategies:
- * This optimization pattern minimizes data movement by:
- * 1. Using one-time transfers for static data
- * 2. Reusing intermediate results already on GPU from previous layers
- * 3. Only transferring //
- * dynamic data that changes per execution
- *
- * @param unifiedLayer
- * The task graph representing this layer's operations
- * @param layerIndex
- * Index of the current layer (0-based)
- * @return The configured task graph with appropriate data transfer operations
- */
- protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) {
- // First layer: Transfer initial data to device (one-time transfer)
- if (layerIndex == 0) {
- // Transfer all attention-related data: query, key, value matrices and their caches
- unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); //
- unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, //
- context, state.wrapXb, state.wrapXb2, //
- state.wrapQ, state.wrapK, state.wrapV, //
- state.wrapKeyCache, state.wrapValueCache, //
- state.wrapAtt, state.wrapHb); //
- } else {
- // Subsequent layers: Consume data already on device from previous layer
- unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, //
- state.wrapQ, state.wrapK, state.wrapV, //
- state.wrapKeyCache, state.wrapValueCache, //
- state.wrapAtt, state.wrapHb, //
- state.positionHolder //
- );
- }
- return unifiedLayer;
- }
-
- // @formatter:off
- /**
- * Sets up the grid scheduler configuration for a layered neural network forward pass.
- *
- * This method creates and configures worker grids for different types of GPU operations
- * in the transformer/ML model pipeline. Each worker grid defines how work should be
- * distributed across GPU threads (OpenCL work-items or CUDA threads).
- *
- * The method creates several worker profiles:
- * - Single thread operations (activation updates)
- * - RoPE (Rotary Position Embedding) operations
- * - Matrix multiplications with different dimensions
- * - RMS normalization operations
- * - Parallel attention computations
- * - Cache copying operations
- * - Vocabulary projections
- *
- * Each worker grid maps to equivalent OpenCL NDRange or CUDA grid/block configurations:
- * - setGlobalWork() ≈ OpenCL global_work_size ≈ CUDA grid dimensions × block dimensions
- * - setLocalWork() ≈ OpenCL local_work_size ≈ CUDA block dimensions
- *
- * @return GridScheduler configured with all necessary worker grids for the model layers
- */
- // @formatter:on
- private GridScheduler setupGridSchedulersLayered() {
- GridScheduler tornadoForwardScheduler = new GridScheduler();
-
- // Single worker for tasks running with a single thread
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[1,1,1], localWorkSize=[1,1,1])
- // CUDA equivalent: kernel<<>>
- WorkerGrid singleWorker = new WorkerGrid1D(1);
- singleWorker.setGlobalWork(1, 1, 1);
- singleWorker.setLocalWork(1, 1, 1);
-
- // config.dim / 2 Worker for RoPE
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim/2,1,1], localWorkSize=[128,1,1])
- // CUDA equivalent: kernel<<>>
- WorkerGrid ropeWorker = new WorkerGrid1D(config.dim() / 2);
- ropeWorker.setGlobalWork(config.dim() / 2, 1, 1);
- ropeWorker.setLocalWork(128, 1, 1);
-
- // config.dim Worker for Row major access
- // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1])
- // CUDA equivalent: kernel<<